diff --git a/synapse/__init__.py b/synapse/__init__.py
index 9c75a0a27f..119359be68 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -17,14 +17,22 @@
""" This is a reference implementation of a Matrix home server.
"""
+import sys
+
+# Check that we're not running on an unsupported Python version.
+if sys.version_info < (3, 5):
+ print("Synapse requires Python 3.5 or above.")
+ sys.exit(1)
+
try:
from twisted.internet import protocol
from twisted.internet.protocol import Factory
from twisted.names.dns import DNSDatagramProtocol
+
protocol.Factory.noisy = False
Factory.noisy = False
DNSDatagramProtocol.noisy = False
except ImportError:
pass
-__version__ = "1.0.0rc3"
+__version__ = "1.0.0"
diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py
index 6e93f5a0c6..bdcd915bbe 100644
--- a/synapse/_scripts/register_new_matrix_user.py
+++ b/synapse/_scripts/register_new_matrix_user.py
@@ -57,18 +57,18 @@ def request_registration(
nonce = r.json()["nonce"]
- mac = hmac.new(key=shared_secret.encode('utf8'), digestmod=hashlib.sha1)
+ mac = hmac.new(key=shared_secret.encode("utf8"), digestmod=hashlib.sha1)
- mac.update(nonce.encode('utf8'))
+ mac.update(nonce.encode("utf8"))
mac.update(b"\x00")
- mac.update(user.encode('utf8'))
+ mac.update(user.encode("utf8"))
mac.update(b"\x00")
- mac.update(password.encode('utf8'))
+ mac.update(password.encode("utf8"))
mac.update(b"\x00")
mac.update(b"admin" if admin else b"notadmin")
if user_type:
mac.update(b"\x00")
- mac.update(user_type.encode('utf8'))
+ mac.update(user_type.encode("utf8"))
mac = mac.hexdigest()
@@ -134,8 +134,9 @@ def register_new_user(user, password, server_location, shared_secret, admin, use
else:
admin = False
- request_registration(user, password, server_location, shared_secret,
- bool(admin), user_type)
+ request_registration(
+ user, password, server_location, shared_secret, bool(admin), user_type
+ )
def main():
@@ -189,7 +190,7 @@ def main():
group.add_argument(
"-c",
"--config",
- type=argparse.FileType('r'),
+ type=argparse.FileType("r"),
help="Path to server config file. Used to read in shared secret.",
)
@@ -200,7 +201,7 @@ def main():
parser.add_argument(
"server_url",
default="https://localhost:8448",
- nargs='?',
+ nargs="?",
help="URL to use to talk to the home server. Defaults to "
" 'https://localhost:8448'.",
)
@@ -220,8 +221,9 @@ def main():
if args.admin or args.no_admin:
admin = args.admin
- register_new_user(args.user, args.password, args.server_url, secret,
- admin, args.user_type)
+ register_new_user(
+ args.user, args.password, args.server_url, secret, admin, args.user_type
+ )
if __name__ == "__main__":
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 79e2808dc5..86f145649c 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -36,8 +36,11 @@ logger = logging.getLogger(__name__)
AuthEventTypes = (
- EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
- EventTypes.JoinRules, EventTypes.RoomHistoryVisibility,
+ EventTypes.Create,
+ EventTypes.Member,
+ EventTypes.PowerLevels,
+ EventTypes.JoinRules,
+ EventTypes.RoomHistoryVisibility,
EventTypes.ThirdPartyInvite,
)
@@ -54,6 +57,7 @@ class Auth(object):
FIXME: This class contains a mix of functions for authenticating users
of our client-server API and authenticating events added to room graphs.
"""
+
def __init__(self, hs):
self.hs = hs
self.clock = hs.get_clock()
@@ -70,15 +74,12 @@ class Auth(object):
def check_from_context(self, room_version, event, context, do_sig_check=True):
prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.compute_auth_events(
- event, prev_state_ids, for_verification=True,
+ event, prev_state_ids, for_verification=True
)
auth_events = yield self.store.get_events(auth_events_ids)
- auth_events = {
- (e.type, e.state_key): e for e in itervalues(auth_events)
- }
+ auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)}
self.check(
- room_version, event,
- auth_events=auth_events, do_sig_check=do_sig_check,
+ room_version, event, auth_events=auth_events, do_sig_check=do_sig_check
)
def check(self, room_version, event, auth_events, do_sig_check=True):
@@ -115,15 +116,10 @@ class Auth(object):
the room.
"""
if current_state:
- member = current_state.get(
- (EventTypes.Member, user_id),
- None
- )
+ member = current_state.get((EventTypes.Member, user_id), None)
else:
member = yield self.state.get_current_state(
- room_id=room_id,
- event_type=EventTypes.Member,
- state_key=user_id
+ room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
self._check_joined_room(member, user_id, room_id)
@@ -143,23 +139,17 @@ class Auth(object):
the room. This will be the leave event if they have left the room.
"""
member = yield self.state.get_current_state(
- room_id=room_id,
- event_type=EventTypes.Member,
- state_key=user_id
+ room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
membership = member.membership if member else None
if membership not in (Membership.JOIN, Membership.LEAVE):
- raise AuthError(403, "User %s not in room %s" % (
- user_id, room_id
- ))
+ raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
if membership == Membership.LEAVE:
forgot = yield self.store.did_forget(user_id, room_id)
if forgot:
- raise AuthError(403, "User %s not in room %s" % (
- user_id, room_id
- ))
+ raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
defer.returnValue(member)
@@ -171,9 +161,9 @@ class Auth(object):
def _check_joined_room(self, member, user_id, room_id):
if not member or member.membership != Membership.JOIN:
- raise AuthError(403, "User %s not in room %s (%s)" % (
- user_id, room_id, repr(member)
- ))
+ raise AuthError(
+ 403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member))
+ )
def can_federate(self, event, auth_events):
creation_event = auth_events.get((EventTypes.Create, ""))
@@ -185,11 +175,7 @@ class Auth(object):
@defer.inlineCallbacks
def get_user_by_req(
- self,
- request,
- allow_guest=False,
- rights="access",
- allow_expired=False,
+ self, request, allow_guest=False, rights="access", allow_expired=False
):
""" Get a registered user's ID.
@@ -209,9 +195,8 @@ class Auth(object):
try:
ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders(
- b"User-Agent",
- default=[b""]
- )[0].decode('ascii', 'surrogateescape')
+ b"User-Agent", default=[b""]
+ )[0].decode("ascii", "surrogateescape")
access_token = self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
@@ -243,11 +228,12 @@ class Auth(object):
if self._account_validity.enabled and not allow_expired:
user_id = user.to_string()
expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
- if expiration_ts is not None and self.clock.time_msec() >= expiration_ts:
+ if (
+ expiration_ts is not None
+ and self.clock.time_msec() >= expiration_ts
+ ):
raise AuthError(
- 403,
- "User account has expired",
- errcode=Codes.EXPIRED_ACCOUNT,
+ 403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
)
# device_id may not be present if get_user_by_access_token has been
@@ -265,18 +251,23 @@ class Auth(object):
if is_guest and not allow_guest:
raise AuthError(
- 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
+ 403,
+ "Guest access not allowed",
+ errcode=Codes.GUEST_ACCESS_FORBIDDEN,
)
request.authenticated_entity = user.to_string()
- defer.returnValue(synapse.types.create_requester(
- user, token_id, is_guest, device_id, app_service=app_service)
+ defer.returnValue(
+ synapse.types.create_requester(
+ user, token_id, is_guest, device_id, app_service=app_service
+ )
)
except KeyError:
raise AuthError(
- self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
- errcode=Codes.MISSING_TOKEN
+ self.TOKEN_NOT_FOUND_HTTP_STATUS,
+ "Missing access token.",
+ errcode=Codes.MISSING_TOKEN,
)
@defer.inlineCallbacks
@@ -297,20 +288,14 @@ class Auth(object):
if b"user_id" not in request.args:
defer.returnValue((app_service.sender, app_service))
- user_id = request.args[b"user_id"][0].decode('utf8')
+ user_id = request.args[b"user_id"][0].decode("utf8")
if app_service.sender == user_id:
defer.returnValue((app_service.sender, app_service))
if not app_service.is_interested_in_user(user_id):
- raise AuthError(
- 403,
- "Application service cannot masquerade as this user."
- )
+ raise AuthError(403, "Application service cannot masquerade as this user.")
if not (yield self.store.get_user_by_id(user_id)):
- raise AuthError(
- 403,
- "Application service has not registered this user"
- )
+ raise AuthError(403, "Application service has not registered this user")
defer.returnValue((user_id, app_service))
@defer.inlineCallbacks
@@ -368,13 +353,13 @@ class Auth(object):
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Unknown user_id %s" % user_id,
- errcode=Codes.UNKNOWN_TOKEN
+ errcode=Codes.UNKNOWN_TOKEN,
)
if not stored_user["is_guest"]:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Guest access token used for regular user",
- errcode=Codes.UNKNOWN_TOKEN
+ errcode=Codes.UNKNOWN_TOKEN,
)
ret = {
"user": user,
@@ -402,8 +387,9 @@ class Auth(object):
) as e:
logger.warning("Invalid macaroon in auth: %s %s", type(e), e)
raise AuthError(
- self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
- errcode=Codes.UNKNOWN_TOKEN
+ self.TOKEN_NOT_FOUND_HTTP_STATUS,
+ "Invalid macaroon passed.",
+ errcode=Codes.UNKNOWN_TOKEN,
)
def _parse_and_validate_macaroon(self, token, rights="access"):
@@ -441,13 +427,13 @@ class Auth(object):
guest = True
self.validate_macaroon(
- macaroon, rights, self.hs.config.expire_access_token,
- user_id=user_id,
+ macaroon, rights, self.hs.config.expire_access_token, user_id=user_id
)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
raise AuthError(
- self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
- errcode=Codes.UNKNOWN_TOKEN
+ self.TOKEN_NOT_FOUND_HTTP_STATUS,
+ "Invalid macaroon passed.",
+ errcode=Codes.UNKNOWN_TOKEN,
)
if not has_expiry and rights == "access":
@@ -472,10 +458,11 @@ class Auth(object):
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
- return caveat.caveat_id[len(user_prefix):]
+ return caveat.caveat_id[len(user_prefix) :]
raise AuthError(
- self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
- errcode=Codes.UNKNOWN_TOKEN
+ self.TOKEN_NOT_FOUND_HTTP_STATUS,
+ "No user caveat in macaroon",
+ errcode=Codes.UNKNOWN_TOKEN,
)
def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
@@ -522,7 +509,7 @@ class Auth(object):
prefix = "time < "
if not caveat.startswith(prefix):
return False
- expiry = int(caveat[len(prefix):])
+ expiry = int(caveat[len(prefix) :])
now = self.hs.get_clock().time_msec()
return now < expiry
@@ -554,14 +541,12 @@ class Auth(object):
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Unrecognised access token.",
- errcode=Codes.UNKNOWN_TOKEN
+ errcode=Codes.UNKNOWN_TOKEN,
)
request.authenticated_entity = service.sender
return defer.succeed(service)
except KeyError:
- raise AuthError(
- self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token."
- )
+ raise AuthError(self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.")
def is_server_admin(self, user):
""" Check if the given user is a local server admin.
@@ -581,19 +566,19 @@ class Auth(object):
auth_ids = []
- key = (EventTypes.PowerLevels, "", )
+ key = (EventTypes.PowerLevels, "")
power_level_event_id = current_state_ids.get(key)
if power_level_event_id:
auth_ids.append(power_level_event_id)
- key = (EventTypes.JoinRules, "", )
+ key = (EventTypes.JoinRules, "")
join_rule_event_id = current_state_ids.get(key)
- key = (EventTypes.Member, event.sender, )
+ key = (EventTypes.Member, event.sender)
member_event_id = current_state_ids.get(key)
- key = (EventTypes.Create, "", )
+ key = (EventTypes.Create, "")
create_event_id = current_state_ids.get(key)
if create_event_id:
auth_ids.append(create_event_id)
@@ -619,7 +604,7 @@ class Auth(object):
auth_ids.append(member_event_id)
if for_verification:
- key = (EventTypes.Member, event.state_key, )
+ key = (EventTypes.Member, event.state_key)
existing_event_id = current_state_ids.get(key)
if existing_event_id:
auth_ids.append(existing_event_id)
@@ -628,7 +613,7 @@ class Auth(object):
if "third_party_invite" in event.content:
key = (
EventTypes.ThirdPartyInvite,
- event.content["third_party_invite"]["signed"]["token"]
+ event.content["third_party_invite"]["signed"]["token"],
)
third_party_invite_id = current_state_ids.get(key)
if third_party_invite_id:
@@ -684,7 +669,7 @@ class Auth(object):
auth_events[(EventTypes.PowerLevels, "")] = power_level_event
send_level = event_auth.get_send_level(
- EventTypes.Aliases, "", power_level_event,
+ EventTypes.Aliases, "", power_level_event
)
user_level = event_auth.get_user_power_level(user_id, auth_events)
@@ -692,7 +677,7 @@ class Auth(object):
raise AuthError(
403,
"This server requires you to be a moderator in the room to"
- " edit its room list entry"
+ " edit its room list entry",
)
@staticmethod
@@ -742,7 +727,7 @@ class Auth(object):
)
parts = auth_headers[0].split(b" ")
if parts[0] == b"Bearer" and len(parts) == 2:
- return parts[1].decode('ascii')
+ return parts[1].decode("ascii")
else:
raise AuthError(
token_not_found_http_status,
@@ -755,10 +740,10 @@ class Auth(object):
raise AuthError(
token_not_found_http_status,
"Missing access token.",
- errcode=Codes.MISSING_TOKEN
+ errcode=Codes.MISSING_TOKEN,
)
- return query_params[0].decode('ascii')
+ return query_params[0].decode("ascii")
@defer.inlineCallbacks
def check_in_room_or_world_readable(self, room_id, user_id):
@@ -785,8 +770,8 @@ class Auth(object):
room_id, EventTypes.RoomHistoryVisibility, ""
)
if (
- visibility and
- visibility.content["history_visibility"] == "world_readable"
+ visibility
+ and visibility.content["history_visibility"] == "world_readable"
):
defer.returnValue((Membership.JOIN, None))
return
@@ -820,10 +805,11 @@ class Auth(object):
if self.hs.config.hs_disabled:
raise ResourceLimitError(
- 403, self.hs.config.hs_disabled_message,
+ 403,
+ self.hs.config.hs_disabled_message,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
admin_contact=self.hs.config.admin_contact,
- limit_type=self.hs.config.hs_disabled_limit_type
+ limit_type=self.hs.config.hs_disabled_limit_type,
)
if self.hs.config.limit_usage_by_mau is True:
assert not (user_id and threepid)
@@ -848,8 +834,9 @@ class Auth(object):
current_mau = yield self.store.get_monthly_active_count()
if current_mau >= self.hs.config.max_mau_value:
raise ResourceLimitError(
- 403, "Monthly Active User Limit Exceeded",
+ 403,
+ "Monthly Active User Limit Exceeded",
admin_contact=self.hs.config.admin_contact,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
- limit_type="monthly_active_user"
+ limit_type="monthly_active_user",
)
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index ee129c8689..3ffde0d7fc 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -18,7 +18,7 @@
"""Contains constants from the specification."""
# the "depth" field on events is limited to 2**63 - 1
-MAX_DEPTH = 2**63 - 1
+MAX_DEPTH = 2 ** 63 - 1
# the maximum length for a room alias is 255 characters
MAX_ALIAS_LENGTH = 255
@@ -30,39 +30,41 @@ MAX_USERID_LENGTH = 255
class Membership(object):
"""Represents the membership states of a user in a room."""
- INVITE = u"invite"
- JOIN = u"join"
- KNOCK = u"knock"
- LEAVE = u"leave"
- BAN = u"ban"
+
+ INVITE = "invite"
+ JOIN = "join"
+ KNOCK = "knock"
+ LEAVE = "leave"
+ BAN = "ban"
LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN)
class PresenceState(object):
"""Represents the presence state of a user."""
- OFFLINE = u"offline"
- UNAVAILABLE = u"unavailable"
- ONLINE = u"online"
+
+ OFFLINE = "offline"
+ UNAVAILABLE = "unavailable"
+ ONLINE = "online"
class JoinRules(object):
- PUBLIC = u"public"
- KNOCK = u"knock"
- INVITE = u"invite"
- PRIVATE = u"private"
+ PUBLIC = "public"
+ KNOCK = "knock"
+ INVITE = "invite"
+ PRIVATE = "private"
class LoginType(object):
- PASSWORD = u"m.login.password"
- EMAIL_IDENTITY = u"m.login.email.identity"
- MSISDN = u"m.login.msisdn"
- RECAPTCHA = u"m.login.recaptcha"
- TERMS = u"m.login.terms"
- DUMMY = u"m.login.dummy"
+ PASSWORD = "m.login.password"
+ EMAIL_IDENTITY = "m.login.email.identity"
+ MSISDN = "m.login.msisdn"
+ RECAPTCHA = "m.login.recaptcha"
+ TERMS = "m.login.terms"
+ DUMMY = "m.login.dummy"
# Only for C/S API v1
- APPLICATION_SERVICE = u"m.login.application_service"
- SHARED_SECRET = u"org.matrix.login.shared_secret"
+ APPLICATION_SERVICE = "m.login.application_service"
+ SHARED_SECRET = "org.matrix.login.shared_secret"
class EventTypes(object):
@@ -118,6 +120,7 @@ class UserTypes(object):
"""Allows for user type specific behaviour. With the benefit of hindsight
'admin' and 'guest' users should also be UserTypes. Normal users are type None
"""
+
SUPPORT = "support"
ALL_USER_TYPES = (SUPPORT,)
@@ -125,6 +128,7 @@ class UserTypes(object):
class RelationTypes(object):
"""The types of relations known to this server.
"""
+
ANNOTATION = "m.annotation"
REPLACE = "m.replace"
REFERENCE = "m.reference"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 66201d6efe..28b5c2af9b 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -70,6 +70,7 @@ class CodeMessageException(RuntimeError):
code (int): HTTP error code
msg (str): string describing the error
"""
+
def __init__(self, code, msg):
super(CodeMessageException, self).__init__("%d: %s" % (code, msg))
self.code = code
@@ -83,6 +84,7 @@ class SynapseError(CodeMessageException):
Attributes:
errcode (str): Matrix error code e.g 'M_FORBIDDEN'
"""
+
def __init__(self, code, msg, errcode=Codes.UNKNOWN):
"""Constructs a synapse error.
@@ -95,10 +97,7 @@ class SynapseError(CodeMessageException):
self.errcode = errcode
def error_dict(self):
- return cs_error(
- self.msg,
- self.errcode,
- )
+ return cs_error(self.msg, self.errcode)
class ProxiedRequestError(SynapseError):
@@ -107,27 +106,23 @@ class ProxiedRequestError(SynapseError):
Attributes:
errcode (str): Matrix error code e.g 'M_FORBIDDEN'
"""
+
def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None):
- super(ProxiedRequestError, self).__init__(
- code, msg, errcode
- )
+ super(ProxiedRequestError, self).__init__(code, msg, errcode)
if additional_fields is None:
self._additional_fields = {}
else:
self._additional_fields = dict(additional_fields)
def error_dict(self):
- return cs_error(
- self.msg,
- self.errcode,
- **self._additional_fields
- )
+ return cs_error(self.msg, self.errcode, **self._additional_fields)
class ConsentNotGivenError(SynapseError):
"""The error returned to the client when the user has not consented to the
privacy policy.
"""
+
def __init__(self, msg, consent_uri):
"""Constructs a ConsentNotGivenError
@@ -136,22 +131,17 @@ class ConsentNotGivenError(SynapseError):
consent_url (str): The URL where the user can give their consent
"""
super(ConsentNotGivenError, self).__init__(
- code=http_client.FORBIDDEN,
- msg=msg,
- errcode=Codes.CONSENT_NOT_GIVEN
+ code=http_client.FORBIDDEN, msg=msg, errcode=Codes.CONSENT_NOT_GIVEN
)
self._consent_uri = consent_uri
def error_dict(self):
- return cs_error(
- self.msg,
- self.errcode,
- consent_uri=self._consent_uri
- )
+ return cs_error(self.msg, self.errcode, consent_uri=self._consent_uri)
class RegistrationError(SynapseError):
"""An error raised when a registration event fails."""
+
pass
@@ -190,15 +180,17 @@ class InteractiveAuthIncompleteError(Exception):
result (dict): the server response to the request, which should be
passed back to the client
"""
+
def __init__(self, result):
super(InteractiveAuthIncompleteError, self).__init__(
- "Interactive auth not yet complete",
+ "Interactive auth not yet complete"
)
self.result = result
class UnrecognizedRequestError(SynapseError):
"""An error indicating we don't understand the request you're trying to make"""
+
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.UNRECOGNIZED
@@ -207,21 +199,14 @@ class UnrecognizedRequestError(SynapseError):
message = "Unrecognized request"
else:
message = args[0]
- super(UnrecognizedRequestError, self).__init__(
- 400,
- message,
- **kwargs
- )
+ super(UnrecognizedRequestError, self).__init__(400, message, **kwargs)
class NotFoundError(SynapseError):
"""An error indicating we can't find the thing you asked for"""
+
def __init__(self, msg="Not found", errcode=Codes.NOT_FOUND):
- super(NotFoundError, self).__init__(
- 404,
- msg,
- errcode=errcode
- )
+ super(NotFoundError, self).__init__(404, msg, errcode=errcode)
class AuthError(SynapseError):
@@ -238,8 +223,11 @@ class ResourceLimitError(SynapseError):
Any error raised when there is a problem with resource usage.
For instance, the monthly active user limit for the server has been exceeded
"""
+
def __init__(
- self, code, msg,
+ self,
+ code,
+ msg,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
admin_contact=None,
limit_type=None,
@@ -253,7 +241,7 @@ class ResourceLimitError(SynapseError):
self.msg,
self.errcode,
admin_contact=self.admin_contact,
- limit_type=self.limit_type
+ limit_type=self.limit_type,
)
@@ -268,6 +256,7 @@ class EventSizeError(SynapseError):
class EventStreamError(SynapseError):
"""An error raised when there a problem with the event stream."""
+
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.BAD_PAGINATION
@@ -276,47 +265,53 @@ class EventStreamError(SynapseError):
class LoginError(SynapseError):
"""An error raised when there was a problem logging in."""
+
pass
class StoreError(SynapseError):
"""An error raised when there was a problem storing some data."""
+
pass
class InvalidCaptchaError(SynapseError):
- def __init__(self, code=400, msg="Invalid captcha.", error_url=None,
- errcode=Codes.CAPTCHA_INVALID):
+ def __init__(
+ self,
+ code=400,
+ msg="Invalid captcha.",
+ error_url=None,
+ errcode=Codes.CAPTCHA_INVALID,
+ ):
super(InvalidCaptchaError, self).__init__(code, msg, errcode)
self.error_url = error_url
def error_dict(self):
- return cs_error(
- self.msg,
- self.errcode,
- error_url=self.error_url,
- )
+ return cs_error(self.msg, self.errcode, error_url=self.error_url)
class LimitExceededError(SynapseError):
"""A client has sent too many requests and is being throttled.
"""
- def __init__(self, code=429, msg="Too Many Requests", retry_after_ms=None,
- errcode=Codes.LIMIT_EXCEEDED):
+
+ def __init__(
+ self,
+ code=429,
+ msg="Too Many Requests",
+ retry_after_ms=None,
+ errcode=Codes.LIMIT_EXCEEDED,
+ ):
super(LimitExceededError, self).__init__(code, msg, errcode)
self.retry_after_ms = retry_after_ms
def error_dict(self):
- return cs_error(
- self.msg,
- self.errcode,
- retry_after_ms=self.retry_after_ms,
- )
+ return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms)
class RoomKeysVersionError(SynapseError):
"""A client has tried to upload to a non-current version of the room_keys store
"""
+
def __init__(self, current_version):
"""
Args:
@@ -331,6 +326,7 @@ class RoomKeysVersionError(SynapseError):
class UnsupportedRoomVersionError(SynapseError):
"""The client's request to create a room used a room version that the server does
not support."""
+
def __init__(self):
super(UnsupportedRoomVersionError, self).__init__(
code=400,
@@ -354,22 +350,19 @@ class IncompatibleRoomVersionError(SynapseError):
Unlike UnsupportedRoomVersionError, it is specific to the case of the make_join
failing.
"""
+
def __init__(self, room_version):
super(IncompatibleRoomVersionError, self).__init__(
code=400,
msg="Your homeserver does not support the features required to "
- "join this room",
+ "join this room",
errcode=Codes.INCOMPATIBLE_ROOM_VERSION,
)
self._room_version = room_version
def error_dict(self):
- return cs_error(
- self.msg,
- self.errcode,
- room_version=self._room_version,
- )
+ return cs_error(self.msg, self.errcode, room_version=self._room_version)
class RequestSendFailed(RuntimeError):
@@ -380,11 +373,11 @@ class RequestSendFailed(RuntimeError):
networking (e.g. DNS failures, connection timeouts etc), versus unexpected
errors (like programming errors).
"""
+
def __init__(self, inner_exception, can_retry):
super(RequestSendFailed, self).__init__(
- "Failed to send request: %s: %s" % (
- type(inner_exception).__name__, inner_exception,
- )
+ "Failed to send request: %s: %s"
+ % (type(inner_exception).__name__, inner_exception)
)
self.inner_exception = inner_exception
self.can_retry = can_retry
@@ -428,7 +421,7 @@ class FederationError(RuntimeError):
self.affected = affected
self.source = source
- msg = "%s %s: %s" % (level, code, reason,)
+ msg = "%s %s: %s" % (level, code, reason)
super(FederationError, self).__init__(msg)
def get_dict(self):
@@ -448,6 +441,7 @@ class HttpResponseException(CodeMessageException):
Attributes:
response (bytes): body of response
"""
+
def __init__(self, code, msg, response):
"""
@@ -486,7 +480,7 @@ class HttpResponseException(CodeMessageException):
if not isinstance(j, dict):
j = {}
- errcode = j.pop('errcode', Codes.UNKNOWN)
- errmsg = j.pop('error', self.msg)
+ errcode = j.pop("errcode", Codes.UNKNOWN)
+ errmsg = j.pop("error", self.msg)
return ProxiedRequestError(self.code, errmsg, errcode, j)
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 3906475403..9b3daca29b 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -28,117 +28,55 @@ FILTER_SCHEMA = {
"additionalProperties": False,
"type": "object",
"properties": {
- "limit": {
- "type": "number"
- },
- "senders": {
- "$ref": "#/definitions/user_id_array"
- },
- "not_senders": {
- "$ref": "#/definitions/user_id_array"
- },
+ "limit": {"type": "number"},
+ "senders": {"$ref": "#/definitions/user_id_array"},
+ "not_senders": {"$ref": "#/definitions/user_id_array"},
# TODO: We don't limit event type values but we probably should...
# check types are valid event types
- "types": {
- "type": "array",
- "items": {
- "type": "string"
- }
- },
- "not_types": {
- "type": "array",
- "items": {
- "type": "string"
- }
- }
- }
+ "types": {"type": "array", "items": {"type": "string"}},
+ "not_types": {"type": "array", "items": {"type": "string"}},
+ },
}
ROOM_FILTER_SCHEMA = {
"additionalProperties": False,
"type": "object",
"properties": {
- "not_rooms": {
- "$ref": "#/definitions/room_id_array"
- },
- "rooms": {
- "$ref": "#/definitions/room_id_array"
- },
- "ephemeral": {
- "$ref": "#/definitions/room_event_filter"
- },
- "include_leave": {
- "type": "boolean"
- },
- "state": {
- "$ref": "#/definitions/room_event_filter"
- },
- "timeline": {
- "$ref": "#/definitions/room_event_filter"
- },
- "account_data": {
- "$ref": "#/definitions/room_event_filter"
- },
- }
+ "not_rooms": {"$ref": "#/definitions/room_id_array"},
+ "rooms": {"$ref": "#/definitions/room_id_array"},
+ "ephemeral": {"$ref": "#/definitions/room_event_filter"},
+ "include_leave": {"type": "boolean"},
+ "state": {"$ref": "#/definitions/room_event_filter"},
+ "timeline": {"$ref": "#/definitions/room_event_filter"},
+ "account_data": {"$ref": "#/definitions/room_event_filter"},
+ },
}
ROOM_EVENT_FILTER_SCHEMA = {
"additionalProperties": False,
"type": "object",
"properties": {
- "limit": {
- "type": "number"
- },
- "senders": {
- "$ref": "#/definitions/user_id_array"
- },
- "not_senders": {
- "$ref": "#/definitions/user_id_array"
- },
- "types": {
- "type": "array",
- "items": {
- "type": "string"
- }
- },
- "not_types": {
- "type": "array",
- "items": {
- "type": "string"
- }
- },
- "rooms": {
- "$ref": "#/definitions/room_id_array"
- },
- "not_rooms": {
- "$ref": "#/definitions/room_id_array"
- },
- "contains_url": {
- "type": "boolean"
- },
- "lazy_load_members": {
- "type": "boolean"
- },
- "include_redundant_members": {
- "type": "boolean"
- },
- }
+ "limit": {"type": "number"},
+ "senders": {"$ref": "#/definitions/user_id_array"},
+ "not_senders": {"$ref": "#/definitions/user_id_array"},
+ "types": {"type": "array", "items": {"type": "string"}},
+ "not_types": {"type": "array", "items": {"type": "string"}},
+ "rooms": {"$ref": "#/definitions/room_id_array"},
+ "not_rooms": {"$ref": "#/definitions/room_id_array"},
+ "contains_url": {"type": "boolean"},
+ "lazy_load_members": {"type": "boolean"},
+ "include_redundant_members": {"type": "boolean"},
+ },
}
USER_ID_ARRAY_SCHEMA = {
"type": "array",
- "items": {
- "type": "string",
- "format": "matrix_user_id"
- }
+ "items": {"type": "string", "format": "matrix_user_id"},
}
ROOM_ID_ARRAY_SCHEMA = {
"type": "array",
- "items": {
- "type": "string",
- "format": "matrix_room_id"
- }
+ "items": {"type": "string", "format": "matrix_room_id"},
}
USER_FILTER_SCHEMA = {
@@ -150,22 +88,13 @@ USER_FILTER_SCHEMA = {
"user_id_array": USER_ID_ARRAY_SCHEMA,
"filter": FILTER_SCHEMA,
"room_filter": ROOM_FILTER_SCHEMA,
- "room_event_filter": ROOM_EVENT_FILTER_SCHEMA
+ "room_event_filter": ROOM_EVENT_FILTER_SCHEMA,
},
"properties": {
- "presence": {
- "$ref": "#/definitions/filter"
- },
- "account_data": {
- "$ref": "#/definitions/filter"
- },
- "room": {
- "$ref": "#/definitions/room_filter"
- },
- "event_format": {
- "type": "string",
- "enum": ["client", "federation"]
- },
+ "presence": {"$ref": "#/definitions/filter"},
+ "account_data": {"$ref": "#/definitions/filter"},
+ "room": {"$ref": "#/definitions/room_filter"},
+ "event_format": {"type": "string", "enum": ["client", "federation"]},
"event_fields": {
"type": "array",
"items": {
@@ -177,26 +106,25 @@ USER_FILTER_SCHEMA = {
#
# Note that because this is a regular expression, we have to escape
# each backslash in the pattern.
- "pattern": r"^((?!\\\\).)*$"
- }
- }
+ "pattern": r"^((?!\\\\).)*$",
+ },
+ },
},
- "additionalProperties": False
+ "additionalProperties": False,
}
-@FormatChecker.cls_checks('matrix_room_id')
+@FormatChecker.cls_checks("matrix_room_id")
def matrix_room_id_validator(room_id_str):
return RoomID.from_string(room_id_str)
-@FormatChecker.cls_checks('matrix_user_id')
+@FormatChecker.cls_checks("matrix_user_id")
def matrix_user_id_validator(user_id_str):
return UserID.from_string(user_id_str)
class Filtering(object):
-
def __init__(self, hs):
super(Filtering, self).__init__()
self.store = hs.get_datastore()
@@ -228,8 +156,9 @@ class Filtering(object):
# individual top-level key e.g. public_user_data. Filters are made of
# many definitions.
try:
- jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA,
- format_checker=FormatChecker())
+ jsonschema.validate(
+ user_filter_json, USER_FILTER_SCHEMA, format_checker=FormatChecker()
+ )
except jsonschema.ValidationError as e:
raise SynapseError(400, str(e))
@@ -240,10 +169,9 @@ class FilterCollection(object):
room_filter_json = self._filter_json.get("room", {})
- self._room_filter = Filter({
- k: v for k, v in room_filter_json.items()
- if k in ("rooms", "not_rooms")
- })
+ self._room_filter = Filter(
+ {k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms")}
+ )
self._room_timeline_filter = Filter(room_filter_json.get("timeline", {}))
self._room_state_filter = Filter(room_filter_json.get("state", {}))
@@ -252,9 +180,7 @@ class FilterCollection(object):
self._presence_filter = Filter(filter_json.get("presence", {}))
self._account_data = Filter(filter_json.get("account_data", {}))
- self.include_leave = filter_json.get("room", {}).get(
- "include_leave", False
- )
+ self.include_leave = filter_json.get("room", {}).get("include_leave", False)
self.event_fields = filter_json.get("event_fields", [])
self.event_format = filter_json.get("event_format", "client")
@@ -299,22 +225,22 @@ class FilterCollection(object):
def blocks_all_presence(self):
return (
- self._presence_filter.filters_all_types() or
- self._presence_filter.filters_all_senders()
+ self._presence_filter.filters_all_types()
+ or self._presence_filter.filters_all_senders()
)
def blocks_all_room_ephemeral(self):
return (
- self._room_ephemeral_filter.filters_all_types() or
- self._room_ephemeral_filter.filters_all_senders() or
- self._room_ephemeral_filter.filters_all_rooms()
+ self._room_ephemeral_filter.filters_all_types()
+ or self._room_ephemeral_filter.filters_all_senders()
+ or self._room_ephemeral_filter.filters_all_rooms()
)
def blocks_all_room_timeline(self):
return (
- self._room_timeline_filter.filters_all_types() or
- self._room_timeline_filter.filters_all_senders() or
- self._room_timeline_filter.filters_all_rooms()
+ self._room_timeline_filter.filters_all_types()
+ or self._room_timeline_filter.filters_all_senders()
+ or self._room_timeline_filter.filters_all_rooms()
)
@@ -375,12 +301,7 @@ class Filter(object):
# check if there is a string url field in the content for filtering purposes
contains_url = isinstance(content.get("url"), text_type)
- return self.check_fields(
- room_id,
- sender,
- ev_type,
- contains_url,
- )
+ return self.check_fields(room_id, sender, ev_type, contains_url)
def check_fields(self, room_id, sender, event_type, contains_url):
"""Checks whether the filter matches the given event fields.
@@ -391,7 +312,7 @@ class Filter(object):
literal_keys = {
"rooms": lambda v: room_id == v,
"senders": lambda v: sender == v,
- "types": lambda v: _matches_wildcard(event_type, v)
+ "types": lambda v: _matches_wildcard(event_type, v),
}
for name, match_func in literal_keys.items():
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index 296c4a1c17..172841f595 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -44,29 +44,25 @@ class Ratelimiter(object):
"""
self.prune_message_counts(time_now_s)
message_count, time_start, _ignored = self.message_counts.get(
- key, (0., time_now_s, None),
+ key, (0.0, time_now_s, None)
)
time_delta = time_now_s - time_start
sent_count = message_count - time_delta * rate_hz
if sent_count < 0:
allowed = True
time_start = time_now_s
- message_count = 1.
- elif sent_count > burst_count - 1.:
+ message_count = 1.0
+ elif sent_count > burst_count - 1.0:
allowed = False
else:
allowed = True
message_count += 1
if update:
- self.message_counts[key] = (
- message_count, time_start, rate_hz
- )
+ self.message_counts[key] = (message_count, time_start, rate_hz)
if rate_hz > 0:
- time_allowed = (
- time_start + (message_count - burst_count + 1) / rate_hz
- )
+ time_allowed = time_start + (message_count - burst_count + 1) / rate_hz
if time_allowed < time_now_s:
time_allowed = time_now_s
else:
@@ -76,9 +72,7 @@ class Ratelimiter(object):
def prune_message_counts(self, time_now_s):
for key in list(self.message_counts.keys()):
- message_count, time_start, rate_hz = (
- self.message_counts[key]
- )
+ message_count, time_start, rate_hz = self.message_counts[key]
time_delta = time_now_s - time_start
if message_count - time_delta * rate_hz > 0:
break
@@ -92,5 +86,5 @@ class Ratelimiter(object):
if not allowed:
raise LimitExceededError(
- retry_after_ms=int(1000 * (time_allowed - time_now_s)),
+ retry_after_ms=int(1000 * (time_allowed - time_now_s))
)
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index d644803d38..95292b7dec 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -19,9 +19,10 @@ class EventFormatVersions(object):
"""This is an internal enum for tracking the version of the event format,
independently from the room version.
"""
- V1 = 1 # $id:server event id format
- V2 = 2 # MSC1659-style $hash event id format: introduced for room v3
- V3 = 3 # MSC1884-style $hash format: introduced for room v4
+
+ V1 = 1 # $id:server event id format
+ V2 = 2 # MSC1659-style $hash event id format: introduced for room v3
+ V3 = 3 # MSC1884-style $hash format: introduced for room v4
KNOWN_EVENT_FORMAT_VERSIONS = {
@@ -33,8 +34,9 @@ KNOWN_EVENT_FORMAT_VERSIONS = {
class StateResolutionVersions(object):
"""Enum to identify the state resolution algorithms"""
- V1 = 1 # room v1 state res
- V2 = 2 # MSC1442 state res: room v2 and later
+
+ V1 = 1 # room v1 state res
+ V2 = 2 # MSC1442 state res: room v2 and later
class RoomDisposition(object):
@@ -46,10 +48,10 @@ class RoomDisposition(object):
class RoomVersion(object):
"""An object which describes the unique attributes of a room version."""
- identifier = attr.ib() # str; the identifier for this version
- disposition = attr.ib() # str; one of the RoomDispositions
- event_format = attr.ib() # int; one of the EventFormatVersions
- state_res = attr.ib() # int; one of the StateResolutionVersions
+ identifier = attr.ib() # str; the identifier for this version
+ disposition = attr.ib() # str; one of the RoomDispositions
+ event_format = attr.ib() # int; one of the EventFormatVersions
+ state_res = attr.ib() # int; one of the StateResolutionVersions
enforce_key_validity = attr.ib() # bool
@@ -92,11 +94,12 @@ class RoomVersions(object):
KNOWN_ROOM_VERSIONS = {
- v.identifier: v for v in (
+ v.identifier: v
+ for v in (
RoomVersions.V1,
RoomVersions.V2,
RoomVersions.V3,
RoomVersions.V4,
RoomVersions.V5,
)
-} # type: dict[str, RoomVersion]
+} # type: dict[str, RoomVersion]
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index e16c386a14..ff1f39e86c 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -42,13 +42,9 @@ class ConsentURIBuilder(object):
hs_config (synapse.config.homeserver.HomeServerConfig):
"""
if hs_config.form_secret is None:
- raise ConfigError(
- "form_secret not set in config",
- )
+ raise ConfigError("form_secret not set in config")
if hs_config.public_baseurl is None:
- raise ConfigError(
- "public_baseurl not set in config",
- )
+ raise ConfigError("public_baseurl not set in config")
self._hmac_secret = hs_config.form_secret.encode("utf-8")
self._public_baseurl = hs_config.public_baseurl
@@ -64,15 +60,10 @@ class ConsentURIBuilder(object):
(str) the URI where the user can do consent
"""
mac = hmac.new(
- key=self._hmac_secret,
- msg=user_id.encode('ascii'),
- digestmod=sha256,
+ key=self._hmac_secret, msg=user_id.encode("ascii"), digestmod=sha256
).hexdigest()
consent_uri = "%s_matrix/consent?%s" % (
self._public_baseurl,
- urlencode({
- "u": user_id,
- "h": mac
- }),
+ urlencode({"u": user_id, "h": mac}),
)
return consent_uri
diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py
index f56f5fcc13..d877c77834 100644
--- a/synapse/app/__init__.py
+++ b/synapse/app/__init__.py
@@ -43,7 +43,7 @@ def check_bind_error(e, address, bind_addresses):
address (str): Address on which binding was attempted.
bind_addresses (list): Addresses on which the service listens.
"""
- if address == '0.0.0.0' and '::' in bind_addresses:
- logger.warn('Failed to listen on 0.0.0.0, continuing because listening on [::]')
+ if address == "0.0.0.0" and "::" in bind_addresses:
+ logger.warn("Failed to listen on 0.0.0.0, continuing because listening on [::]")
else:
raise e
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 8cc990399f..d50a9840d4 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -19,7 +19,6 @@ import signal
import sys
import traceback
-import psutil
from daemonize import Daemonize
from twisted.internet import defer, error, reactor
@@ -68,21 +67,13 @@ def start_worker_reactor(appname, config):
gc_thresholds=config.gc_thresholds,
pid_file=config.worker_pid_file,
daemonize=config.worker_daemonize,
- cpu_affinity=config.worker_cpu_affinity,
print_pidfile=config.print_pidfile,
logger=logger,
)
def start_reactor(
- appname,
- soft_file_limit,
- gc_thresholds,
- pid_file,
- daemonize,
- cpu_affinity,
- print_pidfile,
- logger,
+ appname, soft_file_limit, gc_thresholds, pid_file, daemonize, print_pidfile, logger
):
""" Run the reactor in the main process
@@ -95,7 +86,6 @@ def start_reactor(
gc_thresholds:
pid_file (str): name of pid file to write to if daemonize is True
daemonize (bool): true to run the reactor in a background process
- cpu_affinity (int|None): cpu affinity mask
print_pidfile (bool): whether to print the pid file, if daemonize is True
logger (logging.Logger): logger instance to pass to Daemonize
"""
@@ -109,20 +99,6 @@ def start_reactor(
# between the sentinel and `run` logcontexts.
with PreserveLoggingContext():
logger.info("Running")
- if cpu_affinity is not None:
- # Turn the bitmask into bits, reverse it so we go from 0 up
- mask_to_bits = bin(cpu_affinity)[2:][::-1]
-
- cpus = []
- cpu_num = 0
-
- for i in mask_to_bits:
- if i == "1":
- cpus.append(cpu_num)
- cpu_num += 1
-
- p = psutil.Process()
- p.cpu_affinity(cpus)
change_resource_limit(soft_file_limit)
if gc_thresholds:
@@ -149,10 +125,10 @@ def start_reactor(
def quit_with_error(error_string):
message_lines = error_string.split("\n")
line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2
- sys.stderr.write("*" * line_length + '\n')
+ sys.stderr.write("*" * line_length + "\n")
for line in message_lines:
sys.stderr.write(" %s\n" % (line.rstrip(),))
- sys.stderr.write("*" * line_length + '\n')
+ sys.stderr.write("*" * line_length + "\n")
sys.exit(1)
@@ -178,14 +154,7 @@ def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
r = []
for address in bind_addresses:
try:
- r.append(
- reactor.listenTCP(
- port,
- factory,
- backlog,
- address
- )
- )
+ r.append(reactor.listenTCP(port, factory, backlog, address))
except error.CannotListenError as e:
check_bind_error(e, address, bind_addresses)
@@ -205,13 +174,7 @@ def listen_ssl(
for address in bind_addresses:
try:
r.append(
- reactor.listenSSL(
- port,
- factory,
- context_factory,
- backlog,
- address
- )
+ reactor.listenSSL(port, factory, context_factory, backlog, address)
)
except error.CannotListenError as e:
check_bind_error(e, address, bind_addresses)
@@ -243,15 +206,13 @@ def refresh_certificate(hs):
if isinstance(i.factory, TLSMemoryBIOFactory):
addr = i.getHost()
logger.info(
- "Replacing TLS context factory on [%s]:%i", addr.host, addr.port,
+ "Replacing TLS context factory on [%s]:%i", addr.host, addr.port
)
# We want to replace TLS factories with a new one, with the new
# TLS configuration. We do this by reaching in and pulling out
# the wrappedFactory, and then re-wrapping it.
i.factory = TLSMemoryBIOFactory(
- hs.tls_server_context_factory,
- False,
- i.factory.wrappedFactory
+ hs.tls_server_context_factory, False, i.factory.wrappedFactory
)
logger.info("Context factories updated.")
@@ -267,6 +228,7 @@ def start(hs, listeners=None):
try:
# Set up the SIGHUP machinery.
if hasattr(signal, "SIGHUP"):
+
def handle_sighup(*args, **kwargs):
for i in _sighup_callbacks:
i(hs)
@@ -302,10 +264,8 @@ def setup_sentry(hs):
return
import sentry_sdk
- sentry_sdk.init(
- dsn=hs.config.sentry_dsn,
- release=get_version_string(synapse),
- )
+
+ sentry_sdk.init(dsn=hs.config.sentry_dsn, release=get_version_string(synapse))
# We set some default tags that give some context to this instance
with sentry_sdk.configure_scope() as scope:
@@ -326,7 +286,7 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100):
many DNS queries at once
"""
new_resolver = _LimitedHostnameResolver(
- reactor.nameResolver, max_dns_requests_in_flight,
+ reactor.nameResolver, max_dns_requests_in_flight
)
reactor.installNameResolver(new_resolver)
@@ -339,11 +299,17 @@ class _LimitedHostnameResolver(object):
def __init__(self, resolver, max_dns_requests_in_flight):
self._resolver = resolver
self._limiter = Linearizer(
- name="dns_client_limiter", max_count=max_dns_requests_in_flight,
+ name="dns_client_limiter", max_count=max_dns_requests_in_flight
)
- def resolveHostName(self, resolutionReceiver, hostName, portNumber=0,
- addressTypes=None, transportSemantics='TCP'):
+ def resolveHostName(
+ self,
+ resolutionReceiver,
+ hostName,
+ portNumber=0,
+ addressTypes=None,
+ transportSemantics="TCP",
+ ):
# We need this function to return `resolutionReceiver` so we do all the
# actual logic involving deferreds in a separate function.
@@ -363,8 +329,14 @@ class _LimitedHostnameResolver(object):
return resolutionReceiver
@defer.inlineCallbacks
- def _resolve(self, resolutionReceiver, hostName, portNumber=0,
- addressTypes=None, transportSemantics='TCP'):
+ def _resolve(
+ self,
+ resolutionReceiver,
+ hostName,
+ portNumber=0,
+ addressTypes=None,
+ transportSemantics="TCP",
+ ):
with (yield self._limiter.queue(())):
# resolveHostName doesn't return a Deferred, so we need to hook into
@@ -374,8 +346,7 @@ class _LimitedHostnameResolver(object):
receiver = _DeferredResolutionReceiver(resolutionReceiver, deferred)
self._resolver.resolveHostName(
- receiver, hostName, portNumber,
- addressTypes, transportSemantics,
+ receiver, hostName, portNumber, addressTypes, transportSemantics
)
yield deferred
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index 33107f56d1..9120bdb143 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -44,7 +44,9 @@ logger = logging.getLogger("synapse.app.appservice")
class AppserviceSlaveStore(
- DirectoryStore, SlavedEventStore, SlavedApplicationServiceStore,
+ DirectoryStore,
+ SlavedEventStore,
+ SlavedApplicationServiceStore,
SlavedRegistrationStore,
):
pass
@@ -74,7 +76,7 @@ class AppserviceServer(HomeServer):
listener_config,
root_resource,
self.version_string,
- )
+ ),
)
logger.info("Synapse appservice now listening on port %d", port)
@@ -88,18 +90,19 @@ class AppserviceServer(HomeServer):
listener["bind_addresses"],
listener["port"],
manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- )
+ username="matrix", password="rabbithole", globals={"hs": self}
+ ),
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(("Metrics listener configured, but "
- "enable_metrics is not True!"))
+ logger.warn(
+ (
+ "Metrics listener configured, but "
+ "enable_metrics is not True!"
+ )
+ )
else:
- _base.listen_metrics(listener["bind_addresses"],
- listener["port"])
+ _base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -132,9 +135,7 @@ class ASReplicationHandler(ReplicationClientHandler):
def start(config_options):
try:
- config = HomeServerConfig.load_config(
- "Synapse appservice", config_options
- )
+ config = HomeServerConfig.load_config("Synapse appservice", config_options)
except ConfigError as e:
sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
@@ -173,6 +174,6 @@ def start(config_options):
_base.start_worker_reactor("synapse-appservice", config)
-if __name__ == '__main__':
+if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index a16e037f32..90bc79cdda 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -37,6 +37,7 @@ from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
+from synapse.replication.slave.storage.groups import SlavedGroupServerStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
from synapse.replication.slave.storage.profile import SlavedProfileStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
@@ -52,6 +53,7 @@ from synapse.rest.client.v1.room import (
PublicRoomListRestServlet,
RoomEventContextServlet,
RoomMemberListRestServlet,
+ RoomMessageListRestServlet,
RoomStateRestServlet,
)
from synapse.rest.client.v1.voip import VoipRestServlet
@@ -74,6 +76,7 @@ class ClientReaderSlavedStore(
SlavedDeviceStore,
SlavedReceiptsStore,
SlavedPushRuleStore,
+ SlavedGroupServerStore,
SlavedAccountDataStore,
SlavedEventStore,
SlavedKeyStore,
@@ -109,6 +112,7 @@ class ClientReaderServer(HomeServer):
JoinedRoomMemberListRestServlet(self).register(resource)
RoomStateRestServlet(self).register(resource)
RoomEventContextServlet(self).register(resource)
+ RoomMessageListRestServlet(self).register(resource)
RegisterRestServlet(self).register(resource)
LoginRestServlet(self).register(resource)
ThreepidRestServlet(self).register(resource)
@@ -118,9 +122,7 @@ class ClientReaderServer(HomeServer):
PushRuleRestServlet(self).register(resource)
VersionsRestServlet().register(resource)
- resources.update({
- "/_matrix/client": resource,
- })
+ resources.update({"/_matrix/client": resource})
root_resource = create_resource_tree(resources, NoResource())
@@ -133,7 +135,7 @@ class ClientReaderServer(HomeServer):
listener_config,
root_resource,
self.version_string,
- )
+ ),
)
logger.info("Synapse client reader now listening on port %d", port)
@@ -147,18 +149,19 @@ class ClientReaderServer(HomeServer):
listener["bind_addresses"],
listener["port"],
manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- )
+ username="matrix", password="rabbithole", globals={"hs": self}
+ ),
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(("Metrics listener configured, but "
- "enable_metrics is not True!"))
+ logger.warn(
+ (
+ "Metrics listener configured, but "
+ "enable_metrics is not True!"
+ )
+ )
else:
- _base.listen_metrics(listener["bind_addresses"],
- listener["port"])
+ _base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -170,9 +173,7 @@ class ClientReaderServer(HomeServer):
def start(config_options):
try:
- config = HomeServerConfig.load_config(
- "Synapse client reader", config_options
- )
+ config = HomeServerConfig.load_config("Synapse client reader", config_options)
except ConfigError as e:
sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
@@ -199,6 +200,6 @@ def start(config_options):
_base.start_worker_reactor("synapse-client-reader", config)
-if __name__ == '__main__':
+if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])
diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py
index b8e5196152..ff522e4499 100644
--- a/synapse/app/event_creator.py
+++ b/synapse/app/event_creator.py
@@ -109,12 +109,14 @@ class EventCreatorServer(HomeServer):
ProfileAvatarURLRestServlet(self).register(resource)
ProfileDisplaynameRestServlet(self).register(resource)
ProfileRestServlet(self).register(resource)
- resources.update({
- "/_matrix/client/r0": resource,
- "/_matrix/client/unstable": resource,
- "/_matrix/client/v2_alpha": resource,
- "/_matrix/client/api/v1": resource,
- })
+ resources.update(
+ {
+ "/_matrix/client/r0": resource,
+ "/_matrix/client/unstable": resource,
+ "/_matrix/client/v2_alpha": resource,
+ "/_matrix/client/api/v1": resource,
+ }
+ )
root_resource = create_resource_tree(resources, NoResource())
@@ -127,7 +129,7 @@ class EventCreatorServer(HomeServer):
listener_config,
root_resource,
self.version_string,
- )
+ ),
)
logger.info("Synapse event creator now listening on port %d", port)
@@ -141,18 +143,19 @@ class EventCreatorServer(HomeServer):
listener["bind_addresses"],
listener["port"],
manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- )
+ username="matrix", password="rabbithole", globals={"hs": self}
+ ),
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(("Metrics listener configured, but "
- "enable_metrics is not True!"))
+ logger.warn(
+ (
+ "Metrics listener configured, but "
+ "enable_metrics is not True!"
+ )
+ )
else:
- _base.listen_metrics(listener["bind_addresses"],
- listener["port"])
+ _base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -164,9 +167,7 @@ class EventCreatorServer(HomeServer):
def start(config_options):
try:
- config = HomeServerConfig.load_config(
- "Synapse event creator", config_options
- )
+ config = HomeServerConfig.load_config("Synapse event creator", config_options)
except ConfigError as e:
sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
@@ -198,6 +199,6 @@ def start(config_options):
_base.start_worker_reactor("synapse-event-creator", config)
-if __name__ == '__main__':
+if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index 7da79dc827..9421420930 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -86,19 +86,18 @@ class FederationReaderServer(HomeServer):
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
elif name == "federation":
- resources.update({
- FEDERATION_PREFIX: TransportLayerServer(self),
- })
+ resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
if name == "openid" and "federation" not in res["names"]:
# Only load the openid resource separately if federation resource
# is not specified since federation resource includes openid
# resource.
- resources.update({
- FEDERATION_PREFIX: TransportLayerServer(
- self,
- servlet_groups=["openid"],
- ),
- })
+ resources.update(
+ {
+ FEDERATION_PREFIX: TransportLayerServer(
+ self, servlet_groups=["openid"]
+ )
+ }
+ )
if name in ["keys", "federation"]:
resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
@@ -115,7 +114,7 @@ class FederationReaderServer(HomeServer):
root_resource,
self.version_string,
),
- reactor=self.get_reactor()
+ reactor=self.get_reactor(),
)
logger.info("Synapse federation reader now listening on port %d", port)
@@ -129,18 +128,19 @@ class FederationReaderServer(HomeServer):
listener["bind_addresses"],
listener["port"],
manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- )
+ username="matrix", password="rabbithole", globals={"hs": self}
+ ),
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(("Metrics listener configured, but "
- "enable_metrics is not True!"))
+ logger.warn(
+ (
+ "Metrics listener configured, but "
+ "enable_metrics is not True!"
+ )
+ )
else:
- _base.listen_metrics(listener["bind_addresses"],
- listener["port"])
+ _base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -181,6 +181,6 @@ def start(config_options):
_base.start_worker_reactor("synapse-federation-reader", config)
-if __name__ == '__main__':
+if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 1d43f2b075..969be58d0b 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -52,8 +52,13 @@ logger = logging.getLogger("synapse.app.federation_sender")
class FederationSenderSlaveStore(
- SlavedDeviceInboxStore, SlavedTransactionStore, SlavedReceiptsStore, SlavedEventStore,
- SlavedRegistrationStore, SlavedDeviceStore, SlavedPresenceStore,
+ SlavedDeviceInboxStore,
+ SlavedTransactionStore,
+ SlavedReceiptsStore,
+ SlavedEventStore,
+ SlavedRegistrationStore,
+ SlavedDeviceStore,
+ SlavedPresenceStore,
):
def __init__(self, db_conn, hs):
super(FederationSenderSlaveStore, self).__init__(db_conn, hs)
@@ -65,10 +70,7 @@ class FederationSenderSlaveStore(
self.federation_out_pos_startup = self._get_federation_out_pos(db_conn)
def _get_federation_out_pos(self, db_conn):
- sql = (
- "SELECT stream_id FROM federation_stream_position"
- " WHERE type = ?"
- )
+ sql = "SELECT stream_id FROM federation_stream_position" " WHERE type = ?"
sql = self.database_engine.convert_param_style(sql)
txn = db_conn.cursor()
@@ -103,7 +105,7 @@ class FederationSenderServer(HomeServer):
listener_config,
root_resource,
self.version_string,
- )
+ ),
)
logger.info("Synapse federation_sender now listening on port %d", port)
@@ -117,18 +119,19 @@ class FederationSenderServer(HomeServer):
listener["bind_addresses"],
listener["port"],
manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- )
+ username="matrix", password="rabbithole", globals={"hs": self}
+ ),
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(("Metrics listener configured, but "
- "enable_metrics is not True!"))
+ logger.warn(
+ (
+ "Metrics listener configured, but "
+ "enable_metrics is not True!"
+ )
+ )
else:
- _base.listen_metrics(listener["bind_addresses"],
- listener["port"])
+ _base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -151,7 +154,9 @@ class FederationSenderReplicationHandler(ReplicationClientHandler):
self.send_handler.process_replication_rows(stream_name, token, rows)
def get_streams_to_replicate(self):
- args = super(FederationSenderReplicationHandler, self).get_streams_to_replicate()
+ args = super(
+ FederationSenderReplicationHandler, self
+ ).get_streams_to_replicate()
args.update(self.send_handler.stream_positions())
return args
@@ -203,6 +208,7 @@ class FederationSenderHandler(object):
"""Processes the replication stream and forwards the appropriate entries
to the federation sender.
"""
+
def __init__(self, hs, replication_client):
self.store = hs.get_datastore()
self._is_mine_id = hs.is_mine_id
@@ -241,7 +247,7 @@ class FederationSenderHandler(object):
# ... and when new receipts happen
elif stream_name == ReceiptsStream.NAME:
run_as_background_process(
- "process_receipts_for_federation", self._on_new_receipts, rows,
+ "process_receipts_for_federation", self._on_new_receipts, rows
)
@defer.inlineCallbacks
@@ -278,12 +284,14 @@ class FederationSenderHandler(object):
# We ACK this token over replication so that the master can drop
# its in memory queues
- self.replication_client.send_federation_ack(self.federation_position)
+ self.replication_client.send_federation_ack(
+ self.federation_position
+ )
self._last_ack = self.federation_position
except Exception:
logger.exception("Error updating federation stream position")
-if __name__ == '__main__':
+if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index 6504da5278..2fd7d57ebf 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -62,14 +62,11 @@ class PresenceStatusStubServlet(RestServlet):
# Pass through the auth headers, if any, in case the access token
# is there.
auth_headers = request.requestHeaders.getRawHeaders("Authorization", [])
- headers = {
- "Authorization": auth_headers,
- }
+ headers = {"Authorization": auth_headers}
try:
result = yield self.http_client.get_json(
- self.main_uri + request.uri.decode('ascii'),
- headers=headers,
+ self.main_uri + request.uri.decode("ascii"), headers=headers
)
except HttpResponseException as e:
raise e.to_synapse_error()
@@ -105,18 +102,19 @@ class KeyUploadServlet(RestServlet):
if device_id is not None:
# passing the device_id here is deprecated; however, we allow it
# for now for compatibility with older clients.
- if (requester.device_id is not None and
- device_id != requester.device_id):
- logger.warning("Client uploading keys for a different device "
- "(logged in as %s, uploading for %s)",
- requester.device_id, device_id)
+ if requester.device_id is not None and device_id != requester.device_id:
+ logger.warning(
+ "Client uploading keys for a different device "
+ "(logged in as %s, uploading for %s)",
+ requester.device_id,
+ device_id,
+ )
else:
device_id = requester.device_id
if device_id is None:
raise SynapseError(
- 400,
- "To upload keys, you must pass device_id when authenticating"
+ 400, "To upload keys, you must pass device_id when authenticating"
)
if body:
@@ -124,13 +122,9 @@ class KeyUploadServlet(RestServlet):
# Pass through the auth headers, if any, in case the access token
# is there.
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", [])
- headers = {
- "Authorization": auth_headers,
- }
+ headers = {"Authorization": auth_headers}
result = yield self.http_client.post_json_get_json(
- self.main_uri + request.uri.decode('ascii'),
- body,
- headers=headers,
+ self.main_uri + request.uri.decode("ascii"), body, headers=headers
)
defer.returnValue((200, result))
@@ -171,12 +165,14 @@ class FrontendProxyServer(HomeServer):
if not self.config.use_presence:
PresenceStatusStubServlet(self).register(resource)
- resources.update({
- "/_matrix/client/r0": resource,
- "/_matrix/client/unstable": resource,
- "/_matrix/client/v2_alpha": resource,
- "/_matrix/client/api/v1": resource,
- })
+ resources.update(
+ {
+ "/_matrix/client/r0": resource,
+ "/_matrix/client/unstable": resource,
+ "/_matrix/client/v2_alpha": resource,
+ "/_matrix/client/api/v1": resource,
+ }
+ )
root_resource = create_resource_tree(resources, NoResource())
@@ -190,7 +186,7 @@ class FrontendProxyServer(HomeServer):
root_resource,
self.version_string,
),
- reactor=self.get_reactor()
+ reactor=self.get_reactor(),
)
logger.info("Synapse client reader now listening on port %d", port)
@@ -204,18 +200,19 @@ class FrontendProxyServer(HomeServer):
listener["bind_addresses"],
listener["port"],
manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- )
+ username="matrix", password="rabbithole", globals={"hs": self}
+ ),
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(("Metrics listener configured, but "
- "enable_metrics is not True!"))
+ logger.warn(
+ (
+ "Metrics listener configured, but "
+ "enable_metrics is not True!"
+ )
+ )
else:
- _base.listen_metrics(listener["bind_addresses"],
- listener["port"])
+ _base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -227,9 +224,7 @@ class FrontendProxyServer(HomeServer):
def start(config_options):
try:
- config = HomeServerConfig.load_config(
- "Synapse frontend proxy", config_options
- )
+ config = HomeServerConfig.load_config("Synapse frontend proxy", config_options)
except ConfigError as e:
sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
@@ -258,6 +253,6 @@ def start(config_options):
_base.start_worker_reactor("synapse-frontend-proxy", config)
-if __name__ == '__main__':
+if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index df524a23dd..49da105cf6 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -101,13 +101,12 @@ class SynapseHomeServer(HomeServer):
# Skip loading openid resource if federation is defined
# since federation resource will include openid
continue
- resources.update(self._configure_named_resource(
- name, res.get("compress", False),
- ))
+ resources.update(
+ self._configure_named_resource(name, res.get("compress", False))
+ )
additional_resources = listener_config.get("additional_resources", {})
- logger.debug("Configuring additional resources: %r",
- additional_resources)
+ logger.debug("Configuring additional resources: %r", additional_resources)
module_api = ModuleApi(self, self.get_auth_handler())
for path, resmodule in additional_resources.items():
handler_cls, config = load_module(resmodule)
@@ -174,60 +173,67 @@ class SynapseHomeServer(HomeServer):
if compress:
client_resource = gz_wrap(client_resource)
- resources.update({
- "/_matrix/client/api/v1": client_resource,
- "/_synapse/password_reset": client_resource,
- "/_matrix/client/r0": client_resource,
- "/_matrix/client/unstable": client_resource,
- "/_matrix/client/v2_alpha": client_resource,
- "/_matrix/client/versions": client_resource,
- "/.well-known/matrix/client": WellKnownResource(self),
- "/_synapse/admin": AdminRestResource(self),
- })
+ resources.update(
+ {
+ "/_matrix/client/api/v1": client_resource,
+ "/_matrix/client/r0": client_resource,
+ "/_matrix/client/unstable": client_resource,
+ "/_matrix/client/v2_alpha": client_resource,
+ "/_matrix/client/versions": client_resource,
+ "/.well-known/matrix/client": WellKnownResource(self),
+ "/_synapse/admin": AdminRestResource(self),
+ }
+ )
if self.get_config().saml2_enabled:
from synapse.rest.saml2 import SAML2Resource
+
resources["/_matrix/saml2"] = SAML2Resource(self)
if name == "consent":
from synapse.rest.consent.consent_resource import ConsentResource
+
consent_resource = ConsentResource(self)
if compress:
consent_resource = gz_wrap(consent_resource)
- resources.update({
- "/_matrix/consent": consent_resource,
- })
+ resources.update({"/_matrix/consent": consent_resource})
if name == "federation":
- resources.update({
- FEDERATION_PREFIX: TransportLayerServer(self),
- })
+ resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
if name == "openid":
- resources.update({
- FEDERATION_PREFIX: TransportLayerServer(self, servlet_groups=["openid"]),
- })
+ resources.update(
+ {
+ FEDERATION_PREFIX: TransportLayerServer(
+ self, servlet_groups=["openid"]
+ )
+ }
+ )
if name in ["static", "client"]:
- resources.update({
- STATIC_PREFIX: File(
- os.path.join(os.path.dirname(synapse.__file__), "static")
- ),
- })
+ resources.update(
+ {
+ STATIC_PREFIX: File(
+ os.path.join(os.path.dirname(synapse.__file__), "static")
+ )
+ }
+ )
if name in ["media", "federation", "client"]:
if self.get_config().enable_media_repo:
media_repo = self.get_media_repository_resource()
- resources.update({
- MEDIA_PREFIX: media_repo,
- LEGACY_MEDIA_PREFIX: media_repo,
- CONTENT_REPO_PREFIX: ContentRepoResource(
- self, self.config.uploads_path
- ),
- })
+ resources.update(
+ {
+ MEDIA_PREFIX: media_repo,
+ LEGACY_MEDIA_PREFIX: media_repo,
+ CONTENT_REPO_PREFIX: ContentRepoResource(
+ self, self.config.uploads_path
+ ),
+ }
+ )
elif name == "media":
raise ConfigError(
- "'media' resource conflicts with enable_media_repo=False",
+ "'media' resource conflicts with enable_media_repo=False"
)
if name in ["keys", "federation"]:
@@ -258,18 +264,14 @@ class SynapseHomeServer(HomeServer):
for listener in listeners:
if listener["type"] == "http":
- self._listening_services.extend(
- self._listener_http(config, listener)
- )
+ self._listening_services.extend(self._listener_http(config, listener))
elif listener["type"] == "manhole":
listen_tcp(
listener["bind_addresses"],
listener["port"],
manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- )
+ username="matrix", password="rabbithole", globals={"hs": self}
+ ),
)
elif listener["type"] == "replication":
services = listen_tcp(
@@ -278,16 +280,17 @@ class SynapseHomeServer(HomeServer):
ReplicationStreamProtocolFactory(self),
)
for s in services:
- reactor.addSystemEventTrigger(
- "before", "shutdown", s.stopListening,
- )
+ reactor.addSystemEventTrigger("before", "shutdown", s.stopListening)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(("Metrics listener configured, but "
- "enable_metrics is not True!"))
+ logger.warn(
+ (
+ "Metrics listener configured, but "
+ "enable_metrics is not True!"
+ )
+ )
else:
- _base.listen_metrics(listener["bind_addresses"],
- listener["port"])
+ _base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -313,7 +316,7 @@ current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit")
registered_reserved_users_mau_gauge = Gauge(
"synapse_admin_mau:registered_reserved_users",
- "Registered users with reserved threepids"
+ "Registered users with reserved threepids",
)
@@ -328,8 +331,7 @@ def setup(config_options):
"""
try:
config = HomeServerConfig.load_or_generate_config(
- "Synapse Homeserver",
- config_options,
+ "Synapse Homeserver", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + str(e) + "\n")
@@ -340,10 +342,7 @@ def setup(config_options):
# generating config files and shouldn't try to continue.
sys.exit(0)
- synapse.config.logger.setup_logging(
- config,
- use_worker_options=False
- )
+ synapse.config.logger.setup_logging(config, use_worker_options=False)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
@@ -358,7 +357,7 @@ def setup(config_options):
database_engine=database_engine,
)
- logger.info("Preparing database: %s...", config.database_config['name'])
+ logger.info("Preparing database: %s...", config.database_config["name"])
try:
with hs.get_db_conn(run_new_connection=False) as db_conn:
@@ -376,7 +375,7 @@ def setup(config_options):
)
sys.exit(1)
- logger.info("Database prepared in %s.", config.database_config['name'])
+ logger.info("Database prepared in %s.", config.database_config["name"])
hs.setup()
hs.setup_master()
@@ -392,9 +391,7 @@ def setup(config_options):
acme = hs.get_acme_handler()
# Check how long the certificate is active for.
- cert_days_remaining = hs.config.is_disk_cert_valid(
- allow_self_signed=False
- )
+ cert_days_remaining = hs.config.is_disk_cert_valid(allow_self_signed=False)
# We want to reprovision if cert_days_remaining is None (meaning no
# certificate exists), or the days remaining number it returns
@@ -402,8 +399,8 @@ def setup(config_options):
provision = False
if (
- cert_days_remaining is None or
- cert_days_remaining < hs.config.acme_reprovision_threshold
+ cert_days_remaining is None
+ or cert_days_remaining < hs.config.acme_reprovision_threshold
):
provision = True
@@ -434,10 +431,7 @@ def setup(config_options):
yield do_acme()
# Check if it needs to be reprovisioned every day.
- hs.get_clock().looping_call(
- reprovision_acme,
- 24 * 60 * 60 * 1000
- )
+ hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
_base.start(hs, config.listeners)
@@ -464,6 +458,7 @@ class SynapseService(service.Service):
A twisted Service class that will start synapse. Used to run synapse
via twistd and a .tac.
"""
+
def __init__(self, config):
self.config = config
@@ -480,6 +475,7 @@ class SynapseService(service.Service):
def run(hs):
PROFILE_SYNAPSE = False
if PROFILE_SYNAPSE:
+
def profile(func):
from cProfile import Profile
from threading import current_thread
@@ -490,13 +486,14 @@ def run(hs):
func(*args, **kargs)
profile.disable()
ident = current_thread().ident
- profile.dump_stats("/tmp/%s.%s.%i.pstat" % (
- hs.hostname, func.__name__, ident
- ))
+ profile.dump_stats(
+ "/tmp/%s.%s.%i.pstat" % (hs.hostname, func.__name__, ident)
+ )
return profiled
from twisted.python.threadpool import ThreadPool
+
ThreadPool._worker = profile(ThreadPool._worker)
reactor.run = profile(reactor.run)
@@ -541,7 +538,10 @@ def run(hs):
stats["total_room_count"] = room_count
stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
- stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms()
+ stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users()
+ stats[
+ "daily_active_rooms"
+ ] = yield hs.get_datastore().count_daily_active_rooms()
stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
r30_results = yield hs.get_datastore().count_r30_users()
@@ -565,8 +565,7 @@ def run(hs):
logger.info("Reporting stats to matrix.org: %s" % (stats,))
try:
yield hs.get_simple_http_client().put_json(
- "https://matrix.org/report-usage-stats/push",
- stats
+ "https://matrix.org/report-usage-stats/push", stats
)
except Exception as e:
logger.warn("Error reporting stats: %s", e)
@@ -581,14 +580,11 @@ def run(hs):
logger.info("report_stats can use psutil")
stats_process.append(process)
except (AttributeError):
- logger.warning(
- "Unable to read memory/cpu stats. Disabling reporting."
- )
+ logger.warning("Unable to read memory/cpu stats. Disabling reporting.")
def generate_user_daily_visit_stats():
return run_as_background_process(
- "generate_user_daily_visits",
- hs.get_datastore().generate_user_daily_visits,
+ "generate_user_daily_visits", hs.get_datastore().generate_user_daily_visits
)
# Rather than update on per session basis, batch up the requests.
@@ -599,9 +595,9 @@ def run(hs):
# monthly active user limiting functionality
def reap_monthly_active_users():
return run_as_background_process(
- "reap_monthly_active_users",
- hs.get_datastore().reap_monthly_active_users,
+ "reap_monthly_active_users", hs.get_datastore().reap_monthly_active_users
)
+
clock.looping_call(reap_monthly_active_users, 1000 * 60 * 60)
reap_monthly_active_users()
@@ -619,8 +615,7 @@ def run(hs):
def start_generate_monthly_active_users():
return run_as_background_process(
- "generate_monthly_active_users",
- generate_monthly_active_users,
+ "generate_monthly_active_users", generate_monthly_active_users
)
start_generate_monthly_active_users()
@@ -646,7 +641,6 @@ def run(hs):
gc_thresholds=hs.config.gc_thresholds,
pid_file=hs.config.pid_file,
daemonize=hs.config.daemonize,
- cpu_affinity=hs.config.cpu_affinity,
print_pidfile=hs.config.print_pidfile,
logger=logger,
)
@@ -660,5 +654,5 @@ def main():
run(hs)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index d4cc4e9443..cf0e2036c3 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -72,13 +72,15 @@ class MediaRepositoryServer(HomeServer):
resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
elif name == "media":
media_repo = self.get_media_repository_resource()
- resources.update({
- MEDIA_PREFIX: media_repo,
- LEGACY_MEDIA_PREFIX: media_repo,
- CONTENT_REPO_PREFIX: ContentRepoResource(
- self, self.config.uploads_path
- ),
- })
+ resources.update(
+ {
+ MEDIA_PREFIX: media_repo,
+ LEGACY_MEDIA_PREFIX: media_repo,
+ CONTENT_REPO_PREFIX: ContentRepoResource(
+ self, self.config.uploads_path
+ ),
+ }
+ )
root_resource = create_resource_tree(resources, NoResource())
@@ -91,7 +93,7 @@ class MediaRepositoryServer(HomeServer):
listener_config,
root_resource,
self.version_string,
- )
+ ),
)
logger.info("Synapse media repository now listening on port %d", port)
@@ -105,18 +107,19 @@ class MediaRepositoryServer(HomeServer):
listener["bind_addresses"],
listener["port"],
manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- )
+ username="matrix", password="rabbithole", globals={"hs": self}
+ ),
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(("Metrics listener configured, but "
- "enable_metrics is not True!"))
+ logger.warn(
+ (
+ "Metrics listener configured, but "
+ "enable_metrics is not True!"
+ )
+ )
else:
- _base.listen_metrics(listener["bind_addresses"],
- listener["port"])
+ _base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -164,6 +167,6 @@ def start(config_options):
_base.start_worker_reactor("synapse-media-repository", config)
-if __name__ == '__main__':
+if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index cbf0d67f51..df29ea5ecb 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -46,36 +46,27 @@ logger = logging.getLogger("synapse.app.pusher")
class PusherSlaveStore(
- SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore,
- SlavedAccountDataStore
+ SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore, SlavedAccountDataStore
):
- update_pusher_last_stream_ordering_and_success = (
- __func__(DataStore.update_pusher_last_stream_ordering_and_success)
+ update_pusher_last_stream_ordering_and_success = __func__(
+ DataStore.update_pusher_last_stream_ordering_and_success
)
- update_pusher_failing_since = (
- __func__(DataStore.update_pusher_failing_since)
- )
+ update_pusher_failing_since = __func__(DataStore.update_pusher_failing_since)
- update_pusher_last_stream_ordering = (
- __func__(DataStore.update_pusher_last_stream_ordering)
+ update_pusher_last_stream_ordering = __func__(
+ DataStore.update_pusher_last_stream_ordering
)
- get_throttle_params_by_room = (
- __func__(DataStore.get_throttle_params_by_room)
- )
+ get_throttle_params_by_room = __func__(DataStore.get_throttle_params_by_room)
- set_throttle_params = (
- __func__(DataStore.set_throttle_params)
- )
+ set_throttle_params = __func__(DataStore.set_throttle_params)
- get_time_of_last_push_action_before = (
- __func__(DataStore.get_time_of_last_push_action_before)
+ get_time_of_last_push_action_before = __func__(
+ DataStore.get_time_of_last_push_action_before
)
- get_profile_displayname = (
- __func__(DataStore.get_profile_displayname)
- )
+ get_profile_displayname = __func__(DataStore.get_profile_displayname)
class PusherServer(HomeServer):
@@ -105,7 +96,7 @@ class PusherServer(HomeServer):
listener_config,
root_resource,
self.version_string,
- )
+ ),
)
logger.info("Synapse pusher now listening on port %d", port)
@@ -119,18 +110,19 @@ class PusherServer(HomeServer):
listener["bind_addresses"],
listener["port"],
manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- )
+ username="matrix", password="rabbithole", globals={"hs": self}
+ ),
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(("Metrics listener configured, but "
- "enable_metrics is not True!"))
+ logger.warn(
+ (
+ "Metrics listener configured, but "
+ "enable_metrics is not True!"
+ )
+ )
else:
- _base.listen_metrics(listener["bind_addresses"],
- listener["port"])
+ _base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -161,9 +153,7 @@ class PusherReplicationHandler(ReplicationClientHandler):
else:
yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
elif stream_name == "events":
- yield self.pusher_pool.on_new_notifications(
- token, token,
- )
+ yield self.pusher_pool.on_new_notifications(token, token)
elif stream_name == "receipts":
yield self.pusher_pool.on_new_receipts(
token, token, set(row.room_id for row in rows)
@@ -188,9 +178,7 @@ class PusherReplicationHandler(ReplicationClientHandler):
def start(config_options):
try:
- config = HomeServerConfig.load_config(
- "Synapse pusher", config_options
- )
+ config = HomeServerConfig.load_config("Synapse pusher", config_options)
except ConfigError as e:
sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
@@ -234,6 +222,6 @@ def start(config_options):
_base.start_worker_reactor("synapse-pusher", config)
-if __name__ == '__main__':
+if __name__ == "__main__":
with LoggingContext("main"):
ps = start(sys.argv[1:])
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 5388def28a..858949910d 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -98,10 +98,7 @@ class SynchrotronPresence(object):
self.notifier = hs.get_notifier()
active_presence = self.store.take_presence_startup_info()
- self.user_to_current_state = {
- state.user_id: state
- for state in active_presence
- }
+ self.user_to_current_state = {state.user_id: state for state in active_presence}
# user_id -> last_sync_ms. Lists the users that have stopped syncing
# but we haven't notified the master of that yet
@@ -196,17 +193,26 @@ class SynchrotronPresence(object):
room_ids_to_states, users_to_states = parties
self.notifier.on_new_event(
- "presence_key", stream_id, rooms=room_ids_to_states.keys(),
- users=users_to_states.keys()
+ "presence_key",
+ stream_id,
+ rooms=room_ids_to_states.keys(),
+ users=users_to_states.keys(),
)
@defer.inlineCallbacks
def process_replication_rows(self, token, rows):
- states = [UserPresenceState(
- row.user_id, row.state, row.last_active_ts,
- row.last_federation_update_ts, row.last_user_sync_ts, row.status_msg,
- row.currently_active
- ) for row in rows]
+ states = [
+ UserPresenceState(
+ row.user_id,
+ row.state,
+ row.last_active_ts,
+ row.last_federation_update_ts,
+ row.last_user_sync_ts,
+ row.status_msg,
+ row.currently_active,
+ )
+ for row in rows
+ ]
for state in states:
self.user_to_current_state[state.user_id] = state
@@ -217,7 +223,8 @@ class SynchrotronPresence(object):
def get_currently_syncing_users(self):
if self.hs.config.use_presence:
return [
- user_id for user_id, count in iteritems(self.user_to_num_current_syncs)
+ user_id
+ for user_id, count in iteritems(self.user_to_num_current_syncs)
if count > 0
]
else:
@@ -281,12 +288,14 @@ class SynchrotronServer(HomeServer):
events.register_servlets(self, resource)
InitialSyncRestServlet(self).register(resource)
RoomInitialSyncRestServlet(self).register(resource)
- resources.update({
- "/_matrix/client/r0": resource,
- "/_matrix/client/unstable": resource,
- "/_matrix/client/v2_alpha": resource,
- "/_matrix/client/api/v1": resource,
- })
+ resources.update(
+ {
+ "/_matrix/client/r0": resource,
+ "/_matrix/client/unstable": resource,
+ "/_matrix/client/v2_alpha": resource,
+ "/_matrix/client/api/v1": resource,
+ }
+ )
root_resource = create_resource_tree(resources, NoResource())
@@ -299,7 +308,7 @@ class SynchrotronServer(HomeServer):
listener_config,
root_resource,
self.version_string,
- )
+ ),
)
logger.info("Synapse synchrotron now listening on port %d", port)
@@ -313,18 +322,19 @@ class SynchrotronServer(HomeServer):
listener["bind_addresses"],
listener["port"],
manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- )
+ username="matrix", password="rabbithole", globals={"hs": self}
+ ),
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(("Metrics listener configured, but "
- "enable_metrics is not True!"))
+ logger.warn(
+ (
+ "Metrics listener configured, but "
+ "enable_metrics is not True!"
+ )
+ )
else:
- _base.listen_metrics(listener["bind_addresses"],
- listener["port"])
+ _base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -382,40 +392,36 @@ class SyncReplicationHandler(ReplicationClientHandler):
)
elif stream_name == "push_rules":
self.notifier.on_new_event(
- "push_rules_key", token, users=[row.user_id for row in rows],
+ "push_rules_key", token, users=[row.user_id for row in rows]
)
- elif stream_name in ("account_data", "tag_account_data",):
+ elif stream_name in ("account_data", "tag_account_data"):
self.notifier.on_new_event(
- "account_data_key", token, users=[row.user_id for row in rows],
+ "account_data_key", token, users=[row.user_id for row in rows]
)
elif stream_name == "receipts":
self.notifier.on_new_event(
- "receipt_key", token, rooms=[row.room_id for row in rows],
+ "receipt_key", token, rooms=[row.room_id for row in rows]
)
elif stream_name == "typing":
self.typing_handler.process_replication_rows(token, rows)
self.notifier.on_new_event(
- "typing_key", token, rooms=[row.room_id for row in rows],
+ "typing_key", token, rooms=[row.room_id for row in rows]
)
elif stream_name == "to_device":
entities = [row.entity for row in rows if row.entity.startswith("@")]
if entities:
- self.notifier.on_new_event(
- "to_device_key", token, users=entities,
- )
+ self.notifier.on_new_event("to_device_key", token, users=entities)
elif stream_name == "device_lists":
all_room_ids = set()
for row in rows:
room_ids = yield self.store.get_rooms_for_user(row.user_id)
all_room_ids.update(room_ids)
- self.notifier.on_new_event(
- "device_list_key", token, rooms=all_room_ids,
- )
+ self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids)
elif stream_name == "presence":
yield self.presence_handler.process_replication_rows(token, rows)
elif stream_name == "receipts":
self.notifier.on_new_event(
- "groups_key", token, users=[row.user_id for row in rows],
+ "groups_key", token, users=[row.user_id for row in rows]
)
except Exception:
logger.exception("Error processing replication")
@@ -423,9 +429,7 @@ class SyncReplicationHandler(ReplicationClientHandler):
def start(config_options):
try:
- config = HomeServerConfig.load_config(
- "Synapse synchrotron", config_options
- )
+ config = HomeServerConfig.load_config("Synapse synchrotron", config_options)
except ConfigError as e:
sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
@@ -453,6 +457,6 @@ def start(config_options):
_base.start_worker_reactor("synapse-synchrotron", config)
-if __name__ == '__main__':
+if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index 355f5aa71d..2d9d2e1bbc 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -66,14 +66,16 @@ class UserDirectorySlaveStore(
events_max = self._stream_id_gen.get_current_token()
curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
- db_conn, "current_state_delta_stream",
+ db_conn,
+ "current_state_delta_stream",
entity_column="room_id",
stream_column="stream_id",
max_value=events_max, # As we share the stream id with events token
limit=1000,
)
self._curr_state_delta_stream_cache = StreamChangeCache(
- "_curr_state_delta_stream_cache", min_curr_state_delta_id,
+ "_curr_state_delta_stream_cache",
+ min_curr_state_delta_id,
prefilled_cache=curr_state_delta_prefill,
)
@@ -110,12 +112,14 @@ class UserDirectoryServer(HomeServer):
elif name == "client":
resource = JsonResource(self, canonical_json=False)
user_directory.register_servlets(self, resource)
- resources.update({
- "/_matrix/client/r0": resource,
- "/_matrix/client/unstable": resource,
- "/_matrix/client/v2_alpha": resource,
- "/_matrix/client/api/v1": resource,
- })
+ resources.update(
+ {
+ "/_matrix/client/r0": resource,
+ "/_matrix/client/unstable": resource,
+ "/_matrix/client/v2_alpha": resource,
+ "/_matrix/client/api/v1": resource,
+ }
+ )
root_resource = create_resource_tree(resources, NoResource())
@@ -128,7 +132,7 @@ class UserDirectoryServer(HomeServer):
listener_config,
root_resource,
self.version_string,
- )
+ ),
)
logger.info("Synapse user_dir now listening on port %d", port)
@@ -142,18 +146,19 @@ class UserDirectoryServer(HomeServer):
listener["bind_addresses"],
listener["port"],
manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- )
+ username="matrix", password="rabbithole", globals={"hs": self}
+ ),
)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(("Metrics listener configured, but "
- "enable_metrics is not True!"))
+ logger.warn(
+ (
+ "Metrics listener configured, but "
+ "enable_metrics is not True!"
+ )
+ )
else:
- _base.listen_metrics(listener["bind_addresses"],
- listener["port"])
+ _base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -186,9 +191,7 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler):
def start(config_options):
try:
- config = HomeServerConfig.load_config(
- "Synapse user directory", config_options
- )
+ config = HomeServerConfig.load_config("Synapse user directory", config_options)
except ConfigError as e:
sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
@@ -227,6 +230,6 @@ def start(config_options):
_base.start_worker_reactor("synapse-user-dir", config)
-if __name__ == '__main__':
+if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index 57ed8a3ca2..b26a31dd54 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -48,9 +48,7 @@ class AppServiceTransaction(object):
A Deferred which resolves to True if the transaction was sent.
"""
return as_api.push_bulk(
- service=self.service,
- events=self.events,
- txn_id=self.id
+ service=self.service, events=self.events, txn_id=self.id
)
def complete(self, store):
@@ -64,10 +62,7 @@ class AppServiceTransaction(object):
Returns:
A Deferred which resolves to True if the transaction was completed.
"""
- return store.complete_appservice_txn(
- service=self.service,
- txn_id=self.id
- )
+ return store.complete_appservice_txn(service=self.service, txn_id=self.id)
class ApplicationService(object):
@@ -76,6 +71,7 @@ class ApplicationService(object):
Provides methods to check if this service is "interested" in events.
"""
+
NS_USERS = "users"
NS_ALIASES = "aliases"
NS_ROOMS = "rooms"
@@ -84,9 +80,19 @@ class ApplicationService(object):
# values.
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
- def __init__(self, token, hostname, url=None, namespaces=None, hs_token=None,
- sender=None, id=None, protocols=None, rate_limited=True,
- ip_range_whitelist=None):
+ def __init__(
+ self,
+ token,
+ hostname,
+ url=None,
+ namespaces=None,
+ hs_token=None,
+ sender=None,
+ id=None,
+ protocols=None,
+ rate_limited=True,
+ ip_range_whitelist=None,
+ ):
self.token = token
self.url = url
self.hs_token = hs_token
@@ -128,9 +134,7 @@ class ApplicationService(object):
if not isinstance(regex_obj, dict):
raise ValueError("Expected dict regex for ns '%s'" % ns)
if not isinstance(regex_obj.get("exclusive"), bool):
- raise ValueError(
- "Expected bool for 'exclusive' in ns '%s'" % ns
- )
+ raise ValueError("Expected bool for 'exclusive' in ns '%s'" % ns)
group_id = regex_obj.get("group_id")
if group_id:
if not isinstance(group_id, str):
@@ -153,9 +157,7 @@ class ApplicationService(object):
if isinstance(regex, string_types):
regex_obj["regex"] = re.compile(regex) # Pre-compile regex
else:
- raise ValueError(
- "Expected string for 'regex' in ns '%s'" % ns
- )
+ raise ValueError("Expected string for 'regex' in ns '%s'" % ns)
return namespaces
def _matches_regex(self, test_string, namespace_key):
@@ -178,8 +180,9 @@ class ApplicationService(object):
if self.is_interested_in_user(event.sender):
defer.returnValue(True)
# also check m.room.member state key
- if (event.type == EventTypes.Member and
- self.is_interested_in_user(event.state_key)):
+ if event.type == EventTypes.Member and self.is_interested_in_user(
+ event.state_key
+ ):
defer.returnValue(True)
if not store:
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 9ccc5a80fc..571881775b 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -32,19 +32,17 @@ logger = logging.getLogger(__name__)
sent_transactions_counter = Counter(
"synapse_appservice_api_sent_transactions",
"Number of /transactions/ requests sent",
- ["service"]
+ ["service"],
)
failed_transactions_counter = Counter(
"synapse_appservice_api_failed_transactions",
"Number of /transactions/ requests that failed to send",
- ["service"]
+ ["service"],
)
sent_events_counter = Counter(
- "synapse_appservice_api_sent_events",
- "Number of events sent to the AS",
- ["service"]
+ "synapse_appservice_api_sent_events", "Number of events sent to the AS", ["service"]
)
HOUR_IN_MS = 60 * 60 * 1000
@@ -92,8 +90,9 @@ class ApplicationServiceApi(SimpleHttpClient):
super(ApplicationServiceApi, self).__init__(hs)
self.clock = hs.get_clock()
- self.protocol_meta_cache = ResponseCache(hs, "as_protocol_meta",
- timeout_ms=HOUR_IN_MS)
+ self.protocol_meta_cache = ResponseCache(
+ hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS
+ )
@defer.inlineCallbacks
def query_user(self, service, user_id):
@@ -102,9 +101,7 @@ class ApplicationServiceApi(SimpleHttpClient):
uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
response = None
try:
- response = yield self.get_json(uri, {
- "access_token": service.hs_token
- })
+ response = yield self.get_json(uri, {"access_token": service.hs_token})
if response is not None: # just an empty json object
defer.returnValue(True)
except CodeMessageException as e:
@@ -123,9 +120,7 @@ class ApplicationServiceApi(SimpleHttpClient):
uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
response = None
try:
- response = yield self.get_json(uri, {
- "access_token": service.hs_token
- })
+ response = yield self.get_json(uri, {"access_token": service.hs_token})
if response is not None: # just an empty json object
defer.returnValue(True)
except CodeMessageException as e:
@@ -144,9 +139,7 @@ class ApplicationServiceApi(SimpleHttpClient):
elif kind == ThirdPartyEntityKind.LOCATION:
required_field = "alias"
else:
- raise ValueError(
- "Unrecognised 'kind' argument %r to query_3pe()", kind
- )
+ raise ValueError("Unrecognised 'kind' argument %r to query_3pe()", kind)
if service.url is None:
defer.returnValue([])
@@ -154,14 +147,13 @@ class ApplicationServiceApi(SimpleHttpClient):
service.url,
APP_SERVICE_PREFIX,
kind,
- urllib.parse.quote(protocol)
+ urllib.parse.quote(protocol),
)
try:
response = yield self.get_json(uri, fields)
if not isinstance(response, list):
logger.warning(
- "query_3pe to %s returned an invalid response %r",
- uri, response
+ "query_3pe to %s returned an invalid response %r", uri, response
)
defer.returnValue([])
@@ -171,8 +163,7 @@ class ApplicationServiceApi(SimpleHttpClient):
ret.append(r)
else:
logger.warning(
- "query_3pe to %s returned an invalid result %r",
- uri, r
+ "query_3pe to %s returned an invalid result %r", uri, r
)
defer.returnValue(ret)
@@ -189,27 +180,27 @@ class ApplicationServiceApi(SimpleHttpClient):
uri = "%s%s/thirdparty/protocol/%s" % (
service.url,
APP_SERVICE_PREFIX,
- urllib.parse.quote(protocol)
+ urllib.parse.quote(protocol),
)
try:
info = yield self.get_json(uri, {})
if not _is_valid_3pe_metadata(info):
- logger.warning("query_3pe_protocol to %s did not return a"
- " valid result", uri)
+ logger.warning(
+ "query_3pe_protocol to %s did not return a" " valid result", uri
+ )
defer.returnValue(None)
for instance in info.get("instances", []):
network_id = instance.get("network_id", None)
if network_id is not None:
instance["instance_id"] = ThirdPartyInstanceID(
- service.id, network_id,
+ service.id, network_id
).to_string()
defer.returnValue(info)
except Exception as ex:
- logger.warning("query_3pe_protocol to %s threw exception %s",
- uri, ex)
+ logger.warning("query_3pe_protocol to %s threw exception %s", uri, ex)
defer.returnValue(None)
key = (service.id, protocol)
@@ -223,22 +214,19 @@ class ApplicationServiceApi(SimpleHttpClient):
events = self._serialize(events)
if txn_id is None:
- logger.warning("push_bulk: Missing txn ID sending events to %s",
- service.url)
+ logger.warning(
+ "push_bulk: Missing txn ID sending events to %s", service.url
+ )
txn_id = str(0)
txn_id = str(txn_id)
- uri = service.url + ("/transactions/%s" %
- urllib.parse.quote(txn_id))
+ uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id))
try:
yield self.put_json(
uri=uri,
- json_body={
- "events": events
- },
- args={
- "access_token": service.hs_token
- })
+ json_body={"events": events},
+ args={"access_token": service.hs_token},
+ )
sent_transactions_counter.labels(service.id).inc()
sent_events_counter.labels(service.id).inc(len(events))
defer.returnValue(True)
@@ -252,6 +240,4 @@ class ApplicationServiceApi(SimpleHttpClient):
def _serialize(self, events):
time_now = self.clock.time_msec()
- return [
- serialize_event(e, time_now, as_client_event=True) for e in events
- ]
+ return [serialize_event(e, time_now, as_client_event=True) for e in events]
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 685f15c061..b54bf5411f 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -112,15 +112,14 @@ class _ServiceQueuer(object):
return
run_as_background_process(
- "as-sender-%s" % (service.id, ),
- self._send_request, service,
+ "as-sender-%s" % (service.id,), self._send_request, service
)
@defer.inlineCallbacks
def _send_request(self, service):
# sanity-check: we shouldn't get here if this service already has a sender
# running.
- assert(service.id not in self.requests_in_flight)
+ assert service.id not in self.requests_in_flight
self.requests_in_flight.add(service.id)
try:
@@ -137,7 +136,6 @@ class _ServiceQueuer(object):
class _TransactionController(object):
-
def __init__(self, clock, store, as_api, recoverer_fn):
self.clock = clock
self.store = store
@@ -149,10 +147,7 @@ class _TransactionController(object):
@defer.inlineCallbacks
def send(self, service, events):
try:
- txn = yield self.store.create_appservice_txn(
- service=service,
- events=events
- )
+ txn = yield self.store.create_appservice_txn(service=service, events=events)
service_is_up = yield self._is_service_up(service)
if service_is_up:
sent = yield txn.send(self.as_api)
@@ -167,12 +162,12 @@ class _TransactionController(object):
@defer.inlineCallbacks
def on_recovered(self, recoverer):
self.recoverers.remove(recoverer)
- logger.info("Successfully recovered application service AS ID %s",
- recoverer.service.id)
+ logger.info(
+ "Successfully recovered application service AS ID %s", recoverer.service.id
+ )
logger.info("Remaining active recoverers: %s", len(self.recoverers))
yield self.store.set_appservice_state(
- recoverer.service,
- ApplicationServiceState.UP
+ recoverer.service, ApplicationServiceState.UP
)
def add_recoverers(self, recoverers):
@@ -184,13 +179,10 @@ class _TransactionController(object):
@defer.inlineCallbacks
def _start_recoverer(self, service):
try:
- yield self.store.set_appservice_state(
- service,
- ApplicationServiceState.DOWN
- )
+ yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
logger.info(
"Application service falling behind. Starting recoverer. AS ID %s",
- service.id
+ service.id,
)
recoverer = self.recoverer_fn(service, self.on_recovered)
self.add_recoverers([recoverer])
@@ -205,19 +197,16 @@ class _TransactionController(object):
class _Recoverer(object):
-
@staticmethod
@defer.inlineCallbacks
def start(clock, store, as_api, callback):
- services = yield store.get_appservices_by_state(
- ApplicationServiceState.DOWN
- )
- recoverers = [
- _Recoverer(clock, store, as_api, s, callback) for s in services
- ]
+ services = yield store.get_appservices_by_state(ApplicationServiceState.DOWN)
+ recoverers = [_Recoverer(clock, store, as_api, s, callback) for s in services]
for r in recoverers:
- logger.info("Starting recoverer for AS ID %s which was marked as "
- "DOWN", r.service.id)
+ logger.info(
+ "Starting recoverer for AS ID %s which was marked as " "DOWN",
+ r.service.id,
+ )
r.recover()
defer.returnValue(recoverers)
@@ -232,9 +221,9 @@ class _Recoverer(object):
def recover(self):
def _retry():
run_as_background_process(
- "as-recoverer-%s" % (self.service.id,),
- self.retry,
+ "as-recoverer-%s" % (self.service.id,), self.retry
)
+
self.clock.call_later((2 ** self.backoff_counter), _retry)
def _backoff(self):
@@ -248,8 +237,9 @@ class _Recoverer(object):
try:
txn = yield self.store.get_oldest_unsent_txn(self.service)
if txn:
- logger.info("Retrying transaction %s for AS ID %s",
- txn.id, txn.service.id)
+ logger.info(
+ "Retrying transaction %s for AS ID %s", txn.id, txn.service.id
+ )
sent = yield txn.send(self.as_api)
if sent:
yield txn.complete(self.store)
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index f7d7f153bb..965478d8d5 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -134,11 +136,6 @@ class Config(object):
with open(file_path) as file_stream:
return file_stream.read()
- @staticmethod
- def read_config_file(file_path):
- with open(file_path) as file_stream:
- return yaml.safe_load(file_stream)
-
def invoke_all(self, name, *args, **kargs):
results = []
for cls in type(self).mro():
@@ -153,12 +150,12 @@ class Config(object):
server_name,
generate_secrets=False,
report_stats=None,
+ open_private_ports=False,
):
"""Build a default configuration file
- This is used both when the user explicitly asks us to generate a config file
- (eg with --generate_config), and before loading the config at runtime (to give
- a base which the config files override)
+ This is used when the user explicitly asks us to generate a config file
+ (eg with --generate_config).
Args:
config_dir_path (str): The path where the config files are kept. Used to
@@ -177,25 +174,33 @@ class Config(object):
report_stats (bool|None): Initial setting for the report_stats setting.
If None, report_stats will be left unset.
+ open_private_ports (bool): True to leave private ports (such as the non-TLS
+ HTTP listener) open to the internet.
+
Returns:
str: the yaml config file
"""
- default_config = "\n\n".join(
+ return "\n\n".join(
dedent(conf)
for conf in self.invoke_all(
- "default_config",
+ "generate_config_section",
config_dir_path=config_dir_path,
data_dir_path=data_dir_path,
server_name=server_name,
generate_secrets=generate_secrets,
report_stats=report_stats,
+ open_private_ports=open_private_ports,
)
)
- return default_config
-
@classmethod
def load_config(cls, description, argv):
+ """Parse the commandline and config files
+
+ Doesn't support config-file-generation: used by the worker apps.
+
+ Returns: Config object.
+ """
config_parser = argparse.ArgumentParser(description=description)
config_parser.add_argument(
"-c",
@@ -210,7 +215,7 @@ class Config(object):
"--keys-directory",
metavar="DIRECTORY",
help="Where files such as certs and signing keys are stored when"
- " their location is given explicitly in the config."
+ " their location is not given explicitly in the config."
" Defaults to the directory containing the last config file",
)
@@ -222,8 +227,19 @@ class Config(object):
config_files = find_config_files(search_paths=config_args.config_path)
- obj.read_config_files(
- config_files, keys_directory=config_args.keys_directory, generate_keys=False
+ if not config_files:
+ config_parser.error("Must supply a config file.")
+
+ if config_args.keys_directory:
+ config_dir_path = config_args.keys_directory
+ else:
+ config_dir_path = os.path.dirname(config_files[-1])
+ config_dir_path = os.path.abspath(config_dir_path)
+ data_dir_path = os.getcwd()
+
+ config_dict = read_config_files(config_files)
+ obj.parse_config_dict(
+ config_dict, config_dir_path=config_dir_path, data_dir_path=data_dir_path
)
obj.invoke_all("read_arguments", config_args)
@@ -232,6 +248,12 @@ class Config(object):
@classmethod
def load_or_generate_config(cls, description, argv):
+ """Parse the commandline and config files
+
+ Supports generation of config files, so is used for the main homeserver app.
+
+ Returns: Config object, or None if --generate-config or --generate-keys was set
+ """
config_parser = argparse.ArgumentParser(add_help=False)
config_parser.add_argument(
"-c",
@@ -241,37 +263,74 @@ class Config(object):
help="Specify config file. Can be given multiple times and"
" may specify directories containing *.yaml files.",
)
- config_parser.add_argument(
+
+ generate_group = config_parser.add_argument_group("Config generation")
+ generate_group.add_argument(
"--generate-config",
action="store_true",
- help="Generate a config file for the server name",
+ help="Generate a config file, then exit.",
)
- config_parser.add_argument(
+ generate_group.add_argument(
+ "--generate-missing-configs",
+ "--generate-keys",
+ action="store_true",
+ help="Generate any missing additional config files, then exit.",
+ )
+ generate_group.add_argument(
+ "-H", "--server-name", help="The server name to generate a config file for."
+ )
+ generate_group.add_argument(
"--report-stats",
action="store",
- help="Whether the generated config reports anonymized usage statistics",
+ help="Whether the generated config reports anonymized usage statistics.",
choices=["yes", "no"],
)
- config_parser.add_argument(
- "--generate-keys",
- action="store_true",
- help="Generate any missing key files then exit",
- )
- config_parser.add_argument(
+ generate_group.add_argument(
+ "--config-directory",
"--keys-directory",
metavar="DIRECTORY",
- help="Used with 'generate-*' options to specify where files such as"
- " signing keys should be stored, unless explicitly"
- " specified in the config.",
+ help=(
+ "Specify where additional config files such as signing keys and log"
+ " config should be stored. Defaults to the same directory as the last"
+ " config file."
+ ),
)
- config_parser.add_argument(
- "-H", "--server-name", help="The server name to generate a config file for"
+ generate_group.add_argument(
+ "--data-directory",
+ metavar="DIRECTORY",
+ help=(
+ "Specify where data such as the media store and database file should be"
+ " stored. Defaults to the current working directory."
+ ),
+ )
+ generate_group.add_argument(
+ "--open-private-ports",
+ action="store_true",
+ help=(
+ "Leave private ports (such as the non-TLS HTTP listener) open to the"
+ " internet. Do not use this unless you know what you are doing."
+ ),
)
+
config_args, remaining_args = config_parser.parse_known_args(argv)
config_files = find_config_files(search_paths=config_args.config_path)
- generate_keys = config_args.generate_keys
+ if not config_files:
+ config_parser.error(
+ "Must supply a config file.\nA config file can be automatically"
+ ' generated using "--generate-config -H SERVER_NAME'
+ ' -c CONFIG-FILE"'
+ )
+
+ if config_args.config_directory:
+ config_dir_path = config_args.config_directory
+ else:
+ config_dir_path = os.path.dirname(config_files[-1])
+ config_dir_path = os.path.abspath(config_dir_path)
+ data_dir_path = os.getcwd()
+
+ generate_missing_configs = config_args.generate_missing_configs
obj = cls()
@@ -281,19 +340,16 @@ class Config(object):
"Please specify either --report-stats=yes or --report-stats=no\n\n"
+ MISSING_REPORT_STATS_SPIEL
)
- if not config_files:
- config_parser.error(
- "Must supply a config file.\nA config file can be automatically"
- " generated using \"--generate-config -H SERVER_NAME"
- " -c CONFIG-FILE\""
- )
+
(config_path,) = config_files
if not cls.path_exists(config_path):
- if config_args.keys_directory:
- config_dir_path = config_args.keys_directory
+ print("Generating config file %s" % (config_path,))
+
+ if config_args.data_directory:
+ data_dir_path = config_args.data_directory
else:
- config_dir_path = os.path.dirname(config_path)
- config_dir_path = os.path.abspath(config_dir_path)
+ data_dir_path = os.getcwd()
+ data_dir_path = os.path.abspath(data_dir_path)
server_name = config_args.server_name
if not server_name:
@@ -304,22 +360,21 @@ class Config(object):
config_str = obj.generate_config(
config_dir_path=config_dir_path,
- data_dir_path=os.getcwd(),
+ data_dir_path=data_dir_path,
server_name=server_name,
report_stats=(config_args.report_stats == "yes"),
generate_secrets=True,
+ open_private_ports=config_args.open_private_ports,
)
if not cls.path_exists(config_dir_path):
os.makedirs(config_dir_path)
with open(config_path, "w") as config_file:
- config_file.write(
- "# vim:ft=yaml\n\n"
- )
+ config_file.write("# vim:ft=yaml\n\n")
config_file.write(config_str)
- config = yaml.safe_load(config_str)
- obj.invoke_all("generate_files", config)
+ config_dict = yaml.safe_load(config_str)
+ obj.generate_missing_files(config_dict, config_dir_path)
print(
(
@@ -333,12 +388,12 @@ class Config(object):
else:
print(
(
- "Config file %r already exists. Generating any missing key"
+ "Config file %r already exists. Generating any missing config"
" files."
)
% (config_path,)
)
- generate_keys = True
+ generate_missing_configs = True
parser = argparse.ArgumentParser(
parents=[config_parser],
@@ -349,66 +404,63 @@ class Config(object):
obj.invoke_all("add_arguments", parser)
args = parser.parse_args(remaining_args)
- if not config_files:
- config_parser.error(
- "Must supply a config file.\nA config file can be automatically"
- " generated using \"--generate-config -H SERVER_NAME"
- " -c CONFIG-FILE\""
- )
-
- obj.read_config_files(
- config_files,
- keys_directory=config_args.keys_directory,
- generate_keys=generate_keys,
- )
-
- if generate_keys:
+ config_dict = read_config_files(config_files)
+ if generate_missing_configs:
+ obj.generate_missing_files(config_dict, config_dir_path)
return None
+ obj.parse_config_dict(
+ config_dict, config_dir_path=config_dir_path, data_dir_path=data_dir_path
+ )
obj.invoke_all("read_arguments", args)
return obj
- def read_config_files(self, config_files, keys_directory=None, generate_keys=False):
- if not keys_directory:
- keys_directory = os.path.dirname(config_files[-1])
+ def parse_config_dict(self, config_dict, config_dir_path, data_dir_path):
+ """Read the information from the config dict into this Config object.
- self.config_dir_path = os.path.abspath(keys_directory)
-
- specified_config = {}
- for config_file in config_files:
- yaml_config = self.read_config_file(config_file)
- specified_config.update(yaml_config)
+ Args:
+ config_dict (dict): Configuration data, as read from the yaml
- if "server_name" not in specified_config:
- raise ConfigError(MISSING_SERVER_NAME)
+ config_dir_path (str): The path where the config files are kept. Used to
+ create filenames for things like the log config and the signing key.
- server_name = specified_config["server_name"]
- config_string = self.generate_config(
- config_dir_path=self.config_dir_path,
- data_dir_path=os.getcwd(),
- server_name=server_name,
- generate_secrets=False,
+ data_dir_path (str): The path where the data files are kept. Used to create
+ filenames for things like the database and media store.
+ """
+ self.invoke_all(
+ "read_config",
+ config_dict,
+ config_dir_path=config_dir_path,
+ data_dir_path=data_dir_path,
)
- config = yaml.safe_load(config_string)
- config.pop("log_config")
- config.update(specified_config)
- if "report_stats" not in config:
- raise ConfigError(
- MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS
- + "\n"
- + MISSING_REPORT_STATS_SPIEL
- )
+ def generate_missing_files(self, config_dict, config_dir_path):
+ self.invoke_all("generate_files", config_dict, config_dir_path)
- if generate_keys:
- self.invoke_all("generate_files", config)
- return
- self.parse_config_dict(config)
+def read_config_files(config_files):
+ """Read the config files into a dict
- def parse_config_dict(self, config_dict):
- self.invoke_all("read_config", config_dict)
+ Args:
+ config_files (iterable[str]): A list of the config files to read
+
+ Returns: dict
+ """
+ specified_config = {}
+ for config_file in config_files:
+ with open(config_file) as file_stream:
+ yaml_config = yaml.safe_load(file_stream)
+ specified_config.update(yaml_config)
+
+ if "server_name" not in specified_config:
+ raise ConfigError(MISSING_SERVER_NAME)
+
+ if "report_stats" not in specified_config:
+ raise ConfigError(
+ MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" + MISSING_REPORT_STATS_SPIEL
+ )
+ return specified_config
def find_config_files(search_paths):
diff --git a/synapse/config/api.py b/synapse/config/api.py
index 5eb4f86fa2..dddea79a8a 100644
--- a/synapse/config/api.py
+++ b/synapse/config/api.py
@@ -18,17 +18,19 @@ from ._base import Config
class ApiConfig(Config):
+ def read_config(self, config, **kwargs):
+ self.room_invite_state_types = config.get(
+ "room_invite_state_types",
+ [
+ EventTypes.JoinRules,
+ EventTypes.CanonicalAlias,
+ EventTypes.RoomAvatar,
+ EventTypes.RoomEncryption,
+ EventTypes.Name,
+ ],
+ )
- def read_config(self, config):
- self.room_invite_state_types = config.get("room_invite_state_types", [
- EventTypes.JoinRules,
- EventTypes.CanonicalAlias,
- EventTypes.RoomAvatar,
- EventTypes.RoomEncryption,
- EventTypes.Name,
- ])
-
- def default_config(cls, **kwargs):
+ def generate_config_section(cls, **kwargs):
return """\
## API Configuration ##
@@ -40,4 +42,6 @@ class ApiConfig(Config):
# - "{RoomAvatar}"
# - "{RoomEncryption}"
# - "{Name}"
- """.format(**vars(EventTypes))
+ """.format(
+ **vars(EventTypes)
+ )
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index 7e89d345d8..8387ff6805 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -29,13 +29,12 @@ logger = logging.getLogger(__name__)
class AppServiceConfig(Config):
-
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
self.app_service_config_files = config.get("app_service_config_files", [])
self.notify_appservices = config.get("notify_appservices", True)
self.track_appservice_user_ips = config.get("track_appservice_user_ips", False)
- def default_config(cls, **kwargs):
+ def generate_config_section(cls, **kwargs):
return """\
# A list of application service config files to use
#
@@ -53,9 +52,7 @@ class AppServiceConfig(Config):
def load_appservices(hostname, config_files):
"""Returns a list of Application Services from the config files."""
if not isinstance(config_files, list):
- logger.warning(
- "Expected %s to be a list of AS config files.", config_files
- )
+ logger.warning("Expected %s to be a list of AS config files.", config_files)
return []
# Dicts of value -> filename
@@ -66,22 +63,20 @@ def load_appservices(hostname, config_files):
for config_file in config_files:
try:
- with open(config_file, 'r') as f:
- appservice = _load_appservice(
- hostname, yaml.safe_load(f), config_file
- )
+ with open(config_file, "r") as f:
+ appservice = _load_appservice(hostname, yaml.safe_load(f), config_file)
if appservice.id in seen_ids:
raise ConfigError(
"Cannot reuse ID across application services: "
- "%s (files: %s, %s)" % (
- appservice.id, config_file, seen_ids[appservice.id],
- )
+ "%s (files: %s, %s)"
+ % (appservice.id, config_file, seen_ids[appservice.id])
)
seen_ids[appservice.id] = config_file
if appservice.token in seen_as_tokens:
raise ConfigError(
"Cannot reuse as_token across application services: "
- "%s (files: %s, %s)" % (
+ "%s (files: %s, %s)"
+ % (
appservice.token,
config_file,
seen_as_tokens[appservice.token],
@@ -98,28 +93,26 @@ def load_appservices(hostname, config_files):
def _load_appservice(hostname, as_info, config_filename):
- required_string_fields = [
- "id", "as_token", "hs_token", "sender_localpart"
- ]
+ required_string_fields = ["id", "as_token", "hs_token", "sender_localpart"]
for field in required_string_fields:
if not isinstance(as_info.get(field), string_types):
- raise KeyError("Required string field: '%s' (%s)" % (
- field, config_filename,
- ))
+ raise KeyError(
+ "Required string field: '%s' (%s)" % (field, config_filename)
+ )
# 'url' must either be a string or explicitly null, not missing
# to avoid accidentally turning off push for ASes.
- if (not isinstance(as_info.get("url"), string_types) and
- as_info.get("url", "") is not None):
+ if (
+ not isinstance(as_info.get("url"), string_types)
+ and as_info.get("url", "") is not None
+ ):
raise KeyError(
"Required string field or explicit null: 'url' (%s)" % (config_filename,)
)
localpart = as_info["sender_localpart"]
if urlparse.quote(localpart) != localpart:
- raise ValueError(
- "sender_localpart needs characters which are not URL encoded."
- )
+ raise ValueError("sender_localpart needs characters which are not URL encoded.")
user = UserID(localpart, hostname)
user_id = user.to_string()
@@ -138,13 +131,12 @@ def _load_appservice(hostname, as_info, config_filename):
for regex_obj in as_info["namespaces"][ns]:
if not isinstance(regex_obj, dict):
raise ValueError(
- "Expected namespace entry in %s to be an object,"
- " but got %s", ns, regex_obj
+ "Expected namespace entry in %s to be an object," " but got %s",
+ ns,
+ regex_obj,
)
if not isinstance(regex_obj.get("regex"), string_types):
- raise ValueError(
- "Missing/bad type 'regex' key in %s", regex_obj
- )
+ raise ValueError("Missing/bad type 'regex' key in %s", regex_obj)
if not isinstance(regex_obj.get("exclusive"), bool):
raise ValueError(
"Missing/bad type 'exclusive' key in %s", regex_obj
@@ -167,10 +159,8 @@ def _load_appservice(hostname, as_info, config_filename):
)
ip_range_whitelist = None
- if as_info.get('ip_range_whitelist'):
- ip_range_whitelist = IPSet(
- as_info.get('ip_range_whitelist')
- )
+ if as_info.get("ip_range_whitelist"):
+ ip_range_whitelist = IPSet(as_info.get("ip_range_whitelist"))
return ApplicationService(
token=as_info["as_token"],
diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py
index f7eebf26d2..8dac8152cf 100644
--- a/synapse/config/captcha.py
+++ b/synapse/config/captcha.py
@@ -16,8 +16,7 @@ from ._base import Config
class CaptchaConfig(Config):
-
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
self.recaptcha_private_key = config.get("recaptcha_private_key")
self.recaptcha_public_key = config.get("recaptcha_public_key")
self.enable_registration_captcha = config.get(
@@ -29,7 +28,7 @@ class CaptchaConfig(Config):
"https://www.recaptcha.net/recaptcha/api/siteverify",
)
- def default_config(self, **kwargs):
+ def generate_config_section(self, **kwargs):
return """\
## Captcha ##
# See docs/CAPTCHA_SETUP for full details of configuring this.
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index 609c0815c8..ebe34d933b 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -22,7 +22,7 @@ class CasConfig(Config):
cas_server_url: URL of CAS server
"""
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
cas_config = config.get("cas_config", None)
if cas_config:
self.cas_enabled = cas_config.get("enabled", True)
@@ -35,7 +35,7 @@ class CasConfig(Config):
self.cas_service_url = None
self.cas_required_attributes = {}
- def default_config(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
# Enable CAS for registration and login.
#
diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py
index abeb0180d3..94916f3a49 100644
--- a/synapse/config/consent_config.py
+++ b/synapse/config/consent_config.py
@@ -84,35 +84,32 @@ class ConsentConfig(Config):
self.user_consent_at_registration = False
self.user_consent_policy_name = "Privacy Policy"
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
consent_config = config.get("user_consent")
if consent_config is None:
return
self.user_consent_version = str(consent_config["version"])
- self.user_consent_template_dir = self.abspath(
- consent_config["template_dir"]
- )
+ self.user_consent_template_dir = self.abspath(consent_config["template_dir"])
if not path.isdir(self.user_consent_template_dir):
raise ConfigError(
- "Could not find template directory '%s'" % (
- self.user_consent_template_dir,
- ),
+ "Could not find template directory '%s'"
+ % (self.user_consent_template_dir,)
)
self.user_consent_server_notice_content = consent_config.get(
- "server_notice_content",
+ "server_notice_content"
)
self.block_events_without_consent_error = consent_config.get(
- "block_events_error",
+ "block_events_error"
+ )
+ self.user_consent_server_notice_to_guests = bool(
+ consent_config.get("send_server_notice_to_guests", False)
+ )
+ self.user_consent_at_registration = bool(
+ consent_config.get("require_at_registration", False)
)
- self.user_consent_server_notice_to_guests = bool(consent_config.get(
- "send_server_notice_to_guests", False,
- ))
- self.user_consent_at_registration = bool(consent_config.get(
- "require_at_registration", False,
- ))
self.user_consent_policy_name = consent_config.get(
- "policy_name", "Privacy Policy",
+ "policy_name", "Privacy Policy"
)
- def default_config(self, **kwargs):
+ def generate_config_section(self, **kwargs):
return DEFAULT_CONFIG
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 3c27ed6b4a..bcb2089dd7 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -18,37 +18,30 @@ from ._base import Config
class DatabaseConfig(Config):
-
- def read_config(self, config):
- self.event_cache_size = self.parse_size(
- config.get("event_cache_size", "10K")
- )
+ def read_config(self, config, **kwargs):
+ self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
self.database_config = config.get("database")
if self.database_config is None:
- self.database_config = {
- "name": "sqlite3",
- "args": {},
- }
+ self.database_config = {"name": "sqlite3", "args": {}}
name = self.database_config.get("name", None)
if name == "psycopg2":
pass
elif name == "sqlite3":
- self.database_config.setdefault("args", {}).update({
- "cp_min": 1,
- "cp_max": 1,
- "check_same_thread": False,
- })
+ self.database_config.setdefault("args", {}).update(
+ {"cp_min": 1, "cp_max": 1, "check_same_thread": False}
+ )
else:
raise RuntimeError("Unsupported database type '%s'" % (name,))
self.set_databasepath(config.get("database_path"))
- def default_config(self, data_dir_path, **kwargs):
+ def generate_config_section(self, data_dir_path, **kwargs):
database_path = os.path.join(data_dir_path, "homeserver.db")
- return """\
+ return (
+ """\
## Database ##
database:
@@ -62,7 +55,9 @@ class DatabaseConfig(Config):
# Number of events to cache in memory.
#
#event_cache_size: 10K
- """ % locals()
+ """
+ % locals()
+ )
def read_arguments(self, args):
self.set_databasepath(args.database_path)
@@ -77,6 +72,8 @@ class DatabaseConfig(Config):
def add_arguments(self, parser):
db_group = parser.add_argument_group("database")
db_group.add_argument(
- "-d", "--database-path", metavar="SQLITE_DATABASE_PATH",
- help="The path to a sqlite database to use."
+ "-d",
+ "--database-path",
+ metavar="SQLITE_DATABASE_PATH",
+ help="The path to a sqlite database to use.",
)
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index ae04252906..cf39936da7 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -19,18 +19,15 @@ from __future__ import print_function
# This file can't be called email.py because if it is, we cannot:
import email.utils
-import logging
import os
import pkg_resources
from ._base import Config, ConfigError
-logger = logging.getLogger(__name__)
-
class EmailConfig(Config):
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
# TODO: We should separate better the email configuration from the notification
# and account validity config.
@@ -59,7 +56,7 @@ class EmailConfig(Config):
if self.email_notif_from is not None:
# make sure it's valid
parsed = email.utils.parseaddr(self.email_notif_from)
- if parsed[1] == '':
+ if parsed[1] == "":
raise RuntimeError("Invalid notif_from address")
template_dir = email_config.get("template_dir")
@@ -68,27 +65,27 @@ class EmailConfig(Config):
# (Note that loading as package_resources with jinja.PackageLoader doesn't
# work for the same reason.)
if not template_dir:
- template_dir = pkg_resources.resource_filename(
- 'synapse', 'res/templates'
- )
+ template_dir = pkg_resources.resource_filename("synapse", "res/templates")
self.email_template_dir = os.path.abspath(template_dir)
self.email_enable_notifs = email_config.get("enable_notifs", False)
- account_validity_renewal_enabled = config.get(
- "account_validity", {},
- ).get("renew_at")
+ account_validity_renewal_enabled = config.get("account_validity", {}).get(
+ "renew_at"
+ )
email_trust_identity_server_for_password_resets = email_config.get(
- "trust_identity_server_for_password_resets", False,
+ "trust_identity_server_for_password_resets", False
)
self.email_password_reset_behaviour = (
"remote" if email_trust_identity_server_for_password_resets else "local"
)
+ self.password_resets_were_disabled_due_to_email_config = False
if self.email_password_reset_behaviour == "local" and email_config == {}:
- logger.warn(
- "User password resets have been disabled due to lack of email config"
- )
+ # We cannot warn the user this has happened here
+ # Instead do so when a user attempts to reset their password
+ self.password_resets_were_disabled_due_to_email_config = True
+
self.email_password_reset_behaviour = "off"
# Get lifetime of a validation token in milliseconds
@@ -104,62 +101,59 @@ class EmailConfig(Config):
# make sure we can import the required deps
import jinja2
import bleach
+
# prevent unused warnings
jinja2
bleach
if self.email_password_reset_behaviour == "local":
- required = [
- "smtp_host",
- "smtp_port",
- "notif_from",
- ]
+ required = ["smtp_host", "smtp_port", "notif_from"]
missing = []
for k in required:
if k not in email_config:
missing.append(k)
- if (len(missing) > 0):
+ if len(missing) > 0:
raise RuntimeError(
"email.password_reset_behaviour is set to 'local' "
- "but required keys are missing: %s" %
- (", ".join(["email." + k for k in missing]),)
+ "but required keys are missing: %s"
+ % (", ".join(["email." + k for k in missing]),)
)
# Templates for password reset emails
self.email_password_reset_template_html = email_config.get(
- "password_reset_template_html", "password_reset.html",
+ "password_reset_template_html", "password_reset.html"
)
self.email_password_reset_template_text = email_config.get(
- "password_reset_template_text", "password_reset.txt",
+ "password_reset_template_text", "password_reset.txt"
)
self.email_password_reset_failure_template = email_config.get(
- "password_reset_failure_template", "password_reset_failure.html",
+ "password_reset_failure_template", "password_reset_failure.html"
)
# This template does not support any replaceable variables, so we will
# read it from the disk once during setup
email_password_reset_success_template = email_config.get(
- "password_reset_success_template", "password_reset_success.html",
+ "password_reset_success_template", "password_reset_success.html"
)
# Check templates exist
- for f in [self.email_password_reset_template_html,
- self.email_password_reset_template_text,
- self.email_password_reset_failure_template,
- email_password_reset_success_template]:
+ for f in [
+ self.email_password_reset_template_html,
+ self.email_password_reset_template_text,
+ self.email_password_reset_failure_template,
+ email_password_reset_success_template,
+ ]:
p = os.path.join(self.email_template_dir, f)
if not os.path.isfile(p):
- raise ConfigError("Unable to find template file %s" % (p, ))
+ raise ConfigError("Unable to find template file %s" % (p,))
# Retrieve content of web templates
filepath = os.path.join(
- self.email_template_dir,
- email_password_reset_success_template,
+ self.email_template_dir, email_password_reset_success_template
)
self.email_password_reset_success_html_content = self.read_file(
- filepath,
- "email.password_reset_template_success_html",
+ filepath, "email.password_reset_template_success_html"
)
if config.get("public_baseurl") is None:
@@ -183,10 +177,10 @@ class EmailConfig(Config):
if k not in email_config:
missing.append(k)
- if (len(missing) > 0):
+ if len(missing) > 0:
raise RuntimeError(
- "email.enable_notifs is True but required keys are missing: %s" %
- (", ".join(["email." + k for k in missing]),)
+ "email.enable_notifs is True but required keys are missing: %s"
+ % (", ".join(["email." + k for k in missing]),)
)
if config.get("public_baseurl") is None:
@@ -200,29 +194,27 @@ class EmailConfig(Config):
for f in self.email_notif_template_text, self.email_notif_template_html:
p = os.path.join(self.email_template_dir, f)
if not os.path.isfile(p):
- raise ConfigError("Unable to find email template file %s" % (p, ))
+ raise ConfigError("Unable to find email template file %s" % (p,))
self.email_notif_for_new_users = email_config.get(
"notif_for_new_users", True
)
- self.email_riot_base_url = email_config.get(
- "riot_base_url", None
- )
+ self.email_riot_base_url = email_config.get("riot_base_url", None)
if account_validity_renewal_enabled:
self.email_expiry_template_html = email_config.get(
- "expiry_template_html", "notice_expiry.html",
+ "expiry_template_html", "notice_expiry.html"
)
self.email_expiry_template_text = email_config.get(
- "expiry_template_text", "notice_expiry.txt",
+ "expiry_template_text", "notice_expiry.txt"
)
for f in self.email_expiry_template_text, self.email_expiry_template_html:
p = os.path.join(self.email_template_dir, f)
if not os.path.isfile(p):
- raise ConfigError("Unable to find email template file %s" % (p, ))
+ raise ConfigError("Unable to find email template file %s" % (p,))
- def default_config(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
# Enable sending emails for password resets, notification events or
# account expiry notices
diff --git a/synapse/config/groups.py b/synapse/config/groups.py
index e4be172a79..2a522b5f44 100644
--- a/synapse/config/groups.py
+++ b/synapse/config/groups.py
@@ -17,11 +17,11 @@ from ._base import Config
class GroupsConfig(Config):
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
self.enable_group_creation = config.get("enable_group_creation", False)
self.group_creation_prefix = config.get("group_creation_prefix", "")
- def default_config(self, **kwargs):
+ def generate_config_section(self, **kwargs):
return """\
# Uncomment to allow non-server-admin users to create groups on this server
#
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 5c4fc8ff21..acadef4fd3 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -38,6 +38,7 @@ from .server import ServerConfig
from .server_notices_config import ServerNoticesConfig
from .spam_checker import SpamCheckerConfig
from .stats import StatsConfig
+from .third_party_event_rules import ThirdPartyRulesConfig
from .tls import TlsConfig
from .user_directory import UserDirectoryConfig
from .voip import VoipConfig
@@ -73,5 +74,6 @@ class HomeServerConfig(
StatsConfig,
ServerNoticesConfig,
RoomDirectoryConfig,
+ ThirdPartyRulesConfig,
):
pass
diff --git a/synapse/config/jwt_config.py b/synapse/config/jwt_config.py
index ecb4124096..36d87cef03 100644
--- a/synapse/config/jwt_config.py
+++ b/synapse/config/jwt_config.py
@@ -15,17 +15,15 @@
from ._base import Config, ConfigError
-MISSING_JWT = (
- """Missing jwt library. This is required for jwt login.
+MISSING_JWT = """Missing jwt library. This is required for jwt login.
Install by running:
pip install pyjwt
"""
-)
class JWTConfig(Config):
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
jwt_config = config.get("jwt_config", None)
if jwt_config:
self.jwt_enabled = jwt_config.get("enabled", False)
@@ -34,6 +32,7 @@ class JWTConfig(Config):
try:
import jwt
+
jwt # To stop unused lint.
except ImportError:
raise ConfigError(MISSING_JWT)
@@ -42,7 +41,7 @@ class JWTConfig(Config):
self.jwt_secret = None
self.jwt_algorithm = None
- def default_config(self, **kwargs):
+ def generate_config_section(self, **kwargs):
return """\
# The JWT needs to contain a globally unique "sub" (subject) claim.
#
diff --git a/synapse/config/key.py b/synapse/config/key.py
index 424875feae..8fc74f9cdf 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -65,13 +65,18 @@ class TrustedKeyServer(object):
class KeyConfig(Config):
- def read_config(self, config):
+ def read_config(self, config, config_dir_path, **kwargs):
# the signing key can be specified inline or in a separate file
if "signing_key" in config:
self.signing_key = read_signing_keys([config["signing_key"]])
else:
- self.signing_key_path = config["signing_key_path"]
- self.signing_key = self.read_signing_key(self.signing_key_path)
+ signing_key_path = config.get("signing_key_path")
+ if signing_key_path is None:
+ signing_key_path = os.path.join(
+ config_dir_path, config["server_name"] + ".signing.key"
+ )
+
+ self.signing_key = self.read_signing_key(signing_key_path)
self.old_signing_keys = self.read_old_signing_keys(
config.get("old_signing_keys", {})
@@ -117,7 +122,7 @@ class KeyConfig(Config):
# falsification of values
self.form_secret = config.get("form_secret", None)
- def default_config(
+ def generate_config_section(
self, config_dir_path, server_name, generate_secrets=False, **kwargs
):
base_key_name = os.path.join(config_dir_path, server_name)
@@ -237,10 +242,18 @@ class KeyConfig(Config):
)
return keys
- def generate_files(self, config):
- signing_key_path = config["signing_key_path"]
+ def generate_files(self, config, config_dir_path):
+ if "signing_key" in config:
+ return
+
+ signing_key_path = config.get("signing_key_path")
+ if signing_key_path is None:
+ signing_key_path = os.path.join(
+ config_dir_path, config["server_name"] + ".signing.key"
+ )
if not self.path_exists(signing_key_path):
+ print("Generating signing key file %s" % (signing_key_path,))
with open(signing_key_path, "w") as signing_key_file:
key_id = "a_" + random_string(4)
write_signing_keys(signing_key_file, (generate_signing_key(key_id),))
@@ -348,9 +361,8 @@ def _parse_key_servers(key_servers, federation_verify_certificates):
result.verify_keys[key_id] = verify_key
- if (
- not federation_verify_certificates and
- not server.get("accept_keys_insecurely")
+ if not federation_verify_certificates and not server.get(
+ "accept_keys_insecurely"
):
_assert_keyserver_has_verify_keys(result)
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index c1febbe9d3..931aec41c0 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -29,7 +29,8 @@ from synapse.util.versionstring import get_version_string
from ._base import Config
-DEFAULT_LOG_CONFIG = Template("""
+DEFAULT_LOG_CONFIG = Template(
+ """
version: 1
formatters:
@@ -68,26 +69,29 @@ loggers:
root:
level: INFO
handlers: [file, console]
-""")
+"""
+)
class LoggingConfig(Config):
-
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
self.verbosity = config.get("verbose", 0)
self.no_redirect_stdio = config.get("no_redirect_stdio", False)
self.log_config = self.abspath(config.get("log_config"))
self.log_file = self.abspath(config.get("log_file"))
- def default_config(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, config_dir_path, server_name, **kwargs):
log_config = os.path.join(config_dir_path, server_name + ".log.config")
- return """\
+ return (
+ """\
## Logging ##
# A yaml python logging config file
#
log_config: "%(log_config)s"
- """ % locals()
+ """
+ % locals()
+ )
def read_arguments(self, args):
if args.verbose is not None:
@@ -102,32 +106,43 @@ class LoggingConfig(Config):
def add_arguments(cls, parser):
logging_group = parser.add_argument_group("logging")
logging_group.add_argument(
- '-v', '--verbose', dest="verbose", action='count',
+ "-v",
+ "--verbose",
+ dest="verbose",
+ action="count",
help="The verbosity level. Specify multiple times to increase "
- "verbosity. (Ignored if --log-config is specified.)"
+ "verbosity. (Ignored if --log-config is specified.)",
)
logging_group.add_argument(
- '-f', '--log-file', dest="log_file",
- help="File to log to. (Ignored if --log-config is specified.)"
+ "-f",
+ "--log-file",
+ dest="log_file",
+ help="File to log to. (Ignored if --log-config is specified.)",
)
logging_group.add_argument(
- '--log-config', dest="log_config", default=None,
- help="Python logging config file"
+ "--log-config",
+ dest="log_config",
+ default=None,
+ help="Python logging config file",
)
logging_group.add_argument(
- '-n', '--no-redirect-stdio',
- action='store_true', default=None,
- help="Do not redirect stdout/stderr to the log"
+ "-n",
+ "--no-redirect-stdio",
+ action="store_true",
+ default=None,
+ help="Do not redirect stdout/stderr to the log",
)
- def generate_files(self, config):
+ def generate_files(self, config, config_dir_path):
log_config = config.get("log_config")
if log_config and not os.path.exists(log_config):
log_file = self.abspath("homeserver.log")
+ print(
+ "Generating log config file %s which will log to %s"
+ % (log_config, log_file)
+ )
with open(log_config, "w") as log_config_file:
- log_config_file.write(
- DEFAULT_LOG_CONFIG.substitute(log_file=log_file)
- )
+ log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file))
def setup_logging(config, use_worker_options=False):
@@ -143,10 +158,8 @@ def setup_logging(config, use_worker_options=False):
register_sighup (func | None): Function to call to register a
sighup handler.
"""
- log_config = (config.worker_log_config if use_worker_options
- else config.log_config)
- log_file = (config.worker_log_file if use_worker_options
- else config.log_file)
+ log_config = config.worker_log_config if use_worker_options else config.log_config
+ log_file = config.worker_log_file if use_worker_options else config.log_file
log_format = (
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
@@ -164,23 +177,23 @@ def setup_logging(config, use_worker_options=False):
if config.verbosity > 1:
level_for_storage = logging.DEBUG
- logger = logging.getLogger('')
+ logger = logging.getLogger("")
logger.setLevel(level)
- logging.getLogger('synapse.storage.SQL').setLevel(level_for_storage)
+ logging.getLogger("synapse.storage.SQL").setLevel(level_for_storage)
formatter = logging.Formatter(log_format)
if log_file:
# TODO: Customisable file size / backup count
handler = logging.handlers.RotatingFileHandler(
- log_file, maxBytes=(1000 * 1000 * 100), backupCount=3,
- encoding='utf8'
+ log_file, maxBytes=(1000 * 1000 * 100), backupCount=3, encoding="utf8"
)
def sighup(signum, stack):
logger.info("Closing log file due to SIGHUP")
handler.doRollover()
logger.info("Opened new log file due to SIGHUP")
+
else:
handler = logging.StreamHandler()
@@ -193,8 +206,9 @@ def setup_logging(config, use_worker_options=False):
logger.addHandler(handler)
else:
+
def load_log_config():
- with open(log_config, 'r') as f:
+ with open(log_config, "r") as f:
logging.config.dictConfig(yaml.safe_load(f))
def sighup(*args):
@@ -209,10 +223,7 @@ def setup_logging(config, use_worker_options=False):
# make sure that the first thing we log is a thing we can grep backwards
# for
logging.warn("***** STARTING SERVER *****")
- logging.warn(
- "Server %s version %s",
- sys.argv[0], get_version_string(synapse),
- )
+ logging.warn("Server %s version %s", sys.argv[0], get_version_string(synapse))
logging.info("Server hostname: %s", config.server_name)
# It's critical to point twisted's internal logging somewhere, otherwise it
@@ -242,8 +253,7 @@ def setup_logging(config, use_worker_options=False):
return observer(event)
globalLogBeginner.beginLoggingTo(
- [_log],
- redirectStandardIO=not config.no_redirect_stdio,
+ [_log], redirectStandardIO=not config.no_redirect_stdio
)
if not config.no_redirect_stdio:
print("Redirected stdout/stderr to logs")
diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py
index 2de51979d8..3698441963 100644
--- a/synapse/config/metrics.py
+++ b/synapse/config/metrics.py
@@ -15,15 +15,13 @@
from ._base import Config, ConfigError
-MISSING_SENTRY = (
- """Missing sentry-sdk library. This is required to enable sentry
+MISSING_SENTRY = """Missing sentry-sdk library. This is required to enable sentry
integration.
"""
-)
class MetricsConfig(Config):
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
self.enable_metrics = config.get("enable_metrics", False)
self.report_stats = config.get("report_stats", None)
self.metrics_port = config.get("metrics_port")
@@ -39,10 +37,10 @@ class MetricsConfig(Config):
self.sentry_dsn = config["sentry"].get("dsn")
if not self.sentry_dsn:
raise ConfigError(
- "sentry.dsn field is required when sentry integration is enabled",
+ "sentry.dsn field is required when sentry integration is enabled"
)
- def default_config(self, report_stats=None, **kwargs):
+ def generate_config_section(self, report_stats=None, **kwargs):
res = """\
## Metrics ###
@@ -66,6 +64,6 @@ class MetricsConfig(Config):
if report_stats is None:
res += "# report_stats: true|false\n"
else:
- res += "report_stats: %s\n" % ('true' if report_stats else 'false')
+ res += "report_stats: %s\n" % ("true" if report_stats else "false")
return res
diff --git a/synapse/config/password.py b/synapse/config/password.py
index eea59e772b..598f84fc0c 100644
--- a/synapse/config/password.py
+++ b/synapse/config/password.py
@@ -20,7 +20,7 @@ class PasswordConfig(Config):
"""Password login configuration
"""
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
password_config = config.get("password_config", {})
if password_config is None:
password_config = {}
@@ -28,7 +28,7 @@ class PasswordConfig(Config):
self.password_enabled = password_config.get("enabled", True)
self.password_pepper = password_config.get("pepper", "")
- def default_config(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
password_config:
# Uncomment to disable password login
diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py
index f0a6be0679..788c39c9fb 100644
--- a/synapse/config/password_auth_providers.py
+++ b/synapse/config/password_auth_providers.py
@@ -17,11 +17,11 @@ from synapse.util.module_loader import load_module
from ._base import Config
-LDAP_PROVIDER = 'ldap_auth_provider.LdapAuthProvider'
+LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider"
class PasswordAuthProviderConfig(Config):
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
self.password_providers = []
providers = []
@@ -29,28 +29,24 @@ class PasswordAuthProviderConfig(Config):
# param.
ldap_config = config.get("ldap_config", {})
if ldap_config.get("enabled", False):
- providers.append({
- 'module': LDAP_PROVIDER,
- 'config': ldap_config,
- })
+ providers.append({"module": LDAP_PROVIDER, "config": ldap_config})
providers.extend(config.get("password_providers", []))
for provider in providers:
- mod_name = provider['module']
+ mod_name = provider["module"]
# This is for backwards compat when the ldap auth provider resided
# in this package.
if mod_name == "synapse.util.ldap_auth_provider.LdapAuthProvider":
mod_name = LDAP_PROVIDER
- (provider_class, provider_config) = load_module({
- "module": mod_name,
- "config": provider['config'],
- })
+ (provider_class, provider_config) = load_module(
+ {"module": mod_name, "config": provider["config"]}
+ )
self.password_providers.append((provider_class, provider_config))
- def default_config(self, **kwargs):
+ def generate_config_section(self, **kwargs):
return """\
#password_providers:
# - module: "ldap_auth_provider.LdapAuthProvider"
diff --git a/synapse/config/push.py b/synapse/config/push.py
index 62c0060c9c..1b932722a5 100644
--- a/synapse/config/push.py
+++ b/synapse/config/push.py
@@ -18,7 +18,7 @@ from ._base import Config
class PushConfig(Config):
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
push_config = config.get("push", {})
self.push_include_content = push_config.get("include_content", True)
@@ -42,7 +42,7 @@ class PushConfig(Config):
)
self.push_include_content = not redact_content
- def default_config(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
# Clients requesting push notifications can either have the body of
# the message sent in the notification poke along with other details
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 5a9adac480..8c587f3fd2 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -36,7 +36,7 @@ class FederationRateLimitConfig(object):
class RatelimitConfig(Config):
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
# Load the new-style messages config if it exists. Otherwise fall back
# to the old method.
@@ -80,7 +80,7 @@ class RatelimitConfig(Config):
"federation_rr_transactions_per_room_per_second", 50
)
- def default_config(self, **kwargs):
+ def generate_config_section(self, **kwargs):
return """\
## Ratelimiting ##
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index aad3400819..4a59e6ec90 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -23,7 +23,7 @@ from synapse.util.stringutils import random_string_with_symbols
class AccountValidityConfig(Config):
def __init__(self, config, synapse_config):
self.enabled = config.get("enabled", False)
- self.renew_by_email_enabled = ("renew_at" in config)
+ self.renew_by_email_enabled = "renew_at" in config
if self.enabled:
if "period" in config:
@@ -39,15 +39,14 @@ class AccountValidityConfig(Config):
else:
self.renew_email_subject = "Renew your %(app)s account"
- self.startup_job_max_delta = self.period * 10. / 100.
+ self.startup_job_max_delta = self.period * 10.0 / 100.0
if self.renew_by_email_enabled and "public_baseurl" not in synapse_config:
raise ConfigError("Can't send renewal emails without 'public_baseurl'")
class RegistrationConfig(Config):
-
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
self.enable_registration = bool(
strtobool(str(config.get("enable_registration", False)))
)
@@ -57,7 +56,7 @@ class RegistrationConfig(Config):
)
self.account_validity = AccountValidityConfig(
- config.get("account_validity", {}), config,
+ config.get("account_validity", {}), config
)
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
@@ -67,35 +66,37 @@ class RegistrationConfig(Config):
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.trusted_third_party_id_servers = config.get(
- "trusted_third_party_id_servers",
- ["matrix.org", "vector.im"],
+ "trusted_third_party_id_servers", ["matrix.org", "vector.im"]
)
self.default_identity_server = config.get("default_identity_server")
self.allow_guest_access = config.get("allow_guest_access", False)
- self.invite_3pid_guest = (
- self.allow_guest_access and config.get("invite_3pid_guest", False)
+ self.invite_3pid_guest = self.allow_guest_access and config.get(
+ "invite_3pid_guest", False
)
self.auto_join_rooms = config.get("auto_join_rooms", [])
for room_alias in self.auto_join_rooms:
if not RoomAlias.is_valid(room_alias):
- raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,))
+ raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,))
self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
- self.disable_msisdn_registration = (
- config.get("disable_msisdn_registration", False)
+ self.disable_msisdn_registration = config.get(
+ "disable_msisdn_registration", False
)
- def default_config(self, generate_secrets=False, **kwargs):
+ def generate_config_section(self, generate_secrets=False, **kwargs):
if generate_secrets:
registration_shared_secret = 'registration_shared_secret: "%s"' % (
random_string_with_symbols(50),
)
else:
- registration_shared_secret = '# registration_shared_secret: <PRIVATE STRING>'
+ registration_shared_secret = (
+ "# registration_shared_secret: <PRIVATE STRING>"
+ )
- return """\
+ return (
+ """\
## Registration ##
#
# Registration can be rate-limited using the parameters in the "Ratelimiting"
@@ -217,17 +218,19 @@ class RegistrationConfig(Config):
# users cannot be auto-joined since they do not exist.
#
#autocreate_auto_join_rooms: true
- """ % locals()
+ """
+ % locals()
+ )
def add_arguments(self, parser):
reg_group = parser.add_argument_group("registration")
reg_group.add_argument(
- "--enable-registration", action="store_true", default=None,
- help="Enable registration for new users."
+ "--enable-registration",
+ action="store_true",
+ default=None,
+ help="Enable registration for new users.",
)
def read_arguments(self, args):
if args.enable_registration is not None:
- self.enable_registration = bool(
- strtobool(str(args.enable_registration))
- )
+ self.enable_registration = bool(strtobool(str(args.enable_registration)))
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index fbfcecc240..80a628d9b0 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -20,27 +20,11 @@ from synapse.util.module_loader import load_module
from ._base import Config, ConfigError
DEFAULT_THUMBNAIL_SIZES = [
- {
- "width": 32,
- "height": 32,
- "method": "crop",
- }, {
- "width": 96,
- "height": 96,
- "method": "crop",
- }, {
- "width": 320,
- "height": 240,
- "method": "scale",
- }, {
- "width": 640,
- "height": 480,
- "method": "scale",
- }, {
- "width": 800,
- "height": 600,
- "method": "scale"
- },
+ {"width": 32, "height": 32, "method": "crop"},
+ {"width": 96, "height": 96, "method": "crop"},
+ {"width": 320, "height": 240, "method": "scale"},
+ {"width": 640, "height": 480, "method": "scale"},
+ {"width": 800, "height": 600, "method": "scale"},
]
THUMBNAIL_SIZE_YAML = """\
@@ -49,19 +33,15 @@ THUMBNAIL_SIZE_YAML = """\
# method: %(method)s
"""
-MISSING_NETADDR = (
- "Missing netaddr library. This is required for URL preview API."
-)
+MISSING_NETADDR = "Missing netaddr library. This is required for URL preview API."
-MISSING_LXML = (
- """Missing lxml library. This is required for URL preview API.
+MISSING_LXML = """Missing lxml library. This is required for URL preview API.
Install by running:
pip install lxml
Requires libxslt1-dev system package.
"""
-)
ThumbnailRequirement = namedtuple(
@@ -69,7 +49,8 @@ ThumbnailRequirement = namedtuple(
)
MediaStorageProviderConfig = namedtuple(
- "MediaStorageProviderConfig", (
+ "MediaStorageProviderConfig",
+ (
"store_local", # Whether to store newly uploaded local files
"store_remote", # Whether to store newly downloaded remote files
"store_synchronous", # Whether to wait for successful storage for local uploads
@@ -100,18 +81,19 @@ def parse_thumbnail_requirements(thumbnail_sizes):
requirements.setdefault("image/gif", []).append(png_thumbnail)
requirements.setdefault("image/png", []).append(png_thumbnail)
return {
- media_type: tuple(thumbnails)
- for media_type, thumbnails in requirements.items()
+ media_type: tuple(thumbnails) for media_type, thumbnails in requirements.items()
}
class ContentRepositoryConfig(Config):
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
self.max_upload_size = self.parse_size(config.get("max_upload_size", "10M"))
self.max_image_pixels = self.parse_size(config.get("max_image_pixels", "32M"))
self.max_spider_size = self.parse_size(config.get("max_spider_size", "10M"))
- self.media_store_path = self.ensure_directory(config["media_store_path"])
+ self.media_store_path = self.ensure_directory(
+ config.get("media_store_path", "media_store")
+ )
backup_media_store_path = config.get("backup_media_store_path")
@@ -127,15 +109,15 @@ class ContentRepositoryConfig(Config):
"Cannot use both 'backup_media_store_path' and 'storage_providers'"
)
- storage_providers = [{
- "module": "file_system",
- "store_local": True,
- "store_synchronous": synchronous_backup_media_store,
- "store_remote": True,
- "config": {
- "directory": backup_media_store_path,
+ storage_providers = [
+ {
+ "module": "file_system",
+ "store_local": True,
+ "store_synchronous": synchronous_backup_media_store,
+ "store_remote": True,
+ "config": {"directory": backup_media_store_path},
}
- }]
+ ]
# This is a list of config that can be used to create the storage
# providers. The entries are tuples of (Class, class_config,
@@ -165,18 +147,19 @@ class ContentRepositoryConfig(Config):
)
self.media_storage_providers.append(
- (provider_class, parsed_config, wrapper_config,)
+ (provider_class, parsed_config, wrapper_config)
)
- self.uploads_path = self.ensure_directory(config["uploads_path"])
+ self.uploads_path = self.ensure_directory(config.get("uploads_path", "uploads"))
self.dynamic_thumbnails = config.get("dynamic_thumbnails", False)
self.thumbnail_requirements = parse_thumbnail_requirements(
- config.get("thumbnail_sizes", DEFAULT_THUMBNAIL_SIZES),
+ config.get("thumbnail_sizes", DEFAULT_THUMBNAIL_SIZES)
)
self.url_preview_enabled = config.get("url_preview_enabled", False)
if self.url_preview_enabled:
try:
import lxml
+
lxml # To stop unused lint.
except ImportError:
raise ConfigError(MISSING_LXML)
@@ -199,17 +182,15 @@ class ContentRepositoryConfig(Config):
# we always blacklist '0.0.0.0' and '::', which are supposed to be
# unroutable addresses.
- self.url_preview_ip_range_blacklist.update(['0.0.0.0', '::'])
+ self.url_preview_ip_range_blacklist.update(["0.0.0.0", "::"])
self.url_preview_ip_range_whitelist = IPSet(
config.get("url_preview_ip_range_whitelist", ())
)
- self.url_preview_url_blacklist = config.get(
- "url_preview_url_blacklist", ()
- )
+ self.url_preview_url_blacklist = config.get("url_preview_url_blacklist", ())
- def default_config(self, data_dir_path, **kwargs):
+ def generate_config_section(self, data_dir_path, **kwargs):
media_store = os.path.join(data_dir_path, "media_store")
uploads_path = os.path.join(data_dir_path, "uploads")
@@ -219,7 +200,8 @@ class ContentRepositoryConfig(Config):
# strip final NL
formatted_thumbnail_sizes = formatted_thumbnail_sizes[:-1]
- return r"""
+ return (
+ r"""
# Directory where uploaded images and attachments are stored.
#
media_store_path: "%(media_store)s"
@@ -342,4 +324,6 @@ class ContentRepositoryConfig(Config):
# The largest allowed URL preview spidering size in bytes
#
#max_spider_size: 10M
- """ % locals()
+ """
+ % locals()
+ )
diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
index 8a9fded4c5..a92693017b 100644
--- a/synapse/config/room_directory.py
+++ b/synapse/config/room_directory.py
@@ -19,10 +19,8 @@ from ._base import Config, ConfigError
class RoomDirectoryConfig(Config):
- def read_config(self, config):
- self.enable_room_list_search = config.get(
- "enable_room_list_search", True,
- )
+ def read_config(self, config, **kwargs):
+ self.enable_room_list_search = config.get("enable_room_list_search", True)
alias_creation_rules = config.get("alias_creation_rules")
@@ -33,11 +31,7 @@ class RoomDirectoryConfig(Config):
]
else:
self._alias_creation_rules = [
- _RoomDirectoryRule(
- "alias_creation_rules", {
- "action": "allow",
- }
- )
+ _RoomDirectoryRule("alias_creation_rules", {"action": "allow"})
]
room_list_publication_rules = config.get("room_list_publication_rules")
@@ -49,14 +43,10 @@ class RoomDirectoryConfig(Config):
]
else:
self._room_list_publication_rules = [
- _RoomDirectoryRule(
- "room_list_publication_rules", {
- "action": "allow",
- }
- )
+ _RoomDirectoryRule("room_list_publication_rules", {"action": "allow"})
]
- def default_config(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
# Uncomment to disable searching the public room list. When disabled
# blocks searching local and remote room lists for local and remote
@@ -178,8 +168,7 @@ class _RoomDirectoryRule(object):
self.action = action
else:
raise ConfigError(
- "%s rules can only have action of 'allow'"
- " or 'deny'" % (option_name,)
+ "%s rules can only have action of 'allow'" " or 'deny'" % (option_name,)
)
self._alias_matches_all = alias == "*"
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index a6ff62df09..463b5fdd68 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -18,7 +18,7 @@ from ._base import Config, ConfigError
class SAML2Config(Config):
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
self.saml2_enabled = False
saml2_config = config.get("saml2_config")
@@ -34,6 +34,7 @@ class SAML2Config(Config):
self.saml2_enabled = True
import saml2.config
+
self.saml2_sp_config = saml2.config.SPConfig()
self.saml2_sp_config.load(self._default_saml_config_dict())
self.saml2_sp_config.load(saml2_config.get("sp_config", {}))
@@ -47,29 +48,26 @@ class SAML2Config(Config):
public_baseurl = self.public_baseurl
if public_baseurl is None:
- raise ConfigError(
- "saml2_config requires a public_baseurl to be set"
- )
+ raise ConfigError("saml2_config requires a public_baseurl to be set")
metadata_url = public_baseurl + "_matrix/saml2/metadata.xml"
response_url = public_baseurl + "_matrix/saml2/authn_response"
return {
"entityid": metadata_url,
-
"service": {
"sp": {
"endpoints": {
"assertion_consumer_service": [
- (response_url, saml2.BINDING_HTTP_POST),
- ],
+ (response_url, saml2.BINDING_HTTP_POST)
+ ]
},
"required_attributes": ["uid"],
"optional_attributes": ["mail", "surname", "givenname"],
- },
- }
+ }
+ },
}
- def default_config(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
# Enable SAML2 for registration and login. Uses pysaml2.
#
@@ -112,4 +110,6 @@ class SAML2Config(Config):
# # separate pysaml2 configuration file:
# #
# config_path: "%(config_dir_path)s/sp_conf.py"
- """ % {"config_dir_path": config_dir_path}
+ """ % {
+ "config_dir_path": config_dir_path
+ }
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 7d56e2d141..2a74dea2ea 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -34,14 +34,13 @@ logger = logging.Logger(__name__)
#
# We later check for errors when binding to 0.0.0.0 and ignore them if :: is also in
# in the list.
-DEFAULT_BIND_ADDRESSES = ['::', '0.0.0.0']
+DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"]
DEFAULT_ROOM_VERSION = "4"
class ServerConfig(Config):
-
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
self.server_name = config["server_name"]
self.server_context = config.get("server_context", None)
@@ -58,7 +57,6 @@ class ServerConfig(Config):
self.user_agent_suffix = config.get("user_agent_suffix")
self.use_frozen_dicts = config.get("use_frozen_dicts", False)
self.public_baseurl = config.get("public_baseurl")
- self.cpu_affinity = config.get("cpu_affinity")
# Whether to send federation traffic out in this process. This only
# applies to some federation traffic, and so shouldn't be used to
@@ -81,27 +79,45 @@ class ServerConfig(Config):
# Whether to require authentication to retrieve profile data (avatars,
# display names) of other users through the client API.
self.require_auth_for_profile_requests = config.get(
- "require_auth_for_profile_requests", False,
+ "require_auth_for_profile_requests", False
)
- # If set to 'True', requires authentication to access the server's
- # public rooms directory through the client API, and forbids any other
- # homeserver to fetch it via federation.
- self.restrict_public_rooms_to_local_users = config.get(
- "restrict_public_rooms_to_local_users", False,
- )
+ if "restrict_public_rooms_to_local_users" in config and (
+ "allow_public_rooms_without_auth" in config
+ or "allow_public_rooms_over_federation" in config
+ ):
+ raise ConfigError(
+ "Can't use 'restrict_public_rooms_to_local_users' if"
+ " 'allow_public_rooms_without_auth' and/or"
+ " 'allow_public_rooms_over_federation' is set."
+ )
- default_room_version = config.get(
- "default_room_version", DEFAULT_ROOM_VERSION,
- )
+ # Check if the legacy "restrict_public_rooms_to_local_users" flag is set. This
+ # flag is now obsolete but we need to check it for backward-compatibility.
+ if config.get("restrict_public_rooms_to_local_users", False):
+ self.allow_public_rooms_without_auth = False
+ self.allow_public_rooms_over_federation = False
+ else:
+ # If set to 'False', requires authentication to access the server's public
+ # rooms directory through the client API. Defaults to 'True'.
+ self.allow_public_rooms_without_auth = config.get(
+ "allow_public_rooms_without_auth", True
+ )
+ # If set to 'False', forbids any other homeserver to fetch the server's public
+ # rooms directory via federation. Defaults to 'True'.
+ self.allow_public_rooms_over_federation = config.get(
+ "allow_public_rooms_over_federation", True
+ )
+
+ default_room_version = config.get("default_room_version", DEFAULT_ROOM_VERSION)
# Ensure room version is a str
default_room_version = str(default_room_version)
if default_room_version not in KNOWN_ROOM_VERSIONS:
raise ConfigError(
- "Unknown default_room_version: %s, known room versions: %s" %
- (default_room_version, list(KNOWN_ROOM_VERSIONS.keys()))
+ "Unknown default_room_version: %s, known room versions: %s"
+ % (default_room_version, list(KNOWN_ROOM_VERSIONS.keys()))
)
# Get the actual room version object rather than just the identifier
@@ -116,31 +132,25 @@ class ServerConfig(Config):
# Whether we should block invites sent to users on this server
# (other than those sent by local server admins)
- self.block_non_admin_invites = config.get(
- "block_non_admin_invites", False,
- )
+ self.block_non_admin_invites = config.get("block_non_admin_invites", False)
# Whether to enable experimental MSC1849 (aka relations) support
self.experimental_msc1849_support_enabled = config.get(
- "experimental_msc1849_support_enabled", False,
+ "experimental_msc1849_support_enabled", False
)
# Options to control access by tracking MAU
self.limit_usage_by_mau = config.get("limit_usage_by_mau", False)
self.max_mau_value = 0
if self.limit_usage_by_mau:
- self.max_mau_value = config.get(
- "max_mau_value", 0,
- )
+ self.max_mau_value = config.get("max_mau_value", 0)
self.mau_stats_only = config.get("mau_stats_only", False)
self.mau_limits_reserved_threepids = config.get(
"mau_limit_reserved_threepids", []
)
- self.mau_trial_days = config.get(
- "mau_trial_days", 0,
- )
+ self.mau_trial_days = config.get("mau_trial_days", 0)
# Options to disable HS
self.hs_disabled = config.get("hs_disabled", False)
@@ -153,9 +163,7 @@ class ServerConfig(Config):
# FIXME: federation_domain_whitelist needs sytests
self.federation_domain_whitelist = None
- federation_domain_whitelist = config.get(
- "federation_domain_whitelist", None,
- )
+ federation_domain_whitelist = config.get("federation_domain_whitelist", None)
if federation_domain_whitelist is not None:
# turn the whitelist into a hash for speed of lookup
@@ -165,7 +173,7 @@ class ServerConfig(Config):
self.federation_domain_whitelist[domain] = True
self.federation_ip_range_blacklist = config.get(
- "federation_ip_range_blacklist", [],
+ "federation_ip_range_blacklist", []
)
# Attempt to create an IPSet from the given ranges
@@ -178,13 +186,12 @@ class ServerConfig(Config):
self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
except Exception as e:
raise ConfigError(
- "Invalid range(s) provided in "
- "federation_ip_range_blacklist: %s" % e
+ "Invalid range(s) provided in " "federation_ip_range_blacklist: %s" % e
)
if self.public_baseurl is not None:
- if self.public_baseurl[-1] != '/':
- self.public_baseurl += '/'
+ if self.public_baseurl[-1] != "/":
+ self.public_baseurl += "/"
self.start_pushers = config.get("start_pushers", True)
# (undocumented) option for torturing the worker-mode replication a bit,
@@ -195,7 +202,7 @@ class ServerConfig(Config):
# Whether to require a user to be in the room to add an alias to it.
# Defaults to True.
self.require_membership_for_aliases = config.get(
- "require_membership_for_aliases", True,
+ "require_membership_for_aliases", True
)
# Whether to allow per-room membership profiles through the send of membership
@@ -227,9 +234,9 @@ class ServerConfig(Config):
# if we still have an empty list of addresses, use the default list
if not bind_addresses:
- if listener['type'] == 'metrics':
+ if listener["type"] == "metrics":
# the metrics listener doesn't support IPv6
- bind_addresses.append('0.0.0.0')
+ bind_addresses.append("0.0.0.0")
else:
bind_addresses.extend(DEFAULT_BIND_ADDRESSES)
@@ -249,78 +256,80 @@ class ServerConfig(Config):
bind_host = config.get("bind_host", "")
gzip_responses = config.get("gzip_responses", True)
- self.listeners.append({
- "port": bind_port,
- "bind_addresses": [bind_host],
- "tls": True,
- "type": "http",
- "resources": [
- {
- "names": ["client"],
- "compress": gzip_responses,
- },
- {
- "names": ["federation"],
- "compress": False,
- }
- ]
- })
-
- unsecure_port = config.get("unsecure_port", bind_port - 400)
- if unsecure_port:
- self.listeners.append({
- "port": unsecure_port,
+ self.listeners.append(
+ {
+ "port": bind_port,
"bind_addresses": [bind_host],
- "tls": False,
+ "tls": True,
"type": "http",
"resources": [
- {
- "names": ["client"],
- "compress": gzip_responses,
- },
- {
- "names": ["federation"],
- "compress": False,
- }
- ]
- })
+ {"names": ["client"], "compress": gzip_responses},
+ {"names": ["federation"], "compress": False},
+ ],
+ }
+ )
+
+ unsecure_port = config.get("unsecure_port", bind_port - 400)
+ if unsecure_port:
+ self.listeners.append(
+ {
+ "port": unsecure_port,
+ "bind_addresses": [bind_host],
+ "tls": False,
+ "type": "http",
+ "resources": [
+ {"names": ["client"], "compress": gzip_responses},
+ {"names": ["federation"], "compress": False},
+ ],
+ }
+ )
manhole = config.get("manhole")
if manhole:
- self.listeners.append({
- "port": manhole,
- "bind_addresses": ["127.0.0.1"],
- "type": "manhole",
- "tls": False,
- })
+ self.listeners.append(
+ {
+ "port": manhole,
+ "bind_addresses": ["127.0.0.1"],
+ "type": "manhole",
+ "tls": False,
+ }
+ )
metrics_port = config.get("metrics_port")
if metrics_port:
logger.warn(
- ("The metrics_port configuration option is deprecated in Synapse 0.31 "
- "in favour of a listener. Please see "
- "http://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.rst"
- " on how to configure the new listener."))
-
- self.listeners.append({
- "port": metrics_port,
- "bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")],
- "tls": False,
- "type": "http",
- "resources": [
- {
- "names": ["metrics"],
- "compress": False,
- },
- ]
- })
+ (
+ "The metrics_port configuration option is deprecated in Synapse 0.31 "
+ "in favour of a listener. Please see "
+ "http://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.rst"
+ " on how to configure the new listener."
+ )
+ )
+
+ self.listeners.append(
+ {
+ "port": metrics_port,
+ "bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")],
+ "tls": False,
+ "type": "http",
+ "resources": [{"names": ["metrics"], "compress": False}],
+ }
+ )
_check_resource_config(self.listeners)
+ # An experimental option to try and periodically clean up extremities
+ # by sending dummy events.
+ self.cleanup_extremities_with_dummy_events = config.get(
+ "cleanup_extremities_with_dummy_events", False
+ )
+
def has_tls_listener(self):
return any(l["tls"] for l in self.listeners)
- def default_config(self, server_name, data_dir_path, **kwargs):
+ def generate_config_section(
+ self, server_name, data_dir_path, open_private_ports, **kwargs
+ ):
_, bind_port = parse_and_validate_server_name(server_name)
if bind_port is not None:
unsecure_port = bind_port - 400
@@ -333,7 +342,15 @@ class ServerConfig(Config):
# Bring DEFAULT_ROOM_VERSION into the local-scope for use in the
# default config string
default_room_version = DEFAULT_ROOM_VERSION
- return """\
+
+ unsecure_http_binding = "port: %i\n tls: false" % (unsecure_port,)
+ if not open_private_ports:
+ unsecure_http_binding += (
+ "\n bind_addresses: ['::1', '127.0.0.1']"
+ )
+
+ return (
+ """\
## Server ##
# The domain name of the server, with optional explicit port.
@@ -347,29 +364,6 @@ class ServerConfig(Config):
#
pid_file: %(pid_file)s
- # CPU affinity mask. Setting this restricts the CPUs on which the
- # process will be scheduled. It is represented as a bitmask, with the
- # lowest order bit corresponding to the first logical CPU and the
- # highest order bit corresponding to the last logical CPU. Not all CPUs
- # may exist on a given system but a mask may specify more CPUs than are
- # present.
- #
- # For example:
- # 0x00000001 is processor #0,
- # 0x00000003 is processors #0 and #1,
- # 0xFFFFFFFF is all processors (#0 through #31).
- #
- # Pinning a Python process to a single CPU is desirable, because Python
- # is inherently single-threaded due to the GIL, and can suffer a
- # 30-40%% slowdown due to cache blow-out and thread context switching
- # if the scheduler happens to schedule the underlying threads across
- # different cores. See
- # https://www.mirantis.com/blog/improve-performance-python-programs-restricting-single-cpu/.
- #
- # This setting requires the affinity package to be installed!
- #
- #cpu_affinity: 0xFFFFFFFF
-
# The path to the web client which will be served at /_matrix/client/
# if 'webclient' is configured under the 'listeners' configuration.
#
@@ -401,11 +395,15 @@ class ServerConfig(Config):
#
#require_auth_for_profile_requests: true
- # If set to 'true', requires authentication to access the server's
- # public rooms directory through the client API, and forbids any other
- # homeserver to fetch it via federation. Defaults to 'false'.
+ # If set to 'false', requires authentication to access the server's public rooms
+ # directory through the client API. Defaults to 'true'.
#
- #restrict_public_rooms_to_local_users: true
+ #allow_public_rooms_without_auth: false
+
+ # If set to 'false', forbids any other homeserver to fetch the server's public
+ # rooms directory via federation. Defaults to 'true'.
+ #
+ #allow_public_rooms_over_federation: false
# The default room version for newly created rooms.
#
@@ -546,9 +544,7 @@ class ServerConfig(Config):
# If you plan to use a reverse proxy, please see
# https://github.com/matrix-org/synapse/blob/master/docs/reverse_proxy.rst.
#
- - port: %(unsecure_port)s
- tls: false
- bind_addresses: ['::1', '127.0.0.1']
+ - %(unsecure_http_binding)s
type: http
x_forwarded: true
@@ -556,7 +552,7 @@ class ServerConfig(Config):
- names: [client, federation]
compress: false
- # example additonal_resources:
+ # example additional_resources:
#
#additional_resources:
# "/_matrix/my/custom/endpoint":
@@ -631,7 +627,9 @@ class ServerConfig(Config):
# Defaults to 'true'.
#
#allow_per_room_profiles: false
- """ % locals()
+ """
+ % locals()
+ )
def read_arguments(self, args):
if args.manhole is not None:
@@ -643,17 +641,26 @@ class ServerConfig(Config):
def add_arguments(self, parser):
server_group = parser.add_argument_group("server")
- server_group.add_argument("-D", "--daemonize", action='store_true',
- default=None,
- help="Daemonize the home server")
- server_group.add_argument("--print-pidfile", action='store_true',
- default=None,
- help="Print the path to the pidfile just"
- " before daemonizing")
- server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
- type=int,
- help="Turn on the twisted telnet manhole"
- " service on the given port.")
+ server_group.add_argument(
+ "-D",
+ "--daemonize",
+ action="store_true",
+ default=None,
+ help="Daemonize the home server",
+ )
+ server_group.add_argument(
+ "--print-pidfile",
+ action="store_true",
+ default=None,
+ help="Print the path to the pidfile just" " before daemonizing",
+ )
+ server_group.add_argument(
+ "--manhole",
+ metavar="PORT",
+ dest="manhole",
+ type=int,
+ help="Turn on the twisted telnet manhole" " service on the given port.",
+ )
def is_threepid_reserved(reserved_threepids, threepid):
@@ -667,7 +674,7 @@ def is_threepid_reserved(reserved_threepids, threepid):
"""
for tp in reserved_threepids:
- if (threepid['medium'] == tp['medium'] and threepid['address'] == tp['address']):
+ if threepid["medium"] == tp["medium"] and threepid["address"] == tp["address"]:
return True
return False
@@ -680,9 +687,7 @@ def read_gc_thresholds(thresholds):
return None
try:
assert len(thresholds) == 3
- return (
- int(thresholds[0]), int(thresholds[1]), int(thresholds[2]),
- )
+ return (int(thresholds[0]), int(thresholds[1]), int(thresholds[2]))
except Exception:
raise ConfigError(
"Value of `gc_threshold` must be a list of three integers if set"
@@ -700,22 +705,22 @@ def _warn_if_webclient_configured(listeners):
for listener in listeners:
for res in listener.get("resources", []):
for name in res.get("names", []):
- if name == 'webclient':
+ if name == "webclient":
logger.warning(NO_MORE_WEB_CLIENT_WARNING)
return
KNOWN_RESOURCES = (
- 'client',
- 'consent',
- 'federation',
- 'keys',
- 'media',
- 'metrics',
- 'openid',
- 'replication',
- 'static',
- 'webclient',
+ "client",
+ "consent",
+ "federation",
+ "keys",
+ "media",
+ "metrics",
+ "openid",
+ "replication",
+ "static",
+ "webclient",
)
@@ -729,11 +734,9 @@ def _check_resource_config(listeners):
for resource in resource_names:
if resource not in KNOWN_RESOURCES:
- raise ConfigError(
- "Unknown listener resource '%s'" % (resource, )
- )
+ raise ConfigError("Unknown listener resource '%s'" % (resource,))
if resource == "consent":
try:
- check_requirements('resources.consent')
+ check_requirements("resources.consent")
except DependencyException as e:
raise ConfigError(e.message)
diff --git a/synapse/config/server_notices_config.py b/synapse/config/server_notices_config.py
index 529dc0a617..eaac3d73bc 100644
--- a/synapse/config/server_notices_config.py
+++ b/synapse/config/server_notices_config.py
@@ -58,6 +58,7 @@ class ServerNoticesConfig(Config):
The name to use for the server notices room.
None if server notices are not enabled.
"""
+
def __init__(self):
super(ServerNoticesConfig, self).__init__()
self.server_notices_mxid = None
@@ -65,23 +66,17 @@ class ServerNoticesConfig(Config):
self.server_notices_mxid_avatar_url = None
self.server_notices_room_name = None
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
c = config.get("server_notices")
if c is None:
return
- mxid_localpart = c['system_mxid_localpart']
- self.server_notices_mxid = UserID(
- mxid_localpart, self.server_name,
- ).to_string()
- self.server_notices_mxid_display_name = c.get(
- 'system_mxid_display_name', None,
- )
- self.server_notices_mxid_avatar_url = c.get(
- 'system_mxid_avatar_url', None,
- )
+ mxid_localpart = c["system_mxid_localpart"]
+ self.server_notices_mxid = UserID(mxid_localpart, self.server_name).to_string()
+ self.server_notices_mxid_display_name = c.get("system_mxid_display_name", None)
+ self.server_notices_mxid_avatar_url = c.get("system_mxid_avatar_url", None)
# todo: i18n
- self.server_notices_room_name = c.get('room_name', "Server Notices")
+ self.server_notices_room_name = c.get("room_name", "Server Notices")
- def default_config(self, **kwargs):
+ def generate_config_section(self, **kwargs):
return DEFAULT_CONFIG
diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
index 1502e9faba..e40797ab50 100644
--- a/synapse/config/spam_checker.py
+++ b/synapse/config/spam_checker.py
@@ -19,14 +19,14 @@ from ._base import Config
class SpamCheckerConfig(Config):
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
self.spam_checker = None
provider = config.get("spam_checker", None)
if provider is not None:
self.spam_checker = load_module(provider)
- def default_config(self, **kwargs):
+ def generate_config_section(self, **kwargs):
return """\
#spam_checker:
# module: "my_custom_project.SuperSpamChecker"
diff --git a/synapse/config/stats.py b/synapse/config/stats.py
index 80fc1b9dd0..b518a3ed9c 100644
--- a/synapse/config/stats.py
+++ b/synapse/config/stats.py
@@ -25,7 +25,7 @@ class StatsConfig(Config):
Configuration for the behaviour of synapse's stats engine
"""
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
self.stats_enabled = True
self.stats_bucket_size = 86400
self.stats_retention = sys.maxsize
@@ -42,7 +42,7 @@ class StatsConfig(Config):
/ 1000
)
- def default_config(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
# Local statistics collection. Used in populating the room directory.
#
diff --git a/synapse/config/third_party_event_rules.py b/synapse/config/third_party_event_rules.py
new file mode 100644
index 0000000000..b3431441b9
--- /dev/null
+++ b/synapse/config/third_party_event_rules.py
@@ -0,0 +1,42 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util.module_loader import load_module
+
+from ._base import Config
+
+
+class ThirdPartyRulesConfig(Config):
+ def read_config(self, config, **kwargs):
+ self.third_party_event_rules = None
+
+ provider = config.get("third_party_event_rules", None)
+ if provider is not None:
+ self.third_party_event_rules = load_module(provider)
+
+ def generate_config_section(self, **kwargs):
+ return """\
+ # Server admins can define a Python module that implements extra rules for
+ # allowing or denying incoming events. In order to work, this module needs to
+ # override the methods defined in synapse/events/third_party_rules.py.
+ #
+ # This feature is designed to be used in closed federations only, where each
+ # participating server enforces the same rules.
+ #
+ #third_party_event_rules:
+ # module: "my_custom_project.SuperRulesSet"
+ # config:
+ # example_option: 'things'
+ """
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 658f9dd361..8fcf801418 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
class TlsConfig(Config):
- def read_config(self, config):
+ def read_config(self, config, config_dir_path, **kwargs):
acme_config = config.get("acme", None)
if acme_config is None:
@@ -42,14 +42,18 @@ class TlsConfig(Config):
self.acme_enabled = acme_config.get("enabled", False)
# hyperlink complains on py2 if this is not a Unicode
- self.acme_url = six.text_type(acme_config.get(
- "url", u"https://acme-v01.api.letsencrypt.org/directory"
- ))
+ self.acme_url = six.text_type(
+ acme_config.get("url", "https://acme-v01.api.letsencrypt.org/directory")
+ )
self.acme_port = acme_config.get("port", 80)
- self.acme_bind_addresses = acme_config.get("bind_addresses", ['::', '0.0.0.0'])
+ self.acme_bind_addresses = acme_config.get("bind_addresses", ["::", "0.0.0.0"])
self.acme_reprovision_threshold = acme_config.get("reprovision_threshold", 30)
self.acme_domain = acme_config.get("domain", config.get("server_name"))
+ self.acme_account_key_file = self.abspath(
+ acme_config.get("account_key_file", config_dir_path + "/client.key")
+ )
+
self.tls_certificate_file = self.abspath(config.get("tls_certificate_path"))
self.tls_private_key_file = self.abspath(config.get("tls_private_key_path"))
@@ -74,12 +78,12 @@ class TlsConfig(Config):
# Whether to verify certificates on outbound federation traffic
self.federation_verify_certificates = config.get(
- "federation_verify_certificates", True,
+ "federation_verify_certificates", True
)
# Whitelist of domains to not verify certificates for
fed_whitelist_entries = config.get(
- "federation_certificate_verification_whitelist", [],
+ "federation_certificate_verification_whitelist", []
)
# Support globs (*) in whitelist values
@@ -90,9 +94,7 @@ class TlsConfig(Config):
self.federation_certificate_verification_whitelist.append(entry_regex)
# List of custom certificate authorities for federation traffic validation
- custom_ca_list = config.get(
- "federation_custom_ca_list", None,
- )
+ custom_ca_list = config.get("federation_custom_ca_list", None)
# Read in and parse custom CA certificates
self.federation_ca_trust_root = None
@@ -101,8 +103,10 @@ class TlsConfig(Config):
# A trustroot cannot be generated without any CA certificates.
# Raise an error if this option has been specified without any
# corresponding certificates.
- raise ConfigError("federation_custom_ca_list specified without "
- "any certificate files")
+ raise ConfigError(
+ "federation_custom_ca_list specified without "
+ "any certificate files"
+ )
certs = []
for ca_file in custom_ca_list:
@@ -114,8 +118,9 @@ class TlsConfig(Config):
cert_base = Certificate.loadPEM(content)
certs.append(cert_base)
except Exception as e:
- raise ConfigError("Error parsing custom CA certificate file %s: %s"
- % (ca_file, e))
+ raise ConfigError(
+ "Error parsing custom CA certificate file %s: %s" % (ca_file, e)
+ )
self.federation_ca_trust_root = trustRootFromCertificates(certs)
@@ -146,17 +151,21 @@ class TlsConfig(Config):
return None
try:
- with open(self.tls_certificate_file, 'rb') as f:
+ with open(self.tls_certificate_file, "rb") as f:
cert_pem = f.read()
except Exception as e:
- raise ConfigError("Failed to read existing certificate file %s: %s"
- % (self.tls_certificate_file, e))
+ raise ConfigError(
+ "Failed to read existing certificate file %s: %s"
+ % (self.tls_certificate_file, e)
+ )
try:
tls_certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)
except Exception as e:
- raise ConfigError("Failed to parse existing certificate file %s: %s"
- % (self.tls_certificate_file, e))
+ raise ConfigError(
+ "Failed to parse existing certificate file %s: %s"
+ % (self.tls_certificate_file, e)
+ )
if not allow_self_signed:
if tls_certificate.get_subject() == tls_certificate.get_issuer():
@@ -166,7 +175,7 @@ class TlsConfig(Config):
# YYYYMMDDhhmmssZ -- in UTC
expires_on = datetime.strptime(
- tls_certificate.get_notAfter().decode('ascii'), "%Y%m%d%H%M%SZ"
+ tls_certificate.get_notAfter().decode("ascii"), "%Y%m%d%H%M%SZ"
)
now = datetime.utcnow()
days_remaining = (expires_on - now).days
@@ -191,7 +200,8 @@ class TlsConfig(Config):
except Exception as e:
logger.info(
"Unable to read TLS certificate (%s). Ignoring as no "
- "tls listeners enabled.", e,
+ "tls listeners enabled.",
+ e,
)
self.tls_fingerprints = list(self._original_tls_fingerprints)
@@ -205,18 +215,21 @@ class TlsConfig(Config):
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints)
if sha256_fingerprint not in sha256_fingerprints:
- self.tls_fingerprints.append({u"sha256": sha256_fingerprint})
+ self.tls_fingerprints.append({"sha256": sha256_fingerprint})
- def default_config(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(
+ self, config_dir_path, server_name, data_dir_path, **kwargs
+ ):
base_key_name = os.path.join(config_dir_path, server_name)
tls_certificate_path = base_key_name + ".tls.crt"
tls_private_key_path = base_key_name + ".tls.key"
+ default_acme_account_file = os.path.join(data_dir_path, "acme_account.key")
# this is to avoid the max line length. Sorrynotsorry
proxypassline = (
- 'ProxyPass /.well-known/acme-challenge '
- 'http://localhost:8009/.well-known/acme-challenge'
+ "ProxyPass /.well-known/acme-challenge "
+ "http://localhost:8009/.well-known/acme-challenge"
)
return (
@@ -337,6 +350,13 @@ class TlsConfig(Config):
#
#domain: matrix.example.com
+ # file to use for the account key. This will be generated if it doesn't
+ # exist.
+ #
+ # If unspecified, we will use CONFDIR/client.key.
+ #
+ account_key_file: %(default_acme_account_file)s
+
# List of allowed TLS fingerprints for this server to publish along
# with the signing keys for this server. Other matrix servers that
# make HTTPS requests to this server will check that the TLS
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index 023997ccde..f6313e17d4 100644
--- a/synapse/config/user_directory.py
+++ b/synapse/config/user_directory.py
@@ -21,19 +21,19 @@ class UserDirectoryConfig(Config):
Configuration for the behaviour of the /user_directory API
"""
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
self.user_directory_search_enabled = True
self.user_directory_search_all_users = False
user_directory_config = config.get("user_directory", None)
if user_directory_config:
- self.user_directory_search_enabled = (
- user_directory_config.get("enabled", True)
+ self.user_directory_search_enabled = user_directory_config.get(
+ "enabled", True
)
- self.user_directory_search_all_users = (
- user_directory_config.get("search_all_users", False)
+ self.user_directory_search_all_users = user_directory_config.get(
+ "search_all_users", False
)
- def default_config(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
# User Directory configuration
#
diff --git a/synapse/config/voip.py b/synapse/config/voip.py
index 2a1f005a37..2ca0e1cf70 100644
--- a/synapse/config/voip.py
+++ b/synapse/config/voip.py
@@ -16,18 +16,17 @@ from ._base import Config
class VoipConfig(Config):
-
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
self.turn_uris = config.get("turn_uris", [])
self.turn_shared_secret = config.get("turn_shared_secret")
self.turn_username = config.get("turn_username")
self.turn_password = config.get("turn_password")
self.turn_user_lifetime = self.parse_duration(
- config.get("turn_user_lifetime", "1h"),
+ config.get("turn_user_lifetime", "1h")
)
self.turn_allow_guests = config.get("turn_allow_guests", True)
- def default_config(self, **kwargs):
+ def generate_config_section(self, **kwargs):
return """\
## TURN ##
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index bfbd8b6c91..3b75471d85 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -21,7 +21,7 @@ class WorkerConfig(Config):
They have their own pid_file and listener configuration. They use the
replication_url to talk to the main synapse process."""
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
self.worker_app = config.get("worker_app")
# Canonicalise worker_app so that master always has None
@@ -46,18 +46,19 @@ class WorkerConfig(Config):
self.worker_name = config.get("worker_name", self.worker_app)
self.worker_main_http_uri = config.get("worker_main_http_uri", None)
- self.worker_cpu_affinity = config.get("worker_cpu_affinity")
# This option is really only here to support `--manhole` command line
# argument.
manhole = config.get("worker_manhole")
if manhole:
- self.worker_listeners.append({
- "port": manhole,
- "bind_addresses": ["127.0.0.1"],
- "type": "manhole",
- "tls": False,
- })
+ self.worker_listeners.append(
+ {
+ "port": manhole,
+ "bind_addresses": ["127.0.0.1"],
+ "type": "manhole",
+ "tls": False,
+ }
+ )
if self.worker_listeners:
for listener in self.worker_listeners:
@@ -67,7 +68,7 @@ class WorkerConfig(Config):
if bind_address:
bind_addresses.append(bind_address)
elif not bind_addresses:
- bind_addresses.append('')
+ bind_addresses.append("")
def read_arguments(self, args):
# We support a bunch of command line arguments that override options in
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index 99a586655b..41eabbe717 100644
--- a/synapse/crypto/event_signing.py
+++ b/synapse/crypto/event_signing.py
@@ -46,9 +46,7 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
if name not in hashes:
raise SynapseError(
400,
- "Algorithm %s not in hashes %s" % (
- name, list(hashes),
- ),
+ "Algorithm %s not in hashes %s" % (name, list(hashes)),
Codes.UNAUTHORIZED,
)
message_hash_base64 = hashes[name]
@@ -56,9 +54,7 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
message_hash_bytes = decode_base64(message_hash_base64)
except Exception:
raise SynapseError(
- 400,
- "Invalid base64: %s" % (message_hash_base64,),
- Codes.UNAUTHORIZED,
+ 400, "Invalid base64: %s" % (message_hash_base64,), Codes.UNAUTHORIZED
)
return message_hash_bytes == expected_hash
@@ -135,8 +131,9 @@ def compute_event_signature(event_dict, signature_name, signing_key):
return redact_json["signatures"]
-def add_hashes_and_signatures(event_dict, signature_name, signing_key,
- hash_algorithm=hashlib.sha256):
+def add_hashes_and_signatures(
+ event_dict, signature_name, signing_key, hash_algorithm=hashlib.sha256
+):
"""Add content hash and sign the event
Args:
@@ -153,7 +150,5 @@ def add_hashes_and_signatures(event_dict, signature_name, signing_key,
event_dict.setdefault("hashes", {})[name] = encode_base64(digest)
event_dict["signatures"] = compute_event_signature(
- event_dict,
- signature_name=signature_name,
- signing_key=signing_key,
+ event_dict, signature_name=signature_name, signing_key=signing_key
)
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 6f603f1961..10c2eb7f0f 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -505,7 +505,7 @@ class BaseV2KeyFetcher(object):
Returns:
Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
"""
- ts_valid_until_ms = response_json[u"valid_until_ts"]
+ ts_valid_until_ms = response_json["valid_until_ts"]
# start by extracting the keys from the response, since they may be required
# to validate the signature on the response.
@@ -614,10 +614,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
results = yield logcontext.make_deferred_yieldable(
defer.gatherResults(
- [
- run_in_background(get_key, server)
- for server in self.key_servers
- ],
+ [run_in_background(get_key, server) for server in self.key_servers],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
@@ -630,9 +627,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
defer.returnValue(union_of_keys)
@defer.inlineCallbacks
- def get_server_verify_key_v2_indirect(
- self, keys_to_fetch, key_server
- ):
+ def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
"""
Args:
keys_to_fetch (dict[str, dict[str, int]]):
@@ -661,9 +656,9 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
destination=perspective_name,
path="/_matrix/key/v2/query",
data={
- u"server_keys": {
+ "server_keys": {
server_name: {
- key_id: {u"minimum_valid_until_ts": min_valid_ts}
+ key_id: {"minimum_valid_until_ts": min_valid_ts}
for key_id, min_valid_ts in server_keys.items()
}
for server_name, server_keys in keys_to_fetch.items()
@@ -690,10 +685,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
)
try:
- self._validate_perspectives_response(
- key_server,
- response,
- )
+ self._validate_perspectives_response(key_server, response)
processed_response = yield self.process_v2_response(
perspective_name, response, time_added_ms=time_now_ms
@@ -720,9 +712,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
defer.returnValue(keys)
- def _validate_perspectives_response(
- self, key_server, response,
- ):
+ def _validate_perspectives_response(self, key_server, response):
"""Optionally check the signature on the result of a /key/query request
Args:
@@ -739,13 +729,13 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
return
if (
- u"signatures" not in response
- or perspective_name not in response[u"signatures"]
+ "signatures" not in response
+ or perspective_name not in response["signatures"]
):
raise KeyLookupError("Response not signed by the notary server")
verified = False
- for key_id in response[u"signatures"][perspective_name]:
+ for key_id in response["signatures"][perspective_name]:
if key_id in perspective_keys:
verify_signed_json(response, perspective_name, perspective_keys[key_id])
verified = True
@@ -754,7 +744,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
raise KeyLookupError(
"Response not signed with a known key: signed with: %r, known keys: %r"
% (
- list(response[u"signatures"][perspective_name].keys()),
+ list(response["signatures"][perspective_name].keys()),
list(perspective_keys.keys()),
)
)
@@ -826,7 +816,6 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
path="/_matrix/key/v2/server/"
+ urllib.parse.quote(requested_key_id),
ignore_backoff=True,
-
# we only give the remote server 10s to respond. It should be an
# easy request to handle, so if it doesn't reply within 10s, it's
# probably not going to.
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 203490fc36..cd52e3f867 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -85,17 +85,14 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
room_id_domain = get_domain_from_id(event.room_id)
if room_id_domain != sender_domain:
raise AuthError(
- 403,
- "Creation event's room_id domain does not match sender's"
+ 403, "Creation event's room_id domain does not match sender's"
)
room_version = event.content.get("room_version", "1")
if room_version not in KNOWN_ROOM_VERSIONS:
raise AuthError(
- 403,
- "room appears to have unsupported version %s" % (
- room_version,
- ))
+ 403, "room appears to have unsupported version %s" % (room_version,)
+ )
# FIXME
logger.debug("Allowing! %s", event)
return
@@ -103,46 +100,30 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
creation_event = auth_events.get((EventTypes.Create, ""), None)
if not creation_event:
- raise AuthError(
- 403,
- "No create event in auth events",
- )
+ raise AuthError(403, "No create event in auth events")
creating_domain = get_domain_from_id(event.room_id)
originating_domain = get_domain_from_id(event.sender)
if creating_domain != originating_domain:
if not _can_federate(event, auth_events):
- raise AuthError(
- 403,
- "This room has been marked as unfederatable."
- )
+ raise AuthError(403, "This room has been marked as unfederatable.")
# FIXME: Temp hack
if event.type == EventTypes.Aliases:
if not event.is_state():
- raise AuthError(
- 403,
- "Alias event must be a state event",
- )
+ raise AuthError(403, "Alias event must be a state event")
if not event.state_key:
- raise AuthError(
- 403,
- "Alias event must have non-empty state_key"
- )
+ raise AuthError(403, "Alias event must have non-empty state_key")
sender_domain = get_domain_from_id(event.sender)
if event.state_key != sender_domain:
raise AuthError(
- 403,
- "Alias event's state_key does not match sender's domain"
+ 403, "Alias event's state_key does not match sender's domain"
)
logger.debug("Allowing! %s", event)
return
if logger.isEnabledFor(logging.DEBUG):
- logger.debug(
- "Auth events: %s",
- [a.event_id for a in auth_events.values()]
- )
+ logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()])
if event.type == EventTypes.Member:
_is_membership_change_allowed(event, auth_events)
@@ -159,9 +140,7 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
invite_level = _get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
- raise AuthError(
- 403, "You don't have permission to invite users",
- )
+ raise AuthError(403, "You don't have permission to invite users")
else:
logger.debug("Allowing! %s", event)
return
@@ -207,7 +186,7 @@ def _is_membership_change_allowed(event, auth_events):
# Check if this is the room creator joining:
if len(event.prev_event_ids()) == 1 and Membership.JOIN == membership:
# Get room creation event:
- key = (EventTypes.Create, "", )
+ key = (EventTypes.Create, "")
create = auth_events.get(key)
if create and event.prev_event_ids()[0] == create.event_id:
if create.content["creator"] == event.state_key:
@@ -219,38 +198,31 @@ def _is_membership_change_allowed(event, auth_events):
target_domain = get_domain_from_id(target_user_id)
if creating_domain != target_domain:
if not _can_federate(event, auth_events):
- raise AuthError(
- 403,
- "This room has been marked as unfederatable."
- )
+ raise AuthError(403, "This room has been marked as unfederatable.")
# get info about the caller
- key = (EventTypes.Member, event.user_id, )
+ key = (EventTypes.Member, event.user_id)
caller = auth_events.get(key)
caller_in_room = caller and caller.membership == Membership.JOIN
caller_invited = caller and caller.membership == Membership.INVITE
# get info about the target
- key = (EventTypes.Member, target_user_id, )
+ key = (EventTypes.Member, target_user_id)
target = auth_events.get(key)
target_in_room = target and target.membership == Membership.JOIN
target_banned = target and target.membership == Membership.BAN
- key = (EventTypes.JoinRules, "", )
+ key = (EventTypes.JoinRules, "")
join_rule_event = auth_events.get(key)
if join_rule_event:
- join_rule = join_rule_event.content.get(
- "join_rule", JoinRules.INVITE
- )
+ join_rule = join_rule_event.content.get("join_rule", JoinRules.INVITE)
else:
join_rule = JoinRules.INVITE
user_level = get_user_power_level(event.user_id, auth_events)
- target_level = get_user_power_level(
- target_user_id, auth_events
- )
+ target_level = get_user_power_level(target_user_id, auth_events)
# FIXME (erikj): What should we do here as the default?
ban_level = _get_named_level(auth_events, "ban", 50)
@@ -266,29 +238,26 @@ def _is_membership_change_allowed(event, auth_events):
"join_rule": join_rule,
"target_user_id": target_user_id,
"event.user_id": event.user_id,
- }
+ },
)
if Membership.INVITE == membership and "third_party_invite" in event.content:
if not _verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.")
if target_banned:
- raise AuthError(
- 403, "%s is banned from the room" % (target_user_id,)
- )
+ raise AuthError(403, "%s is banned from the room" % (target_user_id,))
return
if Membership.JOIN != membership:
- if (caller_invited
- and Membership.LEAVE == membership
- and target_user_id == event.user_id):
+ if (
+ caller_invited
+ and Membership.LEAVE == membership
+ and target_user_id == event.user_id
+ ):
return
if not caller_in_room: # caller isn't joined
- raise AuthError(
- 403,
- "%s not in room %s." % (event.user_id, event.room_id,)
- )
+ raise AuthError(403, "%s not in room %s." % (event.user_id, event.room_id))
if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently
@@ -296,19 +265,14 @@ def _is_membership_change_allowed(event, auth_events):
# Invites are valid iff caller is in the room and target isn't.
if target_banned:
- raise AuthError(
- 403, "%s is banned from the room" % (target_user_id,)
- )
+ raise AuthError(403, "%s is banned from the room" % (target_user_id,))
elif target_in_room: # the target is already in the room.
- raise AuthError(403, "%s is already in the room." %
- target_user_id)
+ raise AuthError(403, "%s is already in the room." % target_user_id)
else:
invite_level = _get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
- raise AuthError(
- 403, "You don't have permission to invite users",
- )
+ raise AuthError(403, "You don't have permission to invite users")
elif Membership.JOIN == membership:
# Joins are valid iff caller == target and they were:
# invited: They are accepting the invitation
@@ -329,16 +293,12 @@ def _is_membership_change_allowed(event, auth_events):
elif Membership.LEAVE == membership:
# TODO (erikj): Implement kicks.
if target_banned and user_level < ban_level:
- raise AuthError(
- 403, "You cannot unban user %s." % (target_user_id,)
- )
+ raise AuthError(403, "You cannot unban user %s." % (target_user_id,))
elif target_user_id != event.user_id:
kick_level = _get_named_level(auth_events, "kick", 50)
if user_level < kick_level or user_level <= target_level:
- raise AuthError(
- 403, "You cannot kick user %s." % target_user_id
- )
+ raise AuthError(403, "You cannot kick user %s." % target_user_id)
elif Membership.BAN == membership:
if user_level < ban_level or user_level <= target_level:
raise AuthError(403, "You don't have permission to ban")
@@ -347,21 +307,17 @@ def _is_membership_change_allowed(event, auth_events):
def _check_event_sender_in_room(event, auth_events):
- key = (EventTypes.Member, event.user_id, )
+ key = (EventTypes.Member, event.user_id)
member_event = auth_events.get(key)
- return _check_joined_room(
- member_event,
- event.user_id,
- event.room_id
- )
+ return _check_joined_room(member_event, event.user_id, event.room_id)
def _check_joined_room(member, user_id, room_id):
if not member or member.membership != Membership.JOIN:
- raise AuthError(403, "User %s not in room %s (%s)" % (
- user_id, room_id, repr(member)
- ))
+ raise AuthError(
+ 403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member))
+ )
def get_send_level(etype, state_key, power_levels_event):
@@ -402,26 +358,21 @@ def get_send_level(etype, state_key, power_levels_event):
def _can_send_event(event, auth_events):
power_levels_event = _get_power_level_event(auth_events)
- send_level = get_send_level(
- event.type, event.get("state_key"), power_levels_event,
- )
+ send_level = get_send_level(event.type, event.get("state_key"), power_levels_event)
user_level = get_user_power_level(event.user_id, auth_events)
if user_level < send_level:
raise AuthError(
403,
- "You don't have permission to post that to the room. " +
- "user_level (%d) < send_level (%d)" % (user_level, send_level)
+ "You don't have permission to post that to the room. "
+ + "user_level (%d) < send_level (%d)" % (user_level, send_level),
)
# Check state_key
if hasattr(event, "state_key"):
if event.state_key.startswith("@"):
if event.state_key != event.user_id:
- raise AuthError(
- 403,
- "You are not allowed to set others state"
- )
+ raise AuthError(403, "You are not allowed to set others state")
return True
@@ -459,10 +410,7 @@ def check_redaction(room_version, event, auth_events):
event.internal_metadata.recheck_redaction = True
return True
- raise AuthError(
- 403,
- "You don't have permission to redact events"
- )
+ raise AuthError(403, "You don't have permission to redact events")
def _check_power_levels(event, auth_events):
@@ -479,7 +427,7 @@ def _check_power_levels(event, auth_events):
except Exception:
raise SynapseError(400, "Not a valid power level: %s" % (v,))
- key = (event.type, event.state_key, )
+ key = (event.type, event.state_key)
current_state = auth_events.get(key)
if not current_state:
@@ -500,16 +448,12 @@ def _check_power_levels(event, auth_events):
old_list = current_state.content.get("users", {})
for user in set(list(old_list) + list(user_list)):
- levels_to_check.append(
- (user, "users")
- )
+ levels_to_check.append((user, "users"))
old_list = current_state.content.get("events", {})
new_list = event.content.get("events", {})
for ev_id in set(list(old_list) + list(new_list)):
- levels_to_check.append(
- (ev_id, "events")
- )
+ levels_to_check.append((ev_id, "events"))
old_state = current_state.content
new_state = event.content
@@ -540,7 +484,7 @@ def _check_power_levels(event, auth_events):
raise AuthError(
403,
"You don't have permission to remove ops level equal "
- "to your own"
+ "to your own",
)
# Check if the old and new levels are greater than the user level
@@ -550,8 +494,7 @@ def _check_power_levels(event, auth_events):
if old_level_too_big or new_level_too_big:
raise AuthError(
403,
- "You don't have permission to add ops level greater "
- "than your own"
+ "You don't have permission to add ops level greater " "than your own",
)
@@ -587,10 +530,9 @@ def get_user_power_level(user_id, auth_events):
# some things which call this don't pass the create event: hack around
# that.
- key = (EventTypes.Create, "", )
+ key = (EventTypes.Create, "")
create_event = auth_events.get(key)
- if (create_event is not None and
- create_event.content["creator"] == user_id):
+ if create_event is not None and create_event.content["creator"] == user_id:
return 100
else:
return 0
@@ -636,9 +578,7 @@ def _verify_third_party_invite(event, auth_events):
token = signed["token"]
- invite_event = auth_events.get(
- (EventTypes.ThirdPartyInvite, token,)
- )
+ invite_event = auth_events.get((EventTypes.ThirdPartyInvite, token))
if not invite_event:
return False
@@ -661,8 +601,7 @@ def _verify_third_party_invite(event, auth_events):
if not key_name.startswith("ed25519:"):
continue
verify_key = decode_verify_key_bytes(
- key_name,
- decode_base64(public_key)
+ key_name, decode_base64(public_key)
)
verify_signed_json(signed, server, verify_key)
@@ -671,7 +610,7 @@ def _verify_third_party_invite(event, auth_events):
# The caller is responsible for checking that the signing
# server has not revoked that public key.
return True
- except (KeyError, SignatureVerifyException,):
+ except (KeyError, SignatureVerifyException):
continue
return False
@@ -679,9 +618,7 @@ def _verify_third_party_invite(event, auth_events):
def get_public_keys(invite_event):
public_keys = []
if "public_key" in invite_event.content:
- o = {
- "public_key": invite_event.content["public_key"],
- }
+ o = {"public_key": invite_event.content["public_key"]}
if "key_validity_url" in invite_event.content:
o["key_validity_url"] = invite_event.content["key_validity_url"]
public_keys.append(o)
@@ -702,22 +639,22 @@ def auth_types_for_event(event):
auth_types = []
- auth_types.append((EventTypes.PowerLevels, "", ))
- auth_types.append((EventTypes.Member, event.sender, ))
- auth_types.append((EventTypes.Create, "", ))
+ auth_types.append((EventTypes.PowerLevels, ""))
+ auth_types.append((EventTypes.Member, event.sender))
+ auth_types.append((EventTypes.Create, ""))
if event.type == EventTypes.Member:
membership = event.content["membership"]
if membership in [Membership.JOIN, Membership.INVITE]:
- auth_types.append((EventTypes.JoinRules, "", ))
+ auth_types.append((EventTypes.JoinRules, ""))
- auth_types.append((EventTypes.Member, event.state_key, ))
+ auth_types.append((EventTypes.Member, event.state_key))
if membership == Membership.INVITE:
if "third_party_invite" in event.content:
key = (
EventTypes.ThirdPartyInvite,
- event.content["third_party_invite"]["signed"]["token"]
+ event.content["third_party_invite"]["signed"]["token"],
)
auth_types.append(key)
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 1edd19cc13..d3de70e671 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -92,6 +92,18 @@ class _EventInternalMetadata(object):
"""
return getattr(self, "soft_failed", False)
+ def should_proactively_send(self):
+ """Whether the event, if ours, should be sent to other clients and
+ servers.
+
+ This is used for sending dummy events internally. Servers and clients
+ can still explicitly fetch the event.
+
+ Returns:
+ bool
+ """
+ return getattr(self, "proactively_send", True)
+
def _event_dict_property(key):
# We want to be able to use hasattr with the event dict properties.
@@ -115,25 +127,25 @@ def _event_dict_property(key):
except KeyError:
raise AttributeError(key)
- return property(
- getter,
- setter,
- delete,
- )
+ return property(getter, setter, delete)
class EventBase(object):
- def __init__(self, event_dict, signatures={}, unsigned={},
- internal_metadata_dict={}, rejected_reason=None):
+ def __init__(
+ self,
+ event_dict,
+ signatures={},
+ unsigned={},
+ internal_metadata_dict={},
+ rejected_reason=None,
+ ):
self.signatures = signatures
self.unsigned = unsigned
self.rejected_reason = rejected_reason
self._event_dict = event_dict
- self.internal_metadata = _EventInternalMetadata(
- internal_metadata_dict
- )
+ self.internal_metadata = _EventInternalMetadata(internal_metadata_dict)
auth_events = _event_dict_property("auth_events")
depth = _event_dict_property("depth")
@@ -156,10 +168,7 @@ class EventBase(object):
def get_dict(self):
d = dict(self._event_dict)
- d.update({
- "signatures": self.signatures,
- "unsigned": dict(self.unsigned),
- })
+ d.update({"signatures": self.signatures, "unsigned": dict(self.unsigned)})
return d
@@ -346,6 +355,7 @@ class FrozenEventV2(EventBase):
class FrozenEventV3(FrozenEventV2):
"""FrozenEventV3, which differs from FrozenEventV2 only in the event_id format"""
+
format_version = EventFormatVersions.V3 # All events of this type are V3
@property
@@ -402,6 +412,4 @@ def event_type_from_format_version(format_version):
elif format_version == EventFormatVersions.V3:
return FrozenEventV3
else:
- raise Exception(
- "No event format %r" % (format_version,)
- )
+ raise Exception("No event format %r" % (format_version,))
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 546b6f4982..db011e0407 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -78,7 +78,9 @@ class EventBuilder(object):
_redacts = attr.ib(default=None)
_origin_server_ts = attr.ib(default=None)
- internal_metadata = attr.ib(default=attr.Factory(lambda: _EventInternalMetadata({})))
+ internal_metadata = attr.ib(
+ default=attr.Factory(lambda: _EventInternalMetadata({}))
+ )
@property
def state_key(self):
@@ -102,11 +104,9 @@ class EventBuilder(object):
"""
state_ids = yield self._state.get_current_state_ids(
- self.room_id, prev_event_ids,
- )
- auth_ids = yield self._auth.compute_auth_events(
- self, state_ids,
+ self.room_id, prev_event_ids
)
+ auth_ids = yield self._auth.compute_auth_events(self, state_ids)
if self.format_version == EventFormatVersions.V1:
auth_events = yield self._store.add_event_hashes(auth_ids)
@@ -115,9 +115,7 @@ class EventBuilder(object):
auth_events = auth_ids
prev_events = prev_event_ids
- old_depth = yield self._store.get_max_depth_of(
- prev_event_ids,
- )
+ old_depth = yield self._store.get_max_depth_of(prev_event_ids)
depth = old_depth + 1
# we cap depth of generated events, to ensure that they are not
@@ -217,9 +215,14 @@ class EventBuilderFactory(object):
)
-def create_local_event_from_event_dict(clock, hostname, signing_key,
- format_version, event_dict,
- internal_metadata_dict=None):
+def create_local_event_from_event_dict(
+ clock,
+ hostname,
+ signing_key,
+ format_version,
+ event_dict,
+ internal_metadata_dict=None,
+):
"""Takes a fully formed event dict, ensuring that fields like `origin`
and `origin_server_ts` have correct values for a locally produced event,
then signs and hashes it.
@@ -237,9 +240,7 @@ def create_local_event_from_event_dict(clock, hostname, signing_key,
"""
if format_version not in KNOWN_EVENT_FORMAT_VERSIONS:
- raise Exception(
- "No event format defined for version %r" % (format_version,)
- )
+ raise Exception("No event format defined for version %r" % (format_version,))
if internal_metadata_dict is None:
internal_metadata_dict = {}
@@ -258,13 +259,9 @@ def create_local_event_from_event_dict(clock, hostname, signing_key,
event_dict.setdefault("signatures", {})
- add_hashes_and_signatures(
- event_dict,
- hostname,
- signing_key,
- )
+ add_hashes_and_signatures(event_dict, hostname, signing_key)
return event_type_from_format_version(format_version)(
- event_dict, internal_metadata_dict=internal_metadata_dict,
+ event_dict, internal_metadata_dict=internal_metadata_dict
)
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index fa09c132a0..a96cdada3d 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -88,8 +88,9 @@ class EventContext(object):
self.app_service = None
@staticmethod
- def with_state(state_group, current_state_ids, prev_state_ids,
- prev_group=None, delta_ids=None):
+ def with_state(
+ state_group, current_state_ids, prev_state_ids, prev_group=None, delta_ids=None
+ ):
context = EventContext()
# The current state including the current event
@@ -132,17 +133,19 @@ class EventContext(object):
else:
prev_state_id = None
- defer.returnValue({
- "prev_state_id": prev_state_id,
- "event_type": event.type,
- "event_state_key": event.state_key if event.is_state() else None,
- "state_group": self.state_group,
- "rejected": self.rejected,
- "prev_group": self.prev_group,
- "delta_ids": _encode_state_dict(self.delta_ids),
- "prev_state_events": self.prev_state_events,
- "app_service_id": self.app_service.id if self.app_service else None
- })
+ defer.returnValue(
+ {
+ "prev_state_id": prev_state_id,
+ "event_type": event.type,
+ "event_state_key": event.state_key if event.is_state() else None,
+ "state_group": self.state_group,
+ "rejected": self.rejected,
+ "prev_group": self.prev_group,
+ "delta_ids": _encode_state_dict(self.delta_ids),
+ "prev_state_events": self.prev_state_events,
+ "app_service_id": self.app_service.id if self.app_service else None,
+ }
+ )
@staticmethod
def deserialize(store, input):
@@ -194,7 +197,7 @@ class EventContext(object):
if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(
- self._fill_out_state, store,
+ self._fill_out_state, store
)
yield make_deferred_yieldable(self._fetching_state_deferred)
@@ -214,7 +217,7 @@ class EventContext(object):
if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(
- self._fill_out_state, store,
+ self._fill_out_state, store
)
yield make_deferred_yieldable(self._fetching_state_deferred)
@@ -240,9 +243,7 @@ class EventContext(object):
if self.state_group is None:
return
- self._current_state_ids = yield store.get_state_ids_for_group(
- self.state_group,
- )
+ self._current_state_ids = yield store.get_state_ids_for_group(self.state_group)
if self._prev_state_id and self._event_state_key is not None:
self._prev_state_ids = dict(self._current_state_ids)
@@ -252,8 +253,9 @@ class EventContext(object):
self._prev_state_ids = self._current_state_ids
@defer.inlineCallbacks
- def update_state(self, state_group, prev_state_ids, current_state_ids,
- prev_group, delta_ids):
+ def update_state(
+ self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids
+ ):
"""Replace the state in the context
"""
@@ -279,10 +281,7 @@ def _encode_state_dict(state_dict):
if state_dict is None:
return None
- return [
- (etype, state_key, v)
- for (etype, state_key), v in iteritems(state_dict)
- ]
+ return [(etype, state_key, v) for (etype, state_key), v in iteritems(state_dict)]
def _decode_state_dict(input):
@@ -291,4 +290,4 @@ def _decode_state_dict(input):
if input is None:
return None
- return frozendict({(etype, state_key,): v for etype, state_key, v in input})
+ return frozendict({(etype, state_key): v for etype, state_key, v in input})
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 6058077f75..129771f183 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -60,7 +60,9 @@ class SpamChecker(object):
if self.spam_checker is None:
return True
- return self.spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
+ return self.spam_checker.user_may_invite(
+ inviter_userid, invitee_userid, room_id
+ )
def user_may_create_room(self, userid):
"""Checks if a given user may create a room
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
new file mode 100644
index 0000000000..8f5d95696b
--- /dev/null
+++ b/synapse/events/third_party_rules.py
@@ -0,0 +1,113 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+
+class ThirdPartyEventRules(object):
+ """Allows server admins to provide a Python module implementing an extra
+ set of rules to apply when processing events.
+
+ This is designed to help admins of closed federations with enforcing custom
+ behaviours.
+ """
+
+ def __init__(self, hs):
+ self.third_party_rules = None
+
+ self.store = hs.get_datastore()
+
+ module = None
+ config = None
+ if hs.config.third_party_event_rules:
+ module, config = hs.config.third_party_event_rules
+
+ if module is not None:
+ self.third_party_rules = module(
+ config=config, http_client=hs.get_simple_http_client()
+ )
+
+ @defer.inlineCallbacks
+ def check_event_allowed(self, event, context):
+ """Check if a provided event should be allowed in the given context.
+
+ Args:
+ event (synapse.events.EventBase): The event to be checked.
+ context (synapse.events.snapshot.EventContext): The context of the event.
+
+ Returns:
+ defer.Deferred[bool]: True if the event should be allowed, False if not.
+ """
+ if self.third_party_rules is None:
+ defer.returnValue(True)
+
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+
+ # Retrieve the state events from the database.
+ state_events = {}
+ for key, event_id in prev_state_ids.items():
+ state_events[key] = yield self.store.get_event(event_id, allow_none=True)
+
+ ret = yield self.third_party_rules.check_event_allowed(event, state_events)
+ defer.returnValue(ret)
+
+ @defer.inlineCallbacks
+ def on_create_room(self, requester, config, is_requester_admin):
+ """Intercept requests to create room to allow, deny or update the
+ request config.
+
+ Args:
+ requester (Requester)
+ config (dict): The creation config from the client.
+ is_requester_admin (bool): If the requester is an admin
+
+ Returns:
+ defer.Deferred
+ """
+
+ if self.third_party_rules is None:
+ return
+
+ yield self.third_party_rules.on_create_room(
+ requester, config, is_requester_admin
+ )
+
+ @defer.inlineCallbacks
+ def check_threepid_can_be_invited(self, medium, address, room_id):
+ """Check if a provided 3PID can be invited in the given room.
+
+ Args:
+ medium (str): The 3PID's medium.
+ address (str): The 3PID's address.
+ room_id (str): The room we want to invite the threepid to.
+
+ Returns:
+ defer.Deferred[bool], True if the 3PID can be invited, False if not.
+ """
+
+ if self.third_party_rules is None:
+ defer.returnValue(True)
+
+ state_ids = yield self.store.get_filtered_current_state_ids(room_id)
+ room_state_events = yield self.store.get_events(state_ids.values())
+
+ state_events = {}
+ for key, event_id in state_ids.items():
+ state_events[key] = room_state_events[event_id]
+
+ ret = yield self.third_party_rules.check_threepid_can_be_invited(
+ medium, address, state_events
+ )
+ defer.returnValue(ret)
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index e2d4384de1..f24f0c16f0 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -31,7 +31,7 @@ from . import EventBase
# by a match for 'stuff'.
# TODO: This is fast, but fails to handle "foo\\.bar" which should be treated as
# the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar"
-SPLIT_FIELD_REGEX = re.compile(r'(?<!\\)\.')
+SPLIT_FIELD_REGEX = re.compile(r"(?<!\\)\.")
def prune_event(event):
@@ -51,6 +51,7 @@ def prune_event(event):
pruned_event_dict = prune_event_dict(event.get_dict())
from . import event_type_from_format_version
+
return event_type_from_format_version(event.format_version)(
pruned_event_dict, event.internal_metadata.get_dict()
)
@@ -116,11 +117,7 @@ def prune_event_dict(event_dict):
elif event_type == EventTypes.RoomHistoryVisibility:
add_fields("history_visibility")
- allowed_fields = {
- k: v
- for k, v in event_dict.items()
- if k in allowed_keys
- }
+ allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys}
allowed_fields["content"] = new_content
@@ -205,7 +202,7 @@ def only_fields(dictionary, fields):
# for each element of the output array of arrays:
# remove escaping so we can use the right key names.
split_fields[:] = [
- [f.replace(r'\.', r'.') for f in field_array] for field_array in split_fields
+ [f.replace(r"\.", r".") for f in field_array] for field_array in split_fields
]
output = {}
@@ -226,7 +223,10 @@ def format_event_for_client_v1(d):
d["user_id"] = sender
copy_keys = (
- "age", "redacted_because", "replaces_state", "prev_content",
+ "age",
+ "redacted_because",
+ "replaces_state",
+ "prev_content",
"invite_room_state",
)
for key in copy_keys:
@@ -238,8 +238,13 @@ def format_event_for_client_v1(d):
def format_event_for_client_v2(d):
drop_keys = (
- "auth_events", "prev_events", "hashes", "signatures", "depth",
- "origin", "prev_state",
+ "auth_events",
+ "prev_events",
+ "hashes",
+ "signatures",
+ "depth",
+ "origin",
+ "prev_state",
)
for key in drop_keys:
d.pop(key, None)
@@ -252,9 +257,15 @@ def format_event_for_client_v2_without_room_id(d):
return d
-def serialize_event(e, time_now_ms, as_client_event=True,
- event_format=format_event_for_client_v1,
- token_id=None, only_event_fields=None, is_invite=False):
+def serialize_event(
+ e,
+ time_now_ms,
+ as_client_event=True,
+ event_format=format_event_for_client_v1,
+ token_id=None,
+ only_event_fields=None,
+ is_invite=False,
+):
"""Serialize event for clients
Args:
@@ -288,8 +299,7 @@ def serialize_event(e, time_now_ms, as_client_event=True,
if "redacted_because" in e.unsigned:
d["unsigned"]["redacted_because"] = serialize_event(
- e.unsigned["redacted_because"], time_now_ms,
- event_format=event_format
+ e.unsigned["redacted_because"], time_now_ms, event_format=event_format
)
if token_id is not None:
@@ -308,8 +318,9 @@ def serialize_event(e, time_now_ms, as_client_event=True,
d = event_format(d)
if only_event_fields:
- if (not isinstance(only_event_fields, list) or
- not all(isinstance(f, string_types) for f in only_event_fields)):
+ if not isinstance(only_event_fields, list) or not all(
+ isinstance(f, string_types) for f in only_event_fields
+ ):
raise TypeError("only_event_fields must be a list of strings")
d = only_fields(d, only_event_fields)
@@ -352,11 +363,9 @@ class EventClientSerializer(object):
# If MSC1849 is enabled then we need to look if thre are any relations
# we need to bundle in with the event
if self.experimental_msc1849_support_enabled and bundle_aggregations:
- annotations = yield self.store.get_aggregation_groups_for_event(
- event_id,
- )
+ annotations = yield self.store.get_aggregation_groups_for_event(event_id)
references = yield self.store.get_relations_for_event(
- event_id, RelationTypes.REFERENCE, direction="f",
+ event_id, RelationTypes.REFERENCE, direction="f"
)
if annotations.chunk:
@@ -383,9 +392,7 @@ class EventClientSerializer(object):
serialized_event["content"].pop("m.relates_to", None)
r = serialized_event["unsigned"].setdefault("m.relations", {})
- r[RelationTypes.REPLACE] = {
- "event_id": edit.event_id,
- }
+ r[RelationTypes.REPLACE] = {"event_id": edit.event_id}
defer.returnValue(serialized_event)
@@ -401,6 +408,5 @@ class EventClientSerializer(object):
Deferred[list[dict]]: The list of serialized events
"""
return yieldable_gather_results(
- self.serialize_event, events,
- time_now=time_now, **kwargs
+ self.serialize_event, events, time_now=time_now, **kwargs
)
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 711af512b2..f7ffd1d561 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -48,9 +48,7 @@ class EventValidator(object):
raise SynapseError(400, "Event does not have key %s" % (k,))
# Check that the following keys have string values
- event_strings = [
- "origin",
- ]
+ event_strings = ["origin"]
for s in event_strings:
if not isinstance(getattr(event, s), string_types):
@@ -62,8 +60,10 @@ class EventValidator(object):
if len(alias) > MAX_ALIAS_LENGTH:
raise SynapseError(
400,
- ("Can't create aliases longer than"
- " %d characters" % (MAX_ALIAS_LENGTH,)),
+ (
+ "Can't create aliases longer than"
+ " %d characters" % (MAX_ALIAS_LENGTH,)
+ ),
Codes.INVALID_PARAM,
)
@@ -76,11 +76,7 @@ class EventValidator(object):
event (EventBuilder|FrozenEvent)
"""
- strings = [
- "room_id",
- "sender",
- "type",
- ]
+ strings = ["room_id", "sender", "type"]
if hasattr(event, "state_key"):
strings.append("state_key")
@@ -93,10 +89,7 @@ class EventValidator(object):
UserID.from_string(event.sender)
if event.type == EventTypes.Message:
- strings = [
- "body",
- "msgtype",
- ]
+ strings = ["body", "msgtype"]
self._ensure_strings(event.content, strings)
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index fc5cfb7d83..1e925b19e7 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -44,8 +44,9 @@ class FederationBase(object):
self._clock = hs.get_clock()
@defer.inlineCallbacks
- def _check_sigs_and_hash_and_fetch(self, origin, pdus, room_version,
- outlier=False, include_none=False):
+ def _check_sigs_and_hash_and_fetch(
+ self, origin, pdus, room_version, outlier=False, include_none=False
+ ):
"""Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of
@@ -79,9 +80,7 @@ class FederationBase(object):
if not res:
# Check local db.
res = yield self.store.get_event(
- pdu.event_id,
- allow_rejected=True,
- allow_none=True,
+ pdu.event_id, allow_rejected=True, allow_none=True
)
if not res and pdu.origin != origin:
@@ -98,23 +97,16 @@ class FederationBase(object):
if not res:
logger.warn(
- "Failed to find copy of %s with valid signature",
- pdu.event_id,
+ "Failed to find copy of %s with valid signature", pdu.event_id
)
defer.returnValue(res)
handle = logcontext.preserve_fn(handle_check_result)
- deferreds2 = [
- handle(pdu, deferred)
- for pdu, deferred in zip(pdus, deferreds)
- ]
+ deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
valid_pdus = yield logcontext.make_deferred_yieldable(
- defer.gatherResults(
- deferreds2,
- consumeErrors=True,
- )
+ defer.gatherResults(deferreds2, consumeErrors=True)
).addErrback(unwrapFirstError)
if include_none:
@@ -124,7 +116,7 @@ class FederationBase(object):
def _check_sigs_and_hash(self, room_version, pdu):
return logcontext.make_deferred_yieldable(
- self._check_sigs_and_hashes(room_version, [pdu])[0],
+ self._check_sigs_and_hashes(room_version, [pdu])[0]
)
def _check_sigs_and_hashes(self, room_version, pdus):
@@ -159,11 +151,9 @@ class FederationBase(object):
# received event was probably a redacted copy (but we then use our
# *actual* redacted copy to be on the safe side.)
redacted_event = prune_event(pdu)
- if (
- set(redacted_event.keys()) == set(pdu.keys()) and
- set(six.iterkeys(redacted_event.content))
- == set(six.iterkeys(pdu.content))
- ):
+ if set(redacted_event.keys()) == set(pdu.keys()) and set(
+ six.iterkeys(redacted_event.content)
+ ) == set(six.iterkeys(pdu.content)):
logger.info(
"Event %s seems to have been redacted; using our redacted "
"copy",
@@ -172,14 +162,15 @@ class FederationBase(object):
else:
logger.warning(
"Event %s content has been tampered, redacting",
- pdu.event_id, pdu.get_pdu_json(),
+ pdu.event_id,
)
return redacted_event
if self.spam_checker.check_event_for_spam(pdu):
logger.warn(
"Event contains spam, redacting %s: %s",
- pdu.event_id, pdu.get_pdu_json()
+ pdu.event_id,
+ pdu.get_pdu_json(),
)
return prune_event(pdu)
@@ -190,23 +181,24 @@ class FederationBase(object):
with logcontext.PreserveLoggingContext(ctx):
logger.warn(
"Signature check failed for %s: %s",
- pdu.event_id, failure.getErrorMessage(),
+ pdu.event_id,
+ failure.getErrorMessage(),
)
return failure
for deferred, pdu in zip(deferreds, pdus):
deferred.addCallbacks(
- callback, errback,
- callbackArgs=[pdu],
- errbackArgs=[pdu],
+ callback, errback, callbackArgs=[pdu], errbackArgs=[pdu]
)
return deferreds
-class PduToCheckSig(namedtuple("PduToCheckSig", [
- "pdu", "redacted_pdu_json", "sender_domain", "deferreds",
-])):
+class PduToCheckSig(
+ namedtuple(
+ "PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"]
+ )
+):
pass
@@ -260,10 +252,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
# First we check that the sender event is signed by the sender's domain
# (except if its a 3pid invite, in which case it may be sent by any server)
- pdus_to_check_sender = [
- p for p in pdus_to_check
- if not _is_invite_via_3pid(p.pdu)
- ]
+ pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
more_deferreds = keyring.verify_json_objects_for_server(
[
@@ -297,7 +286,8 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
# (ie, the room version uses old-style non-hash event IDs).
if v.event_format == EventFormatVersions.V1:
pdus_to_check_event_id = [
- p for p in pdus_to_check
+ p
+ for p in pdus_to_check
if p.sender_domain != get_domain_from_id(p.pdu.event_id)
]
@@ -315,10 +305,8 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
def event_err(e, pdu_to_check):
errmsg = (
- "event id %s: unable to verify signature for event id domain: %s" % (
- pdu_to_check.pdu.event_id,
- e.getErrorMessage(),
- )
+ "event id %s: unable to verify signature for event id domain: %s"
+ % (pdu_to_check.pdu.event_id, e.getErrorMessage())
)
# XX as above: not really sure if these are the right codes
raise SynapseError(400, errmsg, Codes.UNAUTHORIZED)
@@ -368,21 +356,18 @@ def event_from_pdu_json(pdu_json, event_format_version, outlier=False):
"""
# we could probably enforce a bunch of other fields here (room_id, sender,
# origin, etc etc)
- assert_params_in_dict(pdu_json, ('type', 'depth'))
+ assert_params_in_dict(pdu_json, ("type", "depth"))
- depth = pdu_json['depth']
+ depth = pdu_json["depth"]
if not isinstance(depth, six.integer_types):
- raise SynapseError(400, "Depth %r not an intger" % (depth, ),
- Codes.BAD_JSON)
+ raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON)
if depth < 0:
raise SynapseError(400, "Depth too small", Codes.BAD_JSON)
elif depth > MAX_DEPTH:
raise SynapseError(400, "Depth too large", Codes.BAD_JSON)
- event = event_type_from_format_version(event_format_version)(
- pdu_json,
- )
+ event = event_type_from_format_version(event_format_version)(pdu_json)
event.internal_metadata.outlier = outlier
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 70573746d6..3883eb525e 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -57,6 +57,7 @@ class InvalidResponseError(RuntimeError):
"""Helper for _try_destination_list: indicates that the server returned a response
we couldn't parse
"""
+
pass
@@ -65,9 +66,7 @@ class FederationClient(FederationBase):
super(FederationClient, self).__init__(hs)
self.pdu_destination_tried = {}
- self._clock.looping_call(
- self._clear_tried_cache, 60 * 1000,
- )
+ self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client()
@@ -99,8 +98,14 @@ class FederationClient(FederationBase):
self.pdu_destination_tried[event_id] = destination_dict
@log_function
- def make_query(self, destination, query_type, args,
- retry_on_dns_fail=False, ignore_backoff=False):
+ def make_query(
+ self,
+ destination,
+ query_type,
+ args,
+ retry_on_dns_fail=False,
+ ignore_backoff=False,
+ ):
"""Sends a federation Query to a remote homeserver of the given type
and arguments.
@@ -120,7 +125,10 @@ class FederationClient(FederationBase):
sent_queries_counter.labels(query_type).inc()
return self.transport_layer.make_query(
- destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail,
+ destination,
+ query_type,
+ args,
+ retry_on_dns_fail=retry_on_dns_fail,
ignore_backoff=ignore_backoff,
)
@@ -137,9 +145,7 @@ class FederationClient(FederationBase):
response
"""
sent_queries_counter.labels("client_device_keys").inc()
- return self.transport_layer.query_client_keys(
- destination, content, timeout
- )
+ return self.transport_layer.query_client_keys(destination, content, timeout)
@log_function
def query_user_devices(self, destination, user_id, timeout=30000):
@@ -147,9 +153,7 @@ class FederationClient(FederationBase):
server.
"""
sent_queries_counter.labels("user_devices").inc()
- return self.transport_layer.query_user_devices(
- destination, user_id, timeout
- )
+ return self.transport_layer.query_user_devices(destination, user_id, timeout)
@log_function
def claim_client_keys(self, destination, content, timeout):
@@ -164,9 +168,7 @@ class FederationClient(FederationBase):
response
"""
sent_queries_counter.labels("client_one_time_keys").inc()
- return self.transport_layer.claim_client_keys(
- destination, content, timeout
- )
+ return self.transport_layer.claim_client_keys(destination, content, timeout)
@defer.inlineCallbacks
@log_function
@@ -191,7 +193,8 @@ class FederationClient(FederationBase):
return
transaction_data = yield self.transport_layer.backfill(
- dest, room_id, extremities, limit)
+ dest, room_id, extremities, limit
+ )
logger.debug("backfill transaction_data=%s", repr(transaction_data))
@@ -204,17 +207,19 @@ class FederationClient(FederationBase):
]
# FIXME: We should handle signature failures more gracefully.
- pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults(
- self._check_sigs_and_hashes(room_version, pdus),
- consumeErrors=True,
- ).addErrback(unwrapFirstError))
+ pdus[:] = yield logcontext.make_deferred_yieldable(
+ defer.gatherResults(
+ self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True
+ ).addErrback(unwrapFirstError)
+ )
defer.returnValue(pdus)
@defer.inlineCallbacks
@log_function
- def get_pdu(self, destinations, event_id, room_version, outlier=False,
- timeout=None):
+ def get_pdu(
+ self, destinations, event_id, room_version, outlier=False, timeout=None
+ ):
"""Requests the PDU with given origin and ID from the remote home
servers.
@@ -255,7 +260,7 @@ class FederationClient(FederationBase):
try:
transaction_data = yield self.transport_layer.get_event(
- destination, event_id, timeout=timeout,
+ destination, event_id, timeout=timeout
)
logger.debug(
@@ -282,8 +287,7 @@ class FederationClient(FederationBase):
except SynapseError as e:
logger.info(
- "Failed to get PDU %s from %s because %s",
- event_id, destination, e,
+ "Failed to get PDU %s from %s because %s", event_id, destination, e
)
continue
except NotRetryingDestination as e:
@@ -296,8 +300,7 @@ class FederationClient(FederationBase):
pdu_attempts[destination] = now
logger.info(
- "Failed to get PDU %s from %s because %s",
- event_id, destination, e,
+ "Failed to get PDU %s from %s because %s", event_id, destination, e
)
continue
@@ -326,7 +329,7 @@ class FederationClient(FederationBase):
# we have most of the state and auth_chain already.
# However, this may 404 if the other side has an old synapse.
result = yield self.transport_layer.get_room_state_ids(
- destination, room_id, event_id=event_id,
+ destination, room_id, event_id=event_id
)
state_event_ids = result["pdu_ids"]
@@ -340,12 +343,10 @@ class FederationClient(FederationBase):
logger.warning(
"Failed to fetch missing state/auth events for %s: %s",
room_id,
- failed_to_fetch
+ failed_to_fetch,
)
- event_map = {
- ev.event_id: ev for ev in fetched_events
- }
+ event_map = {ev.event_id: ev for ev in fetched_events}
pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
auth_chain = [
@@ -362,15 +363,14 @@ class FederationClient(FederationBase):
raise e
result = yield self.transport_layer.get_room_state(
- destination, room_id, event_id=event_id,
+ destination, room_id, event_id=event_id
)
room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
pdus = [
- event_from_pdu_json(p, format_ver, outlier=True)
- for p in result["pdus"]
+ event_from_pdu_json(p, format_ver, outlier=True) for p in result["pdus"]
]
auth_chain = [
@@ -378,9 +378,9 @@ class FederationClient(FederationBase):
for p in result.get("auth_chain", [])
]
- seen_events = yield self.store.get_events([
- ev.event_id for ev in itertools.chain(pdus, auth_chain)
- ])
+ seen_events = yield self.store.get_events(
+ [ev.event_id for ev in itertools.chain(pdus, auth_chain)]
+ )
signed_pdus = yield self._check_sigs_and_hash_and_fetch(
destination,
@@ -442,7 +442,7 @@ class FederationClient(FederationBase):
batch_size = 20
missing_events = list(missing_events)
for i in range(0, len(missing_events), batch_size):
- batch = set(missing_events[i:i + batch_size])
+ batch = set(missing_events[i : i + batch_size])
deferreds = [
run_in_background(
@@ -470,21 +470,17 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
- res = yield self.transport_layer.get_event_auth(
- destination, room_id, event_id,
- )
+ res = yield self.transport_layer.get_event_auth(destination, room_id, event_id)
room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
auth_chain = [
- event_from_pdu_json(p, format_ver, outlier=True)
- for p in res["auth_chain"]
+ event_from_pdu_json(p, format_ver, outlier=True) for p in res["auth_chain"]
]
signed_auth = yield self._check_sigs_and_hash_and_fetch(
- destination, auth_chain,
- outlier=True, room_version=room_version,
+ destination, auth_chain, outlier=True, room_version=room_version
)
signed_auth.sort(key=lambda e: e.depth)
@@ -527,28 +523,26 @@ class FederationClient(FederationBase):
res = yield callback(destination)
defer.returnValue(res)
except InvalidResponseError as e:
- logger.warn(
- "Failed to %s via %s: %s",
- description, destination, e,
- )
+ logger.warn("Failed to %s via %s: %s", description, destination, e)
except HttpResponseException as e:
if not 500 <= e.code < 600:
raise e.to_synapse_error()
else:
logger.warn(
"Failed to %s via %s: %i %s",
- description, destination, e.code, e.args[0],
+ description,
+ destination,
+ e.code,
+ e.args[0],
)
except Exception:
- logger.warn(
- "Failed to %s via %s",
- description, destination, exc_info=1,
- )
+ logger.warn("Failed to %s via %s", description, destination, exc_info=1)
- raise RuntimeError("Failed to %s via any server" % (description, ))
+ raise RuntimeError("Failed to %s via any server" % (description,))
- def make_membership_event(self, destinations, room_id, user_id, membership,
- content, params):
+ def make_membership_event(
+ self, destinations, room_id, user_id, membership, content, params
+ ):
"""
Creates an m.room.member event, with context, without participating in the room.
@@ -584,14 +578,14 @@ class FederationClient(FederationBase):
valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships:
raise RuntimeError(
- "make_membership_event called with membership='%s', must be one of %s" %
- (membership, ",".join(valid_memberships))
+ "make_membership_event called with membership='%s', must be one of %s"
+ % (membership, ",".join(valid_memberships))
)
@defer.inlineCallbacks
def send_request(destination):
ret = yield self.transport_layer.make_membership_event(
- destination, room_id, user_id, membership, params,
+ destination, room_id, user_id, membership, params
)
# Note: If not supplied, the room version may be either v1 or v2,
@@ -614,16 +608,17 @@ class FederationClient(FederationBase):
pdu_dict["prev_state"] = []
ev = builder.create_local_event_from_event_dict(
- self._clock, self.hostname, self.signing_key,
- format_version=event_format, event_dict=pdu_dict,
+ self._clock,
+ self.hostname,
+ self.signing_key,
+ format_version=event_format,
+ event_dict=pdu_dict,
)
- defer.returnValue(
- (destination, ev, event_format)
- )
+ defer.returnValue((destination, ev, event_format))
return self._try_destination_list(
- "make_" + membership, destinations, send_request,
+ "make_" + membership, destinations, send_request
)
def send_join(self, destinations, pdu, event_format_version):
@@ -655,9 +650,7 @@ class FederationClient(FederationBase):
create_event = e
break
else:
- raise InvalidResponseError(
- "no %s in auth chain" % (EventTypes.Create,),
- )
+ raise InvalidResponseError("no %s in auth chain" % (EventTypes.Create,))
# the room version should be sane.
room_version = create_event.content.get("room_version", "1")
@@ -665,9 +658,8 @@ class FederationClient(FederationBase):
# This shouldn't be possible, because the remote server should have
# rejected the join attempt during make_join.
raise InvalidResponseError(
- "room appears to have unsupported version %s" % (
- room_version,
- ))
+ "room appears to have unsupported version %s" % (room_version,)
+ )
@defer.inlineCallbacks
def send_request(destination):
@@ -691,10 +683,7 @@ class FederationClient(FederationBase):
for p in content.get("auth_chain", [])
]
- pdus = {
- p.event_id: p
- for p in itertools.chain(state, auth_chain)
- }
+ pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)}
room_version = None
for e in state:
@@ -710,15 +699,13 @@ class FederationClient(FederationBase):
raise SynapseError(400, "No create event in state")
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
- destination, list(pdus.values()),
+ destination,
+ list(pdus.values()),
outlier=True,
room_version=room_version,
)
- valid_pdus_map = {
- p.event_id: p
- for p in valid_pdus
- }
+ valid_pdus_map = {p.event_id: p for p in valid_pdus}
# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
@@ -741,11 +728,14 @@ class FederationClient(FederationBase):
check_authchain_validity(signed_auth)
- defer.returnValue({
- "state": signed_state,
- "auth_chain": signed_auth,
- "origin": destination,
- })
+ defer.returnValue(
+ {
+ "state": signed_state,
+ "auth_chain": signed_auth,
+ "origin": destination,
+ }
+ )
+
return self._try_destination_list("send_join", destinations, send_request)
@defer.inlineCallbacks
@@ -854,6 +844,7 @@ class FederationClient(FederationBase):
Fails with a ``RuntimeError`` if no servers were reachable.
"""
+
@defer.inlineCallbacks
def send_request(destination):
time_now = self._clock.time_msec()
@@ -869,14 +860,23 @@ class FederationClient(FederationBase):
return self._try_destination_list("send_leave", destinations, send_request)
- def get_public_rooms(self, destination, limit=None, since_token=None,
- search_filter=None, include_all_networks=False,
- third_party_instance_id=None):
+ def get_public_rooms(
+ self,
+ destination,
+ limit=None,
+ since_token=None,
+ search_filter=None,
+ include_all_networks=False,
+ third_party_instance_id=None,
+ ):
if destination == self.server_name:
return
return self.transport_layer.get_public_rooms(
- destination, limit, since_token, search_filter,
+ destination,
+ limit,
+ since_token,
+ search_filter,
include_all_networks=include_all_networks,
third_party_instance_id=third_party_instance_id,
)
@@ -891,9 +891,7 @@ class FederationClient(FederationBase):
"""
time_now = self._clock.time_msec()
- send_content = {
- "auth_chain": [e.get_pdu_json(time_now) for e in local_auth],
- }
+ send_content = {"auth_chain": [e.get_pdu_json(time_now) for e in local_auth]}
code, content = yield self.transport_layer.send_query_auth(
destination=destination,
@@ -905,13 +903,10 @@ class FederationClient(FederationBase):
room_version = yield self.store.get_room_version(room_id)
format_ver = room_version_to_event_format(room_version)
- auth_chain = [
- event_from_pdu_json(e, format_ver)
- for e in content["auth_chain"]
- ]
+ auth_chain = [event_from_pdu_json(e, format_ver) for e in content["auth_chain"]]
signed_auth = yield self._check_sigs_and_hash_and_fetch(
- destination, auth_chain, outlier=True, room_version=room_version,
+ destination, auth_chain, outlier=True, room_version=room_version
)
signed_auth.sort(key=lambda e: e.depth)
@@ -925,8 +920,16 @@ class FederationClient(FederationBase):
defer.returnValue(ret)
@defer.inlineCallbacks
- def get_missing_events(self, destination, room_id, earliest_events_ids,
- latest_events, limit, min_depth, timeout):
+ def get_missing_events(
+ self,
+ destination,
+ room_id,
+ earliest_events_ids,
+ latest_events,
+ limit,
+ min_depth,
+ timeout,
+ ):
"""Tries to fetch events we are missing. This is called when we receive
an event without having received all of its ancestors.
@@ -957,12 +960,11 @@ class FederationClient(FederationBase):
format_ver = room_version_to_event_format(room_version)
events = [
- event_from_pdu_json(e, format_ver)
- for e in content.get("events", [])
+ event_from_pdu_json(e, format_ver) for e in content.get("events", [])
]
signed_events = yield self._check_sigs_and_hash_and_fetch(
- destination, events, outlier=False, room_version=room_version,
+ destination, events, outlier=False, room_version=room_version
)
except HttpResponseException as e:
if not e.code == 400:
@@ -982,17 +984,14 @@ class FederationClient(FederationBase):
try:
yield self.transport_layer.exchange_third_party_invite(
- destination=destination,
- room_id=room_id,
- event_dict=event_dict,
+ destination=destination, room_id=room_id, event_dict=event_dict
)
defer.returnValue(None)
except CodeMessageException:
raise
except Exception as e:
logger.exception(
- "Failed to send_third_party_invite via %s: %s",
- destination, str(e)
+ "Failed to send_third_party_invite via %s: %s", destination, str(e)
)
raise RuntimeError("Failed to send to any server.")
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 4c28c1dc3c..2e0cebb638 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -69,7 +69,6 @@ received_queries_counter = Counter(
class FederationServer(FederationBase):
-
def __init__(self, hs):
super(FederationServer, self).__init__(hs)
@@ -118,11 +117,13 @@ class FederationServer(FederationBase):
# use a linearizer to ensure that we don't process the same transaction
# multiple times in parallel.
- with (yield self._transaction_linearizer.queue(
- (origin, transaction.transaction_id),
- )):
+ with (
+ yield self._transaction_linearizer.queue(
+ (origin, transaction.transaction_id)
+ )
+ ):
result = yield self._handle_incoming_transaction(
- origin, transaction, request_time,
+ origin, transaction, request_time
)
defer.returnValue(result)
@@ -144,7 +145,7 @@ class FederationServer(FederationBase):
if response:
logger.debug(
"[%s] We've already responded to this request",
- transaction.transaction_id
+ transaction.transaction_id,
)
defer.returnValue(response)
return
@@ -152,18 +153,15 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id)
# Reject if PDU count > 50 and EDU count > 100
- if (len(transaction.pdus) > 50
- or (hasattr(transaction, "edus") and len(transaction.edus) > 100)):
+ if len(transaction.pdus) > 50 or (
+ hasattr(transaction, "edus") and len(transaction.edus) > 100
+ ):
- logger.info(
- "Transaction PDU or EDU count too large. Returning 400",
- )
+ logger.info("Transaction PDU or EDU count too large. Returning 400")
response = {}
yield self.transaction_actions.set_response(
- origin,
- transaction,
- 400, response
+ origin, transaction, 400, response
)
defer.returnValue((400, response))
@@ -230,9 +228,7 @@ class FederationServer(FederationBase):
try:
yield self.check_server_matches_acl(origin_host, room_id)
except AuthError as e:
- logger.warn(
- "Ignoring PDUs for room %s from banned server", room_id,
- )
+ logger.warn("Ignoring PDUs for room %s from banned server", room_id)
for pdu in pdus_by_room[room_id]:
event_id = pdu.event_id
pdu_results[event_id] = e.error_dict()
@@ -242,9 +238,7 @@ class FederationServer(FederationBase):
event_id = pdu.event_id
with nested_logging_context(event_id):
try:
- yield self._handle_received_pdu(
- origin, pdu
- )
+ yield self._handle_received_pdu(origin, pdu)
pdu_results[event_id] = {}
except FederationError as e:
logger.warn("Error handling PDU %s: %s", event_id, e)
@@ -259,29 +253,18 @@ class FederationServer(FederationBase):
)
yield concurrently_execute(
- process_pdus_for_room, pdus_by_room.keys(),
- TRANSACTION_CONCURRENCY_LIMIT,
+ process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
)
if hasattr(transaction, "edus"):
for edu in (Edu(**x) for x in transaction.edus):
- yield self.received_edu(
- origin,
- edu.edu_type,
- edu.content
- )
+ yield self.received_edu(origin, edu.edu_type, edu.content)
- response = {
- "pdus": pdu_results,
- }
+ response = {"pdus": pdu_results}
logger.debug("Returning: %s", str(response))
- yield self.transaction_actions.set_response(
- origin,
- transaction,
- 200, response
- )
+ yield self.transaction_actions.set_response(origin, transaction, 200, response)
defer.returnValue((200, response))
@defer.inlineCallbacks
@@ -311,7 +294,8 @@ class FederationServer(FederationBase):
resp = yield self._state_resp_cache.wrap(
(room_id, event_id),
self._on_context_state_request_compute,
- room_id, event_id,
+ room_id,
+ event_id,
)
defer.returnValue((200, resp))
@@ -328,24 +312,17 @@ class FederationServer(FederationBase):
if not in_room:
raise AuthError(403, "Host not in room.")
- state_ids = yield self.handler.get_state_ids_for_pdu(
- room_id, event_id,
- )
+ state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id)
auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids)
- defer.returnValue((200, {
- "pdu_ids": state_ids,
- "auth_chain_ids": auth_chain_ids,
- }))
+ defer.returnValue(
+ (200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids})
+ )
@defer.inlineCallbacks
def _on_context_state_request_compute(self, room_id, event_id):
- pdus = yield self.handler.get_state_for_pdu(
- room_id, event_id,
- )
- auth_chain = yield self.store.get_auth_chain(
- [pdu.event_id for pdu in pdus]
- )
+ pdus = yield self.handler.get_state_for_pdu(room_id, event_id)
+ auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus])
for event in auth_chain:
# We sign these again because there was a bug where we
@@ -355,14 +332,16 @@ class FederationServer(FederationBase):
compute_event_signature(
event.get_pdu_json(),
self.hs.hostname,
- self.hs.config.signing_key[0]
+ self.hs.config.signing_key[0],
)
)
- defer.returnValue({
- "pdus": [pdu.get_pdu_json() for pdu in pdus],
- "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
- })
+ defer.returnValue(
+ {
+ "pdus": [pdu.get_pdu_json() for pdu in pdus],
+ "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
+ }
+ )
@defer.inlineCallbacks
@log_function
@@ -370,9 +349,7 @@ class FederationServer(FederationBase):
pdu = yield self.handler.get_persisted_pdu(origin, event_id)
if pdu:
- defer.returnValue(
- (200, self._transaction_from_pdus([pdu]).get_dict())
- )
+ defer.returnValue((200, self._transaction_from_pdus([pdu]).get_dict()))
else:
defer.returnValue((404, ""))
@@ -394,10 +371,9 @@ class FederationServer(FederationBase):
pdu = yield self.handler.on_make_join_request(room_id, user_id)
time_now = self._clock.time_msec()
- defer.returnValue({
- "event": pdu.get_pdu_json(time_now),
- "room_version": room_version,
- })
+ defer.returnValue(
+ {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
+ )
@defer.inlineCallbacks
def on_invite_request(self, origin, content, room_version):
@@ -431,12 +407,17 @@ class FederationServer(FederationBase):
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
- defer.returnValue((200, {
- "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
- "auth_chain": [
- p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
- ],
- }))
+ defer.returnValue(
+ (
+ 200,
+ {
+ "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
+ "auth_chain": [
+ p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
+ ],
+ },
+ )
+ )
@defer.inlineCallbacks
def on_make_leave_request(self, origin, room_id, user_id):
@@ -447,10 +428,9 @@ class FederationServer(FederationBase):
room_version = yield self.store.get_room_version(room_id)
time_now = self._clock.time_msec()
- defer.returnValue({
- "event": pdu.get_pdu_json(time_now),
- "room_version": room_version,
- })
+ defer.returnValue(
+ {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
+ )
@defer.inlineCallbacks
def on_send_leave_request(self, origin, content, room_id):
@@ -475,9 +455,7 @@ class FederationServer(FederationBase):
time_now = self._clock.time_msec()
auth_pdus = yield self.handler.on_event_auth(event_id)
- res = {
- "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
- }
+ res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
defer.returnValue((200, res))
@defer.inlineCallbacks
@@ -508,12 +486,11 @@ class FederationServer(FederationBase):
format_ver = room_version_to_event_format(room_version)
auth_chain = [
- event_from_pdu_json(e, format_ver)
- for e in content["auth_chain"]
+ event_from_pdu_json(e, format_ver) for e in content["auth_chain"]
]
signed_auth = yield self._check_sigs_and_hash_and_fetch(
- origin, auth_chain, outlier=True, room_version=room_version,
+ origin, auth_chain, outlier=True, room_version=room_version
)
ret = yield self.handler.on_query_auth(
@@ -527,17 +504,12 @@ class FederationServer(FederationBase):
time_now = self._clock.time_msec()
send_content = {
- "auth_chain": [
- e.get_pdu_json(time_now)
- for e in ret["auth_chain"]
- ],
+ "auth_chain": [e.get_pdu_json(time_now) for e in ret["auth_chain"]],
"rejects": ret.get("rejects", []),
"missing": ret.get("missing", []),
}
- defer.returnValue(
- (200, send_content)
- )
+ defer.returnValue((200, send_content))
@log_function
def on_query_client_keys(self, origin, content):
@@ -566,20 +538,23 @@ class FederationServer(FederationBase):
logger.info(
"Claimed one-time-keys: %s",
- ",".join((
- "%s for %s:%s" % (key_id, user_id, device_id)
- for user_id, user_keys in iteritems(json_result)
- for device_id, device_keys in iteritems(user_keys)
- for key_id, _ in iteritems(device_keys)
- )),
+ ",".join(
+ (
+ "%s for %s:%s" % (key_id, user_id, device_id)
+ for user_id, user_keys in iteritems(json_result)
+ for device_id, device_keys in iteritems(user_keys)
+ for key_id, _ in iteritems(device_keys)
+ )
+ ),
)
defer.returnValue({"one_time_keys": json_result})
@defer.inlineCallbacks
@log_function
- def on_get_missing_events(self, origin, room_id, earliest_events,
- latest_events, limit):
+ def on_get_missing_events(
+ self, origin, room_id, earliest_events, latest_events, limit
+ ):
with (yield self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin)
yield self.check_server_matches_acl(origin_host, room_id)
@@ -587,11 +562,13 @@ class FederationServer(FederationBase):
logger.info(
"on_get_missing_events: earliest_events: %r, latest_events: %r,"
" limit: %d",
- earliest_events, latest_events, limit,
+ earliest_events,
+ latest_events,
+ limit,
)
missing_events = yield self.handler.on_get_missing_events(
- origin, room_id, earliest_events, latest_events, limit,
+ origin, room_id, earliest_events, latest_events, limit
)
if len(missing_events) < 5:
@@ -603,9 +580,9 @@ class FederationServer(FederationBase):
time_now = self._clock.time_msec()
- defer.returnValue({
- "events": [ev.get_pdu_json(time_now) for ev in missing_events],
- })
+ defer.returnValue(
+ {"events": [ev.get_pdu_json(time_now) for ev in missing_events]}
+ )
@log_function
def on_openid_userinfo(self, token):
@@ -666,22 +643,17 @@ class FederationServer(FederationBase):
# origin. See bug #1893. This is also true for some third party
# invites).
if not (
- pdu.type == 'm.room.member' and
- pdu.content and
- pdu.content.get("membership", None) in (
- Membership.JOIN, Membership.INVITE,
- )
+ pdu.type == "m.room.member"
+ and pdu.content
+ and pdu.content.get("membership", None)
+ in (Membership.JOIN, Membership.INVITE)
):
logger.info(
- "Discarding PDU %s from invalid origin %s",
- pdu.event_id, origin
+ "Discarding PDU %s from invalid origin %s", pdu.event_id, origin
)
return
else:
- logger.info(
- "Accepting join PDU %s from %s",
- pdu.event_id, origin
- )
+ logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)
# We've already checked that we know the room version by this point
room_version = yield self.store.get_room_version(pdu.room_id)
@@ -690,33 +662,19 @@ class FederationServer(FederationBase):
try:
pdu = yield self._check_sigs_and_hash(room_version, pdu)
except SynapseError as e:
- raise FederationError(
- "ERROR",
- e.code,
- e.msg,
- affected=pdu.event_id,
- )
+ raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id)
- yield self.handler.on_receive_pdu(
- origin, pdu, sent_to_us_directly=True,
- )
+ yield self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name
@defer.inlineCallbacks
def exchange_third_party_invite(
- self,
- sender_user_id,
- target_user_id,
- room_id,
- signed,
+ self, sender_user_id, target_user_id, room_id, signed
):
ret = yield self.handler.exchange_third_party_invite(
- sender_user_id,
- target_user_id,
- room_id,
- signed,
+ sender_user_id, target_user_id, room_id, signed
)
defer.returnValue(ret)
@@ -771,7 +729,7 @@ def server_matches_acl_event(server_name, acl_event):
allow_ip_literals = True
if not allow_ip_literals:
# check for ipv6 literals. These start with '['.
- if server_name[0] == '[':
+ if server_name[0] == "[":
return False
# check for ipv4 literals. We can just lift the routine from twisted.
@@ -805,7 +763,9 @@ def server_matches_acl_event(server_name, acl_event):
def _acl_entry_matches(server_name, acl_entry):
if not isinstance(acl_entry, six.string_types):
- logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry))
+ logger.warn(
+ "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
+ )
return False
regex = glob_to_regex(acl_entry)
return regex.match(server_name)
@@ -815,6 +775,7 @@ class FederationHandlerRegistry(object):
"""Allows classes to register themselves as handlers for a given EDU or
query type for incoming federation traffic.
"""
+
def __init__(self):
self.edu_handlers = {}
self.query_handlers = {}
@@ -848,9 +809,7 @@ class FederationHandlerRegistry(object):
on and the result used as the response to the query request.
"""
if query_type in self.query_handlers:
- raise KeyError(
- "Already have a Query handler for %s" % (query_type,)
- )
+ raise KeyError("Already have a Query handler for %s" % (query_type,))
logger.info("Registering federation query handler for %r", query_type)
@@ -905,14 +864,10 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
handler = self.edu_handlers.get(edu_type)
if handler:
return super(ReplicationFederationHandlerRegistry, self).on_edu(
- edu_type, origin, content,
+ edu_type, origin, content
)
- return self._send_edu(
- edu_type=edu_type,
- origin=origin,
- content=content,
- )
+ return self._send_edu(edu_type=edu_type, origin=origin, content=content)
def on_query(self, query_type, args):
"""Overrides FederationHandlerRegistry
@@ -921,7 +876,4 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
if handler:
return handler(args)
- return self._get_query_client(
- query_type=query_type,
- args=args,
- )
+ return self._get_query_client(query_type=query_type, args=args)
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 74ffd13b4f..7535f79203 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -46,12 +46,9 @@ class TransactionActions(object):
response code and response body.
"""
if not transaction.transaction_id:
- raise RuntimeError("Cannot persist a transaction with no "
- "transaction_id")
+ raise RuntimeError("Cannot persist a transaction with no " "transaction_id")
- return self.store.get_received_txn_response(
- transaction.transaction_id, origin
- )
+ return self.store.get_received_txn_response(transaction.transaction_id, origin)
@log_function
def set_response(self, origin, transaction, code, response):
@@ -61,14 +58,10 @@ class TransactionActions(object):
Deferred
"""
if not transaction.transaction_id:
- raise RuntimeError("Cannot persist a transaction with no "
- "transaction_id")
+ raise RuntimeError("Cannot persist a transaction with no " "transaction_id")
return self.store.set_received_txn_response(
- transaction.transaction_id,
- origin,
- code,
- response,
+ transaction.transaction_id, origin, code, response
)
@defer.inlineCallbacks
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 0240b339b0..454456a52d 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -77,12 +77,22 @@ class FederationRemoteSendQueue(object):
# lambda binds to the queue rather than to the name of the queue which
# changes. ARGH.
def register(name, queue):
- LaterGauge("synapse_federation_send_queue_%s_size" % (queue_name,),
- "", [], lambda: len(queue))
+ LaterGauge(
+ "synapse_federation_send_queue_%s_size" % (queue_name,),
+ "",
+ [],
+ lambda: len(queue),
+ )
for queue_name in [
- "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
- "edus", "device_messages", "pos_time", "presence_destinations",
+ "presence_map",
+ "presence_changed",
+ "keyed_edu",
+ "keyed_edu_changed",
+ "edus",
+ "device_messages",
+ "pos_time",
+ "presence_destinations",
]:
register(queue_name, getattr(self, queue_name))
@@ -121,9 +131,7 @@ class FederationRemoteSendQueue(object):
del self.presence_changed[key]
user_ids = set(
- user_id
- for uids in self.presence_changed.values()
- for user_id in uids
+ user_id for uids in self.presence_changed.values() for user_id in uids
)
keys = self.presence_destinations.keys()
@@ -285,19 +293,21 @@ class FederationRemoteSendQueue(object):
]
for (key, user_id) in dest_user_ids:
- rows.append((key, PresenceRow(
- state=self.presence_map[user_id],
- )))
+ rows.append((key, PresenceRow(state=self.presence_map[user_id])))
# Fetch presence to send to destinations
i = self.presence_destinations.bisect_right(from_token)
j = self.presence_destinations.bisect_right(to_token) + 1
for pos, (user_id, dests) in self.presence_destinations.items()[i:j]:
- rows.append((pos, PresenceDestinationsRow(
- state=self.presence_map[user_id],
- destinations=list(dests),
- )))
+ rows.append(
+ (
+ pos,
+ PresenceDestinationsRow(
+ state=self.presence_map[user_id], destinations=list(dests)
+ ),
+ )
+ )
# Fetch changes keyed edus
i = self.keyed_edu_changed.bisect_right(from_token)
@@ -308,10 +318,14 @@ class FederationRemoteSendQueue(object):
keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]}
for ((destination, edu_key), pos) in iteritems(keyed_edus):
- rows.append((pos, KeyedEduRow(
- key=edu_key,
- edu=self.keyed_edu[(destination, edu_key)],
- )))
+ rows.append(
+ (
+ pos,
+ KeyedEduRow(
+ key=edu_key, edu=self.keyed_edu[(destination, edu_key)]
+ ),
+ )
+ )
# Fetch changed edus
i = self.edus.bisect_right(from_token)
@@ -327,9 +341,7 @@ class FederationRemoteSendQueue(object):
device_messages = {v: k for k, v in self.device_messages.items()[i:j]}
for (destination, pos) in iteritems(device_messages):
- rows.append((pos, DeviceRow(
- destination=destination,
- )))
+ rows.append((pos, DeviceRow(destination=destination)))
# Sort rows based on pos
rows.sort()
@@ -377,16 +389,14 @@ class BaseFederationRow(object):
raise NotImplementedError()
-class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", (
- "state", # UserPresenceState
-))):
+class PresenceRow(
+ BaseFederationRow, namedtuple("PresenceRow", ("state",)) # UserPresenceState
+):
TypeId = "p"
@staticmethod
def from_data(data):
- return PresenceRow(
- state=UserPresenceState.from_dict(data)
- )
+ return PresenceRow(state=UserPresenceState.from_dict(data))
def to_data(self):
return self.state.as_dict()
@@ -395,33 +405,35 @@ class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", (
buff.presence.append(self.state)
-class PresenceDestinationsRow(BaseFederationRow, namedtuple("PresenceDestinationsRow", (
- "state", # UserPresenceState
- "destinations", # list[str]
-))):
+class PresenceDestinationsRow(
+ BaseFederationRow,
+ namedtuple(
+ "PresenceDestinationsRow",
+ ("state", "destinations"), # UserPresenceState # list[str]
+ ),
+):
TypeId = "pd"
@staticmethod
def from_data(data):
return PresenceDestinationsRow(
- state=UserPresenceState.from_dict(data["state"]),
- destinations=data["dests"],
+ state=UserPresenceState.from_dict(data["state"]), destinations=data["dests"]
)
def to_data(self):
- return {
- "state": self.state.as_dict(),
- "dests": self.destinations,
- }
+ return {"state": self.state.as_dict(), "dests": self.destinations}
def add_to_buffer(self, buff):
buff.presence_destinations.append((self.state, self.destinations))
-class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", (
- "key", # tuple(str) - the edu key passed to send_edu
- "edu", # Edu
-))):
+class KeyedEduRow(
+ BaseFederationRow,
+ namedtuple(
+ "KeyedEduRow",
+ ("key", "edu"), # tuple(str) - the edu key passed to send_edu # Edu
+ ),
+):
"""Streams EDUs that have an associated key that is ued to clobber. For example,
typing EDUs clobber based on room_id.
"""
@@ -430,28 +442,19 @@ class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", (
@staticmethod
def from_data(data):
- return KeyedEduRow(
- key=tuple(data["key"]),
- edu=Edu(**data["edu"]),
- )
+ return KeyedEduRow(key=tuple(data["key"]), edu=Edu(**data["edu"]))
def to_data(self):
- return {
- "key": self.key,
- "edu": self.edu.get_internal_dict(),
- }
+ return {"key": self.key, "edu": self.edu.get_internal_dict()}
def add_to_buffer(self, buff):
- buff.keyed_edus.setdefault(
- self.edu.destination, {}
- )[self.key] = self.edu
+ buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu
-class EduRow(BaseFederationRow, namedtuple("EduRow", (
- "edu", # Edu
-))):
+class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
"""Streams EDUs that don't have keys. See KeyedEduRow
"""
+
TypeId = "e"
@staticmethod
@@ -465,13 +468,12 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", (
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
-class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
- "destination", # str
-))):
+class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ("destination",))): # str
"""Streams the fact that either a) there is pending to device messages for
users on the remote, or b) a local users device has changed and needs to
be sent to the remote.
"""
+
TypeId = "d"
@staticmethod
@@ -487,23 +489,20 @@ class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
TypeToRow = {
Row.TypeId: Row
- for Row in (
- PresenceRow,
- PresenceDestinationsRow,
- KeyedEduRow,
- EduRow,
- DeviceRow,
- )
+ for Row in (PresenceRow, PresenceDestinationsRow, KeyedEduRow, EduRow, DeviceRow)
}
-ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", (
- "presence", # list(UserPresenceState)
- "presence_destinations", # list of tuples of UserPresenceState and destinations
- "keyed_edus", # dict of destination -> { key -> Edu }
- "edus", # dict of destination -> [Edu]
- "device_destinations", # set of destinations
-))
+ParsedFederationStreamData = namedtuple(
+ "ParsedFederationStreamData",
+ (
+ "presence", # list(UserPresenceState)
+ "presence_destinations", # list of tuples of UserPresenceState and destinations
+ "keyed_edus", # dict of destination -> { key -> Edu }
+ "edus", # dict of destination -> [Edu]
+ "device_destinations", # set of destinations
+ ),
+)
def process_rows_for_federation(transaction_queue, rows):
@@ -542,7 +541,7 @@ def process_rows_for_federation(transaction_queue, rows):
for state, destinations in buff.presence_destinations:
transaction_queue.send_presence_to_destinations(
- states=[state], destinations=destinations,
+ states=[state], destinations=destinations
)
for destination, edu_map in iteritems(buff.keyed_edus):
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 4f0f939102..766c5a37cd 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -44,8 +44,8 @@ sent_pdus_destination_dist_count = Counter(
)
sent_pdus_destination_dist_total = Counter(
- "synapse_federation_client_sent_pdu_destinations:total", ""
- "Total number of PDUs queued for sending across all destinations",
+ "synapse_federation_client_sent_pdu_destinations:total",
+ "" "Total number of PDUs queued for sending across all destinations",
)
@@ -63,14 +63,15 @@ class FederationSender(object):
self._transaction_manager = TransactionManager(hs)
# map from destination to PerDestinationQueue
- self._per_destination_queues = {} # type: dict[str, PerDestinationQueue]
+ self._per_destination_queues = {} # type: dict[str, PerDestinationQueue]
LaterGauge(
"synapse_federation_transaction_queue_pending_destinations",
"",
[],
lambda: sum(
- 1 for d in self._per_destination_queues.values()
+ 1
+ for d in self._per_destination_queues.values()
if d.transmission_loop_running
),
)
@@ -108,8 +109,9 @@ class FederationSender(object):
# awaiting a call to flush_read_receipts_for_room. The presence of an entry
# here for a given room means that we are rate-limiting RR flushes to that room,
# and that there is a pending call to _flush_rrs_for_room in the system.
- self._queues_awaiting_rr_flush_by_room = {
- } # type: dict[str, set[PerDestinationQueue]]
+ self._queues_awaiting_rr_flush_by_room = (
+ {}
+ ) # type: dict[str, set[PerDestinationQueue]]
self._rr_txn_interval_per_room_ms = (
1000.0 / hs.get_config().federation_rr_transactions_per_room_per_second
@@ -141,8 +143,7 @@ class FederationSender(object):
# fire off a processing loop in the background
run_as_background_process(
- "process_event_queue_for_federation",
- self._process_event_queue_loop,
+ "process_event_queue_for_federation", self._process_event_queue_loop
)
@defer.inlineCallbacks
@@ -152,7 +153,7 @@ class FederationSender(object):
while True:
last_token = yield self.store.get_federation_out_pos("events")
next_token, events = yield self.store.get_all_new_events_stream(
- last_token, self._last_poked_id, limit=100,
+ last_token, self._last_poked_id, limit=100
)
logger.debug("Handling %s -> %s", last_token, next_token)
@@ -168,6 +169,9 @@ class FederationSender(object):
if not is_mine and send_on_behalf_of is None:
return
+ if not event.internal_metadata.should_proactively_send():
+ return
+
try:
# Get the state from before the event.
# We need to make sure that this is the state from before
@@ -176,7 +180,7 @@ class FederationSender(object):
# banned then it won't receive the event because it won't
# be in the room after the ban.
destinations = yield self.state.get_current_hosts_in_room(
- event.room_id, latest_event_ids=event.prev_event_ids(),
+ event.room_id, latest_event_ids=event.prev_event_ids()
)
except Exception:
logger.exception(
@@ -206,37 +210,40 @@ class FederationSender(object):
for event in events:
events_by_room.setdefault(event.room_id, []).append(event)
- yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- logcontext.run_in_background(handle_room_events, evs)
- for evs in itervalues(events_by_room)
- ],
- consumeErrors=True
- ))
-
- yield self.store.update_federation_out_pos(
- "events", next_token
+ yield logcontext.make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ logcontext.run_in_background(handle_room_events, evs)
+ for evs in itervalues(events_by_room)
+ ],
+ consumeErrors=True,
+ )
)
+ yield self.store.update_federation_out_pos("events", next_token)
+
if events:
now = self.clock.time_msec()
ts = yield self.store.get_received_ts(events[-1].event_id)
synapse.metrics.event_processing_lag.labels(
- "federation_sender").set(now - ts)
+ "federation_sender"
+ ).set(now - ts)
synapse.metrics.event_processing_last_ts.labels(
- "federation_sender").set(ts)
+ "federation_sender"
+ ).set(ts)
events_processed_counter.inc(len(events))
- event_processing_loop_room_count.labels(
- "federation_sender"
- ).inc(len(events_by_room))
+ event_processing_loop_room_count.labels("federation_sender").inc(
+ len(events_by_room)
+ )
event_processing_loop_counter.labels("federation_sender").inc()
synapse.metrics.event_processing_positions.labels(
- "federation_sender").set(next_token)
+ "federation_sender"
+ ).set(next_token)
finally:
self._is_processing = False
@@ -309,9 +316,7 @@ class FederationSender(object):
if not domains:
return
- queues_pending_flush = self._queues_awaiting_rr_flush_by_room.get(
- room_id
- )
+ queues_pending_flush = self._queues_awaiting_rr_flush_by_room.get(room_id)
# if there is no flush yet scheduled, we will send out these receipts with
# immediate flushes, and schedule the next flush for this room.
@@ -374,10 +379,9 @@ class FederationSender(object):
# updates in quick succession are correctly handled.
# We only want to send presence for our own users, so lets always just
# filter here just in case.
- self.pending_presence.update({
- state.user_id: state for state in states
- if self.is_mine_id(state.user_id)
- })
+ self.pending_presence.update(
+ {state.user_id: state for state in states if self.is_mine_id(state.user_id)}
+ )
# We then handle the new pending presence in batches, first figuring
# out the destinations we need to send each state to and then poking it
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 564c57203d..9aab12c0d3 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -189,11 +189,21 @@ class PerDestinationQueue(object):
pending_pdus = []
while True:
- device_message_edus, device_stream_id, dev_list_id = (
- # We have to keep 2 free slots for presence and rr_edus
- yield self._get_new_device_messages(MAX_EDUS_PER_TRANSACTION - 2)
+ # We have to keep 2 free slots for presence and rr_edus
+ limit = MAX_EDUS_PER_TRANSACTION - 2
+
+ device_update_edus, dev_list_id = (
+ yield self._get_device_update_edus(limit)
+ )
+
+ limit -= len(device_update_edus)
+
+ to_device_edus, device_stream_id = (
+ yield self._get_to_device_message_edus(limit)
)
+ pending_edus = device_update_edus + to_device_edus
+
# BEGIN CRITICAL SECTION
#
# In order to avoid a race condition, we need to make sure that
@@ -208,10 +218,6 @@ class PerDestinationQueue(object):
# We can only include at most 50 PDUs per transactions
pending_pdus, self._pending_pdus = pending_pdus[:50], pending_pdus[50:]
- pending_edus = []
-
- # We can only include at most 100 EDUs per transactions
- # rr_edus and pending_presence take at most one slot each
pending_edus.extend(self._get_rr_edus(force_flush=False))
pending_presence = self._pending_presence
self._pending_presence = {}
@@ -232,7 +238,6 @@ class PerDestinationQueue(object):
)
)
- pending_edus.extend(device_message_edus)
pending_edus.extend(
self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))
)
@@ -272,10 +277,13 @@ class PerDestinationQueue(object):
sent_edus_by_type.labels(edu.edu_type).inc()
# Remove the acknowledged device messages from the database
# Only bother if we actually sent some device messages
- if device_message_edus:
+ if to_device_edus:
yield self._store.delete_device_msgs_for_remote(
self._destination, device_stream_id
)
+
+ # also mark the device updates as sent
+ if device_update_edus:
logger.info(
"Marking as sent %r %r", self._destination, dev_list_id
)
@@ -347,12 +355,12 @@ class PerDestinationQueue(object):
return pending_edus
@defer.inlineCallbacks
- def _get_new_device_messages(self, limit):
+ def _get_device_update_edus(self, limit):
last_device_list = self._last_device_list_stream_id
# Retrieve list of new device updates to send to the destination
now_stream_id, results = yield self._store.get_devices_by_remote(
- self._destination, last_device_list, limit=limit,
+ self._destination, last_device_list, limit=limit
)
edus = [
Edu(
@@ -366,15 +374,16 @@ class PerDestinationQueue(object):
assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"
+ defer.returnValue((edus, now_stream_id))
+
+ @defer.inlineCallbacks
+ def _get_to_device_message_edus(self, limit):
last_device_stream_id = self._last_device_stream_id
to_device_stream_id = self._store.get_to_device_stream_token()
contents, stream_id = yield self._store.get_new_device_msgs_for_remote(
- self._destination,
- last_device_stream_id,
- to_device_stream_id,
- limit - len(edus),
+ self._destination, last_device_stream_id, to_device_stream_id, limit
)
- edus.extend(
+ edus = [
Edu(
origin=self._server_name,
destination=self._destination,
@@ -382,6 +391,6 @@ class PerDestinationQueue(object):
content=content,
)
for content in contents
- )
+ ]
- defer.returnValue((edus, stream_id, now_stream_id))
+ defer.returnValue((edus, stream_id))
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 35e6b8ff5b..c987bb9a0d 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -29,9 +29,10 @@ class TransactionManager(object):
shared between PerDestinationQueue objects
"""
+
def __init__(self, hs):
self._server_name = hs.hostname
- self.clock = hs.get_clock() # nb must be called this for @measure_func
+ self.clock = hs.get_clock() # nb must be called this for @measure_func
self._store = hs.get_datastore()
self._transaction_actions = TransactionActions(self._store)
self._transport_layer = hs.get_federation_transport_client()
@@ -55,9 +56,9 @@ class TransactionManager(object):
txn_id = str(self._next_txn_id)
logger.debug(
- "TX [%s] {%s} Attempting new transaction"
- " (pdus: %d, edus: %d)",
- destination, txn_id,
+ "TX [%s] {%s} Attempting new transaction" " (pdus: %d, edus: %d)",
+ destination,
+ txn_id,
len(pdus),
len(edus),
)
@@ -79,9 +80,9 @@ class TransactionManager(object):
logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
- "TX [%s] {%s} Sending transaction [%s],"
- " (PDUs: %d, EDUs: %d)",
- destination, txn_id,
+ "TX [%s] {%s} Sending transaction [%s]," " (PDUs: %d, EDUs: %d)",
+ destination,
+ txn_id,
transaction.transaction_id,
len(pdus),
len(edus),
@@ -112,20 +113,12 @@ class TransactionManager(object):
response = e.response
if e.code in (401, 404, 429) or 500 <= e.code:
- logger.info(
- "TX [%s] {%s} got %d response",
- destination, txn_id, code
- )
+ logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
raise e
- logger.info(
- "TX [%s] {%s} got %d response",
- destination, txn_id, code
- )
+ logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
- yield self._transaction_actions.delivered(
- transaction, code, response
- )
+ yield self._transaction_actions.delivered(transaction, code, response)
logger.debug("TX [%s] {%s} Marked as delivered", destination, txn_id)
@@ -134,13 +127,18 @@ class TransactionManager(object):
if "error" in r:
logger.warn(
"TX [%s] {%s} Remote returned error for %s: %s",
- destination, txn_id, e_id, r,
+ destination,
+ txn_id,
+ e_id,
+ r,
)
else:
for p in pdus:
logger.warn(
"TX [%s] {%s} Failed to send event %s",
- destination, txn_id, p.event_id,
+ destination,
+ txn_id,
+ p.event_id,
)
success = False
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index e424c40fdf..aecd142309 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -48,12 +48,13 @@ class TransportLayerClient(object):
Returns:
Deferred: Results in a dict received from the remote homeserver.
"""
- logger.debug("get_room_state dest=%s, room=%s",
- destination, room_id)
+ logger.debug("get_room_state dest=%s, room=%s", destination, room_id)
path = _create_v1_path("/state/%s", room_id)
return self.client.get_json(
- destination, path=path, args={"event_id": event_id},
+ destination,
+ path=path,
+ args={"event_id": event_id},
try_trailing_slash_on_400=True,
)
@@ -71,12 +72,13 @@ class TransportLayerClient(object):
Returns:
Deferred: Results in a dict received from the remote homeserver.
"""
- logger.debug("get_room_state_ids dest=%s, room=%s",
- destination, room_id)
+ logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id)
path = _create_v1_path("/state_ids/%s", room_id)
return self.client.get_json(
- destination, path=path, args={"event_id": event_id},
+ destination,
+ path=path,
+ args={"event_id": event_id},
try_trailing_slash_on_400=True,
)
@@ -94,13 +96,11 @@ class TransportLayerClient(object):
Returns:
Deferred: Results in a dict received from the remote homeserver.
"""
- logger.debug("get_pdu dest=%s, event_id=%s",
- destination, event_id)
+ logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id)
path = _create_v1_path("/event/%s", event_id)
return self.client.get_json(
- destination, path=path, timeout=timeout,
- try_trailing_slash_on_400=True,
+ destination, path=path, timeout=timeout, try_trailing_slash_on_400=True
)
@log_function
@@ -119,7 +119,10 @@ class TransportLayerClient(object):
"""
logger.debug(
"backfill dest=%s, room_id=%s, event_tuples=%s, limit=%s",
- destination, room_id, repr(event_tuples), str(limit)
+ destination,
+ room_id,
+ repr(event_tuples),
+ str(limit),
)
if not event_tuples:
@@ -128,16 +131,10 @@ class TransportLayerClient(object):
path = _create_v1_path("/backfill/%s", room_id)
- args = {
- "v": event_tuples,
- "limit": [str(limit)],
- }
+ args = {"v": event_tuples, "limit": [str(limit)]}
return self.client.get_json(
- destination,
- path=path,
- args=args,
- try_trailing_slash_on_400=True,
+ destination, path=path, args=args, try_trailing_slash_on_400=True
)
@defer.inlineCallbacks
@@ -163,7 +160,8 @@ class TransportLayerClient(object):
"""
logger.debug(
"send_data dest=%s, txid=%s",
- transaction.destination, transaction.transaction_id
+ transaction.destination,
+ transaction.transaction_id,
)
if transaction.destination == self.server_name:
@@ -189,8 +187,9 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
- def make_query(self, destination, query_type, args, retry_on_dns_fail,
- ignore_backoff=False):
+ def make_query(
+ self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False
+ ):
path = _create_v1_path("/query/%s", query_type)
content = yield self.client.get_json(
@@ -235,8 +234,8 @@ class TransportLayerClient(object):
valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships:
raise RuntimeError(
- "make_membership_event called with membership='%s', must be one of %s" %
- (membership, ",".join(valid_memberships))
+ "make_membership_event called with membership='%s', must be one of %s"
+ % (membership, ",".join(valid_memberships))
)
path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id)
@@ -268,9 +267,7 @@ class TransportLayerClient(object):
path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
response = yield self.client.put_json(
- destination=destination,
- path=path,
- data=content,
+ destination=destination, path=path, data=content
)
defer.returnValue(response)
@@ -284,7 +281,6 @@ class TransportLayerClient(object):
destination=destination,
path=path,
data=content,
-
# we want to do our best to send this through. The problem is
# that if it fails, we won't retry it later, so if the remote
# server was just having a momentary blip, the room will be out of
@@ -300,10 +296,7 @@ class TransportLayerClient(object):
path = _create_v1_path("/invite/%s/%s", room_id, event_id)
response = yield self.client.put_json(
- destination=destination,
- path=path,
- data=content,
- ignore_backoff=True,
+ destination=destination, path=path, data=content, ignore_backoff=True
)
defer.returnValue(response)
@@ -314,26 +307,27 @@ class TransportLayerClient(object):
path = _create_v2_path("/invite/%s/%s", room_id, event_id)
response = yield self.client.put_json(
- destination=destination,
- path=path,
- data=content,
- ignore_backoff=True,
+ destination=destination, path=path, data=content, ignore_backoff=True
)
defer.returnValue(response)
@defer.inlineCallbacks
@log_function
- def get_public_rooms(self, remote_server, limit, since_token,
- search_filter=None, include_all_networks=False,
- third_party_instance_id=None):
+ def get_public_rooms(
+ self,
+ remote_server,
+ limit,
+ since_token,
+ search_filter=None,
+ include_all_networks=False,
+ third_party_instance_id=None,
+ ):
path = _create_v1_path("/publicRooms")
- args = {
- "include_all_networks": "true" if include_all_networks else "false",
- }
+ args = {"include_all_networks": "true" if include_all_networks else "false"}
if third_party_instance_id:
- args["third_party_instance_id"] = third_party_instance_id,
+ args["third_party_instance_id"] = (third_party_instance_id,)
if limit:
args["limit"] = [str(limit)]
if since_token:
@@ -342,10 +336,7 @@ class TransportLayerClient(object):
# TODO(erikj): Actually send the search_filter across federation.
response = yield self.client.get_json(
- destination=remote_server,
- path=path,
- args=args,
- ignore_backoff=True,
+ destination=remote_server, path=path, args=args, ignore_backoff=True
)
defer.returnValue(response)
@@ -353,12 +344,10 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def exchange_third_party_invite(self, destination, room_id, event_dict):
- path = _create_v1_path("/exchange_third_party_invite/%s", room_id,)
+ path = _create_v1_path("/exchange_third_party_invite/%s", room_id)
response = yield self.client.put_json(
- destination=destination,
- path=path,
- data=event_dict,
+ destination=destination, path=path, data=event_dict
)
defer.returnValue(response)
@@ -368,10 +357,7 @@ class TransportLayerClient(object):
def get_event_auth(self, destination, room_id, event_id):
path = _create_v1_path("/event_auth/%s/%s", room_id, event_id)
- content = yield self.client.get_json(
- destination=destination,
- path=path,
- )
+ content = yield self.client.get_json(destination=destination, path=path)
defer.returnValue(content)
@@ -381,9 +367,7 @@ class TransportLayerClient(object):
path = _create_v1_path("/query_auth/%s/%s", room_id, event_id)
content = yield self.client.post_json(
- destination=destination,
- path=path,
- data=content,
+ destination=destination, path=path, data=content
)
defer.returnValue(content)
@@ -416,10 +400,7 @@ class TransportLayerClient(object):
path = _create_v1_path("/user/keys/query")
content = yield self.client.post_json(
- destination=destination,
- path=path,
- data=query_content,
- timeout=timeout,
+ destination=destination, path=path, data=query_content, timeout=timeout
)
defer.returnValue(content)
@@ -443,9 +424,7 @@ class TransportLayerClient(object):
path = _create_v1_path("/user/devices/%s", user_id)
content = yield self.client.get_json(
- destination=destination,
- path=path,
- timeout=timeout,
+ destination=destination, path=path, timeout=timeout
)
defer.returnValue(content)
@@ -479,18 +458,23 @@ class TransportLayerClient(object):
path = _create_v1_path("/user/keys/claim")
content = yield self.client.post_json(
- destination=destination,
- path=path,
- data=query_content,
- timeout=timeout,
+ destination=destination, path=path, data=query_content, timeout=timeout
)
defer.returnValue(content)
@defer.inlineCallbacks
@log_function
- def get_missing_events(self, destination, room_id, earliest_events,
- latest_events, limit, min_depth, timeout):
- path = _create_v1_path("/get_missing_events/%s", room_id,)
+ def get_missing_events(
+ self,
+ destination,
+ room_id,
+ earliest_events,
+ latest_events,
+ limit,
+ min_depth,
+ timeout,
+ ):
+ path = _create_v1_path("/get_missing_events/%s", room_id)
content = yield self.client.post_json(
destination=destination,
@@ -510,7 +494,7 @@ class TransportLayerClient(object):
def get_group_profile(self, destination, group_id, requester_user_id):
"""Get a group profile
"""
- path = _create_v1_path("/groups/%s/profile", group_id,)
+ path = _create_v1_path("/groups/%s/profile", group_id)
return self.client.get_json(
destination=destination,
@@ -529,7 +513,7 @@ class TransportLayerClient(object):
requester_user_id (str)
content (dict): The new profile of the group
"""
- path = _create_v1_path("/groups/%s/profile", group_id,)
+ path = _create_v1_path("/groups/%s/profile", group_id)
return self.client.post_json(
destination=destination,
@@ -543,7 +527,7 @@ class TransportLayerClient(object):
def get_group_summary(self, destination, group_id, requester_user_id):
"""Get a group summary
"""
- path = _create_v1_path("/groups/%s/summary", group_id,)
+ path = _create_v1_path("/groups/%s/summary", group_id)
return self.client.get_json(
destination=destination,
@@ -556,7 +540,7 @@ class TransportLayerClient(object):
def get_rooms_in_group(self, destination, group_id, requester_user_id):
"""Get all rooms in a group
"""
- path = _create_v1_path("/groups/%s/rooms", group_id,)
+ path = _create_v1_path("/groups/%s/rooms", group_id)
return self.client.get_json(
destination=destination,
@@ -565,11 +549,12 @@ class TransportLayerClient(object):
ignore_backoff=True,
)
- def add_room_to_group(self, destination, group_id, requester_user_id, room_id,
- content):
+ def add_room_to_group(
+ self, destination, group_id, requester_user_id, room_id, content
+ ):
"""Add a room to a group
"""
- path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,)
+ path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
return self.client.post_json(
destination=destination,
@@ -579,13 +564,13 @@ class TransportLayerClient(object):
ignore_backoff=True,
)
- def update_room_in_group(self, destination, group_id, requester_user_id, room_id,
- config_key, content):
+ def update_room_in_group(
+ self, destination, group_id, requester_user_id, room_id, config_key, content
+ ):
"""Update room in group
"""
path = _create_v1_path(
- "/groups/%s/room/%s/config/%s",
- group_id, room_id, config_key,
+ "/groups/%s/room/%s/config/%s", group_id, room_id, config_key
)
return self.client.post_json(
@@ -599,7 +584,7 @@ class TransportLayerClient(object):
def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
"""Remove a room from a group
"""
- path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,)
+ path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
return self.client.delete_json(
destination=destination,
@@ -612,7 +597,7 @@ class TransportLayerClient(object):
def get_users_in_group(self, destination, group_id, requester_user_id):
"""Get users in a group
"""
- path = _create_v1_path("/groups/%s/users", group_id,)
+ path = _create_v1_path("/groups/%s/users", group_id)
return self.client.get_json(
destination=destination,
@@ -625,7 +610,7 @@ class TransportLayerClient(object):
def get_invited_users_in_group(self, destination, group_id, requester_user_id):
"""Get users that have been invited to a group
"""
- path = _create_v1_path("/groups/%s/invited_users", group_id,)
+ path = _create_v1_path("/groups/%s/invited_users", group_id)
return self.client.get_json(
destination=destination,
@@ -638,16 +623,10 @@ class TransportLayerClient(object):
def accept_group_invite(self, destination, group_id, user_id, content):
"""Accept a group invite
"""
- path = _create_v1_path(
- "/groups/%s/users/%s/accept_invite",
- group_id, user_id,
- )
+ path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id)
return self.client.post_json(
- destination=destination,
- path=path,
- data=content,
- ignore_backoff=True,
+ destination=destination, path=path, data=content, ignore_backoff=True
)
@log_function
@@ -657,14 +636,13 @@ class TransportLayerClient(object):
path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id)
return self.client.post_json(
- destination=destination,
- path=path,
- data=content,
- ignore_backoff=True,
+ destination=destination, path=path, data=content, ignore_backoff=True
)
@log_function
- def invite_to_group(self, destination, group_id, user_id, requester_user_id, content):
+ def invite_to_group(
+ self, destination, group_id, user_id, requester_user_id, content
+ ):
"""Invite a user to a group
"""
path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id)
@@ -686,15 +664,13 @@ class TransportLayerClient(object):
path = _create_v1_path("/groups/local/%s/users/%s/invite", group_id, user_id)
return self.client.post_json(
- destination=destination,
- path=path,
- data=content,
- ignore_backoff=True,
+ destination=destination, path=path, data=content, ignore_backoff=True
)
@log_function
- def remove_user_from_group(self, destination, group_id, requester_user_id,
- user_id, content):
+ def remove_user_from_group(
+ self, destination, group_id, requester_user_id, user_id, content
+ ):
"""Remove a user fron a group
"""
path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id)
@@ -708,8 +684,9 @@ class TransportLayerClient(object):
)
@log_function
- def remove_user_from_group_notification(self, destination, group_id, user_id,
- content):
+ def remove_user_from_group_notification(
+ self, destination, group_id, user_id, content
+ ):
"""Sent by group server to inform a user's server that they have been
kicked from the group.
"""
@@ -717,10 +694,7 @@ class TransportLayerClient(object):
path = _create_v1_path("/groups/local/%s/users/%s/remove", group_id, user_id)
return self.client.post_json(
- destination=destination,
- path=path,
- data=content,
- ignore_backoff=True,
+ destination=destination, path=path, data=content, ignore_backoff=True
)
@log_function
@@ -732,24 +706,24 @@ class TransportLayerClient(object):
path = _create_v1_path("/groups/%s/renew_attestation/%s", group_id, user_id)
return self.client.post_json(
- destination=destination,
- path=path,
- data=content,
- ignore_backoff=True,
+ destination=destination, path=path, data=content, ignore_backoff=True
)
@log_function
- def update_group_summary_room(self, destination, group_id, user_id, room_id,
- category_id, content):
+ def update_group_summary_room(
+ self, destination, group_id, user_id, room_id, category_id, content
+ ):
"""Update a room entry in a group summary
"""
if category_id:
path = _create_v1_path(
"/groups/%s/summary/categories/%s/rooms/%s",
- group_id, category_id, room_id,
+ group_id,
+ category_id,
+ room_id,
)
else:
- path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,)
+ path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id)
return self.client.post_json(
destination=destination,
@@ -760,17 +734,20 @@ class TransportLayerClient(object):
)
@log_function
- def delete_group_summary_room(self, destination, group_id, user_id, room_id,
- category_id):
+ def delete_group_summary_room(
+ self, destination, group_id, user_id, room_id, category_id
+ ):
"""Delete a room entry in a group summary
"""
if category_id:
path = _create_v1_path(
"/groups/%s/summary/categories/%s/rooms/%s",
- group_id, category_id, room_id,
+ group_id,
+ category_id,
+ room_id,
)
else:
- path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,)
+ path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id)
return self.client.delete_json(
destination=destination,
@@ -783,7 +760,7 @@ class TransportLayerClient(object):
def get_group_categories(self, destination, group_id, requester_user_id):
"""Get all categories in a group
"""
- path = _create_v1_path("/groups/%s/categories", group_id,)
+ path = _create_v1_path("/groups/%s/categories", group_id)
return self.client.get_json(
destination=destination,
@@ -796,7 +773,7 @@ class TransportLayerClient(object):
def get_group_category(self, destination, group_id, requester_user_id, category_id):
"""Get category info in a group
"""
- path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
+ path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
return self.client.get_json(
destination=destination,
@@ -806,11 +783,12 @@ class TransportLayerClient(object):
)
@log_function
- def update_group_category(self, destination, group_id, requester_user_id, category_id,
- content):
+ def update_group_category(
+ self, destination, group_id, requester_user_id, category_id, content
+ ):
"""Update a category in a group
"""
- path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
+ path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
return self.client.post_json(
destination=destination,
@@ -821,11 +799,12 @@ class TransportLayerClient(object):
)
@log_function
- def delete_group_category(self, destination, group_id, requester_user_id,
- category_id):
+ def delete_group_category(
+ self, destination, group_id, requester_user_id, category_id
+ ):
"""Delete a category in a group
"""
- path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
+ path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
return self.client.delete_json(
destination=destination,
@@ -838,7 +817,7 @@ class TransportLayerClient(object):
def get_group_roles(self, destination, group_id, requester_user_id):
"""Get all roles in a group
"""
- path = _create_v1_path("/groups/%s/roles", group_id,)
+ path = _create_v1_path("/groups/%s/roles", group_id)
return self.client.get_json(
destination=destination,
@@ -851,7 +830,7 @@ class TransportLayerClient(object):
def get_group_role(self, destination, group_id, requester_user_id, role_id):
"""Get a roles info
"""
- path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
+ path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
return self.client.get_json(
destination=destination,
@@ -861,11 +840,12 @@ class TransportLayerClient(object):
)
@log_function
- def update_group_role(self, destination, group_id, requester_user_id, role_id,
- content):
+ def update_group_role(
+ self, destination, group_id, requester_user_id, role_id, content
+ ):
"""Update a role in a group
"""
- path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
+ path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
return self.client.post_json(
destination=destination,
@@ -879,7 +859,7 @@ class TransportLayerClient(object):
def delete_group_role(self, destination, group_id, requester_user_id, role_id):
"""Delete a role in a group
"""
- path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
+ path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
return self.client.delete_json(
destination=destination,
@@ -889,17 +869,17 @@ class TransportLayerClient(object):
)
@log_function
- def update_group_summary_user(self, destination, group_id, requester_user_id,
- user_id, role_id, content):
+ def update_group_summary_user(
+ self, destination, group_id, requester_user_id, user_id, role_id, content
+ ):
"""Update a users entry in a group
"""
if role_id:
path = _create_v1_path(
- "/groups/%s/summary/roles/%s/users/%s",
- group_id, role_id, user_id,
+ "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id
)
else:
- path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,)
+ path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id)
return self.client.post_json(
destination=destination,
@@ -910,11 +890,10 @@ class TransportLayerClient(object):
)
@log_function
- def set_group_join_policy(self, destination, group_id, requester_user_id,
- content):
+ def set_group_join_policy(self, destination, group_id, requester_user_id, content):
"""Sets the join policy for a group
"""
- path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id,)
+ path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id)
return self.client.put_json(
destination=destination,
@@ -925,17 +904,17 @@ class TransportLayerClient(object):
)
@log_function
- def delete_group_summary_user(self, destination, group_id, requester_user_id,
- user_id, role_id):
+ def delete_group_summary_user(
+ self, destination, group_id, requester_user_id, user_id, role_id
+ ):
"""Delete a users entry in a group
"""
if role_id:
path = _create_v1_path(
- "/groups/%s/summary/roles/%s/users/%s",
- group_id, role_id, user_id,
+ "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id
)
else:
- path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,)
+ path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id)
return self.client.delete_json(
destination=destination,
@@ -953,10 +932,7 @@ class TransportLayerClient(object):
content = {"user_ids": user_ids}
return self.client.post_json(
- destination=destination,
- path=path,
- data=content,
- ignore_backoff=True,
+ destination=destination, path=path, data=content, ignore_backoff=True
)
@@ -975,9 +951,8 @@ def _create_v1_path(path, *args):
Returns:
str
"""
- return (
- FEDERATION_V1_PREFIX
- + path % tuple(urllib.parse.quote(arg, "") for arg in args)
+ return FEDERATION_V1_PREFIX + path % tuple(
+ urllib.parse.quote(arg, "") for arg in args
)
@@ -996,7 +971,6 @@ def _create_v2_path(path, *args):
Returns:
str
"""
- return (
- FEDERATION_V2_PREFIX
- + path % tuple(urllib.parse.quote(arg, "") for arg in args)
+ return FEDERATION_V2_PREFIX + path % tuple(
+ urllib.parse.quote(arg, "") for arg in args
)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 949a5fb2aa..955f0f4308 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -66,8 +66,7 @@ class TransportLayerServer(JsonResource):
self.authenticator = Authenticator(hs)
self.ratelimiter = FederationRateLimiter(
- self.clock,
- config=hs.config.rc_federation,
+ self.clock, config=hs.config.rc_federation
)
self.register_servlets()
@@ -84,11 +83,13 @@ class TransportLayerServer(JsonResource):
class AuthenticationError(SynapseError):
"""There was a problem authenticating the request"""
+
pass
class NoAuthenticationError(AuthenticationError):
"""The request had no authentication information"""
+
pass
@@ -105,8 +106,8 @@ class Authenticator(object):
def authenticate_request(self, request, content):
now = self._clock.time_msec()
json_request = {
- "method": request.method.decode('ascii'),
- "uri": request.uri.decode('ascii'),
+ "method": request.method.decode("ascii"),
+ "uri": request.uri.decode("ascii"),
"destination": self.server_name,
"signatures": {},
}
@@ -120,7 +121,7 @@ class Authenticator(object):
if not auth_headers:
raise NoAuthenticationError(
- 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
+ 401, "Missing Authorization headers", Codes.UNAUTHORIZED
)
for auth in auth_headers:
@@ -130,14 +131,14 @@ class Authenticator(object):
json_request["signatures"].setdefault(origin, {})[key] = sig
if (
- self.federation_domain_whitelist is not None and
- origin not in self.federation_domain_whitelist
+ self.federation_domain_whitelist is not None
+ and origin not in self.federation_domain_whitelist
):
raise FederationDeniedError(origin)
if not json_request["signatures"]:
raise NoAuthenticationError(
- 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
+ 401, "Missing Authorization headers", Codes.UNAUTHORIZED
)
yield self.keyring.verify_json_for_server(
@@ -177,12 +178,12 @@ def _parse_auth_header(header_bytes):
AuthenticationError if the header could not be parsed
"""
try:
- header_str = header_bytes.decode('utf-8')
+ header_str = header_bytes.decode("utf-8")
params = header_str.split(" ")[1].split(",")
param_dict = dict(kv.split("=") for kv in params)
def strip_quotes(value):
- if value.startswith("\""):
+ if value.startswith('"'):
return value[1:-1]
else:
return value
@@ -198,11 +199,11 @@ def _parse_auth_header(header_bytes):
except Exception as e:
logger.warn(
"Error parsing auth header '%s': %s",
- header_bytes.decode('ascii', 'replace'),
+ header_bytes.decode("ascii", "replace"),
e,
)
raise AuthenticationError(
- 400, "Malformed Authorization header", Codes.UNAUTHORIZED,
+ 400, "Malformed Authorization header", Codes.UNAUTHORIZED
)
@@ -242,6 +243,7 @@ class BaseFederationServlet(object):
Exception: other exceptions will be caught, logged, and a 500 will be
returned.
"""
+
REQUIRE_AUTH = True
PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version
@@ -293,9 +295,7 @@ class BaseFederationServlet(object):
origin, content, request.args, *args, **kwargs
)
else:
- response = yield func(
- origin, content, request.args, *args, **kwargs
- )
+ response = yield func(origin, content, request.args, *args, **kwargs)
defer.returnValue(response)
@@ -343,14 +343,12 @@ class FederationSendServlet(BaseFederationServlet):
try:
transaction_data = content
- logger.debug(
- "Decoded %s: %s",
- transaction_id, str(transaction_data)
- )
+ logger.debug("Decoded %s: %s", transaction_id, str(transaction_data))
logger.info(
"Received txn %s from %s. (PDUs: %d, EDUs: %d)",
- transaction_id, origin,
+ transaction_id,
+ origin,
len(transaction_data.get("pdus", [])),
len(transaction_data.get("edus", [])),
)
@@ -361,8 +359,7 @@ class FederationSendServlet(BaseFederationServlet):
# Add some extra data to the transaction dict that isn't included
# in the request body.
transaction_data.update(
- transaction_id=transaction_id,
- destination=self.server_name
+ transaction_id=transaction_id, destination=self.server_name
)
except Exception as e:
@@ -372,7 +369,7 @@ class FederationSendServlet(BaseFederationServlet):
try:
code, response = yield self.handler.on_incoming_transaction(
- origin, transaction_data,
+ origin, transaction_data
)
except Exception:
logger.exception("on_incoming_transaction failed")
@@ -416,7 +413,7 @@ class FederationBackfillServlet(BaseFederationServlet):
PATH = "/backfill/(?P<context>[^/]*)/?"
def on_GET(self, origin, content, query, context):
- versions = [x.decode('ascii') for x in query[b"v"]]
+ versions = [x.decode("ascii") for x in query[b"v"]]
limit = parse_integer_from_args(query, "limit", None)
if not limit:
@@ -432,7 +429,7 @@ class FederationQueryServlet(BaseFederationServlet):
def on_GET(self, origin, content, query, query_type):
return self.handler.on_query_request(
query_type,
- {k.decode('utf8'): v[0].decode("utf-8") for k, v in query.items()}
+ {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()},
)
@@ -456,15 +453,14 @@ class FederationMakeJoinServlet(BaseFederationServlet):
Deferred[(int, object)|None]: either (response code, response object) to
return a JSON response, or None if the request has already been handled.
"""
- versions = query.get(b'ver')
+ versions = query.get(b"ver")
if versions is not None:
supported_versions = [v.decode("utf-8") for v in versions]
else:
supported_versions = ["1"]
content = yield self.handler.on_make_join_request(
- origin, context, user_id,
- supported_versions=supported_versions,
+ origin, context, user_id, supported_versions=supported_versions
)
defer.returnValue((200, content))
@@ -474,9 +470,7 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
@defer.inlineCallbacks
def on_GET(self, origin, content, query, context, user_id):
- content = yield self.handler.on_make_leave_request(
- origin, context, user_id,
- )
+ content = yield self.handler.on_make_leave_request(origin, context, user_id)
defer.returnValue((200, content))
@@ -517,7 +511,7 @@ class FederationV1InviteServlet(BaseFederationServlet):
# state resolution algorithm, and we don't use that for processing
# invites
content = yield self.handler.on_invite_request(
- origin, content, room_version=RoomVersions.V1.identifier,
+ origin, content, room_version=RoomVersions.V1.identifier
)
# V1 federation API is defined to return a content of `[200, {...}]`
@@ -545,7 +539,7 @@ class FederationV2InviteServlet(BaseFederationServlet):
event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state
content = yield self.handler.on_invite_request(
- origin, event, room_version=room_version,
+ origin, event, room_version=room_version
)
defer.returnValue((200, content))
@@ -629,8 +623,10 @@ class On3pidBindServlet(BaseFederationServlet):
for invite in content["invites"]:
try:
if "signed" not in invite or "token" not in invite["signed"]:
- message = ("Rejecting received notification of third-"
- "party invite without signed: %s" % (invite,))
+ message = (
+ "Rejecting received notification of third-"
+ "party invite without signed: %s" % (invite,)
+ )
logger.info(message)
raise SynapseError(400, message)
yield self.handler.exchange_third_party_invite(
@@ -671,18 +667,23 @@ class OpenIdUserInfo(BaseFederationServlet):
def on_GET(self, origin, content, query):
token = query.get(b"access_token", [None])[0]
if token is None:
- defer.returnValue((401, {
- "errcode": "M_MISSING_TOKEN", "error": "Access Token required"
- }))
+ defer.returnValue(
+ (401, {"errcode": "M_MISSING_TOKEN", "error": "Access Token required"})
+ )
return
- user_id = yield self.handler.on_openid_userinfo(token.decode('ascii'))
+ user_id = yield self.handler.on_openid_userinfo(token.decode("ascii"))
if user_id is None:
- defer.returnValue((401, {
- "errcode": "M_UNKNOWN_TOKEN",
- "error": "Access Token unknown or expired"
- }))
+ defer.returnValue(
+ (
+ 401,
+ {
+ "errcode": "M_UNKNOWN_TOKEN",
+ "error": "Access Token unknown or expired",
+ },
+ )
+ )
defer.returnValue((200, {"sub": user_id}))
@@ -720,15 +721,15 @@ class PublicRoomList(BaseFederationServlet):
PATH = "/publicRooms"
- def __init__(self, handler, authenticator, ratelimiter, server_name, deny_access):
+ def __init__(self, handler, authenticator, ratelimiter, server_name, allow_access):
super(PublicRoomList, self).__init__(
- handler, authenticator, ratelimiter, server_name,
+ handler, authenticator, ratelimiter, server_name
)
- self.deny_access = deny_access
+ self.allow_access = allow_access
@defer.inlineCallbacks
def on_GET(self, origin, content, query):
- if self.deny_access:
+ if not self.allow_access:
raise FederationDeniedError(origin)
limit = parse_integer_from_args(query, "limit", 0)
@@ -748,9 +749,7 @@ class PublicRoomList(BaseFederationServlet):
network_tuple = ThirdPartyInstanceID(None, None)
data = yield self.handler.get_local_public_room_list(
- limit, since_token,
- network_tuple=network_tuple,
- from_federation=True,
+ limit, since_token, network_tuple=network_tuple, from_federation=True
)
defer.returnValue((200, data))
@@ -761,17 +760,18 @@ class FederationVersionServlet(BaseFederationServlet):
REQUIRE_AUTH = False
def on_GET(self, origin, content, query):
- return defer.succeed((200, {
- "server": {
- "name": "Synapse",
- "version": get_version_string(synapse)
- },
- }))
+ return defer.succeed(
+ (
+ 200,
+ {"server": {"name": "Synapse", "version": get_version_string(synapse)}},
+ )
+ )
class FederationGroupsProfileServlet(BaseFederationServlet):
"""Get/set the basic profile of a group on behalf of a user
"""
+
PATH = "/groups/(?P<group_id>[^/]*)/profile"
@defer.inlineCallbacks
@@ -780,9 +780,7 @@ class FederationGroupsProfileServlet(BaseFederationServlet):
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- new_content = yield self.handler.get_group_profile(
- group_id, requester_user_id
- )
+ new_content = yield self.handler.get_group_profile(group_id, requester_user_id)
defer.returnValue((200, new_content))
@@ -808,9 +806,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet):
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- new_content = yield self.handler.get_group_summary(
- group_id, requester_user_id
- )
+ new_content = yield self.handler.get_group_summary(group_id, requester_user_id)
defer.returnValue((200, new_content))
@@ -818,6 +814,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet):
class FederationGroupsRoomsServlet(BaseFederationServlet):
"""Get the rooms in a group on behalf of a user
"""
+
PATH = "/groups/(?P<group_id>[^/]*)/rooms"
@defer.inlineCallbacks
@@ -826,9 +823,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet):
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- new_content = yield self.handler.get_rooms_in_group(
- group_id, requester_user_id
- )
+ new_content = yield self.handler.get_rooms_in_group(group_id, requester_user_id)
defer.returnValue((200, new_content))
@@ -836,6 +831,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet):
class FederationGroupsAddRoomsServlet(BaseFederationServlet):
"""Add/remove room from group
"""
+
PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
@defer.inlineCallbacks
@@ -857,7 +853,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet):
raise SynapseError(403, "requester_user_id doesn't match origin")
new_content = yield self.handler.remove_room_from_group(
- group_id, requester_user_id, room_id,
+ group_id, requester_user_id, room_id
)
defer.returnValue((200, new_content))
@@ -866,6 +862,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet):
class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
"""Update room config in group
"""
+
PATH = (
"/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
"/config/(?P<config_key>[^/]*)"
@@ -878,7 +875,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
raise SynapseError(403, "requester_user_id doesn't match origin")
result = yield self.groups_handler.update_room_in_group(
- group_id, requester_user_id, room_id, config_key, content,
+ group_id, requester_user_id, room_id, config_key, content
)
defer.returnValue((200, result))
@@ -887,6 +884,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
class FederationGroupsUsersServlet(BaseFederationServlet):
"""Get the users in a group on behalf of a user
"""
+
PATH = "/groups/(?P<group_id>[^/]*)/users"
@defer.inlineCallbacks
@@ -895,9 +893,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet):
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- new_content = yield self.handler.get_users_in_group(
- group_id, requester_user_id
- )
+ new_content = yield self.handler.get_users_in_group(group_id, requester_user_id)
defer.returnValue((200, new_content))
@@ -905,6 +901,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet):
class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
"""Get the users that have been invited to a group
"""
+
PATH = "/groups/(?P<group_id>[^/]*)/invited_users"
@defer.inlineCallbacks
@@ -923,6 +920,7 @@ class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
class FederationGroupsInviteServlet(BaseFederationServlet):
"""Ask a group server to invite someone to the group
"""
+
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
@defer.inlineCallbacks
@@ -932,7 +930,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet):
raise SynapseError(403, "requester_user_id doesn't match origin")
new_content = yield self.handler.invite_to_group(
- group_id, user_id, requester_user_id, content,
+ group_id, user_id, requester_user_id, content
)
defer.returnValue((200, new_content))
@@ -941,6 +939,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet):
class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
"""Accept an invitation from the group server
"""
+
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite"
@defer.inlineCallbacks
@@ -948,9 +947,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
if get_domain_from_id(user_id) != origin:
raise SynapseError(403, "user_id doesn't match origin")
- new_content = yield self.handler.accept_invite(
- group_id, user_id, content,
- )
+ new_content = yield self.handler.accept_invite(group_id, user_id, content)
defer.returnValue((200, new_content))
@@ -958,6 +955,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
class FederationGroupsJoinServlet(BaseFederationServlet):
"""Attempt to join a group
"""
+
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join"
@defer.inlineCallbacks
@@ -965,9 +963,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet):
if get_domain_from_id(user_id) != origin:
raise SynapseError(403, "user_id doesn't match origin")
- new_content = yield self.handler.join_group(
- group_id, user_id, content,
- )
+ new_content = yield self.handler.join_group(group_id, user_id, content)
defer.returnValue((200, new_content))
@@ -975,6 +971,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet):
class FederationGroupsRemoveUserServlet(BaseFederationServlet):
"""Leave or kick a user from the group
"""
+
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
@defer.inlineCallbacks
@@ -984,7 +981,7 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet):
raise SynapseError(403, "requester_user_id doesn't match origin")
new_content = yield self.handler.remove_user_from_group(
- group_id, user_id, requester_user_id, content,
+ group_id, user_id, requester_user_id, content
)
defer.returnValue((200, new_content))
@@ -993,6 +990,7 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet):
class FederationGroupsLocalInviteServlet(BaseFederationServlet):
"""A group server has invited a local user
"""
+
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
@defer.inlineCallbacks
@@ -1000,9 +998,7 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet):
if get_domain_from_id(group_id) != origin:
raise SynapseError(403, "group_id doesn't match origin")
- new_content = yield self.handler.on_invite(
- group_id, user_id, content,
- )
+ new_content = yield self.handler.on_invite(group_id, user_id, content)
defer.returnValue((200, new_content))
@@ -1010,6 +1006,7 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet):
class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
"""A group server has removed a local user
"""
+
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
@defer.inlineCallbacks
@@ -1018,7 +1015,7 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
raise SynapseError(403, "user_id doesn't match origin")
new_content = yield self.handler.user_removed_from_group(
- group_id, user_id, content,
+ group_id, user_id, content
)
defer.returnValue((200, new_content))
@@ -1027,6 +1024,7 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
"""A group or user's server renews their attestation
"""
+
PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)"
@defer.inlineCallbacks
@@ -1047,6 +1045,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
- /groups/:group/summary/rooms/:room_id
- /groups/:group/summary/categories/:category/rooms/:room_id
"""
+
PATH = (
"/groups/(?P<group_id>[^/]*)/summary"
"(/categories/(?P<category_id>[^/]+))?"
@@ -1063,7 +1062,8 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
raise SynapseError(400, "category_id cannot be empty string")
resp = yield self.handler.update_group_summary_room(
- group_id, requester_user_id,
+ group_id,
+ requester_user_id,
room_id=room_id,
category_id=category_id,
content=content,
@@ -1081,9 +1081,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
raise SynapseError(400, "category_id cannot be empty string")
resp = yield self.handler.delete_group_summary_room(
- group_id, requester_user_id,
- room_id=room_id,
- category_id=category_id,
+ group_id, requester_user_id, room_id=room_id, category_id=category_id
)
defer.returnValue((200, resp))
@@ -1092,9 +1090,8 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
class FederationGroupsCategoriesServlet(BaseFederationServlet):
"""Get all categories for a group
"""
- PATH = (
- "/groups/(?P<group_id>[^/]*)/categories/?"
- )
+
+ PATH = "/groups/(?P<group_id>[^/]*)/categories/?"
@defer.inlineCallbacks
def on_GET(self, origin, content, query, group_id):
@@ -1102,9 +1099,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet):
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- resp = yield self.handler.get_group_categories(
- group_id, requester_user_id,
- )
+ resp = yield self.handler.get_group_categories(group_id, requester_user_id)
defer.returnValue((200, resp))
@@ -1112,9 +1107,8 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet):
class FederationGroupsCategoryServlet(BaseFederationServlet):
"""Add/remove/get a category in a group
"""
- PATH = (
- "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
- )
+
+ PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
@defer.inlineCallbacks
def on_GET(self, origin, content, query, group_id, category_id):
@@ -1138,7 +1132,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
raise SynapseError(400, "category_id cannot be empty string")
resp = yield self.handler.upsert_group_category(
- group_id, requester_user_id, category_id, content,
+ group_id, requester_user_id, category_id, content
)
defer.returnValue((200, resp))
@@ -1153,7 +1147,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
raise SynapseError(400, "category_id cannot be empty string")
resp = yield self.handler.delete_group_category(
- group_id, requester_user_id, category_id,
+ group_id, requester_user_id, category_id
)
defer.returnValue((200, resp))
@@ -1162,9 +1156,8 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
class FederationGroupsRolesServlet(BaseFederationServlet):
"""Get roles in a group
"""
- PATH = (
- "/groups/(?P<group_id>[^/]*)/roles/?"
- )
+
+ PATH = "/groups/(?P<group_id>[^/]*)/roles/?"
@defer.inlineCallbacks
def on_GET(self, origin, content, query, group_id):
@@ -1172,9 +1165,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet):
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- resp = yield self.handler.get_group_roles(
- group_id, requester_user_id,
- )
+ resp = yield self.handler.get_group_roles(group_id, requester_user_id)
defer.returnValue((200, resp))
@@ -1182,9 +1173,8 @@ class FederationGroupsRolesServlet(BaseFederationServlet):
class FederationGroupsRoleServlet(BaseFederationServlet):
"""Add/remove/get a role in a group
"""
- PATH = (
- "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
- )
+
+ PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
@defer.inlineCallbacks
def on_GET(self, origin, content, query, group_id, role_id):
@@ -1192,9 +1182,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet):
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- resp = yield self.handler.get_group_role(
- group_id, requester_user_id, role_id
- )
+ resp = yield self.handler.get_group_role(group_id, requester_user_id, role_id)
defer.returnValue((200, resp))
@@ -1208,7 +1196,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet):
raise SynapseError(400, "role_id cannot be empty string")
resp = yield self.handler.update_group_role(
- group_id, requester_user_id, role_id, content,
+ group_id, requester_user_id, role_id, content
)
defer.returnValue((200, resp))
@@ -1223,7 +1211,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet):
raise SynapseError(400, "role_id cannot be empty string")
resp = yield self.handler.delete_group_role(
- group_id, requester_user_id, role_id,
+ group_id, requester_user_id, role_id
)
defer.returnValue((200, resp))
@@ -1236,6 +1224,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
- /groups/:group/summary/users/:user_id
- /groups/:group/summary/roles/:role/users/:user_id
"""
+
PATH = (
"/groups/(?P<group_id>[^/]*)/summary"
"(/roles/(?P<role_id>[^/]+))?"
@@ -1252,7 +1241,8 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
raise SynapseError(400, "role_id cannot be empty string")
resp = yield self.handler.update_group_summary_user(
- group_id, requester_user_id,
+ group_id,
+ requester_user_id,
user_id=user_id,
role_id=role_id,
content=content,
@@ -1270,9 +1260,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
raise SynapseError(400, "role_id cannot be empty string")
resp = yield self.handler.delete_group_summary_user(
- group_id, requester_user_id,
- user_id=user_id,
- role_id=role_id,
+ group_id, requester_user_id, user_id=user_id, role_id=role_id
)
defer.returnValue((200, resp))
@@ -1281,14 +1269,13 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
"""Get roles in a group
"""
- PATH = (
- "/get_groups_publicised"
- )
+
+ PATH = "/get_groups_publicised"
@defer.inlineCallbacks
def on_POST(self, origin, content, query):
resp = yield self.handler.bulk_get_publicised_groups(
- content["user_ids"], proxy=False,
+ content["user_ids"], proxy=False
)
defer.returnValue((200, resp))
@@ -1297,6 +1284,7 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
"""Sets whether a group is joinable without an invite or knock
"""
+
PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy"
@defer.inlineCallbacks
@@ -1317,6 +1305,7 @@ class RoomComplexityServlet(BaseFederationServlet):
Indicates to other servers how complex (and therefore likely
resource-intensive) a public room this server knows about is.
"""
+
PATH = "/rooms/(?P<room_id>[^/]*)/complexity"
PREFIX = FEDERATION_UNSTABLE_PREFIX
@@ -1325,9 +1314,7 @@ class RoomComplexityServlet(BaseFederationServlet):
store = self.handler.hs.get_datastore()
- is_public = yield store.is_room_world_readable_or_publicly_joinable(
- room_id
- )
+ is_public = yield store.is_room_world_readable_or_publicly_joinable(room_id)
if not is_public:
raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM)
@@ -1362,13 +1349,9 @@ FEDERATION_SERVLET_CLASSES = (
RoomComplexityServlet,
)
-OPENID_SERVLET_CLASSES = (
- OpenIdUserInfo,
-)
+OPENID_SERVLET_CLASSES = (OpenIdUserInfo,)
-ROOM_LIST_CLASSES = (
- PublicRoomList,
-)
+ROOM_LIST_CLASSES = (PublicRoomList,)
GROUP_SERVER_SERVLET_CLASSES = (
FederationGroupsProfileServlet,
@@ -1399,9 +1382,7 @@ GROUP_LOCAL_SERVLET_CLASSES = (
)
-GROUP_ATTESTATION_SERVLET_CLASSES = (
- FederationGroupsRenewAttestaionServlet,
-)
+GROUP_ATTESTATION_SERVLET_CLASSES = (FederationGroupsRenewAttestaionServlet,)
DEFAULT_SERVLET_GROUPS = (
"federation",
@@ -1455,7 +1436,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
- deny_access=hs.config.restrict_public_rooms_to_local_users,
+ allow_access=hs.config.allow_public_rooms_over_federation,
).register(resource)
if "group_server" in servlet_groups:
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index 025a79c022..14aad8f09d 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -32,21 +32,11 @@ class Edu(JsonEncodedObject):
internal ID or previous references graph.
"""
- valid_keys = [
- "origin",
- "destination",
- "edu_type",
- "content",
- ]
+ valid_keys = ["origin", "destination", "edu_type", "content"]
- required_keys = [
- "edu_type",
- ]
+ required_keys = ["edu_type"]
- internal_keys = [
- "origin",
- "destination",
- ]
+ internal_keys = ["origin", "destination"]
class Transaction(JsonEncodedObject):
@@ -75,10 +65,7 @@ class Transaction(JsonEncodedObject):
"edus",
]
- internal_keys = [
- "transaction_id",
- "destination",
- ]
+ internal_keys = ["transaction_id", "destination"]
required_keys = [
"transaction_id",
@@ -98,9 +85,7 @@ class Transaction(JsonEncodedObject):
del kwargs["edus"]
super(Transaction, self).__init__(
- transaction_id=transaction_id,
- pdus=pdus,
- **kwargs
+ transaction_id=transaction_id, pdus=pdus, **kwargs
)
@staticmethod
@@ -109,13 +94,9 @@ class Transaction(JsonEncodedObject):
transaction_id and origin_server_ts keys.
"""
if "origin_server_ts" not in kwargs:
- raise KeyError(
- "Require 'origin_server_ts' to construct a Transaction"
- )
+ raise KeyError("Require 'origin_server_ts' to construct a Transaction")
if "transaction_id" not in kwargs:
- raise KeyError(
- "Require 'transaction_id' to construct a Transaction"
- )
+ raise KeyError("Require 'transaction_id' to construct a Transaction")
kwargs["pdus"] = [p.get_pdu_json() for p in pdus]
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index e5dda1975f..e73757570c 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -42,7 +42,7 @@ from signedjson.sign import sign_json
from twisted.internet import defer
-from synapse.api.errors import RequestSendFailed, SynapseError
+from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import get_domain_from_id
from synapse.util.logcontext import run_in_background
@@ -65,6 +65,7 @@ UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
class GroupAttestationSigning(object):
"""Creates and verifies group attestations.
"""
+
def __init__(self, hs):
self.keyring = hs.get_keyring()
self.clock = hs.get_clock()
@@ -113,11 +114,15 @@ class GroupAttestationSigning(object):
validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER)
valid_until_ms = int(self.clock.time_msec() + validity_period)
- return sign_json({
- "group_id": group_id,
- "user_id": user_id,
- "valid_until_ms": valid_until_ms,
- }, self.server_name, self.signing_key)
+ return sign_json(
+ {
+ "group_id": group_id,
+ "user_id": user_id,
+ "valid_until_ms": valid_until_ms,
+ },
+ self.server_name,
+ self.signing_key,
+ )
class GroupAttestionRenewer(object):
@@ -132,9 +137,10 @@ class GroupAttestionRenewer(object):
self.is_mine_id = hs.is_mine_id
self.attestations = hs.get_groups_attestation_signing()
- self._renew_attestations_loop = self.clock.looping_call(
- self._start_renew_attestations, 30 * 60 * 1000,
- )
+ if not hs.config.worker_app:
+ self._renew_attestations_loop = self.clock.looping_call(
+ self._start_renew_attestations, 30 * 60 * 1000
+ )
@defer.inlineCallbacks
def on_renew_attestation(self, group_id, user_id, content):
@@ -146,9 +152,7 @@ class GroupAttestionRenewer(object):
raise SynapseError(400, "Neither user not group are on this server")
yield self.attestations.verify_attestation(
- attestation,
- user_id=user_id,
- group_id=group_id,
+ attestation, user_id=user_id, group_id=group_id
)
yield self.store.update_remote_attestion(group_id, user_id, attestation)
@@ -179,7 +183,8 @@ class GroupAttestionRenewer(object):
else:
logger.warn(
"Incorrectly trying to do attestations for user: %r in %r",
- user_id, group_id,
+ user_id,
+ group_id,
)
yield self.store.remove_attestation_renewal(group_id, user_id)
return
@@ -187,21 +192,20 @@ class GroupAttestionRenewer(object):
attestation = self.attestations.create_attestation(group_id, user_id)
yield self.transport_client.renew_group_attestation(
- destination, group_id, user_id,
- content={"attestation": attestation},
+ destination, group_id, user_id, content={"attestation": attestation}
)
yield self.store.update_attestation_renewal(
group_id, user_id, attestation
)
- except RequestSendFailed as e:
+ except (RequestSendFailed, HttpResponseException) as e:
logger.warning(
- "Failed to renew attestation of %r in %r: %s",
- user_id, group_id, e,
+ "Failed to renew attestation of %r in %r: %s", user_id, group_id, e
)
except Exception:
- logger.exception("Error renewing attestation of %r in %r",
- user_id, group_id)
+ logger.exception(
+ "Error renewing attestation of %r in %r", user_id, group_id
+ )
for row in rows:
group_id = row["group_id"]
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 817be40360..168c9e3f84 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -54,8 +54,9 @@ class GroupsServerHandler(object):
hs.get_groups_attestation_renewer()
@defer.inlineCallbacks
- def check_group_is_ours(self, group_id, requester_user_id,
- and_exists=False, and_is_admin=None):
+ def check_group_is_ours(
+ self, group_id, requester_user_id, and_exists=False, and_is_admin=None
+ ):
"""Check that the group is ours, and optionally if it exists.
If group does exist then return group.
@@ -73,7 +74,9 @@ class GroupsServerHandler(object):
if and_exists and not group:
raise SynapseError(404, "Unknown group")
- is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
+ is_user_in_group = yield self.store.is_user_in_group(
+ requester_user_id, group_id
+ )
if group and not is_user_in_group and not group["is_public"]:
raise SynapseError(404, "Unknown group")
@@ -96,25 +99,27 @@ class GroupsServerHandler(object):
"""
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
- is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
+ is_user_in_group = yield self.store.is_user_in_group(
+ requester_user_id, group_id
+ )
profile = yield self.get_group_profile(group_id, requester_user_id)
users, roles = yield self.store.get_users_for_summary_by_role(
- group_id, include_private=is_user_in_group,
+ group_id, include_private=is_user_in_group
)
# TODO: Add profiles to users
rooms, categories = yield self.store.get_rooms_for_summary_by_category(
- group_id, include_private=is_user_in_group,
+ group_id, include_private=is_user_in_group
)
for room_entry in rooms:
room_id = room_entry["room_id"]
joined_users = yield self.store.get_users_in_room(room_id)
entry = yield self.room_list_handler.generate_room_entry(
- room_id, len(joined_users), with_alias=False, allow_private=True,
+ room_id, len(joined_users), with_alias=False, allow_private=True
)
entry = dict(entry) # so we don't change whats cached
entry.pop("room_id", None)
@@ -134,7 +139,7 @@ class GroupsServerHandler(object):
entry["attestation"] = attestation
else:
entry["attestation"] = self.attestations.create_attestation(
- group_id, user_id,
+ group_id, user_id
)
user_profile = yield self.profile_handler.get_profile_from_cache(user_id)
@@ -143,34 +148,34 @@ class GroupsServerHandler(object):
users.sort(key=lambda e: e.get("order", 0))
membership_info = yield self.store.get_users_membership_info_in_group(
- group_id, requester_user_id,
+ group_id, requester_user_id
)
- defer.returnValue({
- "profile": profile,
- "users_section": {
- "users": users,
- "roles": roles,
- "total_user_count_estimate": 0, # TODO
- },
- "rooms_section": {
- "rooms": rooms,
- "categories": categories,
- "total_room_count_estimate": 0, # TODO
- },
- "user": membership_info,
- })
+ defer.returnValue(
+ {
+ "profile": profile,
+ "users_section": {
+ "users": users,
+ "roles": roles,
+ "total_user_count_estimate": 0, # TODO
+ },
+ "rooms_section": {
+ "rooms": rooms,
+ "categories": categories,
+ "total_room_count_estimate": 0, # TODO
+ },
+ "user": membership_info,
+ }
+ )
@defer.inlineCallbacks
- def update_group_summary_room(self, group_id, requester_user_id,
- room_id, category_id, content):
+ def update_group_summary_room(
+ self, group_id, requester_user_id, room_id, category_id, content
+ ):
"""Add/update a room to the group summary
"""
yield self.check_group_is_ours(
- group_id,
- requester_user_id,
- and_exists=True,
- and_is_admin=requester_user_id,
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
RoomID.from_string(room_id) # Ensure valid room id
@@ -190,21 +195,17 @@ class GroupsServerHandler(object):
defer.returnValue({})
@defer.inlineCallbacks
- def delete_group_summary_room(self, group_id, requester_user_id,
- room_id, category_id):
+ def delete_group_summary_room(
+ self, group_id, requester_user_id, room_id, category_id
+ ):
"""Remove a room from the summary
"""
yield self.check_group_is_ours(
- group_id,
- requester_user_id,
- and_exists=True,
- and_is_admin=requester_user_id,
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
yield self.store.remove_room_from_summary(
- group_id=group_id,
- room_id=room_id,
- category_id=category_id,
+ group_id=group_id, room_id=room_id, category_id=category_id
)
defer.returnValue({})
@@ -223,9 +224,7 @@ class GroupsServerHandler(object):
join_policy = _parse_join_policy_from_contents(content)
if join_policy is None:
- raise SynapseError(
- 400, "No value specified for 'm.join_policy'"
- )
+ raise SynapseError(400, "No value specified for 'm.join_policy'")
yield self.store.set_group_join_policy(group_id, join_policy=join_policy)
@@ -237,9 +236,7 @@ class GroupsServerHandler(object):
"""
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
- categories = yield self.store.get_group_categories(
- group_id=group_id,
- )
+ categories = yield self.store.get_group_categories(group_id=group_id)
defer.returnValue({"categories": categories})
@defer.inlineCallbacks
@@ -249,8 +246,7 @@ class GroupsServerHandler(object):
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
res = yield self.store.get_group_category(
- group_id=group_id,
- category_id=category_id,
+ group_id=group_id, category_id=category_id
)
defer.returnValue(res)
@@ -260,10 +256,7 @@ class GroupsServerHandler(object):
"""Add/Update a group category
"""
yield self.check_group_is_ours(
- group_id,
- requester_user_id,
- and_exists=True,
- and_is_admin=requester_user_id,
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
is_public = _parse_visibility_from_contents(content)
@@ -283,15 +276,11 @@ class GroupsServerHandler(object):
"""Delete a group category
"""
yield self.check_group_is_ours(
- group_id,
- requester_user_id,
- and_exists=True,
- and_is_admin=requester_user_id
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
yield self.store.remove_group_category(
- group_id=group_id,
- category_id=category_id,
+ group_id=group_id, category_id=category_id
)
defer.returnValue({})
@@ -302,9 +291,7 @@ class GroupsServerHandler(object):
"""
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
- roles = yield self.store.get_group_roles(
- group_id=group_id,
- )
+ roles = yield self.store.get_group_roles(group_id=group_id)
defer.returnValue({"roles": roles})
@defer.inlineCallbacks
@@ -313,10 +300,7 @@ class GroupsServerHandler(object):
"""
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
- res = yield self.store.get_group_role(
- group_id=group_id,
- role_id=role_id,
- )
+ res = yield self.store.get_group_role(group_id=group_id, role_id=role_id)
defer.returnValue(res)
@defer.inlineCallbacks
@@ -324,10 +308,7 @@ class GroupsServerHandler(object):
"""Add/update a role in a group
"""
yield self.check_group_is_ours(
- group_id,
- requester_user_id,
- and_exists=True,
- and_is_admin=requester_user_id,
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
is_public = _parse_visibility_from_contents(content)
@@ -335,10 +316,7 @@ class GroupsServerHandler(object):
profile = content.get("profile")
yield self.store.upsert_group_role(
- group_id=group_id,
- role_id=role_id,
- is_public=is_public,
- profile=profile,
+ group_id=group_id, role_id=role_id, is_public=is_public, profile=profile
)
defer.returnValue({})
@@ -348,26 +326,21 @@ class GroupsServerHandler(object):
"""Remove role from group
"""
yield self.check_group_is_ours(
- group_id,
- requester_user_id,
- and_exists=True,
- and_is_admin=requester_user_id,
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
- yield self.store.remove_group_role(
- group_id=group_id,
- role_id=role_id,
- )
+ yield self.store.remove_group_role(group_id=group_id, role_id=role_id)
defer.returnValue({})
@defer.inlineCallbacks
- def update_group_summary_user(self, group_id, requester_user_id, user_id, role_id,
- content):
+ def update_group_summary_user(
+ self, group_id, requester_user_id, user_id, role_id, content
+ ):
"""Add/update a users entry in the group summary
"""
yield self.check_group_is_ours(
- group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
order = content.get("order", None)
@@ -389,13 +362,11 @@ class GroupsServerHandler(object):
"""Remove a user from the group summary
"""
yield self.check_group_is_ours(
- group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
yield self.store.remove_user_from_summary(
- group_id=group_id,
- user_id=user_id,
- role_id=role_id,
+ group_id=group_id, user_id=user_id, role_id=role_id
)
defer.returnValue({})
@@ -411,8 +382,11 @@ class GroupsServerHandler(object):
if group:
cols = [
- "name", "short_description", "long_description",
- "avatar_url", "is_public",
+ "name",
+ "short_description",
+ "long_description",
+ "avatar_url",
+ "is_public",
]
group_description = {key: group[key] for key in cols}
group_description["is_openly_joinable"] = group["join_policy"] == "open"
@@ -426,12 +400,11 @@ class GroupsServerHandler(object):
"""Update the group profile
"""
yield self.check_group_is_ours(
- group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
profile = {}
- for keyname in ("name", "avatar_url", "short_description",
- "long_description"):
+ for keyname in ("name", "avatar_url", "short_description", "long_description"):
if keyname in content:
value = content[keyname]
if not isinstance(value, string_types):
@@ -449,10 +422,12 @@ class GroupsServerHandler(object):
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
- is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
+ is_user_in_group = yield self.store.is_user_in_group(
+ requester_user_id, group_id
+ )
user_results = yield self.store.get_users_in_group(
- group_id, include_private=is_user_in_group,
+ group_id, include_private=is_user_in_group
)
chunk = []
@@ -470,24 +445,25 @@ class GroupsServerHandler(object):
entry["is_privileged"] = bool(is_privileged)
if not self.is_mine_id(g_user_id):
- attestation = yield self.store.get_remote_attestation(group_id, g_user_id)
+ attestation = yield self.store.get_remote_attestation(
+ group_id, g_user_id
+ )
if not attestation:
continue
entry["attestation"] = attestation
else:
entry["attestation"] = self.attestations.create_attestation(
- group_id, g_user_id,
+ group_id, g_user_id
)
chunk.append(entry)
# TODO: If admin add lists of users whose attestations have timed out
- defer.returnValue({
- "chunk": chunk,
- "total_user_count_estimate": len(user_results),
- })
+ defer.returnValue(
+ {"chunk": chunk, "total_user_count_estimate": len(user_results)}
+ )
@defer.inlineCallbacks
def get_invited_users_in_group(self, group_id, requester_user_id):
@@ -498,7 +474,9 @@ class GroupsServerHandler(object):
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
- is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
+ is_user_in_group = yield self.store.is_user_in_group(
+ requester_user_id, group_id
+ )
if not is_user_in_group:
raise SynapseError(403, "User not in group")
@@ -508,9 +486,7 @@ class GroupsServerHandler(object):
user_profiles = []
for user_id in invited_users:
- user_profile = {
- "user_id": user_id
- }
+ user_profile = {"user_id": user_id}
try:
profile = yield self.profile_handler.get_profile_from_cache(user_id)
user_profile.update(profile)
@@ -518,10 +494,9 @@ class GroupsServerHandler(object):
logger.warn("Error getting profile for %s: %s", user_id, e)
user_profiles.append(user_profile)
- defer.returnValue({
- "chunk": user_profiles,
- "total_user_count_estimate": len(invited_users),
- })
+ defer.returnValue(
+ {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)}
+ )
@defer.inlineCallbacks
def get_rooms_in_group(self, group_id, requester_user_id):
@@ -532,10 +507,12 @@ class GroupsServerHandler(object):
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
- is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
+ is_user_in_group = yield self.store.is_user_in_group(
+ requester_user_id, group_id
+ )
room_results = yield self.store.get_rooms_in_group(
- group_id, include_private=is_user_in_group,
+ group_id, include_private=is_user_in_group
)
chunk = []
@@ -544,7 +521,7 @@ class GroupsServerHandler(object):
joined_users = yield self.store.get_users_in_room(room_id)
entry = yield self.room_list_handler.generate_room_entry(
- room_id, len(joined_users), with_alias=False, allow_private=True,
+ room_id, len(joined_users), with_alias=False, allow_private=True
)
if not entry:
@@ -556,10 +533,9 @@ class GroupsServerHandler(object):
chunk.sort(key=lambda e: -e["num_joined_members"])
- defer.returnValue({
- "chunk": chunk,
- "total_room_count_estimate": len(room_results),
- })
+ defer.returnValue(
+ {"chunk": chunk, "total_room_count_estimate": len(room_results)}
+ )
@defer.inlineCallbacks
def add_room_to_group(self, group_id, requester_user_id, room_id, content):
@@ -578,8 +554,9 @@ class GroupsServerHandler(object):
defer.returnValue({})
@defer.inlineCallbacks
- def update_room_in_group(self, group_id, requester_user_id, room_id, config_key,
- content):
+ def update_room_in_group(
+ self, group_id, requester_user_id, room_id, config_key, content
+ ):
"""Update room in group
"""
RoomID.from_string(room_id) # Ensure valid room id
@@ -592,8 +569,7 @@ class GroupsServerHandler(object):
is_public = _parse_visibility_dict(content)
yield self.store.update_room_in_group_visibility(
- group_id, room_id,
- is_public=is_public,
+ group_id, room_id, is_public=is_public
)
else:
raise SynapseError(400, "Uknown config option")
@@ -625,10 +601,7 @@ class GroupsServerHandler(object):
# TODO: Check if user is already invited
content = {
- "profile": {
- "name": group["name"],
- "avatar_url": group["avatar_url"],
- },
+ "profile": {"name": group["name"], "avatar_url": group["avatar_url"]},
"inviter": requester_user_id,
}
@@ -638,9 +611,7 @@ class GroupsServerHandler(object):
local_attestation = None
else:
local_attestation = self.attestations.create_attestation(group_id, user_id)
- content.update({
- "attestation": local_attestation,
- })
+ content.update({"attestation": local_attestation})
res = yield self.transport_client.invite_to_group_notification(
get_domain_from_id(user_id), group_id, user_id, content
@@ -658,31 +629,24 @@ class GroupsServerHandler(object):
remote_attestation = res["attestation"]
yield self.attestations.verify_attestation(
- remote_attestation,
- user_id=user_id,
- group_id=group_id,
+ remote_attestation, user_id=user_id, group_id=group_id
)
else:
remote_attestation = None
yield self.store.add_user_to_group(
- group_id, user_id,
+ group_id,
+ user_id,
is_admin=False,
is_public=False, # TODO
local_attestation=local_attestation,
remote_attestation=remote_attestation,
)
elif res["state"] == "invite":
- yield self.store.add_group_invite(
- group_id, user_id,
- )
- defer.returnValue({
- "state": "invite"
- })
+ yield self.store.add_group_invite(group_id, user_id)
+ defer.returnValue({"state": "invite"})
elif res["state"] == "reject":
- defer.returnValue({
- "state": "reject"
- })
+ defer.returnValue({"state": "reject"})
else:
raise SynapseError(502, "Unknown state returned by HS")
@@ -693,16 +657,12 @@ class GroupsServerHandler(object):
See accept_invite, join_group.
"""
if not self.hs.is_mine_id(user_id):
- local_attestation = self.attestations.create_attestation(
- group_id, user_id,
- )
+ local_attestation = self.attestations.create_attestation(group_id, user_id)
remote_attestation = content["attestation"]
yield self.attestations.verify_attestation(
- remote_attestation,
- user_id=user_id,
- group_id=group_id,
+ remote_attestation, user_id=user_id, group_id=group_id
)
else:
local_attestation = None
@@ -711,7 +671,8 @@ class GroupsServerHandler(object):
is_public = _parse_visibility_from_contents(content)
yield self.store.add_user_to_group(
- group_id, user_id,
+ group_id,
+ user_id,
is_admin=False,
is_public=is_public,
local_attestation=local_attestation,
@@ -731,17 +692,14 @@ class GroupsServerHandler(object):
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
is_invited = yield self.store.is_user_invited_to_local_group(
- group_id, requester_user_id,
+ group_id, requester_user_id
)
if not is_invited:
raise SynapseError(403, "User not invited to group")
local_attestation = yield self._add_user(group_id, requester_user_id, content)
- defer.returnValue({
- "state": "join",
- "attestation": local_attestation,
- })
+ defer.returnValue({"state": "join", "attestation": local_attestation})
@defer.inlineCallbacks
def join_group(self, group_id, requester_user_id, content):
@@ -753,15 +711,12 @@ class GroupsServerHandler(object):
group_info = yield self.check_group_is_ours(
group_id, requester_user_id, and_exists=True
)
- if group_info['join_policy'] != "open":
+ if group_info["join_policy"] != "open":
raise SynapseError(403, "Group is not publicly joinable")
local_attestation = yield self._add_user(group_id, requester_user_id, content)
- defer.returnValue({
- "state": "join",
- "attestation": local_attestation,
- })
+ defer.returnValue({"state": "join", "attestation": local_attestation})
@defer.inlineCallbacks
def knock(self, group_id, requester_user_id, content):
@@ -800,9 +755,7 @@ class GroupsServerHandler(object):
is_kick = True
- yield self.store.remove_user_from_group(
- group_id, user_id,
- )
+ yield self.store.remove_user_from_group(group_id, user_id)
if is_kick:
if self.hs.is_mine_id(user_id):
@@ -830,19 +783,20 @@ class GroupsServerHandler(object):
if group:
raise SynapseError(400, "Group already exists")
- is_admin = yield self.auth.is_server_admin(UserID.from_string(requester_user_id))
+ is_admin = yield self.auth.is_server_admin(
+ UserID.from_string(requester_user_id)
+ )
if not is_admin:
if not self.hs.config.enable_group_creation:
raise SynapseError(
- 403, "Only a server admin can create groups on this server",
+ 403, "Only a server admin can create groups on this server"
)
localpart = group_id_obj.localpart
if not localpart.startswith(self.hs.config.group_creation_prefix):
raise SynapseError(
400,
- "Can only create groups with prefix %r on this server" % (
- self.hs.config.group_creation_prefix,
- ),
+ "Can only create groups with prefix %r on this server"
+ % (self.hs.config.group_creation_prefix,),
)
profile = content.get("profile", {})
@@ -865,21 +819,19 @@ class GroupsServerHandler(object):
remote_attestation = content["attestation"]
yield self.attestations.verify_attestation(
- remote_attestation,
- user_id=requester_user_id,
- group_id=group_id,
+ remote_attestation, user_id=requester_user_id, group_id=group_id
)
local_attestation = self.attestations.create_attestation(
- group_id,
- requester_user_id,
+ group_id, requester_user_id
)
else:
local_attestation = None
remote_attestation = None
yield self.store.add_user_to_group(
- group_id, requester_user_id,
+ group_id,
+ requester_user_id,
is_admin=True,
is_public=True, # TODO
local_attestation=local_attestation,
@@ -893,9 +845,7 @@ class GroupsServerHandler(object):
avatar_url=user_profile.get("avatar_url"),
)
- defer.returnValue({
- "group_id": group_id,
- })
+ defer.returnValue({"group_id": group_id})
@defer.inlineCallbacks
def delete_group(self, group_id, requester_user_id):
@@ -911,29 +861,22 @@ class GroupsServerHandler(object):
Deferred
"""
- yield self.check_group_is_ours(
- group_id, requester_user_id,
- and_exists=True,
- )
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
# Only server admins or group admins can delete groups.
- is_admin = yield self.store.is_user_admin_in_group(
- group_id, requester_user_id
- )
+ is_admin = yield self.store.is_user_admin_in_group(group_id, requester_user_id)
if not is_admin:
is_admin = yield self.auth.is_server_admin(
- UserID.from_string(requester_user_id),
+ UserID.from_string(requester_user_id)
)
if not is_admin:
raise SynapseError(403, "User is not an admin")
# Before deleting the group lets kick everyone out of it
- users = yield self.store.get_users_in_group(
- group_id, include_private=True,
- )
+ users = yield self.store.get_users_in_group(group_id, include_private=True)
@defer.inlineCallbacks
def _kick_user_from_group(user_id):
@@ -989,9 +932,7 @@ def _parse_join_policy_dict(join_policy_dict):
return "invite"
if join_policy_type not in ("invite", "open"):
- raise SynapseError(
- 400, "Synapse only supports 'invite'/'open' join rule"
- )
+ raise SynapseError(400, "Synapse only supports 'invite'/'open' join rule")
return join_policy_type
@@ -1018,7 +959,5 @@ def _parse_visibility_dict(visibility):
return True
if vis_type not in ("public", "private"):
- raise SynapseError(
- 400, "Synapse only supports 'public'/'private' visibility"
- )
+ raise SynapseError(400, "Synapse only supports 'public'/'private' visibility")
return vis_type == "public"
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index dca337ec61..c29c78bd65 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -94,14 +94,15 @@ class BaseHandler(object):
burst_count = self.hs.config.rc_message.burst_count
allowed, time_allowed = self.ratelimiter.can_do_action(
- user_id, time_now,
+ user_id,
+ time_now,
rate_hz=messages_per_second,
burst_count=burst_count,
update=update,
)
if not allowed:
raise LimitExceededError(
- retry_after_ms=int(1000 * (time_allowed - time_now)),
+ retry_after_ms=int(1000 * (time_allowed - time_now))
)
@defer.inlineCallbacks
@@ -139,7 +140,7 @@ class BaseHandler(object):
if member_event.content["membership"] not in {
Membership.JOIN,
- Membership.INVITE
+ Membership.INVITE,
}:
continue
@@ -156,8 +157,7 @@ class BaseHandler(object):
# and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having
# homeserver.
- requester = synapse.types.create_requester(
- target_user, is_guest=True)
+ requester = synapse.types.create_requester(target_user, is_guest=True)
handler = self.hs.get_room_member_handler()
yield handler.update_membership(
requester,
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 7fa5d44d29..e62e6cab77 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -20,7 +20,7 @@ class AccountDataEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()
- def get_current_key(self, direction='f'):
+ def get_current_key(self, direction="f"):
return self.store.get_max_account_data_stream_id()
@defer.inlineCallbacks
@@ -34,29 +34,22 @@ class AccountDataEventSource(object):
tags = yield self.store.get_updated_tags(user_id, last_stream_id)
for room_id, room_tags in tags.items():
- results.append({
- "type": "m.tag",
- "content": {"tags": room_tags},
- "room_id": room_id,
- })
+ results.append(
+ {"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id}
+ )
account_data, room_account_data = (
yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
)
for account_data_type, content in account_data.items():
- results.append({
- "type": account_data_type,
- "content": content,
- })
+ results.append({"type": account_data_type, "content": content})
for room_id, account_data in room_account_data.items():
for account_data_type, content in account_data.items():
- results.append({
- "type": account_data_type,
- "content": content,
- "room_id": room_id,
- })
+ results.append(
+ {"type": account_data_type, "content": content, "room_id": room_id}
+ )
defer.returnValue((results, current_stream_id))
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 261446517d..0719da3ab7 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -49,12 +49,10 @@ class AccountValidityHandler(object):
app_name = self.hs.config.email_app_name
self._subject = self._account_validity.renew_email_subject % {
- "app": app_name,
+ "app": app_name
}
- self._from_string = self.hs.config.email_notif_from % {
- "app": app_name,
- }
+ self._from_string = self.hs.config.email_notif_from % {"app": app_name}
except Exception:
# If substitution failed, fall back to the bare strings.
self._subject = self._account_validity.renew_email_subject
@@ -69,10 +67,7 @@ class AccountValidityHandler(object):
)
# Check the renewal emails to send and send them every 30min.
- self.clock.looping_call(
- self.send_renewal_emails,
- 30 * 60 * 1000,
- )
+ self.clock.looping_call(self.send_renewal_emails, 30 * 60 * 1000)
@defer.inlineCallbacks
def send_renewal_emails(self):
@@ -86,8 +81,7 @@ class AccountValidityHandler(object):
if expiring_users:
for user in expiring_users:
yield self._send_renewal_email(
- user_id=user["user_id"],
- expiration_ts=user["expiration_ts_ms"],
+ user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
)
@defer.inlineCallbacks
@@ -110,6 +104,9 @@ class AccountValidityHandler(object):
# Stop right here if the user doesn't have at least one email address.
# In this case, they will have to ask their server admin to renew their
# account manually.
+ # We don't need to do a specific check to make sure the account isn't
+ # deactivated, as a deactivated account isn't supposed to have any
+ # email address attached to it.
if not addresses:
return
@@ -143,32 +140,33 @@ class AccountValidityHandler(object):
for address in addresses:
raw_to = email.utils.parseaddr(address)[1]
- multipart_msg = MIMEMultipart('alternative')
- multipart_msg['Subject'] = self._subject
- multipart_msg['From'] = self._from_string
- multipart_msg['To'] = address
- multipart_msg['Date'] = email.utils.formatdate()
- multipart_msg['Message-ID'] = email.utils.make_msgid()
+ multipart_msg = MIMEMultipart("alternative")
+ multipart_msg["Subject"] = self._subject
+ multipart_msg["From"] = self._from_string
+ multipart_msg["To"] = address
+ multipart_msg["Date"] = email.utils.formatdate()
+ multipart_msg["Message-ID"] = email.utils.make_msgid()
multipart_msg.attach(text_part)
multipart_msg.attach(html_part)
logger.info("Sending renewal email to %s", address)
- yield make_deferred_yieldable(self.sendmail(
- self.hs.config.email_smtp_host,
- self._raw_from, raw_to, multipart_msg.as_string().encode('utf8'),
- reactor=self.hs.get_reactor(),
- port=self.hs.config.email_smtp_port,
- requireAuthentication=self.hs.config.email_smtp_user is not None,
- username=self.hs.config.email_smtp_user,
- password=self.hs.config.email_smtp_pass,
- requireTransportSecurity=self.hs.config.require_transport_security
- ))
-
- yield self.store.set_renewal_mail_status(
- user_id=user_id,
- email_sent=True,
- )
+ yield make_deferred_yieldable(
+ self.sendmail(
+ self.hs.config.email_smtp_host,
+ self._raw_from,
+ raw_to,
+ multipart_msg.as_string().encode("utf8"),
+ reactor=self.hs.get_reactor(),
+ port=self.hs.config.email_smtp_port,
+ requireAuthentication=self.hs.config.email_smtp_user is not None,
+ username=self.hs.config.email_smtp_user,
+ password=self.hs.config.email_smtp_pass,
+ requireTransportSecurity=self.hs.config.require_transport_security,
+ )
+ )
+
+ yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
@defer.inlineCallbacks
def _get_email_addresses_for_user(self, user_id):
@@ -245,9 +243,7 @@ class AccountValidityHandler(object):
expiration_ts = self.clock.time_msec() + self._account_validity.period
yield self.store.set_account_validity_for_user(
- user_id=user_id,
- expiration_ts=expiration_ts,
- email_sent=email_sent,
+ user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
)
defer.returnValue(expiration_ts)
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
index 813777bf18..fbef2f3d38 100644
--- a/synapse/handlers/acme.py
+++ b/synapse/handlers/acme.py
@@ -15,14 +15,9 @@
import logging
-import attr
-from zope.interface import implementer
-
import twisted
import twisted.internet.error
from twisted.internet import defer
-from twisted.python.filepath import FilePath
-from twisted.python.url import URL
from twisted.web import server, static
from twisted.web.resource import Resource
@@ -30,27 +25,6 @@ from synapse.app import check_bind_error
logger = logging.getLogger(__name__)
-try:
- from txacme.interfaces import ICertificateStore
-
- @attr.s
- @implementer(ICertificateStore)
- class ErsatzStore(object):
- """
- A store that only stores in memory.
- """
-
- certs = attr.ib(default=attr.Factory(dict))
-
- def store(self, server_name, pem_objects):
- self.certs[server_name] = [o.as_bytes() for o in pem_objects]
- return defer.succeed(None)
-
-
-except ImportError:
- # txacme is missing
- pass
-
class AcmeHandler(object):
def __init__(self, hs):
@@ -60,6 +34,7 @@ class AcmeHandler(object):
@defer.inlineCallbacks
def start_listening(self):
+ from synapse.handlers import acme_issuing_service
# Configure logging for txacme, if you need to debug
# from eliot import add_destinations
@@ -67,50 +42,27 @@ class AcmeHandler(object):
#
# add_destinations(TwistedDestination())
- from txacme.challenges import HTTP01Responder
- from txacme.service import AcmeIssuingService
- from txacme.endpoint import load_or_create_client_key
- from txacme.client import Client
- from josepy.jwa import RS256
-
- self._store = ErsatzStore()
- responder = HTTP01Responder()
-
- self._issuer = AcmeIssuingService(
- cert_store=self._store,
- client_creator=(
- lambda: Client.from_url(
- reactor=self.reactor,
- url=URL.from_text(self.hs.config.acme_url),
- key=load_or_create_client_key(
- FilePath(self.hs.config.config_dir_path)
- ),
- alg=RS256,
- )
- ),
- clock=self.reactor,
- responders=[responder],
+ well_known = Resource()
+
+ self._issuer = acme_issuing_service.create_issuing_service(
+ self.reactor,
+ acme_url=self.hs.config.acme_url,
+ account_key_file=self.hs.config.acme_account_key_file,
+ well_known_resource=well_known,
)
- well_known = Resource()
- well_known.putChild(b'acme-challenge', responder.resource)
responder_resource = Resource()
- responder_resource.putChild(b'.well-known', well_known)
- responder_resource.putChild(b'check', static.Data(b'OK', b'text/plain'))
-
+ responder_resource.putChild(b".well-known", well_known)
+ responder_resource.putChild(b"check", static.Data(b"OK", b"text/plain"))
srv = server.Site(responder_resource)
bind_addresses = self.hs.config.acme_bind_addresses
for host in bind_addresses:
logger.info(
- "Listening for ACME requests on %s:%i", host, self.hs.config.acme_port,
+ "Listening for ACME requests on %s:%i", host, self.hs.config.acme_port
)
try:
- self.reactor.listenTCP(
- self.hs.config.acme_port,
- srv,
- interface=host,
- )
+ self.reactor.listenTCP(self.hs.config.acme_port, srv, interface=host)
except twisted.internet.error.CannotListenError as e:
check_bind_error(e, host, bind_addresses)
@@ -132,7 +84,7 @@ class AcmeHandler(object):
logger.exception("Fail!")
raise
logger.warning("Reprovisioned %s, saving.", self._acme_domain)
- cert_chain = self._store.certs[self._acme_domain]
+ cert_chain = self._issuer.cert_store.certs[self._acme_domain]
try:
with open(self.hs.config.tls_private_key_file, "wb") as private_key_file:
diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py
new file mode 100644
index 0000000000..e1d4224e74
--- /dev/null
+++ b/synapse/handlers/acme_issuing_service.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Utility function to create an ACME issuing service.
+
+This file contains the unconditional imports on the acme and cryptography bits that we
+only need (and may only have available) if we are doing ACME, so is designed to be
+imported conditionally.
+"""
+import logging
+
+import attr
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import serialization
+from josepy import JWKRSA
+from josepy.jwa import RS256
+from txacme.challenges import HTTP01Responder
+from txacme.client import Client
+from txacme.interfaces import ICertificateStore
+from txacme.service import AcmeIssuingService
+from txacme.util import generate_private_key
+from zope.interface import implementer
+
+from twisted.internet import defer
+from twisted.python.filepath import FilePath
+from twisted.python.url import URL
+
+logger = logging.getLogger(__name__)
+
+
+def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource):
+ """Create an ACME issuing service, and attach it to a web Resource
+
+ Args:
+ reactor: twisted reactor
+ acme_url (str): URL to use to request certificates
+ account_key_file (str): where to store the account key
+ well_known_resource (twisted.web.IResource): web resource for .well-known.
+ we will attach a child resource for "acme-challenge".
+
+ Returns:
+ AcmeIssuingService
+ """
+ responder = HTTP01Responder()
+
+ well_known_resource.putChild(b"acme-challenge", responder.resource)
+
+ store = ErsatzStore()
+
+ return AcmeIssuingService(
+ cert_store=store,
+ client_creator=(
+ lambda: Client.from_url(
+ reactor=reactor,
+ url=URL.from_text(acme_url),
+ key=load_or_create_client_key(account_key_file),
+ alg=RS256,
+ )
+ ),
+ clock=reactor,
+ responders=[responder],
+ )
+
+
+@attr.s
+@implementer(ICertificateStore)
+class ErsatzStore(object):
+ """
+ A store that only stores in memory.
+ """
+
+ certs = attr.ib(default=attr.Factory(dict))
+
+ def store(self, server_name, pem_objects):
+ self.certs[server_name] = [o.as_bytes() for o in pem_objects]
+ return defer.succeed(None)
+
+
+def load_or_create_client_key(key_file):
+ """Load the ACME account key from a file, creating it if it does not exist.
+
+ Args:
+ key_file (str): name of the file to use as the account key
+ """
+ # this is based on txacme.endpoint.load_or_create_client_key, but doesn't
+ # hardcode the 'client.key' filename
+ acme_key_file = FilePath(key_file)
+ if acme_key_file.exists():
+ logger.info("Loading ACME account key from '%s'", acme_key_file)
+ key = serialization.load_pem_private_key(
+ acme_key_file.getContent(), password=None, backend=default_backend()
+ )
+ else:
+ logger.info("Saving new ACME account key to '%s'", acme_key_file)
+ key = generate_private_key("rsa")
+ acme_key_file.setContent(
+ key.private_bytes(
+ encoding=serialization.Encoding.PEM,
+ format=serialization.PrivateFormat.TraditionalOpenSSL,
+ encryption_algorithm=serialization.NoEncryption(),
+ )
+ )
+ return JWKRSA(key=key)
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 5d629126fc..941ebfa107 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -23,7 +23,6 @@ logger = logging.getLogger(__name__)
class AdminHandler(BaseHandler):
-
def __init__(self, hs):
super(AdminHandler, self).__init__(hs)
@@ -33,23 +32,17 @@ class AdminHandler(BaseHandler):
sessions = yield self.store.get_user_ip_and_agents(user)
for session in sessions:
- connections.append({
- "ip": session["ip"],
- "last_seen": session["last_seen"],
- "user_agent": session["user_agent"],
- })
+ connections.append(
+ {
+ "ip": session["ip"],
+ "last_seen": session["last_seen"],
+ "user_agent": session["user_agent"],
+ }
+ )
ret = {
"user_id": user.to_string(),
- "devices": {
- "": {
- "sessions": [
- {
- "connections": connections,
- }
- ]
- },
- },
+ "devices": {"": {"sessions": [{"connections": connections}]}},
}
defer.returnValue(ret)
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 17eedf4dbf..5cc89d43f6 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -38,7 +38,6 @@ events_processed_counter = Counter("synapse_handlers_appservice_events_processed
class ApplicationServicesHandler(object):
-
def __init__(self, hs):
self.store = hs.get_datastore()
self.is_mine_id = hs.is_mine_id
@@ -101,9 +100,10 @@ class ApplicationServicesHandler(object):
yield self._check_user_exists(event.state_key)
if not self.started_scheduler:
+
def start_scheduler():
return self.scheduler.start().addErrback(
- log_failure, "Application Services Failure",
+ log_failure, "Application Services Failure"
)
run_as_background_process("as_scheduler", start_scheduler)
@@ -118,10 +118,15 @@ class ApplicationServicesHandler(object):
for event in events:
yield handle_event(event)
- yield make_deferred_yieldable(defer.gatherResults([
- run_in_background(handle_room_events, evs)
- for evs in itervalues(events_by_room)
- ], consumeErrors=True))
+ yield make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(handle_room_events, evs)
+ for evs in itervalues(events_by_room)
+ ],
+ consumeErrors=True,
+ )
+ )
yield self.store.set_appservice_last_pos(upper_bound)
@@ -129,20 +134,23 @@ class ApplicationServicesHandler(object):
ts = yield self.store.get_received_ts(events[-1].event_id)
synapse.metrics.event_processing_positions.labels(
- "appservice_sender").set(upper_bound)
+ "appservice_sender"
+ ).set(upper_bound)
events_processed_counter.inc(len(events))
- event_processing_loop_room_count.labels(
- "appservice_sender"
- ).inc(len(events_by_room))
+ event_processing_loop_room_count.labels("appservice_sender").inc(
+ len(events_by_room)
+ )
event_processing_loop_counter.labels("appservice_sender").inc()
synapse.metrics.event_processing_lag.labels(
- "appservice_sender").set(now - ts)
+ "appservice_sender"
+ ).set(now - ts)
synapse.metrics.event_processing_last_ts.labels(
- "appservice_sender").set(ts)
+ "appservice_sender"
+ ).set(ts)
finally:
self.is_processing = False
@@ -155,13 +163,9 @@ class ApplicationServicesHandler(object):
Returns:
True if this user exists on at least one application service.
"""
- user_query_services = yield self._get_services_for_user(
- user_id=user_id
- )
+ user_query_services = yield self._get_services_for_user(user_id=user_id)
for user_service in user_query_services:
- is_known_user = yield self.appservice_api.query_user(
- user_service, user_id
- )
+ is_known_user = yield self.appservice_api.query_user(user_service, user_id)
if is_known_user:
defer.returnValue(True)
defer.returnValue(False)
@@ -179,9 +183,7 @@ class ApplicationServicesHandler(object):
room_alias_str = room_alias.to_string()
services = self.store.get_app_services()
alias_query_services = [
- s for s in services if (
- s.is_interested_in_alias(room_alias_str)
- )
+ s for s in services if (s.is_interested_in_alias(room_alias_str))
]
for alias_service in alias_query_services:
is_known_alias = yield self.appservice_api.query_alias(
@@ -189,22 +191,24 @@ class ApplicationServicesHandler(object):
)
if is_known_alias:
# the alias exists now so don't query more ASes.
- result = yield self.store.get_association_from_room_alias(
- room_alias
- )
+ result = yield self.store.get_association_from_room_alias(room_alias)
defer.returnValue(result)
@defer.inlineCallbacks
def query_3pe(self, kind, protocol, fields):
services = yield self._get_services_for_3pn(protocol)
- results = yield make_deferred_yieldable(defer.DeferredList([
- run_in_background(
- self.appservice_api.query_3pe,
- service, kind, protocol, fields,
+ results = yield make_deferred_yieldable(
+ defer.DeferredList(
+ [
+ run_in_background(
+ self.appservice_api.query_3pe, service, kind, protocol, fields
+ )
+ for service in services
+ ],
+ consumeErrors=True,
)
- for service in services
- ], consumeErrors=True))
+ )
ret = []
for (success, result) in results:
@@ -276,18 +280,12 @@ class ApplicationServicesHandler(object):
def _get_services_for_user(self, user_id):
services = self.store.get_app_services()
- interested_list = [
- s for s in services if (
- s.is_interested_in_user(user_id)
- )
- ]
+ interested_list = [s for s in services if (s.is_interested_in_user(user_id))]
return defer.succeed(interested_list)
def _get_services_for_3pn(self, protocol):
services = self.store.get_app_services()
- interested_list = [
- s for s in services if s.is_interested_in_protocol(protocol)
- ]
+ interested_list = [s for s in services if s.is_interested_in_protocol(protocol)]
return defer.succeed(interested_list)
@defer.inlineCallbacks
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index a0cf37a9f9..97b21c4093 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -134,13 +134,9 @@ class AuthHandler(BaseHandler):
"""
# build a list of supported flows
- flows = [
- [login_type] for login_type in self._supported_login_types
- ]
+ flows = [[login_type] for login_type in self._supported_login_types]
- result, params, _ = yield self.check_auth(
- flows, request_body, clientip,
- )
+ result, params, _ = yield self.check_auth(flows, request_body, clientip)
# find the completed login type
for login_type in self._supported_login_types:
@@ -151,9 +147,7 @@ class AuthHandler(BaseHandler):
break
else:
# this can't happen
- raise Exception(
- "check_auth returned True but no successful login type",
- )
+ raise Exception("check_auth returned True but no successful login type")
# check that the UI auth matched the access token
if user_id != requester.user.to_string():
@@ -215,11 +209,11 @@ class AuthHandler(BaseHandler):
authdict = None
sid = None
- if clientdict and 'auth' in clientdict:
- authdict = clientdict['auth']
- del clientdict['auth']
- if 'session' in authdict:
- sid = authdict['session']
+ if clientdict and "auth" in clientdict:
+ authdict = clientdict["auth"]
+ del clientdict["auth"]
+ if "session" in authdict:
+ sid = authdict["session"]
session = self._get_session_info(sid)
if len(clientdict) > 0:
@@ -232,27 +226,27 @@ class AuthHandler(BaseHandler):
# on a home server.
# Revisit: Assumimg the REST APIs do sensible validation, the data
# isn't arbintrary.
- session['clientdict'] = clientdict
+ session["clientdict"] = clientdict
self._save_session(session)
- elif 'clientdict' in session:
- clientdict = session['clientdict']
+ elif "clientdict" in session:
+ clientdict = session["clientdict"]
if not authdict:
raise InteractiveAuthIncompleteError(
- self._auth_dict_for_flows(flows, session),
+ self._auth_dict_for_flows(flows, session)
)
- if 'creds' not in session:
- session['creds'] = {}
- creds = session['creds']
+ if "creds" not in session:
+ session["creds"] = {}
+ creds = session["creds"]
# check auth type currently being presented
errordict = {}
- if 'type' in authdict:
- login_type = authdict['type']
+ if "type" in authdict:
+ login_type = authdict["type"]
try:
result = yield self._check_auth_dict(
- authdict, clientip, password_servlet=password_servlet,
+ authdict, clientip, password_servlet=password_servlet
)
if result:
creds[login_type] = result
@@ -281,16 +275,15 @@ class AuthHandler(BaseHandler):
# and is not sensitive).
logger.info(
"Auth completed with creds: %r. Client dict has keys: %r",
- creds, list(clientdict)
+ creds,
+ list(clientdict),
)
- defer.returnValue((creds, clientdict, session['id']))
+ defer.returnValue((creds, clientdict, session["id"]))
ret = self._auth_dict_for_flows(flows, session)
- ret['completed'] = list(creds)
+ ret["completed"] = list(creds)
ret.update(errordict)
- raise InteractiveAuthIncompleteError(
- ret,
- )
+ raise InteractiveAuthIncompleteError(ret)
@defer.inlineCallbacks
def add_oob_auth(self, stagetype, authdict, clientip):
@@ -300,15 +293,13 @@ class AuthHandler(BaseHandler):
"""
if stagetype not in self.checkers:
raise LoginError(400, "", Codes.MISSING_PARAM)
- if 'session' not in authdict:
+ if "session" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
- sess = self._get_session_info(
- authdict['session']
- )
- if 'creds' not in sess:
- sess['creds'] = {}
- creds = sess['creds']
+ sess = self._get_session_info(authdict["session"])
+ if "creds" not in sess:
+ sess["creds"] = {}
+ creds = sess["creds"]
result = yield self.checkers[stagetype](authdict, clientip)
if result:
@@ -329,10 +320,10 @@ class AuthHandler(BaseHandler):
not send a session ID, returns None.
"""
sid = None
- if clientdict and 'auth' in clientdict:
- authdict = clientdict['auth']
- if 'session' in authdict:
- sid = authdict['session']
+ if clientdict and "auth" in clientdict:
+ authdict = clientdict["auth"]
+ if "session" in authdict:
+ sid = authdict["session"]
return sid
def set_session_data(self, session_id, key, value):
@@ -347,7 +338,7 @@ class AuthHandler(BaseHandler):
value (any): The data to store
"""
sess = self._get_session_info(session_id)
- sess.setdefault('serverdict', {})[key] = value
+ sess.setdefault("serverdict", {})[key] = value
self._save_session(sess)
def get_session_data(self, session_id, key, default=None):
@@ -360,7 +351,7 @@ class AuthHandler(BaseHandler):
default (any): Value to return if the key has not been set
"""
sess = self._get_session_info(session_id)
- return sess.setdefault('serverdict', {}).get(key, default)
+ return sess.setdefault("serverdict", {}).get(key, default)
@defer.inlineCallbacks
def _check_auth_dict(self, authdict, clientip, password_servlet=False):
@@ -378,15 +369,13 @@ class AuthHandler(BaseHandler):
SynapseError if there was a problem with the request
LoginError if there was an authentication problem.
"""
- login_type = authdict['type']
+ login_type = authdict["type"]
checker = self.checkers.get(login_type)
if checker is not None:
# XXX: Temporary workaround for having Synapse handle password resets
# See AuthHandler.check_auth for further details
res = yield checker(
- authdict,
- clientip=clientip,
- password_servlet=password_servlet,
+ authdict, clientip=clientip, password_servlet=password_servlet
)
defer.returnValue(res)
@@ -408,13 +397,11 @@ class AuthHandler(BaseHandler):
# Client tried to provide captcha but didn't give the parameter:
# bad request.
raise LoginError(
- 400, "Captcha response is required",
- errcode=Codes.CAPTCHA_NEEDED
+ 400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED
)
logger.info(
- "Submitting recaptcha response %s with remoteip %s",
- user_response, clientip
+ "Submitting recaptcha response %s with remoteip %s", user_response, clientip
)
# TODO: get this from the homeserver rather than creating a new one for
@@ -424,34 +411,34 @@ class AuthHandler(BaseHandler):
resp_body = yield client.post_urlencoded_get_json(
self.hs.config.recaptcha_siteverify_api,
args={
- 'secret': self.hs.config.recaptcha_private_key,
- 'response': user_response,
- 'remoteip': clientip,
- }
+ "secret": self.hs.config.recaptcha_private_key,
+ "response": user_response,
+ "remoteip": clientip,
+ },
)
except PartialDownloadError as pde:
# Twisted is silly
data = pde.response
resp_body = json.loads(data)
- if 'success' in resp_body:
+ if "success" in resp_body:
# Note that we do NOT check the hostname here: we explicitly
# intend the CAPTCHA to be presented by whatever client the
# user is using, we just care that they have completed a CAPTCHA.
logger.info(
"%s reCAPTCHA from hostname %s",
- "Successful" if resp_body['success'] else "Failed",
- resp_body.get('hostname')
+ "Successful" if resp_body["success"] else "Failed",
+ resp_body.get("hostname"),
)
- if resp_body['success']:
+ if resp_body["success"]:
defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
def _check_email_identity(self, authdict, **kwargs):
- return self._check_threepid('email', authdict, **kwargs)
+ return self._check_threepid("email", authdict, **kwargs)
def _check_msisdn(self, authdict, **kwargs):
- return self._check_threepid('msisdn', authdict)
+ return self._check_threepid("msisdn", authdict)
def _check_dummy_auth(self, authdict, **kwargs):
return defer.succeed(True)
@@ -461,10 +448,10 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def _check_threepid(self, medium, authdict, password_servlet=False, **kwargs):
- if 'threepid_creds' not in authdict:
+ if "threepid_creds" not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
- threepid_creds = authdict['threepid_creds']
+ threepid_creds = authdict["threepid_creds"]
identity_handler = self.hs.get_handlers().identity_handler
@@ -482,31 +469,36 @@ class AuthHandler(BaseHandler):
validated=True,
)
- threepid = {
- "medium": row["medium"],
- "address": row["address"],
- "validated_at": row["validated_at"],
- } if row else None
+ threepid = (
+ {
+ "medium": row["medium"],
+ "address": row["address"],
+ "validated_at": row["validated_at"],
+ }
+ if row
+ else None
+ )
if row:
# Valid threepid returned, delete from the db
yield self.store.delete_threepid_session(threepid_creds["sid"])
else:
- raise SynapseError(400, "Password resets are not enabled on this homeserver")
+ raise SynapseError(
+ 400, "Password resets are not enabled on this homeserver"
+ )
if not threepid:
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
- if threepid['medium'] != medium:
+ if threepid["medium"] != medium:
raise LoginError(
401,
- "Expecting threepid of type '%s', got '%s'" % (
- medium, threepid['medium'],
- ),
- errcode=Codes.UNAUTHORIZED
+ "Expecting threepid of type '%s', got '%s'"
+ % (medium, threepid["medium"]),
+ errcode=Codes.UNAUTHORIZED,
)
- threepid['threepid_creds'] = authdict['threepid_creds']
+ threepid["threepid_creds"] = authdict["threepid_creds"]
defer.returnValue(threepid)
@@ -520,13 +512,14 @@ class AuthHandler(BaseHandler):
"version": self.hs.config.user_consent_version,
"en": {
"name": self.hs.config.user_consent_policy_name,
- "url": "%s_matrix/consent?v=%s" % (
+ "url": "%s_matrix/consent?v=%s"
+ % (
self.hs.config.public_baseurl,
self.hs.config.user_consent_version,
),
},
- },
- },
+ }
+ }
}
def _auth_dict_for_flows(self, flows, session):
@@ -547,9 +540,9 @@ class AuthHandler(BaseHandler):
params[stage] = get_params[stage]()
return {
- "session": session['id'],
+ "session": session["id"],
"flows": [{"stages": f} for f in public_flows],
- "params": params
+ "params": params,
}
def _get_session_info(self, session_id):
@@ -560,9 +553,7 @@ class AuthHandler(BaseHandler):
# create a new session
while session_id is None or session_id in self.sessions:
session_id = stringutils.random_string(24)
- self.sessions[session_id] = {
- "id": session_id,
- }
+ self.sessions[session_id] = {"id": session_id}
return self.sessions[session_id]
@@ -652,7 +643,8 @@ class AuthHandler(BaseHandler):
logger.warn(
"Attempted to login as %s but it matches more than one user "
"inexactly: %r",
- user_id, user_infos.keys()
+ user_id,
+ user_infos.keys(),
)
defer.returnValue(result)
@@ -690,12 +682,10 @@ class AuthHandler(BaseHandler):
user is too high too proceed.
"""
- if username.startswith('@'):
+ if username.startswith("@"):
qualified_user_id = username
else:
- qualified_user_id = UserID(
- username, self.hs.hostname
- ).to_string()
+ qualified_user_id = UserID(username, self.hs.hostname).to_string()
self.ratelimit_login_per_account(qualified_user_id)
@@ -713,17 +703,15 @@ class AuthHandler(BaseHandler):
raise SynapseError(400, "Missing parameter: password")
for provider in self.password_providers:
- if (hasattr(provider, "check_password")
- and login_type == LoginType.PASSWORD):
+ if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
known_login_type = True
- is_valid = yield provider.check_password(
- qualified_user_id, password,
- )
+ is_valid = yield provider.check_password(qualified_user_id, password)
if is_valid:
defer.returnValue((qualified_user_id, None))
- if (not hasattr(provider, "get_supported_login_types")
- or not hasattr(provider, "check_auth")):
+ if not hasattr(provider, "get_supported_login_types") or not hasattr(
+ provider, "check_auth"
+ ):
# this password provider doesn't understand custom login types
continue
@@ -744,15 +732,12 @@ class AuthHandler(BaseHandler):
login_dict[f] = login_submission[f]
if missing_fields:
raise SynapseError(
- 400, "Missing parameters for login type %s: %s" % (
- login_type,
- missing_fields,
- ),
+ 400,
+ "Missing parameters for login type %s: %s"
+ % (login_type, missing_fields),
)
- result = yield provider.check_auth(
- username, login_type, login_dict,
- )
+ result = yield provider.check_auth(username, login_type, login_dict)
if result:
if isinstance(result, str):
result = (result, None)
@@ -762,7 +747,7 @@ class AuthHandler(BaseHandler):
known_login_type = True
canonical_user_id = yield self._check_local_password(
- qualified_user_id, password,
+ qualified_user_id, password
)
if canonical_user_id:
@@ -773,7 +758,8 @@ class AuthHandler(BaseHandler):
# unknown username or invalid password.
self._failed_attempts_ratelimiter.ratelimit(
- qualified_user_id.lower(), time_now_s=self._clock.time(),
+ qualified_user_id.lower(),
+ time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
update=True,
@@ -781,10 +767,7 @@ class AuthHandler(BaseHandler):
# We raise a 403 here, but note that if we're doing user-interactive
# login, it turns all LoginErrors into a 401 anyway.
- raise LoginError(
- 403, "Invalid password",
- errcode=Codes.FORBIDDEN
- )
+ raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
def check_password_provider_3pid(self, medium, address, password):
@@ -810,9 +793,7 @@ class AuthHandler(BaseHandler):
# success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
- result = yield provider.check_3pid_auth(
- medium, address, password,
- )
+ result = yield provider.check_3pid_auth(medium, address, password)
if result:
# Check if the return value is a str or a tuple
if isinstance(result, str):
@@ -853,8 +834,7 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def issue_access_token(self, user_id, device_id=None):
access_token = self.macaroon_gen.generate_access_token(user_id)
- yield self.store.add_access_token_to_user(user_id, access_token,
- device_id)
+ yield self.store.add_access_token_to_user(user_id, access_token, device_id)
defer.returnValue(access_token)
@defer.inlineCallbacks
@@ -896,12 +876,13 @@ class AuthHandler(BaseHandler):
# delete pushers associated with this access token
if user_info["token_id"] is not None:
yield self.hs.get_pusherpool().remove_pushers_by_access_token(
- str(user_info["user"]), (user_info["token_id"], )
+ str(user_info["user"]), (user_info["token_id"],)
)
@defer.inlineCallbacks
- def delete_access_tokens_for_user(self, user_id, except_token_id=None,
- device_id=None):
+ def delete_access_tokens_for_user(
+ self, user_id, except_token_id=None, device_id=None
+ ):
"""Invalidate access tokens belonging to a user
Args:
@@ -915,7 +896,7 @@ class AuthHandler(BaseHandler):
Deferred
"""
tokens_and_devices = yield self.store.user_delete_access_tokens(
- user_id, except_token_id=except_token_id, device_id=device_id,
+ user_id, except_token_id=except_token_id, device_id=device_id
)
# see if any of our auth providers want to know about this
@@ -923,14 +904,12 @@ class AuthHandler(BaseHandler):
if hasattr(provider, "on_logged_out"):
for token, token_id, device_id in tokens_and_devices:
yield provider.on_logged_out(
- user_id=user_id,
- device_id=device_id,
- access_token=token,
+ user_id=user_id, device_id=device_id, access_token=token
)
# delete pushers associated with the access tokens
yield self.hs.get_pusherpool().remove_pushers_by_access_token(
- user_id, (token_id for _, token_id, _ in tokens_and_devices),
+ user_id, (token_id for _, token_id, _ in tokens_and_devices)
)
@defer.inlineCallbacks
@@ -944,12 +923,11 @@ class AuthHandler(BaseHandler):
# of specific types of threepid (and fixes the fact that checking
# for the presence of an email address during password reset was
# case sensitive).
- if medium == 'email':
+ if medium == "email":
address = address.lower()
yield self.store.user_add_threepid(
- user_id, medium, address, validated_at,
- self.hs.get_clock().time_msec()
+ user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
)
@defer.inlineCallbacks
@@ -973,22 +951,15 @@ class AuthHandler(BaseHandler):
"""
# 'Canonicalise' email addresses as per above
- if medium == 'email':
+ if medium == "email":
address = address.lower()
identity_handler = self.hs.get_handlers().identity_handler
result = yield identity_handler.try_unbind_threepid(
- user_id,
- {
- 'medium': medium,
- 'address': address,
- 'id_server': id_server,
- },
+ user_id, {"medium": medium, "address": address, "id_server": id_server}
)
- yield self.store.user_delete_threepid(
- user_id, medium, address,
- )
+ yield self.store.user_delete_threepid(user_id, medium, address)
defer.returnValue(result)
def _save_session(self, session):
@@ -1006,14 +977,15 @@ class AuthHandler(BaseHandler):
Returns:
Deferred(unicode): Hashed password.
"""
+
def _do_hash():
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
return bcrypt.hashpw(
- pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
+ pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
bcrypt.gensalt(self.bcrypt_rounds),
- ).decode('ascii')
+ ).decode("ascii")
return logcontext.defer_to_thread(self.hs.get_reactor(), _do_hash)
@@ -1027,18 +999,19 @@ class AuthHandler(BaseHandler):
Returns:
Deferred(bool): Whether self.hash(password) == stored_hash.
"""
+
def _do_validate_hash():
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
return bcrypt.checkpw(
- pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
- stored_hash
+ pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
+ stored_hash,
)
if stored_hash:
if not isinstance(stored_hash, bytes):
- stored_hash = stored_hash.encode('ascii')
+ stored_hash = stored_hash.encode("ascii")
return logcontext.defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
else:
@@ -1058,14 +1031,16 @@ class AuthHandler(BaseHandler):
for this user is too high too proceed.
"""
self._failed_attempts_ratelimiter.ratelimit(
- user_id.lower(), time_now_s=self._clock.time(),
+ user_id.lower(),
+ time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
update=False,
)
self._account_ratelimiter.ratelimit(
- user_id.lower(), time_now_s=self._clock.time(),
+ user_id.lower(),
+ time_now_s=self._clock.time(),
rate_hz=self.hs.config.rc_login_account.per_second,
burst_count=self.hs.config.rc_login_account.burst_count,
update=True,
@@ -1083,9 +1058,9 @@ class MacaroonGenerator(object):
macaroon.add_first_party_caveat("type = access")
# Include a nonce, to make sure that each login gets a different
# access token.
- macaroon.add_first_party_caveat("nonce = %s" % (
- stringutils.random_string_with_symbols(16),
- ))
+ macaroon.add_first_party_caveat(
+ "nonce = %s" % (stringutils.random_string_with_symbols(16),)
+ )
for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat)
return macaroon.serialize()
@@ -1116,7 +1091,8 @@ class MacaroonGenerator(object):
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
- key=self.hs.config.macaroon_secret_key)
+ key=self.hs.config.macaroon_secret_key,
+ )
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 6a91f7698e..e8f9da6098 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2017, 2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -27,6 +28,7 @@ logger = logging.getLogger(__name__)
class DeactivateAccountHandler(BaseHandler):
"""Handler which deals with deactivating user accounts."""
+
def __init__(self, hs):
super(DeactivateAccountHandler, self).__init__(hs)
self._auth_handler = hs.get_auth_handler()
@@ -42,6 +44,8 @@ class DeactivateAccountHandler(BaseHandler):
# it left off (if it has work left to do).
hs.get_reactor().callWhenRunning(self._start_user_parting)
+ self._account_validity_enabled = hs.config.account_validity.enabled
+
@defer.inlineCallbacks
def deactivate_account(self, user_id, erase_data, id_server=None):
"""Deactivate a user's account
@@ -75,9 +79,9 @@ class DeactivateAccountHandler(BaseHandler):
result = yield self._identity_handler.try_unbind_threepid(
user_id,
{
- 'medium': threepid['medium'],
- 'address': threepid['address'],
- 'id_server': id_server,
+ "medium": threepid["medium"],
+ "address": threepid["address"],
+ "id_server": id_server,
},
)
identity_server_supports_unbinding &= result
@@ -86,7 +90,7 @@ class DeactivateAccountHandler(BaseHandler):
logger.exception("Failed to remove threepid from ID server")
raise SynapseError(400, "Failed to remove threepid from ID server")
yield self.store.user_delete_threepid(
- user_id, threepid['medium'], threepid['address'],
+ user_id, threepid["medium"], threepid["address"]
)
# delete any devices belonging to the user, which will also
@@ -114,6 +118,13 @@ class DeactivateAccountHandler(BaseHandler):
# parts users from rooms (if it isn't already running)
self._start_user_parting()
+ # Remove all information on the user from the account_validity table.
+ if self._account_validity_enabled:
+ yield self.store.delete_account_validity_for_user(user_id)
+
+ # Mark the user as deactivated.
+ yield self.store.set_user_deactivated_status(user_id, True)
+
defer.returnValue(identity_server_supports_unbinding)
def _start_user_parting(self):
@@ -173,5 +184,6 @@ class DeactivateAccountHandler(BaseHandler):
except Exception:
logger.exception(
"Failed to part user %r from room %r: ignoring and continuing",
- user_id, room_id,
+ user_id,
+ room_id,
)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index b398848079..f59d0479b5 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -58,9 +58,7 @@ class DeviceWorkerHandler(BaseHandler):
device_map = yield self.store.get_devices_by_user(user_id)
- ips = yield self.store.get_last_client_ip_by_device(
- user_id, device_id=None
- )
+ ips = yield self.store.get_last_client_ip_by_device(user_id, device_id=None)
devices = list(device_map.values())
for device in devices:
@@ -85,9 +83,7 @@ class DeviceWorkerHandler(BaseHandler):
device = yield self.store.get_device(user_id, device_id)
except errors.StoreError:
raise errors.NotFoundError
- ips = yield self.store.get_last_client_ip_by_device(
- user_id, device_id,
- )
+ ips = yield self.store.get_last_client_ip_by_device(user_id, device_id)
_update_device_from_client_ips(device, ips)
defer.returnValue(device)
@@ -114,13 +110,11 @@ class DeviceWorkerHandler(BaseHandler):
rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
member_events = yield self.store.get_membership_changes_for_user(
- user_id, from_token.room_key, now_room_key,
+ user_id, from_token.room_key, now_room_key
)
rooms_changed.update(event.room_id for event in member_events)
- stream_ordering = RoomStreamToken.parse_stream_token(
- from_token.room_key
- ).stream
+ stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key).stream
possibly_changed = set(changed)
possibly_left = set()
@@ -206,10 +200,9 @@ class DeviceWorkerHandler(BaseHandler):
possibly_joined = []
possibly_left = []
- defer.returnValue({
- "changed": list(possibly_joined),
- "left": list(possibly_left),
- })
+ defer.returnValue(
+ {"changed": list(possibly_joined), "left": list(possibly_left)}
+ )
class DeviceHandler(DeviceWorkerHandler):
@@ -223,17 +216,18 @@ class DeviceHandler(DeviceWorkerHandler):
federation_registry = hs.get_federation_registry()
federation_registry.register_edu_handler(
- "m.device_list_update", self._edu_updater.incoming_device_list_update,
+ "m.device_list_update", self._edu_updater.incoming_device_list_update
)
federation_registry.register_query_handler(
- "user_devices", self.on_federation_query_user_devices,
+ "user_devices", self.on_federation_query_user_devices
)
hs.get_distributor().observe("user_left_room", self.user_left_room)
@defer.inlineCallbacks
- def check_device_registered(self, user_id, device_id,
- initial_device_display_name=None):
+ def check_device_registered(
+ self, user_id, device_id, initial_device_display_name=None
+ ):
"""
If the given device has not been registered, register it with the
supplied display name.
@@ -297,12 +291,10 @@ class DeviceHandler(DeviceWorkerHandler):
raise
yield self._auth_handler.delete_access_tokens_for_user(
- user_id, device_id=device_id,
+ user_id, device_id=device_id
)
- yield self.store.delete_e2e_keys_by_device(
- user_id=user_id, device_id=device_id
- )
+ yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
yield self.notify_device_update(user_id, [device_id])
@@ -349,7 +341,7 @@ class DeviceHandler(DeviceWorkerHandler):
# considered as part of a critical path.
for device_id in device_ids:
yield self._auth_handler.delete_access_tokens_for_user(
- user_id, device_id=device_id,
+ user_id, device_id=device_id
)
yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
@@ -372,9 +364,7 @@ class DeviceHandler(DeviceWorkerHandler):
try:
yield self.store.update_device(
- user_id,
- device_id,
- new_display_name=content.get("display_name")
+ user_id, device_id, new_display_name=content.get("display_name")
)
yield self.notify_device_update(user_id, [device_id])
except errors.StoreError as e:
@@ -404,29 +394,26 @@ class DeviceHandler(DeviceWorkerHandler):
for device_id in device_ids:
logger.debug(
- "Notifying about update %r/%r, ID: %r", user_id, device_id,
- position,
+ "Notifying about update %r/%r, ID: %r", user_id, device_id, position
)
room_ids = yield self.store.get_rooms_for_user(user_id)
- yield self.notifier.on_new_event(
- "device_list_key", position, rooms=room_ids,
- )
+ yield self.notifier.on_new_event("device_list_key", position, rooms=room_ids)
if hosts:
- logger.info("Sending device list update notif for %r to: %r", user_id, hosts)
+ logger.info(
+ "Sending device list update notif for %r to: %r", user_id, hosts
+ )
for host in hosts:
self.federation_sender.send_device_messages(host)
@defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id):
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
- defer.returnValue({
- "user_id": user_id,
- "stream_id": stream_id,
- "devices": devices,
- })
+ defer.returnValue(
+ {"user_id": user_id, "stream_id": stream_id, "devices": devices}
+ )
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
@@ -440,10 +427,7 @@ class DeviceHandler(DeviceWorkerHandler):
def _update_device_from_client_ips(device, client_ips):
ip = client_ips.get((device["user_id"], device["device_id"]), {})
- device.update({
- "last_seen_ts": ip.get("last_seen"),
- "last_seen_ip": ip.get("ip"),
- })
+ device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
class DeviceListEduUpdater(object):
@@ -481,13 +465,15 @@ class DeviceListEduUpdater(object):
device_id = edu_content.pop("device_id")
stream_id = str(edu_content.pop("stream_id")) # They may come as ints
prev_ids = edu_content.pop("prev_id", [])
- prev_ids = [str(p) for p in prev_ids] # They may come as ints
+ prev_ids = [str(p) for p in prev_ids] # They may come as ints
if get_domain_from_id(user_id) != origin:
# TODO: Raise?
logger.warning(
"Got device list update edu for %r/%r from %r",
- user_id, device_id, origin,
+ user_id,
+ device_id,
+ origin,
)
return
@@ -497,13 +483,12 @@ class DeviceListEduUpdater(object):
# probably won't get any further updates.
logger.warning(
"Got device list update edu for %r/%r, but don't share a room",
- user_id, device_id,
+ user_id,
+ device_id,
)
return
- logger.debug(
- "Received device list update for %r/%r", user_id, device_id,
- )
+ logger.debug("Received device list update for %r/%r", user_id, device_id)
self._pending_updates.setdefault(user_id, []).append(
(device_id, stream_id, prev_ids, edu_content)
@@ -525,7 +510,10 @@ class DeviceListEduUpdater(object):
for device_id, stream_id, prev_ids, content in pending_updates:
logger.debug(
"Handling update %r/%r, ID: %r, prev: %r ",
- user_id, device_id, stream_id, prev_ids,
+ user_id,
+ device_id,
+ stream_id,
+ prev_ids,
)
# Given a list of updates we check if we need to resync. This
@@ -540,13 +528,13 @@ class DeviceListEduUpdater(object):
try:
result = yield self.federation.query_user_devices(origin, user_id)
except (
- NotRetryingDestination, RequestSendFailed, HttpResponseException,
+ NotRetryingDestination,
+ RequestSendFailed,
+ HttpResponseException,
):
# TODO: Remember that we are now out of sync and try again
# later
- logger.warn(
- "Failed to handle device list update for %s", user_id,
- )
+ logger.warn("Failed to handle device list update for %s", user_id)
# We abort on exceptions rather than accepting the update
# as otherwise synapse will 'forget' that its device list
# is out of date. If we bail then we will retry the resync
@@ -582,18 +570,21 @@ class DeviceListEduUpdater(object):
if len(devices) > 1000:
logger.warn(
"Ignoring device list snapshot for %s as it has >1K devs (%d)",
- user_id, len(devices)
+ user_id,
+ len(devices),
)
devices = []
for device in devices:
logger.debug(
"Handling resync update %r/%r, ID: %r",
- user_id, device["device_id"], stream_id,
+ user_id,
+ device["device_id"],
+ stream_id,
)
yield self.store.update_remote_device_list_cache(
- user_id, devices, stream_id,
+ user_id, devices, stream_id
)
device_ids = [device["device_id"] for device in devices]
yield self.device_handler.notify_device_update(user_id, device_ids)
@@ -606,7 +597,7 @@ class DeviceListEduUpdater(object):
# change (because of the single prev_id matching the current cache)
for device_id, stream_id, prev_ids, content in pending_updates:
yield self.store.update_remote_device_list_cache_entry(
- user_id, device_id, content, stream_id,
+ user_id, device_id, content, stream_id
)
yield self.device_handler.notify_device_update(
@@ -624,14 +615,9 @@ class DeviceListEduUpdater(object):
"""
seen_updates = self._seen_updates.get(user_id, set())
- extremity = yield self.store.get_device_list_last_stream_id_for_remote(
- user_id
- )
+ extremity = yield self.store.get_device_list_last_stream_id_for_remote(user_id)
- logger.debug(
- "Current extremity for %r: %r",
- user_id, extremity,
- )
+ logger.debug("Current extremity for %r: %r", user_id, extremity)
stream_id_in_updates = set() # stream_ids in updates list
for _, stream_id, prev_ids, _ in updates:
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 2e2e5261de..e1ebb6346c 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -25,7 +25,6 @@ logger = logging.getLogger(__name__)
class DeviceMessageHandler(object):
-
def __init__(self, hs):
"""
Args:
@@ -47,15 +46,15 @@ class DeviceMessageHandler(object):
if origin != get_domain_from_id(sender_user_id):
logger.warn(
"Dropping device message from %r with spoofed sender %r",
- origin, sender_user_id
+ origin,
+ sender_user_id,
)
message_type = content["type"]
message_id = content["message_id"]
for user_id, by_device in content["messages"].items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
- logger.warning("Request for keys for non-local user %s",
- user_id)
+ logger.warning("Request for keys for non-local user %s", user_id)
raise SynapseError(400, "Not a user here")
messages_by_device = {
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index a12f9508d8..42d5b3db30 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -36,7 +36,6 @@ logger = logging.getLogger(__name__)
class DirectoryHandler(BaseHandler):
-
def __init__(self, hs):
super(DirectoryHandler, self).__init__(hs)
@@ -77,15 +76,19 @@ class DirectoryHandler(BaseHandler):
raise SynapseError(400, "Failed to get server list")
yield self.store.create_room_alias_association(
- room_alias,
- room_id,
- servers,
- creator=creator,
+ room_alias, room_id, servers, creator=creator
)
@defer.inlineCallbacks
- def create_association(self, requester, room_alias, room_id, servers=None,
- send_event=True, check_membership=True):
+ def create_association(
+ self,
+ requester,
+ room_alias,
+ room_id,
+ servers=None,
+ send_event=True,
+ check_membership=True,
+ ):
"""Attempt to create a new alias
Args:
@@ -115,49 +118,40 @@ class DirectoryHandler(BaseHandler):
if service:
if not service.is_interested_in_alias(room_alias.to_string()):
raise SynapseError(
- 400, "This application service has not reserved"
- " this kind of alias.", errcode=Codes.EXCLUSIVE
+ 400,
+ "This application service has not reserved" " this kind of alias.",
+ errcode=Codes.EXCLUSIVE,
)
else:
if self.require_membership and check_membership:
rooms_for_user = yield self.store.get_rooms_for_user(user_id)
if room_id not in rooms_for_user:
raise AuthError(
- 403,
- "You must be in the room to create an alias for it",
+ 403, "You must be in the room to create an alias for it"
)
if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
- raise AuthError(
- 403, "This user is not permitted to create this alias",
- )
+ raise AuthError(403, "This user is not permitted to create this alias")
if not self.config.is_alias_creation_allowed(
- user_id, room_id, room_alias.to_string(),
+ user_id, room_id, room_alias.to_string()
):
# Lets just return a generic message, as there may be all sorts of
# reasons why we said no. TODO: Allow configurable error messages
# per alias creation rule?
- raise SynapseError(
- 403, "Not allowed to create alias",
- )
+ raise SynapseError(403, "Not allowed to create alias")
- can_create = yield self.can_modify_alias(
- room_alias,
- user_id=user_id
- )
+ can_create = yield self.can_modify_alias(room_alias, user_id=user_id)
if not can_create:
raise AuthError(
- 400, "This alias is reserved by an application service.",
- errcode=Codes.EXCLUSIVE
+ 400,
+ "This alias is reserved by an application service.",
+ errcode=Codes.EXCLUSIVE,
)
yield self._create_association(room_alias, room_id, servers, creator=user_id)
if send_event:
- yield self.send_room_alias_update_event(
- requester,
- room_id
- )
+ yield self.send_room_alias_update_event(requester, room_id)
@defer.inlineCallbacks
def delete_association(self, requester, room_alias, send_event=True):
@@ -194,34 +188,24 @@ class DirectoryHandler(BaseHandler):
raise
if not can_delete:
- raise AuthError(
- 403, "You don't have permission to delete the alias.",
- )
+ raise AuthError(403, "You don't have permission to delete the alias.")
- can_delete = yield self.can_modify_alias(
- room_alias,
- user_id=user_id
- )
+ can_delete = yield self.can_modify_alias(room_alias, user_id=user_id)
if not can_delete:
raise SynapseError(
- 400, "This alias is reserved by an application service.",
- errcode=Codes.EXCLUSIVE
+ 400,
+ "This alias is reserved by an application service.",
+ errcode=Codes.EXCLUSIVE,
)
room_id = yield self._delete_association(room_alias)
try:
if send_event:
- yield self.send_room_alias_update_event(
- requester,
- room_id
- )
+ yield self.send_room_alias_update_event(requester, room_id)
yield self._update_canonical_alias(
- requester,
- requester.user.to_string(),
- room_id,
- room_alias,
+ requester, requester.user.to_string(), room_id, room_alias
)
except AuthError as e:
logger.info("Failed to update alias events: %s", e)
@@ -234,7 +218,7 @@ class DirectoryHandler(BaseHandler):
raise SynapseError(
400,
"This application service has not reserved this kind of alias",
- errcode=Codes.EXCLUSIVE
+ errcode=Codes.EXCLUSIVE,
)
yield self._delete_association(room_alias)
@@ -251,9 +235,7 @@ class DirectoryHandler(BaseHandler):
def get_association(self, room_alias):
room_id = None
if self.hs.is_mine(room_alias):
- result = yield self.get_association_from_room_alias(
- room_alias
- )
+ result = yield self.get_association_from_room_alias(room_alias)
if result:
room_id = result.room_id
@@ -263,9 +245,7 @@ class DirectoryHandler(BaseHandler):
result = yield self.federation.make_query(
destination=room_alias.domain,
query_type="directory",
- args={
- "room_alias": room_alias.to_string(),
- },
+ args={"room_alias": room_alias.to_string()},
retry_on_dns_fail=False,
ignore_backoff=True,
)
@@ -284,7 +264,7 @@ class DirectoryHandler(BaseHandler):
raise SynapseError(
404,
"Room alias %s not found" % (room_alias.to_string(),),
- Codes.NOT_FOUND
+ Codes.NOT_FOUND,
)
users = yield self.state.get_current_users_in_room(room_id)
@@ -293,41 +273,28 @@ class DirectoryHandler(BaseHandler):
# If this server is in the list of servers, return it first.
if self.server_name in servers:
- servers = (
- [self.server_name] +
- [s for s in servers if s != self.server_name]
- )
+ servers = [self.server_name] + [s for s in servers if s != self.server_name]
else:
servers = list(servers)
- defer.returnValue({
- "room_id": room_id,
- "servers": servers,
- })
+ defer.returnValue({"room_id": room_id, "servers": servers})
return
@defer.inlineCallbacks
def on_directory_query(self, args):
room_alias = RoomAlias.from_string(args["room_alias"])
if not self.hs.is_mine(room_alias):
- raise SynapseError(
- 400, "Room Alias is not hosted on this Home Server"
- )
+ raise SynapseError(400, "Room Alias is not hosted on this Home Server")
- result = yield self.get_association_from_room_alias(
- room_alias
- )
+ result = yield self.get_association_from_room_alias(room_alias)
if result is not None:
- defer.returnValue({
- "room_id": result.room_id,
- "servers": result.servers,
- })
+ defer.returnValue({"room_id": result.room_id, "servers": result.servers})
else:
raise SynapseError(
404,
"Room alias %r not found" % (room_alias.to_string(),),
- Codes.NOT_FOUND
+ Codes.NOT_FOUND,
)
@defer.inlineCallbacks
@@ -343,7 +310,7 @@ class DirectoryHandler(BaseHandler):
"sender": requester.user.to_string(),
"content": {"aliases": aliases},
},
- ratelimit=False
+ ratelimit=False,
)
@defer.inlineCallbacks
@@ -365,14 +332,12 @@ class DirectoryHandler(BaseHandler):
"sender": user_id,
"content": {},
},
- ratelimit=False
+ ratelimit=False,
)
@defer.inlineCallbacks
def get_association_from_room_alias(self, room_alias):
- result = yield self.store.get_association_from_room_alias(
- room_alias
- )
+ result = yield self.store.get_association_from_room_alias(room_alias)
if not result:
# Query AS to see if it exists
as_handler = self.appservice_handler
@@ -421,8 +386,7 @@ class DirectoryHandler(BaseHandler):
if not self.spam_checker.user_may_publish_room(user_id, room_id):
raise AuthError(
- 403,
- "This user is not permitted to publish rooms to the room list"
+ 403, "This user is not permitted to publish rooms to the room list"
)
if requester.is_guest:
@@ -434,8 +398,7 @@ class DirectoryHandler(BaseHandler):
if visibility == "public" and not self.enable_room_list_search:
# The room list has been disabled.
raise AuthError(
- 403,
- "This user is not permitted to publish rooms to the room list"
+ 403, "This user is not permitted to publish rooms to the room list"
)
room = yield self.store.get_room(room_id)
@@ -452,20 +415,19 @@ class DirectoryHandler(BaseHandler):
room_aliases.append(canonical_alias)
if not self.config.is_publishing_room_allowed(
- user_id, room_id, room_aliases,
+ user_id, room_id, room_aliases
):
# Lets just return a generic message, as there may be all sorts of
# reasons why we said no. TODO: Allow configurable error messages
# per alias creation rule?
- raise SynapseError(
- 403, "Not allowed to publish room",
- )
+ raise SynapseError(403, "Not allowed to publish room")
yield self.store.set_room_is_public(room_id, making_public)
@defer.inlineCallbacks
- def edit_published_appservice_room_list(self, appservice_id, network_id,
- room_id, visibility):
+ def edit_published_appservice_room_list(
+ self, appservice_id, network_id, room_id, visibility
+ ):
"""Add or remove a room from the appservice/network specific public
room list.
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 9dc46aa15f..807900fe52 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -99,9 +99,7 @@ class E2eKeysHandler(object):
query_list.append((user_id, None))
user_ids_not_in_cache, remote_results = (
- yield self.store.get_user_devices_from_cache(
- query_list
- )
+ yield self.store.get_user_devices_from_cache(query_list)
)
for user_id, devices in iteritems(remote_results):
user_devices = results.setdefault(user_id, {})
@@ -126,9 +124,7 @@ class E2eKeysHandler(object):
destination_query = remote_queries_not_in_cache[destination]
try:
remote_result = yield self.federation.query_client_keys(
- destination,
- {"device_keys": destination_query},
- timeout=timeout
+ destination, {"device_keys": destination_query}, timeout=timeout
)
for user_id, keys in remote_result["device_keys"].items():
@@ -138,14 +134,17 @@ class E2eKeysHandler(object):
except Exception as e:
failures[destination] = _exception_to_failure(e)
- yield make_deferred_yieldable(defer.gatherResults([
- run_in_background(do_remote_query, destination)
- for destination in remote_queries_not_in_cache
- ], consumeErrors=True))
+ yield make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(do_remote_query, destination)
+ for destination in remote_queries_not_in_cache
+ ],
+ consumeErrors=True,
+ )
+ )
- defer.returnValue({
- "device_keys": results, "failures": failures,
- })
+ defer.returnValue({"device_keys": results, "failures": failures})
@defer.inlineCallbacks
def query_local_devices(self, query):
@@ -165,8 +164,7 @@ class E2eKeysHandler(object):
for user_id, device_ids in query.items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
- logger.warning("Request for keys for non-local user %s",
- user_id)
+ logger.warning("Request for keys for non-local user %s", user_id)
raise SynapseError(400, "Not a user here")
if not device_ids:
@@ -231,9 +229,7 @@ class E2eKeysHandler(object):
device_keys = remote_queries[destination]
try:
remote_result = yield self.federation.claim_client_keys(
- destination,
- {"one_time_keys": device_keys},
- timeout=timeout
+ destination, {"one_time_keys": device_keys}, timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
@@ -241,25 +237,29 @@ class E2eKeysHandler(object):
except Exception as e:
failures[destination] = _exception_to_failure(e)
- yield make_deferred_yieldable(defer.gatherResults([
- run_in_background(claim_client_keys, destination)
- for destination in remote_queries
- ], consumeErrors=True))
+ yield make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(claim_client_keys, destination)
+ for destination in remote_queries
+ ],
+ consumeErrors=True,
+ )
+ )
logger.info(
"Claimed one-time-keys: %s",
- ",".join((
- "%s for %s:%s" % (key_id, user_id, device_id)
- for user_id, user_keys in iteritems(json_result)
- for device_id, device_keys in iteritems(user_keys)
- for key_id, _ in iteritems(device_keys)
- )),
+ ",".join(
+ (
+ "%s for %s:%s" % (key_id, user_id, device_id)
+ for user_id, user_keys in iteritems(json_result)
+ for device_id, device_keys in iteritems(user_keys)
+ for key_id, _ in iteritems(device_keys)
+ )
+ ),
)
- defer.returnValue({
- "one_time_keys": json_result,
- "failures": failures
- })
+ defer.returnValue({"one_time_keys": json_result, "failures": failures})
@defer.inlineCallbacks
def upload_keys_for_user(self, user_id, device_id, keys):
@@ -270,11 +270,13 @@ class E2eKeysHandler(object):
if device_keys:
logger.info(
"Updating device_keys for device %r for user %s at %d",
- device_id, user_id, time_now
+ device_id,
+ user_id,
+ time_now,
)
# TODO: Sign the JSON with the server key
changed = yield self.store.set_e2e_device_keys(
- user_id, device_id, time_now, device_keys,
+ user_id, device_id, time_now, device_keys
)
if changed:
# Only notify about device updates *if* the keys actually changed
@@ -283,7 +285,7 @@ class E2eKeysHandler(object):
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
yield self._upload_one_time_keys_for_user(
- user_id, device_id, time_now, one_time_keys,
+ user_id, device_id, time_now, one_time_keys
)
# the device should have been registered already, but it may have been
@@ -298,20 +300,22 @@ class E2eKeysHandler(object):
defer.returnValue({"one_time_key_counts": result})
@defer.inlineCallbacks
- def _upload_one_time_keys_for_user(self, user_id, device_id, time_now,
- one_time_keys):
+ def _upload_one_time_keys_for_user(
+ self, user_id, device_id, time_now, one_time_keys
+ ):
logger.info(
"Adding one_time_keys %r for device %r for user %r at %d",
- one_time_keys.keys(), device_id, user_id, time_now,
+ one_time_keys.keys(),
+ device_id,
+ user_id,
+ time_now,
)
# make a list of (alg, id, key) tuples
key_list = []
for key_id, key_obj in one_time_keys.items():
algorithm, key_id = key_id.split(":")
- key_list.append((
- algorithm, key_id, key_obj
- ))
+ key_list.append((algorithm, key_id, key_obj))
# First we check if we have already persisted any of the keys.
existing_key_map = yield self.store.get_e2e_one_time_keys(
@@ -325,42 +329,35 @@ class E2eKeysHandler(object):
if not _one_time_keys_match(ex_json, key):
raise SynapseError(
400,
- ("One time key %s:%s already exists. "
- "Old key: %s; new key: %r") %
- (algorithm, key_id, ex_json, key)
+ (
+ "One time key %s:%s already exists. "
+ "Old key: %s; new key: %r"
+ )
+ % (algorithm, key_id, ex_json, key),
)
else:
- new_keys.append((
- algorithm, key_id, encode_canonical_json(key).decode('ascii')))
+ new_keys.append(
+ (algorithm, key_id, encode_canonical_json(key).decode("ascii"))
+ )
- yield self.store.add_e2e_one_time_keys(
- user_id, device_id, time_now, new_keys
- )
+ yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
def _exception_to_failure(e):
if isinstance(e, CodeMessageException):
- return {
- "status": e.code, "message": str(e),
- }
+ return {"status": e.code, "message": str(e)}
if isinstance(e, NotRetryingDestination):
- return {
- "status": 503, "message": "Not ready for retry",
- }
+ return {"status": 503, "message": "Not ready for retry"}
if isinstance(e, FederationDeniedError):
- return {
- "status": 403, "message": "Federation Denied",
- }
+ return {"status": 403, "message": "Federation Denied"}
# include ConnectionRefused and other errors
#
# Note that some Exceptions (notably twisted's ResponseFailed etc) don't
# give a string for e.message, which json then fails to serialize.
- return {
- "status": 503, "message": str(e),
- }
+ return {"status": 503, "message": str(e)}
def _one_time_keys_match(old_key_json, new_key):
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 7bc174070e..ebd807bca6 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -152,14 +152,14 @@ class E2eRoomKeysHandler(object):
else:
raise
- if version_info['version'] != version:
+ if version_info["version"] != version:
# Check that the version we're trying to upload actually exists
try:
version_info = yield self.store.get_e2e_room_keys_version_info(
- user_id, version,
+ user_id, version
)
# if we get this far, the version must exist
- raise RoomKeysVersionError(current_version=version_info['version'])
+ raise RoomKeysVersionError(current_version=version_info["version"])
except StoreError as e:
if e.code == 404:
raise NotFoundError("Version '%s' not found" % (version,))
@@ -168,8 +168,8 @@ class E2eRoomKeysHandler(object):
# go through the room_keys.
# XXX: this should/could be done concurrently, given we're in a lock.
- for room_id, room in iteritems(room_keys['rooms']):
- for session_id, session in iteritems(room['sessions']):
+ for room_id, room in iteritems(room_keys["rooms"]):
+ for session_id, session in iteritems(room["sessions"]):
yield self._upload_room_key(
user_id, version, room_id, session_id, session
)
@@ -223,14 +223,14 @@ class E2eRoomKeysHandler(object):
# spelt out with if/elifs rather than nested boolean expressions
# purely for legibility.
- if room_key['is_verified'] and not current_room_key['is_verified']:
+ if room_key["is_verified"] and not current_room_key["is_verified"]:
return True
elif (
- room_key['first_message_index'] <
- current_room_key['first_message_index']
+ room_key["first_message_index"]
+ < current_room_key["first_message_index"]
):
return True
- elif room_key['forwarded_count'] < current_room_key['forwarded_count']:
+ elif room_key["forwarded_count"] < current_room_key["forwarded_count"]:
return True
else:
return False
@@ -328,16 +328,10 @@ class E2eRoomKeysHandler(object):
A deferred of an empty dict.
"""
if "version" not in version_info:
- raise SynapseError(
- 400,
- "Missing version in body",
- Codes.MISSING_PARAM
- )
+ raise SynapseError(400, "Missing version in body", Codes.MISSING_PARAM)
if version_info["version"] != version:
raise SynapseError(
- 400,
- "Version in body does not match",
- Codes.INVALID_PARAM
+ 400, "Version in body does not match", Codes.INVALID_PARAM
)
with (yield self._upload_linearizer.queue(user_id)):
try:
@@ -350,12 +344,10 @@ class E2eRoomKeysHandler(object):
else:
raise
if old_info["algorithm"] != version_info["algorithm"]:
- raise SynapseError(
- 400,
- "Algorithm does not match",
- Codes.INVALID_PARAM
- )
+ raise SynapseError(400, "Algorithm does not match", Codes.INVALID_PARAM)
- yield self.store.update_e2e_room_keys_version(user_id, version, version_info)
+ yield self.store.update_e2e_room_keys_version(
+ user_id, version, version_info
+ )
defer.returnValue({})
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index eb525070cf..5836d3c639 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -31,7 +31,6 @@ logger = logging.getLogger(__name__)
class EventStreamHandler(BaseHandler):
-
def __init__(self, hs):
super(EventStreamHandler, self).__init__(hs)
@@ -53,9 +52,17 @@ class EventStreamHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
- def get_stream(self, auth_user_id, pagin_config, timeout=0,
- as_client_event=True, affect_presence=True,
- only_keys=None, room_id=None, is_guest=False):
+ def get_stream(
+ self,
+ auth_user_id,
+ pagin_config,
+ timeout=0,
+ as_client_event=True,
+ affect_presence=True,
+ only_keys=None,
+ room_id=None,
+ is_guest=False,
+ ):
"""Fetches the events stream for a given user.
If `only_keys` is not None, events from keys will be sent down.
@@ -73,7 +80,7 @@ class EventStreamHandler(BaseHandler):
presence_handler = self.hs.get_presence_handler()
context = yield presence_handler.user_syncing(
- auth_user_id, affect_presence=affect_presence,
+ auth_user_id, affect_presence=affect_presence
)
with context:
if timeout:
@@ -85,9 +92,12 @@ class EventStreamHandler(BaseHandler):
timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
events, tokens = yield self.notifier.get_events_for(
- auth_user, pagin_config, timeout,
+ auth_user,
+ pagin_config,
+ timeout,
only_keys=only_keys,
- is_guest=is_guest, explicit_room_id=room_id
+ is_guest=is_guest,
+ explicit_room_id=room_id,
)
# When the user joins a new room, or another user joins a currently
@@ -102,17 +112,15 @@ class EventStreamHandler(BaseHandler):
# Send down presence.
if event.state_key == auth_user_id:
# Send down presence for everyone in the room.
- users = yield self.state.get_current_users_in_room(event.room_id)
- states = yield presence_handler.get_states(
- users,
- as_event=True,
+ users = yield self.state.get_current_users_in_room(
+ event.room_id
)
+ states = yield presence_handler.get_states(users, as_event=True)
to_add.extend(states)
else:
ev = yield presence_handler.get_state(
- UserID.from_string(event.state_key),
- as_event=True,
+ UserID.from_string(event.state_key), as_event=True
)
to_add.append(ev)
@@ -121,7 +129,9 @@ class EventStreamHandler(BaseHandler):
time_now = self.clock.time_msec()
chunks = yield self._event_serializer.serialize_events(
- events, time_now, as_client_event=as_client_event,
+ events,
+ time_now,
+ as_client_event=as_client_event,
# We don't bundle "live" events, as otherwise clients
# will end up double counting annotations.
bundle_aggregations=False,
@@ -137,7 +147,6 @@ class EventStreamHandler(BaseHandler):
class EventHandler(BaseHandler):
-
@defer.inlineCallbacks
def get_event(self, user, room_id, event_id):
"""Retrieve a single specified event.
@@ -164,16 +173,10 @@ class EventHandler(BaseHandler):
is_peeking = user.to_string() not in users
filtered = yield filter_events_for_client(
- self.store,
- user.to_string(),
- [event],
- is_peeking=is_peeking
+ self.store, user.to_string(), [event], is_peeking=is_peeking
)
if not filtered:
- raise AuthError(
- 403,
- "You don't have permission to access that event."
- )
+ raise AuthError(403, "You don't have permission to access that event.")
defer.returnValue(event)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index ac5ca79143..02d397c498 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -33,6 +34,7 @@ from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.api.errors import (
AuthError,
CodeMessageException,
+ Codes,
FederationDeniedError,
FederationError,
RequestSendFailed,
@@ -80,7 +82,7 @@ def shortstr(iterable, maxitems=5):
items = list(itertools.islice(iterable, maxitems + 1))
if len(items) <= maxitems:
return str(items)
- return u"[" + u", ".join(repr(r) for r in items[:maxitems]) + u", ...]"
+ return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
class FederationHandler(BaseHandler):
@@ -113,24 +115,24 @@ class FederationHandler(BaseHandler):
self.config = hs.config
self.http_client = hs.get_simple_http_client()
- self._send_events_to_master = (
- ReplicationFederationSendEventsRestServlet.make_client(hs)
+ self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client(
+ hs
)
- self._notify_user_membership_change = (
- ReplicationUserJoinedLeftRoomRestServlet.make_client(hs)
+ self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client(
+ hs
)
- self._clean_room_for_join_client = (
- ReplicationCleanRoomRestServlet.make_client(hs)
+ self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client(
+ hs
)
# When joining a room we need to queue any events for that room up
self.room_queues = {}
self._room_pdu_linearizer = Linearizer("fed_room_pdu")
+ self.third_party_event_rules = hs.get_third_party_event_rules()
+
@defer.inlineCallbacks
- def on_receive_pdu(
- self, origin, pdu, sent_to_us_directly=False,
- ):
+ def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
""" Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
@@ -147,26 +149,19 @@ class FederationHandler(BaseHandler):
room_id = pdu.room_id
event_id = pdu.event_id
- logger.info(
- "[%s %s] handling received PDU: %s",
- room_id, event_id, pdu,
- )
+ logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu)
# We reprocess pdus when we have seen them only as outliers
existing = yield self.store.get_event(
- event_id,
- allow_none=True,
- allow_rejected=True,
+ event_id, allow_none=True, allow_rejected=True
)
# FIXME: Currently we fetch an event again when we already have it
# if it has been marked as an outlier.
- already_seen = (
- existing and (
- not existing.internal_metadata.is_outlier()
- or pdu.internal_metadata.is_outlier()
- )
+ already_seen = existing and (
+ not existing.internal_metadata.is_outlier()
+ or pdu.internal_metadata.is_outlier()
)
if already_seen:
logger.debug("[%s %s]: Already seen pdu", room_id, event_id)
@@ -178,20 +173,19 @@ class FederationHandler(BaseHandler):
try:
self._sanity_check_event(pdu)
except SynapseError as err:
- logger.warn("[%s %s] Received event failed sanity checks", room_id, event_id)
- raise FederationError(
- "ERROR",
- err.code,
- err.msg,
- affected=pdu.event_id,
+ logger.warn(
+ "[%s %s] Received event failed sanity checks", room_id, event_id
)
+ raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id)
# If we are currently in the process of joining this room, then we
# queue up events for later processing.
if room_id in self.room_queues:
logger.info(
"[%s %s] Queuing PDU from %s for now: join in progress",
- room_id, event_id, origin,
+ room_id,
+ event_id,
+ origin,
)
self.room_queues[room_id].append((pdu, origin))
return
@@ -202,14 +196,13 @@ class FederationHandler(BaseHandler):
#
# Note that if we were never in the room then we would have already
# dropped the event, since we wouldn't know the room version.
- is_in_room = yield self.auth.check_host_in_room(
- room_id,
- self.server_name
- )
+ is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name)
if not is_in_room:
logger.info(
"[%s %s] Ignoring PDU from %s as we're not in the room",
- room_id, event_id, origin,
+ room_id,
+ event_id,
+ origin,
)
defer.returnValue(None)
@@ -219,14 +212,9 @@ class FederationHandler(BaseHandler):
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth.
- min_depth = yield self.get_min_depth_for_context(
- pdu.room_id
- )
+ min_depth = yield self.get_min_depth_for_context(pdu.room_id)
- logger.debug(
- "[%s %s] min_depth: %d",
- room_id, event_id, min_depth,
- )
+ logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
prevs = set(pdu.prev_event_ids())
seen = yield self.store.have_seen_events(prevs)
@@ -244,12 +232,17 @@ class FederationHandler(BaseHandler):
# at a time.
logger.info(
"[%s %s] Acquiring room lock to fetch %d missing prev_events: %s",
- room_id, event_id, len(missing_prevs), shortstr(missing_prevs),
+ room_id,
+ event_id,
+ len(missing_prevs),
+ shortstr(missing_prevs),
)
with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
"[%s %s] Acquired room lock to fetch %d missing prev_events",
- room_id, event_id, len(missing_prevs),
+ room_id,
+ event_id,
+ len(missing_prevs),
)
yield self._get_missing_events_for_pdu(
@@ -263,12 +256,16 @@ class FederationHandler(BaseHandler):
if not prevs - seen:
logger.info(
"[%s %s] Found all missing prev_events",
- room_id, event_id,
+ room_id,
+ event_id,
)
elif missing_prevs:
logger.info(
"[%s %s] Not recursively fetching %d missing prev_events: %s",
- room_id, event_id, len(missing_prevs), shortstr(missing_prevs),
+ room_id,
+ event_id,
+ len(missing_prevs),
+ shortstr(missing_prevs),
)
if prevs - seen:
@@ -299,7 +296,10 @@ class FederationHandler(BaseHandler):
if sent_to_us_directly:
logger.warn(
"[%s %s] Rejecting: failed to fetch %d prev events: %s",
- room_id, event_id, len(prevs - seen), shortstr(prevs - seen)
+ room_id,
+ event_id,
+ len(prevs - seen),
+ shortstr(prevs - seen),
)
raise FederationError(
"ERROR",
@@ -314,9 +314,7 @@ class FederationHandler(BaseHandler):
# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
auth_chains = set()
- event_map = {
- event_id: pdu,
- }
+ event_map = {event_id: pdu}
try:
# Get the state of the events we know about
ours = yield self.store.get_state_groups_ids(room_id, seen)
@@ -333,7 +331,9 @@ class FederationHandler(BaseHandler):
for p in prevs - seen:
logger.info(
"[%s %s] Requesting state at missing prev_event %s",
- room_id, event_id, p,
+ room_id,
+ event_id,
+ p,
)
room_version = yield self.store.get_room_version(room_id)
@@ -344,19 +344,19 @@ class FederationHandler(BaseHandler):
# by the get_pdu_cache in federation_client.
remote_state, got_auth_chain = (
yield self.federation_client.get_state_for_room(
- origin, room_id, p,
+ origin, room_id, p
)
)
# we want the state *after* p; get_state_for_room returns the
# state *before* p.
remote_event = yield self.federation_client.get_pdu(
- [origin], p, room_version, outlier=True,
+ [origin], p, room_version, outlier=True
)
if remote_event is None:
raise Exception(
- "Unable to get missing prev_event %s" % (p, )
+ "Unable to get missing prev_event %s" % (p,)
)
if remote_event.is_state():
@@ -376,7 +376,9 @@ class FederationHandler(BaseHandler):
event_map[x.event_id] = x
state_map = yield resolve_events_with_store(
- room_version, state_maps, event_map,
+ room_version,
+ state_maps,
+ event_map,
state_res_store=StateResolutionStore(self.store),
)
@@ -392,15 +394,15 @@ class FederationHandler(BaseHandler):
)
event_map.update(evs)
- state = [
- event_map[e] for e in six.itervalues(state_map)
- ]
+ state = [event_map[e] for e in six.itervalues(state_map)]
auth_chain = list(auth_chains)
except Exception:
logger.warn(
"[%s %s] Error attempting to resolve state at missing "
"prev_events",
- room_id, event_id, exc_info=True,
+ room_id,
+ event_id,
+ exc_info=True,
)
raise FederationError(
"ERROR",
@@ -410,10 +412,7 @@ class FederationHandler(BaseHandler):
)
yield self._process_received_pdu(
- origin,
- pdu,
- state=state,
- auth_chain=auth_chain,
+ origin, pdu, state=state, auth_chain=auth_chain
)
@defer.inlineCallbacks
@@ -443,7 +442,10 @@ class FederationHandler(BaseHandler):
logger.info(
"[%s %s]: Requesting missing events between %s and %s",
- room_id, event_id, shortstr(latest), event_id,
+ room_id,
+ event_id,
+ shortstr(latest),
+ event_id,
)
# XXX: we set timeout to 10s to help workaround
@@ -494,19 +496,29 @@ class FederationHandler(BaseHandler):
#
# All that said: Let's try increasing the timout to 60s and see what happens.
- missing_events = yield self.federation_client.get_missing_events(
- origin,
- room_id,
- earliest_events_ids=list(latest),
- latest_events=[pdu],
- limit=10,
- min_depth=min_depth,
- timeout=60000,
- )
+ try:
+ missing_events = yield self.federation_client.get_missing_events(
+ origin,
+ room_id,
+ earliest_events_ids=list(latest),
+ latest_events=[pdu],
+ limit=10,
+ min_depth=min_depth,
+ timeout=60000,
+ )
+ except RequestSendFailed as e:
+ # We failed to get the missing events, but since we need to handle
+ # the case of `get_missing_events` not returning the necessary
+ # events anyway, it is safe to simply log the error and continue.
+ logger.warn("[%s %s]: Failed to get prev_events: %s", room_id, event_id, e)
+ return
logger.info(
"[%s %s]: Got %d prev_events: %s",
- room_id, event_id, len(missing_events), shortstr(missing_events),
+ room_id,
+ event_id,
+ len(missing_events),
+ shortstr(missing_events),
)
# We want to sort these by depth so we process them and
@@ -516,20 +528,20 @@ class FederationHandler(BaseHandler):
for ev in missing_events:
logger.info(
"[%s %s] Handling received prev_event %s",
- room_id, event_id, ev.event_id,
+ room_id,
+ event_id,
+ ev.event_id,
)
with logcontext.nested_logging_context(ev.event_id):
try:
- yield self.on_receive_pdu(
- origin,
- ev,
- sent_to_us_directly=False,
- )
+ yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
except FederationError as e:
if e.code == 403:
logger.warn(
"[%s %s] Received prev_event %s failed history check.",
- room_id, event_id, ev.event_id,
+ room_id,
+ event_id,
+ ev.event_id,
)
else:
raise
@@ -542,10 +554,7 @@ class FederationHandler(BaseHandler):
room_id = event.room_id
event_id = event.event_id
- logger.debug(
- "[%s %s] Processing event: %s",
- room_id, event_id, event,
- )
+ logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
event_ids = set()
if state:
@@ -567,43 +576,32 @@ class FederationHandler(BaseHandler):
e.internal_metadata.outlier = True
auth_ids = e.auth_event_ids()
auth = {
- (e.type, e.state_key): e for e in auth_chain
+ (e.type, e.state_key): e
+ for e in auth_chain
if e.event_id in auth_ids or e.type == EventTypes.Create
}
- event_infos.append({
- "event": e,
- "auth_events": auth,
- })
+ event_infos.append({"event": e, "auth_events": auth})
seen_ids.add(e.event_id)
logger.info(
"[%s %s] persisting newly-received auth/state events %s",
- room_id, event_id, [e["event"].event_id for e in event_infos]
+ room_id,
+ event_id,
+ [e["event"].event_id for e in event_infos],
)
yield self._handle_new_events(origin, event_infos)
try:
- context = yield self._handle_new_event(
- origin,
- event,
- state=state,
- )
+ context = yield self._handle_new_event(origin, event, state=state)
except AuthError as e:
- raise FederationError(
- "ERROR",
- e.code,
- e.msg,
- affected=event.event_id,
- )
+ raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
room = yield self.store.get_room(room_id)
if not room:
try:
yield self.store.store_room(
- room_id=room_id,
- room_creator_user_id="",
- is_public=False,
+ room_id=room_id, room_creator_user_id="", is_public=False
)
except StoreError:
logger.exception("Failed to store room.")
@@ -617,12 +615,10 @@ class FederationHandler(BaseHandler):
prev_state_ids = yield context.get_prev_state_ids(self.store)
- prev_state_id = prev_state_ids.get(
- (event.type, event.state_key)
- )
+ prev_state_id = prev_state_ids.get((event.type, event.state_key))
if prev_state_id:
prev_state = yield self.store.get_event(
- prev_state_id, allow_none=True,
+ prev_state_id, allow_none=True
)
if prev_state and prev_state.membership == Membership.JOIN:
newly_joined = False
@@ -653,10 +649,7 @@ class FederationHandler(BaseHandler):
room_version = yield self.store.get_room_version(room_id)
events = yield self.federation_client.backfill(
- dest,
- room_id,
- limit=limit,
- extremities=extremities,
+ dest, room_id, limit=limit, extremities=extremities
)
# ideally we'd sanity check the events here for excess prev_events etc,
@@ -683,16 +676,9 @@ class FederationHandler(BaseHandler):
event_ids = set(e.event_id for e in events)
- edges = [
- ev.event_id
- for ev in events
- if set(ev.prev_event_ids()) - event_ids
- ]
+ edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids]
- logger.info(
- "backfill: Got %d events with %d edges",
- len(events), len(edges),
- )
+ logger.info("backfill: Got %d events with %d edges", len(events), len(edges))
# For each edge get the current state.
@@ -701,9 +687,7 @@ class FederationHandler(BaseHandler):
events_to_state = {}
for e_id in edges:
state, auth = yield self.federation_client.get_state_for_room(
- destination=dest,
- room_id=room_id,
- event_id=e_id
+ destination=dest, room_id=room_id, event_id=e_id
)
auth_events.update({a.event_id: a for a in auth})
auth_events.update({s.event_id: s for s in state})
@@ -712,12 +696,14 @@ class FederationHandler(BaseHandler):
required_auth = set(
a_id
- for event in events + list(state_events.values()) + list(auth_events.values())
+ for event in events
+ + list(state_events.values())
+ + list(auth_events.values())
for a_id in event.auth_event_ids()
)
- auth_events.update({
- e_id: event_map[e_id] for e_id in required_auth if e_id in event_map
- })
+ auth_events.update(
+ {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
+ )
missing_auth = required_auth - set(auth_events)
failed_to_fetch = set()
@@ -736,27 +722,30 @@ class FederationHandler(BaseHandler):
if missing_auth - failed_to_fetch:
logger.info(
"Fetching missing auth for backfill: %r",
- missing_auth - failed_to_fetch
+ missing_auth - failed_to_fetch,
)
- results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- logcontext.run_in_background(
- self.federation_client.get_pdu,
- [dest],
- event_id,
- room_version=room_version,
- outlier=True,
- timeout=10000,
- )
- for event_id in missing_auth - failed_to_fetch
- ],
- consumeErrors=True
- )).addErrback(unwrapFirstError)
+ results = yield logcontext.make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ logcontext.run_in_background(
+ self.federation_client.get_pdu,
+ [dest],
+ event_id,
+ room_version=room_version,
+ outlier=True,
+ timeout=10000,
+ )
+ for event_id in missing_auth - failed_to_fetch
+ ],
+ consumeErrors=True,
+ )
+ ).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results if a})
required_auth.update(
a_id
- for event in results if event
+ for event in results
+ if event
for a_id in event.auth_event_ids()
)
missing_auth = required_auth - set(auth_events)
@@ -788,15 +777,19 @@ class FederationHandler(BaseHandler):
continue
a.internal_metadata.outlier = True
- ev_infos.append({
- "event": a,
- "auth_events": {
- (auth_events[a_id].type, auth_events[a_id].state_key):
- auth_events[a_id]
- for a_id in a.auth_event_ids()
- if a_id in auth_events
+ ev_infos.append(
+ {
+ "event": a,
+ "auth_events": {
+ (
+ auth_events[a_id].type,
+ auth_events[a_id].state_key,
+ ): auth_events[a_id]
+ for a_id in a.auth_event_ids()
+ if a_id in auth_events
+ },
}
- })
+ )
# Step 1b: persist the events in the chunk we fetched state for (i.e.
# the backwards extremities) as non-outliers.
@@ -804,23 +797,24 @@ class FederationHandler(BaseHandler):
# For paranoia we ensure that these events are marked as
# non-outliers
ev = event_map[e_id]
- assert(not ev.internal_metadata.is_outlier())
-
- ev_infos.append({
- "event": ev,
- "state": events_to_state[e_id],
- "auth_events": {
- (auth_events[a_id].type, auth_events[a_id].state_key):
- auth_events[a_id]
- for a_id in ev.auth_event_ids()
- if a_id in auth_events
+ assert not ev.internal_metadata.is_outlier()
+
+ ev_infos.append(
+ {
+ "event": ev,
+ "state": events_to_state[e_id],
+ "auth_events": {
+ (
+ auth_events[a_id].type,
+ auth_events[a_id].state_key,
+ ): auth_events[a_id]
+ for a_id in ev.auth_event_ids()
+ if a_id in auth_events
+ },
}
- })
+ )
- yield self._handle_new_events(
- dest, ev_infos,
- backfilled=True,
- )
+ yield self._handle_new_events(dest, ev_infos, backfilled=True)
# Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth)
@@ -831,14 +825,12 @@ class FederationHandler(BaseHandler):
# For paranoia we ensure that these events are marked as
# non-outliers
- assert(not event.internal_metadata.is_outlier())
+ assert not event.internal_metadata.is_outlier()
# We store these one at a time since each event depends on the
# previous to work out the state.
# TODO: We can probably do something more clever here.
- yield self._handle_new_event(
- dest, event, backfilled=True,
- )
+ yield self._handle_new_event(dest, event, backfilled=True)
defer.returnValue(events)
@@ -847,9 +839,7 @@ class FederationHandler(BaseHandler):
"""Checks the database to see if we should backfill before paginating,
and if so do.
"""
- extremities = yield self.store.get_oldest_events_with_depth_in_room(
- room_id
- )
+ extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id)
if not extremities:
logger.debug("Not backfilling as no extremeties found.")
@@ -881,31 +871,27 @@ class FederationHandler(BaseHandler):
# state *before* the event, ignoring the special casing certain event
# types have.
- forward_events = yield self.store.get_successor_events(
- list(extremities),
- )
+ forward_events = yield self.store.get_successor_events(list(extremities))
extremities_events = yield self.store.get_events(
- forward_events,
- check_redacted=False,
- get_prev_content=False,
+ forward_events, check_redacted=False, get_prev_content=False
)
# We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased.
filtered_extremities = yield filter_events_for_server(
- self.store, self.server_name, list(extremities_events.values()),
- redact=False, check_history_visibility_only=True,
+ self.store,
+ self.server_name,
+ list(extremities_events.values()),
+ redact=False,
+ check_history_visibility_only=True,
)
if not filtered_extremities:
defer.returnValue(False)
# Check if we reached a point where we should start backfilling.
- sorted_extremeties_tuple = sorted(
- extremities.items(),
- key=lambda e: -int(e[1])
- )
+ sorted_extremeties_tuple = sorted(extremities.items(), key=lambda e: -int(e[1]))
max_depth = sorted_extremeties_tuple[0][1]
# We don't want to specify too many extremities as it causes the backfill
@@ -914,8 +900,7 @@ class FederationHandler(BaseHandler):
if current_depth > max_depth:
logger.debug(
- "Not backfilling as we don't need to. %d < %d",
- max_depth, current_depth,
+ "Not backfilling as we don't need to. %d < %d", max_depth, current_depth
)
return
@@ -940,8 +925,7 @@ class FederationHandler(BaseHandler):
joined_users = [
(state_key, int(event.depth))
for (e_type, state_key), event in iteritems(state)
- if e_type == EventTypes.Member
- and event.membership == Membership.JOIN
+ if e_type == EventTypes.Member and event.membership == Membership.JOIN
]
joined_domains = {}
@@ -961,8 +945,7 @@ class FederationHandler(BaseHandler):
curr_domains = get_domains_from_state(curr_state)
likely_domains = [
- domain for domain, depth in curr_domains
- if domain != self.server_name
+ domain for domain, depth in curr_domains if domain != self.server_name
]
@defer.inlineCallbacks
@@ -971,28 +954,20 @@ class FederationHandler(BaseHandler):
for dom in domains:
try:
yield self.backfill(
- dom, room_id,
- limit=100,
- extremities=extremities,
+ dom, room_id, limit=100, extremities=extremities
)
# If this succeeded then we probably already have the
# appropriate stuff.
# TODO: We can probably do something more intelligent here.
defer.returnValue(True)
except SynapseError as e:
- logger.info(
- "Failed to backfill from %s because %s",
- dom, e,
- )
+ logger.info("Failed to backfill from %s because %s", dom, e)
continue
except CodeMessageException as e:
if 400 <= e.code < 500:
raise
- logger.info(
- "Failed to backfill from %s because %s",
- dom, e,
- )
+ logger.info("Failed to backfill from %s because %s", dom, e)
continue
except NotRetryingDestination as e:
logger.info(str(e))
@@ -1001,10 +976,7 @@ class FederationHandler(BaseHandler):
logger.info(e)
continue
except Exception as e:
- logger.exception(
- "Failed to backfill from %s because %s",
- dom, e,
- )
+ logger.exception("Failed to backfill from %s because %s", dom, e)
continue
defer.returnValue(False)
@@ -1025,10 +997,11 @@ class FederationHandler(BaseHandler):
resolve = logcontext.preserve_fn(
self.state_handler.resolve_state_groups_for_events
)
- states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [resolve(room_id, [e]) for e in event_ids],
- consumeErrors=True,
- ))
+ states = yield logcontext.make_deferred_yieldable(
+ defer.gatherResults(
+ [resolve(room_id, [e]) for e in event_ids], consumeErrors=True
+ )
+ )
# dict[str, dict[tuple, str]], a map from event_id to state map of
# event_ids.
@@ -1036,23 +1009,23 @@ class FederationHandler(BaseHandler):
state_map = yield self.store.get_events(
[e_id for ids in itervalues(states) for e_id in itervalues(ids)],
- get_prev_content=False
+ get_prev_content=False,
)
states = {
key: {
k: state_map[e_id]
for k, e_id in iteritems(state_dict)
if e_id in state_map
- } for key, state_dict in iteritems(states)
+ }
+ for key, state_dict in iteritems(states)
}
for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id])
- success = yield try_backfill([
- dom for dom, _ in likely_domains
- if dom not in tried_domains
- ])
+ success = yield try_backfill(
+ [dom for dom, _ in likely_domains if dom not in tried_domains]
+ )
if success:
defer.returnValue(True)
@@ -1077,20 +1050,20 @@ class FederationHandler(BaseHandler):
SynapseError if the event does not pass muster
"""
if len(ev.prev_event_ids()) > 20:
- logger.warn("Rejecting event %s which has %i prev_events",
- ev.event_id, len(ev.prev_event_ids()))
- raise SynapseError(
- http_client.BAD_REQUEST,
- "Too many prev_events",
+ logger.warn(
+ "Rejecting event %s which has %i prev_events",
+ ev.event_id,
+ len(ev.prev_event_ids()),
)
+ raise SynapseError(http_client.BAD_REQUEST, "Too many prev_events")
if len(ev.auth_event_ids()) > 10:
- logger.warn("Rejecting event %s which has %i auth_events",
- ev.event_id, len(ev.auth_event_ids()))
- raise SynapseError(
- http_client.BAD_REQUEST,
- "Too many auth_events",
+ logger.warn(
+ "Rejecting event %s which has %i auth_events",
+ ev.event_id,
+ len(ev.auth_event_ids()),
)
+ raise SynapseError(http_client.BAD_REQUEST, "Too many auth_events")
@defer.inlineCallbacks
def send_invite(self, target_host, event):
@@ -1102,7 +1075,7 @@ class FederationHandler(BaseHandler):
destination=target_host,
room_id=event.room_id,
event_id=event.event_id,
- pdu=event
+ pdu=event,
)
defer.returnValue(pdu)
@@ -1111,8 +1084,7 @@ class FederationHandler(BaseHandler):
def on_event_auth(self, event_id):
event = yield self.store.get_event(event_id)
auth = yield self.store.get_auth_chain(
- [auth_id for auth_id in event.auth_event_ids()],
- include_given=True
+ [auth_id for auth_id in event.auth_event_ids()], include_given=True
)
defer.returnValue([e for e in auth])
@@ -1138,15 +1110,13 @@ class FederationHandler(BaseHandler):
joinee,
"join",
content,
- params={
- "ver": KNOWN_ROOM_VERSIONS,
- },
+ params={"ver": KNOWN_ROOM_VERSIONS},
)
# This shouldn't happen, because the RoomMemberHandler has a
# linearizer lock which only allows one operation per user per room
# at a time - so this is just paranoia.
- assert (room_id not in self.room_queues)
+ assert room_id not in self.room_queues
self.room_queues[room_id] = []
@@ -1163,7 +1133,7 @@ class FederationHandler(BaseHandler):
except ValueError:
pass
ret = yield self.federation_client.send_join(
- target_hosts, event, event_format_version,
+ target_hosts, event, event_format_version
)
origin = ret["origin"]
@@ -1182,17 +1152,13 @@ class FederationHandler(BaseHandler):
try:
yield self.store.store_room(
- room_id=room_id,
- room_creator_user_id="",
- is_public=False
+ room_id=room_id, room_creator_user_id="", is_public=False
)
except Exception:
# FIXME
pass
- yield self._persist_auth_tree(
- origin, auth_chain, state, event
- )
+ yield self._persist_auth_tree(origin, auth_chain, state, event)
logger.debug("Finished joining %s to %s", joinee, room_id)
finally:
@@ -1219,14 +1185,18 @@ class FederationHandler(BaseHandler):
"""
for p, origin in room_queue:
try:
- logger.info("Processing queued PDU %s which was received "
- "while we were joining %s", p.event_id, p.room_id)
+ logger.info(
+ "Processing queued PDU %s which was received "
+ "while we were joining %s",
+ p.event_id,
+ p.room_id,
+ )
with logcontext.nested_logging_context(p.event_id):
yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
except Exception as e:
logger.warn(
- "Error handling queued PDU %s from %s: %s",
- p.event_id, origin, e)
+ "Error handling queued PDU %s from %s: %s", p.event_id, origin, e
+ )
@defer.inlineCallbacks
@log_function
@@ -1247,21 +1217,30 @@ class FederationHandler(BaseHandler):
"room_id": room_id,
"sender": user_id,
"state_key": user_id,
- }
+ },
)
try:
event, context = yield self.event_creation_handler.create_new_client_event(
- builder=builder,
+ builder=builder
)
except AuthError as e:
logger.warn("Failed to create join %r because %s", event, e)
raise e
+ event_allowed = yield self.third_party_event_rules.check_event_allowed(
+ event, context
+ )
+ if not event_allowed:
+ logger.info("Creation of join %s forbidden by third-party rules", event)
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
+
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request`
yield self.auth.check_from_context(
- room_version, event, context, do_sig_check=False,
+ room_version, event, context, do_sig_check=False
)
defer.returnValue(event)
@@ -1296,9 +1275,16 @@ class FederationHandler(BaseHandler):
# would introduce the danger of backwards-compatibility problems.
event.internal_metadata.send_on_behalf_of = origin
- context = yield self._handle_new_event(
- origin, event
+ context = yield self._handle_new_event(origin, event)
+
+ event_allowed = yield self.third_party_event_rules.check_event_allowed(
+ event, context
)
+ if not event_allowed:
+ logger.info("Sending of join %s forbidden by third-party rules", event)
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
logger.debug(
"on_send_join_request: After _handle_new_event: %s, sigs: %s",
@@ -1318,10 +1304,7 @@ class FederationHandler(BaseHandler):
state = yield self.store.get_events(list(prev_state_ids.values()))
- defer.returnValue({
- "state": list(state.values()),
- "auth_chain": auth_chain,
- })
+ defer.returnValue({"state": list(state.values()), "auth_chain": auth_chain})
@defer.inlineCallbacks
def on_invite_request(self, origin, pdu):
@@ -1342,7 +1325,7 @@ class FederationHandler(BaseHandler):
raise SynapseError(403, "This server does not accept room invites")
if not self.spam_checker.user_may_invite(
- event.sender, event.state_key, event.room_id,
+ event.sender, event.state_key, event.room_id
):
raise SynapseError(
403, "This user is not permitted to send invites to this server/user"
@@ -1354,26 +1337,23 @@ class FederationHandler(BaseHandler):
sender_domain = get_domain_from_id(event.sender)
if sender_domain != origin:
- raise SynapseError(400, "The invite event was not from the server sending it")
+ raise SynapseError(
+ 400, "The invite event was not from the server sending it"
+ )
if not self.is_mine_id(event.state_key):
raise SynapseError(400, "The invite event must be for this server")
# block any attempts to invite the server notices mxid
if event.state_key == self._server_notices_mxid:
- raise SynapseError(
- http_client.FORBIDDEN,
- "Cannot invite this user",
- )
+ raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user")
event.internal_metadata.outlier = True
event.internal_metadata.out_of_band_membership = True
event.signatures.update(
compute_event_signature(
- event.get_pdu_json(),
- self.hs.hostname,
- self.hs.config.signing_key[0]
+ event.get_pdu_json(), self.hs.hostname, self.hs.config.signing_key[0]
)
)
@@ -1385,10 +1365,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
origin, event, event_format_version = yield self._make_and_verify_event(
- target_hosts,
- room_id,
- user_id,
- "leave"
+ target_hosts, room_id, user_id, "leave"
)
# Mark as outlier as we don't have any state for this event; we're not
# even in the room.
@@ -1403,10 +1380,7 @@ class FederationHandler(BaseHandler):
except ValueError:
pass
- yield self.federation_client.send_leave(
- target_hosts,
- event
- )
+ yield self.federation_client.send_leave(target_hosts, event)
context = yield self.state_handler.compute_event_context(event)
yield self.persist_events_and_notify([(event, context)])
@@ -1414,25 +1388,21 @@ class FederationHandler(BaseHandler):
defer.returnValue(event)
@defer.inlineCallbacks
- def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
- content={}, params=None):
+ def _make_and_verify_event(
+ self, target_hosts, room_id, user_id, membership, content={}, params=None
+ ):
origin, event, format_ver = yield self.federation_client.make_membership_event(
- target_hosts,
- room_id,
- user_id,
- membership,
- content,
- params=params,
+ target_hosts, room_id, user_id, membership, content, params=params
)
logger.debug("Got response to make_%s: %s", membership, event)
# We should assert some things.
# FIXME: Do this in a nicer way
- assert(event.type == EventTypes.Member)
- assert(event.user_id == user_id)
- assert(event.state_key == user_id)
- assert(event.room_id == room_id)
+ assert event.type == EventTypes.Member
+ assert event.user_id == user_id
+ assert event.state_key == user_id
+ assert event.room_id == room_id
defer.returnValue((origin, event, format_ver))
@defer.inlineCallbacks
@@ -1451,18 +1421,27 @@ class FederationHandler(BaseHandler):
"room_id": room_id,
"sender": user_id,
"state_key": user_id,
- }
+ },
)
event, context = yield self.event_creation_handler.create_new_client_event(
- builder=builder,
+ builder=builder
)
+ event_allowed = yield self.third_party_event_rules.check_event_allowed(
+ event, context
+ )
+ if not event_allowed:
+ logger.warning("Creation of leave %s forbidden by third-party rules", event)
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
+
try:
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_leave_request`
yield self.auth.check_from_context(
- room_version, event, context, do_sig_check=False,
+ room_version, event, context, do_sig_check=False
)
except AuthError as e:
logger.warn("Failed to create new leave %r because %s", event, e)
@@ -1484,9 +1463,16 @@ class FederationHandler(BaseHandler):
event.internal_metadata.outlier = False
- yield self._handle_new_event(
- origin, event
+ context = yield self._handle_new_event(origin, event)
+
+ event_allowed = yield self.third_party_event_rules.check_event_allowed(
+ event, context
)
+ if not event_allowed:
+ logger.info("Sending of leave %s forbidden by third-party rules", event)
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
logger.debug(
"on_send_leave_request: After _handle_new_event: %s, sigs: %s",
@@ -1502,18 +1488,14 @@ class FederationHandler(BaseHandler):
"""
event = yield self.store.get_event(
- event_id, allow_none=False, check_room_id=room_id,
+ event_id, allow_none=False, check_room_id=room_id
)
- state_groups = yield self.store.get_state_groups(
- room_id, [event_id]
- )
+ state_groups = yield self.store.get_state_groups(room_id, [event_id])
if state_groups:
_, state = list(iteritems(state_groups)).pop()
- results = {
- (e.type, e.state_key): e for e in state
- }
+ results = {(e.type, e.state_key): e for e in state}
if event.is_state():
# Get previous state
@@ -1535,12 +1517,10 @@ class FederationHandler(BaseHandler):
"""Returns the state at the event. i.e. not including said event.
"""
event = yield self.store.get_event(
- event_id, allow_none=False, check_room_id=room_id,
+ event_id, allow_none=False, check_room_id=room_id
)
- state_groups = yield self.store.get_state_groups_ids(
- room_id, [event_id]
- )
+ state_groups = yield self.store.get_state_groups_ids(room_id, [event_id])
if state_groups:
_, state = list(state_groups.items()).pop()
@@ -1566,11 +1546,7 @@ class FederationHandler(BaseHandler):
if not in_room:
raise AuthError(403, "Host not in room.")
- events = yield self.store.get_backfill_events(
- room_id,
- pdu_list,
- limit
- )
+ events = yield self.store.get_backfill_events(room_id, pdu_list, limit)
events = yield filter_events_for_server(self.store, origin, events)
@@ -1594,22 +1570,15 @@ class FederationHandler(BaseHandler):
AuthError if the server is not currently in the room
"""
event = yield self.store.get_event(
- event_id,
- allow_none=True,
- allow_rejected=True,
+ event_id, allow_none=True, allow_rejected=True
)
if event:
- in_room = yield self.auth.check_host_in_room(
- event.room_id,
- origin
- )
+ in_room = yield self.auth.check_host_in_room(event.room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
- events = yield filter_events_for_server(
- self.store, origin, [event],
- )
+ events = yield filter_events_for_server(self.store, origin, [event])
event = events[0]
defer.returnValue(event)
else:
@@ -1619,13 +1588,11 @@ class FederationHandler(BaseHandler):
return self.store.get_min_depth(context)
@defer.inlineCallbacks
- def _handle_new_event(self, origin, event, state=None, auth_events=None,
- backfilled=False):
+ def _handle_new_event(
+ self, origin, event, state=None, auth_events=None, backfilled=False
+ ):
context = yield self._prep_event(
- origin, event,
- state=state,
- auth_events=auth_events,
- backfilled=backfilled,
+ origin, event, state=state, auth_events=auth_events, backfilled=backfilled
)
# reraise does not allow inlineCallbacks to preserve the stacktrace, so we
@@ -1638,15 +1605,13 @@ class FederationHandler(BaseHandler):
)
yield self.persist_events_and_notify(
- [(event, context)],
- backfilled=backfilled,
+ [(event, context)], backfilled=backfilled
)
success = True
finally:
if not success:
logcontext.run_in_background(
- self.store.remove_push_actions_from_staging,
- event.event_id,
+ self.store.remove_push_actions_from_staging, event.event_id
)
defer.returnValue(context)
@@ -1674,12 +1639,15 @@ class FederationHandler(BaseHandler):
)
defer.returnValue(res)
- contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- logcontext.run_in_background(prep, ev_info)
- for ev_info in event_infos
- ], consumeErrors=True,
- ))
+ contexts = yield logcontext.make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ logcontext.run_in_background(prep, ev_info)
+ for ev_info in event_infos
+ ],
+ consumeErrors=True,
+ )
+ )
yield self.persist_events_and_notify(
[
@@ -1714,8 +1682,7 @@ class FederationHandler(BaseHandler):
events_to_context[e.event_id] = ctx
event_map = {
- e.event_id: e
- for e in itertools.chain(auth_events, state, [event])
+ e.event_id: e for e in itertools.chain(auth_events, state, [event])
}
create_event = None
@@ -1730,7 +1697,7 @@ class FederationHandler(BaseHandler):
raise SynapseError(400, "No create event in state")
room_version = create_event.content.get(
- "room_version", RoomVersions.V1.identifier,
+ "room_version", RoomVersions.V1.identifier
)
missing_auth_events = set()
@@ -1741,11 +1708,7 @@ class FederationHandler(BaseHandler):
for e_id in missing_auth_events:
m_ev = yield self.federation_client.get_pdu(
- [origin],
- e_id,
- room_version=room_version,
- outlier=True,
- timeout=10000,
+ [origin], e_id, room_version=room_version, outlier=True, timeout=10000
)
if m_ev and m_ev.event_id == e_id:
event_map[e_id] = m_ev
@@ -1770,10 +1733,7 @@ class FederationHandler(BaseHandler):
# cause SynapseErrors in auth.check. We don't want to give up
# the attempt to federate altogether in such cases.
- logger.warn(
- "Rejecting %s because %s",
- e.event_id, err.msg
- )
+ logger.warn("Rejecting %s because %s", e.event_id, err.msg)
if e == event:
raise
@@ -1783,16 +1743,14 @@ class FederationHandler(BaseHandler):
[
(e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state)
- ],
+ ]
)
new_event_context = yield self.state_handler.compute_event_context(
event, old_state=state
)
- yield self.persist_events_and_notify(
- [(event, new_event_context)],
- )
+ yield self.persist_events_and_notify([(event, new_event_context)])
@defer.inlineCallbacks
def _prep_event(self, origin, event, state, auth_events, backfilled):
@@ -1808,40 +1766,30 @@ class FederationHandler(BaseHandler):
Returns:
Deferred, which resolves to synapse.events.snapshot.EventContext
"""
- context = yield self.state_handler.compute_event_context(
- event, old_state=state,
- )
+ context = yield self.state_handler.compute_event_context(event, old_state=state)
if not auth_events:
prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.auth.compute_auth_events(
- event, prev_state_ids, for_verification=True,
+ event, prev_state_ids, for_verification=True
)
auth_events = yield self.store.get_events(auth_events_ids)
- auth_events = {
- (e.type, e.state_key): e for e in auth_events.values()
- }
+ auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_event_ids():
if len(event.prev_event_ids()) == 1 and event.depth < 5:
c = yield self.store.get_event(
- event.prev_event_ids()[0],
- allow_none=True,
+ event.prev_event_ids()[0], allow_none=True
)
if c and c.type == EventTypes.Create:
auth_events[(c.type, c.state_key)] = c
try:
- yield self.do_auth(
- origin, event, context, auth_events=auth_events
- )
+ yield self.do_auth(origin, event, context, auth_events=auth_events)
except AuthError as e:
- logger.warn(
- "[%s %s] Rejecting: %s",
- event.room_id, event.event_id, e.msg
- )
+ logger.warn("[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg)
context.rejected = RejectedReason.AUTH_ERROR
@@ -1872,9 +1820,7 @@ class FederationHandler(BaseHandler):
# "soft-fail" the event.
do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier()
if do_soft_fail_check:
- extrem_ids = yield self.store.get_latest_event_ids_in_room(
- event.room_id,
- )
+ extrem_ids = yield self.store.get_latest_event_ids_in_room(event.room_id)
extrem_ids = set(extrem_ids)
prev_event_ids = set(event.prev_event_ids())
@@ -1902,31 +1848,31 @@ class FederationHandler(BaseHandler):
# like bans, especially with state res v2.
state_sets = yield self.store.get_state_groups(
- event.room_id, extrem_ids,
+ event.room_id, extrem_ids
)
state_sets = list(state_sets.values())
state_sets.append(state)
current_state_ids = yield self.state_handler.resolve_events(
- room_version, state_sets, event,
+ room_version, state_sets, event
)
current_state_ids = {
k: e.event_id for k, e in iteritems(current_state_ids)
}
else:
current_state_ids = yield self.state_handler.get_current_state_ids(
- event.room_id, latest_event_ids=extrem_ids,
+ event.room_id, latest_event_ids=extrem_ids
)
logger.debug(
"Doing soft-fail check for %s: state %s",
- event.event_id, current_state_ids,
+ event.event_id,
+ current_state_ids,
)
# Now check if event pass auth against said current state
auth_types = auth_types_for_event(event)
current_state_ids = [
- e for k, e in iteritems(current_state_ids)
- if k in auth_types
+ e for k, e in iteritems(current_state_ids) if k in auth_types
]
current_auth_events = yield self.store.get_events(current_state_ids)
@@ -1937,19 +1883,14 @@ class FederationHandler(BaseHandler):
try:
self.auth.check(room_version, event, auth_events=current_auth_events)
except AuthError as e:
- logger.warn(
- "Soft-failing %r because %s",
- event, e,
- )
+ logger.warn("Soft-failing %r because %s", event, e)
event.internal_metadata.soft_failed = True
@defer.inlineCallbacks
- def on_query_auth(self, origin, event_id, room_id, remote_auth_chain, rejects,
- missing):
- in_room = yield self.auth.check_host_in_room(
- room_id,
- origin
- )
+ def on_query_auth(
+ self, origin, event_id, room_id, remote_auth_chain, rejects, missing
+ ):
+ in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -1967,28 +1908,23 @@ class FederationHandler(BaseHandler):
# Now get the current auth_chain for the event.
local_auth_chain = yield self.store.get_auth_chain(
- [auth_id for auth_id in event.auth_event_ids()],
- include_given=True
+ [auth_id for auth_id in event.auth_event_ids()], include_given=True
)
# TODO: Check if we would now reject event_id. If so we need to tell
# everyone.
- ret = yield self.construct_auth_difference(
- local_auth_chain, remote_auth_chain
- )
+ ret = yield self.construct_auth_difference(local_auth_chain, remote_auth_chain)
logger.debug("on_query_auth returning: %s", ret)
defer.returnValue(ret)
@defer.inlineCallbacks
- def on_get_missing_events(self, origin, room_id, earliest_events,
- latest_events, limit):
- in_room = yield self.auth.check_host_in_room(
- room_id,
- origin
- )
+ def on_get_missing_events(
+ self, origin, room_id, earliest_events, latest_events, limit
+ ):
+ in_room = yield self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -2002,7 +1938,7 @@ class FederationHandler(BaseHandler):
)
missing_events = yield filter_events_for_server(
- self.store, origin, missing_events,
+ self.store, origin, missing_events
)
defer.returnValue(missing_events)
@@ -2090,25 +2026,17 @@ class FederationHandler(BaseHandler):
if missing_auth:
# TODO: can we use store.have_seen_events here instead?
- have_events = yield self.store.get_seen_events_with_rejections(
- missing_auth
- )
+ have_events = yield self.store.get_seen_events_with_rejections(missing_auth)
logger.debug("Got events %s from store", have_events)
missing_auth.difference_update(have_events.keys())
else:
have_events = {}
- have_events.update({
- e.event_id: ""
- for e in auth_events.values()
- })
+ have_events.update({e.event_id: "" for e in auth_events.values()})
if missing_auth:
# If we don't have all the auth events, we need to get them.
- logger.info(
- "auth_events contains unknown events: %s",
- missing_auth,
- )
+ logger.info("auth_events contains unknown events: %s", missing_auth)
try:
try:
remote_auth_chain = yield self.federation_client.get_event_auth(
@@ -2134,18 +2062,16 @@ class FederationHandler(BaseHandler):
try:
auth_ids = e.auth_event_ids()
auth = {
- (e.type, e.state_key): e for e in remote_auth_chain
+ (e.type, e.state_key): e
+ for e in remote_auth_chain
if e.event_id in auth_ids or e.type == EventTypes.Create
}
e.internal_metadata.outlier = True
logger.debug(
- "do_auth %s missing_auth: %s",
- event.event_id, e.event_id
- )
- yield self._handle_new_event(
- origin, e, auth_events=auth
+ "do_auth %s missing_auth: %s", event.event_id, e.event_id
)
+ yield self._handle_new_event(origin, e, auth_events=auth)
if e.event_id in event_auth_events:
auth_events[(e.type, e.state_key)] = e
@@ -2181,35 +2107,36 @@ class FederationHandler(BaseHandler):
room_version = yield self.store.get_room_version(event.room_id)
different_events = yield logcontext.make_deferred_yieldable(
- defer.gatherResults([
- logcontext.run_in_background(
- self.store.get_event,
- d,
- allow_none=True,
- allow_rejected=False,
- )
- for d in different_auth
- if d in have_events and not have_events[d]
- ], consumeErrors=True)
+ defer.gatherResults(
+ [
+ logcontext.run_in_background(
+ self.store.get_event, d, allow_none=True, allow_rejected=False
+ )
+ for d in different_auth
+ if d in have_events and not have_events[d]
+ ],
+ consumeErrors=True,
+ )
).addErrback(unwrapFirstError)
if different_events:
local_view = dict(auth_events)
remote_view = dict(auth_events)
- remote_view.update({
- (d.type, d.state_key): d for d in different_events if d
- })
+ remote_view.update(
+ {(d.type, d.state_key): d for d in different_events if d}
+ )
new_state = yield self.state_handler.resolve_events(
room_version,
[list(local_view.values()), list(remote_view.values())],
- event
+ event,
)
logger.info(
"After state res: updating auth_events with new state %s",
{
- (d.type, d.state_key): d.event_id for d in new_state.values()
+ (d.type, d.state_key): d.event_id
+ for d in new_state.values()
if auth_events.get((d.type, d.state_key)) != d
},
)
@@ -2221,7 +2148,7 @@ class FederationHandler(BaseHandler):
)
yield self._update_context_for_auth_events(
- event, context, auth_events, event_key,
+ event, context, auth_events, event_key
)
if not different_auth:
@@ -2255,21 +2182,14 @@ class FederationHandler(BaseHandler):
prev_state_ids = yield context.get_prev_state_ids(self.store)
# 1. Get what we think is the auth chain.
- auth_ids = yield self.auth.compute_auth_events(
- event, prev_state_ids
- )
- local_auth_chain = yield self.store.get_auth_chain(
- auth_ids, include_given=True
- )
+ auth_ids = yield self.auth.compute_auth_events(event, prev_state_ids)
+ local_auth_chain = yield self.store.get_auth_chain(auth_ids, include_given=True)
try:
# 2. Get remote difference.
try:
result = yield self.federation_client.query_auth(
- origin,
- event.room_id,
- event.event_id,
- local_auth_chain,
+ origin, event.room_id, event.event_id, local_auth_chain
)
except RequestSendFailed as e:
# The other side isn't around or doesn't implement the
@@ -2294,19 +2214,15 @@ class FederationHandler(BaseHandler):
auth = {
(e.type, e.state_key): e
for e in result["auth_chain"]
- if e.event_id in auth_ids
- or event.type == EventTypes.Create
+ if e.event_id in auth_ids or event.type == EventTypes.Create
}
ev.internal_metadata.outlier = True
logger.debug(
- "do_auth %s different_auth: %s",
- event.event_id, e.event_id
+ "do_auth %s different_auth: %s", event.event_id, e.event_id
)
- yield self._handle_new_event(
- origin, ev, auth_events=auth
- )
+ yield self._handle_new_event(origin, ev, auth_events=auth)
if ev.event_id in event_auth_events:
auth_events[(ev.type, ev.state_key)] = ev
@@ -2321,12 +2237,11 @@ class FederationHandler(BaseHandler):
# TODO.
yield self._update_context_for_auth_events(
- event, context, auth_events, event_key,
+ event, context, auth_events, event_key
)
@defer.inlineCallbacks
- def _update_context_for_auth_events(self, event, context, auth_events,
- event_key):
+ def _update_context_for_auth_events(self, event, context, auth_events, event_key):
"""Update the state_ids in an event context after auth event resolution,
storing the changes as a new state group.
@@ -2343,8 +2258,7 @@ class FederationHandler(BaseHandler):
this will not be included in the current_state in the context.
"""
state_updates = {
- k: a.event_id for k, a in iteritems(auth_events)
- if k != event_key
+ k: a.event_id for k, a in iteritems(auth_events) if k != event_key
}
current_state_ids = yield context.get_current_state_ids(self.store)
current_state_ids = dict(current_state_ids)
@@ -2354,9 +2268,7 @@ class FederationHandler(BaseHandler):
prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_state_ids = dict(prev_state_ids)
- prev_state_ids.update({
- k: a.event_id for k, a in iteritems(auth_events)
- })
+ prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)})
# create a new state group as a delta from the existing one.
prev_group = context.state_group
@@ -2505,30 +2417,23 @@ class FederationHandler(BaseHandler):
logger.debug("construct_auth_difference returning")
- defer.returnValue({
- "auth_chain": local_auth,
- "rejects": {
- e.event_id: {
- "reason": reason_map[e.event_id],
- "proof": None,
- }
- for e in base_remote_rejected
- },
- "missing": [e.event_id for e in missing_locals],
- })
+ defer.returnValue(
+ {
+ "auth_chain": local_auth,
+ "rejects": {
+ e.event_id: {"reason": reason_map[e.event_id], "proof": None}
+ for e in base_remote_rejected
+ },
+ "missing": [e.event_id for e in missing_locals],
+ }
+ )
@defer.inlineCallbacks
@log_function
def exchange_third_party_invite(
- self,
- sender_user_id,
- target_user_id,
- room_id,
- signed,
+ self, sender_user_id, target_user_id, room_id, signed
):
- third_party_invite = {
- "signed": signed,
- }
+ third_party_invite = {"signed": signed}
event_dict = {
"type": EventTypes.Member,
@@ -2550,6 +2455,18 @@ class FederationHandler(BaseHandler):
builder=builder
)
+ event_allowed = yield self.third_party_event_rules.check_event_allowed(
+ event, context
+ )
+ if not event_allowed:
+ logger.info(
+ "Creation of threepid invite %s forbidden by third-party rules",
+ event,
+ )
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
+
event, context = yield self.add_display_name_to_third_party_invite(
room_version, event_dict, event, context
)
@@ -2572,9 +2489,7 @@ class FederationHandler(BaseHandler):
else:
destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
yield self.federation_client.forward_third_party_invite(
- destinations,
- room_id,
- event_dict,
+ destinations, room_id, event_dict
)
@defer.inlineCallbacks
@@ -2595,9 +2510,20 @@ class FederationHandler(BaseHandler):
builder = self.event_builder_factory.new(room_version, event_dict)
event, context = yield self.event_creation_handler.create_new_client_event(
- builder=builder,
+ builder=builder
)
+ event_allowed = yield self.third_party_event_rules.check_event_allowed(
+ event, context
+ )
+ if not event_allowed:
+ logger.warning(
+ "Exchange of threepid invite %s forbidden by third-party rules", event
+ )
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
+
event, context = yield self.add_display_name_to_third_party_invite(
room_version, event_dict, event, context
)
@@ -2613,21 +2539,16 @@ class FederationHandler(BaseHandler):
# though the sender isn't a local user.
event.internal_metadata.send_on_behalf_of = get_domain_from_id(event.sender)
- # XXX we send the invite here, but send_membership_event also sends it,
- # so we end up making two requests. I think this is redundant.
- returned_invite = yield self.send_invite(origin, event)
- # TODO: Make sure the signatures actually are correct.
- event.signatures.update(returned_invite.signatures)
-
member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context)
@defer.inlineCallbacks
- def add_display_name_to_third_party_invite(self, room_version, event_dict,
- event, context):
+ def add_display_name_to_third_party_invite(
+ self, room_version, event_dict, event, context
+ ):
key = (
EventTypes.ThirdPartyInvite,
- event.content["third_party_invite"]["signed"]["token"]
+ event.content["third_party_invite"]["signed"]["token"],
)
original_invite = None
prev_state_ids = yield context.get_prev_state_ids(self.store)
@@ -2641,8 +2562,7 @@ class FederationHandler(BaseHandler):
event_dict["content"]["third_party_invite"]["display_name"] = display_name
else:
logger.info(
- "Could not find invite event for third_party_invite: %r",
- event_dict
+ "Could not find invite event for third_party_invite: %r", event_dict
)
# We don't discard here as this is not the appropriate place to do
# auth checks. If we need the invite and don't have it then the
@@ -2651,7 +2571,7 @@ class FederationHandler(BaseHandler):
builder = self.event_builder_factory.new(room_version, event_dict)
EventValidator().validate_builder(builder)
event, context = yield self.event_creation_handler.create_new_client_event(
- builder=builder,
+ builder=builder
)
EventValidator().validate_new(event)
defer.returnValue((event, context))
@@ -2675,9 +2595,7 @@ class FederationHandler(BaseHandler):
token = signed["token"]
prev_state_ids = yield context.get_prev_state_ids(self.store)
- invite_event_id = prev_state_ids.get(
- (EventTypes.ThirdPartyInvite, token,)
- )
+ invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))
invite_event = None
if invite_event_id:
@@ -2686,25 +2604,59 @@ class FederationHandler(BaseHandler):
if not invite_event:
raise AuthError(403, "Could not find invite")
+ logger.debug("Checking auth on event %r", event.content)
+
last_exception = None
+ # for each public key in the 3pid invite event
for public_key_object in self.hs.get_auth().get_public_keys(invite_event):
try:
+ # for each sig on the third_party_invite block of the actual invite
for server, signature_block in signed["signatures"].items():
for key_name, encoded_signature in signature_block.items():
if not key_name.startswith("ed25519:"):
continue
- public_key = public_key_object["public_key"]
- verify_key = decode_verify_key_bytes(
+ logger.debug(
+ "Attempting to verify sig with key %s from %r "
+ "against pubkey %r",
key_name,
- decode_base64(public_key)
+ server,
+ public_key_object,
)
- verify_signed_json(signed, server, verify_key)
- if "key_validity_url" in public_key_object:
- yield self._check_key_revocation(
- public_key,
- public_key_object["key_validity_url"]
+
+ try:
+ public_key = public_key_object["public_key"]
+ verify_key = decode_verify_key_bytes(
+ key_name, decode_base64(public_key)
+ )
+ verify_signed_json(signed, server, verify_key)
+ logger.debug(
+ "Successfully verified sig with key %s from %r "
+ "against pubkey %r",
+ key_name,
+ server,
+ public_key_object,
+ )
+ except Exception:
+ logger.info(
+ "Failed to verify sig with key %s from %r "
+ "against pubkey %r",
+ key_name,
+ server,
+ public_key_object,
+ )
+ raise
+ try:
+ if "key_validity_url" in public_key_object:
+ yield self._check_key_revocation(
+ public_key, public_key_object["key_validity_url"]
+ )
+ except Exception:
+ logger.info(
+ "Failed to query key_validity_url %s",
+ public_key_object["key_validity_url"],
)
+ raise
return
except Exception as e:
last_exception = e
@@ -2725,15 +2677,9 @@ class FederationHandler(BaseHandler):
for revocation.
"""
try:
- response = yield self.http_client.get_json(
- url,
- {"public_key": public_key}
- )
+ response = yield self.http_client.get_json(url, {"public_key": public_key})
except Exception:
- raise SynapseError(
- 502,
- "Third party certificate could not be checked"
- )
+ raise SynapseError(502, "Third party certificate could not be checked")
if "valid" not in response or not response["valid"]:
raise AuthError(403, "Third party certificate was invalid")
@@ -2754,12 +2700,11 @@ class FederationHandler(BaseHandler):
yield self._send_events_to_master(
store=self.store,
event_and_contexts=event_and_contexts,
- backfilled=backfilled
+ backfilled=backfilled,
)
else:
max_stream_id = yield self.store.persist_events(
- event_and_contexts,
- backfilled=backfilled,
+ event_and_contexts, backfilled=backfilled
)
if not backfilled: # Never notify for backfilled events
@@ -2793,13 +2738,10 @@ class FederationHandler(BaseHandler):
event_stream_id = event.internal_metadata.stream_ordering
self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id,
- extra_users=extra_users
+ event, event_stream_id, max_stream_id, extra_users=extra_users
)
- return self.pusher_pool.on_new_notifications(
- event_stream_id, max_stream_id,
- )
+ return self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
def _clean_room_for_join(self, room_id):
"""Called to clean up any data in DB for a given room, ready for the
@@ -2818,9 +2760,7 @@ class FederationHandler(BaseHandler):
"""
if self.config.worker_app:
return self._notify_user_membership_change(
- room_id=room_id,
- user_id=user.to_string(),
- change="joined",
+ room_id=room_id, user_id=user.to_string(), change="joined"
)
else:
return user_joined_room(self.distributor, user, room_id)
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 02c508acec..7da63bb643 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -30,6 +30,7 @@ def _create_rerouter(func_name):
"""Returns a function that looks at the group id and calls the function
on federation or the local group server if the group is local
"""
+
def f(self, group_id, *args, **kwargs):
if self.is_mine_id(group_id):
return getattr(self.groups_server_handler, func_name)(
@@ -49,9 +50,7 @@ def _create_rerouter(func_name):
def http_response_errback(failure):
failure.trap(HttpResponseException)
e = failure.value
- if e.code == 403:
- raise e.to_synapse_error()
- return failure
+ raise e.to_synapse_error()
def request_failed_errback(failure):
failure.trap(RequestSendFailed)
@@ -60,6 +59,7 @@ def _create_rerouter(func_name):
d.addErrback(http_response_errback)
d.addErrback(request_failed_errback)
return d
+
return f
@@ -127,7 +127,7 @@ class GroupsLocalHandler(object):
)
else:
res = yield self.transport_client.get_group_summary(
- get_domain_from_id(group_id), group_id, requester_user_id,
+ get_domain_from_id(group_id), group_id, requester_user_id
)
group_server_name = get_domain_from_id(group_id)
@@ -184,7 +184,7 @@ class GroupsLocalHandler(object):
content["user_profile"] = yield self.profile_handler.get_profile(user_id)
res = yield self.transport_client.create_group(
- get_domain_from_id(group_id), group_id, user_id, content,
+ get_domain_from_id(group_id), group_id, user_id, content
)
remote_attestation = res["attestation"]
@@ -197,16 +197,15 @@ class GroupsLocalHandler(object):
is_publicised = content.get("publicise", False)
token = yield self.store.register_user_group_membership(
- group_id, user_id,
+ group_id,
+ user_id,
membership="join",
is_admin=True,
local_attestation=local_attestation,
remote_attestation=remote_attestation,
is_publicised=is_publicised,
)
- self.notifier.on_new_event(
- "groups_key", token, users=[user_id],
- )
+ self.notifier.on_new_event("groups_key", token, users=[user_id])
defer.returnValue(res)
@@ -223,7 +222,7 @@ class GroupsLocalHandler(object):
group_server_name = get_domain_from_id(group_id)
res = yield self.transport_client.get_users_in_group(
- get_domain_from_id(group_id), group_id, requester_user_id,
+ get_domain_from_id(group_id), group_id, requester_user_id
)
chunk = res["chunk"]
@@ -252,9 +251,7 @@ class GroupsLocalHandler(object):
"""Request to join a group
"""
if self.is_mine_id(group_id):
- yield self.groups_server_handler.join_group(
- group_id, user_id, content
- )
+ yield self.groups_server_handler.join_group(group_id, user_id, content)
local_attestation = None
remote_attestation = None
else:
@@ -262,7 +259,7 @@ class GroupsLocalHandler(object):
content["attestation"] = local_attestation
res = yield self.transport_client.join_group(
- get_domain_from_id(group_id), group_id, user_id, content,
+ get_domain_from_id(group_id), group_id, user_id, content
)
remote_attestation = res["attestation"]
@@ -278,16 +275,15 @@ class GroupsLocalHandler(object):
is_publicised = content.get("publicise", False)
token = yield self.store.register_user_group_membership(
- group_id, user_id,
+ group_id,
+ user_id,
membership="join",
is_admin=False,
local_attestation=local_attestation,
remote_attestation=remote_attestation,
is_publicised=is_publicised,
)
- self.notifier.on_new_event(
- "groups_key", token, users=[user_id],
- )
+ self.notifier.on_new_event("groups_key", token, users=[user_id])
defer.returnValue({})
@@ -296,9 +292,7 @@ class GroupsLocalHandler(object):
"""Accept an invite to a group
"""
if self.is_mine_id(group_id):
- yield self.groups_server_handler.accept_invite(
- group_id, user_id, content
- )
+ yield self.groups_server_handler.accept_invite(group_id, user_id, content)
local_attestation = None
remote_attestation = None
else:
@@ -306,7 +300,7 @@ class GroupsLocalHandler(object):
content["attestation"] = local_attestation
res = yield self.transport_client.accept_group_invite(
- get_domain_from_id(group_id), group_id, user_id, content,
+ get_domain_from_id(group_id), group_id, user_id, content
)
remote_attestation = res["attestation"]
@@ -322,16 +316,15 @@ class GroupsLocalHandler(object):
is_publicised = content.get("publicise", False)
token = yield self.store.register_user_group_membership(
- group_id, user_id,
+ group_id,
+ user_id,
membership="join",
is_admin=False,
local_attestation=local_attestation,
remote_attestation=remote_attestation,
is_publicised=is_publicised,
)
- self.notifier.on_new_event(
- "groups_key", token, users=[user_id],
- )
+ self.notifier.on_new_event("groups_key", token, users=[user_id])
defer.returnValue({})
@@ -339,17 +332,17 @@ class GroupsLocalHandler(object):
def invite(self, group_id, user_id, requester_user_id, config):
"""Invite a user to a group
"""
- content = {
- "requester_user_id": requester_user_id,
- "config": config,
- }
+ content = {"requester_user_id": requester_user_id, "config": config}
if self.is_mine_id(group_id):
res = yield self.groups_server_handler.invite_to_group(
- group_id, user_id, requester_user_id, content,
+ group_id, user_id, requester_user_id, content
)
else:
res = yield self.transport_client.invite_to_group(
- get_domain_from_id(group_id), group_id, user_id, requester_user_id,
+ get_domain_from_id(group_id),
+ group_id,
+ user_id,
+ requester_user_id,
content,
)
@@ -372,13 +365,12 @@ class GroupsLocalHandler(object):
local_profile["avatar_url"] = content["profile"]["avatar_url"]
token = yield self.store.register_user_group_membership(
- group_id, user_id,
+ group_id,
+ user_id,
membership="invite",
content={"profile": local_profile, "inviter": content["inviter"]},
)
- self.notifier.on_new_event(
- "groups_key", token, users=[user_id],
- )
+ self.notifier.on_new_event("groups_key", token, users=[user_id])
try:
user_profile = yield self.profile_handler.get_profile(user_id)
except Exception as e:
@@ -393,25 +385,25 @@ class GroupsLocalHandler(object):
"""
if user_id == requester_user_id:
token = yield self.store.register_user_group_membership(
- group_id, user_id,
- membership="leave",
- )
- self.notifier.on_new_event(
- "groups_key", token, users=[user_id],
+ group_id, user_id, membership="leave"
)
+ self.notifier.on_new_event("groups_key", token, users=[user_id])
# TODO: Should probably remember that we tried to leave so that we can
# retry if the group server is currently down.
if self.is_mine_id(group_id):
res = yield self.groups_server_handler.remove_user_from_group(
- group_id, user_id, requester_user_id, content,
+ group_id, user_id, requester_user_id, content
)
else:
content["requester_user_id"] = requester_user_id
res = yield self.transport_client.remove_user_from_group(
- get_domain_from_id(group_id), group_id, requester_user_id,
- user_id, content,
+ get_domain_from_id(group_id),
+ group_id,
+ requester_user_id,
+ user_id,
+ content,
)
defer.returnValue(res)
@@ -422,12 +414,9 @@ class GroupsLocalHandler(object):
"""
# TODO: Check if user in group
token = yield self.store.register_user_group_membership(
- group_id, user_id,
- membership="leave",
- )
- self.notifier.on_new_event(
- "groups_key", token, users=[user_id],
+ group_id, user_id, membership="leave"
)
+ self.notifier.on_new_event("groups_key", token, users=[user_id])
@defer.inlineCallbacks
def get_joined_groups(self, user_id):
@@ -447,7 +436,7 @@ class GroupsLocalHandler(object):
defer.returnValue({"groups": result})
else:
bulk_result = yield self.transport_client.bulk_get_publicised_groups(
- get_domain_from_id(user_id), [user_id],
+ get_domain_from_id(user_id), [user_id]
)
result = bulk_result.get("users", {}).get(user_id)
# TODO: Verify attestations
@@ -462,9 +451,7 @@ class GroupsLocalHandler(object):
if self.hs.is_mine_id(user_id):
local_users.add(user_id)
else:
- destinations.setdefault(
- get_domain_from_id(user_id), set()
- ).add(user_id)
+ destinations.setdefault(get_domain_from_id(user_id), set()).add(user_id)
if not proxy and destinations:
raise SynapseError(400, "Some user_ids are not local")
@@ -474,16 +461,14 @@ class GroupsLocalHandler(object):
for destination, dest_user_ids in iteritems(destinations):
try:
r = yield self.transport_client.bulk_get_publicised_groups(
- destination, list(dest_user_ids),
+ destination, list(dest_user_ids)
)
results.update(r["users"])
except Exception:
failed_results.extend(dest_user_ids)
for uid in local_users:
- results[uid] = yield self.store.get_publicised_groups_for_user(
- uid
- )
+ results[uid] = yield self.store.get_publicised_groups_for_user(uid)
# Check AS associated groups for this user - this depends on the
# RegExps in the AS registration file (under `users`)
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 04caf65793..c82b1933f2 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -36,7 +36,6 @@ logger = logging.getLogger(__name__)
class IdentityHandler(BaseHandler):
-
def __init__(self, hs):
super(IdentityHandler, self).__init__(hs)
@@ -64,40 +63,38 @@ class IdentityHandler(BaseHandler):
@defer.inlineCallbacks
def threepid_from_creds(self, creds):
- if 'id_server' in creds:
- id_server = creds['id_server']
- elif 'idServer' in creds:
- id_server = creds['idServer']
+ if "id_server" in creds:
+ id_server = creds["id_server"]
+ elif "idServer" in creds:
+ id_server = creds["idServer"]
else:
raise SynapseError(400, "No id_server in creds")
- if 'client_secret' in creds:
- client_secret = creds['client_secret']
- elif 'clientSecret' in creds:
- client_secret = creds['clientSecret']
+ if "client_secret" in creds:
+ client_secret = creds["client_secret"]
+ elif "clientSecret" in creds:
+ client_secret = creds["clientSecret"]
else:
raise SynapseError(400, "No client_secret in creds")
if not self._should_trust_id_server(id_server):
logger.warn(
- '%s is not a trusted ID server: rejecting 3pid ' +
- 'credentials', id_server
+ "%s is not a trusted ID server: rejecting 3pid " + "credentials",
+ id_server,
)
defer.returnValue(None)
try:
data = yield self.http_client.get_json(
- "https://%s%s" % (
- id_server,
- "/_matrix/identity/api/v1/3pid/getValidated3pid"
- ),
- {'sid': creds['sid'], 'client_secret': client_secret}
+ "https://%s%s"
+ % (id_server, "/_matrix/identity/api/v1/3pid/getValidated3pid"),
+ {"sid": creds["sid"], "client_secret": client_secret},
)
except HttpResponseException as e:
logger.info("getValidated3pid failed with Matrix error: %r", e)
raise e.to_synapse_error()
- if 'medium' in data:
+ if "medium" in data:
defer.returnValue(data)
defer.returnValue(None)
@@ -106,30 +103,24 @@ class IdentityHandler(BaseHandler):
logger.debug("binding threepid %r to %s", creds, mxid)
data = None
- if 'id_server' in creds:
- id_server = creds['id_server']
- elif 'idServer' in creds:
- id_server = creds['idServer']
+ if "id_server" in creds:
+ id_server = creds["id_server"]
+ elif "idServer" in creds:
+ id_server = creds["idServer"]
else:
raise SynapseError(400, "No id_server in creds")
- if 'client_secret' in creds:
- client_secret = creds['client_secret']
- elif 'clientSecret' in creds:
- client_secret = creds['clientSecret']
+ if "client_secret" in creds:
+ client_secret = creds["client_secret"]
+ elif "clientSecret" in creds:
+ client_secret = creds["clientSecret"]
else:
raise SynapseError(400, "No client_secret in creds")
try:
data = yield self.http_client.post_urlencoded_get_json(
- "https://%s%s" % (
- id_server, "/_matrix/identity/api/v1/3pid/bind"
- ),
- {
- 'sid': creds['sid'],
- 'client_secret': client_secret,
- 'mxid': mxid,
- }
+ "https://%s%s" % (id_server, "/_matrix/identity/api/v1/3pid/bind"),
+ {"sid": creds["sid"], "client_secret": client_secret, "mxid": mxid},
)
logger.debug("bound threepid %r to %s", creds, mxid)
@@ -165,9 +156,7 @@ class IdentityHandler(BaseHandler):
id_servers = [threepid["id_server"]]
else:
id_servers = yield self.store.get_id_servers_user_bound(
- user_id=mxid,
- medium=threepid["medium"],
- address=threepid["address"],
+ user_id=mxid, medium=threepid["medium"], address=threepid["address"]
)
# We don't know where to unbind, so we don't have a choice but to return
@@ -177,7 +166,7 @@ class IdentityHandler(BaseHandler):
changed = True
for id_server in id_servers:
changed &= yield self.try_unbind_threepid_with_id_server(
- mxid, threepid, id_server,
+ mxid, threepid, id_server
)
defer.returnValue(changed)
@@ -201,10 +190,7 @@ class IdentityHandler(BaseHandler):
url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
content = {
"mxid": mxid,
- "threepid": {
- "medium": threepid["medium"],
- "address": threepid["address"],
- },
+ "threepid": {"medium": threepid["medium"], "address": threepid["address"]},
}
# we abuse the federation http client to sign the request, but we have to send it
@@ -212,25 +198,19 @@ class IdentityHandler(BaseHandler):
# 'browser-like' HTTPS.
auth_headers = self.federation_http_client.build_auth_headers(
destination=None,
- method='POST',
- url_bytes='/_matrix/identity/api/v1/3pid/unbind'.encode('ascii'),
+ method="POST",
+ url_bytes="/_matrix/identity/api/v1/3pid/unbind".encode("ascii"),
content=content,
destination_is=id_server,
)
- headers = {
- b"Authorization": auth_headers,
- }
+ headers = {b"Authorization": auth_headers}
try:
- yield self.http_client.post_json_get_json(
- url,
- content,
- headers,
- )
+ yield self.http_client.post_json_get_json(url, content, headers)
changed = True
except HttpResponseException as e:
changed = False
- if e.code in (400, 404, 501,):
+ if e.code in (400, 404, 501):
# The remote server probably doesn't support unbinding (yet)
logger.warn("Received %d response while unbinding threepid", e.code)
else:
@@ -248,35 +228,27 @@ class IdentityHandler(BaseHandler):
@defer.inlineCallbacks
def requestEmailToken(
- self,
- id_server,
- email,
- client_secret,
- send_attempt,
- next_link=None,
+ self, id_server, email, client_secret, send_attempt, next_link=None
):
if not self._should_trust_id_server(id_server):
raise SynapseError(
- 400, "Untrusted ID server '%s'" % id_server,
- Codes.SERVER_NOT_TRUSTED
+ 400, "Untrusted ID server '%s'" % id_server, Codes.SERVER_NOT_TRUSTED
)
params = {
- 'email': email,
- 'client_secret': client_secret,
- 'send_attempt': send_attempt,
+ "email": email,
+ "client_secret": client_secret,
+ "send_attempt": send_attempt,
}
if next_link:
- params.update({'next_link': next_link})
+ params.update({"next_link": next_link})
try:
data = yield self.http_client.post_json_get_json(
- "https://%s%s" % (
- id_server,
- "/_matrix/identity/api/v1/validate/email/requestToken"
- ),
- params
+ "https://%s%s"
+ % (id_server, "/_matrix/identity/api/v1/validate/email/requestToken"),
+ params,
)
defer.returnValue(data)
except HttpResponseException as e:
@@ -285,30 +257,26 @@ class IdentityHandler(BaseHandler):
@defer.inlineCallbacks
def requestMsisdnToken(
- self, id_server, country, phone_number,
- client_secret, send_attempt, **kwargs
+ self, id_server, country, phone_number, client_secret, send_attempt, **kwargs
):
if not self._should_trust_id_server(id_server):
raise SynapseError(
- 400, "Untrusted ID server '%s'" % id_server,
- Codes.SERVER_NOT_TRUSTED
+ 400, "Untrusted ID server '%s'" % id_server, Codes.SERVER_NOT_TRUSTED
)
params = {
- 'country': country,
- 'phone_number': phone_number,
- 'client_secret': client_secret,
- 'send_attempt': send_attempt,
+ "country": country,
+ "phone_number": phone_number,
+ "client_secret": client_secret,
+ "send_attempt": send_attempt,
}
params.update(kwargs)
try:
data = yield self.http_client.post_json_get_json(
- "https://%s%s" % (
- id_server,
- "/_matrix/identity/api/v1/validate/msisdn/requestToken"
- ),
- params
+ "https://%s%s"
+ % (id_server, "/_matrix/identity/api/v1/validate/msisdn/requestToken"),
+ params,
)
defer.returnValue(data)
except HttpResponseException as e:
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index aaee5db0b7..a1fe9d116f 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -44,8 +44,13 @@ class InitialSyncHandler(BaseHandler):
self.snapshot_cache = SnapshotCache()
self._event_serializer = hs.get_event_client_serializer()
- def snapshot_all_rooms(self, user_id=None, pagin_config=None,
- as_client_event=True, include_archived=False):
+ def snapshot_all_rooms(
+ self,
+ user_id=None,
+ pagin_config=None,
+ as_client_event=True,
+ include_archived=False,
+ ):
"""Retrieve a snapshot of all rooms the user is invited or has joined.
This snapshot may include messages for all rooms where the user is
@@ -77,13 +82,22 @@ class InitialSyncHandler(BaseHandler):
if result is not None:
return result
- return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms(
- user_id, pagin_config, as_client_event, include_archived
- ))
+ return self.snapshot_cache.set(
+ now_ms,
+ key,
+ self._snapshot_all_rooms(
+ user_id, pagin_config, as_client_event, include_archived
+ ),
+ )
@defer.inlineCallbacks
- def _snapshot_all_rooms(self, user_id=None, pagin_config=None,
- as_client_event=True, include_archived=False):
+ def _snapshot_all_rooms(
+ self,
+ user_id=None,
+ pagin_config=None,
+ as_client_event=True,
+ include_archived=False,
+ ):
memberships = [Membership.INVITE, Membership.JOIN]
if include_archived:
@@ -128,8 +142,7 @@ class InitialSyncHandler(BaseHandler):
"room_id": event.room_id,
"membership": event.membership,
"visibility": (
- "public" if event.room_id in public_room_ids
- else "private"
+ "public" if event.room_id in public_room_ids else "private"
),
}
@@ -139,7 +152,7 @@ class InitialSyncHandler(BaseHandler):
invite_event = yield self.store.get_event(event.event_id)
d["invite"] = yield self._event_serializer.serialize_event(
- invite_event, time_now, as_client_event,
+ invite_event, time_now, as_client_event
)
rooms_ret.append(d)
@@ -151,14 +164,12 @@ class InitialSyncHandler(BaseHandler):
if event.membership == Membership.JOIN:
room_end_token = now_token.room_key
deferred_room_state = run_in_background(
- self.state_handler.get_current_state,
- event.room_id,
+ self.state_handler.get_current_state, event.room_id
)
elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,)
deferred_room_state = run_in_background(
- self.store.get_state_for_events,
- [event.event_id],
+ self.store.get_state_for_events, [event.event_id]
)
deferred_room_state.addCallback(
lambda states: states[event.event_id]
@@ -178,9 +189,7 @@ class InitialSyncHandler(BaseHandler):
)
).addErrback(unwrapFirstError)
- messages = yield filter_events_for_client(
- self.store, user_id, messages
- )
+ messages = yield filter_events_for_client(self.store, user_id, messages)
start_token = now_token.copy_and_replace("room_key", token)
end_token = now_token.copy_and_replace("room_key", room_end_token)
@@ -189,8 +198,7 @@ class InitialSyncHandler(BaseHandler):
d["messages"] = {
"chunk": (
yield self._event_serializer.serialize_events(
- messages, time_now=time_now,
- as_client_event=as_client_event,
+ messages, time_now=time_now, as_client_event=as_client_event
)
),
"start": start_token.to_string(),
@@ -200,23 +208,21 @@ class InitialSyncHandler(BaseHandler):
d["state"] = yield self._event_serializer.serialize_events(
current_state.values(),
time_now=time_now,
- as_client_event=as_client_event
+ as_client_event=as_client_event,
)
account_data_events = []
tags = tags_by_room.get(event.room_id)
if tags:
- account_data_events.append({
- "type": "m.tag",
- "content": {"tags": tags},
- })
+ account_data_events.append(
+ {"type": "m.tag", "content": {"tags": tags}}
+ )
account_data = account_data_by_room.get(event.room_id, {})
for account_data_type, content in account_data.items():
- account_data_events.append({
- "type": account_data_type,
- "content": content,
- })
+ account_data_events.append(
+ {"type": account_data_type, "content": content}
+ )
d["account_data"] = account_data_events
except Exception:
@@ -226,10 +232,7 @@ class InitialSyncHandler(BaseHandler):
account_data_events = []
for account_data_type, content in account_data.items():
- account_data_events.append({
- "type": account_data_type,
- "content": content,
- })
+ account_data_events.append({"type": account_data_type, "content": content})
now = self.clock.time_msec()
@@ -274,7 +277,7 @@ class InitialSyncHandler(BaseHandler):
user_id = requester.user.to_string()
membership, member_event_id = yield self._check_in_room_or_world_readable(
- room_id, user_id,
+ room_id, user_id
)
is_peeking = member_event_id is None
@@ -290,28 +293,21 @@ class InitialSyncHandler(BaseHandler):
account_data_events = []
tags = yield self.store.get_tags_for_room(user_id, room_id)
if tags:
- account_data_events.append({
- "type": "m.tag",
- "content": {"tags": tags},
- })
+ account_data_events.append({"type": "m.tag", "content": {"tags": tags}})
account_data = yield self.store.get_account_data_for_room(user_id, room_id)
for account_data_type, content in account_data.items():
- account_data_events.append({
- "type": account_data_type,
- "content": content,
- })
+ account_data_events.append({"type": account_data_type, "content": content})
result["account_data"] = account_data_events
defer.returnValue(result)
@defer.inlineCallbacks
- def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
- membership, member_event_id, is_peeking):
- room_state = yield self.store.get_state_for_events(
- [member_event_id],
- )
+ def _room_initial_sync_parted(
+ self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
+ ):
+ room_state = yield self.store.get_state_for_events([member_event_id])
room_state = room_state[member_event_id]
@@ -319,14 +315,10 @@ class InitialSyncHandler(BaseHandler):
if limit is None:
limit = 10
- stream_token = yield self.store.get_stream_token_for_event(
- member_event_id
- )
+ stream_token = yield self.store.get_stream_token_for_event(member_event_id)
messages, token = yield self.store.get_recent_events_for_room(
- room_id,
- limit=limit,
- end_token=stream_token
+ room_id, limit=limit, end_token=stream_token
)
messages = yield filter_events_for_client(
@@ -338,34 +330,39 @@ class InitialSyncHandler(BaseHandler):
time_now = self.clock.time_msec()
- defer.returnValue({
- "membership": membership,
- "room_id": room_id,
- "messages": {
- "chunk": (yield self._event_serializer.serialize_events(
- messages, time_now,
- )),
- "start": start_token.to_string(),
- "end": end_token.to_string(),
- },
- "state": (yield self._event_serializer.serialize_events(
- room_state.values(), time_now,
- )),
- "presence": [],
- "receipts": [],
- })
+ defer.returnValue(
+ {
+ "membership": membership,
+ "room_id": room_id,
+ "messages": {
+ "chunk": (
+ yield self._event_serializer.serialize_events(
+ messages, time_now
+ )
+ ),
+ "start": start_token.to_string(),
+ "end": end_token.to_string(),
+ },
+ "state": (
+ yield self._event_serializer.serialize_events(
+ room_state.values(), time_now
+ )
+ ),
+ "presence": [],
+ "receipts": [],
+ }
+ )
@defer.inlineCallbacks
- def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
- membership, is_peeking):
- current_state = yield self.state.get_current_state(
- room_id=room_id,
- )
+ def _room_initial_sync_joined(
+ self, user_id, room_id, pagin_config, membership, is_peeking
+ ):
+ current_state = yield self.state.get_current_state(room_id=room_id)
# TODO: These concurrently
time_now = self.clock.time_msec()
state = yield self._event_serializer.serialize_events(
- current_state.values(), time_now,
+ current_state.values(), time_now
)
now_token = yield self.hs.get_event_sources().get_current_token()
@@ -375,7 +372,8 @@ class InitialSyncHandler(BaseHandler):
limit = 10
room_members = [
- m for m in current_state.values()
+ m
+ for m in current_state.values()
if m.type == EventTypes.Member
and m.content["membership"] == Membership.JOIN
]
@@ -389,8 +387,7 @@ class InitialSyncHandler(BaseHandler):
defer.returnValue([])
states = yield presence_handler.get_states(
- [m.user_id for m in room_members],
- as_event=True,
+ [m.user_id for m in room_members], as_event=True
)
defer.returnValue(states)
@@ -398,8 +395,7 @@ class InitialSyncHandler(BaseHandler):
@defer.inlineCallbacks
def get_receipts():
receipts = yield self.store.get_linearized_receipts_for_room(
- room_id,
- to_key=now_token.receipt_key,
+ room_id, to_key=now_token.receipt_key
)
if not receipts:
receipts = []
@@ -415,14 +411,14 @@ class InitialSyncHandler(BaseHandler):
room_id,
limit=limit,
end_token=now_token.room_key,
- )
+ ),
],
consumeErrors=True,
- ).addErrback(unwrapFirstError),
+ ).addErrback(unwrapFirstError)
)
messages = yield filter_events_for_client(
- self.store, user_id, messages, is_peeking=is_peeking,
+ self.store, user_id, messages, is_peeking=is_peeking
)
start_token = now_token.copy_and_replace("room_key", token)
@@ -433,9 +429,9 @@ class InitialSyncHandler(BaseHandler):
ret = {
"room_id": room_id,
"messages": {
- "chunk": (yield self._event_serializer.serialize_events(
- messages, time_now,
- )),
+ "chunk": (
+ yield self._event_serializer.serialize_events(messages, time_now)
+ ),
"start": start_token.to_string(),
"end": end_token.to_string(),
},
@@ -464,8 +460,8 @@ class InitialSyncHandler(BaseHandler):
room_id, EventTypes.RoomHistoryVisibility, ""
)
if (
- visibility and
- visibility.content["history_visibility"] == "world_readable"
+ visibility
+ and visibility.content["history_visibility"] == "world_readable"
):
defer.returnValue((Membership.JOIN, None))
return
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 0b02469ceb..683da6bf32 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
-# Copyright 2014 - 2016 OpenMarket Ltd
-# Copyright 2017 - 2018 New Vector Ltd
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -33,9 +34,10 @@ from synapse.api.errors import (
from synapse.api.room_versions import RoomVersions
from synapse.api.urls import ConsentURIBuilder
from synapse.events.validator import EventValidator
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.state import StateFilter
-from synapse.types import RoomAlias, UserID
+from synapse.types import RoomAlias, UserID, create_requester
from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.logcontext import run_in_background
@@ -59,8 +61,9 @@ class MessageHandler(object):
self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
- def get_room_data(self, user_id=None, room_id=None,
- event_type=None, state_key="", is_guest=False):
+ def get_room_data(
+ self, user_id=None, room_id=None, event_type=None, state_key="", is_guest=False
+ ):
""" Get data from a room.
Args:
@@ -75,9 +78,7 @@ class MessageHandler(object):
)
if membership == Membership.JOIN:
- data = yield self.state.get_current_state(
- room_id, event_type, state_key
- )
+ data = yield self.state.get_current_state(room_id, event_type, state_key)
elif membership == Membership.LEAVE:
key = (event_type, state_key)
room_state = yield self.store.get_state_for_events(
@@ -89,8 +90,12 @@ class MessageHandler(object):
@defer.inlineCallbacks
def get_state_events(
- self, user_id, room_id, state_filter=StateFilter.all(),
- at_token=None, is_guest=False,
+ self,
+ user_id,
+ room_id,
+ state_filter=StateFilter.all(),
+ at_token=None,
+ is_guest=False,
):
"""Retrieve all state events for a given room. If the user is
joined to the room then return the current state. If the user has
@@ -122,50 +127,48 @@ class MessageHandler(object):
# does not reliably give you the state at the given stream position.
# (https://github.com/matrix-org/synapse/issues/3305)
last_events, _ = yield self.store.get_recent_events_for_room(
- room_id, end_token=at_token.room_key, limit=1,
+ room_id, end_token=at_token.room_key, limit=1
)
if not last_events:
- raise NotFoundError("Can't find event for token %s" % (at_token, ))
+ raise NotFoundError("Can't find event for token %s" % (at_token,))
visible_events = yield filter_events_for_client(
- self.store, user_id, last_events,
+ self.store, user_id, last_events
)
event = last_events[0]
if visible_events:
room_state = yield self.store.get_state_for_events(
- [event.event_id], state_filter=state_filter,
+ [event.event_id], state_filter=state_filter
)
room_state = room_state[event.event_id]
else:
raise AuthError(
403,
- "User %s not allowed to view events in room %s at token %s" % (
- user_id, room_id, at_token,
- )
+ "User %s not allowed to view events in room %s at token %s"
+ % (user_id, room_id, at_token),
)
else:
membership, membership_event_id = (
- yield self.auth.check_in_room_or_world_readable(
- room_id, user_id,
- )
+ yield self.auth.check_in_room_or_world_readable(room_id, user_id)
)
if membership == Membership.JOIN:
state_ids = yield self.store.get_filtered_current_state_ids(
- room_id, state_filter=state_filter,
+ room_id, state_filter=state_filter
)
room_state = yield self.store.get_events(state_ids.values())
elif membership == Membership.LEAVE:
room_state = yield self.store.get_state_for_events(
- [membership_event_id], state_filter=state_filter,
+ [membership_event_id], state_filter=state_filter
)
room_state = room_state[membership_event_id]
now = self.clock.time_msec()
events = yield self._event_serializer.serialize_events(
- room_state.values(), now,
+ room_state.values(),
+ now,
# We don't bother bundling aggregations in when asked for state
# events, as clients won't use them.
bundle_aggregations=False,
@@ -209,13 +212,15 @@ class MessageHandler(object):
# Loop fell through, AS has no interested users in room
raise AuthError(403, "Appservice not in room")
- defer.returnValue({
- user_id: {
- "avatar_url": profile.avatar_url,
- "display_name": profile.display_name,
+ defer.returnValue(
+ {
+ user_id: {
+ "avatar_url": profile.avatar_url,
+ "display_name": profile.display_name,
+ }
+ for user_id, profile in iteritems(users_with_profile)
}
- for user_id, profile in iteritems(users_with_profile)
- })
+ )
class EventCreationHandler(object):
@@ -248,6 +253,7 @@ class EventCreationHandler(object):
self.action_generator = hs.get_action_generator()
self.spam_checker = hs.get_spam_checker()
+ self.third_party_event_rules = hs.get_third_party_event_rules()
self._block_events_without_consent_error = (
self.config.block_events_without_consent_error
@@ -259,9 +265,28 @@ class EventCreationHandler(object):
if self._block_events_without_consent_error:
self._consent_uri_builder = ConsentURIBuilder(self.config)
+ if (
+ not self.config.worker_app
+ and self.config.cleanup_extremities_with_dummy_events
+ ):
+ self.clock.looping_call(
+ lambda: run_as_background_process(
+ "send_dummy_events_to_fill_extremities",
+ self._send_dummy_events_to_fill_extremities,
+ ),
+ 5 * 60 * 1000,
+ )
+
@defer.inlineCallbacks
- def create_event(self, requester, event_dict, token_id=None, txn_id=None,
- prev_events_and_hashes=None, require_consent=True):
+ def create_event(
+ self,
+ requester,
+ event_dict,
+ token_id=None,
+ txn_id=None,
+ prev_events_and_hashes=None,
+ require_consent=True,
+ ):
"""
Given a dict from a client, create a new event.
@@ -321,8 +346,7 @@ class EventCreationHandler(object):
content["avatar_url"] = yield profile.get_avatar_url(target)
except Exception as e:
logger.info(
- "Failed to get profile information for %r: %s",
- target, e
+ "Failed to get profile information for %r: %s", target, e
)
is_exempt = yield self._is_exempt_from_privacy_policy(builder, requester)
@@ -358,16 +382,17 @@ class EventCreationHandler(object):
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event or prev_event.membership != Membership.JOIN:
logger.warning(
- ("Attempt to send `m.room.aliases` in room %s by user %s but"
- " membership is %s"),
+ (
+ "Attempt to send `m.room.aliases` in room %s by user %s but"
+ " membership is %s"
+ ),
event.room_id,
event.sender,
prev_event.membership if prev_event else None,
)
raise AuthError(
- 403,
- "You must be in the room to create an alias for it",
+ 403, "You must be in the room to create an alias for it"
)
self.validator.validate_new(event)
@@ -434,8 +459,8 @@ class EventCreationHandler(object):
# exempt the system notices user
if (
- self.config.server_notices_mxid is not None and
- user_id == self.config.server_notices_mxid
+ self.config.server_notices_mxid is not None
+ and user_id == self.config.server_notices_mxid
):
return
@@ -448,15 +473,10 @@ class EventCreationHandler(object):
return
consent_uri = self._consent_uri_builder.build_user_consent_uri(
- requester.user.localpart,
- )
- msg = self._block_events_without_consent_error % {
- 'consent_uri': consent_uri,
- }
- raise ConsentNotGivenError(
- msg=msg,
- consent_uri=consent_uri,
+ requester.user.localpart
)
+ msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
+ raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
@defer.inlineCallbacks
def send_nonmember_event(self, requester, event, context, ratelimit=True):
@@ -471,8 +491,7 @@ class EventCreationHandler(object):
"""
if event.type == EventTypes.Member:
raise SynapseError(
- 500,
- "Tried to send member event through non-member codepath"
+ 500, "Tried to send member event through non-member codepath"
)
user = UserID.from_string(event.sender)
@@ -484,15 +503,13 @@ class EventCreationHandler(object):
if prev_state is not None:
logger.info(
"Not bothering to persist state event %s duplicated by %s",
- event.event_id, prev_state.event_id,
+ event.event_id,
+ prev_state.event_id,
)
defer.returnValue(prev_state)
yield self.handle_new_client_event(
- requester=requester,
- event=event,
- context=context,
- ratelimit=ratelimit,
+ requester=requester, event=event, context=context, ratelimit=ratelimit
)
@defer.inlineCallbacks
@@ -518,11 +535,7 @@ class EventCreationHandler(object):
@defer.inlineCallbacks
def create_and_send_nonmember_event(
- self,
- requester,
- event_dict,
- ratelimit=True,
- txn_id=None
+ self, requester, event_dict, ratelimit=True, txn_id=None
):
"""
Creates an event, then sends it.
@@ -537,32 +550,25 @@ class EventCreationHandler(object):
# taking longer.
with (yield self.limiter.queue(event_dict["room_id"])):
event, context = yield self.create_event(
- requester,
- event_dict,
- token_id=requester.access_token_id,
- txn_id=txn_id
+ requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id
)
spam_error = self.spam_checker.check_event_for_spam(event)
if spam_error:
if not isinstance(spam_error, string_types):
spam_error = "Spam is not permitted here"
- raise SynapseError(
- 403, spam_error, Codes.FORBIDDEN
- )
+ raise SynapseError(403, spam_error, Codes.FORBIDDEN)
yield self.send_nonmember_event(
- requester,
- event,
- context,
- ratelimit=ratelimit,
+ requester, event, context, ratelimit=ratelimit
)
defer.returnValue(event)
@measure_func("create_new_client_event")
@defer.inlineCallbacks
- def create_new_client_event(self, builder, requester=None,
- prev_events_and_hashes=None):
+ def create_new_client_event(
+ self, builder, requester=None, prev_events_and_hashes=None
+ ):
"""Create a new event for a local client
Args:
@@ -582,22 +588,21 @@ class EventCreationHandler(object):
"""
if prev_events_and_hashes is not None:
- assert len(prev_events_and_hashes) <= 10, \
- "Attempting to create an event with %i prev_events" % (
- len(prev_events_and_hashes),
+ assert len(prev_events_and_hashes) <= 10, (
+ "Attempting to create an event with %i prev_events"
+ % (len(prev_events_and_hashes),)
)
else:
- prev_events_and_hashes = \
- yield self.store.get_prev_events_for_room(builder.room_id)
+ prev_events_and_hashes = yield self.store.get_prev_events_for_room(
+ builder.room_id
+ )
prev_events = [
(event_id, prev_hashes)
for event_id, prev_hashes, _ in prev_events_and_hashes
]
- event = yield builder.build(
- prev_event_ids=[p for p, _ in prev_events],
- )
+ event = yield builder.build(prev_event_ids=[p for p, _ in prev_events])
context = yield self.state.compute_event_context(event)
if requester:
context.app_service = requester.app_service
@@ -613,29 +618,19 @@ class EventCreationHandler(object):
aggregation_key = relation["key"]
already_exists = yield self.store.has_user_annotated_event(
- relates_to, event.type, aggregation_key, event.sender,
+ relates_to, event.type, aggregation_key, event.sender
)
if already_exists:
raise SynapseError(400, "Can't send same reaction twice")
- logger.debug(
- "Created event %s",
- event.event_id,
- )
+ logger.debug("Created event %s", event.event_id)
- defer.returnValue(
- (event, context,)
- )
+ defer.returnValue((event, context))
@measure_func("handle_new_client_event")
@defer.inlineCallbacks
def handle_new_client_event(
- self,
- requester,
- event,
- context,
- ratelimit=True,
- extra_users=[],
+ self, requester, event, context, ratelimit=True, extra_users=[]
):
"""Processes a new event. This includes checking auth, persisting it,
notifying users, sending to remote servers, etc.
@@ -651,13 +646,22 @@ class EventCreationHandler(object):
extra_users (list(UserID)): Any extra users to notify about event
"""
- if event.is_state() and (event.type, event.state_key) == (EventTypes.Create, ""):
- room_version = event.content.get(
- "room_version", RoomVersions.V1.identifier
- )
+ if event.is_state() and (event.type, event.state_key) == (
+ EventTypes.Create,
+ "",
+ ):
+ room_version = event.content.get("room_version", RoomVersions.V1.identifier)
else:
room_version = yield self.store.get_room_version(event.room_id)
+ event_allowed = yield self.third_party_event_rules.check_event_allowed(
+ event, context
+ )
+ if not event_allowed:
+ raise SynapseError(
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
+ )
+
try:
yield self.auth.check_from_context(room_version, event, context)
except AuthError as err:
@@ -672,9 +676,7 @@ class EventCreationHandler(object):
logger.exception("Failed to encode content: %r", event.content)
raise
- yield self.action_generator.handle_push_actions_for_event(
- event, context
- )
+ yield self.action_generator.handle_push_actions_for_event(event, context)
# reraise does not allow inlineCallbacks to preserve the stacktrace, so we
# hack around with a try/finally instead.
@@ -695,11 +697,7 @@ class EventCreationHandler(object):
return
yield self.persist_and_notify_client_event(
- requester,
- event,
- context,
- ratelimit=ratelimit,
- extra_users=extra_users,
+ requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
success = True
@@ -708,18 +706,12 @@ class EventCreationHandler(object):
# Ensure that we actually remove the entries in the push actions
# staging area, if we calculated them.
run_in_background(
- self.store.remove_push_actions_from_staging,
- event.event_id,
+ self.store.remove_push_actions_from_staging, event.event_id
)
@defer.inlineCallbacks
def persist_and_notify_client_event(
- self,
- requester,
- event,
- context,
- ratelimit=True,
- extra_users=[],
+ self, requester, event, context, ratelimit=True, extra_users=[]
):
"""Called when we have fully built the event, have already
calculated the push actions for the event, and checked auth.
@@ -744,20 +736,16 @@ class EventCreationHandler(object):
if mapping["room_id"] != event.room_id:
raise SynapseError(
400,
- "Room alias %s does not point to the room" % (
- room_alias_str,
- )
+ "Room alias %s does not point to the room" % (room_alias_str,),
)
federation_handler = self.hs.get_handlers().federation_handler
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE:
+
def is_inviter_member_event(e):
- return (
- e.type == EventTypes.Member and
- e.sender == event.sender
- )
+ return e.type == EventTypes.Member and e.sender == event.sender
current_state_ids = yield context.get_current_state_ids(self.store)
@@ -787,26 +775,21 @@ class EventCreationHandler(object):
# to get them to sign the event.
returned_invite = yield federation_handler.send_invite(
- invitee.domain,
- event,
+ invitee.domain, event
)
event.unsigned.pop("room_state", None)
# TODO: Make sure the signatures actually are correct.
- event.signatures.update(
- returned_invite.signatures
- )
+ event.signatures.update(returned_invite.signatures)
if event.type == EventTypes.Redaction:
prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.auth.compute_auth_events(
- event, prev_state_ids, for_verification=True,
+ event, prev_state_ids, for_verification=True
)
auth_events = yield self.store.get_events(auth_events_ids)
- auth_events = {
- (e.type, e.state_key): e for e in auth_events.values()
- }
+ auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
room_version = yield self.store.get_room_version(event.room_id)
if self.auth.check_redaction(room_version, event, auth_events=auth_events):
original_event = yield self.store.get_event(
@@ -814,13 +797,10 @@ class EventCreationHandler(object):
check_redacted=False,
get_prev_content=False,
allow_rejected=False,
- allow_none=False
+ allow_none=False,
)
if event.user_id != original_event.user_id:
- raise AuthError(
- 403,
- "You don't have permission to redact events"
- )
+ raise AuthError(403, "You don't have permission to redact events")
# We've already checked.
event.internal_metadata.recheck_redaction = False
@@ -828,24 +808,18 @@ class EventCreationHandler(object):
if event.type == EventTypes.Create:
prev_state_ids = yield context.get_prev_state_ids(self.store)
if prev_state_ids:
- raise AuthError(
- 403,
- "Changing the room create event is forbidden",
- )
+ raise AuthError(403, "Changing the room create event is forbidden")
(event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context
)
- yield self.pusher_pool.on_new_notifications(
- event_stream_id, max_stream_id,
- )
+ yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
def _notify():
try:
self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id,
- extra_users=extra_users
+ event, event_stream_id, max_stream_id, extra_users=extra_users
)
except Exception:
logger.exception("Error notifying about new room event")
@@ -864,3 +838,54 @@ class EventCreationHandler(object):
yield presence.bump_presence_active_time(user)
except Exception:
logger.exception("Error bumping presence active time")
+
+ @defer.inlineCallbacks
+ def _send_dummy_events_to_fill_extremities(self):
+ """Background task to send dummy events into rooms that have a large
+ number of extremities
+ """
+
+ room_ids = yield self.store.get_rooms_with_many_extremities(
+ min_count=10, limit=5
+ )
+
+ for room_id in room_ids:
+ # For each room we need to find a joined member we can use to send
+ # the dummy event with.
+
+ prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id)
+
+ latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes)
+
+ members = yield self.state.get_current_users_in_room(
+ room_id, latest_event_ids=latest_event_ids
+ )
+
+ user_id = None
+ for member in members:
+ if self.hs.is_mine_id(member):
+ user_id = member
+ break
+
+ if not user_id:
+ # We don't have a joined user.
+ # TODO: We should do something here to stop the room from
+ # appearing next time.
+ continue
+
+ requester = create_requester(user_id)
+
+ event, context = yield self.create_event(
+ requester,
+ {
+ "type": "org.matrix.dummy_event",
+ "content": {},
+ "room_id": room_id,
+ "sender": user_id,
+ },
+ prev_events_and_hashes=prev_events_and_hashes,
+ )
+
+ event.internal_metadata.proactively_send = False
+
+ yield self.send_nonmember_event(requester, event, context, ratelimit=False)
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 8f811e24fe..76ee97ddd3 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -55,9 +55,7 @@ class PurgeStatus(object):
self.status = PurgeStatus.STATUS_ACTIVE
def asdict(self):
- return {
- "status": PurgeStatus.STATUS_TEXT[self.status]
- }
+ return {"status": PurgeStatus.STATUS_TEXT[self.status]}
class PaginationHandler(object):
@@ -79,8 +77,7 @@ class PaginationHandler(object):
self._purges_by_id = {}
self._event_serializer = hs.get_event_client_serializer()
- def start_purge_history(self, room_id, token,
- delete_local_events=False):
+ def start_purge_history(self, room_id, token, delete_local_events=False):
"""Start off a history purge on a room.
Args:
@@ -95,8 +92,7 @@ class PaginationHandler(object):
"""
if room_id in self._purges_in_progress_by_room:
raise SynapseError(
- 400,
- "History purge already in progress for %s" % (room_id, ),
+ 400, "History purge already in progress for %s" % (room_id,)
)
purge_id = random_string(16)
@@ -107,14 +103,12 @@ class PaginationHandler(object):
self._purges_by_id[purge_id] = PurgeStatus()
run_in_background(
- self._purge_history,
- purge_id, room_id, token, delete_local_events,
+ self._purge_history, purge_id, room_id, token, delete_local_events
)
return purge_id
@defer.inlineCallbacks
- def _purge_history(self, purge_id, room_id, token,
- delete_local_events):
+ def _purge_history(self, purge_id, room_id, token, delete_local_events):
"""Carry out a history purge on a room.
Args:
@@ -130,16 +124,13 @@ class PaginationHandler(object):
self._purges_in_progress_by_room.add(room_id)
try:
with (yield self.pagination_lock.write(room_id)):
- yield self.store.purge_history(
- room_id, token, delete_local_events,
- )
+ yield self.store.purge_history(room_id, token, delete_local_events)
logger.info("[purge] complete")
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE
except Exception:
f = Failure()
logger.error(
- "[purge] failed",
- exc_info=(f.type, f.value, f.getTracebackObject()),
+ "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject())
)
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED
finally:
@@ -148,6 +139,7 @@ class PaginationHandler(object):
# remove the purge from the list 24 hours after it completes
def clear_purge():
del self._purges_by_id[purge_id]
+
self.hs.get_reactor().callLater(24 * 3600, clear_purge)
def get_purge_status(self, purge_id):
@@ -162,8 +154,14 @@ class PaginationHandler(object):
return self._purges_by_id.get(purge_id)
@defer.inlineCallbacks
- def get_messages(self, requester, room_id=None, pagin_config=None,
- as_client_event=True, event_filter=None):
+ def get_messages(
+ self,
+ requester,
+ room_id=None,
+ pagin_config=None,
+ as_client_event=True,
+ event_filter=None,
+ ):
"""Get messages in a room.
Args:
@@ -182,9 +180,7 @@ class PaginationHandler(object):
room_token = pagin_config.from_token.room_key
else:
pagin_config.from_token = (
- yield self.hs.get_event_sources().get_current_token_for_room(
- room_id=room_id
- )
+ yield self.hs.get_event_sources().get_current_token_for_pagination()
)
room_token = pagin_config.from_token.room_key
@@ -201,7 +197,7 @@ class PaginationHandler(object):
room_id, user_id
)
- if source_config.direction == 'b':
+ if source_config.direction == "b":
# if we're going backwards, we might need to backfill. This
# requires that we have a topo token.
if room_token.topological:
@@ -235,27 +231,24 @@ class PaginationHandler(object):
event_filter=event_filter,
)
- next_token = pagin_config.from_token.copy_and_replace(
- "room_key", next_key
- )
+ next_token = pagin_config.from_token.copy_and_replace("room_key", next_key)
if events:
if event_filter:
events = event_filter.filter(events)
events = yield filter_events_for_client(
- self.store,
- user_id,
- events,
- is_peeking=(member_event_id is None),
+ self.store, user_id, events, is_peeking=(member_event_id is None)
)
if not events:
- defer.returnValue({
- "chunk": [],
- "start": pagin_config.from_token.to_string(),
- "end": next_token.to_string(),
- })
+ defer.returnValue(
+ {
+ "chunk": [],
+ "start": pagin_config.from_token.to_string(),
+ "end": next_token.to_string(),
+ }
+ )
state = None
if event_filter and event_filter.lazy_load_members() and len(events) > 0:
@@ -263,12 +256,11 @@ class PaginationHandler(object):
# FIXME: we also care about invite targets etc.
state_filter = StateFilter.from_types(
- (EventTypes.Member, event.sender)
- for event in events
+ (EventTypes.Member, event.sender) for event in events
)
state_ids = yield self.store.get_state_ids_for_event(
- events[0].event_id, state_filter=state_filter,
+ events[0].event_id, state_filter=state_filter
)
if state_ids:
@@ -280,8 +272,7 @@ class PaginationHandler(object):
chunk = {
"chunk": (
yield self._event_serializer.serialize_events(
- events, time_now,
- as_client_event=as_client_event,
+ events, time_now, as_client_event=as_client_event
)
),
"start": pagin_config.from_token.to_string(),
@@ -291,8 +282,7 @@ class PaginationHandler(object):
if state:
chunk["state"] = (
yield self._event_serializer.serialize_events(
- state, time_now,
- as_client_event=as_client_event,
+ state, time_now, as_client_event=as_client_event
)
)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 557fb5f83d..5204073a38 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -50,16 +50,20 @@ logger = logging.getLogger(__name__)
notified_presence_counter = Counter("synapse_handler_presence_notified_presence", "")
federation_presence_out_counter = Counter(
- "synapse_handler_presence_federation_presence_out", "")
+ "synapse_handler_presence_federation_presence_out", ""
+)
presence_updates_counter = Counter("synapse_handler_presence_presence_updates", "")
timers_fired_counter = Counter("synapse_handler_presence_timers_fired", "")
-federation_presence_counter = Counter("synapse_handler_presence_federation_presence", "")
+federation_presence_counter = Counter(
+ "synapse_handler_presence_federation_presence", ""
+)
bump_active_time_counter = Counter("synapse_handler_presence_bump_active_time", "")
get_updates_counter = Counter("synapse_handler_presence_get_updates", "", ["type"])
notify_reason_counter = Counter(
- "synapse_handler_presence_notify_reason", "", ["reason"])
+ "synapse_handler_presence_notify_reason", "", ["reason"]
+)
state_transition_counter = Counter(
"synapse_handler_presence_state_transition", "", ["from", "to"]
)
@@ -90,7 +94,6 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
class PresenceHandler(object):
-
def __init__(self, hs):
"""
@@ -110,31 +113,26 @@ class PresenceHandler(object):
federation_registry = hs.get_federation_registry()
- federation_registry.register_edu_handler(
- "m.presence", self.incoming_presence
- )
+ federation_registry.register_edu_handler("m.presence", self.incoming_presence)
active_presence = self.store.take_presence_startup_info()
# A dictionary of the current state of users. This is prefilled with
# non-offline presence from the DB. We should fetch from the DB if
# we can't find a users presence in here.
- self.user_to_current_state = {
- state.user_id: state
- for state in active_presence
- }
+ self.user_to_current_state = {state.user_id: state for state in active_presence}
LaterGauge(
- "synapse_handlers_presence_user_to_current_state_size", "", [],
- lambda: len(self.user_to_current_state)
+ "synapse_handlers_presence_user_to_current_state_size",
+ "",
+ [],
+ lambda: len(self.user_to_current_state),
)
now = self.clock.time_msec()
for state in active_presence:
self.wheel_timer.insert(
- now=now,
- obj=state.user_id,
- then=state.last_active_ts + IDLE_TIMER,
+ now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER
)
self.wheel_timer.insert(
now=now,
@@ -193,27 +191,21 @@ class PresenceHandler(object):
"handle_presence_timeouts", self._handle_timeouts
)
- self.clock.call_later(
- 30,
- self.clock.looping_call,
- run_timeout_handler,
- 5000,
- )
+ self.clock.call_later(30, self.clock.looping_call, run_timeout_handler, 5000)
def run_persister():
return run_as_background_process(
"persist_presence_changes", self._persist_unpersisted_changes
)
- self.clock.call_later(
- 60,
- self.clock.looping_call,
- run_persister,
- 60 * 1000,
- )
+ self.clock.call_later(60, self.clock.looping_call, run_persister, 60 * 1000)
- LaterGauge("synapse_handlers_presence_wheel_timer_size", "", [],
- lambda: len(self.wheel_timer))
+ LaterGauge(
+ "synapse_handlers_presence_wheel_timer_size",
+ "",
+ [],
+ lambda: len(self.wheel_timer),
+ )
# Used to handle sending of presence to newly joined users/servers
if hs.config.use_presence:
@@ -241,15 +233,17 @@ class PresenceHandler(object):
logger.info(
"Performing _on_shutdown. Persisting %d unpersisted changes",
- len(self.user_to_current_state)
+ len(self.user_to_current_state),
)
if self.unpersisted_users_changes:
- yield self.store.update_presence([
- self.user_to_current_state[user_id]
- for user_id in self.unpersisted_users_changes
- ])
+ yield self.store.update_presence(
+ [
+ self.user_to_current_state[user_id]
+ for user_id in self.unpersisted_users_changes
+ ]
+ )
logger.info("Finished _on_shutdown")
@defer.inlineCallbacks
@@ -261,13 +255,10 @@ class PresenceHandler(object):
self.unpersisted_users_changes = set()
if unpersisted:
- logger.info(
- "Persisting %d upersisted presence updates", len(unpersisted)
+ logger.info("Persisting %d upersisted presence updates", len(unpersisted))
+ yield self.store.update_presence(
+ [self.user_to_current_state[user_id] for user_id in unpersisted]
)
- yield self.store.update_presence([
- self.user_to_current_state[user_id]
- for user_id in unpersisted
- ])
@defer.inlineCallbacks
def _update_states(self, new_states):
@@ -303,10 +294,11 @@ class PresenceHandler(object):
)
new_state, should_notify, should_ping = handle_update(
- prev_state, new_state,
+ prev_state,
+ new_state,
is_mine=self.is_mine_id(user_id),
wheel_timer=self.wheel_timer,
- now=now
+ now=now,
)
self.user_to_current_state[user_id] = new_state
@@ -328,7 +320,8 @@ class PresenceHandler(object):
self.unpersisted_users_changes -= set(to_notify.keys())
to_federation_ping = {
- user_id: state for user_id, state in to_federation_ping.items()
+ user_id: state
+ for user_id, state in to_federation_ping.items()
if user_id not in to_notify
}
if to_federation_ping:
@@ -351,8 +344,8 @@ class PresenceHandler(object):
# Check whether the lists of syncing processes from an external
# process have expired.
expired_process_ids = [
- process_id for process_id, last_update
- in self.external_process_last_updated_ms.items()
+ process_id
+ for process_id, last_update in self.external_process_last_updated_ms.items()
if now - last_update > EXTERNAL_PROCESS_EXPIRY
]
for process_id in expired_process_ids:
@@ -362,9 +355,7 @@ class PresenceHandler(object):
self.external_process_last_update.pop(process_id)
states = [
- self.user_to_current_state.get(
- user_id, UserPresenceState.default(user_id)
- )
+ self.user_to_current_state.get(user_id, UserPresenceState.default(user_id))
for user_id in users_to_check
]
@@ -394,9 +385,7 @@ class PresenceHandler(object):
prev_state = yield self.current_state_for_user(user_id)
- new_fields = {
- "last_active_ts": self.clock.time_msec(),
- }
+ new_fields = {"last_active_ts": self.clock.time_msec()}
if prev_state.state == PresenceState.UNAVAILABLE:
new_fields["state"] = PresenceState.ONLINE
@@ -430,15 +419,23 @@ class PresenceHandler(object):
if prev_state.state == PresenceState.OFFLINE:
# If they're currently offline then bring them online, otherwise
# just update the last sync times.
- yield self._update_states([prev_state.copy_and_replace(
- state=PresenceState.ONLINE,
- last_active_ts=self.clock.time_msec(),
- last_user_sync_ts=self.clock.time_msec(),
- )])
+ yield self._update_states(
+ [
+ prev_state.copy_and_replace(
+ state=PresenceState.ONLINE,
+ last_active_ts=self.clock.time_msec(),
+ last_user_sync_ts=self.clock.time_msec(),
+ )
+ ]
+ )
else:
- yield self._update_states([prev_state.copy_and_replace(
- last_user_sync_ts=self.clock.time_msec(),
- )])
+ yield self._update_states(
+ [
+ prev_state.copy_and_replace(
+ last_user_sync_ts=self.clock.time_msec()
+ )
+ ]
+ )
@defer.inlineCallbacks
def _end():
@@ -446,9 +443,13 @@ class PresenceHandler(object):
self.user_to_num_current_syncs[user_id] -= 1
prev_state = yield self.current_state_for_user(user_id)
- yield self._update_states([prev_state.copy_and_replace(
- last_user_sync_ts=self.clock.time_msec(),
- )])
+ yield self._update_states(
+ [
+ prev_state.copy_and_replace(
+ last_user_sync_ts=self.clock.time_msec()
+ )
+ ]
+ )
except Exception:
logger.exception("Error updating presence after sync")
@@ -469,7 +470,8 @@ class PresenceHandler(object):
"""
if self.hs.config.use_presence:
syncing_user_ids = {
- user_id for user_id, count in self.user_to_num_current_syncs.items()
+ user_id
+ for user_id, count in self.user_to_num_current_syncs.items()
if count
}
for user_ids in self.external_process_to_current_syncs.values():
@@ -479,7 +481,9 @@ class PresenceHandler(object):
return set()
@defer.inlineCallbacks
- def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec):
+ def update_external_syncs_row(
+ self, process_id, user_id, is_syncing, sync_time_msec
+ ):
"""Update the syncing users for an external process as a delta.
Args:
@@ -500,20 +504,22 @@ class PresenceHandler(object):
updates = []
if is_syncing and user_id not in process_presence:
if prev_state.state == PresenceState.OFFLINE:
- updates.append(prev_state.copy_and_replace(
- state=PresenceState.ONLINE,
- last_active_ts=sync_time_msec,
- last_user_sync_ts=sync_time_msec,
- ))
+ updates.append(
+ prev_state.copy_and_replace(
+ state=PresenceState.ONLINE,
+ last_active_ts=sync_time_msec,
+ last_user_sync_ts=sync_time_msec,
+ )
+ )
else:
- updates.append(prev_state.copy_and_replace(
- last_user_sync_ts=sync_time_msec,
- ))
+ updates.append(
+ prev_state.copy_and_replace(last_user_sync_ts=sync_time_msec)
+ )
process_presence.add(user_id)
elif user_id in process_presence:
- updates.append(prev_state.copy_and_replace(
- last_user_sync_ts=sync_time_msec,
- ))
+ updates.append(
+ prev_state.copy_and_replace(last_user_sync_ts=sync_time_msec)
+ )
if not is_syncing:
process_presence.discard(user_id)
@@ -537,12 +543,12 @@ class PresenceHandler(object):
prev_states = yield self.current_state_for_users(process_presence)
time_now_ms = self.clock.time_msec()
- yield self._update_states([
- prev_state.copy_and_replace(
- last_user_sync_ts=time_now_ms,
- )
- for prev_state in itervalues(prev_states)
- ])
+ yield self._update_states(
+ [
+ prev_state.copy_and_replace(last_user_sync_ts=time_now_ms)
+ for prev_state in itervalues(prev_states)
+ ]
+ )
self.external_process_last_updated_ms.pop(process_id, None)
@defer.inlineCallbacks
@@ -574,8 +580,7 @@ class PresenceHandler(object):
missing = [user_id for user_id, state in iteritems(states) if not state]
if missing:
new = {
- user_id: UserPresenceState.default(user_id)
- for user_id in missing
+ user_id: UserPresenceState.default(user_id) for user_id in missing
}
states.update(new)
self.user_to_current_state.update(new)
@@ -593,8 +598,10 @@ class PresenceHandler(object):
room_ids_to_states, users_to_states = parties
self.notifier.on_new_event(
- "presence_key", stream_id, rooms=room_ids_to_states.keys(),
- users=[UserID.from_string(u) for u in users_to_states]
+ "presence_key",
+ stream_id,
+ rooms=room_ids_to_states.keys(),
+ users=[UserID.from_string(u) for u in users_to_states],
)
self._push_to_remotes(states)
@@ -605,8 +612,10 @@ class PresenceHandler(object):
room_ids_to_states, users_to_states = parties
self.notifier.on_new_event(
- "presence_key", stream_id, rooms=room_ids_to_states.keys(),
- users=[UserID.from_string(u) for u in users_to_states]
+ "presence_key",
+ stream_id,
+ rooms=room_ids_to_states.keys(),
+ users=[UserID.from_string(u) for u in users_to_states],
)
def _push_to_remotes(self, states):
@@ -631,15 +640,15 @@ class PresenceHandler(object):
user_id = push.get("user_id", None)
if not user_id:
logger.info(
- "Got presence update from %r with no 'user_id': %r",
- origin, push,
+ "Got presence update from %r with no 'user_id': %r", origin, push
)
continue
if get_domain_from_id(user_id) != origin:
logger.info(
"Got presence update from %r with bad 'user_id': %r",
- origin, user_id,
+ origin,
+ user_id,
)
continue
@@ -647,14 +656,12 @@ class PresenceHandler(object):
if not presence_state:
logger.info(
"Got presence update from %r with no 'presence_state': %r",
- origin, push,
+ origin,
+ push,
)
continue
- new_fields = {
- "state": presence_state,
- "last_federation_update_ts": now,
- }
+ new_fields = {"state": presence_state, "last_federation_update_ts": now}
last_active_ago = push.get("last_active_ago", None)
if last_active_ago is not None:
@@ -672,10 +679,7 @@ class PresenceHandler(object):
@defer.inlineCallbacks
def get_state(self, target_user, as_event=False):
- results = yield self.get_states(
- [target_user.to_string()],
- as_event=as_event,
- )
+ results = yield self.get_states([target_user.to_string()], as_event=as_event)
defer.returnValue(results[0])
@@ -699,13 +703,15 @@ class PresenceHandler(object):
now = self.clock.time_msec()
if as_event:
- defer.returnValue([
- {
- "type": "m.presence",
- "content": format_user_presence_state(state, now),
- }
- for state in updates
- ])
+ defer.returnValue(
+ [
+ {
+ "type": "m.presence",
+ "content": format_user_presence_state(state, now),
+ }
+ for state in updates
+ ]
+ )
else:
defer.returnValue(updates)
@@ -717,7 +723,9 @@ class PresenceHandler(object):
presence = state["presence"]
valid_presence = (
- PresenceState.ONLINE, PresenceState.UNAVAILABLE, PresenceState.OFFLINE
+ PresenceState.ONLINE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.OFFLINE,
)
if presence not in valid_presence:
raise SynapseError(400, "Invalid presence state")
@@ -726,9 +734,7 @@ class PresenceHandler(object):
prev_state = yield self.current_state_for_user(user_id)
- new_fields = {
- "state": presence
- }
+ new_fields = {"state": presence}
if not ignore_status_msg:
msg = status_msg if presence != PresenceState.OFFLINE else None
@@ -877,8 +883,7 @@ class PresenceHandler(object):
hosts = set(host for host in hosts if host != self.server_name)
self.federation.send_presence_to_destinations(
- states=[state],
- destinations=hosts,
+ states=[state], destinations=hosts
)
else:
# A remote user has joined the room, so we need to:
@@ -904,7 +909,8 @@ class PresenceHandler(object):
# default state.
now = self.clock.time_msec()
states = [
- state for state in states.values()
+ state
+ for state in states.values()
if state.state != PresenceState.OFFLINE
or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000
or state.status_msg is not None
@@ -912,8 +918,7 @@ class PresenceHandler(object):
if states:
self.federation.send_presence_to_destinations(
- states=states,
- destinations=[get_domain_from_id(user_id)],
+ states=states, destinations=[get_domain_from_id(user_id)]
)
@@ -937,7 +942,10 @@ def should_notify(old_state, new_state):
notify_reason_counter.labels("current_active_change").inc()
return True
- if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
+ if (
+ new_state.last_active_ts - old_state.last_active_ts
+ > LAST_ACTIVE_GRANULARITY
+ ):
# Only notify about last active bumps if we're not currently acive
if not new_state.currently_active:
notify_reason_counter.labels("last_active_change_online").inc()
@@ -958,9 +966,7 @@ def format_user_presence_state(state, now, include_user_id=True):
The "user_id" is optional so that this function can be used to format presence
updates for client /sync responses and for federation /send requests.
"""
- content = {
- "presence": state.state,
- }
+ content = {"presence": state.state}
if include_user_id:
content["user_id"] = state.user_id
if state.last_active_ts:
@@ -986,8 +992,15 @@ class PresenceEventSource(object):
@defer.inlineCallbacks
@log_function
- def get_new_events(self, user, from_key, room_ids=None, include_offline=True,
- explicit_room_id=None, **kwargs):
+ def get_new_events(
+ self,
+ user,
+ from_key,
+ room_ids=None,
+ include_offline=True,
+ explicit_room_id=None,
+ **kwargs
+ ):
# The process for getting presence events are:
# 1. Get the rooms the user is in.
# 2. Get the list of user in the rooms.
@@ -1030,7 +1043,7 @@ class PresenceEventSource(object):
if from_key:
user_ids_changed = stream_change_cache.get_entities_changed(
- users_interested_in, from_key,
+ users_interested_in, from_key
)
else:
user_ids_changed = users_interested_in
@@ -1040,10 +1053,16 @@ class PresenceEventSource(object):
if include_offline:
defer.returnValue((list(updates.values()), max_token))
else:
- defer.returnValue(([
- s for s in itervalues(updates)
- if s.state != PresenceState.OFFLINE
- ], max_token))
+ defer.returnValue(
+ (
+ [
+ s
+ for s in itervalues(updates)
+ if s.state != PresenceState.OFFLINE
+ ],
+ max_token,
+ )
+ )
def get_current_key(self):
return self.store.get_current_presence_token()
@@ -1061,13 +1080,13 @@ class PresenceEventSource(object):
users_interested_in.add(user_id) # So that we receive our own presence
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
- user_id, on_invalidate=cache_context.invalidate,
+ user_id, on_invalidate=cache_context.invalidate
)
users_interested_in.update(users_who_share_room)
if explicit_room_id:
user_ids = yield self.store.get_users_in_room(
- explicit_room_id, on_invalidate=cache_context.invalidate,
+ explicit_room_id, on_invalidate=cache_context.invalidate
)
users_interested_in.update(user_ids)
@@ -1123,9 +1142,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now):
if now - state.last_active_ts > IDLE_TIMER:
# Currently online, but last activity ages ago so auto
# idle
- state = state.copy_and_replace(
- state=PresenceState.UNAVAILABLE,
- )
+ state = state.copy_and_replace(state=PresenceState.UNAVAILABLE)
changed = True
elif now - state.last_active_ts > LAST_ACTIVE_GRANULARITY:
# So that we send down a notification that we've
@@ -1145,8 +1162,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now):
sync_or_active = max(state.last_user_sync_ts, state.last_active_ts)
if now - sync_or_active > SYNC_ONLINE_TIMEOUT:
state = state.copy_and_replace(
- state=PresenceState.OFFLINE,
- status_msg=None,
+ state=PresenceState.OFFLINE, status_msg=None
)
changed = True
else:
@@ -1155,10 +1171,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now):
# no one gets stuck online forever.
if now - state.last_federation_update_ts > FEDERATION_TIMEOUT:
# The other side seems to have disappeared.
- state = state.copy_and_replace(
- state=PresenceState.OFFLINE,
- status_msg=None,
- )
+ state = state.copy_and_replace(state=PresenceState.OFFLINE, status_msg=None)
changed = True
return state if changed else None
@@ -1193,21 +1206,17 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now):
if new_state.state == PresenceState.ONLINE:
# Idle timer
wheel_timer.insert(
- now=now,
- obj=user_id,
- then=new_state.last_active_ts + IDLE_TIMER
+ now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER
)
active = now - new_state.last_active_ts < LAST_ACTIVE_GRANULARITY
- new_state = new_state.copy_and_replace(
- currently_active=active,
- )
+ new_state = new_state.copy_and_replace(currently_active=active)
if active:
wheel_timer.insert(
now=now,
obj=user_id,
- then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY
+ then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY,
)
if new_state.state != PresenceState.OFFLINE:
@@ -1215,29 +1224,25 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now):
wheel_timer.insert(
now=now,
obj=user_id,
- then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT
+ then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
)
last_federate = new_state.last_federation_update_ts
if now - last_federate > FEDERATION_PING_INTERVAL:
# Been a while since we've poked remote servers
- new_state = new_state.copy_and_replace(
- last_federation_update_ts=now,
- )
+ new_state = new_state.copy_and_replace(last_federation_update_ts=now)
federation_ping = True
else:
wheel_timer.insert(
now=now,
obj=user_id,
- then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT
+ then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT,
)
# Check whether the change was something worth notifying about
if should_notify(prev_state, new_state):
- new_state = new_state.copy_and_replace(
- last_federation_update_ts=now,
- )
+ new_state = new_state.copy_and_replace(last_federation_update_ts=now)
persist_and_notify = True
return new_state, persist_and_notify, federation_ping
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index a5fc6c5dbf..d8462b75ec 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -15,12 +15,15 @@
import logging
+from six import raise_from
+
from twisted.internet import defer
from synapse.api.errors import (
AuthError,
- CodeMessageException,
Codes,
+ HttpResponseException,
+ RequestSendFailed,
StoreError,
SynapseError,
)
@@ -70,25 +73,20 @@ class BaseProfileHandler(BaseHandler):
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
raise
- defer.returnValue({
- "displayname": displayname,
- "avatar_url": avatar_url,
- })
+ defer.returnValue({"displayname": displayname, "avatar_url": avatar_url})
else:
try:
result = yield self.federation.make_query(
destination=target_user.domain,
query_type="profile",
- args={
- "user_id": user_id,
- },
+ args={"user_id": user_id},
ignore_backoff=True,
)
defer.returnValue(result)
- except CodeMessageException as e:
- if e.code != 404:
- logger.exception("Failed to get displayname")
- raise
+ except RequestSendFailed as e:
+ raise_from(SynapseError(502, "Failed to fetch profile"), e)
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
@defer.inlineCallbacks
def get_profile_from_cache(self, user_id):
@@ -110,10 +108,7 @@ class BaseProfileHandler(BaseHandler):
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
raise
- defer.returnValue({
- "displayname": displayname,
- "avatar_url": avatar_url,
- })
+ defer.returnValue({"displayname": displayname, "avatar_url": avatar_url})
else:
profile = yield self.store.get_from_remote_profile_cache(user_id)
defer.returnValue(profile or {})
@@ -136,16 +131,13 @@ class BaseProfileHandler(BaseHandler):
result = yield self.federation.make_query(
destination=target_user.domain,
query_type="profile",
- args={
- "user_id": target_user.to_string(),
- "field": "displayname",
- },
+ args={"user_id": target_user.to_string(), "field": "displayname"},
ignore_backoff=True,
)
- except CodeMessageException as e:
- if e.code != 404:
- logger.exception("Failed to get displayname")
- raise
+ except RequestSendFailed as e:
+ raise_from(SynapseError(502, "Failed to fetch profile"), e)
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
defer.returnValue(result["displayname"])
@@ -167,15 +159,13 @@ class BaseProfileHandler(BaseHandler):
if len(new_displayname) > MAX_DISPLAYNAME_LEN:
raise SynapseError(
- 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN, ),
+ 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
)
- if new_displayname == '':
+ if new_displayname == "":
new_displayname = None
- yield self.store.set_profile_displayname(
- target_user.localpart, new_displayname
- )
+ yield self.store.set_profile_displayname(target_user.localpart, new_displayname)
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(target_user.localpart)
@@ -202,16 +192,13 @@ class BaseProfileHandler(BaseHandler):
result = yield self.federation.make_query(
destination=target_user.domain,
query_type="profile",
- args={
- "user_id": target_user.to_string(),
- "field": "avatar_url",
- },
+ args={"user_id": target_user.to_string(), "field": "avatar_url"},
ignore_backoff=True,
)
- except CodeMessageException as e:
- if e.code != 404:
- logger.exception("Failed to get avatar_url")
- raise
+ except RequestSendFailed as e:
+ raise_from(SynapseError(502, "Failed to fetch profile"), e)
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
defer.returnValue(result["avatar_url"])
@@ -227,12 +214,10 @@ class BaseProfileHandler(BaseHandler):
if len(new_avatar_url) > MAX_AVATAR_URL_LEN:
raise SynapseError(
- 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN, ),
+ 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
)
- yield self.store.set_profile_avatar_url(
- target_user.localpart, new_avatar_url
- )
+ yield self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(target_user.localpart)
@@ -275,9 +260,7 @@ class BaseProfileHandler(BaseHandler):
yield self.ratelimit(requester)
- room_ids = yield self.store.get_rooms_for_user(
- target_user.to_string(),
- )
+ room_ids = yield self.store.get_rooms_for_user(target_user.to_string())
for room_id in room_ids:
handler = self.hs.get_room_member_handler()
@@ -293,8 +276,7 @@ class BaseProfileHandler(BaseHandler):
)
except Exception as e:
logger.warn(
- "Failed to update join event for room %s - %s",
- room_id, str(e)
+ "Failed to update join event for room %s - %s", room_id, str(e)
)
@defer.inlineCallbacks
@@ -322,11 +304,9 @@ class BaseProfileHandler(BaseHandler):
return
try:
- requester_rooms = yield self.store.get_rooms_for_user(
- requester.to_string()
- )
+ requester_rooms = yield self.store.get_rooms_for_user(requester.to_string())
target_user_rooms = yield self.store.get_rooms_for_user(
- target_user.to_string(),
+ target_user.to_string()
)
# Check if the room lists have no elements in common.
@@ -350,12 +330,12 @@ class MasterProfileHandler(BaseProfileHandler):
assert hs.config.worker_app is None
self.clock.looping_call(
- self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS,
+ self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS
)
def _start_update_remote_profile_cache(self):
return run_as_background_process(
- "Update remote profile", self._update_remote_profile_cache,
+ "Update remote profile", self._update_remote_profile_cache
)
@defer.inlineCallbacks
@@ -369,7 +349,7 @@ class MasterProfileHandler(BaseProfileHandler):
for user_id, displayname, avatar_url in entries:
is_subscribed = yield self.store.is_subscribed_remote_profile_for_user(
- user_id,
+ user_id
)
if not is_subscribed:
yield self.store.maybe_delete_remote_profile_cache(user_id)
@@ -379,9 +359,7 @@ class MasterProfileHandler(BaseProfileHandler):
profile = yield self.federation.make_query(
destination=get_domain_from_id(user_id),
query_type="profile",
- args={
- "user_id": user_id,
- },
+ args={"user_id": user_id},
ignore_backoff=True,
)
except Exception:
@@ -396,6 +374,4 @@ class MasterProfileHandler(BaseProfileHandler):
new_avatar = profile.get("avatar_url")
# We always hit update to update the last_check timestamp
- yield self.store.update_remote_profile_cache(
- user_id, new_name, new_avatar
- )
+ yield self.store.update_remote_profile_cache(user_id, new_name, new_avatar)
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index 32108568c6..3e4d8c93a4 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -43,7 +43,7 @@ class ReadMarkerHandler(BaseHandler):
with (yield self.read_marker_linearizer.queue((room_id, user_id))):
existing_read_marker = yield self.store.get_account_data_for_room_and_type(
- user_id, room_id, "m.fully_read",
+ user_id, room_id, "m.fully_read"
)
should_update = True
@@ -51,14 +51,11 @@ class ReadMarkerHandler(BaseHandler):
if existing_read_marker:
# Only update if the new marker is ahead in the stream
should_update = yield self.store.is_event_after(
- event_id,
- existing_read_marker['event_id']
+ event_id, existing_read_marker["event_id"]
)
if should_update:
- content = {
- "event_id": event_id
- }
+ content = {"event_id": event_id}
max_id = yield self.store.add_account_data_to_room(
user_id, room_id, "m.fully_read", content
)
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 274d2946ad..a85dd8cdee 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -88,19 +88,16 @@ class ReceiptsHandler(BaseHandler):
affected_room_ids = list(set([r.room_id for r in receipts]))
- self.notifier.on_new_event(
- "receipt_key", max_batch_id, rooms=affected_room_ids
- )
+ self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
# Note that the min here shouldn't be relied upon to be accurate.
yield self.hs.get_pusherpool().on_new_receipts(
- min_batch_id, max_batch_id, affected_room_ids,
+ min_batch_id, max_batch_id, affected_room_ids
)
defer.returnValue(True)
@defer.inlineCallbacks
- def received_client_receipt(self, room_id, receipt_type, user_id,
- event_id):
+ def received_client_receipt(self, room_id, receipt_type, user_id, event_id):
"""Called when a client tells us a local user has read up to the given
event_id in the room.
"""
@@ -109,9 +106,7 @@ class ReceiptsHandler(BaseHandler):
receipt_type=receipt_type,
user_id=user_id,
event_ids=[event_id],
- data={
- "ts": int(self.clock.time_msec()),
- },
+ data={"ts": int(self.clock.time_msec())},
)
is_new = yield self._handle_new_receipts([receipt])
@@ -125,8 +120,7 @@ class ReceiptsHandler(BaseHandler):
"""Gets all receipts for a room, upto the given key.
"""
result = yield self.store.get_linearized_receipts_for_room(
- room_id,
- to_key=to_key,
+ room_id, to_key=to_key
)
if not result:
@@ -148,14 +142,12 @@ class ReceiptEventSource(object):
defer.returnValue(([], to_key))
events = yield self.store.get_linearized_receipts_for_rooms(
- room_ids,
- from_key=from_key,
- to_key=to_key,
+ room_ids, from_key=from_key, to_key=to_key
)
defer.returnValue((events, to_key))
- def get_current_key(self, direction='f'):
+ def get_current_key(self, direction="f"):
return self.store.get_max_receipt_stream_id()
@defer.inlineCallbacks
@@ -169,9 +161,7 @@ class ReceiptEventSource(object):
room_ids = yield self.store.get_rooms_for_user(user.to_string())
events = yield self.store.get_linearized_receipts_for_rooms(
- room_ids,
- from_key=from_key,
- to_key=to_key,
+ room_ids, from_key=from_key, to_key=to_key
)
defer.returnValue((events, to_key))
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 9a388ea013..e487b90c08 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -47,7 +47,6 @@ logger = logging.getLogger(__name__)
class RegistrationHandler(BaseHandler):
-
def __init__(self, hs):
"""
@@ -69,44 +68,37 @@ class RegistrationHandler(BaseHandler):
self.macaroon_gen = hs.get_macaroon_generator()
self._generate_user_id_linearizer = Linearizer(
- name="_generate_user_id_linearizer",
+ name="_generate_user_id_linearizer"
)
self._server_notices_mxid = hs.config.server_notices_mxid
if hs.config.worker_app:
self._register_client = ReplicationRegisterServlet.make_client(hs)
- self._register_device_client = (
- RegisterDeviceReplicationServlet.make_client(hs)
+ self._register_device_client = RegisterDeviceReplicationServlet.make_client(
+ hs
)
- self._post_registration_client = (
- ReplicationPostRegisterActionsServlet.make_client(hs)
+ self._post_registration_client = ReplicationPostRegisterActionsServlet.make_client(
+ hs
)
else:
self.device_handler = hs.get_device_handler()
self.pusher_pool = hs.get_pusherpool()
@defer.inlineCallbacks
- def check_username(self, localpart, guest_access_token=None,
- assigned_user_id=None):
+ def check_username(self, localpart, guest_access_token=None, assigned_user_id=None):
if types.contains_invalid_mxid_characters(localpart):
raise SynapseError(
400,
"User ID can only contain characters a-z, 0-9, or '=_-./'",
- Codes.INVALID_USERNAME
+ Codes.INVALID_USERNAME,
)
if not localpart:
- raise SynapseError(
- 400,
- "User ID cannot be empty",
- Codes.INVALID_USERNAME
- )
+ raise SynapseError(400, "User ID cannot be empty", Codes.INVALID_USERNAME)
- if localpart[0] == '_':
+ if localpart[0] == "_":
raise SynapseError(
- 400,
- "User ID may not begin with _",
- Codes.INVALID_USERNAME
+ 400, "User ID may not begin with _", Codes.INVALID_USERNAME
)
user = UserID(localpart, self.hs.hostname)
@@ -126,19 +118,15 @@ class RegistrationHandler(BaseHandler):
if len(user_id) > MAX_USERID_LENGTH:
raise SynapseError(
400,
- "User ID may not be longer than %s characters" % (
- MAX_USERID_LENGTH,
- ),
- Codes.INVALID_USERNAME
+ "User ID may not be longer than %s characters" % (MAX_USERID_LENGTH,),
+ Codes.INVALID_USERNAME,
)
users = yield self.store.get_users_by_id_case_insensitive(user_id)
if users:
if not guest_access_token:
raise SynapseError(
- 400,
- "User ID already taken.",
- errcode=Codes.USER_IN_USE,
+ 400, "User ID already taken.", errcode=Codes.USER_IN_USE
)
user_data = yield self.auth.get_user_by_access_token(guest_access_token)
if not user_data["is_guest"] or user_data["user"].localpart != localpart:
@@ -203,8 +191,7 @@ class RegistrationHandler(BaseHandler):
try:
int(localpart)
raise RegistrationError(
- 400,
- "Numeric user IDs are reserved for guest users."
+ 400, "Numeric user IDs are reserved for guest users."
)
except ValueError:
pass
@@ -283,9 +270,7 @@ class RegistrationHandler(BaseHandler):
}
# Bind email to new account
- yield self._register_email_threepid(
- user_id, threepid_dict, None, False,
- )
+ yield self._register_email_threepid(user_id, threepid_dict, None, False)
defer.returnValue((user_id, token))
@@ -318,8 +303,8 @@ class RegistrationHandler(BaseHandler):
room_alias = RoomAlias.from_string(r)
if self.hs.hostname != room_alias.domain:
logger.warning(
- 'Cannot create room alias %s, '
- 'it does not match server domain',
+ "Cannot create room alias %s, "
+ "it does not match server domain",
r,
)
else:
@@ -332,7 +317,7 @@ class RegistrationHandler(BaseHandler):
fake_requester,
config={
"preset": "public_chat",
- "room_alias_name": room_alias_localpart
+ "room_alias_name": room_alias_localpart,
},
ratelimit=False,
)
@@ -364,8 +349,9 @@ class RegistrationHandler(BaseHandler):
raise AuthError(403, "Invalid application service token.")
if not service.is_interested_in_user(user_id):
raise SynapseError(
- 400, "Invalid user localpart for this application service.",
- errcode=Codes.EXCLUSIVE
+ 400,
+ "Invalid user localpart for this application service.",
+ errcode=Codes.EXCLUSIVE,
)
service_id = service.id if service.is_exclusive_user(user_id) else None
@@ -391,17 +377,15 @@ class RegistrationHandler(BaseHandler):
"""
captcha_response = yield self._validate_captcha(
- ip,
- private_key,
- challenge,
- response
+ ip, private_key, challenge, response
)
if not captcha_response["valid"]:
- logger.info("Invalid captcha entered from %s. Error: %s",
- ip, captcha_response["error_url"])
- raise InvalidCaptchaError(
- error_url=captcha_response["error_url"]
+ logger.info(
+ "Invalid captcha entered from %s. Error: %s",
+ ip,
+ captcha_response["error_url"],
)
+ raise InvalidCaptchaError(error_url=captcha_response["error_url"])
else:
logger.info("Valid captcha entered from %s", ip)
@@ -414,8 +398,11 @@ class RegistrationHandler(BaseHandler):
"""
for c in threepidCreds:
- logger.info("validating threepidcred sid %s on id server %s",
- c['sid'], c['idServer'])
+ logger.info(
+ "validating threepidcred sid %s on id server %s",
+ c["sid"],
+ c["idServer"],
+ )
try:
threepid = yield self.identity_handler.threepid_from_creds(c)
except Exception:
@@ -424,13 +411,14 @@ class RegistrationHandler(BaseHandler):
if not threepid:
raise RegistrationError(400, "Couldn't validate 3pid")
- logger.info("got threepid with medium '%s' and address '%s'",
- threepid['medium'], threepid['address'])
+ logger.info(
+ "got threepid with medium '%s' and address '%s'",
+ threepid["medium"],
+ threepid["address"],
+ )
- if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']):
- raise RegistrationError(
- 403, "Third party identifier is not allowed"
- )
+ if not check_3pid_allowed(self.hs, threepid["medium"], threepid["address"]):
+ raise RegistrationError(403, "Third party identifier is not allowed")
@defer.inlineCallbacks
def bind_emails(self, user_id, threepidCreds):
@@ -449,23 +437,23 @@ class RegistrationHandler(BaseHandler):
if self._server_notices_mxid is not None:
if user_id == self._server_notices_mxid:
raise SynapseError(
- 400, "This user ID is reserved.",
- errcode=Codes.EXCLUSIVE
+ 400, "This user ID is reserved.", errcode=Codes.EXCLUSIVE
)
# valid user IDs must not clash with any user ID namespaces claimed by
# application services.
services = self.store.get_app_services()
interested_services = [
- s for s in services
- if s.is_interested_in_user(user_id)
- and s != allowed_appservice
+ s
+ for s in services
+ if s.is_interested_in_user(user_id) and s != allowed_appservice
]
for service in interested_services:
if service.is_exclusive_user(user_id):
raise SynapseError(
- 400, "This user ID is reserved by an application service.",
- errcode=Codes.EXCLUSIVE
+ 400,
+ "This user ID is reserved by an application service.",
+ errcode=Codes.EXCLUSIVE,
)
@defer.inlineCallbacks
@@ -491,14 +479,13 @@ class RegistrationHandler(BaseHandler):
dict: Containing 'valid'(bool) and 'error_url'(str) if invalid.
"""
- response = yield self._submit_captcha(ip_addr, private_key, challenge,
- response)
+ response = yield self._submit_captcha(ip_addr, private_key, challenge, response)
# parse Google's response. Lovely format..
- lines = response.split('\n')
+ lines = response.split("\n")
json = {
- "valid": lines[0] == 'true',
- "error_url": "http://www.recaptcha.net/recaptcha/api/challenge?" +
- "error=%s" % lines[1]
+ "valid": lines[0] == "true",
+ "error_url": "http://www.recaptcha.net/recaptcha/api/challenge?"
+ + "error=%s" % lines[1],
}
defer.returnValue(json)
@@ -510,17 +497,16 @@ class RegistrationHandler(BaseHandler):
data = yield self.captcha_client.post_urlencoded_get_raw(
"http://www.recaptcha.net:80/recaptcha/api/verify",
args={
- 'privatekey': private_key,
- 'remoteip': ip_addr,
- 'challenge': challenge,
- 'response': response
- }
+ "privatekey": private_key,
+ "remoteip": ip_addr,
+ "challenge": challenge,
+ "response": response,
+ },
)
defer.returnValue(data)
@defer.inlineCallbacks
- def get_or_create_user(self, requester, localpart, displayname,
- password_hash=None):
+ def get_or_create_user(self, requester, localpart, displayname, password_hash=None):
"""Creates a new user if the user does not exist,
else revokes all previous access tokens and generates a new one.
@@ -565,7 +551,7 @@ class RegistrationHandler(BaseHandler):
if displayname is not None:
logger.info("setting user display name: %s -> %s", user_id, displayname)
yield self.profile_handler.set_displayname(
- user, requester, displayname, by_admin=True,
+ user, requester, displayname, by_admin=True
)
defer.returnValue((user_id, token))
@@ -587,15 +573,12 @@ class RegistrationHandler(BaseHandler):
"""
access_token = yield self.store.get_3pid_guest_access_token(medium, address)
if access_token:
- user_info = yield self.auth.get_user_by_access_token(
- access_token
- )
+ user_info = yield self.auth.get_user_by_access_token(access_token)
defer.returnValue((user_info["user"].to_string(), access_token))
user_id, access_token = yield self.register(
- generate_token=True,
- make_guest=True
+ generate_token=True, make_guest=True
)
access_token = yield self.store.save_or_get_3pid_guest_access_token(
medium, address, access_token, inviter_user_id
@@ -616,9 +599,9 @@ class RegistrationHandler(BaseHandler):
)
room_id = room_id.to_string()
else:
- raise SynapseError(400, "%s was not legal room ID or room alias" % (
- room_identifier,
- ))
+ raise SynapseError(
+ 400, "%s was not legal room ID or room alias" % (room_identifier,)
+ )
yield room_member_handler.update_membership(
requester=requester,
@@ -629,10 +612,19 @@ class RegistrationHandler(BaseHandler):
ratelimit=False,
)
- def register_with_store(self, user_id, token=None, password_hash=None,
- was_guest=False, make_guest=False, appservice_id=None,
- create_profile_with_displayname=None, admin=False,
- user_type=None, address=None):
+ def register_with_store(
+ self,
+ user_id,
+ token=None,
+ password_hash=None,
+ was_guest=False,
+ make_guest=False,
+ appservice_id=None,
+ create_profile_with_displayname=None,
+ admin=False,
+ user_type=None,
+ address=None,
+ ):
"""Register user in the datastore.
Args:
@@ -661,14 +653,15 @@ class RegistrationHandler(BaseHandler):
time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.can_do_action(
- address, time_now_s=time_now,
+ address,
+ time_now_s=time_now,
rate_hz=self.hs.config.rc_registration.per_second,
burst_count=self.hs.config.rc_registration.burst_count,
)
if not allowed:
raise LimitExceededError(
- retry_after_ms=int(1000 * (time_allowed - time_now)),
+ retry_after_ms=int(1000 * (time_allowed - time_now))
)
if self.hs.config.worker_app:
@@ -698,8 +691,7 @@ class RegistrationHandler(BaseHandler):
)
@defer.inlineCallbacks
- def register_device(self, user_id, device_id, initial_display_name,
- is_guest=False):
+ def register_device(self, user_id, device_id, initial_display_name, is_guest=False):
"""Register a device for a user and generate an access token.
Args:
@@ -732,14 +724,15 @@ class RegistrationHandler(BaseHandler):
)
else:
access_token = yield self._auth_handler.get_access_token_for_user_id(
- user_id, device_id=device_id,
+ user_id, device_id=device_id
)
defer.returnValue((device_id, access_token))
@defer.inlineCallbacks
- def post_registration_actions(self, user_id, auth_result, access_token,
- bind_email, bind_msisdn):
+ def post_registration_actions(
+ self, user_id, auth_result, access_token, bind_email, bind_msisdn
+ ):
"""A user has completed registration
Args:
@@ -773,20 +766,15 @@ class RegistrationHandler(BaseHandler):
yield self.store.upsert_monthly_active_user(user_id)
yield self._register_email_threepid(
- user_id, threepid, access_token,
- bind_email,
+ user_id, threepid, access_token, bind_email
)
if auth_result and LoginType.MSISDN in auth_result:
threepid = auth_result[LoginType.MSISDN]
- yield self._register_msisdn_threepid(
- user_id, threepid, bind_msisdn,
- )
+ yield self._register_msisdn_threepid(user_id, threepid, bind_msisdn)
if auth_result and LoginType.TERMS in auth_result:
- yield self._on_user_consented(
- user_id, self.hs.config.user_consent_version,
- )
+ yield self._on_user_consented(user_id, self.hs.config.user_consent_version)
@defer.inlineCallbacks
def _on_user_consented(self, user_id, consent_version):
@@ -798,9 +786,7 @@ class RegistrationHandler(BaseHandler):
consented to.
"""
logger.info("%s has consented to the privacy policy", user_id)
- yield self.store.user_set_consent_version(
- user_id, consent_version,
- )
+ yield self.store.user_set_consent_version(user_id, consent_version)
yield self.post_consent_actions(user_id)
@defer.inlineCallbacks
@@ -824,33 +810,30 @@ class RegistrationHandler(BaseHandler):
Returns:
defer.Deferred:
"""
- reqd = ('medium', 'address', 'validated_at')
+ reqd = ("medium", "address", "validated_at")
if any(x not in threepid for x in reqd):
# This will only happen if the ID server returns a malformed response
logger.info("Can't add incomplete 3pid")
return
yield self._auth_handler.add_threepid(
- user_id,
- threepid['medium'],
- threepid['address'],
- threepid['validated_at'],
+ user_id, threepid["medium"], threepid["address"], threepid["validated_at"]
)
# And we add an email pusher for them by default, but only
# if email notifications are enabled (so people don't start
# getting mail spam where they weren't before if email
# notifs are set up on a home server)
- if (self.hs.config.email_enable_notifs and
- self.hs.config.email_notif_for_new_users
- and token):
+ if (
+ self.hs.config.email_enable_notifs
+ and self.hs.config.email_notif_for_new_users
+ and token
+ ):
# Pull the ID of the access token back out of the db
# It would really make more sense for this to be passed
# up when the access token is saved, but that's quite an
# invasive change I'd rather do separately.
- user_tuple = yield self.store.get_user_by_access_token(
- token
- )
+ user_tuple = yield self.store.get_user_by_access_token(token)
token_id = user_tuple["token_id"]
yield self.pusher_pool.add_pusher(
@@ -867,11 +850,9 @@ class RegistrationHandler(BaseHandler):
if bind_email:
logger.info("bind_email specified: binding")
- logger.debug("Binding emails %s to %s" % (
- threepid, user_id
- ))
+ logger.debug("Binding emails %s to %s" % (threepid, user_id))
yield self.identity_handler.bind_threepid(
- threepid['threepid_creds'], user_id
+ threepid["threepid_creds"], user_id
)
else:
logger.info("bind_email not specified: not binding email")
@@ -894,7 +875,7 @@ class RegistrationHandler(BaseHandler):
defer.Deferred:
"""
try:
- assert_params_in_dict(threepid, ['medium', 'address', 'validated_at'])
+ assert_params_in_dict(threepid, ["medium", "address", "validated_at"])
except SynapseError as ex:
if ex.errcode == Codes.MISSING_PARAM:
# This will only happen if the ID server returns a malformed response
@@ -903,17 +884,14 @@ class RegistrationHandler(BaseHandler):
raise
yield self._auth_handler.add_threepid(
- user_id,
- threepid['medium'],
- threepid['address'],
- threepid['validated_at'],
+ user_id, threepid["medium"], threepid["address"], threepid["validated_at"]
)
if bind_msisdn:
logger.info("bind_msisdn specified: binding")
logger.debug("Binding msisdn %s to %s", threepid, user_id)
yield self.identity_handler.bind_threepid(
- threepid['threepid_creds'], user_id
+ threepid["threepid_creds"], user_id
)
else:
logger.info("bind_msisdn not specified: not binding msisdn")
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 4a17911a87..db3f8cb76b 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -32,6 +32,7 @@ from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer
+from synapse.util.caches.response_cache import ResponseCache
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -40,6 +41,8 @@ logger = logging.getLogger(__name__)
id_server_scheme = "https://"
+FIVE_MINUTES_IN_MS = 5 * 60 * 1000
+
class RoomCreationHandler(BaseHandler):
@@ -75,6 +78,16 @@ class RoomCreationHandler(BaseHandler):
# linearizer to stop two upgrades happening at once
self._upgrade_linearizer = Linearizer("room_upgrade_linearizer")
+ # If a user tries to update the same room multiple times in quick
+ # succession, only process the first attempt and return its result to
+ # subsequent requests
+ self._upgrade_response_cache = ResponseCache(
+ hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
+ )
+ self._server_notices_mxid = hs.config.server_notices_mxid
+
+ self.third_party_event_rules = hs.get_third_party_event_rules()
+
@defer.inlineCallbacks
def upgrade_room(self, requester, old_room_id, new_version):
"""Replace a room with a new room with a different version
@@ -91,70 +104,100 @@ class RoomCreationHandler(BaseHandler):
user_id = requester.user.to_string()
- with (yield self._upgrade_linearizer.queue(old_room_id)):
- # start by allocating a new room id
- r = yield self.store.get_room(old_room_id)
- if r is None:
- raise NotFoundError("Unknown room id %s" % (old_room_id,))
- new_room_id = yield self._generate_room_id(
- creator_id=user_id, is_public=r["is_public"],
- )
+ # Check if this room is already being upgraded by another person
+ for key in self._upgrade_response_cache.pending_result_cache:
+ if key[0] == old_room_id and key[1] != user_id:
+ # Two different people are trying to upgrade the same room.
+ # Send the second an error.
+ #
+ # Note that this of course only gets caught if both users are
+ # on the same homeserver.
+ raise SynapseError(
+ 400, "An upgrade for this room is currently in progress"
+ )
- logger.info("Creating new room %s to replace %s", new_room_id, old_room_id)
+ # Upgrade the room
+ #
+ # If this user has sent multiple upgrade requests for the same room
+ # and one of them is not complete yet, cache the response and
+ # return it to all subsequent requests
+ ret = yield self._upgrade_response_cache.wrap(
+ (old_room_id, user_id),
+ self._upgrade_room,
+ requester,
+ old_room_id,
+ new_version, # args for _upgrade_room
+ )
+ defer.returnValue(ret)
- # we create and auth the tombstone event before properly creating the new
- # room, to check our user has perms in the old room.
- tombstone_event, tombstone_context = (
- yield self.event_creation_handler.create_event(
- requester, {
- "type": EventTypes.Tombstone,
- "state_key": "",
- "room_id": old_room_id,
- "sender": user_id,
- "content": {
- "body": "This room has been replaced",
- "replacement_room": new_room_id,
- }
- },
- token_id=requester.access_token_id,
- )
- )
- old_room_version = yield self.store.get_room_version(old_room_id)
- yield self.auth.check_from_context(
- old_room_version, tombstone_event, tombstone_context,
- )
+ @defer.inlineCallbacks
+ def _upgrade_room(self, requester, old_room_id, new_version):
+ user_id = requester.user.to_string()
- yield self.clone_existing_room(
+ # start by allocating a new room id
+ r = yield self.store.get_room(old_room_id)
+ if r is None:
+ raise NotFoundError("Unknown room id %s" % (old_room_id,))
+ new_room_id = yield self._generate_room_id(
+ creator_id=user_id, is_public=r["is_public"]
+ )
+
+ logger.info("Creating new room %s to replace %s", new_room_id, old_room_id)
+
+ # we create and auth the tombstone event before properly creating the new
+ # room, to check our user has perms in the old room.
+ tombstone_event, tombstone_context = (
+ yield self.event_creation_handler.create_event(
requester,
- old_room_id=old_room_id,
- new_room_id=new_room_id,
- new_room_version=new_version,
- tombstone_event_id=tombstone_event.event_id,
+ {
+ "type": EventTypes.Tombstone,
+ "state_key": "",
+ "room_id": old_room_id,
+ "sender": user_id,
+ "content": {
+ "body": "This room has been replaced",
+ "replacement_room": new_room_id,
+ },
+ },
+ token_id=requester.access_token_id,
)
+ )
+ old_room_version = yield self.store.get_room_version(old_room_id)
+ yield self.auth.check_from_context(
+ old_room_version, tombstone_event, tombstone_context
+ )
- # now send the tombstone
- yield self.event_creation_handler.send_nonmember_event(
- requester, tombstone_event, tombstone_context,
- )
+ yield self.clone_existing_room(
+ requester,
+ old_room_id=old_room_id,
+ new_room_id=new_room_id,
+ new_room_version=new_version,
+ tombstone_event_id=tombstone_event.event_id,
+ )
- old_room_state = yield tombstone_context.get_current_state_ids(self.store)
+ # now send the tombstone
+ yield self.event_creation_handler.send_nonmember_event(
+ requester, tombstone_event, tombstone_context
+ )
- # update any aliases
- yield self._move_aliases_to_new_room(
- requester, old_room_id, new_room_id, old_room_state,
- )
+ old_room_state = yield tombstone_context.get_current_state_ids(self.store)
- # and finally, shut down the PLs in the old room, and update them in the new
- # room.
- yield self._update_upgraded_room_pls(
- requester, old_room_id, new_room_id, old_room_state,
- )
+ # update any aliases
+ yield self._move_aliases_to_new_room(
+ requester, old_room_id, new_room_id, old_room_state
+ )
- defer.returnValue(new_room_id)
+ # and finally, shut down the PLs in the old room, and update them in the new
+ # room.
+ yield self._update_upgraded_room_pls(
+ requester, old_room_id, new_room_id, old_room_state
+ )
+
+ defer.returnValue(new_room_id)
@defer.inlineCallbacks
def _update_upgraded_room_pls(
- self, requester, old_room_id, new_room_id, old_room_state,
+ self, requester, old_room_id, new_room_id, old_room_state
):
"""Send updated power levels in both rooms after an upgrade
@@ -172,7 +215,7 @@ class RoomCreationHandler(BaseHandler):
if old_room_pl_event_id is None:
logger.warning(
"Not supported: upgrading a room with no PL event. Not setting PLs "
- "in old room.",
+ "in old room."
)
return
@@ -193,45 +236,48 @@ class RoomCreationHandler(BaseHandler):
if current < restricted_level:
logger.info(
"Setting level for %s in %s to %i (was %i)",
- v, old_room_id, restricted_level, current,
+ v,
+ old_room_id,
+ restricted_level,
+ current,
)
pl_content[v] = restricted_level
updated = True
else:
- logger.info(
- "Not setting level for %s (already %i)",
- v, current,
- )
+ logger.info("Not setting level for %s (already %i)", v, current)
if updated:
try:
yield self.event_creation_handler.create_and_send_nonmember_event(
- requester, {
+ requester,
+ {
"type": EventTypes.PowerLevels,
- "state_key": '',
+ "state_key": "",
"room_id": old_room_id,
"sender": requester.user.to_string(),
"content": pl_content,
- }, ratelimit=False,
+ },
+ ratelimit=False,
)
except AuthError as e:
logger.warning("Unable to update PLs in old room: %s", e)
logger.info("Setting correct PLs in new room")
yield self.event_creation_handler.create_and_send_nonmember_event(
- requester, {
+ requester,
+ {
"type": EventTypes.PowerLevels,
- "state_key": '',
+ "state_key": "",
"room_id": new_room_id,
"sender": requester.user.to_string(),
"content": old_room_pl_state.content,
- }, ratelimit=False,
+ },
+ ratelimit=False,
)
@defer.inlineCallbacks
def clone_existing_room(
- self, requester, old_room_id, new_room_id, new_room_version,
- tombstone_event_id,
+ self, requester, old_room_id, new_room_id, new_room_version, tombstone_event_id
):
"""Populate a new room based on an old room
@@ -253,10 +299,7 @@ class RoomCreationHandler(BaseHandler):
creation_content = {
"room_version": new_room_version,
- "predecessor": {
- "room_id": old_room_id,
- "event_id": tombstone_event_id,
- }
+ "predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id},
}
# Check if old room was non-federatable
@@ -285,7 +328,7 @@ class RoomCreationHandler(BaseHandler):
)
old_room_state_ids = yield self.store.get_filtered_current_state_ids(
- old_room_id, StateFilter.from_types(types_to_copy),
+ old_room_id, StateFilter.from_types(types_to_copy)
)
# map from event_id to BaseEvent
old_room_state_events = yield self.store.get_events(old_room_state_ids.values())
@@ -298,11 +341,9 @@ class RoomCreationHandler(BaseHandler):
yield self._send_events_for_new_room(
requester,
new_room_id,
-
# we expect to override all the presets with initial_state, so this is
# somewhat arbitrary.
preset_config=RoomCreationPreset.PRIVATE_CHAT,
-
invite_list=[],
initial_state=initial_state,
creation_content=creation_content,
@@ -310,20 +351,22 @@ class RoomCreationHandler(BaseHandler):
# Transfer membership events
old_room_member_state_ids = yield self.store.get_filtered_current_state_ids(
- old_room_id, StateFilter.from_types([(EventTypes.Member, None)]),
+ old_room_id, StateFilter.from_types([(EventTypes.Member, None)])
)
# map from event_id to BaseEvent
old_room_member_state_events = yield self.store.get_events(
- old_room_member_state_ids.values(),
+ old_room_member_state_ids.values()
)
for k, old_event in iteritems(old_room_member_state_events):
# Only transfer ban events
- if ("membership" in old_event.content and
- old_event.content["membership"] == "ban"):
+ if (
+ "membership" in old_event.content
+ and old_event.content["membership"] == "ban"
+ ):
yield self.room_member_handler.update_membership(
requester,
- UserID.from_string(old_event['state_key']),
+ UserID.from_string(old_event["state_key"]),
new_room_id,
"ban",
ratelimit=False,
@@ -335,7 +378,7 @@ class RoomCreationHandler(BaseHandler):
@defer.inlineCallbacks
def _move_aliases_to_new_room(
- self, requester, old_room_id, new_room_id, old_room_state,
+ self, requester, old_room_id, new_room_id, old_room_state
):
directory_handler = self.hs.get_handlers().directory_handler
@@ -366,14 +409,11 @@ class RoomCreationHandler(BaseHandler):
alias = RoomAlias.from_string(alias_str)
try:
yield directory_handler.delete_association(
- requester, alias, send_event=False,
+ requester, alias, send_event=False
)
removed_aliases.append(alias_str)
except SynapseError as e:
- logger.warning(
- "Unable to remove alias %s from old room: %s",
- alias, e,
- )
+ logger.warning("Unable to remove alias %s from old room: %s", alias, e)
# if we didn't find any aliases, or couldn't remove anyway, we can skip the rest
# of this.
@@ -389,30 +429,26 @@ class RoomCreationHandler(BaseHandler):
# as when you remove an alias from the directory normally - it just means that
# the aliases event gets out of sync with the directory
# (cf https://github.com/vector-im/riot-web/issues/2369)
- yield directory_handler.send_room_alias_update_event(
- requester, old_room_id,
- )
+ yield directory_handler.send_room_alias_update_event(requester, old_room_id)
except AuthError as e:
- logger.warning(
- "Failed to send updated alias event on old room: %s", e,
- )
+ logger.warning("Failed to send updated alias event on old room: %s", e)
# we can now add any aliases we successfully removed to the new room.
for alias in removed_aliases:
try:
yield directory_handler.create_association(
- requester, RoomAlias.from_string(alias),
- new_room_id, servers=(self.hs.hostname, ),
- send_event=False, check_membership=False,
+ requester,
+ RoomAlias.from_string(alias),
+ new_room_id,
+ servers=(self.hs.hostname,),
+ send_event=False,
+ check_membership=False,
)
logger.info("Moved alias %s to new room", alias)
except SynapseError as e:
# I'm not really expecting this to happen, but it could if the spam
# checking module decides it shouldn't, or similar.
- logger.error(
- "Error adding alias %s to new room: %s",
- alias, e,
- )
+ logger.error("Error adding alias %s to new room: %s", alias, e)
try:
if canonical_alias and (canonical_alias in removed_aliases):
@@ -423,24 +459,19 @@ class RoomCreationHandler(BaseHandler):
"state_key": "",
"room_id": new_room_id,
"sender": requester.user.to_string(),
- "content": {"alias": canonical_alias, },
+ "content": {"alias": canonical_alias},
},
- ratelimit=False
+ ratelimit=False,
)
- yield directory_handler.send_room_alias_update_event(
- requester, new_room_id,
- )
+ yield directory_handler.send_room_alias_update_event(requester, new_room_id)
except SynapseError as e:
# again I'm not really expecting this to fail, but if it does, I'd rather
# we returned the new room to the client at this point.
- logger.error(
- "Unable to send updated alias events in new room: %s", e,
- )
+ logger.error("Unable to send updated alias events in new room: %s", e)
@defer.inlineCallbacks
- def create_room(self, requester, config, ratelimit=True,
- creator_join_profile=None):
+ def create_room(self, requester, config, ratelimit=True, creator_join_profile=None):
""" Creates a new room.
Args:
@@ -470,23 +501,35 @@ class RoomCreationHandler(BaseHandler):
yield self.auth.check_auth_blocking(user_id)
- if not self.spam_checker.user_may_create_room(user_id):
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
+ # allow the server notices mxid to create rooms
+ is_requester_admin = True
+ else:
+ is_requester_admin = yield self.auth.is_server_admin(requester.user)
+
+ # Check whether the third party rules allows/changes the room create
+ # request.
+ yield self.third_party_event_rules.on_create_room(
+ requester, config, is_requester_admin=is_requester_admin
+ )
+
+ if not is_requester_admin and not self.spam_checker.user_may_create_room(
+ user_id
+ ):
raise SynapseError(403, "You are not permitted to create rooms")
if ratelimit:
yield self.ratelimit(requester)
room_version = config.get(
- "room_version",
- self.config.default_room_version.identifier,
+ "room_version", self.config.default_room_version.identifier
)
if not isinstance(room_version, string_types):
- raise SynapseError(
- 400,
- "room_version must be a string",
- Codes.BAD_JSON,
- )
+ raise SynapseError(400, "room_version must be a string", Codes.BAD_JSON)
if room_version not in KNOWN_ROOM_VERSIONS:
raise SynapseError(
@@ -500,20 +543,11 @@ class RoomCreationHandler(BaseHandler):
if wchar in config["room_alias_name"]:
raise SynapseError(400, "Invalid characters in room alias")
- room_alias = RoomAlias(
- config["room_alias_name"],
- self.hs.hostname,
- )
- mapping = yield self.store.get_association_from_room_alias(
- room_alias
- )
+ room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname)
+ mapping = yield self.store.get_association_from_room_alias(room_alias)
if mapping:
- raise SynapseError(
- 400,
- "Room alias already taken",
- Codes.ROOM_IN_USE
- )
+ raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
else:
room_alias = None
@@ -524,9 +558,7 @@ class RoomCreationHandler(BaseHandler):
except Exception:
raise SynapseError(400, "Invalid user_id: %s" % (i,))
- yield self.event_creation_handler.assert_accepted_privacy_policy(
- requester,
- )
+ yield self.event_creation_handler.assert_accepted_privacy_policy(requester)
invite_3pid_list = config.get("invite_3pid", [])
@@ -550,7 +582,7 @@ class RoomCreationHandler(BaseHandler):
"preset",
RoomCreationPreset.PRIVATE_CHAT
if visibility == "private"
- else RoomCreationPreset.PUBLIC_CHAT
+ else RoomCreationPreset.PUBLIC_CHAT,
)
raw_initial_state = config.get("initial_state", [])
@@ -587,7 +619,8 @@ class RoomCreationHandler(BaseHandler):
"state_key": "",
"content": {"name": name},
},
- ratelimit=False)
+ ratelimit=False,
+ )
if "topic" in config:
topic = config["topic"]
@@ -600,7 +633,8 @@ class RoomCreationHandler(BaseHandler):
"state_key": "",
"content": {"topic": topic},
},
- ratelimit=False)
+ ratelimit=False,
+ )
for invitee in invite_list:
content = {}
@@ -635,30 +669,25 @@ class RoomCreationHandler(BaseHandler):
if room_alias:
result["room_alias"] = room_alias.to_string()
- yield directory_handler.send_room_alias_update_event(
- requester, room_id
- )
+ yield directory_handler.send_room_alias_update_event(requester, room_id)
defer.returnValue(result)
@defer.inlineCallbacks
def _send_events_for_new_room(
- self,
- creator, # A Requester object.
- room_id,
- preset_config,
- invite_list,
- initial_state,
- creation_content,
- room_alias=None,
- power_level_content_override=None,
- creator_join_profile=None,
+ self,
+ creator, # A Requester object.
+ room_id,
+ preset_config,
+ invite_list,
+ initial_state,
+ creation_content,
+ room_alias=None,
+ power_level_content_override=None,
+ creator_join_profile=None,
):
def create(etype, content, **kwargs):
- e = {
- "type": etype,
- "content": content,
- }
+ e = {"type": etype, "content": content}
e.update(event_keys)
e.update(kwargs)
@@ -670,26 +699,17 @@ class RoomCreationHandler(BaseHandler):
event = create(etype, content, **kwargs)
logger.info("Sending %s in new room", etype)
yield self.event_creation_handler.create_and_send_nonmember_event(
- creator,
- event,
- ratelimit=False
+ creator, event, ratelimit=False
)
config = RoomCreationHandler.PRESETS_DICT[preset_config]
creator_id = creator.user.to_string()
- event_keys = {
- "room_id": room_id,
- "sender": creator_id,
- "state_key": "",
- }
+ event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
creation_content.update({"creator": creator_id})
- yield send(
- etype=EventTypes.Create,
- content=creation_content,
- )
+ yield send(etype=EventTypes.Create, content=creation_content)
logger.info("Sending %s in new room", EventTypes.Member)
yield self.room_member_handler.update_membership(
@@ -703,17 +723,12 @@ class RoomCreationHandler(BaseHandler):
# We treat the power levels override specially as this needs to be one
# of the first events that get sent into a room.
- pl_content = initial_state.pop((EventTypes.PowerLevels, ''), None)
+ pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None)
if pl_content is not None:
- yield send(
- etype=EventTypes.PowerLevels,
- content=pl_content,
- )
+ yield send(etype=EventTypes.PowerLevels, content=pl_content)
else:
power_level_content = {
- "users": {
- creator_id: 100,
- },
+ "users": {creator_id: 100},
"users_default": 0,
"events": {
EventTypes.Name: 50,
@@ -737,42 +752,33 @@ class RoomCreationHandler(BaseHandler):
if power_level_content_override:
power_level_content.update(power_level_content_override)
- yield send(
- etype=EventTypes.PowerLevels,
- content=power_level_content,
- )
+ yield send(etype=EventTypes.PowerLevels, content=power_level_content)
- if room_alias and (EventTypes.CanonicalAlias, '') not in initial_state:
+ if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
yield send(
etype=EventTypes.CanonicalAlias,
content={"alias": room_alias.to_string()},
)
- if (EventTypes.JoinRules, '') not in initial_state:
+ if (EventTypes.JoinRules, "") not in initial_state:
yield send(
- etype=EventTypes.JoinRules,
- content={"join_rule": config["join_rules"]},
+ etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]}
)
- if (EventTypes.RoomHistoryVisibility, '') not in initial_state:
+ if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
yield send(
etype=EventTypes.RoomHistoryVisibility,
- content={"history_visibility": config["history_visibility"]}
+ content={"history_visibility": config["history_visibility"]},
)
if config["guest_can_join"]:
- if (EventTypes.GuestAccess, '') not in initial_state:
+ if (EventTypes.GuestAccess, "") not in initial_state:
yield send(
- etype=EventTypes.GuestAccess,
- content={"guest_access": "can_join"}
+ etype=EventTypes.GuestAccess, content={"guest_access": "can_join"}
)
for (etype, state_key), content in initial_state.items():
- yield send(
- etype=etype,
- state_key=state_key,
- content=content,
- )
+ yield send(etype=etype, state_key=state_key, content=content)
@defer.inlineCallbacks
def _generate_room_id(self, creator_id, is_public):
@@ -782,12 +788,9 @@ class RoomCreationHandler(BaseHandler):
while attempts < 5:
try:
random_string = stringutils.random_string(18)
- gen_room_id = RoomID(
- random_string,
- self.hs.hostname,
- ).to_string()
+ gen_room_id = RoomID(random_string, self.hs.hostname).to_string()
if isinstance(gen_room_id, bytes):
- gen_room_id = gen_room_id.decode('utf-8')
+ gen_room_id = gen_room_id.decode("utf-8")
yield self.store.store_room(
room_id=gen_room_id,
room_creator_user_id=creator_id,
@@ -821,7 +824,7 @@ class RoomContextHandler(object):
Returns:
dict, or None if the event isn't found
"""
- before_limit = math.floor(limit / 2.)
+ before_limit = math.floor(limit / 2.0)
after_limit = limit - before_limit
users = yield self.store.get_users_in_room(room_id)
@@ -829,24 +832,19 @@ class RoomContextHandler(object):
def filter_evts(events):
return filter_events_for_client(
- self.store,
- user.to_string(),
- events,
- is_peeking=is_peeking
+ self.store, user.to_string(), events, is_peeking=is_peeking
)
- event = yield self.store.get_event(event_id, get_prev_content=True,
- allow_none=True)
+ event = yield self.store.get_event(
+ event_id, get_prev_content=True, allow_none=True
+ )
if not event:
defer.returnValue(None)
return
- filtered = yield(filter_evts([event]))
+ filtered = yield (filter_evts([event]))
if not filtered:
- raise AuthError(
- 403,
- "You don't have permission to access that event."
- )
+ raise AuthError(403, "You don't have permission to access that event.")
results = yield self.store.get_events_around(
room_id, event_id, before_limit, after_limit, event_filter
@@ -878,7 +876,7 @@ class RoomContextHandler(object):
# https://github.com/matrix-org/matrix-doc/issues/687
state = yield self.store.get_state_for_events(
- [last_event_id], state_filter=state_filter,
+ [last_event_id], state_filter=state_filter
)
results["state"] = list(state[last_event_id].values())
@@ -890,9 +888,7 @@ class RoomContextHandler(object):
"room_key", results["start"]
).to_string()
- results["end"] = token.copy_and_replace(
- "room_key", results["end"]
- ).to_string()
+ results["end"] = token.copy_and_replace("room_key", results["end"]).to_string()
defer.returnValue(results)
@@ -903,13 +899,7 @@ class RoomEventSource(object):
@defer.inlineCallbacks
def get_new_events(
- self,
- user,
- from_key,
- limit,
- room_ids,
- is_guest,
- explicit_room_id=None,
+ self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None
):
# We just ignore the key for now.
@@ -920,9 +910,7 @@ class RoomEventSource(object):
logger.warn("Stream has topological part!!!! %r", from_key)
from_key = "s%s" % (from_token.stream,)
- app_service = self.store.get_app_service_by_user_id(
- user.to_string()
- )
+ app_service = self.store.get_app_service_by_user_id(user.to_string())
if app_service:
# We no longer support AS users using /sync directly.
# See https://github.com/matrix-org/matrix-doc/issues/1144
@@ -937,7 +925,7 @@ class RoomEventSource(object):
from_key=from_key,
to_key=to_key,
limit=limit or 10,
- order='ASC',
+ order="ASC",
)
events = list(room_events)
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 617d1c9ef8..aae696a7e8 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -46,13 +46,18 @@ class RoomListHandler(BaseHandler):
super(RoomListHandler, self).__init__(hs)
self.enable_room_list_search = hs.config.enable_room_list_search
self.response_cache = ResponseCache(hs, "room_list")
- self.remote_response_cache = ResponseCache(hs, "remote_room_list",
- timeout_ms=30 * 1000)
+ self.remote_response_cache = ResponseCache(
+ hs, "remote_room_list", timeout_ms=30 * 1000
+ )
- def get_local_public_room_list(self, limit=None, since_token=None,
- search_filter=None,
- network_tuple=EMPTY_THIRD_PARTY_ID,
- from_federation=False):
+ def get_local_public_room_list(
+ self,
+ limit=None,
+ since_token=None,
+ search_filter=None,
+ network_tuple=EMPTY_THIRD_PARTY_ID,
+ from_federation=False,
+ ):
"""Generate a local public room list.
There are multiple different lists: the main one plus one per third
@@ -68,14 +73,14 @@ class RoomListHandler(BaseHandler):
Setting to None returns all public rooms across all lists.
"""
if not self.enable_room_list_search:
- return defer.succeed({
- "chunk": [],
- "total_room_count_estimate": 0,
- })
+ return defer.succeed({"chunk": [], "total_room_count_estimate": 0})
logger.info(
"Getting public room list: limit=%r, since=%r, search=%r, network=%r",
- limit, since_token, bool(search_filter), network_tuple,
+ limit,
+ since_token,
+ bool(search_filter),
+ network_tuple,
)
if search_filter:
@@ -88,24 +93,33 @@ class RoomListHandler(BaseHandler):
# solution at some point
timeout = self.clock.time() + 60
return self._get_public_room_list(
- limit, since_token, search_filter,
- network_tuple=network_tuple, timeout=timeout,
+ limit,
+ since_token,
+ search_filter,
+ network_tuple=network_tuple,
+ timeout=timeout,
)
key = (limit, since_token, network_tuple)
return self.response_cache.wrap(
key,
self._get_public_room_list,
- limit, since_token,
- network_tuple=network_tuple, from_federation=from_federation,
+ limit,
+ since_token,
+ network_tuple=network_tuple,
+ from_federation=from_federation,
)
@defer.inlineCallbacks
- def _get_public_room_list(self, limit=None, since_token=None,
- search_filter=None,
- network_tuple=EMPTY_THIRD_PARTY_ID,
- from_federation=False,
- timeout=None,):
+ def _get_public_room_list(
+ self,
+ limit=None,
+ since_token=None,
+ search_filter=None,
+ network_tuple=EMPTY_THIRD_PARTY_ID,
+ from_federation=False,
+ timeout=None,
+ ):
"""Generate a public room list.
Args:
limit (int|None): Maximum amount of rooms to return.
@@ -135,15 +149,14 @@ class RoomListHandler(BaseHandler):
current_public_id = yield self.store.get_current_public_room_stream_id()
public_room_stream_id = since_token.public_room_stream_id
newly_visible, newly_unpublished = yield self.store.get_public_room_changes(
- public_room_stream_id, current_public_id,
- network_tuple=network_tuple,
+ public_room_stream_id, current_public_id, network_tuple=network_tuple
)
else:
stream_token = yield self.store.get_room_max_stream_ordering()
public_room_stream_id = yield self.store.get_current_public_room_stream_id()
room_ids = yield self.store.get_public_room_ids_at_stream_id(
- public_room_stream_id, network_tuple=network_tuple,
+ public_room_stream_id, network_tuple=network_tuple
)
# We want to return rooms in a particular order: the number of joined
@@ -168,7 +181,7 @@ class RoomListHandler(BaseHandler):
return
joined_users = yield self.state_handler.get_current_users_in_room(
- room_id, latest_event_ids,
+ room_id, latest_event_ids
)
num_joined_users = len(joined_users)
@@ -180,8 +193,9 @@ class RoomListHandler(BaseHandler):
# We want larger rooms to be first, hence negating num_joined_users
rooms_to_order_value[room_id] = (-num_joined_users, room_id)
- logger.info("Getting ordering for %i rooms since %s",
- len(room_ids), stream_token)
+ logger.info(
+ "Getting ordering for %i rooms since %s", len(room_ids), stream_token
+ )
yield concurrently_execute(get_order_for_room, room_ids, 10)
sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1])
@@ -193,7 +207,8 @@ class RoomListHandler(BaseHandler):
# Filter out rooms that we don't want to return
rooms_to_scan = [
- r for r in sorted_rooms
+ r
+ for r in sorted_rooms
if r not in newly_unpublished and rooms_to_num_joined[r] > 0
]
@@ -204,13 +219,12 @@ class RoomListHandler(BaseHandler):
# `since_token.current_limit` is the index of the last room we
# sent down, so we exclude it and everything before/after it.
if since_token.direction_is_forward:
- rooms_to_scan = rooms_to_scan[since_token.current_limit + 1:]
+ rooms_to_scan = rooms_to_scan[since_token.current_limit + 1 :]
else:
- rooms_to_scan = rooms_to_scan[:since_token.current_limit]
+ rooms_to_scan = rooms_to_scan[: since_token.current_limit]
rooms_to_scan.reverse()
- logger.info("After sorting and filtering, %i rooms remain",
- len(rooms_to_scan))
+ logger.info("After sorting and filtering, %i rooms remain", len(rooms_to_scan))
# _append_room_entry_to_chunk will append to chunk but will stop if
# len(chunk) > limit
@@ -237,15 +251,19 @@ class RoomListHandler(BaseHandler):
if timeout and self.clock.time() > timeout:
raise Exception("Timed out searching room directory")
- batch = rooms_to_scan[i:i + step]
+ batch = rooms_to_scan[i : i + step]
logger.info("Processing %i rooms for result", len(batch))
yield concurrently_execute(
lambda r: self._append_room_entry_to_chunk(
- r, rooms_to_num_joined[r],
- chunk, limit, search_filter,
+ r,
+ rooms_to_num_joined[r],
+ chunk,
+ limit,
+ search_filter,
from_federation=from_federation,
),
- batch, 5,
+ batch,
+ 5,
)
logger.info("Now %i rooms in result", len(chunk))
if len(chunk) >= limit + 1:
@@ -273,10 +291,7 @@ class RoomListHandler(BaseHandler):
new_limit = sorted_rooms.index(last_room_id)
- results = {
- "chunk": chunk,
- "total_room_count_estimate": total_room_count,
- }
+ results = {"chunk": chunk, "total_room_count_estimate": total_room_count}
if since_token:
results["new_rooms"] = bool(newly_visible)
@@ -313,8 +328,15 @@ class RoomListHandler(BaseHandler):
defer.returnValue(results)
@defer.inlineCallbacks
- def _append_room_entry_to_chunk(self, room_id, num_joined_users, chunk, limit,
- search_filter, from_federation=False):
+ def _append_room_entry_to_chunk(
+ self,
+ room_id,
+ num_joined_users,
+ chunk,
+ limit,
+ search_filter,
+ from_federation=False,
+ ):
"""Generate the entry for a room in the public room list and append it
to the `chunk` if it matches the search filter
@@ -345,8 +367,14 @@ class RoomListHandler(BaseHandler):
chunk.append(result)
@cachedInlineCallbacks(num_args=1, cache_context=True)
- def generate_room_entry(self, room_id, num_joined_users, cache_context,
- with_alias=True, allow_private=False):
+ def generate_room_entry(
+ self,
+ room_id,
+ num_joined_users,
+ cache_context,
+ with_alias=True,
+ allow_private=False,
+ ):
"""Returns the entry for a room
Args:
@@ -360,33 +388,31 @@ class RoomListHandler(BaseHandler):
Deferred[dict|None]: Returns a room entry as a dictionary, or None if this
room was determined not to be shown publicly.
"""
- result = {
- "room_id": room_id,
- "num_joined_members": num_joined_users,
- }
+ result = {"room_id": room_id, "num_joined_members": num_joined_users}
current_state_ids = yield self.store.get_current_state_ids(
- room_id, on_invalidate=cache_context.invalidate,
+ room_id, on_invalidate=cache_context.invalidate
)
- event_map = yield self.store.get_events([
- event_id for key, event_id in iteritems(current_state_ids)
- if key[0] in (
- EventTypes.Create,
- EventTypes.JoinRules,
- EventTypes.Name,
- EventTypes.Topic,
- EventTypes.CanonicalAlias,
- EventTypes.RoomHistoryVisibility,
- EventTypes.GuestAccess,
- "m.room.avatar",
- )
- ])
+ event_map = yield self.store.get_events(
+ [
+ event_id
+ for key, event_id in iteritems(current_state_ids)
+ if key[0]
+ in (
+ EventTypes.Create,
+ EventTypes.JoinRules,
+ EventTypes.Name,
+ EventTypes.Topic,
+ EventTypes.CanonicalAlias,
+ EventTypes.RoomHistoryVisibility,
+ EventTypes.GuestAccess,
+ "m.room.avatar",
+ )
+ ]
+ )
- current_state = {
- (ev.type, ev.state_key): ev
- for ev in event_map.values()
- }
+ current_state = {(ev.type, ev.state_key): ev for ev in event_map.values()}
# Double check that this is actually a public room.
@@ -446,14 +472,17 @@ class RoomListHandler(BaseHandler):
defer.returnValue(result)
@defer.inlineCallbacks
- def get_remote_public_room_list(self, server_name, limit=None, since_token=None,
- search_filter=None, include_all_networks=False,
- third_party_instance_id=None,):
+ def get_remote_public_room_list(
+ self,
+ server_name,
+ limit=None,
+ since_token=None,
+ search_filter=None,
+ include_all_networks=False,
+ third_party_instance_id=None,
+ ):
if not self.enable_room_list_search:
- defer.returnValue({
- "chunk": [],
- "total_room_count_estimate": 0,
- })
+ defer.returnValue({"chunk": [], "total_room_count_estimate": 0})
if search_filter:
# We currently don't support searching across federation, so we have
@@ -462,52 +491,75 @@ class RoomListHandler(BaseHandler):
since_token = None
res = yield self._get_remote_list_cached(
- server_name, limit=limit, since_token=since_token,
+ server_name,
+ limit=limit,
+ since_token=since_token,
include_all_networks=include_all_networks,
third_party_instance_id=third_party_instance_id,
)
if search_filter:
- res = {"chunk": [
- entry
- for entry in list(res.get("chunk", []))
- if _matches_room_entry(entry, search_filter)
- ]}
+ res = {
+ "chunk": [
+ entry
+ for entry in list(res.get("chunk", []))
+ if _matches_room_entry(entry, search_filter)
+ ]
+ }
defer.returnValue(res)
- def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
- search_filter=None, include_all_networks=False,
- third_party_instance_id=None,):
+ def _get_remote_list_cached(
+ self,
+ server_name,
+ limit=None,
+ since_token=None,
+ search_filter=None,
+ include_all_networks=False,
+ third_party_instance_id=None,
+ ):
repl_layer = self.hs.get_federation_client()
if search_filter:
# We can't cache when asking for search
return repl_layer.get_public_rooms(
- server_name, limit=limit, since_token=since_token,
- search_filter=search_filter, include_all_networks=include_all_networks,
+ server_name,
+ limit=limit,
+ since_token=since_token,
+ search_filter=search_filter,
+ include_all_networks=include_all_networks,
third_party_instance_id=third_party_instance_id,
)
key = (
- server_name, limit, since_token, include_all_networks,
+ server_name,
+ limit,
+ since_token,
+ include_all_networks,
third_party_instance_id,
)
return self.remote_response_cache.wrap(
key,
repl_layer.get_public_rooms,
- server_name, limit=limit, since_token=since_token,
+ server_name,
+ limit=limit,
+ since_token=since_token,
search_filter=search_filter,
include_all_networks=include_all_networks,
third_party_instance_id=third_party_instance_id,
)
-class RoomListNextBatch(namedtuple("RoomListNextBatch", (
- "stream_ordering", # stream_ordering of the first public room list
- "public_room_stream_id", # public room stream id for first public room list
- "current_limit", # The number of previous rooms returned
- "direction_is_forward", # Bool if this is a next_batch, false if prev_batch
-))):
+class RoomListNextBatch(
+ namedtuple(
+ "RoomListNextBatch",
+ (
+ "stream_ordering", # stream_ordering of the first public room list
+ "public_room_stream_id", # public room stream id for first public room list
+ "current_limit", # The number of previous rooms returned
+ "direction_is_forward", # Bool if this is a next_batch, false if prev_batch
+ ),
+ )
+):
KEY_DICT = {
"stream_ordering": "s",
@@ -527,21 +579,19 @@ class RoomListNextBatch(namedtuple("RoomListNextBatch", (
decoded = msgpack.loads(decode_base64(token), raw=False)
else:
decoded = msgpack.loads(decode_base64(token))
- return RoomListNextBatch(**{
- cls.REVERSE_KEY_DICT[key]: val
- for key, val in decoded.items()
- })
+ return RoomListNextBatch(
+ **{cls.REVERSE_KEY_DICT[key]: val for key, val in decoded.items()}
+ )
def to_token(self):
- return encode_base64(msgpack.dumps({
- self.KEY_DICT[key]: val
- for key, val in self._asdict().items()
- }))
+ return encode_base64(
+ msgpack.dumps(
+ {self.KEY_DICT[key]: val for key, val in self._asdict().items()}
+ )
+ )
def copy_and_replace(self, **kwds):
- return self._replace(
- **kwds
- )
+ return self._replace(**kwds)
def _matches_room_entry(room_entry, search_filter):
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 93ac986c86..4d6e883802 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -72,6 +72,7 @@ class RoomMemberHandler(object):
self.clock = hs.get_clock()
self.spam_checker = hs.get_spam_checker()
+ self.third_party_event_rules = hs.get_third_party_event_rules()
self._server_notices_mxid = self.config.server_notices_mxid
self._enable_lookup = hs.config.enable_3pid_lookup
self.allow_per_room_profiles = self.config.allow_per_room_profiles
@@ -165,7 +166,11 @@ class RoomMemberHandler(object):
@defer.inlineCallbacks
def _local_membership_update(
- self, requester, target, room_id, membership,
+ self,
+ requester,
+ target,
+ room_id,
+ membership,
prev_events_and_hashes,
txn_id=None,
ratelimit=True,
@@ -189,7 +194,6 @@ class RoomMemberHandler(object):
"room_id": room_id,
"sender": requester.user.to_string(),
"state_key": user_id,
-
# For backwards compatibility:
"membership": membership,
},
@@ -201,26 +205,19 @@ class RoomMemberHandler(object):
# Check if this event matches the previous membership event for the user.
duplicate = yield self.event_creation_handler.deduplicate_state_event(
- event, context,
+ event, context
)
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
defer.returnValue(duplicate)
yield self.event_creation_handler.handle_new_client_event(
- requester,
- event,
- context,
- extra_users=[target],
- ratelimit=ratelimit,
+ requester, event, context, extra_users=[target], ratelimit=ratelimit
)
prev_state_ids = yield context.get_prev_state_ids(self.store)
- prev_member_event_id = prev_state_ids.get(
- (EventTypes.Member, user_id),
- None
- )
+ prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
if event.membership == Membership.JOIN:
# Only fire user_joined_room if the user has actually joined the
@@ -242,11 +239,11 @@ class RoomMemberHandler(object):
if predecessor:
# It is an upgraded room. Copy over old tags
self.copy_room_tags_and_direct_to_room(
- predecessor["room_id"], room_id, user_id,
+ predecessor["room_id"], room_id, user_id
)
# Move over old push rules
self.store.move_push_rules_from_room_to_room_for_user(
- predecessor["room_id"], room_id, user_id,
+ predecessor["room_id"], room_id, user_id
)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
@@ -257,12 +254,7 @@ class RoomMemberHandler(object):
defer.returnValue(event)
@defer.inlineCallbacks
- def copy_room_tags_and_direct_to_room(
- self,
- old_room_id,
- new_room_id,
- user_id,
- ):
+ def copy_room_tags_and_direct_to_room(self, old_room_id, new_room_id, user_id):
"""Copies the tags and direct room state from one room to another.
Args:
@@ -274,9 +266,7 @@ class RoomMemberHandler(object):
Deferred[None]
"""
# Retrieve user account data for predecessor room
- user_account_data, _ = yield self.store.get_account_data_for_user(
- user_id,
- )
+ user_account_data, _ = yield self.store.get_account_data_for_user(user_id)
# Copy direct message state if applicable
direct_rooms = user_account_data.get("m.direct", {})
@@ -290,34 +280,30 @@ class RoomMemberHandler(object):
# Save back to user's m.direct account data
yield self.store.add_account_data_for_user(
- user_id, "m.direct", direct_rooms,
+ user_id, "m.direct", direct_rooms
)
break
# Copy room tags if applicable
- room_tags = yield self.store.get_tags_for_room(
- user_id, old_room_id,
- )
+ room_tags = yield self.store.get_tags_for_room(user_id, old_room_id)
# Copy each room tag to the new room
for tag, tag_content in room_tags.items():
- yield self.store.add_tag_to_room(
- user_id, new_room_id, tag, tag_content
- )
+ yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
@defer.inlineCallbacks
def update_membership(
- self,
- requester,
- target,
- room_id,
- action,
- txn_id=None,
- remote_room_hosts=None,
- third_party_signed=None,
- ratelimit=True,
- content=None,
- require_consent=True,
+ self,
+ requester,
+ target,
+ room_id,
+ action,
+ txn_id=None,
+ remote_room_hosts=None,
+ third_party_signed=None,
+ ratelimit=True,
+ content=None,
+ require_consent=True,
):
key = (room_id,)
@@ -339,17 +325,17 @@ class RoomMemberHandler(object):
@defer.inlineCallbacks
def _update_membership(
- self,
- requester,
- target,
- room_id,
- action,
- txn_id=None,
- remote_room_hosts=None,
- third_party_signed=None,
- ratelimit=True,
- content=None,
- require_consent=True,
+ self,
+ requester,
+ target,
+ room_id,
+ action,
+ txn_id=None,
+ remote_room_hosts=None,
+ third_party_signed=None,
+ ratelimit=True,
+ content=None,
+ require_consent=True,
):
content_specified = bool(content)
if content is None:
@@ -383,7 +369,7 @@ class RoomMemberHandler(object):
if not remote_room_hosts:
remote_room_hosts = []
- if effective_membership_state not in ("leave", "ban",):
+ if effective_membership_state not in ("leave", "ban"):
is_blocked = yield self.store.is_room_blocked(room_id)
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
@@ -391,22 +377,19 @@ class RoomMemberHandler(object):
if effective_membership_state == Membership.INVITE:
# block any attempts to invite the server notices mxid
if target.to_string() == self._server_notices_mxid:
- raise SynapseError(
- http_client.FORBIDDEN,
- "Cannot invite this user",
- )
+ raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user")
block_invite = False
- if (self._server_notices_mxid is not None and
- requester.user.to_string() == self._server_notices_mxid):
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
# allow the server notices mxid to send invites
is_requester_admin = True
else:
- is_requester_admin = yield self.auth.is_server_admin(
- requester.user,
- )
+ is_requester_admin = yield self.auth.is_server_admin(requester.user)
if not is_requester_admin:
if self.config.block_non_admin_invites:
@@ -417,25 +400,19 @@ class RoomMemberHandler(object):
block_invite = True
if not self.spam_checker.user_may_invite(
- requester.user.to_string(), target.to_string(), room_id,
+ requester.user.to_string(), target.to_string(), room_id
):
logger.info("Blocking invite due to spam checker")
block_invite = True
if block_invite:
- raise SynapseError(
- 403, "Invites have been disabled on this server",
- )
+ raise SynapseError(403, "Invites have been disabled on this server")
- prev_events_and_hashes = yield self.store.get_prev_events_for_room(
- room_id,
- )
- latest_event_ids = (
- event_id for (event_id, _, _) in prev_events_and_hashes
- )
+ prev_events_and_hashes = yield self.store.get_prev_events_for_room(room_id)
+ latest_event_ids = (event_id for (event_id, _, _) in prev_events_and_hashes)
current_state_ids = yield self.state_handler.get_current_state_ids(
- room_id, latest_event_ids=latest_event_ids,
+ room_id, latest_event_ids=latest_event_ids
)
# TODO: Refactor into dictionary of explicitly allowed transitions
@@ -450,13 +427,13 @@ class RoomMemberHandler(object):
403,
"Cannot unban user who was not banned"
" (membership=%s)" % old_membership,
- errcode=Codes.BAD_STATE
+ errcode=Codes.BAD_STATE,
)
if old_membership == "ban" and action != "unban":
raise SynapseError(
403,
"Cannot %s user who was banned" % (action,),
- errcode=Codes.BAD_STATE
+ errcode=Codes.BAD_STATE,
)
if old_state:
@@ -472,8 +449,8 @@ class RoomMemberHandler(object):
# we don't allow people to reject invites to the server notice
# room, but they can leave it once they are joined.
if (
- old_membership == Membership.INVITE and
- effective_membership_state == Membership.LEAVE
+ old_membership == Membership.INVITE
+ and effective_membership_state == Membership.LEAVE
):
is_blocked = yield self._is_server_notice_room(room_id)
if is_blocked:
@@ -534,7 +511,7 @@ class RoomMemberHandler(object):
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
res = yield self._remote_reject_invite(
- requester, remote_room_hosts, room_id, target,
+ requester, remote_room_hosts, room_id, target
)
defer.returnValue(res)
@@ -553,12 +530,7 @@ class RoomMemberHandler(object):
@defer.inlineCallbacks
def send_membership_event(
- self,
- requester,
- event,
- context,
- remote_room_hosts=None,
- ratelimit=True,
+ self, requester, event, context, remote_room_hosts=None, ratelimit=True
):
"""
Change the membership status of a user in a room.
@@ -584,16 +556,15 @@ class RoomMemberHandler(object):
if requester is not None:
sender = UserID.from_string(event.sender)
- assert sender == requester.user, (
- "Sender (%s) must be same as requester (%s)" %
- (sender, requester.user)
- )
+ assert (
+ sender == requester.user
+ ), "Sender (%s) must be same as requester (%s)" % (sender, requester.user)
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else:
requester = synapse.types.create_requester(target_user)
prev_event = yield self.event_creation_handler.deduplicate_state_event(
- event, context,
+ event, context
)
if prev_event is not None:
return
@@ -613,16 +584,11 @@ class RoomMemberHandler(object):
raise SynapseError(403, "This room has been blocked on this server")
yield self.event_creation_handler.handle_new_client_event(
- requester,
- event,
- context,
- extra_users=[target_user],
- ratelimit=ratelimit,
+ requester, event, context, extra_users=[target_user], ratelimit=ratelimit
)
prev_member_event_id = prev_state_ids.get(
- (EventTypes.Member, event.state_key),
- None
+ (EventTypes.Member, event.state_key), None
)
if event.membership == Membership.JOIN:
@@ -692,58 +658,45 @@ class RoomMemberHandler(object):
@defer.inlineCallbacks
def _get_inviter(self, user_id, room_id):
invite = yield self.store.get_invite_for_user_in_room(
- user_id=user_id,
- room_id=room_id,
+ user_id=user_id, room_id=room_id
)
if invite:
defer.returnValue(UserID.from_string(invite.sender))
@defer.inlineCallbacks
def do_3pid_invite(
- self,
- room_id,
- inviter,
- medium,
- address,
- id_server,
- requester,
- txn_id
+ self, room_id, inviter, medium, address, id_server, requester, txn_id
):
if self.config.block_non_admin_invites:
- is_requester_admin = yield self.auth.is_server_admin(
- requester.user,
- )
+ is_requester_admin = yield self.auth.is_server_admin(requester.user)
if not is_requester_admin:
raise SynapseError(
- 403, "Invites have been disabled on this server",
- Codes.FORBIDDEN,
+ 403, "Invites have been disabled on this server", Codes.FORBIDDEN
)
# We need to rate limit *before* we send out any 3PID invites, so we
# can't just rely on the standard ratelimiting of events.
yield self.base_handler.ratelimit(requester)
- invitee = yield self._lookup_3pid(
- id_server, medium, address
+ can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited(
+ medium, address, room_id
)
+ if not can_invite:
+ raise SynapseError(
+ 403,
+ "This third-party identifier can not be invited in this room",
+ Codes.FORBIDDEN,
+ )
+
+ invitee = yield self._lookup_3pid(id_server, medium, address)
if invitee:
yield self.update_membership(
- requester,
- UserID.from_string(invitee),
- room_id,
- "invite",
- txn_id=txn_id,
+ requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
)
else:
yield self._make_and_store_3pid_invite(
- requester,
- id_server,
- medium,
- address,
- room_id,
- inviter,
- txn_id=txn_id
+ requester, id_server, medium, address, room_id, inviter, txn_id=txn_id
)
@defer.inlineCallbacks
@@ -761,15 +714,12 @@ class RoomMemberHandler(object):
"""
if not self._enable_lookup:
raise SynapseError(
- 403, "Looking up third-party identifiers is denied from this server",
+ 403, "Looking up third-party identifiers is denied from this server"
)
try:
data = yield self.simple_http_client.get_json(
- "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
- {
- "medium": medium,
- "address": address,
- }
+ "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
+ {"medium": medium, "address": address},
)
if "mxid" in data:
@@ -788,29 +738,25 @@ class RoomMemberHandler(object):
raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items():
key_data = yield self.simple_http_client.get_json(
- "%s%s/_matrix/identity/api/v1/pubkey/%s" %
- (id_server_scheme, server_hostname, key_name,),
+ "%s%s/_matrix/identity/api/v1/pubkey/%s"
+ % (id_server_scheme, server_hostname, key_name)
)
if "public_key" not in key_data:
- raise AuthError(401, "No public key named %s from %s" %
- (key_name, server_hostname,))
+ raise AuthError(
+ 401, "No public key named %s from %s" % (key_name, server_hostname)
+ )
verify_signed_json(
data,
server_hostname,
- decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"]))
+ decode_verify_key_bytes(
+ key_name, decode_base64(key_data["public_key"])
+ ),
)
return
@defer.inlineCallbacks
def _make_and_store_3pid_invite(
- self,
- requester,
- id_server,
- medium,
- address,
- room_id,
- user,
- txn_id
+ self, requester, id_server, medium, address, room_id, user, txn_id
):
room_state = yield self.state_handler.get_current_state(room_id)
@@ -858,7 +804,7 @@ class RoomMemberHandler(object):
room_join_rules=room_join_rules,
room_name=room_name,
inviter_display_name=inviter_display_name,
- inviter_avatar_url=inviter_avatar_url
+ inviter_avatar_url=inviter_avatar_url,
)
)
@@ -869,7 +815,6 @@ class RoomMemberHandler(object):
"content": {
"display_name": display_name,
"public_keys": public_keys,
-
# For backwards compatibility:
"key_validity_url": fallback_public_key["key_validity_url"],
"public_key": fallback_public_key["public_key"],
@@ -883,19 +828,19 @@ class RoomMemberHandler(object):
@defer.inlineCallbacks
def _ask_id_server_for_third_party_invite(
- self,
- requester,
- id_server,
- medium,
- address,
- room_id,
- inviter_user_id,
- room_alias,
- room_avatar_url,
- room_join_rules,
- room_name,
- inviter_display_name,
- inviter_avatar_url
+ self,
+ requester,
+ id_server,
+ medium,
+ address,
+ room_id,
+ inviter_user_id,
+ room_alias,
+ room_avatar_url,
+ room_join_rules,
+ room_name,
+ inviter_display_name,
+ inviter_avatar_url,
):
"""
Asks an identity server for a third party invite.
@@ -927,7 +872,8 @@ class RoomMemberHandler(object):
"""
is_url = "%s%s/_matrix/identity/api/v1/store-invite" % (
- id_server_scheme, id_server,
+ id_server_scheme,
+ id_server,
)
invite_config = {
@@ -951,14 +897,15 @@ class RoomMemberHandler(object):
inviter_user_id=inviter_user_id,
)
- invite_config.update({
- "guest_access_token": guest_access_token,
- "guest_user_id": guest_user_id,
- })
+ invite_config.update(
+ {
+ "guest_access_token": guest_access_token,
+ "guest_user_id": guest_user_id,
+ }
+ )
data = yield self.simple_http_client.post_urlencoded_get_json(
- is_url,
- invite_config
+ is_url, invite_config
)
# TODO: Check for success
token = data["token"]
@@ -966,9 +913,8 @@ class RoomMemberHandler(object):
if "public_key" in data:
fallback_public_key = {
"public_key": data["public_key"],
- "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
- id_server_scheme, id_server,
- ),
+ "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid"
+ % (id_server_scheme, id_server),
}
else:
fallback_public_key = public_keys[0]
@@ -1037,10 +983,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
yield self.federation_handler.do_invite_join(
- remote_room_hosts,
- room_id,
- user.to_string(),
- content,
+ remote_room_hosts, room_id, user.to_string(), content
)
yield self._user_joined_room(user, room_id)
@@ -1051,9 +994,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
fed_handler = self.federation_handler
try:
ret = yield fed_handler.do_remotely_reject_invite(
- remote_room_hosts,
- room_id,
- target.to_string(),
+ remote_room_hosts, room_id, target.to_string()
)
defer.returnValue(ret)
except Exception as e:
@@ -1065,9 +1006,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
#
logger.warn("Failed to reject invite: %s", e)
- yield self.store.locally_reject_invite(
- target.to_string(), room_id
- )
+ yield self.store.locally_reject_invite(target.to_string(), room_id)
defer.returnValue({})
def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
@@ -1091,18 +1030,15 @@ class RoomMemberMasterHandler(RoomMemberHandler):
user_id = user.to_string()
member = yield self.state_handler.get_current_state(
- room_id=room_id,
- event_type=EventTypes.Member,
- state_key=user_id
+ room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
membership = member.membership if member else None
if membership is not None and membership not in [
- Membership.LEAVE, Membership.BAN
+ Membership.LEAVE,
+ Membership.BAN,
]:
- raise SynapseError(400, "User %s in room %s" % (
- user_id, room_id
- ))
+ raise SynapseError(400, "User %s in room %s" % (user_id, room_id))
if membership:
yield self.store.forget(user_id, room_id)
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index acc6eb8099..da501f38c0 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -71,18 +71,14 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
"""Implements RoomMemberHandler._user_joined_room
"""
return self._notify_change_client(
- user_id=target.to_string(),
- room_id=room_id,
- change="joined",
+ user_id=target.to_string(), room_id=room_id, change="joined"
)
def _user_left_room(self, target, room_id):
"""Implements RoomMemberHandler._user_left_room
"""
return self._notify_change_client(
- user_id=target.to_string(),
- room_id=room_id,
- change="left",
+ user_id=target.to_string(), room_id=room_id, change="left"
)
def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 9bba74d6c9..ddc4430d03 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -32,7 +32,6 @@ logger = logging.getLogger(__name__)
class SearchHandler(BaseHandler):
-
def __init__(self, hs):
super(SearchHandler, self).__init__(hs)
self._event_serializer = hs.get_event_client_serializer()
@@ -93,7 +92,7 @@ class SearchHandler(BaseHandler):
batch_token = None
if batch:
try:
- b = decode_base64(batch).decode('ascii')
+ b = decode_base64(batch).decode("ascii")
batch_group, batch_group_key, batch_token = b.split("\n")
assert batch_group is not None
@@ -104,7 +103,9 @@ class SearchHandler(BaseHandler):
logger.info(
"Search batch properties: %r, %r, %r",
- batch_group, batch_group_key, batch_token,
+ batch_group,
+ batch_group_key,
+ batch_token,
)
logger.info("Search content: %s", content)
@@ -116,9 +117,9 @@ class SearchHandler(BaseHandler):
search_term = room_cat["search_term"]
# Which "keys" to search over in FTS query
- keys = room_cat.get("keys", [
- "content.body", "content.name", "content.topic",
- ])
+ keys = room_cat.get(
+ "keys", ["content.body", "content.name", "content.topic"]
+ )
# Filter to apply to results
filter_dict = room_cat.get("filter", {})
@@ -130,9 +131,7 @@ class SearchHandler(BaseHandler):
include_state = room_cat.get("include_state", False)
# Include context around each event?
- event_context = room_cat.get(
- "event_context", None
- )
+ event_context = room_cat.get("event_context", None)
# Group results together? May allow clients to paginate within a
# group
@@ -140,12 +139,8 @@ class SearchHandler(BaseHandler):
group_keys = [g["key"] for g in group_by]
if event_context is not None:
- before_limit = int(event_context.get(
- "before_limit", 5
- ))
- after_limit = int(event_context.get(
- "after_limit", 5
- ))
+ before_limit = int(event_context.get("before_limit", 5))
+ after_limit = int(event_context.get("after_limit", 5))
# Return the historic display name and avatar for the senders
# of the events?
@@ -159,7 +154,8 @@ class SearchHandler(BaseHandler):
if set(group_keys) - {"room_id", "sender"}:
raise SynapseError(
400,
- "Invalid group by keys: %r" % (set(group_keys) - {"room_id", "sender"},)
+ "Invalid group by keys: %r"
+ % (set(group_keys) - {"room_id", "sender"},),
)
search_filter = Filter(filter_dict)
@@ -190,15 +186,13 @@ class SearchHandler(BaseHandler):
room_ids.intersection_update({batch_group_key})
if not room_ids:
- defer.returnValue({
- "search_categories": {
- "room_events": {
- "results": [],
- "count": 0,
- "highlights": [],
+ defer.returnValue(
+ {
+ "search_categories": {
+ "room_events": {"results": [], "count": 0, "highlights": []}
}
}
- })
+ )
rank_map = {} # event_id -> rank of event
allowed_events = []
@@ -213,9 +207,7 @@ class SearchHandler(BaseHandler):
count = None
if order_by == "rank":
- search_result = yield self.store.search_msgs(
- room_ids, search_term, keys
- )
+ search_result = yield self.store.search_msgs(room_ids, search_term, keys)
count = search_result["count"]
@@ -235,19 +227,17 @@ class SearchHandler(BaseHandler):
)
events.sort(key=lambda e: -rank_map[e.event_id])
- allowed_events = events[:search_filter.limit()]
+ allowed_events = events[: search_filter.limit()]
for e in allowed_events:
- rm = room_groups.setdefault(e.room_id, {
- "results": [],
- "order": rank_map[e.event_id],
- })
+ rm = room_groups.setdefault(
+ e.room_id, {"results": [], "order": rank_map[e.event_id]}
+ )
rm["results"].append(e.event_id)
- s = sender_group.setdefault(e.sender, {
- "results": [],
- "order": rank_map[e.event_id],
- })
+ s = sender_group.setdefault(
+ e.sender, {"results": [], "order": rank_map[e.event_id]}
+ )
s["results"].append(e.event_id)
elif order_by == "recent":
@@ -262,7 +252,10 @@ class SearchHandler(BaseHandler):
while len(room_events) < search_filter.limit() and i < 5:
i += 1
search_result = yield self.store.search_rooms(
- room_ids, search_term, keys, search_filter.limit() * 2,
+ room_ids,
+ search_term,
+ keys,
+ search_filter.limit() * 2,
pagination_token=pagination_token,
)
@@ -277,16 +270,14 @@ class SearchHandler(BaseHandler):
rank_map.update({r["event"].event_id: r["rank"] for r in results})
- filtered_events = search_filter.filter([
- r["event"] for r in results
- ])
+ filtered_events = search_filter.filter([r["event"] for r in results])
events = yield filter_events_for_client(
self.store, user.to_string(), filtered_events
)
room_events.extend(events)
- room_events = room_events[:search_filter.limit()]
+ room_events = room_events[: search_filter.limit()]
if len(results) < search_filter.limit() * 2:
pagination_token = None
@@ -295,9 +286,7 @@ class SearchHandler(BaseHandler):
pagination_token = results[-1]["pagination_token"]
for event in room_events:
- group = room_groups.setdefault(event.room_id, {
- "results": [],
- })
+ group = room_groups.setdefault(event.room_id, {"results": []})
group["results"].append(event.event_id)
if room_events and len(room_events) >= search_filter.limit():
@@ -309,18 +298,23 @@ class SearchHandler(BaseHandler):
# it returns more from the same group (if applicable) rather
# than reverting to searching all results again.
if batch_group and batch_group_key:
- global_next_batch = encode_base64(("%s\n%s\n%s" % (
- batch_group, batch_group_key, pagination_token
- )).encode('ascii'))
+ global_next_batch = encode_base64(
+ (
+ "%s\n%s\n%s"
+ % (batch_group, batch_group_key, pagination_token)
+ ).encode("ascii")
+ )
else:
- global_next_batch = encode_base64(("%s\n%s\n%s" % (
- "all", "", pagination_token
- )).encode('ascii'))
+ global_next_batch = encode_base64(
+ ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii")
+ )
for room_id, group in room_groups.items():
- group["next_batch"] = encode_base64(("%s\n%s\n%s" % (
- "room_id", room_id, pagination_token
- )).encode('ascii'))
+ group["next_batch"] = encode_base64(
+ ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode(
+ "ascii"
+ )
+ )
allowed_events.extend(room_events)
@@ -338,12 +332,13 @@ class SearchHandler(BaseHandler):
contexts = {}
for event in allowed_events:
res = yield self.store.get_events_around(
- event.room_id, event.event_id, before_limit, after_limit,
+ event.room_id, event.event_id, before_limit, after_limit
)
logger.info(
"Context for search returned %d and %d events",
- len(res["events_before"]), len(res["events_after"]),
+ len(res["events_before"]),
+ len(res["events_after"]),
)
res["events_before"] = yield filter_events_for_client(
@@ -403,12 +398,12 @@ class SearchHandler(BaseHandler):
for context in contexts.values():
context["events_before"] = (
yield self._event_serializer.serialize_events(
- context["events_before"], time_now,
+ context["events_before"], time_now
)
)
context["events_after"] = (
yield self._event_serializer.serialize_events(
- context["events_after"], time_now,
+ context["events_after"], time_now
)
)
@@ -426,11 +421,15 @@ class SearchHandler(BaseHandler):
results = []
for e in allowed_events:
- results.append({
- "rank": rank_map[e.event_id],
- "result": (yield self._event_serializer.serialize_event(e, time_now)),
- "context": contexts.get(e.event_id, {}),
- })
+ results.append(
+ {
+ "rank": rank_map[e.event_id],
+ "result": (
+ yield self._event_serializer.serialize_event(e, time_now)
+ ),
+ "context": contexts.get(e.event_id, {}),
+ }
+ )
rooms_cat_res = {
"results": results,
@@ -442,7 +441,7 @@ class SearchHandler(BaseHandler):
s = {}
for room_id, state in state_results.items():
s[room_id] = yield self._event_serializer.serialize_events(
- state, time_now,
+ state, time_now
)
rooms_cat_res["state"] = s
@@ -456,8 +455,4 @@ class SearchHandler(BaseHandler):
if global_next_batch:
rooms_cat_res["next_batch"] = global_next_batch
- defer.returnValue({
- "search_categories": {
- "room_events": rooms_cat_res
- }
- })
+ defer.returnValue({"search_categories": {"room_events": rooms_cat_res}})
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 7ecdede4dc..5a0995d4fe 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -25,6 +25,7 @@ logger = logging.getLogger(__name__)
class SetPasswordHandler(BaseHandler):
"""Handler which deals with changing user account passwords"""
+
def __init__(self, hs):
super(SetPasswordHandler, self).__init__(hs)
self._auth_handler = hs.get_auth_handler()
@@ -47,11 +48,11 @@ class SetPasswordHandler(BaseHandler):
# we want to log out all of the user's other sessions. First delete
# all his other devices.
yield self._device_handler.delete_all_devices_for_user(
- user_id, except_device_id=except_device_id,
+ user_id, except_device_id=except_device_id
)
# and now delete any access tokens which weren't associated with
# devices (or were associated with this device).
yield self._auth_handler.delete_access_tokens_for_user(
- user_id, except_token_id=except_access_token_id,
+ user_id, except_token_id=except_access_token_id
)
diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py
index b268bbcb2c..6b364befd5 100644
--- a/synapse/handlers/state_deltas.py
+++ b/synapse/handlers/state_deltas.py
@@ -21,7 +21,6 @@ logger = logging.getLogger(__name__)
class StateDeltasHandler(object):
-
def __init__(self, hs):
self.store = hs.get_datastore()
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 7ad16c8566..a0ee8db988 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -156,7 +156,7 @@ class StatsHandler(StateDeltasHandler):
prev_event_content = {}
if prev_event_id is not None:
prev_event = yield self.store.get_event(
- prev_event_id, allow_none=True,
+ prev_event_id, allow_none=True
)
if prev_event:
prev_event_content = prev_event.content
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 62fda0c664..c5188a1f8e 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -64,20 +64,14 @@ LAZY_LOADED_MEMBERS_CACHE_MAX_AGE = 30 * 60 * 1000
LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100
-SyncConfig = collections.namedtuple("SyncConfig", [
- "user",
- "filter_collection",
- "is_guest",
- "request_key",
- "device_id",
-])
-
-
-class TimelineBatch(collections.namedtuple("TimelineBatch", [
- "prev_batch",
- "events",
- "limited",
-])):
+SyncConfig = collections.namedtuple(
+ "SyncConfig", ["user", "filter_collection", "is_guest", "request_key", "device_id"]
+)
+
+
+class TimelineBatch(
+ collections.namedtuple("TimelineBatch", ["prev_batch", "events", "limited"])
+):
__slots__ = []
def __nonzero__(self):
@@ -85,18 +79,24 @@ class TimelineBatch(collections.namedtuple("TimelineBatch", [
to tell if room needs to be part of the sync result.
"""
return bool(self.events)
+
__bool__ = __nonzero__ # python3
-class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
- "room_id", # str
- "timeline", # TimelineBatch
- "state", # dict[(str, str), FrozenEvent]
- "ephemeral",
- "account_data",
- "unread_notifications",
- "summary",
-])):
+class JoinedSyncResult(
+ collections.namedtuple(
+ "JoinedSyncResult",
+ [
+ "room_id", # str
+ "timeline", # TimelineBatch
+ "state", # dict[(str, str), FrozenEvent]
+ "ephemeral",
+ "account_data",
+ "unread_notifications",
+ "summary",
+ ],
+ )
+):
__slots__ = []
def __nonzero__(self):
@@ -111,77 +111,93 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
# nb the notification count does not, er, count: if there's nothing
# else in the result, we don't need to send it.
)
+
__bool__ = __nonzero__ # python3
-class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [
- "room_id", # str
- "timeline", # TimelineBatch
- "state", # dict[(str, str), FrozenEvent]
- "account_data",
-])):
+class ArchivedSyncResult(
+ collections.namedtuple(
+ "ArchivedSyncResult",
+ [
+ "room_id", # str
+ "timeline", # TimelineBatch
+ "state", # dict[(str, str), FrozenEvent]
+ "account_data",
+ ],
+ )
+):
__slots__ = []
def __nonzero__(self):
"""Make the result appear empty if there are no updates. This is used
to tell if room needs to be part of the sync result.
"""
- return bool(
- self.timeline
- or self.state
- or self.account_data
- )
+ return bool(self.timeline or self.state or self.account_data)
+
__bool__ = __nonzero__ # python3
-class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
- "room_id", # str
- "invite", # FrozenEvent: the invite event
-])):
+class InvitedSyncResult(
+ collections.namedtuple(
+ "InvitedSyncResult",
+ ["room_id", "invite"], # str # FrozenEvent: the invite event
+ )
+):
__slots__ = []
def __nonzero__(self):
"""Invited rooms should always be reported to the client"""
return True
+
__bool__ = __nonzero__ # python3
-class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [
- "join",
- "invite",
- "leave",
-])):
+class GroupsSyncResult(
+ collections.namedtuple("GroupsSyncResult", ["join", "invite", "leave"])
+):
__slots__ = []
def __nonzero__(self):
return bool(self.join or self.invite or self.leave)
+
__bool__ = __nonzero__ # python3
-class DeviceLists(collections.namedtuple("DeviceLists", [
- "changed", # list of user_ids whose devices may have changed
- "left", # list of user_ids whose devices we no longer track
-])):
+class DeviceLists(
+ collections.namedtuple(
+ "DeviceLists",
+ [
+ "changed", # list of user_ids whose devices may have changed
+ "left", # list of user_ids whose devices we no longer track
+ ],
+ )
+):
__slots__ = []
def __nonzero__(self):
return bool(self.changed or self.left)
+
__bool__ = __nonzero__ # python3
-class SyncResult(collections.namedtuple("SyncResult", [
- "next_batch", # Token for the next sync
- "presence", # List of presence events for the user.
- "account_data", # List of account_data events for the user.
- "joined", # JoinedSyncResult for each joined room.
- "invited", # InvitedSyncResult for each invited room.
- "archived", # ArchivedSyncResult for each archived room.
- "to_device", # List of direct messages for the device.
- "device_lists", # List of user_ids whose devices have changed
- "device_one_time_keys_count", # Dict of algorithm to count for one time keys
- # for this device
- "groups",
-])):
+class SyncResult(
+ collections.namedtuple(
+ "SyncResult",
+ [
+ "next_batch", # Token for the next sync
+ "presence", # List of presence events for the user.
+ "account_data", # List of account_data events for the user.
+ "joined", # JoinedSyncResult for each joined room.
+ "invited", # InvitedSyncResult for each invited room.
+ "archived", # ArchivedSyncResult for each archived room.
+ "to_device", # List of direct messages for the device.
+ "device_lists", # List of user_ids whose devices have changed
+ "device_one_time_keys_count", # Dict of algorithm to count for one time keys
+ # for this device
+ "groups",
+ ],
+ )
+):
__slots__ = []
def __nonzero__(self):
@@ -190,20 +206,20 @@ class SyncResult(collections.namedtuple("SyncResult", [
events.
"""
return bool(
- self.presence or
- self.joined or
- self.invited or
- self.archived or
- self.account_data or
- self.to_device or
- self.device_lists or
- self.groups
+ self.presence
+ or self.joined
+ or self.invited
+ or self.archived
+ or self.account_data
+ or self.to_device
+ or self.device_lists
+ or self.groups
)
+
__bool__ = __nonzero__ # python3
class SyncHandler(object):
-
def __init__(self, hs):
self.hs_config = hs.config
self.store = hs.get_datastore()
@@ -217,13 +233,16 @@ class SyncHandler(object):
# ExpiringCache((User, Device)) -> LruCache(state_key => event_id)
self.lazy_loaded_members_cache = ExpiringCache(
- "lazy_loaded_members_cache", self.clock,
- max_len=0, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
+ "lazy_loaded_members_cache",
+ self.clock,
+ max_len=0,
+ expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
)
@defer.inlineCallbacks
- def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
- full_state=False):
+ def wait_for_sync_for_user(
+ self, sync_config, since_token=None, timeout=0, full_state=False
+ ):
"""Get the sync for a client if we have new data for it now. Otherwise
wait for new data to arrive on the server. If the timeout expires, then
return an empty sync result.
@@ -239,13 +258,15 @@ class SyncHandler(object):
res = yield self.response_cache.wrap(
sync_config.request_key,
self._wait_for_sync_for_user,
- sync_config, since_token, timeout, full_state,
+ sync_config,
+ since_token,
+ timeout,
+ full_state,
)
defer.returnValue(res)
@defer.inlineCallbacks
- def _wait_for_sync_for_user(self, sync_config, since_token, timeout,
- full_state):
+ def _wait_for_sync_for_user(self, sync_config, since_token, timeout, full_state):
if since_token is None:
sync_type = "initial_sync"
elif full_state:
@@ -261,14 +282,17 @@ class SyncHandler(object):
# we are going to return immediately, so don't bother calling
# notifier.wait_for_events.
result = yield self.current_sync_for_user(
- sync_config, since_token, full_state=full_state,
+ sync_config, since_token, full_state=full_state
)
else:
+
def current_sync_callback(before_token, after_token):
return self.current_sync_for_user(sync_config, since_token)
result = yield self.notifier.wait_for_events(
- sync_config.user.to_string(), timeout, current_sync_callback,
+ sync_config.user.to_string(),
+ timeout,
+ current_sync_callback,
from_token=since_token,
)
@@ -281,8 +305,7 @@ class SyncHandler(object):
defer.returnValue(result)
- def current_sync_for_user(self, sync_config, since_token=None,
- full_state=False):
+ def current_sync_for_user(self, sync_config, since_token=None, full_state=False):
"""Get the sync for client needed to match what the server has now.
Returns:
A Deferred SyncResult.
@@ -334,8 +357,7 @@ class SyncHandler(object):
# result returned by the event source is poor form (it might cache
# the object)
room_id = event["room_id"]
- event_copy = {k: v for (k, v) in iteritems(event)
- if k != "room_id"}
+ event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
receipt_key = since_token.receipt_key if since_token else "0"
@@ -353,22 +375,30 @@ class SyncHandler(object):
for event in receipts:
room_id = event["room_id"]
# exclude room id, as above
- event_copy = {k: v for (k, v) in iteritems(event)
- if k != "room_id"}
+ event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
defer.returnValue((now_token, ephemeral_by_room))
@defer.inlineCallbacks
- def _load_filtered_recents(self, room_id, sync_config, now_token,
- since_token=None, recents=None, newly_joined_room=False):
+ def _load_filtered_recents(
+ self,
+ room_id,
+ sync_config,
+ now_token,
+ since_token=None,
+ recents=None,
+ newly_joined_room=False,
+ ):
"""
Returns:
a Deferred TimelineBatch
"""
with Measure(self.clock, "load_filtered_recents"):
timeline_limit = sync_config.filter_collection.timeline_limit()
- block_all_timeline = sync_config.filter_collection.blocks_all_room_timeline()
+ block_all_timeline = (
+ sync_config.filter_collection.blocks_all_room_timeline()
+ )
if recents is None or newly_joined_room or timeline_limit < len(recents):
limited = True
@@ -396,11 +426,9 @@ class SyncHandler(object):
recents = []
if not limited or block_all_timeline:
- defer.returnValue(TimelineBatch(
- events=recents,
- prev_batch=now_token,
- limited=False
- ))
+ defer.returnValue(
+ TimelineBatch(events=recents, prev_batch=now_token, limited=False)
+ )
filtering_factor = 2
load_limit = max(timeline_limit * filtering_factor, 10)
@@ -427,9 +455,7 @@ class SyncHandler(object):
)
else:
events, end_key = yield self.store.get_recent_events_for_room(
- room_id,
- limit=load_limit + 1,
- end_token=end_key,
+ room_id, limit=load_limit + 1, end_token=end_key
)
loaded_recents = sync_config.filter_collection.filter_room_timeline(
events
@@ -462,15 +488,15 @@ class SyncHandler(object):
recents = recents[-timeline_limit:]
room_key = recents[0].internal_metadata.before
- prev_batch_token = now_token.copy_and_replace(
- "room_key", room_key
- )
+ prev_batch_token = now_token.copy_and_replace("room_key", room_key)
- defer.returnValue(TimelineBatch(
- events=recents,
- prev_batch=prev_batch_token,
- limited=limited or newly_joined_room
- ))
+ defer.returnValue(
+ TimelineBatch(
+ events=recents,
+ prev_batch=prev_batch_token,
+ limited=limited or newly_joined_room,
+ )
+ )
@defer.inlineCallbacks
def get_state_after_event(self, event, state_filter=StateFilter.all()):
@@ -486,7 +512,7 @@ class SyncHandler(object):
A Deferred map from ((type, state_key)->Event)
"""
state_ids = yield self.store.get_state_ids_for_event(
- event.event_id, state_filter=state_filter,
+ event.event_id, state_filter=state_filter
)
if event.is_state():
state_ids = state_ids.copy()
@@ -511,13 +537,13 @@ class SyncHandler(object):
# does not reliably give you the state at the given stream position.
# (https://github.com/matrix-org/synapse/issues/3305)
last_events, _ = yield self.store.get_recent_events_for_room(
- room_id, end_token=stream_position.room_key, limit=1,
+ room_id, end_token=stream_position.room_key, limit=1
)
if last_events:
last_event = last_events[-1]
state = yield self.get_state_after_event(
- last_event, state_filter=state_filter,
+ last_event, state_filter=state_filter
)
else:
@@ -549,7 +575,7 @@ class SyncHandler(object):
# FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305
last_events, _ = yield self.store.get_recent_event_ids_for_room(
- room_id, end_token=now_token.room_key, limit=1,
+ room_id, end_token=now_token.room_key, limit=1
)
if not last_events:
@@ -559,28 +585,25 @@ class SyncHandler(object):
last_event = last_events[-1]
state_ids = yield self.store.get_state_ids_for_event(
last_event.event_id,
- state_filter=StateFilter.from_types([
- (EventTypes.Name, ''),
- (EventTypes.CanonicalAlias, ''),
- ]),
+ state_filter=StateFilter.from_types(
+ [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
+ ),
)
# this is heavily cached, thus: fast.
details = yield self.store.get_room_summary(room_id)
- name_id = state_ids.get((EventTypes.Name, ''))
- canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, ''))
+ name_id = state_ids.get((EventTypes.Name, ""))
+ canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, ""))
summary = {}
empty_ms = MemberSummary([], 0)
# TODO: only send these when they change.
- summary["m.joined_member_count"] = (
- details.get(Membership.JOIN, empty_ms).count
- )
- summary["m.invited_member_count"] = (
- details.get(Membership.INVITE, empty_ms).count
- )
+ summary["m.joined_member_count"] = details.get(Membership.JOIN, empty_ms).count
+ summary["m.invited_member_count"] = details.get(
+ Membership.INVITE, empty_ms
+ ).count
# if the room has a name or canonical_alias set, we can skip
# calculating heroes. Empty strings are falsey, so we check
@@ -592,7 +615,7 @@ class SyncHandler(object):
if canonical_alias_id:
canonical_alias = yield self.store.get_event(
- canonical_alias_id, allow_none=True,
+ canonical_alias_id, allow_none=True
)
if canonical_alias and canonical_alias.content.get("alias"):
defer.returnValue(summary)
@@ -600,26 +623,14 @@ class SyncHandler(object):
me = sync_config.user.to_string()
joined_user_ids = [
- r[0]
- for r in details.get(Membership.JOIN, empty_ms).members
- if r[0] != me
+ r[0] for r in details.get(Membership.JOIN, empty_ms).members if r[0] != me
]
invited_user_ids = [
- r[0]
- for r in details.get(Membership.INVITE, empty_ms).members
- if r[0] != me
+ r[0] for r in details.get(Membership.INVITE, empty_ms).members if r[0] != me
]
- gone_user_ids = (
- [
- r[0]
- for r in details.get(Membership.LEAVE, empty_ms).members
- if r[0] != me
- ] + [
- r[0]
- for r in details.get(Membership.BAN, empty_ms).members
- if r[0] != me
- ]
- )
+ gone_user_ids = [
+ r[0] for r in details.get(Membership.LEAVE, empty_ms).members if r[0] != me
+ ] + [r[0] for r in details.get(Membership.BAN, empty_ms).members if r[0] != me]
# FIXME: only build up a member_ids list for our heroes
member_ids = {}
@@ -627,20 +638,18 @@ class SyncHandler(object):
Membership.JOIN,
Membership.INVITE,
Membership.LEAVE,
- Membership.BAN
+ Membership.BAN,
):
for user_id, event_id in details.get(membership, empty_ms).members:
member_ids[user_id] = event_id
# FIXME: order by stream ordering rather than as returned by SQL
- if (joined_user_ids or invited_user_ids):
- summary['m.heroes'] = sorted(
+ if joined_user_ids or invited_user_ids:
+ summary["m.heroes"] = sorted(
[user_id for user_id in (joined_user_ids + invited_user_ids)]
)[0:5]
else:
- summary['m.heroes'] = sorted(
- [user_id for user_id in gone_user_ids]
- )[0:5]
+ summary["m.heroes"] = sorted([user_id for user_id in gone_user_ids])[0:5]
if not sync_config.filter_collection.lazy_load_members():
defer.returnValue(summary)
@@ -652,8 +661,7 @@ class SyncHandler(object):
# track which members the client should already know about via LL:
# Ones which are already in state...
existing_members = set(
- user_id for (typ, user_id) in state.keys()
- if typ == EventTypes.Member
+ user_id for (typ, user_id) in state.keys() if typ == EventTypes.Member
)
# ...or ones which are in the timeline...
@@ -664,10 +672,10 @@ class SyncHandler(object):
# ...and then ensure any missing ones get included in state.
missing_hero_event_ids = [
member_ids[hero_id]
- for hero_id in summary['m.heroes']
+ for hero_id in summary["m.heroes"]
if (
- cache.get(hero_id) != member_ids[hero_id] and
- hero_id not in existing_members
+ cache.get(hero_id) != member_ids[hero_id]
+ and hero_id not in existing_members
)
]
@@ -691,8 +699,9 @@ class SyncHandler(object):
return cache
@defer.inlineCallbacks
- def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token,
- full_state):
+ def compute_state_delta(
+ self, room_id, batch, sync_config, since_token, now_token, full_state
+ ):
""" Works out the difference in state between the start of the timeline
and the previous sync.
@@ -745,23 +754,23 @@ class SyncHandler(object):
timeline_state = {
(event.type, event.state_key): event.event_id
- for event in batch.events if event.is_state()
+ for event in batch.events
+ if event.is_state()
}
if full_state:
if batch:
current_state_ids = yield self.store.get_state_ids_for_event(
- batch.events[-1].event_id, state_filter=state_filter,
+ batch.events[-1].event_id, state_filter=state_filter
)
state_ids = yield self.store.get_state_ids_for_event(
- batch.events[0].event_id, state_filter=state_filter,
+ batch.events[0].event_id, state_filter=state_filter
)
else:
current_state_ids = yield self.get_state_at(
- room_id, stream_position=now_token,
- state_filter=state_filter,
+ room_id, stream_position=now_token, state_filter=state_filter
)
state_ids = current_state_ids
@@ -775,7 +784,7 @@ class SyncHandler(object):
)
elif batch.limited:
state_at_timeline_start = yield self.store.get_state_ids_for_event(
- batch.events[0].event_id, state_filter=state_filter,
+ batch.events[0].event_id, state_filter=state_filter
)
# for now, we disable LL for gappy syncs - see
@@ -793,12 +802,11 @@ class SyncHandler(object):
state_filter = StateFilter.all()
state_at_previous_sync = yield self.get_state_at(
- room_id, stream_position=since_token,
- state_filter=state_filter,
+ room_id, stream_position=since_token, state_filter=state_filter
)
current_state_ids = yield self.store.get_state_ids_for_event(
- batch.events[-1].event_id, state_filter=state_filter,
+ batch.events[-1].event_id, state_filter=state_filter
)
state_ids = _calculate_state(
@@ -854,8 +862,7 @@ class SyncHandler(object):
# add any member IDs we are about to send into our LruCache
for t, event_id in itertools.chain(
- state_ids.items(),
- timeline_state.items(),
+ state_ids.items(), timeline_state.items()
):
if t[0] == EventTypes.Member:
cache.set(t[1], event_id)
@@ -864,10 +871,14 @@ class SyncHandler(object):
if state_ids:
state = yield self.store.get_events(list(state_ids.values()))
- defer.returnValue({
- (e.type, e.state_key): e
- for e in sync_config.filter_collection.filter_room_state(list(state.values()))
- })
+ defer.returnValue(
+ {
+ (e.type, e.state_key): e
+ for e in sync_config.filter_collection.filter_room_state(
+ list(state.values())
+ )
+ }
+ )
@defer.inlineCallbacks
def unread_notifs_for_room_id(self, room_id, sync_config):
@@ -875,7 +886,7 @@ class SyncHandler(object):
last_unread_event_id = yield self.store.get_last_receipt_event_id_for_user(
user_id=sync_config.user.to_string(),
room_id=room_id,
- receipt_type="m.read"
+ receipt_type="m.read",
)
notifs = []
@@ -909,7 +920,9 @@ class SyncHandler(object):
logger.info(
"Calculating sync response for %r between %s and %s",
- sync_config.user, since_token, now_token,
+ sync_config.user,
+ since_token,
+ now_token,
)
user_id = sync_config.user.to_string()
@@ -920,11 +933,12 @@ class SyncHandler(object):
raise NotImplementedError()
else:
joined_room_ids = yield self.get_rooms_for_user_at(
- user_id, now_token.room_stream_id,
+ user_id, now_token.room_stream_id
)
sync_result_builder = SyncResultBuilder(
- sync_config, full_state,
+ sync_config,
+ full_state,
since_token=since_token,
now_token=now_token,
joined_room_ids=joined_room_ids,
@@ -941,8 +955,7 @@ class SyncHandler(object):
_, _, newly_left_rooms, newly_left_users = res
block_all_presence_data = (
- since_token is None and
- sync_config.filter_collection.blocks_all_presence()
+ since_token is None and sync_config.filter_collection.blocks_all_presence()
)
if self.hs_config.use_presence and not block_all_presence_data:
yield self._generate_sync_entry_for_presence(
@@ -973,22 +986,23 @@ class SyncHandler(object):
room_id = joined_room.room_id
if room_id in newly_joined_rooms:
issue4422_logger.debug(
- "Sync result for newly joined room %s: %r",
- room_id, joined_room,
+ "Sync result for newly joined room %s: %r", room_id, joined_room
)
- defer.returnValue(SyncResult(
- presence=sync_result_builder.presence,
- account_data=sync_result_builder.account_data,
- joined=sync_result_builder.joined,
- invited=sync_result_builder.invited,
- archived=sync_result_builder.archived,
- to_device=sync_result_builder.to_device,
- device_lists=device_lists,
- groups=sync_result_builder.groups,
- device_one_time_keys_count=one_time_key_counts,
- next_batch=sync_result_builder.now_token,
- ))
+ defer.returnValue(
+ SyncResult(
+ presence=sync_result_builder.presence,
+ account_data=sync_result_builder.account_data,
+ joined=sync_result_builder.joined,
+ invited=sync_result_builder.invited,
+ archived=sync_result_builder.archived,
+ to_device=sync_result_builder.to_device,
+ device_lists=device_lists,
+ groups=sync_result_builder.groups,
+ device_one_time_keys_count=one_time_key_counts,
+ next_batch=sync_result_builder.now_token,
+ )
+ )
@measure_func("_generate_sync_entry_for_groups")
@defer.inlineCallbacks
@@ -999,11 +1013,11 @@ class SyncHandler(object):
if since_token and since_token.groups_key:
results = yield self.store.get_groups_changes_for_user(
- user_id, since_token.groups_key, now_token.groups_key,
+ user_id, since_token.groups_key, now_token.groups_key
)
else:
results = yield self.store.get_all_groups_for_user(
- user_id, now_token.groups_key,
+ user_id, now_token.groups_key
)
invited = {}
@@ -1031,17 +1045,19 @@ class SyncHandler(object):
left[group_id] = content["content"]
sync_result_builder.groups = GroupsSyncResult(
- join=joined,
- invite=invited,
- leave=left,
+ join=joined, invite=invited, leave=left
)
@measure_func("_generate_sync_entry_for_device_list")
@defer.inlineCallbacks
- def _generate_sync_entry_for_device_list(self, sync_result_builder,
- newly_joined_rooms,
- newly_joined_or_invited_users,
- newly_left_rooms, newly_left_users):
+ def _generate_sync_entry_for_device_list(
+ self,
+ sync_result_builder,
+ newly_joined_rooms,
+ newly_joined_or_invited_users,
+ newly_left_rooms,
+ newly_left_users,
+ ):
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
@@ -1065,24 +1081,20 @@ class SyncHandler(object):
changed.update(newly_joined_or_invited_users)
if not changed and not newly_left_users:
- defer.returnValue(DeviceLists(
- changed=[],
- left=newly_left_users,
- ))
+ defer.returnValue(DeviceLists(changed=[], left=newly_left_users))
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
user_id
)
- defer.returnValue(DeviceLists(
- changed=users_who_share_room & changed,
- left=set(newly_left_users) - users_who_share_room,
- ))
+ defer.returnValue(
+ DeviceLists(
+ changed=users_who_share_room & changed,
+ left=set(newly_left_users) - users_who_share_room,
+ )
+ )
else:
- defer.returnValue(DeviceLists(
- changed=[],
- left=[],
- ))
+ defer.returnValue(DeviceLists(changed=[], left=[]))
@defer.inlineCallbacks
def _generate_sync_entry_for_to_device(self, sync_result_builder):
@@ -1109,8 +1121,9 @@ class SyncHandler(object):
deleted = yield self.store.delete_messages_for_device(
user_id, device_id, since_stream_id
)
- logger.debug("Deleted %d to-device messages up to %d",
- deleted, since_stream_id)
+ logger.debug(
+ "Deleted %d to-device messages up to %d", deleted, since_stream_id
+ )
messages, stream_id = yield self.store.get_new_messages_for_device(
user_id, device_id, since_stream_id, now_token.to_device_key
@@ -1118,7 +1131,10 @@ class SyncHandler(object):
logger.debug(
"Returning %d to-device messages between %d and %d (current token: %d)",
- len(messages), since_stream_id, stream_id, now_token.to_device_key
+ len(messages),
+ since_stream_id,
+ stream_id,
+ now_token.to_device_key,
)
sync_result_builder.now_token = now_token.copy_and_replace(
"to_device_key", stream_id
@@ -1145,8 +1161,7 @@ class SyncHandler(object):
if since_token and not sync_result_builder.full_state:
account_data, account_data_by_room = (
yield self.store.get_updated_account_data_for_user(
- user_id,
- since_token.account_data_key,
+ user_id, since_token.account_data_key
)
)
@@ -1160,27 +1175,28 @@ class SyncHandler(object):
)
else:
account_data, account_data_by_room = (
- yield self.store.get_account_data_for_user(
- sync_config.user.to_string()
- )
+ yield self.store.get_account_data_for_user(sync_config.user.to_string())
)
- account_data['m.push_rules'] = yield self.push_rules_for_user(
+ account_data["m.push_rules"] = yield self.push_rules_for_user(
sync_config.user
)
- account_data_for_user = sync_config.filter_collection.filter_account_data([
- {"type": account_data_type, "content": content}
- for account_data_type, content in account_data.items()
- ])
+ account_data_for_user = sync_config.filter_collection.filter_account_data(
+ [
+ {"type": account_data_type, "content": content}
+ for account_data_type, content in account_data.items()
+ ]
+ )
sync_result_builder.account_data = account_data_for_user
defer.returnValue(account_data_by_room)
@defer.inlineCallbacks
- def _generate_sync_entry_for_presence(self, sync_result_builder, newly_joined_rooms,
- newly_joined_or_invited_users):
+ def _generate_sync_entry_for_presence(
+ self, sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
+ ):
"""Generates the presence portion of the sync response. Populates the
`sync_result_builder` with the result.
@@ -1223,17 +1239,13 @@ class SyncHandler(object):
extra_users_ids.discard(user.to_string())
if extra_users_ids:
- states = yield self.presence_handler.get_states(
- extra_users_ids,
- )
+ states = yield self.presence_handler.get_states(extra_users_ids)
presence.extend(states)
# Deduplicate the presence entries so that there's at most one per user
presence = list({p.user_id: p for p in presence}.values())
- presence = sync_config.filter_collection.filter_presence(
- presence
- )
+ presence = sync_config.filter_collection.filter_presence(presence)
sync_result_builder.presence = presence
@@ -1253,8 +1265,8 @@ class SyncHandler(object):
"""
user_id = sync_result_builder.sync_config.user.to_string()
block_all_room_ephemeral = (
- sync_result_builder.since_token is None and
- sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral()
+ sync_result_builder.since_token is None
+ and sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral()
)
if block_all_room_ephemeral:
@@ -1275,15 +1287,14 @@ class SyncHandler(object):
have_changed = yield self._have_rooms_changed(sync_result_builder)
if not have_changed:
tags_by_room = yield self.store.get_updated_tags(
- user_id,
- since_token.account_data_key,
+ user_id, since_token.account_data_key
)
if not tags_by_room:
logger.debug("no-oping sync")
defer.returnValue(([], [], [], []))
ignored_account_data = yield self.store.get_global_account_data_by_type_for_user(
- "m.ignored_user_list", user_id=user_id,
+ "m.ignored_user_list", user_id=user_id
)
if ignored_account_data:
@@ -1296,7 +1307,7 @@ class SyncHandler(object):
room_entries, invited, newly_joined_rooms, newly_left_rooms = res
tags_by_room = yield self.store.get_updated_tags(
- user_id, since_token.account_data_key,
+ user_id, since_token.account_data_key
)
else:
res = yield self._get_all_rooms(sync_result_builder, ignored_users)
@@ -1331,8 +1342,8 @@ class SyncHandler(object):
for event in it:
if event.type == EventTypes.Member:
if (
- event.membership == Membership.JOIN or
- event.membership == Membership.INVITE
+ event.membership == Membership.JOIN
+ or event.membership == Membership.INVITE
):
newly_joined_or_invited_users.add(event.state_key)
else:
@@ -1343,12 +1354,14 @@ class SyncHandler(object):
newly_left_users -= newly_joined_or_invited_users
- defer.returnValue((
- newly_joined_rooms,
- newly_joined_or_invited_users,
- newly_left_rooms,
- newly_left_users,
- ))
+ defer.returnValue(
+ (
+ newly_joined_rooms,
+ newly_joined_or_invited_users,
+ newly_left_rooms,
+ newly_left_users,
+ )
+ )
@defer.inlineCallbacks
def _have_rooms_changed(self, sync_result_builder):
@@ -1454,7 +1467,9 @@ class SyncHandler(object):
prev_membership = old_mem_ev.membership
issue4422_logger.debug(
"Previous membership for room %s with join: %s (event %s)",
- room_id, prev_membership, old_mem_ev_id,
+ room_id,
+ prev_membership,
+ old_mem_ev_id,
)
if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
@@ -1476,8 +1491,7 @@ class SyncHandler(object):
if not old_state_ids:
old_state_ids = yield self.get_state_at(room_id, since_token)
old_mem_ev_id = old_state_ids.get(
- (EventTypes.Member, user_id),
- None,
+ (EventTypes.Member, user_id), None
)
old_mem_ev = None
if old_mem_ev_id:
@@ -1498,7 +1512,8 @@ class SyncHandler(object):
# Always include leave/ban events. Just take the last one.
# TODO: How do we handle ban -> leave in same batch?
leave_events = [
- e for e in non_joins
+ e
+ for e in non_joins
if e.membership in (Membership.LEAVE, Membership.BAN)
]
@@ -1526,15 +1541,17 @@ class SyncHandler(object):
else:
batch_events = None
- room_entries.append(RoomSyncResultBuilder(
- room_id=room_id,
- rtype="archived",
- events=batch_events,
- newly_joined=room_id in newly_joined_rooms,
- full_state=False,
- since_token=since_token,
- upto_token=leave_token,
- ))
+ room_entries.append(
+ RoomSyncResultBuilder(
+ room_id=room_id,
+ rtype="archived",
+ events=batch_events,
+ newly_joined=room_id in newly_joined_rooms,
+ full_state=False,
+ since_token=since_token,
+ upto_token=leave_token,
+ )
+ )
timeline_limit = sync_config.filter_collection.timeline_limit()
@@ -1581,7 +1598,8 @@ class SyncHandler(object):
# debugging for https://github.com/matrix-org/synapse/issues/4422
issue4422_logger.debug(
"RoomSyncResultBuilder events for newly joined room %s: %r",
- room_id, entry.events,
+ room_id,
+ entry.events,
)
room_entries.append(entry)
@@ -1606,12 +1624,14 @@ class SyncHandler(object):
sync_config = sync_result_builder.sync_config
membership_list = (
- Membership.INVITE, Membership.JOIN, Membership.LEAVE, Membership.BAN
+ Membership.INVITE,
+ Membership.JOIN,
+ Membership.LEAVE,
+ Membership.BAN,
)
room_list = yield self.store.get_rooms_for_user_where_membership_is(
- user_id=user_id,
- membership_list=membership_list
+ user_id=user_id, membership_list=membership_list
)
room_entries = []
@@ -1619,23 +1639,22 @@ class SyncHandler(object):
for event in room_list:
if event.membership == Membership.JOIN:
- room_entries.append(RoomSyncResultBuilder(
- room_id=event.room_id,
- rtype="joined",
- events=None,
- newly_joined=False,
- full_state=True,
- since_token=since_token,
- upto_token=now_token,
- ))
+ room_entries.append(
+ RoomSyncResultBuilder(
+ room_id=event.room_id,
+ rtype="joined",
+ events=None,
+ newly_joined=False,
+ full_state=True,
+ since_token=since_token,
+ upto_token=now_token,
+ )
+ )
elif event.membership == Membership.INVITE:
if event.sender in ignored_users:
continue
invite = yield self.store.get_event(event.event_id)
- invited.append(InvitedSyncResult(
- room_id=event.room_id,
- invite=invite,
- ))
+ invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite))
elif event.membership in (Membership.LEAVE, Membership.BAN):
# Always send down rooms we were banned or kicked from.
if not sync_config.filter_collection.include_leave:
@@ -1646,22 +1665,31 @@ class SyncHandler(object):
leave_token = now_token.copy_and_replace(
"room_key", "s%d" % (event.stream_ordering,)
)
- room_entries.append(RoomSyncResultBuilder(
- room_id=event.room_id,
- rtype="archived",
- events=None,
- newly_joined=False,
- full_state=True,
- since_token=since_token,
- upto_token=leave_token,
- ))
+ room_entries.append(
+ RoomSyncResultBuilder(
+ room_id=event.room_id,
+ rtype="archived",
+ events=None,
+ newly_joined=False,
+ full_state=True,
+ since_token=since_token,
+ upto_token=leave_token,
+ )
+ )
defer.returnValue((room_entries, invited, []))
@defer.inlineCallbacks
- def _generate_room_entry(self, sync_result_builder, ignored_users,
- room_builder, ephemeral, tags, account_data,
- always_include=False):
+ def _generate_room_entry(
+ self,
+ sync_result_builder,
+ ignored_users,
+ room_builder,
+ ephemeral,
+ tags,
+ account_data,
+ always_include=False,
+ ):
"""Populates the `joined` and `archived` section of `sync_result_builder`
based on the `room_builder`.
@@ -1678,9 +1706,7 @@ class SyncHandler(object):
"""
newly_joined = room_builder.newly_joined
full_state = (
- room_builder.full_state
- or newly_joined
- or sync_result_builder.full_state
+ room_builder.full_state or newly_joined or sync_result_builder.full_state
)
events = room_builder.events
@@ -1697,7 +1723,8 @@ class SyncHandler(object):
upto_token = room_builder.upto_token
batch = yield self._load_filtered_recents(
- room_id, sync_config,
+ room_id,
+ sync_config,
now_token=upto_token,
since_token=since_token,
recents=events,
@@ -1708,7 +1735,8 @@ class SyncHandler(object):
# debug for https://github.com/matrix-org/synapse/issues/4422
issue4422_logger.debug(
"Timeline events after filtering in newly-joined room %s: %r",
- room_id, batch,
+ room_id,
+ batch,
)
# When we join the room (or the client requests full_state), we should
@@ -1726,16 +1754,10 @@ class SyncHandler(object):
account_data_events = []
if tags is not None:
- account_data_events.append({
- "type": "m.tag",
- "content": {"tags": tags},
- })
+ account_data_events.append({"type": "m.tag", "content": {"tags": tags}})
for account_data_type, content in account_data.items():
- account_data_events.append({
- "type": account_data_type,
- "content": content,
- })
+ account_data_events.append({"type": account_data_type, "content": content})
account_data_events = sync_config.filter_collection.filter_room_account_data(
account_data_events
@@ -1743,16 +1765,13 @@ class SyncHandler(object):
ephemeral = sync_config.filter_collection.filter_room_ephemeral(ephemeral)
- if not (always_include
- or batch
- or account_data_events
- or ephemeral
- or full_state):
+ if not (
+ always_include or batch or account_data_events or ephemeral or full_state
+ ):
return
state = yield self.compute_state_delta(
- room_id, batch, sync_config, since_token, now_token,
- full_state=full_state
+ room_id, batch, sync_config, since_token, now_token, full_state=full_state
)
summary = {}
@@ -1760,22 +1779,19 @@ class SyncHandler(object):
# we include a summary in room responses when we're lazy loading
# members (as the client otherwise doesn't have enough info to form
# the name itself).
- if (
- sync_config.filter_collection.lazy_load_members() and
- (
- # we recalulate the summary:
- # if there are membership changes in the timeline, or
- # if membership has changed during a gappy sync, or
- # if this is an initial sync.
- any(ev.type == EventTypes.Member for ev in batch.events) or
- (
- # XXX: this may include false positives in the form of LL
- # members which have snuck into state
- batch.limited and
- any(t == EventTypes.Member for (t, k) in state)
- ) or
- since_token is None
+ if sync_config.filter_collection.lazy_load_members() and (
+ # we recalulate the summary:
+ # if there are membership changes in the timeline, or
+ # if membership has changed during a gappy sync, or
+ # if this is an initial sync.
+ any(ev.type == EventTypes.Member for ev in batch.events)
+ or (
+ # XXX: this may include false positives in the form of LL
+ # members which have snuck into state
+ batch.limited
+ and any(t == EventTypes.Member for (t, k) in state)
)
+ or since_token is None
):
summary = yield self.compute_summary(
room_id, sync_config, batch, state, now_token
@@ -1794,9 +1810,7 @@ class SyncHandler(object):
)
if room_sync or always_include:
- notifs = yield self.unread_notifs_for_room_id(
- room_id, sync_config
- )
+ notifs = yield self.unread_notifs_for_room_id(room_id, sync_config)
if notifs is not None:
unread_notifications["notification_count"] = notifs["notify_count"]
@@ -1807,11 +1821,8 @@ class SyncHandler(object):
if batch.limited and since_token:
user_id = sync_result_builder.sync_config.user.to_string()
logger.info(
- "Incremental gappy sync of %s for user %s with %d state events" % (
- room_id,
- user_id,
- len(state),
- )
+ "Incremental gappy sync of %s for user %s with %d state events"
+ % (room_id, user_id, len(state))
)
elif room_builder.rtype == "archived":
room_sync = ArchivedSyncResult(
@@ -1841,9 +1852,7 @@ class SyncHandler(object):
Deferred[frozenset[str]]: Set of room_ids the user is in at given
stream_ordering.
"""
- joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(
- user_id,
- )
+ joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(user_id)
joined_room_ids = set()
@@ -1862,11 +1871,9 @@ class SyncHandler(object):
logger.info("User joined room after current token: %s", room_id)
extrems = yield self.store.get_forward_extremeties_for_room(
- room_id, stream_ordering,
- )
- users_in_room = yield self.state.get_current_users_in_room(
- room_id, extrems,
+ room_id, stream_ordering
)
+ users_in_room = yield self.state.get_current_users_in_room(room_id, extrems)
if user_id in users_in_room:
joined_room_ids.add(room_id)
@@ -1886,7 +1893,7 @@ def _action_has_highlight(actions):
def _calculate_state(
- timeline_contains, timeline_start, previous, current, lazy_load_members,
+ timeline_contains, timeline_start, previous, current, lazy_load_members
):
"""Works out what state to include in a sync response.
@@ -1930,15 +1937,12 @@ def _calculate_state(
if lazy_load_members:
p_ids.difference_update(
- e for t, e in iteritems(timeline_start)
- if t[0] == EventTypes.Member
+ e for t, e in iteritems(timeline_start) if t[0] == EventTypes.Member
)
state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
- return {
- event_id_to_key[e]: e for e in state_ids
- }
+ return {event_id_to_key[e]: e for e in state_ids}
class SyncResultBuilder(object):
@@ -1961,8 +1965,10 @@ class SyncResultBuilder(object):
groups (GroupsSyncResult|None)
to_device (list)
"""
- def __init__(self, sync_config, full_state, since_token, now_token,
- joined_room_ids):
+
+ def __init__(
+ self, sync_config, full_state, since_token, now_token, joined_room_ids
+ ):
"""
Args:
sync_config (SyncConfig)
@@ -1991,8 +1997,10 @@ class RoomSyncResultBuilder(object):
"""Stores information needed to create either a `JoinedSyncResult` or
`ArchivedSyncResult`.
"""
- def __init__(self, room_id, rtype, events, newly_joined, full_state,
- since_token, upto_token):
+
+ def __init__(
+ self, room_id, rtype, events, newly_joined, full_state, since_token, upto_token
+ ):
"""
Args:
room_id(str)
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 972662eb48..f8062c8671 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -68,13 +68,10 @@ class TypingHandler(object):
# caches which room_ids changed at which serials
self._typing_stream_change_cache = StreamChangeCache(
- "TypingStreamChangeCache", self._latest_room_serial,
+ "TypingStreamChangeCache", self._latest_room_serial
)
- self.clock.looping_call(
- self._handle_timeouts,
- 5000,
- )
+ self.clock.looping_call(self._handle_timeouts, 5000)
def _reset(self):
"""
@@ -108,19 +105,11 @@ class TypingHandler(object):
if self.hs.is_mine_id(member.user_id):
last_fed_poke = self._member_last_federation_poke.get(member, None)
if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now:
- run_in_background(
- self._push_remote,
- member=member,
- typing=True
- )
+ run_in_background(self._push_remote, member=member, typing=True)
# Add a paranoia timer to ensure that we always have a timer for
# each person typing.
- self.wheel_timer.insert(
- now=now,
- obj=member,
- then=now + 60 * 1000,
- )
+ self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
def is_typing(self, member):
return member.user_id in self._room_typing.get(member.room_id, [])
@@ -138,9 +127,7 @@ class TypingHandler(object):
yield self.auth.check_joined_room(room_id, target_user_id)
- logger.debug(
- "%s has started typing in %s", target_user_id, room_id
- )
+ logger.debug("%s has started typing in %s", target_user_id, room_id)
member = RoomMember(room_id=room_id, user_id=target_user_id)
@@ -149,20 +136,13 @@ class TypingHandler(object):
now = self.clock.time_msec()
self._member_typing_until[member] = now + timeout
- self.wheel_timer.insert(
- now=now,
- obj=member,
- then=now + timeout,
- )
+ self.wheel_timer.insert(now=now, obj=member, then=now + timeout)
if was_present:
# No point sending another notification
defer.returnValue(None)
- self._push_update(
- member=member,
- typing=True,
- )
+ self._push_update(member=member, typing=True)
@defer.inlineCallbacks
def stopped_typing(self, target_user, auth_user, room_id):
@@ -177,9 +157,7 @@ class TypingHandler(object):
yield self.auth.check_joined_room(room_id, target_user_id)
- logger.debug(
- "%s has stopped typing in %s", target_user_id, room_id
- )
+ logger.debug("%s has stopped typing in %s", target_user_id, room_id)
member = RoomMember(room_id=room_id, user_id=target_user_id)
@@ -200,20 +178,14 @@ class TypingHandler(object):
self._member_typing_until.pop(member, None)
self._member_last_federation_poke.pop(member, None)
- self._push_update(
- member=member,
- typing=False,
- )
+ self._push_update(member=member, typing=False)
def _push_update(self, member, typing):
if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users.
run_in_background(self._push_remote, member, typing)
- self._push_update_local(
- member=member,
- typing=typing
- )
+ self._push_update_local(member=member, typing=typing)
@defer.inlineCallbacks
def _push_remote(self, member, typing):
@@ -223,9 +195,7 @@ class TypingHandler(object):
now = self.clock.time_msec()
self.wheel_timer.insert(
- now=now,
- obj=member,
- then=now + FEDERATION_PING_INTERVAL,
+ now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
)
for domain in set(get_domain_from_id(u) for u in users):
@@ -256,8 +226,7 @@ class TypingHandler(object):
if user.domain != origin:
logger.info(
- "Got typing update from %r with bad 'user_id': %r",
- origin, user_id,
+ "Got typing update from %r with bad 'user_id': %r", origin, user_id
)
return
@@ -268,15 +237,8 @@ class TypingHandler(object):
logger.info("Got typing update from %s: %r", user_id, content)
now = self.clock.time_msec()
self._member_typing_until[member] = now + FEDERATION_TIMEOUT
- self.wheel_timer.insert(
- now=now,
- obj=member,
- then=now + FEDERATION_TIMEOUT,
- )
- self._push_update_local(
- member=member,
- typing=content["typing"]
- )
+ self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT)
+ self._push_update_local(member=member, typing=content["typing"])
def _push_update_local(self, member, typing):
room_set = self._room_typing.setdefault(member.room_id, set())
@@ -288,7 +250,7 @@ class TypingHandler(object):
self._latest_room_serial += 1
self._room_serials[member.room_id] = self._latest_room_serial
self._typing_stream_change_cache.entity_has_changed(
- member.room_id, self._latest_room_serial,
+ member.room_id, self._latest_room_serial
)
self.notifier.on_new_event(
@@ -300,7 +262,7 @@ class TypingHandler(object):
return []
changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
- last_id,
+ last_id
)
if changed_rooms is None:
@@ -334,9 +296,7 @@ class TypingNotificationEventSource(object):
return {
"type": "m.typing",
"room_id": room_id,
- "content": {
- "user_ids": list(typing),
- },
+ "content": {"user_ids": list(typing)},
}
def get_new_events(self, from_key, room_ids, **kwargs):
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index d36bcd6336..3acf772cd1 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -25,6 +25,7 @@ from synapse.api.errors import SynapseError
class RequestTimedOutError(SynapseError):
"""Exception representing timeout of an outbound request"""
+
def __init__(self):
super(RequestTimedOutError, self).__init__(504, "Timed out")
@@ -40,15 +41,12 @@ def cancelled_to_request_timed_out_error(value, timeout):
return value
-ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
+ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$")
def redact_uri(uri):
"""Strips access tokens from the uri replaces with <redacted>"""
- return ACCESS_TOKEN_RE.sub(
- r'\1<redacted>\3',
- uri
- )
+ return ACCESS_TOKEN_RE.sub(r"\1<redacted>\3", uri)
class QuieterFileBodyProducer(FileBodyProducer):
@@ -57,6 +55,7 @@ class QuieterFileBodyProducer(FileBodyProducer):
Workaround for https://github.com/matrix-org/synapse/issues/4003 /
https://twistedmatrix.com/trac/ticket/6528
"""
+
def stopProducing(self):
try:
FileBodyProducer.stopProducing(self)
diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py
index 0e10e3f8f7..096619a8c2 100644
--- a/synapse/http/additional_resource.py
+++ b/synapse/http/additional_resource.py
@@ -28,6 +28,7 @@ class AdditionalResource(Resource):
This class is also where we wrap the request handler with logging, metrics,
and exception handling.
"""
+
def __init__(self, hs, handler):
"""Initialise AdditionalResource
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 77fe68818b..9bc7035c8d 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -17,7 +17,7 @@
import logging
from io import BytesIO
-from six import text_type
+from six import raise_from, text_type
from six.moves import urllib
import treq
@@ -103,8 +103,8 @@ class IPBlacklistingResolver(object):
ip_address, self._ip_whitelist, self._ip_blacklist
):
logger.info(
- "Dropped %s from DNS resolution to %s due to blacklist" %
- (ip_address, hostname)
+ "Dropped %s from DNS resolution to %s due to blacklist"
+ % (ip_address, hostname)
)
has_bad_ip = True
@@ -156,7 +156,7 @@ class BlacklistingAgentWrapper(Agent):
self._ip_blacklist = ip_blacklist
def request(self, method, uri, headers=None, bodyProducer=None):
- h = urllib.parse.urlparse(uri.decode('ascii'))
+ h = urllib.parse.urlparse(uri.decode("ascii"))
try:
ip_address = IPAddress(h.hostname)
@@ -164,10 +164,7 @@ class BlacklistingAgentWrapper(Agent):
if check_against_blacklist(
ip_address, self._ip_whitelist, self._ip_blacklist
):
- logger.info(
- "Blocking access to %s due to blacklist" %
- (ip_address,)
- )
+ logger.info("Blocking access to %s due to blacklist" % (ip_address,))
e = SynapseError(403, "IP address blocked by IP blacklist entry")
return defer.fail(Failure(e))
except Exception:
@@ -206,7 +203,7 @@ class SimpleHttpClient(object):
if hs.config.user_agent_suffix:
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix)
- self.user_agent = self.user_agent.encode('ascii')
+ self.user_agent = self.user_agent.encode("ascii")
if self._ip_blacklist:
real_reactor = hs.get_reactor()
@@ -520,8 +517,8 @@ class SimpleHttpClient(object):
resp_headers = dict(response.headers.getAllRawHeaders())
if (
- b'Content-Length' in resp_headers
- and int(resp_headers[b'Content-Length'][0]) > max_size
+ b"Content-Length" in resp_headers
+ and int(resp_headers[b"Content-Length"][0]) > max_size
):
logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
raise SynapseError(
@@ -542,17 +539,17 @@ class SimpleHttpClient(object):
length = yield make_deferred_yieldable(
_readBodyToFile(response, output_stream, max_size)
)
+ except SynapseError:
+ # This can happen e.g. because the body is too large.
+ raise
except Exception as e:
- logger.exception("Failed to download body")
- raise SynapseError(
- 502, ("Failed to download remote body: %s" % e), Codes.UNKNOWN
- )
+ raise_from(SynapseError(502, ("Failed to download remote body: %s" % e)), e)
defer.returnValue(
(
length,
resp_headers,
- response.request.absoluteURI.decode('ascii'),
+ response.request.absoluteURI.decode("ascii"),
response.code,
)
)
@@ -642,7 +639,7 @@ def encode_urlencode_args(args):
def encode_urlencode_arg(arg):
if isinstance(arg, text_type):
- return arg.encode('utf-8')
+ return arg.encode("utf-8")
elif isinstance(arg, list):
return [encode_urlencode_arg(i) for i in arg]
else:
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index cd79ebab62..92a5b606c8 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -31,7 +31,7 @@ def parse_server_name(server_name):
ValueError if the server name could not be parsed.
"""
try:
- if server_name[-1] == ']':
+ if server_name[-1] == "]":
# ipv6 literal, hopefully
return server_name, None
@@ -43,9 +43,7 @@ def parse_server_name(server_name):
raise ValueError("Invalid server name '%s'" % server_name)
-VALID_HOST_REGEX = re.compile(
- "\\A[0-9a-zA-Z.-]+\\Z",
-)
+VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z")
def parse_and_validate_server_name(server_name):
@@ -67,17 +65,15 @@ def parse_and_validate_server_name(server_name):
# that nobody is sneaking IP literals in that look like hostnames, etc.
# look for ipv6 literals
- if host[0] == '[':
- if host[-1] != ']':
- raise ValueError("Mismatched [...] in server name '%s'" % (
- server_name,
- ))
+ if host[0] == "[":
+ if host[-1] != "]":
+ raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
return host, port
# otherwise it should only be alphanumerics.
if not VALID_HOST_REGEX.match(host):
- raise ValueError("Server name '%s' contains invalid characters" % (
- server_name,
- ))
+ raise ValueError(
+ "Server name '%s' contains invalid characters" % (server_name,)
+ )
return host, port
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index b4cbe97b41..414cde0777 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -48,7 +48,7 @@ WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600
WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
logger = logging.getLogger(__name__)
-well_known_cache = TTLCache('well-known')
+well_known_cache = TTLCache("well-known")
@implementer(IAgent)
@@ -78,7 +78,9 @@ class MatrixFederationAgent(object):
"""
def __init__(
- self, reactor, tls_client_options_factory,
+ self,
+ reactor,
+ tls_client_options_factory,
_well_known_tls_policy=None,
_srv_resolver=None,
_well_known_cache=well_known_cache,
@@ -100,9 +102,9 @@ class MatrixFederationAgent(object):
if _well_known_tls_policy is not None:
# the param is called 'contextFactory', but actually passing a
# contextfactory is deprecated, and it expects an IPolicyForHTTPS.
- agent_args['contextFactory'] = _well_known_tls_policy
+ agent_args["contextFactory"] = _well_known_tls_policy
_well_known_agent = RedirectAgent(
- Agent(self._reactor, pool=self._pool, **agent_args),
+ Agent(self._reactor, pool=self._pool, **agent_args)
)
self._well_known_agent = _well_known_agent
@@ -149,7 +151,7 @@ class MatrixFederationAgent(object):
tls_options = None
else:
tls_options = self._tls_client_options_factory.get_options(
- res.tls_server_name.decode("ascii"),
+ res.tls_server_name.decode("ascii")
)
# make sure that the Host header is set correctly
@@ -158,14 +160,14 @@ class MatrixFederationAgent(object):
else:
headers = headers.copy()
- if not headers.hasHeader(b'host'):
- headers.addRawHeader(b'host', res.host_header)
+ if not headers.hasHeader(b"host"):
+ headers.addRawHeader(b"host", res.host_header)
class EndpointFactory(object):
@staticmethod
def endpointForURI(_uri):
ep = LoggingHostnameEndpoint(
- self._reactor, res.target_host, res.target_port,
+ self._reactor, res.target_host, res.target_port
)
if tls_options is not None:
ep = wrapClientTLS(tls_options, ep)
@@ -203,21 +205,25 @@ class MatrixFederationAgent(object):
port = parsed_uri.port
if port == -1:
port = 8448
- defer.returnValue(_RoutingResult(
- host_header=parsed_uri.netloc,
- tls_server_name=parsed_uri.host,
- target_host=parsed_uri.host,
- target_port=port,
- ))
+ defer.returnValue(
+ _RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=parsed_uri.host,
+ target_port=port,
+ )
+ )
if parsed_uri.port != -1:
# there is an explicit port
- defer.returnValue(_RoutingResult(
- host_header=parsed_uri.netloc,
- tls_server_name=parsed_uri.host,
- target_host=parsed_uri.host,
- target_port=parsed_uri.port,
- ))
+ defer.returnValue(
+ _RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=parsed_uri.host,
+ target_port=parsed_uri.port,
+ )
+ )
if lookup_well_known:
# try a .well-known lookup
@@ -229,8 +235,8 @@ class MatrixFederationAgent(object):
# parse the server name in the .well-known response into host/port.
# (This code is lifted from twisted.web.client.URI.fromBytes).
- if b':' in well_known_server:
- well_known_host, well_known_port = well_known_server.rsplit(b':', 1)
+ if b":" in well_known_server:
+ well_known_host, well_known_port = well_known_server.rsplit(b":", 1)
try:
well_known_port = int(well_known_port)
except ValueError:
@@ -264,21 +270,27 @@ class MatrixFederationAgent(object):
port = 8448
logger.debug(
"No SRV record for %s, using %s:%i",
- parsed_uri.host.decode("ascii"), target_host.decode("ascii"), port,
+ parsed_uri.host.decode("ascii"),
+ target_host.decode("ascii"),
+ port,
)
else:
target_host, port = pick_server_from_list(server_list)
logger.debug(
"Picked %s:%i from SRV records for %s",
- target_host.decode("ascii"), port, parsed_uri.host.decode("ascii"),
+ target_host.decode("ascii"),
+ port,
+ parsed_uri.host.decode("ascii"),
)
- defer.returnValue(_RoutingResult(
- host_header=parsed_uri.netloc,
- tls_server_name=parsed_uri.host,
- target_host=target_host,
- target_port=port,
- ))
+ defer.returnValue(
+ _RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=target_host,
+ target_port=port,
+ )
+ )
@defer.inlineCallbacks
def _get_well_known(self, server_name):
@@ -318,18 +330,18 @@ class MatrixFederationAgent(object):
- None if there was no .well-known file.
- INVALID_WELL_KNOWN if the .well-known was invalid
"""
- uri = b"https://%s/.well-known/matrix/server" % (server_name, )
+ uri = b"https://%s/.well-known/matrix/server" % (server_name,)
uri_str = uri.decode("ascii")
logger.info("Fetching %s", uri_str)
try:
response = yield make_deferred_yieldable(
- self._well_known_agent.request(b"GET", uri),
+ self._well_known_agent.request(b"GET", uri)
)
body = yield make_deferred_yieldable(readBody(response))
if response.code != 200:
- raise Exception("Non-200 response %s" % (response.code, ))
+ raise Exception("Non-200 response %s" % (response.code,))
- parsed_body = json.loads(body.decode('utf-8'))
+ parsed_body = json.loads(body.decode("utf-8"))
logger.info("Response from .well-known: %s", parsed_body)
if not isinstance(parsed_body, dict):
raise Exception("not a dict")
@@ -347,8 +359,7 @@ class MatrixFederationAgent(object):
result = parsed_body["m.server"].encode("ascii")
cache_period = _cache_period_from_headers(
- response.headers,
- time_now=self._reactor.seconds,
+ response.headers, time_now=self._reactor.seconds
)
if cache_period is None:
cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD
@@ -364,6 +375,7 @@ class MatrixFederationAgent(object):
@implementer(IStreamClientEndpoint)
class LoggingHostnameEndpoint(object):
"""A wrapper for HostnameEndpint which logs when it connects"""
+
def __init__(self, reactor, host, port, *args, **kwargs):
self.host = host
self.port = port
@@ -377,17 +389,17 @@ class LoggingHostnameEndpoint(object):
def _cache_period_from_headers(headers, time_now=time.time):
cache_controls = _parse_cache_control(headers)
- if b'no-store' in cache_controls:
+ if b"no-store" in cache_controls:
return 0
- if b'max-age' in cache_controls:
+ if b"max-age" in cache_controls:
try:
- max_age = int(cache_controls[b'max-age'])
+ max_age = int(cache_controls[b"max-age"])
return max_age
except ValueError:
pass
- expires = headers.getRawHeaders(b'expires')
+ expires = headers.getRawHeaders(b"expires")
if expires is not None:
try:
expires_date = stringToDatetime(expires[-1])
@@ -403,9 +415,9 @@ def _cache_period_from_headers(headers, time_now=time.time):
def _parse_cache_control(headers):
cache_controls = {}
- for hdr in headers.getRawHeaders(b'cache-control', []):
- for directive in hdr.split(b','):
- splits = [x.strip() for x in directive.split(b'=', 1)]
+ for hdr in headers.getRawHeaders(b"cache-control", []):
+ for directive in hdr.split(b","):
+ splits = [x.strip() for x in directive.split(b"=", 1)]
k = splits[0].lower()
v = splits[1] if len(splits) > 1 else None
cache_controls[k] = v
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index 71830c549d..1f22f78a75 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -45,6 +45,7 @@ class Server(object):
expires (int): when the cache should expire this record - in *seconds* since
the epoch
"""
+
host = attr.ib()
port = attr.ib()
priority = attr.ib(default=0)
@@ -79,9 +80,7 @@ def pick_server_from_list(server_list):
return s.host, s.port
# this should be impossible.
- raise RuntimeError(
- "pick_server_from_list got to end of eligible server list.",
- )
+ raise RuntimeError("pick_server_from_list got to end of eligible server list.")
class SrvResolver(object):
@@ -95,6 +94,7 @@ class SrvResolver(object):
cache (dict): cache object
get_time (callable): clock implementation. Should return seconds since the epoch
"""
+
def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
self._dns_client = dns_client
self._cache = cache
@@ -124,7 +124,7 @@ class SrvResolver(object):
try:
answers, _, _ = yield make_deferred_yieldable(
- self._dns_client.lookupService(service_name),
+ self._dns_client.lookupService(service_name)
)
except DNSNameError:
# TODO: cache this. We can get the SOA out of the exception, and use
@@ -136,17 +136,18 @@ class SrvResolver(object):
cache_entry = self._cache.get(service_name, None)
if cache_entry:
logger.warn(
- "Failed to resolve %r, falling back to cache. %r",
- service_name, e
+ "Failed to resolve %r, falling back to cache. %r", service_name, e
)
defer.returnValue(list(cache_entry))
else:
raise e
- if (len(answers) == 1
- and answers[0].type == dns.SRV
- and answers[0].payload
- and answers[0].payload.target == dns.Name(b'.')):
+ if (
+ len(answers) == 1
+ and answers[0].type == dns.SRV
+ and answers[0].payload
+ and answers[0].payload.target == dns.Name(b".")
+ ):
raise ConnectError("Service %s unavailable" % service_name)
servers = []
@@ -157,13 +158,15 @@ class SrvResolver(object):
payload = answer.payload
- servers.append(Server(
- host=payload.target.name,
- port=payload.port,
- priority=payload.priority,
- weight=payload.weight,
- expires=now + answer.ttl,
- ))
+ servers.append(
+ Server(
+ host=payload.target.name,
+ port=payload.port,
+ priority=payload.priority,
+ weight=payload.weight,
+ expires=now + answer.ttl,
+ )
+ )
self._cache[service_name] = list(servers)
defer.returnValue(servers)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 663ea72a7a..5ef8bb60a3 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -54,10 +54,12 @@ from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
-outgoing_requests_counter = Counter("synapse_http_matrixfederationclient_requests",
- "", ["method"])
-incoming_responses_counter = Counter("synapse_http_matrixfederationclient_responses",
- "", ["method", "code"])
+outgoing_requests_counter = Counter(
+ "synapse_http_matrixfederationclient_requests", "", ["method"]
+)
+incoming_responses_counter = Counter(
+ "synapse_http_matrixfederationclient_responses", "", ["method", "code"]
+)
MAX_LONG_RETRIES = 10
@@ -137,11 +139,7 @@ def _handle_json_response(reactor, timeout_sec, request, response):
check_content_type_is_json(response.headers)
d = treq.json_content(response)
- d = timeout_deferred(
- d,
- timeout=timeout_sec,
- reactor=reactor,
- )
+ d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
body = yield make_deferred_yieldable(d)
except Exception as e:
@@ -157,7 +155,7 @@ def _handle_json_response(reactor, timeout_sec, request, response):
request.txn_id,
request.destination,
response.code,
- response.phrase.decode('ascii', errors='replace'),
+ response.phrase.decode("ascii", errors="replace"),
)
defer.returnValue(body)
@@ -181,7 +179,7 @@ class MatrixFederationHttpClient(object):
# We need to use a DNS resolver which filters out blacklisted IP
# addresses, to prevent DNS rebinding.
nameResolver = IPBlacklistingResolver(
- real_reactor, None, hs.config.federation_ip_range_blacklist,
+ real_reactor, None, hs.config.federation_ip_range_blacklist
)
@implementer(IReactorPluggableNameResolver)
@@ -194,21 +192,19 @@ class MatrixFederationHttpClient(object):
self.reactor = Reactor()
- self.agent = MatrixFederationAgent(
- self.reactor,
- tls_client_options_factory,
- )
+ self.agent = MatrixFederationAgent(self.reactor, tls_client_options_factory)
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper(
- self.agent, self.reactor,
+ self.agent,
+ self.reactor,
ip_blacklist=hs.config.federation_ip_range_blacklist,
)
self.clock = hs.get_clock()
self._store = hs.get_datastore()
- self.version_string_bytes = hs.version_string.encode('ascii')
+ self.version_string_bytes = hs.version_string.encode("ascii")
self.default_timeout = 60
def schedule(x):
@@ -218,10 +214,7 @@ class MatrixFederationHttpClient(object):
@defer.inlineCallbacks
def _send_request_with_optional_trailing_slash(
- self,
- request,
- try_trailing_slash_on_400=False,
- **send_request_args
+ self, request, try_trailing_slash_on_400=False, **send_request_args
):
"""Wrapper for _send_request which can optionally retry the request
upon receiving a combination of a 400 HTTP response code and a
@@ -244,9 +237,7 @@ class MatrixFederationHttpClient(object):
Deferred[Dict]: Parsed JSON response body.
"""
try:
- response = yield self._send_request(
- request, **send_request_args
- )
+ response = yield self._send_request(request, **send_request_args)
except HttpResponseException as e:
# Received an HTTP error > 300. Check if it meets the requirements
# to retry with a trailing slash
@@ -262,9 +253,7 @@ class MatrixFederationHttpClient(object):
logger.info("Retrying request with trailing slash")
request.path += "/"
- response = yield self._send_request(
- request, **send_request_args
- )
+ response = yield self._send_request(request, **send_request_args)
defer.returnValue(response)
@@ -329,8 +318,8 @@ class MatrixFederationHttpClient(object):
_sec_timeout = self.default_timeout
if (
- self.hs.config.federation_domain_whitelist is not None and
- request.destination not in self.hs.config.federation_domain_whitelist
+ self.hs.config.federation_domain_whitelist is not None
+ and request.destination not in self.hs.config.federation_domain_whitelist
):
raise FederationDeniedError(request.destination)
@@ -350,9 +339,7 @@ class MatrixFederationHttpClient(object):
else:
query_bytes = b""
- headers_dict = {
- b"User-Agent": [self.version_string_bytes],
- }
+ headers_dict = {b"User-Agent": [self.version_string_bytes]}
with limiter:
# XXX: Would be much nicer to retry only at the transaction-layer
@@ -362,16 +349,14 @@ class MatrixFederationHttpClient(object):
else:
retries_left = MAX_SHORT_RETRIES
- url_bytes = urllib.parse.urlunparse((
- b"matrix", destination_bytes,
- path_bytes, None, query_bytes, b"",
- ))
- url_str = url_bytes.decode('ascii')
+ url_bytes = urllib.parse.urlunparse(
+ (b"matrix", destination_bytes, path_bytes, None, query_bytes, b"")
+ )
+ url_str = url_bytes.decode("ascii")
- url_to_sign_bytes = urllib.parse.urlunparse((
- b"", b"",
- path_bytes, None, query_bytes, b"",
- ))
+ url_to_sign_bytes = urllib.parse.urlunparse(
+ (b"", b"", path_bytes, None, query_bytes, b"")
+ )
while True:
try:
@@ -379,26 +364,27 @@ class MatrixFederationHttpClient(object):
if json:
headers_dict[b"Content-Type"] = [b"application/json"]
auth_headers = self.build_auth_headers(
- destination_bytes, method_bytes, url_to_sign_bytes,
- json,
+ destination_bytes, method_bytes, url_to_sign_bytes, json
)
data = encode_canonical_json(json)
producer = QuieterFileBodyProducer(
- BytesIO(data),
- cooperator=self._cooperator,
+ BytesIO(data), cooperator=self._cooperator
)
else:
producer = None
auth_headers = self.build_auth_headers(
- destination_bytes, method_bytes, url_to_sign_bytes,
+ destination_bytes, method_bytes, url_to_sign_bytes
)
headers_dict[b"Authorization"] = auth_headers
logger.info(
"{%s} [%s] Sending request: %s %s; timeout %fs",
- request.txn_id, request.destination, request.method,
- url_str, _sec_timeout,
+ request.txn_id,
+ request.destination,
+ request.method,
+ url_str,
+ _sec_timeout,
)
try:
@@ -430,7 +416,7 @@ class MatrixFederationHttpClient(object):
request.txn_id,
request.destination,
response.code,
- response.phrase.decode('ascii', errors='replace'),
+ response.phrase.decode("ascii", errors="replace"),
)
if 200 <= response.code < 300:
@@ -440,9 +426,7 @@ class MatrixFederationHttpClient(object):
# Update transactions table?
d = treq.content(response)
d = timeout_deferred(
- d,
- timeout=_sec_timeout,
- reactor=self.reactor,
+ d, timeout=_sec_timeout, reactor=self.reactor
)
try:
@@ -460,9 +444,7 @@ class MatrixFederationHttpClient(object):
)
body = None
- e = HttpResponseException(
- response.code, response.phrase, body
- )
+ e = HttpResponseException(response.code, response.phrase, body)
# Retry if the error is a 429 (Too Many Requests),
# otherwise just raise a standard HttpResponseException
@@ -521,7 +503,7 @@ class MatrixFederationHttpClient(object):
defer.returnValue(response)
def build_auth_headers(
- self, destination, method, url_bytes, content=None, destination_is=None,
+ self, destination, method, url_bytes, content=None, destination_is=None
):
"""
Builds the Authorization headers for a federation request
@@ -538,11 +520,7 @@ class MatrixFederationHttpClient(object):
Returns:
list[bytes]: a list of headers to be added as "Authorization:" headers
"""
- request = {
- "method": method,
- "uri": url_bytes,
- "origin": self.server_name,
- }
+ request = {"method": method, "uri": url_bytes, "origin": self.server_name}
if destination is not None:
request["destination"] = destination
@@ -558,20 +536,28 @@ class MatrixFederationHttpClient(object):
auth_headers = []
for key, sig in request["signatures"][self.server_name].items():
- auth_headers.append((
- "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
- self.server_name, key, sig,
- )).encode('ascii')
+ auth_headers.append(
+ (
+ 'X-Matrix origin=%s,key="%s",sig="%s"'
+ % (self.server_name, key, sig)
+ ).encode("ascii")
)
return auth_headers
@defer.inlineCallbacks
- def put_json(self, destination, path, args={}, data={},
- json_data_callback=None,
- long_retries=False, timeout=None,
- ignore_backoff=False,
- backoff_on_404=False,
- try_trailing_slash_on_400=False):
+ def put_json(
+ self,
+ destination,
+ path,
+ args={},
+ data={},
+ json_data_callback=None,
+ long_retries=False,
+ timeout=None,
+ ignore_backoff=False,
+ backoff_on_404=False,
+ try_trailing_slash_on_400=False,
+ ):
""" Sends the specifed json data using PUT
Args:
@@ -635,14 +621,22 @@ class MatrixFederationHttpClient(object):
)
body = yield _handle_json_response(
- self.reactor, self.default_timeout, request, response,
+ self.reactor, self.default_timeout, request, response
)
defer.returnValue(body)
@defer.inlineCallbacks
- def post_json(self, destination, path, data={}, long_retries=False,
- timeout=None, ignore_backoff=False, args={}):
+ def post_json(
+ self,
+ destination,
+ path,
+ data={},
+ long_retries=False,
+ timeout=None,
+ ignore_backoff=False,
+ args={},
+ ):
""" Sends the specifed json data using POST
Args:
@@ -681,11 +675,7 @@ class MatrixFederationHttpClient(object):
"""
request = MatrixFederationRequest(
- method="POST",
- destination=destination,
- path=path,
- query=args,
- json=data,
+ method="POST", destination=destination, path=path, query=args, json=data
)
response = yield self._send_request(
@@ -701,14 +691,21 @@ class MatrixFederationHttpClient(object):
_sec_timeout = self.default_timeout
body = yield _handle_json_response(
- self.reactor, _sec_timeout, request, response,
+ self.reactor, _sec_timeout, request, response
)
defer.returnValue(body)
@defer.inlineCallbacks
- def get_json(self, destination, path, args=None, retry_on_dns_fail=True,
- timeout=None, ignore_backoff=False,
- try_trailing_slash_on_400=False):
+ def get_json(
+ self,
+ destination,
+ path,
+ args=None,
+ retry_on_dns_fail=True,
+ timeout=None,
+ ignore_backoff=False,
+ try_trailing_slash_on_400=False,
+ ):
""" GETs some json from the given host homeserver and path
Args:
@@ -745,10 +742,7 @@ class MatrixFederationHttpClient(object):
remote, due to e.g. DNS failures, connection timeouts etc.
"""
request = MatrixFederationRequest(
- method="GET",
- destination=destination,
- path=path,
- query=args,
+ method="GET", destination=destination, path=path, query=args
)
response = yield self._send_request_with_optional_trailing_slash(
@@ -761,14 +755,21 @@ class MatrixFederationHttpClient(object):
)
body = yield _handle_json_response(
- self.reactor, self.default_timeout, request, response,
+ self.reactor, self.default_timeout, request, response
)
defer.returnValue(body)
@defer.inlineCallbacks
- def delete_json(self, destination, path, long_retries=False,
- timeout=None, ignore_backoff=False, args={}):
+ def delete_json(
+ self,
+ destination,
+ path,
+ long_retries=False,
+ timeout=None,
+ ignore_backoff=False,
+ args={},
+ ):
"""Send a DELETE request to the remote expecting some json response
Args:
@@ -802,10 +803,7 @@ class MatrixFederationHttpClient(object):
remote, due to e.g. DNS failures, connection timeouts etc.
"""
request = MatrixFederationRequest(
- method="DELETE",
- destination=destination,
- path=path,
- query=args,
+ method="DELETE", destination=destination, path=path, query=args
)
response = yield self._send_request(
@@ -816,14 +814,21 @@ class MatrixFederationHttpClient(object):
)
body = yield _handle_json_response(
- self.reactor, self.default_timeout, request, response,
+ self.reactor, self.default_timeout, request, response
)
defer.returnValue(body)
@defer.inlineCallbacks
- def get_file(self, destination, path, output_stream, args={},
- retry_on_dns_fail=True, max_size=None,
- ignore_backoff=False):
+ def get_file(
+ self,
+ destination,
+ path,
+ output_stream,
+ args={},
+ retry_on_dns_fail=True,
+ max_size=None,
+ ignore_backoff=False,
+ ):
"""GETs a file from a given homeserver
Args:
destination (str): The remote server to send the HTTP request to.
@@ -848,16 +853,11 @@ class MatrixFederationHttpClient(object):
remote, due to e.g. DNS failures, connection timeouts etc.
"""
request = MatrixFederationRequest(
- method="GET",
- destination=destination,
- path=path,
- query=args,
+ method="GET", destination=destination, path=path, query=args
)
response = yield self._send_request(
- request,
- retry_on_dns_fail=retry_on_dns_fail,
- ignore_backoff=ignore_backoff,
+ request, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff
)
headers = dict(response.headers.getAllRawHeaders())
@@ -879,7 +879,7 @@ class MatrixFederationHttpClient(object):
request.txn_id,
request.destination,
response.code,
- response.phrase.decode('ascii', errors='replace'),
+ response.phrase.decode("ascii", errors="replace"),
length,
)
defer.returnValue((length, headers))
@@ -896,11 +896,13 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.stream.write(data)
self.length += len(data)
if self.max_size is not None and self.length >= self.max_size:
- self.deferred.errback(SynapseError(
- 502,
- "Requested file is too large > %r bytes" % (self.max_size,),
- Codes.TOO_LARGE,
- ))
+ self.deferred.errback(
+ SynapseError(
+ 502,
+ "Requested file is too large > %r bytes" % (self.max_size,),
+ Codes.TOO_LARGE,
+ )
+ )
self.deferred = defer.Deferred()
self.transport.loseConnection()
@@ -920,8 +922,7 @@ def _readBodyToFile(response, stream, max_size):
def _flatten_response_never_received(e):
if hasattr(e, "reasons"):
reasons = ", ".join(
- _flatten_response_never_received(f.value)
- for f in e.reasons
+ _flatten_response_never_received(f.value) for f in e.reasons
)
return "%s:[%s]" % (type(e).__name__, reasons)
@@ -943,16 +944,15 @@ def check_content_type_is_json(headers):
"""
c_type = headers.getRawHeaders(b"Content-Type")
if c_type is None:
- raise RequestSendFailed(RuntimeError(
- "No Content-Type header"
- ), can_retry=False)
+ raise RequestSendFailed(RuntimeError("No Content-Type header"), can_retry=False)
- c_type = c_type[0].decode('ascii') # only the first header
+ c_type = c_type[0].decode("ascii") # only the first header
val, options = cgi.parse_header(c_type)
if val != "application/json":
- raise RequestSendFailed(RuntimeError(
- "Content-Type not application/json: was '%s'" % c_type
- ), can_retry=False)
+ raise RequestSendFailed(
+ RuntimeError("Content-Type not application/json: was '%s'" % c_type),
+ can_retry=False,
+ )
def encode_query_args(args):
@@ -967,4 +967,4 @@ def encode_query_args(args):
query_bytes = urllib.parse.urlencode(encoded_args, True)
- return query_bytes.encode('utf8')
+ return query_bytes.encode("utf8")
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 16fb7935da..6fd13e87d1 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -81,9 +81,7 @@ def wrap_json_request_handler(h):
yield h(self, request)
except SynapseError as e:
code = e.code
- logger.info(
- "%s SynapseError: %s - %s", request, code, e.msg
- )
+ logger.info("%s SynapseError: %s - %s", request, code, e.msg)
# Only respond with an error response if we haven't already started
# writing, otherwise lets just kill the connection
@@ -96,7 +94,10 @@ def wrap_json_request_handler(h):
pass
else:
respond_with_json(
- request, code, e.error_dict(), send_cors=True,
+ request,
+ code,
+ e.error_dict(),
+ send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
)
@@ -124,10 +125,7 @@ def wrap_json_request_handler(h):
respond_with_json(
request,
500,
- {
- "error": "Internal server error",
- "errcode": Codes.UNKNOWN,
- },
+ {"error": "Internal server error", "errcode": Codes.UNKNOWN},
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
)
@@ -143,6 +141,7 @@ def wrap_html_request_handler(h):
The handler method must have a signature of "handle_foo(self, request)",
where "request" must be a SynapseRequest.
"""
+
def wrapped_request_handler(self, request):
d = defer.maybeDeferred(h, self, request)
d.addErrback(_return_html_error, request)
@@ -164,9 +163,7 @@ def _return_html_error(f, request):
msg = cme.msg
if isinstance(cme, SynapseError):
- logger.info(
- "%s SynapseError: %s - %s", request, code, msg
- )
+ logger.info("%s SynapseError: %s - %s", request, code, msg)
else:
logger.error(
"Failed handle request %r",
@@ -183,9 +180,7 @@ def _return_html_error(f, request):
exc_info=(f.type, f.value, f.getTracebackObject()),
)
- body = HTML_ERROR_TEMPLATE.format(
- code=code, msg=cgi.escape(msg),
- ).encode("utf-8")
+ body = HTML_ERROR_TEMPLATE.format(code=code, msg=cgi.escape(msg)).encode("utf-8")
request.setResponseCode(code)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%i" % (len(body),))
@@ -205,6 +200,7 @@ def wrap_async_request_handler(h):
The handler may return a deferred, in which case the completion of the request isn't
logged until the deferred completes.
"""
+
@defer.inlineCallbacks
def wrapped_async_request_handler(self, request):
with request.processing():
@@ -306,12 +302,14 @@ class JsonResource(HttpServer, resource.Resource):
# URL again (as it was decoded by _get_handler_for request), as
# ASCII because it's a URL, and then decode it to get the UTF-8
# characters that were quoted.
- return urllib.parse.unquote(s.encode('ascii')).decode('utf8')
+ return urllib.parse.unquote(s.encode("ascii")).decode("utf8")
- kwargs = intern_dict({
- name: _unquote(value) if value else value
- for name, value in group_dict.items()
- })
+ kwargs = intern_dict(
+ {
+ name: _unquote(value) if value else value
+ for name, value in group_dict.items()
+ }
+ )
callback_return = yield callback(request, **kwargs)
if callback_return is not None:
@@ -339,7 +337,7 @@ class JsonResource(HttpServer, resource.Resource):
# Loop through all the registered callbacks to check if the method
# and path regex match
for path_entry in self.path_regexs.get(request.method, []):
- m = path_entry.pattern.match(request.path.decode('ascii'))
+ m = path_entry.pattern.match(request.path.decode("ascii"))
if m:
# We found a match!
return path_entry.callback, m.groupdict()
@@ -347,11 +345,14 @@ class JsonResource(HttpServer, resource.Resource):
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
return _unrecognised_request_handler, {}
- def _send_response(self, request, code, response_json_object,
- response_code_message=None):
+ def _send_response(
+ self, request, code, response_json_object, response_code_message=None
+ ):
# TODO: Only enable CORS for the requests that need it.
respond_with_json(
- request, code, response_json_object,
+ request,
+ code,
+ response_json_object,
send_cors=True,
response_code_message=response_code_message,
pretty_print=_request_user_agent_is_curl(request),
@@ -395,7 +396,7 @@ class RootRedirect(resource.Resource):
self.url = path
def render_GET(self, request):
- return redirectTo(self.url.encode('ascii'), request)
+ return redirectTo(self.url.encode("ascii"), request)
def getChild(self, name, request):
if len(name) == 0:
@@ -403,16 +404,22 @@ class RootRedirect(resource.Resource):
return resource.Resource.getChild(self, name, request)
-def respond_with_json(request, code, json_object, send_cors=False,
- response_code_message=None, pretty_print=False,
- canonical_json=True):
+def respond_with_json(
+ request,
+ code,
+ json_object,
+ send_cors=False,
+ response_code_message=None,
+ pretty_print=False,
+ canonical_json=True,
+):
# could alternatively use request.notifyFinish() and flip a flag when
# the Deferred fires, but since the flag is RIGHT THERE it seems like
# a waste.
if request._disconnected:
logger.warn(
- "Not sending response to request %s, already disconnected.",
- request)
+ "Not sending response to request %s, already disconnected.", request
+ )
return
if pretty_print:
@@ -425,14 +432,17 @@ def respond_with_json(request, code, json_object, send_cors=False,
json_bytes = json.dumps(json_object).encode("utf-8")
return respond_with_json_bytes(
- request, code, json_bytes,
+ request,
+ code,
+ json_bytes,
send_cors=send_cors,
response_code_message=response_code_message,
)
-def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
- response_code_message=None):
+def respond_with_json_bytes(
+ request, code, json_bytes, send_cors=False, response_code_message=None
+):
"""Sends encoded JSON in response to the given request.
Args:
@@ -474,7 +484,7 @@ def set_cors_headers(request):
)
request.setHeader(
b"Access-Control-Allow-Headers",
- b"Origin, X-Requested-With, Content-Type, Accept, Authorization"
+ b"Origin, X-Requested-With, Content-Type, Accept, Authorization",
)
@@ -498,9 +508,7 @@ def finish_request(request):
def _request_user_agent_is_curl(request):
- user_agents = request.requestHeaders.getRawHeaders(
- b"User-Agent", default=[]
- )
+ user_agents = request.requestHeaders.getRawHeaders(b"User-Agent", default=[])
for user_agent in user_agents:
if b"curl" in user_agent:
return True
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 197c652850..cd8415acd5 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -48,7 +48,7 @@ def parse_integer(request, name, default=None, required=False):
def parse_integer_from_args(args, name, default=None, required=False):
if not isinstance(name, bytes):
- name = name.encode('ascii')
+ name = name.encode("ascii")
if name in args:
try:
@@ -89,18 +89,14 @@ def parse_boolean(request, name, default=None, required=False):
def parse_boolean_from_args(args, name, default=None, required=False):
if not isinstance(name, bytes):
- name = name.encode('ascii')
+ name = name.encode("ascii")
if name in args:
try:
- return {
- b"true": True,
- b"false": False,
- }[args[name][0]]
+ return {b"true": True, b"false": False}[args[name][0]]
except Exception:
message = (
- "Boolean query parameter %r must be one of"
- " ['true', 'false']"
+ "Boolean query parameter %r must be one of" " ['true', 'false']"
) % (name,)
raise SynapseError(400, message)
else:
@@ -111,8 +107,15 @@ def parse_boolean_from_args(args, name, default=None, required=False):
return default
-def parse_string(request, name, default=None, required=False,
- allowed_values=None, param_type="string", encoding='ascii'):
+def parse_string(
+ request,
+ name,
+ default=None,
+ required=False,
+ allowed_values=None,
+ param_type="string",
+ encoding="ascii",
+):
"""
Parse a string parameter from the request query string.
@@ -145,11 +148,18 @@ def parse_string(request, name, default=None, required=False,
)
-def parse_string_from_args(args, name, default=None, required=False,
- allowed_values=None, param_type="string", encoding='ascii'):
+def parse_string_from_args(
+ args,
+ name,
+ default=None,
+ required=False,
+ allowed_values=None,
+ param_type="string",
+ encoding="ascii",
+):
if not isinstance(name, bytes):
- name = name.encode('ascii')
+ name = name.encode("ascii")
if name in args:
value = args[name][0]
@@ -159,7 +169,8 @@ def parse_string_from_args(args, name, default=None, required=False,
if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % (
- name, ", ".join(repr(v) for v in allowed_values)
+ name,
+ ", ".join(repr(v) for v in allowed_values),
)
raise SynapseError(400, message)
else:
@@ -201,7 +212,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
# Decode to Unicode so that simplejson will return Unicode strings on
# Python 2
try:
- content_unicode = content_bytes.decode('utf8')
+ content_unicode = content_bytes.decode("utf8")
except UnicodeDecodeError:
logger.warn("Unable to decode UTF-8")
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
@@ -227,9 +238,7 @@ def parse_json_object_from_request(request, allow_empty_body=False):
SynapseError if the request body couldn't be decoded as JSON or
if it wasn't a JSON object.
"""
- content = parse_json_value_from_request(
- request, allow_empty_body=allow_empty_body,
- )
+ content = parse_json_value_from_request(request, allow_empty_body=allow_empty_body)
if allow_empty_body and content is None:
return {}
diff --git a/synapse/http/site.py b/synapse/http/site.py
index e508c0bd4f..93f679ea48 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -46,10 +46,11 @@ class SynapseRequest(Request):
Attributes:
logcontext(LoggingContext) : the log context for this request
"""
+
def __init__(self, site, channel, *args, **kw):
Request.__init__(self, channel, *args, **kw)
self.site = site
- self._channel = channel # this is used by the tests
+ self._channel = channel # this is used by the tests
self.authenticated_entity = None
self.start_time = 0
@@ -72,12 +73,12 @@ class SynapseRequest(Request):
def __repr__(self):
# We overwrite this so that we don't log ``access_token``
- return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % (
+ return "<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>" % (
self.__class__.__name__,
id(self),
self.get_method(),
self.get_redacted_uri(),
- self.clientproto.decode('ascii', errors='replace'),
+ self.clientproto.decode("ascii", errors="replace"),
self.site.site_tag,
)
@@ -87,7 +88,7 @@ class SynapseRequest(Request):
def get_redacted_uri(self):
uri = self.uri
if isinstance(uri, bytes):
- uri = self.uri.decode('ascii')
+ uri = self.uri.decode("ascii")
return redact_uri(uri)
def get_method(self):
@@ -102,7 +103,7 @@ class SynapseRequest(Request):
"""
method = self.method
if isinstance(method, bytes):
- method = self.method.decode('ascii')
+ method = self.method.decode("ascii")
return method
def get_user_agent(self):
@@ -134,8 +135,7 @@ class SynapseRequest(Request):
# dispatching to the handler, so that the handler
# can update the servlet name in the request
# metrics
- requests_counter.labels(self.get_method(),
- self.request_metrics.name).inc()
+ requests_counter.labels(self.get_method(), self.request_metrics.name).inc()
@contextlib.contextmanager
def processing(self):
@@ -200,7 +200,7 @@ class SynapseRequest(Request):
# the client disconnects.
with PreserveLoggingContext(self.logcontext):
logger.warn(
- "Error processing request %r: %s %s", self, reason.type, reason.value,
+ "Error processing request %r: %s %s", self, reason.type, reason.value
)
if not self._is_processing:
@@ -222,7 +222,7 @@ class SynapseRequest(Request):
self.start_time = time.time()
self.request_metrics = RequestMetrics()
self.request_metrics.start(
- self.start_time, name=servlet_name, method=self.get_method(),
+ self.start_time, name=servlet_name, method=self.get_method()
)
self.site.access_logger.info(
@@ -230,7 +230,7 @@ class SynapseRequest(Request):
self.getClientIP(),
self.site.site_tag,
self.get_method(),
- self.get_redacted_uri()
+ self.get_redacted_uri(),
)
def _finished_processing(self):
@@ -282,7 +282,7 @@ class SynapseRequest(Request):
self.site.access_logger.info(
"%s - %s - {%s}"
" Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
- " %sB %s \"%s %s %s\" \"%s\" [%d dbevts]",
+ ' %sB %s "%s %s %s" "%s" [%d dbevts]',
self.getClientIP(),
self.site.site_tag,
authenticated_entity,
@@ -297,7 +297,7 @@ class SynapseRequest(Request):
code,
self.get_method(),
self.get_redacted_uri(),
- self.clientproto.decode('ascii', errors='replace'),
+ self.clientproto.decode("ascii", errors="replace"),
user_agent,
usage.evt_db_fetch_count,
)
@@ -316,14 +316,19 @@ class XForwardedForRequest(SynapseRequest):
Add a layer on top of another request that only uses the value of an
X-Forwarded-For header as the result of C{getClientIP}.
"""
+
def getClientIP(self):
"""
@return: The client address (the first address) in the value of the
I{X-Forwarded-For header}. If the header is not present, return
C{b"-"}.
"""
- return self.requestHeaders.getRawHeaders(
- b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip().decode('ascii')
+ return (
+ self.requestHeaders.getRawHeaders(b"x-forwarded-for", [b"-"])[0]
+ .split(b",")[0]
+ .strip()
+ .decode("ascii")
+ )
class SynapseRequestFactory(object):
@@ -343,8 +348,17 @@ class SynapseSite(Site):
Subclass of a twisted http Site that does access logging with python's
standard logging
"""
- def __init__(self, logger_name, site_tag, config, resource,
- server_version_string, *args, **kwargs):
+
+ def __init__(
+ self,
+ logger_name,
+ site_tag,
+ config,
+ resource,
+ server_version_string,
+ *args,
+ **kwargs
+ ):
Site.__init__(self, resource, *args, **kwargs)
self.site_tag = site_tag
@@ -352,7 +366,7 @@ class SynapseSite(Site):
proxied = config.get("x_forwarded", False)
self.requestFactory = SynapseRequestFactory(self, proxied)
self.access_logger = logging.getLogger(logger_name)
- self.server_version_string = server_version_string.encode('ascii')
+ self.server_version_string = server_version_string.encode("ascii")
def log(self, request):
pass
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index ef48984fdd..1f30179b51 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -25,7 +25,7 @@ import six
import attr
from prometheus_client import Counter, Gauge, Histogram
-from prometheus_client.core import REGISTRY, GaugeMetricFamily
+from prometheus_client.core import REGISTRY, GaugeMetricFamily, HistogramMetricFamily
from twisted.internet import reactor
@@ -40,7 +40,6 @@ HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
class RegistryProxy(object):
-
@staticmethod
def collect():
for metric in REGISTRY.collect():
@@ -63,10 +62,7 @@ class LaterGauge(object):
try:
calls = self.caller()
except Exception:
- logger.exception(
- "Exception running callback for LaterGauge(%s)",
- self.name,
- )
+ logger.exception("Exception running callback for LaterGauge(%s)", self.name)
yield g
return
@@ -116,9 +112,7 @@ class InFlightGauge(object):
# Create a class which have the sub_metrics values as attributes, which
# default to 0 on initialization. Used to pass to registered callbacks.
self._metrics_class = attr.make_class(
- "_MetricsEntry",
- attrs={x: attr.ib(0) for x in sub_metrics},
- slots=True,
+ "_MetricsEntry", attrs={x: attr.ib(0) for x in sub_metrics}, slots=True
)
# Counts number of in flight blocks for a given set of label values
@@ -157,7 +151,9 @@ class InFlightGauge(object):
Note: may be called by a separate thread.
"""
- in_flight = GaugeMetricFamily(self.name + "_total", self.desc, labels=self.labels)
+ in_flight = GaugeMetricFamily(
+ self.name + "_total", self.desc, labels=self.labels
+ )
metrics_by_key = {}
@@ -179,7 +175,9 @@ class InFlightGauge(object):
yield in_flight
for name in self.sub_metrics:
- gauge = GaugeMetricFamily("_".join([self.name, name]), "", labels=self.labels)
+ gauge = GaugeMetricFamily(
+ "_".join([self.name, name]), "", labels=self.labels
+ )
for key, metrics in six.iteritems(metrics_by_key):
gauge.add_metric(key, getattr(metrics, name))
yield gauge
@@ -193,17 +191,76 @@ class InFlightGauge(object):
all_gauges[self.name] = self
+@attr.s(hash=True)
+class BucketCollector(object):
+ """
+ Like a Histogram, but allows buckets to be point-in-time instead of
+ incrementally added to.
+
+ Args:
+ name (str): Base name of metric to be exported to Prometheus.
+ data_collector (callable -> dict): A synchronous callable that
+ returns a dict mapping bucket to number of items in the
+ bucket. If these buckets are not the same as the buckets
+ given to this class, they will be remapped into them.
+ buckets (list[float]): List of floats/ints of the buckets to
+ give to Prometheus. +Inf is ignored, if given.
+
+ """
+
+ name = attr.ib()
+ data_collector = attr.ib()
+ buckets = attr.ib()
+
+ def collect(self):
+
+ # Fetch the data -- this must be synchronous!
+ data = self.data_collector()
+
+ buckets = {}
+
+ res = []
+ for x in data.keys():
+ for i, bound in enumerate(self.buckets):
+ if x <= bound:
+ buckets[bound] = buckets.get(bound, 0) + data[x]
+
+ for i in self.buckets:
+ res.append([str(i), buckets.get(i, 0)])
+
+ res.append(["+Inf", sum(data.values())])
+
+ metric = HistogramMetricFamily(
+ self.name, "", buckets=res, sum_value=sum([x * y for x, y in data.items()])
+ )
+ yield metric
+
+ def __attrs_post_init__(self):
+ self.buckets = [float(x) for x in self.buckets if x != "+Inf"]
+ if self.buckets != sorted(self.buckets):
+ raise ValueError("Buckets not sorted")
+
+ self.buckets = tuple(self.buckets)
+
+ if self.name in all_gauges.keys():
+ logger.warning("%s already registered, reregistering" % (self.name,))
+ REGISTRY.unregister(all_gauges.pop(self.name))
+
+ REGISTRY.register(self)
+ all_gauges[self.name] = self
+
+
#
# Detailed CPU metrics
#
-class CPUMetrics(object):
+class CPUMetrics(object):
def __init__(self):
ticks_per_sec = 100
try:
# Try and get the system config
- ticks_per_sec = os.sysconf('SC_CLK_TCK')
+ ticks_per_sec = os.sysconf("SC_CLK_TCK")
except (ValueError, TypeError, AttributeError):
pass
@@ -237,13 +294,28 @@ gc_time = Histogram(
"python_gc_time",
"Time taken to GC (sec)",
["gen"],
- buckets=[0.0025, 0.005, 0.01, 0.025, 0.05, 0.10, 0.25, 0.50, 1.00, 2.50,
- 5.00, 7.50, 15.00, 30.00, 45.00, 60.00],
+ buckets=[
+ 0.0025,
+ 0.005,
+ 0.01,
+ 0.025,
+ 0.05,
+ 0.10,
+ 0.25,
+ 0.50,
+ 1.00,
+ 2.50,
+ 5.00,
+ 7.50,
+ 15.00,
+ 30.00,
+ 45.00,
+ 60.00,
+ ],
)
class GCCounts(object):
-
def collect(self):
cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"])
for n, m in enumerate(gc.get_count()):
@@ -279,9 +351,7 @@ sent_transactions_counter = Counter("synapse_federation_client_sent_transactions
events_processed_counter = Counter("synapse_federation_client_events_processed", "")
event_processing_loop_counter = Counter(
- "synapse_event_processing_loop_count",
- "Event processing loop iterations",
- ["name"],
+ "synapse_event_processing_loop_count", "Event processing loop iterations", ["name"]
)
event_processing_loop_room_count = Counter(
@@ -311,7 +381,6 @@ last_ticked = time.time()
class ReactorLastSeenMetric(object):
-
def collect(self):
cm = GaugeMetricFamily(
"python_twisted_reactor_last_seen",
@@ -325,7 +394,6 @@ REGISTRY.register(ReactorLastSeenMetric())
def runUntilCurrentTimer(func):
-
@functools.wraps(func)
def f(*args, **kwargs):
now = reactor.seconds()
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 037f1c490e..167e2c068a 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -60,8 +60,10 @@ _background_process_db_txn_count = Counter(
_background_process_db_txn_duration = Counter(
"synapse_background_process_db_txn_duration_seconds",
- ("Seconds spent by background processes waiting for database "
- "transactions, excluding scheduling time"),
+ (
+ "Seconds spent by background processes waiting for database "
+ "transactions, excluding scheduling time"
+ ),
["name"],
registry=None,
)
@@ -94,6 +96,7 @@ class _Collector(object):
Ensures that all of the metrics are up-to-date with any in-flight processes
before they are returned.
"""
+
def collect(self):
background_process_in_flight_count = GaugeMetricFamily(
"synapse_background_process_in_flight_count",
@@ -105,14 +108,11 @@ class _Collector(object):
# We also copy the process lists as that can also change
with _bg_metrics_lock:
_background_processes_copy = {
- k: list(v)
- for k, v in six.iteritems(_background_processes)
+ k: list(v) for k, v in six.iteritems(_background_processes)
}
for desc, processes in six.iteritems(_background_processes_copy):
- background_process_in_flight_count.add_metric(
- (desc,), len(processes),
- )
+ background_process_in_flight_count.add_metric((desc,), len(processes))
for process in processes:
process.update_metrics()
@@ -121,11 +121,11 @@ class _Collector(object):
# now we need to run collect() over each of the static Counters, and
# yield each metric they return.
for m in (
- _background_process_ru_utime,
- _background_process_ru_stime,
- _background_process_db_txn_count,
- _background_process_db_txn_duration,
- _background_process_db_sched_duration,
+ _background_process_ru_utime,
+ _background_process_ru_stime,
+ _background_process_db_txn_count,
+ _background_process_db_txn_duration,
+ _background_process_db_sched_duration,
):
for r in m.collect():
yield r
@@ -151,14 +151,12 @@ class _BackgroundProcess(object):
_background_process_ru_utime.labels(self.desc).inc(diff.ru_utime)
_background_process_ru_stime.labels(self.desc).inc(diff.ru_stime)
- _background_process_db_txn_count.labels(self.desc).inc(
- diff.db_txn_count,
- )
+ _background_process_db_txn_count.labels(self.desc).inc(diff.db_txn_count)
_background_process_db_txn_duration.labels(self.desc).inc(
- diff.db_txn_duration_sec,
+ diff.db_txn_duration_sec
)
_background_process_db_sched_duration.labels(self.desc).inc(
- diff.db_sched_duration_sec,
+ diff.db_sched_duration_sec
)
@@ -182,6 +180,7 @@ def run_as_background_process(desc, func, *args, **kwargs):
Returns: Deferred which returns the result of func, but note that it does not
follow the synapse logcontext rules.
"""
+
@defer.inlineCallbacks
def run():
with _bg_metrics_lock:
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index b3abd1b3c6..bf43ca09be 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -21,6 +21,7 @@ class ModuleApi(object):
"""A proxy object that gets passed to password auth providers so they
can register new users etc if necessary.
"""
+
def __init__(self, hs, auth_handler):
self.hs = hs
@@ -57,7 +58,7 @@ class ModuleApi(object):
Returns:
str: qualified @user:id
"""
- if username.startswith('@'):
+ if username.startswith("@"):
return username
return UserID(username, self.hs.hostname).to_string()
@@ -89,8 +90,7 @@ class ModuleApi(object):
# Register the user
reg = self.hs.get_registration_handler()
user_id, access_token = yield reg.register(
- localpart=localpart, default_display_name=displayname,
- bind_emails=emails,
+ localpart=localpart, default_display_name=displayname, bind_emails=emails
)
defer.returnValue((user_id, access_token))
diff --git a/synapse/notifier.py b/synapse/notifier.py
index ff589660da..d398078eed 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -37,7 +37,8 @@ logger = logging.getLogger(__name__)
notified_events_counter = Counter("synapse_notifier_notified_events", "")
users_woken_by_stream_counter = Counter(
- "synapse_notifier_users_woken_by_stream", "", ["stream"])
+ "synapse_notifier_users_woken_by_stream", "", ["stream"]
+)
# TODO(paul): Should be shared somewhere
@@ -55,6 +56,7 @@ class _NotificationListener(object):
The events stream handler will have yielded to the deferred, so to
notify the handler it is sufficient to resolve the deferred.
"""
+
__slots__ = ["deferred"]
def __init__(self, deferred):
@@ -95,9 +97,7 @@ class _NotifierUserStream(object):
stream_id(str): The new id for the stream the event came from.
time_now_ms(int): The current time in milliseconds.
"""
- self.current_token = self.current_token.copy_and_advance(
- stream_key, stream_id
- )
+ self.current_token = self.current_token.copy_and_advance(stream_key, stream_id)
self.last_notified_token = self.current_token
self.last_notified_ms = time_now_ms
noify_deferred = self.notify_deferred
@@ -141,6 +141,7 @@ class _NotifierUserStream(object):
class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))):
def __nonzero__(self):
return bool(self.events)
+
__bool__ = __nonzero__ # python3
@@ -190,15 +191,17 @@ class Notifier(object):
all_user_streams.add(x)
return sum(stream.count_listeners() for stream in all_user_streams)
+
LaterGauge("synapse_notifier_listeners", "", [], count_listeners)
LaterGauge(
- "synapse_notifier_rooms", "", [],
+ "synapse_notifier_rooms",
+ "",
+ [],
lambda: count(bool, list(self.room_to_user_streams.values())),
)
LaterGauge(
- "synapse_notifier_users", "", [],
- lambda: len(self.user_to_user_stream),
+ "synapse_notifier_users", "", [], lambda: len(self.user_to_user_stream)
)
def add_replication_callback(self, cb):
@@ -209,8 +212,9 @@ class Notifier(object):
"""
self.replication_callbacks.append(cb)
- def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
- extra_users=[]):
+ def on_new_room_event(
+ self, event, room_stream_id, max_room_stream_id, extra_users=[]
+ ):
""" Used by handlers to inform the notifier something has happened
in the room, room event wise.
@@ -222,9 +226,7 @@ class Notifier(object):
until all previous events have been persisted before notifying
the client streams.
"""
- self.pending_new_room_events.append((
- room_stream_id, event, extra_users
- ))
+ self.pending_new_room_events.append((room_stream_id, event, extra_users))
self._notify_pending_new_room_events(max_room_stream_id)
self.notify_replication()
@@ -240,9 +242,9 @@ class Notifier(object):
self.pending_new_room_events = []
for room_stream_id, event, extra_users in pending:
if room_stream_id > max_room_stream_id:
- self.pending_new_room_events.append((
- room_stream_id, event, extra_users
- ))
+ self.pending_new_room_events.append(
+ (room_stream_id, event, extra_users)
+ )
else:
self._on_new_room_event(event, room_stream_id, extra_users)
@@ -250,8 +252,7 @@ class Notifier(object):
"""Notify any user streams that are interested in this room event"""
# poke any interested application service.
run_as_background_process(
- "notify_app_services",
- self._notify_app_services, room_stream_id,
+ "notify_app_services", self._notify_app_services, room_stream_id
)
if self.federation_sender:
@@ -261,9 +262,7 @@ class Notifier(object):
self._user_joined_room(event.state_key, event.room_id)
self.on_new_event(
- "room_key", room_stream_id,
- users=extra_users,
- rooms=[event.room_id],
+ "room_key", room_stream_id, users=extra_users, rooms=[event.room_id]
)
@defer.inlineCallbacks
@@ -305,8 +304,9 @@ class Notifier(object):
self.notify_replication()
@defer.inlineCallbacks
- def wait_for_events(self, user_id, timeout, callback, room_ids=None,
- from_token=StreamToken.START):
+ def wait_for_events(
+ self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START
+ ):
"""Wait until the callback returns a non empty response or the
timeout fires.
"""
@@ -339,7 +339,7 @@ class Notifier(object):
listener = user_stream.new_listener(prev_token)
listener.deferred = timeout_deferred(
listener.deferred,
- (end_time - now) / 1000.,
+ (end_time - now) / 1000.0,
self.hs.get_reactor(),
)
with PreserveLoggingContext():
@@ -368,9 +368,15 @@ class Notifier(object):
defer.returnValue(result)
@defer.inlineCallbacks
- def get_events_for(self, user, pagination_config, timeout,
- only_keys=None,
- is_guest=False, explicit_room_id=None):
+ def get_events_for(
+ self,
+ user,
+ pagination_config,
+ timeout,
+ only_keys=None,
+ is_guest=False,
+ explicit_room_id=None,
+ ):
""" For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any
new events to happen before returning.
@@ -419,10 +425,7 @@ class Notifier(object):
if name == "room":
new_events = yield filter_events_for_client(
- self.store,
- user.to_string(),
- new_events,
- is_peeking=is_peeking,
+ self.store, user.to_string(), new_events, is_peeking=is_peeking
)
elif name == "presence":
now = self.clock.time_msec()
@@ -450,7 +453,8 @@ class Notifier(object):
#
# I am sorry for what I have done.
user_id_for_stream = "_PEEKING_%s_%s" % (
- explicit_room_id, user_id_for_stream
+ explicit_room_id,
+ user_id_for_stream,
)
result = yield self.wait_for_events(
@@ -477,9 +481,7 @@ class Notifier(object):
@defer.inlineCallbacks
def _is_world_readable(self, room_id):
state = yield self.state_handler.get_current_state(
- room_id,
- EventTypes.RoomHistoryVisibility,
- "",
+ room_id, EventTypes.RoomHistoryVisibility, ""
)
if state and "history_visibility" in state.content:
defer.returnValue(state.content["history_visibility"] == "world_readable")
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index a5de75c48a..1ffd5e2df3 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -40,6 +40,4 @@ class ActionGenerator(object):
@defer.inlineCallbacks
def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "action_for_event_by_user"):
- yield self.bulk_evaluator.action_for_event_by_user(
- event, context
- )
+ yield self.bulk_evaluator.action_for_event_by_user(event, context)
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 3523a40108..96d087de22 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -31,48 +31,54 @@ def list_with_base_rules(rawrules):
# Grab the base rules that the user has modified.
# The modified base rules have a priority_class of -1.
- modified_base_rules = {
- r['rule_id']: r for r in rawrules if r['priority_class'] < 0
- }
+ modified_base_rules = {r["rule_id"]: r for r in rawrules if r["priority_class"] < 0}
# Remove the modified base rules from the list, They'll be added back
# in the default postions in the list.
- rawrules = [r for r in rawrules if r['priority_class'] >= 0]
+ rawrules = [r for r in rawrules if r["priority_class"] >= 0]
# shove the server default rules for each kind onto the end of each
current_prio_class = list(PRIORITY_CLASS_INVERSE_MAP)[-1]
- ruleslist.extend(make_base_prepend_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
- ))
+ ruleslist.extend(
+ make_base_prepend_rules(
+ PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
+ )
+ )
for r in rawrules:
- if r['priority_class'] < current_prio_class:
- while r['priority_class'] < current_prio_class:
- ruleslist.extend(make_base_append_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
- modified_base_rules,
- ))
- current_prio_class -= 1
- if current_prio_class > 0:
- ruleslist.extend(make_base_prepend_rules(
+ if r["priority_class"] < current_prio_class:
+ while r["priority_class"] < current_prio_class:
+ ruleslist.extend(
+ make_base_append_rules(
PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
modified_base_rules,
- ))
+ )
+ )
+ current_prio_class -= 1
+ if current_prio_class > 0:
+ ruleslist.extend(
+ make_base_prepend_rules(
+ PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
+ modified_base_rules,
+ )
+ )
ruleslist.append(r)
while current_prio_class > 0:
- ruleslist.extend(make_base_append_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
- modified_base_rules,
- ))
+ ruleslist.extend(
+ make_base_append_rules(
+ PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
+ )
+ )
current_prio_class -= 1
if current_prio_class > 0:
- ruleslist.extend(make_base_prepend_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
- modified_base_rules,
- ))
+ ruleslist.extend(
+ make_base_prepend_rules(
+ PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
+ )
+ )
return ruleslist
@@ -80,20 +86,20 @@ def list_with_base_rules(rawrules):
def make_base_append_rules(kind, modified_base_rules):
rules = []
- if kind == 'override':
+ if kind == "override":
rules = BASE_APPEND_OVERRIDE_RULES
- elif kind == 'underride':
+ elif kind == "underride":
rules = BASE_APPEND_UNDERRIDE_RULES
- elif kind == 'content':
+ elif kind == "content":
rules = BASE_APPEND_CONTENT_RULES
# Copy the rules before modifying them
rules = copy.deepcopy(rules)
for r in rules:
# Only modify the actions, keep the conditions the same.
- modified = modified_base_rules.get(r['rule_id'])
+ modified = modified_base_rules.get(r["rule_id"])
if modified:
- r['actions'] = modified['actions']
+ r["actions"] = modified["actions"]
return rules
@@ -101,103 +107,86 @@ def make_base_append_rules(kind, modified_base_rules):
def make_base_prepend_rules(kind, modified_base_rules):
rules = []
- if kind == 'override':
+ if kind == "override":
rules = BASE_PREPEND_OVERRIDE_RULES
# Copy the rules before modifying them
rules = copy.deepcopy(rules)
for r in rules:
# Only modify the actions, keep the conditions the same.
- modified = modified_base_rules.get(r['rule_id'])
+ modified = modified_base_rules.get(r["rule_id"])
if modified:
- r['actions'] = modified['actions']
+ r["actions"] = modified["actions"]
return rules
BASE_APPEND_CONTENT_RULES = [
{
- 'rule_id': 'global/content/.m.rule.contains_user_name',
- 'conditions': [
+ "rule_id": "global/content/.m.rule.contains_user_name",
+ "conditions": [
{
- 'kind': 'event_match',
- 'key': 'content.body',
- 'pattern_type': 'user_localpart'
+ "kind": "event_match",
+ "key": "content.body",
+ "pattern_type": "user_localpart",
}
],
- 'actions': [
- 'notify',
- {
- 'set_tweak': 'sound',
- 'value': 'default',
- }, {
- 'set_tweak': 'highlight'
- }
- ]
- },
+ "actions": [
+ "notify",
+ {"set_tweak": "sound", "value": "default"},
+ {"set_tweak": "highlight"},
+ ],
+ }
]
BASE_PREPEND_OVERRIDE_RULES = [
{
- 'rule_id': 'global/override/.m.rule.master',
- 'enabled': False,
- 'conditions': [],
- 'actions': [
- "dont_notify"
- ]
+ "rule_id": "global/override/.m.rule.master",
+ "enabled": False,
+ "conditions": [],
+ "actions": ["dont_notify"],
}
]
BASE_APPEND_OVERRIDE_RULES = [
{
- 'rule_id': 'global/override/.m.rule.suppress_notices',
- 'conditions': [
+ "rule_id": "global/override/.m.rule.suppress_notices",
+ "conditions": [
{
- 'kind': 'event_match',
- 'key': 'content.msgtype',
- 'pattern': 'm.notice',
- '_id': '_suppress_notices',
+ "kind": "event_match",
+ "key": "content.msgtype",
+ "pattern": "m.notice",
+ "_id": "_suppress_notices",
}
],
- 'actions': [
- 'dont_notify',
- ]
+ "actions": ["dont_notify"],
},
# NB. .m.rule.invite_for_me must be higher prio than .m.rule.member_event
# otherwise invites will be matched by .m.rule.member_event
{
- 'rule_id': 'global/override/.m.rule.invite_for_me',
- 'conditions': [
+ "rule_id": "global/override/.m.rule.invite_for_me",
+ "conditions": [
{
- 'kind': 'event_match',
- 'key': 'type',
- 'pattern': 'm.room.member',
- '_id': '_member',
+ "kind": "event_match",
+ "key": "type",
+ "pattern": "m.room.member",
+ "_id": "_member",
},
{
- 'kind': 'event_match',
- 'key': 'content.membership',
- 'pattern': 'invite',
- '_id': '_invite_member',
- },
- {
- 'kind': 'event_match',
- 'key': 'state_key',
- 'pattern_type': 'user_id'
+ "kind": "event_match",
+ "key": "content.membership",
+ "pattern": "invite",
+ "_id": "_invite_member",
},
+ {"kind": "event_match", "key": "state_key", "pattern_type": "user_id"},
+ ],
+ "actions": [
+ "notify",
+ {"set_tweak": "sound", "value": "default"},
+ {"set_tweak": "highlight", "value": False},
],
- 'actions': [
- 'notify',
- {
- 'set_tweak': 'sound',
- 'value': 'default'
- }, {
- 'set_tweak': 'highlight',
- 'value': False
- }
- ]
},
# Will we sometimes want to know about people joining and leaving?
# Perhaps: if so, this could be expanded upon. Seems the most usual case
@@ -206,217 +195,164 @@ BASE_APPEND_OVERRIDE_RULES = [
# join/leave/avatar/displayname events.
# See also: https://matrix.org/jira/browse/SYN-607
{
- 'rule_id': 'global/override/.m.rule.member_event',
- 'conditions': [
+ "rule_id": "global/override/.m.rule.member_event",
+ "conditions": [
{
- 'kind': 'event_match',
- 'key': 'type',
- 'pattern': 'm.room.member',
- '_id': '_member',
+ "kind": "event_match",
+ "key": "type",
+ "pattern": "m.room.member",
+ "_id": "_member",
}
],
- 'actions': [
- 'dont_notify'
- ]
+ "actions": ["dont_notify"],
},
# This was changed from underride to override so it's closer in priority
# to the content rules where the user name highlight rule lives. This
# way a room rule is lower priority than both but a custom override rule
# is higher priority than both.
{
- 'rule_id': 'global/override/.m.rule.contains_display_name',
- 'conditions': [
- {
- 'kind': 'contains_display_name'
- }
+ "rule_id": "global/override/.m.rule.contains_display_name",
+ "conditions": [{"kind": "contains_display_name"}],
+ "actions": [
+ "notify",
+ {"set_tweak": "sound", "value": "default"},
+ {"set_tweak": "highlight"},
],
- 'actions': [
- 'notify',
- {
- 'set_tweak': 'sound',
- 'value': 'default'
- }, {
- 'set_tweak': 'highlight'
- }
- ]
},
{
- 'rule_id': 'global/override/.m.rule.roomnotif',
- 'conditions': [
+ "rule_id": "global/override/.m.rule.roomnotif",
+ "conditions": [
{
- 'kind': 'event_match',
- 'key': 'content.body',
- 'pattern': '@room',
- '_id': '_roomnotif_content',
+ "kind": "event_match",
+ "key": "content.body",
+ "pattern": "@room",
+ "_id": "_roomnotif_content",
},
{
- 'kind': 'sender_notification_permission',
- 'key': 'room',
- '_id': '_roomnotif_pl',
+ "kind": "sender_notification_permission",
+ "key": "room",
+ "_id": "_roomnotif_pl",
},
],
- 'actions': [
- 'notify', {
- 'set_tweak': 'highlight',
- 'value': True,
- }
- ]
+ "actions": ["notify", {"set_tweak": "highlight", "value": True}],
},
{
- 'rule_id': 'global/override/.m.rule.tombstone',
- 'conditions': [
+ "rule_id": "global/override/.m.rule.tombstone",
+ "conditions": [
{
- 'kind': 'event_match',
- 'key': 'type',
- 'pattern': 'm.room.tombstone',
- '_id': '_tombstone',
+ "kind": "event_match",
+ "key": "type",
+ "pattern": "m.room.tombstone",
+ "_id": "_tombstone",
}
],
- 'actions': [
- 'notify', {
- 'set_tweak': 'highlight',
- 'value': True,
- }
- ]
- }
+ "actions": ["notify", {"set_tweak": "highlight", "value": True}],
+ },
]
BASE_APPEND_UNDERRIDE_RULES = [
{
- 'rule_id': 'global/underride/.m.rule.call',
- 'conditions': [
+ "rule_id": "global/underride/.m.rule.call",
+ "conditions": [
{
- 'kind': 'event_match',
- 'key': 'type',
- 'pattern': 'm.call.invite',
- '_id': '_call',
+ "kind": "event_match",
+ "key": "type",
+ "pattern": "m.call.invite",
+ "_id": "_call",
}
],
- 'actions': [
- 'notify',
- {
- 'set_tweak': 'sound',
- 'value': 'ring'
- }, {
- 'set_tweak': 'highlight',
- 'value': False
- }
- ]
+ "actions": [
+ "notify",
+ {"set_tweak": "sound", "value": "ring"},
+ {"set_tweak": "highlight", "value": False},
+ ],
},
# XXX: once m.direct is standardised everywhere, we should use it to detect
# a DM from the user's perspective rather than this heuristic.
{
- 'rule_id': 'global/underride/.m.rule.room_one_to_one',
- 'conditions': [
+ "rule_id": "global/underride/.m.rule.room_one_to_one",
+ "conditions": [
+ {"kind": "room_member_count", "is": "2", "_id": "member_count"},
{
- 'kind': 'room_member_count',
- 'is': '2',
- '_id': 'member_count',
+ "kind": "event_match",
+ "key": "type",
+ "pattern": "m.room.message",
+ "_id": "_message",
},
- {
- 'kind': 'event_match',
- 'key': 'type',
- 'pattern': 'm.room.message',
- '_id': '_message',
- }
],
- 'actions': [
- 'notify',
- {
- 'set_tweak': 'sound',
- 'value': 'default'
- }, {
- 'set_tweak': 'highlight',
- 'value': False
- }
- ]
+ "actions": [
+ "notify",
+ {"set_tweak": "sound", "value": "default"},
+ {"set_tweak": "highlight", "value": False},
+ ],
},
# XXX: this is going to fire for events which aren't m.room.messages
# but are encrypted (e.g. m.call.*)...
{
- 'rule_id': 'global/underride/.m.rule.encrypted_room_one_to_one',
- 'conditions': [
+ "rule_id": "global/underride/.m.rule.encrypted_room_one_to_one",
+ "conditions": [
+ {"kind": "room_member_count", "is": "2", "_id": "member_count"},
{
- 'kind': 'room_member_count',
- 'is': '2',
- '_id': 'member_count',
+ "kind": "event_match",
+ "key": "type",
+ "pattern": "m.room.encrypted",
+ "_id": "_encrypted",
},
- {
- 'kind': 'event_match',
- 'key': 'type',
- 'pattern': 'm.room.encrypted',
- '_id': '_encrypted',
- }
],
- 'actions': [
- 'notify',
- {
- 'set_tweak': 'sound',
- 'value': 'default'
- }, {
- 'set_tweak': 'highlight',
- 'value': False
- }
- ]
+ "actions": [
+ "notify",
+ {"set_tweak": "sound", "value": "default"},
+ {"set_tweak": "highlight", "value": False},
+ ],
},
{
- 'rule_id': 'global/underride/.m.rule.message',
- 'conditions': [
+ "rule_id": "global/underride/.m.rule.message",
+ "conditions": [
{
- 'kind': 'event_match',
- 'key': 'type',
- 'pattern': 'm.room.message',
- '_id': '_message',
+ "kind": "event_match",
+ "key": "type",
+ "pattern": "m.room.message",
+ "_id": "_message",
}
],
- 'actions': [
- 'notify', {
- 'set_tweak': 'highlight',
- 'value': False
- }
- ]
+ "actions": ["notify", {"set_tweak": "highlight", "value": False}],
},
# XXX: this is going to fire for events which aren't m.room.messages
# but are encrypted (e.g. m.call.*)...
{
- 'rule_id': 'global/underride/.m.rule.encrypted',
- 'conditions': [
+ "rule_id": "global/underride/.m.rule.encrypted",
+ "conditions": [
{
- 'kind': 'event_match',
- 'key': 'type',
- 'pattern': 'm.room.encrypted',
- '_id': '_encrypted',
+ "kind": "event_match",
+ "key": "type",
+ "pattern": "m.room.encrypted",
+ "_id": "_encrypted",
}
],
- 'actions': [
- 'notify', {
- 'set_tweak': 'highlight',
- 'value': False
- }
- ]
- }
+ "actions": ["notify", {"set_tweak": "highlight", "value": False}],
+ },
]
BASE_RULE_IDS = set()
for r in BASE_APPEND_CONTENT_RULES:
- r['priority_class'] = PRIORITY_CLASS_MAP['content']
- r['default'] = True
- BASE_RULE_IDS.add(r['rule_id'])
+ r["priority_class"] = PRIORITY_CLASS_MAP["content"]
+ r["default"] = True
+ BASE_RULE_IDS.add(r["rule_id"])
for r in BASE_PREPEND_OVERRIDE_RULES:
- r['priority_class'] = PRIORITY_CLASS_MAP['override']
- r['default'] = True
- BASE_RULE_IDS.add(r['rule_id'])
+ r["priority_class"] = PRIORITY_CLASS_MAP["override"]
+ r["default"] = True
+ BASE_RULE_IDS.add(r["rule_id"])
for r in BASE_APPEND_OVERRIDE_RULES:
- r['priority_class'] = PRIORITY_CLASS_MAP['override']
- r['default'] = True
- BASE_RULE_IDS.add(r['rule_id'])
+ r["priority_class"] = PRIORITY_CLASS_MAP["override"]
+ r["default"] = True
+ BASE_RULE_IDS.add(r["rule_id"])
for r in BASE_APPEND_UNDERRIDE_RULES:
- r['priority_class'] = PRIORITY_CLASS_MAP['underride']
- r['default'] = True
- BASE_RULE_IDS.add(r['rule_id'])
+ r["priority_class"] = PRIORITY_CLASS_MAP["underride"]
+ r["default"] = True
+ BASE_RULE_IDS.add(r["rule_id"])
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 8f9a76147f..c8a5b381da 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -39,9 +39,11 @@ rules_by_room = {}
push_rules_invalidation_counter = Counter(
- "synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", "")
+ "synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", ""
+)
push_rules_state_size_counter = Counter(
- "synapse_push_bulk_push_rule_evaluator_push_rules_state_size_counter", "")
+ "synapse_push_bulk_push_rule_evaluator_push_rules_state_size_counter", ""
+)
# Measures whether we use the fast path of using state deltas, or if we have to
# recalculate from scratch
@@ -83,7 +85,7 @@ class BulkPushRuleEvaluator(object):
# if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited
- if event.type == 'm.room.member' and event.content['membership'] == 'invite':
+ if event.type == "m.room.member" and event.content["membership"] == "invite":
invited = event.state_key
if invited and self.hs.is_mine_id(invited):
has_pusher = yield self.store.user_has_pusher(invited)
@@ -106,7 +108,9 @@ class BulkPushRuleEvaluator(object):
# before any lookup methods get called on it as otherwise there may be
# a race if invalidate_all gets called (which assumes its in the cache)
return RulesForRoom(
- self.hs, room_id, self._get_rules_for_room.cache,
+ self.hs,
+ room_id,
+ self._get_rules_for_room.cache,
self.room_push_rule_cache_metrics,
)
@@ -121,12 +125,10 @@ class BulkPushRuleEvaluator(object):
auth_events = {POWER_KEY: pl_event}
else:
auth_events_ids = yield self.auth.compute_auth_events(
- event, prev_state_ids, for_verification=False,
+ event, prev_state_ids, for_verification=False
)
auth_events = yield self.store.get_events(auth_events_ids)
- auth_events = {
- (e.type, e.state_key): e for e in itervalues(auth_events)
- }
+ auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)}
sender_level = get_user_power_level(event.sender, auth_events)
@@ -145,16 +147,14 @@ class BulkPushRuleEvaluator(object):
rules_by_user = yield self._get_rules_for_event(event, context)
actions_by_user = {}
- room_members = yield self.store.get_joined_users_from_context(
- event, context
- )
+ room_members = yield self.store.get_joined_users_from_context(event, context)
(power_levels, sender_power_level) = (
yield self._get_power_levels_and_sender_level(event, context)
)
evaluator = PushRuleEvaluatorForEvent(
- event, len(room_members), sender_power_level, power_levels,
+ event, len(room_members), sender_power_level, power_levels
)
condition_cache = {}
@@ -180,15 +180,15 @@ class BulkPushRuleEvaluator(object):
display_name = event.content.get("displayname", None)
for rule in rules:
- if 'enabled' in rule and not rule['enabled']:
+ if "enabled" in rule and not rule["enabled"]:
continue
matches = _condition_checker(
- evaluator, rule['conditions'], uid, display_name, condition_cache
+ evaluator, rule["conditions"], uid, display_name, condition_cache
)
if matches:
- actions = [x for x in rule['actions'] if x != 'dont_notify']
- if actions and 'notify' in actions:
+ actions = [x for x in rule["actions"] if x != "dont_notify"]
+ if actions and "notify" in actions:
# Push rules say we should notify the user of this event
actions_by_user[uid] = actions
break
@@ -196,9 +196,7 @@ class BulkPushRuleEvaluator(object):
# Mark in the DB staging area the push actions for users who should be
# notified for this event. (This will then get handled when we persist
# the event)
- yield self.store.add_push_actions_to_staging(
- event.event_id, actions_by_user,
- )
+ yield self.store.add_push_actions_to_staging(event.event_id, actions_by_user)
def _condition_checker(evaluator, conditions, uid, display_name, cache):
@@ -361,19 +359,19 @@ class RulesForRoom(object):
self.sequence,
members={}, # There were no membership changes
rules_by_user=ret_rules_by_user,
- state_group=state_group
+ state_group=state_group,
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
- "Returning push rules for %r %r",
- self.room_id, ret_rules_by_user.keys(),
+ "Returning push rules for %r %r", self.room_id, ret_rules_by_user.keys()
)
defer.returnValue(ret_rules_by_user)
@defer.inlineCallbacks
- def _update_rules_with_member_event_ids(self, ret_rules_by_user, member_event_ids,
- state_group, event):
+ def _update_rules_with_member_event_ids(
+ self, ret_rules_by_user, member_event_ids, state_group, event
+ ):
"""Update the partially filled rules_by_user dict by fetching rules for
any newly joined users in the `member_event_ids` list.
@@ -391,16 +389,13 @@ class RulesForRoom(object):
table="room_memberships",
column="event_id",
iterable=member_event_ids.values(),
- retcols=('user_id', 'membership', 'event_id'),
+ retcols=("user_id", "membership", "event_id"),
keyvalues={},
batch_size=500,
desc="_get_rules_for_member_event_ids",
)
- members = {
- row["event_id"]: (row["user_id"], row["membership"])
- for row in rows
- }
+ members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows}
# If the event is a join event then it will be in current state evnts
# map but not in the DB, so we have to explicitly insert it.
@@ -413,15 +408,15 @@ class RulesForRoom(object):
logger.debug("Found members %r: %r", self.room_id, members.values())
interested_in_user_ids = set(
- user_id for user_id, membership in itervalues(members)
+ user_id
+ for user_id, membership in itervalues(members)
if membership == Membership.JOIN
)
logger.debug("Joined: %r", interested_in_user_ids)
if_users_with_pushers = yield self.store.get_if_users_have_pushers(
- interested_in_user_ids,
- on_invalidate=self.invalidate_all_cb,
+ interested_in_user_ids, on_invalidate=self.invalidate_all_cb
)
user_ids = set(
@@ -431,7 +426,7 @@ class RulesForRoom(object):
logger.debug("With pushers: %r", user_ids)
users_with_receipts = yield self.store.get_users_with_read_receipts_in_room(
- self.room_id, on_invalidate=self.invalidate_all_cb,
+ self.room_id, on_invalidate=self.invalidate_all_cb
)
logger.debug("With receipts: %r", users_with_receipts)
@@ -442,7 +437,7 @@ class RulesForRoom(object):
user_ids.add(uid)
rules_by_user = yield self.store.bulk_get_push_rules(
- user_ids, on_invalidate=self.invalidate_all_cb,
+ user_ids, on_invalidate=self.invalidate_all_cb
)
ret_rules_by_user.update(
diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index 8bd96b1178..a59b639f15 100644
--- a/synapse/push/clientformat.py
+++ b/synapse/push/clientformat.py
@@ -25,14 +25,14 @@ def format_push_rules_for_user(user, ruleslist):
# We're going to be mutating this a lot, so do a deep copy
ruleslist = copy.deepcopy(ruleslist)
- rules = {'global': {}, 'device': {}}
+ rules = {"global": {}, "device": {}}
- rules['global'] = _add_empty_priority_class_arrays(rules['global'])
+ rules["global"] = _add_empty_priority_class_arrays(rules["global"])
for r in ruleslist:
rulearray = None
- template_name = _priority_class_to_template_name(r['priority_class'])
+ template_name = _priority_class_to_template_name(r["priority_class"])
# Remove internal stuff.
for c in r["conditions"]:
@@ -44,14 +44,14 @@ def format_push_rules_for_user(user, ruleslist):
elif pattern_type == "user_localpart":
c["pattern"] = user.localpart
- rulearray = rules['global'][template_name]
+ rulearray = rules["global"][template_name]
template_rule = _rule_to_template(r)
if template_rule:
- if 'enabled' in r:
- template_rule['enabled'] = r['enabled']
+ if "enabled" in r:
+ template_rule["enabled"] = r["enabled"]
else:
- template_rule['enabled'] = True
+ template_rule["enabled"] = True
rulearray.append(template_rule)
return rules
@@ -65,33 +65,33 @@ def _add_empty_priority_class_arrays(d):
def _rule_to_template(rule):
unscoped_rule_id = None
- if 'rule_id' in rule:
- unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id'])
+ if "rule_id" in rule:
+ unscoped_rule_id = _rule_id_from_namespaced(rule["rule_id"])
- template_name = _priority_class_to_template_name(rule['priority_class'])
- if template_name in ['override', 'underride']:
+ template_name = _priority_class_to_template_name(rule["priority_class"])
+ if template_name in ["override", "underride"]:
templaterule = {k: rule[k] for k in ["conditions", "actions"]}
elif template_name in ["sender", "room"]:
- templaterule = {'actions': rule['actions']}
- unscoped_rule_id = rule['conditions'][0]['pattern']
- elif template_name == 'content':
+ templaterule = {"actions": rule["actions"]}
+ unscoped_rule_id = rule["conditions"][0]["pattern"]
+ elif template_name == "content":
if len(rule["conditions"]) != 1:
return None
thecond = rule["conditions"][0]
if "pattern" not in thecond:
return None
- templaterule = {'actions': rule['actions']}
+ templaterule = {"actions": rule["actions"]}
templaterule["pattern"] = thecond["pattern"]
if unscoped_rule_id:
- templaterule['rule_id'] = unscoped_rule_id
- if 'default' in rule:
- templaterule['default'] = rule['default']
+ templaterule["rule_id"] = unscoped_rule_id
+ if "default" in rule:
+ templaterule["default"] = rule["default"]
return templaterule
def _rule_id_from_namespaced(in_rule_id):
- return in_rule_id.split('/')[-1]
+ return in_rule_id.split("/")[-1]
def _priority_class_to_template_name(pc):
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index e8ee67401f..424ffa8b68 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -32,13 +32,13 @@ DELAY_BEFORE_MAIL_MS = 10 * 60 * 1000
THROTTLE_START_MS = 10 * 60 * 1000
THROTTLE_MAX_MS = 24 * 60 * 60 * 1000 # 24h
# THROTTLE_MULTIPLIER = 6 # 10 mins, 1 hour, 6 hours, 24 hours
-THROTTLE_MULTIPLIER = 144 # 10 mins, 24 hours - i.e. jump straight to 1 day
+THROTTLE_MULTIPLIER = 144 # 10 mins, 24 hours - i.e. jump straight to 1 day
# If no event triggers a notification for this long after the previous,
# the throttle is released.
# 12 hours - a gap of 12 hours in conversation is surely enough to merit a new
# notification when things get going again...
-THROTTLE_RESET_AFTER_MS = (12 * 60 * 60 * 1000)
+THROTTLE_RESET_AFTER_MS = 12 * 60 * 60 * 1000
# does each email include all unread notifs, or just the ones which have happened
# since the last mail?
@@ -53,17 +53,18 @@ class EmailPusher(object):
This shares quite a bit of code with httpusher: it would be good to
factor out the common parts
"""
+
def __init__(self, hs, pusherdict, mailer):
self.hs = hs
self.mailer = mailer
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
- self.pusher_id = pusherdict['id']
- self.user_id = pusherdict['user_name']
- self.app_id = pusherdict['app_id']
- self.email = pusherdict['pushkey']
- self.last_stream_ordering = pusherdict['last_stream_ordering']
+ self.pusher_id = pusherdict["id"]
+ self.user_id = pusherdict["user_name"]
+ self.app_id = pusherdict["app_id"]
+ self.email = pusherdict["pushkey"]
+ self.last_stream_ordering = pusherdict["last_stream_ordering"]
self.timed_call = None
self.throttle_params = None
@@ -93,7 +94,9 @@ class EmailPusher(object):
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
if self.max_stream_ordering:
- self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering)
+ self.max_stream_ordering = max(
+ max_stream_ordering, self.max_stream_ordering
+ )
else:
self.max_stream_ordering = max_stream_ordering
self._start_processing()
@@ -114,6 +117,21 @@ class EmailPusher(object):
run_as_background_process("emailpush.process", self._process)
+ def _pause_processing(self):
+ """Used by tests to temporarily pause processing of events.
+
+ Asserts that its not currently processing.
+ """
+ assert not self._is_processing
+ self._is_processing = True
+
+ def _resume_processing(self):
+ """Used by tests to resume processing of events after pausing.
+ """
+ assert self._is_processing
+ self._is_processing = False
+ self._start_processing()
+
@defer.inlineCallbacks
def _process(self):
# we should never get here if we are already processing
@@ -159,14 +177,12 @@ class EmailPusher(object):
return
for push_action in unprocessed:
- received_at = push_action['received_ts']
+ received_at = push_action["received_ts"]
if received_at is None:
received_at = 0
notif_ready_at = received_at + DELAY_BEFORE_MAIL_MS
- room_ready_at = self.room_ready_to_notify_at(
- push_action['room_id']
- )
+ room_ready_at = self.room_ready_to_notify_at(push_action["room_id"])
should_notify_at = max(notif_ready_at, room_ready_at)
@@ -177,25 +193,23 @@ class EmailPusher(object):
# to be delivered.
reason = {
- 'room_id': push_action['room_id'],
- 'now': self.clock.time_msec(),
- 'received_at': received_at,
- 'delay_before_mail_ms': DELAY_BEFORE_MAIL_MS,
- 'last_sent_ts': self.get_room_last_sent_ts(push_action['room_id']),
- 'throttle_ms': self.get_room_throttle_ms(push_action['room_id']),
+ "room_id": push_action["room_id"],
+ "now": self.clock.time_msec(),
+ "received_at": received_at,
+ "delay_before_mail_ms": DELAY_BEFORE_MAIL_MS,
+ "last_sent_ts": self.get_room_last_sent_ts(push_action["room_id"]),
+ "throttle_ms": self.get_room_throttle_ms(push_action["room_id"]),
}
yield self.send_notification(unprocessed, reason)
- yield self.save_last_stream_ordering_and_success(max([
- ea['stream_ordering'] for ea in unprocessed
- ]))
+ yield self.save_last_stream_ordering_and_success(
+ max([ea["stream_ordering"] for ea in unprocessed])
+ )
# we update the throttle on all the possible unprocessed push actions
for ea in unprocessed:
- yield self.sent_notif_update_throttle(
- ea['room_id'], ea
- )
+ yield self.sent_notif_update_throttle(ea["room_id"], ea)
break
else:
if soonest_due_at is None or should_notify_at < soonest_due_at:
@@ -215,10 +229,17 @@ class EmailPusher(object):
@defer.inlineCallbacks
def save_last_stream_ordering_and_success(self, last_stream_ordering):
+ if last_stream_ordering is None:
+ # This happens if we haven't yet processed anything
+ return
+
self.last_stream_ordering = last_stream_ordering
yield self.store.update_pusher_last_stream_ordering_and_success(
- self.app_id, self.email, self.user_id,
- last_stream_ordering, self.clock.time_msec()
+ self.app_id,
+ self.email,
+ self.user_id,
+ last_stream_ordering,
+ self.clock.time_msec(),
)
def seconds_until(self, ts_msec):
@@ -257,10 +278,10 @@ class EmailPusher(object):
# THROTTLE_RESET_AFTER_MS after the previous one that triggered a
# notif, we release the throttle. Otherwise, the throttle is increased.
time_of_previous_notifs = yield self.store.get_time_of_last_push_action_before(
- notified_push_action['stream_ordering']
+ notified_push_action["stream_ordering"]
)
- time_of_this_notifs = notified_push_action['received_ts']
+ time_of_this_notifs = notified_push_action["received_ts"]
if time_of_previous_notifs is not None and time_of_this_notifs is not None:
gap = time_of_this_notifs - time_of_previous_notifs
@@ -279,12 +300,11 @@ class EmailPusher(object):
new_throttle_ms = THROTTLE_START_MS
else:
new_throttle_ms = min(
- current_throttle_ms * THROTTLE_MULTIPLIER,
- THROTTLE_MAX_MS
+ current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS
)
self.throttle_params[room_id] = {
"last_sent_ts": self.clock.time_msec(),
- "throttle_ms": new_throttle_ms
+ "throttle_ms": new_throttle_ms,
}
yield self.store.set_throttle_params(
self.pusher_id, room_id, self.throttle_params[room_id]
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index fac05aa44c..4e7b6a5531 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -65,16 +65,16 @@ class HttpPusher(object):
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.state_handler = self.hs.get_state_handler()
- self.user_id = pusherdict['user_name']
- self.app_id = pusherdict['app_id']
- self.app_display_name = pusherdict['app_display_name']
- self.device_display_name = pusherdict['device_display_name']
- self.pushkey = pusherdict['pushkey']
- self.pushkey_ts = pusherdict['ts']
- self.data = pusherdict['data']
- self.last_stream_ordering = pusherdict['last_stream_ordering']
+ self.user_id = pusherdict["user_name"]
+ self.app_id = pusherdict["app_id"]
+ self.app_display_name = pusherdict["app_display_name"]
+ self.device_display_name = pusherdict["device_display_name"]
+ self.pushkey = pusherdict["pushkey"]
+ self.pushkey_ts = pusherdict["ts"]
+ self.data = pusherdict["data"]
+ self.last_stream_ordering = pusherdict["last_stream_ordering"]
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
- self.failing_since = pusherdict['failing_since']
+ self.failing_since = pusherdict["failing_since"]
self.timed_call = None
self._is_processing = False
@@ -85,32 +85,26 @@ class HttpPusher(object):
# off as None though as we don't know any better.
self.max_stream_ordering = None
- if 'data' not in pusherdict:
- raise PusherConfigException(
- "No 'data' key for HTTP pusher"
- )
- self.data = pusherdict['data']
+ if "data" not in pusherdict:
+ raise PusherConfigException("No 'data' key for HTTP pusher")
+ self.data = pusherdict["data"]
self.name = "%s/%s/%s" % (
- pusherdict['user_name'],
- pusherdict['app_id'],
- pusherdict['pushkey'],
+ pusherdict["user_name"],
+ pusherdict["app_id"],
+ pusherdict["pushkey"],
)
if self.data is None:
- raise PusherConfigException(
- "data can not be null for HTTP pusher"
- )
+ raise PusherConfigException("data can not be null for HTTP pusher")
- if 'url' not in self.data:
- raise PusherConfigException(
- "'url' required in data for HTTP pusher"
- )
- self.url = self.data['url']
+ if "url" not in self.data:
+ raise PusherConfigException("'url' required in data for HTTP pusher")
+ self.url = self.data["url"]
self.http_client = hs.get_simple_http_client()
self.data_minus_url = {}
self.data_minus_url.update(self.data)
- del self.data_minus_url['url']
+ del self.data_minus_url["url"]
def on_started(self, should_check_for_notifs):
"""Called when this pusher has been started.
@@ -124,7 +118,9 @@ class HttpPusher(object):
self._start_processing()
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
- self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering or 0)
+ self.max_stream_ordering = max(
+ max_stream_ordering, self.max_stream_ordering or 0
+ )
self._start_processing()
def on_new_receipts(self, min_stream_id, max_stream_id):
@@ -192,7 +188,9 @@ class HttpPusher(object):
logger.info(
"Processing %i unprocessed push actions for %s starting at "
"stream_ordering %s",
- len(unprocessed), self.name, self.last_stream_ordering,
+ len(unprocessed),
+ self.name,
+ self.last_stream_ordering,
)
for push_action in unprocessed:
@@ -200,71 +198,72 @@ class HttpPusher(object):
if processed:
http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
- self.last_stream_ordering = push_action['stream_ordering']
+ self.last_stream_ordering = push_action["stream_ordering"]
yield self.store.update_pusher_last_stream_ordering_and_success(
- self.app_id, self.pushkey, self.user_id,
+ self.app_id,
+ self.pushkey,
+ self.user_id,
self.last_stream_ordering,
- self.clock.time_msec()
+ self.clock.time_msec(),
)
if self.failing_since:
self.failing_since = None
yield self.store.update_pusher_failing_since(
- self.app_id, self.pushkey, self.user_id,
- self.failing_since
+ self.app_id, self.pushkey, self.user_id, self.failing_since
)
else:
http_push_failed_counter.inc()
if not self.failing_since:
self.failing_since = self.clock.time_msec()
yield self.store.update_pusher_failing_since(
- self.app_id, self.pushkey, self.user_id,
- self.failing_since
+ self.app_id, self.pushkey, self.user_id, self.failing_since
)
if (
- self.failing_since and
- self.failing_since <
- self.clock.time_msec() - HttpPusher.GIVE_UP_AFTER_MS
+ self.failing_since
+ and self.failing_since
+ < self.clock.time_msec() - HttpPusher.GIVE_UP_AFTER_MS
):
# we really only give up so that if the URL gets
# fixed, we don't suddenly deliver a load
# of old notifications.
- logger.warn("Giving up on a notification to user %s, "
- "pushkey %s",
- self.user_id, self.pushkey)
+ logger.warn(
+ "Giving up on a notification to user %s, " "pushkey %s",
+ self.user_id,
+ self.pushkey,
+ )
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
- self.last_stream_ordering = push_action['stream_ordering']
+ self.last_stream_ordering = push_action["stream_ordering"]
yield self.store.update_pusher_last_stream_ordering(
self.app_id,
self.pushkey,
self.user_id,
- self.last_stream_ordering
+ self.last_stream_ordering,
)
self.failing_since = None
yield self.store.update_pusher_failing_since(
- self.app_id,
- self.pushkey,
- self.user_id,
- self.failing_since
+ self.app_id, self.pushkey, self.user_id, self.failing_since
)
else:
logger.info("Push failed: delaying for %ds", self.backoff_delay)
self.timed_call = self.hs.get_reactor().callLater(
self.backoff_delay, self.on_timer
)
- self.backoff_delay = min(self.backoff_delay * 2, self.MAX_BACKOFF_SEC)
+ self.backoff_delay = min(
+ self.backoff_delay * 2, self.MAX_BACKOFF_SEC
+ )
break
@defer.inlineCallbacks
def _process_one(self, push_action):
- if 'notify' not in push_action['actions']:
+ if "notify" not in push_action["actions"]:
defer.returnValue(True)
- tweaks = push_rule_evaluator.tweaks_for_actions(push_action['actions'])
+ tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"])
badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
- event = yield self.store.get_event(push_action['event_id'], allow_none=True)
+ event = yield self.store.get_event(push_action["event_id"], allow_none=True)
if event is None:
defer.returnValue(True) # It's been redacted
rejected = yield self.dispatch_push(event, tweaks, badge)
@@ -277,37 +276,30 @@ class HttpPusher(object):
# for sanity, we only remove the pushkey if it
# was the one we actually sent...
logger.warn(
- ("Ignoring rejected pushkey %s because we"
- " didn't send it"), pk
+ ("Ignoring rejected pushkey %s because we" " didn't send it"),
+ pk,
)
else:
- logger.info(
- "Pushkey %s was rejected: removing",
- pk
- )
- yield self.hs.remove_pusher(
- self.app_id, pk, self.user_id
- )
+ logger.info("Pushkey %s was rejected: removing", pk)
+ yield self.hs.remove_pusher(self.app_id, pk, self.user_id)
defer.returnValue(True)
@defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks, badge):
- if self.data.get('format') == 'event_id_only':
+ if self.data.get("format") == "event_id_only":
d = {
- 'notification': {
- 'event_id': event.event_id,
- 'room_id': event.room_id,
- 'counts': {
- 'unread': badge,
- },
- 'devices': [
+ "notification": {
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "counts": {"unread": badge},
+ "devices": [
{
- 'app_id': self.app_id,
- 'pushkey': self.pushkey,
- 'pushkey_ts': long(self.pushkey_ts / 1000),
- 'data': self.data_minus_url,
+ "app_id": self.app_id,
+ "pushkey": self.pushkey,
+ "pushkey_ts": long(self.pushkey_ts / 1000),
+ "data": self.data_minus_url,
}
- ]
+ ],
}
}
defer.returnValue(d)
@@ -317,41 +309,41 @@ class HttpPusher(object):
)
d = {
- 'notification': {
- 'id': event.event_id, # deprecated: remove soon
- 'event_id': event.event_id,
- 'room_id': event.room_id,
- 'type': event.type,
- 'sender': event.user_id,
- 'counts': { # -- we don't mark messages as read yet so
- # we have no way of knowing
+ "notification": {
+ "id": event.event_id, # deprecated: remove soon
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "type": event.type,
+ "sender": event.user_id,
+ "counts": { # -- we don't mark messages as read yet so
+ # we have no way of knowing
# Just set the badge to 1 until we have read receipts
- 'unread': badge,
+ "unread": badge,
# 'missed_calls': 2
},
- 'devices': [
+ "devices": [
{
- 'app_id': self.app_id,
- 'pushkey': self.pushkey,
- 'pushkey_ts': long(self.pushkey_ts / 1000),
- 'data': self.data_minus_url,
- 'tweaks': tweaks
+ "app_id": self.app_id,
+ "pushkey": self.pushkey,
+ "pushkey_ts": long(self.pushkey_ts / 1000),
+ "data": self.data_minus_url,
+ "tweaks": tweaks,
}
- ]
+ ],
}
}
- if event.type == 'm.room.member' and event.is_state():
- d['notification']['membership'] = event.content['membership']
- d['notification']['user_is_target'] = event.state_key == self.user_id
+ if event.type == "m.room.member" and event.is_state():
+ d["notification"]["membership"] = event.content["membership"]
+ d["notification"]["user_is_target"] = event.state_key == self.user_id
if self.hs.config.push_include_content and event.content:
- d['notification']['content'] = event.content
+ d["notification"]["content"] = event.content
# We no longer send aliases separately, instead, we send the human
# readable name of the room, which may be an alias.
- if 'sender_display_name' in ctx and len(ctx['sender_display_name']) > 0:
- d['notification']['sender_display_name'] = ctx['sender_display_name']
- if 'name' in ctx and len(ctx['name']) > 0:
- d['notification']['room_name'] = ctx['name']
+ if "sender_display_name" in ctx and len(ctx["sender_display_name"]) > 0:
+ d["notification"]["sender_display_name"] = ctx["sender_display_name"]
+ if "name" in ctx and len(ctx["name"]) > 0:
+ d["notification"]["room_name"] = ctx["name"]
defer.returnValue(d)
@@ -361,16 +353,21 @@ class HttpPusher(object):
if not notification_dict:
defer.returnValue([])
try:
- resp = yield self.http_client.post_json_get_json(self.url, notification_dict)
+ resp = yield self.http_client.post_json_get_json(
+ self.url, notification_dict
+ )
except Exception as e:
logger.warning(
"Failed to push event %s to %s: %s %s",
- event.event_id, self.name, type(e), e,
+ event.event_id,
+ self.name,
+ type(e),
+ e,
)
defer.returnValue(False)
rejected = []
- if 'rejected' in resp:
- rejected = resp['rejected']
+ if "rejected" in resp:
+ rejected = resp["rejected"]
defer.returnValue(rejected)
@defer.inlineCallbacks
@@ -381,21 +378,19 @@ class HttpPusher(object):
"""
logger.info("Sending updated badge count %d to %s", badge, self.name)
d = {
- 'notification': {
- 'id': '',
- 'type': None,
- 'sender': '',
- 'counts': {
- 'unread': badge
- },
- 'devices': [
+ "notification": {
+ "id": "",
+ "type": None,
+ "sender": "",
+ "counts": {"unread": badge},
+ "devices": [
{
- 'app_id': self.app_id,
- 'pushkey': self.pushkey,
- 'pushkey_ts': long(self.pushkey_ts / 1000),
- 'data': self.data_minus_url,
+ "app_id": self.app_id,
+ "pushkey": self.pushkey,
+ "pushkey_ts": long(self.pushkey_ts / 1000),
+ "data": self.data_minus_url,
}
- ]
+ ],
}
}
try:
@@ -403,7 +398,6 @@ class HttpPusher(object):
http_badges_processed_counter.inc()
except Exception as e:
logger.warning(
- "Failed to send badge count to %s: %s %s",
- self.name, type(e), e,
+ "Failed to send badge count to %s: %s %s", self.name, type(e), e
)
http_badges_failed_counter.inc()
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 4bc9eb7313..809199fe88 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -42,17 +42,21 @@ from synapse.visibility import filter_events_for_client
logger = logging.getLogger(__name__)
-MESSAGE_FROM_PERSON_IN_ROOM = "You have a message on %(app)s from %(person)s " \
- "in the %(room)s room..."
+MESSAGE_FROM_PERSON_IN_ROOM = (
+ "You have a message on %(app)s from %(person)s " "in the %(room)s room..."
+)
MESSAGE_FROM_PERSON = "You have a message on %(app)s from %(person)s..."
MESSAGES_FROM_PERSON = "You have messages on %(app)s from %(person)s..."
MESSAGES_IN_ROOM = "You have messages on %(app)s in the %(room)s room..."
-MESSAGES_IN_ROOM_AND_OTHERS = \
+MESSAGES_IN_ROOM_AND_OTHERS = (
"You have messages on %(app)s in the %(room)s room and others..."
-MESSAGES_FROM_PERSON_AND_OTHERS = \
+)
+MESSAGES_FROM_PERSON_AND_OTHERS = (
"You have messages on %(app)s from %(person)s and others..."
-INVITE_FROM_PERSON_TO_ROOM = "%(person)s has invited you to join the " \
- "%(room)s room on %(app)s..."
+)
+INVITE_FROM_PERSON_TO_ROOM = (
+ "%(person)s has invited you to join the " "%(room)s room on %(app)s..."
+)
INVITE_FROM_PERSON = "%(person)s has invited you to chat on %(app)s..."
CONTEXT_BEFORE = 1
@@ -60,12 +64,38 @@ CONTEXT_AFTER = 1
# From https://github.com/matrix-org/matrix-react-sdk/blob/master/src/HtmlUtils.js
ALLOWED_TAGS = [
- 'font', # custom to matrix for IRC-style font coloring
- 'del', # for markdown
+ "font", # custom to matrix for IRC-style font coloring
+ "del", # for markdown
# deliberately no h1/h2 to stop people shouting.
- 'h3', 'h4', 'h5', 'h6', 'blockquote', 'p', 'a', 'ul', 'ol',
- 'nl', 'li', 'b', 'i', 'u', 'strong', 'em', 'strike', 'code', 'hr', 'br', 'div',
- 'table', 'thead', 'caption', 'tbody', 'tr', 'th', 'td', 'pre'
+ "h3",
+ "h4",
+ "h5",
+ "h6",
+ "blockquote",
+ "p",
+ "a",
+ "ul",
+ "ol",
+ "nl",
+ "li",
+ "b",
+ "i",
+ "u",
+ "strong",
+ "em",
+ "strike",
+ "code",
+ "hr",
+ "br",
+ "div",
+ "table",
+ "thead",
+ "caption",
+ "tbody",
+ "tr",
+ "th",
+ "td",
+ "pre",
]
ALLOWED_ATTRS = {
# custom ones first:
@@ -94,13 +124,7 @@ class Mailer(object):
logger.info("Created Mailer for app_name %s" % app_name)
@defer.inlineCallbacks
- def send_password_reset_mail(
- self,
- email_address,
- token,
- client_secret,
- sid,
- ):
+ def send_password_reset_mail(self, email_address, token, client_secret, sid):
"""Send an email with a password reset link to a user
Args:
@@ -112,19 +136,16 @@ class Mailer(object):
group together multiple email sending attempts
sid (str): The generated session ID
"""
- if email.utils.parseaddr(email_address)[1] == '':
+ if email.utils.parseaddr(email_address)[1] == "":
raise RuntimeError("Invalid 'to' email address")
link = (
- self.hs.config.public_baseurl +
- "_synapse/password_reset/email/submit_token"
- "?token=%s&client_secret=%s&sid=%s" %
- (token, client_secret, sid)
+ self.hs.config.public_baseurl
+ + "_matrix/client/unstable/password_reset/email/submit_token"
+ "?token=%s&client_secret=%s&sid=%s" % (token, client_secret, sid)
)
- template_vars = {
- "link": link,
- }
+ template_vars = {"link": link}
yield self.send_email(
email_address,
@@ -133,15 +154,14 @@ class Mailer(object):
)
@defer.inlineCallbacks
- def send_notification_mail(self, app_id, user_id, email_address,
- push_actions, reason):
+ def send_notification_mail(
+ self, app_id, user_id, email_address, push_actions, reason
+ ):
"""Send email regarding a user's room notifications"""
- rooms_in_order = deduped_ordered_list(
- [pa['room_id'] for pa in push_actions]
- )
+ rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions])
notif_events = yield self.store.get_events(
- [pa['event_id'] for pa in push_actions]
+ [pa["event_id"] for pa in push_actions]
)
notifs_by_room = {}
@@ -171,9 +191,7 @@ class Mailer(object):
yield concurrently_execute(_fetch_room_state, rooms_in_order, 3)
# actually sort our so-called rooms_in_order list, most recent room first
- rooms_in_order.sort(
- key=lambda r: -(notifs_by_room[r][-1]['received_ts'] or 0)
- )
+ rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1]["received_ts"] or 0))
rooms = []
@@ -183,9 +201,11 @@ class Mailer(object):
)
rooms.append(roomvars)
- reason['room_name'] = yield calculate_room_name(
- self.store, state_by_room[reason['room_id']], user_id,
- fallback_to_members=True
+ reason["room_name"] = yield calculate_room_name(
+ self.store,
+ state_by_room[reason["room_id"]],
+ user_id,
+ fallback_to_members=True,
)
summary_text = yield self.make_summary_text(
@@ -204,25 +224,21 @@ class Mailer(object):
}
yield self.send_email(
- email_address,
- "[%s] %s" % (self.app_name, summary_text),
- template_vars,
+ email_address, "[%s] %s" % (self.app_name, summary_text), template_vars
)
@defer.inlineCallbacks
def send_email(self, email_address, subject, template_vars):
"""Send an email with the given information and template text"""
try:
- from_string = self.hs.config.email_notif_from % {
- "app": self.app_name
- }
+ from_string = self.hs.config.email_notif_from % {"app": self.app_name}
except TypeError:
from_string = self.hs.config.email_notif_from
raw_from = email.utils.parseaddr(from_string)[1]
raw_to = email.utils.parseaddr(email_address)[1]
- if raw_to == '':
+ if raw_to == "":
raise RuntimeError("Invalid 'to' address")
html_text = self.template_html.render(**template_vars)
@@ -231,27 +247,31 @@ class Mailer(object):
plain_text = self.template_text.render(**template_vars)
text_part = MIMEText(plain_text, "plain", "utf8")
- multipart_msg = MIMEMultipart('alternative')
- multipart_msg['Subject'] = subject
- multipart_msg['From'] = from_string
- multipart_msg['To'] = email_address
- multipart_msg['Date'] = email.utils.formatdate()
- multipart_msg['Message-ID'] = email.utils.make_msgid()
+ multipart_msg = MIMEMultipart("alternative")
+ multipart_msg["Subject"] = subject
+ multipart_msg["From"] = from_string
+ multipart_msg["To"] = email_address
+ multipart_msg["Date"] = email.utils.formatdate()
+ multipart_msg["Message-ID"] = email.utils.make_msgid()
multipart_msg.attach(text_part)
multipart_msg.attach(html_part)
- logger.info("Sending email push notification to %s" % email_address)
-
- yield make_deferred_yieldable(self.sendmail(
- self.hs.config.email_smtp_host,
- raw_from, raw_to, multipart_msg.as_string().encode('utf8'),
- reactor=self.hs.get_reactor(),
- port=self.hs.config.email_smtp_port,
- requireAuthentication=self.hs.config.email_smtp_user is not None,
- username=self.hs.config.email_smtp_user,
- password=self.hs.config.email_smtp_pass,
- requireTransportSecurity=self.hs.config.require_transport_security
- ))
+ logger.info("Sending email notification to %s" % email_address)
+
+ yield make_deferred_yieldable(
+ self.sendmail(
+ self.hs.config.email_smtp_host,
+ raw_from,
+ raw_to,
+ multipart_msg.as_string().encode("utf8"),
+ reactor=self.hs.get_reactor(),
+ port=self.hs.config.email_smtp_port,
+ requireAuthentication=self.hs.config.email_smtp_user is not None,
+ username=self.hs.config.email_smtp_user,
+ password=self.hs.config.email_smtp_pass,
+ requireTransportSecurity=self.hs.config.require_transport_security,
+ )
+ )
@defer.inlineCallbacks
def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state_ids):
@@ -272,17 +292,18 @@ class Mailer(object):
if not is_invite:
for n in notifs:
notifvars = yield self.get_notif_vars(
- n, user_id, notif_events[n['event_id']], room_state_ids
+ n, user_id, notif_events[n["event_id"]], room_state_ids
)
# merge overlapping notifs together.
# relies on the notifs being in chronological order.
merge = False
- if room_vars['notifs'] and 'messages' in room_vars['notifs'][-1]:
- prev_messages = room_vars['notifs'][-1]['messages']
- for message in notifvars['messages']:
- pm = list(filter(lambda pm: pm['id'] == message['id'],
- prev_messages))
+ if room_vars["notifs"] and "messages" in room_vars["notifs"][-1]:
+ prev_messages = room_vars["notifs"][-1]["messages"]
+ for message in notifvars["messages"]:
+ pm = list(
+ filter(lambda pm: pm["id"] == message["id"], prev_messages)
+ )
if pm:
if not message["is_historical"]:
pm[0]["is_historical"] = False
@@ -293,20 +314,22 @@ class Mailer(object):
prev_messages.append(message)
if not merge:
- room_vars['notifs'].append(notifvars)
+ room_vars["notifs"].append(notifvars)
defer.returnValue(room_vars)
@defer.inlineCallbacks
def get_notif_vars(self, notif, user_id, notif_event, room_state_ids):
results = yield self.store.get_events_around(
- notif['room_id'], notif['event_id'],
- before_limit=CONTEXT_BEFORE, after_limit=CONTEXT_AFTER
+ notif["room_id"],
+ notif["event_id"],
+ before_limit=CONTEXT_BEFORE,
+ after_limit=CONTEXT_AFTER,
)
ret = {
"link": self.make_notif_link(notif),
- "ts": notif['received_ts'],
+ "ts": notif["received_ts"],
"messages": [],
}
@@ -318,7 +341,7 @@ class Mailer(object):
for event in the_events:
messagevars = yield self.get_message_vars(notif, event, room_state_ids)
if messagevars is not None:
- ret['messages'].append(messagevars)
+ ret["messages"].append(messagevars)
defer.returnValue(ret)
@@ -340,7 +363,7 @@ class Mailer(object):
ret = {
"msgtype": msgtype,
- "is_historical": event.event_id != notif['event_id'],
+ "is_historical": event.event_id != notif["event_id"],
"id": event.event_id,
"ts": event.origin_server_ts,
"sender_name": sender_name,
@@ -379,8 +402,9 @@ class Mailer(object):
return messagevars
@defer.inlineCallbacks
- def make_summary_text(self, notifs_by_room, room_state_ids,
- notif_events, user_id, reason):
+ def make_summary_text(
+ self, notifs_by_room, room_state_ids, notif_events, user_id, reason
+ ):
if len(notifs_by_room) == 1:
# Only one room has new stuff
room_id = list(notifs_by_room.keys())[0]
@@ -404,16 +428,19 @@ class Mailer(object):
inviter_name = name_from_member_event(inviter_member_event)
if room_name is None:
- defer.returnValue(INVITE_FROM_PERSON % {
- "person": inviter_name,
- "app": self.app_name
- })
+ defer.returnValue(
+ INVITE_FROM_PERSON
+ % {"person": inviter_name, "app": self.app_name}
+ )
else:
- defer.returnValue(INVITE_FROM_PERSON_TO_ROOM % {
- "person": inviter_name,
- "room": room_name,
- "app": self.app_name,
- })
+ defer.returnValue(
+ INVITE_FROM_PERSON_TO_ROOM
+ % {
+ "person": inviter_name,
+ "room": room_name,
+ "app": self.app_name,
+ }
+ )
sender_name = None
if len(notifs_by_room[room_id]) == 1:
@@ -427,67 +454,86 @@ class Mailer(object):
sender_name = name_from_member_event(state_event)
if sender_name is not None and room_name is not None:
- defer.returnValue(MESSAGE_FROM_PERSON_IN_ROOM % {
- "person": sender_name,
- "room": room_name,
- "app": self.app_name,
- })
+ defer.returnValue(
+ MESSAGE_FROM_PERSON_IN_ROOM
+ % {
+ "person": sender_name,
+ "room": room_name,
+ "app": self.app_name,
+ }
+ )
elif sender_name is not None:
- defer.returnValue(MESSAGE_FROM_PERSON % {
- "person": sender_name,
- "app": self.app_name,
- })
+ defer.returnValue(
+ MESSAGE_FROM_PERSON
+ % {"person": sender_name, "app": self.app_name}
+ )
else:
# There's more than one notification for this room, so just
# say there are several
if room_name is not None:
- defer.returnValue(MESSAGES_IN_ROOM % {
- "room": room_name,
- "app": self.app_name,
- })
+ defer.returnValue(
+ MESSAGES_IN_ROOM % {"room": room_name, "app": self.app_name}
+ )
else:
# If the room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room"
- sender_ids = list(set([
- notif_events[n['event_id']].sender
- for n in notifs_by_room[room_id]
- ]))
-
- member_events = yield self.store.get_events([
- room_state_ids[room_id][("m.room.member", s)]
- for s in sender_ids
- ])
-
- defer.returnValue(MESSAGES_FROM_PERSON % {
- "person": descriptor_from_member_events(member_events.values()),
- "app": self.app_name,
- })
+ sender_ids = list(
+ set(
+ [
+ notif_events[n["event_id"]].sender
+ for n in notifs_by_room[room_id]
+ ]
+ )
+ )
+
+ member_events = yield self.store.get_events(
+ [
+ room_state_ids[room_id][("m.room.member", s)]
+ for s in sender_ids
+ ]
+ )
+
+ defer.returnValue(
+ MESSAGES_FROM_PERSON
+ % {
+ "person": descriptor_from_member_events(
+ member_events.values()
+ ),
+ "app": self.app_name,
+ }
+ )
else:
# Stuff's happened in multiple different rooms
# ...but we still refer to the 'reason' room which triggered the mail
- if reason['room_name'] is not None:
- defer.returnValue(MESSAGES_IN_ROOM_AND_OTHERS % {
- "room": reason['room_name'],
- "app": self.app_name,
- })
+ if reason["room_name"] is not None:
+ defer.returnValue(
+ MESSAGES_IN_ROOM_AND_OTHERS
+ % {"room": reason["room_name"], "app": self.app_name}
+ )
else:
# If the reason room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room"
- sender_ids = list(set([
- notif_events[n['event_id']].sender
- for n in notifs_by_room[reason['room_id']]
- ]))
+ sender_ids = list(
+ set(
+ [
+ notif_events[n["event_id"]].sender
+ for n in notifs_by_room[reason["room_id"]]
+ ]
+ )
+ )
- member_events = yield self.store.get_events([
- room_state_ids[room_id][("m.room.member", s)]
- for s in sender_ids
- ])
+ member_events = yield self.store.get_events(
+ [room_state_ids[room_id][("m.room.member", s)] for s in sender_ids]
+ )
- defer.returnValue(MESSAGES_FROM_PERSON_AND_OTHERS % {
- "person": descriptor_from_member_events(member_events.values()),
- "app": self.app_name,
- })
+ defer.returnValue(
+ MESSAGES_FROM_PERSON_AND_OTHERS
+ % {
+ "person": descriptor_from_member_events(member_events.values()),
+ "app": self.app_name,
+ }
+ )
def make_room_link(self, room_id):
if self.hs.config.email_riot_base_url:
@@ -503,17 +549,17 @@ class Mailer(object):
if self.hs.config.email_riot_base_url:
return "%s/#/room/%s/%s" % (
self.hs.config.email_riot_base_url,
- notif['room_id'], notif['event_id']
+ notif["room_id"],
+ notif["event_id"],
)
elif self.app_name == "Vector":
# need /beta for Universal Links to work on iOS
return "https://vector.im/beta/#/room/%s/%s" % (
- notif['room_id'], notif['event_id']
+ notif["room_id"],
+ notif["event_id"],
)
else:
- return "https://matrix.to/#/%s/%s" % (
- notif['room_id'], notif['event_id']
- )
+ return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"])
def make_unsubscribe_link(self, user_id, app_id, email_address):
params = {
@@ -530,12 +576,18 @@ class Mailer(object):
def safe_markup(raw_html):
- return jinja2.Markup(bleach.linkify(bleach.clean(
- raw_html, tags=ALLOWED_TAGS, attributes=ALLOWED_ATTRS,
- # bleach master has this, but it isn't released yet
- # protocols=ALLOWED_SCHEMES,
- strip=True
- )))
+ return jinja2.Markup(
+ bleach.linkify(
+ bleach.clean(
+ raw_html,
+ tags=ALLOWED_TAGS,
+ attributes=ALLOWED_ATTRS,
+ # bleach master has this, but it isn't released yet
+ # protocols=ALLOWED_SCHEMES,
+ strip=True,
+ )
+ )
+ )
def safe_text(raw_text):
@@ -543,10 +595,9 @@ def safe_text(raw_text):
Process text: treat it as HTML but escape any tags (ie. just escape the
HTML) then linkify it.
"""
- return jinja2.Markup(bleach.linkify(bleach.clean(
- raw_text, tags=[], attributes={},
- strip=False
- )))
+ return jinja2.Markup(
+ bleach.linkify(bleach.clean(raw_text, tags=[], attributes={}, strip=False))
+ )
def deduped_ordered_list(l):
@@ -595,15 +646,11 @@ def _create_mxc_to_http_filter(config):
serverAndMediaId = value[6:]
fragment = None
- if '#' in serverAndMediaId:
- (serverAndMediaId, fragment) = serverAndMediaId.split('#', 1)
+ if "#" in serverAndMediaId:
+ (serverAndMediaId, fragment) = serverAndMediaId.split("#", 1)
fragment = "#" + fragment
- params = {
- "width": width,
- "height": height,
- "method": resize_method,
- }
+ params = {"width": width, "height": height, "method": resize_method}
return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
config.public_baseurl,
serverAndMediaId,
diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py
index eef6e18c2e..06056fbf4f 100644
--- a/synapse/push/presentable_names.py
+++ b/synapse/push/presentable_names.py
@@ -28,8 +28,13 @@ ALL_ALONE = "Empty Room"
@defer.inlineCallbacks
-def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True,
- fallback_to_single_member=True):
+def calculate_room_name(
+ store,
+ room_state_ids,
+ user_id,
+ fallback_to_members=True,
+ fallback_to_single_member=True,
+):
"""
Works out a user-facing name for the given room as per Matrix
spec recommendations.
@@ -58,8 +63,10 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True
room_state_ids[("m.room.canonical_alias", "")], allow_none=True
)
if (
- canon_alias and canon_alias.content and canon_alias.content["alias"] and
- _looks_like_an_alias(canon_alias.content["alias"])
+ canon_alias
+ and canon_alias.content
+ and canon_alias.content["alias"]
+ and _looks_like_an_alias(canon_alias.content["alias"])
):
defer.returnValue(canon_alias.content["alias"])
@@ -71,9 +78,7 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True
if "m.room.aliases" in room_state_bytype_ids:
m_room_aliases = room_state_bytype_ids["m.room.aliases"]
for alias_id in m_room_aliases.values():
- alias_event = yield store.get_event(
- alias_id, allow_none=True
- )
+ alias_event = yield store.get_event(alias_id, allow_none=True)
if alias_event and alias_event.content.get("aliases"):
the_aliases = alias_event.content["aliases"]
if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]):
@@ -89,8 +94,8 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True
)
if (
- my_member_event is not None and
- my_member_event.content['membership'] == "invite"
+ my_member_event is not None
+ and my_member_event.content["membership"] == "invite"
):
if ("m.room.member", my_member_event.sender) in room_state_ids:
inviter_member_event = yield store.get_event(
@@ -100,9 +105,8 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True
if inviter_member_event:
if fallback_to_single_member:
defer.returnValue(
- "Invite from %s" % (
- name_from_member_event(inviter_member_event),
- )
+ "Invite from %s"
+ % (name_from_member_event(inviter_member_event),)
)
else:
return
@@ -116,8 +120,10 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True
list(room_state_bytype_ids["m.room.member"].values())
)
all_members = [
- ev for ev in member_events.values()
- if ev.content['membership'] == "join" or ev.content['membership'] == "invite"
+ ev
+ for ev in member_events.values()
+ if ev.content["membership"] == "join"
+ or ev.content["membership"] == "invite"
]
# Sort the member events oldest-first so the we name people in the
# order the joined (it should at least be deterministic rather than
@@ -134,9 +140,9 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True
# or inbound invite, or outbound 3PID invite.
if all_members[0].sender == user_id:
if "m.room.third_party_invite" in room_state_bytype_ids:
- third_party_invites = (
- room_state_bytype_ids["m.room.third_party_invite"].values()
- )
+ third_party_invites = room_state_bytype_ids[
+ "m.room.third_party_invite"
+ ].values()
if len(third_party_invites) > 0:
# technically third party invite events are not member
@@ -162,6 +168,17 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True
def descriptor_from_member_events(member_events):
+ """Get a description of the room based on the member events.
+
+ Args:
+ member_events (Iterable[FrozenEvent])
+
+ Returns:
+ str
+ """
+
+ member_events = list(member_events)
+
if len(member_events) == 0:
return "nobody"
elif len(member_events) == 1:
@@ -180,8 +197,9 @@ def descriptor_from_member_events(member_events):
def name_from_member_event(member_event):
if (
- member_event.content and "displayname" in member_event.content and
- member_event.content["displayname"]
+ member_event.content
+ and "displayname" in member_event.content
+ and member_event.content["displayname"]
):
return member_event.content["displayname"]
return member_event.state_key
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index cf6c8b875e..5ed9147de4 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -26,8 +26,8 @@ from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
-GLOB_REGEX = re.compile(r'\\\[(\\\!|)(.*)\\\]')
-IS_GLOB = re.compile(r'[\?\*\[\]]')
+GLOB_REGEX = re.compile(r"\\\[(\\\!|)(.*)\\\]")
+IS_GLOB = re.compile(r"[\?\*\[\]]")
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
@@ -36,20 +36,20 @@ def _room_member_count(ev, condition, room_member_count):
def _sender_notification_permission(ev, condition, sender_power_level, power_levels):
- notif_level_key = condition.get('key')
+ notif_level_key = condition.get("key")
if notif_level_key is None:
return False
- notif_levels = power_levels.get('notifications', {})
+ notif_levels = power_levels.get("notifications", {})
room_notif_level = notif_levels.get(notif_level_key, 50)
return sender_power_level >= room_notif_level
def _test_ineq_condition(condition, number):
- if 'is' not in condition:
+ if "is" not in condition:
return False
- m = INEQUALITY_EXPR.match(condition['is'])
+ m = INEQUALITY_EXPR.match(condition["is"])
if not m:
return False
ineq = m.group(1)
@@ -58,15 +58,15 @@ def _test_ineq_condition(condition, number):
return False
rhs = int(rhs)
- if ineq == '' or ineq == '==':
+ if ineq == "" or ineq == "==":
return number == rhs
- elif ineq == '<':
+ elif ineq == "<":
return number < rhs
- elif ineq == '>':
+ elif ineq == ">":
return number > rhs
- elif ineq == '>=':
+ elif ineq == ">=":
return number >= rhs
- elif ineq == '<=':
+ elif ineq == "<=":
return number <= rhs
else:
return False
@@ -77,8 +77,8 @@ def tweaks_for_actions(actions):
for a in actions:
if not isinstance(a, dict):
continue
- if 'set_tweak' in a and 'value' in a:
- tweaks[a['set_tweak']] = a['value']
+ if "set_tweak" in a and "value" in a:
+ tweaks[a["set_tweak"]] = a["value"]
return tweaks
@@ -93,26 +93,24 @@ class PushRuleEvaluatorForEvent(object):
self._value_cache = _flatten_dict(event)
def matches(self, condition, user_id, display_name):
- if condition['kind'] == 'event_match':
+ if condition["kind"] == "event_match":
return self._event_match(condition, user_id)
- elif condition['kind'] == 'contains_display_name':
+ elif condition["kind"] == "contains_display_name":
return self._contains_display_name(display_name)
- elif condition['kind'] == 'room_member_count':
- return _room_member_count(
- self._event, condition, self._room_member_count
- )
- elif condition['kind'] == 'sender_notification_permission':
+ elif condition["kind"] == "room_member_count":
+ return _room_member_count(self._event, condition, self._room_member_count)
+ elif condition["kind"] == "sender_notification_permission":
return _sender_notification_permission(
- self._event, condition, self._sender_power_level, self._power_levels,
+ self._event, condition, self._sender_power_level, self._power_levels
)
else:
return True
def _event_match(self, condition, user_id):
- pattern = condition.get('pattern', None)
+ pattern = condition.get("pattern", None)
if not pattern:
- pattern_type = condition.get('pattern_type', None)
+ pattern_type = condition.get("pattern_type", None)
if pattern_type == "user_id":
pattern = user_id
elif pattern_type == "user_localpart":
@@ -123,14 +121,14 @@ class PushRuleEvaluatorForEvent(object):
return False
# XXX: optimisation: cache our pattern regexps
- if condition['key'] == 'content.body':
+ if condition["key"] == "content.body":
body = self._event.content.get("body", None)
if not body:
return False
return _glob_matches(pattern, body, word_boundary=True)
else:
- haystack = self._get_value(condition['key'])
+ haystack = self._get_value(condition["key"])
if haystack is None:
return False
@@ -193,16 +191,13 @@ def _glob_to_re(glob, word_boundary):
if IS_GLOB.search(glob):
r = re.escape(glob)
- r = r.replace(r'\*', '.*?')
- r = r.replace(r'\?', '.')
+ r = r.replace(r"\*", ".*?")
+ r = r.replace(r"\?", ".")
# handle [abc], [a-z] and [!a-z] style ranges.
r = GLOB_REGEX.sub(
lambda x: (
- '[%s%s]' % (
- x.group(1) and '^' or '',
- x.group(2).replace(r'\\\-', '-')
- )
+ "[%s%s]" % (x.group(1) and "^" or "", x.group(2).replace(r"\\\-", "-"))
),
r,
)
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 8049c298c2..e37269cdb9 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -23,9 +23,7 @@ def get_badge_count(store, user_id):
invites = yield store.get_invited_rooms_for_user(user_id)
joins = yield store.get_rooms_for_user(user_id)
- my_receipts_by_room = yield store.get_receipts_for_user(
- user_id, "m.read",
- )
+ my_receipts_by_room = yield store.get_receipts_for_user(user_id, "m.read")
badge = len(invites)
@@ -57,10 +55,10 @@ def get_context_for_event(store, state_handler, ev, user_id):
store, room_state_ids, user_id, fallback_to_single_member=False
)
if name:
- ctx['name'] = name
+ ctx["name"] = name
sender_state_event_id = room_state_ids[("m.room.member", ev.sender)]
sender_state_event = yield store.get_event(sender_state_event_id)
- ctx['sender_display_name'] = name_from_member_event(sender_state_event)
+ ctx["sender_display_name"] = name_from_member_event(sender_state_event)
defer.returnValue(ctx)
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index aff85daeb5..a9c64a9c54 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -36,9 +36,7 @@ class PusherFactory(object):
def __init__(self, hs):
self.hs = hs
- self.pusher_types = {
- "http": HttpPusher,
- }
+ self.pusher_types = {"http": HttpPusher}
logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
if hs.config.email_enable_notifs:
@@ -56,7 +54,7 @@ class PusherFactory(object):
logger.info("defined email pusher type")
def create_pusher(self, pusherdict):
- kind = pusherdict['kind']
+ kind = pusherdict["kind"]
f = self.pusher_types.get(kind, None)
if not f:
return None
@@ -77,8 +75,8 @@ class PusherFactory(object):
return EmailPusher(self.hs, pusherdict, mailer)
def _app_name_from_pusherdict(self, pusherdict):
- if 'data' in pusherdict and 'brand' in pusherdict['data']:
- app_name = pusherdict['data']['brand']
+ if "data" in pusherdict and "brand" in pusherdict["data"]:
+ app_name = pusherdict["data"]["brand"]
else:
app_name = self.hs.config.email_app_name
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 40a7709c09..df6f670740 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -40,6 +40,7 @@ class PusherPool:
notifications are sent; accordingly Pusher.on_started, Pusher.on_new_notifications and
Pusher.on_new_receipts are not expected to return deferreds.
"""
+
def __init__(self, _hs):
self.hs = _hs
self.pusher_factory = PusherFactory(_hs)
@@ -57,30 +58,47 @@ class PusherPool:
run_as_background_process("start_pushers", self._start_pushers)
@defer.inlineCallbacks
- def add_pusher(self, user_id, access_token, kind, app_id,
- app_display_name, device_display_name, pushkey, lang, data,
- profile_tag=""):
+ def add_pusher(
+ self,
+ user_id,
+ access_token,
+ kind,
+ app_id,
+ app_display_name,
+ device_display_name,
+ pushkey,
+ lang,
+ data,
+ profile_tag="",
+ ):
+ """Creates a new pusher and adds it to the pool
+
+ Returns:
+ Deferred[EmailPusher|HttpPusher]
+ """
time_now_msec = self.clock.time_msec()
# we try to create the pusher just to validate the config: it
# will then get pulled out of the database,
# recreated, added and started: this means we have only one
# code path adding pushers.
- self.pusher_factory.create_pusher({
- "id": None,
- "user_name": user_id,
- "kind": kind,
- "app_id": app_id,
- "app_display_name": app_display_name,
- "device_display_name": device_display_name,
- "pushkey": pushkey,
- "ts": time_now_msec,
- "lang": lang,
- "data": data,
- "last_stream_ordering": None,
- "last_success": None,
- "failing_since": None
- })
+ self.pusher_factory.create_pusher(
+ {
+ "id": None,
+ "user_name": user_id,
+ "kind": kind,
+ "app_id": app_id,
+ "app_display_name": app_display_name,
+ "device_display_name": device_display_name,
+ "pushkey": pushkey,
+ "ts": time_now_msec,
+ "lang": lang,
+ "data": data,
+ "last_stream_ordering": None,
+ "last_success": None,
+ "failing_since": None,
+ }
+ )
# create the pusher setting last_stream_ordering to the current maximum
# stream ordering in event_push_actions, so it will process
@@ -103,21 +121,24 @@ class PusherPool:
last_stream_ordering=last_stream_ordering,
profile_tag=profile_tag,
)
- yield self.start_pusher_by_id(app_id, pushkey, user_id)
+ pusher = yield self.start_pusher_by_id(app_id, pushkey, user_id)
+
+ defer.returnValue(pusher)
@defer.inlineCallbacks
- def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey,
- not_user_id):
- to_remove = yield self.store.get_pushers_by_app_id_and_pushkey(
- app_id, pushkey
- )
+ def remove_pushers_by_app_id_and_pushkey_not_user(
+ self, app_id, pushkey, not_user_id
+ ):
+ to_remove = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
for p in to_remove:
- if p['user_name'] != not_user_id:
+ if p["user_name"] != not_user_id:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
- app_id, pushkey, p['user_name']
+ app_id,
+ pushkey,
+ p["user_name"],
)
- yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
+ yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
@defer.inlineCallbacks
def remove_pushers_by_access_token(self, user_id, access_tokens):
@@ -131,14 +152,14 @@ class PusherPool:
"""
tokens = set(access_tokens)
for p in (yield self.store.get_pushers_by_user_id(user_id)):
- if p['access_token'] in tokens:
+ if p["access_token"] in tokens:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
- p['app_id'], p['pushkey'], p['user_name']
- )
- yield self.remove_pusher(
- p['app_id'], p['pushkey'], p['user_name'],
+ p["app_id"],
+ p["pushkey"],
+ p["user_name"],
)
+ yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
@defer.inlineCallbacks
def on_new_notifications(self, min_stream_id, max_stream_id):
@@ -184,21 +205,26 @@ class PusherPool:
@defer.inlineCallbacks
def start_pusher_by_id(self, app_id, pushkey, user_id):
- """Look up the details for the given pusher, and start it"""
+ """Look up the details for the given pusher, and start it
+
+ Returns:
+ Deferred[EmailPusher|HttpPusher|None]: The pusher started, if any
+ """
if not self._should_start_pushers:
return
- resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(
- app_id, pushkey
- )
+ resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
- p = None
+ pusher_dict = None
for r in resultlist:
- if r['user_name'] == user_id:
- p = r
+ if r["user_name"] == user_id:
+ pusher_dict = r
- if p:
- yield self._start_pusher(p)
+ pusher = None
+ if pusher_dict:
+ pusher = yield self._start_pusher(pusher_dict)
+
+ defer.returnValue(pusher)
@defer.inlineCallbacks
def _start_pushers(self):
@@ -224,16 +250,16 @@ class PusherPool:
pusherdict (dict):
Returns:
- None
+ Deferred[EmailPusher|HttpPusher]
"""
try:
p = self.pusher_factory.create_pusher(pusherdict)
except PusherConfigException as e:
logger.warning(
"Pusher incorrectly configured user=%s, appid=%s, pushkey=%s: %s",
- pusherdict.get('user_name'),
- pusherdict.get('app_id'),
- pusherdict.get('pushkey'),
+ pusherdict.get("user_name"),
+ pusherdict.get("app_id"),
+ pusherdict.get("pushkey"),
e,
)
return
@@ -244,11 +270,8 @@ class PusherPool:
if not p:
return
- appid_pushkey = "%s:%s" % (
- pusherdict['app_id'],
- pusherdict['pushkey'],
- )
- byuser = self.pushers.setdefault(pusherdict['user_name'], {})
+ appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"])
+ byuser = self.pushers.setdefault(pusherdict["user_name"], {})
if appid_pushkey in byuser:
byuser[appid_pushkey].on_stop()
@@ -261,7 +284,7 @@ class PusherPool:
last_stream_ordering = pusherdict["last_stream_ordering"]
if last_stream_ordering:
have_notifs = yield self.store.get_if_maybe_push_in_range_for_user(
- user_id, last_stream_ordering,
+ user_id, last_stream_ordering
)
else:
# We always want to default to starting up the pusher rather than
@@ -270,6 +293,8 @@ class PusherPool:
p.on_started(have_notifs)
+ defer.returnValue(p)
+
@defer.inlineCallbacks
def remove_pusher(self, app_id, pushkey, user_id):
appid_pushkey = "%s:%s" % (app_id, pushkey)
diff --git a/synapse/push/rulekinds.py b/synapse/push/rulekinds.py
index 4cae48ac07..ce7cc1b4ee 100644
--- a/synapse/push/rulekinds.py
+++ b/synapse/push/rulekinds.py
@@ -13,10 +13,10 @@
# limitations under the License.
PRIORITY_CLASS_MAP = {
- 'underride': 1,
- 'sender': 2,
- 'room': 3,
- 'content': 4,
- 'override': 5,
+ "underride": 1,
+ "sender": 2,
+ "room": 3,
+ "content": 4,
+ "override": 5,
}
PRIORITY_CLASS_INVERSE_MAP = {v: k for k, v in PRIORITY_CLASS_MAP.items()}
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 6efd81f204..13698d9638 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -44,15 +44,12 @@ REQUIREMENTS = [
"canonicaljson>=1.1.3",
"signedjson>=1.0.0",
"pynacl>=1.2.1",
- "idna>=2",
-
+ "idna>=2.5",
# validating SSL certs for IP addresses requires service_identity 18.1.
"service_identity>=18.1.0",
-
# our logcontext handling relies on the ability to cancel inlineCallbacks
# (https://twistedmatrix.com/trac/ticket/4632) which landed in Twisted 18.7.
"Twisted>=18.7.0",
-
"treq>=15.1",
# Twisted has required pyopenssl 16.0 since about Twisted 16.6.
"pyopenssl>=16.0.0",
@@ -65,40 +62,34 @@ REQUIREMENTS = [
"sortedcontainers>=1.4.4",
"psutil>=2.0.0",
"pymacaroons>=0.13.0",
- "msgpack>=0.5.0",
+ "msgpack>=0.5.2",
"phonenumbers>=8.2.0",
"six>=1.10",
# prometheus_client 0.4.0 changed the format of counter metrics
# (cf https://github.com/matrix-org/synapse/issues/4001)
"prometheus_client>=0.0.18,<0.4.0",
-
# we use attr.s(slots), which arrived in 16.0.0
# Twisted 18.7.0 requires attrs>=17.4.0
"attrs>=17.4.0",
-
"netaddr>=0.7.18",
+ "Jinja2>=2.9",
+ "bleach>=1.4.3",
]
CONDITIONAL_REQUIREMENTS = {
- "email": ["Jinja2>=2.9", "bleach>=1.4.2"],
"matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"],
-
# we use execute_batch, which arrived in psycopg 2.7.
"postgres": ["psycopg2>=2.7"],
-
# ConsentResource uses select_autoescape, which arrived in jinja 2.9
"resources.consent": ["Jinja2>=2.9"],
-
# ACME support is required to provision TLS certificates from authorities
# that use the protocol, such as Let's Encrypt.
"acme": [
"txacme>=0.9.2",
-
# txacme depends on eliot. Eliot 1.8.0 is incompatible with
# python 3.5.2, as per https://github.com/itamarst/eliot/issues/418
'eliot<1.8.0;python_version<"3.5.3"',
],
-
"saml2": ["pysaml2>=4.5.0"],
"systemd": ["systemd-python>=231"],
"url_preview": ["lxml>=3.5.0"],
@@ -121,12 +112,14 @@ def list_requirements():
class DependencyException(Exception):
@property
def message(self):
- return "\n".join([
- "Missing Requirements: %s" % (", ".join(self.dependencies),),
- "To install run:",
- " pip install --upgrade --force %s" % (" ".join(self.dependencies),),
- "",
- ])
+ return "\n".join(
+ [
+ "Missing Requirements: %s" % (", ".join(self.dependencies),),
+ "To install run:",
+ " pip install --upgrade --force %s" % (" ".join(self.dependencies),),
+ "",
+ ]
+ )
@property
def dependencies(self):
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index e81456ab2b..fe482e279f 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -17,11 +17,17 @@ import abc
import logging
import re
+from six import raise_from
from six.moves import urllib
from twisted.internet import defer
-from synapse.api.errors import CodeMessageException, HttpResponseException
+from synapse.api.errors import (
+ CodeMessageException,
+ HttpResponseException,
+ RequestSendFailed,
+ SynapseError,
+)
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
@@ -77,8 +83,7 @@ class ReplicationEndpoint(object):
def __init__(self, hs):
if self.CACHE:
self.response_cache = ResponseCache(
- hs, "repl." + self.NAME,
- timeout_ms=30 * 60 * 1000,
+ hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
)
assert self.METHOD in ("PUT", "POST", "GET")
@@ -128,8 +133,7 @@ class ReplicationEndpoint(object):
data = yield cls._serialize_payload(**kwargs)
url_args = [
- urllib.parse.quote(kwargs[name], safe='')
- for name in cls.PATH_ARGS
+ urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
]
if cls.CACHE:
@@ -150,7 +154,10 @@ class ReplicationEndpoint(object):
)
uri = "http://%s:%s/_synapse/replication/%s/%s" % (
- host, port, cls.NAME, "/".join(url_args)
+ host,
+ port,
+ cls.NAME,
+ "/".join(url_args),
)
try:
@@ -175,6 +182,8 @@ class ReplicationEndpoint(object):
# on the master process that we should send to the client. (And
# importantly, not stack traces everywhere)
raise e.to_synapse_error()
+ except RequestSendFailed as e:
+ raise_from(SynapseError(502, "Failed to talk to master"), e)
defer.returnValue(result)
@@ -194,10 +203,7 @@ class ReplicationEndpoint(object):
url_args.append("txn_id")
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
- pattern = re.compile("^/_synapse/replication/%s/%s$" % (
- self.NAME,
- args
- ))
+ pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
http_server.register_paths(method, [pattern], handler)
@@ -211,8 +217,4 @@ class ReplicationEndpoint(object):
assert self.CACHE
- return self.response_cache.wrap(
- txn_id,
- self._handle_request,
- request, **kwargs
- )
+ return self.response_cache.wrap(txn_id, self._handle_request, request, **kwargs)
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 0f0a07c422..61eafbe708 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -68,18 +68,17 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
for event, context in event_and_contexts:
serialized_context = yield context.serialize(event, store)
- event_payloads.append({
- "event": event.get_pdu_json(),
- "event_format_version": event.format_version,
- "internal_metadata": event.internal_metadata.get_dict(),
- "rejected_reason": event.rejected_reason,
- "context": serialized_context,
- })
-
- payload = {
- "events": event_payloads,
- "backfilled": backfilled,
- }
+ event_payloads.append(
+ {
+ "event": event.get_pdu_json(),
+ "event_format_version": event.format_version,
+ "internal_metadata": event.internal_metadata.get_dict(),
+ "rejected_reason": event.rejected_reason,
+ "context": serialized_context,
+ }
+ )
+
+ payload = {"events": event_payloads, "backfilled": backfilled}
defer.returnValue(payload)
@@ -103,18 +102,15 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
event = EventType(event_dict, internal_metadata, rejected_reason)
context = yield EventContext.deserialize(
- self.store, event_payload["context"],
+ self.store, event_payload["context"]
)
event_and_contexts.append((event, context))
- logger.info(
- "Got %d events from federation",
- len(event_and_contexts),
- )
+ logger.info("Got %d events from federation", len(event_and_contexts))
yield self.federation_handler.persist_events_and_notify(
- event_and_contexts, backfilled,
+ event_and_contexts, backfilled
)
defer.returnValue((200, {}))
@@ -146,10 +142,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
@staticmethod
def _serialize_payload(edu_type, origin, content):
- return {
- "origin": origin,
- "content": content,
- }
+ return {"origin": origin, "content": content}
@defer.inlineCallbacks
def _handle_request(self, request, edu_type):
@@ -159,10 +152,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
origin = content["origin"]
edu_content = content["content"]
- logger.info(
- "Got %r edu from %s",
- edu_type, origin,
- )
+ logger.info("Got %r edu from %s", edu_type, origin)
result = yield self.registry.on_edu(edu_type, origin, edu_content)
@@ -201,9 +191,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
query_type (str)
args (dict): The arguments received for the given query type
"""
- return {
- "args": args,
- }
+ return {"args": args}
@defer.inlineCallbacks
def _handle_request(self, request, query_type):
@@ -212,10 +200,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
args = content["args"]
- logger.info(
- "Got %r query",
- query_type,
- )
+ logger.info("Got %r query", query_type)
result = yield self.registry.on_query(query_type, args)
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 63bc0405ea..7c1197e5dd 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -61,13 +61,10 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
is_guest = content["is_guest"]
device_id, access_token = yield self.registration_handler.register_device(
- user_id, device_id, initial_display_name, is_guest,
+ user_id, device_id, initial_display_name, is_guest
)
- defer.returnValue((200, {
- "device_id": device_id,
- "access_token": access_token,
- }))
+ defer.returnValue((200, {"device_id": device_id, "access_token": access_token}))
def register_servlets(hs, http_server):
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 81a2b204c7..0a76a3762f 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -40,7 +40,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
"""
NAME = "remote_join"
- PATH_ARGS = ("room_id", "user_id",)
+ PATH_ARGS = ("room_id", "user_id")
def __init__(self, hs):
super(ReplicationRemoteJoinRestServlet, self).__init__(hs)
@@ -50,8 +50,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
- def _serialize_payload(requester, room_id, user_id, remote_room_hosts,
- content):
+ def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content):
"""
Args:
requester(Requester)
@@ -78,16 +77,10 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
if requester.user:
request.authenticated_entity = requester.user.to_string()
- logger.info(
- "remote_join: %s into room: %s",
- user_id, room_id,
- )
+ logger.info("remote_join: %s into room: %s", user_id, room_id)
yield self.federation_handler.do_invite_join(
- remote_room_hosts,
- room_id,
- user_id,
- event_content,
+ remote_room_hosts, room_id, user_id, event_content
)
defer.returnValue((200, {}))
@@ -107,7 +100,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
"""
NAME = "remote_reject_invite"
- PATH_ARGS = ("room_id", "user_id",)
+ PATH_ARGS = ("room_id", "user_id")
def __init__(self, hs):
super(ReplicationRemoteRejectInviteRestServlet, self).__init__(hs)
@@ -141,16 +134,11 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
if requester.user:
request.authenticated_entity = requester.user.to_string()
- logger.info(
- "remote_reject_invite: %s out of room: %s",
- user_id, room_id,
- )
+ logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id)
try:
event = yield self.federation_handler.do_remotely_reject_invite(
- remote_room_hosts,
- room_id,
- user_id,
+ remote_room_hosts, room_id, user_id
)
ret = event.get_pdu_json()
except Exception as e:
@@ -162,9 +150,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
#
logger.warn("Failed to reject invite: %s", e)
- yield self.store.locally_reject_invite(
- user_id, room_id
- )
+ yield self.store.locally_reject_invite(user_id, room_id)
ret = {}
defer.returnValue((200, ret))
@@ -228,7 +214,7 @@ class ReplicationRegister3PIDGuestRestServlet(ReplicationEndpoint):
logger.info("get_or_register_3pid_guest: %r", content)
ret = yield self.registeration_handler.get_or_register_3pid_guest(
- medium, address, inviter_user_id,
+ medium, address, inviter_user_id
)
defer.returnValue((200, ret))
@@ -264,7 +250,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
user_id (str)
change (str): Either "joined" or "left"
"""
- assert change in ("joined", "left",)
+ assert change in ("joined", "left")
return {}
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 912a5ac341..f81a0f1b8f 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -37,8 +37,16 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
@staticmethod
def _serialize_payload(
- user_id, token, password_hash, was_guest, make_guest, appservice_id,
- create_profile_with_displayname, admin, user_type, address,
+ user_id,
+ token,
+ password_hash,
+ was_guest,
+ make_guest,
+ appservice_id,
+ create_profile_with_displayname,
+ admin,
+ user_type,
+ address,
):
"""
Args:
@@ -85,7 +93,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
create_profile_with_displayname=content["create_profile_with_displayname"],
admin=content["admin"],
user_type=content["user_type"],
- address=content["address"]
+ address=content["address"],
)
defer.returnValue((200, {}))
@@ -104,8 +112,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler()
@staticmethod
- def _serialize_payload(user_id, auth_result, access_token, bind_email,
- bind_msisdn):
+ def _serialize_payload(user_id, auth_result, access_token, bind_email, bind_msisdn):
"""
Args:
user_id (str): The user ID that consented
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index 3635015eda..034763fe99 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -45,6 +45,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
"extra_users": [],
}
"""
+
NAME = "send_event"
PATH_ARGS = ("event_id",)
@@ -57,8 +58,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
@staticmethod
@defer.inlineCallbacks
- def _serialize_payload(event_id, store, event, context, requester,
- ratelimit, extra_users):
+ def _serialize_payload(
+ event_id, store, event, context, requester, ratelimit, extra_users
+ ):
"""
Args:
event_id (str)
@@ -108,14 +110,11 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
request.authenticated_entity = requester.user.to_string()
logger.info(
- "Got event to send with ID: %s into room: %s",
- event.event_id, event.room_id,
+ "Got event to send with ID: %s into room: %s", event.event_id, event.room_id
)
yield self.event_creation_handler.persist_and_notify_client_event(
- requester, event, context,
- ratelimit=ratelimit,
- extra_users=extra_users,
+ requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
defer.returnValue((200, {}))
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 817d1f67f9..182cb2a1d8 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -37,7 +37,7 @@ class BaseSlavedStore(SQLBaseStore):
super(BaseSlavedStore, self).__init__(db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker(
- db_conn, "cache_invalidation_stream", "stream_id",
+ db_conn, "cache_invalidation_stream", "stream_id"
)
else:
self._cache_id_gen = None
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index d9ba6d69b1..3c44d1d48d 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -21,10 +21,9 @@ from synapse.storage.tags import TagsWorkerStore
class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
-
def __init__(self, db_conn, hs):
self._account_data_id_gen = SlavedIdTracker(
- db_conn, "account_data_max_stream_id", "stream_id",
+ db_conn, "account_data_max_stream_id", "stream_id"
)
super(SlavedAccountDataStore, self).__init__(db_conn, hs)
@@ -45,24 +44,20 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
self._account_data_id_gen.advance(token)
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
- self._account_data_stream_cache.entity_has_changed(
- row.user_id, token
- )
+ self._account_data_stream_cache.entity_has_changed(row.user_id, token)
elif stream_name == "account_data":
self._account_data_id_gen.advance(token)
for row in rows:
if not row.room_id:
self.get_global_account_data_by_type_for_user.invalidate(
- (row.data_type, row.user_id,)
+ (row.data_type, row.user_id)
)
self.get_account_data_for_user.invalidate((row.user_id,))
- self.get_account_data_for_room.invalidate((row.user_id, row.room_id,))
+ self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
self.get_account_data_for_room_and_type.invalidate(
- (row.user_id, row.room_id, row.data_type,),
- )
- self._account_data_stream_cache.entity_has_changed(
- row.user_id, token
+ (row.user_id, row.room_id, row.data_type)
)
+ self._account_data_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedAccountDataStore, self).process_replication_rows(
stream_name, token, rows
)
diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py
index b53a4c6bd1..cda12ea70d 100644
--- a/synapse/replication/slave/storage/appservice.py
+++ b/synapse/replication/slave/storage/appservice.py
@@ -20,6 +20,7 @@ from synapse.storage.appservice import (
)
-class SlavedApplicationServiceStore(ApplicationServiceTransactionWorkerStore,
- ApplicationServiceWorkerStore):
+class SlavedApplicationServiceStore(
+ ApplicationServiceTransactionWorkerStore, ApplicationServiceWorkerStore
+):
pass
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 5b8521c770..14ced32333 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -25,9 +25,7 @@ class SlavedClientIpStore(BaseSlavedStore):
super(SlavedClientIpStore, self).__init__(db_conn, hs)
self.client_ip_last_seen = Cache(
- name="client_ip_last_seen",
- keylen=4,
- max_entries=50000 * CACHE_SIZE_FACTOR,
+ name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
)
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 4d59778863..284fd30d89 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -24,15 +24,15 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker(
- db_conn, "device_max_stream_id", "stream_id",
+ db_conn, "device_max_stream_id", "stream_id"
)
self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache",
- self._device_inbox_id_gen.get_current_token()
+ self._device_inbox_id_gen.get_current_token(),
)
self._device_federation_outbox_stream_cache = StreamChangeCache(
"DeviceFederationOutboxStreamChangeCache",
- self._device_inbox_id_gen.get_current_token()
+ self._device_inbox_id_gen.get_current_token(),
)
self._last_device_delete_cache = ExpiringCache(
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 16c9a162c5..d9300fce33 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -27,14 +27,14 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
self.hs = hs
self._device_list_id_gen = SlavedIdTracker(
- db_conn, "device_lists_stream", "stream_id",
+ db_conn, "device_lists_stream", "stream_id"
)
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
- "DeviceListStreamChangeCache", device_list_max,
+ "DeviceListStreamChangeCache", device_list_max
)
self._device_list_federation_stream_cache = StreamChangeCache(
- "DeviceListFederationStreamChangeCache", device_list_max,
+ "DeviceListFederationStreamChangeCache", device_list_max
)
def stream_positions(self):
@@ -46,17 +46,13 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
if stream_name == "device_lists":
self._device_list_id_gen.advance(token)
for row in rows:
- self._invalidate_caches_for_devices(
- token, row.user_id, row.destination,
- )
+ self._invalidate_caches_for_devices(token, row.user_id, row.destination)
return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows
)
def _invalidate_caches_for_devices(self, token, user_id, destination):
- self._device_list_stream_cache.entity_has_changed(
- user_id, token
- )
+ self._device_list_stream_cache.entity_has_changed(user_id, token)
if destination:
self._device_list_federation_stream_cache.entity_has_changed(
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index a3952506c1..ab5937e638 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -45,21 +45,20 @@ logger = logging.getLogger(__name__)
# the method descriptor on the DataStore and chuck them into our class.
-class SlavedEventStore(EventFederationWorkerStore,
- RoomMemberWorkerStore,
- EventPushActionsWorkerStore,
- StreamWorkerStore,
- StateGroupWorkerStore,
- EventsWorkerStore,
- SignatureWorkerStore,
- UserErasureWorkerStore,
- RelationsWorkerStore,
- BaseSlavedStore):
-
+class SlavedEventStore(
+ EventFederationWorkerStore,
+ RoomMemberWorkerStore,
+ EventPushActionsWorkerStore,
+ StreamWorkerStore,
+ StateGroupWorkerStore,
+ EventsWorkerStore,
+ SignatureWorkerStore,
+ UserErasureWorkerStore,
+ RelationsWorkerStore,
+ BaseSlavedStore,
+):
def __init__(self, db_conn, hs):
- self._stream_id_gen = SlavedIdTracker(
- db_conn, "events", "stream_ordering",
- )
+ self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1
)
@@ -90,8 +89,13 @@ class SlavedEventStore(EventFederationWorkerStore,
self._backfill_id_gen.advance(-token)
for row in rows:
self.invalidate_caches_for_event(
- -token, row.event_id, row.room_id, row.type, row.state_key,
- row.redacts, row.relates_to,
+ -token,
+ row.event_id,
+ row.room_id,
+ row.type,
+ row.state_key,
+ row.redacts,
+ row.relates_to,
backfilled=True,
)
return super(SlavedEventStore, self).process_replication_rows(
@@ -103,41 +107,48 @@ class SlavedEventStore(EventFederationWorkerStore,
if row.type == EventsStreamEventRow.TypeId:
self.invalidate_caches_for_event(
- token, data.event_id, data.room_id, data.type, data.state_key,
- data.redacts, data.relates_to,
+ token,
+ data.event_id,
+ data.room_id,
+ data.type,
+ data.state_key,
+ data.redacts,
+ data.relates_to,
backfilled=False,
)
elif row.type == EventsStreamCurrentStateRow.TypeId:
if data.type == EventTypes.Member:
self.get_rooms_for_user_with_stream_ordering.invalidate(
- (data.state_key, ),
+ (data.state_key,)
)
else:
- raise Exception("Unknown events stream row type %s" % (row.type, ))
-
- def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,
- etype, state_key, redacts, relates_to,
- backfilled):
+ raise Exception("Unknown events stream row type %s" % (row.type,))
+
+ def invalidate_caches_for_event(
+ self,
+ stream_ordering,
+ event_id,
+ room_id,
+ etype,
+ state_key,
+ redacts,
+ relates_to,
+ backfilled,
+ ):
self._invalidate_get_event_cache(event_id)
self.get_latest_event_ids_in_room.invalidate((room_id,))
- self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
- (room_id,)
- )
+ self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
if not backfilled:
- self._events_stream_cache.entity_has_changed(
- room_id, stream_ordering
- )
+ self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
if redacts:
self._invalidate_get_event_cache(redacts)
if etype == EventTypes.Member:
- self._membership_stream_cache.entity_has_changed(
- state_key, stream_ordering
- )
+ self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
self.get_invited_rooms_for_user.invalidate((state_key,))
if relates_to:
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index e933b170bb..28a46edd28 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -27,10 +27,11 @@ class SlavedGroupServerStore(BaseSlavedStore):
self.hs = hs
self._group_updates_id_gen = SlavedIdTracker(
- db_conn, "local_group_updates", "stream_id",
+ db_conn, "local_group_updates", "stream_id"
)
self._group_updates_stream_cache = StreamChangeCache(
- "_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(),
+ "_group_updates_stream_cache",
+ self._group_updates_id_gen.get_current_token(),
)
get_groups_changes_for_user = __func__(DataStore.get_groups_changes_for_user)
@@ -46,9 +47,7 @@ class SlavedGroupServerStore(BaseSlavedStore):
if stream_name == "groups":
self._group_updates_id_gen.advance(token)
for row in rows:
- self._group_updates_stream_cache.entity_has_changed(
- row.user_id, token
- )
+ self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedGroupServerStore, self).process_replication_rows(
stream_name, token, rows
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 0ec1db25ce..82d808af4c 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -24,9 +24,7 @@ from ._slaved_id_tracker import SlavedIdTracker
class SlavedPresenceStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedPresenceStore, self).__init__(db_conn, hs)
- self._presence_id_gen = SlavedIdTracker(
- db_conn, "presence_stream", "stream_id",
- )
+ self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")
self._presence_on_startup = self._get_active_presence(db_conn)
@@ -55,9 +53,7 @@ class SlavedPresenceStore(BaseSlavedStore):
if stream_name == "presence":
self._presence_id_gen.advance(token)
for row in rows:
- self.presence_stream_cache.entity_has_changed(
- row.user_id, token
- )
+ self.presence_stream_cache.entity_has_changed(row.user_id, token)
self._get_presence_for_user.invalidate((row.user_id,))
return super(SlavedPresenceStore, self).process_replication_rows(
stream_name, token, rows
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 45fc913c52..af7012702e 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -23,7 +23,7 @@ from .events import SlavedEventStore
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def __init__(self, db_conn, hs):
self._push_rules_stream_id_gen = SlavedIdTracker(
- db_conn, "push_rules_stream", "stream_id",
+ db_conn, "push_rules_stream", "stream_id"
)
super(SlavedPushRuleStore, self).__init__(db_conn, hs)
@@ -47,9 +47,7 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
for row in rows:
self.get_push_rules_for_user.invalidate((row.user_id,))
self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
- self.push_rules_stream_cache.entity_has_changed(
- row.user_id, token
- )
+ self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedPushRuleStore, self).process_replication_rows(
stream_name, token, rows
)
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index 3b2213c0d4..8eeb267d61 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -21,12 +21,10 @@ from ._slaved_id_tracker import SlavedIdTracker
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
-
def __init__(self, db_conn, hs):
super(SlavedPusherStore, self).__init__(db_conn, hs)
self._pushers_id_gen = SlavedIdTracker(
- db_conn, "pushers", "id",
- extra_tables=[("deleted_pushers", "stream_id")],
+ db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
def stream_positions(self):
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index ed12342f40..91afa5a72b 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -29,7 +29,6 @@ from ._slaved_id_tracker import SlavedIdTracker
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
-
def __init__(self, db_conn, hs):
# We instantiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 0cb474928c..f68b3378e3 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -38,6 +38,4 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
if stream_name == "public_rooms":
self._public_room_id_gen.advance(token)
- return super(RoomStore, self).process_replication_rows(
- stream_name, token, rows
- )
+ return super(RoomStore, self).process_replication_rows(stream_name, token, rows)
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 206dc3b397..a44ceb00e7 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -39,6 +39,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
Accepts a handler that will be called when new data is available or data
is required.
"""
+
maxDelay = 30 # Try at least once every N seconds
def __init__(self, hs, client_name, handler):
@@ -64,9 +65,7 @@ class ReplicationClientFactory(ReconnectingClientFactory):
def clientConnectionFailed(self, connector, reason):
logger.error("Failed to connect to replication: %r", reason)
- ReconnectingClientFactory.clientConnectionFailed(
- self, connector, reason
- )
+ ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
class ReplicationClientHandler(object):
@@ -74,6 +73,7 @@ class ReplicationClientHandler(object):
By default proxies incoming replication data to the SlaveStore.
"""
+
def __init__(self, store):
self.store = store
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 2098c32a77..0ff2a7199f 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -23,9 +23,11 @@ import platform
if platform.python_implementation() == "PyPy":
import json
+
_json_encoder = json.JSONEncoder()
else:
import simplejson as json
+
_json_encoder = json.JSONEncoder(namedtuple_as_object=False)
logger = logging.getLogger(__name__)
@@ -41,6 +43,7 @@ class Command(object):
The default implementation creates a command of form `<NAME> <data>`
"""
+
NAME = None
def __init__(self, data):
@@ -73,6 +76,7 @@ class ServerCommand(Command):
SERVER <server_name>
"""
+
NAME = "SERVER"
@@ -99,6 +103,7 @@ class RdataCommand(Command):
RDATA presence batch ["@bar:example.com", "online", ...]
RDATA presence 59 ["@baz:example.com", "online", ...]
"""
+
NAME = "RDATA"
def __init__(self, stream_name, token, row):
@@ -110,17 +115,17 @@ class RdataCommand(Command):
def from_line(cls, line):
stream_name, token, row_json = line.split(" ", 2)
return cls(
- stream_name,
- None if token == "batch" else int(token),
- json.loads(row_json)
+ stream_name, None if token == "batch" else int(token), json.loads(row_json)
)
def to_line(self):
- return " ".join((
- self.stream_name,
- str(self.token) if self.token is not None else "batch",
- _json_encoder.encode(self.row),
- ))
+ return " ".join(
+ (
+ self.stream_name,
+ str(self.token) if self.token is not None else "batch",
+ _json_encoder.encode(self.row),
+ )
+ )
def get_logcontext_id(self):
return "RDATA-" + self.stream_name
@@ -133,6 +138,7 @@ class PositionCommand(Command):
Sent to the client after all missing updates for a stream have been sent
to the client and they're now up to date.
"""
+
NAME = "POSITION"
def __init__(self, stream_name, token):
@@ -145,19 +151,21 @@ class PositionCommand(Command):
return cls(stream_name, int(token))
def to_line(self):
- return " ".join((self.stream_name, str(self.token),))
+ return " ".join((self.stream_name, str(self.token)))
class ErrorCommand(Command):
"""Sent by either side if there was an ERROR. The data is a string describing
the error.
"""
+
NAME = "ERROR"
class PingCommand(Command):
"""Sent by either side as a keep alive. The data is arbitary (often timestamp)
"""
+
NAME = "PING"
@@ -165,6 +173,7 @@ class NameCommand(Command):
"""Sent by client to inform the server of the client's identity. The data
is the name
"""
+
NAME = "NAME"
@@ -184,6 +193,7 @@ class ReplicateCommand(Command):
REPLICATE ALL NOW
"""
+
NAME = "REPLICATE"
def __init__(self, stream_name, token):
@@ -200,7 +210,7 @@ class ReplicateCommand(Command):
return cls(stream_name, token)
def to_line(self):
- return " ".join((self.stream_name, str(self.token),))
+ return " ".join((self.stream_name, str(self.token)))
def get_logcontext_id(self):
return "REPLICATE-" + self.stream_name
@@ -218,6 +228,7 @@ class UserSyncCommand(Command):
Where <state> is either "start" or "stop"
"""
+
NAME = "USER_SYNC"
def __init__(self, user_id, is_syncing, last_sync_ms):
@@ -235,9 +246,13 @@ class UserSyncCommand(Command):
return cls(user_id, state == "start", int(last_sync_ms))
def to_line(self):
- return " ".join((
- self.user_id, "start" if self.is_syncing else "end", str(self.last_sync_ms),
- ))
+ return " ".join(
+ (
+ self.user_id,
+ "start" if self.is_syncing else "end",
+ str(self.last_sync_ms),
+ )
+ )
class FederationAckCommand(Command):
@@ -251,6 +266,7 @@ class FederationAckCommand(Command):
FEDERATION_ACK <token>
"""
+
NAME = "FEDERATION_ACK"
def __init__(self, token):
@@ -268,6 +284,7 @@ class SyncCommand(Command):
"""Used for testing. The client protocol implementation allows waiting
on a SYNC command with a specified data.
"""
+
NAME = "SYNC"
@@ -278,6 +295,7 @@ class RemovePusherCommand(Command):
REMOVE_PUSHER <app_id> <push_key> <user_id>
"""
+
NAME = "REMOVE_PUSHER"
def __init__(self, app_id, push_key, user_id):
@@ -309,6 +327,7 @@ class InvalidateCacheCommand(Command):
Where <keys_json> is a json list.
"""
+
NAME = "INVALIDATE_CACHE"
def __init__(self, cache_func, keys):
@@ -322,9 +341,7 @@ class InvalidateCacheCommand(Command):
return cls(cache_func, json.loads(keys_json))
def to_line(self):
- return " ".join((
- self.cache_func, _json_encoder.encode(self.keys),
- ))
+ return " ".join((self.cache_func, _json_encoder.encode(self.keys)))
class UserIpCommand(Command):
@@ -334,6 +351,7 @@ class UserIpCommand(Command):
USER_IP <user_id>, <access_token>, <ip>, <device_id>, <last_seen>, <user_agent>
"""
+
NAME = "USER_IP"
def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen):
@@ -350,15 +368,22 @@ class UserIpCommand(Command):
access_token, ip, user_agent, device_id, last_seen = json.loads(jsn)
- return cls(
- user_id, access_token, ip, user_agent, device_id, last_seen
- )
+ return cls(user_id, access_token, ip, user_agent, device_id, last_seen)
def to_line(self):
- return self.user_id + " " + _json_encoder.encode((
- self.access_token, self.ip, self.user_agent, self.device_id,
- self.last_seen,
- ))
+ return (
+ self.user_id
+ + " "
+ + _json_encoder.encode(
+ (
+ self.access_token,
+ self.ip,
+ self.user_agent,
+ self.device_id,
+ self.last_seen,
+ )
+ )
+ )
# Map of command name to command type.
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index b51590cf8f..97efb835ad 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -84,7 +84,8 @@ from .commands import (
from .streams import STREAMS_MAP
connection_close_counter = Counter(
- "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"])
+ "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
+)
# A list of all connected protocols. This allows us to send metrics about the
# connections.
@@ -119,7 +120,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
It also sends `PING` periodically, and correctly times out remote connections
(if they send a `PING` command)
"""
- delimiter = b'\n'
+
+ delimiter = b"\n"
VALID_INBOUND_COMMANDS = [] # Valid commands we expect to receive
VALID_OUTBOUND_COMMANDS = [] # Valid commans we can send
@@ -183,10 +185,14 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
if now - self.last_sent_command >= PING_TIME:
self.send_command(PingCommand(now))
- if self.received_ping and now - self.last_received_command > PING_TIMEOUT_MS:
+ if (
+ self.received_ping
+ and now - self.last_received_command > PING_TIMEOUT_MS
+ ):
logger.info(
"[%s] Connection hasn't received command in %r ms. Closing.",
- self.id(), now - self.last_received_command
+ self.id(),
+ now - self.last_received_command,
)
self.send_error("ping timeout")
@@ -208,7 +214,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.last_received_command = self.clock.time_msec()
self.inbound_commands_counter[cmd_name] = (
- self.inbound_commands_counter[cmd_name] + 1)
+ self.inbound_commands_counter[cmd_name] + 1
+ )
cmd_cls = COMMAND_MAP[cmd_name]
try:
@@ -224,9 +231,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# Now lets try and call on_<CMD_NAME> function
run_as_background_process(
- "replication-" + cmd.get_logcontext_id(),
- self.handle_command,
- cmd,
+ "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
)
def handle_command(self, cmd):
@@ -274,8 +279,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
return
self.outbound_commands_counter[cmd.NAME] = (
- self.outbound_commands_counter[cmd.NAME] + 1)
- string = "%s %s" % (cmd.NAME, cmd.to_line(),)
+ self.outbound_commands_counter[cmd.NAME] + 1
+ )
+ string = "%s %s" % (cmd.NAME, cmd.to_line())
if "\n" in string:
raise Exception("Unexpected newline in command: %r", string)
@@ -283,10 +289,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
if len(encoded_string) > self.MAX_LENGTH:
raise Exception(
- "Failed to send command %s as too long (%d > %d)" % (
- cmd.NAME,
- len(encoded_string), self.MAX_LENGTH,
- )
+ "Failed to send command %s as too long (%d > %d)"
+ % (cmd.NAME, len(encoded_string), self.MAX_LENGTH)
)
self.sendLine(encoded_string)
@@ -379,7 +383,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
if self.transport:
addr = str(self.transport.getPeer())
return "ReplicationConnection<name=%s,conn_id=%s,addr=%s>" % (
- self.name, self.conn_id, addr,
+ self.name,
+ self.conn_id,
+ addr,
)
def id(self):
@@ -422,7 +428,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
def on_USER_SYNC(self, cmd):
return self.streamer.on_user_sync(
- self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms,
+ self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
def on_REPLICATE(self, cmd):
@@ -432,10 +438,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
if stream_name == "ALL":
# Subscribe to all streams we're publishing to.
deferreds = [
- run_in_background(
- self.subscribe_to_stream,
- stream, token,
- )
+ run_in_background(self.subscribe_to_stream, stream, token)
for stream in iterkeys(self.streamer.streams_by_name)
]
@@ -449,16 +452,18 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
return self.streamer.federation_ack(cmd.token)
def on_REMOVE_PUSHER(self, cmd):
- return self.streamer.on_remove_pusher(
- cmd.app_id, cmd.push_key, cmd.user_id,
- )
+ return self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
def on_INVALIDATE_CACHE(self, cmd):
return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
def on_USER_IP(self, cmd):
return self.streamer.on_user_ip(
- cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id,
+ cmd.user_id,
+ cmd.access_token,
+ cmd.ip,
+ cmd.user_agent,
+ cmd.device_id,
cmd.last_seen,
)
@@ -476,7 +481,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
try:
# Get missing updates
updates, current_token = yield self.streamer.get_stream_updates(
- stream_name, token,
+ stream_name, token
)
# Send all the missing updates
@@ -608,8 +613,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
row = STREAMS_MAP[stream_name].parse_row(cmd.row)
except Exception:
logger.exception(
- "[%s] Failed to parse RDATA: %r %r",
- self.id(), stream_name, cmd.row
+ "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row
)
raise
@@ -643,7 +647,9 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
logger.info(
"[%s] Subscribing to replication stream: %r from %r",
- self.id(), stream_name, token
+ self.id(),
+ stream_name,
+ token,
)
self.streams_connecting.add(stream_name)
@@ -661,9 +667,7 @@ pending_commands = LaterGauge(
"synapse_replication_tcp_protocol_pending_commands",
"",
["name"],
- lambda: {
- (p.name,): len(p.pending_commands) for p in connected_connections
- },
+ lambda: {(p.name,): len(p.pending_commands) for p in connected_connections},
)
@@ -678,9 +682,7 @@ transport_send_buffer = LaterGauge(
"synapse_replication_tcp_protocol_transport_send_buffer",
"",
["name"],
- lambda: {
- (p.name,): transport_buffer_size(p) for p in connected_connections
- },
+ lambda: {(p.name,): transport_buffer_size(p) for p in connected_connections},
)
@@ -694,7 +696,7 @@ def transport_kernel_read_buffer_size(protocol, read=True):
op = SIOCINQ
else:
op = SIOCOUTQ
- size = struct.unpack("I", fcntl.ioctl(fileno, op, '\0\0\0\0'))[0]
+ size = struct.unpack("I", fcntl.ioctl(fileno, op, "\0\0\0\0"))[0]
return size
return 0
@@ -726,7 +728,7 @@ tcp_inbound_commands = LaterGauge(
"",
["command", "name"],
lambda: {
- (k, p.name,): count
+ (k, p.name): count
for p in connected_connections
for k, count in iteritems(p.inbound_commands_counter)
},
@@ -737,7 +739,7 @@ tcp_outbound_commands = LaterGauge(
"",
["command", "name"],
lambda: {
- (k, p.name,): count
+ (k, p.name): count
for p in connected_connections
for k, count in iteritems(p.outbound_commands_counter)
},
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index f6a38f5140..d1e98428bc 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -33,13 +33,15 @@ from .protocol import ServerReplicationStreamProtocol
from .streams import STREAMS_MAP
from .streams.federation import FederationStream
-stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates",
- "", ["stream_name"])
+stream_updates_counter = Counter(
+ "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]
+)
user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
-invalidate_cache_counter = Counter("synapse_replication_tcp_resource_invalidate_cache",
- "")
+invalidate_cache_counter = Counter(
+ "synapse_replication_tcp_resource_invalidate_cache", ""
+)
user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
logger = logging.getLogger(__name__)
@@ -48,6 +50,7 @@ logger = logging.getLogger(__name__)
class ReplicationStreamProtocolFactory(Factory):
"""Factory for new replication connections.
"""
+
def __init__(self, hs):
self.streamer = ReplicationStreamer(hs)
self.clock = hs.get_clock()
@@ -55,9 +58,7 @@ class ReplicationStreamProtocolFactory(Factory):
def buildProtocol(self, addr):
return ServerReplicationStreamProtocol(
- self.server_name,
- self.clock,
- self.streamer,
+ self.server_name, self.clock, self.streamer
)
@@ -80,29 +81,39 @@ class ReplicationStreamer(object):
# Current connections.
self.connections = []
- LaterGauge("synapse_replication_tcp_resource_total_connections", "", [],
- lambda: len(self.connections))
+ LaterGauge(
+ "synapse_replication_tcp_resource_total_connections",
+ "",
+ [],
+ lambda: len(self.connections),
+ )
# List of streams that clients can subscribe to.
# We only support federation stream if federation sending hase been
# disabled on the master.
self.streams = [
- stream(hs) for stream in itervalues(STREAMS_MAP)
+ stream(hs)
+ for stream in itervalues(STREAMS_MAP)
if stream != FederationStream or not hs.config.send_federation
]
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
LaterGauge(
- "synapse_replication_tcp_resource_connections_per_stream", "",
+ "synapse_replication_tcp_resource_connections_per_stream",
+ "",
["stream_name"],
lambda: {
- (stream_name,): len([
- conn for conn in self.connections
- if stream_name in conn.replication_streams
- ])
+ (stream_name,): len(
+ [
+ conn
+ for conn in self.connections
+ if stream_name in conn.replication_streams
+ ]
+ )
for stream_name in self.streams_by_name
- })
+ },
+ )
self.federation_sender = None
if not hs.config.send_federation:
@@ -179,7 +190,9 @@ class ReplicationStreamer(object):
logger.debug(
"Getting stream: %s: %s -> %s",
- stream.NAME, stream.last_token, stream.upto_token
+ stream.NAME,
+ stream.last_token,
+ stream.upto_token,
)
try:
updates, current_token = yield stream.get_updates()
@@ -189,7 +202,8 @@ class ReplicationStreamer(object):
logger.debug(
"Sending %d updates to %d connections",
- len(updates), len(self.connections),
+ len(updates),
+ len(self.connections),
)
if updates:
@@ -243,7 +257,7 @@ class ReplicationStreamer(object):
"""
user_sync_counter.inc()
yield self.presence_handler.update_external_syncs_row(
- conn_id, user_id, is_syncing, last_sync_ms,
+ conn_id, user_id, is_syncing, last_sync_ms
)
@measure_func("repl.on_remove_pusher")
@@ -272,7 +286,7 @@ class ReplicationStreamer(object):
"""
user_ip_cache_counter.inc()
yield self.store.insert_client_ip(
- user_id, access_token, ip, user_agent, device_id, last_seen,
+ user_id, access_token, ip, user_agent, device_id, last_seen
)
yield self._server_notices_sender.on_user_ip(user_id)
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index b6ce7a7bee..7ef67a5a73 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -26,78 +26,75 @@ logger = logging.getLogger(__name__)
MAX_EVENTS_BEHIND = 10000
-BackfillStreamRow = namedtuple("BackfillStreamRow", (
- "event_id", # str
- "room_id", # str
- "type", # str
- "state_key", # str, optional
- "redacts", # str, optional
- "relates_to", # str, optional
-))
-PresenceStreamRow = namedtuple("PresenceStreamRow", (
- "user_id", # str
- "state", # str
- "last_active_ts", # int
- "last_federation_update_ts", # int
- "last_user_sync_ts", # int
- "status_msg", # str
- "currently_active", # bool
-))
-TypingStreamRow = namedtuple("TypingStreamRow", (
- "room_id", # str
- "user_ids", # list(str)
-))
-ReceiptsStreamRow = namedtuple("ReceiptsStreamRow", (
- "room_id", # str
- "receipt_type", # str
- "user_id", # str
- "event_id", # str
- "data", # dict
-))
-PushRulesStreamRow = namedtuple("PushRulesStreamRow", (
- "user_id", # str
-))
-PushersStreamRow = namedtuple("PushersStreamRow", (
- "user_id", # str
- "app_id", # str
- "pushkey", # str
- "deleted", # bool
-))
-CachesStreamRow = namedtuple("CachesStreamRow", (
- "cache_func", # str
- "keys", # list(str)
- "invalidation_ts", # int
-))
-PublicRoomsStreamRow = namedtuple("PublicRoomsStreamRow", (
- "room_id", # str
- "visibility", # str
- "appservice_id", # str, optional
- "network_id", # str, optional
-))
-DeviceListsStreamRow = namedtuple("DeviceListsStreamRow", (
- "user_id", # str
- "destination", # str
-))
-ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", (
- "entity", # str
-))
-TagAccountDataStreamRow = namedtuple("TagAccountDataStreamRow", (
- "user_id", # str
- "room_id", # str
- "data", # dict
-))
-AccountDataStreamRow = namedtuple("AccountDataStream", (
- "user_id", # str
- "room_id", # str
- "data_type", # str
- "data", # dict
-))
-GroupsStreamRow = namedtuple("GroupsStreamRow", (
- "group_id", # str
- "user_id", # str
- "type", # str
- "content", # dict
-))
+BackfillStreamRow = namedtuple(
+ "BackfillStreamRow",
+ (
+ "event_id", # str
+ "room_id", # str
+ "type", # str
+ "state_key", # str, optional
+ "redacts", # str, optional
+ "relates_to", # str, optional
+ ),
+)
+PresenceStreamRow = namedtuple(
+ "PresenceStreamRow",
+ (
+ "user_id", # str
+ "state", # str
+ "last_active_ts", # int
+ "last_federation_update_ts", # int
+ "last_user_sync_ts", # int
+ "status_msg", # str
+ "currently_active", # bool
+ ),
+)
+TypingStreamRow = namedtuple(
+ "TypingStreamRow", ("room_id", "user_ids") # str # list(str)
+)
+ReceiptsStreamRow = namedtuple(
+ "ReceiptsStreamRow",
+ (
+ "room_id", # str
+ "receipt_type", # str
+ "user_id", # str
+ "event_id", # str
+ "data", # dict
+ ),
+)
+PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str
+PushersStreamRow = namedtuple(
+ "PushersStreamRow",
+ ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool
+)
+CachesStreamRow = namedtuple(
+ "CachesStreamRow",
+ ("cache_func", "keys", "invalidation_ts"), # str # list(str) # int
+)
+PublicRoomsStreamRow = namedtuple(
+ "PublicRoomsStreamRow",
+ (
+ "room_id", # str
+ "visibility", # str
+ "appservice_id", # str, optional
+ "network_id", # str, optional
+ ),
+)
+DeviceListsStreamRow = namedtuple(
+ "DeviceListsStreamRow", ("user_id", "destination") # str # str
+)
+ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str
+TagAccountDataStreamRow = namedtuple(
+ "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
+)
+AccountDataStreamRow = namedtuple(
+ "AccountDataStream",
+ ("user_id", "room_id", "data_type", "data"), # str # str # str # dict
+)
+GroupsStreamRow = namedtuple(
+ "GroupsStreamRow",
+ ("group_id", "user_id", "type", "content"), # str # str # str # dict
+)
class Stream(object):
@@ -106,6 +103,7 @@ class Stream(object):
Provides a `get_updates()` function that returns new updates since the last
time it was called up until the point `advance_current_token` was called.
"""
+
NAME = None # The name of the stream
ROW_TYPE = None # The type of the row. Used by the default impl of parse_row.
_LIMITED = True # Whether the update function takes a limit
@@ -185,16 +183,13 @@ class Stream(object):
if self._LIMITED:
rows = yield self.update_function(
- from_token, current_token,
- limit=MAX_EVENTS_BEHIND + 1,
+ from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
)
# never turn more than MAX_EVENTS_BEHIND + 1 into updates.
rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
else:
- rows = yield self.update_function(
- from_token, current_token,
- )
+ rows = yield self.update_function(from_token, current_token)
updates = [(row[0], row[1:]) for row in rows]
@@ -230,6 +225,7 @@ class BackfillStream(Stream):
"""We fetched some old events and either we had never seen that event before
or it went from being an outlier to not.
"""
+
NAME = "backfill"
ROW_TYPE = BackfillStreamRow
@@ -286,6 +282,7 @@ class ReceiptsStream(Stream):
class PushRulesStream(Stream):
"""A user has changed their push rules
"""
+
NAME = "push_rules"
ROW_TYPE = PushRulesStreamRow
@@ -306,6 +303,7 @@ class PushRulesStream(Stream):
class PushersStream(Stream):
"""A user has added/changed/removed a pusher
"""
+
NAME = "pushers"
ROW_TYPE = PushersStreamRow
@@ -322,6 +320,7 @@ class CachesStream(Stream):
"""A cache was invalidated on the master and no other stream would invalidate
the cache on the workers
"""
+
NAME = "caches"
ROW_TYPE = CachesStreamRow
@@ -337,6 +336,7 @@ class CachesStream(Stream):
class PublicRoomsStream(Stream):
"""The public rooms list changed
"""
+
NAME = "public_rooms"
ROW_TYPE = PublicRoomsStreamRow
@@ -352,6 +352,7 @@ class PublicRoomsStream(Stream):
class DeviceListsStream(Stream):
"""Someone added/changed/removed a device
"""
+
NAME = "device_lists"
_LIMITED = False
ROW_TYPE = DeviceListsStreamRow
@@ -368,6 +369,7 @@ class DeviceListsStream(Stream):
class ToDeviceStream(Stream):
"""New to_device messages for a client
"""
+
NAME = "to_device"
ROW_TYPE = ToDeviceStreamRow
@@ -383,6 +385,7 @@ class ToDeviceStream(Stream):
class TagAccountDataStream(Stream):
"""Someone added/removed a tag for a room
"""
+
NAME = "tag_account_data"
ROW_TYPE = TagAccountDataStreamRow
@@ -398,6 +401,7 @@ class TagAccountDataStream(Stream):
class AccountDataStream(Stream):
"""Global or per room account data was changed
"""
+
NAME = "account_data"
ROW_TYPE = AccountDataStreamRow
@@ -416,7 +420,7 @@ class AccountDataStream(Stream):
results = list(room_results)
results.extend(
- (stream_id, user_id, None, account_data_type, content,)
+ (stream_id, user_id, None, account_data_type, content)
for stream_id, user_id, account_data_type, content in global_results
)
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index f1290d022a..3d0694bb11 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -52,6 +52,7 @@ data part are:
@attr.s(slots=True, frozen=True)
class EventsStreamRow(object):
"""A parsed row from the events replication stream"""
+
type = attr.ib() # str: the TypeId of one of the *EventsStreamRows
data = attr.ib() # BaseEventsStreamRow
@@ -80,11 +81,11 @@ class BaseEventsStreamRow(object):
class EventsStreamEventRow(BaseEventsStreamRow):
TypeId = "ev"
- event_id = attr.ib() # str
- room_id = attr.ib() # str
- type = attr.ib() # str
- state_key = attr.ib() # str, optional
- redacts = attr.ib() # str, optional
+ event_id = attr.ib() # str
+ room_id = attr.ib() # str
+ type = attr.ib() # str
+ state_key = attr.ib() # str, optional
+ redacts = attr.ib() # str, optional
relates_to = attr.ib() # str, optional
@@ -92,24 +93,21 @@ class EventsStreamEventRow(BaseEventsStreamRow):
class EventsStreamCurrentStateRow(BaseEventsStreamRow):
TypeId = "state"
- room_id = attr.ib() # str
- type = attr.ib() # str
+ room_id = attr.ib() # str
+ type = attr.ib() # str
state_key = attr.ib() # str
- event_id = attr.ib() # str, optional
+ event_id = attr.ib() # str, optional
TypeToRow = {
- Row.TypeId: Row
- for Row in (
- EventsStreamEventRow,
- EventsStreamCurrentStateRow,
- )
+ Row.TypeId: Row for Row in (EventsStreamEventRow, EventsStreamCurrentStateRow)
}
class EventsStream(Stream):
"""We received a new event, or an event went from being an outlier to not
"""
+
NAME = "events"
def __init__(self, hs):
@@ -121,19 +119,17 @@ class EventsStream(Stream):
@defer.inlineCallbacks
def update_function(self, from_token, current_token, limit=None):
event_rows = yield self._store.get_all_new_forward_event_rows(
- from_token, current_token, limit,
+ from_token, current_token, limit
)
event_updates = (
- (row[0], EventsStreamEventRow.TypeId, row[1:])
- for row in event_rows
+ (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows
)
state_rows = yield self._store.get_all_updated_current_state_deltas(
from_token, current_token, limit
)
state_updates = (
- (row[0], EventsStreamCurrentStateRow.TypeId, row[1:])
- for row in state_rows
+ (row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) for row in state_rows
)
all_updates = heapq.merge(event_updates, state_updates)
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 9aa43aa8d2..dc2484109d 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -17,16 +17,20 @@ from collections import namedtuple
from ._base import Stream
-FederationStreamRow = namedtuple("FederationStreamRow", (
- "type", # str, the type of data as defined in the BaseFederationRows
- "data", # dict, serialization of a federation.send_queue.BaseFederationRow
-))
+FederationStreamRow = namedtuple(
+ "FederationStreamRow",
+ (
+ "type", # str, the type of data as defined in the BaseFederationRows
+ "data", # dict, serialization of a federation.send_queue.BaseFederationRow
+ ),
+)
class FederationStream(Stream):
"""Data to be sent over federation. Only available when master has federation
sending disabled.
"""
+
NAME = "federation"
ROW_TYPE = FederationStreamRow
diff --git a/synapse/res/templates/password_reset_success.html b/synapse/res/templates/password_reset_success.html
index 7b6fa5e6f0..7324d66d1e 100644
--- a/synapse/res/templates/password_reset_success.html
+++ b/synapse/res/templates/password_reset_success.html
@@ -1,6 +1,6 @@
<html>
<head></head>
<body>
-<p>Your password was successfully reset. You may now close this window.</p>
+<p>Your email has now been validated, please return to your client to reset your password. You may now close this window.</p>
</body>
</html>
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index e6110ad9b1..1d20b96d03 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -66,6 +66,7 @@ class ClientRestResource(JsonResource):
* /_matrix/client/unstable
* etc
"""
+
def __init__(self, hs):
JsonResource.__init__(self, hs, canonical_json=False)
self.register_servlets(self, hs)
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index d6c4dcdb18..9843a902c6 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -61,7 +61,7 @@ def historical_admin_path_patterns(path_regex):
"^/_synapse/admin/v1",
"^/_matrix/client/api/v1/admin",
"^/_matrix/client/unstable/admin",
- "^/_matrix/client/r0/admin"
+ "^/_matrix/client/r0/admin",
)
)
@@ -88,12 +88,12 @@ class UsersRestServlet(RestServlet):
class VersionServlet(RestServlet):
- PATTERNS = (re.compile("^/_synapse/admin/v1/server_version$"), )
+ PATTERNS = (re.compile("^/_synapse/admin/v1/server_version$"),)
def __init__(self, hs):
self.res = {
- 'server_version': get_version_string(synapse),
- 'python_version': platform.python_version(),
+ "server_version": get_version_string(synapse),
+ "python_version": platform.python_version(),
}
def on_GET(self, request):
@@ -107,6 +107,7 @@ class UserRegisterServlet(RestServlet):
nonces (dict[str, int]): The nonces that we will accept. A dict of
nonce to the time it was generated, in int seconds.
"""
+
PATTERNS = historical_admin_path_patterns("/register")
NONCE_TIMEOUT = 60
@@ -146,28 +147,24 @@ class UserRegisterServlet(RestServlet):
body = parse_json_object_from_request(request)
if "nonce" not in body:
- raise SynapseError(
- 400, "nonce must be specified", errcode=Codes.BAD_JSON,
- )
+ raise SynapseError(400, "nonce must be specified", errcode=Codes.BAD_JSON)
nonce = body["nonce"]
if nonce not in self.nonces:
- raise SynapseError(
- 400, "unrecognised nonce",
- )
+ raise SynapseError(400, "unrecognised nonce")
# Delete the nonce, so it can't be reused, even if it's invalid
del self.nonces[nonce]
if "username" not in body:
raise SynapseError(
- 400, "username must be specified", errcode=Codes.BAD_JSON,
+ 400, "username must be specified", errcode=Codes.BAD_JSON
)
else:
if (
- not isinstance(body['username'], text_type)
- or len(body['username']) > 512
+ not isinstance(body["username"], text_type)
+ or len(body["username"]) > 512
):
raise SynapseError(400, "Invalid username")
@@ -177,12 +174,12 @@ class UserRegisterServlet(RestServlet):
if "password" not in body:
raise SynapseError(
- 400, "password must be specified", errcode=Codes.BAD_JSON,
+ 400, "password must be specified", errcode=Codes.BAD_JSON
)
else:
if (
- not isinstance(body['password'], text_type)
- or len(body['password']) > 512
+ not isinstance(body["password"], text_type)
+ or len(body["password"]) > 512
):
raise SynapseError(400, "Invalid password")
@@ -202,7 +199,7 @@ class UserRegisterServlet(RestServlet):
key=self.hs.config.registration_shared_secret.encode(),
digestmod=hashlib.sha1,
)
- want_mac.update(nonce.encode('utf8'))
+ want_mac.update(nonce.encode("utf8"))
want_mac.update(b"\x00")
want_mac.update(username)
want_mac.update(b"\x00")
@@ -211,13 +208,10 @@ class UserRegisterServlet(RestServlet):
want_mac.update(b"admin" if admin else b"notadmin")
if user_type:
want_mac.update(b"\x00")
- want_mac.update(user_type.encode('utf8'))
+ want_mac.update(user_type.encode("utf8"))
want_mac = want_mac.hexdigest()
- if not hmac.compare_digest(
- want_mac.encode('ascii'),
- got_mac.encode('ascii')
- ):
+ if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")):
raise SynapseError(403, "HMAC incorrect")
# Reuse the parts of RegisterRestServlet to reduce code duplication
@@ -226,7 +220,7 @@ class UserRegisterServlet(RestServlet):
register = RegisterRestServlet(self.hs)
(user_id, _) = yield register.registration_handler.register(
- localpart=body['username'].lower(),
+ localpart=body["username"].lower(),
password=body["password"],
admin=bool(admin),
generate_token=False,
@@ -308,7 +302,7 @@ class PurgeHistoryRestServlet(RestServlet):
# user can provide an event_id in the URL or the request body, or can
# provide a timestamp in the request body.
if event_id is None:
- event_id = body.get('purge_up_to_event_id')
+ event_id = body.get("purge_up_to_event_id")
if event_id is not None:
event = yield self.store.get_event(event_id)
@@ -318,44 +312,39 @@ class PurgeHistoryRestServlet(RestServlet):
token = yield self.store.get_topological_token_for_event(event_id)
- logger.info(
- "[purge] purging up to token %s (event_id %s)",
- token, event_id,
- )
- elif 'purge_up_to_ts' in body:
- ts = body['purge_up_to_ts']
+ logger.info("[purge] purging up to token %s (event_id %s)", token, event_id)
+ elif "purge_up_to_ts" in body:
+ ts = body["purge_up_to_ts"]
if not isinstance(ts, int):
raise SynapseError(
- 400, "purge_up_to_ts must be an int",
- errcode=Codes.BAD_JSON,
+ 400, "purge_up_to_ts must be an int", errcode=Codes.BAD_JSON
)
- stream_ordering = (
- yield self.store.find_first_stream_ordering_after_ts(ts)
- )
+ stream_ordering = (yield self.store.find_first_stream_ordering_after_ts(ts))
r = (
yield self.store.get_room_event_after_stream_ordering(
- room_id, stream_ordering,
+ room_id, stream_ordering
)
)
if not r:
logger.warn(
"[purge] purging events not possible: No event found "
"(received_ts %i => stream_ordering %i)",
- ts, stream_ordering,
+ ts,
+ stream_ordering,
)
raise SynapseError(
- 404,
- "there is no event to be purged",
- errcode=Codes.NOT_FOUND,
+ 404, "there is no event to be purged", errcode=Codes.NOT_FOUND
)
(stream, topo, _event_id) = r
token = "t%d-%d" % (topo, stream)
logger.info(
"[purge] purging up to token %s (received_ts %i => "
"stream_ordering %i)",
- token, ts, stream_ordering,
+ token,
+ ts,
+ stream_ordering,
)
else:
raise SynapseError(
@@ -365,13 +354,10 @@ class PurgeHistoryRestServlet(RestServlet):
)
purge_id = yield self.pagination_handler.start_purge_history(
- room_id, token,
- delete_local_events=delete_local_events,
+ room_id, token, delete_local_events=delete_local_events
)
- defer.returnValue((200, {
- "purge_id": purge_id,
- }))
+ defer.returnValue((200, {"purge_id": purge_id}))
class PurgeHistoryStatusRestServlet(RestServlet):
@@ -421,16 +407,14 @@ class DeactivateAccountRestServlet(RestServlet):
UserID.from_string(target_user_id)
result = yield self._deactivate_account_handler.deactivate_account(
- target_user_id, erase,
+ target_user_id, erase
)
if result:
id_server_unbind_result = "success"
else:
id_server_unbind_result = "no-support"
- defer.returnValue((200, {
- "id_server_unbind_result": id_server_unbind_result,
- }))
+ defer.returnValue((200, {"id_server_unbind_result": id_server_unbind_result}))
class ShutdownRoomRestServlet(RestServlet):
@@ -439,6 +423,7 @@ class ShutdownRoomRestServlet(RestServlet):
to a new room created by `new_room_user_id` and kicked users will be auto
joined to the new room.
"""
+
PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P<room_id>[^/]+)")
DEFAULT_MESSAGE = (
@@ -474,9 +459,7 @@ class ShutdownRoomRestServlet(RestServlet):
config={
"preset": "public_chat",
"name": room_name,
- "power_level_content_override": {
- "users_default": -10,
- },
+ "power_level_content_override": {"users_default": -10},
},
ratelimit=False,
)
@@ -485,8 +468,7 @@ class ShutdownRoomRestServlet(RestServlet):
requester_user_id = requester.user.to_string()
logger.info(
- "Shutting down room %r, joining to new room: %r",
- room_id, new_room_id,
+ "Shutting down room %r, joining to new room: %r", room_id, new_room_id
)
# This will work even if the room is already blocked, but that is
@@ -529,7 +511,7 @@ class ShutdownRoomRestServlet(RestServlet):
kicked_users.append(user_id)
except Exception:
logger.exception(
- "Failed to leave old room and join new room for %r", user_id,
+ "Failed to leave old room and join new room for %r", user_id
)
failed_to_kick_users.append(user_id)
@@ -550,18 +532,24 @@ class ShutdownRoomRestServlet(RestServlet):
room_id, new_room_id, requester_user_id
)
- defer.returnValue((200, {
- "kicked_users": kicked_users,
- "failed_to_kick_users": failed_to_kick_users,
- "local_aliases": aliases_for_room,
- "new_room_id": new_room_id,
- }))
+ defer.returnValue(
+ (
+ 200,
+ {
+ "kicked_users": kicked_users,
+ "failed_to_kick_users": failed_to_kick_users,
+ "local_aliases": aliases_for_room,
+ "new_room_id": new_room_id,
+ },
+ )
+ )
class QuarantineMediaInRoom(RestServlet):
"""Quarantines all media in a room so that no one can download it via
this server.
"""
+
PATTERNS = historical_admin_path_patterns("/quarantine_media/(?P<room_id>[^/]+)")
def __init__(self, hs):
@@ -574,7 +562,7 @@ class QuarantineMediaInRoom(RestServlet):
yield assert_user_is_admin(self.auth, requester.user)
num_quarantined = yield self.store.quarantine_media_ids_in_room(
- room_id, requester.user.to_string(),
+ room_id, requester.user.to_string()
)
defer.returnValue((200, {"num_quarantined": num_quarantined}))
@@ -583,6 +571,7 @@ class QuarantineMediaInRoom(RestServlet):
class ListMediaInRoom(RestServlet):
"""Lists all of the media in a given room.
"""
+
PATTERNS = historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media")
def __init__(self, hs):
@@ -613,7 +602,10 @@ class ResetPasswordRestServlet(RestServlet):
Returns:
200 OK with empty object if success otherwise an error.
"""
- PATTERNS = historical_admin_path_patterns("/reset_password/(?P<target_user_id>[^/]*)")
+
+ PATTERNS = historical_admin_path_patterns(
+ "/reset_password/(?P<target_user_id>[^/]*)"
+ )
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -633,7 +625,7 @@ class ResetPasswordRestServlet(RestServlet):
params = parse_json_object_from_request(request)
assert_params_in_dict(params, ["new_password"])
- new_password = params['new_password']
+ new_password = params["new_password"]
yield self._set_password_handler.set_password(
target_user_id, new_password, requester
@@ -650,7 +642,10 @@ class GetUsersPaginatedRestServlet(RestServlet):
Returns:
200 OK with json object {list[dict[str, Any]], count} or empty object.
"""
- PATTERNS = historical_admin_path_patterns("/users_paginate/(?P<target_user_id>[^/]*)")
+
+ PATTERNS = historical_admin_path_patterns(
+ "/users_paginate/(?P<target_user_id>[^/]*)"
+ )
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -676,9 +671,7 @@ class GetUsersPaginatedRestServlet(RestServlet):
logger.info("limit: %s, start: %s", limit, start)
- ret = yield self.handlers.admin_handler.get_users_paginate(
- order, start, limit
- )
+ ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit)
defer.returnValue((200, ret))
@defer.inlineCallbacks
@@ -702,13 +695,11 @@ class GetUsersPaginatedRestServlet(RestServlet):
order = "name" # order by name in user table
params = parse_json_object_from_request(request)
assert_params_in_dict(params, ["limit", "start"])
- limit = params['limit']
- start = params['start']
+ limit = params["limit"]
+ start = params["start"]
logger.info("limit: %s, start: %s", limit, start)
- ret = yield self.handlers.admin_handler.get_users_paginate(
- order, start, limit
- )
+ ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit)
defer.returnValue((200, ret))
@@ -722,6 +713,7 @@ class SearchUsersRestServlet(RestServlet):
Returns:
200 OK with json object {list[dict[str, Any]], count} or empty object.
"""
+
PATTERNS = historical_admin_path_patterns("/search_users/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
@@ -750,15 +742,14 @@ class SearchUsersRestServlet(RestServlet):
term = parse_string(request, "term", required=True)
logger.info("term: %s ", term)
- ret = yield self.handlers.admin_handler.search_users(
- term
- )
+ ret = yield self.handlers.admin_handler.search_users(term)
defer.returnValue((200, ret))
class DeleteGroupAdminRestServlet(RestServlet):
"""Allows deleting of local groups
"""
+
PATTERNS = historical_admin_path_patterns("/delete_group/(?P<group_id>[^/]*)")
def __init__(self, hs):
@@ -800,15 +791,15 @@ class AccountValidityRenewServlet(RestServlet):
raise SynapseError(400, "Missing property 'user_id' in the request body")
expiration_ts = yield self.account_activity_handler.renew_account_for_user(
- body["user_id"], body.get("expiration_ts"),
+ body["user_id"],
+ body.get("expiration_ts"),
not body.get("enable_renewal_emails", True),
)
- res = {
- "expiration_ts": expiration_ts,
- }
+ res = {"expiration_ts": expiration_ts}
defer.returnValue((200, res))
+
########################################################################################
#
# please don't add more servlets here: this file is already long and unwieldy. Put
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index ae5aca9dac..ee66838a0d 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -46,6 +46,7 @@ class SendServerNoticeServlet(RestServlet):
"event_id": "$1895723857jgskldgujpious"
}
"""
+
def __init__(self, hs):
"""
Args:
@@ -58,15 +59,9 @@ class SendServerNoticeServlet(RestServlet):
def register(self, json_resource):
PATTERN = "^/_synapse/admin/v1/send_server_notice"
+ json_resource.register_paths("POST", (re.compile(PATTERN + "$"),), self.on_POST)
json_resource.register_paths(
- "POST",
- (re.compile(PATTERN + "$"), ),
- self.on_POST,
- )
- json_resource.register_paths(
- "PUT",
- (re.compile(PATTERN + "/(?P<txn_id>[^/]*)$",), ),
- self.on_PUT,
+ "PUT", (re.compile(PATTERN + "/(?P<txn_id>[^/]*)$"),), self.on_PUT
)
@defer.inlineCallbacks
@@ -96,5 +91,5 @@ class SendServerNoticeServlet(RestServlet):
def on_PUT(self, request, txn_id):
return self.txns.fetch_or_execute_request(
- request, self.on_POST, request, txn_id,
+ request, self.on_POST, request, txn_id
)
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index 48c17f1b6d..36404b797d 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -26,7 +26,6 @@ CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
class HttpTransactionCache(object):
-
def __init__(self, hs):
self.hs = hs
self.auth = self.hs.get_auth()
@@ -53,7 +52,7 @@ class HttpTransactionCache(object):
str: A transaction key
"""
token = self.auth.get_access_token_from_request(request)
- return request.path.decode('utf8') + "/" + token
+ return request.path.decode("utf8") + "/" + token
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
"""A helper function for fetch_or_execute which extracts
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 0035182bb9..dd0d38ea5c 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -56,8 +56,9 @@ class ClientDirectoryServer(RestServlet):
content = parse_json_object_from_request(request)
if "room_id" not in content:
- raise SynapseError(400, 'Missing params: ["room_id"]',
- errcode=Codes.BAD_JSON)
+ raise SynapseError(
+ 400, 'Missing params: ["room_id"]', errcode=Codes.BAD_JSON
+ )
logger.debug("Got content: %s", content)
logger.debug("Got room name: %s", room_alias.to_string())
@@ -89,13 +90,11 @@ class ClientDirectoryServer(RestServlet):
try:
service = yield self.auth.get_appservice_by_req(request)
room_alias = RoomAlias.from_string(room_alias)
- yield dir_handler.delete_appservice_association(
- service, room_alias
- )
+ yield dir_handler.delete_appservice_association(service, room_alias)
logger.info(
"Application service at %s deleted alias %s",
service.url,
- room_alias.to_string()
+ room_alias.to_string(),
)
defer.returnValue((200, {}))
except AuthError:
@@ -107,14 +106,10 @@ class ClientDirectoryServer(RestServlet):
room_alias = RoomAlias.from_string(room_alias)
- yield dir_handler.delete_association(
- requester, room_alias
- )
+ yield dir_handler.delete_association(requester, room_alias)
logger.info(
- "User %s deleted alias %s",
- user.to_string(),
- room_alias.to_string()
+ "User %s deleted alias %s", user.to_string(), room_alias.to_string()
)
defer.returnValue((200, {}))
@@ -135,9 +130,9 @@ class ClientDirectoryListServer(RestServlet):
if room is None:
raise NotFoundError("Unknown room")
- defer.returnValue((200, {
- "visibility": "public" if room["is_public"] else "private"
- }))
+ defer.returnValue(
+ (200, {"visibility": "public" if room["is_public"] else "private"})
+ )
@defer.inlineCallbacks
def on_PUT(self, request, room_id):
@@ -147,7 +142,7 @@ class ClientDirectoryListServer(RestServlet):
visibility = content.get("visibility", "public")
yield self.handlers.directory_handler.edit_published_room_list(
- requester, room_id, visibility,
+ requester, room_id, visibility
)
defer.returnValue((200, {}))
@@ -157,7 +152,7 @@ class ClientDirectoryListServer(RestServlet):
requester = yield self.auth.get_user_by_req(request)
yield self.handlers.directory_handler.edit_published_room_list(
- requester, room_id, "private",
+ requester, room_id, "private"
)
defer.returnValue((200, {}))
@@ -191,7 +186,7 @@ class ClientAppserviceDirectoryListServer(RestServlet):
)
yield self.handlers.directory_handler.edit_published_appservice_room_list(
- requester.app_service.id, network_id, room_id, visibility,
+ requester.app_service.id, network_id, room_id, visibility
)
defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index 84ca36270b..d6de2b7360 100644
--- a/synapse/rest/client/v1/events.py
+++ b/synapse/rest/client/v1/events.py
@@ -38,17 +38,14 @@ class EventStreamRestServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(
- request,
- allow_guest=True,
- )
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
is_guest = requester.is_guest
room_id = None
if is_guest:
if b"room_id" not in request.args:
raise SynapseError(400, "Guest users must specify room_id param")
if b"room_id" in request.args:
- room_id = request.args[b"room_id"][0].decode('ascii')
+ room_id = request.args[b"room_id"][0].decode("ascii")
pagin_config = PaginationConfig.from_request(request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 1a886cbbbf..a31d277935 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -44,10 +44,7 @@ def login_submission_legacy_convert(submission):
to a typed object.
"""
if "user" in submission:
- submission["identifier"] = {
- "type": "m.id.user",
- "user": submission["user"],
- }
+ submission["identifier"] = {"type": "m.id.user", "user": submission["user"]}
del submission["user"]
if "medium" in submission and "address" in submission:
@@ -73,11 +70,7 @@ def login_id_thirdparty_from_phone(identifier):
msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"])
- return {
- "type": "m.id.thirdparty",
- "medium": "msisdn",
- "address": msisdn,
- }
+ return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
class LoginRestServlet(RestServlet):
@@ -124,9 +117,9 @@ class LoginRestServlet(RestServlet):
# login flow types returned.
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
- flows.extend((
- {"type": t} for t in self.auth_handler.get_supported_login_types()
- ))
+ flows.extend(
+ ({"type": t} for t in self.auth_handler.get_supported_login_types())
+ )
return (200, {"flows": flows})
@@ -136,7 +129,8 @@ class LoginRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
self._address_ratelimiter.ratelimit(
- request.getClientIP(), time_now_s=self.hs.clock.time(),
+ request.getClientIP(),
+ time_now_s=self.hs.clock.time(),
rate_hz=self.hs.config.rc_login_address.per_second,
burst_count=self.hs.config.rc_login_address.burst_count,
update=True,
@@ -144,8 +138,9 @@ class LoginRestServlet(RestServlet):
login_submission = parse_json_object_from_request(request)
try:
- if self.jwt_enabled and (login_submission["type"] ==
- LoginRestServlet.JWT_TYPE):
+ if self.jwt_enabled and (
+ login_submission["type"] == LoginRestServlet.JWT_TYPE
+ ):
result = yield self.do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
result = yield self.do_token_login(login_submission)
@@ -174,10 +169,10 @@ class LoginRestServlet(RestServlet):
# field)
logger.info(
"Got login request with identifier: %r, medium: %r, address: %r, user: %r",
- login_submission.get('identifier'),
- login_submission.get('medium'),
- login_submission.get('address'),
- login_submission.get('user'),
+ login_submission.get("identifier"),
+ login_submission.get("medium"),
+ login_submission.get("address"),
+ login_submission.get("user"),
)
login_submission_legacy_convert(login_submission)
@@ -194,13 +189,13 @@ class LoginRestServlet(RestServlet):
# convert threepid identifiers to user IDs
if identifier["type"] == "m.id.thirdparty":
- address = identifier.get('address')
- medium = identifier.get('medium')
+ address = identifier.get("address")
+ medium = identifier.get("medium")
if medium is None or address is None:
raise SynapseError(400, "Invalid thirdparty identifier")
- if medium == 'email':
+ if medium == "email":
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
@@ -209,34 +204,28 @@ class LoginRestServlet(RestServlet):
# Check for login providers that support 3pid login types
canonical_user_id, callback_3pid = (
yield self.auth_handler.check_password_provider_3pid(
- medium,
- address,
- login_submission["password"],
+ medium, address, login_submission["password"]
)
)
if canonical_user_id:
# Authentication through password provider and 3pid succeeded
result = yield self._register_device_with_callback(
- canonical_user_id, login_submission, callback_3pid,
+ canonical_user_id, login_submission, callback_3pid
)
defer.returnValue(result)
# No password providers were able to handle this 3pid
# Check local store
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
- medium, address,
+ medium, address
)
if not user_id:
logger.warn(
- "unknown 3pid identifier medium %s, address %r",
- medium, address,
+ "unknown 3pid identifier medium %s, address %r", medium, address
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
- identifier = {
- "type": "m.id.user",
- "user": user_id,
- }
+ identifier = {"type": "m.id.user", "user": user_id}
# by this point, the identifier should be an m.id.user: if it's anything
# else, we haven't understood it.
@@ -246,22 +235,16 @@ class LoginRestServlet(RestServlet):
raise SynapseError(400, "User identifier is missing 'user' key")
canonical_user_id, callback = yield self.auth_handler.validate_login(
- identifier["user"],
- login_submission,
+ identifier["user"], login_submission
)
result = yield self._register_device_with_callback(
- canonical_user_id, login_submission, callback,
+ canonical_user_id, login_submission, callback
)
defer.returnValue(result)
@defer.inlineCallbacks
- def _register_device_with_callback(
- self,
- user_id,
- login_submission,
- callback=None,
- ):
+ def _register_device_with_callback(self, user_id, login_submission, callback=None):
""" Registers a device with a given user_id. Optionally run a callback
function after registration has completed.
@@ -277,7 +260,7 @@ class LoginRestServlet(RestServlet):
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
- user_id, device_id, initial_display_name,
+ user_id, device_id, initial_display_name
)
result = {
@@ -294,7 +277,7 @@ class LoginRestServlet(RestServlet):
@defer.inlineCallbacks
def do_token_login(self, login_submission):
- token = login_submission['token']
+ token = login_submission["token"]
auth_handler = self.auth_handler
user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
@@ -303,7 +286,7 @@ class LoginRestServlet(RestServlet):
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
- user_id, device_id, initial_display_name,
+ user_id, device_id, initial_display_name
)
result = {
@@ -320,15 +303,16 @@ class LoginRestServlet(RestServlet):
token = login_submission.get("token", None)
if token is None:
raise LoginError(
- 401, "Token field for JWT is missing",
- errcode=Codes.UNAUTHORIZED
+ 401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED
)
import jwt
from jwt.exceptions import InvalidTokenError
try:
- payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
+ payload = jwt.decode(
+ token, self.jwt_secret, algorithms=[self.jwt_algorithm]
+ )
except jwt.ExpiredSignatureError:
raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
except InvalidTokenError:
@@ -346,7 +330,7 @@ class LoginRestServlet(RestServlet):
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
- registered_user_id, device_id, initial_display_name,
+ registered_user_id, device_id, initial_display_name
)
result = {
@@ -362,7 +346,7 @@ class LoginRestServlet(RestServlet):
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
- registered_user_id, device_id, initial_display_name,
+ registered_user_id, device_id, initial_display_name
)
result = {
@@ -376,6 +360,7 @@ class LoginRestServlet(RestServlet):
class BaseSsoRedirectServlet(RestServlet):
"""Common base class for /login/sso/redirect impls"""
+
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
def on_GET(self, request):
@@ -401,21 +386,20 @@ class BaseSsoRedirectServlet(RestServlet):
raise NotImplementedError()
-class CasRedirectServlet(RestServlet):
+class CasRedirectServlet(BaseSsoRedirectServlet):
def __init__(self, hs):
super(CasRedirectServlet, self).__init__()
- self.cas_server_url = hs.config.cas_server_url.encode('ascii')
- self.cas_service_url = hs.config.cas_service_url.encode('ascii')
+ self.cas_server_url = hs.config.cas_server_url.encode("ascii")
+ self.cas_service_url = hs.config.cas_service_url.encode("ascii")
def get_sso_url(self, client_redirect_url):
- client_redirect_url_param = urllib.parse.urlencode({
- b"redirectUrl": client_redirect_url
- }).encode('ascii')
- hs_redirect_url = (self.cas_service_url +
- b"/_matrix/client/r0/login/cas/ticket")
- service_param = urllib.parse.urlencode({
- b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)
- }).encode('ascii')
+ client_redirect_url_param = urllib.parse.urlencode(
+ {b"redirectUrl": client_redirect_url}
+ ).encode("ascii")
+ hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket"
+ service_param = urllib.parse.urlencode(
+ {b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)}
+ ).encode("ascii")
return b"%s/login?%s" % (self.cas_server_url, service_param)
@@ -436,7 +420,7 @@ class CasTicketServlet(RestServlet):
uri = self.cas_server_url + "/proxyValidate"
args = {
"ticket": parse_string(request, "ticket", required=True),
- "service": self.cas_service_url
+ "service": self.cas_service_url,
}
try:
body = yield self._http_client.get_raw(uri, args)
@@ -463,7 +447,7 @@ class CasTicketServlet(RestServlet):
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
return self._sso_auth_handler.on_successful_auth(
- user, request, client_redirect_url,
+ user, request, client_redirect_url
)
def parse_cas_response(self, cas_response_body):
@@ -473,7 +457,7 @@ class CasTicketServlet(RestServlet):
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise Exception("root of CAS response is not serviceResponse")
- success = (root[0].tag.endswith("authenticationSuccess"))
+ success = root[0].tag.endswith("authenticationSuccess")
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
@@ -491,11 +475,11 @@ class CasTicketServlet(RestServlet):
raise Exception("CAS response does not contain user")
except Exception:
logger.error("Error parsing CAS response", exc_info=1)
- raise LoginError(401, "Invalid CAS response",
- errcode=Codes.UNAUTHORIZED)
+ raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not success:
- raise LoginError(401, "Unsuccessful CAS response",
- errcode=Codes.UNAUTHORIZED)
+ raise LoginError(
+ 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
+ )
return user, attributes
@@ -507,11 +491,11 @@ class SAMLRedirectServlet(BaseSsoRedirectServlet):
def get_sso_url(self, client_redirect_url):
reqid, info = self._saml_client.prepare_for_authenticate(
- relay_state=client_redirect_url,
+ relay_state=client_redirect_url
)
- for key, value in info['headers']:
- if key == 'Location':
+ for key, value in info["headers"]:
+ if key == "Location":
return value
# this shouldn't happen!
@@ -526,6 +510,7 @@ class SSOAuthHandler(object):
Args:
hs (synapse.server.HomeServer)
"""
+
def __init__(self, hs):
self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler()
@@ -534,8 +519,7 @@ class SSOAuthHandler(object):
@defer.inlineCallbacks
def on_successful_auth(
- self, username, request, client_redirect_url,
- user_display_name=None,
+ self, username, request, client_redirect_url, user_display_name=None
):
"""Called once the user has successfully authenticated with the SSO.
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index b8064f261e..cd711be519 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -46,7 +46,8 @@ class LogoutRestServlet(RestServlet):
yield self._auth_handler.delete_access_token(access_token)
else:
yield self._device_handler.delete_device(
- requester.user.to_string(), requester.device_id)
+ requester.user.to_string(), requester.device_id
+ )
defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index e263da3cb7..3e87f0fdb3 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -47,7 +47,7 @@ class PresenceStatusRestServlet(RestServlet):
if requester.user != user:
allowed = yield self.presence_handler.is_visible(
- observed_user=user, observer_user=requester.user,
+ observed_user=user, observer_user=requester.user
)
if not allowed:
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index e15d9d82a6..4d8ab1f47e 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -63,8 +63,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
except Exception:
defer.returnValue((400, "Unable to parse name"))
- yield self.profile_handler.set_displayname(
- user, requester, new_name, is_admin)
+ yield self.profile_handler.set_displayname(user, requester, new_name, is_admin)
defer.returnValue((200, {}))
@@ -113,8 +112,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
except Exception:
defer.returnValue((400, "Unable to parse name"))
- yield self.profile_handler.set_avatar_url(
- user, requester, new_name, is_admin)
+ yield self.profile_handler.set_avatar_url(user, requester, new_name, is_admin)
defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 3d6326fe2f..e635efb420 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -21,7 +21,11 @@ from synapse.api.errors import (
SynapseError,
UnrecognizedRequestError,
)
-from synapse.http.servlet import RestServlet, parse_json_value_from_request, parse_string
+from synapse.http.servlet import (
+ RestServlet,
+ parse_json_value_from_request,
+ parse_string,
+)
from synapse.push.baserules import BASE_RULE_IDS
from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
@@ -32,7 +36,8 @@ from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundExc
class PushRuleRestServlet(RestServlet):
PATTERNS = client_patterns("/(?P<path>pushrules/.*)$", v1=True)
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
- "Unrecognised request: You probably wanted a trailing slash")
+ "Unrecognised request: You probably wanted a trailing slash"
+ )
def __init__(self, hs):
super(PushRuleRestServlet, self).__init__()
@@ -54,27 +59,25 @@ class PushRuleRestServlet(RestServlet):
requester = yield self.auth.get_user_by_req(request)
- if '/' in spec['rule_id'] or '\\' in spec['rule_id']:
+ if "/" in spec["rule_id"] or "\\" in spec["rule_id"]:
raise SynapseError(400, "rule_id may not contain slashes")
content = parse_json_value_from_request(request)
user_id = requester.user.to_string()
- if 'attr' in spec:
+ if "attr" in spec:
yield self.set_rule_attr(user_id, spec, content)
self.notify_user(user_id)
defer.returnValue((200, {}))
- if spec['rule_id'].startswith('.'):
+ if spec["rule_id"].startswith("."):
# Rule ids starting with '.' are reserved for server default rules.
raise SynapseError(400, "cannot add new rule_ids that start with '.'")
try:
(conditions, actions) = _rule_tuple_from_request_object(
- spec['template'],
- spec['rule_id'],
- content,
+ spec["template"], spec["rule_id"], content
)
except InvalidRuleException as e:
raise SynapseError(400, str(e))
@@ -95,7 +98,7 @@ class PushRuleRestServlet(RestServlet):
conditions=conditions,
actions=actions,
before=before,
- after=after
+ after=after,
)
self.notify_user(user_id)
except InconsistentRuleException as e:
@@ -118,9 +121,7 @@ class PushRuleRestServlet(RestServlet):
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
try:
- yield self.store.delete_push_rule(
- user_id, namespaced_rule_id
- )
+ yield self.store.delete_push_rule(user_id, namespaced_rule_id)
self.notify_user(user_id)
defer.returnValue((200, {}))
except StoreError as e:
@@ -149,10 +150,10 @@ class PushRuleRestServlet(RestServlet):
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
- if path[0] == '':
+ if path[0] == "":
defer.returnValue((200, rules))
- elif path[0] == 'global':
- result = _filter_ruleset_with_path(rules['global'], path[1:])
+ elif path[0] == "global":
+ result = _filter_ruleset_with_path(rules["global"], path[1:])
defer.returnValue((200, result))
else:
raise UnrecognizedRequestError()
@@ -162,12 +163,10 @@ class PushRuleRestServlet(RestServlet):
def notify_user(self, user_id):
stream_id, _ = self.store.get_push_rules_stream_token()
- self.notifier.on_new_event(
- "push_rules_key", stream_id, users=[user_id]
- )
+ self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
def set_rule_attr(self, user_id, spec, val):
- if spec['attr'] == 'enabled':
+ if spec["attr"] == "enabled":
if isinstance(val, dict) and "enabled" in val:
val = val["enabled"]
if not isinstance(val, bool):
@@ -176,14 +175,12 @@ class PushRuleRestServlet(RestServlet):
# bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean")
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
- return self.store.set_push_rule_enabled(
- user_id, namespaced_rule_id, val
- )
- elif spec['attr'] == 'actions':
- actions = val.get('actions')
+ return self.store.set_push_rule_enabled(user_id, namespaced_rule_id, val)
+ elif spec["attr"] == "actions":
+ actions = val.get("actions")
_check_actions(actions)
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
- rule_id = spec['rule_id']
+ rule_id = spec["rule_id"]
is_default_rule = rule_id.startswith(".")
if is_default_rule:
if namespaced_rule_id not in BASE_RULE_IDS:
@@ -210,12 +207,12 @@ def _rule_spec_from_path(path):
"""
if len(path) < 2:
raise UnrecognizedRequestError()
- if path[0] != 'pushrules':
+ if path[0] != "pushrules":
raise UnrecognizedRequestError()
scope = path[1]
path = path[2:]
- if scope != 'global':
+ if scope != "global":
raise UnrecognizedRequestError()
if len(path) == 0:
@@ -229,56 +226,40 @@ def _rule_spec_from_path(path):
rule_id = path[0]
- spec = {
- 'scope': scope,
- 'template': template,
- 'rule_id': rule_id
- }
+ spec = {"scope": scope, "template": template, "rule_id": rule_id}
path = path[1:]
if len(path) > 0 and len(path[0]) > 0:
- spec['attr'] = path[0]
+ spec["attr"] = path[0]
return spec
def _rule_tuple_from_request_object(rule_template, rule_id, req_obj):
- if rule_template in ['override', 'underride']:
- if 'conditions' not in req_obj:
+ if rule_template in ["override", "underride"]:
+ if "conditions" not in req_obj:
raise InvalidRuleException("Missing 'conditions'")
- conditions = req_obj['conditions']
+ conditions = req_obj["conditions"]
for c in conditions:
- if 'kind' not in c:
+ if "kind" not in c:
raise InvalidRuleException("Condition without 'kind'")
- elif rule_template == 'room':
- conditions = [{
- 'kind': 'event_match',
- 'key': 'room_id',
- 'pattern': rule_id
- }]
- elif rule_template == 'sender':
- conditions = [{
- 'kind': 'event_match',
- 'key': 'user_id',
- 'pattern': rule_id
- }]
- elif rule_template == 'content':
- if 'pattern' not in req_obj:
+ elif rule_template == "room":
+ conditions = [{"kind": "event_match", "key": "room_id", "pattern": rule_id}]
+ elif rule_template == "sender":
+ conditions = [{"kind": "event_match", "key": "user_id", "pattern": rule_id}]
+ elif rule_template == "content":
+ if "pattern" not in req_obj:
raise InvalidRuleException("Content rule missing 'pattern'")
- pat = req_obj['pattern']
+ pat = req_obj["pattern"]
- conditions = [{
- 'kind': 'event_match',
- 'key': 'content.body',
- 'pattern': pat
- }]
+ conditions = [{"kind": "event_match", "key": "content.body", "pattern": pat}]
else:
raise InvalidRuleException("Unknown rule template: %s" % (rule_template,))
- if 'actions' not in req_obj:
+ if "actions" not in req_obj:
raise InvalidRuleException("No actions found")
- actions = req_obj['actions']
+ actions = req_obj["actions"]
_check_actions(actions)
@@ -290,9 +271,9 @@ def _check_actions(actions):
raise InvalidRuleException("No actions found")
for a in actions:
- if a in ['notify', 'dont_notify', 'coalesce']:
+ if a in ["notify", "dont_notify", "coalesce"]:
pass
- elif isinstance(a, dict) and 'set_tweak' in a:
+ elif isinstance(a, dict) and "set_tweak" in a:
pass
else:
raise InvalidRuleException("Unrecognised action")
@@ -304,7 +285,7 @@ def _filter_ruleset_with_path(ruleset, path):
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
- if path[0] == '':
+ if path[0] == "":
return ruleset
template_kind = path[0]
if template_kind not in ruleset:
@@ -314,13 +295,13 @@ def _filter_ruleset_with_path(ruleset, path):
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
- if path[0] == '':
+ if path[0] == "":
return ruleset[template_kind]
rule_id = path[0]
the_rule = None
for r in ruleset[template_kind]:
- if r['rule_id'] == rule_id:
+ if r["rule_id"] == rule_id:
the_rule = r
if the_rule is None:
raise NotFoundError
@@ -339,19 +320,19 @@ def _filter_ruleset_with_path(ruleset, path):
def _priority_class_from_spec(spec):
- if spec['template'] not in PRIORITY_CLASS_MAP.keys():
- raise InvalidRuleException("Unknown template: %s" % (spec['template']))
- pc = PRIORITY_CLASS_MAP[spec['template']]
+ if spec["template"] not in PRIORITY_CLASS_MAP.keys():
+ raise InvalidRuleException("Unknown template: %s" % (spec["template"]))
+ pc = PRIORITY_CLASS_MAP[spec["template"]]
return pc
def _namespaced_rule_id_from_spec(spec):
- return _namespaced_rule_id(spec, spec['rule_id'])
+ return _namespaced_rule_id(spec, spec["rule_id"])
def _namespaced_rule_id(spec, rule_id):
- return "global/%s/%s" % (spec['template'], rule_id)
+ return "global/%s/%s" % (spec["template"], rule_id)
class InvalidRuleException(Exception):
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 15d860db37..e9246018df 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -44,9 +44,7 @@ class PushersRestServlet(RestServlet):
requester = yield self.auth.get_user_by_req(request)
user = requester.user
- pushers = yield self.hs.get_datastore().get_pushers_by_user_id(
- user.to_string()
- )
+ pushers = yield self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
allowed_keys = [
"app_display_name",
@@ -87,50 +85,61 @@ class PushersSetRestServlet(RestServlet):
content = parse_json_object_from_request(request)
- if ('pushkey' in content and 'app_id' in content
- and 'kind' in content and
- content['kind'] is None):
+ if (
+ "pushkey" in content
+ and "app_id" in content
+ and "kind" in content
+ and content["kind"] is None
+ ):
yield self.pusher_pool.remove_pusher(
- content['app_id'], content['pushkey'], user_id=user.to_string()
+ content["app_id"], content["pushkey"], user_id=user.to_string()
)
defer.returnValue((200, {}))
assert_params_in_dict(
content,
- ['kind', 'app_id', 'app_display_name',
- 'device_display_name', 'pushkey', 'lang', 'data']
+ [
+ "kind",
+ "app_id",
+ "app_display_name",
+ "device_display_name",
+ "pushkey",
+ "lang",
+ "data",
+ ],
)
- logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind'])
+ logger.debug("set pushkey %s to kind %s", content["pushkey"], content["kind"])
logger.debug("Got pushers request with body: %r", content)
append = False
- if 'append' in content:
- append = content['append']
+ if "append" in content:
+ append = content["append"]
if not append:
yield self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
- app_id=content['app_id'],
- pushkey=content['pushkey'],
- not_user_id=user.to_string()
+ app_id=content["app_id"],
+ pushkey=content["pushkey"],
+ not_user_id=user.to_string(),
)
try:
yield self.pusher_pool.add_pusher(
user_id=user.to_string(),
access_token=requester.access_token_id,
- kind=content['kind'],
- app_id=content['app_id'],
- app_display_name=content['app_display_name'],
- device_display_name=content['device_display_name'],
- pushkey=content['pushkey'],
- lang=content['lang'],
- data=content['data'],
- profile_tag=content.get('profile_tag', ""),
+ kind=content["kind"],
+ app_id=content["app_id"],
+ app_display_name=content["app_display_name"],
+ device_display_name=content["device_display_name"],
+ pushkey=content["pushkey"],
+ lang=content["lang"],
+ data=content["data"],
+ profile_tag=content.get("profile_tag", ""),
)
except PusherConfigException as pce:
- raise SynapseError(400, "Config Error: " + str(pce),
- errcode=Codes.MISSING_PARAM)
+ raise SynapseError(
+ 400, "Config Error: " + str(pce), errcode=Codes.MISSING_PARAM
+ )
self.notifier.on_new_replication_data()
@@ -144,6 +153,7 @@ class PushersRemoveRestServlet(RestServlet):
"""
To allow pusher to be delete by clicking a link (ie. GET request)
"""
+
PATTERNS = client_patterns("/pushers/remove$", v1=True)
SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>"
@@ -164,9 +174,7 @@ class PushersRemoveRestServlet(RestServlet):
try:
yield self.pusher_pool.remove_pusher(
- app_id=app_id,
- pushkey=pushkey,
- user_id=user.to_string(),
+ app_id=app_id, pushkey=pushkey, user_id=user.to_string()
)
except StoreError as se:
if se.code != 404:
@@ -177,9 +185,9 @@ class PushersRemoveRestServlet(RestServlet):
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Content-Length", b"%d" % (
- len(PushersRemoveRestServlet.SUCCESS_HTML),
- ))
+ request.setHeader(
+ b"Content-Length", b"%d" % (len(PushersRemoveRestServlet.SUCCESS_HTML),)
+ )
request.write(PushersRemoveRestServlet.SUCCESS_HTML)
finish_request(request)
defer.returnValue(None)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index e8f672c4ba..cca7e45ddb 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -61,18 +61,16 @@ class RoomCreateRestServlet(TransactionRestServlet):
PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server)
# define CORS for all of /rooms in RoomCreateRestServlet for simplicity
- http_server.register_paths("OPTIONS",
- client_patterns("/rooms(?:/.*)?$", v1=True),
- self.on_OPTIONS)
+ http_server.register_paths(
+ "OPTIONS", client_patterns("/rooms(?:/.*)?$", v1=True), self.on_OPTIONS
+ )
# define CORS for /createRoom[/txnid]
- http_server.register_paths("OPTIONS",
- client_patterns("/createRoom(?:/.*)?$", v1=True),
- self.on_OPTIONS)
+ http_server.register_paths(
+ "OPTIONS", client_patterns("/createRoom(?:/.*)?$", v1=True), self.on_OPTIONS
+ )
def on_PUT(self, request, txn_id):
- return self.txns.fetch_or_execute_request(
- request, self.on_POST, request
- )
+ return self.txns.fetch_or_execute_request(request, self.on_POST, request)
@defer.inlineCallbacks
def on_POST(self, request):
@@ -107,21 +105,23 @@ class RoomStateEventRestServlet(TransactionRestServlet):
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
# /room/$roomid/state/$eventtype/$statekey
- state_key = ("/rooms/(?P<room_id>[^/]*)/state/"
- "(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$")
-
- http_server.register_paths("GET",
- client_patterns(state_key, v1=True),
- self.on_GET)
- http_server.register_paths("PUT",
- client_patterns(state_key, v1=True),
- self.on_PUT)
- http_server.register_paths("GET",
- client_patterns(no_state_key, v1=True),
- self.on_GET_no_state_key)
- http_server.register_paths("PUT",
- client_patterns(no_state_key, v1=True),
- self.on_PUT_no_state_key)
+ state_key = (
+ "/rooms/(?P<room_id>[^/]*)/state/"
+ "(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$"
+ )
+
+ http_server.register_paths(
+ "GET", client_patterns(state_key, v1=True), self.on_GET
+ )
+ http_server.register_paths(
+ "PUT", client_patterns(state_key, v1=True), self.on_PUT
+ )
+ http_server.register_paths(
+ "GET", client_patterns(no_state_key, v1=True), self.on_GET_no_state_key
+ )
+ http_server.register_paths(
+ "PUT", client_patterns(no_state_key, v1=True), self.on_PUT_no_state_key
+ )
def on_GET_no_state_key(self, request, room_id, event_type):
return self.on_GET(request, room_id, event_type, "")
@@ -132,8 +132,9 @@ class RoomStateEventRestServlet(TransactionRestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_type, state_key):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- format = parse_string(request, "format", default="content",
- allowed_values=["content", "event"])
+ format = parse_string(
+ request, "format", default="content", allowed_values=["content", "event"]
+ )
msg_handler = self.message_handler
data = yield msg_handler.get_room_data(
@@ -145,9 +146,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
)
if not data:
- raise SynapseError(
- 404, "Event not found.", errcode=Codes.NOT_FOUND
- )
+ raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
if format == "event":
event = format_event_for_client_v2(data.get_dict())
@@ -182,9 +181,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
)
else:
event = yield self.event_creation_handler.create_and_send_nonmember_event(
- requester,
- event_dict,
- txn_id=txn_id,
+ requester, event_dict, txn_id=txn_id
)
ret = {}
@@ -195,7 +192,6 @@ class RoomStateEventRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(TransactionRestServlet):
-
def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
@@ -203,7 +199,7 @@ class RoomSendEventRestServlet(TransactionRestServlet):
def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id]
- PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)")
+ PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
register_txn_path(self, PATTERNS, http_server, with_get=True)
@defer.inlineCallbacks
@@ -218,13 +214,11 @@ class RoomSendEventRestServlet(TransactionRestServlet):
"sender": requester.user.to_string(),
}
- if b'ts' in request.args and requester.app_service:
- event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
+ if b"ts" in request.args and requester.app_service:
+ event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
event = yield self.event_creation_handler.create_and_send_nonmember_event(
- requester,
- event_dict,
- txn_id=txn_id,
+ requester, event_dict, txn_id=txn_id
)
defer.returnValue((200, {"event_id": event.event_id}))
@@ -247,15 +241,12 @@ class JoinRoomAliasServlet(TransactionRestServlet):
def register(self, http_server):
# /join/$room_identifier[/$txn_id]
- PATTERNS = ("/join/(?P<room_identifier>[^/]*)")
+ PATTERNS = "/join/(?P<room_identifier>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks
def on_POST(self, request, room_identifier, txn_id=None):
- requester = yield self.auth.get_user_by_req(
- request,
- allow_guest=True,
- )
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
try:
content = parse_json_object_from_request(request)
@@ -268,7 +259,7 @@ class JoinRoomAliasServlet(TransactionRestServlet):
room_id = room_identifier
try:
remote_room_hosts = [
- x.decode('ascii') for x in request.args[b"server_name"]
+ x.decode("ascii") for x in request.args[b"server_name"]
]
except Exception:
remote_room_hosts = None
@@ -278,9 +269,9 @@ class JoinRoomAliasServlet(TransactionRestServlet):
room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias)
room_id = room_id.to_string()
else:
- raise SynapseError(400, "%s was not legal room ID or room alias" % (
- room_identifier,
- ))
+ raise SynapseError(
+ 400, "%s was not legal room ID or room alias" % (room_identifier,)
+ )
yield self.room_member_handler.update_membership(
requester=requester,
@@ -320,7 +311,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
# Option to allow servers to require auth when accessing
# /publicRooms via CS API. This is especially helpful in private
# federations.
- if self.hs.config.restrict_public_rooms_to_local_users:
+ if not self.hs.config.allow_public_rooms_without_auth:
raise
# We allow people to not be authed if they're just looking at our
@@ -339,14 +330,11 @@ class PublicRoomListRestServlet(TransactionRestServlet):
handler = self.hs.get_room_list_handler()
if server:
data = yield handler.get_remote_public_room_list(
- server,
- limit=limit,
- since_token=since_token,
+ server, limit=limit, since_token=since_token
)
else:
data = yield handler.get_local_public_room_list(
- limit=limit,
- since_token=since_token,
+ limit=limit, since_token=since_token
)
defer.returnValue((200, data))
@@ -439,16 +427,13 @@ class RoomMemberListRestServlet(RestServlet):
chunk = []
for event in events:
- if (
- (membership and event['content'].get("membership") != membership) or
- (not_membership and event['content'].get("membership") == not_membership)
+ if (membership and event["content"].get("membership") != membership) or (
+ not_membership and event["content"].get("membership") == not_membership
):
continue
chunk.append(event)
- defer.returnValue((200, {
- "chunk": chunk
- }))
+ defer.returnValue((200, {"chunk": chunk}))
# deprecated in favour of /members?membership=join?
@@ -466,12 +451,10 @@ class JoinedRoomMemberListRestServlet(RestServlet):
requester = yield self.auth.get_user_by_req(request)
users_with_profile = yield self.message_handler.get_joined_members(
- requester, room_id,
+ requester, room_id
)
- defer.returnValue((200, {
- "joined": users_with_profile,
- }))
+ defer.returnValue((200, {"joined": users_with_profile}))
# TODO: Needs better unit testing
@@ -486,9 +469,7 @@ class RoomMessageListRestServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- pagination_config = PaginationConfig.from_request(
- request, default_limit=10,
- )
+ pagination_config = PaginationConfig.from_request(request, default_limit=10)
as_client_event = b"raw" not in request.args
filter_bytes = parse_string(request, b"filter", encoding=None)
if filter_bytes:
@@ -544,9 +525,7 @@ class RoomInitialSyncRestServlet(RestServlet):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(request)
content = yield self.initial_sync_handler.room_initial_sync(
- room_id=room_id,
- requester=requester,
- pagin_config=pagination_config,
+ room_id=room_id, requester=requester, pagin_config=pagination_config
)
defer.returnValue((200, content))
@@ -603,30 +582,24 @@ class RoomEventContextServlet(RestServlet):
event_filter = None
results = yield self.room_context_handler.get_event_context(
- requester.user,
- room_id,
- event_id,
- limit,
- event_filter,
+ requester.user, room_id, event_id, limit, event_filter
)
if not results:
- raise SynapseError(
- 404, "Event not found.", errcode=Codes.NOT_FOUND
- )
+ raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
time_now = self.clock.time_msec()
results["events_before"] = yield self._event_serializer.serialize_events(
- results["events_before"], time_now,
+ results["events_before"], time_now
)
results["event"] = yield self._event_serializer.serialize_event(
- results["event"], time_now,
+ results["event"], time_now
)
results["events_after"] = yield self._event_serializer.serialize_events(
- results["events_after"], time_now,
+ results["events_after"], time_now
)
results["state"] = yield self._event_serializer.serialize_events(
- results["state"], time_now,
+ results["state"], time_now
)
defer.returnValue((200, results))
@@ -639,20 +612,14 @@ class RoomForgetRestServlet(TransactionRestServlet):
self.auth = hs.get_auth()
def register(self, http_server):
- PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
+ PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks
def on_POST(self, request, room_id, txn_id=None):
- requester = yield self.auth.get_user_by_req(
- request,
- allow_guest=False,
- )
+ requester = yield self.auth.get_user_by_req(request, allow_guest=False)
- yield self.room_member_handler.forget(
- user=requester.user,
- room_id=room_id,
- )
+ yield self.room_member_handler.forget(user=requester.user, room_id=room_id)
defer.returnValue((200, {}))
@@ -664,7 +631,6 @@ class RoomForgetRestServlet(TransactionRestServlet):
# TODO: Needs unit testing
class RoomMembershipRestServlet(TransactionRestServlet):
-
def __init__(self, hs):
super(RoomMembershipRestServlet, self).__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
@@ -672,20 +638,19 @@ class RoomMembershipRestServlet(TransactionRestServlet):
def register(self, http_server):
# /rooms/$roomid/[invite|join|leave]
- PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
- "(?P<membership_action>join|invite|leave|ban|unban|kick)")
+ PATTERNS = (
+ "/rooms/(?P<room_id>[^/]*)/"
+ "(?P<membership_action>join|invite|leave|ban|unban|kick)"
+ )
register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks
def on_POST(self, request, room_id, membership_action, txn_id=None):
- requester = yield self.auth.get_user_by_req(
- request,
- allow_guest=True,
- )
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
if requester.is_guest and membership_action not in {
Membership.JOIN,
- Membership.LEAVE
+ Membership.LEAVE,
}:
raise AuthError(403, "Guest access not allowed")
@@ -704,7 +669,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
content["address"],
content["id_server"],
requester,
- txn_id
+ txn_id,
)
defer.returnValue((200, {}))
return
@@ -715,8 +680,8 @@ class RoomMembershipRestServlet(TransactionRestServlet):
target = UserID.from_string(content["user_id"])
event_content = None
- if 'reason' in content and membership_action in ['kick', 'ban']:
- event_content = {'reason': content['reason']}
+ if "reason" in content and membership_action in ["kick", "ban"]:
+ event_content = {"reason": content["reason"]}
yield self.room_member_handler.update_membership(
requester=requester,
@@ -755,7 +720,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
self.auth = hs.get_auth()
def register(self, http_server):
- PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
+ PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks
@@ -817,9 +782,7 @@ class RoomTypingRestServlet(RestServlet):
)
else:
yield self.typing_handler.stopped_typing(
- target_user=target_user,
- auth_user=requester.user,
- room_id=room_id,
+ target_user=target_user, auth_user=requester.user, room_id=room_id
)
defer.returnValue((200, {}))
@@ -841,9 +804,7 @@ class SearchRestServlet(RestServlet):
batch = parse_string(request, "next_batch")
results = yield self.handlers.search_handler.search(
- requester.user,
- content,
- batch,
+ requester.user, content, batch
)
defer.returnValue((200, results))
@@ -879,20 +840,18 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
with_get: True to also register respective GET paths for the PUTs.
"""
http_server.register_paths(
- "POST",
- client_patterns(regex_string + "$", v1=True),
- servlet.on_POST
+ "POST", client_patterns(regex_string + "$", v1=True), servlet.on_POST
)
http_server.register_paths(
"PUT",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
- servlet.on_PUT
+ servlet.on_PUT,
)
if with_get:
http_server.register_paths(
"GET",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
- servlet.on_GET
+ servlet.on_GET,
)
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 6381049210..41b3171ac8 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -34,8 +34,7 @@ class VoipRestServlet(RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(
- request,
- self.hs.config.turn_allow_guests
+ request, self.hs.config.turn_allow_guests
)
turnUris = self.hs.config.turn_uris
@@ -49,9 +48,7 @@ class VoipRestServlet(RestServlet):
username = "%d:%s" % (expiry, requester.user.to_string())
mac = hmac.new(
- turnSecret.encode(),
- msg=username.encode(),
- digestmod=hashlib.sha1
+ turnSecret.encode(), msg=username.encode(), digestmod=hashlib.sha1
)
# We need to use standard padded base64 encoding here
# encode_base64 because we need to add the standard padding to get the
@@ -65,12 +62,17 @@ class VoipRestServlet(RestServlet):
else:
defer.returnValue((200, {}))
- defer.returnValue((200, {
- 'username': username,
- 'password': password,
- 'ttl': userLifetime / 1000,
- 'uris': turnUris,
- }))
+ defer.returnValue(
+ (
+ 200,
+ {
+ "username": username,
+ "password": password,
+ "ttl": userLifetime / 1000,
+ "uris": turnUris,
+ },
+ )
+ )
def on_OPTIONS(self, request):
return (200, {})
diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py
index 5236d5d566..e3d59ac3ac 100644
--- a/synapse/rest/client/v2_alpha/_base.py
+++ b/synapse/rest/client/v2_alpha/_base.py
@@ -52,11 +52,11 @@ def client_patterns(path_regex, releases=(0,), unstable=True, v1=False):
def set_timeline_upper_limit(filter_json, filter_timeline_limit):
if filter_timeline_limit < 0:
return # no upper limits
- timeline = filter_json.get('room', {}).get('timeline', {})
- if 'limit' in timeline:
- filter_json['room']['timeline']["limit"] = min(
- filter_json['room']['timeline']['limit'],
- filter_timeline_limit)
+ timeline = filter_json.get("room", {}).get("timeline", {})
+ if "limit" in timeline:
+ filter_json["room"]["timeline"]["limit"] = min(
+ filter_json["room"]["timeline"]["limit"], filter_timeline_limit
+ )
def interactive_auth_handler(orig):
@@ -74,10 +74,12 @@ def interactive_auth_handler(orig):
# ...
yield self.auth_handler.check_auth
"""
+
def wrapped(*args, **kwargs):
res = defer.maybeDeferred(orig, *args, **kwargs)
res.addErrback(_catch_incomplete_interactive_auth)
return res
+
return wrapped
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index e4c63b69b9..f143d8b85c 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -15,7 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import re
from six.moves import http_client
@@ -53,6 +52,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
if self.config.email_password_reset_behaviour == "local":
from synapse.push.mailer import Mailer, load_jinja2_templates
+
templates = load_jinja2_templates(
config=hs.config,
template_html_name=hs.config.email_password_reset_template_html,
@@ -68,13 +68,17 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
if self.config.email_password_reset_behaviour == "off":
- raise SynapseError(400, "Password resets have been disabled on this server")
+ if self.config.password_resets_were_disabled_due_to_email_config:
+ logger.warn(
+ "User password resets have been disabled due to lack of email config"
+ )
+ raise SynapseError(
+ 400, "Email-based password resets have been disabled on this server"
+ )
body = parse_json_object_from_request(request)
- assert_params_in_dict(body, [
- 'client_secret', 'email', 'send_attempt'
- ])
+ assert_params_in_dict(body, ["client_secret", "email", "send_attempt"])
# Extract params from body
client_secret = body["client_secret"]
@@ -90,24 +94,24 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
- 'email', email,
+ "email", email
)
if existingUid is None:
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
if self.config.email_password_reset_behaviour == "remote":
- if 'id_server' not in body:
+ if "id_server" not in body:
raise SynapseError(400, "Missing 'id_server' param in body")
# Have the identity server handle the password reset flow
ret = yield self.identity_handler.requestEmailToken(
- body["id_server"], email, client_secret, send_attempt, next_link,
+ body["id_server"], email, client_secret, send_attempt, next_link
)
else:
# Send password reset emails from Synapse
sid = yield self.send_password_reset(
- email, client_secret, send_attempt, next_link,
+ email, client_secret, send_attempt, next_link
)
# Wrap the session id in a JSON object
@@ -116,13 +120,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
defer.returnValue((200, ret))
@defer.inlineCallbacks
- def send_password_reset(
- self,
- email,
- client_secret,
- send_attempt,
- next_link=None,
- ):
+ def send_password_reset(self, email, client_secret, send_attempt, next_link=None):
"""Send a password reset email
Args:
@@ -139,14 +137,14 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# Check that this email/client_secret/send_attempt combo is new or
# greater than what we've seen previously
session = yield self.datastore.get_threepid_validation_session(
- "email", client_secret, address=email, validated=False,
+ "email", client_secret, address=email, validated=False
)
# Check to see if a session already exists and that it is not yet
# marked as validated
if session and session.get("validated_at") is None:
- session_id = session['session_id']
- last_send_attempt = session['last_send_attempt']
+ session_id = session["session_id"]
+ last_send_attempt = session["last_send_attempt"]
# Check that the send_attempt is higher than previous attempts
if send_attempt <= last_send_attempt:
@@ -164,22 +162,27 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# and session_id
try:
yield self.mailer.send_password_reset_mail(
- email, token, client_secret, session_id,
+ email, token, client_secret, session_id
)
except Exception:
- logger.exception(
- "Error sending a password reset email to %s", email,
- )
+ logger.exception("Error sending a password reset email to %s", email)
raise SynapseError(
500, "An error was encountered when sending the password reset email"
)
- token_expires = (self.hs.clock.time_msec() +
- self.config.email_validation_token_lifetime)
+ token_expires = (
+ self.hs.clock.time_msec() + self.config.email_validation_token_lifetime
+ )
yield self.datastore.start_or_continue_validation_session(
- "email", email, session_id, client_secret, send_attempt,
- next_link, token, token_expires,
+ "email",
+ email,
+ session_id,
+ client_secret,
+ send_attempt,
+ next_link,
+ token,
+ token_expires,
)
defer.returnValue(session_id)
@@ -196,17 +199,14 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
- if not self.config.email_password_reset_behaviour == "off":
- raise SynapseError(400, "Password resets have been disabled on this server")
-
body = parse_json_object_from_request(request)
- assert_params_in_dict(body, [
- 'id_server', 'client_secret',
- 'country', 'phone_number', 'send_attempt',
- ])
+ assert_params_in_dict(
+ body,
+ ["id_server", "client_secret", "country", "phone_number", "send_attempt"],
+ )
- msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+ msisdn = phone_number_to_msisdn(body["country"], body["phone_number"])
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
@@ -215,9 +215,7 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- existingUid = yield self.datastore.get_user_id_by_threepid(
- 'msisdn', msisdn
- )
+ existingUid = yield self.datastore.get_user_id_by_threepid("msisdn", msisdn)
if existingUid is None:
raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND)
@@ -228,9 +226,10 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
class PasswordResetSubmitTokenServlet(RestServlet):
"""Handles 3PID validation token submission"""
- PATTERNS = [
- re.compile("^/_synapse/password_reset/(?P<medium>[^/]*)/submit_token/*$"),
- ]
+
+ PATTERNS = client_patterns(
+ "/password_reset/(?P<medium>[^/]*)/submit_token/*$", releases=(), unstable=True
+ )
def __init__(self, hs):
"""
@@ -248,8 +247,15 @@ class PasswordResetSubmitTokenServlet(RestServlet):
def on_GET(self, request, medium):
if medium != "email":
raise SynapseError(
- 400,
- "This medium is currently not supported for password resets",
+ 400, "This medium is currently not supported for password resets"
+ )
+ if self.config.email_password_reset_behaviour == "off":
+ if self.config.password_resets_were_disabled_due_to_email_config:
+ logger.warn(
+ "User password resets have been disabled due to lack of email config"
+ )
+ raise SynapseError(
+ 400, "Email-based password resets have been disabled on this server"
)
sid = parse_string(request, "sid")
@@ -260,10 +266,7 @@ class PasswordResetSubmitTokenServlet(RestServlet):
try:
# Mark the session as valid
next_link = yield self.datastore.validate_threepid_session(
- sid,
- client_secret,
- token,
- self.clock.time_msec(),
+ sid, client_secret, token, self.clock.time_msec()
)
# Perform a 302 redirect if next_link is set
@@ -286,13 +289,11 @@ class PasswordResetSubmitTokenServlet(RestServlet):
html = self.load_jinja2_template(
self.config.email_template_dir,
self.config.email_password_reset_failure_template,
- template_vars={
- "failure_reason": e.msg,
- }
+ template_vars={"failure_reason": e.msg},
)
request.setResponseCode(e.code)
- request.write(html.encode('utf-8'))
+ request.write(html.encode("utf-8"))
finish_request(request)
defer.returnValue(None)
@@ -318,20 +319,14 @@ class PasswordResetSubmitTokenServlet(RestServlet):
def on_POST(self, request, medium):
if medium != "email":
raise SynapseError(
- 400,
- "This medium is currently not supported for password resets",
+ 400, "This medium is currently not supported for password resets"
)
body = parse_json_object_from_request(request)
- assert_params_in_dict(body, [
- 'sid', 'client_secret', 'token',
- ])
+ assert_params_in_dict(body, ["sid", "client_secret", "token"])
valid, _ = yield self.datastore.validate_threepid_validation_token(
- body['sid'],
- body['client_secret'],
- body['token'],
- self.clock.time_msec(),
+ body["sid"], body["client_secret"], body["token"], self.clock.time_msec()
)
response_code = 200 if valid else 400
@@ -367,29 +362,30 @@ class PasswordRestServlet(RestServlet):
if self.auth.has_access_token(request):
requester = yield self.auth.get_user_by_req(request)
params = yield self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request),
+ requester, body, self.hs.get_ip_from_request(request)
)
user_id = requester.user.to_string()
else:
requester = None
result, params, _ = yield self.auth_handler.check_auth(
[[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]],
- body, self.hs.get_ip_from_request(request),
+ body,
+ self.hs.get_ip_from_request(request),
password_servlet=True,
)
if LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]
- if 'medium' not in threepid or 'address' not in threepid:
+ if "medium" not in threepid or "address" not in threepid:
raise SynapseError(500, "Malformed threepid")
- if threepid['medium'] == 'email':
+ if threepid["medium"] == "email":
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
- threepid['address'] = threepid['address'].lower()
+ threepid["address"] = threepid["address"].lower()
# if using email, we must know about the email they're authing with!
threepid_user_id = yield self.datastore.get_user_id_by_threepid(
- threepid['medium'], threepid['address']
+ threepid["medium"], threepid["address"]
)
if not threepid_user_id:
raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
@@ -399,11 +395,9 @@ class PasswordRestServlet(RestServlet):
raise SynapseError(500, "", Codes.UNKNOWN)
assert_params_in_dict(params, ["new_password"])
- new_password = params['new_password']
+ new_password = params["new_password"]
- yield self._set_password_handler.set_password(
- user_id, new_password, requester
- )
+ yield self._set_password_handler.set_password(user_id, new_password, requester)
defer.returnValue((200, {}))
@@ -438,25 +432,22 @@ class DeactivateAccountRestServlet(RestServlet):
# allow ASes to dectivate their own users
if requester.app_service:
yield self._deactivate_account_handler.deactivate_account(
- requester.user.to_string(), erase,
+ requester.user.to_string(), erase
)
defer.returnValue((200, {}))
yield self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request),
+ requester, body, self.hs.get_ip_from_request(request)
)
result = yield self._deactivate_account_handler.deactivate_account(
- requester.user.to_string(), erase,
- id_server=body.get("id_server"),
+ requester.user.to_string(), erase, id_server=body.get("id_server")
)
if result:
id_server_unbind_result = "success"
else:
id_server_unbind_result = "no-support"
- defer.returnValue((200, {
- "id_server_unbind_result": id_server_unbind_result,
- }))
+ defer.returnValue((200, {"id_server_unbind_result": id_server_unbind_result}))
class EmailThreepidRequestTokenRestServlet(RestServlet):
@@ -472,11 +463,10 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
assert_params_in_dict(
- body,
- ['id_server', 'client_secret', 'email', 'send_attempt'],
+ body, ["id_server", "client_secret", "email", "send_attempt"]
)
- if not check_3pid_allowed(self.hs, "email", body['email']):
+ if not check_3pid_allowed(self.hs, "email", body["email"]):
raise SynapseError(
403,
"Your email domain is not authorized on this server",
@@ -484,7 +474,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
)
existingUid = yield self.datastore.get_user_id_by_threepid(
- 'email', body['email']
+ "email", body["email"]
)
if existingUid is not None:
@@ -506,12 +496,12 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
- assert_params_in_dict(body, [
- 'id_server', 'client_secret',
- 'country', 'phone_number', 'send_attempt',
- ])
+ assert_params_in_dict(
+ body,
+ ["id_server", "client_secret", "country", "phone_number", "send_attempt"],
+ )
- msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+ msisdn = phone_number_to_msisdn(body["country"], body["phone_number"])
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
@@ -520,9 +510,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- existingUid = yield self.datastore.get_user_id_by_threepid(
- 'msisdn', msisdn
- )
+ existingUid = yield self.datastore.get_user_id_by_threepid("msisdn", msisdn)
if existingUid is not None:
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
@@ -546,18 +534,16 @@ class ThreepidRestServlet(RestServlet):
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
- threepids = yield self.datastore.user_get_threepids(
- requester.user.to_string()
- )
+ threepids = yield self.datastore.user_get_threepids(requester.user.to_string())
- defer.returnValue((200, {'threepids': threepids}))
+ defer.returnValue((200, {"threepids": threepids}))
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
- threePidCreds = body.get('threePidCreds')
- threePidCreds = body.get('three_pid_creds', threePidCreds)
+ threePidCreds = body.get("threePidCreds")
+ threePidCreds = body.get("three_pid_creds", threePidCreds)
if threePidCreds is None:
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
@@ -567,30 +553,20 @@ class ThreepidRestServlet(RestServlet):
threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)
if not threepid:
- raise SynapseError(
- 400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED
- )
+ raise SynapseError(400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED)
- for reqd in ['medium', 'address', 'validated_at']:
+ for reqd in ["medium", "address", "validated_at"]:
if reqd not in threepid:
logger.warn("Couldn't add 3pid: invalid response from ID server")
raise SynapseError(500, "Invalid response from ID Server")
yield self.auth_handler.add_threepid(
- user_id,
- threepid['medium'],
- threepid['address'],
- threepid['validated_at'],
+ user_id, threepid["medium"], threepid["address"], threepid["validated_at"]
)
- if 'bind' in body and body['bind']:
- logger.debug(
- "Binding threepid %s to %s",
- threepid, user_id
- )
- yield self.identity_handler.bind_threepid(
- threePidCreds, user_id
- )
+ if "bind" in body and body["bind"]:
+ logger.debug("Binding threepid %s to %s", threepid, user_id)
+ yield self.identity_handler.bind_threepid(threePidCreds, user_id)
defer.returnValue((200, {}))
@@ -606,14 +582,14 @@ class ThreepidDeleteRestServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
body = parse_json_object_from_request(request)
- assert_params_in_dict(body, ['medium', 'address'])
+ assert_params_in_dict(body, ["medium", "address"])
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
try:
ret = yield self.auth_handler.delete_threepid(
- user_id, body['medium'], body['address'], body.get("id_server"),
+ user_id, body["medium"], body["address"], body.get("id_server")
)
except Exception:
# NB. This endpoint should succeed if there is nothing to
@@ -627,9 +603,7 @@ class ThreepidDeleteRestServlet(RestServlet):
else:
id_server_unbind_result = "no-support"
- defer.returnValue((200, {
- "id_server_unbind_result": id_server_unbind_result,
- }))
+ defer.returnValue((200, {"id_server_unbind_result": id_server_unbind_result}))
class WhoamiRestServlet(RestServlet):
@@ -643,7 +617,7 @@ class WhoamiRestServlet(RestServlet):
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
- defer.returnValue((200, {'user_id': requester.user.to_string()}))
+ defer.returnValue((200, {"user_id": requester.user.to_string()}))
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index 574a6298ce..f155c26259 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -30,6 +30,7 @@ class AccountDataServlet(RestServlet):
PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1
GET /user/{user_id}/account_data/{account_dataType} HTTP/1.1
"""
+
PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)"
)
@@ -52,9 +53,7 @@ class AccountDataServlet(RestServlet):
user_id, account_data_type, body
)
- self.notifier.on_new_event(
- "account_data_key", max_id, users=[user_id]
- )
+ self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
defer.returnValue((200, {}))
@@ -65,7 +64,7 @@ class AccountDataServlet(RestServlet):
raise AuthError(403, "Cannot get account data for other users.")
event = yield self.store.get_global_account_data_by_type_for_user(
- account_data_type, user_id,
+ account_data_type, user_id
)
if event is None:
@@ -79,6 +78,7 @@ class RoomAccountDataServlet(RestServlet):
PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1
GET /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1
"""
+
PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)"
"/rooms/(?P<room_id>[^/]*)"
@@ -103,16 +103,14 @@ class RoomAccountDataServlet(RestServlet):
raise SynapseError(
405,
"Cannot set m.fully_read through this API."
- " Use /rooms/!roomId:server.name/read_markers"
+ " Use /rooms/!roomId:server.name/read_markers",
)
max_id = yield self.store.add_account_data_to_room(
user_id, room_id, account_data_type, body
)
- self.notifier.on_new_event(
- "account_data_key", max_id, users=[user_id]
- )
+ self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
defer.returnValue((200, {}))
@@ -123,7 +121,7 @@ class RoomAccountDataServlet(RestServlet):
raise AuthError(403, "Cannot get account data for other users.")
event = yield self.store.get_account_data_for_room_and_type(
- user_id, room_id, account_data_type,
+ user_id, room_id, account_data_type
)
if event is None:
diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py
index 63bdc33564..d29c10b83d 100644
--- a/synapse/rest/client/v2_alpha/account_validity.py
+++ b/synapse/rest/client/v2_alpha/account_validity.py
@@ -28,7 +28,9 @@ logger = logging.getLogger(__name__)
class AccountValidityRenewServlet(RestServlet):
PATTERNS = client_patterns("/account_validity/renew$")
- SUCCESS_HTML = b"<html><body>Your account has been successfully renewed.</body><html>"
+ SUCCESS_HTML = (
+ b"<html><body>Your account has been successfully renewed.</body><html>"
+ )
def __init__(self, hs):
"""
@@ -47,13 +49,13 @@ class AccountValidityRenewServlet(RestServlet):
raise SynapseError(400, "Missing renewal token")
renewal_token = request.args[b"token"][0]
- yield self.account_activity_handler.renew_account(renewal_token.decode('utf8'))
+ yield self.account_activity_handler.renew_account(renewal_token.decode("utf8"))
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Content-Length", b"%d" % (
- len(AccountValidityRenewServlet.SUCCESS_HTML),
- ))
+ request.setHeader(
+ b"Content-Length", b"%d" % (len(AccountValidityRenewServlet.SUCCESS_HTML),)
+ )
request.write(AccountValidityRenewServlet.SUCCESS_HTML)
finish_request(request)
defer.returnValue(None)
@@ -77,7 +79,9 @@ class AccountValiditySendMailServlet(RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
if not self.account_validity.renew_by_email_enabled:
- raise AuthError(403, "Account renewal via email is disabled on this server.")
+ raise AuthError(
+ 403, "Account renewal via email is disabled on this server."
+ )
requester = yield self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string()
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 8dfe5cba02..bebc2951e7 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -122,6 +122,7 @@ class AuthRestServlet(RestServlet):
cannot be handled in the normal flow (with requests to the same endpoint).
Current use is for web fallback auth.
"""
+
PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
def __init__(self, hs):
@@ -138,11 +139,10 @@ class AuthRestServlet(RestServlet):
if stagetype == LoginType.RECAPTCHA:
html = RECAPTCHA_TEMPLATE % {
- 'session': session,
- 'myurl': "%s/r0/auth/%s/fallback/web" % (
- CLIENT_API_PREFIX, LoginType.RECAPTCHA
- ),
- 'sitekey': self.hs.config.recaptcha_public_key,
+ "session": session,
+ "myurl": "%s/r0/auth/%s/fallback/web"
+ % (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
+ "sitekey": self.hs.config.recaptcha_public_key,
}
html_bytes = html.encode("utf8")
request.setResponseCode(200)
@@ -154,14 +154,11 @@ class AuthRestServlet(RestServlet):
return None
elif stagetype == LoginType.TERMS:
html = TERMS_TEMPLATE % {
- 'session': session,
- 'terms_url': "%s_matrix/consent?v=%s" % (
- self.hs.config.public_baseurl,
- self.hs.config.user_consent_version,
- ),
- 'myurl': "%s/r0/auth/%s/fallback/web" % (
- CLIENT_API_PREFIX, LoginType.TERMS
- ),
+ "session": session,
+ "terms_url": "%s_matrix/consent?v=%s"
+ % (self.hs.config.public_baseurl, self.hs.config.user_consent_version),
+ "myurl": "%s/r0/auth/%s/fallback/web"
+ % (CLIENT_API_PREFIX, LoginType.TERMS),
}
html_bytes = html.encode("utf8")
request.setResponseCode(200)
@@ -187,26 +184,20 @@ class AuthRestServlet(RestServlet):
if not response:
raise SynapseError(400, "No captcha response supplied")
- authdict = {
- 'response': response,
- 'session': session,
- }
+ authdict = {"response": response, "session": session}
success = yield self.auth_handler.add_oob_auth(
- LoginType.RECAPTCHA,
- authdict,
- self.hs.get_ip_from_request(request)
+ LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request)
)
if success:
html = SUCCESS_TEMPLATE
else:
html = RECAPTCHA_TEMPLATE % {
- 'session': session,
- 'myurl': "%s/r0/auth/%s/fallback/web" % (
- CLIENT_API_PREFIX, LoginType.RECAPTCHA
- ),
- 'sitekey': self.hs.config.recaptcha_public_key,
+ "session": session,
+ "myurl": "%s/r0/auth/%s/fallback/web"
+ % (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
+ "sitekey": self.hs.config.recaptcha_public_key,
}
html_bytes = html.encode("utf8")
request.setResponseCode(200)
@@ -218,31 +209,28 @@ class AuthRestServlet(RestServlet):
defer.returnValue(None)
elif stagetype == LoginType.TERMS:
- if ('session' not in request.args or
- len(request.args['session'])) == 0:
+ if ("session" not in request.args or len(request.args["session"])) == 0:
raise SynapseError(400, "No session supplied")
- session = request.args['session'][0]
- authdict = {'session': session}
+ session = request.args["session"][0]
+ authdict = {"session": session}
success = yield self.auth_handler.add_oob_auth(
- LoginType.TERMS,
- authdict,
- self.hs.get_ip_from_request(request)
+ LoginType.TERMS, authdict, self.hs.get_ip_from_request(request)
)
if success:
html = SUCCESS_TEMPLATE
else:
html = TERMS_TEMPLATE % {
- 'session': session,
- 'terms_url': "%s_matrix/consent?v=%s" % (
+ "session": session,
+ "terms_url": "%s_matrix/consent?v=%s"
+ % (
self.hs.config.public_baseurl,
self.hs.config.user_consent_version,
),
- 'myurl': "%s/r0/auth/%s/fallback/web" % (
- CLIENT_API_PREFIX, LoginType.TERMS
- ),
+ "myurl": "%s/r0/auth/%s/fallback/web"
+ % (CLIENT_API_PREFIX, LoginType.TERMS),
}
html_bytes = html.encode("utf8")
request.setResponseCode(200)
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index 78665304a5..d279229d74 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -56,6 +56,7 @@ class DeleteDevicesRestServlet(RestServlet):
API for bulk deletion of devices. Accepts a JSON object with a devices
key which lists the device_ids to delete. Requires user interactive auth.
"""
+
PATTERNS = client_patterns("/delete_devices")
def __init__(self, hs):
@@ -84,12 +85,11 @@ class DeleteDevicesRestServlet(RestServlet):
assert_params_in_dict(body, ["devices"])
yield self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request),
+ requester, body, self.hs.get_ip_from_request(request)
)
yield self.device_handler.delete_devices(
- requester.user.to_string(),
- body['devices'],
+ requester.user.to_string(), body["devices"]
)
defer.returnValue((200, {}))
@@ -112,8 +112,7 @@ class DeviceRestServlet(RestServlet):
def on_GET(self, request, device_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
device = yield self.device_handler.get_device(
- requester.user.to_string(),
- device_id,
+ requester.user.to_string(), device_id
)
defer.returnValue((200, device))
@@ -134,12 +133,10 @@ class DeviceRestServlet(RestServlet):
raise
yield self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request),
+ requester, body, self.hs.get_ip_from_request(request)
)
- yield self.device_handler.delete_device(
- requester.user.to_string(), device_id,
- )
+ yield self.device_handler.delete_device(requester.user.to_string(), device_id)
defer.returnValue((200, {}))
@defer.inlineCallbacks
@@ -148,9 +145,7 @@ class DeviceRestServlet(RestServlet):
body = parse_json_object_from_request(request)
yield self.device_handler.update_device(
- requester.user.to_string(),
- device_id,
- body
+ requester.user.to_string(), device_id, body
)
defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index 65db48c3cc..3f0adf4a21 100644
--- a/synapse/rest/client/v2_alpha/filter.py
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -53,8 +53,7 @@ class GetFilterRestServlet(RestServlet):
try:
filter = yield self.filtering.get_user_filter(
- user_localpart=target_user.localpart,
- filter_id=filter_id,
+ user_localpart=target_user.localpart, filter_id=filter_id
)
defer.returnValue((200, filter.get_filter_json()))
@@ -84,14 +83,10 @@ class CreateFilterRestServlet(RestServlet):
raise AuthError(403, "Can only create filters for local users")
content = parse_json_object_from_request(request)
- set_timeline_upper_limit(
- content,
- self.hs.config.filter_timeline_limit
- )
+ set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit)
filter_id = yield self.filtering.add_user_filter(
- user_localpart=target_user.localpart,
- user_filter=content,
+ user_localpart=target_user.localpart, user_filter=content
)
defer.returnValue((200, {"filter_id": str(filter_id)}))
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index d082385ec7..a312dd2593 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -29,6 +29,7 @@ logger = logging.getLogger(__name__)
class GroupServlet(RestServlet):
"""Get the group profile
"""
+
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$")
def __init__(self, hs):
@@ -43,8 +44,7 @@ class GroupServlet(RestServlet):
requester_user_id = requester.user.to_string()
group_description = yield self.groups_handler.get_group_profile(
- group_id,
- requester_user_id,
+ group_id, requester_user_id
)
defer.returnValue((200, group_description))
@@ -56,7 +56,7 @@ class GroupServlet(RestServlet):
content = parse_json_object_from_request(request)
yield self.groups_handler.update_group_profile(
- group_id, requester_user_id, content,
+ group_id, requester_user_id, content
)
defer.returnValue((200, {}))
@@ -65,6 +65,7 @@ class GroupServlet(RestServlet):
class GroupSummaryServlet(RestServlet):
"""Get the full group summary
"""
+
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$")
def __init__(self, hs):
@@ -79,8 +80,7 @@ class GroupSummaryServlet(RestServlet):
requester_user_id = requester.user.to_string()
get_group_summary = yield self.groups_handler.get_group_summary(
- group_id,
- requester_user_id,
+ group_id, requester_user_id
)
defer.returnValue((200, get_group_summary))
@@ -93,6 +93,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
- /groups/:group/summary/rooms/:room_id
- /groups/:group/summary/categories/:category/rooms/:room_id
"""
+
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/summary"
"(/categories/(?P<category_id>[^/]+))?"
@@ -112,7 +113,8 @@ class GroupSummaryRoomsCatServlet(RestServlet):
content = parse_json_object_from_request(request)
resp = yield self.groups_handler.update_group_summary_room(
- group_id, requester_user_id,
+ group_id,
+ requester_user_id,
room_id=room_id,
category_id=category_id,
content=content,
@@ -126,9 +128,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
requester_user_id = requester.user.to_string()
resp = yield self.groups_handler.delete_group_summary_room(
- group_id, requester_user_id,
- room_id=room_id,
- category_id=category_id,
+ group_id, requester_user_id, room_id=room_id, category_id=category_id
)
defer.returnValue((200, resp))
@@ -137,6 +137,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
class GroupCategoryServlet(RestServlet):
"""Get/add/update/delete a group category
"""
+
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
)
@@ -153,8 +154,7 @@ class GroupCategoryServlet(RestServlet):
requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_category(
- group_id, requester_user_id,
- category_id=category_id,
+ group_id, requester_user_id, category_id=category_id
)
defer.returnValue((200, category))
@@ -166,9 +166,7 @@ class GroupCategoryServlet(RestServlet):
content = parse_json_object_from_request(request)
resp = yield self.groups_handler.update_group_category(
- group_id, requester_user_id,
- category_id=category_id,
- content=content,
+ group_id, requester_user_id, category_id=category_id, content=content
)
defer.returnValue((200, resp))
@@ -179,8 +177,7 @@ class GroupCategoryServlet(RestServlet):
requester_user_id = requester.user.to_string()
resp = yield self.groups_handler.delete_group_category(
- group_id, requester_user_id,
- category_id=category_id,
+ group_id, requester_user_id, category_id=category_id
)
defer.returnValue((200, resp))
@@ -189,9 +186,8 @@ class GroupCategoryServlet(RestServlet):
class GroupCategoriesServlet(RestServlet):
"""Get all group categories
"""
- PATTERNS = client_patterns(
- "/groups/(?P<group_id>[^/]*)/categories/$"
- )
+
+ PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/categories/$")
def __init__(self, hs):
super(GroupCategoriesServlet, self).__init__()
@@ -205,7 +201,7 @@ class GroupCategoriesServlet(RestServlet):
requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_categories(
- group_id, requester_user_id,
+ group_id, requester_user_id
)
defer.returnValue((200, category))
@@ -214,9 +210,8 @@ class GroupCategoriesServlet(RestServlet):
class GroupRoleServlet(RestServlet):
"""Get/add/update/delete a group role
"""
- PATTERNS = client_patterns(
- "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
- )
+
+ PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$")
def __init__(self, hs):
super(GroupRoleServlet, self).__init__()
@@ -230,8 +225,7 @@ class GroupRoleServlet(RestServlet):
requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_role(
- group_id, requester_user_id,
- role_id=role_id,
+ group_id, requester_user_id, role_id=role_id
)
defer.returnValue((200, category))
@@ -243,9 +237,7 @@ class GroupRoleServlet(RestServlet):
content = parse_json_object_from_request(request)
resp = yield self.groups_handler.update_group_role(
- group_id, requester_user_id,
- role_id=role_id,
- content=content,
+ group_id, requester_user_id, role_id=role_id, content=content
)
defer.returnValue((200, resp))
@@ -256,8 +248,7 @@ class GroupRoleServlet(RestServlet):
requester_user_id = requester.user.to_string()
resp = yield self.groups_handler.delete_group_role(
- group_id, requester_user_id,
- role_id=role_id,
+ group_id, requester_user_id, role_id=role_id
)
defer.returnValue((200, resp))
@@ -266,9 +257,8 @@ class GroupRoleServlet(RestServlet):
class GroupRolesServlet(RestServlet):
"""Get all group roles
"""
- PATTERNS = client_patterns(
- "/groups/(?P<group_id>[^/]*)/roles/$"
- )
+
+ PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/$")
def __init__(self, hs):
super(GroupRolesServlet, self).__init__()
@@ -282,7 +272,7 @@ class GroupRolesServlet(RestServlet):
requester_user_id = requester.user.to_string()
category = yield self.groups_handler.get_group_roles(
- group_id, requester_user_id,
+ group_id, requester_user_id
)
defer.returnValue((200, category))
@@ -295,6 +285,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
- /groups/:group/summary/users/:room_id
- /groups/:group/summary/roles/:role/users/:user_id
"""
+
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/summary"
"(/roles/(?P<role_id>[^/]+))?"
@@ -314,7 +305,8 @@ class GroupSummaryUsersRoleServlet(RestServlet):
content = parse_json_object_from_request(request)
resp = yield self.groups_handler.update_group_summary_user(
- group_id, requester_user_id,
+ group_id,
+ requester_user_id,
user_id=user_id,
role_id=role_id,
content=content,
@@ -328,9 +320,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
requester_user_id = requester.user.to_string()
resp = yield self.groups_handler.delete_group_summary_user(
- group_id, requester_user_id,
- user_id=user_id,
- role_id=role_id,
+ group_id, requester_user_id, user_id=user_id, role_id=role_id
)
defer.returnValue((200, resp))
@@ -339,6 +329,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
class GroupRoomServlet(RestServlet):
"""Get all rooms in a group
"""
+
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
def __init__(self, hs):
@@ -352,7 +343,9 @@ class GroupRoomServlet(RestServlet):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id)
+ result = yield self.groups_handler.get_rooms_in_group(
+ group_id, requester_user_id
+ )
defer.returnValue((200, result))
@@ -360,6 +353,7 @@ class GroupRoomServlet(RestServlet):
class GroupUsersServlet(RestServlet):
"""Get all users in a group
"""
+
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$")
def __init__(self, hs):
@@ -373,7 +367,9 @@ class GroupUsersServlet(RestServlet):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id)
+ result = yield self.groups_handler.get_users_in_group(
+ group_id, requester_user_id
+ )
defer.returnValue((200, result))
@@ -381,6 +377,7 @@ class GroupUsersServlet(RestServlet):
class GroupInvitedUsersServlet(RestServlet):
"""Get users invited to a group
"""
+
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
def __init__(self, hs):
@@ -395,8 +392,7 @@ class GroupInvitedUsersServlet(RestServlet):
requester_user_id = requester.user.to_string()
result = yield self.groups_handler.get_invited_users_in_group(
- group_id,
- requester_user_id,
+ group_id, requester_user_id
)
defer.returnValue((200, result))
@@ -405,6 +401,7 @@ class GroupInvitedUsersServlet(RestServlet):
class GroupSettingJoinPolicyServlet(RestServlet):
"""Set group join policy
"""
+
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
def __init__(self, hs):
@@ -420,9 +417,7 @@ class GroupSettingJoinPolicyServlet(RestServlet):
content = parse_json_object_from_request(request)
result = yield self.groups_handler.set_group_join_policy(
- group_id,
- requester_user_id,
- content,
+ group_id, requester_user_id, content
)
defer.returnValue((200, result))
@@ -431,6 +426,7 @@ class GroupSettingJoinPolicyServlet(RestServlet):
class GroupCreateServlet(RestServlet):
"""Create a group
"""
+
PATTERNS = client_patterns("/create_group$")
def __init__(self, hs):
@@ -451,9 +447,7 @@ class GroupCreateServlet(RestServlet):
group_id = GroupID(localpart, self.server_name).to_string()
result = yield self.groups_handler.create_group(
- group_id,
- requester_user_id,
- content,
+ group_id, requester_user_id, content
)
defer.returnValue((200, result))
@@ -462,6 +456,7 @@ class GroupCreateServlet(RestServlet):
class GroupAdminRoomsServlet(RestServlet):
"""Add a room to the group
"""
+
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
)
@@ -479,7 +474,7 @@ class GroupAdminRoomsServlet(RestServlet):
content = parse_json_object_from_request(request)
result = yield self.groups_handler.add_room_to_group(
- group_id, requester_user_id, room_id, content,
+ group_id, requester_user_id, room_id, content
)
defer.returnValue((200, result))
@@ -490,7 +485,7 @@ class GroupAdminRoomsServlet(RestServlet):
requester_user_id = requester.user.to_string()
result = yield self.groups_handler.remove_room_from_group(
- group_id, requester_user_id, room_id,
+ group_id, requester_user_id, room_id
)
defer.returnValue((200, result))
@@ -499,6 +494,7 @@ class GroupAdminRoomsServlet(RestServlet):
class GroupAdminRoomsConfigServlet(RestServlet):
"""Update the config of a room in a group
"""
+
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)"
"/config/(?P<config_key>[^/]*)$"
@@ -517,7 +513,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
content = parse_json_object_from_request(request)
result = yield self.groups_handler.update_room_in_group(
- group_id, requester_user_id, room_id, config_key, content,
+ group_id, requester_user_id, room_id, config_key, content
)
defer.returnValue((200, result))
@@ -526,6 +522,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
class GroupAdminUsersInviteServlet(RestServlet):
"""Invite a user to the group
"""
+
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
)
@@ -546,7 +543,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
content = parse_json_object_from_request(request)
config = content.get("config", {})
result = yield self.groups_handler.invite(
- group_id, user_id, requester_user_id, config,
+ group_id, user_id, requester_user_id, config
)
defer.returnValue((200, result))
@@ -555,6 +552,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
class GroupAdminUsersKickServlet(RestServlet):
"""Kick a user from the group
"""
+
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
)
@@ -572,7 +570,7 @@ class GroupAdminUsersKickServlet(RestServlet):
content = parse_json_object_from_request(request)
result = yield self.groups_handler.remove_user_from_group(
- group_id, user_id, requester_user_id, content,
+ group_id, user_id, requester_user_id, content
)
defer.returnValue((200, result))
@@ -581,9 +579,8 @@ class GroupAdminUsersKickServlet(RestServlet):
class GroupSelfLeaveServlet(RestServlet):
"""Leave a joined group
"""
- PATTERNS = client_patterns(
- "/groups/(?P<group_id>[^/]*)/self/leave$"
- )
+
+ PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/leave$")
def __init__(self, hs):
super(GroupSelfLeaveServlet, self).__init__()
@@ -598,7 +595,7 @@ class GroupSelfLeaveServlet(RestServlet):
content = parse_json_object_from_request(request)
result = yield self.groups_handler.remove_user_from_group(
- group_id, requester_user_id, requester_user_id, content,
+ group_id, requester_user_id, requester_user_id, content
)
defer.returnValue((200, result))
@@ -607,9 +604,8 @@ class GroupSelfLeaveServlet(RestServlet):
class GroupSelfJoinServlet(RestServlet):
"""Attempt to join a group, or knock
"""
- PATTERNS = client_patterns(
- "/groups/(?P<group_id>[^/]*)/self/join$"
- )
+
+ PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/join$")
def __init__(self, hs):
super(GroupSelfJoinServlet, self).__init__()
@@ -624,7 +620,7 @@ class GroupSelfJoinServlet(RestServlet):
content = parse_json_object_from_request(request)
result = yield self.groups_handler.join_group(
- group_id, requester_user_id, content,
+ group_id, requester_user_id, content
)
defer.returnValue((200, result))
@@ -633,9 +629,8 @@ class GroupSelfJoinServlet(RestServlet):
class GroupSelfAcceptInviteServlet(RestServlet):
"""Accept a group invite
"""
- PATTERNS = client_patterns(
- "/groups/(?P<group_id>[^/]*)/self/accept_invite$"
- )
+
+ PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/accept_invite$")
def __init__(self, hs):
super(GroupSelfAcceptInviteServlet, self).__init__()
@@ -650,7 +645,7 @@ class GroupSelfAcceptInviteServlet(RestServlet):
content = parse_json_object_from_request(request)
result = yield self.groups_handler.accept_invite(
- group_id, requester_user_id, content,
+ group_id, requester_user_id, content
)
defer.returnValue((200, result))
@@ -659,9 +654,8 @@ class GroupSelfAcceptInviteServlet(RestServlet):
class GroupSelfUpdatePublicityServlet(RestServlet):
"""Update whether we publicise a users membership of a group
"""
- PATTERNS = client_patterns(
- "/groups/(?P<group_id>[^/]*)/self/update_publicity$"
- )
+
+ PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/update_publicity$")
def __init__(self, hs):
super(GroupSelfUpdatePublicityServlet, self).__init__()
@@ -676,9 +670,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
content = parse_json_object_from_request(request)
publicise = content["publicise"]
- yield self.store.update_group_publicity(
- group_id, requester_user_id, publicise,
- )
+ yield self.store.update_group_publicity(group_id, requester_user_id, publicise)
defer.returnValue((200, {}))
@@ -686,9 +678,8 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
class PublicisedGroupsForUserServlet(RestServlet):
"""Get the list of groups a user is advertising
"""
- PATTERNS = client_patterns(
- "/publicised_groups/(?P<user_id>[^/]*)$"
- )
+
+ PATTERNS = client_patterns("/publicised_groups/(?P<user_id>[^/]*)$")
def __init__(self, hs):
super(PublicisedGroupsForUserServlet, self).__init__()
@@ -701,9 +692,7 @@ class PublicisedGroupsForUserServlet(RestServlet):
def on_GET(self, request, user_id):
yield self.auth.get_user_by_req(request, allow_guest=True)
- result = yield self.groups_handler.get_publicised_groups_for_user(
- user_id
- )
+ result = yield self.groups_handler.get_publicised_groups_for_user(user_id)
defer.returnValue((200, result))
@@ -711,9 +700,8 @@ class PublicisedGroupsForUserServlet(RestServlet):
class PublicisedGroupsForUsersServlet(RestServlet):
"""Get the list of groups a user is advertising
"""
- PATTERNS = client_patterns(
- "/publicised_groups$"
- )
+
+ PATTERNS = client_patterns("/publicised_groups$")
def __init__(self, hs):
super(PublicisedGroupsForUsersServlet, self).__init__()
@@ -729,9 +717,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
content = parse_json_object_from_request(request)
user_ids = content["user_ids"]
- result = yield self.groups_handler.bulk_get_publicised_groups(
- user_ids
- )
+ result = yield self.groups_handler.bulk_get_publicised_groups(user_ids)
defer.returnValue((200, result))
@@ -739,9 +725,8 @@ class PublicisedGroupsForUsersServlet(RestServlet):
class GroupsForUserServlet(RestServlet):
"""Get all groups the logged in user is joined to
"""
- PATTERNS = client_patterns(
- "/joined_groups$"
- )
+
+ PATTERNS = client_patterns("/joined_groups$")
def __init__(self, hs):
super(GroupsForUserServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 4cbfbf5631..45c9928b65 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -56,6 +56,7 @@ class KeyUploadServlet(RestServlet):
},
}
"""
+
PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
def __init__(self, hs):
@@ -76,18 +77,19 @@ class KeyUploadServlet(RestServlet):
if device_id is not None:
# passing the device_id here is deprecated; however, we allow it
# for now for compatibility with older clients.
- if (requester.device_id is not None and
- device_id != requester.device_id):
- logger.warning("Client uploading keys for a different device "
- "(logged in as %s, uploading for %s)",
- requester.device_id, device_id)
+ if requester.device_id is not None and device_id != requester.device_id:
+ logger.warning(
+ "Client uploading keys for a different device "
+ "(logged in as %s, uploading for %s)",
+ requester.device_id,
+ device_id,
+ )
else:
device_id = requester.device_id
if device_id is None:
raise SynapseError(
- 400,
- "To upload keys, you must pass device_id when authenticating"
+ 400, "To upload keys, you must pass device_id when authenticating"
)
result = yield self.e2e_keys_handler.upload_keys_for_user(
@@ -159,6 +161,7 @@ class KeyChangesServlet(RestServlet):
200 OK
{ "changed": ["@foo:example.com"] }
"""
+
PATTERNS = client_patterns("/keys/changes$")
def __init__(self, hs):
@@ -184,9 +187,7 @@ class KeyChangesServlet(RestServlet):
user_id = requester.user.to_string()
- results = yield self.device_handler.get_user_ids_changed(
- user_id, from_token,
- )
+ results = yield self.device_handler.get_user_ids_changed(user_id, from_token)
defer.returnValue((200, results))
@@ -209,6 +210,7 @@ class OneTimeKeyServlet(RestServlet):
} } } }
"""
+
PATTERNS = client_patterns("/keys/claim$")
def __init__(self, hs):
@@ -221,10 +223,7 @@ class OneTimeKeyServlet(RestServlet):
yield self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
- result = yield self.e2e_keys_handler.claim_one_time_keys(
- body,
- timeout,
- )
+ result = yield self.e2e_keys_handler.claim_one_time_keys(body, timeout)
defer.returnValue((200, result))
diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py
index 53e666989b..728a52328f 100644
--- a/synapse/rest/client/v2_alpha/notifications.py
+++ b/synapse/rest/client/v2_alpha/notifications.py
@@ -51,7 +51,7 @@ class NotificationsServlet(RestServlet):
)
receipts_by_room = yield self.store.get_receipts_for_user_with_orderings(
- user_id, 'm.read'
+ user_id, "m.read"
)
notif_event_ids = [pa["event_id"] for pa in push_actions]
@@ -67,11 +67,13 @@ class NotificationsServlet(RestServlet):
"profile_tag": pa["profile_tag"],
"actions": pa["actions"],
"ts": pa["received_ts"],
- "event": (yield self._event_serializer.serialize_event(
- notif_events[pa["event_id"]],
- self.clock.time_msec(),
- event_format=format_event_for_client_v2_without_room_id,
- )),
+ "event": (
+ yield self._event_serializer.serialize_event(
+ notif_events[pa["event_id"]],
+ self.clock.time_msec(),
+ event_format=format_event_for_client_v2_without_room_id,
+ )
+ ),
}
if pa["room_id"] not in receipts_by_room:
@@ -80,17 +82,15 @@ class NotificationsServlet(RestServlet):
receipt = receipts_by_room[pa["room_id"]]
returned_pa["read"] = (
- receipt["topological_ordering"], receipt["stream_ordering"]
- ) >= (
- pa["topological_ordering"], pa["stream_ordering"]
- )
+ receipt["topological_ordering"],
+ receipt["stream_ordering"],
+ ) >= (pa["topological_ordering"], pa["stream_ordering"])
returned_push_actions.append(returned_pa)
next_token = str(pa["stream_ordering"])
- defer.returnValue((200, {
- "notifications": returned_push_actions,
- "next_token": next_token,
- }))
+ defer.returnValue(
+ (200, {"notifications": returned_push_actions, "next_token": next_token})
+ )
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py
index bb927d9f9d..b1b5385b09 100644
--- a/synapse/rest/client/v2_alpha/openid.py
+++ b/synapse/rest/client/v2_alpha/openid.py
@@ -56,9 +56,8 @@ class IdTokenServlet(RestServlet):
"expires_in": 3600,
}
"""
- PATTERNS = client_patterns(
- "/user/(?P<user_id>[^/]*)/openid/request_token"
- )
+
+ PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/openid/request_token")
EXPIRES_MS = 3600 * 1000
@@ -84,12 +83,17 @@ class IdTokenServlet(RestServlet):
yield self.store.insert_open_id_token(token, ts_valid_until_ms, user_id)
- defer.returnValue((200, {
- "access_token": token,
- "token_type": "Bearer",
- "matrix_server_name": self.server_name,
- "expires_in": self.EXPIRES_MS / 1000,
- }))
+ defer.returnValue(
+ (
+ 200,
+ {
+ "access_token": token,
+ "token_type": "Bearer",
+ "matrix_server_name": self.server_name,
+ "expires_in": self.EXPIRES_MS / 1000,
+ },
+ )
+ )
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py
index f4bd0d077f..e75664279b 100644
--- a/synapse/rest/client/v2_alpha/read_marker.py
+++ b/synapse/rest/client/v2_alpha/read_marker.py
@@ -48,7 +48,7 @@ class ReadMarkerRestServlet(RestServlet):
room_id,
"m.read",
user_id=requester.user.to_string(),
- event_id=read_event_id
+ event_id=read_event_id,
)
read_marker_event_id = body.get("m.fully_read", None)
@@ -56,7 +56,7 @@ class ReadMarkerRestServlet(RestServlet):
yield self.read_marker_handler.received_client_read_marker(
room_id,
user_id=requester.user.to_string(),
- event_id=read_marker_event_id
+ event_id=read_marker_event_id,
)
defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py
index fa12ac3e4d..488905626a 100644
--- a/synapse/rest/client/v2_alpha/receipts.py
+++ b/synapse/rest/client/v2_alpha/receipts.py
@@ -49,10 +49,7 @@ class ReceiptRestServlet(RestServlet):
yield self.presence_handler.bump_presence_active_time(requester.user)
yield self.receipts_handler.received_client_receipt(
- room_id,
- receipt_type,
- user_id=requester.user.to_string(),
- event_id=event_id
+ room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id
)
defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 79c085408b..5c120e4dd5 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -52,6 +52,7 @@ from ._base import client_patterns, interactive_auth_handler
if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest
else:
+
def compare_digest(a, b):
return a == b
@@ -75,11 +76,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
- assert_params_in_dict(body, [
- 'id_server', 'client_secret', 'email', 'send_attempt'
- ])
+ assert_params_in_dict(
+ body, ["id_server", "client_secret", "email", "send_attempt"]
+ )
- if not check_3pid_allowed(self.hs, "email", body['email']):
+ if not check_3pid_allowed(self.hs, "email", body["email"]):
raise SynapseError(
403,
"Your email domain is not authorized to register on this server",
@@ -87,7 +88,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
- 'email', body['email']
+ "email", body["email"]
)
if existingUid is not None:
@@ -113,13 +114,12 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
def on_POST(self, request):
body = parse_json_object_from_request(request)
- assert_params_in_dict(body, [
- 'id_server', 'client_secret',
- 'country', 'phone_number',
- 'send_attempt',
- ])
+ assert_params_in_dict(
+ body,
+ ["id_server", "client_secret", "country", "phone_number", "send_attempt"],
+ )
- msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+ msisdn = phone_number_to_msisdn(body["country"], body["phone_number"])
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
@@ -129,7 +129,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
)
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
- 'msisdn', msisdn
+ "msisdn", msisdn
)
if existingUid is not None:
@@ -165,7 +165,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
reject_limit=1,
# Allow 1 request at a time
concurrent_requests=1,
- )
+ ),
)
@defer.inlineCallbacks
@@ -212,7 +212,8 @@ class RegisterRestServlet(RestServlet):
time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.can_do_action(
- client_addr, time_now_s=time_now,
+ client_addr,
+ time_now_s=time_now,
rate_hz=self.hs.config.rc_registration.per_second,
burst_count=self.hs.config.rc_registration.burst_count,
update=False,
@@ -220,7 +221,7 @@ class RegisterRestServlet(RestServlet):
if not allowed:
raise LimitExceededError(
- retry_after_ms=int(1000 * (time_allowed - time_now)),
+ retry_after_ms=int(1000 * (time_allowed - time_now))
)
kind = b"user"
@@ -239,18 +240,22 @@ class RegisterRestServlet(RestServlet):
# we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the username/password provided to us.
desired_password = None
- if 'password' in body:
- if (not isinstance(body['password'], string_types) or
- len(body['password']) > 512):
+ if "password" in body:
+ if (
+ not isinstance(body["password"], string_types)
+ or len(body["password"]) > 512
+ ):
raise SynapseError(400, "Invalid password")
desired_password = body["password"]
desired_username = None
- if 'username' in body:
- if (not isinstance(body['username'], string_types) or
- len(body['username']) > 512):
+ if "username" in body:
+ if (
+ not isinstance(body["username"], string_types)
+ or len(body["username"]) > 512
+ ):
raise SynapseError(400, "Invalid username")
- desired_username = body['username']
+ desired_username = body["username"]
appservice = None
if self.auth.has_access_token(request):
@@ -290,7 +295,7 @@ class RegisterRestServlet(RestServlet):
desired_username = desired_username.lower()
# == Shared Secret Registration == (e.g. create new user scripts)
- if 'mac' in body:
+ if "mac" in body:
# FIXME: Should we really be determining if this is shared secret
# auth based purely on the 'mac' key?
result = yield self._do_shared_secret_registration(
@@ -305,16 +310,13 @@ class RegisterRestServlet(RestServlet):
guest_access_token = body.get("guest_access_token", None)
- if (
- 'initial_device_display_name' in body and
- 'password' not in body
- ):
+ if "initial_device_display_name" in body and "password" not in body:
# ignore 'initial_device_display_name' if sent without
# a password to work around a client bug where it sent
# the 'initial_device_display_name' param alone, wiping out
# the original registration params
logger.warn("Ignoring initial_device_display_name without password")
- del body['initial_device_display_name']
+ del body["initial_device_display_name"]
session_id = self.auth_handler.get_session_id(body)
registered_user_id = None
@@ -336,8 +338,8 @@ class RegisterRestServlet(RestServlet):
# FIXME: need a better error than "no auth flow found" for scenarios
# where we required 3PID for registration but the user didn't give one
- require_email = 'email' in self.hs.config.registrations_require_3pid
- require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
+ require_email = "email" in self.hs.config.registrations_require_3pid
+ require_msisdn = "msisdn" in self.hs.config.registrations_require_3pid
show_msisdn = True
if self.hs.config.disable_msisdn_registration:
@@ -362,9 +364,9 @@ class RegisterRestServlet(RestServlet):
if not require_email:
flows.extend([[LoginType.RECAPTCHA, LoginType.MSISDN]])
# always let users provide both MSISDN & email
- flows.extend([
- [LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
- ])
+ flows.extend(
+ [[LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY]]
+ )
else:
# only support 3PIDless registration if no 3PIDs are required
if not require_email and not require_msisdn:
@@ -378,9 +380,7 @@ class RegisterRestServlet(RestServlet):
if not require_email or require_msisdn:
flows.extend([[LoginType.MSISDN]])
# always let users provide both MSISDN & email
- flows.extend([
- [LoginType.MSISDN, LoginType.EMAIL_IDENTITY]
- ])
+ flows.extend([[LoginType.MSISDN, LoginType.EMAIL_IDENTITY]])
# Append m.login.terms to all flows if we're requiring consent
if self.hs.config.user_consent_at_registration:
@@ -410,21 +410,20 @@ class RegisterRestServlet(RestServlet):
if auth_result:
for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
if login_type in auth_result:
- medium = auth_result[login_type]['medium']
- address = auth_result[login_type]['address']
+ medium = auth_result[login_type]["medium"]
+ address = auth_result[login_type]["address"]
if not check_3pid_allowed(self.hs, medium, address):
raise SynapseError(
403,
- "Third party identifiers (email/phone numbers)" +
- " are not authorized on this server",
+ "Third party identifiers (email/phone numbers)"
+ + " are not authorized on this server",
Codes.THREEPID_DENIED,
)
if registered_user_id is not None:
logger.info(
- "Already registered user ID %r for this session",
- registered_user_id
+ "Already registered user ID %r for this session", registered_user_id
)
# don't re-register the threepids
registered = False
@@ -451,11 +450,11 @@ class RegisterRestServlet(RestServlet):
# the two activation emails, they would register the same 3pid twice.
for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
if login_type in auth_result:
- medium = auth_result[login_type]['medium']
- address = auth_result[login_type]['address']
+ medium = auth_result[login_type]["medium"]
+ address = auth_result[login_type]["address"]
existingUid = yield self.store.get_user_id_by_threepid(
- medium, address,
+ medium, address
)
if existingUid is not None:
@@ -520,7 +519,7 @@ class RegisterRestServlet(RestServlet):
raise SynapseError(400, "Shared secret registration is not enabled")
if not username:
raise SynapseError(
- 400, "username must be specified", errcode=Codes.BAD_JSON,
+ 400, "username must be specified", errcode=Codes.BAD_JSON
)
# use the username from the original request rather than the
@@ -541,12 +540,10 @@ class RegisterRestServlet(RestServlet):
).hexdigest()
if not compare_digest(want_mac, got_mac):
- raise SynapseError(
- 403, "HMAC incorrect",
- )
+ raise SynapseError(403, "HMAC incorrect")
(user_id, _) = yield self.registration_handler.register(
- localpart=username, password=password, generate_token=False,
+ localpart=username, password=password, generate_token=False
)
result = yield self._create_registration_details(user_id, body)
@@ -565,21 +562,15 @@ class RegisterRestServlet(RestServlet):
Returns:
defer.Deferred: (object) dictionary for response from /register
"""
- result = {
- "user_id": user_id,
- "home_server": self.hs.hostname,
- }
+ result = {"user_id": user_id, "home_server": self.hs.hostname}
if not params.get("inhibit_login", False):
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
- user_id, device_id, initial_display_name, is_guest=False,
+ user_id, device_id, initial_display_name, is_guest=False
)
- result.update({
- "access_token": access_token,
- "device_id": device_id,
- })
+ result.update({"access_token": access_token, "device_id": device_id})
defer.returnValue(result)
@defer.inlineCallbacks
@@ -587,9 +578,7 @@ class RegisterRestServlet(RestServlet):
if not self.hs.config.allow_guest_access:
raise SynapseError(403, "Guest access is disabled")
user_id, _ = yield self.registration_handler.register(
- generate_token=False,
- make_guest=True,
- address=address,
+ generate_token=False, make_guest=True, address=address
)
# we don't allow guests to specify their own device_id, because
@@ -597,15 +586,20 @@ class RegisterRestServlet(RestServlet):
device_id = synapse.api.auth.GUEST_DEVICE_ID
initial_display_name = params.get("initial_device_display_name")
device_id, access_token = yield self.registration_handler.register_device(
- user_id, device_id, initial_display_name, is_guest=True,
+ user_id, device_id, initial_display_name, is_guest=True
)
- defer.returnValue((200, {
- "user_id": user_id,
- "device_id": device_id,
- "access_token": access_token,
- "home_server": self.hs.hostname,
- }))
+ defer.returnValue(
+ (
+ 200,
+ {
+ "user_id": user_id,
+ "device_id": device_id,
+ "access_token": access_token,
+ "home_server": self.hs.hostname,
+ },
+ )
+ )
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index f8f8742bdc..8e362782cc 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -32,7 +32,10 @@ from synapse.http.servlet import (
parse_string,
)
from synapse.rest.client.transactions import HttpTransactionCache
-from synapse.storage.relations import AggregationPaginationToken, RelationPaginationToken
+from synapse.storage.relations import (
+ AggregationPaginationToken,
+ RelationPaginationToken,
+)
from ._base import client_patterns
diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py
index 10198662a9..e7578af804 100644
--- a/synapse/rest/client/v2_alpha/report_event.py
+++ b/synapse/rest/client/v2_alpha/report_event.py
@@ -33,9 +33,7 @@ logger = logging.getLogger(__name__)
class ReportEventRestServlet(RestServlet):
- PATTERNS = client_patterns(
- "/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$"
- )
+ PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$")
def __init__(self, hs):
super(ReportEventRestServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py
index 87779645f9..8d1b810565 100644
--- a/synapse/rest/client/v2_alpha/room_keys.py
+++ b/synapse/rest/client/v2_alpha/room_keys.py
@@ -129,22 +129,12 @@ class RoomKeysServlet(RestServlet):
version = parse_string(request, "version")
if session_id:
- body = {
- "sessions": {
- session_id: body
- }
- }
+ body = {"sessions": {session_id: body}}
if room_id:
- body = {
- "rooms": {
- room_id: body
- }
- }
+ body = {"rooms": {room_id: body}}
- yield self.e2e_room_keys_handler.upload_room_keys(
- user_id, version, body
- )
+ yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body)
defer.returnValue((200, {}))
@defer.inlineCallbacks
@@ -212,10 +202,10 @@ class RoomKeysServlet(RestServlet):
if session_id:
# If the client requests a specific session, but that session was
# not backed up, then return an M_NOT_FOUND.
- if room_keys['rooms'] == {}:
+ if room_keys["rooms"] == {}:
raise NotFoundError("No room_keys found")
else:
- room_keys = room_keys['rooms'][room_id]['sessions'][session_id]
+ room_keys = room_keys["rooms"][room_id]["sessions"][session_id]
elif room_id:
# If the client requests all sessions from a room, but no sessions
# are found, then return an empty result rather than an error, so
@@ -223,10 +213,10 @@ class RoomKeysServlet(RestServlet):
# empty result is valid. (Similarly if the client requests all
# sessions from the backup, but in that case, room_keys is already
# in the right format, so we don't need to do anything about it.)
- if room_keys['rooms'] == {}:
- room_keys = {'sessions': {}}
+ if room_keys["rooms"] == {}:
+ room_keys = {"sessions": {}}
else:
- room_keys = room_keys['rooms'][room_id]
+ room_keys = room_keys["rooms"][room_id]
defer.returnValue((200, room_keys))
@@ -256,9 +246,7 @@ class RoomKeysServlet(RestServlet):
class RoomKeysNewVersionServlet(RestServlet):
- PATTERNS = client_patterns(
- "/room_keys/version$"
- )
+ PATTERNS = client_patterns("/room_keys/version$")
def __init__(self, hs):
"""
@@ -304,9 +292,7 @@ class RoomKeysNewVersionServlet(RestServlet):
user_id = requester.user.to_string()
info = parse_json_object_from_request(request)
- new_version = yield self.e2e_room_keys_handler.create_version(
- user_id, info
- )
+ new_version = yield self.e2e_room_keys_handler.create_version(user_id, info)
defer.returnValue((200, {"version": new_version}))
# we deliberately don't have a PUT /version, as these things really should
@@ -314,9 +300,7 @@ class RoomKeysNewVersionServlet(RestServlet):
class RoomKeysVersionServlet(RestServlet):
- PATTERNS = client_patterns(
- "/room_keys/version(/(?P<version>[^/]+))?$"
- )
+ PATTERNS = client_patterns("/room_keys/version(/(?P<version>[^/]+))?$")
def __init__(self, hs):
"""
@@ -350,9 +334,7 @@ class RoomKeysVersionServlet(RestServlet):
user_id = requester.user.to_string()
try:
- info = yield self.e2e_room_keys_handler.get_version_info(
- user_id, version
- )
+ info = yield self.e2e_room_keys_handler.get_version_info(user_id, version)
except SynapseError as e:
if e.code == 404:
raise SynapseError(404, "No backup found", Codes.NOT_FOUND)
@@ -375,9 +357,7 @@ class RoomKeysVersionServlet(RestServlet):
requester = yield self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
- yield self.e2e_room_keys_handler.delete_version(
- user_id, version
- )
+ yield self.e2e_room_keys_handler.delete_version(user_id, version)
defer.returnValue((200, {}))
@defer.inlineCallbacks
@@ -407,11 +387,11 @@ class RoomKeysVersionServlet(RestServlet):
info = parse_json_object_from_request(request)
if version is None:
- raise SynapseError(400, "No version specified to update", Codes.MISSING_PARAM)
+ raise SynapseError(
+ 400, "No version specified to update", Codes.MISSING_PARAM
+ )
- yield self.e2e_room_keys_handler.update_version(
- user_id, version, info
- )
+ yield self.e2e_room_keys_handler.update_version(user_id, version, info)
defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
index c621a90fba..d7f7faa029 100644
--- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
@@ -47,9 +47,10 @@ class RoomUpgradeRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer):
"""
+
PATTERNS = client_patterns(
# /rooms/$roomid/upgrade
- "/rooms/(?P<room_id>[^/]*)/upgrade$",
+ "/rooms/(?P<room_id>[^/]*)/upgrade$"
)
def __init__(self, hs):
@@ -63,7 +64,7 @@ class RoomUpgradeRestServlet(RestServlet):
requester = yield self._auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
- assert_params_in_dict(content, ("new_version", ))
+ assert_params_in_dict(content, ("new_version",))
new_version = content["new_version"]
if new_version not in KNOWN_ROOM_VERSIONS:
@@ -77,9 +78,7 @@ class RoomUpgradeRestServlet(RestServlet):
requester, room_id, new_version
)
- ret = {
- "replacement_room": new_room_id,
- }
+ ret = {"replacement_room": new_room_id}
defer.returnValue((200, ret))
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index 120a713361..78075b8fc0 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
class SendToDeviceRestServlet(servlet.RestServlet):
PATTERNS = client_patterns(
- "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$",
+ "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$"
)
def __init__(self, hs):
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 148fc6c985..02d56dee6c 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -96,44 +96,42 @@ class SyncRestServlet(RestServlet):
400, "'from' is not a valid query parameter. Did you mean 'since'?"
)
- requester = yield self.auth.get_user_by_req(
- request, allow_guest=True
- )
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
user = requester.user
device_id = requester.device_id
timeout = parse_integer(request, "timeout", default=0)
since = parse_string(request, "since")
set_presence = parse_string(
- request, "set_presence", default="online",
- allowed_values=self.ALLOWED_PRESENCE
+ request,
+ "set_presence",
+ default="online",
+ allowed_values=self.ALLOWED_PRESENCE,
)
filter_id = parse_string(request, "filter", default=None)
full_state = parse_boolean(request, "full_state", default=False)
logger.debug(
"/sync: user=%r, timeout=%r, since=%r,"
- " set_presence=%r, filter_id=%r, device_id=%r" % (
- user, timeout, since, set_presence, filter_id, device_id
- )
+ " set_presence=%r, filter_id=%r, device_id=%r"
+ % (user, timeout, since, set_presence, filter_id, device_id)
)
request_key = (user, timeout, since, filter_id, full_state, device_id)
if filter_id:
- if filter_id.startswith('{'):
+ if filter_id.startswith("{"):
try:
filter_object = json.loads(filter_id)
- set_timeline_upper_limit(filter_object,
- self.hs.config.filter_timeline_limit)
+ set_timeline_upper_limit(
+ filter_object, self.hs.config.filter_timeline_limit
+ )
except Exception:
raise SynapseError(400, "Invalid filter JSON")
self.filtering.check_valid_filter(filter_object)
filter = FilterCollection(filter_object)
else:
- filter = yield self.filtering.get_user_filter(
- user.localpart, filter_id
- )
+ filter = yield self.filtering.get_user_filter(user.localpart, filter_id)
else:
filter = DEFAULT_FILTER_COLLECTION
@@ -156,15 +154,19 @@ class SyncRestServlet(RestServlet):
affect_presence = set_presence != PresenceState.OFFLINE
if affect_presence:
- yield self.presence_handler.set_state(user, {"presence": set_presence}, True)
+ yield self.presence_handler.set_state(
+ user, {"presence": set_presence}, True
+ )
context = yield self.presence_handler.user_syncing(
- user.to_string(), affect_presence=affect_presence,
+ user.to_string(), affect_presence=affect_presence
)
with context:
sync_result = yield self.sync_handler.wait_for_sync_for_user(
- sync_config, since_token=since_token, timeout=timeout,
- full_state=full_state
+ sync_config,
+ since_token=since_token,
+ timeout=timeout,
+ full_state=full_state,
)
time_now = self.clock.time_msec()
@@ -176,53 +178,54 @@ class SyncRestServlet(RestServlet):
@defer.inlineCallbacks
def encode_response(self, time_now, sync_result, access_token_id, filter):
- if filter.event_format == 'client':
+ if filter.event_format == "client":
event_formatter = format_event_for_client_v2_without_room_id
- elif filter.event_format == 'federation':
+ elif filter.event_format == "federation":
event_formatter = format_event_raw
else:
- raise Exception("Unknown event format %s" % (filter.event_format, ))
+ raise Exception("Unknown event format %s" % (filter.event_format,))
joined = yield self.encode_joined(
- sync_result.joined, time_now, access_token_id,
+ sync_result.joined,
+ time_now,
+ access_token_id,
filter.event_fields,
event_formatter,
)
invited = yield self.encode_invited(
- sync_result.invited, time_now, access_token_id,
- event_formatter,
+ sync_result.invited, time_now, access_token_id, event_formatter
)
archived = yield self.encode_archived(
- sync_result.archived, time_now, access_token_id,
+ sync_result.archived,
+ time_now,
+ access_token_id,
filter.event_fields,
event_formatter,
)
- defer.returnValue({
- "account_data": {"events": sync_result.account_data},
- "to_device": {"events": sync_result.to_device},
- "device_lists": {
- "changed": list(sync_result.device_lists.changed),
- "left": list(sync_result.device_lists.left),
- },
- "presence": SyncRestServlet.encode_presence(
- sync_result.presence, time_now
- ),
- "rooms": {
- "join": joined,
- "invite": invited,
- "leave": archived,
- },
- "groups": {
- "join": sync_result.groups.join,
- "invite": sync_result.groups.invite,
- "leave": sync_result.groups.leave,
- },
- "device_one_time_keys_count": sync_result.device_one_time_keys_count,
- "next_batch": sync_result.next_batch.to_string(),
- })
+ defer.returnValue(
+ {
+ "account_data": {"events": sync_result.account_data},
+ "to_device": {"events": sync_result.to_device},
+ "device_lists": {
+ "changed": list(sync_result.device_lists.changed),
+ "left": list(sync_result.device_lists.left),
+ },
+ "presence": SyncRestServlet.encode_presence(
+ sync_result.presence, time_now
+ ),
+ "rooms": {"join": joined, "invite": invited, "leave": archived},
+ "groups": {
+ "join": sync_result.groups.join,
+ "invite": sync_result.groups.invite,
+ "leave": sync_result.groups.leave,
+ },
+ "device_one_time_keys_count": sync_result.device_one_time_keys_count,
+ "next_batch": sync_result.next_batch.to_string(),
+ }
+ )
@staticmethod
def encode_presence(events, time_now):
@@ -262,7 +265,11 @@ class SyncRestServlet(RestServlet):
joined = {}
for room in rooms:
joined[room.room_id] = yield self.encode_room(
- room, time_now, token_id, joined=True, only_fields=event_fields,
+ room,
+ time_now,
+ token_id,
+ joined=True,
+ only_fields=event_fields,
event_formatter=event_formatter,
)
@@ -290,7 +297,9 @@ class SyncRestServlet(RestServlet):
invited = {}
for room in rooms:
invite = yield self._event_serializer.serialize_event(
- room.invite, time_now, token_id=token_id,
+ room.invite,
+ time_now,
+ token_id=token_id,
event_format=event_formatter,
is_invite=True,
)
@@ -298,9 +307,7 @@ class SyncRestServlet(RestServlet):
invite["unsigned"] = unsigned
invited_state = list(unsigned.pop("invite_room_state", []))
invited_state.append(invite)
- invited[room.room_id] = {
- "invite_state": {"events": invited_state}
- }
+ invited[room.room_id] = {"invite_state": {"events": invited_state}}
defer.returnValue(invited)
@@ -327,7 +334,10 @@ class SyncRestServlet(RestServlet):
joined = {}
for room in rooms:
joined[room.room_id] = yield self.encode_room(
- room, time_now, token_id, joined=False,
+ room,
+ time_now,
+ token_id,
+ joined=False,
only_fields=event_fields,
event_formatter=event_formatter,
)
@@ -336,8 +346,7 @@ class SyncRestServlet(RestServlet):
@defer.inlineCallbacks
def encode_room(
- self, room, time_now, token_id, joined,
- only_fields, event_formatter,
+ self, room, time_now, token_id, joined, only_fields, event_formatter
):
"""
Args:
@@ -355,9 +364,11 @@ class SyncRestServlet(RestServlet):
Returns:
dict[str, object]: the room, encoded in our response format
"""
+
def serialize(events):
return self._event_serializer.serialize_events(
- events, time_now=time_now,
+ events,
+ time_now=time_now,
# We don't bundle "live" events, as otherwise clients
# will end up double counting annotations.
bundle_aggregations=False,
@@ -377,7 +388,9 @@ class SyncRestServlet(RestServlet):
if event.room_id != room.room_id:
logger.warn(
"Event %r is under room %r instead of %r",
- event.event_id, room.room_id, event.room_id,
+ event.event_id,
+ room.room_id,
+ event.room_id,
)
serialized_state = yield serialize(state_events)
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index ebff7cff45..07b6ede603 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -29,9 +29,8 @@ class TagListServlet(RestServlet):
"""
GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1
"""
- PATTERNS = client_patterns(
- "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags"
- )
+
+ PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags")
def __init__(self, hs):
super(TagListServlet, self).__init__()
@@ -54,6 +53,7 @@ class TagServlet(RestServlet):
PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
"""
+
PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)"
)
@@ -74,9 +74,7 @@ class TagServlet(RestServlet):
max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body)
- self.notifier.on_new_event(
- "account_data_key", max_id, users=[user_id]
- )
+ self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
defer.returnValue((200, {}))
@@ -88,9 +86,7 @@ class TagServlet(RestServlet):
max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)
- self.notifier.on_new_event(
- "account_data_key", max_id, users=[user_id]
- )
+ self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index e7a987466a..1e66662a05 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -57,7 +57,7 @@ class ThirdPartyProtocolServlet(RestServlet):
yield self.auth.get_user_by_req(request, allow_guest=True)
protocols = yield self.appservice_handler.get_3pe_protocols(
- only_protocol=protocol,
+ only_protocol=protocol
)
if protocol in protocols:
defer.returnValue((200, protocols[protocol]))
diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py
index 6c366142e1..2da0f55811 100644
--- a/synapse/rest/client/v2_alpha/tokenrefresh.py
+++ b/synapse/rest/client/v2_alpha/tokenrefresh.py
@@ -26,6 +26,7 @@ class TokenRefreshRestServlet(RestServlet):
Exchanges refresh tokens for a pair of an access token and a new refresh
token.
"""
+
PATTERNS = client_patterns("/tokenrefresh")
def __init__(self, hs):
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index 69e4efc47a..e19fb6d583 100644
--- a/synapse/rest/client/v2_alpha/user_directory.py
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -60,10 +60,7 @@ class UserDirectorySearchRestServlet(RestServlet):
user_id = requester.user.to_string()
if not self.hs.config.user_directory_search_enabled:
- defer.returnValue((200, {
- "limited": False,
- "results": [],
- }))
+ defer.returnValue((200, {"limited": False, "results": []}))
body = parse_json_object_from_request(request)
@@ -76,7 +73,7 @@ class UserDirectorySearchRestServlet(RestServlet):
raise SynapseError(400, "`search_term` is required field")
results = yield self.user_directory_handler.search_users(
- user_id, search_term, limit,
+ user_id, search_term, limit
)
defer.returnValue((200, results))
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index babbf6a23c..0e09191632 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -25,27 +25,28 @@ class VersionsRestServlet(RestServlet):
PATTERNS = [re.compile("^/_matrix/client/versions$")]
def on_GET(self, request):
- return (200, {
- "versions": [
- # XXX: at some point we need to decide whether we need to include
- # the previous version numbers, given we've defined r0.3.0 to be
- # backwards compatible with r0.2.0. But need to check how
- # conscientious we've been in compatibility, and decide whether the
- # middle number is the major revision when at 0.X.Y (as opposed to
- # X.Y.Z). And we need to decide whether it's fair to make clients
- # parse the version string to figure out what's going on.
- "r0.0.1",
- "r0.1.0",
- "r0.2.0",
- "r0.3.0",
- "r0.4.0",
- "r0.5.0",
- ],
- # as per MSC1497:
- "unstable_features": {
- "m.lazy_load_members": True,
- }
- })
+ return (
+ 200,
+ {
+ "versions": [
+ # XXX: at some point we need to decide whether we need to include
+ # the previous version numbers, given we've defined r0.3.0 to be
+ # backwards compatible with r0.2.0. But need to check how
+ # conscientious we've been in compatibility, and decide whether the
+ # middle number is the major revision when at 0.X.Y (as opposed to
+ # X.Y.Z). And we need to decide whether it's fair to make clients
+ # parse the version string to figure out what's going on.
+ "r0.0.1",
+ "r0.1.0",
+ "r0.2.0",
+ "r0.3.0",
+ "r0.4.0",
+ "r0.5.0",
+ ],
+ # as per MSC1497:
+ "unstable_features": {"m.lazy_load_members": True},
+ },
+ )
def register_servlets(http_server):
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 6b371bfa2f..9a32892d8b 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -42,6 +42,7 @@ logger = logging.getLogger(__name__)
if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest
else:
+
def compare_digest(a, b):
return a == b
@@ -80,6 +81,7 @@ class ConsentResource(Resource):
For POST: required; gives the value to be recorded in the database
against the user.
"""
+
def __init__(self, hs):
"""
Args:
@@ -98,21 +100,20 @@ class ConsentResource(Resource):
if self._default_consent_version is None:
raise ConfigError(
"Consent resource is enabled but user_consent section is "
- "missing in config file.",
+ "missing in config file."
)
consent_template_directory = hs.config.user_consent_template_dir
loader = jinja2.FileSystemLoader(consent_template_directory)
self._jinja_env = jinja2.Environment(
- loader=loader,
- autoescape=jinja2.select_autoescape(['html', 'htm', 'xml']),
+ loader=loader, autoescape=jinja2.select_autoescape(["html", "htm", "xml"])
)
if hs.config.form_secret is None:
raise ConfigError(
"Consent resource is enabled but form_secret is not set in "
- "config file. It should be set to an arbitrary secret string.",
+ "config file. It should be set to an arbitrary secret string."
)
self._hmac_secret = hs.config.form_secret.encode("utf-8")
@@ -139,7 +140,7 @@ class ConsentResource(Resource):
self._check_hash(username, userhmac_bytes)
- if username.startswith('@'):
+ if username.startswith("@"):
qualified_user_id = username
else:
qualified_user_id = UserID(username, self.hs.hostname).to_string()
@@ -153,7 +154,8 @@ class ConsentResource(Resource):
try:
self._render_template(
- request, "%s.html" % (version,),
+ request,
+ "%s.html" % (version,),
user=username,
userhmac=userhmac,
version=version,
@@ -180,7 +182,7 @@ class ConsentResource(Resource):
self._check_hash(username, userhmac)
- if username.startswith('@'):
+ if username.startswith("@"):
qualified_user_id = username
else:
qualified_user_id = UserID(username, self.hs.hostname).to_string()
@@ -221,11 +223,13 @@ class ConsentResource(Resource):
SynapseError if the hash doesn't match
"""
- want_mac = hmac.new(
- key=self._hmac_secret,
- msg=userid.encode('utf-8'),
- digestmod=sha256,
- ).hexdigest().encode('ascii')
+ want_mac = (
+ hmac.new(
+ key=self._hmac_secret, msg=userid.encode("utf-8"), digestmod=sha256
+ )
+ .hexdigest()
+ .encode("ascii")
+ )
if not compare_digest(want_mac, userhmac):
raise SynapseError(http_client.FORBIDDEN, "HMAC incorrect")
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index ec0ec7b431..c16280f668 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -80,33 +80,27 @@ class LocalKey(Resource):
for key in self.config.signing_key:
verify_key_bytes = key.verify_key.encode()
key_id = "%s:%s" % (key.alg, key.version)
- verify_keys[key_id] = {
- u"key": encode_base64(verify_key_bytes)
- }
+ verify_keys[key_id] = {"key": encode_base64(verify_key_bytes)}
old_verify_keys = {}
for key_id, key in self.config.old_signing_keys.items():
verify_key_bytes = key.encode()
old_verify_keys[key_id] = {
- u"key": encode_base64(verify_key_bytes),
- u"expired_ts": key.expired_ts,
+ "key": encode_base64(verify_key_bytes),
+ "expired_ts": key.expired_ts,
}
tls_fingerprints = self.config.tls_fingerprints
json_object = {
- u"valid_until_ts": self.valid_until_ts,
- u"server_name": self.config.server_name,
- u"verify_keys": verify_keys,
- u"old_verify_keys": old_verify_keys,
- u"tls_fingerprints": tls_fingerprints,
+ "valid_until_ts": self.valid_until_ts,
+ "server_name": self.config.server_name,
+ "verify_keys": verify_keys,
+ "old_verify_keys": old_verify_keys,
+ "tls_fingerprints": tls_fingerprints,
}
for key in self.config.signing_key:
- json_object = sign_json(
- json_object,
- self.config.server_name,
- key,
- )
+ json_object = sign_json(json_object, self.config.server_name, key)
return json_object
def render_GET(self, request):
@@ -114,6 +108,4 @@ class LocalKey(Resource):
# Update the expiry time if less than half the interval remains.
if time_now + self.config.key_refresh_interval / 2 > self.valid_until_ts:
self.update_response_body(time_now)
- return respond_with_json_bytes(
- request, 200, self.response_body,
- )
+ return respond_with_json_bytes(request, 200, self.response_body)
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 8a730bbc35..ec8b9d7269 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -103,20 +103,16 @@ class RemoteKey(Resource):
def async_render_GET(self, request):
if len(request.postpath) == 1:
server, = request.postpath
- query = {server.decode('ascii'): {}}
+ query = {server.decode("ascii"): {}}
elif len(request.postpath) == 2:
server, key_id = request.postpath
- minimum_valid_until_ts = parse_integer(
- request, "minimum_valid_until_ts"
- )
+ minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
arguments = {}
if minimum_valid_until_ts is not None:
arguments["minimum_valid_until_ts"] = minimum_valid_until_ts
- query = {server.decode('ascii'): {key_id.decode('ascii'): arguments}}
+ query = {server.decode("ascii"): {key_id.decode("ascii"): arguments}}
else:
- raise SynapseError(
- 404, "Not found %r" % request.postpath, Codes.NOT_FOUND
- )
+ raise SynapseError(404, "Not found %r" % request.postpath, Codes.NOT_FOUND)
yield self.query_keys(request, query, query_remote_on_cache_miss=True)
@@ -140,8 +136,8 @@ class RemoteKey(Resource):
store_queries = []
for server_name, key_ids in query.items():
if (
- self.federation_domain_whitelist is not None and
- server_name not in self.federation_domain_whitelist
+ self.federation_domain_whitelist is not None
+ and server_name not in self.federation_domain_whitelist
):
logger.debug("Federation denied with %s", server_name)
continue
@@ -159,9 +155,7 @@ class RemoteKey(Resource):
cache_misses = dict()
for (server_name, key_id, from_server), results in cached.items():
- results = [
- (result["ts_added_ms"], result) for result in results
- ]
+ results = [(result["ts_added_ms"], result) for result in results]
if not results and key_id is not None:
cache_misses.setdefault(server_name, set()).add(key_id)
@@ -178,23 +172,30 @@ class RemoteKey(Resource):
logger.debug(
"Cached response for %r/%r is older than requested"
": valid_until (%r) < minimum_valid_until (%r)",
- server_name, key_id,
- ts_valid_until_ms, req_valid_until
+ server_name,
+ key_id,
+ ts_valid_until_ms,
+ req_valid_until,
)
miss = True
else:
logger.debug(
"Cached response for %r/%r is newer than requested"
": valid_until (%r) >= minimum_valid_until (%r)",
- server_name, key_id,
- ts_valid_until_ms, req_valid_until
+ server_name,
+ key_id,
+ ts_valid_until_ms,
+ req_valid_until,
)
elif (ts_added_ms + ts_valid_until_ms) / 2 < time_now_ms:
logger.debug(
"Cached response for %r/%r is too old"
": (added (%r) + valid_until (%r)) / 2 < now (%r)",
- server_name, key_id,
- ts_added_ms, ts_valid_until_ms, time_now_ms
+ server_name,
+ key_id,
+ ts_added_ms,
+ ts_valid_until_ms,
+ time_now_ms,
)
# We more than half way through the lifetime of the
# response. We should fetch a fresh copy.
@@ -203,8 +204,11 @@ class RemoteKey(Resource):
logger.debug(
"Cached response for %r/%r is still valid"
": (added (%r) + valid_until (%r)) / 2 < now (%r)",
- server_name, key_id,
- ts_added_ms, ts_valid_until_ms, time_now_ms
+ server_name,
+ key_id,
+ ts_added_ms,
+ ts_valid_until_ms,
+ time_now_ms,
)
if miss:
@@ -216,12 +220,10 @@ class RemoteKey(Resource):
if cache_misses and query_remote_on_cache_miss:
yield self.fetcher.get_keys(cache_misses)
- yield self.query_keys(
- request, query, query_remote_on_cache_miss=False
- )
+ yield self.query_keys(request, query, query_remote_on_cache_miss=False)
else:
result_io = BytesIO()
- result_io.write(b"{\"server_keys\":")
+ result_io.write(b'{"server_keys":')
sep = b"["
for json_bytes in json_results:
result_io.write(sep)
@@ -231,6 +233,4 @@ class RemoteKey(Resource):
result_io.write(sep)
result_io.write(b"]}")
- respond_with_json_bytes(
- request, 200, result_io.getvalue(),
- )
+ respond_with_json_bytes(request, 200, result_io.getvalue())
diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py
index 5a426ff2f6..86884c0ef4 100644
--- a/synapse/rest/media/v0/content_repository.py
+++ b/synapse/rest/media/v0/content_repository.py
@@ -44,6 +44,7 @@ class ContentRepoResource(resource.Resource):
- Content type base64d (so we can return it when clients GET it)
"""
+
isLeaf = True
def __init__(self, hs, directory):
@@ -56,7 +57,7 @@ class ContentRepoResource(resource.Resource):
# servers.
# TODO: A little crude here, we could do this better.
- filename = request.path.decode('ascii').split('/')[-1]
+ filename = request.path.decode("ascii").split("/")[-1]
# be paranoid
filename = re.sub("[^0-9A-z.-_]", "", filename)
@@ -69,17 +70,15 @@ class ContentRepoResource(resource.Resource):
base64_contentype = filename.split(".")[1]
content_type = base64.urlsafe_b64decode(base64_contentype)
logger.info("Sending file %s", file_path)
- f = open(file_path, 'rb')
- request.setHeader('Content-Type', content_type)
+ f = open(file_path, "rb")
+ request.setHeader("Content-Type", content_type)
# cache for at least a day.
# XXX: we might want to turn this off for data we don't want to
# recommend caching as it's sensitive or private - or at least
# select private. don't bother setting Expires as all our matrix
# clients are smart enough to be happy with Cache-Control (right?)
- request.setHeader(
- b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
- )
+ request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
d = FileSender().beginFileTransfer(f, request)
@@ -87,13 +86,15 @@ class ContentRepoResource(resource.Resource):
def cbFinished(ignored):
f.close()
finish_request(request)
+
d.addCallback(cbFinished)
else:
respond_with_json_bytes(
request,
404,
json.dumps(cs_error("Not found", code=Codes.NOT_FOUND)),
- send_cors=True)
+ send_cors=True,
+ )
return server.NOT_DONE_YET
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 2dcc8f74d6..3318638d3e 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -38,8 +38,8 @@ def parse_media_id(request):
server_name, media_id = request.postpath[:2]
if isinstance(server_name, bytes):
- server_name = server_name.decode('utf-8')
- media_id = media_id.decode('utf8')
+ server_name = server_name.decode("utf-8")
+ media_id = media_id.decode("utf8")
file_name = None
if len(request.postpath) > 2:
@@ -120,11 +120,11 @@ def add_file_headers(request, media_type, file_size, upload_name):
# correctly interpret those as of 0.99.2 and (b) they are a bit of a pain and we
# may as well just do the filename* version.
if _can_encode_filename_as_token(upload_name):
- disposition = 'inline; filename=%s' % (upload_name, )
+ disposition = "inline; filename=%s" % (upload_name,)
else:
- disposition = "inline; filename*=utf-8''%s" % (_quote(upload_name), )
+ disposition = "inline; filename*=utf-8''%s" % (_quote(upload_name),)
- request.setHeader(b"Content-Disposition", disposition.encode('ascii'))
+ request.setHeader(b"Content-Disposition", disposition.encode("ascii"))
# cache for at least a day.
# XXX: we might want to turn this off for data we don't want to
@@ -137,10 +137,27 @@ def add_file_headers(request, media_type, file_size, upload_name):
# separators as defined in RFC2616. SP and HT are handled separately.
# see _can_encode_filename_as_token.
-_FILENAME_SEPARATOR_CHARS = set((
- "(", ")", "<", ">", "@", ",", ";", ":", "\\", '"',
- "/", "[", "]", "?", "=", "{", "}",
-))
+_FILENAME_SEPARATOR_CHARS = set(
+ (
+ "(",
+ ")",
+ "<",
+ ">",
+ "@",
+ ",",
+ ";",
+ ":",
+ "\\",
+ '"',
+ "/",
+ "[",
+ "]",
+ "?",
+ "=",
+ "{",
+ "}",
+ )
+)
def _can_encode_filename_as_token(x):
@@ -271,7 +288,7 @@ def get_filename_from_headers(headers):
Returns:
A Unicode string of the filename, or None.
"""
- content_disposition = headers.get(b"Content-Disposition", [b''])
+ content_disposition = headers.get(b"Content-Disposition", [b""])
# No header, bail out.
if not content_disposition[0]:
@@ -293,7 +310,7 @@ def get_filename_from_headers(headers):
# Once it is decoded, we can then unquote the %-encoded
# parts strictly into a unicode string.
upload_name = urllib.parse.unquote(
- upload_name_utf8.decode('ascii'), errors="strict"
+ upload_name_utf8.decode("ascii"), errors="strict"
)
except UnicodeDecodeError:
# Incorrect UTF-8.
@@ -302,7 +319,7 @@ def get_filename_from_headers(headers):
# On Python 2, we first unquote the %-encoded parts and then
# decode it strictly using UTF-8.
try:
- upload_name = urllib.parse.unquote(upload_name_utf8).decode('utf8')
+ upload_name = urllib.parse.unquote(upload_name_utf8).decode("utf8")
except UnicodeDecodeError:
pass
@@ -310,7 +327,7 @@ def get_filename_from_headers(headers):
if not upload_name:
upload_name_ascii = params.get(b"filename", None)
if upload_name_ascii and is_ascii(upload_name_ascii):
- upload_name = upload_name_ascii.decode('ascii')
+ upload_name = upload_name_ascii.decode("ascii")
# This may be None here, indicating we did not find a matching name.
return upload_name
@@ -328,19 +345,19 @@ def _parse_header(line):
Tuple[bytes, dict[bytes, bytes]]:
the main content-type, followed by the parameter dictionary
"""
- parts = _parseparam(b';' + line)
+ parts = _parseparam(b";" + line)
key = next(parts)
pdict = {}
for p in parts:
- i = p.find(b'=')
+ i = p.find(b"=")
if i >= 0:
name = p[:i].strip().lower()
- value = p[i + 1:].strip()
+ value = p[i + 1 :].strip()
# strip double-quotes
if len(value) >= 2 and value[0:1] == value[-1:] == b'"':
value = value[1:-1]
- value = value.replace(b'\\\\', b'\\').replace(b'\\"', b'"')
+ value = value.replace(b"\\\\", b"\\").replace(b'\\"', b'"')
pdict[name] = value
return key, pdict
@@ -357,16 +374,16 @@ def _parseparam(s):
Returns:
Iterable[bytes]: the split input
"""
- while s[:1] == b';':
+ while s[:1] == b";":
s = s[1:]
# look for the next ;
- end = s.find(b';')
+ end = s.find(b";")
# if there is an odd number of " marks between here and the next ;, skip to the
# next ; instead
while end > 0 and (s.count(b'"', 0, end) - s.count(b'\\"', 0, end)) % 2:
- end = s.find(b';', end + 1)
+ end = s.find(b";", end + 1)
if end < 0:
end = len(s)
diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py
index 77316033f7..fa3d6680fc 100644
--- a/synapse/rest/media/v1/config_resource.py
+++ b/synapse/rest/media/v1/config_resource.py
@@ -29,9 +29,7 @@ class MediaConfigResource(Resource):
config = hs.get_config()
self.clock = hs.get_clock()
self.auth = hs.get_auth()
- self.limits_dict = {
- "m.upload.size": config.max_upload_size,
- }
+ self.limits_dict = {"m.upload.size": config.max_upload_size}
def render_GET(self, request):
self._async_render_GET(request)
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index bdc5daecc1..a21a35f843 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -54,18 +54,20 @@ class DownloadResource(Resource):
b" plugin-types application/pdf;"
b" style-src 'unsafe-inline';"
b" media-src 'self';"
- b" object-src 'self';"
+ b" object-src 'self';",
)
server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name:
yield self.media_repo.get_local_media(request, media_id, name)
else:
allow_remote = synapse.http.servlet.parse_boolean(
- request, "allow_remote", default=True)
+ request, "allow_remote", default=True
+ )
if not allow_remote:
logger.info(
"Rejecting request for remote media %s/%s due to allow_remote",
- server_name, media_id,
+ server_name,
+ media_id,
)
respond_404(request)
return
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index c8586fa280..e25c382c9c 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -24,6 +24,7 @@ def _wrap_in_base_path(func):
"""Takes a function that returns a relative path and turns it into an
absolute path based on the location of the primary media store
"""
+
@functools.wraps(func)
def _wrapped(self, *args, **kwargs):
path = func(self, *args, **kwargs)
@@ -43,125 +44,102 @@ class MediaFilePaths(object):
def __init__(self, primary_base_path):
self.base_path = primary_base_path
- def default_thumbnail_rel(self, default_top_level, default_sub_type, width,
- height, content_type, method):
+ def default_thumbnail_rel(
+ self, default_top_level, default_sub_type, width, height, content_type, method
+ ):
top_level_type, sub_type = content_type.split("/")
- file_name = "%i-%i-%s-%s-%s" % (
- width, height, top_level_type, sub_type, method
- )
+ file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
- "default_thumbnails", default_top_level,
- default_sub_type, file_name
+ "default_thumbnails", default_top_level, default_sub_type, file_name
)
default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
def local_media_filepath_rel(self, media_id):
- return os.path.join(
- "local_content",
- media_id[0:2], media_id[2:4], media_id[4:]
- )
+ return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:])
local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
- def local_media_thumbnail_rel(self, media_id, width, height, content_type,
- method):
+ def local_media_thumbnail_rel(self, media_id, width, height, content_type, method):
top_level_type, sub_type = content_type.split("/")
- file_name = "%i-%i-%s-%s-%s" % (
- width, height, top_level_type, sub_type, method
- )
+ file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
- "local_thumbnails",
- media_id[0:2], media_id[2:4], media_id[4:],
- file_name
+ "local_thumbnails", media_id[0:2], media_id[2:4], media_id[4:], file_name
)
local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel)
def remote_media_filepath_rel(self, server_name, file_id):
return os.path.join(
- "remote_content", server_name,
- file_id[0:2], file_id[2:4], file_id[4:]
+ "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:]
)
remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
- def remote_media_thumbnail_rel(self, server_name, file_id, width, height,
- content_type, method):
+ def remote_media_thumbnail_rel(
+ self, server_name, file_id, width, height, content_type, method
+ ):
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
return os.path.join(
- "remote_thumbnail", server_name,
- file_id[0:2], file_id[2:4], file_id[4:],
- file_name
+ "remote_thumbnail",
+ server_name,
+ file_id[0:2],
+ file_id[2:4],
+ file_id[4:],
+ file_name,
)
remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel)
def remote_media_thumbnail_dir(self, server_name, file_id):
return os.path.join(
- self.base_path, "remote_thumbnail", server_name,
- file_id[0:2], file_id[2:4], file_id[4:],
+ self.base_path,
+ "remote_thumbnail",
+ server_name,
+ file_id[0:2],
+ file_id[2:4],
+ file_id[4:],
)
def url_cache_filepath_rel(self, media_id):
if NEW_FORMAT_ID_RE.match(media_id):
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
- return os.path.join(
- "url_cache",
- media_id[:10], media_id[11:]
- )
+ return os.path.join("url_cache", media_id[:10], media_id[11:])
else:
- return os.path.join(
- "url_cache",
- media_id[0:2], media_id[2:4], media_id[4:],
- )
+ return os.path.join("url_cache", media_id[0:2], media_id[2:4], media_id[4:])
url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
def url_cache_filepath_dirs_to_delete(self, media_id):
"The dirs to try and remove if we delete the media_id file"
if NEW_FORMAT_ID_RE.match(media_id):
- return [
- os.path.join(
- self.base_path, "url_cache",
- media_id[:10],
- ),
- ]
+ return [os.path.join(self.base_path, "url_cache", media_id[:10])]
else:
return [
- os.path.join(
- self.base_path, "url_cache",
- media_id[0:2], media_id[2:4],
- ),
- os.path.join(
- self.base_path, "url_cache",
- media_id[0:2],
- ),
+ os.path.join(self.base_path, "url_cache", media_id[0:2], media_id[2:4]),
+ os.path.join(self.base_path, "url_cache", media_id[0:2]),
]
- def url_cache_thumbnail_rel(self, media_id, width, height, content_type,
- method):
+ def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method):
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
top_level_type, sub_type = content_type.split("/")
- file_name = "%i-%i-%s-%s-%s" % (
- width, height, top_level_type, sub_type, method
- )
+ file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
if NEW_FORMAT_ID_RE.match(media_id):
return os.path.join(
- "url_cache_thumbnails",
- media_id[:10], media_id[11:],
- file_name
+ "url_cache_thumbnails", media_id[:10], media_id[11:], file_name
)
else:
return os.path.join(
"url_cache_thumbnails",
- media_id[0:2], media_id[2:4], media_id[4:],
- file_name
+ media_id[0:2],
+ media_id[2:4],
+ media_id[4:],
+ file_name,
)
url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
@@ -172,13 +150,15 @@ class MediaFilePaths(object):
if NEW_FORMAT_ID_RE.match(media_id):
return os.path.join(
- self.base_path, "url_cache_thumbnails",
- media_id[:10], media_id[11:],
+ self.base_path, "url_cache_thumbnails", media_id[:10], media_id[11:]
)
else:
return os.path.join(
- self.base_path, "url_cache_thumbnails",
- media_id[0:2], media_id[2:4], media_id[4:],
+ self.base_path,
+ "url_cache_thumbnails",
+ media_id[0:2],
+ media_id[2:4],
+ media_id[4:],
)
def url_cache_thumbnail_dirs_to_delete(self, media_id):
@@ -188,26 +168,21 @@ class MediaFilePaths(object):
if NEW_FORMAT_ID_RE.match(media_id):
return [
os.path.join(
- self.base_path, "url_cache_thumbnails",
- media_id[:10], media_id[11:],
- ),
- os.path.join(
- self.base_path, "url_cache_thumbnails",
- media_id[:10],
+ self.base_path, "url_cache_thumbnails", media_id[:10], media_id[11:]
),
+ os.path.join(self.base_path, "url_cache_thumbnails", media_id[:10]),
]
else:
return [
os.path.join(
- self.base_path, "url_cache_thumbnails",
- media_id[0:2], media_id[2:4], media_id[4:],
- ),
- os.path.join(
- self.base_path, "url_cache_thumbnails",
- media_id[0:2], media_id[2:4],
+ self.base_path,
+ "url_cache_thumbnails",
+ media_id[0:2],
+ media_id[2:4],
+ media_id[4:],
),
os.path.join(
- self.base_path, "url_cache_thumbnails",
- media_id[0:2],
+ self.base_path, "url_cache_thumbnails", media_id[0:2], media_id[2:4]
),
+ os.path.join(self.base_path, "url_cache_thumbnails", media_id[0:2]),
]
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 8569677355..df3d985a38 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -100,17 +100,16 @@ class MediaRepository(object):
storage_providers.append(provider)
self.media_storage = MediaStorage(
- self.hs, self.primary_base_path, self.filepaths, storage_providers,
+ self.hs, self.primary_base_path, self.filepaths, storage_providers
)
self.clock.looping_call(
- self._start_update_recently_accessed,
- UPDATE_RECENTLY_ACCESSED_TS,
+ self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS
)
def _start_update_recently_accessed(self):
return run_as_background_process(
- "update_recently_accessed_media", self._update_recently_accessed,
+ "update_recently_accessed_media", self._update_recently_accessed
)
@defer.inlineCallbacks
@@ -138,8 +137,9 @@ class MediaRepository(object):
self.recently_accessed_locals.add(media_id)
@defer.inlineCallbacks
- def create_content(self, media_type, upload_name, content, content_length,
- auth_user):
+ def create_content(
+ self, media_type, upload_name, content, content_length, auth_user
+ ):
"""Store uploaded content for a local user and return the mxc URL
Args:
@@ -154,10 +154,7 @@ class MediaRepository(object):
"""
media_id = random_string(24)
- file_info = FileInfo(
- server_name=None,
- file_id=media_id,
- )
+ file_info = FileInfo(server_name=None, file_id=media_id)
fname = yield self.media_storage.store_file(content, file_info)
@@ -172,9 +169,7 @@ class MediaRepository(object):
user_id=auth_user,
)
- yield self._generate_thumbnails(
- None, media_id, media_id, media_type,
- )
+ yield self._generate_thumbnails(None, media_id, media_id, media_type)
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
@@ -205,14 +200,11 @@ class MediaRepository(object):
upload_name = name if name else media_info["upload_name"]
url_cache = media_info["url_cache"]
- file_info = FileInfo(
- None, media_id,
- url_cache=url_cache,
- )
+ file_info = FileInfo(None, media_id, url_cache=url_cache)
responder = yield self.media_storage.fetch_media(file_info)
yield respond_with_responder(
- request, responder, media_type, media_length, upload_name,
+ request, responder, media_type, media_length, upload_name
)
@defer.inlineCallbacks
@@ -232,8 +224,8 @@ class MediaRepository(object):
to request
"""
if (
- self.federation_domain_whitelist is not None and
- server_name not in self.federation_domain_whitelist
+ self.federation_domain_whitelist is not None
+ and server_name not in self.federation_domain_whitelist
):
raise FederationDeniedError(server_name)
@@ -244,7 +236,7 @@ class MediaRepository(object):
key = (server_name, media_id)
with (yield self.remote_media_linearizer.queue(key)):
responder, media_info = yield self._get_remote_media_impl(
- server_name, media_id,
+ server_name, media_id
)
# We deliberately stream the file outside the lock
@@ -253,7 +245,7 @@ class MediaRepository(object):
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
yield respond_with_responder(
- request, responder, media_type, media_length, upload_name,
+ request, responder, media_type, media_length, upload_name
)
else:
respond_404(request)
@@ -272,8 +264,8 @@ class MediaRepository(object):
Deferred[dict]: The media_info of the file
"""
if (
- self.federation_domain_whitelist is not None and
- server_name not in self.federation_domain_whitelist
+ self.federation_domain_whitelist is not None
+ and server_name not in self.federation_domain_whitelist
):
raise FederationDeniedError(server_name)
@@ -282,7 +274,7 @@ class MediaRepository(object):
key = (server_name, media_id)
with (yield self.remote_media_linearizer.queue(key)):
responder, media_info = yield self._get_remote_media_impl(
- server_name, media_id,
+ server_name, media_id
)
# Ensure we actually use the responder so that it releases resources
@@ -305,9 +297,7 @@ class MediaRepository(object):
Returns:
Deferred[(Responder, media_info)]
"""
- media_info = yield self.store.get_cached_remote_media(
- server_name, media_id
- )
+ media_info = yield self.store.get_cached_remote_media(server_name, media_id)
# file_id is the ID we use to track the file locally. If we've already
# seen the file then reuse the existing ID, otherwise genereate a new
@@ -331,9 +321,7 @@ class MediaRepository(object):
# Failed to find the file anywhere, lets download it.
- media_info = yield self._download_remote_file(
- server_name, media_id, file_id
- )
+ media_info = yield self._download_remote_file(server_name, media_id, file_id)
responder = yield self.media_storage.fetch_media(file_info)
defer.returnValue((responder, media_info))
@@ -354,52 +342,60 @@ class MediaRepository(object):
Deferred[MediaInfo]
"""
- file_info = FileInfo(
- server_name=server_name,
- file_id=file_id,
- )
+ file_info = FileInfo(server_name=server_name, file_id=file_id)
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
- request_path = "/".join((
- "/_matrix/media/v1/download", server_name, media_id,
- ))
+ request_path = "/".join(
+ ("/_matrix/media/v1/download", server_name, media_id)
+ )
try:
length, headers = yield self.client.get_file(
- server_name, request_path, output_stream=f,
- max_size=self.max_upload_size, args={
+ server_name,
+ request_path,
+ output_stream=f,
+ max_size=self.max_upload_size,
+ args={
# tell the remote server to 404 if it doesn't
# recognise the server_name, to make sure we don't
# end up with a routing loop.
- "allow_remote": "false",
- }
+ "allow_remote": "false"
+ },
)
except RequestSendFailed as e:
- logger.warn("Request failed fetching remote media %s/%s: %r",
- server_name, media_id, e)
+ logger.warn(
+ "Request failed fetching remote media %s/%s: %r",
+ server_name,
+ media_id,
+ e,
+ )
raise SynapseError(502, "Failed to fetch remote media")
except HttpResponseException as e:
- logger.warn("HTTP error fetching remote media %s/%s: %s",
- server_name, media_id, e.response)
+ logger.warn(
+ "HTTP error fetching remote media %s/%s: %s",
+ server_name,
+ media_id,
+ e.response,
+ )
if e.code == twisted.web.http.NOT_FOUND:
raise e.to_synapse_error()
raise SynapseError(502, "Failed to fetch remote media")
except SynapseError:
- logger.exception("Failed to fetch remote media %s/%s",
- server_name, media_id)
+ logger.warn("Failed to fetch remote media %s/%s", server_name, media_id)
raise
except NotRetryingDestination:
logger.warn("Not retrying destination %r", server_name)
raise SynapseError(502, "Failed to fetch remote media")
except Exception:
- logger.exception("Failed to fetch remote media %s/%s",
- server_name, media_id)
+ logger.exception(
+ "Failed to fetch remote media %s/%s", server_name, media_id
+ )
raise SynapseError(502, "Failed to fetch remote media")
yield finish()
- media_type = headers[b"Content-Type"][0].decode('ascii')
+ media_type = headers[b"Content-Type"][0].decode("ascii")
upload_name = get_filename_from_headers(headers)
time_now_ms = self.clock.time_msec()
@@ -423,24 +419,23 @@ class MediaRepository(object):
"filesystem_id": file_id,
}
- yield self._generate_thumbnails(
- server_name, media_id, file_id, media_type,
- )
+ yield self._generate_thumbnails(server_name, media_id, file_id, media_type)
defer.returnValue(media_info)
def _get_thumbnail_requirements(self, media_type):
return self.thumbnail_requirements.get(media_type, ())
- def _generate_thumbnail(self, thumbnailer, t_width, t_height,
- t_method, t_type):
+ def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type):
m_width = thumbnailer.width
m_height = thumbnailer.height
if m_width * m_height >= self.max_image_pixels:
logger.info(
"Image too large to thumbnail %r x %r > %r",
- m_width, m_height, self.max_image_pixels
+ m_width,
+ m_height,
+ self.max_image_pixels,
)
return
@@ -460,17 +455,22 @@ class MediaRepository(object):
return t_byte_source
@defer.inlineCallbacks
- def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
- t_method, t_type, url_cache):
- input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
- None, media_id, url_cache=url_cache,
- ))
+ def generate_local_exact_thumbnail(
+ self, media_id, t_width, t_height, t_method, t_type, url_cache
+ ):
+ input_path = yield self.media_storage.ensure_media_is_in_local_cache(
+ FileInfo(None, media_id, url_cache=url_cache)
+ )
thumbnailer = Thumbnailer(input_path)
t_byte_source = yield logcontext.defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
- thumbnailer, t_width, t_height, t_method, t_type
+ thumbnailer,
+ t_width,
+ t_height,
+ t_method,
+ t_type,
)
if t_byte_source:
@@ -487,7 +487,7 @@ class MediaRepository(object):
)
output_path = yield self.media_storage.store_file(
- t_byte_source, file_info,
+ t_byte_source, file_info
)
finally:
t_byte_source.close()
@@ -503,17 +503,22 @@ class MediaRepository(object):
defer.returnValue(output_path)
@defer.inlineCallbacks
- def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
- t_width, t_height, t_method, t_type):
- input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
- server_name, file_id, url_cache=False,
- ))
+ def generate_remote_exact_thumbnail(
+ self, server_name, file_id, media_id, t_width, t_height, t_method, t_type
+ ):
+ input_path = yield self.media_storage.ensure_media_is_in_local_cache(
+ FileInfo(server_name, file_id, url_cache=False)
+ )
thumbnailer = Thumbnailer(input_path)
t_byte_source = yield logcontext.defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
- thumbnailer, t_width, t_height, t_method, t_type
+ thumbnailer,
+ t_width,
+ t_height,
+ t_method,
+ t_type,
)
if t_byte_source:
@@ -529,7 +534,7 @@ class MediaRepository(object):
)
output_path = yield self.media_storage.store_file(
- t_byte_source, file_info,
+ t_byte_source, file_info
)
finally:
t_byte_source.close()
@@ -539,15 +544,22 @@ class MediaRepository(object):
t_len = os.path.getsize(output_path)
yield self.store.store_remote_media_thumbnail(
- server_name, media_id, file_id,
- t_width, t_height, t_type, t_method, t_len
+ server_name,
+ media_id,
+ file_id,
+ t_width,
+ t_height,
+ t_type,
+ t_method,
+ t_len,
)
defer.returnValue(output_path)
@defer.inlineCallbacks
- def _generate_thumbnails(self, server_name, media_id, file_id, media_type,
- url_cache=False):
+ def _generate_thumbnails(
+ self, server_name, media_id, file_id, media_type, url_cache=False
+ ):
"""Generate and store thumbnails for an image.
Args:
@@ -566,9 +578,9 @@ class MediaRepository(object):
if not requirements:
return
- input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
- server_name, file_id, url_cache=url_cache,
- ))
+ input_path = yield self.media_storage.ensure_media_is_in_local_cache(
+ FileInfo(server_name, file_id, url_cache=url_cache)
+ )
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
@@ -577,14 +589,15 @@ class MediaRepository(object):
if m_width * m_height >= self.max_image_pixels:
logger.info(
"Image too large to thumbnail %r x %r > %r",
- m_width, m_height, self.max_image_pixels
+ m_width,
+ m_height,
+ self.max_image_pixels,
)
return
if thumbnailer.transpose_method is not None:
m_width, m_height = yield logcontext.defer_to_thread(
- self.hs.get_reactor(),
- thumbnailer.transpose
+ self.hs.get_reactor(), thumbnailer.transpose
)
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
@@ -604,15 +617,11 @@ class MediaRepository(object):
# Generate the thumbnail
if t_method == "crop":
t_byte_source = yield logcontext.defer_to_thread(
- self.hs.get_reactor(),
- thumbnailer.crop,
- t_width, t_height, t_type,
+ self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type
)
elif t_method == "scale":
t_byte_source = yield logcontext.defer_to_thread(
- self.hs.get_reactor(),
- thumbnailer.scale,
- t_width, t_height, t_type,
+ self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type
)
else:
logger.error("Unrecognized method: %r", t_method)
@@ -634,7 +643,7 @@ class MediaRepository(object):
)
output_path = yield self.media_storage.store_file(
- t_byte_source, file_info,
+ t_byte_source, file_info
)
finally:
t_byte_source.close()
@@ -644,18 +653,21 @@ class MediaRepository(object):
# Write to database
if server_name:
yield self.store.store_remote_media_thumbnail(
- server_name, media_id, file_id,
- t_width, t_height, t_type, t_method, t_len
+ server_name,
+ media_id,
+ file_id,
+ t_width,
+ t_height,
+ t_type,
+ t_method,
+ t_len,
)
else:
yield self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
- defer.returnValue({
- "width": m_width,
- "height": m_height,
- })
+ defer.returnValue({"width": m_width, "height": m_height})
@defer.inlineCallbacks
def delete_old_remote_media(self, before_ts):
@@ -747,11 +759,12 @@ class MediaRepositoryResource(Resource):
self.putChild(b"upload", UploadResource(hs, media_repo))
self.putChild(b"download", DownloadResource(hs, media_repo))
- self.putChild(b"thumbnail", ThumbnailResource(
- hs, media_repo, media_repo.media_storage,
- ))
+ self.putChild(
+ b"thumbnail", ThumbnailResource(hs, media_repo, media_repo.media_storage)
+ )
if hs.config.url_preview_enabled:
- self.putChild(b"preview_url", PreviewUrlResource(
- hs, media_repo, media_repo.media_storage,
- ))
+ self.putChild(
+ b"preview_url",
+ PreviewUrlResource(hs, media_repo, media_repo.media_storage),
+ )
self.putChild(b"config", MediaConfigResource(hs))
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 896078fe76..eff86836fb 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -66,8 +66,7 @@ class MediaStorage(object):
with self.store_into_file(file_info) as (f, fname, finish_cb):
# Write to the main repository
yield logcontext.defer_to_thread(
- self.hs.get_reactor(),
- _write_file_synchronously, source, f,
+ self.hs.get_reactor(), _write_file_synchronously, source, f
)
yield finish_cb()
@@ -179,7 +178,8 @@ class MediaStorage(object):
if res:
with res:
consumer = BackgroundFileConsumer(
- open(local_path, "wb"), self.hs.get_reactor())
+ open(local_path, "wb"), self.hs.get_reactor()
+ )
yield res.write_to_consumer(consumer)
yield consumer.wait()
defer.returnValue(local_path)
@@ -217,10 +217,10 @@ class MediaStorage(object):
width=file_info.thumbnail_width,
height=file_info.thumbnail_height,
content_type=file_info.thumbnail_type,
- method=file_info.thumbnail_method
+ method=file_info.thumbnail_method,
)
return self.filepaths.remote_media_filepath_rel(
- file_info.server_name, file_info.file_id,
+ file_info.server_name, file_info.file_id
)
if file_info.thumbnail:
@@ -229,11 +229,9 @@ class MediaStorage(object):
width=file_info.thumbnail_width,
height=file_info.thumbnail_height,
content_type=file_info.thumbnail_type,
- method=file_info.thumbnail_method
+ method=file_info.thumbnail_method,
)
- return self.filepaths.local_media_filepath_rel(
- file_info.file_id,
- )
+ return self.filepaths.local_media_filepath_rel(file_info.file_id)
def _write_file_synchronously(source, dest):
@@ -255,6 +253,7 @@ class FileResponder(Responder):
open_file (file): A file like object to be streamed ot the client,
is closed when finished streaming.
"""
+
def __init__(self, open_file):
self.open_file = open_file
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index acf87709f2..de6f292ffb 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -92,7 +92,7 @@ class PreviewUrlResource(Resource):
)
self._cleaner_loop = self.clock.looping_call(
- self._start_expire_url_cache_data, 10 * 1000,
+ self._start_expire_url_cache_data, 10 * 1000
)
def render_OPTIONS(self, request):
@@ -121,16 +121,16 @@ class PreviewUrlResource(Resource):
for attrib in entry:
pattern = entry[attrib]
value = getattr(url_tuple, attrib)
- logger.debug((
- "Matching attrib '%s' with value '%s' against"
- " pattern '%s'"
- ) % (attrib, value, pattern))
+ logger.debug(
+ ("Matching attrib '%s' with value '%s' against" " pattern '%s'")
+ % (attrib, value, pattern)
+ )
if value is None:
match = False
continue
- if pattern.startswith('^'):
+ if pattern.startswith("^"):
if not re.match(pattern, getattr(url_tuple, attrib)):
match = False
continue
@@ -139,12 +139,9 @@ class PreviewUrlResource(Resource):
match = False
continue
if match:
- logger.warn(
- "URL %s blocked by url_blacklist entry %s", url, entry
- )
+ logger.warn("URL %s blocked by url_blacklist entry %s", url, entry)
raise SynapseError(
- 403, "URL blocked by url pattern blacklist entry",
- Codes.UNKNOWN
+ 403, "URL blocked by url pattern blacklist entry", Codes.UNKNOWN
)
# the in-memory cache:
@@ -156,14 +153,8 @@ class PreviewUrlResource(Resource):
observable = self._cache.get(url)
if not observable:
- download = run_in_background(
- self._do_preview,
- url, requester.user, ts,
- )
- observable = ObservableDeferred(
- download,
- consumeErrors=True
- )
+ download = run_in_background(self._do_preview, url, requester.user, ts)
+ observable = ObservableDeferred(download, consumeErrors=True)
self._cache[url] = observable
else:
logger.info("Returning cached response")
@@ -187,15 +178,15 @@ class PreviewUrlResource(Resource):
# historical previews, if we have any)
cache_result = yield self.store.get_url_cache(url, ts)
if (
- cache_result and
- cache_result["expires_ts"] > ts and
- cache_result["response_code"] / 100 == 2
+ cache_result
+ and cache_result["expires_ts"] > ts
+ and cache_result["response_code"] / 100 == 2
):
# It may be stored as text in the database, not as bytes (such as
# PostgreSQL). If so, encode it back before handing it on.
og = cache_result["og"]
if isinstance(og, six.text_type):
- og = og.encode('utf8')
+ og = og.encode("utf8")
defer.returnValue(og)
return
@@ -203,33 +194,31 @@ class PreviewUrlResource(Resource):
logger.debug("got media_info of '%s'" % media_info)
- if _is_media(media_info['media_type']):
- file_id = media_info['filesystem_id']
+ if _is_media(media_info["media_type"]):
+ file_id = media_info["filesystem_id"]
dims = yield self.media_repo._generate_thumbnails(
- None, file_id, file_id, media_info["media_type"],
- url_cache=True,
+ None, file_id, file_id, media_info["media_type"], url_cache=True
)
og = {
- "og:description": media_info['download_name'],
- "og:image": "mxc://%s/%s" % (
- self.server_name, media_info['filesystem_id']
- ),
- "og:image:type": media_info['media_type'],
- "matrix:image:size": media_info['media_length'],
+ "og:description": media_info["download_name"],
+ "og:image": "mxc://%s/%s"
+ % (self.server_name, media_info["filesystem_id"]),
+ "og:image:type": media_info["media_type"],
+ "matrix:image:size": media_info["media_length"],
}
if dims:
- og["og:image:width"] = dims['width']
- og["og:image:height"] = dims['height']
+ og["og:image:width"] = dims["width"]
+ og["og:image:height"] = dims["height"]
else:
logger.warn("Couldn't get dims for %s" % url)
# define our OG response for this media
- elif _is_html(media_info['media_type']):
+ elif _is_html(media_info["media_type"]):
# TODO: somehow stop a big HTML tree from exploding synapse's RAM
- with open(media_info['filename'], 'rb') as file:
+ with open(media_info["filename"], "rb") as file:
body = file.read()
encoding = None
@@ -242,45 +231,43 @@ class PreviewUrlResource(Resource):
# If we find a match, it should take precedence over the
# Content-Type header, so set it here.
if match:
- encoding = match.group(1).decode('ascii')
+ encoding = match.group(1).decode("ascii")
# If we don't find a match, we'll look at the HTTP Content-Type, and
# if that doesn't exist, we'll fall back to UTF-8.
if not encoding:
- match = _content_type_match.match(
- media_info['media_type']
- )
+ match = _content_type_match.match(media_info["media_type"])
encoding = match.group(1) if match else "utf-8"
- og = decode_and_calc_og(body, media_info['uri'], encoding)
+ og = decode_and_calc_og(body, media_info["uri"], encoding)
# pre-cache the image for posterity
# FIXME: it might be cleaner to use the same flow as the main /preview_url
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
- if 'og:image' in og and og['og:image']:
+ if "og:image" in og and og["og:image"]:
image_info = yield self._download_url(
- _rebase_url(og['og:image'], media_info['uri']), user
+ _rebase_url(og["og:image"], media_info["uri"]), user
)
- if _is_media(image_info['media_type']):
+ if _is_media(image_info["media_type"]):
# TODO: make sure we don't choke on white-on-transparent images
- file_id = image_info['filesystem_id']
+ file_id = image_info["filesystem_id"]
dims = yield self.media_repo._generate_thumbnails(
- None, file_id, file_id, image_info["media_type"],
- url_cache=True,
+ None, file_id, file_id, image_info["media_type"], url_cache=True
)
if dims:
- og["og:image:width"] = dims['width']
- og["og:image:height"] = dims['height']
+ og["og:image:width"] = dims["width"]
+ og["og:image:height"] = dims["height"]
else:
logger.warn("Couldn't get dims for %s" % og["og:image"])
og["og:image"] = "mxc://%s/%s" % (
- self.server_name, image_info['filesystem_id']
+ self.server_name,
+ image_info["filesystem_id"],
)
- og["og:image:type"] = image_info['media_type']
- og["matrix:image:size"] = image_info['media_length']
+ og["og:image:type"] = image_info["media_type"]
+ og["matrix:image:size"] = image_info["media_length"]
else:
del og["og:image"]
else:
@@ -289,7 +276,7 @@ class PreviewUrlResource(Resource):
logger.debug("Calculated OG for %s as %s" % (url, og))
- jsonog = json.dumps(og).encode('utf8')
+ jsonog = json.dumps(og).encode("utf8")
# store OG in history-aware DB cache
yield self.store.store_url_cache(
@@ -310,19 +297,15 @@ class PreviewUrlResource(Resource):
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
- file_id = datetime.date.today().isoformat() + '_' + random_string(16)
+ file_id = datetime.date.today().isoformat() + "_" + random_string(16)
- file_info = FileInfo(
- server_name=None,
- file_id=file_id,
- url_cache=True,
- )
+ file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
try:
logger.debug("Trying to get url '%s'" % url)
length, headers, uri, code = yield self.client.get_file(
- url, output_stream=f, max_size=self.max_spider_size,
+ url, output_stream=f, max_size=self.max_spider_size
)
except SynapseError:
# Pass SynapseErrors through directly, so that the servlet
@@ -334,24 +317,25 @@ class PreviewUrlResource(Resource):
# Note: This will also be the case if one of the resolved IP
# addresses is blacklisted
raise SynapseError(
- 502, "DNS resolution failure during URL preview generation",
- Codes.UNKNOWN
+ 502,
+ "DNS resolution failure during URL preview generation",
+ Codes.UNKNOWN,
)
except Exception as e:
# FIXME: pass through 404s and other error messages nicely
logger.warn("Error downloading %s: %r", url, e)
raise SynapseError(
- 500, "Failed to download content: %s" % (
- traceback.format_exception_only(sys.exc_info()[0], e),
- ),
+ 500,
+ "Failed to download content: %s"
+ % (traceback.format_exception_only(sys.exc_info()[0], e),),
Codes.UNKNOWN,
)
yield finish()
try:
if b"Content-Type" in headers:
- media_type = headers[b"Content-Type"][0].decode('ascii')
+ media_type = headers[b"Content-Type"][0].decode("ascii")
else:
media_type = "application/octet-stream"
time_now_ms = self.clock.time_msec()
@@ -375,24 +359,26 @@ class PreviewUrlResource(Resource):
# therefore not expire it.
raise
- defer.returnValue({
- "media_type": media_type,
- "media_length": length,
- "download_name": download_name,
- "created_ts": time_now_ms,
- "filesystem_id": file_id,
- "filename": fname,
- "uri": uri,
- "response_code": code,
- # FIXME: we should calculate a proper expiration based on the
- # Cache-Control and Expire headers. But for now, assume 1 hour.
- "expires": 60 * 60 * 1000,
- "etag": headers["ETag"][0] if "ETag" in headers else None,
- })
+ defer.returnValue(
+ {
+ "media_type": media_type,
+ "media_length": length,
+ "download_name": download_name,
+ "created_ts": time_now_ms,
+ "filesystem_id": file_id,
+ "filename": fname,
+ "uri": uri,
+ "response_code": code,
+ # FIXME: we should calculate a proper expiration based on the
+ # Cache-Control and Expire headers. But for now, assume 1 hour.
+ "expires": 60 * 60 * 1000,
+ "etag": headers["ETag"][0] if "ETag" in headers else None,
+ }
+ )
def _start_expire_url_cache_data(self):
return run_as_background_process(
- "expire_url_cache_data", self._expire_url_cache_data,
+ "expire_url_cache_data", self._expire_url_cache_data
)
@defer.inlineCallbacks
@@ -496,7 +482,7 @@ def decode_and_calc_og(body, media_uri, request_encoding=None):
# blindly try decoding the body as utf-8, which seems to fix
# the charset mismatches on https://google.com
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
- tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
+ tree = etree.fromstring(body.decode("utf-8", "ignore"), parser)
og = _calc_og(tree, media_uri)
return og
@@ -523,8 +509,8 @@ def _calc_og(tree, media_uri):
og = {}
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
- if 'content' in tag.attrib:
- og[tag.attrib['property']] = tag.attrib['content']
+ if "content" in tag.attrib:
+ og[tag.attrib["property"]] = tag.attrib["content"]
# TODO: grab article: meta tags too, e.g.:
@@ -535,39 +521,43 @@ def _calc_og(tree, media_uri):
# "article:published_time" content="2016-03-31T19:58:24+00:00" />
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
- if 'og:title' not in og:
+ if "og:title" not in og:
# do some basic spidering of the HTML
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
if title and title[0].text is not None:
- og['og:title'] = title[0].text.strip()
+ og["og:title"] = title[0].text.strip()
else:
- og['og:title'] = None
+ og["og:title"] = None
- if 'og:image' not in og:
+ if "og:image" not in og:
# TODO: extract a favicon failing all else
meta_image = tree.xpath(
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
)
if meta_image:
- og['og:image'] = _rebase_url(meta_image[0], media_uri)
+ og["og:image"] = _rebase_url(meta_image[0], media_uri)
else:
# TODO: consider inlined CSS styles as well as width & height attribs
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
- images = sorted(images, key=lambda i: (
- -1 * float(i.attrib['width']) * float(i.attrib['height'])
- ))
+ images = sorted(
+ images,
+ key=lambda i: (
+ -1 * float(i.attrib["width"]) * float(i.attrib["height"])
+ ),
+ )
if not images:
images = tree.xpath("//img[@src]")
if images:
- og['og:image'] = images[0].attrib['src']
+ og["og:image"] = images[0].attrib["src"]
- if 'og:description' not in og:
+ if "og:description" not in og:
meta_description = tree.xpath(
"//*/meta"
"[translate(@name, 'DESCRIPTION', 'description')='description']"
- "/@content")
+ "/@content"
+ )
if meta_description:
- og['og:description'] = meta_description[0]
+ og["og:description"] = meta_description[0]
else:
# grab any text nodes which are inside the <body/> tag...
# unless they are within an HTML5 semantic markup tag...
@@ -588,18 +578,18 @@ def _calc_og(tree, media_uri):
"script",
"noscript",
"style",
- etree.Comment
+ etree.Comment,
)
# Split all the text nodes into paragraphs (by splitting on new
# lines)
text_nodes = (
- re.sub(r'\s+', '\n', el).strip()
+ re.sub(r"\s+", "\n", el).strip()
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
)
- og['og:description'] = summarize_paragraphs(text_nodes)
+ og["og:description"] = summarize_paragraphs(text_nodes)
else:
- og['og:description'] = summarize_paragraphs([og['og:description']])
+ og["og:description"] = summarize_paragraphs([og["og:description"]])
# TODO: delete the url downloads to stop diskfilling,
# as we only ever cared about its OG
@@ -636,7 +626,7 @@ def _iterate_over_text(tree, *tags_to_ignore):
[child, child.tail] if child.tail else [child]
for child in el.iterchildren()
),
- elements
+ elements,
)
@@ -647,8 +637,8 @@ def _rebase_url(url, base):
url[0] = base[0] or "http"
if not url[1]: # fix up hostname
url[1] = base[1]
- if not url[2].startswith('/'):
- url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
+ if not url[2].startswith("/"):
+ url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2]
return urlparse.urlunparse(url)
@@ -659,9 +649,8 @@ def _is_media(content_type):
def _is_html(content_type):
content_type = content_type.lower()
- if (
- content_type.startswith("text/html") or
- content_type.startswith("application/xhtml")
+ if content_type.startswith("text/html") or content_type.startswith(
+ "application/xhtml"
):
return True
@@ -671,19 +660,19 @@ def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
# first paragraph and then word boundaries.
# TODO: Respect sentences?
- description = ''
+ description = ""
# Keep adding paragraphs until we get to the MIN_SIZE.
for text_node in text_nodes:
if len(description) < min_size:
- text_node = re.sub(r'[\t \r\n]+', ' ', text_node)
- description += text_node + '\n\n'
+ text_node = re.sub(r"[\t \r\n]+", " ", text_node)
+ description += text_node + "\n\n"
else:
break
description = description.strip()
- description = re.sub(r'[\t ]+', ' ', description)
- description = re.sub(r'[\t \r\n]*[\r\n]+', '\n\n', description)
+ description = re.sub(r"[\t ]+", " ", description)
+ description = re.sub(r"[\t \r\n]*[\r\n]+", "\n\n", description)
# If the concatenation of paragraphs to get above MIN_SIZE
# took us over MAX_SIZE, then we need to truncate mid paragraph
@@ -715,5 +704,5 @@ def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
# We always add an ellipsis because at the very least
# we chopped mid paragraph.
- description = new_desc.strip() + u"…"
+ description = new_desc.strip() + "…"
return description if description else None
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index d90cbfb56a..359b45ebfc 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -32,6 +32,7 @@ class StorageProvider(object):
"""A storage provider is a service that can store uploaded media and
retrieve them.
"""
+
def store_file(self, path, file_info):
"""Store the file described by file_info. The actual contents can be
retrieved by reading the file in file_info.upload_path.
@@ -70,6 +71,7 @@ class StorageProviderWrapper(StorageProvider):
uploaded, or todo the upload in the backgroud.
store_remote (bool): Whether remote media should be uploaded
"""
+
def __init__(self, backend, store_local, store_synchronous, store_remote):
self.backend = backend
self.store_local = store_local
@@ -92,6 +94,7 @@ class StorageProviderWrapper(StorageProvider):
return self.backend.store_file(path, file_info)
except Exception:
logger.exception("Error storing file")
+
run_in_background(store)
return defer.succeed(None)
@@ -123,8 +126,7 @@ class FileStorageProviderBackend(StorageProvider):
os.makedirs(dirname)
return logcontext.defer_to_thread(
- self.hs.get_reactor(),
- shutil.copyfile, primary_fname, backup_fname,
+ self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
)
def fetch(self, path, file_info):
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 35a750923b..ca84c9f139 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -74,19 +74,18 @@ class ThumbnailResource(Resource):
else:
if self.dynamic_thumbnails:
yield self._select_or_generate_remote_thumbnail(
- request, server_name, media_id,
- width, height, method, m_type
+ request, server_name, media_id, width, height, method, m_type
)
else:
yield self._respond_remote_thumbnail(
- request, server_name, media_id,
- width, height, method, m_type
+ request, server_name, media_id, width, height, method, m_type
)
self.media_repo.mark_recently_accessed(server_name, media_id)
@defer.inlineCallbacks
- def _respond_local_thumbnail(self, request, media_id, width, height,
- method, m_type):
+ def _respond_local_thumbnail(
+ self, request, media_id, width, height, method, m_type
+ ):
media_info = yield self.store.get_local_media(media_id)
if not media_info:
@@ -105,7 +104,8 @@ class ThumbnailResource(Resource):
)
file_info = FileInfo(
- server_name=None, file_id=media_id,
+ server_name=None,
+ file_id=media_id,
url_cache=media_info["url_cache"],
thumbnail=True,
thumbnail_width=thumbnail_info["thumbnail_width"],
@@ -124,9 +124,15 @@ class ThumbnailResource(Resource):
respond_404(request)
@defer.inlineCallbacks
- def _select_or_generate_local_thumbnail(self, request, media_id, desired_width,
- desired_height, desired_method,
- desired_type):
+ def _select_or_generate_local_thumbnail(
+ self,
+ request,
+ media_id,
+ desired_width,
+ desired_height,
+ desired_method,
+ desired_type,
+ ):
media_info = yield self.store.get_local_media(media_id)
if not media_info:
@@ -146,7 +152,8 @@ class ThumbnailResource(Resource):
if t_w and t_h and t_method and t_type:
file_info = FileInfo(
- server_name=None, file_id=media_id,
+ server_name=None,
+ file_id=media_id,
url_cache=media_info["url_cache"],
thumbnail=True,
thumbnail_width=info["thumbnail_width"],
@@ -167,7 +174,11 @@ class ThumbnailResource(Resource):
# Okay, so we generate one.
file_path = yield self.media_repo.generate_local_exact_thumbnail(
- media_id, desired_width, desired_height, desired_method, desired_type,
+ media_id,
+ desired_width,
+ desired_height,
+ desired_method,
+ desired_type,
url_cache=media_info["url_cache"],
)
@@ -178,13 +189,20 @@ class ThumbnailResource(Resource):
respond_404(request)
@defer.inlineCallbacks
- def _select_or_generate_remote_thumbnail(self, request, server_name, media_id,
- desired_width, desired_height,
- desired_method, desired_type):
+ def _select_or_generate_remote_thumbnail(
+ self,
+ request,
+ server_name,
+ media_id,
+ desired_width,
+ desired_height,
+ desired_method,
+ desired_type,
+ ):
media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
- server_name, media_id,
+ server_name, media_id
)
file_id = media_info["filesystem_id"]
@@ -197,7 +215,8 @@ class ThumbnailResource(Resource):
if t_w and t_h and t_method and t_type:
file_info = FileInfo(
- server_name=server_name, file_id=media_info["filesystem_id"],
+ server_name=server_name,
+ file_id=media_info["filesystem_id"],
thumbnail=True,
thumbnail_width=info["thumbnail_width"],
thumbnail_height=info["thumbnail_height"],
@@ -217,8 +236,13 @@ class ThumbnailResource(Resource):
# Okay, so we generate one.
file_path = yield self.media_repo.generate_remote_exact_thumbnail(
- server_name, file_id, media_id, desired_width,
- desired_height, desired_method, desired_type
+ server_name,
+ file_id,
+ media_id,
+ desired_width,
+ desired_height,
+ desired_method,
+ desired_type,
)
if file_path:
@@ -228,15 +252,16 @@ class ThumbnailResource(Resource):
respond_404(request)
@defer.inlineCallbacks
- def _respond_remote_thumbnail(self, request, server_name, media_id, width,
- height, method, m_type):
+ def _respond_remote_thumbnail(
+ self, request, server_name, media_id, width, height, method, m_type
+ ):
# TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails.
media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
- server_name, media_id,
+ server_name, media_id
)
if thumbnail_infos:
@@ -244,7 +269,8 @@ class ThumbnailResource(Resource):
width, height, method, m_type, thumbnail_infos
)
file_info = FileInfo(
- server_name=server_name, file_id=media_info["filesystem_id"],
+ server_name=server_name,
+ file_id=media_info["filesystem_id"],
thumbnail=True,
thumbnail_width=thumbnail_info["thumbnail_width"],
thumbnail_height=thumbnail_info["thumbnail_height"],
@@ -261,8 +287,14 @@ class ThumbnailResource(Resource):
logger.info("Failed to find any generated thumbnails")
respond_404(request)
- def _select_thumbnail(self, desired_width, desired_height, desired_method,
- desired_type, thumbnail_infos):
+ def _select_thumbnail(
+ self,
+ desired_width,
+ desired_height,
+ desired_method,
+ desired_type,
+ thumbnail_infos,
+ ):
d_w = desired_width
d_h = desired_height
@@ -280,15 +312,27 @@ class ThumbnailResource(Resource):
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
if t_w >= d_w or t_h >= d_h:
- info_list.append((
- aspect_quality, min_quality, size_quality, type_quality,
- length_quality, info
- ))
+ info_list.append(
+ (
+ aspect_quality,
+ min_quality,
+ size_quality,
+ type_quality,
+ length_quality,
+ info,
+ )
+ )
else:
- info_list2.append((
- aspect_quality, min_quality, size_quality, type_quality,
- length_quality, info
- ))
+ info_list2.append(
+ (
+ aspect_quality,
+ min_quality,
+ size_quality,
+ type_quality,
+ length_quality,
+ info,
+ )
+ )
if info_list:
return min(info_list)[-1]
else:
@@ -304,13 +348,11 @@ class ThumbnailResource(Resource):
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
if t_method == "scale" and (t_w >= d_w or t_h >= d_h):
- info_list.append((
- size_quality, type_quality, length_quality, info
- ))
+ info_list.append((size_quality, type_quality, length_quality, info))
elif t_method == "scale":
- info_list2.append((
- size_quality, type_quality, length_quality, info
- ))
+ info_list2.append(
+ (size_quality, type_quality, length_quality, info)
+ )
if info_list:
return min(info_list)[-1]
else:
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 3efd0d80fc..90d8e6bffe 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -28,16 +28,13 @@ EXIF_TRANSPOSE_MAPPINGS = {
5: Image.TRANSPOSE,
6: Image.ROTATE_270,
7: Image.TRANSVERSE,
- 8: Image.ROTATE_90
+ 8: Image.ROTATE_90,
}
class Thumbnailer(object):
- FORMATS = {
- "image/jpeg": "JPEG",
- "image/png": "PNG",
- }
+ FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
def __init__(self, input_path):
self.image = Image.open(input_path)
@@ -110,17 +107,13 @@ class Thumbnailer(object):
"""
if width * self.height > height * self.width:
scaled_height = (width * self.height) // self.width
- scaled_image = self.image.resize(
- (width, scaled_height), Image.ANTIALIAS
- )
+ scaled_image = self.image.resize((width, scaled_height), Image.ANTIALIAS)
crop_top = (scaled_height - height) // 2
crop_bottom = height + crop_top
cropped = scaled_image.crop((0, crop_top, width, crop_bottom))
else:
scaled_width = (height * self.width) // self.height
- scaled_image = self.image.resize(
- (scaled_width, height), Image.ANTIALIAS
- )
+ scaled_image = self.image.resize((scaled_width, height), Image.ANTIALIAS)
crop_left = (scaled_width - width) // 2
crop_right = width + crop_left
cropped = scaled_image.crop((crop_left, 0, crop_right, height))
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index c1240e1963..d1d7e959f0 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -55,48 +55,36 @@ class UploadResource(Resource):
requester = yield self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
- content_length = request.getHeader(b"Content-Length").decode('ascii')
+ content_length = request.getHeader(b"Content-Length").decode("ascii")
if content_length is None:
- raise SynapseError(
- msg="Request must specify a Content-Length", code=400
- )
+ raise SynapseError(msg="Request must specify a Content-Length", code=400)
if int(content_length) > self.max_upload_size:
- raise SynapseError(
- msg="Upload request body is too large",
- code=413,
- )
+ raise SynapseError(msg="Upload request body is too large", code=413)
upload_name = parse_string(request, b"filename", encoding=None)
if upload_name:
try:
- upload_name = upload_name.decode('utf8')
+ upload_name = upload_name.decode("utf8")
except UnicodeDecodeError:
raise SynapseError(
- msg="Invalid UTF-8 filename parameter: %r" % (upload_name),
- code=400,
+ msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400
)
headers = request.requestHeaders
if headers.hasHeader(b"Content-Type"):
- media_type = headers.getRawHeaders(b"Content-Type")[0].decode('ascii')
+ media_type = headers.getRawHeaders(b"Content-Type")[0].decode("ascii")
else:
- raise SynapseError(
- msg="Upload request missing 'Content-Type'",
- code=400,
- )
+ raise SynapseError(msg="Upload request missing 'Content-Type'", code=400)
# if headers.hasHeader(b"Content-Disposition"):
# disposition = headers.getRawHeaders(b"Content-Disposition")[0]
# TODO(markjh): parse content-dispostion
content_uri = yield self.media_repo.create_content(
- media_type, upload_name, request.content,
- content_length, requester.user
+ media_type, upload_name, request.content, content_length, requester.user
)
logger.info("Uploaded content with URI %r", content_uri)
- respond_with_json(
- request, 200, {"content_uri": content_uri}, send_cors=True
- )
+ respond_with_json(request, 200, {"content_uri": content_uri}, send_cors=True)
diff --git a/synapse/rest/saml2/metadata_resource.py b/synapse/rest/saml2/metadata_resource.py
index e8c680aeb4..1e8526e22e 100644
--- a/synapse/rest/saml2/metadata_resource.py
+++ b/synapse/rest/saml2/metadata_resource.py
@@ -30,7 +30,7 @@ class SAML2MetadataResource(Resource):
def render_GET(self, request):
metadata_xml = saml2.metadata.create_metadata_string(
- configfile=None, config=self.sp_config,
+ configfile=None, config=self.sp_config
)
request.setHeader(b"Content-Type", b"text/xml; charset=utf-8")
return metadata_xml
diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py
index 36ca1333a8..9ec56d6adb 100644
--- a/synapse/rest/saml2/response_resource.py
+++ b/synapse/rest/saml2/response_resource.py
@@ -44,18 +44,16 @@ class SAML2ResponseResource(Resource):
@wrap_html_request_handler
def _async_render_POST(self, request):
- resp_bytes = parse_string(request, 'SAMLResponse', required=True)
- relay_state = parse_string(request, 'RelayState', required=True)
+ resp_bytes = parse_string(request, "SAMLResponse", required=True)
+ relay_state = parse_string(request, "RelayState", required=True)
try:
saml2_auth = self._saml_client.parse_authn_request_response(
- resp_bytes, saml2.BINDING_HTTP_POST,
+ resp_bytes, saml2.BINDING_HTTP_POST
)
except Exception as e:
logger.warning("Exception parsing SAML2 response", exc_info=1)
- raise CodeMessageException(
- 400, "Unable to parse SAML2 response: %s" % (e,),
- )
+ raise CodeMessageException(400, "Unable to parse SAML2 response: %s" % (e,))
if saml2_auth.not_signed:
raise CodeMessageException(400, "SAML2 response was not signed")
@@ -67,6 +65,5 @@ class SAML2ResponseResource(Resource):
displayName = saml2_auth.ava.get("displayName", [None])[0]
return self._sso_auth_handler.on_successful_auth(
- username, request, relay_state,
- user_display_name=displayName,
+ username, request, relay_state, user_display_name=displayName
)
diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index a7fa4f39af..5e8fda4b65 100644
--- a/synapse/rest/well_known.py
+++ b/synapse/rest/well_known.py
@@ -29,6 +29,7 @@ class WellKnownBuilder(object):
Args:
hs (synapse.server.HomeServer):
"""
+
def __init__(self, hs):
self._config = hs.config
@@ -37,15 +38,11 @@ class WellKnownBuilder(object):
if self._config.public_baseurl is None:
return None
- result = {
- "m.homeserver": {
- "base_url": self._config.public_baseurl,
- },
- }
+ result = {"m.homeserver": {"base_url": self._config.public_baseurl}}
if self._config.default_identity_server:
result["m.identity_server"] = {
- "base_url": self._config.default_identity_server,
+ "base_url": self._config.default_identity_server
}
return result
@@ -66,7 +63,7 @@ class WellKnownResource(Resource):
if not r:
request.setResponseCode(404)
request.setHeader(b"Content-Type", b"text/plain")
- return b'.well-known not available'
+ return b".well-known not available"
logger.debug("returning: %s", r)
request.setHeader(b"Content-Type", b"application/json")
diff --git a/synapse/secrets.py b/synapse/secrets.py
index f6280f951c..0b327a0f82 100644
--- a/synapse/secrets.py
+++ b/synapse/secrets.py
@@ -29,6 +29,7 @@ if sys.version_info[0:2] >= (3, 6):
def Secrets():
return secrets
+
else:
import os
import binascii
@@ -38,4 +39,4 @@ else:
return os.urandom(nbytes)
def token_hex(self, nbytes=32):
- return binascii.hexlify(self.token_bytes(nbytes)).decode('ascii')
+ return binascii.hexlify(self.token_bytes(nbytes)).decode("ascii")
diff --git a/synapse/server.py b/synapse/server.py
index 0eb8968674..dbb35c7227 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -35,6 +37,7 @@ from synapse.crypto import context_factory
from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory
from synapse.events.spamcheck import SpamChecker
+from synapse.events.third_party_rules import ThirdPartyEventRules
from synapse.events.utils import EventClientSerializer
from synapse.federation.federation_client import FederationClient
from synapse.federation.federation_server import (
@@ -88,7 +91,9 @@ from synapse.rest.media.v1.media_repository import (
from synapse.secrets import Secrets
from synapse.server_notices.server_notices_manager import ServerNoticesManager
from synapse.server_notices.server_notices_sender import ServerNoticesSender
-from synapse.server_notices.worker_server_notices_sender import WorkerServerNoticesSender
+from synapse.server_notices.worker_server_notices_sender import (
+ WorkerServerNoticesSender,
+)
from synapse.state import StateHandler, StateResolutionHandler
from synapse.streams.events import EventSources
from synapse.util import Clock
@@ -124,79 +129,77 @@ class HomeServer(object):
__metaclass__ = abc.ABCMeta
DEPENDENCIES = [
- 'http_client',
- 'db_pool',
- 'federation_client',
- 'federation_server',
- 'handlers',
- 'auth',
- 'room_creation_handler',
- 'state_handler',
- 'state_resolution_handler',
- 'presence_handler',
- 'sync_handler',
- 'typing_handler',
- 'room_list_handler',
- 'acme_handler',
- 'auth_handler',
- 'device_handler',
- 'stats_handler',
- 'e2e_keys_handler',
- 'e2e_room_keys_handler',
- 'event_handler',
- 'event_stream_handler',
- 'initial_sync_handler',
- 'application_service_api',
- 'application_service_scheduler',
- 'application_service_handler',
- 'device_message_handler',
- 'profile_handler',
- 'event_creation_handler',
- 'deactivate_account_handler',
- 'set_password_handler',
- 'notifier',
- 'event_sources',
- 'keyring',
- 'pusherpool',
- 'event_builder_factory',
- 'filtering',
- 'http_client_context_factory',
- 'simple_http_client',
- 'media_repository',
- 'media_repository_resource',
- 'federation_transport_client',
- 'federation_sender',
- 'receipts_handler',
- 'macaroon_generator',
- 'tcp_replication',
- 'read_marker_handler',
- 'action_generator',
- 'user_directory_handler',
- 'groups_local_handler',
- 'groups_server_handler',
- 'groups_attestation_signing',
- 'groups_attestation_renewer',
- 'secrets',
- 'spam_checker',
- 'room_member_handler',
- 'federation_registry',
- 'server_notices_manager',
- 'server_notices_sender',
- 'message_handler',
- 'pagination_handler',
- 'room_context_handler',
- 'sendmail',
- 'registration_handler',
- 'account_validity_handler',
- 'event_client_serializer',
- 'saml_client',
- ]
-
- REQUIRED_ON_MASTER_STARTUP = [
+ "http_client",
+ "db_pool",
+ "federation_client",
+ "federation_server",
+ "handlers",
+ "auth",
+ "room_creation_handler",
+ "state_handler",
+ "state_resolution_handler",
+ "presence_handler",
+ "sync_handler",
+ "typing_handler",
+ "room_list_handler",
+ "acme_handler",
+ "auth_handler",
+ "device_handler",
+ "stats_handler",
+ "e2e_keys_handler",
+ "e2e_room_keys_handler",
+ "event_handler",
+ "event_stream_handler",
+ "initial_sync_handler",
+ "application_service_api",
+ "application_service_scheduler",
+ "application_service_handler",
+ "device_message_handler",
+ "profile_handler",
+ "event_creation_handler",
+ "deactivate_account_handler",
+ "set_password_handler",
+ "notifier",
+ "event_sources",
+ "keyring",
+ "pusherpool",
+ "event_builder_factory",
+ "filtering",
+ "http_client_context_factory",
+ "simple_http_client",
+ "media_repository",
+ "media_repository_resource",
+ "federation_transport_client",
+ "federation_sender",
+ "receipts_handler",
+ "macaroon_generator",
+ "tcp_replication",
+ "read_marker_handler",
+ "action_generator",
"user_directory_handler",
- "stats_handler"
+ "groups_local_handler",
+ "groups_server_handler",
+ "groups_attestation_signing",
+ "groups_attestation_renewer",
+ "secrets",
+ "spam_checker",
+ "third_party_event_rules",
+ "room_member_handler",
+ "federation_registry",
+ "server_notices_manager",
+ "server_notices_sender",
+ "message_handler",
+ "pagination_handler",
+ "room_context_handler",
+ "sendmail",
+ "registration_handler",
+ "account_validity_handler",
+ "event_client_serializer",
+ "saml_client",
]
+ REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
+
# This is overridden in derived application classes
# (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be
# instantiated during setup() for future return by get_datastore()
@@ -407,9 +410,7 @@ class HomeServer(object):
name = self.db_config["name"]
return adbapi.ConnectionPool(
- name,
- cp_reactor=self.get_reactor(),
- **self.db_config.get("args", {})
+ name, cp_reactor=self.get_reactor(), **self.db_config.get("args", {})
)
def get_db_conn(self, run_new_connection=True):
@@ -421,7 +422,8 @@ class HomeServer(object):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
+ k: v
+ for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
@@ -484,6 +486,9 @@ class HomeServer(object):
def build_spam_checker(self):
return SpamChecker(self)
+ def build_third_party_event_rules(self):
+ return ThirdPartyEventRules(self)
+
def build_room_member_handler(self):
if self.config.worker_app:
return RoomMemberWorkerHandler(self)
@@ -525,6 +530,7 @@ class HomeServer(object):
def build_saml_client(self):
from saml2.client import Saml2Client
+
return Saml2Client(self.config.saml2_sp_config)
def remove_pusher(self, app_id, push_key, user_id):
@@ -553,9 +559,7 @@ def _make_dependency_method(depname):
if builder:
# Prevent cyclic dependencies from deadlocking
if depname in hs._building:
- raise ValueError("Cyclic dependency while building %s" % (
- depname,
- ))
+ raise ValueError("Cyclic dependency while building %s" % (depname,))
hs._building[depname] = 1
dep = builder()
@@ -566,9 +570,7 @@ def _make_dependency_method(depname):
return dep
raise NotImplementedError(
- "%s has no %s nor a builder for it" % (
- type(hs).__name__, depname,
- )
+ "%s has no %s nor a builder for it" % (type(hs).__name__, depname)
)
setattr(HomeServer, "get_%s" % (depname), _get)
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 9583e82d52..16f8f6b573 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -22,60 +22,57 @@ class HomeServer(object):
@property
def config(self) -> synapse.config.homeserver.HomeServerConfig:
pass
-
def get_auth(self) -> synapse.api.auth.Auth:
pass
-
def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler:
pass
-
def get_datastore(self) -> synapse.storage.DataStore:
pass
-
def get_device_handler(self) -> synapse.handlers.device.DeviceHandler:
pass
-
def get_e2e_keys_handler(self) -> synapse.handlers.e2e_keys.E2eKeysHandler:
pass
-
def get_handlers(self) -> synapse.handlers.Handlers:
pass
-
def get_state_handler(self) -> synapse.state.StateHandler:
pass
-
def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
pass
-
- def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
+ def get_deactivate_account_handler(
+ self
+ ) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
pass
-
def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler:
pass
-
def get_room_member_handler(self) -> synapse.handlers.room_member.RoomMemberHandler:
pass
-
- def get_event_creation_handler(self) -> synapse.handlers.message.EventCreationHandler:
+ def get_event_creation_handler(
+ self
+ ) -> synapse.handlers.message.EventCreationHandler:
pass
-
- def get_set_password_handler(self) -> synapse.handlers.set_password.SetPasswordHandler:
+ def get_set_password_handler(
+ self
+ ) -> synapse.handlers.set_password.SetPasswordHandler:
pass
-
def get_federation_sender(self) -> synapse.federation.sender.FederationSender:
pass
-
- def get_federation_transport_client(self) -> synapse.federation.transport.client.TransportLayerClient:
+ def get_federation_transport_client(
+ self
+ ) -> synapse.federation.transport.client.TransportLayerClient:
pass
-
- def get_media_repository_resource(self) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource:
+ def get_media_repository_resource(
+ self
+ ) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource:
pass
-
- def get_media_repository(self) -> synapse.rest.media.v1.media_repository.MediaRepository:
+ def get_media_repository(
+ self
+ ) -> synapse.rest.media.v1.media_repository.MediaRepository:
pass
-
- def get_server_notices_manager(self) -> synapse.server_notices.server_notices_manager.ServerNoticesManager:
+ def get_server_notices_manager(
+ self
+ ) -> synapse.server_notices.server_notices_manager.ServerNoticesManager:
pass
-
- def get_server_notices_sender(self) -> synapse.server_notices.server_notices_sender.ServerNoticesSender:
+ def get_server_notices_sender(
+ self
+ ) -> synapse.server_notices.server_notices_sender.ServerNoticesSender:
pass
diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py
index 5e3044d164..415e9c17d8 100644
--- a/synapse/server_notices/consent_server_notices.py
+++ b/synapse/server_notices/consent_server_notices.py
@@ -30,6 +30,7 @@ class ConsentServerNotices(object):
"""Keeps track of whether we need to send users server_notices about
privacy policy consent, and sends one if we do.
"""
+
def __init__(self, hs):
"""
@@ -49,12 +50,11 @@ class ConsentServerNotices(object):
if not self._server_notices_manager.is_enabled():
raise ConfigError(
"user_consent configuration requires server notices, but "
- "server notices are not enabled.",
+ "server notices are not enabled."
)
- if 'body' not in self._server_notice_content:
+ if "body" not in self._server_notice_content:
raise ConfigError(
- "user_consent server_notice_consent must contain a 'body' "
- "key.",
+ "user_consent server_notice_consent must contain a 'body' " "key."
)
self._consent_uri_builder = ConsentURIBuilder(hs.config)
@@ -95,18 +95,14 @@ class ConsentServerNotices(object):
# need to send a message.
try:
consent_uri = self._consent_uri_builder.build_user_consent_uri(
- get_localpart_from_id(user_id),
+ get_localpart_from_id(user_id)
)
content = copy_with_str_subst(
- self._server_notice_content, {
- 'consent_uri': consent_uri,
- },
- )
- yield self._server_notices_manager.send_notice(
- user_id, content,
+ self._server_notice_content, {"consent_uri": consent_uri}
)
+ yield self._server_notices_manager.send_notice(user_id, content)
yield self._store.user_set_consent_server_notice_sent(
- user_id, self._current_consent_version,
+ user_id, self._current_consent_version
)
except SynapseError as e:
logger.error("Error sending server notice about user consent: %s", e)
@@ -128,9 +124,7 @@ def copy_with_str_subst(x, substitutions):
if isinstance(x, string_types):
return x % substitutions
if isinstance(x, dict):
- return {
- k: copy_with_str_subst(v, substitutions) for (k, v) in iteritems(x)
- }
+ return {k: copy_with_str_subst(v, substitutions) for (k, v) in iteritems(x)}
if isinstance(x, (list, tuple)):
return [copy_with_str_subst(y) for y in x]
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index af15cba0ee..f183743f31 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -33,6 +33,7 @@ class ResourceLimitsServerNotices(object):
""" Keeps track of whether the server has reached it's resource limit and
ensures that the client is kept up to date.
"""
+
def __init__(self, hs):
"""
Args:
@@ -104,34 +105,28 @@ class ResourceLimitsServerNotices(object):
if currently_blocked and not is_auth_blocking:
# Room is notifying of a block, when it ought not to be.
# Remove block notification
- content = {
- "pinned": ref_events
- }
+ content = {"pinned": ref_events}
yield self._server_notices_manager.send_notice(
- user_id, content, EventTypes.Pinned, '',
+ user_id, content, EventTypes.Pinned, ""
)
elif not currently_blocked and is_auth_blocking:
# Room is not notifying of a block, when it ought to be.
# Add block notification
content = {
- 'body': event_content,
- 'msgtype': ServerNoticeMsgType,
- 'server_notice_type': ServerNoticeLimitReached,
- 'admin_contact': self._config.admin_contact,
- 'limit_type': event_limit_type
+ "body": event_content,
+ "msgtype": ServerNoticeMsgType,
+ "server_notice_type": ServerNoticeLimitReached,
+ "admin_contact": self._config.admin_contact,
+ "limit_type": event_limit_type,
}
event = yield self._server_notices_manager.send_notice(
- user_id, content, EventTypes.Message,
+ user_id, content, EventTypes.Message
)
- content = {
- "pinned": [
- event.event_id,
- ]
- }
+ content = {"pinned": [event.event_id]}
yield self._server_notices_manager.send_notice(
- user_id, content, EventTypes.Pinned, '',
+ user_id, content, EventTypes.Pinned, ""
)
except SynapseError as e:
@@ -156,9 +151,7 @@ class ResourceLimitsServerNotices(object):
max_id = yield self._store.add_tag_to_room(
user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
)
- self._notifier.on_new_event(
- "account_data_key", max_id, users=[user_id]
- )
+ self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
@defer.inlineCallbacks
def _is_room_currently_blocked(self, room_id):
@@ -188,7 +181,7 @@ class ResourceLimitsServerNotices(object):
referenced_events = []
if pinned_state_event is not None:
- referenced_events = list(pinned_state_event.content.get('pinned', []))
+ referenced_events = list(pinned_state_event.content.get("pinned", []))
events = yield self._store.get_events(referenced_events)
for event_id, event in iteritems(events):
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index c5cc6d728e..71e7e75320 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -51,8 +51,7 @@ class ServerNoticesManager(object):
@defer.inlineCallbacks
def send_notice(
- self, user_id, event_content,
- type=EventTypes.Message, state_key=None
+ self, user_id, event_content, type=EventTypes.Message, state_key=None
):
"""Send a notice to the given user
@@ -82,10 +81,10 @@ class ServerNoticesManager(object):
}
if state_key is not None:
- event_dict['state_key'] = state_key
+ event_dict["state_key"] = state_key
res = yield self._event_creation_handler.create_and_send_nonmember_event(
- requester, event_dict, ratelimit=False,
+ requester, event_dict, ratelimit=False
)
defer.returnValue(res)
@@ -104,11 +103,10 @@ class ServerNoticesManager(object):
if not self.is_enabled():
raise Exception("Server notices not enabled")
- assert self._is_mine_id(user_id), \
- "Cannot send server notices to remote users"
+ assert self._is_mine_id(user_id), "Cannot send server notices to remote users"
rooms = yield self._store.get_rooms_for_user_where_membership_is(
- user_id, [Membership.INVITE, Membership.JOIN],
+ user_id, [Membership.INVITE, Membership.JOIN]
)
system_mxid = self._config.server_notices_mxid
for room in rooms:
@@ -132,8 +130,8 @@ class ServerNoticesManager(object):
# avatar, we have to use both.
join_profile = None
if (
- self._config.server_notices_mxid_display_name is not None or
- self._config.server_notices_mxid_avatar_url is not None
+ self._config.server_notices_mxid_display_name is not None
+ or self._config.server_notices_mxid_avatar_url is not None
):
join_profile = {
"displayname": self._config.server_notices_mxid_display_name,
@@ -146,22 +144,18 @@ class ServerNoticesManager(object):
config={
"preset": RoomCreationPreset.PRIVATE_CHAT,
"name": self._config.server_notices_room_name,
- "power_level_content_override": {
- "users_default": -10,
- },
- "invite": (user_id,)
+ "power_level_content_override": {"users_default": -10},
+ "invite": (user_id,),
},
ratelimit=False,
creator_join_profile=join_profile,
)
- room_id = info['room_id']
+ room_id = info["room_id"]
max_id = yield self._store.add_tag_to_room(
- user_id, room_id, SERVER_NOTICE_ROOM_TAG, {},
- )
- self._notifier.on_new_event(
- "account_data_key", max_id, users=[user_id]
+ user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
)
+ self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
logger.info("Created server notices room %s for %s", room_id, user_id)
defer.returnValue(room_id)
diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py
index 6121b2f267..652bab58e3 100644
--- a/synapse/server_notices/server_notices_sender.py
+++ b/synapse/server_notices/server_notices_sender.py
@@ -24,6 +24,7 @@ class ServerNoticesSender(object):
"""A centralised place which sends server notices automatically when
Certain Events take place
"""
+
def __init__(self, hs):
"""
@@ -32,7 +33,7 @@ class ServerNoticesSender(object):
"""
self._server_notices = (
ConsentServerNotices(hs),
- ResourceLimitsServerNotices(hs)
+ ResourceLimitsServerNotices(hs),
)
@defer.inlineCallbacks
@@ -43,9 +44,7 @@ class ServerNoticesSender(object):
user_id (str): mxid of user who synced
"""
for sn in self._server_notices:
- yield sn.maybe_send_server_notice_to_user(
- user_id,
- )
+ yield sn.maybe_send_server_notice_to_user(user_id)
@defer.inlineCallbacks
def on_user_ip(self, user_id):
@@ -58,6 +57,4 @@ class ServerNoticesSender(object):
# we check for notices to send to the user in on_user_ip as well as
# in on_user_syncing
for sn in self._server_notices:
- yield sn.maybe_send_server_notice_to_user(
- user_id,
- )
+ yield sn.maybe_send_server_notice_to_user(user_id)
diff --git a/synapse/server_notices/worker_server_notices_sender.py b/synapse/server_notices/worker_server_notices_sender.py
index 4a133026c3..245ec7c64f 100644
--- a/synapse/server_notices/worker_server_notices_sender.py
+++ b/synapse/server_notices/worker_server_notices_sender.py
@@ -17,6 +17,7 @@ from twisted.internet import defer
class WorkerServerNoticesSender(object):
"""Stub impl of ServerNoticesSender which does nothing"""
+
def __init__(self, hs):
"""
Args:
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 36684ef9f6..1b454a56a1 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -21,6 +21,7 @@ from six import iteritems, itervalues
import attr
from frozendict import frozendict
+from prometheus_client import Histogram
from twisted.internet import defer
@@ -37,6 +38,14 @@ from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
+# Metrics for number of state groups involved in a resolution.
+state_groups_histogram = Histogram(
+ "synapse_state_number_state_groups_in_resolution",
+ "Number of state groups used when performing a state resolution",
+ buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
+)
+
+
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
@@ -98,8 +107,9 @@ class StateHandler(object):
self._state_resolution_handler = hs.get_state_resolution_handler()
@defer.inlineCallbacks
- def get_current_state(self, room_id, event_type=None, state_key="",
- latest_event_ids=None):
+ def get_current_state(
+ self, room_id, event_type=None, state_key="", latest_event_ids=None
+ ):
""" Retrieves the current state for the room. This is done by
calling `get_latest_events_in_room` to get the leading edges of the
event graph and then resolving any of the state conflicts.
@@ -128,8 +138,9 @@ class StateHandler(object):
defer.returnValue(event)
return
- state_map = yield self.store.get_events(list(state.values()),
- get_prev_content=False)
+ state_map = yield self.store.get_events(
+ list(state.values()), get_prev_content=False
+ )
state = {
key: state_map[e_id] for key, e_id in iteritems(state) if e_id in state_map
}
@@ -211,9 +222,7 @@ class StateHandler(object):
# state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group.
if old_state:
- prev_state_ids = {
- (s.type, s.state_key): s.event_id for s in old_state
- }
+ prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state}
if event.is_state():
current_state_ids = dict(prev_state_ids)
key = (event.type, event.state_key)
@@ -239,9 +248,7 @@ class StateHandler(object):
# Let's just correctly fill out the context and create a
# new state group for it.
- prev_state_ids = {
- (s.type, s.state_key): s.event_id for s in old_state
- }
+ prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state}
if event.is_state():
key = (event.type, event.state_key)
@@ -273,7 +280,7 @@ class StateHandler(object):
logger.debug("calling resolve_state_groups from compute_event_context")
entry = yield self.resolve_state_groups_for_events(
- event.room_id, event.prev_event_ids(),
+ event.room_id, event.prev_event_ids()
)
prev_state_ids = entry.state
@@ -296,9 +303,7 @@ class StateHandler(object):
# If the state at the event has a state group assigned then
# we can use that as the prev group
prev_group = entry.state_group
- delta_ids = {
- key: event.event_id
- }
+ delta_ids = {key: event.event_id}
elif entry.prev_group:
# If the state at the event only has a prev group, then we can
# use that as a prev group too.
@@ -360,31 +365,31 @@ class StateHandler(object):
# map from state group id to the state in that state group (where
# 'state' is a map from state key to event id)
# dict[int, dict[(str, str), str]]
- state_groups_ids = yield self.store.get_state_groups_ids(
- room_id, event_ids
- )
+ state_groups_ids = yield self.store.get_state_groups_ids(room_id, event_ids)
if len(state_groups_ids) == 0:
- defer.returnValue(_StateCacheEntry(
- state={},
- state_group=None,
- ))
+ defer.returnValue(_StateCacheEntry(state={}, state_group=None))
elif len(state_groups_ids) == 1:
name, state_list = list(state_groups_ids.items()).pop()
prev_group, delta_ids = yield self.store.get_state_group_delta(name)
- defer.returnValue(_StateCacheEntry(
- state=state_list,
- state_group=name,
- prev_group=prev_group,
- delta_ids=delta_ids,
- ))
+ defer.returnValue(
+ _StateCacheEntry(
+ state=state_list,
+ state_group=name,
+ prev_group=prev_group,
+ delta_ids=delta_ids,
+ )
+ )
room_version = yield self.store.get_room_version(room_id)
result = yield self._state_resolution_handler.resolve_state_groups(
- room_id, room_version, state_groups_ids, None,
+ room_id,
+ room_version,
+ state_groups_ids,
+ None,
state_res_store=StateResolutionStore(self.store),
)
defer.returnValue(result)
@@ -394,27 +399,21 @@ class StateHandler(object):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
- state_set_ids = [{
- (ev.type, ev.state_key): ev.event_id
- for ev in st
- } for st in state_sets]
-
- state_map = {
- ev.event_id: ev
- for st in state_sets
- for ev in st
- }
+ state_set_ids = [
+ {(ev.type, ev.state_key): ev.event_id for ev in st} for st in state_sets
+ ]
+
+ state_map = {ev.event_id: ev for st in state_sets for ev in st}
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_store(
- room_version, state_set_ids,
+ room_version,
+ state_set_ids,
event_map=state_map,
state_res_store=StateResolutionStore(self.store),
)
- new_state = {
- key: state_map[ev_id] for key, ev_id in iteritems(new_state)
- }
+ new_state = {key: state_map[ev_id] for key, ev_id in iteritems(new_state)}
defer.returnValue(new_state)
@@ -425,6 +424,7 @@ class StateResolutionHandler(object):
Note that the storage layer depends on this handler, so all functions must
be storage-independent.
"""
+
def __init__(self, hs):
self.clock = hs.get_clock()
@@ -444,7 +444,7 @@ class StateResolutionHandler(object):
@defer.inlineCallbacks
@log_function
def resolve_state_groups(
- self, room_id, room_version, state_groups_ids, event_map, state_res_store,
+ self, room_id, room_version, state_groups_ids, event_map, state_res_store
):
"""Resolves conflicts between a set of state groups
@@ -471,10 +471,7 @@ class StateResolutionHandler(object):
Returns:
Deferred[_StateCacheEntry]: resolved state
"""
- logger.debug(
- "resolve_state_groups state_groups %s",
- state_groups_ids.keys()
- )
+ logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
group_names = frozenset(state_groups_ids.keys())
@@ -488,6 +485,8 @@ class StateResolutionHandler(object):
"Resolving state for %s with %d groups", room_id, len(state_groups_ids)
)
+ state_groups_histogram.observe(len(state_groups_ids))
+
# start by assuming we won't have any conflicted state, and build up the new
# state map by iterating through the state groups. If we discover a conflict,
# we give up and instead use `resolve_events_with_store`.
@@ -529,10 +528,7 @@ class StateResolutionHandler(object):
defer.returnValue(cache)
-def _make_state_cache_entry(
- new_state,
- state_groups_ids,
-):
+def _make_state_cache_entry(new_state, state_groups_ids):
"""Given a resolved state, and a set of input state groups, pick one to base
a new state group on (if any), and return an appropriately-constructed
_StateCacheEntry.
@@ -562,10 +558,7 @@ def _make_state_cache_entry(
old_state_event_ids = set(itervalues(state))
if new_state_event_ids == old_state_event_ids:
# got an exact match.
- return _StateCacheEntry(
- state=new_state,
- state_group=sg,
- )
+ return _StateCacheEntry(state=new_state, state_group=sg)
# TODO: We want to create a state group for this set of events, to
# increase cache hits, but we need to make sure that it doesn't
@@ -576,20 +569,13 @@ def _make_state_cache_entry(
delta_ids = None
for old_group, old_state in iteritems(state_groups_ids):
- n_delta_ids = {
- k: v
- for k, v in iteritems(new_state)
- if old_state.get(k) != v
- }
+ n_delta_ids = {k: v for k, v in iteritems(new_state) if old_state.get(k) != v}
if not delta_ids or len(n_delta_ids) < len(delta_ids):
prev_group = old_group
delta_ids = n_delta_ids
return _StateCacheEntry(
- state=new_state,
- state_group=None,
- prev_group=prev_group,
- delta_ids=delta_ids,
+ state=new_state, state_group=None, prev_group=prev_group, delta_ids=delta_ids
)
@@ -618,11 +604,11 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
v = KNOWN_ROOM_VERSIONS[room_version]
if v.state_res == StateResolutionVersions.V1:
return v1.resolve_events_with_store(
- state_sets, event_map, state_res_store.get_events,
+ state_sets, event_map, state_res_store.get_events
)
else:
return v2.resolve_events_with_store(
- room_version, state_sets, event_map, state_res_store,
+ room_version, state_sets, event_map, state_res_store
)
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 29b4e86cfd..88acd4817e 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -57,23 +57,17 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
if len(state_sets) == 1:
defer.returnValue(state_sets[0])
- unconflicted_state, conflicted_state = _seperate(
- state_sets,
- )
+ unconflicted_state, conflicted_state = _seperate(state_sets)
needed_events = set(
- event_id
- for event_ids in itervalues(conflicted_state)
- for event_id in event_ids
+ event_id for event_ids in itervalues(conflicted_state) for event_id in event_ids
)
needed_event_count = len(needed_events)
if event_map is not None:
needed_events -= set(iterkeys(event_map))
logger.info(
- "Asking for %d/%d conflicted events",
- len(needed_events),
- needed_event_count,
+ "Asking for %d/%d conflicted events", len(needed_events), needed_event_count
)
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
@@ -97,17 +91,17 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
new_needed_events -= set(iterkeys(event_map))
logger.info(
- "Asking for %d/%d auth events",
- len(new_needed_events),
- new_needed_event_count,
+ "Asking for %d/%d auth events", len(new_needed_events), new_needed_event_count
)
state_map_new = yield state_map_factory(new_needed_events)
state_map.update(state_map_new)
- defer.returnValue(_resolve_with_state(
- unconflicted_state, conflicted_state, auth_events, state_map
- ))
+ defer.returnValue(
+ _resolve_with_state(
+ unconflicted_state, conflicted_state, auth_events, state_map
+ )
+ )
def _seperate(state_sets):
@@ -173,8 +167,9 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma
return auth_events
-def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event_ids,
- state_map):
+def _resolve_with_state(
+ unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map
+):
conflicted_state = {}
for key, event_ids in iteritems(conflicted_state_ids):
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
@@ -190,9 +185,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event
}
try:
- resolved_state = _resolve_state_events(
- conflicted_state, auth_events
- )
+ resolved_state = _resolve_state_events(conflicted_state, auth_events)
except Exception:
logger.exception("Failed to resolve state")
raise
@@ -218,37 +211,28 @@ def _resolve_state_events(conflicted_state, auth_events):
if POWER_KEY in conflicted_state:
events = conflicted_state[POWER_KEY]
logger.debug("Resolving conflicted power levels %r", events)
- resolved_state[POWER_KEY] = _resolve_auth_events(
- events, auth_events)
+ resolved_state[POWER_KEY] = _resolve_auth_events(events, auth_events)
auth_events.update(resolved_state)
for key, events in iteritems(conflicted_state):
if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
- resolved_state[key] = _resolve_auth_events(
- events,
- auth_events
- )
+ resolved_state[key] = _resolve_auth_events(events, auth_events)
auth_events.update(resolved_state)
for key, events in iteritems(conflicted_state):
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
- resolved_state[key] = _resolve_auth_events(
- events,
- auth_events
- )
+ resolved_state[key] = _resolve_auth_events(events, auth_events)
auth_events.update(resolved_state)
for key, events in iteritems(conflicted_state):
if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events)
- resolved_state[key] = _resolve_normal_events(
- events, auth_events
- )
+ resolved_state[key] = _resolve_normal_events(events, auth_events)
return resolved_state
@@ -257,9 +241,7 @@ def _resolve_auth_events(events, auth_events):
reverse = [i for i in reversed(_ordered_events(events))]
auth_keys = set(
- key
- for event in events
- for key in event_auth.auth_types_for_event(event)
+ key for event in events for key in event_auth.auth_types_for_event(event)
)
new_auth_events = {}
@@ -313,6 +295,6 @@ def _ordered_events(events):
def key_func(e):
# we have to use utf-8 rather than ascii here because it turns out we allow
# people to send us events with non-ascii event IDs :/
- return -int(e.depth), hashlib.sha1(e.event_id.encode('utf-8')).hexdigest()
+ return -int(e.depth), hashlib.sha1(e.event_id.encode("utf-8")).hexdigest()
return sorted(events, key=key_func)
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 650995c92c..db969e8997 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -70,19 +70,18 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
# Also fetch all auth events that appear in only some of the state sets'
# auth chains.
- auth_diff = yield _get_auth_chain_difference(
- state_sets, event_map, state_res_store,
- )
+ auth_diff = yield _get_auth_chain_difference(state_sets, event_map, state_res_store)
- full_conflicted_set = set(itertools.chain(
- itertools.chain.from_iterable(itervalues(conflicted_state)),
- auth_diff,
- ))
+ full_conflicted_set = set(
+ itertools.chain(
+ itertools.chain.from_iterable(itervalues(conflicted_state)), auth_diff
+ )
+ )
- events = yield state_res_store.get_events([
- eid for eid in full_conflicted_set
- if eid not in event_map
- ], allow_rejected=True)
+ events = yield state_res_store.get_events(
+ [eid for eid in full_conflicted_set if eid not in event_map],
+ allow_rejected=True,
+ )
event_map.update(events)
full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map)
@@ -91,22 +90,21 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
# Get and sort all the power events (kicks/bans/etc)
power_events = (
- eid for eid in full_conflicted_set
- if _is_power_event(event_map[eid])
+ eid for eid in full_conflicted_set if _is_power_event(event_map[eid])
)
sorted_power_events = yield _reverse_topological_power_sort(
- power_events,
- event_map,
- state_res_store,
- full_conflicted_set,
+ power_events, event_map, state_res_store, full_conflicted_set
)
logger.debug("sorted %d power events", len(sorted_power_events))
# Now sequentially auth each one
resolved_state = yield _iterative_auth_checks(
- room_version, sorted_power_events, unconflicted_state, event_map,
+ room_version,
+ sorted_power_events,
+ unconflicted_state,
+ event_map,
state_res_store,
)
@@ -116,23 +114,20 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
# events using the mainline of the resolved power level.
leftover_events = [
- ev_id
- for ev_id in full_conflicted_set
- if ev_id not in sorted_power_events
+ ev_id for ev_id in full_conflicted_set if ev_id not in sorted_power_events
]
logger.debug("sorting %d remaining events", len(leftover_events))
pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
leftover_events = yield _mainline_sort(
- leftover_events, pl, event_map, state_res_store,
+ leftover_events, pl, event_map, state_res_store
)
logger.debug("resolving remaining events")
resolved_state = yield _iterative_auth_checks(
- room_version, leftover_events, resolved_state, event_map,
- state_res_store,
+ room_version, leftover_events, resolved_state, event_map, state_res_store
)
logger.debug("resolved")
@@ -209,14 +204,16 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
auth_ids = set(
eid
for key, eid in iteritems(state_set)
- if (key[0] in (
- EventTypes.Member,
- EventTypes.ThirdPartyInvite,
- ) or key in (
- (EventTypes.PowerLevels, ''),
- (EventTypes.Create, ''),
- (EventTypes.JoinRules, ''),
- )) and eid not in common
+ if (
+ key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite)
+ or key
+ in (
+ (EventTypes.PowerLevels, ""),
+ (EventTypes.Create, ""),
+ (EventTypes.JoinRules, ""),
+ )
+ )
+ and eid not in common
)
auth_chain = yield state_res_store.get_auth_chain(auth_ids)
@@ -274,15 +271,16 @@ def _is_power_event(event):
return True
if event.type == EventTypes.Member:
- if event.membership in ('leave', 'ban'):
+ if event.membership in ("leave", "ban"):
return event.sender != event.state_key
return False
@defer.inlineCallbacks
-def _add_event_and_auth_chain_to_graph(graph, event_id, event_map,
- state_res_store, auth_diff):
+def _add_event_and_auth_chain_to_graph(
+ graph, event_id, event_map, state_res_store, auth_diff
+):
"""Helper function for _reverse_topological_power_sort that add the event
and its auth chain (that is in the auth diff) to the graph
@@ -327,7 +325,7 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_
graph = {}
for event_id in event_ids:
yield _add_event_and_auth_chain_to_graph(
- graph, event_id, event_map, state_res_store, auth_diff,
+ graph, event_id, event_map, state_res_store, auth_diff
)
event_to_pl = {}
@@ -342,18 +340,16 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_
return -pl, ev.origin_server_ts, event_id
# Note: graph is modified during the sort
- it = lexicographical_topological_sort(
- graph,
- key=_get_power_order,
- )
+ it = lexicographical_topological_sort(graph, key=_get_power_order)
sorted_events = list(it)
defer.returnValue(sorted_events)
@defer.inlineCallbacks
-def _iterative_auth_checks(room_version, event_ids, base_state, event_map,
- state_res_store):
+def _iterative_auth_checks(
+ room_version, event_ids, base_state, event_map, state_res_store
+):
"""Sequentially apply auth checks to each event in given list, updating the
state as it goes along.
@@ -389,9 +385,11 @@ def _iterative_auth_checks(room_version, event_ids, base_state, event_map,
try:
event_auth.check(
- room_version, event, auth_events,
+ room_version,
+ event,
+ auth_events,
do_sig_check=False,
- do_size_check=False
+ do_size_check=False,
)
resolved_state[(event.type, event.state_key)] = event_id
@@ -402,8 +400,7 @@ def _iterative_auth_checks(room_version, event_ids, base_state, event_map,
@defer.inlineCallbacks
-def _mainline_sort(event_ids, resolved_power_event_id, event_map,
- state_res_store):
+def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_store):
"""Returns a sorted list of event_ids sorted by mainline ordering based on
the given event resolved_power_event_id
@@ -436,8 +433,7 @@ def _mainline_sort(event_ids, resolved_power_event_id, event_map,
order_map = {}
for ev_id in event_ids:
depth = yield _get_mainline_depth_for_event(
- event_map[ev_id], mainline_map,
- event_map, state_res_store,
+ event_map[ev_id], mainline_map, event_map, state_res_store
)
order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 71316f7d09..6b0ca80087 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -279,23 +279,35 @@ class DataStore(
"""
Counts the number of users who used this homeserver in the last 24 hours.
"""
+ yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
+ return self.runInteraction("count_daily_users", self._count_users, yesterday)
- def _count_users(txn):
- yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
-
- sql = """
- SELECT COALESCE(count(*), 0) FROM (
- SELECT user_id FROM user_ips
- WHERE last_seen > ?
- GROUP BY user_id
- ) u
- """
-
- txn.execute(sql, (yesterday,))
- count, = txn.fetchone()
- return count
+ def count_monthly_users(self):
+ """
+ Counts the number of users who used this homeserver in the last 30 days.
+ Note this method is intended for phonehome metrics only and is different
+ from the mau figure in synapse.storage.monthly_active_users which,
+ amongst other things, includes a 3 day grace period before a user counts.
+ """
+ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
+ return self.runInteraction(
+ "count_monthly_users", self._count_users, thirty_days_ago
+ )
- return self.runInteraction("count_users", _count_users)
+ def _count_users(self, txn, time_from):
+ """
+ Returns number of users seen in the past time_from period
+ """
+ sql = """
+ SELECT COALESCE(count(*), 0) FROM (
+ SELECT user_id FROM user_ips
+ WHERE last_seen > ?
+ GROUP BY user_id
+ ) u
+ """
+ txn.execute(sql, (time_from,))
+ count, = txn.fetchone()
+ return count
def count_r30_users(self):
"""
@@ -347,7 +359,7 @@ class DataStore(
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
for row in txn:
- if row[0] == 'unknown':
+ if row[0] == "unknown":
pass
results[row[0]] = row[1]
@@ -374,7 +386,7 @@ class DataStore(
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
count, = txn.fetchone()
- results['all'] = count
+ results["all"] = count
return results
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index ae891aa332..29589853c6 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -38,6 +38,14 @@ from synapse.util.caches.descriptors import Cache
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.stringutils import exception_to_unicode
+# import a function which will return a monotonic time, in seconds
+try:
+ # on python 3, use time.monotonic, since time.clock can go backwards
+ from time import monotonic as monotonic_time
+except ImportError:
+ # ... but python 2 doesn't have it
+ from time import clock as monotonic_time
+
logger = logging.getLogger(__name__)
try:
@@ -167,22 +175,22 @@ class PerformanceCounters(object):
self.current_counters = {}
self.previous_counters = {}
- def update(self, key, start_time, end_time=None):
- if end_time is None:
- end_time = time.time()
- duration = end_time - start_time
+ def update(self, key, duration_secs):
count, cum_time = self.current_counters.get(key, (0, 0))
count += 1
- cum_time += duration
+ cum_time += duration_secs
self.current_counters[key] = (count, cum_time)
- return end_time
- def interval(self, interval_duration, limit=3):
+ def interval(self, interval_duration_secs, limit=3):
counters = []
for name, (count, cum_time) in iteritems(self.current_counters):
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
counters.append(
- ((cum_time - prev_time) / interval_duration, count - prev_count, name)
+ (
+ (cum_time - prev_time) / interval_duration_secs,
+ count - prev_count,
+ name,
+ )
)
self.previous_counters = dict(self.current_counters)
@@ -213,7 +221,6 @@ class SQLBaseStore(object):
# is running in mainline, and we have some nice monitoring frontends
# to watch it
self._txn_perf_counters = PerformanceCounters()
- self._get_event_counters = PerformanceCounters()
self._get_event_cache = Cache(
"*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
@@ -299,12 +306,12 @@ class SQLBaseStore(object):
def select_users_with_no_expiration_date_txn(txn):
"""Retrieves the list of registered users with no expiration date from the
- database.
+ database, filtering out deactivated users.
"""
sql = (
"SELECT users.name FROM users"
" LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
- " WHERE account_validity.user_id is NULL;"
+ " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
)
txn.execute(sql, [])
@@ -312,9 +319,7 @@ class SQLBaseStore(object):
if res:
for user in res:
self.set_expiration_date_for_user_txn(
- txn,
- user["name"],
- use_delta=True,
+ txn, user["name"], use_delta=True
)
yield self.runInteraction(
@@ -352,32 +357,24 @@ class SQLBaseStore(object):
)
def start_profiling(self):
- self._previous_loop_ts = self._clock.time_msec()
+ self._previous_loop_ts = monotonic_time()
def loop():
curr = self._current_txn_total_time
prev = self._previous_txn_total_time
self._previous_txn_total_time = curr
- time_now = self._clock.time_msec()
+ time_now = monotonic_time()
time_then = self._previous_loop_ts
self._previous_loop_ts = time_now
- ratio = (curr - prev) / (time_now - time_then)
+ duration = time_now - time_then
+ ratio = (curr - prev) / duration
- top_three_counters = self._txn_perf_counters.interval(
- time_now - time_then, limit=3
- )
-
- top_3_event_counters = self._get_event_counters.interval(
- time_now - time_then, limit=3
- )
+ top_three_counters = self._txn_perf_counters.interval(duration, limit=3)
perf_logger.info(
- "Total database time: %.3f%% {%s} {%s}",
- ratio * 100,
- top_three_counters,
- top_3_event_counters,
+ "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters
)
self._clock.looping_call(loop, 10000)
@@ -385,7 +382,7 @@ class SQLBaseStore(object):
def _new_transaction(
self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
):
- start = time.time()
+ start = monotonic_time()
txn_id = self._TXN_ID
# We don't really need these to be unique, so lets stop it from
@@ -451,7 +448,7 @@ class SQLBaseStore(object):
logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
finally:
- end = time.time()
+ end = monotonic_time()
duration = end - start
LoggingContext.current_context().add_database_transaction(duration)
@@ -459,7 +456,7 @@ class SQLBaseStore(object):
transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
self._current_txn_total_time += duration
- self._txn_perf_counters.update(desc, start, end)
+ self._txn_perf_counters.update(desc, duration)
sql_txn_timer.labels(desc).observe(duration)
@defer.inlineCallbacks
@@ -525,11 +522,11 @@ class SQLBaseStore(object):
)
parent_context = None
- start_time = time.time()
+ start_time = monotonic_time()
def inner_func(conn, *args, **kwargs):
with LoggingContext("runWithConnection", parent_context) as context:
- sched_duration_sec = time.time() - start_time
+ sched_duration_sec = monotonic_time() - start_time
sql_scheduling_timer.observe(sched_duration_sec)
context.add_database_scheduled(sched_duration_sec)
@@ -1667,7 +1664,7 @@ def db_to_json(db_content):
# Decode it to a Unicode string before feeding it to json.loads, so we
# consistenty get a Unicode-containing object out.
if isinstance(db_content, (bytes, bytearray)):
- db_content = db_content.decode('utf8')
+ db_content = db_content.decode("utf8")
try:
return json.loads(db_content)
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index b8b8273f73..50f913a414 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -169,7 +169,7 @@ class BackgroundUpdateStore(SQLBaseStore):
in_flight = set(update["update_name"] for update in updates)
for update in updates:
if update["depends_on"] not in in_flight:
- self._background_update_queue.append(update['update_name'])
+ self._background_update_queue.append(update["update_name"])
if not self._background_update_queue:
# no work left to do
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index d102e07372..3413a46675 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -149,9 +149,7 @@ class DeviceWorkerStore(SQLBaseStore):
defer.returnValue((stream_id_cutoff, []))
results = yield self._get_device_update_edus_by_remote(
- destination,
- from_stream_id,
- query_map,
+ destination, from_stream_id, query_map
)
defer.returnValue((now_stream_id, results))
@@ -182,9 +180,7 @@ class DeviceWorkerStore(SQLBaseStore):
return list(txn)
@defer.inlineCallbacks
- def _get_device_update_edus_by_remote(
- self, destination, from_stream_id, query_map,
- ):
+ def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_map):
"""Returns a list of device update EDUs as well as E2EE keys
Args:
@@ -210,7 +206,7 @@ class DeviceWorkerStore(SQLBaseStore):
# The prev_id for the first row is always the last row before
# `from_stream_id`
prev_id = yield self._get_last_device_update_for_remote_user(
- destination, user_id, from_stream_id,
+ destination, user_id, from_stream_id
)
for device_id, device in iteritems(user_devices):
stream_id = query_map[(user_id, device_id)]
@@ -238,7 +234,7 @@ class DeviceWorkerStore(SQLBaseStore):
defer.returnValue(results)
def _get_last_device_update_for_remote_user(
- self, destination, user_id, from_stream_id,
+ self, destination, user_id, from_stream_id
):
def f(txn):
prev_sent_id_sql = """
diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py
index 521936e3b0..f40ef2ab64 100644
--- a/synapse/storage/e2e_room_keys.py
+++ b/synapse/storage/e2e_room_keys.py
@@ -87,10 +87,10 @@ class EndToEndRoomKeyStore(SQLBaseStore):
},
values={
"version": version,
- "first_message_index": room_key['first_message_index'],
- "forwarded_count": room_key['forwarded_count'],
- "is_verified": room_key['is_verified'],
- "session_data": json.dumps(room_key['session_data']),
+ "first_message_index": room_key["first_message_index"],
+ "forwarded_count": room_key["forwarded_count"],
+ "is_verified": room_key["is_verified"],
+ "session_data": json.dumps(room_key["session_data"]),
},
lock=False,
)
@@ -118,13 +118,13 @@ class EndToEndRoomKeyStore(SQLBaseStore):
try:
version = int(version)
except ValueError:
- defer.returnValue({'rooms': {}})
+ defer.returnValue({"rooms": {}})
keyvalues = {"user_id": user_id, "version": version}
if room_id:
- keyvalues['room_id'] = room_id
+ keyvalues["room_id"] = room_id
if session_id:
- keyvalues['session_id'] = session_id
+ keyvalues["session_id"] = session_id
rows = yield self._simple_select_list(
table="e2e_room_keys",
@@ -141,10 +141,10 @@ class EndToEndRoomKeyStore(SQLBaseStore):
desc="get_e2e_room_keys",
)
- sessions = {'rooms': {}}
+ sessions = {"rooms": {}}
for row in rows:
- room_entry = sessions['rooms'].setdefault(row['room_id'], {"sessions": {}})
- room_entry['sessions'][row['session_id']] = {
+ room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}})
+ room_entry["sessions"][row["session_id"]] = {
"first_message_index": row["first_message_index"],
"forwarded_count": row["forwarded_count"],
"is_verified": row["is_verified"],
@@ -174,9 +174,9 @@ class EndToEndRoomKeyStore(SQLBaseStore):
keyvalues = {"user_id": user_id, "version": int(version)}
if room_id:
- keyvalues['room_id'] = room_id
+ keyvalues["room_id"] = room_id
if session_id:
- keyvalues['session_id'] = session_id
+ keyvalues["session_id"] = session_id
yield self._simple_delete(
table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
@@ -191,7 +191,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
row = txn.fetchone()
if not row:
- raise StoreError(404, 'No current backup version')
+ raise StoreError(404, "No current backup version")
return row[0]
def get_e2e_room_keys_version_info(self, user_id, version=None):
@@ -255,7 +255,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
current_version = txn.fetchone()[0]
if current_version is None:
- current_version = '0'
+ current_version = "0"
new_version = str(int(current_version) + 1)
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 1b97ee74e3..289b6bc281 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -45,6 +45,10 @@ class PostgresEngine(object):
# together. For example, version 8.1.5 will be returned as 80105
self._version = db_conn.server_version
+ # Are we on a supported PostgreSQL version?
+ if self._version < 90500:
+ raise RuntimeError("Synapse requires PostgreSQL 9.5+ or above.")
+
db_conn.set_isolation_level(
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
@@ -64,9 +68,9 @@ class PostgresEngine(object):
@property
def can_native_upsert(self):
"""
- Can we use native UPSERTs? This requires PostgreSQL 9.5+.
+ Can we use native UPSERTs?
"""
- return self._version >= 90500
+ return True
def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 933bcf42c2..e9b9caa49a 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -85,7 +85,7 @@ class Sqlite3Engine(object):
def _parse_match_info(buf):
bufsize = len(buf)
- return [struct.unpack('@I', buf[i : i + 4])[0] for i in range(0, bufsize, 4)]
+ return [struct.unpack("@I", buf[i : i + 4])[0] for i in range(0, bufsize, 4)]
def _rank(raw_match_info):
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 09e39c2c28..cb4478342f 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -190,6 +190,34 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
room_id,
)
+ def get_rooms_with_many_extremities(self, min_count, limit):
+ """Get the top rooms with at least N extremities.
+
+ Args:
+ min_count (int): The minimum number of extremities
+ limit (int): The maximum number of rooms to return.
+
+ Returns:
+ Deferred[list]: At most `limit` room IDs that have at least
+ `min_count` extremities, sorted by extremity count.
+ """
+
+ def _get_rooms_with_many_extremities_txn(txn):
+ sql = """
+ SELECT room_id FROM event_forward_extremities
+ GROUP BY room_id
+ HAVING count(*) > ?
+ ORDER BY count(*) DESC
+ LIMIT ?
+ """
+
+ txn.execute(sql, (min_count, limit))
+ return [room_id for room_id, in txn]
+
+ return self.runInteraction(
+ "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn
+ )
+
@cached(max_entries=5000, iterable=True)
def get_latest_event_ids_in_room(self, room_id):
return self._simple_select_onecol(
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index a729f3e067..eca77069fd 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -277,7 +277,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# contain results from the first query, correctly ordered, followed
# by results from the second query, but we want them all ordered
# by stream_ordering, oldest first.
- notifs.sort(key=lambda r: r['stream_ordering'])
+ notifs.sort(key=lambda r: r["stream_ordering"])
# Take only up to the limit. We have to stop at the limit because
# one of the subqueries may have hit the limit.
@@ -379,7 +379,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# contain results from the first query, correctly ordered, followed
# by results from the second query, but we want them all ordered
# by received_ts (most recent first)
- notifs.sort(key=lambda r: -(r['received_ts'] or 0))
+ notifs.sort(key=lambda r: -(r["received_ts"] or 0))
# Now return the first `limit`
defer.returnValue(notifs[:limit])
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index f9162be9b9..fefba39ea1 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -17,14 +17,14 @@
import itertools
import logging
-from collections import OrderedDict, deque, namedtuple
+from collections import Counter as c_counter, OrderedDict, deque, namedtuple
from functools import wraps
from six import iteritems, text_type
from six.moves import range
from canonicaljson import json
-from prometheus_client import Counter
+from prometheus_client import Counter, Histogram
from twisted.internet import defer
@@ -33,6 +33,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
+from synapse.metrics import BucketCollector
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateResolutionStore
from synapse.storage.background_updates import BackgroundUpdateStore
@@ -73,6 +74,21 @@ state_delta_reuse_delta_counter = Counter(
"synapse_storage_events_state_delta_reuse_delta", ""
)
+# The number of forward extremities for each new event.
+forward_extremities_counter = Histogram(
+ "synapse_storage_events_forward_extremities_persisted",
+ "Number of forward extremities for each new event",
+ buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
+)
+
+# The number of stale forward extremities for each new event. Stale extremities
+# are those that were in the previous set of extremities as well as the new.
+stale_forward_extremities_counter = Histogram(
+ "synapse_storage_events_stale_forward_extremities_persisted",
+ "Number of unchanged forward extremities for each new event",
+ buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
+)
+
def encode_json(json_object):
"""
@@ -220,13 +236,39 @@ class EventsStore(
EventsWorkerStore,
BackgroundUpdateStore,
):
-
def __init__(self, db_conn, hs):
super(EventsStore, self).__init__(db_conn, hs)
self._event_persist_queue = _EventPeristenceQueue()
self._state_resolution_handler = hs.get_state_resolution_handler()
+ # Collect metrics on the number of forward extremities that exist.
+ # Counter of number of extremities to count
+ self._current_forward_extremities_amount = c_counter()
+
+ BucketCollector(
+ "synapse_forward_extremities",
+ lambda: self._current_forward_extremities_amount,
+ buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"],
+ )
+
+ # Read the extrems every 60 minutes
+ hs.get_clock().looping_call(self._read_forward_extremities, 60 * 60 * 1000)
+
+ @defer.inlineCallbacks
+ def _read_forward_extremities(self):
+ def fetch(txn):
+ txn.execute(
+ """
+ select count(*) c from event_forward_extremities
+ group by room_id
+ """
+ )
+ return txn.fetchall()
+
+ res = yield self.runInteraction("read_forward_extremities", fetch)
+ self._current_forward_extremities_amount = c_counter(list(x[0] for x in res))
+
@defer.inlineCallbacks
def persist_events(self, events_and_contexts, backfilled=False):
"""
@@ -514,6 +556,8 @@ class EventsStore(
and not event.internal_metadata.is_soft_failed()
]
+ latest_event_ids = set(latest_event_ids)
+
# start with the existing forward extremities
result = set(latest_event_ids)
@@ -537,6 +581,13 @@ class EventsStore(
)
result.difference_update(existing_prevs)
+ # We only update metrics for events that change forward extremities
+ # (e.g. we ignore backfill/outliers/etc)
+ if result != latest_event_ids:
+ forward_extremities_counter.observe(len(result))
+ stale = latest_event_ids & result
+ stale_forward_extremities_counter.observe(len(stale))
+
defer.returnValue(result)
@defer.inlineCallbacks
@@ -568,17 +619,11 @@ class EventsStore(
)
txn.execute(sql, batch)
- results.extend(
- r[0]
- for r in txn
- if not json.loads(r[1]).get("soft_failed")
- )
+ results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed"))
for chunk in batch_iter(event_ids, 100):
yield self.runInteraction(
- "_get_events_which_are_prevs",
- _get_events_which_are_prevs_txn,
- chunk,
+ "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
)
defer.returnValue(results)
@@ -640,9 +685,7 @@ class EventsStore(
for chunk in batch_iter(event_ids, 100):
yield self.runInteraction(
- "_get_prevs_before_rejected",
- _get_prevs_before_rejected_txn,
- chunk,
+ "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
)
defer.returnValue(existing_prevs)
diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/events_bg_updates.py
index 75c1935bf3..1ce21d190c 100644
--- a/synapse/storage/events_bg_updates.py
+++ b/synapse/storage/events_bg_updates.py
@@ -64,8 +64,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
)
self.register_background_update_handler(
- self.DELETE_SOFT_FAILED_EXTREMITIES,
- self._cleanup_extremities_bg_update,
+ self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update
)
@defer.inlineCallbacks
@@ -269,7 +268,8 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
LEFT JOIN events USING (event_id)
LEFT JOIN event_json USING (event_id)
LEFT JOIN rejections USING (event_id)
- """, (batch_size,)
+ """,
+ (batch_size,),
)
for prev_event_id, event_id, metadata, rejected, outlier in txn:
@@ -364,13 +364,12 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
column="event_id",
iterable=to_delete,
keyvalues={},
- retcols=("room_id",)
+ retcols=("room_id",),
)
room_ids = set(row["room_id"] for row in rows)
for room_id in room_ids:
txn.call_after(
- self.get_latest_event_ids_in_room.invalidate,
- (room_id,)
+ self.get_latest_event_ids_in_room.invalidate, (room_id,)
)
self._simple_delete_many_txn(
@@ -384,7 +383,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
return len(original_set)
num_handled = yield self.runInteraction(
- "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn,
+ "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn
)
if not num_handled:
@@ -394,8 +393,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
txn.execute("DROP TABLE _extremities_to_check")
yield self.runInteraction(
- "_cleanup_extremities_bg_update_drop_table",
- _drop_table_txn,
+ "_cleanup_extremities_bg_update_drop_table", _drop_table_txn
)
defer.returnValue(num_handled)
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index cc7df5cf14..6d680d405a 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -27,7 +27,6 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import NotFoundError
from synapse.api.room_versions import EventFormatVersions
from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401
-# these are only included to make the type annotations work
from synapse.events.snapshot import EventContext # noqa: F401
from synapse.events.utils import prune_event
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -111,8 +110,7 @@ class EventsWorkerStore(SQLBaseStore):
return ts
return self.runInteraction(
- "get_approximate_received_ts",
- _get_approximate_received_ts_txn,
+ "get_approximate_received_ts", _get_approximate_received_ts_txn
)
@defer.inlineCallbacks
@@ -677,7 +675,8 @@ class EventsWorkerStore(SQLBaseStore):
"""
return self.runInteraction(
"get_total_state_event_counts",
- self._get_total_state_event_counts_txn, room_id
+ self._get_total_state_event_counts_txn,
+ room_id,
)
def _get_current_state_event_counts_txn(self, txn, room_id):
@@ -701,7 +700,8 @@ class EventsWorkerStore(SQLBaseStore):
"""
return self.runInteraction(
"get_current_state_event_counts",
- self._get_current_state_event_counts_txn, room_id
+ self._get_current_state_event_counts_txn,
+ room_id,
)
@defer.inlineCallbacks
diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py
index dce6a43ac1..73e6fc6de2 100644
--- a/synapse/storage/group_server.py
+++ b/synapse/storage/group_server.py
@@ -1179,11 +1179,7 @@ class GroupServerStore(SQLBaseStore):
for table in tables:
self._simple_delete_txn(
- txn,
- table=table,
- keyvalues={"group_id": group_id},
+ txn, table=table, keyvalues={"group_id": group_id}
)
- return self.runInteraction(
- "delete_group", _delete_group_txn
- )
+ return self.runInteraction("delete_group", _delete_group_txn)
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index e3655ad8d7..e72f89e446 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -131,7 +131,7 @@ class KeyStore(SQLBaseStore):
def _invalidate(res):
f = self._get_server_verify_key.invalidate
for i in invalidations:
- f((i, ))
+ f((i,))
return res
return self.runInteraction(
diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py
index 3ecf47e7a7..6b1238ce4a 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/media_repository.py
@@ -22,11 +22,11 @@ class MediaRepositoryStore(BackgroundUpdateStore):
super(MediaRepositoryStore, self).__init__(db_conn, hs)
self.register_background_index_update(
- update_name='local_media_repository_url_idx',
- index_name='local_media_repository_url_idx',
- table='local_media_repository',
- columns=['created_ts'],
- where_clause='url_cache IS NOT NULL',
+ update_name="local_media_repository_url_idx",
+ index_name="local_media_repository_url_idx",
+ table="local_media_repository",
+ columns=["created_ts"],
+ where_clause="url_cache IS NOT NULL",
)
def get_local_media(self, media_id):
@@ -108,12 +108,12 @@ class MediaRepositoryStore(BackgroundUpdateStore):
return dict(
zip(
(
- 'response_code',
- 'etag',
- 'expires_ts',
- 'og',
- 'media_id',
- 'download_ts',
+ "response_code",
+ "etag",
+ "expires_ts",
+ "og",
+ "media_id",
+ "download_ts",
),
row,
)
diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py
index 8aa8abc470..081564360f 100644
--- a/synapse/storage/monthly_active_users.py
+++ b/synapse/storage/monthly_active_users.py
@@ -86,11 +86,11 @@ class MonthlyActiveUsersStore(SQLBaseStore):
if len(self.reserved_users) > 0:
# questionmarks is a hack to overcome sqlite not supporting
# tuples in 'WHERE IN %s'
- questionmarks = '?' * len(self.reserved_users)
+ questionmarks = "?" * len(self.reserved_users)
query_args.extend(self.reserved_users)
sql = base_sql + """ AND user_id NOT IN ({})""".format(
- ','.join(questionmarks)
+ ",".join(questionmarks)
)
else:
sql = base_sql
@@ -124,7 +124,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
if len(self.reserved_users) > 0:
query_args.extend(self.reserved_users)
sql = base_sql + """ AND user_id NOT IN ({})""".format(
- ','.join(questionmarks)
+ ",".join(questionmarks)
)
else:
sql = base_sql
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index f2c1bed487..7c4e1dc7ec 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -133,7 +133,7 @@ def _setup_new_database(cur, database_engine):
if ver <= SCHEMA_VERSION:
valid_dirs.append((ver, abs_path))
else:
- logger.warn("Unexpected entry in 'full_schemas': %s", filename)
+ logger.debug("Ignoring entry '%s' in 'full_schemas'", filename)
if not valid_dirs:
raise PrepareDatabaseException(
@@ -146,9 +146,10 @@ def _setup_new_database(cur, database_engine):
directory_entries = os.listdir(sql_dir)
- for filename in sorted(fnmatch.filter(directory_entries, "*.sql") + fnmatch.filter(
- directory_entries, "*.sql." + specific
- )):
+ for filename in sorted(
+ fnmatch.filter(directory_entries, "*.sql")
+ + fnmatch.filter(directory_entries, "*.sql." + specific)
+ ):
sql_loc = os.path.join(sql_dir, filename)
logger.debug("Applying schema %s", sql_loc)
executescript(cur, sql_loc)
@@ -313,7 +314,7 @@ def _apply_module_schemas(txn, database_engine, config):
application config
"""
for (mod, _config) in config.password_providers:
- if not hasattr(mod, 'get_db_schema_files'):
+ if not hasattr(mod, "get_db_schema_files"):
continue
modname = ".".join((mod.__module__, mod.__name__))
_apply_module_schema_files(
@@ -343,7 +344,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
continue
root_name, ext = os.path.splitext(name)
- if ext != '.sql':
+ if ext != ".sql":
raise PrepareDatabaseException(
"only .sql files are currently supported for module schemas"
)
@@ -407,7 +408,7 @@ def get_statements(f):
def executescript(txn, schema_path):
- with open(schema_path, 'r') as f:
+ with open(schema_path, "r") as f:
for statement in get_statements(f):
txn.execute(statement)
diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py
index aeec2f57c4..0ff392bdb4 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/profile.py
@@ -41,7 +41,7 @@ class ProfileWorkerStore(SQLBaseStore):
defer.returnValue(
ProfileInfo(
- avatar_url=profile['avatar_url'], display_name=profile['displayname']
+ avatar_url=profile["avatar_url"], display_name=profile["displayname"]
)
)
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 9e406baafa..98cec8c82b 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -46,12 +46,12 @@ def _load_rules(rawrules, enabled_map):
rules = list(list_with_base_rules(ruleslist))
for i, rule in enumerate(rules):
- rule_id = rule['rule_id']
+ rule_id = rule["rule_id"]
if rule_id in enabled_map:
- if rule.get('enabled', True) != bool(enabled_map[rule_id]):
+ if rule.get("enabled", True) != bool(enabled_map[rule_id]):
# Rules are cached across users.
rule = dict(rule)
- rule['enabled'] = bool(enabled_map[rule_id])
+ rule["enabled"] = bool(enabled_map[rule_id])
rules[i] = rule
return rules
@@ -126,12 +126,12 @@ class PushRulesWorkerStore(
def get_push_rules_enabled_for_user(self, user_id):
results = yield self._simple_select_list(
table="push_rules_enable",
- keyvalues={'user_name': user_id},
+ keyvalues={"user_name": user_id},
retcols=("user_name", "rule_id", "enabled"),
desc="get_push_rules_enabled_for_user",
)
defer.returnValue(
- {r['rule_id']: False if r['enabled'] == 0 else True for r in results}
+ {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
)
def have_push_rules_changed_for_user(self, user_id, last_id):
@@ -175,7 +175,7 @@ class PushRulesWorkerStore(
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
for row in rows:
- results.setdefault(row['user_name'], []).append(row)
+ results.setdefault(row["user_name"], []).append(row)
enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
@@ -194,7 +194,7 @@ class PushRulesWorkerStore(
rule (Dict): A push rule.
"""
# Create new rule id
- rule_id_scope = '/'.join(rule["rule_id"].split('/')[:-1])
+ rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
new_rule_id = rule_id_scope + "/" + new_room_id
# Change room id in each condition
@@ -334,8 +334,8 @@ class PushRulesWorkerStore(
desc="bulk_get_push_rules_enabled",
)
for row in rows:
- enabled = bool(row['enabled'])
- results.setdefault(row['user_name'], {})[row['rule_id']] = enabled
+ enabled = bool(row["enabled"])
+ results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
defer.returnValue(results)
@@ -568,7 +568,7 @@ class PushRuleStore(PushRulesWorkerStore):
def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
self._simple_delete_one_txn(
- txn, "push_rules", {'user_name': user_id, 'rule_id': rule_id}
+ txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}
)
self._insert_push_rules_update_txn(
@@ -605,9 +605,9 @@ class PushRuleStore(PushRulesWorkerStore):
self._simple_upsert_txn(
txn,
"push_rules_enable",
- {'user_name': user_id, 'rule_id': rule_id},
- {'enabled': 1 if enabled else 0},
- {'id': new_id},
+ {"user_name": user_id, "rule_id": rule_id},
+ {"enabled": 1 if enabled else 0},
+ {"id": new_id},
)
self._insert_push_rules_update_txn(
@@ -645,8 +645,8 @@ class PushRuleStore(PushRulesWorkerStore):
self._simple_update_one_txn(
txn,
"push_rules",
- {'user_name': user_id, 'rule_id': rule_id},
- {'actions': actions_json},
+ {"user_name": user_id, "rule_id": rule_id},
+ {"actions": actions_json},
)
self._insert_push_rules_update_txn(
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 1567e1df48..cfe0a94330 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -37,24 +37,24 @@ else:
class PusherWorkerStore(SQLBaseStore):
def _decode_pushers_rows(self, rows):
for r in rows:
- dataJson = r['data']
- r['data'] = None
+ dataJson = r["data"]
+ r["data"] = None
try:
if isinstance(dataJson, db_binary_type):
dataJson = str(dataJson).decode("UTF8")
- r['data'] = json.loads(dataJson)
+ r["data"] = json.loads(dataJson)
except Exception as e:
logger.warn(
"Invalid JSON in data for pusher %d: %s, %s",
- r['id'],
+ r["id"],
dataJson,
e.args[0],
)
pass
- if isinstance(r['pushkey'], db_binary_type):
- r['pushkey'] = str(r['pushkey']).decode("UTF8")
+ if isinstance(r["pushkey"], db_binary_type):
+ r["pushkey"] = str(r["pushkey"]).decode("UTF8")
return rows
@@ -195,15 +195,15 @@ class PusherWorkerStore(SQLBaseStore):
)
def get_if_users_have_pushers(self, user_ids):
rows = yield self._simple_select_many_batch(
- table='pushers',
- column='user_name',
+ table="pushers",
+ column="user_name",
iterable=user_ids,
- retcols=['user_name'],
- desc='get_if_users_have_pushers',
+ retcols=["user_name"],
+ desc="get_if_users_have_pushers",
)
result = {user_id: False for user_id in user_ids}
- result.update({r['user_name']: True for r in rows})
+ result.update({r["user_name"]: True for r in rows})
defer.returnValue(result)
@@ -299,8 +299,8 @@ class PusherStore(PusherWorkerStore):
):
yield self._simple_update_one(
"pushers",
- {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
- {'last_stream_ordering': last_stream_ordering},
+ {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
+ {"last_stream_ordering": last_stream_ordering},
desc="update_pusher_last_stream_ordering",
)
@@ -310,10 +310,10 @@ class PusherStore(PusherWorkerStore):
):
yield self._simple_update_one(
"pushers",
- {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
+ {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
{
- 'last_stream_ordering': last_stream_ordering,
- 'last_success': last_success,
+ "last_stream_ordering": last_stream_ordering,
+ "last_success": last_success,
},
desc="update_pusher_last_stream_ordering_and_success",
)
@@ -322,8 +322,8 @@ class PusherStore(PusherWorkerStore):
def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
yield self._simple_update_one(
"pushers",
- {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
- {'failing_since': failing_since},
+ {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
+ {"failing_since": failing_since},
desc="update_pusher_failing_since",
)
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index a1647e50a1..b477da12b1 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -58,7 +58,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cachedInlineCallbacks()
def get_users_with_read_receipts_in_room(self, room_id):
receipts = yield self.get_receipts_for_room(room_id, "m.read")
- defer.returnValue(set(r['user_id'] for r in receipts))
+ defer.returnValue(set(r["user_id"] for r in receipts))
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 1dd1182e82..983ce13291 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
import re
from six import iterkeys
@@ -31,6 +32,8 @@ from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
+logger = logging.getLogger(__name__)
+
class RegistrationWorkerStore(SQLBaseStore):
def __init__(self, db_conn, hs):
@@ -113,8 +116,9 @@ class RegistrationWorkerStore(SQLBaseStore):
defer.returnValue(res)
@defer.inlineCallbacks
- def set_account_validity_for_user(self, user_id, expiration_ts, email_sent,
- renewal_token=None):
+ def set_account_validity_for_user(
+ self, user_id, expiration_ts, email_sent, renewal_token=None
+ ):
"""Updates the account validity properties of the given account, with the
given values.
@@ -128,6 +132,7 @@ class RegistrationWorkerStore(SQLBaseStore):
renewal_token (str): Renewal token the user can use to extend the validity
of their account. Defaults to no token.
"""
+
def set_account_validity_for_user_txn(txn):
self._simple_update_txn(
txn=txn,
@@ -140,12 +145,11 @@ class RegistrationWorkerStore(SQLBaseStore):
},
)
self._invalidate_cache_and_stream(
- txn, self.get_expiration_ts_for_user, (user_id,),
+ txn, self.get_expiration_ts_for_user, (user_id,)
)
yield self.runInteraction(
- "set_account_validity_for_user",
- set_account_validity_for_user_txn,
+ "set_account_validity_for_user", set_account_validity_for_user_txn
)
@defer.inlineCallbacks
@@ -214,6 +218,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]]
"""
+
def select_users_txn(txn, now_ms, renew_at):
sql = (
"SELECT user_id, expiration_ts_ms FROM account_validity"
@@ -226,7 +231,8 @@ class RegistrationWorkerStore(SQLBaseStore):
res = yield self.runInteraction(
"get_users_expiring_soon",
select_users_txn,
- self.clock.time_msec(), self.config.account_validity.renew_at,
+ self.clock.time_msec(),
+ self.config.account_validity.renew_at,
)
defer.returnValue(res)
@@ -249,6 +255,20 @@ class RegistrationWorkerStore(SQLBaseStore):
)
@defer.inlineCallbacks
+ def delete_account_validity_for_user(self, user_id):
+ """Deletes the entry for the given user in the account validity table, removing
+ their expiration date and renewal token.
+
+ Args:
+ user_id (str): ID of the user to remove from the account validity table.
+ """
+ yield self._simple_delete_one(
+ table="account_validity",
+ keyvalues={"user_id": user_id},
+ desc="delete_account_validity_for_user",
+ )
+
+ @defer.inlineCallbacks
def is_server_admin(self, user):
res = yield self._simple_select_one_onecol(
table="users",
@@ -352,7 +372,7 @@ class RegistrationWorkerStore(SQLBaseStore):
WHERE creation_ts > ?
) AS t GROUP BY user_type
"""
- results = {'native': 0, 'guest': 0, 'bridged': 0}
+ results = {"native": 0, "guest": 0, "bridged": 0}
txn.execute(sql, (yesterday,))
for row in txn:
results[row[0]] = row[1]
@@ -418,7 +438,7 @@ class RegistrationWorkerStore(SQLBaseStore):
{"medium": medium, "address": address},
["guest_access_token"],
True,
- 'get_3pid_guest_access_token',
+ "get_3pid_guest_access_token",
)
if ret:
defer.returnValue(ret["guest_access_token"])
@@ -455,11 +475,11 @@ class RegistrationWorkerStore(SQLBaseStore):
txn,
"user_threepids",
{"medium": medium, "address": address},
- ['user_id'],
+ ["user_id"],
True,
)
if ret:
- return ret['user_id']
+ return ret["user_id"]
return None
@defer.inlineCallbacks
@@ -475,8 +495,8 @@ class RegistrationWorkerStore(SQLBaseStore):
ret = yield self._simple_select_list(
"user_threepids",
{"user_id": user_id},
- ['medium', 'address', 'validated_at', 'added_at'],
- 'user_get_threepids',
+ ["medium", "address", "validated_at", "added_at"],
+ "user_get_threepids",
)
defer.returnValue(ret)
@@ -555,11 +575,7 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
return self._simple_select_onecol(
table="user_threepid_id_server",
- keyvalues={
- "user_id": user_id,
- "medium": medium,
- "address": address,
- },
+ keyvalues={"user_id": user_id, "medium": medium, "address": address},
retcol="id_server",
desc="get_id_servers_user_bound",
)
@@ -595,15 +611,80 @@ class RegistrationStore(
self.register_noop_background_update("refresh_tokens_device_index")
self.register_background_update_handler(
- "user_threepids_grandfather", self._bg_user_threepids_grandfather,
+ "user_threepids_grandfather", self._bg_user_threepids_grandfather
+ )
+
+ self.register_background_update_handler(
+ "users_set_deactivated_flag", self._backgroud_update_set_deactivated_flag
)
# Create a background job for culling expired 3PID validity tokens
hs.get_clock().looping_call(
- self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS,
+ self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
)
@defer.inlineCallbacks
+ def _backgroud_update_set_deactivated_flag(self, progress, batch_size):
+ """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
+ for each of them.
+ """
+
+ last_user = progress.get("user_id", "")
+
+ def _backgroud_update_set_deactivated_flag_txn(txn):
+ txn.execute(
+ """
+ SELECT
+ users.name,
+ COUNT(access_tokens.token) AS count_tokens,
+ COUNT(user_threepids.address) AS count_threepids
+ FROM users
+ LEFT JOIN access_tokens ON (access_tokens.user_id = users.name)
+ LEFT JOIN user_threepids ON (user_threepids.user_id = users.name)
+ WHERE (users.password_hash IS NULL OR users.password_hash = '')
+ AND (users.appservice_id IS NULL OR users.appservice_id = '')
+ AND users.is_guest = 0
+ AND users.name > ?
+ GROUP BY users.name
+ ORDER BY users.name ASC
+ LIMIT ?;
+ """,
+ (last_user, batch_size),
+ )
+
+ rows = self.cursor_to_dict(txn)
+
+ if not rows:
+ return True
+
+ rows_processed_nb = 0
+
+ for user in rows:
+ if not user["count_tokens"] and not user["count_threepids"]:
+ self.set_user_deactivated_status_txn(txn, user["name"], True)
+ rows_processed_nb += 1
+
+ logger.info("Marked %d rows as deactivated", rows_processed_nb)
+
+ self._background_update_progress_txn(
+ txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]}
+ )
+
+ if batch_size > len(rows):
+ return True
+ else:
+ return False
+
+ end = yield self.runInteraction(
+ "users_set_deactivated_flag", _backgroud_update_set_deactivated_flag_txn
+ )
+
+ if end:
+ yield self._end_background_update("users_set_deactivated_flag")
+
+ defer.returnValue(batch_size)
+
+ @defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id=None):
"""Adds an access token for the given user.
@@ -768,7 +849,7 @@ class RegistrationStore(
def user_set_password_hash_txn(txn):
self._simple_update_one_txn(
- txn, 'users', {'name': user_id}, {'password_hash': password_hash}
+ txn, "users", {"name": user_id}, {"password_hash": password_hash}
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
@@ -789,9 +870,9 @@ class RegistrationStore(
def f(txn):
self._simple_update_one_txn(
txn,
- table='users',
- keyvalues={'name': user_id},
- updatevalues={'consent_version': consent_version},
+ table="users",
+ keyvalues={"name": user_id},
+ updatevalues={"consent_version": consent_version},
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
@@ -813,9 +894,9 @@ class RegistrationStore(
def f(txn):
self._simple_update_one_txn(
txn,
- table='users',
- keyvalues={'name': user_id},
- updatevalues={'consent_server_notice_sent': consent_version},
+ table="users",
+ keyvalues={"name": user_id},
+ updatevalues={"consent_server_notice_sent": consent_version},
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
@@ -985,7 +1066,7 @@ class RegistrationStore(
if id_servers:
yield self.runInteraction(
- "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn,
+ "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
)
yield self._end_background_update("user_threepids_grandfather")
@@ -993,12 +1074,7 @@ class RegistrationStore(
defer.returnValue(1)
def get_threepid_validation_session(
- self,
- medium,
- client_secret,
- address=None,
- sid=None,
- validated=True,
+ self, medium, client_secret, address=None, sid=None, validated=True
):
"""Gets a session_id and last_send_attempt (if available) for a
client_secret/medium/(address|session_id) combo
@@ -1018,23 +1094,22 @@ class RegistrationStore(
latest session_id and send_attempt count for this 3PID.
Otherwise None if there hasn't been a previous attempt
"""
- keyvalues = {
- "medium": medium,
- "client_secret": client_secret,
- }
+ keyvalues = {"medium": medium, "client_secret": client_secret}
if address:
keyvalues["address"] = address
if sid:
keyvalues["session_id"] = sid
- assert(address or sid)
+ assert address or sid
def get_threepid_validation_session_txn(txn):
sql = """
SELECT address, session_id, medium, client_secret,
last_send_attempt, validated_at
FROM threepid_validation_session WHERE %s
- """ % (" AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),)
+ """ % (
+ " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),
+ )
if validated is not None:
sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL")
@@ -1049,17 +1124,10 @@ class RegistrationStore(
return rows[0]
return self.runInteraction(
- "get_threepid_validation_session",
- get_threepid_validation_session_txn,
+ "get_threepid_validation_session", get_threepid_validation_session_txn
)
- def validate_threepid_session(
- self,
- session_id,
- client_secret,
- token,
- current_ts,
- ):
+ def validate_threepid_session(self, session_id, client_secret, token, current_ts):
"""Attempt to validate a threepid session using a token
Args:
@@ -1091,7 +1159,7 @@ class RegistrationStore(
if retrieved_client_secret != client_secret:
raise ThreepidValidationError(
- 400, "This client_secret does not match the provided session_id",
+ 400, "This client_secret does not match the provided session_id"
)
row = self._simple_select_one_txn(
@@ -1104,7 +1172,7 @@ class RegistrationStore(
if not row:
raise ThreepidValidationError(
- 400, "Validation token not found or has expired",
+ 400, "Validation token not found or has expired"
)
expires = row["expires"]
next_link = row["next_link"]
@@ -1115,7 +1183,7 @@ class RegistrationStore(
if expires <= current_ts:
raise ThreepidValidationError(
- 400, "This token has expired. Please request a new one",
+ 400, "This token has expired. Please request a new one"
)
# Looks good. Validate the session
@@ -1130,8 +1198,7 @@ class RegistrationStore(
# Return next_link if it exists
return self.runInteraction(
- "validate_threepid_session_txn",
- validate_threepid_session_txn,
+ "validate_threepid_session_txn", validate_threepid_session_txn
)
def upsert_threepid_validation_session(
@@ -1198,6 +1265,7 @@ class RegistrationStore(
token_expires (int): The timestamp for which after the token
will no longer be valid
"""
+
def start_or_continue_validation_session_txn(txn):
# Create or update a validation session
self._simple_upsert_txn(
@@ -1231,6 +1299,7 @@ class RegistrationStore(
def cull_expired_threepid_validation_tokens(self):
"""Remove threepid validation tokens with expiry dates that have passed"""
+
def cull_expired_threepid_validation_tokens_txn(txn, ts):
sql = """
DELETE FROM threepid_validation_token WHERE
@@ -1252,6 +1321,7 @@ class RegistrationStore(
Args:
session_id (str): The ID of the session to delete
"""
+
def delete_threepid_session_txn(txn):
self._simple_delete_txn(
txn,
@@ -1265,6 +1335,53 @@ class RegistrationStore(
)
return self.runInteraction(
- "delete_threepid_session",
- delete_threepid_session_txn,
+ "delete_threepid_session", delete_threepid_session_txn
+ )
+
+ def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
+ self._simple_update_one_txn(
+ txn=txn,
+ table="users",
+ keyvalues={"name": user_id},
+ updatevalues={"deactivated": 1 if deactivated else 0},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_deactivated_status, (user_id,)
)
+
+ @defer.inlineCallbacks
+ def set_user_deactivated_status(self, user_id, deactivated):
+ """Set the `deactivated` property for the provided user to the provided value.
+
+ Args:
+ user_id (str): The ID of the user to set the status for.
+ deactivated (bool): The value to set for `deactivated`.
+ """
+
+ yield self.runInteraction(
+ "set_user_deactivated_status",
+ self.set_user_deactivated_status_txn,
+ user_id,
+ deactivated,
+ )
+
+ @cachedInlineCallbacks()
+ def get_user_deactivated_status(self, user_id):
+ """Retrieve the value for the `deactivated` property for the provided user.
+
+ Args:
+ user_id (str): The ID of the user to retrieve the status for.
+
+ Returns:
+ defer.Deferred(bool): The requested value.
+ """
+
+ res = yield self._simple_select_one_onecol(
+ table="users",
+ keyvalues={"name": user_id},
+ retcol="deactivated",
+ desc="get_user_deactivated_status",
+ )
+
+ # Convert the integer into a boolean.
+ defer.returnValue(res == 1)
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index 4c83800cca..1b01934c19 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -468,9 +468,5 @@ class RelationsStore(RelationsWorkerStore):
"""
self._simple_delete_txn(
- txn,
- table="event_relations",
- keyvalues={
- "event_id": redacted_event_id,
- }
+ txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
)
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 7617913326..8004aeb909 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -420,7 +420,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
table="room_memberships",
column="event_id",
iterable=missing_member_event_ids,
- retcols=('user_id', 'display_name', 'avatar_url'),
+ retcols=("user_id", "display_name", "avatar_url"),
keyvalues={"membership": Membership.JOIN},
batch_size=500,
desc="_get_joined_users_from_context",
@@ -448,7 +448,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cachedInlineCallbacks(max_entries=10000)
def is_host_joined(self, room_id, host):
- if '%' in host or '_' in host:
+ if "%" in host or "_" in host:
raise Exception("Invalid host name")
sql = """
@@ -490,7 +490,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
Deferred: Resolves to True if the host is/was in the room, otherwise
False.
"""
- if '%' in host or '_' in host:
+ if "%" in host or "_" in host:
raise Exception("Invalid host name")
sql = """
@@ -723,7 +723,7 @@ class RoomMemberStore(RoomMemberWorkerStore):
room_id = row["room_id"]
try:
event_json = json.loads(row["json"])
- content = event_json['content']
+ content = event_json["content"]
except Exception:
continue
diff --git a/synapse/storage/schema/delta/20/pushers.py b/synapse/storage/schema/delta/20/pushers.py
index 147496a38b..3edfcfd783 100644
--- a/synapse/storage/schema/delta/20/pushers.py
+++ b/synapse/storage/schema/delta/20/pushers.py
@@ -29,7 +29,8 @@ logger = logging.getLogger(__name__)
def run_create(cur, database_engine, *args, **kwargs):
logger.info("Porting pushers table...")
- cur.execute("""
+ cur.execute(
+ """
CREATE TABLE IF NOT EXISTS pushers2 (
id BIGINT PRIMARY KEY,
user_name TEXT NOT NULL,
@@ -48,27 +49,34 @@ def run_create(cur, database_engine, *args, **kwargs):
failing_since BIGINT,
UNIQUE (app_id, pushkey, user_name)
)
- """)
- cur.execute("""SELECT
+ """
+ )
+ cur.execute(
+ """SELECT
id, user_name, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, ts, lang, data, last_token, last_success,
failing_since
FROM pushers
- """)
+ """
+ )
count = 0
for row in cur.fetchall():
row = list(row)
row[8] = bytes(row[8]).decode("utf-8")
row[11] = bytes(row[11]).decode("utf-8")
- cur.execute(database_engine.convert_param_style("""
+ cur.execute(
+ database_engine.convert_param_style(
+ """
INSERT into pushers2 (
id, user_name, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, ts, lang, data, last_token, last_success,
failing_since
- ) values (%s)""" % (','.join(['?' for _ in range(len(row))]))),
- row
+ ) values (%s)"""
+ % (",".join(["?" for _ in range(len(row))]))
+ ),
+ row,
)
count += 1
cur.execute("DROP TABLE pushers")
diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py
index ef7ec34346..9b95411fb6 100644
--- a/synapse/storage/schema/delta/30/as_users.py
+++ b/synapse/storage/schema/delta/30/as_users.py
@@ -40,9 +40,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
logger.warning("Could not get app_service_config_files from config")
pass
- appservices = load_appservices(
- config.server_name, config_files
- )
+ appservices = load_appservices(config.server_name, config_files)
owned = {}
@@ -53,20 +51,19 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
if user_id in owned.keys():
logger.error(
"user_id %s was owned by more than one application"
- " service (IDs %s and %s); assigning arbitrarily to %s" %
- (user_id, owned[user_id], appservice.id, owned[user_id])
+ " service (IDs %s and %s); assigning arbitrarily to %s"
+ % (user_id, owned[user_id], appservice.id, owned[user_id])
)
owned.setdefault(appservice.id, []).append(user_id)
for as_id, user_ids in owned.items():
n = 100
- user_chunks = (user_ids[i:i + 100] for i in range(0, len(user_ids), n))
+ user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n))
for chunk in user_chunks:
cur.execute(
database_engine.convert_param_style(
- "UPDATE users SET appservice_id = ? WHERE name IN (%s)" % (
- ",".join("?" for _ in chunk),
- )
+ "UPDATE users SET appservice_id = ? WHERE name IN (%s)"
+ % (",".join("?" for _ in chunk),)
),
- [as_id] + chunk
+ [as_id] + chunk,
)
diff --git a/synapse/storage/schema/delta/31/pushers.py b/synapse/storage/schema/delta/31/pushers.py
index 93367fa09e..9bb504aad5 100644
--- a/synapse/storage/schema/delta/31/pushers.py
+++ b/synapse/storage/schema/delta/31/pushers.py
@@ -24,12 +24,13 @@ logger = logging.getLogger(__name__)
def token_to_stream_ordering(token):
- return int(token[1:].split('_')[0])
+ return int(token[1:].split("_")[0])
def run_create(cur, database_engine, *args, **kwargs):
logger.info("Porting pushers table, delta 31...")
- cur.execute("""
+ cur.execute(
+ """
CREATE TABLE IF NOT EXISTS pushers2 (
id BIGINT PRIMARY KEY,
user_name TEXT NOT NULL,
@@ -48,26 +49,33 @@ def run_create(cur, database_engine, *args, **kwargs):
failing_since BIGINT,
UNIQUE (app_id, pushkey, user_name)
)
- """)
- cur.execute("""SELECT
+ """
+ )
+ cur.execute(
+ """SELECT
id, user_name, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, ts, lang, data, last_token, last_success,
failing_since
FROM pushers
- """)
+ """
+ )
count = 0
for row in cur.fetchall():
row = list(row)
row[12] = token_to_stream_ordering(row[12])
- cur.execute(database_engine.convert_param_style("""
+ cur.execute(
+ database_engine.convert_param_style(
+ """
INSERT into pushers2 (
id, user_name, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, ts, lang, data, last_stream_ordering, last_success,
failing_since
- ) values (%s)""" % (','.join(['?' for _ in range(len(row))]))),
- row
+ ) values (%s)"""
+ % (",".join(["?" for _ in range(len(row))]))
+ ),
+ row,
)
count += 1
cur.execute("DROP TABLE pushers")
diff --git a/synapse/storage/schema/delta/33/remote_media_ts.py b/synapse/storage/schema/delta/33/remote_media_ts.py
index 9754d3ccfb..a26057dfb6 100644
--- a/synapse/storage/schema/delta/33/remote_media_ts.py
+++ b/synapse/storage/schema/delta/33/remote_media_ts.py
@@ -26,5 +26,5 @@ def run_upgrade(cur, database_engine, *args, **kwargs):
database_engine.convert_param_style(
"UPDATE remote_media_cache SET last_access_ts = ?"
),
- (int(time.time() * 1000),)
+ (int(time.time() * 1000),),
)
diff --git a/synapse/storage/schema/delta/47/state_group_seq.py b/synapse/storage/schema/delta/47/state_group_seq.py
index f6766501d2..9fd1ccf6f7 100644
--- a/synapse/storage/schema/delta/47/state_group_seq.py
+++ b/synapse/storage/schema/delta/47/state_group_seq.py
@@ -27,10 +27,7 @@ def run_create(cur, database_engine, *args, **kwargs):
else:
start_val = row[0] + 1
- cur.execute(
- "CREATE SEQUENCE state_group_id_seq START WITH %s",
- (start_val, ),
- )
+ cur.execute("CREATE SEQUENCE state_group_id_seq START WITH %s", (start_val,))
def run_upgrade(*args, **kwargs):
diff --git a/synapse/storage/schema/delta/48/group_unique_indexes.py b/synapse/storage/schema/delta/48/group_unique_indexes.py
index 2233af87d7..49f5f2c003 100644
--- a/synapse/storage/schema/delta/48/group_unique_indexes.py
+++ b/synapse/storage/schema/delta/48/group_unique_indexes.py
@@ -38,16 +38,22 @@ def run_create(cur, database_engine, *args, **kwargs):
rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid"
# remove duplicates from group_users & group_invites tables
- cur.execute("""
+ cur.execute(
+ """
DELETE FROM group_users WHERE %s NOT IN (
SELECT min(%s) FROM group_users GROUP BY group_id, user_id
);
- """ % (rowid, rowid))
- cur.execute("""
+ """
+ % (rowid, rowid)
+ )
+ cur.execute(
+ """
DELETE FROM group_invites WHERE %s NOT IN (
SELECT min(%s) FROM group_invites GROUP BY group_id, user_id
);
- """ % (rowid, rowid))
+ """
+ % (rowid, rowid)
+ )
for statement in get_statements(FIX_INDEXES.splitlines()):
cur.execute(statement)
diff --git a/synapse/storage/schema/delta/50/make_event_content_nullable.py b/synapse/storage/schema/delta/50/make_event_content_nullable.py
index 6dd467b6c5..b1684a8441 100644
--- a/synapse/storage/schema/delta/50/make_event_content_nullable.py
+++ b/synapse/storage/schema/delta/50/make_event_content_nullable.py
@@ -65,14 +65,18 @@ def run_create(cur, database_engine, *args, **kwargs):
def run_upgrade(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine):
- cur.execute("""
+ cur.execute(
+ """
ALTER TABLE events ALTER COLUMN content DROP NOT NULL;
- """)
+ """
+ )
return
# sqlite is an arse about this. ref: https://www.sqlite.org/lang_altertable.html
- cur.execute("SELECT sql FROM sqlite_master WHERE tbl_name='events' AND type='table'")
+ cur.execute(
+ "SELECT sql FROM sqlite_master WHERE tbl_name='events' AND type='table'"
+ )
(oldsql,) = cur.fetchone()
sql = oldsql.replace("content TEXT NOT NULL", "content TEXT")
@@ -86,7 +90,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs):
cur.execute("PRAGMA writable_schema=ON")
cur.execute(
"UPDATE sqlite_master SET sql=? WHERE tbl_name='events' AND type='table'",
- (sql, ),
+ (sql,),
)
cur.execute("PRAGMA schema_version=%i" % (oldver + 1,))
cur.execute("PRAGMA writable_schema=OFF")
diff --git a/synapse/storage/schema/delta/55/users_alter_deactivated.sql b/synapse/storage/schema/delta/55/users_alter_deactivated.sql
new file mode 100644
index 0000000000..dabdde489b
--- /dev/null
+++ b/synapse/storage/schema/delta/55/users_alter_deactivated.sql
@@ -0,0 +1,19 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE users ADD deactivated SMALLINT DEFAULT 0 NOT NULL;
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('users_set_deactivated_flag', '{}');
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index ff49eaae02..f3b1cec933 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -31,8 +31,8 @@ from .background_updates import BackgroundUpdateStore
logger = logging.getLogger(__name__)
SearchEntry = namedtuple(
- 'SearchEntry',
- ['key', 'value', 'event_id', 'room_id', 'stream_ordering', 'origin_server_ts'],
+ "SearchEntry",
+ ["key", "value", "event_id", "room_id", "stream_ordering", "origin_server_ts"],
)
@@ -216,7 +216,7 @@ class SearchStore(BackgroundUpdateStore):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
- have_added_index = progress['have_added_indexes']
+ have_added_index = progress["have_added_indexes"]
if not have_added_index:
@@ -341,29 +341,7 @@ class SearchStore(BackgroundUpdateStore):
for entry in entries
)
- # inserts to a GIN index are normally batched up into a pending
- # list, and then all committed together once the list gets to a
- # certain size. The trouble with that is that postgres (pre-9.5)
- # uses work_mem to determine the length of the list, and work_mem
- # is typically very large.
- #
- # We therefore reduce work_mem while we do the insert.
- #
- # (postgres 9.5 uses the separate gin_pending_list_limit setting,
- # so doesn't suffer the same problem, but changing work_mem will
- # be harmless)
- #
- # Note that we don't need to worry about restoring it on
- # exception, because exceptions will cause the transaction to be
- # rolled back, including the effects of the SET command.
- #
- # Also: we use SET rather than SET LOCAL because there's lots of
- # other stuff going on in this transaction, which want to have the
- # normal work_mem setting.
-
- txn.execute("SET work_mem='256kB'")
txn.executemany(sql, args)
- txn.execute("RESET work_mem")
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py
index ff266b09b0..1cec84ee2e 100644
--- a/synapse/storage/stats.py
+++ b/synapse/storage/stats.py
@@ -71,7 +71,8 @@ class StatsStore(StateDeltasStore):
# Get all the rooms that we want to process.
def _make_staging_area(txn):
# Create the temporary tables
- stmts = get_statements("""
+ stmts = get_statements(
+ """
-- We just recreate the table, we'll be reinserting the
-- correct entries again later anyway.
DROP TABLE IF EXISTS {temp}_rooms;
@@ -85,7 +86,10 @@ class StatsStore(StateDeltasStore):
ON {temp}_rooms(events);
CREATE INDEX {temp}_rooms_id
ON {temp}_rooms(room_id);
- """.format(temp=TEMP_TABLE).splitlines())
+ """.format(
+ temp=TEMP_TABLE
+ ).splitlines()
+ )
for statement in stmts:
txn.execute(statement)
@@ -105,7 +109,9 @@ class StatsStore(StateDeltasStore):
LEFT JOIN room_stats_earliest_token AS t USING (room_id)
WHERE t.room_id IS NULL
GROUP BY c.room_id
- """ % (TEMP_TABLE,)
+ """ % (
+ TEMP_TABLE,
+ )
txn.execute(sql)
new_pos = yield self.get_max_stream_id_in_current_state_deltas()
@@ -184,7 +190,8 @@ class StatsStore(StateDeltasStore):
logger.info(
"Processing the next %d rooms of %d remaining",
- len(rooms_to_work_on), progress["remaining"],
+ len(rooms_to_work_on),
+ progress["remaining"],
)
# Number of state events we've processed by going through each room
@@ -204,10 +211,17 @@ class StatsStore(StateDeltasStore):
avatar_id = current_state_ids.get((EventTypes.RoomAvatar, ""))
canonical_alias_id = current_state_ids.get((EventTypes.CanonicalAlias, ""))
- state_events = yield self.get_events([
- join_rules_id, history_visibility_id, encryption_id, name_id,
- topic_id, avatar_id, canonical_alias_id,
- ])
+ state_events = yield self.get_events(
+ [
+ join_rules_id,
+ history_visibility_id,
+ encryption_id,
+ name_id,
+ topic_id,
+ avatar_id,
+ canonical_alias_id,
+ ]
+ )
def _get_or_none(event_id, arg):
event = state_events.get(event_id)
@@ -271,7 +285,7 @@ class StatsStore(StateDeltasStore):
# We've finished a room. Delete it from the table.
self._simple_delete_one_txn(
- txn, TEMP_TABLE + "_rooms", {"room_id": room_id},
+ txn, TEMP_TABLE + "_rooms", {"room_id": room_id}
)
yield self.runInteraction("update_room_stats", _fetch_data)
@@ -338,7 +352,7 @@ class StatsStore(StateDeltasStore):
"name",
"topic",
"avatar",
- "canonical_alias"
+ "canonical_alias",
):
field = fields.get(col)
if field and "\0" in field:
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 6f7f65d96b..d9482a3843 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -65,7 +65,7 @@ _EventDictReturn = namedtuple(
def generate_pagination_where_clause(
- direction, column_names, from_token, to_token, engine,
+ direction, column_names, from_token, to_token, engine
):
"""Creates an SQL expression to bound the columns by the pagination
tokens.
@@ -153,7 +153,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
str
"""
- assert(bound in (">", "<", ">=", "<="))
+ assert bound in (">", "<", ">=", "<=")
name1, name2 = column_names
val1, val2 = values
@@ -169,11 +169,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
# Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
# as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
# use the later form when running against postgres.
- return "((%d,%d) %s (%s,%s))" % (
- val1, val2,
- bound,
- name1, name2,
- )
+ return "((%d,%d) %s (%s,%s))" % (val1, val2, bound, name1, name2)
# We want to generate queries of e.g. the form:
#
@@ -276,7 +272,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
@defer.inlineCallbacks
def get_room_events_stream_for_rooms(
- self, room_ids, from_key, to_key, limit=0, order='DESC'
+ self, room_ids, from_key, to_key, limit=0, order="DESC"
):
"""Get new room events in stream ordering since `from_key`.
@@ -346,7 +342,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
@defer.inlineCallbacks
def get_room_events_stream_for_room(
- self, room_id, from_key, to_key, limit=0, order='DESC'
+ self, room_id, from_key, to_key, limit=0, order="DESC"
):
"""Get new room events in stream ordering since `from_key`.
@@ -395,8 +391,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = yield self.runInteraction("get_room_events_stream_for_room", f)
- ret = yield self.get_events_as_list([
- r.event_id for r in rows], get_prev_content=True,
+ ret = yield self.get_events_as_list(
+ [r.event_id for r in rows], get_prev_content=True
)
self._set_before_and_after(ret, rows, topo_order=from_id is None)
@@ -446,7 +442,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = yield self.runInteraction("get_membership_changes_for_user", f)
ret = yield self.get_events_as_list(
- [r.event_id for r in rows], get_prev_content=True,
+ [r.event_id for r in rows], get_prev_content=True
)
self._set_before_and_after(ret, rows, topo_order=False)
@@ -725,7 +721,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn,
room_id,
before_token,
- direction='b',
+ direction="b",
limit=before_limit,
event_filter=event_filter,
)
@@ -735,7 +731,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn,
room_id,
after_token,
- direction='f',
+ direction="f",
limit=after_limit,
event_filter=event_filter,
)
@@ -816,7 +812,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id,
from_token,
to_token=None,
- direction='b',
+ direction="b",
limit=-1,
event_filter=None,
):
@@ -846,7 +842,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities.
args = [False, room_id]
- if direction == 'b':
+ if direction == "b":
order = "DESC"
else:
order = "ASC"
@@ -882,7 +878,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if rows:
topo = rows[-1].topological_ordering
toke = rows[-1].stream_ordering
- if direction == 'b':
+ if direction == "b":
# Tokens are positions between events.
# This token points *after* the last event in the chunk.
# We need it to point to the event before it in the chunk
@@ -898,7 +894,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
@defer.inlineCallbacks
def paginate_room_events(
- self, room_id, from_key, to_key=None, direction='b', limit=-1, event_filter=None
+ self, room_id, from_key, to_key=None, direction="b", limit=-1, event_filter=None
):
"""Returns list of events before or after a given token.
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index 451e4fa441..f7f5906a99 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -30,34 +30,34 @@ class SourcePaginationConfig(object):
"""A configuration object which stores pagination parameters for a
specific event source."""
- def __init__(self, from_key=None, to_key=None, direction='f',
- limit=None):
+ def __init__(self, from_key=None, to_key=None, direction="f", limit=None):
self.from_key = from_key
self.to_key = to_key
- self.direction = 'f' if direction == 'f' else 'b'
+ self.direction = "f" if direction == "f" else "b"
self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None
def __repr__(self):
- return (
- "StreamConfig(from_key=%r, to_key=%r, direction=%r, limit=%r)"
- ) % (self.from_key, self.to_key, self.direction, self.limit)
+ return ("StreamConfig(from_key=%r, to_key=%r, direction=%r, limit=%r)") % (
+ self.from_key,
+ self.to_key,
+ self.direction,
+ self.limit,
+ )
class PaginationConfig(object):
"""A configuration object which stores pagination parameters."""
- def __init__(self, from_token=None, to_token=None, direction='f',
- limit=None):
+ def __init__(self, from_token=None, to_token=None, direction="f", limit=None):
self.from_token = from_token
self.to_token = to_token
- self.direction = 'f' if direction == 'f' else 'b'
+ self.direction = "f" if direction == "f" else "b"
self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None
@classmethod
- def from_request(cls, request, raise_invalid_params=True,
- default_limit=None):
- direction = parse_string(request, "dir", default='f', allowed_values=['f', 'b'])
+ def from_request(cls, request, raise_invalid_params=True, default_limit=None):
+ direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
from_tok = parse_string(request, "from")
to_tok = parse_string(request, "to")
@@ -89,8 +89,7 @@ class PaginationConfig(object):
def __repr__(self):
return (
- "PaginationConfig(from_tok=%r, to_tok=%r,"
- " direction=%r, limit=%r)"
+ "PaginationConfig(from_tok=%r, to_tok=%r," " direction=%r, limit=%r)"
) % (self.from_token, self.to_token, self.direction, self.limit)
def get_source_config(self, source_name):
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index e5220132a3..488c49747a 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -34,8 +34,7 @@ class EventSources(object):
def __init__(self, hs):
self.sources = {
- name: cls(hs)
- for name, cls in EventSources.SOURCE_TYPES.items()
+ name: cls(hs) for name, cls in EventSources.SOURCE_TYPES.items()
}
self.store = hs.get_datastore()
@@ -47,21 +46,11 @@ class EventSources(object):
groups_key = self.store.get_group_stream_token()
token = StreamToken(
- room_key=(
- yield self.sources["room"].get_current_key()
- ),
- presence_key=(
- yield self.sources["presence"].get_current_key()
- ),
- typing_key=(
- yield self.sources["typing"].get_current_key()
- ),
- receipt_key=(
- yield self.sources["receipt"].get_current_key()
- ),
- account_data_key=(
- yield self.sources["account_data"].get_current_key()
- ),
+ room_key=(yield self.sources["room"].get_current_key()),
+ presence_key=(yield self.sources["presence"].get_current_key()),
+ typing_key=(yield self.sources["typing"].get_current_key()),
+ receipt_key=(yield self.sources["receipt"].get_current_key()),
+ account_data_key=(yield self.sources["account_data"].get_current_key()),
push_rules_key=push_rules_key,
to_device_key=to_device_key,
device_list_key=device_list_key,
@@ -70,31 +59,25 @@ class EventSources(object):
defer.returnValue(token)
@defer.inlineCallbacks
- def get_current_token_for_room(self, room_id):
- push_rules_key, _ = self.store.get_push_rules_stream_token()
- to_device_key = self.store.get_to_device_stream_token()
- device_list_key = self.store.get_device_stream_token()
- groups_key = self.store.get_group_stream_token()
+ def get_current_token_for_pagination(self):
+ """Get the current token for a given room to be used to paginate
+ events.
+
+ The returned token does not have the current values for fields other
+ than `room`, since they are not used during pagination.
+ Retuns:
+ Deferred[StreamToken]
+ """
token = StreamToken(
- room_key=(
- yield self.sources["room"].get_current_key_for_room(room_id)
- ),
- presence_key=(
- yield self.sources["presence"].get_current_key()
- ),
- typing_key=(
- yield self.sources["typing"].get_current_key()
- ),
- receipt_key=(
- yield self.sources["receipt"].get_current_key()
- ),
- account_data_key=(
- yield self.sources["account_data"].get_current_key()
- ),
- push_rules_key=push_rules_key,
- to_device_key=to_device_key,
- device_list_key=device_list_key,
- groups_key=groups_key,
+ room_key=(yield self.sources["room"].get_current_key()),
+ presence_key=0,
+ typing_key=0,
+ receipt_key=0,
+ account_data_key=0,
+ push_rules_key=0,
+ to_device_key=0,
+ device_list_key=0,
+ groups_key=0,
)
defer.returnValue(token)
diff --git a/synapse/types.py b/synapse/types.py
index 3de94b6335..51eadb6ad4 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -21,9 +21,11 @@ import attr
from synapse.api.errors import SynapseError
-class Requester(namedtuple("Requester", [
- "user", "access_token_id", "is_guest", "device_id", "app_service",
-])):
+class Requester(
+ namedtuple(
+ "Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"]
+ )
+):
"""
Represents the user making a request
@@ -76,8 +78,9 @@ class Requester(namedtuple("Requester", [
)
-def create_requester(user_id, access_token_id=None, is_guest=False,
- device_id=None, app_service=None):
+def create_requester(
+ user_id, access_token_id=None, is_guest=False, device_id=None, app_service=None
+):
"""
Create a new ``Requester`` object
@@ -101,7 +104,7 @@ def get_domain_from_id(string):
idx = string.find(":")
if idx == -1:
raise SynapseError(400, "Invalid ID: %r" % (string,))
- return string[idx + 1:]
+ return string[idx + 1 :]
def get_localpart_from_id(string):
@@ -111,9 +114,7 @@ def get_localpart_from_id(string):
return string[1:idx]
-class DomainSpecificString(
- namedtuple("DomainSpecificString", ("localpart", "domain"))
-):
+class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "domain"))):
"""Common base class among ID/name strings that have a local part and a
domain name, prefixed with a sigil.
@@ -141,16 +142,16 @@ class DomainSpecificString(
def from_string(cls, s):
"""Parse the string given by 's' into a structure object."""
if len(s) < 1 or s[0:1] != cls.SIGIL:
- raise SynapseError(400, "Expected %s string to start with '%s'" % (
- cls.__name__, cls.SIGIL,
- ))
+ raise SynapseError(
+ 400, "Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL)
+ )
- parts = s[1:].split(':', 1)
+ parts = s[1:].split(":", 1)
if len(parts) != 2:
raise SynapseError(
- 400, "Expected %s of the form '%slocalname:domain'" % (
- cls.__name__, cls.SIGIL,
- )
+ 400,
+ "Expected %s of the form '%slocalname:domain'"
+ % (cls.__name__, cls.SIGIL),
)
domain = parts[1]
@@ -176,47 +177,50 @@ class DomainSpecificString(
class UserID(DomainSpecificString):
"""Structure representing a user ID."""
+
SIGIL = "@"
class RoomAlias(DomainSpecificString):
"""Structure representing a room name."""
+
SIGIL = "#"
class RoomID(DomainSpecificString):
"""Structure representing a room id. """
+
SIGIL = "!"
class EventID(DomainSpecificString):
"""Structure representing an event id. """
+
SIGIL = "$"
class GroupID(DomainSpecificString):
"""Structure representing a group ID."""
+
SIGIL = "+"
@classmethod
def from_string(cls, s):
group_id = super(GroupID, cls).from_string(s)
if not group_id.localpart:
- raise SynapseError(
- 400,
- "Group ID cannot be empty",
- )
+ raise SynapseError(400, "Group ID cannot be empty")
if contains_invalid_mxid_characters(group_id.localpart):
raise SynapseError(
- 400,
- "Group ID can only contain characters a-z, 0-9, or '=_-./'",
+ 400, "Group ID can only contain characters a-z, 0-9, or '=_-./'"
)
return group_id
-mxid_localpart_allowed_characters = set("_-./=" + string.ascii_lowercase + string.digits)
+mxid_localpart_allowed_characters = set(
+ "_-./=" + string.ascii_lowercase + string.digits
+)
def contains_invalid_mxid_characters(localpart):
@@ -245,9 +249,9 @@ UPPER_CASE_PATTERN = re.compile(b"[A-Z_]")
# bytes rather than strings
#
NON_MXID_CHARACTER_PATTERN = re.compile(
- ("[^%s]" % (
- re.escape("".join(mxid_localpart_allowed_characters - {"="}),),
- )).encode("ascii"),
+ ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters - {"="})),)).encode(
+ "ascii"
+ )
)
@@ -266,10 +270,11 @@ def map_username_to_mxid_localpart(username, case_sensitive=False):
unicode: string suitable for a mxid localpart
"""
if not isinstance(username, bytes):
- username = username.encode('utf-8')
+ username = username.encode("utf-8")
# first we sort out upper-case characters
if case_sensitive:
+
def f1(m):
return b"_" + m.group().lower()
@@ -289,25 +294,28 @@ def map_username_to_mxid_localpart(username, case_sensitive=False):
username = NON_MXID_CHARACTER_PATTERN.sub(f2, username)
# we also do the =-escaping to mxids starting with an underscore.
- username = re.sub(b'^_', b'=5f', username)
+ username = re.sub(b"^_", b"=5f", username)
# we should now only have ascii bytes left, so can decode back to a
# unicode.
- return username.decode('ascii')
+ return username.decode("ascii")
class StreamToken(
- namedtuple("Token", (
- "room_key",
- "presence_key",
- "typing_key",
- "receipt_key",
- "account_data_key",
- "push_rules_key",
- "to_device_key",
- "device_list_key",
- "groups_key",
- ))
+ namedtuple(
+ "Token",
+ (
+ "room_key",
+ "presence_key",
+ "typing_key",
+ "receipt_key",
+ "account_data_key",
+ "push_rules_key",
+ "to_device_key",
+ "device_list_key",
+ "groups_key",
+ ),
+ )
):
_SEPARATOR = "_"
@@ -368,9 +376,7 @@ class StreamToken(
return self._replace(**{key: new_value})
-StreamToken.START = StreamToken(
- *(["s0"] + ["0"] * (len(StreamToken._fields) - 1))
-)
+StreamToken.START = StreamToken(*(["s0"] + ["0"] * (len(StreamToken._fields) - 1)))
class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
@@ -395,15 +401,16 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
"topological_ordering" id of the event it comes after, followed by "-",
followed by the "stream_ordering" id of the event it comes after.
"""
+
__slots__ = []
@classmethod
def parse(cls, string):
try:
- if string[0] == 's':
+ if string[0] == "s":
return cls(topological=None, stream=int(string[1:]))
- if string[0] == 't':
- parts = string[1:].split('-', 1)
+ if string[0] == "t":
+ parts = string[1:].split("-", 1)
return cls(topological=int(parts[0]), stream=int(parts[1]))
except Exception:
pass
@@ -412,7 +419,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
@classmethod
def parse_stream_token(cls, string):
try:
- if string[0] == 's':
+ if string[0] == "s":
return cls(topological=None, stream=int(string[1:]))
except Exception:
pass
@@ -426,7 +433,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
class ThirdPartyInstanceID(
- namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id"))
+ namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id"))
):
# Deny iteration because it will bite you if you try to create a singleton
# set by:
@@ -450,18 +457,19 @@ class ThirdPartyInstanceID(
return cls(appservice_id=bits[0], network_id=bits[1])
def to_string(self):
- return "%s|%s" % (self.appservice_id, self.network_id,)
+ return "%s|%s" % (self.appservice_id, self.network_id)
__str__ = to_string
@classmethod
- def create(cls, appservice_id, network_id,):
+ def create(cls, appservice_id, network_id):
return cls(appservice_id=appservice_id, network_id=network_id)
@attr.s(slots=True)
class ReadReceipt(object):
"""Information about a read-receipt"""
+
room_id = attr.ib()
receipt_type = attr.ib()
user_id = attr.ib()
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 0ae7e2ef3b..dcc747cac1 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -40,6 +40,7 @@ class Clock(object):
Args:
reactor: The Twisted reactor to use.
"""
+
_reactor = attr.ib()
@defer.inlineCallbacks
@@ -70,9 +71,7 @@ class Clock(object):
call = task.LoopingCall(f)
call.clock = self._reactor
d = call.start(msec / 1000.0, now=False)
- d.addErrback(
- log_failure, "Looping call died", consumeErrors=False,
- )
+ d.addErrback(log_failure, "Looping call died", consumeErrors=False)
return call
def call_later(self, delay, callback, *args, **kwargs):
@@ -84,6 +83,7 @@ class Clock(object):
*args: Postional arguments to pass to function.
**kwargs: Key arguments to pass to function.
"""
+
def wrapped_callback(*args, **kwargs):
with PreserveLoggingContext():
callback(*args, **kwargs)
@@ -129,12 +129,7 @@ def log_failure(failure, msg, consumeErrors=True):
"""
logger.error(
- msg,
- exc_info=(
- failure.type,
- failure.value,
- failure.getTracebackObject()
- )
+ msg, exc_info=(failure.type, failure.value, failure.getTracebackObject())
)
if not consumeErrors:
@@ -152,12 +147,12 @@ def glob_to_regex(glob):
Returns:
re.RegexObject
"""
- res = ''
+ res = ""
for c in glob:
- if c == '*':
- res = res + '.*'
- elif c == '?':
- res = res + '.'
+ if c == "*":
+ res = res + ".*"
+ elif c == "?":
+ res = res + "."
else:
res = res + re.escape(c)
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 7253ba120f..7757b8708a 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -95,6 +95,7 @@ class ObservableDeferred(object):
def remove(r):
self._observers.discard(d)
return r
+
d.addBoth(remove)
self._observers.add(d)
@@ -123,7 +124,9 @@ class ObservableDeferred(object):
def __repr__(self):
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
- id(self), self._result, self._deferred,
+ id(self),
+ self._result,
+ self._deferred,
)
@@ -150,10 +153,12 @@ def concurrently_execute(func, args, limit):
except StopIteration:
pass
- return logcontext.make_deferred_yieldable(defer.gatherResults([
- run_in_background(_concurrently_execute_inner)
- for _ in range(limit)
- ], consumeErrors=True)).addErrback(unwrapFirstError)
+ return logcontext.make_deferred_yieldable(
+ defer.gatherResults(
+ [run_in_background(_concurrently_execute_inner) for _ in range(limit)],
+ consumeErrors=True,
+ )
+ ).addErrback(unwrapFirstError)
def yieldable_gather_results(func, iter, *args, **kwargs):
@@ -169,10 +174,12 @@ def yieldable_gather_results(func, iter, *args, **kwargs):
Deferred[list]: Resolved when all functions have been invoked, or errors if
one of the function calls fails.
"""
- return logcontext.make_deferred_yieldable(defer.gatherResults([
- run_in_background(func, item, *args, **kwargs)
- for item in iter
- ], consumeErrors=True)).addErrback(unwrapFirstError)
+ return logcontext.make_deferred_yieldable(
+ defer.gatherResults(
+ [run_in_background(func, item, *args, **kwargs) for item in iter],
+ consumeErrors=True,
+ )
+ ).addErrback(unwrapFirstError)
class Linearizer(object):
@@ -185,6 +192,7 @@ class Linearizer(object):
# do some work.
"""
+
def __init__(self, name=None, max_count=1, clock=None):
"""
Args:
@@ -197,6 +205,7 @@ class Linearizer(object):
if not clock:
from twisted.internet import reactor
+
clock = Clock(reactor)
self._clock = clock
self.max_count = max_count
@@ -221,7 +230,7 @@ class Linearizer(object):
res = self._await_lock(key)
else:
logger.debug(
- "Acquired uncontended linearizer lock %r for key %r", self.name, key,
+ "Acquired uncontended linearizer lock %r for key %r", self.name, key
)
entry[0] += 1
res = defer.succeed(None)
@@ -266,9 +275,7 @@ class Linearizer(object):
"""
entry = self.key_to_defer[key]
- logger.debug(
- "Waiting to acquire linearizer lock %r for key %r", self.name, key,
- )
+ logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
new_defer = make_deferred_yieldable(defer.Deferred())
entry[1][new_defer] = 1
@@ -293,14 +300,14 @@ class Linearizer(object):
logger.info("defer %r got err %r", new_defer, e)
if isinstance(e, CancelledError):
logger.debug(
- "Cancelling wait for linearizer lock %r for key %r",
- self.name, key,
+ "Cancelling wait for linearizer lock %r for key %r", self.name, key
)
else:
logger.warn(
"Unexpected exception waiting for linearizer lock %r for key %r",
- self.name, key,
+ self.name,
+ key,
)
# we just have to take ourselves back out of the queue.
@@ -438,7 +445,7 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
try:
deferred.cancel()
- except: # noqa: E722, if we throw any exception it'll break time outs
+ except: # noqa: E722, if we throw any exception it'll break time outs
logger.exception("Canceller failed during timeout")
if not new_d.called:
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index f37d5bec08..8271229015 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -104,8 +104,8 @@ def register_cache(cache_type, cache_name, cache):
KNOWN_KEYS = {
- key: key for key in
- (
+ key: key
+ for key in (
"auth_events",
"content",
"depth",
@@ -150,7 +150,7 @@ def intern_dict(dictionary):
def _intern_known_values(key, value):
- intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key",)
+ intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key")
if key in intern_keys:
return intern_string(value)
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 187510576a..d2f25063aa 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -40,9 +40,7 @@ _CacheSentinel = object()
class CacheEntry(object):
- __slots__ = [
- "deferred", "callbacks", "invalidated"
- ]
+ __slots__ = ["deferred", "callbacks", "invalidated"]
def __init__(self, deferred, callbacks):
self.deferred = deferred
@@ -73,7 +71,9 @@ class Cache(object):
self._pending_deferred_cache = cache_type()
self.cache = LruCache(
- max_size=max_entries, keylen=keylen, cache_type=cache_type,
+ max_size=max_entries,
+ keylen=keylen,
+ cache_type=cache_type,
size_callback=(lambda d: len(d)) if iterable else None,
evicted_callback=self._on_evicted,
)
@@ -133,10 +133,7 @@ class Cache(object):
def set(self, key, value, callback=None):
callbacks = [callback] if callback else []
self.check_thread()
- entry = CacheEntry(
- deferred=value,
- callbacks=callbacks,
- )
+ entry = CacheEntry(deferred=value, callbacks=callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
@@ -191,9 +188,7 @@ class Cache(object):
def invalidate_many(self, key):
self.check_thread()
if not isinstance(key, tuple):
- raise TypeError(
- "The cache key must be a tuple not %r" % (type(key),)
- )
+ raise TypeError("The cache key must be a tuple not %r" % (type(key),))
self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the
@@ -244,29 +239,25 @@ class _CacheDescriptorBase(object):
raise Exception(
"Not enough explicit positional arguments to key off for %r: "
"got %i args, but wanted %i. (@cached cannot key off *args or "
- "**kwargs)"
- % (orig.__name__, len(all_args), num_args)
+ "**kwargs)" % (orig.__name__, len(all_args), num_args)
)
self.num_args = num_args
# list of the names of the args used as the cache key
- self.arg_names = all_args[1:num_args + 1]
+ self.arg_names = all_args[1 : num_args + 1]
# self.arg_defaults is a map of arg name to its default value for each
# argument that has a default value
if arg_spec.defaults:
- self.arg_defaults = dict(zip(
- all_args[-len(arg_spec.defaults):],
- arg_spec.defaults
- ))
+ self.arg_defaults = dict(
+ zip(all_args[-len(arg_spec.defaults) :], arg_spec.defaults)
+ )
else:
self.arg_defaults = {}
if "cache_context" in self.arg_names:
- raise Exception(
- "cache_context arg cannot be included among the cache keys"
- )
+ raise Exception("cache_context arg cannot be included among the cache keys")
self.add_cache_context = cache_context
@@ -304,12 +295,24 @@ class CacheDescriptor(_CacheDescriptorBase):
``cache_context``) to use as cache keys. Defaults to all named
args of the function.
"""
- def __init__(self, orig, max_entries=1000, num_args=None, tree=False,
- inlineCallbacks=False, cache_context=False, iterable=False):
+
+ def __init__(
+ self,
+ orig,
+ max_entries=1000,
+ num_args=None,
+ tree=False,
+ inlineCallbacks=False,
+ cache_context=False,
+ iterable=False,
+ ):
super(CacheDescriptor, self).__init__(
- orig, num_args=num_args, inlineCallbacks=inlineCallbacks,
- cache_context=cache_context)
+ orig,
+ num_args=num_args,
+ inlineCallbacks=inlineCallbacks,
+ cache_context=cache_context,
+ )
max_entries = int(max_entries * get_cache_factor_for(orig.__name__))
@@ -356,7 +359,9 @@ class CacheDescriptor(_CacheDescriptorBase):
return args[0]
else:
return self.arg_defaults[nm]
+
else:
+
def get_cache_key(args, kwargs):
return tuple(get_cache_key_gen(args, kwargs))
@@ -383,8 +388,7 @@ class CacheDescriptor(_CacheDescriptorBase):
except KeyError:
ret = defer.maybeDeferred(
- logcontext.preserve_fn(self.function_to_call),
- obj, *args, **kwargs
+ logcontext.preserve_fn(self.function_to_call), obj, *args, **kwargs
)
def onErr(f):
@@ -437,8 +441,9 @@ class CacheListDescriptor(_CacheDescriptorBase):
results.
"""
- def __init__(self, orig, cached_method_name, list_name, num_args=None,
- inlineCallbacks=False):
+ def __init__(
+ self, orig, cached_method_name, list_name, num_args=None, inlineCallbacks=False
+ ):
"""
Args:
orig (function)
@@ -451,7 +456,8 @@ class CacheListDescriptor(_CacheDescriptorBase):
be wrapped by defer.inlineCallbacks
"""
super(CacheListDescriptor, self).__init__(
- orig, num_args=num_args, inlineCallbacks=inlineCallbacks)
+ orig, num_args=num_args, inlineCallbacks=inlineCallbacks
+ )
self.list_name = list_name
@@ -463,7 +469,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
if self.list_name not in self.arg_names:
raise Exception(
"Couldn't see arguments %r for %r."
- % (self.list_name, cached_method_name,)
+ % (self.list_name, cached_method_name)
)
def __get__(self, obj, objtype=None):
@@ -494,8 +500,10 @@ class CacheListDescriptor(_CacheDescriptorBase):
# If the cache takes a single arg then that is used as the key,
# otherwise a tuple is used.
if num_args == 1:
+
def arg_to_cache_key(arg):
return arg
+
else:
keylist = list(keyargs)
@@ -505,8 +513,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
for arg in list_args:
try:
- res = cache.get(arg_to_cache_key(arg),
- callback=invalidate_callback)
+ res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
if not isinstance(res, ObservableDeferred):
results[arg] = res
elif not res.has_succeeded():
@@ -554,18 +561,15 @@ class CacheListDescriptor(_CacheDescriptorBase):
args_to_call = dict(arg_dict)
args_to_call[self.list_name] = list(missing)
- cached_defers.append(defer.maybeDeferred(
- logcontext.preserve_fn(self.function_to_call),
- **args_to_call
- ).addCallbacks(complete_all, errback))
+ cached_defers.append(
+ defer.maybeDeferred(
+ logcontext.preserve_fn(self.function_to_call), **args_to_call
+ ).addCallbacks(complete_all, errback)
+ )
if cached_defers:
- d = defer.gatherResults(
- cached_defers,
- consumeErrors=True,
- ).addCallbacks(
- lambda _: results,
- unwrapFirstError
+ d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
+ lambda _: results, unwrapFirstError
)
return logcontext.make_deferred_yieldable(d)
else:
@@ -586,8 +590,9 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
self.cache.invalidate(self.key)
-def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
- iterable=False):
+def cached(
+ max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
+):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
@@ -598,8 +603,9 @@ def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
)
-def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False,
- cache_context=False, iterable=False):
+def cachedInlineCallbacks(
+ max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
+):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 6c0b5a4094..6834e6f3ae 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -35,6 +35,7 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va
there.
value (dict): The full or partial dict value
"""
+
def __len__(self):
return len(self.value)
@@ -84,13 +85,15 @@ class DictionaryCache(object):
self.metrics.inc_hits()
if dict_keys is None:
- return DictionaryEntry(entry.full, entry.known_absent, dict(entry.value))
+ return DictionaryEntry(
+ entry.full, entry.known_absent, dict(entry.value)
+ )
else:
- return DictionaryEntry(entry.full, entry.known_absent, {
- k: entry.value[k]
- for k in dict_keys
- if k in entry.value
- })
+ return DictionaryEntry(
+ entry.full,
+ entry.known_absent,
+ {k: entry.value[k] for k in dict_keys if k in entry.value},
+ )
self.metrics.inc_misses()
return DictionaryEntry(False, set(), {})
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index f369780277..cddf1ed515 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -28,8 +28,15 @@ SENTINEL = object()
class ExpiringCache(object):
- def __init__(self, cache_name, clock, max_len=0, expiry_ms=0,
- reset_expiry_on_get=False, iterable=False):
+ def __init__(
+ self,
+ cache_name,
+ clock,
+ max_len=0,
+ expiry_ms=0,
+ reset_expiry_on_get=False,
+ iterable=False,
+ ):
"""
Args:
cache_name (str): Name of this cache, used for logging.
@@ -67,8 +74,7 @@ class ExpiringCache(object):
def f():
return run_as_background_process(
- "prune_cache_%s" % self._cache_name,
- self._prune_cache,
+ "prune_cache_%s" % self._cache_name, self._prune_cache
)
self._clock.looping_call(f, self._expiry_ms / 2)
@@ -153,7 +159,9 @@ class ExpiringCache(object):
logger.debug(
"[%s] _prune_cache before: %d, after len: %d",
- self._cache_name, begin_length, len(self)
+ self._cache_name,
+ begin_length,
+ len(self),
)
def __len__(self):
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index b684f24e7b..1536cb64f3 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -49,8 +49,15 @@ class LruCache(object):
Can also set callbacks on objects when getting/setting which are fired
when that key gets invalidated/evicted.
"""
- def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None,
- evicted_callback=None):
+
+ def __init__(
+ self,
+ max_size,
+ keylen=1,
+ cache_type=dict,
+ size_callback=None,
+ evicted_callback=None,
+ ):
"""
Args:
max_size (int):
@@ -93,9 +100,12 @@ class LruCache(object):
cached_cache_len = [0]
if size_callback is not None:
+
def cache_len():
return cached_cache_len[0]
+
else:
+
def cache_len():
return len(cache)
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index afb03b2e1b..cbe54d45dd 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -35,12 +35,10 @@ class ResponseCache(object):
self.pending_result_cache = {} # Requests that haven't finished yet.
self.clock = hs.get_clock()
- self.timeout_sec = timeout_ms / 1000.
+ self.timeout_sec = timeout_ms / 1000.0
self._name = name
- self._metrics = register_cache(
- "response_cache", name, self
- )
+ self._metrics = register_cache("response_cache", name, self)
def size(self):
return len(self.pending_result_cache)
@@ -100,8 +98,7 @@ class ResponseCache(object):
def remove(r):
if self.timeout_sec:
self.clock.call_later(
- self.timeout_sec,
- self.pending_result_cache.pop, key, None,
+ self.timeout_sec, self.pending_result_cache.pop, key, None
)
else:
self.pending_result_cache.pop(key, None)
@@ -140,21 +137,22 @@ class ResponseCache(object):
*args: positional parameters to pass to the callback, if it is used
- **kwargs: named paramters to pass to the callback, if it is used
+ **kwargs: named parameters to pass to the callback, if it is used
Returns:
twisted.internet.defer.Deferred: yieldable result
"""
result = self.get(key)
if not result:
- logger.info("[%s]: no cached result for [%s], calculating new one",
- self._name, key)
+ logger.info(
+ "[%s]: no cached result for [%s], calculating new one", self._name, key
+ )
d = run_in_background(callback, *args, **kwargs)
result = self.set(key, d)
elif not isinstance(result, defer.Deferred) or result.called:
- logger.info("[%s]: using completed cached result for [%s]",
- self._name, key)
+ logger.info("[%s]: using completed cached result for [%s]", self._name, key)
else:
- logger.info("[%s]: using incomplete cached result for [%s]",
- self._name, key)
+ logger.info(
+ "[%s]: using incomplete cached result for [%s]", self._name, key
+ )
return make_deferred_yieldable(result)
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 625aedc940..235f64049c 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -77,9 +77,8 @@ class StreamChangeCache(object):
if stream_pos >= self._earliest_known_stream_pos:
changed_entities = {
- self._cache[k] for k in self._cache.islice(
- start=self._cache.bisect_right(stream_pos),
- )
+ self._cache[k]
+ for k in self._cache.islice(start=self._cache.bisect_right(stream_pos))
}
result = changed_entities.intersection(entities)
@@ -114,8 +113,10 @@ class StreamChangeCache(object):
assert type(stream_pos) is int
if stream_pos >= self._earliest_known_stream_pos:
- return [self._cache[k] for k in self._cache.islice(
- start=self._cache.bisect_right(stream_pos))]
+ return [
+ self._cache[k]
+ for k in self._cache.islice(start=self._cache.bisect_right(stream_pos))
+ ]
else:
return None
@@ -136,7 +137,7 @@ class StreamChangeCache(object):
while len(self._cache) > self._max_size:
k, r = self._cache.popitem(0)
self._earliest_known_stream_pos = max(
- k, self._earliest_known_stream_pos,
+ k, self._earliest_known_stream_pos
)
self._entity_to_key.pop(r, None)
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index dd4c9e6067..9a72218d85 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -9,6 +9,7 @@ class TreeCache(object):
efficiently.
Keys must be tuples.
"""
+
def __init__(self):
self.size = 0
self.root = {}
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index 5ba1862506..2af8ca43b1 100644
--- a/synapse/util/caches/ttlcache.py
+++ b/synapse/util/caches/ttlcache.py
@@ -155,6 +155,7 @@ class TTLCache(object):
@attr.s(frozen=True, slots=True)
class _CacheEntry(object):
"""TTLCache entry"""
+
# expiry_time is the first attribute, so that entries are sorted by expiry.
expiry_time = attr.ib()
key = attr.ib()
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index e14c8bdfda..5a79db821c 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -51,9 +51,7 @@ class Distributor(object):
if name in self.signals:
raise KeyError("%r already has a signal named %s" % (self, name))
- self.signals[name] = Signal(
- name,
- )
+ self.signals[name] = Signal(name)
if name in self.pre_registration:
signal = self.signals[name]
@@ -78,11 +76,7 @@ class Distributor(object):
if name not in self.signals:
raise KeyError("%r does not have a signal named %s" % (self, name))
- run_as_background_process(
- name,
- self.signals[name].fire,
- *args, **kwargs
- )
+ run_as_background_process(name, self.signals[name].fire, *args, **kwargs)
class Signal(object):
@@ -118,22 +112,23 @@ class Signal(object):
def eb(failure):
logger.warning(
"%s signal observer %s failed: %r",
- self.name, observer, failure,
+ self.name,
+ observer,
+ failure,
exc_info=(
failure.type,
failure.value,
- failure.getTracebackObject()))
+ failure.getTracebackObject(),
+ ),
+ )
return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
- deferreds = [
- run_in_background(do, o)
- for o in self.observers
- ]
+ deferreds = [run_in_background(do, o) for o in self.observers]
- return make_deferred_yieldable(defer.gatherResults(
- deferreds, consumeErrors=True,
- ))
+ return make_deferred_yieldable(
+ defer.gatherResults(deferreds, consumeErrors=True)
+ )
def __repr__(self):
return "<Signal name=%r>" % (self.name,)
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 014edea971..635b897d6c 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -60,11 +60,10 @@ def _handle_frozendict(obj):
# fishing the protected dict out of the object is a bit nasty,
# but we don't really want the overhead of copying the dict.
return obj._dict
- raise TypeError('Object of type %s is not JSON serializable' %
- obj.__class__.__name__)
+ raise TypeError(
+ "Object of type %s is not JSON serializable" % obj.__class__.__name__
+ )
# A JSONEncoder which is capable of encoding frozendics without barfing
-frozendict_json_encoder = json.JSONEncoder(
- default=_handle_frozendict,
-)
+frozendict_json_encoder = json.JSONEncoder(default=_handle_frozendict)
diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py
index 2d7ddc1cbe..1a20c596bf 100644
--- a/synapse/util/httpresourcetree.py
+++ b/synapse/util/httpresourcetree.py
@@ -45,7 +45,7 @@ def create_resource_tree(desired_tree, root_resource):
logger.info("Attaching %s to path %s", res, full_path)
last_resource = root_resource
- for path_seg in full_path.split(b'/')[1:-1]:
+ for path_seg in full_path.split(b"/")[1:-1]:
if path_seg not in last_resource.listNames():
# resource doesn't exist, so make a "dummy resource"
child_resource = NoResource()
@@ -60,7 +60,7 @@ def create_resource_tree(desired_tree, root_resource):
# ===========================
# now attach the actual desired resource
- last_path_seg = full_path.split(b'/')[-1]
+ last_path_seg = full_path.split(b"/")[-1]
# if there is already a resource here, thieve its children and
# replace it
@@ -70,9 +70,7 @@ def create_resource_tree(desired_tree, root_resource):
# to be replaced with the desired resource.
existing_dummy_resource = resource_mappings[res_id]
for child_name in existing_dummy_resource.listNames():
- child_res_id = _resource_id(
- existing_dummy_resource, child_name
- )
+ child_res_id = _resource_id(existing_dummy_resource, child_name)
child_resource = resource_mappings[child_res_id]
# steal the children
res.putChild(child_name, child_resource)
diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py
index d668e5a6b8..6dce03dd3a 100644
--- a/synapse/util/jsonobject.py
+++ b/synapse/util/jsonobject.py
@@ -70,7 +70,8 @@ class JsonEncodedObject(object):
dict
"""
d = {
- k: _encode(v) for (k, v) in self.__dict__.items()
+ k: _encode(v)
+ for (k, v) in self.__dict__.items()
if k in self.valid_keys and k not in self.internal_keys
}
d.update(self.unrecognized_keys)
@@ -78,7 +79,8 @@ class JsonEncodedObject(object):
def get_internal_dict(self):
d = {
- k: _encode(v, internal=True) for (k, v) in self.__dict__.items()
+ k: _encode(v, internal=True)
+ for (k, v) in self.__dict__.items()
if k in self.valid_keys
}
d.update(self.unrecognized_keys)
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index fe412355d8..6b0d2deea0 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -42,6 +42,8 @@ try:
def get_thread_resource_usage():
return resource.getrusage(RUSAGE_THREAD)
+
+
except Exception:
# If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we
# won't track resource usage by returning None.
@@ -64,8 +66,11 @@ class ContextResourceUsage(object):
"""
__slots__ = [
- "ru_stime", "ru_utime",
- "db_txn_count", "db_txn_duration_sec", "db_sched_duration_sec",
+ "ru_stime",
+ "ru_utime",
+ "db_txn_count",
+ "db_txn_duration_sec",
+ "db_sched_duration_sec",
"evt_db_fetch_count",
]
@@ -91,8 +96,8 @@ class ContextResourceUsage(object):
return ContextResourceUsage(copy_from=self)
def reset(self):
- self.ru_stime = 0.
- self.ru_utime = 0.
+ self.ru_stime = 0.0
+ self.ru_utime = 0.0
self.db_txn_count = 0
self.db_txn_duration_sec = 0
@@ -100,15 +105,18 @@ class ContextResourceUsage(object):
self.evt_db_fetch_count = 0
def __repr__(self):
- return ("<ContextResourceUsage ru_stime='%r', ru_utime='%r', "
- "db_txn_count='%r', db_txn_duration_sec='%r', "
- "db_sched_duration_sec='%r', evt_db_fetch_count='%r'>") % (
- self.ru_stime,
- self.ru_utime,
- self.db_txn_count,
- self.db_txn_duration_sec,
- self.db_sched_duration_sec,
- self.evt_db_fetch_count,)
+ return (
+ "<ContextResourceUsage ru_stime='%r', ru_utime='%r', "
+ "db_txn_count='%r', db_txn_duration_sec='%r', "
+ "db_sched_duration_sec='%r', evt_db_fetch_count='%r'>"
+ ) % (
+ self.ru_stime,
+ self.ru_utime,
+ self.db_txn_count,
+ self.db_txn_duration_sec,
+ self.db_sched_duration_sec,
+ self.evt_db_fetch_count,
+ )
def __iadd__(self, other):
"""Add another ContextResourceUsage's stats to this one's.
@@ -159,11 +167,15 @@ class LoggingContext(object):
"""
__slots__ = [
- "previous_context", "name", "parent_context",
+ "previous_context",
+ "name",
+ "parent_context",
"_resource_usage",
"usage_start",
- "main_thread", "alive",
- "request", "tag",
+ "main_thread",
+ "alive",
+ "request",
+ "tag",
]
thread_local = threading.local()
@@ -196,6 +208,7 @@ class LoggingContext(object):
def __nonzero__(self):
return False
+
__bool__ = __nonzero__ # python3
sentinel = Sentinel()
@@ -261,7 +274,8 @@ class LoggingContext(object):
if self.previous_context != old_context:
logger.warn(
"Expected previous context %r, found %r",
- self.previous_context, old_context
+ self.previous_context,
+ old_context,
)
self.alive = True
@@ -285,9 +299,8 @@ class LoggingContext(object):
self.alive = False
# if we have a parent, pass our CPU usage stats on
- if (
- self.parent_context is not None
- and hasattr(self.parent_context, '_resource_usage')
+ if self.parent_context is not None and hasattr(
+ self.parent_context, "_resource_usage"
):
self.parent_context._resource_usage += self._resource_usage
@@ -320,15 +333,12 @@ class LoggingContext(object):
# When we stop, let's record the cpu used since we started
if not self.usage_start:
- logger.warning(
- "Called stop on logcontext %s without calling start", self,
- )
+ logger.warning("Called stop on logcontext %s without calling start", self)
return
- usage_end = get_thread_resource_usage()
-
- self._resource_usage.ru_utime += usage_end.ru_utime - self.usage_start.ru_utime
- self._resource_usage.ru_stime += usage_end.ru_stime - self.usage_start.ru_stime
+ utime_delta, stime_delta = self._get_cputime()
+ self._resource_usage.ru_utime += utime_delta
+ self._resource_usage.ru_stime += stime_delta
self.usage_start = None
@@ -346,13 +356,44 @@ class LoggingContext(object):
# can include resource usage so far.
is_main_thread = threading.current_thread() is self.main_thread
if self.alive and self.usage_start and is_main_thread:
- current = get_thread_resource_usage()
- res.ru_utime += current.ru_utime - self.usage_start.ru_utime
- res.ru_stime += current.ru_stime - self.usage_start.ru_stime
+ utime_delta, stime_delta = self._get_cputime()
+ res.ru_utime += utime_delta
+ res.ru_stime += stime_delta
return res
+ def _get_cputime(self):
+ """Get the cpu usage time so far
+
+ Returns: Tuple[float, float]: seconds in user mode, seconds in system mode
+ """
+ current = get_thread_resource_usage()
+
+ utime_delta = current.ru_utime - self.usage_start.ru_utime
+ stime_delta = current.ru_stime - self.usage_start.ru_stime
+
+ # sanity check
+ if utime_delta < 0:
+ logger.error(
+ "utime went backwards! %f < %f",
+ current.ru_utime,
+ self.usage_start.ru_utime,
+ )
+ utime_delta = 0
+
+ if stime_delta < 0:
+ logger.error(
+ "stime went backwards! %f < %f",
+ current.ru_stime,
+ self.usage_start.ru_stime,
+ )
+ stime_delta = 0
+
+ return utime_delta, stime_delta
+
def add_database_transaction(self, duration_sec):
+ if duration_sec < 0:
+ raise ValueError("DB txn time can only be non-negative")
self._resource_usage.db_txn_count += 1
self._resource_usage.db_txn_duration_sec += duration_sec
@@ -363,6 +404,8 @@ class LoggingContext(object):
sched_sec (float): number of seconds it took us to get a
connection
"""
+ if sched_sec < 0:
+ raise ValueError("DB scheduling time can only be non-negative")
self._resource_usage.db_sched_duration_sec += sched_sec
def record_event_fetch(self, event_count):
@@ -381,6 +424,7 @@ class LoggingContextFilter(logging.Filter):
**defaults: Default values to avoid formatters complaining about
missing fields
"""
+
def __init__(self, **defaults):
self.defaults = defaults
@@ -416,17 +460,12 @@ class PreserveLoggingContext(object):
def __enter__(self):
"""Captures the current logging context"""
- self.current_context = LoggingContext.set_current_context(
- self.new_context
- )
+ self.current_context = LoggingContext.set_current_context(self.new_context)
if self.current_context:
self.has_parent = self.current_context.previous_context is not None
if not self.current_context.alive:
- logger.debug(
- "Entering dead context: %s",
- self.current_context,
- )
+ logger.debug("Entering dead context: %s", self.current_context)
def __exit__(self, type, value, traceback):
"""Restores the current logging context"""
@@ -444,10 +483,7 @@ class PreserveLoggingContext(object):
if self.current_context is not LoggingContext.sentinel:
if not self.current_context.alive:
- logger.debug(
- "Restoring dead context: %s",
- self.current_context,
- )
+ logger.debug("Restoring dead context: %s", self.current_context)
def nested_logging_context(suffix, parent_context=None):
@@ -474,15 +510,16 @@ def nested_logging_context(suffix, parent_context=None):
if parent_context is None:
parent_context = LoggingContext.current_context()
return LoggingContext(
- parent_context=parent_context,
- request=parent_context.request + "-" + suffix,
+ parent_context=parent_context, request=parent_context.request + "-" + suffix
)
def preserve_fn(f):
"""Function decorator which wraps the function with run_in_background"""
+
def g(*args, **kwargs):
return run_in_background(f, *args, **kwargs)
+
return g
@@ -502,7 +539,7 @@ def run_in_background(f, *args, **kwargs):
current = LoggingContext.current_context()
try:
res = f(*args, **kwargs)
- except: # noqa: E722
+ except: # noqa: E722
# the assumption here is that the caller doesn't want to be disturbed
# by synchronous exceptions, so let's turn them into Failures.
return defer.fail()
@@ -639,6 +676,4 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
with LoggingContext(parent_context=logcontext):
return f(*args, **kwargs)
- return make_deferred_yieldable(
- threads.deferToThreadPool(reactor, threadpool, g)
- )
+ return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g))
diff --git a/synapse/util/logformatter.py b/synapse/util/logformatter.py
index a46bc47ce3..fbf570c756 100644
--- a/synapse/util/logformatter.py
+++ b/synapse/util/logformatter.py
@@ -29,6 +29,7 @@ class LogFormatter(logging.Formatter):
(Normally only stack frames between the point the exception was raised and
where it was caught are logged).
"""
+
def __init__(self, *args, **kwargs):
super(LogFormatter, self).__init__(*args, **kwargs)
@@ -40,7 +41,7 @@ class LogFormatter(logging.Formatter):
# check that we actually have an f_back attribute to work around
# https://twistedmatrix.com/trac/ticket/9305
- if tb and hasattr(tb.tb_frame, 'f_back'):
+ if tb and hasattr(tb.tb_frame, "f_back"):
sio.write("Capture point (most recent call last):\n")
traceback.print_stack(tb.tb_frame.f_back, None, sio)
diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py
index ef31458226..7df0fa6087 100644
--- a/synapse/util/logutils.py
+++ b/synapse/util/logutils.py
@@ -44,7 +44,7 @@ def _log_debug_as_f(f, msg, msg_args):
lineno=lineno,
msg=msg,
args=msg_args,
- exc_info=None
+ exc_info=None,
)
logger.handle(record)
@@ -70,20 +70,11 @@ def log_function(f):
r = r[:50] + "..."
return r
- func_args = [
- "%s=%s" % (k, format(v)) for k, v in bound_args.items()
- ]
+ func_args = ["%s=%s" % (k, format(v)) for k, v in bound_args.items()]
- msg_args = {
- "func_name": func_name,
- "args": ", ".join(func_args)
- }
+ msg_args = {"func_name": func_name, "args": ", ".join(func_args)}
- _log_debug_as_f(
- f,
- "Invoked '%(func_name)s' with args: %(args)s",
- msg_args
- )
+ _log_debug_as_f(f, "Invoked '%(func_name)s' with args: %(args)s", msg_args)
return f(*args, **kwargs)
@@ -103,19 +94,13 @@ def time_function(f):
start = time.clock()
try:
- _log_debug_as_f(
- f,
- "[FUNC START] {%s-%d}",
- (func_name, id),
- )
+ _log_debug_as_f(f, "[FUNC START] {%s-%d}", (func_name, id))
r = f(*args, **kwargs)
finally:
end = time.clock()
_log_debug_as_f(
- f,
- "[FUNC END] {%s-%d} %.3f sec",
- (func_name, id, end - start,),
+ f, "[FUNC END] {%s-%d} %.3f sec", (func_name, id, end - start)
)
return r
@@ -137,9 +122,8 @@ def trace_function(f):
s = inspect.currentframe().f_back
to_print = [
- "\t%s:%s %s. Args: args=%s, kwargs=%s" % (
- pathname, linenum, func_name, args, kwargs
- )
+ "\t%s:%s %s. Args: args=%s, kwargs=%s"
+ % (pathname, linenum, func_name, args, kwargs)
]
while s:
if True or s.f_globals["__name__"].startswith("synapse"):
@@ -147,9 +131,7 @@ def trace_function(f):
args_string = inspect.formatargvalues(*inspect.getargvalues(s))
to_print.append(
- "\t%s:%d %s. Args: %s" % (
- filename, lineno, function, args_string
- )
+ "\t%s:%d %s. Args: %s" % (filename, lineno, function, args_string)
)
s = s.f_back
@@ -163,7 +145,7 @@ def trace_function(f):
lineno=lineno,
msg=msg,
args=None,
- exc_info=None
+ exc_info=None,
)
logger.handle(record)
@@ -182,13 +164,13 @@ def get_previous_frames():
filename, lineno, function, _, _ = inspect.getframeinfo(s)
args_string = inspect.formatargvalues(*inspect.getargvalues(s))
- to_return.append("{{ %s:%d %s - Args: %s }}" % (
- filename, lineno, function, args_string
- ))
+ to_return.append(
+ "{{ %s:%d %s - Args: %s }}" % (filename, lineno, function, args_string)
+ )
s = s.f_back
- return ", ". join(to_return)
+ return ", ".join(to_return)
def get_previous_frame(ignore=[]):
@@ -201,7 +183,10 @@ def get_previous_frame(ignore=[]):
args_string = inspect.formatargvalues(*inspect.getargvalues(s))
return "{{ %s:%d %s - Args: %s }}" % (
- filename, lineno, function, args_string
+ filename,
+ lineno,
+ function,
+ args_string,
)
s = s.f_back
diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py
index 628a2962d9..631654f297 100644
--- a/synapse/util/manhole.py
+++ b/synapse/util/manhole.py
@@ -74,27 +74,25 @@ def manhole(username, password, globals):
twisted.internet.protocol.Factory: A factory to pass to ``listenTCP``
"""
if not isinstance(password, bytes):
- password = password.encode('ascii')
+ password = password.encode("ascii")
- checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(
- **{username: password}
- )
+ checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(**{username: password})
rlm = manhole_ssh.TerminalRealm()
rlm.chainedProtocolFactory = lambda: insults.ServerProtocol(
- SynapseManhole,
- dict(globals, __name__="__console__")
+ SynapseManhole, dict(globals, __name__="__console__")
)
factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
- factory.publicKeys[b'ssh-rsa'] = Key.fromString(PUBLIC_KEY)
- factory.privateKeys[b'ssh-rsa'] = Key.fromString(PRIVATE_KEY)
+ factory.publicKeys[b"ssh-rsa"] = Key.fromString(PUBLIC_KEY)
+ factory.privateKeys[b"ssh-rsa"] = Key.fromString(PRIVATE_KEY)
return factory
class SynapseManhole(ColoredManhole):
"""Overrides connectionMade to create our own ManholeInterpreter"""
+
def connectionMade(self):
super(SynapseManhole, self).connectionMade()
@@ -127,7 +125,7 @@ class SynapseManholeInterpreter(ManholeInterpreter):
value = SyntaxError(msg, (filename, lineno, offset, line))
sys.last_value = value
lines = traceback.format_exception_only(type, value)
- self.write(''.join(lines))
+ self.write("".join(lines))
def showtraceback(self):
"""Display the exception that just occurred.
@@ -140,6 +138,6 @@ class SynapseManholeInterpreter(ManholeInterpreter):
try:
# We remove the first stack item because it is our own code.
lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next)
- self.write(''.join(lines))
+ self.write("".join(lines))
finally:
last_tb = ei = None
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 4b4ac5f6c7..01284d3cf8 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -30,25 +30,31 @@ block_counter = Counter("synapse_util_metrics_block_count", "", ["block_name"])
block_timer = Counter("synapse_util_metrics_block_time_seconds", "", ["block_name"])
block_ru_utime = Counter(
- "synapse_util_metrics_block_ru_utime_seconds", "", ["block_name"])
+ "synapse_util_metrics_block_ru_utime_seconds", "", ["block_name"]
+)
block_ru_stime = Counter(
- "synapse_util_metrics_block_ru_stime_seconds", "", ["block_name"])
+ "synapse_util_metrics_block_ru_stime_seconds", "", ["block_name"]
+)
block_db_txn_count = Counter(
- "synapse_util_metrics_block_db_txn_count", "", ["block_name"])
+ "synapse_util_metrics_block_db_txn_count", "", ["block_name"]
+)
# seconds spent waiting for db txns, excluding scheduling time, in this block
block_db_txn_duration = Counter(
- "synapse_util_metrics_block_db_txn_duration_seconds", "", ["block_name"])
+ "synapse_util_metrics_block_db_txn_duration_seconds", "", ["block_name"]
+)
# seconds spent waiting for a db connection, in this block
block_db_sched_duration = Counter(
- "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"])
+ "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"]
+)
# Tracks the number of blocks currently active
in_flight = InFlightGauge(
- "synapse_util_metrics_block_in_flight", "",
+ "synapse_util_metrics_block_in_flight",
+ "",
labels=["block_name"],
sub_metrics=["real_time_max", "real_time_sum"],
)
@@ -62,13 +68,18 @@ def measure_func(name):
with Measure(self.clock, name):
r = yield func(self, *args, **kwargs)
defer.returnValue(r)
+
return measured_func
+
return wrapper
class Measure(object):
__slots__ = [
- "clock", "name", "start_context", "start",
+ "clock",
+ "name",
+ "start_context",
+ "start",
"created_context",
"start_usage",
]
@@ -108,7 +119,9 @@ class Measure(object):
if context != self.start_context:
logger.warn(
"Context has unexpectedly changed from '%s' to '%s'. (%r)",
- self.start_context, context, self.name
+ self.start_context,
+ context,
+ self.name,
)
return
@@ -126,8 +139,7 @@ class Measure(object):
block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec)
except ValueError:
logger.warn(
- "Failed to save metrics! OLD: %r, NEW: %r",
- self.start_usage, current
+ "Failed to save metrics! OLD: %r, NEW: %r", self.start_usage, current
)
if self.created_context:
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 4288312b8a..522acd5aa8 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -28,15 +28,13 @@ def load_module(provider):
"""
# We need to import the module, and then pick the class out of
# that, so we split based on the last dot.
- module, clz = provider['module'].rsplit(".", 1)
+ module, clz = provider["module"].rsplit(".", 1)
module = importlib.import_module(module)
provider_class = getattr(module, clz)
try:
provider_config = provider_class.parse_config(provider["config"])
except Exception as e:
- raise ConfigError(
- "Failed to parse config for %r: %r" % (provider['module'], e)
- )
+ raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e))
return provider_class, provider_config
diff --git a/synapse/util/msisdn.py b/synapse/util/msisdn.py
index a6c30e5265..c8bcbe297a 100644
--- a/synapse/util/msisdn.py
+++ b/synapse/util/msisdn.py
@@ -36,6 +36,6 @@ def phone_number_to_msisdn(country, number):
phoneNumber = phonenumbers.parse(number, country)
except phonenumbers.NumberParseException:
raise SynapseError(400, "Unable to parse phone number")
- return phonenumbers.format_number(
- phoneNumber, phonenumbers.PhoneNumberFormat.E164
- )[1:]
+ return phonenumbers.format_number(phoneNumber, phonenumbers.PhoneNumberFormat.E164)[
+ 1:
+ ]
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index b146d137f4..06defa8199 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -56,11 +56,7 @@ class FederationRateLimiter(object):
_PerHostRatelimiter
"""
return self.ratelimiters.setdefault(
- host,
- _PerHostRatelimiter(
- clock=self.clock,
- config=self._config,
- )
+ host, _PerHostRatelimiter(clock=self.clock, config=self._config)
).ratelimit()
@@ -112,8 +108,7 @@ class _PerHostRatelimiter(object):
# remove any entries from request_times which aren't within the window
self.request_times[:] = [
- r for r in self.request_times
- if time_now - r < self.window_size
+ r for r in self.request_times if time_now - r < self.window_size
]
# reject the request if we already have too many queued up (either
@@ -121,9 +116,7 @@ class _PerHostRatelimiter(object):
queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
if queue_size > self.reject_limit:
raise LimitExceededError(
- retry_after_ms=int(
- self.window_size / self.sleep_limit
- ),
+ retry_after_ms=int(self.window_size / self.sleep_limit)
)
self.request_times.append(time_now)
@@ -143,22 +136,18 @@ class _PerHostRatelimiter(object):
logger.debug(
"Ratelimit [%s]: len(self.request_times)=%d",
- id(request_id), len(self.request_times),
+ id(request_id),
+ len(self.request_times),
)
if len(self.request_times) > self.sleep_limit:
- logger.debug(
- "Ratelimiter: sleeping request for %f sec", self.sleep_sec,
- )
+ logger.debug("Ratelimiter: sleeping request for %f sec", self.sleep_sec)
ret_defer = run_in_background(self.clock.sleep, self.sleep_sec)
self.sleeping_requests.add(request_id)
def on_wait_finished(_):
- logger.debug(
- "Ratelimit [%s]: Finished sleeping",
- id(request_id),
- )
+ logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id))
self.sleeping_requests.discard(request_id)
queue_defer = queue_request()
return queue_defer
@@ -168,10 +157,7 @@ class _PerHostRatelimiter(object):
ret_defer = queue_request()
def on_start(r):
- logger.debug(
- "Ratelimit [%s]: Processing req",
- id(request_id),
- )
+ logger.debug("Ratelimit [%s]: Processing req", id(request_id))
self.current_processing.add(request_id)
return r
@@ -193,10 +179,7 @@ class _PerHostRatelimiter(object):
return make_deferred_yieldable(ret_defer)
def _on_exit(self, request_id):
- logger.debug(
- "Ratelimit [%s]: Processed req",
- id(request_id),
- )
+ logger.debug("Ratelimit [%s]: Processed req", id(request_id))
self.current_processing.discard(request_id)
try:
# start processing the next item on the queue.
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 69dffd8244..982c6d81ca 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -20,9 +20,7 @@ import six
from six import PY2, PY3
from six.moves import range
-_string_with_symbols = (
- string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
-)
+_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
# random_string and random_string_with_symbols are used for a range of things,
# some cryptographically important, some less so. We use SystemRandom to make sure
@@ -31,13 +29,11 @@ rand = random.SystemRandom()
def random_string(length):
- return ''.join(rand.choice(string.ascii_letters) for _ in range(length))
+ return "".join(rand.choice(string.ascii_letters) for _ in range(length))
def random_string_with_symbols(length):
- return ''.join(
- rand.choice(_string_with_symbols) for _ in range(length)
- )
+ return "".join(rand.choice(_string_with_symbols) for _ in range(length))
def is_ascii(s):
@@ -45,7 +41,7 @@ def is_ascii(s):
if PY3:
if isinstance(s, bytes):
try:
- s.decode('ascii').encode('ascii')
+ s.decode("ascii").encode("ascii")
except UnicodeDecodeError:
return False
except UnicodeEncodeError:
@@ -104,12 +100,12 @@ def exception_to_unicode(e):
# and instead look at what is in the args member.
if len(e.args) == 0:
- return u""
+ return ""
elif len(e.args) > 1:
return six.text_type(repr(e.args))
msg = e.args[0]
if isinstance(msg, bytes):
- return msg.decode('utf-8', errors='replace')
+ return msg.decode("utf-8", errors="replace")
else:
return msg
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index 75efa0117b..3ec1dfb0c2 100644
--- a/synapse/util/threepids.py
+++ b/synapse/util/threepids.py
@@ -35,11 +35,13 @@ def check_3pid_allowed(hs, medium, address):
for constraint in hs.config.allowed_local_3pids:
logger.debug(
"Checking 3PID %s (%s) against %s (%s)",
- address, medium, constraint['pattern'], constraint['medium'],
+ address,
+ medium,
+ constraint["pattern"],
+ constraint["medium"],
)
- if (
- medium == constraint['medium'] and
- re.match(constraint['pattern'], address)
+ if medium == constraint["medium"] and re.match(
+ constraint["pattern"], address
):
return True
else:
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index 3baba3225a..a4d9a462f7 100644
--- a/synapse/util/versionstring.py
+++ b/synapse/util/versionstring.py
@@ -23,44 +23,53 @@ logger = logging.getLogger(__name__)
def get_version_string(module):
try:
- null = open(os.devnull, 'w')
+ null = open(os.devnull, "w")
cwd = os.path.dirname(os.path.abspath(module.__file__))
try:
- git_branch = subprocess.check_output(
- ['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
- stderr=null,
- cwd=cwd,
- ).strip().decode('ascii')
+ git_branch = (
+ subprocess.check_output(
+ ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=null, cwd=cwd
+ )
+ .strip()
+ .decode("ascii")
+ )
git_branch = "b=" + git_branch
except subprocess.CalledProcessError:
git_branch = ""
try:
- git_tag = subprocess.check_output(
- ['git', 'describe', '--exact-match'],
- stderr=null,
- cwd=cwd,
- ).strip().decode('ascii')
+ git_tag = (
+ subprocess.check_output(
+ ["git", "describe", "--exact-match"], stderr=null, cwd=cwd
+ )
+ .strip()
+ .decode("ascii")
+ )
git_tag = "t=" + git_tag
except subprocess.CalledProcessError:
git_tag = ""
try:
- git_commit = subprocess.check_output(
- ['git', 'rev-parse', '--short', 'HEAD'],
- stderr=null,
- cwd=cwd,
- ).strip().decode('ascii')
+ git_commit = (
+ subprocess.check_output(
+ ["git", "rev-parse", "--short", "HEAD"], stderr=null, cwd=cwd
+ )
+ .strip()
+ .decode("ascii")
+ )
except subprocess.CalledProcessError:
git_commit = ""
try:
dirty_string = "-this_is_a_dirty_checkout"
- is_dirty = subprocess.check_output(
- ['git', 'describe', '--dirty=' + dirty_string],
- stderr=null,
- cwd=cwd,
- ).strip().decode('ascii').endswith(dirty_string)
+ is_dirty = (
+ subprocess.check_output(
+ ["git", "describe", "--dirty=" + dirty_string], stderr=null, cwd=cwd
+ )
+ .strip()
+ .decode("ascii")
+ .endswith(dirty_string)
+ )
git_dirty = "dirty" if is_dirty else ""
except subprocess.CalledProcessError:
@@ -68,16 +77,10 @@ def get_version_string(module):
if git_branch or git_tag or git_commit or git_dirty:
git_version = ",".join(
- s for s in
- (git_branch, git_tag, git_commit, git_dirty,)
- if s
+ s for s in (git_branch, git_tag, git_commit, git_dirty) if s
)
- return (
- "%s (%s)" % (
- module.__version__, git_version,
- )
- )
+ return "%s (%s)" % (module.__version__, git_version)
except Exception as e:
logger.info("Failed to check for git repository: %s", e)
diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index 7a9e45aca9..9bf6a44f75 100644
--- a/synapse/util/wheel_timer.py
+++ b/synapse/util/wheel_timer.py
@@ -69,9 +69,7 @@ class WheelTimer(object):
# Add empty entries between the end of the current list and when we want
# to insert. This ensures there are no gaps.
- self.entries.extend(
- _Entry(key) for key in range(last_key, then_key + 1)
- )
+ self.entries.extend(_Entry(key) for key in range(last_key, then_key + 1))
self.entries[-1].queue.append(obj)
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 16c40cd74c..2a11c83596 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -29,12 +29,7 @@ from synapse.types import get_domain_from_id
logger = logging.getLogger(__name__)
-VISIBILITY_PRIORITY = (
- "world_readable",
- "shared",
- "invited",
- "joined",
-)
+VISIBILITY_PRIORITY = ("world_readable", "shared", "invited", "joined")
MEMBERSHIP_PRIORITY = (
@@ -47,8 +42,9 @@ MEMBERSHIP_PRIORITY = (
@defer.inlineCallbacks
-def filter_events_for_client(store, user_id, events, is_peeking=False,
- always_include_ids=frozenset()):
+def filter_events_for_client(
+ store, user_id, events, is_peeking=False, always_include_ids=frozenset()
+):
"""
Check which events a user is allowed to see
@@ -71,23 +67,21 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
# to clients.
events = list(e for e in events if not e.internal_metadata.is_soft_failed())
- types = (
- (EventTypes.RoomHistoryVisibility, ""),
- (EventTypes.Member, user_id),
- )
+ types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
event_id_to_state = yield store.get_state_for_events(
frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(types),
)
ignore_dict_content = yield store.get_global_account_data_by_type_for_user(
- "m.ignored_user_list", user_id,
+ "m.ignored_user_list", user_id
)
# FIXME: This will explode if people upload something incorrect.
ignore_list = frozenset(
ignore_dict_content.get("ignored_users", {}).keys()
- if ignore_dict_content else []
+ if ignore_dict_content
+ else []
)
erased_senders = yield store.are_users_erased((e.sender for e in events))
@@ -185,9 +179,7 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
elif visibility == "invited":
# user can also see the event if they were *invited* at the time
# of the event.
- return (
- event if membership == Membership.INVITE else None
- )
+ return event if membership == Membership.INVITE else None
elif visibility == "shared" and is_peeking:
# if the visibility is shared, users cannot see the event unless
@@ -220,8 +212,9 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
@defer.inlineCallbacks
-def filter_events_for_server(store, server_name, events, redact=True,
- check_history_visibility_only=False):
+def filter_events_for_server(
+ store, server_name, events, redact=True, check_history_visibility_only=False
+):
"""Filter a list of events based on whether given server is allowed to
see them.
@@ -242,15 +235,12 @@ def filter_events_for_server(store, server_name, events, redact=True,
def is_sender_erased(event, erased_senders):
if erased_senders and erased_senders[event.sender]:
- logger.info(
- "Sender of %s has been erased, redacting",
- event.event_id,
- )
+ logger.info("Sender of %s has been erased, redacting", event.event_id)
return True
return False
def check_event_is_visible(event, state):
- history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
+ history = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if history:
visibility = history.content.get("history_visibility", "shared")
if visibility in ["invited", "joined"]:
@@ -287,8 +277,8 @@ def filter_events_for_server(store, server_name, events, redact=True,
event_to_state_ids = yield store.get_state_ids_for_events(
frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(
- types=((EventTypes.RoomHistoryVisibility, ""),),
- )
+ types=((EventTypes.RoomHistoryVisibility, ""),)
+ ),
)
visibility_ids = set()
@@ -309,9 +299,7 @@ def filter_events_for_server(store, server_name, events, redact=True,
)
if not check_history_visibility_only:
- erased_senders = yield store.are_users_erased(
- (e.sender for e in events),
- )
+ erased_senders = yield store.are_users_erased((e.sender for e in events))
else:
# We don't want to check whether users are erased, which is equivalent
# to no users having been erased.
@@ -343,11 +331,8 @@ def filter_events_for_server(store, server_name, events, redact=True,
event_to_state_ids = yield store.get_state_ids_for_events(
frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(
- types=(
- (EventTypes.RoomHistoryVisibility, ""),
- (EventTypes.Member, None),
- ),
- )
+ types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None))
+ ),
)
# We only want to pull out member events that correspond to the
@@ -371,13 +356,15 @@ def filter_events_for_server(store, server_name, events, redact=True,
idx = state_key.find(":")
if idx == -1:
return False
- return state_key[idx + 1:] == server_name
-
- event_map = yield store.get_events([
- e_id
- for e_id, key in iteritems(event_id_to_state_key)
- if include(key[0], key[1])
- ])
+ return state_key[idx + 1 :] == server_name
+
+ event_map = yield store.get_events(
+ [
+ e_id
+ for e_id, key in iteritems(event_id_to_state_key)
+ if include(key[0], key[1])
+ ]
+ )
event_to_state = {
e_id: {
|