diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 7ce6540bdd..943b5a2c86 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -128,7 +128,7 @@ class Auth(object):
)
self._check_joined_room(member, user_id, room_id)
- defer.returnValue(member)
+ return member
@defer.inlineCallbacks
def check_user_was_in_room(self, room_id, user_id):
@@ -156,13 +156,13 @@ class Auth(object):
if forgot:
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
- defer.returnValue(member)
+ return member
@defer.inlineCallbacks
def check_host_in_room(self, room_id, host):
with Measure(self.clock, "check_host_in_room"):
latest_event_ids = yield self.store.is_host_joined(room_id, host)
- defer.returnValue(latest_event_ids)
+ return latest_event_ids
def _check_joined_room(self, member, user_id, room_id):
if not member or member.membership != Membership.JOIN:
@@ -207,6 +207,7 @@ class Auth(object):
access_token = self.get_access_token_from_request(request)
user_id, app_service = yield self._get_appservice_user_id(request)
+
if user_id:
request.authenticated_entity = user_id
@@ -219,9 +220,7 @@ class Auth(object):
device_id="dummy-device", # stubbed
)
- defer.returnValue(
- synapse.types.create_requester(user_id, app_service=app_service)
- )
+ return synapse.types.create_requester(user_id, app_service=app_service)
user_info = yield self.get_user_by_access_token(access_token, rights)
user = user_info["user"]
@@ -262,39 +261,41 @@ class Auth(object):
request.authenticated_entity = user.to_string()
- defer.returnValue(
- synapse.types.create_requester(
- user, token_id, is_guest, device_id, app_service=app_service
- )
+ return synapse.types.create_requester(
+ user, token_id, is_guest, device_id, app_service=app_service
)
except KeyError:
raise MissingClientTokenError()
- @defer.inlineCallbacks
def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token(
self.get_access_token_from_request(request)
)
+
if app_service is None:
- defer.returnValue((None, None))
+ return (None, None)
if app_service.ip_range_whitelist:
ip_address = IPAddress(self.hs.get_ip_from_request(request))
if ip_address not in app_service.ip_range_whitelist:
- defer.returnValue((None, None))
+ return (None, None)
if b"user_id" not in request.args:
- defer.returnValue((app_service.sender, app_service))
+ return (app_service.sender, app_service)
user_id = request.args[b"user_id"][0].decode("utf8")
if app_service.sender == user_id:
- defer.returnValue((app_service.sender, app_service))
+ return (app_service.sender, app_service)
if not app_service.is_interested_in_user(user_id):
raise AuthError(403, "Application service cannot masquerade as this user.")
- if not (yield self.store.get_user_by_id(user_id)):
- raise AuthError(403, "Application service has not registered this user")
- defer.returnValue((user_id, app_service))
+ # Let ASes manipulate nonexistent users (e.g. to shadow-register them)
+ # if not (yield self.store.get_user_by_id(user_id)):
+ # raise AuthError(
+ # 403,
+ # "Application service has not registered this user"
+ # )
+ return (user_id, app_service)
@defer.inlineCallbacks
def get_user_by_access_token(self, token, rights="access"):
@@ -330,7 +331,7 @@ class Auth(object):
msg="Access token has expired", soft_logout=True
)
- defer.returnValue(r)
+ return r
# otherwise it needs to be a valid macaroon
try:
@@ -378,7 +379,7 @@ class Auth(object):
}
else:
raise RuntimeError("Unknown rights setting %s", rights)
- defer.returnValue(ret)
+ return ret
except (
_InvalidMacaroonException,
pymacaroons.exceptions.MacaroonException,
@@ -506,7 +507,7 @@ class Auth(object):
def _look_up_user_by_access_token(self, token):
ret = yield self.store.get_user_by_access_token(token)
if not ret:
- defer.returnValue(None)
+ return None
# we use ret.get() below because *lots* of unit tests stub out
# get_user_by_access_token in a way where it only returns a couple of
@@ -518,7 +519,7 @@ class Auth(object):
"device_id": ret.get("device_id"),
"valid_until_ms": ret.get("valid_until_ms"),
}
- defer.returnValue(user_info)
+ return user_info
def get_appservice_by_req(self, request):
token = self.get_access_token_from_request(request)
@@ -543,7 +544,7 @@ class Auth(object):
@defer.inlineCallbacks
def compute_auth_events(self, event, current_state_ids, for_verification=False):
if event.type == EventTypes.Create:
- defer.returnValue([])
+ return []
auth_ids = []
@@ -604,7 +605,7 @@ class Auth(object):
if member_event.content["membership"] == Membership.JOIN:
auth_ids.append(member_event.event_id)
- defer.returnValue(auth_ids)
+ return auth_ids
@defer.inlineCallbacks
def check_can_change_room_list(self, room_id, user):
@@ -618,7 +619,7 @@ class Auth(object):
is_admin = yield self.is_server_admin(user)
if is_admin:
- defer.returnValue(True)
+ return True
user_id = user.to_string()
yield self.check_joined_room(room_id, user_id)
@@ -712,7 +713,7 @@ class Auth(object):
# * The user is a guest user, and has joined the room
# else it will throw.
member_event = yield self.check_user_was_in_room(room_id, user_id)
- defer.returnValue((member_event.membership, member_event.event_id))
+ return (member_event.membership, member_event.event_id)
except AuthError:
visibility = yield self.state.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
@@ -721,7 +722,7 @@ class Auth(object):
visibility
and visibility.content["history_visibility"] == "world_readable"
):
- defer.returnValue((Membership.JOIN, None))
+ return (Membership.JOIN, None)
return
raise AuthError(
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 3ffde0d7fc..c7cae9768f 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -85,6 +85,7 @@ class EventTypes(object):
RoomAvatar = "m.room.avatar"
RoomEncryption = "m.room.encryption"
GuestAccess = "m.room.guest_access"
+ Encryption = "m.room.encryption"
# These are used for validation
Message = "m.room.message"
@@ -94,6 +95,8 @@ class EventTypes(object):
ServerACL = "m.room.server_acl"
Pinned = "m.room.pinned_events"
+ Retention = "m.room.retention"
+
class RejectedReason(object):
AUTH_ERROR = "auth_error"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index ad3e262041..c293135b51 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -61,6 +62,13 @@ class Codes(object):
INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION"
WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION"
EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT"
+ PASSWORD_TOO_SHORT = "M_PASSWORD_TOO_SHORT"
+ PASSWORD_NO_DIGIT = "M_PASSWORD_NO_DIGIT"
+ PASSWORD_NO_UPPERCASE = "M_PASSWORD_NO_UPPERCASE"
+ PASSWORD_NO_LOWERCASE = "M_PASSWORD_NO_LOWERCASE"
+ PASSWORD_NO_SYMBOL = "M_PASSWORD_NO_SYMBOL"
+ PASSWORD_IN_DICTIONARY = "M_PASSWORD_IN_DICTIONARY"
+ WEAK_PASSWORD = "M_WEAK_PASSWORD"
class CodeMessageException(RuntimeError):
@@ -418,6 +426,18 @@ class IncompatibleRoomVersionError(SynapseError):
return cs_error(self.msg, self.errcode, room_version=self._room_version)
+class PasswordRefusedError(SynapseError):
+ """A password has been refused, either during password reset/change or registration.
+ """
+
+ def __init__(
+ self,
+ msg="This password doesn't comply with the server's policy",
+ errcode=Codes.WEAK_PASSWORD,
+ ):
+ super(PasswordRefusedError, self).__init__(code=400, msg=msg, errcode=errcode)
+
+
class RequestSendFailed(RuntimeError):
"""Sending a HTTP request over federation failed due to not being able to
talk to the remote server for some reason.
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 9b3daca29b..9f06556bd2 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -132,7 +132,7 @@ class Filtering(object):
@defer.inlineCallbacks
def get_user_filter(self, user_localpart, filter_id):
result = yield self.store.get_user_filter(user_localpart, filter_id)
- defer.returnValue(FilterCollection(result))
+ return FilterCollection(result)
def add_user_filter(self, user_localpart, user_filter):
self.check_valid_filter(user_filter)
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 540dbd9236..c010e70955 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -15,10 +15,12 @@
import gc
import logging
+import os
import signal
import sys
import traceback
+import sdnotify
from daemonize import Daemonize
from twisted.internet import defer, error, reactor
@@ -242,9 +244,16 @@ def start(hs, listeners=None):
if hasattr(signal, "SIGHUP"):
def handle_sighup(*args, **kwargs):
+ # Tell systemd our state, if we're using it. This will silently fail if
+ # we're not using systemd.
+ sd_channel = sdnotify.SystemdNotifier()
+ sd_channel.notify("RELOADING=1")
+
for i in _sighup_callbacks:
i(hs)
+ sd_channel.notify("READY=1")
+
signal.signal(signal.SIGHUP, handle_sighup)
register_sighup(refresh_certificate)
@@ -260,6 +269,7 @@ def start(hs, listeners=None):
hs.get_datastore().start_profiling()
setup_sentry(hs)
+ setup_sdnotify(hs)
except Exception:
traceback.print_exc(file=sys.stderr)
reactor = hs.get_reactor()
@@ -292,6 +302,25 @@ def setup_sentry(hs):
scope.set_tag("worker_name", name)
+def setup_sdnotify(hs):
+ """Adds process state hooks to tell systemd what we are up to.
+ """
+
+ # Tell systemd our state, if we're using it. This will silently fail if
+ # we're not using systemd.
+ sd_channel = sdnotify.SystemdNotifier()
+
+ hs.get_reactor().addSystemEventTrigger(
+ "after",
+ "startup",
+ lambda: sd_channel.notify("READY=1\nMAINPID=%s" % (os.getpid())),
+ )
+
+ hs.get_reactor().addSystemEventTrigger(
+ "before", "shutdown", lambda: sd_channel.notify("STOPPING=1")
+ )
+
+
def install_dns_limiter(reactor, max_dns_requests_in_flight=100):
"""Replaces the resolver with one that limits the number of in flight DNS
requests.
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index e01f3e5f3b..54bb114dec 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -168,7 +168,9 @@ def start(config_options):
)
ps.setup()
- reactor.callWhenRunning(_base.start, ps, config.worker_listeners)
+ reactor.addSystemEventTrigger(
+ "before", "startup", _base.start, ps, config.worker_listeners
+ )
_base.start_worker_reactor("synapse-appservice", config)
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index 29bddc4823..721bb5b119 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -194,7 +194,9 @@ def start(config_options):
)
ss.setup()
- reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
+ reactor.addSystemEventTrigger(
+ "before", "startup", _base.start, ss, config.worker_listeners
+ )
_base.start_worker_reactor("synapse-client-reader", config)
diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py
index 042cfd04af..473c8895d0 100644
--- a/synapse/app/event_creator.py
+++ b/synapse/app/event_creator.py
@@ -193,7 +193,9 @@ def start(config_options):
)
ss.setup()
- reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
+ reactor.addSystemEventTrigger(
+ "before", "startup", _base.start, ss, config.worker_listeners
+ )
_base.start_worker_reactor("synapse-event-creator", config)
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index 76a97f8f32..5255d9e8cc 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -175,7 +175,9 @@ def start(config_options):
)
ss.setup()
- reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
+ reactor.addSystemEventTrigger(
+ "before", "startup", _base.start, ss, config.worker_listeners
+ )
_base.start_worker_reactor("synapse-federation-reader", config)
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index fec49d5092..c5a2880e69 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -198,7 +198,9 @@ def start(config_options):
)
ss.setup()
- reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
+ reactor.addSystemEventTrigger(
+ "before", "startup", _base.start, ss, config.worker_listeners
+ )
_base.start_worker_reactor("synapse-federation-sender", config)
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index 1f1f1df78e..e2822ca848 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -70,12 +70,12 @@ class PresenceStatusStubServlet(RestServlet):
except HttpResponseException as e:
raise e.to_synapse_error()
- defer.returnValue((200, result))
+ return (200, result)
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
yield self.auth.get_user_by_req(request)
- defer.returnValue((200, {}))
+ return (200, {})
class KeyUploadServlet(RestServlet):
@@ -126,11 +126,11 @@ class KeyUploadServlet(RestServlet):
self.main_uri + request.uri.decode("ascii"), body, headers=headers
)
- defer.returnValue((200, result))
+ return (200, result)
else:
# Just interested in counts.
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
- defer.returnValue((200, {"one_time_key_counts": result}))
+ return (200, {"one_time_key_counts": result})
class FrontendProxySlavedStore(
@@ -247,7 +247,9 @@ def start(config_options):
)
ss.setup()
- reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
+ reactor.addSystemEventTrigger(
+ "before", "startup", _base.start, ss, config.worker_listeners
+ )
_base.start_worker_reactor("synapse-frontend-proxy", config)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 0c075cb3f1..fe4fa20bd9 100755..100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -406,7 +406,7 @@ def setup(config_options):
if provision:
yield acme.provision_certificate()
- defer.returnValue(provision)
+ return provision
@defer.inlineCallbacks
def reprovision_acme():
@@ -447,7 +447,7 @@ def setup(config_options):
reactor.stop()
sys.exit(1)
- reactor.callWhenRunning(start)
+ reactor.addSystemEventTrigger("before", "startup", start)
return hs
@@ -563,7 +563,7 @@ def run(hs):
stats["database_server_version"] = hs.get_datastore().get_server_version()
logger.info("Reporting stats to matrix.org: %s" % (stats,))
try:
- yield hs.get_simple_http_client().put_json(
+ yield hs.get_proxied_http_client().put_json(
"https://matrix.org/report-usage-stats/push", stats
)
except Exception as e:
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index d70780e9d5..ea26f29acb 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -161,7 +161,9 @@ def start(config_options):
)
ss.setup()
- reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
+ reactor.addSystemEventTrigger(
+ "before", "startup", _base.start, ss, config.worker_listeners
+ )
_base.start_worker_reactor("synapse-media-repository", config)
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index 070de7d0b0..692ffa2f04 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -216,7 +216,7 @@ def start(config_options):
_base.start(ps, config.worker_listeners)
ps.get_pusherpool().start()
- reactor.callWhenRunning(start)
+ reactor.addSystemEventTrigger("before", "startup", start)
_base.start_worker_reactor("synapse-pusher", config)
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 315c030694..a1c3b162f7 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -451,7 +451,9 @@ def start(config_options):
)
ss.setup()
- reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
+ reactor.addSystemEventTrigger(
+ "before", "startup", _base.start, ss, config.worker_listeners
+ )
_base.start_worker_reactor("synapse-synchrotron", config)
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index 03ef21bd01..cb29a1afab 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -224,7 +224,9 @@ def start(config_options):
)
ss.setup()
- reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
+ reactor.addSystemEventTrigger(
+ "before", "startup", _base.start, ss, config.worker_listeners
+ )
_base.start_worker_reactor("synapse-user-dir", config)
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index b26a31dd54..65cbff95b9 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -175,21 +175,21 @@ class ApplicationService(object):
@defer.inlineCallbacks
def _matches_user(self, event, store):
if not event:
- defer.returnValue(False)
+ return False
if self.is_interested_in_user(event.sender):
- defer.returnValue(True)
+ return True
# also check m.room.member state key
if event.type == EventTypes.Member and self.is_interested_in_user(
event.state_key
):
- defer.returnValue(True)
+ return True
if not store:
- defer.returnValue(False)
+ return False
does_match = yield self._matches_user_in_member_list(event.room_id, store)
- defer.returnValue(does_match)
+ return does_match
@cachedInlineCallbacks(num_args=1, cache_context=True)
def _matches_user_in_member_list(self, room_id, store, cache_context):
@@ -200,8 +200,8 @@ class ApplicationService(object):
# check joined member events
for user_id in member_list:
if self.is_interested_in_user(user_id):
- defer.returnValue(True)
- defer.returnValue(False)
+ return True
+ return False
def _matches_room_id(self, event):
if hasattr(event, "room_id"):
@@ -211,13 +211,13 @@ class ApplicationService(object):
@defer.inlineCallbacks
def _matches_aliases(self, event, store):
if not store or not event:
- defer.returnValue(False)
+ return False
alias_list = yield store.get_aliases_for_room(event.room_id)
for alias in alias_list:
if self.is_interested_in_alias(alias):
- defer.returnValue(True)
- defer.returnValue(False)
+ return True
+ return False
@defer.inlineCallbacks
def is_interested(self, event, store=None):
@@ -231,15 +231,15 @@ class ApplicationService(object):
"""
# Do cheap checks first
if self._matches_room_id(event):
- defer.returnValue(True)
+ return True
if (yield self._matches_aliases(event, store)):
- defer.returnValue(True)
+ return True
if (yield self._matches_user(event, store)):
- defer.returnValue(True)
+ return True
- defer.returnValue(False)
+ return False
def is_interested_in_user(self, user_id):
return (
@@ -268,7 +268,7 @@ class ApplicationService(object):
def is_exclusive_room(self, room_id):
return self._is_exclusive(ApplicationService.NS_ROOMS, room_id)
- def get_exlusive_user_regexes(self):
+ def get_exclusive_user_regexes(self):
"""Get the list of regexes used to determine if a user is exclusively
registered by the AS
"""
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 571881775b..007ca75a94 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -97,40 +97,40 @@ class ApplicationServiceApi(SimpleHttpClient):
@defer.inlineCallbacks
def query_user(self, service, user_id):
if service.url is None:
- defer.returnValue(False)
+ return False
uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
response = None
try:
response = yield self.get_json(uri, {"access_token": service.hs_token})
if response is not None: # just an empty json object
- defer.returnValue(True)
+ return True
except CodeMessageException as e:
if e.code == 404:
- defer.returnValue(False)
+ return False
return
logger.warning("query_user to %s received %s", uri, e.code)
except Exception as ex:
logger.warning("query_user to %s threw exception %s", uri, ex)
- defer.returnValue(False)
+ return False
@defer.inlineCallbacks
def query_alias(self, service, alias):
if service.url is None:
- defer.returnValue(False)
+ return False
uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
response = None
try:
response = yield self.get_json(uri, {"access_token": service.hs_token})
if response is not None: # just an empty json object
- defer.returnValue(True)
+ return True
except CodeMessageException as e:
logger.warning("query_alias to %s received %s", uri, e.code)
if e.code == 404:
- defer.returnValue(False)
+ return False
return
except Exception as ex:
logger.warning("query_alias to %s threw exception %s", uri, ex)
- defer.returnValue(False)
+ return False
@defer.inlineCallbacks
def query_3pe(self, service, kind, protocol, fields):
@@ -141,7 +141,7 @@ class ApplicationServiceApi(SimpleHttpClient):
else:
raise ValueError("Unrecognised 'kind' argument %r to query_3pe()", kind)
if service.url is None:
- defer.returnValue([])
+ return []
uri = "%s%s/thirdparty/%s/%s" % (
service.url,
@@ -155,7 +155,7 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning(
"query_3pe to %s returned an invalid response %r", uri, response
)
- defer.returnValue([])
+ return []
ret = []
for r in response:
@@ -166,14 +166,14 @@ class ApplicationServiceApi(SimpleHttpClient):
"query_3pe to %s returned an invalid result %r", uri, r
)
- defer.returnValue(ret)
+ return ret
except Exception as ex:
logger.warning("query_3pe to %s threw exception %s", uri, ex)
- defer.returnValue([])
+ return []
def get_3pe_protocol(self, service, protocol):
if service.url is None:
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def _get():
@@ -189,7 +189,7 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning(
"query_3pe_protocol to %s did not return a" " valid result", uri
)
- defer.returnValue(None)
+ return None
for instance in info.get("instances", []):
network_id = instance.get("network_id", None)
@@ -198,10 +198,10 @@ class ApplicationServiceApi(SimpleHttpClient):
service.id, network_id
).to_string()
- defer.returnValue(info)
+ return info
except Exception as ex:
logger.warning("query_3pe_protocol to %s threw exception %s", uri, ex)
- defer.returnValue(None)
+ return None
key = (service.id, protocol)
return self.protocol_meta_cache.wrap(key, _get)
@@ -209,7 +209,7 @@ class ApplicationServiceApi(SimpleHttpClient):
@defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None):
if service.url is None:
- defer.returnValue(True)
+ return True
events = self._serialize(events)
@@ -229,14 +229,14 @@ class ApplicationServiceApi(SimpleHttpClient):
)
sent_transactions_counter.labels(service.id).inc()
sent_events_counter.labels(service.id).inc(len(events))
- defer.returnValue(True)
+ return True
return
except CodeMessageException as e:
logger.warning("push_bulk to %s received %s", uri, e.code)
except Exception as ex:
logger.warning("push_bulk to %s threw exception %s", uri, ex)
failed_transactions_counter.labels(service.id).inc()
- defer.returnValue(False)
+ return False
def _serialize(self, events):
time_now = self.clock.time_msec()
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index e5b36494f5..42a350bff8 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -193,7 +193,7 @@ class _TransactionController(object):
@defer.inlineCallbacks
def _is_service_up(self, service):
state = yield self.store.get_appservice_state(service)
- defer.returnValue(state == ApplicationServiceState.UP or state is None)
+ return state == ApplicationServiceState.UP or state is None
class _Recoverer(object):
@@ -208,7 +208,7 @@ class _Recoverer(object):
r.service.id,
)
r.recover()
- defer.returnValue(recoverers)
+ return recoverers
def __init__(self, clock, store, as_api, service, callback):
self.clock = clock
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 6ce5cd07fb..54230342a1 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -18,6 +18,7 @@
import argparse
import errno
import os
+from io import open as io_open
from textwrap import dedent
from six import integer_types
@@ -133,7 +134,7 @@ class Config(object):
@classmethod
def read_file(cls, file_path, config_name):
cls.check_file(file_path, config_name)
- with open(file_path) as file_stream:
+ with io_open(file_path, encoding="utf-8") as file_stream:
return file_stream.read()
def invoke_all(self, name, *args, **kargs):
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 40502a5798..d321d00b80 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import logging
import logging.config
import os
@@ -75,10 +76,8 @@ root:
class LoggingConfig(Config):
def read_config(self, config, **kwargs):
- self.verbosity = config.get("verbose", 0)
- self.no_redirect_stdio = config.get("no_redirect_stdio", False)
self.log_config = self.abspath(config.get("log_config"))
- self.log_file = self.abspath(config.get("log_file"))
+ self.no_redirect_stdio = config.get("no_redirect_stdio", False)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
log_config = os.path.join(config_dir_path, server_name + ".log.config")
@@ -94,39 +93,13 @@ class LoggingConfig(Config):
)
def read_arguments(self, args):
- if args.verbose is not None:
- self.verbosity = args.verbose
if args.no_redirect_stdio is not None:
self.no_redirect_stdio = args.no_redirect_stdio
- if args.log_config is not None:
- self.log_config = args.log_config
- if args.log_file is not None:
- self.log_file = args.log_file
@staticmethod
def add_arguments(parser):
logging_group = parser.add_argument_group("logging")
logging_group.add_argument(
- "-v",
- "--verbose",
- dest="verbose",
- action="count",
- help="The verbosity level. Specify multiple times to increase "
- "verbosity. (Ignored if --log-config is specified.)",
- )
- logging_group.add_argument(
- "-f",
- "--log-file",
- dest="log_file",
- help="File to log to. (Ignored if --log-config is specified.)",
- )
- logging_group.add_argument(
- "--log-config",
- dest="log_config",
- default=None,
- help="Python logging config file",
- )
- logging_group.add_argument(
"-n",
"--no-redirect-stdio",
action="store_true",
@@ -153,58 +126,29 @@ def setup_logging(config, use_worker_options=False):
config (LoggingConfig | synapse.config.workers.WorkerConfig):
configuration data
- use_worker_options (bool): True to use 'worker_log_config' and
- 'worker_log_file' options instead of 'log_config' and 'log_file'.
+ use_worker_options (bool): True to use the 'worker_log_config' option
+ instead of 'log_config'.
register_sighup (func | None): Function to call to register a
sighup handler.
"""
log_config = config.worker_log_config if use_worker_options else config.log_config
- log_file = config.worker_log_file if use_worker_options else config.log_file
-
- log_format = (
- "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
- " - %(message)s"
- )
if log_config is None:
- # We don't have a logfile, so fall back to the 'verbosity' param from
- # the config or cmdline. (Note that we generate a log config for new
- # installs, so this will be an unusual case)
- level = logging.INFO
- level_for_storage = logging.INFO
- if config.verbosity:
- level = logging.DEBUG
- if config.verbosity > 1:
- level_for_storage = logging.DEBUG
+ log_format = (
+ "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
+ " - %(message)s"
+ )
logger = logging.getLogger("")
- logger.setLevel(level)
-
- logging.getLogger("synapse.storage.SQL").setLevel(level_for_storage)
+ logger.setLevel(logging.INFO)
+ logging.getLogger("synapse.storage.SQL").setLevel(logging.INFO)
formatter = logging.Formatter(log_format)
- if log_file:
- # TODO: Customisable file size / backup count
- handler = logging.handlers.RotatingFileHandler(
- log_file, maxBytes=(1000 * 1000 * 100), backupCount=3, encoding="utf8"
- )
-
- def sighup(signum, stack):
- logger.info("Closing log file due to SIGHUP")
- handler.doRollover()
- logger.info("Opened new log file due to SIGHUP")
-
- else:
- handler = logging.StreamHandler()
-
- def sighup(*args):
- pass
+ handler = logging.StreamHandler()
handler.setFormatter(formatter)
-
handler.addFilter(LoggingContextFilter(request=""))
-
logger.addHandler(handler)
else:
@@ -218,8 +162,7 @@ def setup_logging(config, use_worker_options=False):
logging.info("Reloaded log config from %s due to SIGHUP", log_config)
load_log_config()
-
- appbase.register_sighup(sighup)
+ appbase.register_sighup(sighup)
# make sure that the first thing we log is a thing we can grep backwards
# for
diff --git a/synapse/config/password.py b/synapse/config/password.py
index d5b5953f2f..47df98f41a 100644
--- a/synapse/config/password.py
+++ b/synapse/config/password.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2015-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -29,6 +31,10 @@ class PasswordConfig(Config):
self.password_localdb_enabled = password_config.get("localdb_enabled", True)
self.password_pepper = password_config.get("pepper", "")
+ # Password policy
+ self.password_policy = password_config.get("policy", {})
+ self.password_policy_enabled = self.password_policy.pop("enabled", False)
+
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
password_config:
@@ -46,4 +52,34 @@ class PasswordConfig(Config):
# DO NOT CHANGE THIS AFTER INITIAL SETUP!
#
#pepper: "EVEN_MORE_SECRET"
+
+ # Define and enforce a password policy. Each parameter is optional, boolean
+ # parameters default to 'false' and integer parameters default to 0.
+ # This is an early implementation of MSC2000.
+ #
+ #policy:
+ # Whether to enforce the password policy.
+ #
+ #enabled: true
+
+ # Minimum accepted length for a password.
+ #
+ #minimum_length: 15
+
+ # Whether a password must contain at least one digit.
+ #
+ #require_digit: true
+
+ # Whether a password must contain at least one symbol.
+ # A symbol is any character that's not a number or a letter.
+ #
+ #require_symbol: true
+
+ # Whether a password must contain at least one lowercase letter.
+ #
+ #require_lowercase: true
+
+ # Whether a password must contain at least one lowercase letter.
+ #
+ #require_uppercase: true
"""
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 33f31cf213..a1ea4fe02d 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -68,6 +68,9 @@ class RatelimitConfig(Config):
)
self.rc_registration = RateLimitConfig(config.get("rc_registration", {}))
+ self.rc_third_party_invite = RateLimitConfig(
+ config.get("rc_third_party_invite", {})
+ )
rc_login_config = config.get("rc_login", {})
self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {}))
@@ -102,6 +105,8 @@ class RatelimitConfig(Config):
# - one for login that ratelimits login requests based on the account the
# client is attempting to log into, based on the amount of failed login
# attempts for this account.
+ # - one that ratelimits third-party invites requests based on the account
+ # that's making the requests.
#
# The defaults are as shown below.
#
@@ -123,6 +128,10 @@ class RatelimitConfig(Config):
# failed_attempts:
# per_second: 0.17
# burst_count: 3
+ #
+ #rc_third_party_invite:
+ # per_second: 0.2
+ # burst_count: 10
# Ratelimiting settings for incoming federation
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index c3de7a4e32..3240e30f70 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -13,8 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import os
from distutils.util import strtobool
+import pkg_resources
+
from synapse.config._base import Config, ConfigError
from synapse.types import RoomAlias
from synapse.util.stringutils import random_string_with_symbols
@@ -41,8 +44,36 @@ class AccountValidityConfig(Config):
self.startup_job_max_delta = self.period * 10.0 / 100.0
- if self.renew_by_email_enabled and "public_baseurl" not in synapse_config:
- raise ConfigError("Can't send renewal emails without 'public_baseurl'")
+ if self.renew_by_email_enabled:
+ if "public_baseurl" not in synapse_config:
+ raise ConfigError("Can't send renewal emails without 'public_baseurl'")
+
+ template_dir = config.get("template_dir")
+
+ if not template_dir:
+ template_dir = pkg_resources.resource_filename("synapse", "res/templates")
+
+ if "account_renewed_html_path" in config:
+ file_path = os.path.join(template_dir, config["account_renewed_html_path"])
+
+ self.account_renewed_html_content = self.read_file(
+ file_path, "account_validity.account_renewed_html_path"
+ )
+ else:
+ self.account_renewed_html_content = (
+ "<html><body>Your account has been successfully renewed.</body><html>"
+ )
+
+ if "invalid_token_html_path" in config:
+ file_path = os.path.join(template_dir, config["invalid_token_html_path"])
+
+ self.invalid_token_html_content = self.read_file(
+ file_path, "account_validity.invalid_token_html_path"
+ )
+ else:
+ self.invalid_token_html_content = (
+ "<html><body>Invalid renewal token.</body><html>"
+ )
class RegistrationConfig(Config):
@@ -61,8 +92,19 @@ class RegistrationConfig(Config):
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
+ self.check_is_for_allowed_local_3pids = config.get(
+ "check_is_for_allowed_local_3pids", None
+ )
+ self.allow_invited_3pids = config.get("allow_invited_3pids", False)
+
+ self.disable_3pid_changes = config.get("disable_3pid_changes", False)
+
self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
self.registration_shared_secret = config.get("registration_shared_secret")
+ self.register_mxid_from_3pid = config.get("register_mxid_from_3pid")
+ self.register_just_use_email_for_display_name = config.get(
+ "register_just_use_email_for_display_name", False
+ )
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.trusted_third_party_id_servers = config.get(
@@ -80,6 +122,18 @@ class RegistrationConfig(Config):
raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,))
self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
+ self.disable_set_displayname = config.get("disable_set_displayname", False)
+ self.disable_set_avatar_url = config.get("disable_set_avatar_url", False)
+
+ self.replicate_user_profiles_to = config.get("replicate_user_profiles_to", [])
+ if not isinstance(self.replicate_user_profiles_to, list):
+ self.replicate_user_profiles_to = [self.replicate_user_profiles_to]
+
+ self.shadow_server = config.get("shadow_server", None)
+ self.rewrite_identity_server_urls = config.get(
+ "rewrite_identity_server_urls", {}
+ )
+
self.disable_msisdn_registration = config.get(
"disable_msisdn_registration", False
)
@@ -145,6 +199,16 @@ class RegistrationConfig(Config):
# period: 6w
# renew_at: 1w
# renew_email_subject: "Renew your %%(app)s account"
+ # # Directory in which Synapse will try to find the HTML files to serve to the
+ # # user when trying to renew an account. Optional, defaults to
+ # # synapse/res/templates.
+ # template_dir: "res/templates"
+ # # HTML to be displayed to the user after they successfully renewed their
+ # # account. Optional.
+ # account_renewed_html_path: "account_renewed.html"
+ # # HTML to be displayed when the user tries to renew an account with an invalid
+ # # renewal token. Optional.
+ # invalid_token_html_path: "invalid_token.html"
# Time that a user's session remains valid for, after they log in.
#
@@ -168,9 +232,32 @@ class RegistrationConfig(Config):
#
#disable_msisdn_registration: true
+ # Derive the user's matrix ID from a type of 3PID used when registering.
+ # This overrides any matrix ID the user proposes when calling /register
+ # The 3PID type should be present in registrations_require_3pid to avoid
+ # users failing to register if they don't specify the right kind of 3pid.
+ #
+ #register_mxid_from_3pid: email
+
+ # Uncomment to set the display name of new users to their email address,
+ # rather than using the default heuristic.
+ #
+ #register_just_use_email_for_display_name: true
+
# Mandate that users are only allowed to associate certain formats of
# 3PIDs with accounts on this server.
#
+ # Use an Identity Server to establish which 3PIDs are allowed to register?
+ # Overrides allowed_local_3pids below.
+ #
+ #check_is_for_allowed_local_3pids: matrix.org
+ #
+ # If you are using an IS you can also check whether that IS registers
+ # pending invites for the given 3PID (and then allow it to sign up on
+ # the platform):
+ #
+ #allow_invited_3pids: False
+ #
#allowed_local_3pids:
# - medium: email
# pattern: '.*@matrix\\.org'
@@ -179,6 +266,11 @@ class RegistrationConfig(Config):
# - medium: msisdn
# pattern: '\\+44'
+ # If true, stop users from trying to change the 3PIDs associated with
+ # their accounts.
+ #
+ #disable_3pid_changes: False
+
# Enable 3PIDs lookup requests to identity servers from this server.
#
#enable_3pid_lookup: true
@@ -220,6 +312,30 @@ class RegistrationConfig(Config):
# - matrix.org
# - vector.im
+ # If enabled, user IDs, display names and avatar URLs will be replicated
+ # to this server whenever they change.
+ # This is an experimental API currently implemented by sydent to support
+ # cross-homeserver user directories.
+ #
+ #replicate_user_profiles_to: example.com
+
+ # If specified, attempt to replay registrations, profile changes & 3pid
+ # bindings on the given target homeserver via the AS API. The HS is authed
+ # via a given AS token.
+ #
+ #shadow_server:
+ # hs_url: https://shadow.example.com
+ # hs: shadow.example.com
+ # as_token: 12u394refgbdhivsia
+
+ # If enabled, don't let users set their own display names/avatars
+ # other than for the very first time (unless they are a server admin).
+ # Useful when provisioning users based on the contents of a 3rd party
+ # directory and to avoid ambiguities.
+ #
+ #disable_set_displayname: False
+ #disable_set_avatar_url: False
+
# Users who register on this homeserver will automatically be joined
# to these rooms
#
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 80a628d9b0..c6b737fb6b 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -91,6 +91,12 @@ class ContentRepositoryConfig(Config):
self.max_image_pixels = self.parse_size(config.get("max_image_pixels", "32M"))
self.max_spider_size = self.parse_size(config.get("max_spider_size", "10M"))
+ self.max_avatar_size = config.get("max_avatar_size")
+ if self.max_avatar_size:
+ self.max_avatar_size = self.parse_size(self.max_avatar_size)
+
+ self.allowed_avatar_mimetypes = config.get("allowed_avatar_mimetypes", [])
+
self.media_store_path = self.ensure_directory(
config.get("media_store_path", "media_store")
)
@@ -229,6 +235,30 @@ class ContentRepositoryConfig(Config):
#
#max_upload_size: 10M
+ # The largest allowed size for a user avatar. If not defined, no
+ # restriction will be imposed.
+ #
+ # Note that this only applies when an avatar is changed globally.
+ # Per-room avatar changes are not affected. See allow_per_room_profiles
+ # for disabling that functionality.
+ #
+ # Note that user avatar changes will not work if this is set without
+ # using Synapse's local media repo.
+ #
+ #max_avatar_size: 10M
+
+ # Allow mimetypes for a user avatar. If not defined, no restriction will
+ # be imposed.
+ #
+ # Note that this only applies when an avatar is changed globally.
+ # Per-room avatar changes are not affected. See allow_per_room_profiles
+ # for disabling that functionality.
+ #
+ # Note that user avatar changes will not work if this is set without
+ # using Synapse's local media repo.
+ #
+ #allowed_avatar_mimetypes: ["image/png", "image/jpeg", "image/gif"]
+
# Maximum number of pixels that will be thumbnailed
#
#max_image_pixels: 32M
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 00170f1393..d32a2e4b0a 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -82,6 +82,12 @@ class ServerConfig(Config):
"require_auth_for_profile_requests", False
)
+ # Whether to require sharing a room with a user to retrieve their
+ # profile data
+ self.limit_profile_requests_to_known_users = config.get(
+ "limit_profile_requests_to_known_users", False
+ )
+
if "restrict_public_rooms_to_local_users" in config and (
"allow_public_rooms_without_auth" in config
or "allow_public_rooms_over_federation" in config
@@ -209,6 +215,130 @@ class ServerConfig(Config):
# events with profile information that differ from the target's global profile.
self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)
+ # Whether to show the users on this homeserver in the user directory. Defaults to
+ # True.
+ self.show_users_in_user_directory = config.get(
+ "show_users_in_user_directory", True
+ )
+
+ retention_config = config.get("retention")
+ if retention_config is None:
+ retention_config = {}
+
+ self.retention_enabled = retention_config.get("enabled", False)
+
+ retention_default_policy = retention_config.get("default_policy")
+
+ if retention_default_policy is not None:
+ self.retention_default_min_lifetime = retention_default_policy.get(
+ "min_lifetime"
+ )
+ if self.retention_default_min_lifetime is not None:
+ self.retention_default_min_lifetime = self.parse_duration(
+ self.retention_default_min_lifetime
+ )
+
+ self.retention_default_max_lifetime = retention_default_policy.get(
+ "max_lifetime"
+ )
+ if self.retention_default_max_lifetime is not None:
+ self.retention_default_max_lifetime = self.parse_duration(
+ self.retention_default_max_lifetime
+ )
+
+ if (
+ self.retention_default_min_lifetime is not None
+ and self.retention_default_max_lifetime is not None
+ and (
+ self.retention_default_min_lifetime
+ > self.retention_default_max_lifetime
+ )
+ ):
+ raise ConfigError(
+ "The default retention policy's 'min_lifetime' can not be greater"
+ " than its 'max_lifetime'"
+ )
+ else:
+ self.retention_default_min_lifetime = None
+ self.retention_default_max_lifetime = None
+
+ self.retention_allowed_lifetime_min = retention_config.get(
+ "allowed_lifetime_min"
+ )
+ if self.retention_allowed_lifetime_min is not None:
+ self.retention_allowed_lifetime_min = self.parse_duration(
+ self.retention_allowed_lifetime_min
+ )
+
+ self.retention_allowed_lifetime_max = retention_config.get(
+ "allowed_lifetime_max"
+ )
+ if self.retention_allowed_lifetime_max is not None:
+ self.retention_allowed_lifetime_max = self.parse_duration(
+ self.retention_allowed_lifetime_max
+ )
+
+ if (
+ self.retention_allowed_lifetime_min is not None
+ and self.retention_allowed_lifetime_max is not None
+ and self.retention_allowed_lifetime_min
+ > self.retention_allowed_lifetime_max
+ ):
+ raise ConfigError(
+ "Invalid retention policy limits: 'allowed_lifetime_min' can not be"
+ " greater than 'allowed_lifetime_max'"
+ )
+
+ self.retention_purge_jobs = []
+ for purge_job_config in retention_config.get("purge_jobs", []):
+ interval_config = purge_job_config.get("interval")
+
+ if interval_config is None:
+ raise ConfigError(
+ "A retention policy's purge jobs configuration must have the"
+ " 'interval' key set."
+ )
+
+ interval = self.parse_duration(interval_config)
+
+ shortest_max_lifetime = purge_job_config.get("shortest_max_lifetime")
+
+ if shortest_max_lifetime is not None:
+ shortest_max_lifetime = self.parse_duration(shortest_max_lifetime)
+
+ longest_max_lifetime = purge_job_config.get("longest_max_lifetime")
+
+ if longest_max_lifetime is not None:
+ longest_max_lifetime = self.parse_duration(longest_max_lifetime)
+
+ if (
+ shortest_max_lifetime is not None
+ and longest_max_lifetime is not None
+ and shortest_max_lifetime > longest_max_lifetime
+ ):
+ raise ConfigError(
+ "A retention policy's purge jobs configuration's"
+ " 'shortest_max_lifetime' value can not be greater than its"
+ " 'longest_max_lifetime' value."
+ )
+
+ self.retention_purge_jobs.append(
+ {
+ "interval": interval,
+ "shortest_max_lifetime": shortest_max_lifetime,
+ "longest_max_lifetime": longest_max_lifetime,
+ }
+ )
+
+ if not self.retention_purge_jobs:
+ self.retention_purge_jobs = [
+ {
+ "interval": self.parse_duration("1d"),
+ "shortest_max_lifetime": None,
+ "longest_max_lifetime": None,
+ }
+ ]
+
self.listeners = []
for listener in config.get("listeners", []):
if not isinstance(listener.get("port", None), int):
@@ -395,6 +525,13 @@ class ServerConfig(Config):
#
#require_auth_for_profile_requests: true
+ # Whether to require a user to share a room with another user in order
+ # to retrieve their profile information. Only checked on Client-Server
+ # requests. Profile requests from other servers should be checked by the
+ # requesting server. Defaults to 'false'.
+ #
+ # limit_profile_requests_to_known_users: true
+
# If set to 'false', requires authentication to access the server's public rooms
# directory through the client API. Defaults to 'true'.
#
@@ -627,6 +764,74 @@ class ServerConfig(Config):
# Defaults to 'true'.
#
#allow_per_room_profiles: false
+
+ # Whether to show the users on this homeserver in the user directory. Defaults to
+ # 'true'.
+ #
+ #show_users_in_user_directory: false
+
+ # Message retention policy at the server level.
+ #
+ # Room admins and mods can define a retention period for their rooms using the
+ # 'm.room.retention' state event, and server admins can cap this period by setting
+ # the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options.
+ #
+ # If this feature is enabled, Synapse will regularly look for and purge events
+ # which are older than the room's maximum retention period. Synapse will also
+ # filter events received over federation so that events that should have been
+ # purged are ignored and not stored again.
+ #
+ retention:
+ # The message retention policies feature is disabled by default. Uncomment the
+ # following line to enable it.
+ #
+ #enabled: true
+
+ # Default retention policy. If set, Synapse will apply it to rooms that lack the
+ # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't
+ # matter much because Synapse doesn't take it into account yet.
+ #
+ #default_policy:
+ # min_lifetime: 1d
+ # max_lifetime: 1y
+
+ # Retention policy limits. If set, a user won't be able to send a
+ # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
+ # that's not within this range. This is especially useful in closed federations,
+ # in which server admins can make sure every federating server applies the same
+ # rules.
+ #
+ #allowed_lifetime_min: 1d
+ #allowed_lifetime_max: 1y
+
+ # Server admins can define the settings of the background jobs purging the
+ # events which lifetime has expired under the 'purge_jobs' section.
+ #
+ # If no configuration is provided, a single job will be set up to delete expired
+ # events in every room daily.
+ #
+ # Each job's configuration defines which range of message lifetimes the job
+ # takes care of. For example, if 'shortest_max_lifetime' is '2d' and
+ # 'longest_max_lifetime' is '3d', the job will handle purging expired events in
+ # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and
+ # lower than or equal to 3 days. Both the minimum and the maximum value of a
+ # range are optional, e.g. a job with no 'shortest_max_lifetime' and a
+ # 'longest_max_lifetime' of '3d' will handle every room with a retention policy
+ # which 'max_lifetime' is lower than or equal to three days.
+ #
+ # The rationale for this per-job configuration is that some rooms might have a
+ # retention policy with a low 'max_lifetime', where history needs to be purged
+ # of outdated messages on a very frequent basis (e.g. every 5min), but not want
+ # that purge to be performed by a job that's iterating over every room it knows,
+ # which would be quite heavy on the server.
+ #
+ #purge_jobs:
+ # - shortest_max_lifetime: 1d
+ # longest_max_lifetime: 3d
+ # interval: 5m:
+ # - shortest_max_lifetime: 3d
+ # longest_max_lifetime: 1y
+ # interval: 24h
"""
% locals()
)
diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py
index 4479454415..95e7ccb3a3 100644
--- a/synapse/config/tracer.py
+++ b/synapse/config/tracer.py
@@ -23,6 +23,12 @@ class TracerConfig(Config):
opentracing_config = {}
self.opentracer_enabled = opentracing_config.get("enabled", False)
+
+ self.jaeger_config = opentracing_config.get(
+ "jaeger_config",
+ {"sampler": {"type": "const", "param": 1}, "logging": False},
+ )
+
if not self.opentracer_enabled:
return
@@ -56,4 +62,20 @@ class TracerConfig(Config):
#
#homeserver_whitelist:
# - ".*"
+
+ # Jaeger can be configured to sample traces at different rates.
+ # All configuration options provided by Jaeger can be set here.
+ # Jaeger's configuration mostly related to trace sampling which
+ # is documented here:
+ # https://www.jaegertracing.io/docs/1.13/sampling/.
+ #
+ #jaeger_config:
+ # sampler:
+ # type: const
+ # param: 1
+
+ # Logging whether spans were started and reported
+ #
+ # logging:
+ # false
"""
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index f6313e17d4..96493a5dcc 100644
--- a/synapse/config/user_directory.py
+++ b/synapse/config/user_directory.py
@@ -24,6 +24,7 @@ class UserDirectoryConfig(Config):
def read_config(self, config, **kwargs):
self.user_directory_search_enabled = True
self.user_directory_search_all_users = False
+ self.user_directory_defer_to_id_server = None
user_directory_config = config.get("user_directory", None)
if user_directory_config:
self.user_directory_search_enabled = user_directory_config.get(
@@ -32,6 +33,9 @@ class UserDirectoryConfig(Config):
self.user_directory_search_all_users = user_directory_config.get(
"search_all_users", False
)
+ self.user_directory_defer_to_id_server = user_directory_config.get(
+ "defer_to_id_server", None
+ )
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
@@ -50,4 +54,9 @@ class UserDirectoryConfig(Config):
#user_directory:
# enabled: true
# search_all_users: false
+ #
+ # # If this is set, user search will be delegated to this ID server instead
+ # # of synapse performing the search itself.
+ # # This is an experimental API.
+ # defer_to_id_server: https://id.example.com
"""
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 3b75471d85..bc0fc165e3 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -31,7 +31,6 @@ class WorkerConfig(Config):
self.worker_listeners = config.get("worker_listeners", [])
self.worker_daemonize = config.get("worker_daemonize")
self.worker_pid_file = config.get("worker_pid_file")
- self.worker_log_file = config.get("worker_log_file")
self.worker_log_config = config.get("worker_log_config")
# The host used to connect to the main synapse
@@ -78,9 +77,5 @@ class WorkerConfig(Config):
if args.daemonize is not None:
self.worker_daemonize = args.daemonize
- if args.log_config is not None:
- self.worker_log_config = args.log_config
- if args.log_file is not None:
- self.worker_log_file = args.log_file
if args.manhole is not None:
self.worker_manhole = args.worker_manhole
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 341c863152..6c3e885e72 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -238,27 +238,9 @@ class Keyring(object):
"""
try:
- # create a deferred for each server we're going to look up the keys
- # for; we'll resolve them once we have completed our lookups.
- # These will be passed into wait_for_previous_lookups to block
- # any other lookups until we have finished.
- # The deferreds are called with no logcontext.
- server_to_deferred = {
- rq.server_name: defer.Deferred() for rq in verify_requests
- }
-
- # We want to wait for any previous lookups to complete before
- # proceeding.
- yield self.wait_for_previous_lookups(server_to_deferred)
+ ctx = LoggingContext.current_context()
- # Actually start fetching keys.
- self._get_server_verify_keys(verify_requests)
-
- # When we've finished fetching all the keys for a given server_name,
- # resolve the deferred passed to `wait_for_previous_lookups` so that
- # any lookups waiting will proceed.
- #
- # map from server name to a set of request ids
+ # map from server name to a set of outstanding request ids
server_to_request_ids = {}
for verify_request in verify_requests:
@@ -266,40 +248,61 @@ class Keyring(object):
request_id = id(verify_request)
server_to_request_ids.setdefault(server_name, set()).add(request_id)
- def remove_deferreds(res, verify_request):
+ # Wait for any previous lookups to complete before proceeding.
+ yield self.wait_for_previous_lookups(server_to_request_ids.keys())
+
+ # take out a lock on each of the servers by sticking a Deferred in
+ # key_downloads
+ for server_name in server_to_request_ids.keys():
+ self.key_downloads[server_name] = defer.Deferred()
+ logger.debug("Got key lookup lock on %s", server_name)
+
+ # When we've finished fetching all the keys for a given server_name,
+ # drop the lock by resolving the deferred in key_downloads.
+ def drop_server_lock(server_name):
+ d = self.key_downloads.pop(server_name)
+ d.callback(None)
+
+ def lookup_done(res, verify_request):
server_name = verify_request.server_name
- request_id = id(verify_request)
- server_to_request_ids[server_name].discard(request_id)
- if not server_to_request_ids[server_name]:
- d = server_to_deferred.pop(server_name, None)
- if d:
- d.callback(None)
+ server_requests = server_to_request_ids[server_name]
+ server_requests.remove(id(verify_request))
+
+ # if there are no more requests for this server, we can drop the lock.
+ if not server_requests:
+ with PreserveLoggingContext(ctx):
+ logger.debug("Releasing key lookup lock on %s", server_name)
+
+ # ... but not immediately, as that can cause stack explosions if
+ # we get a long queue of lookups.
+ self.clock.call_later(0, drop_server_lock, server_name)
+
return res
for verify_request in verify_requests:
- verify_request.key_ready.addBoth(remove_deferreds, verify_request)
+ verify_request.key_ready.addBoth(lookup_done, verify_request)
+
+ # Actually start fetching keys.
+ self._get_server_verify_keys(verify_requests)
except Exception:
logger.exception("Error starting key lookups")
@defer.inlineCallbacks
- def wait_for_previous_lookups(self, server_to_deferred):
+ def wait_for_previous_lookups(self, server_names):
"""Waits for any previous key lookups for the given servers to finish.
Args:
- server_to_deferred (dict[str, Deferred]): server_name to deferred which gets
- resolved once we've finished looking up keys for that server.
- The Deferreds should be regular twisted ones which call their
- callbacks with no logcontext.
-
- Returns: a Deferred which resolves once all key lookups for the given
- servers have completed. Follows the synapse rules of logcontext
- preservation.
+ server_names (Iterable[str]): list of servers which we want to look up
+
+ Returns:
+ Deferred[None]: resolves once all key lookups for the given servers have
+ completed. Follows the synapse rules of logcontext preservation.
"""
loop_count = 1
while True:
wait_on = [
(server_name, self.key_downloads[server_name])
- for server_name in server_to_deferred.keys()
+ for server_name in server_names
if server_name in self.key_downloads
]
if not wait_on:
@@ -314,19 +317,6 @@ class Keyring(object):
loop_count += 1
- ctx = LoggingContext.current_context()
-
- def rm(r, server_name_):
- with PreserveLoggingContext(ctx):
- logger.debug("Releasing key lookup lock on %s", server_name_)
- self.key_downloads.pop(server_name_, None)
- return r
-
- for server_name, deferred in server_to_deferred.items():
- logger.debug("Got key lookup lock on %s", server_name)
- self.key_downloads[server_name] = deferred
- deferred.addBoth(rm, server_name)
-
def _get_server_verify_keys(self, verify_requests):
"""Tries to find at least one key for each verify request
@@ -472,7 +462,7 @@ class StoreKeyFetcher(KeyFetcher):
keys = {}
for (server_name, key_id), key in res.items():
keys.setdefault(server_name, {})[key_id] = key
- defer.returnValue(keys)
+ return keys
class BaseV2KeyFetcher(object):
@@ -576,7 +566,7 @@ class BaseV2KeyFetcher(object):
).addErrback(unwrapFirstError)
)
- defer.returnValue(verify_keys)
+ return verify_keys
class PerspectivesKeyFetcher(BaseV2KeyFetcher):
@@ -598,7 +588,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
result = yield self.get_server_verify_key_v2_indirect(
keys_to_fetch, key_server
)
- defer.returnValue(result)
+ return result
except KeyLookupError as e:
logger.warning(
"Key lookup failed from %r: %s", key_server.server_name, e
@@ -611,7 +601,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
str(e),
)
- defer.returnValue({})
+ return {}
results = yield make_deferred_yieldable(
defer.gatherResults(
@@ -625,7 +615,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
for server_name, keys in result.items():
union_of_keys.setdefault(server_name, {}).update(keys)
- defer.returnValue(union_of_keys)
+ return union_of_keys
@defer.inlineCallbacks
def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
@@ -711,7 +701,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
perspective_name, time_now_ms, added_keys
)
- defer.returnValue(keys)
+ return keys
def _validate_perspectives_response(self, key_server, response):
"""Optionally check the signature on the result of a /key/query request
@@ -853,7 +843,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
)
keys.update(response_keys)
- defer.returnValue(keys)
+ return keys
@defer.inlineCallbacks
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index db011e0407..3997751337 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -144,15 +144,13 @@ class EventBuilder(object):
if self._origin_server_ts is not None:
event_dict["origin_server_ts"] = self._origin_server_ts
- defer.returnValue(
- create_local_event_from_event_dict(
- clock=self._clock,
- hostname=self._hostname,
- signing_key=self._signing_key,
- format_version=self.format_version,
- event_dict=event_dict,
- internal_metadata_dict=self.internal_metadata.get_dict(),
- )
+ return create_local_event_from_event_dict(
+ clock=self._clock,
+ hostname=self._hostname,
+ signing_key=self._signing_key,
+ format_version=self.format_version,
+ event_dict=event_dict,
+ internal_metadata_dict=self.internal_metadata.get_dict(),
)
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index a9545e6c1b..acbcbeeced 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -133,19 +133,17 @@ class EventContext(object):
else:
prev_state_id = None
- defer.returnValue(
- {
- "prev_state_id": prev_state_id,
- "event_type": event.type,
- "event_state_key": event.state_key if event.is_state() else None,
- "state_group": self.state_group,
- "rejected": self.rejected,
- "prev_group": self.prev_group,
- "delta_ids": _encode_state_dict(self.delta_ids),
- "prev_state_events": self.prev_state_events,
- "app_service_id": self.app_service.id if self.app_service else None,
- }
- )
+ return {
+ "prev_state_id": prev_state_id,
+ "event_type": event.type,
+ "event_state_key": event.state_key if event.is_state() else None,
+ "state_group": self.state_group,
+ "rejected": self.rejected,
+ "prev_group": self.prev_group,
+ "delta_ids": _encode_state_dict(self.delta_ids),
+ "prev_state_events": self.prev_state_events,
+ "app_service_id": self.app_service.id if self.app_service else None,
+ }
@staticmethod
def deserialize(store, input):
@@ -202,7 +200,7 @@ class EventContext(object):
yield make_deferred_yieldable(self._fetching_state_deferred)
- defer.returnValue(self._current_state_ids)
+ return self._current_state_ids
@defer.inlineCallbacks
def get_prev_state_ids(self, store):
@@ -222,7 +220,7 @@ class EventContext(object):
yield make_deferred_yieldable(self._fetching_state_deferred)
- defer.returnValue(self._prev_state_ids)
+ return self._prev_state_ids
def get_cached_current_state_ids(self):
"""Gets the current state IDs if we have them already cached.
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 129771f183..f0de4d961f 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -46,13 +46,33 @@ class SpamChecker(object):
return self.spam_checker.check_event_for_spam(event)
- def user_may_invite(self, inviter_userid, invitee_userid, room_id):
+ def user_may_invite(
+ self,
+ inviter_userid,
+ invitee_userid,
+ third_party_invite,
+ room_id,
+ new_room,
+ published_room,
+ ):
"""Checks if a given user may send an invite
If this method returns false, the invite will be rejected.
Args:
- userid (string): The sender's user ID
+ inviter_userid (str)
+ invitee_userid (str|None): The user ID of the invitee. Is None
+ if this is a third party invite and the 3PID is not bound to a
+ user ID.
+ third_party_invite (dict|None): If a third party invite then is a
+ dict containing the medium and address of the invitee.
+ room_id (str)
+ new_room (bool): Whether the user is being invited to the room as
+ part of a room creation, if so the invitee would have been
+ included in the call to `user_may_create_room`.
+ published_room (bool): Whether the room the user is being invited
+ to has been published in the local homeserver's public room
+ directory.
Returns:
bool: True if the user may send an invite, otherwise False
@@ -61,16 +81,29 @@ class SpamChecker(object):
return True
return self.spam_checker.user_may_invite(
- inviter_userid, invitee_userid, room_id
+ inviter_userid,
+ invitee_userid,
+ third_party_invite,
+ room_id,
+ new_room,
+ published_room,
)
- def user_may_create_room(self, userid):
+ def user_may_create_room(
+ self, userid, invite_list, third_party_invite_list, cloning
+ ):
"""Checks if a given user may create a room
If this method returns false, the creation request will be rejected.
Args:
userid (string): The sender's user ID
+ invite_list (list[str]): List of user IDs that would be invited to
+ the new room.
+ third_party_invite_list (list[dict]): List of third party invites
+ for the new room.
+ cloning (bool): Whether the user is cloning an existing room, e.g.
+ upgrading a room.
Returns:
bool: True if the user may create a room, otherwise False
@@ -78,7 +111,9 @@ class SpamChecker(object):
if self.spam_checker is None:
return True
- return self.spam_checker.user_may_create_room(userid)
+ return self.spam_checker.user_may_create_room(
+ userid, invite_list, third_party_invite_list, cloning
+ )
def user_may_create_room_alias(self, userid, room_alias):
"""Checks if a given user may create a room alias
@@ -113,3 +148,21 @@ class SpamChecker(object):
return True
return self.spam_checker.user_may_publish_room(userid, room_id)
+
+ def user_may_join_room(self, userid, room_id, is_invited):
+ """Checks if a given users is allowed to join a room.
+
+ Is not called when the user creates a room.
+
+ Args:
+ userid (str)
+ room_id (str)
+ is_invited (bool): Whether the user is invited into the room
+
+ Returns:
+ bool: Whether the user may join the room
+ """
+ if self.spam_checker is None:
+ return True
+
+ return self.spam_checker.user_may_join_room(userid, room_id, is_invited)
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 8f5d95696b..714a9b1579 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -51,7 +51,7 @@ class ThirdPartyEventRules(object):
defer.Deferred[bool]: True if the event should be allowed, False if not.
"""
if self.third_party_rules is None:
- defer.returnValue(True)
+ return True
prev_state_ids = yield context.get_prev_state_ids(self.store)
@@ -61,7 +61,7 @@ class ThirdPartyEventRules(object):
state_events[key] = yield self.store.get_event(event_id, allow_none=True)
ret = yield self.third_party_rules.check_event_allowed(event, state_events)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def on_create_room(self, requester, config, is_requester_admin):
@@ -98,7 +98,7 @@ class ThirdPartyEventRules(object):
"""
if self.third_party_rules is None:
- defer.returnValue(True)
+ return True
state_ids = yield self.store.get_filtered_current_state_ids(room_id)
room_state_events = yield self.store.get_events(state_ids.values())
@@ -110,4 +110,4 @@ class ThirdPartyEventRules(object):
ret = yield self.third_party_rules.check_threepid_can_be_invited(
medium, address, state_events
)
- defer.returnValue(ret)
+ return ret
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 9487a886f5..07d1c5bcf0 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -360,7 +360,7 @@ class EventClientSerializer(object):
"""
# To handle the case of presence events and the like
if not isinstance(event, EventBase):
- defer.returnValue(event)
+ return event
event_id = event.event_id
serialized_event = serialize_event(event, time_now, **kwargs)
@@ -406,7 +406,7 @@ class EventClientSerializer(object):
"sender": edit.sender,
}
- defer.returnValue(serialized_event)
+ return serialized_event
def serialize_events(self, events, time_now, **kwargs):
"""Serializes multiple events.
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index f7ffd1d561..9df9287aa7 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from six import string_types
+from six import integer_types, string_types
from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
@@ -22,11 +22,12 @@ from synapse.types import EventID, RoomID, UserID
class EventValidator(object):
- def validate_new(self, event):
+ def validate_new(self, event, config):
"""Validates the event has roughly the right format
Args:
- event (FrozenEvent)
+ event (FrozenEvent): The event to validate.
+ config (Config): The homeserver's configuration.
"""
self.validate_builder(event)
@@ -67,6 +68,99 @@ class EventValidator(object):
Codes.INVALID_PARAM,
)
+ if event.type == EventTypes.Retention:
+ self._validate_retention(event, config)
+
+ def _validate_retention(self, event, config):
+ """Checks that an event that defines the retention policy for a room respects the
+ boundaries imposed by the server's administrator.
+
+ Args:
+ event (FrozenEvent): The event to validate.
+ config (Config): The homeserver's configuration.
+ """
+ min_lifetime = event.content.get("min_lifetime")
+ max_lifetime = event.content.get("max_lifetime")
+
+ if min_lifetime is not None:
+ if not isinstance(min_lifetime, integer_types):
+ raise SynapseError(
+ code=400,
+ msg="'min_lifetime' must be an integer",
+ errcode=Codes.BAD_JSON,
+ )
+
+ if (
+ config.retention_allowed_lifetime_min is not None
+ and min_lifetime < config.retention_allowed_lifetime_min
+ ):
+ raise SynapseError(
+ code=400,
+ msg=(
+ "'min_lifetime' can't be lower than the minimum allowed"
+ " value enforced by the server's administrator"
+ ),
+ errcode=Codes.BAD_JSON,
+ )
+
+ if (
+ config.retention_allowed_lifetime_max is not None
+ and min_lifetime > config.retention_allowed_lifetime_max
+ ):
+ raise SynapseError(
+ code=400,
+ msg=(
+ "'min_lifetime' can't be greater than the maximum allowed"
+ " value enforced by the server's administrator"
+ ),
+ errcode=Codes.BAD_JSON,
+ )
+
+ if max_lifetime is not None:
+ if not isinstance(max_lifetime, integer_types):
+ raise SynapseError(
+ code=400,
+ msg="'max_lifetime' must be an integer",
+ errcode=Codes.BAD_JSON,
+ )
+
+ if (
+ config.retention_allowed_lifetime_min is not None
+ and max_lifetime < config.retention_allowed_lifetime_min
+ ):
+ raise SynapseError(
+ code=400,
+ msg=(
+ "'max_lifetime' can't be lower than the minimum allowed value"
+ " enforced by the server's administrator"
+ ),
+ errcode=Codes.BAD_JSON,
+ )
+
+ if (
+ config.retention_allowed_lifetime_max is not None
+ and max_lifetime > config.retention_allowed_lifetime_max
+ ):
+ raise SynapseError(
+ code=400,
+ msg=(
+ "'max_lifetime' can't be greater than the maximum allowed"
+ " value enforced by the server's administrator"
+ ),
+ errcode=Codes.BAD_JSON,
+ )
+
+ if (
+ min_lifetime is not None
+ and max_lifetime is not None
+ and min_lifetime > max_lifetime
+ ):
+ raise SynapseError(
+ code=400,
+ msg="'min_lifetime' can't be greater than 'max_lifetime",
+ errcode=Codes.BAD_JSON,
+ )
+
def validate_builder(self, event):
"""Validates that the builder/event has roughly the right format. Only
checks values that we expect a proto event to have, rather than all the
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index f7bb806ae7..5a1e23a145 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -106,7 +106,7 @@ class FederationBase(object):
"Failed to find copy of %s with valid signature", pdu.event_id
)
- defer.returnValue(res)
+ return res
handle = preserve_fn(handle_check_result)
deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
@@ -116,9 +116,9 @@ class FederationBase(object):
).addErrback(unwrapFirstError)
if include_none:
- defer.returnValue(valid_pdus)
+ return valid_pdus
else:
- defer.returnValue([p for p in valid_pdus if p])
+ return [p for p in valid_pdus if p]
def _check_sigs_and_hash(self, room_version, pdu):
return make_deferred_yieldable(
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 3cb4b94420..25ed1257f1 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -213,7 +213,7 @@ class FederationClient(FederationBase):
).addErrback(unwrapFirstError)
)
- defer.returnValue(pdus)
+ return pdus
@defer.inlineCallbacks
@log_function
@@ -245,7 +245,7 @@ class FederationClient(FederationBase):
ev = self._get_pdu_cache.get(event_id)
if ev:
- defer.returnValue(ev)
+ return ev
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
@@ -307,7 +307,7 @@ class FederationClient(FederationBase):
if signed_pdu:
self._get_pdu_cache[event_id] = signed_pdu
- defer.returnValue(signed_pdu)
+ return signed_pdu
@defer.inlineCallbacks
@log_function
@@ -355,7 +355,7 @@ class FederationClient(FederationBase):
auth_chain.sort(key=lambda e: e.depth)
- defer.returnValue((pdus, auth_chain))
+ return (pdus, auth_chain)
except HttpResponseException as e:
if e.code == 400 or e.code == 404:
logger.info("Failed to use get_room_state_ids API, falling back")
@@ -404,7 +404,7 @@ class FederationClient(FederationBase):
signed_auth.sort(key=lambda e: e.depth)
- defer.returnValue((signed_pdus, signed_auth))
+ return (signed_pdus, signed_auth)
@defer.inlineCallbacks
def get_events_from_store_or_dest(self, destination, room_id, event_ids):
@@ -429,7 +429,7 @@ class FederationClient(FederationBase):
missing_events.discard(k)
if not missing_events:
- defer.returnValue((signed_events, failed_to_fetch))
+ return (signed_events, failed_to_fetch)
logger.debug(
"Fetching unknown state/auth events %s for room %s",
@@ -465,7 +465,7 @@ class FederationClient(FederationBase):
# We removed all events we successfully fetched from `batch`
failed_to_fetch.update(batch)
- defer.returnValue((signed_events, failed_to_fetch))
+ return (signed_events, failed_to_fetch)
@defer.inlineCallbacks
@log_function
@@ -485,7 +485,7 @@ class FederationClient(FederationBase):
signed_auth.sort(key=lambda e: e.depth)
- defer.returnValue(signed_auth)
+ return signed_auth
@defer.inlineCallbacks
def _try_destination_list(self, description, destinations, callback):
@@ -521,7 +521,7 @@ class FederationClient(FederationBase):
try:
res = yield callback(destination)
- defer.returnValue(res)
+ return res
except InvalidResponseError as e:
logger.warn("Failed to %s via %s: %s", description, destination, e)
except HttpResponseException as e:
@@ -615,7 +615,7 @@ class FederationClient(FederationBase):
event_dict=pdu_dict,
)
- defer.returnValue((destination, ev, event_format))
+ return (destination, ev, event_format)
return self._try_destination_list(
"make_" + membership, destinations, send_request
@@ -728,13 +728,11 @@ class FederationClient(FederationBase):
check_authchain_validity(signed_auth)
- defer.returnValue(
- {
- "state": signed_state,
- "auth_chain": signed_auth,
- "origin": destination,
- }
- )
+ return {
+ "state": signed_state,
+ "auth_chain": signed_auth,
+ "origin": destination,
+ }
return self._try_destination_list("send_join", destinations, send_request)
@@ -758,7 +756,7 @@ class FederationClient(FederationBase):
# FIXME: We should handle signature failures more gracefully.
- defer.returnValue(pdu)
+ return pdu
@defer.inlineCallbacks
def _do_send_invite(self, destination, pdu, room_version):
@@ -786,7 +784,7 @@ class FederationClient(FederationBase):
"invite_room_state": pdu.unsigned.get("invite_room_state", []),
},
)
- defer.returnValue(content)
+ return content
except HttpResponseException as e:
if e.code in [400, 404]:
err = e.to_synapse_error()
@@ -821,7 +819,7 @@ class FederationClient(FederationBase):
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
- defer.returnValue(content)
+ return content
def send_leave(self, destinations, pdu):
"""Sends a leave event to one of a list of homeservers.
@@ -856,7 +854,7 @@ class FederationClient(FederationBase):
)
logger.debug("Got content: %s", content)
- defer.returnValue(None)
+ return None
return self._try_destination_list("send_leave", destinations, send_request)
@@ -917,7 +915,7 @@ class FederationClient(FederationBase):
"missing": content.get("missing", []),
}
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def get_missing_events(
@@ -974,7 +972,7 @@ class FederationClient(FederationBase):
# get_missing_events
signed_events = []
- defer.returnValue(signed_events)
+ return signed_events
@defer.inlineCallbacks
def forward_third_party_invite(self, destinations, room_id, event_dict):
@@ -986,7 +984,7 @@ class FederationClient(FederationBase):
yield self.transport_layer.exchange_third_party_invite(
destination=destination, room_id=room_id, event_dict=event_dict
)
- defer.returnValue(None)
+ return None
except CodeMessageException:
raise
except Exception as e:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index ed2b6d5eef..d216c46dfe 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -99,7 +99,7 @@ class FederationServer(FederationBase):
res = self._transaction_from_pdus(pdus).get_dict()
- defer.returnValue((200, res))
+ return (200, res)
@defer.inlineCallbacks
@log_function
@@ -126,7 +126,7 @@ class FederationServer(FederationBase):
origin, transaction, request_time
)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _handle_incoming_transaction(self, origin, transaction, request_time):
@@ -147,8 +147,7 @@ class FederationServer(FederationBase):
"[%s] We've already responded to this request",
transaction.transaction_id,
)
- defer.returnValue(response)
- return
+ return response
logger.debug("[%s] Transaction is new", transaction.transaction_id)
@@ -163,7 +162,7 @@ class FederationServer(FederationBase):
yield self.transaction_actions.set_response(
origin, transaction, 400, response
)
- defer.returnValue((400, response))
+ return (400, response)
received_pdus_counter.inc(len(transaction.pdus))
@@ -265,7 +264,7 @@ class FederationServer(FederationBase):
logger.debug("Returning: %s", str(response))
yield self.transaction_actions.set_response(origin, transaction, 200, response)
- defer.returnValue((200, response))
+ return (200, response)
@defer.inlineCallbacks
def received_edu(self, origin, edu_type, content):
@@ -298,7 +297,7 @@ class FederationServer(FederationBase):
event_id,
)
- defer.returnValue((200, resp))
+ return (200, resp)
@defer.inlineCallbacks
def on_state_ids_request(self, origin, room_id, event_id):
@@ -315,9 +314,7 @@ class FederationServer(FederationBase):
state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id)
auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids)
- defer.returnValue(
- (200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids})
- )
+ return (200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids})
@defer.inlineCallbacks
def _on_context_state_request_compute(self, room_id, event_id):
@@ -336,12 +333,10 @@ class FederationServer(FederationBase):
)
)
- defer.returnValue(
- {
- "pdus": [pdu.get_pdu_json() for pdu in pdus],
- "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
- }
- )
+ return {
+ "pdus": [pdu.get_pdu_json() for pdu in pdus],
+ "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
+ }
@defer.inlineCallbacks
@log_function
@@ -349,15 +344,15 @@ class FederationServer(FederationBase):
pdu = yield self.handler.get_persisted_pdu(origin, event_id)
if pdu:
- defer.returnValue((200, self._transaction_from_pdus([pdu]).get_dict()))
+ return (200, self._transaction_from_pdus([pdu]).get_dict())
else:
- defer.returnValue((404, ""))
+ return (404, "")
@defer.inlineCallbacks
def on_query_request(self, query_type, args):
received_queries_counter.labels(query_type).inc()
resp = yield self.registry.on_query(query_type, args)
- defer.returnValue((200, resp))
+ return (200, resp)
@defer.inlineCallbacks
def on_make_join_request(self, origin, room_id, user_id, supported_versions):
@@ -371,9 +366,7 @@ class FederationServer(FederationBase):
pdu = yield self.handler.on_make_join_request(origin, room_id, user_id)
time_now = self._clock.time_msec()
- defer.returnValue(
- {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
- )
+ return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
@defer.inlineCallbacks
def on_invite_request(self, origin, content, room_version):
@@ -391,7 +384,7 @@ class FederationServer(FederationBase):
yield self.check_server_matches_acl(origin_host, pdu.room_id)
ret_pdu = yield self.handler.on_invite_request(origin, pdu)
time_now = self._clock.time_msec()
- defer.returnValue({"event": ret_pdu.get_pdu_json(time_now)})
+ return {"event": ret_pdu.get_pdu_json(time_now)}
@defer.inlineCallbacks
def on_send_join_request(self, origin, content, room_id):
@@ -407,16 +400,14 @@ class FederationServer(FederationBase):
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
- defer.returnValue(
- (
- 200,
- {
- "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
- "auth_chain": [
- p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
- ],
- },
- )
+ return (
+ 200,
+ {
+ "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
+ "auth_chain": [
+ p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
+ ],
+ },
)
@defer.inlineCallbacks
@@ -428,9 +419,7 @@ class FederationServer(FederationBase):
room_version = yield self.store.get_room_version(room_id)
time_now = self._clock.time_msec()
- defer.returnValue(
- {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
- )
+ return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
@defer.inlineCallbacks
def on_send_leave_request(self, origin, content, room_id):
@@ -445,7 +434,7 @@ class FederationServer(FederationBase):
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
yield self.handler.on_send_leave_request(origin, pdu)
- defer.returnValue((200, {}))
+ return (200, {})
@defer.inlineCallbacks
def on_event_auth(self, origin, room_id, event_id):
@@ -456,7 +445,7 @@ class FederationServer(FederationBase):
time_now = self._clock.time_msec()
auth_pdus = yield self.handler.on_event_auth(event_id)
res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
- defer.returnValue((200, res))
+ return (200, res)
@defer.inlineCallbacks
def on_query_auth_request(self, origin, content, room_id, event_id):
@@ -509,7 +498,7 @@ class FederationServer(FederationBase):
"missing": ret.get("missing", []),
}
- defer.returnValue((200, send_content))
+ return (200, send_content)
@log_function
def on_query_client_keys(self, origin, content):
@@ -548,7 +537,7 @@ class FederationServer(FederationBase):
),
)
- defer.returnValue({"one_time_keys": json_result})
+ return {"one_time_keys": json_result}
@defer.inlineCallbacks
@log_function
@@ -580,9 +569,7 @@ class FederationServer(FederationBase):
time_now = self._clock.time_msec()
- defer.returnValue(
- {"events": [ev.get_pdu_json(time_now) for ev in missing_events]}
- )
+ return {"events": [ev.get_pdu_json(time_now) for ev in missing_events]}
@log_function
def on_openid_userinfo(self, token):
@@ -676,14 +663,14 @@ class FederationServer(FederationBase):
ret = yield self.handler.exchange_third_party_invite(
sender_user_id, target_user_id, room_id, signed
)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
ret = yield self.handler.on_exchange_third_party_invite_request(
origin, room_id, event_dict
)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def check_server_matches_acl(self, server_name, room_id):
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index d46f4aaeb1..36f6d470dc 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -49,7 +49,7 @@ sent_pdus_destination_dist_count = Counter(
sent_pdus_destination_dist_total = Counter(
"synapse_federation_client_sent_pdu_destinations:total",
- "" "Total number of PDUs queued for sending across all destinations",
+ "Total number of PDUs queued for sending across all destinations",
)
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 9aab12c0d3..fad980b893 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -374,7 +374,7 @@ class PerDestinationQueue(object):
assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"
- defer.returnValue((edus, now_stream_id))
+ return (edus, now_stream_id)
@defer.inlineCallbacks
def _get_to_device_message_edus(self, limit):
@@ -393,4 +393,4 @@ class PerDestinationQueue(object):
for content in contents
]
- defer.returnValue((edus, stream_id))
+ return (edus, stream_id)
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 0460a8c4ac..52706302f2 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -133,4 +133,4 @@ class TransactionManager(object):
)
success = False
- defer.returnValue(success)
+ return success
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 1aae9ec9e7..2a6709ff48 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -183,7 +183,7 @@ class TransportLayerClient(object):
try_trailing_slash_on_400=True,
)
- defer.returnValue(response)
+ return response
@defer.inlineCallbacks
@log_function
@@ -201,7 +201,7 @@ class TransportLayerClient(object):
ignore_backoff=ignore_backoff,
)
- defer.returnValue(content)
+ return content
@defer.inlineCallbacks
@log_function
@@ -259,7 +259,7 @@ class TransportLayerClient(object):
ignore_backoff=ignore_backoff,
)
- defer.returnValue(content)
+ return content
@defer.inlineCallbacks
@log_function
@@ -270,7 +270,7 @@ class TransportLayerClient(object):
destination=destination, path=path, data=content
)
- defer.returnValue(response)
+ return response
@defer.inlineCallbacks
@log_function
@@ -288,7 +288,7 @@ class TransportLayerClient(object):
ignore_backoff=True,
)
- defer.returnValue(response)
+ return response
@defer.inlineCallbacks
@log_function
@@ -299,7 +299,7 @@ class TransportLayerClient(object):
destination=destination, path=path, data=content, ignore_backoff=True
)
- defer.returnValue(response)
+ return response
@defer.inlineCallbacks
@log_function
@@ -310,7 +310,7 @@ class TransportLayerClient(object):
destination=destination, path=path, data=content, ignore_backoff=True
)
- defer.returnValue(response)
+ return response
@defer.inlineCallbacks
@log_function
@@ -339,7 +339,7 @@ class TransportLayerClient(object):
destination=remote_server, path=path, args=args, ignore_backoff=True
)
- defer.returnValue(response)
+ return response
@defer.inlineCallbacks
@log_function
@@ -350,7 +350,7 @@ class TransportLayerClient(object):
destination=destination, path=path, data=event_dict
)
- defer.returnValue(response)
+ return response
@defer.inlineCallbacks
@log_function
@@ -359,7 +359,7 @@ class TransportLayerClient(object):
content = yield self.client.get_json(destination=destination, path=path)
- defer.returnValue(content)
+ return content
@defer.inlineCallbacks
@log_function
@@ -370,7 +370,7 @@ class TransportLayerClient(object):
destination=destination, path=path, data=content
)
- defer.returnValue(content)
+ return content
@defer.inlineCallbacks
@log_function
@@ -402,7 +402,7 @@ class TransportLayerClient(object):
content = yield self.client.post_json(
destination=destination, path=path, data=query_content, timeout=timeout
)
- defer.returnValue(content)
+ return content
@defer.inlineCallbacks
@log_function
@@ -426,7 +426,7 @@ class TransportLayerClient(object):
content = yield self.client.get_json(
destination=destination, path=path, timeout=timeout
)
- defer.returnValue(content)
+ return content
@defer.inlineCallbacks
@log_function
@@ -460,7 +460,7 @@ class TransportLayerClient(object):
content = yield self.client.post_json(
destination=destination, path=path, data=query_content, timeout=timeout
)
- defer.returnValue(content)
+ return content
@defer.inlineCallbacks
@log_function
@@ -488,7 +488,7 @@ class TransportLayerClient(object):
timeout=timeout,
)
- defer.returnValue(content)
+ return content
@log_function
def get_group_profile(self, destination, group_id, requester_user_id):
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index f497711133..dfd7ae041b 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -157,7 +157,7 @@ class GroupAttestionRenewer(object):
yield self.store.update_remote_attestion(group_id, user_id, attestation)
- defer.returnValue({})
+ return {}
def _start_renew_attestations(self):
return run_as_background_process("renew_attestations", self._renew_attestations)
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 168c9e3f84..d50e691436 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -85,7 +85,7 @@ class GroupsServerHandler(object):
if not is_admin:
raise SynapseError(403, "User is not admin in group")
- defer.returnValue(group)
+ return group
@defer.inlineCallbacks
def get_group_summary(self, group_id, requester_user_id):
@@ -151,22 +151,20 @@ class GroupsServerHandler(object):
group_id, requester_user_id
)
- defer.returnValue(
- {
- "profile": profile,
- "users_section": {
- "users": users,
- "roles": roles,
- "total_user_count_estimate": 0, # TODO
- },
- "rooms_section": {
- "rooms": rooms,
- "categories": categories,
- "total_room_count_estimate": 0, # TODO
- },
- "user": membership_info,
- }
- )
+ return {
+ "profile": profile,
+ "users_section": {
+ "users": users,
+ "roles": roles,
+ "total_user_count_estimate": 0, # TODO
+ },
+ "rooms_section": {
+ "rooms": rooms,
+ "categories": categories,
+ "total_room_count_estimate": 0, # TODO
+ },
+ "user": membership_info,
+ }
@defer.inlineCallbacks
def update_group_summary_room(
@@ -192,7 +190,7 @@ class GroupsServerHandler(object):
is_public=is_public,
)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def delete_group_summary_room(
@@ -208,7 +206,7 @@ class GroupsServerHandler(object):
group_id=group_id, room_id=room_id, category_id=category_id
)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def set_group_join_policy(self, group_id, requester_user_id, content):
@@ -228,7 +226,7 @@ class GroupsServerHandler(object):
yield self.store.set_group_join_policy(group_id, join_policy=join_policy)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def get_group_categories(self, group_id, requester_user_id):
@@ -237,7 +235,7 @@ class GroupsServerHandler(object):
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
categories = yield self.store.get_group_categories(group_id=group_id)
- defer.returnValue({"categories": categories})
+ return {"categories": categories}
@defer.inlineCallbacks
def get_group_category(self, group_id, requester_user_id, category_id):
@@ -249,7 +247,7 @@ class GroupsServerHandler(object):
group_id=group_id, category_id=category_id
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def update_group_category(self, group_id, requester_user_id, category_id, content):
@@ -269,7 +267,7 @@ class GroupsServerHandler(object):
profile=profile,
)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def delete_group_category(self, group_id, requester_user_id, category_id):
@@ -283,7 +281,7 @@ class GroupsServerHandler(object):
group_id=group_id, category_id=category_id
)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def get_group_roles(self, group_id, requester_user_id):
@@ -292,7 +290,7 @@ class GroupsServerHandler(object):
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
roles = yield self.store.get_group_roles(group_id=group_id)
- defer.returnValue({"roles": roles})
+ return {"roles": roles}
@defer.inlineCallbacks
def get_group_role(self, group_id, requester_user_id, role_id):
@@ -301,7 +299,7 @@ class GroupsServerHandler(object):
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
res = yield self.store.get_group_role(group_id=group_id, role_id=role_id)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def update_group_role(self, group_id, requester_user_id, role_id, content):
@@ -319,7 +317,7 @@ class GroupsServerHandler(object):
group_id=group_id, role_id=role_id, is_public=is_public, profile=profile
)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def delete_group_role(self, group_id, requester_user_id, role_id):
@@ -331,7 +329,7 @@ class GroupsServerHandler(object):
yield self.store.remove_group_role(group_id=group_id, role_id=role_id)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def update_group_summary_user(
@@ -355,7 +353,7 @@ class GroupsServerHandler(object):
is_public=is_public,
)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def delete_group_summary_user(self, group_id, requester_user_id, user_id, role_id):
@@ -369,7 +367,7 @@ class GroupsServerHandler(object):
group_id=group_id, user_id=user_id, role_id=role_id
)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def get_group_profile(self, group_id, requester_user_id):
@@ -391,7 +389,7 @@ class GroupsServerHandler(object):
group_description = {key: group[key] for key in cols}
group_description["is_openly_joinable"] = group["join_policy"] == "open"
- defer.returnValue(group_description)
+ return group_description
else:
raise SynapseError(404, "Unknown group")
@@ -461,9 +459,7 @@ class GroupsServerHandler(object):
# TODO: If admin add lists of users whose attestations have timed out
- defer.returnValue(
- {"chunk": chunk, "total_user_count_estimate": len(user_results)}
- )
+ return {"chunk": chunk, "total_user_count_estimate": len(user_results)}
@defer.inlineCallbacks
def get_invited_users_in_group(self, group_id, requester_user_id):
@@ -494,9 +490,7 @@ class GroupsServerHandler(object):
logger.warn("Error getting profile for %s: %s", user_id, e)
user_profiles.append(user_profile)
- defer.returnValue(
- {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)}
- )
+ return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)}
@defer.inlineCallbacks
def get_rooms_in_group(self, group_id, requester_user_id):
@@ -533,9 +527,7 @@ class GroupsServerHandler(object):
chunk.sort(key=lambda e: -e["num_joined_members"])
- defer.returnValue(
- {"chunk": chunk, "total_room_count_estimate": len(room_results)}
- )
+ return {"chunk": chunk, "total_room_count_estimate": len(room_results)}
@defer.inlineCallbacks
def add_room_to_group(self, group_id, requester_user_id, room_id, content):
@@ -551,7 +543,7 @@ class GroupsServerHandler(object):
yield self.store.add_room_to_group(group_id, room_id, is_public=is_public)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def update_room_in_group(
@@ -574,7 +566,7 @@ class GroupsServerHandler(object):
else:
raise SynapseError(400, "Uknown config option")
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def remove_room_from_group(self, group_id, requester_user_id, room_id):
@@ -586,7 +578,7 @@ class GroupsServerHandler(object):
yield self.store.remove_room_from_group(group_id, room_id)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def invite_to_group(self, group_id, user_id, requester_user_id, content):
@@ -644,9 +636,9 @@ class GroupsServerHandler(object):
)
elif res["state"] == "invite":
yield self.store.add_group_invite(group_id, user_id)
- defer.returnValue({"state": "invite"})
+ return {"state": "invite"}
elif res["state"] == "reject":
- defer.returnValue({"state": "reject"})
+ return {"state": "reject"}
else:
raise SynapseError(502, "Unknown state returned by HS")
@@ -679,7 +671,7 @@ class GroupsServerHandler(object):
remote_attestation=remote_attestation,
)
- defer.returnValue(local_attestation)
+ return local_attestation
@defer.inlineCallbacks
def accept_invite(self, group_id, requester_user_id, content):
@@ -699,7 +691,7 @@ class GroupsServerHandler(object):
local_attestation = yield self._add_user(group_id, requester_user_id, content)
- defer.returnValue({"state": "join", "attestation": local_attestation})
+ return {"state": "join", "attestation": local_attestation}
@defer.inlineCallbacks
def join_group(self, group_id, requester_user_id, content):
@@ -716,7 +708,7 @@ class GroupsServerHandler(object):
local_attestation = yield self._add_user(group_id, requester_user_id, content)
- defer.returnValue({"state": "join", "attestation": local_attestation})
+ return {"state": "join", "attestation": local_attestation}
@defer.inlineCallbacks
def knock(self, group_id, requester_user_id, content):
@@ -769,7 +761,7 @@ class GroupsServerHandler(object):
if not self.hs.is_mine_id(user_id):
yield self.store.maybe_delete_remote_profile_cache(user_id)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def create_group(self, group_id, requester_user_id, content):
@@ -845,7 +837,7 @@ class GroupsServerHandler(object):
avatar_url=user_profile.get("avatar_url"),
)
- defer.returnValue({"group_id": group_id})
+ return {"group_id": group_id}
@defer.inlineCallbacks
def delete_group(self, group_id, requester_user_id):
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index e62e6cab77..8acd9f9a83 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -51,8 +51,8 @@ class AccountDataEventSource(object):
{"type": account_data_type, "content": content, "room_id": room_id}
)
- defer.returnValue((results, current_stream_id))
+ return (results, current_stream_id)
@defer.inlineCallbacks
def get_pagination_rows(self, user, config, key):
- defer.returnValue(([], config.to_id))
+ return ([], config.to_id)
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 1f1708ba7d..51305b0c90 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -43,6 +43,8 @@ class AccountValidityHandler(object):
self.clock = self.hs.get_clock()
self._account_validity = self.hs.config.account_validity
+ self._show_users_in_user_directory = self.hs.config.show_users_in_user_directory
+ self.profile_handler = self.hs.get_profile_handler()
if self._account_validity.renew_by_email_enabled and load_jinja2_templates:
# Don't do email-specific configuration if renewal by email is disabled.
@@ -77,6 +79,9 @@ class AccountValidityHandler(object):
self.clock.looping_call(send_emails, 30 * 60 * 1000)
+ # Check every hour to remove expired users from the user directory
+ self.clock.looping_call(self._mark_expired_users_as_inactive, 60 * 60 * 1000)
+
@defer.inlineCallbacks
def send_renewal_emails(self):
"""Gets the list of users whose account is expiring in the amount of time
@@ -193,7 +198,7 @@ class AccountValidityHandler(object):
if threepid["medium"] == "email":
addresses.append(threepid["address"])
- defer.returnValue(addresses)
+ return addresses
@defer.inlineCallbacks
def _get_renewal_token(self, user_id):
@@ -214,7 +219,7 @@ class AccountValidityHandler(object):
try:
renewal_token = stringutils.random_string(32)
yield self.store.set_renewal_token_for_user(user_id, renewal_token)
- defer.returnValue(renewal_token)
+ return renewal_token
except StoreError:
attempts += 1
raise StoreError(500, "Couldn't generate a unique string as refresh string.")
@@ -226,11 +231,19 @@ class AccountValidityHandler(object):
Args:
renewal_token (str): Token sent with the renewal request.
+ Returns:
+ bool: Whether the provided token is valid.
"""
- user_id = yield self.store.get_user_from_renewal_token(renewal_token)
+ try:
+ user_id = yield self.store.get_user_from_renewal_token(renewal_token)
+ except StoreError:
+ defer.returnValue(False)
+
logger.debug("Renewing an account for user %s", user_id)
yield self.renew_account_for_user(user_id)
+ defer.returnValue(True)
+
@defer.inlineCallbacks
def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False):
"""Renews the account attached to a given user by pushing back the
@@ -254,4 +267,27 @@ class AccountValidityHandler(object):
user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
)
- defer.returnValue(expiration_ts)
+ # Check if renewed users should be reintroduced to the user directory
+ if self._show_users_in_user_directory:
+ # Show the user in the directory again by setting them to active
+ yield self.profile_handler.set_active(
+ UserID.from_string(user_id), True, True
+ )
+
+ return expiration_ts
+
+ @defer.inlineCallbacks
+ def _mark_expired_users_as_inactive(self):
+ """Iterate over expired users. Mark them as inactive in order to hide them from the
+ user directory.
+
+ Returns:
+ Deferred
+ """
+ # Get expired users
+ expired_user_ids = yield self.store.get_expired_users()
+ expired_users = [UserID.from_string(user_id) for user_id in expired_user_ids]
+
+ # Mark each one as non-active
+ for user in expired_users:
+ yield self.profile_handler.set_active(user, False, True)
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
index fbef2f3d38..46ac73106d 100644
--- a/synapse/handlers/acme.py
+++ b/synapse/handlers/acme.py
@@ -100,4 +100,4 @@ class AcmeHandler(object):
logger.exception("Failed saving!")
raise
- defer.returnValue(True)
+ return True
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index e8a651e231..2f22f56ca4 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -49,7 +49,7 @@ class AdminHandler(BaseHandler):
"devices": {"": {"sessions": [{"connections": connections}]}},
}
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def get_users(self):
@@ -61,7 +61,7 @@ class AdminHandler(BaseHandler):
"""
ret = yield self.store.get_users()
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def get_users_paginate(self, order, start, limit):
@@ -78,7 +78,7 @@ class AdminHandler(BaseHandler):
"""
ret = yield self.store.get_users_paginate(order, start, limit)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def search_users(self, term):
@@ -92,7 +92,7 @@ class AdminHandler(BaseHandler):
"""
ret = yield self.store.search_users(term)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def export_user_data(self, user_id, writer):
@@ -225,7 +225,7 @@ class AdminHandler(BaseHandler):
state = yield self.store.get_state_for_event(event_id)
writer.write_state(room_id, event_id, state)
- defer.returnValue(writer.finished())
+ return writer.finished()
class ExfiltrationWriter(object):
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 8f089f0e33..d1a51df6f9 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -167,8 +167,8 @@ class ApplicationServicesHandler(object):
for user_service in user_query_services:
is_known_user = yield self.appservice_api.query_user(user_service, user_id)
if is_known_user:
- defer.returnValue(True)
- defer.returnValue(False)
+ return True
+ return False
@defer.inlineCallbacks
def query_room_alias_exists(self, room_alias):
@@ -192,7 +192,7 @@ class ApplicationServicesHandler(object):
if is_known_alias:
# the alias exists now so don't query more ASes.
result = yield self.store.get_association_from_room_alias(room_alias)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def query_3pe(self, kind, protocol, fields):
@@ -215,7 +215,7 @@ class ApplicationServicesHandler(object):
if success:
ret.extend(result)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def get_3pe_protocols(self, only_protocol=None):
@@ -254,7 +254,7 @@ class ApplicationServicesHandler(object):
for p in protocols.keys():
protocols[p] = _merge_instances(protocols[p])
- defer.returnValue(protocols)
+ return protocols
@defer.inlineCallbacks
def _get_services_for_event(self, event):
@@ -276,7 +276,7 @@ class ApplicationServicesHandler(object):
if (yield s.is_interested(event, self.store)):
interested_list.append(s)
- defer.returnValue(interested_list)
+ return interested_list
def _get_services_for_user(self, user_id):
services = self.store.get_app_services()
@@ -293,23 +293,23 @@ class ApplicationServicesHandler(object):
if not self.is_mine_id(user_id):
# we don't know if they are unknown or not since it isn't one of our
# users. We can't poke ASes.
- defer.returnValue(False)
+ return False
return
user_info = yield self.store.get_user_by_id(user_id)
if user_info:
- defer.returnValue(False)
+ return False
return
# user not found; could be the AS though, so check.
services = self.store.get_app_services()
service_list = [s for s in services if s.sender == user_id]
- defer.returnValue(len(service_list) == 0)
+ return len(service_list) == 0
@defer.inlineCallbacks
def _check_user_exists(self, user_id):
unknown_user = yield self._is_unknown_user(user_id)
if unknown_user:
exists = yield self.query_user_exists(user_id)
- defer.returnValue(exists)
- defer.returnValue(True)
+ return exists
+ return True
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index d4d6574975..bf124032f1 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -155,7 +155,7 @@ class AuthHandler(BaseHandler):
if user_id != requester.user.to_string():
raise AuthError(403, "Invalid auth")
- defer.returnValue(params)
+ return params
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip, password_servlet=False):
@@ -280,7 +280,7 @@ class AuthHandler(BaseHandler):
creds,
list(clientdict),
)
- defer.returnValue((creds, clientdict, session["id"]))
+ return (creds, clientdict, session["id"])
ret = self._auth_dict_for_flows(flows, session)
ret["completed"] = list(creds)
@@ -307,8 +307,8 @@ class AuthHandler(BaseHandler):
if result:
creds[stagetype] = result
self._save_session(sess)
- defer.returnValue(True)
- defer.returnValue(False)
+ return True
+ return False
def get_session_id(self, clientdict):
"""
@@ -379,7 +379,7 @@ class AuthHandler(BaseHandler):
res = yield checker(
authdict, clientip=clientip, password_servlet=password_servlet
)
- defer.returnValue(res)
+ return res
# build a v1-login-style dict out of the authdict and fall back to the
# v1 code
@@ -389,7 +389,7 @@ class AuthHandler(BaseHandler):
raise SynapseError(400, "", Codes.MISSING_PARAM)
(canonical_id, callback) = yield self.validate_login(user_id, authdict)
- defer.returnValue(canonical_id)
+ return canonical_id
@defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip, **kwargs):
@@ -409,7 +409,7 @@ class AuthHandler(BaseHandler):
# TODO: get this from the homeserver rather than creating a new one for
# each request
try:
- client = self.hs.get_simple_http_client()
+ client = self.hs.get_proxied_http_client()
resp_body = yield client.post_urlencoded_get_json(
self.hs.config.recaptcha_siteverify_api,
args={
@@ -433,7 +433,7 @@ class AuthHandler(BaseHandler):
resp_body.get("hostname"),
)
if resp_body["success"]:
- defer.returnValue(True)
+ return True
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
def _check_email_identity(self, authdict, **kwargs):
@@ -502,7 +502,7 @@ class AuthHandler(BaseHandler):
threepid["threepid_creds"] = authdict["threepid_creds"]
- defer.returnValue(threepid)
+ return threepid
def _get_params_recaptcha(self):
return {"public_key": self.hs.config.recaptcha_public_key}
@@ -606,7 +606,7 @@ class AuthHandler(BaseHandler):
yield self.store.delete_access_token(access_token)
raise StoreError(400, "Login raced against device deletion")
- defer.returnValue(access_token)
+ return access_token
@defer.inlineCallbacks
def check_user_exists(self, user_id):
@@ -629,8 +629,8 @@ class AuthHandler(BaseHandler):
self.ratelimit_login_per_account(user_id)
res = yield self._find_user_id_and_pwd_hash(user_id)
if res is not None:
- defer.returnValue(res[0])
- defer.returnValue(None)
+ return res[0]
+ return None
@defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id):
@@ -661,7 +661,7 @@ class AuthHandler(BaseHandler):
user_id,
user_infos.keys(),
)
- defer.returnValue(result)
+ return result
def get_supported_login_types(self):
"""Get a the login types supported for the /login API
@@ -722,7 +722,7 @@ class AuthHandler(BaseHandler):
known_login_type = True
is_valid = yield provider.check_password(qualified_user_id, password)
if is_valid:
- defer.returnValue((qualified_user_id, None))
+ return (qualified_user_id, None)
if not hasattr(provider, "get_supported_login_types") or not hasattr(
provider, "check_auth"
@@ -756,7 +756,7 @@ class AuthHandler(BaseHandler):
if result:
if isinstance(result, str):
result = (result, None)
- defer.returnValue(result)
+ return result
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
known_login_type = True
@@ -766,7 +766,7 @@ class AuthHandler(BaseHandler):
)
if canonical_user_id:
- defer.returnValue((canonical_user_id, None))
+ return (canonical_user_id, None)
if not known_login_type:
raise SynapseError(400, "Unknown login type %s" % login_type)
@@ -814,9 +814,9 @@ class AuthHandler(BaseHandler):
if isinstance(result, str):
# If it's a str, set callback function to None
result = (result, None)
- defer.returnValue(result)
+ return result
- defer.returnValue((None, None))
+ return (None, None)
@defer.inlineCallbacks
def _check_local_password(self, user_id, password):
@@ -838,7 +838,7 @@ class AuthHandler(BaseHandler):
"""
lookupres = yield self._find_user_id_and_pwd_hash(user_id)
if not lookupres:
- defer.returnValue(None)
+ return None
(user_id, password_hash) = lookupres
# If the password hash is None, the account has likely been deactivated
@@ -850,8 +850,8 @@ class AuthHandler(BaseHandler):
result = yield self.validate_hash(password, password_hash)
if not result:
logger.warn("Failed password login for user %s", user_id)
- defer.returnValue(None)
- defer.returnValue(user_id)
+ return None
+ return user_id
@defer.inlineCallbacks
def validate_short_term_login_token_and_get_user_id(self, login_token):
@@ -865,7 +865,7 @@ class AuthHandler(BaseHandler):
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
self.ratelimit_login_per_account(user_id)
yield self.auth.check_auth_blocking(user_id)
- defer.returnValue(user_id)
+ return user_id
@defer.inlineCallbacks
def delete_access_token(self, access_token):
@@ -976,7 +976,7 @@ class AuthHandler(BaseHandler):
)
yield self.store.user_delete_threepid(user_id, medium, address)
- defer.returnValue(result)
+ return result
def _save_session(self, session):
# TODO: Persistent storage
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index e8f9da6098..ad00dcecfd 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -35,6 +35,7 @@ class DeactivateAccountHandler(BaseHandler):
self._device_handler = hs.get_device_handler()
self._room_member_handler = hs.get_room_member_handler()
self._identity_handler = hs.get_handlers().identity_handler
+ self._profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
# Flag that indicates whether the process to part users from rooms is running
@@ -102,6 +103,9 @@ class DeactivateAccountHandler(BaseHandler):
yield self.store.user_set_password_hash(user_id, None)
+ user = UserID.from_string(user_id)
+ yield self._profile_handler.set_active(user, False, False)
+
# Add the user to a table of users pending deactivation (ie.
# removal from all the rooms they're a member of)
yield self.store.add_user_pending_deactivation(user_id)
@@ -118,6 +122,10 @@ class DeactivateAccountHandler(BaseHandler):
# parts users from rooms (if it isn't already running)
self._start_user_parting()
+ # Reject all pending invites for the user, so that the user doesn't show up in the
+ # "invited" section of rooms' members list.
+ yield self._reject_pending_invites_for_user(user_id)
+
# Remove all information on the user from the account_validity table.
if self._account_validity_enabled:
yield self.store.delete_account_validity_for_user(user_id)
@@ -125,7 +133,40 @@ class DeactivateAccountHandler(BaseHandler):
# Mark the user as deactivated.
yield self.store.set_user_deactivated_status(user_id, True)
- defer.returnValue(identity_server_supports_unbinding)
+ return identity_server_supports_unbinding
+
+ @defer.inlineCallbacks
+ def _reject_pending_invites_for_user(self, user_id):
+ """Reject pending invites addressed to a given user ID.
+
+ Args:
+ user_id (str): The user ID to reject pending invites for.
+ """
+ user = UserID.from_string(user_id)
+ pending_invites = yield self.store.get_invited_rooms_for_user(user_id)
+
+ for room in pending_invites:
+ try:
+ yield self._room_member_handler.update_membership(
+ create_requester(user),
+ user,
+ room.room_id,
+ "leave",
+ ratelimit=False,
+ require_consent=False,
+ )
+ logger.info(
+ "Rejected invite for deactivated user %r in room %r",
+ user_id,
+ room.room_id,
+ )
+ except Exception:
+ logger.exception(
+ "Failed to reject invite for user %r in room %r:"
+ " ignoring and continuing",
+ user_id,
+ room.room_id,
+ )
def _start_user_parting(self):
"""
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 99e8413092..d36dd850fd 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -64,7 +64,7 @@ class DeviceWorkerHandler(BaseHandler):
for device in devices:
_update_device_from_client_ips(device, ips)
- defer.returnValue(devices)
+ return devices
@defer.inlineCallbacks
def get_device(self, user_id, device_id):
@@ -85,7 +85,7 @@ class DeviceWorkerHandler(BaseHandler):
raise errors.NotFoundError
ips = yield self.store.get_last_client_ip_by_device(user_id, device_id)
_update_device_from_client_ips(device, ips)
- defer.returnValue(device)
+ return device
@measure_func("device.get_user_ids_changed")
@defer.inlineCallbacks
@@ -200,9 +200,7 @@ class DeviceWorkerHandler(BaseHandler):
possibly_joined = []
possibly_left = []
- defer.returnValue(
- {"changed": list(possibly_joined), "left": list(possibly_left)}
- )
+ return {"changed": list(possibly_joined), "left": list(possibly_left)}
class DeviceHandler(DeviceWorkerHandler):
@@ -250,7 +248,7 @@ class DeviceHandler(DeviceWorkerHandler):
)
if new_device:
yield self.notify_device_update(user_id, [device_id])
- defer.returnValue(device_id)
+ return device_id
# if the device id is not specified, we'll autogen one, but loop a few
# times in case of a clash.
@@ -264,7 +262,7 @@ class DeviceHandler(DeviceWorkerHandler):
)
if new_device:
yield self.notify_device_update(user_id, [device_id])
- defer.returnValue(device_id)
+ return device_id
attempts += 1
raise errors.StoreError(500, "Couldn't generate a device ID.")
@@ -411,9 +409,7 @@ class DeviceHandler(DeviceWorkerHandler):
@defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id):
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
- defer.returnValue(
- {"user_id": user_id, "stream_id": stream_id, "devices": devices}
- )
+ return {"user_id": user_id, "stream_id": stream_id, "devices": devices}
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
@@ -556,6 +552,14 @@ class DeviceListEduUpdater(object):
stream_id = result["stream_id"]
devices = result["devices"]
+ for device in devices:
+ logger.debug(
+ "Handling resync update %r/%r, ID: %r",
+ user_id,
+ device["device_id"],
+ stream_id,
+ )
+
# If the remote server has more than ~1000 devices for this user
# we assume that something is going horribly wrong (e.g. a bot
# that logs in and creates a new device every time it tries to
@@ -623,7 +627,7 @@ class DeviceListEduUpdater(object):
for _, stream_id, prev_ids, _ in updates:
if not prev_ids:
# We always do a resync if there are no previous IDs
- defer.returnValue(True)
+ return True
for prev_id in prev_ids:
if prev_id == extremity:
@@ -633,8 +637,8 @@ class DeviceListEduUpdater(object):
elif prev_id in stream_id_in_updates:
continue
else:
- defer.returnValue(True)
+ return True
stream_id_in_updates.add(stream_id)
- defer.returnValue(False)
+ return False
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 42d5b3db30..0fd423197c 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -210,7 +210,7 @@ class DirectoryHandler(BaseHandler):
except AuthError as e:
logger.info("Failed to update alias events: %s", e)
- defer.returnValue(room_id)
+ return room_id
@defer.inlineCallbacks
def delete_appservice_association(self, service, room_alias):
@@ -229,7 +229,7 @@ class DirectoryHandler(BaseHandler):
room_id = yield self.store.delete_room_alias(room_alias)
- defer.returnValue(room_id)
+ return room_id
@defer.inlineCallbacks
def get_association(self, room_alias):
@@ -277,7 +277,7 @@ class DirectoryHandler(BaseHandler):
else:
servers = list(servers)
- defer.returnValue({"room_id": room_id, "servers": servers})
+ return {"room_id": room_id, "servers": servers}
return
@defer.inlineCallbacks
@@ -289,7 +289,7 @@ class DirectoryHandler(BaseHandler):
result = yield self.get_association_from_room_alias(room_alias)
if result is not None:
- defer.returnValue({"room_id": result.room_id, "servers": result.servers})
+ return {"room_id": result.room_id, "servers": result.servers}
else:
raise SynapseError(
404,
@@ -342,7 +342,7 @@ class DirectoryHandler(BaseHandler):
# Query AS to see if it exists
as_handler = self.appservice_handler
result = yield as_handler.query_room_alias_exists(room_alias)
- defer.returnValue(result)
+ return result
def can_modify_alias(self, alias, user_id=None):
# Any application service "interested" in an alias they are regexing on
@@ -369,10 +369,10 @@ class DirectoryHandler(BaseHandler):
creator = yield self.store.get_room_alias_creator(alias.to_string())
if creator is not None and creator == user_id:
- defer.returnValue(True)
+ return True
is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id))
- defer.returnValue(is_admin)
+ return is_admin
@defer.inlineCallbacks
def edit_published_room_list(self, requester, room_id, visibility):
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index fdfe8611b6..1300b540e3 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -144,7 +144,7 @@ class E2eKeysHandler(object):
)
)
- defer.returnValue({"device_keys": results, "failures": failures})
+ return {"device_keys": results, "failures": failures}
@defer.inlineCallbacks
def query_local_devices(self, query):
@@ -189,7 +189,7 @@ class E2eKeysHandler(object):
r["unsigned"]["device_display_name"] = display_name
result_dict[user_id][device_id] = r
- defer.returnValue(result_dict)
+ return result_dict
@defer.inlineCallbacks
def on_federation_query_client_keys(self, query_body):
@@ -197,7 +197,7 @@ class E2eKeysHandler(object):
"""
device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query)
- defer.returnValue({"device_keys": res})
+ return {"device_keys": res}
@defer.inlineCallbacks
def claim_one_time_keys(self, query, timeout):
@@ -259,7 +259,7 @@ class E2eKeysHandler(object):
),
)
- defer.returnValue({"one_time_keys": json_result, "failures": failures})
+ return {"one_time_keys": json_result, "failures": failures}
@defer.inlineCallbacks
def upload_keys_for_user(self, user_id, device_id, keys):
@@ -297,7 +297,7 @@ class E2eKeysHandler(object):
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
- defer.returnValue({"one_time_key_counts": result})
+ return {"one_time_key_counts": result}
@defer.inlineCallbacks
def _upload_one_time_keys_for_user(
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index ebd807bca6..41b871fc59 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -84,7 +84,7 @@ class E2eRoomKeysHandler(object):
user_id, version, room_id, session_id
)
- defer.returnValue(results)
+ return results
@defer.inlineCallbacks
def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
@@ -262,7 +262,7 @@ class E2eRoomKeysHandler(object):
new_version = yield self.store.create_e2e_room_keys_version(
user_id, version_info
)
- defer.returnValue(new_version)
+ return new_version
@defer.inlineCallbacks
def get_version_info(self, user_id, version=None):
@@ -292,7 +292,7 @@ class E2eRoomKeysHandler(object):
raise NotFoundError("Unknown backup version")
else:
raise
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def delete_version(self, user_id, version=None):
@@ -350,4 +350,4 @@ class E2eRoomKeysHandler(object):
user_id, version, version_info
)
- defer.returnValue({})
+ return {}
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 6a38328af3..2f1f10a9af 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -143,7 +143,7 @@ class EventStreamHandler(BaseHandler):
"end": tokens[1].to_string(),
}
- defer.returnValue(chunk)
+ return chunk
class EventHandler(BaseHandler):
@@ -166,7 +166,7 @@ class EventHandler(BaseHandler):
event = yield self.store.get_event(event_id, check_room_id=room_id)
if not event:
- defer.returnValue(None)
+ return None
return
users = yield self.store.get_users_in_room(event.room_id)
@@ -179,4 +179,4 @@ class EventHandler(BaseHandler):
if not filtered:
raise AuthError(403, "You don't have permission to access that event.")
- defer.returnValue(event)
+ return event
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 30b69af82c..319ee35d9a 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -210,7 +210,7 @@ class FederationHandler(BaseHandler):
event_id,
origin,
)
- defer.returnValue(None)
+ return None
state = None
auth_chain = []
@@ -676,7 +676,7 @@ class FederationHandler(BaseHandler):
events = [e for e in events if e.event_id not in seen_events]
if not events:
- defer.returnValue([])
+ return []
event_map = {e.event_id: e for e in events}
@@ -838,7 +838,7 @@ class FederationHandler(BaseHandler):
# TODO: We can probably do something more clever here.
yield self._handle_new_event(dest, event, backfilled=True)
- defer.returnValue(events)
+ return events
@defer.inlineCallbacks
def maybe_backfill(self, room_id, current_depth):
@@ -894,7 +894,7 @@ class FederationHandler(BaseHandler):
)
if not filtered_extremities:
- defer.returnValue(False)
+ return False
# Check if we reached a point where we should start backfilling.
sorted_extremeties_tuple = sorted(extremities.items(), key=lambda e: -int(e[1]))
@@ -965,7 +965,7 @@ class FederationHandler(BaseHandler):
# If this succeeded then we probably already have the
# appropriate stuff.
# TODO: We can probably do something more intelligent here.
- defer.returnValue(True)
+ return True
except SynapseError as e:
logger.info("Failed to backfill from %s because %s", dom, e)
continue
@@ -985,11 +985,11 @@ class FederationHandler(BaseHandler):
logger.exception("Failed to backfill from %s because %s", dom, e)
continue
- defer.returnValue(False)
+ return False
success = yield try_backfill(likely_domains)
if success:
- defer.returnValue(True)
+ return True
# Huh, well *those* domains didn't work out. Lets try some domains
# from the time.
@@ -1031,11 +1031,11 @@ class FederationHandler(BaseHandler):
[dom for dom, _ in likely_domains if dom not in tried_domains]
)
if success:
- defer.returnValue(True)
+ return True
tried_domains.update(dom for dom, _ in likely_domains)
- defer.returnValue(False)
+ return False
def _sanity_check_event(self, ev):
"""
@@ -1082,7 +1082,7 @@ class FederationHandler(BaseHandler):
pdu=event,
)
- defer.returnValue(pdu)
+ return pdu
@defer.inlineCallbacks
def on_event_auth(self, event_id):
@@ -1090,7 +1090,7 @@ class FederationHandler(BaseHandler):
auth = yield self.store.get_auth_chain(
[auth_id for auth_id in event.auth_event_ids()], include_given=True
)
- defer.returnValue([e for e in auth])
+ return [e for e in auth]
@log_function
@defer.inlineCallbacks
@@ -1177,7 +1177,7 @@ class FederationHandler(BaseHandler):
run_in_background(self._handle_queued_pdus, room_queue)
- defer.returnValue(True)
+ return True
@defer.inlineCallbacks
def _handle_queued_pdus(self, room_queue):
@@ -1264,7 +1264,7 @@ class FederationHandler(BaseHandler):
room_version, event, context, do_sig_check=False
)
- defer.returnValue(event)
+ return event
@defer.inlineCallbacks
@log_function
@@ -1325,7 +1325,7 @@ class FederationHandler(BaseHandler):
state = yield self.store.get_events(list(prev_state_ids.values()))
- defer.returnValue({"state": list(state.values()), "auth_chain": auth_chain})
+ return {"state": list(state.values()), "auth_chain": auth_chain}
@defer.inlineCallbacks
def on_invite_request(self, origin, pdu):
@@ -1345,8 +1345,15 @@ class FederationHandler(BaseHandler):
if self.hs.config.block_non_admin_invites:
raise SynapseError(403, "This server does not accept room invites")
+ is_published = yield self.store.is_room_published(event.room_id)
+
if not self.spam_checker.user_may_invite(
- event.sender, event.state_key, event.room_id
+ event.sender,
+ event.state_key,
+ None,
+ room_id=event.room_id,
+ new_room=False,
+ published_room=is_published,
):
raise SynapseError(
403, "This user is not permitted to send invites to this server/user"
@@ -1381,7 +1388,7 @@ class FederationHandler(BaseHandler):
context = yield self.state_handler.compute_event_context(event)
yield self.persist_events_and_notify([(event, context)])
- defer.returnValue(event)
+ return event
@defer.inlineCallbacks
def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
@@ -1406,7 +1413,7 @@ class FederationHandler(BaseHandler):
context = yield self.state_handler.compute_event_context(event)
yield self.persist_events_and_notify([(event, context)])
- defer.returnValue(event)
+ return event
@defer.inlineCallbacks
def _make_and_verify_event(
@@ -1424,7 +1431,7 @@ class FederationHandler(BaseHandler):
assert event.user_id == user_id
assert event.state_key == user_id
assert event.room_id == room_id
- defer.returnValue((origin, event, format_ver))
+ return (origin, event, format_ver)
@defer.inlineCallbacks
@log_function
@@ -1484,7 +1491,7 @@ class FederationHandler(BaseHandler):
logger.warn("Failed to create new leave %r because %s", event, e)
raise e
- defer.returnValue(event)
+ return event
@defer.inlineCallbacks
@log_function
@@ -1517,7 +1524,7 @@ class FederationHandler(BaseHandler):
event.signatures,
)
- defer.returnValue(None)
+ return None
@defer.inlineCallbacks
def get_state_for_pdu(self, room_id, event_id):
@@ -1545,9 +1552,9 @@ class FederationHandler(BaseHandler):
del results[(event.type, event.state_key)]
res = list(results.values())
- defer.returnValue(res)
+ return res
else:
- defer.returnValue([])
+ return []
@defer.inlineCallbacks
def get_state_ids_for_pdu(self, room_id, event_id):
@@ -1572,9 +1579,9 @@ class FederationHandler(BaseHandler):
else:
results.pop((event.type, event.state_key), None)
- defer.returnValue(list(results.values()))
+ return list(results.values())
else:
- defer.returnValue([])
+ return []
@defer.inlineCallbacks
@log_function
@@ -1587,7 +1594,7 @@ class FederationHandler(BaseHandler):
events = yield filter_events_for_server(self.store, origin, events)
- defer.returnValue(events)
+ return events
@defer.inlineCallbacks
@log_function
@@ -1617,9 +1624,9 @@ class FederationHandler(BaseHandler):
events = yield filter_events_for_server(self.store, origin, [event])
event = events[0]
- defer.returnValue(event)
+ return event
else:
- defer.returnValue(None)
+ return None
def get_min_depth_for_context(self, context):
return self.store.get_min_depth(context)
@@ -1651,7 +1658,7 @@ class FederationHandler(BaseHandler):
self.store.remove_push_actions_from_staging, event.event_id
)
- defer.returnValue(context)
+ return context
@defer.inlineCallbacks
def _handle_new_events(self, origin, event_infos, backfilled=False):
@@ -1674,7 +1681,7 @@ class FederationHandler(BaseHandler):
auth_events=ev_info.get("auth_events"),
backfilled=backfilled,
)
- defer.returnValue(res)
+ return res
contexts = yield make_deferred_yieldable(
defer.gatherResults(
@@ -1833,7 +1840,7 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.GuestAccess and not context.rejected:
yield self.maybe_kick_guest_users(event)
- defer.returnValue(context)
+ return context
@defer.inlineCallbacks
def _check_for_soft_fail(self, event, state, backfilled):
@@ -1952,7 +1959,7 @@ class FederationHandler(BaseHandler):
logger.debug("on_query_auth returning: %s", ret)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def on_get_missing_events(
@@ -1975,7 +1982,7 @@ class FederationHandler(BaseHandler):
self.store, origin, missing_events
)
- defer.returnValue(missing_events)
+ return missing_events
@defer.inlineCallbacks
@log_function
@@ -2451,16 +2458,14 @@ class FederationHandler(BaseHandler):
logger.debug("construct_auth_difference returning")
- defer.returnValue(
- {
- "auth_chain": local_auth,
- "rejects": {
- e.event_id: {"reason": reason_map[e.event_id], "proof": None}
- for e in base_remote_rejected
- },
- "missing": [e.event_id for e in missing_locals],
- }
- )
+ return {
+ "auth_chain": local_auth,
+ "rejects": {
+ e.event_id: {"reason": reason_map[e.event_id], "proof": None}
+ for e in base_remote_rejected
+ },
+ "missing": [e.event_id for e in missing_locals],
+ }
@defer.inlineCallbacks
@log_function
@@ -2505,7 +2510,7 @@ class FederationHandler(BaseHandler):
room_version, event_dict, event, context
)
- EventValidator().validate_new(event)
+ EventValidator().validate_new(event, self.config)
# We need to tell the transaction queue to send this out, even
# though the sender isn't a local user.
@@ -2563,7 +2568,7 @@ class FederationHandler(BaseHandler):
)
try:
- self.auth.check_from_context(room_version, event, context)
+ yield self.auth.check_from_context(room_version, event, context)
except AuthError as e:
logger.warn("Denying third party invite %r because %s", event, e)
raise e
@@ -2592,7 +2597,12 @@ class FederationHandler(BaseHandler):
original_invite_id, allow_none=True
)
if original_invite:
- display_name = original_invite.content["display_name"]
+ # If the m.room.third_party_invite event's content is empty, it means the
+ # invite has been revoked. In this case, we don't have to raise an error here
+ # because the auth check will fail on the invite (because it's not able to
+ # fetch public keys from the m.room.third_party_invite event's content, which
+ # is empty).
+ display_name = original_invite.content.get("display_name")
event_dict["content"]["third_party_invite"]["display_name"] = display_name
else:
logger.info(
@@ -2607,8 +2617,8 @@ class FederationHandler(BaseHandler):
event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder
)
- EventValidator().validate_new(event)
- defer.returnValue((event, context))
+ EventValidator().validate_new(event, self.config)
+ return (event, context)
@defer.inlineCallbacks
def _check_signature(self, event, context):
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 7da63bb643..7b67c8ae0f 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -162,7 +162,7 @@ class GroupsLocalHandler(object):
res.setdefault("user", {})["is_publicised"] = is_publicised
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def create_group(self, group_id, user_id, content):
@@ -207,7 +207,7 @@ class GroupsLocalHandler(object):
)
self.notifier.on_new_event("groups_key", token, users=[user_id])
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def get_users_in_group(self, group_id, requester_user_id):
@@ -217,7 +217,7 @@ class GroupsLocalHandler(object):
res = yield self.groups_server_handler.get_users_in_group(
group_id, requester_user_id
)
- defer.returnValue(res)
+ return res
group_server_name = get_domain_from_id(group_id)
@@ -244,7 +244,7 @@ class GroupsLocalHandler(object):
res["chunk"] = valid_entries
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def join_group(self, group_id, user_id, content):
@@ -285,7 +285,7 @@ class GroupsLocalHandler(object):
)
self.notifier.on_new_event("groups_key", token, users=[user_id])
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def accept_invite(self, group_id, user_id, content):
@@ -326,7 +326,7 @@ class GroupsLocalHandler(object):
)
self.notifier.on_new_event("groups_key", token, users=[user_id])
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def invite(self, group_id, user_id, requester_user_id, config):
@@ -346,7 +346,7 @@ class GroupsLocalHandler(object):
content,
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def on_invite(self, group_id, user_id, content):
@@ -377,7 +377,7 @@ class GroupsLocalHandler(object):
logger.warn("No profile for user %s: %s", user_id, e)
user_profile = {}
- defer.returnValue({"state": "invite", "user_profile": user_profile})
+ return {"state": "invite", "user_profile": user_profile}
@defer.inlineCallbacks
def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
@@ -406,7 +406,7 @@ class GroupsLocalHandler(object):
content,
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def user_removed_from_group(self, group_id, user_id, content):
@@ -421,7 +421,7 @@ class GroupsLocalHandler(object):
@defer.inlineCallbacks
def get_joined_groups(self, user_id):
group_ids = yield self.store.get_joined_groups(user_id)
- defer.returnValue({"groups": group_ids})
+ return {"groups": group_ids}
@defer.inlineCallbacks
def get_publicised_groups_for_user(self, user_id):
@@ -433,14 +433,14 @@ class GroupsLocalHandler(object):
for app_service in self.store.get_app_services():
result.extend(app_service.get_groups_for_user(user_id))
- defer.returnValue({"groups": result})
+ return {"groups": result}
else:
bulk_result = yield self.transport_client.bulk_get_publicised_groups(
get_domain_from_id(user_id), [user_id]
)
result = bulk_result.get("users", {}).get(user_id)
# TODO: Verify attestations
- defer.returnValue({"groups": result})
+ return {"groups": result}
@defer.inlineCallbacks
def bulk_get_publicised_groups(self, user_ids, proxy=True):
@@ -475,4 +475,4 @@ class GroupsLocalHandler(object):
for app_service in self.store.get_app_services():
results[uid].extend(app_service.get_groups_for_user(uid))
- defer.returnValue({"users": results})
+ return {"users": results}
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 546d6169e9..339e0dd04d 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018, 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,13 +20,18 @@
import logging
from canonicaljson import json
+from signedjson.key import decode_verify_key_bytes
+from signedjson.sign import verify_signed_json
+from unpaddedbase64 import decode_base64
from twisted.internet import defer
from synapse.api.errors import (
+ AuthError,
CodeMessageException,
Codes,
HttpResponseException,
+ ProxiedRequestError,
SynapseError,
)
@@ -46,6 +51,8 @@ class IdentityHandler(BaseHandler):
self.trust_any_id_server_just_for_testing_do_not_use = (
hs.config.use_insecure_ssl_client_just_for_testing_do_not_use
)
+ self.rewrite_identity_server_urls = hs.config.rewrite_identity_server_urls
+ self._enable_lookup = hs.config.enable_3pid_lookup
def _should_trust_id_server(self, id_server):
if id_server not in self.trusted_id_servers:
@@ -82,8 +89,11 @@ class IdentityHandler(BaseHandler):
"%s is not a trusted ID server: rejecting 3pid " + "credentials",
id_server,
)
- defer.returnValue(None)
-
+ return None
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ if id_server in self.rewrite_identity_server_urls:
+ id_server = self.rewrite_identity_server_urls[id_server]
try:
data = yield self.http_client.get_json(
"https://%s%s"
@@ -95,8 +105,8 @@ class IdentityHandler(BaseHandler):
raise e.to_synapse_error()
if "medium" in data:
- defer.returnValue(data)
- defer.returnValue(None)
+ return data
+ return None
@defer.inlineCallbacks
def bind_threepid(self, creds, mxid):
@@ -117,9 +127,17 @@ class IdentityHandler(BaseHandler):
else:
raise SynapseError(400, "No client_secret in creds")
+ # if we have a rewrite rule set for the identity server,
+ # apply it now, but only for sending the request (not
+ # storing in the database).
+ if id_server in self.rewrite_identity_server_urls:
+ id_server_host = self.rewrite_identity_server_urls[id_server]
+ else:
+ id_server_host = id_server
+
try:
data = yield self.http_client.post_json_get_json(
- "https://%s%s" % (id_server, "/_matrix/identity/api/v1/3pid/bind"),
+ "https://%s%s" % (id_server_host, "/_matrix/identity/api/v1/3pid/bind"),
{"sid": creds["sid"], "client_secret": client_secret, "mxid": mxid},
)
logger.debug("bound threepid %r to %s", creds, mxid)
@@ -133,7 +151,7 @@ class IdentityHandler(BaseHandler):
)
except CodeMessageException as e:
data = json.loads(e.msg) # XXX WAT?
- defer.returnValue(data)
+ return data
@defer.inlineCallbacks
def try_unbind_threepid(self, mxid, threepid):
@@ -161,7 +179,7 @@ class IdentityHandler(BaseHandler):
# We don't know where to unbind, so we don't have a choice but to return
if not id_servers:
- defer.returnValue(False)
+ return False
changed = True
for id_server in id_servers:
@@ -169,7 +187,7 @@ class IdentityHandler(BaseHandler):
mxid, threepid, id_server
)
- defer.returnValue(changed)
+ return changed
@defer.inlineCallbacks
def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
@@ -205,6 +223,16 @@ class IdentityHandler(BaseHandler):
)
headers = {b"Authorization": auth_headers}
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ #
+ # Note that destination_is has to be the real id_server, not
+ # the server we connect to.
+ if id_server in self.rewrite_identity_server_urls:
+ id_server = self.rewrite_identity_server_urls[id_server]
+
+ url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
+
try:
yield self.http_client.post_json_get_json(url, content, headers)
changed = True
@@ -224,7 +252,7 @@ class IdentityHandler(BaseHandler):
id_server=id_server,
)
- defer.returnValue(changed)
+ return changed
@defer.inlineCallbacks
def requestEmailToken(
@@ -241,6 +269,11 @@ class IdentityHandler(BaseHandler):
"send_attempt": send_attempt,
}
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ if id_server in self.rewrite_identity_server_urls:
+ id_server = self.rewrite_identity_server_urls[id_server]
+
if next_link:
params.update({"next_link": next_link})
@@ -250,7 +283,7 @@ class IdentityHandler(BaseHandler):
% (id_server, "/_matrix/identity/api/v1/validate/email/requestToken"),
params,
)
- defer.returnValue(data)
+ return data
except HttpResponseException as e:
logger.info("Proxied requestToken failed: %r", e)
raise e.to_synapse_error()
@@ -271,14 +304,134 @@ class IdentityHandler(BaseHandler):
"send_attempt": send_attempt,
}
params.update(kwargs)
-
+ # if we have a rewrite rule set for the identity server,
+ # apply it now.
+ if id_server in self.rewrite_identity_server_urls:
+ id_server = self.rewrite_identity_server_urls[id_server]
try:
data = yield self.http_client.post_json_get_json(
"https://%s%s"
% (id_server, "/_matrix/identity/api/v1/validate/msisdn/requestToken"),
params,
)
- defer.returnValue(data)
+ return data
except HttpResponseException as e:
logger.info("Proxied requestToken failed: %r", e)
raise e.to_synapse_error()
+
+ @defer.inlineCallbacks
+ def lookup_3pid(self, id_server, medium, address):
+ """Looks up a 3pid in the passed identity server.
+
+ Args:
+ id_server (str): The server name (including port, if required)
+ of the identity server to use.
+ medium (str): The type of the third party identifier (e.g. "email").
+ address (str): The third party identifier (e.g. "foo@example.com").
+
+ Returns:
+ Deferred[dict]: The result of the lookup. See
+ https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup
+ for details
+ """
+ if not self._should_trust_id_server(id_server):
+ raise SynapseError(
+ 400, "Untrusted ID server '%s'" % id_server, Codes.SERVER_NOT_TRUSTED
+ )
+
+ if not self._enable_lookup:
+ raise AuthError(
+ 403, "Looking up third-party identifiers is denied from this server"
+ )
+
+ target = self.rewrite_identity_server_urls.get(id_server, id_server)
+
+ try:
+ data = yield self.http_client.get_json(
+ "https://%s/_matrix/identity/api/v1/lookup" % (target,),
+ {"medium": medium, "address": address},
+ )
+
+ if "mxid" in data:
+ if "signatures" not in data:
+ raise AuthError(401, "No signatures on 3pid binding")
+ yield self._verify_any_signature(data, id_server)
+
+ except HttpResponseException as e:
+ logger.info("Proxied lookup failed: %r", e)
+ raise e.to_synapse_error()
+ except IOError as e:
+ logger.info("Failed to contact %r: %s", id_server, e)
+ raise ProxiedRequestError(503, "Failed to contact identity server")
+
+ defer.returnValue(data)
+
+ @defer.inlineCallbacks
+ def bulk_lookup_3pid(self, id_server, threepids):
+ """Looks up given 3pids in the passed identity server.
+
+ Args:
+ id_server (str): The server name (including port, if required)
+ of the identity server to use.
+ threepids ([[str, str]]): The third party identifiers to lookup, as
+ a list of 2-string sized lists ([medium, address]).
+
+ Returns:
+ Deferred[dict]: The result of the lookup. See
+ https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup
+ for details
+ """
+ if not self._should_trust_id_server(id_server):
+ raise SynapseError(
+ 400, "Untrusted ID server '%s'" % id_server, Codes.SERVER_NOT_TRUSTED
+ )
+
+ if not self._enable_lookup:
+ raise AuthError(
+ 403, "Looking up third-party identifiers is denied from this server"
+ )
+
+ target = self.rewrite_identity_server_urls.get(id_server, id_server)
+
+ try:
+ data = yield self.http_client.post_json_get_json(
+ "https://%s/_matrix/identity/api/v1/bulk_lookup" % (target,),
+ {"threepids": threepids},
+ )
+
+ except HttpResponseException as e:
+ logger.info("Proxied lookup failed: %r", e)
+ raise e.to_synapse_error()
+ except IOError as e:
+ logger.info("Failed to contact %r: %s", id_server, e)
+ raise ProxiedRequestError(503, "Failed to contact identity server")
+
+ defer.returnValue(data)
+
+ @defer.inlineCallbacks
+ def _verify_any_signature(self, data, server_hostname):
+ if server_hostname not in data["signatures"]:
+ raise AuthError(401, "No signature from server %s" % (server_hostname,))
+
+ for key_name, signature in data["signatures"][server_hostname].items():
+ target = self.rewrite_identity_server_urls.get(
+ server_hostname, server_hostname
+ )
+
+ key_data = yield self.http_client.get_json(
+ "https://%s/_matrix/identity/api/v1/pubkey/%s" % (target, key_name)
+ )
+ if "public_key" not in key_data:
+ raise AuthError(
+ 401, "No public key named %s from %s" % (key_name, server_hostname)
+ )
+ verify_signed_json(
+ data,
+ server_hostname,
+ decode_verify_key_bytes(
+ key_name, decode_base64(key_data["public_key"])
+ ),
+ )
+ return
+
+ raise AuthError(401, "No signature from server %s" % (server_hostname,))
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 54c966c8a6..42d6650ed9 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -250,7 +250,7 @@ class InitialSyncHandler(BaseHandler):
"end": now_token.to_string(),
}
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def room_initial_sync(self, requester, room_id, pagin_config=None):
@@ -301,7 +301,7 @@ class InitialSyncHandler(BaseHandler):
result["account_data"] = account_data_events
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _room_initial_sync_parted(
@@ -330,28 +330,24 @@ class InitialSyncHandler(BaseHandler):
time_now = self.clock.time_msec()
- defer.returnValue(
- {
- "membership": membership,
- "room_id": room_id,
- "messages": {
- "chunk": (
- yield self._event_serializer.serialize_events(
- messages, time_now
- )
- ),
- "start": start_token.to_string(),
- "end": end_token.to_string(),
- },
- "state": (
- yield self._event_serializer.serialize_events(
- room_state.values(), time_now
- )
+ return {
+ "membership": membership,
+ "room_id": room_id,
+ "messages": {
+ "chunk": (
+ yield self._event_serializer.serialize_events(messages, time_now)
),
- "presence": [],
- "receipts": [],
- }
- )
+ "start": start_token.to_string(),
+ "end": end_token.to_string(),
+ },
+ "state": (
+ yield self._event_serializer.serialize_events(
+ room_state.values(), time_now
+ )
+ ),
+ "presence": [],
+ "receipts": [],
+ }
@defer.inlineCallbacks
def _room_initial_sync_joined(
@@ -384,13 +380,13 @@ class InitialSyncHandler(BaseHandler):
def get_presence():
# If presence is disabled, return an empty list
if not self.hs.config.use_presence:
- defer.returnValue([])
+ return []
states = yield presence_handler.get_states(
[m.user_id for m in room_members], as_event=True
)
- defer.returnValue(states)
+ return states
@defer.inlineCallbacks
def get_receipts():
@@ -399,7 +395,7 @@ class InitialSyncHandler(BaseHandler):
)
if not receipts:
receipts = []
- defer.returnValue(receipts)
+ return receipts
presence, receipts, (messages, token) = yield make_deferred_yieldable(
defer.gatherResults(
@@ -442,7 +438,7 @@ class InitialSyncHandler(BaseHandler):
if not is_peeking:
ret["membership"] = membership
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def _check_in_room_or_world_readable(self, room_id, user_id):
@@ -453,7 +449,7 @@ class InitialSyncHandler(BaseHandler):
# * The user is a guest user, and has joined the room
# else it will throw.
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
- defer.returnValue((member_event.membership, member_event.event_id))
+ return (member_event.membership, member_event.event_id)
return
except AuthError:
visibility = yield self.state_handler.get_current_state(
@@ -463,7 +459,7 @@ class InitialSyncHandler(BaseHandler):
visibility
and visibility.content["history_visibility"] == "world_readable"
):
- defer.returnValue((Membership.JOIN, None))
+ return (Membership.JOIN, None)
return
raise AuthError(
403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 6d7a987f13..2e1a989782 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -87,7 +87,7 @@ class MessageHandler(object):
)
data = room_state[membership_event_id].get(key)
- defer.returnValue(data)
+ return data
@defer.inlineCallbacks
def get_state_events(
@@ -135,7 +135,7 @@ class MessageHandler(object):
raise NotFoundError("Can't find event for token %s" % (at_token,))
visible_events = yield filter_events_for_client(
- self.store, user_id, last_events
+ self.store, user_id, last_events, apply_retention_policies=False
)
event = last_events[0]
@@ -174,7 +174,7 @@ class MessageHandler(object):
# events, as clients won't use them.
bundle_aggregations=False,
)
- defer.returnValue(events)
+ return events
@defer.inlineCallbacks
def get_joined_members(self, requester, room_id):
@@ -213,15 +213,13 @@ class MessageHandler(object):
# Loop fell through, AS has no interested users in room
raise AuthError(403, "Appservice not in room")
- defer.returnValue(
- {
- user_id: {
- "avatar_url": profile.avatar_url,
- "display_name": profile.display_name,
- }
- for user_id, profile in iteritems(users_with_profile)
+ return {
+ user_id: {
+ "avatar_url": profile.avatar_url,
+ "display_name": profile.display_name,
}
- )
+ for user_id, profile in iteritems(users_with_profile)
+ }
class EventCreationHandler(object):
@@ -380,7 +378,11 @@ class EventCreationHandler(object):
# tolerate them in event_auth.check().
prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
- prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
+ prev_event = (
+ yield self.store.get_event(prev_event_id, allow_none=True)
+ if prev_event_id
+ else None
+ )
if not prev_event or prev_event.membership != Membership.JOIN:
logger.warning(
(
@@ -396,9 +398,9 @@ class EventCreationHandler(object):
403, "You must be in the room to create an alias for it"
)
- self.validator.validate_new(event)
+ self.validator.validate_new(event, self.config)
- defer.returnValue((event, context))
+ return (event, context)
def _is_exempt_from_privacy_policy(self, builder, requester):
""""Determine if an event to be sent is exempt from having to consent
@@ -425,9 +427,9 @@ class EventCreationHandler(object):
@defer.inlineCallbacks
def _is_server_notices_room(self, room_id):
if self.config.server_notices_mxid is None:
- defer.returnValue(False)
+ return False
user_ids = yield self.store.get_users_in_room(room_id)
- defer.returnValue(self.config.server_notices_mxid in user_ids)
+ return self.config.server_notices_mxid in user_ids
@defer.inlineCallbacks
def assert_accepted_privacy_policy(self, requester):
@@ -507,7 +509,7 @@ class EventCreationHandler(object):
event.event_id,
prev_state.event_id,
)
- defer.returnValue(prev_state)
+ return prev_state
yield self.handle_new_client_event(
requester=requester, event=event, context=context, ratelimit=ratelimit
@@ -523,6 +525,8 @@ class EventCreationHandler(object):
"""
prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_event_id = prev_state_ids.get((event.type, event.state_key))
+ if not prev_event_id:
+ return
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
@@ -531,7 +535,7 @@ class EventCreationHandler(object):
prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
- defer.returnValue(prev_event)
+ return prev_event
return
@defer.inlineCallbacks
@@ -563,7 +567,7 @@ class EventCreationHandler(object):
yield self.send_nonmember_event(
requester, event, context, ratelimit=ratelimit
)
- defer.returnValue(event)
+ return event
@measure_func("create_new_client_event")
@defer.inlineCallbacks
@@ -608,7 +612,7 @@ class EventCreationHandler(object):
if requester:
context.app_service = requester.app_service
- self.validator.validate_new(event)
+ self.validator.validate_new(event, self.config)
# If this event is an annotation then we check that that the sender
# can't annotate the same way twice (e.g. stops users from liking an
@@ -626,7 +630,7 @@ class EventCreationHandler(object):
logger.debug("Created event %s", event.event_id)
- defer.returnValue((event, context))
+ return (event, context)
@measure_func("handle_new_client_event")
@defer.inlineCallbacks
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 20bcfed334..6711ced51a 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -15,12 +15,15 @@
# limitations under the License.
import logging
+from six import iteritems
+
from twisted.internet import defer
from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.logging.context import run_in_background
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.state import StateFilter
from synapse.types import RoomStreamToken
from synapse.util.async_helpers import ReadWriteLock
@@ -77,6 +80,111 @@ class PaginationHandler(object):
self._purges_by_id = {}
self._event_serializer = hs.get_event_client_serializer()
+ self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime
+
+ if hs.config.retention_enabled:
+ # Run the purge jobs described in the configuration file.
+ for job in hs.config.retention_purge_jobs:
+ self.clock.looping_call(
+ run_as_background_process,
+ job["interval"],
+ "purge_history_for_rooms_in_range",
+ self.purge_history_for_rooms_in_range,
+ job["shortest_max_lifetime"],
+ job["longest_max_lifetime"],
+ )
+
+ @defer.inlineCallbacks
+ def purge_history_for_rooms_in_range(self, min_ms, max_ms):
+ """Purge outdated events from rooms within the given retention range.
+
+ If a default retention policy is defined in the server's configuration and its
+ 'max_lifetime' is within this range, also targets rooms which don't have a
+ retention policy.
+
+ Args:
+ min_ms (int|None): Duration in milliseconds that define the lower limit of
+ the range to handle (exclusive). If None, it means that the range has no
+ lower limit.
+ max_ms (int|None): Duration in milliseconds that define the upper limit of
+ the range to handle (inclusive). If None, it means that the range has no
+ upper limit.
+ """
+ # We want the storage layer to to include rooms with no retention policy in its
+ # return value only if a default retention policy is defined in the server's
+ # configuration and that policy's 'max_lifetime' is either lower (or equal) than
+ # max_ms or higher than min_ms (or both).
+ if self._retention_default_max_lifetime is not None:
+ include_null = True
+
+ if min_ms is not None and min_ms >= self._retention_default_max_lifetime:
+ # The default max_lifetime is lower than (or equal to) min_ms.
+ include_null = False
+
+ if max_ms is not None and max_ms < self._retention_default_max_lifetime:
+ # The default max_lifetime is higher than max_ms.
+ include_null = False
+ else:
+ include_null = False
+
+ rooms = yield self.store.get_rooms_for_retention_period_in_range(
+ min_ms, max_ms, include_null
+ )
+
+ for room_id, retention_policy in iteritems(rooms):
+ if room_id in self._purges_in_progress_by_room:
+ logger.warning(
+ "[purge] not purging room %s as there's an ongoing purge running"
+ " for this room",
+ room_id,
+ )
+ continue
+
+ max_lifetime = retention_policy["max_lifetime"]
+
+ if max_lifetime is None:
+ # If max_lifetime is None, it means that include_null equals True,
+ # therefore we can safely assume that there is a default policy defined
+ # in the server's configuration.
+ max_lifetime = self._retention_default_max_lifetime
+
+ # Figure out what token we should start purging at.
+ ts = self.clock.time_msec() - max_lifetime
+
+ stream_ordering = (yield self.store.find_first_stream_ordering_after_ts(ts))
+
+ r = (
+ yield self.store.get_room_event_after_stream_ordering(
+ room_id, stream_ordering
+ )
+ )
+ if not r:
+ logger.warning(
+ "[purge] purging events not possible: No event found "
+ "(ts %i => stream_ordering %i)",
+ ts,
+ stream_ordering,
+ )
+ continue
+
+ (stream, topo, _event_id) = r
+ token = "t%d-%d" % (topo, stream)
+
+ purge_id = random_string(16)
+
+ self._purges_by_id[purge_id] = PurgeStatus()
+
+ logger.info(
+ "Starting purging events in room %s (purge_id %s)" % (room_id, purge_id)
+ )
+
+ # We want to purge everything, including local events, and to run the purge in
+ # the background so that it's not blocking any other operation apart from
+ # other purges in the same room.
+ run_as_background_process(
+ "_purge_history", self._purge_history, purge_id, room_id, token, True
+ )
+
def start_purge_history(self, room_id, token, delete_local_events=False):
"""Start off a history purge on a room.
@@ -242,13 +350,11 @@ class PaginationHandler(object):
)
if not events:
- defer.returnValue(
- {
- "chunk": [],
- "start": pagin_config.from_token.to_string(),
- "end": next_token.to_string(),
- }
- )
+ return {
+ "chunk": [],
+ "start": pagin_config.from_token.to_string(),
+ "end": next_token.to_string(),
+ }
state = None
if event_filter and event_filter.lazy_load_members() and len(events) > 0:
@@ -286,4 +392,4 @@ class PaginationHandler(object):
)
)
- defer.returnValue(chunk)
+ return chunk
diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py
new file mode 100644
index 0000000000..d06b110269
--- /dev/null
+++ b/synapse/handlers/password_policy.py
@@ -0,0 +1,93 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+
+from synapse.api.errors import Codes, PasswordRefusedError
+
+logger = logging.getLogger(__name__)
+
+
+class PasswordPolicyHandler(object):
+ def __init__(self, hs):
+ self.policy = hs.config.password_policy
+ self.enabled = hs.config.password_policy_enabled
+
+ # Regexps for the spec'd policy parameters.
+ self.regexp_digit = re.compile("[0-9]")
+ self.regexp_symbol = re.compile("[^a-zA-Z0-9]")
+ self.regexp_uppercase = re.compile("[A-Z]")
+ self.regexp_lowercase = re.compile("[a-z]")
+
+ def validate_password(self, password):
+ """Checks whether a given password complies with the server's policy.
+
+ Args:
+ password (str): The password to check against the server's policy.
+
+ Raises:
+ PasswordRefusedError: The password doesn't comply with the server's policy.
+ """
+
+ if not self.enabled:
+ return
+
+ minimum_accepted_length = self.policy.get("minimum_length", 0)
+ if len(password) < minimum_accepted_length:
+ raise PasswordRefusedError(
+ msg=(
+ "The password must be at least %d characters long"
+ % minimum_accepted_length
+ ),
+ errcode=Codes.PASSWORD_TOO_SHORT,
+ )
+
+ if (
+ self.policy.get("require_digit", False)
+ and self.regexp_digit.search(password) is None
+ ):
+ raise PasswordRefusedError(
+ msg="The password must include at least one digit",
+ errcode=Codes.PASSWORD_NO_DIGIT,
+ )
+
+ if (
+ self.policy.get("require_symbol", False)
+ and self.regexp_symbol.search(password) is None
+ ):
+ raise PasswordRefusedError(
+ msg="The password must include at least one symbol",
+ errcode=Codes.PASSWORD_NO_SYMBOL,
+ )
+
+ if (
+ self.policy.get("require_uppercase", False)
+ and self.regexp_uppercase.search(password) is None
+ ):
+ raise PasswordRefusedError(
+ msg="The password must include at least one uppercase letter",
+ errcode=Codes.PASSWORD_NO_UPPERCASE,
+ )
+
+ if (
+ self.policy.get("require_lowercase", False)
+ and self.regexp_lowercase.search(password) is None
+ ):
+ raise PasswordRefusedError(
+ msg="The password must include at least one lowercase letter",
+ errcode=Codes.PASSWORD_NO_LOWERCASE,
+ )
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 6f3537e435..ea54d0b991 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -461,7 +461,7 @@ class PresenceHandler(object):
if affect_presence:
run_in_background(_end)
- defer.returnValue(_user_syncing())
+ return _user_syncing()
def get_currently_syncing_users(self):
"""Get the set of user ids that are currently syncing on this HS.
@@ -556,7 +556,7 @@ class PresenceHandler(object):
"""Get the current presence state for a user.
"""
res = yield self.current_state_for_users([user_id])
- defer.returnValue(res[user_id])
+ return res[user_id]
@defer.inlineCallbacks
def current_state_for_users(self, user_ids):
@@ -585,7 +585,7 @@ class PresenceHandler(object):
states.update(new)
self.user_to_current_state.update(new)
- defer.returnValue(states)
+ return states
@defer.inlineCallbacks
def _persist_and_notify(self, states):
@@ -681,7 +681,7 @@ class PresenceHandler(object):
def get_state(self, target_user, as_event=False):
results = yield self.get_states([target_user.to_string()], as_event=as_event)
- defer.returnValue(results[0])
+ return results[0]
@defer.inlineCallbacks
def get_states(self, target_user_ids, as_event=False):
@@ -703,17 +703,15 @@ class PresenceHandler(object):
now = self.clock.time_msec()
if as_event:
- defer.returnValue(
- [
- {
- "type": "m.presence",
- "content": format_user_presence_state(state, now),
- }
- for state in updates
- ]
- )
+ return [
+ {
+ "type": "m.presence",
+ "content": format_user_presence_state(state, now),
+ }
+ for state in updates
+ ]
else:
- defer.returnValue(updates)
+ return updates
@defer.inlineCallbacks
def set_state(self, target_user, state, ignore_status_msg=False):
@@ -757,9 +755,9 @@ class PresenceHandler(object):
)
if observer_room_ids & observed_room_ids:
- defer.returnValue(True)
+ return True
- defer.returnValue(False)
+ return False
@defer.inlineCallbacks
def get_all_presence_updates(self, last_id, current_id):
@@ -778,7 +776,7 @@ class PresenceHandler(object):
# TODO(markjh): replicate the unpersisted changes.
# This could use the in-memory stores for recent changes.
rows = yield self.store.get_all_presence_updates(last_id, current_id)
- defer.returnValue(rows)
+ return rows
def notify_new_event(self):
"""Called when new events have happened. Handles users and servers
@@ -1034,7 +1032,7 @@ class PresenceEventSource(object):
#
# Hence this guard where we just return nothing so that the sync
# doesn't return. C.f. #5503.
- defer.returnValue(([], max_token))
+ return ([], max_token)
presence = self.get_presence_handler()
stream_change_cache = self.store.presence_stream_cache
@@ -1068,17 +1066,11 @@ class PresenceEventSource(object):
updates = yield presence.current_state_for_users(user_ids_changed)
if include_offline:
- defer.returnValue((list(updates.values()), max_token))
+ return (list(updates.values()), max_token)
else:
- defer.returnValue(
- (
- [
- s
- for s in itervalues(updates)
- if s.state != PresenceState.OFFLINE
- ],
- max_token,
- )
+ return (
+ [s for s in itervalues(updates) if s.state != PresenceState.OFFLINE],
+ max_token,
)
def get_current_key(self):
@@ -1107,7 +1099,7 @@ class PresenceEventSource(object):
)
users_interested_in.update(user_ids)
- defer.returnValue(users_interested_in)
+ return users_interested_in
def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
@@ -1287,7 +1279,7 @@ def get_interested_parties(store, states):
# Always notify self
users_to_states.setdefault(state.user_id, []).append(state)
- defer.returnValue((room_ids_to_states, users_to_states))
+ return (room_ids_to_states, users_to_states)
@defer.inlineCallbacks
@@ -1321,4 +1313,4 @@ def get_interested_remotes(store, states, state_handler):
host = get_domain_from_id(user_id)
hosts_and_states.append(([host], states))
- defer.returnValue(hosts_and_states)
+ return hosts_and_states
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index a2388a7091..136128b625 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,8 +17,11 @@
import logging
from six import raise_from
+from six.moves import range
-from twisted.internet import defer
+from signedjson.sign import sign_json
+
+from twisted.internet import defer, reactor
from synapse.api.errors import (
AuthError,
@@ -27,6 +31,7 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
+from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID, get_domain_from_id
@@ -46,6 +51,8 @@ class BaseProfileHandler(BaseHandler):
subclass MasterProfileHandler
"""
+ PROFILE_REPLICATE_INTERVAL = 2 * 60 * 1000
+
def __init__(self, hs):
super(BaseProfileHandler, self).__init__(hs)
@@ -56,6 +63,87 @@ class BaseProfileHandler(BaseHandler):
self.user_directory_handler = hs.get_user_directory_handler()
+ self.http_client = hs.get_simple_http_client()
+
+ self.max_avatar_size = hs.config.max_avatar_size
+ self.allowed_avatar_mimetypes = hs.config.allowed_avatar_mimetypes
+
+ if hs.config.worker_app is None:
+ self.clock.looping_call(
+ self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS
+ )
+
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ reactor.callWhenRunning(self._assign_profile_replication_batches)
+ reactor.callWhenRunning(self._replicate_profiles)
+ # Add a looping call to replicate_profiles: this handles retries
+ # if the replication is unsuccessful when the user updated their
+ # profile.
+ self.clock.looping_call(
+ self._replicate_profiles, self.PROFILE_REPLICATE_INTERVAL
+ )
+
+ @defer.inlineCallbacks
+ def _assign_profile_replication_batches(self):
+ """If no profile replication has been done yet, allocate replication batch
+ numbers to each profile to start the replication process.
+ """
+ logger.info("Assigning profile batch numbers...")
+ total = 0
+ while True:
+ assigned = yield self.store.assign_profile_batch()
+ total += assigned
+ if assigned == 0:
+ break
+ logger.info("Assigned %d profile batch numbers", total)
+
+ @defer.inlineCallbacks
+ def _replicate_profiles(self):
+ """If any profile data has been updated and not pushed to the replication targets,
+ replicate it.
+ """
+ host_batches = yield self.store.get_replication_hosts()
+ latest_batch = yield self.store.get_latest_profile_replication_batch_number()
+ if latest_batch is None:
+ latest_batch = -1
+ for repl_host in self.hs.config.replicate_user_profiles_to:
+ if repl_host not in host_batches:
+ host_batches[repl_host] = -1
+ try:
+ for i in range(host_batches[repl_host] + 1, latest_batch + 1):
+ yield self._replicate_host_profile_batch(repl_host, i)
+ except Exception:
+ logger.exception(
+ "Exception while replicating to %s: aborting for now", repl_host
+ )
+
+ @defer.inlineCallbacks
+ def _replicate_host_profile_batch(self, host, batchnum):
+ logger.info("Replicating profile batch %d to %s", batchnum, host)
+ batch_rows = yield self.store.get_profile_batch(batchnum)
+ batch = {
+ UserID(r["user_id"], self.hs.hostname).to_string(): (
+ {"display_name": r["displayname"], "avatar_url": r["avatar_url"]}
+ if r["active"]
+ else None
+ )
+ for r in batch_rows
+ }
+
+ url = "https://%s/_matrix/identity/api/v1/replicate_profiles" % (host,)
+ body = {"batchnum": batchnum, "batch": batch, "origin_server": self.hs.hostname}
+ signed_body = sign_json(body, self.hs.hostname, self.hs.config.signing_key[0])
+ try:
+ yield self.http_client.post_json_get_json(url, signed_body)
+ yield self.store.update_replication_batch_for_host(host, batchnum)
+ logger.info("Sucessfully replicated profile batch %d to %s", batchnum, host)
+ except Exception:
+ # This will get retried when the looping call next comes around
+ logger.exception(
+ "Failed to replicate profile batch %d to %s", batchnum, host
+ )
+ raise
+
@defer.inlineCallbacks
def get_profile(self, user_id):
target_user = UserID.from_string(user_id)
@@ -73,7 +161,7 @@ class BaseProfileHandler(BaseHandler):
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
raise
- defer.returnValue({"displayname": displayname, "avatar_url": avatar_url})
+ return {"displayname": displayname, "avatar_url": avatar_url}
else:
try:
result = yield self.federation.make_query(
@@ -82,7 +170,7 @@ class BaseProfileHandler(BaseHandler):
args={"user_id": user_id},
ignore_backoff=True,
)
- defer.returnValue(result)
+ return result
except RequestSendFailed as e:
raise_from(SynapseError(502, "Failed to fetch profile"), e)
except HttpResponseException as e:
@@ -108,10 +196,10 @@ class BaseProfileHandler(BaseHandler):
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
raise
- defer.returnValue({"displayname": displayname, "avatar_url": avatar_url})
+ return {"displayname": displayname, "avatar_url": avatar_url}
else:
profile = yield self.store.get_from_remote_profile_cache(user_id)
- defer.returnValue(profile or {})
+ return profile or {}
@defer.inlineCallbacks
def get_displayname(self, target_user):
@@ -125,7 +213,7 @@ class BaseProfileHandler(BaseHandler):
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
raise
- defer.returnValue(displayname)
+ return displayname
else:
try:
result = yield self.federation.make_query(
@@ -139,7 +227,7 @@ class BaseProfileHandler(BaseHandler):
except HttpResponseException as e:
raise e.to_synapse_error()
- defer.returnValue(result["displayname"])
+ return result["displayname"]
@defer.inlineCallbacks
def set_displayname(self, target_user, requester, new_displayname, by_admin=False):
@@ -154,9 +242,16 @@ class BaseProfileHandler(BaseHandler):
if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this Home Server")
- if not by_admin and target_user != requester.user:
+ if not by_admin and requester and target_user != requester.user:
raise AuthError(400, "Cannot set another user's displayname")
+ if not by_admin and self.hs.config.disable_set_displayname:
+ profile = yield self.store.get_profileinfo(target_user.localpart)
+ if profile.display_name:
+ raise SynapseError(
+ 400, "Changing displayname is disabled on this server"
+ )
+
if len(new_displayname) > MAX_DISPLAYNAME_LEN:
raise SynapseError(
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
@@ -165,7 +260,17 @@ class BaseProfileHandler(BaseHandler):
if new_displayname == "":
new_displayname = None
- yield self.store.set_profile_displayname(target_user.localpart, new_displayname)
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ yield self.store.get_latest_profile_replication_batch_number()
+ )
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
+
+ yield self.store.set_profile_displayname(
+ target_user.localpart, new_displayname, new_batchnum
+ )
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(target_user.localpart)
@@ -173,7 +278,39 @@ class BaseProfileHandler(BaseHandler):
target_user.to_string(), profile
)
- yield self._update_join_states(requester, target_user)
+ if requester:
+ yield self._update_join_states(requester, target_user)
+
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
+
+ @defer.inlineCallbacks
+ def set_active(self, target_user, active, hide):
+ """
+ Sets the 'active' flag on a user profile. If set to false, the user
+ account is considered deactivated or hidden.
+
+ If 'hide' is true, then we interpret active=False as a request to try to
+ hide the user rather than deactivating it. This means withholding the
+ profile from replication (and mark it as inactive) rather than clearing
+ the profile from the HS DB. Note that unlike set_displayname and
+ set_avatar_url, this does *not* perform authorization checks! This is
+ because the only place it's used currently is in account deactivation
+ where we've already done these checks anyway.
+ """
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ yield self.store.get_latest_profile_replication_batch_number()
+ )
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
+ yield self.store.set_profile_active(
+ target_user.localpart, active, hide, new_batchnum
+ )
+
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
@defer.inlineCallbacks
def get_avatar_url(self, target_user):
@@ -186,7 +323,7 @@ class BaseProfileHandler(BaseHandler):
if e.code == 404:
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
raise
- defer.returnValue(avatar_url)
+ return avatar_url
else:
try:
result = yield self.federation.make_query(
@@ -200,7 +337,7 @@ class BaseProfileHandler(BaseHandler):
except HttpResponseException as e:
raise e.to_synapse_error()
- defer.returnValue(result["avatar_url"])
+ return result["avatar_url"]
@defer.inlineCallbacks
def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False):
@@ -212,12 +349,59 @@ class BaseProfileHandler(BaseHandler):
if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's avatar_url")
+ if not by_admin and self.hs.config.disable_set_avatar_url:
+ profile = yield self.store.get_profileinfo(target_user.localpart)
+ if profile.avatar_url:
+ raise SynapseError(
+ 400, "Changing avatar url is disabled on this server"
+ )
+
+ if len(self.hs.config.replicate_user_profiles_to) > 0:
+ cur_batchnum = (
+ yield self.store.get_latest_profile_replication_batch_number()
+ )
+ new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
+ else:
+ new_batchnum = None
+
if len(new_avatar_url) > MAX_AVATAR_URL_LEN:
raise SynapseError(
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
)
- yield self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
+ # Enforce a max avatar size if one is defined
+ if self.max_avatar_size or self.allowed_avatar_mimetypes:
+ media_id = self._validate_and_parse_media_id_from_avatar_url(new_avatar_url)
+
+ # Check that this media exists locally
+ media_info = yield self.store.get_local_media(media_id)
+ if not media_info:
+ raise SynapseError(
+ 400, "Unknown media id supplied", errcode=Codes.NOT_FOUND
+ )
+
+ # Ensure avatar does not exceed max allowed avatar size
+ media_size = media_info["media_length"]
+ if self.max_avatar_size and media_size > self.max_avatar_size:
+ raise SynapseError(
+ 400,
+ "Avatars must be less than %s bytes in size"
+ % (self.max_avatar_size,),
+ errcode=Codes.TOO_LARGE,
+ )
+
+ # Ensure the avatar's file type is allowed
+ if (
+ self.allowed_avatar_mimetypes
+ and media_info["media_type"] not in self.allowed_avatar_mimetypes
+ ):
+ raise SynapseError(
+ 400, "Avatar file type '%s' not allowed" % media_info["media_type"]
+ )
+
+ yield self.store.set_profile_avatar_url(
+ target_user.localpart, new_avatar_url, new_batchnum
+ )
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(target_user.localpart)
@@ -227,6 +411,23 @@ class BaseProfileHandler(BaseHandler):
yield self._update_join_states(requester, target_user)
+ # start a profile replication push
+ run_in_background(self._replicate_profiles)
+
+ def _validate_and_parse_media_id_from_avatar_url(self, mxc):
+ """Validate and parse a provided avatar url and return the local media id
+
+ Args:
+ mxc (str): A mxc URL
+
+ Returns:
+ str: The ID of the media
+ """
+ avatar_pieces = mxc.split("/")
+ if len(avatar_pieces) != 4 or avatar_pieces[0] != "mxc:":
+ raise SynapseError(400, "Invalid avatar URL '%s' supplied" % mxc)
+ return avatar_pieces[-1]
+
@defer.inlineCallbacks
def on_profile_query(self, args):
user = UserID.from_string(args["user_id"])
@@ -251,7 +452,7 @@ class BaseProfileHandler(BaseHandler):
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
raise
- defer.returnValue(response)
+ return response
@defer.inlineCallbacks
def _update_join_states(self, requester, target_user):
@@ -282,7 +483,7 @@ class BaseProfileHandler(BaseHandler):
@defer.inlineCallbacks
def check_profile_query_allowed(self, target_user, requester=None):
"""Checks whether a profile query is allowed. If the
- 'require_auth_for_profile_requests' config flag is set to True and a
+ 'limit_profile_requests_to_known_users' config flag is set to True and a
'requester' is provided, the query is only allowed if the two users
share a room.
@@ -300,7 +501,11 @@ class BaseProfileHandler(BaseHandler):
# be None when this function is called outside of a profile query, e.g.
# when building a membership event. In this case, we must allow the
# lookup.
- if not self.hs.config.require_auth_for_profile_requests or not requester:
+ if not self.hs.config.limit_profile_requests_to_known_users or not requester:
+ return
+
+ # Always allow the user to query their own profile.
+ if target_user.to_string() == requester.to_string():
return
# Always allow the user to query their own profile.
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index e58bf7e360..73973502a4 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -93,7 +93,7 @@ class ReceiptsHandler(BaseHandler):
if min_batch_id is None:
# no new receipts
- defer.returnValue(False)
+ return False
affected_room_ids = list(set([r.room_id for r in receipts]))
@@ -103,7 +103,7 @@ class ReceiptsHandler(BaseHandler):
min_batch_id, max_batch_id, affected_room_ids
)
- defer.returnValue(True)
+ return True
@defer.inlineCallbacks
def received_client_receipt(self, room_id, receipt_type, user_id, event_id):
@@ -133,9 +133,9 @@ class ReceiptsHandler(BaseHandler):
)
if not result:
- defer.returnValue([])
+ return []
- defer.returnValue(result)
+ return result
class ReceiptEventSource(object):
@@ -148,13 +148,13 @@ class ReceiptEventSource(object):
to_key = yield self.get_current_key()
if from_key == to_key:
- defer.returnValue(([], to_key))
+ return ([], to_key)
events = yield self.store.get_linearized_receipts_for_rooms(
room_ids, from_key=from_key, to_key=to_key
)
- defer.returnValue((events, to_key))
+ return (events, to_key)
def get_current_key(self, direction="f"):
return self.store.get_max_receipt_stream_id()
@@ -173,4 +173,4 @@ class ReceiptEventSource(object):
room_ids, from_key=from_key, to_key=to_key
)
- defer.returnValue((events, to_key))
+ return (events, to_key)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index bb7cfd71b9..0daf193945 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -60,6 +60,7 @@ class RegistrationHandler(BaseHandler):
self.profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
self.captcha_client = CaptchaServerHttpClient(hs)
+ self.http_client = hs.get_simple_http_client()
self.identity_handler = self.hs.get_handlers().identity_handler
self.ratelimiter = hs.get_registration_ratelimiter()
@@ -72,6 +73,8 @@ class RegistrationHandler(BaseHandler):
)
self._server_notices_mxid = hs.config.server_notices_mxid
+ self._show_in_user_directory = self.hs.config.show_users_in_user_directory
+
if hs.config.worker_app:
self._register_client = ReplicationRegisterServlet.make_client(hs)
self._register_device_client = RegisterDeviceReplicationServlet.make_client(
@@ -213,6 +216,11 @@ class RegistrationHandler(BaseHandler):
address=address,
)
+ if default_display_name:
+ yield self.profile_handler.set_displayname(
+ user, None, default_display_name, by_admin=True
+ )
+
if self.hs.config.user_directory_search_all_users:
profile = yield self.store.get_profileinfo(localpart)
yield self.user_directory_handler.handle_local_profile_change(
@@ -238,6 +246,11 @@ class RegistrationHandler(BaseHandler):
create_profile_with_displayname=default_display_name,
address=address,
)
+
+ yield self.profile_handler.set_displayname(
+ user, None, default_display_name, by_admin=True
+ )
+
except SynapseError:
# if user id is taken, just generate another
user = None
@@ -265,7 +278,15 @@ class RegistrationHandler(BaseHandler):
# Bind email to new account
yield self._register_email_threepid(user_id, threepid_dict, None, False)
- defer.returnValue(user_id)
+ # Prevent the new user from showing up in the user directory if the server
+ # mandates it.
+ if not self._show_in_user_directory:
+ yield self.store.add_account_data_for_user(
+ user_id, "im.vector.hide_profile", {"hide_profile": True}
+ )
+ yield self.profile_handler.set_active(user, False, True)
+
+ return user_id
@defer.inlineCallbacks
def _auto_join_rooms(self, user_id):
@@ -335,7 +356,9 @@ class RegistrationHandler(BaseHandler):
yield self._auto_join_rooms(user_id)
@defer.inlineCallbacks
- def appservice_register(self, user_localpart, as_token):
+ def appservice_register(self, user_localpart, as_token, password, display_name):
+ # FIXME: this should be factored out and merged with normal register()
+
user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string()
service = self.store.get_app_service_by_token(as_token)
@@ -354,13 +377,30 @@ class RegistrationHandler(BaseHandler):
user_id, allowed_appservice=service
)
+ password_hash = ""
+ if password:
+ password_hash = yield self.auth_handler().hash(password)
+
+ display_name = display_name or user.localpart
+
yield self.register_with_store(
user_id=user_id,
- password_hash="",
+ password_hash=password_hash,
appservice_id=service_id,
- create_profile_with_displayname=user.localpart,
+ create_profile_with_displayname=display_name,
)
- defer.returnValue(user_id)
+
+ yield self.profile_handler.set_displayname(
+ user, None, display_name, by_admin=True
+ )
+
+ if self.hs.config.user_directory_search_all_users:
+ profile = yield self.store.get_profileinfo(user_localpart)
+ yield self.user_directory_handler.handle_local_profile_change(
+ user_id, profile
+ )
+
+ return user_id
@defer.inlineCallbacks
def check_recaptcha(self, ip, private_key, challenge, response):
@@ -384,6 +424,38 @@ class RegistrationHandler(BaseHandler):
logger.info("Valid captcha entered from %s", ip)
@defer.inlineCallbacks
+ def register_saml2(self, localpart):
+ """
+ Registers email_id as SAML2 Based Auth.
+ """
+ if types.contains_invalid_mxid_characters(localpart):
+ raise SynapseError(
+ 400, "User ID can only contain characters a-z, 0-9, or '=_-./'"
+ )
+ yield self.auth.check_auth_blocking()
+ user = UserID(localpart, self.hs.hostname)
+ user_id = user.to_string()
+
+ yield self.check_user_id_not_appservice_exclusive(user_id)
+ token = self.macaroon_gen.generate_access_token(user_id)
+ try:
+ yield self.register_with_store(
+ user_id=user_id,
+ token=token,
+ password_hash=None,
+ create_profile_with_displayname=user.localpart,
+ )
+
+ yield self.profile_handler.set_displayname(
+ user, None, user.localpart, by_admin=True
+ )
+ except Exception as e:
+ yield self.store.add_access_token_to_user(user_id, token)
+ # Ignore Registration errors
+ logger.exception(e)
+ defer.returnValue((user_id, token))
+
+ @defer.inlineCallbacks
def register_email(self, threepidCreds):
"""
Registers emails with an identity server.
@@ -451,6 +523,39 @@ class RegistrationHandler(BaseHandler):
)
@defer.inlineCallbacks
+ def shadow_register(self, localpart, display_name, auth_result, params):
+ """Invokes the current registration on another server, using
+ shared secret registration, passing in any auth_results from
+ other registration UI auth flows (e.g. validated 3pids)
+ Useful for setting up shadow/backup accounts on a parallel deployment.
+ """
+
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/register?access_token=%s" % (shadow_hs_url, as_token),
+ {
+ # XXX: auth_result is an unspecified extension for shadow registration
+ "auth_result": auth_result,
+ # XXX: another unspecified extension for shadow registration to ensure
+ # that the displayname is correctly set by the masters erver
+ "display_name": display_name,
+ "username": localpart,
+ "password": params.get("password"),
+ "bind_email": params.get("bind_email"),
+ "bind_msisdn": params.get("bind_msisdn"),
+ "device_id": params.get("device_id"),
+ "initial_device_display_name": params.get(
+ "initial_device_display_name"
+ ),
+ "inhibit_login": False,
+ "access_token": as_token,
+ },
+ )
+
+ @defer.inlineCallbacks
def _generate_user_id(self, reseed=False):
if reseed or self._next_generated_user_id is None:
with (yield self._generate_user_id_linearizer.queue(())):
@@ -461,7 +566,7 @@ class RegistrationHandler(BaseHandler):
id = self._next_generated_user_id
self._next_generated_user_id += 1
- defer.returnValue(str(id))
+ return str(id)
@defer.inlineCallbacks
def _validate_captcha(self, ip_addr, private_key, challenge, response):
@@ -481,7 +586,7 @@ class RegistrationHandler(BaseHandler):
"error_url": "http://www.recaptcha.net/recaptcha/api/challenge?"
+ "error=%s" % lines[1],
}
- defer.returnValue(json)
+ return json
@defer.inlineCallbacks
def _submit_captcha(self, ip_addr, private_key, challenge, response):
@@ -497,7 +602,56 @@ class RegistrationHandler(BaseHandler):
"response": response,
},
)
- defer.returnValue(data)
+ return data
+
+ @defer.inlineCallbacks
+ def get_or_create_user(self, requester, localpart, displayname, password_hash=None):
+ """Creates a new user if the user does not exist,
+ else revokes all previous access tokens and generates a new one.
+
+ Args:
+ localpart : The local part of the user ID to register. If None,
+ one will be randomly generated.
+ Returns:
+ A tuple of (user_id, access_token).
+ Raises:
+ RegistrationError if there was a problem registering.
+
+ NB this is only used in tests. TODO: move it to the test package!
+ """
+ if localpart is None:
+ raise SynapseError(400, "Request must include user id")
+ yield self.auth.check_auth_blocking()
+ need_register = True
+
+ try:
+ yield self.check_username(localpart)
+ except SynapseError as e:
+ if e.errcode == Codes.USER_IN_USE:
+ need_register = False
+ else:
+ raise
+
+ user = UserID(localpart, self.hs.hostname)
+ user_id = user.to_string()
+ token = self.macaroon_gen.generate_access_token(user_id)
+
+ if need_register:
+ yield self.register_with_store(
+ user_id=user_id,
+ token=token,
+ password_hash=password_hash,
+ create_profile_with_displayname=displayname or user.localpart,
+ )
+ if displayname is not None:
+ yield self.profile_handler.set_displayname(
+ user, None, displayname or user.localpart, by_admin=True
+ )
+ else:
+ yield self._auth_handler.delete_access_tokens_for_user(user_id)
+ yield self.store.add_access_token_to_user(user_id=user_id, token=token)
+
+ defer.returnValue((user_id, token))
@defer.inlineCallbacks
def _join_user_to_room(self, requester, room_identifier):
@@ -622,7 +776,7 @@ class RegistrationHandler(BaseHandler):
initial_display_name=initial_display_name,
is_guest=is_guest,
)
- defer.returnValue((r["device_id"], r["access_token"]))
+ return (r["device_id"], r["access_token"])
valid_until_ms = None
if self.session_lifetime is not None:
@@ -645,7 +799,7 @@ class RegistrationHandler(BaseHandler):
user_id, device_id=device_id, valid_until_ms=valid_until_ms
)
- defer.returnValue((device_id, access_token))
+ return (device_id, access_token)
@defer.inlineCallbacks
def post_registration_actions(
@@ -798,7 +952,7 @@ class RegistrationHandler(BaseHandler):
if ex.errcode == Codes.MISSING_PARAM:
# This will only happen if the ID server returns a malformed response
logger.info("Can't add incomplete 3pid")
- defer.returnValue(None)
+ return None
raise
yield self._auth_handler.add_threepid(
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index db3f8cb76b..af7cfa7888 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -52,12 +52,14 @@ class RoomCreationHandler(BaseHandler):
"history_visibility": "shared",
"original_invitees_have_ops": False,
"guest_can_join": True,
+ "encryption_alg": "m.megolm.v1.aes-sha2",
},
RoomCreationPreset.TRUSTED_PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE,
"history_visibility": "shared",
"original_invitees_have_ops": True,
"guest_can_join": True,
+ "encryption_alg": "m.megolm.v1.aes-sha2",
},
RoomCreationPreset.PUBLIC_CHAT: {
"join_rules": JoinRules.PUBLIC,
@@ -128,7 +130,7 @@ class RoomCreationHandler(BaseHandler):
old_room_id,
new_version, # args for _upgrade_room
)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def _upgrade_room(self, requester, old_room_id, new_version):
@@ -193,7 +195,7 @@ class RoomCreationHandler(BaseHandler):
requester, old_room_id, new_room_id, old_room_state
)
- defer.returnValue(new_room_id)
+ return new_room_id
@defer.inlineCallbacks
def _update_upgraded_room_pls(
@@ -294,7 +296,19 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- if not self.spam_checker.user_may_create_room(user_id):
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
+ # allow the server notices mxid to create rooms
+ is_requester_admin = True
+
+ else:
+ is_requester_admin = yield self.auth.is_server_admin(requester.user)
+
+ if not is_requester_admin and not self.spam_checker.user_may_create_room(
+ user_id, invite_list=[], third_party_invite_list=[], cloning=True
+ ):
raise SynapseError(403, "You are not permitted to create rooms")
creation_content = {
@@ -516,8 +530,14 @@ class RoomCreationHandler(BaseHandler):
requester, config, is_requester_admin=is_requester_admin
)
+ invite_list = config.get("invite", [])
+ invite_3pid_list = config.get("invite_3pid", [])
+
if not is_requester_admin and not self.spam_checker.user_may_create_room(
- user_id
+ user_id,
+ invite_list=invite_list,
+ third_party_invite_list=invite_3pid_list,
+ cloning=False,
):
raise SynapseError(403, "You are not permitted to create rooms")
@@ -551,7 +571,6 @@ class RoomCreationHandler(BaseHandler):
else:
room_alias = None
- invite_list = config.get("invite", [])
for i in invite_list:
try:
UserID.from_string(i)
@@ -560,8 +579,6 @@ class RoomCreationHandler(BaseHandler):
yield self.event_creation_handler.assert_accepted_privacy_policy(requester)
- invite_3pid_list = config.get("invite_3pid", [])
-
visibility = config.get("visibility", None)
is_public = visibility == "public"
@@ -649,6 +666,7 @@ class RoomCreationHandler(BaseHandler):
"invite",
ratelimit=False,
content=content,
+ new_room=True,
)
for invite_3pid in invite_3pid_list:
@@ -663,6 +681,7 @@ class RoomCreationHandler(BaseHandler):
id_server,
requester,
txn_id=None,
+ new_room=True,
)
result = {"room_id": room_id}
@@ -671,7 +690,7 @@ class RoomCreationHandler(BaseHandler):
result["room_alias"] = room_alias.to_string()
yield directory_handler.send_room_alias_update_event(requester, room_id)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _send_events_for_new_room(
@@ -719,6 +738,7 @@ class RoomCreationHandler(BaseHandler):
"join",
ratelimit=False,
content=creator_join_profile,
+ new_room=True,
)
# We treat the power levels override specially as this needs to be one
@@ -780,6 +800,13 @@ class RoomCreationHandler(BaseHandler):
for (etype, state_key), content in initial_state.items():
yield send(etype=etype, state_key=state_key, content=content)
+ if "encryption_alg" in config:
+ yield send(
+ etype=EventTypes.Encryption,
+ state_key="",
+ content={"algorithm": config["encryption_alg"]},
+ )
+
@defer.inlineCallbacks
def _generate_room_id(self, creator_id, is_public):
# autogen room IDs and try to create it. We may clash, so just
@@ -796,7 +823,7 @@ class RoomCreationHandler(BaseHandler):
room_creator_user_id=creator_id,
is_public=is_public,
)
- defer.returnValue(gen_room_id)
+ return gen_room_id
except StoreError:
attempts += 1
raise StoreError(500, "Couldn't generate a room ID.")
@@ -839,7 +866,7 @@ class RoomContextHandler(object):
event_id, get_prev_content=True, allow_none=True
)
if not event:
- defer.returnValue(None)
+ return None
return
filtered = yield (filter_evts([event]))
@@ -890,7 +917,7 @@ class RoomContextHandler(object):
results["end"] = token.copy_and_replace("room_key", results["end"]).to_string()
- defer.returnValue(results)
+ return results
class RoomEventSource(object):
@@ -941,7 +968,7 @@ class RoomEventSource(object):
else:
end_key = to_key
- defer.returnValue((events, end_key))
+ return (events, end_key)
def get_current_key(self):
return self.store.get_room_events_max_id()
@@ -959,4 +986,4 @@ class RoomEventSource(object):
limit=config.limit,
)
- defer.returnValue((events, next_key))
+ return (events, next_key)
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index aae696a7e8..e9094ad02b 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -325,7 +325,7 @@ class RoomListHandler(BaseHandler):
current_limit=since_token.current_limit - 1,
).to_token()
- defer.returnValue(results)
+ return results
@defer.inlineCallbacks
def _append_room_entry_to_chunk(
@@ -420,7 +420,7 @@ class RoomListHandler(BaseHandler):
if join_rules_event:
join_rule = join_rules_event.content.get("join_rule", None)
if not allow_private and join_rule and join_rule != JoinRules.PUBLIC:
- defer.returnValue(None)
+ return None
# Return whether this room is open to federation users or not
create_event = current_state.get((EventTypes.Create, ""))
@@ -469,7 +469,7 @@ class RoomListHandler(BaseHandler):
if avatar_url:
result["avatar_url"] = avatar_url
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def get_remote_public_room_list(
@@ -482,7 +482,7 @@ class RoomListHandler(BaseHandler):
third_party_instance_id=None,
):
if not self.enable_room_list_search:
- defer.returnValue({"chunk": [], "total_room_count_estimate": 0})
+ return {"chunk": [], "total_room_count_estimate": 0}
if search_filter:
# We currently don't support searching across federation, so we have
@@ -507,7 +507,7 @@ class RoomListHandler(BaseHandler):
]
}
- defer.returnValue(res)
+ return res
def _get_remote_list_cached(
self,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index e0196ef83e..e2ac2637c5 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -20,22 +20,23 @@ import logging
from six.moves import http_client
-from signedjson.key import decode_verify_key_bytes
-from signedjson.sign import verify_signed_json
-from unpaddedbase64 import decode_base64
-
from twisted.internet import defer
import synapse.server
import synapse.types
from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import AuthError, Codes, HttpResponseException, SynapseError
+from synapse.api.ratelimiting import Ratelimiter
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ ProxiedRequestError,
+ HttpResponseException,
+ SynapseError,
+)
from synapse.types import RoomID, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room
-from ._base import BaseHandler
-
logger = logging.getLogger(__name__)
id_server_scheme = "https://"
@@ -67,6 +68,7 @@ class RoomMemberHandler(object):
self.registration_handler = hs.get_registration_handler()
self.profile_handler = hs.get_profile_handler()
self.event_creation_handler = hs.get_event_creation_handler()
+ self.identity_handler = hs.get_handlers().identity_handler
self.member_linearizer = Linearizer(name="member")
@@ -74,13 +76,10 @@ class RoomMemberHandler(object):
self.spam_checker = hs.get_spam_checker()
self.third_party_event_rules = hs.get_third_party_event_rules()
self._server_notices_mxid = self.config.server_notices_mxid
+ self.rewrite_identity_server_urls = self.config.rewrite_identity_server_urls
self._enable_lookup = hs.config.enable_3pid_lookup
self.allow_per_room_profiles = self.config.allow_per_room_profiles
-
- # This is only used to get at ratelimit function, and
- # maybe_kick_guest_users. It's fine there are multiple of these as
- # it doesn't store state.
- self.base_handler = BaseHandler(hs)
+ self.ratelimiter = Ratelimiter()
@abc.abstractmethod
def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
@@ -191,7 +190,7 @@ class RoomMemberHandler(object):
)
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
- defer.returnValue(duplicate)
+ return duplicate
yield self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target], ratelimit=ratelimit
@@ -233,7 +232,7 @@ class RoomMemberHandler(object):
if prev_member_event.membership == Membership.JOIN:
yield self._user_left_room(target, room_id)
- defer.returnValue(event)
+ return event
@defer.inlineCallbacks
def copy_room_tags_and_direct_to_room(self, old_room_id, new_room_id, user_id):
@@ -285,8 +284,31 @@ class RoomMemberHandler(object):
third_party_signed=None,
ratelimit=True,
content=None,
+ new_room=False,
require_consent=True,
):
+ """Update a users membership in a room
+
+ Args:
+ requester (Requester)
+ target (UserID)
+ room_id (str)
+ action (str): The "action" the requester is performing against the
+ target. One of join/leave/kick/ban/invite/unban.
+ txn_id (str|None): The transaction ID associated with the request,
+ or None not provided.
+ remote_room_hosts (list[str]|None): List of remote servers to try
+ and join via if server isn't already in the room.
+ third_party_signed (dict|None): The signed object for third party
+ invites.
+ ratelimit (bool): Whether to apply ratelimiting to this request.
+ content (dict|None): Fields to include in the new events content.
+ new_room (bool): Whether these membership changes are happening
+ as part of a room creation (e.g. initial joins and invites)
+
+ Returns:
+ Deferred[FrozenEvent]
+ """
key = (room_id,)
with (yield self.member_linearizer.queue(key)):
@@ -300,10 +322,11 @@ class RoomMemberHandler(object):
third_party_signed=third_party_signed,
ratelimit=ratelimit,
content=content,
+ new_room=new_room,
require_consent=require_consent,
)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _update_membership(
@@ -317,6 +340,7 @@ class RoomMemberHandler(object):
third_party_signed=None,
ratelimit=True,
content=None,
+ new_room=False,
require_consent=True,
):
content_specified = bool(content)
@@ -381,8 +405,15 @@ class RoomMemberHandler(object):
)
block_invite = True
+ is_published = yield self.store.is_room_published(room_id)
+
if not self.spam_checker.user_may_invite(
- requester.user.to_string(), target.to_string(), room_id
+ requester.user.to_string(),
+ target.to_string(),
+ third_party_invite=None,
+ room_id=room_id,
+ new_room=new_room,
+ published_room=is_published,
):
logger.info("Blocking invite due to spam checker")
block_invite = True
@@ -423,7 +454,7 @@ class RoomMemberHandler(object):
same_membership = old_membership == effective_membership_state
same_sender = requester.user.to_string() == old_state.sender
if same_sender and same_membership and same_content:
- defer.returnValue(old_state)
+ return old_state
if old_membership in ["ban", "leave"] and action == "kick":
raise AuthError(403, "The target user is not in the room")
@@ -455,8 +486,26 @@ class RoomMemberHandler(object):
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
+ # allow the server notices mxid to join rooms
+ is_requester_admin = True
+
+ else:
+ is_requester_admin = yield self.auth.is_server_admin(requester.user)
+
+ inviter = yield self._get_inviter(target.to_string(), room_id)
+ if not is_requester_admin:
+ # We assume that if the spam checker allowed the user to create
+ # a room then they're allowed to join it.
+ if not new_room and not self.spam_checker.user_may_join_room(
+ target.to_string(), room_id, is_invited=inviter is not None
+ ):
+ raise SynapseError(403, "Not allowed to join this room")
+
if not is_host_in_room:
- inviter = yield self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
@@ -473,7 +522,7 @@ class RoomMemberHandler(object):
ret = yield self._remote_join(
requester, remote_room_hosts, room_id, target, content
)
- defer.returnValue(ret)
+ return ret
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
@@ -495,7 +544,7 @@ class RoomMemberHandler(object):
res = yield self._remote_reject_invite(
requester, remote_room_hosts, room_id, target
)
- defer.returnValue(res)
+ return res
res = yield self._local_membership_update(
requester=requester,
@@ -508,7 +557,7 @@ class RoomMemberHandler(object):
content=content,
require_consent=require_consent,
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def send_membership_event(
@@ -596,11 +645,11 @@ class RoomMemberHandler(object):
"""
guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None)
if not guest_access_id:
- defer.returnValue(False)
+ return False
guest_access = yield self.store.get_event(guest_access_id)
- defer.returnValue(
+ return (
guest_access
and guest_access.content
and "guest_access" in guest_access.content
@@ -635,7 +684,7 @@ class RoomMemberHandler(object):
servers.remove(room_alias.domain)
servers.insert(0, room_alias.domain)
- defer.returnValue((RoomID.from_string(room_id), servers))
+ return (RoomID.from_string(room_id), servers)
@defer.inlineCallbacks
def _get_inviter(self, user_id, room_id):
@@ -643,11 +692,19 @@ class RoomMemberHandler(object):
user_id=user_id, room_id=room_id
)
if invite:
- defer.returnValue(UserID.from_string(invite.sender))
+ return UserID.from_string(invite.sender)
@defer.inlineCallbacks
def do_3pid_invite(
- self, room_id, inviter, medium, address, id_server, requester, txn_id
+ self,
+ room_id,
+ inviter,
+ medium,
+ address,
+ id_server,
+ requester,
+ txn_id,
+ new_room=False,
):
if self.config.block_non_admin_invites:
is_requester_admin = yield self.auth.is_server_admin(requester.user)
@@ -658,7 +715,23 @@ class RoomMemberHandler(object):
# We need to rate limit *before* we send out any 3PID invites, so we
# can't just rely on the standard ratelimiting of events.
- yield self.base_handler.ratelimit(requester)
+ self.ratelimiter.ratelimit(
+ requester.user.to_string(),
+ time_now_s=self.hs.clock.time(),
+ rate_hz=self.hs.config.rc_third_party_invite.per_second,
+ burst_count=self.hs.config.rc_third_party_invite.burst_count,
+ update=True,
+ )
+
+ can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited(
+ medium, address, room_id
+ )
+ if not can_invite:
+ raise SynapseError(
+ 403,
+ "This third-party identifier can not be invited in this room",
+ Codes.FORBIDDEN,
+ )
can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited(
medium, address, room_id
@@ -672,6 +745,19 @@ class RoomMemberHandler(object):
invitee = yield self._lookup_3pid(id_server, medium, address)
+ is_published = yield self.store.is_room_published(room_id)
+
+ if not self.spam_checker.user_may_invite(
+ requester.user.to_string(),
+ invitee,
+ third_party_invite={"medium": medium, "address": address},
+ room_id=room_id,
+ new_room=new_room,
+ published_room=is_published,
+ ):
+ logger.info("Blocking invite due to spam checker")
+ raise SynapseError(403, "Invites have been disabled on this server")
+
if invitee:
yield self.update_membership(
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
@@ -681,6 +767,20 @@ class RoomMemberHandler(object):
requester, id_server, medium, address, room_id, inviter, txn_id=txn_id
)
+ def _get_id_server_target(self, id_server):
+ """Looks up an id_server's actual http endpoint
+
+ Args:
+ id_server (str): the server name to lookup.
+
+ Returns:
+ the http endpoint to connect to.
+ """
+ if id_server in self.rewrite_identity_server_urls:
+ return self.rewrite_identity_server_urls[id_server]
+
+ return id_server
+
@defer.inlineCallbacks
def _lookup_3pid(self, id_server, medium, address):
"""Looks up a 3pid in the passed identity server.
@@ -694,47 +794,12 @@ class RoomMemberHandler(object):
Returns:
str: the matrix ID of the 3pid, or None if it is not recognized.
"""
- if not self._enable_lookup:
- raise SynapseError(
- 403, "Looking up third-party identifiers is denied from this server"
- )
try:
- data = yield self.simple_http_client.get_json(
- "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
- {"medium": medium, "address": address},
- )
-
- if "mxid" in data:
- if "signatures" not in data:
- raise AuthError(401, "No signatures on 3pid binding")
- yield self._verify_any_signature(data, id_server)
- defer.returnValue(data["mxid"])
-
- except IOError as e:
+ data = yield self.identity_handler.lookup_3pid(id_server, medium, address)
+ return data.get("mxid")
+ except ProxiedRequestError as e:
logger.warn("Error from identity server lookup: %s" % (e,))
- defer.returnValue(None)
-
- @defer.inlineCallbacks
- def _verify_any_signature(self, data, server_hostname):
- if server_hostname not in data["signatures"]:
- raise AuthError(401, "No signature from server %s" % (server_hostname,))
- for key_name, signature in data["signatures"][server_hostname].items():
- key_data = yield self.simple_http_client.get_json(
- "%s%s/_matrix/identity/api/v1/pubkey/%s"
- % (id_server_scheme, server_hostname, key_name)
- )
- if "public_key" not in key_data:
- raise AuthError(
- 401, "No public key named %s from %s" % (key_name, server_hostname)
- )
- verify_signed_json(
- data,
- server_hostname,
- decode_verify_key_bytes(
- key_name, decode_base64(key_data["public_key"])
- ),
- )
- return
+ return None
@defer.inlineCallbacks
def _make_and_store_3pid_invite(
@@ -854,9 +919,10 @@ class RoomMemberHandler(object):
user.
"""
+ target = self._get_id_server_target(id_server)
is_url = "%s%s/_matrix/identity/api/v1/store-invite" % (
id_server_scheme,
- id_server,
+ target,
)
invite_config = {
@@ -896,7 +962,7 @@ class RoomMemberHandler(object):
fallback_public_key = {
"public_key": data["public_key"],
"key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid"
- % (id_server_scheme, id_server),
+ % (id_server_scheme, target),
}
else:
fallback_public_key = public_keys[0]
@@ -904,7 +970,7 @@ class RoomMemberHandler(object):
if not public_keys:
public_keys.append(fallback_public_key)
display_name = data["display_name"]
- defer.returnValue((token, public_keys, fallback_public_key, display_name))
+ return (token, public_keys, fallback_public_key, display_name)
@defer.inlineCallbacks
def _is_host_in_room(self, current_state_ids):
@@ -913,7 +979,7 @@ class RoomMemberHandler(object):
create_event_id = current_state_ids.get(("m.room.create", ""))
if len(current_state_ids) == 1 and create_event_id:
# We can only get here if we're in the process of creating the room
- defer.returnValue(True)
+ return True
for etype, state_key in current_state_ids:
if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
@@ -925,16 +991,16 @@ class RoomMemberHandler(object):
continue
if event.membership == Membership.JOIN:
- defer.returnValue(True)
+ return True
- defer.returnValue(False)
+ return False
@defer.inlineCallbacks
def _is_server_notice_room(self, room_id):
if self._server_notices_mxid is None:
- defer.returnValue(False)
+ return False
user_ids = yield self.store.get_users_in_room(room_id)
- defer.returnValue(self._server_notices_mxid in user_ids)
+ return self._server_notices_mxid in user_ids
class RoomMemberMasterHandler(RoomMemberHandler):
@@ -978,7 +1044,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
ret = yield fed_handler.do_remotely_reject_invite(
remote_room_hosts, room_id, target.to_string()
)
- defer.returnValue(ret)
+ return ret
except Exception as e:
# if we were unable to reject the exception, just mark
# it as rejected on our end and plough ahead.
@@ -989,7 +1055,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
logger.warn("Failed to reject invite: %s", e)
yield self.store.locally_reject_invite(target.to_string(), room_id)
- defer.returnValue({})
+ return {}
def _user_joined_room(self, target, room_id):
"""Implements RoomMemberHandler._user_joined_room
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index fc873a3ba6..75e96ae1a2 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -53,7 +53,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
yield self._user_joined_room(user, room_id)
- defer.returnValue(ret)
+ return ret
def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
"""Implements RoomMemberHandler._remote_reject_invite
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index ddc4430d03..cd5e90bacb 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -69,7 +69,7 @@ class SearchHandler(BaseHandler):
# Scan through the old room for further predecessors
room_id = predecessor["room_id"]
- defer.returnValue(historical_room_ids)
+ return historical_room_ids
@defer.inlineCallbacks
def search(self, user, content, batch=None):
@@ -186,13 +186,11 @@ class SearchHandler(BaseHandler):
room_ids.intersection_update({batch_group_key})
if not room_ids:
- defer.returnValue(
- {
- "search_categories": {
- "room_events": {"results": [], "count": 0, "highlights": []}
- }
+ return {
+ "search_categories": {
+ "room_events": {"results": [], "count": 0, "highlights": []}
}
- )
+ }
rank_map = {} # event_id -> rank of event
allowed_events = []
@@ -455,4 +453,4 @@ class SearchHandler(BaseHandler):
if global_next_batch:
rooms_cat_res["next_batch"] = global_next_batch
- defer.returnValue({"search_categories": {"room_events": rooms_cat_res}})
+ return {"search_categories": {"room_events": rooms_cat_res}}
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index d90c9e0108..3f50d6de47 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2017 New Vector Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -30,12 +31,15 @@ class SetPasswordHandler(BaseHandler):
super(SetPasswordHandler, self).__init__(hs)
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
+ self._password_policy_handler = hs.get_password_policy_handler()
@defer.inlineCallbacks
def set_password(self, user_id, newpassword, requester=None):
if not self.hs.config.password_localdb_enabled:
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
+ self._password_policy_handler.validate_password(newpassword)
+
password_hash = yield self._auth_handler.hash(newpassword)
except_device_id = requester.device_id if requester else None
diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py
index 6b364befd5..f065970c40 100644
--- a/synapse/handlers/state_deltas.py
+++ b/synapse/handlers/state_deltas.py
@@ -48,7 +48,7 @@ class StateDeltasHandler(object):
if not event and not prev_event:
logger.debug("Neither event exists: %r %r", prev_event_id, event_id)
- defer.returnValue(None)
+ return None
prev_value = None
value = None
@@ -62,8 +62,8 @@ class StateDeltasHandler(object):
logger.debug("prev_value: %r -> value: %r", prev_value, value)
if value == public_value and prev_value != public_value:
- defer.returnValue(True)
+ return True
elif value != public_value and prev_value == public_value:
- defer.returnValue(False)
+ return False
else:
- defer.returnValue(None)
+ return None
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index a0ee8db988..4449da6669 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -86,7 +86,7 @@ class StatsHandler(StateDeltasHandler):
# If still None then the initial background update hasn't happened yet
if self.pos is None:
- defer.returnValue(None)
+ return None
# Loop round handling deltas until we're up to date
while True:
@@ -328,6 +328,6 @@ class StatsHandler(StateDeltasHandler):
== "world_readable"
)
):
- defer.returnValue(True)
+ return True
else:
- defer.returnValue(False)
+ return False
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index cd1ac0a27a..4007284e5b 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -263,7 +263,7 @@ class SyncHandler(object):
timeout,
full_state,
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def _wait_for_sync_for_user(self, sync_config, since_token, timeout, full_state):
@@ -303,7 +303,7 @@ class SyncHandler(object):
lazy_loaded = "false"
non_empty_sync_counter.labels(sync_type, lazy_loaded).inc()
- defer.returnValue(result)
+ return result
def current_sync_for_user(self, sync_config, since_token=None, full_state=False):
"""Get the sync for client needed to match what the server has now.
@@ -317,7 +317,7 @@ class SyncHandler(object):
user_id = user.to_string()
rules = yield self.store.get_push_rules_for_user(user_id)
rules = format_push_rules_for_user(user, rules)
- defer.returnValue(rules)
+ return rules
@defer.inlineCallbacks
def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None):
@@ -378,7 +378,7 @@ class SyncHandler(object):
event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
- defer.returnValue((now_token, ephemeral_by_room))
+ return (now_token, ephemeral_by_room)
@defer.inlineCallbacks
def _load_filtered_recents(
@@ -426,8 +426,8 @@ class SyncHandler(object):
recents = []
if not limited or block_all_timeline:
- defer.returnValue(
- TimelineBatch(events=recents, prev_batch=now_token, limited=False)
+ return TimelineBatch(
+ events=recents, prev_batch=now_token, limited=False
)
filtering_factor = 2
@@ -490,12 +490,10 @@ class SyncHandler(object):
prev_batch_token = now_token.copy_and_replace("room_key", room_key)
- defer.returnValue(
- TimelineBatch(
- events=recents,
- prev_batch=prev_batch_token,
- limited=limited or newly_joined_room,
- )
+ return TimelineBatch(
+ events=recents,
+ prev_batch=prev_batch_token,
+ limited=limited or newly_joined_room,
)
@defer.inlineCallbacks
@@ -517,7 +515,7 @@ class SyncHandler(object):
if event.is_state():
state_ids = state_ids.copy()
state_ids[(event.type, event.state_key)] = event.event_id
- defer.returnValue(state_ids)
+ return state_ids
@defer.inlineCallbacks
def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()):
@@ -549,7 +547,7 @@ class SyncHandler(object):
else:
# no events in this room - so presumably no state
state = {}
- defer.returnValue(state)
+ return state
@defer.inlineCallbacks
def compute_summary(self, room_id, sync_config, batch, state, now_token):
@@ -579,7 +577,7 @@ class SyncHandler(object):
)
if not last_events:
- defer.returnValue(None)
+ return None
return
last_event = last_events[-1]
@@ -611,14 +609,14 @@ class SyncHandler(object):
if name_id:
name = yield self.store.get_event(name_id, allow_none=True)
if name and name.content.get("name"):
- defer.returnValue(summary)
+ return summary
if canonical_alias_id:
canonical_alias = yield self.store.get_event(
canonical_alias_id, allow_none=True
)
if canonical_alias and canonical_alias.content.get("alias"):
- defer.returnValue(summary)
+ return summary
me = sync_config.user.to_string()
@@ -652,7 +650,7 @@ class SyncHandler(object):
summary["m.heroes"] = sorted([user_id for user_id in gone_user_ids])[0:5]
if not sync_config.filter_collection.lazy_load_members():
- defer.returnValue(summary)
+ return summary
# ensure we send membership events for heroes if needed
cache_key = (sync_config.user.to_string(), sync_config.device_id)
@@ -686,7 +684,7 @@ class SyncHandler(object):
cache.set(s.state_key, s.event_id)
state[(EventTypes.Member, s.state_key)] = s
- defer.returnValue(summary)
+ return summary
def get_lazy_loaded_members_cache(self, cache_key):
cache = self.lazy_loaded_members_cache.get(cache_key)
@@ -871,14 +869,12 @@ class SyncHandler(object):
if state_ids:
state = yield self.store.get_events(list(state_ids.values()))
- defer.returnValue(
- {
- (e.type, e.state_key): e
- for e in sync_config.filter_collection.filter_room_state(
- list(state.values())
- )
- }
- )
+ return {
+ (e.type, e.state_key): e
+ for e in sync_config.filter_collection.filter_room_state(
+ list(state.values())
+ )
+ }
@defer.inlineCallbacks
def unread_notifs_for_room_id(self, room_id, sync_config):
@@ -894,11 +890,11 @@ class SyncHandler(object):
notifs = yield self.store.get_unread_event_push_actions_by_room_for_user(
room_id, sync_config.user.to_string(), last_unread_event_id
)
- defer.returnValue(notifs)
+ return notifs
# There is no new information in this period, so your notification
# count is whatever it was last time.
- defer.returnValue(None)
+ return None
@defer.inlineCallbacks
def generate_sync_result(self, sync_config, since_token=None, full_state=False):
@@ -989,19 +985,17 @@ class SyncHandler(object):
"Sync result for newly joined room %s: %r", room_id, joined_room
)
- defer.returnValue(
- SyncResult(
- presence=sync_result_builder.presence,
- account_data=sync_result_builder.account_data,
- joined=sync_result_builder.joined,
- invited=sync_result_builder.invited,
- archived=sync_result_builder.archived,
- to_device=sync_result_builder.to_device,
- device_lists=device_lists,
- groups=sync_result_builder.groups,
- device_one_time_keys_count=one_time_key_counts,
- next_batch=sync_result_builder.now_token,
- )
+ return SyncResult(
+ presence=sync_result_builder.presence,
+ account_data=sync_result_builder.account_data,
+ joined=sync_result_builder.joined,
+ invited=sync_result_builder.invited,
+ archived=sync_result_builder.archived,
+ to_device=sync_result_builder.to_device,
+ device_lists=device_lists,
+ groups=sync_result_builder.groups,
+ device_one_time_keys_count=one_time_key_counts,
+ next_batch=sync_result_builder.now_token,
)
@measure_func("_generate_sync_entry_for_groups")
@@ -1124,11 +1118,9 @@ class SyncHandler(object):
# Remove any users that we still share a room with.
newly_left_users -= users_who_share_room
- defer.returnValue(
- DeviceLists(changed=users_that_have_changed, left=newly_left_users)
- )
+ return DeviceLists(changed=users_that_have_changed, left=newly_left_users)
else:
- defer.returnValue(DeviceLists(changed=[], left=[]))
+ return DeviceLists(changed=[], left=[])
@defer.inlineCallbacks
def _generate_sync_entry_for_to_device(self, sync_result_builder):
@@ -1225,7 +1217,7 @@ class SyncHandler(object):
sync_result_builder.account_data = account_data_for_user
- defer.returnValue(account_data_by_room)
+ return account_data_by_room
@defer.inlineCallbacks
def _generate_sync_entry_for_presence(
@@ -1325,7 +1317,7 @@ class SyncHandler(object):
)
if not tags_by_room:
logger.debug("no-oping sync")
- defer.returnValue(([], [], [], []))
+ return ([], [], [], [])
ignored_account_data = yield self.store.get_global_account_data_by_type_for_user(
"m.ignored_user_list", user_id=user_id
@@ -1388,13 +1380,11 @@ class SyncHandler(object):
newly_left_users -= newly_joined_or_invited_users
- defer.returnValue(
- (
- newly_joined_rooms,
- newly_joined_or_invited_users,
- newly_left_rooms,
- newly_left_users,
- )
+ return (
+ newly_joined_rooms,
+ newly_joined_or_invited_users,
+ newly_left_rooms,
+ newly_left_users,
)
@defer.inlineCallbacks
@@ -1414,13 +1404,13 @@ class SyncHandler(object):
)
if rooms_changed:
- defer.returnValue(True)
+ return True
stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream
for room_id in sync_result_builder.joined_room_ids:
if self.store.has_room_changed_since(room_id, stream_id):
- defer.returnValue(True)
- defer.returnValue(False)
+ return True
+ return False
@defer.inlineCallbacks
def _get_rooms_changed(self, sync_result_builder, ignored_users):
@@ -1637,7 +1627,7 @@ class SyncHandler(object):
)
room_entries.append(entry)
- defer.returnValue((room_entries, invited, newly_joined_rooms, newly_left_rooms))
+ return (room_entries, invited, newly_joined_rooms, newly_left_rooms)
@defer.inlineCallbacks
def _get_all_rooms(self, sync_result_builder, ignored_users):
@@ -1711,7 +1701,7 @@ class SyncHandler(object):
)
)
- defer.returnValue((room_entries, invited, []))
+ return (room_entries, invited, [])
@defer.inlineCallbacks
def _generate_room_entry(
@@ -1912,7 +1902,7 @@ class SyncHandler(object):
joined_room_ids.add(room_id)
joined_room_ids = frozenset(joined_room_ids)
- defer.returnValue(joined_room_ids)
+ return joined_room_ids
def _action_has_highlight(actions):
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index c3e0c8fc7e..6b661aa93d 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -140,7 +140,7 @@ class TypingHandler(object):
if was_present:
# No point sending another notification
- defer.returnValue(None)
+ return None
self._push_update(member=member, typing=True)
@@ -173,7 +173,7 @@ class TypingHandler(object):
def _stopped_typing(self, member):
if member.user_id not in self._room_typing.get(member.room_id, set()):
# No point
- defer.returnValue(None)
+ return None
self._member_typing_until.pop(member, None)
self._member_last_federation_poke.pop(member, None)
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 5de9630950..e53669e40d 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -133,7 +133,7 @@ class UserDirectoryHandler(StateDeltasHandler):
# If still None then the initial background update hasn't happened yet
if self.pos is None:
- defer.returnValue(None)
+ return None
# Loop round handling deltas until we're up to date
while True:
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 45d5010952..ccdcb7a770 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -45,6 +45,7 @@ from synapse.http import (
cancelled_to_request_timed_out_error,
redact_uri,
)
+from synapse.http.proxyagent import ProxyAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.util.async_helpers import timeout_deferred
from synapse.util.caches import CACHE_SIZE_FACTOR
@@ -182,7 +183,15 @@ class SimpleHttpClient(object):
using HTTP in Matrix
"""
- def __init__(self, hs, treq_args={}, ip_whitelist=None, ip_blacklist=None):
+ def __init__(
+ self,
+ hs,
+ treq_args={},
+ ip_whitelist=None,
+ ip_blacklist=None,
+ http_proxy=None,
+ https_proxy=None,
+ ):
"""
Args:
hs (synapse.server.HomeServer)
@@ -191,6 +200,8 @@ class SimpleHttpClient(object):
we may not request.
ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can
request if it were otherwise caught in a blacklist.
+ http_proxy (bytes): proxy server to use for http connections. host[:port]
+ https_proxy (bytes): proxy server to use for https connections. host[:port]
"""
self.hs = hs
@@ -235,11 +246,13 @@ class SimpleHttpClient(object):
# The default context factory in Twisted 14.0.0 (which we require) is
# BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser'
- self.agent = Agent(
+ self.agent = ProxyAgent(
self.reactor,
connectTimeout=15,
contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,
+ http_proxy=http_proxy,
+ https_proxy=https_proxy,
)
if self._ip_blacklist:
@@ -294,7 +307,7 @@ class SimpleHttpClient(object):
logger.info(
"Received response to %s %s: %s", method, redact_uri(uri), response.code
)
- defer.returnValue(response)
+ return response
except Exception as e:
incoming_responses_counter.labels(method, "ERR").inc()
logger.info(
@@ -345,7 +358,7 @@ class SimpleHttpClient(object):
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
- defer.returnValue(json.loads(body))
+ return json.loads(body)
else:
raise HttpResponseException(response.code, response.phrase, body)
@@ -385,7 +398,7 @@ class SimpleHttpClient(object):
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
- defer.returnValue(json.loads(body))
+ return json.loads(body)
else:
raise HttpResponseException(response.code, response.phrase, body)
@@ -410,7 +423,7 @@ class SimpleHttpClient(object):
ValueError: if the response was not JSON
"""
body = yield self.get_raw(uri, args, headers=headers)
- defer.returnValue(json.loads(body))
+ return json.loads(body)
@defer.inlineCallbacks
def put_json(self, uri, json_body, args={}, headers=None):
@@ -453,7 +466,7 @@ class SimpleHttpClient(object):
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
- defer.returnValue(json.loads(body))
+ return json.loads(body)
else:
raise HttpResponseException(response.code, response.phrase, body)
@@ -488,7 +501,7 @@ class SimpleHttpClient(object):
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
- defer.returnValue(body)
+ return body
else:
raise HttpResponseException(response.code, response.phrase, body)
@@ -545,13 +558,11 @@ class SimpleHttpClient(object):
except Exception as e:
raise_from(SynapseError(502, ("Failed to download remote body: %s" % e)), e)
- defer.returnValue(
- (
- length,
- resp_headers,
- response.request.absoluteURI.decode("ascii"),
- response.code,
- )
+ return (
+ length,
+ resp_headers,
+ response.request.absoluteURI.decode("ascii"),
+ response.code,
)
@@ -627,10 +638,10 @@ class CaptchaServerHttpClient(SimpleHttpClient):
try:
body = yield make_deferred_yieldable(readBody(response))
- defer.returnValue(body)
+ return body
except PartialDownloadError as e:
# twisted dislikes google's response, no content length.
- defer.returnValue(e.response)
+ return e.response
def encode_urlencode_args(args):
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
new file mode 100644
index 0000000000..be7b2ceb8e
--- /dev/null
+++ b/synapse/http/connectproxyclient.py
@@ -0,0 +1,195 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from zope.interface import implementer
+
+from twisted.internet import defer, protocol
+from twisted.internet.error import ConnectError
+from twisted.internet.interfaces import IStreamClientEndpoint
+from twisted.internet.protocol import connectionDone
+from twisted.web import http
+
+logger = logging.getLogger(__name__)
+
+
+class ProxyConnectError(ConnectError):
+ pass
+
+
+@implementer(IStreamClientEndpoint)
+class HTTPConnectProxyEndpoint(object):
+ """An Endpoint implementation which will send a CONNECT request to an http proxy
+
+ Wraps an existing HostnameEndpoint for the proxy.
+
+ When we get the connect() request from the connection pool (via the TLS wrapper),
+ we'll first connect to the proxy endpoint with a ProtocolFactory which will make the
+ CONNECT request. Once that completes, we invoke the protocolFactory which was passed
+ in.
+
+ Args:
+ reactor: the Twisted reactor to use for the connection
+ proxy_endpoint (IStreamClientEndpoint): the endpoint to use to connect to the
+ proxy
+ host (bytes): hostname that we want to CONNECT to
+ port (int): port that we want to connect to
+ """
+
+ def __init__(self, reactor, proxy_endpoint, host, port):
+ self._reactor = reactor
+ self._proxy_endpoint = proxy_endpoint
+ self._host = host
+ self._port = port
+
+ def __repr__(self):
+ return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
+
+ def connect(self, protocolFactory):
+ f = HTTPProxiedClientFactory(self._host, self._port, protocolFactory)
+ d = self._proxy_endpoint.connect(f)
+ # once the tcp socket connects successfully, we need to wait for the
+ # CONNECT to complete.
+ d.addCallback(lambda conn: f.on_connection)
+ return d
+
+
+class HTTPProxiedClientFactory(protocol.ClientFactory):
+ """ClientFactory wrapper that triggers an HTTP proxy CONNECT on connect.
+
+ Once the CONNECT completes, invokes the original ClientFactory to build the
+ HTTP Protocol object and run the rest of the connection.
+
+ Args:
+ dst_host (bytes): hostname that we want to CONNECT to
+ dst_port (int): port that we want to connect to
+ wrapped_factory (protocol.ClientFactory): The original Factory
+ """
+
+ def __init__(self, dst_host, dst_port, wrapped_factory):
+ self.dst_host = dst_host
+ self.dst_port = dst_port
+ self.wrapped_factory = wrapped_factory
+ self.on_connection = defer.Deferred()
+
+ def startedConnecting(self, connector):
+ return self.wrapped_factory.startedConnecting(connector)
+
+ def buildProtocol(self, addr):
+ wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
+
+ return HTTPConnectProtocol(
+ self.dst_host, self.dst_port, wrapped_protocol, self.on_connection
+ )
+
+ def clientConnectionFailed(self, connector, reason):
+ logger.debug("Connection to proxy failed: %s", reason)
+ if not self.on_connection.called:
+ self.on_connection.errback(reason)
+ return self.wrapped_factory.clientConnectionFailed(connector, reason)
+
+ def clientConnectionLost(self, connector, reason):
+ logger.debug("Connection to proxy lost: %s", reason)
+ if not self.on_connection.called:
+ self.on_connection.errback(reason)
+ return self.wrapped_factory.clientConnectionLost(connector, reason)
+
+
+class HTTPConnectProtocol(protocol.Protocol):
+ """Protocol that wraps an existing Protocol to do a CONNECT handshake at connect
+
+ Args:
+ host (bytes): The original HTTP(s) hostname or IPv4 or IPv6 address literal
+ to put in the CONNECT request
+
+ port (int): The original HTTP(s) port to put in the CONNECT request
+
+ wrapped_protocol (interfaces.IProtocol): the original protocol (probably
+ HTTPChannel or TLSMemoryBIOProtocol, but could be anything really)
+
+ connected_deferred (Deferred): a Deferred which will be callbacked with
+ wrapped_protocol when the CONNECT completes
+ """
+
+ def __init__(self, host, port, wrapped_protocol, connected_deferred):
+ self.host = host
+ self.port = port
+ self.wrapped_protocol = wrapped_protocol
+ self.connected_deferred = connected_deferred
+ self.http_setup_client = HTTPConnectSetupClient(self.host, self.port)
+ self.http_setup_client.on_connected.addCallback(self.proxyConnected)
+
+ def connectionMade(self):
+ self.http_setup_client.makeConnection(self.transport)
+
+ def connectionLost(self, reason=connectionDone):
+ if self.wrapped_protocol.connected:
+ self.wrapped_protocol.connectionLost(reason)
+
+ self.http_setup_client.connectionLost(reason)
+
+ if not self.connected_deferred.called:
+ self.connected_deferred.errback(reason)
+
+ def proxyConnected(self, _):
+ self.wrapped_protocol.makeConnection(self.transport)
+
+ self.connected_deferred.callback(self.wrapped_protocol)
+
+ # Get any pending data from the http buf and forward it to the original protocol
+ buf = self.http_setup_client.clearLineBuffer()
+ if buf:
+ self.wrapped_protocol.dataReceived(buf)
+
+ def dataReceived(self, data):
+ # if we've set up the HTTP protocol, we can send the data there
+ if self.wrapped_protocol.connected:
+ return self.wrapped_protocol.dataReceived(data)
+
+ # otherwise, we must still be setting up the connection: send the data to the
+ # setup client
+ return self.http_setup_client.dataReceived(data)
+
+
+class HTTPConnectSetupClient(http.HTTPClient):
+ """HTTPClient protocol to send a CONNECT message for proxies and read the response.
+
+ Args:
+ host (bytes): The hostname to send in the CONNECT message
+ port (int): The port to send in the CONNECT message
+ """
+
+ def __init__(self, host, port):
+ self.host = host
+ self.port = port
+ self.on_connected = defer.Deferred()
+
+ def connectionMade(self):
+ logger.debug("Connected to proxy, sending CONNECT")
+ self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
+ self.endHeaders()
+
+ def handleStatus(self, version, status, message):
+ logger.debug("Got Status: %s %s %s", status, message, version)
+ if status != b"200":
+ raise ProxyConnectError("Unexpected status on CONNECT: %s" % status)
+
+ def handleEndHeaders(self):
+ logger.debug("End Headers")
+ self.on_connected.callback(None)
+
+ def handleResponse(self, body):
+ pass
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 054c321a20..c03ddb724f 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -177,7 +177,7 @@ class MatrixFederationAgent(object):
res = yield make_deferred_yieldable(
agent.request(method, uri, headers, bodyProducer)
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def _route_matrix_uri(self, parsed_uri, lookup_well_known=True):
@@ -205,24 +205,20 @@ class MatrixFederationAgent(object):
port = parsed_uri.port
if port == -1:
port = 8448
- defer.returnValue(
- _RoutingResult(
- host_header=parsed_uri.netloc,
- tls_server_name=parsed_uri.host,
- target_host=parsed_uri.host,
- target_port=port,
- )
+ return _RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=parsed_uri.host,
+ target_port=port,
)
if parsed_uri.port != -1:
# there is an explicit port
- defer.returnValue(
- _RoutingResult(
- host_header=parsed_uri.netloc,
- tls_server_name=parsed_uri.host,
- target_host=parsed_uri.host,
- target_port=parsed_uri.port,
- )
+ return _RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=parsed_uri.host,
+ target_port=parsed_uri.port,
)
if lookup_well_known:
@@ -259,7 +255,7 @@ class MatrixFederationAgent(object):
)
res = yield self._route_matrix_uri(new_uri, lookup_well_known=False)
- defer.returnValue(res)
+ return res
# try a SRV lookup
service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
@@ -283,13 +279,11 @@ class MatrixFederationAgent(object):
parsed_uri.host.decode("ascii"),
)
- defer.returnValue(
- _RoutingResult(
- host_header=parsed_uri.netloc,
- tls_server_name=parsed_uri.host,
- target_host=target_host,
- target_port=port,
- )
+ return _RoutingResult(
+ host_header=parsed_uri.netloc,
+ tls_server_name=parsed_uri.host,
+ target_host=target_host,
+ target_port=port,
)
@defer.inlineCallbacks
@@ -314,7 +308,7 @@ class MatrixFederationAgent(object):
if cache_period > 0:
self._well_known_cache.set(server_name, result, cache_period)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _do_get_well_known(self, server_name):
@@ -354,7 +348,7 @@ class MatrixFederationAgent(object):
# after startup
cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
- defer.returnValue((None, cache_period))
+ return (None, cache_period)
result = parsed_body["m.server"].encode("ascii")
@@ -369,7 +363,7 @@ class MatrixFederationAgent(object):
else:
cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD)
- defer.returnValue((result, cache_period))
+ return (result, cache_period)
@implementer(IStreamClientEndpoint)
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index ecc88f9b96..b32188766d 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -120,7 +120,7 @@ class SrvResolver(object):
if cache_entry:
if all(s.expires > now for s in cache_entry):
servers = list(cache_entry)
- defer.returnValue(servers)
+ return servers
try:
answers, _, _ = yield make_deferred_yieldable(
@@ -129,7 +129,7 @@ class SrvResolver(object):
except DNSNameError:
# TODO: cache this. We can get the SOA out of the exception, and use
# the negative-TTL value.
- defer.returnValue([])
+ return []
except DomainError as e:
# We failed to resolve the name (other than a NameError)
# Try something in the cache, else rereaise
@@ -138,7 +138,7 @@ class SrvResolver(object):
logger.warn(
"Failed to resolve %r, falling back to cache. %r", service_name, e
)
- defer.returnValue(list(cache_entry))
+ return list(cache_entry)
else:
raise e
@@ -169,4 +169,4 @@ class SrvResolver(object):
)
self._cache[service_name] = list(servers)
- defer.returnValue(servers)
+ return servers
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index e60334547e..d07d356464 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -158,7 +158,7 @@ def _handle_json_response(reactor, timeout_sec, request, response):
response.code,
response.phrase.decode("ascii", errors="replace"),
)
- defer.returnValue(body)
+ return body
class MatrixFederationHttpClient(object):
@@ -256,7 +256,7 @@ class MatrixFederationHttpClient(object):
response = yield self._send_request(request, **send_request_args)
- defer.returnValue(response)
+ return response
@defer.inlineCallbacks
def _send_request(
@@ -520,7 +520,7 @@ class MatrixFederationHttpClient(object):
_flatten_response_never_received(e),
)
raise
- defer.returnValue(response)
+ return response
def build_auth_headers(
self, destination, method, url_bytes, content=None, destination_is=None
@@ -644,7 +644,7 @@ class MatrixFederationHttpClient(object):
self.reactor, self.default_timeout, request, response
)
- defer.returnValue(body)
+ return body
@defer.inlineCallbacks
def post_json(
@@ -713,7 +713,7 @@ class MatrixFederationHttpClient(object):
body = yield _handle_json_response(
self.reactor, _sec_timeout, request, response
)
- defer.returnValue(body)
+ return body
@defer.inlineCallbacks
def get_json(
@@ -778,7 +778,7 @@ class MatrixFederationHttpClient(object):
self.reactor, self.default_timeout, request, response
)
- defer.returnValue(body)
+ return body
@defer.inlineCallbacks
def delete_json(
@@ -836,7 +836,7 @@ class MatrixFederationHttpClient(object):
body = yield _handle_json_response(
self.reactor, self.default_timeout, request, response
)
- defer.returnValue(body)
+ return body
@defer.inlineCallbacks
def get_file(
@@ -902,7 +902,7 @@ class MatrixFederationHttpClient(object):
response.phrase.decode("ascii", errors="replace"),
length,
)
- defer.returnValue((length, headers))
+ return (length, headers)
class _ReadBodyToFileProtocol(protocol.Protocol):
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
new file mode 100644
index 0000000000..332da02a8d
--- /dev/null
+++ b/synapse/http/proxyagent.py
@@ -0,0 +1,195 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import re
+
+from zope.interface import implementer
+
+from twisted.internet import defer
+from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
+from twisted.python.failure import Failure
+from twisted.web.client import URI, BrowserLikePolicyForHTTPS, _AgentBase
+from twisted.web.error import SchemeNotSupported
+from twisted.web.iweb import IAgent
+
+from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint
+
+logger = logging.getLogger(__name__)
+
+_VALID_URI = re.compile(br"\A[\x21-\x7e]+\Z")
+
+
+@implementer(IAgent)
+class ProxyAgent(_AgentBase):
+ """An Agent implementation which will use an HTTP proxy if one was requested
+
+ Args:
+ reactor: twisted reactor to place outgoing
+ connections.
+
+ contextFactory (IPolicyForHTTPS): A factory for TLS contexts, to control the
+ verification parameters of OpenSSL. The default is to use a
+ `BrowserLikePolicyForHTTPS`, so unless you have special
+ requirements you can leave this as-is.
+
+ connectTimeout (float): The amount of time that this Agent will wait
+ for the peer to accept a connection.
+
+ bindAddress (bytes): The local address for client sockets to bind to.
+
+ pool (HTTPConnectionPool|None): connection pool to be used. If None, a
+ non-persistent pool instance will be created.
+ """
+
+ def __init__(
+ self,
+ reactor,
+ contextFactory=BrowserLikePolicyForHTTPS(),
+ connectTimeout=None,
+ bindAddress=None,
+ pool=None,
+ http_proxy=None,
+ https_proxy=None,
+ ):
+ _AgentBase.__init__(self, reactor, pool)
+
+ self._endpoint_kwargs = {}
+ if connectTimeout is not None:
+ self._endpoint_kwargs["timeout"] = connectTimeout
+ if bindAddress is not None:
+ self._endpoint_kwargs["bindAddress"] = bindAddress
+
+ self.http_proxy_endpoint = _http_proxy_endpoint(
+ http_proxy, reactor, **self._endpoint_kwargs
+ )
+
+ self.https_proxy_endpoint = _http_proxy_endpoint(
+ https_proxy, reactor, **self._endpoint_kwargs
+ )
+
+ self._policy_for_https = contextFactory
+ self._reactor = reactor
+
+ def request(self, method, uri, headers=None, bodyProducer=None):
+ """
+ Issue a request to the server indicated by the given uri.
+
+ Supports `http` and `https` schemes.
+
+ An existing connection from the connection pool may be used or a new one may be
+ created.
+
+ See also: twisted.web.iweb.IAgent.request
+
+ Args:
+ method (bytes): The request method to use, such as `GET`, `POST`, etc
+
+ uri (bytes): The location of the resource to request.
+
+ headers (Headers|None): Extra headers to send with the request
+
+ bodyProducer (IBodyProducer|None): An object which can generate bytes to
+ make up the body of this request (for example, the properly encoded
+ contents of a file for a file upload). Or, None if the request is to
+ have no body.
+
+ Returns:
+ Deferred[IResponse]: completes when the header of the response has
+ been received (regardless of the response status code).
+ """
+ uri = uri.strip()
+ if not _VALID_URI.match(uri):
+ raise ValueError("Invalid URI {!r}".format(uri))
+
+ parsed_uri = URI.fromBytes(uri)
+ pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
+ request_path = parsed_uri.originForm
+
+ if parsed_uri.scheme == b"http" and self.http_proxy_endpoint:
+ # Cache *all* connections under the same key, since we are only
+ # connecting to a single destination, the proxy:
+ pool_key = ("http-proxy", self.http_proxy_endpoint)
+ endpoint = self.http_proxy_endpoint
+ request_path = uri
+ elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint:
+ endpoint = HTTPConnectProxyEndpoint(
+ self._reactor,
+ self.https_proxy_endpoint,
+ parsed_uri.host,
+ parsed_uri.port,
+ )
+ else:
+ # not using a proxy
+ endpoint = HostnameEndpoint(
+ self._reactor, parsed_uri.host, parsed_uri.port, **self._endpoint_kwargs
+ )
+
+ logger.debug("Requesting %s via %s", uri, endpoint)
+
+ if parsed_uri.scheme == b"https":
+ tls_connection_creator = self._policy_for_https.creatorForNetloc(
+ parsed_uri.host, parsed_uri.port
+ )
+ endpoint = wrapClientTLS(tls_connection_creator, endpoint)
+ elif parsed_uri.scheme == b"http":
+ pass
+ else:
+ return defer.fail(
+ Failure(
+ SchemeNotSupported("Unsupported scheme: %r" % (parsed_uri.scheme,))
+ )
+ )
+
+ return self._requestWithEndpoint(
+ pool_key, endpoint, method, parsed_uri, headers, bodyProducer, request_path
+ )
+
+
+def _http_proxy_endpoint(proxy, reactor, **kwargs):
+ """Parses an http proxy setting and returns an endpoint for the proxy
+
+ Args:
+ proxy (bytes|None): the proxy setting
+ reactor: reactor to be used to connect to the proxy
+ kwargs: other args to be passed to HostnameEndpoint
+
+ Returns:
+ interfaces.IStreamClientEndpoint|None: endpoint to use to connect to the proxy,
+ or None
+ """
+ if proxy is None:
+ return None
+
+ # currently we only support hostname:port. Some apps also support
+ # protocol://<host>[:port], which allows a way of requiring a TLS connection to the
+ # proxy.
+
+ host, port = parse_host_port(proxy, default_port=1080)
+ return HostnameEndpoint(reactor, host, port, **kwargs)
+
+
+def parse_host_port(hostport, default_port=None):
+ # could have sworn we had one of these somewhere else...
+ if b":" in hostport:
+ host, port = hostport.rsplit(b":", 1)
+ try:
+ port = int(port)
+ return host, port
+ except ValueError:
+ # the thing after the : wasn't a valid port; presumably this is an
+ # IPv6 address.
+ pass
+
+ return hostport, default_port
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
index 3da33d7826..d2c209c471 100644
--- a/synapse/logging/opentracing.py
+++ b/synapse/logging/opentracing.py
@@ -11,7 +11,7 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
-# limitations under the License.import opentracing
+# limitations under the License.
# NOTE
@@ -89,7 +89,7 @@ the function becomes the operation name for the span.
# We start
yield we_wait
# we finish
- defer.returnValue(something_usual_and_useful)
+ return something_usual_and_useful
Operation names can be explicitly set for functions by using
``trace_using_operation_name`` and
@@ -113,7 +113,7 @@ Operation names can be explicitly set for functions by using
# We start
yield we_wait
# we finish
- defer.returnValue(something_usual_and_useful)
+ return something_usual_and_useful
Contexts and carriers
---------------------
@@ -150,10 +150,13 @@ Gotchas
"""
import contextlib
+import inspect
import logging
import re
from functools import wraps
+from canonicaljson import json
+
from twisted.internet import defer
from synapse.config import ConfigError
@@ -173,36 +176,12 @@ except ImportError:
logger = logging.getLogger(__name__)
-class _DumTagNames(object):
- """wrapper of opentracings tags. We need to have them if we
- want to reference them without opentracing around. Clearly they
- should never actually show up in a trace. `set_tags` overwrites
- these with the correct ones."""
+# Block everything by default
+# A regex which matches the server_names to expose traces for.
+# None means 'block everything'.
+_homeserver_whitelist = None
- INVALID_TAG = "invalid-tag"
- COMPONENT = INVALID_TAG
- DATABASE_INSTANCE = INVALID_TAG
- DATABASE_STATEMENT = INVALID_TAG
- DATABASE_TYPE = INVALID_TAG
- DATABASE_USER = INVALID_TAG
- ERROR = INVALID_TAG
- HTTP_METHOD = INVALID_TAG
- HTTP_STATUS_CODE = INVALID_TAG
- HTTP_URL = INVALID_TAG
- MESSAGE_BUS_DESTINATION = INVALID_TAG
- PEER_ADDRESS = INVALID_TAG
- PEER_HOSTNAME = INVALID_TAG
- PEER_HOST_IPV4 = INVALID_TAG
- PEER_HOST_IPV6 = INVALID_TAG
- PEER_PORT = INVALID_TAG
- PEER_SERVICE = INVALID_TAG
- SAMPLING_PRIORITY = INVALID_TAG
- SERVICE = INVALID_TAG
- SPAN_KIND = INVALID_TAG
- SPAN_KIND_CONSUMER = INVALID_TAG
- SPAN_KIND_PRODUCER = INVALID_TAG
- SPAN_KIND_RPC_CLIENT = INVALID_TAG
- SPAN_KIND_RPC_SERVER = INVALID_TAG
+# Util methods
def only_if_tracing(func):
@@ -219,11 +198,13 @@ def only_if_tracing(func):
return _only_if_tracing_inner
-# A regex which matches the server_names to expose traces for.
-# None means 'block everything'.
-_homeserver_whitelist = None
+@contextlib.contextmanager
+def _noop_context_manager(*args, **kwargs):
+ """Does exactly what it says on the tin"""
+ yield
+
-tags = _DumTagNames
+# Setup
def init_tracer(config):
@@ -247,26 +228,55 @@ def init_tracer(config):
# Include the worker name
name = config.worker_name if config.worker_name else "master"
+ # Pull out the jaeger config if it was given. Otherwise set it to something sensible.
+ # See https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/config.py
+
set_homeserver_whitelist(config.opentracer_whitelist)
- jaeger_config = JaegerConfig(
- config={"sampler": {"type": "const", "param": 1}, "logging": True},
+
+ JaegerConfig(
+ config=config.jaeger_config,
service_name="{} {}".format(config.server_name, name),
scope_manager=LogContextScopeManager(config),
- )
- jaeger_config.initialize_tracer()
+ ).initialize_tracer()
# Set up tags to be opentracing's tags
global tags
tags = opentracing.tags
-@contextlib.contextmanager
-def _noop_context_manager(*args, **kwargs):
- """Does absolutely nothing really well. Can be entered and exited arbitrarily.
- Good substitute for an opentracing scope."""
- yield
+# Whitelisting
+
+
+@only_if_tracing
+def set_homeserver_whitelist(homeserver_whitelist):
+ """Sets the homeserver whitelist
+
+ Args:
+ homeserver_whitelist (Iterable[str]): regex of whitelisted homeservers
+ """
+ global _homeserver_whitelist
+ if homeserver_whitelist:
+ # Makes a single regex which accepts all passed in regexes in the list
+ _homeserver_whitelist = re.compile(
+ "({})".format(")|(".join(homeserver_whitelist))
+ )
+
+
+@only_if_tracing
+def whitelisted_homeserver(destination):
+ """Checks if a destination matches the whitelist
+
+ Args:
+ destination (str)
+ """
+ _homeserver_whitelist
+ if _homeserver_whitelist:
+ return _homeserver_whitelist.match(destination)
+ return False
+# Start spans and scopes
+
# Could use kwargs but I want these to be explicit
def start_active_span(
operation_name,
@@ -285,8 +295,10 @@ def start_active_span(
Returns:
scope (Scope) or noop_context_manager
"""
+
if opentracing is None:
return _noop_context_manager()
+
else:
# We need to enter the scope here for the logcontext to become active
return opentracing.tracer.start_active_span(
@@ -300,63 +312,13 @@ def start_active_span(
)
-@only_if_tracing
-def close_active_span():
- """Closes the active span. This will close it's logcontext if the context
- was made for the span"""
- opentracing.tracer.scope_manager.active.__exit__(None, None, None)
-
-
-@only_if_tracing
-def set_tag(key, value):
- """Set's a tag on the active span"""
- opentracing.tracer.active_span.set_tag(key, value)
-
-
-@only_if_tracing
-def log_kv(key_values, timestamp=None):
- """Log to the active span"""
- opentracing.tracer.active_span.log_kv(key_values, timestamp)
-
-
-# Note: we don't have a get baggage items because we're trying to hide all
-# scope and span state from synapse. I think this method may also be useless
-# as a result
-@only_if_tracing
-def set_baggage_item(key, value):
- """Attach baggage to the active span"""
- opentracing.tracer.active_span.set_baggage_item(key, value)
-
-
-@only_if_tracing
-def set_operation_name(operation_name):
- """Sets the operation name of the active span"""
- opentracing.tracer.active_span.set_operation_name(operation_name)
-
-
-@only_if_tracing
-def set_homeserver_whitelist(homeserver_whitelist):
- """Sets the whitelist
-
- Args:
- homeserver_whitelist (iterable of strings): regex of whitelisted homeservers
- """
- global _homeserver_whitelist
- if homeserver_whitelist:
- # Makes a single regex which accepts all passed in regexes in the list
- _homeserver_whitelist = re.compile(
- "({})".format(")|(".join(homeserver_whitelist))
- )
-
-
-@only_if_tracing
-def whitelisted_homeserver(destination):
- """Checks if a destination matches the whitelist
- Args:
- destination (String)"""
- if _homeserver_whitelist:
- return _homeserver_whitelist.match(destination)
- return False
+def start_active_span_follows_from(operation_name, contexts):
+ if opentracing is None:
+ return _noop_context_manager()
+ else:
+ references = [opentracing.follows_from(context) for context in contexts]
+ scope = start_active_span(operation_name, references=references)
+ return scope
def start_active_span_from_context(
@@ -372,12 +334,16 @@ def start_active_span_from_context(
Extracts a span context from Twisted Headers.
args:
headers (twisted.web.http_headers.Headers)
+
+ For the other args see opentracing.tracer
+
returns:
span_context (opentracing.span.SpanContext)
"""
# Twisted encodes the values as lists whereas opentracing doesn't.
# So, we take the first item in the list.
# Also, twisted uses byte arrays while opentracing expects strings.
+
if opentracing is None:
return _noop_context_manager()
@@ -395,17 +361,90 @@ def start_active_span_from_context(
)
+def start_active_span_from_edu(
+ edu_content,
+ operation_name,
+ references=[],
+ tags=None,
+ start_time=None,
+ ignore_active_span=False,
+ finish_on_close=True,
+):
+ """
+ Extracts a span context from an edu and uses it to start a new active span
+
+ Args:
+ edu_content (dict): and edu_content with a `context` field whose value is
+ canonical json for a dict which contains opentracing information.
+
+ For the other args see opentracing.tracer
+ """
+
+ if opentracing is None:
+ return _noop_context_manager()
+
+ carrier = json.loads(edu_content.get("context", "{}")).get("opentracing", {})
+ context = opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
+ _references = [
+ opentracing.child_of(span_context_from_string(x))
+ for x in carrier.get("references", [])
+ ]
+
+ # For some reason jaeger decided not to support the visualization of multiple parent
+ # spans or explicitely show references. I include the span context as a tag here as
+ # an aid to people debugging but it's really not an ideal solution.
+
+ references += _references
+
+ scope = opentracing.tracer.start_active_span(
+ operation_name,
+ child_of=context,
+ references=references,
+ tags=tags,
+ start_time=start_time,
+ ignore_active_span=ignore_active_span,
+ finish_on_close=finish_on_close,
+ )
+
+ scope.span.set_tag("references", carrier.get("references", []))
+ return scope
+
+
+# Opentracing setters for tags, logs, etc
+
+
+@only_if_tracing
+def set_tag(key, value):
+ """Sets a tag on the active span"""
+ opentracing.tracer.active_span.set_tag(key, value)
+
+
+@only_if_tracing
+def log_kv(key_values, timestamp=None):
+ """Log to the active span"""
+ opentracing.tracer.active_span.log_kv(key_values, timestamp)
+
+
+@only_if_tracing
+def set_operation_name(operation_name):
+ """Sets the operation name of the active span"""
+ opentracing.tracer.active_span.set_operation_name(operation_name)
+
+
+# Injection and extraction
+
+
@only_if_tracing
def inject_active_span_twisted_headers(headers, destination):
"""
- Injects a span context into twisted headers inplace
+ Injects a span context into twisted headers in-place
Args:
headers (twisted.web.http_headers.Headers)
span (opentracing.Span)
Returns:
- Inplace modification of headers
+ In-place modification of headers
Note:
The headers set by the tracer are custom to the tracer implementation which
@@ -437,7 +476,7 @@ def inject_active_span_byte_dict(headers, destination):
span (opentracing.Span)
Returns:
- Inplace modification of headers
+ In-place modification of headers
Note:
The headers set by the tracer are custom to the tracer implementation which
@@ -458,15 +497,195 @@ def inject_active_span_byte_dict(headers, destination):
headers[key.encode()] = [value.encode()]
+@only_if_tracing
+def inject_active_span_text_map(carrier, destination=None):
+ """
+ Injects a span context into a dict
+
+ Args:
+ carrier (dict)
+ destination (str): the name of the remote server. The span context
+ will only be injected if the destination matches the homeserver_whitelist
+ or destination is None.
+
+ Returns:
+ In-place modification of carrier
+
+ Note:
+ The headers set by the tracer are custom to the tracer implementation which
+ should be unique enough that they don't interfere with any headers set by
+ synapse or twisted. If we're still using jaeger these headers would be those
+ here:
+ https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
+ """
+
+ if destination and not whitelisted_homeserver(destination):
+ return
+
+ opentracing.tracer.inject(
+ opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
+ )
+
+
+def active_span_context_as_string():
+ """
+ Returns:
+ The active span context encoded as a string.
+ """
+ carrier = {}
+ if opentracing:
+ opentracing.tracer.inject(
+ opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
+ )
+ return json.dumps(carrier)
+
+
+@only_if_tracing
+def span_context_from_string(carrier):
+ """
+ Returns:
+ The active span context decoded from a string.
+ """
+ carrier = json.loads(carrier)
+ return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
+
+
+@only_if_tracing
+def extract_text_map(carrier):
+ """
+ Wrapper method for opentracing's tracer.extract for TEXT_MAP.
+ Args:
+ carrier (dict): a dict possibly containing a span context.
+
+ Returns:
+ The active span context extracted from carrier.
+ """
+ return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
+
+
+# Tracing decorators
+
+
+def trace(func):
+ """
+ Decorator to trace a function.
+ Sets the operation name to that of the function's.
+ """
+ if opentracing is None:
+ return func
+
+ @wraps(func)
+ def _trace_inner(self, *args, **kwargs):
+ if opentracing is None:
+ return func(self, *args, **kwargs)
+
+ scope = start_active_span(func.__name__)
+ scope.__enter__()
+
+ try:
+ result = func(self, *args, **kwargs)
+ if isinstance(result, defer.Deferred):
+
+ def call_back(result):
+ scope.__exit__(None, None, None)
+ return result
+
+ def err_back(result):
+ scope.span.set_tag(tags.ERROR, True)
+ scope.__exit__(None, None, None)
+ return result
+
+ result.addCallbacks(call_back, err_back)
+
+ else:
+ scope.__exit__(None, None, None)
+
+ return result
+
+ except Exception as e:
+ scope.__exit__(type(e), None, e.__traceback__)
+ raise
+
+ return _trace_inner
+
+
+def trace_using_operation_name(operation_name):
+ """Decorator to trace a function. Explicitely sets the operation_name."""
+
+ def trace(func):
+ """
+ Decorator to trace a function.
+ Sets the operation name to that of the function's.
+ """
+ if opentracing is None:
+ return func
+
+ @wraps(func)
+ def _trace_inner(self, *args, **kwargs):
+ if opentracing is None:
+ return func(self, *args, **kwargs)
+
+ scope = start_active_span(operation_name)
+ scope.__enter__()
+
+ try:
+ result = func(self, *args, **kwargs)
+ if isinstance(result, defer.Deferred):
+
+ def call_back(result):
+ scope.__exit__(None, None, None)
+ return result
+
+ def err_back(result):
+ scope.span.set_tag(tags.ERROR, True)
+ scope.__exit__(None, None, None)
+ return result
+
+ result.addCallbacks(call_back, err_back)
+ else:
+ scope.__exit__(None, None, None)
+
+ return result
+
+ except Exception as e:
+ scope.__exit__(type(e), None, e.__traceback__)
+ raise
+
+ return _trace_inner
+
+ return trace
+
+
+def tag_args(func):
+ """
+ Tags all of the args to the active span.
+ """
+
+ if not opentracing:
+ return func
+
+ @wraps(func)
+ def _tag_args_inner(self, *args, **kwargs):
+ argspec = inspect.getargspec(func)
+ for i, arg in enumerate(argspec.args[1:]):
+ set_tag("ARG_" + arg, args[i])
+ set_tag("args", args[len(argspec.args) :])
+ set_tag("kwargs", kwargs)
+ return func(self, *args, **kwargs)
+
+ return _tag_args_inner
+
+
def trace_servlet(servlet_name, func):
"""Decorator which traces a serlet. It starts a span with some servlet specific
tags such as the servlet_name and request information"""
+ if not opentracing:
+ return func
@wraps(func)
@defer.inlineCallbacks
def _trace_servlet_inner(request, *args, **kwargs):
- with start_active_span_from_context(
- request.requestHeaders,
+ with start_active_span(
"incoming-client-request",
tags={
"request_id": request.get_request_id(),
@@ -478,6 +697,44 @@ def trace_servlet(servlet_name, func):
},
):
result = yield defer.maybeDeferred(func, request, *args, **kwargs)
- defer.returnValue(result)
+ return result
return _trace_servlet_inner
+
+
+# Helper class
+
+
+class _DummyTagNames(object):
+ """wrapper of opentracings tags. We need to have them if we
+ want to reference them without opentracing around. Clearly they
+ should never actually show up in a trace. `set_tags` overwrites
+ these with the correct ones."""
+
+ INVALID_TAG = "invalid-tag"
+ COMPONENT = INVALID_TAG
+ DATABASE_INSTANCE = INVALID_TAG
+ DATABASE_STATEMENT = INVALID_TAG
+ DATABASE_TYPE = INVALID_TAG
+ DATABASE_USER = INVALID_TAG
+ ERROR = INVALID_TAG
+ HTTP_METHOD = INVALID_TAG
+ HTTP_STATUS_CODE = INVALID_TAG
+ HTTP_URL = INVALID_TAG
+ MESSAGE_BUS_DESTINATION = INVALID_TAG
+ PEER_ADDRESS = INVALID_TAG
+ PEER_HOSTNAME = INVALID_TAG
+ PEER_HOST_IPV4 = INVALID_TAG
+ PEER_HOST_IPV6 = INVALID_TAG
+ PEER_PORT = INVALID_TAG
+ PEER_SERVICE = INVALID_TAG
+ SAMPLING_PRIORITY = INVALID_TAG
+ SERVICE = INVALID_TAG
+ SPAN_KIND = INVALID_TAG
+ SPAN_KIND_CONSUMER = INVALID_TAG
+ SPAN_KIND_PRODUCER = INVALID_TAG
+ SPAN_KIND_RPC_CLIENT = INVALID_TAG
+ SPAN_KIND_RPC_SERVER = INVALID_TAG
+
+
+tags = _DummyTagNames
diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py
index 8c661302c9..4eed4f2338 100644
--- a/synapse/logging/scopecontextmanager.py
+++ b/synapse/logging/scopecontextmanager.py
@@ -131,7 +131,7 @@ class _LogContextScope(Scope):
def close(self):
if self.manager.active is not self:
- logger.error("Tried to close a none active scope!")
+ logger.error("Tried to close a non-active scope!")
return
if self._finish_on_close:
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 7bb020cb45..41147d4292 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -101,7 +101,7 @@ class ModuleApi(object):
)
user_id = yield self.register_user(localpart, displayname, emails)
_, access_token = yield self.register_device(user_id)
- defer.returnValue((user_id, access_token))
+ return (user_id, access_token)
def register_user(self, localpart, displayname=None, emails=[]):
"""Registers a new user with given localpart and optional displayname, emails.
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 918ef64897..bd80c801b6 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -365,7 +365,7 @@ class Notifier(object):
current_token = user_stream.current_token
result = yield callback(prev_token, current_token)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def get_events_for(
@@ -400,7 +400,7 @@ class Notifier(object):
@defer.inlineCallbacks
def check_for_updates(before_token, after_token):
if not after_token.is_after(before_token):
- defer.returnValue(EventStreamResult([], (from_token, from_token)))
+ return EventStreamResult([], (from_token, from_token))
events = []
end_token = from_token
@@ -440,7 +440,7 @@ class Notifier(object):
events.extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key)
- defer.returnValue(EventStreamResult(events, (from_token, end_token)))
+ return EventStreamResult(events, (from_token, end_token))
user_id_for_stream = user.to_string()
if is_peeking:
@@ -465,18 +465,18 @@ class Notifier(object):
from_token=from_token,
)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _get_room_ids(self, user, explicit_room_id):
joined_room_ids = yield self.store.get_rooms_for_user(user.to_string())
if explicit_room_id:
if explicit_room_id in joined_room_ids:
- defer.returnValue(([explicit_room_id], True))
+ return ([explicit_room_id], True)
if (yield self._is_world_readable(explicit_room_id)):
- defer.returnValue(([explicit_room_id], False))
+ return ([explicit_room_id], False)
raise AuthError(403, "Non-joined access not allowed")
- defer.returnValue((joined_room_ids, True))
+ return (joined_room_ids, True)
@defer.inlineCallbacks
def _is_world_readable(self, room_id):
@@ -484,9 +484,9 @@ class Notifier(object):
room_id, EventTypes.RoomHistoryVisibility, ""
)
if state and "history_visibility" in state.content:
- defer.returnValue(state.content["history_visibility"] == "world_readable")
+ return state.content["history_visibility"] == "world_readable"
else:
- defer.returnValue(False)
+ return False
@log_function
def remove_expired_streams(self):
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index c8a5b381da..c831975635 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -95,7 +95,7 @@ class BulkPushRuleEvaluator(object):
invited
)
- defer.returnValue(rules_by_user)
+ return rules_by_user
@cached()
def _get_rules_for_room(self, room_id):
@@ -134,7 +134,7 @@ class BulkPushRuleEvaluator(object):
pl_event = auth_events.get(POWER_KEY)
- defer.returnValue((pl_event.content if pl_event else {}, sender_level))
+ return (pl_event.content if pl_event else {}, sender_level)
@defer.inlineCallbacks
def action_for_event_by_user(self, event, context):
@@ -283,13 +283,13 @@ class RulesForRoom(object):
if state_group and self.state_group == state_group:
logger.debug("Using cached rules for %r", self.room_id)
self.room_push_rule_cache_metrics.inc_hits()
- defer.returnValue(self.rules_by_user)
+ return self.rules_by_user
with (yield self.linearizer.queue(())):
if state_group and self.state_group == state_group:
logger.debug("Using cached rules for %r", self.room_id)
self.room_push_rule_cache_metrics.inc_hits()
- defer.returnValue(self.rules_by_user)
+ return self.rules_by_user
self.room_push_rule_cache_metrics.inc_misses()
@@ -366,7 +366,7 @@ class RulesForRoom(object):
logger.debug(
"Returning push rules for %r %r", self.room_id, ret_rules_by_user.keys()
)
- defer.returnValue(ret_rules_by_user)
+ return ret_rules_by_user
@defer.inlineCallbacks
def _update_rules_with_member_event_ids(
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 4e7b6a5531..454297e6a9 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -101,7 +101,7 @@ class HttpPusher(object):
if "url" not in self.data:
raise PusherConfigException("'url' required in data for HTTP pusher")
self.url = self.data["url"]
- self.http_client = hs.get_simple_http_client()
+ self.http_client = hs.get_proxied_http_client()
self.data_minus_url = {}
self.data_minus_url.update(self.data)
del self.data_minus_url["url"]
@@ -258,17 +258,17 @@ class HttpPusher(object):
@defer.inlineCallbacks
def _process_one(self, push_action):
if "notify" not in push_action["actions"]:
- defer.returnValue(True)
+ return True
tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"])
badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
event = yield self.store.get_event(push_action["event_id"], allow_none=True)
if event is None:
- defer.returnValue(True) # It's been redacted
+ return True # It's been redacted
rejected = yield self.dispatch_push(event, tweaks, badge)
if rejected is False:
- defer.returnValue(False)
+ return False
if isinstance(rejected, list) or isinstance(rejected, tuple):
for pk in rejected:
@@ -282,7 +282,7 @@ class HttpPusher(object):
else:
logger.info("Pushkey %s was rejected: removing", pk)
yield self.hs.remove_pusher(self.app_id, pk, self.user_id)
- defer.returnValue(True)
+ return True
@defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks, badge):
@@ -302,7 +302,7 @@ class HttpPusher(object):
],
}
}
- defer.returnValue(d)
+ return d
ctx = yield push_tools.get_context_for_event(
self.store, self.state_handler, event, self.user_id
@@ -345,13 +345,13 @@ class HttpPusher(object):
if "name" in ctx and len(ctx["name"]) > 0:
d["notification"]["room_name"] = ctx["name"]
- defer.returnValue(d)
+ return d
@defer.inlineCallbacks
def dispatch_push(self, event, tweaks, badge):
notification_dict = yield self._build_notification_dict(event, tweaks, badge)
if not notification_dict:
- defer.returnValue([])
+ return []
try:
resp = yield self.http_client.post_json_get_json(
self.url, notification_dict
@@ -364,11 +364,11 @@ class HttpPusher(object):
type(e),
e,
)
- defer.returnValue(False)
+ return False
rejected = []
if "rejected" in resp:
rejected = resp["rejected"]
- defer.returnValue(rejected)
+ return rejected
@defer.inlineCallbacks
def _send_badge(self, badge):
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 521c6e2cd7..4245ce26f3 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -316,7 +316,7 @@ class Mailer(object):
if not merge:
room_vars["notifs"].append(notifvars)
- defer.returnValue(room_vars)
+ return room_vars
@defer.inlineCallbacks
def get_notif_vars(self, notif, user_id, notif_event, room_state_ids):
@@ -343,7 +343,7 @@ class Mailer(object):
if messagevars is not None:
ret["messages"].append(messagevars)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def get_message_vars(self, notif, event, room_state_ids):
@@ -379,7 +379,7 @@ class Mailer(object):
if "body" in event.content:
ret["body_text_plain"] = event.content["body"]
- defer.returnValue(ret)
+ return ret
def add_text_message_vars(self, messagevars, event):
msgformat = event.content.get("format")
@@ -428,19 +428,16 @@ class Mailer(object):
inviter_name = name_from_member_event(inviter_member_event)
if room_name is None:
- defer.returnValue(
- INVITE_FROM_PERSON
- % {"person": inviter_name, "app": self.app_name}
- )
+ return INVITE_FROM_PERSON % {
+ "person": inviter_name,
+ "app": self.app_name,
+ }
else:
- defer.returnValue(
- INVITE_FROM_PERSON_TO_ROOM
- % {
- "person": inviter_name,
- "room": room_name,
- "app": self.app_name,
- }
- )
+ return INVITE_FROM_PERSON_TO_ROOM % {
+ "person": inviter_name,
+ "room": room_name,
+ "app": self.app_name,
+ }
sender_name = None
if len(notifs_by_room[room_id]) == 1:
@@ -454,26 +451,21 @@ class Mailer(object):
sender_name = name_from_member_event(state_event)
if sender_name is not None and room_name is not None:
- defer.returnValue(
- MESSAGE_FROM_PERSON_IN_ROOM
- % {
- "person": sender_name,
- "room": room_name,
- "app": self.app_name,
- }
- )
+ return MESSAGE_FROM_PERSON_IN_ROOM % {
+ "person": sender_name,
+ "room": room_name,
+ "app": self.app_name,
+ }
elif sender_name is not None:
- defer.returnValue(
- MESSAGE_FROM_PERSON
- % {"person": sender_name, "app": self.app_name}
- )
+ return MESSAGE_FROM_PERSON % {
+ "person": sender_name,
+ "app": self.app_name,
+ }
else:
# There's more than one notification for this room, so just
# say there are several
if room_name is not None:
- defer.returnValue(
- MESSAGES_IN_ROOM % {"room": room_name, "app": self.app_name}
- )
+ return MESSAGES_IN_ROOM % {"room": room_name, "app": self.app_name}
else:
# If the room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room"
@@ -493,24 +485,19 @@ class Mailer(object):
]
)
- defer.returnValue(
- MESSAGES_FROM_PERSON
- % {
- "person": descriptor_from_member_events(
- member_events.values()
- ),
- "app": self.app_name,
- }
- )
+ return MESSAGES_FROM_PERSON % {
+ "person": descriptor_from_member_events(member_events.values()),
+ "app": self.app_name,
+ }
else:
# Stuff's happened in multiple different rooms
# ...but we still refer to the 'reason' room which triggered the mail
if reason["room_name"] is not None:
- defer.returnValue(
- MESSAGES_IN_ROOM_AND_OTHERS
- % {"room": reason["room_name"], "app": self.app_name}
- )
+ return MESSAGES_IN_ROOM_AND_OTHERS % {
+ "room": reason["room_name"],
+ "app": self.app_name,
+ }
else:
# If the reason room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room"
@@ -527,13 +514,10 @@ class Mailer(object):
[room_state_ids[room_id][("m.room.member", s)] for s in sender_ids]
)
- defer.returnValue(
- MESSAGES_FROM_PERSON_AND_OTHERS
- % {
- "person": descriptor_from_member_events(member_events.values()),
- "app": self.app_name,
- }
- )
+ return MESSAGES_FROM_PERSON_AND_OTHERS % {
+ "person": descriptor_from_member_events(member_events.values()),
+ "app": self.app_name,
+ }
def make_room_link(self, room_id):
if self.hs.config.email_riot_base_url:
diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py
index 06056fbf4f..16a7e8e31d 100644
--- a/synapse/push/presentable_names.py
+++ b/synapse/push/presentable_names.py
@@ -55,7 +55,7 @@ def calculate_room_name(
room_state_ids[("m.room.name", "")], allow_none=True
)
if m_room_name and m_room_name.content and m_room_name.content["name"]:
- defer.returnValue(m_room_name.content["name"])
+ return m_room_name.content["name"]
# does it have a canonical alias?
if ("m.room.canonical_alias", "") in room_state_ids:
@@ -68,7 +68,7 @@ def calculate_room_name(
and canon_alias.content["alias"]
and _looks_like_an_alias(canon_alias.content["alias"])
):
- defer.returnValue(canon_alias.content["alias"])
+ return canon_alias.content["alias"]
# at this point we're going to need to search the state by all state keys
# for an event type, so rearrange the data structure
@@ -82,10 +82,10 @@ def calculate_room_name(
if alias_event and alias_event.content.get("aliases"):
the_aliases = alias_event.content["aliases"]
if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]):
- defer.returnValue(the_aliases[0])
+ return the_aliases[0]
if not fallback_to_members:
- defer.returnValue(None)
+ return None
my_member_event = None
if ("m.room.member", user_id) in room_state_ids:
@@ -104,14 +104,13 @@ def calculate_room_name(
)
if inviter_member_event:
if fallback_to_single_member:
- defer.returnValue(
- "Invite from %s"
- % (name_from_member_event(inviter_member_event),)
+ return "Invite from %s" % (
+ name_from_member_event(inviter_member_event),
)
else:
return
else:
- defer.returnValue("Room Invite")
+ return "Room Invite"
# we're going to have to generate a name based on who's in the room,
# so find out who is in the room that isn't the user.
@@ -154,17 +153,17 @@ def calculate_room_name(
# return "Inviting %s" % (
# descriptor_from_member_events(third_party_invites)
# )
- defer.returnValue("Inviting email address")
+ return "Inviting email address"
else:
- defer.returnValue(ALL_ALONE)
+ return ALL_ALONE
else:
- defer.returnValue(name_from_member_event(all_members[0]))
+ return name_from_member_event(all_members[0])
else:
- defer.returnValue(ALL_ALONE)
+ return ALL_ALONE
elif len(other_members) == 1 and not fallback_to_single_member:
return
else:
- defer.returnValue(descriptor_from_member_events(other_members))
+ return descriptor_from_member_events(other_members)
def descriptor_from_member_events(member_events):
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index e37269cdb9..a54051a726 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -39,7 +39,7 @@ def get_badge_count(store, user_id):
# return one badge count per conversation, as count per
# message is so noisy as to be almost useless
badge += 1 if notifs["notify_count"] else 0
- defer.returnValue(badge)
+ return badge
@defer.inlineCallbacks
@@ -61,4 +61,4 @@ def get_context_for_event(store, state_handler, ev, user_id):
sender_state_event = yield store.get_event(sender_state_event_id)
ctx["sender_display_name"] = name_from_member_event(sender_state_event)
- defer.returnValue(ctx)
+ return ctx
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index df6f670740..08e840fdc2 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -123,7 +123,7 @@ class PusherPool:
)
pusher = yield self.start_pusher_by_id(app_id, pushkey, user_id)
- defer.returnValue(pusher)
+ return pusher
@defer.inlineCallbacks
def remove_pushers_by_app_id_and_pushkey_not_user(
@@ -224,7 +224,7 @@ class PusherPool:
if pusher_dict:
pusher = yield self._start_pusher(pusher_dict)
- defer.returnValue(pusher)
+ return pusher
@defer.inlineCallbacks
def _start_pushers(self):
@@ -293,7 +293,7 @@ class PusherPool:
p.on_started(have_notifs)
- defer.returnValue(p)
+ return p
@defer.inlineCallbacks
def remove_pusher(self, app_id, pushkey, user_id):
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index c6465c0386..195a7a70c8 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -72,6 +72,7 @@ REQUIREMENTS = [
"netaddr>=0.7.18",
"Jinja2>=2.9",
"bleach>=1.4.3",
+ "sdnotify>=0.3",
]
CONDITIONAL_REQUIREMENTS = {
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 43c89e36dd..2e0594e581 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -185,7 +185,7 @@ class ReplicationEndpoint(object):
except RequestSendFailed as e:
raise_from(SynapseError(502, "Failed to talk to master"), e)
- defer.returnValue(result)
+ return result
return send_request
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 61eafbe708..fed4f08820 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -80,7 +80,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
payload = {"events": event_payloads, "backfilled": backfilled}
- defer.returnValue(payload)
+ return payload
@defer.inlineCallbacks
def _handle_request(self, request):
@@ -113,7 +113,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
event_and_contexts, backfilled
)
- defer.returnValue((200, {}))
+ return (200, {})
class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
@@ -156,7 +156,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
result = yield self.registry.on_edu(edu_type, origin, edu_content)
- defer.returnValue((200, result))
+ return (200, result)
class ReplicationGetQueryRestServlet(ReplicationEndpoint):
@@ -204,7 +204,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
result = yield self.registry.on_query(query_type, args)
- defer.returnValue((200, result))
+ return (200, result)
class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
@@ -238,7 +238,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
def _handle_request(self, request, room_id):
yield self.store.clean_room_for_join(room_id)
- defer.returnValue((200, {}))
+ return (200, {})
def register_servlets(hs, http_server):
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 7c1197e5dd..f17d3a2da4 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -64,7 +64,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
user_id, device_id, initial_display_name, is_guest
)
- defer.returnValue((200, {"device_id": device_id, "access_token": access_token}))
+ return (200, {"device_id": device_id, "access_token": access_token})
def register_servlets(hs, http_server):
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 2d9cbbaefc..4217335d88 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -83,7 +83,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
remote_room_hosts, room_id, user_id, event_content
)
- defer.returnValue((200, {}))
+ return (200, {})
class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
@@ -153,7 +153,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
yield self.store.locally_reject_invite(user_id, room_id)
ret = {}
- defer.returnValue((200, ret))
+ return (200, ret)
class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 2bf2173895..3341320a87 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -90,7 +90,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
address=content["address"],
)
- defer.returnValue((200, {}))
+ return (200, {})
class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
@@ -143,7 +143,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
bind_msisdn=bind_msisdn,
)
- defer.returnValue((200, {}))
+ return (200, {})
def register_servlets(hs, http_server):
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index 034763fe99..eff7bd7305 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -85,7 +85,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
"extra_users": [u.to_string() for u in extra_users],
}
- defer.returnValue(payload)
+ return payload
@defer.inlineCallbacks
def _handle_request(self, request, event_id):
@@ -117,7 +117,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
- defer.returnValue((200, {}))
+ return (200, {})
def register_servlets(hs, http_server):
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 7ef67a5a73..c10b85d2ff 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -158,7 +158,7 @@ class Stream(object):
updates, current_token = yield self.get_updates_since(self.last_token)
self.last_token = current_token
- defer.returnValue((updates, current_token))
+ return (updates, current_token)
@defer.inlineCallbacks
def get_updates_since(self, from_token):
@@ -172,14 +172,14 @@ class Stream(object):
sent over the replication steam.
"""
if from_token in ("NOW", "now"):
- defer.returnValue(([], self.upto_token))
+ return ([], self.upto_token)
current_token = self.upto_token
from_token = int(from_token)
if from_token == current_token:
- defer.returnValue(([], current_token))
+ return ([], current_token)
if self._LIMITED:
rows = yield self.update_function(
@@ -198,7 +198,7 @@ class Stream(object):
if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND:
raise Exception("stream %s has fallen behind" % (self.NAME))
- defer.returnValue((updates, current_token))
+ return (updates, current_token)
def current_token(self):
"""Gets the current token of the underlying streams. Should be provided
@@ -297,7 +297,7 @@ class PushRulesStream(Stream):
@defer.inlineCallbacks
def update_function(self, from_token, to_token, limit):
rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit)
- defer.returnValue([(row[0], row[2]) for row in rows])
+ return [(row[0], row[2]) for row in rows]
class PushersStream(Stream):
@@ -424,7 +424,7 @@ class AccountDataStream(Stream):
for stream_id, user_id, account_data_type, content in global_results
)
- defer.returnValue(results)
+ return results
class GroupServerStream(Stream):
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 3d0694bb11..d97669c886 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -134,7 +134,7 @@ class EventsStream(Stream):
all_updates = heapq.merge(event_updates, state_updates)
- defer.returnValue(all_updates)
+ return all_updates
@classmethod
def parse_row(cls, row):
diff --git a/synapse/res/templates/account_renewed.html b/synapse/res/templates/account_renewed.html
new file mode 100644
index 0000000000..894da030af
--- /dev/null
+++ b/synapse/res/templates/account_renewed.html
@@ -0,0 +1 @@
+<html><body>Your account has been successfully renewed.</body><html>
diff --git a/synapse/res/templates/invalid_token.html b/synapse/res/templates/invalid_token.html
new file mode 100644
index 0000000000..6bd2b98364
--- /dev/null
+++ b/synapse/res/templates/invalid_token.html
@@ -0,0 +1 @@
+<html><body>Invalid renewal token.</body><html>
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 1d20b96d03..f161bc51a5 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -41,6 +41,7 @@ from synapse.rest.client.v2_alpha import (
keys,
notifications,
openid,
+ password_policy,
read_marker,
receipts,
register,
@@ -117,6 +118,7 @@ class ClientRestResource(JsonResource):
room_upgrade_rest_servlet.register_servlets(hs, client_resource)
capabilities.register_servlets(hs, client_resource)
account_validity.register_servlets(hs, client_resource)
+ password_policy.register_servlets(hs, client_resource)
relations.register_servlets(hs, client_resource)
# moving to /_synapse/admin
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 6888ae5590..0a7d9b81b2 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -84,7 +84,7 @@ class UsersRestServlet(RestServlet):
ret = yield self.handlers.admin_handler.get_users()
- defer.returnValue((200, ret))
+ return (200, ret)
class VersionServlet(RestServlet):
@@ -227,7 +227,7 @@ class UserRegisterServlet(RestServlet):
)
result = yield register._create_registration_details(user_id, body)
- defer.returnValue((200, result))
+ return (200, result)
class WhoisRestServlet(RestServlet):
@@ -252,7 +252,7 @@ class WhoisRestServlet(RestServlet):
ret = yield self.handlers.admin_handler.get_whois(target_user)
- defer.returnValue((200, ret))
+ return (200, ret)
class PurgeMediaCacheRestServlet(RestServlet):
@@ -271,7 +271,7 @@ class PurgeMediaCacheRestServlet(RestServlet):
ret = yield self.media_repository.delete_old_remote_media(before_ts)
- defer.returnValue((200, ret))
+ return (200, ret)
class PurgeHistoryRestServlet(RestServlet):
@@ -356,7 +356,7 @@ class PurgeHistoryRestServlet(RestServlet):
room_id, token, delete_local_events=delete_local_events
)
- defer.returnValue((200, {"purge_id": purge_id}))
+ return (200, {"purge_id": purge_id})
class PurgeHistoryStatusRestServlet(RestServlet):
@@ -381,7 +381,7 @@ class PurgeHistoryStatusRestServlet(RestServlet):
if purge_status is None:
raise NotFoundError("purge id '%s' not found" % purge_id)
- defer.returnValue((200, purge_status.asdict()))
+ return (200, purge_status.asdict())
class DeactivateAccountRestServlet(RestServlet):
@@ -413,7 +413,7 @@ class DeactivateAccountRestServlet(RestServlet):
else:
id_server_unbind_result = "no-support"
- defer.returnValue((200, {"id_server_unbind_result": id_server_unbind_result}))
+ return (200, {"id_server_unbind_result": id_server_unbind_result})
class ShutdownRoomRestServlet(RestServlet):
@@ -531,16 +531,14 @@ class ShutdownRoomRestServlet(RestServlet):
room_id, new_room_id, requester_user_id
)
- defer.returnValue(
- (
- 200,
- {
- "kicked_users": kicked_users,
- "failed_to_kick_users": failed_to_kick_users,
- "local_aliases": aliases_for_room,
- "new_room_id": new_room_id,
- },
- )
+ return (
+ 200,
+ {
+ "kicked_users": kicked_users,
+ "failed_to_kick_users": failed_to_kick_users,
+ "local_aliases": aliases_for_room,
+ "new_room_id": new_room_id,
+ },
)
@@ -564,7 +562,7 @@ class QuarantineMediaInRoom(RestServlet):
room_id, requester.user.to_string()
)
- defer.returnValue((200, {"num_quarantined": num_quarantined}))
+ return (200, {"num_quarantined": num_quarantined})
class ListMediaInRoom(RestServlet):
@@ -585,7 +583,7 @@ class ListMediaInRoom(RestServlet):
local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id)
- defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs}))
+ return (200, {"local": local_mxcs, "remote": remote_mxcs})
class ResetPasswordRestServlet(RestServlet):
@@ -629,7 +627,7 @@ class ResetPasswordRestServlet(RestServlet):
yield self._set_password_handler.set_password(
target_user_id, new_password, requester
)
- defer.returnValue((200, {}))
+ return (200, {})
class GetUsersPaginatedRestServlet(RestServlet):
@@ -671,7 +669,7 @@ class GetUsersPaginatedRestServlet(RestServlet):
logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit)
- defer.returnValue((200, ret))
+ return (200, ret)
@defer.inlineCallbacks
def on_POST(self, request, target_user_id):
@@ -699,7 +697,7 @@ class GetUsersPaginatedRestServlet(RestServlet):
logger.info("limit: %s, start: %s", limit, start)
ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit)
- defer.returnValue((200, ret))
+ return (200, ret)
class SearchUsersRestServlet(RestServlet):
@@ -742,7 +740,7 @@ class SearchUsersRestServlet(RestServlet):
logger.info("term: %s ", term)
ret = yield self.handlers.admin_handler.search_users(term)
- defer.returnValue((200, ret))
+ return (200, ret)
class DeleteGroupAdminRestServlet(RestServlet):
@@ -765,7 +763,7 @@ class DeleteGroupAdminRestServlet(RestServlet):
raise SynapseError(400, "Can only delete local groups")
yield self.group_server.delete_group(group_id, requester.user.to_string())
- defer.returnValue((200, {}))
+ return (200, {})
class AccountValidityRenewServlet(RestServlet):
@@ -796,7 +794,7 @@ class AccountValidityRenewServlet(RestServlet):
)
res = {"expiration_ts": expiration_ts}
- defer.returnValue((200, res))
+ return (200, res)
########################################################################################
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index d9c71261f2..656526fea5 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -92,7 +92,7 @@ class SendServerNoticeServlet(RestServlet):
event_content=body["content"],
)
- defer.returnValue((200, {"event_id": event.event_id}))
+ return (200, {"event_id": event.event_id})
def on_PUT(self, request, txn_id):
return self.txns.fetch_or_execute_request(
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 57542c2b4b..4284738021 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -54,7 +54,7 @@ class ClientDirectoryServer(RestServlet):
dir_handler = self.handlers.directory_handler
res = yield dir_handler.get_association(room_alias)
- defer.returnValue((200, res))
+ return (200, res)
@defer.inlineCallbacks
def on_PUT(self, request, room_alias):
@@ -87,7 +87,7 @@ class ClientDirectoryServer(RestServlet):
requester, room_alias, room_id, servers
)
- defer.returnValue((200, {}))
+ return (200, {})
@defer.inlineCallbacks
def on_DELETE(self, request, room_alias):
@@ -102,7 +102,7 @@ class ClientDirectoryServer(RestServlet):
service.url,
room_alias.to_string(),
)
- defer.returnValue((200, {}))
+ return (200, {})
except InvalidClientCredentialsError:
# fallback to default user behaviour if they aren't an AS
pass
@@ -118,7 +118,7 @@ class ClientDirectoryServer(RestServlet):
"User %s deleted alias %s", user.to_string(), room_alias.to_string()
)
- defer.returnValue((200, {}))
+ return (200, {})
class ClientDirectoryListServer(RestServlet):
@@ -136,9 +136,7 @@ class ClientDirectoryListServer(RestServlet):
if room is None:
raise NotFoundError("Unknown room")
- defer.returnValue(
- (200, {"visibility": "public" if room["is_public"] else "private"})
- )
+ return (200, {"visibility": "public" if room["is_public"] else "private"})
@defer.inlineCallbacks
def on_PUT(self, request, room_id):
@@ -151,7 +149,7 @@ class ClientDirectoryListServer(RestServlet):
requester, room_id, visibility
)
- defer.returnValue((200, {}))
+ return (200, {})
@defer.inlineCallbacks
def on_DELETE(self, request, room_id):
@@ -161,7 +159,7 @@ class ClientDirectoryListServer(RestServlet):
requester, room_id, "private"
)
- defer.returnValue((200, {}))
+ return (200, {})
class ClientAppserviceDirectoryListServer(RestServlet):
@@ -195,4 +193,4 @@ class ClientAppserviceDirectoryListServer(RestServlet):
requester.app_service.id, network_id, room_id, visibility
)
- defer.returnValue((200, {}))
+ return (200, {})
diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index d6de2b7360..53ebed2203 100644
--- a/synapse/rest/client/v1/events.py
+++ b/synapse/rest/client/v1/events.py
@@ -67,7 +67,7 @@ class EventStreamRestServlet(RestServlet):
is_guest=is_guest,
)
- defer.returnValue((200, chunk))
+ return (200, chunk)
def on_OPTIONS(self, request):
return (200, {})
@@ -91,9 +91,9 @@ class EventRestServlet(RestServlet):
time_now = self.clock.time_msec()
if event:
event = yield self._event_serializer.serialize_event(event, time_now)
- defer.returnValue((200, event))
+ return (200, event)
else:
- defer.returnValue((404, "Event not found."))
+ return (404, "Event not found.")
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py
index 0fe5f2d79b..70b8478e90 100644
--- a/synapse/rest/client/v1/initial_sync.py
+++ b/synapse/rest/client/v1/initial_sync.py
@@ -42,7 +42,7 @@ class InitialSyncRestServlet(RestServlet):
include_archived=include_archived,
)
- defer.returnValue((200, content))
+ return (200, content)
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 0d05945f0a..4ddf194fd7 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -152,7 +152,7 @@ class LoginRestServlet(RestServlet):
well_known_data = self._well_known_builder.get_well_known()
if well_known_data:
result["well_known"] = well_known_data
- defer.returnValue((200, result))
+ return (200, result)
@defer.inlineCallbacks
def _do_other_login(self, login_submission):
@@ -212,7 +212,7 @@ class LoginRestServlet(RestServlet):
result = yield self._register_device_with_callback(
canonical_user_id, login_submission, callback_3pid
)
- defer.returnValue(result)
+ return result
# No password providers were able to handle this 3pid
# Check local store
@@ -241,7 +241,7 @@ class LoginRestServlet(RestServlet):
result = yield self._register_device_with_callback(
canonical_user_id, login_submission, callback
)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _register_device_with_callback(self, user_id, login_submission, callback=None):
@@ -273,7 +273,7 @@ class LoginRestServlet(RestServlet):
if callback is not None:
yield callback(result)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def do_token_login(self, login_submission):
@@ -284,7 +284,7 @@ class LoginRestServlet(RestServlet):
)
result = yield self._register_device_with_callback(user_id, login_submission)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def do_jwt_login(self, login_submission):
@@ -321,7 +321,7 @@ class LoginRestServlet(RestServlet):
result = yield self._register_device_with_callback(
registered_user_id, login_submission
)
- defer.returnValue(result)
+ return result
class BaseSSORedirectServlet(RestServlet):
@@ -378,7 +378,7 @@ class CasTicketServlet(RestServlet):
self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes
self._sso_auth_handler = SSOAuthHandler(hs)
- self._http_client = hs.get_simple_http_client()
+ self._http_client = hs.get_proxied_http_client()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -395,7 +395,7 @@ class CasTicketServlet(RestServlet):
# even if that's being used old-http style to signal end-of-data
body = pde.response
result = yield self.handle_cas_response(request, body, client_redirect_url)
- defer.returnValue(result)
+ return result
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
user, attributes = self.parse_cas_response(cas_response_body)
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index cd711be519..2769f3a189 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -49,7 +49,7 @@ class LogoutRestServlet(RestServlet):
requester.user.to_string(), requester.device_id
)
- defer.returnValue((200, {}))
+ return (200, {})
class LogoutAllRestServlet(RestServlet):
@@ -75,7 +75,7 @@ class LogoutAllRestServlet(RestServlet):
# .. and then delete any access tokens which weren't associated with
# devices.
yield self._auth_handler.delete_access_tokens_for_user(user_id)
- defer.returnValue((200, {}))
+ return (200, {})
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index 3e87f0fdb3..1eb1068c98 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -56,7 +56,7 @@ class PresenceStatusRestServlet(RestServlet):
state = yield self.presence_handler.get_state(target_user=user)
state = format_user_presence_state(state, self.clock.time_msec())
- defer.returnValue((200, state))
+ return (200, state)
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
@@ -88,7 +88,7 @@ class PresenceStatusRestServlet(RestServlet):
if self.hs.config.use_presence:
yield self.presence_handler.set_state(user, state)
- defer.returnValue((200, {}))
+ return (200, {})
def on_OPTIONS(self, request):
return (200, {})
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index 4d8ab1f47e..d09eb16fb0 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -14,12 +14,16 @@
# limitations under the License.
""" This module contains REST servlets to do with profile: /profile/<paths> """
+import logging
+
from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import UserID
+logger = logging.getLogger(__name__)
+
class ProfileDisplaynameRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True)
@@ -28,6 +32,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
super(ProfileDisplaynameRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
+ self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
@defer.inlineCallbacks
@@ -48,7 +53,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
if displayname is not None:
ret["displayname"] = displayname
- defer.returnValue((200, ret))
+ return (200, ret)
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
@@ -61,15 +66,31 @@ class ProfileDisplaynameRestServlet(RestServlet):
try:
new_name = content["displayname"]
except Exception:
- defer.returnValue((400, "Unable to parse name"))
+ return (400, "Unable to parse name")
yield self.profile_handler.set_displayname(user, requester, new_name, is_admin)
- defer.returnValue((200, {}))
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs"))
+ self.shadow_displayname(shadow_user.to_string(), content)
+
+ return (200, {})
def on_OPTIONS(self, request, user_id):
return (200, {})
+ @defer.inlineCallbacks
+ def shadow_displayname(self, user_id, body):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.put_json(
+ "%s/_matrix/client/r0/profile/%s/displayname?access_token=%s&user_id=%s"
+ % (shadow_hs_url, user_id, as_token, user_id),
+ body,
+ )
+
class ProfileAvatarURLRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
@@ -78,6 +99,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
super(ProfileAvatarURLRestServlet, self).__init__()
self.hs = hs
self.profile_handler = hs.get_profile_handler()
+ self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
@defer.inlineCallbacks
@@ -98,7 +120,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
if avatar_url is not None:
ret["avatar_url"] = avatar_url
- defer.returnValue((200, ret))
+ return (200, ret)
@defer.inlineCallbacks
def on_PUT(self, request, user_id):
@@ -108,17 +130,35 @@ class ProfileAvatarURLRestServlet(RestServlet):
content = parse_json_object_from_request(request)
try:
- new_name = content["avatar_url"]
+ new_avatar_url = content["avatar_url"]
except Exception:
- defer.returnValue((400, "Unable to parse name"))
+ return (400, "Unable to parse name")
+
+ yield self.profile_handler.set_avatar_url(
+ user, requester, new_avatar_url, is_admin
+ )
- yield self.profile_handler.set_avatar_url(user, requester, new_name, is_admin)
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs"))
+ self.shadow_avatar_url(shadow_user.to_string(), content)
- defer.returnValue((200, {}))
+ return (200, {})
def on_OPTIONS(self, request, user_id):
return (200, {})
+ @defer.inlineCallbacks
+ def shadow_avatar_url(self, user_id, body):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.put_json(
+ "%s/_matrix/client/r0/profile/%s/avatar_url?access_token=%s&user_id=%s"
+ % (shadow_hs_url, user_id, as_token, user_id),
+ body,
+ )
+
class ProfileRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
@@ -150,7 +190,7 @@ class ProfileRestServlet(RestServlet):
if avatar_url is not None:
ret["avatar_url"] = avatar_url
- defer.returnValue((200, ret))
+ return (200, ret)
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index e635efb420..c3ae8b98a8 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -69,7 +69,7 @@ class PushRuleRestServlet(RestServlet):
if "attr" in spec:
yield self.set_rule_attr(user_id, spec, content)
self.notify_user(user_id)
- defer.returnValue((200, {}))
+ return (200, {})
if spec["rule_id"].startswith("."):
# Rule ids starting with '.' are reserved for server default rules.
@@ -106,7 +106,7 @@ class PushRuleRestServlet(RestServlet):
except RuleNotFoundException as e:
raise SynapseError(400, str(e))
- defer.returnValue((200, {}))
+ return (200, {})
@defer.inlineCallbacks
def on_DELETE(self, request, path):
@@ -123,7 +123,7 @@ class PushRuleRestServlet(RestServlet):
try:
yield self.store.delete_push_rule(user_id, namespaced_rule_id)
self.notify_user(user_id)
- defer.returnValue((200, {}))
+ return (200, {})
except StoreError as e:
if e.code == 404:
raise NotFoundError()
@@ -151,10 +151,10 @@ class PushRuleRestServlet(RestServlet):
)
if path[0] == "":
- defer.returnValue((200, rules))
+ return (200, rules)
elif path[0] == "global":
result = _filter_ruleset_with_path(rules["global"], path[1:])
- defer.returnValue((200, result))
+ return (200, result)
else:
raise UnrecognizedRequestError()
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index e9246018df..ebc3dec516 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -62,7 +62,7 @@ class PushersRestServlet(RestServlet):
if k not in allowed_keys:
del p[k]
- defer.returnValue((200, {"pushers": pushers}))
+ return (200, {"pushers": pushers})
def on_OPTIONS(self, _):
return 200, {}
@@ -94,7 +94,7 @@ class PushersSetRestServlet(RestServlet):
yield self.pusher_pool.remove_pusher(
content["app_id"], content["pushkey"], user_id=user.to_string()
)
- defer.returnValue((200, {}))
+ return (200, {})
assert_params_in_dict(
content,
@@ -143,7 +143,7 @@ class PushersSetRestServlet(RestServlet):
self.notifier.on_new_replication_data()
- defer.returnValue((200, {}))
+ return (200, {})
def on_OPTIONS(self, _):
return 200, {}
@@ -190,7 +190,7 @@ class PushersRemoveRestServlet(RestServlet):
)
request.write(PushersRemoveRestServlet.SUCCESS_HTML)
finish_request(request)
- defer.returnValue(None)
+ return None
def on_OPTIONS(self, _):
return 200, {}
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 6276e97f89..d42717effd 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -91,7 +91,7 @@ class RoomCreateRestServlet(TransactionRestServlet):
requester, self.get_room_config(request)
)
- defer.returnValue((200, info))
+ return (200, info)
def get_room_config(self, request):
user_supplied_config = parse_json_object_from_request(request)
@@ -173,9 +173,9 @@ class RoomStateEventRestServlet(TransactionRestServlet):
if format == "event":
event = format_event_for_client_v2(data.get_dict())
- defer.returnValue((200, event))
+ return (200, event)
elif format == "content":
- defer.returnValue((200, data.get_dict()["content"]))
+ return (200, data.get_dict()["content"])
@defer.inlineCallbacks
def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
@@ -210,7 +210,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
ret = {}
if event:
ret = {"event_id": event.event_id}
- defer.returnValue((200, ret))
+ return (200, ret)
# TODO: Needs unit testing for generic events + feedback
@@ -244,7 +244,7 @@ class RoomSendEventRestServlet(TransactionRestServlet):
requester, event_dict, txn_id=txn_id
)
- defer.returnValue((200, {"event_id": event.event_id}))
+ return (200, {"event_id": event.event_id})
def on_GET(self, request, room_id, event_type, txn_id):
return (200, "Not implemented")
@@ -307,7 +307,7 @@ class JoinRoomAliasServlet(TransactionRestServlet):
third_party_signed=content.get("third_party_signed", None),
)
- defer.returnValue((200, {"room_id": room_id}))
+ return (200, {"room_id": room_id})
def on_PUT(self, request, room_identifier, txn_id):
return self.txns.fetch_or_execute_request(
@@ -360,7 +360,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
limit=limit, since_token=since_token
)
- defer.returnValue((200, data))
+ return (200, data)
@defer.inlineCallbacks
def on_POST(self, request):
@@ -405,7 +405,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
network_tuple=network_tuple,
)
- defer.returnValue((200, data))
+ return (200, data)
# TODO: Needs unit testing
@@ -456,7 +456,7 @@ class RoomMemberListRestServlet(RestServlet):
continue
chunk.append(event)
- defer.returnValue((200, {"chunk": chunk}))
+ return (200, {"chunk": chunk})
# deprecated in favour of /members?membership=join?
@@ -477,7 +477,7 @@ class JoinedRoomMemberListRestServlet(RestServlet):
requester, room_id
)
- defer.returnValue((200, {"joined": users_with_profile}))
+ return (200, {"joined": users_with_profile})
# TODO: Needs better unit testing
@@ -510,7 +510,7 @@ class RoomMessageListRestServlet(RestServlet):
event_filter=event_filter,
)
- defer.returnValue((200, msgs))
+ return (200, msgs)
# TODO: Needs unit testing
@@ -531,7 +531,7 @@ class RoomStateRestServlet(RestServlet):
user_id=requester.user.to_string(),
is_guest=requester.is_guest,
)
- defer.returnValue((200, events))
+ return (200, events)
# TODO: Needs unit testing
@@ -550,7 +550,7 @@ class RoomInitialSyncRestServlet(RestServlet):
content = yield self.initial_sync_handler.room_initial_sync(
room_id=room_id, requester=requester, pagin_config=pagination_config
)
- defer.returnValue((200, content))
+ return (200, content)
class RoomEventServlet(RestServlet):
@@ -573,9 +573,9 @@ class RoomEventServlet(RestServlet):
time_now = self.clock.time_msec()
if event:
event = yield self._event_serializer.serialize_event(event, time_now)
- defer.returnValue((200, event))
+ return (200, event)
else:
- defer.returnValue((404, "Event not found."))
+ return (404, "Event not found.")
class RoomEventContextServlet(RestServlet):
@@ -625,7 +625,7 @@ class RoomEventContextServlet(RestServlet):
results["state"], time_now
)
- defer.returnValue((200, results))
+ return (200, results)
class RoomForgetRestServlet(TransactionRestServlet):
@@ -644,7 +644,7 @@ class RoomForgetRestServlet(TransactionRestServlet):
yield self.room_member_handler.forget(user=requester.user, room_id=room_id)
- defer.returnValue((200, {}))
+ return (200, {})
def on_PUT(self, request, room_id, txn_id):
return self.txns.fetch_or_execute_request(
@@ -693,8 +693,9 @@ class RoomMembershipRestServlet(TransactionRestServlet):
content["id_server"],
requester,
txn_id,
+ new_room=False,
)
- defer.returnValue((200, {}))
+ return (200, {})
return
target = requester.user
@@ -721,7 +722,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
if membership_action == "join":
return_value["room_id"] = room_id
- defer.returnValue((200, return_value))
+ return (200, return_value)
def _has_3pid_invite_keys(self, content):
for key in {"id_server", "medium", "address"}:
@@ -763,7 +764,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
txn_id=txn_id,
)
- defer.returnValue((200, {"event_id": event.event_id}))
+ return (200, {"event_id": event.event_id})
def on_PUT(self, request, room_id, event_id, txn_id):
return self.txns.fetch_or_execute_request(
@@ -808,7 +809,7 @@ class RoomTypingRestServlet(RestServlet):
target_user=target_user, auth_user=requester.user, room_id=room_id
)
- defer.returnValue((200, {}))
+ return (200, {})
class SearchRestServlet(RestServlet):
@@ -830,7 +831,7 @@ class SearchRestServlet(RestServlet):
requester.user, content, batch
)
- defer.returnValue((200, results))
+ return (200, results)
class JoinedRoomsRestServlet(RestServlet):
@@ -846,7 +847,7 @@ class JoinedRoomsRestServlet(RestServlet):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
room_ids = yield self.store.get_rooms_for_user(requester.user.to_string())
- defer.returnValue((200, {"joined_rooms": list(room_ids)}))
+ return (200, {"joined_rooms": list(room_ids)})
def register_txn_path(servlet, regex_string, http_server, with_get=False):
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 41b3171ac8..497cddf8b8 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -60,18 +60,16 @@ class VoipRestServlet(RestServlet):
password = turnPassword
else:
- defer.returnValue((200, {}))
-
- defer.returnValue(
- (
- 200,
- {
- "username": username,
- "password": password,
- "ttl": userLifetime / 1000,
- "uris": turnUris,
- },
- )
+ return (200, {})
+
+ return (
+ 200,
+ {
+ "username": username,
+ "password": password,
+ "ttl": userLifetime / 1000,
+ "uris": turnUris,
+ },
)
def on_OPTIONS(self, request):
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index f143d8b85c..2580e2bc63 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018, 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import re
from six.moves import http_client
@@ -31,8 +32,9 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.types import UserID
from synapse.util.msisdn import phone_number_to_msisdn
-from synapse.util.stringutils import random_string
+from synapse.util.stringutils import assert_valid_client_secret, random_string
from synapse.util.threepids import check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler
@@ -82,6 +84,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# Extract params from body
client_secret = body["client_secret"]
+ assert_valid_client_secret(client_secret)
+
email = body["email"]
send_attempt = body["send_attempt"]
next_link = body.get("next_link") # Optional param
@@ -89,7 +93,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
if not check_3pid_allowed(self.hs, "email", email):
raise SynapseError(
403,
- "Your email domain is not authorized on this server",
+ "Your email is not authorized on this server",
Codes.THREEPID_DENIED,
)
@@ -117,7 +121,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# Wrap the session id in a JSON object
ret = {"sid": sid}
- defer.returnValue((200, ret))
+ return (200, ret)
@defer.inlineCallbacks
def send_password_reset(self, email, client_secret, send_attempt, next_link=None):
@@ -149,7 +153,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# Check that the send_attempt is higher than previous attempts
if send_attempt <= last_send_attempt:
# If not, just return a success without sending an email
- defer.returnValue(session_id)
+ return session_id
else:
# An non-validated session does not exist yet.
# Generate a session id
@@ -185,7 +189,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
token_expires,
)
- defer.returnValue(session_id)
+ return session_id
class MsisdnPasswordRequestTokenRestServlet(RestServlet):
@@ -208,20 +212,22 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(body["country"], body["phone_number"])
- if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ if not (yield check_3pid_allowed(self.hs, "msisdn", msisdn)):
raise SynapseError(
403,
"Account phone numbers are not authorized on this server",
Codes.THREEPID_DENIED,
)
+ assert_valid_client_secret(body["client_secret"])
+
existingUid = yield self.datastore.get_user_id_by_threepid("msisdn", msisdn)
if existingUid is None:
raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND)
ret = yield self.identity_handler.requestMsisdnToken(**body)
- defer.returnValue((200, ret))
+ return (200, ret)
class PasswordResetSubmitTokenServlet(RestServlet):
@@ -260,6 +266,9 @@ class PasswordResetSubmitTokenServlet(RestServlet):
sid = parse_string(request, "sid")
client_secret = parse_string(request, "client_secret")
+
+ assert_valid_client_secret(client_secret)
+
token = parse_string(request, "token")
# Attempt to validate a 3PID sesssion
@@ -279,7 +288,7 @@ class PasswordResetSubmitTokenServlet(RestServlet):
request.setResponseCode(302)
request.setHeader("Location", next_link)
finish_request(request)
- defer.returnValue(None)
+ return None
# Otherwise show the success template
html = self.config.email_password_reset_success_html_content
@@ -295,7 +304,7 @@ class PasswordResetSubmitTokenServlet(RestServlet):
request.write(html.encode("utf-8"))
finish_request(request)
- defer.returnValue(None)
+ return None
def load_jinja2_template(self, template_dir, template_filename, template_vars):
"""Loads a jinja2 template with variables to insert
@@ -325,12 +334,14 @@ class PasswordResetSubmitTokenServlet(RestServlet):
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["sid", "client_secret", "token"])
- valid, _ = yield self.datastore.validate_threepid_validation_token(
+ assert_valid_client_secret(body["client_secret"])
+
+ valid, _ = yield self.datastore.validate_threepid_session(
body["sid"], body["client_secret"], body["token"], self.clock.time_msec()
)
response_code = 200 if valid else 400
- defer.returnValue((response_code, {"success": valid}))
+ return (response_code, {"success": valid})
class PasswordRestServlet(RestServlet):
@@ -343,6 +354,7 @@ class PasswordRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore()
self._set_password_handler = hs.get_set_password_handler()
+ self.http_client = hs.get_simple_http_client()
@interactive_auth_handler
@defer.inlineCallbacks
@@ -361,9 +373,13 @@ class PasswordRestServlet(RestServlet):
if self.auth.has_access_token(request):
requester = yield self.auth.get_user_by_req(request)
- params = yield self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request)
- )
+ # blindly trust ASes without UI-authing them
+ if requester.app_service:
+ params = body
+ else:
+ params = yield self.auth_handler.validate_user_via_ui_auth(
+ requester, body, self.hs.get_ip_from_request(request)
+ )
user_id = requester.user.to_string()
else:
requester = None
@@ -399,11 +415,29 @@ class PasswordRestServlet(RestServlet):
yield self._set_password_handler.set_password(user_id, new_password, requester)
- defer.returnValue((200, {}))
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ self.shadow_password(params, shadow_user.to_string())
+
+ return (200, {})
def on_OPTIONS(self, _):
return 200, {}
+ @defer.inlineCallbacks
+ def shadow_password(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/password?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = client_patterns("/account/deactivate$")
@@ -434,7 +468,7 @@ class DeactivateAccountRestServlet(RestServlet):
yield self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase
)
- defer.returnValue((200, {}))
+ return (200, {})
yield self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request)
@@ -447,7 +481,7 @@ class DeactivateAccountRestServlet(RestServlet):
else:
id_server_unbind_result = "no-support"
- defer.returnValue((200, {"id_server_unbind_result": id_server_unbind_result}))
+ return (200, {"id_server_unbind_result": id_server_unbind_result})
class EmailThreepidRequestTokenRestServlet(RestServlet):
@@ -466,13 +500,15 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
body, ["id_server", "client_secret", "email", "send_attempt"]
)
- if not check_3pid_allowed(self.hs, "email", body["email"]):
+ if not (yield check_3pid_allowed(self.hs, "email", body["email"])):
raise SynapseError(
403,
- "Your email domain is not authorized on this server",
+ "Your email is not authorized on this server",
Codes.THREEPID_DENIED,
)
+ assert_valid_client_secret(body["client_secret"])
+
existingUid = yield self.datastore.get_user_id_by_threepid(
"email", body["email"]
)
@@ -481,7 +517,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body)
- defer.returnValue((200, ret))
+ return (200, ret)
class MsisdnThreepidRequestTokenRestServlet(RestServlet):
@@ -503,20 +539,22 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(body["country"], body["phone_number"])
- if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ if not (yield check_3pid_allowed(self.hs, "msisdn", msisdn)):
raise SynapseError(
403,
"Account phone numbers are not authorized on this server",
Codes.THREEPID_DENIED,
)
+ assert_valid_client_secret(body["client_secret"])
+
existingUid = yield self.datastore.get_user_id_by_threepid("msisdn", msisdn)
if existingUid is not None:
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestMsisdnToken(**body)
- defer.returnValue((200, ret))
+ return (200, ret)
class ThreepidRestServlet(RestServlet):
@@ -528,7 +566,8 @@ class ThreepidRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
- self.datastore = self.hs.get_datastore()
+ self.datastore = hs.get_datastore()
+ self.http_client = hs.get_simple_http_client()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -536,39 +575,70 @@ class ThreepidRestServlet(RestServlet):
threepids = yield self.datastore.user_get_threepids(requester.user.to_string())
- defer.returnValue((200, {"threepids": threepids}))
+ return (200, {"threepids": threepids})
@defer.inlineCallbacks
def on_POST(self, request):
- body = parse_json_object_from_request(request)
+ if self.hs.config.disable_3pid_changes:
+ raise SynapseError(400, "3PID changes disabled on this server")
- threePidCreds = body.get("threePidCreds")
- threePidCreds = body.get("three_pid_creds", threePidCreds)
- if threePidCreds is None:
- raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
+ body = parse_json_object_from_request(request)
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
- threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)
+ # skip validation if this is a shadow 3PID from an AS
+ if not requester.app_service:
+ threePidCreds = body.get("threePidCreds")
+ threePidCreds = body.get("three_pid_creds", threePidCreds)
+ if threePidCreds is None:
+ raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
- if not threepid:
- raise SynapseError(400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED)
+ threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)
- for reqd in ["medium", "address", "validated_at"]:
- if reqd not in threepid:
- logger.warn("Couldn't add 3pid: invalid response from ID server")
- raise SynapseError(500, "Invalid response from ID Server")
+ if not threepid:
+ raise SynapseError(
+ 400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED
+ )
+
+ for reqd in ["medium", "address", "validated_at"]:
+ if reqd not in threepid:
+ logger.warn("Couldn't add 3pid: invalid response from ID server")
+ raise SynapseError(500, "Invalid response from ID Server")
+ else:
+ # XXX: ASes pass in a validated threepid directly to bypass the IS.
+ # This makes the API entirely change shape when we have an AS token;
+ # it really should be an entirely separate API - perhaps
+ # /account/3pid/replicate or something.
+ threepid = body.get("threepid")
yield self.auth_handler.add_threepid(
user_id, threepid["medium"], threepid["address"], threepid["validated_at"]
)
- if "bind" in body and body["bind"]:
+ if not requester.app_service and ("bind" in body and body["bind"]):
logger.debug("Binding threepid %s to %s", threepid, user_id)
yield self.identity_handler.bind_threepid(threePidCreds, user_id)
- defer.returnValue((200, {}))
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
+
+ return (200, {})
+
+ @defer.inlineCallbacks
+ def shadow_3pid(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
class ThreepidDeleteRestServlet(RestServlet):
@@ -576,11 +646,16 @@ class ThreepidDeleteRestServlet(RestServlet):
def __init__(self, hs):
super(ThreepidDeleteRestServlet, self).__init__()
+ self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
+ self.http_client = hs.get_simple_http_client()
@defer.inlineCallbacks
def on_POST(self, request):
+ if self.hs.config.disable_3pid_changes:
+ raise SynapseError(400, "3PID changes disabled on this server")
+
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["medium", "address"])
@@ -598,12 +673,89 @@ class ThreepidDeleteRestServlet(RestServlet):
logger.exception("Failed to remove threepid")
raise SynapseError(500, "Failed to remove threepid")
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ self.shadow_3pid_delete(body, shadow_user.to_string())
+
if ret:
id_server_unbind_result = "success"
else:
id_server_unbind_result = "no-support"
- defer.returnValue((200, {"id_server_unbind_result": id_server_unbind_result}))
+ return (200, {"id_server_unbind_result": id_server_unbind_result})
+
+ @defer.inlineCallbacks
+ def shadow_3pid_delete(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/3pid/delete?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
+
+class ThreepidLookupRestServlet(RestServlet):
+ PATTERNS = [re.compile("^/_matrix/client/unstable/account/3pid/lookup$")]
+
+ def __init__(self, hs):
+ super(ThreepidLookupRestServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.identity_handler = hs.get_handlers().identity_handler
+
+ @defer.inlineCallbacks
+ def on_GET(self, request):
+ """Proxy a /_matrix/identity/api/v1/lookup request to an identity
+ server
+ """
+ yield self.auth.get_user_by_req(request)
+
+ # Verify query parameters
+ query_params = request.args
+ assert_params_in_dict(query_params, [b"medium", b"address", b"id_server"])
+
+ # Retrieve needed information from query parameters
+ medium = parse_string(request, "medium")
+ address = parse_string(request, "address")
+ id_server = parse_string(request, "id_server")
+
+ # Proxy the request to the identity server. lookup_3pid handles checking
+ # if the lookup is allowed so we don't need to do it here.
+ ret = yield self.identity_handler.lookup_3pid(id_server, medium, address)
+
+ defer.returnValue((200, ret))
+
+
+class ThreepidBulkLookupRestServlet(RestServlet):
+ PATTERNS = [re.compile("^/_matrix/client/unstable/account/3pid/bulk_lookup$")]
+
+ def __init__(self, hs):
+ super(ThreepidBulkLookupRestServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.identity_handler = hs.get_handlers().identity_handler
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ """Proxy a /_matrix/identity/api/v1/bulk_lookup request to an identity
+ server
+ """
+ yield self.auth.get_user_by_req(request)
+
+ body = parse_json_object_from_request(request)
+
+ assert_params_in_dict(body, ["threepids", "id_server"])
+
+ # Proxy the request to the identity server. lookup_3pid handles checking
+ # if the lookup is allowed so we don't need to do it here.
+ ret = yield self.identity_handler.bulk_lookup_3pid(
+ body["id_server"], body["threepids"]
+ )
+
+ defer.returnValue((200, ret))
class WhoamiRestServlet(RestServlet):
@@ -617,7 +769,7 @@ class WhoamiRestServlet(RestServlet):
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
- defer.returnValue((200, {"user_id": requester.user.to_string()}))
+ return (200, {"user_id": requester.user.to_string()})
def register_servlets(hs, http_server):
@@ -630,4 +782,6 @@ def register_servlets(hs, http_server):
MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server)
+ ThreepidLookupRestServlet(hs).register(http_server)
+ ThreepidBulkLookupRestServlet(hs).register(http_server)
WhoamiRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index f155c26259..be1360eca3 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.types import UserID
from ._base import client_patterns
@@ -40,6 +41,7 @@ class AccountDataServlet(RestServlet):
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
+ self._profile_handler = hs.get_profile_handler()
@defer.inlineCallbacks
def on_PUT(self, request, user_id, account_data_type):
@@ -49,13 +51,18 @@ class AccountDataServlet(RestServlet):
body = parse_json_object_from_request(request)
+ if account_data_type == "im.vector.hide_profile":
+ user = UserID.from_string(user_id)
+ hide_profile = body.get("hide_profile")
+ yield self._profile_handler.set_active(user, not hide_profile, True)
+
max_id = yield self.store.add_account_data_for_user(
user_id, account_data_type, body
)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
- defer.returnValue((200, {}))
+ return (200, {})
@defer.inlineCallbacks
def on_GET(self, request, user_id, account_data_type):
@@ -70,7 +77,7 @@ class AccountDataServlet(RestServlet):
if event is None:
raise NotFoundError("Account data not found")
- defer.returnValue((200, event))
+ return (200, event)
class RoomAccountDataServlet(RestServlet):
@@ -112,7 +119,7 @@ class RoomAccountDataServlet(RestServlet):
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
- defer.returnValue((200, {}))
+ return (200, {})
@defer.inlineCallbacks
def on_GET(self, request, user_id, room_id, account_data_type):
@@ -127,7 +134,7 @@ class RoomAccountDataServlet(RestServlet):
if event is None:
raise NotFoundError("Room account data not found")
- defer.returnValue((200, event))
+ return (200, event)
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py
index d29c10b83d..327df56acc 100644
--- a/synapse/rest/client/v2_alpha/account_validity.py
+++ b/synapse/rest/client/v2_alpha/account_validity.py
@@ -15,6 +15,8 @@
import logging
+from six import ensure_binary
+
from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError
@@ -42,6 +44,8 @@ class AccountValidityRenewServlet(RestServlet):
self.hs = hs
self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth()
+ self.success_html = hs.config.account_validity.account_renewed_html_content
+ self.failure_html = hs.config.account_validity.invalid_token_html_content
@defer.inlineCallbacks
def on_GET(self, request):
@@ -49,16 +53,23 @@ class AccountValidityRenewServlet(RestServlet):
raise SynapseError(400, "Missing renewal token")
renewal_token = request.args[b"token"][0]
- yield self.account_activity_handler.renew_account(renewal_token.decode("utf8"))
+ token_valid = yield self.account_activity_handler.renew_account(
+ renewal_token.decode("utf8")
+ )
+
+ if token_valid:
+ status_code = 200
+ response = self.success_html
+ else:
+ status_code = 404
+ response = self.failure_html
- request.setResponseCode(200)
+ request.setResponseCode(status_code)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(
- b"Content-Length", b"%d" % (len(AccountValidityRenewServlet.SUCCESS_HTML),)
- )
- request.write(AccountValidityRenewServlet.SUCCESS_HTML)
+ request.setHeader(b"Content-Length", b"%d" % (len(response),))
+ request.write(ensure_binary(response))
finish_request(request)
- defer.returnValue(None)
+ return None
class AccountValiditySendMailServlet(RestServlet):
@@ -87,7 +98,7 @@ class AccountValiditySendMailServlet(RestServlet):
user_id = requester.user.to_string()
yield self.account_activity_handler.send_renewal_email_to_user(user_id)
- defer.returnValue((200, {}))
+ return (200, {})
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index bebc2951e7..f21aff39e5 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -207,7 +207,7 @@ class AuthRestServlet(RestServlet):
request.write(html_bytes)
finish_request(request)
- defer.returnValue(None)
+ return None
elif stagetype == LoginType.TERMS:
if ("session" not in request.args or len(request.args["session"])) == 0:
raise SynapseError(400, "No session supplied")
@@ -239,7 +239,7 @@ class AuthRestServlet(RestServlet):
request.write(html_bytes)
finish_request(request)
- defer.returnValue(None)
+ return None
else:
raise SynapseError(404, "Unknown auth stage type")
diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py
index fc7e2f4dd5..a4fa45fe11 100644
--- a/synapse/rest/client/v2_alpha/capabilities.py
+++ b/synapse/rest/client/v2_alpha/capabilities.py
@@ -58,7 +58,7 @@ class CapabilitiesRestServlet(RestServlet):
"m.change_password": {"enabled": change_password},
}
}
- defer.returnValue((200, response))
+ return (200, response)
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index d279229d74..9adf76cc0c 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -48,7 +48,7 @@ class DevicesRestServlet(RestServlet):
devices = yield self.device_handler.get_devices_by_user(
requester.user.to_string()
)
- defer.returnValue((200, {"devices": devices}))
+ return (200, {"devices": devices})
class DeleteDevicesRestServlet(RestServlet):
@@ -91,7 +91,7 @@ class DeleteDevicesRestServlet(RestServlet):
yield self.device_handler.delete_devices(
requester.user.to_string(), body["devices"]
)
- defer.returnValue((200, {}))
+ return (200, {})
class DeviceRestServlet(RestServlet):
@@ -114,7 +114,7 @@ class DeviceRestServlet(RestServlet):
device = yield self.device_handler.get_device(
requester.user.to_string(), device_id
)
- defer.returnValue((200, device))
+ return (200, device)
@interactive_auth_handler
@defer.inlineCallbacks
@@ -137,7 +137,7 @@ class DeviceRestServlet(RestServlet):
)
yield self.device_handler.delete_device(requester.user.to_string(), device_id)
- defer.returnValue((200, {}))
+ return (200, {})
@defer.inlineCallbacks
def on_PUT(self, request, device_id):
@@ -147,7 +147,7 @@ class DeviceRestServlet(RestServlet):
yield self.device_handler.update_device(
requester.user.to_string(), device_id, body
)
- defer.returnValue((200, {}))
+ return (200, {})
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index 3f0adf4a21..22be0ee3c5 100644
--- a/synapse/rest/client/v2_alpha/filter.py
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -56,7 +56,7 @@ class GetFilterRestServlet(RestServlet):
user_localpart=target_user.localpart, filter_id=filter_id
)
- defer.returnValue((200, filter.get_filter_json()))
+ return (200, filter.get_filter_json())
except (KeyError, StoreError):
raise SynapseError(400, "No such filter", errcode=Codes.NOT_FOUND)
@@ -89,7 +89,7 @@ class CreateFilterRestServlet(RestServlet):
user_localpart=target_user.localpart, user_filter=content
)
- defer.returnValue((200, {"filter_id": str(filter_id)}))
+ return (200, {"filter_id": str(filter_id)})
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index a312dd2593..e629c4256d 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -47,7 +47,7 @@ class GroupServlet(RestServlet):
group_id, requester_user_id
)
- defer.returnValue((200, group_description))
+ return (200, group_description)
@defer.inlineCallbacks
def on_POST(self, request, group_id):
@@ -59,7 +59,7 @@ class GroupServlet(RestServlet):
group_id, requester_user_id, content
)
- defer.returnValue((200, {}))
+ return (200, {})
class GroupSummaryServlet(RestServlet):
@@ -83,7 +83,7 @@ class GroupSummaryServlet(RestServlet):
group_id, requester_user_id
)
- defer.returnValue((200, get_group_summary))
+ return (200, get_group_summary)
class GroupSummaryRoomsCatServlet(RestServlet):
@@ -120,7 +120,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
content=content,
)
- defer.returnValue((200, resp))
+ return (200, resp)
@defer.inlineCallbacks
def on_DELETE(self, request, group_id, category_id, room_id):
@@ -131,7 +131,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
group_id, requester_user_id, room_id=room_id, category_id=category_id
)
- defer.returnValue((200, resp))
+ return (200, resp)
class GroupCategoryServlet(RestServlet):
@@ -157,7 +157,7 @@ class GroupCategoryServlet(RestServlet):
group_id, requester_user_id, category_id=category_id
)
- defer.returnValue((200, category))
+ return (200, category)
@defer.inlineCallbacks
def on_PUT(self, request, group_id, category_id):
@@ -169,7 +169,7 @@ class GroupCategoryServlet(RestServlet):
group_id, requester_user_id, category_id=category_id, content=content
)
- defer.returnValue((200, resp))
+ return (200, resp)
@defer.inlineCallbacks
def on_DELETE(self, request, group_id, category_id):
@@ -180,7 +180,7 @@ class GroupCategoryServlet(RestServlet):
group_id, requester_user_id, category_id=category_id
)
- defer.returnValue((200, resp))
+ return (200, resp)
class GroupCategoriesServlet(RestServlet):
@@ -204,7 +204,7 @@ class GroupCategoriesServlet(RestServlet):
group_id, requester_user_id
)
- defer.returnValue((200, category))
+ return (200, category)
class GroupRoleServlet(RestServlet):
@@ -228,7 +228,7 @@ class GroupRoleServlet(RestServlet):
group_id, requester_user_id, role_id=role_id
)
- defer.returnValue((200, category))
+ return (200, category)
@defer.inlineCallbacks
def on_PUT(self, request, group_id, role_id):
@@ -240,7 +240,7 @@ class GroupRoleServlet(RestServlet):
group_id, requester_user_id, role_id=role_id, content=content
)
- defer.returnValue((200, resp))
+ return (200, resp)
@defer.inlineCallbacks
def on_DELETE(self, request, group_id, role_id):
@@ -251,7 +251,7 @@ class GroupRoleServlet(RestServlet):
group_id, requester_user_id, role_id=role_id
)
- defer.returnValue((200, resp))
+ return (200, resp)
class GroupRolesServlet(RestServlet):
@@ -275,7 +275,7 @@ class GroupRolesServlet(RestServlet):
group_id, requester_user_id
)
- defer.returnValue((200, category))
+ return (200, category)
class GroupSummaryUsersRoleServlet(RestServlet):
@@ -312,7 +312,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
content=content,
)
- defer.returnValue((200, resp))
+ return (200, resp)
@defer.inlineCallbacks
def on_DELETE(self, request, group_id, role_id, user_id):
@@ -323,7 +323,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
group_id, requester_user_id, user_id=user_id, role_id=role_id
)
- defer.returnValue((200, resp))
+ return (200, resp)
class GroupRoomServlet(RestServlet):
@@ -347,7 +347,7 @@ class GroupRoomServlet(RestServlet):
group_id, requester_user_id
)
- defer.returnValue((200, result))
+ return (200, result)
class GroupUsersServlet(RestServlet):
@@ -371,7 +371,7 @@ class GroupUsersServlet(RestServlet):
group_id, requester_user_id
)
- defer.returnValue((200, result))
+ return (200, result)
class GroupInvitedUsersServlet(RestServlet):
@@ -395,7 +395,7 @@ class GroupInvitedUsersServlet(RestServlet):
group_id, requester_user_id
)
- defer.returnValue((200, result))
+ return (200, result)
class GroupSettingJoinPolicyServlet(RestServlet):
@@ -420,7 +420,7 @@ class GroupSettingJoinPolicyServlet(RestServlet):
group_id, requester_user_id, content
)
- defer.returnValue((200, result))
+ return (200, result)
class GroupCreateServlet(RestServlet):
@@ -450,7 +450,7 @@ class GroupCreateServlet(RestServlet):
group_id, requester_user_id, content
)
- defer.returnValue((200, result))
+ return (200, result)
class GroupAdminRoomsServlet(RestServlet):
@@ -477,7 +477,7 @@ class GroupAdminRoomsServlet(RestServlet):
group_id, requester_user_id, room_id, content
)
- defer.returnValue((200, result))
+ return (200, result)
@defer.inlineCallbacks
def on_DELETE(self, request, group_id, room_id):
@@ -488,7 +488,7 @@ class GroupAdminRoomsServlet(RestServlet):
group_id, requester_user_id, room_id
)
- defer.returnValue((200, result))
+ return (200, result)
class GroupAdminRoomsConfigServlet(RestServlet):
@@ -516,7 +516,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
group_id, requester_user_id, room_id, config_key, content
)
- defer.returnValue((200, result))
+ return (200, result)
class GroupAdminUsersInviteServlet(RestServlet):
@@ -546,7 +546,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
group_id, user_id, requester_user_id, config
)
- defer.returnValue((200, result))
+ return (200, result)
class GroupAdminUsersKickServlet(RestServlet):
@@ -573,7 +573,7 @@ class GroupAdminUsersKickServlet(RestServlet):
group_id, user_id, requester_user_id, content
)
- defer.returnValue((200, result))
+ return (200, result)
class GroupSelfLeaveServlet(RestServlet):
@@ -598,7 +598,7 @@ class GroupSelfLeaveServlet(RestServlet):
group_id, requester_user_id, requester_user_id, content
)
- defer.returnValue((200, result))
+ return (200, result)
class GroupSelfJoinServlet(RestServlet):
@@ -623,7 +623,7 @@ class GroupSelfJoinServlet(RestServlet):
group_id, requester_user_id, content
)
- defer.returnValue((200, result))
+ return (200, result)
class GroupSelfAcceptInviteServlet(RestServlet):
@@ -648,7 +648,7 @@ class GroupSelfAcceptInviteServlet(RestServlet):
group_id, requester_user_id, content
)
- defer.returnValue((200, result))
+ return (200, result)
class GroupSelfUpdatePublicityServlet(RestServlet):
@@ -672,7 +672,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
publicise = content["publicise"]
yield self.store.update_group_publicity(group_id, requester_user_id, publicise)
- defer.returnValue((200, {}))
+ return (200, {})
class PublicisedGroupsForUserServlet(RestServlet):
@@ -694,7 +694,7 @@ class PublicisedGroupsForUserServlet(RestServlet):
result = yield self.groups_handler.get_publicised_groups_for_user(user_id)
- defer.returnValue((200, result))
+ return (200, result)
class PublicisedGroupsForUsersServlet(RestServlet):
@@ -719,7 +719,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
result = yield self.groups_handler.bulk_get_publicised_groups(user_ids)
- defer.returnValue((200, result))
+ return (200, result)
class GroupsForUserServlet(RestServlet):
@@ -741,7 +741,7 @@ class GroupsForUserServlet(RestServlet):
result = yield self.groups_handler.get_joined_groups(requester_user_id)
- defer.returnValue((200, result))
+ return (200, result)
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 45c9928b65..6008adec7c 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -95,7 +95,7 @@ class KeyUploadServlet(RestServlet):
result = yield self.e2e_keys_handler.upload_keys_for_user(
user_id, device_id, body
)
- defer.returnValue((200, result))
+ return (200, result)
class KeyQueryServlet(RestServlet):
@@ -149,7 +149,7 @@ class KeyQueryServlet(RestServlet):
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
result = yield self.e2e_keys_handler.query_devices(body, timeout)
- defer.returnValue((200, result))
+ return (200, result)
class KeyChangesServlet(RestServlet):
@@ -189,7 +189,7 @@ class KeyChangesServlet(RestServlet):
results = yield self.device_handler.get_user_ids_changed(user_id, from_token)
- defer.returnValue((200, results))
+ return (200, results)
class OneTimeKeyServlet(RestServlet):
@@ -224,7 +224,7 @@ class OneTimeKeyServlet(RestServlet):
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
result = yield self.e2e_keys_handler.claim_one_time_keys(body, timeout)
- defer.returnValue((200, result))
+ return (200, result)
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py
index 728a52328f..d034863a3c 100644
--- a/synapse/rest/client/v2_alpha/notifications.py
+++ b/synapse/rest/client/v2_alpha/notifications.py
@@ -88,9 +88,7 @@ class NotificationsServlet(RestServlet):
returned_push_actions.append(returned_pa)
next_token = str(pa["stream_ordering"])
- defer.returnValue(
- (200, {"notifications": returned_push_actions, "next_token": next_token})
- )
+ return (200, {"notifications": returned_push_actions, "next_token": next_token})
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py
index b1b5385b09..b4925c0f59 100644
--- a/synapse/rest/client/v2_alpha/openid.py
+++ b/synapse/rest/client/v2_alpha/openid.py
@@ -83,16 +83,14 @@ class IdTokenServlet(RestServlet):
yield self.store.insert_open_id_token(token, ts_valid_until_ms, user_id)
- defer.returnValue(
- (
- 200,
- {
- "access_token": token,
- "token_type": "Bearer",
- "matrix_server_name": self.server_name,
- "expires_in": self.EXPIRES_MS / 1000,
- },
- )
+ return (
+ 200,
+ {
+ "access_token": token,
+ "token_type": "Bearer",
+ "matrix_server_name": self.server_name,
+ "expires_in": self.EXPIRES_MS / 1000,
+ },
)
diff --git a/synapse/rest/client/v2_alpha/password_policy.py b/synapse/rest/client/v2_alpha/password_policy.py
new file mode 100644
index 0000000000..968403cca4
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/password_policy.py
@@ -0,0 +1,58 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.http.servlet import RestServlet
+
+from ._base import client_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class PasswordPolicyServlet(RestServlet):
+ PATTERNS = client_patterns("/password_policy$")
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super(PasswordPolicyServlet, self).__init__()
+
+ self.policy = hs.config.password_policy
+ self.enabled = hs.config.password_policy_enabled
+
+ def on_GET(self, request):
+ if not self.enabled or not self.policy:
+ return (200, {})
+
+ policy = {}
+
+ for param in [
+ "minimum_length",
+ "require_digit",
+ "require_symbol",
+ "require_lowercase",
+ "require_uppercase",
+ ]:
+ if param in self.policy:
+ policy["m.%s" % param] = self.policy[param]
+
+ return (200, policy)
+
+
+def register_servlets(hs, http_server):
+ PasswordPolicyServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py
index e75664279b..d93d6a9f24 100644
--- a/synapse/rest/client/v2_alpha/read_marker.py
+++ b/synapse/rest/client/v2_alpha/read_marker.py
@@ -59,7 +59,7 @@ class ReadMarkerRestServlet(RestServlet):
event_id=read_marker_event_id,
)
- defer.returnValue((200, {}))
+ return (200, {})
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py
index 488905626a..98a97b7059 100644
--- a/synapse/rest/client/v2_alpha/receipts.py
+++ b/synapse/rest/client/v2_alpha/receipts.py
@@ -52,7 +52,7 @@ class ReceiptRestServlet(RestServlet):
room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id
)
- defer.returnValue((200, {}))
+ return (200, {})
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index f327999e59..c4ab2e53cf 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
-# Copyright 2015 - 2016 OpenMarket Ltd
-# Copyright 2017 Vector Creations Ltd
+# Copyright 2015-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,6 +17,7 @@
import hmac
import logging
+import re
from hashlib import sha1
from six import string_types
@@ -41,6 +43,7 @@ from synapse.http.servlet import (
)
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter
+from synapse.util.stringutils import assert_valid_client_secret
from synapse.util.threepids import check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler
@@ -80,13 +83,15 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
body, ["id_server", "client_secret", "email", "send_attempt"]
)
- if not check_3pid_allowed(self.hs, "email", body["email"]):
+ if not (yield check_3pid_allowed(self.hs, "email", body["email"])):
raise SynapseError(
403,
- "Your email domain is not authorized to register on this server",
+ "Your email is not authorized to register on this server",
Codes.THREEPID_DENIED,
)
+ assert_params_in_dict(body["client_secret"])
+
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
"email", body["email"]
)
@@ -95,7 +100,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
ret = yield self.identity_handler.requestEmailToken(**body)
- defer.returnValue((200, ret))
+ return (200, ret)
class MsisdnRegisterRequestTokenRestServlet(RestServlet):
@@ -121,7 +126,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(body["country"], body["phone_number"])
- if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ assert_valid_client_secret(body["client_secret"])
+
+ if not (yield check_3pid_allowed(self.hs, "msisdn", msisdn)):
raise SynapseError(
403,
"Phone numbers are not authorized to register on this server",
@@ -138,7 +145,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
)
ret = yield self.identity_handler.requestMsisdnToken(**body)
- defer.returnValue((200, ret))
+ return (200, ret)
class UsernameAvailabilityRestServlet(RestServlet):
@@ -178,7 +185,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
yield self.registration_handler.check_username(username)
- defer.returnValue((200, {"available": True}))
+ return (200, {"available": True})
class RegisterRestServlet(RestServlet):
@@ -200,6 +207,7 @@ class RegisterRestServlet(RestServlet):
self.room_member_handler = hs.get_room_member_handler()
self.macaroon_gen = hs.get_macaroon_generator()
self.ratelimiter = hs.get_registration_ratelimiter()
+ self.password_policy_handler = hs.get_password_policy_handler()
self.clock = hs.get_clock()
@interactive_auth_handler
@@ -230,7 +238,7 @@ class RegisterRestServlet(RestServlet):
if kind == b"guest":
ret = yield self._do_guest_registration(body, address=client_addr)
- defer.returnValue(ret)
+ return ret
return
elif kind != b"user":
raise UnrecognizedRequestError(
@@ -246,6 +254,7 @@ class RegisterRestServlet(RestServlet):
or len(body["password"]) > 512
):
raise SynapseError(400, "Invalid password")
+ self.password_policy_handler.validate_password(body["password"])
desired_password = body["password"]
desired_username = None
@@ -257,6 +266,8 @@ class RegisterRestServlet(RestServlet):
raise SynapseError(400, "Invalid username")
desired_username = body["username"]
+ desired_display_name = body.get("display_name")
+
appservice = None
if self.auth.has_access_token(request):
appservice = yield self.auth.get_appservice_by_req(request)
@@ -280,9 +291,13 @@ class RegisterRestServlet(RestServlet):
if isinstance(desired_username, string_types):
result = yield self._do_appservice_registration(
- desired_username, access_token, body
+ desired_username,
+ desired_password,
+ desired_display_name,
+ access_token,
+ body,
)
- defer.returnValue((200, result)) # we throw for non 200 responses
+ return (200, result) # we throw for non 200 responses
return
# for either shared secret or regular registration, downcase the
@@ -301,7 +316,7 @@ class RegisterRestServlet(RestServlet):
result = yield self._do_shared_secret_registration(
desired_username, desired_password, body
)
- defer.returnValue((200, result)) # we throw for non 200 responses
+ return (200, result) # we throw for non 200 responses
return
# == Normal User Registration == (everyone else)
@@ -413,7 +428,7 @@ class RegisterRestServlet(RestServlet):
medium = auth_result[login_type]["medium"]
address = auth_result[login_type]["address"]
- if not check_3pid_allowed(self.hs, medium, address):
+ if not (yield check_3pid_allowed(self.hs, medium, address)):
raise SynapseError(
403,
"Third party identifiers (email/phone numbers)"
@@ -421,6 +436,80 @@ class RegisterRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ existingUid = yield self.store.get_user_id_by_threepid(
+ medium, address
+ )
+
+ if existingUid is not None:
+ raise SynapseError(
+ 400, "%s is already in use" % medium, Codes.THREEPID_IN_USE
+ )
+
+ if self.hs.config.register_mxid_from_3pid:
+ # override the desired_username based on the 3PID if any.
+ # reset it first to avoid folks picking their own username.
+ desired_username = None
+
+ # we should have an auth_result at this point if we're going to progress
+ # to register the user (i.e. we haven't picked up a registered_user_id
+ # from our session store), in which case get ready and gen the
+ # desired_username
+ if auth_result:
+ if (
+ self.hs.config.register_mxid_from_3pid == "email"
+ and LoginType.EMAIL_IDENTITY in auth_result
+ ):
+ address = auth_result[LoginType.EMAIL_IDENTITY]["address"]
+ desired_username = synapse.types.strip_invalid_mxid_characters(
+ address.replace("@", "-").lower()
+ )
+
+ # find a unique mxid for the account, suffixing numbers
+ # if needed
+ while True:
+ try:
+ yield self.registration_handler.check_username(
+ desired_username,
+ guest_access_token=guest_access_token,
+ assigned_user_id=registered_user_id,
+ )
+ # if we got this far we passed the check.
+ break
+ except SynapseError as e:
+ if e.errcode == Codes.USER_IN_USE:
+ m = re.match(r"^(.*?)(\d+)$", desired_username)
+ if m:
+ desired_username = m.group(1) + str(
+ int(m.group(2)) + 1
+ )
+ else:
+ desired_username += "1"
+ else:
+ # something else went wrong.
+ break
+
+ if self.hs.config.register_just_use_email_for_display_name:
+ desired_display_name = address
+ else:
+ # Custom mapping between email address and display name
+ desired_display_name = self._map_email_to_displayname(address)
+ elif (
+ self.hs.config.register_mxid_from_3pid == "msisdn"
+ and LoginType.MSISDN in auth_result
+ ):
+ desired_username = auth_result[LoginType.MSISDN]["address"]
+ else:
+ raise SynapseError(
+ 400, "Cannot derive mxid from 3pid; no recognised 3pid"
+ )
+
+ if desired_username is not None:
+ yield self.registration_handler.check_username(
+ desired_username,
+ guest_access_token=guest_access_token,
+ assigned_user_id=registered_user_id,
+ )
+
if registered_user_id is not None:
logger.info(
"Already registered user ID %r for this session", registered_user_id
@@ -431,9 +520,16 @@ class RegisterRestServlet(RestServlet):
# NB: This may be from the auth handler and NOT from the POST
assert_params_in_dict(params, ["password"])
- desired_username = params.get("username", None)
+ if not self.hs.config.register_mxid_from_3pid:
+ desired_username = params.get("username", None)
+ else:
+ # we keep the original desired_username derived from the 3pid above
+ pass
+
guest_access_token = params.get("guest_access_token", None)
- new_password = params.get("password", None)
+
+ # XXX: don't we need to validate these for length etc like we did on
+ # the ones from the JSON body earlier on in the method?
if desired_username is not None:
desired_username = desired_username.lower()
@@ -466,8 +562,9 @@ class RegisterRestServlet(RestServlet):
registered_user_id = yield self.registration_handler.register_user(
localpart=desired_username,
- password=new_password,
+ password=params.get("password", None),
guest_access_token=guest_access_token,
+ default_display_name=desired_display_name,
threepid=threepid,
address=client_addr,
)
@@ -479,6 +576,14 @@ class RegisterRestServlet(RestServlet):
):
yield self.store.upsert_monthly_active_user(registered_user_id)
+ if self.hs.config.shadow_server:
+ yield self.registration_handler.shadow_register(
+ localpart=desired_username,
+ display_name=desired_display_name,
+ auth_result=auth_result,
+ params=params,
+ )
+
# remember that we've now registered that user account, and with
# what user ID (since the user may not have specified)
self.auth_handler.set_session_data(
@@ -500,17 +605,37 @@ class RegisterRestServlet(RestServlet):
bind_msisdn=params.get("bind_msisdn"),
)
- defer.returnValue((200, return_dict))
+ return (200, return_dict)
def on_OPTIONS(self, _):
return 200, {}
@defer.inlineCallbacks
- def _do_appservice_registration(self, username, as_token, body):
+ def _do_appservice_registration(
+ self, username, password, display_name, as_token, body
+ ):
+
+ # FIXME: appservice_register() is horribly duplicated with register()
+ # and they should probably just be combined together with a config flag.
user_id = yield self.registration_handler.appservice_register(
- username, as_token
+ username, as_token, password, display_name
)
- defer.returnValue((yield self._create_registration_details(user_id, body)))
+ result = yield self._create_registration_details(user_id, body)
+
+ auth_result = body.get("auth_result")
+ if auth_result and LoginType.EMAIL_IDENTITY in auth_result:
+ threepid = auth_result[LoginType.EMAIL_IDENTITY]
+ yield self._register_email_threepid(
+ user_id, threepid, result["access_token"], body.get("bind_email")
+ )
+
+ if auth_result and LoginType.MSISDN in auth_result:
+ threepid = auth_result[LoginType.MSISDN]
+ yield self._register_msisdn_threepid(
+ user_id, threepid, result["access_token"], body.get("bind_msisdn")
+ )
+
+ return result
@defer.inlineCallbacks
def _do_shared_secret_registration(self, username, password, body):
@@ -546,7 +671,7 @@ class RegisterRestServlet(RestServlet):
)
result = yield self._create_registration_details(user_id, body)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _create_registration_details(self, user_id, params):
@@ -570,7 +695,7 @@ class RegisterRestServlet(RestServlet):
)
result.update({"access_token": access_token, "device_id": device_id})
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _do_guest_registration(self, params, address=None):
@@ -588,19 +713,71 @@ class RegisterRestServlet(RestServlet):
user_id, device_id, initial_display_name, is_guest=True
)
- defer.returnValue(
- (
- 200,
- {
- "user_id": user_id,
- "device_id": device_id,
- "access_token": access_token,
- "home_server": self.hs.hostname,
- },
- )
+ return (
+ 200,
+ {
+ "user_id": user_id,
+ "device_id": device_id,
+ "access_token": access_token,
+ "home_server": self.hs.hostname,
+ },
)
+def cap(name):
+ """Capitalise parts of a name containing different words, including those
+ separated by hyphens.
+ For example, 'John-Doe'
+
+ Args:
+ name (str): The name to parse
+ """
+ if not name:
+ return name
+
+ # Split the name by whitespace then hyphens, capitalizing each part then
+ # joining it back together.
+ capatilized_name = " ".join(
+ "-".join(part.capitalize() for part in space_part.split("-"))
+ for space_part in name.split()
+ )
+ return capatilized_name
+
+
+def _map_email_to_displayname(address):
+ """Custom mapping from an email address to a user displayname
+
+ Args:
+ address (str): The email address to process
+ Returns:
+ str: The new displayname
+ """
+ # Split the part before and after the @ in the email.
+ # Replace all . with spaces in the first part
+ parts = address.replace(".", " ").split("@")
+
+ # Figure out which org this email address belongs to
+ org_parts = parts[1].split(" ")
+
+ # If this is a ...matrix.org email, mark them as an Admin
+ if org_parts[-2] == "matrix" and org_parts[-1] == "org":
+ org = "Tchap Admin"
+
+ # Is this is a ...gouv.fr address, set the org to whatever is before
+ # gouv.fr. If there isn't anything (a @gouv.fr email) simply mark their
+ # org as "gouv"
+ elif org_parts[-2] == "gouv" and org_parts[-1] == "fr":
+ org = org_parts[-3] if len(org_parts) > 2 else org_parts[-2]
+
+ # Otherwise, mark their org as the email's second-level domain name
+ else:
+ org = org_parts[-2]
+
+ desired_display_name = cap(parts[0]) + " [" + cap(org) + "]"
+
+ return desired_display_name
+
+
def register_servlets(hs, http_server):
EmailRegisterRequestTokenRestServlet(hs).register(http_server)
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index 9e9a639055..1538b247e5 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -118,7 +118,7 @@ class RelationSendServlet(RestServlet):
requester, event_dict=event_dict, txn_id=txn_id
)
- defer.returnValue((200, {"event_id": event.event_id}))
+ return (200, {"event_id": event.event_id})
class RelationPaginationServlet(RestServlet):
@@ -198,7 +198,7 @@ class RelationPaginationServlet(RestServlet):
return_value["chunk"] = events
return_value["original_event"] = original_event
- defer.returnValue((200, return_value))
+ return (200, return_value)
class RelationAggregationPaginationServlet(RestServlet):
@@ -270,7 +270,7 @@ class RelationAggregationPaginationServlet(RestServlet):
to_token=to_token,
)
- defer.returnValue((200, pagination_chunk.to_dict()))
+ return (200, pagination_chunk.to_dict())
class RelationAggregationGroupPaginationServlet(RestServlet):
@@ -356,7 +356,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
return_value = result.to_dict()
return_value["chunk"] = events
- defer.returnValue((200, return_value))
+ return (200, return_value)
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py
index e7578af804..3fdd4584a3 100644
--- a/synapse/rest/client/v2_alpha/report_event.py
+++ b/synapse/rest/client/v2_alpha/report_event.py
@@ -72,7 +72,7 @@ class ReportEventRestServlet(RestServlet):
received_ts=self.clock.time_msec(),
)
- defer.returnValue((200, {}))
+ return (200, {})
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py
index 8d1b810565..10dec96208 100644
--- a/synapse/rest/client/v2_alpha/room_keys.py
+++ b/synapse/rest/client/v2_alpha/room_keys.py
@@ -135,7 +135,7 @@ class RoomKeysServlet(RestServlet):
body = {"rooms": {room_id: body}}
yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body)
- defer.returnValue((200, {}))
+ return (200, {})
@defer.inlineCallbacks
def on_GET(self, request, room_id, session_id):
@@ -218,7 +218,7 @@ class RoomKeysServlet(RestServlet):
else:
room_keys = room_keys["rooms"][room_id]
- defer.returnValue((200, room_keys))
+ return (200, room_keys)
@defer.inlineCallbacks
def on_DELETE(self, request, room_id, session_id):
@@ -242,7 +242,7 @@ class RoomKeysServlet(RestServlet):
yield self.e2e_room_keys_handler.delete_room_keys(
user_id, version, room_id, session_id
)
- defer.returnValue((200, {}))
+ return (200, {})
class RoomKeysNewVersionServlet(RestServlet):
@@ -293,7 +293,7 @@ class RoomKeysNewVersionServlet(RestServlet):
info = parse_json_object_from_request(request)
new_version = yield self.e2e_room_keys_handler.create_version(user_id, info)
- defer.returnValue((200, {"version": new_version}))
+ return (200, {"version": new_version})
# we deliberately don't have a PUT /version, as these things really should
# be immutable to avoid people footgunning
@@ -338,7 +338,7 @@ class RoomKeysVersionServlet(RestServlet):
except SynapseError as e:
if e.code == 404:
raise SynapseError(404, "No backup found", Codes.NOT_FOUND)
- defer.returnValue((200, info))
+ return (200, info)
@defer.inlineCallbacks
def on_DELETE(self, request, version):
@@ -358,7 +358,7 @@ class RoomKeysVersionServlet(RestServlet):
user_id = requester.user.to_string()
yield self.e2e_room_keys_handler.delete_version(user_id, version)
- defer.returnValue((200, {}))
+ return (200, {})
@defer.inlineCallbacks
def on_PUT(self, request, version):
@@ -392,7 +392,7 @@ class RoomKeysVersionServlet(RestServlet):
)
yield self.e2e_room_keys_handler.update_version(user_id, version, info)
- defer.returnValue((200, {}))
+ return (200, {})
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
index d7f7faa029..14ba61a63e 100644
--- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
@@ -80,7 +80,7 @@ class RoomUpgradeRestServlet(RestServlet):
ret = {"replacement_room": new_room_id}
- defer.returnValue((200, ret))
+ return (200, ret)
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index 78075b8fc0..2613648d82 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -60,7 +60,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
)
response = (200, {})
- defer.returnValue(response)
+ return response
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 02d56dee6c..7b32dd2212 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -174,7 +174,7 @@ class SyncRestServlet(RestServlet):
time_now, sync_result, requester.access_token_id, filter
)
- defer.returnValue((200, response_content))
+ return (200, response_content)
@defer.inlineCallbacks
def encode_response(self, time_now, sync_result, access_token_id, filter):
@@ -205,27 +205,23 @@ class SyncRestServlet(RestServlet):
event_formatter,
)
- defer.returnValue(
- {
- "account_data": {"events": sync_result.account_data},
- "to_device": {"events": sync_result.to_device},
- "device_lists": {
- "changed": list(sync_result.device_lists.changed),
- "left": list(sync_result.device_lists.left),
- },
- "presence": SyncRestServlet.encode_presence(
- sync_result.presence, time_now
- ),
- "rooms": {"join": joined, "invite": invited, "leave": archived},
- "groups": {
- "join": sync_result.groups.join,
- "invite": sync_result.groups.invite,
- "leave": sync_result.groups.leave,
- },
- "device_one_time_keys_count": sync_result.device_one_time_keys_count,
- "next_batch": sync_result.next_batch.to_string(),
- }
- )
+ return {
+ "account_data": {"events": sync_result.account_data},
+ "to_device": {"events": sync_result.to_device},
+ "device_lists": {
+ "changed": list(sync_result.device_lists.changed),
+ "left": list(sync_result.device_lists.left),
+ },
+ "presence": SyncRestServlet.encode_presence(sync_result.presence, time_now),
+ "rooms": {"join": joined, "invite": invited, "leave": archived},
+ "groups": {
+ "join": sync_result.groups.join,
+ "invite": sync_result.groups.invite,
+ "leave": sync_result.groups.leave,
+ },
+ "device_one_time_keys_count": sync_result.device_one_time_keys_count,
+ "next_batch": sync_result.next_batch.to_string(),
+ }
@staticmethod
def encode_presence(events, time_now):
@@ -273,7 +269,7 @@ class SyncRestServlet(RestServlet):
event_formatter=event_formatter,
)
- defer.returnValue(joined)
+ return joined
@defer.inlineCallbacks
def encode_invited(self, rooms, time_now, token_id, event_formatter):
@@ -309,7 +305,7 @@ class SyncRestServlet(RestServlet):
invited_state.append(invite)
invited[room.room_id] = {"invite_state": {"events": invited_state}}
- defer.returnValue(invited)
+ return invited
@defer.inlineCallbacks
def encode_archived(self, rooms, time_now, token_id, event_fields, event_formatter):
@@ -342,7 +338,7 @@ class SyncRestServlet(RestServlet):
event_formatter=event_formatter,
)
- defer.returnValue(joined)
+ return joined
@defer.inlineCallbacks
def encode_room(
@@ -414,7 +410,7 @@ class SyncRestServlet(RestServlet):
result["unread_notifications"] = room.unread_notifications
result["summary"] = room.summary
- defer.returnValue(result)
+ return result
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index 07b6ede603..d173544355 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -45,7 +45,7 @@ class TagListServlet(RestServlet):
tags = yield self.store.get_tags_for_room(user_id, room_id)
- defer.returnValue((200, {"tags": tags}))
+ return (200, {"tags": tags})
class TagServlet(RestServlet):
@@ -76,7 +76,7 @@ class TagServlet(RestServlet):
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
- defer.returnValue((200, {}))
+ return (200, {})
@defer.inlineCallbacks
def on_DELETE(self, request, user_id, room_id, tag):
@@ -88,7 +88,7 @@ class TagServlet(RestServlet):
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
- defer.returnValue((200, {}))
+ return (200, {})
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index 1e66662a05..158e686b01 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -40,7 +40,7 @@ class ThirdPartyProtocolsServlet(RestServlet):
yield self.auth.get_user_by_req(request, allow_guest=True)
protocols = yield self.appservice_handler.get_3pe_protocols()
- defer.returnValue((200, protocols))
+ return (200, protocols)
class ThirdPartyProtocolServlet(RestServlet):
@@ -60,9 +60,9 @@ class ThirdPartyProtocolServlet(RestServlet):
only_protocol=protocol
)
if protocol in protocols:
- defer.returnValue((200, protocols[protocol]))
+ return (200, protocols[protocol])
else:
- defer.returnValue((404, {"error": "Unknown protocol"}))
+ return (404, {"error": "Unknown protocol"})
class ThirdPartyUserServlet(RestServlet):
@@ -85,7 +85,7 @@ class ThirdPartyUserServlet(RestServlet):
ThirdPartyEntityKind.USER, protocol, fields
)
- defer.returnValue((200, results))
+ return (200, results)
class ThirdPartyLocationServlet(RestServlet):
@@ -108,7 +108,7 @@ class ThirdPartyLocationServlet(RestServlet):
ThirdPartyEntityKind.LOCATION, protocol, fields
)
- defer.returnValue((200, results))
+ return (200, results)
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index e19fb6d583..079a823c53 100644
--- a/synapse/rest/client/v2_alpha/user_directory.py
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -15,10 +15,13 @@
import logging
+from signedjson.sign import sign_json
+
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.types import UserID
from ._base import client_patterns
@@ -37,6 +40,7 @@ class UserDirectorySearchRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler()
+ self.http_client = hs.get_simple_http_client()
@defer.inlineCallbacks
def on_POST(self, request):
@@ -60,10 +64,20 @@ class UserDirectorySearchRestServlet(RestServlet):
user_id = requester.user.to_string()
if not self.hs.config.user_directory_search_enabled:
- defer.returnValue((200, {"limited": False, "results": []}))
+ return (200, {"limited": False, "results": []})
body = parse_json_object_from_request(request)
+ if self.hs.config.user_directory_defer_to_id_server:
+ signed_body = sign_json(
+ body, self.hs.hostname, self.hs.config.signing_key[0]
+ )
+ url = "%s/_matrix/identity/api/v1/user_directory/search" % (
+ self.hs.config.user_directory_defer_to_id_server,
+ )
+ resp = yield self.http_client.post_json_get_json(url, signed_body)
+ defer.returnValue((200, resp))
+
limit = body.get("limit", 10)
limit = min(limit, 50)
@@ -76,8 +90,90 @@ class UserDirectorySearchRestServlet(RestServlet):
user_id, search_term, limit
)
- defer.returnValue((200, results))
+ return (200, results)
+
+
+class UserInfoServlet(RestServlet):
+ """
+ GET /user/{user_id}/info HTTP/1.1
+ """
+
+ PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/info$")
+
+ def __init__(self, hs):
+ super(UserInfoServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.notifier = hs.get_notifier()
+ self.clock = hs.get_clock()
+ self.transport_layer = hs.get_federation_transport_client()
+ registry = hs.get_federation_registry()
+
+ if not registry.query_handlers.get("user_info"):
+ registry.register_query_handler("user_info", self._on_federation_query)
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, user_id):
+ # Ensure the user is authenticated
+ yield self.auth.get_user_by_req(request, allow_guest=False)
+
+ user = UserID.from_string(user_id)
+ if not self.hs.is_mine(user):
+ # Attempt to make a federation request to the server that owns this user
+ args = {"user_id": user_id}
+ res = yield self.transport_layer.make_query(
+ user.domain, "user_info", args, retry_on_dns_fail=True
+ )
+ defer.returnValue((200, res))
+
+ res = yield self._get_user_info(user_id)
+ defer.returnValue((200, res))
+
+ @defer.inlineCallbacks
+ def _on_federation_query(self, args):
+ """Called when a request for user information appears over federation
+
+ Args:
+ args (dict): Dictionary of query arguments provided by the request
+
+ Returns:
+ Deferred[dict]: Deactivation and expiration information for a given user
+ """
+ user_id = args.get("user_id")
+ if not user_id:
+ raise SynapseError(400, "user_id not provided")
+
+ user = UserID.from_string(user_id)
+ if not self.hs.is_mine(user):
+ raise SynapseError(400, "User is not hosted on this homeserver")
+
+ res = yield self._get_user_info(user_id)
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
+ def _get_user_info(self, user_id):
+ """Retrieve information about a given user
+
+ Args:
+ user_id (str): The User ID of a given user on this homeserver
+
+ Returns:
+ Deferred[dict]: Deactivation and expiration information for a given user
+ """
+ # Check whether user is deactivated
+ is_deactivated = yield self.store.get_user_deactivated_status(user_id)
+
+ # Check whether user is expired
+ expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
+ is_expired = (
+ expiration_ts is not None and self.clock.time_msec() >= expiration_ts
+ )
+
+ res = {"expired": is_expired, "deactivated": is_deactivated}
+ defer.returnValue(res)
def register_servlets(hs, http_server):
UserDirectorySearchRestServlet(hs).register(http_server)
+ UserInfoServlet(hs).register(http_server)
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 65afffbb42..92beefa176 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -171,7 +171,7 @@ class MediaRepository(object):
yield self._generate_thumbnails(None, media_id, media_id, media_type)
- defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
+ return "mxc://%s/%s" % (self.server_name, media_id)
@defer.inlineCallbacks
def get_local_media(self, request, media_id, name):
@@ -282,7 +282,7 @@ class MediaRepository(object):
with responder:
pass
- defer.returnValue(media_info)
+ return media_info
@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
@@ -317,14 +317,14 @@ class MediaRepository(object):
responder = yield self.media_storage.fetch_media(file_info)
if responder:
- defer.returnValue((responder, media_info))
+ return (responder, media_info)
# Failed to find the file anywhere, lets download it.
media_info = yield self._download_remote_file(server_name, media_id, file_id)
responder = yield self.media_storage.fetch_media(file_info)
- defer.returnValue((responder, media_info))
+ return (responder, media_info)
@defer.inlineCallbacks
def _download_remote_file(self, server_name, media_id, file_id):
@@ -421,7 +421,7 @@ class MediaRepository(object):
yield self._generate_thumbnails(server_name, media_id, file_id, media_type)
- defer.returnValue(media_info)
+ return media_info
def _get_thumbnail_requirements(self, media_type):
return self.thumbnail_requirements.get(media_type, ())
@@ -500,7 +500,7 @@ class MediaRepository(object):
media_id, t_width, t_height, t_type, t_method, t_len
)
- defer.returnValue(output_path)
+ return output_path
@defer.inlineCallbacks
def generate_remote_exact_thumbnail(
@@ -554,7 +554,7 @@ class MediaRepository(object):
t_len,
)
- defer.returnValue(output_path)
+ return output_path
@defer.inlineCallbacks
def _generate_thumbnails(
@@ -667,7 +667,7 @@ class MediaRepository(object):
media_id, t_width, t_height, t_type, t_method, t_len
)
- defer.returnValue({"width": m_width, "height": m_height})
+ return {"width": m_width, "height": m_height}
@defer.inlineCallbacks
def delete_old_remote_media(self, before_ts):
@@ -704,7 +704,7 @@ class MediaRepository(object):
yield self.store.delete_remote_media(origin, media_id)
deleted += 1
- defer.returnValue({"deleted": deleted})
+ return {"deleted": deleted}
class MediaRepositoryResource(Resource):
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 25e5ac2848..3b87717a5a 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -69,7 +69,7 @@ class MediaStorage(object):
)
yield finish_cb()
- defer.returnValue(fname)
+ return fname
@contextlib.contextmanager
def store_into_file(self, file_info):
@@ -143,14 +143,14 @@ class MediaStorage(object):
path = self._file_info_to_path(file_info)
local_path = os.path.join(self.local_media_directory, path)
if os.path.exists(local_path):
- defer.returnValue(FileResponder(open(local_path, "rb")))
+ return FileResponder(open(local_path, "rb"))
for provider in self.storage_providers:
res = yield provider.fetch(path, file_info)
if res:
- defer.returnValue(res)
+ return res
- defer.returnValue(None)
+ return None
@defer.inlineCallbacks
def ensure_media_is_in_local_cache(self, file_info):
@@ -166,7 +166,7 @@ class MediaStorage(object):
path = self._file_info_to_path(file_info)
local_path = os.path.join(self.local_media_directory, path)
if os.path.exists(local_path):
- defer.returnValue(local_path)
+ return local_path
dirname = os.path.dirname(local_path)
if not os.path.exists(dirname):
@@ -181,7 +181,7 @@ class MediaStorage(object):
)
yield res.write_to_consumer(consumer)
yield consumer.wait()
- defer.returnValue(local_path)
+ return local_path
raise Exception("file could not be found")
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 5871737bfd..da4ee52a4d 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -74,6 +74,8 @@ class PreviewUrlResource(DirectServeResource):
treq_args={"browser_like_redirects": True},
ip_whitelist=hs.config.url_preview_ip_range_whitelist,
ip_blacklist=hs.config.url_preview_ip_range_blacklist,
+ http_proxy=os.getenv("http_proxy"),
+ https_proxy=os.getenv("HTTPS_PROXY"),
)
self.media_repo = media_repo
self.primary_base_path = media_repo.primary_base_path
@@ -182,7 +184,7 @@ class PreviewUrlResource(DirectServeResource):
og = cache_result["og"]
if isinstance(og, six.text_type):
og = og.encode("utf8")
- defer.returnValue(og)
+ return og
return
media_info = yield self._download_url(url, user)
@@ -284,7 +286,7 @@ class PreviewUrlResource(DirectServeResource):
media_info["created_ts"],
)
- defer.returnValue(jsonog)
+ return jsonog
@defer.inlineCallbacks
def _download_url(self, url, user):
@@ -354,22 +356,20 @@ class PreviewUrlResource(DirectServeResource):
# therefore not expire it.
raise
- defer.returnValue(
- {
- "media_type": media_type,
- "media_length": length,
- "download_name": download_name,
- "created_ts": time_now_ms,
- "filesystem_id": file_id,
- "filename": fname,
- "uri": uri,
- "response_code": code,
- # FIXME: we should calculate a proper expiration based on the
- # Cache-Control and Expire headers. But for now, assume 1 hour.
- "expires": 60 * 60 * 1000,
- "etag": headers["ETag"][0] if "ETag" in headers else None,
- }
- )
+ return {
+ "media_type": media_type,
+ "media_length": length,
+ "download_name": download_name,
+ "created_ts": time_now_ms,
+ "filesystem_id": file_id,
+ "filename": fname,
+ "uri": uri,
+ "response_code": code,
+ # FIXME: we should calculate a proper expiration based on the
+ # Cache-Control and Expire headers. But for now, assume 1 hour.
+ "expires": 60 * 60 * 1000,
+ "etag": headers["ETag"][0] if "ETag" in headers else None,
+ }
def _start_expire_url_cache_data(self):
return run_as_background_process(
diff --git a/synapse/rulecheck/__init__.py b/synapse/rulecheck/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/synapse/rulecheck/__init__.py
diff --git a/synapse/rulecheck/domain_rule_checker.py b/synapse/rulecheck/domain_rule_checker.py
new file mode 100644
index 0000000000..6f2a1931c5
--- /dev/null
+++ b/synapse/rulecheck/domain_rule_checker.py
@@ -0,0 +1,181 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.config._base import ConfigError
+
+logger = logging.getLogger(__name__)
+
+
+class DomainRuleChecker(object):
+ """
+ A re-implementation of the SpamChecker that prevents users in one domain from
+ inviting users in other domains to rooms, based on a configuration.
+
+ Takes a config in the format:
+
+ spam_checker:
+ module: "rulecheck.DomainRuleChecker"
+ config:
+ domain_mapping:
+ "inviter_domain": [ "invitee_domain_permitted", "other_domain_permitted" ]
+ "other_inviter_domain": [ "invitee_domain_permitted" ]
+ default: False
+
+ # Only let local users join rooms if they were explicitly invited.
+ can_only_join_rooms_with_invite: false
+
+ # Only let local users create rooms if they are inviting only one
+ # other user, and that user matches the rules above.
+ can_only_create_one_to_one_rooms: false
+
+ # Only let local users invite during room creation, regardless of the
+ # domain mapping rules above.
+ can_only_invite_during_room_creation: false
+
+ # Prevent local users from inviting users from certain domains to
+ # rooms published in the room directory.
+ domains_prevented_from_being_invited_to_published_rooms: []
+
+ # Allow third party invites
+ can_invite_by_third_party_id: true
+
+ Don't forget to consider if you can invite users from your own domain.
+ """
+
+ def __init__(self, config):
+ self.domain_mapping = config["domain_mapping"] or {}
+ self.default = config["default"]
+
+ self.can_only_join_rooms_with_invite = config.get(
+ "can_only_join_rooms_with_invite", False
+ )
+ self.can_only_create_one_to_one_rooms = config.get(
+ "can_only_create_one_to_one_rooms", False
+ )
+ self.can_only_invite_during_room_creation = config.get(
+ "can_only_invite_during_room_creation", False
+ )
+ self.can_invite_by_third_party_id = config.get(
+ "can_invite_by_third_party_id", True
+ )
+ self.domains_prevented_from_being_invited_to_published_rooms = config.get(
+ "domains_prevented_from_being_invited_to_published_rooms", []
+ )
+
+ def check_event_for_spam(self, event):
+ """Implements synapse.events.SpamChecker.check_event_for_spam
+ """
+ return False
+
+ def user_may_invite(
+ self,
+ inviter_userid,
+ invitee_userid,
+ third_party_invite,
+ room_id,
+ new_room,
+ published_room=False,
+ ):
+ """Implements synapse.events.SpamChecker.user_may_invite
+ """
+ if self.can_only_invite_during_room_creation and not new_room:
+ return False
+
+ if not self.can_invite_by_third_party_id and third_party_invite:
+ return False
+
+ # This is a third party invite (without a bound mxid), so unless we have
+ # banned all third party invites (above) we allow it.
+ if not invitee_userid:
+ return True
+
+ inviter_domain = self._get_domain_from_id(inviter_userid)
+ invitee_domain = self._get_domain_from_id(invitee_userid)
+
+ if inviter_domain not in self.domain_mapping:
+ return self.default
+
+ if (
+ published_room
+ and invitee_domain
+ in self.domains_prevented_from_being_invited_to_published_rooms
+ ):
+ return False
+
+ return invitee_domain in self.domain_mapping[inviter_domain]
+
+ def user_may_create_room(
+ self, userid, invite_list, third_party_invite_list, cloning
+ ):
+ """Implements synapse.events.SpamChecker.user_may_create_room
+ """
+
+ if cloning:
+ return True
+
+ if not self.can_invite_by_third_party_id and third_party_invite_list:
+ return False
+
+ number_of_invites = len(invite_list) + len(third_party_invite_list)
+
+ if self.can_only_create_one_to_one_rooms and number_of_invites != 1:
+ return False
+
+ return True
+
+ def user_may_create_room_alias(self, userid, room_alias):
+ """Implements synapse.events.SpamChecker.user_may_create_room_alias
+ """
+ return True
+
+ def user_may_publish_room(self, userid, room_id):
+ """Implements synapse.events.SpamChecker.user_may_publish_room
+ """
+ return True
+
+ def user_may_join_room(self, userid, room_id, is_invited):
+ """Implements synapse.events.SpamChecker.user_may_join_room
+ """
+ if self.can_only_join_rooms_with_invite and not is_invited:
+ return False
+
+ return True
+
+ @staticmethod
+ def parse_config(config):
+ """Implements synapse.events.SpamChecker.parse_config
+ """
+ if "default" in config:
+ return config
+ else:
+ raise ConfigError("No default set for spam_config DomainRuleChecker")
+
+ @staticmethod
+ def _get_domain_from_id(mxid):
+ """Parses a string and returns the domain part of the mxid.
+
+ Args:
+ mxid (str): a valid mxid
+
+ Returns:
+ str: the domain part of the mxid
+
+ """
+ idx = mxid.find(":")
+ if idx == -1:
+ raise Exception("Invalid ID: %r" % (mxid,))
+ return mxid[idx + 1 :]
diff --git a/synapse/server.py b/synapse/server.py
index 9e28dba2b1..23be3560fa 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -23,6 +23,7 @@
# Imports required for the default HomeServer() implementation
import abc
import logging
+import os
from twisted.enterprise import adbapi
from twisted.mail.smtp import sendmail
@@ -65,6 +66,7 @@ from synapse.handlers.groups_local import GroupsLocalHandler
from synapse.handlers.initial_sync import InitialSyncHandler
from synapse.handlers.message import EventCreationHandler, MessageHandler
from synapse.handlers.pagination import PaginationHandler
+from synapse.handlers.password_policy import PasswordPolicyHandler
from synapse.handlers.presence import PresenceHandler
from synapse.handlers.profile import BaseProfileHandler, MasterProfileHandler
from synapse.handlers.read_marker import ReadMarkerHandler
@@ -166,6 +168,7 @@ class HomeServer(object):
"event_builder_factory",
"filtering",
"http_client_context_factory",
+ "proxied_http_client",
"simple_http_client",
"media_repository",
"media_repository_resource",
@@ -196,6 +199,7 @@ class HomeServer(object):
"account_validity_handler",
"saml_handler",
"event_client_serializer",
+ "password_policy_handler",
]
REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
@@ -304,6 +308,13 @@ class HomeServer(object):
def build_simple_http_client(self):
return SimpleHttpClient(self)
+ def build_proxied_http_client(self):
+ return SimpleHttpClient(
+ self,
+ http_proxy=os.getenv("http_proxy"),
+ https_proxy=os.getenv("HTTPS_PROXY"),
+ )
+
def build_room_creation_handler(self):
return RoomCreationHandler(self)
@@ -533,6 +544,9 @@ class HomeServer(object):
def build_event_client_serializer(self):
return EventClientSerializer(self)
+ def build_password_policy_handler(self):
+ return PasswordPolicyHandler(self)
+
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 16f8f6b573..56f9cd06e5 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -12,6 +12,7 @@ import synapse.handlers.message
import synapse.handlers.room
import synapse.handlers.room_member
import synapse.handlers.set_password
+import synapse.http.client
import synapse.rest.media.v1.media_repository
import synapse.server_notices.server_notices_manager
import synapse.server_notices.server_notices_sender
@@ -38,6 +39,14 @@ class HomeServer(object):
pass
def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
pass
+ def get_simple_http_client(self) -> synapse.http.client.SimpleHttpClient:
+ """Fetch an HTTP client implementation which doesn't do any blacklisting
+ or support any HTTP_PROXY settings"""
+ pass
+ def get_proxied_http_client(self) -> synapse.http.client.SimpleHttpClient:
+ """Fetch an HTTP client implementation which doesn't do any blacklisting
+ but does support HTTP_PROXY settings"""
+ pass
def get_deactivate_account_handler(
self
) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index f183743f31..729c097e6d 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -193,4 +193,4 @@ class ResourceLimitsServerNotices(object):
if event_id in referenced_events:
referenced_events.remove(event.event_id)
- defer.returnValue((currently_blocked, referenced_events))
+ return (currently_blocked, referenced_events)
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index 71e7e75320..2dac90578c 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -86,7 +86,7 @@ class ServerNoticesManager(object):
res = yield self._event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, ratelimit=False
)
- defer.returnValue(res)
+ return res
@cachedInlineCallbacks()
def get_notice_room_for_user(self, user_id):
@@ -120,7 +120,7 @@ class ServerNoticesManager(object):
# we found a room which our user shares with the system notice
# user
logger.info("Using room %s", room.room_id)
- defer.returnValue(room.room_id)
+ return room.room_id
# apparently no existing notice room: create a new one
logger.info("Creating server notices room for %s", user_id)
@@ -158,4 +158,4 @@ class ServerNoticesManager(object):
self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
logger.info("Created server notices room %s for %s", room_id, user_id)
- defer.returnValue(room_id)
+ return room_id
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 9f708fa205..a0d34f16ea 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -135,7 +135,7 @@ class StateHandler(object):
event = None
if event_id:
event = yield self.store.get_event(event_id, allow_none=True)
- defer.returnValue(event)
+ return event
return
state_map = yield self.store.get_events(
@@ -145,7 +145,7 @@ class StateHandler(object):
key: state_map[e_id] for key, e_id in iteritems(state) if e_id in state_map
}
- defer.returnValue(state)
+ return state
@defer.inlineCallbacks
def get_current_state_ids(self, room_id, latest_event_ids=None):
@@ -169,7 +169,7 @@ class StateHandler(object):
ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
- defer.returnValue(state)
+ return state
@defer.inlineCallbacks
def get_current_users_in_room(self, room_id, latest_event_ids=None):
@@ -189,7 +189,7 @@ class StateHandler(object):
logger.debug("calling resolve_state_groups from get_current_users_in_room")
entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
- defer.returnValue(joined_users)
+ return joined_users
@defer.inlineCallbacks
def get_current_hosts_in_room(self, room_id, latest_event_ids=None):
@@ -198,7 +198,7 @@ class StateHandler(object):
logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
- defer.returnValue(joined_hosts)
+ return joined_hosts
@defer.inlineCallbacks
def compute_event_context(self, event, old_state=None):
@@ -241,7 +241,7 @@ class StateHandler(object):
prev_state_ids=prev_state_ids,
)
- defer.returnValue(context)
+ return context
if old_state:
# We already have the state, so we don't need to calculate it.
@@ -275,7 +275,7 @@ class StateHandler(object):
prev_state_ids=prev_state_ids,
)
- defer.returnValue(context)
+ return context
logger.debug("calling resolve_state_groups from compute_event_context")
@@ -343,7 +343,7 @@ class StateHandler(object):
delta_ids=delta_ids,
)
- defer.returnValue(context)
+ return context
@defer.inlineCallbacks
def resolve_state_groups_for_events(self, room_id, event_ids):
@@ -368,19 +368,17 @@ class StateHandler(object):
state_groups_ids = yield self.store.get_state_groups_ids(room_id, event_ids)
if len(state_groups_ids) == 0:
- defer.returnValue(_StateCacheEntry(state={}, state_group=None))
+ return _StateCacheEntry(state={}, state_group=None)
elif len(state_groups_ids) == 1:
name, state_list = list(state_groups_ids.items()).pop()
prev_group, delta_ids = yield self.store.get_state_group_delta(name)
- defer.returnValue(
- _StateCacheEntry(
- state=state_list,
- state_group=name,
- prev_group=prev_group,
- delta_ids=delta_ids,
- )
+ return _StateCacheEntry(
+ state=state_list,
+ state_group=name,
+ prev_group=prev_group,
+ delta_ids=delta_ids,
)
room_version = yield self.store.get_room_version(room_id)
@@ -392,7 +390,7 @@ class StateHandler(object):
None,
state_res_store=StateResolutionStore(self.store),
)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def resolve_events(self, room_version, state_sets, event):
@@ -415,7 +413,7 @@ class StateHandler(object):
new_state = {key: state_map[ev_id] for key, ev_id in iteritems(new_state)}
- defer.returnValue(new_state)
+ return new_state
class StateResolutionHandler(object):
@@ -479,7 +477,7 @@ class StateResolutionHandler(object):
if self._state_cache is not None:
cache = self._state_cache.get(group_names, None)
if cache:
- defer.returnValue(cache)
+ return cache
logger.info(
"Resolving state for %s with %d groups", room_id, len(state_groups_ids)
@@ -525,7 +523,7 @@ class StateResolutionHandler(object):
if self._state_cache is not None:
self._state_cache[group_names] = cache
- defer.returnValue(cache)
+ return cache
def _make_state_cache_entry(new_state, state_groups_ids):
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 88acd4817e..a2f92d9ff9 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -55,7 +55,7 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
a map from (type, state_key) to event_id.
"""
if len(state_sets) == 1:
- defer.returnValue(state_sets[0])
+ return state_sets[0]
unconflicted_state, conflicted_state = _seperate(state_sets)
@@ -97,10 +97,8 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
state_map_new = yield state_map_factory(new_needed_events)
state_map.update(state_map_new)
- defer.returnValue(
- _resolve_with_state(
- unconflicted_state, conflicted_state, auth_events, state_map
- )
+ return _resolve_with_state(
+ unconflicted_state, conflicted_state, auth_events, state_map
)
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index db969e8997..b327c86f40 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -63,7 +63,7 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
unconflicted_state, conflicted_state = _seperate(state_sets)
if not conflicted_state:
- defer.returnValue(unconflicted_state)
+ return unconflicted_state
logger.debug("%d conflicted state entries", len(conflicted_state))
logger.debug("Calculating auth chain difference")
@@ -137,7 +137,7 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
logger.debug("done")
- defer.returnValue(resolved_state)
+ return resolved_state
@defer.inlineCallbacks
@@ -168,18 +168,18 @@ def _get_power_level_for_sender(event_id, event_map, state_res_store):
aev = yield _get_event(aid, event_map, state_res_store)
if (aev.type, aev.state_key) == (EventTypes.Create, ""):
if aev.content.get("creator") == event.sender:
- defer.returnValue(100)
+ return 100
break
- defer.returnValue(0)
+ return 0
level = pl.content.get("users", {}).get(event.sender)
if level is None:
level = pl.content.get("users_default", 0)
if level is None:
- defer.returnValue(0)
+ return 0
else:
- defer.returnValue(int(level))
+ return int(level)
@defer.inlineCallbacks
@@ -224,7 +224,7 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
intersection = set(auth_sets[0]).intersection(*auth_sets[1:])
union = set().union(*auth_sets)
- defer.returnValue(union - intersection)
+ return union - intersection
def _seperate(state_sets):
@@ -343,7 +343,7 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_
it = lexicographical_topological_sort(graph, key=_get_power_order)
sorted_events = list(it)
- defer.returnValue(sorted_events)
+ return sorted_events
@defer.inlineCallbacks
@@ -396,7 +396,7 @@ def _iterative_auth_checks(
except AuthError:
pass
- defer.returnValue(resolved_state)
+ return resolved_state
@defer.inlineCallbacks
@@ -439,7 +439,7 @@ def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_stor
event_ids.sort(key=lambda ev_id: order_map[ev_id])
- defer.returnValue(event_ids)
+ return event_ids
@defer.inlineCallbacks
@@ -462,7 +462,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
while event:
depth = mainline_map.get(event.event_id)
if depth is not None:
- defer.returnValue(depth)
+ return depth
auth_events = event.auth_event_ids()
event = None
@@ -474,7 +474,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
break
# Didn't find a power level auth event, so we just return 0
- defer.returnValue(0)
+ return 0
@defer.inlineCallbacks
@@ -493,7 +493,7 @@ def _get_event(event_id, event_map, state_res_store):
if event_id not in event_map:
events = yield state_res_store.get_events([event_id], allow_rejected=True)
event_map.update(events)
- defer.returnValue(event_map[event_id])
+ return event_map[event_id]
def lexicographical_topological_sort(graph, key):
diff --git a/synapse/static/index.html b/synapse/static/index.html
index d3f1c7dce0..bf46df9097 100644
--- a/synapse/static/index.html
+++ b/synapse/static/index.html
@@ -48,13 +48,13 @@
</div>
<h1>It works! Synapse is running</h1>
<p>Your Synapse server is listening on this port and is ready for messages.</p>
- <p>To use this server you'll need <a href="https://matrix.org/docs/projects/try-matrix-now.html#clients" target="_blank">a Matrix client</a>.
+ <p>To use this server you'll need <a href="https://matrix.org/docs/projects/try-matrix-now.html#clients" target="_blank" rel="noopener noreferrer">a Matrix client</a>.
</p>
<p>Welcome to the Matrix universe :)</p>
<hr>
<p>
<small>
- <a href="https://matrix.org" target="_blank">
+ <a href="https://matrix.org" target="_blank" rel="noopener noreferrer">
matrix.org
</a>
</small>
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 6b0ca80087..e7f6ea7286 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -469,7 +469,7 @@ class DataStore(
return self._simple_select_list(
table="users",
keyvalues={},
- retcols=["name", "password_hash", "is_guest", "admin"],
+ retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
desc="get_users",
)
@@ -494,11 +494,11 @@ class DataStore(
orderby=order,
start=start,
limit=limit,
- retcols=["name", "password_hash", "is_guest", "admin"],
+ retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
)
count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn)
retval = {"users": users, "total": count}
- defer.returnValue(retval)
+ return retval
def search_users(self, term):
"""Function to search users list for one or more users with
@@ -514,7 +514,7 @@ class DataStore(
table="users",
term=term,
col="name",
- retcols=["name", "password_hash", "is_guest", "admin"],
+ retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
desc="search_users",
)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 2f940dbae6..56a54267af 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -86,7 +86,21 @@ _CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
passed to the constructor. Adds logging and metrics to the .execute()
- method."""
+ method.
+
+ Args:
+ txn: The database transcation object to wrap.
+ name (str): The name of this transactions for logging.
+ database_engine (Sqlite3Engine|PostgresEngine)
+ after_callbacks(list|None): A list that callbacks will be appended to
+ that have been added by `call_after` which should be run on
+ successful completion of the transaction. None indicates that no
+ callbacks should be allowed to be scheduled to run.
+ exception_callbacks(list|None): A list that callbacks will be appended
+ to that have been added by `call_on_exception` which should be run
+ if transaction ends with an error. None indicates that no callbacks
+ should be allowed to be scheduled to run.
+ """
__slots__ = [
"txn",
@@ -97,7 +111,7 @@ class LoggingTransaction(object):
]
def __init__(
- self, txn, name, database_engine, after_callbacks, exception_callbacks
+ self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None
):
object.__setattr__(self, "txn", txn)
object.__setattr__(self, "name", name)
@@ -346,14 +360,11 @@ class SQLBaseStore(object):
expiration_ts,
)
- self._simple_insert_txn(
+ self._simple_upsert_txn(
txn,
"account_validity",
- values={
- "user_id": user_id,
- "expiration_ts_ms": expiration_ts,
- "email_sent": False,
- },
+ keyvalues={"user_id": user_id},
+ values={"expiration_ts_ms": expiration_ts, "email_sent": False},
)
def start_profiling(self):
@@ -499,7 +510,7 @@ class SQLBaseStore(object):
after_callback(*after_args, **after_kwargs)
raise
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def runWithConnection(self, func, *args, **kwargs):
@@ -539,7 +550,7 @@ class SQLBaseStore(object):
with PreserveLoggingContext():
result = yield self._db_pool.runWithConnection(inner_func, *args, **kwargs)
- defer.returnValue(result)
+ return result
@staticmethod
def cursor_to_dict(cursor):
@@ -601,8 +612,8 @@ class SQLBaseStore(object):
# a cursor after we receive an error from the db.
if not or_ignore:
raise
- defer.returnValue(False)
- defer.returnValue(True)
+ return False
+ return True
@staticmethod
def _simple_insert_txn(txn, table, values):
@@ -694,7 +705,7 @@ class SQLBaseStore(object):
insertion_values,
lock=lock,
)
- defer.returnValue(result)
+ return result
except self.database_engine.module.IntegrityError as e:
attempts += 1
if attempts >= 5:
@@ -1107,7 +1118,7 @@ class SQLBaseStore(object):
results = []
if not iterable:
- defer.returnValue(results)
+ return results
# iterables can not be sliced, so convert it to a list first
it_list = list(iterable)
@@ -1128,7 +1139,7 @@ class SQLBaseStore(object):
results.extend(rows)
- defer.returnValue(results)
+ return results
@classmethod
def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index 8394389073..9fa5b4f3d6 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -111,9 +111,9 @@ class AccountDataWorkerStore(SQLBaseStore):
)
if result:
- defer.returnValue(json.loads(result))
+ return json.loads(result)
else:
- defer.returnValue(None)
+ return None
@cached(num_args=2)
def get_account_data_for_room(self, user_id, room_id):
@@ -264,11 +264,9 @@ class AccountDataWorkerStore(SQLBaseStore):
on_invalidate=cache_context.invalidate,
)
if not ignored_account_data:
- defer.returnValue(False)
+ return False
- defer.returnValue(
- ignored_user_id in ignored_account_data.get("ignored_users", {})
- )
+ return ignored_user_id in ignored_account_data.get("ignored_users", {})
class AccountDataStore(AccountDataWorkerStore):
@@ -332,7 +330,7 @@ class AccountDataStore(AccountDataWorkerStore):
)
result = self._account_data_id_gen.get_current_token()
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def add_account_data_for_user(self, user_id, account_data_type, content):
@@ -373,7 +371,7 @@ class AccountDataStore(AccountDataWorkerStore):
)
result = self._account_data_id_gen.get_current_token()
- defer.returnValue(result)
+ return result
def _update_max_stream_id(self, next_id):
"""Update the max stream_id
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index eb329ebd8b..1e9977e1bc 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -35,7 +35,7 @@ def _make_exclusive_regex(services_cache):
exclusive_user_regexes = [
regex.pattern
for service in services_cache
- for regex in service.get_exlusive_user_regexes()
+ for regex in service.get_exclusive_user_regexes()
]
if exclusive_user_regexes:
exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
@@ -145,7 +145,7 @@ class ApplicationServiceTransactionWorkerStore(
for service in as_list:
if service.id == res["as_id"]:
services.append(service)
- defer.returnValue(services)
+ return services
@defer.inlineCallbacks
def get_appservice_state(self, service):
@@ -164,9 +164,9 @@ class ApplicationServiceTransactionWorkerStore(
desc="get_appservice_state",
)
if result:
- defer.returnValue(result.get("state"))
+ return result.get("state")
return
- defer.returnValue(None)
+ return None
def set_appservice_state(self, service, state):
"""Set the application service state.
@@ -298,15 +298,13 @@ class ApplicationServiceTransactionWorkerStore(
)
if not entry:
- defer.returnValue(None)
+ return None
event_ids = json.loads(entry["event_ids"])
events = yield self.get_events_as_list(event_ids)
- defer.returnValue(
- AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
- )
+ return AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
def _get_last_txn(self, txn, service_id):
txn.execute(
@@ -360,7 +358,7 @@ class ApplicationServiceTransactionWorkerStore(
events = yield self.get_events_as_list(event_ids)
- defer.returnValue((upper_bound, events))
+ return (upper_bound, events)
class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore):
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 50f913a414..e5f0668f09 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -115,7 +115,7 @@ class BackgroundUpdateStore(SQLBaseStore):
" Unscheduling background update task."
)
self._all_done = True
- defer.returnValue(None)
+ return None
@defer.inlineCallbacks
def has_completed_background_updates(self):
@@ -127,11 +127,11 @@ class BackgroundUpdateStore(SQLBaseStore):
# if we've previously determined that there is nothing left to do, that
# is easy
if self._all_done:
- defer.returnValue(True)
+ return True
# obviously, if we have things in our queue, we're not done.
if self._background_update_queue:
- defer.returnValue(False)
+ return False
# otherwise, check if there are updates to be run. This is important,
# as we may be running on a worker which doesn't perform the bg updates
@@ -144,9 +144,9 @@ class BackgroundUpdateStore(SQLBaseStore):
)
if not updates:
self._all_done = True
- defer.returnValue(True)
+ return True
- defer.returnValue(False)
+ return False
@defer.inlineCallbacks
def do_next_background_update(self, desired_duration_ms):
@@ -173,14 +173,14 @@ class BackgroundUpdateStore(SQLBaseStore):
if not self._background_update_queue:
# no work left to do
- defer.returnValue(None)
+ return None
# pop from the front, and add back to the back
update_name = self._background_update_queue.pop(0)
self._background_update_queue.append(update_name)
res = yield self._do_background_update(update_name, desired_duration_ms)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def _do_background_update(self, update_name, desired_duration_ms):
@@ -231,7 +231,7 @@ class BackgroundUpdateStore(SQLBaseStore):
performance.update(items_updated, duration_ms)
- defer.returnValue(len(self._background_update_performance))
+ return len(self._background_update_performance)
def register_background_update_handler(self, update_name, update_handler):
"""Register a handler for doing a background update.
@@ -266,7 +266,7 @@ class BackgroundUpdateStore(SQLBaseStore):
@defer.inlineCallbacks
def noop_update(progress, batch_size):
yield self._end_background_update(update_name)
- defer.returnValue(1)
+ return 1
self.register_background_update_handler(update_name, noop_update)
@@ -370,7 +370,7 @@ class BackgroundUpdateStore(SQLBaseStore):
logger.info("Adding index %s to %s", index_name, table)
yield self.runWithConnection(runner)
yield self._end_background_update(update_name)
- defer.returnValue(1)
+ return 1
self.register_background_update_handler(update_name, updater)
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index bda68de5be..6db8c54077 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -104,7 +104,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
yield self.runWithConnection(f)
yield self._end_background_update("user_ips_drop_nonunique_index")
- defer.returnValue(1)
+ return 1
@defer.inlineCallbacks
def _analyze_user_ip(self, progress, batch_size):
@@ -121,7 +121,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
yield self._end_background_update("user_ips_analyze")
- defer.returnValue(1)
+ return 1
@defer.inlineCallbacks
def _remove_user_ip_dupes(self, progress, batch_size):
@@ -291,7 +291,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
if last:
yield self._end_background_update("user_ips_remove_dupes")
- defer.returnValue(batch_size)
+ return batch_size
@defer.inlineCallbacks
def insert_client_ip(
@@ -401,7 +401,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"device_id": did,
"last_seen": last_seen,
}
- defer.returnValue(ret)
+ return ret
@classmethod
def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols):
@@ -461,14 +461,12 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
for row in rows
)
- defer.returnValue(
- list(
- {
- "access_token": access_token,
- "ip": ip,
- "user_agent": user_agent,
- "last_seen": last_seen,
- }
- for (access_token, ip), (user_agent, last_seen) in iteritems(results)
- )
+ return list(
+ {
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "last_seen": last_seen,
+ }
+ for (access_token, ip), (user_agent, last_seen) in iteritems(results)
)
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index 4ea0deea4f..79bb0ea46d 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -92,7 +92,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
user_id, last_deleted_stream_id
)
if not has_changed:
- defer.returnValue(0)
+ return 0
def delete_messages_for_device_txn(txn):
sql = (
@@ -115,7 +115,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
last_deleted_stream_id, up_to_stream_id
)
- defer.returnValue(count)
+ return count
def get_new_device_msgs_for_remote(
self, destination, last_stream_id, current_stream_id, limit
@@ -263,7 +263,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
destination, stream_id
)
- defer.returnValue(self._device_inbox_id_gen.get_current_token())
+ return self._device_inbox_id_gen.get_current_token()
@defer.inlineCallbacks
def add_messages_from_remote_to_device_inbox(
@@ -312,7 +312,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
for user_id in local_messages_by_user_then_device.keys():
self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id)
- defer.returnValue(stream_id)
+ return stream_id
def _add_messages_to_local_device_inbox_txn(
self, txn, stream_id, messages_by_user_then_device
@@ -426,4 +426,4 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
- defer.returnValue(1)
+ return 1
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index d2b113a4e7..8f72d92895 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -71,7 +71,7 @@ class DeviceWorkerStore(SQLBaseStore):
desc="get_devices_by_user",
)
- defer.returnValue({d["device_id"]: d for d in devices})
+ return {d["device_id"]: d for d in devices}
@defer.inlineCallbacks
def get_devices_by_remote(self, destination, from_stream_id, limit):
@@ -88,7 +88,7 @@ class DeviceWorkerStore(SQLBaseStore):
destination, int(from_stream_id)
)
if not has_changed:
- defer.returnValue((now_stream_id, []))
+ return (now_stream_id, [])
# We retrieve n+1 devices from the list of outbound pokes where n is
# our outbound device update limit. We then check if the very last
@@ -111,7 +111,7 @@ class DeviceWorkerStore(SQLBaseStore):
# Return an empty list if there are no updates
if not updates:
- defer.returnValue((now_stream_id, []))
+ return (now_stream_id, [])
# if we have exceeded the limit, we need to exclude any results with the
# same stream_id as the last row.
@@ -147,13 +147,13 @@ class DeviceWorkerStore(SQLBaseStore):
# skip that stream_id and return an empty list, and continue with the next
# stream_id next time.
if not query_map:
- defer.returnValue((stream_id_cutoff, []))
+ return (stream_id_cutoff, [])
results = yield self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)
- defer.returnValue((now_stream_id, results))
+ return (now_stream_id, results)
def _get_devices_by_remote_txn(
self, txn, destination, from_stream_id, now_stream_id, limit
@@ -232,7 +232,7 @@ class DeviceWorkerStore(SQLBaseStore):
results.append(result)
- defer.returnValue(results)
+ return results
def _get_last_device_update_for_remote_user(
self, destination, user_id, from_stream_id
@@ -330,7 +330,7 @@ class DeviceWorkerStore(SQLBaseStore):
else:
results[user_id] = yield self._get_cached_devices_for_user(user_id)
- defer.returnValue((user_ids_not_in_cache, results))
+ return (user_ids_not_in_cache, results)
@cachedInlineCallbacks(num_args=2, tree=True)
def _get_cached_user_device(self, user_id, device_id):
@@ -340,7 +340,7 @@ class DeviceWorkerStore(SQLBaseStore):
retcol="content",
desc="_get_cached_user_device",
)
- defer.returnValue(db_to_json(content))
+ return db_to_json(content)
@cachedInlineCallbacks()
def _get_cached_devices_for_user(self, user_id):
@@ -350,9 +350,9 @@ class DeviceWorkerStore(SQLBaseStore):
retcols=("device_id", "content"),
desc="_get_cached_devices_for_user",
)
- defer.returnValue(
- {device["device_id"]: db_to_json(device["content"]) for device in devices}
- )
+ return {
+ device["device_id"]: db_to_json(device["content"]) for device in devices
+ }
def get_devices_with_keys_by_user(self, user_id):
"""Get all devices (with any device keys) for a user
@@ -482,7 +482,7 @@ class DeviceWorkerStore(SQLBaseStore):
results = {user_id: None for user_id in user_ids}
results.update({row["user_id"]: row["stream_id"] for row in rows})
- defer.returnValue(results)
+ return results
class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
@@ -543,7 +543,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
"""
key = (user_id, device_id)
if self.device_id_exists_cache.get(key, None):
- defer.returnValue(False)
+ return False
try:
inserted = yield self._simple_insert(
@@ -557,7 +557,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
or_ignore=True,
)
self.device_id_exists_cache.prefill(key, True)
- defer.returnValue(inserted)
+ return inserted
except Exception as e:
logger.error(
"store_device with device_id=%s(%r) user_id=%s(%r)"
@@ -780,7 +780,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
hosts,
stream_id,
)
- defer.returnValue(stream_id)
+ return stream_id
def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
now = self._clock.time_msec()
@@ -889,4 +889,4 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
yield self.runWithConnection(f)
yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
- defer.returnValue(1)
+ return 1
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index 201bbd430c..e966a73f3d 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -46,7 +46,7 @@ class DirectoryWorkerStore(SQLBaseStore):
)
if not room_id:
- defer.returnValue(None)
+ return None
return
servers = yield self._simple_select_onecol(
@@ -57,10 +57,10 @@ class DirectoryWorkerStore(SQLBaseStore):
)
if not servers:
- defer.returnValue(None)
+ return None
return
- defer.returnValue(RoomAliasMapping(room_id, room_alias.to_string(), servers))
+ return RoomAliasMapping(room_id, room_alias.to_string(), servers)
def get_room_alias_creator(self, room_alias):
return self._simple_select_one_onecol(
@@ -125,7 +125,7 @@ class DirectoryStore(DirectoryWorkerStore):
raise SynapseError(
409, "Room alias %s already exists" % room_alias.to_string()
)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def delete_room_alias(self, room_alias):
@@ -133,7 +133,7 @@ class DirectoryStore(DirectoryWorkerStore):
"delete_room_alias", self._delete_room_alias_txn, room_alias
)
- defer.returnValue(room_id)
+ return room_id
def _delete_room_alias_txn(self, txn, room_alias):
txn.execute(
diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/e2e_room_keys.py
index f40ef2ab64..99128f2df7 100644
--- a/synapse/storage/e2e_room_keys.py
+++ b/synapse/storage/e2e_room_keys.py
@@ -61,7 +61,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
row["session_data"] = json.loads(row["session_data"])
- defer.returnValue(row)
+ return row
@defer.inlineCallbacks
def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
@@ -118,7 +118,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
try:
version = int(version)
except ValueError:
- defer.returnValue({"rooms": {}})
+ return {"rooms": {}}
keyvalues = {"user_id": user_id, "version": version}
if room_id:
@@ -151,7 +151,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"session_data": json.loads(row["session_data"]),
}
- defer.returnValue(sessions)
+ return sessions
@defer.inlineCallbacks
def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 2fabb9e2cb..1e07474e70 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -41,7 +41,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
dict containing "key_json", "device_display_name".
"""
if not query_list:
- defer.returnValue({})
+ return {}
results = yield self.runInteraction(
"get_e2e_device_keys",
@@ -55,7 +55,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for device_id, device_info in iteritems(device_keys):
device_info["keys"] = db_to_json(device_info.pop("key_json"))
- defer.returnValue(results)
+ return results
def _get_e2e_device_keys_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
@@ -130,9 +130,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
desc="add_e2e_one_time_keys_check",
)
- defer.returnValue(
- {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
- )
+ return {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
@defer.inlineCallbacks
def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index cb4478342f..4f500d893e 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -131,9 +131,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
)
if not rows:
- defer.returnValue(0)
+ return 0
else:
- defer.returnValue(max(row["depth"] for row in rows))
+ return max(row["depth"] for row in rows)
def _get_oldest_events_in_room_txn(self, txn, room_id):
return self._simple_select_onecol_txn(
@@ -169,7 +169,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# make sure that we don't completely ignore the older events.
res = res[0:5] + random.sample(res[5:], 5)
- defer.returnValue(res)
+ return res
def get_latest_event_ids_and_hashes_in_room(self, room_id):
"""
@@ -411,7 +411,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
limit,
)
events = yield self.get_events_as_list(ids)
- defer.returnValue(events)
+ return events
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
@@ -463,7 +463,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
desc="get_successor_events",
)
- defer.returnValue([row["event_id"] for row in rows])
+ return [row["event_id"] for row in rows]
class EventFederationStore(EventFederationWorkerStore):
@@ -654,4 +654,4 @@ class EventFederationStore(EventFederationWorkerStore):
if not result:
yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY)
- defer.returnValue(batch_size)
+ return batch_size
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index eca77069fd..22025effbc 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -79,8 +79,6 @@ class EventPushActionsWorkerStore(SQLBaseStore):
db_conn.cursor(),
name="_find_stream_orderings_for_times_txn",
database_engine=self.database_engine,
- after_callbacks=[],
- exception_callbacks=[],
)
self._find_stream_orderings_for_times_txn(cur)
cur.close()
@@ -102,7 +100,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
user_id,
last_read_event_id,
)
- defer.returnValue(ret)
+ return ret
def _get_unread_counts_by_receipt_txn(
self, txn, room_id, user_id, last_read_event_id
@@ -180,7 +178,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
return [r[0] for r in txn]
ret = yield self.runInteraction("get_push_action_users_in_range", f)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def get_unread_push_actions_for_user_in_range_for_http(
@@ -281,7 +279,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# Take only up to the limit. We have to stop at the limit because
# one of the subqueries may have hit the limit.
- defer.returnValue(notifs[:limit])
+ return notifs[:limit]
@defer.inlineCallbacks
def get_unread_push_actions_for_user_in_range_for_email(
@@ -382,7 +380,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
notifs.sort(key=lambda r: -(r["received_ts"] or 0))
# Now return the first `limit`
- defer.returnValue(notifs[:limit])
+ return notifs[:limit]
def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering):
"""A fast check to see if there might be something to push for the
@@ -479,7 +477,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
keyvalues={"event_id": event_id},
desc="remove_push_actions_from_staging",
)
- defer.returnValue(res)
+ return res
except Exception:
# this method is called from an exception handler, so propagating
# another exception here really isn't helpful - there's nothing
@@ -734,7 +732,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
push_actions = yield self.runInteraction("get_push_actions_for_user", f)
for pa in push_actions:
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
- defer.returnValue(push_actions)
+ return push_actions
@defer.inlineCallbacks
def get_time_of_last_push_action_before(self, stream_ordering):
@@ -751,7 +749,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
return txn.fetchone()
result = yield self.runInteraction("get_time_of_last_push_action_before", f)
- defer.returnValue(result[0] if result else None)
+ return result[0] if result else None
@defer.inlineCallbacks
def get_latest_push_action_stream_ordering(self):
@@ -760,7 +758,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
return txn.fetchone()
result = yield self.runInteraction("get_latest_push_action_stream_ordering", f)
- defer.returnValue(result[0] or 0)
+ return result[0] or 0
def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
# Sad that we have to blow away the cache for the whole room here
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index b486ca50eb..511f0f251f 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -223,7 +223,7 @@ def _retry_on_integrity_error(func):
except self.database_engine.module.IntegrityError:
logger.exception("IntegrityError, retrying.")
res = yield func(self, *args, delete_existing=True, **kwargs)
- defer.returnValue(res)
+ return res
return f
@@ -309,7 +309,7 @@ class EventsStore(
max_persisted_id = yield self._stream_id_gen.get_current_token()
- defer.returnValue(max_persisted_id)
+ return max_persisted_id
@defer.inlineCallbacks
@log_function
@@ -334,7 +334,7 @@ class EventsStore(
yield make_deferred_yieldable(deferred)
max_persisted_id = yield self._stream_id_gen.get_current_token()
- defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
+ return (event.internal_metadata.stream_ordering, max_persisted_id)
def _maybe_start_persisting(self, room_id):
@defer.inlineCallbacks
@@ -595,7 +595,7 @@ class EventsStore(
stale = latest_event_ids & result
stale_forward_extremities_counter.observe(len(stale))
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _get_events_which_are_prevs(self, event_ids):
@@ -633,7 +633,7 @@ class EventsStore(
"_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
)
- defer.returnValue(results)
+ return results
@defer.inlineCallbacks
def _get_prevs_before_rejected(self, event_ids):
@@ -695,7 +695,7 @@ class EventsStore(
"_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
)
- defer.returnValue(existing_prevs)
+ return existing_prevs
@defer.inlineCallbacks
def _get_new_state_after_events(
@@ -796,7 +796,7 @@ class EventsStore(
# If they old and new groups are the same then we don't need to do
# anything.
if old_state_groups == new_state_groups:
- defer.returnValue((None, None))
+ return (None, None)
if len(new_state_groups) == 1 and len(old_state_groups) == 1:
# If we're going from one state group to another, lets check if
@@ -813,7 +813,7 @@ class EventsStore(
# the current state in memory then lets also return that,
# but it doesn't matter if we don't.
new_state = state_groups_map.get(new_state_group)
- defer.returnValue((new_state, delta_ids))
+ return (new_state, delta_ids)
# Now that we have calculated new_state_groups we need to get
# their state IDs so we can resolve to a single state set.
@@ -825,7 +825,7 @@ class EventsStore(
if len(new_state_groups) == 1:
# If there is only one state group, then we know what the current
# state is.
- defer.returnValue((state_groups_map[new_state_groups.pop()], None))
+ return (state_groups_map[new_state_groups.pop()], None)
# Ok, we need to defer to the state handler to resolve our state sets.
@@ -854,7 +854,7 @@ class EventsStore(
state_res_store=StateResolutionStore(self),
)
- defer.returnValue((res.state, None))
+ return (res.state, None)
@defer.inlineCallbacks
def _calculate_state_delta(self, room_id, current_state):
@@ -877,7 +877,7 @@ class EventsStore(
if ev_id != existing_state.get(key)
}
- defer.returnValue((to_delete, to_insert))
+ return (to_delete, to_insert)
@log_function
def _persist_events_txn(
@@ -918,8 +918,6 @@ class EventsStore(
min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
- self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
-
self._update_forward_extremities_txn(
txn,
new_forward_extremities=new_forward_extremeties,
@@ -993,6 +991,10 @@ class EventsStore(
backfilled=backfilled,
)
+ # We call this last as it assumes we've inserted the events into
+ # room_memberships, where applicable.
+ self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+
def _update_current_state_txn(self, txn, state_delta_by_room, stream_id):
for room_id, current_state_tuple in iteritems(state_delta_by_room):
to_delete, to_insert = current_state_tuple
@@ -1062,16 +1064,16 @@ class EventsStore(
),
)
- self._simple_insert_many_txn(
- txn,
- table="current_state_events",
- values=[
- {
- "event_id": ev_id,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- }
+ # We include the membership in the current state table, hence we do
+ # a lookup when we insert. This assumes that all events have already
+ # been inserted into room_memberships.
+ txn.executemany(
+ """INSERT INTO current_state_events
+ (room_id, type, state_key, event_id, membership)
+ VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
+ """,
+ [
+ (room_id, key[0], key[1], ev_id, ev_id)
for key, ev_id in iteritems(to_insert)
],
)
@@ -1455,6 +1457,9 @@ class EventsStore(
elif event.type == EventTypes.GuestAccess:
# Insert into the event_search table.
self._store_guest_access_txn(txn, event)
+ elif event.type == EventTypes.Retention:
+ # Update the room_retention table.
+ self._store_retention_policy_for_room_txn(txn, event)
self._handle_event_relations(txn, event)
@@ -1562,7 +1567,7 @@ class EventsStore(
return count
ret = yield self.runInteraction("count_messages", _count_messages)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def count_daily_sent_messages(self):
@@ -1583,7 +1588,7 @@ class EventsStore(
return count
ret = yield self.runInteraction("count_daily_sent_messages", _count_messages)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def count_daily_active_rooms(self):
@@ -1598,7 +1603,7 @@ class EventsStore(
return count
ret = yield self.runInteraction("count_daily_active_rooms", _count)
- defer.returnValue(ret)
+ return ret
def get_current_backfill_token(self):
"""The current minimum token that backfilled events have reached"""
@@ -2181,7 +2186,7 @@ class EventsStore(
"""
to_1, so_1 = yield self._get_event_ordering(event_id1)
to_2, so_2 = yield self._get_event_ordering(event_id2)
- defer.returnValue((to_1, so_1) > (to_2, so_2))
+ return (to_1, so_1) > (to_2, so_2)
@cachedInlineCallbacks(max_entries=5000)
def _get_event_ordering(self, event_id):
@@ -2195,9 +2200,7 @@ class EventsStore(
if not res:
raise SynapseError(404, "Could not find event %s" % (event_id,))
- defer.returnValue(
- (int(res["topological_ordering"]), int(res["stream_ordering"]))
- )
+ return (int(res["topological_ordering"]), int(res["stream_ordering"]))
def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
def get_all_updated_current_state_deltas_txn(txn):
diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/events_bg_updates.py
index 1ce21d190c..6587f31e2b 100644
--- a/synapse/storage/events_bg_updates.py
+++ b/synapse/storage/events_bg_updates.py
@@ -135,7 +135,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
if not result:
yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _background_reindex_origin_server_ts(self, progress, batch_size):
@@ -212,7 +212,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
if not result:
yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _cleanup_extremities_bg_update(self, progress, batch_size):
@@ -396,4 +396,4 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
"_cleanup_extremities_bg_update_drop_table", _drop_table_txn
)
- defer.returnValue(num_handled)
+ return num_handled
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 06379281b6..79680ee856 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -139,8 +139,11 @@ class EventsWorkerStore(SQLBaseStore):
If there is a mismatch, behave as per allow_none.
Returns:
- Deferred : A FrozenEvent.
+ Deferred[EventBase|None]
"""
+ if not isinstance(event_id, str):
+ raise TypeError("Invalid event event_id %r" % (event_id,))
+
events = yield self.get_events_as_list(
[event_id],
check_redacted=check_redacted,
@@ -157,7 +160,7 @@ class EventsWorkerStore(SQLBaseStore):
if event is None and not allow_none:
raise NotFoundError("Could not find event %s" % (event_id,))
- defer.returnValue(event)
+ return event
@defer.inlineCallbacks
def get_events(
@@ -187,7 +190,7 @@ class EventsWorkerStore(SQLBaseStore):
allow_rejected=allow_rejected,
)
- defer.returnValue({e.event_id: e for e in events})
+ return {e.event_id: e for e in events}
@defer.inlineCallbacks
def get_events_as_list(
@@ -217,7 +220,7 @@ class EventsWorkerStore(SQLBaseStore):
"""
if not event_ids:
- defer.returnValue([])
+ return []
# there may be duplicates so we cast the list to a set
event_entry_map = yield self._get_events_from_cache_or_db(
@@ -313,7 +316,7 @@ class EventsWorkerStore(SQLBaseStore):
event.unsigned["prev_content"] = prev.content
event.unsigned["prev_sender"] = prev.sender
- defer.returnValue(events)
+ return events
@defer.inlineCallbacks
def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
@@ -460,7 +463,7 @@ class EventsWorkerStore(SQLBaseStore):
without having to create a new transaction for each request for events.
"""
if not events:
- defer.returnValue({})
+ return {}
events_d = defer.Deferred()
with self._event_fetch_lock:
@@ -504,7 +507,7 @@ class EventsWorkerStore(SQLBaseStore):
)
)
- defer.returnValue({e.event.event_id: e for e in res if e})
+ return {e.event.event_id: e for e in res if e}
def _fetch_event_rows(self, txn, event_ids):
"""Fetch event rows from the database
@@ -617,7 +620,7 @@ class EventsWorkerStore(SQLBaseStore):
self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
- defer.returnValue(cache_entry)
+ return cache_entry
@defer.inlineCallbacks
def _maybe_redact_event_row(self, original_ev, redactions):
@@ -710,7 +713,7 @@ class EventsWorkerStore(SQLBaseStore):
desc="have_events_in_timeline",
)
- defer.returnValue(set(r["event_id"] for r in rows))
+ return set(r["event_id"] for r in rows)
@defer.inlineCallbacks
def have_seen_events(self, event_ids):
@@ -736,7 +739,7 @@ class EventsWorkerStore(SQLBaseStore):
input_iterator = iter(event_ids)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk)
- defer.returnValue(results)
+ return results
def get_seen_events_with_rejections(self, event_ids):
"""Given a list of event ids, check if we rejected them.
@@ -847,4 +850,4 @@ class EventsWorkerStore(SQLBaseStore):
# it.
complexity_v1 = round(state_events / 500, 2)
- defer.returnValue({"v1": complexity_v1})
+ return {"v1": complexity_v1}
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index b195dc66a0..23b48f6cea 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -15,8 +15,6 @@
from canonicaljson import encode_canonical_json
-from twisted.internet import defer
-
from synapse.api.errors import Codes, SynapseError
from synapse.util.caches.descriptors import cachedInlineCallbacks
@@ -41,7 +39,7 @@ class FilteringStore(SQLBaseStore):
desc="get_user_filter",
)
- defer.returnValue(db_to_json(def_json))
+ return db_to_json(def_json)
def add_user_filter(self, user_localpart, user_filter):
def_json = encode_canonical_json(user_filter)
diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py
index 73e6fc6de2..15b01c6958 100644
--- a/synapse/storage/group_server.py
+++ b/synapse/storage/group_server.py
@@ -307,15 +307,13 @@ class GroupServerStore(SQLBaseStore):
desc="get_group_categories",
)
- defer.returnValue(
- {
- row["category_id"]: {
- "is_public": row["is_public"],
- "profile": json.loads(row["profile"]),
- }
- for row in rows
+ return {
+ row["category_id"]: {
+ "is_public": row["is_public"],
+ "profile": json.loads(row["profile"]),
}
- )
+ for row in rows
+ }
@defer.inlineCallbacks
def get_group_category(self, group_id, category_id):
@@ -328,7 +326,7 @@ class GroupServerStore(SQLBaseStore):
category["profile"] = json.loads(category["profile"])
- defer.returnValue(category)
+ return category
def upsert_group_category(self, group_id, category_id, profile, is_public):
"""Add/update room category for group
@@ -370,15 +368,13 @@ class GroupServerStore(SQLBaseStore):
desc="get_group_roles",
)
- defer.returnValue(
- {
- row["role_id"]: {
- "is_public": row["is_public"],
- "profile": json.loads(row["profile"]),
- }
- for row in rows
+ return {
+ row["role_id"]: {
+ "is_public": row["is_public"],
+ "profile": json.loads(row["profile"]),
}
- )
+ for row in rows
+ }
@defer.inlineCallbacks
def get_group_role(self, group_id, role_id):
@@ -391,7 +387,7 @@ class GroupServerStore(SQLBaseStore):
role["profile"] = json.loads(role["profile"])
- defer.returnValue(role)
+ return role
def upsert_group_role(self, group_id, role_id, profile, is_public):
"""Add/remove user role
@@ -960,7 +956,7 @@ class GroupServerStore(SQLBaseStore):
_register_user_group_membership_txn,
next_id,
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def create_group(
@@ -1057,9 +1053,9 @@ class GroupServerStore(SQLBaseStore):
now = int(self._clock.time_msec())
if row and now < row["valid_until_ms"]:
- defer.returnValue(json.loads(row["attestation_json"]))
+ return json.loads(row["attestation_json"])
- defer.returnValue(None)
+ return None
def get_joined_groups(self, user_id):
return self._simple_select_onecol(
diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py
index 081564360f..752e9788a2 100644
--- a/synapse/storage/monthly_active_users.py
+++ b/synapse/storage/monthly_active_users.py
@@ -173,7 +173,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
)
if user_id:
count = count + 1
- defer.returnValue(count)
+ return count
@defer.inlineCallbacks
def upsert_monthly_active_user(self, user_id):
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 7c4e1dc7ec..d20eacda59 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 55
+SCHEMA_VERSION = 56
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 42ec8c6bb8..1a0f2d5768 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -90,9 +90,7 @@ class PresenceStore(SQLBaseStore):
presence_states,
)
- defer.returnValue(
- (stream_orderings[-1], self._presence_id_gen.get_current_token())
- )
+ return (stream_orderings[-1], self._presence_id_gen.get_current_token())
def _update_presence_txn(self, txn, stream_orderings, presence_states):
for stream_id, state in zip(stream_orderings, presence_states):
@@ -180,7 +178,7 @@ class PresenceStore(SQLBaseStore):
for row in rows:
row["currently_active"] = bool(row["currently_active"])
- defer.returnValue({row["user_id"]: UserPresenceState(**row) for row in rows})
+ return {row["user_id"]: UserPresenceState(**row) for row in rows}
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
diff --git a/synapse/storage/profile.py b/synapse/storage/profile.py
index 0ff392bdb4..96a7e32fca 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/profile.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,8 +19,11 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.storage.roommember import ProfileInfo
+from . import background_updates
from ._base import SQLBaseStore
+BATCH_SIZE = 100
+
class ProfileWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
@@ -34,15 +38,13 @@ class ProfileWorkerStore(SQLBaseStore):
except StoreError as e:
if e.code == 404:
# no match
- defer.returnValue(ProfileInfo(None, None))
+ return ProfileInfo(None, None)
return
else:
raise
- defer.returnValue(
- ProfileInfo(
- avatar_url=profile["avatar_url"], display_name=profile["displayname"]
- )
+ return ProfileInfo(
+ avatar_url=profile["avatar_url"], display_name=profile["displayname"]
)
def get_profile_displayname(self, user_localpart):
@@ -61,6 +63,54 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_avatar_url",
)
+ def get_latest_profile_replication_batch_number(self):
+ def f(txn):
+ txn.execute("SELECT MAX(batch) as maxbatch FROM profiles")
+ rows = self.cursor_to_dict(txn)
+ return rows[0]["maxbatch"]
+
+ return self.runInteraction("get_latest_profile_replication_batch_number", f)
+
+ def get_profile_batch(self, batchnum):
+ return self._simple_select_list(
+ table="profiles",
+ keyvalues={"batch": batchnum},
+ retcols=("user_id", "displayname", "avatar_url", "active"),
+ desc="get_profile_batch",
+ )
+
+ def assign_profile_batch(self):
+ def f(txn):
+ sql = (
+ "UPDATE profiles SET batch = "
+ "(SELECT COALESCE(MAX(batch), -1) + 1 FROM profiles) "
+ "WHERE user_id in ("
+ " SELECT user_id FROM profiles WHERE batch is NULL limit ?"
+ ")"
+ )
+ txn.execute(sql, (BATCH_SIZE,))
+ return txn.rowcount
+
+ return self.runInteraction("assign_profile_batch", f)
+
+ def get_replication_hosts(self):
+ def f(txn):
+ txn.execute(
+ "SELECT host, last_synced_batch FROM profile_replication_status"
+ )
+ rows = self.cursor_to_dict(txn)
+ return {r["host"]: r["last_synced_batch"] for r in rows}
+
+ return self.runInteraction("get_replication_hosts", f)
+
+ def update_replication_batch_for_host(self, host, last_synced_batch):
+ return self._simple_upsert(
+ table="profile_replication_status",
+ keyvalues={"host": host},
+ values={"last_synced_batch": last_synced_batch},
+ desc="update_replication_batch_for_host",
+ )
+
def get_from_remote_profile_cache(self, user_id):
return self._simple_select_one(
table="remote_profile_cache",
@@ -70,29 +120,53 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_from_remote_profile_cache",
)
- def create_profile(self, user_localpart):
- return self._simple_insert(
- table="profiles", values={"user_id": user_localpart}, desc="create_profile"
- )
-
- def set_profile_displayname(self, user_localpart, new_displayname):
- return self._simple_update_one(
+ def set_profile_displayname(self, user_localpart, new_displayname, batchnum):
+ return self._simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
- updatevalues={"displayname": new_displayname},
+ values={"displayname": new_displayname, "batch": batchnum},
desc="set_profile_displayname",
+ lock=False, # we can do this because user_id has a unique index
)
- def set_profile_avatar_url(self, user_localpart, new_avatar_url):
- return self._simple_update_one(
+ def set_profile_avatar_url(self, user_localpart, new_avatar_url, batchnum):
+ return self._simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
- updatevalues={"avatar_url": new_avatar_url},
+ values={"avatar_url": new_avatar_url, "batch": batchnum},
desc="set_profile_avatar_url",
+ lock=False, # we can do this because user_id has a unique index
+ )
+
+ def set_profile_active(self, user_localpart, active, hide, batchnum):
+ values = {"active": int(active), "batch": batchnum}
+ if not active and not hide:
+ # we are deactivating for real (not in hide mode)
+ # so clear the profile.
+ values["avatar_url"] = None
+ values["displayname"] = None
+ return self._simple_upsert(
+ table="profiles",
+ keyvalues={"user_id": user_localpart},
+ values=values,
+ desc="set_profile_active",
+ lock=False, # we can do this because user_id has a unique index
)
-class ProfileStore(ProfileWorkerStore):
+class ProfileStore(ProfileWorkerStore, background_updates.BackgroundUpdateStore):
+ def __init__(self, db_conn, hs):
+
+ super(ProfileStore, self).__init__(db_conn, hs)
+
+ self.register_background_index_update(
+ "profile_replication_status_host_index",
+ index_name="profile_replication_status_idx",
+ table="profile_replication_status",
+ columns=["host"],
+ unique=True,
+ )
+
def add_remote_profile_cache(self, user_id, displayname, avatar_url):
"""Ensure we are caching the remote user's profiles.
@@ -168,7 +242,7 @@ class ProfileStore(ProfileWorkerStore):
)
if res:
- defer.returnValue(True)
+ return True
res = yield self._simple_select_one_onecol(
table="group_invites",
@@ -179,4 +253,4 @@ class ProfileStore(ProfileWorkerStore):
)
if res:
- defer.returnValue(True)
+ return True
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 98cec8c82b..a6517c4cf3 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -120,7 +120,7 @@ class PushRulesWorkerStore(
rules = _load_rules(rows, enabled_map)
- defer.returnValue(rules)
+ return rules
@cachedInlineCallbacks(max_entries=5000)
def get_push_rules_enabled_for_user(self, user_id):
@@ -130,9 +130,7 @@ class PushRulesWorkerStore(
retcols=("user_name", "rule_id", "enabled"),
desc="get_push_rules_enabled_for_user",
)
- defer.returnValue(
- {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
- )
+ return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
def have_push_rules_changed_for_user(self, user_id, last_id):
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
@@ -160,7 +158,7 @@ class PushRulesWorkerStore(
)
def bulk_get_push_rules(self, user_ids):
if not user_ids:
- defer.returnValue({})
+ return {}
results = {user_id: [] for user_id in user_ids}
@@ -182,7 +180,7 @@ class PushRulesWorkerStore(
for user_id, rules in results.items():
results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
- defer.returnValue(results)
+ return results
@defer.inlineCallbacks
def move_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
@@ -253,7 +251,7 @@ class PushRulesWorkerStore(
result = yield self._bulk_get_push_rules_for_room(
event.room_id, state_group, current_state_ids, event=event
)
- defer.returnValue(result)
+ return result
@cachedInlineCallbacks(num_args=2, cache_context=True)
def _bulk_get_push_rules_for_room(
@@ -312,7 +310,7 @@ class PushRulesWorkerStore(
rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
- defer.returnValue(rules_by_user)
+ return rules_by_user
@cachedList(
cached_method_name="get_push_rules_enabled_for_user",
@@ -322,7 +320,7 @@ class PushRulesWorkerStore(
)
def bulk_get_push_rules_enabled(self, user_ids):
if not user_ids:
- defer.returnValue({})
+ return {}
results = {user_id: {} for user_id in user_ids}
@@ -336,7 +334,7 @@ class PushRulesWorkerStore(
for row in rows:
enabled = bool(row["enabled"])
results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
- defer.returnValue(results)
+ return results
class PushRuleStore(PushRulesWorkerStore):
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index cfe0a94330..be3d4d9ded 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -63,7 +63,7 @@ class PusherWorkerStore(SQLBaseStore):
ret = yield self._simple_select_one_onecol(
"pushers", {"user_name": user_id}, "id", allow_none=True
)
- defer.returnValue(ret is not None)
+ return ret is not None
def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
@@ -95,7 +95,7 @@ class PusherWorkerStore(SQLBaseStore):
],
desc="get_pushers_by",
)
- defer.returnValue(self._decode_pushers_rows(ret))
+ return self._decode_pushers_rows(ret)
@defer.inlineCallbacks
def get_all_pushers(self):
@@ -106,7 +106,7 @@ class PusherWorkerStore(SQLBaseStore):
return self._decode_pushers_rows(rows)
rows = yield self.runInteraction("get_all_pushers", get_pushers)
- defer.returnValue(rows)
+ return rows
def get_all_updated_pushers(self, last_id, current_id, limit):
if last_id == current_id:
@@ -205,7 +205,7 @@ class PusherWorkerStore(SQLBaseStore):
result = {user_id: False for user_id in user_ids}
result.update({r["user_name"]: True for r in rows})
- defer.returnValue(result)
+ return result
class PusherStore(PusherWorkerStore):
@@ -343,7 +343,7 @@ class PusherStore(PusherWorkerStore):
"throttle_ms": row["throttle_ms"],
}
- defer.returnValue(params_by_room)
+ return params_by_room
@defer.inlineCallbacks
def set_throttle_params(self, pusher_id, room_id, params):
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index b477da12b1..6aa6d98ebb 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -58,7 +58,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cachedInlineCallbacks()
def get_users_with_read_receipts_in_room(self, room_id):
receipts = yield self.get_receipts_for_room(room_id, "m.read")
- defer.returnValue(set(r["user_id"] for r in receipts))
+ return set(r["user_id"] for r in receipts)
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
@@ -92,7 +92,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
desc="get_receipts_for_user",
)
- defer.returnValue({row["room_id"]: row["event_id"] for row in rows})
+ return {row["room_id"]: row["event_id"] for row in rows}
@defer.inlineCallbacks
def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
@@ -110,16 +110,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
return txn.fetchall()
rows = yield self.runInteraction("get_receipts_for_user_with_orderings", f)
- defer.returnValue(
- {
- row[0]: {
- "event_id": row[1],
- "topological_ordering": row[2],
- "stream_ordering": row[3],
- }
- for row in rows
+ return {
+ row[0]: {
+ "event_id": row[1],
+ "topological_ordering": row[2],
+ "stream_ordering": row[3],
}
- )
+ for row in rows
+ }
@defer.inlineCallbacks
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
@@ -147,7 +145,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
room_ids, to_key, from_key=from_key
)
- defer.returnValue([ev for res in results.values() for ev in res])
+ return [ev for res in results.values() for ev in res]
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients.
@@ -197,7 +195,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
rows = yield self.runInteraction("get_linearized_receipts_for_room", f)
if not rows:
- defer.returnValue([])
+ return []
content = {}
for row in rows:
@@ -205,9 +203,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
row["user_id"]
] = json.loads(row["data"])
- defer.returnValue(
- [{"type": "m.receipt", "room_id": room_id, "content": content}]
- )
+ return [{"type": "m.receipt", "room_id": room_id, "content": content}]
@cachedList(
cached_method_name="_get_linearized_receipts_for_room",
@@ -217,7 +213,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
- defer.returnValue({})
+ return {}
def f(txn):
if from_key:
@@ -264,7 +260,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
room_id: [results[room_id]] if room_id in results else []
for room_id in room_ids
}
- defer.returnValue(results)
+ return results
def get_all_updated_receipts(self, last_id, current_id, limit=None):
if last_id == current_id:
@@ -468,7 +464,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
)
if event_ts is None:
- defer.returnValue(None)
+ return None
now = self._clock.time_msec()
logger.debug(
@@ -482,7 +478,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
max_persisted_id = self._receipts_id_gen.get_current_token()
- defer.returnValue((stream_id, max_persisted_id))
+ return (stream_id, max_persisted_id)
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
return self.runInteraction(
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 8b2c2a97ab..ea5b2be0f7 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -75,12 +75,12 @@ class RegistrationWorkerStore(SQLBaseStore):
info = yield self.get_user_by_id(user_id)
if not info:
- defer.returnValue(False)
+ return False
now = self.clock.time_msec()
trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
- defer.returnValue(is_trial)
+ return is_trial
@cached()
def get_user_by_access_token(self, token):
@@ -115,7 +115,7 @@ class RegistrationWorkerStore(SQLBaseStore):
allow_none=True,
desc="get_expiration_ts_for_user",
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def set_account_validity_for_user(
@@ -155,6 +155,28 @@ class RegistrationWorkerStore(SQLBaseStore):
)
@defer.inlineCallbacks
+ def get_expired_users(self):
+ """Get IDs of all expired users
+
+ Returns:
+ Deferred[list[str]]: List of expired user IDs
+ """
+
+ def get_expired_users_txn(txn, now_ms):
+ sql = """
+ SELECT user_id from account_validity
+ WHERE expiration_ts_ms <= ?
+ """
+ txn.execute(sql, (now_ms,))
+ rows = txn.fetchall()
+ return [row[0] for row in rows]
+
+ res = yield self.runInteraction(
+ "get_expired_users", get_expired_users_txn, self.clock.time_msec()
+ )
+ defer.returnValue(res)
+
+ @defer.inlineCallbacks
def set_renewal_token_for_user(self, user_id, renewal_token):
"""Defines a renewal token for a given user.
@@ -190,7 +212,7 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_user_from_renewal_token",
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def get_renewal_token_for_user(self, user_id):
@@ -209,7 +231,7 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_renewal_token_for_user",
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def get_users_expiring_soon(self):
@@ -237,7 +259,7 @@ class RegistrationWorkerStore(SQLBaseStore):
self.config.account_validity.renew_at,
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def set_renewal_mail_status(self, user_id, email_sent):
@@ -280,7 +302,7 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="is_server_admin",
)
- defer.returnValue(res if res else False)
+ return res if res else False
def _query_for_auth(self, txn, token):
sql = (
@@ -311,7 +333,7 @@ class RegistrationWorkerStore(SQLBaseStore):
res = yield self.runInteraction(
"is_support_user", self.is_support_user_txn, user_id
)
- defer.returnValue(res)
+ return res
def is_support_user_txn(self, txn, user_id):
res = self._simple_select_one_onecol_txn(
@@ -349,7 +371,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return 0
ret = yield self.runInteraction("count_users", _count_users)
- defer.returnValue(ret)
+ return ret
def count_daily_user_type(self):
"""
@@ -395,7 +417,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return count
ret = yield self.runInteraction("count_users", _count_users)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def find_next_generated_user_id_localpart(self):
@@ -425,7 +447,7 @@ class RegistrationWorkerStore(SQLBaseStore):
if i not in found:
return i
- defer.returnValue(
+ return (
(
yield self.runInteraction(
"find_next_generated_user_id", _find_next_generated_user_id
@@ -447,7 +469,7 @@ class RegistrationWorkerStore(SQLBaseStore):
user_id = yield self.runInteraction(
"get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address
)
- defer.returnValue(user_id)
+ return user_id
def get_user_id_by_threepid_txn(self, txn, medium, address):
"""Returns user id from threepid
@@ -487,7 +509,7 @@ class RegistrationWorkerStore(SQLBaseStore):
["medium", "address", "validated_at", "added_at"],
"user_get_threepids",
)
- defer.returnValue(ret)
+ return ret
def user_delete_threepid(self, user_id, medium, address):
return self._simple_delete(
@@ -677,7 +699,7 @@ class RegistrationStore(
if end:
yield self._end_background_update("users_set_deactivated_flag")
- defer.returnValue(batch_size)
+ return batch_size
@defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
@@ -957,7 +979,7 @@ class RegistrationStore(
desc="is_guest",
)
- defer.returnValue(res if res else False)
+ return res if res else False
def add_user_pending_deactivation(self, user_id):
"""
@@ -1024,7 +1046,7 @@ class RegistrationStore(
yield self._end_background_update("user_threepids_grandfather")
- defer.returnValue(1)
+ return 1
def get_threepid_validation_session(
self, medium, client_secret, address=None, sid=None, validated=True
@@ -1337,4 +1359,4 @@ class RegistrationStore(
)
# Convert the integer into a boolean.
- defer.returnValue(res == 1)
+ return res == 1
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index 9954bc094f..fcb5f2f23a 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -17,8 +17,6 @@ import logging
import attr
-from twisted.internet import defer
-
from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
@@ -363,7 +361,7 @@ class RelationsWorkerStore(SQLBaseStore):
return
edit_event = yield self.get_event(edit_id, allow_none=True)
- defer.returnValue(edit_event)
+ return edit_event
def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
"""Check if a user has already annotated an event with the same key
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index fe9d79d792..1ca01a4c70 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -17,10 +17,13 @@ import collections
import logging
import re
+from six import integer_types
+
from canonicaljson import json
from twisted.internet import defer
+from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.search import SearchStore
@@ -171,6 +174,24 @@ class RoomWorkerStore(SQLBaseStore):
desc="is_room_blocked",
)
+ @defer.inlineCallbacks
+ def is_room_published(self, room_id):
+ """Check whether a room has been published in the local public room
+ directory.
+
+ Args:
+ room_id (str)
+ Returns:
+ bool: Whether the room is currently published in the room directory
+ """
+ # Get room information
+ room_info = yield self.get_room(room_id)
+ if not room_info:
+ defer.returnValue(False)
+
+ # Check the is_public value
+ defer.returnValue(room_info.get("is_public", False))
+
@cachedInlineCallbacks(max_entries=10000)
def get_ratelimit_for_user(self, user_id):
"""Check if there are any overrides for ratelimiting for the given
@@ -193,17 +214,153 @@ class RoomWorkerStore(SQLBaseStore):
)
if row:
- defer.returnValue(
- RatelimitOverride(
- messages_per_second=row["messages_per_second"],
- burst_count=row["burst_count"],
- )
+ return RatelimitOverride(
+ messages_per_second=row["messages_per_second"],
+ burst_count=row["burst_count"],
)
else:
- defer.returnValue(None)
+ return None
+
+ @cachedInlineCallbacks()
+ def get_retention_policy_for_room(self, room_id):
+ """Get the retention policy for a given room.
+
+ If no retention policy has been found for this room, returns a policy defined
+ by the configured default policy (which has None as both the 'min_lifetime' and
+ the 'max_lifetime' if no default policy has been defined in the server's
+ configuration).
+
+ Args:
+ room_id (str): The ID of the room to get the retention policy of.
+
+ Returns:
+ dict[int, int]: "min_lifetime" and "max_lifetime" for this room.
+ """
+ # If the room retention feature is disabled, return a policy with no minimum nor
+ # maximum, in order not to filter out events we should filter out when sending to
+ # the client.
+ if not self.config.retention_enabled:
+ defer.returnValue({"min_lifetime": None, "max_lifetime": None})
+
+ def get_retention_policy_for_room_txn(txn):
+ txn.execute(
+ """
+ SELECT min_lifetime, max_lifetime FROM room_retention
+ INNER JOIN current_state_events USING (event_id, room_id)
+ WHERE room_id = ?;
+ """,
+ (room_id,),
+ )
+
+ return self.cursor_to_dict(txn)
+
+ ret = yield self.runInteraction(
+ "get_retention_policy_for_room", get_retention_policy_for_room_txn
+ )
+
+ # If we don't know this room ID, ret will be None, in this case return the default
+ # policy.
+ if not ret:
+ defer.returnValue(
+ {
+ "min_lifetime": self.config.retention_default_min_lifetime,
+ "max_lifetime": self.config.retention_default_max_lifetime,
+ }
+ )
+
+ row = ret[0]
+
+ # If one of the room's policy's attributes isn't defined, use the matching
+ # attribute from the default policy.
+ # The default values will be None if no default policy has been defined, or if one
+ # of the attributes is missing from the default policy.
+ if row["min_lifetime"] is None:
+ row["min_lifetime"] = self.config.retention_default_min_lifetime
+
+ if row["max_lifetime"] is None:
+ row["max_lifetime"] = self.config.retention_default_max_lifetime
+
+ defer.returnValue(row)
class RoomStore(RoomWorkerStore, SearchStore):
+ def __init__(self, db_conn, hs):
+ super(RoomStore, self).__init__(db_conn, hs)
+
+ self.config = hs.config
+
+ self.register_background_update_handler(
+ "insert_room_retention", self._background_insert_retention
+ )
+
+ @defer.inlineCallbacks
+ def _background_insert_retention(self, progress, batch_size):
+ """Retrieves a list of all rooms within a range and inserts an entry for each of
+ them into the room_retention table.
+ NULLs the property's columns if missing from the retention event in the room's
+ state (or NULLs all of them if there's no retention event in the room's state),
+ so that we fall back to the server's retention policy.
+ """
+
+ last_room = progress.get("room_id", "")
+
+ def _background_insert_retention_txn(txn):
+ txn.execute(
+ """
+ SELECT state.room_id, state.event_id, events.json
+ FROM current_state_events as state
+ LEFT JOIN event_json AS events ON (state.event_id = events.event_id)
+ WHERE state.room_id > ? AND state.type = '%s'
+ ORDER BY state.room_id ASC
+ LIMIT ?;
+ """
+ % EventTypes.Retention,
+ (last_room, batch_size),
+ )
+
+ rows = self.cursor_to_dict(txn)
+
+ if not rows:
+ return True
+
+ for row in rows:
+ if not row["json"]:
+ retention_policy = {}
+ else:
+ ev = json.loads(row["json"])
+ retention_policy = json.dumps(ev["content"])
+
+ self._simple_insert_txn(
+ txn=txn,
+ table="room_retention",
+ values={
+ "room_id": row["room_id"],
+ "event_id": row["event_id"],
+ "min_lifetime": retention_policy.get("min_lifetime"),
+ "max_lifetime": retention_policy.get("max_lifetime"),
+ },
+ )
+
+ logger.info("Inserted %d rows into room_retention", len(rows))
+
+ self._background_update_progress_txn(
+ txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]}
+ )
+
+ if batch_size > len(rows):
+ return True
+ else:
+ return False
+
+ end = yield self.runInteraction(
+ "insert_room_retention", _background_insert_retention_txn
+ )
+
+ if end:
+ yield self._end_background_update("insert_room_retention")
+
+ defer.returnValue(batch_size)
+
@defer.inlineCallbacks
def store_room(self, room_id, room_creator_user_id, is_public):
"""Stores a room.
@@ -439,6 +596,35 @@ class RoomStore(RoomWorkerStore, SearchStore):
)
txn.execute(sql, (event.event_id, event.room_id, event.content[key]))
+ def _store_retention_policy_for_room_txn(self, txn, event):
+ if hasattr(event, "content") and (
+ "min_lifetime" in event.content or "max_lifetime" in event.content
+ ):
+ if (
+ "min_lifetime" in event.content
+ and not isinstance(event.content.get("min_lifetime"), integer_types)
+ ) or (
+ "max_lifetime" in event.content
+ and not isinstance(event.content.get("max_lifetime"), integer_types)
+ ):
+ # Ignore the event if one of the value isn't an integer.
+ return
+
+ self._simple_insert_txn(
+ txn=txn,
+ table="room_retention",
+ values={
+ "room_id": event.room_id,
+ "event_id": event.event_id,
+ "min_lifetime": event.content.get("min_lifetime"),
+ "max_lifetime": event.content.get("max_lifetime"),
+ },
+ )
+
+ self._invalidate_cache_and_stream(
+ txn, self.get_retention_policy_for_room, (event.room_id,)
+ )
+
def add_event_report(
self, room_id, event_id, user_id, reason, content, received_ts
):
@@ -620,3 +806,89 @@ class RoomStore(RoomWorkerStore, SearchStore):
remote_media_mxcs.append((hostname, media_id))
return local_media_mxcs, remote_media_mxcs
+
+ @defer.inlineCallbacks
+ def get_rooms_for_retention_period_in_range(
+ self, min_ms, max_ms, include_null=False
+ ):
+ """Retrieves all of the rooms within the given retention range.
+
+ Optionally includes the rooms which don't have a retention policy.
+
+ Args:
+ min_ms (int|None): Duration in milliseconds that define the lower limit of
+ the range to handle (exclusive). If None, doesn't set a lower limit.
+ max_ms (int|None): Duration in milliseconds that define the upper limit of
+ the range to handle (inclusive). If None, doesn't set an upper limit.
+ include_null (bool): Whether to include rooms which retention policy is NULL
+ in the returned set.
+
+ Returns:
+ dict[str, dict]: The rooms within this range, along with their retention
+ policy. The key is "room_id", and maps to a dict describing the retention
+ policy associated with this room ID. The keys for this nested dict are
+ "min_lifetime" (int|None), and "max_lifetime" (int|None).
+ """
+
+ def get_rooms_for_retention_period_in_range_txn(txn):
+ range_conditions = []
+ args = []
+
+ if min_ms is not None:
+ range_conditions.append("max_lifetime > ?")
+ args.append(min_ms)
+
+ if max_ms is not None:
+ range_conditions.append("max_lifetime <= ?")
+ args.append(max_ms)
+
+ # Do a first query which will retrieve the rooms that have a retention policy
+ # in their current state.
+ sql = """
+ SELECT room_id, min_lifetime, max_lifetime FROM room_retention
+ INNER JOIN current_state_events USING (event_id, room_id)
+ """
+
+ if len(range_conditions):
+ sql += " WHERE (" + " AND ".join(range_conditions) + ")"
+
+ if include_null:
+ sql += " OR max_lifetime IS NULL"
+
+ txn.execute(sql, args)
+
+ rows = self.cursor_to_dict(txn)
+ rooms_dict = {}
+
+ for row in rows:
+ rooms_dict[row["room_id"]] = {
+ "min_lifetime": row["min_lifetime"],
+ "max_lifetime": row["max_lifetime"],
+ }
+
+ if include_null:
+ # If required, do a second query that retrieves all of the rooms we know
+ # of so we can handle rooms with no retention policy.
+ sql = "SELECT DISTINCT room_id FROM current_state_events"
+
+ txn.execute(sql)
+
+ rows = self.cursor_to_dict(txn)
+
+ # If a room isn't already in the dict (i.e. it doesn't have a retention
+ # policy in its state), add it with a null policy.
+ for row in rows:
+ if row["room_id"] not in rooms_dict:
+ rooms_dict[row["room_id"]] = {
+ "min_lifetime": None,
+ "max_lifetime": None,
+ }
+
+ return rooms_dict
+
+ rooms = yield self.runInteraction(
+ "get_rooms_for_retention_period_in_range",
+ get_rooms_for_retention_period_in_range_txn,
+ )
+
+ defer.returnValue(rooms)
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 32cfd010a5..cb88e49b51 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -24,6 +24,8 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import LoggingTransaction
from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import get_domain_from_id
from synapse.util.async_helpers import Linearizer
@@ -53,9 +55,51 @@ ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name"))
MemberSummary = namedtuple("MemberSummary", ("members", "count"))
_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
+_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
class RoomMemberWorkerStore(EventsWorkerStore):
+ def __init__(self, db_conn, hs):
+ super(RoomMemberWorkerStore, self).__init__(db_conn, hs)
+
+ # Is the current_state_events.membership up to date? Or is the
+ # background update still running?
+ self._current_state_events_membership_up_to_date = False
+
+ txn = LoggingTransaction(
+ db_conn.cursor(),
+ name="_check_safe_current_state_events_membership_updated",
+ database_engine=self.database_engine,
+ )
+ self._check_safe_current_state_events_membership_updated_txn(txn)
+ txn.close()
+
+ def _check_safe_current_state_events_membership_updated_txn(self, txn):
+ """Checks if it is safe to assume the new current_state_events
+ membership column is up to date
+ """
+
+ pending_update = self._simple_select_one_txn(
+ txn,
+ table="background_updates",
+ keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME},
+ retcols=["update_name"],
+ allow_none=True,
+ )
+
+ self._current_state_events_membership_up_to_date = not pending_update
+
+ # If the update is still running, reschedule to run.
+ if pending_update:
+ self._clock.call_later(
+ 15.0,
+ run_as_background_process,
+ "_check_safe_current_state_events_membership_updated",
+ self.runInteraction,
+ "_check_safe_current_state_events_membership_updated",
+ self._check_safe_current_state_events_membership_updated_txn,
+ )
+
@cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
def get_hosts_in_room(self, room_id, cache_context):
"""Returns the set of all hosts currently in the room
@@ -64,19 +108,28 @@ class RoomMemberWorkerStore(EventsWorkerStore):
room_id, on_invalidate=cache_context.invalidate
)
hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
- defer.returnValue(hosts)
+ return hosts
@cached(max_entries=100000, iterable=True)
def get_users_in_room(self, room_id):
def f(txn):
- sql = (
- "SELECT m.user_id FROM room_memberships as m"
- " INNER JOIN current_state_events as c"
- " ON m.event_id = c.event_id "
- " AND m.room_id = c.room_id "
- " AND m.user_id = c.state_key"
- " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?"
- )
+ # If we can assume current_state_events.membership is up to date
+ # then we can avoid a join, which is a Very Good Thing given how
+ # frequently this function gets called.
+ if self._current_state_events_membership_up_to_date:
+ sql = """
+ SELECT state_key FROM current_state_events
+ WHERE type = 'm.room.member' AND room_id = ? AND membership = ?
+ """
+ else:
+ sql = """
+ SELECT state_key FROM room_memberships as m
+ INNER JOIN current_state_events as c
+ ON m.event_id = c.event_id
+ AND m.room_id = c.room_id
+ AND m.user_id = c.state_key
+ WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?
+ """
txn.execute(sql, (room_id, Membership.JOIN))
return [to_ascii(r[0]) for r in txn]
@@ -98,15 +151,26 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# first get counts.
# We do this all in one transaction to keep the cache small.
# FIXME: get rid of this when we have room_stats
- sql = """
- SELECT count(*), m.membership FROM room_memberships as m
- INNER JOIN current_state_events as c
- ON m.event_id = c.event_id
- AND m.room_id = c.room_id
- AND m.user_id = c.state_key
- WHERE c.type = 'm.room.member' AND c.room_id = ?
- GROUP BY m.membership
- """
+
+ # If we can assume current_state_events.membership is up to date
+ # then we can avoid a join, which is a Very Good Thing given how
+ # frequently this function gets called.
+ if self._current_state_events_membership_up_to_date:
+ sql = """
+ SELECT count(*), membership FROM current_state_events
+ WHERE type = 'm.room.member' AND room_id = ?
+ GROUP BY membership
+ """
+ else:
+ sql = """
+ SELECT count(*), m.membership FROM room_memberships as m
+ INNER JOIN current_state_events as c
+ ON m.event_id = c.event_id
+ AND m.room_id = c.room_id
+ AND m.user_id = c.state_key
+ WHERE c.type = 'm.room.member' AND c.room_id = ?
+ GROUP BY m.membership
+ """
txn.execute(sql, (room_id,))
res = {}
@@ -189,8 +253,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
invites = yield self.get_invited_rooms_for_user(user_id)
for invite in invites:
if invite.room_id == room_id:
- defer.returnValue(invite)
- defer.returnValue(None)
+ return invite
+ return None
def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
""" Get all the rooms for this user where the membership for this user
@@ -224,7 +288,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
results = []
if membership_list:
where_clause = "user_id = ? AND (%s) AND forgotten = 0" % (
- " OR ".join(["membership = ?" for _ in membership_list]),
+ " OR ".join(["m.membership = ?" for _ in membership_list]),
)
args = [user_id]
@@ -283,11 +347,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
rooms = yield self.get_rooms_for_user_where_membership_is(
user_id, membership_list=[Membership.JOIN]
)
- defer.returnValue(
- frozenset(
- GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
- for r in rooms
- )
+ return frozenset(
+ GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
+ for r in rooms
)
@defer.inlineCallbacks
@@ -297,7 +359,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
rooms = yield self.get_rooms_for_user_with_stream_ordering(
user_id, on_invalidate=on_invalidate
)
- defer.returnValue(frozenset(r.room_id for r in rooms))
+ return frozenset(r.room_id for r in rooms)
@cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
def get_users_who_share_room_with_user(self, user_id, cache_context):
@@ -314,7 +376,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
)
user_who_share_room.update(user_ids)
- defer.returnValue(user_who_share_room)
+ return user_who_share_room
@defer.inlineCallbacks
def get_joined_users_from_context(self, event, context):
@@ -330,7 +392,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
result = yield self._get_joined_users_from_context(
event.room_id, state_group, current_state_ids, event=event, context=context
)
- defer.returnValue(result)
+ return result
def get_joined_users_from_state(self, room_id, state_entry):
state_group = state_entry.state_group
@@ -444,7 +506,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
avatar_url=to_ascii(event.content.get("avatar_url", None)),
)
- defer.returnValue(users_in_room)
+ return users_in_room
@cachedInlineCallbacks(max_entries=10000)
def is_host_joined(self, room_id, host):
@@ -453,8 +515,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
sql = """
SELECT state_key FROM current_state_events AS c
- INNER JOIN room_memberships USING (event_id)
- WHERE membership = 'join'
+ INNER JOIN room_memberships AS m USING (event_id)
+ WHERE m.membership = 'join'
AND type = 'm.room.member'
AND c.room_id = ?
AND state_key LIKE ?
@@ -469,14 +531,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause)
if not rows:
- defer.returnValue(False)
+ return False
user_id = rows[0][0]
if get_domain_from_id(user_id) != host:
# This can only happen if the host name has something funky in it
raise Exception("Invalid host name")
- defer.returnValue(True)
+ return True
@cachedInlineCallbacks()
def was_host_joined(self, room_id, host):
@@ -509,14 +571,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause)
if not rows:
- defer.returnValue(False)
+ return False
user_id = rows[0][0]
if get_domain_from_id(user_id) != host:
# This can only happen if the host name has something funky in it
raise Exception("Invalid host name")
- defer.returnValue(True)
+ return True
def get_joined_hosts(self, room_id, state_entry):
state_group = state_entry.state_group
@@ -543,7 +605,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
cache = self._get_joined_hosts_cache(room_id)
joined_hosts = yield cache.get_destinations(state_entry)
- defer.returnValue(joined_hosts)
+ return joined_hosts
@cached(max_entries=10000)
def _get_joined_hosts_cache(self, room_id):
@@ -573,7 +635,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return rows[0][0]
count = yield self.runInteraction("did_forget_membership", f)
- defer.returnValue(count == 0)
+ return count == 0
@defer.inlineCallbacks
def get_rooms_user_has_been_in(self, user_id):
@@ -602,6 +664,10 @@ class RoomMemberStore(RoomMemberWorkerStore):
self.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
)
+ self.register_background_update_handler(
+ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
+ self._background_current_state_membership,
+ )
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
@@ -779,7 +845,65 @@ class RoomMemberStore(RoomMemberWorkerStore):
if not result:
yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
- defer.returnValue(result)
+ return result
+
+ @defer.inlineCallbacks
+ def _background_current_state_membership(self, progress, batch_size):
+ """Update the new membership column on current_state_events.
+
+ This works by iterating over all rooms in alphebetical order.
+ """
+
+ def _background_current_state_membership_txn(txn, last_processed_room):
+ processed = 0
+ while processed < batch_size:
+ txn.execute(
+ """
+ SELECT MIN(room_id) FROM rooms WHERE room_id > ?
+ """,
+ (last_processed_room,),
+ )
+ row = txn.fetchone()
+ if not row or not row[0]:
+ return processed, True
+
+ next_room, = row
+
+ sql = """
+ UPDATE current_state_events AS c
+ SET membership = (
+ SELECT membership FROM room_memberships
+ WHERE event_id = c.event_id
+ )
+ WHERE room_id = ?
+ """
+ txn.execute(sql, (next_room,))
+ processed += txn.rowcount
+
+ last_processed_room = next_room
+
+ self._background_update_progress_txn(
+ txn,
+ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
+ {"last_processed_room": last_processed_room},
+ )
+
+ return processed, False
+
+ # If we haven't got a last processed room then just use the empty
+ # string, which will compare before all room IDs correctly.
+ last_processed_room = progress.get("last_processed_room", "")
+
+ row_count, finished = yield self.runInteraction(
+ "_background_current_state_membership_update",
+ _background_current_state_membership_txn,
+ last_processed_room,
+ )
+
+ if finished:
+ yield self._end_background_update(_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME)
+
+ return row_count
class _JoinedHostsCache(object):
@@ -807,7 +931,7 @@ class _JoinedHostsCache(object):
state_entry(synapse.state._StateCacheEntry)
"""
if state_entry.state_group == self.state_group:
- defer.returnValue(frozenset(self.hosts_to_joined_users))
+ return frozenset(self.hosts_to_joined_users)
with (yield self.linearizer.queue(())):
if state_entry.state_group == self.state_group:
@@ -844,7 +968,7 @@ class _JoinedHostsCache(object):
else:
self.state_group = object()
self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users))
- defer.returnValue(frozenset(self.hosts_to_joined_users))
+ return frozenset(self.hosts_to_joined_users)
def __len__(self):
return self._len
diff --git a/synapse/storage/schema/delta/48/profiles_batch.sql b/synapse/storage/schema/delta/48/profiles_batch.sql
new file mode 100644
index 0000000000..e744c02fe8
--- /dev/null
+++ b/synapse/storage/schema/delta/48/profiles_batch.sql
@@ -0,0 +1,36 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Add a batch number to track changes to profiles and the
+ * order they're made in so we can replicate user profiles
+ * to other hosts as they change
+ */
+ALTER TABLE profiles ADD COLUMN batch BIGINT DEFAULT NULL;
+
+/*
+ * Index on the batch number so we can get profiles
+ * by their batch
+ */
+CREATE INDEX profiles_batch_idx ON profiles(batch);
+
+/*
+ * A table to track what batch of user profiles has been
+ * synced to what profile replication target.
+ */
+CREATE TABLE profile_replication_status (
+ host TEXT NOT NULL,
+ last_synced_batch BIGINT NOT NULL
+);
diff --git a/synapse/storage/schema/delta/50/profiles_deactivated_users.sql b/synapse/storage/schema/delta/50/profiles_deactivated_users.sql
new file mode 100644
index 0000000000..c8893ecbe8
--- /dev/null
+++ b/synapse/storage/schema/delta/50/profiles_deactivated_users.sql
@@ -0,0 +1,23 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * A flag saying whether the user owning the profile has been deactivated
+ * This really belongs on the users table, not here, but the users table
+ * stores users by their full user_id and profiles stores them by localpart,
+ * so we can't easily join between the two tables. Plus, the batch number
+ * realy ought to represent data in this table that has changed.
+ */
+ALTER TABLE profiles ADD COLUMN active SMALLINT DEFAULT 1 NOT NULL;
diff --git a/synapse/storage/schema/delta/55/profile_replication_status_index.sql b/synapse/storage/schema/delta/55/profile_replication_status_index.sql
new file mode 100644
index 0000000000..18a0f7e10c
--- /dev/null
+++ b/synapse/storage/schema/delta/55/profile_replication_status_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('profile_replication_status_host_index', '{}');
diff --git a/synapse/storage/schema/delta/55/room_retention.sql b/synapse/storage/schema/delta/55/room_retention.sql
new file mode 100644
index 0000000000..ee6cdf7a14
--- /dev/null
+++ b/synapse/storage/schema/delta/55/room_retention.sql
@@ -0,0 +1,33 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Tracks the retention policy of a room.
+-- A NULL max_lifetime or min_lifetime means that the matching property is not defined in
+-- the room's retention policy state event.
+-- If a room doesn't have a retention policy state event in its state, both max_lifetime
+-- and min_lifetime are NULL.
+CREATE TABLE IF NOT EXISTS room_retention(
+ room_id TEXT,
+ event_id TEXT,
+ min_lifetime BIGINT,
+ max_lifetime BIGINT,
+
+ PRIMARY KEY(room_id, event_id)
+);
+
+CREATE INDEX room_retention_max_lifetime_idx on room_retention(max_lifetime);
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('insert_room_retention', '{}');
diff --git a/synapse/storage/schema/delta/56/current_state_events_membership.sql b/synapse/storage/schema/delta/56/current_state_events_membership.sql
new file mode 100644
index 0000000000..b2e08cd85d
--- /dev/null
+++ b/synapse/storage/schema/delta/56/current_state_events_membership.sql
@@ -0,0 +1,25 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We add membership to current state so that we don't need to join against
+-- room_memberships, which can be surprisingly costly (we do such queries
+-- very frequently).
+-- This will be null for non-membership events and the content.membership key
+-- for membership events. (Will also be null for membership events until the
+-- background update job has finished).
+ALTER TABLE current_state_events ADD membership TEXT;
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('current_state_events_membership', '{}');
diff --git a/synapse/storage/schema/full_schemas/54/full.sql.postgres b/synapse/storage/schema/full_schemas/54/full.sql.postgres
index 098434356f..01a2b0e024 100644
--- a/synapse/storage/schema/full_schemas/54/full.sql.postgres
+++ b/synapse/storage/schema/full_schemas/54/full.sql.postgres
@@ -667,10 +667,19 @@ CREATE TABLE presence_stream (
+CREATE TABLE profile_replication_status (
+ host text NOT NULL,
+ last_synced_batch bigint NOT NULL
+);
+
+
+
CREATE TABLE profiles (
user_id text NOT NULL,
displayname text,
- avatar_url text
+ avatar_url text,
+ batch bigint,
+ active smallint DEFAULT 1 NOT NULL
);
@@ -1842,6 +1851,10 @@ CREATE INDEX presence_stream_user_id ON presence_stream USING btree (user_id);
+CREATE INDEX profiles_batch_idx ON profiles USING btree (batch);
+
+
+
CREATE INDEX public_room_index ON rooms USING btree (is_public);
diff --git a/synapse/storage/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/schema/full_schemas/54/full.sql.sqlite
index be9295e4c9..f1a71627f0 100644
--- a/synapse/storage/schema/full_schemas/54/full.sql.sqlite
+++ b/synapse/storage/schema/full_schemas/54/full.sql.sqlite
@@ -6,7 +6,7 @@ CREATE TABLE presence_allow_inbound( observed_user_id TEXT NOT NULL, observer_us
CREATE TABLE users( name TEXT, password_hash TEXT, creation_ts BIGINT, admin SMALLINT DEFAULT 0 NOT NULL, upgrade_ts BIGINT, is_guest SMALLINT DEFAULT 0 NOT NULL, appservice_id TEXT, consent_version TEXT, consent_server_notice_sent TEXT, user_type TEXT DEFAULT NULL, UNIQUE(name) );
CREATE TABLE access_tokens( id BIGINT PRIMARY KEY, user_id TEXT NOT NULL, device_id TEXT, token TEXT NOT NULL, last_used BIGINT, UNIQUE(token) );
CREATE TABLE user_ips ( user_id TEXT NOT NULL, access_token TEXT NOT NULL, device_id TEXT, ip TEXT NOT NULL, user_agent TEXT NOT NULL, last_seen BIGINT NOT NULL );
-CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, UNIQUE(user_id) );
+CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, batch BIGINT DEFAULT NULL, active SMALLINT DEFAULT 1 NOT NULL, UNIQUE(user_id) );
CREATE TABLE received_transactions( transaction_id TEXT, origin TEXT, ts BIGINT, response_code INTEGER, response_json bytea, has_been_referenced smallint default 0, UNIQUE (transaction_id, origin) );
CREATE TABLE destinations( destination TEXT PRIMARY KEY, retry_last_ts BIGINT, retry_interval INTEGER );
CREATE TABLE events( stream_ordering INTEGER PRIMARY KEY, topological_ordering BIGINT NOT NULL, event_id TEXT NOT NULL, type TEXT NOT NULL, room_id TEXT NOT NULL, content TEXT, unrecognized_keys TEXT, processed BOOL NOT NULL, outlier BOOL NOT NULL, depth BIGINT DEFAULT 0 NOT NULL, origin_server_ts BIGINT, received_ts BIGINT, sender TEXT, contains_url BOOLEAN, UNIQUE (event_id) );
@@ -208,6 +208,8 @@ CREATE INDEX group_users_u_idx ON group_users(user_id);
CREATE INDEX group_invites_u_idx ON group_invites(user_id);
CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id);
CREATE INDEX group_rooms_r_idx ON group_rooms(room_id);
+CREATE INDEX profiles_batch_idx ON profiles(batch);
+CREATE TABLE profile_replication_status ( host TEXT NOT NULL, last_synced_batch BIGINT NOT NULL );
CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL, device_id TEXT, timestamp BIGINT NOT NULL );
CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp);
CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp);
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index f3b1cec933..df87ab6a6d 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -166,7 +166,7 @@ class SearchStore(BackgroundUpdateStore):
if not result:
yield self._end_background_update(self.EVENT_SEARCH_UPDATE_NAME)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _background_reindex_gin_search(self, progress, batch_size):
@@ -209,7 +209,7 @@ class SearchStore(BackgroundUpdateStore):
yield self.runWithConnection(create_index)
yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME)
- defer.returnValue(1)
+ return 1
@defer.inlineCallbacks
def _background_reindex_search_order(self, progress, batch_size):
@@ -287,7 +287,7 @@ class SearchStore(BackgroundUpdateStore):
if not finished:
yield self._end_background_update(self.EVENT_SEARCH_ORDER_UPDATE_NAME)
- defer.returnValue(num_rows)
+ return num_rows
def store_event_search_txn(self, txn, event, key, value):
"""Add event to the search table
@@ -454,17 +454,15 @@ class SearchStore(BackgroundUpdateStore):
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
- defer.returnValue(
- {
- "results": [
- {"event": event_map[r["event_id"]], "rank": r["rank"]}
- for r in results
- if r["event_id"] in event_map
- ],
- "highlights": highlights,
- "count": count,
- }
- )
+ return {
+ "results": [
+ {"event": event_map[r["event_id"]], "rank": r["rank"]}
+ for r in results
+ if r["event_id"] in event_map
+ ],
+ "highlights": highlights,
+ "count": count,
+ }
@defer.inlineCallbacks
def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None):
@@ -599,22 +597,20 @@ class SearchStore(BackgroundUpdateStore):
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
- defer.returnValue(
- {
- "results": [
- {
- "event": event_map[r["event_id"]],
- "rank": r["rank"],
- "pagination_token": "%s,%s"
- % (r["origin_server_ts"], r["stream_ordering"]),
- }
- for r in results
- if r["event_id"] in event_map
- ],
- "highlights": highlights,
- "count": count,
- }
- )
+ return {
+ "results": [
+ {
+ "event": event_map[r["event_id"]],
+ "rank": r["rank"],
+ "pagination_token": "%s,%s"
+ % (r["origin_server_ts"], r["stream_ordering"]),
+ }
+ for r in results
+ if r["event_id"] in event_map
+ ],
+ "highlights": highlights,
+ "count": count,
+ }
def _find_highlights_in_postgres(self, search_query, events):
"""Given a list of events and a search term, return a list of words
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index 6bd81e84ad..fb83218f90 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -59,7 +59,7 @@ class SignatureWorkerStore(SQLBaseStore):
for e_id, h in hashes.items()
}
- defer.returnValue(list(hashes.items()))
+ return list(hashes.items())
def _get_event_reference_hashes_txn(self, txn, event_id):
"""Get all the hashes for a given PDU.
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 0bfe1b4550..1980a87108 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -422,7 +422,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# Retrieve the room's create event
create_event = yield self.get_create_event_for_room(room_id)
- defer.returnValue(create_event.content.get("room_version", "1"))
+ return create_event.content.get("room_version", "1")
@defer.inlineCallbacks
def get_room_predecessor(self, room_id):
@@ -442,7 +442,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
create_event = yield self.get_create_event_for_room(room_id)
# Return predecessor if present
- defer.returnValue(create_event.content.get("predecessor", None))
+ return create_event.content.get("predecessor", None)
@defer.inlineCallbacks
def get_create_event_for_room(self, room_id):
@@ -466,7 +466,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# Retrieve the room's create event and return
create_event = yield self.get_event(create_id)
- defer.returnValue(create_event)
+ return create_event
@cached(max_entries=100000, iterable=True)
def get_current_state_ids(self, room_id):
@@ -510,6 +510,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
event ID.
"""
+ where_clause, where_args = state_filter.make_sql_filter_clause()
+
+ if not where_clause:
+ # We delegate to the cached version
+ return self.get_current_state_ids(room_id)
+
def _get_filtered_current_state_ids_txn(txn):
results = {}
sql = """
@@ -517,8 +523,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
WHERE room_id = ?
"""
- where_clause, where_args = state_filter.make_sql_filter_clause()
-
if where_clause:
sql += " AND (%s)" % (where_clause,)
@@ -559,7 +563,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
if not event:
return
- defer.returnValue(event.content.get("canonical_alias"))
+ return event.content.get("canonical_alias")
@cached(max_entries=10000, iterable=True)
def get_state_group_delta(self, state_group):
@@ -609,14 +613,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
if not event_ids:
- defer.returnValue({})
+ return {}
event_to_groups = yield self._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups)
- defer.returnValue(group_to_state)
+ return group_to_state
@defer.inlineCallbacks
def get_state_ids_for_group(self, state_group):
@@ -630,7 +634,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
group_to_state = yield self._get_state_for_groups((state_group,))
- defer.returnValue(group_to_state[state_group])
+ return group_to_state[state_group]
@defer.inlineCallbacks
def get_state_groups(self, room_id, event_ids):
@@ -641,7 +645,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
dict of state_group_id -> list of state events.
"""
if not event_ids:
- defer.returnValue({})
+ return {}
group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
@@ -654,16 +658,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
get_prev_content=False,
)
- defer.returnValue(
- {
- group: [
- state_event_map[v]
- for v in itervalues(event_id_map)
- if v in state_event_map
- ]
- for group, event_id_map in iteritems(group_to_ids)
- }
- )
+ return {
+ group: [
+ state_event_map[v]
+ for v in itervalues(event_id_map)
+ if v in state_event_map
+ ]
+ for group, event_id_map in iteritems(group_to_ids)
+ }
@defer.inlineCallbacks
def _get_state_groups_from_groups(self, groups, state_filter):
@@ -690,7 +692,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
)
results.update(res)
- defer.returnValue(results)
+ return results
def _get_state_groups_from_groups_txn(
self, txn, groups, state_filter=StateFilter.all()
@@ -825,7 +827,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
for event_id, group in iteritems(event_to_groups)
}
- defer.returnValue({event: event_to_state[event] for event in event_ids})
+ return {event: event_to_state[event] for event in event_ids}
@defer.inlineCallbacks
def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
@@ -851,7 +853,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
for event_id, group in iteritems(event_to_groups)
}
- defer.returnValue({event: event_to_state[event] for event in event_ids})
+ return {event: event_to_state[event] for event in event_ids}
@defer.inlineCallbacks
def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
@@ -867,7 +869,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
A deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_for_events([event_id], state_filter)
- defer.returnValue(state_map[event_id])
+ return state_map[event_id]
@defer.inlineCallbacks
def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
@@ -883,7 +885,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
A deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_ids_for_events([event_id], state_filter)
- defer.returnValue(state_map[event_id])
+ return state_map[event_id]
@cached(max_entries=50000)
def _get_state_group_for_event(self, event_id):
@@ -913,7 +915,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="_get_state_group_for_events",
)
- defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
+ return {row["event_id"]: row["state_group"] for row in rows}
def _get_state_for_group_using_cache(self, cache, group, state_filter):
"""Checks if group is in cache. See `_get_state_for_groups`
@@ -993,7 +995,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
incomplete_groups = incomplete_groups_m | incomplete_groups_nm
if not incomplete_groups:
- defer.returnValue(state)
+ return state
cache_sequence_nm = self._state_group_cache.sequence
cache_sequence_m = self._state_group_members_cache.sequence
@@ -1020,7 +1022,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# everything we need from the database anyway.
state[group] = state_filter.filter_state(group_state_dict)
- defer.returnValue(state)
+ return state
def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
"""Gets the state at each of a list of state groups, optionally
@@ -1494,7 +1496,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
)
- defer.returnValue(result * BATCH_SIZE_SCALE_FACTOR)
+ return result * BATCH_SIZE_SCALE_FACTOR
@defer.inlineCallbacks
def _background_index_state(self, progress, batch_size):
@@ -1524,4 +1526,4 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
- defer.returnValue(1)
+ return 1
diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py
index 1cec84ee2e..e13efed417 100644
--- a/synapse/storage/stats.py
+++ b/synapse/storage/stats.py
@@ -66,7 +66,7 @@ class StatsStore(StateDeltasStore):
if not self.stats_enabled:
yield self._end_background_update("populate_stats_createtables")
- defer.returnValue(1)
+ return 1
# Get all the rooms that we want to process.
def _make_staging_area(txn):
@@ -120,7 +120,7 @@ class StatsStore(StateDeltasStore):
self.get_earliest_token_for_room_stats.invalidate_all()
yield self._end_background_update("populate_stats_createtables")
- defer.returnValue(1)
+ return 1
@defer.inlineCallbacks
def _populate_stats_cleanup(self, progress, batch_size):
@@ -129,7 +129,7 @@ class StatsStore(StateDeltasStore):
"""
if not self.stats_enabled:
yield self._end_background_update("populate_stats_cleanup")
- defer.returnValue(1)
+ return 1
position = yield self._simple_select_one_onecol(
TEMP_TABLE + "_position", None, "position"
@@ -143,14 +143,14 @@ class StatsStore(StateDeltasStore):
yield self.runInteraction("populate_stats_cleanup", _delete_staging_area)
yield self._end_background_update("populate_stats_cleanup")
- defer.returnValue(1)
+ return 1
@defer.inlineCallbacks
def _populate_stats_process_rooms(self, progress, batch_size):
if not self.stats_enabled:
yield self._end_background_update("populate_stats_process_rooms")
- defer.returnValue(1)
+ return 1
# If we don't have progress filed, delete everything.
if not progress:
@@ -186,7 +186,7 @@ class StatsStore(StateDeltasStore):
# No more rooms -- complete the transaction.
if not rooms_to_work_on:
yield self._end_background_update("populate_stats_process_rooms")
- defer.returnValue(1)
+ return 1
logger.info(
"Processing the next %d rooms of %d remaining",
@@ -211,16 +211,18 @@ class StatsStore(StateDeltasStore):
avatar_id = current_state_ids.get((EventTypes.RoomAvatar, ""))
canonical_alias_id = current_state_ids.get((EventTypes.CanonicalAlias, ""))
+ event_ids = [
+ join_rules_id,
+ history_visibility_id,
+ encryption_id,
+ name_id,
+ topic_id,
+ avatar_id,
+ canonical_alias_id,
+ ]
+
state_events = yield self.get_events(
- [
- join_rules_id,
- history_visibility_id,
- encryption_id,
- name_id,
- topic_id,
- avatar_id,
- canonical_alias_id,
- ]
+ [ev for ev in event_ids if ev is not None]
)
def _get_or_none(event_id, arg):
@@ -303,9 +305,9 @@ class StatsStore(StateDeltasStore):
if processed_event_count > batch_size:
# Don't process any more rooms, we've hit our batch size.
- defer.returnValue(processed_event_count)
+ return processed_event_count
- defer.returnValue(processed_event_count)
+ return processed_event_count
def delete_all_stats(self):
"""
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index a0465484df..856c2ee8d8 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -300,7 +300,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
if not room_ids:
- defer.returnValue({})
+ return {}
results = {}
room_ids = list(room_ids)
@@ -323,7 +323,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
results.update(dict(zip(rm_ids, res)))
- defer.returnValue(results)
+ return results
def get_rooms_that_changed(self, room_ids, from_key):
"""Given a list of rooms and a token, return rooms where there may have
@@ -364,7 +364,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
the chunk of events returned.
"""
if from_key == to_key:
- defer.returnValue(([], from_key))
+ return ([], from_key)
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
@@ -374,7 +374,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
if not has_changed:
- defer.returnValue(([], from_key))
+ return ([], from_key)
def f(txn):
sql = (
@@ -407,7 +407,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# get.
key = from_key
- defer.returnValue((ret, key))
+ return (ret, key)
@defer.inlineCallbacks
def get_membership_changes_for_user(self, user_id, from_key, to_key):
@@ -415,14 +415,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
to_id = RoomStreamToken.parse_stream_token(to_key).stream
if from_key == to_key:
- defer.returnValue([])
+ return []
if from_id:
has_changed = self._membership_stream_cache.has_entity_changed(
user_id, int(from_id)
)
if not has_changed:
- defer.returnValue([])
+ return []
def f(txn):
sql = (
@@ -447,7 +447,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self._set_before_and_after(ret, rows, topo_order=False)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def get_recent_events_for_room(self, room_id, limit, end_token):
@@ -477,7 +477,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self._set_before_and_after(events, rows)
- defer.returnValue((events, token))
+ return (events, token)
@defer.inlineCallbacks
def get_recent_event_ids_for_room(self, room_id, limit, end_token):
@@ -496,7 +496,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
# Allow a zero limit here, and no-op.
if limit == 0:
- defer.returnValue(([], end_token))
+ return ([], end_token)
end_token = RoomStreamToken.parse(end_token)
@@ -511,7 +511,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# We want to return the results in ascending order.
rows.reverse()
- defer.returnValue((rows, token))
+ return (rows, token)
def get_room_event_after_stream_ordering(self, room_id, stream_ordering):
"""Gets details of the first event in a room at or after a stream ordering
@@ -549,12 +549,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
token = yield self.get_room_max_stream_ordering()
if room_id is None:
- defer.returnValue("s%d" % (token,))
+ return "s%d" % (token,)
else:
topo = yield self.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn, room_id
)
- defer.returnValue("t%d-%d" % (topo, token))
+ return "t%d-%d" % (topo, token)
def get_stream_token_for_event(self, event_id):
"""The stream token for an event
@@ -674,14 +674,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
[e for e in results["after"]["event_ids"]], get_prev_content=True
)
- defer.returnValue(
- {
- "events_before": events_before,
- "events_after": events_after,
- "start": results["before"]["token"],
- "end": results["after"]["token"],
- }
- )
+ return {
+ "events_before": events_before,
+ "events_after": events_after,
+ "start": results["before"]["token"],
+ "end": results["after"]["token"],
+ }
def _get_events_around_txn(
self, txn, room_id, event_id, before_limit, after_limit, event_filter
@@ -785,7 +783,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
events = yield self.get_events_as_list(event_ids)
- defer.returnValue((upper_bound, events))
+ return (upper_bound, events)
def get_federation_out_pos(self, typ):
return self._simple_select_one_onecol(
@@ -939,7 +937,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self._set_before_and_after(events, rows)
- defer.returnValue((events, token))
+ return (events, token)
class StreamStore(StreamWorkerStore):
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index e88f8ea35f..20dd6bd53d 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -66,7 +66,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
room_id string, tag string and content string.
"""
if last_id == current_id:
- defer.returnValue([])
+ return []
def get_all_updated_tags_txn(txn):
sql = (
@@ -107,7 +107,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
)
results.extend(tags)
- defer.returnValue(results)
+ return results
@defer.inlineCallbacks
def get_updated_tags(self, user_id, stream_id):
@@ -135,7 +135,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
user_id, int(stream_id)
)
if not changed:
- defer.returnValue({})
+ return {}
room_ids = yield self.runInteraction("get_updated_tags", get_updated_tags_txn)
@@ -145,7 +145,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
for room_id in room_ids:
results[room_id] = tags_by_room.get(room_id, {})
- defer.returnValue(results)
+ return results
def get_tags_for_room(self, user_id, room_id):
"""Get all the tags for the given room
@@ -194,7 +194,7 @@ class TagsStore(TagsWorkerStore):
self.get_tags_for_user.invalidate((user_id,))
result = self._account_data_id_gen.get_current_token()
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def remove_tag_from_room(self, user_id, room_id, tag):
@@ -217,7 +217,7 @@ class TagsStore(TagsWorkerStore):
self.get_tags_for_user.invalidate((user_id,))
result = self._account_data_id_gen.get_current_token()
- defer.returnValue(result)
+ return result
def _update_revision_txn(self, txn, user_id, room_id, next_id):
"""Update the latest revision of the tags for the given user and room.
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index fd18619178..b3c3bf55bc 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -147,7 +147,7 @@ class TransactionStore(SQLBaseStore):
result = self._destination_retry_cache.get(destination, SENTINEL)
if result is not SENTINEL:
- defer.returnValue(result)
+ return result
result = yield self.runInteraction(
"get_destination_retry_timings",
@@ -158,7 +158,7 @@ class TransactionStore(SQLBaseStore):
# We don't hugely care about race conditions between getting and
# invalidating the cache, since we time out fairly quickly anyway.
self._destination_retry_cache[destination] = result
- defer.returnValue(result)
+ return result
def _get_destination_retry_timings(self, txn, destination):
result = self._simple_select_one_txn(
@@ -196,6 +196,26 @@ class TransactionStore(SQLBaseStore):
def _set_destination_retry_timings(
self, txn, destination, retry_last_ts, retry_interval
):
+
+ if self.database_engine.can_native_upsert:
+ # Upsert retry time interval if retry_interval is zero (i.e. we're
+ # resetting it) or greater than the existing retry interval.
+
+ sql = """
+ INSERT INTO destinations (destination, retry_last_ts, retry_interval)
+ VALUES (?, ?, ?)
+ ON CONFLICT (destination) DO UPDATE SET
+ retry_last_ts = EXCLUDED.retry_last_ts,
+ retry_interval = EXCLUDED.retry_interval
+ WHERE
+ EXCLUDED.retry_interval = 0
+ OR destinations.retry_interval < EXCLUDED.retry_interval
+ """
+
+ txn.execute(sql, (destination, retry_last_ts, retry_interval))
+
+ return
+
self.database_engine.lock_table(txn, "destinations")
# We need to be careful here as the data may have changed from under us
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index 83466e25d9..b5188d9bee 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -109,7 +109,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
yield self._end_background_update("populate_user_directory_createtables")
- defer.returnValue(1)
+ return 1
@defer.inlineCallbacks
def _populate_user_directory_cleanup(self, progress, batch_size):
@@ -131,7 +131,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
)
yield self._end_background_update("populate_user_directory_cleanup")
- defer.returnValue(1)
+ return 1
@defer.inlineCallbacks
def _populate_user_directory_process_rooms(self, progress, batch_size):
@@ -177,7 +177,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
# No more rooms -- complete the transaction.
if not rooms_to_work_on:
yield self._end_background_update("populate_user_directory_process_rooms")
- defer.returnValue(1)
+ return 1
logger.info(
"Processing the next %d rooms of %d remaining"
@@ -257,9 +257,9 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
if processed_event_count > batch_size:
# Don't process any more rooms, we've hit our batch size.
- defer.returnValue(processed_event_count)
+ return processed_event_count
- defer.returnValue(processed_event_count)
+ return processed_event_count
@defer.inlineCallbacks
def _populate_user_directory_process_users(self, progress, batch_size):
@@ -268,7 +268,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
"""
if not self.hs.config.user_directory_search_all_users:
yield self._end_background_update("populate_user_directory_process_users")
- defer.returnValue(1)
+ return 1
def _get_next_batch(txn):
sql = "SELECT user_id FROM %s LIMIT %s" % (
@@ -298,7 +298,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
# No more users -- complete the transaction.
if not users_to_work_on:
yield self._end_background_update("populate_user_directory_process_users")
- defer.returnValue(1)
+ return 1
logger.info(
"Processing the next %d users of %d remaining"
@@ -322,7 +322,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
progress,
)
- defer.returnValue(len(users_to_work_on))
+ return len(users_to_work_on)
@defer.inlineCallbacks
def is_room_world_readable_or_publicly_joinable(self, room_id):
@@ -344,16 +344,16 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
join_rule_ev = yield self.get_event(join_rules_id, allow_none=True)
if join_rule_ev:
if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
- defer.returnValue(True)
+ return True
hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
if hist_vis_id:
hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True)
if hist_vis_ev:
if hist_vis_ev.content.get("history_visibility") == "world_readable":
- defer.returnValue(True)
+ return True
- defer.returnValue(False)
+ return False
def update_profile_in_user_dir(self, user_id, display_name, avatar_url):
"""
@@ -499,7 +499,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
user_ids = set(user_ids_share_pub)
user_ids.update(user_ids_share_priv)
- defer.returnValue(user_ids)
+ return user_ids
def add_users_who_share_private_room(self, room_id, user_id_tuples):
"""Insert entries into the users_who_share_private_rooms table. The first
@@ -609,7 +609,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
users = set(pub_rows)
users.update(rows)
- defer.returnValue(list(users))
+ return list(users)
@defer.inlineCallbacks
def get_rooms_in_common_for_users(self, user_id, other_user_id):
@@ -618,15 +618,15 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
sql = """
SELECT room_id FROM (
SELECT c.room_id FROM current_state_events AS c
- INNER JOIN room_memberships USING (event_id)
+ INNER JOIN room_memberships AS m USING (event_id)
WHERE type = 'm.room.member'
- AND membership = 'join'
+ AND m.membership = 'join'
AND state_key = ?
) AS f1 INNER JOIN (
SELECT c.room_id FROM current_state_events AS c
- INNER JOIN room_memberships USING (event_id)
+ INNER JOIN room_memberships AS m USING (event_id)
WHERE type = 'm.room.member'
- AND membership = 'join'
+ AND m.membership = 'join'
AND state_key = ?
) f2 USING (room_id)
"""
@@ -635,7 +635,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
"get_rooms_in_common_for_users", None, sql, user_id, other_user_id
)
- defer.returnValue([room_id for room_id, in rows])
+ return [room_id for room_id, in rows]
def delete_all_from_user_dir(self):
"""Delete the entire user directory
@@ -782,7 +782,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
limited = len(results) > limit
- defer.returnValue({"limited": limited, "results": results})
+ return {"limited": limited, "results": results}
def _parse_query_sqlite(search_term):
diff --git a/synapse/storage/user_erasure_store.py b/synapse/storage/user_erasure_store.py
index 1815fdc0dd..05cabc2282 100644
--- a/synapse/storage/user_erasure_store.py
+++ b/synapse/storage/user_erasure_store.py
@@ -12,9 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import operator
-from twisted.internet import defer
+import operator
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList
@@ -67,7 +66,7 @@ class UserErasureWorkerStore(SQLBaseStore):
erased_users = yield self.runInteraction("are_users_erased", _get_erased_users)
res = dict((u, u in erased_users) for u in user_ids)
- defer.returnValue(res)
+ return res
class UserErasureStore(UserErasureWorkerStore):
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 488c49747a..b91fb2db7b 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -56,7 +56,7 @@ class EventSources(object):
device_list_key=device_list_key,
groups_key=groups_key,
)
- defer.returnValue(token)
+ return token
@defer.inlineCallbacks
def get_current_token_for_pagination(self):
@@ -80,4 +80,4 @@ class EventSources(object):
device_list_key=0,
groups_key=0,
)
- defer.returnValue(token)
+ return token
diff --git a/synapse/third_party_rules/__init__.py b/synapse/third_party_rules/__init__.py
new file mode 100644
index 0000000000..1453d04571
--- /dev/null
+++ b/synapse/third_party_rules/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/synapse/third_party_rules/access_rules.py b/synapse/third_party_rules/access_rules.py
new file mode 100644
index 0000000000..253bba664b
--- /dev/null
+++ b/synapse/third_party_rules/access_rules.py
@@ -0,0 +1,586 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import email.utils
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, JoinRules, Membership, RoomCreationPreset
+from synapse.api.errors import SynapseError
+from synapse.config._base import ConfigError
+from synapse.types import get_domain_from_id
+
+ACCESS_RULES_TYPE = "im.vector.room.access_rules"
+ACCESS_RULE_RESTRICTED = "restricted"
+ACCESS_RULE_UNRESTRICTED = "unrestricted"
+ACCESS_RULE_DIRECT = "direct"
+
+VALID_ACCESS_RULES = (
+ ACCESS_RULE_DIRECT,
+ ACCESS_RULE_RESTRICTED,
+ ACCESS_RULE_UNRESTRICTED,
+)
+
+# Rules to which we need to apply the power levels restrictions.
+#
+# These are all of the rules that neither:
+# * forbid users from joining based on a server blacklist (which means that there
+# is no need to apply power level restrictions), nor
+# * target direct chats (since we allow both users to be room admins in this case).
+#
+# The power-level restrictions, when they are applied, prevent the following:
+# * the default power level for users (users_default) being set to anything other than 0.
+# * a non-default power level being assigned to any user which would be forbidden from
+# joining a restricted room.
+RULES_WITH_RESTRICTED_POWER_LEVELS = (ACCESS_RULE_UNRESTRICTED,)
+
+
+class RoomAccessRules(object):
+ """Implementation of the ThirdPartyEventRules module API that allows federation admins
+ to define custom rules for specific events and actions.
+ Implements the custom behaviour for the "im.vector.room.access_rules" state event.
+
+ Takes a config in the format:
+
+ third_party_event_rules:
+ module: third_party_rules.RoomAccessRules
+ config:
+ # List of domains (server names) that can't be invited to rooms if the
+ # "restricted" rule is set. Defaults to an empty list.
+ domains_forbidden_when_restricted: []
+
+ # Identity server to use when checking the HS an email address belongs to
+ # using the /info endpoint. Required.
+ id_server: "vector.im"
+
+ Don't forget to consider if you can invite users from your own domain.
+ """
+
+ def __init__(self, config, http_client):
+ self.http_client = http_client
+
+ self.id_server = config["id_server"]
+
+ self.domains_forbidden_when_restricted = config.get(
+ "domains_forbidden_when_restricted", []
+ )
+
+ @staticmethod
+ def parse_config(config):
+ if "id_server" in config:
+ return config
+ else:
+ raise ConfigError("No IS for event rules TchapEventRules")
+
+ def on_create_room(self, requester, config, is_requester_admin):
+ """Implements synapse.events.ThirdPartyEventRules.on_create_room
+
+ Checks if a im.vector.room.access_rules event is being set during room creation.
+ If yes, make sure the event is correct. Otherwise, append an event with the
+ default rule to the initial state.
+ """
+ is_direct = config.get("is_direct")
+ preset = config.get("preset")
+ access_rule = None
+ join_rule = None
+
+ # If there's a rules event in the initial state, check if it complies with the
+ # spec for im.vector.room.access_rules and deny the request if not.
+ for event in config.get("initial_state", []):
+ if event["type"] == ACCESS_RULES_TYPE:
+ access_rule = event["content"].get("rule")
+
+ # Make sure the event has a valid content.
+ if access_rule is None:
+ raise SynapseError(400, "Invalid access rule")
+
+ # Make sure the rule name is valid.
+ if access_rule not in VALID_ACCESS_RULES:
+ raise SynapseError(400, "Invalid access rule")
+
+ # Make sure the rule is "direct" if the room is a direct chat.
+ if (is_direct and access_rule != ACCESS_RULE_DIRECT) or (
+ access_rule == ACCESS_RULE_DIRECT and not is_direct
+ ):
+ raise SynapseError(400, "Invalid access rule")
+
+ if event["type"] == EventTypes.JoinRules:
+ join_rule = event["content"].get("join_rule")
+
+ if access_rule is None:
+ # If there's no access rules event in the initial state, create one with the
+ # default setting.
+ if is_direct:
+ default_rule = ACCESS_RULE_DIRECT
+ else:
+ # If the default value for non-direct chat changes, we should make another
+ # case here for rooms created with either a "public" join_rule or the
+ # "public_chat" preset to make sure those keep defaulting to "restricted"
+ default_rule = ACCESS_RULE_RESTRICTED
+
+ if not config.get("initial_state"):
+ config["initial_state"] = []
+
+ config["initial_state"].append(
+ {
+ "type": ACCESS_RULES_TYPE,
+ "state_key": "",
+ "content": {"rule": default_rule},
+ }
+ )
+
+ access_rule = default_rule
+
+ # Check that the preset or the join rule in use is compatible with the access
+ # rule, whether it's a user-defined one or the default one (i.e. if it involves
+ # a "public" join rule, the access rule must be "restricted").
+ if (
+ join_rule == JoinRules.PUBLIC or preset == RoomCreationPreset.PUBLIC_CHAT
+ ) and access_rule != ACCESS_RULE_RESTRICTED:
+ raise SynapseError(400, "Invalid access rule")
+
+ # Check if the creator can override values for the power levels.
+ allowed = self._is_power_level_content_allowed(
+ config.get("power_level_content_override", {}), access_rule
+ )
+ if not allowed:
+ raise SynapseError(400, "Invalid power levels content override")
+
+ # Second loop for events we need to know the current rule to process.
+ for event in config.get("initial_state", []):
+ if event["type"] == EventTypes.PowerLevels:
+ allowed = self._is_power_level_content_allowed(
+ event["content"], access_rule
+ )
+ if not allowed:
+ raise SynapseError(400, "Invalid power levels content")
+
+ @defer.inlineCallbacks
+ def check_threepid_can_be_invited(self, medium, address, state_events):
+ """Implements synapse.events.ThirdPartyEventRules.check_threepid_can_be_invited
+
+ Check if a threepid can be invited to the room via a 3PID invite given the current
+ rules and the threepid's address, by retrieving the HS it's mapped to from the
+ configured identity server, and checking if we can invite users from it.
+ """
+ rule = self._get_rule_from_state(state_events)
+
+ if medium != "email":
+ defer.returnValue(False)
+
+ if rule != ACCESS_RULE_RESTRICTED:
+ # Only "restricted" requires filtering 3PID invites. We don't need to do
+ # anything for "direct" here, because only "restricted" requires filtering
+ # based on the HS the address is mapped to.
+ defer.returnValue(True)
+
+ parsed_address = email.utils.parseaddr(address)[1]
+ if parsed_address != address:
+ # Avoid reproducing the security issue described here:
+ # https://matrix.org/blog/2019/04/18/security-update-sydent-1-0-2
+ # It's probably not worth it but let's just be overly safe here.
+ defer.returnValue(False)
+
+ # Get the HS this address belongs to from the identity server.
+ res = yield self.http_client.get_json(
+ "https://%s/_matrix/identity/api/v1/info" % (self.id_server,),
+ {"medium": medium, "address": address},
+ )
+
+ # Look for a domain that's not forbidden from being invited.
+ if not res.get("hs"):
+ defer.returnValue(False)
+ if res.get("hs") in self.domains_forbidden_when_restricted:
+ defer.returnValue(False)
+
+ defer.returnValue(True)
+
+ def check_event_allowed(self, event, state_events):
+ """Implements synapse.events.ThirdPartyEventRules.check_event_allowed
+
+ Checks the event's type and the current rule and calls the right function to
+ determine whether the event can be allowed.
+ """
+ if event.type == ACCESS_RULES_TYPE:
+ return self._on_rules_change(event, state_events)
+
+ # We need to know the rule to apply when processing the event types below.
+ rule = self._get_rule_from_state(state_events)
+
+ if event.type == EventTypes.PowerLevels:
+ return self._is_power_level_content_allowed(event.content, rule)
+
+ if event.type == EventTypes.Member or event.type == EventTypes.ThirdPartyInvite:
+ return self._on_membership_or_invite(event, rule, state_events)
+
+ if event.type == EventTypes.JoinRules:
+ return self._on_join_rule_change(event, rule)
+
+ if event.type == EventTypes.RoomAvatar:
+ return self._on_room_avatar_change(event, rule)
+
+ if event.type == EventTypes.Name:
+ return self._on_room_name_change(event, rule)
+
+ if event.type == EventTypes.Topic:
+ return self._on_room_topic_change(event, rule)
+
+ return True
+
+ def _on_rules_change(self, event, state_events):
+ """Implement the checks and behaviour specified on allowing or forbidding a new
+ im.vector.room.access_rules event.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ state_events (dict[tuple[event type, state key], EventBase]): The state of the
+ room before the event was sent.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ new_rule = event.content.get("rule")
+
+ # Check for invalid values.
+ if new_rule not in VALID_ACCESS_RULES:
+ return False
+
+ # We must not allow rooms with the "public" join rule to be given any other access
+ # rule than "restricted".
+ join_rule = self._get_join_rule_from_state(state_events)
+ if join_rule == JoinRules.PUBLIC and new_rule != ACCESS_RULE_RESTRICTED:
+ return False
+
+ # Make sure we don't apply "direct" if the room has more than two members.
+ if new_rule == ACCESS_RULE_DIRECT:
+ existing_members, threepid_tokens = self._get_members_and_tokens_from_state(
+ state_events
+ )
+
+ if len(existing_members) > 2 or len(threepid_tokens) > 1:
+ return False
+
+ prev_rules_event = state_events.get((ACCESS_RULES_TYPE, ""))
+
+ # Now that we know the new rule doesn't break the "direct" case, we can allow any
+ # new rule in rooms that had none before.
+ if prev_rules_event is None:
+ return True
+
+ prev_rule = prev_rules_event.content.get("rule")
+
+ # Currently, we can only go from "restricted" to "unrestricted".
+ if prev_rule == ACCESS_RULE_RESTRICTED and new_rule == ACCESS_RULE_UNRESTRICTED:
+ return True
+
+ return False
+
+ def _on_membership_or_invite(self, event, rule, state_events):
+ """Applies the correct rule for incoming m.room.member and
+ m.room.third_party_invite events.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ rule (str): The name of the rule to apply.
+ state_events (dict[tuple[event type, state key], EventBase]): The state of the
+ room before the event was sent.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ if rule == ACCESS_RULE_RESTRICTED:
+ ret = self._on_membership_or_invite_restricted(event)
+ elif rule == ACCESS_RULE_UNRESTRICTED:
+ ret = self._on_membership_or_invite_unrestricted()
+ elif rule == ACCESS_RULE_DIRECT:
+ ret = self._on_membership_or_invite_direct(event, state_events)
+ else:
+ # We currently apply the default (restricted) if we don't know the rule, we
+ # might want to change that in the future.
+ ret = self._on_membership_or_invite_restricted(event)
+
+ return ret
+
+ def _on_membership_or_invite_restricted(self, event):
+ """Implements the checks and behaviour specified for the "restricted" rule.
+
+ "restricted" currently means that users can only invite users if their server is
+ included in a limited list of domains.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ # We're not applying the rules on m.room.third_party_member events here because
+ # the filtering on threepids is done in check_threepid_can_be_invited, which is
+ # called before check_event_allowed.
+ if event.type == EventTypes.ThirdPartyInvite:
+ return True
+
+ # We only need to process "join" and "invite" memberships, in order to be backward
+ # compatible, e.g. if a user from a blacklisted server joined a restricted room
+ # before the rules started being enforced on the server, that user must be able to
+ # leave it.
+ if event.membership not in [Membership.JOIN, Membership.INVITE]:
+ return True
+
+ invitee_domain = get_domain_from_id(event.state_key)
+ return invitee_domain not in self.domains_forbidden_when_restricted
+
+ def _on_membership_or_invite_unrestricted(self):
+ """Implements the checks and behaviour specified for the "unrestricted" rule.
+
+ "unrestricted" currently means that every event is allowed.
+
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ return True
+
+ def _on_membership_or_invite_direct(self, event, state_events):
+ """Implements the checks and behaviour specified for the "direct" rule.
+
+ "direct" currently means that no member is allowed apart from the two initial
+ members the room was created for (i.e. the room's creator and their first
+ invitee).
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ state_events (dict[tuple[event type, state key], EventBase]): The state of the
+ room before the event was sent.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ # Get the room memberships and 3PID invite tokens from the room's state.
+ existing_members, threepid_tokens = self._get_members_and_tokens_from_state(
+ state_events
+ )
+
+ # There should never be more than one 3PID invite in the room state: if the second
+ # original user came and left, and we're inviting them using their email address,
+ # given we know they have a Matrix account binded to the address (so they could
+ # join the first time), Synapse will successfully look it up before attempting to
+ # store an invite on the IS.
+ if len(threepid_tokens) == 1 and event.type == EventTypes.ThirdPartyInvite:
+ # If we already have a 3PID invite in flight, don't accept another one, unless
+ # the new one has the same invite token as its state key. This is because 3PID
+ # invite revocations must be allowed, and a revocation is basically a new 3PID
+ # invite event with an empty content and the same token as the invite it
+ # revokes.
+ return event.state_key in threepid_tokens
+
+ if len(existing_members) == 2:
+ # If the user was within the two initial user of the room, Synapse would have
+ # looked it up successfully and thus sent a m.room.member here instead of
+ # m.room.third_party_invite.
+ if event.type == EventTypes.ThirdPartyInvite:
+ return False
+
+ # We can only have m.room.member events here. The rule in this case is to only
+ # allow the event if its target is one of the initial two members in the room,
+ # i.e. the state key of one of the two m.room.member states in the room.
+ return event.state_key in existing_members
+
+ # We're alone in the room (and always have been) and there's one 3PID invite in
+ # flight.
+ if len(existing_members) == 1 and len(threepid_tokens) == 1:
+ # We can only have m.room.member events here. In this case, we can only allow
+ # the event if it's either a m.room.member from the joined user (we can assume
+ # that the only m.room.member event is a join otherwise we wouldn't be able to
+ # send an event to the room) or an an invite event which target is the invited
+ # user.
+ target = event.state_key
+ is_from_threepid_invite = self._is_invite_from_threepid(
+ event, threepid_tokens[0]
+ )
+ if is_from_threepid_invite or target == existing_members[0]:
+ return True
+
+ return False
+
+ return True
+
+ def _is_power_level_content_allowed(self, content, access_rule):
+ """Check if a given power levels event is permitted under the given access rule.
+
+ It shouldn't be allowed if it either changes the default PL to a non-0 value or
+ gives a non-0 PL to a user that would have been forbidden from joining the room
+ under a more restrictive access rule.
+
+ Args:
+ content (dict[]): The content of the m.room.power_levels event to check.
+ access_rule (str): The access rule in place in this room.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ # Check if we need to apply the restrictions with the current rule.
+ if access_rule not in RULES_WITH_RESTRICTED_POWER_LEVELS:
+ return True
+
+ # If users_default is explicitly set to a non-0 value, deny the event.
+ users_default = content.get("users_default", 0)
+ if users_default:
+ return False
+
+ users = content.get("users", {})
+ for user_id, power_level in users.items():
+ server_name = get_domain_from_id(user_id)
+ # Check the domain against the blacklist. If found, and the PL isn't 0, deny
+ # the event.
+ if (
+ server_name in self.domains_forbidden_when_restricted
+ and power_level != 0
+ ):
+ return False
+
+ return True
+
+ def _on_join_rule_change(self, event, rule):
+ """Check whether a join rule change is allowed. A join rule change is always
+ allowed unless the new join rule is "public" and the current access rule isn't
+ "restricted".
+ The rationale is that external users (those whose server would be denied access
+ to rooms enforcing the "restricted" access rule) should always rely on non-
+ external users for access to rooms, therefore they shouldn't be able to access
+ rooms that don't require an invite to be joined.
+
+ Note that we currently rely on the default access rule being "restricted": during
+ room creation, the m.room.join_rules event will be sent *before* the
+ im.vector.room.access_rules one, so the access rule that will be considered here
+ in this case will be the default "restricted" one. This is fine since the
+ "restricted" access rule allows any value for the join rule, but we should keep
+ that in mind if we need to change the default access rule in the future.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ rule (str): The name of the rule to apply.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ if event.content.get("join_rule") == JoinRules.PUBLIC:
+ return rule == ACCESS_RULE_RESTRICTED
+
+ return True
+
+ def _on_room_avatar_change(self, event, rule):
+ """Check whether a change of room avatar is allowed.
+ The current rule is to forbid such a change in direct chats but allow it
+ everywhere else.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ rule (str): The name of the rule to apply.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ return rule != ACCESS_RULE_DIRECT
+
+ def _on_room_name_change(self, event, rule):
+ """Check whether a change of room name is allowed.
+ The current rule is to forbid such a change in direct chats but allow it
+ everywhere else.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ rule (str): The name of the rule to apply.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ return rule != ACCESS_RULE_DIRECT
+
+ def _on_room_topic_change(self, event, rule):
+ """Check whether a change of room topic is allowed.
+ The current rule is to forbid such a change in direct chats but allow it
+ everywhere else.
+
+ Args:
+ event (synapse.events.EventBase): The event to check.
+ rule (str): The name of the rule to apply.
+ Returns:
+ bool, True if the event can be allowed, False otherwise.
+ """
+ return rule != ACCESS_RULE_DIRECT
+
+ @staticmethod
+ def _get_rule_from_state(state_events):
+ """Extract the rule to be applied from the given set of state events.
+
+ Args:
+ state_events (dict[tuple[event type, state key], EventBase]): The set of state
+ events.
+ Returns:
+ str, the name of the rule (either "direct", "restricted" or "unrestricted")
+ """
+ access_rules = state_events.get((ACCESS_RULES_TYPE, ""))
+ if access_rules is None:
+ rule = ACCESS_RULE_RESTRICTED
+ else:
+ rule = access_rules.content.get("rule")
+ return rule
+
+ @staticmethod
+ def _get_join_rule_from_state(state_events):
+ """Extract the room's join rule from the given set of state events.
+
+ Args:
+ state_events (dict[tuple[event type, state key], EventBase]): The set of state
+ events.
+ Returns:
+ str, the name of the join rule (either "public", or "invite")
+ """
+ join_rule_event = state_events.get((EventTypes.JoinRules, ""))
+ if join_rule_event is None:
+ return None
+ return join_rule_event.content.get("join_rule")
+
+ @staticmethod
+ def _get_members_and_tokens_from_state(state_events):
+ """Retrieves from a list of state events the list of users that have a
+ m.room.member event in the room, and the tokens of 3PID invites in the room.
+
+ Args:
+ state_events (dict[tuple[event type, state key], EventBase]): The set of state
+ events.
+ Returns:
+ existing_members (list[str]): List of targets of the m.room.member events in
+ the state.
+ threepid_invite_tokens (list[str]): List of tokens of the 3PID invites in the
+ state.
+ """
+ existing_members = []
+ threepid_invite_tokens = []
+ for key, state_event in state_events.items():
+ if key[0] == EventTypes.Member and state_event.content:
+ existing_members.append(state_event.state_key)
+ if key[0] == EventTypes.ThirdPartyInvite and state_event.content:
+ # Don't include revoked invites.
+ threepid_invite_tokens.append(state_event.state_key)
+
+ return existing_members, threepid_invite_tokens
+
+ @staticmethod
+ def _is_invite_from_threepid(invite, threepid_invite_token):
+ """Checks whether the given invite follows the given 3PID invite.
+
+ Args:
+ invite (EventBase): The m.room.member event with "invite" membership.
+ threepid_invite_token (str): The state key from the 3PID invite.
+ """
+ token = (
+ invite.content.get("third_party_invite", {})
+ .get("signed", {})
+ .get("token", "")
+ )
+
+ return token == threepid_invite_token
diff --git a/synapse/types.py b/synapse/types.py
index 51eadb6ad4..94c01b0a18 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -16,6 +16,8 @@ import re
import string
from collections import namedtuple
+from six.moves import filter
+
import attr
from synapse.api.errors import SynapseError
@@ -235,6 +237,19 @@ def contains_invalid_mxid_characters(localpart):
return any(c not in mxid_localpart_allowed_characters for c in localpart)
+def strip_invalid_mxid_characters(localpart):
+ """Removes any invalid characters from an mxid
+
+ Args:
+ localpart (basestring): the localpart to be stripped
+
+ Returns:
+ localpart (basestring): the localpart having been stripped
+ """
+ filtered = filter(lambda c: c in mxid_localpart_allowed_characters, localpart)
+ return "".join(filtered)
+
+
UPPER_CASE_PATTERN = re.compile(b"[A-Z_]")
# the following is a pattern which matches '=', and bytes which are not allowed in a mxid
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index f506b2a695..7856353002 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -49,7 +49,7 @@ class Clock(object):
with context.PreserveLoggingContext():
self._reactor.callLater(seconds, d.callback, seconds)
res = yield d
- defer.returnValue(res)
+ return res
def time(self):
"""Returns the current system time in seconds since epoch."""
@@ -59,7 +59,7 @@ class Clock(object):
"""Returns the current system time in miliseconds since epoch."""
return int(self.time() * 1000)
- def looping_call(self, f, msec):
+ def looping_call(self, f, msec, *args, **kwargs):
"""Call a function repeatedly.
Waits `msec` initially before calling `f` for the first time.
@@ -70,8 +70,10 @@ class Clock(object):
Args:
f(function): The function to call repeatedly.
msec(float): How long to wait between calls in milliseconds.
+ *args: Postional arguments to pass to function.
+ **kwargs: Key arguments to pass to function.
"""
- call = task.LoopingCall(f)
+ call = task.LoopingCall(f, *args, **kwargs)
call.clock = self._reactor
d = call.start(msec / 1000.0, now=False)
d.addErrback(log_failure, "Looping call died", consumeErrors=False)
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 58a6b8764f..f1c46836b1 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -366,7 +366,7 @@ class ReadWriteLock(object):
new_defer.callback(None)
self.key_to_current_readers.get(key, set()).discard(new_defer)
- defer.returnValue(_ctx_manager())
+ return _ctx_manager()
@defer.inlineCallbacks
def write(self, key):
@@ -396,7 +396,7 @@ class ReadWriteLock(object):
if self.key_to_current_writer[key] == new_defer:
self.key_to_current_writer.pop(key)
- defer.returnValue(_ctx_manager())
+ return _ctx_manager()
def _cancelled_to_timed_out_error(value, timeout):
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 8271229015..b50e3503f0 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -51,7 +52,19 @@ response_cache_evicted = Gauge(
response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"])
-def register_cache(cache_type, cache_name, cache):
+def register_cache(cache_type, cache_name, cache, collect_callback=None):
+ """Register a cache object for metric collection.
+
+ Args:
+ cache_type (str):
+ cache_name (str): name of the cache
+ cache (object): cache itself
+ collect_callback (callable|None): if not None, a function which is called during
+ metric collection to update additional metrics.
+
+ Returns:
+ CacheMetric: an object which provides inc_{hits,misses,evictions} methods
+ """
# Check if the metric is already registered. Unregister it, if so.
# This usually happens during tests, as at runtime these caches are
@@ -90,6 +103,8 @@ def register_cache(cache_type, cache_name, cache):
cache_hits.labels(cache_name).set(self.hits)
cache_evicted.labels(cache_name).set(self.evicted_size)
cache_total.labels(cache_name).set(self.hits + self.misses)
+ if collect_callback:
+ collect_callback()
except Exception as e:
logger.warn("Error calculating metrics for %s: %s", cache_name, e)
raise
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 675db2f448..43f66ec4be 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -19,8 +19,9 @@ import logging
import threading
from collections import namedtuple
-import six
-from six import itervalues, string_types
+from six import itervalues
+
+from prometheus_client import Gauge
from twisted.internet import defer
@@ -30,13 +31,18 @@ from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
-from synapse.util.stringutils import to_ascii
from . import register_cache
logger = logging.getLogger(__name__)
+cache_pending_metric = Gauge(
+ "synapse_util_caches_cache_pending",
+ "Number of lookups currently pending for this cache",
+ ["name"],
+)
+
_CacheSentinel = object()
@@ -82,11 +88,19 @@ class Cache(object):
self.name = name
self.keylen = keylen
self.thread = None
- self.metrics = register_cache("cache", name, self.cache)
+ self.metrics = register_cache(
+ "cache",
+ name,
+ self.cache,
+ collect_callback=self._metrics_collection_callback,
+ )
def _on_evicted(self, evicted_count):
self.metrics.inc_evictions(evicted_count)
+ def _metrics_collection_callback(self):
+ cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))
+
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
@@ -108,7 +122,7 @@ class Cache(object):
update_metrics (bool): whether to update the cache hit rate metrics
Returns:
- Either a Deferred or the raw result
+ Either an ObservableDeferred or the raw result
"""
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
@@ -132,9 +146,14 @@ class Cache(object):
return default
def set(self, key, value, callback=None):
+ if not isinstance(value, defer.Deferred):
+ raise TypeError("not a Deferred")
+
callbacks = [callback] if callback else []
self.check_thread()
- entry = CacheEntry(deferred=value, callbacks=callbacks)
+ observable = ObservableDeferred(value, consumeErrors=True)
+ observer = defer.maybeDeferred(observable.observe)
+ entry = CacheEntry(deferred=observable, callbacks=callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
@@ -142,20 +161,31 @@ class Cache(object):
self._pending_deferred_cache[key] = entry
- def shuffle(result):
+ def compare_and_pop():
+ """Check if our entry is still the one in _pending_deferred_cache, and
+ if so, pop it.
+
+ Returns true if the entries matched.
+ """
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
+ return True
+
+ # oops, the _pending_deferred_cache has been updated since
+ # we started our query, so we are out of date.
+ #
+ # Better put back whatever we took out. (We do it this way
+ # round, rather than peeking into the _pending_deferred_cache
+ # and then removing on a match, to make the common case faster)
+ if existing_entry is not None:
+ self._pending_deferred_cache[key] = existing_entry
+
+ return False
+
+ def cb(result):
+ if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
- # oops, the _pending_deferred_cache has been updated since
- # we started our query, so we are out of date.
- #
- # Better put back whatever we took out. (We do it this way
- # round, rather than peeking into the _pending_deferred_cache
- # and then removing on a match, to make the common case faster)
- if existing_entry is not None:
- self._pending_deferred_cache[key] = existing_entry
-
# we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was
@@ -163,9 +193,16 @@ class Cache(object):
# `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now.
entry.invalidate()
- return result
- entry.deferred.addCallback(shuffle)
+ def eb(_fail):
+ compare_and_pop()
+ entry.invalidate()
+
+ # once the deferred completes, we can move the entry from the
+ # _pending_deferred_cache to the real cache.
+ #
+ observer.addCallbacks(cb, eb)
+ return observable
def prefill(self, key, value, callback=None):
callbacks = [callback] if callback else []
@@ -289,7 +326,7 @@ class CacheDescriptor(_CacheDescriptorBase):
def foo(self, key, cache_context):
r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
- defer.returnValue(r1 + r2)
+ return r1 + r2
Args:
num_args (int): number of positional arguments (excluding ``self`` and
@@ -398,20 +435,10 @@ class CacheDescriptor(_CacheDescriptorBase):
ret.addErrback(onErr)
- # If our cache_key is a string on py2, try to convert to ascii
- # to save a bit of space in large caches. Py3 does this
- # internally automatically.
- if six.PY2 and isinstance(cache_key, string_types):
- cache_key = to_ascii(cache_key)
-
- result_d = ObservableDeferred(ret, consumeErrors=True)
- cache.set(cache_key, result_d, callback=invalidate_callback)
+ result_d = cache.set(cache_key, ret, callback=invalidate_callback)
observer = result_d.observe()
- if isinstance(observer, defer.Deferred):
- return make_deferred_yieldable(observer)
- else:
- return observer
+ return make_deferred_yieldable(observer)
if self.num_args == 1:
wrapped.invalidate = lambda key: cache.invalidate(key[0])
@@ -527,7 +554,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
missing.add(arg)
if missing:
- # we need an observable deferred for each entry in the list,
+ # we need a deferred for each entry in the list,
# which we put in the cache. Each deferred resolves with the
# relevant result for that key.
deferreds_map = {}
@@ -535,8 +562,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
deferred = defer.Deferred()
deferreds_map[arg] = deferred
key = arg_to_cache_key(arg)
- observable = ObservableDeferred(deferred)
- cache.set(key, observable, callback=invalidate_callback)
+ cache.set(key, deferred, callback=invalidate_callback)
def complete_all(res):
# the wrapped function has completed. It returns a
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index d6908e169d..82d3eefe0e 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -121,7 +121,7 @@ class ResponseCache(object):
@defer.inlineCallbacks
def handle_request(request):
# etc
- defer.returnValue(result)
+ return result
result = yield response_cache.wrap(
key,
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index c30b6de19c..0910930c21 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -67,7 +67,7 @@ def measure_func(name):
def measured_func(self, *args, **kwargs):
with Measure(self.clock, name):
r = yield func(self, *args, **kwargs)
- defer.returnValue(r)
+ return r
return measured_func
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index d8d0ceae51..0862b5ca5a 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -95,15 +95,13 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs)
# maximum backoff even though it might only have been down briefly
backoff_on_failure = not ignore_backoff
- defer.returnValue(
- RetryDestinationLimiter(
- destination,
- clock,
- store,
- retry_interval,
- backoff_on_failure=backoff_on_failure,
- **kwargs
- )
+ return RetryDestinationLimiter(
+ destination,
+ clock,
+ store,
+ retry_interval,
+ backoff_on_failure=backoff_on_failure,
+ **kwargs
)
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 982c6d81ca..6a2464cab3 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,12 +15,15 @@
# limitations under the License.
import random
+import re
import string
import six
from six import PY2, PY3
from six.moves import range
+from synapse.api.errors import Codes, SynapseError
+
_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
# random_string and random_string_with_symbols are used for a range of things,
@@ -27,6 +31,8 @@ _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
# we get cryptographically-secure randoms.
rand = random.SystemRandom()
+client_secret_regex = re.compile(r"^[0-9a-zA-Z.=_-]+$")
+
def random_string(length):
return "".join(rand.choice(string.ascii_letters) for _ in range(length))
@@ -109,3 +115,11 @@ def exception_to_unicode(e):
return msg.decode("utf-8", errors="replace")
else:
return msg
+
+
+def assert_valid_client_secret(client_secret):
+ """Validate that a given string matches the client_secret regex defined by the spec"""
+ if client_secret_regex.match(client_secret) is None:
+ raise SynapseError(
+ 400, "Invalid client_secret parameter", errcode=Codes.INVALID_PARAM
+ )
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index 3ec1dfb0c2..34ce7cac16 100644
--- a/synapse/util/threepids.py
+++ b/synapse/util/threepids.py
@@ -16,11 +16,14 @@
import logging
import re
+from twisted.internet import defer
+
logger = logging.getLogger(__name__)
+@defer.inlineCallbacks
def check_3pid_allowed(hs, medium, address):
- """Checks whether a given format of 3PID is allowed to be used on this HS
+ """Checks whether a given 3PID is allowed to be used on this HS
Args:
hs (synapse.server.HomeServer): server
@@ -28,9 +31,36 @@ def check_3pid_allowed(hs, medium, address):
address (str): address within that medium (e.g. "wotan@matrix.org")
msisdns need to first have been canonicalised
Returns:
- bool: whether the 3PID medium/address is allowed to be added to this HS
+ defered bool: whether the 3PID medium/address is allowed to be added to this HS
"""
+ if hs.config.check_is_for_allowed_local_3pids:
+ data = yield hs.get_simple_http_client().get_json(
+ "https://%s%s"
+ % (
+ hs.config.check_is_for_allowed_local_3pids,
+ "/_matrix/identity/api/v1/internal-info",
+ ),
+ {"medium": medium, "address": address},
+ )
+
+ # Check for invalid response
+ if "hs" not in data and "shadow_hs" not in data:
+ defer.returnValue(False)
+
+ # Check if this user is intended to register for this homeserver
+ if (
+ data.get("hs") != hs.config.server_name
+ and data.get("shadow_hs") != hs.config.server_name
+ ):
+ defer.returnValue(False)
+
+ if data.get("requires_invite", False) and not data.get("invited", False):
+ # Requires an invite but hasn't been invited
+ defer.returnValue(False)
+
+ defer.returnValue(True)
+
if hs.config.allowed_local_3pids:
for constraint in hs.config.allowed_local_3pids:
logger.debug(
@@ -43,8 +73,8 @@ def check_3pid_allowed(hs, medium, address):
if medium == constraint["medium"] and re.match(
constraint["pattern"], address
):
- return True
+ defer.returnValue(True)
else:
- return True
+ defer.returnValue(True)
- return False
+ defer.returnValue(False)
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index a4d9a462f7..fa404b9d75 100644
--- a/synapse/util/versionstring.py
+++ b/synapse/util/versionstring.py
@@ -22,6 +22,23 @@ logger = logging.getLogger(__name__)
def get_version_string(module):
+ """Given a module calculate a git-aware version string for it.
+
+ If called on a module not in a git checkout will return `__verison__`.
+
+ Args:
+ module (module)
+
+ Returns:
+ str
+ """
+
+ cached_version = getattr(module, "_synapse_version_string_cache", None)
+ if cached_version:
+ return cached_version
+
+ version_string = module.__version__
+
try:
null = open(os.devnull, "w")
cwd = os.path.dirname(os.path.abspath(module.__file__))
@@ -80,8 +97,10 @@ def get_version_string(module):
s for s in (git_branch, git_tag, git_commit, git_dirty) if s
)
- return "%s (%s)" % (module.__version__, git_version)
+ version_string = "%s (%s)" % (module.__version__, git_version)
except Exception as e:
logger.info("Failed to check for git repository: %s", e)
- return module.__version__
+ module._synapse_version_string_cache = version_string
+
+ return version_string
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 2a11c83596..a19011b793 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -43,7 +43,12 @@ MEMBERSHIP_PRIORITY = (
@defer.inlineCallbacks
def filter_events_for_client(
- store, user_id, events, is_peeking=False, always_include_ids=frozenset()
+ store,
+ user_id,
+ events,
+ is_peeking=False,
+ always_include_ids=frozenset(),
+ apply_retention_policies=True,
):
"""
Check which events a user is allowed to see
@@ -59,6 +64,10 @@ def filter_events_for_client(
events
always_include_ids (set(event_id)): set of event ids to specifically
include (unless sender is ignored)
+ apply_retention_policies (bool): Whether to filter out events that's older than
+ allowed by the room's retention policy. Useful when this function is called
+ to e.g. check whether a user should be allowed to see the state at a given
+ event rather than to know if it should send an event to a user's client(s).
Returns:
Deferred[list[synapse.events.EventBase]]
@@ -86,6 +95,15 @@ def filter_events_for_client(
erased_senders = yield store.are_users_erased((e.sender for e in events))
+ if apply_retention_policies:
+ room_ids = set(e.room_id for e in events)
+ retention_policies = {}
+
+ for room_id in room_ids:
+ retention_policies[room_id] = (
+ yield store.get_retention_policy_for_room(room_id)
+ )
+
def allowed(event):
"""
Args:
@@ -103,6 +121,18 @@ def filter_events_for_client(
if not event.is_state() and event.sender in ignore_list:
return None
+ # Don't try to apply the room's retention policy if the event is a state event, as
+ # MSC1763 states that retention is only considered for non-state events.
+ if apply_retention_policies and not event.is_state():
+ retention_policy = retention_policies[event.room_id]
+ max_lifetime = retention_policy.get("max_lifetime")
+
+ if max_lifetime is not None:
+ oldest_allowed_ts = store.clock.time_msec() - max_lifetime
+
+ if event.origin_server_ts < oldest_allowed_ts:
+ return None
+
if event.event_id in always_include_ids:
return event
@@ -208,7 +238,7 @@ def filter_events_for_client(
filtered_events = filter(operator.truth, filtered_events)
# we turn it into a list before returning it.
- defer.returnValue(list(filtered_events))
+ return list(filtered_events)
@defer.inlineCallbacks
@@ -317,11 +347,11 @@ def filter_events_for_server(
elif redact:
to_return.append(prune_event(e))
- defer.returnValue(to_return)
+ return to_return
# If there are no erased users then we can just return the given list
# of events without having to copy it.
- defer.returnValue(events)
+ return events
# Ok, so we're dealing with events that have non-trivial visibility
# rules, so we need to also get the memberships of the room.
@@ -384,4 +414,4 @@ def filter_events_for_server(
elif redact:
to_return.append(prune_event(e))
- defer.returnValue(to_return)
+ return to_return
|