diff --git a/synapse/__init__.py b/synapse/__init__.py
index 5bc24863d9..d8d340f426 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -14,17 +14,33 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-""" This is a reference implementation of a Matrix home server.
+""" This is a reference implementation of a Matrix homeserver.
"""
+import os
+import sys
+
+# Check that we're not running on an unsupported Python version.
+if sys.version_info < (3, 5):
+ print("Synapse requires Python 3.5 or above.")
+ sys.exit(1)
+
try:
from twisted.internet import protocol
from twisted.internet.protocol import Factory
from twisted.names.dns import DNSDatagramProtocol
+
protocol.Factory.noisy = False
Factory.noisy = False
DNSDatagramProtocol.noisy = False
except ImportError:
pass
-__version__ = "1.0.0"
+__version__ = "1.12.4"
+
+if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
+ # We import here so that we don't have to install a bunch of deps when
+ # running the packaging tox test.
+ from synapse.util.patch_inline_callbacks import do_patch
+
+ do_patch()
diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py
index 6e93f5a0c6..d528450c78 100644
--- a/synapse/_scripts/register_new_matrix_user.py
+++ b/synapse/_scripts/register_new_matrix_user.py
@@ -57,18 +57,18 @@ def request_registration(
nonce = r.json()["nonce"]
- mac = hmac.new(key=shared_secret.encode('utf8'), digestmod=hashlib.sha1)
+ mac = hmac.new(key=shared_secret.encode("utf8"), digestmod=hashlib.sha1)
- mac.update(nonce.encode('utf8'))
+ mac.update(nonce.encode("utf8"))
mac.update(b"\x00")
- mac.update(user.encode('utf8'))
+ mac.update(user.encode("utf8"))
mac.update(b"\x00")
- mac.update(password.encode('utf8'))
+ mac.update(password.encode("utf8"))
mac.update(b"\x00")
mac.update(b"admin" if admin else b"notadmin")
if user_type:
mac.update(b"\x00")
- mac.update(user_type.encode('utf8'))
+ mac.update(user_type.encode("utf8"))
mac = mac.hexdigest()
@@ -134,8 +134,9 @@ def register_new_user(user, password, server_location, shared_secret, admin, use
else:
admin = False
- request_registration(user, password, server_location, shared_secret,
- bool(admin), user_type)
+ request_registration(
+ user, password, server_location, shared_secret, bool(admin), user_type
+ )
def main():
@@ -143,8 +144,8 @@ def main():
logging.captureWarnings(True)
parser = argparse.ArgumentParser(
- description="Used to register new users with a given home server when"
- " registration has been disabled. The home server must be"
+ description="Used to register new users with a given homeserver when"
+ " registration has been disabled. The homeserver must be"
" configured with the 'registration_shared_secret' option"
" set."
)
@@ -189,7 +190,7 @@ def main():
group.add_argument(
"-c",
"--config",
- type=argparse.FileType('r'),
+ type=argparse.FileType("r"),
help="Path to server config file. Used to read in shared secret.",
)
@@ -200,8 +201,8 @@ def main():
parser.add_argument(
"server_url",
default="https://localhost:8448",
- nargs='?',
- help="URL to use to talk to the home server. Defaults to "
+ nargs="?",
+ help="URL to use to talk to the homeserver. Defaults to "
" 'https://localhost:8448'.",
)
@@ -220,8 +221,9 @@ def main():
if args.admin or args.no_admin:
admin = args.admin
- register_new_user(args.user, args.password, args.server_url, secret,
- admin, args.user_type)
+ register_new_user(
+ args.user, args.password, args.server_url, secret, admin, args.user_type
+ )
if __name__ == "__main__":
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index f505f1ac63..89850f43e0 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import Optional
from six import itervalues
@@ -22,12 +23,21 @@ from netaddr import IPAddress
from twisted.internet import defer
+import synapse.logging.opentracing as opentracing
import synapse.types
from synapse import event_auth
-from synapse.api.constants import EventTypes, JoinRules, Membership
-from synapse.api.errors import AuthError, Codes, ResourceLimitError
+from synapse.api.constants import EventTypes, LimitBlockingTypes, Membership, UserTypes
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ InvalidClientTokenError,
+ MissingClientTokenError,
+ ResourceLimitError,
+)
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.config.server import is_threepid_reserved
-from synapse.types import UserID
+from synapse.events import EventBase
+from synapse.types import StateMap, UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches.lrucache import LruCache
from synapse.util.metrics import Measure
@@ -36,8 +46,11 @@ logger = logging.getLogger(__name__)
AuthEventTypes = (
- EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
- EventTypes.JoinRules, EventTypes.RoomHistoryVisibility,
+ EventTypes.Create,
+ EventTypes.Member,
+ EventTypes.PowerLevels,
+ EventTypes.JoinRules,
+ EventTypes.RoomHistoryVisibility,
EventTypes.ThirdPartyInvite,
)
@@ -54,12 +67,12 @@ class Auth(object):
FIXME: This class contains a mix of functions for authenticating users
of our client-server API and authenticating events added to room graphs.
"""
+
def __init__(self, hs):
self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
- self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000)
register_cache("cache", "token_cache", self.token_cache)
@@ -67,113 +80,74 @@ class Auth(object):
self._account_validity = hs.config.account_validity
@defer.inlineCallbacks
- def check_from_context(self, room_version, event, context, do_sig_check=True):
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ def check_from_context(self, room_version: str, event, context, do_sig_check=True):
+ prev_state_ids = yield context.get_prev_state_ids()
auth_events_ids = yield self.compute_auth_events(
- event, prev_state_ids, for_verification=True,
+ event, prev_state_ids, for_verification=True
)
auth_events = yield self.store.get_events(auth_events_ids)
- auth_events = {
- (e.type, e.state_key): e for e in itervalues(auth_events)
- }
- self.check(
- room_version, event,
- auth_events=auth_events, do_sig_check=do_sig_check,
- )
+ auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)}
- def check(self, room_version, event, auth_events, do_sig_check=True):
- """ Checks if this event is correctly authed.
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+ event_auth.check(
+ room_version_obj, event, auth_events=auth_events, do_sig_check=do_sig_check
+ )
+ @defer.inlineCallbacks
+ def check_user_in_room(
+ self,
+ room_id: str,
+ user_id: str,
+ current_state: Optional[StateMap[EventBase]] = None,
+ allow_departed_users: bool = False,
+ ):
+ """Check if the user is in the room, or was at some point.
Args:
- room_version (str): version of the room
- event: the event being checked.
- auth_events (dict: event-key -> event): the existing room state.
-
+ room_id: The room to check.
- Returns:
- True if the auth checks pass.
- """
- with Measure(self.clock, "auth.check"):
- event_auth.check(
- room_version, event, auth_events, do_sig_check=do_sig_check
- )
+ user_id: The user to check.
- @defer.inlineCallbacks
- def check_joined_room(self, room_id, user_id, current_state=None):
- """Check if the user is currently joined in the room
- Args:
- room_id(str): The room to check.
- user_id(str): The user to check.
- current_state(dict): Optional map of the current state of the room.
+ current_state: Optional map of the current state of the room.
If provided then that map is used to check whether they are a
member of the room. Otherwise the current membership is
loaded from the database.
+
+ allow_departed_users: if True, accept users that were previously
+ members but have now departed.
+
Raises:
- AuthError if the user is not in the room.
+ AuthError if the user is/was not in the room.
Returns:
- A deferred membership event for the user if the user is in
- the room.
+ Deferred[Optional[EventBase]]:
+ Membership event for the user if the user was in the
+ room. This will be the join event if they are currently joined to
+ the room. This will be the leave event if they have left the room.
"""
if current_state:
- member = current_state.get(
- (EventTypes.Member, user_id),
- None
- )
+ member = current_state.get((EventTypes.Member, user_id), None)
else:
member = yield self.state.get_current_state(
- room_id=room_id,
- event_type=EventTypes.Member,
- state_key=user_id
+ room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
-
- self._check_joined_room(member, user_id, room_id)
- defer.returnValue(member)
-
- @defer.inlineCallbacks
- def check_user_was_in_room(self, room_id, user_id):
- """Check if the user was in the room at some point.
- Args:
- room_id(str): The room to check.
- user_id(str): The user to check.
- Raises:
- AuthError if the user was never in the room.
- Returns:
- A deferred membership event for the user if the user was in the
- room. This will be the join event if they are currently joined to
- the room. This will be the leave event if they have left the room.
- """
- member = yield self.state.get_current_state(
- room_id=room_id,
- event_type=EventTypes.Member,
- state_key=user_id
- )
membership = member.membership if member else None
- if membership not in (Membership.JOIN, Membership.LEAVE):
- raise AuthError(403, "User %s not in room %s" % (
- user_id, room_id
- ))
+ if membership == Membership.JOIN:
+ return member
- if membership == Membership.LEAVE:
+ # XXX this looks totally bogus. Why do we not allow users who have been banned,
+ # or those who were members previously and have been re-invited?
+ if allow_departed_users and membership == Membership.LEAVE:
forgot = yield self.store.did_forget(user_id, room_id)
- if forgot:
- raise AuthError(403, "User %s not in room %s" % (
- user_id, room_id
- ))
+ if not forgot:
+ return member
- defer.returnValue(member)
+ raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
@defer.inlineCallbacks
def check_host_in_room(self, room_id, host):
with Measure(self.clock, "check_host_in_room"):
latest_event_ids = yield self.store.is_host_joined(room_id, host)
- defer.returnValue(latest_event_ids)
-
- def _check_joined_room(self, member, user_id, room_id):
- if not member or member.membership != Membership.JOIN:
- raise AuthError(403, "User %s not in room %s (%s)" % (
- user_id, room_id, repr(member)
- ))
+ return latest_event_ids
def can_federate(self, event, auth_events):
creation_event = auth_events.get((EventTypes.Create, ""))
@@ -185,11 +159,7 @@ class Auth(object):
@defer.inlineCallbacks
def get_user_by_req(
- self,
- request,
- allow_guest=False,
- rights="access",
- allow_expired=False,
+ self, request, allow_guest=False, rights="access", allow_expired=False
):
""" Get a registered user's ID.
@@ -203,24 +173,24 @@ class Auth(object):
Returns:
defer.Deferred: resolves to a ``synapse.types.Requester`` object
Raises:
- AuthError if no user by that token exists or the token is invalid.
+ InvalidClientCredentialsError if no user by that token exists or the token
+ is invalid.
+ AuthError if access is denied for the user in the access token
"""
- # Can optionally look elsewhere in the request (e.g. headers)
try:
ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders(
- b"User-Agent",
- default=[b""]
- )[0].decode('ascii', 'surrogateescape')
+ b"User-Agent", default=[b""]
+ )[0].decode("ascii", "surrogateescape")
- access_token = self.get_access_token_from_request(
- request, self.TOKEN_NOT_FOUND_HTTP_STATUS
- )
+ access_token = self.get_access_token_from_request(request)
user_id, app_service = yield self._get_appservice_user_id(request)
if user_id:
request.authenticated_entity = user_id
+ opentracing.set_tag("authenticated_entity", user_id)
+ opentracing.set_tag("appservice_id", app_service.id)
if ip_addr and self.hs.config.track_appservice_user_ips:
yield self.store.insert_client_ip(
@@ -231,9 +201,7 @@ class Auth(object):
device_id="dummy-device", # stubbed
)
- defer.returnValue(
- synapse.types.create_requester(user_id, app_service=app_service)
- )
+ return synapse.types.create_requester(user_id, app_service=app_service)
user_info = yield self.get_user_by_access_token(access_token, rights)
user = user_info["user"]
@@ -244,11 +212,12 @@ class Auth(object):
if self._account_validity.enabled and not allow_expired:
user_id = user.to_string()
expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
- if expiration_ts is not None and self.clock.time_msec() >= expiration_ts:
+ if (
+ expiration_ts is not None
+ and self.clock.time_msec() >= expiration_ts
+ ):
raise AuthError(
- 403,
- "User account has expired",
- errcode=Codes.EXPIRED_ACCOUNT,
+ 403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
)
# device_id may not be present if get_user_by_access_token has been
@@ -266,54 +235,51 @@ class Auth(object):
if is_guest and not allow_guest:
raise AuthError(
- 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
+ 403,
+ "Guest access not allowed",
+ errcode=Codes.GUEST_ACCESS_FORBIDDEN,
)
request.authenticated_entity = user.to_string()
+ opentracing.set_tag("authenticated_entity", user.to_string())
+ if device_id:
+ opentracing.set_tag("device_id", device_id)
- defer.returnValue(synapse.types.create_requester(
- user, token_id, is_guest, device_id, app_service=app_service)
+ return synapse.types.create_requester(
+ user, token_id, is_guest, device_id, app_service=app_service
)
except KeyError:
- raise AuthError(
- self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
- errcode=Codes.MISSING_TOKEN
- )
+ raise MissingClientTokenError()
def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token(
- self.get_access_token_from_request(
- request, self.TOKEN_NOT_FOUND_HTTP_STATUS
- )
+ self.get_access_token_from_request(request)
)
if app_service is None:
- return(None, None)
+ return None, None
if app_service.ip_range_whitelist:
ip_address = IPAddress(self.hs.get_ip_from_request(request))
if ip_address not in app_service.ip_range_whitelist:
- return(None, None)
+ return None, None
if b"user_id" not in request.args:
- return(app_service.sender, app_service)
+ return app_service.sender, app_service
- user_id = request.args[b"user_id"][0].decode('utf8')
+ user_id = request.args[b"user_id"][0].decode("utf8")
if app_service.sender == user_id:
- return(app_service.sender, app_service)
+ return app_service.sender, app_service
if not app_service.is_interested_in_user(user_id):
- raise AuthError(
- 403,
- "Application service cannot masquerade as this user."
- )
+ raise AuthError(403, "Application service cannot masquerade as this user.")
# Let ASes manipulate nonexistent users (e.g. to shadow-register them)
# if not (yield self.store.get_user_by_id(user_id)):
# raise AuthError(
# 403,
# "Application service has not registered this user"
# )
- return(user_id, app_service)
+ return user_id, app_service
@defer.inlineCallbacks
def get_user_by_access_token(self, token, rights="access"):
@@ -330,14 +296,26 @@ class Auth(object):
`token_id` (int|None): access token id. May be None if guest
`device_id` (str|None): device corresponding to access token
Raises:
- AuthError if no user by that token exists or the token is invalid.
+ InvalidClientCredentialsError if no user by that token exists or the token
+ is invalid.
"""
if rights == "access":
# first look in the database
r = yield self._look_up_user_by_access_token(token)
if r:
- defer.returnValue(r)
+ valid_until_ms = r["valid_until_ms"]
+ if (
+ valid_until_ms is not None
+ and valid_until_ms < self.clock.time_msec()
+ ):
+ # there was a valid access token, but it has expired.
+ # soft-logout the user.
+ raise InvalidClientTokenError(
+ msg="Access token has expired", soft_logout=True
+ )
+
+ return r
# otherwise it needs to be a valid macaroon
try:
@@ -348,11 +326,7 @@ class Auth(object):
if not guest:
# non-guest access tokens must be in the database
logger.warning("Unrecognised access token - not in store.")
- raise AuthError(
- self.TOKEN_NOT_FOUND_HTTP_STATUS,
- "Unrecognised access token.",
- errcode=Codes.UNKNOWN_TOKEN,
- )
+ raise InvalidClientTokenError()
# Guest access tokens are not stored in the database (there can
# only be one access token per guest, anyway).
@@ -367,16 +341,10 @@ class Auth(object):
# guest tokens.
stored_user = yield self.store.get_user_by_id(user_id)
if not stored_user:
- raise AuthError(
- self.TOKEN_NOT_FOUND_HTTP_STATUS,
- "Unknown user_id %s" % user_id,
- errcode=Codes.UNKNOWN_TOKEN
- )
+ raise InvalidClientTokenError("Unknown user_id %s" % user_id)
if not stored_user["is_guest"]:
- raise AuthError(
- self.TOKEN_NOT_FOUND_HTTP_STATUS,
- "Guest access token used for regular user",
- errcode=Codes.UNKNOWN_TOKEN
+ raise InvalidClientTokenError(
+ "Guest access token used for regular user"
)
ret = {
"user": user,
@@ -395,7 +363,7 @@ class Auth(object):
}
else:
raise RuntimeError("Unknown rights setting %s", rights)
- defer.returnValue(ret)
+ return ret
except (
_InvalidMacaroonException,
pymacaroons.exceptions.MacaroonException,
@@ -403,10 +371,7 @@ class Auth(object):
ValueError,
) as e:
logger.warning("Invalid macaroon in auth: %s %s", type(e), e)
- raise AuthError(
- self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
- errcode=Codes.UNKNOWN_TOKEN
- )
+ raise InvalidClientTokenError("Invalid macaroon passed.")
def _parse_and_validate_macaroon(self, token, rights="access"):
"""Takes a macaroon and tries to parse and validate it. This is cached
@@ -434,25 +399,16 @@ class Auth(object):
try:
user_id = self.get_user_id_from_macaroon(macaroon)
- has_expiry = False
guest = False
for caveat in macaroon.caveats:
- if caveat.caveat_id.startswith("time "):
- has_expiry = True
- elif caveat.caveat_id == "guest = true":
+ if caveat.caveat_id == "guest = true":
guest = True
- self.validate_macaroon(
- macaroon, rights, self.hs.config.expire_access_token,
- user_id=user_id,
- )
+ self.validate_macaroon(macaroon, rights, user_id=user_id)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
- raise AuthError(
- self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
- errcode=Codes.UNKNOWN_TOKEN
- )
+ raise InvalidClientTokenError("Invalid macaroon passed.")
- if not has_expiry and rights == "access":
+ if rights == "access":
self.token_cache[token] = (user_id, guest)
return user_id, guest
@@ -469,18 +425,16 @@ class Auth(object):
(str) user id
Raises:
- AuthError if there is no user_id caveat in the macaroon
+ InvalidClientCredentialsError if there is no user_id caveat in the
+ macaroon
"""
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
- return caveat.caveat_id[len(user_prefix):]
- raise AuthError(
- self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
- errcode=Codes.UNKNOWN_TOKEN
- )
+ return caveat.caveat_id[len(user_prefix) :]
+ raise InvalidClientTokenError("No user caveat in macaroon")
- def validate_macaroon(self, macaroon, type_string, verify_expiry, user_id):
+ def validate_macaroon(self, macaroon, type_string, user_id):
"""
validate that a Macaroon is understood by and was signed by this server.
@@ -488,7 +442,6 @@ class Auth(object):
macaroon(pymacaroons.Macaroon): The macaroon to validate
type_string(str): The kind of token required (e.g. "access",
"delete_pusher")
- verify_expiry(bool): Whether to verify whether the macaroon has expired.
user_id (str): The user_id required
"""
v = pymacaroons.Verifier()
@@ -501,19 +454,7 @@ class Auth(object):
v.satisfy_exact("type = " + type_string)
v.satisfy_exact("user_id = %s" % user_id)
v.satisfy_exact("guest = true")
-
- # verify_expiry should really always be True, but there exist access
- # tokens in the wild which expire when they should not, so we can't
- # enforce expiry yet (so we have to allow any caveat starting with
- # 'time < ' in access tokens).
- #
- # On the other hand, short-term login tokens (as used by CAS login, for
- # example) have an expiry time which we do want to enforce.
-
- if verify_expiry:
- v.satisfy_general(self._verify_expiry)
- else:
- v.satisfy_general(lambda c: c.startswith("time < "))
+ v.satisfy_general(self._verify_expiry)
# access_tokens include a nonce for uniqueness: any value is acceptable
v.satisfy_general(lambda c: c.startswith("nonce = "))
@@ -524,7 +465,7 @@ class Auth(object):
prefix = "time < "
if not caveat.startswith(prefix):
return False
- expiry = int(caveat[len(prefix):])
+ expiry = int(caveat[len(prefix) :])
now = self.hs.get_clock().time_msec()
return now < expiry
@@ -532,7 +473,7 @@ class Auth(object):
def _look_up_user_by_access_token(self, token):
ret = yield self.store.get_user_by_access_token(token)
if not ret:
- defer.returnValue(None)
+ return None
# we use ret.get() below because *lots* of unit tests stub out
# get_user_by_access_token in a way where it only returns a couple of
@@ -542,19 +483,18 @@ class Auth(object):
"token_id": ret.get("token_id", None),
"is_guest": False,
"device_id": ret.get("device_id"),
+ "valid_until_ms": ret.get("valid_until_ms"),
}
- defer.returnValue(user_info)
+ return user_info
def get_appservice_by_req(self, request):
- (user_id, app_service) = self._get_appservice_user_id(request)
- if not app_service:
- raise AuthError(
- self.TOKEN_NOT_FOUND_HTTP_STATUS,
- "Unrecognised access token.",
- errcode=Codes.UNKNOWN_TOKEN,
- )
- request.authenticated_entity = app_service.sender
- return app_service
+ token = self.get_access_token_from_request(request)
+ service = self.store.get_app_service_by_token(token)
+ if not service:
+ logger.warning("Unrecognised appservice access token.")
+ raise InvalidClientTokenError()
+ request.authenticated_entity = service.sender
+ return defer.succeed(service)
def is_server_admin(self, user):
""" Check if the given user is a local server admin.
@@ -567,107 +507,61 @@ class Auth(object):
"""
return self.store.is_server_admin(user)
- @defer.inlineCallbacks
- def compute_auth_events(self, event, current_state_ids, for_verification=False):
- if event.type == EventTypes.Create:
- defer.returnValue([])
-
- auth_ids = []
-
- key = (EventTypes.PowerLevels, "", )
- power_level_event_id = current_state_ids.get(key)
-
- if power_level_event_id:
- auth_ids.append(power_level_event_id)
-
- key = (EventTypes.JoinRules, "", )
- join_rule_event_id = current_state_ids.get(key)
-
- key = (EventTypes.Member, event.sender, )
- member_event_id = current_state_ids.get(key)
-
- key = (EventTypes.Create, "", )
- create_event_id = current_state_ids.get(key)
- if create_event_id:
- auth_ids.append(create_event_id)
-
- if join_rule_event_id:
- join_rule_event = yield self.store.get_event(join_rule_event_id)
- join_rule = join_rule_event.content.get("join_rule")
- is_public = join_rule == JoinRules.PUBLIC if join_rule else False
- else:
- is_public = False
-
- if event.type == EventTypes.Member:
- e_type = event.content["membership"]
- if e_type in [Membership.JOIN, Membership.INVITE]:
- if join_rule_event_id:
- auth_ids.append(join_rule_event_id)
+ def compute_auth_events(
+ self, event, current_state_ids: StateMap[str], for_verification: bool = False,
+ ):
+ """Given an event and current state return the list of event IDs used
+ to auth an event.
- if e_type == Membership.JOIN:
- if member_event_id and not is_public:
- auth_ids.append(member_event_id)
- else:
- if member_event_id:
- auth_ids.append(member_event_id)
-
- if for_verification:
- key = (EventTypes.Member, event.state_key, )
- existing_event_id = current_state_ids.get(key)
- if existing_event_id:
- auth_ids.append(existing_event_id)
-
- if e_type == Membership.INVITE:
- if "third_party_invite" in event.content:
- key = (
- EventTypes.ThirdPartyInvite,
- event.content["third_party_invite"]["signed"]["token"]
- )
- third_party_invite_id = current_state_ids.get(key)
- if third_party_invite_id:
- auth_ids.append(third_party_invite_id)
- elif member_event_id:
- member_event = yield self.store.get_event(member_event_id)
- if member_event.content["membership"] == Membership.JOIN:
- auth_ids.append(member_event.event_id)
+ If `for_verification` is False then only return auth events that
+ should be added to the event's `auth_events`.
- defer.returnValue(auth_ids)
+ Returns:
+ defer.Deferred(list[str]): List of event IDs.
+ """
- def check_redaction(self, room_version, event, auth_events):
- """Check whether the event sender is allowed to redact the target event.
+ if event.type == EventTypes.Create:
+ return defer.succeed([])
+
+ # Currently we ignore the `for_verification` flag even though there are
+ # some situations where we can drop particular auth events when adding
+ # to the event's `auth_events` (e.g. joins pointing to previous joins
+ # when room is publically joinable). Dropping event IDs has the
+ # advantage that the auth chain for the room grows slower, but we use
+ # the auth chain in state resolution v2 to order events, which means
+ # care must be taken if dropping events to ensure that it doesn't
+ # introduce undesirable "state reset" behaviour.
+ #
+ # All of which sounds a bit tricky so we don't bother for now.
- Returns:
- True if the the sender is allowed to redact the target event if the
- target event was created by them.
- False if the sender is allowed to redact the target event with no
- further checks.
+ auth_ids = []
+ for etype, state_key in event_auth.auth_types_for_event(event):
+ auth_ev_id = current_state_ids.get((etype, state_key))
+ if auth_ev_id:
+ auth_ids.append(auth_ev_id)
- Raises:
- AuthError if the event sender is definitely not allowed to redact
- the target event.
- """
- return event_auth.check_redaction(room_version, event, auth_events)
+ return defer.succeed(auth_ids)
@defer.inlineCallbacks
- def check_can_change_room_list(self, room_id, user):
- """Check if the user is allowed to edit the room's entry in the
+ def check_can_change_room_list(self, room_id: str, user: UserID):
+ """Determine whether the user is allowed to edit the room's entry in the
published room list.
Args:
- room_id (str)
- user (UserID)
+ room_id
+ user
"""
is_admin = yield self.is_server_admin(user)
if is_admin:
- defer.returnValue(True)
+ return True
user_id = user.to_string()
- yield self.check_joined_room(room_id, user_id)
+ yield self.check_user_in_room(room_id, user_id)
# We currently require the user is a "moderator" in the room. We do this
# by checking if they would (theoretically) be able to change the
- # m.room.aliases events
+ # m.room.canonical_alias events
power_level_event = yield self.state.get_current_state(
room_id, EventTypes.PowerLevels, ""
)
@@ -677,16 +571,11 @@ class Auth(object):
auth_events[(EventTypes.PowerLevels, "")] = power_level_event
send_level = event_auth.get_send_level(
- EventTypes.Aliases, "", power_level_event,
+ EventTypes.CanonicalAlias, "", power_level_event
)
user_level = event_auth.get_user_power_level(user_id, auth_events)
- if user_level < send_level:
- raise AuthError(
- 403,
- "This server requires you to be a moderator in the room to"
- " edit its room list entry"
- )
+ return user_level >= send_level
@staticmethod
def has_access_token(request):
@@ -700,20 +589,16 @@ class Auth(object):
return bool(query_params) or bool(auth_headers)
@staticmethod
- def get_access_token_from_request(request, token_not_found_http_status=401):
+ def get_access_token_from_request(request):
"""Extracts the access_token from the request.
Args:
request: The http request.
- token_not_found_http_status(int): The HTTP status code to set in the
- AuthError if the token isn't found. This is used in some of the
- legacy APIs to change the status code to 403 from the default of
- 401 since some of the old clients depended on auth errors returning
- 403.
Returns:
unicode: The access_token
Raises:
- AuthError: If there isn't an access_token in the request.
+ MissingClientTokenError: If there isn't a single access_token in the
+ request
"""
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
@@ -722,42 +607,36 @@ class Auth(object):
# Try the get the access_token from a "Authorization: Bearer"
# header
if query_params is not None:
- raise AuthError(
- token_not_found_http_status,
- "Mixing Authorization headers and access_token query parameters.",
- errcode=Codes.MISSING_TOKEN,
+ raise MissingClientTokenError(
+ "Mixing Authorization headers and access_token query parameters."
)
if len(auth_headers) > 1:
- raise AuthError(
- token_not_found_http_status,
- "Too many Authorization headers.",
- errcode=Codes.MISSING_TOKEN,
- )
+ raise MissingClientTokenError("Too many Authorization headers.")
parts = auth_headers[0].split(b" ")
if parts[0] == b"Bearer" and len(parts) == 2:
- return parts[1].decode('ascii')
+ return parts[1].decode("ascii")
else:
- raise AuthError(
- token_not_found_http_status,
- "Invalid Authorization header.",
- errcode=Codes.MISSING_TOKEN,
- )
+ raise MissingClientTokenError("Invalid Authorization header.")
else:
# Try to get the access_token from the query params.
if not query_params:
- raise AuthError(
- token_not_found_http_status,
- "Missing access token.",
- errcode=Codes.MISSING_TOKEN
- )
+ raise MissingClientTokenError()
- return query_params[0].decode('ascii')
+ return query_params[0].decode("ascii")
@defer.inlineCallbacks
- def check_in_room_or_world_readable(self, room_id, user_id):
+ def check_user_in_room_or_world_readable(
+ self, room_id: str, user_id: str, allow_departed_users: bool = False
+ ):
"""Checks that the user is or was in the room or the room is world
readable. If it isn't then an exception is raised.
+ Args:
+ room_id: room to check
+ user_id: user to check
+ allow_departed_users: if True, accept users that were previously
+ members but have now departed
+
Returns:
Deferred[tuple[str, str|None]]: Resolves to the current membership of
the user in the room and the membership event ID of the user. If
@@ -766,29 +645,32 @@ class Auth(object):
"""
try:
- # check_user_was_in_room will return the most recent membership
+ # check_user_in_room will return the most recent membership
# event for the user if:
# * The user is a non-guest user, and was ever in the room
# * The user is a guest user, and has joined the room
# else it will throw.
- member_event = yield self.check_user_was_in_room(room_id, user_id)
- defer.returnValue((member_event.membership, member_event.event_id))
+ member_event = yield self.check_user_in_room(
+ room_id, user_id, allow_departed_users=allow_departed_users
+ )
+ return member_event.membership, member_event.event_id
except AuthError:
visibility = yield self.state.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if (
- visibility and
- visibility.content["history_visibility"] == "world_readable"
+ visibility
+ and visibility.content["history_visibility"] == "world_readable"
):
- defer.returnValue((Membership.JOIN, None))
- return
+ return Membership.JOIN, None
raise AuthError(
- 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
+ 403,
+ "User %s not in room %s, and room previews are disabled"
+ % (user_id, room_id),
)
@defer.inlineCallbacks
- def check_auth_blocking(self, user_id=None, threepid=None):
+ def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag
@@ -801,6 +683,9 @@ class Auth(object):
with a MAU blocked server, normally they would be rejected but their
threepid is on the reserved list. user_id and
threepid should never be set at the same time.
+
+ user_type(str|None): If present, is used to decide whether to check against
+ certain blocking reasons like MAU.
"""
# Never fail an auth check for the server notices users or support user
@@ -813,10 +698,11 @@ class Auth(object):
if self.hs.config.hs_disabled:
raise ResourceLimitError(
- 403, self.hs.config.hs_disabled_message,
+ 403,
+ self.hs.config.hs_disabled_message,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
admin_contact=self.hs.config.admin_contact,
- limit_type=self.hs.config.hs_disabled_limit_type
+ limit_type=LimitBlockingTypes.HS_DISABLED,
)
if self.hs.config.limit_usage_by_mau is True:
assert not (user_id and threepid)
@@ -837,12 +723,17 @@ class Auth(object):
self.hs.config.mau_limits_reserved_threepids, threepid
):
return
+ elif user_type == UserTypes.SUPPORT:
+ # If the user does not exist yet and is of type "support",
+ # allow registration. Support users are excluded from MAU checks.
+ return
# Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count()
if current_mau >= self.hs.config.max_mau_value:
raise ResourceLimitError(
- 403, "Monthly Active User Limit Exceeded",
+ 403,
+ "Monthly Active User Limit Exceeded",
admin_contact=self.hs.config.admin_contact,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
- limit_type="monthly_active_user"
+ limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER,
)
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 731c200c8d..cc8577552b 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,7 +19,7 @@
"""Contains constants from the specification."""
# the "depth" field on events is limited to 2**63 - 1
-MAX_DEPTH = 2**63 - 1
+MAX_DEPTH = 2 ** 63 - 1
# the maximum length for a room alias is 255 characters
MAX_ALIAS_LENGTH = 255
@@ -30,39 +31,41 @@ MAX_USERID_LENGTH = 255
class Membership(object):
"""Represents the membership states of a user in a room."""
- INVITE = u"invite"
- JOIN = u"join"
- KNOCK = u"knock"
- LEAVE = u"leave"
- BAN = u"ban"
+
+ INVITE = "invite"
+ JOIN = "join"
+ KNOCK = "knock"
+ LEAVE = "leave"
+ BAN = "ban"
LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN)
class PresenceState(object):
"""Represents the presence state of a user."""
- OFFLINE = u"offline"
- UNAVAILABLE = u"unavailable"
- ONLINE = u"online"
+
+ OFFLINE = "offline"
+ UNAVAILABLE = "unavailable"
+ ONLINE = "online"
class JoinRules(object):
- PUBLIC = u"public"
- KNOCK = u"knock"
- INVITE = u"invite"
- PRIVATE = u"private"
+ PUBLIC = "public"
+ KNOCK = "knock"
+ INVITE = "invite"
+ PRIVATE = "private"
class LoginType(object):
- PASSWORD = u"m.login.password"
- EMAIL_IDENTITY = u"m.login.email.identity"
- MSISDN = u"m.login.msisdn"
- RECAPTCHA = u"m.login.recaptcha"
- TERMS = u"m.login.terms"
- DUMMY = u"m.login.dummy"
+ PASSWORD = "m.login.password"
+ EMAIL_IDENTITY = "m.login.email.identity"
+ MSISDN = "m.login.msisdn"
+ RECAPTCHA = "m.login.recaptcha"
+ TERMS = "m.login.terms"
+ DUMMY = "m.login.dummy"
# Only for C/S API v1
- APPLICATION_SERVICE = u"m.login.application_service"
- SHARED_SECRET = u"org.matrix.login.shared_secret"
+ APPLICATION_SERVICE = "m.login.application_service"
+ SHARED_SECRET = "org.matrix.login.shared_secret"
class EventTypes(object):
@@ -74,16 +77,14 @@ class EventTypes(object):
Aliases = "m.room.aliases"
Redaction = "m.room.redaction"
ThirdPartyInvite = "m.room.third_party_invite"
- Encryption = "m.room.encryption"
RelatedGroups = "m.room.related_groups"
RoomHistoryVisibility = "m.room.history_visibility"
CanonicalAlias = "m.room.canonical_alias"
- Encryption = "m.room.encryption"
+ Encrypted = "m.room.encrypted"
RoomAvatar = "m.room.avatar"
RoomEncryption = "m.room.encryption"
GuestAccess = "m.room.guest_access"
- Encryption = "m.room.encryption"
# These are used for validation
Message = "m.room.message"
@@ -98,8 +99,6 @@ class EventTypes(object):
class RejectedReason(object):
AUTH_ERROR = "auth_error"
- REPLACED = "replaced"
- NOT_ANCESTOR = "not_ancestor"
class RoomCreationPreset(object):
@@ -121,13 +120,34 @@ class UserTypes(object):
"""Allows for user type specific behaviour. With the benefit of hindsight
'admin' and 'guest' users should also be UserTypes. Normal users are type None
"""
+
SUPPORT = "support"
- ALL_USER_TYPES = (SUPPORT,)
+ BOT = "bot"
+ ALL_USER_TYPES = (SUPPORT, BOT)
class RelationTypes(object):
"""The types of relations known to this server.
"""
+
ANNOTATION = "m.annotation"
REPLACE = "m.replace"
REFERENCE = "m.reference"
+
+
+class LimitBlockingTypes(object):
+ """Reasons that a server may be blocked"""
+
+ MONTHLY_ACTIVE_USER = "monthly_active_user"
+ HS_DISABLED = "hs_disabled"
+
+
+class EventContentFields(object):
+ """Fields found in events' content, regardless of type."""
+
+ # Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326
+ LABELS = "org.matrix.labels"
+
+ # Timestamp to delete the event after
+ # cf https://github.com/matrix-org/matrix-doc/pull/2228
+ SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index e46bfdfcb9..9dd8d48386 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -18,12 +18,15 @@
"""Contains exceptions and error codes."""
import logging
+from typing import Dict, List
from six import iteritems
from six.moves import http_client
from canonicaljson import json
+from twisted.web import http
+
logger = logging.getLogger(__name__)
@@ -62,6 +65,9 @@ class Codes(object):
INCOMPATIBLE_ROOM_VERSION = "M_INCOMPATIBLE_ROOM_VERSION"
WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION"
EXPIRED_ACCOUNT = "ORG_MATRIX_EXPIRED_ACCOUNT"
+ INVALID_SIGNATURE = "M_INVALID_SIGNATURE"
+ USER_DEACTIVATED = "M_USER_DEACTIVATED"
+ BAD_ALIAS = "M_BAD_ALIAS"
PASSWORD_TOO_SHORT = "M_PASSWORD_TOO_SHORT"
PASSWORD_NO_DIGIT = "M_PASSWORD_NO_DIGIT"
PASSWORD_NO_UPPERCASE = "M_PASSWORD_NO_UPPERCASE"
@@ -78,12 +84,36 @@ class CodeMessageException(RuntimeError):
code (int): HTTP error code
msg (str): string describing the error
"""
+
def __init__(self, code, msg):
super(CodeMessageException, self).__init__("%d: %s" % (code, msg))
self.code = code
self.msg = msg
+class RedirectException(CodeMessageException):
+ """A pseudo-error indicating that we want to redirect the client to a different
+ location
+
+ Attributes:
+ cookies: a list of set-cookies values to add to the response. For example:
+ b"sessionId=a3fWa; Expires=Wed, 21 Oct 2015 07:28:00 GMT"
+ """
+
+ def __init__(self, location: bytes, http_code: int = http.FOUND):
+ """
+
+ Args:
+ location: the URI to redirect to
+ http_code: the HTTP response code
+ """
+ msg = "Redirect to %s" % (location.decode("utf-8"),)
+ super().__init__(code=http_code, msg=msg)
+ self.location = location
+
+ self.cookies = [] # type: List[bytes]
+
+
class SynapseError(CodeMessageException):
"""A base exception type for matrix errors which have an errcode and error
message (as well as an HTTP status code).
@@ -91,6 +121,7 @@ class SynapseError(CodeMessageException):
Attributes:
errcode (str): Matrix error code e.g 'M_FORBIDDEN'
"""
+
def __init__(self, code, msg, errcode=Codes.UNKNOWN):
"""Constructs a synapse error.
@@ -103,10 +134,7 @@ class SynapseError(CodeMessageException):
self.errcode = errcode
def error_dict(self):
- return cs_error(
- self.msg,
- self.errcode,
- )
+ return cs_error(self.msg, self.errcode)
class ProxiedRequestError(SynapseError):
@@ -115,27 +143,23 @@ class ProxiedRequestError(SynapseError):
Attributes:
errcode (str): Matrix error code e.g 'M_FORBIDDEN'
"""
+
def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None):
- super(ProxiedRequestError, self).__init__(
- code, msg, errcode
- )
+ super(ProxiedRequestError, self).__init__(code, msg, errcode)
if additional_fields is None:
- self._additional_fields = {}
+ self._additional_fields = {} # type: Dict
else:
self._additional_fields = dict(additional_fields)
def error_dict(self):
- return cs_error(
- self.msg,
- self.errcode,
- **self._additional_fields
- )
+ return cs_error(self.msg, self.errcode, **self._additional_fields)
class ConsentNotGivenError(SynapseError):
"""The error returned to the client when the user has not consented to the
privacy policy.
"""
+
def __init__(self, msg, consent_uri):
"""Constructs a ConsentNotGivenError
@@ -144,23 +168,28 @@ class ConsentNotGivenError(SynapseError):
consent_url (str): The URL where the user can give their consent
"""
super(ConsentNotGivenError, self).__init__(
- code=http_client.FORBIDDEN,
- msg=msg,
- errcode=Codes.CONSENT_NOT_GIVEN
+ code=http_client.FORBIDDEN, msg=msg, errcode=Codes.CONSENT_NOT_GIVEN
)
self._consent_uri = consent_uri
def error_dict(self):
- return cs_error(
- self.msg,
- self.errcode,
- consent_uri=self._consent_uri
- )
+ return cs_error(self.msg, self.errcode, consent_uri=self._consent_uri)
-class RegistrationError(SynapseError):
- """An error raised when a registration event fails."""
- pass
+class UserDeactivatedError(SynapseError):
+ """The error returned to the client when the user attempted to access an
+ authenticated endpoint, but the account has been deactivated.
+ """
+
+ def __init__(self, msg):
+ """Constructs a UserDeactivatedError
+
+ Args:
+ msg (str): The human-readable error message
+ """
+ super(UserDeactivatedError, self).__init__(
+ code=http_client.FORBIDDEN, msg=msg, errcode=Codes.USER_DEACTIVATED
+ )
class FederationDeniedError(SynapseError):
@@ -198,15 +227,17 @@ class InteractiveAuthIncompleteError(Exception):
result (dict): the server response to the request, which should be
passed back to the client
"""
+
def __init__(self, result):
super(InteractiveAuthIncompleteError, self).__init__(
- "Interactive auth not yet complete",
+ "Interactive auth not yet complete"
)
self.result = result
class UnrecognizedRequestError(SynapseError):
"""An error indicating we don't understand the request you're trying to make"""
+
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.UNRECOGNIZED
@@ -215,25 +246,20 @@ class UnrecognizedRequestError(SynapseError):
message = "Unrecognized request"
else:
message = args[0]
- super(UnrecognizedRequestError, self).__init__(
- 400,
- message,
- **kwargs
- )
+ super(UnrecognizedRequestError, self).__init__(400, message, **kwargs)
class NotFoundError(SynapseError):
"""An error indicating we can't find the thing you asked for"""
+
def __init__(self, msg="Not found", errcode=Codes.NOT_FOUND):
- super(NotFoundError, self).__init__(
- 404,
- msg,
- errcode=errcode
- )
+ super(NotFoundError, self).__init__(404, msg, errcode=errcode)
class AuthError(SynapseError):
- """An error raised when there was a problem authorising an event."""
+ """An error raised when there was a problem authorising an event, and at various
+ other poorly-defined times.
+ """
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
@@ -241,13 +267,51 @@ class AuthError(SynapseError):
super(AuthError, self).__init__(*args, **kwargs)
+class InvalidClientCredentialsError(SynapseError):
+ """An error raised when there was a problem with the authorisation credentials
+ in a client request.
+
+ https://matrix.org/docs/spec/client_server/r0.5.0#using-access-tokens:
+
+ When credentials are required but missing or invalid, the HTTP call will
+ return with a status of 401 and the error code, M_MISSING_TOKEN or
+ M_UNKNOWN_TOKEN respectively.
+ """
+
+ def __init__(self, msg, errcode):
+ super().__init__(code=401, msg=msg, errcode=errcode)
+
+
+class MissingClientTokenError(InvalidClientCredentialsError):
+ """Raised when we couldn't find the access token in a request"""
+
+ def __init__(self, msg="Missing access token"):
+ super().__init__(msg=msg, errcode="M_MISSING_TOKEN")
+
+
+class InvalidClientTokenError(InvalidClientCredentialsError):
+ """Raised when we didn't understand the access token in a request"""
+
+ def __init__(self, msg="Unrecognised access token", soft_logout=False):
+ super().__init__(msg=msg, errcode="M_UNKNOWN_TOKEN")
+ self._soft_logout = soft_logout
+
+ def error_dict(self):
+ d = super().error_dict()
+ d["soft_logout"] = self._soft_logout
+ return d
+
+
class ResourceLimitError(SynapseError):
"""
Any error raised when there is a problem with resource usage.
For instance, the monthly active user limit for the server has been exceeded
"""
+
def __init__(
- self, code, msg,
+ self,
+ code,
+ msg,
errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
admin_contact=None,
limit_type=None,
@@ -261,7 +325,7 @@ class ResourceLimitError(SynapseError):
self.msg,
self.errcode,
admin_contact=self.admin_contact,
- limit_type=self.limit_type
+ limit_type=self.limit_type,
)
@@ -276,6 +340,7 @@ class EventSizeError(SynapseError):
class EventStreamError(SynapseError):
"""An error raised when there a problem with the event stream."""
+
def __init__(self, *args, **kwargs):
if "errcode" not in kwargs:
kwargs["errcode"] = Codes.BAD_PAGINATION
@@ -284,47 +349,53 @@ class EventStreamError(SynapseError):
class LoginError(SynapseError):
"""An error raised when there was a problem logging in."""
+
pass
class StoreError(SynapseError):
"""An error raised when there was a problem storing some data."""
+
pass
class InvalidCaptchaError(SynapseError):
- def __init__(self, code=400, msg="Invalid captcha.", error_url=None,
- errcode=Codes.CAPTCHA_INVALID):
+ def __init__(
+ self,
+ code=400,
+ msg="Invalid captcha.",
+ error_url=None,
+ errcode=Codes.CAPTCHA_INVALID,
+ ):
super(InvalidCaptchaError, self).__init__(code, msg, errcode)
self.error_url = error_url
def error_dict(self):
- return cs_error(
- self.msg,
- self.errcode,
- error_url=self.error_url,
- )
+ return cs_error(self.msg, self.errcode, error_url=self.error_url)
class LimitExceededError(SynapseError):
"""A client has sent too many requests and is being throttled.
"""
- def __init__(self, code=429, msg="Too Many Requests", retry_after_ms=None,
- errcode=Codes.LIMIT_EXCEEDED):
+
+ def __init__(
+ self,
+ code=429,
+ msg="Too Many Requests",
+ retry_after_ms=None,
+ errcode=Codes.LIMIT_EXCEEDED,
+ ):
super(LimitExceededError, self).__init__(code, msg, errcode)
self.retry_after_ms = retry_after_ms
def error_dict(self):
- return cs_error(
- self.msg,
- self.errcode,
- retry_after_ms=self.retry_after_ms,
- )
+ return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms)
class RoomKeysVersionError(SynapseError):
"""A client has tried to upload to a non-current version of the room_keys store
"""
+
def __init__(self, current_version):
"""
Args:
@@ -339,11 +410,10 @@ class RoomKeysVersionError(SynapseError):
class UnsupportedRoomVersionError(SynapseError):
"""The client's request to create a room used a room version that the server does
not support."""
- def __init__(self):
+
+ def __init__(self, msg="Homeserver does not support this room version"):
super(UnsupportedRoomVersionError, self).__init__(
- code=400,
- msg="Homeserver does not support this room version",
- errcode=Codes.UNSUPPORTED_ROOM_VERSION,
+ code=400, msg=msg, errcode=Codes.UNSUPPORTED_ROOM_VERSION,
)
@@ -362,22 +432,19 @@ class IncompatibleRoomVersionError(SynapseError):
Unlike UnsupportedRoomVersionError, it is specific to the case of the make_join
failing.
"""
+
def __init__(self, room_version):
super(IncompatibleRoomVersionError, self).__init__(
code=400,
msg="Your homeserver does not support the features required to "
- "join this room",
+ "join this room",
errcode=Codes.INCOMPATIBLE_ROOM_VERSION,
)
self._room_version = room_version
def error_dict(self):
- return cs_error(
- self.msg,
- self.errcode,
- room_version=self._room_version,
- )
+ return cs_error(self.msg, self.errcode, room_version=self._room_version)
class PasswordRefusedError(SynapseError):
@@ -389,11 +456,7 @@ class PasswordRefusedError(SynapseError):
msg="This password doesn't comply with the server's policy",
errcode=Codes.WEAK_PASSWORD,
):
- super(PasswordRefusedError, self).__init__(
- code=400,
- msg=msg,
- errcode=errcode,
- )
+ super(PasswordRefusedError, self).__init__(code=400, msg=msg, errcode=errcode)
class RequestSendFailed(RuntimeError):
@@ -404,11 +467,11 @@ class RequestSendFailed(RuntimeError):
networking (e.g. DNS failures, connection timeouts etc), versus unexpected
errors (like programming errors).
"""
+
def __init__(self, inner_exception, can_retry):
super(RequestSendFailed, self).__init__(
- "Failed to send request: %s: %s" % (
- type(inner_exception).__name__, inner_exception,
- )
+ "Failed to send request: %s: %s"
+ % (type(inner_exception).__name__, inner_exception)
)
self.inner_exception = inner_exception
self.can_retry = can_retry
@@ -432,7 +495,7 @@ def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
class FederationError(RuntimeError):
- """ This class is used to inform remote home servers about erroneous
+ """ This class is used to inform remote homeservers about erroneous
PDUs they sent us.
FATAL: The remote server could not interpret the source event.
@@ -452,7 +515,7 @@ class FederationError(RuntimeError):
self.affected = affected
self.source = source
- msg = "%s %s: %s" % (level, code, reason,)
+ msg = "%s %s: %s" % (level, code, reason)
super(FederationError, self).__init__(msg)
def get_dict(self):
@@ -472,6 +535,7 @@ class HttpResponseException(CodeMessageException):
Attributes:
response (bytes): body of response
"""
+
def __init__(self, code, msg, response):
"""
@@ -510,7 +574,7 @@ class HttpResponseException(CodeMessageException):
if not isinstance(j, dict):
j = {}
- errcode = j.pop('errcode', Codes.UNKNOWN)
- errmsg = j.pop('error', self.msg)
+ errcode = j.pop("errcode", Codes.UNKNOWN)
+ errmsg = j.pop("error", self.msg)
return ProxiedRequestError(self.code, errmsg, errcode, j)
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 3906475403..8b64d0a285 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -1,5 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,6 +15,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import List
+
from six import text_type
import jsonschema
@@ -20,6 +25,7 @@ from jsonschema import FormatChecker
from twisted.internet import defer
+from synapse.api.constants import EventContentFields
from synapse.api.errors import SynapseError
from synapse.storage.presence import UserPresenceState
from synapse.types import RoomID, UserID
@@ -28,117 +34,59 @@ FILTER_SCHEMA = {
"additionalProperties": False,
"type": "object",
"properties": {
- "limit": {
- "type": "number"
- },
- "senders": {
- "$ref": "#/definitions/user_id_array"
- },
- "not_senders": {
- "$ref": "#/definitions/user_id_array"
- },
+ "limit": {"type": "number"},
+ "senders": {"$ref": "#/definitions/user_id_array"},
+ "not_senders": {"$ref": "#/definitions/user_id_array"},
# TODO: We don't limit event type values but we probably should...
# check types are valid event types
- "types": {
- "type": "array",
- "items": {
- "type": "string"
- }
- },
- "not_types": {
- "type": "array",
- "items": {
- "type": "string"
- }
- }
- }
+ "types": {"type": "array", "items": {"type": "string"}},
+ "not_types": {"type": "array", "items": {"type": "string"}},
+ },
}
ROOM_FILTER_SCHEMA = {
"additionalProperties": False,
"type": "object",
"properties": {
- "not_rooms": {
- "$ref": "#/definitions/room_id_array"
- },
- "rooms": {
- "$ref": "#/definitions/room_id_array"
- },
- "ephemeral": {
- "$ref": "#/definitions/room_event_filter"
- },
- "include_leave": {
- "type": "boolean"
- },
- "state": {
- "$ref": "#/definitions/room_event_filter"
- },
- "timeline": {
- "$ref": "#/definitions/room_event_filter"
- },
- "account_data": {
- "$ref": "#/definitions/room_event_filter"
- },
- }
+ "not_rooms": {"$ref": "#/definitions/room_id_array"},
+ "rooms": {"$ref": "#/definitions/room_id_array"},
+ "ephemeral": {"$ref": "#/definitions/room_event_filter"},
+ "include_leave": {"type": "boolean"},
+ "state": {"$ref": "#/definitions/room_event_filter"},
+ "timeline": {"$ref": "#/definitions/room_event_filter"},
+ "account_data": {"$ref": "#/definitions/room_event_filter"},
+ },
}
ROOM_EVENT_FILTER_SCHEMA = {
"additionalProperties": False,
"type": "object",
"properties": {
- "limit": {
- "type": "number"
- },
- "senders": {
- "$ref": "#/definitions/user_id_array"
- },
- "not_senders": {
- "$ref": "#/definitions/user_id_array"
- },
- "types": {
- "type": "array",
- "items": {
- "type": "string"
- }
- },
- "not_types": {
- "type": "array",
- "items": {
- "type": "string"
- }
- },
- "rooms": {
- "$ref": "#/definitions/room_id_array"
- },
- "not_rooms": {
- "$ref": "#/definitions/room_id_array"
- },
- "contains_url": {
- "type": "boolean"
- },
- "lazy_load_members": {
- "type": "boolean"
- },
- "include_redundant_members": {
- "type": "boolean"
- },
- }
+ "limit": {"type": "number"},
+ "senders": {"$ref": "#/definitions/user_id_array"},
+ "not_senders": {"$ref": "#/definitions/user_id_array"},
+ "types": {"type": "array", "items": {"type": "string"}},
+ "not_types": {"type": "array", "items": {"type": "string"}},
+ "rooms": {"$ref": "#/definitions/room_id_array"},
+ "not_rooms": {"$ref": "#/definitions/room_id_array"},
+ "contains_url": {"type": "boolean"},
+ "lazy_load_members": {"type": "boolean"},
+ "include_redundant_members": {"type": "boolean"},
+ # Include or exclude events with the provided labels.
+ # cf https://github.com/matrix-org/matrix-doc/pull/2326
+ "org.matrix.labels": {"type": "array", "items": {"type": "string"}},
+ "org.matrix.not_labels": {"type": "array", "items": {"type": "string"}},
+ },
}
USER_ID_ARRAY_SCHEMA = {
"type": "array",
- "items": {
- "type": "string",
- "format": "matrix_user_id"
- }
+ "items": {"type": "string", "format": "matrix_user_id"},
}
ROOM_ID_ARRAY_SCHEMA = {
"type": "array",
- "items": {
- "type": "string",
- "format": "matrix_room_id"
- }
+ "items": {"type": "string", "format": "matrix_room_id"},
}
USER_FILTER_SCHEMA = {
@@ -150,22 +98,13 @@ USER_FILTER_SCHEMA = {
"user_id_array": USER_ID_ARRAY_SCHEMA,
"filter": FILTER_SCHEMA,
"room_filter": ROOM_FILTER_SCHEMA,
- "room_event_filter": ROOM_EVENT_FILTER_SCHEMA
+ "room_event_filter": ROOM_EVENT_FILTER_SCHEMA,
},
"properties": {
- "presence": {
- "$ref": "#/definitions/filter"
- },
- "account_data": {
- "$ref": "#/definitions/filter"
- },
- "room": {
- "$ref": "#/definitions/room_filter"
- },
- "event_format": {
- "type": "string",
- "enum": ["client", "federation"]
- },
+ "presence": {"$ref": "#/definitions/filter"},
+ "account_data": {"$ref": "#/definitions/filter"},
+ "room": {"$ref": "#/definitions/room_filter"},
+ "event_format": {"type": "string", "enum": ["client", "federation"]},
"event_fields": {
"type": "array",
"items": {
@@ -177,26 +116,25 @@ USER_FILTER_SCHEMA = {
#
# Note that because this is a regular expression, we have to escape
# each backslash in the pattern.
- "pattern": r"^((?!\\\\).)*$"
- }
- }
+ "pattern": r"^((?!\\\\).)*$",
+ },
+ },
},
- "additionalProperties": False
+ "additionalProperties": False,
}
-@FormatChecker.cls_checks('matrix_room_id')
+@FormatChecker.cls_checks("matrix_room_id")
def matrix_room_id_validator(room_id_str):
return RoomID.from_string(room_id_str)
-@FormatChecker.cls_checks('matrix_user_id')
+@FormatChecker.cls_checks("matrix_user_id")
def matrix_user_id_validator(user_id_str):
return UserID.from_string(user_id_str)
class Filtering(object):
-
def __init__(self, hs):
super(Filtering, self).__init__()
self.store = hs.get_datastore()
@@ -204,7 +142,7 @@ class Filtering(object):
@defer.inlineCallbacks
def get_user_filter(self, user_localpart, filter_id):
result = yield self.store.get_user_filter(user_localpart, filter_id)
- defer.returnValue(FilterCollection(result))
+ return FilterCollection(result)
def add_user_filter(self, user_localpart, user_filter):
self.check_valid_filter(user_filter)
@@ -228,8 +166,9 @@ class Filtering(object):
# individual top-level key e.g. public_user_data. Filters are made of
# many definitions.
try:
- jsonschema.validate(user_filter_json, USER_FILTER_SCHEMA,
- format_checker=FormatChecker())
+ jsonschema.validate(
+ user_filter_json, USER_FILTER_SCHEMA, format_checker=FormatChecker()
+ )
except jsonschema.ValidationError as e:
raise SynapseError(400, str(e))
@@ -240,10 +179,9 @@ class FilterCollection(object):
room_filter_json = self._filter_json.get("room", {})
- self._room_filter = Filter({
- k: v for k, v in room_filter_json.items()
- if k in ("rooms", "not_rooms")
- })
+ self._room_filter = Filter(
+ {k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms")}
+ )
self._room_timeline_filter = Filter(room_filter_json.get("timeline", {}))
self._room_state_filter = Filter(room_filter_json.get("state", {}))
@@ -252,9 +190,7 @@ class FilterCollection(object):
self._presence_filter = Filter(filter_json.get("presence", {}))
self._account_data = Filter(filter_json.get("account_data", {}))
- self.include_leave = filter_json.get("room", {}).get(
- "include_leave", False
- )
+ self.include_leave = filter_json.get("room", {}).get("include_leave", False)
self.event_fields = filter_json.get("event_fields", [])
self.event_format = filter_json.get("event_format", "client")
@@ -299,22 +235,22 @@ class FilterCollection(object):
def blocks_all_presence(self):
return (
- self._presence_filter.filters_all_types() or
- self._presence_filter.filters_all_senders()
+ self._presence_filter.filters_all_types()
+ or self._presence_filter.filters_all_senders()
)
def blocks_all_room_ephemeral(self):
return (
- self._room_ephemeral_filter.filters_all_types() or
- self._room_ephemeral_filter.filters_all_senders() or
- self._room_ephemeral_filter.filters_all_rooms()
+ self._room_ephemeral_filter.filters_all_types()
+ or self._room_ephemeral_filter.filters_all_senders()
+ or self._room_ephemeral_filter.filters_all_rooms()
)
def blocks_all_room_timeline(self):
return (
- self._room_timeline_filter.filters_all_types() or
- self._room_timeline_filter.filters_all_senders() or
- self._room_timeline_filter.filters_all_rooms()
+ self._room_timeline_filter.filters_all_types()
+ or self._room_timeline_filter.filters_all_senders()
+ or self._room_timeline_filter.filters_all_rooms()
)
@@ -333,6 +269,9 @@ class Filter(object):
self.contains_url = self.filter_json.get("contains_url", None)
+ self.labels = self.filter_json.get("org.matrix.labels", None)
+ self.not_labels = self.filter_json.get("org.matrix.not_labels", [])
+
def filters_all_types(self):
return "*" in self.not_types
@@ -356,6 +295,7 @@ class Filter(object):
room_id = None
ev_type = "m.presence"
contains_url = False
+ labels = [] # type: List[str]
else:
sender = event.get("sender", None)
if not sender:
@@ -374,15 +314,11 @@ class Filter(object):
content = event.get("content", {})
# check if there is a string url field in the content for filtering purposes
contains_url = isinstance(content.get("url"), text_type)
+ labels = content.get(EventContentFields.LABELS, [])
- return self.check_fields(
- room_id,
- sender,
- ev_type,
- contains_url,
- )
+ return self.check_fields(room_id, sender, ev_type, labels, contains_url)
- def check_fields(self, room_id, sender, event_type, contains_url):
+ def check_fields(self, room_id, sender, event_type, labels, contains_url):
"""Checks whether the filter matches the given event fields.
Returns:
@@ -391,7 +327,8 @@ class Filter(object):
literal_keys = {
"rooms": lambda v: room_id == v,
"senders": lambda v: sender == v,
- "types": lambda v: _matches_wildcard(event_type, v)
+ "types": lambda v: _matches_wildcard(event_type, v),
+ "labels": lambda v: v in labels,
}
for name, match_func in literal_keys.items():
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index 296c4a1c17..7a049b3af7 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import collections
+from collections import OrderedDict
+from typing import Any, Optional, Tuple
from synapse.api.errors import LimitExceededError
@@ -23,7 +24,9 @@ class Ratelimiter(object):
"""
def __init__(self):
- self.message_counts = collections.OrderedDict()
+ self.message_counts = (
+ OrderedDict()
+ ) # type: OrderedDict[Any, Tuple[float, int, Optional[float]]]
def can_do_action(self, key, time_now_s, rate_hz, burst_count, update=True):
"""Can the entity (e.g. user or IP address) perform the action?
@@ -44,29 +47,25 @@ class Ratelimiter(object):
"""
self.prune_message_counts(time_now_s)
message_count, time_start, _ignored = self.message_counts.get(
- key, (0., time_now_s, None),
+ key, (0.0, time_now_s, None)
)
time_delta = time_now_s - time_start
sent_count = message_count - time_delta * rate_hz
if sent_count < 0:
allowed = True
time_start = time_now_s
- message_count = 1.
- elif sent_count > burst_count - 1.:
+ message_count = 1.0
+ elif sent_count > burst_count - 1.0:
allowed = False
else:
allowed = True
message_count += 1
if update:
- self.message_counts[key] = (
- message_count, time_start, rate_hz
- )
+ self.message_counts[key] = (message_count, time_start, rate_hz)
if rate_hz > 0:
- time_allowed = (
- time_start + (message_count - burst_count + 1) / rate_hz
- )
+ time_allowed = time_start + (message_count - burst_count + 1) / rate_hz
if time_allowed < time_now_s:
time_allowed = time_now_s
else:
@@ -76,9 +75,7 @@ class Ratelimiter(object):
def prune_message_counts(self, time_now_s):
for key in list(self.message_counts.keys()):
- message_count, time_start, rate_hz = (
- self.message_counts[key]
- )
+ message_count, time_start, rate_hz = self.message_counts[key]
time_delta = time_now_s - time_start
if message_count - time_delta * rate_hz > 0:
break
@@ -92,5 +89,5 @@ class Ratelimiter(object):
if not allowed:
raise LimitExceededError(
- retry_after_ms=int(1000 * (time_allowed - time_now_s)),
+ retry_after_ms=int(1000 * (time_allowed - time_now_s))
)
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index d644803d38..871179749a 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -12,6 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+from typing import Dict
+
import attr
@@ -19,9 +22,10 @@ class EventFormatVersions(object):
"""This is an internal enum for tracking the version of the event format,
independently from the room version.
"""
- V1 = 1 # $id:server event id format
- V2 = 2 # MSC1659-style $hash event id format: introduced for room v3
- V3 = 3 # MSC1884-style $hash format: introduced for room v4
+
+ V1 = 1 # $id:server event id format
+ V2 = 2 # MSC1659-style $hash event id format: introduced for room v3
+ V3 = 3 # MSC1884-style $hash format: introduced for room v4
KNOWN_EVENT_FORMAT_VERSIONS = {
@@ -33,8 +37,9 @@ KNOWN_EVENT_FORMAT_VERSIONS = {
class StateResolutionVersions(object):
"""Enum to identify the state resolution algorithms"""
- V1 = 1 # room v1 state res
- V2 = 2 # MSC1442 state res: room v2 and later
+
+ V1 = 1 # room v1 state res
+ V2 = 2 # MSC1442 state res: room v2 and later
class RoomDisposition(object):
@@ -46,12 +51,15 @@ class RoomDisposition(object):
class RoomVersion(object):
"""An object which describes the unique attributes of a room version."""
- identifier = attr.ib() # str; the identifier for this version
- disposition = attr.ib() # str; one of the RoomDispositions
- event_format = attr.ib() # int; one of the EventFormatVersions
- state_res = attr.ib() # int; one of the StateResolutionVersions
+ identifier = attr.ib() # str; the identifier for this version
+ disposition = attr.ib() # str; one of the RoomDispositions
+ event_format = attr.ib() # int; one of the EventFormatVersions
+ state_res = attr.ib() # int; one of the StateResolutionVersions
enforce_key_validity = attr.ib() # bool
+ # bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules
+ special_case_aliases_auth = attr.ib(type=bool, default=False)
+
class RoomVersions(object):
V1 = RoomVersion(
@@ -60,6 +68,7 @@ class RoomVersions(object):
EventFormatVersions.V1,
StateResolutionVersions.V1,
enforce_key_validity=False,
+ special_case_aliases_auth=True,
)
V2 = RoomVersion(
"2",
@@ -67,6 +76,7 @@ class RoomVersions(object):
EventFormatVersions.V1,
StateResolutionVersions.V2,
enforce_key_validity=False,
+ special_case_aliases_auth=True,
)
V3 = RoomVersion(
"3",
@@ -74,6 +84,7 @@ class RoomVersions(object):
EventFormatVersions.V2,
StateResolutionVersions.V2,
enforce_key_validity=False,
+ special_case_aliases_auth=True,
)
V4 = RoomVersion(
"4",
@@ -81,6 +92,7 @@ class RoomVersions(object):
EventFormatVersions.V3,
StateResolutionVersions.V2,
enforce_key_validity=False,
+ special_case_aliases_auth=True,
)
V5 = RoomVersion(
"5",
@@ -88,15 +100,26 @@ class RoomVersions(object):
EventFormatVersions.V3,
StateResolutionVersions.V2,
enforce_key_validity=True,
+ special_case_aliases_auth=True,
+ )
+ MSC2432_DEV = RoomVersion(
+ "org.matrix.msc2432",
+ RoomDisposition.UNSTABLE,
+ EventFormatVersions.V3,
+ StateResolutionVersions.V2,
+ enforce_key_validity=True,
+ special_case_aliases_auth=False,
)
KNOWN_ROOM_VERSIONS = {
- v.identifier: v for v in (
+ v.identifier: v
+ for v in (
RoomVersions.V1,
RoomVersions.V2,
RoomVersions.V3,
RoomVersions.V4,
RoomVersions.V5,
+ RoomVersions.MSC2432_DEV,
)
-} # type: dict[str, RoomVersion]
+} # type: Dict[str, RoomVersion]
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index e16c386a14..f34434bd67 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -29,7 +29,6 @@ FEDERATION_V2_PREFIX = FEDERATION_PREFIX + "/v2"
FEDERATION_UNSTABLE_PREFIX = FEDERATION_PREFIX + "/unstable"
STATIC_PREFIX = "/_matrix/static"
WEB_CLIENT_PREFIX = "/_matrix/client"
-CONTENT_REPO_PREFIX = "/_matrix/content"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
MEDIA_PREFIX = "/_matrix/media/r0"
LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
@@ -42,13 +41,9 @@ class ConsentURIBuilder(object):
hs_config (synapse.config.homeserver.HomeServerConfig):
"""
if hs_config.form_secret is None:
- raise ConfigError(
- "form_secret not set in config",
- )
+ raise ConfigError("form_secret not set in config")
if hs_config.public_baseurl is None:
- raise ConfigError(
- "public_baseurl not set in config",
- )
+ raise ConfigError("public_baseurl not set in config")
self._hmac_secret = hs_config.form_secret.encode("utf-8")
self._public_baseurl = hs_config.public_baseurl
@@ -64,15 +59,10 @@ class ConsentURIBuilder(object):
(str) the URI where the user can do consent
"""
mac = hmac.new(
- key=self._hmac_secret,
- msg=user_id.encode('ascii'),
- digestmod=sha256,
+ key=self._hmac_secret, msg=user_id.encode("ascii"), digestmod=sha256
).hexdigest()
consent_uri = "%s_matrix/consent?%s" % (
self._public_baseurl,
- urlencode({
- "u": user_id,
- "h": mac
- }),
+ urlencode({"u": user_id, "h": mac}),
)
return consent_uri
diff --git a/synapse/app/__init__.py b/synapse/app/__init__.py
index f56f5fcc13..a01bac2997 100644
--- a/synapse/app/__init__.py
+++ b/synapse/app/__init__.py
@@ -43,7 +43,9 @@ def check_bind_error(e, address, bind_addresses):
address (str): Address on which binding was attempted.
bind_addresses (list): Addresses on which the service listens.
"""
- if address == '0.0.0.0' and '::' in bind_addresses:
- logger.warn('Failed to listen on 0.0.0.0, continuing because listening on [::]')
+ if address == "0.0.0.0" and "::" in bind_addresses:
+ logger.warning(
+ "Failed to listen on 0.0.0.0, continuing because listening on [::]"
+ )
else:
raise e
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 8cc990399f..4d84f4595a 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -15,11 +15,12 @@
import gc
import logging
+import os
import signal
+import socket
import sys
import traceback
-import psutil
from daemonize import Daemonize
from twisted.internet import defer, error, reactor
@@ -28,28 +29,30 @@ from twisted.protocols.tls import TLSMemoryBIOFactory
import synapse
from synapse.app import check_bind_error
from synapse.crypto import context_factory
-from synapse.util import PreserveLoggingContext
+from synapse.logging.context import PreserveLoggingContext
from synapse.util.async_helpers import Linearizer
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__)
+# list of tuples of function, args list, kwargs dict
_sighup_callbacks = []
-def register_sighup(func):
+def register_sighup(func, *args, **kwargs):
"""
Register a function to be called when a SIGHUP occurs.
Args:
func (function): Function to be called when sent a SIGHUP signal.
- Will be called with a single argument, the homeserver.
+ Will be called with a single default argument, the homeserver.
+ *args, **kwargs: args and kwargs to be passed to the target function.
"""
- _sighup_callbacks.append(func)
+ _sighup_callbacks.append((func, args, kwargs))
-def start_worker_reactor(appname, config):
+def start_worker_reactor(appname, config, run_command=reactor.run):
""" Run the reactor in the main process
Daemonizes if necessary, and then configures some resources, before starting
@@ -58,6 +61,7 @@ def start_worker_reactor(appname, config):
Args:
appname (str): application name which will be sent to syslog
config (synapse.config.Config): config object
+ run_command (Callable[]): callable that actually runs the reactor
"""
logger = logging.getLogger(config.worker_app)
@@ -68,21 +72,21 @@ def start_worker_reactor(appname, config):
gc_thresholds=config.gc_thresholds,
pid_file=config.worker_pid_file,
daemonize=config.worker_daemonize,
- cpu_affinity=config.worker_cpu_affinity,
print_pidfile=config.print_pidfile,
logger=logger,
+ run_command=run_command,
)
def start_reactor(
- appname,
- soft_file_limit,
- gc_thresholds,
- pid_file,
- daemonize,
- cpu_affinity,
- print_pidfile,
- logger,
+ appname,
+ soft_file_limit,
+ gc_thresholds,
+ pid_file,
+ daemonize,
+ print_pidfile,
+ logger,
+ run_command=reactor.run,
):
""" Run the reactor in the main process
@@ -95,64 +99,53 @@ def start_reactor(
gc_thresholds:
pid_file (str): name of pid file to write to if daemonize is True
daemonize (bool): true to run the reactor in a background process
- cpu_affinity (int|None): cpu affinity mask
print_pidfile (bool): whether to print the pid file, if daemonize is True
logger (logging.Logger): logger instance to pass to Daemonize
+ run_command (Callable[]): callable that actually runs the reactor
"""
install_dns_limiter(reactor)
def run():
- # make sure that we run the reactor with the sentinel log context,
- # otherwise other PreserveLoggingContext instances will get confused
- # and complain when they see the logcontext arbitrarily swapping
- # between the sentinel and `run` logcontexts.
- with PreserveLoggingContext():
- logger.info("Running")
- if cpu_affinity is not None:
- # Turn the bitmask into bits, reverse it so we go from 0 up
- mask_to_bits = bin(cpu_affinity)[2:][::-1]
-
- cpus = []
- cpu_num = 0
-
- for i in mask_to_bits:
- if i == "1":
- cpus.append(cpu_num)
- cpu_num += 1
-
- p = psutil.Process()
- p.cpu_affinity(cpus)
-
- change_resource_limit(soft_file_limit)
- if gc_thresholds:
- gc.set_threshold(*gc_thresholds)
- reactor.run()
-
- if daemonize:
- if print_pidfile:
- print(pid_file)
-
- daemon = Daemonize(
- app=appname,
- pid=pid_file,
- action=run,
- auto_close_fds=False,
- verbose=True,
- logger=logger,
- )
- daemon.start()
- else:
- run()
+ logger.info("Running")
+ change_resource_limit(soft_file_limit)
+ if gc_thresholds:
+ gc.set_threshold(*gc_thresholds)
+ run_command()
+
+ # make sure that we run the reactor with the sentinel log context,
+ # otherwise other PreserveLoggingContext instances will get confused
+ # and complain when they see the logcontext arbitrarily swapping
+ # between the sentinel and `run` logcontexts.
+ #
+ # We also need to drop the logcontext before forking if we're daemonizing,
+ # otherwise the cputime metrics get confused about the per-thread resource usage
+ # appearing to go backwards.
+ with PreserveLoggingContext():
+ if daemonize:
+ if print_pidfile:
+ print(pid_file)
+
+ daemon = Daemonize(
+ app=appname,
+ pid=pid_file,
+ action=run,
+ auto_close_fds=False,
+ verbose=True,
+ logger=logger,
+ )
+ daemon.start()
+ else:
+ run()
def quit_with_error(error_string):
message_lines = error_string.split("\n")
- line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2
- sys.stderr.write("*" * line_length + '\n')
+ line_length = max(len(l) for l in message_lines if len(l) < 80) + 2
+ sys.stderr.write("*" * line_length + "\n")
for line in message_lines:
sys.stderr.write(" %s\n" % (line.rstrip(),))
- sys.stderr.write("*" * line_length + '\n')
+ sys.stderr.write("*" * line_length + "\n")
sys.exit(1)
@@ -160,8 +153,7 @@ def listen_metrics(bind_addresses, port):
"""
Start Prometheus metrics server.
"""
- from synapse.metrics import RegistryProxy
- from prometheus_client import start_http_server
+ from synapse.metrics import RegistryProxy, start_http_server
for host in bind_addresses:
logger.info("Starting metrics listener on %s:%d", host, port)
@@ -178,14 +170,7 @@ def listen_tcp(bind_addresses, port, factory, reactor=reactor, backlog=50):
r = []
for address in bind_addresses:
try:
- r.append(
- reactor.listenTCP(
- port,
- factory,
- backlog,
- address
- )
- )
+ r.append(reactor.listenTCP(port, factory, backlog, address))
except error.CannotListenError as e:
check_bind_error(e, address, bind_addresses)
@@ -205,13 +190,7 @@ def listen_ssl(
for address in bind_addresses:
try:
r.append(
- reactor.listenSSL(
- port,
- factory,
- context_factory,
- backlog,
- address
- )
+ reactor.listenSSL(port, factory, context_factory, backlog, address)
)
except error.CannotListenError as e:
check_bind_error(e, address, bind_addresses)
@@ -243,15 +222,13 @@ def refresh_certificate(hs):
if isinstance(i.factory, TLSMemoryBIOFactory):
addr = i.getHost()
logger.info(
- "Replacing TLS context factory on [%s]:%i", addr.host, addr.port,
+ "Replacing TLS context factory on [%s]:%i", addr.host, addr.port
)
# We want to replace TLS factories with a new one, with the new
# TLS configuration. We do this by reaching in and pulling out
# the wrappedFactory, and then re-wrapping it.
i.factory = TLSMemoryBIOFactory(
- hs.tls_server_context_factory,
- False,
- i.factory.wrappedFactory
+ hs.tls_server_context_factory, False, i.factory.wrappedFactory
)
logger.info("Context factories updated.")
@@ -260,6 +237,12 @@ def start(hs, listeners=None):
"""
Start a Synapse server or worker.
+ Should be called once the reactor is running and (if we're using ACME) the
+ TLS certificates are in place.
+
+ Will start the main HTTP listeners and do some other startup tasks, and then
+ notify systemd.
+
Args:
hs (synapse.server.HomeServer)
listeners (list[dict]): Listener configuration ('listeners' in homeserver.yaml)
@@ -267,9 +250,16 @@ def start(hs, listeners=None):
try:
# Set up the SIGHUP machinery.
if hasattr(signal, "SIGHUP"):
+
def handle_sighup(*args, **kwargs):
- for i in _sighup_callbacks:
- i(hs)
+ # Tell systemd our state, if we're using it. This will silently fail if
+ # we're not using systemd.
+ sdnotify(b"RELOADING=1")
+
+ for i, args, kwargs in _sighup_callbacks:
+ i(hs, *args, **kwargs)
+
+ sdnotify(b"READY=1")
signal.signal(signal.SIGHUP, handle_sighup)
@@ -278,11 +268,27 @@ def start(hs, listeners=None):
# Load the certificate from disk.
refresh_certificate(hs)
+ # Start the tracer
+ synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa
+ hs.config
+ )
+
# It is now safe to start your Synapse.
hs.start_listening(listeners)
- hs.get_datastore().start_profiling()
+ hs.get_datastore().db.start_profiling()
+ hs.get_pusherpool().start()
setup_sentry(hs)
+ setup_sdnotify(hs)
+
+ # We now freeze all allocated objects in the hopes that (almost)
+ # everything currently allocated are things that will be used for the
+ # rest of time. Doing so means less work each GC (hopefully).
+ #
+ # This only works on Python 3.7
+ if sys.version_info >= (3, 7):
+ gc.collect()
+ gc.freeze()
except Exception:
traceback.print_exc(file=sys.stderr)
reactor = hs.get_reactor()
@@ -302,10 +308,8 @@ def setup_sentry(hs):
return
import sentry_sdk
- sentry_sdk.init(
- dsn=hs.config.sentry_dsn,
- release=get_version_string(synapse),
- )
+
+ sentry_sdk.init(dsn=hs.config.sentry_dsn, release=get_version_string(synapse))
# We set some default tags that give some context to this instance
with sentry_sdk.configure_scope() as scope:
@@ -317,6 +321,19 @@ def setup_sentry(hs):
scope.set_tag("worker_name", name)
+def setup_sdnotify(hs):
+ """Adds process state hooks to tell systemd what we are up to.
+ """
+
+ # Tell systemd our state, if we're using it. This will silently fail if
+ # we're not using systemd.
+ sdnotify(b"READY=1\nMAINPID=%i" % (os.getpid(),))
+
+ hs.get_reactor().addSystemEventTrigger(
+ "before", "shutdown", sdnotify, b"STOPPING=1"
+ )
+
+
def install_dns_limiter(reactor, max_dns_requests_in_flight=100):
"""Replaces the resolver with one that limits the number of in flight DNS
requests.
@@ -326,7 +343,7 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100):
many DNS queries at once
"""
new_resolver = _LimitedHostnameResolver(
- reactor.nameResolver, max_dns_requests_in_flight,
+ reactor.nameResolver, max_dns_requests_in_flight
)
reactor.installNameResolver(new_resolver)
@@ -339,11 +356,17 @@ class _LimitedHostnameResolver(object):
def __init__(self, resolver, max_dns_requests_in_flight):
self._resolver = resolver
self._limiter = Linearizer(
- name="dns_client_limiter", max_count=max_dns_requests_in_flight,
+ name="dns_client_limiter", max_count=max_dns_requests_in_flight
)
- def resolveHostName(self, resolutionReceiver, hostName, portNumber=0,
- addressTypes=None, transportSemantics='TCP'):
+ def resolveHostName(
+ self,
+ resolutionReceiver,
+ hostName,
+ portNumber=0,
+ addressTypes=None,
+ transportSemantics="TCP",
+ ):
# We need this function to return `resolutionReceiver` so we do all the
# actual logic involving deferreds in a separate function.
@@ -363,8 +386,14 @@ class _LimitedHostnameResolver(object):
return resolutionReceiver
@defer.inlineCallbacks
- def _resolve(self, resolutionReceiver, hostName, portNumber=0,
- addressTypes=None, transportSemantics='TCP'):
+ def _resolve(
+ self,
+ resolutionReceiver,
+ hostName,
+ portNumber=0,
+ addressTypes=None,
+ transportSemantics="TCP",
+ ):
with (yield self._limiter.queue(())):
# resolveHostName doesn't return a Deferred, so we need to hook into
@@ -374,8 +403,7 @@ class _LimitedHostnameResolver(object):
receiver = _DeferredResolutionReceiver(resolutionReceiver, deferred)
self._resolver.resolveHostName(
- receiver, hostName, portNumber,
- addressTypes, transportSemantics,
+ receiver, hostName, portNumber, addressTypes, transportSemantics
)
yield deferred
@@ -399,3 +427,35 @@ class _DeferredResolutionReceiver(object):
def resolutionComplete(self):
self._deferred.callback(())
self._receiver.resolutionComplete()
+
+
+sdnotify_sockaddr = os.getenv("NOTIFY_SOCKET")
+
+
+def sdnotify(state):
+ """
+ Send a notification to systemd, if the NOTIFY_SOCKET env var is set.
+
+ This function is based on the sdnotify python package, but since it's only a few
+ lines of code, it's easier to duplicate it here than to add a dependency on a
+ package which many OSes don't include as a matter of principle.
+
+ Args:
+ state (bytes): notification to send
+ """
+ if not isinstance(state, bytes):
+ raise TypeError("sdnotify should be called with a bytes")
+ if not sdnotify_sockaddr:
+ return
+ addr = sdnotify_sockaddr
+ if addr[0] == "@":
+ addr = "\0" + addr[1:]
+
+ try:
+ with socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) as sock:
+ sock.connect(addr)
+ sock.sendall(state)
+ except Exception as e:
+ # this is a bit surprising, since we don't expect to have a NOTIFY_SOCKET
+ # unless systemd is expecting us to notify it.
+ logger.warning("Unable to send notification to systemd: %s", e)
diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py
new file mode 100644
index 0000000000..1c7c6ec0c8
--- /dev/null
+++ b/synapse/app/admin_cmd.py
@@ -0,0 +1,260 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2019 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import logging
+import os
+import sys
+import tempfile
+
+from canonicaljson import json
+
+from twisted.internet import defer, task
+
+import synapse
+from synapse.app import _base
+from synapse.config._base import ConfigError
+from synapse.config.homeserver import HomeServerConfig
+from synapse.config.logger import setup_logging
+from synapse.handlers.admin import ExfiltrationWriter
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
+from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
+from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
+from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
+from synapse.replication.slave.storage.devices import SlavedDeviceStore
+from synapse.replication.slave.storage.events import SlavedEventStore
+from synapse.replication.slave.storage.filtering import SlavedFilteringStore
+from synapse.replication.slave.storage.groups import SlavedGroupServerStore
+from synapse.replication.slave.storage.presence import SlavedPresenceStore
+from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
+from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
+from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.replication.slave.storage.room import RoomStore
+from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.server import HomeServer
+from synapse.util.logcontext import LoggingContext
+from synapse.util.versionstring import get_version_string
+
+logger = logging.getLogger("synapse.app.admin_cmd")
+
+
+class AdminCmdSlavedStore(
+ SlavedReceiptsStore,
+ SlavedAccountDataStore,
+ SlavedApplicationServiceStore,
+ SlavedRegistrationStore,
+ SlavedFilteringStore,
+ SlavedPresenceStore,
+ SlavedGroupServerStore,
+ SlavedDeviceInboxStore,
+ SlavedDeviceStore,
+ SlavedPushRuleStore,
+ SlavedEventStore,
+ SlavedClientIpStore,
+ RoomStore,
+ BaseSlavedStore,
+):
+ pass
+
+
+class AdminCmdServer(HomeServer):
+ DATASTORE_CLASS = AdminCmdSlavedStore
+
+ def _listen_http(self, listener_config):
+ pass
+
+ def start_listening(self, listeners):
+ pass
+
+ def build_tcp_replication(self):
+ return AdminCmdReplicationHandler(self)
+
+
+class AdminCmdReplicationHandler(ReplicationClientHandler):
+ async def on_rdata(self, stream_name, token, rows):
+ pass
+
+ def get_streams_to_replicate(self):
+ return {}
+
+
+@defer.inlineCallbacks
+def export_data_command(hs, args):
+ """Export data for a user.
+
+ Args:
+ hs (HomeServer)
+ args (argparse.Namespace)
+ """
+
+ user_id = args.user_id
+ directory = args.output_directory
+
+ res = yield defer.ensureDeferred(
+ hs.get_handlers().admin_handler.export_user_data(
+ user_id, FileExfiltrationWriter(user_id, directory=directory)
+ )
+ )
+ print(res)
+
+
+class FileExfiltrationWriter(ExfiltrationWriter):
+ """An ExfiltrationWriter that writes the users data to a directory.
+ Returns the directory location on completion.
+
+ Note: This writes to disk on the main reactor thread.
+
+ Args:
+ user_id (str): The user whose data is being exfiltrated.
+ directory (str|None): The directory to write the data to, if None then
+ will write to a temporary directory.
+ """
+
+ def __init__(self, user_id, directory=None):
+ self.user_id = user_id
+
+ if directory:
+ self.base_directory = directory
+ else:
+ self.base_directory = tempfile.mkdtemp(
+ prefix="synapse-exfiltrate__%s__" % (user_id,)
+ )
+
+ os.makedirs(self.base_directory, exist_ok=True)
+ if list(os.listdir(self.base_directory)):
+ raise Exception("Directory must be empty")
+
+ def write_events(self, room_id, events):
+ room_directory = os.path.join(self.base_directory, "rooms", room_id)
+ os.makedirs(room_directory, exist_ok=True)
+ events_file = os.path.join(room_directory, "events")
+
+ with open(events_file, "a") as f:
+ for event in events:
+ print(json.dumps(event.get_pdu_json()), file=f)
+
+ def write_state(self, room_id, event_id, state):
+ room_directory = os.path.join(self.base_directory, "rooms", room_id)
+ state_directory = os.path.join(room_directory, "state")
+ os.makedirs(state_directory, exist_ok=True)
+
+ event_file = os.path.join(state_directory, event_id)
+
+ with open(event_file, "a") as f:
+ for event in state.values():
+ print(json.dumps(event.get_pdu_json()), file=f)
+
+ def write_invite(self, room_id, event, state):
+ self.write_events(room_id, [event])
+
+ # We write the invite state somewhere else as they aren't full events
+ # and are only a subset of the state at the event.
+ room_directory = os.path.join(self.base_directory, "rooms", room_id)
+ os.makedirs(room_directory, exist_ok=True)
+
+ invite_state = os.path.join(room_directory, "invite_state")
+
+ with open(invite_state, "a") as f:
+ for event in state.values():
+ print(json.dumps(event), file=f)
+
+ def finished(self):
+ return self.base_directory
+
+
+def start(config_options):
+ parser = argparse.ArgumentParser(description="Synapse Admin Command")
+ HomeServerConfig.add_arguments_to_parser(parser)
+
+ subparser = parser.add_subparsers(
+ title="Admin Commands",
+ required=True,
+ dest="command",
+ metavar="<admin_command>",
+ help="The admin command to perform.",
+ )
+ export_data_parser = subparser.add_parser(
+ "export-data", help="Export all data for a user"
+ )
+ export_data_parser.add_argument("user_id", help="User to extra data from")
+ export_data_parser.add_argument(
+ "--output-directory",
+ action="store",
+ metavar="DIRECTORY",
+ required=False,
+ help="The directory to store the exported data in. Must be empty. Defaults"
+ " to creating a temp directory.",
+ )
+ export_data_parser.set_defaults(func=export_data_command)
+
+ try:
+ config, args = HomeServerConfig.load_config_with_parser(parser, config_options)
+ except ConfigError as e:
+ sys.stderr.write("\n" + str(e) + "\n")
+ sys.exit(1)
+
+ if config.worker_app is not None:
+ assert config.worker_app == "synapse.app.admin_cmd"
+
+ # Update the config with some basic overrides so that don't have to specify
+ # a full worker config.
+ config.worker_app = "synapse.app.admin_cmd"
+
+ if (
+ not config.worker_daemonize
+ and not config.worker_log_file
+ and not config.worker_log_config
+ ):
+ # Since we're meant to be run as a "command" let's not redirect stdio
+ # unless we've actually set log config.
+ config.no_redirect_stdio = True
+
+ # Explicitly disable background processes
+ config.update_user_directory = False
+ config.start_pushers = False
+ config.send_federation = False
+
+ synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
+
+ ss = AdminCmdServer(
+ config.server_name,
+ config=config,
+ version_string="Synapse/" + get_version_string(synapse),
+ )
+
+ setup_logging(ss, config, use_worker_options=True)
+
+ ss.setup()
+
+ # We use task.react as the basic run command as it correctly handles tearing
+ # down the reactor when the deferreds resolve and setting the return value.
+ # We also make sure that `_base.start` gets run before we actually run the
+ # command.
+
+ @defer.inlineCallbacks
+ def run(_reactor):
+ with LoggingContext("command"):
+ yield _base.start(ss, [])
+ yield args.func(ss, args)
+
+ _base.start_worker_reactor(
+ "synapse-admin-cmd", config, run_command=lambda: task.react(run)
+ )
+
+
+if __name__ == "__main__":
+ with LoggingContext("main"):
+ start(sys.argv[1:])
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index 33107f56d1..add43147b3 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -13,166 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
-import sys
-
-from twisted.internet import defer, reactor
-from twisted.web.resource import NoResource
-
-import synapse
-from synapse import events
-from synapse.app import _base
-from synapse.config._base import ConfigError
-from synapse.config.homeserver import HomeServerConfig
-from synapse.config.logger import setup_logging
-from synapse.http.site import SynapseSite
-from synapse.metrics import RegistryProxy
-from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
-from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.directory import DirectoryStore
-from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
-from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.server import HomeServer
-from synapse.storage.engines import create_engine
-from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, run_in_background
-from synapse.util.manhole import manhole
-from synapse.util.versionstring import get_version_string
-
-logger = logging.getLogger("synapse.app.appservice")
-
-
-class AppserviceSlaveStore(
- DirectoryStore, SlavedEventStore, SlavedApplicationServiceStore,
- SlavedRegistrationStore,
-):
- pass
-
-
-class AppserviceServer(HomeServer):
- DATASTORE_CLASS = AppserviceSlaveStore
-
- def _listen_http(self, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- site_tag = listener_config.get("tag", port)
- resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
- if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
-
- root_resource = create_resource_tree(resources, NoResource())
-
- _base.listen_tcp(
- bind_addresses,
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- self.version_string,
- )
- )
-
- logger.info("Synapse appservice now listening on port %d", port)
-
- def start_listening(self, listeners):
- for listener in listeners:
- if listener["type"] == "http":
- self._listen_http(listener)
- elif listener["type"] == "manhole":
- _base.listen_tcp(
- listener["bind_addresses"],
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- )
- )
- elif listener["type"] == "metrics":
- if not self.get_config().enable_metrics:
- logger.warn(("Metrics listener configured, but "
- "enable_metrics is not True!"))
- else:
- _base.listen_metrics(listener["bind_addresses"],
- listener["port"])
- else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
-
- self.get_tcp_replication().start_replication(self)
-
- def build_tcp_replication(self):
- return ASReplicationHandler(self)
-
-class ASReplicationHandler(ReplicationClientHandler):
- def __init__(self, hs):
- super(ASReplicationHandler, self).__init__(hs.get_datastore())
- self.appservice_handler = hs.get_application_service_handler()
-
- @defer.inlineCallbacks
- def on_rdata(self, stream_name, token, rows):
- yield super(ASReplicationHandler, self).on_rdata(stream_name, token, rows)
-
- if stream_name == "events":
- max_stream_id = self.store.get_room_max_stream_ordering()
- run_in_background(self._notify_app_services, max_stream_id)
-
- @defer.inlineCallbacks
- def _notify_app_services(self, room_stream_id):
- try:
- yield self.appservice_handler.notify_interested_services(room_stream_id)
- except Exception:
- logger.exception("Error notifying application services of event")
-
-
-def start(config_options):
- try:
- config = HomeServerConfig.load_config(
- "Synapse appservice", config_options
- )
- except ConfigError as e:
- sys.stderr.write("\n" + str(e) + "\n")
- sys.exit(1)
-
- assert config.worker_app == "synapse.app.appservice"
-
- setup_logging(config, use_worker_options=True)
-
- events.USE_FROZEN_DICTS = config.use_frozen_dicts
-
- database_engine = create_engine(config.database_config)
-
- if config.notify_appservices:
- sys.stderr.write(
- "\nThe appservices must be disabled in the main synapse process"
- "\nbefore they can be run in a separate worker."
- "\nPlease add ``notify_appservices: false`` to the main config"
- "\n"
- )
- sys.exit(1)
-
- # Force the pushers to start since they will be disabled in the main config
- config.notify_appservices = True
-
- ps = AppserviceServer(
- config.server_name,
- db_config=config.database_config,
- config=config,
- version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
- )
-
- ps.setup()
- reactor.callWhenRunning(_base.start, ps, config.worker_listeners)
-
- _base.start_worker_reactor("synapse-appservice", config)
+import sys
+from synapse.app.generic_worker import start
+from synapse.util.logcontext import LoggingContext
-if __name__ == '__main__':
+if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index cd49fc5cd3..add43147b3 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -13,194 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
-import sys
-from twisted.internet import reactor
-from twisted.web.resource import NoResource
+import sys
-import synapse
-from synapse import events
-from synapse.app import _base
-from synapse.config._base import ConfigError
-from synapse.config.homeserver import HomeServerConfig
-from synapse.config.logger import setup_logging
-from synapse.http.server import JsonResource
-from synapse.http.site import SynapseSite
-from synapse.metrics import RegistryProxy
-from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
-from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
-from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
-from synapse.replication.slave.storage.devices import SlavedDeviceStore
-from synapse.replication.slave.storage.directory import DirectoryStore
-from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.replication.slave.storage.keys import SlavedKeyStore
-from synapse.replication.slave.storage.profile import SlavedProfileStore
-from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
-from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
-from synapse.replication.slave.storage.room import RoomStore
-from synapse.replication.slave.storage.transactions import SlavedTransactionStore
-from synapse.replication.slave.storage.user_directory import SlavedUserDirectoryStore
-from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.rest.client.v1.login import LoginRestServlet
-from synapse.rest.client.v1.push_rule import PushRuleRestServlet
-from synapse.rest.client.v1.room import (
- JoinedRoomMemberListRestServlet,
- PublicRoomListRestServlet,
- RoomEventContextServlet,
- RoomMemberListRestServlet,
- RoomStateRestServlet,
-)
-from synapse.rest.client.v1.voip import VoipRestServlet
-from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
-from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
-from synapse.rest.client.v2_alpha.register import RegisterRestServlet
-from synapse.rest.client.versions import VersionsRestServlet
-from synapse.server import HomeServer
-from synapse.storage.engines import create_engine
-from synapse.util.httpresourcetree import create_resource_tree
+from synapse.app.generic_worker import start
from synapse.util.logcontext import LoggingContext
-from synapse.util.manhole import manhole
-from synapse.util.versionstring import get_version_string
-
-logger = logging.getLogger("synapse.app.client_reader")
-
-
-class ClientReaderSlavedStore(
- SlavedDeviceInboxStore,
- SlavedDeviceStore,
- SlavedReceiptsStore,
- SlavedPushRuleStore,
- SlavedAccountDataStore,
- SlavedEventStore,
- SlavedKeyStore,
- RoomStore,
- DirectoryStore,
- SlavedApplicationServiceStore,
- SlavedRegistrationStore,
- SlavedTransactionStore,
- SlavedProfileStore,
- SlavedClientIpStore,
- SlavedUserDirectoryStore,
- BaseSlavedStore,
-):
- pass
-
-
-class ClientReaderServer(HomeServer):
- DATASTORE_CLASS = ClientReaderSlavedStore
-
- def _listen_http(self, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- site_tag = listener_config.get("tag", port)
- resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
- if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
- elif name == "client":
- resource = JsonResource(self, canonical_json=False)
-
- PublicRoomListRestServlet(self).register(resource)
- RoomMemberListRestServlet(self).register(resource)
- JoinedRoomMemberListRestServlet(self).register(resource)
- RoomStateRestServlet(self).register(resource)
- RoomEventContextServlet(self).register(resource)
- RegisterRestServlet(self).register(resource)
- LoginRestServlet(self).register(resource)
- ThreepidRestServlet(self).register(resource)
- KeyQueryServlet(self).register(resource)
- KeyChangesServlet(self).register(resource)
- VoipRestServlet(self).register(resource)
- PushRuleRestServlet(self).register(resource)
- VersionsRestServlet().register(resource)
-
- resources.update({
- "/_matrix/client": resource,
- })
-
- root_resource = create_resource_tree(resources, NoResource())
-
- _base.listen_tcp(
- bind_addresses,
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- self.version_string,
- )
- )
-
- logger.info("Synapse client reader now listening on port %d", port)
-
- def start_listening(self, listeners):
- for listener in listeners:
- if listener["type"] == "http":
- self._listen_http(listener)
- elif listener["type"] == "manhole":
- _base.listen_tcp(
- listener["bind_addresses"],
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- )
- )
- elif listener["type"] == "metrics":
- if not self.get_config().enable_metrics:
- logger.warn(("Metrics listener configured, but "
- "enable_metrics is not True!"))
- else:
- _base.listen_metrics(listener["bind_addresses"],
- listener["port"])
- else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
-
- self.get_tcp_replication().start_replication(self)
-
- def build_tcp_replication(self):
- return ReplicationClientHandler(self.get_datastore())
-
-
-def start(config_options):
- try:
- config = HomeServerConfig.load_config(
- "Synapse client reader", config_options
- )
- except ConfigError as e:
- sys.stderr.write("\n" + str(e) + "\n")
- sys.exit(1)
-
- assert config.worker_app == "synapse.app.client_reader"
-
- setup_logging(config, use_worker_options=True)
-
- events.USE_FROZEN_DICTS = config.use_frozen_dicts
-
- database_engine = create_engine(config.database_config)
-
- ss = ClientReaderServer(
- config.server_name,
- db_config=config.database_config,
- config=config,
- version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
- )
-
- ss.setup()
- reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
-
- _base.start_worker_reactor("synapse-client-reader", config)
-
-if __name__ == '__main__':
+if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])
diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py
index b8e5196152..e9c098c4e7 100644
--- a/synapse/app/event_creator.py
+++ b/synapse/app/event_creator.py
@@ -13,191 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
-import sys
-from twisted.internet import reactor
-from twisted.web.resource import NoResource
+import sys
-import synapse
-from synapse import events
-from synapse.app import _base
-from synapse.config._base import ConfigError
-from synapse.config.homeserver import HomeServerConfig
-from synapse.config.logger import setup_logging
-from synapse.http.server import JsonResource
-from synapse.http.site import SynapseSite
-from synapse.metrics import RegistryProxy
-from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
-from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
-from synapse.replication.slave.storage.devices import SlavedDeviceStore
-from synapse.replication.slave.storage.directory import DirectoryStore
-from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.replication.slave.storage.profile import SlavedProfileStore
-from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
-from synapse.replication.slave.storage.pushers import SlavedPusherStore
-from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
-from synapse.replication.slave.storage.room import RoomStore
-from synapse.replication.slave.storage.transactions import SlavedTransactionStore
-from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.rest.client.v1.profile import (
- ProfileAvatarURLRestServlet,
- ProfileDisplaynameRestServlet,
- ProfileRestServlet,
-)
-from synapse.rest.client.v1.room import (
- JoinRoomAliasServlet,
- RoomMembershipRestServlet,
- RoomSendEventRestServlet,
- RoomStateEventRestServlet,
-)
-from synapse.server import HomeServer
-from synapse.storage.engines import create_engine
-from synapse.storage.user_directory import UserDirectoryStore
-from synapse.util.httpresourcetree import create_resource_tree
+from synapse.app.generic_worker import start
from synapse.util.logcontext import LoggingContext
-from synapse.util.manhole import manhole
-from synapse.util.versionstring import get_version_string
-
-logger = logging.getLogger("synapse.app.event_creator")
-
-
-class EventCreatorSlavedStore(
- # FIXME(#3714): We need to add UserDirectoryStore as we write directly
- # rather than going via the correct worker.
- UserDirectoryStore,
- DirectoryStore,
- SlavedTransactionStore,
- SlavedProfileStore,
- SlavedAccountDataStore,
- SlavedPusherStore,
- SlavedReceiptsStore,
- SlavedPushRuleStore,
- SlavedDeviceStore,
- SlavedClientIpStore,
- SlavedApplicationServiceStore,
- SlavedEventStore,
- SlavedRegistrationStore,
- RoomStore,
- BaseSlavedStore,
-):
- pass
-
-
-class EventCreatorServer(HomeServer):
- DATASTORE_CLASS = EventCreatorSlavedStore
-
- def _listen_http(self, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- site_tag = listener_config.get("tag", port)
- resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
- if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
- elif name == "client":
- resource = JsonResource(self, canonical_json=False)
- RoomSendEventRestServlet(self).register(resource)
- RoomMembershipRestServlet(self).register(resource)
- RoomStateEventRestServlet(self).register(resource)
- JoinRoomAliasServlet(self).register(resource)
- ProfileAvatarURLRestServlet(self).register(resource)
- ProfileDisplaynameRestServlet(self).register(resource)
- ProfileRestServlet(self).register(resource)
- resources.update({
- "/_matrix/client/r0": resource,
- "/_matrix/client/unstable": resource,
- "/_matrix/client/v2_alpha": resource,
- "/_matrix/client/api/v1": resource,
- })
-
- root_resource = create_resource_tree(resources, NoResource())
-
- _base.listen_tcp(
- bind_addresses,
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- self.version_string,
- )
- )
-
- logger.info("Synapse event creator now listening on port %d", port)
-
- def start_listening(self, listeners):
- for listener in listeners:
- if listener["type"] == "http":
- self._listen_http(listener)
- elif listener["type"] == "manhole":
- _base.listen_tcp(
- listener["bind_addresses"],
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- )
- )
- elif listener["type"] == "metrics":
- if not self.get_config().enable_metrics:
- logger.warn(("Metrics listener configured, but "
- "enable_metrics is not True!"))
- else:
- _base.listen_metrics(listener["bind_addresses"],
- listener["port"])
- else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
-
- self.get_tcp_replication().start_replication(self)
-
- def build_tcp_replication(self):
- return ReplicationClientHandler(self.get_datastore())
-
-
-def start(config_options):
- try:
- config = HomeServerConfig.load_config(
- "Synapse event creator", config_options
- )
- except ConfigError as e:
- sys.stderr.write("\n" + str(e) + "\n")
- sys.exit(1)
-
- assert config.worker_app == "synapse.app.event_creator"
-
- assert config.worker_replication_http_port is not None
-
- setup_logging(config, use_worker_options=True)
-
- # This should only be done on the user directory worker or the master
- config.update_user_directory = False
-
- events.USE_FROZEN_DICTS = config.use_frozen_dicts
-
- database_engine = create_engine(config.database_config)
-
- ss = EventCreatorServer(
- config.server_name,
- db_config=config.database_config,
- config=config,
- version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
- )
-
- ss.setup()
- reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
-
- _base.start_worker_reactor("synapse-event-creator", config)
-
-if __name__ == '__main__':
+if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index 7da79dc827..add43147b3 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -13,174 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
-import sys
-from twisted.internet import reactor
-from twisted.web.resource import NoResource
+import sys
-import synapse
-from synapse import events
-from synapse.api.urls import FEDERATION_PREFIX, SERVER_KEY_V2_PREFIX
-from synapse.app import _base
-from synapse.config._base import ConfigError
-from synapse.config.homeserver import HomeServerConfig
-from synapse.config.logger import setup_logging
-from synapse.federation.transport.server import TransportLayerServer
-from synapse.http.site import SynapseSite
-from synapse.metrics import RegistryProxy
-from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
-from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.directory import DirectoryStore
-from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.replication.slave.storage.keys import SlavedKeyStore
-from synapse.replication.slave.storage.profile import SlavedProfileStore
-from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
-from synapse.replication.slave.storage.pushers import SlavedPusherStore
-from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
-from synapse.replication.slave.storage.room import RoomStore
-from synapse.replication.slave.storage.transactions import SlavedTransactionStore
-from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.rest.key.v2 import KeyApiV2Resource
-from synapse.server import HomeServer
-from synapse.storage.engines import create_engine
-from synapse.util.httpresourcetree import create_resource_tree
+from synapse.app.generic_worker import start
from synapse.util.logcontext import LoggingContext
-from synapse.util.manhole import manhole
-from synapse.util.versionstring import get_version_string
-
-logger = logging.getLogger("synapse.app.federation_reader")
-
-
-class FederationReaderSlavedStore(
- SlavedAccountDataStore,
- SlavedProfileStore,
- SlavedApplicationServiceStore,
- SlavedPusherStore,
- SlavedPushRuleStore,
- SlavedReceiptsStore,
- SlavedEventStore,
- SlavedKeyStore,
- SlavedRegistrationStore,
- RoomStore,
- DirectoryStore,
- SlavedTransactionStore,
- BaseSlavedStore,
-):
- pass
-
-
-class FederationReaderServer(HomeServer):
- DATASTORE_CLASS = FederationReaderSlavedStore
-
- def _listen_http(self, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- site_tag = listener_config.get("tag", port)
- resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
- if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
- elif name == "federation":
- resources.update({
- FEDERATION_PREFIX: TransportLayerServer(self),
- })
- if name == "openid" and "federation" not in res["names"]:
- # Only load the openid resource separately if federation resource
- # is not specified since federation resource includes openid
- # resource.
- resources.update({
- FEDERATION_PREFIX: TransportLayerServer(
- self,
- servlet_groups=["openid"],
- ),
- })
-
- if name in ["keys", "federation"]:
- resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
-
- root_resource = create_resource_tree(resources, NoResource())
-
- _base.listen_tcp(
- bind_addresses,
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- self.version_string,
- ),
- reactor=self.get_reactor()
- )
-
- logger.info("Synapse federation reader now listening on port %d", port)
-
- def start_listening(self, listeners):
- for listener in listeners:
- if listener["type"] == "http":
- self._listen_http(listener)
- elif listener["type"] == "manhole":
- _base.listen_tcp(
- listener["bind_addresses"],
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- )
- )
- elif listener["type"] == "metrics":
- if not self.get_config().enable_metrics:
- logger.warn(("Metrics listener configured, but "
- "enable_metrics is not True!"))
- else:
- _base.listen_metrics(listener["bind_addresses"],
- listener["port"])
- else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
-
- self.get_tcp_replication().start_replication(self)
-
- def build_tcp_replication(self):
- return ReplicationClientHandler(self.get_datastore())
-
-
-def start(config_options):
- try:
- config = HomeServerConfig.load_config(
- "Synapse federation reader", config_options
- )
- except ConfigError as e:
- sys.stderr.write("\n" + str(e) + "\n")
- sys.exit(1)
-
- assert config.worker_app == "synapse.app.federation_reader"
-
- setup_logging(config, use_worker_options=True)
-
- events.USE_FROZEN_DICTS = config.use_frozen_dicts
-
- database_engine = create_engine(config.database_config)
-
- ss = FederationReaderServer(
- config.server_name,
- db_config=config.database_config,
- config=config,
- version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
- )
-
- ss.setup()
- reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
-
- _base.start_worker_reactor("synapse-federation-reader", config)
-
-if __name__ == '__main__':
+if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 1d43f2b075..add43147b3 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -13,277 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
-import sys
-
-from twisted.internet import defer, reactor
-from twisted.web.resource import NoResource
-
-import synapse
-from synapse import events
-from synapse.app import _base
-from synapse.config._base import ConfigError
-from synapse.config.homeserver import HomeServerConfig
-from synapse.config.logger import setup_logging
-from synapse.federation import send_queue
-from synapse.http.site import SynapseSite
-from synapse.metrics import RegistryProxy
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
-from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
-from synapse.replication.slave.storage.devices import SlavedDeviceStore
-from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.replication.slave.storage.presence import SlavedPresenceStore
-from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
-from synapse.replication.slave.storage.transactions import SlavedTransactionStore
-from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.replication.tcp.streams._base import ReceiptsStream
-from synapse.server import HomeServer
-from synapse.storage.engines import create_engine
-from synapse.types import ReadReceipt
-from synapse.util.async_helpers import Linearizer
-from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, run_in_background
-from synapse.util.manhole import manhole
-from synapse.util.versionstring import get_version_string
-
-logger = logging.getLogger("synapse.app.federation_sender")
-
-
-class FederationSenderSlaveStore(
- SlavedDeviceInboxStore, SlavedTransactionStore, SlavedReceiptsStore, SlavedEventStore,
- SlavedRegistrationStore, SlavedDeviceStore, SlavedPresenceStore,
-):
- def __init__(self, db_conn, hs):
- super(FederationSenderSlaveStore, self).__init__(db_conn, hs)
-
- # We pull out the current federation stream position now so that we
- # always have a known value for the federation position in memory so
- # that we don't have to bounce via a deferred once when we start the
- # replication streams.
- self.federation_out_pos_startup = self._get_federation_out_pos(db_conn)
-
- def _get_federation_out_pos(self, db_conn):
- sql = (
- "SELECT stream_id FROM federation_stream_position"
- " WHERE type = ?"
- )
- sql = self.database_engine.convert_param_style(sql)
-
- txn = db_conn.cursor()
- txn.execute(sql, ("federation",))
- rows = txn.fetchall()
- txn.close()
-
- return rows[0][0] if rows else -1
-
-
-class FederationSenderServer(HomeServer):
- DATASTORE_CLASS = FederationSenderSlaveStore
-
- def _listen_http(self, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- site_tag = listener_config.get("tag", port)
- resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
- if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
-
- root_resource = create_resource_tree(resources, NoResource())
-
- _base.listen_tcp(
- bind_addresses,
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- self.version_string,
- )
- )
-
- logger.info("Synapse federation_sender now listening on port %d", port)
-
- def start_listening(self, listeners):
- for listener in listeners:
- if listener["type"] == "http":
- self._listen_http(listener)
- elif listener["type"] == "manhole":
- _base.listen_tcp(
- listener["bind_addresses"],
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- )
- )
- elif listener["type"] == "metrics":
- if not self.get_config().enable_metrics:
- logger.warn(("Metrics listener configured, but "
- "enable_metrics is not True!"))
- else:
- _base.listen_metrics(listener["bind_addresses"],
- listener["port"])
- else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
-
- self.get_tcp_replication().start_replication(self)
-
- def build_tcp_replication(self):
- return FederationSenderReplicationHandler(self)
-
-
-class FederationSenderReplicationHandler(ReplicationClientHandler):
- def __init__(self, hs):
- super(FederationSenderReplicationHandler, self).__init__(hs.get_datastore())
- self.send_handler = FederationSenderHandler(hs, self)
-
- @defer.inlineCallbacks
- def on_rdata(self, stream_name, token, rows):
- yield super(FederationSenderReplicationHandler, self).on_rdata(
- stream_name, token, rows
- )
- self.send_handler.process_replication_rows(stream_name, token, rows)
-
- def get_streams_to_replicate(self):
- args = super(FederationSenderReplicationHandler, self).get_streams_to_replicate()
- args.update(self.send_handler.stream_positions())
- return args
-
-def start(config_options):
- try:
- config = HomeServerConfig.load_config(
- "Synapse federation sender", config_options
- )
- except ConfigError as e:
- sys.stderr.write("\n" + str(e) + "\n")
- sys.exit(1)
-
- assert config.worker_app == "synapse.app.federation_sender"
-
- setup_logging(config, use_worker_options=True)
-
- events.USE_FROZEN_DICTS = config.use_frozen_dicts
-
- database_engine = create_engine(config.database_config)
-
- if config.send_federation:
- sys.stderr.write(
- "\nThe send_federation must be disabled in the main synapse process"
- "\nbefore they can be run in a separate worker."
- "\nPlease add ``send_federation: false`` to the main config"
- "\n"
- )
- sys.exit(1)
-
- # Force the pushers to start since they will be disabled in the main config
- config.send_federation = True
-
- ss = FederationSenderServer(
- config.server_name,
- db_config=config.database_config,
- config=config,
- version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
- )
-
- ss.setup()
- reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
-
- _base.start_worker_reactor("synapse-federation-sender", config)
-
-
-class FederationSenderHandler(object):
- """Processes the replication stream and forwards the appropriate entries
- to the federation sender.
- """
- def __init__(self, hs, replication_client):
- self.store = hs.get_datastore()
- self._is_mine_id = hs.is_mine_id
- self.federation_sender = hs.get_federation_sender()
- self.replication_client = replication_client
-
- self.federation_position = self.store.federation_out_pos_startup
- self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
-
- self._last_ack = self.federation_position
-
- self._room_serials = {}
- self._room_typing = {}
-
- def on_start(self):
- # There may be some events that are persisted but haven't been sent,
- # so send them now.
- self.federation_sender.notify_new_events(
- self.store.get_room_max_stream_ordering()
- )
-
- def stream_positions(self):
- return {"federation": self.federation_position}
-
- def process_replication_rows(self, stream_name, token, rows):
- # The federation stream contains things that we want to send out, e.g.
- # presence, typing, etc.
- if stream_name == "federation":
- send_queue.process_rows_for_federation(self.federation_sender, rows)
- run_in_background(self.update_token, token)
-
- # We also need to poke the federation sender when new events happen
- elif stream_name == "events":
- self.federation_sender.notify_new_events(token)
-
- # ... and when new receipts happen
- elif stream_name == ReceiptsStream.NAME:
- run_as_background_process(
- "process_receipts_for_federation", self._on_new_receipts, rows,
- )
-
- @defer.inlineCallbacks
- def _on_new_receipts(self, rows):
- """
- Args:
- rows (iterable[synapse.replication.tcp.streams.ReceiptsStreamRow]):
- new receipts to be processed
- """
- for receipt in rows:
- # we only want to send on receipts for our own users
- if not self._is_mine_id(receipt.user_id):
- continue
- receipt_info = ReadReceipt(
- receipt.room_id,
- receipt.receipt_type,
- receipt.user_id,
- [receipt.event_id],
- receipt.data,
- )
- yield self.federation_sender.send_read_receipt(receipt_info)
-
- @defer.inlineCallbacks
- def update_token(self, token):
- try:
- self.federation_position = token
-
- # We linearize here to ensure we don't have races updating the token
- with (yield self._fed_position_linearizer.queue(None)):
- if self._last_ack < self.federation_position:
- yield self.store.update_federation_out_pos(
- "federation", self.federation_position
- )
-
- # We ACK this token over replication so that the master can drop
- # its in memory queues
- self.replication_client.send_federation_ack(self.federation_position)
- self._last_ack = self.federation_position
- except Exception:
- logger.exception("Error updating federation stream position")
+import sys
+from synapse.app.generic_worker import start
+from synapse.util.logcontext import LoggingContext
-if __name__ == '__main__':
+if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index 6504da5278..add43147b3 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -13,251 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
-import sys
-from twisted.internet import defer, reactor
-from twisted.web.resource import NoResource
+import sys
-import synapse
-from synapse import events
-from synapse.api.errors import HttpResponseException, SynapseError
-from synapse.app import _base
-from synapse.config._base import ConfigError
-from synapse.config.homeserver import HomeServerConfig
-from synapse.config.logger import setup_logging
-from synapse.http.server import JsonResource
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.http.site import SynapseSite
-from synapse.metrics import RegistryProxy
-from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
-from synapse.replication.slave.storage.devices import SlavedDeviceStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
-from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.rest.client.v2_alpha._base import client_patterns
-from synapse.server import HomeServer
-from synapse.storage.engines import create_engine
-from synapse.util.httpresourcetree import create_resource_tree
+from synapse.app.generic_worker import start
from synapse.util.logcontext import LoggingContext
-from synapse.util.manhole import manhole
-from synapse.util.versionstring import get_version_string
-
-logger = logging.getLogger("synapse.app.frontend_proxy")
-
-
-class PresenceStatusStubServlet(RestServlet):
- PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status")
-
- def __init__(self, hs):
- super(PresenceStatusStubServlet, self).__init__()
- self.http_client = hs.get_simple_http_client()
- self.auth = hs.get_auth()
- self.main_uri = hs.config.worker_main_http_uri
-
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
- # Pass through the auth headers, if any, in case the access token
- # is there.
- auth_headers = request.requestHeaders.getRawHeaders("Authorization", [])
- headers = {
- "Authorization": auth_headers,
- }
-
- try:
- result = yield self.http_client.get_json(
- self.main_uri + request.uri.decode('ascii'),
- headers=headers,
- )
- except HttpResponseException as e:
- raise e.to_synapse_error()
-
- defer.returnValue((200, result))
-
- @defer.inlineCallbacks
- def on_PUT(self, request, user_id):
- yield self.auth.get_user_by_req(request)
- defer.returnValue((200, {}))
-
-
-class KeyUploadServlet(RestServlet):
- PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
-
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
- super(KeyUploadServlet, self).__init__()
- self.auth = hs.get_auth()
- self.store = hs.get_datastore()
- self.http_client = hs.get_simple_http_client()
- self.main_uri = hs.config.worker_main_http_uri
-
- @defer.inlineCallbacks
- def on_POST(self, request, device_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- user_id = requester.user.to_string()
- body = parse_json_object_from_request(request)
-
- if device_id is not None:
- # passing the device_id here is deprecated; however, we allow it
- # for now for compatibility with older clients.
- if (requester.device_id is not None and
- device_id != requester.device_id):
- logger.warning("Client uploading keys for a different device "
- "(logged in as %s, uploading for %s)",
- requester.device_id, device_id)
- else:
- device_id = requester.device_id
-
- if device_id is None:
- raise SynapseError(
- 400,
- "To upload keys, you must pass device_id when authenticating"
- )
-
- if body:
- # They're actually trying to upload something, proxy to main synapse.
- # Pass through the auth headers, if any, in case the access token
- # is there.
- auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", [])
- headers = {
- "Authorization": auth_headers,
- }
- result = yield self.http_client.post_json_get_json(
- self.main_uri + request.uri.decode('ascii'),
- body,
- headers=headers,
- )
-
- defer.returnValue((200, result))
- else:
- # Just interested in counts.
- result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
- defer.returnValue((200, {"one_time_key_counts": result}))
-
-
-class FrontendProxySlavedStore(
- SlavedDeviceStore,
- SlavedClientIpStore,
- SlavedApplicationServiceStore,
- SlavedRegistrationStore,
- BaseSlavedStore,
-):
- pass
-
-
-class FrontendProxyServer(HomeServer):
- DATASTORE_CLASS = FrontendProxySlavedStore
-
- def _listen_http(self, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- site_tag = listener_config.get("tag", port)
- resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
- if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
- elif name == "client":
- resource = JsonResource(self, canonical_json=False)
- KeyUploadServlet(self).register(resource)
-
- # If presence is disabled, use the stub servlet that does
- # not allow sending presence
- if not self.config.use_presence:
- PresenceStatusStubServlet(self).register(resource)
-
- resources.update({
- "/_matrix/client/r0": resource,
- "/_matrix/client/unstable": resource,
- "/_matrix/client/v2_alpha": resource,
- "/_matrix/client/api/v1": resource,
- })
-
- root_resource = create_resource_tree(resources, NoResource())
-
- _base.listen_tcp(
- bind_addresses,
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- self.version_string,
- ),
- reactor=self.get_reactor()
- )
-
- logger.info("Synapse client reader now listening on port %d", port)
-
- def start_listening(self, listeners):
- for listener in listeners:
- if listener["type"] == "http":
- self._listen_http(listener)
- elif listener["type"] == "manhole":
- _base.listen_tcp(
- listener["bind_addresses"],
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- )
- )
- elif listener["type"] == "metrics":
- if not self.get_config().enable_metrics:
- logger.warn(("Metrics listener configured, but "
- "enable_metrics is not True!"))
- else:
- _base.listen_metrics(listener["bind_addresses"],
- listener["port"])
- else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
-
- self.get_tcp_replication().start_replication(self)
-
- def build_tcp_replication(self):
- return ReplicationClientHandler(self.get_datastore())
-
-
-def start(config_options):
- try:
- config = HomeServerConfig.load_config(
- "Synapse frontend proxy", config_options
- )
- except ConfigError as e:
- sys.stderr.write("\n" + str(e) + "\n")
- sys.exit(1)
-
- assert config.worker_app == "synapse.app.frontend_proxy"
-
- assert config.worker_main_http_uri is not None
-
- setup_logging(config, use_worker_options=True)
-
- events.USE_FROZEN_DICTS = config.use_frozen_dicts
-
- database_engine = create_engine(config.database_config)
-
- ss = FrontendProxyServer(
- config.server_name,
- db_config=config.database_config,
- config=config,
- version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
- )
-
- ss.setup()
- reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
-
- _base.start_worker_reactor("synapse-frontend-proxy", config)
-
-if __name__ == '__main__':
+if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
new file mode 100644
index 0000000000..66be6ea2ec
--- /dev/null
+++ b/synapse/app/generic_worker.py
@@ -0,0 +1,941 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import contextlib
+import logging
+import sys
+
+from twisted.internet import defer, reactor
+from twisted.web.resource import NoResource
+
+import synapse
+import synapse.events
+from synapse.api.constants import EventTypes
+from synapse.api.errors import HttpResponseException, SynapseError
+from synapse.api.urls import (
+ CLIENT_API_PREFIX,
+ FEDERATION_PREFIX,
+ LEGACY_MEDIA_PREFIX,
+ MEDIA_PREFIX,
+ SERVER_KEY_V2_PREFIX,
+)
+from synapse.app import _base
+from synapse.config._base import ConfigError
+from synapse.config.homeserver import HomeServerConfig
+from synapse.config.logger import setup_logging
+from synapse.federation import send_queue
+from synapse.federation.transport.server import TransportLayerServer
+from synapse.handlers.presence import PresenceHandler, get_interested_parties
+from synapse.http.server import JsonResource
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseSite
+from synapse.logging.context import LoggingContext, run_in_background
+from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.replication.slave.storage._base import BaseSlavedStore, __func__
+from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
+from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
+from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
+from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
+from synapse.replication.slave.storage.devices import SlavedDeviceStore
+from synapse.replication.slave.storage.directory import DirectoryStore
+from synapse.replication.slave.storage.events import SlavedEventStore
+from synapse.replication.slave.storage.filtering import SlavedFilteringStore
+from synapse.replication.slave.storage.groups import SlavedGroupServerStore
+from synapse.replication.slave.storage.keys import SlavedKeyStore
+from synapse.replication.slave.storage.presence import SlavedPresenceStore
+from synapse.replication.slave.storage.profile import SlavedProfileStore
+from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
+from synapse.replication.slave.storage.pushers import SlavedPusherStore
+from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
+from synapse.replication.slave.storage.registration import SlavedRegistrationStore
+from synapse.replication.slave.storage.room import RoomStore
+from synapse.replication.slave.storage.transactions import SlavedTransactionStore
+from synapse.replication.tcp.client import ReplicationClientHandler
+from synapse.replication.tcp.streams._base import (
+ DeviceListsStream,
+ ReceiptsStream,
+ ToDeviceStream,
+)
+from synapse.replication.tcp.streams.events import EventsStreamEventRow, EventsStreamRow
+from synapse.rest.admin import register_servlets_for_media_repo
+from synapse.rest.client.v1 import events
+from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
+from synapse.rest.client.v1.login import LoginRestServlet
+from synapse.rest.client.v1.profile import (
+ ProfileAvatarURLRestServlet,
+ ProfileDisplaynameRestServlet,
+ ProfileRestServlet,
+)
+from synapse.rest.client.v1.push_rule import PushRuleRestServlet
+from synapse.rest.client.v1.room import (
+ JoinedRoomMemberListRestServlet,
+ JoinRoomAliasServlet,
+ PublicRoomListRestServlet,
+ RoomEventContextServlet,
+ RoomInitialSyncRestServlet,
+ RoomMemberListRestServlet,
+ RoomMembershipRestServlet,
+ RoomMessageListRestServlet,
+ RoomSendEventRestServlet,
+ RoomStateEventRestServlet,
+ RoomStateRestServlet,
+)
+from synapse.rest.client.v1.voip import VoipRestServlet
+from synapse.rest.client.v2_alpha import groups, sync, user_directory
+from synapse.rest.client.v2_alpha._base import client_patterns
+from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
+from synapse.rest.client.v2_alpha.account_data import (
+ AccountDataServlet,
+ RoomAccountDataServlet,
+)
+from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
+from synapse.rest.client.v2_alpha.register import RegisterRestServlet
+from synapse.rest.client.versions import VersionsRestServlet
+from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.server import HomeServer
+from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore
+from synapse.storage.data_stores.main.monthly_active_users import (
+ MonthlyActiveUsersWorkerStore,
+)
+from synapse.storage.data_stores.main.presence import UserPresenceState
+from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
+from synapse.types import ReadReceipt
+from synapse.util.async_helpers import Linearizer
+from synapse.util.httpresourcetree import create_resource_tree
+from synapse.util.manhole import manhole
+from synapse.util.stringutils import random_string
+from synapse.util.versionstring import get_version_string
+
+logger = logging.getLogger("synapse.app.generic_worker")
+
+
+class PresenceStatusStubServlet(RestServlet):
+ """If presence is disabled this servlet can be used to stub out setting
+ presence status, while proxying the getters to the master instance.
+ """
+
+ PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status")
+
+ def __init__(self, hs):
+ super(PresenceStatusStubServlet, self).__init__()
+ self.http_client = hs.get_simple_http_client()
+ self.auth = hs.get_auth()
+ self.main_uri = hs.config.worker_main_http_uri
+
+ async def on_GET(self, request, user_id):
+ # Pass through the auth headers, if any, in case the access token
+ # is there.
+ auth_headers = request.requestHeaders.getRawHeaders("Authorization", [])
+ headers = {"Authorization": auth_headers}
+
+ try:
+ result = await self.http_client.get_json(
+ self.main_uri + request.uri.decode("ascii"), headers=headers
+ )
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
+
+ return 200, result
+
+ async def on_PUT(self, request, user_id):
+ await self.auth.get_user_by_req(request)
+ return 200, {}
+
+
+class KeyUploadServlet(RestServlet):
+ """An implementation of the `KeyUploadServlet` that responds to read only
+ requests, but otherwise proxies through to the master instance.
+ """
+
+ PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super(KeyUploadServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.http_client = hs.get_simple_http_client()
+ self.main_uri = hs.config.worker_main_http_uri
+
+ async def on_POST(self, request, device_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
+ user_id = requester.user.to_string()
+ body = parse_json_object_from_request(request)
+
+ if device_id is not None:
+ # passing the device_id here is deprecated; however, we allow it
+ # for now for compatibility with older clients.
+ if requester.device_id is not None and device_id != requester.device_id:
+ logger.warning(
+ "Client uploading keys for a different device "
+ "(logged in as %s, uploading for %s)",
+ requester.device_id,
+ device_id,
+ )
+ else:
+ device_id = requester.device_id
+
+ if device_id is None:
+ raise SynapseError(
+ 400, "To upload keys, you must pass device_id when authenticating"
+ )
+
+ if body:
+ # They're actually trying to upload something, proxy to main synapse.
+ # Pass through the auth headers, if any, in case the access token
+ # is there.
+ auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", [])
+ headers = {"Authorization": auth_headers}
+ result = await self.http_client.post_json_get_json(
+ self.main_uri + request.uri.decode("ascii"), body, headers=headers
+ )
+
+ return 200, result
+ else:
+ # Just interested in counts.
+ result = await self.store.count_e2e_one_time_keys(user_id, device_id)
+ return 200, {"one_time_key_counts": result}
+
+
+UPDATE_SYNCING_USERS_MS = 10 * 1000
+
+
+class GenericWorkerPresence(object):
+ def __init__(self, hs):
+ self.hs = hs
+ self.is_mine_id = hs.is_mine_id
+ self.http_client = hs.get_simple_http_client()
+ self.store = hs.get_datastore()
+ self.user_to_num_current_syncs = {}
+ self.clock = hs.get_clock()
+ self.notifier = hs.get_notifier()
+
+ active_presence = self.store.take_presence_startup_info()
+ self.user_to_current_state = {state.user_id: state for state in active_presence}
+
+ # user_id -> last_sync_ms. Lists the users that have stopped syncing
+ # but we haven't notified the master of that yet
+ self.users_going_offline = {}
+
+ self._send_stop_syncing_loop = self.clock.looping_call(
+ self.send_stop_syncing, UPDATE_SYNCING_USERS_MS
+ )
+
+ self.process_id = random_string(16)
+ logger.info("Presence process_id is %r", self.process_id)
+
+ def send_user_sync(self, user_id, is_syncing, last_sync_ms):
+ if self.hs.config.use_presence:
+ self.hs.get_tcp_replication().send_user_sync(
+ user_id, is_syncing, last_sync_ms
+ )
+
+ def mark_as_coming_online(self, user_id):
+ """A user has started syncing. Send a UserSync to the master, unless they
+ had recently stopped syncing.
+
+ Args:
+ user_id (str)
+ """
+ going_offline = self.users_going_offline.pop(user_id, None)
+ if not going_offline:
+ # Safe to skip because we haven't yet told the master they were offline
+ self.send_user_sync(user_id, True, self.clock.time_msec())
+
+ def mark_as_going_offline(self, user_id):
+ """A user has stopped syncing. We wait before notifying the master as
+ its likely they'll come back soon. This allows us to avoid sending
+ a stopped syncing immediately followed by a started syncing notification
+ to the master
+
+ Args:
+ user_id (str)
+ """
+ self.users_going_offline[user_id] = self.clock.time_msec()
+
+ def send_stop_syncing(self):
+ """Check if there are any users who have stopped syncing a while ago
+ and haven't come back yet. If there are poke the master about them.
+ """
+ now = self.clock.time_msec()
+ for user_id, last_sync_ms in list(self.users_going_offline.items()):
+ if now - last_sync_ms > UPDATE_SYNCING_USERS_MS:
+ self.users_going_offline.pop(user_id, None)
+ self.send_user_sync(user_id, False, last_sync_ms)
+
+ def set_state(self, user, state, ignore_status_msg=False):
+ # TODO Hows this supposed to work?
+ return defer.succeed(None)
+
+ get_states = __func__(PresenceHandler.get_states)
+ get_state = __func__(PresenceHandler.get_state)
+ current_state_for_users = __func__(PresenceHandler.current_state_for_users)
+
+ def user_syncing(self, user_id, affect_presence):
+ if affect_presence:
+ curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
+ self.user_to_num_current_syncs[user_id] = curr_sync + 1
+
+ # If we went from no in flight sync to some, notify replication
+ if self.user_to_num_current_syncs[user_id] == 1:
+ self.mark_as_coming_online(user_id)
+
+ def _end():
+ # We check that the user_id is in user_to_num_current_syncs because
+ # user_to_num_current_syncs may have been cleared if we are
+ # shutting down.
+ if affect_presence and user_id in self.user_to_num_current_syncs:
+ self.user_to_num_current_syncs[user_id] -= 1
+
+ # If we went from one in flight sync to non, notify replication
+ if self.user_to_num_current_syncs[user_id] == 0:
+ self.mark_as_going_offline(user_id)
+
+ @contextlib.contextmanager
+ def _user_syncing():
+ try:
+ yield
+ finally:
+ _end()
+
+ return defer.succeed(_user_syncing())
+
+ @defer.inlineCallbacks
+ def notify_from_replication(self, states, stream_id):
+ parties = yield get_interested_parties(self.store, states)
+ room_ids_to_states, users_to_states = parties
+
+ self.notifier.on_new_event(
+ "presence_key",
+ stream_id,
+ rooms=room_ids_to_states.keys(),
+ users=users_to_states.keys(),
+ )
+
+ @defer.inlineCallbacks
+ def process_replication_rows(self, token, rows):
+ states = [
+ UserPresenceState(
+ row.user_id,
+ row.state,
+ row.last_active_ts,
+ row.last_federation_update_ts,
+ row.last_user_sync_ts,
+ row.status_msg,
+ row.currently_active,
+ )
+ for row in rows
+ ]
+
+ for state in states:
+ self.user_to_current_state[state.user_id] = state
+
+ stream_id = token
+ yield self.notify_from_replication(states, stream_id)
+
+ def get_currently_syncing_users(self):
+ if self.hs.config.use_presence:
+ return [
+ user_id
+ for user_id, count in self.user_to_num_current_syncs.items()
+ if count > 0
+ ]
+ else:
+ return set()
+
+
+class GenericWorkerTyping(object):
+ def __init__(self, hs):
+ self._latest_room_serial = 0
+ self._reset()
+
+ def _reset(self):
+ """
+ Reset the typing handler's data caches.
+ """
+ # map room IDs to serial numbers
+ self._room_serials = {}
+ # map room IDs to sets of users currently typing
+ self._room_typing = {}
+
+ def stream_positions(self):
+ # We must update this typing token from the response of the previous
+ # sync. In particular, the stream id may "reset" back to zero/a low
+ # value which we *must* use for the next replication request.
+ return {"typing": self._latest_room_serial}
+
+ def process_replication_rows(self, token, rows):
+ if self._latest_room_serial > token:
+ # The master has gone backwards. To prevent inconsistent data, just
+ # clear everything.
+ self._reset()
+
+ # Set the latest serial token to whatever the server gave us.
+ self._latest_room_serial = token
+
+ for row in rows:
+ self._room_serials[row.room_id] = token
+ self._room_typing[row.room_id] = row.user_ids
+
+
+class GenericWorkerSlavedStore(
+ # FIXME(#3714): We need to add UserDirectoryStore as we write directly
+ # rather than going via the correct worker.
+ UserDirectoryStore,
+ SlavedDeviceInboxStore,
+ SlavedDeviceStore,
+ SlavedReceiptsStore,
+ SlavedPushRuleStore,
+ SlavedGroupServerStore,
+ SlavedAccountDataStore,
+ SlavedPusherStore,
+ SlavedEventStore,
+ SlavedKeyStore,
+ RoomStore,
+ DirectoryStore,
+ SlavedApplicationServiceStore,
+ SlavedRegistrationStore,
+ SlavedTransactionStore,
+ SlavedProfileStore,
+ SlavedClientIpStore,
+ SlavedPresenceStore,
+ SlavedFilteringStore,
+ MonthlyActiveUsersWorkerStore,
+ MediaRepositoryStore,
+ BaseSlavedStore,
+):
+ def __init__(self, database, db_conn, hs):
+ super(GenericWorkerSlavedStore, self).__init__(database, db_conn, hs)
+
+ # We pull out the current federation stream position now so that we
+ # always have a known value for the federation position in memory so
+ # that we don't have to bounce via a deferred once when we start the
+ # replication streams.
+ self.federation_out_pos_startup = self._get_federation_out_pos(db_conn)
+
+ def _get_federation_out_pos(self, db_conn):
+ sql = "SELECT stream_id FROM federation_stream_position WHERE type = ?"
+ sql = self.database_engine.convert_param_style(sql)
+
+ txn = db_conn.cursor()
+ txn.execute(sql, ("federation",))
+ rows = txn.fetchall()
+ txn.close()
+
+ return rows[0][0] if rows else -1
+
+
+class GenericWorkerServer(HomeServer):
+ DATASTORE_CLASS = GenericWorkerSlavedStore
+
+ def _listen_http(self, listener_config):
+ port = listener_config["port"]
+ bind_addresses = listener_config["bind_addresses"]
+ site_tag = listener_config.get("tag", port)
+ resources = {}
+ for res in listener_config["resources"]:
+ for name in res["names"]:
+ if name == "metrics":
+ resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
+ elif name == "client":
+ resource = JsonResource(self, canonical_json=False)
+
+ PublicRoomListRestServlet(self).register(resource)
+ RoomMemberListRestServlet(self).register(resource)
+ JoinedRoomMemberListRestServlet(self).register(resource)
+ RoomStateRestServlet(self).register(resource)
+ RoomEventContextServlet(self).register(resource)
+ RoomMessageListRestServlet(self).register(resource)
+ RegisterRestServlet(self).register(resource)
+ LoginRestServlet(self).register(resource)
+ ThreepidRestServlet(self).register(resource)
+ KeyQueryServlet(self).register(resource)
+ KeyChangesServlet(self).register(resource)
+ VoipRestServlet(self).register(resource)
+ PushRuleRestServlet(self).register(resource)
+ VersionsRestServlet(self).register(resource)
+ RoomSendEventRestServlet(self).register(resource)
+ RoomMembershipRestServlet(self).register(resource)
+ RoomStateEventRestServlet(self).register(resource)
+ JoinRoomAliasServlet(self).register(resource)
+ ProfileAvatarURLRestServlet(self).register(resource)
+ ProfileDisplaynameRestServlet(self).register(resource)
+ ProfileRestServlet(self).register(resource)
+ KeyUploadServlet(self).register(resource)
+ AccountDataServlet(self).register(resource)
+ RoomAccountDataServlet(self).register(resource)
+
+ sync.register_servlets(self, resource)
+ events.register_servlets(self, resource)
+ InitialSyncRestServlet(self).register(resource)
+ RoomInitialSyncRestServlet(self).register(resource)
+
+ user_directory.register_servlets(self, resource)
+
+ # If presence is disabled, use the stub servlet that does
+ # not allow sending presence
+ if not self.config.use_presence:
+ PresenceStatusStubServlet(self).register(resource)
+
+ groups.register_servlets(self, resource)
+
+ resources.update({CLIENT_API_PREFIX: resource})
+ elif name == "federation":
+ resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
+ elif name == "media":
+ if self.config.can_load_media_repo:
+ media_repo = self.get_media_repository_resource()
+
+ # We need to serve the admin servlets for media on the
+ # worker.
+ admin_resource = JsonResource(self, canonical_json=False)
+ register_servlets_for_media_repo(self, admin_resource)
+
+ resources.update(
+ {
+ MEDIA_PREFIX: media_repo,
+ LEGACY_MEDIA_PREFIX: media_repo,
+ "/_synapse/admin": admin_resource,
+ }
+ )
+ else:
+ logger.warning(
+ "A 'media' listener is configured but the media"
+ " repository is disabled. Ignoring."
+ )
+
+ if name == "openid" and "federation" not in res["names"]:
+ # Only load the openid resource separately if federation resource
+ # is not specified since federation resource includes openid
+ # resource.
+ resources.update(
+ {
+ FEDERATION_PREFIX: TransportLayerServer(
+ self, servlet_groups=["openid"]
+ )
+ }
+ )
+
+ if name in ["keys", "federation"]:
+ resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
+
+ root_resource = create_resource_tree(resources, NoResource())
+
+ _base.listen_tcp(
+ bind_addresses,
+ port,
+ SynapseSite(
+ "synapse.access.http.%s" % (site_tag,),
+ site_tag,
+ listener_config,
+ root_resource,
+ self.version_string,
+ ),
+ reactor=self.get_reactor(),
+ )
+
+ logger.info("Synapse worker now listening on port %d", port)
+
+ def start_listening(self, listeners):
+ for listener in listeners:
+ if listener["type"] == "http":
+ self._listen_http(listener)
+ elif listener["type"] == "manhole":
+ _base.listen_tcp(
+ listener["bind_addresses"],
+ listener["port"],
+ manhole(
+ username="matrix", password="rabbithole", globals={"hs": self}
+ ),
+ )
+ elif listener["type"] == "metrics":
+ if not self.get_config().enable_metrics:
+ logger.warning(
+ (
+ "Metrics listener configured, but "
+ "enable_metrics is not True!"
+ )
+ )
+ else:
+ _base.listen_metrics(listener["bind_addresses"], listener["port"])
+ else:
+ logger.warning("Unrecognized listener type: %s", listener["type"])
+
+ self.get_tcp_replication().start_replication(self)
+
+ def remove_pusher(self, app_id, push_key, user_id):
+ self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
+
+ def build_tcp_replication(self):
+ return GenericWorkerReplicationHandler(self)
+
+ def build_presence_handler(self):
+ return GenericWorkerPresence(self)
+
+ def build_typing_handler(self):
+ return GenericWorkerTyping(self)
+
+
+class GenericWorkerReplicationHandler(ReplicationClientHandler):
+ def __init__(self, hs):
+ super(GenericWorkerReplicationHandler, self).__init__(hs.get_datastore())
+
+ self.store = hs.get_datastore()
+ self.typing_handler = hs.get_typing_handler()
+ # NB this is a SynchrotronPresence, not a normal PresenceHandler
+ self.presence_handler = hs.get_presence_handler()
+ self.notifier = hs.get_notifier()
+
+ self.notify_pushers = hs.config.start_pushers
+ self.pusher_pool = hs.get_pusherpool()
+
+ if hs.config.send_federation:
+ self.send_handler = FederationSenderHandler(hs, self)
+ else:
+ self.send_handler = None
+
+ async def on_rdata(self, stream_name, token, rows):
+ await super(GenericWorkerReplicationHandler, self).on_rdata(
+ stream_name, token, rows
+ )
+ run_in_background(self.process_and_notify, stream_name, token, rows)
+
+ def get_streams_to_replicate(self):
+ args = super(GenericWorkerReplicationHandler, self).get_streams_to_replicate()
+ args.update(self.typing_handler.stream_positions())
+ if self.send_handler:
+ args.update(self.send_handler.stream_positions())
+ return args
+
+ def get_currently_syncing_users(self):
+ return self.presence_handler.get_currently_syncing_users()
+
+ async def process_and_notify(self, stream_name, token, rows):
+ try:
+ if self.send_handler:
+ self.send_handler.process_replication_rows(stream_name, token, rows)
+
+ if stream_name == "events":
+ # We shouldn't get multiple rows per token for events stream, so
+ # we don't need to optimise this for multiple rows.
+ for row in rows:
+ if row.type != EventsStreamEventRow.TypeId:
+ continue
+ assert isinstance(row, EventsStreamRow)
+
+ event = await self.store.get_event(
+ row.data.event_id, allow_rejected=True
+ )
+ if event.rejected_reason:
+ continue
+
+ extra_users = ()
+ if event.type == EventTypes.Member:
+ extra_users = (event.state_key,)
+ max_token = self.store.get_room_max_stream_ordering()
+ self.notifier.on_new_room_event(
+ event, token, max_token, extra_users
+ )
+
+ await self.pusher_pool.on_new_notifications(token, token)
+ elif stream_name == "push_rules":
+ self.notifier.on_new_event(
+ "push_rules_key", token, users=[row.user_id for row in rows]
+ )
+ elif stream_name in ("account_data", "tag_account_data"):
+ self.notifier.on_new_event(
+ "account_data_key", token, users=[row.user_id for row in rows]
+ )
+ elif stream_name == "receipts":
+ self.notifier.on_new_event(
+ "receipt_key", token, rooms=[row.room_id for row in rows]
+ )
+ await self.pusher_pool.on_new_receipts(
+ token, token, {row.room_id for row in rows}
+ )
+ elif stream_name == "typing":
+ self.typing_handler.process_replication_rows(token, rows)
+ self.notifier.on_new_event(
+ "typing_key", token, rooms=[row.room_id for row in rows]
+ )
+ elif stream_name == "to_device":
+ entities = [row.entity for row in rows if row.entity.startswith("@")]
+ if entities:
+ self.notifier.on_new_event("to_device_key", token, users=entities)
+ elif stream_name == "device_lists":
+ all_room_ids = set()
+ for row in rows:
+ room_ids = await self.store.get_rooms_for_user(row.user_id)
+ all_room_ids.update(room_ids)
+ self.notifier.on_new_event("device_list_key", token, rooms=all_room_ids)
+ elif stream_name == "presence":
+ await self.presence_handler.process_replication_rows(token, rows)
+ elif stream_name == "receipts":
+ self.notifier.on_new_event(
+ "groups_key", token, users=[row.user_id for row in rows]
+ )
+ elif stream_name == "pushers":
+ for row in rows:
+ if row.deleted:
+ self.stop_pusher(row.user_id, row.app_id, row.pushkey)
+ else:
+ await self.start_pusher(row.user_id, row.app_id, row.pushkey)
+ except Exception:
+ logger.exception("Error processing replication")
+
+ def stop_pusher(self, user_id, app_id, pushkey):
+ if not self.notify_pushers:
+ return
+
+ key = "%s:%s" % (app_id, pushkey)
+ pushers_for_user = self.pusher_pool.pushers.get(user_id, {})
+ pusher = pushers_for_user.pop(key, None)
+ if pusher is None:
+ return
+ logger.info("Stopping pusher %r / %r", user_id, key)
+ pusher.on_stop()
+
+ async def start_pusher(self, user_id, app_id, pushkey):
+ if not self.notify_pushers:
+ return
+
+ key = "%s:%s" % (app_id, pushkey)
+ logger.info("Starting pusher %r / %r", user_id, key)
+ return await self.pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
+
+ def on_remote_server_up(self, server: str):
+ """Called when get a new REMOTE_SERVER_UP command."""
+
+ # Let's wake up the transaction queue for the server in case we have
+ # pending stuff to send to it.
+ if self.send_handler:
+ self.send_handler.wake_destination(server)
+
+
+class FederationSenderHandler(object):
+ """Processes the replication stream and forwards the appropriate entries
+ to the federation sender.
+ """
+
+ def __init__(self, hs: GenericWorkerServer, replication_client):
+ self.store = hs.get_datastore()
+ self._is_mine_id = hs.is_mine_id
+ self.federation_sender = hs.get_federation_sender()
+ self.replication_client = replication_client
+
+ self.federation_position = self.store.federation_out_pos_startup
+ self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
+
+ self._last_ack = self.federation_position
+
+ self._room_serials = {}
+ self._room_typing = {}
+
+ def on_start(self):
+ # There may be some events that are persisted but haven't been sent,
+ # so send them now.
+ self.federation_sender.notify_new_events(
+ self.store.get_room_max_stream_ordering()
+ )
+
+ def wake_destination(self, server: str):
+ self.federation_sender.wake_destination(server)
+
+ def stream_positions(self):
+ return {"federation": self.federation_position}
+
+ def process_replication_rows(self, stream_name, token, rows):
+ # The federation stream contains things that we want to send out, e.g.
+ # presence, typing, etc.
+ if stream_name == "federation":
+ send_queue.process_rows_for_federation(self.federation_sender, rows)
+ run_in_background(self.update_token, token)
+
+ # We also need to poke the federation sender when new events happen
+ elif stream_name == "events":
+ self.federation_sender.notify_new_events(token)
+
+ # ... and when new receipts happen
+ elif stream_name == ReceiptsStream.NAME:
+ run_as_background_process(
+ "process_receipts_for_federation", self._on_new_receipts, rows
+ )
+
+ # ... as well as device updates and messages
+ elif stream_name == DeviceListsStream.NAME:
+ hosts = {row.destination for row in rows}
+ for host in hosts:
+ self.federation_sender.send_device_messages(host)
+
+ elif stream_name == ToDeviceStream.NAME:
+ # The to_device stream includes stuff to be pushed to both local
+ # clients and remote servers, so we ignore entities that start with
+ # '@' (since they'll be local users rather than destinations).
+ hosts = {row.entity for row in rows if not row.entity.startswith("@")}
+ for host in hosts:
+ self.federation_sender.send_device_messages(host)
+
+ async def _on_new_receipts(self, rows):
+ """
+ Args:
+ rows (iterable[synapse.replication.tcp.streams.ReceiptsStreamRow]):
+ new receipts to be processed
+ """
+ for receipt in rows:
+ # we only want to send on receipts for our own users
+ if not self._is_mine_id(receipt.user_id):
+ continue
+ receipt_info = ReadReceipt(
+ receipt.room_id,
+ receipt.receipt_type,
+ receipt.user_id,
+ [receipt.event_id],
+ receipt.data,
+ )
+ await self.federation_sender.send_read_receipt(receipt_info)
+
+ async def update_token(self, token):
+ try:
+ self.federation_position = token
+
+ # We linearize here to ensure we don't have races updating the token
+ with (await self._fed_position_linearizer.queue(None)):
+ if self._last_ack < self.federation_position:
+ await self.store.update_federation_out_pos(
+ "federation", self.federation_position
+ )
+
+ # We ACK this token over replication so that the master can drop
+ # its in memory queues
+ self.replication_client.send_federation_ack(
+ self.federation_position
+ )
+ self._last_ack = self.federation_position
+ except Exception:
+ logger.exception("Error updating federation stream position")
+
+
+def start(config_options):
+ try:
+ config = HomeServerConfig.load_config("Synapse worker", config_options)
+ except ConfigError as e:
+ sys.stderr.write("\n" + str(e) + "\n")
+ sys.exit(1)
+
+ # For backwards compatibility let any of the old app names.
+ assert config.worker_app in (
+ "synapse.app.appservice",
+ "synapse.app.client_reader",
+ "synapse.app.event_creator",
+ "synapse.app.federation_reader",
+ "synapse.app.federation_sender",
+ "synapse.app.frontend_proxy",
+ "synapse.app.generic_worker",
+ "synapse.app.media_repository",
+ "synapse.app.pusher",
+ "synapse.app.synchrotron",
+ "synapse.app.user_dir",
+ )
+
+ if config.worker_app == "synapse.app.appservice":
+ if config.notify_appservices:
+ sys.stderr.write(
+ "\nThe appservices must be disabled in the main synapse process"
+ "\nbefore they can be run in a separate worker."
+ "\nPlease add ``notify_appservices: false`` to the main config"
+ "\n"
+ )
+ sys.exit(1)
+
+ # Force the appservice to start since they will be disabled in the main config
+ config.notify_appservices = True
+ else:
+ # For other worker types we force this to off.
+ config.notify_appservices = False
+
+ if config.worker_app == "synapse.app.pusher":
+ if config.start_pushers:
+ sys.stderr.write(
+ "\nThe pushers must be disabled in the main synapse process"
+ "\nbefore they can be run in a separate worker."
+ "\nPlease add ``start_pushers: false`` to the main config"
+ "\n"
+ )
+ sys.exit(1)
+
+ # Force the pushers to start since they will be disabled in the main config
+ config.start_pushers = True
+ else:
+ # For other worker types we force this to off.
+ config.start_pushers = False
+
+ if config.worker_app == "synapse.app.user_dir":
+ if config.update_user_directory:
+ sys.stderr.write(
+ "\nThe update_user_directory must be disabled in the main synapse process"
+ "\nbefore they can be run in a separate worker."
+ "\nPlease add ``update_user_directory: false`` to the main config"
+ "\n"
+ )
+ sys.exit(1)
+
+ # Force the pushers to start since they will be disabled in the main config
+ config.update_user_directory = True
+ else:
+ # For other worker types we force this to off.
+ config.update_user_directory = False
+
+ if config.worker_app == "synapse.app.federation_sender":
+ if config.send_federation:
+ sys.stderr.write(
+ "\nThe send_federation must be disabled in the main synapse process"
+ "\nbefore they can be run in a separate worker."
+ "\nPlease add ``send_federation: false`` to the main config"
+ "\n"
+ )
+ sys.exit(1)
+
+ # Force the pushers to start since they will be disabled in the main config
+ config.send_federation = True
+ else:
+ # For other worker types we force this to off.
+ config.send_federation = False
+
+ synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
+
+ ss = GenericWorkerServer(
+ config.server_name,
+ config=config,
+ version_string="Synapse/" + get_version_string(synapse),
+ )
+
+ setup_logging(ss, config, use_worker_options=True)
+
+ ss.setup()
+ reactor.addSystemEventTrigger(
+ "before", "startup", _base.start, ss, config.worker_listeners
+ )
+
+ _base.start_worker_reactor("synapse-generic-worker", config)
+
+
+if __name__ == "__main__":
+ with LoggingContext("main"):
+ start(sys.argv[1:])
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 5a46b11bc0..f2b56a636f 100755..100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -19,18 +19,19 @@ from __future__ import print_function
import gc
import logging
+import math
import os
+import resource
import sys
from six import iteritems
-import psutil
from prometheus_client import Gauge
from twisted.application import service
from twisted.internet import defer, reactor
from twisted.python.failure import Failure
-from twisted.web.resource import EncodingResourceWrapper, NoResource
+from twisted.web.resource import EncodingResourceWrapper, IResource, NoResource
from twisted.web.server import GzipEncoderFactory
from twisted.web.static import File
@@ -38,7 +39,6 @@ import synapse
import synapse.config.logger
from synapse import events
from synapse.api.urls import (
- CONTENT_REPO_PREFIX,
FEDERATION_PREFIX,
LEGACY_MEDIA_PREFIX,
MEDIA_PREFIX,
@@ -54,9 +54,9 @@ from synapse.federation.transport.server import TransportLayerServer
from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import RootRedirect
from synapse.http.site import SynapseSite
-from synapse.metrics import RegistryProxy
+from synapse.logging.context import LoggingContext
+from synapse.metrics import METRICS_PREFIX, MetricsResource, RegistryProxy
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.module_api import ModuleApi
from synapse.python_dependencies import check_requirements
from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
@@ -64,15 +64,13 @@ from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.rest import ClientRestResource
from synapse.rest.admin import AdminRestResource
from synapse.rest.key.v2 import KeyApiV2Resource
-from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.well_known import WellKnownResource
from synapse.server import HomeServer
-from synapse.storage import DataStore, are_all_users_on_domain
-from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
-from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database
+from synapse.storage import DataStore
+from synapse.storage.engines import IncorrectDatabaseSetup
+from synapse.storage.prepare_database import UpgradeDatabaseException
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.module_loader import load_module
from synapse.util.rlimit import change_resource_limit
@@ -101,18 +99,26 @@ class SynapseHomeServer(HomeServer):
# Skip loading openid resource if federation is defined
# since federation resource will include openid
continue
- resources.update(self._configure_named_resource(
- name, res.get("compress", False),
- ))
+ resources.update(
+ self._configure_named_resource(name, res.get("compress", False))
+ )
additional_resources = listener_config.get("additional_resources", {})
- logger.debug("Configuring additional resources: %r",
- additional_resources)
+ logger.debug("Configuring additional resources: %r", additional_resources)
module_api = ModuleApi(self, self.get_auth_handler())
for path, resmodule in additional_resources.items():
handler_cls, config = load_module(resmodule)
handler = handler_cls(config, module_api)
- resources[path] = AdditionalResource(self, handler.handle_request)
+ if IResource.providedBy(handler):
+ resource = handler
+ elif hasattr(handler, "handle_request"):
+ resource = AdditionalResource(self, handler.handle_request)
+ else:
+ raise ConfigError(
+ "additional_resource %s does not implement a known interface"
+ % (resmodule["module"],)
+ )
+ resources[path] = resource
# try to find something useful to redirect '/' to
if WEB_CLIENT_PREFIX in resources:
@@ -174,59 +180,61 @@ class SynapseHomeServer(HomeServer):
if compress:
client_resource = gz_wrap(client_resource)
- resources.update({
- "/_matrix/client/api/v1": client_resource,
- "/_matrix/client/r0": client_resource,
- "/_matrix/client/unstable": client_resource,
- "/_matrix/client/v2_alpha": client_resource,
- "/_matrix/client/versions": client_resource,
- "/.well-known/matrix/client": WellKnownResource(self),
- "/_synapse/admin": AdminRestResource(self),
- })
+ resources.update(
+ {
+ "/_matrix/client/api/v1": client_resource,
+ "/_matrix/client/r0": client_resource,
+ "/_matrix/client/unstable": client_resource,
+ "/_matrix/client/v2_alpha": client_resource,
+ "/_matrix/client/versions": client_resource,
+ "/.well-known/matrix/client": WellKnownResource(self),
+ "/_synapse/admin": AdminRestResource(self),
+ }
+ )
if self.get_config().saml2_enabled:
from synapse.rest.saml2 import SAML2Resource
+
resources["/_matrix/saml2"] = SAML2Resource(self)
if name == "consent":
from synapse.rest.consent.consent_resource import ConsentResource
+
consent_resource = ConsentResource(self)
if compress:
consent_resource = gz_wrap(consent_resource)
- resources.update({
- "/_matrix/consent": consent_resource,
- })
+ resources.update({"/_matrix/consent": consent_resource})
if name == "federation":
- resources.update({
- FEDERATION_PREFIX: TransportLayerServer(self),
- })
+ resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
if name == "openid":
- resources.update({
- FEDERATION_PREFIX: TransportLayerServer(self, servlet_groups=["openid"]),
- })
+ resources.update(
+ {
+ FEDERATION_PREFIX: TransportLayerServer(
+ self, servlet_groups=["openid"]
+ )
+ }
+ )
if name in ["static", "client"]:
- resources.update({
- STATIC_PREFIX: File(
- os.path.join(os.path.dirname(synapse.__file__), "static")
- ),
- })
+ resources.update(
+ {
+ STATIC_PREFIX: File(
+ os.path.join(os.path.dirname(synapse.__file__), "static")
+ )
+ }
+ )
if name in ["media", "federation", "client"]:
if self.get_config().enable_media_repo:
media_repo = self.get_media_repository_resource()
- resources.update({
- MEDIA_PREFIX: media_repo,
- LEGACY_MEDIA_PREFIX: media_repo,
- CONTENT_REPO_PREFIX: ContentRepoResource(
- self, self.config.uploads_path
- ),
- })
+ resources.update(
+ {MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo}
+ )
elif name == "media":
raise ConfigError(
- "'media' resource conflicts with enable_media_repo=False",
+ "'media' resource conflicts with enable_media_repo=False"
)
if name in ["keys", "federation"]:
@@ -257,18 +265,14 @@ class SynapseHomeServer(HomeServer):
for listener in listeners:
if listener["type"] == "http":
- self._listening_services.extend(
- self._listener_http(config, listener)
- )
+ self._listening_services.extend(self._listener_http(config, listener))
elif listener["type"] == "manhole":
listen_tcp(
listener["bind_addresses"],
listener["port"],
manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- )
+ username="matrix", password="rabbithole", globals={"hs": self}
+ ),
)
elif listener["type"] == "replication":
services = listen_tcp(
@@ -277,42 +281,32 @@ class SynapseHomeServer(HomeServer):
ReplicationStreamProtocolFactory(self),
)
for s in services:
- reactor.addSystemEventTrigger(
- "before", "shutdown", s.stopListening,
- )
+ reactor.addSystemEventTrigger("before", "shutdown", s.stopListening)
elif listener["type"] == "metrics":
if not self.get_config().enable_metrics:
- logger.warn(("Metrics listener configured, but "
- "enable_metrics is not True!"))
+ logger.warning(
+ (
+ "Metrics listener configured, but "
+ "enable_metrics is not True!"
+ )
+ )
else:
- _base.listen_metrics(listener["bind_addresses"],
- listener["port"])
+ _base.listen_metrics(listener["bind_addresses"], listener["port"])
else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
-
- def run_startup_checks(self, db_conn, database_engine):
- all_users_native = are_all_users_on_domain(
- db_conn.cursor(), database_engine, self.hostname
- )
- if not all_users_native:
- quit_with_error(
- "Found users in database not native to %s!\n"
- "You cannot changed a synapse server_name after it's been configured"
- % (self.hostname,)
- )
-
- try:
- database_engine.check_database(db_conn.cursor())
- except IncorrectDatabaseSetup as e:
- quit_with_error(str(e))
+ logger.warning("Unrecognized listener type: %s", listener["type"])
# Gauges to expose monthly active user control metrics
current_mau_gauge = Gauge("synapse_admin_mau:current", "Current MAU")
+current_mau_by_service_gauge = Gauge(
+ "synapse_admin_mau_current_mau_by_service",
+ "Current MAU by service",
+ ["app_service"],
+)
max_mau_gauge = Gauge("synapse_admin_mau:max", "MAU Limit")
registered_reserved_users_mau_gauge = Gauge(
"synapse_admin_mau:registered_reserved_users",
- "Registered users with reserved threepids"
+ "Registered users with reserved threepids",
)
@@ -327,11 +321,10 @@ def setup(config_options):
"""
try:
config = HomeServerConfig.load_or_generate_config(
- "Synapse Homeserver",
- config_options,
+ "Synapse Homeserver", config_options
)
except ConfigError as e:
- sys.stderr.write("\n" + str(e) + "\n")
+ sys.stderr.write("\nERROR: %s\n" % (e,))
sys.exit(1)
if not config:
@@ -339,45 +332,25 @@ def setup(config_options):
# generating config files and shouldn't try to continue.
sys.exit(0)
- synapse.config.logger.setup_logging(
- config,
- use_worker_options=False
- )
-
events.USE_FROZEN_DICTS = config.use_frozen_dicts
- database_engine = create_engine(config.database_config)
- config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
-
hs = SynapseHomeServer(
config.server_name,
- db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
)
- logger.info("Preparing database: %s...", config.database_config['name'])
+ synapse.config.logger.setup_logging(hs, config, use_worker_options=False)
- try:
- with hs.get_db_conn(run_new_connection=False) as db_conn:
- prepare_database(db_conn, database_engine, config=config)
- database_engine.on_new_connection(db_conn)
-
- hs.run_startup_checks(db_conn, database_engine)
-
- db_conn.commit()
- except UpgradeDatabaseException:
- sys.stderr.write(
- "\nFailed to upgrade database.\n"
- "Have you checked for version specific instructions in"
- " UPGRADES.rst?\n"
- )
- sys.exit(1)
+ logger.info("Setting up server")
- logger.info("Database prepared in %s.", config.database_config['name'])
+ try:
+ hs.setup()
+ except IncorrectDatabaseSetup as e:
+ quit_with_error(str(e))
+ except UpgradeDatabaseException as e:
+ quit_with_error("Failed to upgrade database: %s" % (e,))
- hs.setup()
hs.setup_master()
@defer.inlineCallbacks
@@ -391,9 +364,7 @@ def setup(config_options):
acme = hs.get_acme_handler()
# Check how long the certificate is active for.
- cert_days_remaining = hs.config.is_disk_cert_valid(
- allow_self_signed=False
- )
+ cert_days_remaining = hs.config.is_disk_cert_valid(allow_self_signed=False)
# We want to reprovision if cert_days_remaining is None (meaning no
# certificate exists), or the days remaining number it returns
@@ -401,15 +372,15 @@ def setup(config_options):
provision = False
if (
- cert_days_remaining is None or
- cert_days_remaining < hs.config.acme_reprovision_threshold
+ cert_days_remaining is None
+ or cert_days_remaining < hs.config.acme_reprovision_threshold
):
provision = True
if provision:
yield acme.provision_certificate()
- defer.returnValue(provision)
+ return provision
@defer.inlineCallbacks
def reprovision_acme():
@@ -433,15 +404,11 @@ def setup(config_options):
yield do_acme()
# Check if it needs to be reprovisioned every day.
- hs.get_clock().looping_call(
- reprovision_acme,
- 24 * 60 * 60 * 1000
- )
+ hs.get_clock().looping_call(reprovision_acme, 24 * 60 * 60 * 1000)
_base.start(hs, config.listeners)
- hs.get_pusherpool().start()
- hs.get_datastore().start_doing_background_updates()
+ hs.get_datastore().db.updates.start_doing_background_updates()
except Exception:
# Print the exception and bail out.
print("Error during startup:", file=sys.stderr)
@@ -463,6 +430,7 @@ class SynapseService(service.Service):
A twisted Service class that will start synapse. Used to run synapse
via twistd and a .tac.
"""
+
def __init__(self, config):
self.config = config
@@ -476,9 +444,93 @@ class SynapseService(service.Service):
return self._port.stopListening()
+# Contains the list of processes we will be monitoring
+# currently either 0 or 1
+_stats_process = []
+
+
+@defer.inlineCallbacks
+def phone_stats_home(hs, stats, stats_process=_stats_process):
+ logger.info("Gathering stats for reporting")
+ now = int(hs.get_clock().time())
+ uptime = int(now - hs.start_time)
+ if uptime < 0:
+ uptime = 0
+
+ stats["homeserver"] = hs.config.server_name
+ stats["server_context"] = hs.config.server_context
+ stats["timestamp"] = now
+ stats["uptime_seconds"] = uptime
+ version = sys.version_info
+ stats["python_version"] = "{}.{}.{}".format(
+ version.major, version.minor, version.micro
+ )
+ stats["total_users"] = yield hs.get_datastore().count_all_users()
+
+ total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users()
+ stats["total_nonbridged_users"] = total_nonbridged_users
+
+ daily_user_type_results = yield hs.get_datastore().count_daily_user_type()
+ for name, count in iteritems(daily_user_type_results):
+ stats["daily_user_type_" + name] = count
+
+ room_count = yield hs.get_datastore().get_room_count()
+ stats["total_room_count"] = room_count
+
+ stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
+ stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users()
+ stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms()
+ stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
+
+ r30_results = yield hs.get_datastore().count_r30_users()
+ for name, count in iteritems(r30_results):
+ stats["r30_users_" + name] = count
+
+ daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
+ stats["daily_sent_messages"] = daily_sent_messages
+ stats["cache_factor"] = CACHE_SIZE_FACTOR
+ stats["event_cache_size"] = hs.config.event_cache_size
+
+ #
+ # Performance statistics
+ #
+ old = stats_process[0]
+ new = (now, resource.getrusage(resource.RUSAGE_SELF))
+ stats_process[0] = new
+
+ # Get RSS in bytes
+ stats["memory_rss"] = new[1].ru_maxrss
+
+ # Get CPU time in % of a single core, not % of all cores
+ used_cpu_time = (new[1].ru_utime + new[1].ru_stime) - (
+ old[1].ru_utime + old[1].ru_stime
+ )
+ if used_cpu_time == 0 or new[0] == old[0]:
+ stats["cpu_average"] = 0
+ else:
+ stats["cpu_average"] = math.floor(used_cpu_time / (new[0] - old[0]) * 100)
+
+ #
+ # Database version
+ #
+
+ # This only reports info about the *main* database.
+ stats["database_engine"] = hs.get_datastore().db.engine.module.__name__
+ stats["database_server_version"] = hs.get_datastore().db.engine.server_version
+
+ logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
+ try:
+ yield hs.get_proxied_http_client().put_json(
+ hs.config.report_stats_endpoint, stats
+ )
+ except Exception as e:
+ logger.warning("Error reporting stats: %s", e)
+
+
def run(hs):
PROFILE_SYNAPSE = False
if PROFILE_SYNAPSE:
+
def profile(func):
from cProfile import Profile
from threading import current_thread
@@ -489,104 +541,35 @@ def run(hs):
func(*args, **kargs)
profile.disable()
ident = current_thread().ident
- profile.dump_stats("/tmp/%s.%s.%i.pstat" % (
- hs.hostname, func.__name__, ident
- ))
+ profile.dump_stats(
+ "/tmp/%s.%s.%i.pstat" % (hs.hostname, func.__name__, ident)
+ )
return profiled
from twisted.python.threadpool import ThreadPool
+
ThreadPool._worker = profile(ThreadPool._worker)
reactor.run = profile(reactor.run)
clock = hs.get_clock()
- start_time = clock.time()
stats = {}
- # Contains the list of processes we will be monitoring
- # currently either 0 or 1
- stats_process = []
+ def performance_stats_init():
+ _stats_process.clear()
+ _stats_process.append(
+ (int(hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF))
+ )
def start_phone_stats_home():
- return run_as_background_process("phone_stats_home", phone_stats_home)
-
- @defer.inlineCallbacks
- def phone_stats_home():
- logger.info("Gathering stats for reporting")
- now = int(hs.get_clock().time())
- uptime = int(now - start_time)
- if uptime < 0:
- uptime = 0
-
- stats["homeserver"] = hs.config.server_name
- stats["server_context"] = hs.config.server_context
- stats["timestamp"] = now
- stats["uptime_seconds"] = uptime
- version = sys.version_info
- stats["python_version"] = "{}.{}.{}".format(
- version.major, version.minor, version.micro
+ return run_as_background_process(
+ "phone_stats_home", phone_stats_home, hs, stats
)
- stats["total_users"] = yield hs.get_datastore().count_all_users()
-
- total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users()
- stats["total_nonbridged_users"] = total_nonbridged_users
-
- daily_user_type_results = yield hs.get_datastore().count_daily_user_type()
- for name, count in iteritems(daily_user_type_results):
- stats["daily_user_type_" + name] = count
-
- room_count = yield hs.get_datastore().get_room_count()
- stats["total_room_count"] = room_count
-
- stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
- stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms()
- stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
-
- r30_results = yield hs.get_datastore().count_r30_users()
- for name, count in iteritems(r30_results):
- stats["r30_users_" + name] = count
-
- daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
- stats["daily_sent_messages"] = daily_sent_messages
- stats["cache_factor"] = CACHE_SIZE_FACTOR
- stats["event_cache_size"] = hs.config.event_cache_size
-
- if len(stats_process) > 0:
- stats["memory_rss"] = 0
- stats["cpu_average"] = 0
- for process in stats_process:
- stats["memory_rss"] += process.memory_info().rss
- stats["cpu_average"] += int(process.cpu_percent(interval=None))
-
- stats["database_engine"] = hs.get_datastore().database_engine_name
- stats["database_server_version"] = hs.get_datastore().get_server_version()
- logger.info("Reporting stats to matrix.org: %s" % (stats,))
- try:
- yield hs.get_proxied_http_client().put_json(
- "https://matrix.org/report-usage-stats/push", stats
- )
- except Exception as e:
- logger.warn("Error reporting stats: %s", e)
-
- def performance_stats_init():
- try:
- process = psutil.Process()
- # Ensure we can fetch both, and make the initial request for cpu_percent
- # so the next request will use this as the initial point.
- process.memory_info().rss
- process.cpu_percent(interval=None)
- logger.info("report_stats can use psutil")
- stats_process.append(process)
- except (AttributeError):
- logger.warning(
- "Unable to read memory/cpu stats. Disabling reporting."
- )
def generate_user_daily_visit_stats():
return run_as_background_process(
- "generate_user_daily_visits",
- hs.get_datastore().generate_user_daily_visits,
+ "generate_user_daily_visits", hs.get_datastore().generate_user_daily_visits
)
# Rather than update on per session basis, batch up the requests.
@@ -597,28 +580,35 @@ def run(hs):
# monthly active user limiting functionality
def reap_monthly_active_users():
return run_as_background_process(
- "reap_monthly_active_users",
- hs.get_datastore().reap_monthly_active_users,
+ "reap_monthly_active_users", hs.get_datastore().reap_monthly_active_users
)
+
clock.looping_call(reap_monthly_active_users, 1000 * 60 * 60)
reap_monthly_active_users()
@defer.inlineCallbacks
def generate_monthly_active_users():
current_mau_count = 0
- reserved_count = 0
+ current_mau_count_by_service = {}
+ reserved_users = ()
store = hs.get_datastore()
if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
current_mau_count = yield store.get_monthly_active_count()
- reserved_count = yield store.get_registered_reserved_users_count()
+ current_mau_count_by_service = (
+ yield store.get_monthly_active_count_by_service()
+ )
+ reserved_users = yield store.get_registered_reserved_users()
current_mau_gauge.set(float(current_mau_count))
- registered_reserved_users_mau_gauge.set(float(reserved_count))
+
+ for app_service, count in current_mau_count_by_service.items():
+ current_mau_by_service_gauge.labels(app_service).set(float(count))
+
+ registered_reserved_users_mau_gauge.set(float(len(reserved_users)))
max_mau_gauge.set(float(hs.config.max_mau_value))
def start_generate_monthly_active_users():
return run_as_background_process(
- "generate_monthly_active_users",
- generate_monthly_active_users,
+ "generate_monthly_active_users", generate_monthly_active_users
)
start_generate_monthly_active_users()
@@ -644,7 +634,6 @@ def run(hs):
gc_thresholds=hs.config.gc_thresholds,
pid_file=hs.config.pid_file,
daemonize=hs.config.daemonize,
- cpu_affinity=hs.config.cpu_affinity,
print_pidfile=hs.config.print_pidfile,
logger=logger,
)
@@ -658,5 +647,5 @@ def main():
run(hs)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index d4cc4e9443..add43147b3 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -13,157 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
-import sys
-from twisted.internet import reactor
-from twisted.web.resource import NoResource
+import sys
-import synapse
-from synapse import events
-from synapse.api.urls import CONTENT_REPO_PREFIX, LEGACY_MEDIA_PREFIX, MEDIA_PREFIX
-from synapse.app import _base
-from synapse.config._base import ConfigError
-from synapse.config.homeserver import HomeServerConfig
-from synapse.config.logger import setup_logging
-from synapse.http.site import SynapseSite
-from synapse.metrics import RegistryProxy
-from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
-from synapse.replication.slave.storage.transactions import SlavedTransactionStore
-from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.rest.media.v0.content_repository import ContentRepoResource
-from synapse.server import HomeServer
-from synapse.storage.engines import create_engine
-from synapse.storage.media_repository import MediaRepositoryStore
-from synapse.util.httpresourcetree import create_resource_tree
+from synapse.app.generic_worker import start
from synapse.util.logcontext import LoggingContext
-from synapse.util.manhole import manhole
-from synapse.util.versionstring import get_version_string
-
-logger = logging.getLogger("synapse.app.media_repository")
-
-
-class MediaRepositorySlavedStore(
- SlavedApplicationServiceStore,
- SlavedRegistrationStore,
- SlavedClientIpStore,
- SlavedTransactionStore,
- BaseSlavedStore,
- MediaRepositoryStore,
-):
- pass
-
-
-class MediaRepositoryServer(HomeServer):
- DATASTORE_CLASS = MediaRepositorySlavedStore
-
- def _listen_http(self, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- site_tag = listener_config.get("tag", port)
- resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
- if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
- elif name == "media":
- media_repo = self.get_media_repository_resource()
- resources.update({
- MEDIA_PREFIX: media_repo,
- LEGACY_MEDIA_PREFIX: media_repo,
- CONTENT_REPO_PREFIX: ContentRepoResource(
- self, self.config.uploads_path
- ),
- })
-
- root_resource = create_resource_tree(resources, NoResource())
-
- _base.listen_tcp(
- bind_addresses,
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- self.version_string,
- )
- )
-
- logger.info("Synapse media repository now listening on port %d", port)
-
- def start_listening(self, listeners):
- for listener in listeners:
- if listener["type"] == "http":
- self._listen_http(listener)
- elif listener["type"] == "manhole":
- _base.listen_tcp(
- listener["bind_addresses"],
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- )
- )
- elif listener["type"] == "metrics":
- if not self.get_config().enable_metrics:
- logger.warn(("Metrics listener configured, but "
- "enable_metrics is not True!"))
- else:
- _base.listen_metrics(listener["bind_addresses"],
- listener["port"])
- else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
-
- self.get_tcp_replication().start_replication(self)
-
- def build_tcp_replication(self):
- return ReplicationClientHandler(self.get_datastore())
-
-
-def start(config_options):
- try:
- config = HomeServerConfig.load_config(
- "Synapse media repository", config_options
- )
- except ConfigError as e:
- sys.stderr.write("\n" + str(e) + "\n")
- sys.exit(1)
-
- assert config.worker_app == "synapse.app.media_repository"
-
- if config.enable_media_repo:
- _base.quit_with_error(
- "enable_media_repo must be disabled in the main synapse process\n"
- "before the media repo can be run in a separate worker.\n"
- "Please add ``enable_media_repo: false`` to the main config\n"
- )
-
- setup_logging(config, use_worker_options=True)
-
- events.USE_FROZEN_DICTS = config.use_frozen_dicts
-
- database_engine = create_engine(config.database_config)
-
- ss = MediaRepositoryServer(
- config.server_name,
- db_config=config.database_config,
- config=config,
- version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
- )
-
- ss.setup()
- reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
-
- _base.start_worker_reactor("synapse-media-repository", config)
-
-if __name__ == '__main__':
+if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index cbf0d67f51..add43147b3 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -13,227 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
-import sys
-
-from twisted.internet import defer, reactor
-from twisted.web.resource import NoResource
-
-import synapse
-from synapse import events
-from synapse.app import _base
-from synapse.config._base import ConfigError
-from synapse.config.homeserver import HomeServerConfig
-from synapse.config.logger import setup_logging
-from synapse.http.site import SynapseSite
-from synapse.metrics import RegistryProxy
-from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
-from synapse.replication.slave.storage._base import __func__
-from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
-from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.replication.slave.storage.pushers import SlavedPusherStore
-from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
-from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.server import HomeServer
-from synapse.storage import DataStore
-from synapse.storage.engines import create_engine
-from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, run_in_background
-from synapse.util.manhole import manhole
-from synapse.util.versionstring import get_version_string
-
-logger = logging.getLogger("synapse.app.pusher")
-
-
-class PusherSlaveStore(
- SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore,
- SlavedAccountDataStore
-):
- update_pusher_last_stream_ordering_and_success = (
- __func__(DataStore.update_pusher_last_stream_ordering_and_success)
- )
-
- update_pusher_failing_since = (
- __func__(DataStore.update_pusher_failing_since)
- )
-
- update_pusher_last_stream_ordering = (
- __func__(DataStore.update_pusher_last_stream_ordering)
- )
-
- get_throttle_params_by_room = (
- __func__(DataStore.get_throttle_params_by_room)
- )
-
- set_throttle_params = (
- __func__(DataStore.set_throttle_params)
- )
-
- get_time_of_last_push_action_before = (
- __func__(DataStore.get_time_of_last_push_action_before)
- )
-
- get_profile_displayname = (
- __func__(DataStore.get_profile_displayname)
- )
-
-
-class PusherServer(HomeServer):
- DATASTORE_CLASS = PusherSlaveStore
-
- def remove_pusher(self, app_id, push_key, user_id):
- self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
-
- def _listen_http(self, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- site_tag = listener_config.get("tag", port)
- resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
- if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
-
- root_resource = create_resource_tree(resources, NoResource())
-
- _base.listen_tcp(
- bind_addresses,
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- self.version_string,
- )
- )
-
- logger.info("Synapse pusher now listening on port %d", port)
-
- def start_listening(self, listeners):
- for listener in listeners:
- if listener["type"] == "http":
- self._listen_http(listener)
- elif listener["type"] == "manhole":
- _base.listen_tcp(
- listener["bind_addresses"],
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- )
- )
- elif listener["type"] == "metrics":
- if not self.get_config().enable_metrics:
- logger.warn(("Metrics listener configured, but "
- "enable_metrics is not True!"))
- else:
- _base.listen_metrics(listener["bind_addresses"],
- listener["port"])
- else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
-
- self.get_tcp_replication().start_replication(self)
- def build_tcp_replication(self):
- return PusherReplicationHandler(self)
-
-
-class PusherReplicationHandler(ReplicationClientHandler):
- def __init__(self, hs):
- super(PusherReplicationHandler, self).__init__(hs.get_datastore())
-
- self.pusher_pool = hs.get_pusherpool()
-
- @defer.inlineCallbacks
- def on_rdata(self, stream_name, token, rows):
- yield super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
- run_in_background(self.poke_pushers, stream_name, token, rows)
-
- @defer.inlineCallbacks
- def poke_pushers(self, stream_name, token, rows):
- try:
- if stream_name == "pushers":
- for row in rows:
- if row.deleted:
- yield self.stop_pusher(row.user_id, row.app_id, row.pushkey)
- else:
- yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
- elif stream_name == "events":
- yield self.pusher_pool.on_new_notifications(
- token, token,
- )
- elif stream_name == "receipts":
- yield self.pusher_pool.on_new_receipts(
- token, token, set(row.room_id for row in rows)
- )
- except Exception:
- logger.exception("Error poking pushers")
-
- def stop_pusher(self, user_id, app_id, pushkey):
- key = "%s:%s" % (app_id, pushkey)
- pushers_for_user = self.pusher_pool.pushers.get(user_id, {})
- pusher = pushers_for_user.pop(key, None)
- if pusher is None:
- return
- logger.info("Stopping pusher %r / %r", user_id, key)
- pusher.on_stop()
-
- def start_pusher(self, user_id, app_id, pushkey):
- key = "%s:%s" % (app_id, pushkey)
- logger.info("Starting pusher %r / %r", user_id, key)
- return self.pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
-
-
-def start(config_options):
- try:
- config = HomeServerConfig.load_config(
- "Synapse pusher", config_options
- )
- except ConfigError as e:
- sys.stderr.write("\n" + str(e) + "\n")
- sys.exit(1)
-
- assert config.worker_app == "synapse.app.pusher"
-
- setup_logging(config, use_worker_options=True)
-
- events.USE_FROZEN_DICTS = config.use_frozen_dicts
-
- if config.start_pushers:
- sys.stderr.write(
- "\nThe pushers must be disabled in the main synapse process"
- "\nbefore they can be run in a separate worker."
- "\nPlease add ``start_pushers: false`` to the main config"
- "\n"
- )
- sys.exit(1)
-
- # Force the pushers to start since they will be disabled in the main config
- config.start_pushers = True
-
- database_engine = create_engine(config.database_config)
-
- ps = PusherServer(
- config.server_name,
- db_config=config.database_config,
- config=config,
- version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
- )
-
- ps.setup()
-
- def start():
- _base.start(ps, config.worker_listeners)
- ps.get_pusherpool().start()
-
- reactor.callWhenRunning(start)
-
- _base.start_worker_reactor("synapse-pusher", config)
+import sys
+from synapse.app.generic_worker import start
+from synapse.util.logcontext import LoggingContext
-if __name__ == '__main__':
+if __name__ == "__main__":
with LoggingContext("main"):
- ps = start(sys.argv[1:])
+ start(sys.argv[1:])
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 5388def28a..add43147b3 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -13,446 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import contextlib
-import logging
-import sys
-
-from six import iteritems
-
-from twisted.internet import defer, reactor
-from twisted.web.resource import NoResource
-
-import synapse
-from synapse.api.constants import EventTypes
-from synapse.app import _base
-from synapse.config._base import ConfigError
-from synapse.config.homeserver import HomeServerConfig
-from synapse.config.logger import setup_logging
-from synapse.handlers.presence import PresenceHandler, get_interested_parties
-from synapse.http.server import JsonResource
-from synapse.http.site import SynapseSite
-from synapse.metrics import RegistryProxy
-from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
-from synapse.replication.slave.storage._base import BaseSlavedStore, __func__
-from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
-from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
-from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
-from synapse.replication.slave.storage.devices import SlavedDeviceStore
-from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.replication.slave.storage.filtering import SlavedFilteringStore
-from synapse.replication.slave.storage.groups import SlavedGroupServerStore
-from synapse.replication.slave.storage.presence import SlavedPresenceStore
-from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
-from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
-from synapse.replication.slave.storage.room import RoomStore
-from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.replication.tcp.streams.events import EventsStreamEventRow
-from synapse.rest.client.v1 import events
-from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
-from synapse.rest.client.v1.room import RoomInitialSyncRestServlet
-from synapse.rest.client.v2_alpha import sync
-from synapse.server import HomeServer
-from synapse.storage.engines import create_engine
-from synapse.storage.presence import UserPresenceState
-from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, run_in_background
-from synapse.util.manhole import manhole
-from synapse.util.stringutils import random_string
-from synapse.util.versionstring import get_version_string
-
-logger = logging.getLogger("synapse.app.synchrotron")
-
-
-class SynchrotronSlavedStore(
- SlavedReceiptsStore,
- SlavedAccountDataStore,
- SlavedApplicationServiceStore,
- SlavedRegistrationStore,
- SlavedFilteringStore,
- SlavedPresenceStore,
- SlavedGroupServerStore,
- SlavedDeviceInboxStore,
- SlavedDeviceStore,
- SlavedPushRuleStore,
- SlavedEventStore,
- SlavedClientIpStore,
- RoomStore,
- BaseSlavedStore,
-):
- pass
-
-
-UPDATE_SYNCING_USERS_MS = 10 * 1000
-
-
-class SynchrotronPresence(object):
- def __init__(self, hs):
- self.hs = hs
- self.is_mine_id = hs.is_mine_id
- self.http_client = hs.get_simple_http_client()
- self.store = hs.get_datastore()
- self.user_to_num_current_syncs = {}
- self.clock = hs.get_clock()
- self.notifier = hs.get_notifier()
-
- active_presence = self.store.take_presence_startup_info()
- self.user_to_current_state = {
- state.user_id: state
- for state in active_presence
- }
-
- # user_id -> last_sync_ms. Lists the users that have stopped syncing
- # but we haven't notified the master of that yet
- self.users_going_offline = {}
-
- self._send_stop_syncing_loop = self.clock.looping_call(
- self.send_stop_syncing, 10 * 1000
- )
-
- self.process_id = random_string(16)
- logger.info("Presence process_id is %r", self.process_id)
-
- def send_user_sync(self, user_id, is_syncing, last_sync_ms):
- if self.hs.config.use_presence:
- self.hs.get_tcp_replication().send_user_sync(
- user_id, is_syncing, last_sync_ms
- )
-
- def mark_as_coming_online(self, user_id):
- """A user has started syncing. Send a UserSync to the master, unless they
- had recently stopped syncing.
-
- Args:
- user_id (str)
- """
- going_offline = self.users_going_offline.pop(user_id, None)
- if not going_offline:
- # Safe to skip because we haven't yet told the master they were offline
- self.send_user_sync(user_id, True, self.clock.time_msec())
-
- def mark_as_going_offline(self, user_id):
- """A user has stopped syncing. We wait before notifying the master as
- its likely they'll come back soon. This allows us to avoid sending
- a stopped syncing immediately followed by a started syncing notification
- to the master
-
- Args:
- user_id (str)
- """
- self.users_going_offline[user_id] = self.clock.time_msec()
-
- def send_stop_syncing(self):
- """Check if there are any users who have stopped syncing a while ago
- and haven't come back yet. If there are poke the master about them.
- """
- now = self.clock.time_msec()
- for user_id, last_sync_ms in list(self.users_going_offline.items()):
- if now - last_sync_ms > 10 * 1000:
- self.users_going_offline.pop(user_id, None)
- self.send_user_sync(user_id, False, last_sync_ms)
-
- def set_state(self, user, state, ignore_status_msg=False):
- # TODO Hows this supposed to work?
- pass
-
- get_states = __func__(PresenceHandler.get_states)
- get_state = __func__(PresenceHandler.get_state)
- current_state_for_users = __func__(PresenceHandler.current_state_for_users)
-
- def user_syncing(self, user_id, affect_presence):
- if affect_presence:
- curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
- self.user_to_num_current_syncs[user_id] = curr_sync + 1
-
- # If we went from no in flight sync to some, notify replication
- if self.user_to_num_current_syncs[user_id] == 1:
- self.mark_as_coming_online(user_id)
-
- def _end():
- # We check that the user_id is in user_to_num_current_syncs because
- # user_to_num_current_syncs may have been cleared if we are
- # shutting down.
- if affect_presence and user_id in self.user_to_num_current_syncs:
- self.user_to_num_current_syncs[user_id] -= 1
-
- # If we went from one in flight sync to non, notify replication
- if self.user_to_num_current_syncs[user_id] == 0:
- self.mark_as_going_offline(user_id)
-
- @contextlib.contextmanager
- def _user_syncing():
- try:
- yield
- finally:
- _end()
-
- return defer.succeed(_user_syncing())
-
- @defer.inlineCallbacks
- def notify_from_replication(self, states, stream_id):
- parties = yield get_interested_parties(self.store, states)
- room_ids_to_states, users_to_states = parties
-
- self.notifier.on_new_event(
- "presence_key", stream_id, rooms=room_ids_to_states.keys(),
- users=users_to_states.keys()
- )
-
- @defer.inlineCallbacks
- def process_replication_rows(self, token, rows):
- states = [UserPresenceState(
- row.user_id, row.state, row.last_active_ts,
- row.last_federation_update_ts, row.last_user_sync_ts, row.status_msg,
- row.currently_active
- ) for row in rows]
-
- for state in states:
- self.user_to_current_state[state.user_id] = state
-
- stream_id = token
- yield self.notify_from_replication(states, stream_id)
-
- def get_currently_syncing_users(self):
- if self.hs.config.use_presence:
- return [
- user_id for user_id, count in iteritems(self.user_to_num_current_syncs)
- if count > 0
- ]
- else:
- return set()
-
-class SynchrotronTyping(object):
- def __init__(self, hs):
- self._latest_room_serial = 0
- self._reset()
-
- def _reset(self):
- """
- Reset the typing handler's data caches.
- """
- # map room IDs to serial numbers
- self._room_serials = {}
- # map room IDs to sets of users currently typing
- self._room_typing = {}
-
- def stream_positions(self):
- # We must update this typing token from the response of the previous
- # sync. In particular, the stream id may "reset" back to zero/a low
- # value which we *must* use for the next replication request.
- return {"typing": self._latest_room_serial}
-
- def process_replication_rows(self, token, rows):
- if self._latest_room_serial > token:
- # The master has gone backwards. To prevent inconsistent data, just
- # clear everything.
- self._reset()
-
- # Set the latest serial token to whatever the server gave us.
- self._latest_room_serial = token
-
- for row in rows:
- self._room_serials[row.room_id] = token
- self._room_typing[row.room_id] = row.user_ids
-
-
-class SynchrotronApplicationService(object):
- def notify_interested_services(self, event):
- pass
-
-
-class SynchrotronServer(HomeServer):
- DATASTORE_CLASS = SynchrotronSlavedStore
-
- def _listen_http(self, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- site_tag = listener_config.get("tag", port)
- resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
- if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
- elif name == "client":
- resource = JsonResource(self, canonical_json=False)
- sync.register_servlets(self, resource)
- events.register_servlets(self, resource)
- InitialSyncRestServlet(self).register(resource)
- RoomInitialSyncRestServlet(self).register(resource)
- resources.update({
- "/_matrix/client/r0": resource,
- "/_matrix/client/unstable": resource,
- "/_matrix/client/v2_alpha": resource,
- "/_matrix/client/api/v1": resource,
- })
-
- root_resource = create_resource_tree(resources, NoResource())
-
- _base.listen_tcp(
- bind_addresses,
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- self.version_string,
- )
- )
-
- logger.info("Synapse synchrotron now listening on port %d", port)
-
- def start_listening(self, listeners):
- for listener in listeners:
- if listener["type"] == "http":
- self._listen_http(listener)
- elif listener["type"] == "manhole":
- _base.listen_tcp(
- listener["bind_addresses"],
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- )
- )
- elif listener["type"] == "metrics":
- if not self.get_config().enable_metrics:
- logger.warn(("Metrics listener configured, but "
- "enable_metrics is not True!"))
- else:
- _base.listen_metrics(listener["bind_addresses"],
- listener["port"])
- else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
-
- self.get_tcp_replication().start_replication(self)
-
- def build_tcp_replication(self):
- return SyncReplicationHandler(self)
-
- def build_presence_handler(self):
- return SynchrotronPresence(self)
-
- def build_typing_handler(self):
- return SynchrotronTyping(self)
-
-
-class SyncReplicationHandler(ReplicationClientHandler):
- def __init__(self, hs):
- super(SyncReplicationHandler, self).__init__(hs.get_datastore())
-
- self.store = hs.get_datastore()
- self.typing_handler = hs.get_typing_handler()
- # NB this is a SynchrotronPresence, not a normal PresenceHandler
- self.presence_handler = hs.get_presence_handler()
- self.notifier = hs.get_notifier()
-
- @defer.inlineCallbacks
- def on_rdata(self, stream_name, token, rows):
- yield super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
- run_in_background(self.process_and_notify, stream_name, token, rows)
-
- def get_streams_to_replicate(self):
- args = super(SyncReplicationHandler, self).get_streams_to_replicate()
- args.update(self.typing_handler.stream_positions())
- return args
-
- def get_currently_syncing_users(self):
- return self.presence_handler.get_currently_syncing_users()
-
- @defer.inlineCallbacks
- def process_and_notify(self, stream_name, token, rows):
- try:
- if stream_name == "events":
- # We shouldn't get multiple rows per token for events stream, so
- # we don't need to optimise this for multiple rows.
- for row in rows:
- if row.type != EventsStreamEventRow.TypeId:
- continue
- event = yield self.store.get_event(row.data.event_id)
- extra_users = ()
- if event.type == EventTypes.Member:
- extra_users = (event.state_key,)
- max_token = self.store.get_room_max_stream_ordering()
- self.notifier.on_new_room_event(
- event, token, max_token, extra_users
- )
- elif stream_name == "push_rules":
- self.notifier.on_new_event(
- "push_rules_key", token, users=[row.user_id for row in rows],
- )
- elif stream_name in ("account_data", "tag_account_data",):
- self.notifier.on_new_event(
- "account_data_key", token, users=[row.user_id for row in rows],
- )
- elif stream_name == "receipts":
- self.notifier.on_new_event(
- "receipt_key", token, rooms=[row.room_id for row in rows],
- )
- elif stream_name == "typing":
- self.typing_handler.process_replication_rows(token, rows)
- self.notifier.on_new_event(
- "typing_key", token, rooms=[row.room_id for row in rows],
- )
- elif stream_name == "to_device":
- entities = [row.entity for row in rows if row.entity.startswith("@")]
- if entities:
- self.notifier.on_new_event(
- "to_device_key", token, users=entities,
- )
- elif stream_name == "device_lists":
- all_room_ids = set()
- for row in rows:
- room_ids = yield self.store.get_rooms_for_user(row.user_id)
- all_room_ids.update(room_ids)
- self.notifier.on_new_event(
- "device_list_key", token, rooms=all_room_ids,
- )
- elif stream_name == "presence":
- yield self.presence_handler.process_replication_rows(token, rows)
- elif stream_name == "receipts":
- self.notifier.on_new_event(
- "groups_key", token, users=[row.user_id for row in rows],
- )
- except Exception:
- logger.exception("Error processing replication")
-
-
-def start(config_options):
- try:
- config = HomeServerConfig.load_config(
- "Synapse synchrotron", config_options
- )
- except ConfigError as e:
- sys.stderr.write("\n" + str(e) + "\n")
- sys.exit(1)
-
- assert config.worker_app == "synapse.app.synchrotron"
-
- setup_logging(config, use_worker_options=True)
-
- synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
-
- database_engine = create_engine(config.database_config)
-
- ss = SynchrotronServer(
- config.server_name,
- db_config=config.database_config,
- config=config,
- version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
- application_service_handler=SynchrotronApplicationService(),
- )
-
- ss.setup()
- reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
-
- _base.start_worker_reactor("synapse-synchrotron", config)
+import sys
+from synapse.app.generic_worker import start
+from synapse.util.logcontext import LoggingContext
-if __name__ == '__main__':
+if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index 355f5aa71d..503d44f687 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -14,219 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
import sys
-from twisted.internet import defer, reactor
-from twisted.web.resource import NoResource
+from synapse.app.generic_worker import start
+from synapse.util.logcontext import LoggingContext
-import synapse
-from synapse import events
-from synapse.app import _base
-from synapse.config._base import ConfigError
-from synapse.config.homeserver import HomeServerConfig
-from synapse.config.logger import setup_logging
-from synapse.http.server import JsonResource
-from synapse.http.site import SynapseSite
-from synapse.metrics import RegistryProxy
-from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
-from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
-from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.replication.slave.storage.registration import SlavedRegistrationStore
-from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.replication.tcp.streams.events import (
- EventsStream,
- EventsStreamCurrentStateRow,
-)
-from synapse.rest.client.v2_alpha import user_directory
-from synapse.server import HomeServer
-from synapse.storage.engines import create_engine
-from synapse.storage.user_directory import UserDirectoryStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
-from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, run_in_background
-from synapse.util.manhole import manhole
-from synapse.util.versionstring import get_version_string
-
-logger = logging.getLogger("synapse.app.user_dir")
-
-
-class UserDirectorySlaveStore(
- SlavedEventStore,
- SlavedApplicationServiceStore,
- SlavedRegistrationStore,
- SlavedClientIpStore,
- UserDirectoryStore,
- BaseSlavedStore,
-):
- def __init__(self, db_conn, hs):
- super(UserDirectorySlaveStore, self).__init__(db_conn, hs)
-
- events_max = self._stream_id_gen.get_current_token()
- curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
- db_conn, "current_state_delta_stream",
- entity_column="room_id",
- stream_column="stream_id",
- max_value=events_max, # As we share the stream id with events token
- limit=1000,
- )
- self._curr_state_delta_stream_cache = StreamChangeCache(
- "_curr_state_delta_stream_cache", min_curr_state_delta_id,
- prefilled_cache=curr_state_delta_prefill,
- )
-
- def stream_positions(self):
- result = super(UserDirectorySlaveStore, self).stream_positions()
- return result
-
- def process_replication_rows(self, stream_name, token, rows):
- if stream_name == EventsStream.NAME:
- self._stream_id_gen.advance(token)
- for row in rows:
- if row.type != EventsStreamCurrentStateRow.TypeId:
- continue
- self._curr_state_delta_stream_cache.entity_has_changed(
- row.data.room_id, token
- )
- return super(UserDirectorySlaveStore, self).process_replication_rows(
- stream_name, token, rows
- )
-
-
-class UserDirectoryServer(HomeServer):
- DATASTORE_CLASS = UserDirectorySlaveStore
-
- def _listen_http(self, listener_config):
- port = listener_config["port"]
- bind_addresses = listener_config["bind_addresses"]
- site_tag = listener_config.get("tag", port)
- resources = {}
- for res in listener_config["resources"]:
- for name in res["names"]:
- if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
- elif name == "client":
- resource = JsonResource(self, canonical_json=False)
- user_directory.register_servlets(self, resource)
- resources.update({
- "/_matrix/client/r0": resource,
- "/_matrix/client/unstable": resource,
- "/_matrix/client/v2_alpha": resource,
- "/_matrix/client/api/v1": resource,
- })
-
- root_resource = create_resource_tree(resources, NoResource())
-
- _base.listen_tcp(
- bind_addresses,
- port,
- SynapseSite(
- "synapse.access.http.%s" % (site_tag,),
- site_tag,
- listener_config,
- root_resource,
- self.version_string,
- )
- )
-
- logger.info("Synapse user_dir now listening on port %d", port)
-
- def start_listening(self, listeners):
- for listener in listeners:
- if listener["type"] == "http":
- self._listen_http(listener)
- elif listener["type"] == "manhole":
- _base.listen_tcp(
- listener["bind_addresses"],
- listener["port"],
- manhole(
- username="matrix",
- password="rabbithole",
- globals={"hs": self},
- )
- )
- elif listener["type"] == "metrics":
- if not self.get_config().enable_metrics:
- logger.warn(("Metrics listener configured, but "
- "enable_metrics is not True!"))
- else:
- _base.listen_metrics(listener["bind_addresses"],
- listener["port"])
- else:
- logger.warn("Unrecognized listener type: %s", listener["type"])
-
- self.get_tcp_replication().start_replication(self)
-
- def build_tcp_replication(self):
- return UserDirectoryReplicationHandler(self)
-
-
-class UserDirectoryReplicationHandler(ReplicationClientHandler):
- def __init__(self, hs):
- super(UserDirectoryReplicationHandler, self).__init__(hs.get_datastore())
- self.user_directory = hs.get_user_directory_handler()
-
- @defer.inlineCallbacks
- def on_rdata(self, stream_name, token, rows):
- yield super(UserDirectoryReplicationHandler, self).on_rdata(
- stream_name, token, rows
- )
- if stream_name == EventsStream.NAME:
- run_in_background(self._notify_directory)
-
- @defer.inlineCallbacks
- def _notify_directory(self):
- try:
- yield self.user_directory.notify_new_event()
- except Exception:
- logger.exception("Error notifiying user directory of state update")
-
-
-def start(config_options):
- try:
- config = HomeServerConfig.load_config(
- "Synapse user directory", config_options
- )
- except ConfigError as e:
- sys.stderr.write("\n" + str(e) + "\n")
- sys.exit(1)
-
- assert config.worker_app == "synapse.app.user_dir"
-
- setup_logging(config, use_worker_options=True)
-
- events.USE_FROZEN_DICTS = config.use_frozen_dicts
-
- database_engine = create_engine(config.database_config)
-
- if config.update_user_directory:
- sys.stderr.write(
- "\nThe update_user_directory must be disabled in the main synapse process"
- "\nbefore they can be run in a separate worker."
- "\nPlease add ``update_user_directory: false`` to the main config"
- "\n"
- )
- sys.exit(1)
-
- # Force the pushers to start since they will be disabled in the main config
- config.update_user_directory = True
-
- ss = UserDirectoryServer(
- config.server_name,
- db_config=config.database_config,
- config=config,
- version_string="Synapse/" + get_version_string(synapse),
- database_engine=database_engine,
- )
-
- ss.setup()
- reactor.callWhenRunning(_base.start, ss, config.worker_listeners)
-
- _base.start_worker_reactor("synapse-user-dir", config)
-
-
-if __name__ == '__main__':
+if __name__ == "__main__":
with LoggingContext("main"):
start(sys.argv[1:])
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index c58f83d268..1b13e84425 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -48,9 +48,7 @@ class AppServiceTransaction(object):
A Deferred which resolves to True if the transaction was sent.
"""
return as_api.push_bulk(
- service=self.service,
- events=self.events,
- txn_id=self.id
+ service=self.service, events=self.events, txn_id=self.id
)
def complete(self, store):
@@ -64,10 +62,7 @@ class AppServiceTransaction(object):
Returns:
A Deferred which resolves to True if the transaction was completed.
"""
- return store.complete_appservice_txn(
- service=self.service,
- txn_id=self.id
- )
+ return store.complete_appservice_txn(service=self.service, txn_id=self.id)
class ApplicationService(object):
@@ -76,6 +71,7 @@ class ApplicationService(object):
Provides methods to check if this service is "interested" in events.
"""
+
NS_USERS = "users"
NS_ALIASES = "aliases"
NS_ROOMS = "rooms"
@@ -84,11 +80,23 @@ class ApplicationService(object):
# values.
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
- def __init__(self, token, hostname, url=None, namespaces=None, hs_token=None,
- sender=None, id=None, protocols=None, rate_limited=True,
- ip_range_whitelist=None):
+ def __init__(
+ self,
+ token,
+ hostname,
+ url=None,
+ namespaces=None,
+ hs_token=None,
+ sender=None,
+ id=None,
+ protocols=None,
+ rate_limited=True,
+ ip_range_whitelist=None,
+ ):
self.token = token
- self.url = url
+ self.url = (
+ url.rstrip("/") if isinstance(url, str) else None
+ ) # url must not end with a slash
self.hs_token = hs_token
self.sender = sender
self.server_name = hostname
@@ -128,9 +136,7 @@ class ApplicationService(object):
if not isinstance(regex_obj, dict):
raise ValueError("Expected dict regex for ns '%s'" % ns)
if not isinstance(regex_obj.get("exclusive"), bool):
- raise ValueError(
- "Expected bool for 'exclusive' in ns '%s'" % ns
- )
+ raise ValueError("Expected bool for 'exclusive' in ns '%s'" % ns)
group_id = regex_obj.get("group_id")
if group_id:
if not isinstance(group_id, str):
@@ -153,9 +159,7 @@ class ApplicationService(object):
if isinstance(regex, string_types):
regex_obj["regex"] = re.compile(regex) # Pre-compile regex
else:
- raise ValueError(
- "Expected string for 'regex' in ns '%s'" % ns
- )
+ raise ValueError("Expected string for 'regex' in ns '%s'" % ns)
return namespaces
def _matches_regex(self, test_string, namespace_key):
@@ -173,20 +177,21 @@ class ApplicationService(object):
@defer.inlineCallbacks
def _matches_user(self, event, store):
if not event:
- defer.returnValue(False)
+ return False
if self.is_interested_in_user(event.sender):
- defer.returnValue(True)
+ return True
# also check m.room.member state key
- if (event.type == EventTypes.Member and
- self.is_interested_in_user(event.state_key)):
- defer.returnValue(True)
+ if event.type == EventTypes.Member and self.is_interested_in_user(
+ event.state_key
+ ):
+ return True
if not store:
- defer.returnValue(False)
+ return False
does_match = yield self._matches_user_in_member_list(event.room_id, store)
- defer.returnValue(does_match)
+ return does_match
@cachedInlineCallbacks(num_args=1, cache_context=True)
def _matches_user_in_member_list(self, room_id, store, cache_context):
@@ -197,8 +202,8 @@ class ApplicationService(object):
# check joined member events
for user_id in member_list:
if self.is_interested_in_user(user_id):
- defer.returnValue(True)
- defer.returnValue(False)
+ return True
+ return False
def _matches_room_id(self, event):
if hasattr(event, "room_id"):
@@ -208,13 +213,13 @@ class ApplicationService(object):
@defer.inlineCallbacks
def _matches_aliases(self, event, store):
if not store or not event:
- defer.returnValue(False)
+ return False
alias_list = yield store.get_aliases_for_room(event.room_id)
for alias in alias_list:
if self.is_interested_in_alias(alias):
- defer.returnValue(True)
- defer.returnValue(False)
+ return True
+ return False
@defer.inlineCallbacks
def is_interested(self, event, store=None):
@@ -228,15 +233,15 @@ class ApplicationService(object):
"""
# Do cheap checks first
if self._matches_room_id(event):
- defer.returnValue(True)
+ return True
if (yield self._matches_aliases(event, store)):
- defer.returnValue(True)
+ return True
if (yield self._matches_user(event, store)):
- defer.returnValue(True)
+ return True
- defer.returnValue(False)
+ return False
def is_interested_in_user(self, user_id):
return (
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 9ccc5a80fc..57174da021 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -32,19 +32,17 @@ logger = logging.getLogger(__name__)
sent_transactions_counter = Counter(
"synapse_appservice_api_sent_transactions",
"Number of /transactions/ requests sent",
- ["service"]
+ ["service"],
)
failed_transactions_counter = Counter(
"synapse_appservice_api_failed_transactions",
"Number of /transactions/ requests that failed to send",
- ["service"]
+ ["service"],
)
sent_events_counter = Counter(
- "synapse_appservice_api_sent_events",
- "Number of events sent to the AS",
- ["service"]
+ "synapse_appservice_api_sent_events", "Number of events sent to the AS", ["service"]
)
HOUR_IN_MS = 60 * 60 * 1000
@@ -92,50 +90,45 @@ class ApplicationServiceApi(SimpleHttpClient):
super(ApplicationServiceApi, self).__init__(hs)
self.clock = hs.get_clock()
- self.protocol_meta_cache = ResponseCache(hs, "as_protocol_meta",
- timeout_ms=HOUR_IN_MS)
+ self.protocol_meta_cache = ResponseCache(
+ hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS
+ )
@defer.inlineCallbacks
def query_user(self, service, user_id):
if service.url is None:
- defer.returnValue(False)
+ return False
uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
response = None
try:
- response = yield self.get_json(uri, {
- "access_token": service.hs_token
- })
+ response = yield self.get_json(uri, {"access_token": service.hs_token})
if response is not None: # just an empty json object
- defer.returnValue(True)
+ return True
except CodeMessageException as e:
if e.code == 404:
- defer.returnValue(False)
- return
+ return False
logger.warning("query_user to %s received %s", uri, e.code)
except Exception as ex:
logger.warning("query_user to %s threw exception %s", uri, ex)
- defer.returnValue(False)
+ return False
@defer.inlineCallbacks
def query_alias(self, service, alias):
if service.url is None:
- defer.returnValue(False)
+ return False
uri = service.url + ("/rooms/%s" % urllib.parse.quote(alias))
response = None
try:
- response = yield self.get_json(uri, {
- "access_token": service.hs_token
- })
+ response = yield self.get_json(uri, {"access_token": service.hs_token})
if response is not None: # just an empty json object
- defer.returnValue(True)
+ return True
except CodeMessageException as e:
logger.warning("query_alias to %s received %s", uri, e.code)
if e.code == 404:
- defer.returnValue(False)
- return
+ return False
except Exception as ex:
logger.warning("query_alias to %s threw exception %s", uri, ex)
- defer.returnValue(False)
+ return False
@defer.inlineCallbacks
def query_3pe(self, service, kind, protocol, fields):
@@ -144,26 +137,23 @@ class ApplicationServiceApi(SimpleHttpClient):
elif kind == ThirdPartyEntityKind.LOCATION:
required_field = "alias"
else:
- raise ValueError(
- "Unrecognised 'kind' argument %r to query_3pe()", kind
- )
+ raise ValueError("Unrecognised 'kind' argument %r to query_3pe()", kind)
if service.url is None:
- defer.returnValue([])
+ return []
uri = "%s%s/thirdparty/%s/%s" % (
service.url,
APP_SERVICE_PREFIX,
kind,
- urllib.parse.quote(protocol)
+ urllib.parse.quote(protocol),
)
try:
response = yield self.get_json(uri, fields)
if not isinstance(response, list):
logger.warning(
- "query_3pe to %s returned an invalid response %r",
- uri, response
+ "query_3pe to %s returned an invalid response %r", uri, response
)
- defer.returnValue([])
+ return []
ret = []
for r in response:
@@ -171,46 +161,45 @@ class ApplicationServiceApi(SimpleHttpClient):
ret.append(r)
else:
logger.warning(
- "query_3pe to %s returned an invalid result %r",
- uri, r
+ "query_3pe to %s returned an invalid result %r", uri, r
)
- defer.returnValue(ret)
+ return ret
except Exception as ex:
logger.warning("query_3pe to %s threw exception %s", uri, ex)
- defer.returnValue([])
+ return []
def get_3pe_protocol(self, service, protocol):
if service.url is None:
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def _get():
uri = "%s%s/thirdparty/protocol/%s" % (
service.url,
APP_SERVICE_PREFIX,
- urllib.parse.quote(protocol)
+ urllib.parse.quote(protocol),
)
try:
info = yield self.get_json(uri, {})
if not _is_valid_3pe_metadata(info):
- logger.warning("query_3pe_protocol to %s did not return a"
- " valid result", uri)
- defer.returnValue(None)
+ logger.warning(
+ "query_3pe_protocol to %s did not return a valid result", uri
+ )
+ return None
for instance in info.get("instances", []):
network_id = instance.get("network_id", None)
if network_id is not None:
instance["instance_id"] = ThirdPartyInstanceID(
- service.id, network_id,
+ service.id, network_id
).to_string()
- defer.returnValue(info)
+ return info
except Exception as ex:
- logger.warning("query_3pe_protocol to %s threw exception %s",
- uri, ex)
- defer.returnValue(None)
+ logger.warning("query_3pe_protocol to %s threw exception %s", uri, ex)
+ return None
key = (service.id, protocol)
return self.protocol_meta_cache.wrap(key, _get)
@@ -218,40 +207,34 @@ class ApplicationServiceApi(SimpleHttpClient):
@defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None):
if service.url is None:
- defer.returnValue(True)
+ return True
events = self._serialize(events)
if txn_id is None:
- logger.warning("push_bulk: Missing txn ID sending events to %s",
- service.url)
+ logger.warning(
+ "push_bulk: Missing txn ID sending events to %s", service.url
+ )
txn_id = str(0)
txn_id = str(txn_id)
- uri = service.url + ("/transactions/%s" %
- urllib.parse.quote(txn_id))
+ uri = service.url + ("/transactions/%s" % urllib.parse.quote(txn_id))
try:
yield self.put_json(
uri=uri,
- json_body={
- "events": events
- },
- args={
- "access_token": service.hs_token
- })
+ json_body={"events": events},
+ args={"access_token": service.hs_token},
+ )
sent_transactions_counter.labels(service.id).inc()
sent_events_counter.labels(service.id).inc(len(events))
- defer.returnValue(True)
- return
+ return True
except CodeMessageException as e:
logger.warning("push_bulk to %s received %s", uri, e.code)
except Exception as ex:
logger.warning("push_bulk to %s threw exception %s", uri, ex)
failed_transactions_counter.labels(service.id).inc()
- defer.returnValue(False)
+ return False
def _serialize(self, events):
time_now = self.clock.time_msec()
- return [
- serialize_event(e, time_now, as_client_event=True) for e in events
- ]
+ return [serialize_event(e, time_now, as_client_event=True) for e in events]
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 685f15c061..9998f822f1 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -53,8 +53,8 @@ import logging
from twisted.internet import defer
from synapse.appservice import ApplicationServiceState
+from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util.logcontext import run_in_background
logger = logging.getLogger(__name__)
@@ -70,35 +70,37 @@ class ApplicationServiceScheduler(object):
self.store = hs.get_datastore()
self.as_api = hs.get_application_service_api()
- def create_recoverer(service, callback):
- return _Recoverer(self.clock, self.store, self.as_api, service, callback)
-
- self.txn_ctrl = _TransactionController(
- self.clock, self.store, self.as_api, create_recoverer
- )
+ self.txn_ctrl = _TransactionController(self.clock, self.store, self.as_api)
self.queuer = _ServiceQueuer(self.txn_ctrl, self.clock)
@defer.inlineCallbacks
def start(self):
logger.info("Starting appservice scheduler")
+
# check for any DOWN ASes and start recoverers for them.
- recoverers = yield _Recoverer.start(
- self.clock, self.store, self.as_api, self.txn_ctrl.on_recovered
+ services = yield self.store.get_appservices_by_state(
+ ApplicationServiceState.DOWN
)
- self.txn_ctrl.add_recoverers(recoverers)
+
+ for service in services:
+ self.txn_ctrl.start_recoverer(service)
def submit_event_for_as(self, service, event):
self.queuer.enqueue(service, event)
class _ServiceQueuer(object):
- """Queues events for the same application service together, sending
- transactions as soon as possible. Once a transaction is sent successfully,
- this schedules any other events in the queue to run.
+ """Queue of events waiting to be sent to appservices.
+
+ Groups events into transactions per-appservice, and sends them on to the
+ TransactionController. Makes sure that we only have one transaction in flight per
+ appservice at a given time.
"""
def __init__(self, txn_ctrl, clock):
self.queued_events = {} # dict of {service_id: [events]}
+
+ # the appservices which currently have a transaction in flight
self.requests_in_flight = set()
self.txn_ctrl = txn_ctrl
self.clock = clock
@@ -112,15 +114,14 @@ class _ServiceQueuer(object):
return
run_as_background_process(
- "as-sender-%s" % (service.id, ),
- self._send_request, service,
+ "as-sender-%s" % (service.id,), self._send_request, service
)
@defer.inlineCallbacks
def _send_request(self, service):
# sanity-check: we shouldn't get here if this service already has a sender
# running.
- assert(service.id not in self.requests_in_flight)
+ assert service.id not in self.requests_in_flight
self.requests_in_flight.add(service.id)
try:
@@ -137,89 +138,97 @@ class _ServiceQueuer(object):
class _TransactionController(object):
+ """Transaction manager.
+
+ Builds AppServiceTransactions and runs their lifecycle. Also starts a Recoverer
+ if a transaction fails.
- def __init__(self, clock, store, as_api, recoverer_fn):
+ (Note we have only have one of these in the homeserver.)
+
+ Args:
+ clock (synapse.util.Clock):
+ store (synapse.storage.DataStore):
+ as_api (synapse.appservice.api.ApplicationServiceApi):
+ """
+
+ def __init__(self, clock, store, as_api):
self.clock = clock
self.store = store
self.as_api = as_api
- self.recoverer_fn = recoverer_fn
- # keep track of how many recoverers there are
- self.recoverers = []
+
+ # map from service id to recoverer instance
+ self.recoverers = {}
+
+ # for UTs
+ self.RECOVERER_CLASS = _Recoverer
@defer.inlineCallbacks
def send(self, service, events):
try:
- txn = yield self.store.create_appservice_txn(
- service=service,
- events=events
- )
+ txn = yield self.store.create_appservice_txn(service=service, events=events)
service_is_up = yield self._is_service_up(service)
if service_is_up:
sent = yield txn.send(self.as_api)
if sent:
yield txn.complete(self.store)
else:
- run_in_background(self._start_recoverer, service)
+ run_in_background(self._on_txn_fail, service)
except Exception:
logger.exception("Error creating appservice transaction")
- run_in_background(self._start_recoverer, service)
+ run_in_background(self._on_txn_fail, service)
@defer.inlineCallbacks
def on_recovered(self, recoverer):
- self.recoverers.remove(recoverer)
- logger.info("Successfully recovered application service AS ID %s",
- recoverer.service.id)
+ logger.info(
+ "Successfully recovered application service AS ID %s", recoverer.service.id
+ )
+ self.recoverers.pop(recoverer.service.id)
logger.info("Remaining active recoverers: %s", len(self.recoverers))
yield self.store.set_appservice_state(
- recoverer.service,
- ApplicationServiceState.UP
+ recoverer.service, ApplicationServiceState.UP
)
- def add_recoverers(self, recoverers):
- for r in recoverers:
- self.recoverers.append(r)
- if len(recoverers) > 0:
- logger.info("New active recoverers: %s", len(self.recoverers))
-
@defer.inlineCallbacks
- def _start_recoverer(self, service):
+ def _on_txn_fail(self, service):
try:
- yield self.store.set_appservice_state(
- service,
- ApplicationServiceState.DOWN
- )
- logger.info(
- "Application service falling behind. Starting recoverer. AS ID %s",
- service.id
- )
- recoverer = self.recoverer_fn(service, self.on_recovered)
- self.add_recoverers([recoverer])
- recoverer.recover()
+ yield self.store.set_appservice_state(service, ApplicationServiceState.DOWN)
+ self.start_recoverer(service)
except Exception:
logger.exception("Error starting AS recoverer")
+ def start_recoverer(self, service):
+ """Start a Recoverer for the given service
+
+ Args:
+ service (synapse.appservice.ApplicationService):
+ """
+ logger.info("Starting recoverer for AS ID %s", service.id)
+ assert service.id not in self.recoverers
+ recoverer = self.RECOVERER_CLASS(
+ self.clock, self.store, self.as_api, service, self.on_recovered
+ )
+ self.recoverers[service.id] = recoverer
+ recoverer.recover()
+ logger.info("Now %i active recoverers", len(self.recoverers))
+
@defer.inlineCallbacks
def _is_service_up(self, service):
state = yield self.store.get_appservice_state(service)
- defer.returnValue(state == ApplicationServiceState.UP or state is None)
+ return state == ApplicationServiceState.UP or state is None
class _Recoverer(object):
+ """Manages retries and backoff for a DOWN appservice.
- @staticmethod
- @defer.inlineCallbacks
- def start(clock, store, as_api, callback):
- services = yield store.get_appservices_by_state(
- ApplicationServiceState.DOWN
- )
- recoverers = [
- _Recoverer(clock, store, as_api, s, callback) for s in services
- ]
- for r in recoverers:
- logger.info("Starting recoverer for AS ID %s which was marked as "
- "DOWN", r.service.id)
- r.recover()
- defer.returnValue(recoverers)
+ We have one of these for each appservice which is currently considered DOWN.
+
+ Args:
+ clock (synapse.util.Clock):
+ store (synapse.storage.DataStore):
+ as_api (synapse.appservice.api.ApplicationServiceApi):
+ service (synapse.appservice.ApplicationService): the service we are managing
+ callback (callable[_Recoverer]): called once the service recovers.
+ """
def __init__(self, clock, store, as_api, service, callback):
self.clock = clock
@@ -232,10 +241,12 @@ class _Recoverer(object):
def recover(self):
def _retry():
run_as_background_process(
- "as-recoverer-%s" % (self.service.id,),
- self.retry,
+ "as-recoverer-%s" % (self.service.id,), self.retry
)
- self.clock.call_later((2 ** self.backoff_counter), _retry)
+
+ delay = 2 ** self.backoff_counter
+ logger.info("Scheduling retries on %s in %fs", self.service.id, delay)
+ self.clock.call_later(delay, _retry)
def _backoff(self):
# cap the backoff to be around 8.5min => (2^9) = 512 secs
@@ -245,24 +256,30 @@ class _Recoverer(object):
@defer.inlineCallbacks
def retry(self):
+ logger.info("Starting retries on %s", self.service.id)
try:
- txn = yield self.store.get_oldest_unsent_txn(self.service)
- if txn:
- logger.info("Retrying transaction %s for AS ID %s",
- txn.id, txn.service.id)
+ while True:
+ txn = yield self.store.get_oldest_unsent_txn(self.service)
+ if not txn:
+ # nothing left: we're done!
+ self.callback(self)
+ return
+
+ logger.info(
+ "Retrying transaction %s for AS ID %s", txn.id, txn.service.id
+ )
sent = yield txn.send(self.as_api)
- if sent:
- yield txn.complete(self.store)
- # reset the backoff counter and retry immediately
- self.backoff_counter = 1
- yield self.retry()
- else:
- self._backoff()
- else:
- self._set_service_recovered()
- except Exception as e:
- logger.exception(e)
- self._backoff()
-
- def _set_service_recovered(self):
- self.callback(self)
+ if not sent:
+ break
+
+ yield txn.complete(self.store)
+
+ # reset the backoff counter and then process the next transaction
+ self.backoff_counter = 1
+
+ except Exception:
+ logger.exception("Unexpected error running retries")
+
+ # we didn't manage to send all of the transactions before we got an error of
+ # some flavour: reschedule the next retry.
+ self._backoff()
diff --git a/synapse/config/__init__.py b/synapse/config/__init__.py
index f2a5a41e92..1e76e9559d 100644
--- a/synapse/config/__init__.py
+++ b/synapse/config/__init__.py
@@ -13,8 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import ConfigError
+from ._base import ConfigError, find_config_files
-# export ConfigError if somebody does import *
+# export ConfigError and find_config_files if somebody does
+# import *
# this is largely a fudge to stop PEP8 moaning about the import
-__all__ = ["ConfigError"]
+__all__ = ["ConfigError", "find_config_files"]
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index bf039e5823..132e48447c 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,8 +18,10 @@
import argparse
import errno
import os
+from collections import OrderedDict
from io import open as io_open
from textwrap import dedent
+from typing import Any, MutableMapping, Optional
from six import integer_types
@@ -50,7 +54,68 @@ Missing mandatory `server_name` config option.
"""
+CONFIG_FILE_HEADER = """\
+# Configuration file for Synapse.
+#
+# This is a YAML file: see [1] for a quick introduction. Note in particular
+# that *indentation is important*: all the elements of a list or dictionary
+# should have the same indentation.
+#
+# [1] https://docs.ansible.com/ansible/latest/reference_appendices/YAMLSyntax.html
+
+"""
+
+
+def path_exists(file_path):
+ """Check if a file exists
+
+ Unlike os.path.exists, this throws an exception if there is an error
+ checking if the file exists (for example, if there is a perms error on
+ the parent dir).
+
+ Returns:
+ bool: True if the file exists; False if not.
+ """
+ try:
+ os.stat(file_path)
+ return True
+ except OSError as e:
+ if e.errno != errno.ENOENT:
+ raise e
+ return False
+
+
class Config(object):
+ """
+ A configuration section, containing configuration keys and values.
+
+ Attributes:
+ section (str): The section title of this config object, such as
+ "tls" or "logger". This is used to refer to it on the root
+ logger (for example, `config.tls.some_option`). Must be
+ defined in subclasses.
+ """
+
+ section = None
+
+ def __init__(self, root_config=None):
+ self.root = root_config
+
+ def __getattr__(self, item: str) -> Any:
+ """
+ Try and fetch a configuration option that does not exist on this class.
+
+ This is so that existing configs that rely on `self.value`, where value
+ is actually from a different config section, continue to work.
+ """
+ if item in ["generate_config_section", "read_config"]:
+ raise AttributeError(item)
+
+ if self.root is None:
+ raise AttributeError(item)
+ else:
+ return self.root._get_unclassed_config(self.section, item)
+
@staticmethod
def parse_size(value):
if isinstance(value, integer_types):
@@ -87,22 +152,7 @@ class Config(object):
@classmethod
def path_exists(cls, file_path):
- """Check if a file exists
-
- Unlike os.path.exists, this throws an exception if there is an error
- checking if the file exists (for example, if there is a perms error on
- the parent dir).
-
- Returns:
- bool: True if the file exists; False if not.
- """
- try:
- os.stat(file_path)
- return True
- except OSError as e:
- if e.errno != errno.ENOENT:
- raise e
- return False
+ return path_exists(file_path)
@classmethod
def check_file(cls, file_path, config_name):
@@ -135,17 +185,106 @@ class Config(object):
with io_open(file_path, encoding="utf-8") as file_stream:
return file_stream.read()
- @staticmethod
- def read_config_file(file_path):
- with open(file_path) as file_stream:
- return yaml.safe_load(file_stream)
- def invoke_all(self, name, *args, **kargs):
- results = []
- for cls in type(self).mro():
- if name in cls.__dict__:
- results.append(getattr(cls, name)(self, *args, **kargs))
- return results
+class RootConfig(object):
+ """
+ Holder of an application's configuration.
+
+ What configuration this object holds is defined by `config_classes`, a list
+ of Config classes that will be instantiated and given the contents of a
+ configuration file to read. They can then be accessed on this class by their
+ section name, defined in the Config or dynamically set to be the name of the
+ class, lower-cased and with "Config" removed.
+ """
+
+ config_classes = []
+
+ def __init__(self):
+ self._configs = OrderedDict()
+
+ for config_class in self.config_classes:
+ if config_class.section is None:
+ raise ValueError("%r requires a section name" % (config_class,))
+
+ try:
+ conf = config_class(self)
+ except Exception as e:
+ raise Exception("Failed making %s: %r" % (config_class.section, e))
+ self._configs[config_class.section] = conf
+
+ def __getattr__(self, item: str) -> Any:
+ """
+ Redirect lookups on this object either to config objects, or values on
+ config objects, so that `config.tls.blah` works, as well as legacy uses
+ of things like `config.server_name`. It will first look up the config
+ section name, and then values on those config classes.
+ """
+ if item in self._configs.keys():
+ return self._configs[item]
+
+ return self._get_unclassed_config(None, item)
+
+ def _get_unclassed_config(self, asking_section: Optional[str], item: str):
+ """
+ Fetch a config value from one of the instantiated config classes that
+ has not been fetched directly.
+
+ Args:
+ asking_section: If this check is coming from a Config child, which
+ one? This section will not be asked if it has the value.
+ item: The configuration value key.
+
+ Raises:
+ AttributeError if no config classes have the config key. The body
+ will contain what sections were checked.
+ """
+ for key, val in self._configs.items():
+ if key == asking_section:
+ continue
+
+ if item in dir(val):
+ return getattr(val, item)
+
+ raise AttributeError(item, "not found in %s" % (list(self._configs.keys()),))
+
+ def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]:
+ """
+ Invoke a function on all instantiated config objects this RootConfig is
+ configured to use.
+
+ Args:
+ func_name: Name of function to invoke
+ *args
+ **kwargs
+ Returns:
+ ordered dictionary of config section name and the result of the
+ function from it.
+ """
+ res = OrderedDict()
+
+ for name, config in self._configs.items():
+ if hasattr(config, func_name):
+ res[name] = getattr(config, func_name)(*args, **kwargs)
+
+ return res
+
+ @classmethod
+ def invoke_all_static(cls, func_name: str, *args, **kwargs):
+ """
+ Invoke a static function on config objects this RootConfig is
+ configured to use.
+
+ Args:
+ func_name: Name of function to invoke
+ *args
+ **kwargs
+ Returns:
+ ordered dictionary of config section name and the result of the
+ function from it.
+ """
+ for config in cls.config_classes:
+ if hasattr(config, func_name):
+ getattr(config, func_name)(*args, **kwargs)
def generate_config(
self,
@@ -154,12 +293,18 @@ class Config(object):
server_name,
generate_secrets=False,
report_stats=None,
+ open_private_ports=False,
+ listeners=None,
+ database_conf=None,
+ tls_certificate_path=None,
+ tls_private_key_path=None,
+ acme_domain=None,
):
- """Build a default configuration file
+ """
+ Build a default configuration file
- This is used both when the user explicitly asks us to generate a config file
- (eg with --generate_config), and before loading the config at runtime (to give
- a base which the config files override)
+ This is used when the user explicitly asks us to generate a config file
+ (eg with --generate_config).
Args:
config_dir_path (str): The path where the config files are kept. Used to
@@ -178,26 +323,84 @@ class Config(object):
report_stats (bool|None): Initial setting for the report_stats setting.
If None, report_stats will be left unset.
+ open_private_ports (bool): True to leave private ports (such as the non-TLS
+ HTTP listener) open to the internet.
+
+ listeners (list(dict)|None): A list of descriptions of the listeners
+ synapse should start with each of which specifies a port (str), a list of
+ resources (list(str)), tls (bool) and type (str). For example:
+ [{
+ "port": 8448,
+ "resources": [{"names": ["federation"]}],
+ "tls": True,
+ "type": "http",
+ },
+ {
+ "port": 443,
+ "resources": [{"names": ["client"]}],
+ "tls": False,
+ "type": "http",
+ }],
+
+
+ database (str|None): The database type to configure, either `psycog2`
+ or `sqlite3`.
+
+ tls_certificate_path (str|None): The path to the tls certificate.
+
+ tls_private_key_path (str|None): The path to the tls private key.
+
+ acme_domain (str|None): The domain acme will try to validate. If
+ specified acme will be enabled.
+
Returns:
str: the yaml config file
"""
- default_config = "\n\n".join(
+
+ return CONFIG_FILE_HEADER + "\n\n".join(
dedent(conf)
for conf in self.invoke_all(
- "default_config",
+ "generate_config_section",
config_dir_path=config_dir_path,
data_dir_path=data_dir_path,
server_name=server_name,
generate_secrets=generate_secrets,
report_stats=report_stats,
- )
+ open_private_ports=open_private_ports,
+ listeners=listeners,
+ database_conf=database_conf,
+ tls_certificate_path=tls_certificate_path,
+ tls_private_key_path=tls_private_key_path,
+ acme_domain=acme_domain,
+ ).values()
)
- return default_config
-
@classmethod
def load_config(cls, description, argv):
+ """Parse the commandline and config files
+
+ Doesn't support config-file-generation: used by the worker apps.
+
+ Returns: Config object.
+ """
config_parser = argparse.ArgumentParser(description=description)
+ cls.add_arguments_to_parser(config_parser)
+ obj, _ = cls.load_config_with_parser(config_parser, argv)
+
+ return obj
+
+ @classmethod
+ def add_arguments_to_parser(cls, config_parser):
+ """Adds all the config flags to an ArgumentParser.
+
+ Doesn't support config-file-generation: used by the worker apps.
+
+ Used for workers where we want to add extra flags/subcommands.
+
+ Args:
+ config_parser (ArgumentParser): App description
+ """
+
config_parser.add_argument(
"-c",
"--config-path",
@@ -211,28 +414,63 @@ class Config(object):
"--keys-directory",
metavar="DIRECTORY",
help="Where files such as certs and signing keys are stored when"
- " their location is given explicitly in the config."
+ " their location is not given explicitly in the config."
" Defaults to the directory containing the last config file",
)
- obj = cls()
+ cls.invoke_all_static("add_arguments", config_parser)
+
+ @classmethod
+ def load_config_with_parser(cls, parser, argv):
+ """Parse the commandline and config files with the given parser
+
+ Doesn't support config-file-generation: used by the worker apps.
- obj.invoke_all("add_arguments", config_parser)
+ Used for workers where we want to add extra flags/subcommands.
- config_args = config_parser.parse_args(argv)
+ Args:
+ parser (ArgumentParser)
+ argv (list[str])
+
+ Returns:
+ tuple[HomeServerConfig, argparse.Namespace]: Returns the parsed
+ config object and the parsed argparse.Namespace object from
+ `parser.parse_args(..)`
+ """
+
+ obj = cls()
+
+ config_args = parser.parse_args(argv)
config_files = find_config_files(search_paths=config_args.config_path)
- obj.read_config_files(
- config_files, keys_directory=config_args.keys_directory, generate_keys=False
+ if not config_files:
+ parser.error("Must supply a config file.")
+
+ if config_args.keys_directory:
+ config_dir_path = config_args.keys_directory
+ else:
+ config_dir_path = os.path.dirname(config_files[-1])
+ config_dir_path = os.path.abspath(config_dir_path)
+ data_dir_path = os.getcwd()
+
+ config_dict = read_config_files(config_files)
+ obj.parse_config_dict(
+ config_dict, config_dir_path=config_dir_path, data_dir_path=data_dir_path
)
obj.invoke_all("read_arguments", config_args)
- return obj
+ return obj, config_args
@classmethod
def load_or_generate_config(cls, description, argv):
+ """Parse the commandline and config files
+
+ Supports generation of config files, so is used for the main homeserver app.
+
+ Returns: Config object, or None if --generate-config or --generate-keys was set
+ """
config_parser = argparse.ArgumentParser(add_help=False)
config_parser.add_argument(
"-c",
@@ -242,37 +480,74 @@ class Config(object):
help="Specify config file. Can be given multiple times and"
" may specify directories containing *.yaml files.",
)
- config_parser.add_argument(
+
+ generate_group = config_parser.add_argument_group("Config generation")
+ generate_group.add_argument(
"--generate-config",
action="store_true",
- help="Generate a config file for the server name",
+ help="Generate a config file, then exit.",
)
- config_parser.add_argument(
+ generate_group.add_argument(
+ "--generate-missing-configs",
+ "--generate-keys",
+ action="store_true",
+ help="Generate any missing additional config files, then exit.",
+ )
+ generate_group.add_argument(
+ "-H", "--server-name", help="The server name to generate a config file for."
+ )
+ generate_group.add_argument(
"--report-stats",
action="store",
- help="Whether the generated config reports anonymized usage statistics",
+ help="Whether the generated config reports anonymized usage statistics.",
choices=["yes", "no"],
)
- config_parser.add_argument(
- "--generate-keys",
- action="store_true",
- help="Generate any missing key files then exit",
- )
- config_parser.add_argument(
+ generate_group.add_argument(
+ "--config-directory",
"--keys-directory",
metavar="DIRECTORY",
- help="Used with 'generate-*' options to specify where files such as"
- " signing keys should be stored, unless explicitly"
- " specified in the config.",
+ help=(
+ "Specify where additional config files such as signing keys and log"
+ " config should be stored. Defaults to the same directory as the last"
+ " config file."
+ ),
)
- config_parser.add_argument(
- "-H", "--server-name", help="The server name to generate a config file for"
+ generate_group.add_argument(
+ "--data-directory",
+ metavar="DIRECTORY",
+ help=(
+ "Specify where data such as the media store and database file should be"
+ " stored. Defaults to the current working directory."
+ ),
)
+ generate_group.add_argument(
+ "--open-private-ports",
+ action="store_true",
+ help=(
+ "Leave private ports (such as the non-TLS HTTP listener) open to the"
+ " internet. Do not use this unless you know what you are doing."
+ ),
+ )
+
config_args, remaining_args = config_parser.parse_known_args(argv)
config_files = find_config_files(search_paths=config_args.config_path)
- generate_keys = config_args.generate_keys
+ if not config_files:
+ config_parser.error(
+ "Must supply a config file.\nA config file can be automatically"
+ ' generated using "--generate-config -H SERVER_NAME'
+ ' -c CONFIG-FILE"'
+ )
+
+ if config_args.config_directory:
+ config_dir_path = config_args.config_directory
+ else:
+ config_dir_path = os.path.dirname(config_files[-1])
+ config_dir_path = os.path.abspath(config_dir_path)
+ data_dir_path = os.getcwd()
+
+ generate_missing_configs = config_args.generate_missing_configs
obj = cls()
@@ -282,19 +557,16 @@ class Config(object):
"Please specify either --report-stats=yes or --report-stats=no\n\n"
+ MISSING_REPORT_STATS_SPIEL
)
- if not config_files:
- config_parser.error(
- "Must supply a config file.\nA config file can be automatically"
- " generated using \"--generate-config -H SERVER_NAME"
- " -c CONFIG-FILE\""
- )
+
(config_path,) = config_files
- if not cls.path_exists(config_path):
- if config_args.keys_directory:
- config_dir_path = config_args.keys_directory
+ if not path_exists(config_path):
+ print("Generating config file %s" % (config_path,))
+
+ if config_args.data_directory:
+ data_dir_path = config_args.data_directory
else:
- config_dir_path = os.path.dirname(config_path)
- config_dir_path = os.path.abspath(config_dir_path)
+ data_dir_path = os.getcwd()
+ data_dir_path = os.path.abspath(data_dir_path)
server_name = config_args.server_name
if not server_name:
@@ -305,22 +577,21 @@ class Config(object):
config_str = obj.generate_config(
config_dir_path=config_dir_path,
- data_dir_path=os.getcwd(),
+ data_dir_path=data_dir_path,
server_name=server_name,
report_stats=(config_args.report_stats == "yes"),
generate_secrets=True,
+ open_private_ports=config_args.open_private_ports,
)
- if not cls.path_exists(config_dir_path):
+ if not path_exists(config_dir_path):
os.makedirs(config_dir_path)
with open(config_path, "w") as config_file:
- config_file.write(
- "# vim:ft=yaml\n\n"
- )
config_file.write(config_str)
+ config_file.write("\n\n# vim:ft=yaml")
- config = yaml.safe_load(config_str)
- obj.invoke_all("generate_files", config)
+ config_dict = yaml.safe_load(config_str)
+ obj.generate_missing_files(config_dict, config_dir_path)
print(
(
@@ -334,12 +605,12 @@ class Config(object):
else:
print(
(
- "Config file %r already exists. Generating any missing key"
+ "Config file %r already exists. Generating any missing config"
" files."
)
% (config_path,)
)
- generate_keys = True
+ generate_missing_configs = True
parser = argparse.ArgumentParser(
parents=[config_parser],
@@ -347,69 +618,66 @@ class Config(object):
formatter_class=argparse.RawDescriptionHelpFormatter,
)
- obj.invoke_all("add_arguments", parser)
+ obj.invoke_all_static("add_arguments", parser)
args = parser.parse_args(remaining_args)
- if not config_files:
- config_parser.error(
- "Must supply a config file.\nA config file can be automatically"
- " generated using \"--generate-config -H SERVER_NAME"
- " -c CONFIG-FILE\""
- )
-
- obj.read_config_files(
- config_files,
- keys_directory=config_args.keys_directory,
- generate_keys=generate_keys,
- )
-
- if generate_keys:
+ config_dict = read_config_files(config_files)
+ if generate_missing_configs:
+ obj.generate_missing_files(config_dict, config_dir_path)
return None
+ obj.parse_config_dict(
+ config_dict, config_dir_path=config_dir_path, data_dir_path=data_dir_path
+ )
obj.invoke_all("read_arguments", args)
return obj
- def read_config_files(self, config_files, keys_directory=None, generate_keys=False):
- if not keys_directory:
- keys_directory = os.path.dirname(config_files[-1])
-
- self.config_dir_path = os.path.abspath(keys_directory)
+ def parse_config_dict(self, config_dict, config_dir_path=None, data_dir_path=None):
+ """Read the information from the config dict into this Config object.
- specified_config = {}
- for config_file in config_files:
- yaml_config = self.read_config_file(config_file)
- specified_config.update(yaml_config)
+ Args:
+ config_dict (dict): Configuration data, as read from the yaml
- if "server_name" not in specified_config:
- raise ConfigError(MISSING_SERVER_NAME)
+ config_dir_path (str): The path where the config files are kept. Used to
+ create filenames for things like the log config and the signing key.
- server_name = specified_config["server_name"]
- config_string = self.generate_config(
- config_dir_path=self.config_dir_path,
- data_dir_path=os.getcwd(),
- server_name=server_name,
- generate_secrets=False,
+ data_dir_path (str): The path where the data files are kept. Used to create
+ filenames for things like the database and media store.
+ """
+ self.invoke_all(
+ "read_config",
+ config_dict,
+ config_dir_path=config_dir_path,
+ data_dir_path=data_dir_path,
)
- config = yaml.safe_load(config_string)
- config.pop("log_config")
- config.update(specified_config)
- if "report_stats" not in config:
- raise ConfigError(
- MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS
- + "\n"
- + MISSING_REPORT_STATS_SPIEL
- )
+ def generate_missing_files(self, config_dict, config_dir_path):
+ self.invoke_all("generate_files", config_dict, config_dir_path)
- if generate_keys:
- self.invoke_all("generate_files", config)
- return
- self.parse_config_dict(config)
+def read_config_files(config_files):
+ """Read the config files into a dict
- def parse_config_dict(self, config_dict):
- self.invoke_all("read_config", config_dict)
+ Args:
+ config_files (iterable[str]): A list of the config files to read
+
+ Returns: dict
+ """
+ specified_config = {}
+ for config_file in config_files:
+ with open(config_file) as file_stream:
+ yaml_config = yaml.safe_load(file_stream)
+ specified_config.update(yaml_config)
+
+ if "server_name" not in specified_config:
+ raise ConfigError(MISSING_SERVER_NAME)
+
+ if "report_stats" not in specified_config:
+ raise ConfigError(
+ MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" + MISSING_REPORT_STATS_SPIEL
+ )
+ return specified_config
def find_config_files(search_paths):
@@ -454,3 +722,6 @@ def find_config_files(search_paths):
else:
config_files.append(config_path)
return config_files
+
+
+__all__ = ["Config", "RootConfig"]
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
new file mode 100644
index 0000000000..3053fc9d27
--- /dev/null
+++ b/synapse/config/_base.pyi
@@ -0,0 +1,137 @@
+from typing import Any, List, Optional
+
+from synapse.config import (
+ api,
+ appservice,
+ captcha,
+ cas,
+ consent_config,
+ database,
+ emailconfig,
+ groups,
+ jwt_config,
+ key,
+ logger,
+ metrics,
+ password,
+ password_auth_providers,
+ push,
+ ratelimiting,
+ registration,
+ repository,
+ room_directory,
+ saml2_config,
+ server,
+ server_notices_config,
+ spam_checker,
+ sso,
+ stats,
+ third_party_event_rules,
+ tls,
+ tracer,
+ user_directory,
+ voip,
+ workers,
+)
+
+class ConfigError(Exception): ...
+
+MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str
+MISSING_REPORT_STATS_SPIEL: str
+MISSING_SERVER_NAME: str
+
+def path_exists(file_path: str): ...
+
+class RootConfig:
+ server: server.ServerConfig
+ tls: tls.TlsConfig
+ database: database.DatabaseConfig
+ logging: logger.LoggingConfig
+ ratelimit: ratelimiting.RatelimitConfig
+ media: repository.ContentRepositoryConfig
+ captcha: captcha.CaptchaConfig
+ voip: voip.VoipConfig
+ registration: registration.RegistrationConfig
+ metrics: metrics.MetricsConfig
+ api: api.ApiConfig
+ appservice: appservice.AppServiceConfig
+ key: key.KeyConfig
+ saml2: saml2_config.SAML2Config
+ cas: cas.CasConfig
+ sso: sso.SSOConfig
+ jwt: jwt_config.JWTConfig
+ password: password.PasswordConfig
+ email: emailconfig.EmailConfig
+ worker: workers.WorkerConfig
+ authproviders: password_auth_providers.PasswordAuthProviderConfig
+ push: push.PushConfig
+ spamchecker: spam_checker.SpamCheckerConfig
+ groups: groups.GroupsConfig
+ userdirectory: user_directory.UserDirectoryConfig
+ consent: consent_config.ConsentConfig
+ stats: stats.StatsConfig
+ servernotices: server_notices_config.ServerNoticesConfig
+ roomdirectory: room_directory.RoomDirectoryConfig
+ thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
+ tracer: tracer.TracerConfig
+
+ config_classes: List = ...
+ def __init__(self) -> None: ...
+ def invoke_all(self, func_name: str, *args: Any, **kwargs: Any): ...
+ @classmethod
+ def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: Any) -> None: ...
+ def __getattr__(self, item: str): ...
+ def parse_config_dict(
+ self,
+ config_dict: Any,
+ config_dir_path: Optional[Any] = ...,
+ data_dir_path: Optional[Any] = ...,
+ ) -> None: ...
+ read_config: Any = ...
+ def generate_config(
+ self,
+ config_dir_path: str,
+ data_dir_path: str,
+ server_name: str,
+ generate_secrets: bool = ...,
+ report_stats: Optional[str] = ...,
+ open_private_ports: bool = ...,
+ listeners: Optional[Any] = ...,
+ database_conf: Optional[Any] = ...,
+ tls_certificate_path: Optional[str] = ...,
+ tls_private_key_path: Optional[str] = ...,
+ acme_domain: Optional[str] = ...,
+ ): ...
+ @classmethod
+ def load_or_generate_config(cls, description: Any, argv: Any): ...
+ @classmethod
+ def load_config(cls, description: Any, argv: Any): ...
+ @classmethod
+ def add_arguments_to_parser(cls, config_parser: Any) -> None: ...
+ @classmethod
+ def load_config_with_parser(cls, parser: Any, argv: Any): ...
+ def generate_missing_files(
+ self, config_dict: dict, config_dir_path: str
+ ) -> None: ...
+
+class Config:
+ root: RootConfig
+ def __init__(self, root_config: Optional[RootConfig] = ...) -> None: ...
+ def __getattr__(self, item: str, from_root: bool = ...): ...
+ @staticmethod
+ def parse_size(value: Any): ...
+ @staticmethod
+ def parse_duration(value: Any): ...
+ @staticmethod
+ def abspath(file_path: Optional[str]): ...
+ @classmethod
+ def path_exists(cls, file_path: str): ...
+ @classmethod
+ def check_file(cls, file_path: str, config_name: str): ...
+ @classmethod
+ def ensure_directory(cls, dir_path: str): ...
+ @classmethod
+ def read_file(cls, file_path: str, config_name: str): ...
+
+def read_config_files(config_files: List[str]): ...
+def find_config_files(search_paths: List[str]): ...
diff --git a/synapse/config/api.py b/synapse/config/api.py
index 5eb4f86fa2..74cd53a8ed 100644
--- a/synapse/config/api.py
+++ b/synapse/config/api.py
@@ -18,17 +18,21 @@ from ._base import Config
class ApiConfig(Config):
+ section = "api"
- def read_config(self, config):
- self.room_invite_state_types = config.get("room_invite_state_types", [
- EventTypes.JoinRules,
- EventTypes.CanonicalAlias,
- EventTypes.RoomAvatar,
- EventTypes.RoomEncryption,
- EventTypes.Name,
- ])
+ def read_config(self, config, **kwargs):
+ self.room_invite_state_types = config.get(
+ "room_invite_state_types",
+ [
+ EventTypes.JoinRules,
+ EventTypes.CanonicalAlias,
+ EventTypes.RoomAvatar,
+ EventTypes.RoomEncryption,
+ EventTypes.Name,
+ ],
+ )
- def default_config(cls, **kwargs):
+ def generate_config_section(cls, **kwargs):
return """\
## API Configuration ##
@@ -40,4 +44,6 @@ class ApiConfig(Config):
# - "{RoomAvatar}"
# - "{RoomEncryption}"
# - "{Name}"
- """.format(**vars(EventTypes))
+ """.format(
+ **vars(EventTypes)
+ )
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index 7e89d345d8..ca43e96bd1 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -13,6 +13,7 @@
# limitations under the License.
import logging
+from typing import Dict
from six import string_types
from six.moves.urllib import parse as urlparse
@@ -29,13 +30,14 @@ logger = logging.getLogger(__name__)
class AppServiceConfig(Config):
+ section = "appservice"
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
self.app_service_config_files = config.get("app_service_config_files", [])
self.notify_appservices = config.get("notify_appservices", True)
self.track_appservice_user_ips = config.get("track_appservice_user_ips", False)
- def default_config(cls, **kwargs):
+ def generate_config_section(cls, **kwargs):
return """\
# A list of application service config files to use
#
@@ -46,42 +48,38 @@ class AppServiceConfig(Config):
# Uncomment to enable tracking of application service IP addresses. Implicitly
# enables MAU tracking for application service users.
#
- #track_appservice_user_ips: True
+ #track_appservice_user_ips: true
"""
def load_appservices(hostname, config_files):
"""Returns a list of Application Services from the config files."""
if not isinstance(config_files, list):
- logger.warning(
- "Expected %s to be a list of AS config files.", config_files
- )
+ logger.warning("Expected %s to be a list of AS config files.", config_files)
return []
# Dicts of value -> filename
- seen_as_tokens = {}
- seen_ids = {}
+ seen_as_tokens = {} # type: Dict[str, str]
+ seen_ids = {} # type: Dict[str, str]
appservices = []
for config_file in config_files:
try:
- with open(config_file, 'r') as f:
- appservice = _load_appservice(
- hostname, yaml.safe_load(f), config_file
- )
+ with open(config_file, "r") as f:
+ appservice = _load_appservice(hostname, yaml.safe_load(f), config_file)
if appservice.id in seen_ids:
raise ConfigError(
"Cannot reuse ID across application services: "
- "%s (files: %s, %s)" % (
- appservice.id, config_file, seen_ids[appservice.id],
- )
+ "%s (files: %s, %s)"
+ % (appservice.id, config_file, seen_ids[appservice.id])
)
seen_ids[appservice.id] = config_file
if appservice.token in seen_as_tokens:
raise ConfigError(
"Cannot reuse as_token across application services: "
- "%s (files: %s, %s)" % (
+ "%s (files: %s, %s)"
+ % (
appservice.token,
config_file,
seen_as_tokens[appservice.token],
@@ -98,28 +96,26 @@ def load_appservices(hostname, config_files):
def _load_appservice(hostname, as_info, config_filename):
- required_string_fields = [
- "id", "as_token", "hs_token", "sender_localpart"
- ]
+ required_string_fields = ["id", "as_token", "hs_token", "sender_localpart"]
for field in required_string_fields:
if not isinstance(as_info.get(field), string_types):
- raise KeyError("Required string field: '%s' (%s)" % (
- field, config_filename,
- ))
+ raise KeyError(
+ "Required string field: '%s' (%s)" % (field, config_filename)
+ )
# 'url' must either be a string or explicitly null, not missing
# to avoid accidentally turning off push for ASes.
- if (not isinstance(as_info.get("url"), string_types) and
- as_info.get("url", "") is not None):
+ if (
+ not isinstance(as_info.get("url"), string_types)
+ and as_info.get("url", "") is not None
+ ):
raise KeyError(
"Required string field or explicit null: 'url' (%s)" % (config_filename,)
)
localpart = as_info["sender_localpart"]
if urlparse.quote(localpart) != localpart:
- raise ValueError(
- "sender_localpart needs characters which are not URL encoded."
- )
+ raise ValueError("sender_localpart needs characters which are not URL encoded.")
user = UserID(localpart, hostname)
user_id = user.to_string()
@@ -138,13 +134,12 @@ def _load_appservice(hostname, as_info, config_filename):
for regex_obj in as_info["namespaces"][ns]:
if not isinstance(regex_obj, dict):
raise ValueError(
- "Expected namespace entry in %s to be an object,"
- " but got %s", ns, regex_obj
+ "Expected namespace entry in %s to be an object, but got %s",
+ ns,
+ regex_obj,
)
if not isinstance(regex_obj.get("regex"), string_types):
- raise ValueError(
- "Missing/bad type 'regex' key in %s", regex_obj
- )
+ raise ValueError("Missing/bad type 'regex' key in %s", regex_obj)
if not isinstance(regex_obj.get("exclusive"), bool):
raise ValueError(
"Missing/bad type 'exclusive' key in %s", regex_obj
@@ -167,10 +162,8 @@ def _load_appservice(hostname, as_info, config_filename):
)
ip_range_whitelist = None
- if as_info.get('ip_range_whitelist'):
- ip_range_whitelist = IPSet(
- as_info.get('ip_range_whitelist')
- )
+ if as_info.get("ip_range_whitelist"):
+ ip_range_whitelist = IPSet(as_info.get("ip_range_whitelist"))
return ApplicationService(
token=as_info["as_token"],
diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py
index f7eebf26d2..f0171bb5b2 100644
--- a/synapse/config/captcha.py
+++ b/synapse/config/captcha.py
@@ -16,8 +16,9 @@ from ._base import Config
class CaptchaConfig(Config):
+ section = "captcha"
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
self.recaptcha_private_key = config.get("recaptcha_private_key")
self.recaptcha_public_key = config.get("recaptcha_public_key")
self.enable_registration_captcha = config.get(
@@ -29,16 +30,16 @@ class CaptchaConfig(Config):
"https://www.recaptcha.net/recaptcha/api/siteverify",
)
- def default_config(self, **kwargs):
+ def generate_config_section(self, **kwargs):
return """\
## Captcha ##
# See docs/CAPTCHA_SETUP for full details of configuring this.
- # This Home Server's ReCAPTCHA public key.
+ # This homeserver's ReCAPTCHA public key.
#
#recaptcha_public_key: "YOUR_PUBLIC_KEY"
- # This Home Server's ReCAPTCHA private key.
+ # This homeserver's ReCAPTCHA private key.
#
#recaptcha_private_key: "YOUR_PRIVATE_KEY"
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index 609c0815c8..4526c1a67b 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -22,20 +22,24 @@ class CasConfig(Config):
cas_server_url: URL of CAS server
"""
- def read_config(self, config):
+ section = "cas"
+
+ def read_config(self, config, **kwargs):
cas_config = config.get("cas_config", None)
if cas_config:
self.cas_enabled = cas_config.get("enabled", True)
self.cas_server_url = cas_config["server_url"]
self.cas_service_url = cas_config["service_url"]
+ self.cas_displayname_attribute = cas_config.get("displayname_attribute")
self.cas_required_attributes = cas_config.get("required_attributes", {})
else:
self.cas_enabled = False
self.cas_server_url = None
self.cas_service_url = None
+ self.cas_displayname_attribute = None
self.cas_required_attributes = {}
- def default_config(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
# Enable CAS for registration and login.
#
@@ -43,6 +47,7 @@ class CasConfig(Config):
# enabled: true
# server_url: "https://cas-server.com"
# service_url: "https://homeserver.domain.com:8448"
+ # #displayname_attribute: name
# #required_attributes:
# # name: value
"""
diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py
index abeb0180d3..aec9c4bbce 100644
--- a/synapse/config/consent_config.py
+++ b/synapse/config/consent_config.py
@@ -62,19 +62,22 @@ DEFAULT_CONFIG = """\
# body: >-
# To continue using this homeserver you must review and agree to the
# terms and conditions at %(consent_uri)s
-# send_server_notice_to_guests: True
+# send_server_notice_to_guests: true
# block_events_error: >-
# To continue using this homeserver you must review and agree to the
# terms and conditions at %(consent_uri)s
-# require_at_registration: False
+# require_at_registration: false
# policy_name: Privacy Policy
#
"""
class ConsentConfig(Config):
- def __init__(self):
- super(ConsentConfig, self).__init__()
+
+ section = "consent"
+
+ def __init__(self, *args):
+ super(ConsentConfig, self).__init__(*args)
self.user_consent_version = None
self.user_consent_template_dir = None
@@ -84,35 +87,32 @@ class ConsentConfig(Config):
self.user_consent_at_registration = False
self.user_consent_policy_name = "Privacy Policy"
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
consent_config = config.get("user_consent")
if consent_config is None:
return
self.user_consent_version = str(consent_config["version"])
- self.user_consent_template_dir = self.abspath(
- consent_config["template_dir"]
- )
+ self.user_consent_template_dir = self.abspath(consent_config["template_dir"])
if not path.isdir(self.user_consent_template_dir):
raise ConfigError(
- "Could not find template directory '%s'" % (
- self.user_consent_template_dir,
- ),
+ "Could not find template directory '%s'"
+ % (self.user_consent_template_dir,)
)
self.user_consent_server_notice_content = consent_config.get(
- "server_notice_content",
+ "server_notice_content"
)
self.block_events_without_consent_error = consent_config.get(
- "block_events_error",
+ "block_events_error"
+ )
+ self.user_consent_server_notice_to_guests = bool(
+ consent_config.get("send_server_notice_to_guests", False)
+ )
+ self.user_consent_at_registration = bool(
+ consent_config.get("require_at_registration", False)
)
- self.user_consent_server_notice_to_guests = bool(consent_config.get(
- "send_server_notice_to_guests", False,
- ))
- self.user_consent_at_registration = bool(consent_config.get(
- "require_at_registration", False,
- ))
self.user_consent_policy_name = consent_config.get(
- "policy_name", "Privacy Policy",
+ "policy_name", "Privacy Policy"
)
- def default_config(self, **kwargs):
+ def generate_config_section(self, **kwargs):
return DEFAULT_CONFIG
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 3c27ed6b4a..219b32f670 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -12,71 +12,157 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
import os
+from textwrap import indent
-from ._base import Config
+import yaml
+
+from synapse.config._base import Config, ConfigError
+
+logger = logging.getLogger(__name__)
+
+
+class DatabaseConnectionConfig:
+ """Contains the connection config for a particular database.
+
+ Args:
+ name: A label for the database, used for logging.
+ db_config: The config for a particular database, as per `database`
+ section of main config. Has three fields: `name` for database
+ module name, `args` for the args to give to the database
+ connector, and optional `data_stores` that is a list of stores to
+ provision on this database (defaulting to all).
+ """
+
+ def __init__(self, name: str, db_config: dict):
+ if db_config["name"] not in ("sqlite3", "psycopg2"):
+ raise ConfigError("Unsupported database type %r" % (db_config["name"],))
+
+ if db_config["name"] == "sqlite3":
+ db_config.setdefault("args", {}).update(
+ {"cp_min": 1, "cp_max": 1, "check_same_thread": False}
+ )
+
+ data_stores = db_config.get("data_stores")
+ if data_stores is None:
+ data_stores = ["main", "state"]
+
+ self.name = name
+ self.config = db_config
+ self.data_stores = data_stores
class DatabaseConfig(Config):
+ section = "database"
- def read_config(self, config):
- self.event_cache_size = self.parse_size(
- config.get("event_cache_size", "10K")
- )
+ def read_config(self, config, **kwargs):
+ self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
+
+ # We *experimentally* support specifying multiple databases via the
+ # `databases` key. This is a map from a label to database config in the
+ # same format as the `database` config option, plus an extra
+ # `data_stores` key to specify which data store goes where. For example:
+ #
+ # databases:
+ # master:
+ # name: psycopg2
+ # data_stores: ["main"]
+ # args: {}
+ # state:
+ # name: psycopg2
+ # data_stores: ["state"]
+ # args: {}
+
+ multi_database_config = config.get("databases")
+ database_config = config.get("database")
+
+ if multi_database_config and database_config:
+ raise ConfigError("Can't specify both 'database' and 'datbases' in config")
+
+ if multi_database_config:
+ if config.get("database_path"):
+ raise ConfigError("Can't specify 'database_path' with 'databases'")
+
+ self.databases = [
+ DatabaseConnectionConfig(name, db_conf)
+ for name, db_conf in multi_database_config.items()
+ ]
- self.database_config = config.get("database")
-
- if self.database_config is None:
- self.database_config = {
- "name": "sqlite3",
- "args": {},
- }
-
- name = self.database_config.get("name", None)
- if name == "psycopg2":
- pass
- elif name == "sqlite3":
- self.database_config.setdefault("args", {}).update({
- "cp_min": 1,
- "cp_max": 1,
- "check_same_thread": False,
- })
else:
- raise RuntimeError("Unsupported database type '%s'" % (name,))
+ if database_config is None:
+ database_config = {"name": "sqlite3", "args": {}}
- self.set_databasepath(config.get("database_path"))
+ self.databases = [DatabaseConnectionConfig("master", database_config)]
- def default_config(self, data_dir_path, **kwargs):
- database_path = os.path.join(data_dir_path, "homeserver.db")
- return """\
- ## Database ##
+ self.set_databasepath(config.get("database_path"))
- database:
- # The database engine name
+ def generate_config_section(self, data_dir_path, database_conf, **kwargs):
+ if not database_conf:
+ database_path = os.path.join(data_dir_path, "homeserver.db")
+ database_conf = (
+ """# The database engine name
name: "sqlite3"
# Arguments to pass to the engine
args:
# Path to the database
database: "%(database_path)s"
+ """
+ % locals()
+ )
+ else:
+ database_conf = indent(yaml.dump(database_conf), " " * 10).lstrip()
+ return (
+ """\
+ ## Database ##
+
+ database:
+ %(database_conf)s
# Number of events to cache in memory.
#
#event_cache_size: 10K
- """ % locals()
+ """
+ % locals()
+ )
def read_arguments(self, args):
self.set_databasepath(args.database_path)
def set_databasepath(self, database_path):
+ if database_path is None:
+ return
+
if database_path != ":memory:":
database_path = self.abspath(database_path)
- if self.database_config.get("name", None) == "sqlite3":
- if database_path is not None:
- self.database_config["args"]["database"] = database_path
- def add_arguments(self, parser):
+ # We only support setting a database path if we have a single sqlite3
+ # database.
+ if len(self.databases) != 1:
+ raise ConfigError("Cannot specify 'database_path' with multiple databases")
+
+ database = self.get_single_database()
+ if database.config["name"] != "sqlite3":
+ # We don't raise here as we haven't done so before for this case.
+ logger.warn("Ignoring 'database_path' for non-sqlite3 database")
+ return
+
+ database.config["args"]["database"] = database_path
+
+ @staticmethod
+ def add_arguments(parser):
db_group = parser.add_argument_group("database")
db_group.add_argument(
- "-d", "--database-path", metavar="SQLITE_DATABASE_PATH",
- help="The path to a sqlite database to use."
+ "-d",
+ "--database-path",
+ metavar="SQLITE_DATABASE_PATH",
+ help="The path to a sqlite database to use.",
)
+
+ def get_single_database(self) -> DatabaseConnectionConfig:
+ """Returns the database if there is only one, useful for e.g. tests
+ """
+ if len(self.databases) != 1:
+ raise Exception("More than one database exists")
+
+ return self.databases[0]
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index ae04252906..f31fc85ec8 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -19,27 +19,36 @@ from __future__ import print_function
# This file can't be called email.py because if it is, we cannot:
import email.utils
-import logging
import os
+from enum import Enum
+from typing import Optional
import pkg_resources
from ._base import Config, ConfigError
-logger = logging.getLogger(__name__)
+MISSING_PASSWORD_RESET_CONFIG_ERROR = """\
+Password reset emails are enabled on this homeserver due to a partial
+'email' block. However, the following required keys are missing:
+ %s
+"""
class EmailConfig(Config):
- def read_config(self, config):
+ section = "email"
+
+ def read_config(self, config, **kwargs):
# TODO: We should separate better the email configuration from the notification
# and account validity config.
self.email_enable_notifs = False
- email_config = config.get("email", {})
+ email_config = config.get("email")
+ if email_config is None:
+ email_config = {}
- self.email_smtp_host = email_config.get("smtp_host", None)
- self.email_smtp_port = email_config.get("smtp_port", None)
+ self.email_smtp_host = email_config.get("smtp_host", "localhost")
+ self.email_smtp_port = email_config.get("smtp_port", 25)
self.email_smtp_user = email_config.get("smtp_user", None)
self.email_smtp_pass = email_config.get("smtp_pass", None)
self.require_transport_security = email_config.get(
@@ -59,7 +68,7 @@ class EmailConfig(Config):
if self.email_notif_from is not None:
# make sure it's valid
parsed = email.utils.parseaddr(self.email_notif_from)
- if parsed[1] == '':
+ if parsed[1] == "":
raise RuntimeError("Invalid notif_from address")
template_dir = email_config.get("template_dir")
@@ -68,28 +77,57 @@ class EmailConfig(Config):
# (Note that loading as package_resources with jinja.PackageLoader doesn't
# work for the same reason.)
if not template_dir:
- template_dir = pkg_resources.resource_filename(
- 'synapse', 'res/templates'
- )
+ template_dir = pkg_resources.resource_filename("synapse", "res/templates")
self.email_template_dir = os.path.abspath(template_dir)
self.email_enable_notifs = email_config.get("enable_notifs", False)
- account_validity_renewal_enabled = config.get(
- "account_validity", {},
- ).get("renew_at")
- email_trust_identity_server_for_password_resets = email_config.get(
- "trust_identity_server_for_password_resets", False,
+ account_validity_config = config.get("account_validity") or {}
+ account_validity_renewal_enabled = account_validity_config.get("renew_at")
+
+ self.threepid_behaviour_email = (
+ # Have Synapse handle the email sending if account_threepid_delegates.email
+ # is not defined
+ # msisdn is currently always remote while Synapse does not support any method of
+ # sending SMS messages
+ ThreepidBehaviour.REMOTE
+ if self.account_threepid_delegate_email
+ else ThreepidBehaviour.LOCAL
)
- self.email_password_reset_behaviour = (
- "remote" if email_trust_identity_server_for_password_resets else "local"
- )
- if self.email_password_reset_behaviour == "local" and email_config == {}:
- logger.warn(
- "User password resets have been disabled due to lack of email config"
- )
- self.email_password_reset_behaviour = "off"
+ # Prior to Synapse v1.4.0, there was another option that defined whether Synapse would
+ # use an identity server to password reset tokens on its behalf. We now warn the user
+ # if they have this set and tell them to use the updated option, while using a default
+ # identity server in the process.
+ self.using_identity_server_from_trusted_list = False
+ if (
+ not self.account_threepid_delegate_email
+ and config.get("trust_identity_server_for_password_resets", False) is True
+ ):
+ # Use the first entry in self.trusted_third_party_id_servers instead
+ if self.trusted_third_party_id_servers:
+ # XXX: It's a little confusing that account_threepid_delegate_email is modified
+ # both in RegistrationConfig and here. We should factor this bit out
+ self.account_threepid_delegate_email = self.trusted_third_party_id_servers[
+ 0
+ ] # type: Optional[str]
+ self.using_identity_server_from_trusted_list = True
+ else:
+ raise ConfigError(
+ "Attempted to use an identity server from"
+ '"trusted_third_party_id_servers" but it is empty.'
+ )
+
+ self.local_threepid_handling_disabled_due_to_email_config = False
+ if (
+ self.threepid_behaviour_email == ThreepidBehaviour.LOCAL
+ and email_config == {}
+ ):
+ # We cannot warn the user this has happened here
+ # Instead do so when a user attempts to reset their password
+ self.local_threepid_handling_disabled_due_to_email_config = True
+
+ self.threepid_behaviour_email = ThreepidBehaviour.OFF
# Get lifetime of a validation token in milliseconds
self.email_validation_token_lifetime = self.parse_duration(
@@ -99,196 +137,279 @@ class EmailConfig(Config):
if (
self.email_enable_notifs
or account_validity_renewal_enabled
- or self.email_password_reset_behaviour == "local"
+ or self.threepid_behaviour_email == ThreepidBehaviour.LOCAL
):
# make sure we can import the required deps
import jinja2
import bleach
+
# prevent unused warnings
jinja2
bleach
- if self.email_password_reset_behaviour == "local":
- required = [
- "smtp_host",
- "smtp_port",
- "notif_from",
- ]
-
+ if self.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
missing = []
- for k in required:
- if k not in email_config:
- missing.append(k)
-
- if (len(missing) > 0):
- raise RuntimeError(
- "email.password_reset_behaviour is set to 'local' "
- "but required keys are missing: %s" %
- (", ".join(["email." + k for k in missing]),)
+ if not self.email_notif_from:
+ missing.append("email.notif_from")
+
+ # public_baseurl is required to build password reset and validation links that
+ # will be emailed to users
+ if config.get("public_baseurl") is None:
+ missing.append("public_baseurl")
+
+ if missing:
+ raise ConfigError(
+ MISSING_PASSWORD_RESET_CONFIG_ERROR % (", ".join(missing),)
)
- # Templates for password reset emails
+ # These email templates have placeholders in them, and thus must be
+ # parsed using a templating engine during a request
self.email_password_reset_template_html = email_config.get(
- "password_reset_template_html", "password_reset.html",
+ "password_reset_template_html", "password_reset.html"
)
self.email_password_reset_template_text = email_config.get(
- "password_reset_template_text", "password_reset.txt",
+ "password_reset_template_text", "password_reset.txt"
+ )
+ self.email_registration_template_html = email_config.get(
+ "registration_template_html", "registration.html"
+ )
+ self.email_registration_template_text = email_config.get(
+ "registration_template_text", "registration.txt"
)
- self.email_password_reset_failure_template = email_config.get(
- "password_reset_failure_template", "password_reset_failure.html",
+ self.email_add_threepid_template_html = email_config.get(
+ "add_threepid_template_html", "add_threepid.html"
)
- # This template does not support any replaceable variables, so we will
- # read it from the disk once during setup
- email_password_reset_success_template = email_config.get(
- "password_reset_success_template", "password_reset_success.html",
+ self.email_add_threepid_template_text = email_config.get(
+ "add_threepid_template_text", "add_threepid.txt"
+ )
+
+ self.email_password_reset_template_failure_html = email_config.get(
+ "password_reset_template_failure_html", "password_reset_failure.html"
+ )
+ self.email_registration_template_failure_html = email_config.get(
+ "registration_template_failure_html", "registration_failure.html"
+ )
+ self.email_add_threepid_template_failure_html = email_config.get(
+ "add_threepid_template_failure_html", "add_threepid_failure.html"
+ )
+
+ # These templates do not support any placeholder variables, so we
+ # will read them from disk once during setup
+ email_password_reset_template_success_html = email_config.get(
+ "password_reset_template_success_html", "password_reset_success.html"
+ )
+ email_registration_template_success_html = email_config.get(
+ "registration_template_success_html", "registration_success.html"
+ )
+ email_add_threepid_template_success_html = email_config.get(
+ "add_threepid_template_success_html", "add_threepid_success.html"
)
# Check templates exist
- for f in [self.email_password_reset_template_html,
- self.email_password_reset_template_text,
- self.email_password_reset_failure_template,
- email_password_reset_success_template]:
+ for f in [
+ self.email_password_reset_template_html,
+ self.email_password_reset_template_text,
+ self.email_registration_template_html,
+ self.email_registration_template_text,
+ self.email_add_threepid_template_html,
+ self.email_add_threepid_template_text,
+ self.email_password_reset_template_failure_html,
+ self.email_registration_template_failure_html,
+ self.email_add_threepid_template_failure_html,
+ email_password_reset_template_success_html,
+ email_registration_template_success_html,
+ email_add_threepid_template_success_html,
+ ]:
p = os.path.join(self.email_template_dir, f)
if not os.path.isfile(p):
- raise ConfigError("Unable to find template file %s" % (p, ))
+ raise ConfigError("Unable to find template file %s" % (p,))
# Retrieve content of web templates
filepath = os.path.join(
- self.email_template_dir,
- email_password_reset_success_template,
+ self.email_template_dir, email_password_reset_template_success_html
)
- self.email_password_reset_success_html_content = self.read_file(
- filepath,
- "email.password_reset_template_success_html",
+ self.email_password_reset_template_success_html = self.read_file(
+ filepath, "email.password_reset_template_success_html"
+ )
+ filepath = os.path.join(
+ self.email_template_dir, email_registration_template_success_html
+ )
+ self.email_registration_template_success_html_content = self.read_file(
+ filepath, "email.registration_template_success_html"
+ )
+ filepath = os.path.join(
+ self.email_template_dir, email_add_threepid_template_success_html
+ )
+ self.email_add_threepid_template_success_html_content = self.read_file(
+ filepath, "email.add_threepid_template_success_html"
)
-
- if config.get("public_baseurl") is None:
- raise RuntimeError(
- "email.password_reset_behaviour is set to 'local' but no "
- "public_baseurl is set. This is necessary to generate password "
- "reset links"
- )
if self.email_enable_notifs:
- required = [
- "smtp_host",
- "smtp_port",
- "notif_from",
- "notif_template_html",
- "notif_template_text",
- ]
-
missing = []
- for k in required:
- if k not in email_config:
- missing.append(k)
-
- if (len(missing) > 0):
- raise RuntimeError(
- "email.enable_notifs is True but required keys are missing: %s" %
- (", ".join(["email." + k for k in missing]),)
- )
+ if not self.email_notif_from:
+ missing.append("email.notif_from")
if config.get("public_baseurl") is None:
- raise RuntimeError(
- "email.enable_notifs is True but no public_baseurl is set"
+ missing.append("public_baseurl")
+
+ if missing:
+ raise ConfigError(
+ "email.enable_notifs is True but required keys are missing: %s"
+ % (", ".join(missing),)
)
- self.email_notif_template_html = email_config["notif_template_html"]
- self.email_notif_template_text = email_config["notif_template_text"]
+ self.email_notif_template_html = email_config.get(
+ "notif_template_html", "notif_mail.html"
+ )
+ self.email_notif_template_text = email_config.get(
+ "notif_template_text", "notif_mail.txt"
+ )
for f in self.email_notif_template_text, self.email_notif_template_html:
p = os.path.join(self.email_template_dir, f)
if not os.path.isfile(p):
- raise ConfigError("Unable to find email template file %s" % (p, ))
+ raise ConfigError("Unable to find email template file %s" % (p,))
self.email_notif_for_new_users = email_config.get(
"notif_for_new_users", True
)
self.email_riot_base_url = email_config.get(
- "riot_base_url", None
+ "client_base_url", email_config.get("riot_base_url", None)
)
if account_validity_renewal_enabled:
self.email_expiry_template_html = email_config.get(
- "expiry_template_html", "notice_expiry.html",
+ "expiry_template_html", "notice_expiry.html"
)
self.email_expiry_template_text = email_config.get(
- "expiry_template_text", "notice_expiry.txt",
+ "expiry_template_text", "notice_expiry.txt"
)
for f in self.email_expiry_template_text, self.email_expiry_template_html:
p = os.path.join(self.email_template_dir, f)
if not os.path.isfile(p):
- raise ConfigError("Unable to find email template file %s" % (p, ))
+ raise ConfigError("Unable to find email template file %s" % (p,))
- def default_config(self, config_dir_path, server_name, **kwargs):
- return """
- # Enable sending emails for password resets, notification events or
- # account expiry notices
- #
- # If your SMTP server requires authentication, the optional smtp_user &
- # smtp_pass variables should be used
+ def generate_config_section(self, config_dir_path, server_name, **kwargs):
+ return """\
+ # Configuration for sending emails from Synapse.
#
- #email:
- # enable_notifs: false
- # smtp_host: "localhost"
- # smtp_port: 25 # SSL: 465, STARTTLS: 587
- # smtp_user: "exampleusername"
- # smtp_pass: "examplepassword"
- # require_transport_security: False
- # notif_from: "Your Friendly %(app)s Home Server <noreply@example.com>"
- # app_name: Matrix
- #
- # # Enable email notifications by default
- # notif_for_new_users: True
- #
- # # Defining a custom URL for Riot is only needed if email notifications
- # # should contain links to a self-hosted installation of Riot; when set
- # # the "app_name" setting is ignored
- # riot_base_url: "http://localhost/riot"
- #
- # # Enable sending password reset emails via the configured, trusted
- # # identity servers
- # #
- # # IMPORTANT! This will give a malicious or overtaken identity server
- # # the ability to reset passwords for your users! Make absolutely sure
- # # that you want to do this! It is strongly recommended that password
- # # reset emails be sent by the homeserver instead
- # #
- # # If this option is set to false and SMTP options have not been
- # # configured, resetting user passwords via email will be disabled
- # #trust_identity_server_for_password_resets: false
- #
- # # Configure the time that a validation email or text message code
- # # will expire after sending
- # #
- # # This is currently used for password resets
- # #validation_token_lifetime: 1h
- #
- # # Template directory. All template files should be stored within this
- # # directory
- # #
- # #template_dir: res/templates
- #
- # # Templates for email notifications
- # #
- # notif_template_html: notif_mail.html
- # notif_template_text: notif_mail.txt
- #
- # # Templates for account expiry notices
- # #
- # expiry_template_html: notice_expiry.html
- # expiry_template_text: notice_expiry.txt
- #
- # # Templates for password reset emails sent by the homeserver
- # #
- # #password_reset_template_html: password_reset.html
- # #password_reset_template_text: password_reset.txt
- #
- # # Templates for password reset success and failure pages that a user
- # # will see after attempting to reset their password
- # #
- # #password_reset_template_success_html: password_reset_success.html
- # #password_reset_template_failure_html: password_reset_failure.html
+ email:
+ # The hostname of the outgoing SMTP server to use. Defaults to 'localhost'.
+ #
+ #smtp_host: mail.server
+
+ # The port on the mail server for outgoing SMTP. Defaults to 25.
+ #
+ #smtp_port: 587
+
+ # Username/password for authentication to the SMTP server. By default, no
+ # authentication is attempted.
+ #
+ # smtp_user: "exampleusername"
+ # smtp_pass: "examplepassword"
+
+ # Uncomment the following to require TLS transport security for SMTP.
+ # By default, Synapse will connect over plain text, and will then switch to
+ # TLS via STARTTLS *if the SMTP server supports it*. If this option is set,
+ # Synapse will refuse to connect unless the server supports STARTTLS.
+ #
+ #require_transport_security: true
+
+ # notif_from defines the "From" address to use when sending emails.
+ # It must be set if email sending is enabled.
+ #
+ # The placeholder '%(app)s' will be replaced by the application name,
+ # which is normally 'app_name' (below), but may be overridden by the
+ # Matrix client application.
+ #
+ # Note that the placeholder must be written '%(app)s', including the
+ # trailing 's'.
+ #
+ #notif_from: "Your Friendly %(app)s homeserver <noreply@example.com>"
+
+ # app_name defines the default value for '%(app)s' in notif_from. It
+ # defaults to 'Matrix'.
+ #
+ #app_name: my_branded_matrix_server
+
+ # Uncomment the following to enable sending emails for messages that the user
+ # has missed. Disabled by default.
+ #
+ #enable_notifs: true
+
+ # Uncomment the following to disable automatic subscription to email
+ # notifications for new users. Enabled by default.
+ #
+ #notif_for_new_users: false
+
+ # Custom URL for client links within the email notifications. By default
+ # links will be based on "https://matrix.to".
+ #
+ # (This setting used to be called riot_base_url; the old name is still
+ # supported for backwards-compatibility but is now deprecated.)
+ #
+ #client_base_url: "http://localhost/riot"
+
+ # Configure the time that a validation email will expire after sending.
+ # Defaults to 1h.
+ #
+ #validation_token_lifetime: 15m
+
+ # Directory in which Synapse will try to find the template files below.
+ # If not set, default templates from within the Synapse package will be used.
+ #
+ # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
+ # If you *do* uncomment it, you will need to make sure that all the templates
+ # below are in the directory.
+ #
+ # Synapse will look for the following templates in this directory:
+ #
+ # * The contents of email notifications of missed events: 'notif_mail.html' and
+ # 'notif_mail.txt'.
+ #
+ # * The contents of account expiry notice emails: 'notice_expiry.html' and
+ # 'notice_expiry.txt'.
+ #
+ # * The contents of password reset emails sent by the homeserver:
+ # 'password_reset.html' and 'password_reset.txt'
+ #
+ # * HTML pages for success and failure that a user will see when they follow
+ # the link in the password reset email: 'password_reset_success.html' and
+ # 'password_reset_failure.html'
+ #
+ # * The contents of address verification emails sent during registration:
+ # 'registration.html' and 'registration.txt'
+ #
+ # * HTML pages for success and failure that a user will see when they follow
+ # the link in an address verification email sent during registration:
+ # 'registration_success.html' and 'registration_failure.html'
+ #
+ # * The contents of address verification emails sent when an address is added
+ # to a Matrix account: 'add_threepid.html' and 'add_threepid.txt'
+ #
+ # * HTML pages for success and failure that a user will see when they follow
+ # the link in an address verification email sent when an address is added
+ # to a Matrix account: 'add_threepid_success.html' and
+ # 'add_threepid_failure.html'
+ #
+ # You can see the default templates at:
+ # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
+ #
+ #template_dir: "res/templates"
"""
+
+
+class ThreepidBehaviour(Enum):
+ """
+ Enum to define the behaviour of Synapse with regards to when it contacts an identity
+ server for 3pid registration and password resets
+
+ REMOTE = use an external server to send tokens
+ LOCAL = send tokens ourselves
+ OFF = disable registration via 3pid and password resets
+ """
+
+ REMOTE = "remote"
+ LOCAL = "local"
+ OFF = "off"
diff --git a/synapse/config/groups.py b/synapse/config/groups.py
index e4be172a79..d6862d9a64 100644
--- a/synapse/config/groups.py
+++ b/synapse/config/groups.py
@@ -17,11 +17,13 @@ from ._base import Config
class GroupsConfig(Config):
- def read_config(self, config):
+ section = "groups"
+
+ def read_config(self, config, **kwargs):
self.enable_group_creation = config.get("enable_group_creation", False)
self.group_creation_prefix = config.get("group_creation_prefix", "")
- def default_config(self, **kwargs):
+ def generate_config_section(self, **kwargs):
return """\
# Uncomment to allow non-server-admin users to create groups on this server
#
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index acadef4fd3..b4bca08b20 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from ._base import RootConfig
from .api import ApiConfig
from .appservice import AppServiceConfig
from .captcha import CaptchaConfig
@@ -37,43 +38,48 @@ from .saml2_config import SAML2Config
from .server import ServerConfig
from .server_notices_config import ServerNoticesConfig
from .spam_checker import SpamCheckerConfig
+from .sso import SSOConfig
from .stats import StatsConfig
from .third_party_event_rules import ThirdPartyRulesConfig
from .tls import TlsConfig
+from .tracer import TracerConfig
from .user_directory import UserDirectoryConfig
from .voip import VoipConfig
from .workers import WorkerConfig
-class HomeServerConfig(
- ServerConfig,
- TlsConfig,
- DatabaseConfig,
- LoggingConfig,
- RatelimitConfig,
- ContentRepositoryConfig,
- CaptchaConfig,
- VoipConfig,
- RegistrationConfig,
- MetricsConfig,
- ApiConfig,
- AppServiceConfig,
- KeyConfig,
- SAML2Config,
- CasConfig,
- JWTConfig,
- PasswordConfig,
- EmailConfig,
- WorkerConfig,
- PasswordAuthProviderConfig,
- PushConfig,
- SpamCheckerConfig,
- GroupsConfig,
- UserDirectoryConfig,
- ConsentConfig,
- StatsConfig,
- ServerNoticesConfig,
- RoomDirectoryConfig,
- ThirdPartyRulesConfig,
-):
- pass
+class HomeServerConfig(RootConfig):
+
+ config_classes = [
+ ServerConfig,
+ TlsConfig,
+ DatabaseConfig,
+ LoggingConfig,
+ RatelimitConfig,
+ ContentRepositoryConfig,
+ CaptchaConfig,
+ VoipConfig,
+ RegistrationConfig,
+ MetricsConfig,
+ ApiConfig,
+ AppServiceConfig,
+ KeyConfig,
+ SAML2Config,
+ CasConfig,
+ SSOConfig,
+ JWTConfig,
+ PasswordConfig,
+ EmailConfig,
+ WorkerConfig,
+ PasswordAuthProviderConfig,
+ PushConfig,
+ SpamCheckerConfig,
+ GroupsConfig,
+ UserDirectoryConfig,
+ ConsentConfig,
+ StatsConfig,
+ ServerNoticesConfig,
+ RoomDirectoryConfig,
+ ThirdPartyRulesConfig,
+ TracerConfig,
+ ]
diff --git a/synapse/config/jwt_config.py b/synapse/config/jwt_config.py
index ecb4124096..a568726985 100644
--- a/synapse/config/jwt_config.py
+++ b/synapse/config/jwt_config.py
@@ -15,17 +15,17 @@
from ._base import Config, ConfigError
-MISSING_JWT = (
- """Missing jwt library. This is required for jwt login.
+MISSING_JWT = """Missing jwt library. This is required for jwt login.
Install by running:
pip install pyjwt
"""
-)
class JWTConfig(Config):
- def read_config(self, config):
+ section = "jwt"
+
+ def read_config(self, config, **kwargs):
jwt_config = config.get("jwt_config", None)
if jwt_config:
self.jwt_enabled = jwt_config.get("enabled", False)
@@ -34,6 +34,7 @@ class JWTConfig(Config):
try:
import jwt
+
jwt # To stop unused lint.
except ImportError:
raise ConfigError(MISSING_JWT)
@@ -42,7 +43,7 @@ class JWTConfig(Config):
self.jwt_secret = None
self.jwt_algorithm = None
- def default_config(self, **kwargs):
+ def generate_config_section(self, **kwargs):
return """\
# The JWT needs to contain a globally unique "sub" (subject) claim.
#
diff --git a/synapse/config/key.py b/synapse/config/key.py
index 424875feae..066e7838c3 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -50,6 +50,33 @@ and you should enable 'federation_verify_certificates' in your configuration.
If you are *sure* you want to do this, set 'accept_keys_insecurely' on the
trusted_key_server configuration."""
+TRUSTED_KEY_SERVER_NOT_CONFIGURED_WARN = """\
+Synapse requires that a list of trusted key servers are specified in order to
+provide signing keys for other servers in the federation.
+
+This homeserver does not have a trusted key server configured in
+homeserver.yaml and will fall back to the default of 'matrix.org'.
+
+Trusted key servers should be long-lived and stable which makes matrix.org a
+good choice for many admins, but some admins may wish to choose another. To
+suppress this warning, the admin should set 'trusted_key_servers' in
+homeserver.yaml to their desired key server and 'suppress_key_server_warning'
+to 'true'.
+
+In a future release the software-defined default will be removed entirely and
+the trusted key server will be defined exclusively by the value of
+'trusted_key_servers'.
+--------------------------------------------------------------------------------"""
+
+TRUSTED_KEY_SERVER_CONFIGURED_AS_M_ORG_WARN = """\
+This server is configured to use 'matrix.org' as its trusted key server via the
+'trusted_key_servers' config option. 'matrix.org' is a good choice for a key
+server since it is long-lived, stable and trusted. However, some admins may
+wish to use another server for this purpose.
+
+To suppress this warning and continue using 'matrix.org', admins should set
+'suppress_key_server_warning' to 'true' in homeserver.yaml.
+--------------------------------------------------------------------------------"""
logger = logging.getLogger(__name__)
@@ -65,23 +92,40 @@ class TrustedKeyServer(object):
class KeyConfig(Config):
- def read_config(self, config):
+ section = "key"
+
+ def read_config(self, config, config_dir_path, **kwargs):
# the signing key can be specified inline or in a separate file
if "signing_key" in config:
self.signing_key = read_signing_keys([config["signing_key"]])
else:
- self.signing_key_path = config["signing_key_path"]
- self.signing_key = self.read_signing_key(self.signing_key_path)
+ signing_key_path = config.get("signing_key_path")
+ if signing_key_path is None:
+ signing_key_path = os.path.join(
+ config_dir_path, config["server_name"] + ".signing.key"
+ )
+
+ self.signing_key = self.read_signing_keys(signing_key_path, "signing_key")
self.old_signing_keys = self.read_old_signing_keys(
- config.get("old_signing_keys", {})
+ config.get("old_signing_keys")
)
self.key_refresh_interval = self.parse_duration(
config.get("key_refresh_interval", "1d")
)
+ suppress_key_server_warning = config.get("suppress_key_server_warning", False)
+ key_server_signing_keys_path = config.get("key_server_signing_keys_path")
+ if key_server_signing_keys_path:
+ self.key_server_signing_keys = self.read_signing_keys(
+ key_server_signing_keys_path, "key_server_signing_keys_path"
+ )
+ else:
+ self.key_server_signing_keys = list(self.signing_key)
+
# if neither trusted_key_servers nor perspectives are given, use the default.
if "perspectives" not in config and "trusted_key_servers" not in config:
+ logger.warning(TRUSTED_KEY_SERVER_NOT_CONFIGURED_WARN)
key_servers = [{"server_name": "matrix.org"}]
else:
key_servers = config.get("trusted_key_servers", [])
@@ -95,6 +139,11 @@ class KeyConfig(Config):
# merge the 'perspectives' config into the 'trusted_key_servers' config.
key_servers.extend(_perspectives_to_key_servers(config))
+ if not suppress_key_server_warning and "matrix.org" in (
+ s["server_name"] for s in key_servers
+ ):
+ logger.warning(TRUSTED_KEY_SERVER_CONFIGURED_AS_M_ORG_WARN)
+
# list of TrustedKeyServer objects
self.key_servers = list(
_parse_key_servers(key_servers, self.federation_verify_certificates)
@@ -107,17 +156,15 @@ class KeyConfig(Config):
if not self.macaroon_secret_key:
# Unfortunately, there are people out there that don't have this
# set. Lets just be "nice" and derive one from their secret key.
- logger.warn("Config is missing macaroon_secret_key")
+ logger.warning("Config is missing macaroon_secret_key")
seed = bytes(self.signing_key[0])
self.macaroon_secret_key = hashlib.sha256(seed).digest()
- self.expire_access_token = config.get("expire_access_token", False)
-
# a secret which is used to calculate HMACs for form values, to stop
# falsification of values
self.form_secret = config.get("form_secret", None)
- def default_config(
+ def generate_config_section(
self, config_dir_path, server_name, generate_secrets=False, **kwargs
):
base_key_name = os.path.join(config_dir_path, server_name)
@@ -139,10 +186,6 @@ class KeyConfig(Config):
#
%(macaroon_secret_key)s
- # Used to enable access token expiration.
- #
- #expire_access_token: False
-
# a secret which is used to calculate HMACs for form values, to stop
# falsification of values. Must be specified for the User Consent
# forms to work.
@@ -156,14 +199,19 @@ class KeyConfig(Config):
signing_key_path: "%(base_key_name)s.signing.key"
# The keys that the server used to sign messages with but won't use
- # to sign new messages. E.g. it has lost its private key
+ # to sign new messages.
#
- #old_signing_keys:
- # "ed25519:auto":
- # # Base64 encoded public key
- # key: "The public part of your old signing key."
- # # Millisecond POSIX timestamp when the key expired.
- # expired_ts: 123456789123
+ old_signing_keys:
+ # For each key, `key` should be the base64-encoded public key, and
+ # `expired_ts`should be the time (in milliseconds since the unix epoch) that
+ # it was last used.
+ #
+ # It is possible to build an entry from an old signing.key file using the
+ # `export_signing_key` script which is provided with synapse.
+ #
+ # For example:
+ #
+ #"ed25519:id": { key: "base64string", expired_ts: 123456789123 }
# How long key response published by this server is valid for.
# Used to set the valid_until_ts in /key/v2 APIs.
@@ -183,6 +231,10 @@ class KeyConfig(Config):
# This setting supercedes an older setting named `perspectives`. The old format
# is still supported for backwards-compatibility, but it is deprecated.
#
+ # 'trusted_key_servers' defaults to matrix.org, but using it will generate a
+ # warning on start-up. To suppress this warning, set
+ # 'suppress_key_server_warning' to true.
+ #
# Options for each entry in the list include:
#
# server_name: the name of the server. required.
@@ -207,22 +259,44 @@ class KeyConfig(Config):
# "ed25519:auto": "abcdefghijklmnopqrstuvwxyzabcdefghijklmopqr"
# - server_name: "my_other_trusted_server.example.com"
#
- # The default configuration is:
+ trusted_key_servers:
+ - server_name: "matrix.org"
+
+ # Uncomment the following to disable the warning that is emitted when the
+ # trusted_key_servers include 'matrix.org'. See above.
#
- #trusted_key_servers:
- # - server_name: "matrix.org"
+ #suppress_key_server_warning: true
+
+ # The signing keys to use when acting as a trusted key server. If not specified
+ # defaults to the server signing key.
+ #
+ # Can contain multiple keys, one per line.
+ #
+ #key_server_signing_keys_path: "key_server_signing_keys.key"
"""
% locals()
)
- def read_signing_key(self, signing_key_path):
- signing_keys = self.read_file(signing_key_path, "signing_key")
+ def read_signing_keys(self, signing_key_path, name):
+ """Read the signing keys in the given path.
+
+ Args:
+ signing_key_path (str)
+ name (str): Associated config key name
+
+ Returns:
+ list[SigningKey]
+ """
+
+ signing_keys = self.read_file(signing_key_path, name)
try:
return read_signing_keys(signing_keys.splitlines(True))
except Exception as e:
- raise ConfigError("Error reading signing_key: %s" % (str(e)))
+ raise ConfigError("Error reading %s: %s" % (name, str(e)))
def read_old_signing_keys(self, old_signing_keys):
+ if old_signing_keys is None:
+ return {}
keys = {}
for key_id, key_data in old_signing_keys.items():
if is_signing_algorithm_supported(key_id):
@@ -237,10 +311,18 @@ class KeyConfig(Config):
)
return keys
- def generate_files(self, config):
- signing_key_path = config["signing_key_path"]
+ def generate_files(self, config, config_dir_path):
+ if "signing_key" in config:
+ return
+
+ signing_key_path = config.get("signing_key_path")
+ if signing_key_path is None:
+ signing_key_path = os.path.join(
+ config_dir_path, config["server_name"] + ".signing.key"
+ )
if not self.path_exists(signing_key_path):
+ print("Generating signing key file %s" % (signing_key_path,))
with open(signing_key_path, "w") as signing_key_file:
key_id = "a_" + random_string(4)
write_signing_keys(signing_key_file, (generate_signing_key(key_id),))
@@ -348,9 +430,8 @@ def _parse_key_servers(key_servers, federation_verify_certificates):
result.verify_keys[key_id] = verify_key
- if (
- not federation_verify_certificates and
- not server.get("accept_keys_insecurely")
+ if not federation_verify_certificates and not server.get(
+ "accept_keys_insecurely"
):
_assert_keyserver_has_verify_keys(result)
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index c1febbe9d3..a25c70e928 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import argparse
import logging
import logging.config
import os
@@ -20,16 +21,33 @@ from string import Template
import yaml
-from twisted.logger import STDLibLogObserver, globalLogBeginner
+from twisted.logger import (
+ ILogObserver,
+ LogBeginner,
+ STDLibLogObserver,
+ globalLogBeginner,
+)
import synapse
from synapse.app import _base as appbase
-from synapse.util.logcontext import LoggingContextFilter
+from synapse.logging._structured import (
+ reload_structured_logging,
+ setup_structured_logging,
+)
+from synapse.logging.context import LoggingContextFilter
from synapse.util.versionstring import get_version_string
-from ._base import Config
+from ._base import Config, ConfigError
+
+DEFAULT_LOG_CONFIG = Template(
+ """\
+# Log configuration for Synapse.
+#
+# This is a YAML file containing a standard Python logging configuration
+# dictionary. See [1] for details on the valid settings.
+#
+# [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema
-DEFAULT_LOG_CONFIG = Template("""
version: 1
formatters:
@@ -39,7 +57,7 @@ formatters:
filters:
context:
- (): synapse.util.logcontext.LoggingContextFilter
+ (): synapse.logging.context.LoggingContextFilter
request: ""
handlers:
@@ -57,9 +75,6 @@ handlers:
filters: [context]
loggers:
- synapse:
- level: INFO
-
synapse.storage.SQL:
# beware: increasing this to DEBUG will make synapse log sensitive
# information such as access tokens.
@@ -68,163 +83,98 @@ loggers:
root:
level: INFO
handlers: [file, console]
-""")
+
+disable_existing_loggers: false
+"""
+)
+
+LOG_FILE_ERROR = """\
+Support for the log_file configuration option and --log-file command-line option was
+removed in Synapse 1.3.0. You should instead set up a separate log configuration file.
+"""
class LoggingConfig(Config):
+ section = "logging"
- def read_config(self, config):
- self.verbosity = config.get("verbose", 0)
- self.no_redirect_stdio = config.get("no_redirect_stdio", False)
+ def read_config(self, config, **kwargs):
+ if config.get("log_file"):
+ raise ConfigError(LOG_FILE_ERROR)
self.log_config = self.abspath(config.get("log_config"))
- self.log_file = self.abspath(config.get("log_file"))
+ self.no_redirect_stdio = config.get("no_redirect_stdio", False)
- def default_config(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, config_dir_path, server_name, **kwargs):
log_config = os.path.join(config_dir_path, server_name + ".log.config")
- return """\
+ return (
+ """\
## Logging ##
- # A yaml python logging config file
+ # A yaml python logging config file as described by
+ # https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema
#
log_config: "%(log_config)s"
- """ % locals()
+ """
+ % locals()
+ )
def read_arguments(self, args):
- if args.verbose is not None:
- self.verbosity = args.verbose
if args.no_redirect_stdio is not None:
self.no_redirect_stdio = args.no_redirect_stdio
- if args.log_config is not None:
- self.log_config = args.log_config
if args.log_file is not None:
- self.log_file = args.log_file
+ raise ConfigError(LOG_FILE_ERROR)
- def add_arguments(cls, parser):
+ @staticmethod
+ def add_arguments(parser):
logging_group = parser.add_argument_group("logging")
logging_group.add_argument(
- '-v', '--verbose', dest="verbose", action='count',
- help="The verbosity level. Specify multiple times to increase "
- "verbosity. (Ignored if --log-config is specified.)"
- )
- logging_group.add_argument(
- '-f', '--log-file', dest="log_file",
- help="File to log to. (Ignored if --log-config is specified.)"
- )
- logging_group.add_argument(
- '--log-config', dest="log_config", default=None,
- help="Python logging config file"
+ "-n",
+ "--no-redirect-stdio",
+ action="store_true",
+ default=None,
+ help="Do not redirect stdout/stderr to the log",
)
+
logging_group.add_argument(
- '-n', '--no-redirect-stdio',
- action='store_true', default=None,
- help="Do not redirect stdout/stderr to the log"
+ "-f", "--log-file", dest="log_file", help=argparse.SUPPRESS,
)
- def generate_files(self, config):
+ def generate_files(self, config, config_dir_path):
log_config = config.get("log_config")
if log_config and not os.path.exists(log_config):
log_file = self.abspath("homeserver.log")
+ print(
+ "Generating log config file %s which will log to %s"
+ % (log_config, log_file)
+ )
with open(log_config, "w") as log_config_file:
- log_config_file.write(
- DEFAULT_LOG_CONFIG.substitute(log_file=log_file)
- )
+ log_config_file.write(DEFAULT_LOG_CONFIG.substitute(log_file=log_file))
-def setup_logging(config, use_worker_options=False):
- """ Set up python logging
-
- Args:
- config (LoggingConfig | synapse.config.workers.WorkerConfig):
- configuration data
-
- use_worker_options (bool): True to use 'worker_log_config' and
- 'worker_log_file' options instead of 'log_config' and 'log_file'.
-
- register_sighup (func | None): Function to call to register a
- sighup handler.
+def _setup_stdlib_logging(config, log_config, logBeginner: LogBeginner):
+ """
+ Set up Python stdlib logging.
"""
- log_config = (config.worker_log_config if use_worker_options
- else config.log_config)
- log_file = (config.worker_log_file if use_worker_options
- else config.log_file)
-
- log_format = (
- "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
- " - %(message)s"
- )
-
if log_config is None:
- # We don't have a logfile, so fall back to the 'verbosity' param from
- # the config or cmdline. (Note that we generate a log config for new
- # installs, so this will be an unusual case)
- level = logging.INFO
- level_for_storage = logging.INFO
- if config.verbosity:
- level = logging.DEBUG
- if config.verbosity > 1:
- level_for_storage = logging.DEBUG
-
- logger = logging.getLogger('')
- logger.setLevel(level)
+ log_format = (
+ "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
+ " - %(message)s"
+ )
- logging.getLogger('synapse.storage.SQL').setLevel(level_for_storage)
+ logger = logging.getLogger("")
+ logger.setLevel(logging.INFO)
+ logging.getLogger("synapse.storage.SQL").setLevel(logging.INFO)
formatter = logging.Formatter(log_format)
- if log_file:
- # TODO: Customisable file size / backup count
- handler = logging.handlers.RotatingFileHandler(
- log_file, maxBytes=(1000 * 1000 * 100), backupCount=3,
- encoding='utf8'
- )
-
- def sighup(signum, stack):
- logger.info("Closing log file due to SIGHUP")
- handler.doRollover()
- logger.info("Opened new log file due to SIGHUP")
- else:
- handler = logging.StreamHandler()
-
- def sighup(*args):
- pass
+ handler = logging.StreamHandler()
handler.setFormatter(formatter)
-
handler.addFilter(LoggingContextFilter(request=""))
-
logger.addHandler(handler)
else:
- def load_log_config():
- with open(log_config, 'r') as f:
- logging.config.dictConfig(yaml.safe_load(f))
-
- def sighup(*args):
- # it might be better to use a file watcher or something for this.
- load_log_config()
- logging.info("Reloaded log config from %s due to SIGHUP", log_config)
-
- load_log_config()
+ logging.config.dictConfig(log_config)
- appbase.register_sighup(sighup)
-
- # make sure that the first thing we log is a thing we can grep backwards
- # for
- logging.warn("***** STARTING SERVER *****")
- logging.warn(
- "Server %s version %s",
- sys.argv[0], get_version_string(synapse),
- )
- logging.info("Server hostname: %s", config.server_name)
-
- # It's critical to point twisted's internal logging somewhere, otherwise it
- # stacks up and leaks kup to 64K object;
- # see: https://twistedmatrix.com/trac/ticket/8164
- #
- # Routing to the python logging framework could be a performance problem if
- # the handlers blocked for a long time as python.logging is a blocking API
- # see https://twistedmatrix.com/documents/current/core/howto/logger.html
- # filed as https://github.com/matrix-org/synapse/issues/1727
- #
- # However this may not be too much of a problem if we are just writing to a file.
+ # Route Twisted's native logging through to the standard library logging
+ # system.
observer = STDLibLogObserver()
def _log(event):
@@ -241,9 +191,71 @@ def setup_logging(config, use_worker_options=False):
return observer(event)
- globalLogBeginner.beginLoggingTo(
- [_log],
- redirectStandardIO=not config.no_redirect_stdio,
- )
+ logBeginner.beginLoggingTo([_log], redirectStandardIO=not config.no_redirect_stdio)
if not config.no_redirect_stdio:
print("Redirected stdout/stderr to logs")
+
+ return observer
+
+
+def _reload_stdlib_logging(*args, log_config=None):
+ logger = logging.getLogger("")
+
+ if not log_config:
+ logger.warning("Reloaded a blank config?")
+
+ logging.config.dictConfig(log_config)
+
+
+def setup_logging(
+ hs, config, use_worker_options=False, logBeginner: LogBeginner = globalLogBeginner
+) -> ILogObserver:
+ """
+ Set up the logging subsystem.
+
+ Args:
+ config (LoggingConfig | synapse.config.workers.WorkerConfig):
+ configuration data
+
+ use_worker_options (bool): True to use the 'worker_log_config' option
+ instead of 'log_config'.
+
+ logBeginner: The Twisted logBeginner to use.
+
+ Returns:
+ The "root" Twisted Logger observer, suitable for sending logs to from a
+ Logger instance.
+ """
+ log_config = config.worker_log_config if use_worker_options else config.log_config
+
+ def read_config(*args, callback=None):
+ if log_config is None:
+ return None
+
+ with open(log_config, "rb") as f:
+ log_config_body = yaml.safe_load(f.read())
+
+ if callback:
+ callback(log_config=log_config_body)
+ logging.info("Reloaded log config from %s due to SIGHUP", log_config)
+
+ return log_config_body
+
+ log_config_body = read_config()
+
+ if log_config_body and log_config_body.get("structured") is True:
+ logger = setup_structured_logging(
+ hs, config, log_config_body, logBeginner=logBeginner
+ )
+ appbase.register_sighup(read_config, callback=reload_structured_logging)
+ else:
+ logger = _setup_stdlib_logging(config, log_config_body, logBeginner=logBeginner)
+ appbase.register_sighup(read_config, callback=_reload_stdlib_logging)
+
+ # make sure that the first thing we log is a thing we can grep backwards
+ # for
+ logging.warning("***** STARTING SERVER *****")
+ logging.warning("Server %s version %s", sys.argv[0], get_version_string(synapse))
+ logging.info("Server hostname: %s", config.server_name)
+
+ return logger
diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py
index 2de51979d8..22538153e1 100644
--- a/synapse/config/metrics.py
+++ b/synapse/config/metrics.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,42 +14,63 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import attr
+
+from synapse.python_dependencies import DependencyException, check_requirements
+
from ._base import Config, ConfigError
-MISSING_SENTRY = (
- """Missing sentry-sdk library. This is required to enable sentry
- integration.
- """
-)
+
+@attr.s
+class MetricsFlags(object):
+ known_servers = attr.ib(default=False, validator=attr.validators.instance_of(bool))
+
+ @classmethod
+ def all_off(cls):
+ """
+ Instantiate the flags with all options set to off.
+ """
+ return cls(**{x.name: False for x in attr.fields(cls)})
class MetricsConfig(Config):
- def read_config(self, config):
+ section = "metrics"
+
+ def read_config(self, config, **kwargs):
self.enable_metrics = config.get("enable_metrics", False)
self.report_stats = config.get("report_stats", None)
+ self.report_stats_endpoint = config.get(
+ "report_stats_endpoint", "https://matrix.org/report-usage-stats/push"
+ )
self.metrics_port = config.get("metrics_port")
self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1")
+ if self.enable_metrics:
+ _metrics_config = config.get("metrics_flags") or {}
+ self.metrics_flags = MetricsFlags(**_metrics_config)
+ else:
+ self.metrics_flags = MetricsFlags.all_off()
+
self.sentry_enabled = "sentry" in config
if self.sentry_enabled:
try:
- import sentry_sdk # noqa F401
- except ImportError:
- raise ConfigError(MISSING_SENTRY)
+ check_requirements("sentry")
+ except DependencyException as e:
+ raise ConfigError(e.message)
self.sentry_dsn = config["sentry"].get("dsn")
if not self.sentry_dsn:
raise ConfigError(
- "sentry.dsn field is required when sentry integration is enabled",
+ "sentry.dsn field is required when sentry integration is enabled"
)
- def default_config(self, report_stats=None, **kwargs):
+ def generate_config_section(self, report_stats=None, **kwargs):
res = """\
## Metrics ###
# Enable collection and rendering of performance metrics
#
- #enable_metrics: False
+ #enable_metrics: false
# Enable sentry integration
# NOTE: While attempts are made to ensure that the logs don't contain
@@ -60,12 +82,28 @@ class MetricsConfig(Config):
#sentry:
# dsn: "..."
+ # Flags to enable Prometheus metrics which are not suitable to be
+ # enabled by default, either for performance reasons or limited use.
+ #
+ metrics_flags:
+ # Publish synapse_federation_known_servers, a g auge of the number of
+ # servers this homeserver knows about, including itself. May cause
+ # performance problems on large homeservers.
+ #
+ #known_servers: true
+
# Whether or not to report anonymized homeserver usage statistics.
"""
if report_stats is None:
res += "# report_stats: true|false\n"
else:
- res += "report_stats: %s\n" % ('true' if report_stats else 'false')
+ res += "report_stats: %s\n" % ("true" if report_stats else "false")
+ res += """
+ # The endpoint to report the anonymized homeserver usage statistics to.
+ # Defaults to https://matrix.org/report-usage-stats/push
+ #
+ #report_stats_endpoint: https://example.com/report-usage-stats/push
+ """
return res
diff --git a/synapse/config/password.py b/synapse/config/password.py
index 48a38512cb..2c13810ab8 100644
--- a/synapse/config/password.py
+++ b/synapse/config/password.py
@@ -22,25 +22,34 @@ class PasswordConfig(Config):
"""Password login configuration
"""
- def read_config(self, config):
+ section = "password"
+
+ def read_config(self, config, **kwargs):
password_config = config.get("password_config", {})
if password_config is None:
password_config = {}
self.password_enabled = password_config.get("enabled", True)
+ self.password_localdb_enabled = password_config.get("localdb_enabled", True)
self.password_pepper = password_config.get("pepper", "")
# Password policy
self.password_policy = password_config.get("policy", {})
self.password_policy_enabled = self.password_policy.pop("enabled", False)
- def default_config(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
password_config:
# Uncomment to disable password login
#
#enabled: false
+ # Uncomment to disable authentication against the local password
+ # database. This is ignored if `enabled` is false, and is only useful
+ # if you have other password_providers.
+ #
+ #localdb_enabled: false
+
# Uncomment and change to a secret random string for extra security.
# DO NOT CHANGE THIS AFTER INITIAL SETUP!
#
diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py
index f0a6be0679..9746bbc681 100644
--- a/synapse/config/password_auth_providers.py
+++ b/synapse/config/password_auth_providers.py
@@ -13,44 +13,44 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, List
+
from synapse.util.module_loader import load_module
from ._base import Config
-LDAP_PROVIDER = 'ldap_auth_provider.LdapAuthProvider'
+LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider"
class PasswordAuthProviderConfig(Config):
- def read_config(self, config):
- self.password_providers = []
+ section = "authproviders"
+
+ def read_config(self, config, **kwargs):
+ self.password_providers = [] # type: List[Any]
providers = []
# We want to be backwards compatible with the old `ldap_config`
# param.
ldap_config = config.get("ldap_config", {})
if ldap_config.get("enabled", False):
- providers.append({
- 'module': LDAP_PROVIDER,
- 'config': ldap_config,
- })
+ providers.append({"module": LDAP_PROVIDER, "config": ldap_config})
providers.extend(config.get("password_providers", []))
for provider in providers:
- mod_name = provider['module']
+ mod_name = provider["module"]
# This is for backwards compat when the ldap auth provider resided
# in this package.
if mod_name == "synapse.util.ldap_auth_provider.LdapAuthProvider":
mod_name = LDAP_PROVIDER
- (provider_class, provider_config) = load_module({
- "module": mod_name,
- "config": provider['config'],
- })
+ (provider_class, provider_config) = load_module(
+ {"module": mod_name, "config": provider["config"]}
+ )
self.password_providers.append((provider_class, provider_config))
- def default_config(self, **kwargs):
+ def generate_config_section(self, **kwargs):
return """\
#password_providers:
# - module: "ldap_auth_provider.LdapAuthProvider"
diff --git a/synapse/config/push.py b/synapse/config/push.py
index 62c0060c9c..6f2b3a7faa 100644
--- a/synapse/config/push.py
+++ b/synapse/config/push.py
@@ -18,7 +18,9 @@ from ._base import Config
class PushConfig(Config):
- def read_config(self, config):
+ section = "push"
+
+ def read_config(self, config, **kwargs):
push_config = config.get("push", {})
self.push_include_content = push_config.get("include_content", True)
@@ -33,7 +35,7 @@ class PushConfig(Config):
# Now check for the one in the 'email' section and honour it,
# with a warning.
- push_config = config.get("email", {})
+ push_config = config.get("email") or {}
redact_content = push_config.get("redact_content")
if redact_content is not None:
print(
@@ -42,7 +44,7 @@ class PushConfig(Config):
)
self.push_include_content = not redact_content
- def default_config(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
# Clients requesting push notifications can either have the body of
# the message sent in the notification poke along with other details
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 2a4fe43406..dbc3dd7a2c 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -23,7 +23,7 @@ class RateLimitConfig(object):
class FederationRateLimitConfig(object):
_items_and_default = {
- "window_size": 10000,
+ "window_size": 1000,
"sleep_limit": 10,
"sleep_delay": 500,
"reject_limit": 50,
@@ -36,7 +36,9 @@ class FederationRateLimitConfig(object):
class RatelimitConfig(Config):
- def read_config(self, config):
+ section = "ratelimiting"
+
+ def read_config(self, config, **kwargs):
# Load the new-style messages config if it exists. Otherwise fall back
# to the old method.
@@ -54,7 +56,7 @@ class RatelimitConfig(Config):
# Load the new-style federation config, if it exists. Otherwise, fall
# back to the old method.
- if "federation_rc" in config:
+ if "rc_federation" in config:
self.rc_federation = FederationRateLimitConfig(**config["rc_federation"])
else:
self.rc_federation = FederationRateLimitConfig(
@@ -83,7 +85,12 @@ class RatelimitConfig(Config):
"federation_rr_transactions_per_room_per_second", 50
)
- def default_config(self, **kwargs):
+ rc_admin_redaction = config.get("rc_admin_redaction")
+ self.rc_admin_redaction = None
+ if rc_admin_redaction:
+ self.rc_admin_redaction = RateLimitConfig(rc_admin_redaction)
+
+ def generate_config_section(self, **kwargs):
return """\
## Ratelimiting ##
@@ -107,6 +114,9 @@ class RatelimitConfig(Config):
# attempts for this account.
# - one that ratelimits third-party invites requests based on the account
# that's making the requests.
+ # - one for ratelimiting redactions by room admins. If this is not explicitly
+ # set then it uses the same ratelimiting as per rc_message. This is useful
+ # to allow room admins to deal with abuse quickly.
#
# The defaults are as shown below.
#
@@ -132,6 +142,10 @@ class RatelimitConfig(Config):
#rc_third_party_invite:
# per_second: 0.2
# burst_count: 10
+ #
+ #rc_admin_redaction:
+ # per_second: 1
+ # burst_count: 50
# Ratelimiting settings for incoming federation
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 14752298e9..687433e88a 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -24,9 +24,14 @@ from synapse.util.stringutils import random_string_with_symbols
class AccountValidityConfig(Config):
+ section = "accountvalidity"
+
def __init__(self, config, synapse_config):
+ if config is None:
+ return
+ super(AccountValidityConfig, self).__init__()
self.enabled = config.get("enabled", False)
- self.renew_by_email_enabled = ("renew_at" in config)
+ self.renew_by_email_enabled = "renew_at" in config
if self.enabled:
if "period" in config:
@@ -42,7 +47,7 @@ class AccountValidityConfig(Config):
else:
self.renew_email_subject = "Renew your %(app)s account"
- self.startup_job_max_delta = self.period * 10. / 100.
+ self.startup_job_max_delta = self.period * 10.0 / 100.0
if self.renew_by_email_enabled:
if "public_baseurl" not in synapse_config:
@@ -77,8 +82,9 @@ class AccountValidityConfig(Config):
class RegistrationConfig(Config):
+ section = "registration"
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
self.enable_registration = bool(
strtobool(str(config.get("enable_registration", False)))
)
@@ -88,7 +94,7 @@ class RegistrationConfig(Config):
)
self.account_validity = AccountValidityConfig(
- config.get("account_validity", {}), config,
+ config.get("account_validity") or {}, config
)
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
@@ -104,25 +110,48 @@ class RegistrationConfig(Config):
self.registration_shared_secret = config.get("registration_shared_secret")
self.register_mxid_from_3pid = config.get("register_mxid_from_3pid")
self.register_just_use_email_for_display_name = config.get(
- "register_just_use_email_for_display_name", False,
+ "register_just_use_email_for_display_name", False
)
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
self.trusted_third_party_id_servers = config.get(
- "trusted_third_party_id_servers",
- ["matrix.org", "vector.im"],
+ "trusted_third_party_id_servers", ["matrix.org", "vector.im"]
)
+ account_threepid_delegates = config.get("account_threepid_delegates") or {}
+ self.account_threepid_delegate_email = account_threepid_delegates.get("email")
+ if (
+ self.account_threepid_delegate_email
+ and not self.account_threepid_delegate_email.startswith("http")
+ ):
+ raise ConfigError(
+ "account_threepid_delegates.email must begin with http:// or https://"
+ )
+ self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn")
+ if (
+ self.account_threepid_delegate_msisdn
+ and not self.account_threepid_delegate_msisdn.startswith("http")
+ ):
+ raise ConfigError(
+ "account_threepid_delegates.msisdn must begin with http:// or https://"
+ )
+ if self.account_threepid_delegate_msisdn and not self.public_baseurl:
+ raise ConfigError(
+ "The configuration option `public_baseurl` is required if "
+ "`account_threepid_delegate.msisdn` is set, such that "
+ "clients know where to submit validation tokens to. Please "
+ "configure `public_baseurl`."
+ )
+
self.default_identity_server = config.get("default_identity_server")
self.allow_guest_access = config.get("allow_guest_access", False)
- self.invite_3pid_guest = (
- self.allow_guest_access and config.get("invite_3pid_guest", False)
- )
+ if config.get("invite_3pid_guest", False):
+ raise ConfigError("invite_3pid_guest is no longer supported")
self.auto_join_rooms = config.get("auto_join_rooms", [])
for room_alias in self.auto_join_rooms:
if not RoomAlias.is_valid(room_alias):
- raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,))
+ raise ConfigError("Invalid auto_join_rooms entry %s" % (room_alias,))
self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
self.disable_set_displayname = config.get("disable_set_displayname", False)
@@ -130,24 +159,34 @@ class RegistrationConfig(Config):
self.replicate_user_profiles_to = config.get("replicate_user_profiles_to", [])
if not isinstance(self.replicate_user_profiles_to, list):
- self.replicate_user_profiles_to = [self.replicate_user_profiles_to, ]
+ self.replicate_user_profiles_to = [self.replicate_user_profiles_to]
self.shadow_server = config.get("shadow_server", None)
- self.rewrite_identity_server_urls = config.get("rewrite_identity_server_urls", {})
+ self.rewrite_identity_server_urls = (
+ config.get("rewrite_identity_server_urls") or {}
+ )
- self.disable_msisdn_registration = (
- config.get("disable_msisdn_registration", False)
+ self.disable_msisdn_registration = config.get(
+ "disable_msisdn_registration", False
)
- def default_config(self, generate_secrets=False, **kwargs):
+ session_lifetime = config.get("session_lifetime")
+ if session_lifetime is not None:
+ session_lifetime = self.parse_duration(session_lifetime)
+ self.session_lifetime = session_lifetime
+
+ def generate_config_section(self, generate_secrets=False, **kwargs):
if generate_secrets:
registration_shared_secret = 'registration_shared_secret: "%s"' % (
random_string_with_symbols(50),
)
else:
- registration_shared_secret = '# registration_shared_secret: <PRIVATE STRING>'
+ registration_shared_secret = (
+ "# registration_shared_secret: <PRIVATE STRING>"
+ )
- return """\
+ return (
+ """\
## Registration ##
#
# Registration can be rate-limited using the parameters in the "Ratelimiting"
@@ -160,23 +199,6 @@ class RegistrationConfig(Config):
# Optional account validity configuration. This allows for accounts to be denied
# any request after a given period.
#
- # ``enabled`` defines whether the account validity feature is enabled. Defaults
- # to False.
- #
- # ``period`` allows setting the period after which an account is valid
- # after its registration. When renewing the account, its validity period
- # will be extended by this amount of time. This parameter is required when using
- # the account validity feature.
- #
- # ``renew_at`` is the amount of time before an account's expiry date at which
- # Synapse will send an email to the account's email address with a renewal link.
- # This needs the ``email`` and ``public_baseurl`` configuration sections to be
- # filled.
- #
- # ``renew_email_subject`` is the subject of the email sent out with the renewal
- # link. ``%%(app)s`` can be used as a placeholder for the ``app_name`` parameter
- # from the ``email`` section.
- #
# Once this feature is enabled, Synapse will look for registered users without an
# expiration date at startup and will add one to every account it found using the
# current settings at that time.
@@ -187,21 +209,66 @@ class RegistrationConfig(Config):
# date will be randomly selected within a range [now + period - d ; now + period],
# where d is equal to 10%% of the validity period.
#
- #account_validity:
- # enabled: True
- # period: 6w
- # renew_at: 1w
- # renew_email_subject: "Renew your %%(app)s account"
- # # Directory in which Synapse will try to find the HTML files to serve to the
- # # user when trying to renew an account. Optional, defaults to
- # # synapse/res/templates.
- # template_dir: "res/templates"
- # # HTML to be displayed to the user after they successfully renewed their
- # # account. Optional.
- # account_renewed_html_path: "account_renewed.html"
- # # HTML to be displayed when the user tries to renew an account with an invalid
- # # renewal token. Optional.
- # invalid_token_html_path: "invalid_token.html"
+ account_validity:
+ # The account validity feature is disabled by default. Uncomment the
+ # following line to enable it.
+ #
+ #enabled: true
+
+ # The period after which an account is valid after its registration. When
+ # renewing the account, its validity period will be extended by this amount
+ # of time. This parameter is required when using the account validity
+ # feature.
+ #
+ #period: 6w
+
+ # The amount of time before an account's expiry date at which Synapse will
+ # send an email to the account's email address with a renewal link. By
+ # default, no such emails are sent.
+ #
+ # If you enable this setting, you will also need to fill out the 'email' and
+ # 'public_baseurl' configuration sections.
+ #
+ #renew_at: 1w
+
+ # The subject of the email sent out with the renewal link. '%%(app)s' can be
+ # used as a placeholder for the 'app_name' parameter from the 'email'
+ # section.
+ #
+ # Note that the placeholder must be written '%%(app)s', including the
+ # trailing 's'.
+ #
+ # If this is not set, a default value is used.
+ #
+ #renew_email_subject: "Renew your %%(app)s account"
+
+ # Directory in which Synapse will try to find templates for the HTML files to
+ # serve to the user when trying to renew an account. If not set, default
+ # templates from within the Synapse package will be used.
+ #
+ #template_dir: "res/templates"
+
+ # File within 'template_dir' giving the HTML to be displayed to the user after
+ # they successfully renewed their account. If not set, default text is used.
+ #
+ #account_renewed_html_path: "account_renewed.html"
+
+ # File within 'template_dir' giving the HTML to be displayed when the user
+ # tries to renew an account with an invalid renewal token. If not set,
+ # default text is used.
+ #
+ #invalid_token_html_path: "invalid_token.html"
+
+ # Time that a user's session remains valid for, after they log in.
+ #
+ # Note that this is not currently compatible with guest logins.
+ #
+ # Note also that this is calculated at login time: changes are not applied
+ # retrospectively to users who have already logged in.
+ #
+ # By default, this is infinite.
+ #
+ #session_lifetime: 24h
# The user must provide all of the below types of 3PID when registering.
#
@@ -238,7 +305,7 @@ class RegistrationConfig(Config):
# pending invites for the given 3PID (and then allow it to sign up on
# the platform):
#
- #allow_invited_3pids: False
+ #allow_invited_3pids: false
#
#allowed_local_3pids:
# - medium: email
@@ -251,7 +318,7 @@ class RegistrationConfig(Config):
# If true, stop users from trying to change the 3PIDs associated with
# their accounts.
#
- #disable_3pid_changes: False
+ #disable_3pid_changes: false
# Enable 3PIDs lookup requests to identity servers from this server.
#
@@ -290,6 +357,14 @@ class RegistrationConfig(Config):
# Also defines the ID server which will be called when an account is
# deactivated (one will be picked arbitrarily).
#
+ # Note: This option is deprecated. Since v0.99.4, Synapse has tracked which identity
+ # server a 3PID has been bound to. For 3PIDs bound before then, Synapse runs a
+ # background migration script, informing itself that the identity server all of its
+ # 3PIDs have been bound to is likely one of the below.
+ #
+ # As of Synapse v1.4.0, all other functionality of this option has been deprecated, and
+ # it is now solely used for the purposes of the background migration script, and can be
+ # removed once it has run.
#trusted_third_party_id_servers:
# - matrix.org
# - vector.im
@@ -315,8 +390,34 @@ class RegistrationConfig(Config):
# Useful when provisioning users based on the contents of a 3rd party
# directory and to avoid ambiguities.
#
- #disable_set_displayname: False
- #disable_set_avatar_url: False
+ #disable_set_displayname: false
+ #disable_set_avatar_url: false
+
+ # Handle threepid (email/phone etc) registration and password resets through a set of
+ # *trusted* identity servers. Note that this allows the configured identity server to
+ # reset passwords for accounts!
+ #
+ # Be aware that if `email` is not set, and SMTP options have not been
+ # configured in the email config block, registration and user password resets via
+ # email will be globally disabled.
+ #
+ # Additionally, if `msisdn` is not set, registration and password resets via msisdn
+ # will be disabled regardless. This is due to Synapse currently not supporting any
+ # method of sending SMS messages on its own.
+ #
+ # To enable using an identity server for operations regarding a particular third-party
+ # identifier type, set the value to the URL of that identity server as shown in the
+ # examples below.
+ #
+ # Servers handling the these requests must answer the `/requestToken` endpoints defined
+ # by the Matrix Identity Service API specification:
+ # https://matrix.org/docs/spec/identity_service/latest
+ #
+ # If a delegate is specified, the config option public_baseurl must also be filled out.
+ #
+ account_threepid_delegates:
+ #email: https://example.com # Delegate email sending to example.com
+ #msisdn: http://localhost:8090 # Delegate SMS sending to this local process
# Users who register on this homeserver will automatically be joined
# to these rooms
@@ -331,17 +432,27 @@ class RegistrationConfig(Config):
# users cannot be auto-joined since they do not exist.
#
#autocreate_auto_join_rooms: true
- """ % locals()
- def add_arguments(self, parser):
+ # Rewrite identity server URLs with a map from one URL to another. Applies to URLs
+ # provided by clients (which have https:// prepended) and those specified
+ # in `account_threepid_delegates`. URLs should not feature a trailing slash.
+ #
+ #rewrite_identity_server_urls:
+ # "https://somewhere.example.com": "https://somewhereelse.example.com"
+ """
+ % locals()
+ )
+
+ @staticmethod
+ def add_arguments(parser):
reg_group = parser.add_argument_group("registration")
reg_group.add_argument(
- "--enable-registration", action="store_true", default=None,
- help="Enable registration for new users."
+ "--enable-registration",
+ action="store_true",
+ default=None,
+ help="Enable registration for new users.",
)
def read_arguments(self, args):
if args.enable_registration is not None:
- self.enable_registration = bool(
- strtobool(str(args.enable_registration))
- )
+ self.enable_registration = bool(strtobool(str(args.enable_registration)))
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 2abede409a..5ebc2ea1f1 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2014, 2015 matrix.org
+# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,35 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import os
from collections import namedtuple
+from typing import Dict, List
+from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module
from ._base import Config, ConfigError
DEFAULT_THUMBNAIL_SIZES = [
- {
- "width": 32,
- "height": 32,
- "method": "crop",
- }, {
- "width": 96,
- "height": 96,
- "method": "crop",
- }, {
- "width": 320,
- "height": 240,
- "method": "scale",
- }, {
- "width": 640,
- "height": 480,
- "method": "scale",
- }, {
- "width": 800,
- "height": 600,
- "method": "scale"
- },
+ {"width": 32, "height": 32, "method": "crop"},
+ {"width": 96, "height": 96, "method": "crop"},
+ {"width": 320, "height": 240, "method": "scale"},
+ {"width": 640, "height": 480, "method": "scale"},
+ {"width": 800, "height": 600, "method": "scale"},
]
THUMBNAIL_SIZE_YAML = """\
@@ -49,27 +36,13 @@ THUMBNAIL_SIZE_YAML = """\
# method: %(method)s
"""
-MISSING_NETADDR = (
- "Missing netaddr library. This is required for URL preview API."
-)
-
-MISSING_LXML = (
- """Missing lxml library. This is required for URL preview API.
-
- Install by running:
- pip install lxml
-
- Requires libxslt1-dev system package.
- """
-)
-
-
ThumbnailRequirement = namedtuple(
"ThumbnailRequirement", ["width", "height", "method", "media_type"]
)
MediaStorageProviderConfig = namedtuple(
- "MediaStorageProviderConfig", (
+ "MediaStorageProviderConfig",
+ (
"store_local", # Whether to store newly uploaded local files
"store_remote", # Whether to store newly downloaded remote files
"store_synchronous", # Whether to wait for successful storage for local uploads
@@ -89,7 +62,7 @@ def parse_thumbnail_requirements(thumbnail_sizes):
Dictionary mapping from media type string to list of
ThumbnailRequirement tuples.
"""
- requirements = {}
+ requirements = {} # type: Dict[str, List]
for size in thumbnail_sizes:
width = size["width"]
height = size["height"]
@@ -100,13 +73,26 @@ def parse_thumbnail_requirements(thumbnail_sizes):
requirements.setdefault("image/gif", []).append(png_thumbnail)
requirements.setdefault("image/png", []).append(png_thumbnail)
return {
- media_type: tuple(thumbnails)
- for media_type, thumbnails in requirements.items()
+ media_type: tuple(thumbnails) for media_type, thumbnails in requirements.items()
}
class ContentRepositoryConfig(Config):
- def read_config(self, config):
+ section = "media"
+
+ def read_config(self, config, **kwargs):
+
+ # Only enable the media repo if either the media repo is enabled or the
+ # current worker app is the media repo.
+ if (
+ self.enable_media_repo is False
+ and config.get("worker_app") != "synapse.app.media_repository"
+ ):
+ self.can_load_media_repo = False
+ return
+ else:
+ self.can_load_media_repo = True
+
self.max_upload_size = self.parse_size(config.get("max_upload_size", "10M"))
self.max_image_pixels = self.parse_size(config.get("max_image_pixels", "32M"))
self.max_spider_size = self.parse_size(config.get("max_spider_size", "10M"))
@@ -117,7 +103,9 @@ class ContentRepositoryConfig(Config):
self.allowed_avatar_mimetypes = config.get("allowed_avatar_mimetypes", [])
- self.media_store_path = self.ensure_directory(config["media_store_path"])
+ self.media_store_path = self.ensure_directory(
+ config.get("media_store_path", "media_store")
+ )
backup_media_store_path = config.get("backup_media_store_path")
@@ -133,15 +121,15 @@ class ContentRepositoryConfig(Config):
"Cannot use both 'backup_media_store_path' and 'storage_providers'"
)
- storage_providers = [{
- "module": "file_system",
- "store_local": True,
- "store_synchronous": synchronous_backup_media_store,
- "store_remote": True,
- "config": {
- "directory": backup_media_store_path,
+ storage_providers = [
+ {
+ "module": "file_system",
+ "store_local": True,
+ "store_synchronous": synchronous_backup_media_store,
+ "store_remote": True,
+ "config": {"directory": backup_media_store_path},
}
- }]
+ ]
# This is a list of config that can be used to create the storage
# providers. The entries are tuples of (Class, class_config,
@@ -151,7 +139,7 @@ class ContentRepositoryConfig(Config):
#
# We don't create the storage providers here as not all workers need
# them to be started.
- self.media_storage_providers = []
+ self.media_storage_providers = [] # type: List[tuple]
for provider_config in storage_providers:
# We special case the module "file_system" so as not to need to
@@ -171,26 +159,20 @@ class ContentRepositoryConfig(Config):
)
self.media_storage_providers.append(
- (provider_class, parsed_config, wrapper_config,)
+ (provider_class, parsed_config, wrapper_config)
)
- self.uploads_path = self.ensure_directory(config["uploads_path"])
self.dynamic_thumbnails = config.get("dynamic_thumbnails", False)
self.thumbnail_requirements = parse_thumbnail_requirements(
- config.get("thumbnail_sizes", DEFAULT_THUMBNAIL_SIZES),
+ config.get("thumbnail_sizes", DEFAULT_THUMBNAIL_SIZES)
)
self.url_preview_enabled = config.get("url_preview_enabled", False)
if self.url_preview_enabled:
try:
- import lxml
- lxml # To stop unused lint.
- except ImportError:
- raise ConfigError(MISSING_LXML)
+ check_requirements("url_preview")
- try:
- from netaddr import IPSet
- except ImportError:
- raise ConfigError(MISSING_NETADDR)
+ except DependencyException as e:
+ raise ConfigError(e.message)
if "url_preview_ip_range_blacklist" not in config:
raise ConfigError(
@@ -199,23 +181,24 @@ class ContentRepositoryConfig(Config):
"to work"
)
+ # netaddr is a dependency for url_preview
+ from netaddr import IPSet
+
self.url_preview_ip_range_blacklist = IPSet(
config["url_preview_ip_range_blacklist"]
)
# we always blacklist '0.0.0.0' and '::', which are supposed to be
# unroutable addresses.
- self.url_preview_ip_range_blacklist.update(['0.0.0.0', '::'])
+ self.url_preview_ip_range_blacklist.update(["0.0.0.0", "::"])
self.url_preview_ip_range_whitelist = IPSet(
config.get("url_preview_ip_range_whitelist", ())
)
- self.url_preview_url_blacklist = config.get(
- "url_preview_url_blacklist", ()
- )
+ self.url_preview_url_blacklist = config.get("url_preview_url_blacklist", ())
- def default_config(self, data_dir_path, **kwargs):
+ def generate_config_section(self, data_dir_path, **kwargs):
media_store = os.path.join(data_dir_path, "media_store")
uploads_path = os.path.join(data_dir_path, "uploads")
@@ -225,7 +208,15 @@ class ContentRepositoryConfig(Config):
# strip final NL
formatted_thumbnail_sizes = formatted_thumbnail_sizes[:-1]
- return r"""
+ return (
+ r"""
+ ## Media Store ##
+
+ # Enable the media store service in the Synapse master. Uncomment the
+ # following if you are using a separate media store worker.
+ #
+ #enable_media_repo: false
+
# Directory where uploaded images and attachments are stored.
#
media_store_path: "%(media_store)s"
@@ -245,10 +236,6 @@ class ContentRepositoryConfig(Config):
# config:
# directory: /mnt/some/other/directory
- # Directory where in-progress uploads are stored.
- #
- uploads_path: "%(uploads_path)s"
-
# The largest allowed upload size in bytes
#
#max_upload_size: 10M
@@ -372,4 +359,6 @@ class ContentRepositoryConfig(Config):
# The largest allowed URL preview spidering size in bytes
#
#max_spider_size: 10M
- """ % locals()
+ """
+ % locals()
+ )
diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
index 8a9fded4c5..7ac7699676 100644
--- a/synapse/config/room_directory.py
+++ b/synapse/config/room_directory.py
@@ -19,10 +19,10 @@ from ._base import Config, ConfigError
class RoomDirectoryConfig(Config):
- def read_config(self, config):
- self.enable_room_list_search = config.get(
- "enable_room_list_search", True,
- )
+ section = "roomdirectory"
+
+ def read_config(self, config, **kwargs):
+ self.enable_room_list_search = config.get("enable_room_list_search", True)
alias_creation_rules = config.get("alias_creation_rules")
@@ -33,11 +33,7 @@ class RoomDirectoryConfig(Config):
]
else:
self._alias_creation_rules = [
- _RoomDirectoryRule(
- "alias_creation_rules", {
- "action": "allow",
- }
- )
+ _RoomDirectoryRule("alias_creation_rules", {"action": "allow"})
]
room_list_publication_rules = config.get("room_list_publication_rules")
@@ -49,14 +45,10 @@ class RoomDirectoryConfig(Config):
]
else:
self._room_list_publication_rules = [
- _RoomDirectoryRule(
- "room_list_publication_rules", {
- "action": "allow",
- }
- )
+ _RoomDirectoryRule("room_list_publication_rules", {"action": "allow"})
]
- def default_config(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
# Uncomment to disable searching the public room list. When disabled
# blocks searching local and remote room lists for local and remote
@@ -178,8 +170,7 @@ class _RoomDirectoryRule(object):
self.action = action
else:
raise ConfigError(
- "%s rules can only have action of 'allow'"
- " or 'deny'" % (option_name,)
+ "%s rules can only have action of 'allow' or 'deny'" % (option_name,)
)
self._alias_matches_all = alias == "*"
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index aa6eac271f..8fe64d90f8 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,11 +14,55 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import os
+
+import pkg_resources
+
+from synapse.python_dependencies import DependencyException, check_requirements
+from synapse.util.module_loader import load_module, load_python_module
+
from ._base import Config, ConfigError
+logger = logging.getLogger(__name__)
+
+DEFAULT_USER_MAPPING_PROVIDER = (
+ "synapse.handlers.saml_handler.DefaultSamlMappingProvider"
+)
+
+
+def _dict_merge(merge_dict, into_dict):
+ """Do a deep merge of two dicts
+
+ Recursively merges `merge_dict` into `into_dict`:
+ * For keys where both `merge_dict` and `into_dict` have a dict value, the values
+ are recursively merged
+ * For all other keys, the values in `into_dict` (if any) are overwritten with
+ the value from `merge_dict`.
+
+ Args:
+ merge_dict (dict): dict to merge
+ into_dict (dict): target dict
+ """
+ for k, v in merge_dict.items():
+ if k not in into_dict:
+ into_dict[k] = v
+ continue
+
+ current_val = into_dict[k]
+
+ if isinstance(v, dict) and isinstance(current_val, dict):
+ _dict_merge(v, current_val)
+ continue
+
+ # otherwise we just overwrite
+ into_dict[k] = v
+
class SAML2Config(Config):
- def read_config(self, config):
+ section = "saml2"
+
+ def read_config(self, config, **kwargs):
self.saml2_enabled = False
saml2_config = config.get("saml2_config")
@@ -25,85 +70,291 @@ class SAML2Config(Config):
if not saml2_config or not saml2_config.get("enabled", True):
return
+ if not saml2_config.get("sp_config") and not saml2_config.get("config_path"):
+ return
+
+ try:
+ check_requirements("saml2")
+ except DependencyException as e:
+ raise ConfigError(e.message)
+
self.saml2_enabled = True
- import saml2.config
- self.saml2_sp_config = saml2.config.SPConfig()
- self.saml2_sp_config.load(self._default_saml_config_dict())
- self.saml2_sp_config.load(saml2_config.get("sp_config", {}))
+ self.saml2_grandfathered_mxid_source_attribute = saml2_config.get(
+ "grandfathered_mxid_source_attribute", "uid"
+ )
+
+ # user_mapping_provider may be None if the key is present but has no value
+ ump_dict = saml2_config.get("user_mapping_provider") or {}
+
+ # Use the default user mapping provider if not set
+ ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
+
+ # Ensure a config is present
+ ump_dict["config"] = ump_dict.get("config") or {}
+
+ if ump_dict["module"] == DEFAULT_USER_MAPPING_PROVIDER:
+ # Load deprecated options for use by the default module
+ old_mxid_source_attribute = saml2_config.get("mxid_source_attribute")
+ if old_mxid_source_attribute:
+ logger.warning(
+ "The config option saml2_config.mxid_source_attribute is deprecated. "
+ "Please use saml2_config.user_mapping_provider.config"
+ ".mxid_source_attribute instead."
+ )
+ ump_dict["config"]["mxid_source_attribute"] = old_mxid_source_attribute
+
+ old_mxid_mapping = saml2_config.get("mxid_mapping")
+ if old_mxid_mapping:
+ logger.warning(
+ "The config option saml2_config.mxid_mapping is deprecated. Please "
+ "use saml2_config.user_mapping_provider.config.mxid_mapping instead."
+ )
+ ump_dict["config"]["mxid_mapping"] = old_mxid_mapping
+
+ # Retrieve an instance of the module's class
+ # Pass the config dictionary to the module for processing
+ (
+ self.saml2_user_mapping_provider_class,
+ self.saml2_user_mapping_provider_config,
+ ) = load_module(ump_dict)
+
+ # Ensure loaded user mapping module has defined all necessary methods
+ # Note parse_config() is already checked during the call to load_module
+ required_methods = [
+ "get_saml_attributes",
+ "saml_response_to_user_attributes",
+ "get_remote_user_id",
+ ]
+ missing_methods = [
+ method
+ for method in required_methods
+ if not hasattr(self.saml2_user_mapping_provider_class, method)
+ ]
+ if missing_methods:
+ raise ConfigError(
+ "Class specified by saml2_config."
+ "user_mapping_provider.module is missing required "
+ "methods: %s" % (", ".join(missing_methods),)
+ )
+
+ # Get the desired saml auth response attributes from the module
+ saml2_config_dict = self._default_saml_config_dict(
+ *self.saml2_user_mapping_provider_class.get_saml_attributes(
+ self.saml2_user_mapping_provider_config
+ )
+ )
+ _dict_merge(
+ merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict
+ )
config_path = saml2_config.get("config_path", None)
if config_path is not None:
- self.saml2_sp_config.load_file(config_path)
+ mod = load_python_module(config_path)
+ _dict_merge(merge_dict=mod.CONFIG, into_dict=saml2_config_dict)
+
+ import saml2.config
+
+ self.saml2_sp_config = saml2.config.SPConfig()
+ self.saml2_sp_config.load(saml2_config_dict)
+
+ # session lifetime: in milliseconds
+ self.saml2_session_lifetime = self.parse_duration(
+ saml2_config.get("saml_session_lifetime", "5m")
+ )
- def _default_saml_config_dict(self):
+ template_dir = saml2_config.get("template_dir")
+ if not template_dir:
+ template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
+
+ self.saml2_error_html_content = self.read_file(
+ os.path.join(template_dir, "saml_error.html"), "saml2_config.saml_error",
+ )
+
+ def _default_saml_config_dict(
+ self, required_attributes: set, optional_attributes: set
+ ):
+ """Generate a configuration dictionary with required and optional attributes that
+ will be needed to process new user registration
+
+ Args:
+ required_attributes: SAML auth response attributes that are
+ necessary to function
+ optional_attributes: SAML auth response attributes that can be used to add
+ additional information to Synapse user accounts, but are not required
+
+ Returns:
+ dict: A SAML configuration dictionary
+ """
import saml2
public_baseurl = self.public_baseurl
if public_baseurl is None:
- raise ConfigError(
- "saml2_config requires a public_baseurl to be set"
- )
+ raise ConfigError("saml2_config requires a public_baseurl to be set")
+
+ if self.saml2_grandfathered_mxid_source_attribute:
+ optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
+ optional_attributes -= required_attributes
metadata_url = public_baseurl + "_matrix/saml2/metadata.xml"
response_url = public_baseurl + "_matrix/saml2/authn_response"
return {
"entityid": metadata_url,
-
"service": {
"sp": {
"endpoints": {
"assertion_consumer_service": [
- (response_url, saml2.BINDING_HTTP_POST),
- ],
+ (response_url, saml2.BINDING_HTTP_POST)
+ ]
},
- "required_attributes": ["uid"],
- "optional_attributes": ["mail", "surname", "givenname"],
- },
- }
+ "required_attributes": list(required_attributes),
+ "optional_attributes": list(optional_attributes),
+ # "name_id_format": saml2.saml.NAMEID_FORMAT_PERSISTENT,
+ }
+ },
}
- def default_config(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
# Enable SAML2 for registration and login. Uses pysaml2.
#
- # `sp_config` is the configuration for the pysaml2 Service Provider.
- # See pysaml2 docs for format of config.
- #
- # Default values will be used for the 'entityid' and 'service' settings,
- # so it is not normally necessary to specify them unless you need to
- # override them.
- #
- #saml2_config:
- # sp_config:
- # # point this to the IdP's metadata. You can use either a local file or
- # # (preferably) a URL.
- # metadata:
- # #local: ["saml2/idp.xml"]
- # remote:
- # - url: https://our_idp/metadata.xml
- #
- # # The rest of sp_config is just used to generate our metadata xml, and you
- # # may well not need it, depending on your setup. Alternatively you
- # # may need a whole lot more detail - see the pysaml2 docs!
+ # At least one of `sp_config` or `config_path` must be set in this section to
+ # enable SAML login.
#
- # description: ["My awesome SP", "en"]
- # name: ["Test SP", "en"]
+ # (You will probably also want to set the following options to `false` to
+ # disable the regular login/registration flows:
+ # * enable_registration
+ # * password_config.enabled
#
- # organization:
- # name: Example com
- # display_name:
- # - ["Example co", "en"]
- # url: "http://example.com"
+ # Once SAML support is enabled, a metadata file will be exposed at
+ # https://<server>:<port>/_matrix/saml2/metadata.xml, which you may be able to
+ # use to configure your SAML IdP with. Alternatively, you can manually configure
+ # the IdP to use an ACS location of
+ # https://<server>:<port>/_matrix/saml2/authn_response.
#
- # contact_person:
- # - given_name: Bob
- # sur_name: "the Sysadmin"
- # email_address": ["admin@example.com"]
- # contact_type": technical
- #
- # # Instead of putting the config inline as above, you can specify a
- # # separate pysaml2 configuration file:
- # #
- # config_path: "%(config_dir_path)s/sp_conf.py"
- """ % {"config_dir_path": config_dir_path}
+ saml2_config:
+ # `sp_config` is the configuration for the pysaml2 Service Provider.
+ # See pysaml2 docs for format of config.
+ #
+ # Default values will be used for the 'entityid' and 'service' settings,
+ # so it is not normally necessary to specify them unless you need to
+ # override them.
+ #
+ #sp_config:
+ # # point this to the IdP's metadata. You can use either a local file or
+ # # (preferably) a URL.
+ # metadata:
+ # #local: ["saml2/idp.xml"]
+ # remote:
+ # - url: https://our_idp/metadata.xml
+ #
+ # # By default, the user has to go to our login page first. If you'd like
+ # # to allow IdP-initiated login, set 'allow_unsolicited: true' in a
+ # # 'service.sp' section:
+ # #
+ # #service:
+ # # sp:
+ # # allow_unsolicited: true
+ #
+ # # The examples below are just used to generate our metadata xml, and you
+ # # may well not need them, depending on your setup. Alternatively you
+ # # may need a whole lot more detail - see the pysaml2 docs!
+ #
+ # description: ["My awesome SP", "en"]
+ # name: ["Test SP", "en"]
+ #
+ # organization:
+ # name: Example com
+ # display_name:
+ # - ["Example co", "en"]
+ # url: "http://example.com"
+ #
+ # contact_person:
+ # - given_name: Bob
+ # sur_name: "the Sysadmin"
+ # email_address": ["admin@example.com"]
+ # contact_type": technical
+
+ # Instead of putting the config inline as above, you can specify a
+ # separate pysaml2 configuration file:
+ #
+ #config_path: "%(config_dir_path)s/sp_conf.py"
+
+ # The lifetime of a SAML session. This defines how long a user has to
+ # complete the authentication process, if allow_unsolicited is unset.
+ # The default is 5 minutes.
+ #
+ #saml_session_lifetime: 5m
+
+ # An external module can be provided here as a custom solution to
+ # mapping attributes returned from a saml provider onto a matrix user.
+ #
+ user_mapping_provider:
+ # The custom module's class. Uncomment to use a custom module.
+ #
+ #module: mapping_provider.SamlMappingProvider
+
+ # Custom configuration values for the module. Below options are
+ # intended for the built-in provider, they should be changed if
+ # using a custom module. This section will be passed as a Python
+ # dictionary to the module's `parse_config` method.
+ #
+ config:
+ # The SAML attribute (after mapping via the attribute maps) to use
+ # to derive the Matrix ID from. 'uid' by default.
+ #
+ # Note: This used to be configured by the
+ # saml2_config.mxid_source_attribute option. If that is still
+ # defined, its value will be used instead.
+ #
+ #mxid_source_attribute: displayName
+
+ # The mapping system to use for mapping the saml attribute onto a
+ # matrix ID.
+ #
+ # Options include:
+ # * 'hexencode' (which maps unpermitted characters to '=xx')
+ # * 'dotreplace' (which replaces unpermitted characters with
+ # '.').
+ # The default is 'hexencode'.
+ #
+ # Note: This used to be configured by the
+ # saml2_config.mxid_mapping option. If that is still defined, its
+ # value will be used instead.
+ #
+ #mxid_mapping: dotreplace
+
+ # In previous versions of synapse, the mapping from SAML attribute to
+ # MXID was always calculated dynamically rather than stored in a
+ # table. For backwards- compatibility, we will look for user_ids
+ # matching such a pattern before creating a new account.
+ #
+ # This setting controls the SAML attribute which will be used for this
+ # backwards-compatibility lookup. Typically it should be 'uid', but if
+ # the attribute maps are changed, it may be necessary to change it.
+ #
+ # The default is 'uid'.
+ #
+ #grandfathered_mxid_source_attribute: upn
+
+ # Directory in which Synapse will try to find the template files below.
+ # If not set, default templates from within the Synapse package will be used.
+ #
+ # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
+ # If you *do* uncomment it, you will need to make sure that all the templates
+ # below are in the directory.
+ #
+ # Synapse will look for the following templates in this directory:
+ #
+ # * HTML page to display to users if something goes wrong during the
+ # authentication process: 'saml_error.html'.
+ #
+ # This template doesn't currently need any variable to render.
+ #
+ # You can see the default templates at:
+ # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
+ #
+ #template_dir: "res/templates"
+ """ % {
+ "config_dir_path": config_dir_path
+ }
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 2ef1d940c4..f5942c45c2 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -17,7 +17,12 @@
import logging
import os.path
+import re
+from textwrap import indent
+from typing import Dict, List, Optional
+import attr
+import yaml
from netaddr import IPSet
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@@ -34,14 +39,28 @@ logger = logging.Logger(__name__)
#
# We later check for errors when binding to 0.0.0.0 and ignore them if :: is also in
# in the list.
-DEFAULT_BIND_ADDRESSES = ['::', '0.0.0.0']
+DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"]
-DEFAULT_ROOM_VERSION = "4"
+DEFAULT_ROOM_VERSION = "5"
+
+ROOM_COMPLEXITY_TOO_GREAT = (
+ "Your homeserver is unable to join rooms this large or complex. "
+ "Please speak to your server administrator, or upgrade your instance "
+ "to join this room."
+)
+
+METRICS_PORT_WARNING = """\
+The metrics_port configuration option is deprecated in Synapse 0.31 in favour of
+a listener. Please see
+https://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.md
+on how to configure the new listener.
+--------------------------------------------------------------------------------"""
class ServerConfig(Config):
+ section = "server"
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
self.server_name = config["server_name"]
self.server_context = config.get("server_context", None)
@@ -58,7 +77,6 @@ class ServerConfig(Config):
self.user_agent_suffix = config.get("user_agent_suffix")
self.use_frozen_dicts = config.get("use_frozen_dicts", False)
self.public_baseurl = config.get("public_baseurl")
- self.cpu_affinity = config.get("cpu_affinity")
# Whether to send federation traffic out in this process. This only
# applies to some federation traffic, and so shouldn't be used to
@@ -81,13 +99,13 @@ class ServerConfig(Config):
# Whether to require authentication to retrieve profile data (avatars,
# display names) of other users through the client API.
self.require_auth_for_profile_requests = config.get(
- "require_auth_for_profile_requests", False,
+ "require_auth_for_profile_requests", False
)
# Whether to require sharing a room with a user to retrieve their
# profile data
- self.limit_profile_requests_to_known_users = config.get(
- "limit_profile_requests_to_known_users", False,
+ self.limit_profile_requests_to_users_who_share_rooms = config.get(
+ "limit_profile_requests_to_users_who_share_rooms", False,
)
if "restrict_public_rooms_to_local_users" in config and (
@@ -106,28 +124,27 @@ class ServerConfig(Config):
self.allow_public_rooms_without_auth = False
self.allow_public_rooms_over_federation = False
else:
- # If set to 'False', requires authentication to access the server's public
- # rooms directory through the client API. Defaults to 'True'.
+ # If set to 'true', removes the need for authentication to access the server's
+ # public rooms directory through the client API, meaning that anyone can
+ # query the room directory. Defaults to 'false'.
self.allow_public_rooms_without_auth = config.get(
- "allow_public_rooms_without_auth", True
+ "allow_public_rooms_without_auth", False
)
- # If set to 'False', forbids any other homeserver to fetch the server's public
- # rooms directory via federation. Defaults to 'True'.
+ # If set to 'true', allows any other homeserver to fetch the server's public
+ # rooms directory via federation. Defaults to 'false'.
self.allow_public_rooms_over_federation = config.get(
- "allow_public_rooms_over_federation", True
+ "allow_public_rooms_over_federation", False
)
- default_room_version = config.get(
- "default_room_version", DEFAULT_ROOM_VERSION,
- )
+ default_room_version = config.get("default_room_version", DEFAULT_ROOM_VERSION)
# Ensure room version is a str
default_room_version = str(default_room_version)
if default_room_version not in KNOWN_ROOM_VERSIONS:
raise ConfigError(
- "Unknown default_room_version: %s, known room versions: %s" %
- (default_room_version, list(KNOWN_ROOM_VERSIONS.keys()))
+ "Unknown default_room_version: %s, known room versions: %s"
+ % (default_room_version, list(KNOWN_ROOM_VERSIONS.keys()))
)
# Get the actual room version object rather than just the identifier
@@ -142,46 +159,55 @@ class ServerConfig(Config):
# Whether we should block invites sent to users on this server
# (other than those sent by local server admins)
- self.block_non_admin_invites = config.get(
- "block_non_admin_invites", False,
- )
+ self.block_non_admin_invites = config.get("block_non_admin_invites", False)
# Whether to enable experimental MSC1849 (aka relations) support
self.experimental_msc1849_support_enabled = config.get(
- "experimental_msc1849_support_enabled", False,
+ "experimental_msc1849_support_enabled", True
)
# Options to control access by tracking MAU
self.limit_usage_by_mau = config.get("limit_usage_by_mau", False)
self.max_mau_value = 0
if self.limit_usage_by_mau:
- self.max_mau_value = config.get(
- "max_mau_value", 0,
- )
+ self.max_mau_value = config.get("max_mau_value", 0)
self.mau_stats_only = config.get("mau_stats_only", False)
self.mau_limits_reserved_threepids = config.get(
"mau_limit_reserved_threepids", []
)
- self.mau_trial_days = config.get(
- "mau_trial_days", 0,
- )
+ self.mau_trial_days = config.get("mau_trial_days", 0)
+ self.mau_limit_alerting = config.get("mau_limit_alerting", True)
+
+ # How long to keep redacted events in the database in unredacted form
+ # before redacting them.
+ redaction_retention_period = config.get("redaction_retention_period", "7d")
+ if redaction_retention_period is not None:
+ self.redaction_retention_period = self.parse_duration(
+ redaction_retention_period
+ )
+ else:
+ self.redaction_retention_period = None
+
+ # How long to keep entries in the `users_ips` table.
+ user_ips_max_age = config.get("user_ips_max_age", "28d")
+ if user_ips_max_age is not None:
+ self.user_ips_max_age = self.parse_duration(user_ips_max_age)
+ else:
+ self.user_ips_max_age = None
# Options to disable HS
self.hs_disabled = config.get("hs_disabled", False)
self.hs_disabled_message = config.get("hs_disabled_message", "")
- self.hs_disabled_limit_type = config.get("hs_disabled_limit_type", "")
# Admin uri to direct users at should their instance become blocked
# due to resource constraints
self.admin_contact = config.get("admin_contact", None)
# FIXME: federation_domain_whitelist needs sytests
- self.federation_domain_whitelist = None
- federation_domain_whitelist = config.get(
- "federation_domain_whitelist", None,
- )
+ self.federation_domain_whitelist = None # type: Optional[dict]
+ federation_domain_whitelist = config.get("federation_domain_whitelist", None)
if federation_domain_whitelist is not None:
# turn the whitelist into a hash for speed of lookup
@@ -191,7 +217,7 @@ class ServerConfig(Config):
self.federation_domain_whitelist[domain] = True
self.federation_ip_range_blacklist = config.get(
- "federation_ip_range_blacklist", [],
+ "federation_ip_range_blacklist", []
)
# Attempt to create an IPSet from the given ranges
@@ -204,13 +230,12 @@ class ServerConfig(Config):
self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
except Exception as e:
raise ConfigError(
- "Invalid range(s) provided in "
- "federation_ip_range_blacklist: %s" % e
+ "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e
)
if self.public_baseurl is not None:
- if self.public_baseurl[-1] != '/':
- self.public_baseurl += '/'
+ if self.public_baseurl[-1] != "/":
+ self.public_baseurl += "/"
self.start_pushers = config.get("start_pushers", True)
# (undocumented) option for torturing the worker-mode replication a bit,
@@ -221,7 +246,7 @@ class ServerConfig(Config):
# Whether to require a user to be in the room to add an alias to it.
# Defaults to True.
self.require_membership_for_aliases = config.get(
- "require_membership_for_aliases", True,
+ "require_membership_for_aliases", True
)
# Whether to allow per-room membership profiles through the send of membership
@@ -231,7 +256,7 @@ class ServerConfig(Config):
# Whether to show the users on this homeserver in the user directory. Defaults to
# True.
self.show_users_in_user_directory = config.get(
- "show_users_in_user_directory", True,
+ "show_users_in_user_directory", True
)
retention_config = config.get("retention")
@@ -275,13 +300,25 @@ class ServerConfig(Config):
self.retention_default_min_lifetime = None
self.retention_default_max_lifetime = None
- self.retention_allowed_lifetime_min = retention_config.get("allowed_lifetime_min")
+ if self.retention_enabled:
+ logger.info(
+ "Message retention policies support enabled with the following default"
+ " policy: min_lifetime = %s ; max_lifetime = %s",
+ self.retention_default_min_lifetime,
+ self.retention_default_max_lifetime,
+ )
+
+ self.retention_allowed_lifetime_min = retention_config.get(
+ "allowed_lifetime_min"
+ )
if self.retention_allowed_lifetime_min is not None:
self.retention_allowed_lifetime_min = self.parse_duration(
self.retention_allowed_lifetime_min
)
- self.retention_allowed_lifetime_max = retention_config.get("allowed_lifetime_max")
+ self.retention_allowed_lifetime_max = retention_config.get(
+ "allowed_lifetime_max"
+ )
if self.retention_allowed_lifetime_max is not None:
self.retention_allowed_lifetime_max = self.parse_duration(
self.retention_allowed_lifetime_max
@@ -290,14 +327,15 @@ class ServerConfig(Config):
if (
self.retention_allowed_lifetime_min is not None
and self.retention_allowed_lifetime_max is not None
- and self.retention_allowed_lifetime_min > self.retention_allowed_lifetime_max
+ and self.retention_allowed_lifetime_min
+ > self.retention_allowed_lifetime_max
):
raise ConfigError(
"Invalid retention policy limits: 'allowed_lifetime_min' can not be"
" greater than 'allowed_lifetime_max'"
)
- self.retention_purge_jobs = []
+ self.retention_purge_jobs = [] # type: List[Dict[str, Optional[int]]]
for purge_job_config in retention_config.get("purge_jobs", []):
interval_config = purge_job_config.get("interval")
@@ -330,20 +368,24 @@ class ServerConfig(Config):
" 'longest_max_lifetime' value."
)
- self.retention_purge_jobs.append({
- "interval": interval,
- "shortest_max_lifetime": shortest_max_lifetime,
- "longest_max_lifetime": longest_max_lifetime,
- })
+ self.retention_purge_jobs.append(
+ {
+ "interval": interval,
+ "shortest_max_lifetime": shortest_max_lifetime,
+ "longest_max_lifetime": longest_max_lifetime,
+ }
+ )
if not self.retention_purge_jobs:
- self.retention_purge_jobs = [{
- "interval": self.parse_duration("1d"),
- "shortest_max_lifetime": None,
- "longest_max_lifetime": None,
- }]
-
- self.listeners = []
+ self.retention_purge_jobs = [
+ {
+ "interval": self.parse_duration("1d"),
+ "shortest_max_lifetime": None,
+ "longest_max_lifetime": None,
+ }
+ ]
+
+ self.listeners = [] # type: List[dict]
for listener in config.get("listeners", []):
if not isinstance(listener.get("port", None), int):
raise ConfigError(
@@ -368,9 +410,9 @@ class ServerConfig(Config):
# if we still have an empty list of addresses, use the default list
if not bind_addresses:
- if listener['type'] == 'metrics':
+ if listener["type"] == "metrics":
# the metrics listener doesn't support IPv6
- bind_addresses.append('0.0.0.0')
+ bind_addresses.append("0.0.0.0")
else:
bind_addresses.extend(DEFAULT_BIND_ADDRESSES)
@@ -381,6 +423,26 @@ class ServerConfig(Config):
self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
+ @attr.s
+ class LimitRemoteRoomsConfig(object):
+ enabled = attr.ib(
+ validator=attr.validators.instance_of(bool), default=False
+ )
+ complexity = attr.ib(
+ validator=attr.validators.instance_of(
+ (float, int) # type: ignore[arg-type] # noqa
+ ),
+ default=1.0,
+ )
+ complexity_error = attr.ib(
+ validator=attr.validators.instance_of(str),
+ default=ROOM_COMPLEXITY_TOO_GREAT,
+ )
+
+ self.limit_remote_rooms = LimitRemoteRoomsConfig(
+ **config.get("limit_remote_rooms", {})
+ )
+
bind_port = config.get("bind_port")
if bind_port:
if config.get("no_tls", False):
@@ -390,78 +452,73 @@ class ServerConfig(Config):
bind_host = config.get("bind_host", "")
gzip_responses = config.get("gzip_responses", True)
- self.listeners.append({
- "port": bind_port,
- "bind_addresses": [bind_host],
- "tls": True,
- "type": "http",
- "resources": [
- {
- "names": ["client"],
- "compress": gzip_responses,
- },
- {
- "names": ["federation"],
- "compress": False,
- }
- ]
- })
-
- unsecure_port = config.get("unsecure_port", bind_port - 400)
- if unsecure_port:
- self.listeners.append({
- "port": unsecure_port,
+ self.listeners.append(
+ {
+ "port": bind_port,
"bind_addresses": [bind_host],
- "tls": False,
+ "tls": True,
"type": "http",
"resources": [
- {
- "names": ["client"],
- "compress": gzip_responses,
- },
- {
- "names": ["federation"],
- "compress": False,
- }
- ]
- })
+ {"names": ["client"], "compress": gzip_responses},
+ {"names": ["federation"], "compress": False},
+ ],
+ }
+ )
+
+ unsecure_port = config.get("unsecure_port", bind_port - 400)
+ if unsecure_port:
+ self.listeners.append(
+ {
+ "port": unsecure_port,
+ "bind_addresses": [bind_host],
+ "tls": False,
+ "type": "http",
+ "resources": [
+ {"names": ["client"], "compress": gzip_responses},
+ {"names": ["federation"], "compress": False},
+ ],
+ }
+ )
manhole = config.get("manhole")
if manhole:
- self.listeners.append({
- "port": manhole,
- "bind_addresses": ["127.0.0.1"],
- "type": "manhole",
- "tls": False,
- })
+ self.listeners.append(
+ {
+ "port": manhole,
+ "bind_addresses": ["127.0.0.1"],
+ "type": "manhole",
+ "tls": False,
+ }
+ )
metrics_port = config.get("metrics_port")
if metrics_port:
- logger.warn(
- ("The metrics_port configuration option is deprecated in Synapse 0.31 "
- "in favour of a listener. Please see "
- "http://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.rst"
- " on how to configure the new listener."))
-
- self.listeners.append({
- "port": metrics_port,
- "bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")],
- "tls": False,
- "type": "http",
- "resources": [
- {
- "names": ["metrics"],
- "compress": False,
- },
- ]
- })
+ logger.warning(METRICS_PORT_WARNING)
+
+ self.listeners.append(
+ {
+ "port": metrics_port,
+ "bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")],
+ "tls": False,
+ "type": "http",
+ "resources": [{"names": ["metrics"], "compress": False}],
+ }
+ )
_check_resource_config(self.listeners)
- def has_tls_listener(self):
+ self.cleanup_extremities_with_dummy_events = config.get(
+ "cleanup_extremities_with_dummy_events", True
+ )
+
+ self.enable_ephemeral_messages = config.get("enable_ephemeral_messages", False)
+
+ def has_tls_listener(self) -> bool:
return any(l["tls"] for l in self.listeners)
- def default_config(self, server_name, data_dir_path, **kwargs):
+ def generate_config_section(
+ self, server_name, data_dir_path, open_private_ports, listeners, **kwargs
+ ):
_, bind_port = parse_and_validate_server_name(server_name)
if bind_port is not None:
unsecure_port = bind_port - 400
@@ -474,7 +531,72 @@ class ServerConfig(Config):
# Bring DEFAULT_ROOM_VERSION into the local-scope for use in the
# default config string
default_room_version = DEFAULT_ROOM_VERSION
- return """\
+ secure_listeners = []
+ unsecure_listeners = []
+ private_addresses = ["::1", "127.0.0.1"]
+ if listeners:
+ for listener in listeners:
+ if listener["tls"]:
+ secure_listeners.append(listener)
+ else:
+ # If we don't want open ports we need to bind the listeners
+ # to some address other than 0.0.0.0. Here we chose to use
+ # localhost.
+ # If the addresses are already bound we won't overwrite them
+ # however.
+ if not open_private_ports:
+ listener.setdefault("bind_addresses", private_addresses)
+
+ unsecure_listeners.append(listener)
+
+ secure_http_bindings = indent(
+ yaml.dump(secure_listeners), " " * 10
+ ).lstrip()
+
+ unsecure_http_bindings = indent(
+ yaml.dump(unsecure_listeners), " " * 10
+ ).lstrip()
+
+ if not unsecure_listeners:
+ unsecure_http_bindings = (
+ """- port: %(unsecure_port)s
+ tls: false
+ type: http
+ x_forwarded: true"""
+ % locals()
+ )
+
+ if not open_private_ports:
+ unsecure_http_bindings += (
+ "\n bind_addresses: ['::1', '127.0.0.1']"
+ )
+
+ unsecure_http_bindings += """
+
+ resources:
+ - names: [client, federation]
+ compress: false"""
+
+ if listeners:
+ # comment out this block
+ unsecure_http_bindings = "#" + re.sub(
+ "\n {10}",
+ lambda match: match.group(0) + "#",
+ unsecure_http_bindings,
+ )
+
+ if not secure_listeners:
+ secure_http_bindings = (
+ """#- port: %(bind_port)s
+ # type: http
+ # tls: true
+ # resources:
+ # - names: [client, federation]"""
+ % locals()
+ )
+
+ return (
+ """\
## Server ##
# The domain name of the server, with optional explicit port.
@@ -488,29 +610,6 @@ class ServerConfig(Config):
#
pid_file: %(pid_file)s
- # CPU affinity mask. Setting this restricts the CPUs on which the
- # process will be scheduled. It is represented as a bitmask, with the
- # lowest order bit corresponding to the first logical CPU and the
- # highest order bit corresponding to the last logical CPU. Not all CPUs
- # may exist on a given system but a mask may specify more CPUs than are
- # present.
- #
- # For example:
- # 0x00000001 is processor #0,
- # 0x00000003 is processors #0 and #1,
- # 0xFFFFFFFF is all processors (#0 through #31).
- #
- # Pinning a Python process to a single CPU is desirable, because Python
- # is inherently single-threaded due to the GIL, and can suffer a
- # 30-40%% slowdown due to cache blow-out and thread context switching
- # if the scheduler happens to schedule the underlying threads across
- # different cores. See
- # https://www.mirantis.com/blog/improve-performance-python-programs-restricting-single-cpu/.
- #
- # This setting requires the affinity package to be installed!
- #
- #cpu_affinity: 0xFFFFFFFF
-
# The path to the web client which will be served at /_matrix/client/
# if 'webclient' is configured under the 'listeners' configuration.
#
@@ -542,22 +641,23 @@ class ServerConfig(Config):
#
#require_auth_for_profile_requests: true
- # Whether to require a user to share a room with another user in order
+ # Uncomment to require a user to share a room with another user in order
# to retrieve their profile information. Only checked on Client-Server
# requests. Profile requests from other servers should be checked by the
# requesting server. Defaults to 'false'.
#
- # limit_profile_requests_to_known_users: true
+ #limit_profile_requests_to_users_who_share_rooms: true
- # If set to 'false', requires authentication to access the server's public rooms
- # directory through the client API. Defaults to 'true'.
+ # If set to 'true', removes the need for authentication to access the server's
+ # public rooms directory through the client API, meaning that anyone can
+ # query the room directory. Defaults to 'false'.
#
- #allow_public_rooms_without_auth: false
+ #allow_public_rooms_without_auth: true
- # If set to 'false', forbids any other homeserver to fetch the server's public
- # rooms directory via federation. Defaults to 'true'.
+ # If set to 'true', allows any other homeserver to fetch the server's public
+ # rooms directory via federation. Defaults to 'false'.
#
- #allow_public_rooms_over_federation: false
+ #allow_public_rooms_over_federation: true
# The default room version for newly created rooms.
#
@@ -581,7 +681,7 @@ class ServerConfig(Config):
# Whether room invites to users on this server should be blocked
# (except those sent by local server admins). The default is False.
#
- #block_non_admin_invites: True
+ #block_non_admin_invites: true
# Room searching
#
@@ -605,6 +705,9 @@ class ServerConfig(Config):
# blacklist IP address CIDR ranges. If this option is not specified, or
# specified with an empty list, no ip range blacklist will be enforced.
#
+ # As of Synapse v1.4.0 this option also affects any outbound requests to identity
+ # servers provided by user input.
+ #
# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
# listed here, since they correspond to unroutable addresses.)
#
@@ -631,8 +734,8 @@ class ServerConfig(Config):
#
# type: the type of listener. Normally 'http', but other valid options are:
# 'manhole' (see docs/manhole.md),
- # 'metrics' (see docs/metrics-howto.rst),
- # 'replication' (see docs/workers.rst).
+ # 'metrics' (see docs/metrics-howto.md),
+ # 'replication' (see docs/workers.md).
#
# tls: set to true to enable TLS for this listener. Will use the TLS
# key/cert specified in tls_private_key_path / tls_certificate_path.
@@ -667,12 +770,12 @@ class ServerConfig(Config):
#
# media: the media API (/_matrix/media).
#
- # metrics: the metrics interface. See docs/metrics-howto.rst.
+ # metrics: the metrics interface. See docs/metrics-howto.md.
#
# openid: OpenID authentication.
#
# replication: the HTTP replication API (/_synapse/replication). See
- # docs/workers.rst.
+ # docs/workers.md.
#
# static: static resources under synapse/static (/_matrix/static). (Mostly
# useful for 'fallback authentication'.)
@@ -686,29 +789,17 @@ class ServerConfig(Config):
# will also need to give Synapse a TLS key and certificate: see the TLS section
# below.)
#
- #- port: %(bind_port)s
- # type: http
- # tls: true
- # resources:
- # - names: [client, federation]
+ %(secure_http_bindings)s
# Unsecure HTTP listener: for when matrix traffic passes through a reverse proxy
# that unwraps TLS.
#
# If you plan to use a reverse proxy, please see
- # https://github.com/matrix-org/synapse/blob/master/docs/reverse_proxy.rst.
+ # https://github.com/matrix-org/synapse/blob/master/docs/reverse_proxy.md.
#
- - port: %(unsecure_port)s
- tls: false
- bind_addresses: ['::1', '127.0.0.1']
- type: http
- x_forwarded: true
-
- resources:
- - names: [client, federation]
- compress: false
+ %(unsecure_http_bindings)s
- # example additonal_resources:
+ # example additional_resources:
#
#additional_resources:
# "/_matrix/my/custom/endpoint":
@@ -731,9 +822,8 @@ class ServerConfig(Config):
# Global blocking
#
- #hs_disabled: False
+ #hs_disabled: false
#hs_disabled_message: 'Human readable reason for why the HS is blocked'
- #hs_disabled_limit_type: 'error code(str), to help clients decode reason'
# Monthly Active User Blocking
#
@@ -753,15 +843,22 @@ class ServerConfig(Config):
# sign up in a short space of time never to return after their initial
# session.
#
- #limit_usage_by_mau: False
+ # 'mau_limit_alerting' is a means of limiting client side alerting
+ # should the mau limit be reached. This is useful for small instances
+ # where the admin has 5 mau seats (say) for 5 specific people and no
+ # interest increasing the mau limit further. Defaults to True, which
+ # means that alerting is enabled
+ #
+ #limit_usage_by_mau: false
#max_mau_value: 50
#mau_trial_days: 2
+ #mau_limit_alerting: false
# If enabled, the metrics for the number of monthly active users will
# be populated, however no one will be limited. If limit_usage_by_mau
# is true, this is implied to be true.
#
- #mau_stats_only: False
+ #mau_stats_only: false
# Sometimes the server admin will want to ensure certain accounts are
# never blocked by mau checking. These accounts are specified here.
@@ -773,6 +870,23 @@ class ServerConfig(Config):
# Used by phonehome stats to group together related servers.
#server_context: context
+ # Resource-constrained homeserver Settings
+ #
+ # If limit_remote_rooms.enabled is True, the room complexity will be
+ # checked before a user joins a new remote room. If it is above
+ # limit_remote_rooms.complexity, it will disallow joining or
+ # instantly leave.
+ #
+ # limit_remote_rooms.complexity_error can be set to customise the text
+ # displayed to the user when a room above the complexity threshold has
+ # its join cancelled.
+ #
+ # Uncomment the below lines to enable:
+ #limit_remote_rooms:
+ # enabled: true
+ # complexity: 1.0
+ # complexity_error: "This room is too complex."
+
# Whether to require a user to be in the room to add an alias to it.
# Defaults to 'true'.
#
@@ -851,7 +965,85 @@ class ServerConfig(Config):
# - shortest_max_lifetime: 3d
# longest_max_lifetime: 1y
# interval: 24h
- """ % locals()
+
+ # How long to keep redacted events in unredacted form in the database. After
+ # this period redacted events get replaced with their redacted form in the DB.
+ #
+ # Defaults to `7d`. Set to `null` to disable.
+ #
+ #redaction_retention_period: 28d
+
+ # How long to track users' last seen time and IPs in the database.
+ #
+ # Defaults to `28d`. Set to `null` to disable clearing out of old rows.
+ #
+ #user_ips_max_age: 14d
+
+ # Message retention policy at the server level.
+ #
+ # Room admins and mods can define a retention period for their rooms using the
+ # 'm.room.retention' state event, and server admins can cap this period by setting
+ # the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options.
+ #
+ # If this feature is enabled, Synapse will regularly look for and purge events
+ # which are older than the room's maximum retention period. Synapse will also
+ # filter events received over federation so that events that should have been
+ # purged are ignored and not stored again.
+ #
+ retention:
+ # The message retention policies feature is disabled by default. Uncomment the
+ # following line to enable it.
+ #
+ #enabled: true
+
+ # Default retention policy. If set, Synapse will apply it to rooms that lack the
+ # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't
+ # matter much because Synapse doesn't take it into account yet.
+ #
+ #default_policy:
+ # min_lifetime: 1d
+ # max_lifetime: 1y
+
+ # Retention policy limits. If set, a user won't be able to send a
+ # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
+ # that's not within this range. This is especially useful in closed federations,
+ # in which server admins can make sure every federating server applies the same
+ # rules.
+ #
+ #allowed_lifetime_min: 1d
+ #allowed_lifetime_max: 1y
+
+ # Server admins can define the settings of the background jobs purging the
+ # events which lifetime has expired under the 'purge_jobs' section.
+ #
+ # If no configuration is provided, a single job will be set up to delete expired
+ # events in every room daily.
+ #
+ # Each job's configuration defines which range of message lifetimes the job
+ # takes care of. For example, if 'shortest_max_lifetime' is '2d' and
+ # 'longest_max_lifetime' is '3d', the job will handle purging expired events in
+ # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and
+ # lower than or equal to 3 days. Both the minimum and the maximum value of a
+ # range are optional, e.g. a job with no 'shortest_max_lifetime' and a
+ # 'longest_max_lifetime' of '3d' will handle every room with a retention policy
+ # which 'max_lifetime' is lower than or equal to three days.
+ #
+ # The rationale for this per-job configuration is that some rooms might have a
+ # retention policy with a low 'max_lifetime', where history needs to be purged
+ # of outdated messages on a more frequent basis than for the rest of the rooms
+ # (e.g. every 12h), but not want that purge to be performed by a job that's
+ # iterating over every room it knows, which could be heavy on the server.
+ #
+ #purge_jobs:
+ # - shortest_max_lifetime: 1d
+ # longest_max_lifetime: 3d
+ # interval: 12h
+ # - shortest_max_lifetime: 3d
+ # longest_max_lifetime: 1y
+ # interval: 1d
+ """
+ % locals()
+ )
def read_arguments(self, args):
if args.manhole is not None:
@@ -861,19 +1053,29 @@ class ServerConfig(Config):
if args.print_pidfile is not None:
self.print_pidfile = args.print_pidfile
- def add_arguments(self, parser):
+ @staticmethod
+ def add_arguments(parser):
server_group = parser.add_argument_group("server")
- server_group.add_argument("-D", "--daemonize", action='store_true',
- default=None,
- help="Daemonize the home server")
- server_group.add_argument("--print-pidfile", action='store_true',
- default=None,
- help="Print the path to the pidfile just"
- " before daemonizing")
- server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
- type=int,
- help="Turn on the twisted telnet manhole"
- " service on the given port.")
+ server_group.add_argument(
+ "-D",
+ "--daemonize",
+ action="store_true",
+ default=None,
+ help="Daemonize the homeserver",
+ )
+ server_group.add_argument(
+ "--print-pidfile",
+ action="store_true",
+ default=None,
+ help="Print the path to the pidfile just before daemonizing",
+ )
+ server_group.add_argument(
+ "--manhole",
+ metavar="PORT",
+ dest="manhole",
+ type=int,
+ help="Turn on the twisted telnet manhole service on the given port.",
+ )
def is_threepid_reserved(reserved_threepids, threepid):
@@ -887,7 +1089,7 @@ def is_threepid_reserved(reserved_threepids, threepid):
"""
for tp in reserved_threepids:
- if (threepid['medium'] == tp['medium'] and threepid['address'] == tp['address']):
+ if threepid["medium"] == tp["medium"] and threepid["address"] == tp["address"]:
return True
return False
@@ -900,9 +1102,7 @@ def read_gc_thresholds(thresholds):
return None
try:
assert len(thresholds) == 3
- return (
- int(thresholds[0]), int(thresholds[1]), int(thresholds[2]),
- )
+ return (int(thresholds[0]), int(thresholds[1]), int(thresholds[2]))
except Exception:
raise ConfigError(
"Value of `gc_threshold` must be a list of three integers if set"
@@ -920,40 +1120,38 @@ def _warn_if_webclient_configured(listeners):
for listener in listeners:
for res in listener.get("resources", []):
for name in res.get("names", []):
- if name == 'webclient':
+ if name == "webclient":
logger.warning(NO_MORE_WEB_CLIENT_WARNING)
return
KNOWN_RESOURCES = (
- 'client',
- 'consent',
- 'federation',
- 'keys',
- 'media',
- 'metrics',
- 'openid',
- 'replication',
- 'static',
- 'webclient',
+ "client",
+ "consent",
+ "federation",
+ "keys",
+ "media",
+ "metrics",
+ "openid",
+ "replication",
+ "static",
+ "webclient",
)
def _check_resource_config(listeners):
- resource_names = set(
+ resource_names = {
res_name
for listener in listeners
for res in listener.get("resources", [])
for res_name in res.get("names", [])
- )
+ }
for resource in resource_names:
if resource not in KNOWN_RESOURCES:
- raise ConfigError(
- "Unknown listener resource '%s'" % (resource, )
- )
+ raise ConfigError("Unknown listener resource '%s'" % (resource,))
if resource == "consent":
try:
- check_requirements('resources.consent')
+ check_requirements("resources.consent")
except DependencyException as e:
raise ConfigError(e.message)
diff --git a/synapse/config/server_notices_config.py b/synapse/config/server_notices_config.py
index 529dc0a617..6ea2ea8869 100644
--- a/synapse/config/server_notices_config.py
+++ b/synapse/config/server_notices_config.py
@@ -58,30 +58,27 @@ class ServerNoticesConfig(Config):
The name to use for the server notices room.
None if server notices are not enabled.
"""
- def __init__(self):
- super(ServerNoticesConfig, self).__init__()
+
+ section = "servernotices"
+
+ def __init__(self, *args):
+ super(ServerNoticesConfig, self).__init__(*args)
self.server_notices_mxid = None
self.server_notices_mxid_display_name = None
self.server_notices_mxid_avatar_url = None
self.server_notices_room_name = None
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
c = config.get("server_notices")
if c is None:
return
- mxid_localpart = c['system_mxid_localpart']
- self.server_notices_mxid = UserID(
- mxid_localpart, self.server_name,
- ).to_string()
- self.server_notices_mxid_display_name = c.get(
- 'system_mxid_display_name', None,
- )
- self.server_notices_mxid_avatar_url = c.get(
- 'system_mxid_avatar_url', None,
- )
+ mxid_localpart = c["system_mxid_localpart"]
+ self.server_notices_mxid = UserID(mxid_localpart, self.server_name).to_string()
+ self.server_notices_mxid_display_name = c.get("system_mxid_display_name", None)
+ self.server_notices_mxid_avatar_url = c.get("system_mxid_avatar_url", None)
# todo: i18n
- self.server_notices_room_name = c.get('room_name', "Server Notices")
+ self.server_notices_room_name = c.get("room_name", "Server Notices")
- def default_config(self, **kwargs):
+ def generate_config_section(self, **kwargs):
return DEFAULT_CONFIG
diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
index 1502e9faba..36e0ddab5c 100644
--- a/synapse/config/spam_checker.py
+++ b/synapse/config/spam_checker.py
@@ -19,14 +19,16 @@ from ._base import Config
class SpamCheckerConfig(Config):
- def read_config(self, config):
+ section = "spamchecker"
+
+ def read_config(self, config, **kwargs):
self.spam_checker = None
provider = config.get("spam_checker", None)
if provider is not None:
self.spam_checker = load_module(provider)
- def default_config(self, **kwargs):
+ def generate_config_section(self, **kwargs):
return """\
#spam_checker:
# module: "my_custom_project.SuperSpamChecker"
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
new file mode 100644
index 0000000000..95762689bc
--- /dev/null
+++ b/synapse/config/sso.py
@@ -0,0 +1,92 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Dict
+
+import pkg_resources
+
+from ._base import Config
+
+
+class SSOConfig(Config):
+ """SSO Configuration
+ """
+
+ section = "sso"
+
+ def read_config(self, config, **kwargs):
+ sso_config = config.get("sso") or {} # type: Dict[str, Any]
+
+ # Pick a template directory in order of:
+ # * The sso-specific template_dir
+ # * /path/to/synapse/install/res/templates
+ template_dir = sso_config.get("template_dir")
+ if not template_dir:
+ template_dir = pkg_resources.resource_filename("synapse", "res/templates",)
+
+ self.sso_redirect_confirm_template_dir = template_dir
+
+ self.sso_client_whitelist = sso_config.get("client_whitelist") or []
+
+ def generate_config_section(self, **kwargs):
+ return """\
+ # Additional settings to use with single-sign on systems such as SAML2 and CAS.
+ #
+ sso:
+ # A list of client URLs which are whitelisted so that the user does not
+ # have to confirm giving access to their account to the URL. Any client
+ # whose URL starts with an entry in the following list will not be subject
+ # to an additional confirmation step after the SSO login is completed.
+ #
+ # WARNING: An entry such as "https://my.client" is insecure, because it
+ # will also match "https://my.client.evil.site", exposing your users to
+ # phishing attacks from evil.site. To avoid this, include a slash after the
+ # hostname: "https://my.client/".
+ #
+ # By default, this list is empty.
+ #
+ #client_whitelist:
+ # - https://riot.im/develop
+ # - https://my.custom.client/
+
+ # Directory in which Synapse will try to find the template files below.
+ # If not set, default templates from within the Synapse package will be used.
+ #
+ # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
+ # If you *do* uncomment it, you will need to make sure that all the templates
+ # below are in the directory.
+ #
+ # Synapse will look for the following templates in this directory:
+ #
+ # * HTML page for a confirmation step before redirecting back to the client
+ # with the login token: 'sso_redirect_confirm.html'.
+ #
+ # When rendering, this template is given three variables:
+ # * redirect_url: the URL the user is about to be redirected to. Needs
+ # manual escaping (see
+ # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
+ #
+ # * display_url: the same as `redirect_url`, but with the query
+ # parameters stripped. The intention is to have a
+ # human-readable URL to show to users, not to use it as
+ # the final address to redirect to. Needs manual escaping
+ # (see https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
+ #
+ # * server_name: the homeserver's name.
+ #
+ # You can see the default templates at:
+ # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
+ #
+ #template_dir: "res/templates"
+ """
diff --git a/synapse/config/stats.py b/synapse/config/stats.py
index 80fc1b9dd0..62485189ea 100644
--- a/synapse/config/stats.py
+++ b/synapse/config/stats.py
@@ -25,24 +25,23 @@ class StatsConfig(Config):
Configuration for the behaviour of synapse's stats engine
"""
- def read_config(self, config):
+ section = "stats"
+
+ def read_config(self, config, **kwargs):
self.stats_enabled = True
- self.stats_bucket_size = 86400
+ self.stats_bucket_size = 86400 * 1000
self.stats_retention = sys.maxsize
stats_config = config.get("stats", None)
if stats_config:
self.stats_enabled = stats_config.get("enabled", self.stats_enabled)
- self.stats_bucket_size = (
- self.parse_duration(stats_config.get("bucket_size", "1d")) / 1000
+ self.stats_bucket_size = self.parse_duration(
+ stats_config.get("bucket_size", "1d")
)
- self.stats_retention = (
- self.parse_duration(
- stats_config.get("retention", "%ds" % (sys.maxsize,))
- )
- / 1000
+ self.stats_retention = self.parse_duration(
+ stats_config.get("retention", "%ds" % (sys.maxsize,))
)
- def default_config(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
# Local statistics collection. Used in populating the room directory.
#
diff --git a/synapse/config/third_party_event_rules.py b/synapse/config/third_party_event_rules.py
index a89dd5f98a..10a99c792e 100644
--- a/synapse/config/third_party_event_rules.py
+++ b/synapse/config/third_party_event_rules.py
@@ -19,14 +19,16 @@ from ._base import Config
class ThirdPartyRulesConfig(Config):
- def read_config(self, config):
+ section = "thirdpartyrules"
+
+ def read_config(self, config, **kwargs):
self.third_party_event_rules = None
provider = config.get("third_party_event_rules", None)
if provider is not None:
self.third_party_event_rules = load_module(provider)
- def default_config(self, **kwargs):
+ def generate_config_section(self, **kwargs):
return """\
# Server admins can define a Python module that implements extra rules for
# allowing or denying incoming events. In order to work, this module needs to
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 658f9dd361..a65538562b 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -18,12 +18,13 @@ import os
import warnings
from datetime import datetime
from hashlib import sha256
+from typing import List
import six
from unpaddedbase64 import encode_base64
-from OpenSSL import crypto
+from OpenSSL import SSL, crypto
from twisted.internet._sslverify import Certificate, trustRootFromCertificates
from synapse.config._base import Config, ConfigError
@@ -31,9 +32,22 @@ from synapse.util import glob_to_regex
logger = logging.getLogger(__name__)
+ACME_SUPPORT_ENABLED_WARN = """\
+This server uses Synapse's built-in ACME support. Note that ACME v1 has been
+deprecated by Let's Encrypt, and that Synapse doesn't currently support ACME v2,
+which means that this feature will not work with Synapse installs set up after
+November 2019, and that it may stop working on June 2020 for installs set up
+before that date.
+
+For more info and alternative solutions, see
+https://github.com/matrix-org/synapse/blob/master/docs/ACME.md#deprecation-of-acme-v1
+--------------------------------------------------------------------------------"""
+
class TlsConfig(Config):
- def read_config(self, config):
+ section = "tls"
+
+ def read_config(self, config: dict, config_dir_path: str, **kwargs):
acme_config = config.get("acme", None)
if acme_config is None:
@@ -41,19 +55,26 @@ class TlsConfig(Config):
self.acme_enabled = acme_config.get("enabled", False)
+ if self.acme_enabled:
+ logger.warning(ACME_SUPPORT_ENABLED_WARN)
+
# hyperlink complains on py2 if this is not a Unicode
- self.acme_url = six.text_type(acme_config.get(
- "url", u"https://acme-v01.api.letsencrypt.org/directory"
- ))
+ self.acme_url = six.text_type(
+ acme_config.get("url", "https://acme-v01.api.letsencrypt.org/directory")
+ )
self.acme_port = acme_config.get("port", 80)
- self.acme_bind_addresses = acme_config.get("bind_addresses", ['::', '0.0.0.0'])
+ self.acme_bind_addresses = acme_config.get("bind_addresses", ["::", "0.0.0.0"])
self.acme_reprovision_threshold = acme_config.get("reprovision_threshold", 30)
self.acme_domain = acme_config.get("domain", config.get("server_name"))
+ self.acme_account_key_file = self.abspath(
+ acme_config.get("account_key_file", config_dir_path + "/client.key")
+ )
+
self.tls_certificate_file = self.abspath(config.get("tls_certificate_path"))
self.tls_private_key_file = self.abspath(config.get("tls_private_key_path"))
- if self.has_tls_listener():
+ if self.root.server.has_tls_listener():
if not self.tls_certificate_file:
raise ConfigError(
"tls_certificate_path must be specified if TLS-enabled listeners are "
@@ -74,25 +95,53 @@ class TlsConfig(Config):
# Whether to verify certificates on outbound federation traffic
self.federation_verify_certificates = config.get(
- "federation_verify_certificates", True,
+ "federation_verify_certificates", True
)
+ # Minimum TLS version to use for outbound federation traffic
+ self.federation_client_minimum_tls_version = str(
+ config.get("federation_client_minimum_tls_version", 1)
+ )
+
+ if self.federation_client_minimum_tls_version not in ["1", "1.1", "1.2", "1.3"]:
+ raise ConfigError(
+ "federation_client_minimum_tls_version must be one of: 1, 1.1, 1.2, 1.3"
+ )
+
+ # Prevent people shooting themselves in the foot here by setting it to
+ # the biggest number blindly
+ if self.federation_client_minimum_tls_version == "1.3":
+ if getattr(SSL, "OP_NO_TLSv1_3", None) is None:
+ raise ConfigError(
+ (
+ "federation_client_minimum_tls_version cannot be 1.3, "
+ "your OpenSSL does not support it"
+ )
+ )
+
# Whitelist of domains to not verify certificates for
fed_whitelist_entries = config.get(
- "federation_certificate_verification_whitelist", [],
+ "federation_certificate_verification_whitelist", []
)
+ if fed_whitelist_entries is None:
+ fed_whitelist_entries = []
# Support globs (*) in whitelist values
- self.federation_certificate_verification_whitelist = []
+ self.federation_certificate_verification_whitelist = [] # type: List[str]
for entry in fed_whitelist_entries:
+ try:
+ entry_regex = glob_to_regex(entry.encode("ascii").decode("ascii"))
+ except UnicodeEncodeError:
+ raise ConfigError(
+ "IDNA domain names are not allowed in the "
+ "federation_certificate_verification_whitelist: %s" % (entry,)
+ )
+
# Convert globs to regex
- entry_regex = glob_to_regex(entry)
self.federation_certificate_verification_whitelist.append(entry_regex)
# List of custom certificate authorities for federation traffic validation
- custom_ca_list = config.get(
- "federation_custom_ca_list", None,
- )
+ custom_ca_list = config.get("federation_custom_ca_list", None)
# Read in and parse custom CA certificates
self.federation_ca_trust_root = None
@@ -101,8 +150,10 @@ class TlsConfig(Config):
# A trustroot cannot be generated without any CA certificates.
# Raise an error if this option has been specified without any
# corresponding certificates.
- raise ConfigError("federation_custom_ca_list specified without "
- "any certificate files")
+ raise ConfigError(
+ "federation_custom_ca_list specified without "
+ "any certificate files"
+ )
certs = []
for ca_file in custom_ca_list:
@@ -114,8 +165,9 @@ class TlsConfig(Config):
cert_base = Certificate.loadPEM(content)
certs.append(cert_base)
except Exception as e:
- raise ConfigError("Error parsing custom CA certificate file %s: %s"
- % (ca_file, e))
+ raise ConfigError(
+ "Error parsing custom CA certificate file %s: %s" % (ca_file, e)
+ )
self.federation_ca_trust_root = trustRootFromCertificates(certs)
@@ -146,17 +198,21 @@ class TlsConfig(Config):
return None
try:
- with open(self.tls_certificate_file, 'rb') as f:
+ with open(self.tls_certificate_file, "rb") as f:
cert_pem = f.read()
except Exception as e:
- raise ConfigError("Failed to read existing certificate file %s: %s"
- % (self.tls_certificate_file, e))
+ raise ConfigError(
+ "Failed to read existing certificate file %s: %s"
+ % (self.tls_certificate_file, e)
+ )
try:
tls_certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)
except Exception as e:
- raise ConfigError("Failed to parse existing certificate file %s: %s"
- % (self.tls_certificate_file, e))
+ raise ConfigError(
+ "Failed to parse existing certificate file %s: %s"
+ % (self.tls_certificate_file, e)
+ )
if not allow_self_signed:
if tls_certificate.get_subject() == tls_certificate.get_issuer():
@@ -166,7 +222,7 @@ class TlsConfig(Config):
# YYYYMMDDhhmmssZ -- in UTC
expires_on = datetime.strptime(
- tls_certificate.get_notAfter().decode('ascii'), "%Y%m%d%H%M%SZ"
+ tls_certificate.get_notAfter().decode("ascii"), "%Y%m%d%H%M%SZ"
)
now = datetime.utcnow()
days_remaining = (expires_on - now).days
@@ -191,7 +247,8 @@ class TlsConfig(Config):
except Exception as e:
logger.info(
"Unable to read TLS certificate (%s). Ignoring as no "
- "tls listeners enabled.", e,
+ "tls listeners enabled.",
+ e,
)
self.tls_fingerprints = list(self._original_tls_fingerprints)
@@ -203,22 +260,54 @@ class TlsConfig(Config):
crypto.FILETYPE_ASN1, self.tls_certificate
)
sha256_fingerprint = encode_base64(sha256(x509_certificate_bytes).digest())
- sha256_fingerprints = set(f["sha256"] for f in self.tls_fingerprints)
+ sha256_fingerprints = {f["sha256"] for f in self.tls_fingerprints}
if sha256_fingerprint not in sha256_fingerprints:
- self.tls_fingerprints.append({u"sha256": sha256_fingerprint})
+ self.tls_fingerprints.append({"sha256": sha256_fingerprint})
+
+ def generate_config_section(
+ self,
+ config_dir_path,
+ server_name,
+ data_dir_path,
+ tls_certificate_path,
+ tls_private_key_path,
+ acme_domain,
+ **kwargs
+ ):
+ """If the acme_domain is specified acme will be enabled.
+ If the TLS paths are not specified the default will be certs in the
+ config directory"""
- def default_config(self, config_dir_path, server_name, **kwargs):
base_key_name = os.path.join(config_dir_path, server_name)
- tls_certificate_path = base_key_name + ".tls.crt"
- tls_private_key_path = base_key_name + ".tls.key"
+ if bool(tls_certificate_path) != bool(tls_private_key_path):
+ raise ConfigError(
+ "Please specify both a cert path and a key path or neither."
+ )
+
+ tls_enabled = (
+ "" if tls_certificate_path and tls_private_key_path or acme_domain else "#"
+ )
+
+ if not tls_certificate_path:
+ tls_certificate_path = base_key_name + ".tls.crt"
+ if not tls_private_key_path:
+ tls_private_key_path = base_key_name + ".tls.key"
+
+ acme_enabled = bool(acme_domain)
+ acme_domain = "matrix.example.com"
+
+ default_acme_account_file = os.path.join(data_dir_path, "acme_account.key")
# this is to avoid the max line length. Sorrynotsorry
proxypassline = (
- 'ProxyPass /.well-known/acme-challenge '
- 'http://localhost:8009/.well-known/acme-challenge'
+ "ProxyPass /.well-known/acme-challenge "
+ "http://localhost:8009/.well-known/acme-challenge"
)
+ # flake8 doesn't recognise that variables are used in the below string
+ _ = tls_enabled, proxypassline, acme_enabled, default_acme_account_file
+
return (
"""\
## TLS ##
@@ -235,11 +324,11 @@ class TlsConfig(Config):
# instance, if using certbot, use `fullchain.pem` as your certificate,
# not `cert.pem`).
#
- #tls_certificate_path: "%(tls_certificate_path)s"
+ %(tls_enabled)stls_certificate_path: "%(tls_certificate_path)s"
# PEM-encoded private key for TLS
#
- #tls_private_key_path: "%(tls_private_key_path)s"
+ %(tls_enabled)stls_private_key_path: "%(tls_private_key_path)s"
# Whether to verify TLS server certificates for outbound federation requests.
#
@@ -248,6 +337,15 @@ class TlsConfig(Config):
#
#federation_verify_certificates: false
+ # The minimum TLS version that will be used for outbound federation requests.
+ #
+ # Defaults to `1`. Configurable to `1`, `1.1`, `1.2`, or `1.3`. Note
+ # that setting this value higher than `1.2` will prevent federation to most
+ # of the public Matrix network: only configure it to `1.3` if you have an
+ # entirely private federation setup and you can ensure TLS 1.3 support.
+ #
+ #federation_client_minimum_tls_version: 1.2
+
# Skip federation certificate verification on the following whitelist
# of domains.
#
@@ -278,6 +376,11 @@ class TlsConfig(Config):
# ACME support: This will configure Synapse to request a valid TLS certificate
# for your configured `server_name` via Let's Encrypt.
#
+ # Note that ACME v1 is now deprecated, and Synapse currently doesn't support
+ # ACME v2. This means that this feature currently won't work with installs set
+ # up after November 2019. For more info, and alternative solutions, see
+ # https://github.com/matrix-org/synapse/blob/master/docs/ACME.md#deprecation-of-acme-v1
+ #
# Note that provisioning a certificate in this way requires port 80 to be
# routed to Synapse so that it can complete the http-01 ACME challenge.
# By default, if you enable ACME support, Synapse will attempt to listen on
@@ -297,10 +400,10 @@ class TlsConfig(Config):
# permission to listen on port 80.
#
acme:
- # ACME support is disabled by default. Uncomment the following line
- # (and tls_certificate_path and tls_private_key_path above) to enable it.
+ # ACME support is disabled by default. Set this to `true` and uncomment
+ # tls_certificate_path and tls_private_key_path above to enable it.
#
- #enabled: true
+ enabled: %(acme_enabled)s
# Endpoint to use to request certificates. If you only want to test,
# use Let's Encrypt's staging url:
@@ -311,17 +414,17 @@ class TlsConfig(Config):
# Port number to listen on for the HTTP-01 challenge. Change this if
# you are forwarding connections through Apache/Nginx/etc.
#
- #port: 80
+ port: 80
# Local addresses to listen on for incoming connections.
# Again, you may want to change this if you are forwarding connections
# through Apache/Nginx/etc.
#
- #bind_addresses: ['::', '0.0.0.0']
+ bind_addresses: ['::', '0.0.0.0']
# How many days remaining on a certificate before it is renewed.
#
- #reprovision_threshold: 30
+ reprovision_threshold: 30
# The domain that the certificate should be for. Normally this
# should be the same as your Matrix domain (i.e., 'server_name'), but,
@@ -335,7 +438,14 @@ class TlsConfig(Config):
#
# If not set, defaults to your 'server_name'.
#
- #domain: matrix.example.com
+ domain: %(acme_domain)s
+
+ # file to use for the account key. This will be generated if it doesn't
+ # exist.
+ #
+ # If unspecified, we will use CONFDIR/client.key.
+ #
+ account_key_file: %(default_acme_account_file)s
# List of allowed TLS fingerprints for this server to publish along
# with the signing keys for this server. Other matrix servers that
@@ -365,7 +475,11 @@ class TlsConfig(Config):
#tls_fingerprints: [{"sha256": "<base64_encoded_sha256_fingerprint>"}]
"""
- % locals()
+ # Lowercase the string representation of boolean values
+ % {
+ x[0]: str(x[1]).lower() if isinstance(x[1], bool) else x[1]
+ for x in locals().items()
+ }
)
def read_tls_certificate(self):
diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py
new file mode 100644
index 0000000000..8be1346113
--- /dev/null
+++ b/synapse/config/tracer.py
@@ -0,0 +1,90 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.d
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.python_dependencies import DependencyException, check_requirements
+
+from ._base import Config, ConfigError
+
+
+class TracerConfig(Config):
+ section = "tracing"
+
+ def read_config(self, config, **kwargs):
+ opentracing_config = config.get("opentracing")
+ if opentracing_config is None:
+ opentracing_config = {}
+
+ self.opentracer_enabled = opentracing_config.get("enabled", False)
+
+ self.jaeger_config = opentracing_config.get(
+ "jaeger_config",
+ {"sampler": {"type": "const", "param": 1}, "logging": False},
+ )
+
+ if not self.opentracer_enabled:
+ return
+
+ try:
+ check_requirements("opentracing")
+ except DependencyException as e:
+ raise ConfigError(e.message)
+
+ # The tracer is enabled so sanitize the config
+
+ self.opentracer_whitelist = opentracing_config.get("homeserver_whitelist", [])
+ if not isinstance(self.opentracer_whitelist, list):
+ raise ConfigError("Tracer homeserver_whitelist config is malformed")
+
+ def generate_config_section(cls, **kwargs):
+ return """\
+ ## Opentracing ##
+
+ # These settings enable opentracing, which implements distributed tracing.
+ # This allows you to observe the causal chains of events across servers
+ # including requests, key lookups etc., across any server running
+ # synapse or any other other services which supports opentracing
+ # (specifically those implemented with Jaeger).
+ #
+ opentracing:
+ # tracing is disabled by default. Uncomment the following line to enable it.
+ #
+ #enabled: true
+
+ # The list of homeservers we wish to send and receive span contexts and span baggage.
+ # See docs/opentracing.rst
+ # This is a list of regexes which are matched against the server_name of the
+ # homeserver.
+ #
+ # By defult, it is empty, so no servers are matched.
+ #
+ #homeserver_whitelist:
+ # - ".*"
+
+ # Jaeger can be configured to sample traces at different rates.
+ # All configuration options provided by Jaeger can be set here.
+ # Jaeger's configuration mostly related to trace sampling which
+ # is documented here:
+ # https://www.jaegertracing.io/docs/1.13/sampling/.
+ #
+ #jaeger_config:
+ # sampler:
+ # type: const
+ # param: 1
+
+ # Logging whether spans were started and reported
+ #
+ # logging:
+ # false
+ """
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index 4376a23636..43b6c40456 100644
--- a/synapse/config/user_directory.py
+++ b/synapse/config/user_directory.py
@@ -21,23 +21,25 @@ class UserDirectoryConfig(Config):
Configuration for the behaviour of the /user_directory API
"""
- def read_config(self, config):
+ section = "userdirectory"
+
+ def read_config(self, config, **kwargs):
self.user_directory_search_enabled = True
self.user_directory_search_all_users = False
self.user_directory_defer_to_id_server = None
user_directory_config = config.get("user_directory", None)
if user_directory_config:
- self.user_directory_search_enabled = (
- user_directory_config.get("enabled", True)
+ self.user_directory_search_enabled = user_directory_config.get(
+ "enabled", True
)
- self.user_directory_search_all_users = (
- user_directory_config.get("search_all_users", False)
+ self.user_directory_search_all_users = user_directory_config.get(
+ "search_all_users", False
)
- self.user_directory_defer_to_id_server = (
- user_directory_config.get("defer_to_id_server", None)
+ self.user_directory_defer_to_id_server = user_directory_config.get(
+ "defer_to_id_server", None
)
- def default_config(self, config_dir_path, server_name, **kwargs):
+ def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
# User Directory configuration
#
diff --git a/synapse/config/voip.py b/synapse/config/voip.py
index 2a1f005a37..b313bff140 100644
--- a/synapse/config/voip.py
+++ b/synapse/config/voip.py
@@ -16,18 +16,19 @@ from ._base import Config
class VoipConfig(Config):
+ section = "voip"
- def read_config(self, config):
+ def read_config(self, config, **kwargs):
self.turn_uris = config.get("turn_uris", [])
self.turn_shared_secret = config.get("turn_shared_secret")
self.turn_username = config.get("turn_username")
self.turn_password = config.get("turn_password")
self.turn_user_lifetime = self.parse_duration(
- config.get("turn_user_lifetime", "1h"),
+ config.get("turn_user_lifetime", "1h")
)
self.turn_allow_guests = config.get("turn_allow_guests", True)
- def default_config(self, **kwargs):
+ def generate_config_section(self, **kwargs):
return """\
## TURN ##
@@ -55,5 +56,5 @@ class VoipConfig(Config):
# connect to arbitrary endpoints without having first signed up for a
# valid account (e.g. by passing a CAPTCHA).
#
- #turn_allow_guests: True
+ #turn_allow_guests: true
"""
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index bfbd8b6c91..fef72ed974 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2016 matrix.org
+# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,7 +21,9 @@ class WorkerConfig(Config):
They have their own pid_file and listener configuration. They use the
replication_url to talk to the main synapse process."""
- def read_config(self, config):
+ section = "worker"
+
+ def read_config(self, config, **kwargs):
self.worker_app = config.get("worker_app")
# Canonicalise worker_app so that master always has None
@@ -31,7 +33,6 @@ class WorkerConfig(Config):
self.worker_listeners = config.get("worker_listeners", [])
self.worker_daemonize = config.get("worker_daemonize")
self.worker_pid_file = config.get("worker_pid_file")
- self.worker_log_file = config.get("worker_log_file")
self.worker_log_config = config.get("worker_log_config")
# The host used to connect to the main synapse
@@ -46,18 +47,19 @@ class WorkerConfig(Config):
self.worker_name = config.get("worker_name", self.worker_app)
self.worker_main_http_uri = config.get("worker_main_http_uri", None)
- self.worker_cpu_affinity = config.get("worker_cpu_affinity")
# This option is really only here to support `--manhole` command line
# argument.
manhole = config.get("worker_manhole")
if manhole:
- self.worker_listeners.append({
- "port": manhole,
- "bind_addresses": ["127.0.0.1"],
- "type": "manhole",
- "tls": False,
- })
+ self.worker_listeners.append(
+ {
+ "port": manhole,
+ "bind_addresses": ["127.0.0.1"],
+ "type": "manhole",
+ "tls": False,
+ }
+ )
if self.worker_listeners:
for listener in self.worker_listeners:
@@ -67,7 +69,7 @@ class WorkerConfig(Config):
if bind_address:
bind_addresses.append(bind_address)
elif not bind_addresses:
- bind_addresses.append('')
+ bind_addresses.append("")
def read_arguments(self, args):
# We support a bunch of command line arguments that override options in
@@ -77,9 +79,5 @@ class WorkerConfig(Config):
if args.daemonize is not None:
self.worker_daemonize = args.daemonize
- if args.log_config is not None:
- self.worker_log_config = args.log_config
- if args.log_file is not None:
- self.worker_log_file = args.log_file
if args.manhole is not None:
self.worker_manhole = args.worker_manhole
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index fec197a0d8..a5a2a7815d 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -15,7 +15,6 @@
import logging
-import idna
from service_identity import VerificationError
from service_identity.pyopenssl import verify_hostname, verify_ip_address
from zope.interface import implementer
@@ -24,13 +23,26 @@ from OpenSSL import SSL, crypto
from twisted.internet._sslverify import _defaultCurveName
from twisted.internet.abstract import isIPAddress, isIPv6Address
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
-from twisted.internet.ssl import CertificateOptions, ContextFactory, platformTrust
+from twisted.internet.ssl import (
+ CertificateOptions,
+ ContextFactory,
+ TLSVersion,
+ platformTrust,
+)
from twisted.python.failure import Failure
from twisted.web.iweb import IPolicyForHTTPS
logger = logging.getLogger(__name__)
+_TLS_VERSION_MAP = {
+ "1": TLSVersion.TLSv1_0,
+ "1.1": TLSVersion.TLSv1_1,
+ "1.2": TLSVersion.TLSv1_2,
+ "1.3": TLSVersion.TLSv1_3,
+}
+
+
class ServerContextFactory(ContextFactory):
"""Factory for PyOpenSSL SSL contexts that are used to handle incoming
connections."""
@@ -44,16 +56,18 @@ class ServerContextFactory(ContextFactory):
try:
_ecCurve = crypto.get_elliptic_curve(_defaultCurveName)
context.set_tmp_ecdh(_ecCurve)
-
except Exception:
logger.exception("Failed to enable elliptic curve for TLS")
- context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)
+
+ context.set_options(
+ SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL.OP_NO_TLSv1 | SSL.OP_NO_TLSv1_1
+ )
context.use_certificate_chain_file(config.tls_certificate_file)
context.use_privatekey(config.tls_private_key)
# https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/
context.set_cipher_list(
- "ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES256:ECDH+AES128:!aNULL:!SHA1"
+ "ECDH+AESGCM:ECDH+CHACHA20:ECDH+AES256:ECDH+AES128:!aNULL:!SHA1:!AESCCM"
)
def getContext(self):
@@ -81,20 +95,38 @@ class FederationPolicyForHTTPS(object):
# Use CA root certs provided by OpenSSL
trust_root = platformTrust()
- self._verify_ssl_context = CertificateOptions(trustRoot=trust_root).getContext()
+ # "insecurelyLowerMinimumTo" is the argument that will go lower than
+ # Twisted's default, which is why it is marked as "insecure" (since
+ # Twisted's defaults are reasonably secure). But, since Twisted is
+ # moving to TLS 1.2 by default, we want to respect the config option if
+ # it is set to 1.0 (which the alternate option, raiseMinimumTo, will not
+ # let us do).
+ minTLS = _TLS_VERSION_MAP[config.federation_client_minimum_tls_version]
+
+ _verify_ssl = CertificateOptions(
+ trustRoot=trust_root, insecurelyLowerMinimumTo=minTLS
+ )
+ self._verify_ssl_context = _verify_ssl.getContext()
self._verify_ssl_context.set_info_callback(_context_info_cb)
- self._no_verify_ssl_context = CertificateOptions().getContext()
+ _no_verify_ssl = CertificateOptions(insecurelyLowerMinimumTo=minTLS)
+ self._no_verify_ssl_context = _no_verify_ssl.getContext()
self._no_verify_ssl_context.set_info_callback(_context_info_cb)
- def get_options(self, host):
+ def get_options(self, host: bytes):
+
+ # IPolicyForHTTPS.get_options takes bytes, but we want to compare
+ # against the str whitelist. The hostnames in the whitelist are already
+ # IDNA-encoded like the hosts will be here.
+ ascii_host = host.decode("ascii")
+
# Check if certificate verification has been enabled
should_verify = self._config.federation_verify_certificates
# Check if we've disabled certificate verification for this host
if should_verify:
for regex in self._config.federation_certificate_verification_whitelist:
- if regex.match(host):
+ if regex.match(ascii_host):
should_verify = False
break
@@ -155,7 +187,7 @@ class SSLClientConnectionCreator(object):
Replaces twisted.internet.ssl.ClientTLSOptions
"""
- def __init__(self, hostname, ctx, verify_certs):
+ def __init__(self, hostname: bytes, ctx, verify_certs: bool):
self._ctx = ctx
self._verifier = ConnectionVerifier(hostname, verify_certs)
@@ -183,21 +215,16 @@ class ConnectionVerifier(object):
# This code is based on twisted.internet.ssl.ClientTLSOptions.
- def __init__(self, hostname, verify_certs):
+ def __init__(self, hostname: bytes, verify_certs):
self._verify_certs = verify_certs
- if isIPAddress(hostname) or isIPv6Address(hostname):
- self._hostnameBytes = hostname.encode("ascii")
+ _decoded = hostname.decode("ascii")
+ if isIPAddress(_decoded) or isIPv6Address(_decoded):
self._is_ip_address = True
else:
- # twisted's ClientTLSOptions falls back to the stdlib impl here if
- # idna is not installed, but points out that lacks support for
- # IDNA2008 (http://bugs.python.org/issue17305).
- #
- # We can rely on having idna.
- self._hostnameBytes = idna.encode(hostname)
self._is_ip_address = False
+ self._hostnameBytes = hostname
self._hostnameASCII = self._hostnameBytes.decode("ascii")
def verify_context_info_cb(self, ssl_connection, where):
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index 99a586655b..0422c43fab 100644
--- a/synapse/crypto/event_signing.py
+++ b/synapse/crypto/event_signing.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
-
+#
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,16 +15,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import collections.abc
import hashlib
import logging
+from typing import Dict
from canonicaljson import encode_canonical_json
from signedjson.sign import sign_json
+from signedjson.types import SigningKey
from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.errors import Codes, SynapseError
+from synapse.api.room_versions import RoomVersion
from synapse.events.utils import prune_event, prune_event_dict
+from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -40,15 +45,16 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
# some malformed events lack a 'hashes'. Protect against it being missing
# or a weird type by basically treating it the same as an unhashed event.
hashes = event.get("hashes")
- if not isinstance(hashes, dict):
- raise SynapseError(400, "Malformed 'hashes'", Codes.UNAUTHORIZED)
+ # nb it might be a frozendict or a dict
+ if not isinstance(hashes, collections.abc.Mapping):
+ raise SynapseError(
+ 400, "Malformed 'hashes': %s" % (type(hashes),), Codes.UNAUTHORIZED
+ )
if name not in hashes:
raise SynapseError(
400,
- "Algorithm %s not in hashes %s" % (
- name, list(hashes),
- ),
+ "Algorithm %s not in hashes %s" % (name, list(hashes)),
Codes.UNAUTHORIZED,
)
message_hash_base64 = hashes[name]
@@ -56,9 +62,7 @@ def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
message_hash_bytes = decode_base64(message_hash_base64)
except Exception:
raise SynapseError(
- 400,
- "Invalid base64: %s" % (message_hash_base64,),
- Codes.UNAUTHORIZED,
+ 400, "Invalid base64: %s" % (message_hash_base64,), Codes.UNAUTHORIZED
)
return message_hash_bytes == expected_hash
@@ -87,7 +91,7 @@ def compute_content_hash(event_dict, hash_algorithm):
event_json_bytes = encode_canonical_json(event_dict)
hashed = hash_algorithm(event_json_bytes)
- return (hashed.name, hashed.digest())
+ return hashed.name, hashed.digest()
def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256):
@@ -110,50 +114,64 @@ def compute_event_reference_hash(event, hash_algorithm=hashlib.sha256):
event_dict.pop("unsigned", None)
event_json_bytes = encode_canonical_json(event_dict)
hashed = hash_algorithm(event_json_bytes)
- return (hashed.name, hashed.digest())
+ return hashed.name, hashed.digest()
-def compute_event_signature(event_dict, signature_name, signing_key):
+def compute_event_signature(
+ room_version: RoomVersion,
+ event_dict: JsonDict,
+ signature_name: str,
+ signing_key: SigningKey,
+) -> Dict[str, Dict[str, str]]:
"""Compute the signature of the event for the given name and key.
Args:
- event_dict (dict): The event as a dict
- signature_name (str): The name of the entity signing the event
+ room_version: the version of the room that this event is in.
+ (the room version determines the redaction algorithm and hence the
+ json to be signed)
+
+ event_dict: The event as a dict
+
+ signature_name: The name of the entity signing the event
(typically the server's hostname).
- signing_key (syutil.crypto.SigningKey): The key to sign with
+
+ signing_key: The key to sign with
Returns:
- dict[str, dict[str, str]]: Returns a dictionary in the same format of
- an event's signatures field.
+ a dictionary in the same format of an event's signatures field.
"""
- redact_json = prune_event_dict(event_dict)
+ redact_json = prune_event_dict(room_version, event_dict)
redact_json.pop("age_ts", None)
redact_json.pop("unsigned", None)
- logger.debug("Signing event: %s", encode_canonical_json(redact_json))
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.debug("Signing event: %s", encode_canonical_json(redact_json))
redact_json = sign_json(redact_json, signature_name, signing_key)
- logger.debug("Signed event: %s", encode_canonical_json(redact_json))
+ if logger.isEnabledFor(logging.DEBUG):
+ logger.debug("Signed event: %s", encode_canonical_json(redact_json))
return redact_json["signatures"]
-def add_hashes_and_signatures(event_dict, signature_name, signing_key,
- hash_algorithm=hashlib.sha256):
+def add_hashes_and_signatures(
+ room_version: RoomVersion,
+ event_dict: JsonDict,
+ signature_name: str,
+ signing_key: SigningKey,
+):
"""Add content hash and sign the event
Args:
- event_dict (dict): The event to add hashes to and sign
- signature_name (str): The name of the entity signing the event
+ room_version: the version of the room this event is in
+
+ event_dict: The event to add hashes to and sign
+ signature_name: The name of the entity signing the event
(typically the server's hostname).
- signing_key (syutil.crypto.SigningKey): The key to sign with
- hash_algorithm: A hasher from `hashlib`, e.g. hashlib.sha256, to use
- to hash the event
+ signing_key: The key to sign with
"""
- name, digest = compute_content_hash(event_dict, hash_algorithm=hash_algorithm)
+ name, digest = compute_content_hash(event_dict, hash_algorithm=hashlib.sha256)
event_dict.setdefault("hashes", {})[name] = encode_base64(digest)
event_dict["signatures"] = compute_event_signature(
- event_dict,
- signature_name=signature_name,
- signing_key=signing_key,
+ room_version, event_dict, signature_name=signature_name, signing_key=signing_key
)
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 6f603f1961..983f0ead8c 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -18,7 +18,6 @@ import logging
from collections import defaultdict
import six
-from six import raise_from
from six.moves import urllib
import attr
@@ -30,7 +29,6 @@ from signedjson.key import (
from signedjson.sign import (
SignatureVerifyException,
encode_canonical_json,
- sign_json,
signature_ids,
verify_signed_json,
)
@@ -44,15 +42,16 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
-from synapse.storage.keys import FetchKeyResult
-from synapse.util import logcontext, unwrapFirstError
-from synapse.util.async_helpers import yieldable_gather_results
-from synapse.util.logcontext import (
+from synapse.logging.context import (
LoggingContext,
PreserveLoggingContext,
+ make_deferred_yieldable,
preserve_fn,
run_in_background,
)
+from synapse.storage.keys import FetchKeyResult
+from synapse.util import unwrapFirstError
+from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.metrics import Measure
from synapse.util.retryutils import NotRetryingDestination
@@ -140,7 +139,7 @@ class Keyring(object):
"""
req = VerifyJsonRequest(server_name, json_object, validity_time, request_name)
requests = (req,)
- return logcontext.make_deferred_yieldable(self._verify_objects(requests)[0])
+ return make_deferred_yieldable(self._verify_objects(requests)[0])
def verify_json_objects_for_server(self, server_and_json):
"""Bulk verifies signatures of json objects, bulk fetching keys as
@@ -237,27 +236,9 @@ class Keyring(object):
"""
try:
- # create a deferred for each server we're going to look up the keys
- # for; we'll resolve them once we have completed our lookups.
- # These will be passed into wait_for_previous_lookups to block
- # any other lookups until we have finished.
- # The deferreds are called with no logcontext.
- server_to_deferred = {
- rq.server_name: defer.Deferred() for rq in verify_requests
- }
-
- # We want to wait for any previous lookups to complete before
- # proceeding.
- yield self.wait_for_previous_lookups(server_to_deferred)
-
- # Actually start fetching keys.
- self._get_server_verify_keys(verify_requests)
+ ctx = LoggingContext.current_context()
- # When we've finished fetching all the keys for a given server_name,
- # resolve the deferred passed to `wait_for_previous_lookups` so that
- # any lookups waiting will proceed.
- #
- # map from server name to a set of request ids
+ # map from server name to a set of outstanding request ids
server_to_request_ids = {}
for verify_request in verify_requests:
@@ -265,40 +246,61 @@ class Keyring(object):
request_id = id(verify_request)
server_to_request_ids.setdefault(server_name, set()).add(request_id)
- def remove_deferreds(res, verify_request):
+ # Wait for any previous lookups to complete before proceeding.
+ yield self.wait_for_previous_lookups(server_to_request_ids.keys())
+
+ # take out a lock on each of the servers by sticking a Deferred in
+ # key_downloads
+ for server_name in server_to_request_ids.keys():
+ self.key_downloads[server_name] = defer.Deferred()
+ logger.debug("Got key lookup lock on %s", server_name)
+
+ # When we've finished fetching all the keys for a given server_name,
+ # drop the lock by resolving the deferred in key_downloads.
+ def drop_server_lock(server_name):
+ d = self.key_downloads.pop(server_name)
+ d.callback(None)
+
+ def lookup_done(res, verify_request):
server_name = verify_request.server_name
- request_id = id(verify_request)
- server_to_request_ids[server_name].discard(request_id)
- if not server_to_request_ids[server_name]:
- d = server_to_deferred.pop(server_name, None)
- if d:
- d.callback(None)
+ server_requests = server_to_request_ids[server_name]
+ server_requests.remove(id(verify_request))
+
+ # if there are no more requests for this server, we can drop the lock.
+ if not server_requests:
+ with PreserveLoggingContext(ctx):
+ logger.debug("Releasing key lookup lock on %s", server_name)
+
+ # ... but not immediately, as that can cause stack explosions if
+ # we get a long queue of lookups.
+ self.clock.call_later(0, drop_server_lock, server_name)
+
return res
for verify_request in verify_requests:
- verify_request.key_ready.addBoth(remove_deferreds, verify_request)
+ verify_request.key_ready.addBoth(lookup_done, verify_request)
+
+ # Actually start fetching keys.
+ self._get_server_verify_keys(verify_requests)
except Exception:
logger.exception("Error starting key lookups")
@defer.inlineCallbacks
- def wait_for_previous_lookups(self, server_to_deferred):
+ def wait_for_previous_lookups(self, server_names):
"""Waits for any previous key lookups for the given servers to finish.
Args:
- server_to_deferred (dict[str, Deferred]): server_name to deferred which gets
- resolved once we've finished looking up keys for that server.
- The Deferreds should be regular twisted ones which call their
- callbacks with no logcontext.
-
- Returns: a Deferred which resolves once all key lookups for the given
- servers have completed. Follows the synapse rules of logcontext
- preservation.
+ server_names (Iterable[str]): list of servers which we want to look up
+
+ Returns:
+ Deferred[None]: resolves once all key lookups for the given servers have
+ completed. Follows the synapse rules of logcontext preservation.
"""
loop_count = 1
while True:
wait_on = [
(server_name, self.key_downloads[server_name])
- for server_name in server_to_deferred.keys()
+ for server_name in server_names
if server_name in self.key_downloads
]
if not wait_on:
@@ -313,19 +315,6 @@ class Keyring(object):
loop_count += 1
- ctx = LoggingContext.current_context()
-
- def rm(r, server_name_):
- with PreserveLoggingContext(ctx):
- logger.debug("Releasing key lookup lock on %s", server_name_)
- self.key_downloads.pop(server_name_, None)
- return r
-
- for server_name, deferred in server_to_deferred.items():
- logger.debug("Got key lookup lock on %s", server_name)
- self.key_downloads[server_name] = deferred
- deferred.addBoth(rm, server_name)
-
def _get_server_verify_keys(self, verify_requests):
"""Tries to find at least one key for each verify request
@@ -337,9 +326,7 @@ class Keyring(object):
verify_requests (list[VerifyJsonRequest]): list of verify requests
"""
- remaining_requests = set(
- (rq for rq in verify_requests if not rq.key_ready.called)
- )
+ remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
@defer.inlineCallbacks
def do_iterations():
@@ -407,7 +394,7 @@ class Keyring(object):
results = yield fetcher.get_keys(missing_keys)
- completed = list()
+ completed = []
for verify_request in remaining_requests:
server_name = verify_request.server_name
@@ -471,7 +458,7 @@ class StoreKeyFetcher(KeyFetcher):
keys = {}
for (server_name, key_id), key in res.items():
keys.setdefault(server_name, {})[key_id] = key
- defer.returnValue(keys)
+ return keys
class BaseV2KeyFetcher(object):
@@ -505,7 +492,7 @@ class BaseV2KeyFetcher(object):
Returns:
Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
"""
- ts_valid_until_ms = response_json[u"valid_until_ts"]
+ ts_valid_until_ms = response_json["valid_until_ts"]
# start by extracting the keys from the response, since they may be required
# to validate the signature on the response.
@@ -522,17 +509,18 @@ class BaseV2KeyFetcher(object):
server_name = response_json["server_name"]
verified = False
for key_id in response_json["signatures"].get(server_name, {}):
- # each of the keys used for the signature must be present in the response
- # json.
key = verify_keys.get(key_id)
if not key:
- raise KeyLookupError(
- "Key response is signed by key id %s:%s but that key is not "
- "present in the response" % (server_name, key_id)
- )
+ # the key may not be present in verify_keys if:
+ # * we got the key from the notary server, and:
+ # * the key belongs to the notary server, and:
+ # * the notary server is using a different key to sign notary
+ # responses.
+ continue
verify_signed_json(response_json, server_name, key.verify_key)
verified = True
+ break
if not verified:
raise KeyLookupError(
@@ -549,15 +537,9 @@ class BaseV2KeyFetcher(object):
verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
)
- # re-sign the json with our own key, so that it is ready if we are asked to
- # give it out as a notary server
- signed_key_json = sign_json(
- response_json, self.config.server_name, self.config.signing_key[0]
- )
-
- signed_key_json_bytes = encode_canonical_json(signed_key_json)
+ key_json_bytes = encode_canonical_json(response_json)
- yield logcontext.make_deferred_yieldable(
+ yield make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
@@ -567,7 +549,7 @@ class BaseV2KeyFetcher(object):
from_server=from_server,
ts_now_ms=time_added_ms,
ts_expires_ms=ts_valid_until_ms,
- key_json_bytes=signed_key_json_bytes,
+ key_json_bytes=key_json_bytes,
)
for key_id in verify_keys
],
@@ -575,7 +557,7 @@ class BaseV2KeyFetcher(object):
).addErrback(unwrapFirstError)
)
- defer.returnValue(verify_keys)
+ return verify_keys
class PerspectivesKeyFetcher(BaseV2KeyFetcher):
@@ -597,7 +579,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
result = yield self.get_server_verify_key_v2_indirect(
keys_to_fetch, key_server
)
- defer.returnValue(result)
+ return result
except KeyLookupError as e:
logger.warning(
"Key lookup failed from %r: %s", key_server.server_name, e
@@ -610,14 +592,11 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
str(e),
)
- defer.returnValue({})
+ return {}
- results = yield logcontext.make_deferred_yieldable(
+ results = yield make_deferred_yieldable(
defer.gatherResults(
- [
- run_in_background(get_key, server)
- for server in self.key_servers
- ],
+ [run_in_background(get_key, server) for server in self.key_servers],
consumeErrors=True,
).addErrback(unwrapFirstError)
)
@@ -627,12 +606,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
for server_name, keys in result.items():
union_of_keys.setdefault(server_name, {}).update(keys)
- defer.returnValue(union_of_keys)
+ return union_of_keys
@defer.inlineCallbacks
- def get_server_verify_key_v2_indirect(
- self, keys_to_fetch, key_server
- ):
+ def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
"""
Args:
keys_to_fetch (dict[str, dict[str, int]]):
@@ -661,9 +638,9 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
destination=perspective_name,
path="/_matrix/key/v2/query",
data={
- u"server_keys": {
+ "server_keys": {
server_name: {
- key_id: {u"minimum_valid_until_ts": min_valid_ts}
+ key_id: {"minimum_valid_until_ts": min_valid_ts}
for key_id, min_valid_ts in server_keys.items()
}
for server_name, server_keys in keys_to_fetch.items()
@@ -671,9 +648,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
},
)
except (NotRetryingDestination, RequestSendFailed) as e:
- raise_from(KeyLookupError("Failed to connect to remote server"), e)
+ # these both have str() representations which we can't really improve upon
+ raise KeyLookupError(str(e))
except HttpResponseException as e:
- raise_from(KeyLookupError("Remote server returned an error"), e)
+ raise KeyLookupError("Remote server returned an error: %s" % (e,))
keys = {}
added_keys = []
@@ -690,10 +668,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
)
try:
- self._validate_perspectives_response(
- key_server,
- response,
- )
+ self._validate_perspectives_response(key_server, response)
processed_response = yield self.process_v2_response(
perspective_name, response, time_added_ms=time_now_ms
@@ -718,11 +693,9 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
perspective_name, time_now_ms, added_keys
)
- defer.returnValue(keys)
+ return keys
- def _validate_perspectives_response(
- self, key_server, response,
- ):
+ def _validate_perspectives_response(self, key_server, response):
"""Optionally check the signature on the result of a /key/query request
Args:
@@ -739,13 +712,13 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
return
if (
- u"signatures" not in response
- or perspective_name not in response[u"signatures"]
+ "signatures" not in response
+ or perspective_name not in response["signatures"]
):
raise KeyLookupError("Response not signed by the notary server")
verified = False
- for key_id in response[u"signatures"][perspective_name]:
+ for key_id in response["signatures"][perspective_name]:
if key_id in perspective_keys:
verify_signed_json(response, perspective_name, perspective_keys[key_id])
verified = True
@@ -754,7 +727,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
raise KeyLookupError(
"Response not signed with a known key: signed with: %r, known keys: %r"
% (
- list(response[u"signatures"][perspective_name].keys()),
+ list(response["signatures"][perspective_name].keys()),
list(perspective_keys.keys()),
)
)
@@ -826,7 +799,6 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
path="/_matrix/key/v2/server/"
+ urllib.parse.quote(requested_key_id),
ignore_backoff=True,
-
# we only give the remote server 10s to respond. It should be an
# easy request to handle, so if it doesn't reply within 10s, it's
# probably not going to.
@@ -841,9 +813,11 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
timeout=10000,
)
except (NotRetryingDestination, RequestSendFailed) as e:
- raise_from(KeyLookupError("Failed to connect to remote server"), e)
+ # these both have str() representations which we can't really improve
+ # upon
+ raise KeyLookupError(str(e))
except HttpResponseException as e:
- raise_from(KeyLookupError("Remote server returned an error"), e)
+ raise KeyLookupError("Remote server returned an error: %s" % (e,))
if response["server_name"] != server_name:
raise KeyLookupError(
@@ -863,7 +837,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
)
keys.update(response_keys)
- defer.returnValue(keys)
+ return keys
@defer.inlineCallbacks
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 203490fc36..46beb5334f 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,6 +15,7 @@
# limitations under the License.
import logging
+from typing import Set, Tuple
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
@@ -22,17 +24,27 @@ from unpaddedbase64 import decode_base64
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, EventSizeError, SynapseError
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions
+from synapse.api.room_versions import (
+ KNOWN_ROOM_VERSIONS,
+ EventFormatVersions,
+ RoomVersion,
+)
from synapse.types import UserID, get_domain_from_id
logger = logging.getLogger(__name__)
-def check(room_version, event, auth_events, do_sig_check=True, do_size_check=True):
+def check(
+ room_version_obj: RoomVersion,
+ event,
+ auth_events,
+ do_sig_check=True,
+ do_size_check=True,
+):
""" Checks if this event is correctly authed.
Args:
- room_version (str): the version of the room
+ room_version_obj: the version of the room
event: the event being checked.
auth_events (dict: event-key -> event): the existing room state.
@@ -42,12 +54,26 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
Returns:
if the auth checks pass.
"""
+ assert isinstance(auth_events, dict)
+
if do_size_check:
_check_size_limits(event)
if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event)
+ room_id = event.room_id
+
+ # I'm not really expecting to get auth events in the wrong room, but let's
+ # sanity-check it
+ for auth_event in auth_events.values():
+ if auth_event.room_id != room_id:
+ raise Exception(
+ "During auth for event %s in room %s, found event %s in the state "
+ "which is in room %s"
+ % (event.event_id, room_id, auth_event.event_id, auth_event.room_id)
+ )
+
if do_sig_check:
sender_domain = get_domain_from_id(event.sender)
@@ -74,75 +100,63 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server")
- if auth_events is None:
- # Oh, we don't know what the state of the room was, so we
- # are trusting that this is allowed (at least for now)
- logger.warn("Trusting event: %s", event.event_id)
- return
-
+ # Implementation of https://matrix.org/docs/spec/rooms/v1#authorization-rules
+ #
+ # 1. If type is m.room.create:
if event.type == EventTypes.Create:
+ # 1b. If the domain of the room_id does not match the domain of the sender,
+ # reject.
sender_domain = get_domain_from_id(event.sender)
room_id_domain = get_domain_from_id(event.room_id)
if room_id_domain != sender_domain:
raise AuthError(
- 403,
- "Creation event's room_id domain does not match sender's"
+ 403, "Creation event's room_id domain does not match sender's"
)
- room_version = event.content.get("room_version", "1")
- if room_version not in KNOWN_ROOM_VERSIONS:
+ # 1c. If content.room_version is present and is not a recognised version, reject
+ room_version_prop = event.content.get("room_version", "1")
+ if room_version_prop not in KNOWN_ROOM_VERSIONS:
raise AuthError(
403,
- "room appears to have unsupported version %s" % (
- room_version,
- ))
- # FIXME
+ "room appears to have unsupported version %s" % (room_version_prop,),
+ )
+
logger.debug("Allowing! %s", event)
return
+ # 3. If event does not have a m.room.create in its auth_events, reject.
creation_event = auth_events.get((EventTypes.Create, ""), None)
-
if not creation_event:
- raise AuthError(
- 403,
- "No create event in auth events",
- )
+ raise AuthError(403, "No create event in auth events")
+ # additional check for m.federate
creating_domain = get_domain_from_id(event.room_id)
originating_domain = get_domain_from_id(event.sender)
if creating_domain != originating_domain:
if not _can_federate(event, auth_events):
- raise AuthError(
- 403,
- "This room has been marked as unfederatable."
- )
+ raise AuthError(403, "This room has been marked as unfederatable.")
- # FIXME: Temp hack
- if event.type == EventTypes.Aliases:
+ # 4. If type is m.room.aliases
+ if event.type == EventTypes.Aliases and room_version_obj.special_case_aliases_auth:
+ # 4a. If event has no state_key, reject
if not event.is_state():
- raise AuthError(
- 403,
- "Alias event must be a state event",
- )
+ raise AuthError(403, "Alias event must be a state event")
if not event.state_key:
- raise AuthError(
- 403,
- "Alias event must have non-empty state_key"
- )
+ raise AuthError(403, "Alias event must have non-empty state_key")
+
+ # 4b. If sender's domain doesn't matches [sic] state_key, reject
sender_domain = get_domain_from_id(event.sender)
if event.state_key != sender_domain:
raise AuthError(
- 403,
- "Alias event's state_key does not match sender's domain"
+ 403, "Alias event's state_key does not match sender's domain"
)
+
+ # 4c. Otherwise, allow.
logger.debug("Allowing! %s", event)
return
if logger.isEnabledFor(logging.DEBUG):
- logger.debug(
- "Auth events: %s",
- [a.event_id for a in auth_events.values()]
- )
+ logger.debug("Auth events: %s", [a.event_id for a in auth_events.values()])
if event.type == EventTypes.Member:
_is_membership_change_allowed(event, auth_events)
@@ -159,9 +173,7 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
invite_level = _get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
- raise AuthError(
- 403, "You don't have permission to invite users",
- )
+ raise AuthError(403, "You don't have permission to invite users")
else:
logger.debug("Allowing! %s", event)
return
@@ -172,7 +184,7 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
_check_power_levels(event, auth_events)
if event.type == EventTypes.Redaction:
- check_redaction(room_version, event, auth_events)
+ check_redaction(room_version_obj, event, auth_events)
logger.debug("Allowing! %s", event)
@@ -207,7 +219,7 @@ def _is_membership_change_allowed(event, auth_events):
# Check if this is the room creator joining:
if len(event.prev_event_ids()) == 1 and Membership.JOIN == membership:
# Get room creation event:
- key = (EventTypes.Create, "", )
+ key = (EventTypes.Create, "")
create = auth_events.get(key)
if create and event.prev_event_ids()[0] == create.event_id:
if create.content["creator"] == event.state_key:
@@ -219,38 +231,31 @@ def _is_membership_change_allowed(event, auth_events):
target_domain = get_domain_from_id(target_user_id)
if creating_domain != target_domain:
if not _can_federate(event, auth_events):
- raise AuthError(
- 403,
- "This room has been marked as unfederatable."
- )
+ raise AuthError(403, "This room has been marked as unfederatable.")
# get info about the caller
- key = (EventTypes.Member, event.user_id, )
+ key = (EventTypes.Member, event.user_id)
caller = auth_events.get(key)
caller_in_room = caller and caller.membership == Membership.JOIN
caller_invited = caller and caller.membership == Membership.INVITE
# get info about the target
- key = (EventTypes.Member, target_user_id, )
+ key = (EventTypes.Member, target_user_id)
target = auth_events.get(key)
target_in_room = target and target.membership == Membership.JOIN
target_banned = target and target.membership == Membership.BAN
- key = (EventTypes.JoinRules, "", )
+ key = (EventTypes.JoinRules, "")
join_rule_event = auth_events.get(key)
if join_rule_event:
- join_rule = join_rule_event.content.get(
- "join_rule", JoinRules.INVITE
- )
+ join_rule = join_rule_event.content.get("join_rule", JoinRules.INVITE)
else:
join_rule = JoinRules.INVITE
user_level = get_user_power_level(event.user_id, auth_events)
- target_level = get_user_power_level(
- target_user_id, auth_events
- )
+ target_level = get_user_power_level(target_user_id, auth_events)
# FIXME (erikj): What should we do here as the default?
ban_level = _get_named_level(auth_events, "ban", 50)
@@ -266,29 +271,26 @@ def _is_membership_change_allowed(event, auth_events):
"join_rule": join_rule,
"target_user_id": target_user_id,
"event.user_id": event.user_id,
- }
+ },
)
if Membership.INVITE == membership and "third_party_invite" in event.content:
if not _verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.")
if target_banned:
- raise AuthError(
- 403, "%s is banned from the room" % (target_user_id,)
- )
+ raise AuthError(403, "%s is banned from the room" % (target_user_id,))
return
if Membership.JOIN != membership:
- if (caller_invited
- and Membership.LEAVE == membership
- and target_user_id == event.user_id):
+ if (
+ caller_invited
+ and Membership.LEAVE == membership
+ and target_user_id == event.user_id
+ ):
return
if not caller_in_room: # caller isn't joined
- raise AuthError(
- 403,
- "%s not in room %s." % (event.user_id, event.room_id,)
- )
+ raise AuthError(403, "%s not in room %s." % (event.user_id, event.room_id))
if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently
@@ -296,19 +298,14 @@ def _is_membership_change_allowed(event, auth_events):
# Invites are valid iff caller is in the room and target isn't.
if target_banned:
- raise AuthError(
- 403, "%s is banned from the room" % (target_user_id,)
- )
+ raise AuthError(403, "%s is banned from the room" % (target_user_id,))
elif target_in_room: # the target is already in the room.
- raise AuthError(403, "%s is already in the room." %
- target_user_id)
+ raise AuthError(403, "%s is already in the room." % target_user_id)
else:
invite_level = _get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
- raise AuthError(
- 403, "You don't have permission to invite users",
- )
+ raise AuthError(403, "You don't have permission to invite users")
elif Membership.JOIN == membership:
# Joins are valid iff caller == target and they were:
# invited: They are accepting the invitation
@@ -329,16 +326,12 @@ def _is_membership_change_allowed(event, auth_events):
elif Membership.LEAVE == membership:
# TODO (erikj): Implement kicks.
if target_banned and user_level < ban_level:
- raise AuthError(
- 403, "You cannot unban user %s." % (target_user_id,)
- )
+ raise AuthError(403, "You cannot unban user %s." % (target_user_id,))
elif target_user_id != event.user_id:
kick_level = _get_named_level(auth_events, "kick", 50)
if user_level < kick_level or user_level <= target_level:
- raise AuthError(
- 403, "You cannot kick user %s." % target_user_id
- )
+ raise AuthError(403, "You cannot kick user %s." % target_user_id)
elif Membership.BAN == membership:
if user_level < ban_level or user_level <= target_level:
raise AuthError(403, "You don't have permission to ban")
@@ -347,21 +340,17 @@ def _is_membership_change_allowed(event, auth_events):
def _check_event_sender_in_room(event, auth_events):
- key = (EventTypes.Member, event.user_id, )
+ key = (EventTypes.Member, event.user_id)
member_event = auth_events.get(key)
- return _check_joined_room(
- member_event,
- event.user_id,
- event.room_id
- )
+ return _check_joined_room(member_event, event.user_id, event.room_id)
def _check_joined_room(member, user_id, room_id):
if not member or member.membership != Membership.JOIN:
- raise AuthError(403, "User %s not in room %s (%s)" % (
- user_id, room_id, repr(member)
- ))
+ raise AuthError(
+ 403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member))
+ )
def get_send_level(etype, state_key, power_levels_event):
@@ -402,31 +391,26 @@ def get_send_level(etype, state_key, power_levels_event):
def _can_send_event(event, auth_events):
power_levels_event = _get_power_level_event(auth_events)
- send_level = get_send_level(
- event.type, event.get("state_key"), power_levels_event,
- )
+ send_level = get_send_level(event.type, event.get("state_key"), power_levels_event)
user_level = get_user_power_level(event.user_id, auth_events)
if user_level < send_level:
raise AuthError(
403,
- "You don't have permission to post that to the room. " +
- "user_level (%d) < send_level (%d)" % (user_level, send_level)
+ "You don't have permission to post that to the room. "
+ + "user_level (%d) < send_level (%d)" % (user_level, send_level),
)
# Check state_key
if hasattr(event, "state_key"):
if event.state_key.startswith("@"):
if event.state_key != event.user_id:
- raise AuthError(
- 403,
- "You are not allowed to set others state"
- )
+ raise AuthError(403, "You are not allowed to set others state")
return True
-def check_redaction(room_version, event, auth_events):
+def check_redaction(room_version_obj: RoomVersion, event, auth_events):
"""Check whether the event sender is allowed to redact the target event.
Returns:
@@ -446,11 +430,7 @@ def check_redaction(room_version, event, auth_events):
if user_level >= redact_level:
return False
- v = KNOWN_ROOM_VERSIONS.get(room_version)
- if not v:
- raise RuntimeError("Unrecognized room version %r" % (room_version,))
-
- if v.event_format == EventFormatVersions.V1:
+ if room_version_obj.event_format == EventFormatVersions.V1:
redacter_domain = get_domain_from_id(event.event_id)
redactee_domain = get_domain_from_id(event.redacts)
if redacter_domain == redactee_domain:
@@ -459,10 +439,7 @@ def check_redaction(room_version, event, auth_events):
event.internal_metadata.recheck_redaction = True
return True
- raise AuthError(
- 403,
- "You don't have permission to redact events"
- )
+ raise AuthError(403, "You don't have permission to redact events")
def _check_power_levels(event, auth_events):
@@ -479,7 +456,7 @@ def _check_power_levels(event, auth_events):
except Exception:
raise SynapseError(400, "Not a valid power level: %s" % (v,))
- key = (event.type, event.state_key, )
+ key = (event.type, event.state_key)
current_state = auth_events.get(key)
if not current_state:
@@ -500,16 +477,12 @@ def _check_power_levels(event, auth_events):
old_list = current_state.content.get("users", {})
for user in set(list(old_list) + list(user_list)):
- levels_to_check.append(
- (user, "users")
- )
+ levels_to_check.append((user, "users"))
old_list = current_state.content.get("events", {})
new_list = event.content.get("events", {})
for ev_id in set(list(old_list) + list(new_list)):
- levels_to_check.append(
- (ev_id, "events")
- )
+ levels_to_check.append((ev_id, "events"))
old_state = current_state.content
new_state = event.content
@@ -540,7 +513,7 @@ def _check_power_levels(event, auth_events):
raise AuthError(
403,
"You don't have permission to remove ops level equal "
- "to your own"
+ "to your own",
)
# Check if the old and new levels are greater than the user level
@@ -549,9 +522,7 @@ def _check_power_levels(event, auth_events):
new_level_too_big = new_level is not None and new_level > user_level
if old_level_too_big or new_level_too_big:
raise AuthError(
- 403,
- "You don't have permission to add ops level greater "
- "than your own"
+ 403, "You don't have permission to add ops level greater than your own"
)
@@ -587,10 +558,9 @@ def get_user_power_level(user_id, auth_events):
# some things which call this don't pass the create event: hack around
# that.
- key = (EventTypes.Create, "", )
+ key = (EventTypes.Create, "")
create_event = auth_events.get(key)
- if (create_event is not None and
- create_event.content["creator"] == user_id):
+ if create_event is not None and create_event.content["creator"] == user_id:
return 100
else:
return 0
@@ -636,9 +606,7 @@ def _verify_third_party_invite(event, auth_events):
token = signed["token"]
- invite_event = auth_events.get(
- (EventTypes.ThirdPartyInvite, token,)
- )
+ invite_event = auth_events.get((EventTypes.ThirdPartyInvite, token))
if not invite_event:
return False
@@ -661,8 +629,7 @@ def _verify_third_party_invite(event, auth_events):
if not key_name.startswith("ed25519:"):
continue
verify_key = decode_verify_key_bytes(
- key_name,
- decode_base64(public_key)
+ key_name, decode_base64(public_key)
)
verify_signed_json(signed, server, verify_key)
@@ -671,7 +638,7 @@ def _verify_third_party_invite(event, auth_events):
# The caller is responsible for checking that the signing
# server has not revoked that public key.
return True
- except (KeyError, SignatureVerifyException,):
+ except (KeyError, SignatureVerifyException):
continue
return False
@@ -679,9 +646,7 @@ def _verify_third_party_invite(event, auth_events):
def get_public_keys(invite_event):
public_keys = []
if "public_key" in invite_event.content:
- o = {
- "public_key": invite_event.content["public_key"],
- }
+ o = {"public_key": invite_event.content["public_key"]}
if "key_validity_url" in invite_event.content:
o["key_validity_url"] = invite_event.content["key_validity_url"]
public_keys.append(o)
@@ -689,7 +654,7 @@ def get_public_keys(invite_event):
return public_keys
-def auth_types_for_event(event):
+def auth_types_for_event(event) -> Set[Tuple[str, str]]:
"""Given an event, return a list of (EventType, StateKey) that may be
needed to auth the event. The returned list may be a superset of what
would actually be required depending on the full state of the room.
@@ -698,27 +663,27 @@ def auth_types_for_event(event):
actually auth the event.
"""
if event.type == EventTypes.Create:
- return []
-
- auth_types = []
+ return set()
- auth_types.append((EventTypes.PowerLevels, "", ))
- auth_types.append((EventTypes.Member, event.sender, ))
- auth_types.append((EventTypes.Create, "", ))
+ auth_types = {
+ (EventTypes.PowerLevels, ""),
+ (EventTypes.Member, event.sender),
+ (EventTypes.Create, ""),
+ }
if event.type == EventTypes.Member:
membership = event.content["membership"]
if membership in [Membership.JOIN, Membership.INVITE]:
- auth_types.append((EventTypes.JoinRules, "", ))
+ auth_types.add((EventTypes.JoinRules, ""))
- auth_types.append((EventTypes.Member, event.state_key, ))
+ auth_types.add((EventTypes.Member, event.state_key))
if membership == Membership.INVITE:
if "third_party_invite" in event.content:
key = (
EventTypes.ThirdPartyInvite,
- event.content["third_party_invite"]["signed"]["token"]
+ event.content["third_party_invite"]["signed"]["token"],
)
- auth_types.append(key)
+ auth_types.add(key)
return auth_types
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 1edd19cc13..533ba327f5 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2019 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,15 +15,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import os
from distutils.util import strtobool
+from typing import Dict, Optional, Type
import six
from unpaddedbase64 import encode_base64
-from synapse.api.errors import UnsupportedRoomVersionError
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions
+from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
+from synapse.types import JsonDict
from synapse.util.caches import intern_dict
from synapse.util.frozenutils import freeze
@@ -36,34 +39,115 @@ from synapse.util.frozenutils import freeze
USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))
-class _EventInternalMetadata(object):
- def __init__(self, internal_metadata_dict):
- self.__dict__ = dict(internal_metadata_dict)
+class DictProperty:
+ """An object property which delegates to the `_dict` within its parent object."""
+
+ __slots__ = ["key"]
+
+ def __init__(self, key: str):
+ self.key = key
+
+ def __get__(self, instance, owner=None):
+ # if the property is accessed as a class property rather than an instance
+ # property, return the property itself rather than the value
+ if instance is None:
+ return self
+ try:
+ return instance._dict[self.key]
+ except KeyError as e1:
+ # We want this to look like a regular attribute error (mostly so that
+ # hasattr() works correctly), so we convert the KeyError into an
+ # AttributeError.
+ #
+ # To exclude the KeyError from the traceback, we explicitly
+ # 'raise from e1.__context__' (which is better than 'raise from None',
+ # becuase that would omit any *earlier* exceptions).
+ #
+ raise AttributeError(
+ "'%s' has no '%s' property" % (type(instance), self.key)
+ ) from e1.__context__
+
+ def __set__(self, instance, v):
+ instance._dict[self.key] = v
+
+ def __delete__(self, instance):
+ try:
+ del instance._dict[self.key]
+ except KeyError as e1:
+ raise AttributeError(
+ "'%s' has no '%s' property" % (type(instance), self.key)
+ ) from e1.__context__
+
+
+class DefaultDictProperty(DictProperty):
+ """An extension of DictProperty which provides a default if the property is
+ not present in the parent's _dict.
+
+ Note that this means that hasattr() on the property always returns True.
+ """
+
+ __slots__ = ["default"]
- def get_dict(self):
- return dict(self.__dict__)
+ def __init__(self, key, default):
+ super().__init__(key)
+ self.default = default
- def is_outlier(self):
- return getattr(self, "outlier", False)
+ def __get__(self, instance, owner=None):
+ if instance is None:
+ return self
+ return instance._dict.get(self.key, self.default)
- def is_out_of_band_membership(self):
+
+class _EventInternalMetadata(object):
+ __slots__ = ["_dict"]
+
+ def __init__(self, internal_metadata_dict: JsonDict):
+ # we have to copy the dict, because it turns out that the same dict is
+ # reused. TODO: fix that
+ self._dict = dict(internal_metadata_dict)
+
+ outlier = DictProperty("outlier") # type: bool
+ out_of_band_membership = DictProperty("out_of_band_membership") # type: bool
+ send_on_behalf_of = DictProperty("send_on_behalf_of") # type: str
+ recheck_redaction = DictProperty("recheck_redaction") # type: bool
+ soft_failed = DictProperty("soft_failed") # type: bool
+ proactively_send = DictProperty("proactively_send") # type: bool
+ redacted = DictProperty("redacted") # type: bool
+ txn_id = DictProperty("txn_id") # type: str
+ token_id = DictProperty("token_id") # type: str
+ stream_ordering = DictProperty("stream_ordering") # type: int
+
+ # XXX: These are set by StreamWorkerStore._set_before_and_after.
+ # I'm pretty sure that these are never persisted to the database, so shouldn't
+ # be here
+ before = DictProperty("before") # type: str
+ after = DictProperty("after") # type: str
+ order = DictProperty("order") # type: int
+
+ def get_dict(self) -> JsonDict:
+ return dict(self._dict)
+
+ def is_outlier(self) -> bool:
+ return self._dict.get("outlier", False)
+
+ def is_out_of_band_membership(self) -> bool:
"""Whether this is an out of band membership, like an invite or an invite
rejection. This is needed as those events are marked as outliers, but
they still need to be processed as if they're new events (e.g. updating
invite state in the database, relaying to clients, etc).
"""
- return getattr(self, "out_of_band_membership", False)
+ return self._dict.get("out_of_band_membership", False)
- def get_send_on_behalf_of(self):
+ def get_send_on_behalf_of(self) -> Optional[str]:
"""Whether this server should send the event on behalf of another server.
This is used by the federation "send_join" API to forward the initial join
event for a server in the room.
returns a str with the name of the server this event is sent on behalf of.
"""
- return getattr(self, "send_on_behalf_of", None)
+ return self._dict.get("send_on_behalf_of")
- def need_to_check_redaction(self):
+ def need_to_check_redaction(self) -> bool:
"""Whether the redaction event needs to be rechecked when fetching
from the database.
@@ -76,9 +160,9 @@ class _EventInternalMetadata(object):
Returns:
bool
"""
- return getattr(self, "recheck_redaction", False)
+ return self._dict.get("recheck_redaction", False)
- def is_soft_failed(self):
+ def is_soft_failed(self) -> bool:
"""Whether the event has been soft failed.
Soft failed events should be handled as usual, except:
@@ -90,62 +174,76 @@ class _EventInternalMetadata(object):
Returns:
bool
"""
- return getattr(self, "soft_failed", False)
+ return self._dict.get("soft_failed", False)
+ def should_proactively_send(self):
+ """Whether the event, if ours, should be sent to other clients and
+ servers.
-def _event_dict_property(key):
- # We want to be able to use hasattr with the event dict properties.
- # However, (on python3) hasattr expects AttributeError to be raised. Hence,
- # we need to transform the KeyError into an AttributeError
- def getter(self):
- try:
- return self._event_dict[key]
- except KeyError:
- raise AttributeError(key)
+ This is used for sending dummy events internally. Servers and clients
+ can still explicitly fetch the event.
- def setter(self, v):
- try:
- self._event_dict[key] = v
- except KeyError:
- raise AttributeError(key)
+ Returns:
+ bool
+ """
+ return self._dict.get("proactively_send", True)
- def delete(self):
- try:
- del self._event_dict[key]
- except KeyError:
- raise AttributeError(key)
+ def is_redacted(self):
+ """Whether the event has been redacted.
- return property(
- getter,
- setter,
- delete,
- )
+ This is used for efficiently checking whether an event has been
+ marked as redacted without needing to make another database call.
+
+ Returns:
+ bool
+ """
+ return self._dict.get("redacted", False)
-class EventBase(object):
- def __init__(self, event_dict, signatures={}, unsigned={},
- internal_metadata_dict={}, rejected_reason=None):
+class EventBase(metaclass=abc.ABCMeta):
+ @property
+ @abc.abstractmethod
+ def format_version(self) -> int:
+ """The EventFormatVersion implemented by this event"""
+ ...
+
+ def __init__(
+ self,
+ event_dict: JsonDict,
+ room_version: RoomVersion,
+ signatures: Dict[str, Dict[str, str]],
+ unsigned: JsonDict,
+ internal_metadata_dict: JsonDict,
+ rejected_reason: Optional[str],
+ ):
+ assert room_version.event_format == self.format_version
+
+ self.room_version = room_version
self.signatures = signatures
self.unsigned = unsigned
self.rejected_reason = rejected_reason
- self._event_dict = event_dict
+ self._dict = event_dict
- self.internal_metadata = _EventInternalMetadata(
- internal_metadata_dict
- )
+ self.internal_metadata = _EventInternalMetadata(internal_metadata_dict)
+
+ auth_events = DictProperty("auth_events")
+ depth = DictProperty("depth")
+ content = DictProperty("content")
+ hashes = DictProperty("hashes")
+ origin = DictProperty("origin")
+ origin_server_ts = DictProperty("origin_server_ts")
+ prev_events = DictProperty("prev_events")
+ redacts = DefaultDictProperty("redacts", None)
+ room_id = DictProperty("room_id")
+ sender = DictProperty("sender")
+ state_key = DictProperty("state_key")
+ type = DictProperty("type")
+ user_id = DictProperty("sender")
- auth_events = _event_dict_property("auth_events")
- depth = _event_dict_property("depth")
- content = _event_dict_property("content")
- hashes = _event_dict_property("hashes")
- origin = _event_dict_property("origin")
- origin_server_ts = _event_dict_property("origin_server_ts")
- prev_events = _event_dict_property("prev_events")
- redacts = _event_dict_property("redacts")
- room_id = _event_dict_property("room_id")
- sender = _event_dict_property("sender")
- user_id = _event_dict_property("sender")
+ @property
+ def event_id(self) -> str:
+ raise NotImplementedError()
@property
def membership(self):
@@ -154,22 +252,19 @@ class EventBase(object):
def is_state(self):
return hasattr(self, "state_key") and self.state_key is not None
- def get_dict(self):
- d = dict(self._event_dict)
- d.update({
- "signatures": self.signatures,
- "unsigned": dict(self.unsigned),
- })
+ def get_dict(self) -> JsonDict:
+ d = dict(self._dict)
+ d.update({"signatures": self.signatures, "unsigned": dict(self.unsigned)})
return d
def get(self, key, default=None):
- return self._event_dict.get(key, default)
+ return self._dict.get(key, default)
def get_internal_metadata_dict(self):
return self.internal_metadata.get_dict()
- def get_pdu_json(self, time_now=None):
+ def get_pdu_json(self, time_now=None) -> JsonDict:
pdu_json = self.get_dict()
if time_now is not None and "age_ts" in pdu_json["unsigned"]:
@@ -186,16 +281,16 @@ class EventBase(object):
raise AttributeError("Unrecognized attribute %s" % (instance,))
def __getitem__(self, field):
- return self._event_dict[field]
+ return self._dict[field]
def __contains__(self, field):
- return field in self._event_dict
+ return field in self._dict
def items(self):
- return list(self._event_dict.items())
+ return list(self._dict.items())
def keys(self):
- return six.iterkeys(self._event_dict)
+ return six.iterkeys(self._dict)
def prev_event_ids(self):
"""Returns the list of prev event IDs. The order matches the order
@@ -219,7 +314,13 @@ class EventBase(object):
class FrozenEvent(EventBase):
format_version = EventFormatVersions.V1 # All events of this type are V1
- def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
+ def __init__(
+ self,
+ event_dict: JsonDict,
+ room_version: RoomVersion,
+ internal_metadata_dict: JsonDict = {},
+ rejected_reason: Optional[str] = None,
+ ):
event_dict = dict(event_dict)
# Signatures is a dict of dicts, and this is faster than doing a
@@ -240,19 +341,21 @@ class FrozenEvent(EventBase):
else:
frozen_dict = event_dict
- self.event_id = event_dict["event_id"]
- self.type = event_dict["type"]
- if "state_key" in event_dict:
- self.state_key = event_dict["state_key"]
+ self._event_id = event_dict["event_id"]
- super(FrozenEvent, self).__init__(
+ super().__init__(
frozen_dict,
+ room_version=room_version,
signatures=signatures,
unsigned=unsigned,
internal_metadata_dict=internal_metadata_dict,
rejected_reason=rejected_reason,
)
+ @property
+ def event_id(self) -> str:
+ return self._event_id
+
def __str__(self):
return self.__repr__()
@@ -267,7 +370,13 @@ class FrozenEvent(EventBase):
class FrozenEventV2(EventBase):
format_version = EventFormatVersions.V2 # All events of this type are V2
- def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
+ def __init__(
+ self,
+ event_dict: JsonDict,
+ room_version: RoomVersion,
+ internal_metadata_dict: JsonDict = {},
+ rejected_reason: Optional[str] = None,
+ ):
event_dict = dict(event_dict)
# Signatures is a dict of dicts, and this is faster than doing a
@@ -291,12 +400,10 @@ class FrozenEventV2(EventBase):
frozen_dict = event_dict
self._event_id = None
- self.type = event_dict["type"]
- if "state_key" in event_dict:
- self.state_key = event_dict["state_key"]
- super(FrozenEventV2, self).__init__(
+ super().__init__(
frozen_dict,
+ room_version=room_version,
signatures=signatures,
unsigned=unsigned,
internal_metadata_dict=internal_metadata_dict,
@@ -346,6 +453,7 @@ class FrozenEventV2(EventBase):
class FrozenEventV3(FrozenEventV2):
"""FrozenEventV3, which differs from FrozenEventV2 only in the event_id format"""
+
format_version = EventFormatVersions.V3 # All events of this type are V3
@property
@@ -362,28 +470,7 @@ class FrozenEventV3(FrozenEventV2):
return self._event_id
-def room_version_to_event_format(room_version):
- """Converts a room version string to the event format
-
- Args:
- room_version (str)
-
- Returns:
- int
-
- Raises:
- UnsupportedRoomVersionError if the room version is unknown
- """
- v = KNOWN_ROOM_VERSIONS.get(room_version)
-
- if not v:
- # this can happen if support is withdrawn for a room version
- raise UnsupportedRoomVersionError()
-
- return v.event_format
-
-
-def event_type_from_format_version(format_version):
+def _event_type_from_format_version(format_version: int) -> Type[EventBase]:
"""Returns the python type to use to construct an Event object for the
given event format version.
@@ -402,6 +489,15 @@ def event_type_from_format_version(format_version):
elif format_version == EventFormatVersions.V3:
return FrozenEventV3
else:
- raise Exception(
- "No event format %r" % (format_version,)
- )
+ raise Exception("No event format %r" % (format_version,))
+
+
+def make_event_from_dict(
+ event_dict: JsonDict,
+ room_version: RoomVersion = RoomVersions.V1,
+ internal_metadata_dict: JsonDict = {},
+ rejected_reason: Optional[str] = None,
+) -> EventBase:
+ """Construct an EventBase from the given event dict"""
+ event_type = _event_type_from_format_version(room_version.event_format)
+ return event_type(event_dict, room_version, internal_metadata_dict, rejected_reason)
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 546b6f4982..a0c4a40c27 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -12,8 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
import attr
+from nacl.signing import SigningKey
from twisted.internet import defer
@@ -23,13 +25,14 @@ from synapse.api.room_versions import (
KNOWN_EVENT_FORMAT_VERSIONS,
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
+ RoomVersion,
)
from synapse.crypto.event_signing import add_hashes_and_signatures
-from synapse.types import EventID
+from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
+from synapse.types import EventID, JsonDict
+from synapse.util import Clock
from synapse.util.stringutils import random_string
-from . import _EventInternalMetadata, event_type_from_format_version
-
@attr.s(slots=True, cmp=False, frozen=True)
class EventBuilder(object):
@@ -40,7 +43,7 @@ class EventBuilder(object):
content/unsigned/internal_metadata fields are still mutable)
Attributes:
- format_version (int): Event format version
+ room_version: Version of the target room
room_id (str)
type (str)
sender (str)
@@ -63,7 +66,7 @@ class EventBuilder(object):
_hostname = attr.ib()
_signing_key = attr.ib()
- format_version = attr.ib()
+ room_version = attr.ib(type=RoomVersion)
room_id = attr.ib()
type = attr.ib()
@@ -78,7 +81,9 @@ class EventBuilder(object):
_redacts = attr.ib(default=None)
_origin_server_ts = attr.ib(default=None)
- internal_metadata = attr.ib(default=attr.Factory(lambda: _EventInternalMetadata({})))
+ internal_metadata = attr.ib(
+ default=attr.Factory(lambda: _EventInternalMetadata({}))
+ )
@property
def state_key(self):
@@ -102,22 +107,19 @@ class EventBuilder(object):
"""
state_ids = yield self._state.get_current_state_ids(
- self.room_id, prev_event_ids,
- )
- auth_ids = yield self._auth.compute_auth_events(
- self, state_ids,
+ self.room_id, prev_event_ids
)
+ auth_ids = yield self._auth.compute_auth_events(self, state_ids)
- if self.format_version == EventFormatVersions.V1:
+ format_version = self.room_version.event_format
+ if format_version == EventFormatVersions.V1:
auth_events = yield self._store.add_event_hashes(auth_ids)
prev_events = yield self._store.add_event_hashes(prev_event_ids)
else:
auth_events = auth_ids
prev_events = prev_event_ids
- old_depth = yield self._store.get_max_depth_of(
- prev_event_ids,
- )
+ old_depth = yield self._store.get_max_depth_of(prev_event_ids)
depth = old_depth + 1
# we cap depth of generated events, to ensure that they are not
@@ -146,15 +148,13 @@ class EventBuilder(object):
if self._origin_server_ts is not None:
event_dict["origin_server_ts"] = self._origin_server_ts
- defer.returnValue(
- create_local_event_from_event_dict(
- clock=self._clock,
- hostname=self._hostname,
- signing_key=self._signing_key,
- format_version=self.format_version,
- event_dict=event_dict,
- internal_metadata_dict=self.internal_metadata.get_dict(),
- )
+ return create_local_event_from_event_dict(
+ clock=self._clock,
+ hostname=self._hostname,
+ signing_key=self._signing_key,
+ room_version=self.room_version,
+ event_dict=event_dict,
+ internal_metadata_dict=self.internal_metadata.get_dict(),
)
@@ -205,7 +205,7 @@ class EventBuilderFactory(object):
clock=self.clock,
hostname=self.hostname,
signing_key=self.signing_key,
- format_version=room_version.event_format,
+ room_version=room_version,
type=key_values["type"],
state_key=key_values.get("state_key"),
room_id=key_values["room_id"],
@@ -217,29 +217,22 @@ class EventBuilderFactory(object):
)
-def create_local_event_from_event_dict(clock, hostname, signing_key,
- format_version, event_dict,
- internal_metadata_dict=None):
+def create_local_event_from_event_dict(
+ clock: Clock,
+ hostname: str,
+ signing_key: SigningKey,
+ room_version: RoomVersion,
+ event_dict: JsonDict,
+ internal_metadata_dict: Optional[JsonDict] = None,
+) -> EventBase:
"""Takes a fully formed event dict, ensuring that fields like `origin`
and `origin_server_ts` have correct values for a locally produced event,
then signs and hashes it.
-
- Args:
- clock (Clock)
- hostname (str)
- signing_key
- format_version (int)
- event_dict (dict)
- internal_metadata_dict (dict|None)
-
- Returns:
- FrozenEvent
"""
+ format_version = room_version.event_format
if format_version not in KNOWN_EVENT_FORMAT_VERSIONS:
- raise Exception(
- "No event format defined for version %r" % (format_version,)
- )
+ raise Exception("No event format defined for version %r" % (format_version,))
if internal_metadata_dict is None:
internal_metadata_dict = {}
@@ -258,13 +251,9 @@ def create_local_event_from_event_dict(clock, hostname, signing_key,
event_dict.setdefault("signatures", {})
- add_hashes_and_signatures(
- event_dict,
- hostname,
- signing_key,
- )
- return event_type_from_format_version(format_version)(
- event_dict, internal_metadata_dict=internal_metadata_dict,
+ add_hashes_and_signatures(room_version, event_dict, hostname, signing_key)
+ return make_event_from_dict(
+ event_dict, room_version, internal_metadata_dict=internal_metadata_dict
)
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index fa09c132a0..9ea85e93e6 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -12,103 +12,124 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional, Union
from six import iteritems
+import attr
from frozendict import frozendict
from twisted.internet import defer
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+from synapse.appservice import ApplicationService
+from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.types import StateMap
-class EventContext(object):
+@attr.s(slots=True)
+class EventContext:
"""
+ Holds information relevant to persisting an event
+
Attributes:
- state_group (int|None): state group id, if the state has been stored
- as a state group. This is usually only None if e.g. the event is
- an outlier.
- rejected (bool|str): A rejection reason if the event was rejected, else
- False
-
- push_actions (list[(str, list[object])]): list of (user_id, actions)
- tuples
-
- prev_group (int): Previously persisted state group. ``None`` for an
- outlier.
- delta_ids (dict[(str, str), str]): Delta from ``prev_group``.
- (type, state_key) -> event_id. ``None`` for an outlier.
-
- prev_state_events (?): XXX: is this ever set to anything other than
- the empty list?
-
- _current_state_ids (dict[(str, str), str]|None):
- The current state map including the current event. None if outlier
- or we haven't fetched the state from DB yet.
- (type, state_key) -> event_id
+ rejected: A rejection reason if the event was rejected, else False
+
+ _state_group: The ID of the state group for this event. Note that state events
+ are persisted with a state group which includes the new event, so this is
+ effectively the state *after* the event in question.
+
+ For a *rejected* state event, where the state of the rejected event is
+ ignored, this state_group should never make it into the
+ event_to_state_groups table. Indeed, inspecting this value for a rejected
+ state event is almost certainly incorrect.
+
+ For an outlier, where we don't have the state at the event, this will be
+ None.
+
+ Note that this is a private attribute: it should be accessed via
+ the ``state_group`` property.
+
+ state_group_before_event: The ID of the state group representing the state
+ of the room before this event.
+
+ If this is a non-state event, this will be the same as ``state_group``. If
+ it's a state event, it will be the same as ``prev_group``.
+
+ If ``state_group`` is None (ie, the event is an outlier),
+ ``state_group_before_event`` will always also be ``None``.
+
+ prev_group: If it is known, ``state_group``'s prev_group. Note that this being
+ None does not necessarily mean that ``state_group`` does not have
+ a prev_group!
+
+ If the event is a state event, this is normally the same as ``prev_group``.
+
+ If ``state_group`` is None (ie, the event is an outlier), ``prev_group``
+ will always also be ``None``.
+
+ Note that this *not* (necessarily) the state group associated with
+ ``_prev_state_ids``.
+
+ delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group``
+ and ``state_group``.
+
+ app_service: If this event is being sent by a (local) application service, that
+ app service.
+
+ _current_state_ids: The room state map, including this event - ie, the state
+ in ``state_group``.
- _prev_state_ids (dict[(str, str), str]|None):
- The current state map excluding the current event. None if outlier
- or we haven't fetched the state from DB yet.
(type, state_key) -> event_id
- _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
- been calculated. None if we haven't started calculating yet
+ FIXME: what is this for an outlier? it seems ill-defined. It seems like
+ it could be either {}, or the state we were given by the remote
+ server, depending on $THINGS
- _event_type (str): The type of the event the context is associated with.
- Only set when state has not been fetched yet.
+ Note that this is a private attribute: it should be accessed via
+ ``get_current_state_ids``. _AsyncEventContext impl calculates this
+ on-demand: it will be None until that happens.
- _event_state_key (str|None): The state_key of the event the context is
- associated with. Only set when state has not been fetched yet.
+ _prev_state_ids: The room state map, excluding this event - ie, the state
+ in ``state_group_before_event``. For a non-state
+ event, this will be the same as _current_state_events.
- _prev_state_id (str|None): If the event associated with the context is
- a state event, then `_prev_state_id` is the event_id of the state
- that was replaced.
- Only set when state has not been fetched yet.
+ Note that it is a completely different thing to prev_group!
+
+ (type, state_key) -> event_id
+
+ FIXME: again, what is this for an outlier?
+
+ As with _current_state_ids, this is a private attribute. It should be
+ accessed via get_prev_state_ids.
"""
- __slots__ = [
- "state_group",
- "rejected",
- "prev_group",
- "delta_ids",
- "prev_state_events",
- "app_service",
- "_current_state_ids",
- "_prev_state_ids",
- "_prev_state_id",
- "_event_type",
- "_event_state_key",
- "_fetching_state_deferred",
- ]
-
- def __init__(self):
- self.prev_state_events = []
- self.rejected = False
- self.app_service = None
+ rejected = attr.ib(default=False, type=Union[bool, str])
+ _state_group = attr.ib(default=None, type=Optional[int])
+ state_group_before_event = attr.ib(default=None, type=Optional[int])
+ prev_group = attr.ib(default=None, type=Optional[int])
+ delta_ids = attr.ib(default=None, type=Optional[StateMap[str]])
+ app_service = attr.ib(default=None, type=Optional[ApplicationService])
- @staticmethod
- def with_state(state_group, current_state_ids, prev_state_ids,
- prev_group=None, delta_ids=None):
- context = EventContext()
-
- # The current state including the current event
- context._current_state_ids = current_state_ids
- # The current state excluding the current event
- context._prev_state_ids = prev_state_ids
- context.state_group = state_group
-
- context._prev_state_id = None
- context._event_type = None
- context._event_state_key = None
- context._fetching_state_deferred = defer.succeed(None)
-
- # A previously persisted state group and a delta between that
- # and this state.
- context.prev_group = prev_group
- context.delta_ids = delta_ids
+ _current_state_ids = attr.ib(default=None, type=Optional[StateMap[str]])
+ _prev_state_ids = attr.ib(default=None, type=Optional[StateMap[str]])
- return context
+ @staticmethod
+ def with_state(
+ state_group,
+ state_group_before_event,
+ current_state_ids,
+ prev_state_ids,
+ prev_group=None,
+ delta_ids=None,
+ ):
+ return EventContext(
+ current_state_ids=current_state_ids,
+ prev_state_ids=prev_state_ids,
+ state_group=state_group,
+ state_group_before_event=state_group_before_event,
+ prev_group=prev_group,
+ delta_ids=delta_ids,
+ )
@defer.inlineCallbacks
def serialize(self, event, store):
@@ -127,83 +148,102 @@ class EventContext(object):
# the prev_state_ids, so if we're a state event we include the event
# id that we replaced in the state.
if event.is_state():
- prev_state_ids = yield self.get_prev_state_ids(store)
+ prev_state_ids = yield self.get_prev_state_ids()
prev_state_id = prev_state_ids.get((event.type, event.state_key))
else:
prev_state_id = None
- defer.returnValue({
+ return {
"prev_state_id": prev_state_id,
"event_type": event.type,
"event_state_key": event.state_key if event.is_state() else None,
- "state_group": self.state_group,
+ "state_group": self._state_group,
+ "state_group_before_event": self.state_group_before_event,
"rejected": self.rejected,
"prev_group": self.prev_group,
"delta_ids": _encode_state_dict(self.delta_ids),
- "prev_state_events": self.prev_state_events,
- "app_service_id": self.app_service.id if self.app_service else None
- })
+ "app_service_id": self.app_service.id if self.app_service else None,
+ }
@staticmethod
- def deserialize(store, input):
+ def deserialize(storage, input):
"""Converts a dict that was produced by `serialize` back into a
EventContext.
Args:
- store (DataStore): Used to convert AS ID to AS object
+ storage (Storage): Used to convert AS ID to AS object and fetch
+ state.
input (dict): A dict produced by `serialize`
Returns:
EventContext
"""
- context = EventContext()
+ context = _AsyncEventContextImpl(
+ # We use the state_group and prev_state_id stuff to pull the
+ # current_state_ids out of the DB and construct prev_state_ids.
+ storage=storage,
+ prev_state_id=input["prev_state_id"],
+ event_type=input["event_type"],
+ event_state_key=input["event_state_key"],
+ state_group=input["state_group"],
+ state_group_before_event=input["state_group_before_event"],
+ prev_group=input["prev_group"],
+ delta_ids=_decode_state_dict(input["delta_ids"]),
+ rejected=input["rejected"],
+ )
- # We use the state_group and prev_state_id stuff to pull the
- # current_state_ids out of the DB and construct prev_state_ids.
- context._prev_state_id = input["prev_state_id"]
- context._event_type = input["event_type"]
- context._event_state_key = input["event_state_key"]
+ app_service_id = input["app_service_id"]
+ if app_service_id:
+ context.app_service = storage.main.get_app_service_by_id(app_service_id)
- context._current_state_ids = None
- context._prev_state_ids = None
- context._fetching_state_deferred = None
+ return context
- context.state_group = input["state_group"]
- context.prev_group = input["prev_group"]
- context.delta_ids = _decode_state_dict(input["delta_ids"])
+ @property
+ def state_group(self) -> Optional[int]:
+ """The ID of the state group for this event.
- context.rejected = input["rejected"]
- context.prev_state_events = input["prev_state_events"]
+ Note that state events are persisted with a state group which includes the new
+ event, so this is effectively the state *after* the event in question.
- app_service_id = input["app_service_id"]
- if app_service_id:
- context.app_service = store.get_app_service_by_id(app_service_id)
+ For an outlier, where we don't have the state at the event, this will be None.
- return context
+ It is an error to access this for a rejected event, since rejected state should
+ not make it into the room state. Accessing this property will raise an exception
+ if ``rejected`` is set.
+ """
+ if self.rejected:
+ raise RuntimeError("Attempt to access state_group of rejected event")
+
+ return self._state_group
@defer.inlineCallbacks
- def get_current_state_ids(self, store):
- """Gets the current state IDs
+ def get_current_state_ids(self):
+ """
+ Gets the room state map, including this event - ie, the state in ``state_group``
+
+ It is an error to access this for a rejected event, since rejected state should
+ not make it into the room state. This method will raise an exception if
+ ``rejected`` is set.
Returns:
Deferred[dict[(str, str), str]|None]: Returns None if state_group
is None, which happens when the associated event is an outlier.
+
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
+ if self.rejected:
+ raise RuntimeError("Attempt to access state_ids of rejected event")
- if not self._fetching_state_deferred:
- self._fetching_state_deferred = run_in_background(
- self._fill_out_state, store,
- )
-
- yield make_deferred_yieldable(self._fetching_state_deferred)
-
- defer.returnValue(self._current_state_ids)
+ yield self._ensure_fetched()
+ return self._current_state_ids
@defer.inlineCallbacks
- def get_prev_state_ids(self, store):
- """Gets the prev state IDs
+ def get_prev_state_ids(self):
+ """
+ Gets the room state map, excluding this event.
+
+ For a non-state event, this will be the same as get_current_state_ids().
Returns:
Deferred[dict[(str, str), str]|None]: Returns None if state_group
@@ -211,37 +251,76 @@ class EventContext(object):
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
-
- if not self._fetching_state_deferred:
- self._fetching_state_deferred = run_in_background(
- self._fill_out_state, store,
- )
-
- yield make_deferred_yieldable(self._fetching_state_deferred)
-
- defer.returnValue(self._prev_state_ids)
+ yield self._ensure_fetched()
+ return self._prev_state_ids
def get_cached_current_state_ids(self):
"""Gets the current state IDs if we have them already cached.
+ It is an error to access this for a rejected event, since rejected state should
+ not make it into the room state. This method will raise an exception if
+ ``rejected`` is set.
+
Returns:
dict[(str, str), str]|None: Returns None if we haven't cached the
state or if state_group is None, which happens when the associated
event is an outlier.
"""
+ if self.rejected:
+ raise RuntimeError("Attempt to access state_ids of rejected event")
return self._current_state_ids
+ def _ensure_fetched(self):
+ return defer.succeed(None)
+
+
+@attr.s(slots=True)
+class _AsyncEventContextImpl(EventContext):
+ """
+ An implementation of EventContext which fetches _current_state_ids and
+ _prev_state_ids from the database on demand.
+
+ Attributes:
+
+ _storage (Storage)
+
+ _fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
+ been calculated. None if we haven't started calculating yet
+
+ _event_type (str): The type of the event the context is associated with.
+
+ _event_state_key (str): The state_key of the event the context is
+ associated with.
+
+ _prev_state_id (str|None): If the event associated with the context is
+ a state event, then `_prev_state_id` is the event_id of the state
+ that was replaced.
+ """
+
+ # This needs to have a default as we're inheriting
+ _storage = attr.ib(default=None)
+ _prev_state_id = attr.ib(default=None)
+ _event_type = attr.ib(default=None)
+ _event_state_key = attr.ib(default=None)
+ _fetching_state_deferred = attr.ib(default=None)
+
+ def _ensure_fetched(self):
+ if not self._fetching_state_deferred:
+ self._fetching_state_deferred = run_in_background(self._fill_out_state)
+
+ return make_deferred_yieldable(self._fetching_state_deferred)
+
@defer.inlineCallbacks
- def _fill_out_state(self, store):
+ def _fill_out_state(self):
"""Called to populate the _current_state_ids and _prev_state_ids
attributes by loading from the database.
"""
if self.state_group is None:
return
- self._current_state_ids = yield store.get_state_ids_for_group(
- self.state_group,
+ self._current_state_ids = yield self._storage.state.get_state_ids_for_group(
+ self.state_group
)
if self._prev_state_id and self._event_state_key is not None:
self._prev_state_ids = dict(self._current_state_ids)
@@ -251,26 +330,6 @@ class EventContext(object):
else:
self._prev_state_ids = self._current_state_ids
- @defer.inlineCallbacks
- def update_state(self, state_group, prev_state_ids, current_state_ids,
- prev_group, delta_ids):
- """Replace the state in the context
- """
-
- # We need to make sure we wait for any ongoing fetching of state
- # to complete so that the updated state doesn't get clobbered
- if self._fetching_state_deferred:
- yield make_deferred_yieldable(self._fetching_state_deferred)
-
- self.state_group = state_group
- self._prev_state_ids = prev_state_ids
- self.prev_group = prev_group
- self._current_state_ids = current_state_ids
- self.delta_ids = delta_ids
-
- # We need to ensure that that we've marked as having fetched the state
- self._fetching_state_deferred = defer.succeed(None)
-
def _encode_state_dict(state_dict):
"""Since dicts of (type, state_key) -> event_id cannot be serialized in
@@ -279,10 +338,7 @@ def _encode_state_dict(state_dict):
if state_dict is None:
return None
- return [
- (etype, state_key, v)
- for (etype, state_key), v in iteritems(state_dict)
- ]
+ return [(etype, state_key, v) for (etype, state_key), v in iteritems(state_dict)]
def _decode_state_dict(input):
@@ -291,4 +347,4 @@ def _decode_state_dict(input):
if input is None:
return None
- return frozendict({(etype, state_key,): v for etype, state_key, v in input})
+ return frozendict({(etype, state_key): v for etype, state_key, v in input})
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index b8ccced43b..ae33b3f65d 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,9 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
+from typing import Dict, List, Optional
+
+from synapse.spam_checker_api import SpamCheckerApi
+
+MYPY = False
+if MYPY:
+ import synapse.server
+
class SpamChecker(object):
- def __init__(self, hs):
+ def __init__(self, hs: "synapse.server.HomeServer"):
self.spam_checker = None
module = None
@@ -26,9 +36,16 @@ class SpamChecker(object):
pass
if module is not None:
- self.spam_checker = module(config=config)
-
- def check_event_for_spam(self, event):
+ # Older spam checkers don't accept the `api` argument, so we
+ # try and detect support.
+ spam_args = inspect.getfullargspec(module)
+ if "api" in spam_args.args:
+ api = SpamCheckerApi(hs)
+ self.spam_checker = module(config=config, api=api)
+ else:
+ self.spam_checker = module(config=config)
+
+ def check_event_for_spam(self, event: "synapse.events.EventBase") -> bool:
"""Checks if a given event is considered "spammy" by this server.
If the server considers an event spammy, then it will be rejected if
@@ -36,101 +53,117 @@ class SpamChecker(object):
users receive a blank event.
Args:
- event (synapse.events.EventBase): the event to be checked
+ event: the event to be checked
Returns:
- bool: True if the event is spammy.
+ True if the event is spammy.
"""
if self.spam_checker is None:
return False
return self.spam_checker.check_event_for_spam(event)
- def user_may_invite(self, inviter_userid, invitee_userid, third_party_invite,
- room_id, new_room, published_room):
+ def user_may_invite(
+ self,
+ inviter_userid: str,
+ invitee_userid: str,
+ third_party_invite: Optional[Dict],
+ room_id: str,
+ new_room: bool,
+ published_room: bool,
+ ) -> bool:
"""Checks if a given user may send an invite
If this method returns false, the invite will be rejected.
Args:
- inviter_userid (str)
- invitee_userid (str|None): The user ID of the invitee. Is None
+ inviter_userid:
+ invitee_userid: The user ID of the invitee. Is None
if this is a third party invite and the 3PID is not bound to a
user ID.
- third_party_invite (dict|None): If a third party invite then is a
+ third_party_invite: If a third party invite then is a
dict containing the medium and address of the invitee.
- room_id (str)
- new_room (bool): Whether the user is being invited to the room as
+ room_id:
+ new_room: Whether the user is being invited to the room as
part of a room creation, if so the invitee would have been
included in the call to `user_may_create_room`.
- published_room (bool): Whether the room the user is being invited
+ published_room: Whether the room the user is being invited
to has been published in the local homeserver's public room
directory.
Returns:
- bool: True if the user may send an invite, otherwise False
+ True if the user may send an invite, otherwise False
"""
if self.spam_checker is None:
return True
return self.spam_checker.user_may_invite(
- inviter_userid, invitee_userid, third_party_invite, room_id, new_room,
+ inviter_userid,
+ invitee_userid,
+ third_party_invite,
+ room_id,
+ new_room,
published_room,
)
- def user_may_create_room(self, userid, invite_list, third_party_invite_list,
- cloning):
+ def user_may_create_room(
+ self,
+ userid: str,
+ invite_list: List[str],
+ third_party_invite_list: List[Dict],
+ cloning: bool,
+ ) -> bool:
"""Checks if a given user may create a room
If this method returns false, the creation request will be rejected.
Args:
- userid (string): The sender's user ID
- invite_list (list[str]): List of user IDs that would be invited to
+ userid: The ID of the user attempting to create a room
+ invite_list: List of user IDs that would be invited to
the new room.
- third_party_invite_list (list[dict]): List of third party invites
+ third_party_invite_list: List of third party invites
for the new room.
- cloning (bool): Whether the user is cloning an existing room, e.g.
+ cloning: Whether the user is cloning an existing room, e.g.
upgrading a room.
Returns:
- bool: True if the user may create a room, otherwise False
+ True if the user may create a room, otherwise False
"""
if self.spam_checker is None:
return True
return self.spam_checker.user_may_create_room(
- userid, invite_list, third_party_invite_list, cloning,
+ userid, invite_list, third_party_invite_list, cloning
)
- def user_may_create_room_alias(self, userid, room_alias):
+ def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
"""Checks if a given user may create a room alias
If this method returns false, the association request will be rejected.
Args:
- userid (string): The sender's user ID
- room_alias (string): The alias to be created
+ userid: The ID of the user attempting to create a room alias
+ room_alias: The alias to be created
Returns:
- bool: True if the user may create a room alias, otherwise False
+ True if the user may create a room alias, otherwise False
"""
if self.spam_checker is None:
return True
return self.spam_checker.user_may_create_room_alias(userid, room_alias)
- def user_may_publish_room(self, userid, room_id):
+ def user_may_publish_room(self, userid: str, room_id: str) -> bool:
"""Checks if a given user may publish a room to the directory
If this method returns false, the publish request will be rejected.
Args:
- userid (string): The sender's user ID
- room_id (string): The ID of the room that would be published
+ userid: The user ID attempting to publish the room
+ room_id: The ID of the room that would be published
Returns:
- bool: True if the user may publish the room, otherwise False
+ True if the user may publish the room, otherwise False
"""
if self.spam_checker is None:
return True
@@ -154,3 +187,29 @@ class SpamChecker(object):
return True
return self.spam_checker.user_may_join_room(userid, room_id, is_invited)
+
+ def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
+ """Checks if a user ID or display name are considered "spammy" by this server.
+
+ If the server considers a username spammy, then it will not be included in
+ user directory results.
+
+ Args:
+ user_profile: The user information to check, it contains the keys:
+ * user_id
+ * display_name
+ * avatar_url
+
+ Returns:
+ True if the user is spammy.
+ """
+ if self.spam_checker is None:
+ return False
+
+ # For backwards compatibility, if the method does not exist on the spam checker, fallback to not interfering.
+ checker = getattr(self.spam_checker, "check_username_for_spam", None)
+ if not checker:
+ return False
+ # Make a copy of the user profile object to ensure the spam checker
+ # cannot modify it.
+ return checker(user_profile.copy())
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 50ceeb1e8e..459132d388 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -36,8 +36,7 @@ class ThirdPartyEventRules(object):
if module is not None:
self.third_party_rules = module(
- config=config,
- http_client=hs.get_simple_http_client(),
+ config=config, http_client=hs.get_simple_http_client()
)
@defer.inlineCallbacks
@@ -52,9 +51,9 @@ class ThirdPartyEventRules(object):
defer.Deferred[bool]: True if the event should be allowed, False if not.
"""
if self.third_party_rules is None:
- defer.returnValue(True)
+ return True
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
# Retrieve the state events from the database.
state_events = {}
@@ -62,7 +61,7 @@ class ThirdPartyEventRules(object):
state_events[key] = yield self.store.get_event(event_id, allow_none=True)
ret = yield self.third_party_rules.check_event_allowed(event, state_events)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def on_create_room(self, requester, config, is_requester_admin):
@@ -75,15 +74,16 @@ class ThirdPartyEventRules(object):
is_requester_admin (bool): If the requester is an admin
Returns:
- defer.Deferred
+ defer.Deferred[bool]: Whether room creation is allowed or denied.
"""
if self.third_party_rules is None:
- return
+ return True
- yield self.third_party_rules.on_create_room(
+ ret = yield self.third_party_rules.on_create_room(
requester, config, is_requester_admin
)
+ return ret
@defer.inlineCallbacks
def check_threepid_can_be_invited(self, medium, address, room_id):
@@ -99,7 +99,7 @@ class ThirdPartyEventRules(object):
"""
if self.third_party_rules is None:
- defer.returnValue(True)
+ return True
state_ids = yield self.store.get_filtered_current_state_ids(room_id)
room_state_events = yield self.store.get_events(state_ids.values())
@@ -109,6 +109,6 @@ class ThirdPartyEventRules(object):
state_events[key] = room_state_events[event_id]
ret = yield self.third_party_rules.check_threepid_can_be_invited(
- medium, address, state_events,
+ medium, address, state_events
)
- defer.returnValue(ret)
+ return ret
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index e2d4384de1..b75b097e5e 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -12,8 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import collections
import re
+from typing import Mapping, Union
from six import string_types
@@ -22,6 +23,7 @@ from frozendict import frozendict
from twisted.internet import defer
from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.room_versions import RoomVersion
from synapse.util.async_helpers import yieldable_gather_results
from . import EventBase
@@ -31,40 +33,37 @@ from . import EventBase
# by a match for 'stuff'.
# TODO: This is fast, but fails to handle "foo\\.bar" which should be treated as
# the literal fields "foo\" and "bar" but will instead be treated as "foo\\.bar"
-SPLIT_FIELD_REGEX = re.compile(r'(?<!\\)\.')
+SPLIT_FIELD_REGEX = re.compile(r"(?<!\\)\.")
-def prune_event(event):
+def prune_event(event: EventBase) -> EventBase:
""" Returns a pruned version of the given event, which removes all keys we
don't know about or think could potentially be dodgy.
This is used when we "redact" an event. We want to remove all fields that
the user has specified, but we do want to keep necessary information like
type, state_key etc.
-
- Args:
- event (FrozenEvent)
-
- Returns:
- FrozenEvent
"""
- pruned_event_dict = prune_event_dict(event.get_dict())
+ pruned_event_dict = prune_event_dict(event.room_version, event.get_dict())
+
+ from . import make_event_from_dict
- from . import event_type_from_format_version
- return event_type_from_format_version(event.format_version)(
- pruned_event_dict, event.internal_metadata.get_dict()
+ pruned_event = make_event_from_dict(
+ pruned_event_dict, event.room_version, event.internal_metadata.get_dict()
)
+ # Mark the event as redacted
+ pruned_event.internal_metadata.redacted = True
-def prune_event_dict(event_dict):
+ return pruned_event
+
+
+def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
"""Redacts the event_dict in the same way as `prune_event`, except it
operates on dicts rather than event objects
- Args:
- event_dict (dict)
-
Returns:
- dict: A copy of the pruned event dict
+ A copy of the pruned event dict
"""
allowed_keys = [
@@ -111,16 +110,12 @@ def prune_event_dict(event_dict):
"kick",
"redact",
)
- elif event_type == EventTypes.Aliases:
+ elif event_type == EventTypes.Aliases and room_version.special_case_aliases_auth:
add_fields("aliases")
elif event_type == EventTypes.RoomHistoryVisibility:
add_fields("history_visibility")
- allowed_fields = {
- k: v
- for k, v in event_dict.items()
- if k in allowed_keys
- }
+ allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys}
allowed_fields["content"] = new_content
@@ -205,7 +200,7 @@ def only_fields(dictionary, fields):
# for each element of the output array of arrays:
# remove escaping so we can use the right key names.
split_fields[:] = [
- [f.replace(r'\.', r'.') for f in field_array] for field_array in split_fields
+ [f.replace(r"\.", r".") for f in field_array] for field_array in split_fields
]
output = {}
@@ -226,7 +221,10 @@ def format_event_for_client_v1(d):
d["user_id"] = sender
copy_keys = (
- "age", "redacted_because", "replaces_state", "prev_content",
+ "age",
+ "redacted_because",
+ "replaces_state",
+ "prev_content",
"invite_room_state",
)
for key in copy_keys:
@@ -238,8 +236,13 @@ def format_event_for_client_v1(d):
def format_event_for_client_v2(d):
drop_keys = (
- "auth_events", "prev_events", "hashes", "signatures", "depth",
- "origin", "prev_state",
+ "auth_events",
+ "prev_events",
+ "hashes",
+ "signatures",
+ "depth",
+ "origin",
+ "prev_state",
)
for key in drop_keys:
d.pop(key, None)
@@ -252,9 +255,15 @@ def format_event_for_client_v2_without_room_id(d):
return d
-def serialize_event(e, time_now_ms, as_client_event=True,
- event_format=format_event_for_client_v1,
- token_id=None, only_event_fields=None, is_invite=False):
+def serialize_event(
+ e,
+ time_now_ms,
+ as_client_event=True,
+ event_format=format_event_for_client_v1,
+ token_id=None,
+ only_event_fields=None,
+ is_invite=False,
+):
"""Serialize event for clients
Args:
@@ -288,8 +297,7 @@ def serialize_event(e, time_now_ms, as_client_event=True,
if "redacted_because" in e.unsigned:
d["unsigned"]["redacted_because"] = serialize_event(
- e.unsigned["redacted_because"], time_now_ms,
- event_format=event_format
+ e.unsigned["redacted_because"], time_now_ms, event_format=event_format
)
if token_id is not None:
@@ -308,8 +316,9 @@ def serialize_event(e, time_now_ms, as_client_event=True,
d = event_format(d)
if only_event_fields:
- if (not isinstance(only_event_fields, list) or
- not all(isinstance(f, string_types) for f in only_event_fields)):
+ if not isinstance(only_event_fields, list) or not all(
+ isinstance(f, string_types) for f in only_event_fields
+ ):
raise TypeError("only_event_fields must be a list of strings")
d = only_fields(d, only_event_fields)
@@ -344,19 +353,20 @@ class EventClientSerializer(object):
"""
# To handle the case of presence events and the like
if not isinstance(event, EventBase):
- defer.returnValue(event)
+ return event
event_id = event.event_id
serialized_event = serialize_event(event, time_now, **kwargs)
- # If MSC1849 is enabled then we need to look if thre are any relations
- # we need to bundle in with the event
- if self.experimental_msc1849_support_enabled and bundle_aggregations:
- annotations = yield self.store.get_aggregation_groups_for_event(
- event_id,
- )
+ # If MSC1849 is enabled then we need to look if there are any relations
+ # we need to bundle in with the event.
+ # Do not bundle relations if the event has been redacted
+ if not event.internal_metadata.is_redacted() and (
+ self.experimental_msc1849_support_enabled and bundle_aggregations
+ ):
+ annotations = yield self.store.get_aggregation_groups_for_event(event_id)
references = yield self.store.get_relations_for_event(
- event_id, RelationTypes.REFERENCE, direction="f",
+ event_id, RelationTypes.REFERENCE, direction="f"
)
if annotations.chunk:
@@ -385,9 +395,11 @@ class EventClientSerializer(object):
r = serialized_event["unsigned"].setdefault("m.relations", {})
r[RelationTypes.REPLACE] = {
"event_id": edit.event_id,
+ "origin_server_ts": edit.origin_server_ts,
+ "sender": edit.sender,
}
- defer.returnValue(serialized_event)
+ return serialized_event
def serialize_events(self, events, time_now, **kwargs):
"""Serializes multiple events.
@@ -401,6 +413,39 @@ class EventClientSerializer(object):
Deferred[list[dict]]: The list of serialized events
"""
return yieldable_gather_results(
- self.serialize_event, events,
- time_now=time_now, **kwargs
+ self.serialize_event, events, time_now=time_now, **kwargs
)
+
+
+def copy_power_levels_contents(
+ old_power_levels: Mapping[str, Union[int, Mapping[str, int]]]
+):
+ """Copy the content of a power_levels event, unfreezing frozendicts along the way
+
+ Raises:
+ TypeError if the input does not look like a valid power levels event content
+ """
+ if not isinstance(old_power_levels, collections.Mapping):
+ raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,))
+
+ power_levels = {}
+ for k, v in old_power_levels.items():
+
+ if isinstance(v, int):
+ power_levels[k] = v
+ continue
+
+ if isinstance(v, collections.Mapping):
+ power_levels[k] = h = {}
+ for k1, v1 in v.items():
+ # we should only have one level of nesting
+ if not isinstance(v1, int):
+ raise TypeError(
+ "Invalid power_levels value for %s.%s: %r" % (k, k1, v1)
+ )
+ h[k1] = v1
+ continue
+
+ raise TypeError("Invalid power_levels value for %s: %r" % (k, v))
+
+ return power_levels
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 6d2bd97317..9b90c9ce04 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -49,9 +49,7 @@ class EventValidator(object):
raise SynapseError(400, "Event does not have key %s" % (k,))
# Check that the following keys have string values
- event_strings = [
- "origin",
- ]
+ event_strings = ["origin"]
for s in event_strings:
if not isinstance(getattr(event, s), string_types):
@@ -63,8 +61,10 @@ class EventValidator(object):
if len(alias) > MAX_ALIAS_LENGTH:
raise SynapseError(
400,
- ("Can't create aliases longer than"
- " %d characters" % (MAX_ALIAS_LENGTH,)),
+ (
+ "Can't create aliases longer than"
+ " %d characters" % (MAX_ALIAS_LENGTH,)
+ ),
Codes.INVALID_PARAM,
)
@@ -170,11 +170,7 @@ class EventValidator(object):
event (EventBuilder|FrozenEvent)
"""
- strings = [
- "room_id",
- "sender",
- "type",
- ]
+ strings = ["room_id", "sender", "type"]
if hasattr(event, "state_key"):
strings.append("state_key")
@@ -187,19 +183,16 @@ class EventValidator(object):
UserID.from_string(event.sender)
if event.type == EventTypes.Message:
- strings = [
- "body",
- "msgtype",
- ]
+ strings = ["body", "msgtype"]
self._ensure_strings(event.content, strings)
elif event.type == EventTypes.Topic:
self._ensure_strings(event.content, ["topic"])
-
+ self._ensure_state_event(event)
elif event.type == EventTypes.Name:
self._ensure_strings(event.content, ["name"])
-
+ self._ensure_state_event(event)
elif event.type == EventTypes.Member:
if "membership" not in event.content:
raise SynapseError(400, "Content has not membership key")
@@ -207,9 +200,25 @@ class EventValidator(object):
if event.content["membership"] not in Membership.LIST:
raise SynapseError(400, "Invalid membership key")
+ self._ensure_state_event(event)
+ elif event.type == EventTypes.Tombstone:
+ if "replacement_room" not in event.content:
+ raise SynapseError(400, "Content has no replacement_room key")
+
+ if event.content["replacement_room"] == event.room_id:
+ raise SynapseError(
+ 400, "Tombstone cannot reference the room it was sent in"
+ )
+
+ self._ensure_state_event(event)
+
def _ensure_strings(self, d, keys):
for s in keys:
if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,))
if not isinstance(d[s], string_types):
raise SynapseError(400, "'%s' not a string type" % (s,))
+
+ def _ensure_state_event(self, event):
+ if not event.is_state():
+ raise SynapseError(400, "'%s' must be state events" % (event.type,))
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index fc5cfb7d83..5c991e5412 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,21 +15,32 @@
# limitations under the License.
import logging
from collections import namedtuple
+from typing import Iterable, List
import six
from twisted.internet import defer
-from twisted.internet.defer import DeferredList
+from twisted.internet.defer import Deferred, DeferredList
+from twisted.python.failure import Failure
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions
+from synapse.api.room_versions import (
+ KNOWN_ROOM_VERSIONS,
+ EventFormatVersions,
+ RoomVersion,
+)
from synapse.crypto.event_signing import check_event_content_hash
-from synapse.events import event_type_from_format_version
+from synapse.crypto.keyring import Keyring
+from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event
from synapse.http.servlet import assert_params_in_dict
-from synapse.types import get_domain_from_id
-from synapse.util import logcontext, unwrapFirstError
+from synapse.logging.context import (
+ LoggingContext,
+ PreserveLoggingContext,
+ make_deferred_yieldable,
+)
+from synapse.types import JsonDict, get_domain_from_id
logger = logging.getLogger(__name__)
@@ -43,112 +55,35 @@ class FederationBase(object):
self.store = hs.get_datastore()
self._clock = hs.get_clock()
- @defer.inlineCallbacks
- def _check_sigs_and_hash_and_fetch(self, origin, pdus, room_version,
- outlier=False, include_none=False):
- """Takes a list of PDUs and checks the signatures and hashs of each
- one. If a PDU fails its signature check then we check if we have it in
- the database and if not then request if from the originating server of
- that PDU.
-
- If a PDU fails its content hash check then it is redacted.
-
- The given list of PDUs are not modified, instead the function returns
- a new list.
-
- Args:
- origin (str)
- pdu (list)
- room_version (str)
- outlier (bool): Whether the events are outliers or not
- include_none (str): Whether to include None in the returned list
- for events that have failed their checks
-
- Returns:
- Deferred : A list of PDUs that have valid signatures and hashes.
- """
- deferreds = self._check_sigs_and_hashes(room_version, pdus)
-
- @defer.inlineCallbacks
- def handle_check_result(pdu, deferred):
- try:
- res = yield logcontext.make_deferred_yieldable(deferred)
- except SynapseError:
- res = None
-
- if not res:
- # Check local db.
- res = yield self.store.get_event(
- pdu.event_id,
- allow_rejected=True,
- allow_none=True,
- )
-
- if not res and pdu.origin != origin:
- try:
- res = yield self.get_pdu(
- destinations=[pdu.origin],
- event_id=pdu.event_id,
- room_version=room_version,
- outlier=outlier,
- timeout=10000,
- )
- except SynapseError:
- pass
-
- if not res:
- logger.warn(
- "Failed to find copy of %s with valid signature",
- pdu.event_id,
- )
-
- defer.returnValue(res)
-
- handle = logcontext.preserve_fn(handle_check_result)
- deferreds2 = [
- handle(pdu, deferred)
- for pdu, deferred in zip(pdus, deferreds)
- ]
-
- valid_pdus = yield logcontext.make_deferred_yieldable(
- defer.gatherResults(
- deferreds2,
- consumeErrors=True,
- )
- ).addErrback(unwrapFirstError)
-
- if include_none:
- defer.returnValue(valid_pdus)
- else:
- defer.returnValue([p for p in valid_pdus if p])
-
- def _check_sigs_and_hash(self, room_version, pdu):
- return logcontext.make_deferred_yieldable(
- self._check_sigs_and_hashes(room_version, [pdu])[0],
+ def _check_sigs_and_hash(self, room_version: str, pdu: EventBase) -> Deferred:
+ return make_deferred_yieldable(
+ self._check_sigs_and_hashes(room_version, [pdu])[0]
)
- def _check_sigs_and_hashes(self, room_version, pdus):
+ def _check_sigs_and_hashes(
+ self, room_version: str, pdus: List[EventBase]
+ ) -> List[Deferred]:
"""Checks that each of the received events is correctly signed by the
sending server.
Args:
- room_version (str): The room version of the PDUs
- pdus (list[FrozenEvent]): the events to be checked
+ room_version: The room version of the PDUs
+ pdus: the events to be checked
Returns:
- list[Deferred]: for each input event, a deferred which:
+ For each input event, a deferred which:
* returns the original event if the checks pass
* returns a redacted version of the event (if the signature
matched but the hash did not)
* throws a SynapseError if the signature check failed.
- The deferreds run their callbacks in the sentinel logcontext.
+ The deferreds run their callbacks in the sentinel
"""
deferreds = _check_sigs_on_pdus(self.keyring, room_version, pdus)
- ctx = logcontext.LoggingContext.current_context()
+ ctx = LoggingContext.current_context()
- def callback(_, pdu):
- with logcontext.PreserveLoggingContext(ctx):
+ def callback(_, pdu: EventBase):
+ with PreserveLoggingContext(ctx):
if not check_event_content_hash(pdu):
# let's try to distinguish between failures because the event was
# redacted (which are somewhat expected) vs actual ball-tampering
@@ -159,11 +94,9 @@ class FederationBase(object):
# received event was probably a redacted copy (but we then use our
# *actual* redacted copy to be on the safe side.)
redacted_event = prune_event(pdu)
- if (
- set(redacted_event.keys()) == set(pdu.keys()) and
- set(six.iterkeys(redacted_event.content))
- == set(six.iterkeys(pdu.content))
- ):
+ if set(redacted_event.keys()) == set(pdu.keys()) and set(
+ six.iterkeys(redacted_event.content)
+ ) == set(six.iterkeys(pdu.content)):
logger.info(
"Event %s seems to have been redacted; using our redacted "
"copy",
@@ -172,54 +105,58 @@ class FederationBase(object):
else:
logger.warning(
"Event %s content has been tampered, redacting",
- pdu.event_id, pdu.get_pdu_json(),
+ pdu.event_id,
)
return redacted_event
if self.spam_checker.check_event_for_spam(pdu):
- logger.warn(
+ logger.warning(
"Event contains spam, redacting %s: %s",
- pdu.event_id, pdu.get_pdu_json()
+ pdu.event_id,
+ pdu.get_pdu_json(),
)
return prune_event(pdu)
return pdu
- def errback(failure, pdu):
+ def errback(failure: Failure, pdu: EventBase):
failure.trap(SynapseError)
- with logcontext.PreserveLoggingContext(ctx):
- logger.warn(
+ with PreserveLoggingContext(ctx):
+ logger.warning(
"Signature check failed for %s: %s",
- pdu.event_id, failure.getErrorMessage(),
+ pdu.event_id,
+ failure.getErrorMessage(),
)
return failure
for deferred, pdu in zip(deferreds, pdus):
deferred.addCallbacks(
- callback, errback,
- callbackArgs=[pdu],
- errbackArgs=[pdu],
+ callback, errback, callbackArgs=[pdu], errbackArgs=[pdu]
)
return deferreds
-class PduToCheckSig(namedtuple("PduToCheckSig", [
- "pdu", "redacted_pdu_json", "sender_domain", "deferreds",
-])):
+class PduToCheckSig(
+ namedtuple(
+ "PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"]
+ )
+):
pass
-def _check_sigs_on_pdus(keyring, room_version, pdus):
+def _check_sigs_on_pdus(
+ keyring: Keyring, room_version: str, pdus: Iterable[EventBase]
+) -> List[Deferred]:
"""Check that the given events are correctly signed
Args:
- keyring (synapse.crypto.Keyring): keyring object to do the checks
- room_version (str): the room version of the PDUs
- pdus (Collection[EventBase]): the events to be checked
+ keyring: keyring object to do the checks
+ room_version: the room version of the PDUs
+ pdus: the events to be checked
Returns:
- List[Deferred]: a Deferred for each event in pdus, which will either succeed if
+ A Deferred for each event in pdus, which will either succeed if
the signatures are valid, or fail (with a SynapseError) if not.
"""
@@ -260,10 +197,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
# First we check that the sender event is signed by the sender's domain
# (except if its a 3pid invite, in which case it may be sent by any server)
- pdus_to_check_sender = [
- p for p in pdus_to_check
- if not _is_invite_via_3pid(p.pdu)
- ]
+ pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)]
more_deferreds = keyring.verify_json_objects_for_server(
[
@@ -283,9 +217,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
pdu_to_check.sender_domain,
e.getErrorMessage(),
)
- # XX not really sure if these are the right codes, but they are what
- # we've done for ages
- raise SynapseError(400, errmsg, Codes.UNAUTHORIZED)
+ raise SynapseError(403, errmsg, Codes.FORBIDDEN)
for p, d in zip(pdus_to_check_sender, more_deferreds):
d.addErrback(sender_err, p)
@@ -297,7 +229,8 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
# (ie, the room version uses old-style non-hash event IDs).
if v.event_format == EventFormatVersions.V1:
pdus_to_check_event_id = [
- p for p in pdus_to_check
+ p
+ for p in pdus_to_check
if p.sender_domain != get_domain_from_id(p.pdu.event_id)
]
@@ -315,13 +248,10 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
def event_err(e, pdu_to_check):
errmsg = (
- "event id %s: unable to verify signature for event id domain: %s" % (
- pdu_to_check.pdu.event_id,
- e.getErrorMessage(),
- )
+ "event id %s: unable to verify signature for event id domain: %s"
+ % (pdu_to_check.pdu.event_id, e.getErrorMessage())
)
- # XX as above: not really sure if these are the right codes
- raise SynapseError(400, errmsg, Codes.UNAUTHORIZED)
+ raise SynapseError(403, errmsg, Codes.FORBIDDEN)
for p, d in zip(pdus_to_check_event_id, more_deferreds):
d.addErrback(event_err, p)
@@ -331,7 +261,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
-def _flatten_deferred_list(deferreds):
+def _flatten_deferred_list(deferreds: List[Deferred]) -> Deferred:
"""Given a list of deferreds, either return the single deferred,
combine into a DeferredList, or return an already resolved deferred.
"""
@@ -343,7 +273,7 @@ def _flatten_deferred_list(deferreds):
return defer.succeed(None)
-def _is_invite_via_3pid(event):
+def _is_invite_via_3pid(event: EventBase) -> bool:
return (
event.type == EventTypes.Member
and event.membership == Membership.INVITE
@@ -351,16 +281,15 @@ def _is_invite_via_3pid(event):
)
-def event_from_pdu_json(pdu_json, event_format_version, outlier=False):
- """Construct a FrozenEvent from an event json received over federation
+def event_from_pdu_json(
+ pdu_json: JsonDict, room_version: RoomVersion, outlier: bool = False
+) -> EventBase:
+ """Construct an EventBase from an event json received over federation
Args:
- pdu_json (object): pdu as received over federation
- event_format_version (int): The event format version
- outlier (bool): True to mark this event as an outlier
-
- Returns:
- FrozenEvent
+ pdu_json: pdu as received over federation
+ room_version: The version of the room this event belongs to
+ outlier: True to mark this event as an outlier
Raises:
SynapseError: if the pdu is missing required fields or is otherwise
@@ -368,22 +297,18 @@ def event_from_pdu_json(pdu_json, event_format_version, outlier=False):
"""
# we could probably enforce a bunch of other fields here (room_id, sender,
# origin, etc etc)
- assert_params_in_dict(pdu_json, ('type', 'depth'))
+ assert_params_in_dict(pdu_json, ("type", "depth"))
- depth = pdu_json['depth']
+ depth = pdu_json["depth"]
if not isinstance(depth, six.integer_types):
- raise SynapseError(400, "Depth %r not an intger" % (depth, ),
- Codes.BAD_JSON)
+ raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON)
if depth < 0:
raise SynapseError(400, "Depth too small", Codes.BAD_JSON)
elif depth > MAX_DEPTH:
raise SynapseError(400, "Depth too large", Codes.BAD_JSON)
- event = event_type_from_format_version(event_format_version)(
- pdu_json,
- )
-
+ event = make_event_from_dict(pdu_json, room_version)
event.internal_metadata.outlier = outlier
return event
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 70573746d6..8c6b839478 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -17,12 +17,23 @@
import copy
import itertools
import logging
-
-from six.moves import range
+from typing import (
+ Any,
+ Awaitable,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ TypeVar,
+)
from prometheus_client import Counter
from twisted.internet import defer
+from twisted.internet.defer import Deferred
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
@@ -31,18 +42,21 @@ from synapse.api.errors import (
FederationDeniedError,
HttpResponseException,
SynapseError,
+ UnsupportedRoomVersionError,
)
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
+ RoomVersion,
RoomVersions,
)
-from synapse.events import builder, room_version_to_event_format
+from synapse.events import EventBase, builder
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
-from synapse.util import logcontext, unwrapFirstError
+from synapse.logging.context import make_deferred_yieldable, preserve_fn
+from synapse.logging.utils import log_function
+from synapse.types import JsonDict
+from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
-from synapse.util.logutils import log_function
from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
@@ -52,11 +66,14 @@ sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["t
PDU_RETRY_TIME_MS = 1 * 60 * 1000
+T = TypeVar("T")
+
class InvalidResponseError(RuntimeError):
"""Helper for _try_destination_list: indicates that the server returned a response
we couldn't parse
"""
+
pass
@@ -65,9 +82,7 @@ class FederationClient(FederationBase):
super(FederationClient, self).__init__(hs)
self.pdu_destination_tried = {}
- self._clock.looping_call(
- self._clear_tried_cache, 60 * 1000,
- )
+ self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client()
@@ -99,8 +114,14 @@ class FederationClient(FederationBase):
self.pdu_destination_tried[event_id] = destination_dict
@log_function
- def make_query(self, destination, query_type, args,
- retry_on_dns_fail=False, ignore_backoff=False):
+ def make_query(
+ self,
+ destination,
+ query_type,
+ args,
+ retry_on_dns_fail=False,
+ ignore_backoff=False,
+ ):
"""Sends a federation Query to a remote homeserver of the given type
and arguments.
@@ -120,7 +141,10 @@ class FederationClient(FederationBase):
sent_queries_counter.labels(query_type).inc()
return self.transport_layer.make_query(
- destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail,
+ destination,
+ query_type,
+ args,
+ retry_on_dns_fail=retry_on_dns_fail,
ignore_backoff=ignore_backoff,
)
@@ -137,9 +161,7 @@ class FederationClient(FederationBase):
response
"""
sent_queries_counter.labels("client_device_keys").inc()
- return self.transport_layer.query_client_keys(
- destination, content, timeout
- )
+ return self.transport_layer.query_client_keys(destination, content, timeout)
@log_function
def query_user_devices(self, destination, user_id, timeout=30000):
@@ -147,9 +169,7 @@ class FederationClient(FederationBase):
server.
"""
sent_queries_counter.labels("user_devices").inc()
- return self.transport_layer.query_user_devices(
- destination, user_id, timeout
- )
+ return self.transport_layer.query_user_devices(destination, user_id, timeout)
@log_function
def claim_client_keys(self, destination, content, timeout):
@@ -164,57 +184,57 @@ class FederationClient(FederationBase):
response
"""
sent_queries_counter.labels("client_one_time_keys").inc()
- return self.transport_layer.claim_client_keys(
- destination, content, timeout
- )
+ return self.transport_layer.claim_client_keys(destination, content, timeout)
- @defer.inlineCallbacks
- @log_function
- def backfill(self, dest, room_id, limit, extremities):
- """Requests some more historic PDUs for the given context from the
+ async def backfill(
+ self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
+ ) -> Optional[List[EventBase]]:
+ """Requests some more historic PDUs for the given room from the
given destination server.
Args:
- dest (str): The remote home server to ask.
+ dest (str): The remote homeserver to ask.
room_id (str): The room_id to backfill.
- limit (int): The maximum number of PDUs to return.
- extremities (list): List of PDU id and origins of the first pdus
- we have seen from the context
-
- Returns:
- Deferred: Results in the received PDUs.
+ limit (int): The maximum number of events to return.
+ extremities (list): our current backwards extremities, to backfill from
"""
logger.debug("backfill extrem=%s", extremities)
- # If there are no extremeties then we've (probably) reached the start.
+ # If there are no extremities then we've (probably) reached the start.
if not extremities:
- return
+ return None
- transaction_data = yield self.transport_layer.backfill(
- dest, room_id, extremities, limit)
+ transaction_data = await self.transport_layer.backfill(
+ dest, room_id, extremities, limit
+ )
- logger.debug("backfill transaction_data=%s", repr(transaction_data))
+ logger.debug("backfill transaction_data=%r", transaction_data)
- room_version = yield self.store.get_room_version(room_id)
- format_ver = room_version_to_event_format(room_version)
+ room_version = await self.store.get_room_version(room_id)
pdus = [
- event_from_pdu_json(p, format_ver, outlier=False)
+ event_from_pdu_json(p, room_version, outlier=False)
for p in transaction_data["pdus"]
]
# FIXME: We should handle signature failures more gracefully.
- pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults(
- self._check_sigs_and_hashes(room_version, pdus),
- consumeErrors=True,
- ).addErrback(unwrapFirstError))
+ pdus[:] = await make_deferred_yieldable(
+ defer.gatherResults(
+ self._check_sigs_and_hashes(room_version.identifier, pdus),
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+ )
- defer.returnValue(pdus)
+ return pdus
- @defer.inlineCallbacks
- @log_function
- def get_pdu(self, destinations, event_id, room_version, outlier=False,
- timeout=None):
+ async def get_pdu(
+ self,
+ destinations: Iterable[str],
+ event_id: str,
+ room_version: RoomVersion,
+ outlier: bool = False,
+ timeout: Optional[int] = None,
+ ) -> Optional[EventBase]:
"""Requests the PDU with given origin and ID from the remote home
servers.
@@ -222,30 +242,27 @@ class FederationClient(FederationBase):
one succeeds.
Args:
- destinations (list): Which home servers to query
- event_id (str): event to fetch
- room_version (str): version of the room
- outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
+ destinations: Which homeservers to query
+ event_id: event to fetch
+ room_version: version of the room
+ outlier: Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False`
- timeout (int): How long to try (in ms) each destination for before
+ timeout: How long to try (in ms) each destination for before
moving to the next destination. None indicates no timeout.
Returns:
- Deferred: Results in the requested PDU, or None if we were unable to find
- it.
+ The requested PDU, or None if we were unable to find it.
"""
# TODO: Rate limit the number of times we try and get the same event.
ev = self._get_pdu_cache.get(event_id)
if ev:
- defer.returnValue(ev)
+ return ev
pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
- format_ver = room_version_to_event_format(room_version)
-
signed_pdu = None
for destination in destinations:
now = self._clock.time_msec()
@@ -254,8 +271,8 @@ class FederationClient(FederationBase):
continue
try:
- transaction_data = yield self.transport_layer.get_event(
- destination, event_id, timeout=timeout,
+ transaction_data = await self.transport_layer.get_event(
+ destination, event_id, timeout=timeout
)
logger.debug(
@@ -266,15 +283,17 @@ class FederationClient(FederationBase):
)
pdu_list = [
- event_from_pdu_json(p, format_ver, outlier=outlier)
+ event_from_pdu_json(p, room_version, outlier=outlier)
for p in transaction_data["pdus"]
- ]
+ ] # type: List[EventBase]
if pdu_list and pdu_list[0]:
pdu = pdu_list[0]
# Check signatures are correct.
- signed_pdu = yield self._check_sigs_and_hash(room_version, pdu)
+ signed_pdu = await self._check_sigs_and_hash(
+ room_version.identifier, pdu
+ )
break
@@ -282,8 +301,7 @@ class FederationClient(FederationBase):
except SynapseError as e:
logger.info(
- "Failed to get PDU %s from %s because %s",
- event_id, destination, e,
+ "Failed to get PDU %s from %s because %s", event_id, destination, e
)
continue
except NotRetryingDestination as e:
@@ -296,212 +314,148 @@ class FederationClient(FederationBase):
pdu_attempts[destination] = now
logger.info(
- "Failed to get PDU %s from %s because %s",
- event_id, destination, e,
+ "Failed to get PDU %s from %s because %s", event_id, destination, e
)
continue
if signed_pdu:
self._get_pdu_cache[event_id] = signed_pdu
- defer.returnValue(signed_pdu)
+ return signed_pdu
- @defer.inlineCallbacks
- @log_function
- def get_state_for_room(self, destination, room_id, event_id):
- """Requests all of the room state at a given event from a remote home server.
-
- Args:
- destination (str): The remote homeserver to query for the state.
- room_id (str): The id of the room we're interested in.
- event_id (str): The id of the event we want the state at.
+ async def get_room_state_ids(
+ self, destination: str, room_id: str, event_id: str
+ ) -> Tuple[List[str], List[str]]:
+ """Calls the /state_ids endpoint to fetch the state at a particular point
+ in the room, and the auth events for the given event
Returns:
- Deferred[Tuple[List[EventBase], List[EventBase]]]:
- A list of events in the state, and a list of events in the auth chain
- for the given event.
+ a tuple of (state event_ids, auth event_ids)
"""
- try:
- # First we try and ask for just the IDs, as thats far quicker if
- # we have most of the state and auth_chain already.
- # However, this may 404 if the other side has an old synapse.
- result = yield self.transport_layer.get_room_state_ids(
- destination, room_id, event_id=event_id,
- )
-
- state_event_ids = result["pdu_ids"]
- auth_event_ids = result.get("auth_chain_ids", [])
-
- fetched_events, failed_to_fetch = yield self.get_events_from_store_or_dest(
- destination, room_id, set(state_event_ids + auth_event_ids)
- )
-
- if failed_to_fetch:
- logger.warning(
- "Failed to fetch missing state/auth events for %s: %s",
- room_id,
- failed_to_fetch
- )
-
- event_map = {
- ev.event_id: ev for ev in fetched_events
- }
-
- pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
- auth_chain = [
- event_map[e_id] for e_id in auth_event_ids if e_id in event_map
- ]
-
- auth_chain.sort(key=lambda e: e.depth)
-
- defer.returnValue((pdus, auth_chain))
- except HttpResponseException as e:
- if e.code == 400 or e.code == 404:
- logger.info("Failed to use get_room_state_ids API, falling back")
- else:
- raise e
-
- result = yield self.transport_layer.get_room_state(
- destination, room_id, event_id=event_id,
+ result = await self.transport_layer.get_room_state_ids(
+ destination, room_id, event_id=event_id
)
- room_version = yield self.store.get_room_version(room_id)
- format_ver = room_version_to_event_format(room_version)
+ state_event_ids = result["pdu_ids"]
+ auth_event_ids = result.get("auth_chain_ids", [])
- pdus = [
- event_from_pdu_json(p, format_ver, outlier=True)
- for p in result["pdus"]
- ]
+ if not isinstance(state_event_ids, list) or not isinstance(
+ auth_event_ids, list
+ ):
+ raise Exception("invalid response from /state_ids")
- auth_chain = [
- event_from_pdu_json(p, format_ver, outlier=True)
- for p in result.get("auth_chain", [])
- ]
+ return state_event_ids, auth_event_ids
- seen_events = yield self.store.get_events([
- ev.event_id for ev in itertools.chain(pdus, auth_chain)
- ])
+ async def _check_sigs_and_hash_and_fetch(
+ self,
+ origin: str,
+ pdus: List[EventBase],
+ room_version: str,
+ outlier: bool = False,
+ include_none: bool = False,
+ ) -> List[EventBase]:
+ """Takes a list of PDUs and checks the signatures and hashs of each
+ one. If a PDU fails its signature check then we check if we have it in
+ the database and if not then request if from the originating server of
+ that PDU.
- signed_pdus = yield self._check_sigs_and_hash_and_fetch(
- destination,
- [p for p in pdus if p.event_id not in seen_events],
- outlier=True,
- room_version=room_version,
- )
- signed_pdus.extend(
- seen_events[p.event_id] for p in pdus if p.event_id in seen_events
- )
-
- signed_auth = yield self._check_sigs_and_hash_and_fetch(
- destination,
- [p for p in auth_chain if p.event_id not in seen_events],
- outlier=True,
- room_version=room_version,
- )
- signed_auth.extend(
- seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events
- )
+ If a PDU fails its content hash check then it is redacted.
- signed_auth.sort(key=lambda e: e.depth)
-
- defer.returnValue((signed_pdus, signed_auth))
-
- @defer.inlineCallbacks
- def get_events_from_store_or_dest(self, destination, room_id, event_ids):
- """Fetch events from a remote destination, checking if we already have them.
+ The given list of PDUs are not modified, instead the function returns
+ a new list.
Args:
- destination (str)
- room_id (str)
- event_ids (list)
+ origin
+ pdu
+ room_version
+ outlier: Whether the events are outliers or not
+ include_none: Whether to include None in the returned list
+ for events that have failed their checks
Returns:
- Deferred: A deferred resolving to a 2-tuple where the first is a list of
- events and the second is a list of event ids that we failed to fetch.
+ Deferred : A list of PDUs that have valid signatures and hashes.
"""
- seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
- signed_events = list(seen_events.values())
-
- failed_to_fetch = set()
-
- missing_events = set(event_ids)
- for k in seen_events:
- missing_events.discard(k)
-
- if not missing_events:
- defer.returnValue((signed_events, failed_to_fetch))
-
- logger.debug(
- "Fetching unknown state/auth events %s for room %s",
- missing_events,
- event_ids,
- )
+ deferreds = self._check_sigs_and_hashes(room_version, pdus)
- room_version = yield self.store.get_room_version(room_id)
+ @defer.inlineCallbacks
+ def handle_check_result(pdu: EventBase, deferred: Deferred):
+ try:
+ res = yield make_deferred_yieldable(deferred)
+ except SynapseError:
+ res = None
+
+ if not res:
+ # Check local db.
+ res = yield self.store.get_event(
+ pdu.event_id, allow_rejected=True, allow_none=True
+ )
- batch_size = 20
- missing_events = list(missing_events)
- for i in range(0, len(missing_events), batch_size):
- batch = set(missing_events[i:i + batch_size])
+ if not res and pdu.origin != origin:
+ try:
+ res = yield defer.ensureDeferred(
+ self.get_pdu(
+ destinations=[pdu.origin],
+ event_id=pdu.event_id,
+ room_version=room_version, # type: ignore
+ outlier=outlier,
+ timeout=10000,
+ )
+ )
+ except SynapseError:
+ pass
- deferreds = [
- run_in_background(
- self.get_pdu,
- destinations=[destination],
- event_id=e_id,
- room_version=room_version,
+ if not res:
+ logger.warning(
+ "Failed to find copy of %s with valid signature", pdu.event_id
)
- for e_id in batch
- ]
- res = yield make_deferred_yieldable(
- defer.DeferredList(deferreds, consumeErrors=True)
- )
- for success, result in res:
- if success and result:
- signed_events.append(result)
- batch.discard(result.event_id)
+ return res
- # We removed all events we successfully fetched from `batch`
- failed_to_fetch.update(batch)
+ handle = preserve_fn(handle_check_result)
+ deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)]
- defer.returnValue((signed_events, failed_to_fetch))
+ valid_pdus = await make_deferred_yieldable(
+ defer.gatherResults(deferreds2, consumeErrors=True)
+ ).addErrback(unwrapFirstError)
- @defer.inlineCallbacks
- @log_function
- def get_event_auth(self, destination, room_id, event_id):
- res = yield self.transport_layer.get_event_auth(
- destination, room_id, event_id,
- )
+ if include_none:
+ return valid_pdus
+ else:
+ return [p for p in valid_pdus if p]
- room_version = yield self.store.get_room_version(room_id)
- format_ver = room_version_to_event_format(room_version)
+ async def get_event_auth(self, destination, room_id, event_id):
+ res = await self.transport_layer.get_event_auth(destination, room_id, event_id)
+
+ room_version = await self.store.get_room_version(room_id)
auth_chain = [
- event_from_pdu_json(p, format_ver, outlier=True)
+ event_from_pdu_json(p, room_version, outlier=True)
for p in res["auth_chain"]
]
- signed_auth = yield self._check_sigs_and_hash_and_fetch(
- destination, auth_chain,
- outlier=True, room_version=room_version,
+ signed_auth = await self._check_sigs_and_hash_and_fetch(
+ destination, auth_chain, outlier=True, room_version=room_version.identifier
)
signed_auth.sort(key=lambda e: e.depth)
- defer.returnValue(signed_auth)
+ return signed_auth
- @defer.inlineCallbacks
- def _try_destination_list(self, description, destinations, callback):
+ async def _try_destination_list(
+ self,
+ description: str,
+ destinations: Iterable[str],
+ callback: Callable[[str], Awaitable[T]],
+ ) -> T:
"""Try an operation on a series of servers, until it succeeds
Args:
- description (unicode): description of the operation we're doing, for logging
+ description: description of the operation we're doing, for logging
- destinations (Iterable[unicode]): list of server_names to try
+ destinations: list of server_names to try
- callback (callable): Function to run for each server. Passed a single
- argument: the server_name to try. May return a deferred.
+ callback: Function to run for each server. Passed a single
+ argument: the server_name to try.
If the callback raises a CodeMessageException with a 300/400 code,
attempts to perform the operation stop immediately and the exception is
@@ -512,43 +466,50 @@ class FederationClient(FederationBase):
suppressed if the exception is an InvalidResponseError.
Returns:
- The [Deferred] result of callback, if it succeeds
+ The result of callback, if it succeeds
Raises:
- SynapseError if the chosen remote server returns a 300/400 code.
-
- RuntimeError if no servers were reachable.
+ SynapseError if the chosen remote server returns a 300/400 code, or
+ no servers were reachable.
"""
for destination in destinations:
if destination == self.server_name:
continue
try:
- res = yield callback(destination)
- defer.returnValue(res)
+ res = await callback(destination)
+ return res
except InvalidResponseError as e:
- logger.warn(
- "Failed to %s via %s: %s",
- description, destination, e,
- )
+ logger.warning("Failed to %s via %s: %s", description, destination, e)
+ except UnsupportedRoomVersionError:
+ raise
except HttpResponseException as e:
if not 500 <= e.code < 600:
raise e.to_synapse_error()
else:
- logger.warn(
+ logger.warning(
"Failed to %s via %s: %i %s",
- description, destination, e.code, e.args[0],
+ description,
+ destination,
+ e.code,
+ e.args[0],
)
except Exception:
- logger.warn(
- "Failed to %s via %s",
- description, destination, exc_info=1,
+ logger.warning(
+ "Failed to %s via %s", description, destination, exc_info=True
)
- raise RuntimeError("Failed to %s via any server" % (description, ))
-
- def make_membership_event(self, destinations, room_id, user_id, membership,
- content, params):
+ raise SynapseError(502, "Failed to %s via any server" % (description,))
+
+ async def make_membership_event(
+ self,
+ destinations: Iterable[str],
+ room_id: str,
+ user_id: str,
+ membership: str,
+ content: dict,
+ params: Dict[str, str],
+ ) -> Tuple[str, EventBase, RoomVersion]:
"""
Creates an m.room.member event, with context, without participating in the room.
@@ -560,44 +521,47 @@ class FederationClient(FederationBase):
Note that this does not append any events to any graphs.
Args:
- destinations (str): Candidate homeservers which are probably
+ destinations: Candidate homeservers which are probably
participating in the room.
- room_id (str): The room in which the event will happen.
- user_id (str): The user whose membership is being evented.
- membership (str): The "membership" property of the event. Must be
- one of "join" or "leave".
- content (dict): Any additional data to put into the content field
- of the event.
- params (dict[str, str|Iterable[str]]): Query parameters to include in the
- request.
- Return:
- Deferred[tuple[str, FrozenEvent, int]]: resolves to a tuple of
- `(origin, event, event_format)` where origin is the remote
- homeserver which generated the event, and event_format is one of
- `synapse.api.room_versions.EventFormatVersions`.
-
- Fails with a ``SynapseError`` if the chosen remote server
- returns a 300/400 code.
-
- Fails with a ``RuntimeError`` if no servers were reachable.
+ room_id: The room in which the event will happen.
+ user_id: The user whose membership is being evented.
+ membership: The "membership" property of the event. Must be one of
+ "join" or "leave".
+ content: Any additional data to put into the content field of the
+ event.
+ params: Query parameters to include in the request.
+
+ Returns:
+ `(origin, event, room_version)` where origin is the remote
+ homeserver which generated the event, and room_version is the
+ version of the room.
+
+ Raises:
+ UnsupportedRoomVersionError: if remote responds with
+ a room version we don't understand.
+
+ SynapseError: if the chosen remote server returns a 300/400 code.
+
+ RuntimeError: if no servers were reachable.
"""
valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships:
raise RuntimeError(
- "make_membership_event called with membership='%s', must be one of %s" %
- (membership, ",".join(valid_memberships))
+ "make_membership_event called with membership='%s', must be one of %s"
+ % (membership, ",".join(valid_memberships))
)
- @defer.inlineCallbacks
- def send_request(destination):
- ret = yield self.transport_layer.make_membership_event(
- destination, room_id, user_id, membership, params,
+ async def send_request(destination: str) -> Tuple[str, EventBase, RoomVersion]:
+ ret = await self.transport_layer.make_membership_event(
+ destination, room_id, user_id, membership, params
)
# Note: If not supplied, the room version may be either v1 or v2,
# however either way the event format version will be v1.
- room_version = ret.get("room_version", RoomVersions.V1.identifier)
- event_format = room_version_to_event_format(room_version)
+ room_version_id = ret.get("room_version", RoomVersions.V1.identifier)
+ room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
+ if not room_version:
+ raise UnsupportedRoomVersionError()
pdu_dict = ret.get("event", None)
if not isinstance(pdu_dict, dict):
@@ -614,111 +578,93 @@ class FederationClient(FederationBase):
pdu_dict["prev_state"] = []
ev = builder.create_local_event_from_event_dict(
- self._clock, self.hostname, self.signing_key,
- format_version=event_format, event_dict=pdu_dict,
+ self._clock,
+ self.hostname,
+ self.signing_key,
+ room_version=room_version,
+ event_dict=pdu_dict,
)
- defer.returnValue(
- (destination, ev, event_format)
- )
+ return destination, ev, room_version
- return self._try_destination_list(
- "make_" + membership, destinations, send_request,
+ return await self._try_destination_list(
+ "make_" + membership, destinations, send_request
)
- def send_join(self, destinations, pdu, event_format_version):
+ async def send_join(
+ self, destinations: Iterable[str], pdu: EventBase, room_version: RoomVersion
+ ) -> Dict[str, Any]:
"""Sends a join event to one of a list of homeservers.
Doing so will cause the remote server to add the event to the graph,
and send the event out to the rest of the federation.
Args:
- destinations (str): Candidate homeservers which are probably
+ destinations: Candidate homeservers which are probably
participating in the room.
- pdu (BaseEvent): event to be sent
- event_format_version (int): The event format version
+ pdu: event to be sent
+ room_version: the version of the room (according to the server that
+ did the make_join)
- Return:
- Deferred: resolves to a dict with members ``origin`` (a string
- giving the serer the event was sent to, ``state`` (?) and
+ Returns:
+ a dict with members ``origin`` (a string
+ giving the server the event was sent to, ``state`` (?) and
``auth_chain``.
- Fails with a ``SynapseError`` if the chosen remote server
- returns a 300/400 code.
+ Raises:
+ SynapseError: if the chosen remote server returns a 300/400 code.
- Fails with a ``RuntimeError`` if no servers were reachable.
+ RuntimeError: if no servers were reachable.
"""
- def check_authchain_validity(signed_auth_chain):
- for e in signed_auth_chain:
- if e.type == EventTypes.Create:
- create_event = e
- break
- else:
- raise InvalidResponseError(
- "no %s in auth chain" % (EventTypes.Create,),
- )
-
- # the room version should be sane.
- room_version = create_event.content.get("room_version", "1")
- if room_version not in KNOWN_ROOM_VERSIONS:
- # This shouldn't be possible, because the remote server should have
- # rejected the join attempt during make_join.
- raise InvalidResponseError(
- "room appears to have unsupported version %s" % (
- room_version,
- ))
-
- @defer.inlineCallbacks
- def send_request(destination):
- time_now = self._clock.time_msec()
- _, content = yield self.transport_layer.send_join(
- destination=destination,
- room_id=pdu.room_id,
- event_id=pdu.event_id,
- content=pdu.get_pdu_json(time_now),
- )
+ async def send_request(destination) -> Dict[str, Any]:
+ content = await self._do_send_join(destination, pdu)
logger.debug("Got content: %s", content)
state = [
- event_from_pdu_json(p, event_format_version, outlier=True)
+ event_from_pdu_json(p, room_version, outlier=True)
for p in content.get("state", [])
]
auth_chain = [
- event_from_pdu_json(p, event_format_version, outlier=True)
+ event_from_pdu_json(p, room_version, outlier=True)
for p in content.get("auth_chain", [])
]
- pdus = {
- p.event_id: p
- for p in itertools.chain(state, auth_chain)
- }
+ pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)}
- room_version = None
+ create_event = None
for e in state:
if (e.type, e.state_key) == (EventTypes.Create, ""):
- room_version = e.content.get(
- "room_version", RoomVersions.V1.identifier
- )
+ create_event = e
break
- if room_version is None:
+ if create_event is None:
# If the state doesn't have a create event then the room is
# invalid, and it would fail auth checks anyway.
raise SynapseError(400, "No create event in state")
- valid_pdus = yield self._check_sigs_and_hash_and_fetch(
- destination, list(pdus.values()),
+ # the room version should be sane.
+ create_room_version = create_event.content.get(
+ "room_version", RoomVersions.V1.identifier
+ )
+ if create_room_version != room_version.identifier:
+ # either the server that fulfilled the make_join, or the server that is
+ # handling the send_join, is lying.
+ raise InvalidResponseError(
+ "Unexpected room version %s in create event"
+ % (create_room_version,)
+ )
+
+ valid_pdus = await self._check_sigs_and_hash_and_fetch(
+ destination,
+ list(pdus.values()),
outlier=True,
- room_version=room_version,
+ room_version=room_version.identifier,
)
- valid_pdus_map = {
- p.event_id: p
- for p in valid_pdus
- }
+ valid_pdus_map = {p.event_id: p for p in valid_pdus}
# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
@@ -739,64 +685,106 @@ class FederationClient(FederationBase):
for s in signed_state:
s.internal_metadata = copy.deepcopy(s.internal_metadata)
- check_authchain_validity(signed_auth)
+ # double-check that the same create event has ended up in the auth chain
+ auth_chain_create_events = [
+ e.event_id
+ for e in signed_auth
+ if (e.type, e.state_key) == (EventTypes.Create, "")
+ ]
+ if auth_chain_create_events != [create_event.event_id]:
+ raise InvalidResponseError(
+ "Unexpected create event(s) in auth chain: %s"
+ % (auth_chain_create_events,)
+ )
- defer.returnValue({
+ return {
"state": signed_state,
"auth_chain": signed_auth,
"origin": destination,
- })
- return self._try_destination_list("send_join", destinations, send_request)
+ }
- @defer.inlineCallbacks
- def send_invite(self, destination, room_id, event_id, pdu):
- room_version = yield self.store.get_room_version(room_id)
+ return await self._try_destination_list("send_join", destinations, send_request)
+
+ async def _do_send_join(self, destination: str, pdu: EventBase):
+ time_now = self._clock.time_msec()
- content = yield self._do_send_invite(destination, pdu, room_version)
+ try:
+ content = await self.transport_layer.send_join_v2(
+ destination=destination,
+ room_id=pdu.room_id,
+ event_id=pdu.event_id,
+ content=pdu.get_pdu_json(time_now),
+ )
+
+ return content
+ except HttpResponseException as e:
+ if e.code in [400, 404]:
+ err = e.to_synapse_error()
+
+ # If we receive an error response that isn't a generic error, or an
+ # unrecognised endpoint error, we assume that the remote understands
+ # the v2 invite API and this is a legitimate error.
+ if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]:
+ raise err
+ else:
+ raise e.to_synapse_error()
+
+ logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
+
+ resp = await self.transport_layer.send_join_v1(
+ destination=destination,
+ room_id=pdu.room_id,
+ event_id=pdu.event_id,
+ content=pdu.get_pdu_json(time_now),
+ )
+
+ # We expect the v1 API to respond with [200, content], so we only return the
+ # content.
+ return resp[1]
+
+ async def send_invite(
+ self, destination: str, room_id: str, event_id: str, pdu: EventBase,
+ ) -> EventBase:
+ room_version = await self.store.get_room_version(room_id)
+
+ content = await self._do_send_invite(destination, pdu, room_version)
pdu_dict = content["event"]
logger.debug("Got response to send_invite: %s", pdu_dict)
- room_version = yield self.store.get_room_version(room_id)
- format_ver = room_version_to_event_format(room_version)
-
- pdu = event_from_pdu_json(pdu_dict, format_ver)
+ pdu = event_from_pdu_json(pdu_dict, room_version)
# Check signatures are correct.
- pdu = yield self._check_sigs_and_hash(room_version, pdu)
+ pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
# FIXME: We should handle signature failures more gracefully.
- defer.returnValue(pdu)
+ return pdu
- @defer.inlineCallbacks
- def _do_send_invite(self, destination, pdu, room_version):
+ async def _do_send_invite(
+ self, destination: str, pdu: EventBase, room_version: RoomVersion
+ ) -> JsonDict:
"""Actually sends the invite, first trying v2 API and falling back to
v1 API if necessary.
- Args:
- destination (str): Target server
- pdu (FrozenEvent)
- room_version (str)
-
Returns:
- dict: The event as a dict as returned by the remote server
+ The event as a dict as returned by the remote server
"""
time_now = self._clock.time_msec()
try:
- content = yield self.transport_layer.send_invite_v2(
+ content = await self.transport_layer.send_invite_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content={
"event": pdu.get_pdu_json(time_now),
- "room_version": room_version,
+ "room_version": room_version.identifier,
"invite_room_state": pdu.unsigned.get("invite_room_state", []),
},
)
- defer.returnValue(content)
+ return content
except HttpResponseException as e:
if e.code in [400, 404]:
err = e.to_synapse_error()
@@ -810,8 +798,7 @@ class FederationClient(FederationBase):
# Otherwise, we assume that the remote server doesn't understand
# the v2 invite API. That's ok provided the room uses old-style event
# IDs.
- v = KNOWN_ROOM_VERSIONS.get(room_version)
- if v.event_format != EventFormatVersions.V1:
+ if room_version.event_format != EventFormatVersions.V1:
raise SynapseError(
400,
"User's homeserver does not support this room version",
@@ -825,15 +812,15 @@ class FederationClient(FederationBase):
# Didn't work, try v1 API.
# Note the v1 API returns a tuple of `(200, content)`
- _, content = yield self.transport_layer.send_invite_v1(
+ _, content = await self.transport_layer.send_invite_v1(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
- defer.returnValue(content)
+ return content
- def send_leave(self, destinations, pdu):
+ async def send_leave(self, destinations: Iterable[str], pdu: EventBase) -> None:
"""Sends a leave event to one of a list of homeservers.
Doing so will cause the remote server to add the event to the graph,
@@ -842,108 +829,109 @@ class FederationClient(FederationBase):
This is mostly useful to reject received invites.
Args:
- destinations (str): Candidate homeservers which are probably
+ destinations: Candidate homeservers which are probably
participating in the room.
- pdu (BaseEvent): event to be sent
-
- Return:
- Deferred: resolves to None.
+ pdu: event to be sent
- Fails with a ``SynapseError`` if the chosen remote server
- returns a 300/400 code.
+ Raises:
+ SynapseError if the chosen remote server returns a 300/400 code.
- Fails with a ``RuntimeError`` if no servers were reachable.
+ RuntimeError if no servers were reachable.
"""
- @defer.inlineCallbacks
- def send_request(destination):
- time_now = self._clock.time_msec()
- _, content = yield self.transport_layer.send_leave(
+
+ async def send_request(destination: str) -> None:
+ content = await self._do_send_leave(destination, pdu)
+ logger.debug("Got content: %s", content)
+
+ return await self._try_destination_list(
+ "send_leave", destinations, send_request
+ )
+
+ async def _do_send_leave(self, destination, pdu):
+ time_now = self._clock.time_msec()
+
+ try:
+ content = await self.transport_layer.send_leave_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
- logger.debug("Got content: %s", content)
- defer.returnValue(None)
+ return content
+ except HttpResponseException as e:
+ if e.code in [400, 404]:
+ err = e.to_synapse_error()
- return self._try_destination_list("send_leave", destinations, send_request)
+ # If we receive an error response that isn't a generic error, or an
+ # unrecognised endpoint error, we assume that the remote understands
+ # the v2 invite API and this is a legitimate error.
+ if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]:
+ raise err
+ else:
+ raise e.to_synapse_error()
+
+ logger.debug("Couldn't send_leave with the v2 API, falling back to the v1 API")
- def get_public_rooms(self, destination, limit=None, since_token=None,
- search_filter=None, include_all_networks=False,
- third_party_instance_id=None):
+ resp = await self.transport_layer.send_leave_v1(
+ destination=destination,
+ room_id=pdu.room_id,
+ event_id=pdu.event_id,
+ content=pdu.get_pdu_json(time_now),
+ )
+
+ # We expect the v1 API to respond with [200, content], so we only return the
+ # content.
+ return resp[1]
+
+ def get_public_rooms(
+ self,
+ destination,
+ limit=None,
+ since_token=None,
+ search_filter=None,
+ include_all_networks=False,
+ third_party_instance_id=None,
+ ):
if destination == self.server_name:
return
return self.transport_layer.get_public_rooms(
- destination, limit, since_token, search_filter,
+ destination,
+ limit,
+ since_token,
+ search_filter,
include_all_networks=include_all_networks,
third_party_instance_id=third_party_instance_id,
)
- @defer.inlineCallbacks
- def query_auth(self, destination, room_id, event_id, local_auth):
- """
- Params:
- destination (str)
- event_it (str)
- local_auth (list)
- """
- time_now = self._clock.time_msec()
-
- send_content = {
- "auth_chain": [e.get_pdu_json(time_now) for e in local_auth],
- }
-
- code, content = yield self.transport_layer.send_query_auth(
- destination=destination,
- room_id=room_id,
- event_id=event_id,
- content=send_content,
- )
-
- room_version = yield self.store.get_room_version(room_id)
- format_ver = room_version_to_event_format(room_version)
-
- auth_chain = [
- event_from_pdu_json(e, format_ver)
- for e in content["auth_chain"]
- ]
-
- signed_auth = yield self._check_sigs_and_hash_and_fetch(
- destination, auth_chain, outlier=True, room_version=room_version,
- )
-
- signed_auth.sort(key=lambda e: e.depth)
-
- ret = {
- "auth_chain": signed_auth,
- "rejects": content.get("rejects", []),
- "missing": content.get("missing", []),
- }
-
- defer.returnValue(ret)
-
- @defer.inlineCallbacks
- def get_missing_events(self, destination, room_id, earliest_events_ids,
- latest_events, limit, min_depth, timeout):
+ async def get_missing_events(
+ self,
+ destination: str,
+ room_id: str,
+ earliest_events_ids: Sequence[str],
+ latest_events: Iterable[EventBase],
+ limit: int,
+ min_depth: int,
+ timeout: int,
+ ) -> List[EventBase]:
"""Tries to fetch events we are missing. This is called when we receive
an event without having received all of its ancestors.
Args:
- destination (str)
- room_id (str)
- earliest_events_ids (list): List of event ids. Effectively the
+ destination
+ room_id
+ earliest_events_ids: List of event ids. Effectively the
events we expected to receive, but haven't. `get_missing_events`
should only return events that didn't happen before these.
- latest_events (list): List of events we have received that we don't
+ latest_events: List of events we have received that we don't
have all previous events for.
- limit (int): Maximum number of events to return.
- min_depth (int): Minimum depth of events tor return.
- timeout (int): Max time to wait in ms
+ limit: Maximum number of events to return.
+ min_depth: Minimum depth of events to return.
+ timeout: Max time to wait in ms
"""
try:
- content = yield self.transport_layer.get_missing_events(
+ content = await self.transport_layer.get_missing_events(
destination=destination,
room_id=room_id,
earliest_events=earliest_events_ids,
@@ -953,16 +941,14 @@ class FederationClient(FederationBase):
timeout=timeout,
)
- room_version = yield self.store.get_room_version(room_id)
- format_ver = room_version_to_event_format(room_version)
+ room_version = await self.store.get_room_version(room_id)
events = [
- event_from_pdu_json(e, format_ver)
- for e in content.get("events", [])
+ event_from_pdu_json(e, room_version) for e in content.get("events", [])
]
- signed_events = yield self._check_sigs_and_hash_and_fetch(
- destination, events, outlier=False, room_version=room_version,
+ signed_events = await self._check_sigs_and_hash_and_fetch(
+ destination, events, outlier=False, room_version=room_version.identifier
)
except HttpResponseException as e:
if not e.code == 400:
@@ -972,7 +958,7 @@ class FederationClient(FederationBase):
# get_missing_events
signed_events = []
- defer.returnValue(signed_events)
+ return signed_events
@defer.inlineCallbacks
def forward_third_party_invite(self, destinations, room_id, event_dict):
@@ -982,17 +968,50 @@ class FederationClient(FederationBase):
try:
yield self.transport_layer.exchange_third_party_invite(
- destination=destination,
- room_id=room_id,
- event_dict=event_dict,
+ destination=destination, room_id=room_id, event_dict=event_dict
)
- defer.returnValue(None)
+ return None
except CodeMessageException:
raise
except Exception as e:
logger.exception(
- "Failed to send_third_party_invite via %s: %s",
- destination, str(e)
+ "Failed to send_third_party_invite via %s: %s", destination, str(e)
)
raise RuntimeError("Failed to send to any server.")
+
+ @defer.inlineCallbacks
+ def get_room_complexity(self, destination, room_id):
+ """
+ Fetch the complexity of a remote room from another server.
+
+ Args:
+ destination (str): The remote server
+ room_id (str): The room ID to ask about.
+
+ Returns:
+ Deferred[dict] or Deferred[None]: Dict contains the complexity
+ metric versions, while None means we could not fetch the complexity.
+ """
+ try:
+ complexity = yield self.transport_layer.get_room_complexity(
+ destination=destination, room_id=room_id
+ )
+ defer.returnValue(complexity)
+ except CodeMessageException as e:
+ # We didn't manage to get it -- probably a 404. We are okay if other
+ # servers don't give it to us.
+ logger.debug(
+ "Failed to fetch room complexity via %s for %s, got a %d",
+ destination,
+ room_id,
+ e.code,
+ )
+ except Exception:
+ logger.exception(
+ "Failed to fetch room complexity via %s for %s", destination, room_id
+ )
+
+ # If we don't manage to find it, return None. It's not an error if a
+ # server doesn't give it to us.
+ defer.returnValue(None)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 4c28c1dc3c..275b9c99d7 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
+# Copyright 2019 Matrix.org Federation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import Dict
import six
from six import iteritems
@@ -36,22 +38,25 @@ from synapse.api.errors import (
UnsupportedRoomVersionError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
-from synapse.crypto.event_signing import compute_event_signature
-from synapse.events import room_version_to_event_format
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
from synapse.http.endpoint import parse_server_name
+from synapse.logging.context import (
+ make_deferred_yieldable,
+ nested_logging_context,
+ run_in_background,
+)
+from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace
+from synapse.logging.utils import log_function
from synapse.replication.http.federation import (
ReplicationFederationSendEduRestServlet,
ReplicationGetQueryRestServlet,
)
-from synapse.types import get_domain_from_id
-from synapse.util import glob_to_regex
+from synapse.types import JsonDict, get_domain_from_id
+from synapse.util import glob_to_regex, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
-from synapse.util.logcontext import nested_logging_context
-from synapse.util.logutils import log_function
# when processing incoming transactions, we try to handle multiple rooms in
# parallel, up to this limit.
@@ -69,12 +74,14 @@ received_queries_counter = Counter(
class FederationServer(FederationBase):
-
def __init__(self, hs):
super(FederationServer, self).__init__(hs)
self.auth = hs.get_auth()
self.handler = hs.get_handlers().federation_handler
+ self.state = hs.get_state_handler()
+
+ self.device_handler = hs.get_device_handler()
self._server_linearizer = Linearizer("fed_server")
self._transaction_linearizer = Linearizer("fed_txn_handler")
@@ -87,24 +94,20 @@ class FederationServer(FederationBase):
# come in waves.
self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
- @defer.inlineCallbacks
- @log_function
- def on_backfill_request(self, origin, room_id, versions, limit):
- with (yield self._server_linearizer.queue((origin, room_id))):
+ async def on_backfill_request(self, origin, room_id, versions, limit):
+ with (await self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
- pdus = yield self.handler.on_backfill_request(
+ pdus = await self.handler.on_backfill_request(
origin, room_id, versions, limit
)
res = self._transaction_from_pdus(pdus).get_dict()
- defer.returnValue((200, res))
+ return 200, res
- @defer.inlineCallbacks
- @log_function
- def on_incoming_transaction(self, origin, transaction_data):
+ async def on_incoming_transaction(self, origin, transaction_data):
# keep this as early as possible to make the calculated origin ts as
# accurate as possible.
request_time = self._clock.time_msec()
@@ -118,17 +121,18 @@ class FederationServer(FederationBase):
# use a linearizer to ensure that we don't process the same transaction
# multiple times in parallel.
- with (yield self._transaction_linearizer.queue(
- (origin, transaction.transaction_id),
- )):
- result = yield self._handle_incoming_transaction(
- origin, transaction, request_time,
+ with (
+ await self._transaction_linearizer.queue(
+ (origin, transaction.transaction_id)
+ )
+ ):
+ result = await self._handle_incoming_transaction(
+ origin, transaction, request_time
)
- defer.returnValue(result)
+ return result
- @defer.inlineCallbacks
- def _handle_incoming_transaction(self, origin, transaction, request_time):
+ async def _handle_incoming_transaction(self, origin, transaction, request_time):
""" Process an incoming transaction and return the HTTP response
Args:
@@ -139,33 +143,66 @@ class FederationServer(FederationBase):
Returns:
Deferred[(int, object)]: http response code and body
"""
- response = yield self.transaction_actions.have_responded(origin, transaction)
+ response = await self.transaction_actions.have_responded(origin, transaction)
if response:
logger.debug(
"[%s] We've already responded to this request",
- transaction.transaction_id
+ transaction.transaction_id,
)
- defer.returnValue(response)
- return
+ return response
logger.debug("[%s] Transaction is new", transaction.transaction_id)
- # Reject if PDU count > 50 and EDU count > 100
- if (len(transaction.pdus) > 50
- or (hasattr(transaction, "edus") and len(transaction.edus) > 100)):
+ # Reject if PDU count > 50 or EDU count > 100
+ if len(transaction.pdus) > 50 or (
+ hasattr(transaction, "edus") and len(transaction.edus) > 100
+ ):
- logger.info(
- "Transaction PDU or EDU count too large. Returning 400",
- )
+ logger.info("Transaction PDU or EDU count too large. Returning 400")
response = {}
- yield self.transaction_actions.set_response(
- origin,
- transaction,
- 400, response
+ await self.transaction_actions.set_response(
+ origin, transaction, 400, response
)
- defer.returnValue((400, response))
+ return 400, response
+
+ # We process PDUs and EDUs in parallel. This is important as we don't
+ # want to block things like to device messages from reaching clients
+ # behind the potentially expensive handling of PDUs.
+ pdu_results, _ = await make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(
+ self._handle_pdus_in_txn, origin, transaction, request_time
+ ),
+ run_in_background(self._handle_edus_in_txn, origin, transaction),
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+ )
+
+ response = {"pdus": pdu_results}
+
+ logger.debug("Returning: %s", str(response))
+
+ await self.transaction_actions.set_response(origin, transaction, 200, response)
+ return 200, response
+
+ async def _handle_pdus_in_txn(
+ self, origin: str, transaction: Transaction, request_time: int
+ ) -> Dict[str, dict]:
+ """Process the PDUs in a received transaction.
+
+ Args:
+ origin: the server making the request
+ transaction: incoming transaction
+ request_time: timestamp that the HTTP request arrived at
+
+ Returns:
+ A map from event ID of a processed PDU to any errors we should
+ report back to the sending server.
+ """
received_pdus_counter.inc(len(transaction.pdus))
@@ -198,24 +235,17 @@ class FederationServer(FederationBase):
continue
try:
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version(room_id)
except NotFoundError:
logger.info("Ignoring PDU for unknown room_id: %s", room_id)
continue
-
- try:
- format_ver = room_version_to_event_format(room_version)
- except UnsupportedRoomVersionError:
+ except UnsupportedRoomVersionError as e:
# this can happen if support for a given room version is withdrawn,
# so that we still get events for said room.
- logger.info(
- "Ignoring PDU for room %s with unknown version %s",
- room_id,
- room_version,
- )
+ logger.info("Ignoring PDU: %s", e)
continue
- event = event_from_pdu_json(p, format_ver)
+ event = event_from_pdu_json(p, room_version)
pdus_by_room.setdefault(room_id, []).append(event)
pdu_results = {}
@@ -224,15 +254,12 @@ class FederationServer(FederationBase):
# require callouts to other servers to fetch missing events), but
# impose a limit to avoid going too crazy with ram/cpu.
- @defer.inlineCallbacks
- def process_pdus_for_room(room_id):
+ async def process_pdus_for_room(room_id):
logger.debug("Processing PDUs for %s", room_id)
try:
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
except AuthError as e:
- logger.warn(
- "Ignoring PDUs for room %s from banned server", room_id,
- )
+ logger.warning("Ignoring PDUs for room %s from banned server", room_id)
for pdu in pdus_by_room[room_id]:
event_id = pdu.event_id
pdu_results[event_id] = e.error_dict()
@@ -242,12 +269,10 @@ class FederationServer(FederationBase):
event_id = pdu.event_id
with nested_logging_context(event_id):
try:
- yield self._handle_received_pdu(
- origin, pdu
- )
+ await self._handle_received_pdu(origin, pdu)
pdu_results[event_id] = {}
except FederationError as e:
- logger.warn("Error handling PDU %s: %s", event_id, e)
+ logger.warning("Error handling PDU %s: %s", event_id, e)
pdu_results[event_id] = {"error": str(e)}
except Exception as e:
f = failure.Failure()
@@ -258,47 +283,38 @@ class FederationServer(FederationBase):
exc_info=(f.type, f.value, f.getTracebackObject()),
)
- yield concurrently_execute(
- process_pdus_for_room, pdus_by_room.keys(),
- TRANSACTION_CONCURRENCY_LIMIT,
+ await concurrently_execute(
+ process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT
)
- if hasattr(transaction, "edus"):
- for edu in (Edu(**x) for x in transaction.edus):
- yield self.received_edu(
- origin,
- edu.edu_type,
- edu.content
- )
-
- response = {
- "pdus": pdu_results,
- }
+ return pdu_results
- logger.debug("Returning: %s", str(response))
+ async def _handle_edus_in_txn(self, origin: str, transaction: Transaction):
+ """Process the EDUs in a received transaction.
+ """
- yield self.transaction_actions.set_response(
- origin,
- transaction,
- 200, response
- )
- defer.returnValue((200, response))
+ async def _process_edu(edu_dict):
+ received_edus_counter.inc()
- @defer.inlineCallbacks
- def received_edu(self, origin, edu_type, content):
- received_edus_counter.inc()
- yield self.registry.on_edu(edu_type, origin, content)
+ edu = Edu(
+ origin=origin,
+ destination=self.server_name,
+ edu_type=edu_dict["edu_type"],
+ content=edu_dict["content"],
+ )
+ await self.registry.on_edu(edu.edu_type, origin, edu.content)
- @defer.inlineCallbacks
- @log_function
- def on_context_state_request(self, origin, room_id, event_id):
- if not event_id:
- raise NotImplementedError("Specify an event")
+ await concurrently_execute(
+ _process_edu,
+ getattr(transaction, "edus", []),
+ TRANSACTION_CONCURRENCY_LIMIT,
+ )
+ async def on_context_state_request(self, origin, room_id, event_id):
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
- in_room = yield self.auth.check_host_in_room(room_id, origin)
+ in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -307,254 +323,170 @@ class FederationServer(FederationBase):
# in the cache so we could return it without waiting for the linearizer
# - but that's non-trivial to get right, and anyway somewhat defeats
# the point of the linearizer.
- with (yield self._server_linearizer.queue((origin, room_id))):
- resp = yield self._state_resp_cache.wrap(
- (room_id, event_id),
- self._on_context_state_request_compute,
- room_id, event_id,
+ with (await self._server_linearizer.queue((origin, room_id))):
+ resp = dict(
+ await self._state_resp_cache.wrap(
+ (room_id, event_id),
+ self._on_context_state_request_compute,
+ room_id,
+ event_id,
+ )
)
- defer.returnValue((200, resp))
+ room_version = await self.store.get_room_version_id(room_id)
+ resp["room_version"] = room_version
- @defer.inlineCallbacks
- def on_state_ids_request(self, origin, room_id, event_id):
+ return 200, resp
+
+ async def on_state_ids_request(self, origin, room_id, event_id):
if not event_id:
raise NotImplementedError("Specify an event")
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
- in_room = yield self.auth.check_host_in_room(room_id, origin)
+ in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
- state_ids = yield self.handler.get_state_ids_for_pdu(
- room_id, event_id,
- )
- auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids)
+ state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
+ auth_chain_ids = await self.store.get_auth_chain_ids(state_ids)
- defer.returnValue((200, {
- "pdu_ids": state_ids,
- "auth_chain_ids": auth_chain_ids,
- }))
+ return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
- @defer.inlineCallbacks
- def _on_context_state_request_compute(self, room_id, event_id):
- pdus = yield self.handler.get_state_for_pdu(
- room_id, event_id,
- )
- auth_chain = yield self.store.get_auth_chain(
- [pdu.event_id for pdu in pdus]
- )
+ async def _on_context_state_request_compute(self, room_id, event_id):
+ if event_id:
+ pdus = await self.handler.get_state_for_pdu(room_id, event_id)
+ else:
+ pdus = (await self.state.get_current_state(room_id)).values()
- for event in auth_chain:
- # We sign these again because there was a bug where we
- # incorrectly signed things the first time round
- if self.hs.is_mine_id(event.event_id):
- event.signatures.update(
- compute_event_signature(
- event.get_pdu_json(),
- self.hs.hostname,
- self.hs.config.signing_key[0]
- )
- )
+ auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus])
- defer.returnValue({
+ return {
"pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],
- })
+ }
- @defer.inlineCallbacks
- @log_function
- def on_pdu_request(self, origin, event_id):
- pdu = yield self.handler.get_persisted_pdu(origin, event_id)
+ async def on_pdu_request(self, origin, event_id):
+ pdu = await self.handler.get_persisted_pdu(origin, event_id)
if pdu:
- defer.returnValue(
- (200, self._transaction_from_pdus([pdu]).get_dict())
- )
+ return 200, self._transaction_from_pdus([pdu]).get_dict()
else:
- defer.returnValue((404, ""))
+ return 404, ""
- @defer.inlineCallbacks
- def on_query_request(self, query_type, args):
+ async def on_query_request(self, query_type, args):
received_queries_counter.labels(query_type).inc()
- resp = yield self.registry.on_query(query_type, args)
- defer.returnValue((200, resp))
+ resp = await self.registry.on_query(query_type, args)
+ return 200, resp
- @defer.inlineCallbacks
- def on_make_join_request(self, origin, room_id, user_id, supported_versions):
+ async def on_make_join_request(self, origin, room_id, user_id, supported_versions):
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version_id(room_id)
if room_version not in supported_versions:
- logger.warn("Room version %s not in %s", room_version, supported_versions)
+ logger.warning(
+ "Room version %s not in %s", room_version, supported_versions
+ )
raise IncompatibleRoomVersionError(room_version=room_version)
- pdu = yield self.handler.on_make_join_request(room_id, user_id)
+ pdu = await self.handler.on_make_join_request(origin, room_id, user_id)
time_now = self._clock.time_msec()
- defer.returnValue({
- "event": pdu.get_pdu_json(time_now),
- "room_version": room_version,
- })
-
- @defer.inlineCallbacks
- def on_invite_request(self, origin, content, room_version):
- if room_version not in KNOWN_ROOM_VERSIONS:
+ return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
+
+ async def on_invite_request(
+ self, origin: str, content: JsonDict, room_version_id: str
+ ):
+ room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
+ if not room_version:
raise SynapseError(
400,
"Homeserver does not support this room version",
Codes.UNSUPPORTED_ROOM_VERSION,
)
- format_ver = room_version_to_event_format(room_version)
-
- pdu = event_from_pdu_json(content, format_ver)
+ pdu = event_from_pdu_json(content, room_version)
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, pdu.room_id)
- ret_pdu = yield self.handler.on_invite_request(origin, pdu)
+ await self.check_server_matches_acl(origin_host, pdu.room_id)
+ pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
+ ret_pdu = await self.handler.on_invite_request(origin, pdu, room_version)
time_now = self._clock.time_msec()
- defer.returnValue({"event": ret_pdu.get_pdu_json(time_now)})
+ return {"event": ret_pdu.get_pdu_json(time_now)}
- @defer.inlineCallbacks
- def on_send_join_request(self, origin, content, room_id):
+ async def on_send_join_request(self, origin, content, room_id):
logger.debug("on_send_join_request: content: %s", content)
- room_version = yield self.store.get_room_version(room_id)
- format_ver = room_version_to_event_format(room_version)
- pdu = event_from_pdu_json(content, format_ver)
+ room_version = await self.store.get_room_version(room_id)
+ pdu = event_from_pdu_json(content, room_version)
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, pdu.room_id)
+ await self.check_server_matches_acl(origin_host, pdu.room_id)
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
- res_pdus = yield self.handler.on_send_join_request(origin, pdu)
+
+ pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
+
+ res_pdus = await self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
- defer.returnValue((200, {
+ return {
"state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
- "auth_chain": [
- p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
- ],
- }))
+ "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]],
+ }
- @defer.inlineCallbacks
- def on_make_leave_request(self, origin, room_id, user_id):
+ async def on_make_leave_request(self, origin, room_id, user_id):
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
- pdu = yield self.handler.on_make_leave_request(room_id, user_id)
+ await self.check_server_matches_acl(origin_host, room_id)
+ pdu = await self.handler.on_make_leave_request(origin, room_id, user_id)
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version_id(room_id)
time_now = self._clock.time_msec()
- defer.returnValue({
- "event": pdu.get_pdu_json(time_now),
- "room_version": room_version,
- })
+ return {"event": pdu.get_pdu_json(time_now), "room_version": room_version}
- @defer.inlineCallbacks
- def on_send_leave_request(self, origin, content, room_id):
+ async def on_send_leave_request(self, origin, content, room_id):
logger.debug("on_send_leave_request: content: %s", content)
- room_version = yield self.store.get_room_version(room_id)
- format_ver = room_version_to_event_format(room_version)
- pdu = event_from_pdu_json(content, format_ver)
+ room_version = await self.store.get_room_version(room_id)
+ pdu = event_from_pdu_json(content, room_version)
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, pdu.room_id)
+ await self.check_server_matches_acl(origin_host, pdu.room_id)
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
- yield self.handler.on_send_leave_request(origin, pdu)
- defer.returnValue((200, {}))
-
- @defer.inlineCallbacks
- def on_event_auth(self, origin, room_id, event_id):
- with (yield self._server_linearizer.queue((origin, room_id))):
- origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
- time_now = self._clock.time_msec()
- auth_pdus = yield self.handler.on_event_auth(event_id)
- res = {
- "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
- }
- defer.returnValue((200, res))
-
- @defer.inlineCallbacks
- def on_query_auth_request(self, origin, content, room_id, event_id):
- """
- Content is a dict with keys::
- auth_chain (list): A list of events that give the auth chain.
- missing (list): A list of event_ids indicating what the other
- side (`origin`) think we're missing.
- rejects (dict): A mapping from event_id to a 2-tuple of reason
- string and a proof (or None) of why the event was rejected.
- The keys of this dict give the list of events the `origin` has
- rejected.
+ pdu = await self._check_sigs_and_hash(room_version.identifier, pdu)
- Args:
- origin (str)
- content (dict)
- event_id (str)
+ await self.handler.on_send_leave_request(origin, pdu)
+ return {}
- Returns:
- Deferred: Results in `dict` with the same format as `content`
- """
- with (yield self._server_linearizer.queue((origin, room_id))):
+ async def on_event_auth(self, origin, room_id, event_id):
+ with (await self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
-
- room_version = yield self.store.get_room_version(room_id)
- format_ver = room_version_to_event_format(room_version)
-
- auth_chain = [
- event_from_pdu_json(e, format_ver)
- for e in content["auth_chain"]
- ]
-
- signed_auth = yield self._check_sigs_and_hash_and_fetch(
- origin, auth_chain, outlier=True, room_version=room_version,
- )
-
- ret = yield self.handler.on_query_auth(
- origin,
- event_id,
- room_id,
- signed_auth,
- content.get("rejects", []),
- content.get("missing", []),
- )
+ await self.check_server_matches_acl(origin_host, room_id)
time_now = self._clock.time_msec()
- send_content = {
- "auth_chain": [
- e.get_pdu_json(time_now)
- for e in ret["auth_chain"]
- ],
- "rejects": ret.get("rejects", []),
- "missing": ret.get("missing", []),
- }
-
- defer.returnValue(
- (200, send_content)
- )
+ auth_pdus = await self.handler.on_event_auth(event_id)
+ res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]}
+ return 200, res
@log_function
def on_query_client_keys(self, origin, content):
return self.on_query_request("client_keys", content)
- def on_query_user_devices(self, origin, user_id):
- return self.on_query_request("user_devices", user_id)
+ async def on_query_user_devices(self, origin: str, user_id: str):
+ keys = await self.device_handler.on_federation_query_user_devices(user_id)
+ return 200, keys
- @defer.inlineCallbacks
- @log_function
- def on_claim_client_keys(self, origin, content):
+ @trace
+ async def on_claim_client_keys(self, origin, content):
query = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items():
query.append((user_id, device_id, algorithm))
- results = yield self.store.claim_e2e_one_time_keys(query)
+ log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
+ results = await self.store.claim_e2e_one_time_keys(query)
json_result = {}
for user_id, device_keys in results.items():
@@ -566,46 +498,47 @@ class FederationServer(FederationBase):
logger.info(
"Claimed one-time-keys: %s",
- ",".join((
- "%s for %s:%s" % (key_id, user_id, device_id)
- for user_id, user_keys in iteritems(json_result)
- for device_id, device_keys in iteritems(user_keys)
- for key_id, _ in iteritems(device_keys)
- )),
+ ",".join(
+ (
+ "%s for %s:%s" % (key_id, user_id, device_id)
+ for user_id, user_keys in iteritems(json_result)
+ for device_id, device_keys in iteritems(user_keys)
+ for key_id, _ in iteritems(device_keys)
+ )
+ ),
)
- defer.returnValue({"one_time_keys": json_result})
+ return {"one_time_keys": json_result}
- @defer.inlineCallbacks
- @log_function
- def on_get_missing_events(self, origin, room_id, earliest_events,
- latest_events, limit):
- with (yield self._server_linearizer.queue((origin, room_id))):
+ async def on_get_missing_events(
+ self, origin, room_id, earliest_events, latest_events, limit
+ ):
+ with (await self._server_linearizer.queue((origin, room_id))):
origin_host, _ = parse_server_name(origin)
- yield self.check_server_matches_acl(origin_host, room_id)
+ await self.check_server_matches_acl(origin_host, room_id)
- logger.info(
+ logger.debug(
"on_get_missing_events: earliest_events: %r, latest_events: %r,"
" limit: %d",
- earliest_events, latest_events, limit,
+ earliest_events,
+ latest_events,
+ limit,
)
- missing_events = yield self.handler.on_get_missing_events(
- origin, room_id, earliest_events, latest_events, limit,
+ missing_events = await self.handler.on_get_missing_events(
+ origin, room_id, earliest_events, latest_events, limit
)
if len(missing_events) < 5:
- logger.info(
+ logger.debug(
"Returning %d events: %r", len(missing_events), missing_events
)
else:
- logger.info("Returning %d events", len(missing_events))
+ logger.debug("Returning %d events", len(missing_events))
time_now = self._clock.time_msec()
- defer.returnValue({
- "events": [ev.get_pdu_json(time_now) for ev in missing_events],
- })
+ return {"events": [ev.get_pdu_json(time_now) for ev in missing_events]}
@log_function
def on_openid_userinfo(self, token):
@@ -625,8 +558,7 @@ class FederationServer(FederationBase):
destination=None,
)
- @defer.inlineCallbacks
- def _handle_received_pdu(self, origin, pdu):
+ async def _handle_received_pdu(self, origin, pdu):
""" Process a PDU received in a federation /send/ transaction.
If the event is invalid, then this method throws a FederationError.
@@ -666,69 +598,47 @@ class FederationServer(FederationBase):
# origin. See bug #1893. This is also true for some third party
# invites).
if not (
- pdu.type == 'm.room.member' and
- pdu.content and
- pdu.content.get("membership", None) in (
- Membership.JOIN, Membership.INVITE,
- )
+ pdu.type == "m.room.member"
+ and pdu.content
+ and pdu.content.get("membership", None)
+ in (Membership.JOIN, Membership.INVITE)
):
logger.info(
- "Discarding PDU %s from invalid origin %s",
- pdu.event_id, origin
+ "Discarding PDU %s from invalid origin %s", pdu.event_id, origin
)
return
else:
- logger.info(
- "Accepting join PDU %s from %s",
- pdu.event_id, origin
- )
+ logger.info("Accepting join PDU %s from %s", pdu.event_id, origin)
# We've already checked that we know the room version by this point
- room_version = yield self.store.get_room_version(pdu.room_id)
+ room_version = await self.store.get_room_version_id(pdu.room_id)
# Check signature.
try:
- pdu = yield self._check_sigs_and_hash(room_version, pdu)
+ pdu = await self._check_sigs_and_hash(room_version, pdu)
except SynapseError as e:
- raise FederationError(
- "ERROR",
- e.code,
- e.msg,
- affected=pdu.event_id,
- )
+ raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id)
- yield self.handler.on_receive_pdu(
- origin, pdu, sent_to_us_directly=True,
- )
+ await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
def __str__(self):
return "<ReplicationLayer(%s)>" % self.server_name
- @defer.inlineCallbacks
- def exchange_third_party_invite(
- self,
- sender_user_id,
- target_user_id,
- room_id,
- signed,
+ async def exchange_third_party_invite(
+ self, sender_user_id, target_user_id, room_id, signed
):
- ret = yield self.handler.exchange_third_party_invite(
- sender_user_id,
- target_user_id,
- room_id,
- signed,
+ ret = await self.handler.exchange_third_party_invite(
+ sender_user_id, target_user_id, room_id, signed
)
- defer.returnValue(ret)
+ return ret
- @defer.inlineCallbacks
- def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
- ret = yield self.handler.on_exchange_third_party_invite_request(
- origin, room_id, event_dict
+ async def on_exchange_third_party_invite_request(self, room_id, event_dict):
+ ret = await self.handler.on_exchange_third_party_invite_request(
+ room_id, event_dict
)
- defer.returnValue(ret)
+ return ret
- @defer.inlineCallbacks
- def check_server_matches_acl(self, server_name, room_id):
+ async def check_server_matches_acl(self, server_name, room_id):
"""Check if the given server is allowed by the server ACLs in the room
Args:
@@ -738,13 +648,13 @@ class FederationServer(FederationBase):
Raises:
AuthError if the server does not match the ACL
"""
- state_ids = yield self.store.get_current_state_ids(room_id)
+ state_ids = await self.store.get_current_state_ids(room_id)
acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
if not acl_event_id:
return
- acl_event = yield self.store.get_event(acl_event_id)
+ acl_event = await self.store.get_event(acl_event_id)
if server_matches_acl_event(server_name, acl_event):
return
@@ -767,11 +677,11 @@ def server_matches_acl_event(server_name, acl_event):
# server name is a literal IP
allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
if not isinstance(allow_ip_literals, bool):
- logger.warn("Ignorning non-bool allow_ip_literals flag")
+ logger.warning("Ignorning non-bool allow_ip_literals flag")
allow_ip_literals = True
if not allow_ip_literals:
# check for ipv6 literals. These start with '['.
- if server_name[0] == '[':
+ if server_name[0] == "[":
return False
# check for ipv4 literals. We can just lift the routine from twisted.
@@ -781,7 +691,7 @@ def server_matches_acl_event(server_name, acl_event):
# next, check the deny list
deny = acl_event.content.get("deny", [])
if not isinstance(deny, (list, tuple)):
- logger.warn("Ignorning non-list deny ACL %s", deny)
+ logger.warning("Ignorning non-list deny ACL %s", deny)
deny = []
for e in deny:
if _acl_entry_matches(server_name, e):
@@ -791,7 +701,7 @@ def server_matches_acl_event(server_name, acl_event):
# then the allow list.
allow = acl_event.content.get("allow", [])
if not isinstance(allow, (list, tuple)):
- logger.warn("Ignorning non-list allow ACL %s", allow)
+ logger.warning("Ignorning non-list allow ACL %s", allow)
allow = []
for e in allow:
if _acl_entry_matches(server_name, e):
@@ -805,7 +715,9 @@ def server_matches_acl_event(server_name, acl_event):
def _acl_entry_matches(server_name, acl_entry):
if not isinstance(acl_entry, six.string_types):
- logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry))
+ logger.warning(
+ "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)
+ )
return False
regex = glob_to_regex(acl_entry)
return regex.match(server_name)
@@ -815,6 +727,7 @@ class FederationHandlerRegistry(object):
"""Allows classes to register themselves as handlers for a given EDU or
query type for incoming federation traffic.
"""
+
def __init__(self):
self.edu_handlers = {}
self.query_handlers = {}
@@ -848,31 +761,29 @@ class FederationHandlerRegistry(object):
on and the result used as the response to the query request.
"""
if query_type in self.query_handlers:
- raise KeyError(
- "Already have a Query handler for %s" % (query_type,)
- )
+ raise KeyError("Already have a Query handler for %s" % (query_type,))
logger.info("Registering federation query handler for %r", query_type)
self.query_handlers[query_type] = handler
- @defer.inlineCallbacks
- def on_edu(self, edu_type, origin, content):
+ async def on_edu(self, edu_type, origin, content):
handler = self.edu_handlers.get(edu_type)
if not handler:
- logger.warn("No handler registered for EDU type %s", edu_type)
+ logger.warning("No handler registered for EDU type %s", edu_type)
- try:
- yield handler(origin, content)
- except SynapseError as e:
- logger.info("Failed to handle edu %r: %r", edu_type, e)
- except Exception:
- logger.exception("Failed to handle edu %r", edu_type)
+ with start_active_span_from_edu(content, "handle_edu"):
+ try:
+ await handler(origin, content)
+ except SynapseError as e:
+ logger.info("Failed to handle edu %r: %r", edu_type, e)
+ except Exception:
+ logger.exception("Failed to handle edu %r", edu_type)
def on_query(self, query_type, args):
handler = self.query_handlers.get(query_type)
if not handler:
- logger.warn("No handler registered for query type %s", query_type)
+ logger.warning("No handler registered for query type %s", query_type)
raise NotFoundError("No handler for Query type '%s'" % (query_type,))
return handler(args)
@@ -896,7 +807,7 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
super(ReplicationFederationHandlerRegistry, self).__init__()
- def on_edu(self, edu_type, origin, content):
+ async def on_edu(self, edu_type, origin, content):
"""Overrides FederationHandlerRegistry
"""
if not self.config.use_presence and edu_type == "m.presence":
@@ -904,24 +815,17 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
handler = self.edu_handlers.get(edu_type)
if handler:
- return super(ReplicationFederationHandlerRegistry, self).on_edu(
- edu_type, origin, content,
+ return await super(ReplicationFederationHandlerRegistry, self).on_edu(
+ edu_type, origin, content
)
- return self._send_edu(
- edu_type=edu_type,
- origin=origin,
- content=content,
- )
+ return await self._send_edu(edu_type=edu_type, origin=origin, content=content)
- def on_query(self, query_type, args):
+ async def on_query(self, query_type, args):
"""Overrides FederationHandlerRegistry
"""
handler = self.query_handlers.get(query_type)
if handler:
- return handler(args)
+ return await handler(args)
- return self._get_query_client(
- query_type=query_type,
- args=args,
- )
+ return await self._get_query_client(query_type=query_type, args=args)
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 74ffd13b4f..d68b4bd670 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -21,9 +21,7 @@ These actions are mostly only used by the :py:mod:`.replication` module.
import logging
-from twisted.internet import defer
-
-from synapse.util.logutils import log_function
+from synapse.logging.utils import log_function
logger = logging.getLogger(__name__)
@@ -46,12 +44,9 @@ class TransactionActions(object):
response code and response body.
"""
if not transaction.transaction_id:
- raise RuntimeError("Cannot persist a transaction with no "
- "transaction_id")
+ raise RuntimeError("Cannot persist a transaction with no transaction_id")
- return self.store.get_received_txn_response(
- transaction.transaction_id, origin
- )
+ return self.store.get_received_txn_response(transaction.transaction_id, origin)
@log_function
def set_response(self, origin, transaction, code, response):
@@ -61,42 +56,8 @@ class TransactionActions(object):
Deferred
"""
if not transaction.transaction_id:
- raise RuntimeError("Cannot persist a transaction with no "
- "transaction_id")
+ raise RuntimeError("Cannot persist a transaction with no transaction_id")
return self.store.set_received_txn_response(
- transaction.transaction_id,
- origin,
- code,
- response,
- )
-
- @defer.inlineCallbacks
- @log_function
- def prepare_to_send(self, transaction):
- """ Persists the `Transaction` we are about to send and works out the
- correct value for the `prev_ids` key.
-
- Returns:
- Deferred
- """
- transaction.prev_ids = yield self.store.prep_send_transaction(
- transaction.transaction_id,
- transaction.destination,
- transaction.origin_server_ts,
- )
-
- @log_function
- def delivered(self, transaction, response_code, response_dict):
- """ Marks the given `Transaction` as having been successfully
- delivered to the remote homeserver, and what the response was.
-
- Returns:
- Deferred
- """
- return self.store.delivered_txn(
- transaction.transaction_id,
- transaction.destination,
- response_code,
- response_dict,
+ transaction.transaction_id, origin, code, response
)
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 0240b339b0..876fb0e245 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -36,6 +36,8 @@ from six import iteritems
from sortedcontainers import SortedDict
+from twisted.internet import defer
+
from synapse.metrics import LaterGauge
from synapse.storage.presence import UserPresenceState
from synapse.util.metrics import Measure
@@ -67,8 +69,6 @@ class FederationRemoteSendQueue(object):
self.edus = SortedDict() # stream position -> Edu
- self.device_messages = SortedDict() # stream position -> destination
-
self.pos = 1
self.pos_time = SortedDict()
@@ -77,12 +77,21 @@ class FederationRemoteSendQueue(object):
# lambda binds to the queue rather than to the name of the queue which
# changes. ARGH.
def register(name, queue):
- LaterGauge("synapse_federation_send_queue_%s_size" % (queue_name,),
- "", [], lambda: len(queue))
+ LaterGauge(
+ "synapse_federation_send_queue_%s_size" % (queue_name,),
+ "",
+ [],
+ lambda: len(queue),
+ )
for queue_name in [
- "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
- "edus", "device_messages", "pos_time", "presence_destinations",
+ "presence_map",
+ "presence_changed",
+ "keyed_edu",
+ "keyed_edu_changed",
+ "edus",
+ "pos_time",
+ "presence_destinations",
]:
register(queue_name, getattr(self, queue_name))
@@ -120,11 +129,9 @@ class FederationRemoteSendQueue(object):
for key in keys[:i]:
del self.presence_changed[key]
- user_ids = set(
- user_id
- for uids in self.presence_changed.values()
- for user_id in uids
- )
+ user_ids = {
+ user_id for uids in self.presence_changed.values() for user_id in uids
+ }
keys = self.presence_destinations.keys()
i = self.presence_destinations.bisect_left(position_to_delete)
@@ -161,12 +168,6 @@ class FederationRemoteSendQueue(object):
for key in keys[:i]:
del self.edus[key]
- # Delete things out of device map
- keys = self.device_messages.keys()
- i = self.device_messages.bisect_left(position_to_delete)
- for key in keys[:i]:
- del self.device_messages[key]
-
def notify_new_events(self, current_id):
"""As per FederationSender"""
# We don't need to replicate this as it gets sent down a different
@@ -204,7 +205,7 @@ class FederationRemoteSendQueue(object):
receipt (synapse.types.ReadReceipt):
"""
# nothing to do here: the replication listener will handle it.
- pass
+ return defer.succeed(None)
def send_presence(self, states):
"""As per FederationSender
@@ -239,9 +240,8 @@ class FederationRemoteSendQueue(object):
def send_device_messages(self, destination):
"""As per FederationSender"""
- pos = self._next_pos()
- self.device_messages[pos] = destination
- self.notifier.on_new_replication_data()
+ # We don't need to replicate this as it gets sent down a different
+ # stream.
def get_current_token(self):
return self.pos - 1
@@ -249,7 +249,9 @@ class FederationRemoteSendQueue(object):
def federation_ack(self, token):
self._clear_queue_before_pos(token)
- def get_replication_rows(self, from_token, to_token, limit, federation_ack=None):
+ async def get_replication_rows(
+ self, from_token, to_token, limit, federation_ack=None
+ ):
"""Get rows to be sent over federation between the two tokens
Args:
@@ -285,19 +287,21 @@ class FederationRemoteSendQueue(object):
]
for (key, user_id) in dest_user_ids:
- rows.append((key, PresenceRow(
- state=self.presence_map[user_id],
- )))
+ rows.append((key, PresenceRow(state=self.presence_map[user_id])))
# Fetch presence to send to destinations
i = self.presence_destinations.bisect_right(from_token)
j = self.presence_destinations.bisect_right(to_token) + 1
for pos, (user_id, dests) in self.presence_destinations.items()[i:j]:
- rows.append((pos, PresenceDestinationsRow(
- state=self.presence_map[user_id],
- destinations=list(dests),
- )))
+ rows.append(
+ (
+ pos,
+ PresenceDestinationsRow(
+ state=self.presence_map[user_id], destinations=list(dests)
+ ),
+ )
+ )
# Fetch changes keyed edus
i = self.keyed_edu_changed.bisect_right(from_token)
@@ -308,10 +312,14 @@ class FederationRemoteSendQueue(object):
keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]}
for ((destination, edu_key), pos) in iteritems(keyed_edus):
- rows.append((pos, KeyedEduRow(
- key=edu_key,
- edu=self.keyed_edu[(destination, edu_key)],
- )))
+ rows.append(
+ (
+ pos,
+ KeyedEduRow(
+ key=edu_key, edu=self.keyed_edu[(destination, edu_key)]
+ ),
+ )
+ )
# Fetch changed edus
i = self.edus.bisect_right(from_token)
@@ -321,16 +329,6 @@ class FederationRemoteSendQueue(object):
for (pos, edu) in edus:
rows.append((pos, EduRow(edu)))
- # Fetch changed device messages
- i = self.device_messages.bisect_right(from_token)
- j = self.device_messages.bisect_right(to_token) + 1
- device_messages = {v: k for k, v in self.device_messages.items()[i:j]}
-
- for (destination, pos) in iteritems(device_messages):
- rows.append((pos, DeviceRow(
- destination=destination,
- )))
-
# Sort rows based on pos
rows.sort()
@@ -377,16 +375,14 @@ class BaseFederationRow(object):
raise NotImplementedError()
-class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", (
- "state", # UserPresenceState
-))):
+class PresenceRow(
+ BaseFederationRow, namedtuple("PresenceRow", ("state",)) # UserPresenceState
+):
TypeId = "p"
@staticmethod
def from_data(data):
- return PresenceRow(
- state=UserPresenceState.from_dict(data)
- )
+ return PresenceRow(state=UserPresenceState.from_dict(data))
def to_data(self):
return self.state.as_dict()
@@ -395,33 +391,35 @@ class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", (
buff.presence.append(self.state)
-class PresenceDestinationsRow(BaseFederationRow, namedtuple("PresenceDestinationsRow", (
- "state", # UserPresenceState
- "destinations", # list[str]
-))):
+class PresenceDestinationsRow(
+ BaseFederationRow,
+ namedtuple(
+ "PresenceDestinationsRow",
+ ("state", "destinations"), # UserPresenceState # list[str]
+ ),
+):
TypeId = "pd"
@staticmethod
def from_data(data):
return PresenceDestinationsRow(
- state=UserPresenceState.from_dict(data["state"]),
- destinations=data["dests"],
+ state=UserPresenceState.from_dict(data["state"]), destinations=data["dests"]
)
def to_data(self):
- return {
- "state": self.state.as_dict(),
- "dests": self.destinations,
- }
+ return {"state": self.state.as_dict(), "dests": self.destinations}
def add_to_buffer(self, buff):
buff.presence_destinations.append((self.state, self.destinations))
-class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", (
- "key", # tuple(str) - the edu key passed to send_edu
- "edu", # Edu
-))):
+class KeyedEduRow(
+ BaseFederationRow,
+ namedtuple(
+ "KeyedEduRow",
+ ("key", "edu"), # tuple(str) - the edu key passed to send_edu # Edu
+ ),
+):
"""Streams EDUs that have an associated key that is ued to clobber. For example,
typing EDUs clobber based on room_id.
"""
@@ -430,28 +428,19 @@ class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", (
@staticmethod
def from_data(data):
- return KeyedEduRow(
- key=tuple(data["key"]),
- edu=Edu(**data["edu"]),
- )
+ return KeyedEduRow(key=tuple(data["key"]), edu=Edu(**data["edu"]))
def to_data(self):
- return {
- "key": self.key,
- "edu": self.edu.get_internal_dict(),
- }
+ return {"key": self.key, "edu": self.edu.get_internal_dict()}
def add_to_buffer(self, buff):
- buff.keyed_edus.setdefault(
- self.edu.destination, {}
- )[self.key] = self.edu
+ buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu
-class EduRow(BaseFederationRow, namedtuple("EduRow", (
- "edu", # Edu
-))):
+class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu
"""Streams EDUs that don't have keys. See KeyedEduRow
"""
+
TypeId = "e"
@staticmethod
@@ -465,45 +454,21 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", (
buff.edus.setdefault(self.edu.destination, []).append(self.edu)
-class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
- "destination", # str
-))):
- """Streams the fact that either a) there is pending to device messages for
- users on the remote, or b) a local users device has changed and needs to
- be sent to the remote.
- """
- TypeId = "d"
-
- @staticmethod
- def from_data(data):
- return DeviceRow(destination=data["destination"])
-
- def to_data(self):
- return {"destination": self.destination}
-
- def add_to_buffer(self, buff):
- buff.device_destinations.add(self.destination)
-
-
TypeToRow = {
Row.TypeId: Row
- for Row in (
- PresenceRow,
- PresenceDestinationsRow,
- KeyedEduRow,
- EduRow,
- DeviceRow,
- )
+ for Row in (PresenceRow, PresenceDestinationsRow, KeyedEduRow, EduRow,)
}
-ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", (
- "presence", # list(UserPresenceState)
- "presence_destinations", # list of tuples of UserPresenceState and destinations
- "keyed_edus", # dict of destination -> { key -> Edu }
- "edus", # dict of destination -> [Edu]
- "device_destinations", # set of destinations
-))
+ParsedFederationStreamData = namedtuple(
+ "ParsedFederationStreamData",
+ (
+ "presence", # list(UserPresenceState)
+ "presence_destinations", # list of tuples of UserPresenceState and destinations
+ "keyed_edus", # dict of destination -> { key -> Edu }
+ "edus", # dict of destination -> [Edu]
+ ),
+)
def process_rows_for_federation(transaction_queue, rows):
@@ -520,11 +485,7 @@ def process_rows_for_federation(transaction_queue, rows):
# them into the appropriate collection and then send them off.
buff = ParsedFederationStreamData(
- presence=[],
- presence_destinations=[],
- keyed_edus={},
- edus={},
- device_destinations=set(),
+ presence=[], presence_destinations=[], keyed_edus={}, edus={},
)
# Parse the rows in the stream and add to the buffer
@@ -542,7 +503,7 @@ def process_rows_for_federation(transaction_queue, rows):
for state, destinations in buff.presence_destinations:
transaction_queue.send_presence_to_destinations(
- states=[state], destinations=destinations,
+ states=[state], destinations=destinations
)
for destination, edu_map in iteritems(buff.keyed_edus):
@@ -552,6 +513,3 @@ def process_rows_for_federation(transaction_queue, rows):
for destination, edu_list in iteritems(buff.edus):
for edu in edu_list:
transaction_queue.send_edu(edu, None)
-
- for destination in buff.device_destinations:
- transaction_queue.send_device_messages(destination)
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 4f0f939102..233cb33daf 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import Dict, Hashable, Iterable, List, Optional, Set
from six import itervalues
@@ -21,11 +22,18 @@ from prometheus_client import Counter
from twisted.internet import defer
+import synapse
import synapse.metrics
+from synapse.events import EventBase
from synapse.federation.sender.per_destination_queue import PerDestinationQueue
from synapse.federation.sender.transaction_manager import TransactionManager
from synapse.federation.units import Edu
from synapse.handlers.presence import get_interested_remotes
+from synapse.logging.context import (
+ make_deferred_yieldable,
+ preserve_fn,
+ run_in_background,
+)
from synapse.metrics import (
LaterGauge,
event_processing_loop_counter,
@@ -33,8 +41,9 @@ from synapse.metrics import (
events_processed_counter,
)
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util import logcontext
-from synapse.util.metrics import measure_func
+from synapse.storage.presence import UserPresenceState
+from synapse.types import ReadReceipt
+from synapse.util.metrics import Measure, measure_func
logger = logging.getLogger(__name__)
@@ -44,13 +53,13 @@ sent_pdus_destination_dist_count = Counter(
)
sent_pdus_destination_dist_total = Counter(
- "synapse_federation_client_sent_pdu_destinations:total", ""
+ "synapse_federation_client_sent_pdu_destinations:total",
"Total number of PDUs queued for sending across all destinations",
)
class FederationSender(object):
- def __init__(self, hs):
+ def __init__(self, hs: "synapse.server.HomeServer"):
self.hs = hs
self.server_name = hs.hostname
@@ -63,14 +72,15 @@ class FederationSender(object):
self._transaction_manager = TransactionManager(hs)
# map from destination to PerDestinationQueue
- self._per_destination_queues = {} # type: dict[str, PerDestinationQueue]
+ self._per_destination_queues = {} # type: Dict[str, PerDestinationQueue]
LaterGauge(
"synapse_federation_transaction_queue_pending_destinations",
"",
[],
lambda: sum(
- 1 for d in self._per_destination_queues.values()
+ 1
+ for d in self._per_destination_queues.values()
if d.transmission_loop_running
),
)
@@ -78,7 +88,7 @@ class FederationSender(object):
# Map of user_id -> UserPresenceState for all the pending presence
# to be sent out by user_id. Entries here get processed and put in
# pending_presence_by_dest
- self.pending_presence = {}
+ self.pending_presence = {} # type: Dict[str, UserPresenceState]
LaterGauge(
"synapse_federation_transaction_queue_pending_pdus",
@@ -108,21 +118,19 @@ class FederationSender(object):
# awaiting a call to flush_read_receipts_for_room. The presence of an entry
# here for a given room means that we are rate-limiting RR flushes to that room,
# and that there is a pending call to _flush_rrs_for_room in the system.
- self._queues_awaiting_rr_flush_by_room = {
- } # type: dict[str, set[PerDestinationQueue]]
+ self._queues_awaiting_rr_flush_by_room = (
+ {}
+ ) # type: Dict[str, Set[PerDestinationQueue]]
self._rr_txn_interval_per_room_ms = (
- 1000.0 / hs.get_config().federation_rr_transactions_per_room_per_second
+ 1000.0 / hs.config.federation_rr_transactions_per_room_per_second
)
- def _get_per_destination_queue(self, destination):
+ def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
"""Get or create a PerDestinationQueue for the given destination
Args:
- destination (str): server_name of remote server
-
- Returns:
- PerDestinationQueue
+ destination: server_name of remote server
"""
queue = self._per_destination_queues.get(destination)
if not queue:
@@ -130,7 +138,7 @@ class FederationSender(object):
self._per_destination_queues[destination] = queue
return queue
- def notify_new_events(self, current_id):
+ def notify_new_events(self, current_id: int) -> None:
"""This gets called when we have some new events we might want to
send out to other servers.
"""
@@ -141,18 +149,16 @@ class FederationSender(object):
# fire off a processing loop in the background
run_as_background_process(
- "process_event_queue_for_federation",
- self._process_event_queue_loop,
+ "process_event_queue_for_federation", self._process_event_queue_loop
)
- @defer.inlineCallbacks
- def _process_event_queue_loop(self):
+ async def _process_event_queue_loop(self) -> None:
try:
self._is_processing = True
while True:
- last_token = yield self.store.get_federation_out_pos("events")
- next_token, events = yield self.store.get_all_new_events_stream(
- last_token, self._last_poked_id, limit=100,
+ last_token = await self.store.get_federation_out_pos("events")
+ next_token, events = await self.store.get_all_new_events_stream(
+ last_token, self._last_poked_id, limit=100
)
logger.debug("Handling %s -> %s", last_token, next_token)
@@ -160,14 +166,16 @@ class FederationSender(object):
if not events and next_token >= self._last_poked_id:
break
- @defer.inlineCallbacks
- def handle_event(event):
+ async def handle_event(event: EventBase) -> None:
# Only send events for this server.
send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of()
is_mine = self.is_mine_id(event.sender)
if not is_mine and send_on_behalf_of is None:
return
+ if not event.internal_metadata.should_proactively_send():
+ return
+
try:
# Get the state from before the event.
# We need to make sure that this is the state from before
@@ -175,8 +183,8 @@ class FederationSender(object):
# Otherwise if the last member on a server in a room is
# banned then it won't receive the event because it won't
# be in the room after the ban.
- destinations = yield self.state.get_current_hosts_in_room(
- event.room_id, latest_event_ids=event.prev_event_ids(),
+ destinations = await self.state.get_hosts_in_room_at_events(
+ event.room_id, event_ids=event.prev_event_ids()
)
except Exception:
logger.exception(
@@ -197,51 +205,54 @@ class FederationSender(object):
self._send_pdu(event, destinations)
- @defer.inlineCallbacks
- def handle_room_events(events):
- for event in events:
- yield handle_event(event)
+ async def handle_room_events(events: Iterable[EventBase]) -> None:
+ with Measure(self.clock, "handle_room_events"):
+ for event in events:
+ await handle_event(event)
- events_by_room = {}
+ events_by_room = {} # type: Dict[str, List[EventBase]]
for event in events:
events_by_room.setdefault(event.room_id, []).append(event)
- yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- logcontext.run_in_background(handle_room_events, evs)
- for evs in itervalues(events_by_room)
- ],
- consumeErrors=True
- ))
-
- yield self.store.update_federation_out_pos(
- "events", next_token
+ await make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(handle_room_events, evs)
+ for evs in itervalues(events_by_room)
+ ],
+ consumeErrors=True,
+ )
)
+ await self.store.update_federation_out_pos("events", next_token)
+
if events:
now = self.clock.time_msec()
- ts = yield self.store.get_received_ts(events[-1].event_id)
+ ts = await self.store.get_received_ts(events[-1].event_id)
synapse.metrics.event_processing_lag.labels(
- "federation_sender").set(now - ts)
+ "federation_sender"
+ ).set(now - ts)
synapse.metrics.event_processing_last_ts.labels(
- "federation_sender").set(ts)
+ "federation_sender"
+ ).set(ts)
events_processed_counter.inc(len(events))
- event_processing_loop_room_count.labels(
- "federation_sender"
- ).inc(len(events_by_room))
+ event_processing_loop_room_count.labels("federation_sender").inc(
+ len(events_by_room)
+ )
event_processing_loop_counter.labels("federation_sender").inc()
synapse.metrics.event_processing_positions.labels(
- "federation_sender").set(next_token)
+ "federation_sender"
+ ).set(next_token)
finally:
self._is_processing = False
- def _send_pdu(self, pdu, destinations):
+ def _send_pdu(self, pdu: EventBase, destinations: Iterable[str]) -> None:
# We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus
# table and we'll get back to it later.
@@ -263,11 +274,11 @@ class FederationSender(object):
self._get_per_destination_queue(destination).send_pdu(pdu, order)
@defer.inlineCallbacks
- def send_read_receipt(self, receipt):
+ def send_read_receipt(self, receipt: ReadReceipt):
"""Send a RR to any other servers in the room
Args:
- receipt (synapse.types.ReadReceipt): receipt to be sent
+ receipt: receipt to be sent
"""
# Some background on the rate-limiting going on here.
@@ -309,9 +320,7 @@ class FederationSender(object):
if not domains:
return
- queues_pending_flush = self._queues_awaiting_rr_flush_by_room.get(
- room_id
- )
+ queues_pending_flush = self._queues_awaiting_rr_flush_by_room.get(room_id)
# if there is no flush yet scheduled, we will send out these receipts with
# immediate flushes, and schedule the next flush for this room.
@@ -332,7 +341,7 @@ class FederationSender(object):
else:
queue.flush_read_receipts_for_room(room_id)
- def _schedule_rr_flush_for_room(self, room_id, n_domains):
+ def _schedule_rr_flush_for_room(self, room_id: str, n_domains: int) -> None:
# that is going to cause approximately len(domains) transactions, so now back
# off for that multiplied by RR_TXN_INTERVAL_PER_ROOM
backoff_ms = self._rr_txn_interval_per_room_ms * n_domains
@@ -341,7 +350,7 @@ class FederationSender(object):
self.clock.call_later(backoff_ms, self._flush_rrs_for_room, room_id)
self._queues_awaiting_rr_flush_by_room[room_id] = set()
- def _flush_rrs_for_room(self, room_id):
+ def _flush_rrs_for_room(self, room_id: str) -> None:
queues = self._queues_awaiting_rr_flush_by_room.pop(room_id)
logger.debug("Flushing RRs in %s to %s", room_id, queues)
@@ -355,16 +364,13 @@ class FederationSender(object):
for queue in queues:
queue.flush_read_receipts_for_room(room_id)
- @logcontext.preserve_fn # the caller should not yield on this
+ @preserve_fn # the caller should not yield on this
@defer.inlineCallbacks
- def send_presence(self, states):
+ def send_presence(self, states: List[UserPresenceState]):
"""Send the new presence states to the appropriate destinations.
This actually queues up the presence states ready for sending and
triggers a background task to process them and send out the transactions.
-
- Args:
- states (list(UserPresenceState))
"""
if not self.hs.config.use_presence:
# No-op if presence is disabled.
@@ -374,10 +380,9 @@ class FederationSender(object):
# updates in quick succession are correctly handled.
# We only want to send presence for our own users, so lets always just
# filter here just in case.
- self.pending_presence.update({
- state.user_id: state for state in states
- if self.is_mine_id(state.user_id)
- })
+ self.pending_presence.update(
+ {state.user_id: state for state in states if self.is_mine_id(state.user_id)}
+ )
# We then handle the new pending presence in batches, first figuring
# out the destinations we need to send each state to and then poking it
@@ -402,11 +407,10 @@ class FederationSender(object):
finally:
self._processing_pending_presence = False
- def send_presence_to_destinations(self, states, destinations):
+ def send_presence_to_destinations(
+ self, states: List[UserPresenceState], destinations: List[str]
+ ) -> None:
"""Send the given presence states to the given destinations.
-
- Args:
- states (list[UserPresenceState])
destinations (list[str])
"""
@@ -421,12 +425,9 @@ class FederationSender(object):
@measure_func("txnqueue._process_presence")
@defer.inlineCallbacks
- def _process_presence_inner(self, states):
+ def _process_presence_inner(self, states: List[UserPresenceState]):
"""Given a list of states populate self.pending_presence_by_dest and
poke to send a new transaction to each destination
-
- Args:
- states (list(UserPresenceState))
"""
hosts_and_states = yield get_interested_remotes(self.store, states, self.state)
@@ -436,14 +437,20 @@ class FederationSender(object):
continue
self._get_per_destination_queue(destination).send_presence(states)
- def build_and_send_edu(self, destination, edu_type, content, key=None):
+ def build_and_send_edu(
+ self,
+ destination: str,
+ edu_type: str,
+ content: dict,
+ key: Optional[Hashable] = None,
+ ):
"""Construct an Edu object, and queue it for sending
Args:
- destination (str): name of server to send to
- edu_type (str): type of EDU to send
- content (dict): content of EDU
- key (Any|None): clobbering key for this edu
+ destination: name of server to send to
+ edu_type: type of EDU to send
+ content: content of EDU
+ key: clobbering key for this edu
"""
if destination == self.server_name:
logger.info("Not sending EDU to ourselves")
@@ -458,12 +465,12 @@ class FederationSender(object):
self.send_edu(edu, key)
- def send_edu(self, edu, key):
+ def send_edu(self, edu: Edu, key: Optional[Hashable]):
"""Queue an EDU for sending
Args:
- edu (Edu): edu to send
- key (Any|None): clobbering key for this edu
+ edu: edu to send
+ key: clobbering key for this edu
"""
queue = self._get_per_destination_queue(edu.destination)
if key:
@@ -471,12 +478,25 @@ class FederationSender(object):
else:
queue.send_edu(edu)
- def send_device_messages(self, destination):
+ def send_device_messages(self, destination: str):
+ if destination == self.server_name:
+ logger.warning("Not sending device update to ourselves")
+ return
+
+ self._get_per_destination_queue(destination).attempt_new_transaction()
+
+ def wake_destination(self, destination: str):
+ """Called when we want to retry sending transactions to a remote.
+
+ This is mainly useful if the remote server has been down and we think it
+ might have come back.
+ """
+
if destination == self.server_name:
- logger.info("Not sending device update to ourselves")
+ logger.warning("Not waking up ourselves")
return
self._get_per_destination_queue(destination).attempt_new_transaction()
- def get_current_token(self):
+ def get_current_token(self) -> int:
return 0
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 564c57203d..e13cd20ffa 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -15,11 +15,11 @@
# limitations under the License.
import datetime
import logging
+from typing import Dict, Hashable, Iterable, List, Tuple
from prometheus_client import Counter
-from twisted.internet import defer
-
+import synapse.server
from synapse.api.errors import (
FederationDeniedError,
HttpResponseException,
@@ -30,7 +30,8 @@ from synapse.federation.units import Edu
from synapse.handlers.presence import format_user_presence_state
from synapse.metrics import sent_transactions_counter
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage import UserPresenceState
+from synapse.storage.presence import UserPresenceState
+from synapse.types import ReadReceipt
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
# This is defined in the Matrix spec and enforced by the receiver.
@@ -55,13 +56,18 @@ class PerDestinationQueue(object):
Manages the per-destination transmission queues.
Args:
- hs (synapse.HomeServer):
- transaction_sender (TransactionManager):
- destination (str): the server_name of the destination that we are managing
+ hs
+ transaction_sender
+ destination: the server_name of the destination that we are managing
transmission for.
"""
- def __init__(self, hs, transaction_manager, destination):
+ def __init__(
+ self,
+ hs: "synapse.server.HomeServer",
+ transaction_manager: "synapse.federation.sender.TransactionManager",
+ destination: str,
+ ):
self._server_name = hs.hostname
self._clock = hs.get_clock()
self._store = hs.get_datastore()
@@ -71,20 +77,20 @@ class PerDestinationQueue(object):
self.transmission_loop_running = False
# a list of tuples of (pending pdu, order)
- self._pending_pdus = [] # type: list[tuple[EventBase, int]]
- self._pending_edus = [] # type: list[Edu]
+ self._pending_pdus = [] # type: List[Tuple[EventBase, int]]
+ self._pending_edus = [] # type: List[Edu]
# Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered
# based on their key (e.g. typing events by room_id)
# Map of (edu_type, key) -> Edu
- self._pending_edus_keyed = {} # type: dict[tuple[str, str], Edu]
+ self._pending_edus_keyed = {} # type: Dict[Tuple[str, Hashable], Edu]
# Map of user_id -> UserPresenceState of pending presence to be sent to this
# destination
- self._pending_presence = {} # type: dict[str, UserPresenceState]
+ self._pending_presence = {} # type: Dict[str, UserPresenceState]
# room_id -> receipt_type -> user_id -> receipt_dict
- self._pending_rrs = {}
+ self._pending_rrs = {} # type: Dict[str, Dict[str, Dict[str, dict]]]
self._rrs_pending_flush = False
# stream_id of last successfully sent to-device message.
@@ -94,50 +100,50 @@ class PerDestinationQueue(object):
# stream_id of last successfully sent device list update.
self._last_device_list_stream_id = 0
- def __str__(self):
+ def __str__(self) -> str:
return "PerDestinationQueue[%s]" % self._destination
- def pending_pdu_count(self):
+ def pending_pdu_count(self) -> int:
return len(self._pending_pdus)
- def pending_edu_count(self):
+ def pending_edu_count(self) -> int:
return (
len(self._pending_edus)
+ len(self._pending_presence)
+ len(self._pending_edus_keyed)
)
- def send_pdu(self, pdu, order):
+ def send_pdu(self, pdu: EventBase, order: int) -> None:
"""Add a PDU to the queue, and start the transmission loop if neccessary
Args:
- pdu (EventBase): pdu to send
- order (int):
+ pdu: pdu to send
+ order
"""
self._pending_pdus.append((pdu, order))
self.attempt_new_transaction()
- def send_presence(self, states):
+ def send_presence(self, states: Iterable[UserPresenceState]) -> None:
"""Add presence updates to the queue. Start the transmission loop if neccessary.
Args:
- states (iterable[UserPresenceState]): presence to send
+ states: presence to send
"""
self._pending_presence.update({state.user_id: state for state in states})
self.attempt_new_transaction()
- def queue_read_receipt(self, receipt):
+ def queue_read_receipt(self, receipt: ReadReceipt) -> None:
"""Add a RR to the list to be sent. Doesn't start the transmission loop yet
(see flush_read_receipts_for_room)
Args:
- receipt (synapse.api.receipt_info.ReceiptInfo): receipt to be queued
+ receipt: receipt to be queued
"""
self._pending_rrs.setdefault(receipt.room_id, {}).setdefault(
receipt.receipt_type, {}
)[receipt.user_id] = {"event_ids": receipt.event_ids, "data": receipt.data}
- def flush_read_receipts_for_room(self, room_id):
+ def flush_read_receipts_for_room(self, room_id: str) -> None:
# if we don't have any read-receipts for this room, it may be that we've already
# sent them out, so we don't need to flush.
if room_id not in self._pending_rrs:
@@ -145,15 +151,15 @@ class PerDestinationQueue(object):
self._rrs_pending_flush = True
self.attempt_new_transaction()
- def send_keyed_edu(self, edu, key):
+ def send_keyed_edu(self, edu: Edu, key: Hashable) -> None:
self._pending_edus_keyed[(edu.edu_type, key)] = edu
self.attempt_new_transaction()
- def send_edu(self, edu):
+ def send_edu(self, edu) -> None:
self._pending_edus.append(edu)
self.attempt_new_transaction()
- def attempt_new_transaction(self):
+ def attempt_new_transaction(self) -> None:
"""Try to start a new transaction to this destination
If there is already a transaction in progress to this destination,
@@ -176,24 +182,34 @@ class PerDestinationQueue(object):
self._transaction_transmission_loop,
)
- @defer.inlineCallbacks
- def _transaction_transmission_loop(self):
- pending_pdus = []
+ async def _transaction_transmission_loop(self) -> None:
+ pending_pdus = [] # type: List[Tuple[EventBase, int]]
try:
self.transmission_loop_running = True
# This will throw if we wouldn't retry. We do this here so we fail
# quickly, but we will later check this again in the http client,
# hence why we throw the result away.
- yield get_retry_limiter(self._destination, self._clock, self._store)
+ await get_retry_limiter(self._destination, self._clock, self._store)
pending_pdus = []
while True:
- device_message_edus, device_stream_id, dev_list_id = (
- # We have to keep 2 free slots for presence and rr_edus
- yield self._get_new_device_messages(MAX_EDUS_PER_TRANSACTION - 2)
+ # We have to keep 2 free slots for presence and rr_edus
+ limit = MAX_EDUS_PER_TRANSACTION - 2
+
+ device_update_edus, dev_list_id = await self._get_device_update_edus(
+ limit
)
+ limit -= len(device_update_edus)
+
+ (
+ to_device_edus,
+ device_stream_id,
+ ) = await self._get_to_device_message_edus(limit)
+
+ pending_edus = device_update_edus + to_device_edus
+
# BEGIN CRITICAL SECTION
#
# In order to avoid a race condition, we need to make sure that
@@ -208,10 +224,6 @@ class PerDestinationQueue(object):
# We can only include at most 50 PDUs per transactions
pending_pdus, self._pending_pdus = pending_pdus[:50], pending_pdus[50:]
- pending_edus = []
-
- # We can only include at most 100 EDUs per transactions
- # rr_edus and pending_presence take at most one slot each
pending_edus.extend(self._get_rr_edus(force_flush=False))
pending_presence = self._pending_presence
self._pending_presence = {}
@@ -232,7 +244,6 @@ class PerDestinationQueue(object):
)
)
- pending_edus.extend(device_message_edus)
pending_edus.extend(
self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))
)
@@ -262,7 +273,7 @@ class PerDestinationQueue(object):
# END CRITICAL SECTION
- success = yield self._transaction_manager.send_new_transaction(
+ success = await self._transaction_manager.send_new_transaction(
self._destination, pending_pdus, pending_edus
)
if success:
@@ -272,14 +283,17 @@ class PerDestinationQueue(object):
sent_edus_by_type.labels(edu.edu_type).inc()
# Remove the acknowledged device messages from the database
# Only bother if we actually sent some device messages
- if device_message_edus:
- yield self._store.delete_device_msgs_for_remote(
+ if to_device_edus:
+ await self._store.delete_device_msgs_for_remote(
self._destination, device_stream_id
)
+
+ # also mark the device updates as sent
+ if device_update_edus:
logger.info(
"Marking as sent %r %r", self._destination, dev_list_id
)
- yield self._store.mark_as_sent_devices_by_remote(
+ await self._store.mark_as_sent_devices_by_remote(
self._destination, dev_list_id
)
@@ -324,7 +338,7 @@ class PerDestinationQueue(object):
# We want to be *very* sure we clear this after we stop processing
self.transmission_loop_running = False
- def _get_rr_edus(self, force_flush):
+ def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]:
if not self._pending_rrs:
return
if not force_flush and not self._rrs_pending_flush:
@@ -341,40 +355,39 @@ class PerDestinationQueue(object):
self._rrs_pending_flush = False
yield edu
- def _pop_pending_edus(self, limit):
+ def _pop_pending_edus(self, limit: int) -> List[Edu]:
pending_edus = self._pending_edus
pending_edus, self._pending_edus = pending_edus[:limit], pending_edus[limit:]
return pending_edus
- @defer.inlineCallbacks
- def _get_new_device_messages(self, limit):
+ async def _get_device_update_edus(self, limit: int) -> Tuple[List[Edu], int]:
last_device_list = self._last_device_list_stream_id
# Retrieve list of new device updates to send to the destination
- now_stream_id, results = yield self._store.get_devices_by_remote(
- self._destination, last_device_list, limit=limit,
+ now_stream_id, results = await self._store.get_device_updates_by_remote(
+ self._destination, last_device_list, limit=limit
)
edus = [
Edu(
origin=self._server_name,
destination=self._destination,
- edu_type="m.device_list_update",
+ edu_type=edu_type,
content=content,
)
- for content in results
+ for (edu_type, content) in results
]
- assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"
+ assert len(edus) <= limit, "get_device_updates_by_remote returned too many EDUs"
+
+ return (edus, now_stream_id)
+ async def _get_to_device_message_edus(self, limit: int) -> Tuple[List[Edu], int]:
last_device_stream_id = self._last_device_stream_id
to_device_stream_id = self._store.get_to_device_stream_token()
- contents, stream_id = yield self._store.get_new_device_msgs_for_remote(
- self._destination,
- last_device_stream_id,
- to_device_stream_id,
- limit - len(edus),
+ contents, stream_id = await self._store.get_new_device_msgs_for_remote(
+ self._destination, last_device_stream_id, to_device_stream_id, limit
)
- edus.extend(
+ edus = [
Edu(
origin=self._server_name,
destination=self._destination,
@@ -382,6 +395,6 @@ class PerDestinationQueue(object):
content=content,
)
for content in contents
- )
+ ]
- defer.returnValue((edus, stream_id, now_stream_id))
+ return (edus, stream_id)
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 35e6b8ff5b..3c2a02a3b3 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -13,12 +13,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import List
-from twisted.internet import defer
+from canonicaljson import json
+import synapse.server
from synapse.api.errors import HttpResponseException
+from synapse.events import EventBase
from synapse.federation.persistence import TransactionActions
-from synapse.federation.units import Transaction
+from synapse.federation.units import Edu, Transaction
+from synapse.logging.opentracing import (
+ extract_text_map,
+ set_tag,
+ start_active_span_follows_from,
+ tags,
+ whitelisted_homeserver,
+)
from synapse.util.metrics import measure_func
logger = logging.getLogger(__name__)
@@ -29,9 +39,10 @@ class TransactionManager(object):
shared between PerDestinationQueue objects
"""
- def __init__(self, hs):
+
+ def __init__(self, hs: "synapse.server.HomeServer"):
self._server_name = hs.hostname
- self.clock = hs.get_clock() # nb must be called this for @measure_func
+ self.clock = hs.get_clock() # nb must be called this for @measure_func
self._store = hs.get_datastore()
self._transaction_actions = TransactionActions(self._store)
self._transport_layer = hs.get_federation_transport_client()
@@ -40,108 +51,119 @@ class TransactionManager(object):
self._next_txn_id = int(self.clock.time_msec())
@measure_func("_send_new_transaction")
- @defer.inlineCallbacks
- def send_new_transaction(self, destination, pending_pdus, pending_edus):
-
- # Sort based on the order field
- pending_pdus.sort(key=lambda t: t[1])
- pdus = [x[0] for x in pending_pdus]
- edus = pending_edus
-
- success = True
-
- logger.debug("TX [%s] _attempt_new_transaction", destination)
-
- txn_id = str(self._next_txn_id)
-
- logger.debug(
- "TX [%s] {%s} Attempting new transaction"
- " (pdus: %d, edus: %d)",
- destination, txn_id,
- len(pdus),
- len(edus),
- )
-
- logger.debug("TX [%s] Persisting transaction...", destination)
-
- transaction = Transaction.create_new(
- origin_server_ts=int(self.clock.time_msec()),
- transaction_id=txn_id,
- origin=self._server_name,
- destination=destination,
- pdus=pdus,
- edus=edus,
- )
-
- self._next_txn_id += 1
-
- yield self._transaction_actions.prepare_to_send(transaction)
-
- logger.debug("TX [%s] Persisted transaction", destination)
- logger.info(
- "TX [%s] {%s} Sending transaction [%s],"
- " (PDUs: %d, EDUs: %d)",
- destination, txn_id,
- transaction.transaction_id,
- len(pdus),
- len(edus),
- )
-
- # Actually send the transaction
-
- # FIXME (erikj): This is a bit of a hack to make the Pdu age
- # keys work
- def json_data_cb():
- data = transaction.get_dict()
- now = int(self.clock.time_msec())
- if "pdus" in data:
- for p in data["pdus"]:
- if "age_ts" in p:
- unsigned = p.setdefault("unsigned", {})
- unsigned["age"] = now - int(p["age_ts"])
- del p["age_ts"]
- return data
-
- try:
- response = yield self._transport_layer.send_transaction(
- transaction, json_data_cb
+ async def send_new_transaction(
+ self, destination: str, pending_pdus: List[EventBase], pending_edus: List[Edu]
+ ):
+
+ # Make a transaction-sending opentracing span. This span follows on from
+ # all the edus in that transaction. This needs to be done since there is
+ # no active span here, so if the edus were not received by the remote the
+ # span would have no causality and it would be forgotten.
+ # The span_contexts is a generator so that it won't be evaluated if
+ # opentracing is disabled. (Yay speed!)
+
+ span_contexts = []
+ keep_destination = whitelisted_homeserver(destination)
+
+ for edu in pending_edus:
+ context = edu.get_context()
+ if context:
+ span_contexts.append(extract_text_map(json.loads(context)))
+ if keep_destination:
+ edu.strip_context()
+
+ with start_active_span_follows_from("send_transaction", span_contexts):
+
+ # Sort based on the order field
+ pending_pdus.sort(key=lambda t: t[1])
+ pdus = [x[0] for x in pending_pdus]
+ edus = pending_edus
+
+ success = True
+
+ logger.debug("TX [%s] _attempt_new_transaction", destination)
+
+ txn_id = str(self._next_txn_id)
+
+ logger.debug(
+ "TX [%s] {%s} Attempting new transaction (pdus: %d, edus: %d)",
+ destination,
+ txn_id,
+ len(pdus),
+ len(edus),
)
- code = 200
- except HttpResponseException as e:
- code = e.code
- response = e.response
-
- if e.code in (401, 404, 429) or 500 <= e.code:
- logger.info(
- "TX [%s] {%s} got %d response",
- destination, txn_id, code
- )
- raise e
- logger.info(
- "TX [%s] {%s} got %d response",
- destination, txn_id, code
- )
+ transaction = Transaction.create_new(
+ origin_server_ts=int(self.clock.time_msec()),
+ transaction_id=txn_id,
+ origin=self._server_name,
+ destination=destination,
+ pdus=pdus,
+ edus=edus,
+ )
- yield self._transaction_actions.delivered(
- transaction, code, response
- )
+ self._next_txn_id += 1
- logger.debug("TX [%s] {%s} Marked as delivered", destination, txn_id)
+ logger.info(
+ "TX [%s] {%s} Sending transaction [%s], (PDUs: %d, EDUs: %d)",
+ destination,
+ txn_id,
+ transaction.transaction_id,
+ len(pdus),
+ len(edus),
+ )
- if code == 200:
- for e_id, r in response.get("pdus", {}).items():
- if "error" in r:
- logger.warn(
- "TX [%s] {%s} Remote returned error for %s: %s",
- destination, txn_id, e_id, r,
- )
- else:
- for p in pdus:
- logger.warn(
- "TX [%s] {%s} Failed to send event %s",
- destination, txn_id, p.event_id,
+ # Actually send the transaction
+
+ # FIXME (erikj): This is a bit of a hack to make the Pdu age
+ # keys work
+ def json_data_cb():
+ data = transaction.get_dict()
+ now = int(self.clock.time_msec())
+ if "pdus" in data:
+ for p in data["pdus"]:
+ if "age_ts" in p:
+ unsigned = p.setdefault("unsigned", {})
+ unsigned["age"] = now - int(p["age_ts"])
+ del p["age_ts"]
+ return data
+
+ try:
+ response = await self._transport_layer.send_transaction(
+ transaction, json_data_cb
)
- success = False
+ code = 200
+ except HttpResponseException as e:
+ code = e.code
+ response = e.response
+
+ if e.code in (401, 404, 429) or 500 <= e.code:
+ logger.info(
+ "TX [%s] {%s} got %d response", destination, txn_id, code
+ )
+ raise e
+
+ logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
+
+ if code == 200:
+ for e_id, r in response.get("pdus", {}).items():
+ if "error" in r:
+ logger.warning(
+ "TX [%s] {%s} Remote returned error for %s: %s",
+ destination,
+ txn_id,
+ e_id,
+ r,
+ )
+ else:
+ for p in pdus:
+ logger.warning(
+ "TX [%s] {%s} Failed to send event %s",
+ destination,
+ txn_id,
+ p.event_id,
+ )
+ success = False
- defer.returnValue(success)
+ set_tag(tags.ERROR, not success)
+ return success
diff --git a/synapse/federation/transport/__init__.py b/synapse/federation/transport/__init__.py
index d9fcc520a0..5db733af98 100644
--- a/synapse/federation/transport/__init__.py
+++ b/synapse/federation/transport/__init__.py
@@ -14,9 +14,9 @@
# limitations under the License.
"""The transport layer is responsible for both sending transactions to remote
-home servers and receiving a variety of requests from other home servers.
+homeservers and receiving a variety of requests from other homeservers.
-By default this is done over HTTPS (and all home servers are required to
+By default this is done over HTTPS (and all homeservers are required to
support HTTPS), however individual pairings of servers may decide to
communicate over a different (albeit still reliable) protocol.
"""
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index e424c40fdf..383e3fdc8b 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -15,14 +15,19 @@
# limitations under the License.
import logging
+from typing import Any, Dict
from six.moves import urllib
from twisted.internet import defer
from synapse.api.constants import Membership
-from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX
-from synapse.util.logutils import log_function
+from synapse.api.urls import (
+ FEDERATION_UNSTABLE_PREFIX,
+ FEDERATION_V1_PREFIX,
+ FEDERATION_V2_PREFIX,
+)
+from synapse.logging.utils import log_function
logger = logging.getLogger(__name__)
@@ -35,35 +40,12 @@ class TransportLayerClient(object):
self.client = hs.get_http_client()
@log_function
- def get_room_state(self, destination, room_id, event_id):
- """ Requests all state for a given room from the given server at the
- given event.
-
- Args:
- destination (str): The host name of the remote home server we want
- to get the state from.
- context (str): The name of the context we want the state of
- event_id (str): The event we want the context at.
-
- Returns:
- Deferred: Results in a dict received from the remote homeserver.
- """
- logger.debug("get_room_state dest=%s, room=%s",
- destination, room_id)
-
- path = _create_v1_path("/state/%s", room_id)
- return self.client.get_json(
- destination, path=path, args={"event_id": event_id},
- try_trailing_slash_on_400=True,
- )
-
- @log_function
def get_room_state_ids(self, destination, room_id, event_id):
""" Requests all state for a given room from the given server at the
given event. Returns the state's event_id's
Args:
- destination (str): The host name of the remote home server we want
+ destination (str): The host name of the remote homeserver we want
to get the state from.
context (str): The name of the context we want the state of
event_id (str): The event we want the context at.
@@ -71,12 +53,13 @@ class TransportLayerClient(object):
Returns:
Deferred: Results in a dict received from the remote homeserver.
"""
- logger.debug("get_room_state_ids dest=%s, room=%s",
- destination, room_id)
+ logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id)
path = _create_v1_path("/state_ids/%s", room_id)
return self.client.get_json(
- destination, path=path, args={"event_id": event_id},
+ destination,
+ path=path,
+ args={"event_id": event_id},
try_trailing_slash_on_400=True,
)
@@ -85,7 +68,7 @@ class TransportLayerClient(object):
""" Requests the pdu with give id and origin from the given server.
Args:
- destination (str): The host name of the remote home server we want
+ destination (str): The host name of the remote homeserver we want
to get the state from.
event_id (str): The id of the event being requested.
timeout (int): How long to try (in ms) the destination for before
@@ -94,13 +77,11 @@ class TransportLayerClient(object):
Returns:
Deferred: Results in a dict received from the remote homeserver.
"""
- logger.debug("get_pdu dest=%s, event_id=%s",
- destination, event_id)
+ logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id)
path = _create_v1_path("/event/%s", event_id)
return self.client.get_json(
- destination, path=path, timeout=timeout,
- try_trailing_slash_on_400=True,
+ destination, path=path, timeout=timeout, try_trailing_slash_on_400=True
)
@log_function
@@ -118,8 +99,11 @@ class TransportLayerClient(object):
Deferred: Results in a dict received from the remote homeserver.
"""
logger.debug(
- "backfill dest=%s, room_id=%s, event_tuples=%s, limit=%s",
- destination, room_id, repr(event_tuples), str(limit)
+ "backfill dest=%s, room_id=%s, event_tuples=%r, limit=%s",
+ destination,
+ room_id,
+ event_tuples,
+ str(limit),
)
if not event_tuples:
@@ -128,16 +112,10 @@ class TransportLayerClient(object):
path = _create_v1_path("/backfill/%s", room_id)
- args = {
- "v": event_tuples,
- "limit": [str(limit)],
- }
+ args = {"v": event_tuples, "limit": [str(limit)]}
return self.client.get_json(
- destination,
- path=path,
- args=args,
- try_trailing_slash_on_400=True,
+ destination, path=path, args=args, try_trailing_slash_on_400=True
)
@defer.inlineCallbacks
@@ -163,7 +141,8 @@ class TransportLayerClient(object):
"""
logger.debug(
"send_data dest=%s, txid=%s",
- transaction.destination, transaction.transaction_id
+ transaction.destination,
+ transaction.transaction_id,
)
if transaction.destination == self.server_name:
@@ -185,12 +164,13 @@ class TransportLayerClient(object):
try_trailing_slash_on_400=True,
)
- defer.returnValue(response)
+ return response
@defer.inlineCallbacks
@log_function
- def make_query(self, destination, query_type, args, retry_on_dns_fail,
- ignore_backoff=False):
+ def make_query(
+ self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False
+ ):
path = _create_v1_path("/query/%s", query_type)
content = yield self.client.get_json(
@@ -202,7 +182,7 @@ class TransportLayerClient(object):
ignore_backoff=ignore_backoff,
)
- defer.returnValue(content)
+ return content
@defer.inlineCallbacks
@log_function
@@ -235,8 +215,8 @@ class TransportLayerClient(object):
valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships:
raise RuntimeError(
- "make_membership_event called with membership='%s', must be one of %s" %
- (membership, ",".join(valid_memberships))
+ "make_membership_event called with membership='%s', must be one of %s"
+ % (membership, ",".join(valid_memberships))
)
path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id)
@@ -260,31 +240,39 @@ class TransportLayerClient(object):
ignore_backoff=ignore_backoff,
)
- defer.returnValue(content)
+ return content
@defer.inlineCallbacks
@log_function
- def send_join(self, destination, room_id, event_id, content):
+ def send_join_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
response = yield self.client.put_json(
- destination=destination,
- path=path,
- data=content,
+ destination=destination, path=path, data=content
)
- defer.returnValue(response)
+ return response
@defer.inlineCallbacks
@log_function
- def send_leave(self, destination, room_id, event_id, content):
+ def send_join_v2(self, destination, room_id, event_id, content):
+ path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
+
+ response = yield self.client.put_json(
+ destination=destination, path=path, data=content
+ )
+
+ return response
+
+ @defer.inlineCallbacks
+ @log_function
+ def send_leave_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/send_leave/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
path=path,
data=content,
-
# we want to do our best to send this through. The problem is
# that if it fails, we won't retry it later, so if the remote
# server was just having a momentary blip, the room will be out of
@@ -292,21 +280,36 @@ class TransportLayerClient(object):
ignore_backoff=True,
)
- defer.returnValue(response)
+ return response
@defer.inlineCallbacks
@log_function
- def send_invite_v1(self, destination, room_id, event_id, content):
- path = _create_v1_path("/invite/%s/%s", room_id, event_id)
+ def send_leave_v2(self, destination, room_id, event_id, content):
+ path = _create_v2_path("/send_leave/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
path=path,
data=content,
+ # we want to do our best to send this through. The problem is
+ # that if it fails, we won't retry it later, so if the remote
+ # server was just having a momentary blip, the room will be out of
+ # sync.
ignore_backoff=True,
)
- defer.returnValue(response)
+ return response
+
+ @defer.inlineCallbacks
+ @log_function
+ def send_invite_v1(self, destination, room_id, event_id, content):
+ path = _create_v1_path("/invite/%s/%s", room_id, event_id)
+
+ response = yield self.client.put_json(
+ destination=destination, path=path, data=content, ignore_backoff=True
+ )
+
+ return response
@defer.inlineCallbacks
@log_function
@@ -314,79 +317,77 @@ class TransportLayerClient(object):
path = _create_v2_path("/invite/%s/%s", room_id, event_id)
response = yield self.client.put_json(
- destination=destination,
- path=path,
- data=content,
- ignore_backoff=True,
+ destination=destination, path=path, data=content, ignore_backoff=True
)
- defer.returnValue(response)
+ return response
@defer.inlineCallbacks
@log_function
- def get_public_rooms(self, remote_server, limit, since_token,
- search_filter=None, include_all_networks=False,
- third_party_instance_id=None):
- path = _create_v1_path("/publicRooms")
-
- args = {
- "include_all_networks": "true" if include_all_networks else "false",
- }
- if third_party_instance_id:
- args["third_party_instance_id"] = third_party_instance_id,
- if limit:
- args["limit"] = [str(limit)]
- if since_token:
- args["since"] = [since_token]
-
- # TODO(erikj): Actually send the search_filter across federation.
-
- response = yield self.client.get_json(
- destination=remote_server,
- path=path,
- args=args,
- ignore_backoff=True,
- )
+ def get_public_rooms(
+ self,
+ remote_server,
+ limit,
+ since_token,
+ search_filter=None,
+ include_all_networks=False,
+ third_party_instance_id=None,
+ ):
+ if search_filter:
+ # this uses MSC2197 (Search Filtering over Federation)
+ path = _create_v1_path("/publicRooms")
+
+ data = {"include_all_networks": "true" if include_all_networks else "false"}
+ if third_party_instance_id:
+ data["third_party_instance_id"] = third_party_instance_id
+ if limit:
+ data["limit"] = str(limit)
+ if since_token:
+ data["since"] = since_token
+
+ data["filter"] = search_filter
+
+ response = yield self.client.post_json(
+ destination=remote_server, path=path, data=data, ignore_backoff=True
+ )
+ else:
+ path = _create_v1_path("/publicRooms")
+
+ args = {
+ "include_all_networks": "true" if include_all_networks else "false"
+ } # type: Dict[str, Any]
+ if third_party_instance_id:
+ args["third_party_instance_id"] = (third_party_instance_id,)
+ if limit:
+ args["limit"] = [str(limit)]
+ if since_token:
+ args["since"] = [since_token]
+
+ response = yield self.client.get_json(
+ destination=remote_server, path=path, args=args, ignore_backoff=True
+ )
- defer.returnValue(response)
+ return response
@defer.inlineCallbacks
@log_function
def exchange_third_party_invite(self, destination, room_id, event_dict):
- path = _create_v1_path("/exchange_third_party_invite/%s", room_id,)
+ path = _create_v1_path("/exchange_third_party_invite/%s", room_id)
response = yield self.client.put_json(
- destination=destination,
- path=path,
- data=event_dict,
+ destination=destination, path=path, data=event_dict
)
- defer.returnValue(response)
+ return response
@defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
path = _create_v1_path("/event_auth/%s/%s", room_id, event_id)
- content = yield self.client.get_json(
- destination=destination,
- path=path,
- )
-
- defer.returnValue(content)
-
- @defer.inlineCallbacks
- @log_function
- def send_query_auth(self, destination, room_id, event_id, content):
- path = _create_v1_path("/query_auth/%s/%s", room_id, event_id)
-
- content = yield self.client.post_json(
- destination=destination,
- path=path,
- data=content,
- )
+ content = yield self.client.get_json(destination=destination, path=path)
- defer.returnValue(content)
+ return content
@defer.inlineCallbacks
@log_function
@@ -398,30 +399,37 @@ class TransportLayerClient(object):
{
"device_keys": {
"<user_id>": ["<device_id>"]
- } }
+ }
+ }
Response:
{
"device_keys": {
"<user_id>": {
"<device_id>": {...}
- } } }
+ }
+ },
+ "master_key": {
+ "<user_id>": {...}
+ }
+ },
+ "self_signing_key": {
+ "<user_id>": {...}
+ }
+ }
Args:
destination(str): The server to query.
query_content(dict): The user ids to query.
Returns:
- A dict containg the device keys.
+ A dict containing device and cross-signing keys.
"""
path = _create_v1_path("/user/keys/query")
content = yield self.client.post_json(
- destination=destination,
- path=path,
- data=query_content,
- timeout=timeout,
+ destination=destination, path=path, data=query_content, timeout=timeout
)
- defer.returnValue(content)
+ return content
@defer.inlineCallbacks
@log_function
@@ -431,23 +439,37 @@ class TransportLayerClient(object):
Response:
{
"stream_id": "...",
- "devices": [ { ... } ]
+ "devices": [ { ... } ],
+ "master_key": {
+ "user_id": "<user_id>",
+ "usage": [...],
+ "keys": {...},
+ "signatures": {
+ "<user_id>": {...}
+ }
+ },
+ "self_signing_key": {
+ "user_id": "<user_id>",
+ "usage": [...],
+ "keys": {...},
+ "signatures": {
+ "<user_id>": {...}
+ }
+ }
}
Args:
destination(str): The server to query.
query_content(dict): The user ids to query.
Returns:
- A dict containg the device keys.
+ A dict containing device and cross-signing keys.
"""
path = _create_v1_path("/user/devices/%s", user_id)
content = yield self.client.get_json(
- destination=destination,
- path=path,
- timeout=timeout,
+ destination=destination, path=path, timeout=timeout
)
- defer.returnValue(content)
+ return content
@defer.inlineCallbacks
@log_function
@@ -458,8 +480,10 @@ class TransportLayerClient(object):
{
"one_time_keys": {
"<user_id>": {
- "<device_id>": "<algorithm>"
- } } }
+ "<device_id>": "<algorithm>"
+ }
+ }
+ }
Response:
{
@@ -467,30 +491,38 @@ class TransportLayerClient(object):
"<user_id>": {
"<device_id>": {
"<algorithm>:<key_id>": "<key_base64>"
- } } } }
+ }
+ }
+ }
+ }
Args:
destination(str): The server to query.
query_content(dict): The user ids to query.
Returns:
- A dict containg the one-time keys.
+ A dict containing the one-time keys.
"""
path = _create_v1_path("/user/keys/claim")
content = yield self.client.post_json(
- destination=destination,
- path=path,
- data=query_content,
- timeout=timeout,
+ destination=destination, path=path, data=query_content, timeout=timeout
)
- defer.returnValue(content)
+ return content
@defer.inlineCallbacks
@log_function
- def get_missing_events(self, destination, room_id, earliest_events,
- latest_events, limit, min_depth, timeout):
- path = _create_v1_path("/get_missing_events/%s", room_id,)
+ def get_missing_events(
+ self,
+ destination,
+ room_id,
+ earliest_events,
+ latest_events,
+ limit,
+ min_depth,
+ timeout,
+ ):
+ path = _create_v1_path("/get_missing_events/%s", room_id)
content = yield self.client.post_json(
destination=destination,
@@ -504,13 +536,13 @@ class TransportLayerClient(object):
timeout=timeout,
)
- defer.returnValue(content)
+ return content
@log_function
def get_group_profile(self, destination, group_id, requester_user_id):
"""Get a group profile
"""
- path = _create_v1_path("/groups/%s/profile", group_id,)
+ path = _create_v1_path("/groups/%s/profile", group_id)
return self.client.get_json(
destination=destination,
@@ -529,7 +561,7 @@ class TransportLayerClient(object):
requester_user_id (str)
content (dict): The new profile of the group
"""
- path = _create_v1_path("/groups/%s/profile", group_id,)
+ path = _create_v1_path("/groups/%s/profile", group_id)
return self.client.post_json(
destination=destination,
@@ -543,7 +575,7 @@ class TransportLayerClient(object):
def get_group_summary(self, destination, group_id, requester_user_id):
"""Get a group summary
"""
- path = _create_v1_path("/groups/%s/summary", group_id,)
+ path = _create_v1_path("/groups/%s/summary", group_id)
return self.client.get_json(
destination=destination,
@@ -556,7 +588,7 @@ class TransportLayerClient(object):
def get_rooms_in_group(self, destination, group_id, requester_user_id):
"""Get all rooms in a group
"""
- path = _create_v1_path("/groups/%s/rooms", group_id,)
+ path = _create_v1_path("/groups/%s/rooms", group_id)
return self.client.get_json(
destination=destination,
@@ -565,11 +597,12 @@ class TransportLayerClient(object):
ignore_backoff=True,
)
- def add_room_to_group(self, destination, group_id, requester_user_id, room_id,
- content):
+ def add_room_to_group(
+ self, destination, group_id, requester_user_id, room_id, content
+ ):
"""Add a room to a group
"""
- path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,)
+ path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
return self.client.post_json(
destination=destination,
@@ -579,13 +612,13 @@ class TransportLayerClient(object):
ignore_backoff=True,
)
- def update_room_in_group(self, destination, group_id, requester_user_id, room_id,
- config_key, content):
+ def update_room_in_group(
+ self, destination, group_id, requester_user_id, room_id, config_key, content
+ ):
"""Update room in group
"""
path = _create_v1_path(
- "/groups/%s/room/%s/config/%s",
- group_id, room_id, config_key,
+ "/groups/%s/room/%s/config/%s", group_id, room_id, config_key
)
return self.client.post_json(
@@ -599,7 +632,7 @@ class TransportLayerClient(object):
def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
"""Remove a room from a group
"""
- path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,)
+ path = _create_v1_path("/groups/%s/room/%s", group_id, room_id)
return self.client.delete_json(
destination=destination,
@@ -612,7 +645,7 @@ class TransportLayerClient(object):
def get_users_in_group(self, destination, group_id, requester_user_id):
"""Get users in a group
"""
- path = _create_v1_path("/groups/%s/users", group_id,)
+ path = _create_v1_path("/groups/%s/users", group_id)
return self.client.get_json(
destination=destination,
@@ -625,7 +658,7 @@ class TransportLayerClient(object):
def get_invited_users_in_group(self, destination, group_id, requester_user_id):
"""Get users that have been invited to a group
"""
- path = _create_v1_path("/groups/%s/invited_users", group_id,)
+ path = _create_v1_path("/groups/%s/invited_users", group_id)
return self.client.get_json(
destination=destination,
@@ -638,16 +671,10 @@ class TransportLayerClient(object):
def accept_group_invite(self, destination, group_id, user_id, content):
"""Accept a group invite
"""
- path = _create_v1_path(
- "/groups/%s/users/%s/accept_invite",
- group_id, user_id,
- )
+ path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id)
return self.client.post_json(
- destination=destination,
- path=path,
- data=content,
- ignore_backoff=True,
+ destination=destination, path=path, data=content, ignore_backoff=True
)
@log_function
@@ -657,14 +684,13 @@ class TransportLayerClient(object):
path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id)
return self.client.post_json(
- destination=destination,
- path=path,
- data=content,
- ignore_backoff=True,
+ destination=destination, path=path, data=content, ignore_backoff=True
)
@log_function
- def invite_to_group(self, destination, group_id, user_id, requester_user_id, content):
+ def invite_to_group(
+ self, destination, group_id, user_id, requester_user_id, content
+ ):
"""Invite a user to a group
"""
path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id)
@@ -686,15 +712,13 @@ class TransportLayerClient(object):
path = _create_v1_path("/groups/local/%s/users/%s/invite", group_id, user_id)
return self.client.post_json(
- destination=destination,
- path=path,
- data=content,
- ignore_backoff=True,
+ destination=destination, path=path, data=content, ignore_backoff=True
)
@log_function
- def remove_user_from_group(self, destination, group_id, requester_user_id,
- user_id, content):
+ def remove_user_from_group(
+ self, destination, group_id, requester_user_id, user_id, content
+ ):
"""Remove a user fron a group
"""
path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id)
@@ -708,8 +732,9 @@ class TransportLayerClient(object):
)
@log_function
- def remove_user_from_group_notification(self, destination, group_id, user_id,
- content):
+ def remove_user_from_group_notification(
+ self, destination, group_id, user_id, content
+ ):
"""Sent by group server to inform a user's server that they have been
kicked from the group.
"""
@@ -717,10 +742,7 @@ class TransportLayerClient(object):
path = _create_v1_path("/groups/local/%s/users/%s/remove", group_id, user_id)
return self.client.post_json(
- destination=destination,
- path=path,
- data=content,
- ignore_backoff=True,
+ destination=destination, path=path, data=content, ignore_backoff=True
)
@log_function
@@ -732,24 +754,24 @@ class TransportLayerClient(object):
path = _create_v1_path("/groups/%s/renew_attestation/%s", group_id, user_id)
return self.client.post_json(
- destination=destination,
- path=path,
- data=content,
- ignore_backoff=True,
+ destination=destination, path=path, data=content, ignore_backoff=True
)
@log_function
- def update_group_summary_room(self, destination, group_id, user_id, room_id,
- category_id, content):
+ def update_group_summary_room(
+ self, destination, group_id, user_id, room_id, category_id, content
+ ):
"""Update a room entry in a group summary
"""
if category_id:
path = _create_v1_path(
"/groups/%s/summary/categories/%s/rooms/%s",
- group_id, category_id, room_id,
+ group_id,
+ category_id,
+ room_id,
)
else:
- path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,)
+ path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id)
return self.client.post_json(
destination=destination,
@@ -760,17 +782,20 @@ class TransportLayerClient(object):
)
@log_function
- def delete_group_summary_room(self, destination, group_id, user_id, room_id,
- category_id):
+ def delete_group_summary_room(
+ self, destination, group_id, user_id, room_id, category_id
+ ):
"""Delete a room entry in a group summary
"""
if category_id:
path = _create_v1_path(
"/groups/%s/summary/categories/%s/rooms/%s",
- group_id, category_id, room_id,
+ group_id,
+ category_id,
+ room_id,
)
else:
- path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,)
+ path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id)
return self.client.delete_json(
destination=destination,
@@ -783,7 +808,7 @@ class TransportLayerClient(object):
def get_group_categories(self, destination, group_id, requester_user_id):
"""Get all categories in a group
"""
- path = _create_v1_path("/groups/%s/categories", group_id,)
+ path = _create_v1_path("/groups/%s/categories", group_id)
return self.client.get_json(
destination=destination,
@@ -796,7 +821,7 @@ class TransportLayerClient(object):
def get_group_category(self, destination, group_id, requester_user_id, category_id):
"""Get category info in a group
"""
- path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
+ path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
return self.client.get_json(
destination=destination,
@@ -806,11 +831,12 @@ class TransportLayerClient(object):
)
@log_function
- def update_group_category(self, destination, group_id, requester_user_id, category_id,
- content):
+ def update_group_category(
+ self, destination, group_id, requester_user_id, category_id, content
+ ):
"""Update a category in a group
"""
- path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
+ path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
return self.client.post_json(
destination=destination,
@@ -821,11 +847,12 @@ class TransportLayerClient(object):
)
@log_function
- def delete_group_category(self, destination, group_id, requester_user_id,
- category_id):
+ def delete_group_category(
+ self, destination, group_id, requester_user_id, category_id
+ ):
"""Delete a category in a group
"""
- path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,)
+ path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id)
return self.client.delete_json(
destination=destination,
@@ -838,7 +865,7 @@ class TransportLayerClient(object):
def get_group_roles(self, destination, group_id, requester_user_id):
"""Get all roles in a group
"""
- path = _create_v1_path("/groups/%s/roles", group_id,)
+ path = _create_v1_path("/groups/%s/roles", group_id)
return self.client.get_json(
destination=destination,
@@ -851,7 +878,7 @@ class TransportLayerClient(object):
def get_group_role(self, destination, group_id, requester_user_id, role_id):
"""Get a roles info
"""
- path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
+ path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
return self.client.get_json(
destination=destination,
@@ -861,11 +888,12 @@ class TransportLayerClient(object):
)
@log_function
- def update_group_role(self, destination, group_id, requester_user_id, role_id,
- content):
+ def update_group_role(
+ self, destination, group_id, requester_user_id, role_id, content
+ ):
"""Update a role in a group
"""
- path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
+ path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
return self.client.post_json(
destination=destination,
@@ -879,7 +907,7 @@ class TransportLayerClient(object):
def delete_group_role(self, destination, group_id, requester_user_id, role_id):
"""Delete a role in a group
"""
- path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,)
+ path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id)
return self.client.delete_json(
destination=destination,
@@ -889,17 +917,17 @@ class TransportLayerClient(object):
)
@log_function
- def update_group_summary_user(self, destination, group_id, requester_user_id,
- user_id, role_id, content):
+ def update_group_summary_user(
+ self, destination, group_id, requester_user_id, user_id, role_id, content
+ ):
"""Update a users entry in a group
"""
if role_id:
path = _create_v1_path(
- "/groups/%s/summary/roles/%s/users/%s",
- group_id, role_id, user_id,
+ "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id
)
else:
- path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,)
+ path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id)
return self.client.post_json(
destination=destination,
@@ -910,11 +938,10 @@ class TransportLayerClient(object):
)
@log_function
- def set_group_join_policy(self, destination, group_id, requester_user_id,
- content):
+ def set_group_join_policy(self, destination, group_id, requester_user_id, content):
"""Sets the join policy for a group
"""
- path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id,)
+ path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id)
return self.client.put_json(
destination=destination,
@@ -925,17 +952,17 @@ class TransportLayerClient(object):
)
@log_function
- def delete_group_summary_user(self, destination, group_id, requester_user_id,
- user_id, role_id):
+ def delete_group_summary_user(
+ self, destination, group_id, requester_user_id, user_id, role_id
+ ):
"""Delete a users entry in a group
"""
if role_id:
path = _create_v1_path(
- "/groups/%s/summary/roles/%s/users/%s",
- group_id, role_id, user_id,
+ "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id
)
else:
- path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,)
+ path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id)
return self.client.delete_json(
destination=destination,
@@ -953,12 +980,26 @@ class TransportLayerClient(object):
content = {"user_ids": user_ids}
return self.client.post_json(
- destination=destination,
- path=path,
- data=content,
- ignore_backoff=True,
+ destination=destination, path=path, data=content, ignore_backoff=True
)
+ def get_room_complexity(self, destination, room_id):
+ """
+ Args:
+ destination (str): The remote server
+ room_id (str): The room ID to ask about.
+ """
+ path = _create_path(FEDERATION_UNSTABLE_PREFIX, "/rooms/%s/complexity", room_id)
+
+ return self.client.get_json(destination=destination, path=path)
+
+
+def _create_path(federation_prefix, path, *args):
+ """
+ Ensures that all args are url encoded.
+ """
+ return federation_prefix + path % tuple(urllib.parse.quote(arg, "") for arg in args)
+
def _create_v1_path(path, *args):
"""Creates a path against V1 federation API from the path template and
@@ -975,10 +1016,7 @@ def _create_v1_path(path, *args):
Returns:
str
"""
- return (
- FEDERATION_V1_PREFIX
- + path % tuple(urllib.parse.quote(arg, "") for arg in args)
- )
+ return _create_path(FEDERATION_V1_PREFIX, path, *args)
def _create_v2_path(path, *args):
@@ -996,7 +1034,4 @@ def _create_v2_path(path, *args):
Returns:
str
"""
- return (
- FEDERATION_V2_PREFIX
- + path % tuple(urllib.parse.quote(arg, "") for arg in args)
- )
+ return _create_path(FEDERATION_V2_PREFIX, path, *args)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 6cf213b895..af4595498c 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,8 +18,9 @@
import functools
import logging
import re
+from typing import Optional, Tuple, Type
-from twisted.internet import defer
+from twisted.internet.defer import maybeDeferred
import synapse
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
@@ -36,8 +38,15 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string_from_args,
)
+from synapse.logging.context import run_in_background
+from synapse.logging.opentracing import (
+ start_active_span,
+ start_active_span_from_request,
+ tags,
+ whitelisted_homeserver,
+)
+from synapse.server import HomeServer
from synapse.types import ThirdPartyInstanceID, get_domain_from_id
-from synapse.util.logcontext import run_in_background
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
@@ -66,8 +75,7 @@ class TransportLayerServer(JsonResource):
self.authenticator = Authenticator(hs)
self.ratelimiter = FederationRateLimiter(
- self.clock,
- config=hs.config.rc_federation,
+ self.clock, config=hs.config.rc_federation
)
self.register_servlets()
@@ -84,29 +92,35 @@ class TransportLayerServer(JsonResource):
class AuthenticationError(SynapseError):
"""There was a problem authenticating the request"""
+
pass
class NoAuthenticationError(AuthenticationError):
"""The request had no authentication information"""
+
pass
class Authenticator(object):
- def __init__(self, hs):
+ def __init__(self, hs: HomeServer):
self._clock = hs.get_clock()
self.keyring = hs.get_keyring()
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
+ self.notifer = hs.get_notifier()
+
+ self.replication_client = None
+ if hs.config.worker.worker_app:
+ self.replication_client = hs.get_tcp_replication()
# A method just so we can pass 'self' as the authenticator to the Servlets
- @defer.inlineCallbacks
- def authenticate_request(self, request, content):
+ async def authenticate_request(self, request, content):
now = self._clock.time_msec()
json_request = {
- "method": request.method.decode('ascii'),
- "uri": request.uri.decode('ascii'),
+ "method": request.method.decode("ascii"),
+ "uri": request.uri.decode("ascii"),
"destination": self.server_name,
"signatures": {},
}
@@ -120,7 +134,7 @@ class Authenticator(object):
if not auth_headers:
raise NoAuthenticationError(
- 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
+ 401, "Missing Authorization headers", Codes.UNAUTHORIZED
)
for auth in auth_headers:
@@ -130,36 +144,46 @@ class Authenticator(object):
json_request["signatures"].setdefault(origin, {})[key] = sig
if (
- self.federation_domain_whitelist is not None and
- origin not in self.federation_domain_whitelist
+ self.federation_domain_whitelist is not None
+ and origin not in self.federation_domain_whitelist
):
raise FederationDeniedError(origin)
if not json_request["signatures"]:
raise NoAuthenticationError(
- 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
+ 401, "Missing Authorization headers", Codes.UNAUTHORIZED
)
- yield self.keyring.verify_json_for_server(
+ await self.keyring.verify_json_for_server(
origin, json_request, now, "Incoming request"
)
- logger.info("Request from %s", origin)
+ logger.debug("Request from %s", origin)
request.authenticated_entity = origin
# If we get a valid signed request from the other side, its probably
# alive
- retry_timings = yield self.store.get_destination_retry_timings(origin)
+ retry_timings = await self.store.get_destination_retry_timings(origin)
if retry_timings and retry_timings["retry_last_ts"]:
run_in_background(self._reset_retry_timings, origin)
- defer.returnValue(origin)
+ return origin
- @defer.inlineCallbacks
- def _reset_retry_timings(self, origin):
+ async def _reset_retry_timings(self, origin):
try:
logger.info("Marking origin %r as up", origin)
- yield self.store.set_destination_retry_timings(origin, 0, 0)
+ await self.store.set_destination_retry_timings(origin, None, 0, 0)
+
+ # Inform the relevant places that the remote server is back up.
+ self.notifer.notify_remote_server_up(origin)
+ if self.replication_client:
+ # If we're on a worker we try and inform master about this. The
+ # replication client doesn't hook into the notifier to avoid
+ # infinite loops where we send a `REMOTE_SERVER_UP` command to
+ # master, which then echoes it back to us which in turn pokes
+ # the notifier.
+ self.replication_client.send_remote_server_up(origin)
+
except Exception:
logger.exception("Error resetting retry timings on %s", origin)
@@ -177,12 +201,12 @@ def _parse_auth_header(header_bytes):
AuthenticationError if the header could not be parsed
"""
try:
- header_str = header_bytes.decode('utf-8')
+ header_str = header_bytes.decode("utf-8")
params = header_str.split(" ")[1].split(",")
param_dict = dict(kv.split("=") for kv in params)
def strip_quotes(value):
- if value.startswith("\""):
+ if value.startswith('"'):
return value[1:-1]
else:
return value
@@ -196,13 +220,13 @@ def _parse_auth_header(header_bytes):
sig = strip_quotes(param_dict["sig"])
return origin, key, sig
except Exception as e:
- logger.warn(
+ logger.warning(
"Error parsing auth header '%s': %s",
- header_bytes.decode('ascii', 'replace'),
+ header_bytes.decode("ascii", "replace"),
e,
)
raise AuthenticationError(
- 400, "Malformed Authorization header", Codes.UNAUTHORIZED,
+ 400, "Malformed Authorization header", Codes.UNAUTHORIZED
)
@@ -213,7 +237,8 @@ class BaseFederationServlet(object):
match against the request path (excluding the /federation/v1 prefix).
The servlet should also implement one or more of on_GET, on_POST, on_PUT, to match
- the appropriate HTTP method. These methods have the signature:
+ the appropriate HTTP method. These methods must be *asynchronous* and have the
+ signature:
on_<METHOD>(self, origin, content, query, **kwargs)
@@ -233,7 +258,7 @@ class BaseFederationServlet(object):
components as specified in the path match regexp.
Returns:
- Deferred[(int, object)|None]: either (response code, response object) to
+ Optional[Tuple[int, object]]: either (response code, response object) to
return a JSON response, or None if the request has already been handled.
Raises:
@@ -242,6 +267,9 @@ class BaseFederationServlet(object):
Exception: other exceptions will be caught, logged, and a 500 will be
returned.
"""
+
+ PATH = "" # Overridden in subclasses, the regex to match against the path.
+
REQUIRE_AUTH = True
PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version
@@ -255,10 +283,9 @@ class BaseFederationServlet(object):
authenticator = self.authenticator
ratelimiter = self.ratelimiter
- @defer.inlineCallbacks
@functools.wraps(func)
- def new_func(request, *args, **kwargs):
- """ A callback which can be passed to HttpServer.RegisterPaths
+ async def new_func(request, *args, **kwargs):
+ """A callback which can be passed to HttpServer.RegisterPaths
Args:
request (twisted.web.http.Request):
@@ -267,8 +294,8 @@ class BaseFederationServlet(object):
components as specified in the path match regexp.
Returns:
- Deferred[(int, object)|None]: (response code, response object) as returned
- by the callback method. None if the request has already been handled.
+ Tuple[int, object]|None: (response code, response object) as returned by
+ the callback method. None if the request has already been handled.
"""
content = None
if request.method in [b"PUT", b"POST"]:
@@ -276,31 +303,52 @@ class BaseFederationServlet(object):
content = parse_json_object_from_request(request)
try:
- origin = yield authenticator.authenticate_request(request, content)
+ origin = await authenticator.authenticate_request(request, content)
except NoAuthenticationError:
origin = None
if self.REQUIRE_AUTH:
- logger.warn("authenticate_request failed: missing authentication")
+ logger.warning(
+ "authenticate_request failed: missing authentication"
+ )
raise
except Exception as e:
- logger.warn("authenticate_request failed: %s", e)
+ logger.warning("authenticate_request failed: %s", e)
raise
- if origin:
- with ratelimiter.ratelimit(origin) as d:
- yield d
- response = yield func(
- origin, content, request.args, *args, **kwargs
- )
+ request_tags = {
+ "request_id": request.get_request_id(),
+ tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
+ tags.HTTP_METHOD: request.get_method(),
+ tags.HTTP_URL: request.get_redacted_uri(),
+ tags.PEER_HOST_IPV6: request.getClientIP(),
+ "authenticated_entity": origin,
+ "servlet_name": request.request_metrics.name,
+ }
+
+ # Only accept the span context if the origin is authenticated
+ # and whitelisted
+ if origin and whitelisted_homeserver(origin):
+ scope = start_active_span_from_request(
+ request, "incoming-federation-request", tags=request_tags
+ )
else:
- response = yield func(
- origin, content, request.args, *args, **kwargs
+ scope = start_active_span(
+ "incoming-federation-request", tags=request_tags
)
- defer.returnValue(response)
+ with scope:
+ if origin:
+ with ratelimiter.ratelimit(origin) as d:
+ await d
+ response = await func(
+ origin, content, request.args, *args, **kwargs
+ )
+ else:
+ response = await func(
+ origin, content, request.args, *args, **kwargs
+ )
- # Extra logic that functools.wraps() doesn't finish
- new_func.__self__ = func.__self__
+ return response
return new_func
@@ -312,7 +360,13 @@ class BaseFederationServlet(object):
if code is None:
continue
- server.register_paths(method, (pattern,), self._wrap(code))
+ server.register_paths(
+ method,
+ (pattern,),
+ self._wrap(code),
+ self.__class__.__name__,
+ trace=False,
+ )
class FederationSendServlet(BaseFederationServlet):
@@ -325,8 +379,7 @@ class FederationSendServlet(BaseFederationServlet):
self.server_name = server_name
# This is when someone is trying to send us a bunch of data.
- @defer.inlineCallbacks
- def on_PUT(self, origin, content, query, transaction_id):
+ async def on_PUT(self, origin, content, query, transaction_id):
""" Called on PUT /send/<transaction_id>/
Args:
@@ -335,7 +388,7 @@ class FederationSendServlet(BaseFederationServlet):
request. This is *not* None.
Returns:
- Deferred: Results in a tuple of `(code, response)`, where
+ Tuple of `(code, response)`, where
`response` is a python dict to be converted into JSON that is
used as the response body.
"""
@@ -343,14 +396,12 @@ class FederationSendServlet(BaseFederationServlet):
try:
transaction_data = content
- logger.debug(
- "Decoded %s: %s",
- transaction_id, str(transaction_data)
- )
+ logger.debug("Decoded %s: %s", transaction_id, str(transaction_data))
logger.info(
"Received txn %s from %s. (PDUs: %d, EDUs: %d)",
- transaction_id, origin,
+ transaction_id,
+ origin,
len(transaction_data.get("pdus", [])),
len(transaction_data.get("edus", [])),
)
@@ -361,51 +412,49 @@ class FederationSendServlet(BaseFederationServlet):
# Add some extra data to the transaction dict that isn't included
# in the request body.
transaction_data.update(
- transaction_id=transaction_id,
- destination=self.server_name
+ transaction_id=transaction_id, destination=self.server_name
)
except Exception as e:
logger.exception(e)
- defer.returnValue((400, {"error": "Invalid transaction"}))
- return
+ return 400, {"error": "Invalid transaction"}
try:
- code, response = yield self.handler.on_incoming_transaction(
- origin, transaction_data,
+ code, response = await self.handler.on_incoming_transaction(
+ origin, transaction_data
)
except Exception:
logger.exception("on_incoming_transaction failed")
raise
- defer.returnValue((code, response))
+ return code, response
class FederationEventServlet(BaseFederationServlet):
PATH = "/event/(?P<event_id>[^/]*)/?"
# This is when someone asks for a data item for a given server data_id pair.
- def on_GET(self, origin, content, query, event_id):
- return self.handler.on_pdu_request(origin, event_id)
+ async def on_GET(self, origin, content, query, event_id):
+ return await self.handler.on_pdu_request(origin, event_id)
-class FederationStateServlet(BaseFederationServlet):
+class FederationStateV1Servlet(BaseFederationServlet):
PATH = "/state/(?P<context>[^/]*)/?"
# This is when someone asks for all data for a given context.
- def on_GET(self, origin, content, query, context):
- return self.handler.on_context_state_request(
+ async def on_GET(self, origin, content, query, context):
+ return await self.handler.on_context_state_request(
origin,
context,
- parse_string_from_args(query, "event_id", None, required=True),
+ parse_string_from_args(query, "event_id", None, required=False),
)
class FederationStateIdsServlet(BaseFederationServlet):
PATH = "/state_ids/(?P<room_id>[^/]*)/?"
- def on_GET(self, origin, content, query, room_id):
- return self.handler.on_state_ids_request(
+ async def on_GET(self, origin, content, query, room_id):
+ return await self.handler.on_state_ids_request(
origin,
room_id,
parse_string_from_args(query, "event_id", None, required=True),
@@ -415,32 +464,31 @@ class FederationStateIdsServlet(BaseFederationServlet):
class FederationBackfillServlet(BaseFederationServlet):
PATH = "/backfill/(?P<context>[^/]*)/?"
- def on_GET(self, origin, content, query, context):
- versions = [x.decode('ascii') for x in query[b"v"]]
+ async def on_GET(self, origin, content, query, context):
+ versions = [x.decode("ascii") for x in query[b"v"]]
limit = parse_integer_from_args(query, "limit", None)
if not limit:
- return defer.succeed((400, {"error": "Did not include limit param"}))
+ return 400, {"error": "Did not include limit param"}
- return self.handler.on_backfill_request(origin, context, versions, limit)
+ return await self.handler.on_backfill_request(origin, context, versions, limit)
class FederationQueryServlet(BaseFederationServlet):
PATH = "/query/(?P<query_type>[^/]*)"
# This is when we receive a server-server Query
- def on_GET(self, origin, content, query, query_type):
- return self.handler.on_query_request(
+ async def on_GET(self, origin, content, query, query_type):
+ return await self.handler.on_query_request(
query_type,
- {k.decode('utf8'): v[0].decode("utf-8") for k, v in query.items()}
+ {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()},
)
class FederationMakeJoinServlet(BaseFederationServlet):
PATH = "/make_join/(?P<context>[^/]*)/(?P<user_id>[^/]*)"
- @defer.inlineCallbacks
- def on_GET(self, origin, _content, query, context, user_id):
+ async def on_GET(self, origin, _content, query, context, user_id):
"""
Args:
origin (unicode): The authenticated server_name of the calling server
@@ -453,76 +501,90 @@ class FederationMakeJoinServlet(BaseFederationServlet):
components as specified in the path match regexp.
Returns:
- Deferred[(int, object)|None]: either (response code, response object) to
- return a JSON response, or None if the request has already been handled.
+ Tuple[int, object]: (response code, response object)
"""
- versions = query.get(b'ver')
+ versions = query.get(b"ver")
if versions is not None:
supported_versions = [v.decode("utf-8") for v in versions]
else:
supported_versions = ["1"]
- content = yield self.handler.on_make_join_request(
- origin, context, user_id,
- supported_versions=supported_versions,
+ content = await self.handler.on_make_join_request(
+ origin, context, user_id, supported_versions=supported_versions
)
- defer.returnValue((200, content))
+ return 200, content
class FederationMakeLeaveServlet(BaseFederationServlet):
PATH = "/make_leave/(?P<context>[^/]*)/(?P<user_id>[^/]*)"
- @defer.inlineCallbacks
- def on_GET(self, origin, content, query, context, user_id):
- content = yield self.handler.on_make_leave_request(
- origin, context, user_id,
- )
- defer.returnValue((200, content))
+ async def on_GET(self, origin, content, query, context, user_id):
+ content = await self.handler.on_make_leave_request(origin, context, user_id)
+ return 200, content
-class FederationSendLeaveServlet(BaseFederationServlet):
+class FederationV1SendLeaveServlet(BaseFederationServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
- @defer.inlineCallbacks
- def on_PUT(self, origin, content, query, room_id, event_id):
- content = yield self.handler.on_send_leave_request(origin, content, room_id)
- defer.returnValue((200, content))
+ async def on_PUT(self, origin, content, query, room_id, event_id):
+ content = await self.handler.on_send_leave_request(origin, content, room_id)
+ return 200, (200, content)
+
+
+class FederationV2SendLeaveServlet(BaseFederationServlet):
+ PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
+
+ PREFIX = FEDERATION_V2_PREFIX
+
+ async def on_PUT(self, origin, content, query, room_id, event_id):
+ content = await self.handler.on_send_leave_request(origin, content, room_id)
+ return 200, content
class FederationEventAuthServlet(BaseFederationServlet):
PATH = "/event_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
- def on_GET(self, origin, content, query, context, event_id):
- return self.handler.on_event_auth(origin, context, event_id)
+ async def on_GET(self, origin, content, query, context, event_id):
+ return await self.handler.on_event_auth(origin, context, event_id)
+
+
+class FederationV1SendJoinServlet(BaseFederationServlet):
+ PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
+
+ async def on_PUT(self, origin, content, query, context, event_id):
+ # TODO(paul): assert that context/event_id parsed from path actually
+ # match those given in content
+ content = await self.handler.on_send_join_request(origin, content, context)
+ return 200, (200, content)
-class FederationSendJoinServlet(BaseFederationServlet):
+class FederationV2SendJoinServlet(BaseFederationServlet):
PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
- @defer.inlineCallbacks
- def on_PUT(self, origin, content, query, context, event_id):
+ PREFIX = FEDERATION_V2_PREFIX
+
+ async def on_PUT(self, origin, content, query, context, event_id):
# TODO(paul): assert that context/event_id parsed from path actually
# match those given in content
- content = yield self.handler.on_send_join_request(origin, content, context)
- defer.returnValue((200, content))
+ content = await self.handler.on_send_join_request(origin, content, context)
+ return 200, content
class FederationV1InviteServlet(BaseFederationServlet):
PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
- @defer.inlineCallbacks
- def on_PUT(self, origin, content, query, context, event_id):
+ async def on_PUT(self, origin, content, query, context, event_id):
# We don't get a room version, so we have to assume its EITHER v1 or
# v2. This is "fine" as the only difference between V1 and V2 is the
# state resolution algorithm, and we don't use that for processing
# invites
- content = yield self.handler.on_invite_request(
- origin, content, room_version=RoomVersions.V1.identifier,
+ content = await self.handler.on_invite_request(
+ origin, content, room_version_id=RoomVersions.V1.identifier
)
# V1 federation API is defined to return a content of `[200, {...}]`
# due to a historical bug.
- defer.returnValue((200, (200, content)))
+ return 200, (200, content)
class FederationV2InviteServlet(BaseFederationServlet):
@@ -530,8 +592,7 @@ class FederationV2InviteServlet(BaseFederationServlet):
PREFIX = FEDERATION_V2_PREFIX
- @defer.inlineCallbacks
- def on_PUT(self, origin, content, query, context, event_id):
+ async def on_PUT(self, origin, content, query, context, event_id):
# TODO(paul): assert that context/event_id parsed from path actually
# match those given in content
@@ -544,69 +605,54 @@ class FederationV2InviteServlet(BaseFederationServlet):
event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state
- content = yield self.handler.on_invite_request(
- origin, event, room_version=room_version,
+ content = await self.handler.on_invite_request(
+ origin, event, room_version_id=room_version
)
- defer.returnValue((200, content))
+ return 200, content
class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
- @defer.inlineCallbacks
- def on_PUT(self, origin, content, query, room_id):
- content = yield self.handler.on_exchange_third_party_invite_request(
- origin, room_id, content
+ async def on_PUT(self, origin, content, query, room_id):
+ content = await self.handler.on_exchange_third_party_invite_request(
+ room_id, content
)
- defer.returnValue((200, content))
+ return 200, content
class FederationClientKeysQueryServlet(BaseFederationServlet):
PATH = "/user/keys/query"
- def on_POST(self, origin, content, query):
- return self.handler.on_query_client_keys(origin, content)
+ async def on_POST(self, origin, content, query):
+ return await self.handler.on_query_client_keys(origin, content)
class FederationUserDevicesQueryServlet(BaseFederationServlet):
PATH = "/user/devices/(?P<user_id>[^/]*)"
- def on_GET(self, origin, content, query, user_id):
- return self.handler.on_query_user_devices(origin, user_id)
+ async def on_GET(self, origin, content, query, user_id):
+ return await self.handler.on_query_user_devices(origin, user_id)
class FederationClientKeysClaimServlet(BaseFederationServlet):
PATH = "/user/keys/claim"
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query):
- response = yield self.handler.on_claim_client_keys(origin, content)
- defer.returnValue((200, response))
-
-
-class FederationQueryAuthServlet(BaseFederationServlet):
- PATH = "/query_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
-
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, context, event_id):
- new_content = yield self.handler.on_query_auth_request(
- origin, content, context, event_id
- )
-
- defer.returnValue((200, new_content))
+ async def on_POST(self, origin, content, query):
+ response = await self.handler.on_claim_client_keys(origin, content)
+ return 200, response
class FederationGetMissingEventsServlet(BaseFederationServlet):
# TODO(paul): Why does this path alone end with "/?" optional?
PATH = "/get_missing_events/(?P<room_id>[^/]*)/?"
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, room_id):
+ async def on_POST(self, origin, content, query, room_id):
limit = int(content.get("limit", 10))
earliest_events = content.get("earliest_events", [])
latest_events = content.get("latest_events", [])
- content = yield self.handler.on_get_missing_events(
+ content = await self.handler.on_get_missing_events(
origin,
room_id=room_id,
earliest_events=earliest_events,
@@ -614,7 +660,7 @@ class FederationGetMissingEventsServlet(BaseFederationServlet):
limit=limit,
)
- defer.returnValue((200, content))
+ return 200, content
class On3pidBindServlet(BaseFederationServlet):
@@ -622,18 +668,19 @@ class On3pidBindServlet(BaseFederationServlet):
REQUIRE_AUTH = False
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query):
+ async def on_POST(self, origin, content, query):
if "invites" in content:
last_exception = None
for invite in content["invites"]:
try:
if "signed" not in invite or "token" not in invite["signed"]:
- message = ("Rejecting received notification of third-"
- "party invite without signed: %s" % (invite,))
+ message = (
+ "Rejecting received notification of third-"
+ "party invite without signed: %s" % (invite,)
+ )
logger.info(message)
raise SynapseError(400, message)
- yield self.handler.exchange_third_party_invite(
+ await self.handler.exchange_third_party_invite(
invite["sender"],
invite["mxid"],
invite["room_id"],
@@ -643,7 +690,7 @@ class On3pidBindServlet(BaseFederationServlet):
last_exception = e
if last_exception:
raise last_exception
- defer.returnValue((200, {}))
+ return 200, {}
class OpenIdUserInfo(BaseFederationServlet):
@@ -667,24 +714,26 @@ class OpenIdUserInfo(BaseFederationServlet):
REQUIRE_AUTH = False
- @defer.inlineCallbacks
- def on_GET(self, origin, content, query):
+ async def on_GET(self, origin, content, query):
token = query.get(b"access_token", [None])[0]
if token is None:
- defer.returnValue((401, {
- "errcode": "M_MISSING_TOKEN", "error": "Access Token required"
- }))
- return
+ return (
+ 401,
+ {"errcode": "M_MISSING_TOKEN", "error": "Access Token required"},
+ )
- user_id = yield self.handler.on_openid_userinfo(token.decode('ascii'))
+ user_id = await self.handler.on_openid_userinfo(token.decode("ascii"))
if user_id is None:
- defer.returnValue((401, {
- "errcode": "M_UNKNOWN_TOKEN",
- "error": "Access Token unknown or expired"
- }))
+ return (
+ 401,
+ {
+ "errcode": "M_UNKNOWN_TOKEN",
+ "error": "Access Token unknown or expired",
+ },
+ )
- defer.returnValue((200, {"sub": user_id}))
+ return 200, {"sub": user_id}
class PublicRoomList(BaseFederationServlet):
@@ -693,7 +742,7 @@ class PublicRoomList(BaseFederationServlet):
This API returns information in the same format as /publicRooms on the
client API, but will only ever include local public rooms and hence is
- intended for consumption by other home servers.
+ intended for consumption by other homeservers.
GET /publicRooms HTTP/1.1
@@ -722,12 +771,11 @@ class PublicRoomList(BaseFederationServlet):
def __init__(self, handler, authenticator, ratelimiter, server_name, allow_access):
super(PublicRoomList, self).__init__(
- handler, authenticator, ratelimiter, server_name,
+ handler, authenticator, ratelimiter, server_name
)
self.allow_access = allow_access
- @defer.inlineCallbacks
- def on_GET(self, origin, content, query):
+ async def on_GET(self, origin, content, query):
if not self.allow_access:
raise FederationDeniedError(origin)
@@ -747,12 +795,58 @@ class PublicRoomList(BaseFederationServlet):
else:
network_tuple = ThirdPartyInstanceID(None, None)
- data = yield self.handler.get_local_public_room_list(
- limit, since_token,
+ if limit == 0:
+ # zero is a special value which corresponds to no limit.
+ limit = None
+
+ data = await maybeDeferred(
+ self.handler.get_local_public_room_list,
+ limit,
+ since_token,
network_tuple=network_tuple,
from_federation=True,
)
- defer.returnValue((200, data))
+ return 200, data
+
+ async def on_POST(self, origin, content, query):
+ # This implements MSC2197 (Search Filtering over Federation)
+ if not self.allow_access:
+ raise FederationDeniedError(origin)
+
+ limit = int(content.get("limit", 100)) # type: Optional[int]
+ since_token = content.get("since", None)
+ search_filter = content.get("filter", None)
+
+ include_all_networks = content.get("include_all_networks", False)
+ third_party_instance_id = content.get("third_party_instance_id", None)
+
+ if include_all_networks:
+ network_tuple = None
+ if third_party_instance_id is not None:
+ raise SynapseError(
+ 400, "Can't use include_all_networks with an explicit network"
+ )
+ elif third_party_instance_id is None:
+ network_tuple = ThirdPartyInstanceID(None, None)
+ else:
+ network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id)
+
+ if search_filter is None:
+ logger.warning("Nonefilter")
+
+ if limit == 0:
+ # zero is a special value which corresponds to no limit.
+ limit = None
+
+ data = await self.handler.get_local_public_room_list(
+ limit=limit,
+ since_token=since_token,
+ search_filter=search_filter,
+ network_tuple=network_tuple,
+ from_federation=True,
+ )
+
+ return 200, data
class FederationVersionServlet(BaseFederationServlet):
@@ -760,284 +854,265 @@ class FederationVersionServlet(BaseFederationServlet):
REQUIRE_AUTH = False
- def on_GET(self, origin, content, query):
- return defer.succeed((200, {
- "server": {
- "name": "Synapse",
- "version": get_version_string(synapse)
- },
- }))
+ async def on_GET(self, origin, content, query):
+ return (
+ 200,
+ {"server": {"name": "Synapse", "version": get_version_string(synapse)}},
+ )
class FederationGroupsProfileServlet(BaseFederationServlet):
"""Get/set the basic profile of a group on behalf of a user
"""
+
PATH = "/groups/(?P<group_id>[^/]*)/profile"
- @defer.inlineCallbacks
- def on_GET(self, origin, content, query, group_id):
+ async def on_GET(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- new_content = yield self.handler.get_group_profile(
- group_id, requester_user_id
- )
+ new_content = await self.handler.get_group_profile(group_id, requester_user_id)
- defer.returnValue((200, new_content))
+ return 200, new_content
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id):
+ async def on_POST(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- new_content = yield self.handler.update_group_profile(
+ new_content = await self.handler.update_group_profile(
group_id, requester_user_id, content
)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGroupsSummaryServlet(BaseFederationServlet):
PATH = "/groups/(?P<group_id>[^/]*)/summary"
- @defer.inlineCallbacks
- def on_GET(self, origin, content, query, group_id):
+ async def on_GET(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- new_content = yield self.handler.get_group_summary(
- group_id, requester_user_id
- )
+ new_content = await self.handler.get_group_summary(group_id, requester_user_id)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGroupsRoomsServlet(BaseFederationServlet):
"""Get the rooms in a group on behalf of a user
"""
+
PATH = "/groups/(?P<group_id>[^/]*)/rooms"
- @defer.inlineCallbacks
- def on_GET(self, origin, content, query, group_id):
+ async def on_GET(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- new_content = yield self.handler.get_rooms_in_group(
- group_id, requester_user_id
- )
+ new_content = await self.handler.get_rooms_in_group(group_id, requester_user_id)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGroupsAddRoomsServlet(BaseFederationServlet):
"""Add/remove room from group
"""
+
PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id, room_id):
+ async def on_POST(self, origin, content, query, group_id, room_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- new_content = yield self.handler.add_room_to_group(
+ new_content = await self.handler.add_room_to_group(
group_id, requester_user_id, room_id, content
)
- defer.returnValue((200, new_content))
+ return 200, new_content
- @defer.inlineCallbacks
- def on_DELETE(self, origin, content, query, group_id, room_id):
+ async def on_DELETE(self, origin, content, query, group_id, room_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- new_content = yield self.handler.remove_room_from_group(
- group_id, requester_user_id, room_id,
+ new_content = await self.handler.remove_room_from_group(
+ group_id, requester_user_id, room_id
)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet):
"""Update room config in group
"""
+
PATH = (
"/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)"
"/config/(?P<config_key>[^/]*)"
)
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id, room_id, config_key):
+ async def on_POST(self, origin, content, query, group_id, room_id, config_key):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- result = yield self.groups_handler.update_room_in_group(
- group_id, requester_user_id, room_id, config_key, content,
+ result = await self.handler.update_room_in_group(
+ group_id, requester_user_id, room_id, config_key, content
)
- defer.returnValue((200, result))
+ return 200, result
class FederationGroupsUsersServlet(BaseFederationServlet):
"""Get the users in a group on behalf of a user
"""
+
PATH = "/groups/(?P<group_id>[^/]*)/users"
- @defer.inlineCallbacks
- def on_GET(self, origin, content, query, group_id):
+ async def on_GET(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- new_content = yield self.handler.get_users_in_group(
- group_id, requester_user_id
- )
+ new_content = await self.handler.get_users_in_group(group_id, requester_user_id)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
"""Get the users that have been invited to a group
"""
+
PATH = "/groups/(?P<group_id>[^/]*)/invited_users"
- @defer.inlineCallbacks
- def on_GET(self, origin, content, query, group_id):
+ async def on_GET(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- new_content = yield self.handler.get_invited_users_in_group(
+ new_content = await self.handler.get_invited_users_in_group(
group_id, requester_user_id
)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGroupsInviteServlet(BaseFederationServlet):
"""Ask a group server to invite someone to the group
"""
+
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(self, origin, content, query, group_id, user_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- new_content = yield self.handler.invite_to_group(
- group_id, user_id, requester_user_id, content,
+ new_content = await self.handler.invite_to_group(
+ group_id, user_id, requester_user_id, content
)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
"""Accept an invitation from the group server
"""
+
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite"
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(self, origin, content, query, group_id, user_id):
if get_domain_from_id(user_id) != origin:
raise SynapseError(403, "user_id doesn't match origin")
- new_content = yield self.handler.accept_invite(
- group_id, user_id, content,
- )
+ new_content = await self.handler.accept_invite(group_id, user_id, content)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGroupsJoinServlet(BaseFederationServlet):
"""Attempt to join a group
"""
+
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join"
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(self, origin, content, query, group_id, user_id):
if get_domain_from_id(user_id) != origin:
raise SynapseError(403, "user_id doesn't match origin")
- new_content = yield self.handler.join_group(
- group_id, user_id, content,
- )
+ new_content = await self.handler.join_group(group_id, user_id, content)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGroupsRemoveUserServlet(BaseFederationServlet):
"""Leave or kick a user from the group
"""
+
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(self, origin, content, query, group_id, user_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- new_content = yield self.handler.remove_user_from_group(
- group_id, user_id, requester_user_id, content,
+ new_content = await self.handler.remove_user_from_group(
+ group_id, user_id, requester_user_id, content
)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGroupsLocalInviteServlet(BaseFederationServlet):
"""A group server has invited a local user
"""
+
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite"
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(self, origin, content, query, group_id, user_id):
if get_domain_from_id(group_id) != origin:
raise SynapseError(403, "group_id doesn't match origin")
- new_content = yield self.handler.on_invite(
- group_id, user_id, content,
- )
+ new_content = await self.handler.on_invite(group_id, user_id, content)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
"""A group server has removed a local user
"""
+
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove"
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(self, origin, content, query, group_id, user_id):
if get_domain_from_id(group_id) != origin:
raise SynapseError(403, "user_id doesn't match origin")
- new_content = yield self.handler.user_removed_from_group(
- group_id, user_id, content,
+ new_content = await self.handler.user_removed_from_group(
+ group_id, user_id, content
)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
"""A group or user's server renews their attestation
"""
+
PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)"
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id, user_id):
+ async def on_POST(self, origin, content, query, group_id, user_id):
# We don't need to check auth here as we check the attestation signatures
- new_content = yield self.handler.on_renew_attestation(
+ new_content = await self.handler.on_renew_attestation(
group_id, user_id, content
)
- defer.returnValue((200, new_content))
+ return 200, new_content
class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
@@ -1047,14 +1122,14 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
- /groups/:group/summary/rooms/:room_id
- /groups/:group/summary/categories/:category/rooms/:room_id
"""
+
PATH = (
"/groups/(?P<group_id>[^/]*)/summary"
"(/categories/(?P<category_id>[^/]+))?"
"/rooms/(?P<room_id>[^/]*)"
)
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id, category_id, room_id):
+ async def on_POST(self, origin, content, query, group_id, category_id, room_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1062,17 +1137,17 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
if category_id == "":
raise SynapseError(400, "category_id cannot be empty string")
- resp = yield self.handler.update_group_summary_room(
- group_id, requester_user_id,
+ resp = await self.handler.update_group_summary_room(
+ group_id,
+ requester_user_id,
room_id=room_id,
category_id=category_id,
content=content,
)
- defer.returnValue((200, resp))
+ return 200, resp
- @defer.inlineCallbacks
- def on_DELETE(self, origin, content, query, group_id, category_id, room_id):
+ async def on_DELETE(self, origin, content, query, group_id, category_id, room_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1080,56 +1155,47 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
if category_id == "":
raise SynapseError(400, "category_id cannot be empty string")
- resp = yield self.handler.delete_group_summary_room(
- group_id, requester_user_id,
- room_id=room_id,
- category_id=category_id,
+ resp = await self.handler.delete_group_summary_room(
+ group_id, requester_user_id, room_id=room_id, category_id=category_id
)
- defer.returnValue((200, resp))
+ return 200, resp
class FederationGroupsCategoriesServlet(BaseFederationServlet):
"""Get all categories for a group
"""
- PATH = (
- "/groups/(?P<group_id>[^/]*)/categories/?"
- )
- @defer.inlineCallbacks
- def on_GET(self, origin, content, query, group_id):
+ PATH = "/groups/(?P<group_id>[^/]*)/categories/?"
+
+ async def on_GET(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- resp = yield self.handler.get_group_categories(
- group_id, requester_user_id,
- )
+ resp = await self.handler.get_group_categories(group_id, requester_user_id)
- defer.returnValue((200, resp))
+ return 200, resp
class FederationGroupsCategoryServlet(BaseFederationServlet):
"""Add/remove/get a category in a group
"""
- PATH = (
- "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
- )
- @defer.inlineCallbacks
- def on_GET(self, origin, content, query, group_id, category_id):
+ PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)"
+
+ async def on_GET(self, origin, content, query, group_id, category_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- resp = yield self.handler.get_group_category(
+ resp = await self.handler.get_group_category(
group_id, requester_user_id, category_id
)
- defer.returnValue((200, resp))
+ return 200, resp
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id, category_id):
+ async def on_POST(self, origin, content, query, group_id, category_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1137,14 +1203,13 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
if category_id == "":
raise SynapseError(400, "category_id cannot be empty string")
- resp = yield self.handler.upsert_group_category(
- group_id, requester_user_id, category_id, content,
+ resp = await self.handler.upsert_group_category(
+ group_id, requester_user_id, category_id, content
)
- defer.returnValue((200, resp))
+ return 200, resp
- @defer.inlineCallbacks
- def on_DELETE(self, origin, content, query, group_id, category_id):
+ async def on_DELETE(self, origin, content, query, group_id, category_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1152,54 +1217,45 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
if category_id == "":
raise SynapseError(400, "category_id cannot be empty string")
- resp = yield self.handler.delete_group_category(
- group_id, requester_user_id, category_id,
+ resp = await self.handler.delete_group_category(
+ group_id, requester_user_id, category_id
)
- defer.returnValue((200, resp))
+ return 200, resp
class FederationGroupsRolesServlet(BaseFederationServlet):
"""Get roles in a group
"""
- PATH = (
- "/groups/(?P<group_id>[^/]*)/roles/?"
- )
- @defer.inlineCallbacks
- def on_GET(self, origin, content, query, group_id):
+ PATH = "/groups/(?P<group_id>[^/]*)/roles/?"
+
+ async def on_GET(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- resp = yield self.handler.get_group_roles(
- group_id, requester_user_id,
- )
+ resp = await self.handler.get_group_roles(group_id, requester_user_id)
- defer.returnValue((200, resp))
+ return 200, resp
class FederationGroupsRoleServlet(BaseFederationServlet):
"""Add/remove/get a role in a group
"""
- PATH = (
- "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
- )
- @defer.inlineCallbacks
- def on_GET(self, origin, content, query, group_id, role_id):
+ PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)"
+
+ async def on_GET(self, origin, content, query, group_id, role_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- resp = yield self.handler.get_group_role(
- group_id, requester_user_id, role_id
- )
+ resp = await self.handler.get_group_role(group_id, requester_user_id, role_id)
- defer.returnValue((200, resp))
+ return 200, resp
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id, role_id):
+ async def on_POST(self, origin, content, query, group_id, role_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1207,14 +1263,13 @@ class FederationGroupsRoleServlet(BaseFederationServlet):
if role_id == "":
raise SynapseError(400, "role_id cannot be empty string")
- resp = yield self.handler.update_group_role(
- group_id, requester_user_id, role_id, content,
+ resp = await self.handler.update_group_role(
+ group_id, requester_user_id, role_id, content
)
- defer.returnValue((200, resp))
+ return 200, resp
- @defer.inlineCallbacks
- def on_DELETE(self, origin, content, query, group_id, role_id):
+ async def on_DELETE(self, origin, content, query, group_id, role_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1222,11 +1277,11 @@ class FederationGroupsRoleServlet(BaseFederationServlet):
if role_id == "":
raise SynapseError(400, "role_id cannot be empty string")
- resp = yield self.handler.delete_group_role(
- group_id, requester_user_id, role_id,
+ resp = await self.handler.delete_group_role(
+ group_id, requester_user_id, role_id
)
- defer.returnValue((200, resp))
+ return 200, resp
class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
@@ -1236,14 +1291,14 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
- /groups/:group/summary/users/:user_id
- /groups/:group/summary/roles/:role/users/:user_id
"""
+
PATH = (
"/groups/(?P<group_id>[^/]*)/summary"
"(/roles/(?P<role_id>[^/]+))?"
"/users/(?P<user_id>[^/]*)"
)
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query, group_id, role_id, user_id):
+ async def on_POST(self, origin, content, query, group_id, role_id, user_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1251,17 +1306,17 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
if role_id == "":
raise SynapseError(400, "role_id cannot be empty string")
- resp = yield self.handler.update_group_summary_user(
- group_id, requester_user_id,
+ resp = await self.handler.update_group_summary_user(
+ group_id,
+ requester_user_id,
user_id=user_id,
role_id=role_id,
content=content,
)
- defer.returnValue((200, resp))
+ return 200, resp
- @defer.inlineCallbacks
- def on_DELETE(self, origin, content, query, group_id, role_id, user_id):
+ async def on_DELETE(self, origin, content, query, group_id, role_id, user_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
@@ -1269,47 +1324,43 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
if role_id == "":
raise SynapseError(400, "role_id cannot be empty string")
- resp = yield self.handler.delete_group_summary_user(
- group_id, requester_user_id,
- user_id=user_id,
- role_id=role_id,
+ resp = await self.handler.delete_group_summary_user(
+ group_id, requester_user_id, user_id=user_id, role_id=role_id
)
- defer.returnValue((200, resp))
+ return 200, resp
class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
"""Get roles in a group
"""
- PATH = (
- "/get_groups_publicised"
- )
- @defer.inlineCallbacks
- def on_POST(self, origin, content, query):
- resp = yield self.handler.bulk_get_publicised_groups(
- content["user_ids"], proxy=False,
+ PATH = "/get_groups_publicised"
+
+ async def on_POST(self, origin, content, query):
+ resp = await self.handler.bulk_get_publicised_groups(
+ content["user_ids"], proxy=False
)
- defer.returnValue((200, resp))
+ return 200, resp
class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
"""Sets whether a group is joinable without an invite or knock
"""
+
PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy"
- @defer.inlineCallbacks
- def on_PUT(self, origin, content, query, group_id):
+ async def on_PUT(self, origin, content, query, group_id):
requester_user_id = parse_string_from_args(query, "requester_user_id")
if get_domain_from_id(requester_user_id) != origin:
raise SynapseError(403, "requester_user_id doesn't match origin")
- new_content = yield self.handler.set_group_join_policy(
+ new_content = await self.handler.set_group_join_policy(
group_id, requester_user_id, content
)
- defer.returnValue((200, new_content))
+ return 200, new_content
class RoomComplexityServlet(BaseFederationServlet):
@@ -1317,40 +1368,39 @@ class RoomComplexityServlet(BaseFederationServlet):
Indicates to other servers how complex (and therefore likely
resource-intensive) a public room this server knows about is.
"""
+
PATH = "/rooms/(?P<room_id>[^/]*)/complexity"
PREFIX = FEDERATION_UNSTABLE_PREFIX
- @defer.inlineCallbacks
- def on_GET(self, origin, content, query, room_id):
+ async def on_GET(self, origin, content, query, room_id):
store = self.handler.hs.get_datastore()
- is_public = yield store.is_room_world_readable_or_publicly_joinable(
- room_id
- )
+ is_public = await store.is_room_world_readable_or_publicly_joinable(room_id)
if not is_public:
raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM)
- complexity = yield store.get_room_complexity(room_id)
- defer.returnValue((200, complexity))
+ complexity = await store.get_room_complexity(room_id)
+ return 200, complexity
FEDERATION_SERVLET_CLASSES = (
FederationSendServlet,
FederationEventServlet,
- FederationStateServlet,
+ FederationStateV1Servlet,
FederationStateIdsServlet,
FederationBackfillServlet,
FederationQueryServlet,
FederationMakeJoinServlet,
FederationMakeLeaveServlet,
FederationEventServlet,
- FederationSendJoinServlet,
- FederationSendLeaveServlet,
+ FederationV1SendJoinServlet,
+ FederationV2SendJoinServlet,
+ FederationV1SendLeaveServlet,
+ FederationV2SendLeaveServlet,
FederationV1InviteServlet,
FederationV2InviteServlet,
- FederationQueryAuthServlet,
FederationGetMissingEventsServlet,
FederationEventAuthServlet,
FederationClientKeysQueryServlet,
@@ -1360,15 +1410,13 @@ FEDERATION_SERVLET_CLASSES = (
On3pidBindServlet,
FederationVersionServlet,
RoomComplexityServlet,
-)
+) # type: Tuple[Type[BaseFederationServlet], ...]
OPENID_SERVLET_CLASSES = (
OpenIdUserInfo,
-)
+) # type: Tuple[Type[BaseFederationServlet], ...]
-ROOM_LIST_CLASSES = (
- PublicRoomList,
-)
+ROOM_LIST_CLASSES = (PublicRoomList,) # type: Tuple[Type[PublicRoomList], ...]
GROUP_SERVER_SERVLET_CLASSES = (
FederationGroupsProfileServlet,
@@ -1389,19 +1437,19 @@ GROUP_SERVER_SERVLET_CLASSES = (
FederationGroupsAddRoomsServlet,
FederationGroupsAddRoomsConfigServlet,
FederationGroupsSettingJoinPolicyServlet,
-)
+) # type: Tuple[Type[BaseFederationServlet], ...]
GROUP_LOCAL_SERVLET_CLASSES = (
FederationGroupsLocalInviteServlet,
FederationGroupsRemoveLocalUserServlet,
FederationGroupsBulkPublicisedServlet,
-)
+) # type: Tuple[Type[BaseFederationServlet], ...]
GROUP_ATTESTATION_SERVLET_CLASSES = (
FederationGroupsRenewAttestaionServlet,
-)
+) # type: Tuple[Type[BaseFederationServlet], ...]
DEFAULT_SERVLET_GROUPS = (
"federation",
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index 025a79c022..6b32e0dcbf 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -19,11 +19,15 @@ server protocol.
import logging
+import attr
+
+from synapse.types import JsonDict
from synapse.util.jsonobject import JsonEncodedObject
logger = logging.getLogger(__name__)
+@attr.s(slots=True)
class Edu(JsonEncodedObject):
""" An Edu represents a piece of data sent from one homeserver to another.
@@ -32,21 +36,30 @@ class Edu(JsonEncodedObject):
internal ID or previous references graph.
"""
- valid_keys = [
- "origin",
- "destination",
- "edu_type",
- "content",
- ]
+ edu_type = attr.ib(type=str)
+ content = attr.ib(type=dict)
+ origin = attr.ib(type=str)
+ destination = attr.ib(type=str)
- required_keys = [
- "edu_type",
- ]
+ def get_dict(self) -> JsonDict:
+ return {
+ "edu_type": self.edu_type,
+ "content": self.content,
+ }
- internal_keys = [
- "origin",
- "destination",
- ]
+ def get_internal_dict(self) -> JsonDict:
+ return {
+ "edu_type": self.edu_type,
+ "content": self.content,
+ "origin": self.origin,
+ "destination": self.destination,
+ }
+
+ def get_context(self):
+ return getattr(self, "content", {}).get("org.matrix.opentracing_context", "{}")
+
+ def strip_context(self):
+ getattr(self, "content", {})["org.matrix.opentracing_context"] = "{}"
class Transaction(JsonEncodedObject):
@@ -75,10 +88,7 @@ class Transaction(JsonEncodedObject):
"edus",
]
- internal_keys = [
- "transaction_id",
- "destination",
- ]
+ internal_keys = ["transaction_id", "destination"]
required_keys = [
"transaction_id",
@@ -98,9 +108,7 @@ class Transaction(JsonEncodedObject):
del kwargs["edus"]
super(Transaction, self).__init__(
- transaction_id=transaction_id,
- pdus=pdus,
- **kwargs
+ transaction_id=transaction_id, pdus=pdus, **kwargs
)
@staticmethod
@@ -109,13 +117,9 @@ class Transaction(JsonEncodedObject):
transaction_id and origin_server_ts keys.
"""
if "origin_server_ts" not in kwargs:
- raise KeyError(
- "Require 'origin_server_ts' to construct a Transaction"
- )
+ raise KeyError("Require 'origin_server_ts' to construct a Transaction")
if "transaction_id" not in kwargs:
- raise KeyError(
- "Require 'transaction_id' to construct a Transaction"
- )
+ raise KeyError("Require 'transaction_id' to construct a Transaction")
kwargs["pdus"] = [p.get_pdu_json() for p in pdus]
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index e5dda1975f..d950a8b246 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -42,10 +42,10 @@ from signedjson.sign import sign_json
from twisted.internet import defer
-from synapse.api.errors import RequestSendFailed, SynapseError
+from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
+from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import get_domain_from_id
-from synapse.util.logcontext import run_in_background
logger = logging.getLogger(__name__)
@@ -65,6 +65,7 @@ UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
class GroupAttestationSigning(object):
"""Creates and verifies group attestations.
"""
+
def __init__(self, hs):
self.keyring = hs.get_keyring()
self.clock = hs.get_clock()
@@ -113,11 +114,15 @@ class GroupAttestationSigning(object):
validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER)
valid_until_ms = int(self.clock.time_msec() + validity_period)
- return sign_json({
- "group_id": group_id,
- "user_id": user_id,
- "valid_until_ms": valid_until_ms,
- }, self.server_name, self.signing_key)
+ return sign_json(
+ {
+ "group_id": group_id,
+ "user_id": user_id,
+ "valid_until_ms": valid_until_ms,
+ },
+ self.server_name,
+ self.signing_key,
+ )
class GroupAttestionRenewer(object):
@@ -132,9 +137,10 @@ class GroupAttestionRenewer(object):
self.is_mine_id = hs.is_mine_id
self.attestations = hs.get_groups_attestation_signing()
- self._renew_attestations_loop = self.clock.looping_call(
- self._start_renew_attestations, 30 * 60 * 1000,
- )
+ if not hs.config.worker_app:
+ self._renew_attestations_loop = self.clock.looping_call(
+ self._start_renew_attestations, 30 * 60 * 1000
+ )
@defer.inlineCallbacks
def on_renew_attestation(self, group_id, user_id, content):
@@ -146,14 +152,12 @@ class GroupAttestionRenewer(object):
raise SynapseError(400, "Neither user not group are on this server")
yield self.attestations.verify_attestation(
- attestation,
- user_id=user_id,
- group_id=group_id,
+ attestation, user_id=user_id, group_id=group_id
)
yield self.store.update_remote_attestion(group_id, user_id, attestation)
- defer.returnValue({})
+ return {}
def _start_renew_attestations(self):
return run_as_background_process("renew_attestations", self._renew_attestations)
@@ -177,9 +181,10 @@ class GroupAttestionRenewer(object):
elif not self.is_mine_id(user_id):
destination = get_domain_from_id(user_id)
else:
- logger.warn(
+ logger.warning(
"Incorrectly trying to do attestations for user: %r in %r",
- user_id, group_id,
+ user_id,
+ group_id,
)
yield self.store.remove_attestation_renewal(group_id, user_id)
return
@@ -187,21 +192,20 @@ class GroupAttestionRenewer(object):
attestation = self.attestations.create_attestation(group_id, user_id)
yield self.transport_client.renew_group_attestation(
- destination, group_id, user_id,
- content={"attestation": attestation},
+ destination, group_id, user_id, content={"attestation": attestation}
)
yield self.store.update_attestation_renewal(
group_id, user_id, attestation
)
- except RequestSendFailed as e:
+ except (RequestSendFailed, HttpResponseException) as e:
logger.warning(
- "Failed to renew attestation of %r in %r: %s",
- user_id, group_id, e,
+ "Failed to renew attestation of %r in %r: %s", user_id, group_id, e
)
except Exception:
- logger.exception("Error renewing attestation of %r in %r",
- user_id, group_id)
+ logger.exception(
+ "Error renewing attestation of %r in %r", user_id, group_id
+ )
for row in rows:
group_id = row["group_id"]
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 817be40360..4f0dc0a209 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
# Copyright 2018 New Vector Ltd
+# Copyright 2019 Michael Telatynski <7t3chguy@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,22 +21,22 @@ from six import string_types
from twisted.internet import defer
-from synapse.api.errors import SynapseError
+from synapse.api.errors import Codes, SynapseError
from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute
logger = logging.getLogger(__name__)
-# TODO: Allow users to "knock" or simpkly join depending on rules
+# TODO: Allow users to "knock" or simply join depending on rules
# TODO: Federation admin APIs
-# TODO: is_priveged flag to users and is_public to users and rooms
+# TODO: is_privileged flag to users and is_public to users and rooms
# TODO: Audit log for admins (profile updates, membership changes, users who tried
# to join but were rejected, etc)
# TODO: Flairs
-class GroupsServerHandler(object):
+class GroupsServerWorkerHandler(object):
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
@@ -50,12 +51,10 @@ class GroupsServerHandler(object):
self.transport_client = hs.get_federation_transport_client()
self.profile_handler = hs.get_profile_handler()
- # Ensure attestations get renewed
- hs.get_groups_attestation_renewer()
-
@defer.inlineCallbacks
- def check_group_is_ours(self, group_id, requester_user_id,
- and_exists=False, and_is_admin=None):
+ def check_group_is_ours(
+ self, group_id, requester_user_id, and_exists=False, and_is_admin=None
+ ):
"""Check that the group is ours, and optionally if it exists.
If group does exist then return group.
@@ -73,7 +72,9 @@ class GroupsServerHandler(object):
if and_exists and not group:
raise SynapseError(404, "Unknown group")
- is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
+ is_user_in_group = yield self.store.is_user_in_group(
+ requester_user_id, group_id
+ )
if group and not is_user_in_group and not group["is_public"]:
raise SynapseError(404, "Unknown group")
@@ -82,7 +83,7 @@ class GroupsServerHandler(object):
if not is_admin:
raise SynapseError(403, "User is not admin in group")
- defer.returnValue(group)
+ return group
@defer.inlineCallbacks
def get_group_summary(self, group_id, requester_user_id):
@@ -96,25 +97,27 @@ class GroupsServerHandler(object):
"""
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
- is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
+ is_user_in_group = yield self.store.is_user_in_group(
+ requester_user_id, group_id
+ )
profile = yield self.get_group_profile(group_id, requester_user_id)
users, roles = yield self.store.get_users_for_summary_by_role(
- group_id, include_private=is_user_in_group,
+ group_id, include_private=is_user_in_group
)
# TODO: Add profiles to users
rooms, categories = yield self.store.get_rooms_for_summary_by_category(
- group_id, include_private=is_user_in_group,
+ group_id, include_private=is_user_in_group
)
for room_entry in rooms:
room_id = room_entry["room_id"]
joined_users = yield self.store.get_users_in_room(room_id)
entry = yield self.room_list_handler.generate_room_entry(
- room_id, len(joined_users), with_alias=False, allow_private=True,
+ room_id, len(joined_users), with_alias=False, allow_private=True
)
entry = dict(entry) # so we don't change whats cached
entry.pop("room_id", None)
@@ -134,7 +137,7 @@ class GroupsServerHandler(object):
entry["attestation"] = attestation
else:
entry["attestation"] = self.attestations.create_attestation(
- group_id, user_id,
+ group_id, user_id
)
user_profile = yield self.profile_handler.get_profile_from_cache(user_id)
@@ -143,10 +146,10 @@ class GroupsServerHandler(object):
users.sort(key=lambda e: e.get("order", 0))
membership_info = yield self.store.get_users_membership_info_in_group(
- group_id, requester_user_id,
+ group_id, requester_user_id
)
- defer.returnValue({
+ return {
"profile": profile,
"users_section": {
"users": users,
@@ -159,18 +162,207 @@ class GroupsServerHandler(object):
"total_room_count_estimate": 0, # TODO
},
"user": membership_info,
- })
+ }
+
+ @defer.inlineCallbacks
+ def get_group_categories(self, group_id, requester_user_id):
+ """Get all categories in a group (as seen by user)
+ """
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ categories = yield self.store.get_group_categories(group_id=group_id)
+ return {"categories": categories}
+
+ @defer.inlineCallbacks
+ def get_group_category(self, group_id, requester_user_id, category_id):
+ """Get a specific category in a group (as seen by user)
+ """
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ res = yield self.store.get_group_category(
+ group_id=group_id, category_id=category_id
+ )
+
+ logger.info("group %s", res)
+
+ return res
+
+ @defer.inlineCallbacks
+ def get_group_roles(self, group_id, requester_user_id):
+ """Get all roles in a group (as seen by user)
+ """
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ roles = yield self.store.get_group_roles(group_id=group_id)
+ return {"roles": roles}
+
+ @defer.inlineCallbacks
+ def get_group_role(self, group_id, requester_user_id, role_id):
+ """Get a specific role in a group (as seen by user)
+ """
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ res = yield self.store.get_group_role(group_id=group_id, role_id=role_id)
+ return res
+
+ @defer.inlineCallbacks
+ def get_group_profile(self, group_id, requester_user_id):
+ """Get the group profile as seen by requester_user_id
+ """
+
+ yield self.check_group_is_ours(group_id, requester_user_id)
+
+ group = yield self.store.get_group(group_id)
+
+ if group:
+ cols = [
+ "name",
+ "short_description",
+ "long_description",
+ "avatar_url",
+ "is_public",
+ ]
+ group_description = {key: group[key] for key in cols}
+ group_description["is_openly_joinable"] = group["join_policy"] == "open"
+
+ return group_description
+ else:
+ raise SynapseError(404, "Unknown group")
@defer.inlineCallbacks
- def update_group_summary_room(self, group_id, requester_user_id,
- room_id, category_id, content):
+ def get_users_in_group(self, group_id, requester_user_id):
+ """Get the users in group as seen by requester_user_id.
+
+ The ordering is arbitrary at the moment
+ """
+
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ is_user_in_group = yield self.store.is_user_in_group(
+ requester_user_id, group_id
+ )
+
+ user_results = yield self.store.get_users_in_group(
+ group_id, include_private=is_user_in_group
+ )
+
+ chunk = []
+ for user_result in user_results:
+ g_user_id = user_result["user_id"]
+ is_public = user_result["is_public"]
+ is_privileged = user_result["is_admin"]
+
+ entry = {"user_id": g_user_id}
+
+ profile = yield self.profile_handler.get_profile_from_cache(g_user_id)
+ entry.update(profile)
+
+ entry["is_public"] = bool(is_public)
+ entry["is_privileged"] = bool(is_privileged)
+
+ if not self.is_mine_id(g_user_id):
+ attestation = yield self.store.get_remote_attestation(
+ group_id, g_user_id
+ )
+ if not attestation:
+ continue
+
+ entry["attestation"] = attestation
+ else:
+ entry["attestation"] = self.attestations.create_attestation(
+ group_id, g_user_id
+ )
+
+ chunk.append(entry)
+
+ # TODO: If admin add lists of users whose attestations have timed out
+
+ return {"chunk": chunk, "total_user_count_estimate": len(user_results)}
+
+ @defer.inlineCallbacks
+ def get_invited_users_in_group(self, group_id, requester_user_id):
+ """Get the users that have been invited to a group as seen by requester_user_id.
+
+ The ordering is arbitrary at the moment
+ """
+
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ is_user_in_group = yield self.store.is_user_in_group(
+ requester_user_id, group_id
+ )
+
+ if not is_user_in_group:
+ raise SynapseError(403, "User not in group")
+
+ invited_users = yield self.store.get_invited_users_in_group(group_id)
+
+ user_profiles = []
+
+ for user_id in invited_users:
+ user_profile = {"user_id": user_id}
+ try:
+ profile = yield self.profile_handler.get_profile_from_cache(user_id)
+ user_profile.update(profile)
+ except Exception as e:
+ logger.warning("Error getting profile for %s: %s", user_id, e)
+ user_profiles.append(user_profile)
+
+ return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)}
+
+ @defer.inlineCallbacks
+ def get_rooms_in_group(self, group_id, requester_user_id):
+ """Get the rooms in group as seen by requester_user_id
+
+ This returns rooms in order of decreasing number of joined users
+ """
+
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ is_user_in_group = yield self.store.is_user_in_group(
+ requester_user_id, group_id
+ )
+
+ room_results = yield self.store.get_rooms_in_group(
+ group_id, include_private=is_user_in_group
+ )
+
+ chunk = []
+ for room_result in room_results:
+ room_id = room_result["room_id"]
+
+ joined_users = yield self.store.get_users_in_room(room_id)
+ entry = yield self.room_list_handler.generate_room_entry(
+ room_id, len(joined_users), with_alias=False, allow_private=True
+ )
+
+ if not entry:
+ continue
+
+ entry["is_public"] = bool(room_result["is_public"])
+
+ chunk.append(entry)
+
+ chunk.sort(key=lambda e: -e["num_joined_members"])
+
+ return {"chunk": chunk, "total_room_count_estimate": len(room_results)}
+
+
+class GroupsServerHandler(GroupsServerWorkerHandler):
+ def __init__(self, hs):
+ super(GroupsServerHandler, self).__init__(hs)
+
+ # Ensure attestations get renewed
+ hs.get_groups_attestation_renewer()
+
+ @defer.inlineCallbacks
+ def update_group_summary_room(
+ self, group_id, requester_user_id, room_id, category_id, content
+ ):
"""Add/update a room to the group summary
"""
yield self.check_group_is_ours(
- group_id,
- requester_user_id,
- and_exists=True,
- and_is_admin=requester_user_id,
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
RoomID.from_string(room_id) # Ensure valid room id
@@ -187,27 +379,23 @@ class GroupsServerHandler(object):
is_public=is_public,
)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
- def delete_group_summary_room(self, group_id, requester_user_id,
- room_id, category_id):
+ def delete_group_summary_room(
+ self, group_id, requester_user_id, room_id, category_id
+ ):
"""Remove a room from the summary
"""
yield self.check_group_is_ours(
- group_id,
- requester_user_id,
- and_exists=True,
- and_is_admin=requester_user_id,
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
yield self.store.remove_room_from_summary(
- group_id=group_id,
- room_id=room_id,
- category_id=category_id,
+ group_id=group_id, room_id=room_id, category_id=category_id
)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def set_group_join_policy(self, group_id, requester_user_id, content):
@@ -223,47 +411,18 @@ class GroupsServerHandler(object):
join_policy = _parse_join_policy_from_contents(content)
if join_policy is None:
- raise SynapseError(
- 400, "No value specified for 'm.join_policy'"
- )
+ raise SynapseError(400, "No value specified for 'm.join_policy'")
yield self.store.set_group_join_policy(group_id, join_policy=join_policy)
- defer.returnValue({})
-
- @defer.inlineCallbacks
- def get_group_categories(self, group_id, requester_user_id):
- """Get all categories in a group (as seen by user)
- """
- yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
-
- categories = yield self.store.get_group_categories(
- group_id=group_id,
- )
- defer.returnValue({"categories": categories})
-
- @defer.inlineCallbacks
- def get_group_category(self, group_id, requester_user_id, category_id):
- """Get a specific category in a group (as seen by user)
- """
- yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
-
- res = yield self.store.get_group_category(
- group_id=group_id,
- category_id=category_id,
- )
-
- defer.returnValue(res)
+ return {}
@defer.inlineCallbacks
def update_group_category(self, group_id, requester_user_id, category_id, content):
"""Add/Update a group category
"""
yield self.check_group_is_ours(
- group_id,
- requester_user_id,
- and_exists=True,
- and_is_admin=requester_user_id,
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
is_public = _parse_visibility_from_contents(content)
@@ -276,58 +435,28 @@ class GroupsServerHandler(object):
profile=profile,
)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def delete_group_category(self, group_id, requester_user_id, category_id):
"""Delete a group category
"""
yield self.check_group_is_ours(
- group_id,
- requester_user_id,
- and_exists=True,
- and_is_admin=requester_user_id
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
yield self.store.remove_group_category(
- group_id=group_id,
- category_id=category_id,
+ group_id=group_id, category_id=category_id
)
- defer.returnValue({})
-
- @defer.inlineCallbacks
- def get_group_roles(self, group_id, requester_user_id):
- """Get all roles in a group (as seen by user)
- """
- yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
-
- roles = yield self.store.get_group_roles(
- group_id=group_id,
- )
- defer.returnValue({"roles": roles})
-
- @defer.inlineCallbacks
- def get_group_role(self, group_id, requester_user_id, role_id):
- """Get a specific role in a group (as seen by user)
- """
- yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
-
- res = yield self.store.get_group_role(
- group_id=group_id,
- role_id=role_id,
- )
- defer.returnValue(res)
+ return {}
@defer.inlineCallbacks
def update_group_role(self, group_id, requester_user_id, role_id, content):
"""Add/update a role in a group
"""
yield self.check_group_is_ours(
- group_id,
- requester_user_id,
- and_exists=True,
- and_is_admin=requester_user_id,
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
is_public = _parse_visibility_from_contents(content)
@@ -335,39 +464,31 @@ class GroupsServerHandler(object):
profile = content.get("profile")
yield self.store.upsert_group_role(
- group_id=group_id,
- role_id=role_id,
- is_public=is_public,
- profile=profile,
+ group_id=group_id, role_id=role_id, is_public=is_public, profile=profile
)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def delete_group_role(self, group_id, requester_user_id, role_id):
"""Remove role from group
"""
yield self.check_group_is_ours(
- group_id,
- requester_user_id,
- and_exists=True,
- and_is_admin=requester_user_id,
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
- yield self.store.remove_group_role(
- group_id=group_id,
- role_id=role_id,
- )
+ yield self.store.remove_group_role(group_id=group_id, role_id=role_id)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
- def update_group_summary_user(self, group_id, requester_user_id, user_id, role_id,
- content):
+ def update_group_summary_user(
+ self, group_id, requester_user_id, user_id, role_id, content
+ ):
"""Add/update a users entry in the group summary
"""
yield self.check_group_is_ours(
- group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
order = content.get("order", None)
@@ -382,56 +503,32 @@ class GroupsServerHandler(object):
is_public=is_public,
)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def delete_group_summary_user(self, group_id, requester_user_id, user_id, role_id):
"""Remove a user from the group summary
"""
yield self.check_group_is_ours(
- group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
yield self.store.remove_user_from_summary(
- group_id=group_id,
- user_id=user_id,
- role_id=role_id,
+ group_id=group_id, user_id=user_id, role_id=role_id
)
- defer.returnValue({})
-
- @defer.inlineCallbacks
- def get_group_profile(self, group_id, requester_user_id):
- """Get the group profile as seen by requester_user_id
- """
-
- yield self.check_group_is_ours(group_id, requester_user_id)
-
- group = yield self.store.get_group(group_id)
-
- if group:
- cols = [
- "name", "short_description", "long_description",
- "avatar_url", "is_public",
- ]
- group_description = {key: group[key] for key in cols}
- group_description["is_openly_joinable"] = group["join_policy"] == "open"
-
- defer.returnValue(group_description)
- else:
- raise SynapseError(404, "Unknown group")
+ return {}
@defer.inlineCallbacks
def update_group_profile(self, group_id, requester_user_id, content):
"""Update the group profile
"""
yield self.check_group_is_ours(
- group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id,
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
)
profile = {}
- for keyname in ("name", "avatar_url", "short_description",
- "long_description"):
+ for keyname in ("name", "avatar_url", "short_description", "long_description"):
if keyname in content:
value = content[keyname]
if not isinstance(value, string_types):
@@ -441,127 +538,6 @@ class GroupsServerHandler(object):
yield self.store.update_group_profile(group_id, profile)
@defer.inlineCallbacks
- def get_users_in_group(self, group_id, requester_user_id):
- """Get the users in group as seen by requester_user_id.
-
- The ordering is arbitrary at the moment
- """
-
- yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
-
- is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
-
- user_results = yield self.store.get_users_in_group(
- group_id, include_private=is_user_in_group,
- )
-
- chunk = []
- for user_result in user_results:
- g_user_id = user_result["user_id"]
- is_public = user_result["is_public"]
- is_privileged = user_result["is_admin"]
-
- entry = {"user_id": g_user_id}
-
- profile = yield self.profile_handler.get_profile_from_cache(g_user_id)
- entry.update(profile)
-
- entry["is_public"] = bool(is_public)
- entry["is_privileged"] = bool(is_privileged)
-
- if not self.is_mine_id(g_user_id):
- attestation = yield self.store.get_remote_attestation(group_id, g_user_id)
- if not attestation:
- continue
-
- entry["attestation"] = attestation
- else:
- entry["attestation"] = self.attestations.create_attestation(
- group_id, g_user_id,
- )
-
- chunk.append(entry)
-
- # TODO: If admin add lists of users whose attestations have timed out
-
- defer.returnValue({
- "chunk": chunk,
- "total_user_count_estimate": len(user_results),
- })
-
- @defer.inlineCallbacks
- def get_invited_users_in_group(self, group_id, requester_user_id):
- """Get the users that have been invited to a group as seen by requester_user_id.
-
- The ordering is arbitrary at the moment
- """
-
- yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
-
- is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
-
- if not is_user_in_group:
- raise SynapseError(403, "User not in group")
-
- invited_users = yield self.store.get_invited_users_in_group(group_id)
-
- user_profiles = []
-
- for user_id in invited_users:
- user_profile = {
- "user_id": user_id
- }
- try:
- profile = yield self.profile_handler.get_profile_from_cache(user_id)
- user_profile.update(profile)
- except Exception as e:
- logger.warn("Error getting profile for %s: %s", user_id, e)
- user_profiles.append(user_profile)
-
- defer.returnValue({
- "chunk": user_profiles,
- "total_user_count_estimate": len(invited_users),
- })
-
- @defer.inlineCallbacks
- def get_rooms_in_group(self, group_id, requester_user_id):
- """Get the rooms in group as seen by requester_user_id
-
- This returns rooms in order of decreasing number of joined users
- """
-
- yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
-
- is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
-
- room_results = yield self.store.get_rooms_in_group(
- group_id, include_private=is_user_in_group,
- )
-
- chunk = []
- for room_result in room_results:
- room_id = room_result["room_id"]
-
- joined_users = yield self.store.get_users_in_room(room_id)
- entry = yield self.room_list_handler.generate_room_entry(
- room_id, len(joined_users), with_alias=False, allow_private=True,
- )
-
- if not entry:
- continue
-
- entry["is_public"] = bool(room_result["is_public"])
-
- chunk.append(entry)
-
- chunk.sort(key=lambda e: -e["num_joined_members"])
-
- defer.returnValue({
- "chunk": chunk,
- "total_room_count_estimate": len(room_results),
- })
-
- @defer.inlineCallbacks
def add_room_to_group(self, group_id, requester_user_id, room_id, content):
"""Add room to group
"""
@@ -575,11 +551,12 @@ class GroupsServerHandler(object):
yield self.store.add_room_to_group(group_id, room_id, is_public=is_public)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
- def update_room_in_group(self, group_id, requester_user_id, room_id, config_key,
- content):
+ def update_room_in_group(
+ self, group_id, requester_user_id, room_id, config_key, content
+ ):
"""Update room in group
"""
RoomID.from_string(room_id) # Ensure valid room id
@@ -592,13 +569,12 @@ class GroupsServerHandler(object):
is_public = _parse_visibility_dict(content)
yield self.store.update_room_in_group_visibility(
- group_id, room_id,
- is_public=is_public,
+ group_id, room_id, is_public=is_public
)
else:
raise SynapseError(400, "Uknown config option")
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def remove_room_from_group(self, group_id, requester_user_id, room_id):
@@ -610,7 +586,7 @@ class GroupsServerHandler(object):
yield self.store.remove_room_from_group(group_id, room_id)
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def invite_to_group(self, group_id, user_id, requester_user_id, content):
@@ -622,13 +598,21 @@ class GroupsServerHandler(object):
)
# TODO: Check if user knocked
- # TODO: Check if user is already invited
+
+ invited_users = yield self.store.get_invited_users_in_group(group_id)
+ if user_id in invited_users:
+ raise SynapseError(
+ 400, "User already invited to group", errcode=Codes.BAD_STATE
+ )
+
+ user_results = yield self.store.get_users_in_group(
+ group_id, include_private=True
+ )
+ if user_id in (user_result["user_id"] for user_result in user_results):
+ raise SynapseError(400, "User already in group")
content = {
- "profile": {
- "name": group["name"],
- "avatar_url": group["avatar_url"],
- },
+ "profile": {"name": group["name"], "avatar_url": group["avatar_url"]},
"inviter": requester_user_id,
}
@@ -638,9 +622,7 @@ class GroupsServerHandler(object):
local_attestation = None
else:
local_attestation = self.attestations.create_attestation(group_id, user_id)
- content.update({
- "attestation": local_attestation,
- })
+ content.update({"attestation": local_attestation})
res = yield self.transport_client.invite_to_group_notification(
get_domain_from_id(user_id), group_id, user_id, content
@@ -658,31 +640,24 @@ class GroupsServerHandler(object):
remote_attestation = res["attestation"]
yield self.attestations.verify_attestation(
- remote_attestation,
- user_id=user_id,
- group_id=group_id,
+ remote_attestation, user_id=user_id, group_id=group_id
)
else:
remote_attestation = None
yield self.store.add_user_to_group(
- group_id, user_id,
+ group_id,
+ user_id,
is_admin=False,
is_public=False, # TODO
local_attestation=local_attestation,
remote_attestation=remote_attestation,
)
elif res["state"] == "invite":
- yield self.store.add_group_invite(
- group_id, user_id,
- )
- defer.returnValue({
- "state": "invite"
- })
+ yield self.store.add_group_invite(group_id, user_id)
+ return {"state": "invite"}
elif res["state"] == "reject":
- defer.returnValue({
- "state": "reject"
- })
+ return {"state": "reject"}
else:
raise SynapseError(502, "Unknown state returned by HS")
@@ -693,16 +668,12 @@ class GroupsServerHandler(object):
See accept_invite, join_group.
"""
if not self.hs.is_mine_id(user_id):
- local_attestation = self.attestations.create_attestation(
- group_id, user_id,
- )
+ local_attestation = self.attestations.create_attestation(group_id, user_id)
remote_attestation = content["attestation"]
yield self.attestations.verify_attestation(
- remote_attestation,
- user_id=user_id,
- group_id=group_id,
+ remote_attestation, user_id=user_id, group_id=group_id
)
else:
local_attestation = None
@@ -711,14 +682,15 @@ class GroupsServerHandler(object):
is_public = _parse_visibility_from_contents(content)
yield self.store.add_user_to_group(
- group_id, user_id,
+ group_id,
+ user_id,
is_admin=False,
is_public=is_public,
local_attestation=local_attestation,
remote_attestation=remote_attestation,
)
- defer.returnValue(local_attestation)
+ return local_attestation
@defer.inlineCallbacks
def accept_invite(self, group_id, requester_user_id, content):
@@ -731,17 +703,14 @@ class GroupsServerHandler(object):
yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
is_invited = yield self.store.is_user_invited_to_local_group(
- group_id, requester_user_id,
+ group_id, requester_user_id
)
if not is_invited:
raise SynapseError(403, "User not invited to group")
local_attestation = yield self._add_user(group_id, requester_user_id, content)
- defer.returnValue({
- "state": "join",
- "attestation": local_attestation,
- })
+ return {"state": "join", "attestation": local_attestation}
@defer.inlineCallbacks
def join_group(self, group_id, requester_user_id, content):
@@ -753,15 +722,12 @@ class GroupsServerHandler(object):
group_info = yield self.check_group_is_ours(
group_id, requester_user_id, and_exists=True
)
- if group_info['join_policy'] != "open":
+ if group_info["join_policy"] != "open":
raise SynapseError(403, "Group is not publicly joinable")
local_attestation = yield self._add_user(group_id, requester_user_id, content)
- defer.returnValue({
- "state": "join",
- "attestation": local_attestation,
- })
+ return {"state": "join", "attestation": local_attestation}
@defer.inlineCallbacks
def knock(self, group_id, requester_user_id, content):
@@ -800,9 +766,7 @@ class GroupsServerHandler(object):
is_kick = True
- yield self.store.remove_user_from_group(
- group_id, user_id,
- )
+ yield self.store.remove_user_from_group(group_id, user_id)
if is_kick:
if self.hs.is_mine_id(user_id):
@@ -816,7 +780,12 @@ class GroupsServerHandler(object):
if not self.hs.is_mine_id(user_id):
yield self.store.maybe_delete_remote_profile_cache(user_id)
- defer.returnValue({})
+ # Delete group if the last user has left
+ users = yield self.store.get_users_in_group(group_id, include_private=True)
+ if not users:
+ yield self.store.delete_group(group_id)
+
+ return {}
@defer.inlineCallbacks
def create_group(self, group_id, requester_user_id, content):
@@ -830,19 +799,20 @@ class GroupsServerHandler(object):
if group:
raise SynapseError(400, "Group already exists")
- is_admin = yield self.auth.is_server_admin(UserID.from_string(requester_user_id))
+ is_admin = yield self.auth.is_server_admin(
+ UserID.from_string(requester_user_id)
+ )
if not is_admin:
if not self.hs.config.enable_group_creation:
raise SynapseError(
- 403, "Only a server admin can create groups on this server",
+ 403, "Only a server admin can create groups on this server"
)
localpart = group_id_obj.localpart
if not localpart.startswith(self.hs.config.group_creation_prefix):
raise SynapseError(
400,
- "Can only create groups with prefix %r on this server" % (
- self.hs.config.group_creation_prefix,
- ),
+ "Can only create groups with prefix %r on this server"
+ % (self.hs.config.group_creation_prefix,),
)
profile = content.get("profile", {})
@@ -865,21 +835,19 @@ class GroupsServerHandler(object):
remote_attestation = content["attestation"]
yield self.attestations.verify_attestation(
- remote_attestation,
- user_id=requester_user_id,
- group_id=group_id,
+ remote_attestation, user_id=requester_user_id, group_id=group_id
)
local_attestation = self.attestations.create_attestation(
- group_id,
- requester_user_id,
+ group_id, requester_user_id
)
else:
local_attestation = None
remote_attestation = None
yield self.store.add_user_to_group(
- group_id, requester_user_id,
+ group_id,
+ requester_user_id,
is_admin=True,
is_public=True, # TODO
local_attestation=local_attestation,
@@ -893,9 +861,7 @@ class GroupsServerHandler(object):
avatar_url=user_profile.get("avatar_url"),
)
- defer.returnValue({
- "group_id": group_id,
- })
+ return {"group_id": group_id}
@defer.inlineCallbacks
def delete_group(self, group_id, requester_user_id):
@@ -911,29 +877,22 @@ class GroupsServerHandler(object):
Deferred
"""
- yield self.check_group_is_ours(
- group_id, requester_user_id,
- and_exists=True,
- )
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
# Only server admins or group admins can delete groups.
- is_admin = yield self.store.is_user_admin_in_group(
- group_id, requester_user_id
- )
+ is_admin = yield self.store.is_user_admin_in_group(group_id, requester_user_id)
if not is_admin:
is_admin = yield self.auth.is_server_admin(
- UserID.from_string(requester_user_id),
+ UserID.from_string(requester_user_id)
)
if not is_admin:
raise SynapseError(403, "User is not an admin")
# Before deleting the group lets kick everyone out of it
- users = yield self.store.get_users_in_group(
- group_id, include_private=True,
- )
+ users = yield self.store.get_users_in_group(group_id, include_private=True)
@defer.inlineCallbacks
def _kick_user_from_group(user_id):
@@ -989,9 +948,7 @@ def _parse_join_policy_dict(join_policy_dict):
return "invite"
if join_policy_type not in ("invite", "open"):
- raise SynapseError(
- 400, "Synapse only supports 'invite'/'open' join rule"
- )
+ raise SynapseError(400, "Synapse only supports 'invite'/'open' join rule")
return join_policy_type
@@ -1018,7 +975,5 @@ def _parse_visibility_dict(visibility):
return True
if vis_type not in ("public", "private"):
- raise SynapseError(
- 400, "Synapse only supports 'public'/'private' visibility"
- )
+ raise SynapseError(400, "Synapse only supports 'public'/'private' visibility")
return vis_type == "public"
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index dca337ec61..51413d910e 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -45,6 +45,7 @@ class BaseHandler(object):
self.state_handler = hs.get_state_handler()
self.distributor = hs.get_distributor()
self.ratelimiter = hs.get_ratelimiter()
+ self.admin_redaction_ratelimiter = hs.get_admin_redaction_ratelimiter()
self.clock = hs.get_clock()
self.hs = hs
@@ -53,7 +54,7 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks
- def ratelimit(self, requester, update=True):
+ def ratelimit(self, requester, update=True, is_admin_redaction=False):
"""Ratelimits requests.
Args:
@@ -62,6 +63,9 @@ class BaseHandler(object):
Set to False when doing multiple checks for one request (e.g.
to check up front if we would reject the request), and set to
True for the last call for a given request.
+ is_admin_redaction (bool): Whether this is a room admin/moderator
+ redacting an event. If so then we may apply different
+ ratelimits depending on config.
Raises:
LimitExceededError if the request should be ratelimited
@@ -90,18 +94,36 @@ class BaseHandler(object):
messages_per_second = override.messages_per_second
burst_count = override.burst_count
else:
- messages_per_second = self.hs.config.rc_message.per_second
- burst_count = self.hs.config.rc_message.burst_count
-
- allowed, time_allowed = self.ratelimiter.can_do_action(
- user_id, time_now,
- rate_hz=messages_per_second,
- burst_count=burst_count,
- update=update,
- )
+ # We default to different values if this is an admin redaction and
+ # the config is set
+ if is_admin_redaction and self.hs.config.rc_admin_redaction:
+ messages_per_second = self.hs.config.rc_admin_redaction.per_second
+ burst_count = self.hs.config.rc_admin_redaction.burst_count
+ else:
+ messages_per_second = self.hs.config.rc_message.per_second
+ burst_count = self.hs.config.rc_message.burst_count
+
+ if is_admin_redaction and self.hs.config.rc_admin_redaction:
+ # If we have separate config for admin redactions we use a separate
+ # ratelimiter
+ allowed, time_allowed = self.admin_redaction_ratelimiter.can_do_action(
+ user_id,
+ time_now,
+ rate_hz=messages_per_second,
+ burst_count=burst_count,
+ update=update,
+ )
+ else:
+ allowed, time_allowed = self.ratelimiter.can_do_action(
+ user_id,
+ time_now,
+ rate_hz=messages_per_second,
+ burst_count=burst_count,
+ update=update,
+ )
if not allowed:
raise LimitExceededError(
- retry_after_ms=int(1000 * (time_allowed - time_now)),
+ retry_after_ms=int(1000 * (time_allowed - time_now))
)
@defer.inlineCallbacks
@@ -112,7 +134,7 @@ class BaseHandler(object):
guest_access = event.content.get("guest_access", "forbidden")
if guest_access != "can_join":
if context:
- current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = yield context.get_current_state_ids()
current_state = yield self.store.get_events(
list(current_state_ids.values())
)
@@ -139,7 +161,7 @@ class BaseHandler(object):
if member_event.content["membership"] not in {
Membership.JOIN,
- Membership.INVITE
+ Membership.INVITE,
}:
continue
@@ -156,8 +178,7 @@ class BaseHandler(object):
# and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having
# homeserver.
- requester = synapse.types.create_requester(
- target_user, is_guest=True)
+ requester = synapse.types.create_requester(target_user, is_guest=True)
handler = self.hs.get_room_member_handler()
yield handler.update_membership(
requester,
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 7fa5d44d29..a8d3fbc6de 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -13,53 +13,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
class AccountDataEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()
- def get_current_key(self, direction='f'):
+ def get_current_key(self, direction="f"):
return self.store.get_max_account_data_stream_id()
- @defer.inlineCallbacks
- def get_new_events(self, user, from_key, **kwargs):
+ async def get_new_events(self, user, from_key, **kwargs):
user_id = user.to_string()
last_stream_id = from_key
- current_stream_id = yield self.store.get_max_account_data_stream_id()
+ current_stream_id = self.store.get_max_account_data_stream_id()
results = []
- tags = yield self.store.get_updated_tags(user_id, last_stream_id)
+ tags = await self.store.get_updated_tags(user_id, last_stream_id)
for room_id, room_tags in tags.items():
- results.append({
- "type": "m.tag",
- "content": {"tags": room_tags},
- "room_id": room_id,
- })
+ results.append(
+ {"type": "m.tag", "content": {"tags": room_tags}, "room_id": room_id}
+ )
- account_data, room_account_data = (
- yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
- )
+ (
+ account_data,
+ room_account_data,
+ ) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id)
for account_data_type, content in account_data.items():
- results.append({
- "type": account_data_type,
- "content": content,
- })
+ results.append({"type": account_data_type, "content": content})
for room_id, account_data in room_account_data.items():
for account_data_type, content in account_data.items():
- results.append({
- "type": account_data_type,
- "content": content,
- "room_id": room_id,
- })
-
- defer.returnValue((results, current_stream_id))
+ results.append(
+ {"type": account_data_type, "content": content, "room_id": room_id}
+ )
- @defer.inlineCallbacks
- def get_pagination_rows(self, user, config, key):
- defer.returnValue(([], config.to_id))
+ return results, current_stream_id
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 51507bde61..a6c907b9c9 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -18,13 +18,15 @@ import email.utils
import logging
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
+from typing import List
from twisted.internet import defer
from synapse.api.errors import StoreError
+from synapse.logging.context import make_deferred_yieldable
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID
from synapse.util import stringutils
-from synapse.util.logcontext import make_deferred_yieldable
try:
from synapse.push.mailer import load_jinja2_templates
@@ -37,6 +39,7 @@ logger = logging.getLogger(__name__)
class AccountValidityHandler(object):
def __init__(self, hs):
self.hs = hs
+ self.config = hs.config
self.store = self.hs.get_datastore()
self.sendmail = self.hs.get_sendmail()
self.clock = self.hs.get_clock()
@@ -55,12 +58,10 @@ class AccountValidityHandler(object):
app_name = self.hs.config.email_app_name
self._subject = self._account_validity.renew_email_subject % {
- "app": app_name,
+ "app": app_name
}
- self._from_string = self.hs.config.email_notif_from % {
- "app": app_name,
- }
+ self._from_string = self.hs.config.email_notif_from % {"app": app_name}
except Exception:
# If substitution failed, fall back to the bare strings.
self._subject = self._account_validity.renew_email_subject
@@ -69,55 +70,61 @@ class AccountValidityHandler(object):
self._raw_from = email.utils.parseaddr(self._from_string)[1]
self._template_html, self._template_text = load_jinja2_templates(
- config=self.hs.config,
- template_html_name=self.hs.config.email_expiry_template_html,
- template_text_name=self.hs.config.email_expiry_template_text,
+ self.config.email_template_dir,
+ [
+ self.config.email_expiry_template_html,
+ self.config.email_expiry_template_text,
+ ],
+ apply_format_ts_filter=True,
+ apply_mxc_to_http_filter=True,
+ public_baseurl=self.config.public_baseurl,
)
# Check the renewal emails to send and send them every 30min.
+ def send_emails():
+ # run as a background process to make sure that the database transactions
+ # have a logcontext to report to
+ return run_as_background_process(
+ "send_renewals", self._send_renewal_emails
+ )
+
+ self.clock.looping_call(send_emails, 30 * 60 * 1000)
+
+ # If account_validity is enabled,check every hour to remove expired users from
+ # the user directory
+ if self._account_validity.enabled:
self.clock.looping_call(
- self.send_renewal_emails,
- 30 * 60 * 1000,
+ self._mark_expired_users_as_inactive, 60 * 60 * 1000
)
- # Check every hour to remove expired users from the user directory
- self.clock.looping_call(
- self._mark_expired_users_as_inactive,
- 60 * 60 * 1000,
- )
-
- @defer.inlineCallbacks
- def send_renewal_emails(self):
+ async def _send_renewal_emails(self):
"""Gets the list of users whose account is expiring in the amount of time
configured in the ``renew_at`` parameter from the ``account_validity``
configuration, and sends renewal emails to all of these users as long as they
have an email 3PID attached to their account.
"""
- expiring_users = yield self.store.get_users_expiring_soon()
+ expiring_users = await self.store.get_users_expiring_soon()
if expiring_users:
for user in expiring_users:
- yield self._send_renewal_email(
- user_id=user["user_id"],
- expiration_ts=user["expiration_ts_ms"],
+ await self._send_renewal_email(
+ user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
)
- @defer.inlineCallbacks
- def send_renewal_email_to_user(self, user_id):
- expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
- yield self._send_renewal_email(user_id, expiration_ts)
+ async def send_renewal_email_to_user(self, user_id: str):
+ expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
+ await self._send_renewal_email(user_id, expiration_ts)
- @defer.inlineCallbacks
- def _send_renewal_email(self, user_id, expiration_ts):
+ async def _send_renewal_email(self, user_id: str, expiration_ts: int):
"""Sends out a renewal email to every email address attached to the given user
with a unique link allowing them to renew their account.
Args:
- user_id (str): ID of the user to send email(s) to.
- expiration_ts (int): Timestamp in milliseconds for the expiration date of
+ user_id: ID of the user to send email(s) to.
+ expiration_ts: Timestamp in milliseconds for the expiration date of
this user's account (used in the email templates).
"""
- addresses = yield self._get_email_addresses_for_user(user_id)
+ addresses = await self._get_email_addresses_for_user(user_id)
# Stop right here if the user doesn't have at least one email address.
# In this case, they will have to ask their server admin to renew their
@@ -129,7 +136,7 @@ class AccountValidityHandler(object):
return
try:
- user_display_name = yield self.store.get_profile_displayname(
+ user_display_name = await self.store.get_profile_displayname(
UserID.from_string(user_id).localpart
)
if user_display_name is None:
@@ -137,7 +144,7 @@ class AccountValidityHandler(object):
except StoreError:
user_display_name = user_id
- renewal_token = yield self._get_renewal_token(user_id)
+ renewal_token = await self._get_renewal_token(user_id)
url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % (
self.hs.config.public_baseurl,
renewal_token,
@@ -158,62 +165,61 @@ class AccountValidityHandler(object):
for address in addresses:
raw_to = email.utils.parseaddr(address)[1]
- multipart_msg = MIMEMultipart('alternative')
- multipart_msg['Subject'] = self._subject
- multipart_msg['From'] = self._from_string
- multipart_msg['To'] = address
- multipart_msg['Date'] = email.utils.formatdate()
- multipart_msg['Message-ID'] = email.utils.make_msgid()
+ multipart_msg = MIMEMultipart("alternative")
+ multipart_msg["Subject"] = self._subject
+ multipart_msg["From"] = self._from_string
+ multipart_msg["To"] = address
+ multipart_msg["Date"] = email.utils.formatdate()
+ multipart_msg["Message-ID"] = email.utils.make_msgid()
multipart_msg.attach(text_part)
multipart_msg.attach(html_part)
logger.info("Sending renewal email to %s", address)
- yield make_deferred_yieldable(self.sendmail(
- self.hs.config.email_smtp_host,
- self._raw_from, raw_to, multipart_msg.as_string().encode('utf8'),
- reactor=self.hs.get_reactor(),
- port=self.hs.config.email_smtp_port,
- requireAuthentication=self.hs.config.email_smtp_user is not None,
- username=self.hs.config.email_smtp_user,
- password=self.hs.config.email_smtp_pass,
- requireTransportSecurity=self.hs.config.require_transport_security
- ))
-
- yield self.store.set_renewal_mail_status(
- user_id=user_id,
- email_sent=True,
- )
+ await make_deferred_yieldable(
+ self.sendmail(
+ self.hs.config.email_smtp_host,
+ self._raw_from,
+ raw_to,
+ multipart_msg.as_string().encode("utf8"),
+ reactor=self.hs.get_reactor(),
+ port=self.hs.config.email_smtp_port,
+ requireAuthentication=self.hs.config.email_smtp_user is not None,
+ username=self.hs.config.email_smtp_user,
+ password=self.hs.config.email_smtp_pass,
+ requireTransportSecurity=self.hs.config.require_transport_security,
+ )
+ )
- @defer.inlineCallbacks
- def _get_email_addresses_for_user(self, user_id):
+ await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
+
+ async def _get_email_addresses_for_user(self, user_id: str) -> List[str]:
"""Retrieve the list of email addresses attached to a user's account.
Args:
- user_id (str): ID of the user to lookup email addresses for.
+ user_id: ID of the user to lookup email addresses for.
Returns:
- defer.Deferred[list[str]]: Email addresses for this account.
+ Email addresses for this account.
"""
- threepids = yield self.store.user_get_threepids(user_id)
+ threepids = await self.store.user_get_threepids(user_id)
addresses = []
for threepid in threepids:
if threepid["medium"] == "email":
addresses.append(threepid["address"])
- defer.returnValue(addresses)
+ return addresses
- @defer.inlineCallbacks
- def _get_renewal_token(self, user_id):
+ async def _get_renewal_token(self, user_id: str) -> str:
"""Generates a 32-byte long random string that will be inserted into the
user's renewal email's unique link, then saves it into the database.
Args:
- user_id (str): ID of the user to generate a string for.
+ user_id: ID of the user to generate a string for.
Returns:
- defer.Deferred[str]: The generated string.
+ The generated string.
Raises:
StoreError(500): Couldn't generate a unique string after 5 attempts.
@@ -222,63 +228,63 @@ class AccountValidityHandler(object):
while attempts < 5:
try:
renewal_token = stringutils.random_string(32)
- yield self.store.set_renewal_token_for_user(user_id, renewal_token)
- defer.returnValue(renewal_token)
+ await self.store.set_renewal_token_for_user(user_id, renewal_token)
+ return renewal_token
except StoreError:
attempts += 1
raise StoreError(500, "Couldn't generate a unique string as refresh string.")
- @defer.inlineCallbacks
- def renew_account(self, renewal_token):
+ async def renew_account(self, renewal_token: str) -> bool:
"""Renews the account attached to a given renewal token by pushing back the
expiration date by the current validity period in the server's configuration.
Args:
- renewal_token (str): Token sent with the renewal request.
+ renewal_token: Token sent with the renewal request.
Returns:
- bool: Whether the provided token is valid.
+ Whether the provided token is valid.
"""
try:
- user_id = yield self.store.get_user_from_renewal_token(renewal_token)
+ user_id = await self.store.get_user_from_renewal_token(renewal_token)
except StoreError:
- defer.returnValue(False)
+ return False
logger.debug("Renewing an account for user %s", user_id)
- yield self.renew_account_for_user(user_id)
+ await self.renew_account_for_user(user_id)
- defer.returnValue(True)
+ return True
- @defer.inlineCallbacks
- def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False):
+ async def renew_account_for_user(
+ self, user_id: str, expiration_ts: int = None, email_sent: bool = False
+ ) -> int:
"""Renews the account attached to a given user by pushing back the
expiration date by the current validity period in the server's
configuration.
Args:
- renewal_token (str): Token sent with the renewal request.
- expiration_ts (int): New expiration date. Defaults to now + validity period.
- email_sent (bool): Whether an email has been sent for this validity period.
+ renewal_token: Token sent with the renewal request.
+ expiration_ts: New expiration date. Defaults to now + validity period.
+ email_sen: Whether an email has been sent for this validity period.
Defaults to False.
Returns:
- defer.Deferred[int]: New expiration date for this account, as a timestamp
- in milliseconds since epoch.
+ New expiration date for this account, as a timestamp in
+ milliseconds since epoch.
"""
if expiration_ts is None:
expiration_ts = self.clock.time_msec() + self._account_validity.period
- yield self.store.set_account_validity_for_user(
- user_id=user_id,
- expiration_ts=expiration_ts,
- email_sent=email_sent,
+ await self.store.set_account_validity_for_user(
+ user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
)
# Check if renewed users should be reintroduced to the user directory
if self._show_users_in_user_directory:
# Show the user in the directory again by setting them to active
- yield self.profile_handler.set_active(UserID.from_string(user_id), True, True)
+ await self.profile_handler.set_active(
+ UserID.from_string(user_id), True, True
+ )
- defer.returnValue(expiration_ts)
+ return expiration_ts
@defer.inlineCallbacks
def _mark_expired_users_as_inactive(self):
@@ -290,10 +296,7 @@ class AccountValidityHandler(object):
"""
# Get expired users
expired_user_ids = yield self.store.get_expired_users()
- expired_users = [
- UserID.from_string(user_id)
- for user_id in expired_user_ids
- ]
+ expired_users = [UserID.from_string(user_id) for user_id in expired_user_ids]
# Mark each one as non-active
for user in expired_users:
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
index 813777bf18..a2d7959abe 100644
--- a/synapse/handlers/acme.py
+++ b/synapse/handlers/acme.py
@@ -15,14 +15,9 @@
import logging
-import attr
-from zope.interface import implementer
-
import twisted
import twisted.internet.error
from twisted.internet import defer
-from twisted.python.filepath import FilePath
-from twisted.python.url import URL
from twisted.web import server, static
from twisted.web.resource import Resource
@@ -30,26 +25,14 @@ from synapse.app import check_bind_error
logger = logging.getLogger(__name__)
-try:
- from txacme.interfaces import ICertificateStore
-
- @attr.s
- @implementer(ICertificateStore)
- class ErsatzStore(object):
- """
- A store that only stores in memory.
- """
-
- certs = attr.ib(default=attr.Factory(dict))
-
- def store(self, server_name, pem_objects):
- self.certs[server_name] = [o.as_bytes() for o in pem_objects]
- return defer.succeed(None)
-
-
-except ImportError:
- # txacme is missing
- pass
+ACME_REGISTER_FAIL_ERROR = """
+--------------------------------------------------------------------------------
+Failed to register with the ACME provider. This is likely happening because the installation
+is new, and ACME v1 has been deprecated by Let's Encrypt and disabled for
+new installations since November 2019.
+At the moment, Synapse doesn't support ACME v2. For more information and alternative
+solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/ACME.md#deprecation-of-acme-v1
+--------------------------------------------------------------------------------"""
class AcmeHandler(object):
@@ -60,6 +43,7 @@ class AcmeHandler(object):
@defer.inlineCallbacks
def start_listening(self):
+ from synapse.handlers import acme_issuing_service
# Configure logging for txacme, if you need to debug
# from eliot import add_destinations
@@ -67,50 +51,27 @@ class AcmeHandler(object):
#
# add_destinations(TwistedDestination())
- from txacme.challenges import HTTP01Responder
- from txacme.service import AcmeIssuingService
- from txacme.endpoint import load_or_create_client_key
- from txacme.client import Client
- from josepy.jwa import RS256
-
- self._store = ErsatzStore()
- responder = HTTP01Responder()
-
- self._issuer = AcmeIssuingService(
- cert_store=self._store,
- client_creator=(
- lambda: Client.from_url(
- reactor=self.reactor,
- url=URL.from_text(self.hs.config.acme_url),
- key=load_or_create_client_key(
- FilePath(self.hs.config.config_dir_path)
- ),
- alg=RS256,
- )
- ),
- clock=self.reactor,
- responders=[responder],
+ well_known = Resource()
+
+ self._issuer = acme_issuing_service.create_issuing_service(
+ self.reactor,
+ acme_url=self.hs.config.acme_url,
+ account_key_file=self.hs.config.acme_account_key_file,
+ well_known_resource=well_known,
)
- well_known = Resource()
- well_known.putChild(b'acme-challenge', responder.resource)
responder_resource = Resource()
- responder_resource.putChild(b'.well-known', well_known)
- responder_resource.putChild(b'check', static.Data(b'OK', b'text/plain'))
-
+ responder_resource.putChild(b".well-known", well_known)
+ responder_resource.putChild(b"check", static.Data(b"OK", b"text/plain"))
srv = server.Site(responder_resource)
bind_addresses = self.hs.config.acme_bind_addresses
for host in bind_addresses:
logger.info(
- "Listening for ACME requests on %s:%i", host, self.hs.config.acme_port,
+ "Listening for ACME requests on %s:%i", host, self.hs.config.acme_port
)
try:
- self.reactor.listenTCP(
- self.hs.config.acme_port,
- srv,
- interface=host,
- )
+ self.reactor.listenTCP(self.hs.config.acme_port, srv, interface=host)
except twisted.internet.error.CannotListenError as e:
check_bind_error(e, host, bind_addresses)
@@ -119,7 +80,12 @@ class AcmeHandler(object):
# want it to control where we save the certificates, we have to reach in
# and trigger the registration machinery ourselves.
self._issuer._registered = False
- yield self._issuer._ensure_registered()
+
+ try:
+ yield self._issuer._ensure_registered()
+ except Exception:
+ logger.error(ACME_REGISTER_FAIL_ERROR)
+ raise
@defer.inlineCallbacks
def provision_certificate(self):
@@ -132,7 +98,7 @@ class AcmeHandler(object):
logger.exception("Fail!")
raise
logger.warning("Reprovisioned %s, saving.", self._acme_domain)
- cert_chain = self._store.certs[self._acme_domain]
+ cert_chain = self._issuer.cert_store.certs[self._acme_domain]
try:
with open(self.hs.config.tls_private_key_file, "wb") as private_key_file:
@@ -148,4 +114,4 @@ class AcmeHandler(object):
logger.exception("Failed saving!")
raise
- defer.returnValue(True)
+ return True
diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py
new file mode 100644
index 0000000000..e1d4224e74
--- /dev/null
+++ b/synapse/handlers/acme_issuing_service.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Utility function to create an ACME issuing service.
+
+This file contains the unconditional imports on the acme and cryptography bits that we
+only need (and may only have available) if we are doing ACME, so is designed to be
+imported conditionally.
+"""
+import logging
+
+import attr
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import serialization
+from josepy import JWKRSA
+from josepy.jwa import RS256
+from txacme.challenges import HTTP01Responder
+from txacme.client import Client
+from txacme.interfaces import ICertificateStore
+from txacme.service import AcmeIssuingService
+from txacme.util import generate_private_key
+from zope.interface import implementer
+
+from twisted.internet import defer
+from twisted.python.filepath import FilePath
+from twisted.python.url import URL
+
+logger = logging.getLogger(__name__)
+
+
+def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource):
+ """Create an ACME issuing service, and attach it to a web Resource
+
+ Args:
+ reactor: twisted reactor
+ acme_url (str): URL to use to request certificates
+ account_key_file (str): where to store the account key
+ well_known_resource (twisted.web.IResource): web resource for .well-known.
+ we will attach a child resource for "acme-challenge".
+
+ Returns:
+ AcmeIssuingService
+ """
+ responder = HTTP01Responder()
+
+ well_known_resource.putChild(b"acme-challenge", responder.resource)
+
+ store = ErsatzStore()
+
+ return AcmeIssuingService(
+ cert_store=store,
+ client_creator=(
+ lambda: Client.from_url(
+ reactor=reactor,
+ url=URL.from_text(acme_url),
+ key=load_or_create_client_key(account_key_file),
+ alg=RS256,
+ )
+ ),
+ clock=reactor,
+ responders=[responder],
+ )
+
+
+@attr.s
+@implementer(ICertificateStore)
+class ErsatzStore(object):
+ """
+ A store that only stores in memory.
+ """
+
+ certs = attr.ib(default=attr.Factory(dict))
+
+ def store(self, server_name, pem_objects):
+ self.certs[server_name] = [o.as_bytes() for o in pem_objects]
+ return defer.succeed(None)
+
+
+def load_or_create_client_key(key_file):
+ """Load the ACME account key from a file, creating it if it does not exist.
+
+ Args:
+ key_file (str): name of the file to use as the account key
+ """
+ # this is based on txacme.endpoint.load_or_create_client_key, but doesn't
+ # hardcode the 'client.key' filename
+ acme_key_file = FilePath(key_file)
+ if acme_key_file.exists():
+ logger.info("Loading ACME account key from '%s'", acme_key_file)
+ key = serialization.load_pem_private_key(
+ acme_key_file.getContent(), password=None, backend=default_backend()
+ )
+ else:
+ logger.info("Saving new ACME account key to '%s'", acme_key_file)
+ key = generate_private_key("rsa")
+ acme_key_file.setContent(
+ key.private_bytes(
+ encoding=serialization.Encoding.PEM,
+ format=serialization.PrivateFormat.TraditionalOpenSSL,
+ encryption_algorithm=serialization.NoEncryption(),
+ )
+ )
+ return JWKRSA(key=key)
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 5d629126fc..f3c0aeceb6 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -14,8 +14,12 @@
# limitations under the License.
import logging
+from typing import List
-from twisted.internet import defer
+from synapse.api.constants import Membership
+from synapse.events import FrozenEvent
+from synapse.types import RoomStreamToken, StateMap
+from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -23,76 +27,208 @@ logger = logging.getLogger(__name__)
class AdminHandler(BaseHandler):
-
def __init__(self, hs):
super(AdminHandler, self).__init__(hs)
- @defer.inlineCallbacks
- def get_whois(self, user):
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
+
+ async def get_whois(self, user):
connections = []
- sessions = yield self.store.get_user_ip_and_agents(user)
+ sessions = await self.store.get_user_ip_and_agents(user)
for session in sessions:
- connections.append({
- "ip": session["ip"],
- "last_seen": session["last_seen"],
- "user_agent": session["user_agent"],
- })
+ connections.append(
+ {
+ "ip": session["ip"],
+ "last_seen": session["last_seen"],
+ "user_agent": session["user_agent"],
+ }
+ )
ret = {
"user_id": user.to_string(),
- "devices": {
- "": {
- "sessions": [
- {
- "connections": connections,
- }
- ]
- },
- },
+ "devices": {"": {"sessions": [{"connections": connections}]}},
}
- defer.returnValue(ret)
+ return ret
+
+ async def get_user(self, user):
+ """Function to get user details"""
+ ret = await self.store.get_user_by_id(user.to_string())
+ if ret:
+ profile = await self.store.get_profileinfo(user.localpart)
+ threepids = await self.store.user_get_threepids(user.to_string())
+ ret["displayname"] = profile.display_name
+ ret["avatar_url"] = profile.avatar_url
+ ret["threepids"] = threepids
+ return ret
- @defer.inlineCallbacks
- def get_users(self):
- """Function to reterive a list of users in users table.
+ async def export_user_data(self, user_id, writer):
+ """Write all data we have on the user to the given writer.
Args:
+ user_id (str)
+ writer (ExfiltrationWriter)
+
Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
+ defer.Deferred: Resolves when all data for a user has been written.
+ The returned value is that returned by `writer.finished()`.
"""
- ret = yield self.store.get_users()
-
- defer.returnValue(ret)
+ # Get all rooms the user is in or has been in
+ rooms = await self.store.get_rooms_for_local_user_where_membership_is(
+ user_id,
+ membership_list=(
+ Membership.JOIN,
+ Membership.LEAVE,
+ Membership.BAN,
+ Membership.INVITE,
+ ),
+ )
+
+ # We only try and fetch events for rooms the user has been in. If
+ # they've been e.g. invited to a room without joining then we handle
+ # those seperately.
+ rooms_user_has_been_in = await self.store.get_rooms_user_has_been_in(user_id)
+
+ for index, room in enumerate(rooms):
+ room_id = room.room_id
+
+ logger.info(
+ "[%s] Handling room %s, %d/%d", user_id, room_id, index + 1, len(rooms)
+ )
+
+ forgotten = await self.store.did_forget(user_id, room_id)
+ if forgotten:
+ logger.info("[%s] User forgot room %d, ignoring", user_id, room_id)
+ continue
+
+ if room_id not in rooms_user_has_been_in:
+ # If we haven't been in the rooms then the filtering code below
+ # won't return anything, so we need to handle these cases
+ # explicitly.
+
+ if room.membership == Membership.INVITE:
+ event_id = room.event_id
+ invite = await self.store.get_event(event_id, allow_none=True)
+ if invite:
+ invited_state = invite.unsigned["invite_room_state"]
+ writer.write_invite(room_id, invite, invited_state)
+
+ continue
+
+ # We only want to bother fetching events up to the last time they
+ # were joined. We estimate that point by looking at the
+ # stream_ordering of the last membership if it wasn't a join.
+ if room.membership == Membership.JOIN:
+ stream_ordering = self.store.get_room_max_stream_ordering()
+ else:
+ stream_ordering = room.stream_ordering
+
+ from_key = str(RoomStreamToken(0, 0))
+ to_key = str(RoomStreamToken(None, stream_ordering))
+
+ written_events = set() # Events that we've processed in this room
+
+ # We need to track gaps in the events stream so that we can then
+ # write out the state at those events. We do this by keeping track
+ # of events whose prev events we haven't seen.
+
+ # Map from event ID to prev events that haven't been processed,
+ # dict[str, set[str]].
+ event_to_unseen_prevs = {}
+
+ # The reverse mapping to above, i.e. map from unseen event to events
+ # that have the unseen event in their prev_events, i.e. the unseen
+ # events "children". dict[str, set[str]]
+ unseen_to_child_events = {}
+
+ # We fetch events in the room the user could see by fetching *all*
+ # events that we have and then filtering, this isn't the most
+ # efficient method perhaps but it does guarantee we get everything.
+ while True:
+ events, _ = await self.store.paginate_room_events(
+ room_id, from_key, to_key, limit=100, direction="f"
+ )
+ if not events:
+ break
+
+ from_key = events[-1].internal_metadata.after
+
+ events = await filter_events_for_client(self.storage, user_id, events)
+
+ writer.write_events(room_id, events)
+
+ # Update the extremity tracking dicts
+ for event in events:
+ # Check if we have any prev events that haven't been
+ # processed yet, and add those to the appropriate dicts.
+ unseen_events = set(event.prev_event_ids()) - written_events
+ if unseen_events:
+ event_to_unseen_prevs[event.event_id] = unseen_events
+ for unseen in unseen_events:
+ unseen_to_child_events.setdefault(unseen, set()).add(
+ event.event_id
+ )
+
+ # Now check if this event is an unseen prev event, if so
+ # then we remove this event from the appropriate dicts.
+ for child_id in unseen_to_child_events.pop(event.event_id, []):
+ event_to_unseen_prevs[child_id].discard(event.event_id)
+
+ written_events.add(event.event_id)
+
+ logger.info(
+ "Written %d events in room %s", len(written_events), room_id
+ )
+
+ # Extremities are the events who have at least one unseen prev event.
+ extremities = (
+ event_id
+ for event_id, unseen_prevs in event_to_unseen_prevs.items()
+ if unseen_prevs
+ )
+ for event_id in extremities:
+ if not event_to_unseen_prevs[event_id]:
+ continue
+ state = await self.state_store.get_state_for_event(event_id)
+ writer.write_state(room_id, event_id, state)
+
+ return writer.finished()
+
+
+class ExfiltrationWriter(object):
+ """Interface used to specify how to write exported data.
+ """
+
+ def write_events(self, room_id: str, events: List[FrozenEvent]):
+ """Write a batch of events for a room.
+ """
+ pass
- @defer.inlineCallbacks
- def get_users_paginate(self, order, start, limit):
- """Function to reterive a paginated list of users from
- users list. This will return a json object, which contains
- list of users and the total number of users in users table.
+ def write_state(self, room_id: str, event_id: str, state: StateMap[FrozenEvent]):
+ """Write the state at the given event in the room.
- Args:
- order (str): column name to order the select by this column
- start (int): start number to begin the query from
- limit (int): number of rows to reterive
- Returns:
- defer.Deferred: resolves to json object {list[dict[str, Any]], count}
+ This only gets called for backward extremities rather than for each
+ event.
"""
- ret = yield self.store.get_users_paginate(order, start, limit)
-
- defer.returnValue(ret)
+ pass
- @defer.inlineCallbacks
- def search_users(self, term):
- """Function to search users list for one or more users with
- the matched term.
+ def write_invite(self, room_id: str, event: FrozenEvent, state: StateMap[dict]):
+ """Write an invite for the room, with associated invite state.
Args:
- term (str): search term
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
+ room_id
+ event
+ state: A subset of the state at the
+ invite, with a subset of the event keys (type, state_key
+ content and sender)
"""
- ret = yield self.store.search_users(term)
- defer.returnValue(ret)
+ def finished(self):
+ """Called when all data has succesfully been exported and written.
+
+ This functions return value is passed to the caller of
+ `export_user_data`.
+ """
+ pass
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 17eedf4dbf..fe62f78e67 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -23,13 +23,13 @@ from twisted.internet import defer
import synapse
from synapse.api.constants import EventTypes
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import (
event_processing_loop_counter,
event_processing_loop_room_count,
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import log_failure
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@@ -38,7 +38,6 @@ events_processed_counter = Counter("synapse_handlers_appservice_events_processed
class ApplicationServicesHandler(object):
-
def __init__(self, hs):
self.store = hs.get_datastore()
self.is_mine_id = hs.is_mine_id
@@ -74,7 +73,10 @@ class ApplicationServicesHandler(object):
try:
limit = 100
while True:
- upper_bound, events = yield self.store.get_new_events_for_appservice(
+ (
+ upper_bound,
+ events,
+ ) = yield self.store.get_new_events_for_appservice(
self.current_max, limit
)
@@ -101,9 +103,10 @@ class ApplicationServicesHandler(object):
yield self._check_user_exists(event.state_key)
if not self.started_scheduler:
+
def start_scheduler():
return self.scheduler.start().addErrback(
- log_failure, "Application Services Failure",
+ log_failure, "Application Services Failure"
)
run_as_background_process("as_scheduler", start_scheduler)
@@ -118,10 +121,15 @@ class ApplicationServicesHandler(object):
for event in events:
yield handle_event(event)
- yield make_deferred_yieldable(defer.gatherResults([
- run_in_background(handle_room_events, evs)
- for evs in itervalues(events_by_room)
- ], consumeErrors=True))
+ yield make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(handle_room_events, evs)
+ for evs in itervalues(events_by_room)
+ ],
+ consumeErrors=True,
+ )
+ )
yield self.store.set_appservice_last_pos(upper_bound)
@@ -129,20 +137,23 @@ class ApplicationServicesHandler(object):
ts = yield self.store.get_received_ts(events[-1].event_id)
synapse.metrics.event_processing_positions.labels(
- "appservice_sender").set(upper_bound)
+ "appservice_sender"
+ ).set(upper_bound)
events_processed_counter.inc(len(events))
- event_processing_loop_room_count.labels(
- "appservice_sender"
- ).inc(len(events_by_room))
+ event_processing_loop_room_count.labels("appservice_sender").inc(
+ len(events_by_room)
+ )
event_processing_loop_counter.labels("appservice_sender").inc()
synapse.metrics.event_processing_lag.labels(
- "appservice_sender").set(now - ts)
+ "appservice_sender"
+ ).set(now - ts)
synapse.metrics.event_processing_last_ts.labels(
- "appservice_sender").set(ts)
+ "appservice_sender"
+ ).set(ts)
finally:
self.is_processing = False
@@ -155,16 +166,12 @@ class ApplicationServicesHandler(object):
Returns:
True if this user exists on at least one application service.
"""
- user_query_services = yield self._get_services_for_user(
- user_id=user_id
- )
+ user_query_services = yield self._get_services_for_user(user_id=user_id)
for user_service in user_query_services:
- is_known_user = yield self.appservice_api.query_user(
- user_service, user_id
- )
+ is_known_user = yield self.appservice_api.query_user(user_service, user_id)
if is_known_user:
- defer.returnValue(True)
- defer.returnValue(False)
+ return True
+ return False
@defer.inlineCallbacks
def query_room_alias_exists(self, room_alias):
@@ -179,9 +186,7 @@ class ApplicationServicesHandler(object):
room_alias_str = room_alias.to_string()
services = self.store.get_app_services()
alias_query_services = [
- s for s in services if (
- s.is_interested_in_alias(room_alias_str)
- )
+ s for s in services if (s.is_interested_in_alias(room_alias_str))
]
for alias_service in alias_query_services:
is_known_alias = yield self.appservice_api.query_alias(
@@ -189,29 +194,31 @@ class ApplicationServicesHandler(object):
)
if is_known_alias:
# the alias exists now so don't query more ASes.
- result = yield self.store.get_association_from_room_alias(
- room_alias
- )
- defer.returnValue(result)
+ result = yield self.store.get_association_from_room_alias(room_alias)
+ return result
@defer.inlineCallbacks
def query_3pe(self, kind, protocol, fields):
services = yield self._get_services_for_3pn(protocol)
- results = yield make_deferred_yieldable(defer.DeferredList([
- run_in_background(
- self.appservice_api.query_3pe,
- service, kind, protocol, fields,
+ results = yield make_deferred_yieldable(
+ defer.DeferredList(
+ [
+ run_in_background(
+ self.appservice_api.query_3pe, service, kind, protocol, fields
+ )
+ for service in services
+ ],
+ consumeErrors=True,
)
- for service in services
- ], consumeErrors=True))
+ )
ret = []
for (success, result) in results:
if success:
ret.extend(result)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def get_3pe_protocols(self, only_protocol=None):
@@ -250,7 +257,7 @@ class ApplicationServicesHandler(object):
for p in protocols.keys():
protocols[p] = _merge_instances(protocols[p])
- defer.returnValue(protocols)
+ return protocols
@defer.inlineCallbacks
def _get_services_for_event(self, event):
@@ -272,22 +279,16 @@ class ApplicationServicesHandler(object):
if (yield s.is_interested(event, self.store)):
interested_list.append(s)
- defer.returnValue(interested_list)
+ return interested_list
def _get_services_for_user(self, user_id):
services = self.store.get_app_services()
- interested_list = [
- s for s in services if (
- s.is_interested_in_user(user_id)
- )
- ]
+ interested_list = [s for s in services if (s.is_interested_in_user(user_id))]
return defer.succeed(interested_list)
def _get_services_for_3pn(self, protocol):
services = self.store.get_app_services()
- interested_list = [
- s for s in services if s.is_interested_in_protocol(protocol)
- ]
+ interested_list = [s for s in services if s.is_interested_in_protocol(protocol)]
return defer.succeed(interested_list)
@defer.inlineCallbacks
@@ -295,23 +296,21 @@ class ApplicationServicesHandler(object):
if not self.is_mine_id(user_id):
# we don't know if they are unknown or not since it isn't one of our
# users. We can't poke ASes.
- defer.returnValue(False)
- return
+ return False
user_info = yield self.store.get_user_by_id(user_id)
if user_info:
- defer.returnValue(False)
- return
+ return False
# user not found; could be the AS though, so check.
services = self.store.get_app_services()
service_list = [s for s in services if s.sender == user_id]
- defer.returnValue(len(service_list) == 0)
+ return len(service_list) == 0
@defer.inlineCallbacks
def _check_user_exists(self, user_id):
unknown_user = yield self._is_unknown_user(user_id)
if unknown_user:
exists = yield self.query_user_exists(user_id)
- defer.returnValue(exists)
- defer.returnValue(True)
+ return exists
+ return True
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 9a2ff177a6..7860f9625e 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -15,15 +15,16 @@
# limitations under the License.
import logging
+import time
import unicodedata
+import urllib.parse
+from typing import Any, Dict, Iterable, List, Optional
import attr
-import bcrypt
+import bcrypt # type: ignore[import]
import pymacaroons
-from canonicaljson import json
from twisted.internet import defer
-from twisted.web.client import PartialDownloadError
import synapse.util.stringutils as stringutils
from synapse.api.constants import LoginType
@@ -34,11 +35,17 @@ from synapse.api.errors import (
LoginError,
StoreError,
SynapseError,
+ UserDeactivatedError,
)
from synapse.api.ratelimiting import Ratelimiter
+from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
+from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
+from synapse.http.server import finish_request
+from synapse.http.site import SynapseRequest
+from synapse.logging.context import defer_to_thread
from synapse.module_api import ModuleApi
-from synapse.types import UserID
-from synapse.util import logcontext
+from synapse.push.mailer import load_jinja2_templates
+from synapse.types import Requester, UserID
from synapse.util.caches.expiringcache import ExpiringCache
from ._base import BaseHandler
@@ -55,13 +62,13 @@ class AuthHandler(BaseHandler):
hs (synapse.server.HomeServer):
"""
super(AuthHandler, self).__init__(hs)
- self.checkers = {
- LoginType.RECAPTCHA: self._check_recaptcha,
- LoginType.EMAIL_IDENTITY: self._check_email_identity,
- LoginType.MSISDN: self._check_msisdn,
- LoginType.DUMMY: self._check_dummy_auth,
- LoginType.TERMS: self._check_terms_auth,
- }
+
+ self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]
+ for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
+ inst = auth_checker_class(hs)
+ if inst.is_enabled():
+ self.checkers[inst.AUTH_TYPE] = inst # type: ignore
+
self.bcrypt_rounds = hs.config.bcrypt_rounds
# This is not a cache per se, but a store of all current sessions that
@@ -100,13 +107,26 @@ class AuthHandler(BaseHandler):
login_types.append(t)
self._supported_login_types = login_types
- self._account_ratelimiter = Ratelimiter()
- self._failed_attempts_ratelimiter = Ratelimiter()
+ # Ratelimiter for failed auth during UIA. Uses same ratelimit config
+ # as per `rc_login.failed_attempts`.
+ self._failed_uia_attempts_ratelimiter = Ratelimiter()
self._clock = self.hs.get_clock()
+ # Load the SSO redirect confirmation page HTML template
+ self._sso_redirect_confirm_template = load_jinja2_templates(
+ hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
+ )[0]
+
+ self._server_name = hs.config.server_name
+
+ # cast to tuple for use with str.startswith
+ self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
+
@defer.inlineCallbacks
- def validate_user_via_ui_auth(self, requester, request_body, clientip):
+ def validate_user_via_ui_auth(
+ self, requester: Requester, request_body: Dict[str, Any], clientip: str
+ ):
"""
Checks that the user is who they claim to be, via a UI auth.
@@ -115,11 +135,11 @@ class AuthHandler(BaseHandler):
that it isn't stolen by re-authenticating them.
Args:
- requester (Requester): The user, as given by the access token
+ requester: The user, as given by the access token
- request_body (dict): The body of the request sent by the client
+ request_body: The body of the request sent by the client
- clientip (str): The IP address of the client.
+ clientip: The IP address of the client.
Returns:
defer.Deferred[dict]: the parameters for this request (which may
@@ -131,17 +151,39 @@ class AuthHandler(BaseHandler):
AuthError if the client has completed a login flow, and it gives
a different user to `requester`
+
+ LimitExceededError if the ratelimiter's failed request count for this
+ user is too high to proceed
+
"""
- # build a list of supported flows
- flows = [
- [login_type] for login_type in self._supported_login_types
- ]
+ user_id = requester.user.to_string()
- result, params, _ = yield self.check_auth(
- flows, request_body, clientip,
+ # Check if we should be ratelimited due to too many previous failed attempts
+ self._failed_uia_attempts_ratelimiter.ratelimit(
+ user_id,
+ time_now_s=self._clock.time(),
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ update=False,
)
+ # build a list of supported flows
+ flows = [[login_type] for login_type in self._supported_login_types]
+
+ try:
+ result, params, _ = yield self.check_auth(flows, request_body, clientip)
+ except LoginError:
+ # Update the ratelimite to say we failed (`can_do_action` doesn't raise).
+ self._failed_uia_attempts_ratelimiter.can_do_action(
+ user_id,
+ time_now_s=self._clock.time(),
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ update=True,
+ )
+ raise
+
# find the completed login type
for login_type in self._supported_login_types:
if login_type not in result:
@@ -151,18 +193,26 @@ class AuthHandler(BaseHandler):
break
else:
# this can't happen
- raise Exception(
- "check_auth returned True but no successful login type",
- )
+ raise Exception("check_auth returned True but no successful login type")
# check that the UI auth matched the access token
if user_id != requester.user.to_string():
raise AuthError(403, "Invalid auth")
- defer.returnValue(params)
+ return params
+
+ def get_enabled_auth_types(self):
+ """Return the enabled user-interactive authentication types
+
+ Returns the UI-Auth types which are supported by the homeserver's current
+ config.
+ """
+ return self.checkers.keys()
@defer.inlineCallbacks
- def check_auth(self, flows, clientdict, clientip, password_servlet=False):
+ def check_auth(
+ self, flows: List[List[str]], clientdict: Dict[str, Any], clientip: str
+ ):
"""
Takes a dictionary sent by the client in the login / registration
protocol and handles the User-Interactive Auth flow.
@@ -177,24 +227,14 @@ class AuthHandler(BaseHandler):
decorator.
Args:
- flows (list): A list of login flows. Each flow is an ordered list of
- strings representing auth-types. At least one full
- flow must be completed in order for auth to be successful.
+ flows: A list of login flows. Each flow is an ordered list of
+ strings representing auth-types. At least one full
+ flow must be completed in order for auth to be successful.
clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
- clientip (str): The IP address of the client.
-
- password_servlet (bool): Whether the request originated from
- PasswordRestServlet.
- XXX: This is a temporary hack to distinguish between checking
- for threepid validations locally (in the case of password
- resets) and using the identity server (in the case of binding
- a 3PID during registration). Once we start using the
- homeserver for both tasks, this distinction will no longer be
- necessary.
-
+ clientip: The IP address of the client.
Returns:
defer.Deferred[dict, dict, str]: a deferred tuple of
@@ -214,12 +254,12 @@ class AuthHandler(BaseHandler):
"""
authdict = None
- sid = None
- if clientdict and 'auth' in clientdict:
- authdict = clientdict['auth']
- del clientdict['auth']
- if 'session' in authdict:
- sid = authdict['session']
+ sid = None # type: Optional[str]
+ if clientdict and "auth" in clientdict:
+ authdict = clientdict["auth"]
+ del clientdict["auth"]
+ if "session" in authdict:
+ sid = authdict["session"]
session = self._get_session_info(sid)
if len(clientdict) > 0:
@@ -229,31 +269,29 @@ class AuthHandler(BaseHandler):
# could continue registration from your phone having clicked the
# email auth link on there). It's probably too open to abuse
# because it lets unauthenticated clients store arbitrary objects
- # on a home server.
+ # on a homeserver.
# Revisit: Assumimg the REST APIs do sensible validation, the data
# isn't arbintrary.
- session['clientdict'] = clientdict
+ session["clientdict"] = clientdict
self._save_session(session)
- elif 'clientdict' in session:
- clientdict = session['clientdict']
+ elif "clientdict" in session:
+ clientdict = session["clientdict"]
if not authdict:
raise InteractiveAuthIncompleteError(
- self._auth_dict_for_flows(flows, session),
+ self._auth_dict_for_flows(flows, session)
)
- if 'creds' not in session:
- session['creds'] = {}
- creds = session['creds']
+ if "creds" not in session:
+ session["creds"] = {}
+ creds = session["creds"]
# check auth type currently being presented
- errordict = {}
- if 'type' in authdict:
- login_type = authdict['type']
+ errordict = {} # type: Dict[str, Any]
+ if "type" in authdict:
+ login_type = authdict["type"] # type: str
try:
- result = yield self._check_auth_dict(
- authdict, clientip, password_servlet=password_servlet,
- )
+ result = yield self._check_auth_dict(authdict, clientip)
if result:
creds[login_type] = result
self._save_session(session)
@@ -281,43 +319,40 @@ class AuthHandler(BaseHandler):
# and is not sensitive).
logger.info(
"Auth completed with creds: %r. Client dict has keys: %r",
- creds, list(clientdict)
+ creds,
+ list(clientdict),
)
- defer.returnValue((creds, clientdict, session['id']))
+ return creds, clientdict, session["id"]
ret = self._auth_dict_for_flows(flows, session)
- ret['completed'] = list(creds)
+ ret["completed"] = list(creds)
ret.update(errordict)
- raise InteractiveAuthIncompleteError(
- ret,
- )
+ raise InteractiveAuthIncompleteError(ret)
@defer.inlineCallbacks
- def add_oob_auth(self, stagetype, authdict, clientip):
+ def add_oob_auth(self, stagetype: str, authdict: Dict[str, Any], clientip: str):
"""
Adds the result of out-of-band authentication into an existing auth
session. Currently used for adding the result of fallback auth.
"""
if stagetype not in self.checkers:
raise LoginError(400, "", Codes.MISSING_PARAM)
- if 'session' not in authdict:
+ if "session" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
- sess = self._get_session_info(
- authdict['session']
- )
- if 'creds' not in sess:
- sess['creds'] = {}
- creds = sess['creds']
+ sess = self._get_session_info(authdict["session"])
+ if "creds" not in sess:
+ sess["creds"] = {}
+ creds = sess["creds"]
- result = yield self.checkers[stagetype](authdict, clientip)
+ result = yield self.checkers[stagetype].check_auth(authdict, clientip)
if result:
creds[stagetype] = result
self._save_session(sess)
- defer.returnValue(True)
- defer.returnValue(False)
+ return True
+ return False
- def get_session_id(self, clientdict):
+ def get_session_id(self, clientdict: Dict[str, Any]) -> Optional[str]:
"""
Gets the session ID for a client given the client dictionary
@@ -325,50 +360,52 @@ class AuthHandler(BaseHandler):
clientdict: The dictionary sent by the client in the request
Returns:
- str|None: The string session ID the client sent. If the client did
+ The string session ID the client sent. If the client did
not send a session ID, returns None.
"""
sid = None
- if clientdict and 'auth' in clientdict:
- authdict = clientdict['auth']
- if 'session' in authdict:
- sid = authdict['session']
+ if clientdict and "auth" in clientdict:
+ authdict = clientdict["auth"]
+ if "session" in authdict:
+ sid = authdict["session"]
return sid
- def set_session_data(self, session_id, key, value):
+ def set_session_data(self, session_id: str, key: str, value: Any) -> None:
"""
Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by
the client.
Args:
- session_id (string): The ID of this session as returned from check_auth
- key (string): The key to store the data under
- value (any): The data to store
+ session_id: The ID of this session as returned from check_auth
+ key: The key to store the data under
+ value: The data to store
"""
sess = self._get_session_info(session_id)
- sess.setdefault('serverdict', {})[key] = value
+ sess.setdefault("serverdict", {})[key] = value
self._save_session(sess)
- def get_session_data(self, session_id, key, default=None):
+ def get_session_data(
+ self, session_id: str, key: str, default: Optional[Any] = None
+ ) -> Any:
"""
Retrieve data stored with set_session_data
Args:
- session_id (string): The ID of this session as returned from check_auth
- key (string): The key to store the data under
- default (any): Value to return if the key has not been set
+ session_id: The ID of this session as returned from check_auth
+ key: The key to store the data under
+ default: Value to return if the key has not been set
"""
sess = self._get_session_info(session_id)
- return sess.setdefault('serverdict', {}).get(key, default)
+ return sess.setdefault("serverdict", {}).get(key, default)
@defer.inlineCallbacks
- def _check_auth_dict(self, authdict, clientip, password_servlet=False):
+ def _check_auth_dict(self, authdict: Dict[str, Any], clientip: str):
"""Attempt to validate the auth dict provided by a client
Args:
- authdict (object): auth dict provided by the client
- clientip (str): IP address of the client
+ authdict: auth dict provided by the client
+ clientip: IP address of the client
Returns:
Deferred: result of the stage verification.
@@ -378,17 +415,11 @@ class AuthHandler(BaseHandler):
SynapseError if there was a problem with the request
LoginError if there was an authentication problem.
"""
- login_type = authdict['type']
+ login_type = authdict["type"]
checker = self.checkers.get(login_type)
if checker is not None:
- # XXX: Temporary workaround for having Synapse handle password resets
- # See AuthHandler.check_auth for further details
- res = yield checker(
- authdict,
- clientip=clientip,
- password_servlet=password_servlet,
- )
- defer.returnValue(res)
+ res = yield checker.check_auth(authdict, clientip=clientip)
+ return res
# build a v1-login-style dict out of the authdict and fall back to the
# v1 code
@@ -398,138 +429,31 @@ class AuthHandler(BaseHandler):
raise SynapseError(400, "", Codes.MISSING_PARAM)
(canonical_id, callback) = yield self.validate_login(user_id, authdict)
- defer.returnValue(canonical_id)
-
- @defer.inlineCallbacks
- def _check_recaptcha(self, authdict, clientip, **kwargs):
- try:
- user_response = authdict["response"]
- except KeyError:
- # Client tried to provide captcha but didn't give the parameter:
- # bad request.
- raise LoginError(
- 400, "Captcha response is required",
- errcode=Codes.CAPTCHA_NEEDED
- )
-
- logger.info(
- "Submitting recaptcha response %s with remoteip %s",
- user_response, clientip
- )
-
- # TODO: get this from the homeserver rather than creating a new one for
- # each request
- try:
- client = self.hs.get_proxied_http_client()
- resp_body = yield client.post_urlencoded_get_json(
- self.hs.config.recaptcha_siteverify_api,
- args={
- 'secret': self.hs.config.recaptcha_private_key,
- 'response': user_response,
- 'remoteip': clientip,
- }
- )
- except PartialDownloadError as pde:
- # Twisted is silly
- data = pde.response
- resp_body = json.loads(data)
-
- if 'success' in resp_body:
- # Note that we do NOT check the hostname here: we explicitly
- # intend the CAPTCHA to be presented by whatever client the
- # user is using, we just care that they have completed a CAPTCHA.
- logger.info(
- "%s reCAPTCHA from hostname %s",
- "Successful" if resp_body['success'] else "Failed",
- resp_body.get('hostname')
- )
- if resp_body['success']:
- defer.returnValue(True)
- raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
-
- def _check_email_identity(self, authdict, **kwargs):
- return self._check_threepid('email', authdict, **kwargs)
-
- def _check_msisdn(self, authdict, **kwargs):
- return self._check_threepid('msisdn', authdict)
-
- def _check_dummy_auth(self, authdict, **kwargs):
- return defer.succeed(True)
-
- def _check_terms_auth(self, authdict, **kwargs):
- return defer.succeed(True)
-
- @defer.inlineCallbacks
- def _check_threepid(self, medium, authdict, password_servlet=False, **kwargs):
- if 'threepid_creds' not in authdict:
- raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
-
- threepid_creds = authdict['threepid_creds']
-
- identity_handler = self.hs.get_handlers().identity_handler
-
- logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
- if (
- not password_servlet
- or self.hs.config.email_password_reset_behaviour == "remote"
- ):
- threepid = yield identity_handler.threepid_from_creds(threepid_creds)
- elif self.hs.config.email_password_reset_behaviour == "local":
- row = yield self.store.get_threepid_validation_session(
- medium,
- threepid_creds["client_secret"],
- sid=threepid_creds["sid"],
- validated=True,
- )
-
- threepid = {
- "medium": row["medium"],
- "address": row["address"],
- "validated_at": row["validated_at"],
- } if row else None
-
- if row:
- # Valid threepid returned, delete from the db
- yield self.store.delete_threepid_session(threepid_creds["sid"])
- else:
- raise SynapseError(400, "Password resets are not enabled on this homeserver")
-
- if not threepid:
- raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
-
- if threepid['medium'] != medium:
- raise LoginError(
- 401,
- "Expecting threepid of type '%s', got '%s'" % (
- medium, threepid['medium'],
- ),
- errcode=Codes.UNAUTHORIZED
- )
-
- threepid['threepid_creds'] = authdict['threepid_creds']
+ return canonical_id
- defer.returnValue(threepid)
-
- def _get_params_recaptcha(self):
+ def _get_params_recaptcha(self) -> dict:
return {"public_key": self.hs.config.recaptcha_public_key}
- def _get_params_terms(self):
+ def _get_params_terms(self) -> dict:
return {
"policies": {
"privacy_policy": {
"version": self.hs.config.user_consent_version,
"en": {
"name": self.hs.config.user_consent_policy_name,
- "url": "%s_matrix/consent?v=%s" % (
+ "url": "%s_matrix/consent?v=%s"
+ % (
self.hs.config.public_baseurl,
self.hs.config.user_consent_version,
),
},
- },
- },
+ }
+ }
}
- def _auth_dict_for_flows(self, flows, session):
+ def _auth_dict_for_flows(
+ self, flows: List[List[str]], session: Dict[str, Any]
+ ) -> Dict[str, Any]:
public_flows = []
for f in flows:
public_flows.append(f)
@@ -539,7 +463,7 @@ class AuthHandler(BaseHandler):
LoginType.TERMS: self._get_params_terms,
}
- params = {}
+ params = {} # type: Dict[str, Any]
for f in public_flows:
for stage in f:
@@ -547,12 +471,18 @@ class AuthHandler(BaseHandler):
params[stage] = get_params[stage]()
return {
- "session": session['id'],
+ "session": session["id"],
"flows": [{"stages": f} for f in public_flows],
- "params": params
+ "params": params,
}
- def _get_session_info(self, session_id):
+ def _get_session_info(self, session_id: Optional[str]) -> dict:
+ """
+ Gets or creates a session given a session ID.
+
+ The session can be used to track data across multiple requests, e.g. for
+ interactive authentication.
+ """
if session_id not in self.sessions:
session_id = None
@@ -560,14 +490,14 @@ class AuthHandler(BaseHandler):
# create a new session
while session_id is None or session_id in self.sessions:
session_id = stringutils.random_string(24)
- self.sessions[session_id] = {
- "id": session_id,
- }
+ self.sessions[session_id] = {"id": session_id}
return self.sessions[session_id]
@defer.inlineCallbacks
- def get_access_token_for_user_id(self, user_id, device_id=None):
+ def get_access_token_for_user_id(
+ self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
+ ):
"""
Creates a new access token for the user with the given user ID.
@@ -577,19 +507,31 @@ class AuthHandler(BaseHandler):
The device will be recorded in the table if it is not there already.
Args:
- user_id (str): canonical User ID
- device_id (str|None): the device ID to associate with the tokens.
+ user_id: canonical User ID
+ device_id: the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID)
+ valid_until_ms: when the token is valid until. None for
+ no expiry.
Returns:
The access token for the user's session.
Raises:
StoreError if there was a problem storing the token.
"""
- logger.info("Logging in user %s on device %s", user_id, device_id)
- access_token = yield self.issue_access_token(user_id, device_id)
+ fmt_expiry = ""
+ if valid_until_ms is not None:
+ fmt_expiry = time.strftime(
+ " until %Y-%m-%d %H:%M:%S", time.localtime(valid_until_ms / 1000.0)
+ )
+ logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry)
+
yield self.auth.check_auth_blocking(user_id)
+ access_token = self.macaroon_gen.generate_access_token(user_id)
+ yield self.store.add_access_token_to_user(
+ user_id, access_token, device_id, valid_until_ms
+ )
+
# the device *should* have been registered before we got here; however,
# it's possible we raced against a DELETE operation. The thing we
# really don't want is active access_tokens without a record of the
@@ -601,33 +543,31 @@ class AuthHandler(BaseHandler):
yield self.store.delete_access_token(access_token)
raise StoreError(400, "Login raced against device deletion")
- defer.returnValue(access_token)
+ return access_token
@defer.inlineCallbacks
- def check_user_exists(self, user_id):
+ def check_user_exists(self, user_id: str):
"""
Checks to see if a user with the given id exists. Will check case
insensitively, but return None if there are multiple inexact matches.
Args:
- (unicode|bytes) user_id: complete @user:id
+ user_id: complete @user:id
Returns:
defer.Deferred: (unicode) canonical_user_id, or None if zero or
multiple matches
Raises:
- LimitExceededError if the ratelimiter's login requests count for this
- user is too high too proceed.
+ UserDeactivatedError if a user is found but is deactivated.
"""
- self.ratelimit_login_per_account(user_id)
res = yield self._find_user_id_and_pwd_hash(user_id)
if res is not None:
- defer.returnValue(res[0])
- defer.returnValue(None)
+ return res[0]
+ return None
@defer.inlineCallbacks
- def _find_user_id_and_pwd_hash(self, user_id):
+ def _find_user_id_and_pwd_hash(self, user_id: str):
"""Checks to see if a user with the given id exists. Will check case
insensitively, but will return None if there are multiple inexact
matches.
@@ -640,7 +580,7 @@ class AuthHandler(BaseHandler):
result = None
if not user_infos:
- logger.warn("Attempted to login as %s but they do not exist", user_id)
+ logger.warning("Attempted to login as %s but they do not exist", user_id)
elif len(user_infos) == 1:
# a single match (possibly not exact)
result = user_infos.popitem()
@@ -649,14 +589,15 @@ class AuthHandler(BaseHandler):
result = (user_id, user_infos[user_id])
else:
# multiple matches, none of them exact
- logger.warn(
+ logger.warning(
"Attempted to login as %s but it matches more than one user "
"inexactly: %r",
- user_id, user_infos.keys()
+ user_id,
+ user_infos.keys(),
)
- defer.returnValue(result)
+ return result
- def get_supported_login_types(self):
+ def get_supported_login_types(self) -> Iterable[str]:
"""Get a the login types supported for the /login API
By default this is just 'm.login.password' (unless password_enabled is
@@ -664,20 +605,20 @@ class AuthHandler(BaseHandler):
other login types.
Returns:
- Iterable[str]: login types
+ login types
"""
return self._supported_login_types
@defer.inlineCallbacks
- def validate_login(self, username, login_submission):
+ def validate_login(self, username: str, login_submission: Dict[str, Any]):
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate
m.login.password auth types.
Args:
- username (str): username supplied by the user
- login_submission (dict): the whole of the login submission
+ username: username supplied by the user
+ login_submission: the whole of the login submission
(including 'type' and other relevant fields)
Returns:
Deferred[str, func]: canonical user id, and optional callback
@@ -686,18 +627,12 @@ class AuthHandler(BaseHandler):
StoreError if there was a problem accessing the database
SynapseError if there was a problem with the request
LoginError if there was an authentication problem.
- LimitExceededError if the ratelimiter's login requests count for this
- user is too high too proceed.
"""
- if username.startswith('@'):
+ if username.startswith("@"):
qualified_user_id = username
else:
- qualified_user_id = UserID(
- username, self.hs.hostname
- ).to_string()
-
- self.ratelimit_login_per_account(qualified_user_id)
+ qualified_user_id = UserID(username, self.hs.hostname).to_string()
login_type = login_submission.get("type")
known_login_type = False
@@ -713,17 +648,15 @@ class AuthHandler(BaseHandler):
raise SynapseError(400, "Missing parameter: password")
for provider in self.password_providers:
- if (hasattr(provider, "check_password")
- and login_type == LoginType.PASSWORD):
+ if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
known_login_type = True
- is_valid = yield provider.check_password(
- qualified_user_id, password,
- )
+ is_valid = yield provider.check_password(qualified_user_id, password)
if is_valid:
- defer.returnValue((qualified_user_id, None))
+ return qualified_user_id, None
- if (not hasattr(provider, "get_supported_login_types")
- or not hasattr(provider, "check_auth")):
+ if not hasattr(provider, "get_supported_login_types") or not hasattr(
+ provider, "check_auth"
+ ):
# this password provider doesn't understand custom login types
continue
@@ -744,56 +677,42 @@ class AuthHandler(BaseHandler):
login_dict[f] = login_submission[f]
if missing_fields:
raise SynapseError(
- 400, "Missing parameters for login type %s: %s" % (
- login_type,
- missing_fields,
- ),
+ 400,
+ "Missing parameters for login type %s: %s"
+ % (login_type, missing_fields),
)
- result = yield provider.check_auth(
- username, login_type, login_dict,
- )
+ result = yield provider.check_auth(username, login_type, login_dict)
if result:
if isinstance(result, str):
result = (result, None)
- defer.returnValue(result)
+ return result
- if login_type == LoginType.PASSWORD:
+ if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
known_login_type = True
canonical_user_id = yield self._check_local_password(
- qualified_user_id, password,
+ qualified_user_id, password
)
if canonical_user_id:
- defer.returnValue((canonical_user_id, None))
+ return canonical_user_id, None
if not known_login_type:
raise SynapseError(400, "Unknown login type %s" % login_type)
- # unknown username or invalid password.
- self._failed_attempts_ratelimiter.ratelimit(
- qualified_user_id.lower(), time_now_s=self._clock.time(),
- rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
- burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
- update=True,
- )
-
# We raise a 403 here, but note that if we're doing user-interactive
# login, it turns all LoginErrors into a 401 anyway.
- raise LoginError(
- 403, "Invalid password",
- errcode=Codes.FORBIDDEN
- )
+ raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
- def check_password_provider_3pid(self, medium, address, password):
+ def check_password_provider_3pid(self, medium: str, address: str, password: str):
"""Check if a password provider is able to validate a thirdparty login
Args:
- medium (str): The medium of the 3pid (ex. email).
- address (str): The address of the 3pid (ex. jdoe@example.com).
- password (str): The password of the user.
+ medium: The medium of the 3pid (ex. email).
+ address: The address of the 3pid (ex. jdoe@example.com).
+ password: The password of the user.
Returns:
Deferred[(str|None, func|None)]: A tuple of `(user_id,
@@ -810,73 +729,67 @@ class AuthHandler(BaseHandler):
# success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
- result = yield provider.check_3pid_auth(
- medium, address, password,
- )
+ result = yield provider.check_3pid_auth(medium, address, password)
if result:
# Check if the return value is a str or a tuple
if isinstance(result, str):
# If it's a str, set callback function to None
result = (result, None)
- defer.returnValue(result)
+ return result
- defer.returnValue((None, None))
+ return None, None
@defer.inlineCallbacks
- def _check_local_password(self, user_id, password):
+ def _check_local_password(self, user_id: str, password: str):
"""Authenticate a user against the local password database.
user_id is checked case insensitively, but will return None if there are
multiple inexact matches.
Args:
- user_id (unicode): complete @user:id
- password (unicode): the provided password
+ user_id: complete @user:id
+ password: the provided password
Returns:
Deferred[unicode] the canonical_user_id, or Deferred[None] if
unknown user/bad password
-
- Raises:
- LimitExceededError if the ratelimiter's login requests count for this
- user is too high too proceed.
"""
lookupres = yield self._find_user_id_and_pwd_hash(user_id)
if not lookupres:
- defer.returnValue(None)
+ return None
(user_id, password_hash) = lookupres
+
+ # If the password hash is None, the account has likely been deactivated
+ if not password_hash:
+ deactivated = yield self.store.get_user_deactivated_status(user_id)
+ if deactivated:
+ raise UserDeactivatedError("This account has been deactivated")
+
result = yield self.validate_hash(password, password_hash)
if not result:
- logger.warn("Failed password login for user %s", user_id)
- defer.returnValue(None)
- defer.returnValue(user_id)
-
- @defer.inlineCallbacks
- def issue_access_token(self, user_id, device_id=None):
- access_token = self.macaroon_gen.generate_access_token(user_id)
- yield self.store.add_access_token_to_user(user_id, access_token,
- device_id)
- defer.returnValue(access_token)
+ logger.warning("Failed password login for user %s", user_id)
+ return None
+ return user_id
@defer.inlineCallbacks
- def validate_short_term_login_token_and_get_user_id(self, login_token):
+ def validate_short_term_login_token_and_get_user_id(self, login_token: str):
auth_api = self.hs.get_auth()
user_id = None
try:
macaroon = pymacaroons.Macaroon.deserialize(login_token)
user_id = auth_api.get_user_id_from_macaroon(macaroon)
- auth_api.validate_macaroon(macaroon, "login", True, user_id)
+ auth_api.validate_macaroon(macaroon, "login", user_id)
except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
- self.ratelimit_login_per_account(user_id)
+
yield self.auth.check_auth_blocking(user_id)
- defer.returnValue(user_id)
+ return user_id
@defer.inlineCallbacks
- def delete_access_token(self, access_token):
+ def delete_access_token(self, access_token: str):
"""Invalidate a single access token
Args:
- access_token (str): access token to be deleted
+ access_token: access token to be deleted
Returns:
Deferred
@@ -896,26 +809,29 @@ class AuthHandler(BaseHandler):
# delete pushers associated with this access token
if user_info["token_id"] is not None:
yield self.hs.get_pusherpool().remove_pushers_by_access_token(
- str(user_info["user"]), (user_info["token_id"], )
+ str(user_info["user"]), (user_info["token_id"],)
)
@defer.inlineCallbacks
- def delete_access_tokens_for_user(self, user_id, except_token_id=None,
- device_id=None):
+ def delete_access_tokens_for_user(
+ self,
+ user_id: str,
+ except_token_id: Optional[str] = None,
+ device_id: Optional[str] = None,
+ ):
"""Invalidate access tokens belonging to a user
Args:
- user_id (str): ID of user the tokens belong to
- except_token_id (str|None): access_token ID which should *not* be
- deleted
- device_id (str|None): ID of device the tokens are associated with.
+ user_id: ID of user the tokens belong to
+ except_token_id: access_token ID which should *not* be deleted
+ device_id: ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
Returns:
Deferred
"""
tokens_and_devices = yield self.store.user_delete_access_tokens(
- user_id, except_token_id=except_token_id, device_id=device_id,
+ user_id, except_token_id=except_token_id, device_id=device_id
)
# see if any of our auth providers want to know about this
@@ -923,20 +839,26 @@ class AuthHandler(BaseHandler):
if hasattr(provider, "on_logged_out"):
for token, token_id, device_id in tokens_and_devices:
yield provider.on_logged_out(
- user_id=user_id,
- device_id=device_id,
- access_token=token,
+ user_id=user_id, device_id=device_id, access_token=token
)
# delete pushers associated with the access tokens
yield self.hs.get_pusherpool().remove_pushers_by_access_token(
- user_id, (token_id for _, token_id, _ in tokens_and_devices),
+ user_id, (token_id for _, token_id, _ in tokens_and_devices)
)
@defer.inlineCallbacks
- def add_threepid(self, user_id, medium, address, validated_at):
+ def add_threepid(self, user_id: str, medium: str, address: str, validated_at: int):
+ # check if medium has a valid value
+ if medium not in ["email", "msisdn"]:
+ raise SynapseError(
+ code=400,
+ msg=("'%s' is not a valid value for 'medium'" % (medium,)),
+ errcode=Codes.INVALID_PARAM,
+ )
+
# 'Canonicalise' email addresses down to lower case.
- # We've now moving towards the Home Server being the entity that
+ # We've now moving towards the homeserver being the entity that
# is responsible for validating threepids used for resetting passwords
# on accounts, so in future Synapse will gain knowledge of specific
# types (mediums) of threepid. For now, we still use the existing
@@ -944,28 +866,28 @@ class AuthHandler(BaseHandler):
# of specific types of threepid (and fixes the fact that checking
# for the presence of an email address during password reset was
# case sensitive).
- if medium == 'email':
+ if medium == "email":
address = address.lower()
yield self.store.user_add_threepid(
- user_id, medium, address, validated_at,
- self.hs.get_clock().time_msec()
+ user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
)
@defer.inlineCallbacks
- def delete_threepid(self, user_id, medium, address, id_server=None):
+ def delete_threepid(
+ self, user_id: str, medium: str, address: str, id_server: Optional[str] = None
+ ):
"""Attempts to unbind the 3pid on the identity servers and deletes it
from the local database.
Args:
- user_id (str)
- medium (str)
- address (str)
- id_server (str|None): Use the given identity server when unbinding
+ user_id: ID of user to remove the 3pid from.
+ medium: The medium of the 3pid being removed: "email" or "msisdn".
+ address: The 3pid address to remove.
+ id_server: Use the given identity server when unbinding
any threepids. If None then will attempt to unbind using the
identity server specified when binding (if known).
-
Returns:
Deferred[bool]: Returns True if successfully unbound the 3pid on
the identity server, False if identity server doesn't support the
@@ -973,133 +895,156 @@ class AuthHandler(BaseHandler):
"""
# 'Canonicalise' email addresses as per above
- if medium == 'email':
+ if medium == "email":
address = address.lower()
identity_handler = self.hs.get_handlers().identity_handler
result = yield identity_handler.try_unbind_threepid(
- user_id,
- {
- 'medium': medium,
- 'address': address,
- 'id_server': id_server,
- },
+ user_id, {"medium": medium, "address": address, "id_server": id_server}
)
- yield self.store.user_delete_threepid(
- user_id, medium, address,
- )
- defer.returnValue(result)
+ yield self.store.user_delete_threepid(user_id, medium, address)
+ return result
- def _save_session(self, session):
+ def _save_session(self, session: Dict[str, Any]) -> None:
+ """Update the last used time on the session to now and add it back to the session store."""
# TODO: Persistent storage
logger.debug("Saving session %s", session)
session["last_used"] = self.hs.get_clock().time_msec()
self.sessions[session["id"]] = session
- def hash(self, password):
+ def hash(self, password: str):
"""Computes a secure hash of password.
Args:
- password (unicode): Password to hash.
+ password: Password to hash.
Returns:
Deferred(unicode): Hashed password.
"""
+
def _do_hash():
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
return bcrypt.hashpw(
- pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
+ pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
bcrypt.gensalt(self.bcrypt_rounds),
- ).decode('ascii')
+ ).decode("ascii")
- return logcontext.defer_to_thread(self.hs.get_reactor(), _do_hash)
+ return defer_to_thread(self.hs.get_reactor(), _do_hash)
- def validate_hash(self, password, stored_hash):
+ def validate_hash(self, password: str, stored_hash: bytes):
"""Validates that self.hash(password) == stored_hash.
Args:
- password (unicode): Password to hash.
- stored_hash (bytes): Expected hash value.
+ password: Password to hash.
+ stored_hash: Expected hash value.
Returns:
Deferred(bool): Whether self.hash(password) == stored_hash.
"""
+
def _do_validate_hash():
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
return bcrypt.checkpw(
- pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
- stored_hash
+ pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
+ stored_hash,
)
if stored_hash:
if not isinstance(stored_hash, bytes):
- stored_hash = stored_hash.encode('ascii')
+ stored_hash = stored_hash.encode("ascii")
- return logcontext.defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
+ return defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
else:
return defer.succeed(False)
- def ratelimit_login_per_account(self, user_id):
- """Checks whether the process must be stopped because of ratelimiting.
-
- Checks against two ratelimiters: the generic one for login attempts per
- account and the one specific to failed attempts.
+ def complete_sso_login(
+ self,
+ registered_user_id: str,
+ request: SynapseRequest,
+ client_redirect_url: str,
+ ):
+ """Having figured out a mxid for this user, complete the HTTP request
Args:
- user_id (unicode): complete @user:id
-
- Raises:
- LimitExceededError if one of the ratelimiters' login requests count
- for this user is too high too proceed.
+ registered_user_id: The registered user ID to complete SSO login for.
+ request: The request to complete.
+ client_redirect_url: The URL to which to redirect the user at the end of the
+ process.
"""
- self._failed_attempts_ratelimiter.ratelimit(
- user_id.lower(), time_now_s=self._clock.time(),
- rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
- burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
- update=False,
+ # Create a login token
+ login_token = self.macaroon_gen.generate_short_term_login_token(
+ registered_user_id
)
- self._account_ratelimiter.ratelimit(
- user_id.lower(), time_now_s=self._clock.time(),
- rate_hz=self.hs.config.rc_login_account.per_second,
- burst_count=self.hs.config.rc_login_account.burst_count,
- update=True,
+ # Append the login token to the original redirect URL (i.e. with its query
+ # parameters kept intact) to build the URL to which the template needs to
+ # redirect the users once they have clicked on the confirmation link.
+ redirect_url = self.add_query_param_to_url(
+ client_redirect_url, "loginToken", login_token
)
+ # if the client is whitelisted, we can redirect straight to it
+ if client_redirect_url.startswith(self._whitelisted_sso_clients):
+ request.redirect(redirect_url)
+ finish_request(request)
+ return
+
+ # Otherwise, serve the redirect confirmation page.
+
+ # Remove the query parameters from the redirect URL to get a shorter version of
+ # it. This is only to display a human-readable URL in the template, but not the
+ # URL we redirect users to.
+ redirect_url_no_params = client_redirect_url.split("?")[0]
+
+ html = self._sso_redirect_confirm_template.render(
+ display_url=redirect_url_no_params,
+ redirect_url=redirect_url,
+ server_name=self._server_name,
+ ).encode("utf-8")
+
+ request.setResponseCode(200)
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(b"Content-Length", b"%d" % (len(html),))
+ request.write(html)
+ finish_request(request)
+
+ @staticmethod
+ def add_query_param_to_url(url: str, param_name: str, param: Any):
+ url_parts = list(urllib.parse.urlparse(url))
+ query = dict(urllib.parse.parse_qsl(url_parts[4]))
+ query.update({param_name: param})
+ url_parts[4] = urllib.parse.urlencode(query)
+ return urllib.parse.urlunparse(url_parts)
+
@attr.s
class MacaroonGenerator(object):
hs = attr.ib()
- def generate_access_token(self, user_id, extra_caveats=None):
+ def generate_access_token(
+ self, user_id: str, extra_caveats: Optional[List[str]] = None
+ ) -> str:
extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access")
# Include a nonce, to make sure that each login gets a different
# access token.
- macaroon.add_first_party_caveat("nonce = %s" % (
- stringutils.random_string_with_symbols(16),
- ))
+ macaroon.add_first_party_caveat(
+ "nonce = %s" % (stringutils.random_string_with_symbols(16),)
+ )
for caveat in extra_caveats:
macaroon.add_first_party_caveat(caveat)
return macaroon.serialize()
- def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
- """
-
- Args:
- user_id (unicode):
- duration_in_ms (int):
-
- Returns:
- unicode
- """
+ def generate_short_term_login_token(
+ self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
+ ) -> str:
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
@@ -1107,16 +1052,17 @@ class MacaroonGenerator(object):
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
- def generate_delete_pusher_token(self, user_id):
+ def generate_delete_pusher_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = delete_pusher")
return macaroon.serialize()
- def _generate_base_macaroon(self, user_id):
+ def _generate_base_macaroon(self, user_id: str) -> pymacaroons.Macaroon:
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
- key=self.hs.config.macaroon_secret_key)
+ key=self.hs.config.macaroon_secret_key,
+ )
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 32e004e53e..f624c2a3f9 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -15,8 +15,6 @@
# limitations under the License.
import logging
-from twisted.internet import defer
-
from synapse.api.errors import SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID, create_requester
@@ -28,6 +26,7 @@ logger = logging.getLogger(__name__)
class DeactivateAccountHandler(BaseHandler):
"""Handler which deals with deactivating user accounts."""
+
def __init__(self, hs):
super(DeactivateAccountHandler, self).__init__(hs)
self._auth_handler = hs.get_auth_handler()
@@ -46,8 +45,7 @@ class DeactivateAccountHandler(BaseHandler):
self._account_validity_enabled = hs.config.account_validity.enabled
- @defer.inlineCallbacks
- def deactivate_account(self, user_id, erase_data, id_server=None):
+ async def deactivate_account(self, user_id, erase_data, id_server=None):
"""Deactivate a user's account
Args:
@@ -73,15 +71,17 @@ class DeactivateAccountHandler(BaseHandler):
# unbinding
identity_server_supports_unbinding = True
- threepids = yield self.store.user_get_threepids(user_id)
+ # Retrieve the 3PIDs this user has bound to an identity server
+ threepids = await self.store.user_get_bound_threepids(user_id)
+
for threepid in threepids:
try:
- result = yield self._identity_handler.try_unbind_threepid(
+ result = await self._identity_handler.try_unbind_threepid(
user_id,
{
- 'medium': threepid['medium'],
- 'address': threepid['address'],
- 'id_server': id_server,
+ "medium": threepid["medium"],
+ "address": threepid["address"],
+ "id_server": id_server,
},
)
identity_server_supports_unbinding &= result
@@ -89,33 +89,36 @@ class DeactivateAccountHandler(BaseHandler):
# Do we want this to be a fatal error or should we carry on?
logger.exception("Failed to remove threepid from ID server")
raise SynapseError(400, "Failed to remove threepid from ID server")
- yield self.store.user_delete_threepid(
- user_id, threepid['medium'], threepid['address'],
+ await self.store.user_delete_threepid(
+ user_id, threepid["medium"], threepid["address"]
)
+ # Remove all 3PIDs this user has bound to the homeserver
+ await self.store.user_delete_threepids(user_id)
+
# delete any devices belonging to the user, which will also
# delete corresponding access tokens.
- yield self._device_handler.delete_all_devices_for_user(user_id)
+ await self._device_handler.delete_all_devices_for_user(user_id)
# then delete any remaining access tokens which weren't associated with
# a device.
- yield self._auth_handler.delete_access_tokens_for_user(user_id)
+ await self._auth_handler.delete_access_tokens_for_user(user_id)
- yield self.store.user_set_password_hash(user_id, None)
+ await self.store.user_set_password_hash(user_id, None)
user = UserID.from_string(user_id)
- yield self._profile_handler.set_active(user, False, False)
+ await self._profile_handler.set_active(user, False, False)
# Add the user to a table of users pending deactivation (ie.
# removal from all the rooms they're a member of)
- yield self.store.add_user_pending_deactivation(user_id)
+ await self.store.add_user_pending_deactivation(user_id)
# delete from user directory
- yield self.user_directory_handler.handle_user_deactivated(user_id)
+ await self.user_directory_handler.handle_user_deactivated(user_id)
# Mark the user as erased, if they asked for that
if erase_data:
logger.info("Marking %s as erased", user_id)
- yield self.store.mark_user_erased(user_id)
+ await self.store.mark_user_erased(user_id)
# Now start the process that goes through that list and
# parts users from rooms (if it isn't already running)
@@ -123,30 +126,29 @@ class DeactivateAccountHandler(BaseHandler):
# Reject all pending invites for the user, so that the user doesn't show up in the
# "invited" section of rooms' members list.
- yield self._reject_pending_invites_for_user(user_id)
+ await self._reject_pending_invites_for_user(user_id)
# Remove all information on the user from the account_validity table.
if self._account_validity_enabled:
- yield self.store.delete_account_validity_for_user(user_id)
+ await self.store.delete_account_validity_for_user(user_id)
# Mark the user as deactivated.
- yield self.store.set_user_deactivated_status(user_id, True)
+ await self.store.set_user_deactivated_status(user_id, True)
- defer.returnValue(identity_server_supports_unbinding)
+ return identity_server_supports_unbinding
- @defer.inlineCallbacks
- def _reject_pending_invites_for_user(self, user_id):
+ async def _reject_pending_invites_for_user(self, user_id):
"""Reject pending invites addressed to a given user ID.
Args:
user_id (str): The user ID to reject pending invites for.
"""
user = UserID.from_string(user_id)
- pending_invites = yield self.store.get_invited_rooms_for_user(user_id)
+ pending_invites = await self.store.get_invited_rooms_for_local_user(user_id)
for room in pending_invites:
try:
- yield self._room_member_handler.update_membership(
+ await self._room_member_handler.update_membership(
create_requester(user),
user,
room.room_id,
@@ -178,8 +180,7 @@ class DeactivateAccountHandler(BaseHandler):
if not self._user_parter_running:
run_as_background_process("user_parter_loop", self._user_parter_loop)
- @defer.inlineCallbacks
- def _user_parter_loop(self):
+ async def _user_parter_loop(self):
"""Loop that parts deactivated users from rooms
Returns:
@@ -189,19 +190,18 @@ class DeactivateAccountHandler(BaseHandler):
logger.info("Starting user parter")
try:
while True:
- user_id = yield self.store.get_user_pending_deactivation()
+ user_id = await self.store.get_user_pending_deactivation()
if user_id is None:
break
logger.info("User parter parting %r", user_id)
- yield self._part_user(user_id)
- yield self.store.del_user_pending_deactivation(user_id)
+ await self._part_user(user_id)
+ await self.store.del_user_pending_deactivation(user_id)
logger.info("User parter finished parting %r", user_id)
logger.info("User parter finished: stopping")
finally:
self._user_parter_running = False
- @defer.inlineCallbacks
- def _part_user(self, user_id):
+ async def _part_user(self, user_id):
"""Causes the given user_id to leave all the rooms they're joined to
Returns:
@@ -209,11 +209,11 @@ class DeactivateAccountHandler(BaseHandler):
"""
user = UserID.from_string(user_id)
- rooms_for_user = yield self.store.get_rooms_for_user(user_id)
+ rooms_for_user = await self.store.get_rooms_for_user(user_id)
for room_id in rooms_for_user:
logger.info("User parter parting %r from %r", user_id, room_id)
try:
- yield self._room_member_handler.update_membership(
+ await self._room_member_handler.update_membership(
create_requester(user),
user,
room_id,
@@ -224,5 +224,6 @@ class DeactivateAccountHandler(BaseHandler):
except Exception:
logger.exception(
"Failed to part user %r from room %r: ignoring and continuing",
- user_id, room_id,
+ user_id,
+ room_id,
)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index d69fc8b061..993499f446 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -24,7 +26,9 @@ from synapse.api.errors import (
FederationDeniedError,
HttpResponseException,
RequestSendFailed,
+ SynapseError,
)
+from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.types import RoomStreamToken, get_domain_from_id
from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer
@@ -36,6 +40,8 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__)
+MAX_DEVICE_DISPLAY_NAME_LEN = 100
+
class DeviceWorkerHandler(BaseHandler):
def __init__(self, hs):
@@ -43,8 +49,10 @@ class DeviceWorkerHandler(BaseHandler):
self.hs = hs
self.state = hs.get_state_handler()
+ self.state_store = hs.get_storage().state
self._auth_handler = hs.get_auth_handler()
+ @trace
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
"""
@@ -56,18 +64,19 @@ class DeviceWorkerHandler(BaseHandler):
defer.Deferred: list[dict[str, X]]: info on each device
"""
+ set_tag("user_id", user_id)
device_map = yield self.store.get_devices_by_user(user_id)
- ips = yield self.store.get_last_client_ip_by_device(
- user_id, device_id=None
- )
+ ips = yield self.store.get_last_client_ip_by_device(user_id, device_id=None)
devices = list(device_map.values())
for device in devices:
_update_device_from_client_ips(device, ips)
- defer.returnValue(devices)
+ log_kv(device_map)
+ return devices
+ @trace
@defer.inlineCallbacks
def get_device(self, user_id, device_id):
""" Retrieve the given device
@@ -85,13 +94,16 @@ class DeviceWorkerHandler(BaseHandler):
device = yield self.store.get_device(user_id, device_id)
except errors.StoreError:
raise errors.NotFoundError
- ips = yield self.store.get_last_client_ip_by_device(
- user_id, device_id,
- )
+ ips = yield self.store.get_last_client_ip_by_device(user_id, device_id)
_update_device_from_client_ips(device, ips)
- defer.returnValue(device)
+
+ set_tag("device", device)
+ set_tag("ips", ips)
+
+ return device
@measure_func("device.get_user_ids_changed")
+ @trace
@defer.inlineCallbacks
def get_user_ids_changed(self, user_id, from_token):
"""Get list of users that have had the devices updated, or have newly
@@ -101,26 +113,37 @@ class DeviceWorkerHandler(BaseHandler):
user_id (str)
from_token (StreamToken)
"""
+
+ set_tag("user_id", user_id)
+ set_tag("from_token", from_token)
now_room_key = yield self.store.get_room_events_max_id()
room_ids = yield self.store.get_rooms_for_user(user_id)
- # First we check if any devices have changed
- changed = yield self.store.get_user_whose_devices_changed(
- from_token.device_list_key
+ # First we check if any devices have changed for users that we share
+ # rooms with.
+ users_who_share_room = yield self.store.get_users_who_share_room_with_user(
+ user_id
+ )
+
+ tracked_users = set(users_who_share_room)
+
+ # Always tell the user about their own devices
+ tracked_users.add(user_id)
+
+ changed = yield self.store.get_users_whose_devices_changed(
+ from_token.device_list_key, tracked_users
)
# Then work out if any users have since joined
rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
member_events = yield self.store.get_membership_changes_for_user(
- user_id, from_token.room_key, now_room_key,
+ user_id, from_token.room_key, now_room_key
)
rooms_changed.update(event.room_id for event in member_events)
- stream_ordering = RoomStreamToken.parse_stream_token(
- from_token.room_key
- ).stream
+ stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key).stream
possibly_changed = set(changed)
possibly_left = set()
@@ -150,6 +173,9 @@ class DeviceWorkerHandler(BaseHandler):
# special-case for an empty prev state: include all members
# in the changed list
if not event_ids:
+ log_kv(
+ {"event": "encountered empty previous state", "room_id": room_id}
+ )
for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
@@ -162,7 +188,7 @@ class DeviceWorkerHandler(BaseHandler):
continue
# mapping from event_id -> state_dict
- prev_state_ids = yield self.store.get_state_ids_for_events(event_ids)
+ prev_state_ids = yield self.state_store.get_state_ids_for_events(event_ids)
# Check if we've joined the room? If so we just blindly add all the users to
# the "possibly changed" users.
@@ -194,10 +220,6 @@ class DeviceWorkerHandler(BaseHandler):
break
if possibly_changed or possibly_left:
- users_who_share_room = yield self.store.get_users_who_share_room_with_user(
- user_id
- )
-
# Take the intersection of the users whose devices may have changed
# and those that actually still share a room with the user
possibly_joined = possibly_changed & users_who_share_room
@@ -206,10 +228,27 @@ class DeviceWorkerHandler(BaseHandler):
possibly_joined = []
possibly_left = []
- defer.returnValue({
- "changed": list(possibly_joined),
- "left": list(possibly_left),
- })
+ result = {"changed": list(possibly_joined), "left": list(possibly_left)}
+
+ log_kv(result)
+
+ return result
+
+ @defer.inlineCallbacks
+ def on_federation_query_user_devices(self, user_id):
+ stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
+ master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
+ self_signing_key = yield self.store.get_e2e_cross_signing_key(
+ user_id, "self_signing"
+ )
+
+ return {
+ "user_id": user_id,
+ "stream_id": stream_id,
+ "devices": devices,
+ "master_key": master_key,
+ "self_signing_key": self_signing_key,
+ }
class DeviceHandler(DeviceWorkerHandler):
@@ -218,22 +257,20 @@ class DeviceHandler(DeviceWorkerHandler):
self.federation_sender = hs.get_federation_sender()
- self._edu_updater = DeviceListEduUpdater(hs, self)
+ self.device_list_updater = DeviceListUpdater(hs, self)
federation_registry = hs.get_federation_registry()
federation_registry.register_edu_handler(
- "m.device_list_update", self._edu_updater.incoming_device_list_update,
- )
- federation_registry.register_query_handler(
- "user_devices", self.on_federation_query_user_devices,
+ "m.device_list_update", self.device_list_updater.incoming_device_list_update
)
hs.get_distributor().observe("user_left_room", self.user_left_room)
@defer.inlineCallbacks
- def check_device_registered(self, user_id, device_id,
- initial_device_display_name=None):
+ def check_device_registered(
+ self, user_id, device_id, initial_device_display_name=None
+ ):
"""
If the given device has not been registered, register it with the
supplied display name.
@@ -256,7 +293,7 @@ class DeviceHandler(DeviceWorkerHandler):
)
if new_device:
yield self.notify_device_update(user_id, [device_id])
- defer.returnValue(device_id)
+ return device_id
# if the device id is not specified, we'll autogen one, but loop a few
# times in case of a clash.
@@ -270,11 +307,12 @@ class DeviceHandler(DeviceWorkerHandler):
)
if new_device:
yield self.notify_device_update(user_id, [device_id])
- defer.returnValue(device_id)
+ return device_id
attempts += 1
raise errors.StoreError(500, "Couldn't generate a device ID.")
+ @trace
@defer.inlineCallbacks
def delete_device(self, user_id, device_id):
""" Delete the given device
@@ -292,20 +330,23 @@ class DeviceHandler(DeviceWorkerHandler):
except errors.StoreError as e:
if e.code == 404:
# no match
+ set_tag("error", True)
+ log_kv(
+ {"reason": "User doesn't have device id.", "device_id": device_id}
+ )
pass
else:
raise
yield self._auth_handler.delete_access_tokens_for_user(
- user_id, device_id=device_id,
+ user_id, device_id=device_id
)
- yield self.store.delete_e2e_keys_by_device(
- user_id=user_id, device_id=device_id
- )
+ yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
yield self.notify_device_update(user_id, [device_id])
+ @trace
@defer.inlineCallbacks
def delete_all_devices_for_user(self, user_id, except_device_id=None):
"""Delete all of the user's devices
@@ -341,6 +382,8 @@ class DeviceHandler(DeviceWorkerHandler):
except errors.StoreError as e:
if e.code == 404:
# no match
+ set_tag("error", True)
+ set_tag("reason", "User doesn't have that device id.")
pass
else:
raise
@@ -349,7 +392,7 @@ class DeviceHandler(DeviceWorkerHandler):
# considered as part of a critical path.
for device_id in device_ids:
yield self._auth_handler.delete_access_tokens_for_user(
- user_id, device_id=device_id,
+ user_id, device_id=device_id
)
yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
@@ -370,11 +413,18 @@ class DeviceHandler(DeviceWorkerHandler):
defer.Deferred:
"""
+ # Reject a new displayname which is too long.
+ new_display_name = content.get("display_name")
+ if new_display_name and len(new_display_name) > MAX_DEVICE_DISPLAY_NAME_LEN:
+ raise SynapseError(
+ 400,
+ "Device display name is too long (max %i)"
+ % (MAX_DEVICE_DISPLAY_NAME_LEN,),
+ )
+
try:
yield self.store.update_device(
- user_id,
- device_id,
- new_display_name=content.get("display_name")
+ user_id, device_id, new_display_name=new_display_name
)
yield self.notify_device_update(user_id, [device_id])
except errors.StoreError as e:
@@ -383,6 +433,7 @@ class DeviceHandler(DeviceWorkerHandler):
else:
raise
+ @trace
@measure_func("notify_device_update")
@defer.inlineCallbacks
def notify_device_update(self, user_id, device_ids):
@@ -398,35 +449,47 @@ class DeviceHandler(DeviceWorkerHandler):
hosts.update(get_domain_from_id(u) for u in users_who_share_room)
hosts.discard(self.server_name)
+ set_tag("target_hosts", hosts)
+
position = yield self.store.add_device_change_to_streams(
user_id, device_ids, list(hosts)
)
for device_id in device_ids:
logger.debug(
- "Notifying about update %r/%r, ID: %r", user_id, device_id,
- position,
+ "Notifying about update %r/%r, ID: %r", user_id, device_id, position
)
room_ids = yield self.store.get_rooms_for_user(user_id)
+ # specify the user ID too since the user should always get their own device list
+ # updates, even if they aren't in any rooms.
yield self.notifier.on_new_event(
- "device_list_key", position, rooms=room_ids,
+ "device_list_key", position, users=[user_id], rooms=room_ids
)
if hosts:
- logger.info("Sending device list update notif for %r to: %r", user_id, hosts)
+ logger.info(
+ "Sending device list update notif for %r to: %r", user_id, hosts
+ )
for host in hosts:
self.federation_sender.send_device_messages(host)
+ log_kv({"message": "sent device update to host", "host": host})
@defer.inlineCallbacks
- def on_federation_query_user_devices(self, user_id):
- stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
- defer.returnValue({
- "user_id": user_id,
- "stream_id": stream_id,
- "devices": devices,
- })
+ def notify_user_signature_update(self, from_user_id, user_ids):
+ """Notify a user that they have made new signatures of other users.
+
+ Args:
+ from_user_id (str): the user who made the signature
+ user_ids (list[str]): the users IDs that have new signatures
+ """
+
+ position = yield self.store.add_user_signature_change_to_streams(
+ from_user_id, user_ids
+ )
+
+ self.notifier.on_new_event("device_list_key", position, users=[from_user_id])
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
@@ -440,13 +503,10 @@ class DeviceHandler(DeviceWorkerHandler):
def _update_device_from_client_ips(device, client_ips):
ip = client_ips.get((device["user_id"], device["device_id"]), {})
- device.update({
- "last_seen_ts": ip.get("last_seen"),
- "last_seen_ip": ip.get("ip"),
- })
+ device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
-class DeviceListEduUpdater(object):
+class DeviceListUpdater(object):
"Handles incoming device list updates from federation and updates the DB"
def __init__(self, hs, device_handler):
@@ -471,23 +531,38 @@ class DeviceListEduUpdater(object):
iterable=True,
)
+ @trace
@defer.inlineCallbacks
def incoming_device_list_update(self, origin, edu_content):
"""Called on incoming device list update from federation. Responsible
for parsing the EDU and adding to pending updates list.
"""
+ set_tag("origin", origin)
+ set_tag("edu_content", edu_content)
user_id = edu_content.pop("user_id")
device_id = edu_content.pop("device_id")
stream_id = str(edu_content.pop("stream_id")) # They may come as ints
prev_ids = edu_content.pop("prev_id", [])
- prev_ids = [str(p) for p in prev_ids] # They may come as ints
+ prev_ids = [str(p) for p in prev_ids] # They may come as ints
if get_domain_from_id(user_id) != origin:
# TODO: Raise?
logger.warning(
"Got device list update edu for %r/%r from %r",
- user_id, device_id, origin,
+ user_id,
+ device_id,
+ origin,
+ )
+
+ set_tag("error", True)
+ log_kv(
+ {
+ "message": "Got a device list update edu from a user and "
+ "device which does not match the origin of the request.",
+ "user_id": user_id,
+ "device_id": device_id,
+ }
)
return
@@ -495,15 +570,22 @@ class DeviceListEduUpdater(object):
if not room_ids:
# We don't share any rooms with this user. Ignore update, as we
# probably won't get any further updates.
+ set_tag("error", True)
+ log_kv(
+ {
+ "message": "Got an update from a user for which "
+ "we don't share any rooms",
+ "other user_id": user_id,
+ }
+ )
logger.warning(
"Got device list update edu for %r/%r, but don't share a room",
- user_id, device_id,
+ user_id,
+ device_id,
)
return
- logger.debug(
- "Received device list update for %r/%r", user_id, device_id,
- )
+ logger.debug("Received device list update for %r/%r", user_id, device_id)
self._pending_updates.setdefault(user_id, []).append(
(device_id, stream_id, prev_ids, edu_content)
@@ -525,94 +607,32 @@ class DeviceListEduUpdater(object):
for device_id, stream_id, prev_ids, content in pending_updates:
logger.debug(
"Handling update %r/%r, ID: %r, prev: %r ",
- user_id, device_id, stream_id, prev_ids,
+ user_id,
+ device_id,
+ stream_id,
+ prev_ids,
)
# Given a list of updates we check if we need to resync. This
# happens if we've missed updates.
resync = yield self._need_to_do_resync(user_id, pending_updates)
- logger.debug("Need to re-sync devices for %r? %r", user_id, resync)
-
- if resync:
- # Fetch all devices for the user.
- origin = get_domain_from_id(user_id)
- try:
- result = yield self.federation.query_user_devices(origin, user_id)
- except (
- NotRetryingDestination, RequestSendFailed, HttpResponseException,
- ):
- # TODO: Remember that we are now out of sync and try again
- # later
- logger.warn(
- "Failed to handle device list update for %s", user_id,
- )
- # We abort on exceptions rather than accepting the update
- # as otherwise synapse will 'forget' that its device list
- # is out of date. If we bail then we will retry the resync
- # next time we get a device list update for this user_id.
- # This makes it more likely that the device lists will
- # eventually become consistent.
- return
- except FederationDeniedError as e:
- logger.info(e)
- return
- except Exception:
- # TODO: Remember that we are now out of sync and try again
- # later
- logger.exception(
- "Failed to handle device list update for %s", user_id
- )
- return
-
- stream_id = result["stream_id"]
- devices = result["devices"]
-
- for device in devices:
- logger.debug(
- "Handling resync update %r/%r, ID: %r",
- user_id, device["device_id"], stream_id,
- )
-
- # If the remote server has more than ~1000 devices for this user
- # we assume that something is going horribly wrong (e.g. a bot
- # that logs in and creates a new device every time it tries to
- # send a message). Maintaining lots of devices per user in the
- # cache can cause serious performance issues as if this request
- # takes more than 60s to complete, internal replication from the
- # inbound federation worker to the synapse master may time out
- # causing the inbound federation to fail and causing the remote
- # server to retry, causing a DoS. So in this scenario we give
- # up on storing the total list of devices and only handle the
- # delta instead.
- if len(devices) > 1000:
- logger.warn(
- "Ignoring device list snapshot for %s as it has >1K devs (%d)",
- user_id, len(devices)
- )
- devices = []
-
- for device in devices:
- logger.debug(
- "Handling resync update %r/%r, ID: %r",
- user_id, device["device_id"], stream_id,
- )
-
- yield self.store.update_remote_device_list_cache(
- user_id, devices, stream_id,
+ if logger.isEnabledFor(logging.INFO):
+ logger.info(
+ "Received device list update for %s, requiring resync: %s. Devices: %s",
+ user_id,
+ resync,
+ ", ".join(u[0] for u in pending_updates),
)
- device_ids = [device["device_id"] for device in devices]
- yield self.device_handler.notify_device_update(user_id, device_ids)
- # We clobber the seen updates since we've re-synced from a given
- # point.
- self._seen_updates[user_id] = set([stream_id])
+ if resync:
+ yield self.user_device_resync(user_id)
else:
# Simply update the single device, since we know that is the only
# change (because of the single prev_id matching the current cache)
for device_id, stream_id, prev_ids, content in pending_updates:
yield self.store.update_remote_device_list_cache_entry(
- user_id, device_id, content, stream_id,
+ user_id, device_id, content, stream_id
)
yield self.device_handler.notify_device_update(
@@ -630,20 +650,15 @@ class DeviceListEduUpdater(object):
"""
seen_updates = self._seen_updates.get(user_id, set())
- extremity = yield self.store.get_device_list_last_stream_id_for_remote(
- user_id
- )
+ extremity = yield self.store.get_device_list_last_stream_id_for_remote(user_id)
- logger.debug(
- "Current extremity for %r: %r",
- user_id, extremity,
- )
+ logger.debug("Current extremity for %r: %r", user_id, extremity)
stream_id_in_updates = set() # stream_ids in updates list
for _, stream_id, prev_ids, _ in updates:
if not prev_ids:
# We always do a resync if there are no previous IDs
- defer.returnValue(True)
+ return True
for prev_id in prev_ids:
if prev_id == extremity:
@@ -653,8 +668,90 @@ class DeviceListEduUpdater(object):
elif prev_id in stream_id_in_updates:
continue
else:
- defer.returnValue(True)
+ return True
stream_id_in_updates.add(stream_id)
- defer.returnValue(False)
+ return False
+
+ @defer.inlineCallbacks
+ def user_device_resync(self, user_id):
+ """Fetches all devices for a user and updates the device cache with them.
+
+ Args:
+ user_id (str): The user's id whose device_list will be updated.
+ Returns:
+ Deferred[dict]: a dict with device info as under the "devices" in the result of this
+ request:
+ https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
+ """
+ log_kv({"message": "Doing resync to update device list."})
+ # Fetch all devices for the user.
+ origin = get_domain_from_id(user_id)
+ try:
+ result = yield self.federation.query_user_devices(origin, user_id)
+ except (NotRetryingDestination, RequestSendFailed, HttpResponseException):
+ # TODO: Remember that we are now out of sync and try again
+ # later
+ logger.warning("Failed to handle device list update for %s", user_id)
+ # We abort on exceptions rather than accepting the update
+ # as otherwise synapse will 'forget' that its device list
+ # is out of date. If we bail then we will retry the resync
+ # next time we get a device list update for this user_id.
+ # This makes it more likely that the device lists will
+ # eventually become consistent.
+ return
+ except FederationDeniedError as e:
+ set_tag("error", True)
+ log_kv({"reason": "FederationDeniedError"})
+ logger.info(e)
+ return
+ except Exception as e:
+ # TODO: Remember that we are now out of sync and try again
+ # later
+ set_tag("error", True)
+ log_kv(
+ {"message": "Exception raised by federation request", "exception": e}
+ )
+ logger.exception("Failed to handle device list update for %s", user_id)
+ return
+ log_kv({"result": result})
+ stream_id = result["stream_id"]
+ devices = result["devices"]
+
+ # If the remote server has more than ~1000 devices for this user
+ # we assume that something is going horribly wrong (e.g. a bot
+ # that logs in and creates a new device every time it tries to
+ # send a message). Maintaining lots of devices per user in the
+ # cache can cause serious performance issues as if this request
+ # takes more than 60s to complete, internal replication from the
+ # inbound federation worker to the synapse master may time out
+ # causing the inbound federation to fail and causing the remote
+ # server to retry, causing a DoS. So in this scenario we give
+ # up on storing the total list of devices and only handle the
+ # delta instead.
+ if len(devices) > 1000:
+ logger.warning(
+ "Ignoring device list snapshot for %s as it has >1K devs (%d)",
+ user_id,
+ len(devices),
+ )
+ devices = []
+
+ for device in devices:
+ logger.debug(
+ "Handling resync update %r/%r, ID: %r",
+ user_id,
+ device["device_id"],
+ stream_id,
+ )
+
+ yield self.store.update_remote_device_list_cache(user_id, devices, stream_id)
+ device_ids = [device["device_id"] for device in devices]
+ yield self.device_handler.notify_device_update(user_id, device_ids)
+
+ # We clobber the seen updates since we've re-synced from a given
+ # point.
+ self._seen_updates[user_id] = {stream_id}
+
+ defer.returnValue(result)
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 2e2e5261de..05c4b3eec0 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -14,10 +14,20 @@
# limitations under the License.
import logging
+from typing import Any, Dict
+
+from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import SynapseError
+from synapse.logging.context import run_in_background
+from synapse.logging.opentracing import (
+ get_active_span_text_map,
+ log_kv,
+ set_tag,
+ start_active_span,
+)
from synapse.types import UserID, get_domain_from_id
from synapse.util.stringutils import random_string
@@ -25,7 +35,6 @@ logger = logging.getLogger(__name__)
class DeviceMessageHandler(object):
-
def __init__(self, hs):
"""
Args:
@@ -40,24 +49,29 @@ class DeviceMessageHandler(object):
"m.direct_to_device", self.on_direct_to_device_edu
)
+ self._device_list_updater = hs.get_device_handler().device_list_updater
+
@defer.inlineCallbacks
def on_direct_to_device_edu(self, origin, content):
local_messages = {}
sender_user_id = content["sender"]
if origin != get_domain_from_id(sender_user_id):
- logger.warn(
+ logger.warning(
"Dropping device message from %r with spoofed sender %r",
- origin, sender_user_id
+ origin,
+ sender_user_id,
)
message_type = content["type"]
message_id = content["message_id"]
for user_id, by_device in content["messages"].items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
- logger.warning("Request for keys for non-local user %s",
- user_id)
+ logger.warning("Request for keys for non-local user %s", user_id)
raise SynapseError(400, "Not a user here")
+ if not by_device:
+ continue
+
messages_by_device = {
device_id: {
"content": message_content,
@@ -66,8 +80,11 @@ class DeviceMessageHandler(object):
}
for device_id, message_content in by_device.items()
}
- if messages_by_device:
- local_messages[user_id] = messages_by_device
+ local_messages[user_id] = messages_by_device
+
+ yield self._check_for_unknown_devices(
+ message_type, sender_user_id, by_device
+ )
stream_id = yield self.store.add_messages_from_remote_to_device_inbox(
origin, message_id, local_messages
@@ -78,8 +95,58 @@ class DeviceMessageHandler(object):
)
@defer.inlineCallbacks
- def send_device_message(self, sender_user_id, message_type, messages):
+ def _check_for_unknown_devices(
+ self,
+ message_type: str,
+ sender_user_id: str,
+ by_device: Dict[str, Dict[str, Any]],
+ ):
+ """Checks inbound device messages for unkown remote devices, and if
+ found marks the remote cache for the user as stale.
+ """
+
+ if message_type != "m.room_key_request":
+ return
+
+ # Get the sending device IDs
+ requesting_device_ids = set()
+ for message_content in by_device.values():
+ device_id = message_content.get("requesting_device_id")
+ requesting_device_ids.add(device_id)
+
+ # Check if we are tracking the devices of the remote user.
+ room_ids = yield self.store.get_rooms_for_user(sender_user_id)
+ if not room_ids:
+ logger.info(
+ "Received device message from remote device we don't"
+ " share a room with: %s %s",
+ sender_user_id,
+ requesting_device_ids,
+ )
+ return
+
+ # If we are tracking check that we know about the sending
+ # devices.
+ cached_devices = yield self.store.get_cached_devices_for_user(sender_user_id)
+
+ unknown_devices = requesting_device_ids - set(cached_devices)
+ if unknown_devices:
+ logger.info(
+ "Received device message from remote device not in our cache: %s %s",
+ sender_user_id,
+ unknown_devices,
+ )
+ yield self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
+
+ # Immediately attempt a resync in the background
+ run_in_background(
+ self._device_list_updater.user_device_resync, sender_user_id
+ )
+ @defer.inlineCallbacks
+ def send_device_message(self, sender_user_id, message_type, messages):
+ set_tag("number_of_messages", len(messages))
+ set_tag("sender", sender_user_id)
local_messages = {}
remote_messages = {}
for user_id, by_device in messages.items():
@@ -101,15 +168,21 @@ class DeviceMessageHandler(object):
message_id = random_string(16)
+ context = get_active_span_text_map()
+
remote_edu_contents = {}
for destination, messages in remote_messages.items():
- remote_edu_contents[destination] = {
- "messages": messages,
- "sender": sender_user_id,
- "type": message_type,
- "message_id": message_id,
- }
+ with start_active_span("to_device_for_user"):
+ set_tag("destination", destination)
+ remote_edu_contents[destination] = {
+ "messages": messages,
+ "sender": sender_user_id,
+ "type": message_type,
+ "message_id": message_id,
+ "org.matrix.opentracing_context": json.dumps(context),
+ }
+ log_kv({"local_messages": local_messages})
stream_id = yield self.store.add_messages_to_device_inbox(
local_messages, remote_edu_contents
)
@@ -118,6 +191,7 @@ class DeviceMessageHandler(object):
"to_device_key", stream_id, users=local_messages.keys()
)
+ log_kv({"remote_messages": remote_messages})
for destination in remote_messages.keys():
# Enqueue a new federation transaction to send the new
# device messages to each remote destination.
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index a12f9508d8..1d842c369b 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -13,9 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
import string
+from typing import Iterable, List, Optional
from twisted.internet import defer
@@ -28,7 +28,8 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
-from synapse.types import RoomAlias, UserID, get_domain_from_id
+from synapse.appservice import ApplicationService
+from synapse.types import Requester, RoomAlias, UserID, get_domain_from_id
from ._base import BaseHandler
@@ -36,7 +37,6 @@ logger = logging.getLogger(__name__)
class DirectoryHandler(BaseHandler):
-
def __init__(self, hs):
super(DirectoryHandler, self).__init__(hs)
@@ -56,7 +56,13 @@ class DirectoryHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker()
@defer.inlineCallbacks
- def _create_association(self, room_alias, room_id, servers=None, creator=None):
+ def _create_association(
+ self,
+ room_alias: RoomAlias,
+ room_id: str,
+ servers: Optional[Iterable[str]] = None,
+ creator: Optional[str] = None,
+ ):
# general association creation for both human users and app services
for wchar in string.whitespace:
@@ -71,31 +77,32 @@ class DirectoryHandler(BaseHandler):
# TODO(erikj): Check if there is a current association.
if not servers:
users = yield self.state.get_current_users_in_room(room_id)
- servers = set(get_domain_from_id(u) for u in users)
+ servers = {get_domain_from_id(u) for u in users}
if not servers:
raise SynapseError(400, "Failed to get server list")
yield self.store.create_room_alias_association(
- room_alias,
- room_id,
- servers,
- creator=creator,
+ room_alias, room_id, servers, creator=creator
)
@defer.inlineCallbacks
- def create_association(self, requester, room_alias, room_id, servers=None,
- send_event=True, check_membership=True):
+ def create_association(
+ self,
+ requester: Requester,
+ room_alias: RoomAlias,
+ room_id: str,
+ servers: Optional[List[str]] = None,
+ check_membership: bool = True,
+ ):
"""Attempt to create a new alias
Args:
- requester (Requester)
- room_alias (RoomAlias)
- room_id (str)
- servers (list[str]|None): List of servers that others servers
- should try and join via
- send_event (bool): Whether to send an updated m.room.aliases event
- check_membership (bool): Whether to check if the user is in the room
+ requester
+ room_alias
+ room_id
+ servers: Iterable of servers that others servers should try and join via
+ check_membership: Whether to check if the user is in the room
before the alias can be set (if the server's config requires it).
Returns:
@@ -115,63 +122,49 @@ class DirectoryHandler(BaseHandler):
if service:
if not service.is_interested_in_alias(room_alias.to_string()):
raise SynapseError(
- 400, "This application service has not reserved"
- " this kind of alias.", errcode=Codes.EXCLUSIVE
+ 400,
+ "This application service has not reserved this kind of alias.",
+ errcode=Codes.EXCLUSIVE,
)
else:
if self.require_membership and check_membership:
rooms_for_user = yield self.store.get_rooms_for_user(user_id)
if room_id not in rooms_for_user:
raise AuthError(
- 403,
- "You must be in the room to create an alias for it",
+ 403, "You must be in the room to create an alias for it"
)
if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
- raise AuthError(
- 403, "This user is not permitted to create this alias",
- )
+ raise AuthError(403, "This user is not permitted to create this alias")
if not self.config.is_alias_creation_allowed(
- user_id, room_id, room_alias.to_string(),
+ user_id, room_id, room_alias.to_string()
):
# Lets just return a generic message, as there may be all sorts of
# reasons why we said no. TODO: Allow configurable error messages
# per alias creation rule?
- raise SynapseError(
- 403, "Not allowed to create alias",
- )
+ raise SynapseError(403, "Not allowed to create alias")
- can_create = yield self.can_modify_alias(
- room_alias,
- user_id=user_id
- )
+ can_create = yield self.can_modify_alias(room_alias, user_id=user_id)
if not can_create:
raise AuthError(
- 400, "This alias is reserved by an application service.",
- errcode=Codes.EXCLUSIVE
+ 400,
+ "This alias is reserved by an application service.",
+ errcode=Codes.EXCLUSIVE,
)
yield self._create_association(room_alias, room_id, servers, creator=user_id)
- if send_event:
- yield self.send_room_alias_update_event(
- requester,
- room_id
- )
@defer.inlineCallbacks
- def delete_association(self, requester, room_alias, send_event=True):
+ def delete_association(self, requester: Requester, room_alias: RoomAlias):
"""Remove an alias from the directory
(this is only meant for human users; AS users should call
delete_appservice_association)
Args:
- requester (Requester):
- room_alias (RoomAlias):
- send_event (bool): Whether to send an updated m.room.aliases event.
- Note that, if we delete the canonical alias, we will always attempt
- to send an m.room.canonical_alias event
+ requester
+ room_alias
Returns:
Deferred[unicode]: room id that the alias used to point to
@@ -194,66 +187,51 @@ class DirectoryHandler(BaseHandler):
raise
if not can_delete:
- raise AuthError(
- 403, "You don't have permission to delete the alias.",
- )
+ raise AuthError(403, "You don't have permission to delete the alias.")
- can_delete = yield self.can_modify_alias(
- room_alias,
- user_id=user_id
- )
+ can_delete = yield self.can_modify_alias(room_alias, user_id=user_id)
if not can_delete:
raise SynapseError(
- 400, "This alias is reserved by an application service.",
- errcode=Codes.EXCLUSIVE
+ 400,
+ "This alias is reserved by an application service.",
+ errcode=Codes.EXCLUSIVE,
)
room_id = yield self._delete_association(room_alias)
try:
- if send_event:
- yield self.send_room_alias_update_event(
- requester,
- room_id
- )
-
- yield self._update_canonical_alias(
- requester,
- requester.user.to_string(),
- room_id,
- room_alias,
- )
+ yield self._update_canonical_alias(requester, user_id, room_id, room_alias)
except AuthError as e:
logger.info("Failed to update alias events: %s", e)
- defer.returnValue(room_id)
+ return room_id
@defer.inlineCallbacks
- def delete_appservice_association(self, service, room_alias):
+ def delete_appservice_association(
+ self, service: ApplicationService, room_alias: RoomAlias
+ ):
if not service.is_interested_in_alias(room_alias.to_string()):
raise SynapseError(
400,
"This application service has not reserved this kind of alias",
- errcode=Codes.EXCLUSIVE
+ errcode=Codes.EXCLUSIVE,
)
yield self._delete_association(room_alias)
@defer.inlineCallbacks
- def _delete_association(self, room_alias):
+ def _delete_association(self, room_alias: RoomAlias):
if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local")
room_id = yield self.store.delete_room_alias(room_alias)
- defer.returnValue(room_id)
+ return room_id
@defer.inlineCallbacks
- def get_association(self, room_alias):
+ def get_association(self, room_alias: RoomAlias):
room_id = None
if self.hs.is_mine(room_alias):
- result = yield self.get_association_from_room_alias(
- room_alias
- )
+ result = yield self.get_association_from_room_alias(room_alias)
if result:
room_id = result.room_id
@@ -263,14 +241,12 @@ class DirectoryHandler(BaseHandler):
result = yield self.federation.make_query(
destination=room_alias.domain,
query_type="directory",
- args={
- "room_alias": room_alias.to_string(),
- },
+ args={"room_alias": room_alias.to_string()},
retry_on_dns_fail=False,
ignore_backoff=True,
)
except CodeMessageException as e:
- logging.warn("Error retrieving alias")
+ logging.warning("Error retrieving alias")
if e.code == 404:
result = None
else:
@@ -284,102 +260,99 @@ class DirectoryHandler(BaseHandler):
raise SynapseError(
404,
"Room alias %s not found" % (room_alias.to_string(),),
- Codes.NOT_FOUND
+ Codes.NOT_FOUND,
)
users = yield self.state.get_current_users_in_room(room_id)
- extra_servers = set(get_domain_from_id(u) for u in users)
+ extra_servers = {get_domain_from_id(u) for u in users}
servers = set(extra_servers) | set(servers)
# If this server is in the list of servers, return it first.
if self.server_name in servers:
- servers = (
- [self.server_name] +
- [s for s in servers if s != self.server_name]
- )
+ servers = [self.server_name] + [s for s in servers if s != self.server_name]
else:
servers = list(servers)
- defer.returnValue({
- "room_id": room_id,
- "servers": servers,
- })
- return
+ return {"room_id": room_id, "servers": servers}
@defer.inlineCallbacks
def on_directory_query(self, args):
room_alias = RoomAlias.from_string(args["room_alias"])
if not self.hs.is_mine(room_alias):
- raise SynapseError(
- 400, "Room Alias is not hosted on this Home Server"
- )
+ raise SynapseError(400, "Room Alias is not hosted on this homeserver")
- result = yield self.get_association_from_room_alias(
- room_alias
- )
+ result = yield self.get_association_from_room_alias(room_alias)
if result is not None:
- defer.returnValue({
- "room_id": result.room_id,
- "servers": result.servers,
- })
+ return {"room_id": result.room_id, "servers": result.servers}
else:
raise SynapseError(
404,
"Room alias %r not found" % (room_alias.to_string(),),
- Codes.NOT_FOUND
+ Codes.NOT_FOUND,
)
@defer.inlineCallbacks
- def send_room_alias_update_event(self, requester, room_id):
- aliases = yield self.store.get_aliases_for_room(room_id)
-
- yield self.event_creation_handler.create_and_send_nonmember_event(
- requester,
- {
- "type": EventTypes.Aliases,
- "state_key": self.hs.hostname,
- "room_id": room_id,
- "sender": requester.user.to_string(),
- "content": {"aliases": aliases},
- },
- ratelimit=False
- )
-
- @defer.inlineCallbacks
- def _update_canonical_alias(self, requester, user_id, room_id, room_alias):
+ def _update_canonical_alias(
+ self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias
+ ):
+ """
+ Send an updated canonical alias event if the removed alias was set as
+ the canonical alias or listed in the alt_aliases field.
+ """
alias_event = yield self.state.get_current_state(
room_id, EventTypes.CanonicalAlias, ""
)
- alias_str = room_alias.to_string()
- if not alias_event or alias_event.content.get("alias", "") != alias_str:
+ # There is no canonical alias, nothing to do.
+ if not alias_event:
return
- yield self.event_creation_handler.create_and_send_nonmember_event(
- requester,
- {
- "type": EventTypes.CanonicalAlias,
- "state_key": "",
- "room_id": room_id,
- "sender": user_id,
- "content": {},
- },
- ratelimit=False
- )
+ # Obtain a mutable version of the event content.
+ content = dict(alias_event.content)
+ send_update = False
+
+ # Remove the alias property if it matches the removed alias.
+ alias_str = room_alias.to_string()
+ if alias_event.content.get("alias", "") == alias_str:
+ send_update = True
+ content.pop("alias", "")
+
+ # Filter the alt_aliases property for the removed alias. Note that the
+ # value is not modified if alt_aliases is of an unexpected form.
+ alt_aliases = content.get("alt_aliases")
+ if isinstance(alt_aliases, (list, tuple)) and alias_str in alt_aliases:
+ send_update = True
+ alt_aliases = [alias for alias in alt_aliases if alias != alias_str]
+
+ if alt_aliases:
+ content["alt_aliases"] = alt_aliases
+ else:
+ del content["alt_aliases"]
+
+ if send_update:
+ yield self.event_creation_handler.create_and_send_nonmember_event(
+ requester,
+ {
+ "type": EventTypes.CanonicalAlias,
+ "state_key": "",
+ "room_id": room_id,
+ "sender": user_id,
+ "content": content,
+ },
+ ratelimit=False,
+ )
@defer.inlineCallbacks
- def get_association_from_room_alias(self, room_alias):
- result = yield self.store.get_association_from_room_alias(
- room_alias
- )
+ def get_association_from_room_alias(self, room_alias: RoomAlias):
+ result = yield self.store.get_association_from_room_alias(room_alias)
if not result:
# Query AS to see if it exists
as_handler = self.appservice_handler
result = yield as_handler.query_room_alias_exists(room_alias)
- defer.returnValue(result)
+ return result
- def can_modify_alias(self, alias, user_id=None):
+ def can_modify_alias(self, alias: RoomAlias, user_id: Optional[str] = None):
# Any application service "interested" in an alias they are regexing on
# can modify the alias.
# Users can only modify the alias if ALL the interested services have
@@ -400,29 +373,48 @@ class DirectoryHandler(BaseHandler):
return defer.succeed(True)
@defer.inlineCallbacks
- def _user_can_delete_alias(self, alias, user_id):
+ def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
+ """Determine whether a user can delete an alias.
+
+ One of the following must be true:
+
+ 1. The user created the alias.
+ 2. The user is a server administrator.
+ 3. The user has a power-level sufficient to send a canonical alias event
+ for the current room.
+
+ """
creator = yield self.store.get_room_alias_creator(alias.to_string())
if creator is not None and creator == user_id:
- defer.returnValue(True)
+ return True
- is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id))
- defer.returnValue(is_admin)
+ # Resolve the alias to the corresponding room.
+ room_mapping = yield self.get_association(alias)
+ room_id = room_mapping["room_id"]
+ if not room_id:
+ return False
+
+ res = yield self.auth.check_can_change_room_list(
+ room_id, UserID.from_string(user_id)
+ )
+ return res
@defer.inlineCallbacks
- def edit_published_room_list(self, requester, room_id, visibility):
+ def edit_published_room_list(
+ self, requester: Requester, room_id: str, visibility: str
+ ):
"""Edit the entry of the room in the published room list.
requester
- room_id (str)
- visibility (str): "public" or "private"
+ room_id
+ visibility: "public" or "private"
"""
user_id = requester.user.to_string()
if not self.spam_checker.user_may_publish_room(user_id, room_id):
raise AuthError(
- 403,
- "This user is not permitted to publish rooms to the room list"
+ 403, "This user is not permitted to publish rooms to the room list"
)
if requester.is_guest:
@@ -434,15 +426,22 @@ class DirectoryHandler(BaseHandler):
if visibility == "public" and not self.enable_room_list_search:
# The room list has been disabled.
raise AuthError(
- 403,
- "This user is not permitted to publish rooms to the room list"
+ 403, "This user is not permitted to publish rooms to the room list"
)
room = yield self.store.get_room(room_id)
if room is None:
raise SynapseError(400, "Unknown room")
- yield self.auth.check_can_change_room_list(room_id, requester.user)
+ can_change_room_list = yield self.auth.check_can_change_room_list(
+ room_id, requester.user
+ )
+ if not can_change_room_list:
+ raise AuthError(
+ 403,
+ "This server requires you to be a moderator in the room to"
+ " edit its room list entry",
+ )
making_public = visibility == "public"
if making_public:
@@ -452,28 +451,27 @@ class DirectoryHandler(BaseHandler):
room_aliases.append(canonical_alias)
if not self.config.is_publishing_room_allowed(
- user_id, room_id, room_aliases,
+ user_id, room_id, room_aliases
):
# Lets just return a generic message, as there may be all sorts of
# reasons why we said no. TODO: Allow configurable error messages
# per alias creation rule?
- raise SynapseError(
- 403, "Not allowed to publish room",
- )
+ raise SynapseError(403, "Not allowed to publish room")
yield self.store.set_room_is_public(room_id, making_public)
@defer.inlineCallbacks
- def edit_published_appservice_room_list(self, appservice_id, network_id,
- room_id, visibility):
+ def edit_published_appservice_room_list(
+ self, appservice_id: str, network_id: str, room_id: str, visibility: str
+ ):
"""Add or remove a room from the appservice/network specific public
room list.
Args:
- appservice_id (str): ID of the appservice that owns the list
- network_id (str): The ID of the network the list is associated with
- room_id (str)
- visibility (str): either "public" or "private"
+ appservice_id: ID of the appservice that owns the list
+ network_id: The ID of the network the list is associated with
+ room_id
+ visibility: either "public" or "private"
"""
if visibility not in ["public", "private"]:
raise SynapseError(400, "Invalid visibility setting")
@@ -481,3 +479,19 @@ class DirectoryHandler(BaseHandler):
yield self.store.set_room_is_public_appservice(
room_id, appservice_id, network_id, visibility == "public"
)
+
+ async def get_aliases_for_room(
+ self, requester: Requester, room_id: str
+ ) -> List[str]:
+ """
+ Get a list of the aliases that currently point to this room on this server
+ """
+ # allow access to server admins and current members of the room
+ is_admin = await self.auth.is_server_admin(requester.user)
+ if not is_admin:
+ await self.auth.check_user_in_room_or_world_readable(
+ room_id, requester.user.to_string()
+ )
+
+ aliases = await self.store.get_aliases_for_room(room_id)
+ return aliases
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 9dc46aa15f..8f1bc0323c 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,13 +19,26 @@ import logging
from six import iteritems
+import attr
from canonicaljson import encode_canonical_json, json
+from signedjson.key import decode_verify_key_bytes
+from signedjson.sign import SignatureVerifyException, verify_signed_json
+from unpaddedbase64 import decode_base64
from twisted.internet import defer
-from synapse.api.errors import CodeMessageException, FederationDeniedError, SynapseError
-from synapse.types import UserID, get_domain_from_id
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError
+from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
+from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
+from synapse.types import (
+ UserID,
+ get_domain_from_id,
+ get_verify_key_from_cross_signing_key,
+)
+from synapse.util import unwrapFirstError
+from synapse.util.async_helpers import Linearizer
+from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
@@ -38,15 +52,35 @@ class E2eKeysHandler(object):
self.is_mine = hs.is_mine
self.clock = hs.get_clock()
+ self._edu_updater = SigningKeyEduUpdater(hs, self)
+
+ federation_registry = hs.get_federation_registry()
+
+ self._is_master = hs.config.worker_app is None
+ if not self._is_master:
+ self._user_device_resync_client = ReplicationUserDevicesResyncRestServlet.make_client(
+ hs
+ )
+ else:
+ # Only register this edu handler on master as it requires writing
+ # device updates to the db
+ #
+ # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+ federation_registry.register_edu_handler(
+ "org.matrix.signing_key_update",
+ self._edu_updater.incoming_signing_key_update,
+ )
+
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
- hs.get_federation_registry().register_query_handler(
+ federation_registry.register_query_handler(
"client_keys", self.on_federation_query_client_keys
)
+ @trace
@defer.inlineCallbacks
- def query_devices(self, query_body, timeout):
+ def query_devices(self, query_body, timeout, from_user_id):
""" Handle a device key query from a client
{
@@ -64,7 +98,13 @@ class E2eKeysHandler(object):
}
}
}
+
+ Args:
+ from_user_id (str): the user making the query. This is used when
+ adding cross-signing signatures to limit what signatures users
+ can see.
"""
+
device_keys_query = query_body.get("device_keys", {})
# separate users by domain.
@@ -79,6 +119,9 @@ class E2eKeysHandler(object):
else:
remote_queries[user_id] = device_ids
+ set_tag("local_key_query", local_query)
+ set_tag("remote_key_query", remote_queries)
+
# First get local devices.
failures = {}
results = {}
@@ -98,11 +141,10 @@ class E2eKeysHandler(object):
else:
query_list.append((user_id, None))
- user_ids_not_in_cache, remote_results = (
- yield self.store.get_user_devices_from_cache(
- query_list
- )
- )
+ (
+ user_ids_not_in_cache,
+ remote_results,
+ ) = yield self.store.get_user_devices_from_cache(query_list)
for user_id, devices in iteritems(remote_results):
user_devices = results.setdefault(user_id, {})
for device_id, device in iteritems(devices):
@@ -120,33 +162,161 @@ class E2eKeysHandler(object):
r = remote_queries_not_in_cache.setdefault(domain, {})
r[user_id] = remote_queries[user_id]
+ # Get cached cross-signing keys
+ cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
+ device_keys_query, from_user_id
+ )
+
# Now fetch any devices that we don't have in our cache
+ @trace
@defer.inlineCallbacks
def do_remote_query(destination):
+ """This is called when we are querying the device list of a user on
+ a remote homeserver and their device list is not in the device list
+ cache. If we share a room with this user and we're not querying for
+ specific user we will update the cache with their device list.
+ """
+
destination_query = remote_queries_not_in_cache[destination]
+
+ # We first consider whether we wish to update the device list cache with
+ # the users device list. We want to track a user's devices when the
+ # authenticated user shares a room with the queried user and the query
+ # has not specified a particular device.
+ # If we update the cache for the queried user we remove them from further
+ # queries. We use the more efficient batched query_client_keys for all
+ # remaining users
+ user_ids_updated = []
+ for (user_id, device_list) in destination_query.items():
+ if user_id in user_ids_updated:
+ continue
+
+ if device_list:
+ continue
+
+ room_ids = yield self.store.get_rooms_for_user(user_id)
+ if not room_ids:
+ continue
+
+ # We've decided we're sharing a room with this user and should
+ # probably be tracking their device lists. However, we haven't
+ # done an initial sync on the device list so we do it now.
+ try:
+ if self._is_master:
+ user_devices = yield self.device_handler.device_list_updater.user_device_resync(
+ user_id
+ )
+ else:
+ user_devices = yield self._user_device_resync_client(
+ user_id=user_id
+ )
+
+ user_devices = user_devices["devices"]
+ user_results = results.setdefault(user_id, {})
+ for device in user_devices:
+ user_results[device["device_id"]] = device["keys"]
+ user_ids_updated.append(user_id)
+ except Exception as e:
+ failures[destination] = _exception_to_failure(e)
+
+ if len(destination_query) == len(user_ids_updated):
+ # We've updated all the users in the query and we do not need to
+ # make any further remote calls.
+ return
+
+ # Remove all the users from the query which we have updated
+ for user_id in user_ids_updated:
+ destination_query.pop(user_id)
+
try:
remote_result = yield self.federation.query_client_keys(
- destination,
- {"device_keys": destination_query},
- timeout=timeout
+ destination, {"device_keys": destination_query}, timeout=timeout
)
for user_id, keys in remote_result["device_keys"].items():
if user_id in destination_query:
results[user_id] = keys
+ if "master_keys" in remote_result:
+ for user_id, key in remote_result["master_keys"].items():
+ if user_id in destination_query:
+ cross_signing_keys["master_keys"][user_id] = key
+
+ if "self_signing_keys" in remote_result:
+ for user_id, key in remote_result["self_signing_keys"].items():
+ if user_id in destination_query:
+ cross_signing_keys["self_signing_keys"][user_id] = key
+
except Exception as e:
- failures[destination] = _exception_to_failure(e)
+ failure = _exception_to_failure(e)
+ failures[destination] = failure
+ set_tag("error", True)
+ set_tag("reason", failure)
+
+ yield make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(do_remote_query, destination)
+ for destination in remote_queries_not_in_cache
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+ )
+
+ ret = {"device_keys": results, "failures": failures}
+
+ ret.update(cross_signing_keys)
+
+ return ret
+
+ @defer.inlineCallbacks
+ def get_cross_signing_keys_from_cache(self, query, from_user_id):
+ """Get cross-signing keys for users from the database
+
+ Args:
+ query (Iterable[string]) an iterable of user IDs. A dict whose keys
+ are user IDs satisfies this, so the query format used for
+ query_devices can be used here.
+ from_user_id (str): the user making the query. This is used when
+ adding cross-signing signatures to limit what signatures users
+ can see.
+
+ Returns:
+ defer.Deferred[dict[str, dict[str, dict]]]: map from
+ (master_keys|self_signing_keys|user_signing_keys) -> user_id -> key
+ """
+ master_keys = {}
+ self_signing_keys = {}
+ user_signing_keys = {}
+
+ user_ids = list(query)
- yield make_deferred_yieldable(defer.gatherResults([
- run_in_background(do_remote_query, destination)
- for destination in remote_queries_not_in_cache
- ], consumeErrors=True))
+ keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id)
- defer.returnValue({
- "device_keys": results, "failures": failures,
- })
+ for user_id, user_info in keys.items():
+ if user_info is None:
+ continue
+ if "master" in user_info:
+ master_keys[user_id] = user_info["master"]
+ if "self_signing" in user_info:
+ self_signing_keys[user_id] = user_info["self_signing"]
+ if (
+ from_user_id in keys
+ and keys[from_user_id] is not None
+ and "user_signing" in keys[from_user_id]
+ ):
+ # users can see other users' master and self-signing keys, but can
+ # only see their own user-signing keys
+ user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"]
+
+ return {
+ "master_keys": master_keys,
+ "self_signing_keys": self_signing_keys,
+ "user_signing_keys": user_signing_keys,
+ }
+
+ @trace
@defer.inlineCallbacks
def query_local_devices(self, query):
"""Get E2E device keys for local users
@@ -159,14 +329,22 @@ class E2eKeysHandler(object):
defer.Deferred: (resolves to dict[string, dict[string, dict]]):
map from user_id -> device_id -> device details
"""
+ set_tag("local_query", query)
local_query = []
result_dict = {}
for user_id, device_ids in query.items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
- logger.warning("Request for keys for non-local user %s",
- user_id)
+ logger.warning("Request for keys for non-local user %s", user_id)
+ log_kv(
+ {
+ "message": "Requested a local key for a user which"
+ " was not local to the homeserver",
+ "user_id": user_id,
+ }
+ )
+ set_tag("error", True)
raise SynapseError(400, "Not a user here")
if not device_ids:
@@ -180,18 +358,13 @@ class E2eKeysHandler(object):
results = yield self.store.get_e2e_device_keys(local_query)
- # Build the result structure, un-jsonify the results, and add the
- # "unsigned" section
+ # Build the result structure
for user_id, device_keys in results.items():
for device_id, device_info in device_keys.items():
- r = dict(device_info["keys"])
- r["unsigned"] = {}
- display_name = device_info["device_display_name"]
- if display_name is not None:
- r["unsigned"]["device_display_name"] = display_name
- result_dict[user_id][device_id] = r
+ result_dict[user_id][device_id] = device_info
- defer.returnValue(result_dict)
+ log_kv(results)
+ return result_dict
@defer.inlineCallbacks
def on_federation_query_client_keys(self, query_body):
@@ -199,8 +372,18 @@ class E2eKeysHandler(object):
"""
device_keys_query = query_body.get("device_keys", {})
res = yield self.query_local_devices(device_keys_query)
- defer.returnValue({"device_keys": res})
+ ret = {"device_keys": res}
+ # add in the cross-signing keys
+ cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
+ device_keys_query, None
+ )
+
+ ret.update(cross_signing_keys)
+
+ return ret
+
+ @trace
@defer.inlineCallbacks
def claim_one_time_keys(self, query, timeout):
local_query = []
@@ -215,6 +398,9 @@ class E2eKeysHandler(object):
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = device_keys
+ set_tag("local_key_query", local_query)
+ set_tag("remote_key_query", remote_queries)
+
results = yield self.store.claim_e2e_one_time_keys(local_query)
json_result = {}
@@ -226,43 +412,54 @@ class E2eKeysHandler(object):
key_id: json.loads(json_bytes)
}
+ @trace
@defer.inlineCallbacks
def claim_client_keys(destination):
+ set_tag("destination", destination)
device_keys = remote_queries[destination]
try:
remote_result = yield self.federation.claim_client_keys(
- destination,
- {"one_time_keys": device_keys},
- timeout=timeout
+ destination, {"one_time_keys": device_keys}, timeout=timeout
)
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
+
except Exception as e:
- failures[destination] = _exception_to_failure(e)
+ failure = _exception_to_failure(e)
+ failures[destination] = failure
+ set_tag("error", True)
+ set_tag("reason", failure)
- yield make_deferred_yieldable(defer.gatherResults([
- run_in_background(claim_client_keys, destination)
- for destination in remote_queries
- ], consumeErrors=True))
+ yield make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(claim_client_keys, destination)
+ for destination in remote_queries
+ ],
+ consumeErrors=True,
+ )
+ )
logger.info(
"Claimed one-time-keys: %s",
- ",".join((
- "%s for %s:%s" % (key_id, user_id, device_id)
- for user_id, user_keys in iteritems(json_result)
- for device_id, device_keys in iteritems(user_keys)
- for key_id, _ in iteritems(device_keys)
- )),
+ ",".join(
+ (
+ "%s for %s:%s" % (key_id, user_id, device_id)
+ for user_id, user_keys in iteritems(json_result)
+ for device_id, device_keys in iteritems(user_keys)
+ for key_id, _ in iteritems(device_keys)
+ )
+ ),
)
- defer.returnValue({
- "one_time_keys": json_result,
- "failures": failures
- })
+ log_kv({"one_time_keys": json_result, "failures": failures})
+ return {"one_time_keys": json_result, "failures": failures}
@defer.inlineCallbacks
+ @tag_args
def upload_keys_for_user(self, user_id, device_id, keys):
+
time_now = self.clock.time_msec()
# TODO: Validate the JSON to make sure it has the right keys.
@@ -270,20 +467,41 @@ class E2eKeysHandler(object):
if device_keys:
logger.info(
"Updating device_keys for device %r for user %s at %d",
- device_id, user_id, time_now
+ device_id,
+ user_id,
+ time_now,
+ )
+ log_kv(
+ {
+ "message": "Updating device_keys for user.",
+ "user_id": user_id,
+ "device_id": device_id,
+ }
)
# TODO: Sign the JSON with the server key
changed = yield self.store.set_e2e_device_keys(
- user_id, device_id, time_now, device_keys,
+ user_id, device_id, time_now, device_keys
)
if changed:
# Only notify about device updates *if* the keys actually changed
yield self.device_handler.notify_device_update(user_id, [device_id])
-
+ else:
+ log_kv({"message": "Not updating device_keys for user", "user_id": user_id})
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
+ log_kv(
+ {
+ "message": "Updating one_time_keys for device.",
+ "user_id": user_id,
+ "device_id": device_id,
+ }
+ )
yield self._upload_one_time_keys_for_user(
- user_id, device_id, time_now, one_time_keys,
+ user_id, device_id, time_now, one_time_keys
+ )
+ else:
+ log_kv(
+ {"message": "Did not update one_time_keys", "reason": "no keys given"}
)
# the device should have been registered already, but it may have been
@@ -295,23 +513,26 @@ class E2eKeysHandler(object):
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
- defer.returnValue({"one_time_key_counts": result})
+ set_tag("one_time_key_counts", result)
+ return {"one_time_key_counts": result}
@defer.inlineCallbacks
- def _upload_one_time_keys_for_user(self, user_id, device_id, time_now,
- one_time_keys):
+ def _upload_one_time_keys_for_user(
+ self, user_id, device_id, time_now, one_time_keys
+ ):
logger.info(
"Adding one_time_keys %r for device %r for user %r at %d",
- one_time_keys.keys(), device_id, user_id, time_now,
+ one_time_keys.keys(),
+ device_id,
+ user_id,
+ time_now,
)
# make a list of (alg, id, key) tuples
key_list = []
for key_id, key_obj in one_time_keys.items():
algorithm, key_id = key_id.split(":")
- key_list.append((
- algorithm, key_id, key_obj
- ))
+ key_list.append((algorithm, key_id, key_obj))
# First we check if we have already persisted any of the keys.
existing_key_map = yield self.store.get_e2e_one_time_keys(
@@ -325,42 +546,658 @@ class E2eKeysHandler(object):
if not _one_time_keys_match(ex_json, key):
raise SynapseError(
400,
- ("One time key %s:%s already exists. "
- "Old key: %s; new key: %r") %
- (algorithm, key_id, ex_json, key)
+ (
+ "One time key %s:%s already exists. "
+ "Old key: %s; new key: %r"
+ )
+ % (algorithm, key_id, ex_json, key),
)
else:
- new_keys.append((
- algorithm, key_id, encode_canonical_json(key).decode('ascii')))
+ new_keys.append(
+ (algorithm, key_id, encode_canonical_json(key).decode("ascii"))
+ )
+
+ log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys})
+ yield self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys)
+
+ @defer.inlineCallbacks
+ def upload_signing_keys_for_user(self, user_id, keys):
+ """Upload signing keys for cross-signing
+
+ Args:
+ user_id (string): the user uploading the keys
+ keys (dict[string, dict]): the signing keys
+ """
+
+ # if a master key is uploaded, then check it. Otherwise, load the
+ # stored master key, to check signatures on other keys
+ if "master_key" in keys:
+ master_key = keys["master_key"]
+
+ _check_cross_signing_key(master_key, user_id, "master")
+ else:
+ master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
+
+ # if there is no master key, then we can't do anything, because all the
+ # other cross-signing keys need to be signed by the master key
+ if not master_key:
+ raise SynapseError(400, "No master key available", Codes.MISSING_PARAM)
+
+ try:
+ master_key_id, master_verify_key = get_verify_key_from_cross_signing_key(
+ master_key
+ )
+ except ValueError:
+ if "master_key" in keys:
+ # the invalid key came from the request
+ raise SynapseError(400, "Invalid master key", Codes.INVALID_PARAM)
+ else:
+ # the invalid key came from the database
+ logger.error("Invalid master key found for user %s", user_id)
+ raise SynapseError(500, "Invalid master key")
+
+ # for the other cross-signing keys, make sure that they have valid
+ # signatures from the master key
+ if "self_signing_key" in keys:
+ self_signing_key = keys["self_signing_key"]
+
+ _check_cross_signing_key(
+ self_signing_key, user_id, "self_signing", master_verify_key
+ )
+
+ if "user_signing_key" in keys:
+ user_signing_key = keys["user_signing_key"]
+
+ _check_cross_signing_key(
+ user_signing_key, user_id, "user_signing", master_verify_key
+ )
+
+ # if everything checks out, then store the keys and send notifications
+ deviceids = []
+ if "master_key" in keys:
+ yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
+ deviceids.append(master_verify_key.version)
+ if "self_signing_key" in keys:
+ yield self.store.set_e2e_cross_signing_key(
+ user_id, "self_signing", self_signing_key
+ )
+ try:
+ deviceids.append(
+ get_verify_key_from_cross_signing_key(self_signing_key)[1].version
+ )
+ except ValueError:
+ raise SynapseError(400, "Invalid self-signing key", Codes.INVALID_PARAM)
+ if "user_signing_key" in keys:
+ yield self.store.set_e2e_cross_signing_key(
+ user_id, "user_signing", user_signing_key
+ )
+ # the signature stream matches the semantics that we want for
+ # user-signing key updates: only the user themselves is notified of
+ # their own user-signing key updates
+ yield self.device_handler.notify_user_signature_update(user_id, [user_id])
+
+ # master key and self-signing key updates match the semantics of device
+ # list updates: all users who share an encrypted room are notified
+ if len(deviceids):
+ yield self.device_handler.notify_device_update(user_id, deviceids)
+
+ return {}
+
+ @defer.inlineCallbacks
+ def upload_signatures_for_device_keys(self, user_id, signatures):
+ """Upload device signatures for cross-signing
+
+ Args:
+ user_id (string): the user uploading the signatures
+ signatures (dict[string, dict[string, dict]]): map of users to
+ devices to signed keys. This is the submission from the user; an
+ exception will be raised if it is malformed.
+ Returns:
+ dict: response to be sent back to the client. The response will have
+ a "failures" key, which will be a dict mapping users to devices
+ to errors for the signatures that failed.
+ Raises:
+ SynapseError: if the signatures dict is not valid.
+ """
+ failures = {}
+
+ # signatures to be stored. Each item will be a SignatureListItem
+ signature_list = []
+
+ # split between checking signatures for own user and signatures for
+ # other users, since we verify them with different keys
+ self_signatures = signatures.get(user_id, {})
+ other_signatures = {k: v for k, v in signatures.items() if k != user_id}
+
+ self_signature_list, self_failures = yield self._process_self_signatures(
+ user_id, self_signatures
+ )
+ signature_list.extend(self_signature_list)
+ failures.update(self_failures)
+
+ other_signature_list, other_failures = yield self._process_other_signatures(
+ user_id, other_signatures
+ )
+ signature_list.extend(other_signature_list)
+ failures.update(other_failures)
+
+ # store the signature, and send the appropriate notifications for sync
+ logger.debug("upload signature failures: %r", failures)
+ yield self.store.store_e2e_cross_signing_signatures(user_id, signature_list)
+
+ self_device_ids = [item.target_device_id for item in self_signature_list]
+ if self_device_ids:
+ yield self.device_handler.notify_device_update(user_id, self_device_ids)
+ signed_users = [item.target_user_id for item in other_signature_list]
+ if signed_users:
+ yield self.device_handler.notify_user_signature_update(
+ user_id, signed_users
+ )
+
+ return {"failures": failures}
+
+ @defer.inlineCallbacks
+ def _process_self_signatures(self, user_id, signatures):
+ """Process uploaded signatures of the user's own keys.
+
+ Signatures of the user's own keys from this API come in two forms:
+ - signatures of the user's devices by the user's self-signing key,
+ - signatures of the user's master key by the user's devices.
+
+ Args:
+ user_id (string): the user uploading the keys
+ signatures (dict[string, dict]): map of devices to signed keys
+
+ Returns:
+ (list[SignatureListItem], dict[string, dict[string, dict]]):
+ a list of signatures to store, and a map of users to devices to failure
+ reasons
+
+ Raises:
+ SynapseError: if the input is malformed
+ """
+ signature_list = []
+ failures = {}
+ if not signatures:
+ return signature_list, failures
+
+ if not isinstance(signatures, dict):
+ raise SynapseError(400, "Invalid parameter", Codes.INVALID_PARAM)
+
+ try:
+ # get our self-signing key to verify the signatures
+ (
+ _,
+ self_signing_key_id,
+ self_signing_verify_key,
+ ) = yield self._get_e2e_cross_signing_verify_key(user_id, "self_signing")
+
+ # get our master key, since we may have received a signature of it.
+ # We need to fetch it here so that we know what its key ID is, so
+ # that we can check if a signature that was sent is a signature of
+ # the master key or of a device
+ (
+ master_key,
+ _,
+ master_verify_key,
+ ) = yield self._get_e2e_cross_signing_verify_key(user_id, "master")
+
+ # fetch our stored devices. This is used to 1. verify
+ # signatures on the master key, and 2. to compare with what
+ # was sent if the device was signed
+ devices = yield self.store.get_e2e_device_keys([(user_id, None)])
+
+ if user_id not in devices:
+ raise NotFoundError("No device keys found")
+
+ devices = devices[user_id]
+ except SynapseError as e:
+ failure = _exception_to_failure(e)
+ failures[user_id] = {device: failure for device in signatures.keys()}
+ return signature_list, failures
+
+ for device_id, device in signatures.items():
+ # make sure submitted data is in the right form
+ if not isinstance(device, dict):
+ raise SynapseError(400, "Invalid parameter", Codes.INVALID_PARAM)
+
+ try:
+ if "signatures" not in device or user_id not in device["signatures"]:
+ # no signature was sent
+ raise SynapseError(
+ 400, "Invalid signature", Codes.INVALID_SIGNATURE
+ )
+
+ if device_id == master_verify_key.version:
+ # The signature is of the master key. This needs to be
+ # handled differently from signatures of normal devices.
+ master_key_signature_list = self._check_master_key_signature(
+ user_id, device_id, device, master_key, devices
+ )
+ signature_list.extend(master_key_signature_list)
+ continue
+
+ # at this point, we have a device that should be signed
+ # by the self-signing key
+ if self_signing_key_id not in device["signatures"][user_id]:
+ # no signature was sent
+ raise SynapseError(
+ 400, "Invalid signature", Codes.INVALID_SIGNATURE
+ )
+
+ try:
+ stored_device = devices[device_id]
+ except KeyError:
+ raise NotFoundError("Unknown device")
+ if self_signing_key_id in stored_device.get("signatures", {}).get(
+ user_id, {}
+ ):
+ # we already have a signature on this device, so we
+ # can skip it, since it should be exactly the same
+ continue
+
+ _check_device_signature(
+ user_id, self_signing_verify_key, device, stored_device
+ )
+
+ signature = device["signatures"][user_id][self_signing_key_id]
+ signature_list.append(
+ SignatureListItem(
+ self_signing_key_id, user_id, device_id, signature
+ )
+ )
+ except SynapseError as e:
+ failures.setdefault(user_id, {})[device_id] = _exception_to_failure(e)
+
+ return signature_list, failures
+
+ def _check_master_key_signature(
+ self, user_id, master_key_id, signed_master_key, stored_master_key, devices
+ ):
+ """Check signatures of a user's master key made by their devices.
- yield self.store.add_e2e_one_time_keys(
- user_id, device_id, time_now, new_keys
+ Args:
+ user_id (string): the user whose master key is being checked
+ master_key_id (string): the ID of the user's master key
+ signed_master_key (dict): the user's signed master key that was uploaded
+ stored_master_key (dict): our previously-stored copy of the user's master key
+ devices (iterable(dict)): the user's devices
+
+ Returns:
+ list[SignatureListItem]: a list of signatures to store
+
+ Raises:
+ SynapseError: if a signature is invalid
+ """
+ # for each device that signed the master key, check the signature.
+ master_key_signature_list = []
+ sigs = signed_master_key["signatures"]
+ for signing_key_id, signature in sigs[user_id].items():
+ _, signing_device_id = signing_key_id.split(":", 1)
+ if (
+ signing_device_id not in devices
+ or signing_key_id not in devices[signing_device_id]["keys"]
+ ):
+ # signed by an unknown device, or the
+ # device does not have the key
+ raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE)
+
+ # get the key and check the signature
+ pubkey = devices[signing_device_id]["keys"][signing_key_id]
+ verify_key = decode_verify_key_bytes(signing_key_id, decode_base64(pubkey))
+ _check_device_signature(
+ user_id, verify_key, signed_master_key, stored_master_key
+ )
+
+ master_key_signature_list.append(
+ SignatureListItem(signing_key_id, user_id, master_key_id, signature)
+ )
+
+ return master_key_signature_list
+
+ @defer.inlineCallbacks
+ def _process_other_signatures(self, user_id, signatures):
+ """Process uploaded signatures of other users' keys. These will be the
+ target user's master keys, signed by the uploading user's user-signing
+ key.
+
+ Args:
+ user_id (string): the user uploading the keys
+ signatures (dict[string, dict]): map of users to devices to signed keys
+
+ Returns:
+ (list[SignatureListItem], dict[string, dict[string, dict]]):
+ a list of signatures to store, and a map of users to devices to failure
+ reasons
+
+ Raises:
+ SynapseError: if the input is malformed
+ """
+ signature_list = []
+ failures = {}
+ if not signatures:
+ return signature_list, failures
+
+ try:
+ # get our user-signing key to verify the signatures
+ (
+ user_signing_key,
+ user_signing_key_id,
+ user_signing_verify_key,
+ ) = yield self._get_e2e_cross_signing_verify_key(user_id, "user_signing")
+ except SynapseError as e:
+ failure = _exception_to_failure(e)
+ for user, devicemap in signatures.items():
+ failures[user] = {device_id: failure for device_id in devicemap.keys()}
+ return signature_list, failures
+
+ for target_user, devicemap in signatures.items():
+ # make sure submitted data is in the right form
+ if not isinstance(devicemap, dict):
+ raise SynapseError(400, "Invalid parameter", Codes.INVALID_PARAM)
+ for device in devicemap.values():
+ if not isinstance(device, dict):
+ raise SynapseError(400, "Invalid parameter", Codes.INVALID_PARAM)
+
+ device_id = None
+ try:
+ # get the target user's master key, to make sure it matches
+ # what was sent
+ (
+ master_key,
+ master_key_id,
+ _,
+ ) = yield self._get_e2e_cross_signing_verify_key(
+ target_user, "master", user_id
+ )
+
+ # make sure that the target user's master key is the one that
+ # was signed (and no others)
+ device_id = master_key_id.split(":", 1)[1]
+ if device_id not in devicemap:
+ logger.debug(
+ "upload signature: could not find signature for device %s",
+ device_id,
+ )
+ # set device to None so that the failure gets
+ # marked on all the signatures
+ device_id = None
+ raise NotFoundError("Unknown device")
+ key = devicemap[device_id]
+ other_devices = [k for k in devicemap.keys() if k != device_id]
+ if other_devices:
+ # other devices were signed -- mark those as failures
+ logger.debug("upload signature: too many devices specified")
+ failure = _exception_to_failure(NotFoundError("Unknown device"))
+ failures[target_user] = {
+ device: failure for device in other_devices
+ }
+
+ if user_signing_key_id in master_key.get("signatures", {}).get(
+ user_id, {}
+ ):
+ # we already have the signature, so we can skip it
+ continue
+
+ _check_device_signature(
+ user_id, user_signing_verify_key, key, master_key
+ )
+
+ signature = key["signatures"][user_id][user_signing_key_id]
+ signature_list.append(
+ SignatureListItem(
+ user_signing_key_id, target_user, device_id, signature
+ )
+ )
+ except SynapseError as e:
+ failure = _exception_to_failure(e)
+ if device_id is None:
+ failures[target_user] = {
+ device_id: failure for device_id in devicemap.keys()
+ }
+ else:
+ failures.setdefault(target_user, {})[device_id] = failure
+
+ return signature_list, failures
+
+ @defer.inlineCallbacks
+ def _get_e2e_cross_signing_verify_key(
+ self, user_id: str, key_type: str, from_user_id: str = None
+ ):
+ """Fetch locally or remotely query for a cross-signing public key.
+
+ First, attempt to fetch the cross-signing public key from storage.
+ If that fails, query the keys from the homeserver they belong to
+ and update our local copy.
+
+ Args:
+ user_id: the user whose key should be fetched
+ key_type: the type of key to fetch
+ from_user_id: the user that we are fetching the keys for.
+ This affects what signatures are fetched.
+
+ Returns:
+ dict, str, VerifyKey: the raw key data, the key ID, and the
+ signedjson verify key
+
+ Raises:
+ NotFoundError: if the key is not found
+ SynapseError: if `user_id` is invalid
+ """
+ user = UserID.from_string(user_id)
+ key = yield self.store.get_e2e_cross_signing_key(
+ user_id, key_type, from_user_id
+ )
+
+ if key:
+ # We found a copy of this key in our database. Decode and return it
+ key_id, verify_key = get_verify_key_from_cross_signing_key(key)
+ return key, key_id, verify_key
+
+ # If we couldn't find the key locally, and we're looking for keys of
+ # another user then attempt to fetch the missing key from the remote
+ # user's server.
+ #
+ # We may run into this in possible edge cases where a user tries to
+ # cross-sign a remote user, but does not share any rooms with them yet.
+ # Thus, we would not have their key list yet. We instead fetch the key,
+ # store it and notify clients of new, associated device IDs.
+ if self.is_mine(user) or key_type not in ["master", "self_signing"]:
+ # Note that master and self_signing keys are the only cross-signing keys we
+ # can request over federation
+ raise NotFoundError("No %s key found for %s" % (key_type, user_id))
+
+ (
+ key,
+ key_id,
+ verify_key,
+ ) = yield self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
+
+ if key is None:
+ raise NotFoundError("No %s key found for %s" % (key_type, user_id))
+
+ return key, key_id, verify_key
+
+ @defer.inlineCallbacks
+ def _retrieve_cross_signing_keys_for_remote_user(
+ self, user: UserID, desired_key_type: str,
+ ):
+ """Queries cross-signing keys for a remote user and saves them to the database
+
+ Only the key specified by `key_type` will be returned, while all retrieved keys
+ will be saved regardless
+
+ Args:
+ user: The user to query remote keys for
+ desired_key_type: The type of key to receive. One of "master", "self_signing"
+
+ Returns:
+ Deferred[Tuple[Optional[Dict], Optional[str], Optional[VerifyKey]]]: A tuple
+ of the retrieved key content, the key's ID and the matching VerifyKey.
+ If the key cannot be retrieved, all values in the tuple will instead be None.
+ """
+ try:
+ remote_result = yield self.federation.query_user_devices(
+ user.domain, user.to_string()
+ )
+ except Exception as e:
+ logger.warning(
+ "Unable to query %s for cross-signing keys of user %s: %s %s",
+ user.domain,
+ user.to_string(),
+ type(e),
+ e,
+ )
+ return None, None, None
+
+ # Process each of the retrieved cross-signing keys
+ desired_key = None
+ desired_key_id = None
+ desired_verify_key = None
+ retrieved_device_ids = []
+ for key_type in ["master", "self_signing"]:
+ key_content = remote_result.get(key_type + "_key")
+ if not key_content:
+ continue
+
+ # Ensure these keys belong to the correct user
+ if "user_id" not in key_content:
+ logger.warning(
+ "Invalid %s key retrieved, missing user_id field: %s",
+ key_type,
+ key_content,
+ )
+ continue
+ if user.to_string() != key_content["user_id"]:
+ logger.warning(
+ "Found %s key of user %s when querying for keys of user %s",
+ key_type,
+ key_content["user_id"],
+ user.to_string(),
+ )
+ continue
+
+ # Validate the key contents
+ try:
+ # verify_key is a VerifyKey from signedjson, which uses
+ # .version to denote the portion of the key ID after the
+ # algorithm and colon, which is the device ID
+ key_id, verify_key = get_verify_key_from_cross_signing_key(key_content)
+ except ValueError as e:
+ logger.warning(
+ "Invalid %s key retrieved: %s - %s %s",
+ key_type,
+ key_content,
+ type(e),
+ e,
+ )
+ continue
+
+ # Note down the device ID attached to this key
+ retrieved_device_ids.append(verify_key.version)
+
+ # If this is the desired key type, save it and its ID/VerifyKey
+ if key_type == desired_key_type:
+ desired_key = key_content
+ desired_verify_key = verify_key
+ desired_key_id = key_id
+
+ # At the same time, store this key in the db for subsequent queries
+ yield self.store.set_e2e_cross_signing_key(
+ user.to_string(), key_type, key_content
+ )
+
+ # Notify clients that new devices for this user have been discovered
+ if retrieved_device_ids:
+ # XXX is this necessary?
+ yield self.device_handler.notify_device_update(
+ user.to_string(), retrieved_device_ids
+ )
+
+ return desired_key, desired_key_id, desired_verify_key
+
+
+def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
+ """Check a cross-signing key uploaded by a user. Performs some basic sanity
+ checking, and ensures that it is signed, if a signature is required.
+
+ Args:
+ key (dict): the key data to verify
+ user_id (str): the user whose key is being checked
+ key_type (str): the type of key that the key should be
+ signing_key (VerifyKey): (optional) the signing key that the key should
+ be signed with. If omitted, signatures will not be checked.
+ """
+ if (
+ key.get("user_id") != user_id
+ or key_type not in key.get("usage", [])
+ or len(key.get("keys", {})) != 1
+ ):
+ raise SynapseError(400, ("Invalid %s key" % (key_type,)), Codes.INVALID_PARAM)
+
+ if signing_key:
+ try:
+ verify_signed_json(key, user_id, signing_key)
+ except SignatureVerifyException:
+ raise SynapseError(
+ 400, ("Invalid signature on %s key" % key_type), Codes.INVALID_SIGNATURE
+ )
+
+
+def _check_device_signature(user_id, verify_key, signed_device, stored_device):
+ """Check that a signature on a device or cross-signing key is correct and
+ matches the copy of the device/key that we have stored. Throws an
+ exception if an error is detected.
+
+ Args:
+ user_id (str): the user ID whose signature is being checked
+ verify_key (VerifyKey): the key to verify the device with
+ signed_device (dict): the uploaded signed device data
+ stored_device (dict): our previously stored copy of the device
+
+ Raises:
+ SynapseError: if the signature was invalid or the sent device is not the
+ same as the stored device
+
+ """
+
+ # make sure that the device submitted matches what we have stored
+ stripped_signed_device = {
+ k: v for k, v in signed_device.items() if k not in ["signatures", "unsigned"]
+ }
+ stripped_stored_device = {
+ k: v for k, v in stored_device.items() if k not in ["signatures", "unsigned"]
+ }
+ if stripped_signed_device != stripped_stored_device:
+ logger.debug(
+ "upload signatures: key does not match %s vs %s",
+ signed_device,
+ stored_device,
)
+ raise SynapseError(400, "Key does not match")
+
+ try:
+ verify_signed_json(signed_device, user_id, verify_key)
+ except SignatureVerifyException:
+ logger.debug("invalid signature on key")
+ raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE)
def _exception_to_failure(e):
+ if isinstance(e, SynapseError):
+ return {"status": e.code, "errcode": e.errcode, "message": str(e)}
+
if isinstance(e, CodeMessageException):
- return {
- "status": e.code, "message": str(e),
- }
+ return {"status": e.code, "message": str(e)}
if isinstance(e, NotRetryingDestination):
- return {
- "status": 503, "message": "Not ready for retry",
- }
-
- if isinstance(e, FederationDeniedError):
- return {
- "status": 403, "message": "Federation Denied",
- }
+ return {"status": 503, "message": "Not ready for retry"}
# include ConnectionRefused and other errors
#
# Note that some Exceptions (notably twisted's ResponseFailed etc) don't
# give a string for e.message, which json then fails to serialize.
- return {
- "status": 503, "message": str(e),
- }
+ return {"status": 503, "message": str(e)}
def _one_time_keys_match(old_key_json, new_key):
@@ -377,3 +1214,111 @@ def _one_time_keys_match(old_key_json, new_key):
new_key_copy.pop("signatures", None)
return old_key == new_key_copy
+
+
+@attr.s
+class SignatureListItem:
+ """An item in the signature list as used by upload_signatures_for_device_keys.
+ """
+
+ signing_key_id = attr.ib()
+ target_user_id = attr.ib()
+ target_device_id = attr.ib()
+ signature = attr.ib()
+
+
+class SigningKeyEduUpdater(object):
+ """Handles incoming signing key updates from federation and updates the DB"""
+
+ def __init__(self, hs, e2e_keys_handler):
+ self.store = hs.get_datastore()
+ self.federation = hs.get_federation_client()
+ self.clock = hs.get_clock()
+ self.e2e_keys_handler = e2e_keys_handler
+
+ self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
+
+ # user_id -> list of updates waiting to be handled.
+ self._pending_updates = {}
+
+ # Recently seen stream ids. We don't bother keeping these in the DB,
+ # but they're useful to have them about to reduce the number of spurious
+ # resyncs.
+ self._seen_updates = ExpiringCache(
+ cache_name="signing_key_update_edu",
+ clock=self.clock,
+ max_len=10000,
+ expiry_ms=30 * 60 * 1000,
+ iterable=True,
+ )
+
+ @defer.inlineCallbacks
+ def incoming_signing_key_update(self, origin, edu_content):
+ """Called on incoming signing key update from federation. Responsible for
+ parsing the EDU and adding to pending updates list.
+
+ Args:
+ origin (string): the server that sent the EDU
+ edu_content (dict): the contents of the EDU
+ """
+
+ user_id = edu_content.pop("user_id")
+ master_key = edu_content.pop("master_key", None)
+ self_signing_key = edu_content.pop("self_signing_key", None)
+
+ if get_domain_from_id(user_id) != origin:
+ logger.warning("Got signing key update edu for %r from %r", user_id, origin)
+ return
+
+ room_ids = yield self.store.get_rooms_for_user(user_id)
+ if not room_ids:
+ # We don't share any rooms with this user. Ignore update, as we
+ # probably won't get any further updates.
+ return
+
+ self._pending_updates.setdefault(user_id, []).append(
+ (master_key, self_signing_key)
+ )
+
+ yield self._handle_signing_key_updates(user_id)
+
+ @defer.inlineCallbacks
+ def _handle_signing_key_updates(self, user_id):
+ """Actually handle pending updates.
+
+ Args:
+ user_id (string): the user whose updates we are processing
+ """
+
+ device_handler = self.e2e_keys_handler.device_handler
+
+ with (yield self._remote_edu_linearizer.queue(user_id)):
+ pending_updates = self._pending_updates.pop(user_id, [])
+ if not pending_updates:
+ # This can happen since we batch updates
+ return
+
+ device_ids = []
+
+ logger.info("pending updates: %r", pending_updates)
+
+ for master_key, self_signing_key in pending_updates:
+ if master_key:
+ yield self.store.set_e2e_cross_signing_key(
+ user_id, "master", master_key
+ )
+ _, verify_key = get_verify_key_from_cross_signing_key(master_key)
+ # verify_key is a VerifyKey from signedjson, which uses
+ # .version to denote the portion of the key ID after the
+ # algorithm and colon, which is the device ID
+ device_ids.append(verify_key.version)
+ if self_signing_key:
+ yield self.store.set_e2e_cross_signing_key(
+ user_id, "self_signing", self_signing_key
+ )
+ _, verify_key = get_verify_key_from_cross_signing_key(
+ self_signing_key
+ )
+ device_ids.append(verify_key.version)
+
+ yield device_handler.notify_device_update(user_id, device_ids)
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 7bc174070e..9abaf13b8f 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2017, 2018 New Vector Ltd
+# Copyright 2019 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -26,6 +27,7 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
+from synapse.logging.opentracing import log_kv, trace
from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
@@ -49,6 +51,7 @@ class E2eRoomKeysHandler(object):
# changed.
self._upload_linearizer = Linearizer("upload_room_keys_lock")
+ @trace
@defer.inlineCallbacks
def get_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
@@ -84,8 +87,10 @@ class E2eRoomKeysHandler(object):
user_id, version, room_id, session_id
)
- defer.returnValue(results)
+ log_kv(results)
+ return results
+ @trace
@defer.inlineCallbacks
def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
@@ -99,14 +104,36 @@ class E2eRoomKeysHandler(object):
rooms
session_id(string): session ID to delete keys for, for None to delete keys
for all sessions
+ Raises:
+ NotFoundError: if the backup version does not exist
Returns:
- A deferred of the deletion transaction
+ A dict containing the count and etag for the backup version
"""
# lock for consistency with uploading
with (yield self._upload_linearizer.queue(user_id)):
+ # make sure the backup version exists
+ try:
+ version_info = yield self.store.get_e2e_room_keys_version_info(
+ user_id, version
+ )
+ except StoreError as e:
+ if e.code == 404:
+ raise NotFoundError("Unknown backup version")
+ else:
+ raise
+
yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id)
+ version_etag = version_info["etag"] + 1
+ yield self.store.update_e2e_room_keys_version(
+ user_id, version, None, version_etag
+ )
+
+ count = yield self.store.count_e2e_room_keys(user_id, version)
+ return {"etag": str(version_etag), "count": count}
+
+ @trace
@defer.inlineCallbacks
def upload_room_keys(self, user_id, version, room_keys):
"""Bulk upload a list of room keys into a given backup version, asserting
@@ -133,6 +160,9 @@ class E2eRoomKeysHandler(object):
}
}
+ Returns:
+ A dict containing the count and etag for the backup version
+
Raises:
NotFoundError: if there are no versions defined
RoomKeysVersionError: if the uploaded version is not the current version
@@ -152,57 +182,83 @@ class E2eRoomKeysHandler(object):
else:
raise
- if version_info['version'] != version:
+ if version_info["version"] != version:
# Check that the version we're trying to upload actually exists
try:
version_info = yield self.store.get_e2e_room_keys_version_info(
- user_id, version,
+ user_id, version
)
# if we get this far, the version must exist
- raise RoomKeysVersionError(current_version=version_info['version'])
+ raise RoomKeysVersionError(current_version=version_info["version"])
except StoreError as e:
if e.code == 404:
raise NotFoundError("Version '%s' not found" % (version,))
else:
raise
- # go through the room_keys.
- # XXX: this should/could be done concurrently, given we're in a lock.
- for room_id, room in iteritems(room_keys['rooms']):
- for session_id, session in iteritems(room['sessions']):
- yield self._upload_room_key(
- user_id, version, room_id, session_id, session
- )
-
- @defer.inlineCallbacks
- def _upload_room_key(self, user_id, version, room_id, session_id, room_key):
- """Upload a given room_key for a given room and session into a given
- version of the backup. Merges the key with any which might already exist.
-
- Args:
- user_id(str): the user whose backup we're setting
- version(str): the version ID of the backup we're updating
- room_id(str): the ID of the room whose keys we're setting
- session_id(str): the session whose room_key we're setting
- room_key(dict): the room_key being set
- """
-
- # get the room_key for this particular row
- current_room_key = None
- try:
- current_room_key = yield self.store.get_e2e_room_key(
- user_id, version, room_id, session_id
+ # Fetch any existing room keys for the sessions that have been
+ # submitted. Then compare them with the submitted keys. If the
+ # key is new, insert it; if the key should be updated, then update
+ # it; otherwise, drop it.
+ existing_keys = yield self.store.get_e2e_room_keys_multi(
+ user_id, version, room_keys["rooms"]
)
- except StoreError as e:
- if e.code == 404:
- pass
- else:
- raise
+ to_insert = [] # batch the inserts together
+ changed = False # if anything has changed, we need to update the etag
+ for room_id, room in iteritems(room_keys["rooms"]):
+ for session_id, room_key in iteritems(room["sessions"]):
+ if not isinstance(room_key["is_verified"], bool):
+ msg = (
+ "is_verified must be a boolean in keys for session %s in"
+ "room %s" % (session_id, room_id)
+ )
+ raise SynapseError(400, msg, Codes.INVALID_PARAM)
+
+ log_kv(
+ {
+ "message": "Trying to upload room key",
+ "room_id": room_id,
+ "session_id": session_id,
+ "user_id": user_id,
+ }
+ )
+ current_room_key = existing_keys.get(room_id, {}).get(session_id)
+ if current_room_key:
+ if self._should_replace_room_key(current_room_key, room_key):
+ log_kv({"message": "Replacing room key."})
+ # updates are done one at a time in the DB, so send
+ # updates right away rather than batching them up,
+ # like we do with the inserts
+ yield self.store.update_e2e_room_key(
+ user_id, version, room_id, session_id, room_key
+ )
+ changed = True
+ else:
+ log_kv({"message": "Not replacing room_key."})
+ else:
+ log_kv(
+ {
+ "message": "Room key not found.",
+ "room_id": room_id,
+ "user_id": user_id,
+ }
+ )
+ log_kv({"message": "Replacing room key."})
+ to_insert.append((room_id, session_id, room_key))
+ changed = True
+
+ if len(to_insert):
+ yield self.store.add_e2e_room_keys(user_id, version, to_insert)
+
+ version_etag = version_info["etag"]
+ if changed:
+ version_etag = version_etag + 1
+ yield self.store.update_e2e_room_keys_version(
+ user_id, version, None, version_etag
+ )
- if self._should_replace_room_key(current_room_key, room_key):
- yield self.store.set_e2e_room_key(
- user_id, version, room_id, session_id, room_key
- )
+ count = yield self.store.count_e2e_room_keys(user_id, version)
+ return {"etag": str(version_etag), "count": count}
@staticmethod
def _should_replace_room_key(current_room_key, room_key):
@@ -223,19 +279,20 @@ class E2eRoomKeysHandler(object):
# spelt out with if/elifs rather than nested boolean expressions
# purely for legibility.
- if room_key['is_verified'] and not current_room_key['is_verified']:
+ if room_key["is_verified"] and not current_room_key["is_verified"]:
return True
elif (
- room_key['first_message_index'] <
- current_room_key['first_message_index']
+ room_key["first_message_index"]
+ < current_room_key["first_message_index"]
):
return True
- elif room_key['forwarded_count'] < current_room_key['forwarded_count']:
+ elif room_key["forwarded_count"] < current_room_key["forwarded_count"]:
return True
else:
return False
return True
+ @trace
@defer.inlineCallbacks
def create_version(self, user_id, version_info):
"""Create a new backup version. This automatically becomes the new
@@ -262,7 +319,7 @@ class E2eRoomKeysHandler(object):
new_version = yield self.store.create_e2e_room_keys_version(
user_id, version_info
)
- defer.returnValue(new_version)
+ return new_version
@defer.inlineCallbacks
def get_version_info(self, user_id, version=None):
@@ -292,8 +349,11 @@ class E2eRoomKeysHandler(object):
raise NotFoundError("Unknown backup version")
else:
raise
- defer.returnValue(res)
+ res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"])
+ return res
+
+ @trace
@defer.inlineCallbacks
def delete_version(self, user_id, version=None):
"""Deletes a given version of the user's e2e_room_keys backup
@@ -314,6 +374,7 @@ class E2eRoomKeysHandler(object):
else:
raise
+ @trace
@defer.inlineCallbacks
def update_version(self, user_id, version, version_info):
"""Update the info about a given version of the user's backup
@@ -328,16 +389,10 @@ class E2eRoomKeysHandler(object):
A deferred of an empty dict.
"""
if "version" not in version_info:
+ version_info["version"] = version
+ elif version_info["version"] != version:
raise SynapseError(
- 400,
- "Missing version in body",
- Codes.MISSING_PARAM
- )
- if version_info["version"] != version:
- raise SynapseError(
- 400,
- "Version in body does not match",
- Codes.INVALID_PARAM
+ 400, "Version in body does not match", Codes.INVALID_PARAM
)
with (yield self._upload_linearizer.queue(user_id)):
try:
@@ -350,12 +405,10 @@ class E2eRoomKeysHandler(object):
else:
raise
if old_info["algorithm"] != version_info["algorithm"]:
- raise SynapseError(
- 400,
- "Algorithm does not match",
- Codes.INVALID_PARAM
- )
+ raise SynapseError(400, "Algorithm does not match", Codes.INVALID_PARAM)
- yield self.store.update_e2e_room_keys_version(user_id, version, version_info)
+ yield self.store.update_e2e_room_keys_version(
+ user_id, version, version_info
+ )
- defer.returnValue({})
+ return {}
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index eb525070cf..ec18a42a68 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -16,13 +16,11 @@
import logging
import random
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, SynapseError
from synapse.events import EventBase
+from synapse.logging.utils import log_function
from synapse.types import UserID
-from synapse.util.logutils import log_function
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -31,7 +29,6 @@ logger = logging.getLogger(__name__)
class EventStreamHandler(BaseHandler):
-
def __init__(self, hs):
super(EventStreamHandler, self).__init__(hs)
@@ -51,29 +48,36 @@ class EventStreamHandler(BaseHandler):
self._server_notices_sender = hs.get_server_notices_sender()
self._event_serializer = hs.get_event_client_serializer()
- @defer.inlineCallbacks
@log_function
- def get_stream(self, auth_user_id, pagin_config, timeout=0,
- as_client_event=True, affect_presence=True,
- only_keys=None, room_id=None, is_guest=False):
+ async def get_stream(
+ self,
+ auth_user_id,
+ pagin_config,
+ timeout=0,
+ as_client_event=True,
+ affect_presence=True,
+ only_keys=None,
+ room_id=None,
+ is_guest=False,
+ ):
"""Fetches the events stream for a given user.
If `only_keys` is not None, events from keys will be sent down.
"""
if room_id:
- blocked = yield self.store.is_room_blocked(room_id)
+ blocked = await self.store.is_room_blocked(room_id)
if blocked:
raise SynapseError(403, "This room has been blocked on this server")
# send any outstanding server notices to the user.
- yield self._server_notices_sender.on_user_syncing(auth_user_id)
+ await self._server_notices_sender.on_user_syncing(auth_user_id)
auth_user = UserID.from_string(auth_user_id)
presence_handler = self.hs.get_presence_handler()
- context = yield presence_handler.user_syncing(
- auth_user_id, affect_presence=affect_presence,
+ context = await presence_handler.user_syncing(
+ auth_user_id, affect_presence=affect_presence
)
with context:
if timeout:
@@ -84,10 +88,13 @@ class EventStreamHandler(BaseHandler):
# thundering herds on restart.
timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
- events, tokens = yield self.notifier.get_events_for(
- auth_user, pagin_config, timeout,
+ events, tokens = await self.notifier.get_events_for(
+ auth_user,
+ pagin_config,
+ timeout,
only_keys=only_keys,
- is_guest=is_guest, explicit_room_id=room_id
+ is_guest=is_guest,
+ explicit_room_id=room_id,
)
# When the user joins a new room, or another user joins a currently
@@ -102,17 +109,15 @@ class EventStreamHandler(BaseHandler):
# Send down presence.
if event.state_key == auth_user_id:
# Send down presence for everyone in the room.
- users = yield self.state.get_current_users_in_room(event.room_id)
- states = yield presence_handler.get_states(
- users,
- as_event=True,
+ users = await self.state.get_current_users_in_room(
+ event.room_id
)
+ states = await presence_handler.get_states(users, as_event=True)
to_add.extend(states)
else:
- ev = yield presence_handler.get_state(
- UserID.from_string(event.state_key),
- as_event=True,
+ ev = await presence_handler.get_state(
+ UserID.from_string(event.state_key), as_event=True
)
to_add.append(ev)
@@ -120,8 +125,10 @@ class EventStreamHandler(BaseHandler):
time_now = self.clock.time_msec()
- chunks = yield self._event_serializer.serialize_events(
- events, time_now, as_client_event=as_client_event,
+ chunks = await self._event_serializer.serialize_events(
+ events,
+ time_now,
+ as_client_event=as_client_event,
# We don't bundle "live" events, as otherwise clients
# will end up double counting annotations.
bundle_aggregations=False,
@@ -133,13 +140,15 @@ class EventStreamHandler(BaseHandler):
"end": tokens[1].to_string(),
}
- defer.returnValue(chunk)
+ return chunk
class EventHandler(BaseHandler):
+ def __init__(self, hs):
+ super(EventHandler, self).__init__(hs)
+ self.storage = hs.get_storage()
- @defer.inlineCallbacks
- def get_event(self, user, room_id, event_id):
+ async def get_event(self, user, room_id, event_id):
"""Retrieve a single specified event.
Args:
@@ -154,26 +163,19 @@ class EventHandler(BaseHandler):
AuthError if the user does not have the rights to inspect this
event.
"""
- event = yield self.store.get_event(event_id, check_room_id=room_id)
+ event = await self.store.get_event(event_id, check_room_id=room_id)
if not event:
- defer.returnValue(None)
- return
+ return None
- users = yield self.store.get_users_in_room(event.room_id)
+ users = await self.store.get_users_in_room(event.room_id)
is_peeking = user.to_string() not in users
- filtered = yield filter_events_for_client(
- self.store,
- user.to_string(),
- [event],
- is_peeking=is_peeking
+ filtered = await filter_events_for_client(
+ self.storage, user.to_string(), [event], is_peeking=is_peeking
)
if not filtered:
- raise AuthError(
- 403,
- "You don't have permission to access that event."
- )
+ raise AuthError(403, "You don't have permission to access that event.")
- defer.returnValue(event)
+ return event
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 35528eb48a..ebdc239fff 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -19,17 +19,20 @@
import itertools
import logging
+from typing import Dict, Iterable, List, Optional, Sequence, Tuple
import six
from six import iteritems, itervalues
from six.moves import http_client, zip
+import attr
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
from twisted.internet import defer
+from synapse import event_auth
from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.api.errors import (
AuthError,
@@ -38,24 +41,33 @@ from synapse.api.errors import (
FederationDeniedError,
FederationError,
RequestSendFailed,
- StoreError,
SynapseError,
)
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion, RoomVersions
from synapse.crypto.event_signing import compute_event_signature
from synapse.event_auth import auth_types_for_event
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
+from synapse.logging.context import (
+ make_deferred_yieldable,
+ nested_logging_context,
+ preserve_fn,
+ run_in_background,
+)
+from synapse.logging.utils import log_function
+from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.replication.http.federation import (
ReplicationCleanRoomRestServlet,
ReplicationFederationSendEventsRestServlet,
+ ReplicationStoreRoomOnInviteRestServlet,
)
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store
-from synapse.types import UserID, get_domain_from_id
-from synapse.util import logcontext, unwrapFirstError
-from synapse.util.async_helpers import Linearizer
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
+from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id
+from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.distributor import user_joined_room
-from synapse.util.logutils import log_function
from synapse.util.retryutils import NotRetryingDestination
from synapse.visibility import filter_events_for_server
@@ -64,6 +76,23 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__)
+@attr.s
+class _NewEventInfo:
+ """Holds information about a received event, ready for passing to _handle_new_events
+
+ Attributes:
+ event: the received event
+
+ state: the state at that event
+
+ auth_events: the auth_event map for that event
+ """
+
+ event = attr.ib(type=EventBase)
+ state = attr.ib(type=Optional[Sequence[EventBase]], default=None)
+ auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None)
+
+
def shortstr(iterable, maxitems=5):
"""If iterable has maxitems or fewer, return the stringification of a list
containing those items.
@@ -82,16 +111,16 @@ def shortstr(iterable, maxitems=5):
items = list(itertools.islice(iterable, maxitems + 1))
if len(items) <= maxitems:
return str(items)
- return u"[" + u", ".join(repr(r) for r in items[:maxitems]) + u", ...]"
+ return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
class FederationHandler(BaseHandler):
"""Handles events that originated from federation.
Responsible for:
a) handling received Pdus before handing them on as Events to the rest
- of the home server (including auth and state conflict resoultion)
+ of the homeserver (including auth and state conflict resoultion)
b) converting events that were produced by local clients that may need
- to be sent to remote home servers.
+ to be sent to remote homeservers.
c) doing the necessary dances to invite remote users and join remote
rooms.
"""
@@ -102,6 +131,8 @@ class FederationHandler(BaseHandler):
self.hs = hs
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
self.federation_client = hs.get_federation_client()
self.state_handler = hs.get_state_handler()
self.server_name = hs.hostname
@@ -111,30 +142,41 @@ class FederationHandler(BaseHandler):
self.pusher_pool = hs.get_pusherpool()
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
+ self._message_handler = hs.get_message_handler()
self._server_notices_mxid = hs.config.server_notices_mxid
self.config = hs.config
self.http_client = hs.get_simple_http_client()
- self._send_events_to_master = (
- ReplicationFederationSendEventsRestServlet.make_client(hs)
+ self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client(
+ hs
)
- self._notify_user_membership_change = (
- ReplicationUserJoinedLeftRoomRestServlet.make_client(hs)
+ self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client(
+ hs
)
- self._clean_room_for_join_client = (
- ReplicationCleanRoomRestServlet.make_client(hs)
+ self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client(
+ hs
)
+ if hs.config.worker_app:
+ self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
+ hs
+ )
+ self._maybe_store_room_on_invite = ReplicationStoreRoomOnInviteRestServlet.make_client(
+ hs
+ )
+ else:
+ self._device_list_updater = hs.get_device_handler().device_list_updater
+ self._maybe_store_room_on_invite = self.store.maybe_store_room_on_invite
+
# When joining a room we need to queue any events for that room up
self.room_queues = {}
self._room_pdu_linearizer = Linearizer("fed_room_pdu")
self.third_party_event_rules = hs.get_third_party_event_rules()
- @defer.inlineCallbacks
- def on_receive_pdu(
- self, origin, pdu, sent_to_us_directly=False,
- ):
+ self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
+
+ async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
""" Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
@@ -144,33 +186,24 @@ class FederationHandler(BaseHandler):
pdu (FrozenEvent): received PDU
sent_to_us_directly (bool): True if this event was pushed to us; False if
we pulled it as the result of a missing prev_event.
-
- Returns (Deferred): completes with None
"""
room_id = pdu.room_id
event_id = pdu.event_id
- logger.info(
- "[%s %s] handling received PDU: %s",
- room_id, event_id, pdu,
- )
+ logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu)
# We reprocess pdus when we have seen them only as outliers
- existing = yield self.store.get_event(
- event_id,
- allow_none=True,
- allow_rejected=True,
+ existing = await self.store.get_event(
+ event_id, allow_none=True, allow_rejected=True
)
# FIXME: Currently we fetch an event again when we already have it
# if it has been marked as an outlier.
- already_seen = (
- existing and (
- not existing.internal_metadata.is_outlier()
- or pdu.internal_metadata.is_outlier()
- )
+ already_seen = existing and (
+ not existing.internal_metadata.is_outlier()
+ or pdu.internal_metadata.is_outlier()
)
if already_seen:
logger.debug("[%s %s]: Already seen pdu", room_id, event_id)
@@ -182,20 +215,19 @@ class FederationHandler(BaseHandler):
try:
self._sanity_check_event(pdu)
except SynapseError as err:
- logger.warn("[%s %s] Received event failed sanity checks", room_id, event_id)
- raise FederationError(
- "ERROR",
- err.code,
- err.msg,
- affected=pdu.event_id,
+ logger.warning(
+ "[%s %s] Received event failed sanity checks", room_id, event_id
)
+ raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id)
# If we are currently in the process of joining this room, then we
# queue up events for later processing.
if room_id in self.room_queues:
logger.info(
"[%s %s] Queuing PDU from %s for now: join in progress",
- room_id, event_id, origin,
+ room_id,
+ event_id,
+ origin,
)
self.room_queues[room_id].append((pdu, origin))
return
@@ -206,73 +238,81 @@ class FederationHandler(BaseHandler):
#
# Note that if we were never in the room then we would have already
# dropped the event, since we wouldn't know the room version.
- is_in_room = yield self.auth.check_host_in_room(
- room_id,
- self.server_name
- )
+ is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
if not is_in_room:
logger.info(
"[%s %s] Ignoring PDU from %s as we're not in the room",
- room_id, event_id, origin,
+ room_id,
+ event_id,
+ origin,
)
- defer.returnValue(None)
+ return None
state = None
- auth_chain = []
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth.
- min_depth = yield self.get_min_depth_for_context(
- pdu.room_id
- )
+ min_depth = await self.get_min_depth_for_context(pdu.room_id)
- logger.debug(
- "[%s %s] min_depth: %d",
- room_id, event_id, min_depth,
- )
+ logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
prevs = set(pdu.prev_event_ids())
- seen = yield self.store.have_seen_events(prevs)
+ seen = await self.store.have_seen_events(prevs)
- if min_depth and pdu.depth < min_depth:
+ if min_depth is not None and pdu.depth < min_depth:
# This is so that we don't notify the user about this
# message, to work around the fact that some events will
# reference really really old events we really don't want to
# send to the clients.
pdu.internal_metadata.outlier = True
- elif min_depth and pdu.depth > min_depth:
+ elif min_depth is not None and pdu.depth > min_depth:
missing_prevs = prevs - seen
if sent_to_us_directly and missing_prevs:
# If we're missing stuff, ensure we only fetch stuff one
# at a time.
logger.info(
"[%s %s] Acquiring room lock to fetch %d missing prev_events: %s",
- room_id, event_id, len(missing_prevs), shortstr(missing_prevs),
+ room_id,
+ event_id,
+ len(missing_prevs),
+ shortstr(missing_prevs),
)
- with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
+ with (await self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
"[%s %s] Acquired room lock to fetch %d missing prev_events",
- room_id, event_id, len(missing_prevs),
+ room_id,
+ event_id,
+ len(missing_prevs),
)
- yield self._get_missing_events_for_pdu(
- origin, pdu, prevs, min_depth
- )
+ try:
+ await self._get_missing_events_for_pdu(
+ origin, pdu, prevs, min_depth
+ )
+ except Exception as e:
+ raise Exception(
+ "Error fetching missing prev_events for %s: %s"
+ % (event_id, e)
+ )
# Update the set of things we've seen after trying to
# fetch the missing stuff
- seen = yield self.store.have_seen_events(prevs)
+ seen = await self.store.have_seen_events(prevs)
if not prevs - seen:
logger.info(
"[%s %s] Found all missing prev_events",
- room_id, event_id,
+ room_id,
+ event_id,
)
elif missing_prevs:
logger.info(
"[%s %s] Not recursively fetching %d missing prev_events: %s",
- room_id, event_id, len(missing_prevs), shortstr(missing_prevs),
+ room_id,
+ event_id,
+ len(missing_prevs),
+ shortstr(missing_prevs),
)
if prevs - seen:
@@ -301,9 +341,12 @@ class FederationHandler(BaseHandler):
# following.
if sent_to_us_directly:
- logger.warn(
+ logger.warning(
"[%s %s] Rejecting: failed to fetch %d prev events: %s",
- room_id, event_id, len(prevs - seen), shortstr(prevs - seen)
+ room_id,
+ event_id,
+ len(prevs - seen),
+ shortstr(prevs - seen),
)
raise FederationError(
"ERROR",
@@ -317,17 +360,13 @@ class FederationHandler(BaseHandler):
# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
- auth_chains = set()
- event_map = {
- event_id: pdu,
- }
+ event_map = {event_id: pdu}
try:
# Get the state of the events we know about
- ours = yield self.store.get_state_groups_ids(room_id, seen)
+ ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
- # type: list[dict[tuple[str, str], str]]
- state_maps = list(ours.values())
+ state_maps = list(ours.values()) # type: list[StateMap[str]]
# we don't need this any more, let's delete it.
del ours
@@ -337,40 +376,19 @@ class FederationHandler(BaseHandler):
for p in prevs - seen:
logger.info(
"[%s %s] Requesting state at missing prev_event %s",
- room_id, event_id, p,
+ room_id,
+ event_id,
+ p,
)
- room_version = yield self.store.get_room_version(room_id)
-
- with logcontext.nested_logging_context(p):
+ with nested_logging_context(p):
# note that if any of the missing prevs share missing state or
# auth events, the requests to fetch those events are deduped
# by the get_pdu_cache in federation_client.
- remote_state, got_auth_chain = (
- yield self.federation_client.get_state_for_room(
- origin, room_id, p,
- )
- )
-
- # we want the state *after* p; get_state_for_room returns the
- # state *before* p.
- remote_event = yield self.federation_client.get_pdu(
- [origin], p, room_version, outlier=True,
+ (remote_state, _,) = await self._get_state_for_room(
+ origin, room_id, p, include_event_in_state=True
)
- if remote_event is None:
- raise Exception(
- "Unable to get missing prev_event %s" % (p, )
- )
-
- if remote_event.is_state():
- remote_state.append(remote_event)
-
- # XXX hrm I'm not convinced that duplicate events will compare
- # for equality, so I'm not sure this does what the author
- # hoped.
- auth_chains.update(got_auth_chain)
-
remote_state_map = {
(x.type, x.state_key): x.event_id for x in remote_state
}
@@ -379,8 +397,12 @@ class FederationHandler(BaseHandler):
for x in remote_state:
event_map[x.event_id] = x
- state_map = yield resolve_events_with_store(
- room_version, state_maps, event_map,
+ room_version = await self.store.get_room_version_id(room_id)
+ state_map = await resolve_events_with_store(
+ room_id,
+ room_version,
+ state_maps,
+ event_map,
state_res_store=StateResolutionStore(self.store),
)
@@ -389,22 +411,19 @@ class FederationHandler(BaseHandler):
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
- evs = yield self.store.get_events(
- list(state_map.values()),
- get_prev_content=False,
- check_redacted=False,
+ evs = await self.store.get_events(
+ list(state_map.values()), get_prev_content=False,
)
event_map.update(evs)
- state = [
- event_map[e] for e in six.itervalues(state_map)
- ]
- auth_chain = list(auth_chains)
+ state = [event_map[e] for e in six.itervalues(state_map)]
except Exception:
- logger.warn(
+ logger.warning(
"[%s %s] Error attempting to resolve state at missing "
"prev_events",
- room_id, event_id, exc_info=True,
+ room_id,
+ event_id,
+ exc_info=True,
)
raise FederationError(
"ERROR",
@@ -413,15 +432,9 @@ class FederationHandler(BaseHandler):
affected=event_id,
)
- yield self._process_received_pdu(
- origin,
- pdu,
- state=state,
- auth_chain=auth_chain,
- )
+ await self._process_received_pdu(origin, pdu, state=state)
- @defer.inlineCallbacks
- def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
+ async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
"""
Args:
origin (str): Origin of the pdu. Will be called to get the missing events
@@ -433,12 +446,12 @@ class FederationHandler(BaseHandler):
room_id = pdu.room_id
event_id = pdu.event_id
- seen = yield self.store.have_seen_events(prevs)
+ seen = await self.store.have_seen_events(prevs)
if not prevs - seen:
return
- latest = yield self.store.get_latest_event_ids_in_room(room_id)
+ latest = await self.store.get_latest_event_ids_in_room(room_id)
# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
@@ -447,7 +460,10 @@ class FederationHandler(BaseHandler):
logger.info(
"[%s %s]: Requesting missing events between %s and %s",
- room_id, event_id, shortstr(latest), event_id,
+ room_id,
+ event_id,
+ shortstr(latest),
+ event_id,
)
# XXX: we set timeout to 10s to help workaround
@@ -498,19 +514,31 @@ class FederationHandler(BaseHandler):
#
# All that said: Let's try increasing the timout to 60s and see what happens.
- missing_events = yield self.federation_client.get_missing_events(
- origin,
- room_id,
- earliest_events_ids=list(latest),
- latest_events=[pdu],
- limit=10,
- min_depth=min_depth,
- timeout=60000,
- )
+ try:
+ missing_events = await self.federation_client.get_missing_events(
+ origin,
+ room_id,
+ earliest_events_ids=list(latest),
+ latest_events=[pdu],
+ limit=10,
+ min_depth=min_depth,
+ timeout=60000,
+ )
+ except RequestSendFailed as e:
+ # We failed to get the missing events, but since we need to handle
+ # the case of `get_missing_events` not returning the necessary
+ # events anyway, it is safe to simply log the error and continue.
+ logger.warning(
+ "[%s %s]: Failed to get prev_events: %s", room_id, event_id, e
+ )
+ return
logger.info(
"[%s %s]: Got %d prev_events: %s",
- room_id, event_id, len(missing_events), shortstr(missing_events),
+ room_id,
+ event_id,
+ len(missing_events),
+ shortstr(missing_events),
)
# We want to sort these by depth so we process them and
@@ -520,97 +548,171 @@ class FederationHandler(BaseHandler):
for ev in missing_events:
logger.info(
"[%s %s] Handling received prev_event %s",
- room_id, event_id, ev.event_id,
+ room_id,
+ event_id,
+ ev.event_id,
)
- with logcontext.nested_logging_context(ev.event_id):
+ with nested_logging_context(ev.event_id):
try:
- yield self.on_receive_pdu(
- origin,
- ev,
- sent_to_us_directly=False,
- )
+ await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
except FederationError as e:
if e.code == 403:
- logger.warn(
+ logger.warning(
"[%s %s] Received prev_event %s failed history check.",
- room_id, event_id, ev.event_id,
+ room_id,
+ event_id,
+ ev.event_id,
)
else:
raise
- @defer.inlineCallbacks
- def _process_received_pdu(self, origin, event, state, auth_chain):
- """ Called when we have a new pdu. We need to do auth checks and put it
- through the StateHandler.
+ async def _get_state_for_room(
+ self,
+ destination: str,
+ room_id: str,
+ event_id: str,
+ include_event_in_state: bool = False,
+ ) -> Tuple[List[EventBase], List[EventBase]]:
+ """Requests all of the room state at a given event from a remote homeserver.
+
+ Args:
+ destination: The remote homeserver to query for the state.
+ room_id: The id of the room we're interested in.
+ event_id: The id of the event we want the state at.
+ include_event_in_state: if true, the event itself will be included in the
+ returned state event list.
+
+ Returns:
+ A list of events in the state, possibly including the event itself, and
+ a list of events in the auth chain for the given event.
"""
- room_id = event.room_id
- event_id = event.event_id
+ (
+ state_event_ids,
+ auth_event_ids,
+ ) = await self.federation_client.get_room_state_ids(
+ destination, room_id, event_id=event_id
+ )
- logger.debug(
- "[%s %s] Processing event: %s",
- room_id, event_id, event,
+ desired_events = set(state_event_ids + auth_event_ids)
+
+ if include_event_in_state:
+ desired_events.add(event_id)
+
+ event_map = await self._get_events_from_store_or_dest(
+ destination, room_id, desired_events
)
- event_ids = set()
- if state:
- event_ids |= {e.event_id for e in state}
- if auth_chain:
- event_ids |= {e.event_id for e in auth_chain}
+ failed_to_fetch = desired_events - event_map.keys()
+ if failed_to_fetch:
+ logger.warning(
+ "Failed to fetch missing state/auth events for %s %s",
+ event_id,
+ failed_to_fetch,
+ )
+
+ remote_state = [
+ event_map[e_id] for e_id in state_event_ids if e_id in event_map
+ ]
- seen_ids = yield self.store.have_seen_events(event_ids)
+ if include_event_in_state:
+ remote_event = event_map.get(event_id)
+ if not remote_event:
+ raise Exception("Unable to get missing prev_event %s" % (event_id,))
+ if remote_event.is_state() and remote_event.rejected_reason is None:
+ remote_state.append(remote_event)
- if state and auth_chain is not None:
- # If we have any state or auth_chain given to us by the replication
- # layer, then we should handle them (if we haven't before.)
+ auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
+ auth_chain.sort(key=lambda e: e.depth)
- event_infos = []
+ return remote_state, auth_chain
- for e in itertools.chain(auth_chain, state):
- if e.event_id in seen_ids:
- continue
- e.internal_metadata.outlier = True
- auth_ids = e.auth_event_ids()
- auth = {
- (e.type, e.state_key): e for e in auth_chain
- if e.event_id in auth_ids or e.type == EventTypes.Create
- }
- event_infos.append({
- "event": e,
- "auth_events": auth,
- })
- seen_ids.add(e.event_id)
+ async def _get_events_from_store_or_dest(
+ self, destination: str, room_id: str, event_ids: Iterable[str]
+ ) -> Dict[str, EventBase]:
+ """Fetch events from a remote destination, checking if we already have them.
- logger.info(
- "[%s %s] persisting newly-received auth/state events %s",
- room_id, event_id, [e["event"].event_id for e in event_infos]
+ Persists any events we don't already have as outliers.
+
+ If we fail to fetch any of the events, a warning will be logged, and the event
+ will be omitted from the result. Likewise, any events which turn out not to
+ be in the given room.
+
+ Returns:
+ map from event_id to event
+ """
+ fetched_events = await self.store.get_events(event_ids, allow_rejected=True)
+
+ missing_events = set(event_ids) - fetched_events.keys()
+
+ if missing_events:
+ logger.debug(
+ "Fetching unknown state/auth events %s for room %s",
+ missing_events,
+ room_id,
)
- yield self._handle_new_events(origin, event_infos)
- try:
- context = yield self._handle_new_event(
- origin,
- event,
- state=state,
+ await self._get_events_and_persist(
+ destination=destination, room_id=room_id, events=missing_events
)
- except AuthError as e:
- raise FederationError(
- "ERROR",
- e.code,
- e.msg,
- affected=event.event_id,
+
+ # we need to make sure we re-load from the database to get the rejected
+ # state correct.
+ fetched_events.update(
+ (await self.store.get_events(missing_events, allow_rejected=True))
)
- room = yield self.store.get_room(room_id)
+ # check for events which were in the wrong room.
+ #
+ # this can happen if a remote server claims that the state or
+ # auth_events at an event in room A are actually events in room B
- if not room:
- try:
- yield self.store.store_room(
- room_id=room_id,
- room_creator_user_id="",
- is_public=False,
- )
- except StoreError:
- logger.exception("Failed to store room.")
+ bad_events = [
+ (event_id, event.room_id)
+ for event_id, event in fetched_events.items()
+ if event.room_id != room_id
+ ]
+
+ for bad_event_id, bad_room_id in bad_events:
+ # This is a bogus situation, but since we may only discover it a long time
+ # after it happened, we try our best to carry on, by just omitting the
+ # bad events from the returned auth/state set.
+ logger.warning(
+ "Remote server %s claims event %s in room %s is an auth/state "
+ "event in room %s",
+ destination,
+ bad_event_id,
+ bad_room_id,
+ room_id,
+ )
+
+ del fetched_events[bad_event_id]
+
+ return fetched_events
+
+ async def _process_received_pdu(
+ self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]],
+ ):
+ """ Called when we have a new pdu. We need to do auth checks and put it
+ through the StateHandler.
+
+ Args:
+ origin: server sending the event
+
+ event: event to be persisted
+
+ state: Normally None, but if we are handling a gap in the graph
+ (ie, we are missing one or more prev_events), the resolved state at the
+ event
+ """
+ room_id = event.room_id
+ event_id = event.event_id
+
+ logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
+
+ try:
+ context = await self._handle_new_event(origin, event, state=state)
+ except AuthError as e:
+ raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
if event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
@@ -619,25 +721,94 @@ class FederationHandler(BaseHandler):
# changing their profile info.
newly_joined = True
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = await context.get_prev_state_ids()
- prev_state_id = prev_state_ids.get(
- (event.type, event.state_key)
- )
+ prev_state_id = prev_state_ids.get((event.type, event.state_key))
if prev_state_id:
- prev_state = yield self.store.get_event(
- prev_state_id, allow_none=True,
+ prev_state = await self.store.get_event(
+ prev_state_id, allow_none=True
)
if prev_state and prev_state.membership == Membership.JOIN:
newly_joined = False
if newly_joined:
user = UserID.from_string(event.state_key)
- yield self.user_joined_room(user, room_id)
+ await self.user_joined_room(user, room_id)
+
+ # For encrypted messages we check that we know about the sending device,
+ # if we don't then we mark the device cache for that user as stale.
+ if event.type == EventTypes.Encrypted:
+ device_id = event.content.get("device_id")
+ sender_key = event.content.get("sender_key")
+
+ cached_devices = await self.store.get_cached_devices_for_user(event.sender)
+
+ resync = False # Whether we should resync device lists.
+
+ device = None
+ if device_id is not None:
+ device = cached_devices.get(device_id)
+ if device is None:
+ logger.info(
+ "Received event from remote device not in our cache: %s %s",
+ event.sender,
+ device_id,
+ )
+ resync = True
+
+ # We also check if the `sender_key` matches what we expect.
+ if sender_key is not None:
+ # Figure out what sender key we're expecting. If we know the
+ # device and recognize the algorithm then we can work out the
+ # exact key to expect. Otherwise check it matches any key we
+ # have for that device.
+ if device:
+ keys = device.get("keys", {}).get("keys", {})
+
+ if event.content.get("algorithm") == "m.megolm.v1.aes-sha2":
+ # For this algorithm we expect a curve25519 key.
+ key_name = "curve25519:%s" % (device_id,)
+ current_keys = [keys.get(key_name)]
+ else:
+ # We don't know understand the algorithm, so we just
+ # check it matches a key for the device.
+ current_keys = keys.values()
+ elif device_id:
+ # We don't have any keys for the device ID.
+ current_keys = []
+ else:
+ # The event didn't include a device ID, so we just look for
+ # keys across all devices.
+ current_keys = (
+ key
+ for device in cached_devices
+ for key in device.get("keys", {}).get("keys", {}).values()
+ )
+
+ # We now check that the sender key matches (one of) the expected
+ # keys.
+ if sender_key not in current_keys:
+ logger.info(
+ "Received event from remote device with unexpected sender key: %s %s: %s",
+ event.sender,
+ device_id or "<no device_id>",
+ sender_key,
+ )
+ resync = True
+
+ if resync:
+ await self.store.mark_remote_user_device_cache_as_stale(event.sender)
+
+ # Immediately attempt a resync in the background
+ if self.config.worker_app:
+ return run_in_background(self._user_device_resync, event.sender)
+ else:
+ return run_in_background(
+ self._device_list_updater.user_device_resync, event.sender
+ )
@log_function
- @defer.inlineCallbacks
- def backfill(self, dest, room_id, limit, extremities):
+ async def backfill(self, dest, room_id, limit, extremities):
""" Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. If the other side
@@ -654,13 +825,8 @@ class FederationHandler(BaseHandler):
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")
- room_version = yield self.store.get_room_version(room_id)
-
- events = yield self.federation_client.backfill(
- dest,
- room_id,
- limit=limit,
- extremities=extremities,
+ events = await self.federation_client.backfill(
+ dest, room_id, limit=limit, extremities=extremities
)
# ideally we'd sanity check the events here for excess prev_events etc,
@@ -674,29 +840,25 @@ class FederationHandler(BaseHandler):
# self._sanity_check_event(ev)
# Don't bother processing events we already have.
- seen_events = yield self.store.have_events_in_timeline(
- set(e.event_id for e in events)
+ seen_events = await self.store.have_events_in_timeline(
+ {e.event_id for e in events}
)
events = [e for e in events if e.event_id not in seen_events]
if not events:
- defer.returnValue([])
+ return []
event_map = {e.event_id: e for e in events}
- event_ids = set(e.event_id for e in events)
+ event_ids = {e.event_id for e in events}
- edges = [
- ev.event_id
- for ev in events
- if set(ev.prev_event_ids()) - event_ids
- ]
+ # build a list of events whose prev_events weren't in the batch.
+ # (XXX: this will include events whose prev_events we already have; that doesn't
+ # sound right?)
+ edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids]
- logger.info(
- "backfill: Got %d events with %d edges",
- len(events), len(edges),
- )
+ logger.info("backfill: Got %d events with %d edges", len(events), len(edges))
# For each edge get the current state.
@@ -704,127 +866,54 @@ class FederationHandler(BaseHandler):
state_events = {}
events_to_state = {}
for e_id in edges:
- state, auth = yield self.federation_client.get_state_for_room(
+ state, auth = await self._get_state_for_room(
destination=dest,
room_id=room_id,
- event_id=e_id
+ event_id=e_id,
+ include_event_in_state=False,
)
auth_events.update({a.event_id: a for a in auth})
auth_events.update({s.event_id: s for s in state})
state_events.update({s.event_id: s for s in state})
events_to_state[e_id] = state
- required_auth = set(
+ required_auth = {
a_id
- for event in events + list(state_events.values()) + list(auth_events.values())
+ for event in events
+ + list(state_events.values())
+ + list(auth_events.values())
for a_id in event.auth_event_ids()
+ }
+ auth_events.update(
+ {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
)
- auth_events.update({
- e_id: event_map[e_id] for e_id in required_auth if e_id in event_map
- })
- missing_auth = required_auth - set(auth_events)
- failed_to_fetch = set()
-
- # Try and fetch any missing auth events from both DB and remote servers.
- # We repeatedly do this until we stop finding new auth events.
- while missing_auth - failed_to_fetch:
- logger.info("Missing auth for backfill: %r", missing_auth)
- ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
- auth_events.update(ret_events)
-
- required_auth.update(
- a_id for event in ret_events.values() for a_id in event.auth_event_ids()
- )
- missing_auth = required_auth - set(auth_events)
-
- if missing_auth - failed_to_fetch:
- logger.info(
- "Fetching missing auth for backfill: %r",
- missing_auth - failed_to_fetch
- )
-
- results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- logcontext.run_in_background(
- self.federation_client.get_pdu,
- [dest],
- event_id,
- room_version=room_version,
- outlier=True,
- timeout=10000,
- )
- for event_id in missing_auth - failed_to_fetch
- ],
- consumeErrors=True
- )).addErrback(unwrapFirstError)
- auth_events.update({a.event_id: a for a in results if a})
- required_auth.update(
- a_id
- for event in results if event
- for a_id in event.auth_event_ids()
- )
- missing_auth = required_auth - set(auth_events)
-
- failed_to_fetch = missing_auth - set(auth_events)
-
- seen_events = yield self.store.have_seen_events(
- set(auth_events.keys()) | set(state_events.keys())
- )
-
- # We now have a chunk of events plus associated state and auth chain to
- # persist. We do the persistence in two steps:
- # 1. Auth events and state get persisted as outliers, plus the
- # backward extremities get persisted (as non-outliers).
- # 2. The rest of the events in the chunk get persisted one by one, as
- # each one depends on the previous event for its state.
- #
- # The important thing is that events in the chunk get persisted as
- # non-outliers, including when those events are also in the state or
- # auth chain. Caution must therefore be taken to ensure that they are
- # not accidentally marked as outliers.
- # Step 1a: persist auth events that *don't* appear in the chunk
ev_infos = []
- for a in auth_events.values():
- # We only want to persist auth events as outliers that we haven't
- # seen and aren't about to persist as part of the backfilled chunk.
- if a.event_id in seen_events or a.event_id in event_map:
- continue
-
- a.internal_metadata.outlier = True
- ev_infos.append({
- "event": a,
- "auth_events": {
- (auth_events[a_id].type, auth_events[a_id].state_key):
- auth_events[a_id]
- for a_id in a.auth_event_ids()
- if a_id in auth_events
- }
- })
- # Step 1b: persist the events in the chunk we fetched state for (i.e.
- # the backwards extremities) as non-outliers.
+ # Step 1: persist the events in the chunk we fetched state for (i.e.
+ # the backwards extremities), with custom auth events and state
for e_id in events_to_state:
# For paranoia we ensure that these events are marked as
# non-outliers
ev = event_map[e_id]
- assert(not ev.internal_metadata.is_outlier())
-
- ev_infos.append({
- "event": ev,
- "state": events_to_state[e_id],
- "auth_events": {
- (auth_events[a_id].type, auth_events[a_id].state_key):
- auth_events[a_id]
- for a_id in ev.auth_event_ids()
- if a_id in auth_events
- }
- })
+ assert not ev.internal_metadata.is_outlier()
- yield self._handle_new_events(
- dest, ev_infos,
- backfilled=True,
- )
+ ev_infos.append(
+ _NewEventInfo(
+ event=ev,
+ state=events_to_state[e_id],
+ auth_events={
+ (
+ auth_events[a_id].type,
+ auth_events[a_id].state_key,
+ ): auth_events[a_id]
+ for a_id in ev.auth_event_ids()
+ if a_id in auth_events
+ },
+ )
+ )
+
+ await self._handle_new_events(dest, ev_infos, backfilled=True)
# Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth)
@@ -835,25 +924,20 @@ class FederationHandler(BaseHandler):
# For paranoia we ensure that these events are marked as
# non-outliers
- assert(not event.internal_metadata.is_outlier())
+ assert not event.internal_metadata.is_outlier()
# We store these one at a time since each event depends on the
# previous to work out the state.
# TODO: We can probably do something more clever here.
- yield self._handle_new_event(
- dest, event, backfilled=True,
- )
+ await self._handle_new_event(dest, event, backfilled=True)
- defer.returnValue(events)
+ return events
- @defer.inlineCallbacks
- def maybe_backfill(self, room_id, current_depth):
+ async def maybe_backfill(self, room_id, current_depth):
"""Checks the database to see if we should backfill before paginating,
and if so do.
"""
- extremities = yield self.store.get_oldest_events_with_depth_in_room(
- room_id
- )
+ extremities = await self.store.get_oldest_events_with_depth_in_room(room_id)
if not extremities:
logger.debug("Not backfilling as no extremeties found.")
@@ -885,31 +969,29 @@ class FederationHandler(BaseHandler):
# state *before* the event, ignoring the special casing certain event
# types have.
- forward_events = yield self.store.get_successor_events(
- list(extremities),
- )
+ forward_events = await self.store.get_successor_events(list(extremities))
- extremities_events = yield self.store.get_events(
+ extremities_events = await self.store.get_events(
forward_events,
- check_redacted=False,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
)
# We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased.
- filtered_extremities = yield filter_events_for_server(
- self.store, self.server_name, list(extremities_events.values()),
- redact=False, check_history_visibility_only=True,
+ filtered_extremities = await filter_events_for_server(
+ self.storage,
+ self.server_name,
+ list(extremities_events.values()),
+ redact=False,
+ check_history_visibility_only=True,
)
if not filtered_extremities:
- defer.returnValue(False)
+ return False
# Check if we reached a point where we should start backfilling.
- sorted_extremeties_tuple = sorted(
- extremities.items(),
- key=lambda e: -int(e[1])
- )
+ sorted_extremeties_tuple = sorted(extremities.items(), key=lambda e: -int(e[1]))
max_depth = sorted_extremeties_tuple[0][1]
# We don't want to specify too many extremities as it causes the backfill
@@ -918,8 +1000,7 @@ class FederationHandler(BaseHandler):
if current_depth > max_depth:
logger.debug(
- "Not backfilling as we don't need to. %d < %d",
- max_depth, current_depth,
+ "Not backfilling as we don't need to. %d < %d", max_depth, current_depth
)
return
@@ -928,7 +1009,7 @@ class FederationHandler(BaseHandler):
# First we try hosts that are already in the room
# TODO: HEURISTIC ALERT.
- curr_state = yield self.state_handler.get_current_state(room_id)
+ curr_state = await self.state_handler.get_current_state(room_id)
def get_domains_from_state(state):
"""Get joined domains from state
@@ -944,8 +1025,7 @@ class FederationHandler(BaseHandler):
joined_users = [
(state_key, int(event.depth))
for (e_type, state_key), event in iteritems(state)
- if e_type == EventTypes.Member
- and event.membership == Membership.JOIN
+ if e_type == EventTypes.Member and event.membership == Membership.JOIN
]
joined_domains = {}
@@ -965,57 +1045,47 @@ class FederationHandler(BaseHandler):
curr_domains = get_domains_from_state(curr_state)
likely_domains = [
- domain for domain, depth in curr_domains
- if domain != self.server_name
+ domain for domain, depth in curr_domains if domain != self.server_name
]
- @defer.inlineCallbacks
- def try_backfill(domains):
+ async def try_backfill(domains):
# TODO: Should we try multiple of these at a time?
for dom in domains:
try:
- yield self.backfill(
- dom, room_id,
- limit=100,
- extremities=extremities,
+ await self.backfill(
+ dom, room_id, limit=100, extremities=extremities
)
# If this succeeded then we probably already have the
# appropriate stuff.
# TODO: We can probably do something more intelligent here.
- defer.returnValue(True)
+ return True
except SynapseError as e:
- logger.info(
- "Failed to backfill from %s because %s",
- dom, e,
- )
+ logger.info("Failed to backfill from %s because %s", dom, e)
continue
except CodeMessageException as e:
if 400 <= e.code < 500:
raise
- logger.info(
- "Failed to backfill from %s because %s",
- dom, e,
- )
+ logger.info("Failed to backfill from %s because %s", dom, e)
continue
except NotRetryingDestination as e:
logger.info(str(e))
continue
+ except RequestSendFailed as e:
+ logger.info("Falied to get backfill from %s because %s", dom, e)
+ continue
except FederationDeniedError as e:
logger.info(e)
continue
except Exception as e:
- logger.exception(
- "Failed to backfill from %s because %s",
- dom, e,
- )
+ logger.exception("Failed to backfill from %s because %s", dom, e)
continue
- defer.returnValue(False)
+ return False
- success = yield try_backfill(likely_domains)
+ success = await try_backfill(likely_domains)
if success:
- defer.returnValue(True)
+ return True
# Huh, well *those* domains didn't work out. Lets try some domains
# from the time.
@@ -1026,43 +1096,92 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys())
logger.debug("calling resolve_state_groups in _maybe_backfill")
- resolve = logcontext.preserve_fn(
- self.state_handler.resolve_state_groups_for_events
+ resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
+ states = await make_deferred_yieldable(
+ defer.gatherResults(
+ [resolve(room_id, [e]) for e in event_ids], consumeErrors=True
+ )
)
- states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [resolve(room_id, [e]) for e in event_ids],
- consumeErrors=True,
- ))
# dict[str, dict[tuple, str]], a map from event_id to state map of
# event_ids.
states = dict(zip(event_ids, [s.state for s in states]))
- state_map = yield self.store.get_events(
+ state_map = await self.store.get_events(
[e_id for ids in itervalues(states) for e_id in itervalues(ids)],
- get_prev_content=False
+ get_prev_content=False,
)
states = {
key: {
k: state_map[e_id]
for k, e_id in iteritems(state_dict)
if e_id in state_map
- } for key, state_dict in iteritems(states)
+ }
+ for key, state_dict in iteritems(states)
}
for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id])
- success = yield try_backfill([
- dom for dom, _ in likely_domains
- if dom not in tried_domains
- ])
+ success = await try_backfill(
+ [dom for dom, _ in likely_domains if dom not in tried_domains]
+ )
if success:
- defer.returnValue(True)
+ return True
tried_domains.update(dom for dom, _ in likely_domains)
- defer.returnValue(False)
+ return False
+
+ async def _get_events_and_persist(
+ self, destination: str, room_id: str, events: Iterable[str]
+ ):
+ """Fetch the given events from a server, and persist them as outliers.
+
+ Logs a warning if we can't find the given event.
+ """
+
+ room_version = await self.store.get_room_version(room_id)
+
+ event_infos = []
+
+ async def get_event(event_id: str):
+ with nested_logging_context(event_id):
+ try:
+ event = await self.federation_client.get_pdu(
+ [destination], event_id, room_version, outlier=True,
+ )
+ if event is None:
+ logger.warning(
+ "Server %s didn't return event %s", destination, event_id,
+ )
+ return
+
+ # recursively fetch the auth events for this event
+ auth_events = await self._get_events_from_store_or_dest(
+ destination, room_id, event.auth_event_ids()
+ )
+ auth = {}
+ for auth_event_id in event.auth_event_ids():
+ ae = auth_events.get(auth_event_id)
+ if ae:
+ auth[(ae.type, ae.state_key)] = ae
+
+ event_infos.append(_NewEventInfo(event, None, auth))
+
+ except Exception as e:
+ logger.warning(
+ "Error fetching missing state/auth event %s: %s %s",
+ event_id,
+ type(e),
+ e,
+ )
+
+ await concurrently_execute(get_event, events, 5)
+
+ await self._handle_new_events(
+ destination, event_infos,
+ )
def _sanity_check_event(self, ev):
"""
@@ -1081,50 +1200,47 @@ class FederationHandler(BaseHandler):
SynapseError if the event does not pass muster
"""
if len(ev.prev_event_ids()) > 20:
- logger.warn("Rejecting event %s which has %i prev_events",
- ev.event_id, len(ev.prev_event_ids()))
- raise SynapseError(
- http_client.BAD_REQUEST,
- "Too many prev_events",
+ logger.warning(
+ "Rejecting event %s which has %i prev_events",
+ ev.event_id,
+ len(ev.prev_event_ids()),
)
+ raise SynapseError(http_client.BAD_REQUEST, "Too many prev_events")
if len(ev.auth_event_ids()) > 10:
- logger.warn("Rejecting event %s which has %i auth_events",
- ev.event_id, len(ev.auth_event_ids()))
- raise SynapseError(
- http_client.BAD_REQUEST,
- "Too many auth_events",
+ logger.warning(
+ "Rejecting event %s which has %i auth_events",
+ ev.event_id,
+ len(ev.auth_event_ids()),
)
+ raise SynapseError(http_client.BAD_REQUEST, "Too many auth_events")
- @defer.inlineCallbacks
- def send_invite(self, target_host, event):
+ async def send_invite(self, target_host, event):
""" Sends the invite to the remote server for signing.
Invites must be signed by the invitee's server before distribution.
"""
- pdu = yield self.federation_client.send_invite(
+ pdu = await self.federation_client.send_invite(
destination=target_host,
room_id=event.room_id,
event_id=event.event_id,
- pdu=event
+ pdu=event,
)
- defer.returnValue(pdu)
+ return pdu
- @defer.inlineCallbacks
- def on_event_auth(self, event_id):
- event = yield self.store.get_event(event_id)
- auth = yield self.store.get_auth_chain(
- [auth_id for auth_id in event.auth_event_ids()],
- include_given=True
+ async def on_event_auth(self, event_id: str) -> List[EventBase]:
+ event = await self.store.get_event(event_id)
+ auth = await self.store.get_auth_chain(
+ list(event.auth_event_ids()), include_given=True
)
- defer.returnValue([e for e in auth])
+ return list(auth)
- @log_function
- @defer.inlineCallbacks
- def do_invite_join(self, target_hosts, room_id, joinee, content):
+ async def do_invite_join(
+ self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict
+ ) -> None:
""" Attempts to join the `joinee` to the room `room_id` via the
- server `target_host`.
+ servers contained in `target_hosts`.
This first triggers a /make_join/ request that returns a partial
event that we can fill out and sign. This is then sent to the
@@ -1133,28 +1249,35 @@ class FederationHandler(BaseHandler):
We suspend processing of any received events from this room until we
have finished processing the join.
+
+ Args:
+ target_hosts: List of servers to attempt to join the room with.
+
+ room_id: The ID of the room to join.
+
+ joinee: The User ID of the joining user.
+
+ content: The event content to use for the join event.
"""
logger.debug("Joining %s to %s", joinee, room_id)
- origin, event, event_format_version = yield self._make_and_verify_event(
+ origin, event, room_version_obj = await self._make_and_verify_event(
target_hosts,
room_id,
joinee,
"join",
content,
- params={
- "ver": KNOWN_ROOM_VERSIONS,
- },
+ params={"ver": KNOWN_ROOM_VERSIONS},
)
# This shouldn't happen, because the RoomMemberHandler has a
# linearizer lock which only allows one operation per user per room
# at a time - so this is just paranoia.
- assert (room_id not in self.room_queues)
+ assert room_id not in self.room_queues
self.room_queues[room_id] = []
- yield self._clean_room_for_join(room_id)
+ await self._clean_room_for_join(room_id)
handled_events = set()
@@ -1166,8 +1289,9 @@ class FederationHandler(BaseHandler):
target_hosts.insert(0, origin)
except ValueError:
pass
- ret = yield self.federation_client.send_join(
- target_hosts, event, event_format_version,
+
+ ret = await self.federation_client.send_join(
+ target_hosts, event, room_version_obj
)
origin = ret["origin"]
@@ -1184,18 +1308,37 @@ class FederationHandler(BaseHandler):
logger.debug("do_invite_join event: %s", event)
- try:
- yield self.store.store_room(
- room_id=room_id,
- room_creator_user_id="",
- is_public=False
- )
- except Exception:
- # FIXME
- pass
+ # if this is the first time we've joined this room, it's time to add
+ # a row to `rooms` with the correct room version. If there's already a
+ # row there, we should override it, since it may have been populated
+ # based on an invite request which lied about the room version.
+ #
+ # federation_client.send_join has already checked that the room
+ # version in the received create event is the same as room_version_obj,
+ # so we can rely on it now.
+ #
+ await self.store.upsert_room_on_join(
+ room_id=room_id, room_version=room_version_obj,
+ )
+
+ await self._persist_auth_tree(
+ origin, auth_chain, state, event, room_version_obj
+ )
+
+ # Check whether this room is the result of an upgrade of a room we already know
+ # about. If so, migrate over user information
+ predecessor = await self.store.get_room_predecessor(room_id)
+ if not predecessor or not isinstance(predecessor.get("room_id"), str):
+ return
+ old_room_id = predecessor["room_id"]
+ logger.debug(
+ "Found predecessor for %s during remote join: %s", room_id, old_room_id
+ )
- yield self._persist_auth_tree(
- origin, auth_chain, state, event
+ # We retrieve the room member handler here as to not cause a cyclic dependency
+ member_handler = self.hs.get_room_member_handler()
+ await member_handler.transfer_room_state_on_room_upgrade(
+ old_room_id, room_id
)
logger.debug("Finished joining %s to %s", joinee, room_id)
@@ -1209,12 +1352,9 @@ class FederationHandler(BaseHandler):
# lots of requests for missing prev_events which we do actually
# have. Hence we fire off the deferred, but don't wait for it.
- logcontext.run_in_background(self._handle_queued_pdus, room_queue)
+ run_in_background(self._handle_queued_pdus, room_queue)
- defer.returnValue(True)
-
- @defer.inlineCallbacks
- def _handle_queued_pdus(self, room_queue):
+ async def _handle_queued_pdus(self, room_queue):
"""Process PDUs which got queued up while we were busy send_joining.
Args:
@@ -1223,25 +1363,42 @@ class FederationHandler(BaseHandler):
"""
for p, origin in room_queue:
try:
- logger.info("Processing queued PDU %s which was received "
- "while we were joining %s", p.event_id, p.room_id)
- with logcontext.nested_logging_context(p.event_id):
- yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
+ logger.info(
+ "Processing queued PDU %s which was received "
+ "while we were joining %s",
+ p.event_id,
+ p.room_id,
+ )
+ with nested_logging_context(p.event_id):
+ await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
except Exception as e:
- logger.warn(
- "Error handling queued PDU %s from %s: %s",
- p.event_id, origin, e)
+ logger.warning(
+ "Error handling queued PDU %s from %s: %s", p.event_id, origin, e
+ )
- @defer.inlineCallbacks
- @log_function
- def on_make_join_request(self, room_id, user_id):
+ async def on_make_join_request(
+ self, origin: str, room_id: str, user_id: str
+ ) -> EventBase:
""" We've received a /make_join/ request, so we create a partial
join event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back.
+
+ Args:
+ origin: The (verified) server name of the requesting server.
+ room_id: Room to create join event in
+ user_id: The user to create the join for
"""
+ if get_domain_from_id(user_id) != origin:
+ logger.info(
+ "Got /make_join request for user %r from different origin %s, ignoring",
+ user_id,
+ origin,
+ )
+ raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
+
event_content = {"membership": Membership.JOIN}
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version_id(room_id)
builder = self.event_builder_factory.new(
room_version,
@@ -1251,48 +1408,55 @@ class FederationHandler(BaseHandler):
"room_id": room_id,
"sender": user_id,
"state_key": user_id,
- }
+ },
)
try:
- event, context = yield self.event_creation_handler.create_new_client_event(
- builder=builder,
+ event, context = await self.event_creation_handler.create_new_client_event(
+ builder=builder
)
except AuthError as e:
- logger.warn("Failed to create join %r because %s", event, e)
+ logger.warning("Failed to create join to %s because %s", room_id, e)
raise e
- event_allowed = yield self.third_party_event_rules.check_event_allowed(
- event, context,
+ event_allowed = await self.third_party_event_rules.check_event_allowed(
+ event, context
)
if not event_allowed:
logger.info("Creation of join %s forbidden by third-party rules", event)
raise SynapseError(
- 403, "This event is not allowed in this context", Codes.FORBIDDEN,
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
)
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_join_request`
- yield self.auth.check_from_context(
- room_version, event, context, do_sig_check=False,
+ await self.auth.check_from_context(
+ room_version, event, context, do_sig_check=False
)
- defer.returnValue(event)
+ return event
- @defer.inlineCallbacks
- @log_function
- def on_send_join_request(self, origin, pdu):
+ async def on_send_join_request(self, origin, pdu):
""" We have received a join event for a room. Fully process it and
respond with the current state and auth chains.
"""
event = pdu
logger.debug(
- "on_send_join_request: Got event: %s, signatures: %s",
+ "on_send_join_request from %s: Got event: %s, signatures: %s",
+ origin,
event.event_id,
event.signatures,
)
+ if get_domain_from_id(event.sender) != origin:
+ logger.info(
+ "Got /send_join request for user %r from different origin %s",
+ event.sender,
+ origin,
+ )
+ raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
+
event.internal_metadata.outlier = False
# Send this event on behalf of the origin server.
#
@@ -1309,17 +1473,15 @@ class FederationHandler(BaseHandler):
# would introduce the danger of backwards-compatibility problems.
event.internal_metadata.send_on_behalf_of = origin
- context = yield self._handle_new_event(
- origin, event
- )
+ context = await self._handle_new_event(origin, event)
- event_allowed = yield self.third_party_event_rules.check_event_allowed(
- event, context,
+ event_allowed = await self.third_party_event_rules.check_event_allowed(
+ event, context
)
if not event_allowed:
logger.info("Sending of join %s forbidden by third-party rules", event)
raise SynapseError(
- 403, "This event is not allowed in this context", Codes.FORBIDDEN,
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
)
logger.debug(
@@ -1331,43 +1493,42 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.JOIN:
user = UserID.from_string(event.state_key)
- yield self.user_joined_room(user, event.room_id)
+ await self.user_joined_room(user, event.room_id)
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = await context.get_prev_state_ids()
state_ids = list(prev_state_ids.values())
- auth_chain = yield self.store.get_auth_chain(state_ids)
+ auth_chain = await self.store.get_auth_chain(state_ids)
- state = yield self.store.get_events(list(prev_state_ids.values()))
+ state = await self.store.get_events(list(prev_state_ids.values()))
- defer.returnValue({
- "state": list(state.values()),
- "auth_chain": auth_chain,
- })
+ return {"state": list(state.values()), "auth_chain": auth_chain}
- @defer.inlineCallbacks
- def on_invite_request(self, origin, pdu):
+ async def on_invite_request(
+ self, origin: str, event: EventBase, room_version: RoomVersion
+ ):
""" We've got an invite event. Process and persist it. Sign it.
Respond with the now signed event.
"""
- event = pdu
-
if event.state_key is None:
raise SynapseError(400, "The invite event did not have a state key")
- is_blocked = yield self.store.is_room_blocked(event.room_id)
+ is_blocked = await self.store.is_room_blocked(event.room_id)
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
if self.hs.config.block_non_admin_invites:
raise SynapseError(403, "This server does not accept room invites")
- is_published = yield self.store.is_room_published(event.room_id)
+ is_published = await self.store.is_room_published(event.room_id)
if not self.spam_checker.user_may_invite(
- event.sender, event.state_key, None,
- room_id=event.room_id, new_room=False,
+ event.sender,
+ event.state_key,
+ None,
+ room_id=event.room_id,
+ new_room=False,
published_room=is_published,
):
raise SynapseError(
@@ -1380,41 +1541,46 @@ class FederationHandler(BaseHandler):
sender_domain = get_domain_from_id(event.sender)
if sender_domain != origin:
- raise SynapseError(400, "The invite event was not from the server sending it")
+ raise SynapseError(
+ 400, "The invite event was not from the server sending it"
+ )
if not self.is_mine_id(event.state_key):
raise SynapseError(400, "The invite event must be for this server")
# block any attempts to invite the server notices mxid
if event.state_key == self._server_notices_mxid:
- raise SynapseError(
- http_client.FORBIDDEN,
- "Cannot invite this user",
- )
+ raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user")
+
+ # keep a record of the room version, if we don't yet know it.
+ # (this may get overwritten if we later get a different room version in a
+ # join dance).
+ await self._maybe_store_room_on_invite(
+ room_id=event.room_id, room_version=room_version
+ )
event.internal_metadata.outlier = True
event.internal_metadata.out_of_band_membership = True
event.signatures.update(
compute_event_signature(
+ room_version,
event.get_pdu_json(),
self.hs.hostname,
- self.hs.config.signing_key[0]
+ self.hs.config.signing_key[0],
)
)
- context = yield self.state_handler.compute_event_context(event)
- yield self.persist_events_and_notify([(event, context)])
+ context = await self.state_handler.compute_event_context(event)
+ await self.persist_events_and_notify([(event, context)])
- defer.returnValue(event)
+ return event
- @defer.inlineCallbacks
- def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
- origin, event, event_format_version = yield self._make_and_verify_event(
- target_hosts,
- room_id,
- user_id,
- "leave"
+ async def do_remotely_reject_invite(
+ self, target_hosts: Iterable[str], room_id: str, user_id: str, content: JsonDict
+ ) -> EventBase:
+ origin, event, room_version = await self._make_and_verify_event(
+ target_hosts, room_id, user_id, "leave", content=content
)
# Mark as outlier as we don't have any state for this event; we're not
# even in the room.
@@ -1429,46 +1595,61 @@ class FederationHandler(BaseHandler):
except ValueError:
pass
- yield self.federation_client.send_leave(
- target_hosts,
- event
- )
+ await self.federation_client.send_leave(target_hosts, event)
- context = yield self.state_handler.compute_event_context(event)
- yield self.persist_events_and_notify([(event, context)])
+ context = await self.state_handler.compute_event_context(event)
+ await self.persist_events_and_notify([(event, context)])
- defer.returnValue(event)
+ return event
- @defer.inlineCallbacks
- def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
- content={}, params=None):
- origin, event, format_ver = yield self.federation_client.make_membership_event(
- target_hosts,
- room_id,
- user_id,
- membership,
- content,
- params=params,
+ async def _make_and_verify_event(
+ self,
+ target_hosts: Iterable[str],
+ room_id: str,
+ user_id: str,
+ membership: str,
+ content: JsonDict = {},
+ params: Optional[Dict[str, str]] = None,
+ ) -> Tuple[str, EventBase, RoomVersion]:
+ (
+ origin,
+ event,
+ room_version,
+ ) = await self.federation_client.make_membership_event(
+ target_hosts, room_id, user_id, membership, content, params=params
)
logger.debug("Got response to make_%s: %s", membership, event)
# We should assert some things.
# FIXME: Do this in a nicer way
- assert(event.type == EventTypes.Member)
- assert(event.user_id == user_id)
- assert(event.state_key == user_id)
- assert(event.room_id == room_id)
- defer.returnValue((origin, event, format_ver))
-
- @defer.inlineCallbacks
- @log_function
- def on_make_leave_request(self, room_id, user_id):
+ assert event.type == EventTypes.Member
+ assert event.user_id == user_id
+ assert event.state_key == user_id
+ assert event.room_id == room_id
+ return origin, event, room_version
+
+ async def on_make_leave_request(
+ self, origin: str, room_id: str, user_id: str
+ ) -> EventBase:
""" We've received a /make_leave/ request, so we create a partial
leave event for the room and return that. We do *not* persist or
process it until the other server has signed it and sent it back.
+
+ Args:
+ origin: The (verified) server name of the requesting server.
+ room_id: Room to create leave event in
+ user_id: The user to create the leave for
"""
- room_version = yield self.store.get_room_version(room_id)
+ if get_domain_from_id(user_id) != origin:
+ logger.info(
+ "Got /make_leave request for user %r from different origin %s, ignoring",
+ user_id,
+ origin,
+ )
+ raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
+
+ room_version = await self.store.get_room_version_id(room_id)
builder = self.event_builder_factory.new(
room_version,
{
@@ -1477,37 +1658,35 @@ class FederationHandler(BaseHandler):
"room_id": room_id,
"sender": user_id,
"state_key": user_id,
- }
+ },
)
- event, context = yield self.event_creation_handler.create_new_client_event(
- builder=builder,
+ event, context = await self.event_creation_handler.create_new_client_event(
+ builder=builder
)
- event_allowed = yield self.third_party_event_rules.check_event_allowed(
- event, context,
+ event_allowed = await self.third_party_event_rules.check_event_allowed(
+ event, context
)
if not event_allowed:
logger.warning("Creation of leave %s forbidden by third-party rules", event)
raise SynapseError(
- 403, "This event is not allowed in this context", Codes.FORBIDDEN,
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
)
try:
# The remote hasn't signed it yet, obviously. We'll do the full checks
# when we get the event back in `on_send_leave_request`
- yield self.auth.check_from_context(
- room_version, event, context, do_sig_check=False,
+ await self.auth.check_from_context(
+ room_version, event, context, do_sig_check=False
)
except AuthError as e:
- logger.warn("Failed to create new leave %r because %s", event, e)
+ logger.warning("Failed to create new leave %r because %s", event, e)
raise e
- defer.returnValue(event)
+ return event
- @defer.inlineCallbacks
- @log_function
- def on_send_leave_request(self, origin, pdu):
+ async def on_send_leave_request(self, origin, pdu):
""" We have received a leave event for a room. Fully process it."""
event = pdu
@@ -1517,19 +1696,25 @@ class FederationHandler(BaseHandler):
event.signatures,
)
+ if get_domain_from_id(event.sender) != origin:
+ logger.info(
+ "Got /send_leave request for user %r from different origin %s",
+ event.sender,
+ origin,
+ )
+ raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
+
event.internal_metadata.outlier = False
- context = yield self._handle_new_event(
- origin, event
- )
+ context = await self._handle_new_event(origin, event)
- event_allowed = yield self.third_party_event_rules.check_event_allowed(
- event, context,
+ event_allowed = await self.third_party_event_rules.check_event_allowed(
+ event, context
)
if not event_allowed:
logger.info("Sending of leave %s forbidden by third-party rules", event)
raise SynapseError(
- 403, "This event is not allowed in this context", Codes.FORBIDDEN,
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
)
logger.debug(
@@ -1538,7 +1723,7 @@ class FederationHandler(BaseHandler):
event.signatures,
)
- defer.returnValue(None)
+ return None
@defer.inlineCallbacks
def get_state_for_pdu(self, room_id, event_id):
@@ -1546,18 +1731,14 @@ class FederationHandler(BaseHandler):
"""
event = yield self.store.get_event(
- event_id, allow_none=False, check_room_id=room_id,
+ event_id, allow_none=False, check_room_id=room_id
)
- state_groups = yield self.store.get_state_groups(
- room_id, [event_id]
- )
+ state_groups = yield self.state_store.get_state_groups(room_id, [event_id])
if state_groups:
_, state = list(iteritems(state_groups)).pop()
- results = {
- (e.type, e.state_key): e for e in state
- }
+ results = {(e.type, e.state_key): e for e in state}
if event.is_state():
# Get previous state
@@ -1570,21 +1751,19 @@ class FederationHandler(BaseHandler):
del results[(event.type, event.state_key)]
res = list(results.values())
- defer.returnValue(res)
+ return res
else:
- defer.returnValue([])
+ return []
@defer.inlineCallbacks
def get_state_ids_for_pdu(self, room_id, event_id):
"""Returns the state at the event. i.e. not including said event.
"""
event = yield self.store.get_event(
- event_id, allow_none=False, check_room_id=room_id,
+ event_id, allow_none=False, check_room_id=room_id
)
- state_groups = yield self.store.get_state_groups_ids(
- room_id, [event_id]
- )
+ state_groups = yield self.state_store.get_state_groups_ids(room_id, [event_id])
if state_groups:
_, state = list(state_groups.items()).pop()
@@ -1599,9 +1778,9 @@ class FederationHandler(BaseHandler):
else:
results.pop((event.type, event.state_key), None)
- defer.returnValue(list(results.values()))
+ return list(results.values())
else:
- defer.returnValue([])
+ return []
@defer.inlineCallbacks
@log_function
@@ -1610,15 +1789,14 @@ class FederationHandler(BaseHandler):
if not in_room:
raise AuthError(403, "Host not in room.")
- events = yield self.store.get_backfill_events(
- room_id,
- pdu_list,
- limit
- )
+ # Synapse asks for 100 events per backfill request. Do not allow more.
+ limit = min(limit, 100)
- events = yield filter_events_for_server(self.store, origin, events)
+ events = yield self.store.get_backfill_events(room_id, pdu_list, limit)
- defer.returnValue(events)
+ events = yield filter_events_for_server(self.storage, origin, events)
+
+ return events
@defer.inlineCallbacks
@log_function
@@ -1638,65 +1816,61 @@ class FederationHandler(BaseHandler):
AuthError if the server is not currently in the room
"""
event = yield self.store.get_event(
- event_id,
- allow_none=True,
- allow_rejected=True,
+ event_id, allow_none=True, allow_rejected=True
)
if event:
- in_room = yield self.auth.check_host_in_room(
- event.room_id,
- origin
- )
+ in_room = yield self.auth.check_host_in_room(event.room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
- events = yield filter_events_for_server(
- self.store, origin, [event],
- )
+ events = yield filter_events_for_server(self.storage, origin, [event])
event = events[0]
- defer.returnValue(event)
+ return event
else:
- defer.returnValue(None)
+ return None
def get_min_depth_for_context(self, context):
return self.store.get_min_depth(context)
- @defer.inlineCallbacks
- def _handle_new_event(self, origin, event, state=None, auth_events=None,
- backfilled=False):
- context = yield self._prep_event(
- origin, event,
- state=state,
- auth_events=auth_events,
- backfilled=backfilled,
+ async def _handle_new_event(
+ self, origin, event, state=None, auth_events=None, backfilled=False
+ ):
+ context = await self._prep_event(
+ origin, event, state=state, auth_events=auth_events, backfilled=backfilled
)
# reraise does not allow inlineCallbacks to preserve the stacktrace, so we
# hack around with a try/finally instead.
success = False
try:
- if not event.internal_metadata.is_outlier() and not backfilled:
- yield self.action_generator.handle_push_actions_for_event(
+ if (
+ not event.internal_metadata.is_outlier()
+ and not backfilled
+ and not context.rejected
+ ):
+ await self.action_generator.handle_push_actions_for_event(
event, context
)
- yield self.persist_events_and_notify(
- [(event, context)],
- backfilled=backfilled,
+ await self.persist_events_and_notify(
+ [(event, context)], backfilled=backfilled
)
success = True
finally:
if not success:
- logcontext.run_in_background(
- self.store.remove_push_actions_from_staging,
- event.event_id,
+ run_in_background(
+ self.store.remove_push_actions_from_staging, event.event_id
)
- defer.returnValue(context)
+ return context
- @defer.inlineCallbacks
- def _handle_new_events(self, origin, event_infos, backfilled=False):
+ async def _handle_new_events(
+ self,
+ origin: str,
+ event_infos: Iterable[_NewEventInfo],
+ backfilled: bool = False,
+ ) -> None:
"""Creates the appropriate contexts and persists events. The events
should not depend on one another, e.g. this should be used to persist
a bunch of outliers, but not a chunk of individual events that depend
@@ -1705,36 +1879,41 @@ class FederationHandler(BaseHandler):
Notifies about the events where appropriate.
"""
- @defer.inlineCallbacks
- def prep(ev_info):
- event = ev_info["event"]
- with logcontext.nested_logging_context(suffix=event.event_id):
- res = yield self._prep_event(
+ async def prep(ev_info: _NewEventInfo):
+ event = ev_info.event
+ with nested_logging_context(suffix=event.event_id):
+ res = await self._prep_event(
origin,
event,
- state=ev_info.get("state"),
- auth_events=ev_info.get("auth_events"),
+ state=ev_info.state,
+ auth_events=ev_info.auth_events,
backfilled=backfilled,
)
- defer.returnValue(res)
+ return res
- contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- logcontext.run_in_background(prep, ev_info)
- for ev_info in event_infos
- ], consumeErrors=True,
- ))
+ contexts = await make_deferred_yieldable(
+ defer.gatherResults(
+ [run_in_background(prep, ev_info) for ev_info in event_infos],
+ consumeErrors=True,
+ )
+ )
- yield self.persist_events_and_notify(
+ await self.persist_events_and_notify(
[
- (ev_info["event"], context)
+ (ev_info.event, context)
for ev_info, context in zip(event_infos, contexts)
],
backfilled=backfilled,
)
- @defer.inlineCallbacks
- def _persist_auth_tree(self, origin, auth_events, state, event):
+ async def _persist_auth_tree(
+ self,
+ origin: str,
+ auth_events: List[EventBase],
+ state: List[EventBase],
+ event: EventBase,
+ room_version: RoomVersion,
+ ) -> None:
"""Checks the auth chain is valid (and passes auth checks) for the
state and event. Then persists the auth chain and state atomically.
Persists the event separately. Notifies about the persisted events
@@ -1743,23 +1922,21 @@ class FederationHandler(BaseHandler):
Will attempt to fetch missing auth events.
Args:
- origin (str): Where the events came from
- auth_events (list)
- state (list)
- event (Event)
-
- Returns:
- Deferred
+ origin: Where the events came from
+ auth_events
+ state
+ event
+ room_version: The room version we expect this room to have, and
+ will raise if it doesn't match the version in the create event.
"""
events_to_context = {}
for e in itertools.chain(auth_events, state):
e.internal_metadata.outlier = True
- ctx = yield self.state_handler.compute_event_context(e)
+ ctx = await self.state_handler.compute_event_context(e)
events_to_context[e.event_id] = ctx
event_map = {
- e.event_id: e
- for e in itertools.chain(auth_events, state, [event])
+ e.event_id: e for e in itertools.chain(auth_events, state, [event])
}
create_event = None
@@ -1773,10 +1950,13 @@ class FederationHandler(BaseHandler):
# invalid, and it would fail auth checks anyway.
raise SynapseError(400, "No create event in state")
- room_version = create_event.content.get(
- "room_version", RoomVersions.V1.identifier,
+ room_version_id = create_event.content.get(
+ "room_version", RoomVersions.V1.identifier
)
+ if room_version.identifier != room_version_id:
+ raise SynapseError(400, "Room version mismatch")
+
missing_auth_events = set()
for e in itertools.chain(auth_events, state, [event]):
for e_id in e.auth_event_ids():
@@ -1784,12 +1964,8 @@ class FederationHandler(BaseHandler):
missing_auth_events.add(e_id)
for e_id in missing_auth_events:
- m_ev = yield self.federation_client.get_pdu(
- [origin],
- e_id,
- room_version=room_version,
- outlier=True,
- timeout=10000,
+ m_ev = await self.federation_client.get_pdu(
+ [origin], e_id, room_version=room_version, outlier=True, timeout=10000,
)
if m_ev and m_ev.event_id == e_id:
event_map[e_id] = m_ev
@@ -1806,7 +1982,7 @@ class FederationHandler(BaseHandler):
auth_for_e[(EventTypes.Create, "")] = create_event
try:
- self.auth.check(room_version, e, auth_events=auth_for_e)
+ event_auth.check(room_version, e, auth_events=auth_for_e)
except SynapseError as err:
# we may get SynapseErrors here as well as AuthErrors. For
# instance, there are a couple of (ancient) events in some
@@ -1814,111 +1990,80 @@ class FederationHandler(BaseHandler):
# cause SynapseErrors in auth.check. We don't want to give up
# the attempt to federate altogether in such cases.
- logger.warn(
- "Rejecting %s because %s",
- e.event_id, err.msg
- )
+ logger.warning("Rejecting %s because %s", e.event_id, err.msg)
if e == event:
raise
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
- yield self.persist_events_and_notify(
+ await self.persist_events_and_notify(
[
(e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state)
- ],
+ ]
)
- new_event_context = yield self.state_handler.compute_event_context(
+ new_event_context = await self.state_handler.compute_event_context(
event, old_state=state
)
- yield self.persist_events_and_notify(
- [(event, new_event_context)],
- )
-
- @defer.inlineCallbacks
- def _prep_event(self, origin, event, state, auth_events, backfilled):
- """
-
- Args:
- origin:
- event:
- state:
- auth_events:
- backfilled (bool)
+ await self.persist_events_and_notify([(event, new_event_context)])
- Returns:
- Deferred, which resolves to synapse.events.snapshot.EventContext
- """
- context = yield self.state_handler.compute_event_context(
- event, old_state=state,
- )
+ async def _prep_event(
+ self,
+ origin: str,
+ event: EventBase,
+ state: Optional[Iterable[EventBase]],
+ auth_events: Optional[StateMap[EventBase]],
+ backfilled: bool,
+ ) -> EventContext:
+ context = await self.state_handler.compute_event_context(event, old_state=state)
if not auth_events:
- prev_state_ids = yield context.get_prev_state_ids(self.store)
- auth_events_ids = yield self.auth.compute_auth_events(
- event, prev_state_ids, for_verification=True,
+ prev_state_ids = await context.get_prev_state_ids()
+ auth_events_ids = await self.auth.compute_auth_events(
+ event, prev_state_ids, for_verification=True
)
- auth_events = yield self.store.get_events(auth_events_ids)
- auth_events = {
- (e.type, e.state_key): e for e in auth_events.values()
- }
+ auth_events = await self.store.get_events(auth_events_ids)
+ auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_event_ids():
if len(event.prev_event_ids()) == 1 and event.depth < 5:
- c = yield self.store.get_event(
- event.prev_event_ids()[0],
- allow_none=True,
+ c = await self.store.get_event(
+ event.prev_event_ids()[0], allow_none=True
)
if c and c.type == EventTypes.Create:
auth_events[(c.type, c.state_key)] = c
- try:
- yield self.do_auth(
- origin, event, context, auth_events=auth_events
- )
- except AuthError as e:
- logger.warn(
- "[%s %s] Rejecting: %s",
- event.room_id, event.event_id, e.msg
- )
-
- context.rejected = RejectedReason.AUTH_ERROR
+ context = await self.do_auth(origin, event, context, auth_events=auth_events)
if not context.rejected:
- yield self._check_for_soft_fail(event, state, backfilled)
+ await self._check_for_soft_fail(event, state, backfilled)
if event.type == EventTypes.GuestAccess and not context.rejected:
- yield self.maybe_kick_guest_users(event)
+ await self.maybe_kick_guest_users(event)
- defer.returnValue(context)
+ return context
- @defer.inlineCallbacks
- def _check_for_soft_fail(self, event, state, backfilled):
- """Checks if we should soft fail the event, if so marks the event as
+ async def _check_for_soft_fail(
+ self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool
+ ) -> None:
+ """Checks if we should soft fail the event; if so, marks the event as
such.
Args:
- event (FrozenEvent)
- state (dict|None): The state at the event if we don't have all the
- event's prev events
- backfilled (bool): Whether the event is from backfill
-
- Returns:
- Deferred
+ event
+ state: The state at the event if we don't have all the event's prev events
+ backfilled: Whether the event is from backfill
"""
# For new (non-backfilled and non-outlier) events we check if the event
# passes auth based on the current state. If it doesn't then we
# "soft-fail" the event.
do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier()
if do_soft_fail_check:
- extrem_ids = yield self.store.get_latest_event_ids_in_room(
- event.room_id,
- )
+ extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
extrem_ids = set(extrem_ids)
prev_event_ids = set(event.prev_event_ids())
@@ -1929,7 +2074,8 @@ class FederationHandler(BaseHandler):
do_soft_fail_check = False
if do_soft_fail_check:
- room_version = yield self.store.get_room_version(event.room_id)
+ room_version = await self.store.get_room_version_id(event.room_id)
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
# Calculate the "current state".
if state is not None:
@@ -1945,59 +2091,55 @@ class FederationHandler(BaseHandler):
# given state at the event. This should correctly handle cases
# like bans, especially with state res v2.
- state_sets = yield self.store.get_state_groups(
- event.room_id, extrem_ids,
+ state_sets = await self.state_store.get_state_groups(
+ event.room_id, extrem_ids
)
state_sets = list(state_sets.values())
state_sets.append(state)
- current_state_ids = yield self.state_handler.resolve_events(
- room_version, state_sets, event,
+ current_state_ids = await self.state_handler.resolve_events(
+ room_version, state_sets, event
)
current_state_ids = {
k: e.event_id for k, e in iteritems(current_state_ids)
}
else:
- current_state_ids = yield self.state_handler.get_current_state_ids(
- event.room_id, latest_event_ids=extrem_ids,
+ current_state_ids = await self.state_handler.get_current_state_ids(
+ event.room_id, latest_event_ids=extrem_ids
)
logger.debug(
"Doing soft-fail check for %s: state %s",
- event.event_id, current_state_ids,
+ event.event_id,
+ current_state_ids,
)
# Now check if event pass auth against said current state
auth_types = auth_types_for_event(event)
current_state_ids = [
- e for k, e in iteritems(current_state_ids)
- if k in auth_types
+ e for k, e in iteritems(current_state_ids) if k in auth_types
]
- current_auth_events = yield self.store.get_events(current_state_ids)
+ current_auth_events = await self.store.get_events(current_state_ids)
current_auth_events = {
(e.type, e.state_key): e for e in current_auth_events.values()
}
try:
- self.auth.check(room_version, event, auth_events=current_auth_events)
- except AuthError as e:
- logger.warn(
- "Soft-failing %r because %s",
- event, e,
+ event_auth.check(
+ room_version_obj, event, auth_events=current_auth_events
)
+ except AuthError as e:
+ logger.warning("Soft-failing %r because %s", event, e)
event.internal_metadata.soft_failed = True
- @defer.inlineCallbacks
- def on_query_auth(self, origin, event_id, room_id, remote_auth_chain, rejects,
- missing):
- in_room = yield self.auth.check_host_in_room(
- room_id,
- origin
- )
+ async def on_query_auth(
+ self, origin, event_id, room_id, remote_auth_chain, rejects, missing
+ ):
+ in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
- event = yield self.store.get_event(
+ event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
@@ -2005,75 +2147,77 @@ class FederationHandler(BaseHandler):
# don't want to fall into the trap of `missing` being wrong.
for e in remote_auth_chain:
try:
- yield self._handle_new_event(origin, e)
+ await self._handle_new_event(origin, e)
except AuthError:
pass
# Now get the current auth_chain for the event.
- local_auth_chain = yield self.store.get_auth_chain(
- [auth_id for auth_id in event.auth_event_ids()],
- include_given=True
+ local_auth_chain = await self.store.get_auth_chain(
+ list(event.auth_event_ids()), include_given=True
)
# TODO: Check if we would now reject event_id. If so we need to tell
# everyone.
- ret = yield self.construct_auth_difference(
- local_auth_chain, remote_auth_chain
- )
+ ret = await self.construct_auth_difference(local_auth_chain, remote_auth_chain)
logger.debug("on_query_auth returning: %s", ret)
- defer.returnValue(ret)
+ return ret
- @defer.inlineCallbacks
- def on_get_missing_events(self, origin, room_id, earliest_events,
- latest_events, limit):
- in_room = yield self.auth.check_host_in_room(
- room_id,
- origin
- )
+ async def on_get_missing_events(
+ self, origin, room_id, earliest_events, latest_events, limit
+ ):
+ in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
+ # Only allow up to 20 events to be retrieved per request.
limit = min(limit, 20)
- missing_events = yield self.store.get_missing_events(
+ missing_events = await self.store.get_missing_events(
room_id=room_id,
earliest_events=earliest_events,
latest_events=latest_events,
limit=limit,
)
- missing_events = yield filter_events_for_server(
- self.store, origin, missing_events,
+ missing_events = await filter_events_for_server(
+ self.storage, origin, missing_events
)
- defer.returnValue(missing_events)
+ return missing_events
- @defer.inlineCallbacks
- @log_function
- def do_auth(self, origin, event, context, auth_events):
+ async def do_auth(
+ self,
+ origin: str,
+ event: EventBase,
+ context: EventContext,
+ auth_events: StateMap[EventBase],
+ ) -> EventContext:
"""
Args:
- origin (str):
- event (synapse.events.EventBase):
- context (synapse.events.snapshot.EventContext):
- auth_events (dict[(str, str)->synapse.events.EventBase]):
+ origin:
+ event:
+ context:
+ auth_events:
Map from (event_type, state_key) to event
- What we expect the event's auth_events to be, based on the event's
- position in the dag. I think? maybe??
+ Normally, our calculated auth_events based on the state of the room
+ at the event's position in the DAG, though occasionally (eg if the
+ event is an outlier), may be the auth events claimed by the remote
+ server.
Also NB that this function adds entries to it.
Returns:
- defer.Deferred[None]
+ updated context object
"""
- room_version = yield self.store.get_room_version(event.room_id)
+ room_version = await self.store.get_room_version_id(event.room_id)
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
try:
- yield self._update_auth_events_and_context_for_auth(
+ context = await self._update_auth_events_and_context_for_auth(
origin, event, context, auth_events
)
except Exception:
@@ -2088,15 +2232,20 @@ class FederationHandler(BaseHandler):
)
try:
- self.auth.check(room_version, event, auth_events=auth_events)
+ event_auth.check(room_version_obj, event, auth_events=auth_events)
except AuthError as e:
- logger.warn("Failed auth resolution for %r because %s", event, e)
- raise e
+ logger.warning("Failed auth resolution for %r because %s", event, e)
+ context.rejected = RejectedReason.AUTH_ERROR
- @defer.inlineCallbacks
- def _update_auth_events_and_context_for_auth(
- self, origin, event, context, auth_events
- ):
+ return context
+
+ async def _update_auth_events_and_context_for_auth(
+ self,
+ origin: str,
+ event: EventBase,
+ context: EventContext,
+ auth_events: StateMap[EventBase],
+ ) -> EventContext:
"""Helper for do_auth. See there for docs.
Checks whether a given event has the expected auth events. If it
@@ -2104,67 +2253,59 @@ class FederationHandler(BaseHandler):
we can come to a consensus (e.g. if one server missed some valid
state).
- This attempts to resovle any potential divergence of state between
+ This attempts to resolve any potential divergence of state between
servers, but is not essential and so failures should not block further
processing of the event.
Args:
- origin (str):
- event (synapse.events.EventBase):
- context (synapse.events.snapshot.EventContext):
- auth_events (dict[(str, str)->synapse.events.EventBase]):
+ origin:
+ event:
+ context:
+
+ auth_events:
+ Map from (event_type, state_key) to event
+
+ Normally, our calculated auth_events based on the state of the room
+ at the event's position in the DAG, though occasionally (eg if the
+ event is an outlier), may be the auth events claimed by the remote
+ server.
+
+ Also NB that this function adds entries to it.
Returns:
- defer.Deferred[None]
+ updated context
"""
event_auth_events = set(event.auth_event_ids())
- if event.is_state():
- event_key = (event.type, event.state_key)
- else:
- event_key = None
-
- # if the event's auth_events refers to events which are not in our
- # calculated auth_events, we need to fetch those events from somewhere.
- #
- # we start by fetching them from the store, and then try calling /event_auth/.
+ # missing_auth is the set of the event's auth_events which we don't yet have
+ # in auth_events.
missing_auth = event_auth_events.difference(
e.event_id for e in auth_events.values()
)
+ # if we have missing events, we need to fetch those events from somewhere.
+ #
+ # we start by checking if they are in the store, and then try calling /event_auth/.
if missing_auth:
- # TODO: can we use store.have_seen_events here instead?
- have_events = yield self.store.get_seen_events_with_rejections(
- missing_auth
- )
- logger.debug("Got events %s from store", have_events)
- missing_auth.difference_update(have_events.keys())
- else:
- have_events = {}
-
- have_events.update({
- e.event_id: ""
- for e in auth_events.values()
- })
+ have_events = await self.store.have_seen_events(missing_auth)
+ logger.debug("Events %s are in the store", have_events)
+ missing_auth.difference_update(have_events)
if missing_auth:
# If we don't have all the auth events, we need to get them.
- logger.info(
- "auth_events contains unknown events: %s",
- missing_auth,
- )
+ logger.info("auth_events contains unknown events: %s", missing_auth)
try:
try:
- remote_auth_chain = yield self.federation_client.get_event_auth(
+ remote_auth_chain = await self.federation_client.get_event_auth(
origin, event.room_id, event.event_id
)
except RequestSendFailed as e:
# The other side isn't around or doesn't implement the
# endpoint, so lets just bail out.
logger.info("Failed to get event auth from remote: %s", e)
- return
+ return context
- seen_remotes = yield self.store.have_seen_events(
+ seen_remotes = await self.store.have_seen_events(
[e.event_id for e in remote_auth_chain]
)
@@ -2178,43 +2319,40 @@ class FederationHandler(BaseHandler):
try:
auth_ids = e.auth_event_ids()
auth = {
- (e.type, e.state_key): e for e in remote_auth_chain
+ (e.type, e.state_key): e
+ for e in remote_auth_chain
if e.event_id in auth_ids or e.type == EventTypes.Create
}
e.internal_metadata.outlier = True
logger.debug(
- "do_auth %s missing_auth: %s",
- event.event_id, e.event_id
- )
- yield self._handle_new_event(
- origin, e, auth_events=auth
+ "do_auth %s missing_auth: %s", event.event_id, e.event_id
)
+ await self._handle_new_event(origin, e, auth_events=auth)
if e.event_id in event_auth_events:
auth_events[(e.type, e.state_key)] = e
except AuthError:
pass
- have_events = yield self.store.get_seen_events_with_rejections(
- event.auth_event_ids()
- )
except Exception:
- # FIXME:
logger.exception("Failed to get auth chain")
if event.internal_metadata.is_outlier():
+ # XXX: given that, for an outlier, we'll be working with the
+ # event's *claimed* auth events rather than those we calculated:
+ # (a) is there any point in this test, since different_auth below will
+ # obviously be empty
+ # (b) alternatively, why don't we do it earlier?
logger.info("Skipping auth_event fetch for outlier")
- return
+ return context
- # FIXME: Assumes we have and stored all the state for all the
- # prev_events
different_auth = event_auth_events.difference(
e.event_id for e in auth_events.values()
)
if not different_auth:
- return
+ return context
logger.info(
"auth_events refers to events which are not in our calculated auth "
@@ -2222,189 +2360,94 @@ class FederationHandler(BaseHandler):
different_auth,
)
- room_version = yield self.store.get_room_version(event.room_id)
+ # XXX: currently this checks for redactions but I'm not convinced that is
+ # necessary?
+ different_events = await self.store.get_events_as_list(different_auth)
- different_events = yield logcontext.make_deferred_yieldable(
- defer.gatherResults([
- logcontext.run_in_background(
- self.store.get_event,
- d,
- allow_none=True,
- allow_rejected=False,
+ for d in different_events:
+ if d.room_id != event.room_id:
+ logger.warning(
+ "Event %s refers to auth_event %s which is in a different room",
+ event.event_id,
+ d.event_id,
)
- for d in different_auth
- if d in have_events and not have_events[d]
- ], consumeErrors=True)
- ).addErrback(unwrapFirstError)
-
- if different_events:
- local_view = dict(auth_events)
- remote_view = dict(auth_events)
- remote_view.update({
- (d.type, d.state_key): d for d in different_events if d
- })
-
- new_state = yield self.state_handler.resolve_events(
- room_version,
- [list(local_view.values()), list(remote_view.values())],
- event
- )
-
- logger.info(
- "After state res: updating auth_events with new state %s",
- {
- (d.type, d.state_key): d.event_id for d in new_state.values()
- if auth_events.get((d.type, d.state_key)) != d
- },
- )
- auth_events.update(new_state)
+ # don't attempt to resolve the claimed auth events against our own
+ # in this case: just use our own auth events.
+ #
+ # XXX: should we reject the event in this case? It feels like we should,
+ # but then shouldn't we also do so if we've failed to fetch any of the
+ # auth events?
+ return context
- different_auth = event_auth_events.difference(
- e.event_id for e in auth_events.values()
- )
+ # now we state-resolve between our own idea of the auth events, and the remote's
+ # idea of them.
- yield self._update_context_for_auth_events(
- event, context, auth_events, event_key,
- )
+ local_state = auth_events.values()
+ remote_auth_events = dict(auth_events)
+ remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
+ remote_state = remote_auth_events.values()
- if not different_auth:
- # we're done
- return
+ room_version = await self.store.get_room_version_id(event.room_id)
+ new_state = await self.state_handler.resolve_events(
+ room_version, (local_state, remote_state), event
+ )
logger.info(
- "auth_events still refers to events which are not in the calculated auth "
- "chain after state resolution: %s",
- different_auth,
+ "After state res: updating auth_events with new state %s",
+ {
+ (d.type, d.state_key): d.event_id
+ for d in new_state.values()
+ if auth_events.get((d.type, d.state_key)) != d
+ },
)
- # Only do auth resolution if we have something new to say.
- # We can't prove an auth failure.
- do_resolution = False
-
- for e_id in different_auth:
- if e_id in have_events:
- if have_events[e_id] == RejectedReason.NOT_ANCESTOR:
- do_resolution = True
- break
-
- if not do_resolution:
- logger.info(
- "Skipping auth resolution due to lack of provable rejection reasons"
- )
- return
-
- logger.info("Doing auth resolution")
+ auth_events.update(new_state)
- prev_state_ids = yield context.get_prev_state_ids(self.store)
-
- # 1. Get what we think is the auth chain.
- auth_ids = yield self.auth.compute_auth_events(
- event, prev_state_ids
- )
- local_auth_chain = yield self.store.get_auth_chain(
- auth_ids, include_given=True
+ context = await self._update_context_for_auth_events(
+ event, context, auth_events
)
- try:
- # 2. Get remote difference.
- try:
- result = yield self.federation_client.query_auth(
- origin,
- event.room_id,
- event.event_id,
- local_auth_chain,
- )
- except RequestSendFailed as e:
- # The other side isn't around or doesn't implement the
- # endpoint, so lets just bail out.
- logger.info("Failed to query auth from remote: %s", e)
- return
-
- seen_remotes = yield self.store.have_seen_events(
- [e.event_id for e in result["auth_chain"]]
- )
-
- # 3. Process any remote auth chain events we haven't seen.
- for ev in result["auth_chain"]:
- if ev.event_id in seen_remotes:
- continue
-
- if ev.event_id == event.event_id:
- continue
-
- try:
- auth_ids = ev.auth_event_ids()
- auth = {
- (e.type, e.state_key): e
- for e in result["auth_chain"]
- if e.event_id in auth_ids
- or event.type == EventTypes.Create
- }
- ev.internal_metadata.outlier = True
-
- logger.debug(
- "do_auth %s different_auth: %s",
- event.event_id, e.event_id
- )
-
- yield self._handle_new_event(
- origin, ev, auth_events=auth
- )
-
- if ev.event_id in event_auth_events:
- auth_events[(ev.type, ev.state_key)] = ev
- except AuthError:
- pass
-
- except Exception:
- # FIXME:
- logger.exception("Failed to query auth chain")
-
- # 4. Look at rejects and their proofs.
- # TODO.
+ return context
- yield self._update_context_for_auth_events(
- event, context, auth_events, event_key,
- )
-
- @defer.inlineCallbacks
- def _update_context_for_auth_events(self, event, context, auth_events,
- event_key):
+ async def _update_context_for_auth_events(
+ self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
+ ) -> EventContext:
"""Update the state_ids in an event context after auth event resolution,
storing the changes as a new state group.
Args:
- event (Event): The event we're handling the context for
+ event: The event we're handling the context for
- context (synapse.events.snapshot.EventContext): event context
- to be updated
+ context: initial event context
- auth_events (dict[(str, str)->str]): Events to update in the event
- context.
+ auth_events: Events to update in the event context.
- event_key ((str, str)): (type, state_key) for the current event.
- this will not be included in the current_state in the context.
+ Returns:
+ new event context
"""
+ # exclude the state key of the new event from the current_state in the context.
+ if event.is_state():
+ event_key = (event.type, event.state_key)
+ else:
+ event_key = None
state_updates = {
- k: a.event_id for k, a in iteritems(auth_events)
- if k != event_key
+ k: a.event_id for k, a in iteritems(auth_events) if k != event_key
}
- current_state_ids = yield context.get_current_state_ids(self.store)
+
+ current_state_ids = await context.get_current_state_ids()
current_state_ids = dict(current_state_ids)
current_state_ids.update(state_updates)
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = await context.get_prev_state_ids()
prev_state_ids = dict(prev_state_ids)
- prev_state_ids.update({
- k: a.event_id for k, a in iteritems(auth_events)
- })
+ prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)})
# create a new state group as a delta from the existing one.
prev_group = context.state_group
- state_group = yield self.store.store_state_group(
+ state_group = await self.state_store.store_state_group(
event.event_id,
event.room_id,
prev_group=prev_group,
@@ -2412,16 +2455,18 @@ class FederationHandler(BaseHandler):
current_state_ids=current_state_ids,
)
- yield context.update_state(
+ return EventContext.with_state(
state_group=state_group,
+ state_group_before_event=context.state_group_before_event,
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
prev_group=prev_group,
delta_ids=state_updates,
)
- @defer.inlineCallbacks
- def construct_auth_difference(self, local_auth, remote_auth):
+ async def construct_auth_difference(
+ self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase]
+ ) -> Dict:
""" Given a local and remote auth chain, find the differences. This
assumes that we have already processed all events in remote_auth
@@ -2530,7 +2575,7 @@ class FederationHandler(BaseHandler):
reason_map = {}
for e in base_remote_rejected:
- reason = yield self.store.get_rejection_reason(e.event_id)
+ reason = await self.store.get_rejection_reason(e.event_id)
if reason is None:
# TODO: e is not in the current state, so we should
# construct some proof of that.
@@ -2538,41 +2583,23 @@ class FederationHandler(BaseHandler):
reason_map[e.event_id] = reason
- if reason == RejectedReason.AUTH_ERROR:
- pass
- elif reason == RejectedReason.REPLACED:
- # TODO: Get proof
- pass
- elif reason == RejectedReason.NOT_ANCESTOR:
- # TODO: Get proof.
- pass
-
logger.debug("construct_auth_difference returning")
- defer.returnValue({
+ return {
"auth_chain": local_auth,
"rejects": {
- e.event_id: {
- "reason": reason_map[e.event_id],
- "proof": None,
- }
+ e.event_id: {"reason": reason_map[e.event_id], "proof": None}
for e in base_remote_rejected
},
"missing": [e.event_id for e in missing_locals],
- })
+ }
@defer.inlineCallbacks
@log_function
def exchange_third_party_invite(
- self,
- sender_user_id,
- target_user_id,
- room_id,
- signed,
+ self, sender_user_id, target_user_id, room_id, signed
):
- third_party_invite = {
- "signed": signed,
- }
+ third_party_invite = {"signed": signed}
event_dict = {
"type": EventTypes.Member,
@@ -2586,7 +2613,7 @@ class FederationHandler(BaseHandler):
}
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
- room_version = yield self.store.get_room_version(room_id)
+ room_version = yield self.store.get_room_version_id(room_id)
builder = self.event_builder_factory.new(room_version, event_dict)
EventValidator().validate_builder(builder)
@@ -2595,7 +2622,7 @@ class FederationHandler(BaseHandler):
)
event_allowed = yield self.third_party_event_rules.check_event_allowed(
- event, context,
+ event, context
)
if not event_allowed:
logger.info(
@@ -2603,7 +2630,7 @@ class FederationHandler(BaseHandler):
event,
)
raise SynapseError(
- 403, "This event is not allowed in this context", Codes.FORBIDDEN,
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
)
event, context = yield self.add_display_name_to_third_party_invite(
@@ -2619,80 +2646,84 @@ class FederationHandler(BaseHandler):
try:
yield self.auth.check_from_context(room_version, event, context)
except AuthError as e:
- logger.warn("Denying new third party invite %r because %s", event, e)
+ logger.warning("Denying new third party invite %r because %s", event, e)
raise e
yield self._check_signature(event, context)
+
+ # We retrieve the room member handler here as to not cause a cyclic dependency
member_handler = self.hs.get_room_member_handler()
yield member_handler.send_membership_event(None, event, context)
else:
- destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
+ destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)}
yield self.federation_client.forward_third_party_invite(
- destinations,
- room_id,
- event_dict,
+ destinations, room_id, event_dict
)
- @defer.inlineCallbacks
- @log_function
- def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
+ async def on_exchange_third_party_invite_request(
+ self, room_id: str, event_dict: JsonDict
+ ) -> None:
"""Handle an exchange_third_party_invite request from a remote server
The remote server will call this when it wants to turn a 3pid invite
into a normal m.room.member invite.
- Returns:
- Deferred: resolves (to None)
+ Args:
+ room_id: The ID of the room.
+
+ event_dict (dict[str, Any]): Dictionary containing the event body.
+
"""
- room_version = yield self.store.get_room_version(room_id)
+ room_version = await self.store.get_room_version_id(room_id)
# NB: event_dict has a particular specced format we might need to fudge
# if we change event formats too much.
builder = self.event_builder_factory.new(room_version, event_dict)
- event, context = yield self.event_creation_handler.create_new_client_event(
- builder=builder,
+ event, context = await self.event_creation_handler.create_new_client_event(
+ builder=builder
)
- event_allowed = yield self.third_party_event_rules.check_event_allowed(
- event, context,
+ event_allowed = await self.third_party_event_rules.check_event_allowed(
+ event, context
)
if not event_allowed:
logger.warning(
- "Exchange of threepid invite %s forbidden by third-party rules",
- event,
+ "Exchange of threepid invite %s forbidden by third-party rules", event
)
raise SynapseError(
- 403, "This event is not allowed in this context", Codes.FORBIDDEN,
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
)
- event, context = yield self.add_display_name_to_third_party_invite(
+ event, context = await self.add_display_name_to_third_party_invite(
room_version, event_dict, event, context
)
try:
- yield self.auth.check_from_context(room_version, event, context)
+ await self.auth.check_from_context(room_version, event, context)
except AuthError as e:
- logger.warn("Denying third party invite %r because %s", event, e)
+ logger.warning("Denying third party invite %r because %s", event, e)
raise e
- yield self._check_signature(event, context)
+ await self._check_signature(event, context)
# We need to tell the transaction queue to send this out, even
# though the sender isn't a local user.
event.internal_metadata.send_on_behalf_of = get_domain_from_id(event.sender)
+ # We retrieve the room member handler here as to not cause a cyclic dependency
member_handler = self.hs.get_room_member_handler()
- yield member_handler.send_membership_event(None, event, context)
+ await member_handler.send_membership_event(None, event, context)
@defer.inlineCallbacks
- def add_display_name_to_third_party_invite(self, room_version, event_dict,
- event, context):
+ def add_display_name_to_third_party_invite(
+ self, room_version, event_dict, event, context
+ ):
key = (
EventTypes.ThirdPartyInvite,
- event.content["third_party_invite"]["signed"]["token"]
+ event.content["third_party_invite"]["signed"]["token"],
)
original_invite = None
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
original_invite_id = prev_state_ids.get(key)
if original_invite_id:
original_invite = yield self.store.get_event(
@@ -2708,8 +2739,7 @@ class FederationHandler(BaseHandler):
event_dict["content"]["third_party_invite"]["display_name"] = display_name
else:
logger.info(
- "Could not find invite event for third_party_invite: %r",
- event_dict
+ "Could not find invite event for third_party_invite: %r", event_dict
)
# We don't discard here as this is not the appropriate place to do
# auth checks. If we need the invite and don't have it then the
@@ -2718,10 +2748,10 @@ class FederationHandler(BaseHandler):
builder = self.event_builder_factory.new(room_version, event_dict)
EventValidator().validate_builder(builder)
event, context = yield self.event_creation_handler.create_new_client_event(
- builder=builder,
+ builder=builder
)
EventValidator().validate_new(event, self.config)
- defer.returnValue((event, context))
+ return (event, context)
@defer.inlineCallbacks
def _check_signature(self, event, context):
@@ -2741,10 +2771,8 @@ class FederationHandler(BaseHandler):
signed = event.content["third_party_invite"]["signed"]
token = signed["token"]
- prev_state_ids = yield context.get_prev_state_ids(self.store)
- invite_event_id = prev_state_ids.get(
- (EventTypes.ThirdPartyInvite, token,)
- )
+ prev_state_ids = yield context.get_prev_state_ids()
+ invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))
invite_event = None
if invite_event_id:
@@ -2753,25 +2781,59 @@ class FederationHandler(BaseHandler):
if not invite_event:
raise AuthError(403, "Could not find invite")
+ logger.debug("Checking auth on event %r", event.content)
+
last_exception = None
+ # for each public key in the 3pid invite event
for public_key_object in self.hs.get_auth().get_public_keys(invite_event):
try:
+ # for each sig on the third_party_invite block of the actual invite
for server, signature_block in signed["signatures"].items():
for key_name, encoded_signature in signature_block.items():
if not key_name.startswith("ed25519:"):
continue
- public_key = public_key_object["public_key"]
- verify_key = decode_verify_key_bytes(
+ logger.debug(
+ "Attempting to verify sig with key %s from %r "
+ "against pubkey %r",
key_name,
- decode_base64(public_key)
+ server,
+ public_key_object,
)
- verify_signed_json(signed, server, verify_key)
- if "key_validity_url" in public_key_object:
- yield self._check_key_revocation(
- public_key,
- public_key_object["key_validity_url"]
+
+ try:
+ public_key = public_key_object["public_key"]
+ verify_key = decode_verify_key_bytes(
+ key_name, decode_base64(public_key)
)
+ verify_signed_json(signed, server, verify_key)
+ logger.debug(
+ "Successfully verified sig with key %s from %r "
+ "against pubkey %r",
+ key_name,
+ server,
+ public_key_object,
+ )
+ except Exception:
+ logger.info(
+ "Failed to verify sig with key %s from %r "
+ "against pubkey %r",
+ key_name,
+ server,
+ public_key_object,
+ )
+ raise
+ try:
+ if "key_validity_url" in public_key_object:
+ yield self._check_key_revocation(
+ public_key, public_key_object["key_validity_url"]
+ )
+ except Exception:
+ logger.info(
+ "Failed to query key_validity_url %s",
+ public_key_object["key_validity_url"],
+ )
+ raise
return
except Exception as e:
last_exception = e
@@ -2792,54 +2854,54 @@ class FederationHandler(BaseHandler):
for revocation.
"""
try:
- response = yield self.http_client.get_json(
- url,
- {"public_key": public_key}
- )
+ response = yield self.http_client.get_json(url, {"public_key": public_key})
except Exception:
- raise SynapseError(
- 502,
- "Third party certificate could not be checked"
- )
+ raise SynapseError(502, "Third party certificate could not be checked")
if "valid" not in response or not response["valid"]:
raise AuthError(403, "Third party certificate was invalid")
- @defer.inlineCallbacks
- def persist_events_and_notify(self, event_and_contexts, backfilled=False):
+ async def persist_events_and_notify(
+ self,
+ event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
+ backfilled: bool = False,
+ ) -> None:
"""Persists events and tells the notifier/pushers about them, if
necessary.
Args:
- event_and_contexts(list[tuple[FrozenEvent, EventContext]])
- backfilled (bool): Whether these events are a result of
+ event_and_contexts:
+ backfilled: Whether these events are a result of
backfilling or not
-
- Returns:
- Deferred
"""
if self.config.worker_app:
- yield self._send_events_to_master(
+ await self._send_events_to_master(
store=self.store,
event_and_contexts=event_and_contexts,
- backfilled=backfilled
+ backfilled=backfilled,
)
else:
- max_stream_id = yield self.store.persist_events(
- event_and_contexts,
- backfilled=backfilled,
+ max_stream_id = await self.storage.persistence.persist_events(
+ event_and_contexts, backfilled=backfilled
)
+ if self._ephemeral_messages_enabled:
+ for (event, context) in event_and_contexts:
+ # If there's an expiry timestamp on the event, schedule its expiry.
+ self._message_handler.maybe_schedule_expiry(event)
+
if not backfilled: # Never notify for backfilled events
for event, _ in event_and_contexts:
- yield self._notify_persisted_event(event, max_stream_id)
+ await self._notify_persisted_event(event, max_stream_id)
- def _notify_persisted_event(self, event, max_stream_id):
+ async def _notify_persisted_event(
+ self, event: EventBase, max_stream_id: int
+ ) -> None:
"""Checks to see if notifier/pushers should be notified about the
event or not.
Args:
- event (FrozenEvent)
- max_stream_id (int): The max_stream_id returned by persist_events
+ event:
+ max_stream_id: The max_stream_id returned by persist_events
"""
extra_users = []
@@ -2860,34 +2922,54 @@ class FederationHandler(BaseHandler):
event_stream_id = event.internal_metadata.stream_ordering
self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id,
- extra_users=extra_users
+ event, event_stream_id, max_stream_id, extra_users=extra_users
)
- return self.pusher_pool.on_new_notifications(
- event_stream_id, max_stream_id,
- )
+ await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
- def _clean_room_for_join(self, room_id):
+ async def _clean_room_for_join(self, room_id: str) -> None:
"""Called to clean up any data in DB for a given room, ready for the
server to join the room.
Args:
- room_id (str)
+ room_id
"""
if self.config.worker_app:
- return self._clean_room_for_join_client(room_id)
+ await self._clean_room_for_join_client(room_id)
else:
- return self.store.clean_room_for_join(room_id)
+ await self.store.clean_room_for_join(room_id)
- def user_joined_room(self, user, room_id):
+ async def user_joined_room(self, user: UserID, room_id: str) -> None:
"""Called when a new user has joined the room
"""
if self.config.worker_app:
- return self._notify_user_membership_change(
- room_id=room_id,
- user_id=user.to_string(),
- change="joined",
+ await self._notify_user_membership_change(
+ room_id=room_id, user_id=user.to_string(), change="joined"
)
else:
- return user_joined_room(self.distributor, user, room_id)
+ user_joined_room(self.distributor, user, room_id)
+
+ @defer.inlineCallbacks
+ def get_room_complexity(self, remote_room_hosts, room_id):
+ """
+ Fetch the complexity of a remote room over federation.
+
+ Args:
+ remote_room_hosts (list[str]): The remote servers to ask.
+ room_id (str): The room ID to ask about.
+
+ Returns:
+ Deferred[dict] or Deferred[None]: Dict contains the complexity
+ metric versions, while None means we could not fetch the complexity.
+ """
+
+ for host in remote_room_hosts:
+ res = yield self.federation_client.get_room_complexity(host, room_id)
+
+ # We got a result, return it.
+ if res:
+ defer.returnValue(res)
+
+ # We fell off the bottom, couldn't get the complexity from anyone. Oh
+ # well.
+ defer.returnValue(None)
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 02c508acec..ad22415782 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -30,6 +30,7 @@ def _create_rerouter(func_name):
"""Returns a function that looks at the group id and calls the function
on federation or the local group server if the group is local
"""
+
def f(self, group_id, *args, **kwargs):
if self.is_mine_id(group_id):
return getattr(self.groups_server_handler, func_name)(
@@ -49,9 +50,7 @@ def _create_rerouter(func_name):
def http_response_errback(failure):
failure.trap(HttpResponseException)
e = failure.value
- if e.code == 403:
- raise e.to_synapse_error()
- return failure
+ raise e.to_synapse_error()
def request_failed_errback(failure):
failure.trap(RequestSendFailed)
@@ -60,10 +59,11 @@ def _create_rerouter(func_name):
d.addErrback(http_response_errback)
d.addErrback(request_failed_errback)
return d
+
return f
-class GroupsLocalHandler(object):
+class GroupsLocalWorkerHandler(object):
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
@@ -81,40 +81,17 @@ class GroupsLocalHandler(object):
self.profile_handler = hs.get_profile_handler()
- # Ensure attestations get renewed
- hs.get_groups_attestation_renewer()
-
# The following functions merely route the query to the local groups server
# or federation depending on if the group is local or remote
get_group_profile = _create_rerouter("get_group_profile")
- update_group_profile = _create_rerouter("update_group_profile")
get_rooms_in_group = _create_rerouter("get_rooms_in_group")
-
get_invited_users_in_group = _create_rerouter("get_invited_users_in_group")
-
- add_room_to_group = _create_rerouter("add_room_to_group")
- update_room_in_group = _create_rerouter("update_room_in_group")
- remove_room_from_group = _create_rerouter("remove_room_from_group")
-
- update_group_summary_room = _create_rerouter("update_group_summary_room")
- delete_group_summary_room = _create_rerouter("delete_group_summary_room")
-
- update_group_category = _create_rerouter("update_group_category")
- delete_group_category = _create_rerouter("delete_group_category")
get_group_category = _create_rerouter("get_group_category")
get_group_categories = _create_rerouter("get_group_categories")
-
- update_group_summary_user = _create_rerouter("update_group_summary_user")
- delete_group_summary_user = _create_rerouter("delete_group_summary_user")
-
- update_group_role = _create_rerouter("update_group_role")
- delete_group_role = _create_rerouter("delete_group_role")
get_group_role = _create_rerouter("get_group_role")
get_group_roles = _create_rerouter("get_group_roles")
- set_group_join_policy = _create_rerouter("set_group_join_policy")
-
@defer.inlineCallbacks
def get_group_summary(self, group_id, requester_user_id):
"""Get the group summary for a group.
@@ -126,9 +103,14 @@ class GroupsLocalHandler(object):
group_id, requester_user_id
)
else:
- res = yield self.transport_client.get_group_summary(
- get_domain_from_id(group_id), group_id, requester_user_id,
- )
+ try:
+ res = yield self.transport_client.get_group_summary(
+ get_domain_from_id(group_id), group_id, requester_user_id
+ )
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
+ except RequestSendFailed:
+ raise SynapseError(502, "Failed to contact group server")
group_server_name = get_domain_from_id(group_id)
@@ -162,7 +144,145 @@ class GroupsLocalHandler(object):
res.setdefault("user", {})["is_publicised"] = is_publicised
- defer.returnValue(res)
+ return res
+
+ @defer.inlineCallbacks
+ def get_users_in_group(self, group_id, requester_user_id):
+ """Get users in a group
+ """
+ if self.is_mine_id(group_id):
+ res = yield self.groups_server_handler.get_users_in_group(
+ group_id, requester_user_id
+ )
+ return res
+
+ group_server_name = get_domain_from_id(group_id)
+
+ try:
+ res = yield self.transport_client.get_users_in_group(
+ get_domain_from_id(group_id), group_id, requester_user_id
+ )
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
+ except RequestSendFailed:
+ raise SynapseError(502, "Failed to contact group server")
+
+ chunk = res["chunk"]
+ valid_entries = []
+ for entry in chunk:
+ g_user_id = entry["user_id"]
+ attestation = entry.pop("attestation", {})
+ try:
+ if get_domain_from_id(g_user_id) != group_server_name:
+ yield self.attestations.verify_attestation(
+ attestation,
+ group_id=group_id,
+ user_id=g_user_id,
+ server_name=get_domain_from_id(g_user_id),
+ )
+ valid_entries.append(entry)
+ except Exception as e:
+ logger.info("Failed to verify user is in group: %s", e)
+
+ res["chunk"] = valid_entries
+
+ return res
+
+ @defer.inlineCallbacks
+ def get_joined_groups(self, user_id):
+ group_ids = yield self.store.get_joined_groups(user_id)
+ return {"groups": group_ids}
+
+ @defer.inlineCallbacks
+ def get_publicised_groups_for_user(self, user_id):
+ if self.hs.is_mine_id(user_id):
+ result = yield self.store.get_publicised_groups_for_user(user_id)
+
+ # Check AS associated groups for this user - this depends on the
+ # RegExps in the AS registration file (under `users`)
+ for app_service in self.store.get_app_services():
+ result.extend(app_service.get_groups_for_user(user_id))
+
+ return {"groups": result}
+ else:
+ try:
+ bulk_result = yield self.transport_client.bulk_get_publicised_groups(
+ get_domain_from_id(user_id), [user_id]
+ )
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
+ except RequestSendFailed:
+ raise SynapseError(502, "Failed to contact group server")
+
+ result = bulk_result.get("users", {}).get(user_id)
+ # TODO: Verify attestations
+ return {"groups": result}
+
+ @defer.inlineCallbacks
+ def bulk_get_publicised_groups(self, user_ids, proxy=True):
+ destinations = {}
+ local_users = set()
+
+ for user_id in user_ids:
+ if self.hs.is_mine_id(user_id):
+ local_users.add(user_id)
+ else:
+ destinations.setdefault(get_domain_from_id(user_id), set()).add(user_id)
+
+ if not proxy and destinations:
+ raise SynapseError(400, "Some user_ids are not local")
+
+ results = {}
+ failed_results = []
+ for destination, dest_user_ids in iteritems(destinations):
+ try:
+ r = yield self.transport_client.bulk_get_publicised_groups(
+ destination, list(dest_user_ids)
+ )
+ results.update(r["users"])
+ except Exception:
+ failed_results.extend(dest_user_ids)
+
+ for uid in local_users:
+ results[uid] = yield self.store.get_publicised_groups_for_user(uid)
+
+ # Check AS associated groups for this user - this depends on the
+ # RegExps in the AS registration file (under `users`)
+ for app_service in self.store.get_app_services():
+ results[uid].extend(app_service.get_groups_for_user(uid))
+
+ return {"users": results}
+
+
+class GroupsLocalHandler(GroupsLocalWorkerHandler):
+ def __init__(self, hs):
+ super(GroupsLocalHandler, self).__init__(hs)
+
+ # Ensure attestations get renewed
+ hs.get_groups_attestation_renewer()
+
+ # The following functions merely route the query to the local groups server
+ # or federation depending on if the group is local or remote
+
+ update_group_profile = _create_rerouter("update_group_profile")
+
+ add_room_to_group = _create_rerouter("add_room_to_group")
+ update_room_in_group = _create_rerouter("update_room_in_group")
+ remove_room_from_group = _create_rerouter("remove_room_from_group")
+
+ update_group_summary_room = _create_rerouter("update_group_summary_room")
+ delete_group_summary_room = _create_rerouter("delete_group_summary_room")
+
+ update_group_category = _create_rerouter("update_group_category")
+ delete_group_category = _create_rerouter("delete_group_category")
+
+ update_group_summary_user = _create_rerouter("update_group_summary_user")
+ delete_group_summary_user = _create_rerouter("delete_group_summary_user")
+
+ update_group_role = _create_rerouter("update_group_role")
+ delete_group_role = _create_rerouter("delete_group_role")
+
+ set_group_join_policy = _create_rerouter("set_group_join_policy")
@defer.inlineCallbacks
def create_group(self, group_id, user_id, content):
@@ -183,9 +303,14 @@ class GroupsLocalHandler(object):
content["user_profile"] = yield self.profile_handler.get_profile(user_id)
- res = yield self.transport_client.create_group(
- get_domain_from_id(group_id), group_id, user_id, content,
- )
+ try:
+ res = yield self.transport_client.create_group(
+ get_domain_from_id(group_id), group_id, user_id, content
+ )
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
+ except RequestSendFailed:
+ raise SynapseError(502, "Failed to contact group server")
remote_attestation = res["attestation"]
yield self.attestations.verify_attestation(
@@ -197,73 +322,38 @@ class GroupsLocalHandler(object):
is_publicised = content.get("publicise", False)
token = yield self.store.register_user_group_membership(
- group_id, user_id,
+ group_id,
+ user_id,
membership="join",
is_admin=True,
local_attestation=local_attestation,
remote_attestation=remote_attestation,
is_publicised=is_publicised,
)
- self.notifier.on_new_event(
- "groups_key", token, users=[user_id],
- )
+ self.notifier.on_new_event("groups_key", token, users=[user_id])
- defer.returnValue(res)
-
- @defer.inlineCallbacks
- def get_users_in_group(self, group_id, requester_user_id):
- """Get users in a group
- """
- if self.is_mine_id(group_id):
- res = yield self.groups_server_handler.get_users_in_group(
- group_id, requester_user_id
- )
- defer.returnValue(res)
-
- group_server_name = get_domain_from_id(group_id)
-
- res = yield self.transport_client.get_users_in_group(
- get_domain_from_id(group_id), group_id, requester_user_id,
- )
-
- chunk = res["chunk"]
- valid_entries = []
- for entry in chunk:
- g_user_id = entry["user_id"]
- attestation = entry.pop("attestation", {})
- try:
- if get_domain_from_id(g_user_id) != group_server_name:
- yield self.attestations.verify_attestation(
- attestation,
- group_id=group_id,
- user_id=g_user_id,
- server_name=get_domain_from_id(g_user_id),
- )
- valid_entries.append(entry)
- except Exception as e:
- logger.info("Failed to verify user is in group: %s", e)
-
- res["chunk"] = valid_entries
-
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def join_group(self, group_id, user_id, content):
"""Request to join a group
"""
if self.is_mine_id(group_id):
- yield self.groups_server_handler.join_group(
- group_id, user_id, content
- )
+ yield self.groups_server_handler.join_group(group_id, user_id, content)
local_attestation = None
remote_attestation = None
else:
local_attestation = self.attestations.create_attestation(group_id, user_id)
content["attestation"] = local_attestation
- res = yield self.transport_client.join_group(
- get_domain_from_id(group_id), group_id, user_id, content,
- )
+ try:
+ res = yield self.transport_client.join_group(
+ get_domain_from_id(group_id), group_id, user_id, content
+ )
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
+ except RequestSendFailed:
+ raise SynapseError(502, "Failed to contact group server")
remote_attestation = res["attestation"]
@@ -278,36 +368,38 @@ class GroupsLocalHandler(object):
is_publicised = content.get("publicise", False)
token = yield self.store.register_user_group_membership(
- group_id, user_id,
+ group_id,
+ user_id,
membership="join",
is_admin=False,
local_attestation=local_attestation,
remote_attestation=remote_attestation,
is_publicised=is_publicised,
)
- self.notifier.on_new_event(
- "groups_key", token, users=[user_id],
- )
+ self.notifier.on_new_event("groups_key", token, users=[user_id])
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def accept_invite(self, group_id, user_id, content):
"""Accept an invite to a group
"""
if self.is_mine_id(group_id):
- yield self.groups_server_handler.accept_invite(
- group_id, user_id, content
- )
+ yield self.groups_server_handler.accept_invite(group_id, user_id, content)
local_attestation = None
remote_attestation = None
else:
local_attestation = self.attestations.create_attestation(group_id, user_id)
content["attestation"] = local_attestation
- res = yield self.transport_client.accept_group_invite(
- get_domain_from_id(group_id), group_id, user_id, content,
- )
+ try:
+ res = yield self.transport_client.accept_group_invite(
+ get_domain_from_id(group_id), group_id, user_id, content
+ )
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
+ except RequestSendFailed:
+ raise SynapseError(502, "Failed to contact group server")
remote_attestation = res["attestation"]
@@ -322,38 +414,42 @@ class GroupsLocalHandler(object):
is_publicised = content.get("publicise", False)
token = yield self.store.register_user_group_membership(
- group_id, user_id,
+ group_id,
+ user_id,
membership="join",
is_admin=False,
local_attestation=local_attestation,
remote_attestation=remote_attestation,
is_publicised=is_publicised,
)
- self.notifier.on_new_event(
- "groups_key", token, users=[user_id],
- )
+ self.notifier.on_new_event("groups_key", token, users=[user_id])
- defer.returnValue({})
+ return {}
@defer.inlineCallbacks
def invite(self, group_id, user_id, requester_user_id, config):
"""Invite a user to a group
"""
- content = {
- "requester_user_id": requester_user_id,
- "config": config,
- }
+ content = {"requester_user_id": requester_user_id, "config": config}
if self.is_mine_id(group_id):
res = yield self.groups_server_handler.invite_to_group(
- group_id, user_id, requester_user_id, content,
+ group_id, user_id, requester_user_id, content
)
else:
- res = yield self.transport_client.invite_to_group(
- get_domain_from_id(group_id), group_id, user_id, requester_user_id,
- content,
- )
+ try:
+ res = yield self.transport_client.invite_to_group(
+ get_domain_from_id(group_id),
+ group_id,
+ user_id,
+ requester_user_id,
+ content,
+ )
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
+ except RequestSendFailed:
+ raise SynapseError(502, "Failed to contact group server")
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def on_invite(self, group_id, user_id, content):
@@ -372,20 +468,19 @@ class GroupsLocalHandler(object):
local_profile["avatar_url"] = content["profile"]["avatar_url"]
token = yield self.store.register_user_group_membership(
- group_id, user_id,
+ group_id,
+ user_id,
membership="invite",
content={"profile": local_profile, "inviter": content["inviter"]},
)
- self.notifier.on_new_event(
- "groups_key", token, users=[user_id],
- )
+ self.notifier.on_new_event("groups_key", token, users=[user_id])
try:
user_profile = yield self.profile_handler.get_profile(user_id)
except Exception as e:
- logger.warn("No profile for user %s: %s", user_id, e)
+ logger.warning("No profile for user %s: %s", user_id, e)
user_profile = {}
- defer.returnValue({"state": "invite", "user_profile": user_profile})
+ return {"state": "invite", "user_profile": user_profile}
@defer.inlineCallbacks
def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
@@ -393,28 +488,33 @@ class GroupsLocalHandler(object):
"""
if user_id == requester_user_id:
token = yield self.store.register_user_group_membership(
- group_id, user_id,
- membership="leave",
- )
- self.notifier.on_new_event(
- "groups_key", token, users=[user_id],
+ group_id, user_id, membership="leave"
)
+ self.notifier.on_new_event("groups_key", token, users=[user_id])
# TODO: Should probably remember that we tried to leave so that we can
# retry if the group server is currently down.
if self.is_mine_id(group_id):
res = yield self.groups_server_handler.remove_user_from_group(
- group_id, user_id, requester_user_id, content,
+ group_id, user_id, requester_user_id, content
)
else:
content["requester_user_id"] = requester_user_id
- res = yield self.transport_client.remove_user_from_group(
- get_domain_from_id(group_id), group_id, requester_user_id,
- user_id, content,
- )
+ try:
+ res = yield self.transport_client.remove_user_from_group(
+ get_domain_from_id(group_id),
+ group_id,
+ requester_user_id,
+ user_id,
+ content,
+ )
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
+ except RequestSendFailed:
+ raise SynapseError(502, "Failed to contact group server")
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def user_removed_from_group(self, group_id, user_id, content):
@@ -422,72 +522,6 @@ class GroupsLocalHandler(object):
"""
# TODO: Check if user in group
token = yield self.store.register_user_group_membership(
- group_id, user_id,
- membership="leave",
+ group_id, user_id, membership="leave"
)
- self.notifier.on_new_event(
- "groups_key", token, users=[user_id],
- )
-
- @defer.inlineCallbacks
- def get_joined_groups(self, user_id):
- group_ids = yield self.store.get_joined_groups(user_id)
- defer.returnValue({"groups": group_ids})
-
- @defer.inlineCallbacks
- def get_publicised_groups_for_user(self, user_id):
- if self.hs.is_mine_id(user_id):
- result = yield self.store.get_publicised_groups_for_user(user_id)
-
- # Check AS associated groups for this user - this depends on the
- # RegExps in the AS registration file (under `users`)
- for app_service in self.store.get_app_services():
- result.extend(app_service.get_groups_for_user(user_id))
-
- defer.returnValue({"groups": result})
- else:
- bulk_result = yield self.transport_client.bulk_get_publicised_groups(
- get_domain_from_id(user_id), [user_id],
- )
- result = bulk_result.get("users", {}).get(user_id)
- # TODO: Verify attestations
- defer.returnValue({"groups": result})
-
- @defer.inlineCallbacks
- def bulk_get_publicised_groups(self, user_ids, proxy=True):
- destinations = {}
- local_users = set()
-
- for user_id in user_ids:
- if self.hs.is_mine_id(user_id):
- local_users.add(user_id)
- else:
- destinations.setdefault(
- get_domain_from_id(user_id), set()
- ).add(user_id)
-
- if not proxy and destinations:
- raise SynapseError(400, "Some user_ids are not local")
-
- results = {}
- failed_results = []
- for destination, dest_user_ids in iteritems(destinations):
- try:
- r = yield self.transport_client.bulk_get_publicised_groups(
- destination, list(dest_user_ids),
- )
- results.update(r["users"])
- except Exception:
- failed_results.extend(dest_user_ids)
-
- for uid in local_users:
- results[uid] = yield self.store.get_publicised_groups_for_user(
- uid
- )
-
- # Check AS associated groups for this user - this depends on the
- # RegExps in the AS registration file (under `users`)
- for app_service in self.store.get_app_services():
- results[uid].extend(app_service.get_groups_for_user(uid))
-
- defer.returnValue({"users": results})
+ self.notifier.on_new_event("groups_key", token, users=[user_id])
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index dfc03f51e7..1d07361661 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -18,6 +18,7 @@
"""Utilities for interacting with Identity Servers"""
import logging
+import urllib
from canonicaljson import json
from signedjson.key import decode_verify_key_bytes
@@ -25,6 +26,7 @@ from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
from twisted.internet import defer
+from twisted.internet.error import TimeoutError
from synapse.api.errors import (
AuthError,
@@ -34,6 +36,10 @@ from synapse.api.errors import (
ProxiedRequestError,
SynapseError,
)
+from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.http.client import SimpleHttpClient
+from synapse.util.hash import sha256_and_url_safe_base64
+from synapse.util.stringutils import assert_valid_client_secret, random_string
from ._base import BaseHandler
@@ -41,11 +47,16 @@ logger = logging.getLogger(__name__)
class IdentityHandler(BaseHandler):
-
def __init__(self, hs):
super(IdentityHandler, self).__init__(hs)
+ self.hs = hs
self.http_client = hs.get_simple_http_client()
+ # We create a blacklisting instance of SimpleHttpClient for contacting identity
+ # servers specified by clients
+ self.blacklisting_http_client = SimpleHttpClient(
+ hs, ip_blacklist=hs.config.federation_ip_range_blacklist
+ )
self.federation_http_client = hs.get_http_client()
self.trusted_id_servers = set(hs.config.trusted_third_party_id_servers)
@@ -55,101 +66,122 @@ class IdentityHandler(BaseHandler):
self.rewrite_identity_server_urls = hs.config.rewrite_identity_server_urls
self._enable_lookup = hs.config.enable_3pid_lookup
- def _should_trust_id_server(self, id_server):
- if id_server not in self.trusted_id_servers:
- if self.trust_any_id_server_just_for_testing_do_not_use:
- logger.warn(
- "Trusting untrustworthy ID server %r even though it isn't"
- " in the trusted id list for testing because"
- " 'use_insecure_ssl_client_just_for_testing_do_not_use'"
- " is set in the config",
- id_server,
- )
- else:
- return False
- return True
-
@defer.inlineCallbacks
- def threepid_from_creds(self, creds):
- if 'id_server' in creds:
- id_server = creds['id_server']
- elif 'idServer' in creds:
- id_server = creds['idServer']
- else:
- raise SynapseError(400, "No id_server in creds")
+ def threepid_from_creds(self, id_server_url, creds):
+ """
+ Retrieve and validate a threepid identifier from a "credentials" dictionary against a
+ given identity server
- if 'client_secret' in creds:
- client_secret = creds['client_secret']
- elif 'clientSecret' in creds:
- client_secret = creds['clientSecret']
- else:
- raise SynapseError(400, "No client_secret in creds")
+ Args:
+ id_server_url (str): The identity server to validate 3PIDs against. Must be a
+ complete URL including the protocol (http(s)://)
+
+ creds (dict[str, str]): Dictionary containing the following keys:
+ * client_secret|clientSecret: A unique secret str provided by the client
+ * sid: The ID of the validation session
+
+ Returns:
+ Deferred[dict[str,str|int]|None]: A dictionary consisting of response params to
+ the /getValidated3pid endpoint of the Identity Service API, or None if the
+ threepid was not found
+ """
+ client_secret = creds.get("client_secret") or creds.get("clientSecret")
+ if not client_secret:
+ raise SynapseError(
+ 400, "Missing param client_secret in creds", errcode=Codes.MISSING_PARAM
+ )
+ assert_valid_client_secret(client_secret)
- if not self._should_trust_id_server(id_server):
- logger.warn(
- '%s is not a trusted ID server: rejecting 3pid ' +
- 'credentials', id_server
+ session_id = creds.get("sid")
+ if not session_id:
+ raise SynapseError(
+ 400, "Missing param session_id in creds", errcode=Codes.MISSING_PARAM
)
- defer.returnValue(None)
+
+ query_params = {"sid": session_id, "client_secret": client_secret}
+
# if we have a rewrite rule set for the identity server,
# apply it now.
- if id_server in self.rewrite_identity_server_urls:
- id_server = self.rewrite_identity_server_urls[id_server]
+ id_server_url = self.rewrite_id_server_url(id_server_url)
+
+ url = "%s%s" % (
+ id_server_url,
+ "/_matrix/identity/api/v1/3pid/getValidated3pid",
+ )
+
try:
- data = yield self.http_client.get_json(
- "https://%s%s" % (
- id_server,
- "/_matrix/identity/api/v1/3pid/getValidated3pid"
- ),
- {'sid': creds['sid'], 'client_secret': client_secret}
- )
+ data = yield self.http_client.get_json(url, query_params)
+ except TimeoutError:
+ raise SynapseError(500, "Timed out contacting identity server")
except HttpResponseException as e:
- logger.info("getValidated3pid failed with Matrix error: %r", e)
- raise e.to_synapse_error()
+ logger.info(
+ "%s returned %i for threepid validation for: %s",
+ id_server_url,
+ e.code,
+ creds,
+ )
+ return None
- if 'medium' in data:
- defer.returnValue(data)
- defer.returnValue(None)
+ # Old versions of Sydent return a 200 http code even on a failed validation
+ # check. Thus, in addition to the HttpResponseException check above (which
+ # checks for non-200 errors), we need to make sure validation_session isn't
+ # actually an error, identified by the absence of a "medium" key
+ # See https://github.com/matrix-org/sydent/issues/215 for details
+ if "medium" in data:
+ return data
+
+ logger.info("%s reported non-validated threepid: %s", id_server_url, creds)
+ return None
@defer.inlineCallbacks
- def bind_threepid(self, creds, mxid):
- logger.debug("binding threepid %r to %s", creds, mxid)
- data = None
+ def bind_threepid(
+ self, client_secret, sid, mxid, id_server, id_access_token=None, use_v2=True
+ ):
+ """Bind a 3PID to an identity server
- if 'id_server' in creds:
- id_server = creds['id_server']
- elif 'idServer' in creds:
- id_server = creds['idServer']
- else:
- raise SynapseError(400, "No id_server in creds")
+ Args:
+ client_secret (str): A unique secret provided by the client
- if 'client_secret' in creds:
- client_secret = creds['client_secret']
- elif 'clientSecret' in creds:
- client_secret = creds['clientSecret']
- else:
- raise SynapseError(400, "No client_secret in creds")
+ sid (str): The ID of the validation session
+
+ mxid (str): The MXID to bind the 3PID to
+
+ id_server (str): The domain of the identity server to query
+
+ id_access_token (str): The access token to authenticate to the identity
+ server with, if necessary. Required if use_v2 is true
+
+ use_v2 (bool): Whether to use v2 Identity Service API endpoints. Defaults to True
+
+ Returns:
+ Deferred[dict]: The response from the identity server
+ """
+ logger.debug("Proxying threepid bind request for %s to %s", mxid, id_server)
+
+ # If an id_access_token is not supplied, force usage of v1
+ if id_access_token is None:
+ use_v2 = False
# if we have a rewrite rule set for the identity server,
# apply it now, but only for sending the request (not
# storing in the database).
- if id_server in self.rewrite_identity_server_urls:
- id_server_host = self.rewrite_identity_server_urls[id_server]
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
+ # Decide which API endpoint URLs to use
+ headers = {}
+ bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid}
+ if use_v2:
+ bind_url = "%s/_matrix/identity/v2/3pid/bind" % (id_server_url,)
+ headers["Authorization"] = create_id_access_token_header(id_access_token)
else:
- id_server_host = id_server
+ bind_url = "%s/_matrix/identity/api/v1/3pid/bind" % (id_server_url,)
try:
- data = yield self.http_client.post_urlencoded_get_json(
- "https://%s%s" % (
- id_server_host, "/_matrix/identity/api/v1/3pid/bind"
- ),
- {
- 'sid': creds['sid'],
- 'client_secret': client_secret,
- 'mxid': mxid,
- }
+ # Use the blacklisting http client as this call is only to identity servers
+ # provided by a client
+ data = yield self.blacklisting_http_client.post_json_get_json(
+ bind_url, bind_data, headers=headers
)
- logger.debug("bound threepid %r to %s", creds, mxid)
# Remember where we bound the threepid
yield self.store.add_user_bound_threepid(
@@ -158,13 +190,28 @@ class IdentityHandler(BaseHandler):
address=data["address"],
id_server=id_server,
)
+
+ return data
+ except HttpResponseException as e:
+ if e.code != 404 or not use_v2:
+ logger.error("3PID bind failed with Matrix error: %r", e)
+ raise e.to_synapse_error()
+ except TimeoutError:
+ raise SynapseError(500, "Timed out contacting identity server")
except CodeMessageException as e:
data = json.loads(e.msg) # XXX WAT?
- defer.returnValue(data)
+ return data
+
+ logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)
+ res = yield self.bind_threepid(
+ client_secret, sid, mxid, id_server, id_access_token, use_v2=False
+ )
+ return res
@defer.inlineCallbacks
def try_unbind_threepid(self, mxid, threepid):
- """Removes a binding from an identity server
+ """Attempt to remove a 3PID from an identity server, or if one is not provided, all
+ identity servers we're aware the binding is present on
Args:
mxid (str): Matrix user ID of binding to be removed
@@ -183,22 +230,20 @@ class IdentityHandler(BaseHandler):
id_servers = [threepid["id_server"]]
else:
id_servers = yield self.store.get_id_servers_user_bound(
- user_id=mxid,
- medium=threepid["medium"],
- address=threepid["address"],
+ user_id=mxid, medium=threepid["medium"], address=threepid["address"]
)
# We don't know where to unbind, so we don't have a choice but to return
if not id_servers:
- defer.returnValue(False)
+ return False
changed = True
for id_server in id_servers:
changed &= yield self.try_unbind_threepid_with_id_server(
- mxid, threepid, id_server,
+ mxid, threepid, id_server
)
- defer.returnValue(changed)
+ return changed
@defer.inlineCallbacks
def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
@@ -216,54 +261,50 @@ class IdentityHandler(BaseHandler):
Deferred[bool]: True on success, otherwise False if the identity
server doesn't support unbinding
"""
- url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
content = {
"mxid": mxid,
- "threepid": {
- "medium": threepid["medium"],
- "address": threepid["address"],
- },
+ "threepid": {"medium": threepid["medium"], "address": threepid["address"]},
}
# we abuse the federation http client to sign the request, but we have to send it
# using the normal http client since we don't want the SRV lookup and want normal
# 'browser-like' HTTPS.
+ url_bytes = "/_matrix/identity/api/v1/3pid/unbind".encode("ascii")
auth_headers = self.federation_http_client.build_auth_headers(
destination=None,
- method='POST',
- url_bytes='/_matrix/identity/api/v1/3pid/unbind'.encode('ascii'),
+ method="POST",
+ url_bytes=url_bytes,
content=content,
destination_is=id_server,
)
- headers = {
- b"Authorization": auth_headers,
- }
+ headers = {b"Authorization": auth_headers}
# if we have a rewrite rule set for the identity server,
# apply it now.
#
# Note that destination_is has to be the real id_server, not
# the server we connect to.
- if id_server in self.rewrite_identity_server_urls:
- id_server = self.rewrite_identity_server_urls[id_server]
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
- url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
+ url = "%s/_matrix/identity/api/v1/3pid/unbind" % (id_server_url,)
try:
- yield self.http_client.post_json_get_json(
- url,
- content,
- headers,
+ # Use the blacklisting http client as this call is only to identity servers
+ # provided by a client
+ yield self.blacklisting_http_client.post_json_get_json(
+ url, content, headers
)
changed = True
except HttpResponseException as e:
changed = False
- if e.code in (400, 404, 501,):
+ if e.code in (400, 404, 501):
# The remote server probably doesn't support unbinding (yet)
- logger.warn("Received %d response while unbinding threepid", e.code)
+ logger.warning("Received %d response while unbinding threepid", e.code)
else:
logger.error("Failed to unbind threepid on identity server: %s", e)
- raise SynapseError(502, "Failed to contact identity server")
+ raise SynapseError(500, "Failed to contact identity server")
+ except TimeoutError:
+ raise SynapseError(500, "Timed out contacting identity server")
yield self.store.remove_user_bound_threepid(
user_id=mxid,
@@ -272,87 +313,319 @@ class IdentityHandler(BaseHandler):
id_server=id_server,
)
- defer.returnValue(changed)
+ return changed
@defer.inlineCallbacks
- def requestEmailToken(
+ def send_threepid_validation(
self,
- id_server,
- email,
+ email_address,
client_secret,
send_attempt,
+ send_email_func,
next_link=None,
):
- if not self._should_trust_id_server(id_server):
- raise SynapseError(
- 400, "Untrusted ID server '%s'" % id_server,
- Codes.SERVER_NOT_TRUSTED
+ """Send a threepid validation email for password reset or
+ registration purposes
+
+ Args:
+ email_address (str): The user's email address
+ client_secret (str): The provided client secret
+ send_attempt (int): Which send attempt this is
+ send_email_func (func): A function that takes an email address, token,
+ client_secret and session_id, sends an email
+ and returns a Deferred.
+ next_link (str|None): The URL to redirect the user to after validation
+
+ Returns:
+ The new session_id upon success
+
+ Raises:
+ SynapseError is an error occurred when sending the email
+ """
+ # Check that this email/client_secret/send_attempt combo is new or
+ # greater than what we've seen previously
+ session = yield self.store.get_threepid_validation_session(
+ "email", client_secret, address=email_address, validated=False
+ )
+
+ # Check to see if a session already exists and that it is not yet
+ # marked as validated
+ if session and session.get("validated_at") is None:
+ session_id = session["session_id"]
+ last_send_attempt = session["last_send_attempt"]
+
+ # Check that the send_attempt is higher than previous attempts
+ if send_attempt <= last_send_attempt:
+ # If not, just return a success without sending an email
+ return session_id
+ else:
+ # An non-validated session does not exist yet.
+ # Generate a session id
+ session_id = random_string(16)
+
+ if next_link:
+ # Manipulate the next_link to add the sid, because the caller won't get
+ # it until we send a response, by which time we've sent the mail.
+ if "?" in next_link:
+ next_link += "&"
+ else:
+ next_link += "?"
+ next_link += "sid=" + urllib.parse.quote(session_id)
+
+ # Generate a new validation token
+ token = random_string(32)
+
+ # Send the mail with the link containing the token, client_secret
+ # and session_id
+ try:
+ yield send_email_func(email_address, token, client_secret, session_id)
+ except Exception:
+ logger.exception(
+ "Error sending threepid validation email to %s", email_address
)
+ raise SynapseError(500, "An error was encountered when sending the email")
+
+ token_expires = (
+ self.hs.clock.time_msec() + self.hs.config.email_validation_token_lifetime
+ )
+
+ yield self.store.start_or_continue_validation_session(
+ "email",
+ email_address,
+ session_id,
+ client_secret,
+ send_attempt,
+ next_link,
+ token,
+ token_expires,
+ )
+
+ return session_id
+
+ def rewrite_id_server_url(self, url: str, add_https=False) -> str:
+ """Given an identity server URL, optionally add a protocol scheme
+ before rewriting it according to the rewrite_identity_server_urls
+ config option
+
+ Adds https:// to the URL if specified, then tries to rewrite the
+ url. Returns either the rewritten URL or the URL with optional
+ protocol scheme additions.
+ """
+ rewritten_url = url
+ if add_https:
+ rewritten_url = "https://" + rewritten_url
+
+ rewritten_url = self.rewrite_identity_server_urls.get(
+ rewritten_url, rewritten_url
+ )
+ logger.debug("Rewriting identity server rule from %s to %s", url, rewritten_url)
+ return rewritten_url
+ @defer.inlineCallbacks
+ def requestEmailToken(
+ self, id_server_url, email, client_secret, send_attempt, next_link=None
+ ):
+ """
+ Request an external server send an email on our behalf for the purposes of threepid
+ validation.
+
+ Args:
+ id_server_url (str): The identity server to proxy to
+ email (str): The email to send the message to
+ client_secret (str): The unique client_secret sends by the user
+ send_attempt (int): Which attempt this is
+ next_link: A link to redirect the user to once they submit the token
+
+ Returns:
+ The json response body from the server
+ """
params = {
- 'email': email,
- 'client_secret': client_secret,
- 'send_attempt': send_attempt,
+ "email": email,
+ "client_secret": client_secret,
+ "send_attempt": send_attempt,
}
# if we have a rewrite rule set for the identity server,
# apply it now.
- if id_server in self.rewrite_identity_server_urls:
- id_server = self.rewrite_identity_server_urls[id_server]
+ id_server_url = self.rewrite_id_server_url(id_server_url)
if next_link:
- params.update({'next_link': next_link})
+ params["next_link"] = next_link
+
+ if self.hs.config.using_identity_server_from_trusted_list:
+ # Warn that a deprecated config option is in use
+ logger.warning(
+ 'The config option "trust_identity_server_for_password_resets" '
+ 'has been replaced by "account_threepid_delegate". '
+ "Please consult the sample config at docs/sample_config.yaml for "
+ "details and update your config file."
+ )
try:
data = yield self.http_client.post_json_get_json(
- "https://%s%s" % (
- id_server,
- "/_matrix/identity/api/v1/validate/email/requestToken"
- ),
- params
+ "%s/_matrix/identity/api/v1/validate/email/requestToken"
+ % (id_server_url,),
+ params,
)
- defer.returnValue(data)
+ return data
except HttpResponseException as e:
logger.info("Proxied requestToken failed: %r", e)
raise e.to_synapse_error()
+ except TimeoutError:
+ raise SynapseError(500, "Timed out contacting identity server")
@defer.inlineCallbacks
def requestMsisdnToken(
- self, id_server, country, phone_number,
- client_secret, send_attempt, **kwargs
+ self,
+ id_server_url,
+ country,
+ phone_number,
+ client_secret,
+ send_attempt,
+ next_link=None,
):
- if not self._should_trust_id_server(id_server):
- raise SynapseError(
- 400, "Untrusted ID server '%s'" % id_server,
- Codes.SERVER_NOT_TRUSTED
- )
+ """
+ Request an external server send an SMS message on our behalf for the purposes of
+ threepid validation.
+ Args:
+ id_server_url (str): The identity server to proxy to
+ country (str): The country code of the phone number
+ phone_number (str): The number to send the message to
+ client_secret (str): The unique client_secret sends by the user
+ send_attempt (int): Which attempt this is
+ next_link: A link to redirect the user to once they submit the token
+ Returns:
+ The json response body from the server
+ """
params = {
- 'country': country,
- 'phone_number': phone_number,
- 'client_secret': client_secret,
- 'send_attempt': send_attempt,
+ "country": country,
+ "phone_number": phone_number,
+ "client_secret": client_secret,
+ "send_attempt": send_attempt,
}
- params.update(kwargs)
+ if next_link:
+ params["next_link"] = next_link
+
+ if self.hs.config.using_identity_server_from_trusted_list:
+ # Warn that a deprecated config option is in use
+ logger.warning(
+ 'The config option "trust_identity_server_for_password_resets" '
+ 'has been replaced by "account_threepid_delegate". '
+ "Please consult the sample config at docs/sample_config.yaml for "
+ "details and update your config file."
+ )
+
# if we have a rewrite rule set for the identity server,
# apply it now.
- if id_server in self.rewrite_identity_server_urls:
- id_server = self.rewrite_identity_server_urls[id_server]
+ id_server_url = self.rewrite_id_server_url(id_server_url)
try:
data = yield self.http_client.post_json_get_json(
- "https://%s%s" % (
- id_server,
- "/_matrix/identity/api/v1/validate/msisdn/requestToken"
- ),
- params
+ "%s/_matrix/identity/api/v1/validate/msisdn/requestToken"
+ % (id_server_url,),
+ params,
)
- defer.returnValue(data)
except HttpResponseException as e:
logger.info("Proxied requestToken failed: %r", e)
raise e.to_synapse_error()
+ except TimeoutError:
+ raise SynapseError(500, "Timed out contacting identity server")
+
+ assert self.hs.config.public_baseurl
+
+ # we need to tell the client to send the token back to us, since it doesn't
+ # otherwise know where to send it, so add submit_url response parameter
+ # (see also MSC2078)
+ data["submit_url"] = (
+ self.hs.config.public_baseurl
+ + "_matrix/client/unstable/add_threepid/msisdn/submit_token"
+ )
+ return data
+
+ @defer.inlineCallbacks
+ def validate_threepid_session(self, client_secret, sid):
+ """Validates a threepid session with only the client secret and session ID
+ Tries validating against any configured account_threepid_delegates as well as locally.
+
+ Args:
+ client_secret (str): A secret provided by the client
+
+ sid (str): The ID of the session
+
+ Returns:
+ Dict[str, str|int] if validation was successful, otherwise None
+ """
+ # XXX: We shouldn't need to keep wrapping and unwrapping this value
+ threepid_creds = {"client_secret": client_secret, "sid": sid}
+
+ # We don't actually know which medium this 3PID is. Thus we first assume it's email,
+ # and if validation fails we try msisdn
+ validation_session = None
+
+ # Try to validate as email
+ if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ # Ask our delegated email identity server
+ validation_session = yield self.threepid_from_creds(
+ self.hs.config.account_threepid_delegate_email, threepid_creds
+ )
+ elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ # Get a validated session matching these details
+ validation_session = yield self.store.get_threepid_validation_session(
+ "email", client_secret, sid=sid, validated=True
+ )
+
+ if validation_session:
+ return validation_session
+
+ # Try to validate as msisdn
+ if self.hs.config.account_threepid_delegate_msisdn:
+ # Ask our delegated msisdn identity server
+ validation_session = yield self.threepid_from_creds(
+ self.hs.config.account_threepid_delegate_msisdn, threepid_creds
+ )
+
+ return validation_session
+
+ @defer.inlineCallbacks
+ def proxy_msisdn_submit_token(self, id_server, client_secret, sid, token):
+ """Proxy a POST submitToken request to an identity server for verification purposes
+
+ Args:
+ id_server (str): The identity server URL to contact
+
+ client_secret (str): Secret provided by the client
+
+ sid (str): The ID of the session
+
+ token (str): The verification token
+
+ Raises:
+ SynapseError: If we failed to contact the identity server
+
+ Returns:
+ Deferred[dict]: The response dict from the identity server
+ """
+ body = {"client_secret": client_secret, "sid": sid, "token": token}
+
+ try:
+ return (
+ yield self.http_client.post_json_get_json(
+ id_server + "/_matrix/identity/api/v1/validate/msisdn/submitToken",
+ body,
+ )
+ )
+ except TimeoutError:
+ raise SynapseError(500, "Timed out contacting identity server")
+ except HttpResponseException as e:
+ logger.warning("Error contacting msisdn account_threepid_delegate: %s", e)
+ raise SynapseError(400, "Error contacting the identity server")
+
+ # TODO: The following two methods are used for proxying IS requests using
+ # the CS API. They should be consolidated with those in RoomMemberHandler
+ # https://github.com/matrix-org/synapse-dinsic/issues/25
@defer.inlineCallbacks
- def lookup_3pid(self, id_server, medium, address):
+ def proxy_lookup_3pid(self, id_server, medium, address):
"""Looks up a 3pid in the passed identity server.
Args:
@@ -366,26 +639,17 @@ class IdentityHandler(BaseHandler):
https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup
for details
"""
- if not self._should_trust_id_server(id_server):
- raise SynapseError(
- 400, "Untrusted ID server '%s'" % id_server,
- Codes.SERVER_NOT_TRUSTED
- )
-
if not self._enable_lookup:
raise AuthError(
- 403, "Looking up third-party identifiers is denied from this server",
+ 403, "Looking up third-party identifiers is denied from this server"
)
- target = self.rewrite_identity_server_urls.get(id_server, id_server)
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
try:
data = yield self.http_client.get_json(
- "https://%s/_matrix/identity/api/v1/lookup" % (target,),
- {
- "medium": medium,
- "address": address,
- }
+ "%s/_matrix/identity/api/v1/lookup" % (id_server_url,),
+ {"medium": medium, "address": address},
)
if "mxid" in data:
@@ -397,13 +661,13 @@ class IdentityHandler(BaseHandler):
logger.info("Proxied lookup failed: %r", e)
raise e.to_synapse_error()
except IOError as e:
- logger.info("Failed to contact %r: %s", id_server, e)
+ logger.info("Failed to contact %s: %s", id_server, e)
raise ProxiedRequestError(503, "Failed to contact identity server")
defer.returnValue(data)
@defer.inlineCallbacks
- def bulk_lookup_3pid(self, id_server, threepids):
+ def proxy_bulk_lookup_3pid(self, id_server, threepids):
"""Looks up given 3pids in the passed identity server.
Args:
@@ -417,58 +681,403 @@ class IdentityHandler(BaseHandler):
https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup
for details
"""
- if not self._should_trust_id_server(id_server):
- raise SynapseError(
- 400, "Untrusted ID server '%s'" % id_server,
- Codes.SERVER_NOT_TRUSTED
- )
-
if not self._enable_lookup:
raise AuthError(
- 403, "Looking up third-party identifiers is denied from this server",
+ 403, "Looking up third-party identifiers is denied from this server"
)
- target = self.rewrite_identity_server_urls.get(id_server, id_server)
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
try:
data = yield self.http_client.post_json_get_json(
- "https://%s/_matrix/identity/api/v1/bulk_lookup" % (target,),
- {
- "threepids": threepids,
- }
+ "%s/_matrix/identity/api/v1/bulk_lookup" % (id_server_url,),
+ {"threepids": threepids},
)
except HttpResponseException as e:
logger.info("Proxied lookup failed: %r", e)
raise e.to_synapse_error()
except IOError as e:
- logger.info("Failed to contact %r: %s", id_server, e)
+ logger.info("Failed to contact %s: %s", id_server, e)
raise ProxiedRequestError(503, "Failed to contact identity server")
defer.returnValue(data)
@defer.inlineCallbacks
- def _verify_any_signature(self, data, server_hostname):
- if server_hostname not in data["signatures"]:
- raise AuthError(401, "No signature from server %s" % (server_hostname,))
+ def lookup_3pid(self, id_server, medium, address, id_access_token=None):
+ """Looks up a 3pid in the passed identity server.
+
+ Args:
+ id_server (str): The server name (including port, if required)
+ of the identity server to use.
+ medium (str): The type of the third party identifier (e.g. "email").
+ address (str): The third party identifier (e.g. "foo@example.com").
+ id_access_token (str|None): The access token to authenticate to the identity
+ server with
+
+ Returns:
+ str|None: the matrix ID of the 3pid, or None if it is not recognized.
+ """
+ # Rewrite id_server URL if necessary
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
+ if id_access_token is not None:
+ try:
+ results = yield self._lookup_3pid_v2(
+ id_server_url, id_access_token, medium, address
+ )
+ return results
+
+ except Exception as e:
+ # Catch HttpResponseExcept for a non-200 response code
+ # Check if this identity server does not know about v2 lookups
+ if isinstance(e, HttpResponseException) and e.code == 404:
+ # This is an old identity server that does not yet support v2 lookups
+ logger.warning(
+ "Attempted v2 lookup on v1 identity server %s. Falling "
+ "back to v1",
+ id_server,
+ )
+ else:
+ logger.warning("Error when looking up hashing details: %s", e)
+ return None
+
+ return (yield self._lookup_3pid_v1(id_server, id_server_url, medium, address))
+
+ @defer.inlineCallbacks
+ def _lookup_3pid_v1(self, id_server, id_server_url, medium, address):
+ """Looks up a 3pid in the passed identity server using v1 lookup.
+
+ Args:
+ id_server (str): The server name (including port, if required)
+ of the identity server to use.
+ id_server_url (str): The actual, reachable domain of the id server
+ medium (str): The type of the third party identifier (e.g. "email").
+ address (str): The third party identifier (e.g. "foo@example.com").
+
+ Returns:
+ str: the matrix ID of the 3pid, or None if it is not recognized.
+ """
+ try:
+ data = yield self.http_client.get_json(
+ "%s/_matrix/identity/api/v1/lookup" % (id_server_url,),
+ {"medium": medium, "address": address},
+ )
+
+ if "mxid" in data:
+ if "signatures" not in data:
+ raise AuthError(401, "No signatures on 3pid binding")
+ yield self._verify_any_signature(data, id_server)
+ return data["mxid"]
+ except TimeoutError:
+ raise SynapseError(500, "Timed out contacting identity server")
+ except IOError as e:
+ logger.warning("Error from v1 identity server lookup: %s" % (e,))
+
+ return None
+
+ @defer.inlineCallbacks
+ def _lookup_3pid_v2(self, id_server_url, id_access_token, medium, address):
+ """Looks up a 3pid in the passed identity server using v2 lookup.
+
+ Args:
+ id_server_url (str): The protocol scheme and domain of the id server
+ id_access_token (str): The access token to authenticate to the identity server with
+ medium (str): The type of the third party identifier (e.g. "email").
+ address (str): The third party identifier (e.g. "foo@example.com").
+
+ Returns:
+ Deferred[str|None]: the matrix ID of the 3pid, or None if it is not recognised.
+ """
+ # Check what hashing details are supported by this identity server
+ try:
+ hash_details = yield self.http_client.get_json(
+ "%s/_matrix/identity/v2/hash_details" % (id_server_url,),
+ {"access_token": id_access_token},
+ )
+ except TimeoutError:
+ raise SynapseError(500, "Timed out contacting identity server")
+
+ if not isinstance(hash_details, dict):
+ logger.warning(
+ "Got non-dict object when checking hash details of %s: %s",
+ id_server_url,
+ hash_details,
+ )
+ raise SynapseError(
+ 400,
+ "Non-dict object from %s during v2 hash_details request: %s"
+ % (id_server_url, hash_details),
+ )
+
+ # Extract information from hash_details
+ supported_lookup_algorithms = hash_details.get("algorithms")
+ lookup_pepper = hash_details.get("lookup_pepper")
+ if (
+ not supported_lookup_algorithms
+ or not isinstance(supported_lookup_algorithms, list)
+ or not lookup_pepper
+ or not isinstance(lookup_pepper, str)
+ ):
+ raise SynapseError(
+ 400,
+ "Invalid hash details received from identity server %s: %s"
+ % (id_server_url, hash_details),
+ )
+
+ # Check if any of the supported lookup algorithms are present
+ if LookupAlgorithm.SHA256 in supported_lookup_algorithms:
+ # Perform a hashed lookup
+ lookup_algorithm = LookupAlgorithm.SHA256
+
+ # Hash address, medium and the pepper with sha256
+ to_hash = "%s %s %s" % (address, medium, lookup_pepper)
+ lookup_value = sha256_and_url_safe_base64(to_hash)
+
+ elif LookupAlgorithm.NONE in supported_lookup_algorithms:
+ # Perform a non-hashed lookup
+ lookup_algorithm = LookupAlgorithm.NONE
+
+ # Combine together plaintext address and medium
+ lookup_value = "%s %s" % (address, medium)
+
+ else:
+ logger.warning(
+ "None of the provided lookup algorithms of %s are supported: %s",
+ id_server_url,
+ supported_lookup_algorithms,
+ )
+ raise SynapseError(
+ 400,
+ "Provided identity server does not support any v2 lookup "
+ "algorithms that this homeserver supports.",
+ )
- for key_name, signature in data["signatures"][server_hostname].items():
- target = self.rewrite_identity_server_urls.get(
- server_hostname, server_hostname,
+ # Authenticate with identity server given the access token from the client
+ headers = {"Authorization": create_id_access_token_header(id_access_token)}
+
+ try:
+ lookup_results = yield self.http_client.post_json_get_json(
+ "%s/_matrix/identity/v2/lookup" % (id_server_url,),
+ {
+ "addresses": [lookup_value],
+ "algorithm": lookup_algorithm,
+ "pepper": lookup_pepper,
+ },
+ headers=headers,
+ )
+ except TimeoutError:
+ raise SynapseError(500, "Timed out contacting identity server")
+ except Exception as e:
+ logger.warning("Error when performing a v2 3pid lookup: %s", e)
+ raise SynapseError(
+ 500, "Unknown error occurred during identity server lookup"
)
+ # Check for a mapping from what we looked up to an MXID
+ if "mappings" not in lookup_results or not isinstance(
+ lookup_results["mappings"], dict
+ ):
+ logger.warning("No results from 3pid lookup")
+ return None
+
+ # Return the MXID if it's available, or None otherwise
+ mxid = lookup_results["mappings"].get(lookup_value)
+ return mxid
+
+ @defer.inlineCallbacks
+ def _verify_any_signature(self, data, id_server):
+ if id_server not in data["signatures"]:
+ raise AuthError(401, "No signature from server %s" % (id_server,))
+
+ for key_name, signature in data["signatures"][id_server].items():
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
key_data = yield self.http_client.get_json(
- "https://%s/_matrix/identity/api/v1/pubkey/%s" %
- (target, key_name,),
+ "%s/_matrix/identity/api/v1/pubkey/%s" % (id_server_url, key_name)
)
if "public_key" not in key_data:
- raise AuthError(401, "No public key named %s from %s" %
- (key_name, server_hostname,))
+ raise AuthError(
+ 401, "No public key named %s from %s" % (key_name, id_server)
+ )
verify_signed_json(
data,
- server_hostname,
- decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"]))
+ id_server,
+ decode_verify_key_bytes(
+ key_name, decode_base64(key_data["public_key"])
+ ),
)
return
- raise AuthError(401, "No signature from server %s" % (server_hostname,))
+ raise AuthError(401, "No signature from server %s" % (id_server,))
+
+ @defer.inlineCallbacks
+ def ask_id_server_for_third_party_invite(
+ self,
+ requester,
+ id_server,
+ medium,
+ address,
+ room_id,
+ inviter_user_id,
+ room_alias,
+ room_avatar_url,
+ room_join_rules,
+ room_name,
+ inviter_display_name,
+ inviter_avatar_url,
+ id_access_token=None,
+ ):
+ """
+ Asks an identity server for a third party invite.
+
+ Args:
+ requester (Requester)
+ id_server (str): hostname + optional port for the identity server.
+ medium (str): The literal string "email".
+ address (str): The third party address being invited.
+ room_id (str): The ID of the room to which the user is invited.
+ inviter_user_id (str): The user ID of the inviter.
+ room_alias (str): An alias for the room, for cosmetic notifications.
+ room_avatar_url (str): The URL of the room's avatar, for cosmetic
+ notifications.
+ room_join_rules (str): The join rules of the email (e.g. "public").
+ room_name (str): The m.room.name of the room.
+ inviter_display_name (str): The current display name of the
+ inviter.
+ inviter_avatar_url (str): The URL of the inviter's avatar.
+ id_access_token (str|None): The access token to authenticate to the identity
+ server with
+
+ Returns:
+ A deferred tuple containing:
+ token (str): The token which must be signed to prove authenticity.
+ public_keys ([{"public_key": str, "key_validity_url": str}]):
+ public_key is a base64-encoded ed25519 public key.
+ fallback_public_key: One element from public_keys.
+ display_name (str): A user-friendly name to represent the invited
+ user.
+ """
+ invite_config = {
+ "medium": medium,
+ "address": address,
+ "room_id": room_id,
+ "room_alias": room_alias,
+ "room_avatar_url": room_avatar_url,
+ "room_join_rules": room_join_rules,
+ "room_name": room_name,
+ "sender": inviter_user_id,
+ "sender_display_name": inviter_display_name,
+ "sender_avatar_url": inviter_avatar_url,
+ }
+
+ # Rewrite the identity server URL if necessary
+ id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
+
+ # Add the identity service access token to the JSON body and use the v2
+ # Identity Service endpoints if id_access_token is present
+ data = None
+ base_url = "%s/_matrix/identity" % (id_server_url,)
+
+ if id_access_token:
+ key_validity_url = "%s/_matrix/identity/v2/pubkey/isvalid" % (
+ id_server_url,
+ )
+
+ # Attempt a v2 lookup
+ url = base_url + "/v2/store-invite"
+ try:
+ data = yield self.blacklisting_http_client.post_json_get_json(
+ url,
+ invite_config,
+ {"Authorization": create_id_access_token_header(id_access_token)},
+ )
+ except TimeoutError:
+ raise SynapseError(500, "Timed out contacting identity server")
+ except HttpResponseException as e:
+ if e.code != 404:
+ logger.info("Failed to POST %s with JSON: %s", url, e)
+ raise e
+
+ if data is None:
+ key_validity_url = "%s/_matrix/identity/api/v1/pubkey/isvalid" % (
+ id_server_url,
+ )
+ url = base_url + "/api/v1/store-invite"
+
+ try:
+ data = yield self.blacklisting_http_client.post_json_get_json(
+ url, invite_config
+ )
+ except TimeoutError:
+ raise SynapseError(500, "Timed out contacting identity server")
+ except HttpResponseException as e:
+ logger.warning(
+ "Error trying to call /store-invite on %s: %s", id_server_url, e,
+ )
+
+ if data is None:
+ # Some identity servers may only support application/x-www-form-urlencoded
+ # types. This is especially true with old instances of Sydent, see
+ # https://github.com/matrix-org/sydent/pull/170
+ try:
+ data = yield self.blacklisting_http_client.post_urlencoded_get_json(
+ url, invite_config
+ )
+ except HttpResponseException as e:
+ logger.warning(
+ "Error calling /store-invite on %s with fallback "
+ "encoding: %s",
+ id_server_url,
+ e,
+ )
+ raise e
+
+ # TODO: Check for success
+ token = data["token"]
+ public_keys = data.get("public_keys", [])
+ if "public_key" in data:
+ fallback_public_key = {
+ "public_key": data["public_key"],
+ "key_validity_url": key_validity_url,
+ }
+ else:
+ fallback_public_key = public_keys[0]
+
+ if not public_keys:
+ public_keys.append(fallback_public_key)
+ display_name = data["display_name"]
+ return token, public_keys, fallback_public_key, display_name
+
+
+def create_id_access_token_header(id_access_token):
+ """Create an Authorization header for passing to SimpleHttpClient as the header value
+ of an HTTP request.
+
+ Args:
+ id_access_token (str): An identity server access token.
+
+ Returns:
+ list[str]: The ascii-encoded bearer token encased in a list.
+ """
+ # Prefix with Bearer
+ bearer_token = "Bearer %s" % id_access_token
+
+ # Encode headers to standard ascii
+ bearer_token.encode("ascii")
+
+ # Return as a list as that's how SimpleHttpClient takes header values
+ return [bearer_token]
+
+
+class LookupAlgorithm:
+ """
+ Supported hashing algorithms when performing a 3PID lookup.
+
+ SHA256 - Hashing an (address, medium, pepper) combo with sha256, then url-safe base64
+ encoding
+ NONE - Not performing any hashing. Simply sending an (address, medium) combo in plaintext
+ """
+
+ SHA256 = "sha256"
+ NONE = "none"
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index aaee5db0b7..b116500c7d 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -18,15 +18,15 @@ import logging
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.errors import SynapseError
from synapse.events.validator import EventValidator
from synapse.handlers.presence import format_user_presence_state
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.streams.config import PaginationConfig
from synapse.types import StreamToken, UserID
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import concurrently_execute
-from synapse.util.caches.snapshot_cache import SnapshotCache
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+from synapse.util.caches.response_cache import ResponseCache
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -41,11 +41,18 @@ class InitialSyncHandler(BaseHandler):
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
- self.snapshot_cache = SnapshotCache()
+ self.snapshot_cache = ResponseCache(hs, "initial_sync_cache")
self._event_serializer = hs.get_event_client_serializer()
-
- def snapshot_all_rooms(self, user_id=None, pagin_config=None,
- as_client_event=True, include_archived=False):
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
+
+ def snapshot_all_rooms(
+ self,
+ user_id=None,
+ pagin_config=None,
+ as_client_event=True,
+ include_archived=False,
+ ):
"""Retrieve a snapshot of all rooms the user is invited or has joined.
This snapshot may include messages for all rooms where the user is
@@ -72,24 +79,29 @@ class InitialSyncHandler(BaseHandler):
as_client_event,
include_archived,
)
- now_ms = self.clock.time_msec()
- result = self.snapshot_cache.get(now_ms, key)
- if result is not None:
- return result
- return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms(
- user_id, pagin_config, as_client_event, include_archived
- ))
+ return self.snapshot_cache.wrap(
+ key,
+ self._snapshot_all_rooms,
+ user_id,
+ pagin_config,
+ as_client_event,
+ include_archived,
+ )
- @defer.inlineCallbacks
- def _snapshot_all_rooms(self, user_id=None, pagin_config=None,
- as_client_event=True, include_archived=False):
+ async def _snapshot_all_rooms(
+ self,
+ user_id=None,
+ pagin_config=None,
+ as_client_event=True,
+ include_archived=False,
+ ):
memberships = [Membership.INVITE, Membership.JOIN]
if include_archived:
memberships.append(Membership.LEAVE)
- room_list = yield self.store.get_rooms_for_user_where_membership_is(
+ room_list = await self.store.get_rooms_for_local_user_where_membership_is(
user_id=user_id, membership_list=memberships
)
@@ -97,39 +109,37 @@ class InitialSyncHandler(BaseHandler):
rooms_ret = []
- now_token = yield self.hs.get_event_sources().get_current_token()
+ now_token = await self.hs.get_event_sources().get_current_token()
presence_stream = self.hs.get_event_sources().sources["presence"]
pagination_config = PaginationConfig(from_token=now_token)
- presence, _ = yield presence_stream.get_pagination_rows(
+ presence, _ = await presence_stream.get_pagination_rows(
user, pagination_config.get_source_config("presence"), None
)
receipt_stream = self.hs.get_event_sources().sources["receipt"]
- receipt, _ = yield receipt_stream.get_pagination_rows(
+ receipt, _ = await receipt_stream.get_pagination_rows(
user, pagination_config.get_source_config("receipt"), None
)
- tags_by_room = yield self.store.get_tags_for_user(user_id)
+ tags_by_room = await self.store.get_tags_for_user(user_id)
- account_data, account_data_by_room = (
- yield self.store.get_account_data_for_user(user_id)
+ account_data, account_data_by_room = await self.store.get_account_data_for_user(
+ user_id
)
- public_room_ids = yield self.store.get_public_room_ids()
+ public_room_ids = await self.store.get_public_room_ids()
limit = pagin_config.limit
if limit is None:
limit = 10
- @defer.inlineCallbacks
- def handle_room(event):
+ async def handle_room(event):
d = {
"room_id": event.room_id,
"membership": event.membership,
"visibility": (
- "public" if event.room_id in public_room_ids
- else "private"
+ "public" if event.room_id in public_room_ids else "private"
),
}
@@ -137,9 +147,9 @@ class InitialSyncHandler(BaseHandler):
time_now = self.clock.time_msec()
d["inviter"] = event.sender
- invite_event = yield self.store.get_event(event.event_id)
- d["invite"] = yield self._event_serializer.serialize_event(
- invite_event, time_now, as_client_event,
+ invite_event = await self.store.get_event(event.event_id)
+ d["invite"] = await self._event_serializer.serialize_event(
+ invite_event, time_now, as_client_event
)
rooms_ret.append(d)
@@ -151,20 +161,18 @@ class InitialSyncHandler(BaseHandler):
if event.membership == Membership.JOIN:
room_end_token = now_token.room_key
deferred_room_state = run_in_background(
- self.state_handler.get_current_state,
- event.room_id,
+ self.state_handler.get_current_state, event.room_id
)
elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,)
deferred_room_state = run_in_background(
- self.store.get_state_for_events,
- [event.event_id],
+ self.state_store.get_state_for_events, [event.event_id]
)
deferred_room_state.addCallback(
lambda states: states[event.event_id]
)
- (messages, token), current_state = yield make_deferred_yieldable(
+ (messages, token), current_state = await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
@@ -178,8 +186,8 @@ class InitialSyncHandler(BaseHandler):
)
).addErrback(unwrapFirstError)
- messages = yield filter_events_for_client(
- self.store, user_id, messages
+ messages = await filter_events_for_client(
+ self.storage, user_id, messages
)
start_token = now_token.copy_and_replace("room_key", token)
@@ -188,48 +196,42 @@ class InitialSyncHandler(BaseHandler):
d["messages"] = {
"chunk": (
- yield self._event_serializer.serialize_events(
- messages, time_now=time_now,
- as_client_event=as_client_event,
+ await self._event_serializer.serialize_events(
+ messages, time_now=time_now, as_client_event=as_client_event
)
),
"start": start_token.to_string(),
"end": end_token.to_string(),
}
- d["state"] = yield self._event_serializer.serialize_events(
+ d["state"] = await self._event_serializer.serialize_events(
current_state.values(),
time_now=time_now,
- as_client_event=as_client_event
+ as_client_event=as_client_event,
)
account_data_events = []
tags = tags_by_room.get(event.room_id)
if tags:
- account_data_events.append({
- "type": "m.tag",
- "content": {"tags": tags},
- })
+ account_data_events.append(
+ {"type": "m.tag", "content": {"tags": tags}}
+ )
account_data = account_data_by_room.get(event.room_id, {})
for account_data_type, content in account_data.items():
- account_data_events.append({
- "type": account_data_type,
- "content": content,
- })
+ account_data_events.append(
+ {"type": account_data_type, "content": content}
+ )
d["account_data"] = account_data_events
except Exception:
logger.exception("Failed to get snapshot")
- yield concurrently_execute(handle_room, room_list, 10)
+ await concurrently_execute(handle_room, room_list, 10)
account_data_events = []
for account_data_type, content in account_data.items():
- account_data_events.append({
- "type": account_data_type,
- "content": content,
- })
+ account_data_events.append({"type": account_data_type, "content": content})
now = self.clock.time_msec()
@@ -247,10 +249,9 @@ class InitialSyncHandler(BaseHandler):
"end": now_token.to_string(),
}
- defer.returnValue(ret)
+ return ret
- @defer.inlineCallbacks
- def room_initial_sync(self, requester, room_id, pagin_config=None):
+ async def room_initial_sync(self, requester, room_id, pagin_config=None):
"""Capture the a snapshot of a room. If user is currently a member of
the room this will be what is currently in the room. If the user left
the room this will be what was in the room when they left.
@@ -267,51 +268,46 @@ class InitialSyncHandler(BaseHandler):
A JSON serialisable dict with the snapshot of the room.
"""
- blocked = yield self.store.is_room_blocked(room_id)
+ blocked = await self.store.is_room_blocked(room_id)
if blocked:
raise SynapseError(403, "This room has been blocked on this server")
user_id = requester.user.to_string()
- membership, member_event_id = yield self._check_in_room_or_world_readable(
- room_id, user_id,
+ (
+ membership,
+ member_event_id,
+ ) = await self.auth.check_user_in_room_or_world_readable(
+ room_id, user_id, allow_departed_users=True,
)
is_peeking = member_event_id is None
if membership == Membership.JOIN:
- result = yield self._room_initial_sync_joined(
+ result = await self._room_initial_sync_joined(
user_id, room_id, pagin_config, membership, is_peeking
)
elif membership == Membership.LEAVE:
- result = yield self._room_initial_sync_parted(
+ result = await self._room_initial_sync_parted(
user_id, room_id, pagin_config, membership, member_event_id, is_peeking
)
account_data_events = []
- tags = yield self.store.get_tags_for_room(user_id, room_id)
+ tags = await self.store.get_tags_for_room(user_id, room_id)
if tags:
- account_data_events.append({
- "type": "m.tag",
- "content": {"tags": tags},
- })
+ account_data_events.append({"type": "m.tag", "content": {"tags": tags}})
- account_data = yield self.store.get_account_data_for_room(user_id, room_id)
+ account_data = await self.store.get_account_data_for_room(user_id, room_id)
for account_data_type, content in account_data.items():
- account_data_events.append({
- "type": account_data_type,
- "content": content,
- })
+ account_data_events.append({"type": account_data_type, "content": content})
result["account_data"] = account_data_events
- defer.returnValue(result)
+ return result
- @defer.inlineCallbacks
- def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
- membership, member_event_id, is_peeking):
- room_state = yield self.store.get_state_for_events(
- [member_event_id],
- )
+ async def _room_initial_sync_parted(
+ self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
+ ):
+ room_state = await self.state_store.get_state_for_events([member_event_id])
room_state = room_state[member_event_id]
@@ -319,18 +315,14 @@ class InitialSyncHandler(BaseHandler):
if limit is None:
limit = 10
- stream_token = yield self.store.get_stream_token_for_event(
- member_event_id
- )
+ stream_token = await self.store.get_stream_token_for_event(member_event_id)
- messages, token = yield self.store.get_recent_events_for_room(
- room_id,
- limit=limit,
- end_token=stream_token
+ messages, token = await self.store.get_recent_events_for_room(
+ room_id, limit=limit, end_token=stream_token
)
- messages = yield filter_events_for_client(
- self.store, user_id, messages, is_peeking=is_peeking
+ messages = await filter_events_for_client(
+ self.storage, user_id, messages, is_peeking=is_peeking
)
start_token = StreamToken.START.copy_and_replace("room_key", token)
@@ -338,74 +330,71 @@ class InitialSyncHandler(BaseHandler):
time_now = self.clock.time_msec()
- defer.returnValue({
+ return {
"membership": membership,
"room_id": room_id,
"messages": {
- "chunk": (yield self._event_serializer.serialize_events(
- messages, time_now,
- )),
+ "chunk": (
+ await self._event_serializer.serialize_events(messages, time_now)
+ ),
"start": start_token.to_string(),
"end": end_token.to_string(),
},
- "state": (yield self._event_serializer.serialize_events(
- room_state.values(), time_now,
- )),
+ "state": (
+ await self._event_serializer.serialize_events(
+ room_state.values(), time_now
+ )
+ ),
"presence": [],
"receipts": [],
- })
+ }
- @defer.inlineCallbacks
- def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
- membership, is_peeking):
- current_state = yield self.state.get_current_state(
- room_id=room_id,
- )
+ async def _room_initial_sync_joined(
+ self, user_id, room_id, pagin_config, membership, is_peeking
+ ):
+ current_state = await self.state.get_current_state(room_id=room_id)
# TODO: These concurrently
time_now = self.clock.time_msec()
- state = yield self._event_serializer.serialize_events(
- current_state.values(), time_now,
+ state = await self._event_serializer.serialize_events(
+ current_state.values(), time_now
)
- now_token = yield self.hs.get_event_sources().get_current_token()
+ now_token = await self.hs.get_event_sources().get_current_token()
limit = pagin_config.limit if pagin_config else None
if limit is None:
limit = 10
room_members = [
- m for m in current_state.values()
+ m
+ for m in current_state.values()
if m.type == EventTypes.Member
and m.content["membership"] == Membership.JOIN
]
presence_handler = self.hs.get_presence_handler()
- @defer.inlineCallbacks
- def get_presence():
+ async def get_presence():
# If presence is disabled, return an empty list
if not self.hs.config.use_presence:
- defer.returnValue([])
+ return []
- states = yield presence_handler.get_states(
- [m.user_id for m in room_members],
- as_event=True,
+ states = await presence_handler.get_states(
+ [m.user_id for m in room_members], as_event=True
)
- defer.returnValue(states)
+ return states
- @defer.inlineCallbacks
- def get_receipts():
- receipts = yield self.store.get_linearized_receipts_for_room(
- room_id,
- to_key=now_token.receipt_key,
+ async def get_receipts():
+ receipts = await self.store.get_linearized_receipts_for_room(
+ room_id, to_key=now_token.receipt_key
)
if not receipts:
receipts = []
- defer.returnValue(receipts)
+ return receipts
- presence, receipts, (messages, token) = yield make_deferred_yieldable(
+ presence, receipts, (messages, token) = await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(get_presence),
@@ -415,14 +404,14 @@ class InitialSyncHandler(BaseHandler):
room_id,
limit=limit,
end_token=now_token.room_key,
- )
+ ),
],
consumeErrors=True,
- ).addErrback(unwrapFirstError),
+ ).addErrback(unwrapFirstError)
)
- messages = yield filter_events_for_client(
- self.store, user_id, messages, is_peeking=is_peeking,
+ messages = await filter_events_for_client(
+ self.storage, user_id, messages, is_peeking=is_peeking
)
start_token = now_token.copy_and_replace("room_key", token)
@@ -433,9 +422,9 @@ class InitialSyncHandler(BaseHandler):
ret = {
"room_id": room_id,
"messages": {
- "chunk": (yield self._event_serializer.serialize_events(
- messages, time_now,
- )),
+ "chunk": (
+ await self._event_serializer.serialize_events(messages, time_now)
+ ),
"start": start_token.to_string(),
"end": end_token.to_string(),
},
@@ -446,29 +435,4 @@ class InitialSyncHandler(BaseHandler):
if not is_peeking:
ret["membership"] = membership
- defer.returnValue(ret)
-
- @defer.inlineCallbacks
- def _check_in_room_or_world_readable(self, room_id, user_id):
- try:
- # check_user_was_in_room will return the most recent membership
- # event for the user if:
- # * The user is a non-guest user, and was ever in the room
- # * The user is a guest user, and has joined the room
- # else it will throw.
- member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
- defer.returnValue((member_event.membership, member_event.event_id))
- return
- except AuthError:
- visibility = yield self.state_handler.get_current_state(
- room_id, EventTypes.RoomHistoryVisibility, ""
- )
- if (
- visibility and
- visibility.content["history_visibility"] == "world_readable"
- ):
- defer.returnValue((Membership.JOIN, None))
- return
- raise AuthError(
- 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
- )
+ return ret
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index eb750d65d8..b743fc2dcc 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import Optional
from six import iteritems, itervalues, string_types
@@ -22,8 +23,16 @@ from canonicaljson import encode_canonical_json, json
from twisted.internet import defer
from twisted.internet.defer import succeed
-
-from synapse.api.constants import EventTypes, Membership, RelationTypes
+from twisted.internet.interfaces import IDelayedCall
+
+from synapse import event_auth
+from synapse.api.constants import (
+ EventContentFields,
+ EventTypes,
+ Membership,
+ RelationTypes,
+ UserTypes,
+)
from synapse.api.errors import (
AuthError,
Codes,
@@ -31,15 +40,17 @@ from synapse.api.errors import (
NotFoundError,
SynapseError,
)
-from synapse.api.room_versions import RoomVersions
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.api.urls import ConsentURIBuilder
from synapse.events.validator import EventValidator
+from synapse.logging.context import run_in_background
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
-from synapse.types import RoomAlias, UserID
+from synapse.types import Collection, RoomAlias, UserID, create_requester
from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder
-from synapse.util.logcontext import run_in_background
from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client
@@ -57,11 +68,25 @@ class MessageHandler(object):
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
self._event_serializer = hs.get_event_client_serializer()
+ self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
+ self._is_worker_app = bool(hs.config.worker_app)
+
+ # The scheduled call to self._expire_event. None if no call is currently
+ # scheduled.
+ self._scheduled_expiry = None # type: Optional[IDelayedCall]
+
+ if not hs.config.worker_app:
+ run_as_background_process(
+ "_schedule_next_expiry", self._schedule_next_expiry
+ )
@defer.inlineCallbacks
- def get_room_data(self, user_id=None, room_id=None,
- event_type=None, state_key="", is_guest=False):
+ def get_room_data(
+ self, user_id=None, room_id=None, event_type=None, state_key="", is_guest=False
+ ):
""" Get data from a room.
Args:
@@ -71,27 +96,32 @@ class MessageHandler(object):
Raises:
SynapseError if something went wrong.
"""
- membership, membership_event_id = yield self.auth.check_in_room_or_world_readable(
- room_id, user_id
+ (
+ membership,
+ membership_event_id,
+ ) = yield self.auth.check_user_in_room_or_world_readable(
+ room_id, user_id, allow_departed_users=True
)
if membership == Membership.JOIN:
- data = yield self.state.get_current_state(
- room_id, event_type, state_key
- )
+ data = yield self.state.get_current_state(room_id, event_type, state_key)
elif membership == Membership.LEAVE:
key = (event_type, state_key)
- room_state = yield self.store.get_state_for_events(
+ room_state = yield self.state_store.get_state_for_events(
[membership_event_id], StateFilter.from_types([key])
)
data = room_state[membership_event_id].get(key)
- defer.returnValue(data)
+ return data
@defer.inlineCallbacks
def get_state_events(
- self, user_id, room_id, state_filter=StateFilter.all(),
- at_token=None, is_guest=False,
+ self,
+ user_id,
+ room_id,
+ state_filter=StateFilter.all(),
+ at_token=None,
+ is_guest=False,
):
"""Retrieve all state events for a given room. If the user is
joined to the room then return the current state. If the user has
@@ -123,55 +153,56 @@ class MessageHandler(object):
# does not reliably give you the state at the given stream position.
# (https://github.com/matrix-org/synapse/issues/3305)
last_events, _ = yield self.store.get_recent_events_for_room(
- room_id, end_token=at_token.room_key, limit=1,
+ room_id, end_token=at_token.room_key, limit=1
)
if not last_events:
- raise NotFoundError("Can't find event for token %s" % (at_token, ))
+ raise NotFoundError("Can't find event for token %s" % (at_token,))
visible_events = yield filter_events_for_client(
- self.store, user_id, last_events, apply_retention_policies=False
+ self.storage, user_id, last_events, filter_send_to_client=False
)
event = last_events[0]
if visible_events:
- room_state = yield self.store.get_state_for_events(
- [event.event_id], state_filter=state_filter,
+ room_state = yield self.state_store.get_state_for_events(
+ [event.event_id], state_filter=state_filter
)
room_state = room_state[event.event_id]
else:
raise AuthError(
403,
- "User %s not allowed to view events in room %s at token %s" % (
- user_id, room_id, at_token,
- )
+ "User %s not allowed to view events in room %s at token %s"
+ % (user_id, room_id, at_token),
)
else:
- membership, membership_event_id = (
- yield self.auth.check_in_room_or_world_readable(
- room_id, user_id,
- )
+ (
+ membership,
+ membership_event_id,
+ ) = yield self.auth.check_user_in_room_or_world_readable(
+ room_id, user_id, allow_departed_users=True
)
if membership == Membership.JOIN:
state_ids = yield self.store.get_filtered_current_state_ids(
- room_id, state_filter=state_filter,
+ room_id, state_filter=state_filter
)
room_state = yield self.store.get_events(state_ids.values())
elif membership == Membership.LEAVE:
- room_state = yield self.store.get_state_for_events(
- [membership_event_id], state_filter=state_filter,
+ room_state = yield self.state_store.get_state_for_events(
+ [membership_event_id], state_filter=state_filter
)
room_state = room_state[membership_event_id]
now = self.clock.time_msec()
events = yield self._event_serializer.serialize_events(
- room_state.values(), now,
+ room_state.values(),
+ now,
# We don't bother bundling aggregations in when asked for state
# events, as clients won't use them.
bundle_aggregations=False,
)
- defer.returnValue(events)
+ return events
@defer.inlineCallbacks
def get_joined_members(self, requester, room_id):
@@ -189,8 +220,8 @@ class MessageHandler(object):
if not requester.app_service:
# We check AS auth after fetching the room membership, as it
# requires us to pull out all joined members anyway.
- membership, _ = yield self.auth.check_in_room_or_world_readable(
- room_id, user_id
+ membership, _ = yield self.auth.check_user_in_room_or_world_readable(
+ room_id, user_id, allow_departed_users=True
)
if membership != Membership.JOIN:
raise NotImplementedError(
@@ -210,13 +241,114 @@ class MessageHandler(object):
# Loop fell through, AS has no interested users in room
raise AuthError(403, "Appservice not in room")
- defer.returnValue({
+ return {
user_id: {
"avatar_url": profile.avatar_url,
"display_name": profile.display_name,
}
for user_id, profile in iteritems(users_with_profile)
- })
+ }
+
+ def maybe_schedule_expiry(self, event):
+ """Schedule the expiry of an event if there's not already one scheduled,
+ or if the one running is for an event that will expire after the provided
+ timestamp.
+
+ This function needs to invalidate the event cache, which is only possible on
+ the master process, and therefore needs to be run on there.
+
+ Args:
+ event (EventBase): The event to schedule the expiry of.
+ """
+ assert not self._is_worker_app
+
+ expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
+ if not isinstance(expiry_ts, int) or event.is_state():
+ return
+
+ # _schedule_expiry_for_event won't actually schedule anything if there's already
+ # a task scheduled for a timestamp that's sooner than the provided one.
+ self._schedule_expiry_for_event(event.event_id, expiry_ts)
+
+ @defer.inlineCallbacks
+ def _schedule_next_expiry(self):
+ """Retrieve the ID and the expiry timestamp of the next event to be expired,
+ and schedule an expiry task for it.
+
+ If there's no event left to expire, set _expiry_scheduled to None so that a
+ future call to save_expiry_ts can schedule a new expiry task.
+ """
+ # Try to get the expiry timestamp of the next event to expire.
+ res = yield self.store.get_next_event_to_expire()
+ if res:
+ event_id, expiry_ts = res
+ self._schedule_expiry_for_event(event_id, expiry_ts)
+
+ def _schedule_expiry_for_event(self, event_id, expiry_ts):
+ """Schedule an expiry task for the provided event if there's not already one
+ scheduled at a timestamp that's sooner than the provided one.
+
+ Args:
+ event_id (str): The ID of the event to expire.
+ expiry_ts (int): The timestamp at which to expire the event.
+ """
+ if self._scheduled_expiry:
+ # If the provided timestamp refers to a time before the scheduled time of the
+ # next expiry task, cancel that task and reschedule it for this timestamp.
+ next_scheduled_expiry_ts = self._scheduled_expiry.getTime() * 1000
+ if expiry_ts < next_scheduled_expiry_ts:
+ self._scheduled_expiry.cancel()
+ else:
+ return
+
+ # Figure out how many seconds we need to wait before expiring the event.
+ now_ms = self.clock.time_msec()
+ delay = (expiry_ts - now_ms) / 1000
+
+ # callLater doesn't support negative delays, so trim the delay to 0 if we're
+ # in that case.
+ if delay < 0:
+ delay = 0
+
+ logger.info("Scheduling expiry for event %s in %.3fs", event_id, delay)
+
+ self._scheduled_expiry = self.clock.call_later(
+ delay,
+ run_as_background_process,
+ "_expire_event",
+ self._expire_event,
+ event_id,
+ )
+
+ @defer.inlineCallbacks
+ def _expire_event(self, event_id):
+ """Retrieve and expire an event that needs to be expired from the database.
+
+ If the event doesn't exist in the database, log it and delete the expiry date
+ from the database (so that we don't try to expire it again).
+ """
+ assert self._ephemeral_events_enabled
+
+ self._scheduled_expiry = None
+
+ logger.info("Expiring event %s", event_id)
+
+ try:
+ # Expire the event if we know about it. This function also deletes the expiry
+ # date from the database in the same database transaction.
+ yield self.store.expire_event(event_id)
+ except Exception as e:
+ logger.error("Could not expire event %s: %r", event_id, e)
+
+ # Schedule the expiry of the next event to expire.
+ yield self._schedule_next_expiry()
+
+
+# The duration (in ms) after which rooms should be removed
+# `_rooms_to_exclude_from_dummy_event_insertion` (with the effect that we will try
+# to generate a dummy event for them once more)
+#
+_DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY = 7 * 24 * 60 * 60 * 1000
class EventCreationHandler(object):
@@ -224,6 +356,7 @@ class EventCreationHandler(object):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
@@ -235,6 +368,8 @@ class EventCreationHandler(object):
self.config = hs.config
self.require_membership_for_aliases = hs.config.require_membership_for_aliases
+ self.room_invite_state_types = self.hs.config.room_invite_state_types
+
self.send_event_to_master = ReplicationSendEventRestServlet.make_client(hs)
# This is only used to get at ratelimit function, and maybe_kick_guest_users
@@ -255,15 +390,45 @@ class EventCreationHandler(object):
self.config.block_events_without_consent_error
)
+ # Rooms which should be excluded from dummy insertion. (For instance,
+ # those without local users who can send events into the room).
+ #
+ # map from room id to time-of-last-attempt.
+ #
+ self._rooms_to_exclude_from_dummy_event_insertion = {} # type: dict[str, int]
+
# we need to construct a ConsentURIBuilder here, as it checks that the necessary
# config options, but *only* if we have a configuration for which we are
# going to need it.
if self._block_events_without_consent_error:
self._consent_uri_builder = ConsentURIBuilder(self.config)
+ if (
+ not self.config.worker_app
+ and self.config.cleanup_extremities_with_dummy_events
+ ):
+ self.clock.looping_call(
+ lambda: run_as_background_process(
+ "send_dummy_events_to_fill_extremities",
+ self._send_dummy_events_to_fill_extremities,
+ ),
+ 5 * 60 * 1000,
+ )
+
+ self._message_handler = hs.get_message_handler()
+
+ self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
+
@defer.inlineCallbacks
- def create_event(self, requester, event_dict, token_id=None, txn_id=None,
- prev_events_and_hashes=None, require_consent=True):
+ def create_event(
+ self,
+ requester,
+ event_dict,
+ token_id=None,
+ txn_id=None,
+ prev_event_ids: Optional[Collection[str]] = None,
+ require_consent=True,
+ ):
"""
Given a dict from a client, create a new event.
@@ -278,10 +443,9 @@ class EventCreationHandler(object):
token_id (str)
txn_id (str)
- prev_events_and_hashes (list[(str, dict[str, str], int)]|None):
+ prev_event_ids:
the forward extremities to use as the prev_events for the
- new event. For each event, a tuple of (event_id, hashes, depth)
- where *hashes* is a map from algorithm to hash.
+ new event.
If None, they will be requested from the database.
@@ -299,7 +463,9 @@ class EventCreationHandler(object):
room_version = event_dict["content"]["room_version"]
else:
try:
- room_version = yield self.store.get_room_version(event_dict["room_id"])
+ room_version = yield self.store.get_room_version_id(
+ event_dict["room_id"]
+ )
except NotFoundError:
raise AuthError(403, "Unknown room")
@@ -323,8 +489,7 @@ class EventCreationHandler(object):
content["avatar_url"] = yield profile.get_avatar_url(target)
except Exception as e:
logger.info(
- "Failed to get profile information for %r: %s",
- target, e
+ "Failed to get profile information for %r: %s", target, e
)
is_exempt = yield self._is_exempt_from_privacy_policy(builder, requester)
@@ -338,9 +503,7 @@ class EventCreationHandler(object):
builder.internal_metadata.txn_id = txn_id
event, context = yield self.create_new_client_event(
- builder=builder,
- requester=requester,
- prev_events_and_hashes=prev_events_and_hashes,
+ builder=builder, requester=requester, prev_event_ids=prev_event_ids,
)
# In an ideal world we wouldn't need the second part of this condition. However,
@@ -355,26 +518,31 @@ class EventCreationHandler(object):
# federation as well as those created locally. As of room v3, aliases events
# can be created by users that are not in the room, therefore we have to
# tolerate them in event_auth.check().
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
- prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
+ prev_event = (
+ yield self.store.get_event(prev_event_id, allow_none=True)
+ if prev_event_id
+ else None
+ )
if not prev_event or prev_event.membership != Membership.JOIN:
logger.warning(
- ("Attempt to send `m.room.aliases` in room %s by user %s but"
- " membership is %s"),
+ (
+ "Attempt to send `m.room.aliases` in room %s by user %s but"
+ " membership is %s"
+ ),
event.room_id,
event.sender,
prev_event.membership if prev_event else None,
)
raise AuthError(
- 403,
- "You must be in the room to create an alias for it",
+ 403, "You must be in the room to create an alias for it"
)
self.validator.validate_new(event, self.config)
- defer.returnValue((event, context))
+ return (event, context)
def _is_exempt_from_privacy_policy(self, builder, requester):
""""Determine if an event to be sent is exempt from having to consent
@@ -401,9 +569,9 @@ class EventCreationHandler(object):
@defer.inlineCallbacks
def _is_server_notices_room(self, room_id):
if self.config.server_notices_mxid is None:
- defer.returnValue(False)
+ return False
user_ids = yield self.store.get_users_in_room(room_id)
- defer.returnValue(self.config.server_notices_mxid in user_ids)
+ return self.config.server_notices_mxid in user_ids
@defer.inlineCallbacks
def assert_accepted_privacy_policy(self, requester):
@@ -436,13 +604,16 @@ class EventCreationHandler(object):
# exempt the system notices user
if (
- self.config.server_notices_mxid is not None and
- user_id == self.config.server_notices_mxid
+ self.config.server_notices_mxid is not None
+ and user_id == self.config.server_notices_mxid
):
return
u = yield self.store.get_user_by_id(user_id)
assert u is not None
+ if u["user_type"] in (UserTypes.SUPPORT, UserTypes.BOT):
+ # support and bot users are not required to consent
+ return
if u["appservice_id"] is not None:
# users registered by an appservice are exempt
return
@@ -450,15 +621,10 @@ class EventCreationHandler(object):
return
consent_uri = self._consent_uri_builder.build_user_consent_uri(
- requester.user.localpart,
- )
- msg = self._block_events_without_consent_error % {
- 'consent_uri': consent_uri,
- }
- raise ConsentNotGivenError(
- msg=msg,
- consent_uri=consent_uri,
+ requester.user.localpart
)
+ msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
+ raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
@defer.inlineCallbacks
def send_nonmember_event(self, requester, event, context, ratelimit=True):
@@ -473,8 +639,7 @@ class EventCreationHandler(object):
"""
if event.type == EventTypes.Member:
raise SynapseError(
- 500,
- "Tried to send member event through non-member codepath"
+ 500, "Tried to send member event through non-member codepath"
)
user = UserID.from_string(event.sender)
@@ -486,15 +651,13 @@ class EventCreationHandler(object):
if prev_state is not None:
logger.info(
"Not bothering to persist state event %s duplicated by %s",
- event.event_id, prev_state.event_id,
+ event.event_id,
+ prev_state.event_id,
)
- defer.returnValue(prev_state)
+ return prev_state
yield self.handle_new_client_event(
- requester=requester,
- event=event,
- context=context,
- ratelimit=ratelimit,
+ requester=requester, event=event, context=context, ratelimit=ratelimit
)
@defer.inlineCallbacks
@@ -505,8 +668,10 @@ class EventCreationHandler(object):
If so, returns the version of the event in context.
Otherwise, returns None.
"""
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
prev_event_id = prev_state_ids.get((event.type, event.state_key))
+ if not prev_event_id:
+ return
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
@@ -515,16 +680,12 @@ class EventCreationHandler(object):
prev_content = encode_canonical_json(prev_event.content)
next_content = encode_canonical_json(event.content)
if prev_content == next_content:
- defer.returnValue(prev_event)
+ return prev_event
return
@defer.inlineCallbacks
def create_and_send_nonmember_event(
- self,
- requester,
- event_dict,
- ratelimit=True,
- txn_id=None
+ self, requester, event_dict, ratelimit=True, txn_id=None
):
"""
Creates an event, then sends it.
@@ -539,32 +700,25 @@ class EventCreationHandler(object):
# taking longer.
with (yield self.limiter.queue(event_dict["room_id"])):
event, context = yield self.create_event(
- requester,
- event_dict,
- token_id=requester.access_token_id,
- txn_id=txn_id
+ requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id
)
spam_error = self.spam_checker.check_event_for_spam(event)
if spam_error:
if not isinstance(spam_error, string_types):
spam_error = "Spam is not permitted here"
- raise SynapseError(
- 403, spam_error, Codes.FORBIDDEN
- )
+ raise SynapseError(403, spam_error, Codes.FORBIDDEN)
yield self.send_nonmember_event(
- requester,
- event,
- context,
- ratelimit=ratelimit,
+ requester, event, context, ratelimit=ratelimit
)
- defer.returnValue(event)
+ return event
@measure_func("create_new_client_event")
@defer.inlineCallbacks
- def create_new_client_event(self, builder, requester=None,
- prev_events_and_hashes=None):
+ def create_new_client_event(
+ self, builder, requester=None, prev_event_ids: Optional[Collection[str]] = None
+ ):
"""Create a new event for a local client
Args:
@@ -572,10 +726,9 @@ class EventCreationHandler(object):
requester (synapse.types.Requester|None):
- prev_events_and_hashes (list[(str, dict[str, str], int)]|None):
+ prev_event_ids:
the forward extremities to use as the prev_events for the
- new event. For each event, a tuple of (event_id, hashes, depth)
- where *hashes* is a map from algorithm to hash.
+ new event.
If None, they will be requested from the database.
@@ -583,23 +736,15 @@ class EventCreationHandler(object):
Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)]
"""
- if prev_events_and_hashes is not None:
- assert len(prev_events_and_hashes) <= 10, \
- "Attempting to create an event with %i prev_events" % (
- len(prev_events_and_hashes),
+ if prev_event_ids is not None:
+ assert len(prev_event_ids) <= 10, (
+ "Attempting to create an event with %i prev_events"
+ % (len(prev_event_ids),)
)
else:
- prev_events_and_hashes = \
- yield self.store.get_prev_events_for_room(builder.room_id)
+ prev_event_ids = yield self.store.get_prev_events_for_room(builder.room_id)
- prev_events = [
- (event_id, prev_hashes)
- for event_id, prev_hashes, _ in prev_events_and_hashes
- ]
-
- event = yield builder.build(
- prev_event_ids=[p for p, _ in prev_events],
- )
+ event = yield builder.build(prev_event_ids=prev_event_ids)
context = yield self.state.compute_event_context(event)
if requester:
context.app_service = requester.app_service
@@ -615,29 +760,19 @@ class EventCreationHandler(object):
aggregation_key = relation["key"]
already_exists = yield self.store.has_user_annotated_event(
- relates_to, event.type, aggregation_key, event.sender,
+ relates_to, event.type, aggregation_key, event.sender
)
if already_exists:
raise SynapseError(400, "Can't send same reaction twice")
- logger.debug(
- "Created event %s",
- event.event_id,
- )
+ logger.debug("Created event %s", event.event_id)
- defer.returnValue(
- (event, context,)
- )
+ return (event, context)
@measure_func("handle_new_client_event")
@defer.inlineCallbacks
def handle_new_client_event(
- self,
- requester,
- event,
- context,
- ratelimit=True,
- extra_users=[],
+ self, requester, event, context, ratelimit=True, extra_users=[]
):
"""Processes a new event. This includes checking auth, persisting it,
notifying users, sending to remote servers, etc.
@@ -653,25 +788,26 @@ class EventCreationHandler(object):
extra_users (list(UserID)): Any extra users to notify about event
"""
- if event.is_state() and (event.type, event.state_key) == (EventTypes.Create, ""):
- room_version = event.content.get(
- "room_version", RoomVersions.V1.identifier
- )
+ if event.is_state() and (event.type, event.state_key) == (
+ EventTypes.Create,
+ "",
+ ):
+ room_version = event.content.get("room_version", RoomVersions.V1.identifier)
else:
- room_version = yield self.store.get_room_version(event.room_id)
+ room_version = yield self.store.get_room_version_id(event.room_id)
event_allowed = yield self.third_party_event_rules.check_event_allowed(
- event, context,
+ event, context
)
if not event_allowed:
raise SynapseError(
- 403, "This event is not allowed in this context", Codes.FORBIDDEN,
+ 403, "This event is not allowed in this context", Codes.FORBIDDEN
)
try:
yield self.auth.check_from_context(room_version, event, context)
except AuthError as err:
- logger.warn("Denying new event %r because %s", event, err)
+ logger.warning("Denying new event %r because %s", event, err)
raise err
# Ensure that we can round trip before trying to persist in db
@@ -682,9 +818,7 @@ class EventCreationHandler(object):
logger.exception("Failed to encode content: %r", event.content)
raise
- yield self.action_generator.handle_push_actions_for_event(
- event, context
- )
+ yield self.action_generator.handle_push_actions_for_event(event, context)
# reraise does not allow inlineCallbacks to preserve the stacktrace, so we
# hack around with a try/finally instead.
@@ -705,11 +839,7 @@ class EventCreationHandler(object):
return
yield self.persist_and_notify_client_event(
- requester,
- event,
- context,
- ratelimit=ratelimit,
- extra_users=extra_users,
+ requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
success = True
@@ -718,18 +848,12 @@ class EventCreationHandler(object):
# Ensure that we actually remove the entries in the push actions
# staging area, if we calculated them.
run_in_background(
- self.store.remove_push_actions_from_staging,
- event.event_id,
+ self.store.remove_push_actions_from_staging, event.event_id
)
@defer.inlineCallbacks
def persist_and_notify_client_event(
- self,
- requester,
- event,
- context,
- ratelimit=True,
- extra_users=[],
+ self, requester, event, context, ratelimit=True, extra_users=[]
):
"""Called when we have fully built the event, have already
calculated the push actions for the event, and checked auth.
@@ -739,42 +863,99 @@ class EventCreationHandler(object):
assert not self.config.worker_app
if ratelimit:
- yield self.base_handler.ratelimit(requester)
+ # We check if this is a room admin redacting an event so that we
+ # can apply different ratelimiting. We do this by simply checking
+ # it's not a self-redaction (to avoid having to look up whether the
+ # user is actually admin or not).
+ is_admin_redaction = False
+ if event.type == EventTypes.Redaction:
+ original_event = yield self.store.get_event(
+ event.redacts,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ get_prev_content=False,
+ allow_rejected=False,
+ allow_none=True,
+ )
+
+ is_admin_redaction = (
+ original_event and event.sender != original_event.sender
+ )
+
+ yield self.base_handler.ratelimit(
+ requester, is_admin_redaction=is_admin_redaction
+ )
yield self.base_handler.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias:
- # Check the alias is acually valid (at this time at least)
+ # Validate a newly added alias or newly added alt_aliases.
+
+ original_alias = None
+ original_alt_aliases = set()
+
+ original_event_id = event.unsigned.get("replaces_state")
+ if original_event_id:
+ original_event = yield self.store.get_event(original_event_id)
+
+ if original_event:
+ original_alias = original_event.content.get("alias", None)
+ original_alt_aliases = original_event.content.get("alt_aliases", [])
+
+ # Check the alias is currently valid (if it has changed).
room_alias_str = event.content.get("alias", None)
- if room_alias_str:
+ directory_handler = self.hs.get_handlers().directory_handler
+ if room_alias_str and room_alias_str != original_alias:
room_alias = RoomAlias.from_string(room_alias_str)
- directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias)
if mapping["room_id"] != event.room_id:
raise SynapseError(
400,
- "Room alias %s does not point to the room" % (
- room_alias_str,
- )
+ "Room alias %s does not point to the room" % (room_alias_str,),
+ Codes.BAD_ALIAS,
)
+ # Check that alt_aliases is the proper form.
+ alt_aliases = event.content.get("alt_aliases", [])
+ if not isinstance(alt_aliases, (list, tuple)):
+ raise SynapseError(
+ 400, "The alt_aliases property must be a list.", Codes.INVALID_PARAM
+ )
+
+ # If the old version of alt_aliases is of an unknown form,
+ # completely replace it.
+ if not isinstance(original_alt_aliases, (list, tuple)):
+ original_alt_aliases = []
+
+ # Check that each alias is currently valid.
+ new_alt_aliases = set(alt_aliases) - set(original_alt_aliases)
+ if new_alt_aliases:
+ for alias_str in new_alt_aliases:
+ room_alias = RoomAlias.from_string(alias_str)
+ mapping = yield directory_handler.get_association(room_alias)
+
+ if mapping["room_id"] != event.room_id:
+ raise SynapseError(
+ 400,
+ "Room alias %s does not point to the room"
+ % (room_alias_str,),
+ Codes.BAD_ALIAS,
+ )
+
federation_handler = self.hs.get_handlers().federation_handler
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE:
+
def is_inviter_member_event(e):
- return (
- e.type == EventTypes.Member and
- e.sender == event.sender
- )
+ return e.type == EventTypes.Member and e.sender == event.sender
- current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = yield context.get_current_state_ids()
state_to_include_ids = [
e_id
for k, e_id in iteritems(current_state_ids)
- if k[0] in self.hs.config.room_invite_state_types
+ if k[0] in self.room_invite_state_types
or k == (EventTypes.Member, event.sender)
]
@@ -796,66 +977,75 @@ class EventCreationHandler(object):
# way? If we have been invited by a remote server, we need
# to get them to sign the event.
- returned_invite = yield federation_handler.send_invite(
- invitee.domain,
- event,
+ returned_invite = yield defer.ensureDeferred(
+ federation_handler.send_invite(invitee.domain, event)
)
-
event.unsigned.pop("room_state", None)
# TODO: Make sure the signatures actually are correct.
- event.signatures.update(
- returned_invite.signatures
- )
+ event.signatures.update(returned_invite.signatures)
if event.type == EventTypes.Redaction:
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ original_event = yield self.store.get_event(
+ event.redacts,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
+ get_prev_content=False,
+ allow_rejected=False,
+ allow_none=True,
+ )
+
+ # we can make some additional checks now if we have the original event.
+ if original_event:
+ if original_event.type == EventTypes.Create:
+ raise AuthError(403, "Redacting create events is not permitted")
+
+ if original_event.room_id != event.room_id:
+ raise SynapseError(400, "Cannot redact event from a different room")
+
+ prev_state_ids = yield context.get_prev_state_ids()
auth_events_ids = yield self.auth.compute_auth_events(
- event, prev_state_ids, for_verification=True,
+ event, prev_state_ids, for_verification=True
)
auth_events = yield self.store.get_events(auth_events_ids)
- auth_events = {
- (e.type, e.state_key): e for e in auth_events.values()
- }
- room_version = yield self.store.get_room_version(event.room_id)
- if self.auth.check_redaction(room_version, event, auth_events=auth_events):
- original_event = yield self.store.get_event(
- event.redacts,
- check_redacted=False,
- get_prev_content=False,
- allow_rejected=False,
- allow_none=False
- )
+ auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
+
+ room_version = yield self.store.get_room_version_id(event.room_id)
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+
+ if event_auth.check_redaction(
+ room_version_obj, event, auth_events=auth_events
+ ):
+ # this user doesn't have 'redact' rights, so we need to do some more
+ # checks on the original event. Let's start by checking the original
+ # event exists.
+ if not original_event:
+ raise NotFoundError("Could not find event %s" % (event.redacts,))
+
if event.user_id != original_event.user_id:
- raise AuthError(
- 403,
- "You don't have permission to redact events"
- )
+ raise AuthError(403, "You don't have permission to redact events")
- # We've already checked.
+ # all the checks are done.
event.internal_metadata.recheck_redaction = False
if event.type == EventTypes.Create:
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
if prev_state_ids:
- raise AuthError(
- 403,
- "Changing the room create event is forbidden",
- )
+ raise AuthError(403, "Changing the room create event is forbidden")
- (event_stream_id, max_stream_id) = yield self.store.persist_event(
+ event_stream_id, max_stream_id = yield self.storage.persistence.persist_event(
event, context=context
)
- yield self.pusher_pool.on_new_notifications(
- event_stream_id, max_stream_id,
- )
+ if self._ephemeral_events_enabled:
+ # If there's an expiry timestamp on the event, schedule its expiry.
+ self._message_handler.maybe_schedule_expiry(event)
+
+ yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
def _notify():
try:
self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id,
- extra_users=extra_users
+ event, event_stream_id, max_stream_id, extra_users=extra_users
)
except Exception:
logger.exception("Error notifying about new room event")
@@ -867,10 +1057,89 @@ class EventCreationHandler(object):
# matters as sometimes presence code can take a while.
run_in_background(self._bump_active_time, requester.user)
- @defer.inlineCallbacks
- def _bump_active_time(self, user):
+ async def _bump_active_time(self, user):
try:
presence = self.hs.get_presence_handler()
- yield presence.bump_presence_active_time(user)
+ await presence.bump_presence_active_time(user)
except Exception:
logger.exception("Error bumping presence active time")
+
+ @defer.inlineCallbacks
+ def _send_dummy_events_to_fill_extremities(self):
+ """Background task to send dummy events into rooms that have a large
+ number of extremities
+ """
+ self._expire_rooms_to_exclude_from_dummy_event_insertion()
+ room_ids = yield self.store.get_rooms_with_many_extremities(
+ min_count=10,
+ limit=5,
+ room_id_filter=self._rooms_to_exclude_from_dummy_event_insertion.keys(),
+ )
+
+ for room_id in room_ids:
+ # For each room we need to find a joined member we can use to send
+ # the dummy event with.
+
+ latest_event_ids = yield self.store.get_prev_events_for_room(room_id)
+
+ members = yield self.state.get_current_users_in_room(
+ room_id, latest_event_ids=latest_event_ids
+ )
+ dummy_event_sent = False
+ for user_id in members:
+ if not self.hs.is_mine_id(user_id):
+ continue
+ requester = create_requester(user_id)
+ try:
+ event, context = yield self.create_event(
+ requester,
+ {
+ "type": "org.matrix.dummy_event",
+ "content": {},
+ "room_id": room_id,
+ "sender": user_id,
+ },
+ prev_event_ids=latest_event_ids,
+ )
+
+ event.internal_metadata.proactively_send = False
+
+ yield self.send_nonmember_event(
+ requester, event, context, ratelimit=False
+ )
+ dummy_event_sent = True
+ break
+ except ConsentNotGivenError:
+ logger.info(
+ "Failed to send dummy event into room %s for user %s due to "
+ "lack of consent. Will try another user" % (room_id, user_id)
+ )
+ except AuthError:
+ logger.info(
+ "Failed to send dummy event into room %s for user %s due to "
+ "lack of power. Will try another user" % (room_id, user_id)
+ )
+
+ if not dummy_event_sent:
+ # Did not find a valid user in the room, so remove from future attempts
+ # Exclusion is time limited, so the room will be rechecked in the future
+ # dependent on _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY
+ logger.info(
+ "Failed to send dummy event into room %s. Will exclude it from "
+ "future attempts until cache expires" % (room_id,)
+ )
+ now = self.clock.time_msec()
+ self._rooms_to_exclude_from_dummy_event_insertion[room_id] = now
+
+ def _expire_rooms_to_exclude_from_dummy_event_insertion(self):
+ expire_before = self.clock.time_msec() - _DUMMY_EVENT_ROOM_EXCLUSION_EXPIRY
+ to_expire = set()
+ for room_id, time in self._rooms_to_exclude_from_dummy_event_insertion.items():
+ if time < expire_before:
+ to_expire.add(room_id)
+ for room_id in to_expire:
+ logger.debug(
+ "Expiring room id %s from dummy event insertion exclusion cache",
+ room_id,
+ )
+ del self._rooms_to_exclude_from_dummy_event_insertion[room_id]
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 3cf783e3bd..d7442c62a7 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -22,11 +22,11 @@ from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
+from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.state import StateFilter
from synapse.types import RoomStreamToken
from synapse.util.async_helpers import ReadWriteLock
-from synapse.util.logcontext import run_in_background
from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client
@@ -58,9 +58,7 @@ class PurgeStatus(object):
self.status = PurgeStatus.STATUS_ACTIVE
def asdict(self):
- return {
- "status": PurgeStatus.STATUS_TEXT[self.status]
- }
+ return {"status": PurgeStatus.STATUS_TEXT[self.status]}
class PaginationHandler(object):
@@ -74,7 +72,10 @@ class PaginationHandler(object):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
self.clock = hs.get_clock()
+ self._server_name = hs.hostname
self.pagination_lock = ReadWriteLock()
self._purges_in_progress_by_room = set()
@@ -87,6 +88,8 @@ class PaginationHandler(object):
if hs.config.retention_enabled:
# Run the purge jobs described in the configuration file.
for job in hs.config.retention_purge_jobs:
+ logger.info("Setting up purge job with config: %s", job)
+
self.clock.looping_call(
run_as_background_process,
job["interval"],
@@ -129,11 +132,22 @@ class PaginationHandler(object):
else:
include_null = False
+ logger.info(
+ "[purge] Running purge job for %s < max_lifetime <= %s (include NULLs = %s)",
+ min_ms,
+ max_ms,
+ include_null,
+ )
+
rooms = yield self.store.get_rooms_for_retention_period_in_range(
min_ms, max_ms, include_null
)
+ logger.debug("[purge] Rooms to purge: %s", rooms)
+
for room_id, retention_policy in iteritems(rooms):
+ logger.info("[purge] Attempting to purge messages in room %s", room_id)
+
if room_id in self._purges_in_progress_by_room:
logger.warning(
"[purge] not purging room %s as there's an ongoing purge running"
@@ -153,20 +167,17 @@ class PaginationHandler(object):
# Figure out what token we should start purging at.
ts = self.clock.time_msec() - max_lifetime
- stream_ordering = (
- yield self.store.find_first_stream_ordering_after_ts(ts)
- )
+ stream_ordering = yield self.store.find_first_stream_ordering_after_ts(ts)
- r = (
- yield self.store.get_room_event_after_stream_ordering(
- room_id, stream_ordering,
- )
+ r = yield self.store.get_room_event_before_stream_ordering(
+ room_id, stream_ordering,
)
if not r:
logger.warning(
"[purge] purging events not possible: No event found "
"(ts %i => stream_ordering %i)",
- ts, stream_ordering,
+ ts,
+ stream_ordering,
)
continue
@@ -185,13 +196,10 @@ class PaginationHandler(object):
# the background so that it's not blocking any other operation apart from
# other purges in the same room.
run_as_background_process(
- "_purge_history",
- self._purge_history,
- purge_id, room_id, token, True,
+ "_purge_history", self._purge_history, purge_id, room_id, token, True,
)
- def start_purge_history(self, room_id, token,
- delete_local_events=False):
+ def start_purge_history(self, room_id, token, delete_local_events=False):
"""Start off a history purge on a room.
Args:
@@ -206,8 +214,7 @@ class PaginationHandler(object):
"""
if room_id in self._purges_in_progress_by_room:
raise SynapseError(
- 400,
- "History purge already in progress for %s" % (room_id, ),
+ 400, "History purge already in progress for %s" % (room_id,)
)
purge_id = random_string(16)
@@ -218,14 +225,12 @@ class PaginationHandler(object):
self._purges_by_id[purge_id] = PurgeStatus()
run_in_background(
- self._purge_history,
- purge_id, room_id, token, delete_local_events,
+ self._purge_history, purge_id, room_id, token, delete_local_events
)
return purge_id
@defer.inlineCallbacks
- def _purge_history(self, purge_id, room_id, token,
- delete_local_events):
+ def _purge_history(self, purge_id, room_id, token, delete_local_events):
"""Carry out a history purge on a room.
Args:
@@ -241,16 +246,15 @@ class PaginationHandler(object):
self._purges_in_progress_by_room.add(room_id)
try:
with (yield self.pagination_lock.write(room_id)):
- yield self.store.purge_history(
- room_id, token, delete_local_events,
+ yield self.storage.purge_events.purge_history(
+ room_id, token, delete_local_events
)
logger.info("[purge] complete")
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE
except Exception:
f = Failure()
logger.error(
- "[purge] failed",
- exc_info=(f.type, f.value, f.getTracebackObject()),
+ "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject())
)
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED
finally:
@@ -259,6 +263,7 @@ class PaginationHandler(object):
# remove the purge from the list 24 hours after it completes
def clear_purge():
del self._purges_by_id[purge_id]
+
self.hs.get_reactor().callLater(24 * 3600, clear_purge)
def get_purge_status(self, purge_id):
@@ -272,9 +277,30 @@ class PaginationHandler(object):
"""
return self._purges_by_id.get(purge_id)
- @defer.inlineCallbacks
- def get_messages(self, requester, room_id=None, pagin_config=None,
- as_client_event=True, event_filter=None):
+ async def purge_room(self, room_id):
+ """Purge the given room from the database"""
+ with (await self.pagination_lock.write(room_id)):
+ # check we know about the room
+ await self.store.get_room_version_id(room_id)
+
+ # first check that we have no users in this room
+ joined = await defer.maybeDeferred(
+ self.store.is_host_joined, room_id, self._server_name
+ )
+
+ if joined:
+ raise SynapseError(400, "Users are still joined to this room")
+
+ await self.storage.purge_events.purge_room(room_id)
+
+ async def get_messages(
+ self,
+ requester,
+ room_id=None,
+ pagin_config=None,
+ as_client_event=True,
+ event_filter=None,
+ ):
"""Get messages in a room.
Args:
@@ -293,9 +319,7 @@ class PaginationHandler(object):
room_token = pagin_config.from_token.room_key
else:
pagin_config.from_token = (
- yield self.hs.get_event_sources().get_current_token_for_room(
- room_id=room_id
- )
+ await self.hs.get_event_sources().get_current_token_for_pagination()
)
room_token = pagin_config.from_token.room_key
@@ -307,18 +331,21 @@ class PaginationHandler(object):
source_config = pagin_config.get_source_config("room")
- with (yield self.pagination_lock.read(room_id)):
- membership, member_event_id = yield self.auth.check_in_room_or_world_readable(
- room_id, user_id
+ with (await self.pagination_lock.read(room_id)):
+ (
+ membership,
+ member_event_id,
+ ) = await self.auth.check_user_in_room_or_world_readable(
+ room_id, user_id, allow_departed_users=True
)
- if source_config.direction == 'b':
+ if source_config.direction == "b":
# if we're going backwards, we might need to backfill. This
# requires that we have a topo token.
if room_token.topological:
max_topo = room_token.topological
else:
- max_topo = yield self.store.get_max_topological_token(
+ max_topo = await self.store.get_max_topological_token(
room_id, room_token.stream
)
@@ -326,18 +353,18 @@ class PaginationHandler(object):
# If they have left the room then clamp the token to be before
# they left the room, to save the effort of loading from the
# database.
- leave_token = yield self.store.get_topological_token_for_event(
+ leave_token = await self.store.get_topological_token_for_event(
member_event_id
)
leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < max_topo:
source_config.from_key = str(leave_token)
- yield self.hs.get_handlers().federation_handler.maybe_backfill(
+ await self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, max_topo
)
- events, next_key = yield self.store.paginate_room_events(
+ events, next_key = await self.store.paginate_room_events(
room_id=room_id,
from_key=source_config.from_key,
to_key=source_config.to_key,
@@ -346,27 +373,22 @@ class PaginationHandler(object):
event_filter=event_filter,
)
- next_token = pagin_config.from_token.copy_and_replace(
- "room_key", next_key
- )
+ next_token = pagin_config.from_token.copy_and_replace("room_key", next_key)
if events:
if event_filter:
events = event_filter.filter(events)
- events = yield filter_events_for_client(
- self.store,
- user_id,
- events,
- is_peeking=(member_event_id is None),
+ events = await filter_events_for_client(
+ self.storage, user_id, events, is_peeking=(member_event_id is None)
)
if not events:
- defer.returnValue({
+ return {
"chunk": [],
"start": pagin_config.from_token.to_string(),
"end": next_token.to_string(),
- })
+ }
state = None
if event_filter and event_filter.lazy_load_members() and len(events) > 0:
@@ -374,25 +396,23 @@ class PaginationHandler(object):
# FIXME: we also care about invite targets etc.
state_filter = StateFilter.from_types(
- (EventTypes.Member, event.sender)
- for event in events
+ (EventTypes.Member, event.sender) for event in events
)
- state_ids = yield self.store.get_state_ids_for_event(
- events[0].event_id, state_filter=state_filter,
+ state_ids = await self.state_store.get_state_ids_for_event(
+ events[0].event_id, state_filter=state_filter
)
if state_ids:
- state = yield self.store.get_events(list(state_ids.values()))
+ state = await self.store.get_events(list(state_ids.values()))
state = state.values()
time_now = self.clock.time_msec()
chunk = {
"chunk": (
- yield self._event_serializer.serialize_events(
- events, time_now,
- as_client_event=as_client_event,
+ await self._event_serializer.serialize_events(
+ events, time_now, as_client_event=as_client_event
)
),
"start": pagin_config.from_token.to_string(),
@@ -400,11 +420,8 @@ class PaginationHandler(object):
}
if state:
- chunk["state"] = (
- yield self._event_serializer.serialize_events(
- state, time_now,
- as_client_event=as_client_event,
- )
+ chunk["state"] = await self._event_serializer.serialize_events(
+ state, time_now, as_client_event=as_client_event
)
- defer.returnValue(chunk)
+ return chunk
diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py
index 9994b44455..d06b110269 100644
--- a/synapse/handlers/password_policy.py
+++ b/synapse/handlers/password_policy.py
@@ -57,8 +57,8 @@ class PasswordPolicyHandler(object):
)
if (
- self.policy.get("require_digit", False) and
- self.regexp_digit.search(password) is None
+ self.policy.get("require_digit", False)
+ and self.regexp_digit.search(password) is None
):
raise PasswordRefusedError(
msg="The password must include at least one digit",
@@ -66,8 +66,8 @@ class PasswordPolicyHandler(object):
)
if (
- self.policy.get("require_symbol", False) and
- self.regexp_symbol.search(password) is None
+ self.policy.get("require_symbol", False)
+ and self.regexp_symbol.search(password) is None
):
raise PasswordRefusedError(
msg="The password must include at least one symbol",
@@ -75,8 +75,8 @@ class PasswordPolicyHandler(object):
)
if (
- self.policy.get("require_uppercase", False) and
- self.regexp_uppercase.search(password) is None
+ self.policy.get("require_uppercase", False)
+ and self.regexp_uppercase.search(password) is None
):
raise PasswordRefusedError(
msg="The password must include at least one uppercase letter",
@@ -84,8 +84,8 @@ class PasswordPolicyHandler(object):
)
if (
- self.policy.get("require_lowercase", False) and
- self.regexp_lowercase.search(password) is None
+ self.policy.get("require_lowercase", False)
+ and self.regexp_lowercase.search(password) is None
):
raise PasswordRefusedError(
msg="The password must include at least one lowercase letter",
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 557fb5f83d..5526015ddb 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -24,42 +24,52 @@ The methods that define policy are:
import logging
from contextlib import contextmanager
+from typing import Dict, List, Set
from six import iteritems, itervalues
from prometheus_client import Counter
+from typing_extensions import ContextManager
from twisted.internet import defer
import synapse.metrics
from synapse.api.constants import EventTypes, Membership, PresenceState
from synapse.api.errors import SynapseError
+from synapse.logging.context import run_in_background
+from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
-from synapse.util.caches.descriptors import cachedInlineCallbacks
-from synapse.util.logcontext import run_in_background
-from synapse.util.logutils import log_function
+from synapse.util.caches.descriptors import cached
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
+MYPY = False
+if MYPY:
+ import synapse.server
+
logger = logging.getLogger(__name__)
notified_presence_counter = Counter("synapse_handler_presence_notified_presence", "")
federation_presence_out_counter = Counter(
- "synapse_handler_presence_federation_presence_out", "")
+ "synapse_handler_presence_federation_presence_out", ""
+)
presence_updates_counter = Counter("synapse_handler_presence_presence_updates", "")
timers_fired_counter = Counter("synapse_handler_presence_timers_fired", "")
-federation_presence_counter = Counter("synapse_handler_presence_federation_presence", "")
+federation_presence_counter = Counter(
+ "synapse_handler_presence_federation_presence", ""
+)
bump_active_time_counter = Counter("synapse_handler_presence_bump_active_time", "")
get_updates_counter = Counter("synapse_handler_presence_get_updates", "", ["type"])
notify_reason_counter = Counter(
- "synapse_handler_presence_notify_reason", "", ["reason"])
+ "synapse_handler_presence_notify_reason", "", ["reason"]
+)
state_transition_counter = Counter(
"synapse_handler_presence_state_transition", "", ["from", "to"]
)
@@ -90,15 +100,8 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
class PresenceHandler(object):
-
- def __init__(self, hs):
- """
-
- Args:
- hs (synapse.server.HomeServer):
- """
+ def __init__(self, hs: "synapse.server.HomeServer"):
self.hs = hs
- self.is_mine = hs.is_mine
self.is_mine_id = hs.is_mine_id
self.server_name = hs.hostname
self.clock = hs.get_clock()
@@ -110,31 +113,26 @@ class PresenceHandler(object):
federation_registry = hs.get_federation_registry()
- federation_registry.register_edu_handler(
- "m.presence", self.incoming_presence
- )
+ federation_registry.register_edu_handler("m.presence", self.incoming_presence)
active_presence = self.store.take_presence_startup_info()
# A dictionary of the current state of users. This is prefilled with
# non-offline presence from the DB. We should fetch from the DB if
# we can't find a users presence in here.
- self.user_to_current_state = {
- state.user_id: state
- for state in active_presence
- }
+ self.user_to_current_state = {state.user_id: state for state in active_presence}
LaterGauge(
- "synapse_handlers_presence_user_to_current_state_size", "", [],
- lambda: len(self.user_to_current_state)
+ "synapse_handlers_presence_user_to_current_state_size",
+ "",
+ [],
+ lambda: len(self.user_to_current_state),
)
now = self.clock.time_msec()
for state in active_presence:
self.wheel_timer.insert(
- now=now,
- obj=state.user_id,
- then=state.last_active_ts + IDLE_TIMER,
+ now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER
)
self.wheel_timer.insert(
now=now,
@@ -156,7 +154,7 @@ class PresenceHandler(object):
# Set of users who have presence in the `user_to_current_state` that
# have not yet been persisted
- self.unpersisted_users_changes = set()
+ self.unpersisted_users_changes = set() # type: Set[str]
hs.get_reactor().addSystemEventTrigger(
"before",
@@ -166,12 +164,11 @@ class PresenceHandler(object):
self._on_shutdown,
)
- self.serial_to_user = {}
self._next_serial = 1
# Keeps track of the number of *ongoing* syncs on this process. While
# this is non zero a user will never go offline.
- self.user_to_num_current_syncs = {}
+ self.user_to_num_current_syncs = {} # type: Dict[str, int]
# Keeps track of the number of *ongoing* syncs on other processes.
# While any sync is ongoing on another process the user will never
@@ -181,8 +178,9 @@ class PresenceHandler(object):
# we assume that all the sync requests on that process have stopped.
# Stored as a dict from process_id to set of user_id, and a dict of
# process_id to millisecond timestamp last updated.
- self.external_process_to_current_syncs = {}
- self.external_process_last_updated_ms = {}
+ self.external_process_to_current_syncs = {} # type: Dict[int, Set[str]]
+ self.external_process_last_updated_ms = {} # type: Dict[int, int]
+
self.external_sync_linearizer = Linearizer(name="external_sync_linearizer")
# Start a LoopingCall in 30s that fires every 5s.
@@ -193,27 +191,21 @@ class PresenceHandler(object):
"handle_presence_timeouts", self._handle_timeouts
)
- self.clock.call_later(
- 30,
- self.clock.looping_call,
- run_timeout_handler,
- 5000,
- )
+ self.clock.call_later(30, self.clock.looping_call, run_timeout_handler, 5000)
def run_persister():
return run_as_background_process(
"persist_presence_changes", self._persist_unpersisted_changes
)
- self.clock.call_later(
- 60,
- self.clock.looping_call,
- run_persister,
- 60 * 1000,
- )
+ self.clock.call_later(60, self.clock.looping_call, run_persister, 60 * 1000)
- LaterGauge("synapse_handlers_presence_wheel_timer_size", "", [],
- lambda: len(self.wheel_timer))
+ LaterGauge(
+ "synapse_handlers_presence_wheel_timer_size",
+ "",
+ [],
+ lambda: len(self.wheel_timer),
+ )
# Used to handle sending of presence to newly joined users/servers
if hs.config.use_presence:
@@ -224,8 +216,7 @@ class PresenceHandler(object):
self._event_pos = self.store.get_current_events_token()
self._event_processing = False
- @defer.inlineCallbacks
- def _on_shutdown(self):
+ async def _on_shutdown(self):
"""Gets called when shutting down. This lets us persist any updates that
we haven't yet persisted, e.g. updates that only changes some internal
timers. This allows changes to persist across startup without having to
@@ -236,24 +227,25 @@ class PresenceHandler(object):
is some spurious presence changes that will self-correct.
"""
# If the DB pool has already terminated, don't try updating
- if not self.hs.get_db_pool().running:
+ if not self.store.db.is_running():
return
logger.info(
"Performing _on_shutdown. Persisting %d unpersisted changes",
- len(self.user_to_current_state)
+ len(self.user_to_current_state),
)
if self.unpersisted_users_changes:
- yield self.store.update_presence([
- self.user_to_current_state[user_id]
- for user_id in self.unpersisted_users_changes
- ])
+ await self.store.update_presence(
+ [
+ self.user_to_current_state[user_id]
+ for user_id in self.unpersisted_users_changes
+ ]
+ )
logger.info("Finished _on_shutdown")
- @defer.inlineCallbacks
- def _persist_unpersisted_changes(self):
+ async def _persist_unpersisted_changes(self):
"""We periodically persist the unpersisted changes, as otherwise they
may stack up and slow down shutdown times.
"""
@@ -261,16 +253,12 @@ class PresenceHandler(object):
self.unpersisted_users_changes = set()
if unpersisted:
- logger.info(
- "Persisting %d upersisted presence updates", len(unpersisted)
+ logger.info("Persisting %d unpersisted presence updates", len(unpersisted))
+ await self.store.update_presence(
+ [self.user_to_current_state[user_id] for user_id in unpersisted]
)
- yield self.store.update_presence([
- self.user_to_current_state[user_id]
- for user_id in unpersisted
- ])
- @defer.inlineCallbacks
- def _update_states(self, new_states):
+ async def _update_states(self, new_states):
"""Updates presence of users. Sets the appropriate timeouts. Pokes
the notifier and federation if and only if the changed presence state
should be sent to clients/servers.
@@ -279,7 +267,7 @@ class PresenceHandler(object):
with Measure(self.clock, "presence_update_states"):
- # NOTE: We purposefully don't yield between now and when we've
+ # NOTE: We purposefully don't await between now and when we've
# calculated what we want to do with the new states, to avoid races.
to_notify = {} # Changes we want to notify everyone about
@@ -303,10 +291,11 @@ class PresenceHandler(object):
)
new_state, should_notify, should_ping = handle_update(
- prev_state, new_state,
+ prev_state,
+ new_state,
is_mine=self.is_mine_id(user_id),
wheel_timer=self.wheel_timer,
- now=now
+ now=now,
)
self.user_to_current_state[user_id] = new_state
@@ -322,13 +311,14 @@ class PresenceHandler(object):
if to_notify:
notified_presence_counter.inc(len(to_notify))
- yield self._persist_and_notify(list(to_notify.values()))
+ await self._persist_and_notify(list(to_notify.values()))
- self.unpersisted_users_changes |= set(s.user_id for s in new_states)
+ self.unpersisted_users_changes |= {s.user_id for s in new_states}
self.unpersisted_users_changes -= set(to_notify.keys())
to_federation_ping = {
- user_id: state for user_id, state in to_federation_ping.items()
+ user_id: state
+ for user_id, state in to_federation_ping.items()
if user_id not in to_notify
}
if to_federation_ping:
@@ -336,11 +326,11 @@ class PresenceHandler(object):
self._push_to_remotes(to_federation_ping.values())
- def _handle_timeouts(self):
+ async def _handle_timeouts(self):
"""Checks the presence of users that have timed out and updates as
appropriate.
"""
- logger.info("Handling presence timeouts")
+ logger.debug("Handling presence timeouts")
now = self.clock.time_msec()
# Fetch the list of users that *may* have timed out. Things may have
@@ -351,20 +341,21 @@ class PresenceHandler(object):
# Check whether the lists of syncing processes from an external
# process have expired.
expired_process_ids = [
- process_id for process_id, last_update
- in self.external_process_last_updated_ms.items()
+ process_id
+ for process_id, last_update in self.external_process_last_updated_ms.items()
if now - last_update > EXTERNAL_PROCESS_EXPIRY
]
for process_id in expired_process_ids:
+ # For each expired process drop tracking info and check the users
+ # that were syncing on that process to see if they need to be timed
+ # out.
users_to_check.update(
- self.external_process_last_updated_ms.pop(process_id, ())
+ self.external_process_to_current_syncs.pop(process_id, ())
)
- self.external_process_last_update.pop(process_id)
+ self.external_process_last_updated_ms.pop(process_id)
states = [
- self.user_to_current_state.get(
- user_id, UserPresenceState.default(user_id)
- )
+ self.user_to_current_state.get(user_id, UserPresenceState.default(user_id))
for user_id in users_to_check
]
@@ -377,10 +368,9 @@ class PresenceHandler(object):
now=now,
)
- return self._update_states(changes)
+ return await self._update_states(changes)
- @defer.inlineCallbacks
- def bump_presence_active_time(self, user):
+ async def bump_presence_active_time(self, user):
"""We've seen the user do something that indicates they're interacting
with the app.
"""
@@ -392,18 +382,17 @@ class PresenceHandler(object):
bump_active_time_counter.inc()
- prev_state = yield self.current_state_for_user(user_id)
+ prev_state = await self.current_state_for_user(user_id)
- new_fields = {
- "last_active_ts": self.clock.time_msec(),
- }
+ new_fields = {"last_active_ts": self.clock.time_msec()}
if prev_state.state == PresenceState.UNAVAILABLE:
new_fields["state"] = PresenceState.ONLINE
- yield self._update_states([prev_state.copy_and_replace(**new_fields)])
+ await self._update_states([prev_state.copy_and_replace(**new_fields)])
- @defer.inlineCallbacks
- def user_syncing(self, user_id, affect_presence=True):
+ async def user_syncing(
+ self, user_id: str, affect_presence: bool = True
+ ) -> ContextManager[None]:
"""Returns a context manager that should surround any stream requests
from the user.
@@ -426,29 +415,40 @@ class PresenceHandler(object):
curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
self.user_to_num_current_syncs[user_id] = curr_sync + 1
- prev_state = yield self.current_state_for_user(user_id)
+ prev_state = await self.current_state_for_user(user_id)
if prev_state.state == PresenceState.OFFLINE:
# If they're currently offline then bring them online, otherwise
# just update the last sync times.
- yield self._update_states([prev_state.copy_and_replace(
- state=PresenceState.ONLINE,
- last_active_ts=self.clock.time_msec(),
- last_user_sync_ts=self.clock.time_msec(),
- )])
+ await self._update_states(
+ [
+ prev_state.copy_and_replace(
+ state=PresenceState.ONLINE,
+ last_active_ts=self.clock.time_msec(),
+ last_user_sync_ts=self.clock.time_msec(),
+ )
+ ]
+ )
else:
- yield self._update_states([prev_state.copy_and_replace(
- last_user_sync_ts=self.clock.time_msec(),
- )])
+ await self._update_states(
+ [
+ prev_state.copy_and_replace(
+ last_user_sync_ts=self.clock.time_msec()
+ )
+ ]
+ )
- @defer.inlineCallbacks
- def _end():
+ async def _end():
try:
self.user_to_num_current_syncs[user_id] -= 1
- prev_state = yield self.current_state_for_user(user_id)
- yield self._update_states([prev_state.copy_and_replace(
- last_user_sync_ts=self.clock.time_msec(),
- )])
+ prev_state = await self.current_state_for_user(user_id)
+ await self._update_states(
+ [
+ prev_state.copy_and_replace(
+ last_user_sync_ts=self.clock.time_msec()
+ )
+ ]
+ )
except Exception:
logger.exception("Error updating presence after sync")
@@ -460,7 +460,7 @@ class PresenceHandler(object):
if affect_presence:
run_in_background(_end)
- defer.returnValue(_user_syncing())
+ return _user_syncing()
def get_currently_syncing_users(self):
"""Get the set of user ids that are currently syncing on this HS.
@@ -469,7 +469,8 @@ class PresenceHandler(object):
"""
if self.hs.config.use_presence:
syncing_user_ids = {
- user_id for user_id, count in self.user_to_num_current_syncs.items()
+ user_id
+ for user_id, count in self.user_to_num_current_syncs.items()
if count
}
for user_ids in self.external_process_to_current_syncs.values():
@@ -478,8 +479,9 @@ class PresenceHandler(object):
else:
return set()
- @defer.inlineCallbacks
- def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec):
+ async def update_external_syncs_row(
+ self, process_id, user_id, is_syncing, sync_time_msec
+ ):
"""Update the syncing users for an external process as a delta.
Args:
@@ -490,8 +492,8 @@ class PresenceHandler(object):
is_syncing (bool): Whether or not the user is now syncing
sync_time_msec(int): Time in ms when the user was last syncing
"""
- with (yield self.external_sync_linearizer.queue(process_id)):
- prev_state = yield self.current_state_for_user(user_id)
+ with (await self.external_sync_linearizer.queue(process_id)):
+ prev_state = await self.current_state_for_user(user_id)
process_presence = self.external_process_to_current_syncs.setdefault(
process_id, set()
@@ -500,60 +502,59 @@ class PresenceHandler(object):
updates = []
if is_syncing and user_id not in process_presence:
if prev_state.state == PresenceState.OFFLINE:
- updates.append(prev_state.copy_and_replace(
- state=PresenceState.ONLINE,
- last_active_ts=sync_time_msec,
- last_user_sync_ts=sync_time_msec,
- ))
+ updates.append(
+ prev_state.copy_and_replace(
+ state=PresenceState.ONLINE,
+ last_active_ts=sync_time_msec,
+ last_user_sync_ts=sync_time_msec,
+ )
+ )
else:
- updates.append(prev_state.copy_and_replace(
- last_user_sync_ts=sync_time_msec,
- ))
+ updates.append(
+ prev_state.copy_and_replace(last_user_sync_ts=sync_time_msec)
+ )
process_presence.add(user_id)
elif user_id in process_presence:
- updates.append(prev_state.copy_and_replace(
- last_user_sync_ts=sync_time_msec,
- ))
+ updates.append(
+ prev_state.copy_and_replace(last_user_sync_ts=sync_time_msec)
+ )
if not is_syncing:
process_presence.discard(user_id)
if updates:
- yield self._update_states(updates)
+ await self._update_states(updates)
self.external_process_last_updated_ms[process_id] = self.clock.time_msec()
- @defer.inlineCallbacks
- def update_external_syncs_clear(self, process_id):
+ async def update_external_syncs_clear(self, process_id):
"""Marks all users that had been marked as syncing by a given process
as offline.
Used when the process has stopped/disappeared.
"""
- with (yield self.external_sync_linearizer.queue(process_id)):
+ with (await self.external_sync_linearizer.queue(process_id)):
process_presence = self.external_process_to_current_syncs.pop(
process_id, set()
)
- prev_states = yield self.current_state_for_users(process_presence)
+ prev_states = await self.current_state_for_users(process_presence)
time_now_ms = self.clock.time_msec()
- yield self._update_states([
- prev_state.copy_and_replace(
- last_user_sync_ts=time_now_ms,
- )
- for prev_state in itervalues(prev_states)
- ])
+ await self._update_states(
+ [
+ prev_state.copy_and_replace(last_user_sync_ts=time_now_ms)
+ for prev_state in itervalues(prev_states)
+ ]
+ )
self.external_process_last_updated_ms.pop(process_id, None)
- @defer.inlineCallbacks
- def current_state_for_user(self, user_id):
+ async def current_state_for_user(self, user_id):
"""Get the current presence state for a user.
"""
- res = yield self.current_state_for_users([user_id])
- defer.returnValue(res[user_id])
+ res = await self.current_state_for_users([user_id])
+ return res[user_id]
- @defer.inlineCallbacks
- def current_state_for_users(self, user_ids):
+ async def current_state_for_users(self, user_ids):
"""Get the current presence state for multiple users.
Returns:
@@ -568,45 +569,46 @@ class PresenceHandler(object):
if missing:
# There are things not in our in memory cache. Lets pull them out of
# the database.
- res = yield self.store.get_presence_for_users(missing)
+ res = await self.store.get_presence_for_users(missing)
states.update(res)
missing = [user_id for user_id, state in iteritems(states) if not state]
if missing:
new = {
- user_id: UserPresenceState.default(user_id)
- for user_id in missing
+ user_id: UserPresenceState.default(user_id) for user_id in missing
}
states.update(new)
self.user_to_current_state.update(new)
- defer.returnValue(states)
+ return states
- @defer.inlineCallbacks
- def _persist_and_notify(self, states):
+ async def _persist_and_notify(self, states):
"""Persist states in the database, poke the notifier and send to
interested remote servers
"""
- stream_id, max_token = yield self.store.update_presence(states)
+ stream_id, max_token = await self.store.update_presence(states)
- parties = yield get_interested_parties(self.store, states)
+ parties = await get_interested_parties(self.store, states)
room_ids_to_states, users_to_states = parties
self.notifier.on_new_event(
- "presence_key", stream_id, rooms=room_ids_to_states.keys(),
- users=[UserID.from_string(u) for u in users_to_states]
+ "presence_key",
+ stream_id,
+ rooms=room_ids_to_states.keys(),
+ users=[UserID.from_string(u) for u in users_to_states],
)
self._push_to_remotes(states)
- @defer.inlineCallbacks
- def notify_for_states(self, state, stream_id):
- parties = yield get_interested_parties(self.store, [state])
+ async def notify_for_states(self, state, stream_id):
+ parties = await get_interested_parties(self.store, [state])
room_ids_to_states, users_to_states = parties
self.notifier.on_new_event(
- "presence_key", stream_id, rooms=room_ids_to_states.keys(),
- users=[UserID.from_string(u) for u in users_to_states]
+ "presence_key",
+ stream_id,
+ rooms=room_ids_to_states.keys(),
+ users=[UserID.from_string(u) for u in users_to_states],
)
def _push_to_remotes(self, states):
@@ -617,8 +619,7 @@ class PresenceHandler(object):
"""
self.federation.send_presence(states)
- @defer.inlineCallbacks
- def incoming_presence(self, origin, content):
+ async def incoming_presence(self, origin, content):
"""Called when we receive a `m.presence` EDU from a remote server.
"""
now = self.clock.time_msec()
@@ -631,15 +632,15 @@ class PresenceHandler(object):
user_id = push.get("user_id", None)
if not user_id:
logger.info(
- "Got presence update from %r with no 'user_id': %r",
- origin, push,
+ "Got presence update from %r with no 'user_id': %r", origin, push
)
continue
if get_domain_from_id(user_id) != origin:
logger.info(
"Got presence update from %r with bad 'user_id': %r",
- origin, user_id,
+ origin,
+ user_id,
)
continue
@@ -647,14 +648,12 @@ class PresenceHandler(object):
if not presence_state:
logger.info(
"Got presence update from %r with no 'presence_state': %r",
- origin, push,
+ origin,
+ push,
)
continue
- new_fields = {
- "state": presence_state,
- "last_federation_update_ts": now,
- }
+ new_fields = {"state": presence_state, "last_federation_update_ts": now}
last_active_ago = push.get("last_active_ago", None)
if last_active_ago is not None:
@@ -663,24 +662,19 @@ class PresenceHandler(object):
new_fields["status_msg"] = push.get("status_msg", None)
new_fields["currently_active"] = push.get("currently_active", False)
- prev_state = yield self.current_state_for_user(user_id)
+ prev_state = await self.current_state_for_user(user_id)
updates.append(prev_state.copy_and_replace(**new_fields))
if updates:
federation_presence_counter.inc(len(updates))
- yield self._update_states(updates)
+ await self._update_states(updates)
- @defer.inlineCallbacks
- def get_state(self, target_user, as_event=False):
- results = yield self.get_states(
- [target_user.to_string()],
- as_event=as_event,
- )
+ async def get_state(self, target_user, as_event=False):
+ results = await self.get_states([target_user.to_string()], as_event=as_event)
- defer.returnValue(results[0])
+ return results[0]
- @defer.inlineCallbacks
- def get_states(self, target_user_ids, as_event=False):
+ async def get_states(self, target_user_ids, as_event=False):
"""Get the presence state for users.
Args:
@@ -691,44 +685,43 @@ class PresenceHandler(object):
list
"""
- updates = yield self.current_state_for_users(target_user_ids)
+ updates = await self.current_state_for_users(target_user_ids)
updates = list(updates.values())
- for user_id in set(target_user_ids) - set(u.user_id for u in updates):
+ for user_id in set(target_user_ids) - {u.user_id for u in updates}:
updates.append(UserPresenceState.default(user_id))
now = self.clock.time_msec()
if as_event:
- defer.returnValue([
+ return [
{
"type": "m.presence",
"content": format_user_presence_state(state, now),
}
for state in updates
- ])
+ ]
else:
- defer.returnValue(updates)
+ return updates
- @defer.inlineCallbacks
- def set_state(self, target_user, state, ignore_status_msg=False):
+ async def set_state(self, target_user, state, ignore_status_msg=False):
"""Set the presence state of the user.
"""
status_msg = state.get("status_msg", None)
presence = state["presence"]
valid_presence = (
- PresenceState.ONLINE, PresenceState.UNAVAILABLE, PresenceState.OFFLINE
+ PresenceState.ONLINE,
+ PresenceState.UNAVAILABLE,
+ PresenceState.OFFLINE,
)
if presence not in valid_presence:
raise SynapseError(400, "Invalid presence state")
user_id = target_user.to_string()
- prev_state = yield self.current_state_for_user(user_id)
+ prev_state = await self.current_state_for_user(user_id)
- new_fields = {
- "state": presence
- }
+ new_fields = {"state": presence}
if not ignore_status_msg:
msg = status_msg if presence != PresenceState.OFFLINE else None
@@ -737,26 +730,24 @@ class PresenceHandler(object):
if presence == PresenceState.ONLINE:
new_fields["last_active_ts"] = self.clock.time_msec()
- yield self._update_states([prev_state.copy_and_replace(**new_fields)])
+ await self._update_states([prev_state.copy_and_replace(**new_fields)])
- @defer.inlineCallbacks
- def is_visible(self, observed_user, observer_user):
+ async def is_visible(self, observed_user, observer_user):
"""Returns whether a user can see another user's presence.
"""
- observer_room_ids = yield self.store.get_rooms_for_user(
+ observer_room_ids = await self.store.get_rooms_for_user(
observer_user.to_string()
)
- observed_room_ids = yield self.store.get_rooms_for_user(
+ observed_room_ids = await self.store.get_rooms_for_user(
observed_user.to_string()
)
if observer_room_ids & observed_room_ids:
- defer.returnValue(True)
+ return True
- defer.returnValue(False)
+ return False
- @defer.inlineCallbacks
- def get_all_presence_updates(self, last_id, current_id):
+ async def get_all_presence_updates(self, last_id, current_id):
"""
Gets a list of presence update rows from between the given stream ids.
Each row has:
@@ -771,8 +762,8 @@ class PresenceHandler(object):
"""
# TODO(markjh): replicate the unpersisted changes.
# This could use the in-memory stores for recent changes.
- rows = yield self.store.get_all_presence_updates(last_id, current_id)
- defer.returnValue(rows)
+ rows = await self.store.get_all_presence_updates(last_id, current_id)
+ return rows
def notify_new_event(self):
"""Called when new events have happened. Handles users and servers
@@ -782,38 +773,43 @@ class PresenceHandler(object):
if self._event_processing:
return
- @defer.inlineCallbacks
- def _process_presence():
+ async def _process_presence():
assert not self._event_processing
self._event_processing = True
try:
- yield self._unsafe_process()
+ await self._unsafe_process()
finally:
self._event_processing = False
run_as_background_process("presence.notify_new_event", _process_presence)
- @defer.inlineCallbacks
- def _unsafe_process(self):
+ async def _unsafe_process(self):
# Loop round handling deltas until we're up to date
while True:
with Measure(self.clock, "presence_delta"):
- deltas = yield self.store.get_current_state_deltas(self._event_pos)
- if not deltas:
+ room_max_stream_ordering = self.store.get_room_max_stream_ordering()
+ if self._event_pos == room_max_stream_ordering:
return
- yield self._handle_state_delta(deltas)
+ logger.debug(
+ "Processing presence stats %s->%s",
+ self._event_pos,
+ room_max_stream_ordering,
+ )
+ max_pos, deltas = await self.store.get_current_state_deltas(
+ self._event_pos, room_max_stream_ordering
+ )
+ await self._handle_state_delta(deltas)
- self._event_pos = deltas[-1]["stream_id"]
+ self._event_pos = max_pos
# Expose current event processing position to prometheus
synapse.metrics.event_processing_positions.labels("presence").set(
- self._event_pos
+ max_pos
)
- @defer.inlineCallbacks
- def _handle_state_delta(self, deltas):
+ async def _handle_state_delta(self, deltas):
"""Process current state deltas to find new joins that need to be
handled.
"""
@@ -834,13 +830,13 @@ class PresenceHandler(object):
# joins.
continue
- event = yield self.store.get_event(event_id, allow_none=True)
+ event = await self.store.get_event(event_id, allow_none=True)
if not event or event.content.get("membership") != Membership.JOIN:
# We only care about joins
continue
if prev_event_id:
- prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
+ prev_event = await self.store.get_event(prev_event_id, allow_none=True)
if (
prev_event
and prev_event.content.get("membership") == Membership.JOIN
@@ -848,10 +844,9 @@ class PresenceHandler(object):
# Ignore changes to join events.
continue
- yield self._on_user_joined_room(room_id, state_key)
+ await self._on_user_joined_room(room_id, state_key)
- @defer.inlineCallbacks
- def _on_user_joined_room(self, room_id, user_id):
+ async def _on_user_joined_room(self, room_id, user_id):
"""Called when we detect a user joining the room via the current state
delta stream.
@@ -870,15 +865,14 @@ class PresenceHandler(object):
# TODO: We should be able to filter the hosts down to those that
# haven't previously seen the user
- state = yield self.current_state_for_user(user_id)
- hosts = yield self.state.get_current_hosts_in_room(room_id)
+ state = await self.current_state_for_user(user_id)
+ hosts = await self.state.get_current_hosts_in_room(room_id)
# Filter out ourselves.
- hosts = set(host for host in hosts if host != self.server_name)
+ hosts = {host for host in hosts if host != self.server_name}
self.federation.send_presence_to_destinations(
- states=[state],
- destinations=hosts,
+ states=[state], destinations=hosts
)
else:
# A remote user has joined the room, so we need to:
@@ -892,10 +886,10 @@ class PresenceHandler(object):
# TODO: Check that this is actually a new server joining the
# room.
- user_ids = yield self.state.get_current_users_in_room(room_id)
+ user_ids = await self.state.get_current_users_in_room(room_id)
user_ids = list(filter(self.is_mine_id, user_ids))
- states = yield self.current_state_for_users(user_ids)
+ states = await self.current_state_for_users(user_ids)
# Filter out old presence, i.e. offline presence states where
# the user hasn't been active for a week. We can change this
@@ -904,7 +898,8 @@ class PresenceHandler(object):
# default state.
now = self.clock.time_msec()
states = [
- state for state in states.values()
+ state
+ for state in states.values()
if state.state != PresenceState.OFFLINE
or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000
or state.status_msg is not None
@@ -912,8 +907,7 @@ class PresenceHandler(object):
if states:
self.federation.send_presence_to_destinations(
- states=states,
- destinations=[get_domain_from_id(user_id)],
+ states=states, destinations=[get_domain_from_id(user_id)]
)
@@ -937,7 +931,10 @@ def should_notify(old_state, new_state):
notify_reason_counter.labels("current_active_change").inc()
return True
- if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
+ if (
+ new_state.last_active_ts - old_state.last_active_ts
+ > LAST_ACTIVE_GRANULARITY
+ ):
# Only notify about last active bumps if we're not currently acive
if not new_state.currently_active:
notify_reason_counter.labels("last_active_change_online").inc()
@@ -958,9 +955,7 @@ def format_user_presence_state(state, now, include_user_id=True):
The "user_id" is optional so that this function can be used to format presence
updates for client /sync responses and for federation /send requests.
"""
- content = {
- "presence": state.state,
- }
+ content = {"presence": state.state}
if include_user_id:
content["user_id"] = state.user_id
if state.last_active_ts:
@@ -984,10 +979,16 @@ class PresenceEventSource(object):
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
- @defer.inlineCallbacks
@log_function
- def get_new_events(self, user, from_key, room_ids=None, include_offline=True,
- explicit_room_id=None, **kwargs):
+ async def get_new_events(
+ self,
+ user,
+ from_key,
+ room_ids=None,
+ include_offline=True,
+ explicit_room_id=None,
+ **kwargs
+ ):
# The process for getting presence events are:
# 1. Get the rooms the user is in.
# 2. Get the list of user in the rooms.
@@ -1004,12 +1005,29 @@ class PresenceEventSource(object):
if from_key is not None:
from_key = int(from_key)
+ max_token = self.store.get_current_presence_token()
+ if from_key == max_token:
+ # This is necessary as due to the way stream ID generators work
+ # we may get updates that have a stream ID greater than the max
+ # token (e.g. max_token is N but stream generator may return
+ # results for N+2, due to N+1 not having finished being
+ # persisted yet).
+ #
+ # This is usually fine, as it just means that we may send down
+ # some presence updates multiple times. However, we need to be
+ # careful that the sync stream either actually does make some
+ # progress or doesn't return, otherwise clients will end up
+ # tight looping calling /sync due to it immediately returning
+ # the same token repeatedly.
+ #
+ # Hence this guard where we just return nothing so that the sync
+ # doesn't return. C.f. #5503.
+ return [], max_token
+
presence = self.get_presence_handler()
stream_change_cache = self.store.presence_stream_cache
- max_token = self.store.get_current_presence_token()
-
- users_interested_in = yield self._get_interested_in(user, explicit_room_id)
+ users_interested_in = await self._get_interested_in(user, explicit_room_id)
user_ids_changed = set()
changed = None
@@ -1030,29 +1048,29 @@ class PresenceEventSource(object):
if from_key:
user_ids_changed = stream_change_cache.get_entities_changed(
- users_interested_in, from_key,
+ users_interested_in, from_key
)
else:
user_ids_changed = users_interested_in
- updates = yield presence.current_state_for_users(user_ids_changed)
+ updates = await presence.current_state_for_users(user_ids_changed)
if include_offline:
- defer.returnValue((list(updates.values()), max_token))
+ return (list(updates.values()), max_token)
else:
- defer.returnValue(([
- s for s in itervalues(updates)
- if s.state != PresenceState.OFFLINE
- ], max_token))
+ return (
+ [s for s in itervalues(updates) if s.state != PresenceState.OFFLINE],
+ max_token,
+ )
def get_current_key(self):
return self.store.get_current_presence_token()
- def get_pagination_rows(self, user, pagination_config, key):
- return self.get_new_events(user, from_key=None, include_offline=False)
+ async def get_pagination_rows(self, user, pagination_config, key):
+ return await self.get_new_events(user, from_key=None, include_offline=False)
- @cachedInlineCallbacks(num_args=2, cache_context=True)
- def _get_interested_in(self, user, explicit_room_id, cache_context):
+ @cached(num_args=2, cache_context=True)
+ async def _get_interested_in(self, user, explicit_room_id, cache_context):
"""Returns the set of users that the given user should see presence
updates for
"""
@@ -1060,18 +1078,18 @@ class PresenceEventSource(object):
users_interested_in = set()
users_interested_in.add(user_id) # So that we receive our own presence
- users_who_share_room = yield self.store.get_users_who_share_room_with_user(
- user_id, on_invalidate=cache_context.invalidate,
+ users_who_share_room = await self.store.get_users_who_share_room_with_user(
+ user_id, on_invalidate=cache_context.invalidate
)
users_interested_in.update(users_who_share_room)
if explicit_room_id:
- user_ids = yield self.store.get_users_in_room(
- explicit_room_id, on_invalidate=cache_context.invalidate,
+ user_ids = await self.store.get_users_in_room(
+ explicit_room_id, on_invalidate=cache_context.invalidate
)
users_interested_in.update(user_ids)
- defer.returnValue(users_interested_in)
+ return users_interested_in
def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
@@ -1123,9 +1141,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now):
if now - state.last_active_ts > IDLE_TIMER:
# Currently online, but last activity ages ago so auto
# idle
- state = state.copy_and_replace(
- state=PresenceState.UNAVAILABLE,
- )
+ state = state.copy_and_replace(state=PresenceState.UNAVAILABLE)
changed = True
elif now - state.last_active_ts > LAST_ACTIVE_GRANULARITY:
# So that we send down a notification that we've
@@ -1145,8 +1161,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now):
sync_or_active = max(state.last_user_sync_ts, state.last_active_ts)
if now - sync_or_active > SYNC_ONLINE_TIMEOUT:
state = state.copy_and_replace(
- state=PresenceState.OFFLINE,
- status_msg=None,
+ state=PresenceState.OFFLINE, status_msg=None
)
changed = True
else:
@@ -1155,10 +1170,7 @@ def handle_timeout(state, is_mine, syncing_user_ids, now):
# no one gets stuck online forever.
if now - state.last_federation_update_ts > FEDERATION_TIMEOUT:
# The other side seems to have disappeared.
- state = state.copy_and_replace(
- state=PresenceState.OFFLINE,
- status_msg=None,
- )
+ state = state.copy_and_replace(state=PresenceState.OFFLINE, status_msg=None)
changed = True
return state if changed else None
@@ -1193,21 +1205,17 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now):
if new_state.state == PresenceState.ONLINE:
# Idle timer
wheel_timer.insert(
- now=now,
- obj=user_id,
- then=new_state.last_active_ts + IDLE_TIMER
+ now=now, obj=user_id, then=new_state.last_active_ts + IDLE_TIMER
)
active = now - new_state.last_active_ts < LAST_ACTIVE_GRANULARITY
- new_state = new_state.copy_and_replace(
- currently_active=active,
- )
+ new_state = new_state.copy_and_replace(currently_active=active)
if active:
wheel_timer.insert(
now=now,
obj=user_id,
- then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY
+ then=new_state.last_active_ts + LAST_ACTIVE_GRANULARITY,
)
if new_state.state != PresenceState.OFFLINE:
@@ -1215,29 +1223,25 @@ def handle_update(prev_state, new_state, is_mine, wheel_timer, now):
wheel_timer.insert(
now=now,
obj=user_id,
- then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT
+ then=new_state.last_user_sync_ts + SYNC_ONLINE_TIMEOUT,
)
last_federate = new_state.last_federation_update_ts
if now - last_federate > FEDERATION_PING_INTERVAL:
# Been a while since we've poked remote servers
- new_state = new_state.copy_and_replace(
- last_federation_update_ts=now,
- )
+ new_state = new_state.copy_and_replace(last_federation_update_ts=now)
federation_ping = True
else:
wheel_timer.insert(
now=now,
obj=user_id,
- then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT
+ then=new_state.last_federation_update_ts + FEDERATION_TIMEOUT,
)
# Check whether the change was something worth notifying about
if should_notify(prev_state, new_state):
- new_state = new_state.copy_and_replace(
- last_federation_update_ts=now,
- )
+ new_state = new_state.copy_and_replace(last_federation_update_ts=now)
persist_and_notify = True
return new_state, persist_and_notify, federation_ping
@@ -1255,8 +1259,8 @@ def get_interested_parties(store, states):
2-tuple: `(room_ids_to_states, users_to_states)`,
with each item being a dict of `entity_name` -> `[UserPresenceState]`
"""
- room_ids_to_states = {}
- users_to_states = {}
+ room_ids_to_states = {} # type: Dict[str, List[UserPresenceState]]
+ users_to_states = {} # type: Dict[str, List[UserPresenceState]]
for state in states:
room_ids = yield store.get_rooms_for_user(state.user_id)
for room_id in room_ids:
@@ -1265,7 +1269,7 @@ def get_interested_parties(store, states):
# Always notify self
users_to_states.setdefault(state.user_id, []).append(state)
- defer.returnValue((room_ids_to_states, users_to_states))
+ return room_ids_to_states, users_to_states
@defer.inlineCallbacks
@@ -1299,4 +1303,4 @@ def get_interested_remotes(store, states, state_handler):
host = get_domain_from_id(user_id)
hosts_and_states.append(([host], states))
- defer.returnValue(hosts_and_states)
+ return hosts_and_states
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 584f804986..e800504ea6 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -16,6 +16,7 @@
import logging
+from six import raise_from
from six.moves import range
from signedjson.sign import sign_json
@@ -24,20 +25,21 @@ from twisted.internet import defer, reactor
from synapse.api.errors import (
AuthError,
- CodeMessageException,
Codes,
+ HttpResponseException,
+ RequestSendFailed,
StoreError,
SynapseError,
)
+from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import UserID, get_domain_from_id
-from synapse.util.logcontext import run_in_background
+from synapse.types import UserID, create_requester, get_domain_from_id
from ._base import BaseHandler
logger = logging.getLogger(__name__)
-MAX_DISPLAYNAME_LEN = 100
+MAX_DISPLAYNAME_LEN = 256
MAX_AVATAR_URL_LEN = 1000
@@ -68,7 +70,7 @@ class BaseProfileHandler(BaseHandler):
if hs.config.worker_app is None:
self.clock.looping_call(
- self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS,
+ self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS
)
if len(self.hs.config.replicate_user_profiles_to) > 0:
@@ -112,7 +114,7 @@ class BaseProfileHandler(BaseHandler):
yield self._replicate_host_profile_batch(repl_host, i)
except Exception:
logger.exception(
- "Exception while replicating to %s: aborting for now", repl_host,
+ "Exception while replicating to %s: aborting for now", repl_host
)
@defer.inlineCallbacks
@@ -120,18 +122,16 @@ class BaseProfileHandler(BaseHandler):
logger.info("Replicating profile batch %d to %s", batchnum, host)
batch_rows = yield self.store.get_profile_batch(batchnum)
batch = {
- UserID(r["user_id"], self.hs.hostname).to_string(): ({
- "display_name": r["displayname"],
- "avatar_url": r["avatar_url"],
- } if r["active"] else None) for r in batch_rows
+ UserID(r["user_id"], self.hs.hostname).to_string(): (
+ {"display_name": r["displayname"], "avatar_url": r["avatar_url"]}
+ if r["active"]
+ else None
+ )
+ for r in batch_rows
}
url = "https://%s/_matrix/identity/api/v1/replicate_profiles" % (host,)
- body = {
- "batchnum": batchnum,
- "batch": batch,
- "origin_server": self.hs.hostname,
- }
+ body = {"batchnum": batchnum, "batch": batch, "origin_server": self.hs.hostname}
signed_body = sign_json(body, self.hs.hostname, self.hs.config.signing_key[0])
try:
yield self.http_client.post_json_get_json(url, signed_body)
@@ -139,7 +139,9 @@ class BaseProfileHandler(BaseHandler):
logger.info("Sucessfully replicated profile batch %d to %s", batchnum, host)
except Exception:
# This will get retried when the looping call next comes around
- logger.exception("Failed to replicate profile batch %d to %s", batchnum, host)
+ logger.exception(
+ "Failed to replicate profile batch %d to %s", batchnum, host
+ )
raise
@defer.inlineCallbacks
@@ -159,25 +161,20 @@ class BaseProfileHandler(BaseHandler):
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
raise
- defer.returnValue({
- "displayname": displayname,
- "avatar_url": avatar_url,
- })
+ return {"displayname": displayname, "avatar_url": avatar_url}
else:
try:
result = yield self.federation.make_query(
destination=target_user.domain,
query_type="profile",
- args={
- "user_id": user_id,
- },
+ args={"user_id": user_id},
ignore_backoff=True,
)
- defer.returnValue(result)
- except CodeMessageException as e:
- if e.code != 404:
- logger.exception("Failed to get displayname")
- raise
+ return result
+ except RequestSendFailed as e:
+ raise_from(SynapseError(502, "Failed to fetch profile"), e)
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
@defer.inlineCallbacks
def get_profile_from_cache(self, user_id):
@@ -199,13 +196,10 @@ class BaseProfileHandler(BaseHandler):
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
raise
- defer.returnValue({
- "displayname": displayname,
- "avatar_url": avatar_url,
- })
+ return {"displayname": displayname, "avatar_url": avatar_url}
else:
profile = yield self.store.get_from_remote_profile_cache(user_id)
- defer.returnValue(profile or {})
+ return profile or {}
@defer.inlineCallbacks
def get_displayname(self, target_user):
@@ -219,24 +213,21 @@ class BaseProfileHandler(BaseHandler):
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
raise
- defer.returnValue(displayname)
+ return displayname
else:
try:
result = yield self.federation.make_query(
destination=target_user.domain,
query_type="profile",
- args={
- "user_id": target_user.to_string(),
- "field": "displayname",
- },
+ args={"user_id": target_user.to_string(), "field": "displayname"},
ignore_backoff=True,
)
- except CodeMessageException as e:
- if e.code != 404:
- logger.exception("Failed to get displayname")
- raise
+ except RequestSendFailed as e:
+ raise_from(SynapseError(502, "Failed to fetch profile"), e)
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
- defer.returnValue(result["displayname"])
+ return result["displayname"]
@defer.inlineCallbacks
def set_displayname(self, target_user, requester, new_displayname, by_admin=False):
@@ -249,7 +240,7 @@ class BaseProfileHandler(BaseHandler):
by_admin (bool): Whether this change was made by an administrator.
"""
if not self.hs.is_mine(target_user):
- raise SynapseError(400, "User is not hosted on this Home Server")
+ raise SynapseError(400, "User is not hosted on this homeserver")
if not by_admin and requester and target_user != requester.user:
raise AuthError(400, "Cannot set another user's displayname")
@@ -257,22 +248,32 @@ class BaseProfileHandler(BaseHandler):
if not by_admin and self.hs.config.disable_set_displayname:
profile = yield self.store.get_profileinfo(target_user.localpart)
if profile.display_name:
- raise SynapseError(400, "Changing displayname is disabled on this server")
+ raise SynapseError(
+ 400, "Changing displayname is disabled on this server"
+ )
if len(new_displayname) > MAX_DISPLAYNAME_LEN:
raise SynapseError(
- 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN, ),
+ 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
)
- if new_displayname == '':
+ if new_displayname == "":
new_displayname = None
if len(self.hs.config.replicate_user_profiles_to) > 0:
- cur_batchnum = yield self.store.get_latest_profile_replication_batch_number()
+ cur_batchnum = (
+ yield self.store.get_latest_profile_replication_batch_number()
+ )
new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
else:
new_batchnum = None
+ # If the admin changes the display name of a user, the requesting user cannot send
+ # the join event to update the displayname in the rooms.
+ # This must be done by the target user himself.
+ if by_admin:
+ requester = create_requester(target_user)
+
yield self.store.set_profile_displayname(
target_user.localpart, new_displayname, new_batchnum
)
@@ -304,7 +305,9 @@ class BaseProfileHandler(BaseHandler):
where we've already done these checks anyway.
"""
if len(self.hs.config.replicate_user_profiles_to) > 0:
- cur_batchnum = yield self.store.get_latest_profile_replication_batch_number()
+ cur_batchnum = (
+ yield self.store.get_latest_profile_replication_batch_number()
+ )
new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
else:
new_batchnum = None
@@ -326,31 +329,28 @@ class BaseProfileHandler(BaseHandler):
if e.code == 404:
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
raise
- defer.returnValue(avatar_url)
+ return avatar_url
else:
try:
result = yield self.federation.make_query(
destination=target_user.domain,
query_type="profile",
- args={
- "user_id": target_user.to_string(),
- "field": "avatar_url",
- },
+ args={"user_id": target_user.to_string(), "field": "avatar_url"},
ignore_backoff=True,
)
- except CodeMessageException as e:
- if e.code != 404:
- logger.exception("Failed to get avatar_url")
- raise
+ except RequestSendFailed as e:
+ raise_from(SynapseError(502, "Failed to fetch profile"), e)
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
- defer.returnValue(result["avatar_url"])
+ return result["avatar_url"]
@defer.inlineCallbacks
def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False):
"""target_user is the user whose avatar_url is to be changed;
auth_user is the user attempting to make this change."""
if not self.hs.is_mine(target_user):
- raise SynapseError(400, "User is not hosted on this Home Server")
+ raise SynapseError(400, "User is not hosted on this homeserver")
if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's avatar_url")
@@ -358,17 +358,21 @@ class BaseProfileHandler(BaseHandler):
if not by_admin and self.hs.config.disable_set_avatar_url:
profile = yield self.store.get_profileinfo(target_user.localpart)
if profile.avatar_url:
- raise SynapseError(400, "Changing avatar url is disabled on this server")
+ raise SynapseError(
+ 400, "Changing avatar url is disabled on this server"
+ )
if len(self.hs.config.replicate_user_profiles_to) > 0:
- cur_batchnum = yield self.store.get_latest_profile_replication_batch_number()
+ cur_batchnum = (
+ yield self.store.get_latest_profile_replication_batch_number()
+ )
new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
else:
new_batchnum = None
if len(new_avatar_url) > MAX_AVATAR_URL_LEN:
raise SynapseError(
- 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN, ),
+ 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
)
# Enforce a max avatar size if one is defined
@@ -386,8 +390,10 @@ class BaseProfileHandler(BaseHandler):
media_size = media_info["media_length"]
if self.max_avatar_size and media_size > self.max_avatar_size:
raise SynapseError(
- 400, "Avatars must be less than %s bytes in size" %
- (self.max_avatar_size,), errcode=Codes.TOO_LARGE,
+ 400,
+ "Avatars must be less than %s bytes in size"
+ % (self.max_avatar_size,),
+ errcode=Codes.TOO_LARGE,
)
# Ensure the avatar's file type is allowed
@@ -396,12 +402,15 @@ class BaseProfileHandler(BaseHandler):
and media_info["media_type"] not in self.allowed_avatar_mimetypes
):
raise SynapseError(
- 400, "Avatar file type '%s' not allowed" %
- media_info["media_type"],
+ 400, "Avatar file type '%s' not allowed" % media_info["media_type"]
)
+ # Same like set_displayname
+ if by_admin:
+ requester = create_requester(target_user)
+
yield self.store.set_profile_avatar_url(
- target_user.localpart, new_avatar_url, new_batchnum,
+ target_user.localpart, new_avatar_url, new_batchnum
)
if self.hs.config.user_directory_search_all_users:
@@ -433,7 +442,7 @@ class BaseProfileHandler(BaseHandler):
def on_profile_query(self, args):
user = UserID.from_string(args["user_id"])
if not self.hs.is_mine(user):
- raise SynapseError(400, "User is not hosted on this Home Server")
+ raise SynapseError(400, "User is not hosted on this homeserver")
just_field = args.get("field", None)
@@ -453,7 +462,7 @@ class BaseProfileHandler(BaseHandler):
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
raise
- defer.returnValue(response)
+ return response
@defer.inlineCallbacks
def _update_join_states(self, requester, target_user):
@@ -462,9 +471,7 @@ class BaseProfileHandler(BaseHandler):
yield self.ratelimit(requester)
- room_ids = yield self.store.get_rooms_for_user(
- target_user.to_string(),
- )
+ room_ids = yield self.store.get_rooms_for_user(target_user.to_string())
for room_id in room_ids:
handler = self.hs.get_room_member_handler()
@@ -479,15 +486,14 @@ class BaseProfileHandler(BaseHandler):
ratelimit=False, # Try to hide that these events aren't atomic.
)
except Exception as e:
- logger.warn(
- "Failed to update join event for room %s - %s",
- room_id, str(e)
+ logger.warning(
+ "Failed to update join event for room %s - %s", room_id, str(e)
)
@defer.inlineCallbacks
def check_profile_query_allowed(self, target_user, requester=None):
"""Checks whether a profile query is allowed. If the
- 'limit_profile_requests_to_known_users' config flag is set to True and a
+ 'require_auth_for_profile_requests' config flag is set to True and a
'requester' is provided, the query is only allowed if the two users
share a room.
@@ -500,12 +506,16 @@ class BaseProfileHandler(BaseHandler):
be found to be in any room the server is in, and therefore the query
is denied.
"""
+
# Implementation of MSC1301: don't allow looking up profiles if the
# requester isn't in the same room as the target. We expect requester to
# be None when this function is called outside of a profile query, e.g.
# when building a membership event. In this case, we must allow the
# lookup.
- if not self.hs.config.limit_profile_requests_to_known_users or not requester:
+ if (
+ not self.hs.config.limit_profile_requests_to_users_who_share_rooms
+ or not requester
+ ):
return
# Always allow the user to query their own profile.
@@ -513,11 +523,9 @@ class BaseProfileHandler(BaseHandler):
return
try:
- requester_rooms = yield self.store.get_rooms_for_user(
- requester.to_string()
- )
+ requester_rooms = yield self.store.get_rooms_for_user(requester.to_string())
target_user_rooms = yield self.store.get_rooms_for_user(
- target_user.to_string(),
+ target_user.to_string()
)
# Check if the room lists have no elements in common.
@@ -541,12 +549,12 @@ class MasterProfileHandler(BaseProfileHandler):
assert hs.config.worker_app is None
self.clock.looping_call(
- self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS,
+ self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS
)
def _start_update_remote_profile_cache(self):
return run_as_background_process(
- "Update remote profile", self._update_remote_profile_cache,
+ "Update remote profile", self._update_remote_profile_cache
)
@defer.inlineCallbacks
@@ -560,7 +568,7 @@ class MasterProfileHandler(BaseProfileHandler):
for user_id, displayname, avatar_url in entries:
is_subscribed = yield self.store.is_subscribed_remote_profile_for_user(
- user_id,
+ user_id
)
if not is_subscribed:
yield self.store.maybe_delete_remote_profile_cache(user_id)
@@ -570,9 +578,7 @@ class MasterProfileHandler(BaseProfileHandler):
profile = yield self.federation.make_query(
destination=get_domain_from_id(user_id),
query_type="profile",
- args={
- "user_id": user_id,
- },
+ args={"user_id": user_id},
ignore_backoff=True,
)
except Exception:
@@ -587,6 +593,4 @@ class MasterProfileHandler(BaseProfileHandler):
new_avatar = profile.get("avatar_url")
# We always hit update to update the last_check timestamp
- yield self.store.update_remote_profile_cache(
- user_id, new_name, new_avatar
- )
+ yield self.store.update_remote_profile_cache(user_id, new_name, new_avatar)
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index 32108568c6..e3b528d271 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.util.async_helpers import Linearizer
from ._base import BaseHandler
@@ -32,8 +30,7 @@ class ReadMarkerHandler(BaseHandler):
self.read_marker_linearizer = Linearizer(name="read_marker")
self.notifier = hs.get_notifier()
- @defer.inlineCallbacks
- def received_client_read_marker(self, room_id, user_id, event_id):
+ async def received_client_read_marker(self, room_id, user_id, event_id):
"""Updates the read marker for a given user in a given room if the event ID given
is ahead in the stream relative to the current read marker.
@@ -41,25 +38,22 @@ class ReadMarkerHandler(BaseHandler):
the read marker has changed.
"""
- with (yield self.read_marker_linearizer.queue((room_id, user_id))):
- existing_read_marker = yield self.store.get_account_data_for_room_and_type(
- user_id, room_id, "m.fully_read",
+ with await self.read_marker_linearizer.queue((room_id, user_id)):
+ existing_read_marker = await self.store.get_account_data_for_room_and_type(
+ user_id, room_id, "m.fully_read"
)
should_update = True
if existing_read_marker:
# Only update if the new marker is ahead in the stream
- should_update = yield self.store.is_event_after(
- event_id,
- existing_read_marker['event_id']
+ should_update = await self.store.is_event_after(
+ event_id, existing_read_marker["event_id"]
)
if should_update:
- content = {
- "event_id": event_id
- }
- max_id = yield self.store.add_account_data_to_room(
+ content = {"event_id": event_id}
+ max_id = await self.store.add_account_data_to_room(
user_id, room_id, "m.fully_read", content
)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 274d2946ad..8bc100db42 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -17,7 +17,8 @@ import logging
from twisted.internet import defer
from synapse.handlers._base import BaseHandler
-from synapse.types import ReadReceipt
+from synapse.types import ReadReceipt, get_domain_from_id
+from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
@@ -36,34 +37,41 @@ class ReceiptsHandler(BaseHandler):
self.clock = self.hs.get_clock()
self.state = hs.get_state_handler()
- @defer.inlineCallbacks
- def _received_remote_receipt(self, origin, content):
+ async def _received_remote_receipt(self, origin, content):
"""Called when we receive an EDU of type m.receipt from a remote HS.
"""
- receipts = [
- ReadReceipt(
- room_id=room_id,
- receipt_type=receipt_type,
- user_id=user_id,
- event_ids=user_values["event_ids"],
- data=user_values.get("data", {}),
- )
- for room_id, room_values in content.items()
- for receipt_type, users in room_values.items()
- for user_id, user_values in users.items()
- ]
-
- yield self._handle_new_receipts(receipts)
-
- @defer.inlineCallbacks
- def _handle_new_receipts(self, receipts):
+ receipts = []
+ for room_id, room_values in content.items():
+ for receipt_type, users in room_values.items():
+ for user_id, user_values in users.items():
+ if get_domain_from_id(user_id) != origin:
+ logger.info(
+ "Received receipt for user %r from server %s, ignoring",
+ user_id,
+ origin,
+ )
+ continue
+
+ receipts.append(
+ ReadReceipt(
+ room_id=room_id,
+ receipt_type=receipt_type,
+ user_id=user_id,
+ event_ids=user_values["event_ids"],
+ data=user_values.get("data", {}),
+ )
+ )
+
+ await self._handle_new_receipts(receipts)
+
+ async def _handle_new_receipts(self, receipts):
"""Takes a list of receipts, stores them and informs the notifier.
"""
min_batch_id = None
max_batch_id = None
for receipt in receipts:
- res = yield self.store.insert_receipt(
+ res = await self.store.insert_receipt(
receipt.room_id,
receipt.receipt_type,
receipt.user_id,
@@ -84,23 +92,21 @@ class ReceiptsHandler(BaseHandler):
if min_batch_id is None:
# no new receipts
- defer.returnValue(False)
+ return False
- affected_room_ids = list(set([r.room_id for r in receipts]))
+ affected_room_ids = list({r.room_id for r in receipts})
- self.notifier.on_new_event(
- "receipt_key", max_batch_id, rooms=affected_room_ids
- )
+ self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
# Note that the min here shouldn't be relied upon to be accurate.
- yield self.hs.get_pusherpool().on_new_receipts(
- min_batch_id, max_batch_id, affected_room_ids,
+ await maybe_awaitable(
+ self.hs.get_pusherpool().on_new_receipts(
+ min_batch_id, max_batch_id, affected_room_ids
+ )
)
- defer.returnValue(True)
+ return True
- @defer.inlineCallbacks
- def received_client_receipt(self, room_id, receipt_type, user_id,
- event_id):
+ async def received_client_receipt(self, room_id, receipt_type, user_id, event_id):
"""Called when a client tells us a local user has read up to the given
event_id in the room.
"""
@@ -109,30 +115,14 @@ class ReceiptsHandler(BaseHandler):
receipt_type=receipt_type,
user_id=user_id,
event_ids=[event_id],
- data={
- "ts": int(self.clock.time_msec()),
- },
+ data={"ts": int(self.clock.time_msec())},
)
- is_new = yield self._handle_new_receipts([receipt])
+ is_new = await self._handle_new_receipts([receipt])
if not is_new:
return
- yield self.federation.send_read_receipt(receipt)
-
- @defer.inlineCallbacks
- def get_receipts_for_room(self, room_id, to_key):
- """Gets all receipts for a room, upto the given key.
- """
- result = yield self.store.get_linearized_receipts_for_room(
- room_id,
- to_key=to_key,
- )
-
- if not result:
- defer.returnValue([])
-
- defer.returnValue(result)
+ await self.federation.send_read_receipt(receipt)
class ReceiptEventSource(object):
@@ -145,17 +135,15 @@ class ReceiptEventSource(object):
to_key = yield self.get_current_key()
if from_key == to_key:
- defer.returnValue(([], to_key))
+ return [], to_key
events = yield self.store.get_linearized_receipts_for_rooms(
- room_ids,
- from_key=from_key,
- to_key=to_key,
+ room_ids, from_key=from_key, to_key=to_key
)
- defer.returnValue((events, to_key))
+ return (events, to_key)
- def get_current_key(self, direction='f'):
+ def get_current_key(self, direction="f"):
return self.store.get_max_receipt_stream_id()
@defer.inlineCallbacks
@@ -169,9 +157,7 @@ class ReceiptEventSource(object):
room_ids = yield self.store.get_rooms_for_user(user.to_string())
events = yield self.store.get_linearized_receipts_for_rooms(
- room_ids,
- from_key=from_key,
- to_key=to_key,
+ room_ids, from_key=from_key, to_key=to_key
)
- defer.returnValue((events, to_key))
+ return (events, to_key)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 7747964352..34b876b6af 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -20,17 +20,8 @@ from twisted.internet import defer
from synapse import types
from synapse.api.constants import MAX_USERID_LENGTH, LoginType
-from synapse.api.errors import (
- AuthError,
- Codes,
- ConsentNotGivenError,
- InvalidCaptchaError,
- LimitExceededError,
- RegistrationError,
- SynapseError,
-)
+from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
from synapse.config.server import is_threepid_reserved
-from synapse.http.client import CaptchaServerHttpClient
from synapse.http.servlet import assert_params_in_dict
from synapse.replication.http.login import RegisterDeviceReplicationServlet
from synapse.replication.http.register import (
@@ -39,7 +30,6 @@ from synapse.replication.http.register import (
)
from synapse.types import RoomAlias, RoomID, UserID, create_requester
from synapse.util.async_helpers import Linearizer
-from synapse.util.threepids import check_3pid_allowed
from ._base import BaseHandler
@@ -47,7 +37,6 @@ logger = logging.getLogger(__name__)
class RegistrationHandler(BaseHandler):
-
def __init__(self, hs):
"""
@@ -60,7 +49,6 @@ class RegistrationHandler(BaseHandler):
self._auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
- self.captcha_client = CaptchaServerHttpClient(hs)
self.http_client = hs.get_simple_http_client()
self.identity_handler = self.hs.get_handlers().identity_handler
self.ratelimiter = hs.get_registration_ratelimiter()
@@ -70,7 +58,7 @@ class RegistrationHandler(BaseHandler):
self.macaroon_gen = hs.get_macaroon_generator()
self._generate_user_id_linearizer = Linearizer(
- name="_generate_user_id_linearizer",
+ name="_generate_user_id_linearizer"
)
self._server_notices_mxid = hs.config.server_notices_mxid
@@ -78,38 +66,33 @@ class RegistrationHandler(BaseHandler):
if hs.config.worker_app:
self._register_client = ReplicationRegisterServlet.make_client(hs)
- self._register_device_client = (
- RegisterDeviceReplicationServlet.make_client(hs)
+ self._register_device_client = RegisterDeviceReplicationServlet.make_client(
+ hs
)
- self._post_registration_client = (
- ReplicationPostRegisterActionsServlet.make_client(hs)
+ self._post_registration_client = ReplicationPostRegisterActionsServlet.make_client(
+ hs
)
else:
self.device_handler = hs.get_device_handler()
self.pusher_pool = hs.get_pusherpool()
+ self.session_lifetime = hs.config.session_lifetime
+
@defer.inlineCallbacks
- def check_username(self, localpart, guest_access_token=None,
- assigned_user_id=None):
+ def check_username(self, localpart, guest_access_token=None, assigned_user_id=None):
if types.contains_invalid_mxid_characters(localpart):
raise SynapseError(
400,
"User ID can only contain characters a-z, 0-9, or '=_-./'",
- Codes.INVALID_USERNAME
+ Codes.INVALID_USERNAME,
)
if not localpart:
- raise SynapseError(
- 400,
- "User ID cannot be empty",
- Codes.INVALID_USERNAME
- )
+ raise SynapseError(400, "User ID cannot be empty", Codes.INVALID_USERNAME)
- if localpart[0] == '_':
+ if localpart[0] == "_":
raise SynapseError(
- 400,
- "User ID may not begin with _",
- Codes.INVALID_USERNAME
+ 400, "User ID may not begin with _", Codes.INVALID_USERNAME
)
user = UserID(localpart, self.hs.hostname)
@@ -129,19 +112,15 @@ class RegistrationHandler(BaseHandler):
if len(user_id) > MAX_USERID_LENGTH:
raise SynapseError(
400,
- "User ID may not be longer than %s characters" % (
- MAX_USERID_LENGTH,
- ),
- Codes.INVALID_USERNAME
+ "User ID may not be longer than %s characters" % (MAX_USERID_LENGTH,),
+ Codes.INVALID_USERNAME,
)
users = yield self.store.get_users_by_id_case_insensitive(user_id)
if users:
if not guest_access_token:
raise SynapseError(
- 400,
- "User ID already taken.",
- errcode=Codes.USER_IN_USE,
+ 400, "User ID already taken.", errcode=Codes.USER_IN_USE
)
user_data = yield self.auth.get_user_by_access_token(guest_access_token)
if not user_data["is_guest"] or user_data["user"].localpart != localpart:
@@ -153,11 +132,10 @@ class RegistrationHandler(BaseHandler):
)
@defer.inlineCallbacks
- def register(
+ def register_user(
self,
localpart=None,
password=None,
- generate_token=True,
guest_access_token=None,
make_guest=False,
admin=False,
@@ -175,11 +153,6 @@ class RegistrationHandler(BaseHandler):
password (unicode) : The password to assign to this user so they can
login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
- generate_token (bool): Whether a new access token should be
- generated. Having this be True should be considered deprecated,
- since it offers no means of associating a device_id with the
- access_token. Instead you should call auth_handler.issue_access_token
- after registration.
user_type (str|None): type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
default_display_name (unicode|None): if set, the new user's displayname
@@ -187,17 +160,18 @@ class RegistrationHandler(BaseHandler):
address (str|None): the IP address used to perform the registration.
bind_emails (List[str]): list of emails to bind to this account.
Returns:
- A tuple of (user_id, access_token).
+ Deferred[str]: user_id
Raises:
- RegistrationError if there was a problem registering.
+ SynapseError if there was a problem registering.
"""
+ yield self.check_registration_ratelimit(address)
yield self.auth.check_auth_blocking(threepid=threepid)
password_hash = None
if password:
password_hash = yield self._auth_handler.hash(password)
- if localpart:
+ if localpart is not None:
yield self.check_username(localpart, guest_access_token=guest_access_token)
was_guest = guest_access_token is not None
@@ -205,9 +179,8 @@ class RegistrationHandler(BaseHandler):
if not was_guest:
try:
int(localpart)
- raise RegistrationError(
- 400,
- "Numeric user IDs are reserved for guest users."
+ raise SynapseError(
+ 400, "Numeric user IDs are reserved for guest users."
)
except ValueError:
pass
@@ -222,12 +195,8 @@ class RegistrationHandler(BaseHandler):
elif default_display_name is None:
default_display_name = localpart
- token = None
- if generate_token:
- token = self.macaroon_gen.generate_access_token(user_id)
yield self.register_with_store(
user_id=user_id,
- token=token,
password_hash=password_hash,
was_guest=was_guest,
make_guest=make_guest,
@@ -239,7 +208,7 @@ class RegistrationHandler(BaseHandler):
if default_display_name:
yield self.profile_handler.set_displayname(
- user, None, default_display_name, by_admin=True,
+ user, None, default_display_name, by_admin=True
)
if self.hs.config.user_directory_search_all_users:
@@ -250,22 +219,22 @@ class RegistrationHandler(BaseHandler):
else:
# autogen a sequential user ID
- attempts = 0
- token = None
+ fail_count = 0
user = None
while not user:
- localpart = yield self._generate_user_id(attempts > 0)
+ # Fail after being unable to find a suitable ID a few times
+ if fail_count > 10:
+ raise SynapseError(500, "Unable to find a suitable guest user ID")
+
+ localpart = yield self._generate_user_id()
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
yield self.check_user_id_not_appservice_exclusive(user_id)
- if generate_token:
- token = self.macaroon_gen.generate_access_token(user_id)
if default_display_name is None:
default_display_name = localpart
try:
yield self.register_with_store(
user_id=user_id,
- token=token,
password_hash=password_hash,
make_guest=make_guest,
create_profile_with_displayname=default_display_name,
@@ -273,17 +242,24 @@ class RegistrationHandler(BaseHandler):
)
yield self.profile_handler.set_displayname(
- user, None, default_display_name, by_admin=True,
+ user, None, default_display_name, by_admin=True
)
+ # Successfully registered
+ break
except SynapseError:
# if user id is taken, just generate another
user = None
user_id = None
- token = None
- attempts += 1
+ fail_count += 1
+
if not self.hs.config.user_consent_at_registration:
yield self._auto_join_rooms(user_id)
+ else:
+ logger.info(
+ "Skipping auto-join for %s because consent is required at registration",
+ user_id,
+ )
# Bind any specified emails to this account
current_time = self.hs.get_clock().time_msec()
@@ -296,19 +272,17 @@ class RegistrationHandler(BaseHandler):
}
# Bind email to new account
- yield self._register_email_threepid(
- user_id, threepid_dict, None, False,
- )
+ yield self.register_email_threepid(user_id, threepid_dict, None)
# Prevent the new user from showing up in the user directory if the server
# mandates it.
if not self._show_in_user_directory:
yield self.store.add_account_data_for_user(
- user_id, "im.vector.hide_profile", {'hide_profile': True},
+ user_id, "im.vector.hide_profile", {"hide_profile": True}
)
yield self.profile_handler.set_active(user, False, True)
- defer.returnValue((user_id, token))
+ return user_id
@defer.inlineCallbacks
def _auto_join_rooms(self, user_id):
@@ -322,25 +296,22 @@ class RegistrationHandler(BaseHandler):
fake_requester = create_requester(user_id)
# try to create the room if we're the first real user on the server. Note
- # that an auto-generated support user is not a real user and will never be
+ # that an auto-generated support or bot user is not a real user and will never be
# the user to create the room
should_auto_create_rooms = False
- is_support = yield self.store.is_support_user(user_id)
- # There is an edge case where the first user is the support user, then
- # the room is never created, though this seems unlikely and
- # recoverable from given the support user being involved in the first
- # place.
- if self.hs.config.autocreate_auto_join_rooms and not is_support:
- count = yield self.store.count_all_users()
+ is_real_user = yield self.store.is_real_user(user_id)
+ if self.hs.config.autocreate_auto_join_rooms and is_real_user:
+ count = yield self.store.count_real_users()
should_auto_create_rooms = count == 1
for r in self.hs.config.auto_join_rooms:
+ logger.info("Auto-joining %s to %s", user_id, r)
try:
if should_auto_create_rooms:
room_alias = RoomAlias.from_string(r)
if self.hs.hostname != room_alias.domain:
logger.warning(
- 'Cannot create room alias %s, '
- 'it does not match server domain',
+ "Cannot create room alias %s, "
+ "it does not match server domain",
r,
)
else:
@@ -353,7 +324,7 @@ class RegistrationHandler(BaseHandler):
fake_requester,
config={
"preset": "public_chat",
- "room_alias_name": room_alias_localpart
+ "room_alias_name": room_alias_localpart,
},
ratelimit=False,
)
@@ -387,8 +358,9 @@ class RegistrationHandler(BaseHandler):
raise AuthError(403, "Invalid application service token.")
if not service.is_interested_in_user(user_id):
raise SynapseError(
- 400, "Invalid user localpart for this application service.",
- errcode=Codes.EXCLUSIVE
+ 400,
+ "Invalid user localpart for this application service.",
+ errcode=Codes.EXCLUSIVE,
)
service_id = service.id if service.is_exclusive_user(user_id) else None
@@ -399,7 +371,7 @@ class RegistrationHandler(BaseHandler):
password_hash = ""
if password:
- password_hash = yield self.auth_handler().hash(password)
+ password_hash = yield self._auth_handler().hash(password)
display_name = display_name or user.localpart
@@ -411,7 +383,7 @@ class RegistrationHandler(BaseHandler):
)
yield self.profile_handler.set_displayname(
- user, None, display_name, by_admin=True,
+ user, None, display_name, by_admin=True
)
if self.hs.config.user_directory_search_all_users:
@@ -420,127 +392,30 @@ class RegistrationHandler(BaseHandler):
user_id, profile
)
- defer.returnValue(user_id)
-
- @defer.inlineCallbacks
- def check_recaptcha(self, ip, private_key, challenge, response):
- """
- Checks a recaptcha is correct.
-
- Used only by c/s api v1
- """
-
- captcha_response = yield self._validate_captcha(
- ip,
- private_key,
- challenge,
- response
- )
- if not captcha_response["valid"]:
- logger.info("Invalid captcha entered from %s. Error: %s",
- ip, captcha_response["error_url"])
- raise InvalidCaptchaError(
- error_url=captcha_response["error_url"]
- )
- else:
- logger.info("Valid captcha entered from %s", ip)
-
- @defer.inlineCallbacks
- def register_saml2(self, localpart):
- """
- Registers email_id as SAML2 Based Auth.
- """
- if types.contains_invalid_mxid_characters(localpart):
- raise SynapseError(
- 400,
- "User ID can only contain characters a-z, 0-9, or '=_-./'",
- )
- yield self.auth.check_auth_blocking()
- user = UserID(localpart, self.hs.hostname)
- user_id = user.to_string()
-
- yield self.check_user_id_not_appservice_exclusive(user_id)
- token = self.macaroon_gen.generate_access_token(user_id)
- try:
- yield self.register_with_store(
- user_id=user_id,
- token=token,
- password_hash=None,
- create_profile_with_displayname=user.localpart,
- )
-
- yield self.profile_handler.set_displayname(
- user, None, user.localpart, by_admin=True,
- )
- except Exception as e:
- yield self.store.add_access_token_to_user(user_id, token)
- # Ignore Registration errors
- logger.exception(e)
- defer.returnValue((user_id, token))
-
- @defer.inlineCallbacks
- def register_email(self, threepidCreds):
- """
- Registers emails with an identity server.
-
- Used only by c/s api v1
- """
-
- for c in threepidCreds:
- logger.info("validating threepidcred sid %s on id server %s",
- c['sid'], c['idServer'])
- try:
- threepid = yield self.identity_handler.threepid_from_creds(c)
- except Exception:
- logger.exception("Couldn't validate 3pid")
- raise RegistrationError(400, "Couldn't validate 3pid")
-
- if not threepid:
- raise RegistrationError(400, "Couldn't validate 3pid")
- logger.info("got threepid with medium '%s' and address '%s'",
- threepid['medium'], threepid['address'])
-
- if not (
- yield check_3pid_allowed(self.hs, threepid['medium'], threepid['address'])
- ):
- raise RegistrationError(
- 403, "Third party identifier is not allowed"
- )
-
- @defer.inlineCallbacks
- def bind_emails(self, user_id, threepidCreds):
- """Links emails with a user ID and informs an identity server.
-
- Used only by c/s api v1
- """
-
- # Now we have a matrix ID, bind it to the threepids we were given
- for c in threepidCreds:
- # XXX: This should be a deferred list, shouldn't it?
- yield self.identity_handler.bind_threepid(c, user_id)
+ return user_id
def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
# don't allow people to register the server notices mxid
if self._server_notices_mxid is not None:
if user_id == self._server_notices_mxid:
raise SynapseError(
- 400, "This user ID is reserved.",
- errcode=Codes.EXCLUSIVE
+ 400, "This user ID is reserved.", errcode=Codes.EXCLUSIVE
)
# valid user IDs must not clash with any user ID namespaces claimed by
# application services.
services = self.store.get_app_services()
interested_services = [
- s for s in services
- if s.is_interested_in_user(user_id)
- and s != allowed_appservice
+ s
+ for s in services
+ if s.is_interested_in_user(user_id) and s != allowed_appservice
]
for service in interested_services:
if service.is_exclusive_user(user_id):
raise SynapseError(
- 400, "This user ID is reserved by an application service.",
- errcode=Codes.EXCLUSIVE
+ 400,
+ "This user ID is reserved by an application service.",
+ errcode=Codes.EXCLUSIVE,
)
@defer.inlineCallbacks
@@ -556,175 +431,53 @@ class RegistrationHandler(BaseHandler):
as_token = self.hs.config.shadow_server.get("as_token")
yield self.http_client.post_json_get_json(
- "%s/_matrix/client/r0/register?access_token=%s" % (
- shadow_hs_url, as_token,
- ),
+ "%s/_matrix/client/r0/register?access_token=%s" % (shadow_hs_url, as_token),
{
# XXX: auth_result is an unspecified extension for shadow registration
- 'auth_result': auth_result,
+ "auth_result": auth_result,
# XXX: another unspecified extension for shadow registration to ensure
# that the displayname is correctly set by the masters erver
- 'display_name': display_name,
- 'username': localpart,
- 'password': params.get("password"),
- 'bind_email': params.get("bind_email"),
- 'bind_msisdn': params.get("bind_msisdn"),
- 'device_id': params.get("device_id"),
- 'initial_device_display_name': params.get("initial_device_display_name"),
- 'inhibit_login': False,
- 'access_token': as_token,
- }
+ "display_name": display_name,
+ "username": localpart,
+ "password": params.get("password"),
+ "bind_msisdn": params.get("bind_msisdn"),
+ "device_id": params.get("device_id"),
+ "initial_device_display_name": params.get(
+ "initial_device_display_name"
+ ),
+ "inhibit_login": False,
+ "access_token": as_token,
+ },
)
@defer.inlineCallbacks
- def _generate_user_id(self, reseed=False):
- if reseed or self._next_generated_user_id is None:
+ def _generate_user_id(self):
+ if self._next_generated_user_id is None:
with (yield self._generate_user_id_linearizer.queue(())):
- if reseed or self._next_generated_user_id is None:
+ if self._next_generated_user_id is None:
self._next_generated_user_id = (
yield self.store.find_next_generated_user_id_localpart()
)
id = self._next_generated_user_id
self._next_generated_user_id += 1
- defer.returnValue(str(id))
-
- @defer.inlineCallbacks
- def _validate_captcha(self, ip_addr, private_key, challenge, response):
- """Validates the captcha provided.
-
- Used only by c/s api v1
-
- Returns:
- dict: Containing 'valid'(bool) and 'error_url'(str) if invalid.
-
- """
- response = yield self._submit_captcha(ip_addr, private_key, challenge,
- response)
- # parse Google's response. Lovely format..
- lines = response.split('\n')
- json = {
- "valid": lines[0] == 'true',
- "error_url": "http://www.recaptcha.net/recaptcha/api/challenge?" +
- "error=%s" % lines[1]
- }
- defer.returnValue(json)
-
- @defer.inlineCallbacks
- def _submit_captcha(self, ip_addr, private_key, challenge, response):
- """
- Used only by c/s api v1
- """
- data = yield self.captcha_client.post_urlencoded_get_raw(
- "http://www.recaptcha.net:80/recaptcha/api/verify",
- args={
- 'privatekey': private_key,
- 'remoteip': ip_addr,
- 'challenge': challenge,
- 'response': response
- }
- )
- defer.returnValue(data)
-
- @defer.inlineCallbacks
- def get_or_create_user(self, requester, localpart, displayname,
- password_hash=None):
- """Creates a new user if the user does not exist,
- else revokes all previous access tokens and generates a new one.
-
- Args:
- localpart : The local part of the user ID to register. If None,
- one will be randomly generated.
- Returns:
- A tuple of (user_id, access_token).
- Raises:
- RegistrationError if there was a problem registering.
-
- NB this is only used in tests. TODO: move it to the test package!
- """
- if localpart is None:
- raise SynapseError(400, "Request must include user id")
- yield self.auth.check_auth_blocking()
- need_register = True
-
- try:
- yield self.check_username(localpart)
- except SynapseError as e:
- if e.errcode == Codes.USER_IN_USE:
- need_register = False
- else:
- raise
-
- user = UserID(localpart, self.hs.hostname)
- user_id = user.to_string()
- token = self.macaroon_gen.generate_access_token(user_id)
-
- if need_register:
- yield self.register_with_store(
- user_id=user_id,
- token=token,
- password_hash=password_hash,
- create_profile_with_displayname=displayname or user.localpart,
- )
- if displayname is not None:
- yield self.profile_handler.set_displayname(
- user, None, displayname or user.localpart, by_admin=True,
- )
- else:
- yield self._auth_handler.delete_access_tokens_for_user(user_id)
- yield self.store.add_access_token_to_user(user_id=user_id, token=token)
-
- defer.returnValue((user_id, token))
-
- @defer.inlineCallbacks
- def get_or_register_3pid_guest(self, medium, address, inviter_user_id):
- """Get a guest access token for a 3PID, creating a guest account if
- one doesn't already exist.
-
- Args:
- medium (str)
- address (str)
- inviter_user_id (str): The user ID who is trying to invite the
- 3PID
-
- Returns:
- Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
- 3PID guest account.
- """
- access_token = yield self.store.get_3pid_guest_access_token(medium, address)
- if access_token:
- user_info = yield self.auth.get_user_by_access_token(
- access_token
- )
-
- defer.returnValue((user_info["user"].to_string(), access_token))
-
- user_id, access_token = yield self.register(
- generate_token=True,
- make_guest=True
- )
- access_token = yield self.store.save_or_get_3pid_guest_access_token(
- medium, address, access_token, inviter_user_id
- )
-
- defer.returnValue((user_id, access_token))
+ return str(id)
@defer.inlineCallbacks
def _join_user_to_room(self, requester, room_identifier):
- room_id = None
room_member_handler = self.hs.get_room_member_handler()
if RoomID.is_valid(room_identifier):
room_id = room_identifier
elif RoomAlias.is_valid(room_identifier):
room_alias = RoomAlias.from_string(room_identifier)
- room_id, remote_room_hosts = (
- yield room_member_handler.lookup_room_alias(room_alias)
+ room_id, remote_room_hosts = yield room_member_handler.lookup_room_alias(
+ room_alias
)
room_id = room_id.to_string()
else:
- raise SynapseError(400, "%s was not legal room ID or room alias" % (
- room_identifier,
- ))
+ raise SynapseError(
+ 400, "%s was not legal room ID or room alias" % (room_identifier,)
+ )
yield room_member_handler.update_membership(
requester=requester,
@@ -735,17 +488,45 @@ class RegistrationHandler(BaseHandler):
ratelimit=False,
)
- def register_with_store(self, user_id, token=None, password_hash=None,
- was_guest=False, make_guest=False, appservice_id=None,
- create_profile_with_displayname=None, admin=False,
- user_type=None, address=None):
+ def check_registration_ratelimit(self, address):
+ """A simple helper method to check whether the registration rate limit has been hit
+ for a given IP address
+
+ Args:
+ address (str|None): the IP address used to perform the registration. If this is
+ None, no ratelimiting will be performed.
+
+ Raises:
+ LimitExceededError: If the rate limit has been exceeded.
+ """
+ if not address:
+ return
+
+ time_now = self.clock.time()
+
+ self.ratelimiter.ratelimit(
+ address,
+ time_now_s=time_now,
+ rate_hz=self.hs.config.rc_registration.per_second,
+ burst_count=self.hs.config.rc_registration.burst_count,
+ )
+
+ def register_with_store(
+ self,
+ user_id,
+ password_hash=None,
+ was_guest=False,
+ make_guest=False,
+ appservice_id=None,
+ create_profile_with_displayname=None,
+ admin=False,
+ user_type=None,
+ address=None,
+ ):
"""Register user in the datastore.
Args:
user_id (str): The desired user ID to register.
- token (str): The desired access token to use for this user. If this
- is not None, the given access token is associated with the user
- id.
password_hash (str|None): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
@@ -762,25 +543,9 @@ class RegistrationHandler(BaseHandler):
Returns:
Deferred
"""
- # Don't rate limit for app services
- if appservice_id is None and address is not None:
- time_now = self.clock.time()
-
- allowed, time_allowed = self.ratelimiter.can_do_action(
- address, time_now_s=time_now,
- rate_hz=self.hs.config.rc_registration.per_second,
- burst_count=self.hs.config.rc_registration.burst_count,
- )
-
- if not allowed:
- raise LimitExceededError(
- retry_after_ms=int(1000 * (time_allowed - time_now)),
- )
-
if self.hs.config.worker_app:
return self._register_client(
user_id=user_id,
- token=token,
password_hash=password_hash,
was_guest=was_guest,
make_guest=make_guest,
@@ -791,9 +556,8 @@ class RegistrationHandler(BaseHandler):
address=address,
)
else:
- return self.store.register(
+ return self.store.register_user(
user_id=user_id,
- token=token,
password_hash=password_hash,
was_guest=was_guest,
make_guest=make_guest,
@@ -804,10 +568,11 @@ class RegistrationHandler(BaseHandler):
)
@defer.inlineCallbacks
- def register_device(self, user_id, device_id, initial_display_name,
- is_guest=False):
+ def register_device(self, user_id, device_id, initial_display_name, is_guest=False):
"""Register a device for a user and generate an access token.
+ The access token will be limited by the homeserver's session_lifetime config.
+
Args:
user_id (str): full canonical @user:id
device_id (str|None): The device ID to check, or None to generate
@@ -827,25 +592,35 @@ class RegistrationHandler(BaseHandler):
initial_display_name=initial_display_name,
is_guest=is_guest,
)
- defer.returnValue((r["device_id"], r["access_token"]))
- else:
- device_id = yield self.device_handler.check_device_registered(
- user_id, device_id, initial_display_name
- )
+ return r["device_id"], r["access_token"]
+
+ valid_until_ms = None
+ if self.session_lifetime is not None:
if is_guest:
- access_token = self.macaroon_gen.generate_access_token(
- user_id, ["guest = true"]
- )
- else:
- access_token = yield self._auth_handler.get_access_token_for_user_id(
- user_id, device_id=device_id,
+ raise Exception(
+ "session_lifetime is not currently implemented for guest access"
)
+ valid_until_ms = self.clock.time_msec() + self.session_lifetime
+
+ device_id = yield self.device_handler.check_device_registered(
+ user_id, device_id, initial_display_name
+ )
+ if is_guest:
+ assert valid_until_ms is None
+ access_token = self.macaroon_gen.generate_access_token(
+ user_id, ["guest = true"]
+ )
+ else:
+ access_token = yield self._auth_handler.get_access_token_for_user_id(
+ user_id, device_id=device_id, valid_until_ms=valid_until_ms
+ )
- defer.returnValue((device_id, access_token))
+ return (device_id, access_token)
@defer.inlineCallbacks
- def post_registration_actions(self, user_id, auth_result, access_token,
- bind_email, bind_msisdn):
+ def post_registration_actions(
+ self, user_id, auth_result, access_token,
+ ):
"""A user has completed registration
Args:
@@ -854,18 +629,10 @@ class RegistrationHandler(BaseHandler):
registered user.
access_token (str|None): The access token of the newly logged in
device, or None if `inhibit_login` enabled.
- bind_email (bool): Whether to bind the email with the identity
- server.
- bind_msisdn (bool): Whether to bind the msisdn with the identity
- server.
"""
if self.hs.config.worker_app:
yield self._post_registration_client(
- user_id=user_id,
- auth_result=auth_result,
- access_token=access_token,
- bind_email=bind_email,
- bind_msisdn=bind_msisdn,
+ user_id=user_id, auth_result=auth_result, access_token=access_token
)
return
@@ -878,21 +645,43 @@ class RegistrationHandler(BaseHandler):
):
yield self.store.upsert_monthly_active_user(user_id)
- yield self._register_email_threepid(
- user_id, threepid, access_token,
- bind_email,
- )
+ yield self.register_email_threepid(user_id, threepid, access_token)
+
+ if self.hs.config.account_threepid_delegate_email:
+ # Bind the 3PID to the identity server
+ logger.debug(
+ "Binding email to %s on id_server %s",
+ user_id,
+ self.hs.config.account_threepid_delegate_email,
+ )
+ threepid_creds = threepid["threepid_creds"]
+
+ # Remove the protocol scheme before handling to `bind_threepid`
+ # `bind_threepid` will add https:// to it, so this restricts
+ # account_threepid_delegate.email to https:// addresses only
+ # We assume this is always the case for dinsic however.
+ if self.hs.config.account_threepid_delegate_email.startswith(
+ "https://"
+ ):
+ id_server = self.hs.config.account_threepid_delegate_email[8:]
+ else:
+ # Must start with http:// instead
+ id_server = self.hs.config.account_threepid_delegate_email[7:]
+
+ yield self.identity_handler.bind_threepid(
+ threepid_creds["client_secret"],
+ threepid_creds["sid"],
+ user_id,
+ id_server,
+ threepid_creds.get("id_access_token"),
+ )
if auth_result and LoginType.MSISDN in auth_result:
threepid = auth_result[LoginType.MSISDN]
- yield self._register_msisdn_threepid(
- user_id, threepid, bind_msisdn,
- )
+ yield self._register_msisdn_threepid(user_id, threepid)
if auth_result and LoginType.TERMS in auth_result:
- yield self._on_user_consented(
- user_id, self.hs.config.user_consent_version,
- )
+ yield self._on_user_consented(user_id, self.hs.config.user_consent_version)
@defer.inlineCallbacks
def _on_user_consented(self, user_id, consent_version):
@@ -904,20 +693,16 @@ class RegistrationHandler(BaseHandler):
consented to.
"""
logger.info("%s has consented to the privacy policy", user_id)
- yield self.store.user_set_consent_version(
- user_id, consent_version,
- )
+ yield self.store.user_set_consent_version(user_id, consent_version)
yield self.post_consent_actions(user_id)
@defer.inlineCallbacks
- def _register_email_threepid(self, user_id, threepid, token, bind_email):
+ def register_email_threepid(self, user_id, threepid, token):
"""Add an email address as a 3pid identifier
Also adds an email pusher for the email address, if configured in the
HS config
- Also optionally binds emails to the given user_id on the identity server
-
Must be called on master.
Args:
@@ -925,38 +710,33 @@ class RegistrationHandler(BaseHandler):
threepid (object): m.login.email.identity auth response
token (str|None): access_token for the user, or None if not logged
in.
- bind_email (bool): true if the client requested the email to be
- bound at the identity server
Returns:
defer.Deferred:
"""
- reqd = ('medium', 'address', 'validated_at')
+ reqd = ("medium", "address", "validated_at")
if any(x not in threepid for x in reqd):
# This will only happen if the ID server returns a malformed response
logger.info("Can't add incomplete 3pid")
return
yield self._auth_handler.add_threepid(
- user_id,
- threepid['medium'],
- threepid['address'],
- threepid['validated_at'],
+ user_id, threepid["medium"], threepid["address"], threepid["validated_at"]
)
# And we add an email pusher for them by default, but only
# if email notifications are enabled (so people don't start
# getting mail spam where they weren't before if email
- # notifs are set up on a home server)
- if (self.hs.config.email_enable_notifs and
- self.hs.config.email_notif_for_new_users
- and token):
+ # notifs are set up on a homeserver)
+ if (
+ self.hs.config.email_enable_notifs
+ and self.hs.config.email_notif_for_new_users
+ and token
+ ):
# Pull the ID of the access token back out of the db
# It would really make more sense for this to be passed
# up when the access token is saved, but that's quite an
# invasive change I'd rather do separately.
- user_tuple = yield self.store.get_user_by_access_token(
- token
- )
+ user_tuple = yield self.store.get_user_by_access_token(token)
token_id = user_tuple["token_id"]
yield self.pusher_pool.add_pusher(
@@ -971,55 +751,27 @@ class RegistrationHandler(BaseHandler):
data={},
)
- if bind_email:
- logger.info("bind_email specified: binding")
- logger.debug("Binding emails %s to %s" % (
- threepid, user_id
- ))
- yield self.identity_handler.bind_threepid(
- threepid['threepid_creds'], user_id
- )
- else:
- logger.info("bind_email not specified: not binding email")
-
@defer.inlineCallbacks
- def _register_msisdn_threepid(self, user_id, threepid, bind_msisdn):
+ def _register_msisdn_threepid(self, user_id, threepid):
"""Add a phone number as a 3pid identifier
- Also optionally binds msisdn to the given user_id on the identity server
-
Must be called on master.
Args:
user_id (str): id of user
threepid (object): m.login.msisdn auth response
- token (str): access_token for the user
- bind_email (bool): true if the client requested the email to be
- bound at the identity server
Returns:
defer.Deferred:
"""
try:
- assert_params_in_dict(threepid, ['medium', 'address', 'validated_at'])
+ assert_params_in_dict(threepid, ["medium", "address", "validated_at"])
except SynapseError as ex:
if ex.errcode == Codes.MISSING_PARAM:
# This will only happen if the ID server returns a malformed response
logger.info("Can't add incomplete 3pid")
- defer.returnValue(None)
+ return None
raise
yield self._auth_handler.add_threepid(
- user_id,
- threepid['medium'],
- threepid['address'],
- threepid['validated_at'],
+ user_id, threepid["medium"], threepid["address"], threepid["validated_at"]
)
-
- if bind_msisdn:
- logger.info("bind_msisdn specified: binding")
- logger.debug("Binding msisdn %s to %s", threepid, user_id)
- yield self.identity_handler.bind_threepid(
- threepid['threepid_creds'], user_id
- )
- else:
- logger.info("bind_msisdn not specified: not binding msisdn")
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 7c24f9aac3..ee9fc296e1 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,6 +16,7 @@
# limitations under the License.
"""Contains functions for performing events on rooms."""
+
import itertools
import logging
import math
@@ -27,11 +29,22 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
-from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
+from synapse.events.utils import copy_power_levels_contents
+from synapse.http.endpoint import parse_and_validate_server_name
from synapse.storage.state import StateFilter
-from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
+from synapse.types import (
+ Requester,
+ RoomAlias,
+ RoomID,
+ RoomStreamToken,
+ StateMap,
+ StreamToken,
+ UserID,
+)
from synapse.util import stringutils
from synapse.util.async_helpers import Linearizer
+from synapse.util.caches.response_cache import ResponseCache
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -40,6 +53,8 @@ logger = logging.getLogger(__name__)
id_server_scheme = "https://"
+FIVE_MINUTES_IN_MS = 5 * 60 * 1000
+
class RoomCreationHandler(BaseHandler):
@@ -50,6 +65,7 @@ class RoomCreationHandler(BaseHandler):
"original_invitees_have_ops": False,
"guest_can_join": True,
"encryption_alg": "m.megolm.v1.aes-sha2",
+ "power_level_content_override": {"invite": 0},
},
RoomCreationPreset.TRUSTED_PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE,
@@ -57,12 +73,14 @@ class RoomCreationHandler(BaseHandler):
"original_invitees_have_ops": True,
"guest_can_join": True,
"encryption_alg": "m.megolm.v1.aes-sha2",
+ "power_level_content_override": {"invite": 0},
},
RoomCreationPreset.PUBLIC_CHAT: {
"join_rules": JoinRules.PUBLIC,
"history_visibility": "shared",
"original_invitees_have_ops": False,
"guest_can_join": False,
+ "power_level_content_override": {},
},
}
@@ -77,18 +95,26 @@ class RoomCreationHandler(BaseHandler):
# linearizer to stop two upgrades happening at once
self._upgrade_linearizer = Linearizer("room_upgrade_linearizer")
+ # If a user tries to update the same room multiple times in quick
+ # succession, only process the first attempt and return its result to
+ # subsequent requests
+ self._upgrade_response_cache = ResponseCache(
+ hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
+ )
self._server_notices_mxid = hs.config.server_notices_mxid
self.third_party_event_rules = hs.get_third_party_event_rules()
@defer.inlineCallbacks
- def upgrade_room(self, requester, old_room_id, new_version):
+ def upgrade_room(
+ self, requester: Requester, old_room_id: str, new_version: RoomVersion
+ ):
"""Replace a room with a new room with a different version
Args:
- requester (synapse.types.Requester): the user requesting the upgrade
- old_room_id (unicode): the id of the room to be replaced
- new_version (unicode): the new room version to use
+ requester: the user requesting the upgrade
+ old_room_id: the id of the room to be replaced
+ new_version: the new room version to use
Returns:
Deferred[unicode]: the new room id
@@ -97,78 +123,121 @@ class RoomCreationHandler(BaseHandler):
user_id = requester.user.to_string()
- with (yield self._upgrade_linearizer.queue(old_room_id)):
- # start by allocating a new room id
- r = yield self.store.get_room(old_room_id)
- if r is None:
- raise NotFoundError("Unknown room id %s" % (old_room_id,))
- new_room_id = yield self._generate_room_id(
- creator_id=user_id, is_public=r["is_public"],
- )
+ # Check if this room is already being upgraded by another person
+ for key in self._upgrade_response_cache.pending_result_cache:
+ if key[0] == old_room_id and key[1] != user_id:
+ # Two different people are trying to upgrade the same room.
+ # Send the second an error.
+ #
+ # Note that this of course only gets caught if both users are
+ # on the same homeserver.
+ raise SynapseError(
+ 400, "An upgrade for this room is currently in progress"
+ )
- logger.info("Creating new room %s to replace %s", new_room_id, old_room_id)
+ # Upgrade the room
+ #
+ # If this user has sent multiple upgrade requests for the same room
+ # and one of them is not complete yet, cache the response and
+ # return it to all subsequent requests
+ ret = yield self._upgrade_response_cache.wrap(
+ (old_room_id, user_id),
+ self._upgrade_room,
+ requester,
+ old_room_id,
+ new_version, # args for _upgrade_room
+ )
- # we create and auth the tombstone event before properly creating the new
- # room, to check our user has perms in the old room.
- tombstone_event, tombstone_context = (
- yield self.event_creation_handler.create_event(
- requester, {
- "type": EventTypes.Tombstone,
- "state_key": "",
- "room_id": old_room_id,
- "sender": user_id,
- "content": {
- "body": "This room has been replaced",
- "replacement_room": new_room_id,
- }
- },
- token_id=requester.access_token_id,
- )
- )
- old_room_version = yield self.store.get_room_version(old_room_id)
- yield self.auth.check_from_context(
- old_room_version, tombstone_event, tombstone_context,
- )
+ return ret
- yield self.clone_existing_room(
- requester,
- old_room_id=old_room_id,
- new_room_id=new_room_id,
- new_room_version=new_version,
- tombstone_event_id=tombstone_event.event_id,
- )
+ @defer.inlineCallbacks
+ def _upgrade_room(
+ self, requester: Requester, old_room_id: str, new_version: RoomVersion
+ ):
+ user_id = requester.user.to_string()
- # now send the tombstone
- yield self.event_creation_handler.send_nonmember_event(
- requester, tombstone_event, tombstone_context,
- )
+ # start by allocating a new room id
+ r = yield self.store.get_room(old_room_id)
+ if r is None:
+ raise NotFoundError("Unknown room id %s" % (old_room_id,))
+ new_room_id = yield self._generate_room_id(
+ creator_id=user_id, is_public=r["is_public"], room_version=new_version,
+ )
- old_room_state = yield tombstone_context.get_current_state_ids(self.store)
+ logger.info("Creating new room %s to replace %s", new_room_id, old_room_id)
- # update any aliases
- yield self._move_aliases_to_new_room(
- requester, old_room_id, new_room_id, old_room_state,
- )
+ # we create and auth the tombstone event before properly creating the new
+ # room, to check our user has perms in the old room.
+ (
+ tombstone_event,
+ tombstone_context,
+ ) = yield self.event_creation_handler.create_event(
+ requester,
+ {
+ "type": EventTypes.Tombstone,
+ "state_key": "",
+ "room_id": old_room_id,
+ "sender": user_id,
+ "content": {
+ "body": "This room has been replaced",
+ "replacement_room": new_room_id,
+ },
+ },
+ token_id=requester.access_token_id,
+ )
+ old_room_version = yield self.store.get_room_version_id(old_room_id)
+ yield self.auth.check_from_context(
+ old_room_version, tombstone_event, tombstone_context
+ )
- # and finally, shut down the PLs in the old room, and update them in the new
- # room.
- yield self._update_upgraded_room_pls(
- requester, old_room_id, new_room_id, old_room_state,
- )
+ yield self.clone_existing_room(
+ requester,
+ old_room_id=old_room_id,
+ new_room_id=new_room_id,
+ new_room_version=new_version,
+ tombstone_event_id=tombstone_event.event_id,
+ )
+
+ # now send the tombstone
+ yield self.event_creation_handler.send_nonmember_event(
+ requester, tombstone_event, tombstone_context
+ )
- defer.returnValue(new_room_id)
+ old_room_state = yield tombstone_context.get_current_state_ids()
+
+ # update any aliases
+ yield self._move_aliases_to_new_room(
+ requester, old_room_id, new_room_id, old_room_state
+ )
+
+ # Copy over user push rules, tags and migrate room directory state
+ yield self.room_member_handler.transfer_room_state_on_room_upgrade(
+ old_room_id, new_room_id
+ )
+
+ # finally, shut down the PLs in the old room, and update them in the new
+ # room.
+ yield self._update_upgraded_room_pls(
+ requester, old_room_id, new_room_id, old_room_state,
+ )
+
+ return new_room_id
@defer.inlineCallbacks
def _update_upgraded_room_pls(
- self, requester, old_room_id, new_room_id, old_room_state,
+ self,
+ requester: Requester,
+ old_room_id: str,
+ new_room_id: str,
+ old_room_state: StateMap[str],
):
"""Send updated power levels in both rooms after an upgrade
Args:
- requester (synapse.types.Requester): the user requesting the upgrade
- old_room_id (unicode): the id of the room to be replaced
- new_room_id (unicode): the id of the replacement room
- old_room_state (dict[tuple[str, str], str]): the state map for the old room
+ requester: the user requesting the upgrade
+ old_room_id: the id of the room to be replaced
+ new_room_id: the id of the replacement room
+ old_room_state: the state map for the old room
Returns:
Deferred
@@ -178,7 +247,7 @@ class RoomCreationHandler(BaseHandler):
if old_room_pl_event_id is None:
logger.warning(
"Not supported: upgrading a room with no PL event. Not setting PLs "
- "in old room.",
+ "in old room."
)
return
@@ -197,87 +266,87 @@ class RoomCreationHandler(BaseHandler):
for v in ("invite", "events_default"):
current = int(pl_content.get(v, 0))
if current < restricted_level:
- logger.info(
+ logger.debug(
"Setting level for %s in %s to %i (was %i)",
- v, old_room_id, restricted_level, current,
+ v,
+ old_room_id,
+ restricted_level,
+ current,
)
pl_content[v] = restricted_level
updated = True
else:
- logger.info(
- "Not setting level for %s (already %i)",
- v, current,
- )
+ logger.debug("Not setting level for %s (already %i)", v, current)
if updated:
try:
yield self.event_creation_handler.create_and_send_nonmember_event(
- requester, {
+ requester,
+ {
"type": EventTypes.PowerLevels,
- "state_key": '',
+ "state_key": "",
"room_id": old_room_id,
"sender": requester.user.to_string(),
"content": pl_content,
- }, ratelimit=False,
+ },
+ ratelimit=False,
)
except AuthError as e:
logger.warning("Unable to update PLs in old room: %s", e)
- logger.info("Setting correct PLs in new room")
yield self.event_creation_handler.create_and_send_nonmember_event(
- requester, {
+ requester,
+ {
"type": EventTypes.PowerLevels,
- "state_key": '',
+ "state_key": "",
"room_id": new_room_id,
"sender": requester.user.to_string(),
"content": old_room_pl_state.content,
- }, ratelimit=False,
+ },
+ ratelimit=False,
)
@defer.inlineCallbacks
def clone_existing_room(
- self, requester, old_room_id, new_room_id, new_room_version,
- tombstone_event_id,
+ self,
+ requester: Requester,
+ old_room_id: str,
+ new_room_id: str,
+ new_room_version: RoomVersion,
+ tombstone_event_id: str,
):
"""Populate a new room based on an old room
Args:
- requester (synapse.types.Requester): the user requesting the upgrade
- old_room_id (unicode): the id of the room to be replaced
- new_room_id (unicode): the id to give the new room (should already have been
+ requester: the user requesting the upgrade
+ old_room_id : the id of the room to be replaced
+ new_room_id: the id to give the new room (should already have been
created with _gemerate_room_id())
- new_room_version (unicode): the new room version to use
- tombstone_event_id (unicode|str): the ID of the tombstone event in the old
- room.
+ new_room_version: the new room version to use
+ tombstone_event_id: the ID of the tombstone event in the old room.
Returns:
- Deferred[None]
+ Deferred
"""
user_id = requester.user.to_string()
- if (self._server_notices_mxid is not None and
- requester.user.to_string() == self._server_notices_mxid):
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
# allow the server notices mxid to create rooms
is_requester_admin = True
else:
- is_requester_admin = yield self.auth.is_server_admin(
- requester.user,
- )
+ is_requester_admin = yield self.auth.is_server_admin(requester.user)
if not is_requester_admin and not self.spam_checker.user_may_create_room(
- user_id,
- invite_list=[],
- third_party_invite_list=[],
- cloning=True,
+ user_id, invite_list=[], third_party_invite_list=[], cloning=True
):
raise SynapseError(403, "You are not permitted to create rooms")
creation_content = {
- "room_version": new_room_version,
- "predecessor": {
- "room_id": old_room_id,
- "event_id": tombstone_event_id,
- }
+ "room_version": new_room_version.identifier,
+ "predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id},
}
# Check if old room was non-federatable
@@ -290,7 +359,7 @@ class RoomCreationHandler(BaseHandler):
# If so, mark the new room as non-federatable as well
creation_content["m.federate"] = False
- initial_state = dict()
+ initial_state = {}
# Replicate relevant room events
types_to_copy = (
@@ -300,13 +369,14 @@ class RoomCreationHandler(BaseHandler):
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.GuestAccess, ""),
(EventTypes.RoomAvatar, ""),
- (EventTypes.Encryption, ""),
+ (EventTypes.RoomEncryption, ""),
(EventTypes.ServerACL, ""),
(EventTypes.RelatedGroups, ""),
+ (EventTypes.PowerLevels, ""),
)
old_room_state_ids = yield self.store.get_filtered_current_state_ids(
- old_room_id, StateFilter.from_types(types_to_copy),
+ old_room_id, StateFilter.from_types(types_to_copy)
)
# map from event_id to BaseEvent
old_room_state_events = yield self.store.get_events(old_room_state_ids.values())
@@ -316,14 +386,40 @@ class RoomCreationHandler(BaseHandler):
if old_event:
initial_state[k] = old_event.content
+ # deep-copy the power-levels event before we start modifying it
+ # note that if frozen_dicts are enabled, `power_levels` will be a frozen
+ # dict so we can't just copy.deepcopy it.
+ initial_state[
+ (EventTypes.PowerLevels, "")
+ ] = power_levels = copy_power_levels_contents(
+ initial_state[(EventTypes.PowerLevels, "")]
+ )
+
+ # Resolve the minimum power level required to send any state event
+ # We will give the upgrading user this power level temporarily (if necessary) such that
+ # they are able to copy all of the state events over, then revert them back to their
+ # original power level afterwards in _update_upgraded_room_pls
+
+ # Copy over user power levels now as this will not be possible with >100PL users once
+ # the room has been created
+
+ # Calculate the minimum power level needed to clone the room
+ event_power_levels = power_levels.get("events", {})
+ state_default = power_levels.get("state_default", 0)
+ ban = power_levels.get("ban")
+ needed_power_level = max(state_default, ban, max(event_power_levels.values()))
+
+ # Raise the requester's power level in the new room if necessary
+ current_power_level = power_levels["users"][user_id]
+ if current_power_level < needed_power_level:
+ power_levels["users"][user_id] = needed_power_level
+
yield self._send_events_for_new_room(
requester,
new_room_id,
-
# we expect to override all the presets with initial_state, so this is
# somewhat arbitrary.
preset_config=RoomCreationPreset.PRIVATE_CHAT,
-
invite_list=[],
initial_state=initial_state,
creation_content=creation_content,
@@ -331,20 +427,22 @@ class RoomCreationHandler(BaseHandler):
# Transfer membership events
old_room_member_state_ids = yield self.store.get_filtered_current_state_ids(
- old_room_id, StateFilter.from_types([(EventTypes.Member, None)]),
+ old_room_id, StateFilter.from_types([(EventTypes.Member, None)])
)
# map from event_id to BaseEvent
old_room_member_state_events = yield self.store.get_events(
- old_room_member_state_ids.values(),
+ old_room_member_state_ids.values()
)
for k, old_event in iteritems(old_room_member_state_events):
# Only transfer ban events
- if ("membership" in old_event.content and
- old_event.content["membership"] == "ban"):
+ if (
+ "membership" in old_event.content
+ and old_event.content["membership"] == "ban"
+ ):
yield self.room_member_handler.update_membership(
requester,
- UserID.from_string(old_event['state_key']),
+ UserID.from_string(old_event["state_key"]),
new_room_id,
"ban",
ratelimit=False,
@@ -356,19 +454,21 @@ class RoomCreationHandler(BaseHandler):
@defer.inlineCallbacks
def _move_aliases_to_new_room(
- self, requester, old_room_id, new_room_id, old_room_state,
+ self,
+ requester: Requester,
+ old_room_id: str,
+ new_room_id: str,
+ old_room_state: StateMap[str],
):
directory_handler = self.hs.get_handlers().directory_handler
aliases = yield self.store.get_aliases_for_room(old_room_id)
# check to see if we have a canonical alias.
- canonical_alias = None
+ canonical_alias_event = None
canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, ""))
if canonical_alias_event_id:
canonical_alias_event = yield self.store.get_event(canonical_alias_event_id)
- if canonical_alias_event:
- canonical_alias = canonical_alias_event.content.get("alias", "")
# first we try to remove the aliases from the old room (we suppress sending
# the room_aliases event until the end).
@@ -386,57 +486,36 @@ class RoomCreationHandler(BaseHandler):
for alias_str in aliases:
alias = RoomAlias.from_string(alias_str)
try:
- yield directory_handler.delete_association(
- requester, alias, send_event=False,
- )
+ yield directory_handler.delete_association(requester, alias)
removed_aliases.append(alias_str)
except SynapseError as e:
- logger.warning(
- "Unable to remove alias %s from old room: %s",
- alias, e,
- )
+ logger.warning("Unable to remove alias %s from old room: %s", alias, e)
# if we didn't find any aliases, or couldn't remove anyway, we can skip the rest
# of this.
if not removed_aliases:
return
- try:
- # this can fail if, for some reason, our user doesn't have perms to send
- # m.room.aliases events in the old room (note that we've already checked that
- # they have perms to send a tombstone event, so that's not terribly likely).
- #
- # If that happens, it's regrettable, but we should carry on: it's the same
- # as when you remove an alias from the directory normally - it just means that
- # the aliases event gets out of sync with the directory
- # (cf https://github.com/vector-im/riot-web/issues/2369)
- yield directory_handler.send_room_alias_update_event(
- requester, old_room_id,
- )
- except AuthError as e:
- logger.warning(
- "Failed to send updated alias event on old room: %s", e,
- )
-
# we can now add any aliases we successfully removed to the new room.
for alias in removed_aliases:
try:
yield directory_handler.create_association(
- requester, RoomAlias.from_string(alias),
- new_room_id, servers=(self.hs.hostname, ),
- send_event=False, check_membership=False,
+ requester,
+ RoomAlias.from_string(alias),
+ new_room_id,
+ servers=(self.hs.hostname,),
+ check_membership=False,
)
logger.info("Moved alias %s to new room", alias)
except SynapseError as e:
# I'm not really expecting this to happen, but it could if the spam
# checking module decides it shouldn't, or similar.
- logger.error(
- "Error adding alias %s to new room: %s",
- alias, e,
- )
+ logger.error("Error adding alias %s to new room: %s", alias, e)
+ # If a canonical alias event existed for the old room, fire a canonical
+ # alias event for the new room with a copy of the information.
try:
- if canonical_alias and (canonical_alias in removed_aliases):
+ if canonical_alias_event:
yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
@@ -444,24 +523,17 @@ class RoomCreationHandler(BaseHandler):
"state_key": "",
"room_id": new_room_id,
"sender": requester.user.to_string(),
- "content": {"alias": canonical_alias, },
+ "content": canonical_alias_event.content,
},
- ratelimit=False
+ ratelimit=False,
)
-
- yield directory_handler.send_room_alias_update_event(
- requester, new_room_id,
- )
except SynapseError as e:
# again I'm not really expecting this to fail, but if it does, I'd rather
# we returned the new room to the client at this point.
- logger.error(
- "Unable to send updated alias events in new room: %s", e,
- )
+ logger.error("Unable to send updated alias events in new room: %s", e)
@defer.inlineCallbacks
- def create_room(self, requester, config, ratelimit=True,
- creator_join_profile=None):
+ def create_room(self, requester, config, ratelimit=True, creator_join_profile=None):
""" Creates a new room.
Args:
@@ -491,22 +563,24 @@ class RoomCreationHandler(BaseHandler):
yield self.auth.check_auth_blocking(user_id)
- if (self._server_notices_mxid is not None and
- requester.user.to_string() == self._server_notices_mxid):
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
# allow the server notices mxid to create rooms
is_requester_admin = True
else:
- is_requester_admin = yield self.auth.is_server_admin(
- requester.user,
- )
+ is_requester_admin = yield self.auth.is_server_admin(requester.user)
# Check whether the third party rules allows/changes the room create
# request.
- yield self.third_party_event_rules.on_create_room(
- requester,
- config,
- is_requester_admin=is_requester_admin,
+ event_allowed = yield self.third_party_event_rules.on_create_room(
+ requester, config, is_requester_admin=is_requester_admin
)
+ if not event_allowed:
+ raise SynapseError(
+ 403, "You are not permitted to create rooms", Codes.FORBIDDEN
+ )
invite_list = config.get("invite", [])
invite_3pid_list = config.get("invite_3pid", [])
@@ -522,19 +596,15 @@ class RoomCreationHandler(BaseHandler):
if ratelimit:
yield self.ratelimit(requester)
- room_version = config.get(
- "room_version",
- self.config.default_room_version.identifier,
+ room_version_id = config.get(
+ "room_version", self.config.default_room_version.identifier
)
- if not isinstance(room_version, string_types):
- raise SynapseError(
- 400,
- "room_version must be a string",
- Codes.BAD_JSON,
- )
+ if not isinstance(room_version_id, string_types):
+ raise SynapseError(400, "room_version must be a string", Codes.BAD_JSON)
- if room_version not in KNOWN_ROOM_VERSIONS:
+ room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
+ if room_version is None:
raise SynapseError(
400,
"Your homeserver does not support this room version",
@@ -546,46 +616,49 @@ class RoomCreationHandler(BaseHandler):
if wchar in config["room_alias_name"]:
raise SynapseError(400, "Invalid characters in room alias")
- room_alias = RoomAlias(
- config["room_alias_name"],
- self.hs.hostname,
- )
- mapping = yield self.store.get_association_from_room_alias(
- room_alias
- )
+ room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname)
+ mapping = yield self.store.get_association_from_room_alias(room_alias)
if mapping:
- raise SynapseError(
- 400,
- "Room alias already taken",
- Codes.ROOM_IN_USE
- )
+ raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
else:
room_alias = None
for i in invite_list:
try:
- UserID.from_string(i)
+ uid = UserID.from_string(i)
+ parse_and_validate_server_name(uid.domain)
except Exception:
raise SynapseError(400, "Invalid user_id: %s" % (i,))
- yield self.event_creation_handler.assert_accepted_privacy_policy(
- requester,
- )
+ yield self.event_creation_handler.assert_accepted_privacy_policy(requester)
+
+ power_level_content_override = config.get("power_level_content_override")
+ if (
+ power_level_content_override
+ and "users" in power_level_content_override
+ and user_id not in power_level_content_override["users"]
+ ):
+ raise SynapseError(
+ 400,
+ "Not a valid power_level_content_override: 'users' did not contain %s"
+ % (user_id,),
+ )
visibility = config.get("visibility", None)
is_public = visibility == "public"
- room_id = yield self._generate_room_id(creator_id=user_id, is_public=is_public)
+ room_id = yield self._generate_room_id(
+ creator_id=user_id, is_public=is_public, room_version=room_version,
+ )
+ directory_handler = self.hs.get_handlers().directory_handler
if room_alias:
- directory_handler = self.hs.get_handlers().directory_handler
yield directory_handler.create_association(
requester=requester,
room_id=room_id,
room_alias=room_alias,
servers=[self.hs.hostname],
- send_event=False,
check_membership=False,
)
@@ -593,7 +666,7 @@ class RoomCreationHandler(BaseHandler):
"preset",
RoomCreationPreset.PRIVATE_CHAT
if visibility == "private"
- else RoomCreationPreset.PUBLIC_CHAT
+ else RoomCreationPreset.PUBLIC_CHAT,
)
raw_initial_state = config.get("initial_state", [])
@@ -605,7 +678,7 @@ class RoomCreationHandler(BaseHandler):
creation_content = config.get("creation_content", {})
# override any attempt to set room versions via the creation_content
- creation_content["room_version"] = room_version
+ creation_content["room_version"] = room_version.identifier
yield self._send_events_for_new_room(
requester,
@@ -615,7 +688,7 @@ class RoomCreationHandler(BaseHandler):
initial_state=initial_state,
creation_content=creation_content,
room_alias=room_alias,
- power_level_content_override=config.get("power_level_content_override"),
+ power_level_content_override=power_level_content_override,
creator_join_profile=creator_join_profile,
)
@@ -630,7 +703,8 @@ class RoomCreationHandler(BaseHandler):
"state_key": "",
"content": {"name": name},
},
- ratelimit=False)
+ ratelimit=False,
+ )
if "topic" in config:
topic = config["topic"]
@@ -643,7 +717,8 @@ class RoomCreationHandler(BaseHandler):
"state_key": "",
"content": {"topic": topic},
},
- ratelimit=False)
+ ratelimit=False,
+ )
for invitee in invite_list:
content = {}
@@ -663,6 +738,7 @@ class RoomCreationHandler(BaseHandler):
for invite_3pid in invite_3pid_list:
id_server = invite_3pid["id_server"]
+ id_access_token = invite_3pid.get("id_access_token") # optional
address = invite_3pid["address"]
medium = invite_3pid["medium"]
yield self.hs.get_room_member_handler().do_3pid_invite(
@@ -674,36 +750,31 @@ class RoomCreationHandler(BaseHandler):
requester,
txn_id=None,
new_room=True,
+ id_access_token=id_access_token,
)
result = {"room_id": room_id}
if room_alias:
result["room_alias"] = room_alias.to_string()
- yield directory_handler.send_room_alias_update_event(
- requester, room_id
- )
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _send_events_for_new_room(
- self,
- creator, # A Requester object.
- room_id,
- preset_config,
- invite_list,
- initial_state,
- creation_content,
- room_alias=None,
- power_level_content_override=None,
- creator_join_profile=None,
+ self,
+ creator, # A Requester object.
+ room_id,
+ preset_config,
+ invite_list,
+ initial_state,
+ creation_content,
+ room_alias=None,
+ power_level_content_override=None, # Doesn't apply when initial state has power level state event content
+ creator_join_profile=None,
):
def create(etype, content, **kwargs):
- e = {
- "type": etype,
- "content": content,
- }
+ e = {"type": etype, "content": content}
e.update(event_keys)
e.update(kwargs)
@@ -713,30 +784,21 @@ class RoomCreationHandler(BaseHandler):
@defer.inlineCallbacks
def send(etype, content, **kwargs):
event = create(etype, content, **kwargs)
- logger.info("Sending %s in new room", etype)
+ logger.debug("Sending %s in new room", etype)
yield self.event_creation_handler.create_and_send_nonmember_event(
- creator,
- event,
- ratelimit=False
+ creator, event, ratelimit=False
)
config = RoomCreationHandler.PRESETS_DICT[preset_config]
creator_id = creator.user.to_string()
- event_keys = {
- "room_id": room_id,
- "sender": creator_id,
- "state_key": "",
- }
+ event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
creation_content.update({"creator": creator_id})
- yield send(
- etype=EventTypes.Create,
- content=creation_content,
- )
+ yield send(etype=EventTypes.Create, content=creation_content)
- logger.info("Sending %s in new room", EventTypes.Member)
+ logger.debug("Sending %s in new room", EventTypes.Member)
yield self.room_member_handler.update_membership(
creator,
creator.user,
@@ -749,17 +811,12 @@ class RoomCreationHandler(BaseHandler):
# We treat the power levels override specially as this needs to be one
# of the first events that get sent into a room.
- pl_content = initial_state.pop((EventTypes.PowerLevels, ''), None)
+ pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None)
if pl_content is not None:
- yield send(
- etype=EventTypes.PowerLevels,
- content=pl_content,
- )
+ yield send(etype=EventTypes.PowerLevels, content=pl_content)
else:
power_level_content = {
- "users": {
- creator_id: 100,
- },
+ "users": {creator_id: 100},
"users_default": 0,
"events": {
EventTypes.Name: 50,
@@ -767,88 +824,82 @@ class RoomCreationHandler(BaseHandler):
EventTypes.RoomHistoryVisibility: 100,
EventTypes.CanonicalAlias: 50,
EventTypes.RoomAvatar: 50,
+ EventTypes.Tombstone: 100,
+ EventTypes.ServerACL: 100,
},
"events_default": 0,
"state_default": 50,
"ban": 50,
"kick": 50,
"redact": 50,
- "invite": 0,
+ "invite": 50,
}
if config["original_invitees_have_ops"]:
for invitee in invite_list:
power_level_content["users"][invitee] = 100
+ # Power levels overrides are defined per chat preset
+ power_level_content.update(config["power_level_content_override"])
+
if power_level_content_override:
power_level_content.update(power_level_content_override)
- yield send(
- etype=EventTypes.PowerLevels,
- content=power_level_content,
- )
+ yield send(etype=EventTypes.PowerLevels, content=power_level_content)
- if room_alias and (EventTypes.CanonicalAlias, '') not in initial_state:
+ if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
yield send(
etype=EventTypes.CanonicalAlias,
content={"alias": room_alias.to_string()},
)
- if (EventTypes.JoinRules, '') not in initial_state:
+ if (EventTypes.JoinRules, "") not in initial_state:
yield send(
- etype=EventTypes.JoinRules,
- content={"join_rule": config["join_rules"]},
+ etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]}
)
- if (EventTypes.RoomHistoryVisibility, '') not in initial_state:
+ if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
yield send(
etype=EventTypes.RoomHistoryVisibility,
- content={"history_visibility": config["history_visibility"]}
+ content={"history_visibility": config["history_visibility"]},
)
if config["guest_can_join"]:
- if (EventTypes.GuestAccess, '') not in initial_state:
+ if (EventTypes.GuestAccess, "") not in initial_state:
yield send(
- etype=EventTypes.GuestAccess,
- content={"guest_access": "can_join"}
+ etype=EventTypes.GuestAccess, content={"guest_access": "can_join"}
)
for (etype, state_key), content in initial_state.items():
- yield send(
- etype=etype,
- state_key=state_key,
- content=content,
- )
+ yield send(etype=etype, state_key=state_key, content=content)
if "encryption_alg" in config:
yield send(
- etype=EventTypes.Encryption,
+ etype=EventTypes.RoomEncryption,
state_key="",
- content={
- 'algorithm': config["encryption_alg"],
- }
+ content={"algorithm": config["encryption_alg"]},
)
@defer.inlineCallbacks
- def _generate_room_id(self, creator_id, is_public):
+ def _generate_room_id(
+ self, creator_id: str, is_public: str, room_version: RoomVersion,
+ ):
# autogen room IDs and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
attempts = 0
while attempts < 5:
try:
random_string = stringutils.random_string(18)
- gen_room_id = RoomID(
- random_string,
- self.hs.hostname,
- ).to_string()
+ gen_room_id = RoomID(random_string, self.hs.hostname).to_string()
if isinstance(gen_room_id, bytes):
- gen_room_id = gen_room_id.decode('utf-8')
+ gen_room_id = gen_room_id.decode("utf-8")
yield self.store.store_room(
room_id=gen_room_id,
room_creator_user_id=creator_id,
is_public=is_public,
+ room_version=room_version,
)
- defer.returnValue(gen_room_id)
+ return gen_room_id
except StoreError:
attempts += 1
raise StoreError(500, "Couldn't generate a room ID.")
@@ -858,6 +909,8 @@ class RoomContextHandler(object):
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
@defer.inlineCallbacks
def get_event_context(self, user, room_id, event_id, limit, event_filter):
@@ -876,7 +929,7 @@ class RoomContextHandler(object):
Returns:
dict, or None if the event isn't found
"""
- before_limit = math.floor(limit / 2.)
+ before_limit = math.floor(limit / 2.0)
after_limit = limit - before_limit
users = yield self.store.get_users_in_room(room_id)
@@ -884,32 +937,33 @@ class RoomContextHandler(object):
def filter_evts(events):
return filter_events_for_client(
- self.store,
- user.to_string(),
- events,
- is_peeking=is_peeking
+ self.storage, user.to_string(), events, is_peeking=is_peeking
)
- event = yield self.store.get_event(event_id, get_prev_content=True,
- allow_none=True)
+ event = yield self.store.get_event(
+ event_id, get_prev_content=True, allow_none=True
+ )
if not event:
- defer.returnValue(None)
- return
+ return None
- filtered = yield(filter_evts([event]))
+ filtered = yield (filter_evts([event]))
if not filtered:
- raise AuthError(
- 403,
- "You don't have permission to access that event."
- )
+ raise AuthError(403, "You don't have permission to access that event.")
results = yield self.store.get_events_around(
room_id, event_id, before_limit, after_limit, event_filter
)
+ if event_filter:
+ results["events_before"] = event_filter.filter(results["events_before"])
+ results["events_after"] = event_filter.filter(results["events_after"])
+
results["events_before"] = yield filter_evts(results["events_before"])
results["events_after"] = yield filter_evts(results["events_after"])
- results["event"] = event
+ # filter_evts can return a pruned event in case the user is allowed to see that
+ # there's something there but not see the content, so use the event that's in
+ # `filtered` rather than the event we retrieved from the datastore.
+ results["event"] = filtered[0]
if results["events_after"]:
last_event_id = results["events_after"][-1].event_id
@@ -932,10 +986,15 @@ class RoomContextHandler(object):
# first? Shouldn't we be consistent with /sync?
# https://github.com/matrix-org/matrix-doc/issues/687
- state = yield self.store.get_state_for_events(
- [last_event_id], state_filter=state_filter,
+ state = yield self.state_store.get_state_for_events(
+ [last_event_id], state_filter=state_filter
)
- results["state"] = list(state[last_event_id].values())
+
+ state_events = list(state[last_event_id].values())
+ if event_filter:
+ state_events = event_filter.filter(state_events)
+
+ results["state"] = yield filter_evts(state_events)
# We use a dummy token here as we only care about the room portion of
# the token, which we replace.
@@ -945,11 +1004,9 @@ class RoomContextHandler(object):
"room_key", results["start"]
).to_string()
- results["end"] = token.copy_and_replace(
- "room_key", results["end"]
- ).to_string()
+ results["end"] = token.copy_and_replace("room_key", results["end"]).to_string()
- defer.returnValue(results)
+ return results
class RoomEventSource(object):
@@ -958,13 +1015,7 @@ class RoomEventSource(object):
@defer.inlineCallbacks
def get_new_events(
- self,
- user,
- from_key,
- limit,
- room_ids,
- is_guest,
- explicit_room_id=None,
+ self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None
):
# We just ignore the key for now.
@@ -972,12 +1023,10 @@ class RoomEventSource(object):
from_token = RoomStreamToken.parse(from_key)
if from_token.topological:
- logger.warn("Stream has topological part!!!! %r", from_key)
+ logger.warning("Stream has topological part!!!! %r", from_key)
from_key = "s%s" % (from_token.stream,)
- app_service = self.store.get_app_service_by_user_id(
- user.to_string()
- )
+ app_service = self.store.get_app_service_by_user_id(user.to_string())
if app_service:
# We no longer support AS users using /sync directly.
# See https://github.com/matrix-org/matrix-doc/issues/1144
@@ -992,7 +1041,7 @@ class RoomEventSource(object):
from_key=from_key,
to_key=to_key,
limit=limit or 10,
- order='ASC',
+ order="ASC",
)
events = list(room_events)
@@ -1008,22 +1057,10 @@ class RoomEventSource(object):
else:
end_key = to_key
- defer.returnValue((events, end_key))
+ return (events, end_key)
def get_current_key(self):
return self.store.get_room_events_max_id()
def get_current_key_for_room(self, room_id):
return self.store.get_room_events_max_id(room_id)
-
- @defer.inlineCallbacks
- def get_pagination_rows(self, user, config, key):
- events, next_key = yield self.store.paginate_room_events(
- room_id=key,
- from_key=config.from_key,
- to_key=config.to_key,
- direction=config.direction,
- limit=config.limit,
- )
-
- defer.returnValue((events, next_key))
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 617d1c9ef8..4469d51c52 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -16,8 +16,7 @@
import logging
from collections import namedtuple
-from six import PY3, iteritems
-from six.moves import range
+from six import iteritems
import msgpack
from unpaddedbase64 import decode_base64, encode_base64
@@ -25,8 +24,8 @@ from unpaddedbase64 import decode_base64, encode_base64
from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules
+from synapse.api.errors import Codes, HttpResponseException
from synapse.types import ThirdPartyInstanceID
-from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.response_cache import ResponseCache
@@ -36,7 +35,6 @@ logger = logging.getLogger(__name__)
REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000
-
# This is used to indicate we should only return rooms published to the main list.
EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
@@ -46,13 +44,18 @@ class RoomListHandler(BaseHandler):
super(RoomListHandler, self).__init__(hs)
self.enable_room_list_search = hs.config.enable_room_list_search
self.response_cache = ResponseCache(hs, "room_list")
- self.remote_response_cache = ResponseCache(hs, "remote_room_list",
- timeout_ms=30 * 1000)
+ self.remote_response_cache = ResponseCache(
+ hs, "remote_room_list", timeout_ms=30 * 1000
+ )
- def get_local_public_room_list(self, limit=None, since_token=None,
- search_filter=None,
- network_tuple=EMPTY_THIRD_PARTY_ID,
- from_federation=False):
+ def get_local_public_room_list(
+ self,
+ limit=None,
+ since_token=None,
+ search_filter=None,
+ network_tuple=EMPTY_THIRD_PARTY_ID,
+ from_federation=False,
+ ):
"""Generate a local public room list.
There are multiple different lists: the main one plus one per third
@@ -66,16 +69,18 @@ class RoomListHandler(BaseHandler):
This can be (None, None) to indicate the main list, or a particular
appservice and network id to use an appservice specific one.
Setting to None returns all public rooms across all lists.
+ from_federation (bool): true iff the request comes from the federation
+ API
"""
if not self.enable_room_list_search:
- return defer.succeed({
- "chunk": [],
- "total_room_count_estimate": 0,
- })
+ return defer.succeed({"chunk": [], "total_room_count_estimate": 0})
logger.info(
"Getting public room list: limit=%r, since=%r, search=%r, network=%r",
- limit, since_token, bool(search_filter), network_tuple,
+ limit,
+ since_token,
+ bool(search_filter),
+ network_tuple,
)
if search_filter:
@@ -83,29 +88,33 @@ class RoomListHandler(BaseHandler):
# appservice specific lists.
logger.info("Bypassing cache as search request.")
- # XXX: Quick hack to stop room directory queries taking too long.
- # Timeout request after 60s. Probably want a more fundamental
- # solution at some point
- timeout = self.clock.time() + 60
return self._get_public_room_list(
- limit, since_token, search_filter,
- network_tuple=network_tuple, timeout=timeout,
+ limit,
+ since_token,
+ search_filter,
+ network_tuple=network_tuple,
+ from_federation=from_federation,
)
key = (limit, since_token, network_tuple)
return self.response_cache.wrap(
key,
self._get_public_room_list,
- limit, since_token,
- network_tuple=network_tuple, from_federation=from_federation,
+ limit,
+ since_token,
+ network_tuple=network_tuple,
+ from_federation=from_federation,
)
@defer.inlineCallbacks
- def _get_public_room_list(self, limit=None, since_token=None,
- search_filter=None,
- network_tuple=EMPTY_THIRD_PARTY_ID,
- from_federation=False,
- timeout=None,):
+ def _get_public_room_list(
+ self,
+ limit=None,
+ since_token=None,
+ search_filter=None,
+ network_tuple=EMPTY_THIRD_PARTY_ID,
+ from_federation=False,
+ ):
"""Generate a public room list.
Args:
limit (int|None): Maximum amount of rooms to return.
@@ -117,236 +126,117 @@ class RoomListHandler(BaseHandler):
Setting to None returns all public rooms across all lists.
from_federation (bool): Whether this request originated from a
federating server or a client. Used for room filtering.
- timeout (int|None): Amount of seconds to wait for a response before
- timing out.
"""
- if since_token and since_token != "END":
- since_token = RoomListNextBatch.from_token(since_token)
- else:
- since_token = None
- rooms_to_order_value = {}
- rooms_to_num_joined = {}
+ # Pagination tokens work by storing the room ID sent in the last batch,
+ # plus the direction (forwards or backwards). Next batch tokens always
+ # go forwards, prev batch tokens always go backwards.
- newly_visible = []
- newly_unpublished = []
if since_token:
- stream_token = since_token.stream_ordering
- current_public_id = yield self.store.get_current_public_room_stream_id()
- public_room_stream_id = since_token.public_room_stream_id
- newly_visible, newly_unpublished = yield self.store.get_public_room_changes(
- public_room_stream_id, current_public_id,
- network_tuple=network_tuple,
- )
- else:
- stream_token = yield self.store.get_room_max_stream_ordering()
- public_room_stream_id = yield self.store.get_current_public_room_stream_id()
-
- room_ids = yield self.store.get_public_room_ids_at_stream_id(
- public_room_stream_id, network_tuple=network_tuple,
- )
-
- # We want to return rooms in a particular order: the number of joined
- # users. We then arbitrarily use the room_id as a tie breaker.
-
- @defer.inlineCallbacks
- def get_order_for_room(room_id):
- # Most of the rooms won't have changed between the since token and
- # now (especially if the since token is "now"). So, we can ask what
- # the current users are in a room (that will hit a cache) and then
- # check if the room has changed since the since token. (We have to
- # do it in that order to avoid races).
- # If things have changed then fall back to getting the current state
- # at the since token.
- joined_users = yield self.store.get_users_in_room(room_id)
- if self.store.has_room_changed_since(room_id, stream_token):
- latest_event_ids = yield self.store.get_forward_extremeties_for_room(
- room_id, stream_token
- )
-
- if not latest_event_ids:
- return
+ batch_token = RoomListNextBatch.from_token(since_token)
- joined_users = yield self.state_handler.get_current_users_in_room(
- room_id, latest_event_ids,
- )
-
- num_joined_users = len(joined_users)
- rooms_to_num_joined[room_id] = num_joined_users
-
- if num_joined_users == 0:
- return
-
- # We want larger rooms to be first, hence negating num_joined_users
- rooms_to_order_value[room_id] = (-num_joined_users, room_id)
-
- logger.info("Getting ordering for %i rooms since %s",
- len(room_ids), stream_token)
- yield concurrently_execute(get_order_for_room, room_ids, 10)
-
- sorted_entries = sorted(rooms_to_order_value.items(), key=lambda e: e[1])
- sorted_rooms = [room_id for room_id, _ in sorted_entries]
+ bounds = (batch_token.last_joined_members, batch_token.last_room_id)
+ forwards = batch_token.direction_is_forward
+ else:
+ batch_token = None
+ bounds = None
- # `sorted_rooms` should now be a list of all public room ids that is
- # stable across pagination. Therefore, we can use indices into this
- # list as our pagination tokens.
+ forwards = True
- # Filter out rooms that we don't want to return
- rooms_to_scan = [
- r for r in sorted_rooms
- if r not in newly_unpublished and rooms_to_num_joined[r] > 0
- ]
+ # we request one more than wanted to see if there are more pages to come
+ probing_limit = limit + 1 if limit is not None else None
- total_room_count = len(rooms_to_scan)
+ results = yield self.store.get_largest_public_rooms(
+ network_tuple,
+ search_filter,
+ probing_limit,
+ bounds=bounds,
+ forwards=forwards,
+ ignore_non_federatable=from_federation,
+ )
- if since_token:
- # Filter out rooms we've already returned previously
- # `since_token.current_limit` is the index of the last room we
- # sent down, so we exclude it and everything before/after it.
- if since_token.direction_is_forward:
- rooms_to_scan = rooms_to_scan[since_token.current_limit + 1:]
+ def build_room_entry(room):
+ entry = {
+ "room_id": room["room_id"],
+ "name": room["name"],
+ "topic": room["topic"],
+ "canonical_alias": room["canonical_alias"],
+ "num_joined_members": room["joined_members"],
+ "avatar_url": room["avatar"],
+ "world_readable": room["history_visibility"] == "world_readable",
+ "guest_can_join": room["guest_access"] == "can_join",
+ }
+
+ # Filter out Nones – rather omit the field altogether
+ return {k: v for k, v in entry.items() if v is not None}
+
+ results = [build_room_entry(r) for r in results]
+
+ response = {}
+ num_results = len(results)
+ if limit is not None:
+ more_to_come = num_results == probing_limit
+
+ # Depending on direction we trim either the front or back.
+ if forwards:
+ results = results[:limit]
else:
- rooms_to_scan = rooms_to_scan[:since_token.current_limit]
- rooms_to_scan.reverse()
-
- logger.info("After sorting and filtering, %i rooms remain",
- len(rooms_to_scan))
-
- # _append_room_entry_to_chunk will append to chunk but will stop if
- # len(chunk) > limit
- #
- # Normally we will generate enough results on the first iteration here,
- # but if there is a search filter, _append_room_entry_to_chunk may
- # filter some results out, in which case we loop again.
- #
- # We don't want to scan over the entire range either as that
- # would potentially waste a lot of work.
- #
- # XXX if there is no limit, we may end up DoSing the server with
- # calls to get_current_state_ids for every single room on the
- # server. Surely we should cap this somehow?
- #
- if limit:
- step = limit + 1
+ results = results[-limit:]
else:
- # step cannot be zero
- step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1
-
- chunk = []
- for i in range(0, len(rooms_to_scan), step):
- if timeout and self.clock.time() > timeout:
- raise Exception("Timed out searching room directory")
-
- batch = rooms_to_scan[i:i + step]
- logger.info("Processing %i rooms for result", len(batch))
- yield concurrently_execute(
- lambda r: self._append_room_entry_to_chunk(
- r, rooms_to_num_joined[r],
- chunk, limit, search_filter,
- from_federation=from_federation,
- ),
- batch, 5,
- )
- logger.info("Now %i rooms in result", len(chunk))
- if len(chunk) >= limit + 1:
- break
-
- chunk.sort(key=lambda e: (-e["num_joined_members"], e["room_id"]))
-
- # Work out the new limit of the batch for pagination, or None if we
- # know there are no more results that would be returned.
- # i.e., [since_token.current_limit..new_limit] is the batch of rooms
- # we've returned (or the reverse if we paginated backwards)
- # We tried to pull out limit + 1 rooms above, so if we have <= limit
- # then we know there are no more results to return
- new_limit = None
- if chunk and (not limit or len(chunk) > limit):
-
- if not since_token or since_token.direction_is_forward:
- if limit:
- chunk = chunk[:limit]
- last_room_id = chunk[-1]["room_id"]
+ more_to_come = False
+
+ if num_results > 0:
+ final_entry = results[-1]
+ initial_entry = results[0]
+
+ if forwards:
+ if batch_token:
+ # If there was a token given then we assume that there
+ # must be previous results.
+ response["prev_batch"] = RoomListNextBatch(
+ last_joined_members=initial_entry["num_joined_members"],
+ last_room_id=initial_entry["room_id"],
+ direction_is_forward=False,
+ ).to_token()
+
+ if more_to_come:
+ response["next_batch"] = RoomListNextBatch(
+ last_joined_members=final_entry["num_joined_members"],
+ last_room_id=final_entry["room_id"],
+ direction_is_forward=True,
+ ).to_token()
else:
- if limit:
- chunk = chunk[-limit:]
- last_room_id = chunk[0]["room_id"]
-
- new_limit = sorted_rooms.index(last_room_id)
-
- results = {
- "chunk": chunk,
- "total_room_count_estimate": total_room_count,
- }
-
- if since_token:
- results["new_rooms"] = bool(newly_visible)
-
- if not since_token or since_token.direction_is_forward:
- if new_limit is not None:
- results["next_batch"] = RoomListNextBatch(
- stream_ordering=stream_token,
- public_room_stream_id=public_room_stream_id,
- current_limit=new_limit,
- direction_is_forward=True,
- ).to_token()
-
- if since_token:
- results["prev_batch"] = since_token.copy_and_replace(
- direction_is_forward=False,
- current_limit=since_token.current_limit + 1,
- ).to_token()
- else:
- if new_limit is not None:
- results["prev_batch"] = RoomListNextBatch(
- stream_ordering=stream_token,
- public_room_stream_id=public_room_stream_id,
- current_limit=new_limit,
- direction_is_forward=False,
- ).to_token()
-
- if since_token:
- results["next_batch"] = since_token.copy_and_replace(
- direction_is_forward=True,
- current_limit=since_token.current_limit - 1,
- ).to_token()
-
- defer.returnValue(results)
-
- @defer.inlineCallbacks
- def _append_room_entry_to_chunk(self, room_id, num_joined_users, chunk, limit,
- search_filter, from_federation=False):
- """Generate the entry for a room in the public room list and append it
- to the `chunk` if it matches the search filter
-
- Args:
- room_id (str): The ID of the room.
- num_joined_users (int): The number of joined users in the room.
- chunk (list)
- limit (int|None): Maximum amount of rooms to display. Function will
- return if length of chunk is greater than limit + 1.
- search_filter (dict|None)
- from_federation (bool): Whether this request originated from a
- federating server or a client. Used for room filtering.
- """
- if limit and len(chunk) > limit + 1:
- # We've already got enough, so lets just drop it.
- return
-
- result = yield self.generate_room_entry(room_id, num_joined_users)
- if not result:
- return
-
- if from_federation and not result.get("m.federate", True):
- # This is a room that other servers cannot join. Do not show them
- # this room.
- return
+ if batch_token:
+ response["next_batch"] = RoomListNextBatch(
+ last_joined_members=final_entry["num_joined_members"],
+ last_room_id=final_entry["room_id"],
+ direction_is_forward=True,
+ ).to_token()
+
+ if more_to_come:
+ response["prev_batch"] = RoomListNextBatch(
+ last_joined_members=initial_entry["num_joined_members"],
+ last_room_id=initial_entry["room_id"],
+ direction_is_forward=False,
+ ).to_token()
+
+ response["chunk"] = results
+
+ response["total_room_count_estimate"] = yield self.store.count_public_rooms(
+ network_tuple, ignore_non_federatable=from_federation
+ )
- if _matches_room_entry(result, search_filter):
- chunk.append(result)
+ return response
@cachedInlineCallbacks(num_args=1, cache_context=True)
- def generate_room_entry(self, room_id, num_joined_users, cache_context,
- with_alias=True, allow_private=False):
+ def generate_room_entry(
+ self,
+ room_id,
+ num_joined_users,
+ cache_context,
+ with_alias=True,
+ allow_private=False,
+ ):
"""Returns the entry for a room
Args:
@@ -360,33 +250,31 @@ class RoomListHandler(BaseHandler):
Deferred[dict|None]: Returns a room entry as a dictionary, or None if this
room was determined not to be shown publicly.
"""
- result = {
- "room_id": room_id,
- "num_joined_members": num_joined_users,
- }
+ result = {"room_id": room_id, "num_joined_members": num_joined_users}
current_state_ids = yield self.store.get_current_state_ids(
- room_id, on_invalidate=cache_context.invalidate,
+ room_id, on_invalidate=cache_context.invalidate
)
- event_map = yield self.store.get_events([
- event_id for key, event_id in iteritems(current_state_ids)
- if key[0] in (
- EventTypes.Create,
- EventTypes.JoinRules,
- EventTypes.Name,
- EventTypes.Topic,
- EventTypes.CanonicalAlias,
- EventTypes.RoomHistoryVisibility,
- EventTypes.GuestAccess,
- "m.room.avatar",
- )
- ])
+ event_map = yield self.store.get_events(
+ [
+ event_id
+ for key, event_id in iteritems(current_state_ids)
+ if key[0]
+ in (
+ EventTypes.Create,
+ EventTypes.JoinRules,
+ EventTypes.Name,
+ EventTypes.Topic,
+ EventTypes.CanonicalAlias,
+ EventTypes.RoomHistoryVisibility,
+ EventTypes.GuestAccess,
+ "m.room.avatar",
+ )
+ ]
+ )
- current_state = {
- (ev.type, ev.state_key): ev
- for ev in event_map.values()
- }
+ current_state = {(ev.type, ev.state_key): ev for ev in event_map.values()}
# Double check that this is actually a public room.
@@ -394,7 +282,7 @@ class RoomListHandler(BaseHandler):
if join_rules_event:
join_rule = join_rules_event.content.get("join_rule", None)
if not allow_private and join_rule and join_rule != JoinRules.PUBLIC:
- defer.returnValue(None)
+ return None
# Return whether this room is open to federation users or not
create_event = current_state.get((EventTypes.Create, ""))
@@ -443,76 +331,125 @@ class RoomListHandler(BaseHandler):
if avatar_url:
result["avatar_url"] = avatar_url
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
- def get_remote_public_room_list(self, server_name, limit=None, since_token=None,
- search_filter=None, include_all_networks=False,
- third_party_instance_id=None,):
+ def get_remote_public_room_list(
+ self,
+ server_name,
+ limit=None,
+ since_token=None,
+ search_filter=None,
+ include_all_networks=False,
+ third_party_instance_id=None,
+ ):
if not self.enable_room_list_search:
- defer.returnValue({
- "chunk": [],
- "total_room_count_estimate": 0,
- })
+ return {"chunk": [], "total_room_count_estimate": 0}
if search_filter:
- # We currently don't support searching across federation, so we have
+ # Searching across federation is defined in MSC2197.
+ # However, the remote homeserver may or may not actually support it.
+ # So we first try an MSC2197 remote-filtered search, then fall back
+ # to a locally-filtered search if we must.
+
+ try:
+ res = yield self._get_remote_list_cached(
+ server_name,
+ limit=limit,
+ since_token=since_token,
+ include_all_networks=include_all_networks,
+ third_party_instance_id=third_party_instance_id,
+ search_filter=search_filter,
+ )
+ return res
+ except HttpResponseException as hre:
+ syn_err = hre.to_synapse_error()
+ if hre.code in (404, 405) or syn_err.errcode in (
+ Codes.UNRECOGNIZED,
+ Codes.NOT_FOUND,
+ ):
+ logger.debug("Falling back to locally-filtered /publicRooms")
+ else:
+ raise # Not an error that should trigger a fallback.
+
+ # if we reach this point, then we fall back to the situation where
+ # we currently don't support searching across federation, so we have
# to do it manually without pagination
limit = None
since_token = None
res = yield self._get_remote_list_cached(
- server_name, limit=limit, since_token=since_token,
+ server_name,
+ limit=limit,
+ since_token=since_token,
include_all_networks=include_all_networks,
third_party_instance_id=third_party_instance_id,
)
if search_filter:
- res = {"chunk": [
- entry
- for entry in list(res.get("chunk", []))
- if _matches_room_entry(entry, search_filter)
- ]}
-
- defer.returnValue(res)
-
- def _get_remote_list_cached(self, server_name, limit=None, since_token=None,
- search_filter=None, include_all_networks=False,
- third_party_instance_id=None,):
+ res = {
+ "chunk": [
+ entry
+ for entry in list(res.get("chunk", []))
+ if _matches_room_entry(entry, search_filter)
+ ]
+ }
+
+ return res
+
+ def _get_remote_list_cached(
+ self,
+ server_name,
+ limit=None,
+ since_token=None,
+ search_filter=None,
+ include_all_networks=False,
+ third_party_instance_id=None,
+ ):
repl_layer = self.hs.get_federation_client()
if search_filter:
# We can't cache when asking for search
return repl_layer.get_public_rooms(
- server_name, limit=limit, since_token=since_token,
- search_filter=search_filter, include_all_networks=include_all_networks,
+ server_name,
+ limit=limit,
+ since_token=since_token,
+ search_filter=search_filter,
+ include_all_networks=include_all_networks,
third_party_instance_id=third_party_instance_id,
)
key = (
- server_name, limit, since_token, include_all_networks,
+ server_name,
+ limit,
+ since_token,
+ include_all_networks,
third_party_instance_id,
)
return self.remote_response_cache.wrap(
key,
repl_layer.get_public_rooms,
- server_name, limit=limit, since_token=since_token,
+ server_name,
+ limit=limit,
+ since_token=since_token,
search_filter=search_filter,
include_all_networks=include_all_networks,
third_party_instance_id=third_party_instance_id,
)
-class RoomListNextBatch(namedtuple("RoomListNextBatch", (
- "stream_ordering", # stream_ordering of the first public room list
- "public_room_stream_id", # public room stream id for first public room list
- "current_limit", # The number of previous rooms returned
- "direction_is_forward", # Bool if this is a next_batch, false if prev_batch
-))):
-
+class RoomListNextBatch(
+ namedtuple(
+ "RoomListNextBatch",
+ (
+ "last_joined_members", # The count to get rooms after/before
+ "last_room_id", # The room_id to get rooms after/before
+ "direction_is_forward", # Bool if this is a next_batch, false if prev_batch
+ ),
+ )
+):
KEY_DICT = {
- "stream_ordering": "s",
- "public_room_stream_id": "p",
- "current_limit": "n",
+ "last_joined_members": "m",
+ "last_room_id": "r",
"direction_is_forward": "d",
}
@@ -520,28 +457,20 @@ class RoomListNextBatch(namedtuple("RoomListNextBatch", (
@classmethod
def from_token(cls, token):
- if PY3:
- # The argument raw=False is only available on new versions of
- # msgpack, and only really needed on Python 3. Gate it behind
- # a PY3 check to avoid causing issues on Debian-packaged versions.
- decoded = msgpack.loads(decode_base64(token), raw=False)
- else:
- decoded = msgpack.loads(decode_base64(token))
- return RoomListNextBatch(**{
- cls.REVERSE_KEY_DICT[key]: val
- for key, val in decoded.items()
- })
+ decoded = msgpack.loads(decode_base64(token), raw=False)
+ return RoomListNextBatch(
+ **{cls.REVERSE_KEY_DICT[key]: val for key, val in decoded.items()}
+ )
def to_token(self):
- return encode_base64(msgpack.dumps({
- self.KEY_DICT[key]: val
- for key, val in self._asdict().items()
- }))
+ return encode_base64(
+ msgpack.dumps(
+ {self.KEY_DICT[key]: val for key, val in self._asdict().items()}
+ )
+ )
def copy_and_replace(self, **kwds):
- return self._replace(
- **kwds
- )
+ return self._replace(**kwds)
def _matches_room_entry(room_entry, search_filter):
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 790aeba9f5..cddc95413a 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -22,19 +22,16 @@ from six.moves import http_client
from twisted.internet import defer
-import synapse.server
-import synapse.types
+from synapse import types
from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import AuthError, Codes, ProxiedRequestError, SynapseError
+from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.ratelimiting import Ratelimiter
-from synapse.types import RoomID, UserID
+from synapse.types import Collection, RoomID, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room
logger = logging.getLogger(__name__)
-id_server_scheme = "https://"
-
class RoomMemberHandler(object):
# TODO(paul): This handler currently contains a messy conflation of
@@ -55,10 +52,10 @@ class RoomMemberHandler(object):
self.auth = hs.get_auth()
self.state_handler = hs.get_state_handler()
self.config = hs.config
- self.simple_http_client = hs.get_simple_http_client()
self.federation_handler = hs.get_handlers().federation_handler
self.directory_handler = hs.get_handlers().directory_handler
+ self.identity_handler = hs.get_handlers().identity_handler
self.registration_handler = hs.get_registration_handler()
self.profile_handler = hs.get_profile_handler()
self.event_creation_handler = hs.get_event_creation_handler()
@@ -70,7 +67,6 @@ class RoomMemberHandler(object):
self.spam_checker = hs.get_spam_checker()
self.third_party_event_rules = hs.get_third_party_event_rules()
self._server_notices_mxid = self.config.server_notices_mxid
- self.rewrite_identity_server_urls = self.config.rewrite_identity_server_urls
self._enable_lookup = hs.config.enable_3pid_lookup
self.allow_per_room_profiles = self.config.allow_per_room_profiles
self.ratelimiter = Ratelimiter()
@@ -94,7 +90,9 @@ class RoomMemberHandler(object):
raise NotImplementedError()
@abc.abstractmethod
- def _remote_reject_invite(self, remote_room_hosts, room_id, target):
+ def _remote_reject_invite(
+ self, requester, remote_room_hosts, room_id, target, content
+ ):
"""Attempt to reject an invite for a room this server is not in. If we
fail to do so we locally mark the invite as rejected.
@@ -104,6 +102,7 @@ class RoomMemberHandler(object):
reject invite
room_id (str)
target (UserID): The user rejecting the invite
+ content (dict): The content for the rejection event
Returns:
Deferred[dict]: A dictionary to be returned to the client, may
@@ -112,24 +111,6 @@ class RoomMemberHandler(object):
raise NotImplementedError()
@abc.abstractmethod
- def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
- """Get a guest access token for a 3PID, creating a guest account if
- one doesn't already exist.
-
- Args:
- requester (Requester)
- medium (str)
- address (str)
- inviter_user_id (str): The user ID who is trying to invite the
- 3PID
-
- Returns:
- Deferred[(str, str)]: A 2-tuple of `(user_id, access_token)` of the
- 3PID guest account.
- """
- raise NotImplementedError()
-
- @abc.abstractmethod
def _user_joined_room(self, target, room_id):
"""Notifies distributor on master process that the user has joined the
room.
@@ -159,8 +140,12 @@ class RoomMemberHandler(object):
@defer.inlineCallbacks
def _local_membership_update(
- self, requester, target, room_id, membership,
- prev_events_and_hashes,
+ self,
+ requester,
+ target,
+ room_id,
+ membership,
+ prev_event_ids: Collection[str],
txn_id=None,
ratelimit=True,
content=None,
@@ -183,38 +168,30 @@ class RoomMemberHandler(object):
"room_id": room_id,
"sender": requester.user.to_string(),
"state_key": user_id,
-
# For backwards compatibility:
"membership": membership,
},
token_id=requester.access_token_id,
txn_id=txn_id,
- prev_events_and_hashes=prev_events_and_hashes,
+ prev_event_ids=prev_event_ids,
require_consent=require_consent,
)
# Check if this event matches the previous membership event for the user.
duplicate = yield self.event_creation_handler.deduplicate_state_event(
- event, context,
+ event, context
)
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
- defer.returnValue(duplicate)
+ return duplicate
yield self.event_creation_handler.handle_new_client_event(
- requester,
- event,
- context,
- extra_users=[target],
- ratelimit=ratelimit,
+ requester, event, context, extra_users=[target], ratelimit=ratelimit
)
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
- prev_member_event_id = prev_state_ids.get(
- (EventTypes.Member, user_id),
- None
- )
+ prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
if event.membership == Membership.JOIN:
# Only fire user_joined_room if the user has actually joined the
@@ -226,37 +203,16 @@ class RoomMemberHandler(object):
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
yield self._user_joined_room(target, room_id)
-
- # Copy over direct message status and room tags if this is a join
- # on an upgraded room
-
- # Check if this is an upgraded room
- predecessor = yield self.store.get_room_predecessor(room_id)
-
- if predecessor:
- # It is an upgraded room. Copy over old tags
- self.copy_room_tags_and_direct_to_room(
- predecessor["room_id"], room_id, user_id,
- )
- # Move over old push rules
- self.store.move_push_rules_from_room_to_room_for_user(
- predecessor["room_id"], room_id, user_id,
- )
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
yield self._user_left_room(target, room_id)
- defer.returnValue(event)
+ return event
@defer.inlineCallbacks
- def copy_room_tags_and_direct_to_room(
- self,
- old_room_id,
- new_room_id,
- user_id,
- ):
+ def copy_room_tags_and_direct_to_room(self, old_room_id, new_room_id, user_id):
"""Copies the tags and direct room state from one room to another.
Args:
@@ -268,9 +224,7 @@ class RoomMemberHandler(object):
Deferred[None]
"""
# Retrieve user account data for predecessor room
- user_account_data, _ = yield self.store.get_account_data_for_user(
- user_id,
- )
+ user_account_data, _ = yield self.store.get_account_data_for_user(user_id)
# Copy direct message state if applicable
direct_rooms = user_account_data.get("m.direct", {})
@@ -284,35 +238,31 @@ class RoomMemberHandler(object):
# Save back to user's m.direct account data
yield self.store.add_account_data_for_user(
- user_id, "m.direct", direct_rooms,
+ user_id, "m.direct", direct_rooms
)
break
# Copy room tags if applicable
- room_tags = yield self.store.get_tags_for_room(
- user_id, old_room_id,
- )
+ room_tags = yield self.store.get_tags_for_room(user_id, old_room_id)
# Copy each room tag to the new room
for tag, tag_content in room_tags.items():
- yield self.store.add_tag_to_room(
- user_id, new_room_id, tag, tag_content
- )
+ yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
@defer.inlineCallbacks
def update_membership(
- self,
- requester,
- target,
- room_id,
- action,
- txn_id=None,
- remote_room_hosts=None,
- third_party_signed=None,
- ratelimit=True,
- content=None,
- new_room=False,
- require_consent=True,
+ self,
+ requester,
+ target,
+ room_id,
+ action,
+ txn_id=None,
+ remote_room_hosts=None,
+ third_party_signed=None,
+ ratelimit=True,
+ content=None,
+ new_room=False,
+ require_consent=True,
):
"""Update a users membership in a room
@@ -353,22 +303,22 @@ class RoomMemberHandler(object):
require_consent=require_consent,
)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _update_membership(
- self,
- requester,
- target,
- room_id,
- action,
- txn_id=None,
- remote_room_hosts=None,
- third_party_signed=None,
- ratelimit=True,
- content=None,
- new_room=False,
- require_consent=True,
+ self,
+ requester,
+ target,
+ room_id,
+ action,
+ txn_id=None,
+ remote_room_hosts=None,
+ third_party_signed=None,
+ ratelimit=True,
+ content=None,
+ new_room=False,
+ require_consent=True,
):
content_specified = bool(content)
if content is None:
@@ -402,7 +352,7 @@ class RoomMemberHandler(object):
if not remote_room_hosts:
remote_room_hosts = []
- if effective_membership_state not in ("leave", "ban",):
+ if effective_membership_state not in ("leave", "ban"):
is_blocked = yield self.store.is_room_blocked(room_id)
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
@@ -410,22 +360,19 @@ class RoomMemberHandler(object):
if effective_membership_state == Membership.INVITE:
# block any attempts to invite the server notices mxid
if target.to_string() == self._server_notices_mxid:
- raise SynapseError(
- http_client.FORBIDDEN,
- "Cannot invite this user",
- )
+ raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user")
block_invite = False
- if (self._server_notices_mxid is not None and
- requester.user.to_string() == self._server_notices_mxid):
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
# allow the server notices mxid to send invites
is_requester_admin = True
else:
- is_requester_admin = yield self.auth.is_server_admin(
- requester.user,
- )
+ is_requester_admin = yield self.auth.is_server_admin(requester.user)
if not is_requester_admin:
if self.config.block_non_admin_invites:
@@ -438,7 +385,8 @@ class RoomMemberHandler(object):
is_published = yield self.store.is_room_published(room_id)
if not self.spam_checker.user_may_invite(
- requester.user.to_string(), target.to_string(),
+ requester.user.to_string(),
+ target.to_string(),
third_party_invite=None,
room_id=room_id,
new_room=new_room,
@@ -448,19 +396,12 @@ class RoomMemberHandler(object):
block_invite = True
if block_invite:
- raise SynapseError(
- 403, "Invites have been disabled on this server",
- )
+ raise SynapseError(403, "Invites have been disabled on this server")
- prev_events_and_hashes = yield self.store.get_prev_events_for_room(
- room_id,
- )
- latest_event_ids = (
- event_id for (event_id, _, _) in prev_events_and_hashes
- )
+ latest_event_ids = yield self.store.get_prev_events_for_room(room_id)
current_state_ids = yield self.state_handler.get_current_state_ids(
- room_id, latest_event_ids=latest_event_ids,
+ room_id, latest_event_ids=latest_event_ids
)
# TODO: Refactor into dictionary of explicitly allowed transitions
@@ -475,13 +416,13 @@ class RoomMemberHandler(object):
403,
"Cannot unban user who was not banned"
" (membership=%s)" % old_membership,
- errcode=Codes.BAD_STATE
+ errcode=Codes.BAD_STATE,
)
if old_membership == "ban" and action != "unban":
raise SynapseError(
403,
"Cannot %s user who was banned" % (action,),
- errcode=Codes.BAD_STATE
+ errcode=Codes.BAD_STATE,
)
if old_state:
@@ -489,7 +430,7 @@ class RoomMemberHandler(object):
same_membership = old_membership == effective_membership_state
same_sender = requester.user.to_string() == old_state.sender
if same_sender and same_membership and same_content:
- defer.returnValue(old_state)
+ return old_state
if old_membership in ["ban", "leave"] and action == "kick":
raise AuthError(403, "The target user is not in the room")
@@ -497,8 +438,8 @@ class RoomMemberHandler(object):
# we don't allow people to reject invites to the server notice
# room, but they can leave it once they are joined.
if (
- old_membership == Membership.INVITE and
- effective_membership_state == Membership.LEAVE
+ old_membership == Membership.INVITE
+ and effective_membership_state == Membership.LEAVE
):
is_blocked = yield self._is_server_notice_room(room_id)
if is_blocked:
@@ -521,27 +462,24 @@ class RoomMemberHandler(object):
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
- if (self._server_notices_mxid is not None and
- requester.user.to_string() == self._server_notices_mxid):
+ if (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ ):
# allow the server notices mxid to join rooms
is_requester_admin = True
else:
- is_requester_admin = yield self.auth.is_server_admin(
- requester.user,
- )
+ is_requester_admin = yield self.auth.is_server_admin(requester.user)
inviter = yield self._get_inviter(target.to_string(), room_id)
if not is_requester_admin:
# We assume that if the spam checker allowed the user to create
# a room then they're allowed to join it.
if not new_room and not self.spam_checker.user_may_join_room(
- target.to_string(), room_id,
- is_invited=inviter is not None,
+ target.to_string(), room_id, is_invited=inviter is not None
):
- raise SynapseError(
- 403, "Not allowed to join this room",
- )
+ raise SynapseError(403, "Not allowed to join this room")
if not is_host_in_room:
if inviter and not self.hs.is_mine(inviter):
@@ -557,10 +495,11 @@ class RoomMemberHandler(object):
if requester.is_guest:
content["kind"] = "guest"
- ret = yield self._remote_join(
+ remote_join_response = yield self._remote_join(
requester, remote_room_hosts, room_id, target, content
)
- defer.returnValue(ret)
+
+ return remote_join_response
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
@@ -580,9 +519,9 @@ class RoomMemberHandler(object):
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
res = yield self._remote_reject_invite(
- requester, remote_room_hosts, room_id, target,
+ requester, remote_room_hosts, room_id, target, content,
)
- defer.returnValue(res)
+ return res
res = yield self._local_membership_update(
requester=requester,
@@ -591,21 +530,93 @@ class RoomMemberHandler(object):
membership=effective_membership_state,
txn_id=txn_id,
ratelimit=ratelimit,
- prev_events_and_hashes=prev_events_and_hashes,
+ prev_event_ids=latest_event_ids,
content=content,
require_consent=require_consent,
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
- def send_membership_event(
- self,
- requester,
- event,
- context,
- remote_room_hosts=None,
- ratelimit=True,
- ):
+ def transfer_room_state_on_room_upgrade(self, old_room_id, room_id):
+ """Upon our server becoming aware of an upgraded room, either by upgrading a room
+ ourselves or joining one, we can transfer over information from the previous room.
+
+ Copies user state (tags/push rules) for every local user that was in the old room, as
+ well as migrating the room directory state.
+
+ Args:
+ old_room_id (str): The ID of the old room
+
+ room_id (str): The ID of the new room
+
+ Returns:
+ Deferred
+ """
+ logger.info("Transferring room state from %s to %s", old_room_id, room_id)
+
+ # Find all local users that were in the old room and copy over each user's state
+ users = yield self.store.get_users_in_room(old_room_id)
+ yield self.copy_user_state_on_room_upgrade(old_room_id, room_id, users)
+
+ # Add new room to the room directory if the old room was there
+ # Remove old room from the room directory
+ old_room = yield self.store.get_room(old_room_id)
+ if old_room and old_room["is_public"]:
+ yield self.store.set_room_is_public(old_room_id, False)
+ yield self.store.set_room_is_public(room_id, True)
+
+ # Check if any groups we own contain the predecessor room
+ local_group_ids = yield self.store.get_local_groups_for_room(old_room_id)
+ for group_id in local_group_ids:
+ # Add new the new room to those groups
+ yield self.store.add_room_to_group(group_id, room_id, old_room["is_public"])
+
+ # Remove the old room from those groups
+ yield self.store.remove_room_from_group(group_id, old_room_id)
+
+ @defer.inlineCallbacks
+ def copy_user_state_on_room_upgrade(self, old_room_id, new_room_id, user_ids):
+ """Copy user-specific information when they join a new room when that new room is the
+ result of a room upgrade
+
+ Args:
+ old_room_id (str): The ID of upgraded room
+ new_room_id (str): The ID of the new room
+ user_ids (Iterable[str]): User IDs to copy state for
+
+ Returns:
+ Deferred
+ """
+
+ logger.debug(
+ "Copying over room tags and push rules from %s to %s for users %s",
+ old_room_id,
+ new_room_id,
+ user_ids,
+ )
+
+ for user_id in user_ids:
+ try:
+ # It is an upgraded room. Copy over old tags
+ yield self.copy_room_tags_and_direct_to_room(
+ old_room_id, new_room_id, user_id
+ )
+ # Copy over push rules
+ yield self.store.copy_push_rules_from_room_to_room_for_user(
+ old_room_id, new_room_id, user_id
+ )
+ except Exception:
+ logger.exception(
+ "Error copying tags and/or push rules from rooms %s to %s for user %s. "
+ "Skipping...",
+ old_room_id,
+ new_room_id,
+ user_id,
+ )
+ continue
+
+ @defer.inlineCallbacks
+ def send_membership_event(self, requester, event, context, ratelimit=True):
"""
Change the membership status of a user in a room.
@@ -615,36 +626,29 @@ class RoomMemberHandler(object):
act as the sender, will be skipped.
event (SynapseEvent): The membership event.
context: The context of the event.
- is_guest (bool): Whether the sender is a guest.
- room_hosts ([str]): Homeservers which are likely to already be in
- the room, and could be danced with in order to join this
- homeserver for the first time.
ratelimit (bool): Whether to rate limit this request.
Raises:
SynapseError if there was a problem changing the membership.
"""
- remote_room_hosts = remote_room_hosts or []
-
target_user = UserID.from_string(event.state_key)
room_id = event.room_id
if requester is not None:
sender = UserID.from_string(event.sender)
- assert sender == requester.user, (
- "Sender (%s) must be same as requester (%s)" %
- (sender, requester.user)
- )
+ assert (
+ sender == requester.user
+ ), "Sender (%s) must be same as requester (%s)" % (sender, requester.user)
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else:
- requester = synapse.types.create_requester(target_user)
+ requester = types.create_requester(target_user)
prev_event = yield self.event_creation_handler.deduplicate_state_event(
- event, context,
+ event, context
)
if prev_event is not None:
return
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
if event.membership == Membership.JOIN:
if requester.is_guest:
guest_can_join = yield self._can_guest_join(prev_state_ids)
@@ -659,16 +663,11 @@ class RoomMemberHandler(object):
raise SynapseError(403, "This room has been blocked on this server")
yield self.event_creation_handler.handle_new_client_event(
- requester,
- event,
- context,
- extra_users=[target_user],
- ratelimit=ratelimit,
+ requester, event, context, extra_users=[target_user], ratelimit=ratelimit
)
prev_member_event_id = prev_state_ids.get(
- (EventTypes.Member, event.state_key),
- None
+ (EventTypes.Member, event.state_key), None
)
if event.membership == Membership.JOIN:
@@ -694,11 +693,11 @@ class RoomMemberHandler(object):
"""
guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None)
if not guest_access_id:
- defer.returnValue(False)
+ return False
guest_access = yield self.store.get_event(guest_access_id)
- defer.returnValue(
+ return (
guest_access
and guest_access.content
and "guest_access" in guest_access.content
@@ -733,85 +732,91 @@ class RoomMemberHandler(object):
servers.remove(room_alias.domain)
servers.insert(0, room_alias.domain)
- defer.returnValue((RoomID.from_string(room_id), servers))
+ return RoomID.from_string(room_id), servers
@defer.inlineCallbacks
def _get_inviter(self, user_id, room_id):
- invite = yield self.store.get_invite_for_user_in_room(
- user_id=user_id,
- room_id=room_id,
+ invite = yield self.store.get_invite_for_local_user_in_room(
+ user_id=user_id, room_id=room_id
)
if invite:
- defer.returnValue(UserID.from_string(invite.sender))
+ return UserID.from_string(invite.sender)
@defer.inlineCallbacks
def do_3pid_invite(
- self,
- room_id,
- inviter,
- medium,
- address,
- id_server,
- requester,
- txn_id,
- new_room=False,
+ self,
+ room_id,
+ inviter,
+ medium,
+ address,
+ id_server,
+ requester,
+ txn_id,
+ new_room=False,
+ id_access_token=None,
):
if self.config.block_non_admin_invites:
- is_requester_admin = yield self.auth.is_server_admin(
- requester.user,
- )
+ is_requester_admin = yield self.auth.is_server_admin(requester.user)
if not is_requester_admin:
raise SynapseError(
- 403, "Invites have been disabled on this server",
- Codes.FORBIDDEN,
+ 403, "Invites have been disabled on this server", Codes.FORBIDDEN
)
# We need to rate limit *before* we send out any 3PID invites, so we
# can't just rely on the standard ratelimiting of events.
self.ratelimiter.ratelimit(
- requester.user.to_string(), time_now_s=self.hs.clock.time(),
+ requester.user.to_string(),
+ time_now_s=self.hs.clock.time(),
rate_hz=self.hs.config.rc_third_party_invite.per_second,
burst_count=self.hs.config.rc_third_party_invite.burst_count,
update=True,
)
can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited(
- medium, address, room_id,
+ medium, address, room_id
)
if not can_invite:
raise SynapseError(
- 403, "This third-party identifier can not be invited in this room",
+ 403,
+ "This third-party identifier can not be invited in this room",
Codes.FORBIDDEN,
)
- invitee = yield self._lookup_3pid(
- id_server, medium, address
+ can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited(
+ medium, address, room_id
+ )
+ if not can_invite:
+ raise SynapseError(
+ 403,
+ "This third-party identifier can not be invited in this room",
+ Codes.FORBIDDEN,
+ )
+
+ if not self._enable_lookup:
+ raise SynapseError(
+ 403, "Looking up third-party identifiers is denied from this server"
+ )
+
+ invitee = yield self.identity_handler.lookup_3pid(
+ id_server, medium, address, id_access_token
)
is_published = yield self.store.is_room_published(room_id)
if not self.spam_checker.user_may_invite(
- requester.user.to_string(), invitee,
- third_party_invite={
- "medium": medium,
- "address": address,
- },
+ requester.user.to_string(),
+ invitee,
+ third_party_invite={"medium": medium, "address": address},
room_id=room_id,
new_room=new_room,
published_room=is_published,
):
logger.info("Blocking invite due to spam checker")
- raise SynapseError(
- 403, "Invites have been disabled on this server",
- )
+ raise SynapseError(403, "Invites have been disabled on this server")
if invitee:
yield self.update_membership(
- requester,
- UserID.from_string(invitee),
- room_id,
- "invite",
- txn_id=txn_id,
+ requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
)
else:
yield self._make_and_store_3pid_invite(
@@ -821,53 +826,21 @@ class RoomMemberHandler(object):
address,
room_id,
inviter,
- txn_id=txn_id
+ txn_id=txn_id,
+ id_access_token=id_access_token,
)
- def _get_id_server_target(self, id_server):
- """Looks up an id_server's actual http endpoint
-
- Args:
- id_server (str): the server name to lookup.
-
- Returns:
- the http endpoint to connect to.
- """
- if id_server in self.rewrite_identity_server_urls:
- return self.rewrite_identity_server_urls[id_server]
-
- return id_server
-
- @defer.inlineCallbacks
- def _lookup_3pid(self, id_server, medium, address):
- """Looks up a 3pid in the passed identity server.
-
- Args:
- id_server (str): The server name (including port, if required)
- of the identity server to use.
- medium (str): The type of the third party identifier (e.g. "email").
- address (str): The third party identifier (e.g. "foo@example.com").
-
- Returns:
- str: the matrix ID of the 3pid, or None if it is not recognized.
- """
- try:
- data = yield self.identity_handler.lookup_3pid(id_server, medium, address)
- defer.returnValue(data.get("mxid"))
- except ProxiedRequestError as e:
- logger.warn("Error from identity server lookup: %s" % (e,))
- defer.returnValue(None)
-
@defer.inlineCallbacks
def _make_and_store_3pid_invite(
- self,
- requester,
- id_server,
- medium,
- address,
- room_id,
- user,
- txn_id
+ self,
+ requester,
+ id_server,
+ medium,
+ address,
+ room_id,
+ user,
+ txn_id,
+ id_access_token=None,
):
room_state = yield self.state_handler.get_current_state(room_id)
@@ -902,21 +875,25 @@ class RoomMemberHandler(object):
if room_avatar_event:
room_avatar_url = room_avatar_event.content.get("url", "")
- token, public_keys, fallback_public_key, display_name = (
- yield self._ask_id_server_for_third_party_invite(
- requester=requester,
- id_server=id_server,
- medium=medium,
- address=address,
- room_id=room_id,
- inviter_user_id=user.to_string(),
- room_alias=canonical_room_alias,
- room_avatar_url=room_avatar_url,
- room_join_rules=room_join_rules,
- room_name=room_name,
- inviter_display_name=inviter_display_name,
- inviter_avatar_url=inviter_avatar_url
- )
+ (
+ token,
+ public_keys,
+ fallback_public_key,
+ display_name,
+ ) = yield self.identity_handler.ask_id_server_for_third_party_invite(
+ requester=requester,
+ id_server=id_server,
+ medium=medium,
+ address=address,
+ room_id=room_id,
+ inviter_user_id=user.to_string(),
+ room_alias=canonical_room_alias,
+ room_avatar_url=room_avatar_url,
+ room_join_rules=room_join_rules,
+ room_name=room_name,
+ inviter_display_name=inviter_display_name,
+ inviter_avatar_url=inviter_avatar_url,
+ id_access_token=id_access_token,
)
yield self.event_creation_handler.create_and_send_nonmember_event(
@@ -926,7 +903,6 @@ class RoomMemberHandler(object):
"content": {
"display_name": display_name,
"public_keys": public_keys,
-
# For backwards compatibility:
"key_validity_url": fallback_public_key["key_validity_url"],
"public_key": fallback_public_key["public_key"],
@@ -940,111 +916,13 @@ class RoomMemberHandler(object):
)
@defer.inlineCallbacks
- def _ask_id_server_for_third_party_invite(
- self,
- requester,
- id_server,
- medium,
- address,
- room_id,
- inviter_user_id,
- room_alias,
- room_avatar_url,
- room_join_rules,
- room_name,
- inviter_display_name,
- inviter_avatar_url
- ):
- """
- Asks an identity server for a third party invite.
-
- Args:
- requester (Requester)
- id_server (str): hostname + optional port for the identity server.
- medium (str): The literal string "email".
- address (str): The third party address being invited.
- room_id (str): The ID of the room to which the user is invited.
- inviter_user_id (str): The user ID of the inviter.
- room_alias (str): An alias for the room, for cosmetic notifications.
- room_avatar_url (str): The URL of the room's avatar, for cosmetic
- notifications.
- room_join_rules (str): The join rules of the email (e.g. "public").
- room_name (str): The m.room.name of the room.
- inviter_display_name (str): The current display name of the
- inviter.
- inviter_avatar_url (str): The URL of the inviter's avatar.
-
- Returns:
- A deferred tuple containing:
- token (str): The token which must be signed to prove authenticity.
- public_keys ([{"public_key": str, "key_validity_url": str}]):
- public_key is a base64-encoded ed25519 public key.
- fallback_public_key: One element from public_keys.
- display_name (str): A user-friendly name to represent the invited
- user.
- """
-
- target = self._get_id_server_target(id_server)
- is_url = "%s%s/_matrix/identity/api/v1/store-invite" % (
- id_server_scheme, target,
- )
-
- invite_config = {
- "medium": medium,
- "address": address,
- "room_id": room_id,
- "room_alias": room_alias,
- "room_avatar_url": room_avatar_url,
- "room_join_rules": room_join_rules,
- "room_name": room_name,
- "sender": inviter_user_id,
- "sender_display_name": inviter_display_name,
- "sender_avatar_url": inviter_avatar_url,
- }
-
- if self.config.invite_3pid_guest:
- guest_user_id, guest_access_token = yield self.get_or_register_3pid_guest(
- requester=requester,
- medium=medium,
- address=address,
- inviter_user_id=inviter_user_id,
- )
-
- invite_config.update({
- "guest_access_token": guest_access_token,
- "guest_user_id": guest_user_id,
- })
-
- data = yield self.simple_http_client.post_urlencoded_get_json(
- is_url,
- invite_config
- )
- # TODO: Check for success
- token = data["token"]
- public_keys = data.get("public_keys", [])
- if "public_key" in data:
- fallback_public_key = {
- "public_key": data["public_key"],
- "key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
- id_server_scheme, target,
- ),
- }
- else:
- fallback_public_key = public_keys[0]
-
- if not public_keys:
- public_keys.append(fallback_public_key)
- display_name = data["display_name"]
- defer.returnValue((token, public_keys, fallback_public_key, display_name))
-
- @defer.inlineCallbacks
def _is_host_in_room(self, current_state_ids):
# Have we just created the room, and is this about to be the very
# first member event?
create_event_id = current_state_ids.get(("m.room.create", ""))
if len(current_state_ids) == 1 and create_event_id:
# We can only get here if we're in the process of creating the room
- defer.returnValue(True)
+ return True
for etype, state_key in current_state_ids:
if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
@@ -1056,16 +934,16 @@ class RoomMemberHandler(object):
continue
if event.membership == Membership.JOIN:
- defer.returnValue(True)
+ return True
- defer.returnValue(False)
+ return False
@defer.inlineCallbacks
def _is_server_notice_room(self, room_id):
if self._server_notices_mxid is None:
- defer.returnValue(False)
+ return False
user_ids = yield self.store.get_users_in_room(room_id)
- defer.returnValue(self._server_notices_mxid in user_ids)
+ return self._server_notices_mxid in user_ids
class RoomMemberMasterHandler(RoomMemberHandler):
@@ -1077,13 +955,48 @@ class RoomMemberMasterHandler(RoomMemberHandler):
self.distributor.declare("user_left_room")
@defer.inlineCallbacks
+ def _is_remote_room_too_complex(self, room_id, remote_room_hosts):
+ """
+ Check if complexity of a remote room is too great.
+
+ Args:
+ room_id (str)
+ remote_room_hosts (list[str])
+
+ Returns: bool of whether the complexity is too great, or None
+ if unable to be fetched
+ """
+ max_complexity = self.hs.config.limit_remote_rooms.complexity
+ complexity = yield self.federation_handler.get_room_complexity(
+ remote_room_hosts, room_id
+ )
+
+ if complexity:
+ return complexity["v1"] > max_complexity
+ return None
+
+ @defer.inlineCallbacks
+ def _is_local_room_too_complex(self, room_id):
+ """
+ Check if the complexity of a local room is too great.
+
+ Args:
+ room_id (str)
+
+ Returns: bool
+ """
+ max_complexity = self.hs.config.limit_remote_rooms.complexity
+ complexity = yield self.store.get_room_complexity(room_id)
+
+ return complexity["v1"] > max_complexity
+
+ @defer.inlineCallbacks
def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
"""Implements RoomMemberHandler._remote_join
"""
# filter ourselves out of remote_room_hosts: do_invite_join ignores it
# and if it is the only entry we'd like to return a 404 rather than a
# 500.
-
remote_room_hosts = [
host for host in remote_room_hosts if host != self.hs.hostname
]
@@ -1091,30 +1004,68 @@ class RoomMemberMasterHandler(RoomMemberHandler):
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
+ if self.hs.config.limit_remote_rooms.enabled:
+ # Fetch the room complexity
+ too_complex = yield self._is_remote_room_too_complex(
+ room_id, remote_room_hosts
+ )
+ if too_complex is True:
+ raise SynapseError(
+ code=400,
+ msg=self.hs.config.limit_remote_rooms.complexity_error,
+ errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
+ )
+
# We don't do an auth check if we are doing an invite
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
- yield self.federation_handler.do_invite_join(
- remote_room_hosts,
- room_id,
- user.to_string(),
- content,
+ yield defer.ensureDeferred(
+ self.federation_handler.do_invite_join(
+ remote_room_hosts, room_id, user.to_string(), content
+ )
)
yield self._user_joined_room(user, room_id)
+ # Check the room we just joined wasn't too large, if we didn't fetch the
+ # complexity of it before.
+ if self.hs.config.limit_remote_rooms.enabled:
+ if too_complex is False:
+ # We checked, and we're under the limit.
+ return
+
+ # Check again, but with the local state events
+ too_complex = yield self._is_local_room_too_complex(room_id)
+
+ if too_complex is False:
+ # We're under the limit.
+ return
+
+ # The room is too large. Leave.
+ requester = types.create_requester(user, None, False, None)
+ yield self.update_membership(
+ requester=requester, target=user, room_id=room_id, action="leave"
+ )
+ raise SynapseError(
+ code=400,
+ msg=self.hs.config.limit_remote_rooms.complexity_error,
+ errcode=Codes.RESOURCE_LIMIT_EXCEEDED,
+ )
+
@defer.inlineCallbacks
- def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
+ def _remote_reject_invite(
+ self, requester, remote_room_hosts, room_id, target, content
+ ):
"""Implements RoomMemberHandler._remote_reject_invite
"""
fed_handler = self.federation_handler
try:
- ret = yield fed_handler.do_remotely_reject_invite(
- remote_room_hosts,
- room_id,
- target.to_string(),
+ ret = yield defer.ensureDeferred(
+ fed_handler.do_remotely_reject_invite(
+ remote_room_hosts, room_id, target.to_string(), content=content,
+ )
)
- defer.returnValue(ret)
+ return ret
except Exception as e:
# if we were unable to reject the exception, just mark
# it as rejected on our end and plough ahead.
@@ -1122,18 +1073,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# The 'except' clause is very broad, but we need to
# capture everything from DNS failures upwards
#
- logger.warn("Failed to reject invite: %s", e)
-
- yield self.store.locally_reject_invite(
- target.to_string(), room_id
- )
- defer.returnValue({})
+ logger.warning("Failed to reject invite: %s", e)
- def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
- """Implements RoomMemberHandler.get_or_register_3pid_guest
- """
- rg = self.registration_handler
- return rg.get_or_register_3pid_guest(medium, address, inviter_user_id)
+ yield self.store.locally_reject_invite(target.to_string(), room_id)
+ return {}
def _user_joined_room(self, target, room_id):
"""Implements RoomMemberHandler._user_joined_room
@@ -1150,18 +1093,15 @@ class RoomMemberMasterHandler(RoomMemberHandler):
user_id = user.to_string()
member = yield self.state_handler.get_current_state(
- room_id=room_id,
- event_type=EventTypes.Member,
- state_key=user_id
+ room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
membership = member.membership if member else None
if membership is not None and membership not in [
- Membership.LEAVE, Membership.BAN
+ Membership.LEAVE,
+ Membership.BAN,
]:
- raise SynapseError(400, "User %s in room %s" % (
- user_id, room_id
- ))
+ raise SynapseError(400, "User %s in room %s" % (user_id, room_id))
if membership:
yield self.store.forget(user_id, room_id)
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index acc6eb8099..69be86893b 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -20,7 +20,6 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.handlers.room_member import RoomMemberHandler
from synapse.replication.http.membership import (
- ReplicationRegister3PIDGuestRestServlet as Repl3PID,
ReplicationRemoteJoinRestServlet as ReplRemoteJoin,
ReplicationRemoteRejectInviteRestServlet as ReplRejectInvite,
ReplicationUserJoinedLeftRoomRestServlet as ReplJoinedLeft,
@@ -33,7 +32,6 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
def __init__(self, hs):
super(RoomMemberWorkerHandler, self).__init__(hs)
- self._get_register_3pid_client = Repl3PID.make_client(hs)
self._remote_join_client = ReplRemoteJoin.make_client(hs)
self._remote_reject_client = ReplRejectInvite.make_client(hs)
self._notify_change_client = ReplJoinedLeft.make_client(hs)
@@ -55,9 +53,11 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
yield self._user_joined_room(user, room_id)
- defer.returnValue(ret)
+ return ret
- def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
+ def _remote_reject_invite(
+ self, requester, remote_room_hosts, room_id, target, content
+ ):
"""Implements RoomMemberHandler._remote_reject_invite
"""
return self._remote_reject_client(
@@ -65,32 +65,19 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
remote_room_hosts=remote_room_hosts,
room_id=room_id,
user_id=target.to_string(),
+ content=content,
)
def _user_joined_room(self, target, room_id):
"""Implements RoomMemberHandler._user_joined_room
"""
return self._notify_change_client(
- user_id=target.to_string(),
- room_id=room_id,
- change="joined",
+ user_id=target.to_string(), room_id=room_id, change="joined"
)
def _user_left_room(self, target, room_id):
"""Implements RoomMemberHandler._user_left_room
"""
return self._notify_change_client(
- user_id=target.to_string(),
- room_id=room_id,
- change="left",
- )
-
- def get_or_register_3pid_guest(self, requester, medium, address, inviter_user_id):
- """Implements RoomMemberHandler.get_or_register_3pid_guest
- """
- return self._get_register_3pid_client(
- requester=requester,
- medium=medium,
- address=address,
- inviter_user_id=inviter_user_id,
+ user_id=target.to_string(), room_id=room_id, change="left"
)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
new file mode 100644
index 0000000000..72c109981b
--- /dev/null
+++ b/synapse/handlers/saml_handler.py
@@ -0,0 +1,408 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import re
+from typing import Tuple
+
+import attr
+import saml2
+import saml2.response
+from saml2.client import Saml2Client
+
+from synapse.api.errors import SynapseError
+from synapse.config import ConfigError
+from synapse.http.server import finish_request
+from synapse.http.servlet import parse_string
+from synapse.module_api import ModuleApi
+from synapse.types import (
+ UserID,
+ map_username_to_mxid_localpart,
+ mxid_localpart_allowed_characters,
+)
+from synapse.util.async_helpers import Linearizer
+from synapse.util.iterutils import chunk_seq
+
+logger = logging.getLogger(__name__)
+
+
+@attr.s
+class Saml2SessionData:
+ """Data we track about SAML2 sessions"""
+
+ # time the session was created, in milliseconds
+ creation_time = attr.ib()
+
+
+class SamlHandler:
+ def __init__(self, hs):
+ self._saml_client = Saml2Client(hs.config.saml2_sp_config)
+ self._auth_handler = hs.get_auth_handler()
+ self._registration_handler = hs.get_registration_handler()
+
+ self._clock = hs.get_clock()
+ self._datastore = hs.get_datastore()
+ self._hostname = hs.hostname
+ self._saml2_session_lifetime = hs.config.saml2_session_lifetime
+ self._grandfathered_mxid_source_attribute = (
+ hs.config.saml2_grandfathered_mxid_source_attribute
+ )
+
+ # plugin to do custom mapping from saml response to mxid
+ self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
+ hs.config.saml2_user_mapping_provider_config,
+ ModuleApi(hs, hs.get_auth_handler()),
+ )
+
+ # identifier for the external_ids table
+ self._auth_provider_id = "saml"
+
+ # a map from saml session id to Saml2SessionData object
+ self._outstanding_requests_dict = {}
+
+ # a lock on the mappings
+ self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
+
+ self._error_html_content = hs.config.saml2_error_html_content
+
+ def handle_redirect_request(self, client_redirect_url):
+ """Handle an incoming request to /login/sso/redirect
+
+ Args:
+ client_redirect_url (bytes): the URL that we should redirect the
+ client to when everything is done
+
+ Returns:
+ bytes: URL to redirect to
+ """
+ reqid, info = self._saml_client.prepare_for_authenticate(
+ relay_state=client_redirect_url
+ )
+
+ now = self._clock.time_msec()
+ self._outstanding_requests_dict[reqid] = Saml2SessionData(creation_time=now)
+
+ for key, value in info["headers"]:
+ if key == "Location":
+ return value
+
+ # this shouldn't happen!
+ raise Exception("prepare_for_authenticate didn't return a Location header")
+
+ async def handle_saml_response(self, request):
+ """Handle an incoming request to /_matrix/saml2/authn_response
+
+ Args:
+ request (SynapseRequest): the incoming request from the browser. We'll
+ respond to it with a redirect.
+
+ Returns:
+ Deferred[none]: Completes once we have handled the request.
+ """
+ resp_bytes = parse_string(request, "SAMLResponse", required=True)
+ relay_state = parse_string(request, "RelayState", required=True)
+
+ # expire outstanding sessions before parse_authn_request_response checks
+ # the dict.
+ self.expire_sessions()
+
+ try:
+ user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
+ except Exception as e:
+ # If decoding the response or mapping it to a user failed, then log the
+ # error and tell the user that something went wrong.
+ logger.error(e)
+
+ request.setResponseCode(400)
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(
+ b"Content-Length", b"%d" % (len(self._error_html_content),)
+ )
+ request.write(self._error_html_content.encode("utf8"))
+ finish_request(request)
+ return
+
+ self._auth_handler.complete_sso_login(user_id, request, relay_state)
+
+ async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url):
+ try:
+ saml2_auth = self._saml_client.parse_authn_request_response(
+ resp_bytes,
+ saml2.BINDING_HTTP_POST,
+ outstanding=self._outstanding_requests_dict,
+ )
+ except Exception as e:
+ logger.warning("Exception parsing SAML2 response: %s", e)
+ raise SynapseError(400, "Unable to parse SAML2 response: %s" % (e,))
+
+ if saml2_auth.not_signed:
+ logger.warning("SAML2 response was not signed")
+ raise SynapseError(400, "SAML2 response was not signed")
+
+ logger.debug("SAML2 response: %s", saml2_auth.origxml)
+ for assertion in saml2_auth.assertions:
+ # kibana limits the length of a log field, whereas this is all rather
+ # useful, so split it up.
+ count = 0
+ for part in chunk_seq(str(assertion), 10000):
+ logger.info(
+ "SAML2 assertion: %s%s", "(%i)..." % (count,) if count else "", part
+ )
+ count += 1
+
+ logger.info("SAML2 mapped attributes: %s", saml2_auth.ava)
+
+ self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
+
+ remote_user_id = self._user_mapping_provider.get_remote_user_id(
+ saml2_auth, client_redirect_url
+ )
+
+ if not remote_user_id:
+ raise Exception("Failed to extract remote user id from SAML response")
+
+ with (await self._mapping_lock.queue(self._auth_provider_id)):
+ # first of all, check if we already have a mapping for this user
+ logger.info(
+ "Looking for existing mapping for user %s:%s",
+ self._auth_provider_id,
+ remote_user_id,
+ )
+ registered_user_id = await self._datastore.get_user_by_external_id(
+ self._auth_provider_id, remote_user_id
+ )
+ if registered_user_id is not None:
+ logger.info("Found existing mapping %s", registered_user_id)
+ return registered_user_id
+
+ # backwards-compatibility hack: see if there is an existing user with a
+ # suitable mapping from the uid
+ if (
+ self._grandfathered_mxid_source_attribute
+ and self._grandfathered_mxid_source_attribute in saml2_auth.ava
+ ):
+ attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
+ user_id = UserID(
+ map_username_to_mxid_localpart(attrval), self._hostname
+ ).to_string()
+ logger.info(
+ "Looking for existing account based on mapped %s %s",
+ self._grandfathered_mxid_source_attribute,
+ user_id,
+ )
+
+ users = await self._datastore.get_users_by_id_case_insensitive(user_id)
+ if users:
+ registered_user_id = list(users.keys())[0]
+ logger.info("Grandfathering mapping to %s", registered_user_id)
+ await self._datastore.record_user_external_id(
+ self._auth_provider_id, remote_user_id, registered_user_id
+ )
+ return registered_user_id
+
+ # Map saml response to user attributes using the configured mapping provider
+ for i in range(1000):
+ attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
+ saml2_auth, i, client_redirect_url=client_redirect_url,
+ )
+
+ logger.debug(
+ "Retrieved SAML attributes from user mapping provider: %s "
+ "(attempt %d)",
+ attribute_dict,
+ i,
+ )
+
+ localpart = attribute_dict.get("mxid_localpart")
+ if not localpart:
+ logger.error(
+ "SAML mapping provider plugin did not return a "
+ "mxid_localpart object"
+ )
+ raise SynapseError(500, "Error parsing SAML2 response")
+
+ displayname = attribute_dict.get("displayname")
+
+ # Check if this mxid already exists
+ if not await self._datastore.get_users_by_id_case_insensitive(
+ UserID(localpart, self._hostname).to_string()
+ ):
+ # This mxid is free
+ break
+ else:
+ # Unable to generate a username in 1000 iterations
+ # Break and return error to the user
+ raise SynapseError(
+ 500, "Unable to generate a Matrix ID from the SAML response"
+ )
+
+ logger.info("Mapped SAML user to local part %s", localpart)
+
+ registered_user_id = await self._registration_handler.register_user(
+ localpart=localpart, default_display_name=displayname
+ )
+
+ await self._datastore.record_user_external_id(
+ self._auth_provider_id, remote_user_id, registered_user_id
+ )
+ return registered_user_id
+
+ def expire_sessions(self):
+ expire_before = self._clock.time_msec() - self._saml2_session_lifetime
+ to_expire = set()
+ for reqid, data in self._outstanding_requests_dict.items():
+ if data.creation_time < expire_before:
+ to_expire.add(reqid)
+ for reqid in to_expire:
+ logger.debug("Expiring session id %s", reqid)
+ del self._outstanding_requests_dict[reqid]
+
+
+DOT_REPLACE_PATTERN = re.compile(
+ ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
+)
+
+
+def dot_replace_for_mxid(username: str) -> str:
+ username = username.lower()
+ username = DOT_REPLACE_PATTERN.sub(".", username)
+
+ # regular mxids aren't allowed to start with an underscore either
+ username = re.sub("^_", "", username)
+ return username
+
+
+MXID_MAPPER_MAP = {
+ "hexencode": map_username_to_mxid_localpart,
+ "dotreplace": dot_replace_for_mxid,
+}
+
+
+@attr.s
+class SamlConfig(object):
+ mxid_source_attribute = attr.ib()
+ mxid_mapper = attr.ib()
+
+
+class DefaultSamlMappingProvider(object):
+ __version__ = "0.0.1"
+
+ def __init__(self, parsed_config: SamlConfig, module_api: ModuleApi):
+ """The default SAML user mapping provider
+
+ Args:
+ parsed_config: Module configuration
+ module_api: module api proxy
+ """
+ self._mxid_source_attribute = parsed_config.mxid_source_attribute
+ self._mxid_mapper = parsed_config.mxid_mapper
+
+ self._grandfathered_mxid_source_attribute = (
+ module_api._hs.config.saml2_grandfathered_mxid_source_attribute
+ )
+
+ def get_remote_user_id(
+ self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str
+ ):
+ """Extracts the remote user id from the SAML response"""
+ try:
+ return saml_response.ava["uid"][0]
+ except KeyError:
+ logger.warning("SAML2 response lacks a 'uid' attestation")
+ raise SynapseError(400, "'uid' not in SAML2 response")
+
+ def saml_response_to_user_attributes(
+ self,
+ saml_response: saml2.response.AuthnResponse,
+ failures: int,
+ client_redirect_url: str,
+ ) -> dict:
+ """Maps some text from a SAML response to attributes of a new user
+
+ Args:
+ saml_response: A SAML auth response object
+
+ failures: How many times a call to this function with this
+ saml_response has resulted in a failure
+
+ client_redirect_url: where the client wants to redirect to
+
+ Returns:
+ dict: A dict containing new user attributes. Possible keys:
+ * mxid_localpart (str): Required. The localpart of the user's mxid
+ * displayname (str): The displayname of the user
+ """
+ try:
+ mxid_source = saml_response.ava[self._mxid_source_attribute][0]
+ except KeyError:
+ logger.warning(
+ "SAML2 response lacks a '%s' attestation", self._mxid_source_attribute,
+ )
+ raise SynapseError(
+ 400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
+ )
+
+ # Use the configured mapper for this mxid_source
+ base_mxid_localpart = self._mxid_mapper(mxid_source)
+
+ # Append suffix integer if last call to this function failed to produce
+ # a usable mxid
+ localpart = base_mxid_localpart + (str(failures) if failures else "")
+
+ # Retrieve the display name from the saml response
+ # If displayname is None, the mxid_localpart will be used instead
+ displayname = saml_response.ava.get("displayName", [None])[0]
+
+ return {
+ "mxid_localpart": localpart,
+ "displayname": displayname,
+ }
+
+ @staticmethod
+ def parse_config(config: dict) -> SamlConfig:
+ """Parse the dict provided by the homeserver's config
+ Args:
+ config: A dictionary containing configuration options for this provider
+ Returns:
+ SamlConfig: A custom config object for this module
+ """
+ # Parse config options and use defaults where necessary
+ mxid_source_attribute = config.get("mxid_source_attribute", "uid")
+ mapping_type = config.get("mxid_mapping", "hexencode")
+
+ # Retrieve the associating mapping function
+ try:
+ mxid_mapper = MXID_MAPPER_MAP[mapping_type]
+ except KeyError:
+ raise ConfigError(
+ "saml2_config.user_mapping_provider.config: '%s' is not a valid "
+ "mxid_mapping value" % (mapping_type,)
+ )
+
+ return SamlConfig(mxid_source_attribute, mxid_mapper)
+
+ @staticmethod
+ def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
+ """Returns the required attributes of a SAML
+
+ Args:
+ config: A SamlConfig object containing configuration params for this provider
+
+ Returns:
+ tuple[set,set]: The first set equates to the saml auth response
+ attributes that are required for the module to function, whereas the
+ second set consists of those attributes which can be used if
+ available, but are not necessary
+ """
+ return {"uid", config.mxid_source_attribute}, {"displayName"}
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 9bba74d6c9..ec1542d416 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -21,7 +21,7 @@ from unpaddedbase64 import decode_base64, encode_base64
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import SynapseError
+from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
from synapse.storage.state import StateFilter
from synapse.visibility import filter_events_for_client
@@ -32,10 +32,12 @@ logger = logging.getLogger(__name__)
class SearchHandler(BaseHandler):
-
def __init__(self, hs):
super(SearchHandler, self).__init__(hs)
self._event_serializer = hs.get_event_client_serializer()
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def get_old_rooms_from_upgraded_room(self, room_id):
@@ -52,25 +54,40 @@ class SearchHandler(BaseHandler):
room_id (str): id of the room to search through.
Returns:
- Deferred[iterable[unicode]]: predecessor room ids
+ Deferred[iterable[str]]: predecessor room ids
"""
historical_room_ids = []
- while True:
- predecessor = yield self.store.get_room_predecessor(room_id)
+ # The initial room must have been known for us to get this far
+ predecessor = yield self.store.get_room_predecessor(room_id)
- # If no predecessor, assume we've hit a dead end
+ while True:
if not predecessor:
+ # We have reached the end of the chain of predecessors
+ break
+
+ if not isinstance(predecessor.get("room_id"), str):
+ # This predecessor object is malformed. Exit here
+ break
+
+ predecessor_room_id = predecessor["room_id"]
+
+ # Don't add it to the list until we have checked that we are in the room
+ try:
+ next_predecessor_room = yield self.store.get_room_predecessor(
+ predecessor_room_id
+ )
+ except NotFoundError:
+ # The predecessor is not a known room, so we are done here
break
- # Add predecessor's room ID
- historical_room_ids.append(predecessor["room_id"])
+ historical_room_ids.append(predecessor_room_id)
- # Scan through the old room for further predecessors
- room_id = predecessor["room_id"]
+ # And repeat
+ predecessor = next_predecessor_room
- defer.returnValue(historical_room_ids)
+ return historical_room_ids
@defer.inlineCallbacks
def search(self, user, content, batch=None):
@@ -93,7 +110,7 @@ class SearchHandler(BaseHandler):
batch_token = None
if batch:
try:
- b = decode_base64(batch).decode('ascii')
+ b = decode_base64(batch).decode("ascii")
batch_group, batch_group_key, batch_token = b.split("\n")
assert batch_group is not None
@@ -104,7 +121,9 @@ class SearchHandler(BaseHandler):
logger.info(
"Search batch properties: %r, %r, %r",
- batch_group, batch_group_key, batch_token,
+ batch_group,
+ batch_group_key,
+ batch_token,
)
logger.info("Search content: %s", content)
@@ -116,9 +135,9 @@ class SearchHandler(BaseHandler):
search_term = room_cat["search_term"]
# Which "keys" to search over in FTS query
- keys = room_cat.get("keys", [
- "content.body", "content.name", "content.topic",
- ])
+ keys = room_cat.get(
+ "keys", ["content.body", "content.name", "content.topic"]
+ )
# Filter to apply to results
filter_dict = room_cat.get("filter", {})
@@ -130,9 +149,7 @@ class SearchHandler(BaseHandler):
include_state = room_cat.get("include_state", False)
# Include context around each event?
- event_context = room_cat.get(
- "event_context", None
- )
+ event_context = room_cat.get("event_context", None)
# Group results together? May allow clients to paginate within a
# group
@@ -140,12 +157,8 @@ class SearchHandler(BaseHandler):
group_keys = [g["key"] for g in group_by]
if event_context is not None:
- before_limit = int(event_context.get(
- "before_limit", 5
- ))
- after_limit = int(event_context.get(
- "after_limit", 5
- ))
+ before_limit = int(event_context.get("before_limit", 5))
+ after_limit = int(event_context.get("after_limit", 5))
# Return the historic display name and avatar for the senders
# of the events?
@@ -159,18 +172,19 @@ class SearchHandler(BaseHandler):
if set(group_keys) - {"room_id", "sender"}:
raise SynapseError(
400,
- "Invalid group by keys: %r" % (set(group_keys) - {"room_id", "sender"},)
+ "Invalid group by keys: %r"
+ % (set(group_keys) - {"room_id", "sender"},),
)
search_filter = Filter(filter_dict)
# TODO: Search through left rooms too
- rooms = yield self.store.get_rooms_for_user_where_membership_is(
+ rooms = yield self.store.get_rooms_for_local_user_where_membership_is(
user.to_string(),
membership_list=[Membership.JOIN],
# membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
)
- room_ids = set(r.room_id for r in rooms)
+ room_ids = {r.room_id for r in rooms}
# If doing a subset of all rooms seearch, check if any of the rooms
# are from an upgraded room, and search their contents as well
@@ -190,15 +204,11 @@ class SearchHandler(BaseHandler):
room_ids.intersection_update({batch_group_key})
if not room_ids:
- defer.returnValue({
+ return {
"search_categories": {
- "room_events": {
- "results": [],
- "count": 0,
- "highlights": [],
- }
+ "room_events": {"results": [], "count": 0, "highlights": []}
}
- })
+ }
rank_map = {} # event_id -> rank of event
allowed_events = []
@@ -213,9 +223,7 @@ class SearchHandler(BaseHandler):
count = None
if order_by == "rank":
- search_result = yield self.store.search_msgs(
- room_ids, search_term, keys
- )
+ search_result = yield self.store.search_msgs(room_ids, search_term, keys)
count = search_result["count"]
@@ -231,23 +239,21 @@ class SearchHandler(BaseHandler):
filtered_events = search_filter.filter([r["event"] for r in results])
events = yield filter_events_for_client(
- self.store, user.to_string(), filtered_events
+ self.storage, user.to_string(), filtered_events
)
events.sort(key=lambda e: -rank_map[e.event_id])
- allowed_events = events[:search_filter.limit()]
+ allowed_events = events[: search_filter.limit()]
for e in allowed_events:
- rm = room_groups.setdefault(e.room_id, {
- "results": [],
- "order": rank_map[e.event_id],
- })
+ rm = room_groups.setdefault(
+ e.room_id, {"results": [], "order": rank_map[e.event_id]}
+ )
rm["results"].append(e.event_id)
- s = sender_group.setdefault(e.sender, {
- "results": [],
- "order": rank_map[e.event_id],
- })
+ s = sender_group.setdefault(
+ e.sender, {"results": [], "order": rank_map[e.event_id]}
+ )
s["results"].append(e.event_id)
elif order_by == "recent":
@@ -262,7 +268,10 @@ class SearchHandler(BaseHandler):
while len(room_events) < search_filter.limit() and i < 5:
i += 1
search_result = yield self.store.search_rooms(
- room_ids, search_term, keys, search_filter.limit() * 2,
+ room_ids,
+ search_term,
+ keys,
+ search_filter.limit() * 2,
pagination_token=pagination_token,
)
@@ -277,16 +286,14 @@ class SearchHandler(BaseHandler):
rank_map.update({r["event"].event_id: r["rank"] for r in results})
- filtered_events = search_filter.filter([
- r["event"] for r in results
- ])
+ filtered_events = search_filter.filter([r["event"] for r in results])
events = yield filter_events_for_client(
- self.store, user.to_string(), filtered_events
+ self.storage, user.to_string(), filtered_events
)
room_events.extend(events)
- room_events = room_events[:search_filter.limit()]
+ room_events = room_events[: search_filter.limit()]
if len(results) < search_filter.limit() * 2:
pagination_token = None
@@ -295,9 +302,7 @@ class SearchHandler(BaseHandler):
pagination_token = results[-1]["pagination_token"]
for event in room_events:
- group = room_groups.setdefault(event.room_id, {
- "results": [],
- })
+ group = room_groups.setdefault(event.room_id, {"results": []})
group["results"].append(event.event_id)
if room_events and len(room_events) >= search_filter.limit():
@@ -309,18 +314,23 @@ class SearchHandler(BaseHandler):
# it returns more from the same group (if applicable) rather
# than reverting to searching all results again.
if batch_group and batch_group_key:
- global_next_batch = encode_base64(("%s\n%s\n%s" % (
- batch_group, batch_group_key, pagination_token
- )).encode('ascii'))
+ global_next_batch = encode_base64(
+ (
+ "%s\n%s\n%s"
+ % (batch_group, batch_group_key, pagination_token)
+ ).encode("ascii")
+ )
else:
- global_next_batch = encode_base64(("%s\n%s\n%s" % (
- "all", "", pagination_token
- )).encode('ascii'))
+ global_next_batch = encode_base64(
+ ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii")
+ )
for room_id, group in room_groups.items():
- group["next_batch"] = encode_base64(("%s\n%s\n%s" % (
- "room_id", room_id, pagination_token
- )).encode('ascii'))
+ group["next_batch"] = encode_base64(
+ ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode(
+ "ascii"
+ )
+ )
allowed_events.extend(room_events)
@@ -338,20 +348,21 @@ class SearchHandler(BaseHandler):
contexts = {}
for event in allowed_events:
res = yield self.store.get_events_around(
- event.room_id, event.event_id, before_limit, after_limit,
+ event.room_id, event.event_id, before_limit, after_limit
)
logger.info(
"Context for search returned %d and %d events",
- len(res["events_before"]), len(res["events_after"]),
+ len(res["events_before"]),
+ len(res["events_after"]),
)
res["events_before"] = yield filter_events_for_client(
- self.store, user.to_string(), res["events_before"]
+ self.storage, user.to_string(), res["events_before"]
)
res["events_after"] = yield filter_events_for_client(
- self.store, user.to_string(), res["events_after"]
+ self.storage, user.to_string(), res["events_after"]
)
res["start"] = now_token.copy_and_replace(
@@ -363,12 +374,12 @@ class SearchHandler(BaseHandler):
).to_string()
if include_profile:
- senders = set(
+ senders = {
ev.sender
for ev in itertools.chain(
res["events_before"], [event], res["events_after"]
)
- )
+ }
if res["events_after"]:
last_event_id = res["events_after"][-1].event_id
@@ -379,7 +390,7 @@ class SearchHandler(BaseHandler):
[(EventTypes.Member, sender) for sender in senders]
)
- state = yield self.store.get_state_for_event(
+ state = yield self.state_store.get_state_for_event(
last_event_id, state_filter
)
@@ -401,20 +412,16 @@ class SearchHandler(BaseHandler):
time_now = self.clock.time_msec()
for context in contexts.values():
- context["events_before"] = (
- yield self._event_serializer.serialize_events(
- context["events_before"], time_now,
- )
+ context["events_before"] = yield self._event_serializer.serialize_events(
+ context["events_before"], time_now
)
- context["events_after"] = (
- yield self._event_serializer.serialize_events(
- context["events_after"], time_now,
- )
+ context["events_after"] = yield self._event_serializer.serialize_events(
+ context["events_after"], time_now
)
state_results = {}
if include_state:
- rooms = set(e.room_id for e in allowed_events)
+ rooms = {e.room_id for e in allowed_events}
for room_id in rooms:
state = yield self.state_handler.get_current_state(room_id)
state_results[room_id] = list(state.values())
@@ -426,11 +433,15 @@ class SearchHandler(BaseHandler):
results = []
for e in allowed_events:
- results.append({
- "rank": rank_map[e.event_id],
- "result": (yield self._event_serializer.serialize_event(e, time_now)),
- "context": contexts.get(e.event_id, {}),
- })
+ results.append(
+ {
+ "rank": rank_map[e.event_id],
+ "result": (
+ yield self._event_serializer.serialize_event(e, time_now)
+ ),
+ "context": contexts.get(e.event_id, {}),
+ }
+ )
rooms_cat_res = {
"results": results,
@@ -442,7 +453,7 @@ class SearchHandler(BaseHandler):
s = {}
for room_id, state in state_results.items():
s[room_id] = yield self._event_serializer.serialize_events(
- state, time_now,
+ state, time_now
)
rooms_cat_res["state"] = s
@@ -456,8 +467,4 @@ class SearchHandler(BaseHandler):
if global_next_batch:
rooms_cat_res["next_batch"] = global_next_batch
- defer.returnValue({
- "search_categories": {
- "room_events": rooms_cat_res
- }
- })
+ return {"search_categories": {"room_events": rooms_cat_res}}
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index b556d23173..1c826b9407 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -14,10 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import Optional
from twisted.internet import defer
from synapse.api.errors import Codes, StoreError, SynapseError
+from synapse.types import Requester
from ._base import BaseHandler
@@ -26,6 +28,7 @@ logger = logging.getLogger(__name__)
class SetPasswordHandler(BaseHandler):
"""Handler which deals with changing user account passwords"""
+
def __init__(self, hs):
super(SetPasswordHandler, self).__init__(hs)
self._auth_handler = hs.get_auth_handler()
@@ -33,13 +36,19 @@ class SetPasswordHandler(BaseHandler):
self._password_policy_handler = hs.get_password_policy_handler()
@defer.inlineCallbacks
- def set_password(self, user_id, newpassword, requester=None):
- self._password_policy_handler.validate_password(newpassword)
+ def set_password(
+ self,
+ user_id: str,
+ new_password: str,
+ logout_devices: bool,
+ requester: Optional[Requester] = None,
+ ):
+ if not self.hs.config.password_localdb_enabled:
+ raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
- password_hash = yield self._auth_handler.hash(newpassword)
+ self._password_policy_handler.validate_password(new_password)
- except_device_id = requester.device_id if requester else None
- except_access_token_id = requester.access_token_id if requester else None
+ password_hash = yield self._auth_handler.hash(new_password)
try:
yield self.store.user_set_password_hash(user_id, password_hash)
@@ -48,14 +57,18 @@ class SetPasswordHandler(BaseHandler):
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e
- # we want to log out all of the user's other sessions. First delete
- # all his other devices.
- yield self._device_handler.delete_all_devices_for_user(
- user_id, except_device_id=except_device_id,
- )
-
- # and now delete any access tokens which weren't associated with
- # devices (or were associated with this device).
- yield self._auth_handler.delete_access_tokens_for_user(
- user_id, except_token_id=except_access_token_id,
- )
+ # Optionally, log out all of the user's other sessions.
+ if logout_devices:
+ except_device_id = requester.device_id if requester else None
+ except_access_token_id = requester.access_token_id if requester else None
+
+ # First delete all of their other devices.
+ yield self._device_handler.delete_all_devices_for_user(
+ user_id, except_device_id=except_device_id
+ )
+
+ # and now delete any access tokens which weren't associated with
+ # devices (or were associated with this device).
+ yield self._auth_handler.delete_access_tokens_for_user(
+ user_id, except_token_id=except_access_token_id
+ )
diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py
index b268bbcb2c..f065970c40 100644
--- a/synapse/handlers/state_deltas.py
+++ b/synapse/handlers/state_deltas.py
@@ -21,7 +21,6 @@ logger = logging.getLogger(__name__)
class StateDeltasHandler(object):
-
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -49,7 +48,7 @@ class StateDeltasHandler(object):
if not event and not prev_event:
logger.debug("Neither event exists: %r %r", prev_event_id, event_id)
- defer.returnValue(None)
+ return None
prev_value = None
value = None
@@ -63,8 +62,8 @@ class StateDeltasHandler(object):
logger.debug("prev_value: %r -> value: %r", prev_value, value)
if value == public_value and prev_value != public_value:
- defer.returnValue(True)
+ return True
elif value != public_value and prev_value == public_value:
- defer.returnValue(False)
+ return False
else:
- defer.returnValue(None)
+ return None
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 7ad16c8566..d93a276693 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -14,15 +14,14 @@
# limitations under the License.
import logging
+from collections import Counter
from twisted.internet import defer
-from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.constants import EventTypes, Membership
from synapse.handlers.state_deltas import StateDeltasHandler
from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import UserID
-from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@@ -46,6 +45,8 @@ class StatsHandler(StateDeltasHandler):
self.is_mine_id = hs.is_mine_id
self.stats_bucket_size = hs.config.stats_bucket_size
+ self.stats_enabled = hs.config.stats_enabled
+
# The current position in the current_state_delta stream
self.pos = None
@@ -62,11 +63,10 @@ class StatsHandler(StateDeltasHandler):
def notify_new_event(self):
"""Called when there may be more deltas to process
"""
- if not self.hs.config.stats_enabled:
+ if not self.stats_enabled or self._is_processing:
return
- if self._is_processing:
- return
+ self._is_processing = True
@defer.inlineCallbacks
def process():
@@ -75,39 +75,83 @@ class StatsHandler(StateDeltasHandler):
finally:
self._is_processing = False
- self._is_processing = True
run_as_background_process("stats.notify_new_event", process)
@defer.inlineCallbacks
def _unsafe_process(self):
# If self.pos is None then means we haven't fetched it from DB
if self.pos is None:
- self.pos = yield self.store.get_stats_stream_pos()
-
- # If still None then the initial background update hasn't happened yet
- if self.pos is None:
- defer.returnValue(None)
+ self.pos = yield self.store.get_stats_positions()
# Loop round handling deltas until we're up to date
+
while True:
- with Measure(self.clock, "stats_delta"):
- deltas = yield self.store.get_current_state_deltas(self.pos)
- if not deltas:
- return
+ # Be sure to read the max stream_ordering *before* checking if there are any outstanding
+ # deltas, since there is otherwise a chance that we could miss updates which arrive
+ # after we check the deltas.
+ room_max_stream_ordering = self.store.get_room_max_stream_ordering()
+ if self.pos == room_max_stream_ordering:
+ break
+
+ logger.debug(
+ "Processing room stats %s->%s", self.pos, room_max_stream_ordering
+ )
+ max_pos, deltas = yield self.store.get_current_state_deltas(
+ self.pos, room_max_stream_ordering
+ )
+
+ if deltas:
+ logger.debug("Handling %d state deltas", len(deltas))
+ room_deltas, user_deltas = yield self._handle_deltas(deltas)
+ else:
+ room_deltas = {}
+ user_deltas = {}
+
+ # Then count deltas for total_events and total_event_bytes.
+ (
+ room_count,
+ user_count,
+ ) = yield self.store.get_changes_room_total_events_and_bytes(
+ self.pos, max_pos
+ )
+
+ for room_id, fields in room_count.items():
+ room_deltas.setdefault(room_id, {}).update(fields)
+
+ for user_id, fields in user_count.items():
+ user_deltas.setdefault(user_id, {}).update(fields)
+
+ logger.debug("room_deltas: %s", room_deltas)
+ logger.debug("user_deltas: %s", user_deltas)
+
+ # Always call this so that we update the stats position.
+ yield self.store.bulk_update_stats_delta(
+ self.clock.time_msec(),
+ updates={"room": room_deltas, "user": user_deltas},
+ stream_id=max_pos,
+ )
- logger.info("Handling %d state deltas", len(deltas))
- yield self._handle_deltas(deltas)
+ logger.debug("Handled room stats to %s -> %s", self.pos, max_pos)
- self.pos = deltas[-1]["stream_id"]
- yield self.store.update_stats_stream_pos(self.pos)
+ event_processing_positions.labels("stats").set(max_pos)
- event_processing_positions.labels("stats").set(self.pos)
+ self.pos = max_pos
@defer.inlineCallbacks
def _handle_deltas(self, deltas):
+ """Called with the state deltas to process
+
+ Returns:
+ Deferred[tuple[dict[str, Counter], dict[str, counter]]]
+ Resovles to two dicts, the room deltas and the user deltas,
+ mapping from room/user ID to changes in the various fields.
"""
- Called with the state deltas to process
- """
+
+ room_to_stats_deltas = {}
+ user_to_stats_deltas = {}
+
+ room_to_state_updates = {}
+
for delta in deltas:
typ = delta["type"]
state_key = delta["state_key"]
@@ -115,11 +159,10 @@ class StatsHandler(StateDeltasHandler):
event_id = delta["event_id"]
stream_id = delta["stream_id"]
prev_event_id = delta["prev_event_id"]
- stream_pos = delta["stream_id"]
- logger.debug("Handling: %r %r, %s", typ, state_key, event_id)
+ logger.debug("Handling: %r, %r %r, %s", room_id, typ, state_key, event_id)
- token = yield self.store.get_earliest_token_for_room_stats(room_id)
+ token = yield self.store.get_earliest_token_for_stats("room", room_id)
# If the earliest token to begin from is larger than our current
# stream ID, skip processing this delta.
@@ -131,203 +174,133 @@ class StatsHandler(StateDeltasHandler):
continue
if event_id is None and prev_event_id is None:
- # Errr...
+ logger.error(
+ "event ID is None and so is the previous event ID. stream_id: %s",
+ stream_id,
+ )
continue
event_content = {}
+ sender = None
if event_id is not None:
event = yield self.store.get_event(event_id, allow_none=True)
if event:
event_content = event.content or {}
+ sender = event.sender
- # We use stream_pos here rather than fetch by event_id as event_id
- # may be None
- now = yield self.store.get_received_ts_by_stream_pos(stream_pos)
+ # All the values in this dict are deltas (RELATIVE changes)
+ room_stats_delta = room_to_stats_deltas.setdefault(room_id, Counter())
- # quantise time to the nearest bucket
- now = (now // 1000 // self.stats_bucket_size) * self.stats_bucket_size
+ room_state = room_to_state_updates.setdefault(room_id, {})
+
+ if prev_event_id is None:
+ # this state event doesn't overwrite another,
+ # so it is a new effective/current state event
+ room_stats_delta["current_state_events"] += 1
if typ == EventTypes.Member:
# we could use _get_key_change here but it's a bit inefficient
# given we're not testing for a specific result; might as well
# just grab the prev_membership and membership strings and
# compare them.
- prev_event_content = {}
+ # We take None rather than leave as a previous membership
+ # in the absence of a previous event because we do not want to
+ # reduce the leave count when a new-to-the-room user joins.
+ prev_membership = None
if prev_event_id is not None:
prev_event = yield self.store.get_event(
- prev_event_id, allow_none=True,
+ prev_event_id, allow_none=True
)
if prev_event:
prev_event_content = prev_event.content
+ prev_membership = prev_event_content.get(
+ "membership", Membership.LEAVE
+ )
membership = event_content.get("membership", Membership.LEAVE)
- prev_membership = prev_event_content.get("membership", Membership.LEAVE)
- if prev_membership == membership:
- continue
-
- if prev_membership == Membership.JOIN:
- yield self.store.update_stats_delta(
- now, "room", room_id, "joined_members", -1
- )
+ if prev_membership is None:
+ logger.debug("No previous membership for this user.")
+ elif membership == prev_membership:
+ pass # noop
+ elif prev_membership == Membership.JOIN:
+ room_stats_delta["joined_members"] -= 1
elif prev_membership == Membership.INVITE:
- yield self.store.update_stats_delta(
- now, "room", room_id, "invited_members", -1
- )
+ room_stats_delta["invited_members"] -= 1
elif prev_membership == Membership.LEAVE:
- yield self.store.update_stats_delta(
- now, "room", room_id, "left_members", -1
- )
+ room_stats_delta["left_members"] -= 1
elif prev_membership == Membership.BAN:
- yield self.store.update_stats_delta(
- now, "room", room_id, "banned_members", -1
- )
+ room_stats_delta["banned_members"] -= 1
else:
- err = "%s is not a valid prev_membership" % (repr(prev_membership),)
- logger.error(err)
- raise ValueError(err)
+ raise ValueError(
+ "%r is not a valid prev_membership" % (prev_membership,)
+ )
+ if membership == prev_membership:
+ pass # noop
if membership == Membership.JOIN:
- yield self.store.update_stats_delta(
- now, "room", room_id, "joined_members", +1
- )
+ room_stats_delta["joined_members"] += 1
elif membership == Membership.INVITE:
- yield self.store.update_stats_delta(
- now, "room", room_id, "invited_members", +1
- )
+ room_stats_delta["invited_members"] += 1
+
+ if sender and self.is_mine_id(sender):
+ user_to_stats_deltas.setdefault(sender, Counter())[
+ "invites_sent"
+ ] += 1
+
elif membership == Membership.LEAVE:
- yield self.store.update_stats_delta(
- now, "room", room_id, "left_members", +1
- )
+ room_stats_delta["left_members"] += 1
elif membership == Membership.BAN:
- yield self.store.update_stats_delta(
- now, "room", room_id, "banned_members", +1
- )
+ room_stats_delta["banned_members"] += 1
else:
- err = "%s is not a valid membership" % (repr(membership),)
- logger.error(err)
- raise ValueError(err)
+ raise ValueError("%r is not a valid membership" % (membership,))
user_id = state_key
if self.is_mine_id(user_id):
- # update user_stats as it's one of our users
- public = yield self._is_public_room(room_id)
-
- if membership == Membership.LEAVE:
- yield self.store.update_stats_delta(
- now,
- "user",
- user_id,
- "public_rooms" if public else "private_rooms",
- -1,
- )
- elif membership == Membership.JOIN:
- yield self.store.update_stats_delta(
- now,
- "user",
- user_id,
- "public_rooms" if public else "private_rooms",
- +1,
- )
-
- elif typ == EventTypes.Create:
- # Newly created room. Add it with all blank portions.
- yield self.store.update_room_state(
- room_id,
- {
- "join_rules": None,
- "history_visibility": None,
- "encryption": None,
- "name": None,
- "topic": None,
- "avatar": None,
- "canonical_alias": None,
- },
- )
+ # this accounts for transitions like leave → ban and so on.
+ has_changed_joinedness = (prev_membership == Membership.JOIN) != (
+ membership == Membership.JOIN
+ )
- elif typ == EventTypes.JoinRules:
- yield self.store.update_room_state(
- room_id, {"join_rules": event_content.get("join_rule")}
- )
+ if has_changed_joinedness:
+ delta = +1 if membership == Membership.JOIN else -1
- is_public = yield self._get_key_change(
- prev_event_id, event_id, "join_rule", JoinRules.PUBLIC
- )
- if is_public is not None:
- yield self.update_public_room_stats(now, room_id, is_public)
+ user_to_stats_deltas.setdefault(user_id, Counter())[
+ "joined_rooms"
+ ] += delta
- elif typ == EventTypes.RoomHistoryVisibility:
- yield self.store.update_room_state(
- room_id,
- {"history_visibility": event_content.get("history_visibility")},
- )
+ room_stats_delta["local_users_in_room"] += delta
- is_public = yield self._get_key_change(
- prev_event_id, event_id, "history_visibility", "world_readable"
+ elif typ == EventTypes.Create:
+ room_state["is_federatable"] = (
+ event_content.get("m.federate", True) is True
)
- if is_public is not None:
- yield self.update_public_room_stats(now, room_id, is_public)
-
- elif typ == EventTypes.Encryption:
- yield self.store.update_room_state(
- room_id, {"encryption": event_content.get("algorithm")}
+ if sender and self.is_mine_id(sender):
+ user_to_stats_deltas.setdefault(sender, Counter())[
+ "rooms_created"
+ ] += 1
+ elif typ == EventTypes.JoinRules:
+ room_state["join_rules"] = event_content.get("join_rule")
+ elif typ == EventTypes.RoomHistoryVisibility:
+ room_state["history_visibility"] = event_content.get(
+ "history_visibility"
)
+ elif typ == EventTypes.RoomEncryption:
+ room_state["encryption"] = event_content.get("algorithm")
elif typ == EventTypes.Name:
- yield self.store.update_room_state(
- room_id, {"name": event_content.get("name")}
- )
+ room_state["name"] = event_content.get("name")
elif typ == EventTypes.Topic:
- yield self.store.update_room_state(
- room_id, {"topic": event_content.get("topic")}
- )
+ room_state["topic"] = event_content.get("topic")
elif typ == EventTypes.RoomAvatar:
- yield self.store.update_room_state(
- room_id, {"avatar": event_content.get("url")}
- )
+ room_state["avatar"] = event_content.get("url")
elif typ == EventTypes.CanonicalAlias:
- yield self.store.update_room_state(
- room_id, {"canonical_alias": event_content.get("alias")}
- )
-
- @defer.inlineCallbacks
- def update_public_room_stats(self, ts, room_id, is_public):
- """
- Increment/decrement a user's number of public rooms when a room they are
- in changes to/from public visibility.
-
- Args:
- ts (int): Timestamp in seconds
- room_id (str)
- is_public (bool)
- """
- # For now, blindly iterate over all local users in the room so that
- # we can handle the whole problem of copying buckets over as needed
- user_ids = yield self.store.get_users_in_room(room_id)
-
- for user_id in user_ids:
- if self.hs.is_mine(UserID.from_string(user_id)):
- yield self.store.update_stats_delta(
- ts, "user", user_id, "public_rooms", +1 if is_public else -1
- )
- yield self.store.update_stats_delta(
- ts, "user", user_id, "private_rooms", -1 if is_public else +1
- )
+ room_state["canonical_alias"] = event_content.get("alias")
+ elif typ == EventTypes.GuestAccess:
+ room_state["guest_access"] = event_content.get("guest_access")
- @defer.inlineCallbacks
- def _is_public_room(self, room_id):
- join_rules = yield self.state.get_current_state(room_id, EventTypes.JoinRules)
- history_visibility = yield self.state.get_current_state(
- room_id, EventTypes.RoomHistoryVisibility
- )
+ for room_id, state in room_to_state_updates.items():
+ logger.debug("Updating room_stats_state for %s: %s", room_id, state)
+ yield self.store.update_room_state(room_id, state)
- if (join_rules and join_rules.content.get("join_rule") == JoinRules.PUBLIC) or (
- (
- history_visibility
- and history_visibility.content.get("history_visibility")
- == "world_readable"
- )
- ):
- defer.returnValue(True)
- else:
- defer.returnValue(False)
+ return room_to_stats_deltas, user_to_stats_deltas
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 62fda0c664..cfd5dfc9e5 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018, 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,26 +14,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import collections
import itertools
import logging
+from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple
from six import iteritems, itervalues
+import attr
from prometheus_client import Counter
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
+from synapse.api.filtering import FilterCollection
+from synapse.events import EventBase
+from synapse.logging.context import LoggingContext
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
-from synapse.types import RoomStreamToken
+from synapse.types import (
+ Collection,
+ JsonDict,
+ RoomStreamToken,
+ StateMap,
+ StreamToken,
+ UserID,
+)
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.response_cache import ResponseCache
-from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure, measure_func
from synapse.visibility import filter_events_for_client
@@ -64,42 +72,41 @@ LAZY_LOADED_MEMBERS_CACHE_MAX_AGE = 30 * 60 * 1000
LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE = 100
-SyncConfig = collections.namedtuple("SyncConfig", [
- "user",
- "filter_collection",
- "is_guest",
- "request_key",
- "device_id",
-])
+@attr.s(slots=True, frozen=True)
+class SyncConfig:
+ user = attr.ib(type=UserID)
+ filter_collection = attr.ib(type=FilterCollection)
+ is_guest = attr.ib(type=bool)
+ request_key = attr.ib(type=Tuple[Any, ...])
+ device_id = attr.ib(type=str)
-class TimelineBatch(collections.namedtuple("TimelineBatch", [
- "prev_batch",
- "events",
- "limited",
-])):
- __slots__ = []
+@attr.s(slots=True, frozen=True)
+class TimelineBatch:
+ prev_batch = attr.ib(type=StreamToken)
+ events = attr.ib(type=List[EventBase])
+ limited = attr.ib(bool)
- def __nonzero__(self):
+ def __nonzero__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
to tell if room needs to be part of the sync result.
"""
return bool(self.events)
+
__bool__ = __nonzero__ # python3
-class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
- "room_id", # str
- "timeline", # TimelineBatch
- "state", # dict[(str, str), FrozenEvent]
- "ephemeral",
- "account_data",
- "unread_notifications",
- "summary",
-])):
- __slots__ = []
+@attr.s(slots=True, frozen=True)
+class JoinedSyncResult:
+ room_id = attr.ib(type=str)
+ timeline = attr.ib(type=TimelineBatch)
+ state = attr.ib(type=StateMap[EventBase])
+ ephemeral = attr.ib(type=List[JsonDict])
+ account_data = attr.ib(type=List[JsonDict])
+ unread_notifications = attr.ib(type=JsonDict)
+ summary = attr.ib(type=Optional[JsonDict])
- def __nonzero__(self):
+ def __nonzero__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
to tell if room needs to be part of the sync result.
"""
@@ -111,99 +118,127 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
# nb the notification count does not, er, count: if there's nothing
# else in the result, we don't need to send it.
)
+
__bool__ = __nonzero__ # python3
-class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [
- "room_id", # str
- "timeline", # TimelineBatch
- "state", # dict[(str, str), FrozenEvent]
- "account_data",
-])):
- __slots__ = []
+@attr.s(slots=True, frozen=True)
+class ArchivedSyncResult:
+ room_id = attr.ib(type=str)
+ timeline = attr.ib(type=TimelineBatch)
+ state = attr.ib(type=StateMap[EventBase])
+ account_data = attr.ib(type=List[JsonDict])
- def __nonzero__(self):
+ def __nonzero__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
to tell if room needs to be part of the sync result.
"""
- return bool(
- self.timeline
- or self.state
- or self.account_data
- )
+ return bool(self.timeline or self.state or self.account_data)
+
__bool__ = __nonzero__ # python3
-class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
- "room_id", # str
- "invite", # FrozenEvent: the invite event
-])):
- __slots__ = []
+@attr.s(slots=True, frozen=True)
+class InvitedSyncResult:
+ room_id = attr.ib(type=str)
+ invite = attr.ib(type=EventBase)
- def __nonzero__(self):
+ def __nonzero__(self) -> bool:
"""Invited rooms should always be reported to the client"""
return True
+
__bool__ = __nonzero__ # python3
-class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [
- "join",
- "invite",
- "leave",
-])):
- __slots__ = []
+@attr.s(slots=True, frozen=True)
+class GroupsSyncResult:
+ join = attr.ib(type=JsonDict)
+ invite = attr.ib(type=JsonDict)
+ leave = attr.ib(type=JsonDict)
- def __nonzero__(self):
+ def __nonzero__(self) -> bool:
return bool(self.join or self.invite or self.leave)
+
__bool__ = __nonzero__ # python3
-class DeviceLists(collections.namedtuple("DeviceLists", [
- "changed", # list of user_ids whose devices may have changed
- "left", # list of user_ids whose devices we no longer track
-])):
- __slots__ = []
+@attr.s(slots=True, frozen=True)
+class DeviceLists:
+ """
+ Attributes:
+ changed: List of user_ids whose devices may have changed
+ left: List of user_ids whose devices we no longer track
+ """
+
+ changed = attr.ib(type=Collection[str])
+ left = attr.ib(type=Collection[str])
- def __nonzero__(self):
+ def __nonzero__(self) -> bool:
return bool(self.changed or self.left)
+
__bool__ = __nonzero__ # python3
-class SyncResult(collections.namedtuple("SyncResult", [
- "next_batch", # Token for the next sync
- "presence", # List of presence events for the user.
- "account_data", # List of account_data events for the user.
- "joined", # JoinedSyncResult for each joined room.
- "invited", # InvitedSyncResult for each invited room.
- "archived", # ArchivedSyncResult for each archived room.
- "to_device", # List of direct messages for the device.
- "device_lists", # List of user_ids whose devices have changed
- "device_one_time_keys_count", # Dict of algorithm to count for one time keys
- # for this device
- "groups",
-])):
- __slots__ = []
-
- def __nonzero__(self):
+@attr.s
+class _RoomChanges:
+ """The set of room entries to include in the sync, plus the set of joined
+ and left room IDs since last sync.
+ """
+
+ room_entries = attr.ib(type=List["RoomSyncResultBuilder"])
+ invited = attr.ib(type=List[InvitedSyncResult])
+ newly_joined_rooms = attr.ib(type=List[str])
+ newly_left_rooms = attr.ib(type=List[str])
+
+
+@attr.s(slots=True, frozen=True)
+class SyncResult:
+ """
+ Attributes:
+ next_batch: Token for the next sync
+ presence: List of presence events for the user.
+ account_data: List of account_data events for the user.
+ joined: JoinedSyncResult for each joined room.
+ invited: InvitedSyncResult for each invited room.
+ archived: ArchivedSyncResult for each archived room.
+ to_device: List of direct messages for the device.
+ device_lists: List of user_ids whose devices have changed
+ device_one_time_keys_count: Dict of algorithm to count for one time keys
+ for this device
+ groups: Group updates, if any
+ """
+
+ next_batch = attr.ib(type=StreamToken)
+ presence = attr.ib(type=List[JsonDict])
+ account_data = attr.ib(type=List[JsonDict])
+ joined = attr.ib(type=List[JoinedSyncResult])
+ invited = attr.ib(type=List[InvitedSyncResult])
+ archived = attr.ib(type=List[ArchivedSyncResult])
+ to_device = attr.ib(type=List[JsonDict])
+ device_lists = attr.ib(type=DeviceLists)
+ device_one_time_keys_count = attr.ib(type=JsonDict)
+ groups = attr.ib(type=Optional[GroupsSyncResult])
+
+ def __nonzero__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
to tell if the notifier needs to wait for more events when polling for
events.
"""
return bool(
- self.presence or
- self.joined or
- self.invited or
- self.archived or
- self.account_data or
- self.to_device or
- self.device_lists or
- self.groups
+ self.presence
+ or self.joined
+ or self.invited
+ or self.archived
+ or self.account_data
+ or self.to_device
+ or self.device_lists
+ or self.groups
)
+
__bool__ = __nonzero__ # python3
class SyncHandler(object):
-
def __init__(self, hs):
self.hs_config = hs.config
self.store = hs.get_datastore()
@@ -214,38 +249,51 @@ class SyncHandler(object):
self.response_cache = ResponseCache(hs, "sync")
self.state = hs.get_state_handler()
self.auth = hs.get_auth()
+ self.storage = hs.get_storage()
+ self.state_store = self.storage.state
# ExpiringCache((User, Device)) -> LruCache(state_key => event_id)
self.lazy_loaded_members_cache = ExpiringCache(
- "lazy_loaded_members_cache", self.clock,
- max_len=0, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
+ "lazy_loaded_members_cache",
+ self.clock,
+ max_len=0,
+ expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE,
)
- @defer.inlineCallbacks
- def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
- full_state=False):
+ async def wait_for_sync_for_user(
+ self,
+ sync_config: SyncConfig,
+ since_token: Optional[StreamToken] = None,
+ timeout: int = 0,
+ full_state: bool = False,
+ ) -> SyncResult:
"""Get the sync for a client if we have new data for it now. Otherwise
wait for new data to arrive on the server. If the timeout expires, then
return an empty sync result.
- Returns:
- Deferred[SyncResult]
"""
# If the user is not part of the mau group, then check that limits have
# not been exceeded (if not part of the group by this point, almost certain
# auth_blocking will occur)
user_id = sync_config.user.to_string()
- yield self.auth.check_auth_blocking(user_id)
+ await self.auth.check_auth_blocking(user_id)
- res = yield self.response_cache.wrap(
+ res = await self.response_cache.wrap(
sync_config.request_key,
self._wait_for_sync_for_user,
- sync_config, since_token, timeout, full_state,
+ sync_config,
+ since_token,
+ timeout,
+ full_state,
)
- defer.returnValue(res)
-
- @defer.inlineCallbacks
- def _wait_for_sync_for_user(self, sync_config, since_token, timeout,
- full_state):
+ return res
+
+ async def _wait_for_sync_for_user(
+ self,
+ sync_config: SyncConfig,
+ since_token: Optional[StreamToken] = None,
+ timeout: int = 0,
+ full_state: bool = False,
+ ) -> SyncResult:
if since_token is None:
sync_type = "initial_sync"
elif full_state:
@@ -260,15 +308,18 @@ class SyncHandler(object):
if timeout == 0 or since_token is None or full_state:
# we are going to return immediately, so don't bother calling
# notifier.wait_for_events.
- result = yield self.current_sync_for_user(
- sync_config, since_token, full_state=full_state,
+ result = await self.current_sync_for_user(
+ sync_config, since_token, full_state=full_state
)
else:
+
def current_sync_callback(before_token, after_token):
return self.current_sync_for_user(sync_config, since_token)
- result = yield self.notifier.wait_for_events(
- sync_config.user.to_string(), timeout, current_sync_callback,
+ result = await self.notifier.wait_for_events(
+ sync_config.user.to_string(),
+ timeout,
+ current_sync_callback,
from_token=since_token,
)
@@ -279,30 +330,35 @@ class SyncHandler(object):
lazy_loaded = "false"
non_empty_sync_counter.labels(sync_type, lazy_loaded).inc()
- defer.returnValue(result)
+ return result
- def current_sync_for_user(self, sync_config, since_token=None,
- full_state=False):
+ async def current_sync_for_user(
+ self,
+ sync_config: SyncConfig,
+ since_token: Optional[StreamToken] = None,
+ full_state: bool = False,
+ ) -> SyncResult:
"""Get the sync for client needed to match what the server has now.
- Returns:
- A Deferred SyncResult.
"""
- return self.generate_sync_result(sync_config, since_token, full_state)
+ return await self.generate_sync_result(sync_config, since_token, full_state)
- @defer.inlineCallbacks
- def push_rules_for_user(self, user):
+ async def push_rules_for_user(self, user: UserID) -> JsonDict:
user_id = user.to_string()
- rules = yield self.store.get_push_rules_for_user(user_id)
+ rules = await self.store.get_push_rules_for_user(user_id)
rules = format_push_rules_for_user(user, rules)
- defer.returnValue(rules)
-
- @defer.inlineCallbacks
- def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None):
+ return rules
+
+ async def ephemeral_by_room(
+ self,
+ sync_result_builder: "SyncResultBuilder",
+ now_token: StreamToken,
+ since_token: Optional[StreamToken] = None,
+ ) -> Tuple[StreamToken, Dict[str, List[JsonDict]]]:
"""Get the ephemeral events for each room the user is in
Args:
- sync_result_builder(SyncResultBuilder)
- now_token (StreamToken): Where the server is currently up to.
- since_token (StreamToken): Where the server was when the client
+ sync_result_builder
+ now_token: Where the server is currently up to.
+ since_token: Where the server was when the client
last synced.
Returns:
A tuple of the now StreamToken, updated to reflect the which typing
@@ -318,7 +374,7 @@ class SyncHandler(object):
room_ids = sync_result_builder.joined_room_ids
typing_source = self.event_sources.sources["typing"]
- typing, typing_key = yield typing_source.get_new_events(
+ typing, typing_key = await typing_source.get_new_events(
user=sync_config.user,
from_key=typing_key,
limit=sync_config.filter_collection.ephemeral_limit(),
@@ -327,21 +383,20 @@ class SyncHandler(object):
)
now_token = now_token.copy_and_replace("typing_key", typing_key)
- ephemeral_by_room = {}
+ ephemeral_by_room = {} # type: JsonDict
for event in typing:
# we want to exclude the room_id from the event, but modifying the
# result returned by the event source is poor form (it might cache
# the object)
room_id = event["room_id"]
- event_copy = {k: v for (k, v) in iteritems(event)
- if k != "room_id"}
+ event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
receipt_key = since_token.receipt_key if since_token else "0"
receipt_source = self.event_sources.sources["receipt"]
- receipts, receipt_key = yield receipt_source.get_new_events(
+ receipts, receipt_key = await receipt_source.get_new_events(
user=sync_config.user,
from_key=receipt_key,
limit=sync_config.filter_collection.ephemeral_limit(),
@@ -353,41 +408,56 @@ class SyncHandler(object):
for event in receipts:
room_id = event["room_id"]
# exclude room id, as above
- event_copy = {k: v for (k, v) in iteritems(event)
- if k != "room_id"}
+ event_copy = {k: v for (k, v) in iteritems(event) if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
- defer.returnValue((now_token, ephemeral_by_room))
-
- @defer.inlineCallbacks
- def _load_filtered_recents(self, room_id, sync_config, now_token,
- since_token=None, recents=None, newly_joined_room=False):
+ return now_token, ephemeral_by_room
+
+ async def _load_filtered_recents(
+ self,
+ room_id: str,
+ sync_config: SyncConfig,
+ now_token: StreamToken,
+ since_token: Optional[StreamToken] = None,
+ potential_recents: Optional[List[EventBase]] = None,
+ newly_joined_room: bool = False,
+ ) -> TimelineBatch:
"""
Returns:
a Deferred TimelineBatch
"""
with Measure(self.clock, "load_filtered_recents"):
timeline_limit = sync_config.filter_collection.timeline_limit()
- block_all_timeline = sync_config.filter_collection.blocks_all_room_timeline()
+ block_all_timeline = (
+ sync_config.filter_collection.blocks_all_room_timeline()
+ )
- if recents is None or newly_joined_room or timeline_limit < len(recents):
+ if (
+ potential_recents is None
+ or newly_joined_room
+ or timeline_limit < len(potential_recents)
+ ):
limited = True
else:
limited = False
- if recents:
- recents = sync_config.filter_collection.filter_room_timeline(recents)
+ if potential_recents:
+ recents = sync_config.filter_collection.filter_room_timeline(
+ potential_recents
+ )
# We check if there are any state events, if there are then we pass
# all current state events to the filter_events function. This is to
# ensure that we always include current state in the timeline
- current_state_ids = frozenset()
+ current_state_ids = frozenset() # type: FrozenSet[str]
if any(e.is_state() for e in recents):
- current_state_ids = yield self.state.get_current_state_ids(room_id)
- current_state_ids = frozenset(itervalues(current_state_ids))
+ current_state_ids_map = await self.state.get_current_state_ids(
+ room_id
+ )
+ current_state_ids = frozenset(itervalues(current_state_ids_map))
- recents = yield filter_events_for_client(
- self.store,
+ recents = await filter_events_for_client(
+ self.storage,
sync_config.user.to_string(),
recents,
always_include_ids=current_state_ids,
@@ -396,11 +466,9 @@ class SyncHandler(object):
recents = []
if not limited or block_all_timeline:
- defer.returnValue(TimelineBatch(
- events=recents,
- prev_batch=now_token,
- limited=False
- ))
+ return TimelineBatch(
+ events=recents, prev_batch=now_token, limited=False
+ )
filtering_factor = 2
load_limit = max(timeline_limit * filtering_factor, 10)
@@ -419,17 +487,15 @@ class SyncHandler(object):
# Otherwise, we want to return the last N events in the room
# in toplogical ordering.
if since_key:
- events, end_key = yield self.store.get_room_events_stream_for_room(
+ events, end_key = await self.store.get_room_events_stream_for_room(
room_id,
limit=load_limit + 1,
from_key=since_key,
to_key=end_key,
)
else:
- events, end_key = yield self.store.get_recent_events_for_room(
- room_id,
- limit=load_limit + 1,
- end_token=end_key,
+ events, end_key = await self.store.get_recent_events_for_room(
+ room_id, limit=load_limit + 1, end_token=end_key
)
loaded_recents = sync_config.filter_collection.filter_room_timeline(
events
@@ -440,11 +506,13 @@ class SyncHandler(object):
# ensure that we always include current state in the timeline
current_state_ids = frozenset()
if any(e.is_state() for e in loaded_recents):
- current_state_ids = yield self.state.get_current_state_ids(room_id)
- current_state_ids = frozenset(itervalues(current_state_ids))
+ current_state_ids_map = await self.state.get_current_state_ids(
+ room_id
+ )
+ current_state_ids = frozenset(itervalues(current_state_ids_map))
- loaded_recents = yield filter_events_for_client(
- self.store,
+ loaded_recents = await filter_events_for_client(
+ self.storage,
sync_config.user.to_string(),
loaded_recents,
always_include_ids=current_state_ids,
@@ -462,164 +530,144 @@ class SyncHandler(object):
recents = recents[-timeline_limit:]
room_key = recents[0].internal_metadata.before
- prev_batch_token = now_token.copy_and_replace(
- "room_key", room_key
- )
+ prev_batch_token = now_token.copy_and_replace("room_key", room_key)
- defer.returnValue(TimelineBatch(
+ return TimelineBatch(
events=recents,
prev_batch=prev_batch_token,
- limited=limited or newly_joined_room
- ))
+ limited=limited or newly_joined_room,
+ )
- @defer.inlineCallbacks
- def get_state_after_event(self, event, state_filter=StateFilter.all()):
+ async def get_state_after_event(
+ self, event: EventBase, state_filter: StateFilter = StateFilter.all()
+ ) -> StateMap[str]:
"""
Get the room state after the given event
Args:
- event(synapse.events.EventBase): event of interest
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
-
- Returns:
- A Deferred map from ((type, state_key)->Event)
+ event: event of interest
+ state_filter: The state filter used to fetch state from the database.
"""
- state_ids = yield self.store.get_state_ids_for_event(
- event.event_id, state_filter=state_filter,
+ state_ids = await self.state_store.get_state_ids_for_event(
+ event.event_id, state_filter=state_filter
)
if event.is_state():
state_ids = state_ids.copy()
state_ids[(event.type, event.state_key)] = event.event_id
- defer.returnValue(state_ids)
-
- @defer.inlineCallbacks
- def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()):
+ return state_ids
+
+ async def get_state_at(
+ self,
+ room_id: str,
+ stream_position: StreamToken,
+ state_filter: StateFilter = StateFilter.all(),
+ ) -> StateMap[str]:
""" Get the room state at a particular stream position
Args:
- room_id(str): room for which to get state
- stream_position(StreamToken): point at which to get state
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
-
- Returns:
- A Deferred map from ((type, state_key)->Event)
+ room_id: room for which to get state
+ stream_position: point at which to get state
+ state_filter: The state filter used to fetch state from the database.
"""
# FIXME this claims to get the state at a stream position, but
# get_recent_events_for_room operates by topo ordering. This therefore
# does not reliably give you the state at the given stream position.
# (https://github.com/matrix-org/synapse/issues/3305)
- last_events, _ = yield self.store.get_recent_events_for_room(
- room_id, end_token=stream_position.room_key, limit=1,
+ last_events, _ = await self.store.get_recent_events_for_room(
+ room_id, end_token=stream_position.room_key, limit=1
)
if last_events:
last_event = last_events[-1]
- state = yield self.get_state_after_event(
- last_event, state_filter=state_filter,
+ state = await self.get_state_after_event(
+ last_event, state_filter=state_filter
)
else:
# no events in this room - so presumably no state
state = {}
- defer.returnValue(state)
-
- @defer.inlineCallbacks
- def compute_summary(self, room_id, sync_config, batch, state, now_token):
+ return state
+
+ async def compute_summary(
+ self,
+ room_id: str,
+ sync_config: SyncConfig,
+ batch: TimelineBatch,
+ state: StateMap[EventBase],
+ now_token: StreamToken,
+ ) -> Optional[JsonDict]:
""" Works out a room summary block for this room, summarising the number
of joined members in the room, and providing the 'hero' members if the
room has no name so clients can consistently name rooms. Also adds
state events to 'state' if needed to describe the heroes.
- Args:
- room_id(str):
- sync_config(synapse.handlers.sync.SyncConfig):
- batch(synapse.handlers.sync.TimelineBatch): The timeline batch for
- the room that will be sent to the user.
- state(dict): dict of (type, state_key) -> Event as returned by
- compute_state_delta
- now_token(str): Token of the end of the current batch.
-
- Returns:
- A deferred dict describing the room summary
+ Args
+ room_id
+ sync_config
+ batch: The timeline batch for the room that will be sent to the user.
+ state: State as returned by compute_state_delta
+ now_token: Token of the end of the current batch.
"""
# FIXME: we could/should get this from room_stats when matthew/stats lands
# FIXME: this promulgates https://github.com/matrix-org/synapse/issues/3305
- last_events, _ = yield self.store.get_recent_event_ids_for_room(
- room_id, end_token=now_token.room_key, limit=1,
+ last_events, _ = await self.store.get_recent_event_ids_for_room(
+ room_id, end_token=now_token.room_key, limit=1
)
if not last_events:
- defer.returnValue(None)
- return
+ return None
last_event = last_events[-1]
- state_ids = yield self.store.get_state_ids_for_event(
+ state_ids = await self.state_store.get_state_ids_for_event(
last_event.event_id,
- state_filter=StateFilter.from_types([
- (EventTypes.Name, ''),
- (EventTypes.CanonicalAlias, ''),
- ]),
+ state_filter=StateFilter.from_types(
+ [(EventTypes.Name, ""), (EventTypes.CanonicalAlias, "")]
+ ),
)
# this is heavily cached, thus: fast.
- details = yield self.store.get_room_summary(room_id)
+ details = await self.store.get_room_summary(room_id)
- name_id = state_ids.get((EventTypes.Name, ''))
- canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, ''))
+ name_id = state_ids.get((EventTypes.Name, ""))
+ canonical_alias_id = state_ids.get((EventTypes.CanonicalAlias, ""))
summary = {}
empty_ms = MemberSummary([], 0)
# TODO: only send these when they change.
- summary["m.joined_member_count"] = (
- details.get(Membership.JOIN, empty_ms).count
- )
- summary["m.invited_member_count"] = (
- details.get(Membership.INVITE, empty_ms).count
- )
+ summary["m.joined_member_count"] = details.get(Membership.JOIN, empty_ms).count
+ summary["m.invited_member_count"] = details.get(
+ Membership.INVITE, empty_ms
+ ).count
# if the room has a name or canonical_alias set, we can skip
# calculating heroes. Empty strings are falsey, so we check
# for the "name" value and default to an empty string.
if name_id:
- name = yield self.store.get_event(name_id, allow_none=True)
+ name = await self.store.get_event(name_id, allow_none=True)
if name and name.content.get("name"):
- defer.returnValue(summary)
+ return summary
if canonical_alias_id:
- canonical_alias = yield self.store.get_event(
- canonical_alias_id, allow_none=True,
+ canonical_alias = await self.store.get_event(
+ canonical_alias_id, allow_none=True
)
if canonical_alias and canonical_alias.content.get("alias"):
- defer.returnValue(summary)
+ return summary
me = sync_config.user.to_string()
joined_user_ids = [
- r[0]
- for r in details.get(Membership.JOIN, empty_ms).members
- if r[0] != me
+ r[0] for r in details.get(Membership.JOIN, empty_ms).members if r[0] != me
]
invited_user_ids = [
- r[0]
- for r in details.get(Membership.INVITE, empty_ms).members
- if r[0] != me
+ r[0] for r in details.get(Membership.INVITE, empty_ms).members if r[0] != me
]
- gone_user_ids = (
- [
- r[0]
- for r in details.get(Membership.LEAVE, empty_ms).members
- if r[0] != me
- ] + [
- r[0]
- for r in details.get(Membership.BAN, empty_ms).members
- if r[0] != me
- ]
- )
+ gone_user_ids = [
+ r[0] for r in details.get(Membership.LEAVE, empty_ms).members if r[0] != me
+ ] + [r[0] for r in details.get(Membership.BAN, empty_ms).members if r[0] != me]
# FIXME: only build up a member_ids list for our heroes
member_ids = {}
@@ -627,23 +675,19 @@ class SyncHandler(object):
Membership.JOIN,
Membership.INVITE,
Membership.LEAVE,
- Membership.BAN
+ Membership.BAN,
):
for user_id, event_id in details.get(membership, empty_ms).members:
member_ids[user_id] = event_id
# FIXME: order by stream ordering rather than as returned by SQL
- if (joined_user_ids or invited_user_ids):
- summary['m.heroes'] = sorted(
- [user_id for user_id in (joined_user_ids + invited_user_ids)]
- )[0:5]
+ if joined_user_ids or invited_user_ids:
+ summary["m.heroes"] = sorted(joined_user_ids + invited_user_ids)[0:5]
else:
- summary['m.heroes'] = sorted(
- [user_id for user_id in gone_user_ids]
- )[0:5]
+ summary["m.heroes"] = sorted(gone_user_ids)[0:5]
if not sync_config.filter_collection.lazy_load_members():
- defer.returnValue(summary)
+ return summary
# ensure we send membership events for heroes if needed
cache_key = (sync_config.user.to_string(), sync_config.device_id)
@@ -651,10 +695,9 @@ class SyncHandler(object):
# track which members the client should already know about via LL:
# Ones which are already in state...
- existing_members = set(
- user_id for (typ, user_id) in state.keys()
- if typ == EventTypes.Member
- )
+ existing_members = {
+ user_id for (typ, user_id) in state.keys() if typ == EventTypes.Member
+ }
# ...or ones which are in the timeline...
for ev in batch.events:
@@ -664,23 +707,23 @@ class SyncHandler(object):
# ...and then ensure any missing ones get included in state.
missing_hero_event_ids = [
member_ids[hero_id]
- for hero_id in summary['m.heroes']
+ for hero_id in summary["m.heroes"]
if (
- cache.get(hero_id) != member_ids[hero_id] and
- hero_id not in existing_members
+ cache.get(hero_id) != member_ids[hero_id]
+ and hero_id not in existing_members
)
]
- missing_hero_state = yield self.store.get_events(missing_hero_event_ids)
+ missing_hero_state = await self.store.get_events(missing_hero_event_ids)
missing_hero_state = missing_hero_state.values()
for s in missing_hero_state:
cache.set(s.state_key, s.event_id)
state[(EventTypes.Member, s.state_key)] = s
- defer.returnValue(summary)
+ return summary
- def get_lazy_loaded_members_cache(self, cache_key):
+ def get_lazy_loaded_members_cache(self, cache_key: Tuple[str, str]) -> LruCache:
cache = self.lazy_loaded_members_cache.get(cache_key)
if cache is None:
logger.debug("creating LruCache for %r", cache_key)
@@ -690,24 +733,25 @@ class SyncHandler(object):
logger.debug("found LruCache for %r", cache_key)
return cache
- @defer.inlineCallbacks
- def compute_state_delta(self, room_id, batch, sync_config, since_token, now_token,
- full_state):
+ async def compute_state_delta(
+ self,
+ room_id: str,
+ batch: TimelineBatch,
+ sync_config: SyncConfig,
+ since_token: Optional[StreamToken],
+ now_token: StreamToken,
+ full_state: bool,
+ ) -> StateMap[EventBase]:
""" Works out the difference in state between the start of the timeline
and the previous sync.
Args:
- room_id(str):
- batch(synapse.handlers.sync.TimelineBatch): The timeline batch for
- the room that will be sent to the user.
- sync_config(synapse.handlers.sync.SyncConfig):
- since_token(str|None): Token of the end of the previous batch. May
- be None.
- now_token(str): Token of the end of the current batch.
- full_state(bool): Whether to force returning the full state.
-
- Returns:
- A deferred dict of (type, state_key) -> Event
+ room_id:
+ batch: The timeline batch for the room that will be sent to the user.
+ sync_config:
+ since_token: Token of the end of the previous batch. May be None.
+ now_token: Token of the end of the current batch.
+ full_state: Whether to force returning the full state.
"""
# TODO(mjark) Check if the state events were received by the server
# after the previous sync, since we need to include those state
@@ -727,10 +771,10 @@ class SyncHandler(object):
# We only request state for the members needed to display the
# timeline:
- members_to_fetch = set(
+ members_to_fetch = {
event.sender # FIXME: we also care about invite targets etc.
for event in batch.events
- )
+ }
if full_state:
# always make sure we LL ourselves so we know we're in the room
@@ -745,23 +789,23 @@ class SyncHandler(object):
timeline_state = {
(event.type, event.state_key): event.event_id
- for event in batch.events if event.is_state()
+ for event in batch.events
+ if event.is_state()
}
if full_state:
if batch:
- current_state_ids = yield self.store.get_state_ids_for_event(
- batch.events[-1].event_id, state_filter=state_filter,
+ current_state_ids = await self.state_store.get_state_ids_for_event(
+ batch.events[-1].event_id, state_filter=state_filter
)
- state_ids = yield self.store.get_state_ids_for_event(
- batch.events[0].event_id, state_filter=state_filter,
+ state_ids = await self.state_store.get_state_ids_for_event(
+ batch.events[0].event_id, state_filter=state_filter
)
else:
- current_state_ids = yield self.get_state_at(
- room_id, stream_position=now_token,
- state_filter=state_filter,
+ current_state_ids = await self.get_state_at(
+ room_id, stream_position=now_token, state_filter=state_filter
)
state_ids = current_state_ids
@@ -774,9 +818,16 @@ class SyncHandler(object):
lazy_load_members=lazy_load_members,
)
elif batch.limited:
- state_at_timeline_start = yield self.store.get_state_ids_for_event(
- batch.events[0].event_id, state_filter=state_filter,
- )
+ if batch:
+ state_at_timeline_start = await self.state_store.get_state_ids_for_event(
+ batch.events[0].event_id, state_filter=state_filter
+ )
+ else:
+ # We can get here if the user has ignored the senders of all
+ # the recent events.
+ state_at_timeline_start = await self.get_state_at(
+ room_id, stream_position=now_token, state_filter=state_filter
+ )
# for now, we disable LL for gappy syncs - see
# https://github.com/vector-im/riot-web/issues/7211#issuecomment-419976346
@@ -792,14 +843,25 @@ class SyncHandler(object):
# about them).
state_filter = StateFilter.all()
- state_at_previous_sync = yield self.get_state_at(
- room_id, stream_position=since_token,
- state_filter=state_filter,
+ # If this is an initial sync then full_state should be set, and
+ # that case is handled above. We assert here to ensure that this
+ # is indeed the case.
+ assert since_token is not None
+ state_at_previous_sync = await self.get_state_at(
+ room_id, stream_position=since_token, state_filter=state_filter
)
- current_state_ids = yield self.store.get_state_ids_for_event(
- batch.events[-1].event_id, state_filter=state_filter,
- )
+ if batch:
+ current_state_ids = await self.state_store.get_state_ids_for_event(
+ batch.events[-1].event_id, state_filter=state_filter
+ )
+ else:
+ # Its not clear how we get here, but empirically we do
+ # (#5407). Logging has been added elsewhere to try and
+ # figure out where this state comes from.
+ current_state_ids = await self.get_state_at(
+ room_id, stream_position=now_token, state_filter=state_filter
+ )
state_ids = _calculate_state(
timeline_contains=timeline_state,
@@ -821,7 +883,7 @@ class SyncHandler(object):
# So we fish out all the member events corresponding to the
# timeline here, and then dedupe any redundant ones below.
- state_ids = yield self.store.get_state_ids_for_event(
+ state_ids = await self.state_store.get_state_ids_for_event(
batch.events[0].event_id,
# we only want members!
state_filter=StateFilter.from_types(
@@ -854,62 +916,62 @@ class SyncHandler(object):
# add any member IDs we are about to send into our LruCache
for t, event_id in itertools.chain(
- state_ids.items(),
- timeline_state.items(),
+ state_ids.items(), timeline_state.items()
):
if t[0] == EventTypes.Member:
cache.set(t[1], event_id)
- state = {}
+ state = {} # type: Dict[str, EventBase]
if state_ids:
- state = yield self.store.get_events(list(state_ids.values()))
+ state = await self.store.get_events(list(state_ids.values()))
- defer.returnValue({
+ return {
(e.type, e.state_key): e
- for e in sync_config.filter_collection.filter_room_state(list(state.values()))
- })
+ for e in sync_config.filter_collection.filter_room_state(
+ list(state.values())
+ )
+ if e.type != EventTypes.Aliases # until MSC2261 or alternative solution
+ }
- @defer.inlineCallbacks
- def unread_notifs_for_room_id(self, room_id, sync_config):
+ async def unread_notifs_for_room_id(
+ self, room_id: str, sync_config: SyncConfig
+ ) -> Optional[Dict[str, str]]:
with Measure(self.clock, "unread_notifs_for_room_id"):
- last_unread_event_id = yield self.store.get_last_receipt_event_id_for_user(
+ last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
user_id=sync_config.user.to_string(),
room_id=room_id,
- receipt_type="m.read"
+ receipt_type="m.read",
)
- notifs = []
if last_unread_event_id:
- notifs = yield self.store.get_unread_event_push_actions_by_room_for_user(
+ notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
room_id, sync_config.user.to_string(), last_unread_event_id
)
- defer.returnValue(notifs)
+ return notifs
# There is no new information in this period, so your notification
# count is whatever it was last time.
- defer.returnValue(None)
-
- @defer.inlineCallbacks
- def generate_sync_result(self, sync_config, since_token=None, full_state=False):
+ return None
+
+ async def generate_sync_result(
+ self,
+ sync_config: SyncConfig,
+ since_token: Optional[StreamToken] = None,
+ full_state: bool = False,
+ ) -> SyncResult:
"""Generates a sync result.
-
- Args:
- sync_config (SyncConfig)
- since_token (StreamToken)
- full_state (bool)
-
- Returns:
- Deferred(SyncResult)
"""
# NB: The now_token gets changed by some of the generate_sync_* methods,
# this is due to some of the underlying streams not supporting the ability
# to query up to a given point.
# Always use the `now_token` in `SyncResultBuilder`
- now_token = yield self.event_sources.get_current_token()
+ now_token = await self.event_sources.get_current_token()
- logger.info(
+ logger.debug(
"Calculating sync response for %r between %s and %s",
- sync_config.user, since_token, now_token,
+ sync_config.user,
+ since_token,
+ now_token,
)
user_id = sync_config.user.to_string()
@@ -919,39 +981,38 @@ class SyncHandler(object):
# See https://github.com/matrix-org/matrix-doc/issues/1144
raise NotImplementedError()
else:
- joined_room_ids = yield self.get_rooms_for_user_at(
- user_id, now_token.room_stream_id,
+ joined_room_ids = await self.get_rooms_for_user_at(
+ user_id, now_token.room_stream_id
)
-
sync_result_builder = SyncResultBuilder(
- sync_config, full_state,
+ sync_config,
+ full_state,
since_token=since_token,
now_token=now_token,
joined_room_ids=joined_room_ids,
)
- account_data_by_room = yield self._generate_sync_entry_for_account_data(
+ account_data_by_room = await self._generate_sync_entry_for_account_data(
sync_result_builder
)
- res = yield self._generate_sync_entry_for_rooms(
+ res = await self._generate_sync_entry_for_rooms(
sync_result_builder, account_data_by_room
)
newly_joined_rooms, newly_joined_or_invited_users, _, _ = res
_, _, newly_left_rooms, newly_left_users = res
block_all_presence_data = (
- since_token is None and
- sync_config.filter_collection.blocks_all_presence()
+ since_token is None and sync_config.filter_collection.blocks_all_presence()
)
if self.hs_config.use_presence and not block_all_presence_data:
- yield self._generate_sync_entry_for_presence(
+ await self._generate_sync_entry_for_presence(
sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
)
- yield self._generate_sync_entry_for_to_device(sync_result_builder)
+ await self._generate_sync_entry_for_to_device(sync_result_builder)
- device_lists = yield self._generate_sync_entry_for_device_list(
+ device_lists = await self._generate_sync_entry_for_device_list(
sync_result_builder,
newly_joined_rooms=newly_joined_rooms,
newly_joined_or_invited_users=newly_joined_or_invited_users,
@@ -960,24 +1021,23 @@ class SyncHandler(object):
)
device_id = sync_config.device_id
- one_time_key_counts = {}
+ one_time_key_counts = {} # type: JsonDict
if device_id:
- one_time_key_counts = yield self.store.count_e2e_one_time_keys(
+ one_time_key_counts = await self.store.count_e2e_one_time_keys(
user_id, device_id
)
- yield self._generate_sync_entry_for_groups(sync_result_builder)
+ await self._generate_sync_entry_for_groups(sync_result_builder)
# debug for https://github.com/matrix-org/synapse/issues/4422
for joined_room in sync_result_builder.joined:
room_id = joined_room.room_id
if room_id in newly_joined_rooms:
issue4422_logger.debug(
- "Sync result for newly joined room %s: %r",
- room_id, joined_room,
+ "Sync result for newly joined room %s: %r", room_id, joined_room
)
- defer.returnValue(SyncResult(
+ return SyncResult(
presence=sync_result_builder.presence,
account_data=sync_result_builder.account_data,
joined=sync_result_builder.joined,
@@ -988,22 +1048,23 @@ class SyncHandler(object):
groups=sync_result_builder.groups,
device_one_time_keys_count=one_time_key_counts,
next_batch=sync_result_builder.now_token,
- ))
+ )
@measure_func("_generate_sync_entry_for_groups")
- @defer.inlineCallbacks
- def _generate_sync_entry_for_groups(self, sync_result_builder):
+ async def _generate_sync_entry_for_groups(
+ self, sync_result_builder: "SyncResultBuilder"
+ ) -> None:
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
now_token = sync_result_builder.now_token
if since_token and since_token.groups_key:
- results = yield self.store.get_groups_changes_for_user(
- user_id, since_token.groups_key, now_token.groups_key,
+ results = await self.store.get_groups_changes_for_user(
+ user_id, since_token.groups_key, now_token.groups_key
)
else:
- results = yield self.store.get_all_groups_for_user(
- user_id, now_token.groups_key,
+ results = await self.store.get_all_groups_for_user(
+ user_id, now_token.groups_key
)
invited = {}
@@ -1031,69 +1092,98 @@ class SyncHandler(object):
left[group_id] = content["content"]
sync_result_builder.groups = GroupsSyncResult(
- join=joined,
- invite=invited,
- leave=left,
+ join=joined, invite=invited, leave=left
)
@measure_func("_generate_sync_entry_for_device_list")
- @defer.inlineCallbacks
- def _generate_sync_entry_for_device_list(self, sync_result_builder,
- newly_joined_rooms,
- newly_joined_or_invited_users,
- newly_left_rooms, newly_left_users):
+ async def _generate_sync_entry_for_device_list(
+ self,
+ sync_result_builder: "SyncResultBuilder",
+ newly_joined_rooms: Set[str],
+ newly_joined_or_invited_users: Set[str],
+ newly_left_rooms: Set[str],
+ newly_left_users: Set[str],
+ ) -> DeviceLists:
+ """Generate the DeviceLists section of sync
+
+ Args:
+ sync_result_builder
+ newly_joined_rooms: Set of rooms user has joined since previous sync
+ newly_joined_or_invited_users: Set of users that have joined or
+ been invited to a room since previous sync.
+ newly_left_rooms: Set of rooms user has left since previous sync
+ newly_left_users: Set of users that have left a room we're in since
+ previous sync
+ """
+
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
+ # We're going to mutate these fields, so lets copy them rather than
+ # assume they won't get used later.
+ newly_joined_or_invited_users = set(newly_joined_or_invited_users)
+ newly_left_users = set(newly_left_users)
+
if since_token and since_token.device_list_key:
- changed = yield self.store.get_user_whose_devices_changed(
- since_token.device_list_key
+ # We want to figure out what user IDs the client should refetch
+ # device keys for, and which users we aren't going to track changes
+ # for anymore.
+ #
+ # For the first step we check:
+ # a. if any users we share a room with have updated their devices,
+ # and
+ # b. we also check if we've joined any new rooms, or if a user has
+ # joined a room we're in.
+ #
+ # For the second step we just find any users we no longer share a
+ # room with by looking at all users that have left a room plus users
+ # that were in a room we've left.
+
+ users_who_share_room = await self.store.get_users_who_share_room_with_user(
+ user_id
+ )
+
+ tracked_users = set(users_who_share_room)
+
+ # Always tell the user about their own devices
+ tracked_users.add(user_id)
+
+ # Step 1a, check for changes in devices of users we share a room with
+ users_that_have_changed = await self.store.get_users_whose_devices_changed(
+ since_token.device_list_key, tracked_users
)
- # TODO: Be more clever than this, i.e. remove users who we already
- # share a room with?
+ # Step 1b, check for newly joined rooms
for room_id in newly_joined_rooms:
- joined_users = yield self.state.get_current_users_in_room(room_id)
+ joined_users = await self.state.get_current_users_in_room(room_id)
newly_joined_or_invited_users.update(joined_users)
- for room_id in newly_left_rooms:
- left_users = yield self.state.get_current_users_in_room(room_id)
- newly_left_users.update(left_users)
-
# TODO: Check that these users are actually new, i.e. either they
# weren't in the previous sync *or* they left and rejoined.
- changed.update(newly_joined_or_invited_users)
+ users_that_have_changed.update(newly_joined_or_invited_users)
- if not changed and not newly_left_users:
- defer.returnValue(DeviceLists(
- changed=[],
- left=newly_left_users,
- ))
-
- users_who_share_room = yield self.store.get_users_who_share_room_with_user(
- user_id
+ user_signatures_changed = await self.store.get_users_whose_signatures_changed(
+ user_id, since_token.device_list_key
)
+ users_that_have_changed.update(user_signatures_changed)
- defer.returnValue(DeviceLists(
- changed=users_who_share_room & changed,
- left=set(newly_left_users) - users_who_share_room,
- ))
+ # Now find users that we no longer track
+ for room_id in newly_left_rooms:
+ left_users = await self.state.get_current_users_in_room(room_id)
+ newly_left_users.update(left_users)
+
+ # Remove any users that we still share a room with.
+ newly_left_users -= users_who_share_room
+
+ return DeviceLists(changed=users_that_have_changed, left=newly_left_users)
else:
- defer.returnValue(DeviceLists(
- changed=[],
- left=[],
- ))
+ return DeviceLists(changed=[], left=[])
- @defer.inlineCallbacks
- def _generate_sync_entry_for_to_device(self, sync_result_builder):
+ async def _generate_sync_entry_for_to_device(
+ self, sync_result_builder: "SyncResultBuilder"
+ ) -> None:
"""Generates the portion of the sync response. Populates
`sync_result_builder` with the result.
-
- Args:
- sync_result_builder(SyncResultBuilder)
-
- Returns:
- Deferred(dict): A dictionary containing the per room account data.
"""
user_id = sync_result_builder.sync_config.user.to_string()
device_id = sync_result_builder.sync_config.device_id
@@ -1106,19 +1196,23 @@ class SyncHandler(object):
# We only delete messages when a new message comes in, but that's
# fine so long as we delete them at some point.
- deleted = yield self.store.delete_messages_for_device(
+ deleted = await self.store.delete_messages_for_device(
user_id, device_id, since_stream_id
)
- logger.debug("Deleted %d to-device messages up to %d",
- deleted, since_stream_id)
+ logger.debug(
+ "Deleted %d to-device messages up to %d", deleted, since_stream_id
+ )
- messages, stream_id = yield self.store.get_new_messages_for_device(
+ messages, stream_id = await self.store.get_new_messages_for_device(
user_id, device_id, since_stream_id, now_token.to_device_key
)
logger.debug(
"Returning %d to-device messages between %d and %d (current token: %d)",
- len(messages), since_stream_id, stream_id, now_token.to_device_key
+ len(messages),
+ since_stream_id,
+ stream_id,
+ now_token.to_device_key,
)
sync_result_builder.now_token = now_token.copy_and_replace(
"to_device_key", stream_id
@@ -1127,70 +1221,75 @@ class SyncHandler(object):
else:
sync_result_builder.to_device = []
- @defer.inlineCallbacks
- def _generate_sync_entry_for_account_data(self, sync_result_builder):
+ async def _generate_sync_entry_for_account_data(
+ self, sync_result_builder: "SyncResultBuilder"
+ ) -> Dict[str, Dict[str, JsonDict]]:
"""Generates the account data portion of the sync response. Populates
`sync_result_builder` with the result.
Args:
- sync_result_builder(SyncResultBuilder)
+ sync_result_builder
Returns:
- Deferred(dict): A dictionary containing the per room account data.
+ A dictionary containing the per room account data.
"""
sync_config = sync_result_builder.sync_config
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
if since_token and not sync_result_builder.full_state:
- account_data, account_data_by_room = (
- yield self.store.get_updated_account_data_for_user(
- user_id,
- since_token.account_data_key,
- )
+ (
+ account_data,
+ account_data_by_room,
+ ) = await self.store.get_updated_account_data_for_user(
+ user_id, since_token.account_data_key
)
- push_rules_changed = yield self.store.have_push_rules_changed_for_user(
+ push_rules_changed = await self.store.have_push_rules_changed_for_user(
user_id, int(since_token.push_rules_key)
)
if push_rules_changed:
- account_data["m.push_rules"] = yield self.push_rules_for_user(
+ account_data["m.push_rules"] = await self.push_rules_for_user(
sync_config.user
)
else:
- account_data, account_data_by_room = (
- yield self.store.get_account_data_for_user(
- sync_config.user.to_string()
- )
- )
+ (
+ account_data,
+ account_data_by_room,
+ ) = await self.store.get_account_data_for_user(sync_config.user.to_string())
- account_data['m.push_rules'] = yield self.push_rules_for_user(
+ account_data["m.push_rules"] = await self.push_rules_for_user(
sync_config.user
)
- account_data_for_user = sync_config.filter_collection.filter_account_data([
- {"type": account_data_type, "content": content}
- for account_data_type, content in account_data.items()
- ])
+ account_data_for_user = sync_config.filter_collection.filter_account_data(
+ [
+ {"type": account_data_type, "content": content}
+ for account_data_type, content in account_data.items()
+ ]
+ )
sync_result_builder.account_data = account_data_for_user
- defer.returnValue(account_data_by_room)
+ return account_data_by_room
- @defer.inlineCallbacks
- def _generate_sync_entry_for_presence(self, sync_result_builder, newly_joined_rooms,
- newly_joined_or_invited_users):
+ async def _generate_sync_entry_for_presence(
+ self,
+ sync_result_builder: "SyncResultBuilder",
+ newly_joined_rooms: Set[str],
+ newly_joined_or_invited_users: Set[str],
+ ) -> None:
"""Generates the presence portion of the sync response. Populates the
`sync_result_builder` with the result.
Args:
- sync_result_builder(SyncResultBuilder)
- newly_joined_rooms(list): List of rooms that the user has joined
- since the last sync (or empty if an initial sync)
- newly_joined_or_invited_users(list): List of users that have joined
- or been invited to rooms since the last sync (or empty if an initial
- sync)
+ sync_result_builder
+ newly_joined_rooms: Set of rooms that the user has joined since
+ the last sync (or empty if an initial sync)
+ newly_joined_or_invited_users: Set of users that have joined or
+ been invited to rooms since the last sync (or empty if an
+ initial sync)
"""
now_token = sync_result_builder.now_token
sync_config = sync_result_builder.sync_config
@@ -1206,7 +1305,7 @@ class SyncHandler(object):
presence_key = None
include_offline = False
- presence, presence_key = yield presence_source.get_new_events(
+ presence, presence_key = await presence_source.get_new_events(
user=user,
from_key=presence_key,
is_guest=sync_config.is_guest,
@@ -1218,49 +1317,48 @@ class SyncHandler(object):
extra_users_ids = set(newly_joined_or_invited_users)
for room_id in newly_joined_rooms:
- users = yield self.state.get_current_users_in_room(room_id)
+ users = await self.state.get_current_users_in_room(room_id)
extra_users_ids.update(users)
extra_users_ids.discard(user.to_string())
if extra_users_ids:
- states = yield self.presence_handler.get_states(
- extra_users_ids,
- )
+ states = await self.presence_handler.get_states(extra_users_ids)
presence.extend(states)
# Deduplicate the presence entries so that there's at most one per user
presence = list({p.user_id: p for p in presence}.values())
- presence = sync_config.filter_collection.filter_presence(
- presence
- )
+ presence = sync_config.filter_collection.filter_presence(presence)
sync_result_builder.presence = presence
- @defer.inlineCallbacks
- def _generate_sync_entry_for_rooms(self, sync_result_builder, account_data_by_room):
+ async def _generate_sync_entry_for_rooms(
+ self,
+ sync_result_builder: "SyncResultBuilder",
+ account_data_by_room: Dict[str, Dict[str, JsonDict]],
+ ) -> Tuple[Set[str], Set[str], Set[str], Set[str]]:
"""Generates the rooms portion of the sync response. Populates the
`sync_result_builder` with the result.
Args:
- sync_result_builder(SyncResultBuilder)
- account_data_by_room(dict): Dictionary of per room account data
+ sync_result_builder
+ account_data_by_room: Dictionary of per room account data
Returns:
- Deferred(tuple): Returns a 4-tuple of
+ Returns a 4-tuple of
`(newly_joined_rooms, newly_joined_or_invited_users,
newly_left_rooms, newly_left_users)`
"""
user_id = sync_result_builder.sync_config.user.to_string()
block_all_room_ephemeral = (
- sync_result_builder.since_token is None and
- sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral()
+ sync_result_builder.since_token is None
+ and sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral()
)
if block_all_room_ephemeral:
- ephemeral_by_room = {}
+ ephemeral_by_room = {} # type: Dict[str, List[JsonDict]]
else:
- now_token, ephemeral_by_room = yield self.ephemeral_by_room(
+ now_token, ephemeral_by_room = await self.ephemeral_by_room(
sync_result_builder,
now_token=sync_result_builder.now_token,
since_token=sync_result_builder.since_token,
@@ -1272,18 +1370,17 @@ class SyncHandler(object):
since_token = sync_result_builder.since_token
if not sync_result_builder.full_state:
if since_token and not ephemeral_by_room and not account_data_by_room:
- have_changed = yield self._have_rooms_changed(sync_result_builder)
+ have_changed = await self._have_rooms_changed(sync_result_builder)
if not have_changed:
- tags_by_room = yield self.store.get_updated_tags(
- user_id,
- since_token.account_data_key,
+ tags_by_room = await self.store.get_updated_tags(
+ user_id, since_token.account_data_key
)
if not tags_by_room:
logger.debug("no-oping sync")
- defer.returnValue(([], [], [], []))
+ return set(), set(), set(), set()
- ignored_account_data = yield self.store.get_global_account_data_by_type_for_user(
- "m.ignored_user_list", user_id=user_id,
+ ignored_account_data = await self.store.get_global_account_data_by_type_for_user(
+ "m.ignored_user_list", user_id=user_id
)
if ignored_account_data:
@@ -1292,18 +1389,21 @@ class SyncHandler(object):
ignored_users = frozenset()
if since_token:
- res = yield self._get_rooms_changed(sync_result_builder, ignored_users)
- room_entries, invited, newly_joined_rooms, newly_left_rooms = res
-
- tags_by_room = yield self.store.get_updated_tags(
- user_id, since_token.account_data_key,
+ room_changes = await self._get_rooms_changed(
+ sync_result_builder, ignored_users
+ )
+ tags_by_room = await self.store.get_updated_tags(
+ user_id, since_token.account_data_key
)
else:
- res = yield self._get_all_rooms(sync_result_builder, ignored_users)
- room_entries, invited, newly_joined_rooms = res
- newly_left_rooms = []
+ room_changes = await self._get_all_rooms(sync_result_builder, ignored_users)
- tags_by_room = yield self.store.get_tags_for_user(user_id)
+ tags_by_room = await self.store.get_tags_for_user(user_id)
+
+ room_entries = room_changes.room_entries
+ invited = room_changes.invited
+ newly_joined_rooms = room_changes.newly_joined_rooms
+ newly_left_rooms = room_changes.newly_left_rooms
def handle_room_entries(room_entry):
return self._generate_room_entry(
@@ -1316,7 +1416,7 @@ class SyncHandler(object):
always_include=sync_result_builder.full_state,
)
- yield concurrently_execute(handle_room_entries, room_entries, 10)
+ await concurrently_execute(handle_room_entries, room_entries, 10)
sync_result_builder.invited.extend(invited)
@@ -1331,8 +1431,8 @@ class SyncHandler(object):
for event in it:
if event.type == EventTypes.Member:
if (
- event.membership == Membership.JOIN or
- event.membership == Membership.INVITE
+ event.membership == Membership.JOIN
+ or event.membership == Membership.INVITE
):
newly_joined_or_invited_users.add(event.state_key)
else:
@@ -1343,15 +1443,16 @@ class SyncHandler(object):
newly_left_users -= newly_joined_or_invited_users
- defer.returnValue((
- newly_joined_rooms,
+ return (
+ set(newly_joined_rooms),
newly_joined_or_invited_users,
- newly_left_rooms,
+ set(newly_left_rooms),
newly_left_users,
- ))
+ )
- @defer.inlineCallbacks
- def _have_rooms_changed(self, sync_result_builder):
+ async def _have_rooms_changed(
+ self, sync_result_builder: "SyncResultBuilder"
+ ) -> bool:
"""Returns whether there may be any new events that should be sent down
the sync. Returns True if there are.
"""
@@ -1362,36 +1463,23 @@ class SyncHandler(object):
assert since_token
# Get a list of membership change events that have happened.
- rooms_changed = yield self.store.get_membership_changes_for_user(
+ rooms_changed = await self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key
)
if rooms_changed:
- defer.returnValue(True)
+ return True
stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream
for room_id in sync_result_builder.joined_room_ids:
if self.store.has_room_changed_since(room_id, stream_id):
- defer.returnValue(True)
- defer.returnValue(False)
+ return True
+ return False
- @defer.inlineCallbacks
- def _get_rooms_changed(self, sync_result_builder, ignored_users):
+ async def _get_rooms_changed(
+ self, sync_result_builder: "SyncResultBuilder", ignored_users: Set[str]
+ ) -> _RoomChanges:
"""Gets the the changes that have happened since the last sync.
-
- Args:
- sync_result_builder(SyncResultBuilder)
- ignored_users(set(str)): Set of users ignored by user.
-
- Returns:
- Deferred(tuple): Returns a tuple of the form:
- `(room_entries, invited_rooms, newly_joined_rooms, newly_left_rooms)`
-
- where:
- room_entries is a list [RoomSyncResultBuilder]
- invited_rooms is a list [InvitedSyncResult]
- newly_joined_rooms is a list[str] of room ids
- newly_left_rooms is a list[str] of room ids
"""
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
@@ -1401,11 +1489,11 @@ class SyncHandler(object):
assert since_token
# Get a list of membership change events that have happened.
- rooms_changed = yield self.store.get_membership_changes_for_user(
+ rooms_changed = await self.store.get_membership_changes_for_user(
user_id, since_token.room_key, now_token.room_key
)
- mem_change_events_by_room_id = {}
+ mem_change_events_by_room_id = {} # type: Dict[str, List[EventBase]]
for event in rooms_changed:
mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
@@ -1414,7 +1502,7 @@ class SyncHandler(object):
room_entries = []
invited = []
for room_id, events in iteritems(mem_change_events_by_room_id):
- logger.info(
+ logger.debug(
"Membership changes in %s: [%s]",
room_id,
", ".join(("%s (%s)" % (e.event_id, e.membership) for e in events)),
@@ -1439,11 +1527,11 @@ class SyncHandler(object):
continue
if room_id in sync_result_builder.joined_room_ids or has_join:
- old_state_ids = yield self.get_state_at(room_id, since_token)
+ old_state_ids = await self.get_state_at(room_id, since_token)
old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
old_mem_ev = None
if old_mem_ev_id:
- old_mem_ev = yield self.store.get_event(
+ old_mem_ev = await self.store.get_event(
old_mem_ev_id, allow_none=True
)
@@ -1454,7 +1542,9 @@ class SyncHandler(object):
prev_membership = old_mem_ev.membership
issue4422_logger.debug(
"Previous membership for room %s with join: %s (event %s)",
- room_id, prev_membership, old_mem_ev_id,
+ room_id,
+ prev_membership,
+ old_mem_ev_id,
)
if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
@@ -1474,14 +1564,13 @@ class SyncHandler(object):
newly_left_rooms.append(room_id)
else:
if not old_state_ids:
- old_state_ids = yield self.get_state_at(room_id, since_token)
+ old_state_ids = await self.get_state_at(room_id, since_token)
old_mem_ev_id = old_state_ids.get(
- (EventTypes.Member, user_id),
- None,
+ (EventTypes.Member, user_id), None
)
old_mem_ev = None
if old_mem_ev_id:
- old_mem_ev = yield self.store.get_event(
+ old_mem_ev = await self.store.get_event(
old_mem_ev_id, allow_none=True
)
if old_mem_ev and old_mem_ev.membership == Membership.JOIN:
@@ -1498,13 +1587,14 @@ class SyncHandler(object):
# Always include leave/ban events. Just take the last one.
# TODO: How do we handle ban -> leave in same batch?
leave_events = [
- e for e in non_joins
+ e
+ for e in non_joins
if e.membership in (Membership.LEAVE, Membership.BAN)
]
if leave_events:
leave_event = leave_events[-1]
- leave_stream_token = yield self.store.get_stream_token_for_event(
+ leave_stream_token = await self.store.get_stream_token_for_event(
leave_event.event_id
)
leave_token = since_token.copy_and_replace(
@@ -1522,24 +1612,26 @@ class SyncHandler(object):
# This is all screaming out for a refactor, as the logic here is
# subtle and the moving parts numerous.
if leave_event.internal_metadata.is_out_of_band_membership():
- batch_events = [leave_event]
+ batch_events = [leave_event] # type: Optional[List[EventBase]]
else:
batch_events = None
- room_entries.append(RoomSyncResultBuilder(
- room_id=room_id,
- rtype="archived",
- events=batch_events,
- newly_joined=room_id in newly_joined_rooms,
- full_state=False,
- since_token=since_token,
- upto_token=leave_token,
- ))
+ room_entries.append(
+ RoomSyncResultBuilder(
+ room_id=room_id,
+ rtype="archived",
+ events=batch_events,
+ newly_joined=room_id in newly_joined_rooms,
+ full_state=False,
+ since_token=since_token,
+ upto_token=leave_token,
+ )
+ )
timeline_limit = sync_config.filter_collection.timeline_limit()
# Get all events for rooms we're currently joined to.
- room_to_events = yield self.store.get_room_events_stream_for_rooms(
+ room_to_events = await self.store.get_room_events_stream_for_rooms(
room_ids=sync_result_builder.joined_room_ids,
from_key=since_token.room_key,
to_key=now_token.room_key,
@@ -1581,23 +1673,22 @@ class SyncHandler(object):
# debugging for https://github.com/matrix-org/synapse/issues/4422
issue4422_logger.debug(
"RoomSyncResultBuilder events for newly joined room %s: %r",
- room_id, entry.events,
+ room_id,
+ entry.events,
)
room_entries.append(entry)
- defer.returnValue((room_entries, invited, newly_joined_rooms, newly_left_rooms))
+ return _RoomChanges(room_entries, invited, newly_joined_rooms, newly_left_rooms)
- @defer.inlineCallbacks
- def _get_all_rooms(self, sync_result_builder, ignored_users):
+ async def _get_all_rooms(
+ self, sync_result_builder: "SyncResultBuilder", ignored_users: Set[str]
+ ) -> _RoomChanges:
"""Returns entries for all rooms for the user.
Args:
- sync_result_builder(SyncResultBuilder)
- ignored_users(set(str)): Set of users ignored by user.
+ sync_result_builder
+ ignored_users: Set of users ignored by user.
- Returns:
- Deferred(tuple): Returns a tuple of the form:
- `([RoomSyncResultBuilder], [InvitedSyncResult], [])`
"""
user_id = sync_result_builder.sync_config.user.to_string()
@@ -1606,12 +1697,14 @@ class SyncHandler(object):
sync_config = sync_result_builder.sync_config
membership_list = (
- Membership.INVITE, Membership.JOIN, Membership.LEAVE, Membership.BAN
+ Membership.INVITE,
+ Membership.JOIN,
+ Membership.LEAVE,
+ Membership.BAN,
)
- room_list = yield self.store.get_rooms_for_user_where_membership_is(
- user_id=user_id,
- membership_list=membership_list
+ room_list = await self.store.get_rooms_for_local_user_where_membership_is(
+ user_id=user_id, membership_list=membership_list
)
room_entries = []
@@ -1619,23 +1712,22 @@ class SyncHandler(object):
for event in room_list:
if event.membership == Membership.JOIN:
- room_entries.append(RoomSyncResultBuilder(
- room_id=event.room_id,
- rtype="joined",
- events=None,
- newly_joined=False,
- full_state=True,
- since_token=since_token,
- upto_token=now_token,
- ))
+ room_entries.append(
+ RoomSyncResultBuilder(
+ room_id=event.room_id,
+ rtype="joined",
+ events=None,
+ newly_joined=False,
+ full_state=True,
+ since_token=since_token,
+ upto_token=now_token,
+ )
+ )
elif event.membership == Membership.INVITE:
if event.sender in ignored_users:
continue
- invite = yield self.store.get_event(event.event_id)
- invited.append(InvitedSyncResult(
- room_id=event.room_id,
- invite=invite,
- ))
+ invite = await self.store.get_event(event.event_id)
+ invited.append(InvitedSyncResult(room_id=event.room_id, invite=invite))
elif event.membership in (Membership.LEAVE, Membership.BAN):
# Always send down rooms we were banned or kicked from.
if not sync_config.filter_collection.include_leave:
@@ -1646,41 +1738,47 @@ class SyncHandler(object):
leave_token = now_token.copy_and_replace(
"room_key", "s%d" % (event.stream_ordering,)
)
- room_entries.append(RoomSyncResultBuilder(
- room_id=event.room_id,
- rtype="archived",
- events=None,
- newly_joined=False,
- full_state=True,
- since_token=since_token,
- upto_token=leave_token,
- ))
-
- defer.returnValue((room_entries, invited, []))
+ room_entries.append(
+ RoomSyncResultBuilder(
+ room_id=event.room_id,
+ rtype="archived",
+ events=None,
+ newly_joined=False,
+ full_state=True,
+ since_token=since_token,
+ upto_token=leave_token,
+ )
+ )
- @defer.inlineCallbacks
- def _generate_room_entry(self, sync_result_builder, ignored_users,
- room_builder, ephemeral, tags, account_data,
- always_include=False):
+ return _RoomChanges(room_entries, invited, [], [])
+
+ async def _generate_room_entry(
+ self,
+ sync_result_builder: "SyncResultBuilder",
+ ignored_users: Set[str],
+ room_builder: "RoomSyncResultBuilder",
+ ephemeral: List[JsonDict],
+ tags: Optional[List[JsonDict]],
+ account_data: Dict[str, JsonDict],
+ always_include: bool = False,
+ ):
"""Populates the `joined` and `archived` section of `sync_result_builder`
based on the `room_builder`.
Args:
- sync_result_builder(SyncResultBuilder)
- ignored_users(set(str)): Set of users ignored by user.
- room_builder(RoomSyncResultBuilder)
- ephemeral(list): List of new ephemeral events for room
- tags(list): List of *all* tags for room, or None if there has been
+ sync_result_builder
+ ignored_users: Set of users ignored by user.
+ room_builder
+ ephemeral: List of new ephemeral events for room
+ tags: List of *all* tags for room, or None if there has been
no change.
- account_data(list): List of new account data for room
- always_include(bool): Always include this room in the sync response,
+ account_data: List of new account data for room
+ always_include: Always include this room in the sync response,
even if empty.
"""
newly_joined = room_builder.newly_joined
full_state = (
- room_builder.full_state
- or newly_joined
- or sync_result_builder.full_state
+ room_builder.full_state or newly_joined or sync_result_builder.full_state
)
events = room_builder.events
@@ -1696,19 +1794,25 @@ class SyncHandler(object):
since_token = room_builder.since_token
upto_token = room_builder.upto_token
- batch = yield self._load_filtered_recents(
- room_id, sync_config,
+ batch = await self._load_filtered_recents(
+ room_id,
+ sync_config,
now_token=upto_token,
since_token=since_token,
- recents=events,
+ potential_recents=events,
newly_joined_room=newly_joined,
)
+ # Note: `batch` can be both empty and limited here in the case where
+ # `_load_filtered_recents` can't find any events the user should see
+ # (e.g. due to having ignored the sender of the last 50 events).
+
if newly_joined:
# debug for https://github.com/matrix-org/synapse/issues/4422
issue4422_logger.debug(
"Timeline events after filtering in newly-joined room %s: %r",
- room_id, batch,
+ room_id,
+ batch,
)
# When we join the room (or the client requests full_state), we should
@@ -1717,7 +1821,7 @@ class SyncHandler(object):
# tag was added by synapse e.g. for server notice rooms.
if full_state:
user_id = sync_result_builder.sync_config.user.to_string()
- tags = yield self.store.get_tags_for_room(user_id, room_id)
+ tags = await self.store.get_tags_for_room(user_id, room_id)
# If there aren't any tags, don't send the empty tags list down
# sync
@@ -1726,16 +1830,10 @@ class SyncHandler(object):
account_data_events = []
if tags is not None:
- account_data_events.append({
- "type": "m.tag",
- "content": {"tags": tags},
- })
+ account_data_events.append({"type": "m.tag", "content": {"tags": tags}})
for account_data_type, content in account_data.items():
- account_data_events.append({
- "type": account_data_type,
- "content": content,
- })
+ account_data_events.append({"type": account_data_type, "content": content})
account_data_events = sync_config.filter_collection.filter_room_account_data(
account_data_events
@@ -1743,46 +1841,40 @@ class SyncHandler(object):
ephemeral = sync_config.filter_collection.filter_room_ephemeral(ephemeral)
- if not (always_include
- or batch
- or account_data_events
- or ephemeral
- or full_state):
+ if not (
+ always_include or batch or account_data_events or ephemeral or full_state
+ ):
return
- state = yield self.compute_state_delta(
- room_id, batch, sync_config, since_token, now_token,
- full_state=full_state
+ state = await self.compute_state_delta(
+ room_id, batch, sync_config, since_token, now_token, full_state=full_state
)
- summary = {}
+ summary = {} # type: Optional[JsonDict]
# we include a summary in room responses when we're lazy loading
# members (as the client otherwise doesn't have enough info to form
# the name itself).
- if (
- sync_config.filter_collection.lazy_load_members() and
- (
- # we recalulate the summary:
- # if there are membership changes in the timeline, or
- # if membership has changed during a gappy sync, or
- # if this is an initial sync.
- any(ev.type == EventTypes.Member for ev in batch.events) or
- (
- # XXX: this may include false positives in the form of LL
- # members which have snuck into state
- batch.limited and
- any(t == EventTypes.Member for (t, k) in state)
- ) or
- since_token is None
+ if sync_config.filter_collection.lazy_load_members() and (
+ # we recalulate the summary:
+ # if there are membership changes in the timeline, or
+ # if membership has changed during a gappy sync, or
+ # if this is an initial sync.
+ any(ev.type == EventTypes.Member for ev in batch.events)
+ or (
+ # XXX: this may include false positives in the form of LL
+ # members which have snuck into state
+ batch.limited
+ and any(t == EventTypes.Member for (t, k) in state)
)
+ or since_token is None
):
- summary = yield self.compute_summary(
+ summary = await self.compute_summary(
room_id, sync_config, batch, state, now_token
)
if room_builder.rtype == "joined":
- unread_notifications = {}
+ unread_notifications = {} # type: Dict[str, str]
room_sync = JoinedSyncResult(
room_id=room_id,
timeline=batch,
@@ -1794,9 +1886,7 @@ class SyncHandler(object):
)
if room_sync or always_include:
- notifs = yield self.unread_notifs_for_room_id(
- room_id, sync_config
- )
+ notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
if notifs is not None:
unread_notifications["notification_count"] = notifs["notify_count"]
@@ -1806,27 +1896,25 @@ class SyncHandler(object):
if batch.limited and since_token:
user_id = sync_result_builder.sync_config.user.to_string()
- logger.info(
- "Incremental gappy sync of %s for user %s with %d state events" % (
- room_id,
- user_id,
- len(state),
- )
+ logger.debug(
+ "Incremental gappy sync of %s for user %s with %d state events"
+ % (room_id, user_id, len(state))
)
elif room_builder.rtype == "archived":
- room_sync = ArchivedSyncResult(
+ archived_room_sync = ArchivedSyncResult(
room_id=room_id,
timeline=batch,
state=state,
account_data=account_data_events,
)
- if room_sync or always_include:
- sync_result_builder.archived.append(room_sync)
+ if archived_room_sync or always_include:
+ sync_result_builder.archived.append(archived_room_sync)
else:
raise Exception("Unrecognized rtype: %r", room_builder.rtype)
- @defer.inlineCallbacks
- def get_rooms_for_user_at(self, user_id, stream_ordering):
+ async def get_rooms_for_user_at(
+ self, user_id: str, stream_ordering: int
+ ) -> FrozenSet[str]:
"""Get set of joined rooms for a user at the given stream ordering.
The stream ordering *must* be recent, otherwise this may throw an
@@ -1834,16 +1922,13 @@ class SyncHandler(object):
current token, which should be perfectly fine).
Args:
- user_id (str)
- stream_ordering (int)
+ user_id
+ stream_ordering
ReturnValue:
- Deferred[frozenset[str]]: Set of room_ids the user is in at given
- stream_ordering.
+ Set of room_ids the user is in at given stream_ordering.
"""
- joined_rooms = yield self.store.get_rooms_for_user_with_stream_ordering(
- user_id,
- )
+ joined_rooms = await self.store.get_rooms_for_user_with_stream_ordering(user_id)
joined_room_ids = set()
@@ -1861,20 +1946,17 @@ class SyncHandler(object):
logger.info("User joined room after current token: %s", room_id)
- extrems = yield self.store.get_forward_extremeties_for_room(
- room_id, stream_ordering,
- )
- users_in_room = yield self.state.get_current_users_in_room(
- room_id, extrems,
+ extrems = await self.store.get_forward_extremeties_for_room(
+ room_id, stream_ordering
)
+ users_in_room = await self.state.get_current_users_in_room(room_id, extrems)
if user_id in users_in_room:
joined_room_ids.add(room_id)
- joined_room_ids = frozenset(joined_room_ids)
- defer.returnValue(joined_room_ids)
+ return frozenset(joined_room_ids)
-def _action_has_highlight(actions):
+def _action_has_highlight(actions: List[JsonDict]) -> bool:
for action in actions:
try:
if action.get("set_tweak", None) == "highlight":
@@ -1886,22 +1968,23 @@ def _action_has_highlight(actions):
def _calculate_state(
- timeline_contains, timeline_start, previous, current, lazy_load_members,
-):
+ timeline_contains: StateMap[str],
+ timeline_start: StateMap[str],
+ previous: StateMap[str],
+ current: StateMap[str],
+ lazy_load_members: bool,
+) -> StateMap[str]:
"""Works out what state to include in a sync response.
Args:
- timeline_contains (dict): state in the timeline
- timeline_start (dict): state at the start of the timeline
- previous (dict): state at the end of the previous sync (or empty dict
+ timeline_contains: state in the timeline
+ timeline_start: state at the start of the timeline
+ previous: state at the end of the previous sync (or empty dict
if this is an initial sync)
- current (dict): state at the end of the timeline
- lazy_load_members (bool): whether to return members from timeline_start
+ current: state at the end of the timeline
+ lazy_load_members: whether to return members from timeline_start
or not. assumes that timeline_start has already been filtered to
include only the members the client needs to know about.
-
- Returns:
- dict
"""
event_id_to_key = {
e: key
@@ -1913,10 +1996,10 @@ def _calculate_state(
)
}
- c_ids = set(e for e in itervalues(current))
- ts_ids = set(e for e in itervalues(timeline_start))
- p_ids = set(e for e in itervalues(previous))
- tc_ids = set(e for e in itervalues(timeline_contains))
+ c_ids = set(itervalues(current))
+ ts_ids = set(itervalues(timeline_start))
+ p_ids = set(itervalues(previous))
+ tc_ids = set(itervalues(timeline_contains))
# If we are lazyloading room members, we explicitly add the membership events
# for the senders in the timeline into the state block returned by /sync,
@@ -1930,26 +2013,24 @@ def _calculate_state(
if lazy_load_members:
p_ids.difference_update(
- e for t, e in iteritems(timeline_start)
- if t[0] == EventTypes.Member
+ e for t, e in iteritems(timeline_start) if t[0] == EventTypes.Member
)
state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
- return {
- event_id_to_key[e]: e for e in state_ids
- }
+ return {event_id_to_key[e]: e for e in state_ids}
-class SyncResultBuilder(object):
+@attr.s
+class SyncResultBuilder:
"""Used to help build up a new SyncResult for a user
Attributes:
- sync_config (SyncConfig)
- full_state (bool)
- since_token (StreamToken)
- now_token (StreamToken)
- joined_room_ids (list[str])
+ sync_config
+ full_state: The full_state flag as specified by user
+ since_token: The token supplied by user, or None.
+ now_token: The token to sync up to.
+ joined_room_ids: List of rooms the user is joined to
# The following mirror the fields in a sync response
presence (list)
@@ -1957,57 +2038,45 @@ class SyncResultBuilder(object):
joined (list[JoinedSyncResult])
invited (list[InvitedSyncResult])
archived (list[ArchivedSyncResult])
- device (list)
groups (GroupsSyncResult|None)
to_device (list)
"""
- def __init__(self, sync_config, full_state, since_token, now_token,
- joined_room_ids):
- """
- Args:
- sync_config (SyncConfig)
- full_state (bool): The full_state flag as specified by user
- since_token (StreamToken): The token supplied by user, or None.
- now_token (StreamToken): The token to sync up to.
- joined_room_ids (list[str]): List of rooms the user is joined to
- """
- self.sync_config = sync_config
- self.full_state = full_state
- self.since_token = since_token
- self.now_token = now_token
- self.joined_room_ids = joined_room_ids
-
- self.presence = []
- self.account_data = []
- self.joined = []
- self.invited = []
- self.archived = []
- self.device = []
- self.groups = None
- self.to_device = []
+ sync_config = attr.ib(type=SyncConfig)
+ full_state = attr.ib(type=bool)
+ since_token = attr.ib(type=Optional[StreamToken])
+ now_token = attr.ib(type=StreamToken)
+ joined_room_ids = attr.ib(type=FrozenSet[str])
+
+ presence = attr.ib(type=List[JsonDict], default=attr.Factory(list))
+ account_data = attr.ib(type=List[JsonDict], default=attr.Factory(list))
+ joined = attr.ib(type=List[JoinedSyncResult], default=attr.Factory(list))
+ invited = attr.ib(type=List[InvitedSyncResult], default=attr.Factory(list))
+ archived = attr.ib(type=List[ArchivedSyncResult], default=attr.Factory(list))
+ groups = attr.ib(type=Optional[GroupsSyncResult], default=None)
+ to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list))
+
+@attr.s
class RoomSyncResultBuilder(object):
"""Stores information needed to create either a `JoinedSyncResult` or
`ArchivedSyncResult`.
+
+ Attributes:
+ room_id
+ rtype: One of `"joined"` or `"archived"`
+ events: List of events to include in the room (more events may be added
+ when generating result).
+ newly_joined: If the user has newly joined the room
+ full_state: Whether the full state should be sent in result
+ since_token: Earliest point to return events from, or None
+ upto_token: Latest point to return events from.
"""
- def __init__(self, room_id, rtype, events, newly_joined, full_state,
- since_token, upto_token):
- """
- Args:
- room_id(str)
- rtype(str): One of `"joined"` or `"archived"`
- events(list[FrozenEvent]): List of events to include in the room
- (more events may be added when generating result).
- newly_joined(bool): If the user has newly joined the room
- full_state(bool): Whether the full state should be sent in result
- since_token(StreamToken): Earliest point to return events from, or None
- upto_token(StreamToken): Latest point to return events from.
- """
- self.room_id = room_id
- self.rtype = rtype
- self.events = events
- self.newly_joined = newly_joined
- self.full_state = full_state
- self.since_token = since_token
- self.upto_token = upto_token
+
+ room_id = attr.ib(type=str)
+ rtype = attr.ib(type=str)
+ events = attr.ib(type=Optional[List[EventBase]])
+ newly_joined = attr.ib(type=bool)
+ full_state = attr.ib(type=bool)
+ since_token = attr.ib(type=Optional[StreamToken])
+ upto_token = attr.ib(type=StreamToken)
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 972662eb48..391bceb0c4 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -19,9 +19,9 @@ from collections import namedtuple
from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError
+from synapse.logging.context import run_in_background
from synapse.types import UserID, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
-from synapse.util.logcontext import run_in_background
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
@@ -68,13 +68,10 @@ class TypingHandler(object):
# caches which room_ids changed at which serials
self._typing_stream_change_cache = StreamChangeCache(
- "TypingStreamChangeCache", self._latest_room_serial,
+ "TypingStreamChangeCache", self._latest_room_serial
)
- self.clock.looping_call(
- self._handle_timeouts,
- 5000,
- )
+ self.clock.looping_call(self._handle_timeouts, 5000)
def _reset(self):
"""
@@ -86,7 +83,7 @@ class TypingHandler(object):
self._room_typing = {}
def _handle_timeouts(self):
- logger.info("Checking for typing timeouts")
+ logger.debug("Checking for typing timeouts")
now = self.clock.time_msec()
@@ -108,19 +105,11 @@ class TypingHandler(object):
if self.hs.is_mine_id(member.user_id):
last_fed_poke = self._member_last_federation_poke.get(member, None)
if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now:
- run_in_background(
- self._push_remote,
- member=member,
- typing=True
- )
+ run_in_background(self._push_remote, member=member, typing=True)
# Add a paranoia timer to ensure that we always have a timer for
# each person typing.
- self.wheel_timer.insert(
- now=now,
- obj=member,
- then=now + 60 * 1000,
- )
+ self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
def is_typing(self, member):
return member.user_id in self._room_typing.get(member.room_id, [])
@@ -131,16 +120,14 @@ class TypingHandler(object):
auth_user_id = auth_user.to_string()
if not self.is_mine_id(target_user_id):
- raise SynapseError(400, "User is not hosted on this Home Server")
+ raise SynapseError(400, "User is not hosted on this homeserver")
if target_user_id != auth_user_id:
raise AuthError(400, "Cannot set another user's typing state")
- yield self.auth.check_joined_room(room_id, target_user_id)
+ yield self.auth.check_user_in_room(room_id, target_user_id)
- logger.debug(
- "%s has started typing in %s", target_user_id, room_id
- )
+ logger.debug("%s has started typing in %s", target_user_id, room_id)
member = RoomMember(room_id=room_id, user_id=target_user_id)
@@ -149,20 +136,13 @@ class TypingHandler(object):
now = self.clock.time_msec()
self._member_typing_until[member] = now + timeout
- self.wheel_timer.insert(
- now=now,
- obj=member,
- then=now + timeout,
- )
+ self.wheel_timer.insert(now=now, obj=member, then=now + timeout)
if was_present:
# No point sending another notification
- defer.returnValue(None)
+ return None
- self._push_update(
- member=member,
- typing=True,
- )
+ self._push_update(member=member, typing=True)
@defer.inlineCallbacks
def stopped_typing(self, target_user, auth_user, room_id):
@@ -170,16 +150,14 @@ class TypingHandler(object):
auth_user_id = auth_user.to_string()
if not self.is_mine_id(target_user_id):
- raise SynapseError(400, "User is not hosted on this Home Server")
+ raise SynapseError(400, "User is not hosted on this homeserver")
if target_user_id != auth_user_id:
raise AuthError(400, "Cannot set another user's typing state")
- yield self.auth.check_joined_room(room_id, target_user_id)
+ yield self.auth.check_user_in_room(room_id, target_user_id)
- logger.debug(
- "%s has stopped typing in %s", target_user_id, room_id
- )
+ logger.debug("%s has stopped typing in %s", target_user_id, room_id)
member = RoomMember(room_id=room_id, user_id=target_user_id)
@@ -195,25 +173,19 @@ class TypingHandler(object):
def _stopped_typing(self, member):
if member.user_id not in self._room_typing.get(member.room_id, set()):
# No point
- defer.returnValue(None)
+ return None
self._member_typing_until.pop(member, None)
self._member_last_federation_poke.pop(member, None)
- self._push_update(
- member=member,
- typing=False,
- )
+ self._push_update(member=member, typing=False)
def _push_update(self, member, typing):
if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users.
run_in_background(self._push_remote, member, typing)
- self._push_update_local(
- member=member,
- typing=typing
- )
+ self._push_update_local(member=member, typing=typing)
@defer.inlineCallbacks
def _push_remote(self, member, typing):
@@ -223,12 +195,10 @@ class TypingHandler(object):
now = self.clock.time_msec()
self.wheel_timer.insert(
- now=now,
- obj=member,
- then=now + FEDERATION_PING_INTERVAL,
+ now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
)
- for domain in set(get_domain_from_id(u) for u in users):
+ for domain in {get_domain_from_id(u) for u in users}:
if domain != self.server_name:
logger.debug("sending typing update to %s", domain)
self.federation.build_and_send_edu(
@@ -256,27 +226,19 @@ class TypingHandler(object):
if user.domain != origin:
logger.info(
- "Got typing update from %r with bad 'user_id': %r",
- origin, user_id,
+ "Got typing update from %r with bad 'user_id': %r", origin, user_id
)
return
users = yield self.state.get_current_users_in_room(room_id)
- domains = set(get_domain_from_id(u) for u in users)
+ domains = {get_domain_from_id(u) for u in users}
if self.server_name in domains:
logger.info("Got typing update from %s: %r", user_id, content)
now = self.clock.time_msec()
self._member_typing_until[member] = now + FEDERATION_TIMEOUT
- self.wheel_timer.insert(
- now=now,
- obj=member,
- then=now + FEDERATION_TIMEOUT,
- )
- self._push_update_local(
- member=member,
- typing=content["typing"]
- )
+ self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT)
+ self._push_update_local(member=member, typing=content["typing"])
def _push_update_local(self, member, typing):
room_set = self._room_typing.setdefault(member.room_id, set())
@@ -288,19 +250,19 @@ class TypingHandler(object):
self._latest_room_serial += 1
self._room_serials[member.room_id] = self._latest_room_serial
self._typing_stream_change_cache.entity_has_changed(
- member.room_id, self._latest_room_serial,
+ member.room_id, self._latest_room_serial
)
self.notifier.on_new_event(
"typing_key", self._latest_room_serial, rooms=[member.room_id]
)
- def get_all_typing_updates(self, last_id, current_id):
+ async def get_all_typing_updates(self, last_id, current_id):
if last_id == current_id:
return []
changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
- last_id,
+ last_id
)
if changed_rooms is None:
@@ -334,9 +296,7 @@ class TypingNotificationEventSource(object):
return {
"type": "m.typing",
"room_id": room_id,
- "content": {
- "user_ids": list(typing),
- },
+ "content": {"user_ids": list(typing)},
}
def get_new_events(self, from_key, room_ids, **kwargs):
@@ -353,10 +313,7 @@ class TypingNotificationEventSource(object):
events.append(self._make_event_for(room_id))
- return events, handler._latest_room_serial
+ return defer.succeed((events, handler._latest_room_serial))
def get_current_key(self):
return self.get_typing_handler()._latest_room_serial
-
- def get_pagination_rows(self, user, pagination_config, key):
- return ([], pagination_config.from_key)
diff --git a/synapse/handlers/ui_auth/__init__.py b/synapse/handlers/ui_auth/__init__.py
new file mode 100644
index 0000000000..824f37f8f8
--- /dev/null
+++ b/synapse/handlers/ui_auth/__init__.py
@@ -0,0 +1,22 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""This module implements user-interactive auth verification.
+
+TODO: move more stuff out of AuthHandler in here.
+
+"""
+
+from synapse.handlers.ui_auth.checkers import INTERACTIVE_AUTH_CHECKERS # noqa: F401
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
new file mode 100644
index 0000000000..8363d887a9
--- /dev/null
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -0,0 +1,247 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from canonicaljson import json
+
+from twisted.internet import defer
+from twisted.web.client import PartialDownloadError
+
+from synapse.api.constants import LoginType
+from synapse.api.errors import Codes, LoginError, SynapseError
+from synapse.config.emailconfig import ThreepidBehaviour
+
+logger = logging.getLogger(__name__)
+
+
+class UserInteractiveAuthChecker:
+ """Abstract base class for an interactive auth checker"""
+
+ def __init__(self, hs):
+ pass
+
+ def is_enabled(self):
+ """Check if the configuration of the homeserver allows this checker to work
+
+ Returns:
+ bool: True if this login type is enabled.
+ """
+
+ def check_auth(self, authdict, clientip):
+ """Given the authentication dict from the client, attempt to check this step
+
+ Args:
+ authdict (dict): authentication dictionary from the client
+ clientip (str): The IP address of the client.
+
+ Raises:
+ SynapseError if authentication failed
+
+ Returns:
+ Deferred: the result of authentication (to pass back to the client?)
+ """
+ raise NotImplementedError()
+
+
+class DummyAuthChecker(UserInteractiveAuthChecker):
+ AUTH_TYPE = LoginType.DUMMY
+
+ def is_enabled(self):
+ return True
+
+ def check_auth(self, authdict, clientip):
+ return defer.succeed(True)
+
+
+class TermsAuthChecker(UserInteractiveAuthChecker):
+ AUTH_TYPE = LoginType.TERMS
+
+ def is_enabled(self):
+ return True
+
+ def check_auth(self, authdict, clientip):
+ return defer.succeed(True)
+
+
+class RecaptchaAuthChecker(UserInteractiveAuthChecker):
+ AUTH_TYPE = LoginType.RECAPTCHA
+
+ def __init__(self, hs):
+ super().__init__(hs)
+ self._enabled = bool(hs.config.recaptcha_private_key)
+ self._http_client = hs.get_proxied_http_client()
+ self._url = hs.config.recaptcha_siteverify_api
+ self._secret = hs.config.recaptcha_private_key
+
+ def is_enabled(self):
+ return self._enabled
+
+ @defer.inlineCallbacks
+ def check_auth(self, authdict, clientip):
+ try:
+ user_response = authdict["response"]
+ except KeyError:
+ # Client tried to provide captcha but didn't give the parameter:
+ # bad request.
+ raise LoginError(
+ 400, "Captcha response is required", errcode=Codes.CAPTCHA_NEEDED
+ )
+
+ logger.info(
+ "Submitting recaptcha response %s with remoteip %s", user_response, clientip
+ )
+
+ # TODO: get this from the homeserver rather than creating a new one for
+ # each request
+ try:
+ resp_body = yield self._http_client.post_urlencoded_get_json(
+ self._url,
+ args={
+ "secret": self._secret,
+ "response": user_response,
+ "remoteip": clientip,
+ },
+ )
+ except PartialDownloadError as pde:
+ # Twisted is silly
+ data = pde.response
+ resp_body = json.loads(data)
+
+ if "success" in resp_body:
+ # Note that we do NOT check the hostname here: we explicitly
+ # intend the CAPTCHA to be presented by whatever client the
+ # user is using, we just care that they have completed a CAPTCHA.
+ logger.info(
+ "%s reCAPTCHA from hostname %s",
+ "Successful" if resp_body["success"] else "Failed",
+ resp_body.get("hostname"),
+ )
+ if resp_body["success"]:
+ return True
+ raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
+
+
+class _BaseThreepidAuthChecker:
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def _check_threepid(self, medium, authdict):
+ if "threepid_creds" not in authdict:
+ raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
+
+ threepid_creds = authdict["threepid_creds"]
+
+ identity_handler = self.hs.get_handlers().identity_handler
+
+ logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
+
+ # msisdns are currently always ThreepidBehaviour.REMOTE
+ if medium == "msisdn":
+ if not self.hs.config.account_threepid_delegate_msisdn:
+ raise SynapseError(
+ 400, "Phone number verification is not enabled on this homeserver"
+ )
+ threepid = yield identity_handler.threepid_from_creds(
+ self.hs.config.account_threepid_delegate_msisdn, threepid_creds
+ )
+ elif medium == "email":
+ if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ assert self.hs.config.account_threepid_delegate_email
+ threepid = yield identity_handler.threepid_from_creds(
+ self.hs.config.account_threepid_delegate_email, threepid_creds
+ )
+ elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ threepid = None
+ row = yield self.store.get_threepid_validation_session(
+ medium,
+ threepid_creds["client_secret"],
+ sid=threepid_creds["sid"],
+ validated=True,
+ )
+
+ if row:
+ threepid = {
+ "medium": row["medium"],
+ "address": row["address"],
+ "validated_at": row["validated_at"],
+ }
+
+ # Valid threepid returned, delete from the db
+ yield self.store.delete_threepid_session(threepid_creds["sid"])
+ else:
+ raise SynapseError(
+ 400, "Email address verification is not enabled on this homeserver"
+ )
+ else:
+ # this can't happen!
+ raise AssertionError("Unrecognized threepid medium: %s" % (medium,))
+
+ if not threepid:
+ raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
+
+ if threepid["medium"] != medium:
+ raise LoginError(
+ 401,
+ "Expecting threepid of type '%s', got '%s'"
+ % (medium, threepid["medium"]),
+ errcode=Codes.UNAUTHORIZED,
+ )
+
+ threepid["threepid_creds"] = authdict["threepid_creds"]
+
+ return threepid
+
+
+class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
+ AUTH_TYPE = LoginType.EMAIL_IDENTITY
+
+ def __init__(self, hs):
+ UserInteractiveAuthChecker.__init__(self, hs)
+ _BaseThreepidAuthChecker.__init__(self, hs)
+
+ def is_enabled(self):
+ return self.hs.config.threepid_behaviour_email in (
+ ThreepidBehaviour.REMOTE,
+ ThreepidBehaviour.LOCAL,
+ )
+
+ def check_auth(self, authdict, clientip):
+ return self._check_threepid("email", authdict)
+
+
+class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker):
+ AUTH_TYPE = LoginType.MSISDN
+
+ def __init__(self, hs):
+ UserInteractiveAuthChecker.__init__(self, hs)
+ _BaseThreepidAuthChecker.__init__(self, hs)
+
+ def is_enabled(self):
+ return bool(self.hs.config.account_threepid_delegate_msisdn)
+
+ def check_auth(self, authdict, clientip):
+ return self._check_threepid("msisdn", authdict)
+
+
+INTERACTIVE_AUTH_CHECKERS = [
+ DummyAuthChecker,
+ TermsAuthChecker,
+ RecaptchaAuthChecker,
+ EmailIdentityAuthChecker,
+ MsisdnAuthChecker,
+]
+"""A list of UserInteractiveAuthChecker classes"""
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 5de9630950..722760c59d 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -52,6 +52,7 @@ class UserDirectoryHandler(StateDeltasHandler):
self.is_mine_id = hs.is_mine_id
self.update_user_directory = hs.config.update_user_directory
self.search_all_users = hs.config.user_directory_search_all_users
+ self.spam_checker = hs.get_spam_checker()
# The current position in the current_state_delta stream
self.pos = None
@@ -65,7 +66,7 @@ class UserDirectoryHandler(StateDeltasHandler):
# we start populating the user directory
self.clock.call_later(0, self.notify_new_event)
- def search_users(self, user_id, search_term, limit):
+ async def search_users(self, user_id, search_term, limit):
"""Searches for users in directory
Returns:
@@ -82,7 +83,16 @@ class UserDirectoryHandler(StateDeltasHandler):
]
}
"""
- return self.store.search_user_dir(user_id, search_term, limit)
+ results = await self.store.search_user_dir(user_id, search_term, limit)
+
+ # Remove any spammy users from the results.
+ results["results"] = [
+ user
+ for user in results["results"]
+ if not self.spam_checker.check_username_for_spam(user)
+ ]
+
+ return results
def notify_new_event(self):
"""Called when there may be more deltas to process
@@ -133,26 +143,33 @@ class UserDirectoryHandler(StateDeltasHandler):
# If still None then the initial background update hasn't happened yet
if self.pos is None:
- defer.returnValue(None)
+ return None
# Loop round handling deltas until we're up to date
while True:
with Measure(self.clock, "user_dir_delta"):
- deltas = yield self.store.get_current_state_deltas(self.pos)
- if not deltas:
+ room_max_stream_ordering = self.store.get_room_max_stream_ordering()
+ if self.pos == room_max_stream_ordering:
return
- logger.info("Handling %d state deltas", len(deltas))
+ logger.debug(
+ "Processing user stats %s->%s", self.pos, room_max_stream_ordering
+ )
+ max_pos, deltas = yield self.store.get_current_state_deltas(
+ self.pos, room_max_stream_ordering
+ )
+
+ logger.debug("Handling %d state deltas", len(deltas))
yield self._handle_deltas(deltas)
- self.pos = deltas[-1]["stream_id"]
+ self.pos = max_pos
# Expose current event processing position to prometheus
synapse.metrics.event_processing_positions.labels("user_dir").set(
- self.pos
+ max_pos
)
- yield self.store.update_user_directory_stream_pos(self.pos)
+ yield self.store.update_user_directory_stream_pos(max_pos)
@defer.inlineCallbacks
def _handle_deltas(self, deltas):
@@ -188,7 +205,7 @@ class UserDirectoryHandler(StateDeltasHandler):
room_id, self.server_name
)
if not is_in_room:
- logger.info("Server left room: %r", room_id)
+ logger.debug("Server left room: %r", room_id)
# Fetch all the users that we marked as being in user
# directory due to being in the room and then check if
# need to remove those users or not
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index d36bcd6336..3880ce0d94 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -25,6 +25,7 @@ from synapse.api.errors import SynapseError
class RequestTimedOutError(SynapseError):
"""Exception representing timeout of an outbound request"""
+
def __init__(self):
super(RequestTimedOutError, self).__init__(504, "Timed out")
@@ -40,15 +41,14 @@ def cancelled_to_request_timed_out_error(value, timeout):
return value
-ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
+ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$")
+CLIENT_SECRET_RE = re.compile(r"(\?.*client(_|%5[Ff])secret=)[^&]*(.*)$")
def redact_uri(uri):
- """Strips access tokens from the uri replaces with <redacted>"""
- return ACCESS_TOKEN_RE.sub(
- r'\1<redacted>\3',
- uri
- )
+ """Strips sensitive information from the uri replaces with <redacted>"""
+ uri = ACCESS_TOKEN_RE.sub(r"\1<redacted>\3", uri)
+ return CLIENT_SECRET_RE.sub(r"\1<redacted>\3", uri)
class QuieterFileBodyProducer(FileBodyProducer):
@@ -57,6 +57,7 @@ class QuieterFileBodyProducer(FileBodyProducer):
Workaround for https://github.com/matrix-org/synapse/issues/4003 /
https://twistedmatrix.com/trac/ticket/6528
"""
+
def stopProducing(self):
try:
FileBodyProducer.stopProducing(self)
diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py
index 0e10e3f8f7..096619a8c2 100644
--- a/synapse/http/additional_resource.py
+++ b/synapse/http/additional_resource.py
@@ -28,6 +28,7 @@ class AdditionalResource(Resource):
This class is also where we wrap the request handler with logging, metrics,
and exception handling.
"""
+
def __init__(self, hs, handler):
"""Initialise AdditionalResource
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 896e71cef3..3797545824 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -17,7 +17,7 @@
import logging
from io import BytesIO
-from six import text_type
+from six import raise_from, text_type
from six.moves import urllib
import treq
@@ -35,7 +35,7 @@ from twisted.internet.interfaces import (
)
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
-from twisted.web.client import Agent, HTTPConnectionPool, PartialDownloadError, readBody
+from twisted.web.client import Agent, HTTPConnectionPool, readBody
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
@@ -46,9 +46,10 @@ from synapse.http import (
redact_uri,
)
from synapse.http.proxyagent import ProxyAgent
+from synapse.logging.context import make_deferred_yieldable
+from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.util.async_helpers import timeout_deferred
from synapse.util.caches import CACHE_SIZE_FACTOR
-from synapse.util.logcontext import make_deferred_yieldable
logger = logging.getLogger(__name__)
@@ -104,8 +105,8 @@ class IPBlacklistingResolver(object):
ip_address, self._ip_whitelist, self._ip_blacklist
):
logger.info(
- "Dropped %s from DNS resolution to %s due to blacklist" %
- (ip_address, hostname)
+ "Dropped %s from DNS resolution to %s due to blacklist"
+ % (ip_address, hostname)
)
has_bad_ip = True
@@ -157,7 +158,7 @@ class BlacklistingAgentWrapper(Agent):
self._ip_blacklist = ip_blacklist
def request(self, method, uri, headers=None, bodyProducer=None):
- h = urllib.parse.urlparse(uri.decode('ascii'))
+ h = urllib.parse.urlparse(uri.decode("ascii"))
try:
ip_address = IPAddress(h.hostname)
@@ -165,10 +166,7 @@ class BlacklistingAgentWrapper(Agent):
if check_against_blacklist(
ip_address, self._ip_whitelist, self._ip_blacklist
):
- logger.info(
- "Blocking access to %s due to blacklist" %
- (ip_address,)
- )
+ logger.info("Blocking access to %s due to blacklist" % (ip_address,))
e = SynapseError(403, "IP address blocked by IP blacklist entry")
return defer.fail(Failure(e))
except Exception:
@@ -217,7 +215,7 @@ class SimpleHttpClient(object):
if hs.config.user_agent_suffix:
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix)
- self.user_agent = self.user_agent.encode('ascii')
+ self.user_agent = self.user_agent.encode("ascii")
if self._ip_blacklist:
real_reactor = hs.get_reactor()
@@ -282,42 +280,56 @@ class SimpleHttpClient(object):
# log request but strip `access_token` (AS requests for example include this)
logger.info("Sending request %s %s", method, redact_uri(uri))
- try:
- body_producer = None
- if data is not None:
- body_producer = QuieterFileBodyProducer(BytesIO(data))
-
- request_deferred = treq.request(
- method,
- uri,
- agent=self.agent,
- data=body_producer,
- headers=headers,
- **self._extra_treq_args
- )
- request_deferred = timeout_deferred(
- request_deferred,
- 60,
- self.hs.get_reactor(),
- cancelled_to_request_timed_out_error,
- )
- response = yield make_deferred_yieldable(request_deferred)
+ with start_active_span(
+ "outgoing-client-request",
+ tags={
+ tags.SPAN_KIND: tags.SPAN_KIND_RPC_CLIENT,
+ tags.HTTP_METHOD: method,
+ tags.HTTP_URL: uri,
+ },
+ finish_on_close=True,
+ ):
+ try:
+ body_producer = None
+ if data is not None:
+ body_producer = QuieterFileBodyProducer(BytesIO(data))
+
+ request_deferred = treq.request(
+ method,
+ uri,
+ agent=self.agent,
+ data=body_producer,
+ headers=headers,
+ **self._extra_treq_args
+ )
+ request_deferred = timeout_deferred(
+ request_deferred,
+ 60,
+ self.hs.get_reactor(),
+ cancelled_to_request_timed_out_error,
+ )
+ response = yield make_deferred_yieldable(request_deferred)
- incoming_responses_counter.labels(method, response.code).inc()
- logger.info(
- "Received response to %s %s: %s", method, redact_uri(uri), response.code
- )
- defer.returnValue(response)
- except Exception as e:
- incoming_responses_counter.labels(method, "ERR").inc()
- logger.info(
- "Error sending request to %s %s: %s %s",
- method,
- redact_uri(uri),
- type(e).__name__,
- e.args[0],
- )
- raise
+ incoming_responses_counter.labels(method, response.code).inc()
+ logger.info(
+ "Received response to %s %s: %s",
+ method,
+ redact_uri(uri),
+ response.code,
+ )
+ return response
+ except Exception as e:
+ incoming_responses_counter.labels(method, "ERR").inc()
+ logger.info(
+ "Error sending request to %s %s: %s %s",
+ method,
+ redact_uri(uri),
+ type(e).__name__,
+ e.args[0],
+ )
+ set_tag(tags.ERROR, True)
+ set_tag("error_reason", e.args[0])
+ raise
@defer.inlineCallbacks
def post_urlencoded_get_json(self, uri, args={}, headers=None):
@@ -325,7 +337,7 @@ class SimpleHttpClient(object):
Args:
uri (str):
args (dict[str, str|List[str]]): query params
- headers (dict[str, List[str]]|None): If not None, a map from
+ headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
@@ -358,7 +370,7 @@ class SimpleHttpClient(object):
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
- defer.returnValue(json.loads(body))
+ return json.loads(body)
else:
raise HttpResponseException(response.code, response.phrase, body)
@@ -369,7 +381,7 @@ class SimpleHttpClient(object):
Args:
uri (str):
post_json (object):
- headers (dict[str, List[str]]|None): If not None, a map from
+ headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
@@ -398,7 +410,7 @@ class SimpleHttpClient(object):
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
- defer.returnValue(json.loads(body))
+ return json.loads(body)
else:
raise HttpResponseException(response.code, response.phrase, body)
@@ -412,7 +424,7 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
- headers (dict[str, List[str]]|None): If not None, a map from
+ headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
@@ -423,7 +435,7 @@ class SimpleHttpClient(object):
ValueError: if the response was not JSON
"""
body = yield self.get_raw(uri, args, headers=headers)
- defer.returnValue(json.loads(body))
+ return json.loads(body)
@defer.inlineCallbacks
def put_json(self, uri, json_body, args={}, headers=None):
@@ -436,7 +448,7 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
- headers (dict[str, List[str]]|None): If not None, a map from
+ headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
@@ -466,7 +478,7 @@ class SimpleHttpClient(object):
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
- defer.returnValue(json.loads(body))
+ return json.loads(body)
else:
raise HttpResponseException(response.code, response.phrase, body)
@@ -480,7 +492,7 @@ class SimpleHttpClient(object):
None.
**Note**: The value of each key is assumed to be an iterable
and *not* a string.
- headers (dict[str, List[str]]|None): If not None, a map from
+ headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
Deferred: Succeeds when we get *any* 2xx HTTP response, with the
@@ -501,7 +513,7 @@ class SimpleHttpClient(object):
body = yield make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
- defer.returnValue(body)
+ return body
else:
raise HttpResponseException(response.code, response.phrase, body)
@@ -514,7 +526,7 @@ class SimpleHttpClient(object):
Args:
url (str): The URL to GET
output_stream (file): File to write the response body to.
- headers (dict[str, List[str]]|None): If not None, a map from
+ headers (dict[str|bytes, List[str|bytes]]|None): If not None, a map from
header name to a list of values for that header
Returns:
A (int,dict,string,int) tuple of the file length, dict of the response
@@ -530,10 +542,10 @@ class SimpleHttpClient(object):
resp_headers = dict(response.headers.getAllRawHeaders())
if (
- b'Content-Length' in resp_headers
- and int(resp_headers[b'Content-Length'][0]) > max_size
+ b"Content-Length" in resp_headers
+ and int(resp_headers[b"Content-Length"][0]) > max_size
):
- logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
+ logger.warning("Requested URL is too large > %r bytes" % (self.max_size,))
raise SynapseError(
502,
"Requested file is too large > %r bytes" % (self.max_size,),
@@ -541,7 +553,7 @@ class SimpleHttpClient(object):
)
if response.code > 299:
- logger.warn("Got %d when downloading %s" % (response.code, url))
+ logger.warning("Got %d when downloading %s" % (response.code, url))
raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN)
# TODO: if our Content-Type is HTML or something, just read the first
@@ -552,19 +564,17 @@ class SimpleHttpClient(object):
length = yield make_deferred_yieldable(
_readBodyToFile(response, output_stream, max_size)
)
+ except SynapseError:
+ # This can happen e.g. because the body is too large.
+ raise
except Exception as e:
- logger.exception("Failed to download body")
- raise SynapseError(
- 502, ("Failed to download remote body: %s" % e), Codes.UNKNOWN
- )
+ raise_from(SynapseError(502, ("Failed to download remote body: %s" % e)), e)
- defer.returnValue(
- (
- length,
- resp_headers,
- response.request.absoluteURI.decode('ascii'),
- response.code,
- )
+ return (
+ length,
+ resp_headers,
+ response.request.absoluteURI.decode("ascii"),
+ response.code,
)
@@ -614,45 +624,13 @@ def _readBodyToFile(response, stream, max_size):
return d
-class CaptchaServerHttpClient(SimpleHttpClient):
- """
- Separate HTTP client for talking to google's captcha servers
- Only slightly special because accepts partial download responses
-
- used only by c/s api v1
- """
-
- @defer.inlineCallbacks
- def post_urlencoded_get_raw(self, url, args={}):
- query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True)
-
- response = yield self.request(
- "POST",
- url,
- data=query_bytes,
- headers=Headers(
- {
- b"Content-Type": [b"application/x-www-form-urlencoded"],
- b"User-Agent": [self.user_agent],
- }
- ),
- )
-
- try:
- body = yield make_deferred_yieldable(readBody(response))
- defer.returnValue(body)
- except PartialDownloadError as e:
- # twisted dislikes google's response, no content length.
- defer.returnValue(e.response)
-
-
def encode_urlencode_args(args):
return {k: encode_urlencode_arg(v) for k, v in args.items()}
def encode_urlencode_arg(arg):
if isinstance(arg, text_type):
- return arg.encode('utf-8')
+ return arg.encode("utf-8")
elif isinstance(arg, list):
return [encode_urlencode_arg(i) for i in arg]
else:
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index cd79ebab62..92a5b606c8 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -31,7 +31,7 @@ def parse_server_name(server_name):
ValueError if the server name could not be parsed.
"""
try:
- if server_name[-1] == ']':
+ if server_name[-1] == "]":
# ipv6 literal, hopefully
return server_name, None
@@ -43,9 +43,7 @@ def parse_server_name(server_name):
raise ValueError("Invalid server name '%s'" % server_name)
-VALID_HOST_REGEX = re.compile(
- "\\A[0-9a-zA-Z.-]+\\Z",
-)
+VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z")
def parse_and_validate_server_name(server_name):
@@ -67,17 +65,15 @@ def parse_and_validate_server_name(server_name):
# that nobody is sneaking IP literals in that look like hostnames, etc.
# look for ipv6 literals
- if host[0] == '[':
- if host[-1] != ']':
- raise ValueError("Mismatched [...] in server name '%s'" % (
- server_name,
- ))
+ if host[0] == "[":
+ if host[-1] != "]":
+ raise ValueError("Mismatched [...] in server name '%s'" % (server_name,))
return host, port
# otherwise it should only be alphanumerics.
if not VALID_HOST_REGEX.match(host):
- raise ValueError("Server name '%s' contains invalid characters" % (
- server_name,
- ))
+ raise ValueError(
+ "Server name '%s' contains invalid characters" % (server_name,)
+ )
return host, port
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index f595349a0e..f5f917f5ae 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -12,49 +12,33 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
+
import logging
-import random
-import time
+import urllib
-import attr
-from netaddr import IPAddress
+from netaddr import AddrFormatError, IPAddress
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet.interfaces import IStreamClientEndpoint
-from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent, readBody
-from twisted.web.http import stringToDatetime
+from twisted.web.client import Agent, HTTPConnectionPool
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent
+from twisted.web.iweb import IAgent, IAgentEndpointFactory
-from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list
+from synapse.http.federation.srv_resolver import Server, SrvResolver
+from synapse.http.federation.well_known_resolver import WellKnownResolver
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.util import Clock
-from synapse.util.caches.ttlcache import TTLCache
-from synapse.util.logcontext import make_deferred_yieldable
-from synapse.util.metrics import Measure
-
-# period to cache .well-known results for by default
-WELL_KNOWN_DEFAULT_CACHE_PERIOD = 24 * 3600
-
-# jitter to add to the .well-known default cache ttl
-WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER = 10 * 60
-
-# period to cache failure to fetch .well-known for
-WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600
-
-# cap for .well-known cache period
-WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
logger = logging.getLogger(__name__)
-well_known_cache = TTLCache('well-known')
@implementer(IAgent)
class MatrixFederationAgent(object):
- """An Agent-like thing which provides a `request` method which will look up a matrix
- server and send an HTTP request to it.
+ """An Agent-like thing which provides a `request` method which correctly
+ handles resolving matrix server names when using matrix://. Handles standard
+ https URIs as normal.
Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
@@ -68,63 +52,59 @@ class MatrixFederationAgent(object):
SRVResolver impl to use for looking up SRV records. None to use a default
implementation.
- _well_known_cache (TTLCache|None):
- TTLCache impl for storing cached well-known lookups. None to use a default
- implementation.
+ _well_known_resolver (WellKnownResolver|None):
+ WellKnownResolver to use to perform well-known lookups. None to use a
+ default implementation.
"""
def __init__(
- self, reactor, tls_client_options_factory,
- _well_known_tls_policy=None,
+ self,
+ reactor,
+ tls_client_options_factory,
_srv_resolver=None,
- _well_known_cache=well_known_cache,
+ _well_known_resolver=None,
):
self._reactor = reactor
self._clock = Clock(reactor)
-
- self._tls_client_options_factory = tls_client_options_factory
- if _srv_resolver is None:
- _srv_resolver = SrvResolver()
- self._srv_resolver = _srv_resolver
-
self._pool = HTTPConnectionPool(reactor)
self._pool.retryAutomatically = False
self._pool.maxPersistentPerHost = 5
self._pool.cachedConnectionTimeout = 2 * 60
- _well_known_agent = RedirectAgent(
- Agent(
+ self._agent = Agent.usingEndpointFactory(
+ self._reactor,
+ MatrixHostnameEndpointFactory(
+ reactor, tls_client_options_factory, _srv_resolver
+ ),
+ pool=self._pool,
+ )
+
+ if _well_known_resolver is None:
+ _well_known_resolver = WellKnownResolver(
self._reactor,
- pool=self._pool,
- contextFactory=tls_client_options_factory,
+ agent=Agent(
+ self._reactor,
+ pool=self._pool,
+ contextFactory=tls_client_options_factory,
+ ),
)
- )
- self._well_known_agent = _well_known_agent
- # our cache of .well-known lookup results, mapping from server name
- # to delegated name. The values can be:
- # `bytes`: a valid server-name
- # `None`: there is no (valid) .well-known here
- self._well_known_cache = _well_known_cache
+ self._well_known_resolver = _well_known_resolver
@defer.inlineCallbacks
def request(self, method, uri, headers=None, bodyProducer=None):
"""
Args:
method (bytes): HTTP method: GET/POST/etc
-
uri (bytes): Absolute URI to be retrieved
-
headers (twisted.web.http_headers.Headers|None):
HTTP headers to send with the request, or None to
send no extra headers.
-
bodyProducer (twisted.web.iweb.IBodyProducer|None):
An object which can generate bytes to make up the
body of this request (for example, the properly encoded contents of
a file for a file upload). Or None if the request is to have
no body.
-
Returns:
Deferred[twisted.web.iweb.IResponse]:
fires when the header of the response has been received (regardless of the
@@ -132,320 +112,207 @@ class MatrixFederationAgent(object):
response from being received (including problems that prevent the request
from being sent).
"""
- parsed_uri = URI.fromBytes(uri, defaultPort=-1)
- res = yield self._route_matrix_uri(parsed_uri)
+ # We use urlparse as that will set `port` to None if there is no
+ # explicit port.
+ parsed_uri = urllib.parse.urlparse(uri)
- # set up the TLS connection params
+ # If this is a matrix:// URI check if the server has delegated matrix
+ # traffic using well-known delegation.
#
- # XXX disabling TLS is really only supported here for the benefit of the
- # unit tests. We should make the UTs cope with TLS rather than having to make
- # the code support the unit tests.
- if self._tls_client_options_factory is None:
- tls_options = None
- else:
- tls_options = self._tls_client_options_factory.get_options(
- res.tls_server_name.decode("ascii"),
+ # We have to do this here and not in the endpoint as we need to rewrite
+ # the host header with the delegated server name.
+ delegated_server = None
+ if (
+ parsed_uri.scheme == b"matrix"
+ and not _is_ip_literal(parsed_uri.hostname)
+ and not parsed_uri.port
+ ):
+ well_known_result = yield self._well_known_resolver.get_well_known(
+ parsed_uri.hostname
+ )
+ delegated_server = well_known_result.delegated_server
+
+ if delegated_server:
+ # Ok, the server has delegated matrix traffic to somewhere else, so
+ # lets rewrite the URL to replace the server with the delegated
+ # server name.
+ uri = urllib.parse.urlunparse(
+ (
+ parsed_uri.scheme,
+ delegated_server,
+ parsed_uri.path,
+ parsed_uri.params,
+ parsed_uri.query,
+ parsed_uri.fragment,
+ )
)
+ parsed_uri = urllib.parse.urlparse(uri)
- # make sure that the Host header is set correctly
+ # We need to make sure the host header is set to the netloc of the
+ # server.
if headers is None:
headers = Headers()
else:
headers = headers.copy()
- if not headers.hasHeader(b'host'):
- headers.addRawHeader(b'host', res.host_header)
-
- class EndpointFactory(object):
- @staticmethod
- def endpointForURI(_uri):
- ep = LoggingHostnameEndpoint(
- self._reactor, res.target_host, res.target_port,
- )
- if tls_options is not None:
- ep = wrapClientTLS(tls_options, ep)
- return ep
+ if not headers.hasHeader(b"host"):
+ headers.addRawHeader(b"host", parsed_uri.netloc)
- agent = Agent.usingEndpointFactory(self._reactor, EndpointFactory(), self._pool)
res = yield make_deferred_yieldable(
- agent.request(method, uri, headers, bodyProducer)
+ self._agent.request(method, uri, headers, bodyProducer)
)
- defer.returnValue(res)
- @defer.inlineCallbacks
- def _route_matrix_uri(self, parsed_uri, lookup_well_known=True):
- """Helper for `request`: determine the routing for a Matrix URI
-
- Args:
- parsed_uri (twisted.web.client.URI): uri to route. Note that it should be
- parsed with URI.fromBytes(uri, defaultPort=-1) to set the `port` to -1
- if there is no explicit port given.
+ return res
- lookup_well_known (bool): True if we should look up the .well-known file if
- there is no SRV record.
-
- Returns:
- Deferred[_RoutingResult]
- """
- # check for an IP literal
- try:
- ip_address = IPAddress(parsed_uri.host.decode("ascii"))
- except Exception:
- # not an IP address
- ip_address = None
-
- if ip_address:
- port = parsed_uri.port
- if port == -1:
- port = 8448
- defer.returnValue(_RoutingResult(
- host_header=parsed_uri.netloc,
- tls_server_name=parsed_uri.host,
- target_host=parsed_uri.host,
- target_port=port,
- ))
-
- if parsed_uri.port != -1:
- # there is an explicit port
- defer.returnValue(_RoutingResult(
- host_header=parsed_uri.netloc,
- tls_server_name=parsed_uri.host,
- target_host=parsed_uri.host,
- target_port=parsed_uri.port,
- ))
-
- if lookup_well_known:
- # try a .well-known lookup
- well_known_server = yield self._get_well_known(parsed_uri.host)
-
- if well_known_server:
- # if we found a .well-known, start again, but don't do another
- # .well-known lookup.
-
- # parse the server name in the .well-known response into host/port.
- # (This code is lifted from twisted.web.client.URI.fromBytes).
- if b':' in well_known_server:
- well_known_host, well_known_port = well_known_server.rsplit(b':', 1)
- try:
- well_known_port = int(well_known_port)
- except ValueError:
- # the part after the colon could not be parsed as an int
- # - we assume it is an IPv6 literal with no port (the closing
- # ']' stops it being parsed as an int)
- well_known_host, well_known_port = well_known_server, -1
- else:
- well_known_host, well_known_port = well_known_server, -1
-
- new_uri = URI(
- scheme=parsed_uri.scheme,
- netloc=well_known_server,
- host=well_known_host,
- port=well_known_port,
- path=parsed_uri.path,
- params=parsed_uri.params,
- query=parsed_uri.query,
- fragment=parsed_uri.fragment,
- )
- res = yield self._route_matrix_uri(new_uri, lookup_well_known=False)
- defer.returnValue(res)
+@implementer(IAgentEndpointFactory)
+class MatrixHostnameEndpointFactory(object):
+ """Factory for MatrixHostnameEndpoint for parsing to an Agent.
+ """
- # try a SRV lookup
- service_name = b"_matrix._tcp.%s" % (parsed_uri.host,)
- server_list = yield self._srv_resolver.resolve_service(service_name)
+ def __init__(self, reactor, tls_client_options_factory, srv_resolver):
+ self._reactor = reactor
+ self._tls_client_options_factory = tls_client_options_factory
- if not server_list:
- target_host = parsed_uri.host
- port = 8448
- logger.debug(
- "No SRV record for %s, using %s:%i",
- parsed_uri.host.decode("ascii"), target_host.decode("ascii"), port,
- )
- else:
- target_host, port = pick_server_from_list(server_list)
- logger.debug(
- "Picked %s:%i from SRV records for %s",
- target_host.decode("ascii"), port, parsed_uri.host.decode("ascii"),
- )
+ if srv_resolver is None:
+ srv_resolver = SrvResolver()
- defer.returnValue(_RoutingResult(
- host_header=parsed_uri.netloc,
- tls_server_name=parsed_uri.host,
- target_host=target_host,
- target_port=port,
- ))
+ self._srv_resolver = srv_resolver
- @defer.inlineCallbacks
- def _get_well_known(self, server_name):
- """Attempt to fetch and parse a .well-known file for the given server
+ def endpointForURI(self, parsed_uri):
+ return MatrixHostnameEndpoint(
+ self._reactor,
+ self._tls_client_options_factory,
+ self._srv_resolver,
+ parsed_uri,
+ )
- Args:
- server_name (bytes): name of the server, from the requested url
- Returns:
- Deferred[bytes|None]: either the new server name, from the .well-known, or
- None if there was no .well-known file.
- """
- try:
- result = self._well_known_cache[server_name]
- except KeyError:
- # TODO: should we linearise so that we don't end up doing two .well-known
- # requests for the same server in parallel?
- with Measure(self._clock, "get_well_known"):
- result, cache_period = yield self._do_get_well_known(server_name)
+@implementer(IStreamClientEndpoint)
+class MatrixHostnameEndpoint(object):
+ """An endpoint that resolves matrix:// URLs using Matrix server name
+ resolution (i.e. via SRV). Does not check for well-known delegation.
- if cache_period > 0:
- self._well_known_cache.set(server_name, result, cache_period)
+ Args:
+ reactor (IReactor)
+ tls_client_options_factory (ClientTLSOptionsFactory|None):
+ factory to use for fetching client tls options, or none to disable TLS.
+ srv_resolver (SrvResolver): The SRV resolver to use
+ parsed_uri (twisted.web.client.URI): The parsed URI that we're wanting
+ to connect to.
+ """
- defer.returnValue(result)
+ def __init__(self, reactor, tls_client_options_factory, srv_resolver, parsed_uri):
+ self._reactor = reactor
- @defer.inlineCallbacks
- def _do_get_well_known(self, server_name):
- """Actually fetch and parse a .well-known, without checking the cache
+ self._parsed_uri = parsed_uri
- Args:
- server_name (bytes): name of the server, from the requested url
+ # set up the TLS connection params
+ #
+ # XXX disabling TLS is really only supported here for the benefit of the
+ # unit tests. We should make the UTs cope with TLS rather than having to make
+ # the code support the unit tests.
- Returns:
- Deferred[Tuple[bytes|None|object],int]:
- result, cache period, where result is one of:
- - the new server name from the .well-known (as a `bytes`)
- - None if there was no .well-known file.
- - INVALID_WELL_KNOWN if the .well-known was invalid
- """
- uri = b"https://%s/.well-known/matrix/server" % (server_name, )
- uri_str = uri.decode("ascii")
- logger.info("Fetching %s", uri_str)
- try:
- response = yield make_deferred_yieldable(
- self._well_known_agent.request(b"GET", uri),
- )
- body = yield make_deferred_yieldable(readBody(response))
- if response.code != 200:
- raise Exception("Non-200 response %s" % (response.code, ))
-
- parsed_body = json.loads(body.decode('utf-8'))
- logger.info("Response from .well-known: %s", parsed_body)
- if not isinstance(parsed_body, dict):
- raise Exception("not a dict")
- if "m.server" not in parsed_body:
- raise Exception("Missing key 'm.server'")
- except Exception as e:
- logger.info("Error fetching %s: %s", uri_str, e)
-
- # add some randomness to the TTL to avoid a stampeding herd every hour
- # after startup
- cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
- cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
- defer.returnValue((None, cache_period))
-
- result = parsed_body["m.server"].encode("ascii")
-
- cache_period = _cache_period_from_headers(
- response.headers,
- time_now=self._reactor.seconds,
- )
- if cache_period is None:
- cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD
- # add some randomness to the TTL to avoid a stampeding herd every 24 hours
- # after startup
- cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
+ if tls_client_options_factory is None:
+ self._tls_options = None
else:
- cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD)
-
- defer.returnValue((result, cache_period))
-
+ self._tls_options = tls_client_options_factory.get_options(
+ self._parsed_uri.host
+ )
-@implementer(IStreamClientEndpoint)
-class LoggingHostnameEndpoint(object):
- """A wrapper for HostnameEndpint which logs when it connects"""
- def __init__(self, reactor, host, port, *args, **kwargs):
- self.host = host
- self.port = port
- self.ep = HostnameEndpoint(reactor, host, port, *args, **kwargs)
+ self._srv_resolver = srv_resolver
def connect(self, protocol_factory):
- logger.info("Connecting to %s:%i", self.host.decode("ascii"), self.port)
- return self.ep.connect(protocol_factory)
+ """Implements IStreamClientEndpoint interface
+ """
+
+ return run_in_background(self._do_connect, protocol_factory)
+ @defer.inlineCallbacks
+ def _do_connect(self, protocol_factory):
+ first_exception = None
+
+ server_list = yield self._resolve_server()
+
+ for server in server_list:
+ host = server.host
+ port = server.port
+
+ try:
+ logger.info("Connecting to %s:%i", host.decode("ascii"), port)
+ endpoint = HostnameEndpoint(self._reactor, host, port)
+ if self._tls_options:
+ endpoint = wrapClientTLS(self._tls_options, endpoint)
+ result = yield make_deferred_yieldable(
+ endpoint.connect(protocol_factory)
+ )
-def _cache_period_from_headers(headers, time_now=time.time):
- cache_controls = _parse_cache_control(headers)
+ return result
+ except Exception as e:
+ logger.info(
+ "Failed to connect to %s:%i: %s", host.decode("ascii"), port, e
+ )
+ if not first_exception:
+ first_exception = e
- if b'no-store' in cache_controls:
- return 0
+ # We return the first failure because that's probably the most interesting.
+ if first_exception:
+ raise first_exception
- if b'max-age' in cache_controls:
- try:
- max_age = int(cache_controls[b'max-age'])
- return max_age
- except ValueError:
- pass
+ # This shouldn't happen as we should always have at least one host/port
+ # to try and if that doesn't work then we'll have an exception.
+ raise Exception("Failed to resolve server %r" % (self._parsed_uri.netloc,))
- expires = headers.getRawHeaders(b'expires')
- if expires is not None:
- try:
- expires_date = stringToDatetime(expires[-1])
- return expires_date - time_now()
- except ValueError:
- # RFC7234 says 'A cache recipient MUST interpret invalid date formats,
- # especially the value "0", as representing a time in the past (i.e.,
- # "already expired").
- return 0
+ @defer.inlineCallbacks
+ def _resolve_server(self):
+ """Resolves the server name to a list of hosts and ports to attempt to
+ connect to.
- return None
+ Returns:
+ Deferred[list[Server]]
+ """
+ if self._parsed_uri.scheme != b"matrix":
+ return [Server(host=self._parsed_uri.host, port=self._parsed_uri.port)]
-def _parse_cache_control(headers):
- cache_controls = {}
- for hdr in headers.getRawHeaders(b'cache-control', []):
- for directive in hdr.split(b','):
- splits = [x.strip() for x in directive.split(b'=', 1)]
- k = splits[0].lower()
- v = splits[1] if len(splits) > 1 else None
- cache_controls[k] = v
- return cache_controls
+ # Note: We don't do well-known lookup as that needs to have happened
+ # before now, due to needing to rewrite the Host header of the HTTP
+ # request.
+ # We reparse the URI so that defaultPort is -1 rather than 80
+ parsed_uri = urllib.parse.urlparse(self._parsed_uri.toBytes())
-@attr.s
-class _RoutingResult(object):
- """The result returned by `_route_matrix_uri`.
+ host = parsed_uri.hostname
+ port = parsed_uri.port
- Contains the parameters needed to direct a federation connection to a particular
- server.
+ # If there is an explicit port or the host is an IP address we bypass
+ # SRV lookups and just use the given host/port.
+ if port or _is_ip_literal(host):
+ return [Server(host, port or 8448)]
- Where a SRV record points to several servers, this object contains a single server
- chosen from the list.
- """
+ server_list = yield self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
- host_header = attr.ib()
- """
- The value we should assign to the Host header (host:port from the matrix
- URI, or .well-known).
+ if server_list:
+ return server_list
- :type: bytes
- """
+ # No SRV records, so we fallback to host and 8448
+ return [Server(host, 8448)]
- tls_server_name = attr.ib()
- """
- The server name we should set in the SNI (typically host, without port, from the
- matrix URI or .well-known)
- :type: bytes
- """
+def _is_ip_literal(host):
+ """Test if the given host name is either an IPv4 or IPv6 literal.
- target_host = attr.ib()
- """
- The hostname (or IP literal) we should route the TCP connection to (the target of the
- SRV record, or the hostname from the URL/.well-known)
+ Args:
+ host (bytes)
- :type: bytes
+ Returns:
+ bool
"""
- target_port = attr.ib()
- """
- The port we should route the TCP connection to (the target of the SRV record, or
- the port from the URL/.well-known, or 8448)
+ host = host.decode("ascii")
- :type: int
- """
+ try:
+ IPAddress(host)
+ return True
+ except AddrFormatError:
+ return False
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index 71830c549d..021b233a7d 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -25,14 +25,14 @@ from twisted.internet.error import ConnectError
from twisted.names import client, dns
from twisted.names.error import DNSNameError, DomainError
-from synapse.util.logcontext import make_deferred_yieldable
+from synapse.logging.context import make_deferred_yieldable
logger = logging.getLogger(__name__)
SERVER_CACHE = {}
-@attr.s
+@attr.s(slots=True, frozen=True)
class Server(object):
"""
Our record of an individual server which can be tried to reach a destination.
@@ -45,6 +45,7 @@ class Server(object):
expires (int): when the cache should expire this record - in *seconds* since
the epoch
"""
+
host = attr.ib()
port = attr.ib()
priority = attr.ib(default=0)
@@ -52,36 +53,47 @@ class Server(object):
expires = attr.ib(default=0)
-def pick_server_from_list(server_list):
- """Randomly choose a server from the server list
+def _sort_server_list(server_list):
+ """Given a list of SRV records sort them into priority order and shuffle
+ each priority with the given weight.
+ """
+ priority_map = {}
- Args:
- server_list (list[Server]): list of candidate servers
+ for server in server_list:
+ priority_map.setdefault(server.priority, []).append(server)
- Returns:
- Tuple[bytes, int]: (host, port) pair for the chosen server
- """
- if not server_list:
- raise RuntimeError("pick_server_from_list called with empty list")
+ results = []
+ for priority in sorted(priority_map):
+ servers = priority_map[priority]
+
+ # This algorithms roughly follows the algorithm described in RFC2782,
+ # changed to remove an off-by-one error.
+ #
+ # N.B. Weights can be zero, which means that they should be picked
+ # rarely.
- # TODO: currently we only use the lowest-priority servers. We should maintain a
- # cache of servers known to be "down" and filter them out
+ total_weight = sum(s.weight for s in servers)
- min_priority = min(s.priority for s in server_list)
- eligible_servers = list(s for s in server_list if s.priority == min_priority)
- total_weight = sum(s.weight for s in eligible_servers)
- target_weight = random.randint(0, total_weight)
+ # Total weight can become zero if there are only zero weight servers
+ # left, which we handle by just shuffling and appending to the results.
+ while servers and total_weight:
+ target_weight = random.randint(1, total_weight)
- for s in eligible_servers:
- target_weight -= s.weight
+ for s in servers:
+ target_weight -= s.weight
- if target_weight <= 0:
- return s.host, s.port
+ if target_weight <= 0:
+ break
- # this should be impossible.
- raise RuntimeError(
- "pick_server_from_list got to end of eligible server list.",
- )
+ results.append(s)
+ servers.remove(s)
+ total_weight -= s.weight
+
+ if servers:
+ random.shuffle(servers)
+ results.extend(servers)
+
+ return results
class SrvResolver(object):
@@ -95,6 +107,7 @@ class SrvResolver(object):
cache (dict): cache object
get_time (callable): clock implementation. Should return seconds since the epoch
"""
+
def __init__(self, dns_client=client, cache=SERVER_CACHE, get_time=time.time):
self._dns_client = dns_client
self._cache = cache
@@ -120,33 +133,34 @@ class SrvResolver(object):
if cache_entry:
if all(s.expires > now for s in cache_entry):
servers = list(cache_entry)
- defer.returnValue(servers)
+ return _sort_server_list(servers)
try:
answers, _, _ = yield make_deferred_yieldable(
- self._dns_client.lookupService(service_name),
+ self._dns_client.lookupService(service_name)
)
except DNSNameError:
# TODO: cache this. We can get the SOA out of the exception, and use
# the negative-TTL value.
- defer.returnValue([])
+ return []
except DomainError as e:
# We failed to resolve the name (other than a NameError)
# Try something in the cache, else rereaise
cache_entry = self._cache.get(service_name, None)
if cache_entry:
- logger.warn(
- "Failed to resolve %r, falling back to cache. %r",
- service_name, e
+ logger.warning(
+ "Failed to resolve %r, falling back to cache. %r", service_name, e
)
- defer.returnValue(list(cache_entry))
+ return list(cache_entry)
else:
raise e
- if (len(answers) == 1
- and answers[0].type == dns.SRV
- and answers[0].payload
- and answers[0].payload.target == dns.Name(b'.')):
+ if (
+ len(answers) == 1
+ and answers[0].type == dns.SRV
+ and answers[0].payload
+ and answers[0].payload.target == dns.Name(b".")
+ ):
raise ConnectError("Service %s unavailable" % service_name)
servers = []
@@ -157,13 +171,15 @@ class SrvResolver(object):
payload = answer.payload
- servers.append(Server(
- host=payload.target.name,
- port=payload.port,
- priority=payload.priority,
- weight=payload.weight,
- expires=now + answer.ttl,
- ))
+ servers.append(
+ Server(
+ host=payload.target.name,
+ port=payload.port,
+ priority=payload.priority,
+ weight=payload.weight,
+ expires=now + answer.ttl,
+ )
+ )
self._cache[service_name] = list(servers)
- defer.returnValue(servers)
+ return _sort_server_list(servers)
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
new file mode 100644
index 0000000000..7ddfad286d
--- /dev/null
+++ b/synapse/http/federation/well_known_resolver.py
@@ -0,0 +1,301 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import logging
+import random
+import time
+
+import attr
+
+from twisted.internet import defer
+from twisted.web.client import RedirectAgent, readBody
+from twisted.web.http import stringToDatetime
+
+from synapse.logging.context import make_deferred_yieldable
+from synapse.util import Clock
+from synapse.util.caches.ttlcache import TTLCache
+from synapse.util.metrics import Measure
+
+# period to cache .well-known results for by default
+WELL_KNOWN_DEFAULT_CACHE_PERIOD = 24 * 3600
+
+# jitter factor to add to the .well-known default cache ttls
+WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER = 0.1
+
+# period to cache failure to fetch .well-known for
+WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600
+
+# period to cache failure to fetch .well-known if there has recently been a
+# valid well-known for that domain.
+WELL_KNOWN_DOWN_CACHE_PERIOD = 2 * 60
+
+# period to remember there was a valid well-known after valid record expires
+WELL_KNOWN_REMEMBER_DOMAIN_HAD_VALID = 2 * 3600
+
+# cap for .well-known cache period
+WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
+
+# lower bound for .well-known cache period
+WELL_KNOWN_MIN_CACHE_PERIOD = 5 * 60
+
+# Attempt to refetch a cached well-known N% of the TTL before it expires.
+# e.g. if set to 0.2 and we have a cached entry with a TTL of 5mins, then
+# we'll start trying to refetch 1 minute before it expires.
+WELL_KNOWN_GRACE_PERIOD_FACTOR = 0.2
+
+# Number of times we retry fetching a well-known for a domain we know recently
+# had a valid entry.
+WELL_KNOWN_RETRY_ATTEMPTS = 3
+
+
+logger = logging.getLogger(__name__)
+
+
+_well_known_cache = TTLCache("well-known")
+_had_valid_well_known_cache = TTLCache("had-valid-well-known")
+
+
+@attr.s(slots=True, frozen=True)
+class WellKnownLookupResult(object):
+ delegated_server = attr.ib()
+
+
+class WellKnownResolver(object):
+ """Handles well-known lookups for matrix servers.
+ """
+
+ def __init__(
+ self, reactor, agent, well_known_cache=None, had_well_known_cache=None
+ ):
+ self._reactor = reactor
+ self._clock = Clock(reactor)
+
+ if well_known_cache is None:
+ well_known_cache = _well_known_cache
+
+ if had_well_known_cache is None:
+ had_well_known_cache = _had_valid_well_known_cache
+
+ self._well_known_cache = well_known_cache
+ self._had_valid_well_known_cache = had_well_known_cache
+ self._well_known_agent = RedirectAgent(agent)
+
+ @defer.inlineCallbacks
+ def get_well_known(self, server_name):
+ """Attempt to fetch and parse a .well-known file for the given server
+
+ Args:
+ server_name (bytes): name of the server, from the requested url
+
+ Returns:
+ Deferred[WellKnownLookupResult]: The result of the lookup
+ """
+ try:
+ prev_result, expiry, ttl = self._well_known_cache.get_with_expiry(
+ server_name
+ )
+
+ now = self._clock.time()
+ if now < expiry - WELL_KNOWN_GRACE_PERIOD_FACTOR * ttl:
+ return WellKnownLookupResult(delegated_server=prev_result)
+ except KeyError:
+ prev_result = None
+
+ # TODO: should we linearise so that we don't end up doing two .well-known
+ # requests for the same server in parallel?
+ try:
+ with Measure(self._clock, "get_well_known"):
+ result, cache_period = yield self._fetch_well_known(server_name)
+
+ except _FetchWellKnownFailure as e:
+ if prev_result and e.temporary:
+ # This is a temporary failure and we have a still valid cached
+ # result, so lets return that. Hopefully the next time we ask
+ # the remote will be back up again.
+ return WellKnownLookupResult(delegated_server=prev_result)
+
+ result = None
+
+ if self._had_valid_well_known_cache.get(server_name, False):
+ # We have recently seen a valid well-known record for this
+ # server, so we cache the lack of well-known for a shorter time.
+ cache_period = WELL_KNOWN_DOWN_CACHE_PERIOD
+ else:
+ cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
+
+ # add some randomness to the TTL to avoid a stampeding herd
+ cache_period *= random.uniform(
+ 1 - WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER,
+ 1 + WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER,
+ )
+
+ if cache_period > 0:
+ self._well_known_cache.set(server_name, result, cache_period)
+
+ return WellKnownLookupResult(delegated_server=result)
+
+ @defer.inlineCallbacks
+ def _fetch_well_known(self, server_name):
+ """Actually fetch and parse a .well-known, without checking the cache
+
+ Args:
+ server_name (bytes): name of the server, from the requested url
+
+ Raises:
+ _FetchWellKnownFailure if we fail to lookup a result
+
+ Returns:
+ Deferred[Tuple[bytes,int]]: The lookup result and cache period.
+ """
+
+ had_valid_well_known = self._had_valid_well_known_cache.get(server_name, False)
+
+ # We do this in two steps to differentiate between possibly transient
+ # errors (e.g. can't connect to host, 503 response) and more permenant
+ # errors (such as getting a 404 response).
+ response, body = yield self._make_well_known_request(
+ server_name, retry=had_valid_well_known
+ )
+
+ try:
+ if response.code != 200:
+ raise Exception("Non-200 response %s" % (response.code,))
+
+ parsed_body = json.loads(body.decode("utf-8"))
+ logger.info("Response from .well-known: %s", parsed_body)
+
+ result = parsed_body["m.server"].encode("ascii")
+ except defer.CancelledError:
+ # Bail if we've been cancelled
+ raise
+ except Exception as e:
+ logger.info("Error parsing well-known for %s: %s", server_name, e)
+ raise _FetchWellKnownFailure(temporary=False)
+
+ cache_period = _cache_period_from_headers(
+ response.headers, time_now=self._reactor.seconds
+ )
+ if cache_period is None:
+ cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD
+ # add some randomness to the TTL to avoid a stampeding herd every 24 hours
+ # after startup
+ cache_period *= random.uniform(
+ 1 - WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER,
+ 1 + WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER,
+ )
+ else:
+ cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD)
+ cache_period = max(cache_period, WELL_KNOWN_MIN_CACHE_PERIOD)
+
+ # We got a success, mark as such in the cache
+ self._had_valid_well_known_cache.set(
+ server_name,
+ bool(result),
+ cache_period + WELL_KNOWN_REMEMBER_DOMAIN_HAD_VALID,
+ )
+
+ return result, cache_period
+
+ @defer.inlineCallbacks
+ def _make_well_known_request(self, server_name, retry):
+ """Make the well known request.
+
+ This will retry the request if requested and it fails (with unable
+ to connect or receives a 5xx error).
+
+ Args:
+ server_name (bytes)
+ retry (bool): Whether to retry the request if it fails.
+
+ Returns:
+ Deferred[tuple[IResponse, bytes]] Returns the response object and
+ body. Response may be a non-200 response.
+ """
+ uri = b"https://%s/.well-known/matrix/server" % (server_name,)
+ uri_str = uri.decode("ascii")
+
+ i = 0
+ while True:
+ i += 1
+
+ logger.info("Fetching %s", uri_str)
+ try:
+ response = yield make_deferred_yieldable(
+ self._well_known_agent.request(b"GET", uri)
+ )
+ body = yield make_deferred_yieldable(readBody(response))
+
+ if 500 <= response.code < 600:
+ raise Exception("Non-200 response %s" % (response.code,))
+
+ return response, body
+ except defer.CancelledError:
+ # Bail if we've been cancelled
+ raise
+ except Exception as e:
+ if not retry or i >= WELL_KNOWN_RETRY_ATTEMPTS:
+ logger.info("Error fetching %s: %s", uri_str, e)
+ raise _FetchWellKnownFailure(temporary=True)
+
+ logger.info("Error fetching %s: %s. Retrying", uri_str, e)
+
+ # Sleep briefly in the hopes that they come back up
+ yield self._clock.sleep(0.5)
+
+
+def _cache_period_from_headers(headers, time_now=time.time):
+ cache_controls = _parse_cache_control(headers)
+
+ if b"no-store" in cache_controls:
+ return 0
+
+ if b"max-age" in cache_controls:
+ try:
+ max_age = int(cache_controls[b"max-age"])
+ return max_age
+ except ValueError:
+ pass
+
+ expires = headers.getRawHeaders(b"expires")
+ if expires is not None:
+ try:
+ expires_date = stringToDatetime(expires[-1])
+ return expires_date - time_now()
+ except ValueError:
+ # RFC7234 says 'A cache recipient MUST interpret invalid date formats,
+ # especially the value "0", as representing a time in the past (i.e.,
+ # "already expired").
+ return 0
+
+ return None
+
+
+def _parse_cache_control(headers):
+ cache_controls = {}
+ for hdr in headers.getRawHeaders(b"cache-control", []):
+ for directive in hdr.split(b","):
+ splits = [x.strip() for x in directive.split(b"=", 1)]
+ k = splits[0].lower()
+ v = splits[1] if len(splits) > 1 else None
+ cache_controls[k] = v
+ return cache_controls
+
+
+@attr.s()
+class _FetchWellKnownFailure(Exception):
+ # True if we didn't get a non-5xx HTTP response, i.e. this may or may not be
+ # a temporary failure.
+ temporary = attr.ib()
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 663ea72a7a..6f1bb04d8b 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -48,16 +48,24 @@ from synapse.api.errors import (
from synapse.http import QuieterFileBodyProducer
from synapse.http.client import BlacklistingAgentWrapper, IPBlacklistingResolver
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
+from synapse.logging.context import make_deferred_yieldable
+from synapse.logging.opentracing import (
+ inject_active_span_byte_dict,
+ set_tag,
+ start_active_span,
+ tags,
+)
from synapse.util.async_helpers import timeout_deferred
-from synapse.util.logcontext import make_deferred_yieldable
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
-outgoing_requests_counter = Counter("synapse_http_matrixfederationclient_requests",
- "", ["method"])
-incoming_responses_counter = Counter("synapse_http_matrixfederationclient_responses",
- "", ["method", "code"])
+outgoing_requests_counter = Counter(
+ "synapse_http_matrixfederationclient_requests", "", ["method"]
+)
+incoming_responses_counter = Counter(
+ "synapse_http_matrixfederationclient_responses", "", ["method", "code"]
+)
MAX_LONG_RETRIES = 10
@@ -137,15 +145,11 @@ def _handle_json_response(reactor, timeout_sec, request, response):
check_content_type_is_json(response.headers)
d = treq.json_content(response)
- d = timeout_deferred(
- d,
- timeout=timeout_sec,
- reactor=reactor,
- )
+ d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
body = yield make_deferred_yieldable(d)
except Exception as e:
- logger.warn(
+ logger.warning(
"{%s} [%s] Error reading response: %s",
request.txn_id,
request.destination,
@@ -157,9 +161,9 @@ def _handle_json_response(reactor, timeout_sec, request, response):
request.txn_id,
request.destination,
response.code,
- response.phrase.decode('ascii', errors='replace'),
+ response.phrase.decode("ascii", errors="replace"),
)
- defer.returnValue(body)
+ return body
class MatrixFederationHttpClient(object):
@@ -181,7 +185,7 @@ class MatrixFederationHttpClient(object):
# We need to use a DNS resolver which filters out blacklisted IP
# addresses, to prevent DNS rebinding.
nameResolver = IPBlacklistingResolver(
- real_reactor, None, hs.config.federation_ip_range_blacklist,
+ real_reactor, None, hs.config.federation_ip_range_blacklist
)
@implementer(IReactorPluggableNameResolver)
@@ -194,21 +198,19 @@ class MatrixFederationHttpClient(object):
self.reactor = Reactor()
- self.agent = MatrixFederationAgent(
- self.reactor,
- tls_client_options_factory,
- )
+ self.agent = MatrixFederationAgent(self.reactor, tls_client_options_factory)
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper(
- self.agent, self.reactor,
+ self.agent,
+ self.reactor,
ip_blacklist=hs.config.federation_ip_range_blacklist,
)
self.clock = hs.get_clock()
self._store = hs.get_datastore()
- self.version_string_bytes = hs.version_string.encode('ascii')
+ self.version_string_bytes = hs.version_string.encode("ascii")
self.default_timeout = 60
def schedule(x):
@@ -218,10 +220,7 @@ class MatrixFederationHttpClient(object):
@defer.inlineCallbacks
def _send_request_with_optional_trailing_slash(
- self,
- request,
- try_trailing_slash_on_400=False,
- **send_request_args
+ self, request, try_trailing_slash_on_400=False, **send_request_args
):
"""Wrapper for _send_request which can optionally retry the request
upon receiving a combination of a 400 HTTP response code and a
@@ -244,9 +243,7 @@ class MatrixFederationHttpClient(object):
Deferred[Dict]: Parsed JSON response body.
"""
try:
- response = yield self._send_request(
- request, **send_request_args
- )
+ response = yield self._send_request(request, **send_request_args)
except HttpResponseException as e:
# Received an HTTP error > 300. Check if it meets the requirements
# to retry with a trailing slash
@@ -262,11 +259,9 @@ class MatrixFederationHttpClient(object):
logger.info("Retrying request with trailing slash")
request.path += "/"
- response = yield self._send_request(
- request, **send_request_args
- )
+ response = yield self._send_request(request, **send_request_args)
- defer.returnValue(response)
+ return response
@defer.inlineCallbacks
def _send_request(
@@ -329,8 +324,8 @@ class MatrixFederationHttpClient(object):
_sec_timeout = self.default_timeout
if (
- self.hs.config.federation_domain_whitelist is not None and
- request.destination not in self.hs.config.federation_domain_whitelist
+ self.hs.config.federation_domain_whitelist is not None
+ and request.destination not in self.hs.config.federation_domain_whitelist
):
raise FederationDeniedError(request.destination)
@@ -350,11 +345,24 @@ class MatrixFederationHttpClient(object):
else:
query_bytes = b""
- headers_dict = {
- b"User-Agent": [self.version_string_bytes],
- }
+ scope = start_active_span(
+ "outgoing-federation-request",
+ tags={
+ tags.SPAN_KIND: tags.SPAN_KIND_RPC_CLIENT,
+ tags.PEER_ADDRESS: request.destination,
+ tags.HTTP_METHOD: request.method,
+ tags.HTTP_URL: request.path,
+ },
+ finish_on_close=True,
+ )
+
+ # Inject the span into the headers
+ headers_dict = {}
+ inject_active_span_byte_dict(headers_dict, request.destination)
- with limiter:
+ headers_dict[b"User-Agent"] = [self.version_string_bytes]
+
+ with limiter, scope:
# XXX: Would be much nicer to retry only at the transaction-layer
# (once we have reliable transactions in place)
if long_retries:
@@ -362,16 +370,14 @@ class MatrixFederationHttpClient(object):
else:
retries_left = MAX_SHORT_RETRIES
- url_bytes = urllib.parse.urlunparse((
- b"matrix", destination_bytes,
- path_bytes, None, query_bytes, b"",
- ))
- url_str = url_bytes.decode('ascii')
+ url_bytes = urllib.parse.urlunparse(
+ (b"matrix", destination_bytes, path_bytes, None, query_bytes, b"")
+ )
+ url_str = url_bytes.decode("ascii")
- url_to_sign_bytes = urllib.parse.urlunparse((
- b"", b"",
- path_bytes, None, query_bytes, b"",
- ))
+ url_to_sign_bytes = urllib.parse.urlunparse(
+ (b"", b"", path_bytes, None, query_bytes, b"")
+ )
while True:
try:
@@ -379,28 +385,31 @@ class MatrixFederationHttpClient(object):
if json:
headers_dict[b"Content-Type"] = [b"application/json"]
auth_headers = self.build_auth_headers(
- destination_bytes, method_bytes, url_to_sign_bytes,
- json,
+ destination_bytes, method_bytes, url_to_sign_bytes, json
)
data = encode_canonical_json(json)
producer = QuieterFileBodyProducer(
- BytesIO(data),
- cooperator=self._cooperator,
+ BytesIO(data), cooperator=self._cooperator
)
else:
producer = None
auth_headers = self.build_auth_headers(
- destination_bytes, method_bytes, url_to_sign_bytes,
+ destination_bytes, method_bytes, url_to_sign_bytes
)
headers_dict[b"Authorization"] = auth_headers
logger.info(
"{%s} [%s] Sending request: %s %s; timeout %fs",
- request.txn_id, request.destination, request.method,
- url_str, _sec_timeout,
+ request.txn_id,
+ request.destination,
+ request.method,
+ url_str,
+ _sec_timeout,
)
+ outgoing_requests_counter.labels(method_bytes).inc()
+
try:
with Measure(self.clock, "outbound_request"):
# we don't want all the fancy cookie and redirect handling
@@ -430,9 +439,13 @@ class MatrixFederationHttpClient(object):
request.txn_id,
request.destination,
response.code,
- response.phrase.decode('ascii', errors='replace'),
+ response.phrase.decode("ascii", errors="replace"),
)
+ incoming_responses_counter.labels(method_bytes, response.code).inc()
+
+ set_tag(tags.HTTP_STATUS_CODE, response.code)
+
if 200 <= response.code < 300:
pass
else:
@@ -440,9 +453,7 @@ class MatrixFederationHttpClient(object):
# Update transactions table?
d = treq.content(response)
d = timeout_deferred(
- d,
- timeout=_sec_timeout,
- reactor=self.reactor,
+ d, timeout=_sec_timeout, reactor=self.reactor
)
try:
@@ -450,7 +461,7 @@ class MatrixFederationHttpClient(object):
except Exception as e:
# Eh, we're already going to raise an exception so lets
# ignore if this fails.
- logger.warn(
+ logger.warning(
"{%s} [%s] Failed to get error response: %s %s: %s",
request.txn_id,
request.destination,
@@ -460,9 +471,7 @@ class MatrixFederationHttpClient(object):
)
body = None
- e = HttpResponseException(
- response.code, response.phrase, body
- )
+ e = HttpResponseException(response.code, response.phrase, body)
# Retry if the error is a 429 (Too Many Requests),
# otherwise just raise a standard HttpResponseException
@@ -473,7 +482,7 @@ class MatrixFederationHttpClient(object):
break
except RequestSendFailed as e:
- logger.warn(
+ logger.warning(
"{%s} [%s] Request failed: %s %s: %s",
request.txn_id,
request.destination,
@@ -508,7 +517,7 @@ class MatrixFederationHttpClient(object):
raise
except Exception as e:
- logger.warn(
+ logger.warning(
"{%s} [%s] Request failed: %s %s: %s",
request.txn_id,
request.destination,
@@ -517,16 +526,15 @@ class MatrixFederationHttpClient(object):
_flatten_response_never_received(e),
)
raise
-
- defer.returnValue(response)
+ return response
def build_auth_headers(
- self, destination, method, url_bytes, content=None, destination_is=None,
+ self, destination, method, url_bytes, content=None, destination_is=None
):
"""
Builds the Authorization headers for a federation request
Args:
- destination (bytes|None): The desination home server of the request.
+ destination (bytes|None): The desination homeserver of the request.
May be None if the destination is an identity server, in which case
destination_is must be non-None.
method (bytes): The HTTP method of the request
@@ -538,11 +546,7 @@ class MatrixFederationHttpClient(object):
Returns:
list[bytes]: a list of headers to be added as "Authorization:" headers
"""
- request = {
- "method": method,
- "uri": url_bytes,
- "origin": self.server_name,
- }
+ request = {"method": method, "uri": url_bytes, "origin": self.server_name}
if destination is not None:
request["destination"] = destination
@@ -558,20 +562,28 @@ class MatrixFederationHttpClient(object):
auth_headers = []
for key, sig in request["signatures"][self.server_name].items():
- auth_headers.append((
- "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % (
- self.server_name, key, sig,
- )).encode('ascii')
+ auth_headers.append(
+ (
+ 'X-Matrix origin=%s,key="%s",sig="%s"'
+ % (self.server_name, key, sig)
+ ).encode("ascii")
)
return auth_headers
@defer.inlineCallbacks
- def put_json(self, destination, path, args={}, data={},
- json_data_callback=None,
- long_retries=False, timeout=None,
- ignore_backoff=False,
- backoff_on_404=False,
- try_trailing_slash_on_400=False):
+ def put_json(
+ self,
+ destination,
+ path,
+ args={},
+ data={},
+ json_data_callback=None,
+ long_retries=False,
+ timeout=None,
+ ignore_backoff=False,
+ backoff_on_404=False,
+ try_trailing_slash_on_400=False,
+ ):
""" Sends the specifed json data using PUT
Args:
@@ -635,14 +647,22 @@ class MatrixFederationHttpClient(object):
)
body = yield _handle_json_response(
- self.reactor, self.default_timeout, request, response,
+ self.reactor, self.default_timeout, request, response
)
- defer.returnValue(body)
+ return body
@defer.inlineCallbacks
- def post_json(self, destination, path, data={}, long_retries=False,
- timeout=None, ignore_backoff=False, args={}):
+ def post_json(
+ self,
+ destination,
+ path,
+ data={},
+ long_retries=False,
+ timeout=None,
+ ignore_backoff=False,
+ args={},
+ ):
""" Sends the specifed json data using POST
Args:
@@ -681,11 +701,7 @@ class MatrixFederationHttpClient(object):
"""
request = MatrixFederationRequest(
- method="POST",
- destination=destination,
- path=path,
- query=args,
- json=data,
+ method="POST", destination=destination, path=path, query=args, json=data
)
response = yield self._send_request(
@@ -701,14 +717,21 @@ class MatrixFederationHttpClient(object):
_sec_timeout = self.default_timeout
body = yield _handle_json_response(
- self.reactor, _sec_timeout, request, response,
+ self.reactor, _sec_timeout, request, response
)
- defer.returnValue(body)
+ return body
@defer.inlineCallbacks
- def get_json(self, destination, path, args=None, retry_on_dns_fail=True,
- timeout=None, ignore_backoff=False,
- try_trailing_slash_on_400=False):
+ def get_json(
+ self,
+ destination,
+ path,
+ args=None,
+ retry_on_dns_fail=True,
+ timeout=None,
+ ignore_backoff=False,
+ try_trailing_slash_on_400=False,
+ ):
""" GETs some json from the given host homeserver and path
Args:
@@ -745,10 +768,7 @@ class MatrixFederationHttpClient(object):
remote, due to e.g. DNS failures, connection timeouts etc.
"""
request = MatrixFederationRequest(
- method="GET",
- destination=destination,
- path=path,
- query=args,
+ method="GET", destination=destination, path=path, query=args
)
response = yield self._send_request_with_optional_trailing_slash(
@@ -761,14 +781,21 @@ class MatrixFederationHttpClient(object):
)
body = yield _handle_json_response(
- self.reactor, self.default_timeout, request, response,
+ self.reactor, self.default_timeout, request, response
)
- defer.returnValue(body)
+ return body
@defer.inlineCallbacks
- def delete_json(self, destination, path, long_retries=False,
- timeout=None, ignore_backoff=False, args={}):
+ def delete_json(
+ self,
+ destination,
+ path,
+ long_retries=False,
+ timeout=None,
+ ignore_backoff=False,
+ args={},
+ ):
"""Send a DELETE request to the remote expecting some json response
Args:
@@ -802,10 +829,7 @@ class MatrixFederationHttpClient(object):
remote, due to e.g. DNS failures, connection timeouts etc.
"""
request = MatrixFederationRequest(
- method="DELETE",
- destination=destination,
- path=path,
- query=args,
+ method="DELETE", destination=destination, path=path, query=args
)
response = yield self._send_request(
@@ -816,14 +840,21 @@ class MatrixFederationHttpClient(object):
)
body = yield _handle_json_response(
- self.reactor, self.default_timeout, request, response,
+ self.reactor, self.default_timeout, request, response
)
- defer.returnValue(body)
+ return body
@defer.inlineCallbacks
- def get_file(self, destination, path, output_stream, args={},
- retry_on_dns_fail=True, max_size=None,
- ignore_backoff=False):
+ def get_file(
+ self,
+ destination,
+ path,
+ output_stream,
+ args={},
+ retry_on_dns_fail=True,
+ max_size=None,
+ ignore_backoff=False,
+ ):
"""GETs a file from a given homeserver
Args:
destination (str): The remote server to send the HTTP request to.
@@ -848,16 +879,11 @@ class MatrixFederationHttpClient(object):
remote, due to e.g. DNS failures, connection timeouts etc.
"""
request = MatrixFederationRequest(
- method="GET",
- destination=destination,
- path=path,
- query=args,
+ method="GET", destination=destination, path=path, query=args
)
response = yield self._send_request(
- request,
- retry_on_dns_fail=retry_on_dns_fail,
- ignore_backoff=ignore_backoff,
+ request, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff
)
headers = dict(response.headers.getAllRawHeaders())
@@ -867,7 +893,7 @@ class MatrixFederationHttpClient(object):
d.addTimeout(self.default_timeout, self.reactor)
length = yield make_deferred_yieldable(d)
except Exception as e:
- logger.warn(
+ logger.warning(
"{%s} [%s] Error reading response: %s",
request.txn_id,
request.destination,
@@ -879,10 +905,10 @@ class MatrixFederationHttpClient(object):
request.txn_id,
request.destination,
response.code,
- response.phrase.decode('ascii', errors='replace'),
+ response.phrase.decode("ascii", errors="replace"),
length,
)
- defer.returnValue((length, headers))
+ return (length, headers)
class _ReadBodyToFileProtocol(protocol.Protocol):
@@ -896,11 +922,13 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.stream.write(data)
self.length += len(data)
if self.max_size is not None and self.length >= self.max_size:
- self.deferred.errback(SynapseError(
- 502,
- "Requested file is too large > %r bytes" % (self.max_size,),
- Codes.TOO_LARGE,
- ))
+ self.deferred.errback(
+ SynapseError(
+ 502,
+ "Requested file is too large > %r bytes" % (self.max_size,),
+ Codes.TOO_LARGE,
+ )
+ )
self.deferred = defer.Deferred()
self.transport.loseConnection()
@@ -920,8 +948,7 @@ def _readBodyToFile(response, stream, max_size):
def _flatten_response_never_received(e):
if hasattr(e, "reasons"):
reasons = ", ".join(
- _flatten_response_never_received(f.value)
- for f in e.reasons
+ _flatten_response_never_received(f.value) for f in e.reasons
)
return "%s:[%s]" % (type(e).__name__, reasons)
@@ -943,16 +970,15 @@ def check_content_type_is_json(headers):
"""
c_type = headers.getRawHeaders(b"Content-Type")
if c_type is None:
- raise RequestSendFailed(RuntimeError(
- "No Content-Type header"
- ), can_retry=False)
+ raise RequestSendFailed(RuntimeError("No Content-Type header"), can_retry=False)
- c_type = c_type[0].decode('ascii') # only the first header
+ c_type = c_type[0].decode("ascii") # only the first header
val, options = cgi.parse_header(c_type)
if val != "application/json":
- raise RequestSendFailed(RuntimeError(
- "Content-Type not application/json: was '%s'" % c_type
- ), can_retry=False)
+ raise RequestSendFailed(
+ RuntimeError("Content-Type not application/json: was '%s'" % c_type),
+ can_retry=False,
+ )
def encode_query_args(args):
@@ -967,4 +993,4 @@ def encode_query_args(args):
query_bytes = urllib.parse.urlencode(encoded_args, True)
- return query_bytes.encode('utf8')
+ return query_bytes.encode("utf8")
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index 62045a918b..58f9cc61c8 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -19,8 +19,8 @@ import threading
from prometheus_client.core import Counter, Histogram
+from synapse.logging.context import LoggingContext
from synapse.metrics import LaterGauge
-from synapse.util.logcontext import LoggingContext
logger = logging.getLogger(__name__)
@@ -170,7 +170,7 @@ class RequestMetrics(object):
tag = context.tag
if context != self.start_context:
- logger.warn(
+ logger.warning(
"Context have unexpectedly changed %r, %r",
context,
self.start_context,
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 16fb7935da..042a605198 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -14,12 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import cgi
import collections
+import html
+import http.client
import logging
-
-from six import PY3
-from six.moves import http_client, urllib
+import types
+import urllib
+from io import BytesIO
from canonicaljson import encode_canonical_json, encode_pretty_printed_json, json
@@ -35,16 +36,13 @@ import synapse.metrics
from synapse.api.errors import (
CodeMessageException,
Codes,
+ RedirectException,
SynapseError,
UnrecognizedRequestError,
)
+from synapse.logging.context import preserve_fn
+from synapse.logging.opentracing import trace_servlet
from synapse.util.caches import intern_dict
-from synapse.util.logcontext import preserve_fn
-
-if PY3:
- from io import BytesIO
-else:
- from cStringIO import StringIO as BytesIO
logger = logging.getLogger(__name__)
@@ -69,21 +67,18 @@ def wrap_json_request_handler(h):
The handler method must have a signature of "handle_foo(self, request)",
where "request" must be a SynapseRequest.
- The handler must return a deferred. If the deferred succeeds we assume that
- a response has been sent. If the deferred fails with a SynapseError we use
+ The handler must return a deferred or a coroutine. If the deferred succeeds
+ we assume that a response has been sent. If the deferred fails with a SynapseError we use
it to send a JSON response with the appropriate HTTP reponse code. If the
deferred fails with any other type of error we send a 500 reponse.
"""
- @defer.inlineCallbacks
- def wrapped_request_handler(self, request):
+ async def wrapped_request_handler(self, request):
try:
- yield h(self, request)
+ await h(self, request)
except SynapseError as e:
code = e.code
- logger.info(
- "%s SynapseError: %s - %s", request, code, e.msg
- )
+ logger.info("%s SynapseError: %s - %s", request, code, e.msg)
# Only respond with an error response if we haven't already started
# writing, otherwise lets just kill the connection
@@ -96,7 +91,10 @@ def wrap_json_request_handler(h):
pass
else:
respond_with_json(
- request, code, e.error_dict(), send_cors=True,
+ request,
+ code,
+ e.error_dict(),
+ send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
)
@@ -124,10 +122,7 @@ def wrap_json_request_handler(h):
respond_with_json(
request,
500,
- {
- "error": "Internal server error",
- "errcode": Codes.UNKNOWN,
- },
+ {"error": "Internal server error", "errcode": Codes.UNKNOWN},
send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
)
@@ -143,10 +138,13 @@ def wrap_html_request_handler(h):
The handler method must have a signature of "handle_foo(self, request)",
where "request" must be a SynapseRequest.
"""
- def wrapped_request_handler(self, request):
- d = defer.maybeDeferred(h, self, request)
- d.addErrback(_return_html_error, request)
- return d
+
+ async def wrapped_request_handler(self, request):
+ try:
+ return await h(self, request)
+ except Exception:
+ f = failure.Failure()
+ return _return_html_error(f, request)
return wrap_async_request_handler(wrapped_request_handler)
@@ -156,17 +154,19 @@ def _return_html_error(f, request):
Args:
f (twisted.python.failure.Failure):
- request (twisted.web.iweb.IRequest):
+ request (twisted.web.server.Request):
"""
if f.check(CodeMessageException):
cme = f.value
code = cme.code
msg = cme.msg
- if isinstance(cme, SynapseError):
- logger.info(
- "%s SynapseError: %s - %s", request, code, msg
- )
+ if isinstance(cme, RedirectException):
+ logger.info("%s redirect to %s", request, cme.location)
+ request.setHeader(b"location", cme.location)
+ request.cookies.extend(cme.cookies)
+ elif isinstance(cme, SynapseError):
+ logger.info("%s SynapseError: %s - %s", request, code, msg)
else:
logger.error(
"Failed handle request %r",
@@ -174,7 +174,7 @@ def _return_html_error(f, request):
exc_info=(f.type, f.value, f.getTracebackObject()),
)
else:
- code = http_client.INTERNAL_SERVER_ERROR
+ code = http.client.INTERNAL_SERVER_ERROR
msg = "Internal server error"
logger.error(
@@ -183,9 +183,7 @@ def _return_html_error(f, request):
exc_info=(f.type, f.value, f.getTracebackObject()),
)
- body = HTML_ERROR_TEMPLATE.format(
- code=code, msg=cgi.escape(msg),
- ).encode("utf-8")
+ body = HTML_ERROR_TEMPLATE.format(code=code, msg=html.escape(msg)).encode("utf-8")
request.setResponseCode(code)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.setHeader(b"Content-Length", b"%i" % (len(body),))
@@ -205,10 +203,10 @@ def wrap_async_request_handler(h):
The handler may return a deferred, in which case the completion of the request isn't
logged until the deferred completes.
"""
- @defer.inlineCallbacks
- def wrapped_async_request_handler(self, request):
+
+ async def wrapped_async_request_handler(self, request):
with request.processing():
- yield h(self, request)
+ await h(self, request)
# we need to preserve_fn here, because the synchronous render method won't yield for
# us (obviously)
@@ -253,7 +251,9 @@ class JsonResource(HttpServer, resource.Resource):
isLeaf = True
- _PathEntry = collections.namedtuple("_PathEntry", ["pattern", "callback"])
+ _PathEntry = collections.namedtuple(
+ "_PathEntry", ["pattern", "callback", "servlet_classname"]
+ )
def __init__(self, hs, canonical_json=True):
resource.Resource.__init__(self)
@@ -263,57 +263,73 @@ class JsonResource(HttpServer, resource.Resource):
self.path_regexs = {}
self.hs = hs
- def register_paths(self, method, path_patterns, callback):
+ def register_paths(
+ self, method, path_patterns, callback, servlet_classname, trace=True
+ ):
+ """
+ Registers a request handler against a regular expression. Later request URLs are
+ checked against these regular expressions in order to identify an appropriate
+ handler for that request.
+
+ Args:
+ method (str): GET, POST etc
+
+ path_patterns (Iterable[str]): A list of regular expressions to which
+ the request URLs are compared.
+
+ callback (function): The handler for the request. Usually a Servlet
+
+ servlet_classname (str): The name of the handler to be used in prometheus
+ and opentracing logs.
+
+ trace (bool): Whether we should start a span to trace the servlet.
+ """
method = method.encode("utf-8") # method is bytes on py3
+
+ if trace:
+ # We don't extract the context from the servlet because we can't
+ # trust the sender
+ callback = trace_servlet(servlet_classname)(callback)
+
for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append(
- self._PathEntry(path_pattern, callback)
+ self._PathEntry(path_pattern, callback, servlet_classname)
)
def render(self, request):
""" This gets called by twisted every time someone sends us a request.
"""
- self._async_render(request)
+ defer.ensureDeferred(self._async_render(request))
return NOT_DONE_YET
@wrap_json_request_handler
- @defer.inlineCallbacks
- def _async_render(self, request):
+ async def _async_render(self, request):
""" This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and
path.
"""
- callback, group_dict = self._get_handler_for_request(request)
+ callback, servlet_classname, group_dict = self._get_handler_for_request(request)
- servlet_instance = getattr(callback, "__self__", None)
- if servlet_instance is not None:
- servlet_classname = servlet_instance.__class__.__name__
- else:
- servlet_classname = "%r" % callback
+ # Make sure we have a name for this handler in prometheus.
request.request_metrics.name = servlet_classname
# Now trigger the callback. If it returns a response, we send it
# here. If it throws an exception, that is handled by the wrapper
# installed by @request_handler.
+ kwargs = intern_dict(
+ {
+ name: urllib.parse.unquote(value) if value else value
+ for name, value in group_dict.items()
+ }
+ )
+
+ callback_return = callback(request, **kwargs)
+
+ # Is it synchronous? We'll allow this for now.
+ if isinstance(callback_return, (defer.Deferred, types.CoroutineType)):
+ callback_return = await callback_return
- def _unquote(s):
- if PY3:
- # On Python 3, unquote is unicode -> unicode
- return urllib.parse.unquote(s)
- else:
- # On Python 2, unquote is bytes -> bytes We need to encode the
- # URL again (as it was decoded by _get_handler_for request), as
- # ASCII because it's a URL, and then decode it to get the UTF-8
- # characters that were quoted.
- return urllib.parse.unquote(s.encode('ascii')).decode('utf8')
-
- kwargs = intern_dict({
- name: _unquote(value) if value else value
- for name, value in group_dict.items()
- })
-
- callback_return = yield callback(request, **kwargs)
if callback_return is not None:
code, response = callback_return
self._send_response(request, code, response)
@@ -325,7 +341,8 @@ class JsonResource(HttpServer, resource.Resource):
request (twisted.web.http.Request):
Returns:
- Tuple[Callable, dict[unicode, unicode]]: callback method, and the
+ Tuple[Callable, str, dict[unicode, unicode]]: callback method, the
+ label to use for that method in prometheus metrics, and the
dict mapping keys to path components as specified in the
handler's path match regexp.
@@ -334,24 +351,29 @@ class JsonResource(HttpServer, resource.Resource):
None, or a tuple of (http code, response body).
"""
if request.method == b"OPTIONS":
- return _options_handler, {}
+ return _options_handler, "options_request_handler", {}
+
+ request_path = request.path.decode("ascii")
# Loop through all the registered callbacks to check if the method
# and path regex match
for path_entry in self.path_regexs.get(request.method, []):
- m = path_entry.pattern.match(request.path.decode('ascii'))
+ m = path_entry.pattern.match(request_path)
if m:
# We found a match!
- return path_entry.callback, m.groupdict()
+ return path_entry.callback, path_entry.servlet_classname, m.groupdict()
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
- return _unrecognised_request_handler, {}
+ return _unrecognised_request_handler, "unrecognised_request_handler", {}
- def _send_response(self, request, code, response_json_object,
- response_code_message=None):
+ def _send_response(
+ self, request, code, response_json_object, response_code_message=None
+ ):
# TODO: Only enable CORS for the requests that need it.
respond_with_json(
- request, code, response_json_object,
+ request,
+ code,
+ response_json_object,
send_cors=True,
response_code_message=response_code_message,
pretty_print=_request_user_agent_is_curl(request),
@@ -359,6 +381,29 @@ class JsonResource(HttpServer, resource.Resource):
)
+class DirectServeResource(resource.Resource):
+ def render(self, request):
+ """
+ Render the request, using an asynchronous render handler if it exists.
+ """
+ async_render_callback_name = "_async_render_" + request.method.decode("ascii")
+
+ # Try and get the async renderer
+ callback = getattr(self, async_render_callback_name, None)
+
+ # No async renderer for this request method.
+ if not callback:
+ return super().render(request)
+
+ resp = trace_servlet(self.__class__.__name__)(callback)(request)
+
+ # If it's a coroutine, turn it into a Deferred
+ if isinstance(resp, types.CoroutineType):
+ defer.ensureDeferred(resp)
+
+ return NOT_DONE_YET
+
+
def _options_handler(request):
"""Request handler for OPTIONS requests
@@ -395,7 +440,7 @@ class RootRedirect(resource.Resource):
self.url = path
def render_GET(self, request):
- return redirectTo(self.url.encode('ascii'), request)
+ return redirectTo(self.url.encode("ascii"), request)
def getChild(self, name, request):
if len(name) == 0:
@@ -403,16 +448,22 @@ class RootRedirect(resource.Resource):
return resource.Resource.getChild(self, name, request)
-def respond_with_json(request, code, json_object, send_cors=False,
- response_code_message=None, pretty_print=False,
- canonical_json=True):
+def respond_with_json(
+ request,
+ code,
+ json_object,
+ send_cors=False,
+ response_code_message=None,
+ pretty_print=False,
+ canonical_json=True,
+):
# could alternatively use request.notifyFinish() and flip a flag when
# the Deferred fires, but since the flag is RIGHT THERE it seems like
# a waste.
if request._disconnected:
- logger.warn(
- "Not sending response to request %s, already disconnected.",
- request)
+ logger.warning(
+ "Not sending response to request %s, already disconnected.", request
+ )
return
if pretty_print:
@@ -425,14 +476,17 @@ def respond_with_json(request, code, json_object, send_cors=False,
json_bytes = json.dumps(json_object).encode("utf-8")
return respond_with_json_bytes(
- request, code, json_bytes,
+ request,
+ code,
+ json_bytes,
send_cors=send_cors,
response_code_message=response_code_message,
)
-def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
- response_code_message=None):
+def respond_with_json_bytes(
+ request, code, json_bytes, send_cors=False, response_code_message=None
+):
"""Sends encoded JSON in response to the given request.
Args:
@@ -474,7 +528,7 @@ def set_cors_headers(request):
)
request.setHeader(
b"Access-Control-Allow-Headers",
- b"Origin, X-Requested-With, Content-Type, Accept, Authorization"
+ b"Origin, X-Requested-With, Content-Type, Accept, Authorization",
)
@@ -498,9 +552,7 @@ def finish_request(request):
def _request_user_agent_is_curl(request):
- user_agents = request.requestHeaders.getRawHeaders(
- b"User-Agent", default=[]
- )
+ user_agents = request.requestHeaders.getRawHeaders(b"User-Agent", default=[])
for user_agent in user_agents:
if b"curl" in user_agent:
return True
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 197c652850..13fcb408a6 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -48,7 +48,7 @@ def parse_integer(request, name, default=None, required=False):
def parse_integer_from_args(args, name, default=None, required=False):
if not isinstance(name, bytes):
- name = name.encode('ascii')
+ name = name.encode("ascii")
if name in args:
try:
@@ -89,18 +89,14 @@ def parse_boolean(request, name, default=None, required=False):
def parse_boolean_from_args(args, name, default=None, required=False):
if not isinstance(name, bytes):
- name = name.encode('ascii')
+ name = name.encode("ascii")
if name in args:
try:
- return {
- b"true": True,
- b"false": False,
- }[args[name][0]]
+ return {b"true": True, b"false": False}[args[name][0]]
except Exception:
message = (
- "Boolean query parameter %r must be one of"
- " ['true', 'false']"
+ "Boolean query parameter %r must be one of ['true', 'false']"
) % (name,)
raise SynapseError(400, message)
else:
@@ -111,8 +107,15 @@ def parse_boolean_from_args(args, name, default=None, required=False):
return default
-def parse_string(request, name, default=None, required=False,
- allowed_values=None, param_type="string", encoding='ascii'):
+def parse_string(
+ request,
+ name,
+ default=None,
+ required=False,
+ allowed_values=None,
+ param_type="string",
+ encoding="ascii",
+):
"""
Parse a string parameter from the request query string.
@@ -145,21 +148,34 @@ def parse_string(request, name, default=None, required=False,
)
-def parse_string_from_args(args, name, default=None, required=False,
- allowed_values=None, param_type="string", encoding='ascii'):
+def parse_string_from_args(
+ args,
+ name,
+ default=None,
+ required=False,
+ allowed_values=None,
+ param_type="string",
+ encoding="ascii",
+):
if not isinstance(name, bytes):
- name = name.encode('ascii')
+ name = name.encode("ascii")
if name in args:
value = args[name][0]
if encoding:
- value = value.decode(encoding)
+ try:
+ value = value.decode(encoding)
+ except ValueError:
+ raise SynapseError(
+ 400, "Query parameter %r must be %s" % (name, encoding)
+ )
if allowed_values is not None and value not in allowed_values:
message = "Query parameter %r must be one of [%s]" % (
- name, ", ".join(repr(v) for v in allowed_values)
+ name,
+ ", ".join(repr(v) for v in allowed_values),
)
raise SynapseError(400, message)
else:
@@ -201,15 +217,15 @@ def parse_json_value_from_request(request, allow_empty_body=False):
# Decode to Unicode so that simplejson will return Unicode strings on
# Python 2
try:
- content_unicode = content_bytes.decode('utf8')
+ content_unicode = content_bytes.decode("utf8")
except UnicodeDecodeError:
- logger.warn("Unable to decode UTF-8")
+ logger.warning("Unable to decode UTF-8")
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
try:
content = json.loads(content_unicode)
except Exception as e:
- logger.warn("Unable to parse JSON: %s", e)
+ logger.warning("Unable to parse JSON: %s", e)
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
return content
@@ -227,9 +243,7 @@ def parse_json_object_from_request(request, allow_empty_body=False):
SynapseError if the request body couldn't be decoded as JSON or
if it wasn't a JSON object.
"""
- content = parse_json_value_from_request(
- request, allow_empty_body=allow_empty_body,
- )
+ content = parse_json_value_from_request(request, allow_empty_body=allow_empty_body)
if allow_empty_body and content is None:
return {}
@@ -280,8 +294,11 @@ class RestServlet(object):
for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"):
if hasattr(self, "on_%s" % (method,)):
+ servlet_classname = self.__class__.__name__
method_handler = getattr(self, "on_%s" % (method,))
- http_server.register_paths(method, patterns, method_handler)
+ http_server.register_paths(
+ method, patterns, method_handler, servlet_classname
+ )
else:
raise NotImplementedError("RestServlet must register something.")
diff --git a/synapse/http/site.py b/synapse/http/site.py
index e508c0bd4f..e092193c9c 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -19,7 +19,7 @@ from twisted.web.server import Request, Site
from synapse.http import redact_uri
from synapse.http.request_metrics import RequestMetrics, requests_counter
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
+from synapse.logging.context import LoggingContext, PreserveLoggingContext
logger = logging.getLogger(__name__)
@@ -46,10 +46,11 @@ class SynapseRequest(Request):
Attributes:
logcontext(LoggingContext) : the log context for this request
"""
- def __init__(self, site, channel, *args, **kw):
+
+ def __init__(self, channel, *args, **kw):
Request.__init__(self, channel, *args, **kw)
- self.site = site
- self._channel = channel # this is used by the tests
+ self.site = channel.site
+ self._channel = channel # this is used by the tests
self.authenticated_entity = None
self.start_time = 0
@@ -72,12 +73,12 @@ class SynapseRequest(Request):
def __repr__(self):
# We overwrite this so that we don't log ``access_token``
- return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % (
+ return "<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>" % (
self.__class__.__name__,
id(self),
self.get_method(),
self.get_redacted_uri(),
- self.clientproto.decode('ascii', errors='replace'),
+ self.clientproto.decode("ascii", errors="replace"),
self.site.site_tag,
)
@@ -87,7 +88,7 @@ class SynapseRequest(Request):
def get_redacted_uri(self):
uri = self.uri
if isinstance(uri, bytes):
- uri = self.uri.decode('ascii')
+ uri = self.uri.decode("ascii", errors="replace")
return redact_uri(uri)
def get_method(self):
@@ -102,7 +103,7 @@ class SynapseRequest(Request):
"""
method = self.method
if isinstance(method, bytes):
- method = self.method.decode('ascii')
+ method = self.method.decode("ascii")
return method
def get_user_agent(self):
@@ -134,8 +135,7 @@ class SynapseRequest(Request):
# dispatching to the handler, so that the handler
# can update the servlet name in the request
# metrics
- requests_counter.labels(self.get_method(),
- self.request_metrics.name).inc()
+ requests_counter.labels(self.get_method(), self.request_metrics.name).inc()
@contextlib.contextmanager
def processing(self):
@@ -199,8 +199,8 @@ class SynapseRequest(Request):
# It's useful to log it here so that we can get an idea of when
# the client disconnects.
with PreserveLoggingContext(self.logcontext):
- logger.warn(
- "Error processing request %r: %s %s", self, reason.type, reason.value,
+ logger.warning(
+ "Error processing request %r: %s %s", self, reason.type, reason.value
)
if not self._is_processing:
@@ -222,15 +222,15 @@ class SynapseRequest(Request):
self.start_time = time.time()
self.request_metrics = RequestMetrics()
self.request_metrics.start(
- self.start_time, name=servlet_name, method=self.get_method(),
+ self.start_time, name=servlet_name, method=self.get_method()
)
- self.site.access_logger.info(
+ self.site.access_logger.debug(
"%s - %s - Received request: %s %s",
self.getClientIP(),
self.site.site_tag,
self.get_method(),
- self.get_redacted_uri()
+ self.get_redacted_uri(),
)
def _finished_processing(self):
@@ -282,7 +282,7 @@ class SynapseRequest(Request):
self.site.access_logger.info(
"%s - %s - {%s}"
" Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
- " %sB %s \"%s %s %s\" \"%s\" [%d dbevts]",
+ ' %sB %s "%s %s %s" "%s" [%d dbevts]',
self.getClientIP(),
self.site.site_tag,
authenticated_entity,
@@ -297,7 +297,7 @@ class SynapseRequest(Request):
code,
self.get_method(),
self.get_redacted_uri(),
- self.clientproto.decode('ascii', errors='replace'),
+ self.clientproto.decode("ascii", errors="replace"),
user_agent,
usage.evt_db_fetch_count,
)
@@ -305,7 +305,7 @@ class SynapseRequest(Request):
try:
self.request_metrics.stop(self.finish_time, self.code, self.sentLength)
except Exception as e:
- logger.warn("Failed to stop metrics: %r", e)
+ logger.warning("Failed to stop metrics: %r", e)
class XForwardedForRequest(SynapseRequest):
@@ -316,26 +316,19 @@ class XForwardedForRequest(SynapseRequest):
Add a layer on top of another request that only uses the value of an
X-Forwarded-For header as the result of C{getClientIP}.
"""
+
def getClientIP(self):
"""
@return: The client address (the first address) in the value of the
I{X-Forwarded-For header}. If the header is not present, return
C{b"-"}.
"""
- return self.requestHeaders.getRawHeaders(
- b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip().decode('ascii')
-
-
-class SynapseRequestFactory(object):
- def __init__(self, site, x_forwarded_for):
- self.site = site
- self.x_forwarded_for = x_forwarded_for
-
- def __call__(self, *args, **kwargs):
- if self.x_forwarded_for:
- return XForwardedForRequest(self.site, *args, **kwargs)
- else:
- return SynapseRequest(self.site, *args, **kwargs)
+ return (
+ self.requestHeaders.getRawHeaders(b"x-forwarded-for", [b"-"])[0]
+ .split(b",")[0]
+ .strip()
+ .decode("ascii")
+ )
class SynapseSite(Site):
@@ -343,16 +336,25 @@ class SynapseSite(Site):
Subclass of a twisted http Site that does access logging with python's
standard logging
"""
- def __init__(self, logger_name, site_tag, config, resource,
- server_version_string, *args, **kwargs):
+
+ def __init__(
+ self,
+ logger_name,
+ site_tag,
+ config,
+ resource,
+ server_version_string,
+ *args,
+ **kwargs
+ ):
Site.__init__(self, resource, *args, **kwargs)
self.site_tag = site_tag
proxied = config.get("x_forwarded", False)
- self.requestFactory = SynapseRequestFactory(self, proxied)
+ self.requestFactory = XForwardedForRequest if proxied else SynapseRequest
self.access_logger = logging.getLogger(logger_name)
- self.server_version_string = server_version_string.encode('ascii')
+ self.server_version_string = server_version_string.encode("ascii")
def log(self, request):
pass
diff --git a/synapse/rest/media/v0/__init__.py b/synapse/logging/__init__.py
index e69de29bb2..e69de29bb2 100644
--- a/synapse/rest/media/v0/__init__.py
+++ b/synapse/logging/__init__.py
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
new file mode 100644
index 0000000000..ffa7b20ca8
--- /dev/null
+++ b/synapse/logging/_structured.py
@@ -0,0 +1,386 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os.path
+import sys
+import typing
+import warnings
+from typing import List
+
+import attr
+from constantly import NamedConstant, Names, ValueConstant, Values
+from zope.interface import implementer
+
+from twisted.logger import (
+ FileLogObserver,
+ FilteringLogObserver,
+ ILogObserver,
+ LogBeginner,
+ Logger,
+ LogLevel,
+ LogLevelFilterPredicate,
+ LogPublisher,
+ eventAsText,
+ jsonFileLogObserver,
+)
+
+from synapse.config._base import ConfigError
+from synapse.logging._terse_json import (
+ TerseJSONToConsoleLogObserver,
+ TerseJSONToTCPLogObserver,
+)
+from synapse.logging.context import LoggingContext
+
+
+def stdlib_log_level_to_twisted(level: str) -> LogLevel:
+ """
+ Convert a stdlib log level to Twisted's log level.
+ """
+ lvl = level.lower().replace("warning", "warn")
+ return LogLevel.levelWithName(lvl)
+
+
+@attr.s
+@implementer(ILogObserver)
+class LogContextObserver(object):
+ """
+ An ILogObserver which adds Synapse-specific log context information.
+
+ Attributes:
+ observer (ILogObserver): The target parent observer.
+ """
+
+ observer = attr.ib()
+
+ def __call__(self, event: dict) -> None:
+ """
+ Consume a log event and emit it to the parent observer after filtering
+ and adding log context information.
+
+ Args:
+ event (dict)
+ """
+ # Filter out some useless events that Twisted outputs
+ if "log_text" in event:
+ if event["log_text"].startswith("DNSDatagramProtocol starting on "):
+ return
+
+ if event["log_text"].startswith("(UDP Port "):
+ return
+
+ if event["log_text"].startswith("Timing out client") or event[
+ "log_format"
+ ].startswith("Timing out client"):
+ return
+
+ context = LoggingContext.current_context()
+
+ # Copy the context information to the log event.
+ if context is not None:
+ context.copy_to_twisted_log_entry(event)
+ else:
+ # If there's no logging context, not even the root one, we might be
+ # starting up or it might be from non-Synapse code. Log it as if it
+ # came from the root logger.
+ event["request"] = None
+ event["scope"] = None
+
+ self.observer(event)
+
+
+class PythonStdlibToTwistedLogger(logging.Handler):
+ """
+ Transform a Python stdlib log message into a Twisted one.
+ """
+
+ def __init__(self, observer, *args, **kwargs):
+ """
+ Args:
+ observer (ILogObserver): A Twisted logging observer.
+ *args, **kwargs: Args/kwargs to be passed to logging.Handler.
+ """
+ self.observer = observer
+ super().__init__(*args, **kwargs)
+
+ def emit(self, record: logging.LogRecord) -> None:
+ """
+ Emit a record to Twisted's observer.
+
+ Args:
+ record (logging.LogRecord)
+ """
+
+ self.observer(
+ {
+ "log_time": record.created,
+ "log_text": record.getMessage(),
+ "log_format": "{log_text}",
+ "log_namespace": record.name,
+ "log_level": stdlib_log_level_to_twisted(record.levelname),
+ }
+ )
+
+
+def SynapseFileLogObserver(outFile: typing.IO[str]) -> FileLogObserver:
+ """
+ A log observer that formats events like the traditional log formatter and
+ sends them to `outFile`.
+
+ Args:
+ outFile (file object): The file object to write to.
+ """
+
+ def formatEvent(_event: dict) -> str:
+ event = dict(_event)
+ event["log_level"] = event["log_level"].name.upper()
+ event["log_format"] = "- {log_namespace} - {log_level} - {request} - " + (
+ event.get("log_format", "{log_text}") or "{log_text}"
+ )
+ return eventAsText(event, includeSystem=False) + "\n"
+
+ return FileLogObserver(outFile, formatEvent)
+
+
+class DrainType(Names):
+ CONSOLE = NamedConstant()
+ CONSOLE_JSON = NamedConstant()
+ CONSOLE_JSON_TERSE = NamedConstant()
+ FILE = NamedConstant()
+ FILE_JSON = NamedConstant()
+ NETWORK_JSON_TERSE = NamedConstant()
+
+
+class OutputPipeType(Values):
+ stdout = ValueConstant(sys.__stdout__)
+ stderr = ValueConstant(sys.__stderr__)
+
+
+@attr.s
+class DrainConfiguration(object):
+ name = attr.ib()
+ type = attr.ib()
+ location = attr.ib()
+ options = attr.ib(default=None)
+
+
+@attr.s
+class NetworkJSONTerseOptions(object):
+ maximum_buffer = attr.ib(type=int)
+
+
+DEFAULT_LOGGERS = {"synapse": {"level": "INFO"}}
+
+
+def parse_drain_configs(
+ drains: dict,
+) -> typing.Generator[DrainConfiguration, None, None]:
+ """
+ Parse the drain configurations.
+
+ Args:
+ drains (dict): A list of drain configurations.
+
+ Yields:
+ DrainConfiguration instances.
+
+ Raises:
+ ConfigError: If any of the drain configuration items are invalid.
+ """
+ for name, config in drains.items():
+ if "type" not in config:
+ raise ConfigError("Logging drains require a 'type' key.")
+
+ try:
+ logging_type = DrainType.lookupByName(config["type"].upper())
+ except ValueError:
+ raise ConfigError(
+ "%s is not a known logging drain type." % (config["type"],)
+ )
+
+ if logging_type in [
+ DrainType.CONSOLE,
+ DrainType.CONSOLE_JSON,
+ DrainType.CONSOLE_JSON_TERSE,
+ ]:
+ location = config.get("location")
+ if location is None or location not in ["stdout", "stderr"]:
+ raise ConfigError(
+ (
+ "The %s drain needs the 'location' key set to "
+ "either 'stdout' or 'stderr'."
+ )
+ % (logging_type,)
+ )
+
+ pipe = OutputPipeType.lookupByName(location).value
+
+ yield DrainConfiguration(name=name, type=logging_type, location=pipe)
+
+ elif logging_type in [DrainType.FILE, DrainType.FILE_JSON]:
+ if "location" not in config:
+ raise ConfigError(
+ "The %s drain needs the 'location' key set." % (logging_type,)
+ )
+
+ location = config.get("location")
+ if os.path.abspath(location) != location:
+ raise ConfigError(
+ "File paths need to be absolute, '%s' is a relative path"
+ % (location,)
+ )
+ yield DrainConfiguration(name=name, type=logging_type, location=location)
+
+ elif logging_type in [DrainType.NETWORK_JSON_TERSE]:
+ host = config.get("host")
+ port = config.get("port")
+ maximum_buffer = config.get("maximum_buffer", 1000)
+ yield DrainConfiguration(
+ name=name,
+ type=logging_type,
+ location=(host, port),
+ options=NetworkJSONTerseOptions(maximum_buffer=maximum_buffer),
+ )
+
+ else:
+ raise ConfigError(
+ "The %s drain type is currently not implemented."
+ % (config["type"].upper(),)
+ )
+
+
+class StoppableLogPublisher(LogPublisher):
+ """
+ A log publisher that can tell its observers to shut down any external
+ communications.
+ """
+
+ def stop(self):
+ for obs in self._observers:
+ if hasattr(obs, "stop"):
+ obs.stop()
+
+
+def setup_structured_logging(
+ hs,
+ config,
+ log_config: dict,
+ logBeginner: LogBeginner,
+ redirect_stdlib_logging: bool = True,
+) -> LogPublisher:
+ """
+ Set up Twisted's structured logging system.
+
+ Args:
+ hs: The homeserver to use.
+ config (HomeserverConfig): The configuration of the Synapse homeserver.
+ log_config (dict): The log configuration to use.
+ """
+ if config.no_redirect_stdio:
+ raise ConfigError(
+ "no_redirect_stdio cannot be defined using structured logging."
+ )
+
+ logger = Logger()
+
+ if "drains" not in log_config:
+ raise ConfigError("The logging configuration requires a list of drains.")
+
+ observers = [] # type: List[ILogObserver]
+
+ for observer in parse_drain_configs(log_config["drains"]):
+ # Pipe drains
+ if observer.type == DrainType.CONSOLE:
+ logger.debug(
+ "Starting up the {name} console logger drain", name=observer.name
+ )
+ observers.append(SynapseFileLogObserver(observer.location))
+ elif observer.type == DrainType.CONSOLE_JSON:
+ logger.debug(
+ "Starting up the {name} JSON console logger drain", name=observer.name
+ )
+ observers.append(jsonFileLogObserver(observer.location))
+ elif observer.type == DrainType.CONSOLE_JSON_TERSE:
+ logger.debug(
+ "Starting up the {name} terse JSON console logger drain",
+ name=observer.name,
+ )
+ observers.append(
+ TerseJSONToConsoleLogObserver(observer.location, metadata={})
+ )
+
+ # File drains
+ elif observer.type == DrainType.FILE:
+ logger.debug("Starting up the {name} file logger drain", name=observer.name)
+ log_file = open(observer.location, "at", buffering=1, encoding="utf8")
+ observers.append(SynapseFileLogObserver(log_file))
+ elif observer.type == DrainType.FILE_JSON:
+ logger.debug(
+ "Starting up the {name} JSON file logger drain", name=observer.name
+ )
+ log_file = open(observer.location, "at", buffering=1, encoding="utf8")
+ observers.append(jsonFileLogObserver(log_file))
+
+ elif observer.type == DrainType.NETWORK_JSON_TERSE:
+ metadata = {"server_name": hs.config.server_name}
+ log_observer = TerseJSONToTCPLogObserver(
+ hs=hs,
+ host=observer.location[0],
+ port=observer.location[1],
+ metadata=metadata,
+ maximum_buffer=observer.options.maximum_buffer,
+ )
+ log_observer.start()
+ observers.append(log_observer)
+ else:
+ # We should never get here, but, just in case, throw an error.
+ raise ConfigError("%s drain type cannot be configured" % (observer.type,))
+
+ publisher = StoppableLogPublisher(*observers)
+ log_filter = LogLevelFilterPredicate()
+
+ for namespace, namespace_config in log_config.get(
+ "loggers", DEFAULT_LOGGERS
+ ).items():
+ # Set the log level for twisted.logger.Logger namespaces
+ log_filter.setLogLevelForNamespace(
+ namespace,
+ stdlib_log_level_to_twisted(namespace_config.get("level", "INFO")),
+ )
+
+ # Also set the log levels for the stdlib logger namespaces, to prevent
+ # them getting to PythonStdlibToTwistedLogger and having to be formatted
+ if "level" in namespace_config:
+ logging.getLogger(namespace).setLevel(namespace_config.get("level"))
+
+ f = FilteringLogObserver(publisher, [log_filter])
+ lco = LogContextObserver(f)
+
+ if redirect_stdlib_logging:
+ stuff_into_twisted = PythonStdlibToTwistedLogger(lco)
+ stdliblogger = logging.getLogger()
+ stdliblogger.addHandler(stuff_into_twisted)
+
+ # Always redirect standard I/O, otherwise other logging outputs might miss
+ # it.
+ logBeginner.beginLoggingTo([lco], redirectStandardIO=True)
+
+ return publisher
+
+
+def reload_structured_logging(*args, log_config=None) -> None:
+ warnings.warn(
+ "Currently the structured logging system can not be reloaded, doing nothing"
+ )
diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py
new file mode 100644
index 0000000000..c0b9384189
--- /dev/null
+++ b/synapse/logging/_terse_json.py
@@ -0,0 +1,329 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Log formatters that output terse JSON.
+"""
+
+import json
+import sys
+import traceback
+from collections import deque
+from ipaddress import IPv4Address, IPv6Address, ip_address
+from math import floor
+from typing import IO, Optional
+
+import attr
+from zope.interface import implementer
+
+from twisted.application.internet import ClientService
+from twisted.internet.defer import Deferred
+from twisted.internet.endpoints import (
+ HostnameEndpoint,
+ TCP4ClientEndpoint,
+ TCP6ClientEndpoint,
+)
+from twisted.internet.interfaces import IPushProducer, ITransport
+from twisted.internet.protocol import Factory, Protocol
+from twisted.logger import FileLogObserver, ILogObserver, Logger
+
+_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"))
+
+
+def flatten_event(event: dict, metadata: dict, include_time: bool = False):
+ """
+ Flatten a Twisted logging event to an dictionary capable of being sent
+ as a log event to a logging aggregation system.
+
+ The format is vastly simplified and is not designed to be a "human readable
+ string" in the sense that traditional logs are. Instead, the structure is
+ optimised for searchability and filtering, with human-understandable log
+ keys.
+
+ Args:
+ event (dict): The Twisted logging event we are flattening.
+ metadata (dict): Additional data to include with each log message. This
+ can be information like the server name. Since the target log
+ consumer does not know who we are other than by host IP, this
+ allows us to forward through static information.
+ include_time (bool): Should we include the `time` key? If False, the
+ event time is stripped from the event.
+ """
+ new_event = {}
+
+ # If it's a failure, make the new event's log_failure be the traceback text.
+ if "log_failure" in event:
+ new_event["log_failure"] = event["log_failure"].getTraceback()
+
+ # If it's a warning, copy over a string representation of the warning.
+ if "warning" in event:
+ new_event["warning"] = str(event["warning"])
+
+ # Stdlib logging events have "log_text" as their human-readable portion,
+ # Twisted ones have "log_format". For now, include the log_format, so that
+ # context only given in the log format (e.g. what is being logged) is
+ # available.
+ if "log_text" in event:
+ new_event["log"] = event["log_text"]
+ else:
+ new_event["log"] = event["log_format"]
+
+ # We want to include the timestamp when forwarding over the network, but
+ # exclude it when we are writing to stdout. This is because the log ingester
+ # (e.g. logstash, fluentd) can add its own timestamp.
+ if include_time:
+ new_event["time"] = round(event["log_time"], 2)
+
+ # Convert the log level to a textual representation.
+ new_event["level"] = event["log_level"].name.upper()
+
+ # Ignore these keys, and do not transfer them over to the new log object.
+ # They are either useless (isError), transferred manually above (log_time,
+ # log_level, etc), or contain Python objects which are not useful for output
+ # (log_logger, log_source).
+ keys_to_delete = [
+ "isError",
+ "log_failure",
+ "log_format",
+ "log_level",
+ "log_logger",
+ "log_source",
+ "log_system",
+ "log_time",
+ "log_text",
+ "observer",
+ "warning",
+ ]
+
+ # If it's from the Twisted legacy logger (twisted.python.log), it adds some
+ # more keys we want to purge.
+ if event.get("log_namespace") == "log_legacy":
+ keys_to_delete.extend(["message", "system", "time"])
+
+ # Rather than modify the dictionary in place, construct a new one with only
+ # the content we want. The original event should be considered 'frozen'.
+ for key in event.keys():
+
+ if key in keys_to_delete:
+ continue
+
+ if isinstance(event[key], (str, int, bool, float)) or event[key] is None:
+ # If it's a plain type, include it as is.
+ new_event[key] = event[key]
+ else:
+ # If it's not one of those basic types, write out a string
+ # representation. This should probably be a warning in development,
+ # so that we are sure we are only outputting useful data.
+ new_event[key] = str(event[key])
+
+ # Add the metadata information to the event (e.g. the server_name).
+ new_event.update(metadata)
+
+ return new_event
+
+
+def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogObserver:
+ """
+ A log observer that formats events to a flattened JSON representation.
+
+ Args:
+ outFile: The file object to write to.
+ metadata: Metadata to be added to each log object.
+ """
+
+ def formatEvent(_event: dict) -> str:
+ flattened = flatten_event(_event, metadata)
+ return _encoder.encode(flattened) + "\n"
+
+ return FileLogObserver(outFile, formatEvent)
+
+
+@attr.s
+@implementer(IPushProducer)
+class LogProducer(object):
+ """
+ An IPushProducer that writes logs from its buffer to its transport when it
+ is resumed.
+
+ Args:
+ buffer: Log buffer to read logs from.
+ transport: Transport to write to.
+ """
+
+ transport = attr.ib(type=ITransport)
+ _buffer = attr.ib(type=deque)
+ _paused = attr.ib(default=False, type=bool, init=False)
+
+ def pauseProducing(self):
+ self._paused = True
+
+ def stopProducing(self):
+ self._paused = True
+ self._buffer = deque()
+
+ def resumeProducing(self):
+ self._paused = False
+
+ while self._paused is False and (self._buffer and self.transport.connected):
+ try:
+ event = self._buffer.popleft()
+ self.transport.write(_encoder.encode(event).encode("utf8"))
+ self.transport.write(b"\n")
+ except Exception:
+ # Something has gone wrong writing to the transport -- log it
+ # and break out of the while.
+ traceback.print_exc(file=sys.__stderr__)
+ break
+
+
+@attr.s
+@implementer(ILogObserver)
+class TerseJSONToTCPLogObserver(object):
+ """
+ An IObserver that writes JSON logs to a TCP target.
+
+ Args:
+ hs (HomeServer): The homeserver that is being logged for.
+ host: The host of the logging target.
+ port: The logging target's port.
+ metadata: Metadata to be added to each log entry.
+ """
+
+ hs = attr.ib()
+ host = attr.ib(type=str)
+ port = attr.ib(type=int)
+ metadata = attr.ib(type=dict)
+ maximum_buffer = attr.ib(type=int)
+ _buffer = attr.ib(default=attr.Factory(deque), type=deque)
+ _connection_waiter = attr.ib(default=None, type=Optional[Deferred])
+ _logger = attr.ib(default=attr.Factory(Logger))
+ _producer = attr.ib(default=None, type=Optional[LogProducer])
+
+ def start(self) -> None:
+
+ # Connect without DNS lookups if it's a direct IP.
+ try:
+ ip = ip_address(self.host)
+ if isinstance(ip, IPv4Address):
+ endpoint = TCP4ClientEndpoint(
+ self.hs.get_reactor(), self.host, self.port
+ )
+ elif isinstance(ip, IPv6Address):
+ endpoint = TCP6ClientEndpoint(
+ self.hs.get_reactor(), self.host, self.port
+ )
+ except ValueError:
+ endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port)
+
+ factory = Factory.forProtocol(Protocol)
+ self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
+ self._service.startService()
+ self._connect()
+
+ def stop(self):
+ self._service.stopService()
+
+ def _connect(self) -> None:
+ """
+ Triggers an attempt to connect then write to the remote if not already writing.
+ """
+ if self._connection_waiter:
+ return
+
+ self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
+
+ @self._connection_waiter.addErrback
+ def fail(r):
+ r.printTraceback(file=sys.__stderr__)
+ self._connection_waiter = None
+ self._connect()
+
+ @self._connection_waiter.addCallback
+ def writer(r):
+ # We have a connection. If we already have a producer, and its
+ # transport is the same, just trigger a resumeProducing.
+ if self._producer and r.transport is self._producer.transport:
+ self._producer.resumeProducing()
+ self._connection_waiter = None
+ return
+
+ # If the producer is still producing, stop it.
+ if self._producer:
+ self._producer.stopProducing()
+
+ # Make a new producer and start it.
+ self._producer = LogProducer(buffer=self._buffer, transport=r.transport)
+ r.transport.registerProducer(self._producer, True)
+ self._producer.resumeProducing()
+ self._connection_waiter = None
+
+ def _handle_pressure(self) -> None:
+ """
+ Handle backpressure by shedding events.
+
+ The buffer will, in this order, until the buffer is below the maximum:
+ - Shed DEBUG events
+ - Shed INFO events
+ - Shed the middle 50% of the events.
+ """
+ if len(self._buffer) <= self.maximum_buffer:
+ return
+
+ # Strip out DEBUGs
+ self._buffer = deque(
+ filter(lambda event: event["level"] != "DEBUG", self._buffer)
+ )
+
+ if len(self._buffer) <= self.maximum_buffer:
+ return
+
+ # Strip out INFOs
+ self._buffer = deque(
+ filter(lambda event: event["level"] != "INFO", self._buffer)
+ )
+
+ if len(self._buffer) <= self.maximum_buffer:
+ return
+
+ # Cut the middle entries out
+ buffer_split = floor(self.maximum_buffer / 2)
+
+ old_buffer = self._buffer
+ self._buffer = deque()
+
+ for i in range(buffer_split):
+ self._buffer.append(old_buffer.popleft())
+
+ end_buffer = []
+ for i in range(buffer_split):
+ end_buffer.append(old_buffer.pop())
+
+ self._buffer.extend(reversed(end_buffer))
+
+ def __call__(self, event: dict) -> None:
+ flattened = flatten_event(event, self.metadata, include_time=True)
+ self._buffer.append(flattened)
+
+ # Handle backpressure, if it exists.
+ try:
+ self._handle_pressure()
+ except Exception:
+ # If handling backpressure fails,clear the buffer and log the
+ # exception.
+ self._buffer.clear()
+ self._logger.failure("Failed clearing backpressure")
+
+ # Try and write immediately.
+ self._connect()
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
new file mode 100644
index 0000000000..860b99a4c6
--- /dev/null
+++ b/synapse/logging/context.py
@@ -0,0 +1,761 @@
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+""" Thread-local-alike tracking of log contexts within synapse
+
+This module provides objects and utilities for tracking contexts through
+synapse code, so that log lines can include a request identifier, and so that
+CPU and database activity can be accounted for against the request that caused
+them.
+
+See doc/log_contexts.rst for details on how this works.
+"""
+
+import inspect
+import logging
+import threading
+import types
+from typing import TYPE_CHECKING, Optional, Tuple, TypeVar, Union
+
+from typing_extensions import Literal
+
+from twisted.internet import defer, threads
+
+if TYPE_CHECKING:
+ from synapse.logging.scopecontextmanager import _LogContextScope
+
+logger = logging.getLogger(__name__)
+
+try:
+ import resource
+
+ # Python doesn't ship with a definition of RUSAGE_THREAD but it's defined
+ # to be 1 on linux so we hard code it.
+ RUSAGE_THREAD = 1
+
+ # If the system doesn't support RUSAGE_THREAD then this should throw an
+ # exception.
+ resource.getrusage(RUSAGE_THREAD)
+
+ is_thread_resource_usage_supported = True
+
+ def get_thread_resource_usage():
+ return resource.getrusage(RUSAGE_THREAD)
+
+
+except Exception:
+ # If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we
+ # won't track resource usage.
+ is_thread_resource_usage_supported = False
+
+ def get_thread_resource_usage():
+ return None
+
+
+# get an id for the current thread.
+#
+# threading.get_ident doesn't actually return an OS-level tid, and annoyingly,
+# on Linux it actually returns the same value either side of a fork() call. However
+# we only fork in one place, so it's not worth the hoop-jumping to get a real tid.
+#
+get_thread_id = threading.get_ident
+
+
+class ContextResourceUsage(object):
+ """Object for tracking the resources used by a log context
+
+ Attributes:
+ ru_utime (float): user CPU time (in seconds)
+ ru_stime (float): system CPU time (in seconds)
+ db_txn_count (int): number of database transactions done
+ db_sched_duration_sec (float): amount of time spent waiting for a
+ database connection
+ db_txn_duration_sec (float): amount of time spent doing database
+ transactions (excluding scheduling time)
+ evt_db_fetch_count (int): number of events requested from the database
+ """
+
+ __slots__ = [
+ "ru_stime",
+ "ru_utime",
+ "db_txn_count",
+ "db_txn_duration_sec",
+ "db_sched_duration_sec",
+ "evt_db_fetch_count",
+ ]
+
+ def __init__(self, copy_from: "Optional[ContextResourceUsage]" = None) -> None:
+ """Create a new ContextResourceUsage
+
+ Args:
+ copy_from (ContextResourceUsage|None): if not None, an object to
+ copy stats from
+ """
+ if copy_from is None:
+ self.reset()
+ else:
+ # FIXME: mypy can't infer the types set via reset() above, so specify explicitly for now
+ self.ru_utime = copy_from.ru_utime # type: float
+ self.ru_stime = copy_from.ru_stime # type: float
+ self.db_txn_count = copy_from.db_txn_count # type: int
+
+ self.db_txn_duration_sec = copy_from.db_txn_duration_sec # type: float
+ self.db_sched_duration_sec = copy_from.db_sched_duration_sec # type: float
+ self.evt_db_fetch_count = copy_from.evt_db_fetch_count # type: int
+
+ def copy(self) -> "ContextResourceUsage":
+ return ContextResourceUsage(copy_from=self)
+
+ def reset(self) -> None:
+ self.ru_stime = 0.0
+ self.ru_utime = 0.0
+ self.db_txn_count = 0
+
+ self.db_txn_duration_sec = 0.0
+ self.db_sched_duration_sec = 0.0
+ self.evt_db_fetch_count = 0
+
+ def __repr__(self) -> str:
+ return (
+ "<ContextResourceUsage ru_stime='%r', ru_utime='%r', "
+ "db_txn_count='%r', db_txn_duration_sec='%r', "
+ "db_sched_duration_sec='%r', evt_db_fetch_count='%r'>"
+ ) % (
+ self.ru_stime,
+ self.ru_utime,
+ self.db_txn_count,
+ self.db_txn_duration_sec,
+ self.db_sched_duration_sec,
+ self.evt_db_fetch_count,
+ )
+
+ def __iadd__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
+ """Add another ContextResourceUsage's stats to this one's.
+
+ Args:
+ other (ContextResourceUsage): the other resource usage object
+ """
+ self.ru_utime += other.ru_utime
+ self.ru_stime += other.ru_stime
+ self.db_txn_count += other.db_txn_count
+ self.db_txn_duration_sec += other.db_txn_duration_sec
+ self.db_sched_duration_sec += other.db_sched_duration_sec
+ self.evt_db_fetch_count += other.evt_db_fetch_count
+ return self
+
+ def __isub__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
+ self.ru_utime -= other.ru_utime
+ self.ru_stime -= other.ru_stime
+ self.db_txn_count -= other.db_txn_count
+ self.db_txn_duration_sec -= other.db_txn_duration_sec
+ self.db_sched_duration_sec -= other.db_sched_duration_sec
+ self.evt_db_fetch_count -= other.evt_db_fetch_count
+ return self
+
+ def __add__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
+ res = ContextResourceUsage(copy_from=self)
+ res += other
+ return res
+
+ def __sub__(self, other: "ContextResourceUsage") -> "ContextResourceUsage":
+ res = ContextResourceUsage(copy_from=self)
+ res -= other
+ return res
+
+
+LoggingContextOrSentinel = Union["LoggingContext", "LoggingContext.Sentinel"]
+
+
+class LoggingContext(object):
+ """Additional context for log formatting. Contexts are scoped within a
+ "with" block.
+
+ If a parent is given when creating a new context, then:
+ - logging fields are copied from the parent to the new context on entry
+ - when the new context exits, the cpu usage stats are copied from the
+ child to the parent
+
+ Args:
+ name (str): Name for the context for debugging.
+ parent_context (LoggingContext|None): The parent of the new context
+ """
+
+ __slots__ = [
+ "previous_context",
+ "name",
+ "parent_context",
+ "_resource_usage",
+ "usage_start",
+ "main_thread",
+ "alive",
+ "request",
+ "tag",
+ "scope",
+ ]
+
+ thread_local = threading.local()
+
+ class Sentinel(object):
+ """Sentinel to represent the root context"""
+
+ __slots__ = ["previous_context", "alive", "request", "scope", "tag"]
+
+ def __init__(self) -> None:
+ # Minimal set for compatibility with LoggingContext
+ self.previous_context = None
+ self.alive = None
+ self.request = None
+ self.scope = None
+ self.tag = None
+
+ def __str__(self):
+ return "sentinel"
+
+ def copy_to(self, record):
+ pass
+
+ def copy_to_twisted_log_entry(self, record):
+ record["request"] = None
+ record["scope"] = None
+
+ def start(self):
+ pass
+
+ def stop(self):
+ pass
+
+ def add_database_transaction(self, duration_sec):
+ pass
+
+ def add_database_scheduled(self, sched_sec):
+ pass
+
+ def record_event_fetch(self, event_count):
+ pass
+
+ def __nonzero__(self):
+ return False
+
+ __bool__ = __nonzero__ # python3
+
+ sentinel = Sentinel()
+
+ def __init__(self, name=None, parent_context=None, request=None) -> None:
+ self.previous_context = LoggingContext.current_context()
+ self.name = name
+
+ # track the resources used by this context so far
+ self._resource_usage = ContextResourceUsage()
+
+ # If alive has the thread resource usage when the logcontext last
+ # became active.
+ self.usage_start = None
+
+ self.main_thread = get_thread_id()
+ self.request = None
+ self.tag = ""
+ self.alive = True
+ self.scope = None # type: Optional[_LogContextScope]
+
+ self.parent_context = parent_context
+
+ if self.parent_context is not None:
+ self.parent_context.copy_to(self)
+
+ if request is not None:
+ # the request param overrides the request from the parent context
+ self.request = request
+
+ def __str__(self) -> str:
+ if self.request:
+ return str(self.request)
+ return "%s@%x" % (self.name, id(self))
+
+ @classmethod
+ def current_context(cls) -> LoggingContextOrSentinel:
+ """Get the current logging context from thread local storage
+
+ Returns:
+ LoggingContext: the current logging context
+ """
+ return getattr(cls.thread_local, "current_context", cls.sentinel)
+
+ @classmethod
+ def set_current_context(
+ cls, context: LoggingContextOrSentinel
+ ) -> LoggingContextOrSentinel:
+ """Set the current logging context in thread local storage
+ Args:
+ context(LoggingContext): The context to activate.
+ Returns:
+ The context that was previously active
+ """
+ current = cls.current_context()
+
+ if current is not context:
+ current.stop()
+ cls.thread_local.current_context = context
+ context.start()
+ return current
+
+ def __enter__(self) -> "LoggingContext":
+ """Enters this logging context into thread local storage"""
+ old_context = self.set_current_context(self)
+ if self.previous_context != old_context:
+ logger.warning(
+ "Expected previous context %r, found %r",
+ self.previous_context,
+ old_context,
+ )
+ self.alive = True
+
+ return self
+
+ def __exit__(self, type, value, traceback) -> None:
+ """Restore the logging context in thread local storage to the state it
+ was before this context was entered.
+ Returns:
+ None to avoid suppressing any exceptions that were thrown.
+ """
+ current = self.set_current_context(self.previous_context)
+ if current is not self:
+ if current is self.sentinel:
+ logger.warning("Expected logging context %s was lost", self)
+ else:
+ logger.warning(
+ "Expected logging context %s but found %s", self, current
+ )
+ self.alive = False
+
+ # if we have a parent, pass our CPU usage stats on
+ if self.parent_context is not None and hasattr(
+ self.parent_context, "_resource_usage"
+ ):
+ self.parent_context._resource_usage += self._resource_usage
+
+ # reset them in case we get entered again
+ self._resource_usage.reset()
+
+ def copy_to(self, record) -> None:
+ """Copy logging fields from this context to a log record or
+ another LoggingContext
+ """
+
+ # we track the current request
+ record.request = self.request
+
+ # we also track the current scope:
+ record.scope = self.scope
+
+ def copy_to_twisted_log_entry(self, record) -> None:
+ """
+ Copy logging fields from this context to a Twisted log record.
+ """
+ record["request"] = self.request
+ record["scope"] = self.scope
+
+ def start(self) -> None:
+ if get_thread_id() != self.main_thread:
+ logger.warning("Started logcontext %s on different thread", self)
+ return
+
+ # If we haven't already started record the thread resource usage so
+ # far
+ if not self.usage_start:
+ self.usage_start = get_thread_resource_usage()
+
+ def stop(self) -> None:
+ if get_thread_id() != self.main_thread:
+ logger.warning("Stopped logcontext %s on different thread", self)
+ return
+
+ # When we stop, let's record the cpu used since we started
+ if not self.usage_start:
+ # Log a warning on platforms that support thread usage tracking
+ if is_thread_resource_usage_supported:
+ logger.warning(
+ "Called stop on logcontext %s without calling start", self
+ )
+ return
+
+ utime_delta, stime_delta = self._get_cputime()
+ self._resource_usage.ru_utime += utime_delta
+ self._resource_usage.ru_stime += stime_delta
+
+ self.usage_start = None
+
+ def get_resource_usage(self) -> ContextResourceUsage:
+ """Get resources used by this logcontext so far.
+
+ Returns:
+ ContextResourceUsage: a *copy* of the object tracking resource
+ usage so far
+ """
+ # we always return a copy, for consistency
+ res = self._resource_usage.copy()
+
+ # If we are on the correct thread and we're currently running then we
+ # can include resource usage so far.
+ is_main_thread = get_thread_id() == self.main_thread
+ if self.alive and self.usage_start and is_main_thread:
+ utime_delta, stime_delta = self._get_cputime()
+ res.ru_utime += utime_delta
+ res.ru_stime += stime_delta
+
+ return res
+
+ def _get_cputime(self) -> Tuple[float, float]:
+ """Get the cpu usage time so far
+
+ Returns: Tuple[float, float]: seconds in user mode, seconds in system mode
+ """
+ assert self.usage_start is not None
+
+ current = get_thread_resource_usage()
+
+ # Indicate to mypy that we know that self.usage_start is None.
+ assert self.usage_start is not None
+
+ utime_delta = current.ru_utime - self.usage_start.ru_utime
+ stime_delta = current.ru_stime - self.usage_start.ru_stime
+
+ # sanity check
+ if utime_delta < 0:
+ logger.error(
+ "utime went backwards! %f < %f",
+ current.ru_utime,
+ self.usage_start.ru_utime,
+ )
+ utime_delta = 0
+
+ if stime_delta < 0:
+ logger.error(
+ "stime went backwards! %f < %f",
+ current.ru_stime,
+ self.usage_start.ru_stime,
+ )
+ stime_delta = 0
+
+ return utime_delta, stime_delta
+
+ def add_database_transaction(self, duration_sec: float) -> None:
+ if duration_sec < 0:
+ raise ValueError("DB txn time can only be non-negative")
+ self._resource_usage.db_txn_count += 1
+ self._resource_usage.db_txn_duration_sec += duration_sec
+
+ def add_database_scheduled(self, sched_sec: float) -> None:
+ """Record a use of the database pool
+
+ Args:
+ sched_sec (float): number of seconds it took us to get a
+ connection
+ """
+ if sched_sec < 0:
+ raise ValueError("DB scheduling time can only be non-negative")
+ self._resource_usage.db_sched_duration_sec += sched_sec
+
+ def record_event_fetch(self, event_count: int) -> None:
+ """Record a number of events being fetched from the db
+
+ Args:
+ event_count (int): number of events being fetched
+ """
+ self._resource_usage.evt_db_fetch_count += event_count
+
+
+class LoggingContextFilter(logging.Filter):
+ """Logging filter that adds values from the current logging context to each
+ record.
+ Args:
+ **defaults: Default values to avoid formatters complaining about
+ missing fields
+ """
+
+ def __init__(self, **defaults) -> None:
+ self.defaults = defaults
+
+ def filter(self, record) -> Literal[True]:
+ """Add each fields from the logging contexts to the record.
+ Returns:
+ True to include the record in the log output.
+ """
+ context = LoggingContext.current_context()
+ for key, value in self.defaults.items():
+ setattr(record, key, value)
+
+ # context should never be None, but if it somehow ends up being, then
+ # we end up in a death spiral of infinite loops, so let's check, for
+ # robustness' sake.
+ if context is not None:
+ context.copy_to(record)
+
+ return True
+
+
+class PreserveLoggingContext(object):
+ """Captures the current logging context and restores it when the scope is
+ exited. Used to restore the context after a function using
+ @defer.inlineCallbacks is resumed by a callback from the reactor."""
+
+ __slots__ = ["current_context", "new_context", "has_parent"]
+
+ def __init__(self, new_context: Optional[LoggingContextOrSentinel] = None) -> None:
+ if new_context is None:
+ self.new_context = LoggingContext.sentinel # type: LoggingContextOrSentinel
+ else:
+ self.new_context = new_context
+
+ def __enter__(self) -> None:
+ """Captures the current logging context"""
+ self.current_context = LoggingContext.set_current_context(self.new_context)
+
+ if self.current_context:
+ self.has_parent = self.current_context.previous_context is not None
+ if not self.current_context.alive:
+ logger.debug("Entering dead context: %s", self.current_context)
+
+ def __exit__(self, type, value, traceback) -> None:
+ """Restores the current logging context"""
+ context = LoggingContext.set_current_context(self.current_context)
+
+ if context != self.new_context:
+ if context is LoggingContext.sentinel:
+ logger.warning("Expected logging context %s was lost", self.new_context)
+ else:
+ logger.warning(
+ "Expected logging context %s but found %s",
+ self.new_context,
+ context,
+ )
+
+ if self.current_context is not LoggingContext.sentinel:
+ if not self.current_context.alive:
+ logger.debug("Restoring dead context: %s", self.current_context)
+
+
+def nested_logging_context(
+ suffix: str, parent_context: Optional[LoggingContext] = None
+) -> LoggingContext:
+ """Creates a new logging context as a child of another.
+
+ The nested logging context will have a 'request' made up of the parent context's
+ request, plus the given suffix.
+
+ CPU/db usage stats will be added to the parent context's on exit.
+
+ Normal usage looks like:
+
+ with nested_logging_context(suffix):
+ # ... do stuff
+
+ Args:
+ suffix (str): suffix to add to the parent context's 'request'.
+ parent_context (LoggingContext|None): parent context. Will use the current context
+ if None.
+
+ Returns:
+ LoggingContext: new logging context.
+ """
+ if parent_context is not None:
+ context = parent_context # type: LoggingContextOrSentinel
+ else:
+ context = LoggingContext.current_context()
+ return LoggingContext(
+ parent_context=context, request=str(context.request) + "-" + suffix
+ )
+
+
+def preserve_fn(f):
+ """Function decorator which wraps the function with run_in_background"""
+
+ def g(*args, **kwargs):
+ return run_in_background(f, *args, **kwargs)
+
+ return g
+
+
+def run_in_background(f, *args, **kwargs):
+ """Calls a function, ensuring that the current context is restored after
+ return from the function, and that the sentinel context is set once the
+ deferred returned by the function completes.
+
+ Useful for wrapping functions that return a deferred or coroutine, which you don't
+ yield or await on (for instance because you want to pass it to
+ deferred.gatherResults()).
+
+ If f returns a Coroutine object, it will be wrapped into a Deferred (which will have
+ the side effect of executing the coroutine).
+
+ Note that if you completely discard the result, you should make sure that
+ `f` doesn't raise any deferred exceptions, otherwise a scary-looking
+ CRITICAL error about an unhandled error will be logged without much
+ indication about where it came from.
+ """
+ current = LoggingContext.current_context()
+ try:
+ res = f(*args, **kwargs)
+ except: # noqa: E722
+ # the assumption here is that the caller doesn't want to be disturbed
+ # by synchronous exceptions, so let's turn them into Failures.
+ return defer.fail()
+
+ if isinstance(res, types.CoroutineType):
+ res = defer.ensureDeferred(res)
+
+ if not isinstance(res, defer.Deferred):
+ return res
+
+ if res.called and not res.paused:
+ # The function should have maintained the logcontext, so we can
+ # optimise out the messing about
+ return res
+
+ # The function may have reset the context before returning, so
+ # we need to restore it now.
+ ctx = LoggingContext.set_current_context(current)
+
+ # The original context will be restored when the deferred
+ # completes, but there is nothing waiting for it, so it will
+ # get leaked into the reactor or some other function which
+ # wasn't expecting it. We therefore need to reset the context
+ # here.
+ #
+ # (If this feels asymmetric, consider it this way: we are
+ # effectively forking a new thread of execution. We are
+ # probably currently within a ``with LoggingContext()`` block,
+ # which is supposed to have a single entry and exit point. But
+ # by spawning off another deferred, we are effectively
+ # adding a new exit point.)
+ res.addBoth(_set_context_cb, ctx)
+ return res
+
+
+def make_deferred_yieldable(deferred):
+ """Given a deferred (or coroutine), make it follow the Synapse logcontext
+ rules:
+
+ If the deferred has completed (or is not actually a Deferred), essentially
+ does nothing (just returns another completed deferred with the
+ result/failure).
+
+ If the deferred has not yet completed, resets the logcontext before
+ returning a deferred. Then, when the deferred completes, restores the
+ current logcontext before running callbacks/errbacks.
+
+ (This is more-or-less the opposite operation to run_in_background.)
+ """
+ if inspect.isawaitable(deferred):
+ # If we're given a coroutine we convert it to a deferred so that we
+ # run it and find out if it immediately finishes, it it does then we
+ # don't need to fiddle with log contexts at all and can return
+ # immediately.
+ deferred = defer.ensureDeferred(deferred)
+
+ if not isinstance(deferred, defer.Deferred):
+ return deferred
+
+ if deferred.called and not deferred.paused:
+ # it looks like this deferred is ready to run any callbacks we give it
+ # immediately. We may as well optimise out the logcontext faffery.
+ return deferred
+
+ # ok, we can't be sure that a yield won't block, so let's reset the
+ # logcontext, and add a callback to the deferred to restore it.
+ prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
+ deferred.addBoth(_set_context_cb, prev_context)
+ return deferred
+
+
+ResultT = TypeVar("ResultT")
+
+
+def _set_context_cb(result: ResultT, context: LoggingContext) -> ResultT:
+ """A callback function which just sets the logging context"""
+ LoggingContext.set_current_context(context)
+ return result
+
+
+def defer_to_thread(reactor, f, *args, **kwargs):
+ """
+ Calls the function `f` using a thread from the reactor's default threadpool and
+ returns the result as a Deferred.
+
+ Creates a new logcontext for `f`, which is created as a child of the current
+ logcontext (so its CPU usage metrics will get attributed to the current
+ logcontext). `f` should preserve the logcontext it is given.
+
+ The result deferred follows the Synapse logcontext rules: you should `yield`
+ on it.
+
+ Args:
+ reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread
+ the Deferred will be invoked, and whose threadpool we should use for the
+ function.
+
+ Normally this will be hs.get_reactor().
+
+ f (callable): The function to call.
+
+ args: positional arguments to pass to f.
+
+ kwargs: keyword arguments to pass to f.
+
+ Returns:
+ Deferred: A Deferred which fires a callback with the result of `f`, or an
+ errback if `f` throws an exception.
+ """
+ return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs)
+
+
+def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
+ """
+ A wrapper for twisted.internet.threads.deferToThreadpool, which handles
+ logcontexts correctly.
+
+ Calls the function `f` using a thread from the given threadpool and returns
+ the result as a Deferred.
+
+ Creates a new logcontext for `f`, which is created as a child of the current
+ logcontext (so its CPU usage metrics will get attributed to the current
+ logcontext). `f` should preserve the logcontext it is given.
+
+ The result deferred follows the Synapse logcontext rules: you should `yield`
+ on it.
+
+ Args:
+ reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread
+ the Deferred will be invoked. Normally this will be hs.get_reactor().
+
+ threadpool (twisted.python.threadpool.ThreadPool): The threadpool to use for
+ running `f`. Normally this will be hs.get_reactor().getThreadPool().
+
+ f (callable): The function to call.
+
+ args: positional arguments to pass to f.
+
+ kwargs: keyword arguments to pass to f.
+
+ Returns:
+ Deferred: A Deferred which fires a callback with the result of `f`, or an
+ errback if `f` throws an exception.
+ """
+ logcontext = LoggingContext.current_context()
+
+ def g():
+ with LoggingContext(parent_context=logcontext):
+ return f(*args, **kwargs)
+
+ return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g))
diff --git a/synapse/logging/formatter.py b/synapse/logging/formatter.py
new file mode 100644
index 0000000000..fbf570c756
--- /dev/null
+++ b/synapse/logging/formatter.py
@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import logging
+import traceback
+
+from six import StringIO
+
+
+class LogFormatter(logging.Formatter):
+ """Log formatter which gives more detail for exceptions
+
+ This is the same as the standard log formatter, except that when logging
+ exceptions [typically via log.foo("msg", exc_info=1)], it prints the
+ sequence that led up to the point at which the exception was caught.
+ (Normally only stack frames between the point the exception was raised and
+ where it was caught are logged).
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(LogFormatter, self).__init__(*args, **kwargs)
+
+ def formatException(self, ei):
+ sio = StringIO()
+ (typ, val, tb) = ei
+
+ # log the stack above the exception capture point if possible, but
+ # check that we actually have an f_back attribute to work around
+ # https://twistedmatrix.com/trac/ticket/9305
+
+ if tb and hasattr(tb.tb_frame, "f_back"):
+ sio.write("Capture point (most recent call last):\n")
+ traceback.print_stack(tb.tb_frame.f_back, None, sio)
+
+ traceback.print_exception(typ, val, tb, None, sio)
+ s = sio.getvalue()
+ sio.close()
+ if s[-1:] == "\n":
+ s = s[:-1]
+ return s
diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py
new file mode 100644
index 0000000000..0638cec429
--- /dev/null
+++ b/synapse/logging/opentracing.py
@@ -0,0 +1,810 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# NOTE
+# This is a small wrapper around opentracing because opentracing is not currently
+# packaged downstream (specifically debian). Since opentracing instrumentation is
+# fairly invasive it was awkward to make it optional. As a result we opted to encapsulate
+# all opentracing state in these methods which effectively noop if opentracing is
+# not present. We should strongly consider encouraging the downstream distributers
+# to package opentracing and making opentracing a full dependency. In order to facilitate
+# this move the methods have work very similarly to opentracing's and it should only
+# be a matter of few regexes to move over to opentracing's access patterns proper.
+
+"""
+============================
+Using OpenTracing in Synapse
+============================
+
+Python-specific tracing concepts are at https://opentracing.io/guides/python/.
+Note that Synapse wraps OpenTracing in a small module (this one) in order to make the
+OpenTracing dependency optional. That means that the access patterns are
+different to those demonstrated in the OpenTracing guides. However, it is
+still useful to know, especially if OpenTracing is included as a full dependency
+in the future or if you are modifying this module.
+
+
+OpenTracing is encapsulated so that
+no span objects from OpenTracing are exposed in Synapse's code. This allows
+OpenTracing to be easily disabled in Synapse and thereby have OpenTracing as
+an optional dependency. This does however limit the number of modifiable spans
+at any point in the code to one. From here out references to `opentracing`
+in the code snippets refer to the Synapses module.
+Most methods provided in the module have a direct correlation to those provided
+by opentracing. Refer to docs there for a more in-depth documentation on some of
+the args and methods.
+
+Tracing
+-------
+
+In Synapse it is not possible to start a non-active span. Spans can be started
+using the ``start_active_span`` method. This returns a scope (see
+OpenTracing docs) which is a context manager that needs to be entered and
+exited. This is usually done by using ``with``.
+
+.. code-block:: python
+
+ from synapse.logging.opentracing import start_active_span
+
+ with start_active_span("operation name"):
+ # Do something we want to tracer
+
+Forgetting to enter or exit a scope will result in some mysterious and grievous log
+context errors.
+
+At anytime where there is an active span ``opentracing.set_tag`` can be used to
+set a tag on the current active span.
+
+Tracing functions
+-----------------
+
+Functions can be easily traced using decorators. The name of
+the function becomes the operation name for the span.
+
+.. code-block:: python
+
+ from synapse.logging.opentracing import trace
+
+ # Start a span using 'interesting_function' as the operation name
+ @trace
+ def interesting_function(*args, **kwargs):
+ # Does all kinds of cool and expected things
+ return something_usual_and_useful
+
+
+Operation names can be explicitly set for a function by passing the
+operation name to ``trace``
+
+.. code-block:: python
+
+ from synapse.logging.opentracing import trace
+
+ @trace(opname="a_better_operation_name")
+ def interesting_badly_named_function(*args, **kwargs):
+ # Does all kinds of cool and expected things
+ return something_usual_and_useful
+
+Setting Tags
+------------
+
+To set a tag on the active span do
+
+.. code-block:: python
+
+ from synapse.logging.opentracing import set_tag
+
+ set_tag(tag_name, tag_value)
+
+There's a convenient decorator to tag all the args of the method. It uses
+inspection in order to use the formal parameter names prefixed with 'ARG_' as
+tag names. It uses kwarg names as tag names without the prefix.
+
+.. code-block:: python
+
+ from synapse.logging.opentracing import tag_args
+
+ @tag_args
+ def set_fates(clotho, lachesis, atropos, father="Zues", mother="Themis"):
+ pass
+
+ set_fates("the story", "the end", "the act")
+ # This will have the following tags
+ # - ARG_clotho: "the story"
+ # - ARG_lachesis: "the end"
+ # - ARG_atropos: "the act"
+ # - father: "Zues"
+ # - mother: "Themis"
+
+Contexts and carriers
+---------------------
+
+There are a selection of wrappers for injecting and extracting contexts from
+carriers provided. Unfortunately OpenTracing's three context injection
+techniques are not adequate for our inject of OpenTracing span-contexts into
+Twisted's http headers, EDU contents and our database tables. Also note that
+the binary encoding format mandated by OpenTracing is not actually implemented
+by jaeger_client v4.0.0 - it will silently noop.
+Please refer to the end of ``logging/opentracing.py`` for the available
+injection and extraction methods.
+
+Homeserver whitelisting
+-----------------------
+
+Most of the whitelist checks are encapsulated in the modules's injection
+and extraction method but be aware that using custom carriers or crossing
+unchartered waters will require the enforcement of the whitelist.
+``logging/opentracing.py`` has a ``whitelisted_homeserver`` method which takes
+in a destination and compares it to the whitelist.
+
+Most injection methods take a 'destination' arg. The context will only be injected
+if the destination matches the whitelist or the destination is None.
+
+=======
+Gotchas
+=======
+
+- Checking whitelists on span propagation
+- Inserting pii
+- Forgetting to enter or exit a scope
+- Span source: make sure that the span you expect to be active across a
+ function call really will be that one. Does the current function have more
+ than one caller? Will all of those calling functions have be in a context
+ with an active span?
+"""
+
+import contextlib
+import inspect
+import logging
+import re
+import types
+from functools import wraps
+from typing import Dict
+
+from canonicaljson import json
+
+from twisted.internet import defer
+
+from synapse.config import ConfigError
+
+# Helper class
+
+
+class _DummyTagNames(object):
+ """wrapper of opentracings tags. We need to have them if we
+ want to reference them without opentracing around. Clearly they
+ should never actually show up in a trace. `set_tags` overwrites
+ these with the correct ones."""
+
+ INVALID_TAG = "invalid-tag"
+ COMPONENT = INVALID_TAG
+ DATABASE_INSTANCE = INVALID_TAG
+ DATABASE_STATEMENT = INVALID_TAG
+ DATABASE_TYPE = INVALID_TAG
+ DATABASE_USER = INVALID_TAG
+ ERROR = INVALID_TAG
+ HTTP_METHOD = INVALID_TAG
+ HTTP_STATUS_CODE = INVALID_TAG
+ HTTP_URL = INVALID_TAG
+ MESSAGE_BUS_DESTINATION = INVALID_TAG
+ PEER_ADDRESS = INVALID_TAG
+ PEER_HOSTNAME = INVALID_TAG
+ PEER_HOST_IPV4 = INVALID_TAG
+ PEER_HOST_IPV6 = INVALID_TAG
+ PEER_PORT = INVALID_TAG
+ PEER_SERVICE = INVALID_TAG
+ SAMPLING_PRIORITY = INVALID_TAG
+ SERVICE = INVALID_TAG
+ SPAN_KIND = INVALID_TAG
+ SPAN_KIND_CONSUMER = INVALID_TAG
+ SPAN_KIND_PRODUCER = INVALID_TAG
+ SPAN_KIND_RPC_CLIENT = INVALID_TAG
+ SPAN_KIND_RPC_SERVER = INVALID_TAG
+
+
+try:
+ import opentracing
+
+ tags = opentracing.tags
+except ImportError:
+ opentracing = None
+ tags = _DummyTagNames
+try:
+ from jaeger_client import Config as JaegerConfig
+ from synapse.logging.scopecontextmanager import LogContextScopeManager
+except ImportError:
+ JaegerConfig = None # type: ignore
+ LogContextScopeManager = None # type: ignore
+
+
+logger = logging.getLogger(__name__)
+
+
+# Block everything by default
+# A regex which matches the server_names to expose traces for.
+# None means 'block everything'.
+_homeserver_whitelist = None
+
+# Util methods
+
+
+def only_if_tracing(func):
+ """Executes the function only if we're tracing. Otherwise returns None."""
+
+ @wraps(func)
+ def _only_if_tracing_inner(*args, **kwargs):
+ if opentracing:
+ return func(*args, **kwargs)
+ else:
+ return
+
+ return _only_if_tracing_inner
+
+
+def ensure_active_span(message, ret=None):
+ """Executes the operation only if opentracing is enabled and there is an active span.
+ If there is no active span it logs message at the error level.
+
+ Args:
+ message (str): Message which fills in "There was no active span when trying to %s"
+ in the error log if there is no active span and opentracing is enabled.
+ ret (object): return value if opentracing is None or there is no active span.
+
+ Returns (object): The result of the func or ret if opentracing is disabled or there
+ was no active span.
+ """
+
+ def ensure_active_span_inner_1(func):
+ @wraps(func)
+ def ensure_active_span_inner_2(*args, **kwargs):
+ if not opentracing:
+ return ret
+
+ if not opentracing.tracer.active_span:
+ logger.error(
+ "There was no active span when trying to %s."
+ " Did you forget to start one or did a context slip?",
+ message,
+ )
+
+ return ret
+
+ return func(*args, **kwargs)
+
+ return ensure_active_span_inner_2
+
+ return ensure_active_span_inner_1
+
+
+@contextlib.contextmanager
+def _noop_context_manager(*args, **kwargs):
+ """Does exactly what it says on the tin"""
+ yield
+
+
+# Setup
+
+
+def init_tracer(config):
+ """Set the whitelists and initialise the JaegerClient tracer
+
+ Args:
+ config (HomeserverConfig): The config used by the homeserver
+ """
+ global opentracing
+ if not config.opentracer_enabled:
+ # We don't have a tracer
+ opentracing = None
+ return
+
+ if not opentracing or not JaegerConfig:
+ raise ConfigError(
+ "The server has been configured to use opentracing but opentracing is not "
+ "installed."
+ )
+
+ # Include the worker name
+ name = config.worker_name if config.worker_name else "master"
+
+ # Pull out the jaeger config if it was given. Otherwise set it to something sensible.
+ # See https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/config.py
+
+ set_homeserver_whitelist(config.opentracer_whitelist)
+
+ JaegerConfig(
+ config=config.jaeger_config,
+ service_name="{} {}".format(config.server_name, name),
+ scope_manager=LogContextScopeManager(config),
+ ).initialize_tracer()
+
+
+# Whitelisting
+
+
+@only_if_tracing
+def set_homeserver_whitelist(homeserver_whitelist):
+ """Sets the homeserver whitelist
+
+ Args:
+ homeserver_whitelist (Iterable[str]): regex of whitelisted homeservers
+ """
+ global _homeserver_whitelist
+ if homeserver_whitelist:
+ # Makes a single regex which accepts all passed in regexes in the list
+ _homeserver_whitelist = re.compile(
+ "({})".format(")|(".join(homeserver_whitelist))
+ )
+
+
+@only_if_tracing
+def whitelisted_homeserver(destination):
+ """Checks if a destination matches the whitelist
+
+ Args:
+ destination (str)
+ """
+
+ if _homeserver_whitelist:
+ return _homeserver_whitelist.match(destination)
+ return False
+
+
+# Start spans and scopes
+
+# Could use kwargs but I want these to be explicit
+def start_active_span(
+ operation_name,
+ child_of=None,
+ references=None,
+ tags=None,
+ start_time=None,
+ ignore_active_span=False,
+ finish_on_close=True,
+):
+ """Starts an active opentracing span. Note, the scope doesn't become active
+ until it has been entered, however, the span starts from the time this
+ message is called.
+ Args:
+ See opentracing.tracer
+ Returns:
+ scope (Scope) or noop_context_manager
+ """
+
+ if opentracing is None:
+ return _noop_context_manager()
+
+ return opentracing.tracer.start_active_span(
+ operation_name,
+ child_of=child_of,
+ references=references,
+ tags=tags,
+ start_time=start_time,
+ ignore_active_span=ignore_active_span,
+ finish_on_close=finish_on_close,
+ )
+
+
+def start_active_span_follows_from(operation_name, contexts):
+ if opentracing is None:
+ return _noop_context_manager()
+
+ references = [opentracing.follows_from(context) for context in contexts]
+ scope = start_active_span(operation_name, references=references)
+ return scope
+
+
+def start_active_span_from_request(
+ request,
+ operation_name,
+ references=None,
+ tags=None,
+ start_time=None,
+ ignore_active_span=False,
+ finish_on_close=True,
+):
+ """
+ Extracts a span context from a Twisted Request.
+ args:
+ headers (twisted.web.http.Request)
+
+ For the other args see opentracing.tracer
+
+ returns:
+ span_context (opentracing.span.SpanContext)
+ """
+ # Twisted encodes the values as lists whereas opentracing doesn't.
+ # So, we take the first item in the list.
+ # Also, twisted uses byte arrays while opentracing expects strings.
+
+ if opentracing is None:
+ return _noop_context_manager()
+
+ header_dict = {
+ k.decode(): v[0].decode() for k, v in request.requestHeaders.getAllRawHeaders()
+ }
+ context = opentracing.tracer.extract(opentracing.Format.HTTP_HEADERS, header_dict)
+
+ return opentracing.tracer.start_active_span(
+ operation_name,
+ child_of=context,
+ references=references,
+ tags=tags,
+ start_time=start_time,
+ ignore_active_span=ignore_active_span,
+ finish_on_close=finish_on_close,
+ )
+
+
+def start_active_span_from_edu(
+ edu_content,
+ operation_name,
+ references=[],
+ tags=None,
+ start_time=None,
+ ignore_active_span=False,
+ finish_on_close=True,
+):
+ """
+ Extracts a span context from an edu and uses it to start a new active span
+
+ Args:
+ edu_content (dict): and edu_content with a `context` field whose value is
+ canonical json for a dict which contains opentracing information.
+
+ For the other args see opentracing.tracer
+ """
+
+ if opentracing is None:
+ return _noop_context_manager()
+
+ carrier = json.loads(edu_content.get("context", "{}")).get("opentracing", {})
+ context = opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
+ _references = [
+ opentracing.child_of(span_context_from_string(x))
+ for x in carrier.get("references", [])
+ ]
+
+ # For some reason jaeger decided not to support the visualization of multiple parent
+ # spans or explicitely show references. I include the span context as a tag here as
+ # an aid to people debugging but it's really not an ideal solution.
+
+ references += _references
+
+ scope = opentracing.tracer.start_active_span(
+ operation_name,
+ child_of=context,
+ references=references,
+ tags=tags,
+ start_time=start_time,
+ ignore_active_span=ignore_active_span,
+ finish_on_close=finish_on_close,
+ )
+
+ scope.span.set_tag("references", carrier.get("references", []))
+ return scope
+
+
+# Opentracing setters for tags, logs, etc
+
+
+@ensure_active_span("set a tag")
+def set_tag(key, value):
+ """Sets a tag on the active span"""
+ opentracing.tracer.active_span.set_tag(key, value)
+
+
+@ensure_active_span("log")
+def log_kv(key_values, timestamp=None):
+ """Log to the active span"""
+ opentracing.tracer.active_span.log_kv(key_values, timestamp)
+
+
+@ensure_active_span("set the traces operation name")
+def set_operation_name(operation_name):
+ """Sets the operation name of the active span"""
+ opentracing.tracer.active_span.set_operation_name(operation_name)
+
+
+# Injection and extraction
+
+
+@ensure_active_span("inject the span into a header")
+def inject_active_span_twisted_headers(headers, destination, check_destination=True):
+ """
+ Injects a span context into twisted headers in-place
+
+ Args:
+ headers (twisted.web.http_headers.Headers)
+ destination (str): address of entity receiving the span context. If check_destination
+ is true the context will only be injected if the destination matches the
+ opentracing whitelist
+ check_destination (bool): If false, destination will be ignored and the context
+ will always be injected.
+ span (opentracing.Span)
+
+ Returns:
+ In-place modification of headers
+
+ Note:
+ The headers set by the tracer are custom to the tracer implementation which
+ should be unique enough that they don't interfere with any headers set by
+ synapse or twisted. If we're still using jaeger these headers would be those
+ here:
+ https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
+ """
+
+ if check_destination and not whitelisted_homeserver(destination):
+ return
+
+ span = opentracing.tracer.active_span
+ carrier = {} # type: Dict[str, str]
+ opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
+
+ for key, value in carrier.items():
+ headers.addRawHeaders(key, value)
+
+
+@ensure_active_span("inject the span into a byte dict")
+def inject_active_span_byte_dict(headers, destination, check_destination=True):
+ """
+ Injects a span context into a dict where the headers are encoded as byte
+ strings
+
+ Args:
+ headers (dict)
+ destination (str): address of entity receiving the span context. If check_destination
+ is true the context will only be injected if the destination matches the
+ opentracing whitelist
+ check_destination (bool): If false, destination will be ignored and the context
+ will always be injected.
+ span (opentracing.Span)
+
+ Returns:
+ In-place modification of headers
+
+ Note:
+ The headers set by the tracer are custom to the tracer implementation which
+ should be unique enough that they don't interfere with any headers set by
+ synapse or twisted. If we're still using jaeger these headers would be those
+ here:
+ https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
+ """
+ if check_destination and not whitelisted_homeserver(destination):
+ return
+
+ span = opentracing.tracer.active_span
+
+ carrier = {} # type: Dict[str, str]
+ opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
+
+ for key, value in carrier.items():
+ headers[key.encode()] = [value.encode()]
+
+
+@ensure_active_span("inject the span into a text map")
+def inject_active_span_text_map(carrier, destination, check_destination=True):
+ """
+ Injects a span context into a dict
+
+ Args:
+ carrier (dict)
+ destination (str): address of entity receiving the span context. If check_destination
+ is true the context will only be injected if the destination matches the
+ opentracing whitelist
+ check_destination (bool): If false, destination will be ignored and the context
+ will always be injected.
+
+ Returns:
+ In-place modification of carrier
+
+ Note:
+ The headers set by the tracer are custom to the tracer implementation which
+ should be unique enough that they don't interfere with any headers set by
+ synapse or twisted. If we're still using jaeger these headers would be those
+ here:
+ https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/constants.py
+ """
+
+ if check_destination and not whitelisted_homeserver(destination):
+ return
+
+ opentracing.tracer.inject(
+ opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
+ )
+
+
+@ensure_active_span("get the active span context as a dict", ret={})
+def get_active_span_text_map(destination=None):
+ """
+ Gets a span context as a dict. This can be used instead of manually
+ injecting a span into an empty carrier.
+
+ Args:
+ destination (str): the name of the remote server.
+
+ Returns:
+ dict: the active span's context if opentracing is enabled, otherwise empty.
+ """
+
+ if destination and not whitelisted_homeserver(destination):
+ return {}
+
+ carrier = {} # type: Dict[str, str]
+ opentracing.tracer.inject(
+ opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
+ )
+
+ return carrier
+
+
+@ensure_active_span("get the span context as a string.", ret={})
+def active_span_context_as_string():
+ """
+ Returns:
+ The active span context encoded as a string.
+ """
+ carrier = {} # type: Dict[str, str]
+ if opentracing:
+ opentracing.tracer.inject(
+ opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
+ )
+ return json.dumps(carrier)
+
+
+@only_if_tracing
+def span_context_from_string(carrier):
+ """
+ Returns:
+ The active span context decoded from a string.
+ """
+ carrier = json.loads(carrier)
+ return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
+
+
+@only_if_tracing
+def extract_text_map(carrier):
+ """
+ Wrapper method for opentracing's tracer.extract for TEXT_MAP.
+ Args:
+ carrier (dict): a dict possibly containing a span context.
+
+ Returns:
+ The active span context extracted from carrier.
+ """
+ return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
+
+
+# Tracing decorators
+
+
+def trace(func=None, opname=None):
+ """
+ Decorator to trace a function.
+ Sets the operation name to that of the function's or that given
+ as operation_name. See the module's doc string for usage
+ examples.
+ """
+
+ def decorator(func):
+ if opentracing is None:
+ return func
+
+ _opname = opname if opname else func.__name__
+
+ @wraps(func)
+ def _trace_inner(*args, **kwargs):
+ if opentracing is None:
+ return func(*args, **kwargs)
+
+ scope = start_active_span(_opname)
+ scope.__enter__()
+
+ try:
+ result = func(*args, **kwargs)
+ if isinstance(result, defer.Deferred):
+
+ def call_back(result):
+ scope.__exit__(None, None, None)
+ return result
+
+ def err_back(result):
+ scope.span.set_tag(tags.ERROR, True)
+ scope.__exit__(None, None, None)
+ return result
+
+ result.addCallbacks(call_back, err_back)
+
+ else:
+ scope.__exit__(None, None, None)
+
+ return result
+
+ except Exception as e:
+ scope.__exit__(type(e), None, e.__traceback__)
+ raise
+
+ return _trace_inner
+
+ if func:
+ return decorator(func)
+ else:
+ return decorator
+
+
+def tag_args(func):
+ """
+ Tags all of the args to the active span.
+ """
+
+ if not opentracing:
+ return func
+
+ @wraps(func)
+ def _tag_args_inner(*args, **kwargs):
+ argspec = inspect.getargspec(func)
+ for i, arg in enumerate(argspec.args[1:]):
+ set_tag("ARG_" + arg, args[i])
+ set_tag("args", args[len(argspec.args) :])
+ set_tag("kwargs", kwargs)
+ return func(*args, **kwargs)
+
+ return _tag_args_inner
+
+
+def trace_servlet(servlet_name, extract_context=False):
+ """Decorator which traces a serlet. It starts a span with some servlet specific
+ tags such as the servlet_name and request information
+
+ Args:
+ servlet_name (str): The name to be used for the span's operation_name
+ extract_context (bool): Whether to attempt to extract the opentracing
+ context from the request the servlet is handling.
+
+ """
+
+ def _trace_servlet_inner_1(func):
+ if not opentracing:
+ return func
+
+ @wraps(func)
+ async def _trace_servlet_inner(request, *args, **kwargs):
+ request_tags = {
+ "request_id": request.get_request_id(),
+ tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
+ tags.HTTP_METHOD: request.get_method(),
+ tags.HTTP_URL: request.get_redacted_uri(),
+ tags.PEER_HOST_IPV6: request.getClientIP(),
+ }
+
+ if extract_context:
+ scope = start_active_span_from_request(
+ request, servlet_name, tags=request_tags
+ )
+ else:
+ scope = start_active_span(servlet_name, tags=request_tags)
+
+ with scope:
+ result = func(request, *args, **kwargs)
+
+ if not isinstance(result, (types.CoroutineType, defer.Deferred)):
+ # Some servlets aren't async and just return results
+ # directly, so we handle that here.
+ return result
+
+ return await result
+
+ return _trace_servlet_inner
+
+ return _trace_servlet_inner_1
diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py
new file mode 100644
index 0000000000..4eed4f2338
--- /dev/null
+++ b/synapse/logging/scopecontextmanager.py
@@ -0,0 +1,138 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.import logging
+
+import logging
+
+from opentracing import Scope, ScopeManager
+
+import twisted
+
+from synapse.logging.context import LoggingContext, nested_logging_context
+
+logger = logging.getLogger(__name__)
+
+
+class LogContextScopeManager(ScopeManager):
+ """
+ The LogContextScopeManager tracks the active scope in opentracing
+ by using the log contexts which are native to synapse. This is so
+ that the basic opentracing api can be used across twisted defereds.
+ (I would love to break logcontexts and this into an OS package. but
+ let's wait for twisted's contexts to be released.)
+ """
+
+ def __init__(self, config):
+ pass
+
+ @property
+ def active(self):
+ """
+ Returns the currently active Scope which can be used to access the
+ currently active Scope.span.
+ If there is a non-null Scope, its wrapped Span
+ becomes an implicit parent of any newly-created Span at
+ Tracer.start_active_span() time.
+
+ Return:
+ (Scope) : the Scope that is active, or None if not
+ available.
+ """
+ ctx = LoggingContext.current_context()
+ if ctx is LoggingContext.sentinel:
+ return None
+ else:
+ return ctx.scope
+
+ def activate(self, span, finish_on_close):
+ """
+ Makes a Span active.
+ Args
+ span (Span): the span that should become active.
+ finish_on_close (Boolean): whether Span should be automatically
+ finished when Scope.close() is called.
+
+ Returns:
+ Scope to control the end of the active period for
+ *span*. It is a programming error to neglect to call
+ Scope.close() on the returned instance.
+ """
+
+ enter_logcontext = False
+ ctx = LoggingContext.current_context()
+
+ if ctx is LoggingContext.sentinel:
+ # We don't want this scope to affect.
+ logger.error("Tried to activate scope outside of loggingcontext")
+ return Scope(None, span)
+ elif ctx.scope is not None:
+ # We want the logging scope to look exactly the same so we give it
+ # a blank suffix
+ ctx = nested_logging_context("")
+ enter_logcontext = True
+
+ scope = _LogContextScope(self, span, ctx, enter_logcontext, finish_on_close)
+ ctx.scope = scope
+ return scope
+
+
+class _LogContextScope(Scope):
+ """
+ A custom opentracing scope. The only significant difference is that it will
+ close the log context it's related to if the logcontext was created specifically
+ for this scope.
+ """
+
+ def __init__(self, manager, span, logcontext, enter_logcontext, finish_on_close):
+ """
+ Args:
+ manager (LogContextScopeManager):
+ the manager that is responsible for this scope.
+ span (Span):
+ the opentracing span which this scope represents the local
+ lifetime for.
+ logcontext (LogContext):
+ the logcontext to which this scope is attached.
+ enter_logcontext (Boolean):
+ if True the logcontext will be entered and exited when the scope
+ is entered and exited respectively
+ finish_on_close (Boolean):
+ if True finish the span when the scope is closed
+ """
+ super(_LogContextScope, self).__init__(manager, span)
+ self.logcontext = logcontext
+ self._finish_on_close = finish_on_close
+ self._enter_logcontext = enter_logcontext
+
+ def __enter__(self):
+ if self._enter_logcontext:
+ self.logcontext.__enter__()
+
+ def __exit__(self, type, value, traceback):
+ if type == twisted.internet.defer._DefGen_Return:
+ super(_LogContextScope, self).__exit__(None, None, None)
+ else:
+ super(_LogContextScope, self).__exit__(type, value, traceback)
+ if self._enter_logcontext:
+ self.logcontext.__exit__(type, value, traceback)
+ else: # the logcontext existed before the creation of the scope
+ self.logcontext.scope = None
+
+ def close(self):
+ if self.manager.active is not self:
+ logger.error("Tried to close a non-active scope!")
+ return
+
+ if self._finish_on_close:
+ self.span.finish()
diff --git a/synapse/util/logutils.py b/synapse/logging/utils.py
index ef31458226..0c2527bd86 100644
--- a/synapse/util/logutils.py
+++ b/synapse/logging/utils.py
@@ -44,7 +44,7 @@ def _log_debug_as_f(f, msg, msg_args):
lineno=lineno,
msg=msg,
args=msg_args,
- exc_info=None
+ exc_info=None,
)
logger.handle(record)
@@ -70,20 +70,11 @@ def log_function(f):
r = r[:50] + "..."
return r
- func_args = [
- "%s=%s" % (k, format(v)) for k, v in bound_args.items()
- ]
+ func_args = ["%s=%s" % (k, format(v)) for k, v in bound_args.items()]
- msg_args = {
- "func_name": func_name,
- "args": ", ".join(func_args)
- }
+ msg_args = {"func_name": func_name, "args": ", ".join(func_args)}
- _log_debug_as_f(
- f,
- "Invoked '%(func_name)s' with args: %(args)s",
- msg_args
- )
+ _log_debug_as_f(f, "Invoked '%(func_name)s' with args: %(args)s", msg_args)
return f(*args, **kwargs)
@@ -103,19 +94,13 @@ def time_function(f):
start = time.clock()
try:
- _log_debug_as_f(
- f,
- "[FUNC START] {%s-%d}",
- (func_name, id),
- )
+ _log_debug_as_f(f, "[FUNC START] {%s-%d}", (func_name, id))
r = f(*args, **kwargs)
finally:
end = time.clock()
_log_debug_as_f(
- f,
- "[FUNC END] {%s-%d} %.3f sec",
- (func_name, id, end - start,),
+ f, "[FUNC END] {%s-%d} %.3f sec", (func_name, id, end - start)
)
return r
@@ -134,12 +119,15 @@ def trace_function(f):
logger = logging.getLogger(name)
level = logging.DEBUG
- s = inspect.currentframe().f_back
+ frame = inspect.currentframe()
+ if frame is None:
+ raise Exception("Can't get current frame!")
+
+ s = frame.f_back
to_print = [
- "\t%s:%s %s. Args: args=%s, kwargs=%s" % (
- pathname, linenum, func_name, args, kwargs
- )
+ "\t%s:%s %s. Args: args=%s, kwargs=%s"
+ % (pathname, linenum, func_name, args, kwargs)
]
while s:
if True or s.f_globals["__name__"].startswith("synapse"):
@@ -147,9 +135,7 @@ def trace_function(f):
args_string = inspect.formatargvalues(*inspect.getargvalues(s))
to_print.append(
- "\t%s:%d %s. Args: %s" % (
- filename, lineno, function, args_string
- )
+ "\t%s:%d %s. Args: %s" % (filename, lineno, function, args_string)
)
s = s.f_back
@@ -162,8 +148,8 @@ def trace_function(f):
pathname=pathname,
lineno=lineno,
msg=msg,
- args=None,
- exc_info=None
+ args=(),
+ exc_info=None,
)
logger.handle(record)
@@ -175,24 +161,32 @@ def trace_function(f):
def get_previous_frames():
- s = inspect.currentframe().f_back.f_back
+
+ frame = inspect.currentframe()
+ if frame is None:
+ raise Exception("Can't get current frame!")
+
+ s = frame.f_back.f_back
to_return = []
while s:
if s.f_globals["__name__"].startswith("synapse"):
filename, lineno, function, _, _ = inspect.getframeinfo(s)
args_string = inspect.formatargvalues(*inspect.getargvalues(s))
- to_return.append("{{ %s:%d %s - Args: %s }}" % (
- filename, lineno, function, args_string
- ))
+ to_return.append(
+ "{{ %s:%d %s - Args: %s }}" % (filename, lineno, function, args_string)
+ )
s = s.f_back
- return ", ". join(to_return)
+ return ", ".join(to_return)
def get_previous_frame(ignore=[]):
- s = inspect.currentframe().f_back.f_back
+ frame = inspect.currentframe()
+ if frame is None:
+ raise Exception("Can't get current frame!")
+ s = frame.f_back.f_back
while s:
if s.f_globals["__name__"].startswith("synapse"):
@@ -201,7 +195,10 @@ def get_previous_frame(ignore=[]):
args_string = inspect.formatargvalues(*inspect.getargvalues(s))
return "{{ %s:%d %s - Args: %s }}" % (
- filename, lineno, function, args_string
+ filename,
+ lineno,
+ function,
+ args_string,
)
s = s.f_back
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index ef48984fdd..d2fd29acb4 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -20,27 +20,35 @@ import os
import platform
import threading
import time
+from typing import Callable, Dict, Iterable, Optional, Tuple, Union
import six
import attr
from prometheus_client import Counter, Gauge, Histogram
-from prometheus_client.core import REGISTRY, GaugeMetricFamily
+from prometheus_client.core import REGISTRY, GaugeMetricFamily, HistogramMetricFamily
from twisted.internet import reactor
+import synapse
+from synapse.metrics._exposition import (
+ MetricsResource,
+ generate_latest,
+ start_http_server,
+)
+from synapse.util.versionstring import get_version_string
+
logger = logging.getLogger(__name__)
+METRICS_PREFIX = "/_synapse/metrics"
+
running_on_pypy = platform.python_implementation() == "PyPy"
-all_metrics = []
-all_collectors = []
-all_gauges = {}
+all_gauges = {} # type: Dict[str, Union[LaterGauge, InFlightGauge, BucketCollector]]
HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
class RegistryProxy(object):
-
@staticmethod
def collect():
for metric in REGISTRY.collect():
@@ -51,10 +59,12 @@ class RegistryProxy(object):
@attr.s(hash=True)
class LaterGauge(object):
- name = attr.ib()
- desc = attr.ib()
- labels = attr.ib(hash=False)
- caller = attr.ib()
+ name = attr.ib(type=str)
+ desc = attr.ib(type=str)
+ labels = attr.ib(hash=False, type=Optional[Iterable[str]])
+ # callback: should either return a value (if there are no labels for this metric),
+ # or dict mapping from a label tuple to a value
+ caller = attr.ib(type=Callable[[], Union[Dict[Tuple[str, ...], float], float]])
def collect(self):
@@ -63,10 +73,7 @@ class LaterGauge(object):
try:
calls = self.caller()
except Exception:
- logger.exception(
- "Exception running callback for LaterGauge(%s)",
- self.name,
- )
+ logger.exception("Exception running callback for LaterGauge(%s)", self.name)
yield g
return
@@ -116,13 +123,11 @@ class InFlightGauge(object):
# Create a class which have the sub_metrics values as attributes, which
# default to 0 on initialization. Used to pass to registered callbacks.
self._metrics_class = attr.make_class(
- "_MetricsEntry",
- attrs={x: attr.ib(0) for x in sub_metrics},
- slots=True,
+ "_MetricsEntry", attrs={x: attr.ib(0) for x in sub_metrics}, slots=True
)
# Counts number of in flight blocks for a given set of label values
- self._registrations = {}
+ self._registrations = {} # type: Dict
# Protects access to _registrations
self._lock = threading.Lock()
@@ -157,7 +162,9 @@ class InFlightGauge(object):
Note: may be called by a separate thread.
"""
- in_flight = GaugeMetricFamily(self.name + "_total", self.desc, labels=self.labels)
+ in_flight = GaugeMetricFamily(
+ self.name + "_total", self.desc, labels=self.labels
+ )
metrics_by_key = {}
@@ -179,7 +186,9 @@ class InFlightGauge(object):
yield in_flight
for name in self.sub_metrics:
- gauge = GaugeMetricFamily("_".join([self.name, name]), "", labels=self.labels)
+ gauge = GaugeMetricFamily(
+ "_".join([self.name, name]), "", labels=self.labels
+ )
for key, metrics in six.iteritems(metrics_by_key):
gauge.add_metric(key, getattr(metrics, name))
yield gauge
@@ -193,17 +202,76 @@ class InFlightGauge(object):
all_gauges[self.name] = self
+@attr.s(hash=True)
+class BucketCollector(object):
+ """
+ Like a Histogram, but allows buckets to be point-in-time instead of
+ incrementally added to.
+
+ Args:
+ name (str): Base name of metric to be exported to Prometheus.
+ data_collector (callable -> dict): A synchronous callable that
+ returns a dict mapping bucket to number of items in the
+ bucket. If these buckets are not the same as the buckets
+ given to this class, they will be remapped into them.
+ buckets (list[float]): List of floats/ints of the buckets to
+ give to Prometheus. +Inf is ignored, if given.
+
+ """
+
+ name = attr.ib()
+ data_collector = attr.ib()
+ buckets = attr.ib()
+
+ def collect(self):
+
+ # Fetch the data -- this must be synchronous!
+ data = self.data_collector()
+
+ buckets = {} # type: Dict[float, int]
+
+ res = []
+ for x in data.keys():
+ for i, bound in enumerate(self.buckets):
+ if x <= bound:
+ buckets[bound] = buckets.get(bound, 0) + data[x]
+
+ for i in self.buckets:
+ res.append([str(i), buckets.get(i, 0)])
+
+ res.append(["+Inf", sum(data.values())])
+
+ metric = HistogramMetricFamily(
+ self.name, "", buckets=res, sum_value=sum(x * y for x, y in data.items())
+ )
+ yield metric
+
+ def __attrs_post_init__(self):
+ self.buckets = [float(x) for x in self.buckets if x != "+Inf"]
+ if self.buckets != sorted(self.buckets):
+ raise ValueError("Buckets not sorted")
+
+ self.buckets = tuple(self.buckets)
+
+ if self.name in all_gauges.keys():
+ logger.warning("%s already registered, reregistering" % (self.name,))
+ REGISTRY.unregister(all_gauges.pop(self.name))
+
+ REGISTRY.register(self)
+ all_gauges[self.name] = self
+
+
#
# Detailed CPU metrics
#
-class CPUMetrics(object):
+class CPUMetrics(object):
def __init__(self):
ticks_per_sec = 100
try:
# Try and get the system config
- ticks_per_sec = os.sysconf('SC_CLK_TCK')
+ ticks_per_sec = os.sysconf("SC_CLK_TCK")
except (ValueError, TypeError, AttributeError):
pass
@@ -237,13 +305,28 @@ gc_time = Histogram(
"python_gc_time",
"Time taken to GC (sec)",
["gen"],
- buckets=[0.0025, 0.005, 0.01, 0.025, 0.05, 0.10, 0.25, 0.50, 1.00, 2.50,
- 5.00, 7.50, 15.00, 30.00, 45.00, 60.00],
+ buckets=[
+ 0.0025,
+ 0.005,
+ 0.01,
+ 0.025,
+ 0.05,
+ 0.10,
+ 0.25,
+ 0.50,
+ 1.00,
+ 2.50,
+ 5.00,
+ 7.50,
+ 15.00,
+ 30.00,
+ 45.00,
+ 60.00,
+ ],
)
class GCCounts(object):
-
def collect(self):
cm = GaugeMetricFamily("python_gc_counts", "GC object counts", labels=["gen"])
for n, m in enumerate(gc.get_count()):
@@ -279,9 +362,7 @@ sent_transactions_counter = Counter("synapse_federation_client_sent_transactions
events_processed_counter = Counter("synapse_federation_client_events_processed", "")
event_processing_loop_counter = Counter(
- "synapse_event_processing_loop_count",
- "Event processing loop iterations",
- ["name"],
+ "synapse_event_processing_loop_count", "Event processing loop iterations", ["name"]
)
event_processing_loop_room_count = Counter(
@@ -307,11 +388,20 @@ event_processing_last_ts = Gauge("synapse_event_processing_last_ts", "", ["name"
# finished being processed.
event_processing_lag = Gauge("synapse_event_processing_lag", "", ["name"])
+# Build info of the running server.
+build_info = Gauge(
+ "synapse_build_info", "Build information", ["pythonversion", "version", "osversion"]
+)
+build_info.labels(
+ " ".join([platform.python_implementation(), platform.python_version()]),
+ get_version_string(synapse),
+ " ".join([platform.system(), platform.release()]),
+).set(1)
+
last_ticked = time.time()
class ReactorLastSeenMetric(object):
-
def collect(self):
cm = GaugeMetricFamily(
"python_twisted_reactor_last_seen",
@@ -325,7 +415,6 @@ REGISTRY.register(ReactorLastSeenMetric())
def runUntilCurrentTimer(func):
-
@functools.wraps(func)
def f(*args, **kwargs):
now = reactor.seconds()
@@ -369,7 +458,10 @@ def runUntilCurrentTimer(func):
counts = gc.get_count()
for i in (2, 1, 0):
if threshold[i] < counts[i]:
- logger.info("Collecting gc %d", i)
+ if i == 0:
+ logger.debug("Collecting gc %d", i)
+ else:
+ logger.info("Collecting gc %d", i)
start = time.time()
unreachable = gc.collect(i)
@@ -399,3 +491,12 @@ try:
gc.disable()
except AttributeError:
pass
+
+__all__ = [
+ "MetricsResource",
+ "generate_latest",
+ "start_http_server",
+ "LaterGauge",
+ "InFlightGauge",
+ "BucketCollector",
+]
diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py
new file mode 100644
index 0000000000..a248103191
--- /dev/null
+++ b/synapse/metrics/_exposition.py
@@ -0,0 +1,260 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015-2019 Prometheus Python Client Developers
+# Copyright 2019 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+This code is based off `prometheus_client/exposition.py` from version 0.7.1.
+
+Due to the renaming of metrics in prometheus_client 0.4.0, this customised
+vendoring of the code will emit both the old versions that Synapse dashboards
+expect, and the newer "best practice" version of the up-to-date official client.
+"""
+
+import math
+import threading
+from collections import namedtuple
+from http.server import BaseHTTPRequestHandler, HTTPServer
+from socketserver import ThreadingMixIn
+from urllib.parse import parse_qs, urlparse
+
+from prometheus_client import REGISTRY
+
+from twisted.web.resource import Resource
+
+try:
+ from prometheus_client.samples import Sample
+except ImportError:
+ Sample = namedtuple( # type: ignore[no-redef] # noqa
+ "Sample", ["name", "labels", "value", "timestamp", "exemplar"]
+ )
+
+
+CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8")
+
+
+INF = float("inf")
+MINUS_INF = float("-inf")
+
+
+def floatToGoString(d):
+ d = float(d)
+ if d == INF:
+ return "+Inf"
+ elif d == MINUS_INF:
+ return "-Inf"
+ elif math.isnan(d):
+ return "NaN"
+ else:
+ s = repr(d)
+ dot = s.find(".")
+ # Go switches to exponents sooner than Python.
+ # We only need to care about positive values for le/quantile.
+ if d > 0 and dot > 6:
+ mantissa = "{0}.{1}{2}".format(s[0], s[1:dot], s[dot + 1 :]).rstrip("0.")
+ return "{0}e+0{1}".format(mantissa, dot - 1)
+ return s
+
+
+def sample_line(line, name):
+ if line.labels:
+ labelstr = "{{{0}}}".format(
+ ",".join(
+ [
+ '{0}="{1}"'.format(
+ k,
+ v.replace("\\", r"\\").replace("\n", r"\n").replace('"', r"\""),
+ )
+ for k, v in sorted(line.labels.items())
+ ]
+ )
+ )
+ else:
+ labelstr = ""
+ timestamp = ""
+ if line.timestamp is not None:
+ # Convert to milliseconds.
+ timestamp = " {0:d}".format(int(float(line.timestamp) * 1000))
+ return "{0}{1} {2}{3}\n".format(
+ name, labelstr, floatToGoString(line.value), timestamp
+ )
+
+
+def nameify_sample(sample):
+ """
+ If we get a prometheus_client<0.4.0 sample as a tuple, transform it into a
+ namedtuple which has the names we expect.
+ """
+ if not isinstance(sample, Sample):
+ sample = Sample(*sample, None, None)
+
+ return sample
+
+
+def generate_latest(registry, emit_help=False):
+ output = []
+
+ for metric in registry.collect():
+
+ if metric.name.startswith("__unused"):
+ continue
+
+ if not metric.samples:
+ # No samples, don't bother.
+ continue
+
+ mname = metric.name
+ mnewname = metric.name
+ mtype = metric.type
+
+ # OpenMetrics -> Prometheus
+ if mtype == "counter":
+ mnewname = mnewname + "_total"
+ elif mtype == "info":
+ mtype = "gauge"
+ mnewname = mnewname + "_info"
+ elif mtype == "stateset":
+ mtype = "gauge"
+ elif mtype == "gaugehistogram":
+ mtype = "histogram"
+ elif mtype == "unknown":
+ mtype = "untyped"
+
+ # Output in the old format for compatibility.
+ if emit_help:
+ output.append(
+ "# HELP {0} {1}\n".format(
+ mname,
+ metric.documentation.replace("\\", r"\\").replace("\n", r"\n"),
+ )
+ )
+ output.append("# TYPE {0} {1}\n".format(mname, mtype))
+ for sample in map(nameify_sample, metric.samples):
+ # Get rid of the OpenMetrics specific samples
+ for suffix in ["_created", "_gsum", "_gcount"]:
+ if sample.name.endswith(suffix):
+ break
+ else:
+ newname = sample.name.replace(mnewname, mname)
+ if ":" in newname and newname.endswith("_total"):
+ newname = newname[: -len("_total")]
+ output.append(sample_line(sample, newname))
+
+ # Get rid of the weird colon things while we're at it
+ if mtype == "counter":
+ mnewname = mnewname.replace(":total", "")
+ mnewname = mnewname.replace(":", "_")
+
+ if mname == mnewname:
+ continue
+
+ # Also output in the new format, if it's different.
+ if emit_help:
+ output.append(
+ "# HELP {0} {1}\n".format(
+ mnewname,
+ metric.documentation.replace("\\", r"\\").replace("\n", r"\n"),
+ )
+ )
+ output.append("# TYPE {0} {1}\n".format(mnewname, mtype))
+ for sample in map(nameify_sample, metric.samples):
+ # Get rid of the OpenMetrics specific samples
+ for suffix in ["_created", "_gsum", "_gcount"]:
+ if sample.name.endswith(suffix):
+ break
+ else:
+ output.append(
+ sample_line(
+ sample, sample.name.replace(":total", "").replace(":", "_")
+ )
+ )
+
+ return "".join(output).encode("utf-8")
+
+
+class MetricsHandler(BaseHTTPRequestHandler):
+ """HTTP handler that gives metrics from ``REGISTRY``."""
+
+ registry = REGISTRY
+
+ def do_GET(self):
+ registry = self.registry
+ params = parse_qs(urlparse(self.path).query)
+
+ if "help" in params:
+ emit_help = True
+ else:
+ emit_help = False
+
+ try:
+ output = generate_latest(registry, emit_help=emit_help)
+ except Exception:
+ self.send_error(500, "error generating metric output")
+ raise
+ self.send_response(200)
+ self.send_header("Content-Type", CONTENT_TYPE_LATEST)
+ self.end_headers()
+ self.wfile.write(output)
+
+ def log_message(self, format, *args):
+ """Log nothing."""
+
+ @classmethod
+ def factory(cls, registry):
+ """Returns a dynamic MetricsHandler class tied
+ to the passed registry.
+ """
+ # This implementation relies on MetricsHandler.registry
+ # (defined above and defaulted to REGISTRY).
+
+ # As we have unicode_literals, we need to create a str()
+ # object for type().
+ cls_name = str(cls.__name__)
+ MyMetricsHandler = type(cls_name, (cls, object), {"registry": registry})
+ return MyMetricsHandler
+
+
+class _ThreadingSimpleServer(ThreadingMixIn, HTTPServer):
+ """Thread per request HTTP server."""
+
+ # Make worker threads "fire and forget". Beginning with Python 3.7 this
+ # prevents a memory leak because ``ThreadingMixIn`` starts to gather all
+ # non-daemon threads in a list in order to join on them at server close.
+ # Enabling daemon threads virtually makes ``_ThreadingSimpleServer`` the
+ # same as Python 3.7's ``ThreadingHTTPServer``.
+ daemon_threads = True
+
+
+def start_http_server(port, addr="", registry=REGISTRY):
+ """Starts an HTTP server for prometheus metrics as a daemon thread"""
+ CustomMetricsHandler = MetricsHandler.factory(registry)
+ httpd = _ThreadingSimpleServer((addr, port), CustomMetricsHandler)
+ t = threading.Thread(target=httpd.serve_forever)
+ t.daemon = True
+ t.start()
+
+
+class MetricsResource(Resource):
+ """
+ Twisted ``Resource`` that serves prometheus metrics.
+ """
+
+ isLeaf = True
+
+ def __init__(self, registry=REGISTRY):
+ self.registry = registry
+
+ def render_GET(self, request):
+ request.setHeader(b"Content-Type", CONTENT_TYPE_LATEST.encode("ascii"))
+ return generate_latest(self.registry)
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 037f1c490e..8449ef82f7 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -15,6 +15,9 @@
import logging
import threading
+from asyncio import iscoroutine
+from functools import wraps
+from typing import Dict, Set
import six
@@ -22,7 +25,7 @@ from prometheus_client.core import REGISTRY, Counter, GaugeMetricFamily
from twisted.internet import defer
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
+from synapse.logging.context import LoggingContext, PreserveLoggingContext
logger = logging.getLogger(__name__)
@@ -60,8 +63,10 @@ _background_process_db_txn_count = Counter(
_background_process_db_txn_duration = Counter(
"synapse_background_process_db_txn_duration_seconds",
- ("Seconds spent by background processes waiting for database "
- "transactions, excluding scheduling time"),
+ (
+ "Seconds spent by background processes waiting for database "
+ "transactions, excluding scheduling time"
+ ),
["name"],
registry=None,
)
@@ -76,13 +81,13 @@ _background_process_db_sched_duration = Counter(
# map from description to a counter, so that we can name our logcontexts
# incrementally. (It actually duplicates _background_process_start_count, but
# it's much simpler to do so than to try to combine them.)
-_background_process_counts = dict() # type: dict[str, int]
+_background_process_counts = {} # type: Dict[str, int]
# map from description to the currently running background processes.
#
# it's kept as a dict of sets rather than a big set so that we can keep track
# of process descriptions that no longer have any active processes.
-_background_processes = dict() # type: dict[str, set[_BackgroundProcess]]
+_background_processes = {} # type: Dict[str, Set[_BackgroundProcess]]
# A lock that covers the above dicts
_bg_metrics_lock = threading.Lock()
@@ -94,6 +99,7 @@ class _Collector(object):
Ensures that all of the metrics are up-to-date with any in-flight processes
before they are returned.
"""
+
def collect(self):
background_process_in_flight_count = GaugeMetricFamily(
"synapse_background_process_in_flight_count",
@@ -105,14 +111,11 @@ class _Collector(object):
# We also copy the process lists as that can also change
with _bg_metrics_lock:
_background_processes_copy = {
- k: list(v)
- for k, v in six.iteritems(_background_processes)
+ k: list(v) for k, v in six.iteritems(_background_processes)
}
for desc, processes in six.iteritems(_background_processes_copy):
- background_process_in_flight_count.add_metric(
- (desc,), len(processes),
- )
+ background_process_in_flight_count.add_metric((desc,), len(processes))
for process in processes:
process.update_metrics()
@@ -121,11 +124,11 @@ class _Collector(object):
# now we need to run collect() over each of the static Counters, and
# yield each metric they return.
for m in (
- _background_process_ru_utime,
- _background_process_ru_stime,
- _background_process_db_txn_count,
- _background_process_db_txn_duration,
- _background_process_db_sched_duration,
+ _background_process_ru_utime,
+ _background_process_ru_stime,
+ _background_process_db_txn_count,
+ _background_process_db_txn_duration,
+ _background_process_db_sched_duration,
):
for r in m.collect():
yield r
@@ -151,14 +154,12 @@ class _BackgroundProcess(object):
_background_process_ru_utime.labels(self.desc).inc(diff.ru_utime)
_background_process_ru_stime.labels(self.desc).inc(diff.ru_stime)
- _background_process_db_txn_count.labels(self.desc).inc(
- diff.db_txn_count,
- )
+ _background_process_db_txn_count.labels(self.desc).inc(diff.db_txn_count)
_background_process_db_txn_duration.labels(self.desc).inc(
- diff.db_txn_duration_sec,
+ diff.db_txn_duration_sec
)
_background_process_db_sched_duration.labels(self.desc).inc(
- diff.db_sched_duration_sec,
+ diff.db_sched_duration_sec
)
@@ -175,13 +176,14 @@ def run_as_background_process(desc, func, *args, **kwargs):
Args:
desc (str): a description for this background process type
- func: a function, which may return a Deferred
+ func: a function, which may return a Deferred or a coroutine
args: positional args for func
kwargs: keyword args for func
Returns: Deferred which returns the result of func, but note that it does not
follow the synapse logcontext rules.
"""
+
@defer.inlineCallbacks
def run():
with _bg_metrics_lock:
@@ -198,7 +200,17 @@ def run_as_background_process(desc, func, *args, **kwargs):
_background_processes.setdefault(desc, set()).add(proc)
try:
- yield func(*args, **kwargs)
+ result = func(*args, **kwargs)
+
+ # We probably don't have an ensureDeferred in our call stack to handle
+ # coroutine results, so we need to ensureDeferred here.
+ #
+ # But we need this check because ensureDeferred doesn't like being
+ # called on immediate values (as opposed to Deferreds or coroutines).
+ if iscoroutine(result):
+ result = defer.ensureDeferred(result)
+
+ return (yield result)
except Exception:
logger.exception("Background process '%s' threw an exception", desc)
finally:
@@ -209,3 +221,20 @@ def run_as_background_process(desc, func, *args, **kwargs):
with PreserveLoggingContext():
return run()
+
+
+def wrap_as_background_process(desc):
+ """Decorator that wraps a function that gets called as a background
+ process.
+
+ Equivalent of calling the function with `run_as_background_process`
+ """
+
+ def wrap_as_background_process_inner(func):
+ @wraps(func)
+ def wrap_as_background_process_inner_2(*args, **kwargs):
+ return run_as_background_process(desc, func, *args, **kwargs)
+
+ return wrap_as_background_process_inner_2
+
+ return wrap_as_background_process_inner
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index b3abd1b3c6..c7fffd72f2 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,17 +13,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
+from synapse.http.site import SynapseRequest
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import UserID
+"""
+This package defines the 'stable' API which can be used by extension modules which
+are loaded into Synapse.
+"""
+
+__all__ = ["errors", "make_deferred_yieldable", "run_in_background", "ModuleApi"]
+
+logger = logging.getLogger(__name__)
+
class ModuleApi(object):
- """A proxy object that gets passed to password auth providers so they
+ """A proxy object that gets passed to various plugin modules so they
can register new users etc if necessary.
"""
+
def __init__(self, hs, auth_handler):
- self.hs = hs
+ self._hs = hs
self._store = hs.get_datastore()
self._auth = hs.get_auth()
@@ -57,9 +72,9 @@ class ModuleApi(object):
Returns:
str: qualified @user:id
"""
- if username.startswith('@'):
+ if username.startswith("@"):
return username
- return UserID(username, self.hs.hostname).to_string()
+ return UserID(username, self._hs.hostname).to_string()
def check_user_exists(self, user_id):
"""Check if user exists.
@@ -75,8 +90,13 @@ class ModuleApi(object):
@defer.inlineCallbacks
def register(self, localpart, displayname=None, emails=[]):
- """Registers a new user with given localpart and optional
- displayname, emails.
+ """Registers a new user with given localpart and optional displayname, emails.
+
+ Also returns an access token for the new user.
+
+ Deprecated: avoid this, as it generates a new device with no way to
+ return that device to the user. Prefer separate calls to register_user and
+ register_device.
Args:
localpart (str): The localpart of the new user.
@@ -84,16 +104,74 @@ class ModuleApi(object):
emails (List[str]): Emails to bind to the new user.
Returns:
- Deferred: a 2-tuple of (user_id, access_token)
+ Deferred[tuple[str, str]]: a 2-tuple of (user_id, access_token)
+ """
+ logger.warning(
+ "Using deprecated ModuleApi.register which creates a dummy user device."
+ )
+ user_id = yield self.register_user(localpart, displayname, emails)
+ _, access_token = yield self.register_device(user_id)
+ return user_id, access_token
+
+ def register_user(self, localpart, displayname=None, emails=[]):
+ """Registers a new user with given localpart and optional displayname, emails.
+
+ Args:
+ localpart (str): The localpart of the new user.
+ displayname (str|None): The displayname of the new user.
+ emails (List[str]): Emails to bind to the new user.
+
+ Raises:
+ SynapseError if there is an error performing the registration. Check the
+ 'errcode' property for more information on the reason for failure
+
+ Returns:
+ Deferred[str]: user_id
"""
- # Register the user
- reg = self.hs.get_registration_handler()
- user_id, access_token = yield reg.register(
- localpart=localpart, default_display_name=displayname,
- bind_emails=emails,
+ return self._hs.get_registration_handler().register_user(
+ localpart=localpart, default_display_name=displayname, bind_emails=emails
)
- defer.returnValue((user_id, access_token))
+ def register_device(self, user_id, device_id=None, initial_display_name=None):
+ """Register a device for a user and generate an access token.
+
+ Args:
+ user_id (str): full canonical @user:id
+ device_id (str|None): The device ID to check, or None to generate
+ a new one.
+ initial_display_name (str|None): An optional display name for the
+ device.
+
+ Returns:
+ defer.Deferred[tuple[str, str]]: Tuple of device ID and access token
+ """
+ return self._hs.get_registration_handler().register_device(
+ user_id=user_id,
+ device_id=device_id,
+ initial_display_name=initial_display_name,
+ )
+
+ def record_user_external_id(
+ self, auth_provider_id: str, remote_user_id: str, registered_user_id: str
+ ) -> defer.Deferred:
+ """Record a mapping from an external user id to a mxid
+
+ Args:
+ auth_provider: identifier for the remote auth provider
+ external_id: id on that system
+ user_id: complete mxid that it is mapped to
+ """
+ return self._store.record_user_external_id(
+ auth_provider_id, remote_user_id, registered_user_id
+ )
+
+ def generate_short_term_login_token(
+ self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
+ ) -> str:
+ """Generate a login token suitable for m.login.token authentication"""
+ return self._hs.get_macaroon_generator().generate_short_term_login_token(
+ user_id, duration_in_ms
+ )
@defer.inlineCallbacks
def invalidate_access_token(self, access_token):
@@ -115,7 +193,7 @@ class ModuleApi(object):
user_id = user_info["user"].to_string()
if device_id:
# delete the device, which will also delete its access tokens
- yield self.hs.get_device_handler().delete_device(user_id, device_id)
+ yield self._hs.get_device_handler().delete_device(user_id, device_id)
else:
# no associated device. Just delete the access token.
yield self._auth_handler.delete_access_token(access_token)
@@ -133,4 +211,22 @@ class ModuleApi(object):
Returns:
Deferred[object]: result of func
"""
- return self._store.runInteraction(desc, func, *args, **kwargs)
+ return self._store.db.runInteraction(desc, func, *args, **kwargs)
+
+ def complete_sso_login(
+ self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
+ ):
+ """Complete a SSO login by redirecting the user to a page to confirm whether they
+ want their access token sent to `client_redirect_url`, or redirect them to that
+ URL with a token directly if the URL matches with one of the whitelisted clients.
+
+ Args:
+ registered_user_id: The MXID that has been registered as a previous step of
+ of this SSO login.
+ request: The request to respond to.
+ client_redirect_url: The URL to which to offer to redirect the user (or to
+ redirect them directly if whitelisted).
+ """
+ self._auth_handler.complete_sso_login(
+ registered_user_id, request, client_redirect_url,
+ )
diff --git a/synapse/replication/slave/storage/user_directory.py b/synapse/module_api/errors.py
index 0d7b1a4a83..b15441772c 100644
--- a/synapse/replication/slave/storage/user_directory.py
+++ b/synapse/module_api/errors.py
@@ -13,10 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.user_directory import UserDirectoryStore
+"""Exception types which are exposed as part of the stable module API"""
-from ._base import BaseSlavedStore
-
-
-class SlavedUserDirectoryStore(UserDirectoryStore, BaseSlavedStore):
- pass
+from synapse.api.errors import RedirectException, SynapseError # noqa: F401
diff --git a/synapse/notifier.py b/synapse/notifier.py
index ff589660da..6132727cbd 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -15,20 +15,22 @@
import logging
from collections import namedtuple
+from typing import Callable, List
from prometheus_client import Counter
from twisted.internet import defer
+import synapse.server
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError
from synapse.handlers.presence import format_user_presence_state
+from synapse.logging.context import PreserveLoggingContext
+from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import StreamToken
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
-from synapse.util.logcontext import PreserveLoggingContext
-from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
from synapse.visibility import filter_events_for_client
@@ -37,7 +39,8 @@ logger = logging.getLogger(__name__)
notified_events_counter = Counter("synapse_notifier_notified_events", "")
users_woken_by_stream_counter = Counter(
- "synapse_notifier_users_woken_by_stream", "", ["stream"])
+ "synapse_notifier_users_woken_by_stream", "", ["stream"]
+)
# TODO(paul): Should be shared somewhere
@@ -55,6 +58,7 @@ class _NotificationListener(object):
The events stream handler will have yielded to the deferred, so to
notify the handler it is sufficient to resolve the deferred.
"""
+
__slots__ = ["deferred"]
def __init__(self, deferred):
@@ -95,9 +99,7 @@ class _NotifierUserStream(object):
stream_id(str): The new id for the stream the event came from.
time_now_ms(int): The current time in milliseconds.
"""
- self.current_token = self.current_token.copy_and_advance(
- stream_key, stream_id
- )
+ self.current_token = self.current_token.copy_and_advance(stream_key, stream_id)
self.last_notified_token = self.current_token
self.last_notified_ms = time_now_ms
noify_deferred = self.notify_deferred
@@ -141,6 +143,7 @@ class _NotifierUserStream(object):
class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))):
def __nonzero__(self):
return bool(self.events)
+
__bool__ = __nonzero__ # python3
@@ -153,16 +156,22 @@ class Notifier(object):
UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
- def __init__(self, hs):
+ def __init__(self, hs: "synapse.server.HomeServer"):
self.user_to_user_stream = {}
self.room_to_user_streams = {}
self.hs = hs
+ self.storage = hs.get_storage()
self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore()
self.pending_new_room_events = []
- self.replication_callbacks = []
+ # Called when there are new things to stream over replication
+ self.replication_callbacks = [] # type: List[Callable[[], None]]
+
+ # Called when remote servers have come back online after having been
+ # down.
+ self.remote_server_up_callbacks = [] # type: List[Callable[[str], None]]
self.clock = hs.get_clock()
self.appservice_handler = hs.get_application_service_handler()
@@ -190,18 +199,20 @@ class Notifier(object):
all_user_streams.add(x)
return sum(stream.count_listeners() for stream in all_user_streams)
+
LaterGauge("synapse_notifier_listeners", "", [], count_listeners)
LaterGauge(
- "synapse_notifier_rooms", "", [],
+ "synapse_notifier_rooms",
+ "",
+ [],
lambda: count(bool, list(self.room_to_user_streams.values())),
)
LaterGauge(
- "synapse_notifier_users", "", [],
- lambda: len(self.user_to_user_stream),
+ "synapse_notifier_users", "", [], lambda: len(self.user_to_user_stream)
)
- def add_replication_callback(self, cb):
+ def add_replication_callback(self, cb: Callable[[], None]):
"""Add a callback that will be called when some new data is available.
Callback is not given any arguments. It should *not* return a Deferred - if
it needs to do any asynchronous work, a background thread should be started and
@@ -209,8 +220,15 @@ class Notifier(object):
"""
self.replication_callbacks.append(cb)
- def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
- extra_users=[]):
+ def add_remote_server_up_callback(self, cb: Callable[[str], None]):
+ """Add a callback that will be called when synapse detects a server
+ has been
+ """
+ self.remote_server_up_callbacks.append(cb)
+
+ def on_new_room_event(
+ self, event, room_stream_id, max_room_stream_id, extra_users=[]
+ ):
""" Used by handlers to inform the notifier something has happened
in the room, room event wise.
@@ -222,9 +240,7 @@ class Notifier(object):
until all previous events have been persisted before notifying
the client streams.
"""
- self.pending_new_room_events.append((
- room_stream_id, event, extra_users
- ))
+ self.pending_new_room_events.append((room_stream_id, event, extra_users))
self._notify_pending_new_room_events(max_room_stream_id)
self.notify_replication()
@@ -240,9 +256,9 @@ class Notifier(object):
self.pending_new_room_events = []
for room_stream_id, event, extra_users in pending:
if room_stream_id > max_room_stream_id:
- self.pending_new_room_events.append((
- room_stream_id, event, extra_users
- ))
+ self.pending_new_room_events.append(
+ (room_stream_id, event, extra_users)
+ )
else:
self._on_new_room_event(event, room_stream_id, extra_users)
@@ -250,8 +266,7 @@ class Notifier(object):
"""Notify any user streams that are interested in this room event"""
# poke any interested application service.
run_as_background_process(
- "notify_app_services",
- self._notify_app_services, room_stream_id,
+ "notify_app_services", self._notify_app_services, room_stream_id
)
if self.federation_sender:
@@ -261,9 +276,7 @@ class Notifier(object):
self._user_joined_room(event.state_key, event.room_id)
self.on_new_event(
- "room_key", room_stream_id,
- users=extra_users,
- rooms=[event.room_id],
+ "room_key", room_stream_id, users=extra_users, rooms=[event.room_id]
)
@defer.inlineCallbacks
@@ -304,17 +317,17 @@ class Notifier(object):
without waking up any of the normal user event streams"""
self.notify_replication()
- @defer.inlineCallbacks
- def wait_for_events(self, user_id, timeout, callback, room_ids=None,
- from_token=StreamToken.START):
+ async def wait_for_events(
+ self, user_id, timeout, callback, room_ids=None, from_token=StreamToken.START
+ ):
"""Wait until the callback returns a non empty response or the
timeout fires.
"""
user_stream = self.user_to_user_stream.get(user_id)
if user_stream is None:
- current_token = yield self.event_sources.get_current_token()
+ current_token = await self.event_sources.get_current_token()
if room_ids is None:
- room_ids = yield self.store.get_rooms_for_user(user_id)
+ room_ids = await self.store.get_rooms_for_user(user_id)
user_stream = _NotifierUserStream(
user_id=user_id,
rooms=room_ids,
@@ -339,15 +352,15 @@ class Notifier(object):
listener = user_stream.new_listener(prev_token)
listener.deferred = timeout_deferred(
listener.deferred,
- (end_time - now) / 1000.,
+ (end_time - now) / 1000.0,
self.hs.get_reactor(),
)
with PreserveLoggingContext():
- yield listener.deferred
+ await listener.deferred
current_token = user_stream.current_token
- result = yield callback(prev_token, current_token)
+ result = await callback(prev_token, current_token)
if result:
break
@@ -363,14 +376,19 @@ class Notifier(object):
# This happened if there was no timeout or if the timeout had
# already expired.
current_token = user_stream.current_token
- result = yield callback(prev_token, current_token)
-
- defer.returnValue(result)
-
- @defer.inlineCallbacks
- def get_events_for(self, user, pagination_config, timeout,
- only_keys=None,
- is_guest=False, explicit_room_id=None):
+ result = await callback(prev_token, current_token)
+
+ return result
+
+ async def get_events_for(
+ self,
+ user,
+ pagination_config,
+ timeout,
+ only_keys=None,
+ is_guest=False,
+ explicit_room_id=None,
+ ):
""" For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any
new events to happen before returning.
@@ -384,17 +402,16 @@ class Notifier(object):
"""
from_token = pagination_config.from_token
if not from_token:
- from_token = yield self.event_sources.get_current_token()
+ from_token = await self.event_sources.get_current_token()
limit = pagination_config.limit
- room_ids, is_joined = yield self._get_room_ids(user, explicit_room_id)
+ room_ids, is_joined = await self._get_room_ids(user, explicit_room_id)
is_peeking = not is_joined
- @defer.inlineCallbacks
- def check_for_updates(before_token, after_token):
+ async def check_for_updates(before_token, after_token):
if not after_token.is_after(before_token):
- defer.returnValue(EventStreamResult([], (from_token, from_token)))
+ return EventStreamResult([], (from_token, from_token))
events = []
end_token = from_token
@@ -408,7 +425,7 @@ class Notifier(object):
if only_keys and name not in only_keys:
continue
- new_events, new_key = yield source.get_new_events(
+ new_events, new_key = await source.get_new_events(
user=user,
from_key=getattr(from_token, keyname),
limit=limit,
@@ -418,8 +435,8 @@ class Notifier(object):
)
if name == "room":
- new_events = yield filter_events_for_client(
- self.store,
+ new_events = await filter_events_for_client(
+ self.storage,
user.to_string(),
new_events,
is_peeking=is_peeking,
@@ -437,7 +454,7 @@ class Notifier(object):
events.extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key)
- defer.returnValue(EventStreamResult(events, (from_token, end_token)))
+ return EventStreamResult(events, (from_token, end_token))
user_id_for_stream = user.to_string()
if is_peeking:
@@ -450,10 +467,11 @@ class Notifier(object):
#
# I am sorry for what I have done.
user_id_for_stream = "_PEEKING_%s_%s" % (
- explicit_room_id, user_id_for_stream
+ explicit_room_id,
+ user_id_for_stream,
)
- result = yield self.wait_for_events(
+ result = await self.wait_for_events(
user_id_for_stream,
timeout,
check_for_updates,
@@ -461,30 +479,28 @@ class Notifier(object):
from_token=from_token,
)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _get_room_ids(self, user, explicit_room_id):
joined_room_ids = yield self.store.get_rooms_for_user(user.to_string())
if explicit_room_id:
if explicit_room_id in joined_room_ids:
- defer.returnValue(([explicit_room_id], True))
+ return [explicit_room_id], True
if (yield self._is_world_readable(explicit_room_id)):
- defer.returnValue(([explicit_room_id], False))
+ return [explicit_room_id], False
raise AuthError(403, "Non-joined access not allowed")
- defer.returnValue((joined_room_ids, True))
+ return joined_room_ids, True
@defer.inlineCallbacks
def _is_world_readable(self, room_id):
state = yield self.state_handler.get_current_state(
- room_id,
- EventTypes.RoomHistoryVisibility,
- "",
+ room_id, EventTypes.RoomHistoryVisibility, ""
)
if state and "history_visibility" in state.content:
- defer.returnValue(state.content["history_visibility"] == "world_readable")
+ return state.content["history_visibility"] == "world_readable"
else:
- defer.returnValue(False)
+ return False
@log_function
def remove_expired_streams(self):
@@ -519,3 +535,15 @@ class Notifier(object):
"""Notify the any replication listeners that there's a new event"""
for cb in self.replication_callbacks:
cb()
+
+ def notify_remote_server_up(self, server: str):
+ """Notify any replication that a remote server has come back up
+ """
+ # We call federation_sender directly rather than registering as a
+ # callback as a) we already have a reference to it and b) it introduces
+ # circular dependencies.
+ if self.federation_sender:
+ self.federation_sender.wake_destination(server)
+
+ for cb in self.remote_server_up_callbacks:
+ cb(server)
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index a5de75c48a..1ffd5e2df3 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -40,6 +40,4 @@ class ActionGenerator(object):
@defer.inlineCallbacks
def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "action_for_event_by_user"):
- yield self.bulk_evaluator.action_for_event_by_user(
- event, context
- )
+ yield self.bulk_evaluator.action_for_event_by_user(event, context)
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 3523a40108..286374d0b5 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -1,5 +1,6 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -31,48 +32,54 @@ def list_with_base_rules(rawrules):
# Grab the base rules that the user has modified.
# The modified base rules have a priority_class of -1.
- modified_base_rules = {
- r['rule_id']: r for r in rawrules if r['priority_class'] < 0
- }
+ modified_base_rules = {r["rule_id"]: r for r in rawrules if r["priority_class"] < 0}
# Remove the modified base rules from the list, They'll be added back
# in the default postions in the list.
- rawrules = [r for r in rawrules if r['priority_class'] >= 0]
+ rawrules = [r for r in rawrules if r["priority_class"] >= 0]
# shove the server default rules for each kind onto the end of each
current_prio_class = list(PRIORITY_CLASS_INVERSE_MAP)[-1]
- ruleslist.extend(make_base_prepend_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
- ))
+ ruleslist.extend(
+ make_base_prepend_rules(
+ PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
+ )
+ )
for r in rawrules:
- if r['priority_class'] < current_prio_class:
- while r['priority_class'] < current_prio_class:
- ruleslist.extend(make_base_append_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
- modified_base_rules,
- ))
- current_prio_class -= 1
- if current_prio_class > 0:
- ruleslist.extend(make_base_prepend_rules(
+ if r["priority_class"] < current_prio_class:
+ while r["priority_class"] < current_prio_class:
+ ruleslist.extend(
+ make_base_append_rules(
PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
modified_base_rules,
- ))
+ )
+ )
+ current_prio_class -= 1
+ if current_prio_class > 0:
+ ruleslist.extend(
+ make_base_prepend_rules(
+ PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
+ modified_base_rules,
+ )
+ )
ruleslist.append(r)
while current_prio_class > 0:
- ruleslist.extend(make_base_append_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
- modified_base_rules,
- ))
+ ruleslist.extend(
+ make_base_append_rules(
+ PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
+ )
+ )
current_prio_class -= 1
if current_prio_class > 0:
- ruleslist.extend(make_base_prepend_rules(
- PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
- modified_base_rules,
- ))
+ ruleslist.extend(
+ make_base_prepend_rules(
+ PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
+ )
+ )
return ruleslist
@@ -80,20 +87,20 @@ def list_with_base_rules(rawrules):
def make_base_append_rules(kind, modified_base_rules):
rules = []
- if kind == 'override':
+ if kind == "override":
rules = BASE_APPEND_OVERRIDE_RULES
- elif kind == 'underride':
+ elif kind == "underride":
rules = BASE_APPEND_UNDERRIDE_RULES
- elif kind == 'content':
+ elif kind == "content":
rules = BASE_APPEND_CONTENT_RULES
# Copy the rules before modifying them
rules = copy.deepcopy(rules)
for r in rules:
# Only modify the actions, keep the conditions the same.
- modified = modified_base_rules.get(r['rule_id'])
+ modified = modified_base_rules.get(r["rule_id"])
if modified:
- r['actions'] = modified['actions']
+ r["actions"] = modified["actions"]
return rules
@@ -101,103 +108,86 @@ def make_base_append_rules(kind, modified_base_rules):
def make_base_prepend_rules(kind, modified_base_rules):
rules = []
- if kind == 'override':
+ if kind == "override":
rules = BASE_PREPEND_OVERRIDE_RULES
# Copy the rules before modifying them
rules = copy.deepcopy(rules)
for r in rules:
# Only modify the actions, keep the conditions the same.
- modified = modified_base_rules.get(r['rule_id'])
+ modified = modified_base_rules.get(r["rule_id"])
if modified:
- r['actions'] = modified['actions']
+ r["actions"] = modified["actions"]
return rules
BASE_APPEND_CONTENT_RULES = [
{
- 'rule_id': 'global/content/.m.rule.contains_user_name',
- 'conditions': [
+ "rule_id": "global/content/.m.rule.contains_user_name",
+ "conditions": [
{
- 'kind': 'event_match',
- 'key': 'content.body',
- 'pattern_type': 'user_localpart'
+ "kind": "event_match",
+ "key": "content.body",
+ "pattern_type": "user_localpart",
}
],
- 'actions': [
- 'notify',
- {
- 'set_tweak': 'sound',
- 'value': 'default',
- }, {
- 'set_tweak': 'highlight'
- }
- ]
- },
+ "actions": [
+ "notify",
+ {"set_tweak": "sound", "value": "default"},
+ {"set_tweak": "highlight"},
+ ],
+ }
]
BASE_PREPEND_OVERRIDE_RULES = [
{
- 'rule_id': 'global/override/.m.rule.master',
- 'enabled': False,
- 'conditions': [],
- 'actions': [
- "dont_notify"
- ]
+ "rule_id": "global/override/.m.rule.master",
+ "enabled": False,
+ "conditions": [],
+ "actions": ["dont_notify"],
}
]
BASE_APPEND_OVERRIDE_RULES = [
{
- 'rule_id': 'global/override/.m.rule.suppress_notices',
- 'conditions': [
+ "rule_id": "global/override/.m.rule.suppress_notices",
+ "conditions": [
{
- 'kind': 'event_match',
- 'key': 'content.msgtype',
- 'pattern': 'm.notice',
- '_id': '_suppress_notices',
+ "kind": "event_match",
+ "key": "content.msgtype",
+ "pattern": "m.notice",
+ "_id": "_suppress_notices",
}
],
- 'actions': [
- 'dont_notify',
- ]
+ "actions": ["dont_notify"],
},
# NB. .m.rule.invite_for_me must be higher prio than .m.rule.member_event
# otherwise invites will be matched by .m.rule.member_event
{
- 'rule_id': 'global/override/.m.rule.invite_for_me',
- 'conditions': [
- {
- 'kind': 'event_match',
- 'key': 'type',
- 'pattern': 'm.room.member',
- '_id': '_member',
- },
+ "rule_id": "global/override/.m.rule.invite_for_me",
+ "conditions": [
{
- 'kind': 'event_match',
- 'key': 'content.membership',
- 'pattern': 'invite',
- '_id': '_invite_member',
+ "kind": "event_match",
+ "key": "type",
+ "pattern": "m.room.member",
+ "_id": "_member",
},
{
- 'kind': 'event_match',
- 'key': 'state_key',
- 'pattern_type': 'user_id'
+ "kind": "event_match",
+ "key": "content.membership",
+ "pattern": "invite",
+ "_id": "_invite_member",
},
+ {"kind": "event_match", "key": "state_key", "pattern_type": "user_id"},
+ ],
+ "actions": [
+ "notify",
+ {"set_tweak": "sound", "value": "default"},
+ {"set_tweak": "highlight", "value": False},
],
- 'actions': [
- 'notify',
- {
- 'set_tweak': 'sound',
- 'value': 'default'
- }, {
- 'set_tweak': 'highlight',
- 'value': False
- }
- ]
},
# Will we sometimes want to know about people joining and leaving?
# Perhaps: if so, this could be expanded upon. Seems the most usual case
@@ -206,217 +196,182 @@ BASE_APPEND_OVERRIDE_RULES = [
# join/leave/avatar/displayname events.
# See also: https://matrix.org/jira/browse/SYN-607
{
- 'rule_id': 'global/override/.m.rule.member_event',
- 'conditions': [
+ "rule_id": "global/override/.m.rule.member_event",
+ "conditions": [
{
- 'kind': 'event_match',
- 'key': 'type',
- 'pattern': 'm.room.member',
- '_id': '_member',
+ "kind": "event_match",
+ "key": "type",
+ "pattern": "m.room.member",
+ "_id": "_member",
}
],
- 'actions': [
- 'dont_notify'
- ]
+ "actions": ["dont_notify"],
},
# This was changed from underride to override so it's closer in priority
# to the content rules where the user name highlight rule lives. This
# way a room rule is lower priority than both but a custom override rule
# is higher priority than both.
{
- 'rule_id': 'global/override/.m.rule.contains_display_name',
- 'conditions': [
- {
- 'kind': 'contains_display_name'
- }
+ "rule_id": "global/override/.m.rule.contains_display_name",
+ "conditions": [{"kind": "contains_display_name"}],
+ "actions": [
+ "notify",
+ {"set_tweak": "sound", "value": "default"},
+ {"set_tweak": "highlight"},
],
- 'actions': [
- 'notify',
+ },
+ {
+ "rule_id": "global/override/.m.rule.roomnotif",
+ "conditions": [
{
- 'set_tweak': 'sound',
- 'value': 'default'
- }, {
- 'set_tweak': 'highlight'
- }
- ]
+ "kind": "event_match",
+ "key": "content.body",
+ "pattern": "@room",
+ "_id": "_roomnotif_content",
+ },
+ {
+ "kind": "sender_notification_permission",
+ "key": "room",
+ "_id": "_roomnotif_pl",
+ },
+ ],
+ "actions": ["notify", {"set_tweak": "highlight", "value": True}],
},
{
- 'rule_id': 'global/override/.m.rule.roomnotif',
- 'conditions': [
+ "rule_id": "global/override/.m.rule.tombstone",
+ "conditions": [
{
- 'kind': 'event_match',
- 'key': 'content.body',
- 'pattern': '@room',
- '_id': '_roomnotif_content',
+ "kind": "event_match",
+ "key": "type",
+ "pattern": "m.room.tombstone",
+ "_id": "_tombstone",
},
{
- 'kind': 'sender_notification_permission',
- 'key': 'room',
- '_id': '_roomnotif_pl',
+ "kind": "event_match",
+ "key": "state_key",
+ "pattern": "",
+ "_id": "_tombstone_statekey",
},
],
- 'actions': [
- 'notify', {
- 'set_tweak': 'highlight',
- 'value': True,
- }
- ]
+ "actions": ["notify", {"set_tweak": "highlight", "value": True}],
},
{
- 'rule_id': 'global/override/.m.rule.tombstone',
- 'conditions': [
+ "rule_id": "global/override/.m.rule.reaction",
+ "conditions": [
{
- 'kind': 'event_match',
- 'key': 'type',
- 'pattern': 'm.room.tombstone',
- '_id': '_tombstone',
+ "kind": "event_match",
+ "key": "type",
+ "pattern": "m.reaction",
+ "_id": "_reaction",
}
],
- 'actions': [
- 'notify', {
- 'set_tweak': 'highlight',
- 'value': True,
- }
- ]
- }
+ "actions": ["dont_notify"],
+ },
]
BASE_APPEND_UNDERRIDE_RULES = [
{
- 'rule_id': 'global/underride/.m.rule.call',
- 'conditions': [
+ "rule_id": "global/underride/.m.rule.call",
+ "conditions": [
{
- 'kind': 'event_match',
- 'key': 'type',
- 'pattern': 'm.call.invite',
- '_id': '_call',
+ "kind": "event_match",
+ "key": "type",
+ "pattern": "m.call.invite",
+ "_id": "_call",
}
],
- 'actions': [
- 'notify',
- {
- 'set_tweak': 'sound',
- 'value': 'ring'
- }, {
- 'set_tweak': 'highlight',
- 'value': False
- }
- ]
+ "actions": [
+ "notify",
+ {"set_tweak": "sound", "value": "ring"},
+ {"set_tweak": "highlight", "value": False},
+ ],
},
# XXX: once m.direct is standardised everywhere, we should use it to detect
# a DM from the user's perspective rather than this heuristic.
{
- 'rule_id': 'global/underride/.m.rule.room_one_to_one',
- 'conditions': [
+ "rule_id": "global/underride/.m.rule.room_one_to_one",
+ "conditions": [
+ {"kind": "room_member_count", "is": "2", "_id": "member_count"},
{
- 'kind': 'room_member_count',
- 'is': '2',
- '_id': 'member_count',
+ "kind": "event_match",
+ "key": "type",
+ "pattern": "m.room.message",
+ "_id": "_message",
},
- {
- 'kind': 'event_match',
- 'key': 'type',
- 'pattern': 'm.room.message',
- '_id': '_message',
- }
],
- 'actions': [
- 'notify',
- {
- 'set_tweak': 'sound',
- 'value': 'default'
- }, {
- 'set_tweak': 'highlight',
- 'value': False
- }
- ]
+ "actions": [
+ "notify",
+ {"set_tweak": "sound", "value": "default"},
+ {"set_tweak": "highlight", "value": False},
+ ],
},
# XXX: this is going to fire for events which aren't m.room.messages
# but are encrypted (e.g. m.call.*)...
{
- 'rule_id': 'global/underride/.m.rule.encrypted_room_one_to_one',
- 'conditions': [
+ "rule_id": "global/underride/.m.rule.encrypted_room_one_to_one",
+ "conditions": [
+ {"kind": "room_member_count", "is": "2", "_id": "member_count"},
{
- 'kind': 'room_member_count',
- 'is': '2',
- '_id': 'member_count',
+ "kind": "event_match",
+ "key": "type",
+ "pattern": "m.room.encrypted",
+ "_id": "_encrypted",
},
- {
- 'kind': 'event_match',
- 'key': 'type',
- 'pattern': 'm.room.encrypted',
- '_id': '_encrypted',
- }
],
- 'actions': [
- 'notify',
- {
- 'set_tweak': 'sound',
- 'value': 'default'
- }, {
- 'set_tweak': 'highlight',
- 'value': False
- }
- ]
+ "actions": [
+ "notify",
+ {"set_tweak": "sound", "value": "default"},
+ {"set_tweak": "highlight", "value": False},
+ ],
},
{
- 'rule_id': 'global/underride/.m.rule.message',
- 'conditions': [
+ "rule_id": "global/underride/.m.rule.message",
+ "conditions": [
{
- 'kind': 'event_match',
- 'key': 'type',
- 'pattern': 'm.room.message',
- '_id': '_message',
+ "kind": "event_match",
+ "key": "type",
+ "pattern": "m.room.message",
+ "_id": "_message",
}
],
- 'actions': [
- 'notify', {
- 'set_tweak': 'highlight',
- 'value': False
- }
- ]
+ "actions": ["notify", {"set_tweak": "highlight", "value": False}],
},
# XXX: this is going to fire for events which aren't m.room.messages
# but are encrypted (e.g. m.call.*)...
{
- 'rule_id': 'global/underride/.m.rule.encrypted',
- 'conditions': [
+ "rule_id": "global/underride/.m.rule.encrypted",
+ "conditions": [
{
- 'kind': 'event_match',
- 'key': 'type',
- 'pattern': 'm.room.encrypted',
- '_id': '_encrypted',
+ "kind": "event_match",
+ "key": "type",
+ "pattern": "m.room.encrypted",
+ "_id": "_encrypted",
}
],
- 'actions': [
- 'notify', {
- 'set_tweak': 'highlight',
- 'value': False
- }
- ]
- }
+ "actions": ["notify", {"set_tweak": "highlight", "value": False}],
+ },
]
BASE_RULE_IDS = set()
for r in BASE_APPEND_CONTENT_RULES:
- r['priority_class'] = PRIORITY_CLASS_MAP['content']
- r['default'] = True
- BASE_RULE_IDS.add(r['rule_id'])
+ r["priority_class"] = PRIORITY_CLASS_MAP["content"]
+ r["default"] = True
+ BASE_RULE_IDS.add(r["rule_id"])
for r in BASE_PREPEND_OVERRIDE_RULES:
- r['priority_class'] = PRIORITY_CLASS_MAP['override']
- r['default'] = True
- BASE_RULE_IDS.add(r['rule_id'])
+ r["priority_class"] = PRIORITY_CLASS_MAP["override"]
+ r["default"] = True
+ BASE_RULE_IDS.add(r["rule_id"])
for r in BASE_APPEND_OVERRIDE_RULES:
- r['priority_class'] = PRIORITY_CLASS_MAP['override']
- r['default'] = True
- BASE_RULE_IDS.add(r['rule_id'])
+ r["priority_class"] = PRIORITY_CLASS_MAP["override"]
+ r["default"] = True
+ BASE_RULE_IDS.add(r["rule_id"])
for r in BASE_APPEND_UNDERRIDE_RULES:
- r['priority_class'] = PRIORITY_CLASS_MAP['underride']
- r['default'] = True
- BASE_RULE_IDS.add(r['rule_id'])
+ r["priority_class"] = PRIORITY_CLASS_MAP["underride"]
+ r["default"] = True
+ BASE_RULE_IDS.add(r["rule_id"])
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 8f9a76147f..433ca2f416 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -39,9 +39,11 @@ rules_by_room = {}
push_rules_invalidation_counter = Counter(
- "synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", "")
+ "synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", ""
+)
push_rules_state_size_counter = Counter(
- "synapse_push_bulk_push_rule_evaluator_push_rules_state_size_counter", "")
+ "synapse_push_bulk_push_rule_evaluator_push_rules_state_size_counter", ""
+)
# Measures whether we use the fast path of using state deltas, or if we have to
# recalculate from scratch
@@ -77,13 +79,13 @@ class BulkPushRuleEvaluator(object):
dict of user_id -> push_rules
"""
room_id = event.room_id
- rules_for_room = self._get_rules_for_room(room_id)
+ rules_for_room = yield self._get_rules_for_room(room_id)
rules_by_user = yield rules_for_room.get_rules(event, context)
# if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited
- if event.type == 'm.room.member' and event.content['membership'] == 'invite':
+ if event.type == "m.room.member" and event.content["membership"] == "invite":
invited = event.state_key
if invited and self.hs.is_mine_id(invited):
has_pusher = yield self.store.user_has_pusher(invited)
@@ -93,7 +95,7 @@ class BulkPushRuleEvaluator(object):
invited
)
- defer.returnValue(rules_by_user)
+ return rules_by_user
@cached()
def _get_rules_for_room(self, room_id):
@@ -106,13 +108,15 @@ class BulkPushRuleEvaluator(object):
# before any lookup methods get called on it as otherwise there may be
# a race if invalidate_all gets called (which assumes its in the cache)
return RulesForRoom(
- self.hs, room_id, self._get_rules_for_room.cache,
+ self.hs,
+ room_id,
+ self._get_rules_for_room.cache,
self.room_push_rule_cache_metrics,
)
@defer.inlineCallbacks
def _get_power_levels_and_sender_level(self, event, context):
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
pl_event_id = prev_state_ids.get(POWER_KEY)
if pl_event_id:
# fastpath: if there's a power level event, that's all we need, and
@@ -121,18 +125,16 @@ class BulkPushRuleEvaluator(object):
auth_events = {POWER_KEY: pl_event}
else:
auth_events_ids = yield self.auth.compute_auth_events(
- event, prev_state_ids, for_verification=False,
+ event, prev_state_ids, for_verification=False
)
auth_events = yield self.store.get_events(auth_events_ids)
- auth_events = {
- (e.type, e.state_key): e for e in itervalues(auth_events)
- }
+ auth_events = {(e.type, e.state_key): e for e in itervalues(auth_events)}
sender_level = get_user_power_level(event.sender, auth_events)
pl_event = auth_events.get(POWER_KEY)
- defer.returnValue((pl_event.content if pl_event else {}, sender_level))
+ return pl_event.content if pl_event else {}, sender_level
@defer.inlineCallbacks
def action_for_event_by_user(self, event, context):
@@ -145,16 +147,15 @@ class BulkPushRuleEvaluator(object):
rules_by_user = yield self._get_rules_for_event(event, context)
actions_by_user = {}
- room_members = yield self.store.get_joined_users_from_context(
- event, context
- )
+ room_members = yield self.store.get_joined_users_from_context(event, context)
- (power_levels, sender_power_level) = (
- yield self._get_power_levels_and_sender_level(event, context)
- )
+ (
+ power_levels,
+ sender_power_level,
+ ) = yield self._get_power_levels_and_sender_level(event, context)
evaluator = PushRuleEvaluatorForEvent(
- event, len(room_members), sender_power_level, power_levels,
+ event, len(room_members), sender_power_level, power_levels
)
condition_cache = {}
@@ -180,15 +181,15 @@ class BulkPushRuleEvaluator(object):
display_name = event.content.get("displayname", None)
for rule in rules:
- if 'enabled' in rule and not rule['enabled']:
+ if "enabled" in rule and not rule["enabled"]:
continue
matches = _condition_checker(
- evaluator, rule['conditions'], uid, display_name, condition_cache
+ evaluator, rule["conditions"], uid, display_name, condition_cache
)
if matches:
- actions = [x for x in rule['actions'] if x != 'dont_notify']
- if actions and 'notify' in actions:
+ actions = [x for x in rule["actions"] if x != "dont_notify"]
+ if actions and "notify" in actions:
# Push rules say we should notify the user of this event
actions_by_user[uid] = actions
break
@@ -196,9 +197,7 @@ class BulkPushRuleEvaluator(object):
# Mark in the DB staging area the push actions for users who should be
# notified for this event. (This will then get handled when we persist
# the event)
- yield self.store.add_push_actions_to_staging(
- event.event_id, actions_by_user,
- )
+ yield self.store.add_push_actions_to_staging(event.event_id, actions_by_user)
def _condition_checker(evaluator, conditions, uid, display_name, cache):
@@ -285,13 +284,13 @@ class RulesForRoom(object):
if state_group and self.state_group == state_group:
logger.debug("Using cached rules for %r", self.room_id)
self.room_push_rule_cache_metrics.inc_hits()
- defer.returnValue(self.rules_by_user)
+ return self.rules_by_user
with (yield self.linearizer.queue(())):
if state_group and self.state_group == state_group:
logger.debug("Using cached rules for %r", self.room_id)
self.room_push_rule_cache_metrics.inc_hits()
- defer.returnValue(self.rules_by_user)
+ return self.rules_by_user
self.room_push_rule_cache_metrics.inc_misses()
@@ -305,7 +304,7 @@ class RulesForRoom(object):
push_rules_delta_state_cache_metric.inc_hits()
else:
- current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = yield context.get_current_state_ids()
push_rules_delta_state_cache_metric.inc_misses()
push_rules_state_size_counter.inc(len(current_state_ids))
@@ -361,19 +360,19 @@ class RulesForRoom(object):
self.sequence,
members={}, # There were no membership changes
rules_by_user=ret_rules_by_user,
- state_group=state_group
+ state_group=state_group,
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
- "Returning push rules for %r %r",
- self.room_id, ret_rules_by_user.keys(),
+ "Returning push rules for %r %r", self.room_id, ret_rules_by_user.keys()
)
- defer.returnValue(ret_rules_by_user)
+ return ret_rules_by_user
@defer.inlineCallbacks
- def _update_rules_with_member_event_ids(self, ret_rules_by_user, member_event_ids,
- state_group, event):
+ def _update_rules_with_member_event_ids(
+ self, ret_rules_by_user, member_event_ids, state_group, event
+ ):
"""Update the partially filled rules_by_user dict by fetching rules for
any newly joined users in the `member_event_ids` list.
@@ -387,20 +386,9 @@ class RulesForRoom(object):
"""
sequence = self.sequence
- rows = yield self.store._simple_select_many_batch(
- table="room_memberships",
- column="event_id",
- iterable=member_event_ids.values(),
- retcols=('user_id', 'membership', 'event_id'),
- keyvalues={},
- batch_size=500,
- desc="_get_rules_for_member_event_ids",
- )
+ rows = yield self.store.get_membership_from_event_ids(member_event_ids.values())
- members = {
- row["event_id"]: (row["user_id"], row["membership"])
- for row in rows
- }
+ members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows}
# If the event is a join event then it will be in current state evnts
# map but not in the DB, so we have to explicitly insert it.
@@ -412,26 +400,26 @@ class RulesForRoom(object):
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Found members %r: %r", self.room_id, members.values())
- interested_in_user_ids = set(
- user_id for user_id, membership in itervalues(members)
+ interested_in_user_ids = {
+ user_id
+ for user_id, membership in itervalues(members)
if membership == Membership.JOIN
- )
+ }
logger.debug("Joined: %r", interested_in_user_ids)
if_users_with_pushers = yield self.store.get_if_users_have_pushers(
- interested_in_user_ids,
- on_invalidate=self.invalidate_all_cb,
+ interested_in_user_ids, on_invalidate=self.invalidate_all_cb
)
- user_ids = set(
+ user_ids = {
uid for uid, have_pusher in iteritems(if_users_with_pushers) if have_pusher
- )
+ }
logger.debug("With pushers: %r", user_ids)
users_with_receipts = yield self.store.get_users_with_read_receipts_in_room(
- self.room_id, on_invalidate=self.invalidate_all_cb,
+ self.room_id, on_invalidate=self.invalidate_all_cb
)
logger.debug("With receipts: %r", users_with_receipts)
@@ -442,7 +430,7 @@ class RulesForRoom(object):
user_ids.add(uid)
rules_by_user = yield self.store.bulk_get_push_rules(
- user_ids, on_invalidate=self.invalidate_all_cb,
+ user_ids, on_invalidate=self.invalidate_all_cb
)
ret_rules_by_user.update(
diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index 8bd96b1178..a59b639f15 100644
--- a/synapse/push/clientformat.py
+++ b/synapse/push/clientformat.py
@@ -25,14 +25,14 @@ def format_push_rules_for_user(user, ruleslist):
# We're going to be mutating this a lot, so do a deep copy
ruleslist = copy.deepcopy(ruleslist)
- rules = {'global': {}, 'device': {}}
+ rules = {"global": {}, "device": {}}
- rules['global'] = _add_empty_priority_class_arrays(rules['global'])
+ rules["global"] = _add_empty_priority_class_arrays(rules["global"])
for r in ruleslist:
rulearray = None
- template_name = _priority_class_to_template_name(r['priority_class'])
+ template_name = _priority_class_to_template_name(r["priority_class"])
# Remove internal stuff.
for c in r["conditions"]:
@@ -44,14 +44,14 @@ def format_push_rules_for_user(user, ruleslist):
elif pattern_type == "user_localpart":
c["pattern"] = user.localpart
- rulearray = rules['global'][template_name]
+ rulearray = rules["global"][template_name]
template_rule = _rule_to_template(r)
if template_rule:
- if 'enabled' in r:
- template_rule['enabled'] = r['enabled']
+ if "enabled" in r:
+ template_rule["enabled"] = r["enabled"]
else:
- template_rule['enabled'] = True
+ template_rule["enabled"] = True
rulearray.append(template_rule)
return rules
@@ -65,33 +65,33 @@ def _add_empty_priority_class_arrays(d):
def _rule_to_template(rule):
unscoped_rule_id = None
- if 'rule_id' in rule:
- unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id'])
+ if "rule_id" in rule:
+ unscoped_rule_id = _rule_id_from_namespaced(rule["rule_id"])
- template_name = _priority_class_to_template_name(rule['priority_class'])
- if template_name in ['override', 'underride']:
+ template_name = _priority_class_to_template_name(rule["priority_class"])
+ if template_name in ["override", "underride"]:
templaterule = {k: rule[k] for k in ["conditions", "actions"]}
elif template_name in ["sender", "room"]:
- templaterule = {'actions': rule['actions']}
- unscoped_rule_id = rule['conditions'][0]['pattern']
- elif template_name == 'content':
+ templaterule = {"actions": rule["actions"]}
+ unscoped_rule_id = rule["conditions"][0]["pattern"]
+ elif template_name == "content":
if len(rule["conditions"]) != 1:
return None
thecond = rule["conditions"][0]
if "pattern" not in thecond:
return None
- templaterule = {'actions': rule['actions']}
+ templaterule = {"actions": rule["actions"]}
templaterule["pattern"] = thecond["pattern"]
if unscoped_rule_id:
- templaterule['rule_id'] = unscoped_rule_id
- if 'default' in rule:
- templaterule['default'] = rule['default']
+ templaterule["rule_id"] = unscoped_rule_id
+ if "default" in rule:
+ templaterule["default"] = rule["default"]
return templaterule
def _rule_id_from_namespaced(in_rule_id):
- return in_rule_id.split('/')[-1]
+ return in_rule_id.split("/")[-1]
def _priority_class_to_template_name(pc):
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index e8ee67401f..ba4551d619 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -32,13 +32,13 @@ DELAY_BEFORE_MAIL_MS = 10 * 60 * 1000
THROTTLE_START_MS = 10 * 60 * 1000
THROTTLE_MAX_MS = 24 * 60 * 60 * 1000 # 24h
# THROTTLE_MULTIPLIER = 6 # 10 mins, 1 hour, 6 hours, 24 hours
-THROTTLE_MULTIPLIER = 144 # 10 mins, 24 hours - i.e. jump straight to 1 day
+THROTTLE_MULTIPLIER = 144 # 10 mins, 24 hours - i.e. jump straight to 1 day
# If no event triggers a notification for this long after the previous,
# the throttle is released.
# 12 hours - a gap of 12 hours in conversation is surely enough to merit a new
# notification when things get going again...
-THROTTLE_RESET_AFTER_MS = (12 * 60 * 60 * 1000)
+THROTTLE_RESET_AFTER_MS = 12 * 60 * 60 * 1000
# does each email include all unread notifs, or just the ones which have happened
# since the last mail?
@@ -53,17 +53,18 @@ class EmailPusher(object):
This shares quite a bit of code with httpusher: it would be good to
factor out the common parts
"""
+
def __init__(self, hs, pusherdict, mailer):
self.hs = hs
self.mailer = mailer
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
- self.pusher_id = pusherdict['id']
- self.user_id = pusherdict['user_name']
- self.app_id = pusherdict['app_id']
- self.email = pusherdict['pushkey']
- self.last_stream_ordering = pusherdict['last_stream_ordering']
+ self.pusher_id = pusherdict["id"]
+ self.user_id = pusherdict["user_name"]
+ self.app_id = pusherdict["app_id"]
+ self.email = pusherdict["pushkey"]
+ self.last_stream_ordering = pusherdict["last_stream_ordering"]
self.timed_call = None
self.throttle_params = None
@@ -93,7 +94,9 @@ class EmailPusher(object):
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
if self.max_stream_ordering:
- self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering)
+ self.max_stream_ordering = max(
+ max_stream_ordering, self.max_stream_ordering
+ )
else:
self.max_stream_ordering = max_stream_ordering
self._start_processing()
@@ -114,6 +117,21 @@ class EmailPusher(object):
run_as_background_process("emailpush.process", self._process)
+ def _pause_processing(self):
+ """Used by tests to temporarily pause processing of events.
+
+ Asserts that its not currently processing.
+ """
+ assert not self._is_processing
+ self._is_processing = True
+
+ def _resume_processing(self):
+ """Used by tests to resume processing of events after pausing.
+ """
+ assert self._is_processing
+ self._is_processing = False
+ self._start_processing()
+
@defer.inlineCallbacks
def _process(self):
# we should never get here if we are already processing
@@ -159,14 +177,12 @@ class EmailPusher(object):
return
for push_action in unprocessed:
- received_at = push_action['received_ts']
+ received_at = push_action["received_ts"]
if received_at is None:
received_at = 0
notif_ready_at = received_at + DELAY_BEFORE_MAIL_MS
- room_ready_at = self.room_ready_to_notify_at(
- push_action['room_id']
- )
+ room_ready_at = self.room_ready_to_notify_at(push_action["room_id"])
should_notify_at = max(notif_ready_at, room_ready_at)
@@ -177,25 +193,23 @@ class EmailPusher(object):
# to be delivered.
reason = {
- 'room_id': push_action['room_id'],
- 'now': self.clock.time_msec(),
- 'received_at': received_at,
- 'delay_before_mail_ms': DELAY_BEFORE_MAIL_MS,
- 'last_sent_ts': self.get_room_last_sent_ts(push_action['room_id']),
- 'throttle_ms': self.get_room_throttle_ms(push_action['room_id']),
+ "room_id": push_action["room_id"],
+ "now": self.clock.time_msec(),
+ "received_at": received_at,
+ "delay_before_mail_ms": DELAY_BEFORE_MAIL_MS,
+ "last_sent_ts": self.get_room_last_sent_ts(push_action["room_id"]),
+ "throttle_ms": self.get_room_throttle_ms(push_action["room_id"]),
}
yield self.send_notification(unprocessed, reason)
- yield self.save_last_stream_ordering_and_success(max([
- ea['stream_ordering'] for ea in unprocessed
- ]))
+ yield self.save_last_stream_ordering_and_success(
+ max(ea["stream_ordering"] for ea in unprocessed)
+ )
# we update the throttle on all the possible unprocessed push actions
for ea in unprocessed:
- yield self.sent_notif_update_throttle(
- ea['room_id'], ea
- )
+ yield self.sent_notif_update_throttle(ea["room_id"], ea)
break
else:
if soonest_due_at is None or should_notify_at < soonest_due_at:
@@ -215,11 +229,22 @@ class EmailPusher(object):
@defer.inlineCallbacks
def save_last_stream_ordering_and_success(self, last_stream_ordering):
+ if last_stream_ordering is None:
+ # This happens if we haven't yet processed anything
+ return
+
self.last_stream_ordering = last_stream_ordering
- yield self.store.update_pusher_last_stream_ordering_and_success(
- self.app_id, self.email, self.user_id,
- last_stream_ordering, self.clock.time_msec()
+ pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success(
+ self.app_id,
+ self.email,
+ self.user_id,
+ last_stream_ordering,
+ self.clock.time_msec(),
)
+ if not pusher_still_exists:
+ # The pusher has been deleted while we were processing, so
+ # lets just stop and return.
+ self.on_stop()
def seconds_until(self, ts_msec):
secs = (ts_msec - self.clock.time_msec()) / 1000
@@ -257,10 +282,10 @@ class EmailPusher(object):
# THROTTLE_RESET_AFTER_MS after the previous one that triggered a
# notif, we release the throttle. Otherwise, the throttle is increased.
time_of_previous_notifs = yield self.store.get_time_of_last_push_action_before(
- notified_push_action['stream_ordering']
+ notified_push_action["stream_ordering"]
)
- time_of_this_notifs = notified_push_action['received_ts']
+ time_of_this_notifs = notified_push_action["received_ts"]
if time_of_previous_notifs is not None and time_of_this_notifs is not None:
gap = time_of_this_notifs - time_of_previous_notifs
@@ -279,12 +304,11 @@ class EmailPusher(object):
new_throttle_ms = THROTTLE_START_MS
else:
new_throttle_ms = min(
- current_throttle_ms * THROTTLE_MULTIPLIER,
- THROTTLE_MAX_MS
+ current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS
)
self.throttle_params[room_id] = {
"last_sent_ts": self.clock.time_msec(),
- "throttle_ms": new_throttle_ms
+ "throttle_ms": new_throttle_ms,
}
yield self.store.set_throttle_params(
self.pusher_id, room_id, self.throttle_params[room_id]
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index a21b164266..5bb17d1228 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -22,6 +22,7 @@ from prometheus_client import Counter
from twisted.internet import defer
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
+from synapse.logging import opentracing
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import PusherConfigException
@@ -63,18 +64,19 @@ class HttpPusher(object):
def __init__(self, hs, pusherdict):
self.hs = hs
self.store = self.hs.get_datastore()
+ self.storage = self.hs.get_storage()
self.clock = self.hs.get_clock()
self.state_handler = self.hs.get_state_handler()
- self.user_id = pusherdict['user_name']
- self.app_id = pusherdict['app_id']
- self.app_display_name = pusherdict['app_display_name']
- self.device_display_name = pusherdict['device_display_name']
- self.pushkey = pusherdict['pushkey']
- self.pushkey_ts = pusherdict['ts']
- self.data = pusherdict['data']
- self.last_stream_ordering = pusherdict['last_stream_ordering']
+ self.user_id = pusherdict["user_name"]
+ self.app_id = pusherdict["app_id"]
+ self.app_display_name = pusherdict["app_display_name"]
+ self.device_display_name = pusherdict["device_display_name"]
+ self.pushkey = pusherdict["pushkey"]
+ self.pushkey_ts = pusherdict["ts"]
+ self.data = pusherdict["data"]
+ self.last_stream_ordering = pusherdict["last_stream_ordering"]
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
- self.failing_since = pusherdict['failing_since']
+ self.failing_since = pusherdict["failing_since"]
self.timed_call = None
self._is_processing = False
@@ -85,32 +87,26 @@ class HttpPusher(object):
# off as None though as we don't know any better.
self.max_stream_ordering = None
- if 'data' not in pusherdict:
- raise PusherConfigException(
- "No 'data' key for HTTP pusher"
- )
- self.data = pusherdict['data']
+ if "data" not in pusherdict:
+ raise PusherConfigException("No 'data' key for HTTP pusher")
+ self.data = pusherdict["data"]
self.name = "%s/%s/%s" % (
- pusherdict['user_name'],
- pusherdict['app_id'],
- pusherdict['pushkey'],
+ pusherdict["user_name"],
+ pusherdict["app_id"],
+ pusherdict["pushkey"],
)
if self.data is None:
- raise PusherConfigException(
- "data can not be null for HTTP pusher"
- )
+ raise PusherConfigException("data can not be null for HTTP pusher")
- if 'url' not in self.data:
- raise PusherConfigException(
- "'url' required in data for HTTP pusher"
- )
- self.url = self.data['url']
+ if "url" not in self.data:
+ raise PusherConfigException("'url' required in data for HTTP pusher")
+ self.url = self.data["url"]
self.http_client = hs.get_proxied_http_client()
self.data_minus_url = {}
self.data_minus_url.update(self.data)
- del self.data_minus_url['url']
+ del self.data_minus_url["url"]
def on_started(self, should_check_for_notifs):
"""Called when this pusher has been started.
@@ -124,7 +120,9 @@ class HttpPusher(object):
self._start_processing()
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
- self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering or 0)
+ self.max_stream_ordering = max(
+ max_stream_ordering, self.max_stream_ordering or 0
+ )
self._start_processing()
def on_new_receipts(self, min_stream_id, max_stream_id):
@@ -192,186 +190,207 @@ class HttpPusher(object):
logger.info(
"Processing %i unprocessed push actions for %s starting at "
"stream_ordering %s",
- len(unprocessed), self.name, self.last_stream_ordering,
+ len(unprocessed),
+ self.name,
+ self.last_stream_ordering,
)
for push_action in unprocessed:
- processed = yield self._process_one(push_action)
+ with opentracing.start_active_span(
+ "http-push",
+ tags={
+ "authenticated_entity": self.user_id,
+ "event_id": push_action["event_id"],
+ "app_id": self.app_id,
+ "app_display_name": self.app_display_name,
+ },
+ ):
+ processed = yield self._process_one(push_action)
+
if processed:
http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
- self.last_stream_ordering = push_action['stream_ordering']
- yield self.store.update_pusher_last_stream_ordering_and_success(
- self.app_id, self.pushkey, self.user_id,
+ self.last_stream_ordering = push_action["stream_ordering"]
+ pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success(
+ self.app_id,
+ self.pushkey,
+ self.user_id,
self.last_stream_ordering,
- self.clock.time_msec()
+ self.clock.time_msec(),
)
+ if not pusher_still_exists:
+ # The pusher has been deleted while we were processing, so
+ # lets just stop and return.
+ self.on_stop()
+ return
+
if self.failing_since:
self.failing_since = None
yield self.store.update_pusher_failing_since(
- self.app_id, self.pushkey, self.user_id,
- self.failing_since
+ self.app_id, self.pushkey, self.user_id, self.failing_since
)
else:
http_push_failed_counter.inc()
if not self.failing_since:
self.failing_since = self.clock.time_msec()
yield self.store.update_pusher_failing_since(
- self.app_id, self.pushkey, self.user_id,
- self.failing_since
+ self.app_id, self.pushkey, self.user_id, self.failing_since
)
if (
- self.failing_since and
- self.failing_since <
- self.clock.time_msec() - HttpPusher.GIVE_UP_AFTER_MS
+ self.failing_since
+ and self.failing_since
+ < self.clock.time_msec() - HttpPusher.GIVE_UP_AFTER_MS
):
# we really only give up so that if the URL gets
# fixed, we don't suddenly deliver a load
# of old notifications.
- logger.warn("Giving up on a notification to user %s, "
- "pushkey %s",
- self.user_id, self.pushkey)
+ logger.warning(
+ "Giving up on a notification to user %s, pushkey %s",
+ self.user_id,
+ self.pushkey,
+ )
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
- self.last_stream_ordering = push_action['stream_ordering']
- yield self.store.update_pusher_last_stream_ordering(
+ self.last_stream_ordering = push_action["stream_ordering"]
+ pusher_still_exists = yield self.store.update_pusher_last_stream_ordering(
self.app_id,
self.pushkey,
self.user_id,
- self.last_stream_ordering
+ self.last_stream_ordering,
)
+ if not pusher_still_exists:
+ # The pusher has been deleted while we were processing, so
+ # lets just stop and return.
+ self.on_stop()
+ return
self.failing_since = None
yield self.store.update_pusher_failing_since(
- self.app_id,
- self.pushkey,
- self.user_id,
- self.failing_since
+ self.app_id, self.pushkey, self.user_id, self.failing_since
)
else:
logger.info("Push failed: delaying for %ds", self.backoff_delay)
self.timed_call = self.hs.get_reactor().callLater(
self.backoff_delay, self.on_timer
)
- self.backoff_delay = min(self.backoff_delay * 2, self.MAX_BACKOFF_SEC)
+ self.backoff_delay = min(
+ self.backoff_delay * 2, self.MAX_BACKOFF_SEC
+ )
break
@defer.inlineCallbacks
def _process_one(self, push_action):
- if 'notify' not in push_action['actions']:
- defer.returnValue(True)
+ if "notify" not in push_action["actions"]:
+ return True
- tweaks = push_rule_evaluator.tweaks_for_actions(push_action['actions'])
+ tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"])
badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
- event = yield self.store.get_event(push_action['event_id'], allow_none=True)
+ event = yield self.store.get_event(push_action["event_id"], allow_none=True)
if event is None:
- defer.returnValue(True) # It's been redacted
+ return True # It's been redacted
rejected = yield self.dispatch_push(event, tweaks, badge)
if rejected is False:
- defer.returnValue(False)
+ return False
if isinstance(rejected, list) or isinstance(rejected, tuple):
for pk in rejected:
if pk != self.pushkey:
# for sanity, we only remove the pushkey if it
# was the one we actually sent...
- logger.warn(
- ("Ignoring rejected pushkey %s because we"
- " didn't send it"), pk
+ logger.warning(
+ ("Ignoring rejected pushkey %s because we didn't send it"), pk,
)
else:
- logger.info(
- "Pushkey %s was rejected: removing",
- pk
- )
- yield self.hs.remove_pusher(
- self.app_id, pk, self.user_id
- )
- defer.returnValue(True)
+ logger.info("Pushkey %s was rejected: removing", pk)
+ yield self.hs.remove_pusher(self.app_id, pk, self.user_id)
+ return True
@defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks, badge):
- if self.data.get('format') == 'event_id_only':
+ if self.data.get("format") == "event_id_only":
d = {
- 'notification': {
- 'event_id': event.event_id,
- 'room_id': event.room_id,
- 'counts': {
- 'unread': badge,
- },
- 'devices': [
+ "notification": {
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "counts": {"unread": badge},
+ "devices": [
{
- 'app_id': self.app_id,
- 'pushkey': self.pushkey,
- 'pushkey_ts': long(self.pushkey_ts / 1000),
- 'data': self.data_minus_url,
+ "app_id": self.app_id,
+ "pushkey": self.pushkey,
+ "pushkey_ts": long(self.pushkey_ts / 1000),
+ "data": self.data_minus_url,
}
- ]
+ ],
}
}
- defer.returnValue(d)
+ return d
ctx = yield push_tools.get_context_for_event(
- self.store, self.state_handler, event, self.user_id
+ self.storage, self.state_handler, event, self.user_id
)
d = {
- 'notification': {
- 'id': event.event_id, # deprecated: remove soon
- 'event_id': event.event_id,
- 'room_id': event.room_id,
- 'type': event.type,
- 'sender': event.user_id,
- 'counts': { # -- we don't mark messages as read yet so
- # we have no way of knowing
+ "notification": {
+ "id": event.event_id, # deprecated: remove soon
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "type": event.type,
+ "sender": event.user_id,
+ "counts": { # -- we don't mark messages as read yet so
+ # we have no way of knowing
# Just set the badge to 1 until we have read receipts
- 'unread': badge,
+ "unread": badge,
# 'missed_calls': 2
},
- 'devices': [
+ "devices": [
{
- 'app_id': self.app_id,
- 'pushkey': self.pushkey,
- 'pushkey_ts': long(self.pushkey_ts / 1000),
- 'data': self.data_minus_url,
- 'tweaks': tweaks
+ "app_id": self.app_id,
+ "pushkey": self.pushkey,
+ "pushkey_ts": long(self.pushkey_ts / 1000),
+ "data": self.data_minus_url,
+ "tweaks": tweaks,
}
- ]
+ ],
}
}
- if event.type == 'm.room.member' and event.is_state():
- d['notification']['membership'] = event.content['membership']
- d['notification']['user_is_target'] = event.state_key == self.user_id
+ if event.type == "m.room.member" and event.is_state():
+ d["notification"]["membership"] = event.content["membership"]
+ d["notification"]["user_is_target"] = event.state_key == self.user_id
if self.hs.config.push_include_content and event.content:
- d['notification']['content'] = event.content
+ d["notification"]["content"] = event.content
# We no longer send aliases separately, instead, we send the human
# readable name of the room, which may be an alias.
- if 'sender_display_name' in ctx and len(ctx['sender_display_name']) > 0:
- d['notification']['sender_display_name'] = ctx['sender_display_name']
- if 'name' in ctx and len(ctx['name']) > 0:
- d['notification']['room_name'] = ctx['name']
+ if "sender_display_name" in ctx and len(ctx["sender_display_name"]) > 0:
+ d["notification"]["sender_display_name"] = ctx["sender_display_name"]
+ if "name" in ctx and len(ctx["name"]) > 0:
+ d["notification"]["room_name"] = ctx["name"]
- defer.returnValue(d)
+ return d
@defer.inlineCallbacks
def dispatch_push(self, event, tweaks, badge):
notification_dict = yield self._build_notification_dict(event, tweaks, badge)
if not notification_dict:
- defer.returnValue([])
+ return []
try:
- resp = yield self.http_client.post_json_get_json(self.url, notification_dict)
+ resp = yield self.http_client.post_json_get_json(
+ self.url, notification_dict
+ )
except Exception as e:
logger.warning(
"Failed to push event %s to %s: %s %s",
- event.event_id, self.name, type(e), e,
+ event.event_id,
+ self.name,
+ type(e),
+ e,
)
- defer.returnValue(False)
+ return False
rejected = []
- if 'rejected' in resp:
- rejected = resp['rejected']
- defer.returnValue(rejected)
+ if "rejected" in resp:
+ rejected = resp["rejected"]
+ return rejected
@defer.inlineCallbacks
def _send_badge(self, badge):
@@ -379,23 +398,21 @@ class HttpPusher(object):
Args:
badge (int): number of unread messages
"""
- logger.info("Sending updated badge count %d to %s", badge, self.name)
+ logger.debug("Sending updated badge count %d to %s", badge, self.name)
d = {
- 'notification': {
- 'id': '',
- 'type': None,
- 'sender': '',
- 'counts': {
- 'unread': badge
- },
- 'devices': [
+ "notification": {
+ "id": "",
+ "type": None,
+ "sender": "",
+ "counts": {"unread": badge},
+ "devices": [
{
- 'app_id': self.app_id,
- 'pushkey': self.pushkey,
- 'pushkey_ts': long(self.pushkey_ts / 1000),
- 'data': self.data_minus_url,
+ "app_id": self.app_id,
+ "pushkey": self.pushkey,
+ "pushkey_ts": long(self.pushkey_ts / 1000),
+ "data": self.data_minus_url,
}
- ]
+ ],
}
}
try:
@@ -403,7 +420,6 @@ class HttpPusher(object):
http_badges_processed_counter.inc()
except Exception as e:
logger.warning(
- "Failed to send badge count to %s: %s %s",
- self.name, type(e), e,
+ "Failed to send badge count to %s: %s %s", self.name, type(e), e
)
http_badges_failed_counter.inc()
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 099f9545ab..73580c1c6c 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -29,6 +29,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
+from synapse.logging.context import make_deferred_yieldable
from synapse.push.presentable_names import (
calculate_room_name,
descriptor_from_member_events,
@@ -36,23 +37,26 @@ from synapse.push.presentable_names import (
)
from synapse.types import UserID
from synapse.util.async_helpers import concurrently_execute
-from synapse.util.logcontext import make_deferred_yieldable
from synapse.visibility import filter_events_for_client
logger = logging.getLogger(__name__)
-MESSAGE_FROM_PERSON_IN_ROOM = "You have a message on %(app)s from %(person)s " \
- "in the %(room)s room..."
+MESSAGE_FROM_PERSON_IN_ROOM = (
+ "You have a message on %(app)s from %(person)s in the %(room)s room..."
+)
MESSAGE_FROM_PERSON = "You have a message on %(app)s from %(person)s..."
MESSAGES_FROM_PERSON = "You have messages on %(app)s from %(person)s..."
MESSAGES_IN_ROOM = "You have messages on %(app)s in the %(room)s room..."
-MESSAGES_IN_ROOM_AND_OTHERS = \
+MESSAGES_IN_ROOM_AND_OTHERS = (
"You have messages on %(app)s in the %(room)s room and others..."
-MESSAGES_FROM_PERSON_AND_OTHERS = \
+)
+MESSAGES_FROM_PERSON_AND_OTHERS = (
"You have messages on %(app)s from %(person)s and others..."
-INVITE_FROM_PERSON_TO_ROOM = "%(person)s has invited you to join the " \
- "%(room)s room on %(app)s..."
+)
+INVITE_FROM_PERSON_TO_ROOM = (
+ "%(person)s has invited you to join the %(room)s room on %(app)s..."
+)
INVITE_FROM_PERSON = "%(person)s has invited you to chat on %(app)s..."
CONTEXT_BEFORE = 1
@@ -60,12 +64,38 @@ CONTEXT_AFTER = 1
# From https://github.com/matrix-org/matrix-react-sdk/blob/master/src/HtmlUtils.js
ALLOWED_TAGS = [
- 'font', # custom to matrix for IRC-style font coloring
- 'del', # for markdown
+ "font", # custom to matrix for IRC-style font coloring
+ "del", # for markdown
# deliberately no h1/h2 to stop people shouting.
- 'h3', 'h4', 'h5', 'h6', 'blockquote', 'p', 'a', 'ul', 'ol',
- 'nl', 'li', 'b', 'i', 'u', 'strong', 'em', 'strike', 'code', 'hr', 'br', 'div',
- 'table', 'thead', 'caption', 'tbody', 'tr', 'th', 'td', 'pre'
+ "h3",
+ "h4",
+ "h5",
+ "h6",
+ "blockquote",
+ "p",
+ "a",
+ "ul",
+ "ol",
+ "nl",
+ "li",
+ "b",
+ "i",
+ "u",
+ "strong",
+ "em",
+ "strike",
+ "code",
+ "hr",
+ "br",
+ "div",
+ "table",
+ "thead",
+ "caption",
+ "tbody",
+ "tr",
+ "th",
+ "td",
+ "pre",
]
ALLOWED_ATTRS = {
# custom ones first:
@@ -89,59 +119,105 @@ class Mailer(object):
self.store = self.hs.get_datastore()
self.macaroon_gen = self.hs.get_macaroon_generator()
self.state_handler = self.hs.get_state_handler()
+ self.storage = hs.get_storage()
self.app_name = app_name
logger.info("Created Mailer for app_name %s" % app_name)
@defer.inlineCallbacks
- def send_password_reset_mail(
- self,
- email_address,
- token,
- client_secret,
- sid,
- ):
+ def send_password_reset_mail(self, email_address, token, client_secret, sid):
"""Send an email with a password reset link to a user
Args:
email_address (str): Email address we're sending the password
reset to
token (str): Unique token generated by the server to verify
- password reset email was received
+ the email was received
client_secret (str): Unique token generated by the client to
group together multiple email sending attempts
sid (str): The generated session ID
"""
- if email.utils.parseaddr(email_address)[1] == '':
- raise RuntimeError("Invalid 'to' email address")
+ params = {"token": token, "client_secret": client_secret, "sid": sid}
+ link = (
+ self.hs.config.public_baseurl
+ + "_matrix/client/unstable/password_reset/email/submit_token?%s"
+ % urllib.parse.urlencode(params)
+ )
+ template_vars = {"link": link}
+
+ yield self.send_email(
+ email_address,
+ "[%s] Password Reset" % self.hs.config.server_name,
+ template_vars,
+ )
+
+ @defer.inlineCallbacks
+ def send_registration_mail(self, email_address, token, client_secret, sid):
+ """Send an email with a registration confirmation link to a user
+
+ Args:
+ email_address (str): Email address we're sending the registration
+ link to
+ token (str): Unique token generated by the server to verify
+ the email was received
+ client_secret (str): Unique token generated by the client to
+ group together multiple email sending attempts
+ sid (str): The generated session ID
+ """
+ params = {"token": token, "client_secret": client_secret, "sid": sid}
link = (
- self.hs.config.public_baseurl +
- "_matrix/client/unstable/password_reset/email/submit_token"
- "?token=%s&client_secret=%s&sid=%s" %
- (token, client_secret, sid)
+ self.hs.config.public_baseurl
+ + "_matrix/client/unstable/registration/email/submit_token?%s"
+ % urllib.parse.urlencode(params)
)
- template_vars = {
- "link": link,
- }
+ template_vars = {"link": link}
yield self.send_email(
email_address,
- "[%s] Password Reset Email" % self.hs.config.server_name,
+ "[%s] Register your Email Address" % self.hs.config.server_name,
template_vars,
)
@defer.inlineCallbacks
- def send_notification_mail(self, app_id, user_id, email_address,
- push_actions, reason):
- """Send email regarding a user's room notifications"""
- rooms_in_order = deduped_ordered_list(
- [pa['room_id'] for pa in push_actions]
+ def send_add_threepid_mail(self, email_address, token, client_secret, sid):
+ """Send an email with a validation link to a user for adding a 3pid to their account
+
+ Args:
+ email_address (str): Email address we're sending the validation link to
+
+ token (str): Unique token generated by the server to verify the email was received
+
+ client_secret (str): Unique token generated by the client to group together
+ multiple email sending attempts
+
+ sid (str): The generated session ID
+ """
+ params = {"token": token, "client_secret": client_secret, "sid": sid}
+ link = (
+ self.hs.config.public_baseurl
+ + "_matrix/client/unstable/add_threepid/email/submit_token?%s"
+ % urllib.parse.urlencode(params)
+ )
+
+ template_vars = {"link": link}
+
+ yield self.send_email(
+ email_address,
+ "[%s] Validate Your Email" % self.hs.config.server_name,
+ template_vars,
)
+ @defer.inlineCallbacks
+ def send_notification_mail(
+ self, app_id, user_id, email_address, push_actions, reason
+ ):
+ """Send email regarding a user's room notifications"""
+ rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions])
+
notif_events = yield self.store.get_events(
- [pa['event_id'] for pa in push_actions]
+ [pa["event_id"] for pa in push_actions]
)
notifs_by_room = {}
@@ -171,9 +247,7 @@ class Mailer(object):
yield concurrently_execute(_fetch_room_state, rooms_in_order, 3)
# actually sort our so-called rooms_in_order list, most recent room first
- rooms_in_order.sort(
- key=lambda r: -(notifs_by_room[r][-1]['received_ts'] or 0)
- )
+ rooms_in_order.sort(key=lambda r: -(notifs_by_room[r][-1]["received_ts"] or 0))
rooms = []
@@ -183,9 +257,11 @@ class Mailer(object):
)
rooms.append(roomvars)
- reason['room_name'] = yield calculate_room_name(
- self.store, state_by_room[reason['room_id']], user_id,
- fallback_to_members=True
+ reason["room_name"] = yield calculate_room_name(
+ self.store,
+ state_by_room[reason["room_id"]],
+ user_id,
+ fallback_to_members=True,
)
summary_text = yield self.make_summary_text(
@@ -204,25 +280,21 @@ class Mailer(object):
}
yield self.send_email(
- email_address,
- "[%s] %s" % (self.app_name, summary_text),
- template_vars,
+ email_address, "[%s] %s" % (self.app_name, summary_text), template_vars
)
@defer.inlineCallbacks
def send_email(self, email_address, subject, template_vars):
"""Send an email with the given information and template text"""
try:
- from_string = self.hs.config.email_notif_from % {
- "app": self.app_name
- }
+ from_string = self.hs.config.email_notif_from % {"app": self.app_name}
except TypeError:
from_string = self.hs.config.email_notif_from
raw_from = email.utils.parseaddr(from_string)[1]
raw_to = email.utils.parseaddr(email_address)[1]
- if raw_to == '':
+ if raw_to == "":
raise RuntimeError("Invalid 'to' address")
html_text = self.template_html.render(**template_vars)
@@ -231,27 +303,31 @@ class Mailer(object):
plain_text = self.template_text.render(**template_vars)
text_part = MIMEText(plain_text, "plain", "utf8")
- multipart_msg = MIMEMultipart('alternative')
- multipart_msg['Subject'] = subject
- multipart_msg['From'] = from_string
- multipart_msg['To'] = email_address
- multipart_msg['Date'] = email.utils.formatdate()
- multipart_msg['Message-ID'] = email.utils.make_msgid()
+ multipart_msg = MIMEMultipart("alternative")
+ multipart_msg["Subject"] = subject
+ multipart_msg["From"] = from_string
+ multipart_msg["To"] = email_address
+ multipart_msg["Date"] = email.utils.formatdate()
+ multipart_msg["Message-ID"] = email.utils.make_msgid()
multipart_msg.attach(text_part)
multipart_msg.attach(html_part)
- logger.info("Sending email push notification to %s" % email_address)
-
- yield make_deferred_yieldable(self.sendmail(
- self.hs.config.email_smtp_host,
- raw_from, raw_to, multipart_msg.as_string().encode('utf8'),
- reactor=self.hs.get_reactor(),
- port=self.hs.config.email_smtp_port,
- requireAuthentication=self.hs.config.email_smtp_user is not None,
- username=self.hs.config.email_smtp_user,
- password=self.hs.config.email_smtp_pass,
- requireTransportSecurity=self.hs.config.require_transport_security
- ))
+ logger.info("Sending email to %s" % email_address)
+
+ yield make_deferred_yieldable(
+ self.sendmail(
+ self.hs.config.email_smtp_host,
+ raw_from,
+ raw_to,
+ multipart_msg.as_string().encode("utf8"),
+ reactor=self.hs.get_reactor(),
+ port=self.hs.config.email_smtp_port,
+ requireAuthentication=self.hs.config.email_smtp_user is not None,
+ username=self.hs.config.email_smtp_user,
+ password=self.hs.config.email_smtp_pass,
+ requireTransportSecurity=self.hs.config.require_transport_security,
+ )
+ )
@defer.inlineCallbacks
def get_room_vars(self, room_id, user_id, notifs, notif_events, room_state_ids):
@@ -272,17 +348,18 @@ class Mailer(object):
if not is_invite:
for n in notifs:
notifvars = yield self.get_notif_vars(
- n, user_id, notif_events[n['event_id']], room_state_ids
+ n, user_id, notif_events[n["event_id"]], room_state_ids
)
# merge overlapping notifs together.
# relies on the notifs being in chronological order.
merge = False
- if room_vars['notifs'] and 'messages' in room_vars['notifs'][-1]:
- prev_messages = room_vars['notifs'][-1]['messages']
- for message in notifvars['messages']:
- pm = list(filter(lambda pm: pm['id'] == message['id'],
- prev_messages))
+ if room_vars["notifs"] and "messages" in room_vars["notifs"][-1]:
+ prev_messages = room_vars["notifs"][-1]["messages"]
+ for message in notifvars["messages"]:
+ pm = list(
+ filter(lambda pm: pm["id"] == message["id"], prev_messages)
+ )
if pm:
if not message["is_historical"]:
pm[0]["is_historical"] = False
@@ -293,34 +370,36 @@ class Mailer(object):
prev_messages.append(message)
if not merge:
- room_vars['notifs'].append(notifvars)
+ room_vars["notifs"].append(notifvars)
- defer.returnValue(room_vars)
+ return room_vars
@defer.inlineCallbacks
def get_notif_vars(self, notif, user_id, notif_event, room_state_ids):
results = yield self.store.get_events_around(
- notif['room_id'], notif['event_id'],
- before_limit=CONTEXT_BEFORE, after_limit=CONTEXT_AFTER
+ notif["room_id"],
+ notif["event_id"],
+ before_limit=CONTEXT_BEFORE,
+ after_limit=CONTEXT_AFTER,
)
ret = {
"link": self.make_notif_link(notif),
- "ts": notif['received_ts'],
+ "ts": notif["received_ts"],
"messages": [],
}
the_events = yield filter_events_for_client(
- self.store, user_id, results["events_before"]
+ self.storage, user_id, results["events_before"]
)
the_events.append(notif_event)
for event in the_events:
messagevars = yield self.get_message_vars(notif, event, room_state_ids)
if messagevars is not None:
- ret['messages'].append(messagevars)
+ ret["messages"].append(messagevars)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def get_message_vars(self, notif, event, room_state_ids):
@@ -340,7 +419,7 @@ class Mailer(object):
ret = {
"msgtype": msgtype,
- "is_historical": event.event_id != notif['event_id'],
+ "is_historical": event.event_id != notif["event_id"],
"id": event.event_id,
"ts": event.origin_server_ts,
"sender_name": sender_name,
@@ -356,7 +435,7 @@ class Mailer(object):
if "body" in event.content:
ret["body_text_plain"] = event.content["body"]
- defer.returnValue(ret)
+ return ret
def add_text_message_vars(self, messagevars, event):
msgformat = event.content.get("format")
@@ -379,8 +458,9 @@ class Mailer(object):
return messagevars
@defer.inlineCallbacks
- def make_summary_text(self, notifs_by_room, room_state_ids,
- notif_events, user_id, reason):
+ def make_summary_text(
+ self, notifs_by_room, room_state_ids, notif_events, user_id, reason
+ ):
if len(notifs_by_room) == 1:
# Only one room has new stuff
room_id = list(notifs_by_room.keys())[0]
@@ -404,16 +484,16 @@ class Mailer(object):
inviter_name = name_from_member_event(inviter_member_event)
if room_name is None:
- defer.returnValue(INVITE_FROM_PERSON % {
+ return INVITE_FROM_PERSON % {
"person": inviter_name,
- "app": self.app_name
- })
+ "app": self.app_name,
+ }
else:
- defer.returnValue(INVITE_FROM_PERSON_TO_ROOM % {
+ return INVITE_FROM_PERSON_TO_ROOM % {
"person": inviter_name,
"room": room_name,
"app": self.app_name,
- })
+ }
sender_name = None
if len(notifs_by_room[room_id]) == 1:
@@ -427,67 +507,71 @@ class Mailer(object):
sender_name = name_from_member_event(state_event)
if sender_name is not None and room_name is not None:
- defer.returnValue(MESSAGE_FROM_PERSON_IN_ROOM % {
+ return MESSAGE_FROM_PERSON_IN_ROOM % {
"person": sender_name,
"room": room_name,
"app": self.app_name,
- })
+ }
elif sender_name is not None:
- defer.returnValue(MESSAGE_FROM_PERSON % {
+ return MESSAGE_FROM_PERSON % {
"person": sender_name,
"app": self.app_name,
- })
+ }
else:
# There's more than one notification for this room, so just
# say there are several
if room_name is not None:
- defer.returnValue(MESSAGES_IN_ROOM % {
- "room": room_name,
- "app": self.app_name,
- })
+ return MESSAGES_IN_ROOM % {"room": room_name, "app": self.app_name}
else:
# If the room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room"
- sender_ids = list(set([
- notif_events[n['event_id']].sender
- for n in notifs_by_room[room_id]
- ]))
-
- member_events = yield self.store.get_events([
- room_state_ids[room_id][("m.room.member", s)]
- for s in sender_ids
- ])
-
- defer.returnValue(MESSAGES_FROM_PERSON % {
+ sender_ids = list(
+ {
+ notif_events[n["event_id"]].sender
+ for n in notifs_by_room[room_id]
+ }
+ )
+
+ member_events = yield self.store.get_events(
+ [
+ room_state_ids[room_id][("m.room.member", s)]
+ for s in sender_ids
+ ]
+ )
+
+ return MESSAGES_FROM_PERSON % {
"person": descriptor_from_member_events(member_events.values()),
"app": self.app_name,
- })
+ }
else:
# Stuff's happened in multiple different rooms
# ...but we still refer to the 'reason' room which triggered the mail
- if reason['room_name'] is not None:
- defer.returnValue(MESSAGES_IN_ROOM_AND_OTHERS % {
- "room": reason['room_name'],
+ if reason["room_name"] is not None:
+ return MESSAGES_IN_ROOM_AND_OTHERS % {
+ "room": reason["room_name"],
"app": self.app_name,
- })
+ }
else:
# If the reason room doesn't have a name, say who the messages
# are from explicitly to avoid, "messages in the Bob room"
- sender_ids = list(set([
- notif_events[n['event_id']].sender
- for n in notifs_by_room[reason['room_id']]
- ]))
+ room_id = reason["room_id"]
+
+ sender_ids = list(
+ {
+ notif_events[n["event_id"]].sender
+ for n in notifs_by_room[room_id]
+ }
+ )
- member_events = yield self.store.get_events([
- room_state_ids[room_id][("m.room.member", s)]
- for s in sender_ids
- ])
+ member_events = yield self.store.get_events(
+ [room_state_ids[room_id][("m.room.member", s)] for s in sender_ids]
+ )
- defer.returnValue(MESSAGES_FROM_PERSON_AND_OTHERS % {
+ return MESSAGES_FROM_PERSON_AND_OTHERS % {
"person": descriptor_from_member_events(member_events.values()),
"app": self.app_name,
- })
+ }
def make_room_link(self, room_id):
if self.hs.config.email_riot_base_url:
@@ -503,17 +587,17 @@ class Mailer(object):
if self.hs.config.email_riot_base_url:
return "%s/#/room/%s/%s" % (
self.hs.config.email_riot_base_url,
- notif['room_id'], notif['event_id']
+ notif["room_id"],
+ notif["event_id"],
)
elif self.app_name == "Vector":
# need /beta for Universal Links to work on iOS
return "https://vector.im/beta/#/room/%s/%s" % (
- notif['room_id'], notif['event_id']
+ notif["room_id"],
+ notif["event_id"],
)
else:
- return "https://matrix.to/#/%s/%s" % (
- notif['room_id'], notif['event_id']
- )
+ return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"])
def make_unsubscribe_link(self, user_id, app_id, email_address):
params = {
@@ -530,12 +614,18 @@ class Mailer(object):
def safe_markup(raw_html):
- return jinja2.Markup(bleach.linkify(bleach.clean(
- raw_html, tags=ALLOWED_TAGS, attributes=ALLOWED_ATTRS,
- # bleach master has this, but it isn't released yet
- # protocols=ALLOWED_SCHEMES,
- strip=True
- )))
+ return jinja2.Markup(
+ bleach.linkify(
+ bleach.clean(
+ raw_html,
+ tags=ALLOWED_TAGS,
+ attributes=ALLOWED_ATTRS,
+ # bleach master has this, but it isn't released yet
+ # protocols=ALLOWED_SCHEMES,
+ strip=True,
+ )
+ )
+ )
def safe_text(raw_text):
@@ -543,10 +633,9 @@ def safe_text(raw_text):
Process text: treat it as HTML but escape any tags (ie. just escape the
HTML) then linkify it.
"""
- return jinja2.Markup(bleach.linkify(bleach.clean(
- raw_text, tags=[], attributes={},
- strip=False
- )))
+ return jinja2.Markup(
+ bleach.linkify(bleach.clean(raw_text, tags=[], attributes={}, strip=False))
+ )
def deduped_ordered_list(l):
@@ -570,42 +659,63 @@ def format_ts_filter(value, format):
return time.strftime(format, time.localtime(value / 1000))
-def load_jinja2_templates(config, template_html_name, template_text_name):
- """Load the jinja2 email templates from disk
+def load_jinja2_templates(
+ template_dir,
+ template_filenames,
+ apply_format_ts_filter=False,
+ apply_mxc_to_http_filter=False,
+ public_baseurl=None,
+):
+ """Loads and returns one or more jinja2 templates and applies optional filters
+
+ Args:
+ template_dir (str): The directory where templates are stored
+ template_filenames (list[str]): A list of template filenames
+ apply_format_ts_filter (bool): Whether to apply a template filter that formats
+ timestamps
+ apply_mxc_to_http_filter (bool): Whether to apply a template filter that converts
+ mxc urls to http urls
+ public_baseurl (str|None): The public baseurl of the server. Required for
+ apply_mxc_to_http_filter to be enabled
Returns:
- (template_html, template_text)
+ A list of jinja2 templates corresponding to the given list of filenames,
+ with order preserved
"""
- logger.info("loading email templates from '%s'", config.email_template_dir)
- loader = jinja2.FileSystemLoader(config.email_template_dir)
+ logger.info(
+ "loading email templates %s from '%s'", template_filenames, template_dir
+ )
+ loader = jinja2.FileSystemLoader(template_dir)
env = jinja2.Environment(loader=loader)
- env.filters["format_ts"] = format_ts_filter
- env.filters["mxc_to_http"] = _create_mxc_to_http_filter(config)
- template_html = env.get_template(template_html_name)
- template_text = env.get_template(template_text_name)
+ if apply_format_ts_filter:
+ env.filters["format_ts"] = format_ts_filter
- return template_html, template_text
+ if apply_mxc_to_http_filter and public_baseurl:
+ env.filters["mxc_to_http"] = _create_mxc_to_http_filter(public_baseurl)
+ templates = []
+ for template_filename in template_filenames:
+ template = env.get_template(template_filename)
+ templates.append(template)
-def _create_mxc_to_http_filter(config):
+ return templates
+
+
+def _create_mxc_to_http_filter(public_baseurl):
def mxc_to_http_filter(value, width, height, resize_method="crop"):
if value[0:6] != "mxc://":
return ""
serverAndMediaId = value[6:]
fragment = None
- if '#' in serverAndMediaId:
- (serverAndMediaId, fragment) = serverAndMediaId.split('#', 1)
+ if "#" in serverAndMediaId:
+ (serverAndMediaId, fragment) = serverAndMediaId.split("#", 1)
fragment = "#" + fragment
- params = {
- "width": width,
- "height": height,
- "method": resize_method,
- }
+ params = {"width": width, "height": height, "method": resize_method}
return "%s_matrix/media/v1/thumbnail/%s?%s%s" % (
- config.public_baseurl,
+ public_baseurl,
serverAndMediaId,
urllib.parse.urlencode(params),
fragment or "",
diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py
index eef6e18c2e..0644a13cfc 100644
--- a/synapse/push/presentable_names.py
+++ b/synapse/push/presentable_names.py
@@ -18,6 +18,8 @@ import re
from twisted.internet import defer
+from synapse.api.constants import EventTypes
+
logger = logging.getLogger(__name__)
# intentionally looser than what aliases we allow to be registered since
@@ -28,8 +30,13 @@ ALL_ALONE = "Empty Room"
@defer.inlineCallbacks
-def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True,
- fallback_to_single_member=True):
+def calculate_room_name(
+ store,
+ room_state_ids,
+ user_id,
+ fallback_to_members=True,
+ fallback_to_single_member=True,
+):
"""
Works out a user-facing name for the given room as per Matrix
spec recommendations.
@@ -45,79 +52,69 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True
(string or None) A human readable name for the room.
"""
# does it have a name?
- if ("m.room.name", "") in room_state_ids:
+ if (EventTypes.Name, "") in room_state_ids:
m_room_name = yield store.get_event(
- room_state_ids[("m.room.name", "")], allow_none=True
+ room_state_ids[(EventTypes.Name, "")], allow_none=True
)
if m_room_name and m_room_name.content and m_room_name.content["name"]:
- defer.returnValue(m_room_name.content["name"])
+ return m_room_name.content["name"]
# does it have a canonical alias?
- if ("m.room.canonical_alias", "") in room_state_ids:
+ if (EventTypes.CanonicalAlias, "") in room_state_ids:
canon_alias = yield store.get_event(
- room_state_ids[("m.room.canonical_alias", "")], allow_none=True
+ room_state_ids[(EventTypes.CanonicalAlias, "")], allow_none=True
)
if (
- canon_alias and canon_alias.content and canon_alias.content["alias"] and
- _looks_like_an_alias(canon_alias.content["alias"])
+ canon_alias
+ and canon_alias.content
+ and canon_alias.content["alias"]
+ and _looks_like_an_alias(canon_alias.content["alias"])
):
- defer.returnValue(canon_alias.content["alias"])
+ return canon_alias.content["alias"]
# at this point we're going to need to search the state by all state keys
# for an event type, so rearrange the data structure
room_state_bytype_ids = _state_as_two_level_dict(room_state_ids)
- # right then, any aliases at all?
- if "m.room.aliases" in room_state_bytype_ids:
- m_room_aliases = room_state_bytype_ids["m.room.aliases"]
- for alias_id in m_room_aliases.values():
- alias_event = yield store.get_event(
- alias_id, allow_none=True
- )
- if alias_event and alias_event.content.get("aliases"):
- the_aliases = alias_event.content["aliases"]
- if len(the_aliases) > 0 and _looks_like_an_alias(the_aliases[0]):
- defer.returnValue(the_aliases[0])
-
if not fallback_to_members:
- defer.returnValue(None)
+ return None
my_member_event = None
- if ("m.room.member", user_id) in room_state_ids:
+ if (EventTypes.Member, user_id) in room_state_ids:
my_member_event = yield store.get_event(
- room_state_ids[("m.room.member", user_id)], allow_none=True
+ room_state_ids[(EventTypes.Member, user_id)], allow_none=True
)
if (
- my_member_event is not None and
- my_member_event.content['membership'] == "invite"
+ my_member_event is not None
+ and my_member_event.content["membership"] == "invite"
):
- if ("m.room.member", my_member_event.sender) in room_state_ids:
+ if (EventTypes.Member, my_member_event.sender) in room_state_ids:
inviter_member_event = yield store.get_event(
- room_state_ids[("m.room.member", my_member_event.sender)],
+ room_state_ids[(EventTypes.Member, my_member_event.sender)],
allow_none=True,
)
if inviter_member_event:
if fallback_to_single_member:
- defer.returnValue(
- "Invite from %s" % (
- name_from_member_event(inviter_member_event),
- )
+ return "Invite from %s" % (
+ name_from_member_event(inviter_member_event),
)
else:
return
else:
- defer.returnValue("Room Invite")
+ return "Room Invite"
# we're going to have to generate a name based on who's in the room,
# so find out who is in the room that isn't the user.
- if "m.room.member" in room_state_bytype_ids:
+ if EventTypes.Member in room_state_bytype_ids:
member_events = yield store.get_events(
- list(room_state_bytype_ids["m.room.member"].values())
+ list(room_state_bytype_ids[EventTypes.Member].values())
)
all_members = [
- ev for ev in member_events.values()
- if ev.content['membership'] == "join" or ev.content['membership'] == "invite"
+ ev
+ for ev in member_events.values()
+ if ev.content["membership"] == "join"
+ or ev.content["membership"] == "invite"
]
# Sort the member events oldest-first so the we name people in the
# order the joined (it should at least be deterministic rather than
@@ -133,10 +130,10 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True
# self-chat, peeked room with 1 participant,
# or inbound invite, or outbound 3PID invite.
if all_members[0].sender == user_id:
- if "m.room.third_party_invite" in room_state_bytype_ids:
- third_party_invites = (
- room_state_bytype_ids["m.room.third_party_invite"].values()
- )
+ if EventTypes.ThirdPartyInvite in room_state_bytype_ids:
+ third_party_invites = room_state_bytype_ids[
+ EventTypes.ThirdPartyInvite
+ ].values()
if len(third_party_invites) > 0:
# technically third party invite events are not member
@@ -148,20 +145,31 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True
# return "Inviting %s" % (
# descriptor_from_member_events(third_party_invites)
# )
- defer.returnValue("Inviting email address")
+ return "Inviting email address"
else:
- defer.returnValue(ALL_ALONE)
+ return ALL_ALONE
else:
- defer.returnValue(name_from_member_event(all_members[0]))
+ return name_from_member_event(all_members[0])
else:
- defer.returnValue(ALL_ALONE)
+ return ALL_ALONE
elif len(other_members) == 1 and not fallback_to_single_member:
return
else:
- defer.returnValue(descriptor_from_member_events(other_members))
+ return descriptor_from_member_events(other_members)
def descriptor_from_member_events(member_events):
+ """Get a description of the room based on the member events.
+
+ Args:
+ member_events (Iterable[FrozenEvent])
+
+ Returns:
+ str
+ """
+
+ member_events = list(member_events)
+
if len(member_events) == 0:
return "nobody"
elif len(member_events) == 1:
@@ -180,8 +188,9 @@ def descriptor_from_member_events(member_events):
def name_from_member_event(member_event):
if (
- member_event.content and "displayname" in member_event.content and
- member_event.content["displayname"]
+ member_event.content
+ and "displayname" in member_event.content
+ and member_event.content["displayname"]
):
return member_event.content["displayname"]
return member_event.state_key
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index cf6c8b875e..4cd702b5fa 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -16,9 +16,11 @@
import logging
import re
+from typing import Pattern
from six import string_types
+from synapse.events import EventBase
from synapse.types import UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches.lrucache import LruCache
@@ -26,8 +28,8 @@ from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__)
-GLOB_REGEX = re.compile(r'\\\[(\\\!|)(.*)\\\]')
-IS_GLOB = re.compile(r'[\?\*\[\]]')
+GLOB_REGEX = re.compile(r"\\\[(\\\!|)(.*)\\\]")
+IS_GLOB = re.compile(r"[\?\*\[\]]")
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
@@ -36,38 +38,38 @@ def _room_member_count(ev, condition, room_member_count):
def _sender_notification_permission(ev, condition, sender_power_level, power_levels):
- notif_level_key = condition.get('key')
+ notif_level_key = condition.get("key")
if notif_level_key is None:
return False
- notif_levels = power_levels.get('notifications', {})
+ notif_levels = power_levels.get("notifications", {})
room_notif_level = notif_levels.get(notif_level_key, 50)
return sender_power_level >= room_notif_level
def _test_ineq_condition(condition, number):
- if 'is' not in condition:
+ if "is" not in condition:
return False
- m = INEQUALITY_EXPR.match(condition['is'])
+ m = INEQUALITY_EXPR.match(condition["is"])
if not m:
return False
ineq = m.group(1)
rhs = m.group(2)
if not rhs.isdigit():
return False
- rhs = int(rhs)
-
- if ineq == '' or ineq == '==':
- return number == rhs
- elif ineq == '<':
- return number < rhs
- elif ineq == '>':
- return number > rhs
- elif ineq == '>=':
- return number >= rhs
- elif ineq == '<=':
- return number <= rhs
+ rhs_int = int(rhs)
+
+ if ineq == "" or ineq == "==":
+ return number == rhs_int
+ elif ineq == "<":
+ return number < rhs_int
+ elif ineq == ">":
+ return number > rhs_int
+ elif ineq == ">=":
+ return number >= rhs_int
+ elif ineq == "<=":
+ return number <= rhs_int
else:
return False
@@ -77,13 +79,19 @@ def tweaks_for_actions(actions):
for a in actions:
if not isinstance(a, dict):
continue
- if 'set_tweak' in a and 'value' in a:
- tweaks[a['set_tweak']] = a['value']
+ if "set_tweak" in a and "value" in a:
+ tweaks[a["set_tweak"]] = a["value"]
return tweaks
class PushRuleEvaluatorForEvent(object):
- def __init__(self, event, room_member_count, sender_power_level, power_levels):
+ def __init__(
+ self,
+ event: EventBase,
+ room_member_count: int,
+ sender_power_level: int,
+ power_levels: dict,
+ ):
self._event = event
self._room_member_count = room_member_count
self._sender_power_level = sender_power_level
@@ -92,51 +100,49 @@ class PushRuleEvaluatorForEvent(object):
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
self._value_cache = _flatten_dict(event)
- def matches(self, condition, user_id, display_name):
- if condition['kind'] == 'event_match':
+ def matches(self, condition: dict, user_id: str, display_name: str) -> bool:
+ if condition["kind"] == "event_match":
return self._event_match(condition, user_id)
- elif condition['kind'] == 'contains_display_name':
+ elif condition["kind"] == "contains_display_name":
return self._contains_display_name(display_name)
- elif condition['kind'] == 'room_member_count':
- return _room_member_count(
- self._event, condition, self._room_member_count
- )
- elif condition['kind'] == 'sender_notification_permission':
+ elif condition["kind"] == "room_member_count":
+ return _room_member_count(self._event, condition, self._room_member_count)
+ elif condition["kind"] == "sender_notification_permission":
return _sender_notification_permission(
- self._event, condition, self._sender_power_level, self._power_levels,
+ self._event, condition, self._sender_power_level, self._power_levels
)
else:
return True
- def _event_match(self, condition, user_id):
- pattern = condition.get('pattern', None)
+ def _event_match(self, condition: dict, user_id: str) -> bool:
+ pattern = condition.get("pattern", None)
if not pattern:
- pattern_type = condition.get('pattern_type', None)
+ pattern_type = condition.get("pattern_type", None)
if pattern_type == "user_id":
pattern = user_id
elif pattern_type == "user_localpart":
pattern = UserID.from_string(user_id).localpart
if not pattern:
- logger.warn("event_match condition with no pattern")
+ logger.warning("event_match condition with no pattern")
return False
# XXX: optimisation: cache our pattern regexps
- if condition['key'] == 'content.body':
+ if condition["key"] == "content.body":
body = self._event.content.get("body", None)
if not body:
return False
return _glob_matches(pattern, body, word_boundary=True)
else:
- haystack = self._get_value(condition['key'])
+ haystack = self._get_value(condition["key"])
if haystack is None:
return False
return _glob_matches(pattern, haystack)
- def _contains_display_name(self, display_name):
+ def _contains_display_name(self, display_name: str) -> bool:
if not display_name:
return False
@@ -144,65 +150,63 @@ class PushRuleEvaluatorForEvent(object):
if not body:
return False
- return _glob_matches(display_name, body, word_boundary=True)
+ # Similar to _glob_matches, but do not treat display_name as a glob.
+ r = regex_cache.get((display_name, False, True), None)
+ if not r:
+ r = re.escape(display_name)
+ r = _re_word_boundary(r)
+ r = re.compile(r, flags=re.IGNORECASE)
+ regex_cache[(display_name, False, True)] = r
+
+ return r.search(body)
- def _get_value(self, dotted_key):
+ def _get_value(self, dotted_key: str) -> str:
return self._value_cache.get(dotted_key, None)
-# Caches (glob, word_boundary) -> regex for push. See _glob_matches
+# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR)
register_cache("cache", "regex_push_cache", regex_cache)
-def _glob_matches(glob, value, word_boundary=False):
+def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
"""Tests if value matches glob.
Args:
- glob (string)
- value (string): String to test against glob.
- word_boundary (bool): Whether to match against word boundaries or entire
+ glob
+ value: String to test against glob.
+ word_boundary: Whether to match against word boundaries or entire
string. Defaults to False.
-
- Returns:
- bool
"""
try:
- r = regex_cache.get((glob, word_boundary), None)
+ r = regex_cache.get((glob, True, word_boundary), None)
if not r:
r = _glob_to_re(glob, word_boundary)
- regex_cache[(glob, word_boundary)] = r
+ regex_cache[(glob, True, word_boundary)] = r
return r.search(value)
except re.error:
- logger.warn("Failed to parse glob to regex: %r", glob)
+ logger.warning("Failed to parse glob to regex: %r", glob)
return False
-def _glob_to_re(glob, word_boundary):
+def _glob_to_re(glob: str, word_boundary: bool) -> Pattern:
"""Generates regex for a given glob.
Args:
- glob (string)
- word_boundary (bool): Whether to match against word boundaries or entire
- string. Defaults to False.
-
- Returns:
- regex object
+ glob
+ word_boundary: Whether to match against word boundaries or entire string.
"""
if IS_GLOB.search(glob):
r = re.escape(glob)
- r = r.replace(r'\*', '.*?')
- r = r.replace(r'\?', '.')
+ r = r.replace(r"\*", ".*?")
+ r = r.replace(r"\?", ".")
# handle [abc], [a-z] and [!a-z] style ranges.
r = GLOB_REGEX.sub(
lambda x: (
- '[%s%s]' % (
- x.group(1) and '^' or '',
- x.group(2).replace(r'\\\-', '-')
- )
+ "[%s%s]" % (x.group(1) and "^" or "", x.group(2).replace(r"\\\-", "-"))
),
r,
)
@@ -224,7 +228,7 @@ def _glob_to_re(glob, word_boundary):
return re.compile(r, flags=re.IGNORECASE)
-def _re_word_boundary(r):
+def _re_word_boundary(r: str) -> str:
"""
Adds word boundary characters to the start and end of an
expression to require that the match occur as a whole word,
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 8049c298c2..5dae4648c0 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -16,16 +16,15 @@
from twisted.internet import defer
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
+from synapse.storage import Storage
@defer.inlineCallbacks
def get_badge_count(store, user_id):
- invites = yield store.get_invited_rooms_for_user(user_id)
+ invites = yield store.get_invited_rooms_for_local_user(user_id)
joins = yield store.get_rooms_for_user(user_id)
- my_receipts_by_room = yield store.get_receipts_for_user(
- user_id, "m.read",
- )
+ my_receipts_by_room = yield store.get_receipts_for_user(user_id, "m.read")
badge = len(invites)
@@ -41,26 +40,26 @@ def get_badge_count(store, user_id):
# return one badge count per conversation, as count per
# message is so noisy as to be almost useless
badge += 1 if notifs["notify_count"] else 0
- defer.returnValue(badge)
+ return badge
@defer.inlineCallbacks
-def get_context_for_event(store, state_handler, ev, user_id):
+def get_context_for_event(storage: Storage, state_handler, ev, user_id):
ctx = {}
- room_state_ids = yield store.get_state_ids_for_event(ev.event_id)
+ room_state_ids = yield storage.state.get_state_ids_for_event(ev.event_id)
# we no longer bother setting room_alias, and make room_name the
# human-readable name instead, be that m.room.name, an alias or
# a list of people in the room
name = yield calculate_room_name(
- store, room_state_ids, user_id, fallback_to_single_member=False
+ storage.main, room_state_ids, user_id, fallback_to_single_member=False
)
if name:
- ctx['name'] = name
+ ctx["name"] = name
sender_state_event_id = room_state_ids[("m.room.member", ev.sender)]
- sender_state_event = yield store.get_event(sender_state_event_id)
- ctx['sender_display_name'] = name_from_member_event(sender_state_event)
+ sender_state_event = yield storage.main.get_event(sender_state_event_id)
+ ctx["sender_display_name"] = name_from_member_event(sender_state_event)
- defer.returnValue(ctx)
+ return ctx
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index aff85daeb5..8ad0bf5936 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -35,28 +35,31 @@ except Exception:
class PusherFactory(object):
def __init__(self, hs):
self.hs = hs
+ self.config = hs.config
- self.pusher_types = {
- "http": HttpPusher,
- }
+ self.pusher_types = {"http": HttpPusher}
logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
if hs.config.email_enable_notifs:
self.mailers = {} # app_name -> Mailer
- templates = load_jinja2_templates(
- config=hs.config,
- template_html_name=hs.config.email_notif_template_html,
- template_text_name=hs.config.email_notif_template_text,
+ self.notif_template_html, self.notif_template_text = load_jinja2_templates(
+ self.config.email_template_dir,
+ [
+ self.config.email_notif_template_html,
+ self.config.email_notif_template_text,
+ ],
+ apply_format_ts_filter=True,
+ apply_mxc_to_http_filter=True,
+ public_baseurl=self.config.public_baseurl,
)
- self.notif_template_html, self.notif_template_text = templates
self.pusher_types["email"] = self._create_email_pusher
logger.info("defined email pusher type")
def create_pusher(self, pusherdict):
- kind = pusherdict['kind']
+ kind = pusherdict["kind"]
f = self.pusher_types.get(kind, None)
if not f:
return None
@@ -77,9 +80,11 @@ class PusherFactory(object):
return EmailPusher(self.hs, pusherdict, mailer)
def _app_name_from_pusherdict(self, pusherdict):
- if 'data' in pusherdict and 'brand' in pusherdict['data']:
- app_name = pusherdict['data']['brand']
- else:
- app_name = self.hs.config.email_app_name
+ data = pusherdict["data"]
- return app_name
+ if isinstance(data, dict):
+ brand = data.get("brand")
+ if isinstance(brand, str):
+ return brand
+
+ return self.config.email_app_name
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 40a7709c09..88d203aa44 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -15,11 +15,17 @@
# limitations under the License.
import logging
+from collections import defaultdict
+from threading import Lock
+from typing import Dict, Tuple, Union
from twisted.internet import defer
+from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import PusherConfigException
+from synapse.push.emailpusher import EmailPusher
+from synapse.push.httppusher import HttpPusher
from synapse.push.pusher import PusherFactory
from synapse.util.async_helpers import concurrently_execute
@@ -40,13 +46,36 @@ class PusherPool:
notifications are sent; accordingly Pusher.on_started, Pusher.on_new_notifications and
Pusher.on_new_receipts are not expected to return deferreds.
"""
+
def __init__(self, _hs):
self.hs = _hs
self.pusher_factory = PusherFactory(_hs)
self._should_start_pushers = _hs.config.start_pushers
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
- self.pushers = {}
+
+ # map from user id to app_id:pushkey to pusher
+ self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]]
+
+ # a lock for the pushers dict, since `count_pushers` is called from an different
+ # and we otherwise get concurrent modification errors
+ self._pushers_lock = Lock()
+
+ def count_pushers():
+ results = defaultdict(int) # type: Dict[Tuple[str, str], int]
+ with self._pushers_lock:
+ for pushers in self.pushers.values():
+ for pusher in pushers.values():
+ k = (type(pusher).__name__, pusher.app_id)
+ results[k] += 1
+ return results
+
+ LaterGauge(
+ name="synapse_pushers",
+ desc="the number of active pushers",
+ labels=["kind", "app_id"],
+ caller=count_pushers,
+ )
def start(self):
"""Starts the pushers off in a background process.
@@ -57,37 +86,52 @@ class PusherPool:
run_as_background_process("start_pushers", self._start_pushers)
@defer.inlineCallbacks
- def add_pusher(self, user_id, access_token, kind, app_id,
- app_display_name, device_display_name, pushkey, lang, data,
- profile_tag=""):
+ def add_pusher(
+ self,
+ user_id,
+ access_token,
+ kind,
+ app_id,
+ app_display_name,
+ device_display_name,
+ pushkey,
+ lang,
+ data,
+ profile_tag="",
+ ):
+ """Creates a new pusher and adds it to the pool
+
+ Returns:
+ Deferred[EmailPusher|HttpPusher]
+ """
time_now_msec = self.clock.time_msec()
# we try to create the pusher just to validate the config: it
# will then get pulled out of the database,
# recreated, added and started: this means we have only one
# code path adding pushers.
- self.pusher_factory.create_pusher({
- "id": None,
- "user_name": user_id,
- "kind": kind,
- "app_id": app_id,
- "app_display_name": app_display_name,
- "device_display_name": device_display_name,
- "pushkey": pushkey,
- "ts": time_now_msec,
- "lang": lang,
- "data": data,
- "last_stream_ordering": None,
- "last_success": None,
- "failing_since": None
- })
+ self.pusher_factory.create_pusher(
+ {
+ "id": None,
+ "user_name": user_id,
+ "kind": kind,
+ "app_id": app_id,
+ "app_display_name": app_display_name,
+ "device_display_name": device_display_name,
+ "pushkey": pushkey,
+ "ts": time_now_msec,
+ "lang": lang,
+ "data": data,
+ "last_stream_ordering": None,
+ "last_success": None,
+ "failing_since": None,
+ }
+ )
# create the pusher setting last_stream_ordering to the current maximum
# stream ordering in event_push_actions, so it will process
# pushes from this point onwards.
- last_stream_ordering = (
- yield self.store.get_latest_push_action_stream_ordering()
- )
+ last_stream_ordering = yield self.store.get_latest_push_action_stream_ordering()
yield self.store.add_pusher(
user_id=user_id,
@@ -103,21 +147,24 @@ class PusherPool:
last_stream_ordering=last_stream_ordering,
profile_tag=profile_tag,
)
- yield self.start_pusher_by_id(app_id, pushkey, user_id)
+ pusher = yield self.start_pusher_by_id(app_id, pushkey, user_id)
+
+ return pusher
@defer.inlineCallbacks
- def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey,
- not_user_id):
- to_remove = yield self.store.get_pushers_by_app_id_and_pushkey(
- app_id, pushkey
- )
+ def remove_pushers_by_app_id_and_pushkey_not_user(
+ self, app_id, pushkey, not_user_id
+ ):
+ to_remove = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
for p in to_remove:
- if p['user_name'] != not_user_id:
+ if p["user_name"] != not_user_id:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
- app_id, pushkey, p['user_name']
+ app_id,
+ pushkey,
+ p["user_name"],
)
- yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
+ yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
@defer.inlineCallbacks
def remove_pushers_by_access_token(self, user_id, access_tokens):
@@ -131,14 +178,14 @@ class PusherPool:
"""
tokens = set(access_tokens)
for p in (yield self.store.get_pushers_by_user_id(user_id)):
- if p['access_token'] in tokens:
+ if p["access_token"] in tokens:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
- p['app_id'], p['pushkey'], p['user_name']
- )
- yield self.remove_pusher(
- p['app_id'], p['pushkey'], p['user_name'],
+ p["app_id"],
+ p["pushkey"],
+ p["user_name"],
)
+ yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
@defer.inlineCallbacks
def on_new_notifications(self, min_stream_id, max_stream_id):
@@ -172,7 +219,7 @@ class PusherPool:
min_stream_id - 1, max_stream_id
)
# This returns a tuple, user_id is at index 3
- users_affected = set([r[3] for r in updated_receipts])
+ users_affected = {r[3] for r in updated_receipts}
for u in users_affected:
if u in self.pushers:
@@ -184,21 +231,26 @@ class PusherPool:
@defer.inlineCallbacks
def start_pusher_by_id(self, app_id, pushkey, user_id):
- """Look up the details for the given pusher, and start it"""
+ """Look up the details for the given pusher, and start it
+
+ Returns:
+ Deferred[EmailPusher|HttpPusher|None]: The pusher started, if any
+ """
if not self._should_start_pushers:
return
- resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(
- app_id, pushkey
- )
+ resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
- p = None
+ pusher_dict = None
for r in resultlist:
- if r['user_name'] == user_id:
- p = r
+ if r["user_name"] == user_id:
+ pusher_dict = r
+
+ pusher = None
+ if pusher_dict:
+ pusher = yield self._start_pusher(pusher_dict)
- if p:
- yield self._start_pusher(p)
+ return pusher
@defer.inlineCallbacks
def _start_pushers(self):
@@ -208,7 +260,6 @@ class PusherPool:
Deferred
"""
pushers = yield self.store.get_all_pushers()
- logger.info("Starting %d pushers", len(pushers))
# Stagger starting up the pushers so we don't completely drown the
# process on start up.
@@ -221,38 +272,39 @@ class PusherPool:
"""Start the given pusher
Args:
- pusherdict (dict):
+ pusherdict (dict): dict with the values pulled from the db table
Returns:
- None
+ Deferred[EmailPusher|HttpPusher]
"""
try:
p = self.pusher_factory.create_pusher(pusherdict)
except PusherConfigException as e:
logger.warning(
- "Pusher incorrectly configured user=%s, appid=%s, pushkey=%s: %s",
- pusherdict.get('user_name'),
- pusherdict.get('app_id'),
- pusherdict.get('pushkey'),
+ "Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s",
+ pusherdict["id"],
+ pusherdict.get("user_name"),
+ pusherdict.get("app_id"),
+ pusherdict.get("pushkey"),
e,
)
return
except Exception:
- logger.exception("Couldn't start a pusher: caught Exception")
+ logger.exception(
+ "Couldn't start pusher id %i: caught Exception", pusherdict["id"],
+ )
return
if not p:
return
- appid_pushkey = "%s:%s" % (
- pusherdict['app_id'],
- pusherdict['pushkey'],
- )
- byuser = self.pushers.setdefault(pusherdict['user_name'], {})
+ appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"])
- if appid_pushkey in byuser:
- byuser[appid_pushkey].on_stop()
- byuser[appid_pushkey] = p
+ with self._pushers_lock:
+ byuser = self.pushers.setdefault(pusherdict["user_name"], {})
+ if appid_pushkey in byuser:
+ byuser[appid_pushkey].on_stop()
+ byuser[appid_pushkey] = p
# Check if there *may* be push to process. We do this as this check is a
# lot cheaper to do than actually fetching the exact rows we need to
@@ -261,7 +313,7 @@ class PusherPool:
last_stream_ordering = pusherdict["last_stream_ordering"]
if last_stream_ordering:
have_notifs = yield self.store.get_if_maybe_push_in_range_for_user(
- user_id, last_stream_ordering,
+ user_id, last_stream_ordering
)
else:
# We always want to default to starting up the pusher rather than
@@ -270,6 +322,8 @@ class PusherPool:
p.on_started(have_notifs)
+ return p
+
@defer.inlineCallbacks
def remove_pusher(self, app_id, pushkey, user_id):
appid_pushkey = "%s:%s" % (app_id, pushkey)
@@ -279,7 +333,9 @@ class PusherPool:
if appid_pushkey in byuser:
logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
byuser[appid_pushkey].on_stop()
- del byuser[appid_pushkey]
+ with self._pushers_lock:
+ del byuser[appid_pushkey]
+
yield self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id, pushkey, user_id
)
diff --git a/synapse/push/rulekinds.py b/synapse/push/rulekinds.py
index 4cae48ac07..ce7cc1b4ee 100644
--- a/synapse/push/rulekinds.py
+++ b/synapse/push/rulekinds.py
@@ -13,10 +13,10 @@
# limitations under the License.
PRIORITY_CLASS_MAP = {
- 'underride': 1,
- 'sender': 2,
- 'room': 3,
- 'content': 4,
- 'override': 5,
+ "underride": 1,
+ "sender": 2,
+ "room": 3,
+ "content": 4,
+ "override": 5,
}
PRIORITY_CLASS_INVERSE_MAP = {v: k for k, v in PRIORITY_CLASS_MAP.items()}
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 437e79f27c..8de8cb2c12 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -1,6 +1,7 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# Copyright 2018 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,6 +16,7 @@
# limitations under the License.
import logging
+from typing import List, Set
from pkg_resources import (
DistributionNotFound,
@@ -42,19 +44,15 @@ REQUIREMENTS = [
"frozendict>=1",
"unpaddedbase64>=1.1.0",
"canonicaljson>=1.1.3",
- # Pin signedjson to 1.0.0 because this version of Synapse relies on a function that's
- # been removed in 1.1.0. Hopefully, this will be fixed by the upcoming mainline merge.
- "signedjson==1.0.0",
+ # we use the type definitions added in signedjson 1.1.
+ "signedjson>=1.1.0",
"pynacl>=1.2.1",
- "idna>=2",
-
+ "idna>=2.5",
# validating SSL certs for IP addresses requires service_identity 18.1.
"service_identity>=18.1.0",
-
- # our logcontext handling relies on the ability to cancel inlineCallbacks
- # (https://twistedmatrix.com/trac/ticket/4632) which landed in Twisted 18.7.
- "Twisted>=18.7.0",
-
+ # Twisted 18.9 introduces some logger improvements that the structured
+ # logger utilises
+ "Twisted>=18.9.0",
"treq>=15.1",
# Twisted has required pyopenssl 16.0 since about Twisted 16.6.
"pyopenssl>=16.0.0",
@@ -65,50 +63,44 @@ REQUIREMENTS = [
"bcrypt>=3.1.0",
"pillow>=4.3.0",
"sortedcontainers>=1.4.4",
- "psutil>=2.0.0",
"pymacaroons>=0.13.0",
- "msgpack>=0.5.0",
+ "msgpack>=0.5.2",
"phonenumbers>=8.2.0",
"six>=1.10",
- # prometheus_client 0.4.0 changed the format of counter metrics
- # (cf https://github.com/matrix-org/synapse/issues/4001)
- "prometheus_client>=0.0.18,<0.4.0",
-
+ "prometheus_client>=0.0.18,<0.8.0",
# we use attr.s(slots), which arrived in 16.0.0
# Twisted 18.7.0 requires attrs>=17.4.0
"attrs>=17.4.0",
-
"netaddr>=0.7.18",
+ "Jinja2>=2.9",
+ "bleach>=1.4.3",
+ "typing-extensions>=3.7.4",
]
CONDITIONAL_REQUIREMENTS = {
- "email": ["Jinja2>=2.9", "bleach>=1.4.3"],
"matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"],
-
# we use execute_batch, which arrived in psycopg 2.7.
"postgres": ["psycopg2>=2.7"],
-
# ConsentResource uses select_autoescape, which arrived in jinja 2.9
"resources.consent": ["Jinja2>=2.9"],
-
# ACME support is required to provision TLS certificates from authorities
# that use the protocol, such as Let's Encrypt.
"acme": [
"txacme>=0.9.2",
-
# txacme depends on eliot. Eliot 1.8.0 is incompatible with
# python 3.5.2, as per https://github.com/itamarst/eliot/issues/418
'eliot<1.8.0;python_version<"3.5.3"',
],
-
"saml2": ["pysaml2>=4.5.0"],
"systemd": ["systemd-python>=231"],
"url_preview": ["lxml>=3.5.0"],
"test": ["mock>=2.0", "parameterized"],
"sentry": ["sentry-sdk>=0.7.2"],
+ "opentracing": ["jaeger-client>=4.0.0", "opentracing>=2.2.0"],
+ "jwt": ["pyjwt>=1.6.4"],
}
-ALL_OPTIONAL_REQUIREMENTS = set()
+ALL_OPTIONAL_REQUIREMENTS = set() # type: Set[str]
for name, optional_deps in CONDITIONAL_REQUIREMENTS.items():
# Exclude systemd as it's a system-based requirement.
@@ -123,12 +115,14 @@ def list_requirements():
class DependencyException(Exception):
@property
def message(self):
- return "\n".join([
- "Missing Requirements: %s" % (", ".join(self.dependencies),),
- "To install run:",
- " pip install --upgrade --force %s" % (" ".join(self.dependencies),),
- "",
- ])
+ return "\n".join(
+ [
+ "Missing Requirements: %s" % (", ".join(self.dependencies),),
+ "To install run:",
+ " pip install --upgrade --force %s" % (" ".join(self.dependencies),),
+ "",
+ ]
+ )
@property
def dependencies(self):
@@ -152,16 +146,26 @@ def check_requirements(for_feature=None):
deps_needed.append(dependency)
errors.append(
"Needed %s, got %s==%s"
- % (dependency, e.dist.project_name, e.dist.version)
+ % (
+ dependency,
+ e.dist.project_name, # type: ignore[attr-defined] # noqa
+ e.dist.version, # type: ignore[attr-defined] # noqa
+ )
)
except DistributionNotFound:
deps_needed.append(dependency)
- errors.append("Needed %s but it was not installed" % (dependency,))
+ if for_feature:
+ errors.append(
+ "Needed %s for the '%s' feature but it was not installed"
+ % (dependency, for_feature)
+ )
+ else:
+ errors.append("Needed %s but it was not installed" % (dependency,))
if not for_feature:
# Check the optional dependencies are up to date. We allow them to not be
# installed.
- OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), [])
+ OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), []) # type: List[str]
for dependency in OPTS:
try:
@@ -170,15 +174,19 @@ def check_requirements(for_feature=None):
deps_needed.append(dependency)
errors.append(
"Needed optional %s, got %s==%s"
- % (dependency, e.dist.project_name, e.dist.version)
+ % (
+ dependency,
+ e.dist.project_name, # type: ignore[attr-defined] # noqa
+ e.dist.version, # type: ignore[attr-defined] # noqa
+ )
)
except DistributionNotFound:
# If it's not found, we don't care
pass
if deps_needed:
- for e in errors:
- logging.error(e)
+ for err in errors:
+ logging.error(err)
raise DependencyException(deps_needed)
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index 81b85352b1..28dbc6fcba 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -14,7 +14,14 @@
# limitations under the License.
from synapse.http.server import JsonResource
-from synapse.replication.http import federation, login, membership, register, send_event
+from synapse.replication.http import (
+ devices,
+ federation,
+ login,
+ membership,
+ register,
+ send_event,
+)
REPLICATION_PREFIX = "/_synapse/replication"
@@ -30,3 +37,4 @@ class ReplicationRestResource(JsonResource):
federation.register_servlets(hs, self)
login.register_servlets(hs, self)
register.register_servlets(hs, self)
+ devices.register_servlets(hs, self)
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index e81456ab2b..1be1ccbdf3 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -16,12 +16,24 @@
import abc
import logging
import re
+from typing import Dict, List, Tuple
+from six import raise_from
from six.moves import urllib
from twisted.internet import defer
-from synapse.api.errors import CodeMessageException, HttpResponseException
+from synapse.api.errors import (
+ CodeMessageException,
+ HttpResponseException,
+ RequestSendFailed,
+ SynapseError,
+)
+from synapse.logging.opentracing import (
+ inject_active_span_byte_dict,
+ trace,
+ trace_servlet,
+)
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
@@ -32,7 +44,7 @@ class ReplicationEndpoint(object):
"""Helper base class for defining new replication HTTP endpoints.
This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..`
- (with an `/:txn_id` prefix for cached requests.), where NAME is a name,
+ (with a `/:txn_id` suffix for cached requests), where NAME is a name,
PATH_ARGS are a tuple of parameters to be encoded in the URL.
For example, if `NAME` is "send_event" and `PATH_ARGS` is `("event_id",)`,
@@ -67,9 +79,8 @@ class ReplicationEndpoint(object):
__metaclass__ = abc.ABCMeta
- NAME = abc.abstractproperty()
- PATH_ARGS = abc.abstractproperty()
-
+ NAME = abc.abstractproperty() # type: str # type: ignore
+ PATH_ARGS = abc.abstractproperty() # type: Tuple[str, ...] # type: ignore
METHOD = "POST"
CACHE = True
RETRY_ON_TIMEOUT = True
@@ -77,8 +88,7 @@ class ReplicationEndpoint(object):
def __init__(self, hs):
if self.CACHE:
self.response_cache = ResponseCache(
- hs, "repl." + self.NAME,
- timeout_ms=30 * 60 * 1000,
+ hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
)
assert self.METHOD in ("PUT", "POST", "GET")
@@ -100,14 +110,14 @@ class ReplicationEndpoint(object):
return {}
@abc.abstractmethod
- def _handle_request(self, request, **kwargs):
+ async def _handle_request(self, request, **kwargs):
"""Handle incoming request.
This is called with the request object and PATH_ARGS.
Returns:
- Deferred[dict]: A JSON serialisable dict to be used as response
- body of request.
+ tuple[int, dict]: HTTP status code and a JSON serialisable dict
+ to be used as response body of request.
"""
pass
@@ -123,13 +133,13 @@ class ReplicationEndpoint(object):
client = hs.get_simple_http_client()
+ @trace(opname="outgoing_replication_request")
@defer.inlineCallbacks
def send_request(**kwargs):
data = yield cls._serialize_payload(**kwargs)
url_args = [
- urllib.parse.quote(kwargs[name], safe='')
- for name in cls.PATH_ARGS
+ urllib.parse.quote(kwargs[name], safe="") for name in cls.PATH_ARGS
]
if cls.CACHE:
@@ -150,7 +160,10 @@ class ReplicationEndpoint(object):
)
uri = "http://%s:%s/_synapse/replication/%s/%s" % (
- host, port, cls.NAME, "/".join(url_args)
+ host,
+ port,
+ cls.NAME,
+ "/".join(url_args),
)
try:
@@ -158,14 +171,16 @@ class ReplicationEndpoint(object):
# have a good idea that the request has either succeeded or failed on
# the master, and so whether we should clean up or not.
while True:
+ headers = {} # type: Dict[bytes, List[bytes]]
+ inject_active_span_byte_dict(headers, None, check_destination=False)
try:
- result = yield request_func(uri, data)
+ result = yield request_func(uri, data, headers=headers)
break
except CodeMessageException as e:
if e.code != 504 or not cls.RETRY_ON_TIMEOUT:
raise
- logger.warn("%s request timed out", cls.NAME)
+ logger.warning("%s request timed out", cls.NAME)
# If we timed out we probably don't need to worry about backing
# off too much, but lets just wait a little anyway.
@@ -175,8 +190,10 @@ class ReplicationEndpoint(object):
# on the master process that we should send to the client. (And
# importantly, not stack traces everywhere)
raise e.to_synapse_error()
+ except RequestSendFailed as e:
+ raise_from(SynapseError(502, "Failed to talk to master"), e)
- defer.returnValue(result)
+ return result
return send_request
@@ -190,16 +207,18 @@ class ReplicationEndpoint(object):
method = self.METHOD
if self.CACHE:
- handler = self._cached_handler
+ handler = self._cached_handler # type: ignore
url_args.append("txn_id")
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
- pattern = re.compile("^/_synapse/replication/%s/%s$" % (
- self.NAME,
- args
- ))
+ pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
- http_server.register_paths(method, [pattern], handler)
+ handler = trace_servlet(self.__class__.__name__, extract_context=True)(handler)
+ # We don't let register paths trace this servlet using the default tracing
+ # options because we wish to extract the context explicitly.
+ http_server.register_paths(
+ method, [pattern], handler, self.__class__.__name__, trace=False
+ )
def _cached_handler(self, request, txn_id, **kwargs):
"""Called on new incoming requests when caching is enabled. Checks
@@ -211,8 +230,4 @@ class ReplicationEndpoint(object):
assert self.CACHE
- return self.response_cache.wrap(
- txn_id,
- self._handle_request,
- request, **kwargs
- )
+ return self.response_cache.wrap(txn_id, self._handle_request, request, **kwargs)
diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
new file mode 100644
index 0000000000..e32aac0a25
--- /dev/null
+++ b/synapse/replication/http/devices.py
@@ -0,0 +1,73 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.replication.http._base import ReplicationEndpoint
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
+ """Ask master to resync the device list for a user by contacting their
+ server.
+
+ This must happen on master so that the results can be correctly cached in
+ the database and streamed to workers.
+
+ Request format:
+
+ POST /_synapse/replication/user_device_resync/:user_id
+
+ {}
+
+ Response is equivalent to ` /_matrix/federation/v1/user/devices/:user_id`
+ response, e.g.:
+
+ {
+ "user_id": "@alice:example.org",
+ "devices": [
+ {
+ "device_id": "JLAFKJWSCS",
+ "keys": { ... },
+ "device_display_name": "Alice's Mobile Phone"
+ }
+ ]
+ }
+ """
+
+ NAME = "user_device_resync"
+ PATH_ARGS = ("user_id",)
+ CACHE = False
+
+ def __init__(self, hs):
+ super(ReplicationUserDevicesResyncRestServlet, self).__init__(hs)
+
+ self.device_list_updater = hs.get_device_handler().device_list_updater
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+
+ @staticmethod
+ def _serialize_payload(user_id):
+ return {}
+
+ async def _handle_request(self, request, user_id):
+ user_devices = await self.device_list_updater.user_device_resync(user_id)
+
+ return 200, user_devices
+
+
+def register_servlets(hs, http_server):
+ ReplicationUserDevicesResyncRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 0f0a07c422..7e23b565b9 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -17,7 +17,8 @@ import logging
from twisted.internet import defer
-from synapse.events import event_type_from_format_version
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
@@ -37,6 +38,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
{
"events": [{
"event": { .. serialized event .. },
+ "room_version": .., // "1", "2", "3", etc: the version of the room
+ // containing the event
+ "event_format_version": .., // 1,2,3 etc: the event format version
"internal_metadata": { .. serialized internal_metadata .. },
"rejected_reason": .., // The event.rejected_reason field
"context": { .. serialized event context .. },
@@ -51,6 +55,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
super(ReplicationFederationSendEventsRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
self.clock = hs.get_clock()
self.federation_handler = hs.get_handlers().federation_handler
@@ -68,23 +73,22 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
for event, context in event_and_contexts:
serialized_context = yield context.serialize(event, store)
- event_payloads.append({
- "event": event.get_pdu_json(),
- "event_format_version": event.format_version,
- "internal_metadata": event.internal_metadata.get_dict(),
- "rejected_reason": event.rejected_reason,
- "context": serialized_context,
- })
-
- payload = {
- "events": event_payloads,
- "backfilled": backfilled,
- }
+ event_payloads.append(
+ {
+ "event": event.get_pdu_json(),
+ "room_version": event.room_version.identifier,
+ "event_format_version": event.format_version,
+ "internal_metadata": event.internal_metadata.get_dict(),
+ "rejected_reason": event.rejected_reason,
+ "context": serialized_context,
+ }
+ )
- defer.returnValue(payload)
+ payload = {"events": event_payloads, "backfilled": backfilled}
- @defer.inlineCallbacks
- def _handle_request(self, request):
+ return payload
+
+ async def _handle_request(self, request):
with Measure(self.clock, "repl_fed_send_events_parse"):
content = parse_json_object_from_request(request)
@@ -95,29 +99,27 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
event_and_contexts = []
for event_payload in event_payloads:
event_dict = event_payload["event"]
- format_ver = event_payload["event_format_version"]
+ room_ver = KNOWN_ROOM_VERSIONS[event_payload["room_version"]]
internal_metadata = event_payload["internal_metadata"]
rejected_reason = event_payload["rejected_reason"]
- EventType = event_type_from_format_version(format_ver)
- event = EventType(event_dict, internal_metadata, rejected_reason)
+ event = make_event_from_dict(
+ event_dict, room_ver, internal_metadata, rejected_reason
+ )
- context = yield EventContext.deserialize(
- self.store, event_payload["context"],
+ context = EventContext.deserialize(
+ self.storage, event_payload["context"]
)
event_and_contexts.append((event, context))
- logger.info(
- "Got %d events from federation",
- len(event_and_contexts),
- )
+ logger.info("Got %d events from federation", len(event_and_contexts))
- yield self.federation_handler.persist_events_and_notify(
- event_and_contexts, backfilled,
+ await self.federation_handler.persist_events_and_notify(
+ event_and_contexts, backfilled
)
- defer.returnValue((200, {}))
+ return 200, {}
class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
@@ -146,27 +148,20 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
@staticmethod
def _serialize_payload(edu_type, origin, content):
- return {
- "origin": origin,
- "content": content,
- }
+ return {"origin": origin, "content": content}
- @defer.inlineCallbacks
- def _handle_request(self, request, edu_type):
+ async def _handle_request(self, request, edu_type):
with Measure(self.clock, "repl_fed_send_edu_parse"):
content = parse_json_object_from_request(request)
origin = content["origin"]
edu_content = content["content"]
- logger.info(
- "Got %r edu from %s",
- edu_type, origin,
- )
+ logger.info("Got %r edu from %s", edu_type, origin)
- result = yield self.registry.on_edu(edu_type, origin, edu_content)
+ result = await self.registry.on_edu(edu_type, origin, edu_content)
- defer.returnValue((200, result))
+ return 200, result
class ReplicationGetQueryRestServlet(ReplicationEndpoint):
@@ -201,25 +196,19 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
query_type (str)
args (dict): The arguments received for the given query type
"""
- return {
- "args": args,
- }
+ return {"args": args}
- @defer.inlineCallbacks
- def _handle_request(self, request, query_type):
+ async def _handle_request(self, request, query_type):
with Measure(self.clock, "repl_fed_query_parse"):
content = parse_json_object_from_request(request)
args = content["args"]
- logger.info(
- "Got %r query",
- query_type,
- )
+ logger.info("Got %r query", query_type)
- result = yield self.registry.on_query(query_type, args)
+ result = await self.registry.on_query(query_type, args)
- defer.returnValue((200, result))
+ return 200, result
class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
@@ -228,7 +217,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
Request format:
- POST /_synapse/replication/fed_query/:fed_cleanup_room/:txn_id
+ POST /_synapse/replication/fed_cleanup_room/:room_id/:txn_id
{}
"""
@@ -249,11 +238,42 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
"""
return {}
- @defer.inlineCallbacks
- def _handle_request(self, request, room_id):
- yield self.store.clean_room_for_join(room_id)
+ async def _handle_request(self, request, room_id):
+ await self.store.clean_room_for_join(room_id)
+
+ return 200, {}
+
+
+class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint):
+ """Called to clean up any data in DB for a given room, ready for the
+ server to join the room.
+
+ Request format:
+
+ POST /_synapse/replication/store_room_on_invite/:room_id/:txn_id
+
+ {
+ "room_version": "1",
+ }
+ """
+
+ NAME = "store_room_on_invite"
+ PATH_ARGS = ("room_id",)
+
+ def __init__(self, hs):
+ super().__init__(hs)
+
+ self.store = hs.get_datastore()
+
+ @staticmethod
+ def _serialize_payload(room_id, room_version):
+ return {"room_version": room_version.identifier}
- defer.returnValue((200, {}))
+ async def _handle_request(self, request, room_id):
+ content = parse_json_object_from_request(request)
+ room_version = KNOWN_ROOM_VERSIONS[content["room_version"]]
+ await self.store.maybe_store_room_on_invite(room_id, room_version)
+ return 200, {}
def register_servlets(hs, http_server):
@@ -261,3 +281,4 @@ def register_servlets(hs, http_server):
ReplicationFederationSendEduRestServlet(hs).register(http_server)
ReplicationGetQueryRestServlet(hs).register(http_server)
ReplicationCleanRoomRestServlet(hs).register(http_server)
+ ReplicationStoreRoomOnInviteRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 63bc0405ea..798b9d3af5 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
@@ -52,22 +50,18 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
"is_guest": is_guest,
}
- @defer.inlineCallbacks
- def _handle_request(self, request, user_id):
+ async def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
device_id = content["device_id"]
initial_display_name = content["initial_display_name"]
is_guest = content["is_guest"]
- device_id, access_token = yield self.registration_handler.register_device(
- user_id, device_id, initial_display_name, is_guest,
+ device_id, access_token = await self.registration_handler.register_device(
+ user_id, device_id, initial_display_name, is_guest
)
- defer.returnValue((200, {
- "device_id": device_id,
- "access_token": access_token,
- }))
+ return 200, {"device_id": device_id, "access_token": access_token}
def register_servlets(hs, http_server):
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 81a2b204c7..3577611fd7 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import Requester, UserID
@@ -40,7 +38,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
"""
NAME = "remote_join"
- PATH_ARGS = ("room_id", "user_id",)
+ PATH_ARGS = ("room_id", "user_id")
def __init__(self, hs):
super(ReplicationRemoteJoinRestServlet, self).__init__(hs)
@@ -50,8 +48,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
- def _serialize_payload(requester, room_id, user_id, remote_room_hosts,
- content):
+ def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content):
"""
Args:
requester(Requester)
@@ -66,8 +63,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
"content": content,
}
- @defer.inlineCallbacks
- def _handle_request(self, request, room_id, user_id):
+ async def _handle_request(self, request, room_id, user_id):
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
@@ -78,19 +74,13 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
if requester.user:
request.authenticated_entity = requester.user.to_string()
- logger.info(
- "remote_join: %s into room: %s",
- user_id, room_id,
- )
+ logger.info("remote_join: %s into room: %s", user_id, room_id)
- yield self.federation_handler.do_invite_join(
- remote_room_hosts,
- room_id,
- user_id,
- event_content,
+ await self.federation_handler.do_invite_join(
+ remote_room_hosts, room_id, user_id, event_content
)
- defer.returnValue((200, {}))
+ return 200, {}
class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
@@ -103,11 +93,12 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
{
"requester": ...,
"remote_room_hosts": [...],
+ "content": { ... }
}
"""
NAME = "remote_reject_invite"
- PATH_ARGS = ("room_id", "user_id",)
+ PATH_ARGS = ("room_id", "user_id")
def __init__(self, hs):
super(ReplicationRemoteRejectInviteRestServlet, self).__init__(hs)
@@ -117,7 +108,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
- def _serialize_payload(requester, room_id, user_id, remote_room_hosts):
+ def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content):
"""
Args:
requester(Requester)
@@ -128,29 +119,25 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
return {
"requester": requester.serialize(),
"remote_room_hosts": remote_room_hosts,
+ "content": content,
}
- @defer.inlineCallbacks
- def _handle_request(self, request, room_id, user_id):
+ async def _handle_request(self, request, room_id, user_id):
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
+ event_content = content["content"]
requester = Requester.deserialize(self.store, content["requester"])
if requester.user:
request.authenticated_entity = requester.user.to_string()
- logger.info(
- "remote_reject_invite: %s out of room: %s",
- user_id, room_id,
- )
+ logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id)
try:
- event = yield self.federation_handler.do_remotely_reject_invite(
- remote_room_hosts,
- room_id,
- user_id,
+ event = await self.federation_handler.do_remotely_reject_invite(
+ remote_room_hosts, room_id, user_id, event_content,
)
ret = event.get_pdu_json()
except Exception as e:
@@ -160,78 +147,12 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
# The 'except' clause is very broad, but we need to
# capture everything from DNS failures upwards
#
- logger.warn("Failed to reject invite: %s", e)
+ logger.warning("Failed to reject invite: %s", e)
- yield self.store.locally_reject_invite(
- user_id, room_id
- )
+ await self.store.locally_reject_invite(user_id, room_id)
ret = {}
- defer.returnValue((200, ret))
-
-
-class ReplicationRegister3PIDGuestRestServlet(ReplicationEndpoint):
- """Gets/creates a guest account for given 3PID.
-
- Request format:
-
- POST /_synapse/replication/get_or_register_3pid_guest/
-
- {
- "requester": ...,
- "medium": ...,
- "address": ...,
- "inviter_user_id": ...
- }
- """
-
- NAME = "get_or_register_3pid_guest"
- PATH_ARGS = ()
-
- def __init__(self, hs):
- super(ReplicationRegister3PIDGuestRestServlet, self).__init__(hs)
-
- self.registeration_handler = hs.get_registration_handler()
- self.store = hs.get_datastore()
- self.clock = hs.get_clock()
-
- @staticmethod
- def _serialize_payload(requester, medium, address, inviter_user_id):
- """
- Args:
- requester(Requester)
- medium (str)
- address (str)
- inviter_user_id (str): The user ID who is trying to invite the
- 3PID
- """
- return {
- "requester": requester.serialize(),
- "medium": medium,
- "address": address,
- "inviter_user_id": inviter_user_id,
- }
-
- @defer.inlineCallbacks
- def _handle_request(self, request):
- content = parse_json_object_from_request(request)
-
- medium = content["medium"]
- address = content["address"]
- inviter_user_id = content["inviter_user_id"]
-
- requester = Requester.deserialize(self.store, content["requester"])
-
- if requester.user:
- request.authenticated_entity = requester.user.to_string()
-
- logger.info("get_or_register_3pid_guest: %r", content)
-
- ret = yield self.registeration_handler.get_or_register_3pid_guest(
- medium, address, inviter_user_id,
- )
-
- defer.returnValue((200, ret))
+ return 200, ret
class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
@@ -264,7 +185,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
user_id (str)
change (str): Either "joined" or "left"
"""
- assert change in ("joined", "left",)
+ assert change in ("joined", "left")
return {}
@@ -280,11 +201,10 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
else:
raise Exception("Unrecognized change: %r", change)
- return (200, {})
+ return 200, {}
def register_servlets(hs, http_server):
ReplicationRemoteJoinRestServlet(hs).register(http_server)
ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
- ReplicationRegister3PIDGuestRestServlet(hs).register(http_server)
ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 912a5ac341..0c4aca1291 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
@@ -37,15 +35,19 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
@staticmethod
def _serialize_payload(
- user_id, token, password_hash, was_guest, make_guest, appservice_id,
- create_profile_with_displayname, admin, user_type, address,
+ user_id,
+ password_hash,
+ was_guest,
+ make_guest,
+ appservice_id,
+ create_profile_with_displayname,
+ admin,
+ user_type,
+ address,
):
"""
Args:
user_id (str): The desired user ID to register.
- token (str): The desired access token to use for this user. If this
- is not None, the given access token is associated with the user
- id.
password_hash (str|None): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
@@ -60,7 +62,6 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
address (str|None): the IP address used to perform the regitration.
"""
return {
- "token": token,
"password_hash": password_hash,
"was_guest": was_guest,
"make_guest": make_guest,
@@ -71,13 +72,13 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
"address": address,
}
- @defer.inlineCallbacks
- def _handle_request(self, request, user_id):
+ async def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
- yield self.registration_handler.register_with_store(
+ self.registration_handler.check_registration_ratelimit(content["address"])
+
+ await self.registration_handler.register_with_store(
user_id=user_id,
- token=content["token"],
password_hash=content["password_hash"],
was_guest=content["was_guest"],
make_guest=content["make_guest"],
@@ -85,10 +86,10 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
create_profile_with_displayname=content["create_profile_with_displayname"],
admin=content["admin"],
user_type=content["user_type"],
- address=content["address"]
+ address=content["address"],
)
- defer.returnValue((200, {}))
+ return 200, {}
class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
@@ -104,8 +105,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
self.registration_handler = hs.get_registration_handler()
@staticmethod
- def _serialize_payload(user_id, auth_result, access_token, bind_email,
- bind_msisdn):
+ def _serialize_payload(user_id, auth_result, access_token):
"""
Args:
user_id (str): The user ID that consented
@@ -113,36 +113,20 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
registered user.
access_token (str|None): The access token of the newly logged in
device, or None if `inhibit_login` enabled.
- bind_email (bool): Whether to bind the email with the identity
- server
- bind_msisdn (bool): Whether to bind the msisdn with the identity
- server
"""
- return {
- "auth_result": auth_result,
- "access_token": access_token,
- "bind_email": bind_email,
- "bind_msisdn": bind_msisdn,
- }
+ return {"auth_result": auth_result, "access_token": access_token}
- @defer.inlineCallbacks
- def _handle_request(self, request, user_id):
+ async def _handle_request(self, request, user_id):
content = parse_json_object_from_request(request)
auth_result = content["auth_result"]
access_token = content["access_token"]
- bind_email = content["bind_email"]
- bind_msisdn = content["bind_msisdn"]
- yield self.registration_handler.post_registration_actions(
- user_id=user_id,
- auth_result=auth_result,
- access_token=access_token,
- bind_email=bind_email,
- bind_msisdn=bind_msisdn,
+ await self.registration_handler.post_registration_actions(
+ user_id=user_id, auth_result=auth_result, access_token=access_token
)
- defer.returnValue((200, {}))
+ return 200, {}
def register_servlets(hs, http_server):
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index 3635015eda..b74b088ff4 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -17,7 +17,8 @@ import logging
from twisted.internet import defer
-from synapse.events import event_type_from_format_version
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
@@ -37,6 +38,9 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
{
"event": { .. serialized event .. },
+ "room_version": .., // "1", "2", "3", etc: the version of the room
+ // containing the event
+ "event_format_version": .., // 1,2,3 etc: the event format version
"internal_metadata": { .. serialized internal_metadata .. },
"rejected_reason": .., // The event.rejected_reason field
"context": { .. serialized event context .. },
@@ -45,6 +49,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
"extra_users": [],
}
"""
+
NAME = "send_event"
PATH_ARGS = ("event_id",)
@@ -53,12 +58,14 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
self.event_creation_handler = hs.get_event_creation_handler()
self.store = hs.get_datastore()
+ self.storage = hs.get_storage()
self.clock = hs.get_clock()
@staticmethod
@defer.inlineCallbacks
- def _serialize_payload(event_id, store, event, context, requester,
- ratelimit, extra_users):
+ def _serialize_payload(
+ event_id, store, event, context, requester, ratelimit, extra_users
+ ):
"""
Args:
event_id (str)
@@ -74,6 +81,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
payload = {
"event": event.get_pdu_json(),
+ "room_version": event.room_version.identifier,
"event_format_version": event.format_version,
"internal_metadata": event.internal_metadata.get_dict(),
"rejected_reason": event.rejected_reason,
@@ -83,23 +91,23 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
"extra_users": [u.to_string() for u in extra_users],
}
- defer.returnValue(payload)
+ return payload
- @defer.inlineCallbacks
- def _handle_request(self, request, event_id):
+ async def _handle_request(self, request, event_id):
with Measure(self.clock, "repl_send_event_parse"):
content = parse_json_object_from_request(request)
event_dict = content["event"]
- format_ver = content["event_format_version"]
+ room_ver = KNOWN_ROOM_VERSIONS[content["room_version"]]
internal_metadata = content["internal_metadata"]
rejected_reason = content["rejected_reason"]
- EventType = event_type_from_format_version(format_ver)
- event = EventType(event_dict, internal_metadata, rejected_reason)
+ event = make_event_from_dict(
+ event_dict, room_ver, internal_metadata, rejected_reason
+ )
requester = Requester.deserialize(self.store, content["requester"])
- context = yield EventContext.deserialize(self.store, content["context"])
+ context = EventContext.deserialize(self.storage, content["context"])
ratelimit = content["ratelimit"]
extra_users = [UserID.from_string(u) for u in content["extra_users"]]
@@ -108,17 +116,14 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
request.authenticated_entity = requester.user.to_string()
logger.info(
- "Got event to send with ID: %s into room: %s",
- event.event_id, event.room_id,
+ "Got event to send with ID: %s into room: %s", event.event_id, event.room_id
)
- yield self.event_creation_handler.persist_and_notify_client_event(
- requester, event, context,
- ratelimit=ratelimit,
- extra_users=extra_users,
+ await self.event_creation_handler.persist_and_notify_client_event(
+ requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
- defer.returnValue((200, {}))
+ return 200, {}
def register_servlets(hs, http_server):
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 817d1f67f9..f45cbd37a0 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -14,10 +14,13 @@
# limitations under the License.
import logging
+from typing import Dict, Optional
import six
-from synapse.storage._base import _CURRENT_STATE_CACHE_NAME, SQLBaseStore
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME
+from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
from ._slaved_id_tracker import SlavedIdTracker
@@ -33,18 +36,25 @@ def __func__(inp):
class BaseSlavedStore(SQLBaseStore):
- def __init__(self, db_conn, hs):
- super(BaseSlavedStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(BaseSlavedStore, self).__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker(
- db_conn, "cache_invalidation_stream", "stream_id",
- )
+ db_conn, "cache_invalidation_stream", "stream_id"
+ ) # type: Optional[SlavedIdTracker]
else:
self._cache_id_gen = None
self.hs = hs
- def stream_positions(self):
+ def stream_positions(self) -> Dict[str, int]:
+ """
+ Get the current positions of all the streams this store wants to subscribe to
+
+ Returns:
+ map from stream name to the most recent update we have for
+ that stream (ie, the point we want to start replicating from)
+ """
pos = {}
if self._cache_id_gen:
pos["caches"] = self._cache_id_gen.get_current_token()
@@ -52,14 +62,20 @@ class BaseSlavedStore(SQLBaseStore):
def process_replication_rows(self, stream_name, token, rows):
if stream_name == "caches":
- self._cache_id_gen.advance(token)
+ if self._cache_id_gen:
+ self._cache_id_gen.advance(token)
for row in rows:
- if row.cache_func == _CURRENT_STATE_CACHE_NAME:
+ if row.cache_func == CURRENT_STATE_CACHE_NAME:
+ if row.keys is None:
+ raise Exception(
+ "Can't send an 'invalidate all' for current state cache"
+ )
+
room_id = row.keys[0]
members_changed = set(row.keys[1:])
self._invalidate_state_caches(room_id, members_changed)
else:
- self._attempt_to_invalidate_cache(row.cache_func, tuple(row.keys))
+ self._attempt_to_invalidate_cache(row.cache_func, row.keys)
def _invalidate_cache_and_stream(self, txn, cache_func, keys):
txn.call_after(cache_func.invalidate, keys)
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index d9ba6d69b1..ebe94909cb 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -16,18 +16,18 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage.account_data import AccountDataWorkerStore
-from synapse.storage.tags import TagsWorkerStore
+from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
+from synapse.storage.data_stores.main.tags import TagsWorkerStore
+from synapse.storage.database import Database
class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore):
-
- def __init__(self, db_conn, hs):
+ def __init__(self, database: Database, db_conn, hs):
self._account_data_id_gen = SlavedIdTracker(
- db_conn, "account_data_max_stream_id", "stream_id",
+ db_conn, "account_data_max_stream_id", "stream_id"
)
- super(SlavedAccountDataStore, self).__init__(db_conn, hs)
+ super(SlavedAccountDataStore, self).__init__(database, db_conn, hs)
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()
@@ -45,24 +45,20 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
self._account_data_id_gen.advance(token)
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
- self._account_data_stream_cache.entity_has_changed(
- row.user_id, token
- )
+ self._account_data_stream_cache.entity_has_changed(row.user_id, token)
elif stream_name == "account_data":
self._account_data_id_gen.advance(token)
for row in rows:
if not row.room_id:
self.get_global_account_data_by_type_for_user.invalidate(
- (row.data_type, row.user_id,)
+ (row.data_type, row.user_id)
)
self.get_account_data_for_user.invalidate((row.user_id,))
- self.get_account_data_for_room.invalidate((row.user_id, row.room_id,))
+ self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
self.get_account_data_for_room_and_type.invalidate(
- (row.user_id, row.room_id, row.data_type,),
- )
- self._account_data_stream_cache.entity_has_changed(
- row.user_id, token
+ (row.user_id, row.room_id, row.data_type)
)
+ self._account_data_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedAccountDataStore, self).process_replication_rows(
stream_name, token, rows
)
diff --git a/synapse/replication/slave/storage/appservice.py b/synapse/replication/slave/storage/appservice.py
index b53a4c6bd1..a67fbeffb7 100644
--- a/synapse/replication/slave/storage/appservice.py
+++ b/synapse/replication/slave/storage/appservice.py
@@ -14,12 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.appservice import (
+from synapse.storage.data_stores.main.appservice import (
ApplicationServiceTransactionWorkerStore,
ApplicationServiceWorkerStore,
)
-class SlavedApplicationServiceStore(ApplicationServiceTransactionWorkerStore,
- ApplicationServiceWorkerStore):
+class SlavedApplicationServiceStore(
+ ApplicationServiceTransactionWorkerStore, ApplicationServiceWorkerStore
+):
pass
diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py
index 5b8521c770..fbf996e33a 100644
--- a/synapse/replication/slave/storage/client_ips.py
+++ b/synapse/replication/slave/storage/client_ips.py
@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.client_ips import LAST_SEEN_GRANULARITY
+from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY
+from synapse.storage.database import Database
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.caches.descriptors import Cache
@@ -21,13 +22,11 @@ from ._base import BaseSlavedStore
class SlavedClientIpStore(BaseSlavedStore):
- def __init__(self, db_conn, hs):
- super(SlavedClientIpStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(SlavedClientIpStore, self).__init__(database, db_conn, hs)
self.client_ip_last_seen = Cache(
- name="client_ip_last_seen",
- keylen=4,
- max_entries=50000 * CACHE_SIZE_FACTOR,
+ name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
)
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index 4d59778863..0c237c6e0f 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -15,24 +15,25 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage.deviceinbox import DeviceInboxWorkerStore
+from synapse.storage.data_stores.main.deviceinbox import DeviceInboxWorkerStore
+from synapse.storage.database import Database
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
- def __init__(self, db_conn, hs):
- super(SlavedDeviceInboxStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker(
- db_conn, "device_max_stream_id", "stream_id",
+ db_conn, "device_max_stream_id", "stream_id"
)
self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache",
- self._device_inbox_id_gen.get_current_token()
+ self._device_inbox_id_gen.get_current_token(),
)
self._device_federation_outbox_stream_cache = StreamChangeCache(
"DeviceFederationOutboxStreamChangeCache",
- self._device_inbox_id_gen.get_current_token()
+ self._device_inbox_id_gen.get_current_token(),
)
self._last_device_delete_cache = ExpiringCache(
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 16c9a162c5..1c77687eea 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -15,54 +15,63 @@
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
-from synapse.storage.devices import DeviceWorkerStore
-from synapse.storage.end_to_end_keys import EndToEndKeyWorkerStore
+from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
+from synapse.storage.data_stores.main.devices import DeviceWorkerStore
+from synapse.storage.data_stores.main.end_to_end_keys import EndToEndKeyWorkerStore
+from synapse.storage.database import Database
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
- def __init__(self, db_conn, hs):
- super(SlavedDeviceStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(SlavedDeviceStore, self).__init__(database, db_conn, hs)
self.hs = hs
self._device_list_id_gen = SlavedIdTracker(
- db_conn, "device_lists_stream", "stream_id",
+ db_conn, "device_lists_stream", "stream_id"
)
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
- "DeviceListStreamChangeCache", device_list_max,
+ "DeviceListStreamChangeCache", device_list_max
+ )
+ self._user_signature_stream_cache = StreamChangeCache(
+ "UserSignatureStreamChangeCache", device_list_max
)
self._device_list_federation_stream_cache = StreamChangeCache(
- "DeviceListFederationStreamChangeCache", device_list_max,
+ "DeviceListFederationStreamChangeCache", device_list_max
)
def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions()
- result["device_lists"] = self._device_list_id_gen.get_current_token()
+ # The user signature stream uses the same stream ID generator as the
+ # device list stream, so set them both to the device list ID
+ # generator's current token.
+ current_token = self._device_list_id_gen.get_current_token()
+ result[DeviceListsStream.NAME] = current_token
+ result[UserSignatureStream.NAME] = current_token
return result
def process_replication_rows(self, stream_name, token, rows):
- if stream_name == "device_lists":
+ if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token)
for row in rows:
- self._invalidate_caches_for_devices(
- token, row.user_id, row.destination,
- )
+ self._invalidate_caches_for_devices(token, row.user_id, row.destination)
+ elif stream_name == UserSignatureStream.NAME:
+ for row in rows:
+ self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows
)
def _invalidate_caches_for_devices(self, token, user_id, destination):
- self._device_list_stream_cache.entity_has_changed(
- user_id, token
- )
+ self._device_list_stream_cache.entity_has_changed(user_id, token)
if destination:
self._device_list_federation_stream_cache.entity_has_changed(
destination, token
)
- self._get_cached_devices_for_user.invalidate((user_id,))
+ self.get_cached_devices_for_user.invalidate((user_id,))
self._get_cached_user_device.invalidate_many((user_id,))
self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
diff --git a/synapse/replication/slave/storage/directory.py b/synapse/replication/slave/storage/directory.py
index 1d1d48709a..8b9717c46f 100644
--- a/synapse/replication/slave/storage/directory.py
+++ b/synapse/replication/slave/storage/directory.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.directory import DirectoryWorkerStore
+from synapse.storage.data_stores.main.directory import DirectoryWorkerStore
from ._base import BaseSlavedStore
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index a3952506c1..e73342c657 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -20,15 +20,19 @@ from synapse.replication.tcp.streams.events import (
EventsStreamCurrentStateRow,
EventsStreamEventRow,
)
-from synapse.storage.event_federation import EventFederationWorkerStore
-from synapse.storage.event_push_actions import EventPushActionsWorkerStore
-from synapse.storage.events_worker import EventsWorkerStore
-from synapse.storage.relations import RelationsWorkerStore
-from synapse.storage.roommember import RoomMemberWorkerStore
-from synapse.storage.signatures import SignatureWorkerStore
-from synapse.storage.state import StateGroupWorkerStore
-from synapse.storage.stream import StreamWorkerStore
-from synapse.storage.user_erasure_store import UserErasureWorkerStore
+from synapse.storage.data_stores.main.event_federation import EventFederationWorkerStore
+from synapse.storage.data_stores.main.event_push_actions import (
+ EventPushActionsWorkerStore,
+)
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.data_stores.main.relations import RelationsWorkerStore
+from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
+from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
+from synapse.storage.data_stores.main.state import StateGroupWorkerStore
+from synapse.storage.data_stores.main.stream import StreamWorkerStore
+from synapse.storage.data_stores.main.user_erasure_store import UserErasureWorkerStore
+from synapse.storage.database import Database
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
@@ -45,26 +49,40 @@ logger = logging.getLogger(__name__)
# the method descriptor on the DataStore and chuck them into our class.
-class SlavedEventStore(EventFederationWorkerStore,
- RoomMemberWorkerStore,
- EventPushActionsWorkerStore,
- StreamWorkerStore,
- StateGroupWorkerStore,
- EventsWorkerStore,
- SignatureWorkerStore,
- UserErasureWorkerStore,
- RelationsWorkerStore,
- BaseSlavedStore):
-
- def __init__(self, db_conn, hs):
- self._stream_id_gen = SlavedIdTracker(
- db_conn, "events", "stream_ordering",
- )
+class SlavedEventStore(
+ EventFederationWorkerStore,
+ RoomMemberWorkerStore,
+ EventPushActionsWorkerStore,
+ StreamWorkerStore,
+ StateGroupWorkerStore,
+ EventsWorkerStore,
+ SignatureWorkerStore,
+ UserErasureWorkerStore,
+ RelationsWorkerStore,
+ BaseSlavedStore,
+):
+ def __init__(self, database: Database, db_conn, hs):
+ self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1
)
- super(SlavedEventStore, self).__init__(db_conn, hs)
+ super(SlavedEventStore, self).__init__(database, db_conn, hs)
+
+ events_max = self._stream_id_gen.get_current_token()
+ curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
+ db_conn,
+ "current_state_delta_stream",
+ entity_column="room_id",
+ stream_column="stream_id",
+ max_value=events_max, # As we share the stream id with events token
+ limit=1000,
+ )
+ self._curr_state_delta_stream_cache = StreamChangeCache(
+ "_curr_state_delta_stream_cache",
+ min_curr_state_delta_id,
+ prefilled_cache=curr_state_delta_prefill,
+ )
# Cached functions can't be accessed through a class instance so we need
# to reach inside the __dict__ to extract them.
@@ -90,8 +108,13 @@ class SlavedEventStore(EventFederationWorkerStore,
self._backfill_id_gen.advance(-token)
for row in rows:
self.invalidate_caches_for_event(
- -token, row.event_id, row.room_id, row.type, row.state_key,
- row.redacts, row.relates_to,
+ -token,
+ row.event_id,
+ row.room_id,
+ row.type,
+ row.state_key,
+ row.redacts,
+ row.relates_to,
backfilled=True,
)
return super(SlavedEventStore, self).process_replication_rows(
@@ -103,42 +126,53 @@ class SlavedEventStore(EventFederationWorkerStore,
if row.type == EventsStreamEventRow.TypeId:
self.invalidate_caches_for_event(
- token, data.event_id, data.room_id, data.type, data.state_key,
- data.redacts, data.relates_to,
+ token,
+ data.event_id,
+ data.room_id,
+ data.type,
+ data.state_key,
+ data.redacts,
+ data.relates_to,
backfilled=False,
)
elif row.type == EventsStreamCurrentStateRow.TypeId:
+ self._curr_state_delta_stream_cache.entity_has_changed(
+ row.data.room_id, token
+ )
+
if data.type == EventTypes.Member:
self.get_rooms_for_user_with_stream_ordering.invalidate(
- (data.state_key, ),
+ (data.state_key,)
)
else:
- raise Exception("Unknown events stream row type %s" % (row.type, ))
-
- def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,
- etype, state_key, redacts, relates_to,
- backfilled):
+ raise Exception("Unknown events stream row type %s" % (row.type,))
+
+ def invalidate_caches_for_event(
+ self,
+ stream_ordering,
+ event_id,
+ room_id,
+ etype,
+ state_key,
+ redacts,
+ relates_to,
+ backfilled,
+ ):
self._invalidate_get_event_cache(event_id)
self.get_latest_event_ids_in_room.invalidate((room_id,))
- self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
- (room_id,)
- )
+ self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
if not backfilled:
- self._events_stream_cache.entity_has_changed(
- room_id, stream_ordering
- )
+ self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
if redacts:
self._invalidate_get_event_cache(redacts)
if etype == EventTypes.Member:
- self._membership_stream_cache.entity_has_changed(
- state_key, stream_ordering
- )
- self.get_invited_rooms_for_user.invalidate((state_key,))
+ self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
+ self.get_invited_rooms_for_local_user.invalidate((state_key,))
if relates_to:
self.get_relations_for_event.invalidate_many((relates_to,))
diff --git a/synapse/replication/slave/storage/filtering.py b/synapse/replication/slave/storage/filtering.py
index 456a14cd5c..bcb0688954 100644
--- a/synapse/replication/slave/storage/filtering.py
+++ b/synapse/replication/slave/storage/filtering.py
@@ -13,14 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.filtering import FilteringStore
+from synapse.storage.data_stores.main.filtering import FilteringStore
+from synapse.storage.database import Database
from ._base import BaseSlavedStore
class SlavedFilteringStore(BaseSlavedStore):
- def __init__(self, db_conn, hs):
- super(SlavedFilteringStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(SlavedFilteringStore, self).__init__(database, db_conn, hs)
# Filters are immutable so this cache doesn't need to be expired
get_user_filter = FilteringStore.__dict__["get_user_filter"]
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index e933b170bb..2d4fd08cf5 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -13,29 +13,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage import DataStore
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
+from synapse.storage.data_stores.main.group_server import GroupServerWorkerStore
+from synapse.storage.database import Database
from synapse.util.caches.stream_change_cache import StreamChangeCache
-from ._base import BaseSlavedStore, __func__
-from ._slaved_id_tracker import SlavedIdTracker
-
-class SlavedGroupServerStore(BaseSlavedStore):
- def __init__(self, db_conn, hs):
- super(SlavedGroupServerStore, self).__init__(db_conn, hs)
+class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(SlavedGroupServerStore, self).__init__(database, db_conn, hs)
self.hs = hs
self._group_updates_id_gen = SlavedIdTracker(
- db_conn, "local_group_updates", "stream_id",
+ db_conn, "local_group_updates", "stream_id"
)
self._group_updates_stream_cache = StreamChangeCache(
- "_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(),
+ "_group_updates_stream_cache",
+ self._group_updates_id_gen.get_current_token(),
)
- get_groups_changes_for_user = __func__(DataStore.get_groups_changes_for_user)
- get_group_stream_token = __func__(DataStore.get_group_stream_token)
- get_all_groups_for_user = __func__(DataStore.get_all_groups_for_user)
+ def get_group_stream_token(self):
+ return self._group_updates_id_gen.get_current_token()
def stream_positions(self):
result = super(SlavedGroupServerStore, self).stream_positions()
@@ -46,9 +46,7 @@ class SlavedGroupServerStore(BaseSlavedStore):
if stream_name == "groups":
self._group_updates_id_gen.advance(token)
for row in rows:
- self._group_updates_stream_cache.entity_has_changed(
- row.user_id, token
- )
+ self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedGroupServerStore, self).process_replication_rows(
stream_name, token, rows
diff --git a/synapse/replication/slave/storage/keys.py b/synapse/replication/slave/storage/keys.py
index cc6f7f009f..3def367ae9 100644
--- a/synapse/replication/slave/storage/keys.py
+++ b/synapse/replication/slave/storage/keys.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage import KeyStore
+from synapse.storage.data_stores.main.keys import KeyStore
# KeyStore isn't really safe to use from a worker, but for now we do so and hope that
# the races it creates aren't too bad.
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index 0ec1db25ce..ad8f0c15a9 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -14,7 +14,8 @@
# limitations under the License.
from synapse.storage import DataStore
-from synapse.storage.presence import PresenceStore
+from synapse.storage.data_stores.main.presence import PresenceStore
+from synapse.storage.database import Database
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore, __func__
@@ -22,15 +23,13 @@ from ._slaved_id_tracker import SlavedIdTracker
class SlavedPresenceStore(BaseSlavedStore):
- def __init__(self, db_conn, hs):
- super(SlavedPresenceStore, self).__init__(db_conn, hs)
- self._presence_id_gen = SlavedIdTracker(
- db_conn, "presence_stream", "stream_id",
- )
+ def __init__(self, database: Database, db_conn, hs):
+ super(SlavedPresenceStore, self).__init__(database, db_conn, hs)
+ self._presence_id_gen = SlavedIdTracker(db_conn, "presence_stream", "stream_id")
self._presence_on_startup = self._get_active_presence(db_conn)
- self.presence_stream_cache = self.presence_stream_cache = StreamChangeCache(
+ self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", self._presence_id_gen.get_current_token()
)
@@ -55,9 +54,7 @@ class SlavedPresenceStore(BaseSlavedStore):
if stream_name == "presence":
self._presence_id_gen.advance(token)
for row in rows:
- self.presence_stream_cache.entity_has_changed(
- row.user_id, token
- )
+ self.presence_stream_cache.entity_has_changed(row.user_id, token)
self._get_presence_for_user.invalidate((row.user_id,))
return super(SlavedPresenceStore, self).process_replication_rows(
stream_name, token, rows
diff --git a/synapse/replication/slave/storage/profile.py b/synapse/replication/slave/storage/profile.py
index 46c28d4171..28c508aad3 100644
--- a/synapse/replication/slave/storage/profile.py
+++ b/synapse/replication/slave/storage/profile.py
@@ -14,7 +14,7 @@
# limitations under the License.
from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.storage.profile import ProfileWorkerStore
+from synapse.storage.data_stores.main.profile import ProfileWorkerStore
class SlavedProfileStore(ProfileWorkerStore, BaseSlavedStore):
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 45fc913c52..eebd5a1fb6 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -14,18 +14,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.push_rule import PushRulesWorkerStore
+from synapse.storage.data_stores.main.push_rule import PushRulesWorkerStore
+from synapse.storage.database import Database
from ._slaved_id_tracker import SlavedIdTracker
from .events import SlavedEventStore
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
- def __init__(self, db_conn, hs):
+ def __init__(self, database: Database, db_conn, hs):
self._push_rules_stream_id_gen = SlavedIdTracker(
- db_conn, "push_rules_stream", "stream_id",
+ db_conn, "push_rules_stream", "stream_id"
)
- super(SlavedPushRuleStore, self).__init__(db_conn, hs)
+ super(SlavedPushRuleStore, self).__init__(database, db_conn, hs)
def get_push_rules_stream_token(self):
return (
@@ -47,9 +48,7 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
for row in rows:
self.get_push_rules_for_user.invalidate((row.user_id,))
self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
- self.push_rules_stream_cache.entity_has_changed(
- row.user_id, token
- )
+ self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedPushRuleStore, self).process_replication_rows(
stream_name, token, rows
)
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index 3b2213c0d4..f22c2d44a3 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -14,19 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.pusher import PusherWorkerStore
+from synapse.storage.data_stores.main.pusher import PusherWorkerStore
+from synapse.storage.database import Database
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
-
- def __init__(self, db_conn, hs):
- super(SlavedPusherStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(SlavedPusherStore, self).__init__(database, db_conn, hs)
self._pushers_id_gen = SlavedIdTracker(
- db_conn, "pushers", "id",
- extra_tables=[("deleted_pushers", "stream_id")],
+ db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
def stream_positions(self):
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index ed12342f40..d40dc6e1f5 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -14,7 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.receipts import ReceiptsWorkerStore
+from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
+from synapse.storage.database import Database
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
@@ -29,15 +30,14 @@ from ._slaved_id_tracker import SlavedIdTracker
class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
-
- def __init__(self, db_conn, hs):
+ def __init__(self, database: Database, db_conn, hs):
# We instantiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = SlavedIdTracker(
db_conn, "receipts_linearized", "stream_id"
)
- super(SlavedReceiptsStore, self).__init__(db_conn, hs)
+ super(SlavedReceiptsStore, self).__init__(database, db_conn, hs)
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/registration.py b/synapse/replication/slave/storage/registration.py
index 408d91df1c..4b8553e250 100644
--- a/synapse/replication/slave/storage/registration.py
+++ b/synapse/replication/slave/storage/registration.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.registration import RegistrationWorkerStore
+from synapse.storage.data_stores.main.registration import RegistrationWorkerStore
from ._base import BaseSlavedStore
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 0cb474928c..3a20f45316 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -13,15 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.room import RoomWorkerStore
+from synapse.storage.data_stores.main.room import RoomWorkerStore
+from synapse.storage.database import Database
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
class RoomStore(RoomWorkerStore, BaseSlavedStore):
- def __init__(self, db_conn, hs):
- super(RoomStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(RoomStore, self).__init__(database, db_conn, hs)
self._public_room_id_gen = SlavedIdTracker(
db_conn, "public_room_list_stream", "stream_id"
)
@@ -38,6 +39,4 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
if stream_name == "public_rooms":
self._public_room_id_gen.advance(token)
- return super(RoomStore, self).process_replication_rows(
- stream_name, token, rows
- )
+ return super(RoomStore, self).process_replication_rows(stream_name, token, rows)
diff --git a/synapse/replication/slave/storage/transactions.py b/synapse/replication/slave/storage/transactions.py
index 3527beb3c9..ac88e6b8c3 100644
--- a/synapse/replication/slave/storage/transactions.py
+++ b/synapse/replication/slave/storage/transactions.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.transactions import TransactionStore
+from synapse.storage.data_stores.main.transactions import TransactionStore
from ._base import BaseSlavedStore
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 206dc3b397..02ab5b66ea 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -16,18 +16,26 @@
"""
import logging
+from typing import Dict, List, Optional
from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory
+from synapse.replication.slave.storage._base import BaseSlavedStore
+from synapse.replication.tcp.protocol import (
+ AbstractReplicationClientHandler,
+ ClientReplicationStreamProtocol,
+)
+
from .commands import (
+ Command,
FederationAckCommand,
InvalidateCacheCommand,
+ RemoteServerUpCommand,
RemovePusherCommand,
UserIpCommand,
UserSyncCommand,
)
-from .protocol import ClientReplicationStreamProtocol
logger = logging.getLogger(__name__)
@@ -39,9 +47,11 @@ class ReplicationClientFactory(ReconnectingClientFactory):
Accepts a handler that will be called when new data is available or data
is required.
"""
- maxDelay = 30 # Try at least once every N seconds
- def __init__(self, hs, client_name, handler):
+ initialDelay = 0.1
+ maxDelay = 1 # Try at least once every N seconds
+
+ def __init__(self, hs, client_name, handler: AbstractReplicationClientHandler):
self.client_name = client_name
self.handler = handler
self.server_name = hs.config.server_name
@@ -64,17 +74,16 @@ class ReplicationClientFactory(ReconnectingClientFactory):
def clientConnectionFailed(self, connector, reason):
logger.error("Failed to connect to replication: %r", reason)
- ReconnectingClientFactory.clientConnectionFailed(
- self, connector, reason
- )
+ ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
-class ReplicationClientHandler(object):
+class ReplicationClientHandler(AbstractReplicationClientHandler):
"""A base handler that can be passed to the ReplicationClientFactory.
By default proxies incoming replication data to the SlaveStore.
"""
- def __init__(self, store):
+
+ def __init__(self, store: BaseSlavedStore):
self.store = store
# The current connection. None if we are currently (re)connecting
@@ -82,15 +91,15 @@ class ReplicationClientHandler(object):
# Any pending commands to be sent once a new connection has been
# established
- self.pending_commands = []
+ self.pending_commands = [] # type: List[Command]
# Map from string -> deferred, to wake up when receiveing a SYNC with
# the given string.
# Used for tests.
- self.awaiting_syncs = {}
+ self.awaiting_syncs = {} # type: Dict[str, defer.Deferred]
# The factory used to create connections.
- self.factory = None
+ self.factory = None # type: Optional[ReplicationClientFactory]
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
@@ -102,7 +111,7 @@ class ReplicationClientHandler(object):
port = hs.config.worker_replication_port
hs.get_reactor().connectTCP(host, port, self.factory)
- def on_rdata(self, stream_name, token, rows):
+ async def on_rdata(self, stream_name, token, rows):
"""Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to
@@ -113,20 +122,17 @@ class ReplicationClientHandler(object):
token (int): stream token for this batch of rows
rows (list): a list of Stream.ROW_TYPE objects as returned by
Stream.parse_row.
-
- Returns:
- Deferred|None
"""
logger.debug("Received rdata %s -> %s", stream_name, token)
- return self.store.process_replication_rows(stream_name, token, rows)
+ self.store.process_replication_rows(stream_name, token, rows)
- def on_position(self, stream_name, token):
+ async def on_position(self, stream_name, token):
"""Called when we get new position data. By default this just pokes
the slave store.
Can be overriden in subclasses to handle more.
"""
- return self.store.process_replication_rows(stream_name, token, [])
+ self.store.process_replication_rows(stream_name, token, [])
def on_sync(self, data):
"""When we received a SYNC we wake up any deferreds that were waiting
@@ -138,11 +144,16 @@ class ReplicationClientHandler(object):
if d:
d.callback(data)
- def get_streams_to_replicate(self):
+ def on_remote_server_up(self, server: str):
+ """Called when get a new REMOTE_SERVER_UP command."""
+
+ def get_streams_to_replicate(self) -> Dict[str, int]:
"""Called when a new connection has been established and we need to
subscribe to streams.
- Returns a dictionary of stream name to token.
+ Returns:
+ map from stream name to the most recent update we have for
+ that stream (ie, the point we want to start replicating from)
"""
args = self.store.stream_positions()
user_account_data = args.pop("user_account_data", None)
@@ -168,7 +179,7 @@ class ReplicationClientHandler(object):
if self.connection:
self.connection.send_command(cmd)
else:
- logger.warn("Queuing command as not connected: %r", cmd.NAME)
+ logger.warning("Queuing command as not connected: %r", cmd.NAME)
self.pending_commands.append(cmd)
def send_federation_ack(self, token):
@@ -200,6 +211,9 @@ class ReplicationClientHandler(object):
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
self.send_command(cmd)
+ def send_remote_server_up(self, server: str):
+ self.send_command(RemoteServerUpCommand(server))
+
def await_sync(self, data):
"""Returns a deferred that is resolved when we receive a SYNC command
with given data.
@@ -226,4 +240,5 @@ class ReplicationClientHandler(object):
# We don't reset the delay any earlier as otherwise if there is a
# problem during start up we'll end up tight looping connecting to the
# server.
- self.factory.resetDelay()
+ if self.factory:
+ self.factory.resetDelay()
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 2098c32a77..451671412d 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -20,13 +20,16 @@ allowed to be sent by which side.
import logging
import platform
+from typing import Tuple, Type
if platform.python_implementation() == "PyPy":
import json
+
_json_encoder = json.JSONEncoder()
else:
- import simplejson as json
- _json_encoder = json.JSONEncoder(namedtuple_as_object=False)
+ import simplejson as json # type: ignore[no-redef] # noqa: F821
+
+ _json_encoder = json.JSONEncoder(namedtuple_as_object=False) # type: ignore[call-arg] # noqa: F821
logger = logging.getLogger(__name__)
@@ -41,7 +44,8 @@ class Command(object):
The default implementation creates a command of form `<NAME> <data>`
"""
- NAME = None
+
+ NAME = None # type: str
def __init__(self, data):
self.data = data
@@ -73,6 +77,7 @@ class ServerCommand(Command):
SERVER <server_name>
"""
+
NAME = "SERVER"
@@ -99,6 +104,7 @@ class RdataCommand(Command):
RDATA presence batch ["@bar:example.com", "online", ...]
RDATA presence 59 ["@baz:example.com", "online", ...]
"""
+
NAME = "RDATA"
def __init__(self, stream_name, token, row):
@@ -110,17 +116,17 @@ class RdataCommand(Command):
def from_line(cls, line):
stream_name, token, row_json = line.split(" ", 2)
return cls(
- stream_name,
- None if token == "batch" else int(token),
- json.loads(row_json)
+ stream_name, None if token == "batch" else int(token), json.loads(row_json)
)
def to_line(self):
- return " ".join((
- self.stream_name,
- str(self.token) if self.token is not None else "batch",
- _json_encoder.encode(self.row),
- ))
+ return " ".join(
+ (
+ self.stream_name,
+ str(self.token) if self.token is not None else "batch",
+ _json_encoder.encode(self.row),
+ )
+ )
def get_logcontext_id(self):
return "RDATA-" + self.stream_name
@@ -133,6 +139,7 @@ class PositionCommand(Command):
Sent to the client after all missing updates for a stream have been sent
to the client and they're now up to date.
"""
+
NAME = "POSITION"
def __init__(self, stream_name, token):
@@ -145,19 +152,21 @@ class PositionCommand(Command):
return cls(stream_name, int(token))
def to_line(self):
- return " ".join((self.stream_name, str(self.token),))
+ return " ".join((self.stream_name, str(self.token)))
class ErrorCommand(Command):
"""Sent by either side if there was an ERROR. The data is a string describing
the error.
"""
+
NAME = "ERROR"
class PingCommand(Command):
"""Sent by either side as a keep alive. The data is arbitary (often timestamp)
"""
+
NAME = "PING"
@@ -165,6 +174,7 @@ class NameCommand(Command):
"""Sent by client to inform the server of the client's identity. The data
is the name
"""
+
NAME = "NAME"
@@ -184,6 +194,7 @@ class ReplicateCommand(Command):
REPLICATE ALL NOW
"""
+
NAME = "REPLICATE"
def __init__(self, stream_name, token):
@@ -200,7 +211,7 @@ class ReplicateCommand(Command):
return cls(stream_name, token)
def to_line(self):
- return " ".join((self.stream_name, str(self.token),))
+ return " ".join((self.stream_name, str(self.token)))
def get_logcontext_id(self):
return "REPLICATE-" + self.stream_name
@@ -218,6 +229,7 @@ class UserSyncCommand(Command):
Where <state> is either "start" or "stop"
"""
+
NAME = "USER_SYNC"
def __init__(self, user_id, is_syncing, last_sync_ms):
@@ -235,9 +247,13 @@ class UserSyncCommand(Command):
return cls(user_id, state == "start", int(last_sync_ms))
def to_line(self):
- return " ".join((
- self.user_id, "start" if self.is_syncing else "end", str(self.last_sync_ms),
- ))
+ return " ".join(
+ (
+ self.user_id,
+ "start" if self.is_syncing else "end",
+ str(self.last_sync_ms),
+ )
+ )
class FederationAckCommand(Command):
@@ -251,6 +267,7 @@ class FederationAckCommand(Command):
FEDERATION_ACK <token>
"""
+
NAME = "FEDERATION_ACK"
def __init__(self, token):
@@ -268,6 +285,7 @@ class SyncCommand(Command):
"""Used for testing. The client protocol implementation allows waiting
on a SYNC command with a specified data.
"""
+
NAME = "SYNC"
@@ -278,6 +296,7 @@ class RemovePusherCommand(Command):
REMOVE_PUSHER <app_id> <push_key> <user_id>
"""
+
NAME = "REMOVE_PUSHER"
def __init__(self, app_id, push_key, user_id):
@@ -309,6 +328,7 @@ class InvalidateCacheCommand(Command):
Where <keys_json> is a json list.
"""
+
NAME = "INVALIDATE_CACHE"
def __init__(self, cache_func, keys):
@@ -322,9 +342,7 @@ class InvalidateCacheCommand(Command):
return cls(cache_func, json.loads(keys_json))
def to_line(self):
- return " ".join((
- self.cache_func, _json_encoder.encode(self.keys),
- ))
+ return " ".join((self.cache_func, _json_encoder.encode(self.keys)))
class UserIpCommand(Command):
@@ -334,6 +352,7 @@ class UserIpCommand(Command):
USER_IP <user_id>, <access_token>, <ip>, <device_id>, <last_seen>, <user_agent>
"""
+
NAME = "USER_IP"
def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen):
@@ -350,36 +369,57 @@ class UserIpCommand(Command):
access_token, ip, user_agent, device_id, last_seen = json.loads(jsn)
- return cls(
- user_id, access_token, ip, user_agent, device_id, last_seen
- )
+ return cls(user_id, access_token, ip, user_agent, device_id, last_seen)
def to_line(self):
- return self.user_id + " " + _json_encoder.encode((
- self.access_token, self.ip, self.user_agent, self.device_id,
- self.last_seen,
- ))
+ return (
+ self.user_id
+ + " "
+ + _json_encoder.encode(
+ (
+ self.access_token,
+ self.ip,
+ self.user_agent,
+ self.device_id,
+ self.last_seen,
+ )
+ )
+ )
+
+
+class RemoteServerUpCommand(Command):
+ """Sent when a worker has detected that a remote server is no longer
+ "down" and retry timings should be reset.
+
+ If sent from a client the server will relay to all other workers.
+
+ Format::
+
+ REMOTE_SERVER_UP <server>
+ """
+ NAME = "REMOTE_SERVER_UP"
+
+
+_COMMANDS = (
+ ServerCommand,
+ RdataCommand,
+ PositionCommand,
+ ErrorCommand,
+ PingCommand,
+ NameCommand,
+ ReplicateCommand,
+ UserSyncCommand,
+ FederationAckCommand,
+ SyncCommand,
+ RemovePusherCommand,
+ InvalidateCacheCommand,
+ UserIpCommand,
+ RemoteServerUpCommand,
+) # type: Tuple[Type[Command], ...]
# Map of command name to command type.
-COMMAND_MAP = {
- cmd.NAME: cmd
- for cmd in (
- ServerCommand,
- RdataCommand,
- PositionCommand,
- ErrorCommand,
- PingCommand,
- NameCommand,
- ReplicateCommand,
- UserSyncCommand,
- FederationAckCommand,
- SyncCommand,
- RemovePusherCommand,
- InvalidateCacheCommand,
- UserIpCommand,
- )
-}
+COMMAND_MAP = {cmd.NAME: cmd for cmd in _COMMANDS}
# The commands the server is allowed to send
VALID_SERVER_COMMANDS = (
@@ -389,6 +429,7 @@ VALID_SERVER_COMMANDS = (
ErrorCommand.NAME,
PingCommand.NAME,
SyncCommand.NAME,
+ RemoteServerUpCommand.NAME,
)
# The commands the client is allowed to send
@@ -402,4 +443,5 @@ VALID_CLIENT_COMMANDS = (
InvalidateCacheCommand.NAME,
UserIpCommand.NAME,
ErrorCommand.NAME,
+ RemoteServerUpCommand.NAME,
)
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index b51590cf8f..d185cc0c8f 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -48,11 +48,12 @@ indicate which side is sending, these are *not* included on the wire::
> ERROR server stopping
* connection closed by server *
"""
-
+import abc
import fcntl
import logging
import struct
from collections import defaultdict
+from typing import Any, DefaultDict, Dict, List, Set, Tuple
from six import iteritems, iterkeys
@@ -62,29 +63,33 @@ from twisted.internet import defer
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
-from synapse.util.stringutils import random_string
-
-from .commands import (
+from synapse.replication.tcp.commands import (
COMMAND_MAP,
VALID_CLIENT_COMMANDS,
VALID_SERVER_COMMANDS,
+ Command,
ErrorCommand,
NameCommand,
PingCommand,
PositionCommand,
RdataCommand,
+ RemoteServerUpCommand,
ReplicateCommand,
ServerCommand,
SyncCommand,
UserSyncCommand,
)
-from .streams import STREAMS_MAP
+from synapse.replication.tcp.streams import STREAMS_MAP
+from synapse.types import Collection
+from synapse.util import Clock
+from synapse.util.stringutils import random_string
connection_close_counter = Counter(
- "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"])
+ "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"]
+)
# A list of all connected protocols. This allows us to send metrics about the
# connections.
@@ -119,10 +124,14 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
It also sends `PING` periodically, and correctly times out remote connections
(if they send a `PING` command)
"""
- delimiter = b'\n'
- VALID_INBOUND_COMMANDS = [] # Valid commands we expect to receive
- VALID_OUTBOUND_COMMANDS = [] # Valid commans we can send
+ delimiter = b"\n"
+
+ # Valid commands we expect to receive
+ VALID_INBOUND_COMMANDS = [] # type: Collection[str]
+
+ # Valid commands we can send
+ VALID_OUTBOUND_COMMANDS = [] # type: Collection[str]
max_line_buffer = 10000
@@ -141,13 +150,13 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.conn_id = random_string(5) # To dedupe in case of name clashes.
# List of pending commands to send once we've established the connection
- self.pending_commands = []
+ self.pending_commands = [] # type: List[Command]
# The LoopingCall for sending pings.
self._send_ping_loop = None
- self.inbound_commands_counter = defaultdict(int)
- self.outbound_commands_counter = defaultdict(int)
+ self.inbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int]
+ self.outbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int]
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
@@ -183,10 +192,14 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
if now - self.last_sent_command >= PING_TIME:
self.send_command(PingCommand(now))
- if self.received_ping and now - self.last_received_command > PING_TIMEOUT_MS:
+ if (
+ self.received_ping
+ and now - self.last_received_command > PING_TIMEOUT_MS
+ ):
logger.info(
"[%s] Connection hasn't received command in %r ms. Closing.",
- self.id(), now - self.last_received_command
+ self.id(),
+ now - self.last_received_command,
)
self.send_error("ping timeout")
@@ -208,7 +221,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.last_received_command = self.clock.time_msec()
self.inbound_commands_counter[cmd_name] = (
- self.inbound_commands_counter[cmd_name] + 1)
+ self.inbound_commands_counter[cmd_name] + 1
+ )
cmd_cls = COMMAND_MAP[cmd_name]
try:
@@ -224,27 +238,22 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# Now lets try and call on_<CMD_NAME> function
run_as_background_process(
- "replication-" + cmd.get_logcontext_id(),
- self.handle_command,
- cmd,
+ "replication-" + cmd.get_logcontext_id(), self.handle_command, cmd
)
- def handle_command(self, cmd):
+ async def handle_command(self, cmd: Command):
"""Handle a command we have received over the replication stream.
- By default delegates to on_<COMMAND>
+ By default delegates to on_<COMMAND>, which should return an awaitable.
Args:
- cmd (synapse.replication.tcp.commands.Command): received command
-
- Returns:
- Deferred
+ cmd: received command
"""
handler = getattr(self, "on_%s" % (cmd.NAME,))
- return handler(cmd)
+ await handler(cmd)
def close(self):
- logger.warn("[%s] Closing connection", self.id())
+ logger.warning("[%s] Closing connection", self.id())
self.time_we_closed = self.clock.time_msec()
self.transport.loseConnection()
self.on_connection_closed()
@@ -274,8 +283,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
return
self.outbound_commands_counter[cmd.NAME] = (
- self.outbound_commands_counter[cmd.NAME] + 1)
- string = "%s %s" % (cmd.NAME, cmd.to_line(),)
+ self.outbound_commands_counter[cmd.NAME] + 1
+ )
+ string = "%s %s" % (cmd.NAME, cmd.to_line())
if "\n" in string:
raise Exception("Unexpected newline in command: %r", string)
@@ -283,10 +293,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
if len(encoded_string) > self.MAX_LENGTH:
raise Exception(
- "Failed to send command %s as too long (%d > %d)" % (
- cmd.NAME,
- len(encoded_string), self.MAX_LENGTH,
- )
+ "Failed to send command %s as too long (%d > %d)"
+ % (cmd.NAME, len(encoded_string), self.MAX_LENGTH)
)
self.sendLine(encoded_string)
@@ -315,10 +323,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
for cmd in pending:
self.send_command(cmd)
- def on_PING(self, line):
+ async def on_PING(self, line):
self.received_ping = True
- def on_ERROR(self, cmd):
+ async def on_ERROR(self, cmd):
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
def pauseProducing(self):
@@ -379,7 +387,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
if self.transport:
addr = str(self.transport.getPeer())
return "ReplicationConnection<name=%s,conn_id=%s,addr=%s>" % (
- self.name, self.conn_id, addr,
+ self.name,
+ self.conn_id,
+ addr,
)
def id(self):
@@ -402,68 +412,69 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.streamer = streamer
# The streams the client has subscribed to and is up to date with
- self.replication_streams = set()
+ self.replication_streams = set() # type: Set[str]
# The streams the client is currently subscribing to.
- self.connecting_streams = set()
+ self.connecting_streams = set() # type: Set[str]
# Map from stream name to list of updates to send once we've finished
# subscribing the client to the stream.
- self.pending_rdata = {}
+ self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]]
def connectionMade(self):
self.send_command(ServerCommand(self.server_name))
BaseReplicationStreamProtocol.connectionMade(self)
self.streamer.new_connection(self)
- def on_NAME(self, cmd):
+ async def on_NAME(self, cmd):
logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data
- def on_USER_SYNC(self, cmd):
- return self.streamer.on_user_sync(
- self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms,
+ async def on_USER_SYNC(self, cmd):
+ await self.streamer.on_user_sync(
+ self.conn_id, cmd.user_id, cmd.is_syncing, cmd.last_sync_ms
)
- def on_REPLICATE(self, cmd):
+ async def on_REPLICATE(self, cmd):
stream_name = cmd.stream_name
token = cmd.token
if stream_name == "ALL":
# Subscribe to all streams we're publishing to.
deferreds = [
- run_in_background(
- self.subscribe_to_stream,
- stream, token,
- )
+ run_in_background(self.subscribe_to_stream, stream, token)
for stream in iterkeys(self.streamer.streams_by_name)
]
- return make_deferred_yieldable(
+ await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
else:
- return self.subscribe_to_stream(stream_name, token)
+ await self.subscribe_to_stream(stream_name, token)
- def on_FEDERATION_ACK(self, cmd):
- return self.streamer.federation_ack(cmd.token)
+ async def on_FEDERATION_ACK(self, cmd):
+ self.streamer.federation_ack(cmd.token)
- def on_REMOVE_PUSHER(self, cmd):
- return self.streamer.on_remove_pusher(
- cmd.app_id, cmd.push_key, cmd.user_id,
- )
+ async def on_REMOVE_PUSHER(self, cmd):
+ await self.streamer.on_remove_pusher(cmd.app_id, cmd.push_key, cmd.user_id)
- def on_INVALIDATE_CACHE(self, cmd):
- return self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
+ async def on_INVALIDATE_CACHE(self, cmd):
+ await self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
- def on_USER_IP(self, cmd):
- return self.streamer.on_user_ip(
- cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id,
+ async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
+ self.streamer.on_remote_server_up(cmd.data)
+
+ async def on_USER_IP(self, cmd):
+ await self.streamer.on_user_ip(
+ cmd.user_id,
+ cmd.access_token,
+ cmd.ip,
+ cmd.user_agent,
+ cmd.device_id,
cmd.last_seen,
)
- @defer.inlineCallbacks
- def subscribe_to_stream(self, stream_name, token):
+ async def subscribe_to_stream(self, stream_name, token):
"""Subscribe the remote to a stream.
This invloves checking if they've missed anything and sending those
@@ -475,8 +486,8 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
try:
# Get missing updates
- updates, current_token = yield self.streamer.get_stream_updates(
- stream_name, token,
+ updates, current_token = await self.streamer.get_stream_updates(
+ stream_name, token
)
# Send all the missing updates
@@ -548,16 +559,90 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
def send_sync(self, data):
self.send_command(SyncCommand(data))
+ def send_remote_server_up(self, server: str):
+ self.send_command(RemoteServerUpCommand(server))
+
def on_connection_closed(self):
BaseReplicationStreamProtocol.on_connection_closed(self)
self.streamer.lost_connection(self)
+class AbstractReplicationClientHandler(metaclass=abc.ABCMeta):
+ """
+ The interface for the handler that should be passed to
+ ClientReplicationStreamProtocol
+ """
+
+ @abc.abstractmethod
+ async def on_rdata(self, stream_name, token, rows):
+ """Called to handle a batch of replication data with a given stream token.
+
+ Args:
+ stream_name (str): name of the replication stream for this batch of rows
+ token (int): stream token for this batch of rows
+ rows (list): a list of Stream.ROW_TYPE objects as returned by
+ Stream.parse_row.
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ async def on_position(self, stream_name, token):
+ """Called when we get new position data."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def on_sync(self, data):
+ """Called when get a new SYNC command."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ async def on_remote_server_up(self, server: str):
+ """Called when get a new REMOTE_SERVER_UP command."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_streams_to_replicate(self):
+ """Called when a new connection has been established and we need to
+ subscribe to streams.
+
+ Returns:
+ map from stream name to the most recent update we have for
+ that stream (ie, the point we want to start replicating from)
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_currently_syncing_users(self):
+ """Get the list of currently syncing users (if any). This is called
+ when a connection has been established and we need to send the
+ currently syncing users."""
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def update_connection(self, connection):
+ """Called when a connection has been established (or lost with None).
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def finished_connecting(self):
+ """Called when we have successfully subscribed and caught up to all
+ streams we're interested in.
+ """
+ raise NotImplementedError()
+
+
class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
VALID_INBOUND_COMMANDS = VALID_SERVER_COMMANDS
VALID_OUTBOUND_COMMANDS = VALID_CLIENT_COMMANDS
- def __init__(self, client_name, server_name, clock, handler):
+ def __init__(
+ self,
+ client_name: str,
+ server_name: str,
+ clock: Clock,
+ handler: AbstractReplicationClientHandler,
+ ):
BaseReplicationStreamProtocol.__init__(self, clock)
self.client_name = client_name
@@ -567,11 +652,11 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Set of stream names that have been subscribe to, but haven't yet
# caught up with. This is used to track when the client has been fully
# connected to the remote.
- self.streams_connecting = set()
+ self.streams_connecting = set() # type: Set[str]
# Map of stream to batched updates. See RdataCommand for info on how
# batching works.
- self.pending_batches = {}
+ self.pending_batches = {} # type: Dict[str, Any]
def connectionMade(self):
self.send_command(NameCommand(self.client_name))
@@ -595,12 +680,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
if not self.streams_connecting:
self.handler.finished_connecting()
- def on_SERVER(self, cmd):
+ async def on_SERVER(self, cmd):
if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
self.send_error("Wrong remote")
- def on_RDATA(self, cmd):
+ async def on_RDATA(self, cmd):
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()
@@ -608,8 +693,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
row = STREAMS_MAP[stream_name].parse_row(cmd.row)
except Exception:
logger.exception(
- "[%s] Failed to parse RDATA: %r %r",
- self.id(), stream_name, cmd.row
+ "[%s] Failed to parse RDATA: %r %r", self.id(), stream_name, cmd.row
)
raise
@@ -621,19 +705,22 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Check if this is the last of a batch of updates
rows = self.pending_batches.pop(stream_name, [])
rows.append(row)
- return self.handler.on_rdata(stream_name, cmd.token, rows)
+ await self.handler.on_rdata(stream_name, cmd.token, rows)
- def on_POSITION(self, cmd):
+ async def on_POSITION(self, cmd):
# When we get a `POSITION` command it means we've finished getting
# missing updates for the given stream, and are now up to date.
self.streams_connecting.discard(cmd.stream_name)
if not self.streams_connecting:
self.handler.finished_connecting()
- return self.handler.on_position(cmd.stream_name, cmd.token)
+ await self.handler.on_position(cmd.stream_name, cmd.token)
- def on_SYNC(self, cmd):
- return self.handler.on_sync(cmd.data)
+ async def on_SYNC(self, cmd):
+ self.handler.on_sync(cmd.data)
+
+ async def on_REMOTE_SERVER_UP(self, cmd: RemoteServerUpCommand):
+ self.handler.on_remote_server_up(cmd.data)
def replicate(self, stream_name, token):
"""Send the subscription request to the server
@@ -643,7 +730,9 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
logger.info(
"[%s] Subscribing to replication stream: %r from %r",
- self.id(), stream_name, token
+ self.id(),
+ stream_name,
+ token,
)
self.streams_connecting.add(stream_name)
@@ -661,9 +750,7 @@ pending_commands = LaterGauge(
"synapse_replication_tcp_protocol_pending_commands",
"",
["name"],
- lambda: {
- (p.name,): len(p.pending_commands) for p in connected_connections
- },
+ lambda: {(p.name,): len(p.pending_commands) for p in connected_connections},
)
@@ -678,9 +765,7 @@ transport_send_buffer = LaterGauge(
"synapse_replication_tcp_protocol_transport_send_buffer",
"",
["name"],
- lambda: {
- (p.name,): transport_buffer_size(p) for p in connected_connections
- },
+ lambda: {(p.name,): transport_buffer_size(p) for p in connected_connections},
)
@@ -694,7 +779,7 @@ def transport_kernel_read_buffer_size(protocol, read=True):
op = SIOCINQ
else:
op = SIOCOUTQ
- size = struct.unpack("I", fcntl.ioctl(fileno, op, '\0\0\0\0'))[0]
+ size = struct.unpack("I", fcntl.ioctl(fileno, op, b"\0\0\0\0"))[0]
return size
return 0
@@ -726,7 +811,7 @@ tcp_inbound_commands = LaterGauge(
"",
["command", "name"],
lambda: {
- (k, p.name,): count
+ (k, p.name): count
for p in connected_connections
for k, count in iteritems(p.inbound_commands_counter)
},
@@ -737,7 +822,7 @@ tcp_outbound_commands = LaterGauge(
"",
["command", "name"],
lambda: {
- (k, p.name,): count
+ (k, p.name): count
for p in connected_connections
for k, count in iteritems(p.outbound_commands_counter)
},
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index f6a38f5140..ce9d1fae12 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -17,12 +17,12 @@
import logging
import random
+from typing import Any, List
from six import itervalues
from prometheus_client import Counter
-from twisted.internet import defer
from twisted.internet.protocol import Factory
from synapse.metrics import LaterGauge
@@ -33,13 +33,15 @@ from .protocol import ServerReplicationStreamProtocol
from .streams import STREAMS_MAP
from .streams.federation import FederationStream
-stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates",
- "", ["stream_name"])
+stream_updates_counter = Counter(
+ "synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]
+)
user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
-invalidate_cache_counter = Counter("synapse_replication_tcp_resource_invalidate_cache",
- "")
+invalidate_cache_counter = Counter(
+ "synapse_replication_tcp_resource_invalidate_cache", ""
+)
user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
logger = logging.getLogger(__name__)
@@ -48,6 +50,7 @@ logger = logging.getLogger(__name__)
class ReplicationStreamProtocolFactory(Factory):
"""Factory for new replication connections.
"""
+
def __init__(self, hs):
self.streamer = ReplicationStreamer(hs)
self.clock = hs.get_clock()
@@ -55,9 +58,7 @@ class ReplicationStreamProtocolFactory(Factory):
def buildProtocol(self, addr):
return ServerReplicationStreamProtocol(
- self.server_name,
- self.clock,
- self.streamer,
+ self.server_name, self.clock, self.streamer
)
@@ -78,37 +79,48 @@ class ReplicationStreamer(object):
self._replication_torture_level = hs.config.replication_torture_level
# Current connections.
- self.connections = []
+ self.connections = [] # type: List[ServerReplicationStreamProtocol]
- LaterGauge("synapse_replication_tcp_resource_total_connections", "", [],
- lambda: len(self.connections))
+ LaterGauge(
+ "synapse_replication_tcp_resource_total_connections",
+ "",
+ [],
+ lambda: len(self.connections),
+ )
# List of streams that clients can subscribe to.
# We only support federation stream if federation sending hase been
# disabled on the master.
self.streams = [
- stream(hs) for stream in itervalues(STREAMS_MAP)
+ stream(hs)
+ for stream in itervalues(STREAMS_MAP)
if stream != FederationStream or not hs.config.send_federation
]
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
LaterGauge(
- "synapse_replication_tcp_resource_connections_per_stream", "",
+ "synapse_replication_tcp_resource_connections_per_stream",
+ "",
["stream_name"],
lambda: {
- (stream_name,): len([
- conn for conn in self.connections
- if stream_name in conn.replication_streams
- ])
+ (stream_name,): len(
+ [
+ conn
+ for conn in self.connections
+ if stream_name in conn.replication_streams
+ ]
+ )
for stream_name in self.streams_by_name
- })
+ },
+ )
self.federation_sender = None
if not hs.config.send_federation:
self.federation_sender = hs.get_federation_sender()
self.notifier.add_replication_callback(self.on_notifier_poke)
+ self.notifier.add_remote_server_up_callback(self.send_remote_server_up)
# Keeps track of whether we are currently checking for updates
self.is_looping = False
@@ -143,8 +155,7 @@ class ReplicationStreamer(object):
run_as_background_process("replication_notifier", self._run_notifier_loop)
- @defer.inlineCallbacks
- def _run_notifier_loop(self):
+ async def _run_notifier_loop(self):
self.is_looping = True
try:
@@ -173,23 +184,26 @@ class ReplicationStreamer(object):
continue
if self._replication_torture_level:
- yield self.clock.sleep(
+ await self.clock.sleep(
self._replication_torture_level / 1000.0
)
logger.debug(
"Getting stream: %s: %s -> %s",
- stream.NAME, stream.last_token, stream.upto_token
+ stream.NAME,
+ stream.last_token,
+ stream.upto_token,
)
try:
- updates, current_token = yield stream.get_updates()
+ updates, current_token = await stream.get_updates()
except Exception:
logger.info("Failed to handle stream %s", stream.NAME)
raise
logger.debug(
"Sending %d updates to %d connections",
- len(updates), len(self.connections),
+ len(updates),
+ len(self.connections),
)
if updates:
@@ -218,7 +232,7 @@ class ReplicationStreamer(object):
self.is_looping = False
@measure_func("repl.get_stream_updates")
- def get_stream_updates(self, stream_name, token):
+ async def get_stream_updates(self, stream_name, token):
"""For a given stream get all updates since token. This is called when
a client first subscribes to a stream.
"""
@@ -226,7 +240,7 @@ class ReplicationStreamer(object):
if not stream:
raise Exception("unknown stream %s", stream_name)
- return stream.get_updates_since(token)
+ return await stream.get_updates_since(token)
@measure_func("repl.federation_ack")
def federation_ack(self, token):
@@ -237,44 +251,54 @@ class ReplicationStreamer(object):
self.federation_sender.federation_ack(token)
@measure_func("repl.on_user_sync")
- @defer.inlineCallbacks
- def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
+ async def on_user_sync(self, conn_id, user_id, is_syncing, last_sync_ms):
"""A client has started/stopped syncing on a worker.
"""
user_sync_counter.inc()
- yield self.presence_handler.update_external_syncs_row(
- conn_id, user_id, is_syncing, last_sync_ms,
+ await self.presence_handler.update_external_syncs_row(
+ conn_id, user_id, is_syncing, last_sync_ms
)
@measure_func("repl.on_remove_pusher")
- @defer.inlineCallbacks
- def on_remove_pusher(self, app_id, push_key, user_id):
+ async def on_remove_pusher(self, app_id, push_key, user_id):
"""A client has asked us to remove a pusher
"""
remove_pusher_counter.inc()
- yield self.store.delete_pusher_by_app_id_pushkey_user_id(
+ await self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id=app_id, pushkey=push_key, user_id=user_id
)
self.notifier.on_new_replication_data()
@measure_func("repl.on_invalidate_cache")
- def on_invalidate_cache(self, cache_func, keys):
+ async def on_invalidate_cache(self, cache_func: str, keys: List[Any]):
"""The client has asked us to invalidate a cache
"""
invalidate_cache_counter.inc()
- getattr(self.store, cache_func).invalidate(tuple(keys))
+
+ # We invalidate the cache locally, but then also stream that to other
+ # workers.
+ await self.store.invalidate_cache_and_stream(cache_func, tuple(keys))
@measure_func("repl.on_user_ip")
- @defer.inlineCallbacks
- def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
+ async def on_user_ip(
+ self, user_id, access_token, ip, user_agent, device_id, last_seen
+ ):
"""The client saw a user request
"""
user_ip_cache_counter.inc()
- yield self.store.insert_client_ip(
- user_id, access_token, ip, user_agent, device_id, last_seen,
+ await self.store.insert_client_ip(
+ user_id, access_token, ip, user_agent, device_id, last_seen
)
- yield self._server_notices_sender.on_user_ip(user_id)
+ await self._server_notices_sender.on_user_ip(user_id)
+
+ @measure_func("repl.on_remote_server_up")
+ def on_remote_server_up(self, server: str):
+ self.notifier.notify_remote_server_up(server)
+
+ def send_remote_server_up(self, server: str):
+ for conn in self.connections:
+ conn.send_remote_server_up(server)
def send_sync_to_all_connections(self, data):
"""Sends a SYNC command to all clients.
@@ -299,7 +323,11 @@ class ReplicationStreamer(object):
# We need to tell the presence handler that the connection has been
# lost so that it can handle any ongoing syncs on that connection.
- self.presence_handler.update_external_syncs_clear(connection.conn_id)
+ run_as_background_process(
+ "update_external_syncs_clear",
+ self.presence_handler.update_external_syncs_clear,
+ connection.conn_id,
+ )
def _batch_updates(updates):
diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py
index 634f636dc9..5f52264e84 100644
--- a/synapse/replication/tcp/streams/__init__.py
+++ b/synapse/replication/tcp/streams/__init__.py
@@ -45,5 +45,6 @@ STREAMS_MAP = {
_base.TagAccountDataStream,
_base.AccountDataStream,
_base.GroupServerStream,
+ _base.UserSignatureStream,
)
}
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index b6ce7a7bee..208e8a667b 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,90 +14,101 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
import itertools
import logging
from collections import namedtuple
+from typing import Any, List, Optional
-from twisted.internet import defer
+import attr
logger = logging.getLogger(__name__)
-MAX_EVENTS_BEHIND = 10000
-
-BackfillStreamRow = namedtuple("BackfillStreamRow", (
- "event_id", # str
- "room_id", # str
- "type", # str
- "state_key", # str, optional
- "redacts", # str, optional
- "relates_to", # str, optional
-))
-PresenceStreamRow = namedtuple("PresenceStreamRow", (
- "user_id", # str
- "state", # str
- "last_active_ts", # int
- "last_federation_update_ts", # int
- "last_user_sync_ts", # int
- "status_msg", # str
- "currently_active", # bool
-))
-TypingStreamRow = namedtuple("TypingStreamRow", (
- "room_id", # str
- "user_ids", # list(str)
-))
-ReceiptsStreamRow = namedtuple("ReceiptsStreamRow", (
- "room_id", # str
- "receipt_type", # str
- "user_id", # str
- "event_id", # str
- "data", # dict
-))
-PushRulesStreamRow = namedtuple("PushRulesStreamRow", (
- "user_id", # str
-))
-PushersStreamRow = namedtuple("PushersStreamRow", (
- "user_id", # str
- "app_id", # str
- "pushkey", # str
- "deleted", # bool
-))
-CachesStreamRow = namedtuple("CachesStreamRow", (
- "cache_func", # str
- "keys", # list(str)
- "invalidation_ts", # int
-))
-PublicRoomsStreamRow = namedtuple("PublicRoomsStreamRow", (
- "room_id", # str
- "visibility", # str
- "appservice_id", # str, optional
- "network_id", # str, optional
-))
-DeviceListsStreamRow = namedtuple("DeviceListsStreamRow", (
- "user_id", # str
- "destination", # str
-))
-ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", (
- "entity", # str
-))
-TagAccountDataStreamRow = namedtuple("TagAccountDataStreamRow", (
- "user_id", # str
- "room_id", # str
- "data", # dict
-))
-AccountDataStreamRow = namedtuple("AccountDataStream", (
- "user_id", # str
- "room_id", # str
- "data_type", # str
- "data", # dict
-))
-GroupsStreamRow = namedtuple("GroupsStreamRow", (
- "group_id", # str
- "user_id", # str
- "type", # str
- "content", # dict
-))
+MAX_EVENTS_BEHIND = 500000
+
+BackfillStreamRow = namedtuple(
+ "BackfillStreamRow",
+ (
+ "event_id", # str
+ "room_id", # str
+ "type", # str
+ "state_key", # str, optional
+ "redacts", # str, optional
+ "relates_to", # str, optional
+ ),
+)
+PresenceStreamRow = namedtuple(
+ "PresenceStreamRow",
+ (
+ "user_id", # str
+ "state", # str
+ "last_active_ts", # int
+ "last_federation_update_ts", # int
+ "last_user_sync_ts", # int
+ "status_msg", # str
+ "currently_active", # bool
+ ),
+)
+TypingStreamRow = namedtuple(
+ "TypingStreamRow", ("room_id", "user_ids") # str # list(str)
+)
+ReceiptsStreamRow = namedtuple(
+ "ReceiptsStreamRow",
+ (
+ "room_id", # str
+ "receipt_type", # str
+ "user_id", # str
+ "event_id", # str
+ "data", # dict
+ ),
+)
+PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",)) # str
+PushersStreamRow = namedtuple(
+ "PushersStreamRow",
+ ("user_id", "app_id", "pushkey", "deleted"), # str # str # str # bool
+)
+
+
+@attr.s
+class CachesStreamRow:
+ """Stream to inform workers they should invalidate their cache.
+
+ Attributes:
+ cache_func: Name of the cached function.
+ keys: The entry in the cache to invalidate. If None then will
+ invalidate all.
+ invalidation_ts: Timestamp of when the invalidation took place.
+ """
+
+ cache_func = attr.ib(type=str)
+ keys = attr.ib(type=Optional[List[Any]])
+ invalidation_ts = attr.ib(type=int)
+
+
+PublicRoomsStreamRow = namedtuple(
+ "PublicRoomsStreamRow",
+ (
+ "room_id", # str
+ "visibility", # str
+ "appservice_id", # str, optional
+ "network_id", # str, optional
+ ),
+)
+DeviceListsStreamRow = namedtuple(
+ "DeviceListsStreamRow", ("user_id", "destination") # str # str
+)
+ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str
+TagAccountDataStreamRow = namedtuple(
+ "TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
+)
+AccountDataStreamRow = namedtuple(
+ "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
+)
+GroupsStreamRow = namedtuple(
+ "GroupsStreamRow",
+ ("group_id", "user_id", "type", "content"), # str # str # str # dict
+)
+UserSignatureStreamRow = namedtuple("UserSignatureStreamRow", ("user_id")) # str
class Stream(object):
@@ -106,8 +117,10 @@ class Stream(object):
Provides a `get_updates()` function that returns new updates since the last
time it was called up until the point `advance_current_token` was called.
"""
- NAME = None # The name of the stream
- ROW_TYPE = None # The type of the row. Used by the default impl of parse_row.
+
+ NAME = None # type: str # The name of the stream
+ # The type of the row. Used by the default impl of parse_row.
+ ROW_TYPE = None # type: Any
_LIMITED = True # Whether the update function takes a limit
@classmethod
@@ -145,8 +158,7 @@ class Stream(object):
self.upto_token = self.current_token()
self.last_token = self.upto_token
- @defer.inlineCallbacks
- def get_updates(self):
+ async def get_updates(self):
"""Gets all updates since the last time this function was called (or
since the stream was constructed if it hadn't been called before),
until the `upto_token`
@@ -157,13 +169,12 @@ class Stream(object):
list of ``(token, row)`` entries. ``row`` will be json-serialised and
sent over the replication steam.
"""
- updates, current_token = yield self.get_updates_since(self.last_token)
+ updates, current_token = await self.get_updates_since(self.last_token)
self.last_token = current_token
- defer.returnValue((updates, current_token))
+ return updates, current_token
- @defer.inlineCallbacks
- def get_updates_since(self, from_token):
+ async def get_updates_since(self, from_token):
"""Like get_updates except allows specifying from when we should
stream updates
@@ -174,27 +185,25 @@ class Stream(object):
sent over the replication steam.
"""
if from_token in ("NOW", "now"):
- defer.returnValue(([], self.upto_token))
+ return [], self.upto_token
current_token = self.upto_token
from_token = int(from_token)
if from_token == current_token:
- defer.returnValue(([], current_token))
+ return [], current_token
+ logger.info("get_updates_since: %s", self.__class__)
if self._LIMITED:
- rows = yield self.update_function(
- from_token, current_token,
- limit=MAX_EVENTS_BEHIND + 1,
+ rows = await self.update_function(
+ from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
)
# never turn more than MAX_EVENTS_BEHIND + 1 into updates.
rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
else:
- rows = yield self.update_function(
- from_token, current_token,
- )
+ rows = await self.update_function(from_token, current_token)
updates = [(row[0], row[1:]) for row in rows]
@@ -203,7 +212,7 @@ class Stream(object):
if self._LIMITED and len(updates) >= MAX_EVENTS_BEHIND:
raise Exception("stream %s has fallen behind" % (self.NAME))
- defer.returnValue((updates, current_token))
+ return updates, current_token
def current_token(self):
"""Gets the current token of the underlying streams. Should be provided
@@ -230,13 +239,14 @@ class BackfillStream(Stream):
"""We fetched some old events and either we had never seen that event before
or it went from being an outlier to not.
"""
+
NAME = "backfill"
ROW_TYPE = BackfillStreamRow
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_current_backfill_token
- self.update_function = store.get_all_new_backfill_event_rows
+ self.current_token = store.get_current_backfill_token # type: ignore
+ self.update_function = store.get_all_new_backfill_event_rows # type: ignore
super(BackfillStream, self).__init__(hs)
@@ -250,8 +260,8 @@ class PresenceStream(Stream):
store = hs.get_datastore()
presence_handler = hs.get_presence_handler()
- self.current_token = store.get_current_presence_token
- self.update_function = presence_handler.get_all_presence_updates
+ self.current_token = store.get_current_presence_token # type: ignore
+ self.update_function = presence_handler.get_all_presence_updates # type: ignore
super(PresenceStream, self).__init__(hs)
@@ -264,8 +274,8 @@ class TypingStream(Stream):
def __init__(self, hs):
typing_handler = hs.get_typing_handler()
- self.current_token = typing_handler.get_current_token
- self.update_function = typing_handler.get_all_typing_updates
+ self.current_token = typing_handler.get_current_token # type: ignore
+ self.update_function = typing_handler.get_all_typing_updates # type: ignore
super(TypingStream, self).__init__(hs)
@@ -277,8 +287,8 @@ class ReceiptsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_max_receipt_stream_id
- self.update_function = store.get_all_updated_receipts
+ self.current_token = store.get_max_receipt_stream_id # type: ignore
+ self.update_function = store.get_all_updated_receipts # type: ignore
super(ReceiptsStream, self).__init__(hs)
@@ -286,6 +296,7 @@ class ReceiptsStream(Stream):
class PushRulesStream(Stream):
"""A user has changed their push rules
"""
+
NAME = "push_rules"
ROW_TYPE = PushRulesStreamRow
@@ -297,23 +308,23 @@ class PushRulesStream(Stream):
push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token
- @defer.inlineCallbacks
- def update_function(self, from_token, to_token, limit):
- rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit)
- defer.returnValue([(row[0], row[2]) for row in rows])
+ async def update_function(self, from_token, to_token, limit):
+ rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
+ return [(row[0], row[2]) for row in rows]
class PushersStream(Stream):
"""A user has added/changed/removed a pusher
"""
+
NAME = "pushers"
ROW_TYPE = PushersStreamRow
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_pushers_stream_token
- self.update_function = store.get_all_updated_pushers_rows
+ self.current_token = store.get_pushers_stream_token # type: ignore
+ self.update_function = store.get_all_updated_pushers_rows # type: ignore
super(PushersStream, self).__init__(hs)
@@ -322,14 +333,15 @@ class CachesStream(Stream):
"""A cache was invalidated on the master and no other stream would invalidate
the cache on the workers
"""
+
NAME = "caches"
ROW_TYPE = CachesStreamRow
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_cache_stream_token
- self.update_function = store.get_all_updated_caches
+ self.current_token = store.get_cache_stream_token # type: ignore
+ self.update_function = store.get_all_updated_caches # type: ignore
super(CachesStream, self).__init__(hs)
@@ -337,14 +349,15 @@ class CachesStream(Stream):
class PublicRoomsStream(Stream):
"""The public rooms list changed
"""
+
NAME = "public_rooms"
ROW_TYPE = PublicRoomsStreamRow
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_current_public_room_stream_id
- self.update_function = store.get_all_new_public_rooms
+ self.current_token = store.get_current_public_room_stream_id # type: ignore
+ self.update_function = store.get_all_new_public_rooms # type: ignore
super(PublicRoomsStream, self).__init__(hs)
@@ -352,6 +365,7 @@ class PublicRoomsStream(Stream):
class DeviceListsStream(Stream):
"""Someone added/changed/removed a device
"""
+
NAME = "device_lists"
_LIMITED = False
ROW_TYPE = DeviceListsStreamRow
@@ -359,8 +373,8 @@ class DeviceListsStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_device_stream_token
- self.update_function = store.get_all_device_list_changes_for_remotes
+ self.current_token = store.get_device_stream_token # type: ignore
+ self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore
super(DeviceListsStream, self).__init__(hs)
@@ -368,14 +382,15 @@ class DeviceListsStream(Stream):
class ToDeviceStream(Stream):
"""New to_device messages for a client
"""
+
NAME = "to_device"
ROW_TYPE = ToDeviceStreamRow
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_to_device_stream_token
- self.update_function = store.get_all_new_device_messages
+ self.current_token = store.get_to_device_stream_token # type: ignore
+ self.update_function = store.get_all_new_device_messages # type: ignore
super(ToDeviceStream, self).__init__(hs)
@@ -383,14 +398,15 @@ class ToDeviceStream(Stream):
class TagAccountDataStream(Stream):
"""Someone added/removed a tag for a room
"""
+
NAME = "tag_account_data"
ROW_TYPE = TagAccountDataStreamRow
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_max_account_data_stream_id
- self.update_function = store.get_all_updated_tags
+ self.current_token = store.get_max_account_data_stream_id # type: ignore
+ self.update_function = store.get_all_updated_tags # type: ignore
super(TagAccountDataStream, self).__init__(hs)
@@ -398,29 +414,29 @@ class TagAccountDataStream(Stream):
class AccountDataStream(Stream):
"""Global or per room account data was changed
"""
+
NAME = "account_data"
ROW_TYPE = AccountDataStreamRow
def __init__(self, hs):
self.store = hs.get_datastore()
- self.current_token = self.store.get_max_account_data_stream_id
+ self.current_token = self.store.get_max_account_data_stream_id # type: ignore
super(AccountDataStream, self).__init__(hs)
- @defer.inlineCallbacks
- def update_function(self, from_token, to_token, limit):
- global_results, room_results = yield self.store.get_all_updated_account_data(
+ async def update_function(self, from_token, to_token, limit):
+ global_results, room_results = await self.store.get_all_updated_account_data(
from_token, from_token, to_token, limit
)
results = list(room_results)
results.extend(
- (stream_id, user_id, None, account_data_type, content,)
- for stream_id, user_id, account_data_type, content in global_results
+ (stream_id, user_id, None, account_data_type)
+ for stream_id, user_id, account_data_type in global_results
)
- defer.returnValue(results)
+ return results
class GroupServerStream(Stream):
@@ -430,7 +446,24 @@ class GroupServerStream(Stream):
def __init__(self, hs):
store = hs.get_datastore()
- self.current_token = store.get_group_stream_token
- self.update_function = store.get_all_groups_changes
+ self.current_token = store.get_group_stream_token # type: ignore
+ self.update_function = store.get_all_groups_changes # type: ignore
super(GroupServerStream, self).__init__(hs)
+
+
+class UserSignatureStream(Stream):
+ """A user has signed their own device with their user-signing key
+ """
+
+ NAME = "user_signature"
+ _LIMITED = False
+ ROW_TYPE = UserSignatureStreamRow
+
+ def __init__(self, hs):
+ store = hs.get_datastore()
+
+ self.current_token = store.get_device_stream_token # type: ignore
+ self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore
+
+ super(UserSignatureStream, self).__init__(hs)
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index f1290d022a..b3afabb8cd 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -13,12 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import heapq
+from typing import Tuple, Type
import attr
-from twisted.internet import defer
-
from ._base import Stream
@@ -52,6 +52,7 @@ data part are:
@attr.s(slots=True, frozen=True)
class EventsStreamRow(object):
"""A parsed row from the events replication stream"""
+
type = attr.ib() # str: the TypeId of one of the *EventsStreamRows
data = attr.ib() # BaseEventsStreamRow
@@ -62,7 +63,8 @@ class BaseEventsStreamRow(object):
Specifies how to identify, serialize and deserialize the different types.
"""
- TypeId = None # Unique string that ids the type. Must be overriden in sub classes.
+ # Unique string that ids the type. Must be overriden in sub classes.
+ TypeId = None # type: str
@classmethod
def from_data(cls, data):
@@ -80,11 +82,11 @@ class BaseEventsStreamRow(object):
class EventsStreamEventRow(BaseEventsStreamRow):
TypeId = "ev"
- event_id = attr.ib() # str
- room_id = attr.ib() # str
- type = attr.ib() # str
- state_key = attr.ib() # str, optional
- redacts = attr.ib() # str, optional
+ event_id = attr.ib() # str
+ room_id = attr.ib() # str
+ type = attr.ib() # str
+ state_key = attr.ib() # str, optional
+ redacts = attr.ib() # str, optional
relates_to = attr.ib() # str, optional
@@ -92,53 +94,50 @@ class EventsStreamEventRow(BaseEventsStreamRow):
class EventsStreamCurrentStateRow(BaseEventsStreamRow):
TypeId = "state"
- room_id = attr.ib() # str
- type = attr.ib() # str
+ room_id = attr.ib() # str
+ type = attr.ib() # str
state_key = attr.ib() # str
- event_id = attr.ib() # str, optional
+ event_id = attr.ib() # str, optional
-TypeToRow = {
- Row.TypeId: Row
- for Row in (
- EventsStreamEventRow,
- EventsStreamCurrentStateRow,
- )
-}
+_EventRows = (
+ EventsStreamEventRow,
+ EventsStreamCurrentStateRow,
+) # type: Tuple[Type[BaseEventsStreamRow], ...]
+
+TypeToRow = {Row.TypeId: Row for Row in _EventRows}
class EventsStream(Stream):
"""We received a new event, or an event went from being an outlier to not
"""
+
NAME = "events"
def __init__(self, hs):
self._store = hs.get_datastore()
- self.current_token = self._store.get_current_events_token
+ self.current_token = self._store.get_current_events_token # type: ignore
super(EventsStream, self).__init__(hs)
- @defer.inlineCallbacks
- def update_function(self, from_token, current_token, limit=None):
- event_rows = yield self._store.get_all_new_forward_event_rows(
- from_token, current_token, limit,
+ async def update_function(self, from_token, current_token, limit=None):
+ event_rows = await self._store.get_all_new_forward_event_rows(
+ from_token, current_token, limit
)
event_updates = (
- (row[0], EventsStreamEventRow.TypeId, row[1:])
- for row in event_rows
+ (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows
)
- state_rows = yield self._store.get_all_updated_current_state_deltas(
+ state_rows = await self._store.get_all_updated_current_state_deltas(
from_token, current_token, limit
)
state_updates = (
- (row[0], EventsStreamCurrentStateRow.TypeId, row[1:])
- for row in state_rows
+ (row[0], EventsStreamCurrentStateRow.TypeId, row[1:]) for row in state_rows
)
all_updates = heapq.merge(event_updates, state_updates)
- defer.returnValue(all_updates)
+ return all_updates
@classmethod
def parse_row(cls, row):
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index 9aa43aa8d2..615f3dc9ac 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -17,23 +17,27 @@ from collections import namedtuple
from ._base import Stream
-FederationStreamRow = namedtuple("FederationStreamRow", (
- "type", # str, the type of data as defined in the BaseFederationRows
- "data", # dict, serialization of a federation.send_queue.BaseFederationRow
-))
+FederationStreamRow = namedtuple(
+ "FederationStreamRow",
+ (
+ "type", # str, the type of data as defined in the BaseFederationRows
+ "data", # dict, serialization of a federation.send_queue.BaseFederationRow
+ ),
+)
class FederationStream(Stream):
"""Data to be sent over federation. Only available when master has federation
sending disabled.
"""
+
NAME = "federation"
ROW_TYPE = FederationStreamRow
def __init__(self, hs):
federation_sender = hs.get_federation_sender()
- self.current_token = federation_sender.get_current_token
- self.update_function = federation_sender.get_replication_rows
+ self.current_token = federation_sender.get_current_token # type: ignore
+ self.update_function = federation_sender.get_replication_rows # type: ignore
super(FederationStream, self).__init__(hs)
diff --git a/synapse/res/templates/add_threepid.html b/synapse/res/templates/add_threepid.html
new file mode 100644
index 0000000000..cc4ab07e09
--- /dev/null
+++ b/synapse/res/templates/add_threepid.html
@@ -0,0 +1,9 @@
+<html>
+<body>
+ <p>A request to add an email address to your Matrix account has been received. If this was you, please click the link below to confirm adding this email:</p>
+
+ <a href="{{ link }}">{{ link }}</a>
+
+ <p>If this was not you, you can safely ignore this email. Thank you.</p>
+</body>
+</html>
diff --git a/synapse/res/templates/add_threepid.txt b/synapse/res/templates/add_threepid.txt
new file mode 100644
index 0000000000..a60c1ff659
--- /dev/null
+++ b/synapse/res/templates/add_threepid.txt
@@ -0,0 +1,6 @@
+A request to add an email address to your Matrix account has been received. If this was you,
+please click the link below to confirm adding this email:
+
+{{ link }}
+
+If this was not you, you can safely ignore this email. Thank you.
diff --git a/synapse/res/templates/add_threepid_failure.html b/synapse/res/templates/add_threepid_failure.html
new file mode 100644
index 0000000000..441d11c846
--- /dev/null
+++ b/synapse/res/templates/add_threepid_failure.html
@@ -0,0 +1,8 @@
+<html>
+<head></head>
+<body>
+<p>The request failed for the following reason: {{ failure_reason }}.</p>
+
+<p>No changes have been made to your account.</p>
+</body>
+</html>
diff --git a/synapse/res/templates/add_threepid_success.html b/synapse/res/templates/add_threepid_success.html
new file mode 100644
index 0000000000..fbd6e4018f
--- /dev/null
+++ b/synapse/res/templates/add_threepid_success.html
@@ -0,0 +1,6 @@
+<html>
+<head></head>
+<body>
+<p>Your email has now been validated, please return to your client. You may now close this window.</p>
+</body>
+</html>
diff --git a/synapse/res/templates/password_reset.html b/synapse/res/templates/password_reset.html
index 4fa7b36734..a197bf872c 100644
--- a/synapse/res/templates/password_reset.html
+++ b/synapse/res/templates/password_reset.html
@@ -4,6 +4,6 @@
<a href="{{ link }}">{{ link }}</a>
- <p>If this was not you, please disregard this email and contact your server administrator. Thank you.</p>
+ <p>If this was not you, <strong>do not</strong> click the link above and instead contact your server administrator. Thank you.</p>
</body>
</html>
diff --git a/synapse/res/templates/password_reset.txt b/synapse/res/templates/password_reset.txt
index f0deff59a7..6aa6527560 100644
--- a/synapse/res/templates/password_reset.txt
+++ b/synapse/res/templates/password_reset.txt
@@ -3,5 +3,5 @@ was you, please click the link below to confirm resetting your password:
{{ link }}
-If this was not you, please disregard this email and contact your server
-administrator. Thank you.
+If this was not you, DO NOT click the link above and instead contact your
+server administrator. Thank you.
diff --git a/synapse/res/templates/password_reset_failure.html b/synapse/res/templates/password_reset_failure.html
index 0b132cf8db..9e3c4446e3 100644
--- a/synapse/res/templates/password_reset_failure.html
+++ b/synapse/res/templates/password_reset_failure.html
@@ -1,6 +1,8 @@
<html>
<head></head>
<body>
-<p>{{ failure_reason }}. Your password has not been reset.</p>
+<p>The request failed for the following reason: {{ failure_reason }}.</p>
+
+<p>Your password has not been reset.</p>
</body>
</html>
diff --git a/synapse/res/templates/registration.html b/synapse/res/templates/registration.html
new file mode 100644
index 0000000000..16730a527f
--- /dev/null
+++ b/synapse/res/templates/registration.html
@@ -0,0 +1,11 @@
+<html>
+<body>
+ <p>You have asked us to register this email with a new Matrix account. If this was you, please click the link below to confirm your email address:</p>
+
+ <a href="{{ link }}">Verify Your Email Address</a>
+
+ <p>If this was not you, you can safely disregard this email.</p>
+
+ <p>Thank you.</p>
+</body>
+</html>
diff --git a/synapse/res/templates/registration.txt b/synapse/res/templates/registration.txt
new file mode 100644
index 0000000000..cb4f16a90c
--- /dev/null
+++ b/synapse/res/templates/registration.txt
@@ -0,0 +1,10 @@
+Hello there,
+
+You have asked us to register this email with a new Matrix account. If this
+was you, please click the link below to confirm your email address:
+
+{{ link }}
+
+If this was not you, you can safely disregard this email.
+
+Thank you.
diff --git a/synapse/res/templates/registration_failure.html b/synapse/res/templates/registration_failure.html
new file mode 100644
index 0000000000..2833d79c37
--- /dev/null
+++ b/synapse/res/templates/registration_failure.html
@@ -0,0 +1,6 @@
+<html>
+<head></head>
+<body>
+<p>Validation failed for the following reason: {{ failure_reason }}.</p>
+</body>
+</html>
diff --git a/synapse/res/templates/registration_success.html b/synapse/res/templates/registration_success.html
new file mode 100644
index 0000000000..fbd6e4018f
--- /dev/null
+++ b/synapse/res/templates/registration_success.html
@@ -0,0 +1,6 @@
+<html>
+<head></head>
+<body>
+<p>Your email has now been validated, please return to your client. You may now close this window.</p>
+</body>
+</html>
diff --git a/synapse/res/templates/saml_error.html b/synapse/res/templates/saml_error.html
new file mode 100644
index 0000000000..bfd6449c5d
--- /dev/null
+++ b/synapse/res/templates/saml_error.html
@@ -0,0 +1,45 @@
+<!DOCTYPE html>
+<html lang="en">
+<head>
+ <meta charset="UTF-8">
+ <title>SSO error</title>
+</head>
+<body>
+ <p>Oops! Something went wrong during authentication<span id="errormsg"></span>.</p>
+ <p>
+ If you are seeing this page after clicking a link sent to you via email, make
+ sure you only click the confirmation link once, and that you open the
+ validation link in the same client you're logging in from.
+ </p>
+ <p>
+ Try logging in again from your Matrix client and if the problem persists
+ please contact the server's administrator.
+ </p>
+
+ <script type="text/javascript">
+ // Error handling to support Auth0 errors that we might get through a GET request
+ // to the validation endpoint. If an error is provided, it's either going to be
+ // located in the query string or in a query string-like URI fragment.
+ // We try to locate the error from any of these two locations, but if we can't
+ // we just don't print anything specific.
+ let searchStr = "";
+ if (window.location.search) {
+ // window.location.searchParams isn't always defined when
+ // window.location.search is, so it's more reliable to parse the latter.
+ searchStr = window.location.search;
+ } else if (window.location.hash) {
+ // Replace the # with a ? so that URLSearchParams does the right thing and
+ // doesn't parse the first parameter incorrectly.
+ searchStr = window.location.hash.replace("#", "?");
+ }
+
+ // We might end up with no error in the URL, so we need to check if we have one
+ // to print one.
+ let errorDesc = new URLSearchParams(searchStr).get("error_description")
+ if (errorDesc) {
+
+ document.getElementById("errormsg").innerText = ` ("${errorDesc}")`;
+ }
+ </script>
+</body>
+</html>
\ No newline at end of file
diff --git a/synapse/res/templates/sso_redirect_confirm.html b/synapse/res/templates/sso_redirect_confirm.html
new file mode 100644
index 0000000000..20a15e1e74
--- /dev/null
+++ b/synapse/res/templates/sso_redirect_confirm.html
@@ -0,0 +1,14 @@
+<!DOCTYPE html>
+<html lang="en">
+<head>
+ <meta charset="UTF-8">
+ <title>SSO redirect confirmation</title>
+</head>
+ <body>
+ <p>The application at <span style="font-weight:bold">{{ display_url | e }}</span> is requesting full access to your <span style="font-weight:bold">{{ server_name }}</span> Matrix account.</p>
+ <p>If you don't recognise this address, you should ignore this and close this tab.</p>
+ <p>
+ <a href="{{ redirect_url | e }}">I trust this address</a>
+ </p>
+ </body>
+</html>
\ No newline at end of file
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 195f103cdd..14eca70ba4 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -67,13 +67,14 @@ class ClientRestResource(JsonResource):
* /_matrix/client/unstable
* etc
"""
+
def __init__(self, hs):
JsonResource.__init__(self, hs, canonical_json=False)
self.register_servlets(self, hs)
@staticmethod
def register_servlets(client_resource, hs):
- versions.register_servlets(client_resource)
+ versions.register_servlets(hs, client_resource)
# Deprecated in r0
initial_sync.register_servlets(hs, client_resource)
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index d6c4dcdb18..42cc2b062a 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -14,273 +14,53 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import hashlib
-import hmac
import logging
import platform
import re
-from six import text_type
-from six.moves import http_client
-
-from twisted.internet import defer
-
import synapse
-from synapse.api.constants import Membership, UserTypes
-from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
+from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import JsonResource
-from synapse.http.servlet import (
- RestServlet,
- assert_params_in_dict,
- parse_integer,
- parse_json_object_from_request,
- parse_string,
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.rest.admin._base import (
+ assert_requester_is_admin,
+ historical_admin_path_patterns,
)
-from synapse.rest.admin._base import assert_requester_is_admin, assert_user_is_admin
+from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
+from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
+from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
+from synapse.rest.admin.rooms import ListRoomRestServlet, ShutdownRoomRestServlet
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
-from synapse.types import UserID, create_requester
+from synapse.rest.admin.users import (
+ AccountValidityRenewServlet,
+ DeactivateAccountRestServlet,
+ ResetPasswordRestServlet,
+ SearchUsersRestServlet,
+ UserAdminServlet,
+ UserRegisterServlet,
+ UserRestServletV2,
+ UsersRestServlet,
+ UsersRestServletV2,
+ WhoisRestServlet,
+)
from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__)
-def historical_admin_path_patterns(path_regex):
- """Returns the list of patterns for an admin endpoint, including historical ones
-
- This is a backwards-compatibility hack. Previously, the Admin API was exposed at
- various paths under /_matrix/client. This function returns a list of patterns
- matching those paths (as well as the new one), so that existing scripts which rely
- on the endpoints being available there are not broken.
-
- Note that this should only be used for existing endpoints: new ones should just
- register for the /_synapse/admin path.
- """
- return list(
- re.compile(prefix + path_regex)
- for prefix in (
- "^/_synapse/admin/v1",
- "^/_matrix/client/api/v1/admin",
- "^/_matrix/client/unstable/admin",
- "^/_matrix/client/r0/admin"
- )
- )
-
-
-class UsersRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)")
-
- def __init__(self, hs):
- self.hs = hs
- self.auth = hs.get_auth()
- self.handlers = hs.get_handlers()
-
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
- target_user = UserID.from_string(user_id)
- yield assert_requester_is_admin(self.auth, request)
-
- if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Can only users a local user")
-
- ret = yield self.handlers.admin_handler.get_users()
-
- defer.returnValue((200, ret))
-
-
class VersionServlet(RestServlet):
- PATTERNS = (re.compile("^/_synapse/admin/v1/server_version$"), )
+ PATTERNS = (re.compile("^/_synapse/admin/v1/server_version$"),)
def __init__(self, hs):
self.res = {
- 'server_version': get_version_string(synapse),
- 'python_version': platform.python_version(),
+ "server_version": get_version_string(synapse),
+ "python_version": platform.python_version(),
}
def on_GET(self, request):
return 200, self.res
-class UserRegisterServlet(RestServlet):
- """
- Attributes:
- NONCE_TIMEOUT (int): Seconds until a generated nonce won't be accepted
- nonces (dict[str, int]): The nonces that we will accept. A dict of
- nonce to the time it was generated, in int seconds.
- """
- PATTERNS = historical_admin_path_patterns("/register")
- NONCE_TIMEOUT = 60
-
- def __init__(self, hs):
- self.handlers = hs.get_handlers()
- self.reactor = hs.get_reactor()
- self.nonces = {}
- self.hs = hs
-
- def _clear_old_nonces(self):
- """
- Clear out old nonces that are older than NONCE_TIMEOUT.
- """
- now = int(self.reactor.seconds())
-
- for k, v in list(self.nonces.items()):
- if now - v > self.NONCE_TIMEOUT:
- del self.nonces[k]
-
- def on_GET(self, request):
- """
- Generate a new nonce.
- """
- self._clear_old_nonces()
-
- nonce = self.hs.get_secrets().token_hex(64)
- self.nonces[nonce] = int(self.reactor.seconds())
- return (200, {"nonce": nonce})
-
- @defer.inlineCallbacks
- def on_POST(self, request):
- self._clear_old_nonces()
-
- if not self.hs.config.registration_shared_secret:
- raise SynapseError(400, "Shared secret registration is not enabled")
-
- body = parse_json_object_from_request(request)
-
- if "nonce" not in body:
- raise SynapseError(
- 400, "nonce must be specified", errcode=Codes.BAD_JSON,
- )
-
- nonce = body["nonce"]
-
- if nonce not in self.nonces:
- raise SynapseError(
- 400, "unrecognised nonce",
- )
-
- # Delete the nonce, so it can't be reused, even if it's invalid
- del self.nonces[nonce]
-
- if "username" not in body:
- raise SynapseError(
- 400, "username must be specified", errcode=Codes.BAD_JSON,
- )
- else:
- if (
- not isinstance(body['username'], text_type)
- or len(body['username']) > 512
- ):
- raise SynapseError(400, "Invalid username")
-
- username = body["username"].encode("utf-8")
- if b"\x00" in username:
- raise SynapseError(400, "Invalid username")
-
- if "password" not in body:
- raise SynapseError(
- 400, "password must be specified", errcode=Codes.BAD_JSON,
- )
- else:
- if (
- not isinstance(body['password'], text_type)
- or len(body['password']) > 512
- ):
- raise SynapseError(400, "Invalid password")
-
- password = body["password"].encode("utf-8")
- if b"\x00" in password:
- raise SynapseError(400, "Invalid password")
-
- admin = body.get("admin", None)
- user_type = body.get("user_type", None)
-
- if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
- raise SynapseError(400, "Invalid user type")
-
- got_mac = body["mac"]
-
- want_mac = hmac.new(
- key=self.hs.config.registration_shared_secret.encode(),
- digestmod=hashlib.sha1,
- )
- want_mac.update(nonce.encode('utf8'))
- want_mac.update(b"\x00")
- want_mac.update(username)
- want_mac.update(b"\x00")
- want_mac.update(password)
- want_mac.update(b"\x00")
- want_mac.update(b"admin" if admin else b"notadmin")
- if user_type:
- want_mac.update(b"\x00")
- want_mac.update(user_type.encode('utf8'))
- want_mac = want_mac.hexdigest()
-
- if not hmac.compare_digest(
- want_mac.encode('ascii'),
- got_mac.encode('ascii')
- ):
- raise SynapseError(403, "HMAC incorrect")
-
- # Reuse the parts of RegisterRestServlet to reduce code duplication
- from synapse.rest.client.v2_alpha.register import RegisterRestServlet
-
- register = RegisterRestServlet(self.hs)
-
- (user_id, _) = yield register.registration_handler.register(
- localpart=body['username'].lower(),
- password=body["password"],
- admin=bool(admin),
- generate_token=False,
- user_type=user_type,
- )
-
- result = yield register._create_registration_details(user_id, body)
- defer.returnValue((200, result))
-
-
-class WhoisRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/whois/(?P<user_id>[^/]*)")
-
- def __init__(self, hs):
- self.hs = hs
- self.auth = hs.get_auth()
- self.handlers = hs.get_handlers()
-
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
- target_user = UserID.from_string(user_id)
- requester = yield self.auth.get_user_by_req(request)
- auth_user = requester.user
-
- if target_user != auth_user:
- yield assert_user_is_admin(self.auth, auth_user)
-
- if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Can only whois a local user")
-
- ret = yield self.handlers.admin_handler.get_whois(target_user)
-
- defer.returnValue((200, ret))
-
-
-class PurgeMediaCacheRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/purge_media_cache")
-
- def __init__(self, hs):
- self.media_repository = hs.get_media_repository()
- self.auth = hs.get_auth()
-
- @defer.inlineCallbacks
- def on_POST(self, request):
- yield assert_requester_is_admin(self.auth, request)
-
- before_ts = parse_integer(request, "before_ts", required=True)
- logger.info("before_ts: %r", before_ts)
-
- ret = yield self.media_repository.delete_old_remote_media(before_ts)
-
- defer.returnValue((200, ret))
-
-
class PurgeHistoryRestServlet(RestServlet):
PATTERNS = historical_admin_path_patterns(
"/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
@@ -296,9 +76,8 @@ class PurgeHistoryRestServlet(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_POST(self, request, room_id, event_id):
- yield assert_requester_is_admin(self.auth, request)
+ async def on_POST(self, request, room_id, event_id):
+ await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request, allow_empty_body=True)
@@ -308,54 +87,47 @@ class PurgeHistoryRestServlet(RestServlet):
# user can provide an event_id in the URL or the request body, or can
# provide a timestamp in the request body.
if event_id is None:
- event_id = body.get('purge_up_to_event_id')
+ event_id = body.get("purge_up_to_event_id")
if event_id is not None:
- event = yield self.store.get_event(event_id)
+ event = await self.store.get_event(event_id)
if event.room_id != room_id:
raise SynapseError(400, "Event is for wrong room.")
- token = yield self.store.get_topological_token_for_event(event_id)
+ token = await self.store.get_topological_token_for_event(event_id)
- logger.info(
- "[purge] purging up to token %s (event_id %s)",
- token, event_id,
- )
- elif 'purge_up_to_ts' in body:
- ts = body['purge_up_to_ts']
+ logger.info("[purge] purging up to token %s (event_id %s)", token, event_id)
+ elif "purge_up_to_ts" in body:
+ ts = body["purge_up_to_ts"]
if not isinstance(ts, int):
raise SynapseError(
- 400, "purge_up_to_ts must be an int",
- errcode=Codes.BAD_JSON,
+ 400, "purge_up_to_ts must be an int", errcode=Codes.BAD_JSON
)
- stream_ordering = (
- yield self.store.find_first_stream_ordering_after_ts(ts)
- )
+ stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts)
- r = (
- yield self.store.get_room_event_after_stream_ordering(
- room_id, stream_ordering,
- )
+ r = await self.store.get_room_event_before_stream_ordering(
+ room_id, stream_ordering
)
if not r:
- logger.warn(
+ logger.warning(
"[purge] purging events not possible: No event found "
"(received_ts %i => stream_ordering %i)",
- ts, stream_ordering,
+ ts,
+ stream_ordering,
)
raise SynapseError(
- 404,
- "there is no event to be purged",
- errcode=Codes.NOT_FOUND,
+ 404, "there is no event to be purged", errcode=Codes.NOT_FOUND
)
(stream, topo, _event_id) = r
token = "t%d-%d" % (topo, stream)
logger.info(
"[purge] purging up to token %s (received_ts %i => "
"stream_ordering %i)",
- token, ts, stream_ordering,
+ token,
+ ts,
+ stream_ordering,
)
else:
raise SynapseError(
@@ -364,14 +136,11 @@ class PurgeHistoryRestServlet(RestServlet):
errcode=Codes.BAD_JSON,
)
- purge_id = yield self.pagination_handler.start_purge_history(
- room_id, token,
- delete_local_events=delete_local_events,
+ purge_id = self.pagination_handler.start_purge_history(
+ room_id, token, delete_local_events=delete_local_events
)
- defer.returnValue((200, {
- "purge_id": purge_id,
- }))
+ return 200, {"purge_id": purge_id}
class PurgeHistoryStatusRestServlet(RestServlet):
@@ -388,427 +157,16 @@ class PurgeHistoryStatusRestServlet(RestServlet):
self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, purge_id):
- yield assert_requester_is_admin(self.auth, request)
+ async def on_GET(self, request, purge_id):
+ await assert_requester_is_admin(self.auth, request)
purge_status = self.pagination_handler.get_purge_status(purge_id)
if purge_status is None:
raise NotFoundError("purge id '%s' not found" % purge_id)
- defer.returnValue((200, purge_status.asdict()))
-
-
-class DeactivateAccountRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/deactivate/(?P<target_user_id>[^/]*)")
-
- def __init__(self, hs):
- self._deactivate_account_handler = hs.get_deactivate_account_handler()
- self.auth = hs.get_auth()
-
- @defer.inlineCallbacks
- def on_POST(self, request, target_user_id):
- yield assert_requester_is_admin(self.auth, request)
- body = parse_json_object_from_request(request, allow_empty_body=True)
- erase = body.get("erase", False)
- if not isinstance(erase, bool):
- raise SynapseError(
- http_client.BAD_REQUEST,
- "Param 'erase' must be a boolean, if given",
- Codes.BAD_JSON,
- )
-
- UserID.from_string(target_user_id)
-
- result = yield self._deactivate_account_handler.deactivate_account(
- target_user_id, erase,
- )
- if result:
- id_server_unbind_result = "success"
- else:
- id_server_unbind_result = "no-support"
-
- defer.returnValue((200, {
- "id_server_unbind_result": id_server_unbind_result,
- }))
+ return 200, purge_status.asdict()
-class ShutdownRoomRestServlet(RestServlet):
- """Shuts down a room by removing all local users from the room and blocking
- all future invites and joins to the room. Any local aliases will be repointed
- to a new room created by `new_room_user_id` and kicked users will be auto
- joined to the new room.
- """
- PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P<room_id>[^/]+)")
-
- DEFAULT_MESSAGE = (
- "Sharing illegal content on this server is not permitted and rooms in"
- " violation will be blocked."
- )
-
- def __init__(self, hs):
- self.hs = hs
- self.store = hs.get_datastore()
- self.state = hs.get_state_handler()
- self._room_creation_handler = hs.get_room_creation_handler()
- self.event_creation_handler = hs.get_event_creation_handler()
- self.room_member_handler = hs.get_room_member_handler()
- self.auth = hs.get_auth()
-
- @defer.inlineCallbacks
- def on_POST(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request)
- yield assert_user_is_admin(self.auth, requester.user)
-
- content = parse_json_object_from_request(request)
- assert_params_in_dict(content, ["new_room_user_id"])
- new_room_user_id = content["new_room_user_id"]
-
- room_creator_requester = create_requester(new_room_user_id)
-
- message = content.get("message", self.DEFAULT_MESSAGE)
- room_name = content.get("room_name", "Content Violation Notification")
-
- info = yield self._room_creation_handler.create_room(
- room_creator_requester,
- config={
- "preset": "public_chat",
- "name": room_name,
- "power_level_content_override": {
- "users_default": -10,
- },
- },
- ratelimit=False,
- )
- new_room_id = info["room_id"]
-
- requester_user_id = requester.user.to_string()
-
- logger.info(
- "Shutting down room %r, joining to new room: %r",
- room_id, new_room_id,
- )
-
- # This will work even if the room is already blocked, but that is
- # desirable in case the first attempt at blocking the room failed below.
- yield self.store.block_room(room_id, requester_user_id)
-
- users = yield self.state.get_current_users_in_room(room_id)
- kicked_users = []
- failed_to_kick_users = []
- for user_id in users:
- if not self.hs.is_mine_id(user_id):
- continue
-
- logger.info("Kicking %r from %r...", user_id, room_id)
-
- try:
- target_requester = create_requester(user_id)
- yield self.room_member_handler.update_membership(
- requester=target_requester,
- target=target_requester.user,
- room_id=room_id,
- action=Membership.LEAVE,
- content={},
- ratelimit=False,
- require_consent=False,
- )
-
- yield self.room_member_handler.forget(target_requester.user, room_id)
-
- yield self.room_member_handler.update_membership(
- requester=target_requester,
- target=target_requester.user,
- room_id=new_room_id,
- action=Membership.JOIN,
- content={},
- ratelimit=False,
- require_consent=False,
- )
-
- kicked_users.append(user_id)
- except Exception:
- logger.exception(
- "Failed to leave old room and join new room for %r", user_id,
- )
- failed_to_kick_users.append(user_id)
-
- yield self.event_creation_handler.create_and_send_nonmember_event(
- room_creator_requester,
- {
- "type": "m.room.message",
- "content": {"body": message, "msgtype": "m.text"},
- "room_id": new_room_id,
- "sender": new_room_user_id,
- },
- ratelimit=False,
- )
-
- aliases_for_room = yield self.store.get_aliases_for_room(room_id)
-
- yield self.store.update_aliases_for_room(
- room_id, new_room_id, requester_user_id
- )
-
- defer.returnValue((200, {
- "kicked_users": kicked_users,
- "failed_to_kick_users": failed_to_kick_users,
- "local_aliases": aliases_for_room,
- "new_room_id": new_room_id,
- }))
-
-
-class QuarantineMediaInRoom(RestServlet):
- """Quarantines all media in a room so that no one can download it via
- this server.
- """
- PATTERNS = historical_admin_path_patterns("/quarantine_media/(?P<room_id>[^/]+)")
-
- def __init__(self, hs):
- self.store = hs.get_datastore()
- self.auth = hs.get_auth()
-
- @defer.inlineCallbacks
- def on_POST(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request)
- yield assert_user_is_admin(self.auth, requester.user)
-
- num_quarantined = yield self.store.quarantine_media_ids_in_room(
- room_id, requester.user.to_string(),
- )
-
- defer.returnValue((200, {"num_quarantined": num_quarantined}))
-
-
-class ListMediaInRoom(RestServlet):
- """Lists all of the media in a given room.
- """
- PATTERNS = historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media")
-
- def __init__(self, hs):
- self.store = hs.get_datastore()
-
- @defer.inlineCallbacks
- def on_GET(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request)
- is_admin = yield self.auth.is_server_admin(requester.user)
- if not is_admin:
- raise AuthError(403, "You are not a server admin")
-
- local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id)
-
- defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs}))
-
-
-class ResetPasswordRestServlet(RestServlet):
- """Post request to allow an administrator reset password for a user.
- This needs user to have administrator access in Synapse.
- Example:
- http://localhost:8008/_synapse/admin/v1/reset_password/
- @user:to_reset_password?access_token=admin_access_token
- JsonBodyToSend:
- {
- "new_password": "secret"
- }
- Returns:
- 200 OK with empty object if success otherwise an error.
- """
- PATTERNS = historical_admin_path_patterns("/reset_password/(?P<target_user_id>[^/]*)")
-
- def __init__(self, hs):
- self.store = hs.get_datastore()
- self.hs = hs
- self.auth = hs.get_auth()
- self._set_password_handler = hs.get_set_password_handler()
-
- @defer.inlineCallbacks
- def on_POST(self, request, target_user_id):
- """Post request to allow an administrator reset password for a user.
- This needs user to have administrator access in Synapse.
- """
- requester = yield self.auth.get_user_by_req(request)
- yield assert_user_is_admin(self.auth, requester.user)
-
- UserID.from_string(target_user_id)
-
- params = parse_json_object_from_request(request)
- assert_params_in_dict(params, ["new_password"])
- new_password = params['new_password']
-
- yield self._set_password_handler.set_password(
- target_user_id, new_password, requester
- )
- defer.returnValue((200, {}))
-
-
-class GetUsersPaginatedRestServlet(RestServlet):
- """Get request to get specific number of users from Synapse.
- This needs user to have administrator access in Synapse.
- Example:
- http://localhost:8008/_synapse/admin/v1/users_paginate/
- @admin:user?access_token=admin_access_token&start=0&limit=10
- Returns:
- 200 OK with json object {list[dict[str, Any]], count} or empty object.
- """
- PATTERNS = historical_admin_path_patterns("/users_paginate/(?P<target_user_id>[^/]*)")
-
- def __init__(self, hs):
- self.store = hs.get_datastore()
- self.hs = hs
- self.auth = hs.get_auth()
- self.handlers = hs.get_handlers()
-
- @defer.inlineCallbacks
- def on_GET(self, request, target_user_id):
- """Get request to get specific number of users from Synapse.
- This needs user to have administrator access in Synapse.
- """
- yield assert_requester_is_admin(self.auth, request)
-
- target_user = UserID.from_string(target_user_id)
-
- if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Can only users a local user")
-
- order = "name" # order by name in user table
- start = parse_integer(request, "start", required=True)
- limit = parse_integer(request, "limit", required=True)
-
- logger.info("limit: %s, start: %s", limit, start)
-
- ret = yield self.handlers.admin_handler.get_users_paginate(
- order, start, limit
- )
- defer.returnValue((200, ret))
-
- @defer.inlineCallbacks
- def on_POST(self, request, target_user_id):
- """Post request to get specific number of users from Synapse..
- This needs user to have administrator access in Synapse.
- Example:
- http://localhost:8008/_synapse/admin/v1/users_paginate/
- @admin:user?access_token=admin_access_token
- JsonBodyToSend:
- {
- "start": "0",
- "limit": "10
- }
- Returns:
- 200 OK with json object {list[dict[str, Any]], count} or empty object.
- """
- yield assert_requester_is_admin(self.auth, request)
- UserID.from_string(target_user_id)
-
- order = "name" # order by name in user table
- params = parse_json_object_from_request(request)
- assert_params_in_dict(params, ["limit", "start"])
- limit = params['limit']
- start = params['start']
- logger.info("limit: %s, start: %s", limit, start)
-
- ret = yield self.handlers.admin_handler.get_users_paginate(
- order, start, limit
- )
- defer.returnValue((200, ret))
-
-
-class SearchUsersRestServlet(RestServlet):
- """Get request to search user table for specific users according to
- search term.
- This needs user to have administrator access in Synapse.
- Example:
- http://localhost:8008/_synapse/admin/v1/search_users/
- @admin:user?access_token=admin_access_token&term=alice
- Returns:
- 200 OK with json object {list[dict[str, Any]], count} or empty object.
- """
- PATTERNS = historical_admin_path_patterns("/search_users/(?P<target_user_id>[^/]*)")
-
- def __init__(self, hs):
- self.store = hs.get_datastore()
- self.hs = hs
- self.auth = hs.get_auth()
- self.handlers = hs.get_handlers()
-
- @defer.inlineCallbacks
- def on_GET(self, request, target_user_id):
- """Get request to search user table for specific users according to
- search term.
- This needs user to have a administrator access in Synapse.
- """
- yield assert_requester_is_admin(self.auth, request)
-
- target_user = UserID.from_string(target_user_id)
-
- # To allow all users to get the users list
- # if not is_admin and target_user != auth_user:
- # raise AuthError(403, "You are not a server admin")
-
- if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Can only users a local user")
-
- term = parse_string(request, "term", required=True)
- logger.info("term: %s ", term)
-
- ret = yield self.handlers.admin_handler.search_users(
- term
- )
- defer.returnValue((200, ret))
-
-
-class DeleteGroupAdminRestServlet(RestServlet):
- """Allows deleting of local groups
- """
- PATTERNS = historical_admin_path_patterns("/delete_group/(?P<group_id>[^/]*)")
-
- def __init__(self, hs):
- self.group_server = hs.get_groups_server_handler()
- self.is_mine_id = hs.is_mine_id
- self.auth = hs.get_auth()
-
- @defer.inlineCallbacks
- def on_POST(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request)
- yield assert_user_is_admin(self.auth, requester.user)
-
- if not self.is_mine_id(group_id):
- raise SynapseError(400, "Can only delete local groups")
-
- yield self.group_server.delete_group(group_id, requester.user.to_string())
- defer.returnValue((200, {}))
-
-
-class AccountValidityRenewServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/account_validity/validity$")
-
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
- self.hs = hs
- self.account_activity_handler = hs.get_account_validity_handler()
- self.auth = hs.get_auth()
-
- @defer.inlineCallbacks
- def on_POST(self, request):
- yield assert_requester_is_admin(self.auth, request)
-
- body = parse_json_object_from_request(request)
-
- if "user_id" not in body:
- raise SynapseError(400, "Missing property 'user_id' in the request body")
-
- expiration_ts = yield self.account_activity_handler.renew_account_for_user(
- body["user_id"], body.get("expiration_ts"),
- not body.get("enable_renewal_emails", True),
- )
-
- res = {
- "expiration_ts": expiration_ts,
- }
- defer.returnValue((200, res))
-
########################################################################################
#
# please don't add more servlets here: this file is already long and unwieldy. Put
@@ -830,26 +188,35 @@ def register_servlets(hs, http_server):
Register all the admin servlets.
"""
register_servlets_for_client_rest_resource(hs, http_server)
+ ListRoomRestServlet(hs).register(http_server)
+ PurgeRoomServlet(hs).register(http_server)
SendServerNoticeServlet(hs).register(http_server)
VersionServlet(hs).register(http_server)
+ UserAdminServlet(hs).register(http_server)
+ UserRestServletV2(hs).register(http_server)
+ UsersRestServletV2(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server):
"""Register only the servlets which need to be exposed on /_matrix/client/xxx"""
WhoisRestServlet(hs).register(http_server)
- PurgeMediaCacheRestServlet(hs).register(http_server)
PurgeHistoryStatusRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
PurgeHistoryRestServlet(hs).register(http_server)
UsersRestServlet(hs).register(http_server)
ResetPasswordRestServlet(hs).register(http_server)
- GetUsersPaginatedRestServlet(hs).register(http_server)
SearchUsersRestServlet(hs).register(http_server)
ShutdownRoomRestServlet(hs).register(http_server)
- QuarantineMediaInRoom(hs).register(http_server)
- ListMediaInRoom(hs).register(http_server)
UserRegisterServlet(hs).register(http_server)
DeleteGroupAdminRestServlet(hs).register(http_server)
AccountValidityRenewServlet(hs).register(http_server)
+
+ # Load the media repo ones if we're using them. Otherwise load the servlets which
+ # don't need a media repo (typically readonly admin APIs).
+ if hs.config.can_load_media_repo:
+ register_servlets_for_media_repo(hs, http_server)
+ else:
+ ListMediaInRoom(hs).register(http_server)
+
# don't add more things here: new servlets should only be exposed on
# /_synapse/admin so should not go here. Instead register them in AdminRestResource.
diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index 881d67b89c..a96f75ce26 100644
--- a/synapse/rest/admin/_base.py
+++ b/synapse/rest/admin/_base.py
@@ -12,13 +12,50 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+
+import re
from synapse.api.errors import AuthError
-@defer.inlineCallbacks
-def assert_requester_is_admin(auth, request):
+def historical_admin_path_patterns(path_regex):
+ """Returns the list of patterns for an admin endpoint, including historical ones
+
+ This is a backwards-compatibility hack. Previously, the Admin API was exposed at
+ various paths under /_matrix/client. This function returns a list of patterns
+ matching those paths (as well as the new one), so that existing scripts which rely
+ on the endpoints being available there are not broken.
+
+ Note that this should only be used for existing endpoints: new ones should just
+ register for the /_synapse/admin path.
+ """
+ return [
+ re.compile(prefix + path_regex)
+ for prefix in (
+ "^/_synapse/admin/v1",
+ "^/_matrix/client/api/v1/admin",
+ "^/_matrix/client/unstable/admin",
+ "^/_matrix/client/r0/admin",
+ )
+ ]
+
+
+def admin_patterns(path_regex: str):
+ """Returns the list of patterns for an admin endpoint
+
+ Args:
+ path_regex: The regex string to match. This should NOT have a ^
+ as this will be prefixed.
+
+ Returns:
+ A list of regex patterns.
+ """
+ admin_prefix = "^/_synapse/admin/v1"
+ patterns = [re.compile(admin_prefix + path_regex)]
+ return patterns
+
+
+async def assert_requester_is_admin(auth, request):
"""Verify that the requester is an admin user
WARNING: MAKE SURE YOU YIELD ON THE RESULT!
@@ -33,12 +70,11 @@ def assert_requester_is_admin(auth, request):
Raises:
AuthError if the requester is not an admin
"""
- requester = yield auth.get_user_by_req(request)
- yield assert_user_is_admin(auth, requester.user)
+ requester = await auth.get_user_by_req(request)
+ await assert_user_is_admin(auth, requester.user)
-@defer.inlineCallbacks
-def assert_user_is_admin(auth, user_id):
+async def assert_user_is_admin(auth, user_id):
"""Verify that the given user is an admin user
WARNING: MAKE SURE YOU YIELD ON THE RESULT!
@@ -54,6 +90,6 @@ def assert_user_is_admin(auth, user_id):
AuthError if the user is not an admin
"""
- is_admin = yield auth.is_server_admin(user_id)
+ is_admin = await auth.is_server_admin(user_id)
if not is_admin:
raise AuthError(403, "You are not a server admin")
diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py
new file mode 100644
index 0000000000..0b54ca09f4
--- /dev/null
+++ b/synapse/rest/admin/groups.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import RestServlet
+from synapse.rest.admin._base import (
+ assert_user_is_admin,
+ historical_admin_path_patterns,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class DeleteGroupAdminRestServlet(RestServlet):
+ """Allows deleting of local groups
+ """
+
+ PATTERNS = historical_admin_path_patterns("/delete_group/(?P<group_id>[^/]*)")
+
+ def __init__(self, hs):
+ self.group_server = hs.get_groups_server_handler()
+ self.is_mine_id = hs.is_mine_id
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ if not self.is_mine_id(group_id):
+ raise SynapseError(400, "Can only delete local groups")
+
+ await self.group_server.delete_group(group_id, requester.user.to_string())
+ return 200, {}
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
new file mode 100644
index 0000000000..ee75095c0e
--- /dev/null
+++ b/synapse/rest/admin/media.py
@@ -0,0 +1,161 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018-2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.api.errors import AuthError
+from synapse.http.servlet import RestServlet, parse_integer
+from synapse.rest.admin._base import (
+ assert_requester_is_admin,
+ assert_user_is_admin,
+ historical_admin_path_patterns,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class QuarantineMediaInRoom(RestServlet):
+ """Quarantines all media in a room so that no one can download it via
+ this server.
+ """
+
+ PATTERNS = (
+ historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media/quarantine")
+ +
+ # This path kept around for legacy reasons
+ historical_admin_path_patterns("/quarantine_media/(?P<room_id>[^/]+)")
+ )
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request, room_id: str):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ logging.info("Quarantining room: %s", room_id)
+
+ # Quarantine all media in this room
+ num_quarantined = await self.store.quarantine_media_ids_in_room(
+ room_id, requester.user.to_string()
+ )
+
+ return 200, {"num_quarantined": num_quarantined}
+
+
+class QuarantineMediaByUser(RestServlet):
+ """Quarantines all local media by a given user so that no one can download it via
+ this server.
+ """
+
+ PATTERNS = historical_admin_path_patterns(
+ "/user/(?P<user_id>[^/]+)/media/quarantine"
+ )
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request, user_id: str):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ logging.info("Quarantining local media by user: %s", user_id)
+
+ # Quarantine all media this user has uploaded
+ num_quarantined = await self.store.quarantine_media_ids_by_user(
+ user_id, requester.user.to_string()
+ )
+
+ return 200, {"num_quarantined": num_quarantined}
+
+
+class QuarantineMediaByID(RestServlet):
+ """Quarantines local or remote media by a given ID so that no one can download
+ it via this server.
+ """
+
+ PATTERNS = historical_admin_path_patterns(
+ "/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
+ )
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request, server_name: str, media_id: str):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ logging.info("Quarantining local media by ID: %s/%s", server_name, media_id)
+
+ # Quarantine this media id
+ await self.store.quarantine_media_by_id(
+ server_name, media_id, requester.user.to_string()
+ )
+
+ return 200, {}
+
+
+class ListMediaInRoom(RestServlet):
+ """Lists all of the media in a given room.
+ """
+
+ PATTERNS = historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media")
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ async def on_GET(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
+ is_admin = await self.auth.is_server_admin(requester.user)
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id)
+
+ return 200, {"local": local_mxcs, "remote": remote_mxcs}
+
+
+class PurgeMediaCacheRestServlet(RestServlet):
+ PATTERNS = historical_admin_path_patterns("/purge_media_cache")
+
+ def __init__(self, hs):
+ self.media_repository = hs.get_media_repository()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request):
+ await assert_requester_is_admin(self.auth, request)
+
+ before_ts = parse_integer(request, "before_ts", required=True)
+ logger.info("before_ts: %r", before_ts)
+
+ ret = await self.media_repository.delete_old_remote_media(before_ts)
+
+ return 200, ret
+
+
+def register_servlets_for_media_repo(hs, http_server):
+ """
+ Media repo specific APIs.
+ """
+ PurgeMediaCacheRestServlet(hs).register(http_server)
+ QuarantineMediaInRoom(hs).register(http_server)
+ QuarantineMediaByID(hs).register(http_server)
+ QuarantineMediaByUser(hs).register(http_server)
+ ListMediaInRoom(hs).register(http_server)
diff --git a/synapse/rest/admin/purge_room_servlet.py b/synapse/rest/admin/purge_room_servlet.py
new file mode 100644
index 0000000000..f474066542
--- /dev/null
+++ b/synapse/rest/admin/purge_room_servlet.py
@@ -0,0 +1,57 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import re
+
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
+from synapse.rest.admin import assert_requester_is_admin
+
+
+class PurgeRoomServlet(RestServlet):
+ """Servlet which will remove all trace of a room from the database
+
+ POST /_synapse/admin/v1/purge_room
+ {
+ "room_id": "!room:id"
+ }
+
+ returns:
+
+ {}
+ """
+
+ PATTERNS = (re.compile("^/_synapse/admin/v1/purge_room$"),)
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.pagination_handler = hs.get_pagination_handler()
+
+ async def on_POST(self, request):
+ await assert_requester_is_admin(self.auth, request)
+
+ body = parse_json_object_from_request(request)
+ assert_params_in_dict(body, ("room_id",))
+
+ await self.pagination_handler.purge_room(body["room_id"])
+
+ return 200, {}
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
new file mode 100644
index 0000000000..f9b8c0a4f0
--- /dev/null
+++ b/synapse/rest/admin/rooms.py
@@ -0,0 +1,239 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.api.constants import Membership
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_integer,
+ parse_json_object_from_request,
+ parse_string,
+)
+from synapse.rest.admin._base import (
+ admin_patterns,
+ assert_user_is_admin,
+ historical_admin_path_patterns,
+)
+from synapse.storage.data_stores.main.room import RoomSortOrder
+from synapse.types import create_requester
+from synapse.util.async_helpers import maybe_awaitable
+
+logger = logging.getLogger(__name__)
+
+
+class ShutdownRoomRestServlet(RestServlet):
+ """Shuts down a room by removing all local users from the room and blocking
+ all future invites and joins to the room. Any local aliases will be repointed
+ to a new room created by `new_room_user_id` and kicked users will be auto
+ joined to the new room.
+ """
+
+ PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P<room_id>[^/]+)")
+
+ DEFAULT_MESSAGE = (
+ "Sharing illegal content on this server is not permitted and rooms in"
+ " violation will be blocked."
+ )
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.state = hs.get_state_handler()
+ self._room_creation_handler = hs.get_room_creation_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self.room_member_handler = hs.get_room_member_handler()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ content = parse_json_object_from_request(request)
+ assert_params_in_dict(content, ["new_room_user_id"])
+ new_room_user_id = content["new_room_user_id"]
+
+ room_creator_requester = create_requester(new_room_user_id)
+
+ message = content.get("message", self.DEFAULT_MESSAGE)
+ room_name = content.get("room_name", "Content Violation Notification")
+
+ info = await self._room_creation_handler.create_room(
+ room_creator_requester,
+ config={
+ "preset": "public_chat",
+ "name": room_name,
+ "power_level_content_override": {"users_default": -10},
+ },
+ ratelimit=False,
+ )
+ new_room_id = info["room_id"]
+
+ requester_user_id = requester.user.to_string()
+
+ logger.info(
+ "Shutting down room %r, joining to new room: %r", room_id, new_room_id
+ )
+
+ # This will work even if the room is already blocked, but that is
+ # desirable in case the first attempt at blocking the room failed below.
+ await self.store.block_room(room_id, requester_user_id)
+
+ users = await self.state.get_current_users_in_room(room_id)
+ kicked_users = []
+ failed_to_kick_users = []
+ for user_id in users:
+ if not self.hs.is_mine_id(user_id):
+ continue
+
+ logger.info("Kicking %r from %r...", user_id, room_id)
+
+ try:
+ target_requester = create_requester(user_id)
+ await self.room_member_handler.update_membership(
+ requester=target_requester,
+ target=target_requester.user,
+ room_id=room_id,
+ action=Membership.LEAVE,
+ content={},
+ ratelimit=False,
+ require_consent=False,
+ )
+
+ await self.room_member_handler.forget(target_requester.user, room_id)
+
+ await self.room_member_handler.update_membership(
+ requester=target_requester,
+ target=target_requester.user,
+ room_id=new_room_id,
+ action=Membership.JOIN,
+ content={},
+ ratelimit=False,
+ require_consent=False,
+ )
+
+ kicked_users.append(user_id)
+ except Exception:
+ logger.exception(
+ "Failed to leave old room and join new room for %r", user_id
+ )
+ failed_to_kick_users.append(user_id)
+
+ await self.event_creation_handler.create_and_send_nonmember_event(
+ room_creator_requester,
+ {
+ "type": "m.room.message",
+ "content": {"body": message, "msgtype": "m.text"},
+ "room_id": new_room_id,
+ "sender": new_room_user_id,
+ },
+ ratelimit=False,
+ )
+
+ aliases_for_room = await maybe_awaitable(
+ self.store.get_aliases_for_room(room_id)
+ )
+
+ await self.store.update_aliases_for_room(
+ room_id, new_room_id, requester_user_id
+ )
+
+ return (
+ 200,
+ {
+ "kicked_users": kicked_users,
+ "failed_to_kick_users": failed_to_kick_users,
+ "local_aliases": aliases_for_room,
+ "new_room_id": new_room_id,
+ },
+ )
+
+
+class ListRoomRestServlet(RestServlet):
+ """
+ List all rooms that are known to the homeserver. Results are returned
+ in a dictionary containing room information. Supports pagination.
+ """
+
+ PATTERNS = admin_patterns("/rooms")
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.admin_handler = hs.get_handlers().admin_handler
+
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ # Extract query parameters
+ start = parse_integer(request, "from", default=0)
+ limit = parse_integer(request, "limit", default=100)
+ order_by = parse_string(request, "order_by", default="alphabetical")
+ if order_by not in (
+ RoomSortOrder.ALPHABETICAL.value,
+ RoomSortOrder.SIZE.value,
+ ):
+ raise SynapseError(
+ 400,
+ "Unknown value for order_by: %s" % (order_by,),
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ search_term = parse_string(request, "search_term")
+ if search_term == "":
+ raise SynapseError(
+ 400,
+ "search_term cannot be an empty string",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ direction = parse_string(request, "dir", default="f")
+ if direction not in ("f", "b"):
+ raise SynapseError(
+ 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM
+ )
+
+ reverse_order = True if direction == "b" else False
+
+ # Return list of rooms according to parameters
+ rooms, total_rooms = await self.store.get_rooms_paginate(
+ start, limit, order_by, reverse_order, search_term
+ )
+ response = {
+ # next_token should be opaque, so return a value the client can parse
+ "offset": start,
+ "rooms": rooms,
+ "total_rooms": total_rooms,
+ }
+
+ # Are there more rooms to paginate through after this?
+ if (start + limit) < total_rooms:
+ # There are. Calculate where the query should start from next time
+ # to get the next part of the list
+ response["next_batch"] = start + limit
+
+ # Is it possible to paginate backwards? Check if we currently have an
+ # offset
+ if start > 0:
+ if start > limit:
+ # Going back one iteration won't take us to the start.
+ # Calculate new offset
+ response["prev_batch"] = start - limit
+ else:
+ response["prev_batch"] = 0
+
+ return 200, response
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index ae5aca9dac..6e9a874121 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -14,8 +14,6 @@
# limitations under the License.
import re
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
from synapse.http.servlet import (
@@ -46,6 +44,7 @@ class SendServerNoticeServlet(RestServlet):
"event_id": "$1895723857jgskldgujpious"
}
"""
+
def __init__(self, hs):
"""
Args:
@@ -59,19 +58,17 @@ class SendServerNoticeServlet(RestServlet):
def register(self, json_resource):
PATTERN = "^/_synapse/admin/v1/send_server_notice"
json_resource.register_paths(
- "POST",
- (re.compile(PATTERN + "$"), ),
- self.on_POST,
+ "POST", (re.compile(PATTERN + "$"),), self.on_POST, self.__class__.__name__
)
json_resource.register_paths(
"PUT",
- (re.compile(PATTERN + "/(?P<txn_id>[^/]*)$",), ),
+ (re.compile(PATTERN + "/(?P<txn_id>[^/]*)$"),),
self.on_PUT,
+ self.__class__.__name__,
)
- @defer.inlineCallbacks
- def on_POST(self, request, txn_id=None):
- yield assert_requester_is_admin(self.auth, request)
+ async def on_POST(self, request, txn_id=None):
+ await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ("user_id", "content"))
event_type = body.get("type", EventTypes.Message)
@@ -85,16 +82,16 @@ class SendServerNoticeServlet(RestServlet):
if not self.hs.is_mine_id(user_id):
raise SynapseError(400, "Server notices can only be sent to local users")
- event = yield self.snm.send_notice(
+ event = await self.snm.send_notice(
user_id=body["user_id"],
type=event_type,
state_key=state_key,
event_content=body["content"],
)
- defer.returnValue((200, {"event_id": event.event_id}))
+ return 200, {"event_id": event.event_id}
def on_PUT(self, request, txn_id):
return self.txns.fetch_or_execute_request(
- request, self.on_POST, request, txn_id,
+ request, self.on_POST, request, txn_id
)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
new file mode 100644
index 0000000000..8551ac19b8
--- /dev/null
+++ b/synapse/rest/admin/users.py
@@ -0,0 +1,656 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import hashlib
+import hmac
+import logging
+import re
+
+from six import text_type
+from six.moves import http_client
+
+from synapse.api.constants import UserTypes
+from synapse.api.errors import Codes, NotFoundError, SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_boolean,
+ parse_integer,
+ parse_json_object_from_request,
+ parse_string,
+)
+from synapse.rest.admin._base import (
+ assert_requester_is_admin,
+ assert_user_is_admin,
+ historical_admin_path_patterns,
+)
+from synapse.types import UserID
+
+logger = logging.getLogger(__name__)
+
+
+class UsersRestServlet(RestServlet):
+ PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)$")
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.admin_handler = hs.get_handlers().admin_handler
+
+ async def on_GET(self, request, user_id):
+ target_user = UserID.from_string(user_id)
+ await assert_requester_is_admin(self.auth, request)
+
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only users a local user")
+
+ ret = await self.store.get_users()
+
+ return 200, ret
+
+
+class UsersRestServletV2(RestServlet):
+ PATTERNS = (re.compile("^/_synapse/admin/v2/users$"),)
+
+ """Get request to list all local users.
+ This needs user to have administrator access in Synapse.
+
+ GET /_synapse/admin/v2/users?from=0&limit=10&guests=false
+
+ returns:
+ 200 OK with list of users if success otherwise an error.
+
+ The parameters `from` and `limit` are required only for pagination.
+ By default, a `limit` of 100 is used.
+ The parameter `user_id` can be used to filter by user id.
+ The parameter `guests` can be used to exclude guest users.
+ The parameter `deactivated` can be used to include deactivated users.
+ """
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.admin_handler = hs.get_handlers().admin_handler
+
+ async def on_GET(self, request):
+ await assert_requester_is_admin(self.auth, request)
+
+ start = parse_integer(request, "from", default=0)
+ limit = parse_integer(request, "limit", default=100)
+ user_id = parse_string(request, "user_id", default=None)
+ guests = parse_boolean(request, "guests", default=True)
+ deactivated = parse_boolean(request, "deactivated", default=False)
+
+ users = await self.store.get_users_paginate(
+ start, limit, user_id, guests, deactivated
+ )
+ ret = {"users": users}
+ if len(users) >= limit:
+ ret["next_token"] = str(start + len(users))
+
+ return 200, ret
+
+
+class UserRestServletV2(RestServlet):
+ PATTERNS = (re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]+)$"),)
+
+ """Get request to list user details.
+ This needs user to have administrator access in Synapse.
+
+ GET /_synapse/admin/v2/users/<user_id>
+
+ returns:
+ 200 OK with user details if success otherwise an error.
+
+ Put request to allow an administrator to add or modify a user.
+ This needs user to have administrator access in Synapse.
+ We use PUT instead of POST since we already know the id of the user
+ object to create. POST could be used to create guests.
+
+ PUT /_synapse/admin/v2/users/<user_id>
+ {
+ "password": "secret",
+ "displayname": "User"
+ }
+
+ returns:
+ 201 OK with new user object if user was created or
+ 200 OK with modified user object if user was modified
+ otherwise an error.
+ """
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.admin_handler = hs.get_handlers().admin_handler
+ self.store = hs.get_datastore()
+ self.auth_handler = hs.get_auth_handler()
+ self.profile_handler = hs.get_profile_handler()
+ self.set_password_handler = hs.get_set_password_handler()
+ self.deactivate_account_handler = hs.get_deactivate_account_handler()
+ self.registration_handler = hs.get_registration_handler()
+
+ async def on_GET(self, request, user_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ target_user = UserID.from_string(user_id)
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only lookup local users")
+
+ ret = await self.admin_handler.get_user(target_user)
+
+ if not ret:
+ raise NotFoundError("User not found")
+
+ return 200, ret
+
+ async def on_PUT(self, request, user_id):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ target_user = UserID.from_string(user_id)
+ body = parse_json_object_from_request(request)
+
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "This endpoint can only be used with local users")
+
+ user = await self.admin_handler.get_user(target_user)
+ user_id = target_user.to_string()
+
+ if user: # modify user
+ if "displayname" in body:
+ await self.profile_handler.set_displayname(
+ target_user, requester, body["displayname"], True
+ )
+
+ if "threepids" in body:
+ # check for required parameters for each threepid
+ for threepid in body["threepids"]:
+ assert_params_in_dict(threepid, ["medium", "address"])
+
+ # remove old threepids from user
+ threepids = await self.store.user_get_threepids(user_id)
+ for threepid in threepids:
+ try:
+ await self.auth_handler.delete_threepid(
+ user_id, threepid["medium"], threepid["address"], None
+ )
+ except Exception:
+ logger.exception("Failed to remove threepids")
+ raise SynapseError(500, "Failed to remove threepids")
+
+ # add new threepids to user
+ current_time = self.hs.get_clock().time_msec()
+ for threepid in body["threepids"]:
+ await self.auth_handler.add_threepid(
+ user_id, threepid["medium"], threepid["address"], current_time
+ )
+
+ if "avatar_url" in body:
+ await self.profile_handler.set_avatar_url(
+ target_user, requester, body["avatar_url"], True
+ )
+
+ if "admin" in body:
+ set_admin_to = bool(body["admin"])
+ if set_admin_to != user["admin"]:
+ auth_user = requester.user
+ if target_user == auth_user and not set_admin_to:
+ raise SynapseError(400, "You may not demote yourself.")
+
+ await self.store.set_server_admin(target_user, set_admin_to)
+
+ if "password" in body:
+ if (
+ not isinstance(body["password"], text_type)
+ or len(body["password"]) > 512
+ ):
+ raise SynapseError(400, "Invalid password")
+ else:
+ new_password = body["password"]
+ logout_devices = True
+ await self.set_password_handler.set_password(
+ target_user.to_string(), new_password, logout_devices, requester
+ )
+
+ if "deactivated" in body:
+ deactivate = body["deactivated"]
+ if not isinstance(deactivate, bool):
+ raise SynapseError(
+ 400, "'deactivated' parameter is not of type boolean"
+ )
+
+ if deactivate and not user["deactivated"]:
+ await self.deactivate_account_handler.deactivate_account(
+ target_user.to_string(), False
+ )
+
+ user = await self.admin_handler.get_user(target_user)
+ return 200, user
+
+ else: # create user
+ password = body.get("password")
+ if password is not None and (
+ not isinstance(body["password"], text_type)
+ or len(body["password"]) > 512
+ ):
+ raise SynapseError(400, "Invalid password")
+
+ admin = body.get("admin", None)
+ user_type = body.get("user_type", None)
+ displayname = body.get("displayname", None)
+ threepids = body.get("threepids", None)
+
+ if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
+ raise SynapseError(400, "Invalid user type")
+
+ user_id = await self.registration_handler.register_user(
+ localpart=target_user.localpart,
+ password=password,
+ admin=bool(admin),
+ default_display_name=displayname,
+ user_type=user_type,
+ )
+
+ if "threepids" in body:
+ # check for required parameters for each threepid
+ for threepid in body["threepids"]:
+ assert_params_in_dict(threepid, ["medium", "address"])
+
+ current_time = self.hs.get_clock().time_msec()
+ for threepid in body["threepids"]:
+ await self.auth_handler.add_threepid(
+ user_id, threepid["medium"], threepid["address"], current_time
+ )
+
+ if "avatar_url" in body:
+ await self.profile_handler.set_avatar_url(
+ user_id, requester, body["avatar_url"], True
+ )
+
+ ret = await self.admin_handler.get_user(target_user)
+
+ return 201, ret
+
+
+class UserRegisterServlet(RestServlet):
+ """
+ Attributes:
+ NONCE_TIMEOUT (int): Seconds until a generated nonce won't be accepted
+ nonces (dict[str, int]): The nonces that we will accept. A dict of
+ nonce to the time it was generated, in int seconds.
+ """
+
+ PATTERNS = historical_admin_path_patterns("/register")
+ NONCE_TIMEOUT = 60
+
+ def __init__(self, hs):
+ self.handlers = hs.get_handlers()
+ self.reactor = hs.get_reactor()
+ self.nonces = {}
+ self.hs = hs
+
+ def _clear_old_nonces(self):
+ """
+ Clear out old nonces that are older than NONCE_TIMEOUT.
+ """
+ now = int(self.reactor.seconds())
+
+ for k, v in list(self.nonces.items()):
+ if now - v > self.NONCE_TIMEOUT:
+ del self.nonces[k]
+
+ def on_GET(self, request):
+ """
+ Generate a new nonce.
+ """
+ self._clear_old_nonces()
+
+ nonce = self.hs.get_secrets().token_hex(64)
+ self.nonces[nonce] = int(self.reactor.seconds())
+ return 200, {"nonce": nonce}
+
+ async def on_POST(self, request):
+ self._clear_old_nonces()
+
+ if not self.hs.config.registration_shared_secret:
+ raise SynapseError(400, "Shared secret registration is not enabled")
+
+ body = parse_json_object_from_request(request)
+
+ if "nonce" not in body:
+ raise SynapseError(400, "nonce must be specified", errcode=Codes.BAD_JSON)
+
+ nonce = body["nonce"]
+
+ if nonce not in self.nonces:
+ raise SynapseError(400, "unrecognised nonce")
+
+ # Delete the nonce, so it can't be reused, even if it's invalid
+ del self.nonces[nonce]
+
+ if "username" not in body:
+ raise SynapseError(
+ 400, "username must be specified", errcode=Codes.BAD_JSON
+ )
+ else:
+ if (
+ not isinstance(body["username"], text_type)
+ or len(body["username"]) > 512
+ ):
+ raise SynapseError(400, "Invalid username")
+
+ username = body["username"].encode("utf-8")
+ if b"\x00" in username:
+ raise SynapseError(400, "Invalid username")
+
+ if "password" not in body:
+ raise SynapseError(
+ 400, "password must be specified", errcode=Codes.BAD_JSON
+ )
+ else:
+ if (
+ not isinstance(body["password"], text_type)
+ or len(body["password"]) > 512
+ ):
+ raise SynapseError(400, "Invalid password")
+
+ password = body["password"].encode("utf-8")
+ if b"\x00" in password:
+ raise SynapseError(400, "Invalid password")
+
+ admin = body.get("admin", None)
+ user_type = body.get("user_type", None)
+
+ if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
+ raise SynapseError(400, "Invalid user type")
+
+ got_mac = body["mac"]
+
+ want_mac_builder = hmac.new(
+ key=self.hs.config.registration_shared_secret.encode(),
+ digestmod=hashlib.sha1,
+ )
+ want_mac_builder.update(nonce.encode("utf8"))
+ want_mac_builder.update(b"\x00")
+ want_mac_builder.update(username)
+ want_mac_builder.update(b"\x00")
+ want_mac_builder.update(password)
+ want_mac_builder.update(b"\x00")
+ want_mac_builder.update(b"admin" if admin else b"notadmin")
+ if user_type:
+ want_mac_builder.update(b"\x00")
+ want_mac_builder.update(user_type.encode("utf8"))
+
+ want_mac = want_mac_builder.hexdigest()
+
+ if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")):
+ raise SynapseError(403, "HMAC incorrect")
+
+ # Reuse the parts of RegisterRestServlet to reduce code duplication
+ from synapse.rest.client.v2_alpha.register import RegisterRestServlet
+
+ register = RegisterRestServlet(self.hs)
+
+ user_id = await register.registration_handler.register_user(
+ localpart=body["username"].lower(),
+ password=body["password"],
+ admin=bool(admin),
+ user_type=user_type,
+ )
+
+ result = await register._create_registration_details(user_id, body)
+ return 200, result
+
+
+class WhoisRestServlet(RestServlet):
+ PATTERNS = historical_admin_path_patterns("/whois/(?P<user_id>[^/]*)")
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.handlers = hs.get_handlers()
+
+ async def on_GET(self, request, user_id):
+ target_user = UserID.from_string(user_id)
+ requester = await self.auth.get_user_by_req(request)
+ auth_user = requester.user
+
+ if target_user != auth_user:
+ await assert_user_is_admin(self.auth, auth_user)
+
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only whois a local user")
+
+ ret = await self.handlers.admin_handler.get_whois(target_user)
+
+ return 200, ret
+
+
+class DeactivateAccountRestServlet(RestServlet):
+ PATTERNS = historical_admin_path_patterns("/deactivate/(?P<target_user_id>[^/]*)")
+
+ def __init__(self, hs):
+ self._deactivate_account_handler = hs.get_deactivate_account_handler()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request, target_user_id):
+ await assert_requester_is_admin(self.auth, request)
+ body = parse_json_object_from_request(request, allow_empty_body=True)
+ erase = body.get("erase", False)
+ if not isinstance(erase, bool):
+ raise SynapseError(
+ http_client.BAD_REQUEST,
+ "Param 'erase' must be a boolean, if given",
+ Codes.BAD_JSON,
+ )
+
+ UserID.from_string(target_user_id)
+
+ result = await self._deactivate_account_handler.deactivate_account(
+ target_user_id, erase
+ )
+ if result:
+ id_server_unbind_result = "success"
+ else:
+ id_server_unbind_result = "no-support"
+
+ return 200, {"id_server_unbind_result": id_server_unbind_result}
+
+
+class AccountValidityRenewServlet(RestServlet):
+ PATTERNS = historical_admin_path_patterns("/account_validity/validity$")
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ self.hs = hs
+ self.account_activity_handler = hs.get_account_validity_handler()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request):
+ await assert_requester_is_admin(self.auth, request)
+
+ body = parse_json_object_from_request(request)
+
+ if "user_id" not in body:
+ raise SynapseError(400, "Missing property 'user_id' in the request body")
+
+ expiration_ts = await self.account_activity_handler.renew_account_for_user(
+ body["user_id"],
+ body.get("expiration_ts"),
+ not body.get("enable_renewal_emails", True),
+ )
+
+ res = {"expiration_ts": expiration_ts}
+ return 200, res
+
+
+class ResetPasswordRestServlet(RestServlet):
+ """Post request to allow an administrator reset password for a user.
+ This needs user to have administrator access in Synapse.
+ Example:
+ http://localhost:8008/_synapse/admin/v1/reset_password/
+ @user:to_reset_password?access_token=admin_access_token
+ JsonBodyToSend:
+ {
+ "new_password": "secret"
+ }
+ Returns:
+ 200 OK with empty object if success otherwise an error.
+ """
+
+ PATTERNS = historical_admin_path_patterns(
+ "/reset_password/(?P<target_user_id>[^/]*)"
+ )
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self._set_password_handler = hs.get_set_password_handler()
+
+ async def on_POST(self, request, target_user_id):
+ """Post request to allow an administrator reset password for a user.
+ This needs user to have administrator access in Synapse.
+ """
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ UserID.from_string(target_user_id)
+
+ params = parse_json_object_from_request(request)
+ assert_params_in_dict(params, ["new_password"])
+ new_password = params["new_password"]
+ logout_devices = params.get("logout_devices", True)
+
+ await self._set_password_handler.set_password(
+ target_user_id, new_password, logout_devices, requester
+ )
+ return 200, {}
+
+
+class SearchUsersRestServlet(RestServlet):
+ """Get request to search user table for specific users according to
+ search term.
+ This needs user to have administrator access in Synapse.
+ Example:
+ http://localhost:8008/_synapse/admin/v1/search_users/
+ @admin:user?access_token=admin_access_token&term=alice
+ Returns:
+ 200 OK with json object {list[dict[str, Any]], count} or empty object.
+ """
+
+ PATTERNS = historical_admin_path_patterns("/search_users/(?P<target_user_id>[^/]*)")
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.handlers = hs.get_handlers()
+
+ async def on_GET(self, request, target_user_id):
+ """Get request to search user table for specific users according to
+ search term.
+ This needs user to have a administrator access in Synapse.
+ """
+ await assert_requester_is_admin(self.auth, request)
+
+ target_user = UserID.from_string(target_user_id)
+
+ # To allow all users to get the users list
+ # if not is_admin and target_user != auth_user:
+ # raise AuthError(403, "You are not a server admin")
+
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only users a local user")
+
+ term = parse_string(request, "term", required=True)
+ logger.info("term: %s ", term)
+
+ ret = await self.handlers.store.search_users(term)
+ return 200, ret
+
+
+class UserAdminServlet(RestServlet):
+ """
+ Get or set whether or not a user is a server administrator.
+
+ Note that only local users can be server administrators, and that an
+ administrator may not demote themselves.
+
+ Only server administrators can use this API.
+
+ Examples:
+ * Get
+ GET /_synapse/admin/v1/users/@nonadmin:example.com/admin
+ response on success:
+ {
+ "admin": false
+ }
+ * Set
+ PUT /_synapse/admin/v1/users/@reivilibre:librepush.net/admin
+ request body:
+ {
+ "admin": true
+ }
+ response on success:
+ {}
+ """
+
+ PATTERNS = (re.compile("^/_synapse/admin/v1/users/(?P<user_id>[^/]*)/admin$"),)
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ async def on_GET(self, request, user_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ target_user = UserID.from_string(user_id)
+
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Only local users can be admins of this homeserver")
+
+ is_admin = await self.store.is_server_admin(target_user)
+
+ return 200, {"admin": is_admin}
+
+ async def on_PUT(self, request, user_id):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+ auth_user = requester.user
+
+ target_user = UserID.from_string(user_id)
+
+ body = parse_json_object_from_request(request)
+
+ assert_params_in_dict(body, ["admin"])
+
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Only local users can be admins of this homeserver")
+
+ set_admin_to = bool(body["admin"])
+
+ if target_user == auth_user and not set_admin_to:
+ raise SynapseError(400, "You may not demote yourself.")
+
+ await self.store.set_server_admin(target_user, set_admin_to)
+
+ return 200, {}
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index 48c17f1b6d..6da71dc46f 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -17,8 +17,8 @@
to ensure idempotency when performing PUTs using the REST API."""
import logging
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.util.async_helpers import ObservableDeferred
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
@@ -26,7 +26,6 @@ CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
class HttpTransactionCache(object):
-
def __init__(self, hs):
self.hs = hs
self.auth = self.hs.get_auth()
@@ -53,7 +52,7 @@ class HttpTransactionCache(object):
str: A transaction key
"""
token = self.auth.get_access_token_from_request(request)
- return request.path.decode('utf8') + "/" + token
+ return request.path.decode("utf8") + "/" + token
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
"""A helper function for fetch_or_execute which extracts
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 0035182bb9..5934b1fe8b 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -16,9 +16,13 @@
import logging
-from twisted.internet import defer
-
-from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ InvalidClientCredentialsError,
+ NotFoundError,
+ SynapseError,
+)
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import RoomAlias
@@ -41,23 +45,22 @@ class ClientDirectoryServer(RestServlet):
self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_alias):
+ async def on_GET(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias)
dir_handler = self.handlers.directory_handler
- res = yield dir_handler.get_association(room_alias)
+ res = await dir_handler.get_association(room_alias)
- defer.returnValue((200, res))
+ return 200, res
- @defer.inlineCallbacks
- def on_PUT(self, request, room_alias):
+ async def on_PUT(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias)
content = parse_json_object_from_request(request)
if "room_id" not in content:
- raise SynapseError(400, 'Missing params: ["room_id"]',
- errcode=Codes.BAD_JSON)
+ raise SynapseError(
+ 400, 'Missing params: ["room_id"]', errcode=Codes.BAD_JSON
+ )
logger.debug("Got content: %s", content)
logger.debug("Got room name: %s", room_alias.to_string())
@@ -70,54 +73,47 @@ class ClientDirectoryServer(RestServlet):
# TODO(erikj): Check types.
- room = yield self.store.get_room(room_id)
+ room = await self.store.get_room(room_id)
if room is None:
raise SynapseError(400, "Room does not exist")
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
- yield self.handlers.directory_handler.create_association(
+ await self.handlers.directory_handler.create_association(
requester, room_alias, room_id, servers
)
- defer.returnValue((200, {}))
+ return 200, {}
- @defer.inlineCallbacks
- def on_DELETE(self, request, room_alias):
+ async def on_DELETE(self, request, room_alias):
dir_handler = self.handlers.directory_handler
try:
- service = yield self.auth.get_appservice_by_req(request)
+ service = await self.auth.get_appservice_by_req(request)
room_alias = RoomAlias.from_string(room_alias)
- yield dir_handler.delete_appservice_association(
- service, room_alias
- )
+ await dir_handler.delete_appservice_association(service, room_alias)
logger.info(
"Application service at %s deleted alias %s",
service.url,
- room_alias.to_string()
+ room_alias.to_string(),
)
- defer.returnValue((200, {}))
- except AuthError:
+ return 200, {}
+ except InvalidClientCredentialsError:
# fallback to default user behaviour if they aren't an AS
pass
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
user = requester.user
room_alias = RoomAlias.from_string(room_alias)
- yield dir_handler.delete_association(
- requester, room_alias
- )
+ await dir_handler.delete_association(requester, room_alias)
logger.info(
- "User %s deleted alias %s",
- user.to_string(),
- room_alias.to_string()
+ "User %s deleted alias %s", user.to_string(), room_alias.to_string()
)
- defer.returnValue((200, {}))
+ return 200, {}
class ClientDirectoryListServer(RestServlet):
@@ -129,38 +125,33 @@ class ClientDirectoryListServer(RestServlet):
self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id):
- room = yield self.store.get_room(room_id)
+ async def on_GET(self, request, room_id):
+ room = await self.store.get_room(room_id)
if room is None:
raise NotFoundError("Unknown room")
- defer.returnValue((200, {
- "visibility": "public" if room["is_public"] else "private"
- }))
+ return 200, {"visibility": "public" if room["is_public"] else "private"}
- @defer.inlineCallbacks
- def on_PUT(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
visibility = content.get("visibility", "public")
- yield self.handlers.directory_handler.edit_published_room_list(
- requester, room_id, visibility,
+ await self.handlers.directory_handler.edit_published_room_list(
+ requester, room_id, visibility
)
- defer.returnValue((200, {}))
+ return 200, {}
- @defer.inlineCallbacks
- def on_DELETE(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_DELETE(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
- yield self.handlers.directory_handler.edit_published_room_list(
- requester, room_id, "private",
+ await self.handlers.directory_handler.edit_published_room_list(
+ requester, room_id, "private"
)
- defer.returnValue((200, {}))
+ return 200, {}
class ClientAppserviceDirectoryListServer(RestServlet):
@@ -182,16 +173,15 @@ class ClientAppserviceDirectoryListServer(RestServlet):
def on_DELETE(self, request, network_id, room_id):
return self._edit(request, network_id, room_id, "private")
- @defer.inlineCallbacks
- def _edit(self, request, network_id, room_id, visibility):
- requester = yield self.auth.get_user_by_req(request)
+ async def _edit(self, request, network_id, room_id, visibility):
+ requester = await self.auth.get_user_by_req(request)
if not requester.app_service:
raise AuthError(
403, "Only appservices can edit the appservice published room list"
)
- yield self.handlers.directory_handler.edit_published_appservice_room_list(
- requester.app_service.id, network_id, room_id, visibility,
+ await self.handlers.directory_handler.edit_published_appservice_room_list(
+ requester.app_service.id, network_id, room_id, visibility
)
- defer.returnValue((200, {}))
+ return 200, {}
diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index 84ca36270b..25effd0261 100644
--- a/synapse/rest/client/v1/events.py
+++ b/synapse/rest/client/v1/events.py
@@ -16,8 +16,6 @@
"""This module contains REST servlets to do with event streaming, /events."""
import logging
-from twisted.internet import defer
-
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns
@@ -36,19 +34,15 @@ class EventStreamRestServlet(RestServlet):
self.event_stream_handler = hs.get_event_stream_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(
- request,
- allow_guest=True,
- )
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
is_guest = requester.is_guest
room_id = None
if is_guest:
if b"room_id" not in request.args:
raise SynapseError(400, "Guest users must specify room_id param")
if b"room_id" in request.args:
- room_id = request.args[b"room_id"][0].decode('ascii')
+ room_id = request.args[b"room_id"][0].decode("ascii")
pagin_config = PaginationConfig.from_request(request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
@@ -60,7 +54,7 @@ class EventStreamRestServlet(RestServlet):
as_client_event = b"raw" not in request.args
- chunk = yield self.event_stream_handler.get_stream(
+ chunk = await self.event_stream_handler.get_stream(
requester.user.to_string(),
pagin_config,
timeout=timeout,
@@ -70,13 +64,12 @@ class EventStreamRestServlet(RestServlet):
is_guest=is_guest,
)
- defer.returnValue((200, chunk))
+ return 200, chunk
def on_OPTIONS(self, request):
- return (200, {})
+ return 200, {}
-# TODO: Unit test gets, with and without auth, with different kinds of events.
class EventRestServlet(RestServlet):
PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True)
@@ -84,19 +77,19 @@ class EventRestServlet(RestServlet):
super(EventRestServlet, self).__init__()
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
+ self.auth = hs.get_auth()
self._event_serializer = hs.get_event_client_serializer()
- @defer.inlineCallbacks
- def on_GET(self, request, event_id):
- requester = yield self.auth.get_user_by_req(request)
- event = yield self.event_handler.get_event(requester.user, None, event_id)
+ async def on_GET(self, request, event_id):
+ requester = await self.auth.get_user_by_req(request)
+ event = await self.event_handler.get_event(requester.user, None, event_id)
time_now = self.clock.time_msec()
if event:
- event = yield self._event_serializer.serialize_event(event, time_now)
- defer.returnValue((200, event))
+ event = await self._event_serializer.serialize_event(event, time_now)
+ return 200, event
else:
- defer.returnValue((404, "Event not found."))
+ return 404, "Event not found."
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py
index 0fe5f2d79b..910b3b4eeb 100644
--- a/synapse/rest/client/v1/initial_sync.py
+++ b/synapse/rest/client/v1/initial_sync.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_boolean
from synapse.rest.client.v2_alpha._base import client_patterns
@@ -29,20 +28,19 @@ class InitialSyncRestServlet(RestServlet):
self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request)
as_client_event = b"raw" not in request.args
pagination_config = PaginationConfig.from_request(request)
include_archived = parse_boolean(request, "archived", default=False)
- content = yield self.initial_sync_handler.snapshot_all_rooms(
+ content = await self.initial_sync_handler.snapshot_all_rooms(
user_id=requester.user.to_string(),
pagin_config=pagination_config,
as_client_event=as_client_event,
include_archived=include_archived,
)
- defer.returnValue((200, content))
+ return 200, content
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 7c86b88f30..d0d4999795 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -18,7 +18,6 @@ import xml.etree.ElementTree as ET
from six.moves import urllib
-from twisted.internet import defer
from twisted.web.client import PartialDownloadError
from synapse.api.errors import Codes, LoginError, SynapseError
@@ -29,6 +28,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.push.mailer import load_jinja2_templates
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import UserID, map_username_to_mxid_localpart
@@ -44,10 +44,7 @@ def login_submission_legacy_convert(submission):
to a typed object.
"""
if "user" in submission:
- submission["identifier"] = {
- "type": "m.id.user",
- "user": submission["user"],
- }
+ submission["identifier"] = {"type": "m.id.user", "user": submission["user"]}
del submission["user"]
if "medium" in submission and "address" in submission:
@@ -73,11 +70,7 @@ def login_id_thirdparty_from_phone(identifier):
msisdn = phone_number_to_msisdn(identifier["country"], identifier["number"])
- return {
- "type": "m.id.thirdparty",
- "medium": "msisdn",
- "address": msisdn,
- }
+ return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
class LoginRestServlet(RestServlet):
@@ -93,17 +86,24 @@ class LoginRestServlet(RestServlet):
self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm
+ self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self.handlers = hs.get_handlers()
+ self._clock = hs.get_clock()
self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter()
+ self._account_ratelimiter = Ratelimiter()
+ self._failed_attempts_ratelimiter = Ratelimiter()
def on_GET(self, request):
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
+ if self.saml2_enabled:
+ flows.append({"type": LoginRestServlet.SSO_TYPE})
+ flows.append({"type": LoginRestServlet.TOKEN_TYPE})
if self.cas_enabled:
flows.append({"type": LoginRestServlet.SSO_TYPE})
@@ -120,19 +120,19 @@ class LoginRestServlet(RestServlet):
# login flow types returned.
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
- flows.extend((
- {"type": t} for t in self.auth_handler.get_supported_login_types()
- ))
+ flows.extend(
+ ({"type": t} for t in self.auth_handler.get_supported_login_types())
+ )
- return (200, {"flows": flows})
+ return 200, {"flows": flows}
def on_OPTIONS(self, request):
- return (200, {})
+ return 200, {}
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
self._address_ratelimiter.ratelimit(
- request.getClientIP(), time_now_s=self.hs.clock.time(),
+ request.getClientIP(),
+ time_now_s=self.hs.clock.time(),
rate_hz=self.hs.config.rc_login_address.per_second,
burst_count=self.hs.config.rc_login_address.burst_count,
update=True,
@@ -140,23 +140,23 @@ class LoginRestServlet(RestServlet):
login_submission = parse_json_object_from_request(request)
try:
- if self.jwt_enabled and (login_submission["type"] ==
- LoginRestServlet.JWT_TYPE):
- result = yield self.do_jwt_login(login_submission)
+ if self.jwt_enabled and (
+ login_submission["type"] == LoginRestServlet.JWT_TYPE
+ ):
+ result = await self.do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
- result = yield self.do_token_login(login_submission)
+ result = await self.do_token_login(login_submission)
else:
- result = yield self._do_other_login(login_submission)
+ result = await self._do_other_login(login_submission)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
well_known_data = self._well_known_builder.get_well_known()
if well_known_data:
result["well_known"] = well_known_data
- defer.returnValue((200, result))
+ return 200, result
- @defer.inlineCallbacks
- def _do_other_login(self, login_submission):
+ async def _do_other_login(self, login_submission):
"""Handle non-token/saml/jwt logins
Args:
@@ -170,10 +170,10 @@ class LoginRestServlet(RestServlet):
# field)
logger.info(
"Got login request with identifier: %r, medium: %r, address: %r, user: %r",
- login_submission.get('identifier'),
- login_submission.get('medium'),
- login_submission.get('address'),
- login_submission.get('user'),
+ login_submission.get("identifier"),
+ login_submission.get("medium"),
+ login_submission.get("address"),
+ login_submission.get("user"),
)
login_submission_legacy_convert(login_submission)
@@ -190,49 +190,70 @@ class LoginRestServlet(RestServlet):
# convert threepid identifiers to user IDs
if identifier["type"] == "m.id.thirdparty":
- address = identifier.get('address')
- medium = identifier.get('medium')
+ address = identifier.get("address")
+ medium = identifier.get("medium")
if medium is None or address is None:
raise SynapseError(400, "Invalid thirdparty identifier")
- if medium == 'email':
+ if medium == "email":
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
address = address.lower()
+ # We also apply account rate limiting using the 3PID as a key, as
+ # otherwise using 3PID bypasses the ratelimiting based on user ID.
+ self._failed_attempts_ratelimiter.ratelimit(
+ (medium, address),
+ time_now_s=self._clock.time(),
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ update=False,
+ )
+
# Check for login providers that support 3pid login types
- canonical_user_id, callback_3pid = (
- yield self.auth_handler.check_password_provider_3pid(
- medium,
- address,
- login_submission["password"],
- )
+ (
+ canonical_user_id,
+ callback_3pid,
+ ) = await self.auth_handler.check_password_provider_3pid(
+ medium, address, login_submission["password"]
)
if canonical_user_id:
# Authentication through password provider and 3pid succeeded
- result = yield self._register_device_with_callback(
- canonical_user_id, login_submission, callback_3pid,
+
+ result = await self._complete_login(
+ canonical_user_id, login_submission, callback_3pid
)
- defer.returnValue(result)
+ return result
# No password providers were able to handle this 3pid
# Check local store
- user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
- medium, address,
+ user_id = await self.hs.get_datastore().get_user_id_by_threepid(
+ medium, address
)
if not user_id:
- logger.warn(
- "unknown 3pid identifier medium %s, address %r",
- medium, address,
+ logger.warning(
+ "unknown 3pid identifier medium %s, address %r", medium, address
+ )
+ # We mark that we've failed to log in here, as
+ # `check_password_provider_3pid` might have returned `None` due
+ # to an incorrect password, rather than the account not
+ # existing.
+ #
+ # If it returned None but the 3PID was bound then we won't hit
+ # this code path, which is fine as then the per-user ratelimit
+ # will kick in below.
+ self._failed_attempts_ratelimiter.can_do_action(
+ (medium, address),
+ time_now_s=self._clock.time(),
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ update=True,
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
- identifier = {
- "type": "m.id.user",
- "user": user_id,
- }
+ identifier = {"type": "m.id.user", "user": user_id}
# by this point, the identifier should be an m.id.user: if it's anything
# else, we haven't understood it.
@@ -241,39 +262,87 @@ class LoginRestServlet(RestServlet):
if "user" not in identifier:
raise SynapseError(400, "User identifier is missing 'user' key")
- canonical_user_id, callback = yield self.auth_handler.validate_login(
- identifier["user"],
- login_submission,
+ if identifier["user"].startswith("@"):
+ qualified_user_id = identifier["user"]
+ else:
+ qualified_user_id = UserID(identifier["user"], self.hs.hostname).to_string()
+
+ # Check if we've hit the failed ratelimit (but don't update it)
+ self._failed_attempts_ratelimiter.ratelimit(
+ qualified_user_id.lower(),
+ time_now_s=self._clock.time(),
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ update=False,
)
- result = yield self._register_device_with_callback(
- canonical_user_id, login_submission, callback,
+ try:
+ canonical_user_id, callback = await self.auth_handler.validate_login(
+ identifier["user"], login_submission
+ )
+ except LoginError:
+ # The user has failed to log in, so we need to update the rate
+ # limiter. Using `can_do_action` avoids us raising a ratelimit
+ # exception and masking the LoginError. The actual ratelimiting
+ # should have happened above.
+ self._failed_attempts_ratelimiter.can_do_action(
+ qualified_user_id.lower(),
+ time_now_s=self._clock.time(),
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ update=True,
+ )
+ raise
+
+ result = await self._complete_login(
+ canonical_user_id, login_submission, callback
)
- defer.returnValue(result)
-
- @defer.inlineCallbacks
- def _register_device_with_callback(
- self,
- user_id,
- login_submission,
- callback=None,
+ return result
+
+ async def _complete_login(
+ self, user_id, login_submission, callback=None, create_non_existant_users=False
):
- """ Registers a device with a given user_id. Optionally run a callback
- function after registration has completed.
+ """Called when we've successfully authed the user and now need to
+ actually login them in (e.g. create devices). This gets called on
+ all succesful logins.
+
+ Applies the ratelimiting for succesful login attempts against an
+ account.
Args:
user_id (str): ID of the user to register.
login_submission (dict): Dictionary of login information.
callback (func|None): Callback function to run after registration.
+ create_non_existant_users (bool): Whether to create the user if
+ they don't exist. Defaults to False.
Returns:
result (Dict[str,str]): Dictionary of account information after
successful registration.
"""
+
+ # Before we actually log them in we check if they've already logged in
+ # too often. This happens here rather than before as we don't
+ # necessarily know the user before now.
+ self._account_ratelimiter.ratelimit(
+ user_id.lower(),
+ time_now_s=self._clock.time(),
+ rate_hz=self.hs.config.rc_login_account.per_second,
+ burst_count=self.hs.config.rc_login_account.burst_count,
+ update=True,
+ )
+
+ if create_non_existant_users:
+ user_id = await self.auth_handler.check_user_exists(user_id)
+ if not user_id:
+ user_id = await self.registration_handler.register_user(
+ localpart=UserID.from_string(user_id).localpart
+ )
+
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
- device_id, access_token = yield self.registration_handler.register_device(
- user_id, device_id, initial_display_name,
+ device_id, access_token = await self.registration_handler.register_device(
+ user_id, device_id, initial_display_name
)
result = {
@@ -284,47 +353,34 @@ class LoginRestServlet(RestServlet):
}
if callback is not None:
- yield callback(result)
+ await callback(result)
- defer.returnValue(result)
+ return result
- @defer.inlineCallbacks
- def do_token_login(self, login_submission):
- token = login_submission['token']
+ async def do_token_login(self, login_submission):
+ token = login_submission["token"]
auth_handler = self.auth_handler
- user_id = (
- yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
- )
-
- device_id = login_submission.get("device_id")
- initial_display_name = login_submission.get("initial_device_display_name")
- device_id, access_token = yield self.registration_handler.register_device(
- user_id, device_id, initial_display_name,
+ user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
+ token
)
- result = {
- "user_id": user_id, # may have changed
- "access_token": access_token,
- "home_server": self.hs.hostname,
- "device_id": device_id,
- }
-
- defer.returnValue(result)
+ result = await self._complete_login(user_id, login_submission)
+ return result
- @defer.inlineCallbacks
- def do_jwt_login(self, login_submission):
+ async def do_jwt_login(self, login_submission):
token = login_submission.get("token", None)
if token is None:
raise LoginError(
- 401, "Token field for JWT is missing",
- errcode=Codes.UNAUTHORIZED
+ 401, "Token field for JWT is missing", errcode=Codes.UNAUTHORIZED
)
import jwt
from jwt.exceptions import InvalidTokenError
try:
- payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
+ payload = jwt.decode(
+ token, self.jwt_secret, algorithms=[self.jwt_algorithm]
+ )
except jwt.ExpiredSignatureError:
raise LoginError(401, "JWT expired", errcode=Codes.UNAUTHORIZED)
except InvalidTokenError:
@@ -335,63 +391,55 @@ class LoginRestServlet(RestServlet):
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
user_id = UserID(user, self.hs.hostname).to_string()
+ result = await self._complete_login(
+ user_id, login_submission, create_non_existant_users=True
+ )
+ return result
- auth_handler = self.auth_handler
- registered_user_id = yield auth_handler.check_user_exists(user_id)
- if registered_user_id:
- device_id = login_submission.get("device_id")
- initial_display_name = login_submission.get("initial_device_display_name")
- device_id, access_token = yield self.registration_handler.register_device(
- registered_user_id, device_id, initial_display_name,
- )
- result = {
- "user_id": registered_user_id,
- "access_token": access_token,
- "home_server": self.hs.hostname,
- }
- else:
- user_id, access_token = (
- yield self.handlers.registration_handler.register(localpart=user)
- )
+class BaseSSORedirectServlet(RestServlet):
+ """Common base class for /login/sso/redirect impls"""
- device_id = login_submission.get("device_id")
- initial_display_name = login_submission.get("initial_device_display_name")
- device_id, access_token = yield self.registration_handler.register_device(
- registered_user_id, device_id, initial_display_name,
- )
+ PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
+
+ def on_GET(self, request):
+ args = request.args
+ if b"redirectUrl" not in args:
+ return 400, "Redirect URL not specified for SSO auth"
+ client_redirect_url = args[b"redirectUrl"][0]
+ sso_url = self.get_sso_url(client_redirect_url)
+ request.redirect(sso_url)
+ finish_request(request)
- result = {
- "user_id": user_id, # may have changed
- "access_token": access_token,
- "home_server": self.hs.hostname,
- }
+ def get_sso_url(self, client_redirect_url):
+ """Get the URL to redirect to, to perform SSO auth
- defer.returnValue(result)
+ Args:
+ client_redirect_url (bytes): the URL that we should redirect the
+ client to when everything is done
+ Returns:
+ bytes: URL to redirect to
+ """
+ # to be implemented by subclasses
+ raise NotImplementedError()
-class CasRedirectServlet(RestServlet):
- PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
+class CasRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
super(CasRedirectServlet, self).__init__()
- self.cas_server_url = hs.config.cas_server_url.encode('ascii')
- self.cas_service_url = hs.config.cas_service_url.encode('ascii')
+ self.cas_server_url = hs.config.cas_server_url.encode("ascii")
+ self.cas_service_url = hs.config.cas_service_url.encode("ascii")
- def on_GET(self, request):
- args = request.args
- if b"redirectUrl" not in args:
- return (400, "Redirect URL not specified for CAS auth")
- client_redirect_url_param = urllib.parse.urlencode({
- b"redirectUrl": args[b"redirectUrl"][0]
- }).encode('ascii')
- hs_redirect_url = (self.cas_service_url +
- b"/_matrix/client/r0/login/cas/ticket")
- service_param = urllib.parse.urlencode({
- b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)
- }).encode('ascii')
- request.redirect(b"%s/login?%s" % (self.cas_server_url, service_param))
- finish_request(request)
+ def get_sso_url(self, client_redirect_url):
+ client_redirect_url_param = urllib.parse.urlencode(
+ {b"redirectUrl": client_redirect_url}
+ ).encode("ascii")
+ hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket"
+ service_param = urllib.parse.urlencode(
+ {b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)}
+ ).encode("ascii")
+ return b"%s/login?%s" % (self.cas_server_url, service_param)
class CasTicketServlet(RestServlet):
@@ -401,29 +449,30 @@ class CasTicketServlet(RestServlet):
super(CasTicketServlet, self).__init__()
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
+ self.cas_displayname_attribute = hs.config.cas_displayname_attribute
self.cas_required_attributes = hs.config.cas_required_attributes
self._sso_auth_handler = SSOAuthHandler(hs)
self._http_client = hs.get_proxied_http_client()
- @defer.inlineCallbacks
- def on_GET(self, request):
+ async def on_GET(self, request):
client_redirect_url = parse_string(request, "redirectUrl", required=True)
uri = self.cas_server_url + "/proxyValidate"
args = {
"ticket": parse_string(request, "ticket", required=True),
- "service": self.cas_service_url
+ "service": self.cas_service_url,
}
try:
- body = yield self._http_client.get_raw(uri, args)
+ body = await self._http_client.get_raw(uri, args)
except PartialDownloadError as pde:
# Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data
body = pde.response
- result = yield self.handle_cas_response(request, body, client_redirect_url)
- defer.returnValue(result)
+ result = await self.handle_cas_response(request, body, client_redirect_url)
+ return result
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
user, attributes = self.parse_cas_response(cas_response_body)
+ displayname = attributes.pop(self.cas_displayname_attribute, None)
for required_attribute, required_value in self.cas_required_attributes.items():
# If required attribute was not in CAS Response - Forbidden
@@ -438,7 +487,7 @@ class CasTicketServlet(RestServlet):
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
return self._sso_auth_handler.on_successful_auth(
- user, request, client_redirect_url,
+ user, request, client_redirect_url, displayname
)
def parse_cas_response(self, cas_response_body):
@@ -448,7 +497,7 @@ class CasTicketServlet(RestServlet):
root = ET.fromstring(cas_response_body)
if not root.tag.endswith("serviceResponse"):
raise Exception("root of CAS response is not serviceResponse")
- success = (root[0].tag.endswith("authenticationSuccess"))
+ success = root[0].tag.endswith("authenticationSuccess")
for child in root[0]:
if child.tag.endswith("user"):
user = child.text
@@ -465,15 +514,25 @@ class CasTicketServlet(RestServlet):
if user is None:
raise Exception("CAS response does not contain user")
except Exception:
- logger.error("Error parsing CAS response", exc_info=1)
- raise LoginError(401, "Invalid CAS response",
- errcode=Codes.UNAUTHORIZED)
+ logger.exception("Error parsing CAS response")
+ raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not success:
- raise LoginError(401, "Unsuccessful CAS response",
- errcode=Codes.UNAUTHORIZED)
+ raise LoginError(
+ 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
+ )
return user, attributes
+class SAMLRedirectServlet(BaseSSORedirectServlet):
+ PATTERNS = client_patterns("/login/sso/redirect", v1=True)
+
+ def __init__(self, hs):
+ self._saml_handler = hs.get_saml_handler()
+
+ def get_sso_url(self, client_redirect_url):
+ return self._saml_handler.handle_redirect_request(client_redirect_url)
+
+
class SSOAuthHandler(object):
"""
Utility class for Resources and Servlets which handle the response from a SSO
@@ -482,16 +541,25 @@ class SSOAuthHandler(object):
Args:
hs (synapse.server.HomeServer)
"""
+
def __init__(self, hs):
self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
self._macaroon_gen = hs.get_macaroon_generator()
- @defer.inlineCallbacks
- def on_successful_auth(
- self, username, request, client_redirect_url,
- user_display_name=None,
+ # Load the redirect page HTML template
+ self._template = load_jinja2_templates(
+ hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
+ )[0]
+
+ self._server_name = hs.config.server_name
+
+ # cast to tuple for use with str.startswith
+ self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
+
+ async def on_successful_auth(
+ self, username, request, client_redirect_url, user_display_name=None
):
"""Called once the user has successfully authenticated with the SSO.
@@ -516,32 +584,15 @@ class SSOAuthHandler(object):
"""
localpart = map_username_to_mxid_localpart(username)
user_id = UserID(localpart, self._hostname).to_string()
- registered_user_id = yield self._auth_handler.check_user_exists(user_id)
+ registered_user_id = await self._auth_handler.check_user_exists(user_id)
if not registered_user_id:
- registered_user_id, _ = (
- yield self._registration_handler.register(
- localpart=localpart,
- generate_token=False,
- default_display_name=user_display_name,
- )
+ registered_user_id = await self._registration_handler.register_user(
+ localpart=localpart, default_display_name=user_display_name
)
- login_token = self._macaroon_gen.generate_short_term_login_token(
- registered_user_id
- )
- redirect_url = self._add_login_token_to_redirect_url(
- client_redirect_url, login_token
+ self._auth_handler.complete_sso_login(
+ registered_user_id, request, client_redirect_url
)
- request.redirect(redirect_url)
- finish_request(request)
-
- @staticmethod
- def _add_login_token_to_redirect_url(url, token):
- url_parts = list(urllib.parse.urlparse(url))
- query = dict(urllib.parse.parse_qsl(url_parts[4]))
- query.update({"loginToken": token})
- url_parts[4] = urllib.parse.urlencode(query)
- return urllib.parse.urlunparse(url_parts)
def register_servlets(hs, http_server):
@@ -549,3 +600,5 @@ def register_servlets(hs, http_server):
if hs.config.cas_enabled:
CasRedirectServlet(hs).register(http_server)
CasTicketServlet(hs).register(http_server)
+ elif hs.config.saml2_enabled:
+ SAMLRedirectServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index b8064f261e..1cf3caf832 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns
@@ -33,22 +31,22 @@ class LogoutRestServlet(RestServlet):
self._device_handler = hs.get_device_handler()
def on_OPTIONS(self, request):
- return (200, {})
+ return 200, {}
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
if requester.device_id is None:
# the acccess token wasn't associated with a device.
# Just delete the access token
access_token = self.auth.get_access_token_from_request(request)
- yield self._auth_handler.delete_access_token(access_token)
+ await self._auth_handler.delete_access_token(access_token)
else:
- yield self._device_handler.delete_device(
- requester.user.to_string(), requester.device_id)
+ await self._device_handler.delete_device(
+ requester.user.to_string(), requester.device_id
+ )
- defer.returnValue((200, {}))
+ return 200, {}
class LogoutAllRestServlet(RestServlet):
@@ -61,20 +59,19 @@ class LogoutAllRestServlet(RestServlet):
self._device_handler = hs.get_device_handler()
def on_OPTIONS(self, request):
- return (200, {})
+ return 200, {}
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
# first delete all of the user's devices
- yield self._device_handler.delete_all_devices_for_user(user_id)
+ await self._device_handler.delete_all_devices_for_user(user_id)
# .. and then delete any access tokens which weren't associated with
# devices.
- yield self._auth_handler.delete_access_tokens_for_user(user_id)
- defer.returnValue((200, {}))
+ await self._auth_handler.delete_access_tokens_for_user(user_id)
+ return 200, {}
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index e263da3cb7..eec16f8ad8 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -19,8 +19,6 @@ import logging
from six import string_types
-from twisted.internet import defer
-
from synapse.api.errors import AuthError, SynapseError
from synapse.handlers.presence import format_user_presence_state
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -40,27 +38,25 @@ class PresenceStatusRestServlet(RestServlet):
self.clock = hs.get_clock()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request, user_id):
+ requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
if requester.user != user:
- allowed = yield self.presence_handler.is_visible(
- observed_user=user, observer_user=requester.user,
+ allowed = await self.presence_handler.is_visible(
+ observed_user=user, observer_user=requester.user
)
if not allowed:
raise AuthError(403, "You are not allowed to see their presence.")
- state = yield self.presence_handler.get_state(target_user=user)
+ state = await self.presence_handler.get_state(target_user=user)
state = format_user_presence_state(state, self.clock.time_msec())
- defer.returnValue((200, state))
+ return 200, state
- @defer.inlineCallbacks
- def on_PUT(self, request, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, user_id):
+ requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
if requester.user != user:
@@ -86,12 +82,12 @@ class PresenceStatusRestServlet(RestServlet):
raise SynapseError(400, "Unable to parse state")
if self.hs.config.use_presence:
- yield self.presence_handler.set_state(user, state)
+ await self.presence_handler.set_state(user, state)
- defer.returnValue((200, {}))
+ return 200, {}
def on_OPTIONS(self, request):
- return (200, {})
+ return 200, {}
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index 34361697df..165313b572 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -14,16 +14,13 @@
# limitations under the License.
""" This module contains REST servlets to do with profile: /profile/<paths> """
-import logging
-
from twisted.internet import defer
+from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import UserID
-logger = logging.getLogger(__name__)
-
class ProfileDisplaynameRestServlet(RestServlet):
PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True)
@@ -35,52 +32,47 @@ class ProfileDisplaynameRestServlet(RestServlet):
self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
+ async def on_GET(self, request, user_id):
requester_user = None
if self.hs.config.require_auth_for_profile_requests:
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
requester_user = requester.user
user = UserID.from_string(user_id)
- yield self.profile_handler.check_profile_query_allowed(user, requester_user)
+ await self.profile_handler.check_profile_query_allowed(user, requester_user)
- displayname = yield self.profile_handler.get_displayname(user)
+ displayname = await self.profile_handler.get_displayname(user)
ret = {}
if displayname is not None:
ret["displayname"] = displayname
- defer.returnValue((200, ret))
+ return 200, ret
- @defer.inlineCallbacks
- def on_PUT(self, request, user_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_PUT(self, request, user_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
user = UserID.from_string(user_id)
- is_admin = yield self.auth.is_server_admin(requester.user)
+ is_admin = await self.auth.is_server_admin(requester.user)
content = parse_json_object_from_request(request)
try:
new_name = content["displayname"]
except Exception:
- defer.returnValue((400, "Unable to parse name"))
+ return 400, "Unable to parse name"
- yield self.profile_handler.set_displayname(
- user, requester, new_name, is_admin)
+ await self.profile_handler.set_displayname(user, requester, new_name, is_admin)
if self.hs.config.shadow_server:
- shadow_user = UserID(
- user.localpart, self.hs.config.shadow_server.get("hs")
- )
+ shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs"))
self.shadow_displayname(shadow_user.to_string(), content)
- defer.returnValue((200, {}))
+ return 200, {}
def on_OPTIONS(self, request, user_id):
- return (200, {})
+ return 200, {}
@defer.inlineCallbacks
def shadow_displayname(self, user_id, body):
@@ -89,10 +81,9 @@ class ProfileDisplaynameRestServlet(RestServlet):
as_token = self.hs.config.shadow_server.get("as_token")
yield self.http_client.put_json(
- "%s/_matrix/client/r0/profile/%s/displayname?access_token=%s&user_id=%s" % (
- shadow_hs_url, user_id, as_token, user_id
- ),
- body
+ "%s/_matrix/client/r0/profile/%s/displayname?access_token=%s&user_id=%s"
+ % (shadow_hs_url, user_id, as_token, user_id),
+ body,
)
@@ -106,52 +97,50 @@ class ProfileAvatarURLRestServlet(RestServlet):
self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
+ async def on_GET(self, request, user_id):
requester_user = None
if self.hs.config.require_auth_for_profile_requests:
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
requester_user = requester.user
user = UserID.from_string(user_id)
- yield self.profile_handler.check_profile_query_allowed(user, requester_user)
+ await self.profile_handler.check_profile_query_allowed(user, requester_user)
- avatar_url = yield self.profile_handler.get_avatar_url(user)
+ avatar_url = await self.profile_handler.get_avatar_url(user)
ret = {}
if avatar_url is not None:
ret["avatar_url"] = avatar_url
- defer.returnValue((200, ret))
+ return 200, ret
- @defer.inlineCallbacks
- def on_PUT(self, request, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, user_id):
+ requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
- is_admin = yield self.auth.is_server_admin(requester.user)
+ is_admin = await self.auth.is_server_admin(requester.user)
content = parse_json_object_from_request(request)
try:
new_avatar_url = content["avatar_url"]
- except Exception:
- defer.returnValue((400, "Unable to parse name"))
+ except KeyError:
+ raise SynapseError(
+ 400, "Missing key 'avatar_url'", errcode=Codes.MISSING_PARAM
+ )
- yield self.profile_handler.set_avatar_url(
+ await self.profile_handler.set_avatar_url(
user, requester, new_avatar_url, is_admin
)
if self.hs.config.shadow_server:
- shadow_user = UserID(
- user.localpart, self.hs.config.shadow_server.get("hs")
- )
+ shadow_user = UserID(user.localpart, self.hs.config.shadow_server.get("hs"))
self.shadow_avatar_url(shadow_user.to_string(), content)
- defer.returnValue((200, {}))
+ return 200, {}
def on_OPTIONS(self, request, user_id):
- return (200, {})
+ return 200, {}
@defer.inlineCallbacks
def shadow_avatar_url(self, user_id, body):
@@ -160,10 +149,9 @@ class ProfileAvatarURLRestServlet(RestServlet):
as_token = self.hs.config.shadow_server.get("as_token")
yield self.http_client.put_json(
- "%s/_matrix/client/r0/profile/%s/avatar_url?access_token=%s&user_id=%s" % (
- shadow_hs_url, user_id, as_token, user_id
- ),
- body
+ "%s/_matrix/client/r0/profile/%s/avatar_url?access_token=%s&user_id=%s"
+ % (shadow_hs_url, user_id, as_token, user_id),
+ body,
)
@@ -176,20 +164,19 @@ class ProfileRestServlet(RestServlet):
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
+ async def on_GET(self, request, user_id):
requester_user = None
if self.hs.config.require_auth_for_profile_requests:
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
requester_user = requester.user
user = UserID.from_string(user_id)
- yield self.profile_handler.check_profile_query_allowed(user, requester_user)
+ await self.profile_handler.check_profile_query_allowed(user, requester_user)
- displayname = yield self.profile_handler.get_displayname(user)
- avatar_url = yield self.profile_handler.get_avatar_url(user)
+ displayname = await self.profile_handler.get_displayname(user)
+ avatar_url = await self.profile_handler.get_avatar_url(user)
ret = {}
if displayname is not None:
@@ -197,7 +184,7 @@ class ProfileRestServlet(RestServlet):
if avatar_url is not None:
ret["avatar_url"] = avatar_url
- defer.returnValue((200, ret))
+ return 200, ret
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 3d6326fe2f..9fd4908136 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
from synapse.api.errors import (
NotFoundError,
@@ -21,7 +20,11 @@ from synapse.api.errors import (
SynapseError,
UnrecognizedRequestError,
)
-from synapse.http.servlet import RestServlet, parse_json_value_from_request, parse_string
+from synapse.http.servlet import (
+ RestServlet,
+ parse_json_value_from_request,
+ parse_string,
+)
from synapse.push.baserules import BASE_RULE_IDS
from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
@@ -32,7 +35,8 @@ from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundExc
class PushRuleRestServlet(RestServlet):
PATTERNS = client_patterns("/(?P<path>pushrules/.*)$", v1=True)
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
- "Unrecognised request: You probably wanted a trailing slash")
+ "Unrecognised request: You probably wanted a trailing slash"
+ )
def __init__(self, hs):
super(PushRuleRestServlet, self).__init__()
@@ -41,40 +45,37 @@ class PushRuleRestServlet(RestServlet):
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None
- @defer.inlineCallbacks
- def on_PUT(self, request, path):
+ async def on_PUT(self, request, path):
if self._is_worker:
raise Exception("Cannot handle PUT /push_rules on worker")
- spec = _rule_spec_from_path([x for x in path.split("/")])
+ spec = _rule_spec_from_path(path.split("/"))
try:
priority_class = _priority_class_from_spec(spec)
except InvalidRuleException as e:
raise SynapseError(400, str(e))
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
- if '/' in spec['rule_id'] or '\\' in spec['rule_id']:
+ if "/" in spec["rule_id"] or "\\" in spec["rule_id"]:
raise SynapseError(400, "rule_id may not contain slashes")
content = parse_json_value_from_request(request)
user_id = requester.user.to_string()
- if 'attr' in spec:
- yield self.set_rule_attr(user_id, spec, content)
+ if "attr" in spec:
+ await self.set_rule_attr(user_id, spec, content)
self.notify_user(user_id)
- defer.returnValue((200, {}))
+ return 200, {}
- if spec['rule_id'].startswith('.'):
+ if spec["rule_id"].startswith("."):
# Rule ids starting with '.' are reserved for server default rules.
raise SynapseError(400, "cannot add new rule_ids that start with '.'")
try:
(conditions, actions) = _rule_tuple_from_request_object(
- spec['template'],
- spec['rule_id'],
- content,
+ spec["template"], spec["rule_id"], content
)
except InvalidRuleException as e:
raise SynapseError(400, str(e))
@@ -88,14 +89,14 @@ class PushRuleRestServlet(RestServlet):
after = _namespaced_rule_id(spec, after)
try:
- yield self.store.add_push_rule(
+ await self.store.add_push_rule(
user_id=user_id,
rule_id=_namespaced_rule_id_from_spec(spec),
priority_class=priority_class,
conditions=conditions,
actions=actions,
before=before,
- after=after
+ after=after,
)
self.notify_user(user_id)
except InconsistentRuleException as e:
@@ -103,45 +104,41 @@ class PushRuleRestServlet(RestServlet):
except RuleNotFoundException as e:
raise SynapseError(400, str(e))
- defer.returnValue((200, {}))
+ return 200, {}
- @defer.inlineCallbacks
- def on_DELETE(self, request, path):
+ async def on_DELETE(self, request, path):
if self._is_worker:
raise Exception("Cannot handle DELETE /push_rules on worker")
- spec = _rule_spec_from_path([x for x in path.split("/")])
+ spec = _rule_spec_from_path(path.split("/"))
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
try:
- yield self.store.delete_push_rule(
- user_id, namespaced_rule_id
- )
+ await self.store.delete_push_rule(user_id, namespaced_rule_id)
self.notify_user(user_id)
- defer.returnValue((200, {}))
+ return 200, {}
except StoreError as e:
if e.code == 404:
raise NotFoundError()
else:
raise
- @defer.inlineCallbacks
- def on_GET(self, request, path):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request, path):
+ requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
# we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference
- rules = yield self.store.get_push_rules_for_user(user_id)
+ rules = await self.store.get_push_rules_for_user(user_id)
rules = format_push_rules_for_user(requester.user, rules)
- path = [x for x in path.split("/")][1:]
+ path = path.split("/")[1:]
if path == []:
# we're a reference impl: pedantry is our job.
@@ -149,11 +146,11 @@ class PushRuleRestServlet(RestServlet):
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
- if path[0] == '':
- defer.returnValue((200, rules))
- elif path[0] == 'global':
- result = _filter_ruleset_with_path(rules['global'], path[1:])
- defer.returnValue((200, result))
+ if path[0] == "":
+ return 200, rules
+ elif path[0] == "global":
+ result = _filter_ruleset_with_path(rules["global"], path[1:])
+ return 200, result
else:
raise UnrecognizedRequestError()
@@ -162,12 +159,10 @@ class PushRuleRestServlet(RestServlet):
def notify_user(self, user_id):
stream_id, _ = self.store.get_push_rules_stream_token()
- self.notifier.on_new_event(
- "push_rules_key", stream_id, users=[user_id]
- )
+ self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
def set_rule_attr(self, user_id, spec, val):
- if spec['attr'] == 'enabled':
+ if spec["attr"] == "enabled":
if isinstance(val, dict) and "enabled" in val:
val = val["enabled"]
if not isinstance(val, bool):
@@ -176,14 +171,12 @@ class PushRuleRestServlet(RestServlet):
# bools directly, so let's not break them.
raise SynapseError(400, "Value for 'enabled' must be boolean")
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
- return self.store.set_push_rule_enabled(
- user_id, namespaced_rule_id, val
- )
- elif spec['attr'] == 'actions':
- actions = val.get('actions')
+ return self.store.set_push_rule_enabled(user_id, namespaced_rule_id, val)
+ elif spec["attr"] == "actions":
+ actions = val.get("actions")
_check_actions(actions)
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
- rule_id = spec['rule_id']
+ rule_id = spec["rule_id"]
is_default_rule = rule_id.startswith(".")
if is_default_rule:
if namespaced_rule_id not in BASE_RULE_IDS:
@@ -210,12 +203,12 @@ def _rule_spec_from_path(path):
"""
if len(path) < 2:
raise UnrecognizedRequestError()
- if path[0] != 'pushrules':
+ if path[0] != "pushrules":
raise UnrecognizedRequestError()
scope = path[1]
path = path[2:]
- if scope != 'global':
+ if scope != "global":
raise UnrecognizedRequestError()
if len(path) == 0:
@@ -229,56 +222,40 @@ def _rule_spec_from_path(path):
rule_id = path[0]
- spec = {
- 'scope': scope,
- 'template': template,
- 'rule_id': rule_id
- }
+ spec = {"scope": scope, "template": template, "rule_id": rule_id}
path = path[1:]
if len(path) > 0 and len(path[0]) > 0:
- spec['attr'] = path[0]
+ spec["attr"] = path[0]
return spec
def _rule_tuple_from_request_object(rule_template, rule_id, req_obj):
- if rule_template in ['override', 'underride']:
- if 'conditions' not in req_obj:
+ if rule_template in ["override", "underride"]:
+ if "conditions" not in req_obj:
raise InvalidRuleException("Missing 'conditions'")
- conditions = req_obj['conditions']
+ conditions = req_obj["conditions"]
for c in conditions:
- if 'kind' not in c:
+ if "kind" not in c:
raise InvalidRuleException("Condition without 'kind'")
- elif rule_template == 'room':
- conditions = [{
- 'kind': 'event_match',
- 'key': 'room_id',
- 'pattern': rule_id
- }]
- elif rule_template == 'sender':
- conditions = [{
- 'kind': 'event_match',
- 'key': 'user_id',
- 'pattern': rule_id
- }]
- elif rule_template == 'content':
- if 'pattern' not in req_obj:
+ elif rule_template == "room":
+ conditions = [{"kind": "event_match", "key": "room_id", "pattern": rule_id}]
+ elif rule_template == "sender":
+ conditions = [{"kind": "event_match", "key": "user_id", "pattern": rule_id}]
+ elif rule_template == "content":
+ if "pattern" not in req_obj:
raise InvalidRuleException("Content rule missing 'pattern'")
- pat = req_obj['pattern']
+ pat = req_obj["pattern"]
- conditions = [{
- 'kind': 'event_match',
- 'key': 'content.body',
- 'pattern': pat
- }]
+ conditions = [{"kind": "event_match", "key": "content.body", "pattern": pat}]
else:
raise InvalidRuleException("Unknown rule template: %s" % (rule_template,))
- if 'actions' not in req_obj:
+ if "actions" not in req_obj:
raise InvalidRuleException("No actions found")
- actions = req_obj['actions']
+ actions = req_obj["actions"]
_check_actions(actions)
@@ -290,9 +267,9 @@ def _check_actions(actions):
raise InvalidRuleException("No actions found")
for a in actions:
- if a in ['notify', 'dont_notify', 'coalesce']:
+ if a in ["notify", "dont_notify", "coalesce"]:
pass
- elif isinstance(a, dict) and 'set_tweak' in a:
+ elif isinstance(a, dict) and "set_tweak" in a:
pass
else:
raise InvalidRuleException("Unrecognised action")
@@ -304,7 +281,7 @@ def _filter_ruleset_with_path(ruleset, path):
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
- if path[0] == '':
+ if path[0] == "":
return ruleset
template_kind = path[0]
if template_kind not in ruleset:
@@ -314,13 +291,13 @@ def _filter_ruleset_with_path(ruleset, path):
raise UnrecognizedRequestError(
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
- if path[0] == '':
+ if path[0] == "":
return ruleset[template_kind]
rule_id = path[0]
the_rule = None
for r in ruleset[template_kind]:
- if r['rule_id'] == rule_id:
+ if r["rule_id"] == rule_id:
the_rule = r
if the_rule is None:
raise NotFoundError
@@ -339,19 +316,19 @@ def _filter_ruleset_with_path(ruleset, path):
def _priority_class_from_spec(spec):
- if spec['template'] not in PRIORITY_CLASS_MAP.keys():
- raise InvalidRuleException("Unknown template: %s" % (spec['template']))
- pc = PRIORITY_CLASS_MAP[spec['template']]
+ if spec["template"] not in PRIORITY_CLASS_MAP.keys():
+ raise InvalidRuleException("Unknown template: %s" % (spec["template"]))
+ pc = PRIORITY_CLASS_MAP[spec["template"]]
return pc
def _namespaced_rule_id_from_spec(spec):
- return _namespaced_rule_id(spec, spec['rule_id'])
+ return _namespaced_rule_id(spec, spec["rule_id"])
def _namespaced_rule_id(spec, rule_id):
- return "global/%s/%s" % (spec['template'], rule_id)
+ return "global/%s/%s" % (spec["template"], rule_id)
class InvalidRuleException(Exception):
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 15d860db37..550a2f1b44 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.http.server import finish_request
from synapse.http.servlet import (
@@ -30,6 +28,17 @@ from synapse.rest.client.v2_alpha._base import client_patterns
logger = logging.getLogger(__name__)
+ALLOWED_KEYS = {
+ "app_display_name",
+ "app_id",
+ "data",
+ "device_display_name",
+ "kind",
+ "lang",
+ "profile_tag",
+ "pushkey",
+}
+
class PushersRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers$", v1=True)
@@ -39,32 +48,17 @@ class PushersRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request)
user = requester.user
- pushers = yield self.hs.get_datastore().get_pushers_by_user_id(
- user.to_string()
- )
+ pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
- allowed_keys = [
- "app_display_name",
- "app_id",
- "data",
- "device_display_name",
- "kind",
- "lang",
- "profile_tag",
- "pushkey",
+ filtered_pushers = [
+ {k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers
]
- for p in pushers:
- for k, v in list(p.items()):
- if k not in allowed_keys:
- del p[k]
-
- defer.returnValue((200, {"pushers": pushers}))
+ return 200, {"pushers": filtered_pushers}
def on_OPTIONS(self, _):
return 200, {}
@@ -80,61 +74,71 @@ class PushersSetRestServlet(RestServlet):
self.notifier = hs.get_notifier()
self.pusher_pool = self.hs.get_pusherpool()
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
user = requester.user
content = parse_json_object_from_request(request)
- if ('pushkey' in content and 'app_id' in content
- and 'kind' in content and
- content['kind'] is None):
- yield self.pusher_pool.remove_pusher(
- content['app_id'], content['pushkey'], user_id=user.to_string()
+ if (
+ "pushkey" in content
+ and "app_id" in content
+ and "kind" in content
+ and content["kind"] is None
+ ):
+ await self.pusher_pool.remove_pusher(
+ content["app_id"], content["pushkey"], user_id=user.to_string()
)
- defer.returnValue((200, {}))
+ return 200, {}
assert_params_in_dict(
content,
- ['kind', 'app_id', 'app_display_name',
- 'device_display_name', 'pushkey', 'lang', 'data']
+ [
+ "kind",
+ "app_id",
+ "app_display_name",
+ "device_display_name",
+ "pushkey",
+ "lang",
+ "data",
+ ],
)
- logger.debug("set pushkey %s to kind %s", content['pushkey'], content['kind'])
+ logger.debug("set pushkey %s to kind %s", content["pushkey"], content["kind"])
logger.debug("Got pushers request with body: %r", content)
append = False
- if 'append' in content:
- append = content['append']
+ if "append" in content:
+ append = content["append"]
if not append:
- yield self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
- app_id=content['app_id'],
- pushkey=content['pushkey'],
- not_user_id=user.to_string()
+ await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
+ app_id=content["app_id"],
+ pushkey=content["pushkey"],
+ not_user_id=user.to_string(),
)
try:
- yield self.pusher_pool.add_pusher(
+ await self.pusher_pool.add_pusher(
user_id=user.to_string(),
access_token=requester.access_token_id,
- kind=content['kind'],
- app_id=content['app_id'],
- app_display_name=content['app_display_name'],
- device_display_name=content['device_display_name'],
- pushkey=content['pushkey'],
- lang=content['lang'],
- data=content['data'],
- profile_tag=content.get('profile_tag', ""),
+ kind=content["kind"],
+ app_id=content["app_id"],
+ app_display_name=content["app_display_name"],
+ device_display_name=content["device_display_name"],
+ pushkey=content["pushkey"],
+ lang=content["lang"],
+ data=content["data"],
+ profile_tag=content.get("profile_tag", ""),
)
except PusherConfigException as pce:
- raise SynapseError(400, "Config Error: " + str(pce),
- errcode=Codes.MISSING_PARAM)
+ raise SynapseError(
+ 400, "Config Error: " + str(pce), errcode=Codes.MISSING_PARAM
+ )
self.notifier.on_new_replication_data()
- defer.returnValue((200, {}))
+ return 200, {}
def on_OPTIONS(self, _):
return 200, {}
@@ -144,6 +148,7 @@ class PushersRemoveRestServlet(RestServlet):
"""
To allow pusher to be delete by clicking a link (ie. GET request)
"""
+
PATTERNS = client_patterns("/pushers/remove$", v1=True)
SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>"
@@ -154,19 +159,16 @@ class PushersRemoveRestServlet(RestServlet):
self.auth = hs.get_auth()
self.pusher_pool = self.hs.get_pusherpool()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request, rights="delete_pusher")
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request, rights="delete_pusher")
user = requester.user
app_id = parse_string(request, "app_id", required=True)
pushkey = parse_string(request, "pushkey", required=True)
try:
- yield self.pusher_pool.remove_pusher(
- app_id=app_id,
- pushkey=pushkey,
- user_id=user.to_string(),
+ await self.pusher_pool.remove_pusher(
+ app_id=app_id, pushkey=pushkey, user_id=user.to_string()
)
except StoreError as se:
if se.code != 404:
@@ -177,12 +179,12 @@ class PushersRemoveRestServlet(RestServlet):
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Content-Length", b"%d" % (
- len(PushersRemoveRestServlet.SUCCESS_HTML),
- ))
+ request.setHeader(
+ b"Content-Length", b"%d" % (len(PushersRemoveRestServlet.SUCCESS_HTML),)
+ )
request.write(PushersRemoveRestServlet.SUCCESS_HTML)
finish_request(request)
- defer.returnValue(None)
+ return None
def on_OPTIONS(self, _):
return 200, {}
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 151b553730..e788eb0193 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -16,15 +16,20 @@
""" This module contains REST servlets to do with rooms: /rooms/<paths> """
import logging
+import re
+from typing import List, Optional
from six.moves.urllib import parse as urlparse
from canonicaljson import json
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ InvalidClientCredentialsError,
+ SynapseError,
+)
from synapse.api.filtering import Filter
from synapse.events.utils import format_event_for_client_v2
from synapse.http.servlet import (
@@ -34,12 +39,17 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.logging.opentracing import set_tag
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
+MYPY = False
+if MYPY:
+ import synapse.server
+
logger = logging.getLogger(__name__)
@@ -61,35 +71,39 @@ class RoomCreateRestServlet(TransactionRestServlet):
PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server)
# define CORS for all of /rooms in RoomCreateRestServlet for simplicity
- http_server.register_paths("OPTIONS",
- client_patterns("/rooms(?:/.*)?$", v1=True),
- self.on_OPTIONS)
+ http_server.register_paths(
+ "OPTIONS",
+ client_patterns("/rooms(?:/.*)?$", v1=True),
+ self.on_OPTIONS,
+ self.__class__.__name__,
+ )
# define CORS for /createRoom[/txnid]
- http_server.register_paths("OPTIONS",
- client_patterns("/createRoom(?:/.*)?$", v1=True),
- self.on_OPTIONS)
+ http_server.register_paths(
+ "OPTIONS",
+ client_patterns("/createRoom(?:/.*)?$", v1=True),
+ self.on_OPTIONS,
+ self.__class__.__name__,
+ )
def on_PUT(self, request, txn_id):
- return self.txns.fetch_or_execute_request(
- request, self.on_POST, request
- )
+ set_tag("txn_id", txn_id)
+ return self.txns.fetch_or_execute_request(request, self.on_POST, request)
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
- info = yield self._room_creation_handler.create_room(
+ info = await self._room_creation_handler.create_room(
requester, self.get_room_config(request)
)
- defer.returnValue((200, info))
+ return 200, info
def get_room_config(self, request):
user_supplied_config = parse_json_object_from_request(request)
return user_supplied_config
def on_OPTIONS(self, request):
- return (200, {})
+ return 200, {}
# TODO: Needs unit testing for generic events
@@ -107,21 +121,35 @@ class RoomStateEventRestServlet(TransactionRestServlet):
no_state_key = "/rooms/(?P<room_id>[^/]*)/state/(?P<event_type>[^/]*)$"
# /room/$roomid/state/$eventtype/$statekey
- state_key = ("/rooms/(?P<room_id>[^/]*)/state/"
- "(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$")
-
- http_server.register_paths("GET",
- client_patterns(state_key, v1=True),
- self.on_GET)
- http_server.register_paths("PUT",
- client_patterns(state_key, v1=True),
- self.on_PUT)
- http_server.register_paths("GET",
- client_patterns(no_state_key, v1=True),
- self.on_GET_no_state_key)
- http_server.register_paths("PUT",
- client_patterns(no_state_key, v1=True),
- self.on_PUT_no_state_key)
+ state_key = (
+ "/rooms/(?P<room_id>[^/]*)/state/"
+ "(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$"
+ )
+
+ http_server.register_paths(
+ "GET",
+ client_patterns(state_key, v1=True),
+ self.on_GET,
+ self.__class__.__name__,
+ )
+ http_server.register_paths(
+ "PUT",
+ client_patterns(state_key, v1=True),
+ self.on_PUT,
+ self.__class__.__name__,
+ )
+ http_server.register_paths(
+ "GET",
+ client_patterns(no_state_key, v1=True),
+ self.on_GET_no_state_key,
+ self.__class__.__name__,
+ )
+ http_server.register_paths(
+ "PUT",
+ client_patterns(no_state_key, v1=True),
+ self.on_PUT_no_state_key,
+ self.__class__.__name__,
+ )
def on_GET_no_state_key(self, request, room_id, event_type):
return self.on_GET(request, room_id, event_type, "")
@@ -129,14 +157,14 @@ class RoomStateEventRestServlet(TransactionRestServlet):
def on_PUT_no_state_key(self, request, room_id, event_type):
return self.on_PUT(request, room_id, event_type, "")
- @defer.inlineCallbacks
- def on_GET(self, request, room_id, event_type, state_key):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- format = parse_string(request, "format", default="content",
- allowed_values=["content", "event"])
+ async def on_GET(self, request, room_id, event_type, state_key):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
+ format = parse_string(
+ request, "format", default="content", allowed_values=["content", "event"]
+ )
msg_handler = self.message_handler
- data = yield msg_handler.get_room_data(
+ data = await msg_handler.get_room_data(
user_id=requester.user.to_string(),
room_id=room_id,
event_type=event_type,
@@ -145,19 +173,19 @@ class RoomStateEventRestServlet(TransactionRestServlet):
)
if not data:
- raise SynapseError(
- 404, "Event not found.", errcode=Codes.NOT_FOUND
- )
+ raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
if format == "event":
event = format_event_for_client_v2(data.get_dict())
- defer.returnValue((200, event))
+ return 200, event
elif format == "content":
- defer.returnValue((200, data.get_dict()["content"]))
+ return 200, data.get_dict()["content"]
- @defer.inlineCallbacks
- def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
+ requester = await self.auth.get_user_by_req(request)
+
+ if txn_id:
+ set_tag("txn_id", txn_id)
content = parse_json_object_from_request(request)
@@ -173,7 +201,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
if event_type == EventTypes.Member:
membership = content.get("membership", None)
- event = yield self.room_member_handler.update_membership(
+ event = await self.room_member_handler.update_membership(
requester,
target=UserID.from_string(state_key),
room_id=room_id,
@@ -181,21 +209,19 @@ class RoomStateEventRestServlet(TransactionRestServlet):
content=content,
)
else:
- event = yield self.event_creation_handler.create_and_send_nonmember_event(
- requester,
- event_dict,
- txn_id=txn_id,
+ event = await self.event_creation_handler.create_and_send_nonmember_event(
+ requester, event_dict, txn_id=txn_id
)
- ret = {}
+ ret = {} # type: dict
if event:
+ set_tag("event_id", event.event_id)
ret = {"event_id": event.event_id}
- defer.returnValue((200, ret))
+ return 200, ret
# TODO: Needs unit testing for generic events + feedback
class RoomSendEventRestServlet(TransactionRestServlet):
-
def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
@@ -203,12 +229,11 @@ class RoomSendEventRestServlet(TransactionRestServlet):
def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id]
- PATTERNS = ("/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)")
+ PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
register_txn_path(self, PATTERNS, http_server, with_get=True)
- @defer.inlineCallbacks
- def on_POST(self, request, room_id, event_type, txn_id=None):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_POST(self, request, room_id, event_type, txn_id=None):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
event_dict = {
@@ -218,21 +243,22 @@ class RoomSendEventRestServlet(TransactionRestServlet):
"sender": requester.user.to_string(),
}
- if b'ts' in request.args and requester.app_service:
- event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
+ if b"ts" in request.args and requester.app_service:
+ event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
- event = yield self.event_creation_handler.create_and_send_nonmember_event(
- requester,
- event_dict,
- txn_id=txn_id,
+ event = await self.event_creation_handler.create_and_send_nonmember_event(
+ requester, event_dict, txn_id=txn_id
)
- defer.returnValue((200, {"event_id": event.event_id}))
+ set_tag("event_id", event.event_id)
+ return 200, {"event_id": event.event_id}
def on_GET(self, request, room_id, event_type, txn_id):
- return (200, "Not implemented")
+ return 200, "Not implemented"
def on_PUT(self, request, room_id, event_type, txn_id):
+ set_tag("txn_id", txn_id)
+
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, event_type, txn_id
)
@@ -247,15 +273,11 @@ class JoinRoomAliasServlet(TransactionRestServlet):
def register(self, http_server):
# /join/$room_identifier[/$txn_id]
- PATTERNS = ("/join/(?P<room_identifier>[^/]*)")
+ PATTERNS = "/join/(?P<room_identifier>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
- @defer.inlineCallbacks
- def on_POST(self, request, room_identifier, txn_id=None):
- requester = yield self.auth.get_user_by_req(
- request,
- allow_guest=True,
- )
+ async def on_POST(self, request, room_identifier, txn_id=None):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
try:
content = parse_json_object_from_request(request)
@@ -268,21 +290,21 @@ class JoinRoomAliasServlet(TransactionRestServlet):
room_id = room_identifier
try:
remote_room_hosts = [
- x.decode('ascii') for x in request.args[b"server_name"]
- ]
+ x.decode("ascii") for x in request.args[b"server_name"]
+ ] # type: Optional[List[str]]
except Exception:
remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier):
handler = self.room_member_handler
room_alias = RoomAlias.from_string(room_identifier)
- room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias)
+ room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
room_id = room_id.to_string()
else:
- raise SynapseError(400, "%s was not legal room ID or room alias" % (
- room_identifier,
- ))
+ raise SynapseError(
+ 400, "%s was not legal room ID or room alias" % (room_identifier,)
+ )
- yield self.room_member_handler.update_membership(
+ await self.room_member_handler.update_membership(
requester=requester,
target=requester.user,
room_id=room_id,
@@ -293,9 +315,11 @@ class JoinRoomAliasServlet(TransactionRestServlet):
third_party_signed=content.get("third_party_signed", None),
)
- defer.returnValue((200, {"room_id": room_id}))
+ return 200, {"room_id": room_id}
def on_PUT(self, request, room_identifier, txn_id):
+ set_tag("txn_id", txn_id)
+
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_identifier, txn_id
)
@@ -310,13 +334,12 @@ class PublicRoomListRestServlet(TransactionRestServlet):
self.hs = hs
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request):
+ async def on_GET(self, request):
server = parse_string(request, "server", default=None)
try:
- yield self.auth.get_user_by_req(request, allow_guest=True)
- except AuthError as e:
+ await self.auth.get_user_by_req(request, allow_guest=True)
+ except InvalidClientCredentialsError as e:
# Option to allow servers to require auth when accessing
# /publicRooms via CS API. This is especially helpful in private
# federations.
@@ -336,29 +359,29 @@ class PublicRoomListRestServlet(TransactionRestServlet):
limit = parse_integer(request, "limit", 0)
since_token = parse_string(request, "since", None)
+ if limit == 0:
+ # zero is a special value which corresponds to no limit.
+ limit = None
+
handler = self.hs.get_room_list_handler()
if server:
- data = yield handler.get_remote_public_room_list(
- server,
- limit=limit,
- since_token=since_token,
+ data = await handler.get_remote_public_room_list(
+ server, limit=limit, since_token=since_token
)
else:
- data = yield handler.get_local_public_room_list(
- limit=limit,
- since_token=since_token,
+ data = await handler.get_local_public_room_list(
+ limit=limit, since_token=since_token
)
- defer.returnValue((200, data))
+ return 200, data
- @defer.inlineCallbacks
- def on_POST(self, request):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_POST(self, request):
+ await self.auth.get_user_by_req(request, allow_guest=True)
server = parse_string(request, "server", default=None)
content = parse_json_object_from_request(request)
- limit = int(content.get("limit", 100))
+ limit = int(content.get("limit", 100)) # type: Optional[int]
since_token = content.get("since", None)
search_filter = content.get("filter", None)
@@ -376,9 +399,13 @@ class PublicRoomListRestServlet(TransactionRestServlet):
else:
network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id)
+ if limit == 0:
+ # zero is a special value which corresponds to no limit.
+ limit = None
+
handler = self.hs.get_room_list_handler()
if server:
- data = yield handler.get_remote_public_room_list(
+ data = await handler.get_remote_public_room_list(
server,
limit=limit,
since_token=since_token,
@@ -387,14 +414,14 @@ class PublicRoomListRestServlet(TransactionRestServlet):
third_party_instance_id=third_party_instance_id,
)
else:
- data = yield handler.get_local_public_room_list(
+ data = await handler.get_local_public_room_list(
limit=limit,
since_token=since_token,
search_filter=search_filter,
network_tuple=network_tuple,
)
- defer.returnValue((200, data))
+ return 200, data
# TODO: Needs unit testing
@@ -406,10 +433,9 @@ class RoomMemberListRestServlet(RestServlet):
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id):
+ async def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
handler = self.message_handler
# request the state as of a given event, as identified by a stream token,
@@ -429,7 +455,7 @@ class RoomMemberListRestServlet(RestServlet):
membership = parse_string(request, "membership")
not_membership = parse_string(request, "not_membership")
- events = yield handler.get_state_events(
+ events = await handler.get_state_events(
room_id=room_id,
user_id=requester.user.to_string(),
at_token=at_token,
@@ -439,16 +465,13 @@ class RoomMemberListRestServlet(RestServlet):
chunk = []
for event in events:
- if (
- (membership and event['content'].get("membership") != membership) or
- (not_membership and event['content'].get("membership") == not_membership)
+ if (membership and event["content"].get("membership") != membership) or (
+ not_membership and event["content"].get("membership") == not_membership
):
continue
chunk.append(event)
- defer.returnValue((200, {
- "chunk": chunk
- }))
+ return 200, {"chunk": chunk}
# deprecated in favour of /members?membership=join?
@@ -461,17 +484,14 @@ class JoinedRoomMemberListRestServlet(RestServlet):
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
- users_with_profile = yield self.message_handler.get_joined_members(
- requester, room_id,
+ users_with_profile = await self.message_handler.get_joined_members(
+ requester, room_id
)
- defer.returnValue((200, {
- "joined": users_with_profile,
- }))
+ return 200, {"joined": users_with_profile}
# TODO: Needs better unit testing
@@ -483,22 +503,24 @@ class RoomMessageListRestServlet(RestServlet):
self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- pagination_config = PaginationConfig.from_request(
- request, default_limit=10,
- )
+ async def on_GET(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
+ pagination_config = PaginationConfig.from_request(request, default_limit=10)
as_client_event = b"raw" not in request.args
filter_bytes = parse_string(request, b"filter", encoding=None)
if filter_bytes:
filter_json = urlparse.unquote(filter_bytes.decode("UTF-8"))
- event_filter = Filter(json.loads(filter_json))
- if event_filter.filter_json.get("event_format", "client") == "federation":
+ event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
+ if (
+ event_filter
+ and event_filter.filter_json.get("event_format", "client")
+ == "federation"
+ ):
as_client_event = False
else:
event_filter = None
- msgs = yield self.pagination_handler.get_messages(
+
+ msgs = await self.pagination_handler.get_messages(
room_id=room_id,
requester=requester,
pagin_config=pagination_config,
@@ -506,7 +528,7 @@ class RoomMessageListRestServlet(RestServlet):
event_filter=event_filter,
)
- defer.returnValue((200, msgs))
+ return 200, msgs
# TODO: Needs unit testing
@@ -518,16 +540,15 @@ class RoomStateRestServlet(RestServlet):
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
# Get all the current state for this room
- events = yield self.message_handler.get_state_events(
+ events = await self.message_handler.get_state_events(
room_id=room_id,
user_id=requester.user.to_string(),
is_guest=requester.is_guest,
)
- defer.returnValue((200, events))
+ return 200, events
# TODO: Needs unit testing
@@ -539,16 +560,13 @@ class RoomInitialSyncRestServlet(RestServlet):
self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(request)
- content = yield self.initial_sync_handler.room_initial_sync(
- room_id=room_id,
- requester=requester,
- pagin_config=pagination_config,
+ content = await self.initial_sync_handler.room_initial_sync(
+ room_id=room_id, requester=requester, pagin_config=pagination_config
)
- defer.returnValue((200, content))
+ return 200, content
class RoomEventServlet(RestServlet):
@@ -563,17 +581,24 @@ class RoomEventServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id, event_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- event = yield self.event_handler.get_event(requester.user, room_id, event_id)
+ async def on_GET(self, request, room_id, event_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
+ try:
+ event = await self.event_handler.get_event(
+ requester.user, room_id, event_id
+ )
+ except AuthError:
+ # This endpoint is supposed to return a 404 when the requester does
+ # not have permission to access the event
+ # https://matrix.org/docs/spec/client_server/r0.5.0#get-matrix-client-r0-rooms-roomid-event-eventid
+ raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
time_now = self.clock.time_msec()
if event:
- event = yield self._event_serializer.serialize_event(event, time_now)
- defer.returnValue((200, event))
- else:
- defer.returnValue((404, "Event not found."))
+ event = await self._event_serializer.serialize_event(event, time_now)
+ return 200, event
+
+ return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
class RoomEventContextServlet(RestServlet):
@@ -588,9 +613,8 @@ class RoomEventContextServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id, event_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, room_id, event_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
limit = parse_integer(request, "limit", default=10)
@@ -598,38 +622,32 @@ class RoomEventContextServlet(RestServlet):
filter_bytes = parse_string(request, "filter")
if filter_bytes:
filter_json = urlparse.unquote(filter_bytes)
- event_filter = Filter(json.loads(filter_json))
+ event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
else:
event_filter = None
- results = yield self.room_context_handler.get_event_context(
- requester.user,
- room_id,
- event_id,
- limit,
- event_filter,
+ results = await self.room_context_handler.get_event_context(
+ requester.user, room_id, event_id, limit, event_filter
)
if not results:
- raise SynapseError(
- 404, "Event not found.", errcode=Codes.NOT_FOUND
- )
+ raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
time_now = self.clock.time_msec()
- results["events_before"] = yield self._event_serializer.serialize_events(
- results["events_before"], time_now,
+ results["events_before"] = await self._event_serializer.serialize_events(
+ results["events_before"], time_now
)
- results["event"] = yield self._event_serializer.serialize_event(
- results["event"], time_now,
+ results["event"] = await self._event_serializer.serialize_event(
+ results["event"], time_now
)
- results["events_after"] = yield self._event_serializer.serialize_events(
- results["events_after"], time_now,
+ results["events_after"] = await self._event_serializer.serialize_events(
+ results["events_after"], time_now
)
- results["state"] = yield self._event_serializer.serialize_events(
- results["state"], time_now,
+ results["state"] = await self._event_serializer.serialize_events(
+ results["state"], time_now
)
- defer.returnValue((200, results))
+ return 200, results
class RoomForgetRestServlet(TransactionRestServlet):
@@ -639,24 +657,19 @@ class RoomForgetRestServlet(TransactionRestServlet):
self.auth = hs.get_auth()
def register(self, http_server):
- PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
+ PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
register_txn_path(self, PATTERNS, http_server)
- @defer.inlineCallbacks
- def on_POST(self, request, room_id, txn_id=None):
- requester = yield self.auth.get_user_by_req(
- request,
- allow_guest=False,
- )
+ async def on_POST(self, request, room_id, txn_id=None):
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
- yield self.room_member_handler.forget(
- user=requester.user,
- room_id=room_id,
- )
+ await self.room_member_handler.forget(user=requester.user, room_id=room_id)
- defer.returnValue((200, {}))
+ return 200, {}
def on_PUT(self, request, room_id, txn_id):
+ set_tag("txn_id", txn_id)
+
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, txn_id
)
@@ -664,7 +677,6 @@ class RoomForgetRestServlet(TransactionRestServlet):
# TODO: Needs unit testing
class RoomMembershipRestServlet(TransactionRestServlet):
-
def __init__(self, hs):
super(RoomMembershipRestServlet, self).__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
@@ -672,20 +684,18 @@ class RoomMembershipRestServlet(TransactionRestServlet):
def register(self, http_server):
# /rooms/$roomid/[invite|join|leave]
- PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
- "(?P<membership_action>join|invite|leave|ban|unban|kick)")
+ PATTERNS = (
+ "/rooms/(?P<room_id>[^/]*)/"
+ "(?P<membership_action>join|invite|leave|ban|unban|kick)"
+ )
register_txn_path(self, PATTERNS, http_server)
- @defer.inlineCallbacks
- def on_POST(self, request, room_id, membership_action, txn_id=None):
- requester = yield self.auth.get_user_by_req(
- request,
- allow_guest=True,
- )
+ async def on_POST(self, request, room_id, membership_action, txn_id=None):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
if requester.is_guest and membership_action not in {
Membership.JOIN,
- Membership.LEAVE
+ Membership.LEAVE,
}:
raise AuthError(403, "Guest access not allowed")
@@ -697,7 +707,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
content = {}
if membership_action == "invite" and self._has_3pid_invite_keys(content):
- yield self.room_member_handler.do_3pid_invite(
+ await self.room_member_handler.do_3pid_invite(
room_id,
requester.user,
content["medium"],
@@ -706,9 +716,9 @@ class RoomMembershipRestServlet(TransactionRestServlet):
requester,
txn_id,
new_room=False,
+ id_access_token=content.get("id_access_token"),
)
- defer.returnValue((200, {}))
- return
+ return 200, {}
target = requester.user
if membership_action in ["invite", "ban", "unban", "kick"]:
@@ -716,10 +726,10 @@ class RoomMembershipRestServlet(TransactionRestServlet):
target = UserID.from_string(content["user_id"])
event_content = None
- if 'reason' in content and membership_action in ['kick', 'ban']:
- event_content = {'reason': content['reason']}
+ if "reason" in content:
+ event_content = {"reason": content["reason"]}
- yield self.room_member_handler.update_membership(
+ await self.room_member_handler.update_membership(
requester=requester,
target=target,
room_id=room_id,
@@ -734,7 +744,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
if membership_action == "join":
return_value["room_id"] = room_id
- defer.returnValue((200, return_value))
+ return 200, return_value
def _has_3pid_invite_keys(self, content):
for key in {"id_server", "medium", "address"}:
@@ -743,6 +753,8 @@ class RoomMembershipRestServlet(TransactionRestServlet):
return True
def on_PUT(self, request, room_id, membership_action, txn_id):
+ set_tag("txn_id", txn_id)
+
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, membership_action, txn_id
)
@@ -756,15 +768,14 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
self.auth = hs.get_auth()
def register(self, http_server):
- PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
+ PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
- @defer.inlineCallbacks
- def on_POST(self, request, room_id, event_id, txn_id=None):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request, room_id, event_id, txn_id=None):
+ requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
- event = yield self.event_creation_handler.create_and_send_nonmember_event(
+ event = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Redaction,
@@ -776,9 +787,12 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
txn_id=txn_id,
)
- defer.returnValue((200, {"event_id": event.event_id}))
+ set_tag("event_id", event.event_id)
+ return 200, {"event_id": event.event_id}
def on_PUT(self, request, room_id, event_id, txn_id):
+ set_tag("txn_id", txn_id)
+
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, room_id, event_id, txn_id
)
@@ -795,35 +809,55 @@ class RoomTypingRestServlet(RestServlet):
self.typing_handler = hs.get_typing_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_PUT(self, request, room_id, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, room_id, user_id):
+ requester = await self.auth.get_user_by_req(request)
room_id = urlparse.unquote(room_id)
target_user = UserID.from_string(urlparse.unquote(user_id))
content = parse_json_object_from_request(request)
- yield self.presence_handler.bump_presence_active_time(requester.user)
+ await self.presence_handler.bump_presence_active_time(requester.user)
# Limit timeout to stop people from setting silly typing timeouts.
timeout = min(content.get("timeout", 30000), 120000)
if content["typing"]:
- yield self.typing_handler.started_typing(
+ await self.typing_handler.started_typing(
target_user=target_user,
auth_user=requester.user,
room_id=room_id,
timeout=timeout,
)
else:
- yield self.typing_handler.stopped_typing(
- target_user=target_user,
- auth_user=requester.user,
- room_id=room_id,
+ await self.typing_handler.stopped_typing(
+ target_user=target_user, auth_user=requester.user, room_id=room_id
)
- defer.returnValue((200, {}))
+ return 200, {}
+
+
+class RoomAliasListServlet(RestServlet):
+ PATTERNS = [
+ re.compile(
+ r"^/_matrix/client/unstable/org\.matrix\.msc2432"
+ r"/rooms/(?P<room_id>[^/]*)/aliases"
+ ),
+ ]
+
+ def __init__(self, hs: "synapse.server.HomeServer"):
+ super().__init__()
+ self.auth = hs.get_auth()
+ self.directory_handler = hs.get_handlers().directory_handler
+
+ async def on_GET(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
+
+ alias_list = await self.directory_handler.get_aliases_for_room(
+ requester, room_id
+ )
+
+ return 200, {"aliases": alias_list}
class SearchRestServlet(RestServlet):
@@ -834,20 +868,17 @@ class SearchRestServlet(RestServlet):
self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
batch = parse_string(request, "next_batch")
- results = yield self.handlers.search_handler.search(
- requester.user,
- content,
- batch,
+ results = await self.handlers.search_handler.search(
+ requester.user, content, batch
)
- defer.returnValue((200, results))
+ return 200, results
class JoinedRoomsRestServlet(RestServlet):
@@ -858,12 +889,11 @@ class JoinedRoomsRestServlet(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
- room_ids = yield self.store.get_rooms_for_user(requester.user.to_string())
- defer.returnValue((200, {"joined_rooms": list(room_ids)}))
+ room_ids = await self.store.get_rooms_for_user(requester.user.to_string())
+ return 200, {"joined_rooms": list(room_ids)}
def register_txn_path(servlet, regex_string, http_server, with_get=False):
@@ -882,18 +912,21 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
http_server.register_paths(
"POST",
client_patterns(regex_string + "$", v1=True),
- servlet.on_POST
+ servlet.on_POST,
+ servlet.__class__.__name__,
)
http_server.register_paths(
"PUT",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
- servlet.on_PUT
+ servlet.on_PUT,
+ servlet.__class__.__name__,
)
if with_get:
http_server.register_paths(
"GET",
client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
- servlet.on_GET
+ servlet.on_GET,
+ servlet.__class__.__name__,
)
@@ -915,6 +948,7 @@ def register_servlets(hs, http_server):
JoinedRoomsRestServlet(hs).register(http_server)
RoomEventServlet(hs).register(http_server)
RoomEventContextServlet(hs).register(http_server)
+ RoomAliasListServlet(hs).register(http_server)
def register_deprecated_servlets(hs, http_server):
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 6381049210..747d46eac2 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -17,8 +17,6 @@ import base64
import hashlib
import hmac
-from twisted.internet import defer
-
from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns
@@ -31,11 +29,9 @@ class VoipRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(
- request,
- self.hs.config.turn_allow_guests
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(
+ request, self.hs.config.turn_allow_guests
)
turnUris = self.hs.config.turn_uris
@@ -49,9 +45,7 @@ class VoipRestServlet(RestServlet):
username = "%d:%s" % (expiry, requester.user.to_string())
mac = hmac.new(
- turnSecret.encode(),
- msg=username.encode(),
- digestmod=hashlib.sha1
+ turnSecret.encode(), msg=username.encode(), digestmod=hashlib.sha1
)
# We need to use standard padded base64 encoding here
# encode_base64 because we need to add the standard padding to get the
@@ -63,17 +57,20 @@ class VoipRestServlet(RestServlet):
password = turnPassword
else:
- defer.returnValue((200, {}))
-
- defer.returnValue((200, {
- 'username': username,
- 'password': password,
- 'ttl': userLifetime / 1000,
- 'uris': turnUris,
- }))
+ return 200, {}
+
+ return (
+ 200,
+ {
+ "username": username,
+ "password": password,
+ "ttl": userLifetime / 1000,
+ "uris": turnUris,
+ },
+ )
def on_OPTIONS(self, request):
- return (200, {})
+ return 200, {}
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py
index 5236d5d566..bc11b4dda4 100644
--- a/synapse/rest/client/v2_alpha/_base.py
+++ b/synapse/rest/client/v2_alpha/_base.py
@@ -32,11 +32,12 @@ def client_patterns(path_regex, releases=(0,), unstable=True, v1=False):
Args:
path_regex (str): The regex string to match. This should NOT have a ^
- as this will be prefixed.
+ as this will be prefixed.
Returns:
SRE_Pattern
"""
patterns = []
+
if unstable:
unstable_prefix = CLIENT_API_PREFIX + "/unstable"
patterns.append(re.compile("^" + unstable_prefix + path_regex))
@@ -46,17 +47,18 @@ def client_patterns(path_regex, releases=(0,), unstable=True, v1=False):
for release in releases:
new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,)
patterns.append(re.compile("^" + new_prefix + path_regex))
+
return patterns
def set_timeline_upper_limit(filter_json, filter_timeline_limit):
if filter_timeline_limit < 0:
return # no upper limits
- timeline = filter_json.get('room', {}).get('timeline', {})
- if 'limit' in timeline:
- filter_json['room']['timeline']["limit"] = min(
- filter_json['room']['timeline']['limit'],
- filter_timeline_limit)
+ timeline = filter_json.get("room", {}).get("timeline", {})
+ if "limit" in timeline:
+ filter_json["room"]["timeline"]["limit"] = min(
+ filter_json["room"]["timeline"]["limit"], filter_timeline_limit
+ )
def interactive_auth_handler(orig):
@@ -74,10 +76,12 @@ def interactive_auth_handler(orig):
# ...
yield self.auth_handler.check_auth
"""
+
def wrapped(*args, **kwargs):
- res = defer.maybeDeferred(orig, *args, **kwargs)
+ res = defer.ensureDeferred(orig(*args, **kwargs))
res.addErrback(_catch_incomplete_interactive_auth)
return res
+
return wrapped
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 974465e90c..7d2cd29a60 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -19,12 +19,11 @@ import re
from six.moves import http_client
-import jinja2
-
from twisted.internet import defer
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError, ThreepidValidationError
+from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http.server import finish_request
from synapse.http.servlet import (
RestServlet,
@@ -32,9 +31,10 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.push.mailer import Mailer, load_jinja2_templates
from synapse.types import UserID
from synapse.util.msisdn import phone_number_to_msisdn
-from synapse.util.stringutils import assert_valid_client_secret, random_string
+from synapse.util.stringutils import assert_valid_client_secret
from synapse.util.threepids import check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler
@@ -52,30 +52,37 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
self.config = hs.config
self.identity_handler = hs.get_handlers().identity_handler
- if self.config.email_password_reset_behaviour == "local":
- from synapse.push.mailer import Mailer, load_jinja2_templates
- templates = load_jinja2_templates(
- config=hs.config,
- template_html_name=hs.config.email_password_reset_template_html,
- template_text_name=hs.config.email_password_reset_template_text,
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ template_html, template_text = load_jinja2_templates(
+ self.config.email_template_dir,
+ [
+ self.config.email_password_reset_template_html,
+ self.config.email_password_reset_template_text,
+ ],
+ apply_format_ts_filter=True,
+ apply_mxc_to_http_filter=True,
+ public_baseurl=self.config.public_baseurl,
)
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email_app_name,
- template_html=templates[0],
- template_text=templates[1],
+ template_html=template_html,
+ template_text=template_text,
)
- @defer.inlineCallbacks
- def on_POST(self, request):
- if self.config.email_password_reset_behaviour == "off":
- raise SynapseError(400, "Password resets have been disabled on this server")
+ async def on_POST(self, request):
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
+ if self.config.local_threepid_handling_disabled_due_to_email_config:
+ logger.warning(
+ "User password resets have been disabled due to lack of email config"
+ )
+ raise SynapseError(
+ 400, "Email-based password resets have been disabled on this server"
+ )
body = parse_json_object_from_request(request)
- assert_params_in_dict(body, [
- 'client_secret', 'email', 'send_attempt'
- ])
+ assert_params_in_dict(body, ["client_secret", "email", "send_attempt"])
# Extract params from body
client_secret = body["client_secret"]
@@ -92,151 +99,45 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
- 'email', email,
+ existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
+ "email", email
)
- if existingUid is None:
+ if existing_user_id is None:
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
- if self.config.email_password_reset_behaviour == "remote":
- if 'id_server' not in body:
- raise SynapseError(400, "Missing 'id_server' param in body")
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ assert self.hs.config.account_threepid_delegate_email
- # Have the identity server handle the password reset flow
- ret = yield self.identity_handler.requestEmailToken(
- body["id_server"], email, client_secret, send_attempt, next_link,
+ # Have the configured identity server handle the request
+ ret = await self.identity_handler.requestEmailToken(
+ self.hs.config.account_threepid_delegate_email,
+ email,
+ client_secret,
+ send_attempt,
+ next_link,
)
else:
# Send password reset emails from Synapse
- sid = yield self.send_password_reset(
- email, client_secret, send_attempt, next_link,
+ sid = await self.identity_handler.send_threepid_validation(
+ email,
+ client_secret,
+ send_attempt,
+ self.mailer.send_password_reset_mail,
+ next_link,
)
# Wrap the session id in a JSON object
ret = {"sid": sid}
- defer.returnValue((200, ret))
-
- @defer.inlineCallbacks
- def send_password_reset(
- self,
- email,
- client_secret,
- send_attempt,
- next_link=None,
- ):
- """Send a password reset email
-
- Args:
- email (str): The user's email address
- client_secret (str): The provided client secret
- send_attempt (int): Which send attempt this is
-
- Returns:
- The new session_id upon success
-
- Raises:
- SynapseError is an error occurred when sending the email
- """
- # Check that this email/client_secret/send_attempt combo is new or
- # greater than what we've seen previously
- session = yield self.datastore.get_threepid_validation_session(
- "email", client_secret, address=email, validated=False,
- )
-
- # Check to see if a session already exists and that it is not yet
- # marked as validated
- if session and session.get("validated_at") is None:
- session_id = session['session_id']
- last_send_attempt = session['last_send_attempt']
-
- # Check that the send_attempt is higher than previous attempts
- if send_attempt <= last_send_attempt:
- # If not, just return a success without sending an email
- defer.returnValue(session_id)
- else:
- # An non-validated session does not exist yet.
- # Generate a session id
- session_id = random_string(16)
-
- # Generate a new validation token
- token = random_string(32)
-
- # Send the mail with the link containing the token, client_secret
- # and session_id
- try:
- yield self.mailer.send_password_reset_mail(
- email, token, client_secret, session_id,
- )
- except Exception:
- logger.exception(
- "Error sending a password reset email to %s", email,
- )
- raise SynapseError(
- 500, "An error was encountered when sending the password reset email"
- )
-
- token_expires = (self.hs.clock.time_msec() +
- self.config.email_validation_token_lifetime)
-
- yield self.datastore.start_or_continue_validation_session(
- "email", email, session_id, client_secret, send_attempt,
- next_link, token, token_expires,
- )
-
- defer.returnValue(session_id)
-
-
-class MsisdnPasswordRequestTokenRestServlet(RestServlet):
- PATTERNS = client_patterns("/account/password/msisdn/requestToken$")
-
- def __init__(self, hs):
- super(MsisdnPasswordRequestTokenRestServlet, self).__init__()
- self.hs = hs
- self.datastore = self.hs.get_datastore()
- self.identity_handler = hs.get_handlers().identity_handler
-
- @defer.inlineCallbacks
- def on_POST(self, request):
- if not self.config.email_password_reset_behaviour == "off":
- raise SynapseError(400, "Password resets have been disabled on this server")
-
- body = parse_json_object_from_request(request)
-
- assert_params_in_dict(body, [
- 'id_server', 'client_secret',
- 'country', 'phone_number', 'send_attempt',
- ])
-
- msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
-
- if not (yield check_3pid_allowed(self.hs, "msisdn", msisdn)):
- raise SynapseError(
- 403,
- "Account phone numbers are not authorized on this server",
- Codes.THREEPID_DENIED,
- )
-
- assert_valid_client_secret(body["client_secret"])
-
- existingUid = yield self.datastore.get_user_id_by_threepid(
- 'msisdn', msisdn
- )
-
- if existingUid is None:
- raise SynapseError(400, "MSISDN not found", Codes.THREEPID_NOT_FOUND)
-
- ret = yield self.identity_handler.requestMsisdnToken(**body)
- defer.returnValue((200, ret))
+ return 200, ret
class PasswordResetSubmitTokenServlet(RestServlet):
"""Handles 3PID validation token submission"""
+
PATTERNS = client_patterns(
- "/password_reset/(?P<medium>[^/]*)/submit_token/*$",
- releases=(),
- unstable=True,
+ "/password_reset/(?P<medium>[^/]*)/submit_token$", releases=(), unstable=True
)
def __init__(self, hs):
@@ -249,105 +150,64 @@ class PasswordResetSubmitTokenServlet(RestServlet):
self.auth = hs.get_auth()
self.config = hs.config
self.clock = hs.get_clock()
- self.datastore = hs.get_datastore()
+ self.store = hs.get_datastore()
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ (self.failure_email_template,) = load_jinja2_templates(
+ self.config.email_template_dir,
+ [self.config.email_password_reset_template_failure_html],
+ )
- @defer.inlineCallbacks
- def on_GET(self, request, medium):
+ async def on_GET(self, request, medium):
+ # We currently only handle threepid token submissions for email
if medium != "email":
raise SynapseError(
- 400,
- "This medium is currently not supported for password resets",
+ 400, "This medium is currently not supported for password resets"
+ )
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
+ if self.config.local_threepid_handling_disabled_due_to_email_config:
+ logger.warning(
+ "Password reset emails have been disabled due to lack of an email config"
+ )
+ raise SynapseError(
+ 400, "Email-based password resets are disabled on this server"
)
- sid = parse_string(request, "sid")
- client_secret = parse_string(request, "client_secret")
-
+ sid = parse_string(request, "sid", required=True)
+ token = parse_string(request, "token", required=True)
+ client_secret = parse_string(request, "client_secret", required=True)
assert_valid_client_secret(client_secret)
- token = parse_string(request, "token")
-
- # Attempt to validate a 3PID sesssion
+ # Attempt to validate a 3PID session
try:
# Mark the session as valid
- next_link = yield self.datastore.validate_threepid_session(
- sid,
- client_secret,
- token,
- self.clock.time_msec(),
+ next_link = await self.store.validate_threepid_session(
+ sid, client_secret, token, self.clock.time_msec()
)
# Perform a 302 redirect if next_link is set
if next_link:
if next_link.startswith("file:///"):
- logger.warn(
+ logger.warning(
"Not redirecting to next_link as it is a local file: address"
)
else:
request.setResponseCode(302)
request.setHeader("Location", next_link)
finish_request(request)
- defer.returnValue(None)
+ return None
# Otherwise show the success template
- html = self.config.email_password_reset_success_html_content
+ html = self.config.email_password_reset_template_success_html
request.setResponseCode(200)
except ThreepidValidationError as e:
- # Show a failure page with a reason
- html = self.load_jinja2_template(
- self.config.email_template_dir,
- self.config.email_password_reset_failure_template,
- template_vars={
- "failure_reason": e.msg,
- }
- )
request.setResponseCode(e.code)
- request.write(html.encode('utf-8'))
- finish_request(request)
- defer.returnValue(None)
-
- def load_jinja2_template(self, template_dir, template_filename, template_vars):
- """Loads a jinja2 template with variables to insert
-
- Args:
- template_dir (str): The directory where templates are stored
- template_filename (str): The name of the template in the template_dir
- template_vars (Dict): Dictionary of keys in the template
- alongside their values to insert
-
- Returns:
- str containing the contents of the rendered template
- """
- loader = jinja2.FileSystemLoader(template_dir)
- env = jinja2.Environment(loader=loader)
-
- template = env.get_template(template_filename)
- return template.render(**template_vars)
-
- @defer.inlineCallbacks
- def on_POST(self, request, medium):
- if medium != "email":
- raise SynapseError(
- 400,
- "This medium is currently not supported for password resets",
- )
-
- body = parse_json_object_from_request(request)
- assert_params_in_dict(body, [
- 'sid', 'client_secret', 'token',
- ])
-
- assert_valid_client_secret(body["client_secret"])
-
- valid, _ = yield self.datastore.validate_threepid_session(
- body['sid'],
- body['client_secret'],
- body['token'],
- self.clock.time_msec(),
- )
- response_code = 200 if valid else 400
+ # Show a failure page with a reason
+ template_vars = {"failure_reason": e.msg}
+ html = self.failure_email_template.render(**template_vars)
- defer.returnValue((response_code, {"success": valid}))
+ request.write(html.encode("utf-8"))
+ finish_request(request)
class PasswordRestServlet(RestServlet):
@@ -363,8 +223,7 @@ class PasswordRestServlet(RestServlet):
self.http_client = hs.get_simple_http_client()
@interactive_auth_handler
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
body = parse_json_object_from_request(request)
# there are two possibilities here. Either the user does not have an
@@ -378,35 +237,33 @@ class PasswordRestServlet(RestServlet):
# In the second case, we require a password to confirm their identity.
if self.auth.has_access_token(request):
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
# blindly trust ASes without UI-authing them
if requester.app_service:
params = body
else:
- params = yield self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request),
+ params = await self.auth_handler.validate_user_via_ui_auth(
+ requester, body, self.hs.get_ip_from_request(request)
)
user_id = requester.user.to_string()
else:
requester = None
- result, params, _ = yield self.auth_handler.check_auth(
- [[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]],
- body, self.hs.get_ip_from_request(request),
- password_servlet=True,
+ result, params, _ = await self.auth_handler.check_auth(
+ [[LoginType.EMAIL_IDENTITY]], body, self.hs.get_ip_from_request(request)
)
if LoginType.EMAIL_IDENTITY in result:
threepid = result[LoginType.EMAIL_IDENTITY]
- if 'medium' not in threepid or 'address' not in threepid:
+ if "medium" not in threepid or "address" not in threepid:
raise SynapseError(500, "Malformed threepid")
- if threepid['medium'] == 'email':
+ if threepid["medium"] == "email":
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
- threepid['address'] = threepid['address'].lower()
+ threepid["address"] = threepid["address"].lower()
# if using email, we must know about the email they're authing with!
- threepid_user_id = yield self.datastore.get_user_id_by_threepid(
- threepid['medium'], threepid['address']
+ threepid_user_id = await self.datastore.get_user_id_by_threepid(
+ threepid["medium"], threepid["address"]
)
if not threepid_user_id:
raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
@@ -416,10 +273,11 @@ class PasswordRestServlet(RestServlet):
raise SynapseError(500, "", Codes.UNKNOWN)
assert_params_in_dict(params, ["new_password"])
- new_password = params['new_password']
+ new_password = params["new_password"]
+ logout_devices = params.get("logout_devices", True)
- yield self._set_password_handler.set_password(
- user_id, new_password, requester
+ await self._set_password_handler.set_password(
+ user_id, new_password, logout_devices, requester
)
if self.hs.config.shadow_server:
@@ -428,7 +286,7 @@ class PasswordRestServlet(RestServlet):
)
self.shadow_password(params, shadow_user.to_string())
- defer.returnValue((200, {}))
+ return 200, {}
def on_OPTIONS(self, _):
return 200, {}
@@ -440,10 +298,9 @@ class PasswordRestServlet(RestServlet):
as_token = self.hs.config.shadow_server.get("as_token")
yield self.http_client.post_json_get_json(
- "%s/_matrix/client/r0/account/password?access_token=%s&user_id=%s" % (
- shadow_hs_url, as_token, user_id,
- ),
- body
+ "%s/_matrix/client/r0/account/password?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
)
@@ -458,8 +315,7 @@ class DeactivateAccountRestServlet(RestServlet):
self._deactivate_account_handler = hs.get_deactivate_account_handler()
@interactive_auth_handler
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
body = parse_json_object_from_request(request)
erase = body.get("erase", False)
if not isinstance(erase, bool):
@@ -469,50 +325,75 @@ class DeactivateAccountRestServlet(RestServlet):
Codes.BAD_JSON,
)
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
# allow ASes to dectivate their own users
if requester.app_service:
- yield self._deactivate_account_handler.deactivate_account(
- requester.user.to_string(), erase,
+ await self._deactivate_account_handler.deactivate_account(
+ requester.user.to_string(), erase
)
- defer.returnValue((200, {}))
+ return 200, {}
- yield self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request),
+ await self.auth_handler.validate_user_via_ui_auth(
+ requester, body, self.hs.get_ip_from_request(request)
)
- result = yield self._deactivate_account_handler.deactivate_account(
- requester.user.to_string(), erase,
- id_server=body.get("id_server"),
+ result = await self._deactivate_account_handler.deactivate_account(
+ requester.user.to_string(), erase, id_server=body.get("id_server")
)
if result:
id_server_unbind_result = "success"
else:
id_server_unbind_result = "no-support"
- defer.returnValue((200, {
- "id_server_unbind_result": id_server_unbind_result,
- }))
+ return 200, {"id_server_unbind_result": id_server_unbind_result}
class EmailThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/email/requestToken$")
def __init__(self, hs):
- self.hs = hs
super(EmailThreepidRequestTokenRestServlet, self).__init__()
+ self.hs = hs
+ self.config = hs.config
self.identity_handler = hs.get_handlers().identity_handler
- self.datastore = self.hs.get_datastore()
+ self.store = self.hs.get_datastore()
+
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ template_html, template_text = load_jinja2_templates(
+ self.config.email_template_dir,
+ [
+ self.config.email_add_threepid_template_html,
+ self.config.email_add_threepid_template_text,
+ ],
+ public_baseurl=self.config.public_baseurl,
+ )
+ self.mailer = Mailer(
+ hs=self.hs,
+ app_name=self.config.email_app_name,
+ template_html=template_html,
+ template_text=template_text,
+ )
+
+ async def on_POST(self, request):
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
+ if self.config.local_threepid_handling_disabled_due_to_email_config:
+ logger.warning(
+ "Adding emails have been disabled due to lack of an email config"
+ )
+ raise SynapseError(
+ 400, "Adding an email to your account is disabled on this server"
+ )
- @defer.inlineCallbacks
- def on_POST(self, request):
body = parse_json_object_from_request(request)
- assert_params_in_dict(
- body,
- ['id_server', 'client_secret', 'email', 'send_attempt'],
- )
+ assert_params_in_dict(body, ["client_secret", "email", "send_attempt"])
+ client_secret = body["client_secret"]
+ assert_valid_client_secret(client_secret)
- if not (yield check_3pid_allowed(self.hs, "email", body['email'])):
+ email = body["email"]
+ send_attempt = body["send_attempt"]
+ next_link = body.get("next_link") # Optional param
+
+ if not (await check_3pid_allowed(self.hs, "email", email)):
raise SynapseError(
403,
"Your email is not authorized on this server",
@@ -521,15 +402,38 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
assert_valid_client_secret(body["client_secret"])
- existingUid = yield self.datastore.get_user_id_by_threepid(
- 'email', body['email']
+ existing_user_id = await self.store.get_user_id_by_threepid(
+ "email", body["email"]
)
- if existingUid is not None:
+ if existing_user_id is not None:
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
- ret = yield self.identity_handler.requestEmailToken(**body)
- defer.returnValue((200, ret))
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ assert self.hs.config.account_threepid_delegate_email
+
+ # Have the configured identity server handle the request
+ ret = await self.identity_handler.requestEmailToken(
+ self.hs.config.account_threepid_delegate_email,
+ email,
+ client_secret,
+ send_attempt,
+ next_link,
+ )
+ else:
+ # Send threepid validation emails from Synapse
+ sid = await self.identity_handler.send_threepid_validation(
+ email,
+ client_secret,
+ send_attempt,
+ self.mailer.send_add_threepid_mail,
+ next_link,
+ )
+
+ # Wrap the session id in a JSON object
+ ret = {"sid": sid}
+
+ return 200, ret
class MsisdnThreepidRequestTokenRestServlet(RestServlet):
@@ -538,20 +442,25 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
def __init__(self, hs):
self.hs = hs
super(MsisdnThreepidRequestTokenRestServlet, self).__init__()
+ self.store = self.hs.get_datastore()
self.identity_handler = hs.get_handlers().identity_handler
- self.datastore = self.hs.get_datastore()
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
body = parse_json_object_from_request(request)
- assert_params_in_dict(body, [
- 'id_server', 'client_secret',
- 'country', 'phone_number', 'send_attempt',
- ])
+ assert_params_in_dict(
+ body, ["client_secret", "country", "phone_number", "send_attempt"]
+ )
+ client_secret = body["client_secret"]
+ assert_valid_client_secret(client_secret)
- msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+ country = body["country"]
+ phone_number = body["phone_number"]
+ send_attempt = body["send_attempt"]
+ next_link = body.get("next_link") # Optional param
- if not (yield check_3pid_allowed(self.hs, "msisdn", msisdn)):
+ msisdn = phone_number_to_msisdn(country, phone_number)
+
+ if not (await check_3pid_allowed(self.hs, "msisdn", msisdn)):
raise SynapseError(
403,
"Account phone numbers are not authorized on this server",
@@ -560,15 +469,149 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
assert_valid_client_secret(body["client_secret"])
- existingUid = yield self.datastore.get_user_id_by_threepid(
- 'msisdn', msisdn
- )
+ existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
- if existingUid is not None:
+ if existing_user_id is not None:
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
- ret = yield self.identity_handler.requestMsisdnToken(**body)
- defer.returnValue((200, ret))
+ if not self.hs.config.account_threepid_delegate_msisdn:
+ logger.warning(
+ "No upstream msisdn account_threepid_delegate configured on the server to "
+ "handle this request"
+ )
+ raise SynapseError(
+ 400,
+ "Adding phone numbers to user account is not supported by this homeserver",
+ )
+
+ ret = await self.identity_handler.requestMsisdnToken(
+ self.hs.config.account_threepid_delegate_msisdn,
+ country,
+ phone_number,
+ client_secret,
+ send_attempt,
+ next_link,
+ )
+
+ return 200, ret
+
+
+class AddThreepidEmailSubmitTokenServlet(RestServlet):
+ """Handles 3PID validation token submission for adding an email to a user's account"""
+
+ PATTERNS = client_patterns(
+ "/add_threepid/email/submit_token$", releases=(), unstable=True
+ )
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super().__init__()
+ self.config = hs.config
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastore()
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ (self.failure_email_template,) = load_jinja2_templates(
+ self.config.email_template_dir,
+ [self.config.email_add_threepid_template_failure_html],
+ )
+
+ async def on_GET(self, request):
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
+ if self.config.local_threepid_handling_disabled_due_to_email_config:
+ logger.warning(
+ "Adding emails have been disabled due to lack of an email config"
+ )
+ raise SynapseError(
+ 400, "Adding an email to your account is disabled on this server"
+ )
+ elif self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ raise SynapseError(
+ 400,
+ "This homeserver is not validating threepids. Use an identity server "
+ "instead.",
+ )
+
+ sid = parse_string(request, "sid", required=True)
+ token = parse_string(request, "token", required=True)
+ client_secret = parse_string(request, "client_secret", required=True)
+ assert_valid_client_secret(client_secret)
+
+ # Attempt to validate a 3PID session
+ try:
+ # Mark the session as valid
+ next_link = await self.store.validate_threepid_session(
+ sid, client_secret, token, self.clock.time_msec()
+ )
+
+ # Perform a 302 redirect if next_link is set
+ if next_link:
+ if next_link.startswith("file:///"):
+ logger.warning(
+ "Not redirecting to next_link as it is a local file: address"
+ )
+ else:
+ request.setResponseCode(302)
+ request.setHeader("Location", next_link)
+ finish_request(request)
+ return None
+
+ # Otherwise show the success template
+ html = self.config.email_add_threepid_template_success_html_content
+ request.setResponseCode(200)
+ except ThreepidValidationError as e:
+ request.setResponseCode(e.code)
+
+ # Show a failure page with a reason
+ template_vars = {"failure_reason": e.msg}
+ html = self.failure_email_template.render(**template_vars)
+
+ request.write(html.encode("utf-8"))
+ finish_request(request)
+
+
+class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
+ """Handles 3PID validation token submission for adding a phone number to a user's
+ account
+ """
+
+ PATTERNS = client_patterns(
+ "/add_threepid/msisdn/submit_token$", releases=(), unstable=True
+ )
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super().__init__()
+ self.config = hs.config
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastore()
+ self.identity_handler = hs.get_handlers().identity_handler
+
+ async def on_POST(self, request):
+ if not self.config.account_threepid_delegate_msisdn:
+ raise SynapseError(
+ 400,
+ "This homeserver is not validating phone numbers. Use an identity server "
+ "instead.",
+ )
+
+ body = parse_json_object_from_request(request)
+ assert_params_in_dict(body, ["client_secret", "sid", "token"])
+ assert_valid_client_secret(body["client_secret"])
+
+ # Proxy submit_token request to msisdn threepid delegate
+ response = await self.identity_handler.proxy_msisdn_submit_token(
+ self.config.account_threepid_delegate_msisdn,
+ body["client_secret"],
+ body["sid"],
+ body["token"],
+ )
+ return 200, response
class ThreepidRestServlet(RestServlet):
@@ -583,74 +626,147 @@ class ThreepidRestServlet(RestServlet):
self.datastore = hs.get_datastore()
self.http_client = hs.get_simple_http_client()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request)
- threepids = yield self.datastore.user_get_threepids(
- requester.user.to_string()
- )
+ threepids = await self.datastore.user_get_threepids(requester.user.to_string())
- defer.returnValue((200, {'threepids': threepids}))
+ return 200, {"threepids": threepids}
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
if self.hs.config.disable_3pid_changes:
raise SynapseError(400, "3PID changes disabled on this server")
- body = parse_json_object_from_request(request)
-
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
+ body = parse_json_object_from_request(request)
# skip validation if this is a shadow 3PID from an AS
- if not requester.app_service:
- threePidCreds = body.get('threePidCreds')
- threePidCreds = body.get('three_pid_creds', threePidCreds)
- if threePidCreds is None:
- raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
-
- threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)
-
- if not threepid:
- raise SynapseError(
- 400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED
- )
-
- for reqd in ['medium', 'address', 'validated_at']:
- if reqd not in threepid:
- logger.warn("Couldn't add 3pid: invalid response from ID server")
- raise SynapseError(500, "Invalid response from ID Server")
- else:
+ if requester.app_service:
# XXX: ASes pass in a validated threepid directly to bypass the IS.
# This makes the API entirely change shape when we have an AS token;
# it really should be an entirely separate API - perhaps
# /account/3pid/replicate or something.
- threepid = body.get('threepid')
+ threepid = body.get("threepid")
- yield self.auth_handler.add_threepid(
- user_id,
- threepid['medium'],
- threepid['address'],
- threepid['validated_at'],
- )
+ await self.auth_handler.add_threepid(
+ user_id,
+ threepid["medium"],
+ threepid["address"],
+ threepid["validated_at"],
+ )
+
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
- if not requester.app_service and ('bind' in body and body['bind']):
- logger.debug(
- "Binding threepid %s to %s",
- threepid, user_id
+ return 200, {}
+
+ threepid_creds = body.get("threePidCreds") or body.get("three_pid_creds")
+ if threepid_creds is None:
+ raise SynapseError(
+ 400, "Missing param three_pid_creds", Codes.MISSING_PARAM
)
- yield self.identity_handler.bind_threepid(
- threePidCreds, user_id
+ assert_params_in_dict(threepid_creds, ["client_secret", "sid"])
+
+ sid = threepid_creds["sid"]
+ client_secret = threepid_creds["client_secret"]
+ assert_valid_client_secret(client_secret)
+
+ validation_session = await self.identity_handler.validate_threepid_session(
+ client_secret, sid
+ )
+ if validation_session:
+ await self.auth_handler.add_threepid(
+ user_id,
+ validation_session["medium"],
+ validation_session["address"],
+ validation_session["validated_at"],
)
- if self.hs.config.shadow_server:
- shadow_user = UserID(
- requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ threepid = {
+ "medium": validation_session["medium"],
+ "address": validation_session["address"],
+ "validated_at": validation_session["validated_at"],
+ }
+ self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
+
+ return 200, {}
+
+ raise SynapseError(
+ 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED
+ )
+
+ @defer.inlineCallbacks
+ def shadow_3pid(self, body, user_id):
+ # TODO: retries
+ shadow_hs_url = self.hs.config.shadow_server.get("hs_url")
+ as_token = self.hs.config.shadow_server.get("as_token")
+
+ yield self.http_client.post_json_get_json(
+ "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
+ )
+
+
+class ThreepidAddRestServlet(RestServlet):
+ PATTERNS = client_patterns("/account/3pid/add$", releases=(), unstable=True)
+
+ def __init__(self, hs):
+ super(ThreepidAddRestServlet, self).__init__()
+ self.hs = hs
+ self.identity_handler = hs.get_handlers().identity_handler
+ self.auth = hs.get_auth()
+ self.auth_handler = hs.get_auth_handler()
+ self.http_client = hs.get_simple_http_client()
+
+ @interactive_auth_handler
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+ body = parse_json_object_from_request(request)
+
+ assert_params_in_dict(body, ["client_secret", "sid"])
+ sid = body["sid"]
+ client_secret = body["client_secret"]
+ assert_valid_client_secret(client_secret)
+
+ await self.auth_handler.validate_user_via_ui_auth(
+ requester, body, self.hs.get_ip_from_request(request)
+ )
+
+ validation_session = await self.identity_handler.validate_threepid_session(
+ client_secret, sid
+ )
+ if validation_session:
+ await self.auth_handler.add_threepid(
+ user_id,
+ validation_session["medium"],
+ validation_session["address"],
+ validation_session["validated_at"],
)
- self.shadow_3pid({'threepid': threepid}, shadow_user.to_string())
+ if self.hs.config.shadow_server:
+ shadow_user = UserID(
+ requester.user.localpart, self.hs.config.shadow_server.get("hs")
+ )
+ threepid = {
+ "medium": validation_session["medium"],
+ "address": validation_session["address"],
+ "validated_at": validation_session["validated_at"],
+ }
+ self.shadow_3pid({"threepid": threepid}, shadow_user.to_string())
+ return 200, {}
- defer.returnValue((200, {}))
+ raise SynapseError(
+ 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED
+ )
@defer.inlineCallbacks
def shadow_3pid(self, body, user_id):
@@ -659,13 +775,72 @@ class ThreepidRestServlet(RestServlet):
as_token = self.hs.config.shadow_server.get("as_token")
yield self.http_client.post_json_get_json(
- "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s" % (
- shadow_hs_url, as_token, user_id,
- ),
- body
+ "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
)
+class ThreepidBindRestServlet(RestServlet):
+ PATTERNS = client_patterns("/account/3pid/bind$", releases=(), unstable=True)
+
+ def __init__(self, hs):
+ super(ThreepidBindRestServlet, self).__init__()
+ self.hs = hs
+ self.identity_handler = hs.get_handlers().identity_handler
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request):
+ body = parse_json_object_from_request(request)
+
+ assert_params_in_dict(body, ["id_server", "sid", "client_secret"])
+ id_server = body["id_server"]
+ sid = body["sid"]
+ id_access_token = body.get("id_access_token") # optional
+ client_secret = body["client_secret"]
+ assert_valid_client_secret(client_secret)
+
+ requester = await self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+
+ await self.identity_handler.bind_threepid(
+ client_secret, sid, user_id, id_server, id_access_token
+ )
+
+ return 200, {}
+
+
+class ThreepidUnbindRestServlet(RestServlet):
+ PATTERNS = client_patterns("/account/3pid/unbind$", releases=(), unstable=True)
+
+ def __init__(self, hs):
+ super(ThreepidUnbindRestServlet, self).__init__()
+ self.hs = hs
+ self.identity_handler = hs.get_handlers().identity_handler
+ self.auth = hs.get_auth()
+ self.datastore = self.hs.get_datastore()
+
+ async def on_POST(self, request):
+ """Unbind the given 3pid from a specific identity server, or identity servers that are
+ known to have this 3pid bound
+ """
+ requester = await self.auth.get_user_by_req(request)
+ body = parse_json_object_from_request(request)
+ assert_params_in_dict(body, ["medium", "address"])
+
+ medium = body.get("medium")
+ address = body.get("address")
+ id_server = body.get("id_server")
+
+ # Attempt to unbind the threepid from an identity server. If id_server is None, try to
+ # unbind from all identity servers this threepid has been added to in the past
+ result = await self.identity_handler.try_unbind_threepid(
+ requester.user.to_string(),
+ {"address": address, "medium": medium, "id_server": id_server},
+ )
+ return 200, {"id_server_unbind_result": "success" if result else "no-support"}
+
+
class ThreepidDeleteRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/delete$")
@@ -676,20 +851,19 @@ class ThreepidDeleteRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
self.http_client = hs.get_simple_http_client()
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
if self.hs.config.disable_3pid_changes:
raise SynapseError(400, "3PID changes disabled on this server")
body = parse_json_object_from_request(request)
- assert_params_in_dict(body, ['medium', 'address'])
+ assert_params_in_dict(body, ["medium", "address"])
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
try:
- ret = yield self.auth_handler.delete_threepid(
- user_id, body['medium'], body['address'], body.get("id_server"),
+ ret = await self.auth_handler.delete_threepid(
+ user_id, body["medium"], body["address"], body.get("id_server")
)
except Exception:
# NB. This endpoint should succeed if there is nothing to
@@ -709,9 +883,7 @@ class ThreepidDeleteRestServlet(RestServlet):
else:
id_server_unbind_result = "no-support"
- defer.returnValue((200, {
- "id_server_unbind_result": id_server_unbind_result,
- }))
+ return 200, {"id_server_unbind_result": id_server_unbind_result}
@defer.inlineCallbacks
def shadow_3pid_delete(self, body, user_id):
@@ -720,10 +892,9 @@ class ThreepidDeleteRestServlet(RestServlet):
as_token = self.hs.config.shadow_server.get("as_token")
yield self.http_client.post_json_get_json(
- "%s/_matrix/client/r0/account/3pid/delete?access_token=%s&user_id=%s" % (
- shadow_hs_url, as_token, user_id
- ),
- body
+ "%s/_matrix/client/r0/account/3pid/delete?access_token=%s&user_id=%s"
+ % (shadow_hs_url, as_token, user_id),
+ body,
)
@@ -753,7 +924,7 @@ class ThreepidLookupRestServlet(RestServlet):
# Proxy the request to the identity server. lookup_3pid handles checking
# if the lookup is allowed so we don't need to do it here.
- ret = yield self.identity_handler.lookup_3pid(id_server, medium, address)
+ ret = yield self.identity_handler.proxy_lookup_3pid(id_server, medium, address)
defer.returnValue((200, ret))
@@ -779,8 +950,8 @@ class ThreepidBulkLookupRestServlet(RestServlet):
# Proxy the request to the identity server. lookup_3pid handles checking
# if the lookup is allowed so we don't need to do it here.
- ret = yield self.identity_handler.bulk_lookup_3pid(
- body["id_server"], body["threepids"],
+ ret = yield self.identity_handler.proxy_bulk_lookup_3pid(
+ body["id_server"], body["threepids"]
)
defer.returnValue((200, ret))
@@ -793,22 +964,25 @@ class WhoamiRestServlet(RestServlet):
super(WhoamiRestServlet, self).__init__()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request)
- defer.returnValue((200, {'user_id': requester.user.to_string()}))
+ return 200, {"user_id": requester.user.to_string()}
def register_servlets(hs, http_server):
EmailPasswordRequestTokenRestServlet(hs).register(http_server)
- MsisdnPasswordRequestTokenRestServlet(hs).register(http_server)
PasswordResetSubmitTokenServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
EmailThreepidRequestTokenRestServlet(hs).register(http_server)
MsisdnThreepidRequestTokenRestServlet(hs).register(http_server)
+ AddThreepidEmailSubmitTokenServlet(hs).register(http_server)
+ AddThreepidMsisdnSubmitTokenServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server)
+ ThreepidAddRestServlet(hs).register(http_server)
+ ThreepidBindRestServlet(hs).register(http_server)
+ ThreepidUnbindRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server)
ThreepidLookupRestServlet(hs).register(http_server)
ThreepidBulkLookupRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index 17b967d363..ddb011d864 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID
@@ -31,6 +29,7 @@ class AccountDataServlet(RestServlet):
PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1
GET /user/{user_id}/account_data/{account_dataType} HTTP/1.1
"""
+
PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)"
)
@@ -40,11 +39,14 @@ class AccountDataServlet(RestServlet):
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
+ self._is_worker = hs.config.worker_app is not None
self._profile_handler = hs.get_profile_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, user_id, account_data_type):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, user_id, account_data_type):
+ if self._is_worker:
+ raise Exception("Cannot handle PUT /account_data on worker")
+
+ requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
@@ -52,33 +54,30 @@ class AccountDataServlet(RestServlet):
if account_data_type == "im.vector.hide_profile":
user = UserID.from_string(user_id)
- hide_profile = body.get('hide_profile')
- yield self._profile_handler.set_active(user, not hide_profile, True)
+ hide_profile = body.get("hide_profile")
+ await self._profile_handler.set_active(user, not hide_profile, True)
- max_id = yield self.store.add_account_data_for_user(
+ max_id = await self.store.add_account_data_for_user(
user_id, account_data_type, body
)
- self.notifier.on_new_event(
- "account_data_key", max_id, users=[user_id]
- )
+ self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
- defer.returnValue((200, {}))
+ return 200, {}
- @defer.inlineCallbacks
- def on_GET(self, request, user_id, account_data_type):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request, user_id, account_data_type):
+ requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.")
- event = yield self.store.get_global_account_data_by_type_for_user(
- account_data_type, user_id,
+ event = await self.store.get_global_account_data_by_type_for_user(
+ account_data_type, user_id
)
if event is None:
raise NotFoundError("Account data not found")
- defer.returnValue((200, event))
+ return 200, event
class RoomAccountDataServlet(RestServlet):
@@ -86,6 +85,7 @@ class RoomAccountDataServlet(RestServlet):
PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1
GET /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1
"""
+
PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)"
"/rooms/(?P<room_id>[^/]*)"
@@ -97,10 +97,13 @@ class RoomAccountDataServlet(RestServlet):
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
+ self._is_worker = hs.config.worker_app is not None
+
+ async def on_PUT(self, request, user_id, room_id, account_data_type):
+ if self._is_worker:
+ raise Exception("Cannot handle PUT /account_data on worker")
- @defer.inlineCallbacks
- def on_PUT(self, request, user_id, room_id, account_data_type):
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
@@ -110,33 +113,30 @@ class RoomAccountDataServlet(RestServlet):
raise SynapseError(
405,
"Cannot set m.fully_read through this API."
- " Use /rooms/!roomId:server.name/read_markers"
+ " Use /rooms/!roomId:server.name/read_markers",
)
- max_id = yield self.store.add_account_data_to_room(
+ max_id = await self.store.add_account_data_to_room(
user_id, room_id, account_data_type, body
)
- self.notifier.on_new_event(
- "account_data_key", max_id, users=[user_id]
- )
+ self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
- defer.returnValue((200, {}))
+ return 200, {}
- @defer.inlineCallbacks
- def on_GET(self, request, user_id, room_id, account_data_type):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request, user_id, room_id, account_data_type):
+ requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.")
- event = yield self.store.get_account_data_for_room_and_type(
- user_id, room_id, account_data_type,
+ event = await self.store.get_account_data_for_room_and_type(
+ user_id, room_id, account_data_type
)
if event is None:
raise NotFoundError("Room account data not found")
- defer.returnValue((200, event))
+ return 200, event
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py
index 8091b78285..2f10fa64e2 100644
--- a/synapse/rest/client/v2_alpha/account_validity.py
+++ b/synapse/rest/client/v2_alpha/account_validity.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import AuthError, SynapseError
from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet
@@ -28,7 +26,9 @@ logger = logging.getLogger(__name__)
class AccountValidityRenewServlet(RestServlet):
PATTERNS = client_patterns("/account_validity/renew$")
- SUCCESS_HTML = b"<html><body>Your account has been successfully renewed.</body><html>"
+ SUCCESS_HTML = (
+ b"<html><body>Your account has been successfully renewed.</body><html>"
+ )
def __init__(self, hs):
"""
@@ -43,13 +43,12 @@ class AccountValidityRenewServlet(RestServlet):
self.success_html = hs.config.account_validity.account_renewed_html_content
self.failure_html = hs.config.account_validity.invalid_token_html_content
- @defer.inlineCallbacks
- def on_GET(self, request):
+ async def on_GET(self, request):
if b"token" not in request.args:
raise SynapseError(400, "Missing renewal token")
renewal_token = request.args[b"token"][0]
- token_valid = yield self.account_activity_handler.renew_account(
+ token_valid = await self.account_activity_handler.renew_account(
renewal_token.decode("utf8")
)
@@ -65,7 +64,6 @@ class AccountValidityRenewServlet(RestServlet):
request.setHeader(b"Content-Length", b"%d" % (len(response),))
request.write(response.encode("utf8"))
finish_request(request)
- defer.returnValue(None)
class AccountValiditySendMailServlet(RestServlet):
@@ -83,16 +81,17 @@ class AccountValiditySendMailServlet(RestServlet):
self.auth = hs.get_auth()
self.account_validity = self.hs.config.account_validity
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
if not self.account_validity.renew_by_email_enabled:
- raise AuthError(403, "Account renewal via email is disabled on this server.")
+ raise AuthError(
+ 403, "Account renewal via email is disabled on this server."
+ )
- requester = yield self.auth.get_user_by_req(request, allow_expired=True)
+ requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string()
- yield self.account_activity_handler.send_renewal_email_to_user(user_id)
+ await self.account_activity_handler.send_renewal_email_to_user(user_id)
- defer.returnValue((200, {}))
+ return 200, {}
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 8dfe5cba02..50e080673b 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError
from synapse.api.urls import CLIENT_API_PREFIX
@@ -122,6 +120,7 @@ class AuthRestServlet(RestServlet):
cannot be handled in the normal flow (with requests to the same endpoint).
Current use is for web fallback auth.
"""
+
PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
def __init__(self, hs):
@@ -138,11 +137,10 @@ class AuthRestServlet(RestServlet):
if stagetype == LoginType.RECAPTCHA:
html = RECAPTCHA_TEMPLATE % {
- 'session': session,
- 'myurl': "%s/r0/auth/%s/fallback/web" % (
- CLIENT_API_PREFIX, LoginType.RECAPTCHA
- ),
- 'sitekey': self.hs.config.recaptcha_public_key,
+ "session": session,
+ "myurl": "%s/r0/auth/%s/fallback/web"
+ % (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
+ "sitekey": self.hs.config.recaptcha_public_key,
}
html_bytes = html.encode("utf8")
request.setResponseCode(200)
@@ -154,14 +152,11 @@ class AuthRestServlet(RestServlet):
return None
elif stagetype == LoginType.TERMS:
html = TERMS_TEMPLATE % {
- 'session': session,
- 'terms_url': "%s_matrix/consent?v=%s" % (
- self.hs.config.public_baseurl,
- self.hs.config.user_consent_version,
- ),
- 'myurl': "%s/r0/auth/%s/fallback/web" % (
- CLIENT_API_PREFIX, LoginType.TERMS
- ),
+ "session": session,
+ "terms_url": "%s_matrix/consent?v=%s"
+ % (self.hs.config.public_baseurl, self.hs.config.user_consent_version),
+ "myurl": "%s/r0/auth/%s/fallback/web"
+ % (CLIENT_API_PREFIX, LoginType.TERMS),
}
html_bytes = html.encode("utf8")
request.setResponseCode(200)
@@ -174,8 +169,7 @@ class AuthRestServlet(RestServlet):
else:
raise SynapseError(404, "Unknown auth stage type")
- @defer.inlineCallbacks
- def on_POST(self, request, stagetype):
+ async def on_POST(self, request, stagetype):
session = parse_string(request, "session")
if not session:
@@ -187,26 +181,20 @@ class AuthRestServlet(RestServlet):
if not response:
raise SynapseError(400, "No captcha response supplied")
- authdict = {
- 'response': response,
- 'session': session,
- }
+ authdict = {"response": response, "session": session}
- success = yield self.auth_handler.add_oob_auth(
- LoginType.RECAPTCHA,
- authdict,
- self.hs.get_ip_from_request(request)
+ success = await self.auth_handler.add_oob_auth(
+ LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request)
)
if success:
html = SUCCESS_TEMPLATE
else:
html = RECAPTCHA_TEMPLATE % {
- 'session': session,
- 'myurl': "%s/r0/auth/%s/fallback/web" % (
- CLIENT_API_PREFIX, LoginType.RECAPTCHA
- ),
- 'sitekey': self.hs.config.recaptcha_public_key,
+ "session": session,
+ "myurl": "%s/r0/auth/%s/fallback/web"
+ % (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
+ "sitekey": self.hs.config.recaptcha_public_key,
}
html_bytes = html.encode("utf8")
request.setResponseCode(200)
@@ -216,33 +204,26 @@ class AuthRestServlet(RestServlet):
request.write(html_bytes)
finish_request(request)
- defer.returnValue(None)
+ return None
elif stagetype == LoginType.TERMS:
- if ('session' not in request.args or
- len(request.args['session'])) == 0:
- raise SynapseError(400, "No session supplied")
-
- session = request.args['session'][0]
- authdict = {'session': session}
+ authdict = {"session": session}
- success = yield self.auth_handler.add_oob_auth(
- LoginType.TERMS,
- authdict,
- self.hs.get_ip_from_request(request)
+ success = await self.auth_handler.add_oob_auth(
+ LoginType.TERMS, authdict, self.hs.get_ip_from_request(request)
)
if success:
html = SUCCESS_TEMPLATE
else:
html = TERMS_TEMPLATE % {
- 'session': session,
- 'terms_url': "%s_matrix/consent?v=%s" % (
+ "session": session,
+ "terms_url": "%s_matrix/consent?v=%s"
+ % (
self.hs.config.public_baseurl,
self.hs.config.user_consent_version,
),
- 'myurl': "%s/r0/auth/%s/fallback/web" % (
- CLIENT_API_PREFIX, LoginType.TERMS
- ),
+ "myurl": "%s/r0/auth/%s/fallback/web"
+ % (CLIENT_API_PREFIX, LoginType.TERMS),
}
html_bytes = html.encode("utf8")
request.setResponseCode(200)
@@ -251,7 +232,7 @@ class AuthRestServlet(RestServlet):
request.write(html_bytes)
finish_request(request)
- defer.returnValue(None)
+ return None
else:
raise SynapseError(404, "Unknown auth stage type")
diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py
index fc7e2f4dd5..fe9d019c44 100644
--- a/synapse/rest/client/v2_alpha/capabilities.py
+++ b/synapse/rest/client/v2_alpha/capabilities.py
@@ -14,8 +14,6 @@
# limitations under the License.
import logging
-from twisted.internet import defer
-
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import RestServlet
@@ -40,10 +38,9 @@ class CapabilitiesRestServlet(RestServlet):
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- user = yield self.store.get_user_by_id(requester.user.to_string())
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
+ user = await self.store.get_user_by_id(requester.user.to_string())
change_password = bool(user["password_hash"])
response = {
@@ -58,7 +55,7 @@ class CapabilitiesRestServlet(RestServlet):
"m.change_password": {"enabled": change_password},
}
}
- defer.returnValue((200, response))
+ return 200, response
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index 78665304a5..94ff73f384 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api import errors
from synapse.http.servlet import (
RestServlet,
@@ -42,13 +40,12 @@ class DevicesRestServlet(RestServlet):
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- devices = yield self.device_handler.get_devices_by_user(
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
+ devices = await self.device_handler.get_devices_by_user(
requester.user.to_string()
)
- defer.returnValue((200, {"devices": devices}))
+ return 200, {"devices": devices}
class DeleteDevicesRestServlet(RestServlet):
@@ -56,6 +53,7 @@ class DeleteDevicesRestServlet(RestServlet):
API for bulk deletion of devices. Accepts a JSON object with a devices
key which lists the device_ids to delete. Requires user interactive auth.
"""
+
PATTERNS = client_patterns("/delete_devices")
def __init__(self, hs):
@@ -66,9 +64,8 @@ class DeleteDevicesRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
try:
body = parse_json_object_from_request(request)
@@ -83,15 +80,14 @@ class DeleteDevicesRestServlet(RestServlet):
assert_params_in_dict(body, ["devices"])
- yield self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request),
+ await self.auth_handler.validate_user_via_ui_auth(
+ requester, body, self.hs.get_ip_from_request(request)
)
- yield self.device_handler.delete_devices(
- requester.user.to_string(),
- body['devices'],
+ await self.device_handler.delete_devices(
+ requester.user.to_string(), body["devices"]
)
- defer.returnValue((200, {}))
+ return 200, {}
class DeviceRestServlet(RestServlet):
@@ -108,19 +104,16 @@ class DeviceRestServlet(RestServlet):
self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, device_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- device = yield self.device_handler.get_device(
- requester.user.to_string(),
- device_id,
+ async def on_GET(self, request, device_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
+ device = await self.device_handler.get_device(
+ requester.user.to_string(), device_id
)
- defer.returnValue((200, device))
+ return 200, device
@interactive_auth_handler
- @defer.inlineCallbacks
- def on_DELETE(self, request, device_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_DELETE(self, request, device_id):
+ requester = await self.auth.get_user_by_req(request)
try:
body = parse_json_object_from_request(request)
@@ -133,26 +126,21 @@ class DeviceRestServlet(RestServlet):
else:
raise
- yield self.auth_handler.validate_user_via_ui_auth(
- requester, body, self.hs.get_ip_from_request(request),
+ await self.auth_handler.validate_user_via_ui_auth(
+ requester, body, self.hs.get_ip_from_request(request)
)
- yield self.device_handler.delete_device(
- requester.user.to_string(), device_id,
- )
- defer.returnValue((200, {}))
+ await self.device_handler.delete_device(requester.user.to_string(), device_id)
+ return 200, {}
- @defer.inlineCallbacks
- def on_PUT(self, request, device_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_PUT(self, request, device_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
body = parse_json_object_from_request(request)
- yield self.device_handler.update_device(
- requester.user.to_string(),
- device_id,
- body
+ await self.device_handler.update_device(
+ requester.user.to_string(), device_id, body
)
- defer.returnValue((200, {}))
+ return 200, {}
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index 65db48c3cc..b28da017cd 100644
--- a/synapse/rest/client/v2_alpha/filter.py
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -15,9 +15,7 @@
import logging
-from twisted.internet import defer
-
-from synapse.api.errors import AuthError, Codes, StoreError, SynapseError
+from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID
@@ -35,10 +33,9 @@ class GetFilterRestServlet(RestServlet):
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
- @defer.inlineCallbacks
- def on_GET(self, request, user_id, filter_id):
+ async def on_GET(self, request, user_id, filter_id):
target_user = UserID.from_string(user_id)
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
if target_user != requester.user:
raise AuthError(403, "Cannot get filters for other users")
@@ -52,14 +49,15 @@ class GetFilterRestServlet(RestServlet):
raise SynapseError(400, "Invalid filter_id")
try:
- filter = yield self.filtering.get_user_filter(
- user_localpart=target_user.localpart,
- filter_id=filter_id,
+ filter_collection = await self.filtering.get_user_filter(
+ user_localpart=target_user.localpart, filter_id=filter_id
)
+ except StoreError as e:
+ if e.code != 404:
+ raise
+ raise NotFoundError("No such filter")
- defer.returnValue((200, filter.get_filter_json()))
- except (KeyError, StoreError):
- raise SynapseError(400, "No such filter", errcode=Codes.NOT_FOUND)
+ return 200, filter_collection.get_filter_json()
class CreateFilterRestServlet(RestServlet):
@@ -71,11 +69,10 @@ class CreateFilterRestServlet(RestServlet):
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
- @defer.inlineCallbacks
- def on_POST(self, request, user_id):
+ async def on_POST(self, request, user_id):
target_user = UserID.from_string(user_id)
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
if target_user != requester.user:
raise AuthError(403, "Cannot create filters for other users")
@@ -84,17 +81,13 @@ class CreateFilterRestServlet(RestServlet):
raise AuthError(403, "Can only create filters for local users")
content = parse_json_object_from_request(request)
- set_timeline_upper_limit(
- content,
- self.hs.config.filter_timeline_limit
- )
+ set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit)
- filter_id = yield self.filtering.add_user_filter(
- user_localpart=target_user.localpart,
- user_filter=content,
+ filter_id = await self.filtering.add_user_filter(
+ user_localpart=target_user.localpart, user_filter=content
)
- defer.returnValue((200, {"filter_id": str(filter_id)}))
+ return 200, {"filter_id": str(filter_id)}
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index d082385ec7..d84a6d7e11 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -16,8 +16,6 @@
import logging
-from twisted.internet import defer
-
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import GroupID
@@ -29,6 +27,7 @@ logger = logging.getLogger(__name__)
class GroupServlet(RestServlet):
"""Get the group profile
"""
+
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$")
def __init__(self, hs):
@@ -37,34 +36,32 @@ class GroupServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- group_description = yield self.groups_handler.get_group_profile(
- group_id,
- requester_user_id,
+ group_description = await self.groups_handler.get_group_profile(
+ group_id, requester_user_id
)
- defer.returnValue((200, group_description))
+ return 200, group_description
- @defer.inlineCallbacks
- def on_POST(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- yield self.groups_handler.update_group_profile(
- group_id, requester_user_id, content,
+ await self.groups_handler.update_group_profile(
+ group_id, requester_user_id, content
)
- defer.returnValue((200, {}))
+ return 200, {}
class GroupSummaryServlet(RestServlet):
"""Get the full group summary
"""
+
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$")
def __init__(self, hs):
@@ -73,17 +70,15 @@ class GroupSummaryServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- get_group_summary = yield self.groups_handler.get_group_summary(
- group_id,
- requester_user_id,
+ get_group_summary = await self.groups_handler.get_group_summary(
+ group_id, requester_user_id
)
- defer.returnValue((200, get_group_summary))
+ return 200, get_group_summary
class GroupSummaryRoomsCatServlet(RestServlet):
@@ -93,6 +88,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
- /groups/:group/summary/rooms/:room_id
- /groups/:group/summary/categories/:category/rooms/:room_id
"""
+
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/summary"
"(/categories/(?P<category_id>[^/]+))?"
@@ -105,38 +101,36 @@ class GroupSummaryRoomsCatServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id, category_id, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id, category_id, room_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- resp = yield self.groups_handler.update_group_summary_room(
- group_id, requester_user_id,
+ resp = await self.groups_handler.update_group_summary_room(
+ group_id,
+ requester_user_id,
room_id=room_id,
category_id=category_id,
content=content,
)
- defer.returnValue((200, resp))
+ return 200, resp
- @defer.inlineCallbacks
- def on_DELETE(self, request, group_id, category_id, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_DELETE(self, request, group_id, category_id, room_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
- resp = yield self.groups_handler.delete_group_summary_room(
- group_id, requester_user_id,
- room_id=room_id,
- category_id=category_id,
+ resp = await self.groups_handler.delete_group_summary_room(
+ group_id, requester_user_id, room_id=room_id, category_id=category_id
)
- defer.returnValue((200, resp))
+ return 200, resp
class GroupCategoryServlet(RestServlet):
"""Get/add/update/delete a group category
"""
+
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
)
@@ -147,51 +141,43 @@ class GroupCategoryServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id, category_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, group_id, category_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- category = yield self.groups_handler.get_group_category(
- group_id, requester_user_id,
- category_id=category_id,
+ category = await self.groups_handler.get_group_category(
+ group_id, requester_user_id, category_id=category_id
)
- defer.returnValue((200, category))
+ return 200, category
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id, category_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id, category_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- resp = yield self.groups_handler.update_group_category(
- group_id, requester_user_id,
- category_id=category_id,
- content=content,
+ resp = await self.groups_handler.update_group_category(
+ group_id, requester_user_id, category_id=category_id, content=content
)
- defer.returnValue((200, resp))
+ return 200, resp
- @defer.inlineCallbacks
- def on_DELETE(self, request, group_id, category_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_DELETE(self, request, group_id, category_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
- resp = yield self.groups_handler.delete_group_category(
- group_id, requester_user_id,
- category_id=category_id,
+ resp = await self.groups_handler.delete_group_category(
+ group_id, requester_user_id, category_id=category_id
)
- defer.returnValue((200, resp))
+ return 200, resp
class GroupCategoriesServlet(RestServlet):
"""Get all group categories
"""
- PATTERNS = client_patterns(
- "/groups/(?P<group_id>[^/]*)/categories/$"
- )
+
+ PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/categories/$")
def __init__(self, hs):
super(GroupCategoriesServlet, self).__init__()
@@ -199,24 +185,22 @@ class GroupCategoriesServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- category = yield self.groups_handler.get_group_categories(
- group_id, requester_user_id,
+ category = await self.groups_handler.get_group_categories(
+ group_id, requester_user_id
)
- defer.returnValue((200, category))
+ return 200, category
class GroupRoleServlet(RestServlet):
"""Get/add/update/delete a group role
"""
- PATTERNS = client_patterns(
- "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
- )
+
+ PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$")
def __init__(self, hs):
super(GroupRoleServlet, self).__init__()
@@ -224,51 +208,43 @@ class GroupRoleServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id, role_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, group_id, role_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- category = yield self.groups_handler.get_group_role(
- group_id, requester_user_id,
- role_id=role_id,
+ category = await self.groups_handler.get_group_role(
+ group_id, requester_user_id, role_id=role_id
)
- defer.returnValue((200, category))
+ return 200, category
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id, role_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id, role_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- resp = yield self.groups_handler.update_group_role(
- group_id, requester_user_id,
- role_id=role_id,
- content=content,
+ resp = await self.groups_handler.update_group_role(
+ group_id, requester_user_id, role_id=role_id, content=content
)
- defer.returnValue((200, resp))
+ return 200, resp
- @defer.inlineCallbacks
- def on_DELETE(self, request, group_id, role_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_DELETE(self, request, group_id, role_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
- resp = yield self.groups_handler.delete_group_role(
- group_id, requester_user_id,
- role_id=role_id,
+ resp = await self.groups_handler.delete_group_role(
+ group_id, requester_user_id, role_id=role_id
)
- defer.returnValue((200, resp))
+ return 200, resp
class GroupRolesServlet(RestServlet):
"""Get all group roles
"""
- PATTERNS = client_patterns(
- "/groups/(?P<group_id>[^/]*)/roles/$"
- )
+
+ PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/$")
def __init__(self, hs):
super(GroupRolesServlet, self).__init__()
@@ -276,16 +252,15 @@ class GroupRolesServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- category = yield self.groups_handler.get_group_roles(
- group_id, requester_user_id,
+ category = await self.groups_handler.get_group_roles(
+ group_id, requester_user_id
)
- defer.returnValue((200, category))
+ return 200, category
class GroupSummaryUsersRoleServlet(RestServlet):
@@ -295,6 +270,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
- /groups/:group/summary/users/:room_id
- /groups/:group/summary/roles/:role/users/:user_id
"""
+
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/summary"
"(/roles/(?P<role_id>[^/]+))?"
@@ -307,38 +283,36 @@ class GroupSummaryUsersRoleServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id, role_id, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id, role_id, user_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- resp = yield self.groups_handler.update_group_summary_user(
- group_id, requester_user_id,
+ resp = await self.groups_handler.update_group_summary_user(
+ group_id,
+ requester_user_id,
user_id=user_id,
role_id=role_id,
content=content,
)
- defer.returnValue((200, resp))
+ return 200, resp
- @defer.inlineCallbacks
- def on_DELETE(self, request, group_id, role_id, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_DELETE(self, request, group_id, role_id, user_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
- resp = yield self.groups_handler.delete_group_summary_user(
- group_id, requester_user_id,
- user_id=user_id,
- role_id=role_id,
+ resp = await self.groups_handler.delete_group_summary_user(
+ group_id, requester_user_id, user_id=user_id, role_id=role_id
)
- defer.returnValue((200, resp))
+ return 200, resp
class GroupRoomServlet(RestServlet):
"""Get all rooms in a group
"""
+
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
def __init__(self, hs):
@@ -347,19 +321,21 @@ class GroupRoomServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- result = yield self.groups_handler.get_rooms_in_group(group_id, requester_user_id)
+ result = await self.groups_handler.get_rooms_in_group(
+ group_id, requester_user_id
+ )
- defer.returnValue((200, result))
+ return 200, result
class GroupUsersServlet(RestServlet):
"""Get all users in a group
"""
+
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$")
def __init__(self, hs):
@@ -368,19 +344,21 @@ class GroupUsersServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- result = yield self.groups_handler.get_users_in_group(group_id, requester_user_id)
+ result = await self.groups_handler.get_users_in_group(
+ group_id, requester_user_id
+ )
- defer.returnValue((200, result))
+ return 200, result
class GroupInvitedUsersServlet(RestServlet):
"""Get users invited to a group
"""
+
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
def __init__(self, hs):
@@ -389,22 +367,21 @@ class GroupInvitedUsersServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
- result = yield self.groups_handler.get_invited_users_in_group(
- group_id,
- requester_user_id,
+ result = await self.groups_handler.get_invited_users_in_group(
+ group_id, requester_user_id
)
- defer.returnValue((200, result))
+ return 200, result
class GroupSettingJoinPolicyServlet(RestServlet):
"""Set group join policy
"""
+
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
def __init__(self, hs):
@@ -412,25 +389,23 @@ class GroupSettingJoinPolicyServlet(RestServlet):
self.auth = hs.get_auth()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- result = yield self.groups_handler.set_group_join_policy(
- group_id,
- requester_user_id,
- content,
+ result = await self.groups_handler.set_group_join_policy(
+ group_id, requester_user_id, content
)
- defer.returnValue((200, result))
+ return 200, result
class GroupCreateServlet(RestServlet):
"""Create a group
"""
+
PATTERNS = client_patterns("/create_group$")
def __init__(self, hs):
@@ -440,9 +415,8 @@ class GroupCreateServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
self.server_name = hs.hostname
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
# TODO: Create group on remote server
@@ -450,18 +424,17 @@ class GroupCreateServlet(RestServlet):
localpart = content.pop("localpart")
group_id = GroupID(localpart, self.server_name).to_string()
- result = yield self.groups_handler.create_group(
- group_id,
- requester_user_id,
- content,
+ result = await self.groups_handler.create_group(
+ group_id, requester_user_id, content
)
- defer.returnValue((200, result))
+ return 200, result
class GroupAdminRoomsServlet(RestServlet):
"""Add a room to the group
"""
+
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
)
@@ -472,33 +445,32 @@ class GroupAdminRoomsServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id, room_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- result = yield self.groups_handler.add_room_to_group(
- group_id, requester_user_id, room_id, content,
+ result = await self.groups_handler.add_room_to_group(
+ group_id, requester_user_id, room_id, content
)
- defer.returnValue((200, result))
+ return 200, result
- @defer.inlineCallbacks
- def on_DELETE(self, request, group_id, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_DELETE(self, request, group_id, room_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
- result = yield self.groups_handler.remove_room_from_group(
- group_id, requester_user_id, room_id,
+ result = await self.groups_handler.remove_room_from_group(
+ group_id, requester_user_id, room_id
)
- defer.returnValue((200, result))
+ return 200, result
class GroupAdminRoomsConfigServlet(RestServlet):
"""Update the config of a room in a group
"""
+
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)"
"/config/(?P<config_key>[^/]*)$"
@@ -510,22 +482,22 @@ class GroupAdminRoomsConfigServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id, room_id, config_key):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id, room_id, config_key):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- result = yield self.groups_handler.update_room_in_group(
- group_id, requester_user_id, room_id, config_key, content,
+ result = await self.groups_handler.update_room_in_group(
+ group_id, requester_user_id, room_id, config_key, content
)
- defer.returnValue((200, result))
+ return 200, result
class GroupAdminUsersInviteServlet(RestServlet):
"""Invite a user to the group
"""
+
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
)
@@ -538,23 +510,23 @@ class GroupAdminUsersInviteServlet(RestServlet):
self.store = hs.get_datastore()
self.is_mine_id = hs.is_mine_id
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id, user_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
config = content.get("config", {})
- result = yield self.groups_handler.invite(
- group_id, user_id, requester_user_id, config,
+ result = await self.groups_handler.invite(
+ group_id, user_id, requester_user_id, config
)
- defer.returnValue((200, result))
+ return 200, result
class GroupAdminUsersKickServlet(RestServlet):
"""Kick a user from the group
"""
+
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
)
@@ -565,25 +537,23 @@ class GroupAdminUsersKickServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id, user_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- result = yield self.groups_handler.remove_user_from_group(
- group_id, user_id, requester_user_id, content,
+ result = await self.groups_handler.remove_user_from_group(
+ group_id, user_id, requester_user_id, content
)
- defer.returnValue((200, result))
+ return 200, result
class GroupSelfLeaveServlet(RestServlet):
"""Leave a joined group
"""
- PATTERNS = client_patterns(
- "/groups/(?P<group_id>[^/]*)/self/leave$"
- )
+
+ PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/leave$")
def __init__(self, hs):
super(GroupSelfLeaveServlet, self).__init__()
@@ -591,25 +561,23 @@ class GroupSelfLeaveServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- result = yield self.groups_handler.remove_user_from_group(
- group_id, requester_user_id, requester_user_id, content,
+ result = await self.groups_handler.remove_user_from_group(
+ group_id, requester_user_id, requester_user_id, content
)
- defer.returnValue((200, result))
+ return 200, result
class GroupSelfJoinServlet(RestServlet):
"""Attempt to join a group, or knock
"""
- PATTERNS = client_patterns(
- "/groups/(?P<group_id>[^/]*)/self/join$"
- )
+
+ PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/join$")
def __init__(self, hs):
super(GroupSelfJoinServlet, self).__init__()
@@ -617,25 +585,23 @@ class GroupSelfJoinServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- result = yield self.groups_handler.join_group(
- group_id, requester_user_id, content,
+ result = await self.groups_handler.join_group(
+ group_id, requester_user_id, content
)
- defer.returnValue((200, result))
+ return 200, result
class GroupSelfAcceptInviteServlet(RestServlet):
"""Accept a group invite
"""
- PATTERNS = client_patterns(
- "/groups/(?P<group_id>[^/]*)/self/accept_invite$"
- )
+
+ PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/accept_invite$")
def __init__(self, hs):
super(GroupSelfAcceptInviteServlet, self).__init__()
@@ -643,25 +609,23 @@ class GroupSelfAcceptInviteServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- result = yield self.groups_handler.accept_invite(
- group_id, requester_user_id, content,
+ result = await self.groups_handler.accept_invite(
+ group_id, requester_user_id, content
)
- defer.returnValue((200, result))
+ return 200, result
class GroupSelfUpdatePublicityServlet(RestServlet):
"""Update whether we publicise a users membership of a group
"""
- PATTERNS = client_patterns(
- "/groups/(?P<group_id>[^/]*)/self/update_publicity$"
- )
+
+ PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/update_publicity$")
def __init__(self, hs):
super(GroupSelfUpdatePublicityServlet, self).__init__()
@@ -669,26 +633,22 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
publicise = content["publicise"]
- yield self.store.update_group_publicity(
- group_id, requester_user_id, publicise,
- )
+ await self.store.update_group_publicity(group_id, requester_user_id, publicise)
- defer.returnValue((200, {}))
+ return 200, {}
class PublicisedGroupsForUserServlet(RestServlet):
"""Get the list of groups a user is advertising
"""
- PATTERNS = client_patterns(
- "/publicised_groups/(?P<user_id>[^/]*)$"
- )
+
+ PATTERNS = client_patterns("/publicised_groups/(?P<user_id>[^/]*)$")
def __init__(self, hs):
super(PublicisedGroupsForUserServlet, self).__init__()
@@ -697,23 +657,19 @@ class PublicisedGroupsForUserServlet(RestServlet):
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, user_id):
+ await self.auth.get_user_by_req(request, allow_guest=True)
- result = yield self.groups_handler.get_publicised_groups_for_user(
- user_id
- )
+ result = await self.groups_handler.get_publicised_groups_for_user(user_id)
- defer.returnValue((200, result))
+ return 200, result
class PublicisedGroupsForUsersServlet(RestServlet):
"""Get the list of groups a user is advertising
"""
- PATTERNS = client_patterns(
- "/publicised_groups$"
- )
+
+ PATTERNS = client_patterns("/publicised_groups$")
def __init__(self, hs):
super(PublicisedGroupsForUsersServlet, self).__init__()
@@ -722,26 +678,22 @@ class PublicisedGroupsForUsersServlet(RestServlet):
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_POST(self, request):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_POST(self, request):
+ await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
user_ids = content["user_ids"]
- result = yield self.groups_handler.bulk_get_publicised_groups(
- user_ids
- )
+ result = await self.groups_handler.bulk_get_publicised_groups(user_ids)
- defer.returnValue((200, result))
+ return 200, result
class GroupsForUserServlet(RestServlet):
"""Get all groups the logged in user is joined to
"""
- PATTERNS = client_patterns(
- "/joined_groups$"
- )
+
+ PATTERNS = client_patterns("/joined_groups$")
def __init__(self, hs):
super(GroupsForUserServlet, self).__init__()
@@ -749,14 +701,13 @@ class GroupsForUserServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- result = yield self.groups_handler.get_joined_groups(requester_user_id)
+ result = await self.groups_handler.get_joined_groups(requester_user_id)
- defer.returnValue((200, result))
+ return 200, result
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 4cbfbf5631..f7ed4daf90 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,8 +16,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import SynapseError
from synapse.http.servlet import (
RestServlet,
@@ -24,9 +23,10 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.types import StreamToken
-from ._base import client_patterns
+from ._base import client_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
@@ -56,6 +56,7 @@ class KeyUploadServlet(RestServlet):
},
}
"""
+
PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
def __init__(self, hs):
@@ -67,33 +68,42 @@ class KeyUploadServlet(RestServlet):
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
- @defer.inlineCallbacks
- def on_POST(self, request, device_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ @trace(opname="upload_keys")
+ async def on_POST(self, request, device_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
if device_id is not None:
# passing the device_id here is deprecated; however, we allow it
# for now for compatibility with older clients.
- if (requester.device_id is not None and
- device_id != requester.device_id):
- logger.warning("Client uploading keys for a different device "
- "(logged in as %s, uploading for %s)",
- requester.device_id, device_id)
+ if requester.device_id is not None and device_id != requester.device_id:
+ set_tag("error", True)
+ log_kv(
+ {
+ "message": "Client uploading keys for a different device",
+ "logged_in_id": requester.device_id,
+ "key_being_uploaded": device_id,
+ }
+ )
+ logger.warning(
+ "Client uploading keys for a different device "
+ "(logged in as %s, uploading for %s)",
+ requester.device_id,
+ device_id,
+ )
else:
device_id = requester.device_id
if device_id is None:
raise SynapseError(
- 400,
- "To upload keys, you must pass device_id when authenticating"
+ 400, "To upload keys, you must pass device_id when authenticating"
)
- result = yield self.e2e_keys_handler.upload_keys_for_user(
+ result = await self.e2e_keys_handler.upload_keys_for_user(
user_id, device_id, body
)
- defer.returnValue((200, result))
+ return 200, result
class KeyQueryServlet(RestServlet):
@@ -141,13 +151,13 @@ class KeyQueryServlet(RestServlet):
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
- @defer.inlineCallbacks
- def on_POST(self, request):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
+ user_id = requester.user.to_string()
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
- result = yield self.e2e_keys_handler.query_devices(body, timeout)
- defer.returnValue((200, result))
+ result = await self.e2e_keys_handler.query_devices(body, timeout, user_id)
+ return 200, result
class KeyChangesServlet(RestServlet):
@@ -159,6 +169,7 @@ class KeyChangesServlet(RestServlet):
200 OK
{ "changed": ["@foo:example.com"] }
"""
+
PATTERNS = client_patterns("/keys/changes$")
def __init__(self, hs):
@@ -170,25 +181,23 @@ class KeyChangesServlet(RestServlet):
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
from_token_string = parse_string(request, "from")
+ set_tag("from", from_token_string)
# We want to enforce they do pass us one, but we ignore it and return
# changes after the "to" as well as before.
- parse_string(request, "to")
+ set_tag("to", parse_string(request, "to"))
from_token = StreamToken.from_string(from_token_string)
user_id = requester.user.to_string()
- results = yield self.device_handler.get_user_ids_changed(
- user_id, from_token,
- )
+ results = await self.device_handler.get_user_ids_changed(user_id, from_token)
- defer.returnValue((200, results))
+ return 200, results
class OneTimeKeyServlet(RestServlet):
@@ -209,6 +218,7 @@ class OneTimeKeyServlet(RestServlet):
} } } }
"""
+
PATTERNS = client_patterns("/keys/claim$")
def __init__(self, hs):
@@ -216,16 +226,97 @@ class OneTimeKeyServlet(RestServlet):
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
- @defer.inlineCallbacks
- def on_POST(self, request):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_POST(self, request):
+ await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
- result = yield self.e2e_keys_handler.claim_one_time_keys(
- body,
- timeout,
+ result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout)
+ return 200, result
+
+
+class SigningKeyUploadServlet(RestServlet):
+ """
+ POST /keys/device_signing/upload HTTP/1.1
+ Content-Type: application/json
+
+ {
+ }
+ """
+
+ PATTERNS = client_patterns("/keys/device_signing/upload$", releases=())
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super(SigningKeyUploadServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.e2e_keys_handler = hs.get_e2e_keys_handler()
+ self.auth_handler = hs.get_auth_handler()
+
+ @interactive_auth_handler
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
+ user_id = requester.user.to_string()
+ body = parse_json_object_from_request(request)
+
+ await self.auth_handler.validate_user_via_ui_auth(
+ requester, body, self.hs.get_ip_from_request(request)
+ )
+
+ result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
+ return 200, result
+
+
+class SignaturesUploadServlet(RestServlet):
+ """
+ POST /keys/signatures/upload HTTP/1.1
+ Content-Type: application/json
+
+ {
+ "@alice:example.com": {
+ "<device_id>": {
+ "user_id": "<user_id>",
+ "device_id": "<device_id>",
+ "algorithms": [
+ "m.olm.curve25519-aes-sha256",
+ "m.megolm.v1.aes-sha"
+ ],
+ "keys": {
+ "<algorithm>:<device_id>": "<key_base64>",
+ },
+ "signatures": {
+ "<signing_user_id>": {
+ "<algorithm>:<signing_key_base64>": "<signature_base64>>"
+ }
+ }
+ }
+ }
+ }
+ """
+
+ PATTERNS = client_patterns("/keys/signatures/upload$")
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super(SignaturesUploadServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.e2e_keys_handler = hs.get_e2e_keys_handler()
+
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
+ user_id = requester.user.to_string()
+ body = parse_json_object_from_request(request)
+
+ result = await self.e2e_keys_handler.upload_signatures_for_device_keys(
+ user_id, body
)
- defer.returnValue((200, result))
+ return 200, result
def register_servlets(hs, http_server):
@@ -233,3 +324,5 @@ def register_servlets(hs, http_server):
KeyQueryServlet(hs).register(http_server)
KeyChangesServlet(hs).register(http_server)
OneTimeKeyServlet(hs).register(http_server)
+ SigningKeyUploadServlet(hs).register(http_server)
+ SignaturesUploadServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py
index 53e666989b..aa911d75ee 100644
--- a/synapse/rest/client/v2_alpha/notifications.py
+++ b/synapse/rest/client/v2_alpha/notifications.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.events.utils import format_event_for_client_v2_without_room_id
from synapse.http.servlet import RestServlet, parse_integer, parse_string
@@ -35,9 +33,8 @@ class NotificationsServlet(RestServlet):
self.clock = hs.get_clock()
self._event_serializer = hs.get_event_client_serializer()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
from_token = parse_string(request, "from", required=False)
@@ -46,16 +43,16 @@ class NotificationsServlet(RestServlet):
limit = min(limit, 500)
- push_actions = yield self.store.get_push_actions_for_user(
+ push_actions = await self.store.get_push_actions_for_user(
user_id, from_token, limit, only_highlight=(only == "highlight")
)
- receipts_by_room = yield self.store.get_receipts_for_user_with_orderings(
- user_id, 'm.read'
+ receipts_by_room = await self.store.get_receipts_for_user_with_orderings(
+ user_id, "m.read"
)
notif_event_ids = [pa["event_id"] for pa in push_actions]
- notif_events = yield self.store.get_events(notif_event_ids)
+ notif_events = await self.store.get_events(notif_event_ids)
returned_push_actions = []
@@ -67,11 +64,13 @@ class NotificationsServlet(RestServlet):
"profile_tag": pa["profile_tag"],
"actions": pa["actions"],
"ts": pa["received_ts"],
- "event": (yield self._event_serializer.serialize_event(
- notif_events[pa["event_id"]],
- self.clock.time_msec(),
- event_format=format_event_for_client_v2_without_room_id,
- )),
+ "event": (
+ await self._event_serializer.serialize_event(
+ notif_events[pa["event_id"]],
+ self.clock.time_msec(),
+ event_format=format_event_for_client_v2_without_room_id,
+ )
+ ),
}
if pa["room_id"] not in receipts_by_room:
@@ -80,17 +79,13 @@ class NotificationsServlet(RestServlet):
receipt = receipts_by_room[pa["room_id"]]
returned_pa["read"] = (
- receipt["topological_ordering"], receipt["stream_ordering"]
- ) >= (
- pa["topological_ordering"], pa["stream_ordering"]
- )
+ receipt["topological_ordering"],
+ receipt["stream_ordering"],
+ ) >= (pa["topological_ordering"], pa["stream_ordering"])
returned_push_actions.append(returned_pa)
next_token = str(pa["stream_ordering"])
- defer.returnValue((200, {
- "notifications": returned_push_actions,
- "next_token": next_token,
- }))
+ return 200, {"notifications": returned_push_actions, "next_token": next_token}
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py
index bb927d9f9d..6ae9a5a8e9 100644
--- a/synapse/rest/client/v2_alpha/openid.py
+++ b/synapse/rest/client/v2_alpha/openid.py
@@ -16,8 +16,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.util.stringutils import random_string
@@ -56,9 +54,8 @@ class IdTokenServlet(RestServlet):
"expires_in": 3600,
}
"""
- PATTERNS = client_patterns(
- "/user/(?P<user_id>[^/]*)/openid/request_token"
- )
+
+ PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/openid/request_token")
EXPIRES_MS = 3600 * 1000
@@ -69,9 +66,8 @@ class IdTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.server_name = hs.config.server_name
- @defer.inlineCallbacks
- def on_POST(self, request, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request, user_id):
+ requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot request tokens for other users.")
@@ -82,14 +78,17 @@ class IdTokenServlet(RestServlet):
token = random_string(24)
ts_valid_until_ms = self.clock.time_msec() + self.EXPIRES_MS
- yield self.store.insert_open_id_token(token, ts_valid_until_ms, user_id)
-
- defer.returnValue((200, {
- "access_token": token,
- "token_type": "Bearer",
- "matrix_server_name": self.server_name,
- "expires_in": self.EXPIRES_MS / 1000,
- }))
+ await self.store.insert_open_id_token(token, ts_valid_until_ms, user_id)
+
+ return (
+ 200,
+ {
+ "access_token": token,
+ "token_type": "Bearer",
+ "matrix_server_name": self.server_name,
+ "expires_in": self.EXPIRES_MS / 1000,
+ },
+ )
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py
index f4bd0d077f..67cbc37312 100644
--- a/synapse/rest/client/v2_alpha/read_marker.py
+++ b/synapse/rest/client/v2_alpha/read_marker.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from ._base import client_patterns
@@ -34,32 +32,31 @@ class ReadMarkerRestServlet(RestServlet):
self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler()
- @defer.inlineCallbacks
- def on_POST(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
- yield self.presence_handler.bump_presence_active_time(requester.user)
+ await self.presence_handler.bump_presence_active_time(requester.user)
body = parse_json_object_from_request(request)
read_event_id = body.get("m.read", None)
if read_event_id:
- yield self.receipts_handler.received_client_receipt(
+ await self.receipts_handler.received_client_receipt(
room_id,
"m.read",
user_id=requester.user.to_string(),
- event_id=read_event_id
+ event_id=read_event_id,
)
read_marker_event_id = body.get("m.fully_read", None)
if read_marker_event_id:
- yield self.read_marker_handler.received_client_read_marker(
+ await self.read_marker_handler.received_client_read_marker(
room_id,
user_id=requester.user.to_string(),
- event_id=read_marker_event_id
+ event_id=read_marker_event_id,
)
- defer.returnValue((200, {}))
+ return 200, {}
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py
index fa12ac3e4d..92555bd4a9 100644
--- a/synapse/rest/client/v2_alpha/receipts.py
+++ b/synapse/rest/client/v2_alpha/receipts.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
@@ -39,23 +37,19 @@ class ReceiptRestServlet(RestServlet):
self.receipts_handler = hs.get_receipts_handler()
self.presence_handler = hs.get_presence_handler()
- @defer.inlineCallbacks
- def on_POST(self, request, room_id, receipt_type, event_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request, room_id, receipt_type, event_id):
+ requester = await self.auth.get_user_by_req(request)
if receipt_type != "m.read":
raise SynapseError(400, "Receipt type must be 'm.read'")
- yield self.presence_handler.bump_presence_active_time(requester.user)
+ await self.presence_handler.bump_presence_active_time(requester.user)
- yield self.receipts_handler.received_client_receipt(
- room_id,
- receipt_type,
- user_id=requester.user.to_string(),
- event_id=event_id
+ await self.receipts_handler.received_client_receipt(
+ room_id, receipt_type, user_id=requester.user.to_string(), event_id=event_id
)
- defer.returnValue((200, {}))
+ return 200, {}
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 3d5a198278..c3c96a9e86 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -18,29 +18,37 @@
import hmac
import logging
import re
-from hashlib import sha1
+from typing import List, Union
from six import string_types
-from twisted.internet import defer
-
import synapse
+import synapse.api.auth
import synapse.types
from synapse.api.constants import LoginType
from synapse.api.errors import (
Codes,
LimitExceededError,
SynapseError,
+ ThreepidValidationError,
UnrecognizedRequestError,
)
+from synapse.config import ConfigError
+from synapse.config.captcha import CaptchaConfig
+from synapse.config.consent_config import ConsentConfig
+from synapse.config.emailconfig import ThreepidBehaviour
from synapse.config.ratelimiting import FederationRateLimitConfig
+from synapse.config.registration import RegistrationConfig
from synapse.config.server import is_threepid_reserved
+from synapse.handlers.auth import AuthHandler
+from synapse.http.server import finish_request
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
parse_string,
)
+from synapse.push.mailer import load_jinja2_templates
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.stringutils import assert_valid_client_secret
@@ -55,6 +63,7 @@ from ._base import client_patterns, interactive_auth_handler
if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest
else:
+
def compare_digest(a, b):
return a == b
@@ -73,33 +82,88 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
super(EmailRegisterRequestTokenRestServlet, self).__init__()
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
+ self.config = hs.config
+
+ if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ from synapse.push.mailer import Mailer, load_jinja2_templates
+
+ template_html, template_text = load_jinja2_templates(
+ self.config.email_template_dir,
+ [
+ self.config.email_registration_template_html,
+ self.config.email_registration_template_text,
+ ],
+ apply_format_ts_filter=True,
+ apply_mxc_to_http_filter=True,
+ public_baseurl=self.config.public_baseurl,
+ )
+ self.mailer = Mailer(
+ hs=self.hs,
+ app_name=self.config.email_app_name,
+ template_html=template_html,
+ template_text=template_text,
+ )
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
+ if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
+ if self.hs.config.local_threepid_handling_disabled_due_to_email_config:
+ logger.warning(
+ "Email registration has been disabled due to lack of email config"
+ )
+ raise SynapseError(
+ 400, "Email-based registration has been disabled on this server"
+ )
body = parse_json_object_from_request(request)
- assert_params_in_dict(body, [
- 'id_server', 'client_secret', 'email', 'send_attempt'
- ])
+ assert_params_in_dict(body, ["client_secret", "email", "send_attempt"])
- if not (yield check_3pid_allowed(self.hs, "email", body['email'])):
+ # Extract params from body
+ client_secret = body["client_secret"]
+ assert_valid_client_secret(client_secret)
+
+ email = body["email"]
+ send_attempt = body["send_attempt"]
+ next_link = body.get("next_link") # Optional param
+
+ if not (await check_3pid_allowed(self.hs, "email", body["email"])):
raise SynapseError(
403,
"Your email is not authorized to register on this server",
Codes.THREEPID_DENIED,
)
- assert_params_in_dict(body["client_secret"])
-
- existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
- 'email', body['email']
+ existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
+ "email", body["email"]
)
- if existingUid is not None:
+ if existing_user_id is not None:
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
- ret = yield self.identity_handler.requestEmailToken(**body)
- defer.returnValue((200, ret))
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
+ assert self.hs.config.account_threepid_delegate_email
+
+ # Have the configured identity server handle the request
+ ret = await self.identity_handler.requestEmailToken(
+ self.hs.config.account_threepid_delegate_email,
+ email,
+ client_secret,
+ send_attempt,
+ next_link,
+ )
+ else:
+ # Send registration emails from Synapse
+ sid = await self.identity_handler.send_threepid_validation(
+ email,
+ client_secret,
+ send_attempt,
+ self.mailer.send_registration_mail,
+ next_link,
+ )
+
+ # Wrap the session id in a JSON object
+ ret = {"sid": sid}
+
+ return 200, ret
class MsisdnRegisterRequestTokenRestServlet(RestServlet):
@@ -114,38 +178,140 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
body = parse_json_object_from_request(request)
- assert_params_in_dict(body, [
- 'id_server', 'client_secret',
- 'country', 'phone_number',
- 'send_attempt',
- ])
+ assert_params_in_dict(
+ body, ["client_secret", "country", "phone_number", "send_attempt"]
+ )
+ client_secret = body["client_secret"]
+ country = body["country"]
+ phone_number = body["phone_number"]
+ send_attempt = body["send_attempt"]
+ next_link = body.get("next_link") # Optional param
- msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+ msisdn = phone_number_to_msisdn(country, phone_number)
assert_valid_client_secret(body["client_secret"])
- if not (yield check_3pid_allowed(self.hs, "msisdn", msisdn)):
+ if not (await check_3pid_allowed(self.hs, "msisdn", msisdn)):
raise SynapseError(
403,
"Phone numbers are not authorized to register on this server",
Codes.THREEPID_DENIED,
)
- existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
- 'msisdn', msisdn
+ existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
+ "msisdn", msisdn
)
- if existingUid is not None:
+ if existing_user_id is not None:
raise SynapseError(
400, "Phone number is already in use", Codes.THREEPID_IN_USE
)
- ret = yield self.identity_handler.requestMsisdnToken(**body)
- defer.returnValue((200, ret))
+ if not self.hs.config.account_threepid_delegate_msisdn:
+ logger.warning(
+ "No upstream msisdn account_threepid_delegate configured on the server to "
+ "handle this request"
+ )
+ raise SynapseError(
+ 400, "Registration by phone number is not supported on this homeserver"
+ )
+
+ ret = await self.identity_handler.requestMsisdnToken(
+ self.hs.config.account_threepid_delegate_msisdn,
+ country,
+ phone_number,
+ client_secret,
+ send_attempt,
+ next_link,
+ )
+
+ return 200, ret
+
+
+class RegistrationSubmitTokenServlet(RestServlet):
+ """Handles registration 3PID validation token submission"""
+
+ PATTERNS = client_patterns(
+ "/registration/(?P<medium>[^/]*)/submit_token$", releases=(), unstable=True
+ )
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ super(RegistrationSubmitTokenServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.config = hs.config
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastore()
+
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ (self.failure_email_template,) = load_jinja2_templates(
+ self.config.email_template_dir,
+ [self.config.email_registration_template_failure_html],
+ )
+
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ (self.failure_email_template,) = load_jinja2_templates(
+ self.config.email_template_dir,
+ [self.config.email_registration_template_failure_html],
+ )
+
+ async def on_GET(self, request, medium):
+ if medium != "email":
+ raise SynapseError(
+ 400, "This medium is currently not supported for registration"
+ )
+ if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
+ if self.config.local_threepid_handling_disabled_due_to_email_config:
+ logger.warning(
+ "User registration via email has been disabled due to lack of email config"
+ )
+ raise SynapseError(
+ 400, "Email-based registration is disabled on this server"
+ )
+
+ sid = parse_string(request, "sid", required=True)
+ client_secret = parse_string(request, "client_secret", required=True)
+ token = parse_string(request, "token", required=True)
+
+ # Attempt to validate a 3PID session
+ try:
+ # Mark the session as valid
+ next_link = await self.store.validate_threepid_session(
+ sid, client_secret, token, self.clock.time_msec()
+ )
+
+ # Perform a 302 redirect if next_link is set
+ if next_link:
+ if next_link.startswith("file:///"):
+ logger.warning(
+ "Not redirecting to next_link as it is a local file: address"
+ )
+ else:
+ request.setResponseCode(302)
+ request.setHeader("Location", next_link)
+ finish_request(request)
+ return None
+
+ # Otherwise show the success template
+ html = self.config.email_registration_template_success_html_content
+
+ request.setResponseCode(200)
+ except ThreepidValidationError as e:
+ request.setResponseCode(e.code)
+
+ # Show a failure page with a reason
+ template_vars = {"failure_reason": e.msg}
+ html = self.failure_email_template.render(**template_vars)
+
+ request.write(html.encode("utf-8"))
+ finish_request(request)
class UsernameAvailabilityRestServlet(RestServlet):
@@ -172,20 +338,24 @@ class UsernameAvailabilityRestServlet(RestServlet):
reject_limit=1,
# Allow 1 request at a time
concurrent_requests=1,
- )
+ ),
)
- @defer.inlineCallbacks
- def on_GET(self, request):
+ async def on_GET(self, request):
+ if not self.hs.config.enable_registration:
+ raise SynapseError(
+ 403, "Registration has been disabled", errcode=Codes.FORBIDDEN
+ )
+
ip = self.hs.get_ip_from_request(request)
with self.ratelimiter.ratelimit(ip) as wait_deferred:
- yield wait_deferred
+ await wait_deferred
username = parse_string(request, "username", required=True)
- yield self.registration_handler.check_username(username)
+ await self.registration_handler.check_username(username)
- defer.returnValue((200, {"available": True}))
+ return 200, {"available": True}
class RegisterRestServlet(RestServlet):
@@ -210,9 +380,12 @@ class RegisterRestServlet(RestServlet):
self.password_policy_handler = hs.get_password_policy_handler()
self.clock = hs.get_clock()
+ self._registration_flows = _calculate_registration_flows(
+ hs.config, self.auth_handler
+ )
+
@interactive_auth_handler
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
body = parse_json_object_from_request(request)
client_addr = request.getClientIP()
@@ -220,7 +393,8 @@ class RegisterRestServlet(RestServlet):
time_now = self.clock.time()
allowed, time_allowed = self.ratelimiter.can_do_action(
- client_addr, time_now_s=time_now,
+ client_addr,
+ time_now_s=time_now,
rate_hz=self.hs.config.rc_registration.per_second,
burst_count=self.hs.config.rc_registration.burst_count,
update=False,
@@ -228,7 +402,7 @@ class RegisterRestServlet(RestServlet):
if not allowed:
raise LimitExceededError(
- retry_after_ms=int(1000 * (time_allowed - time_now)),
+ retry_after_ms=int(1000 * (time_allowed - time_now))
)
kind = b"user"
@@ -236,39 +410,42 @@ class RegisterRestServlet(RestServlet):
kind = request.args[b"kind"][0]
if kind == b"guest":
- ret = yield self._do_guest_registration(body, address=client_addr)
- defer.returnValue(ret)
- return
+ ret = await self._do_guest_registration(body, address=client_addr)
+ return ret
elif kind != b"user":
raise UnrecognizedRequestError(
- "Do not understand membership kind: %s" % (kind,)
+ "Do not understand membership kind: %s" % (kind.decode("utf8"),)
)
# we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the username/password provided to us.
desired_password = None
- if 'password' in body:
- if (not isinstance(body['password'], string_types) or
- len(body['password']) > 512):
+ if "password" in body:
+ if (
+ not isinstance(body["password"], string_types)
+ or len(body["password"]) > 512
+ ):
raise SynapseError(400, "Invalid password")
- self.password_policy_handler.validate_password(body['password'])
+ self.password_policy_handler.validate_password(body["password"])
desired_password = body["password"]
desired_username = None
- if 'username' in body:
- if (not isinstance(body['username'], string_types) or
- len(body['username']) > 512):
+ if "username" in body:
+ if (
+ not isinstance(body["username"], string_types)
+ or len(body["username"]) > 512
+ ):
raise SynapseError(400, "Invalid username")
- desired_username = body['username']
+ desired_username = body["username"]
- desired_display_name = body.get('display_name')
+ desired_display_name = body.get("display_name")
appservice = None
if self.auth.has_access_token(request):
- appservice = yield self.auth.get_appservice_by_req(request)
+ appservice = await self.auth.get_appservice_by_req(request)
- # fork off as soon as possible for ASes and shared secret auth which
- # have completely different registration flows to normal users
+ # fork off as soon as possible for ASes which have completely
+ # different registration flows to normal users
# == Application Service Registration ==
if appservice:
@@ -285,15 +462,17 @@ class RegisterRestServlet(RestServlet):
access_token = self.auth.get_access_token_from_request(request)
if isinstance(desired_username, string_types):
- result = yield self._do_appservice_registration(
- desired_username, desired_password, desired_display_name,
- access_token, body
+ result = await self._do_appservice_registration(
+ desired_username,
+ desired_password,
+ desired_display_name,
+ access_token,
+ body,
)
- defer.returnValue((200, result)) # we throw for non 200 responses
- return
+ return 200, result # we throw for non 200 responses
- # for either shared secret or regular registration, downcase the
- # provided username before attempting to register it. This should mean
+ # for regular registration, downcase the provided username before
+ # attempting to register it. This should mean
# that people who try to register with upper-case in their usernames
# don't get a nasty surprise. (Note that we treat username
# case-insenstively in login, so they are free to carry on imagining
@@ -301,32 +480,19 @@ class RegisterRestServlet(RestServlet):
if desired_username is not None:
desired_username = desired_username.lower()
- # == Shared Secret Registration == (e.g. create new user scripts)
- if 'mac' in body:
- # FIXME: Should we really be determining if this is shared secret
- # auth based purely on the 'mac' key?
- result = yield self._do_shared_secret_registration(
- desired_username, desired_password, body
- )
- defer.returnValue((200, result)) # we throw for non 200 responses
- return
-
# == Normal User Registration == (everyone else)
if not self.hs.config.enable_registration:
raise SynapseError(403, "Registration has been disabled")
guest_access_token = body.get("guest_access_token", None)
- if (
- 'initial_device_display_name' in body and
- 'password' not in body
- ):
+ if "initial_device_display_name" in body and "password" not in body:
# ignore 'initial_device_display_name' if sent without
# a password to work around a client bug where it sent
# the 'initial_device_display_name' param alone, wiping out
# the original registration params
- logger.warn("Ignoring initial_device_display_name without password")
- del body['initial_device_display_name']
+ logger.warning("Ignoring initial_device_display_name without password")
+ del body["initial_device_display_name"]
session_id = self.auth_handler.get_session_id(body)
registered_user_id = None
@@ -340,77 +506,14 @@ class RegisterRestServlet(RestServlet):
)
if desired_username is not None:
- yield self.registration_handler.check_username(
+ await self.registration_handler.check_username(
desired_username,
guest_access_token=guest_access_token,
assigned_user_id=registered_user_id,
)
- # FIXME: need a better error than "no auth flow found" for scenarios
- # where we required 3PID for registration but the user didn't give one
- require_email = 'email' in self.hs.config.registrations_require_3pid
- require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
-
- show_msisdn = True
- if self.hs.config.disable_msisdn_registration:
- show_msisdn = False
- require_msisdn = False
-
- flows = []
- if self.hs.config.enable_registration_captcha:
- # only support 3PIDless registration if no 3PIDs are required
- if not require_email and not require_msisdn:
- # Also add a dummy flow here, otherwise if a client completes
- # recaptcha first we'll assume they were going for this flow
- # and complete the request, when they could have been trying to
- # complete one of the flows with email/msisdn auth.
- flows.extend([[LoginType.RECAPTCHA, LoginType.DUMMY]])
- # only support the email-only flow if we don't require MSISDN 3PIDs
- if not require_msisdn:
- flows.extend([[LoginType.RECAPTCHA, LoginType.EMAIL_IDENTITY]])
-
- if show_msisdn:
- # only support the MSISDN-only flow if we don't require email 3PIDs
- if not require_email:
- flows.extend([[LoginType.RECAPTCHA, LoginType.MSISDN]])
- # always let users provide both MSISDN & email
- flows.extend([
- [LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
- ])
- else:
- # only support 3PIDless registration if no 3PIDs are required
- if not require_email and not require_msisdn:
- flows.extend([[LoginType.DUMMY]])
- # only support the email-only flow if we don't require MSISDN 3PIDs
- if not require_msisdn:
- flows.extend([[LoginType.EMAIL_IDENTITY]])
-
- if show_msisdn:
- # only support the MSISDN-only flow if we don't require email 3PIDs
- if not require_email or require_msisdn:
- flows.extend([[LoginType.MSISDN]])
- # always let users provide both MSISDN & email
- flows.extend([
- [LoginType.MSISDN, LoginType.EMAIL_IDENTITY]
- ])
-
- # Append m.login.terms to all flows if we're requiring consent
- if self.hs.config.user_consent_at_registration:
- new_flows = []
- for flow in flows:
- inserted = False
- # m.login.terms should go near the end but before msisdn or email auth
- for i, stage in enumerate(flow):
- if stage == LoginType.EMAIL_IDENTITY or stage == LoginType.MSISDN:
- flow.insert(i, LoginType.TERMS)
- inserted = True
- break
- if not inserted:
- flow.append(LoginType.TERMS)
- flows.extend(new_flows)
-
- auth_result, params, session_id = yield self.auth_handler.check_auth(
- flows, body, self.hs.get_ip_from_request(request)
+ auth_result, params, session_id = await self.auth_handler.check_auth(
+ self._registration_flows, body, self.hs.get_ip_from_request(request)
)
# Check that we're not trying to register a denied 3pid.
@@ -422,26 +525,24 @@ class RegisterRestServlet(RestServlet):
if auth_result:
for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
if login_type in auth_result:
- medium = auth_result[login_type]['medium']
- address = auth_result[login_type]['address']
+ medium = auth_result[login_type]["medium"]
+ address = auth_result[login_type]["address"]
- if not (yield check_3pid_allowed(self.hs, medium, address)):
+ if not (await check_3pid_allowed(self.hs, medium, address)):
raise SynapseError(
403,
- "Third party identifiers (email/phone numbers)" +
- " are not authorized on this server",
+ "Third party identifiers (email/phone numbers)"
+ + " are not authorized on this server",
Codes.THREEPID_DENIED,
)
- existingUid = yield self.store.get_user_id_by_threepid(
- medium, address,
+ existingUid = await self.store.get_user_id_by_threepid(
+ medium, address
)
if existingUid is not None:
raise SynapseError(
- 400,
- "%s is already in use" % medium,
- Codes.THREEPID_IN_USE,
+ 400, "%s is already in use" % medium, Codes.THREEPID_IN_USE
)
if self.hs.config.register_mxid_from_3pid:
@@ -455,19 +556,19 @@ class RegisterRestServlet(RestServlet):
# desired_username
if auth_result:
if (
- self.hs.config.register_mxid_from_3pid == 'email' and
- LoginType.EMAIL_IDENTITY in auth_result
+ self.hs.config.register_mxid_from_3pid == "email"
+ and LoginType.EMAIL_IDENTITY in auth_result
):
- address = auth_result[LoginType.EMAIL_IDENTITY]['address']
+ address = auth_result[LoginType.EMAIL_IDENTITY]["address"]
desired_username = synapse.types.strip_invalid_mxid_characters(
- address.replace('@', '-').lower()
+ address.replace("@", "-").lower()
)
# find a unique mxid for the account, suffixing numbers
# if needed
while True:
try:
- yield self.registration_handler.check_username(
+ await self.registration_handler.check_username(
desired_username,
guest_access_token=guest_access_token,
assigned_user_id=registered_user_id,
@@ -476,7 +577,7 @@ class RegisterRestServlet(RestServlet):
break
except SynapseError as e:
if e.errcode == Codes.USER_IN_USE:
- m = re.match(r'^(.*?)(\d+)$', desired_username)
+ m = re.match(r"^(.*?)(\d+)$", desired_username)
if m:
desired_username = m.group(1) + str(
int(m.group(2)) + 1
@@ -491,19 +592,19 @@ class RegisterRestServlet(RestServlet):
desired_display_name = address
else:
# Custom mapping between email address and display name
- desired_display_name = self._map_email_to_displayname(address)
+ desired_display_name = _map_email_to_displayname(address)
elif (
- self.hs.config.register_mxid_from_3pid == 'msisdn' and
- LoginType.MSISDN in auth_result
+ self.hs.config.register_mxid_from_3pid == "msisdn"
+ and LoginType.MSISDN in auth_result
):
- desired_username = auth_result[LoginType.MSISDN]['address']
+ desired_username = auth_result[LoginType.MSISDN]["address"]
else:
raise SynapseError(
400, "Cannot derive mxid from 3pid; no recognised 3pid"
)
if desired_username is not None:
- yield self.registration_handler.check_username(
+ await self.registration_handler.check_username(
desired_username,
guest_access_token=guest_access_token,
assigned_user_id=registered_user_id,
@@ -511,8 +612,7 @@ class RegisterRestServlet(RestServlet):
if registered_user_id is not None:
logger.info(
- "Already registered user ID %r for this session",
- registered_user_id
+ "Already registered user ID %r for this session", registered_user_id
)
# don't re-register the threepids
registered = False
@@ -546,25 +646,24 @@ class RegisterRestServlet(RestServlet):
# the two activation emails, they would register the same 3pid twice.
for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
if login_type in auth_result:
- medium = auth_result[login_type]['medium']
- address = auth_result[login_type]['address']
+ medium = auth_result[login_type]["medium"]
+ address = auth_result[login_type]["address"]
- existingUid = yield self.store.get_user_id_by_threepid(
- medium, address,
+ existing_user_id = await self.store.get_user_id_by_threepid(
+ medium, address
)
- if existingUid is not None:
+ if existing_user_id is not None:
raise SynapseError(
400,
"%s is already in use" % medium,
Codes.THREEPID_IN_USE,
)
- (registered_user_id, _) = yield self.registration_handler.register(
+ registered_user_id = await self.registration_handler.register_user(
localpart=desired_username,
password=params.get("password", None),
guest_access_token=guest_access_token,
- generate_token=False,
default_display_name=desired_display_name,
threepid=threepid,
address=client_addr,
@@ -575,10 +674,10 @@ class RegisterRestServlet(RestServlet):
if is_threepid_reserved(
self.hs.config.mau_limits_reserved_threepids, threepid
):
- yield self.store.upsert_monthly_active_user(registered_user_id)
+ await self.store.upsert_monthly_active_user(registered_user_id)
if self.hs.config.shadow_server:
- yield self.registration_handler.shadow_register(
+ await self.registration_handler.shadow_register(
localpart=desired_username,
display_name=desired_display_name,
auth_result=auth_result,
@@ -593,93 +692,48 @@ class RegisterRestServlet(RestServlet):
registered = True
- return_dict = yield self._create_registration_details(
+ return_dict = await self._create_registration_details(
registered_user_id, params
)
if registered:
- yield self.registration_handler.post_registration_actions(
+ await self.registration_handler.post_registration_actions(
user_id=registered_user_id,
auth_result=auth_result,
access_token=return_dict.get("access_token"),
- bind_email=params.get("bind_email"),
- bind_msisdn=params.get("bind_msisdn"),
)
- defer.returnValue((200, return_dict))
+ return 200, return_dict
def on_OPTIONS(self, _):
return 200, {}
- @defer.inlineCallbacks
- def _do_appservice_registration(
+ async def _do_appservice_registration(
self, username, password, display_name, as_token, body
):
-
# FIXME: appservice_register() is horribly duplicated with register()
# and they should probably just be combined together with a config flag.
- user_id = yield self.registration_handler.appservice_register(
+ user_id = await self.registration_handler.appservice_register(
username, as_token, password, display_name
)
- result = yield self._create_registration_details(user_id, body)
+ result = await self._create_registration_details(user_id, body)
- auth_result = body.get('auth_result')
+ auth_result = body.get("auth_result")
if auth_result and LoginType.EMAIL_IDENTITY in auth_result:
threepid = auth_result[LoginType.EMAIL_IDENTITY]
- yield self._register_email_threepid(
- user_id, threepid, result["access_token"],
- body.get("bind_email")
+ await self.registration_handler.register_email_threepid(
+ user_id, threepid, result["access_token"]
)
if auth_result and LoginType.MSISDN in auth_result:
threepid = auth_result[LoginType.MSISDN]
- yield self._register_msisdn_threepid(
- user_id, threepid, result["access_token"],
- body.get("bind_msisdn")
- )
-
- defer.returnValue(result)
-
- @defer.inlineCallbacks
- def _do_shared_secret_registration(self, username, password, body):
- if not self.hs.config.registration_shared_secret:
- raise SynapseError(400, "Shared secret registration is not enabled")
- if not username:
- raise SynapseError(
- 400, "username must be specified", errcode=Codes.BAD_JSON,
+ await self.registration_handler.register_msisdn_threepid(
+ user_id, threepid, result["access_token"]
)
- # use the username from the original request rather than the
- # downcased one in `username` for the mac calculation
- user = body["username"].encode("utf-8")
+ return result
- # str() because otherwise hmac complains that 'unicode' does not
- # have the buffer interface
- got_mac = str(body["mac"])
-
- # FIXME this is different to the /v1/register endpoint, which
- # includes the password and admin flag in the hashed text. Why are
- # these different?
- want_mac = hmac.new(
- key=self.hs.config.registration_shared_secret.encode(),
- msg=user,
- digestmod=sha1,
- ).hexdigest()
-
- if not compare_digest(want_mac, got_mac):
- raise SynapseError(
- 403, "HMAC incorrect",
- )
-
- (user_id, _) = yield self.registration_handler.register(
- localpart=username, password=password, generate_token=False,
- )
-
- result = yield self._create_registration_details(user_id, body)
- defer.returnValue(result)
-
- @defer.inlineCallbacks
- def _create_registration_details(self, user_id, params):
+ async def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token.
@@ -691,47 +745,41 @@ class RegisterRestServlet(RestServlet):
Returns:
defer.Deferred: (object) dictionary for response from /register
"""
- result = {
- "user_id": user_id,
- "home_server": self.hs.hostname,
- }
+ result = {"user_id": user_id, "home_server": self.hs.hostname}
if not params.get("inhibit_login", False):
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
- device_id, access_token = yield self.registration_handler.register_device(
- user_id, device_id, initial_display_name, is_guest=False,
+ device_id, access_token = await self.registration_handler.register_device(
+ user_id, device_id, initial_display_name, is_guest=False
)
- result.update({
- "access_token": access_token,
- "device_id": device_id,
- })
- defer.returnValue(result)
+ result.update({"access_token": access_token, "device_id": device_id})
+ return result
- @defer.inlineCallbacks
- def _do_guest_registration(self, params, address=None):
+ async def _do_guest_registration(self, params, address=None):
if not self.hs.config.allow_guest_access:
raise SynapseError(403, "Guest access is disabled")
- user_id, _ = yield self.registration_handler.register(
- generate_token=False,
- make_guest=True,
- address=address,
+ user_id = await self.registration_handler.register_user(
+ make_guest=True, address=address
)
# we don't allow guests to specify their own device_id, because
# we have nowhere to store it.
device_id = synapse.api.auth.GUEST_DEVICE_ID
initial_display_name = params.get("initial_device_display_name")
- device_id, access_token = yield self.registration_handler.register_device(
- user_id, device_id, initial_display_name, is_guest=True,
+ device_id, access_token = await self.registration_handler.register_device(
+ user_id, device_id, initial_display_name, is_guest=True
)
- defer.returnValue((200, {
- "user_id": user_id,
- "device_id": device_id,
- "access_token": access_token,
- "home_server": self.hs.hostname,
- }))
+ return (
+ 200,
+ {
+ "user_id": user_id,
+ "device_id": device_id,
+ "access_token": access_token,
+ "home_server": self.hs.hostname,
+ },
+ )
def cap(name):
@@ -764,10 +812,10 @@ def _map_email_to_displayname(address):
"""
# Split the part before and after the @ in the email.
# Replace all . with spaces in the first part
- parts = address.replace('.', ' ').split('@')
+ parts = address.replace(".", " ").split("@")
# Figure out which org this email address belongs to
- org_parts = parts[1].split(' ')
+ org_parts = parts[1].split(" ")
# If this is a ...matrix.org email, mark them as an Admin
if org_parts[-2] == "matrix" and org_parts[-1] == "org":
@@ -783,15 +831,91 @@ def _map_email_to_displayname(address):
else:
org = org_parts[-2]
- desired_display_name = (
- cap(parts[0]) + " [" + cap(org) + "]"
- )
+ desired_display_name = cap(parts[0]) + " [" + cap(org) + "]"
return desired_display_name
+def _calculate_registration_flows(
+ # technically `config` has to provide *all* of these interfaces, not just one
+ config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig],
+ auth_handler: AuthHandler,
+) -> List[List[str]]:
+ """Get a suitable flows list for registration
+
+ Args:
+ config: server configuration
+ auth_handler: authorization handler
+
+ Returns: a list of supported flows
+ """
+ # FIXME: need a better error than "no auth flow found" for scenarios
+ # where we required 3PID for registration but the user didn't give one
+ require_email = "email" in config.registrations_require_3pid
+ require_msisdn = "msisdn" in config.registrations_require_3pid
+
+ show_msisdn = True
+ show_email = True
+
+ if config.disable_msisdn_registration:
+ show_msisdn = False
+ require_msisdn = False
+
+ enabled_auth_types = auth_handler.get_enabled_auth_types()
+ if LoginType.EMAIL_IDENTITY not in enabled_auth_types:
+ show_email = False
+ if require_email:
+ raise ConfigError(
+ "Configuration requires email address at registration, but email "
+ "validation is not configured"
+ )
+
+ if LoginType.MSISDN not in enabled_auth_types:
+ show_msisdn = False
+ if require_msisdn:
+ raise ConfigError(
+ "Configuration requires msisdn at registration, but msisdn "
+ "validation is not configured"
+ )
+
+ flows = []
+
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ # Add a dummy step here, otherwise if a client completes
+ # recaptcha first we'll assume they were going for this flow
+ # and complete the request, when they could have been trying to
+ # complete one of the flows with email/msisdn auth.
+ flows.append([LoginType.DUMMY])
+
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if show_email and not require_msisdn:
+ flows.append([LoginType.EMAIL_IDENTITY])
+
+ # only support the MSISDN-only flow if we don't require email 3PIDs
+ if show_msisdn and not require_email:
+ flows.append([LoginType.MSISDN])
+
+ if show_email and show_msisdn:
+ # always let users provide both MSISDN & email
+ flows.append([LoginType.MSISDN, LoginType.EMAIL_IDENTITY])
+
+ # Prepend m.login.terms to all flows if we're requiring consent
+ if config.user_consent_at_registration:
+ for flow in flows:
+ flow.insert(0, LoginType.TERMS)
+
+ # Prepend recaptcha to all flows if we're requiring captcha
+ if config.enable_registration_captcha:
+ for flow in flows:
+ flow.insert(0, LoginType.RECAPTCHA)
+
+ return flows
+
+
def register_servlets(hs, http_server):
EmailRegisterRequestTokenRestServlet(hs).register(http_server)
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
UsernameAvailabilityRestServlet(hs).register(http_server)
+ RegistrationSubmitTokenServlet(hs).register(http_server)
RegisterRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index f8f8742bdc..63f07b63da 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -21,8 +21,6 @@ any time to reflect changes in the MSC.
import logging
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import SynapseError
from synapse.http.servlet import (
@@ -32,7 +30,11 @@ from synapse.http.servlet import (
parse_string,
)
from synapse.rest.client.transactions import HttpTransactionCache
-from synapse.storage.relations import AggregationPaginationToken, RelationPaginationToken
+from synapse.storage.relations import (
+ AggregationPaginationToken,
+ PaginationChunk,
+ RelationPaginationToken,
+)
from ._base import client_patterns
@@ -68,11 +70,13 @@ class RelationSendServlet(RestServlet):
"POST",
client_patterns(self.PATTERN + "$", releases=()),
self.on_PUT_or_POST,
+ self.__class__.__name__,
)
http_server.register_paths(
"PUT",
client_patterns(self.PATTERN + "/(?P<txn_id>[^/]*)$", releases=()),
self.on_PUT,
+ self.__class__.__name__,
)
def on_PUT(self, request, *args, **kwargs):
@@ -80,11 +84,10 @@ class RelationSendServlet(RestServlet):
request, self.on_PUT_or_POST, request, *args, **kwargs
)
- @defer.inlineCallbacks
- def on_PUT_or_POST(
+ async def on_PUT_or_POST(
self, request, room_id, parent_id, relation_type, event_type, txn_id=None
):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
if event_type == EventTypes.Member:
# Add relations to a membership is meaningless, so we just deny it
@@ -108,11 +111,11 @@ class RelationSendServlet(RestServlet):
"sender": requester.user.to_string(),
}
- event = yield self.event_creation_handler.create_and_send_nonmember_event(
+ event = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict=event_dict, txn_id=txn_id
)
- defer.returnValue((200, {"event_id": event.event_id}))
+ return 200, {"event_id": event.event_id}
class RelationPaginationServlet(RestServlet):
@@ -134,48 +137,66 @@ class RelationPaginationServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(
+ self, request, room_id, parent_id, relation_type=None, event_type=None
+ ):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
- yield self.auth.check_in_room_or_world_readable(
- room_id, requester.user.to_string()
+ await self.auth.check_user_in_room_or_world_readable(
+ room_id, requester.user.to_string(), allow_departed_users=True
)
- # This checks that a) the event exists and b) the user is allowed to
- # view it.
- yield self.event_handler.get_event(requester.user, room_id, parent_id)
+ # This gets the original event and checks that a) the event exists and
+ # b) the user is allowed to view it.
+ event = await self.event_handler.get_event(requester.user, room_id, parent_id)
limit = parse_integer(request, "limit", default=5)
from_token = parse_string(request, "from")
to_token = parse_string(request, "to")
- if from_token:
- from_token = RelationPaginationToken.from_string(from_token)
-
- if to_token:
- to_token = RelationPaginationToken.from_string(to_token)
-
- result = yield self.store.get_relations_for_event(
- event_id=parent_id,
- relation_type=relation_type,
- event_type=event_type,
- limit=limit,
- from_token=from_token,
- to_token=to_token,
- )
-
- events = yield self.store.get_events_as_list(
- [c["event_id"] for c in result.chunk]
+ if event.internal_metadata.is_redacted():
+ # If the event is redacted, return an empty list of relations
+ pagination_chunk = PaginationChunk(chunk=[])
+ else:
+ # Return the relations
+ if from_token:
+ from_token = RelationPaginationToken.from_string(from_token)
+
+ if to_token:
+ to_token = RelationPaginationToken.from_string(to_token)
+
+ pagination_chunk = await self.store.get_relations_for_event(
+ event_id=parent_id,
+ relation_type=relation_type,
+ event_type=event_type,
+ limit=limit,
+ from_token=from_token,
+ to_token=to_token,
+ )
+
+ events = await self.store.get_events_as_list(
+ [c["event_id"] for c in pagination_chunk.chunk]
)
now = self.clock.time_msec()
- events = yield self._event_serializer.serialize_events(events, now)
+ # We set bundle_aggregations to False when retrieving the original
+ # event because we want the content before relations were applied to
+ # it.
+ original_event = await self._event_serializer.serialize_event(
+ event, now, bundle_aggregations=False
+ )
+ # Similarly, we don't allow relations to be applied to relations, so we
+ # return the original relations without any aggregations on top of them
+ # here.
+ events = await self._event_serializer.serialize_events(
+ events, now, bundle_aggregations=False
+ )
- return_value = result.to_dict()
+ return_value = pagination_chunk.to_dict()
return_value["chunk"] = events
+ return_value["original_event"] = original_event
- defer.returnValue((200, return_value))
+ return 200, return_value
class RelationAggregationPaginationServlet(RestServlet):
@@ -209,17 +230,18 @@ class RelationAggregationPaginationServlet(RestServlet):
self.store = hs.get_datastore()
self.event_handler = hs.get_event_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(
+ self, request, room_id, parent_id, relation_type=None, event_type=None
+ ):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
- yield self.auth.check_in_room_or_world_readable(
- room_id, requester.user.to_string()
+ await self.auth.check_user_in_room_or_world_readable(
+ room_id, requester.user.to_string(), allow_departed_users=True,
)
# This checks that a) the event exists and b) the user is allowed to
# view it.
- yield self.event_handler.get_event(requester.user, room_id, parent_id)
+ event = await self.event_handler.get_event(requester.user, room_id, parent_id)
if relation_type not in (RelationTypes.ANNOTATION, None):
raise SynapseError(400, "Relation type must be 'annotation'")
@@ -228,21 +250,26 @@ class RelationAggregationPaginationServlet(RestServlet):
from_token = parse_string(request, "from")
to_token = parse_string(request, "to")
- if from_token:
- from_token = AggregationPaginationToken.from_string(from_token)
+ if event.internal_metadata.is_redacted():
+ # If the event is redacted, return an empty list of relations
+ pagination_chunk = PaginationChunk(chunk=[])
+ else:
+ # Return the relations
+ if from_token:
+ from_token = AggregationPaginationToken.from_string(from_token)
- if to_token:
- to_token = AggregationPaginationToken.from_string(to_token)
+ if to_token:
+ to_token = AggregationPaginationToken.from_string(to_token)
- res = yield self.store.get_aggregation_groups_for_event(
- event_id=parent_id,
- event_type=event_type,
- limit=limit,
- from_token=from_token,
- to_token=to_token,
- )
+ pagination_chunk = await self.store.get_aggregation_groups_for_event(
+ event_id=parent_id,
+ event_type=event_type,
+ limit=limit,
+ from_token=from_token,
+ to_token=to_token,
+ )
- defer.returnValue((200, res.to_dict()))
+ return 200, pagination_chunk.to_dict()
class RelationAggregationGroupPaginationServlet(RestServlet):
@@ -283,17 +310,16 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id, parent_id, relation_type, event_type, key):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
- yield self.auth.check_in_room_or_world_readable(
- room_id, requester.user.to_string()
+ await self.auth.check_user_in_room_or_world_readable(
+ room_id, requester.user.to_string(), allow_departed_users=True,
)
# This checks that a) the event exists and b) the user is allowed to
# view it.
- yield self.event_handler.get_event(requester.user, room_id, parent_id)
+ await self.event_handler.get_event(requester.user, room_id, parent_id)
if relation_type != RelationTypes.ANNOTATION:
raise SynapseError(400, "Relation type must be 'annotation'")
@@ -308,7 +334,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
if to_token:
to_token = RelationPaginationToken.from_string(to_token)
- result = yield self.store.get_relations_for_event(
+ result = await self.store.get_relations_for_event(
event_id=parent_id,
relation_type=relation_type,
event_type=event_type,
@@ -318,17 +344,17 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
to_token=to_token,
)
- events = yield self.store.get_events_as_list(
+ events = await self.store.get_events_as_list(
[c["event_id"] for c in result.chunk]
)
now = self.clock.time_msec()
- events = yield self._event_serializer.serialize_events(events, now)
+ events = await self._event_serializer.serialize_events(events, now)
return_value = result.to_dict()
return_value["chunk"] = events
- defer.returnValue((200, return_value))
+ return 200, return_value
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py
index 10198662a9..f067b5edac 100644
--- a/synapse/rest/client/v2_alpha/report_event.py
+++ b/synapse/rest/client/v2_alpha/report_event.py
@@ -18,8 +18,6 @@ import logging
from six import string_types
from six.moves import http_client
-from twisted.internet import defer
-
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import (
RestServlet,
@@ -33,9 +31,7 @@ logger = logging.getLogger(__name__)
class ReportEventRestServlet(RestServlet):
- PATTERNS = client_patterns(
- "/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$"
- )
+ PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$")
def __init__(self, hs):
super(ReportEventRestServlet, self).__init__()
@@ -44,9 +40,8 @@ class ReportEventRestServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def on_POST(self, request, room_id, event_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request, room_id, event_id):
+ requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -65,7 +60,7 @@ class ReportEventRestServlet(RestServlet):
Codes.BAD_JSON,
)
- yield self.store.add_event_report(
+ await self.store.add_event_report(
room_id=room_id,
event_id=event_id,
user_id=user_id,
@@ -74,7 +69,7 @@ class ReportEventRestServlet(RestServlet):
received_ts=self.clock.time_msec(),
)
- defer.returnValue((200, {}))
+ return 200, {}
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py
index 87779645f9..38952a1d27 100644
--- a/synapse/rest/client/v2_alpha/room_keys.py
+++ b/synapse/rest/client/v2_alpha/room_keys.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import (
RestServlet,
@@ -43,8 +41,7 @@ class RoomKeysServlet(RestServlet):
self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, room_id, session_id):
+ async def on_PUT(self, request, room_id, session_id):
"""
Uploads one or more encrypted E2E room keys for backup purposes.
room_id: the ID of the room the keys are for (optional)
@@ -123,32 +120,21 @@ class RoomKeysServlet(RestServlet):
}
}
"""
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
version = parse_string(request, "version")
if session_id:
- body = {
- "sessions": {
- session_id: body
- }
- }
+ body = {"sessions": {session_id: body}}
if room_id:
- body = {
- "rooms": {
- room_id: body
- }
- }
+ body = {"rooms": {room_id: body}}
- yield self.e2e_room_keys_handler.upload_room_keys(
- user_id, version, body
- )
- defer.returnValue((200, {}))
+ ret = await self.e2e_room_keys_handler.upload_room_keys(user_id, version, body)
+ return 200, ret
- @defer.inlineCallbacks
- def on_GET(self, request, room_id, session_id):
+ async def on_GET(self, request, room_id, session_id):
"""
Retrieves one or more encrypted E2E room keys for backup purposes.
Symmetric with the PUT version of the API.
@@ -200,11 +186,11 @@ class RoomKeysServlet(RestServlet):
}
}
"""
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
version = parse_string(request, "version")
- room_keys = yield self.e2e_room_keys_handler.get_room_keys(
+ room_keys = await self.e2e_room_keys_handler.get_room_keys(
user_id, version, room_id, session_id
)
@@ -212,10 +198,10 @@ class RoomKeysServlet(RestServlet):
if session_id:
# If the client requests a specific session, but that session was
# not backed up, then return an M_NOT_FOUND.
- if room_keys['rooms'] == {}:
+ if room_keys["rooms"] == {}:
raise NotFoundError("No room_keys found")
else:
- room_keys = room_keys['rooms'][room_id]['sessions'][session_id]
+ room_keys = room_keys["rooms"][room_id]["sessions"][session_id]
elif room_id:
# If the client requests all sessions from a room, but no sessions
# are found, then return an empty result rather than an error, so
@@ -223,15 +209,14 @@ class RoomKeysServlet(RestServlet):
# empty result is valid. (Similarly if the client requests all
# sessions from the backup, but in that case, room_keys is already
# in the right format, so we don't need to do anything about it.)
- if room_keys['rooms'] == {}:
- room_keys = {'sessions': {}}
+ if room_keys["rooms"] == {}:
+ room_keys = {"sessions": {}}
else:
- room_keys = room_keys['rooms'][room_id]
+ room_keys = room_keys["rooms"][room_id]
- defer.returnValue((200, room_keys))
+ return 200, room_keys
- @defer.inlineCallbacks
- def on_DELETE(self, request, room_id, session_id):
+ async def on_DELETE(self, request, room_id, session_id):
"""
Deletes one or more encrypted E2E room keys for a user for backup purposes.
@@ -245,20 +230,18 @@ class RoomKeysServlet(RestServlet):
the version must already have been created via the /change_secret API.
"""
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
version = parse_string(request, "version")
- yield self.e2e_room_keys_handler.delete_room_keys(
+ ret = await self.e2e_room_keys_handler.delete_room_keys(
user_id, version, room_id, session_id
)
- defer.returnValue((200, {}))
+ return 200, ret
class RoomKeysNewVersionServlet(RestServlet):
- PATTERNS = client_patterns(
- "/room_keys/version$"
- )
+ PATTERNS = client_patterns("/room_keys/version$")
def __init__(self, hs):
"""
@@ -269,8 +252,7 @@ class RoomKeysNewVersionServlet(RestServlet):
self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
"""
Create a new backup version for this user's room_keys with the given
info. The version is allocated by the server and returned to the user
@@ -300,23 +282,19 @@ class RoomKeysNewVersionServlet(RestServlet):
"version": 12345
}
"""
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
info = parse_json_object_from_request(request)
- new_version = yield self.e2e_room_keys_handler.create_version(
- user_id, info
- )
- defer.returnValue((200, {"version": new_version}))
+ new_version = await self.e2e_room_keys_handler.create_version(user_id, info)
+ return 200, {"version": new_version}
# we deliberately don't have a PUT /version, as these things really should
# be immutable to avoid people footgunning
class RoomKeysVersionServlet(RestServlet):
- PATTERNS = client_patterns(
- "/room_keys/version(/(?P<version>[^/]+))?$"
- )
+ PATTERNS = client_patterns("/room_keys/version(/(?P<version>[^/]+))?$")
def __init__(self, hs):
"""
@@ -327,8 +305,7 @@ class RoomKeysVersionServlet(RestServlet):
self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, version):
+ async def on_GET(self, request, version):
"""
Retrieve the version information about a given version of the user's
room_keys backup. If the version part is missing, returns info about the
@@ -346,20 +323,17 @@ class RoomKeysVersionServlet(RestServlet):
"auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K"
}
"""
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
try:
- info = yield self.e2e_room_keys_handler.get_version_info(
- user_id, version
- )
+ info = await self.e2e_room_keys_handler.get_version_info(user_id, version)
except SynapseError as e:
if e.code == 404:
raise SynapseError(404, "No backup found", Codes.NOT_FOUND)
- defer.returnValue((200, info))
+ return 200, info
- @defer.inlineCallbacks
- def on_DELETE(self, request, version):
+ async def on_DELETE(self, request, version):
"""
Delete the information about a given version of the user's
room_keys backup. If the version part is missing, deletes the most
@@ -372,16 +346,13 @@ class RoomKeysVersionServlet(RestServlet):
if version is None:
raise SynapseError(400, "No version specified to delete", Codes.NOT_FOUND)
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
- yield self.e2e_room_keys_handler.delete_version(
- user_id, version
- )
- defer.returnValue((200, {}))
+ await self.e2e_room_keys_handler.delete_version(user_id, version)
+ return 200, {}
- @defer.inlineCallbacks
- def on_PUT(self, request, version):
+ async def on_PUT(self, request, version):
"""
Update the information about a given version of the user's room_keys backup.
@@ -395,24 +366,24 @@ class RoomKeysVersionServlet(RestServlet):
"ed25519:something": "hijklmnop"
}
},
- "version": "42"
+ "version": "12345"
}
HTTP/1.1 200 OK
Content-Type: application/json
{}
"""
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
info = parse_json_object_from_request(request)
if version is None:
- raise SynapseError(400, "No version specified to update", Codes.MISSING_PARAM)
+ raise SynapseError(
+ 400, "No version specified to update", Codes.MISSING_PARAM
+ )
- yield self.e2e_room_keys_handler.update_version(
- user_id, version, info
- )
- defer.returnValue((200, {}))
+ await self.e2e_room_keys_handler.update_version(user_id, version, info)
+ return 200, {}
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
index c621a90fba..f357015a70 100644
--- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import (
@@ -47,9 +45,10 @@ class RoomUpgradeRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer):
"""
+
PATTERNS = client_patterns(
# /rooms/$roomid/upgrade
- "/rooms/(?P<room_id>[^/]*)/upgrade$",
+ "/rooms/(?P<room_id>[^/]*)/upgrade$"
)
def __init__(self, hs):
@@ -58,30 +57,28 @@ class RoomUpgradeRestServlet(RestServlet):
self._room_creation_handler = hs.get_room_creation_handler()
self._auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_POST(self, request, room_id):
- requester = yield self._auth.get_user_by_req(request)
+ async def on_POST(self, request, room_id):
+ requester = await self._auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
- assert_params_in_dict(content, ("new_version", ))
+ assert_params_in_dict(content, ("new_version",))
new_version = content["new_version"]
- if new_version not in KNOWN_ROOM_VERSIONS:
+ new_version = KNOWN_ROOM_VERSIONS.get(content["new_version"])
+ if new_version is None:
raise SynapseError(
400,
"Your homeserver does not support this room version",
Codes.UNSUPPORTED_ROOM_VERSION,
)
- new_room_id = yield self._room_creation_handler.upgrade_room(
+ new_room_id = await self._room_creation_handler.upgrade_room(
requester, room_id, new_version
)
- ret = {
- "replacement_room": new_room_id,
- }
+ ret = {"replacement_room": new_room_id}
- defer.returnValue((200, ret))
+ return 200, ret
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index 120a713361..db829f3098 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -14,11 +14,11 @@
# limitations under the License.
import logging
-
-from twisted.internet import defer
+from typing import Tuple
from synapse.http import servlet
from synapse.http.servlet import parse_json_object_from_request
+from synapse.logging.opentracing import set_tag, trace
from synapse.rest.client.transactions import HttpTransactionCache
from ._base import client_patterns
@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
class SendToDeviceRestServlet(servlet.RestServlet):
PATTERNS = client_patterns(
- "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$",
+ "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$"
)
def __init__(self, hs):
@@ -42,25 +42,27 @@ class SendToDeviceRestServlet(servlet.RestServlet):
self.txns = HttpTransactionCache(hs)
self.device_message_handler = hs.get_device_message_handler()
+ @trace(opname="sendToDevice")
def on_PUT(self, request, message_type, txn_id):
+ set_tag("message_type", message_type)
+ set_tag("txn_id", txn_id)
return self.txns.fetch_or_execute_request(
request, self._put, request, message_type, txn_id
)
- @defer.inlineCallbacks
- def _put(self, request, message_type, txn_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def _put(self, request, message_type, txn_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
sender_user_id = requester.user.to_string()
- yield self.device_message_handler.send_device_message(
+ await self.device_message_handler.send_device_message(
sender_user_id, message_type, content["messages"]
)
- response = (200, {})
- defer.returnValue(response)
+ response = (200, {}) # type: Tuple[int, dict]
+ return response
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 148fc6c985..8fa68dd37f 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -18,10 +18,8 @@ import logging
from canonicaljson import json
-from twisted.internet import defer
-
from synapse.api.constants import PresenceState
-from synapse.api.errors import SynapseError
+from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
from synapse.events.utils import (
format_event_for_client_v2_without_room_id,
@@ -74,7 +72,7 @@ class SyncRestServlet(RestServlet):
"""
PATTERNS = client_patterns("/sync$")
- ALLOWED_PRESENCE = set(["online", "offline", "unavailable"])
+ ALLOWED_PRESENCE = {"online", "offline", "unavailable"}
def __init__(self, hs):
super(SyncRestServlet, self).__init__()
@@ -87,8 +85,7 @@ class SyncRestServlet(RestServlet):
self._server_notices_sender = hs.get_server_notices_sender()
self._event_serializer = hs.get_event_client_serializer()
- @defer.inlineCallbacks
- def on_GET(self, request):
+ async def on_GET(self, request):
if b"from" in request.args:
# /events used to use 'from', but /sync uses 'since'.
# Lets be helpful and whine if we see a 'from'.
@@ -96,50 +93,60 @@ class SyncRestServlet(RestServlet):
400, "'from' is not a valid query parameter. Did you mean 'since'?"
)
- requester = yield self.auth.get_user_by_req(
- request, allow_guest=True
- )
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
user = requester.user
device_id = requester.device_id
timeout = parse_integer(request, "timeout", default=0)
since = parse_string(request, "since")
set_presence = parse_string(
- request, "set_presence", default="online",
- allowed_values=self.ALLOWED_PRESENCE
+ request,
+ "set_presence",
+ default="online",
+ allowed_values=self.ALLOWED_PRESENCE,
)
filter_id = parse_string(request, "filter", default=None)
full_state = parse_boolean(request, "full_state", default=False)
logger.debug(
- "/sync: user=%r, timeout=%r, since=%r,"
- " set_presence=%r, filter_id=%r, device_id=%r" % (
- user, timeout, since, set_presence, filter_id, device_id
- )
+ "/sync: user=%r, timeout=%r, since=%r, "
+ "set_presence=%r, filter_id=%r, device_id=%r",
+ user,
+ timeout,
+ since,
+ set_presence,
+ filter_id,
+ device_id,
)
request_key = (user, timeout, since, filter_id, full_state, device_id)
- if filter_id:
- if filter_id.startswith('{'):
- try:
- filter_object = json.loads(filter_id)
- set_timeline_upper_limit(filter_object,
- self.hs.config.filter_timeline_limit)
- except Exception:
- raise SynapseError(400, "Invalid filter JSON")
- self.filtering.check_valid_filter(filter_object)
- filter = FilterCollection(filter_object)
- else:
- filter = yield self.filtering.get_user_filter(
- user.localpart, filter_id
+ if filter_id is None:
+ filter_collection = DEFAULT_FILTER_COLLECTION
+ elif filter_id.startswith("{"):
+ try:
+ filter_object = json.loads(filter_id)
+ set_timeline_upper_limit(
+ filter_object, self.hs.config.filter_timeline_limit
)
+ except Exception:
+ raise SynapseError(400, "Invalid filter JSON")
+ self.filtering.check_valid_filter(filter_object)
+ filter_collection = FilterCollection(filter_object)
else:
- filter = DEFAULT_FILTER_COLLECTION
+ try:
+ filter_collection = await self.filtering.get_user_filter(
+ user.localpart, filter_id
+ )
+ except StoreError as err:
+ if err.code != 404:
+ raise
+ # fix up the description and errcode to be more useful
+ raise SynapseError(400, "No such filter", errcode=Codes.INVALID_PARAM)
sync_config = SyncConfig(
user=user,
- filter_collection=filter,
+ filter_collection=filter_collection,
is_guest=requester.is_guest,
request_key=request_key,
device_id=device_id,
@@ -151,70 +158,70 @@ class SyncRestServlet(RestServlet):
since_token = None
# send any outstanding server notices to the user.
- yield self._server_notices_sender.on_user_syncing(user.to_string())
+ await self._server_notices_sender.on_user_syncing(user.to_string())
affect_presence = set_presence != PresenceState.OFFLINE
if affect_presence:
- yield self.presence_handler.set_state(user, {"presence": set_presence}, True)
+ await self.presence_handler.set_state(
+ user, {"presence": set_presence}, True
+ )
- context = yield self.presence_handler.user_syncing(
- user.to_string(), affect_presence=affect_presence,
+ context = await self.presence_handler.user_syncing(
+ user.to_string(), affect_presence=affect_presence
)
with context:
- sync_result = yield self.sync_handler.wait_for_sync_for_user(
- sync_config, since_token=since_token, timeout=timeout,
- full_state=full_state
+ sync_result = await self.sync_handler.wait_for_sync_for_user(
+ sync_config,
+ since_token=since_token,
+ timeout=timeout,
+ full_state=full_state,
)
time_now = self.clock.time_msec()
- response_content = yield self.encode_response(
- time_now, sync_result, requester.access_token_id, filter
+ response_content = await self.encode_response(
+ time_now, sync_result, requester.access_token_id, filter_collection
)
- defer.returnValue((200, response_content))
+ return 200, response_content
- @defer.inlineCallbacks
- def encode_response(self, time_now, sync_result, access_token_id, filter):
- if filter.event_format == 'client':
+ async def encode_response(self, time_now, sync_result, access_token_id, filter):
+ if filter.event_format == "client":
event_formatter = format_event_for_client_v2_without_room_id
- elif filter.event_format == 'federation':
+ elif filter.event_format == "federation":
event_formatter = format_event_raw
else:
- raise Exception("Unknown event format %s" % (filter.event_format, ))
+ raise Exception("Unknown event format %s" % (filter.event_format,))
- joined = yield self.encode_joined(
- sync_result.joined, time_now, access_token_id,
+ joined = await self.encode_joined(
+ sync_result.joined,
+ time_now,
+ access_token_id,
filter.event_fields,
event_formatter,
)
- invited = yield self.encode_invited(
- sync_result.invited, time_now, access_token_id,
- event_formatter,
+ invited = await self.encode_invited(
+ sync_result.invited, time_now, access_token_id, event_formatter
)
- archived = yield self.encode_archived(
- sync_result.archived, time_now, access_token_id,
+ archived = await self.encode_archived(
+ sync_result.archived,
+ time_now,
+ access_token_id,
filter.event_fields,
event_formatter,
)
- defer.returnValue({
+ return {
"account_data": {"events": sync_result.account_data},
"to_device": {"events": sync_result.to_device},
"device_lists": {
"changed": list(sync_result.device_lists.changed),
"left": list(sync_result.device_lists.left),
},
- "presence": SyncRestServlet.encode_presence(
- sync_result.presence, time_now
- ),
- "rooms": {
- "join": joined,
- "invite": invited,
- "leave": archived,
- },
+ "presence": SyncRestServlet.encode_presence(sync_result.presence, time_now),
+ "rooms": {"join": joined, "invite": invited, "leave": archived},
"groups": {
"join": sync_result.groups.join,
"invite": sync_result.groups.invite,
@@ -222,7 +229,7 @@ class SyncRestServlet(RestServlet):
},
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
"next_batch": sync_result.next_batch.to_string(),
- })
+ }
@staticmethod
def encode_presence(events, time_now):
@@ -239,8 +246,9 @@ class SyncRestServlet(RestServlet):
]
}
- @defer.inlineCallbacks
- def encode_joined(self, rooms, time_now, token_id, event_fields, event_formatter):
+ async def encode_joined(
+ self, rooms, time_now, token_id, event_fields, event_formatter
+ ):
"""
Encode the joined rooms in a sync result
@@ -261,15 +269,18 @@ class SyncRestServlet(RestServlet):
"""
joined = {}
for room in rooms:
- joined[room.room_id] = yield self.encode_room(
- room, time_now, token_id, joined=True, only_fields=event_fields,
+ joined[room.room_id] = await self.encode_room(
+ room,
+ time_now,
+ token_id,
+ joined=True,
+ only_fields=event_fields,
event_formatter=event_formatter,
)
- defer.returnValue(joined)
+ return joined
- @defer.inlineCallbacks
- def encode_invited(self, rooms, time_now, token_id, event_formatter):
+ async def encode_invited(self, rooms, time_now, token_id, event_formatter):
"""
Encode the invited rooms in a sync result
@@ -289,8 +300,10 @@ class SyncRestServlet(RestServlet):
"""
invited = {}
for room in rooms:
- invite = yield self._event_serializer.serialize_event(
- room.invite, time_now, token_id=token_id,
+ invite = await self._event_serializer.serialize_event(
+ room.invite,
+ time_now,
+ token_id=token_id,
event_format=event_formatter,
is_invite=True,
)
@@ -298,14 +311,13 @@ class SyncRestServlet(RestServlet):
invite["unsigned"] = unsigned
invited_state = list(unsigned.pop("invite_room_state", []))
invited_state.append(invite)
- invited[room.room_id] = {
- "invite_state": {"events": invited_state}
- }
+ invited[room.room_id] = {"invite_state": {"events": invited_state}}
- defer.returnValue(invited)
+ return invited
- @defer.inlineCallbacks
- def encode_archived(self, rooms, time_now, token_id, event_fields, event_formatter):
+ async def encode_archived(
+ self, rooms, time_now, token_id, event_fields, event_formatter
+ ):
"""
Encode the archived rooms in a sync result
@@ -326,18 +338,19 @@ class SyncRestServlet(RestServlet):
"""
joined = {}
for room in rooms:
- joined[room.room_id] = yield self.encode_room(
- room, time_now, token_id, joined=False,
+ joined[room.room_id] = await self.encode_room(
+ room,
+ time_now,
+ token_id,
+ joined=False,
only_fields=event_fields,
event_formatter=event_formatter,
)
- defer.returnValue(joined)
+ return joined
- @defer.inlineCallbacks
- def encode_room(
- self, room, time_now, token_id, joined,
- only_fields, event_formatter,
+ async def encode_room(
+ self, room, time_now, token_id, joined, only_fields, event_formatter
):
"""
Args:
@@ -355,9 +368,11 @@ class SyncRestServlet(RestServlet):
Returns:
dict[str, object]: the room, encoded in our response format
"""
+
def serialize(events):
return self._event_serializer.serialize_events(
- events, time_now=time_now,
+ events,
+ time_now=time_now,
# We don't bundle "live" events, as otherwise clients
# will end up double counting annotations.
bundle_aggregations=False,
@@ -375,13 +390,15 @@ class SyncRestServlet(RestServlet):
# We've had bug reports that events were coming down under the
# wrong room.
if event.room_id != room.room_id:
- logger.warn(
+ logger.warning(
"Event %r is under room %r instead of %r",
- event.event_id, room.room_id, event.room_id,
+ event.event_id,
+ room.room_id,
+ event.room_id,
)
- serialized_state = yield serialize(state_events)
- serialized_timeline = yield serialize(timeline_events)
+ serialized_state = await serialize(state_events)
+ serialized_timeline = await serialize(timeline_events)
account_data = room.account_data
@@ -401,7 +418,7 @@ class SyncRestServlet(RestServlet):
result["unread_notifications"] = room.unread_notifications
result["summary"] = room.summary
- defer.returnValue(result)
+ return result
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index ebff7cff45..a3f12e8a77 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -29,24 +27,22 @@ class TagListServlet(RestServlet):
"""
GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1
"""
- PATTERNS = client_patterns(
- "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags"
- )
+
+ PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags")
def __init__(self, hs):
super(TagListServlet, self).__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def on_GET(self, request, user_id, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request, user_id, room_id):
+ requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get tags for other users.")
- tags = yield self.store.get_tags_for_room(user_id, room_id)
+ tags = await self.store.get_tags_for_room(user_id, room_id)
- defer.returnValue((200, {"tags": tags}))
+ return 200, {"tags": tags}
class TagServlet(RestServlet):
@@ -54,6 +50,7 @@ class TagServlet(RestServlet):
PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
"""
+
PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)"
)
@@ -64,35 +61,29 @@ class TagServlet(RestServlet):
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
- @defer.inlineCallbacks
- def on_PUT(self, request, user_id, room_id, tag):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, user_id, room_id, tag):
+ requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
body = parse_json_object_from_request(request)
- max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body)
+ max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body)
- self.notifier.on_new_event(
- "account_data_key", max_id, users=[user_id]
- )
+ self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
- defer.returnValue((200, {}))
+ return 200, {}
- @defer.inlineCallbacks
- def on_DELETE(self, request, user_id, room_id, tag):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_DELETE(self, request, user_id, room_id, tag):
+ requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
- max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)
+ max_id = await self.store.remove_tag_from_room(user_id, room_id, tag)
- self.notifier.on_new_event(
- "account_data_key", max_id, users=[user_id]
- )
+ self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
- defer.returnValue((200, {}))
+ return 200, {}
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index e7a987466a..23709960ad 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -16,8 +16,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.constants import ThirdPartyEntityKind
from synapse.http.servlet import RestServlet
@@ -35,12 +33,11 @@ class ThirdPartyProtocolsServlet(RestServlet):
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- @defer.inlineCallbacks
- def on_GET(self, request):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request):
+ await self.auth.get_user_by_req(request, allow_guest=True)
- protocols = yield self.appservice_handler.get_3pe_protocols()
- defer.returnValue((200, protocols))
+ protocols = await self.appservice_handler.get_3pe_protocols()
+ return 200, protocols
class ThirdPartyProtocolServlet(RestServlet):
@@ -52,17 +49,16 @@ class ThirdPartyProtocolServlet(RestServlet):
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, protocol):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, protocol):
+ await self.auth.get_user_by_req(request, allow_guest=True)
- protocols = yield self.appservice_handler.get_3pe_protocols(
- only_protocol=protocol,
+ protocols = await self.appservice_handler.get_3pe_protocols(
+ only_protocol=protocol
)
if protocol in protocols:
- defer.returnValue((200, protocols[protocol]))
+ return 200, protocols[protocol]
else:
- defer.returnValue((404, {"error": "Unknown protocol"}))
+ return 404, {"error": "Unknown protocol"}
class ThirdPartyUserServlet(RestServlet):
@@ -74,18 +70,17 @@ class ThirdPartyUserServlet(RestServlet):
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, protocol):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, protocol):
+ await self.auth.get_user_by_req(request, allow_guest=True)
fields = request.args
fields.pop(b"access_token", None)
- results = yield self.appservice_handler.query_3pe(
+ results = await self.appservice_handler.query_3pe(
ThirdPartyEntityKind.USER, protocol, fields
)
- defer.returnValue((200, results))
+ return 200, results
class ThirdPartyLocationServlet(RestServlet):
@@ -97,18 +92,17 @@ class ThirdPartyLocationServlet(RestServlet):
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, protocol):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, protocol):
+ await self.auth.get_user_by_req(request, allow_guest=True)
fields = request.args
fields.pop(b"access_token", None)
- results = yield self.appservice_handler.query_3pe(
+ results = await self.appservice_handler.query_3pe(
ThirdPartyEntityKind.LOCATION, protocol, fields
)
- defer.returnValue((200, results))
+ return 200, results
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py
index 6c366142e1..83f3b6b70a 100644
--- a/synapse/rest/client/v2_alpha/tokenrefresh.py
+++ b/synapse/rest/client/v2_alpha/tokenrefresh.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet
@@ -26,13 +24,13 @@ class TokenRefreshRestServlet(RestServlet):
Exchanges refresh tokens for a pair of an access token and a new refresh
token.
"""
+
PATTERNS = client_patterns("/tokenrefresh")
def __init__(self, hs):
super(TokenRefreshRestServlet, self).__init__()
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
raise AuthError(403, "tokenrefresh is no longer supported.")
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index b6f4d8b3f4..f9dfdce112 100644
--- a/synapse/rest/client/v2_alpha/user_directory.py
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -42,8 +42,7 @@ class UserDirectorySearchRestServlet(RestServlet):
self.user_directory_handler = hs.get_user_directory_handler()
self.http_client = hs.get_simple_http_client()
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
"""Searches for users in directory
Returns:
@@ -60,24 +59,23 @@ class UserDirectorySearchRestServlet(RestServlet):
]
}
"""
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
if not self.hs.config.user_directory_search_enabled:
- defer.returnValue((200, {
- "limited": False,
- "results": [],
- }))
+ return 200, {"limited": False, "results": []}
body = parse_json_object_from_request(request)
if self.hs.config.user_directory_defer_to_id_server:
- signed_body = sign_json(body, self.hs.hostname, self.hs.config.signing_key[0])
+ signed_body = sign_json(
+ body, self.hs.hostname, self.hs.config.signing_key[0]
+ )
url = "%s/_matrix/identity/api/v1/user_directory/search" % (
self.hs.config.user_directory_defer_to_id_server,
)
- resp = yield self.http_client.post_json_get_json(url, signed_body)
- defer.returnValue((200, resp))
+ resp = await self.http_client.post_json_get_json(url, signed_body)
+ return 200, resp
limit = body.get("limit", 10)
limit = min(limit, 50)
@@ -87,20 +85,19 @@ class UserDirectorySearchRestServlet(RestServlet):
except Exception:
raise SynapseError(400, "`search_term` is required field")
- results = yield self.user_directory_handler.search_users(
- user_id, search_term, limit,
+ results = await self.user_directory_handler.search_users(
+ user_id, search_term, limit
)
- defer.returnValue((200, results))
+ return 200, results
class UserInfoServlet(RestServlet):
"""
GET /user/{user_id}/info HTTP/1.1
"""
- PATTERNS = client_patterns(
- "/user/(?P<user_id>[^/]*)/info$"
- )
+
+ PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/info$")
def __init__(self, hs):
super(UserInfoServlet, self).__init__()
@@ -113,9 +110,7 @@ class UserInfoServlet(RestServlet):
registry = hs.get_federation_registry()
if not registry.query_handlers.get("user_info"):
- registry.register_query_handler(
- "user_info", self._on_federation_query
- )
+ registry.register_query_handler("user_info", self._on_federation_query)
@defer.inlineCallbacks
def on_GET(self, request, user_id):
@@ -127,7 +122,7 @@ class UserInfoServlet(RestServlet):
# Attempt to make a federation request to the server that owns this user
args = {"user_id": user_id}
res = yield self.transport_layer.make_query(
- user.domain, "user_info", args, retry_on_dns_fail=True,
+ user.domain, "user_info", args, retry_on_dns_fail=True
)
defer.returnValue((200, res))
@@ -174,10 +169,7 @@ class UserInfoServlet(RestServlet):
expiration_ts is not None and self.clock.time_msec() >= expiration_ts
)
- res = {
- "expired": is_expired,
- "deactivated": is_deactivated,
- }
+ res = {"expired": is_expired, "deactivated": is_deactivated}
defer.returnValue(res)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index babbf6a23c..c99250c2ee 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -1,5 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -24,29 +27,60 @@ logger = logging.getLogger(__name__)
class VersionsRestServlet(RestServlet):
PATTERNS = [re.compile("^/_matrix/client/versions$")]
+ def __init__(self, hs):
+ super(VersionsRestServlet, self).__init__()
+ self.config = hs.config
+
def on_GET(self, request):
- return (200, {
- "versions": [
- # XXX: at some point we need to decide whether we need to include
- # the previous version numbers, given we've defined r0.3.0 to be
- # backwards compatible with r0.2.0. But need to check how
- # conscientious we've been in compatibility, and decide whether the
- # middle number is the major revision when at 0.X.Y (as opposed to
- # X.Y.Z). And we need to decide whether it's fair to make clients
- # parse the version string to figure out what's going on.
- "r0.0.1",
- "r0.1.0",
- "r0.2.0",
- "r0.3.0",
- "r0.4.0",
- "r0.5.0",
- ],
- # as per MSC1497:
- "unstable_features": {
- "m.lazy_load_members": True,
- }
- })
-
-
-def register_servlets(http_server):
- VersionsRestServlet().register(http_server)
+ return (
+ 200,
+ {
+ "versions": [
+ # XXX: at some point we need to decide whether we need to include
+ # the previous version numbers, given we've defined r0.3.0 to be
+ # backwards compatible with r0.2.0. But need to check how
+ # conscientious we've been in compatibility, and decide whether the
+ # middle number is the major revision when at 0.X.Y (as opposed to
+ # X.Y.Z). And we need to decide whether it's fair to make clients
+ # parse the version string to figure out what's going on.
+ "r0.0.1",
+ "r0.1.0",
+ "r0.2.0",
+ "r0.3.0",
+ "r0.4.0",
+ "r0.5.0",
+ ],
+ # as per MSC1497:
+ "unstable_features": {
+ # as per MSC2190, as amended by MSC2264
+ # to be removed in r0.6.0
+ # "m.id_access_token": True,
+ # Advertise to clients that they need not include an `id_server`
+ # parameter during registration or password reset, as Synapse now decides
+ # itself which identity server to use (or none at all).
+ #
+ # This is also used by a client when they wish to bind a 3PID to their
+ # account, but not bind it to an identity server, the endpoint for which
+ # also requires `id_server`. If the homeserver is handling 3PID
+ # verification itself, there is no need to ask the user for `id_server` to
+ # be supplied.
+ # "m.require_identity_server": False,
+ # as per MSC2290
+ # "m.separate_add_and_bind": True,
+ # Implements support for label-based filtering as described in
+ # MSC2326.
+ "org.matrix.label_based_filtering": True,
+ # Implements support for cross signing as described in MSC1756
+ # "org.matrix.e2e_cross_signing": True,
+ # Implements additional endpoints as described in MSC2432
+ "org.matrix.msc2432": True,
+ # Tchap does not currently assume this rule for r0.5.0
+ # XXX: Remove this when it does
+ "m.lazy_load_members": True,
+ },
+ },
+ )
+
+
+def register_servlets(hs, http_server):
+ VersionsRestServlet(hs).register(http_server)
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 6b371bfa2f..1ddf9997ff 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -24,12 +24,14 @@ import jinja2
from jinja2 import TemplateNotFound
from twisted.internet import defer
-from twisted.web.resource import Resource
-from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import NotFoundError, StoreError, SynapseError
from synapse.config import ConfigError
-from synapse.http.server import finish_request, wrap_html_request_handler
+from synapse.http.server import (
+ DirectServeResource,
+ finish_request,
+ wrap_html_request_handler,
+)
from synapse.http.servlet import parse_string
from synapse.types import UserID
@@ -42,11 +44,12 @@ logger = logging.getLogger(__name__)
if hasattr(hmac, "compare_digest"):
compare_digest = hmac.compare_digest
else:
+
def compare_digest(a, b):
return a == b
-class ConsentResource(Resource):
+class ConsentResource(DirectServeResource):
"""A twisted Resource to display a privacy policy and gather consent to it
When accessed via GET, returns the privacy policy via a template.
@@ -80,12 +83,13 @@ class ConsentResource(Resource):
For POST: required; gives the value to be recorded in the database
against the user.
"""
+
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): homeserver
"""
- Resource.__init__(self)
+ super().__init__()
self.hs = hs
self.store = hs.get_datastore()
@@ -98,37 +102,30 @@ class ConsentResource(Resource):
if self._default_consent_version is None:
raise ConfigError(
"Consent resource is enabled but user_consent section is "
- "missing in config file.",
+ "missing in config file."
)
consent_template_directory = hs.config.user_consent_template_dir
loader = jinja2.FileSystemLoader(consent_template_directory)
self._jinja_env = jinja2.Environment(
- loader=loader,
- autoescape=jinja2.select_autoescape(['html', 'htm', 'xml']),
+ loader=loader, autoescape=jinja2.select_autoescape(["html", "htm", "xml"])
)
if hs.config.form_secret is None:
raise ConfigError(
"Consent resource is enabled but form_secret is not set in "
- "config file. It should be set to an arbitrary secret string.",
+ "config file. It should be set to an arbitrary secret string."
)
self._hmac_secret = hs.config.form_secret.encode("utf-8")
- def render_GET(self, request):
- self._async_render_GET(request)
- return NOT_DONE_YET
-
@wrap_html_request_handler
- @defer.inlineCallbacks
- def _async_render_GET(self, request):
+ async def _async_render_GET(self, request):
"""
Args:
request (twisted.web.http.Request):
"""
-
version = parse_string(request, "v", default=self._default_consent_version)
username = parse_string(request, "u", required=False, default="")
userhmac = None
@@ -139,12 +136,12 @@ class ConsentResource(Resource):
self._check_hash(username, userhmac_bytes)
- if username.startswith('@'):
+ if username.startswith("@"):
qualified_user_id = username
else:
qualified_user_id = UserID(username, self.hs.hostname).to_string()
- u = yield self.store.get_user_by_id(qualified_user_id)
+ u = await defer.maybeDeferred(self.store.get_user_by_id, qualified_user_id)
if u is None:
raise NotFoundError("Unknown user")
@@ -153,7 +150,8 @@ class ConsentResource(Resource):
try:
self._render_template(
- request, "%s.html" % (version,),
+ request,
+ "%s.html" % (version,),
user=username,
userhmac=userhmac,
version=version,
@@ -163,13 +161,8 @@ class ConsentResource(Resource):
except TemplateNotFound:
raise NotFoundError("Unknown policy version")
- def render_POST(self, request):
- self._async_render_POST(request)
- return NOT_DONE_YET
-
@wrap_html_request_handler
- @defer.inlineCallbacks
- def _async_render_POST(self, request):
+ async def _async_render_POST(self, request):
"""
Args:
request (twisted.web.http.Request):
@@ -180,18 +173,18 @@ class ConsentResource(Resource):
self._check_hash(username, userhmac)
- if username.startswith('@'):
+ if username.startswith("@"):
qualified_user_id = username
else:
qualified_user_id = UserID(username, self.hs.hostname).to_string()
try:
- yield self.store.user_set_consent_version(qualified_user_id, version)
+ await self.store.user_set_consent_version(qualified_user_id, version)
except StoreError as e:
if e.code != 404:
raise
raise NotFoundError("Unknown user")
- yield self.registration_handler.post_consent_actions(qualified_user_id)
+ await self.registration_handler.post_consent_actions(qualified_user_id)
try:
self._render_template(request, "success.html")
@@ -221,11 +214,13 @@ class ConsentResource(Resource):
SynapseError if the hash doesn't match
"""
- want_mac = hmac.new(
- key=self._hmac_secret,
- msg=userid.encode('utf-8'),
- digestmod=sha256,
- ).hexdigest().encode('ascii')
+ want_mac = (
+ hmac.new(
+ key=self._hmac_secret, msg=userid.encode("utf-8"), digestmod=sha256
+ )
+ .hexdigest()
+ .encode("ascii")
+ )
if not compare_digest(want_mac, userhmac):
raise SynapseError(http_client.FORBIDDEN, "HMAC incorrect")
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index ec0ec7b431..c16280f668 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -80,33 +80,27 @@ class LocalKey(Resource):
for key in self.config.signing_key:
verify_key_bytes = key.verify_key.encode()
key_id = "%s:%s" % (key.alg, key.version)
- verify_keys[key_id] = {
- u"key": encode_base64(verify_key_bytes)
- }
+ verify_keys[key_id] = {"key": encode_base64(verify_key_bytes)}
old_verify_keys = {}
for key_id, key in self.config.old_signing_keys.items():
verify_key_bytes = key.encode()
old_verify_keys[key_id] = {
- u"key": encode_base64(verify_key_bytes),
- u"expired_ts": key.expired_ts,
+ "key": encode_base64(verify_key_bytes),
+ "expired_ts": key.expired_ts,
}
tls_fingerprints = self.config.tls_fingerprints
json_object = {
- u"valid_until_ts": self.valid_until_ts,
- u"server_name": self.config.server_name,
- u"verify_keys": verify_keys,
- u"old_verify_keys": old_verify_keys,
- u"tls_fingerprints": tls_fingerprints,
+ "valid_until_ts": self.valid_until_ts,
+ "server_name": self.config.server_name,
+ "verify_keys": verify_keys,
+ "old_verify_keys": old_verify_keys,
+ "tls_fingerprints": tls_fingerprints,
}
for key in self.config.signing_key:
- json_object = sign_json(
- json_object,
- self.config.server_name,
- key,
- )
+ json_object = sign_json(json_object, self.config.server_name, key)
return json_object
def render_GET(self, request):
@@ -114,6 +108,4 @@ class LocalKey(Resource):
# Update the expiry time if less than half the interval remains.
if time_now + self.config.key_refresh_interval / 2 > self.valid_until_ts:
self.update_response_body(time_now)
- return respond_with_json_bytes(
- request, 200, self.response_body,
- )
+ return respond_with_json_bytes(request, 200, self.response_body)
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 8a730bbc35..ab671f7334 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -13,21 +13,24 @@
# limitations under the License.
import logging
-from io import BytesIO
+from typing import Dict, Set
-from twisted.internet import defer
-from twisted.web.resource import Resource
-from twisted.web.server import NOT_DONE_YET
+from canonicaljson import encode_canonical_json, json
+from signedjson.sign import sign_json
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher
-from synapse.http.server import respond_with_json_bytes, wrap_json_request_handler
+from synapse.http.server import (
+ DirectServeResource,
+ respond_with_json_bytes,
+ wrap_json_request_handler,
+)
from synapse.http.servlet import parse_integer, parse_json_object_from_request
logger = logging.getLogger(__name__)
-class RemoteKey(Resource):
+class RemoteKey(DirectServeResource):
"""HTTP resource for retreiving the TLS certificate and NACL signature
verification keys for a collection of servers. Checks that the reported
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
@@ -93,55 +96,41 @@ class RemoteKey(Resource):
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
-
- def render_GET(self, request):
- self.async_render_GET(request)
- return NOT_DONE_YET
+ self.config = hs.config
@wrap_json_request_handler
- @defer.inlineCallbacks
- def async_render_GET(self, request):
+ async def _async_render_GET(self, request):
if len(request.postpath) == 1:
- server, = request.postpath
- query = {server.decode('ascii'): {}}
+ (server,) = request.postpath
+ query = {server.decode("ascii"): {}} # type: dict
elif len(request.postpath) == 2:
server, key_id = request.postpath
- minimum_valid_until_ts = parse_integer(
- request, "minimum_valid_until_ts"
- )
+ minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
arguments = {}
if minimum_valid_until_ts is not None:
arguments["minimum_valid_until_ts"] = minimum_valid_until_ts
- query = {server.decode('ascii'): {key_id.decode('ascii'): arguments}}
+ query = {server.decode("ascii"): {key_id.decode("ascii"): arguments}}
else:
- raise SynapseError(
- 404, "Not found %r" % request.postpath, Codes.NOT_FOUND
- )
-
- yield self.query_keys(request, query, query_remote_on_cache_miss=True)
+ raise SynapseError(404, "Not found %r" % request.postpath, Codes.NOT_FOUND)
- def render_POST(self, request):
- self.async_render_POST(request)
- return NOT_DONE_YET
+ await self.query_keys(request, query, query_remote_on_cache_miss=True)
@wrap_json_request_handler
- @defer.inlineCallbacks
- def async_render_POST(self, request):
+ async def _async_render_POST(self, request):
content = parse_json_object_from_request(request)
query = content["server_keys"]
- yield self.query_keys(request, query, query_remote_on_cache_miss=True)
+ await self.query_keys(request, query, query_remote_on_cache_miss=True)
- @defer.inlineCallbacks
- def query_keys(self, request, query, query_remote_on_cache_miss=False):
+ async def query_keys(self, request, query, query_remote_on_cache_miss=False):
logger.info("Handling query for keys %r", query)
store_queries = []
for server_name, key_ids in query.items():
if (
- self.federation_domain_whitelist is not None and
- server_name not in self.federation_domain_whitelist
+ self.federation_domain_whitelist is not None
+ and server_name not in self.federation_domain_whitelist
):
logger.debug("Federation denied with %s", server_name)
continue
@@ -151,17 +140,15 @@ class RemoteKey(Resource):
for key_id in key_ids:
store_queries.append((server_name, key_id, None))
- cached = yield self.store.get_server_keys_json(store_queries)
+ cached = await self.store.get_server_keys_json(store_queries)
json_results = set()
time_now_ms = self.clock.time_msec()
- cache_misses = dict()
+ cache_misses = {} # type: Dict[str, Set[str]]
for (server_name, key_id, from_server), results in cached.items():
- results = [
- (result["ts_added_ms"], result) for result in results
- ]
+ results = [(result["ts_added_ms"], result) for result in results]
if not results and key_id is not None:
cache_misses.setdefault(server_name, set()).add(key_id)
@@ -178,23 +165,30 @@ class RemoteKey(Resource):
logger.debug(
"Cached response for %r/%r is older than requested"
": valid_until (%r) < minimum_valid_until (%r)",
- server_name, key_id,
- ts_valid_until_ms, req_valid_until
+ server_name,
+ key_id,
+ ts_valid_until_ms,
+ req_valid_until,
)
miss = True
else:
logger.debug(
"Cached response for %r/%r is newer than requested"
": valid_until (%r) >= minimum_valid_until (%r)",
- server_name, key_id,
- ts_valid_until_ms, req_valid_until
+ server_name,
+ key_id,
+ ts_valid_until_ms,
+ req_valid_until,
)
elif (ts_added_ms + ts_valid_until_ms) / 2 < time_now_ms:
logger.debug(
"Cached response for %r/%r is too old"
": (added (%r) + valid_until (%r)) / 2 < now (%r)",
- server_name, key_id,
- ts_added_ms, ts_valid_until_ms, time_now_ms
+ server_name,
+ key_id,
+ ts_added_ms,
+ ts_valid_until_ms,
+ time_now_ms,
)
# We more than half way through the lifetime of the
# response. We should fetch a fresh copy.
@@ -203,8 +197,11 @@ class RemoteKey(Resource):
logger.debug(
"Cached response for %r/%r is still valid"
": (added (%r) + valid_until (%r)) / 2 < now (%r)",
- server_name, key_id,
- ts_added_ms, ts_valid_until_ms, time_now_ms
+ server_name,
+ key_id,
+ ts_added_ms,
+ ts_valid_until_ms,
+ time_now_ms,
)
if miss:
@@ -215,22 +212,17 @@ class RemoteKey(Resource):
json_results.add(bytes(result["key_json"]))
if cache_misses and query_remote_on_cache_miss:
- yield self.fetcher.get_keys(cache_misses)
- yield self.query_keys(
- request, query, query_remote_on_cache_miss=False
- )
+ await self.fetcher.get_keys(cache_misses)
+ await self.query_keys(request, query, query_remote_on_cache_miss=False)
else:
- result_io = BytesIO()
- result_io.write(b"{\"server_keys\":")
- sep = b"["
- for json_bytes in json_results:
- result_io.write(sep)
- result_io.write(json_bytes)
- sep = b","
- if sep == b"[":
- result_io.write(sep)
- result_io.write(b"]}")
-
- respond_with_json_bytes(
- request, 200, result_io.getvalue(),
- )
+ signed_keys = []
+ for key_json in json_results:
+ key_json = json.loads(key_json)
+ for signing_key in self.config.key_server_signing_keys:
+ key_json = sign_json(key_json, self.config.server_name, signing_key)
+
+ signed_keys.append(key_json)
+
+ results = {"server_keys": signed_keys}
+
+ respond_with_json_bytes(request, 200, encode_canonical_json(results))
diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py
deleted file mode 100644
index 5a426ff2f6..0000000000
--- a/synapse/rest/media/v0/content_repository.py
+++ /dev/null
@@ -1,102 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import base64
-import logging
-import os
-import re
-
-from canonicaljson import json
-
-from twisted.protocols.basic import FileSender
-from twisted.web import resource, server
-
-from synapse.api.errors import Codes, cs_error
-from synapse.http.server import finish_request, respond_with_json_bytes
-
-logger = logging.getLogger(__name__)
-
-
-class ContentRepoResource(resource.Resource):
- """Provides file uploading and downloading.
-
- Uploads are POSTed to wherever this Resource is linked to. This resource
- returns a "content token" which can be used to GET this content again. The
- token is typically a path, but it may not be. Tokens can expire, be
- one-time uses, etc.
-
- In this case, the token is a path to the file and contains 3 interesting
- sections:
- - User ID base64d (for namespacing content to each user)
- - random 24 char string
- - Content type base64d (so we can return it when clients GET it)
-
- """
- isLeaf = True
-
- def __init__(self, hs, directory):
- resource.Resource.__init__(self)
- self.hs = hs
- self.directory = directory
-
- def render_GET(self, request):
- # no auth here on purpose, to allow anyone to view, even across home
- # servers.
-
- # TODO: A little crude here, we could do this better.
- filename = request.path.decode('ascii').split('/')[-1]
- # be paranoid
- filename = re.sub("[^0-9A-z.-_]", "", filename)
-
- file_path = self.directory + "/" + filename
-
- logger.debug("Searching for %s", file_path)
-
- if os.path.isfile(file_path):
- # filename has the content type
- base64_contentype = filename.split(".")[1]
- content_type = base64.urlsafe_b64decode(base64_contentype)
- logger.info("Sending file %s", file_path)
- f = open(file_path, 'rb')
- request.setHeader('Content-Type', content_type)
-
- # cache for at least a day.
- # XXX: we might want to turn this off for data we don't want to
- # recommend caching as it's sensitive or private - or at least
- # select private. don't bother setting Expires as all our matrix
- # clients are smart enough to be happy with Cache-Control (right?)
- request.setHeader(
- b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
- )
-
- d = FileSender().beginFileTransfer(f, request)
-
- # after the file has been sent, clean up and finish the request
- def cbFinished(ignored):
- f.close()
- finish_request(request)
- d.addCallback(cbFinished)
- else:
- respond_with_json_bytes(
- request,
- 404,
- json.dumps(cs_error("Not found", code=Codes.NOT_FOUND)),
- send_cors=True)
-
- return server.NOT_DONE_YET
-
- def render_OPTIONS(self, request):
- respond_with_json_bytes(request, 200, {}, send_cors=True)
- return server.NOT_DONE_YET
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 2dcc8f74d6..503f2bed98 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -25,11 +25,27 @@ from twisted.protocols.basic import FileSender
from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.http.server import finish_request, respond_with_json
-from synapse.util import logcontext
+from synapse.logging.context import make_deferred_yieldable
from synapse.util.stringutils import is_ascii
logger = logging.getLogger(__name__)
+# list all text content types that will have the charset default to UTF-8 when
+# none is given
+TEXT_CONTENT_TYPES = [
+ "text/css",
+ "text/csv",
+ "text/html",
+ "text/calendar",
+ "text/plain",
+ "text/javascript",
+ "application/json",
+ "application/ld+json",
+ "application/rtf",
+ "image/svg+xml",
+ "text/xml",
+]
+
def parse_media_id(request):
try:
@@ -38,8 +54,8 @@ def parse_media_id(request):
server_name, media_id = request.postpath[:2]
if isinstance(server_name, bytes):
- server_name = server_name.decode('utf-8')
- media_id = media_id.decode('utf8')
+ server_name = server_name.decode("utf-8")
+ media_id = media_id.decode("utf8")
file_name = None
if len(request.postpath) > 2:
@@ -75,9 +91,7 @@ def respond_with_file(request, media_type, file_path, file_size=None, upload_nam
add_file_headers(request, media_type, file_size, upload_name)
with open(file_path, "rb") as f:
- yield logcontext.make_deferred_yieldable(
- FileSender().beginFileTransfer(f, request)
- )
+ yield make_deferred_yieldable(FileSender().beginFileTransfer(f, request))
finish_request(request)
else:
@@ -98,7 +112,14 @@ def add_file_headers(request, media_type, file_size, upload_name):
def _quote(x):
return urllib.parse.quote(x.encode("utf-8"))
- request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
+ # Default to a UTF-8 charset for text content types.
+ # ex, uses UTF-8 for 'text/css' but not 'text/css; charset=UTF-16'
+ if media_type.lower() in TEXT_CONTENT_TYPES:
+ content_type = media_type + "; charset=UTF-8"
+ else:
+ content_type = media_type
+
+ request.setHeader(b"Content-Type", content_type.encode("UTF-8"))
if upload_name:
# RFC6266 section 4.1 [1] defines both `filename` and `filename*`.
#
@@ -120,11 +141,11 @@ def add_file_headers(request, media_type, file_size, upload_name):
# correctly interpret those as of 0.99.2 and (b) they are a bit of a pain and we
# may as well just do the filename* version.
if _can_encode_filename_as_token(upload_name):
- disposition = 'inline; filename=%s' % (upload_name, )
+ disposition = "inline; filename=%s" % (upload_name,)
else:
- disposition = "inline; filename*=utf-8''%s" % (_quote(upload_name), )
+ disposition = "inline; filename*=utf-8''%s" % (_quote(upload_name),)
- request.setHeader(b"Content-Disposition", disposition.encode('ascii'))
+ request.setHeader(b"Content-Disposition", disposition.encode("ascii"))
# cache for at least a day.
# XXX: we might want to turn this off for data we don't want to
@@ -137,10 +158,25 @@ def add_file_headers(request, media_type, file_size, upload_name):
# separators as defined in RFC2616. SP and HT are handled separately.
# see _can_encode_filename_as_token.
-_FILENAME_SEPARATOR_CHARS = set((
- "(", ")", "<", ">", "@", ",", ";", ":", "\\", '"',
- "/", "[", "]", "?", "=", "{", "}",
-))
+_FILENAME_SEPARATOR_CHARS = {
+ "(",
+ ")",
+ "<",
+ ">",
+ "@",
+ ",",
+ ";",
+ ":",
+ "\\",
+ '"',
+ "/",
+ "[",
+ "]",
+ "?",
+ "=",
+ "{",
+ "}",
+}
def _can_encode_filename_as_token(x):
@@ -180,7 +216,7 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam
respond_404(request)
return
- logger.debug("Responding to media request with responder %s")
+ logger.debug("Responding to media request with responder %s", responder)
add_file_headers(request, media_type, file_size, upload_name)
try:
with responder:
@@ -271,7 +307,7 @@ def get_filename_from_headers(headers):
Returns:
A Unicode string of the filename, or None.
"""
- content_disposition = headers.get(b"Content-Disposition", [b''])
+ content_disposition = headers.get(b"Content-Disposition", [b""])
# No header, bail out.
if not content_disposition[0]:
@@ -293,7 +329,7 @@ def get_filename_from_headers(headers):
# Once it is decoded, we can then unquote the %-encoded
# parts strictly into a unicode string.
upload_name = urllib.parse.unquote(
- upload_name_utf8.decode('ascii'), errors="strict"
+ upload_name_utf8.decode("ascii"), errors="strict"
)
except UnicodeDecodeError:
# Incorrect UTF-8.
@@ -302,7 +338,7 @@ def get_filename_from_headers(headers):
# On Python 2, we first unquote the %-encoded parts and then
# decode it strictly using UTF-8.
try:
- upload_name = urllib.parse.unquote(upload_name_utf8).decode('utf8')
+ upload_name = urllib.parse.unquote(upload_name_utf8).decode("utf8")
except UnicodeDecodeError:
pass
@@ -310,7 +346,7 @@ def get_filename_from_headers(headers):
if not upload_name:
upload_name_ascii = params.get(b"filename", None)
if upload_name_ascii and is_ascii(upload_name_ascii):
- upload_name = upload_name_ascii.decode('ascii')
+ upload_name = upload_name_ascii.decode("ascii")
# This may be None here, indicating we did not find a matching name.
return upload_name
@@ -328,19 +364,19 @@ def _parse_header(line):
Tuple[bytes, dict[bytes, bytes]]:
the main content-type, followed by the parameter dictionary
"""
- parts = _parseparam(b';' + line)
+ parts = _parseparam(b";" + line)
key = next(parts)
pdict = {}
for p in parts:
- i = p.find(b'=')
+ i = p.find(b"=")
if i >= 0:
name = p[:i].strip().lower()
- value = p[i + 1:].strip()
+ value = p[i + 1 :].strip()
# strip double-quotes
if len(value) >= 2 and value[0:1] == value[-1:] == b'"':
value = value[1:-1]
- value = value.replace(b'\\\\', b'\\').replace(b'\\"', b'"')
+ value = value.replace(b"\\\\", b"\\").replace(b'\\"', b'"')
pdict[name] = value
return key, pdict
@@ -357,16 +393,16 @@ def _parseparam(s):
Returns:
Iterable[bytes]: the split input
"""
- while s[:1] == b';':
+ while s[:1] == b";":
s = s[1:]
# look for the next ;
- end = s.find(b';')
+ end = s.find(b";")
# if there is an odd number of " marks between here and the next ;, skip to the
# next ; instead
while end > 0 and (s.count(b'"', 0, end) - s.count(b'\\"', 0, end)) % 2:
- end = s.find(b';', end + 1)
+ end = s.find(b";", end + 1)
if end < 0:
end = len(s)
diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py
index 77316033f7..9f747de263 100644
--- a/synapse/rest/media/v1/config_resource.py
+++ b/synapse/rest/media/v1/config_resource.py
@@ -14,33 +14,28 @@
# limitations under the License.
#
-from twisted.internet import defer
-from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
-from synapse.http.server import respond_with_json, wrap_json_request_handler
+from synapse.http.server import (
+ DirectServeResource,
+ respond_with_json,
+ wrap_json_request_handler,
+)
-class MediaConfigResource(Resource):
+class MediaConfigResource(DirectServeResource):
isLeaf = True
def __init__(self, hs):
- Resource.__init__(self)
+ super().__init__()
config = hs.get_config()
self.clock = hs.get_clock()
self.auth = hs.get_auth()
- self.limits_dict = {
- "m.upload.size": config.max_upload_size,
- }
-
- def render_GET(self, request):
- self._async_render_GET(request)
- return NOT_DONE_YET
+ self.limits_dict = {"m.upload.size": config.max_upload_size}
@wrap_json_request_handler
- @defer.inlineCallbacks
- def _async_render_GET(self, request):
- yield self.auth.get_user_by_req(request)
+ async def _async_render_GET(self, request):
+ await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)
def render_OPTIONS(self, request):
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index bdc5daecc1..66a01559e1 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -14,37 +14,31 @@
# limitations under the License.
import logging
-from twisted.internet import defer
-from twisted.web.resource import Resource
-from twisted.web.server import NOT_DONE_YET
-
import synapse.http.servlet
-from synapse.http.server import set_cors_headers, wrap_json_request_handler
+from synapse.http.server import (
+ DirectServeResource,
+ set_cors_headers,
+ wrap_json_request_handler,
+)
from ._base import parse_media_id, respond_404
logger = logging.getLogger(__name__)
-class DownloadResource(Resource):
+class DownloadResource(DirectServeResource):
isLeaf = True
def __init__(self, hs, media_repo):
- Resource.__init__(self)
-
+ super().__init__()
self.media_repo = media_repo
self.server_name = hs.hostname
# this is expected by @wrap_json_request_handler
self.clock = hs.get_clock()
- def render_GET(self, request):
- self._async_render_GET(request)
- return NOT_DONE_YET
-
@wrap_json_request_handler
- @defer.inlineCallbacks
- def _async_render_GET(self, request):
+ async def _async_render_GET(self, request):
set_cors_headers(request)
request.setHeader(
b"Content-Security-Policy",
@@ -54,20 +48,22 @@ class DownloadResource(Resource):
b" plugin-types application/pdf;"
b" style-src 'unsafe-inline';"
b" media-src 'self';"
- b" object-src 'self';"
+ b" object-src 'self';",
)
server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name:
- yield self.media_repo.get_local_media(request, media_id, name)
+ await self.media_repo.get_local_media(request, media_id, name)
else:
allow_remote = synapse.http.servlet.parse_boolean(
- request, "allow_remote", default=True)
+ request, "allow_remote", default=True
+ )
if not allow_remote:
logger.info(
"Rejecting request for remote media %s/%s due to allow_remote",
- server_name, media_id,
+ server_name,
+ media_id,
)
respond_404(request)
return
- yield self.media_repo.get_remote_media(request, server_name, media_id, name)
+ await self.media_repo.get_remote_media(request, server_name, media_id, name)
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index c8586fa280..e25c382c9c 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -24,6 +24,7 @@ def _wrap_in_base_path(func):
"""Takes a function that returns a relative path and turns it into an
absolute path based on the location of the primary media store
"""
+
@functools.wraps(func)
def _wrapped(self, *args, **kwargs):
path = func(self, *args, **kwargs)
@@ -43,125 +44,102 @@ class MediaFilePaths(object):
def __init__(self, primary_base_path):
self.base_path = primary_base_path
- def default_thumbnail_rel(self, default_top_level, default_sub_type, width,
- height, content_type, method):
+ def default_thumbnail_rel(
+ self, default_top_level, default_sub_type, width, height, content_type, method
+ ):
top_level_type, sub_type = content_type.split("/")
- file_name = "%i-%i-%s-%s-%s" % (
- width, height, top_level_type, sub_type, method
- )
+ file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
- "default_thumbnails", default_top_level,
- default_sub_type, file_name
+ "default_thumbnails", default_top_level, default_sub_type, file_name
)
default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
def local_media_filepath_rel(self, media_id):
- return os.path.join(
- "local_content",
- media_id[0:2], media_id[2:4], media_id[4:]
- )
+ return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:])
local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
- def local_media_thumbnail_rel(self, media_id, width, height, content_type,
- method):
+ def local_media_thumbnail_rel(self, media_id, width, height, content_type, method):
top_level_type, sub_type = content_type.split("/")
- file_name = "%i-%i-%s-%s-%s" % (
- width, height, top_level_type, sub_type, method
- )
+ file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
- "local_thumbnails",
- media_id[0:2], media_id[2:4], media_id[4:],
- file_name
+ "local_thumbnails", media_id[0:2], media_id[2:4], media_id[4:], file_name
)
local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel)
def remote_media_filepath_rel(self, server_name, file_id):
return os.path.join(
- "remote_content", server_name,
- file_id[0:2], file_id[2:4], file_id[4:]
+ "remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:]
)
remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
- def remote_media_thumbnail_rel(self, server_name, file_id, width, height,
- content_type, method):
+ def remote_media_thumbnail_rel(
+ self, server_name, file_id, width, height, content_type, method
+ ):
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
return os.path.join(
- "remote_thumbnail", server_name,
- file_id[0:2], file_id[2:4], file_id[4:],
- file_name
+ "remote_thumbnail",
+ server_name,
+ file_id[0:2],
+ file_id[2:4],
+ file_id[4:],
+ file_name,
)
remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel)
def remote_media_thumbnail_dir(self, server_name, file_id):
return os.path.join(
- self.base_path, "remote_thumbnail", server_name,
- file_id[0:2], file_id[2:4], file_id[4:],
+ self.base_path,
+ "remote_thumbnail",
+ server_name,
+ file_id[0:2],
+ file_id[2:4],
+ file_id[4:],
)
def url_cache_filepath_rel(self, media_id):
if NEW_FORMAT_ID_RE.match(media_id):
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
- return os.path.join(
- "url_cache",
- media_id[:10], media_id[11:]
- )
+ return os.path.join("url_cache", media_id[:10], media_id[11:])
else:
- return os.path.join(
- "url_cache",
- media_id[0:2], media_id[2:4], media_id[4:],
- )
+ return os.path.join("url_cache", media_id[0:2], media_id[2:4], media_id[4:])
url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
def url_cache_filepath_dirs_to_delete(self, media_id):
"The dirs to try and remove if we delete the media_id file"
if NEW_FORMAT_ID_RE.match(media_id):
- return [
- os.path.join(
- self.base_path, "url_cache",
- media_id[:10],
- ),
- ]
+ return [os.path.join(self.base_path, "url_cache", media_id[:10])]
else:
return [
- os.path.join(
- self.base_path, "url_cache",
- media_id[0:2], media_id[2:4],
- ),
- os.path.join(
- self.base_path, "url_cache",
- media_id[0:2],
- ),
+ os.path.join(self.base_path, "url_cache", media_id[0:2], media_id[2:4]),
+ os.path.join(self.base_path, "url_cache", media_id[0:2]),
]
- def url_cache_thumbnail_rel(self, media_id, width, height, content_type,
- method):
+ def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method):
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
top_level_type, sub_type = content_type.split("/")
- file_name = "%i-%i-%s-%s-%s" % (
- width, height, top_level_type, sub_type, method
- )
+ file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
if NEW_FORMAT_ID_RE.match(media_id):
return os.path.join(
- "url_cache_thumbnails",
- media_id[:10], media_id[11:],
- file_name
+ "url_cache_thumbnails", media_id[:10], media_id[11:], file_name
)
else:
return os.path.join(
"url_cache_thumbnails",
- media_id[0:2], media_id[2:4], media_id[4:],
- file_name
+ media_id[0:2],
+ media_id[2:4],
+ media_id[4:],
+ file_name,
)
url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
@@ -172,13 +150,15 @@ class MediaFilePaths(object):
if NEW_FORMAT_ID_RE.match(media_id):
return os.path.join(
- self.base_path, "url_cache_thumbnails",
- media_id[:10], media_id[11:],
+ self.base_path, "url_cache_thumbnails", media_id[:10], media_id[11:]
)
else:
return os.path.join(
- self.base_path, "url_cache_thumbnails",
- media_id[0:2], media_id[2:4], media_id[4:],
+ self.base_path,
+ "url_cache_thumbnails",
+ media_id[0:2],
+ media_id[2:4],
+ media_id[4:],
)
def url_cache_thumbnail_dirs_to_delete(self, media_id):
@@ -188,26 +168,21 @@ class MediaFilePaths(object):
if NEW_FORMAT_ID_RE.match(media_id):
return [
os.path.join(
- self.base_path, "url_cache_thumbnails",
- media_id[:10], media_id[11:],
- ),
- os.path.join(
- self.base_path, "url_cache_thumbnails",
- media_id[:10],
+ self.base_path, "url_cache_thumbnails", media_id[:10], media_id[11:]
),
+ os.path.join(self.base_path, "url_cache_thumbnails", media_id[:10]),
]
else:
return [
os.path.join(
- self.base_path, "url_cache_thumbnails",
- media_id[0:2], media_id[2:4], media_id[4:],
- ),
- os.path.join(
- self.base_path, "url_cache_thumbnails",
- media_id[0:2], media_id[2:4],
+ self.base_path,
+ "url_cache_thumbnails",
+ media_id[0:2],
+ media_id[2:4],
+ media_id[4:],
),
os.path.join(
- self.base_path, "url_cache_thumbnails",
- media_id[0:2],
+ self.base_path, "url_cache_thumbnails", media_id[0:2], media_id[2:4]
),
+ os.path.join(self.base_path, "url_cache_thumbnails", media_id[0:2]),
]
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 8569677355..490b1b45a8 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -18,6 +18,7 @@ import errno
import logging
import os
import shutil
+from typing import Dict, Tuple
from six import iteritems
@@ -33,8 +34,9 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
+from synapse.config._base import ConfigError
+from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util import logcontext
from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import random_string
@@ -100,17 +102,16 @@ class MediaRepository(object):
storage_providers.append(provider)
self.media_storage = MediaStorage(
- self.hs, self.primary_base_path, self.filepaths, storage_providers,
+ self.hs, self.primary_base_path, self.filepaths, storage_providers
)
self.clock.looping_call(
- self._start_update_recently_accessed,
- UPDATE_RECENTLY_ACCESSED_TS,
+ self._start_update_recently_accessed, UPDATE_RECENTLY_ACCESSED_TS
)
def _start_update_recently_accessed(self):
return run_as_background_process(
- "update_recently_accessed_media", self._update_recently_accessed,
+ "update_recently_accessed_media", self._update_recently_accessed
)
@defer.inlineCallbacks
@@ -138,8 +139,9 @@ class MediaRepository(object):
self.recently_accessed_locals.add(media_id)
@defer.inlineCallbacks
- def create_content(self, media_type, upload_name, content, content_length,
- auth_user):
+ def create_content(
+ self, media_type, upload_name, content, content_length, auth_user
+ ):
"""Store uploaded content for a local user and return the mxc URL
Args:
@@ -154,10 +156,7 @@ class MediaRepository(object):
"""
media_id = random_string(24)
- file_info = FileInfo(
- server_name=None,
- file_id=media_id,
- )
+ file_info = FileInfo(server_name=None, file_id=media_id)
fname = yield self.media_storage.store_file(content, file_info)
@@ -172,11 +171,9 @@ class MediaRepository(object):
user_id=auth_user,
)
- yield self._generate_thumbnails(
- None, media_id, media_id, media_type,
- )
+ yield self._generate_thumbnails(None, media_id, media_id, media_type)
- defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
+ return "mxc://%s/%s" % (self.server_name, media_id)
@defer.inlineCallbacks
def get_local_media(self, request, media_id, name):
@@ -205,14 +202,11 @@ class MediaRepository(object):
upload_name = name if name else media_info["upload_name"]
url_cache = media_info["url_cache"]
- file_info = FileInfo(
- None, media_id,
- url_cache=url_cache,
- )
+ file_info = FileInfo(None, media_id, url_cache=url_cache)
responder = yield self.media_storage.fetch_media(file_info)
yield respond_with_responder(
- request, responder, media_type, media_length, upload_name,
+ request, responder, media_type, media_length, upload_name
)
@defer.inlineCallbacks
@@ -232,8 +226,8 @@ class MediaRepository(object):
to request
"""
if (
- self.federation_domain_whitelist is not None and
- server_name not in self.federation_domain_whitelist
+ self.federation_domain_whitelist is not None
+ and server_name not in self.federation_domain_whitelist
):
raise FederationDeniedError(server_name)
@@ -244,7 +238,7 @@ class MediaRepository(object):
key = (server_name, media_id)
with (yield self.remote_media_linearizer.queue(key)):
responder, media_info = yield self._get_remote_media_impl(
- server_name, media_id,
+ server_name, media_id
)
# We deliberately stream the file outside the lock
@@ -253,7 +247,7 @@ class MediaRepository(object):
media_length = media_info["media_length"]
upload_name = name if name else media_info["upload_name"]
yield respond_with_responder(
- request, responder, media_type, media_length, upload_name,
+ request, responder, media_type, media_length, upload_name
)
else:
respond_404(request)
@@ -272,8 +266,8 @@ class MediaRepository(object):
Deferred[dict]: The media_info of the file
"""
if (
- self.federation_domain_whitelist is not None and
- server_name not in self.federation_domain_whitelist
+ self.federation_domain_whitelist is not None
+ and server_name not in self.federation_domain_whitelist
):
raise FederationDeniedError(server_name)
@@ -282,7 +276,7 @@ class MediaRepository(object):
key = (server_name, media_id)
with (yield self.remote_media_linearizer.queue(key)):
responder, media_info = yield self._get_remote_media_impl(
- server_name, media_id,
+ server_name, media_id
)
# Ensure we actually use the responder so that it releases resources
@@ -290,7 +284,7 @@ class MediaRepository(object):
with responder:
pass
- defer.returnValue(media_info)
+ return media_info
@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
@@ -305,9 +299,7 @@ class MediaRepository(object):
Returns:
Deferred[(Responder, media_info)]
"""
- media_info = yield self.store.get_cached_remote_media(
- server_name, media_id
- )
+ media_info = yield self.store.get_cached_remote_media(server_name, media_id)
# file_id is the ID we use to track the file locally. If we've already
# seen the file then reuse the existing ID, otherwise genereate a new
@@ -327,16 +319,14 @@ class MediaRepository(object):
responder = yield self.media_storage.fetch_media(file_info)
if responder:
- defer.returnValue((responder, media_info))
+ return responder, media_info
# Failed to find the file anywhere, lets download it.
- media_info = yield self._download_remote_file(
- server_name, media_id, file_id
- )
+ media_info = yield self._download_remote_file(server_name, media_id, file_id)
responder = yield self.media_storage.fetch_media(file_info)
- defer.returnValue((responder, media_info))
+ return responder, media_info
@defer.inlineCallbacks
def _download_remote_file(self, server_name, media_id, file_id):
@@ -354,52 +344,62 @@ class MediaRepository(object):
Deferred[MediaInfo]
"""
- file_info = FileInfo(
- server_name=server_name,
- file_id=file_id,
- )
+ file_info = FileInfo(server_name=server_name, file_id=file_id)
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
- request_path = "/".join((
- "/_matrix/media/v1/download", server_name, media_id,
- ))
+ request_path = "/".join(
+ ("/_matrix/media/v1/download", server_name, media_id)
+ )
try:
length, headers = yield self.client.get_file(
- server_name, request_path, output_stream=f,
- max_size=self.max_upload_size, args={
+ server_name,
+ request_path,
+ output_stream=f,
+ max_size=self.max_upload_size,
+ args={
# tell the remote server to 404 if it doesn't
# recognise the server_name, to make sure we don't
# end up with a routing loop.
- "allow_remote": "false",
- }
+ "allow_remote": "false"
+ },
)
except RequestSendFailed as e:
- logger.warn("Request failed fetching remote media %s/%s: %r",
- server_name, media_id, e)
+ logger.warning(
+ "Request failed fetching remote media %s/%s: %r",
+ server_name,
+ media_id,
+ e,
+ )
raise SynapseError(502, "Failed to fetch remote media")
except HttpResponseException as e:
- logger.warn("HTTP error fetching remote media %s/%s: %s",
- server_name, media_id, e.response)
+ logger.warning(
+ "HTTP error fetching remote media %s/%s: %s",
+ server_name,
+ media_id,
+ e.response,
+ )
if e.code == twisted.web.http.NOT_FOUND:
raise e.to_synapse_error()
raise SynapseError(502, "Failed to fetch remote media")
except SynapseError:
- logger.exception("Failed to fetch remote media %s/%s",
- server_name, media_id)
+ logger.warning(
+ "Failed to fetch remote media %s/%s", server_name, media_id
+ )
raise
except NotRetryingDestination:
- logger.warn("Not retrying destination %r", server_name)
+ logger.warning("Not retrying destination %r", server_name)
raise SynapseError(502, "Failed to fetch remote media")
except Exception:
- logger.exception("Failed to fetch remote media %s/%s",
- server_name, media_id)
+ logger.exception(
+ "Failed to fetch remote media %s/%s", server_name, media_id
+ )
raise SynapseError(502, "Failed to fetch remote media")
yield finish()
- media_type = headers[b"Content-Type"][0].decode('ascii')
+ media_type = headers[b"Content-Type"][0].decode("ascii")
upload_name = get_filename_from_headers(headers)
time_now_ms = self.clock.time_msec()
@@ -423,24 +423,23 @@ class MediaRepository(object):
"filesystem_id": file_id,
}
- yield self._generate_thumbnails(
- server_name, media_id, file_id, media_type,
- )
+ yield self._generate_thumbnails(server_name, media_id, file_id, media_type)
- defer.returnValue(media_info)
+ return media_info
def _get_thumbnail_requirements(self, media_type):
return self.thumbnail_requirements.get(media_type, ())
- def _generate_thumbnail(self, thumbnailer, t_width, t_height,
- t_method, t_type):
+ def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type):
m_width = thumbnailer.width
m_height = thumbnailer.height
if m_width * m_height >= self.max_image_pixels:
logger.info(
"Image too large to thumbnail %r x %r > %r",
- m_width, m_height, self.max_image_pixels
+ m_width,
+ m_height,
+ self.max_image_pixels,
)
return
@@ -460,17 +459,22 @@ class MediaRepository(object):
return t_byte_source
@defer.inlineCallbacks
- def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
- t_method, t_type, url_cache):
- input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
- None, media_id, url_cache=url_cache,
- ))
+ def generate_local_exact_thumbnail(
+ self, media_id, t_width, t_height, t_method, t_type, url_cache
+ ):
+ input_path = yield self.media_storage.ensure_media_is_in_local_cache(
+ FileInfo(None, media_id, url_cache=url_cache)
+ )
thumbnailer = Thumbnailer(input_path)
- t_byte_source = yield logcontext.defer_to_thread(
+ t_byte_source = yield defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
- thumbnailer, t_width, t_height, t_method, t_type
+ thumbnailer,
+ t_width,
+ t_height,
+ t_method,
+ t_type,
)
if t_byte_source:
@@ -487,7 +491,7 @@ class MediaRepository(object):
)
output_path = yield self.media_storage.store_file(
- t_byte_source, file_info,
+ t_byte_source, file_info
)
finally:
t_byte_source.close()
@@ -500,27 +504,32 @@ class MediaRepository(object):
media_id, t_width, t_height, t_type, t_method, t_len
)
- defer.returnValue(output_path)
+ return output_path
@defer.inlineCallbacks
- def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
- t_width, t_height, t_method, t_type):
- input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
- server_name, file_id, url_cache=False,
- ))
+ def generate_remote_exact_thumbnail(
+ self, server_name, file_id, media_id, t_width, t_height, t_method, t_type
+ ):
+ input_path = yield self.media_storage.ensure_media_is_in_local_cache(
+ FileInfo(server_name, file_id, url_cache=False)
+ )
thumbnailer = Thumbnailer(input_path)
- t_byte_source = yield logcontext.defer_to_thread(
+ t_byte_source = yield defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
- thumbnailer, t_width, t_height, t_method, t_type
+ thumbnailer,
+ t_width,
+ t_height,
+ t_method,
+ t_type,
)
if t_byte_source:
try:
file_info = FileInfo(
server_name=server_name,
- file_id=media_id,
+ file_id=file_id,
thumbnail=True,
thumbnail_width=t_width,
thumbnail_height=t_height,
@@ -529,7 +538,7 @@ class MediaRepository(object):
)
output_path = yield self.media_storage.store_file(
- t_byte_source, file_info,
+ t_byte_source, file_info
)
finally:
t_byte_source.close()
@@ -539,15 +548,22 @@ class MediaRepository(object):
t_len = os.path.getsize(output_path)
yield self.store.store_remote_media_thumbnail(
- server_name, media_id, file_id,
- t_width, t_height, t_type, t_method, t_len
+ server_name,
+ media_id,
+ file_id,
+ t_width,
+ t_height,
+ t_type,
+ t_method,
+ t_len,
)
- defer.returnValue(output_path)
+ return output_path
@defer.inlineCallbacks
- def _generate_thumbnails(self, server_name, media_id, file_id, media_type,
- url_cache=False):
+ def _generate_thumbnails(
+ self, server_name, media_id, file_id, media_type, url_cache=False
+ ):
"""Generate and store thumbnails for an image.
Args:
@@ -566,9 +582,9 @@ class MediaRepository(object):
if not requirements:
return
- input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
- server_name, file_id, url_cache=url_cache,
- ))
+ input_path = yield self.media_storage.ensure_media_is_in_local_cache(
+ FileInfo(server_name, file_id, url_cache=url_cache)
+ )
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
@@ -577,19 +593,20 @@ class MediaRepository(object):
if m_width * m_height >= self.max_image_pixels:
logger.info(
"Image too large to thumbnail %r x %r > %r",
- m_width, m_height, self.max_image_pixels
+ m_width,
+ m_height,
+ self.max_image_pixels,
)
return
if thumbnailer.transpose_method is not None:
- m_width, m_height = yield logcontext.defer_to_thread(
- self.hs.get_reactor(),
- thumbnailer.transpose
+ m_width, m_height = yield defer_to_thread(
+ self.hs.get_reactor(), thumbnailer.transpose
)
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
# they have the same dimensions of a scaled one.
- thumbnails = {}
+ thumbnails = {} # type: Dict[Tuple[int, int, str], str]
for r_width, r_height, r_method, r_type in requirements:
if r_method == "crop":
thumbnails.setdefault((r_width, r_height, r_type), r_method)
@@ -603,16 +620,12 @@ class MediaRepository(object):
for (t_width, t_height, t_type), t_method in iteritems(thumbnails):
# Generate the thumbnail
if t_method == "crop":
- t_byte_source = yield logcontext.defer_to_thread(
- self.hs.get_reactor(),
- thumbnailer.crop,
- t_width, t_height, t_type,
+ t_byte_source = yield defer_to_thread(
+ self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type
)
elif t_method == "scale":
- t_byte_source = yield logcontext.defer_to_thread(
- self.hs.get_reactor(),
- thumbnailer.scale,
- t_width, t_height, t_type,
+ t_byte_source = yield defer_to_thread(
+ self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type
)
else:
logger.error("Unrecognized method: %r", t_method)
@@ -634,7 +647,7 @@ class MediaRepository(object):
)
output_path = yield self.media_storage.store_file(
- t_byte_source, file_info,
+ t_byte_source, file_info
)
finally:
t_byte_source.close()
@@ -644,18 +657,21 @@ class MediaRepository(object):
# Write to database
if server_name:
yield self.store.store_remote_media_thumbnail(
- server_name, media_id, file_id,
- t_width, t_height, t_type, t_method, t_len
+ server_name,
+ media_id,
+ file_id,
+ t_width,
+ t_height,
+ t_type,
+ t_method,
+ t_len,
)
else:
yield self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
- defer.returnValue({
- "width": m_width,
- "height": m_height,
- })
+ return {"width": m_width, "height": m_height}
@defer.inlineCallbacks
def delete_old_remote_media(self, before_ts):
@@ -678,7 +694,7 @@ class MediaRepository(object):
try:
os.remove(full_path)
except OSError as e:
- logger.warn("Failed to remove file: %r", full_path)
+ logger.warning("Failed to remove file: %r", full_path)
if e.errno == errno.ENOENT:
pass
else:
@@ -692,7 +708,7 @@ class MediaRepository(object):
yield self.store.delete_remote_media(origin, media_id)
deleted += 1
- defer.returnValue({"deleted": deleted})
+ return {"deleted": deleted}
class MediaRepositoryResource(Resource):
@@ -741,17 +757,21 @@ class MediaRepositoryResource(Resource):
"""
def __init__(self, hs):
- Resource.__init__(self)
+ # If we're not configured to use it, raise if we somehow got here.
+ if not hs.config.can_load_media_repo:
+ raise ConfigError("Synapse is not configured to use a media repo.")
+ super().__init__()
media_repo = hs.get_media_repository()
self.putChild(b"upload", UploadResource(hs, media_repo))
self.putChild(b"download", DownloadResource(hs, media_repo))
- self.putChild(b"thumbnail", ThumbnailResource(
- hs, media_repo, media_repo.media_storage,
- ))
+ self.putChild(
+ b"thumbnail", ThumbnailResource(hs, media_repo, media_repo.media_storage)
+ )
if hs.config.url_preview_enabled:
- self.putChild(b"preview_url", PreviewUrlResource(
- hs, media_repo, media_repo.media_storage,
- ))
+ self.putChild(
+ b"preview_url",
+ PreviewUrlResource(hs, media_repo, media_repo.media_storage),
+ )
self.putChild(b"config", MediaConfigResource(hs))
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 896078fe76..683a79c966 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -24,9 +24,8 @@ import six
from twisted.internet import defer
from twisted.protocols.basic import FileSender
-from synapse.util import logcontext
+from synapse.logging.context import defer_to_thread, make_deferred_yieldable
from synapse.util.file_consumer import BackgroundFileConsumer
-from synapse.util.logcontext import make_deferred_yieldable
from ._base import Responder
@@ -65,13 +64,12 @@ class MediaStorage(object):
with self.store_into_file(file_info) as (f, fname, finish_cb):
# Write to the main repository
- yield logcontext.defer_to_thread(
- self.hs.get_reactor(),
- _write_file_synchronously, source, f,
+ yield defer_to_thread(
+ self.hs.get_reactor(), _write_file_synchronously, source, f
)
yield finish_cb()
- defer.returnValue(fname)
+ return fname
@contextlib.contextmanager
def store_into_file(self, file_info):
@@ -145,14 +143,15 @@ class MediaStorage(object):
path = self._file_info_to_path(file_info)
local_path = os.path.join(self.local_media_directory, path)
if os.path.exists(local_path):
- defer.returnValue(FileResponder(open(local_path, "rb")))
+ return FileResponder(open(local_path, "rb"))
for provider in self.storage_providers:
res = yield provider.fetch(path, file_info)
if res:
- defer.returnValue(res)
+ logger.debug("Streaming %s from %s", path, provider)
+ return res
- defer.returnValue(None)
+ return None
@defer.inlineCallbacks
def ensure_media_is_in_local_cache(self, file_info):
@@ -168,7 +167,7 @@ class MediaStorage(object):
path = self._file_info_to_path(file_info)
local_path = os.path.join(self.local_media_directory, path)
if os.path.exists(local_path):
- defer.returnValue(local_path)
+ return local_path
dirname = os.path.dirname(local_path)
if not os.path.exists(dirname):
@@ -179,10 +178,11 @@ class MediaStorage(object):
if res:
with res:
consumer = BackgroundFileConsumer(
- open(local_path, "wb"), self.hs.get_reactor())
+ open(local_path, "wb"), self.hs.get_reactor()
+ )
yield res.write_to_consumer(consumer)
yield consumer.wait()
- defer.returnValue(local_path)
+ return local_path
raise Exception("file could not be found")
@@ -217,10 +217,10 @@ class MediaStorage(object):
width=file_info.thumbnail_width,
height=file_info.thumbnail_height,
content_type=file_info.thumbnail_type,
- method=file_info.thumbnail_method
+ method=file_info.thumbnail_method,
)
return self.filepaths.remote_media_filepath_rel(
- file_info.server_name, file_info.file_id,
+ file_info.server_name, file_info.file_id
)
if file_info.thumbnail:
@@ -229,11 +229,9 @@ class MediaStorage(object):
width=file_info.thumbnail_width,
height=file_info.thumbnail_height,
content_type=file_info.thumbnail_type,
- method=file_info.thumbnail_method
+ method=file_info.thumbnail_method,
)
- return self.filepaths.local_media_filepath_rel(
- file_info.file_id,
- )
+ return self.filepaths.local_media_filepath_rel(file_info.file_id)
def _write_file_synchronously(source, dest):
@@ -255,6 +253,7 @@ class FileResponder(Responder):
open_file (file): A file like object to be streamed ot the client,
is closed when finished streaming.
"""
+
def __init__(self, open_file):
self.open_file = open_file
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 85a7c61a24..07e395cfd1 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -23,6 +23,7 @@ import re
import shutil
import sys
import traceback
+from typing import Dict, Optional
import six
from six import string_types
@@ -32,22 +33,21 @@ from canonicaljson import json
from twisted.internet import defer
from twisted.internet.error import DNSLookupError
-from twisted.web.resource import Resource
-from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SimpleHttpClient
from synapse.http.server import (
+ DirectServeResource,
respond_with_json,
respond_with_json_bytes,
wrap_json_request_handler,
)
from synapse.http.servlet import parse_integer, parse_string
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.stringutils import random_string
from ._base import FileInfo
@@ -57,12 +57,15 @@ logger = logging.getLogger(__name__)
_charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I)
_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
+OG_TAG_NAME_MAXLEN = 50
+OG_TAG_VALUE_MAXLEN = 1000
-class PreviewUrlResource(Resource):
+
+class PreviewUrlResource(DirectServeResource):
isLeaf = True
def __init__(self, hs, media_repo, media_storage):
- Resource.__init__(self)
+ super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
@@ -75,8 +78,8 @@ class PreviewUrlResource(Resource):
treq_args={"browser_like_redirects": True},
ip_whitelist=hs.config.url_preview_ip_range_whitelist,
ip_blacklist=hs.config.url_preview_ip_range_blacklist,
- http_proxy=os.getenv("http_proxy"),
- https_proxy=os.getenv("HTTPS_PROXY"),
+ http_proxy=os.getenvb(b"http_proxy"),
+ https_proxy=os.getenvb(b"HTTPS_PROXY"),
)
self.media_repo = media_repo
self.primary_base_path = media_repo.primary_base_path
@@ -94,22 +97,18 @@ class PreviewUrlResource(Resource):
)
self._cleaner_loop = self.clock.looping_call(
- self._start_expire_url_cache_data, 10 * 1000,
+ self._start_expire_url_cache_data, 10 * 1000
)
def render_OPTIONS(self, request):
+ request.setHeader(b"Allow", b"OPTIONS, GET")
return respond_with_json(request, 200, {}, send_cors=True)
- def render_GET(self, request):
- self._async_render_GET(request)
- return NOT_DONE_YET
-
@wrap_json_request_handler
- @defer.inlineCallbacks
- def _async_render_GET(self, request):
+ async def _async_render_GET(self, request):
# XXX: if get_user_by_req fails, what should we do in an async render?
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
url = parse_string(request, "url")
if b"ts" in request.args:
ts = parse_integer(request, "ts")
@@ -123,16 +122,18 @@ class PreviewUrlResource(Resource):
for attrib in entry:
pattern = entry[attrib]
value = getattr(url_tuple, attrib)
- logger.debug((
- "Matching attrib '%s' with value '%s' against"
- " pattern '%s'"
- ) % (attrib, value, pattern))
+ logger.debug(
+ "Matching attrib '%s' with value '%s' against pattern '%s'",
+ attrib,
+ value,
+ pattern,
+ )
if value is None:
match = False
continue
- if pattern.startswith('^'):
+ if pattern.startswith("^"):
if not re.match(pattern, getattr(url_tuple, attrib)):
match = False
continue
@@ -141,12 +142,9 @@ class PreviewUrlResource(Resource):
match = False
continue
if match:
- logger.warn(
- "URL %s blocked by url_blacklist entry %s", url, entry
- )
+ logger.warning("URL %s blocked by url_blacklist entry %s", url, entry)
raise SynapseError(
- 403, "URL blocked by url pattern blacklist entry",
- Codes.UNKNOWN
+ 403, "URL blocked by url pattern blacklist entry", Codes.UNKNOWN
)
# the in-memory cache:
@@ -158,19 +156,13 @@ class PreviewUrlResource(Resource):
observable = self._cache.get(url)
if not observable:
- download = run_in_background(
- self._do_preview,
- url, requester.user, ts,
- )
- observable = ObservableDeferred(
- download,
- consumeErrors=True
- )
+ download = run_in_background(self._do_preview, url, requester.user, ts)
+ observable = ObservableDeferred(download, consumeErrors=True)
self._cache[url] = observable
else:
logger.info("Returning cached response")
- og = yield make_deferred_yieldable(observable.observe())
+ og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe))
respond_with_json_bytes(request, 200, og, send_cors=True)
@defer.inlineCallbacks
@@ -183,55 +175,52 @@ class PreviewUrlResource(Resource):
ts (int):
Returns:
- Deferred[str]: json-encoded og data
+ Deferred[bytes]: json-encoded og data
"""
# check the URL cache in the DB (which will also provide us with
# historical previews, if we have any)
cache_result = yield self.store.get_url_cache(url, ts)
if (
- cache_result and
- cache_result["expires_ts"] > ts and
- cache_result["response_code"] / 100 == 2
+ cache_result
+ and cache_result["expires_ts"] > ts
+ and cache_result["response_code"] / 100 == 2
):
# It may be stored as text in the database, not as bytes (such as
# PostgreSQL). If so, encode it back before handing it on.
og = cache_result["og"]
if isinstance(og, six.text_type):
- og = og.encode('utf8')
- defer.returnValue(og)
- return
+ og = og.encode("utf8")
+ return og
media_info = yield self._download_url(url, user)
- logger.debug("got media_info of '%s'" % media_info)
+ logger.debug("got media_info of '%s'", media_info)
- if _is_media(media_info['media_type']):
- file_id = media_info['filesystem_id']
+ if _is_media(media_info["media_type"]):
+ file_id = media_info["filesystem_id"]
dims = yield self.media_repo._generate_thumbnails(
- None, file_id, file_id, media_info["media_type"],
- url_cache=True,
+ None, file_id, file_id, media_info["media_type"], url_cache=True
)
og = {
- "og:description": media_info['download_name'],
- "og:image": "mxc://%s/%s" % (
- self.server_name, media_info['filesystem_id']
- ),
- "og:image:type": media_info['media_type'],
- "matrix:image:size": media_info['media_length'],
+ "og:description": media_info["download_name"],
+ "og:image": "mxc://%s/%s"
+ % (self.server_name, media_info["filesystem_id"]),
+ "og:image:type": media_info["media_type"],
+ "matrix:image:size": media_info["media_length"],
}
if dims:
- og["og:image:width"] = dims['width']
- og["og:image:height"] = dims['height']
+ og["og:image:width"] = dims["width"]
+ og["og:image:height"] = dims["height"]
else:
- logger.warn("Couldn't get dims for %s" % url)
+ logger.warning("Couldn't get dims for %s" % url)
# define our OG response for this media
- elif _is_html(media_info['media_type']):
+ elif _is_html(media_info["media_type"]):
# TODO: somehow stop a big HTML tree from exploding synapse's RAM
- with open(media_info['filename'], 'rb') as file:
+ with open(media_info["filename"], "rb") as file:
body = file.read()
encoding = None
@@ -244,54 +233,64 @@ class PreviewUrlResource(Resource):
# If we find a match, it should take precedence over the
# Content-Type header, so set it here.
if match:
- encoding = match.group(1).decode('ascii')
+ encoding = match.group(1).decode("ascii")
# If we don't find a match, we'll look at the HTTP Content-Type, and
# if that doesn't exist, we'll fall back to UTF-8.
if not encoding:
- match = _content_type_match.match(
- media_info['media_type']
- )
- encoding = match.group(1) if match else "utf-8"
+ content_match = _content_type_match.match(media_info["media_type"])
+ encoding = content_match.group(1) if content_match else "utf-8"
- og = decode_and_calc_og(body, media_info['uri'], encoding)
+ og = decode_and_calc_og(body, media_info["uri"], encoding)
# pre-cache the image for posterity
# FIXME: it might be cleaner to use the same flow as the main /preview_url
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
- if 'og:image' in og and og['og:image']:
+ if "og:image" in og and og["og:image"]:
image_info = yield self._download_url(
- _rebase_url(og['og:image'], media_info['uri']), user
+ _rebase_url(og["og:image"], media_info["uri"]), user
)
- if _is_media(image_info['media_type']):
+ if _is_media(image_info["media_type"]):
# TODO: make sure we don't choke on white-on-transparent images
- file_id = image_info['filesystem_id']
+ file_id = image_info["filesystem_id"]
dims = yield self.media_repo._generate_thumbnails(
- None, file_id, file_id, image_info["media_type"],
- url_cache=True,
+ None, file_id, file_id, image_info["media_type"], url_cache=True
)
if dims:
- og["og:image:width"] = dims['width']
- og["og:image:height"] = dims['height']
+ og["og:image:width"] = dims["width"]
+ og["og:image:height"] = dims["height"]
else:
- logger.warn("Couldn't get dims for %s" % og["og:image"])
+ logger.warning("Couldn't get dims for %s", og["og:image"])
og["og:image"] = "mxc://%s/%s" % (
- self.server_name, image_info['filesystem_id']
+ self.server_name,
+ image_info["filesystem_id"],
)
- og["og:image:type"] = image_info['media_type']
- og["matrix:image:size"] = image_info['media_length']
+ og["og:image:type"] = image_info["media_type"]
+ og["matrix:image:size"] = image_info["media_length"]
else:
del og["og:image"]
else:
- logger.warn("Failed to find any OG data in %s", url)
+ logger.warning("Failed to find any OG data in %s", url)
og = {}
- logger.debug("Calculated OG for %s as %s" % (url, og))
+ # filter out any stupidly long values
+ keys_to_remove = []
+ for k, v in og.items():
+ # values can be numeric as well as strings, hence the cast to str
+ if len(k) > OG_TAG_NAME_MAXLEN or len(str(v)) > OG_TAG_VALUE_MAXLEN:
+ logger.warning(
+ "Pruning overlong tag %s from OG data", k[:OG_TAG_NAME_MAXLEN]
+ )
+ keys_to_remove.append(k)
+ for k in keys_to_remove:
+ del og[k]
+
+ logger.debug("Calculated OG for %s as %s", url, og)
- jsonog = json.dumps(og).encode('utf8')
+ jsonog = json.dumps(og)
# store OG in history-aware DB cache
yield self.store.store_url_cache(
@@ -304,7 +303,7 @@ class PreviewUrlResource(Resource):
media_info["created_ts"],
)
- defer.returnValue(jsonog)
+ return jsonog.encode("utf8")
@defer.inlineCallbacks
def _download_url(self, url, user):
@@ -312,19 +311,15 @@ class PreviewUrlResource(Resource):
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
- file_id = datetime.date.today().isoformat() + '_' + random_string(16)
+ file_id = datetime.date.today().isoformat() + "_" + random_string(16)
- file_info = FileInfo(
- server_name=None,
- file_id=file_id,
- url_cache=True,
- )
+ file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True)
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
try:
- logger.debug("Trying to get url '%s'" % url)
+ logger.debug("Trying to get url '%s'", url)
length, headers, uri, code = yield self.client.get_file(
- url, output_stream=f, max_size=self.max_spider_size,
+ url, output_stream=f, max_size=self.max_spider_size
)
except SynapseError:
# Pass SynapseErrors through directly, so that the servlet
@@ -336,24 +331,25 @@ class PreviewUrlResource(Resource):
# Note: This will also be the case if one of the resolved IP
# addresses is blacklisted
raise SynapseError(
- 502, "DNS resolution failure during URL preview generation",
- Codes.UNKNOWN
+ 502,
+ "DNS resolution failure during URL preview generation",
+ Codes.UNKNOWN,
)
except Exception as e:
# FIXME: pass through 404s and other error messages nicely
- logger.warn("Error downloading %s: %r", url, e)
+ logger.warning("Error downloading %s: %r", url, e)
raise SynapseError(
- 500, "Failed to download content: %s" % (
- traceback.format_exception_only(sys.exc_info()[0], e),
- ),
+ 500,
+ "Failed to download content: %s"
+ % (traceback.format_exception_only(sys.exc_info()[0], e),),
Codes.UNKNOWN,
)
yield finish()
try:
if b"Content-Type" in headers:
- media_type = headers[b"Content-Type"][0].decode('ascii')
+ media_type = headers[b"Content-Type"][0].decode("ascii")
else:
media_type = "application/octet-stream"
time_now_ms = self.clock.time_msec()
@@ -377,7 +373,7 @@ class PreviewUrlResource(Resource):
# therefore not expire it.
raise
- defer.returnValue({
+ return {
"media_type": media_type,
"media_length": length,
"download_name": download_name,
@@ -390,11 +386,11 @@ class PreviewUrlResource(Resource):
# Cache-Control and Expire headers. But for now, assume 1 hour.
"expires": 60 * 60 * 1000,
"etag": headers["ETag"][0] if "ETag" in headers else None,
- })
+ }
def _start_expire_url_cache_data(self):
return run_as_background_process(
- "expire_url_cache_data", self._expire_url_cache_data,
+ "expire_url_cache_data", self._expire_url_cache_data
)
@defer.inlineCallbacks
@@ -407,7 +403,7 @@ class PreviewUrlResource(Resource):
logger.info("Running url preview cache expiry")
- if not (yield self.store.has_completed_background_updates()):
+ if not (yield self.store.db.updates.has_completed_background_updates()):
logger.info("Still running DB updates; skipping expiry")
return
@@ -422,7 +418,7 @@ class PreviewUrlResource(Resource):
except OSError as e:
# If the path doesn't exist, meh
if e.errno != errno.ENOENT:
- logger.warn("Failed to remove media: %r: %s", media_id, e)
+ logger.warning("Failed to remove media: %r: %s", media_id, e)
continue
removed_media.append(media_id)
@@ -454,7 +450,7 @@ class PreviewUrlResource(Resource):
except OSError as e:
# If the path doesn't exist, meh
if e.errno != errno.ENOENT:
- logger.warn("Failed to remove media: %r: %s", media_id, e)
+ logger.warning("Failed to remove media: %r: %s", media_id, e)
continue
try:
@@ -470,7 +466,7 @@ class PreviewUrlResource(Resource):
except OSError as e:
# If the path doesn't exist, meh
if e.errno != errno.ENOENT:
- logger.warn("Failed to remove media: %r: %s", media_id, e)
+ logger.warning("Failed to remove media: %r: %s", media_id, e)
continue
removed_media.append(media_id)
@@ -498,7 +494,7 @@ def decode_and_calc_og(body, media_uri, request_encoding=None):
# blindly try decoding the body as utf-8, which seems to fix
# the charset mismatches on https://google.com
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
- tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
+ tree = etree.fromstring(body.decode("utf-8", "ignore"), parser)
og = _calc_og(tree, media_uri)
return og
@@ -523,10 +519,14 @@ def _calc_og(tree, media_uri):
# "og:video:height" : "720",
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
- og = {}
+ og = {} # type: Dict[str, Optional[str]]
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
- if 'content' in tag.attrib:
- og[tag.attrib['property']] = tag.attrib['content']
+ if "content" in tag.attrib:
+ # if we've got more than 50 tags, someone is taking the piss
+ if len(og) >= 50:
+ logger.warning("Skipping OG for page with too many 'og:' tags")
+ return {}
+ og[tag.attrib["property"]] = tag.attrib["content"]
# TODO: grab article: meta tags too, e.g.:
@@ -537,39 +537,43 @@ def _calc_og(tree, media_uri):
# "article:published_time" content="2016-03-31T19:58:24+00:00" />
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
- if 'og:title' not in og:
+ if "og:title" not in og:
# do some basic spidering of the HTML
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
if title and title[0].text is not None:
- og['og:title'] = title[0].text.strip()
+ og["og:title"] = title[0].text.strip()
else:
- og['og:title'] = None
+ og["og:title"] = None
- if 'og:image' not in og:
+ if "og:image" not in og:
# TODO: extract a favicon failing all else
meta_image = tree.xpath(
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
)
if meta_image:
- og['og:image'] = _rebase_url(meta_image[0], media_uri)
+ og["og:image"] = _rebase_url(meta_image[0], media_uri)
else:
# TODO: consider inlined CSS styles as well as width & height attribs
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
- images = sorted(images, key=lambda i: (
- -1 * float(i.attrib['width']) * float(i.attrib['height'])
- ))
+ images = sorted(
+ images,
+ key=lambda i: (
+ -1 * float(i.attrib["width"]) * float(i.attrib["height"])
+ ),
+ )
if not images:
images = tree.xpath("//img[@src]")
if images:
- og['og:image'] = images[0].attrib['src']
+ og["og:image"] = images[0].attrib["src"]
- if 'og:description' not in og:
+ if "og:description" not in og:
meta_description = tree.xpath(
"//*/meta"
"[translate(@name, 'DESCRIPTION', 'description')='description']"
- "/@content")
+ "/@content"
+ )
if meta_description:
- og['og:description'] = meta_description[0]
+ og["og:description"] = meta_description[0]
else:
# grab any text nodes which are inside the <body/> tag...
# unless they are within an HTML5 semantic markup tag...
@@ -590,18 +594,18 @@ def _calc_og(tree, media_uri):
"script",
"noscript",
"style",
- etree.Comment
+ etree.Comment,
)
# Split all the text nodes into paragraphs (by splitting on new
# lines)
text_nodes = (
- re.sub(r'\s+', '\n', el).strip()
+ re.sub(r"\s+", "\n", el).strip()
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
)
- og['og:description'] = summarize_paragraphs(text_nodes)
+ og["og:description"] = summarize_paragraphs(text_nodes)
else:
- og['og:description'] = summarize_paragraphs([og['og:description']])
+ og["og:description"] = summarize_paragraphs([og["og:description"]])
# TODO: delete the url downloads to stop diskfilling,
# as we only ever cared about its OG
@@ -638,7 +642,7 @@ def _iterate_over_text(tree, *tags_to_ignore):
[child, child.tail] if child.tail else [child]
for child in el.iterchildren()
),
- elements
+ elements,
)
@@ -649,8 +653,8 @@ def _rebase_url(url, base):
url[0] = base[0] or "http"
if not url[1]: # fix up hostname
url[1] = base[1]
- if not url[2].startswith('/'):
- url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
+ if not url[2].startswith("/"):
+ url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2]
return urlparse.urlunparse(url)
@@ -661,9 +665,8 @@ def _is_media(content_type):
def _is_html(content_type):
content_type = content_type.lower()
- if (
- content_type.startswith("text/html") or
- content_type.startswith("application/xhtml")
+ if content_type.startswith("text/html") or content_type.startswith(
+ "application/xhtml"
):
return True
@@ -673,19 +676,19 @@ def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
# first paragraph and then word boundaries.
# TODO: Respect sentences?
- description = ''
+ description = ""
# Keep adding paragraphs until we get to the MIN_SIZE.
for text_node in text_nodes:
if len(description) < min_size:
- text_node = re.sub(r'[\t \r\n]+', ' ', text_node)
- description += text_node + '\n\n'
+ text_node = re.sub(r"[\t \r\n]+", " ", text_node)
+ description += text_node + "\n\n"
else:
break
description = description.strip()
- description = re.sub(r'[\t ]+', ' ', description)
- description = re.sub(r'[\t \r\n]*[\r\n]+', '\n\n', description)
+ description = re.sub(r"[\t ]+", " ", description)
+ description = re.sub(r"[\t \r\n]*[\r\n]+", "\n\n", description)
# If the concatenation of paragraphs to get above MIN_SIZE
# took us over MAX_SIZE, then we need to truncate mid paragraph
@@ -717,5 +720,5 @@ def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
# We always add an ellipsis because at the very least
# we chopped mid paragraph.
- description = new_desc.strip() + u"…"
+ description = new_desc.strip() + "…"
return description if description else None
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index d90cbfb56a..858680be26 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -20,8 +20,7 @@ import shutil
from twisted.internet import defer
from synapse.config._base import Config
-from synapse.util import logcontext
-from synapse.util.logcontext import run_in_background
+from synapse.logging.context import defer_to_thread, run_in_background
from .media_storage import FileResponder
@@ -32,6 +31,7 @@ class StorageProvider(object):
"""A storage provider is a service that can store uploaded media and
retrieve them.
"""
+
def store_file(self, path, file_info):
"""Store the file described by file_info. The actual contents can be
retrieved by reading the file in file_info.upload_path.
@@ -67,15 +67,19 @@ class StorageProviderWrapper(StorageProvider):
backend (StorageProvider)
store_local (bool): Whether to store new local files or not.
store_synchronous (bool): Whether to wait for file to be successfully
- uploaded, or todo the upload in the backgroud.
+ uploaded, or todo the upload in the background.
store_remote (bool): Whether remote media should be uploaded
"""
+
def __init__(self, backend, store_local, store_synchronous, store_remote):
self.backend = backend
self.store_local = store_local
self.store_synchronous = store_synchronous
self.store_remote = store_remote
+ def __str__(self):
+ return "StorageProviderWrapper[%s]" % (self.backend,)
+
def store_file(self, path, file_info):
if not file_info.server_name and not self.store_local:
return defer.succeed(None)
@@ -92,6 +96,7 @@ class StorageProviderWrapper(StorageProvider):
return self.backend.store_file(path, file_info)
except Exception:
logger.exception("Error storing file")
+
run_in_background(store)
return defer.succeed(None)
@@ -112,6 +117,9 @@ class FileStorageProviderBackend(StorageProvider):
self.cache_directory = hs.config.media_store_path
self.base_directory = config
+ def __str__(self):
+ return "FileStorageProviderBackend[%s]" % (self.base_directory,)
+
def store_file(self, path, file_info):
"""See StorageProvider.store_file"""
@@ -122,9 +130,8 @@ class FileStorageProviderBackend(StorageProvider):
if not os.path.exists(dirname):
os.makedirs(dirname)
- return logcontext.defer_to_thread(
- self.hs.get_reactor(),
- shutil.copyfile, primary_fname, backup_fname,
+ return defer_to_thread(
+ self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
)
def fetch(self, path, file_info):
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 35a750923b..d57480f761 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -17,10 +17,12 @@
import logging
from twisted.internet import defer
-from twisted.web.resource import Resource
-from twisted.web.server import NOT_DONE_YET
-from synapse.http.server import set_cors_headers, wrap_json_request_handler
+from synapse.http.server import (
+ DirectServeResource,
+ set_cors_headers,
+ wrap_json_request_handler,
+)
from synapse.http.servlet import parse_integer, parse_string
from ._base import (
@@ -34,11 +36,11 @@ from ._base import (
logger = logging.getLogger(__name__)
-class ThumbnailResource(Resource):
+class ThumbnailResource(DirectServeResource):
isLeaf = True
def __init__(self, hs, media_repo, media_storage):
- Resource.__init__(self)
+ super().__init__()
self.store = hs.get_datastore()
self.media_repo = media_repo
@@ -47,13 +49,8 @@ class ThumbnailResource(Resource):
self.server_name = hs.hostname
self.clock = hs.get_clock()
- def render_GET(self, request):
- self._async_render_GET(request)
- return NOT_DONE_YET
-
@wrap_json_request_handler
- @defer.inlineCallbacks
- def _async_render_GET(self, request):
+ async def _async_render_GET(self, request):
set_cors_headers(request)
server_name, media_id, _ = parse_media_id(request)
width = parse_integer(request, "width", required=True)
@@ -63,30 +60,29 @@ class ThumbnailResource(Resource):
if server_name == self.server_name:
if self.dynamic_thumbnails:
- yield self._select_or_generate_local_thumbnail(
+ await self._select_or_generate_local_thumbnail(
request, media_id, width, height, method, m_type
)
else:
- yield self._respond_local_thumbnail(
+ await self._respond_local_thumbnail(
request, media_id, width, height, method, m_type
)
self.media_repo.mark_recently_accessed(None, media_id)
else:
if self.dynamic_thumbnails:
- yield self._select_or_generate_remote_thumbnail(
- request, server_name, media_id,
- width, height, method, m_type
+ await self._select_or_generate_remote_thumbnail(
+ request, server_name, media_id, width, height, method, m_type
)
else:
- yield self._respond_remote_thumbnail(
- request, server_name, media_id,
- width, height, method, m_type
+ await self._respond_remote_thumbnail(
+ request, server_name, media_id, width, height, method, m_type
)
self.media_repo.mark_recently_accessed(server_name, media_id)
@defer.inlineCallbacks
- def _respond_local_thumbnail(self, request, media_id, width, height,
- method, m_type):
+ def _respond_local_thumbnail(
+ self, request, media_id, width, height, method, m_type
+ ):
media_info = yield self.store.get_local_media(media_id)
if not media_info:
@@ -105,7 +101,8 @@ class ThumbnailResource(Resource):
)
file_info = FileInfo(
- server_name=None, file_id=media_id,
+ server_name=None,
+ file_id=media_id,
url_cache=media_info["url_cache"],
thumbnail=True,
thumbnail_width=thumbnail_info["thumbnail_width"],
@@ -124,9 +121,15 @@ class ThumbnailResource(Resource):
respond_404(request)
@defer.inlineCallbacks
- def _select_or_generate_local_thumbnail(self, request, media_id, desired_width,
- desired_height, desired_method,
- desired_type):
+ def _select_or_generate_local_thumbnail(
+ self,
+ request,
+ media_id,
+ desired_width,
+ desired_height,
+ desired_method,
+ desired_type,
+ ):
media_info = yield self.store.get_local_media(media_id)
if not media_info:
@@ -146,7 +149,8 @@ class ThumbnailResource(Resource):
if t_w and t_h and t_method and t_type:
file_info = FileInfo(
- server_name=None, file_id=media_id,
+ server_name=None,
+ file_id=media_id,
url_cache=media_info["url_cache"],
thumbnail=True,
thumbnail_width=info["thumbnail_width"],
@@ -167,24 +171,35 @@ class ThumbnailResource(Resource):
# Okay, so we generate one.
file_path = yield self.media_repo.generate_local_exact_thumbnail(
- media_id, desired_width, desired_height, desired_method, desired_type,
+ media_id,
+ desired_width,
+ desired_height,
+ desired_method,
+ desired_type,
url_cache=media_info["url_cache"],
)
if file_path:
yield respond_with_file(request, desired_type, file_path)
else:
- logger.warn("Failed to generate thumbnail")
+ logger.warning("Failed to generate thumbnail")
respond_404(request)
@defer.inlineCallbacks
- def _select_or_generate_remote_thumbnail(self, request, server_name, media_id,
- desired_width, desired_height,
- desired_method, desired_type):
+ def _select_or_generate_remote_thumbnail(
+ self,
+ request,
+ server_name,
+ media_id,
+ desired_width,
+ desired_height,
+ desired_method,
+ desired_type,
+ ):
media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
- server_name, media_id,
+ server_name, media_id
)
file_id = media_info["filesystem_id"]
@@ -197,7 +212,8 @@ class ThumbnailResource(Resource):
if t_w and t_h and t_method and t_type:
file_info = FileInfo(
- server_name=server_name, file_id=media_info["filesystem_id"],
+ server_name=server_name,
+ file_id=media_info["filesystem_id"],
thumbnail=True,
thumbnail_width=info["thumbnail_width"],
thumbnail_height=info["thumbnail_height"],
@@ -217,26 +233,32 @@ class ThumbnailResource(Resource):
# Okay, so we generate one.
file_path = yield self.media_repo.generate_remote_exact_thumbnail(
- server_name, file_id, media_id, desired_width,
- desired_height, desired_method, desired_type
+ server_name,
+ file_id,
+ media_id,
+ desired_width,
+ desired_height,
+ desired_method,
+ desired_type,
)
if file_path:
yield respond_with_file(request, desired_type, file_path)
else:
- logger.warn("Failed to generate thumbnail")
+ logger.warning("Failed to generate thumbnail")
respond_404(request)
@defer.inlineCallbacks
- def _respond_remote_thumbnail(self, request, server_name, media_id, width,
- height, method, m_type):
+ def _respond_remote_thumbnail(
+ self, request, server_name, media_id, width, height, method, m_type
+ ):
# TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails.
media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
- server_name, media_id,
+ server_name, media_id
)
if thumbnail_infos:
@@ -244,7 +266,8 @@ class ThumbnailResource(Resource):
width, height, method, m_type, thumbnail_infos
)
file_info = FileInfo(
- server_name=server_name, file_id=media_info["filesystem_id"],
+ server_name=server_name,
+ file_id=media_info["filesystem_id"],
thumbnail=True,
thumbnail_width=thumbnail_info["thumbnail_width"],
thumbnail_height=thumbnail_info["thumbnail_height"],
@@ -261,14 +284,20 @@ class ThumbnailResource(Resource):
logger.info("Failed to find any generated thumbnails")
respond_404(request)
- def _select_thumbnail(self, desired_width, desired_height, desired_method,
- desired_type, thumbnail_infos):
+ def _select_thumbnail(
+ self,
+ desired_width,
+ desired_height,
+ desired_method,
+ desired_type,
+ thumbnail_infos,
+ ):
d_w = desired_width
d_h = desired_height
if desired_method.lower() == "crop":
- info_list = []
- info_list2 = []
+ crop_info_list = []
+ crop_info_list2 = []
for info in thumbnail_infos:
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
@@ -280,19 +309,31 @@ class ThumbnailResource(Resource):
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
if t_w >= d_w or t_h >= d_h:
- info_list.append((
- aspect_quality, min_quality, size_quality, type_quality,
- length_quality, info
- ))
+ crop_info_list.append(
+ (
+ aspect_quality,
+ min_quality,
+ size_quality,
+ type_quality,
+ length_quality,
+ info,
+ )
+ )
else:
- info_list2.append((
- aspect_quality, min_quality, size_quality, type_quality,
- length_quality, info
- ))
- if info_list:
- return min(info_list)[-1]
+ crop_info_list2.append(
+ (
+ aspect_quality,
+ min_quality,
+ size_quality,
+ type_quality,
+ length_quality,
+ info,
+ )
+ )
+ if crop_info_list:
+ return min(crop_info_list)[-1]
else:
- return min(info_list2)[-1]
+ return min(crop_info_list2)[-1]
else:
info_list = []
info_list2 = []
@@ -304,13 +345,11 @@ class ThumbnailResource(Resource):
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
if t_method == "scale" and (t_w >= d_w or t_h >= d_h):
- info_list.append((
- size_quality, type_quality, length_quality, info
- ))
+ info_list.append((size_quality, type_quality, length_quality, info))
elif t_method == "scale":
- info_list2.append((
- size_quality, type_quality, length_quality, info
- ))
+ info_list2.append(
+ (size_quality, type_quality, length_quality, info)
+ )
if info_list:
return min(info_list)[-1]
else:
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 3efd0d80fc..c234ea7421 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -28,16 +28,13 @@ EXIF_TRANSPOSE_MAPPINGS = {
5: Image.TRANSPOSE,
6: Image.ROTATE_270,
7: Image.TRANSVERSE,
- 8: Image.ROTATE_90
+ 8: Image.ROTATE_90,
}
class Thumbnailer(object):
- FORMATS = {
- "image/jpeg": "JPEG",
- "image/png": "PNG",
- }
+ FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
def __init__(self, input_path):
self.image = Image.open(input_path)
@@ -81,9 +78,17 @@ class Thumbnailer(object):
"""
if max_width * self.height < max_height * self.width:
- return (max_width, (max_width * self.height) // self.width)
+ return max_width, (max_width * self.height) // self.width
else:
- return ((max_height * self.width) // self.height, max_height)
+ return (max_height * self.width) // self.height, max_height
+
+ def _resize(self, width, height):
+ # 1-bit or 8-bit color palette images need converting to RGB
+ # otherwise they will be scaled using nearest neighbour which
+ # looks awful
+ if self.image.mode in ["1", "P"]:
+ self.image = self.image.convert("RGB")
+ return self.image.resize((width, height), Image.ANTIALIAS)
def scale(self, width, height, output_type):
"""Rescales the image to the given dimensions.
@@ -91,7 +96,7 @@ class Thumbnailer(object):
Returns:
BytesIO: the bytes of the encoded image ready to be written to disk
"""
- scaled = self.image.resize((width, height), Image.ANTIALIAS)
+ scaled = self._resize(width, height)
return self._encode_image(scaled, output_type)
def crop(self, width, height, output_type):
@@ -110,17 +115,13 @@ class Thumbnailer(object):
"""
if width * self.height > height * self.width:
scaled_height = (width * self.height) // self.width
- scaled_image = self.image.resize(
- (width, scaled_height), Image.ANTIALIAS
- )
+ scaled_image = self._resize(width, scaled_height)
crop_top = (scaled_height - height) // 2
crop_bottom = height + crop_top
cropped = scaled_image.crop((0, crop_top, width, crop_bottom))
else:
scaled_width = (height * self.width) // self.height
- scaled_image = self.image.resize(
- (scaled_width, height), Image.ANTIALIAS
- )
+ scaled_image = self._resize(scaled_width, height)
crop_left = (scaled_width - width) // 2
crop_right = width + crop_left
cropped = scaled_image.crop((crop_left, 0, crop_right, height))
@@ -128,5 +129,8 @@ class Thumbnailer(object):
def _encode_image(self, output_image, output_type):
output_bytes_io = BytesIO()
- output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80)
+ fmt = self.FORMATS[output_type]
+ if fmt == "JPEG":
+ output_image = output_image.convert("RGB")
+ output_image.save(output_bytes_io, fmt, quality=80)
return output_bytes_io
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index c1240e1963..83d005812d 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -15,22 +15,24 @@
import logging
-from twisted.internet import defer
-from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
-from synapse.api.errors import SynapseError
-from synapse.http.server import respond_with_json, wrap_json_request_handler
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import (
+ DirectServeResource,
+ respond_with_json,
+ wrap_json_request_handler,
+)
from synapse.http.servlet import parse_string
logger = logging.getLogger(__name__)
-class UploadResource(Resource):
+class UploadResource(DirectServeResource):
isLeaf = True
def __init__(self, hs, media_repo):
- Resource.__init__(self)
+ super().__init__()
self.media_repo = media_repo
self.filepaths = media_repo.filepaths
@@ -41,62 +43,49 @@ class UploadResource(Resource):
self.max_upload_size = hs.config.max_upload_size
self.clock = hs.get_clock()
- def render_POST(self, request):
- self._async_render_POST(request)
- return NOT_DONE_YET
-
def render_OPTIONS(self, request):
respond_with_json(request, 200, {}, send_cors=True)
return NOT_DONE_YET
@wrap_json_request_handler
- @defer.inlineCallbacks
- def _async_render_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def _async_render_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
- content_length = request.getHeader(b"Content-Length").decode('ascii')
+ content_length = request.getHeader(b"Content-Length").decode("ascii")
if content_length is None:
- raise SynapseError(
- msg="Request must specify a Content-Length", code=400
- )
+ raise SynapseError(msg="Request must specify a Content-Length", code=400)
if int(content_length) > self.max_upload_size:
raise SynapseError(
msg="Upload request body is too large",
code=413,
+ errcode=Codes.TOO_LARGE,
)
upload_name = parse_string(request, b"filename", encoding=None)
if upload_name:
try:
- upload_name = upload_name.decode('utf8')
+ upload_name = upload_name.decode("utf8")
except UnicodeDecodeError:
raise SynapseError(
- msg="Invalid UTF-8 filename parameter: %r" % (upload_name),
- code=400,
+ msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400
)
headers = request.requestHeaders
if headers.hasHeader(b"Content-Type"):
- media_type = headers.getRawHeaders(b"Content-Type")[0].decode('ascii')
+ media_type = headers.getRawHeaders(b"Content-Type")[0].decode("ascii")
else:
- raise SynapseError(
- msg="Upload request missing 'Content-Type'",
- code=400,
- )
+ raise SynapseError(msg="Upload request missing 'Content-Type'", code=400)
# if headers.hasHeader(b"Content-Disposition"):
# disposition = headers.getRawHeaders(b"Content-Disposition")[0]
# TODO(markjh): parse content-dispostion
- content_uri = yield self.media_repo.create_content(
- media_type, upload_name, request.content,
- content_length, requester.user
+ content_uri = await self.media_repo.create_content(
+ media_type, upload_name, request.content, content_length, requester.user
)
logger.info("Uploaded content with URI %r", content_uri)
- respond_with_json(
- request, 200, {"content_uri": content_uri}, send_cors=True
- )
+ respond_with_json(request, 200, {"content_uri": content_uri}, send_cors=True)
diff --git a/synapse/rest/saml2/metadata_resource.py b/synapse/rest/saml2/metadata_resource.py
index e8c680aeb4..1e8526e22e 100644
--- a/synapse/rest/saml2/metadata_resource.py
+++ b/synapse/rest/saml2/metadata_resource.py
@@ -30,7 +30,7 @@ class SAML2MetadataResource(Resource):
def render_GET(self, request):
metadata_xml = saml2.metadata.create_metadata_string(
- configfile=None, config=self.sp_config,
+ configfile=None, config=self.sp_config
)
request.setHeader(b"Content-Type", b"text/xml; charset=utf-8")
return metadata_xml
diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py
index 69fb77b322..a545c13db7 100644
--- a/synapse/rest/saml2/response_resource.py
+++ b/synapse/rest/saml2/response_resource.py
@@ -13,62 +13,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import logging
-import saml2
-from saml2.client import Saml2Client
+from synapse.http.server import (
+ DirectServeResource,
+ finish_request,
+ wrap_html_request_handler,
+)
-from twisted.web.resource import Resource
-from twisted.web.server import NOT_DONE_YET
-from synapse.api.errors import CodeMessageException
-from synapse.http.server import wrap_html_request_handler
-from synapse.http.servlet import parse_string
-from synapse.rest.client.v1.login import SSOAuthHandler
-
-logger = logging.getLogger(__name__)
-
-
-class SAML2ResponseResource(Resource):
+class SAML2ResponseResource(DirectServeResource):
"""A Twisted web resource which handles the SAML response"""
isLeaf = 1
def __init__(self, hs):
- Resource.__init__(self)
-
- self._saml_client = Saml2Client(hs.config.saml2_sp_config)
- self._sso_auth_handler = SSOAuthHandler(hs)
-
- def render_POST(self, request):
- self._async_render_POST(request)
- return NOT_DONE_YET
+ super().__init__()
+ self._error_html_content = hs.config.saml2_error_html_content
+ self._saml_handler = hs.get_saml_handler()
+
+ async def _async_render_GET(self, request):
+ # We're not expecting any GET request on that resource if everything goes right,
+ # but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
+ # In this case, just tell the user that something went wrong and they should
+ # try to authenticate again.
+ request.setResponseCode(400)
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(b"Content-Length", b"%d" % (len(self._error_html_content),))
+ request.write(self._error_html_content.encode("utf8"))
+ finish_request(request)
@wrap_html_request_handler
- def _async_render_POST(self, request):
- resp_bytes = parse_string(request, 'SAMLResponse', required=True)
- relay_state = parse_string(request, 'RelayState', required=True)
-
- try:
- saml2_auth = self._saml_client.parse_authn_request_response(
- resp_bytes, saml2.BINDING_HTTP_POST,
- )
- except Exception as e:
- logger.warning("Exception parsing SAML2 response", exc_info=1)
- raise CodeMessageException(
- 400, "Unable to parse SAML2 response: %s" % (e,),
- )
-
- if saml2_auth.not_signed:
- raise CodeMessageException(400, "SAML2 response was not signed")
-
- if "uid" not in saml2_auth.ava:
- raise CodeMessageException(400, "uid not in SAML2 response")
-
- username = saml2_auth.ava["uid"][0]
-
- displayName = saml2_auth.ava.get("displayName", [None])[0]
- return self._sso_auth_handler.on_successful_auth(
- username, request, relay_state,
- user_display_name=displayName,
- )
+ async def _async_render_POST(self, request):
+ return await self._saml_handler.handle_saml_response(request)
diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index a7fa4f39af..20177b44e7 100644
--- a/synapse/rest/well_known.py
+++ b/synapse/rest/well_known.py
@@ -29,23 +29,20 @@ class WellKnownBuilder(object):
Args:
hs (synapse.server.HomeServer):
"""
+
def __init__(self, hs):
self._config = hs.config
def get_well_known(self):
- # if we don't have a public_base_url, we can't help much here.
+ # if we don't have a public_baseurl, we can't help much here.
if self._config.public_baseurl is None:
return None
- result = {
- "m.homeserver": {
- "base_url": self._config.public_baseurl,
- },
- }
+ result = {"m.homeserver": {"base_url": self._config.public_baseurl}}
if self._config.default_identity_server:
result["m.identity_server"] = {
- "base_url": self._config.default_identity_server,
+ "base_url": self._config.default_identity_server
}
return result
@@ -66,7 +63,7 @@ class WellKnownResource(Resource):
if not r:
request.setResponseCode(404)
request.setHeader(b"Content-Type", b"text/plain")
- return b'.well-known not available'
+ return b".well-known not available"
logger.debug("returning: %s", r)
request.setHeader(b"Content-Type", b"application/json")
diff --git a/synapse/rulecheck/domain_rule_checker.py b/synapse/rulecheck/domain_rule_checker.py
index 212cc212cc..6f2a1931c5 100644
--- a/synapse/rulecheck/domain_rule_checker.py
+++ b/synapse/rulecheck/domain_rule_checker.py
@@ -61,19 +61,19 @@ class DomainRuleChecker(object):
self.default = config["default"]
self.can_only_join_rooms_with_invite = config.get(
- "can_only_join_rooms_with_invite", False,
+ "can_only_join_rooms_with_invite", False
)
self.can_only_create_one_to_one_rooms = config.get(
- "can_only_create_one_to_one_rooms", False,
+ "can_only_create_one_to_one_rooms", False
)
self.can_only_invite_during_room_creation = config.get(
- "can_only_invite_during_room_creation", False,
+ "can_only_invite_during_room_creation", False
)
self.can_invite_by_third_party_id = config.get(
- "can_invite_by_third_party_id", True,
+ "can_invite_by_third_party_id", True
)
self.domains_prevented_from_being_invited_to_published_rooms = config.get(
- "domains_prevented_from_being_invited_to_published_rooms", [],
+ "domains_prevented_from_being_invited_to_published_rooms", []
)
def check_event_for_spam(self, event):
@@ -81,8 +81,15 @@ class DomainRuleChecker(object):
"""
return False
- def user_may_invite(self, inviter_userid, invitee_userid, third_party_invite,
- room_id, new_room, published_room=False):
+ def user_may_invite(
+ self,
+ inviter_userid,
+ invitee_userid,
+ third_party_invite,
+ room_id,
+ new_room,
+ published_room=False,
+ ):
"""Implements synapse.events.SpamChecker.user_may_invite
"""
if self.can_only_invite_during_room_creation and not new_room:
@@ -103,15 +110,17 @@ class DomainRuleChecker(object):
return self.default
if (
- published_room and
- invitee_domain in self.domains_prevented_from_being_invited_to_published_rooms
+ published_room
+ and invitee_domain
+ in self.domains_prevented_from_being_invited_to_published_rooms
):
return False
return invitee_domain in self.domain_mapping[inviter_domain]
- def user_may_create_room(self, userid, invite_list, third_party_invite_list,
- cloning):
+ def user_may_create_room(
+ self, userid, invite_list, third_party_invite_list, cloning
+ ):
"""Implements synapse.events.SpamChecker.user_may_create_room
"""
@@ -169,4 +178,4 @@ class DomainRuleChecker(object):
idx = mxid.find(":")
if idx == -1:
raise Exception("Invalid ID: %r" % (mxid,))
- return mxid[idx + 1:]
+ return mxid[idx + 1 :]
diff --git a/synapse/secrets.py b/synapse/secrets.py
index f6280f951c..0b327a0f82 100644
--- a/synapse/secrets.py
+++ b/synapse/secrets.py
@@ -29,6 +29,7 @@ if sys.version_info[0:2] >= (3, 6):
def Secrets():
return secrets
+
else:
import os
import binascii
@@ -38,4 +39,4 @@ else:
return os.urandom(nbytes)
def token_hex(self, nbytes=32):
- return binascii.hexlify(self.token_bytes(nbytes)).decode('ascii')
+ return binascii.hexlify(self.token_bytes(nbytes)).decode("ascii")
diff --git a/synapse/server.py b/synapse/server.py
index 0b2e49cb72..98a567efb3 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -25,7 +25,6 @@ import abc
import logging
import os
-from twisted.enterprise import adbapi
from twisted.mail.smtp import sendmail
from synapse.api.auth import Auth
@@ -33,6 +32,7 @@ from synapse.api.filtering import Filtering
from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice.api import ApplicationServiceApi
from synapse.appservice.scheduler import ApplicationServiceScheduler
+from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.crypto.context_factory import RegularPolicyForHTTPS
from synapse.crypto.keyring import Keyring
@@ -50,7 +50,7 @@ from synapse.federation.send_queue import FederationRemoteSendQueue
from synapse.federation.sender import FederationSender
from synapse.federation.transport.client import TransportLayerClient
from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer
-from synapse.groups.groups_server import GroupsServerHandler
+from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler
from synapse.handlers import Handlers
from synapse.handlers.account_validity import AccountValidityHandler
from synapse.handlers.acme import AcmeHandler
@@ -62,7 +62,7 @@ from synapse.handlers.devicemessage import DeviceMessageHandler
from synapse.handlers.e2e_keys import E2eKeysHandler
from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler
from synapse.handlers.events import EventHandler, EventStreamHandler
-from synapse.handlers.groups_local import GroupsLocalHandler
+from synapse.handlers.groups_local import GroupsLocalHandler, GroupsLocalWorkerHandler
from synapse.handlers.initial_sync import InitialSyncHandler
from synapse.handlers.message import EventCreationHandler, MessageHandler
from synapse.handlers.pagination import PaginationHandler
@@ -93,8 +93,11 @@ from synapse.rest.media.v1.media_repository import (
from synapse.secrets import Secrets
from synapse.server_notices.server_notices_manager import ServerNoticesManager
from synapse.server_notices.server_notices_sender import ServerNoticesSender
-from synapse.server_notices.worker_server_notices_sender import WorkerServerNoticesSender
+from synapse.server_notices.worker_server_notices_sender import (
+ WorkerServerNoticesSender,
+)
from synapse.state import StateHandler, StateResolutionHandler
+from synapse.storage import DataStores, Storage
from synapse.streams.events import EventSources
from synapse.util import Clock
from synapse.util.distributor import Distributor
@@ -129,105 +132,108 @@ class HomeServer(object):
__metaclass__ = abc.ABCMeta
DEPENDENCIES = [
- 'http_client',
- 'db_pool',
- 'federation_client',
- 'federation_server',
- 'handlers',
- 'auth',
- 'room_creation_handler',
- 'state_handler',
- 'state_resolution_handler',
- 'presence_handler',
- 'sync_handler',
- 'typing_handler',
- 'room_list_handler',
- 'acme_handler',
- 'auth_handler',
- 'device_handler',
- 'stats_handler',
- 'e2e_keys_handler',
- 'e2e_room_keys_handler',
- 'event_handler',
- 'event_stream_handler',
- 'initial_sync_handler',
- 'application_service_api',
- 'application_service_scheduler',
- 'application_service_handler',
- 'device_message_handler',
- 'profile_handler',
- 'event_creation_handler',
- 'deactivate_account_handler',
- 'set_password_handler',
- 'notifier',
- 'event_sources',
- 'keyring',
- 'pusherpool',
- 'event_builder_factory',
- 'filtering',
- 'http_client_context_factory',
+ "http_client",
+ "federation_client",
+ "federation_server",
+ "handlers",
+ "auth",
+ "room_creation_handler",
+ "state_handler",
+ "state_resolution_handler",
+ "presence_handler",
+ "sync_handler",
+ "typing_handler",
+ "room_list_handler",
+ "acme_handler",
+ "auth_handler",
+ "device_handler",
+ "stats_handler",
+ "e2e_keys_handler",
+ "e2e_room_keys_handler",
+ "event_handler",
+ "event_stream_handler",
+ "initial_sync_handler",
+ "application_service_api",
+ "application_service_scheduler",
+ "application_service_handler",
+ "device_message_handler",
+ "profile_handler",
+ "event_creation_handler",
+ "deactivate_account_handler",
+ "set_password_handler",
+ "notifier",
+ "event_sources",
+ "keyring",
+ "pusherpool",
+ "event_builder_factory",
+ "filtering",
+ "http_client_context_factory",
"proxied_http_client",
- 'simple_http_client',
- 'media_repository',
- 'media_repository_resource',
- 'federation_transport_client',
- 'federation_sender',
- 'receipts_handler',
- 'macaroon_generator',
- 'tcp_replication',
- 'read_marker_handler',
- 'action_generator',
- 'user_directory_handler',
- 'groups_local_handler',
- 'groups_server_handler',
- 'groups_attestation_signing',
- 'groups_attestation_renewer',
- 'secrets',
- 'spam_checker',
- 'third_party_event_rules',
- 'room_member_handler',
- 'federation_registry',
- 'server_notices_manager',
- 'server_notices_sender',
- 'message_handler',
- 'pagination_handler',
- 'room_context_handler',
- 'sendmail',
- 'registration_handler',
- 'account_validity_handler',
- 'event_client_serializer',
- 'password_policy_handler',
- ]
-
- REQUIRED_ON_MASTER_STARTUP = [
+ "simple_http_client",
+ "proxied_http_client",
+ "media_repository",
+ "media_repository_resource",
+ "federation_transport_client",
+ "federation_sender",
+ "receipts_handler",
+ "macaroon_generator",
+ "tcp_replication",
+ "read_marker_handler",
+ "action_generator",
"user_directory_handler",
- "stats_handler"
+ "groups_local_handler",
+ "groups_server_handler",
+ "groups_attestation_signing",
+ "groups_attestation_renewer",
+ "secrets",
+ "spam_checker",
+ "third_party_event_rules",
+ "room_member_handler",
+ "federation_registry",
+ "server_notices_manager",
+ "server_notices_sender",
+ "message_handler",
+ "pagination_handler",
+ "room_context_handler",
+ "sendmail",
+ "registration_handler",
+ "account_validity_handler",
+ "saml_handler",
+ "event_client_serializer",
+ "storage",
+ "password_policy_handler",
]
+ REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
+
# This is overridden in derived application classes
# (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be
# instantiated during setup() for future return by get_datastore()
DATASTORE_CLASS = abc.abstractproperty()
- def __init__(self, hostname, reactor=None, **kwargs):
+ def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwargs):
"""
Args:
hostname : The hostname for the server.
+ config: The full config for the homeserver.
"""
if not reactor:
from twisted.internet import reactor
self._reactor = reactor
self.hostname = hostname
+ self.config = config
self._building = {}
self._listening_services = []
+ self.start_time = None
self.clock = Clock(reactor)
self.distributor = Distributor()
self.ratelimiter = Ratelimiter()
+ self.admin_redaction_ratelimiter = Ratelimiter()
self.registration_ratelimiter = Ratelimiter()
- self.datastore = None
+ self.datastores = None
# Other kwargs are explicit dependencies
for depname in kwargs:
@@ -235,9 +241,8 @@ class HomeServer(object):
def setup(self):
logger.info("Setting up.")
- with self.get_db_conn() as conn:
- self.datastore = self.DATASTORE_CLASS(conn, self)
- conn.commit()
+ self.start_time = int(self.get_clock().time())
+ self.datastores = DataStores(self.DATASTORE_CLASS, self)
logger.info("Finished setting up.")
def setup_master(self):
@@ -269,7 +274,10 @@ class HomeServer(object):
return self.clock
def get_datastore(self):
- return self.datastore
+ return self.datastores.main
+
+ def get_datastores(self):
+ return self.datastores
def get_config(self):
return self.config
@@ -283,6 +291,9 @@ class HomeServer(object):
def get_registration_ratelimiter(self):
return self.registration_ratelimiter
+ def get_admin_redaction_ratelimiter(self):
+ return self.admin_redaction_ratelimiter
+
def build_federation_client(self):
return FederationClient(self)
@@ -311,8 +322,8 @@ class HomeServer(object):
def build_proxied_http_client(self):
return SimpleHttpClient(
self,
- http_proxy=os.getenv("http_proxy"),
- https_proxy=os.getenv("HTTPS_PROXY"),
+ http_proxy=os.getenvb(b"http_proxy"),
+ https_proxy=os.getenvb(b"HTTPS_PROXY"),
)
def build_room_creation_handler(self):
@@ -417,32 +428,6 @@ class HomeServer(object):
)
return MatrixFederationHttpClient(self, tls_client_options_factory)
- def build_db_pool(self):
- name = self.db_config["name"]
-
- return adbapi.ConnectionPool(
- name,
- cp_reactor=self.get_reactor(),
- **self.db_config.get("args", {})
- )
-
- def get_db_conn(self, run_new_connection=True):
- """Makes a new connection to the database, skipping the db pool
-
- Returns:
- Connection: a connection object implementing the PEP-249 spec
- """
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def build_media_repository_resource(self):
# build the media repo resource. This indirects through the HomeServer
# to ensure that we only have a single instance of
@@ -478,10 +463,16 @@ class HomeServer(object):
return UserDirectoryHandler(self)
def build_groups_local_handler(self):
- return GroupsLocalHandler(self)
+ if self.config.worker_app:
+ return GroupsLocalWorkerHandler(self)
+ else:
+ return GroupsLocalHandler(self)
def build_groups_server_handler(self):
- return GroupsServerHandler(self)
+ if self.config.worker_app:
+ return GroupsServerWorkerHandler(self)
+ else:
+ return GroupsServerHandler(self)
def build_groups_attestation_signing(self):
return GroupAttestationSigning(self)
@@ -537,9 +528,17 @@ class HomeServer(object):
def build_account_validity_handler(self):
return AccountValidityHandler(self)
+ def build_saml_handler(self):
+ from synapse.handlers.saml_handler import SamlHandler
+
+ return SamlHandler(self)
+
def build_event_client_serializer(self):
return EventClientSerializer(self)
+ def build_storage(self) -> Storage:
+ return Storage(self, self.datastores)
+
def build_password_policy_handler(self):
return PasswordPolicyHandler(self)
@@ -569,9 +568,7 @@ def _make_dependency_method(depname):
if builder:
# Prevent cyclic dependencies from deadlocking
if depname in hs._building:
- raise ValueError("Cyclic dependency while building %s" % (
- depname,
- ))
+ raise ValueError("Cyclic dependency while building %s" % (depname,))
hs._building[depname] = 1
dep = builder()
@@ -582,9 +579,7 @@ def _make_dependency_method(depname):
return dep
raise NotImplementedError(
- "%s has no %s nor a builder for it" % (
- type(hs).__name__, depname,
- )
+ "%s has no %s nor a builder for it" % (type(hs).__name__, depname)
)
setattr(HomeServer, "get_%s" % (depname), _get)
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 69263458db..3844f0e12f 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -1,7 +1,10 @@
+import twisted.internet
+
import synapse.api.auth
import synapse.config.homeserver
+import synapse.crypto.keyring
+import synapse.federation.federation_server
import synapse.federation.sender
-import synapse.federation.transaction_queue
import synapse.federation.transport.client
import synapse.handlers
import synapse.handlers.auth
@@ -9,10 +12,13 @@ import synapse.handlers.deactivate_account
import synapse.handlers.device
import synapse.handlers.e2e_keys
import synapse.handlers.message
+import synapse.handlers.presence
import synapse.handlers.room
import synapse.handlers.room_member
import synapse.handlers.set_password
import synapse.http.client
+import synapse.notifier
+import synapse.replication.tcp.client
import synapse.rest.media.v1.media_repository
import synapse.server_notices.server_notices_manager
import synapse.server_notices.server_notices_sender
@@ -23,28 +29,23 @@ class HomeServer(object):
@property
def config(self) -> synapse.config.homeserver.HomeServerConfig:
pass
-
+ @property
+ def hostname(self) -> str:
+ pass
def get_auth(self) -> synapse.api.auth.Auth:
pass
-
def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler:
pass
-
def get_datastore(self) -> synapse.storage.DataStore:
pass
-
def get_device_handler(self) -> synapse.handlers.device.DeviceHandler:
pass
-
def get_e2e_keys_handler(self) -> synapse.handlers.e2e_keys.E2eKeysHandler:
pass
-
def get_handlers(self) -> synapse.handlers.Handlers:
pass
-
def get_state_handler(self) -> synapse.state.StateHandler:
pass
-
def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
pass
def get_simple_http_client(self) -> synapse.http.client.SimpleHttpClient:
@@ -55,35 +56,61 @@ class HomeServer(object):
"""Fetch an HTTP client implementation which doesn't do any blacklisting
but does support HTTP_PROXY settings"""
pass
- def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
+ def get_deactivate_account_handler(
+ self,
+ ) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
pass
-
def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler:
pass
-
def get_room_member_handler(self) -> synapse.handlers.room_member.RoomMemberHandler:
pass
-
- def get_event_creation_handler(self) -> synapse.handlers.message.EventCreationHandler:
+ def get_event_creation_handler(
+ self,
+ ) -> synapse.handlers.message.EventCreationHandler:
pass
-
- def get_set_password_handler(self) -> synapse.handlers.set_password.SetPasswordHandler:
+ def get_set_password_handler(
+ self,
+ ) -> synapse.handlers.set_password.SetPasswordHandler:
pass
-
def get_federation_sender(self) -> synapse.federation.sender.FederationSender:
pass
-
- def get_federation_transport_client(self) -> synapse.federation.transport.client.TransportLayerClient:
+ def get_federation_transport_client(
+ self,
+ ) -> synapse.federation.transport.client.TransportLayerClient:
pass
-
- def get_media_repository_resource(self) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource:
+ def get_media_repository_resource(
+ self,
+ ) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource:
pass
-
- def get_media_repository(self) -> synapse.rest.media.v1.media_repository.MediaRepository:
+ def get_media_repository(
+ self,
+ ) -> synapse.rest.media.v1.media_repository.MediaRepository:
pass
-
- def get_server_notices_manager(self) -> synapse.server_notices.server_notices_manager.ServerNoticesManager:
+ def get_server_notices_manager(
+ self,
+ ) -> synapse.server_notices.server_notices_manager.ServerNoticesManager:
pass
-
- def get_server_notices_sender(self) -> synapse.server_notices.server_notices_sender.ServerNoticesSender:
+ def get_server_notices_sender(
+ self,
+ ) -> synapse.server_notices.server_notices_sender.ServerNoticesSender:
+ pass
+ def get_notifier(self) -> synapse.notifier.Notifier:
+ pass
+ def get_presence_handler(self) -> synapse.handlers.presence.PresenceHandler:
+ pass
+ def get_clock(self) -> synapse.util.Clock:
+ pass
+ def get_reactor(self) -> twisted.internet.base.ReactorBase:
+ pass
+ def get_keyring(self) -> synapse.crypto.keyring.Keyring:
+ pass
+ def get_tcp_replication(
+ self,
+ ) -> synapse.replication.tcp.client.ReplicationClientHandler:
+ pass
+ def get_federation_registry(
+ self,
+ ) -> synapse.federation.federation_server.FederationHandlerRegistry:
+ pass
+ def is_mine_id(self, domain_id: str) -> bool:
pass
diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py
index 5e3044d164..5736c56032 100644
--- a/synapse/server_notices/consent_server_notices.py
+++ b/synapse/server_notices/consent_server_notices.py
@@ -30,6 +30,7 @@ class ConsentServerNotices(object):
"""Keeps track of whether we need to send users server_notices about
privacy policy consent, and sends one if we do.
"""
+
def __init__(self, hs):
"""
@@ -49,12 +50,11 @@ class ConsentServerNotices(object):
if not self._server_notices_manager.is_enabled():
raise ConfigError(
"user_consent configuration requires server notices, but "
- "server notices are not enabled.",
+ "server notices are not enabled."
)
- if 'body' not in self._server_notice_content:
+ if "body" not in self._server_notice_content:
raise ConfigError(
- "user_consent server_notice_consent must contain a 'body' "
- "key.",
+ "user_consent server_notice_consent must contain a 'body' key."
)
self._consent_uri_builder = ConsentURIBuilder(hs.config)
@@ -95,18 +95,14 @@ class ConsentServerNotices(object):
# need to send a message.
try:
consent_uri = self._consent_uri_builder.build_user_consent_uri(
- get_localpart_from_id(user_id),
+ get_localpart_from_id(user_id)
)
content = copy_with_str_subst(
- self._server_notice_content, {
- 'consent_uri': consent_uri,
- },
- )
- yield self._server_notices_manager.send_notice(
- user_id, content,
+ self._server_notice_content, {"consent_uri": consent_uri}
)
+ yield self._server_notices_manager.send_notice(user_id, content)
yield self._store.user_set_consent_server_notice_sent(
- user_id, self._current_consent_version,
+ user_id, self._current_consent_version
)
except SynapseError as e:
logger.error("Error sending server notice about user consent: %s", e)
@@ -128,9 +124,7 @@ def copy_with_str_subst(x, substitutions):
if isinstance(x, string_types):
return x % substitutions
if isinstance(x, dict):
- return {
- k: copy_with_str_subst(v, substitutions) for (k, v) in iteritems(x)
- }
+ return {k: copy_with_str_subst(v, substitutions) for (k, v) in iteritems(x)}
if isinstance(x, (list, tuple)):
return [copy_with_str_subst(y) for y in x]
diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py
index af15cba0ee..9fae2e0afe 100644
--- a/synapse/server_notices/resource_limits_server_notices.py
+++ b/synapse/server_notices/resource_limits_server_notices.py
@@ -20,6 +20,7 @@ from twisted.internet import defer
from synapse.api.constants import (
EventTypes,
+ LimitBlockingTypes,
ServerNoticeLimitReached,
ServerNoticeMsgType,
)
@@ -33,6 +34,7 @@ class ResourceLimitsServerNotices(object):
""" Keeps track of whether the server has reached it's resource limit and
ensures that the client is kept up to date.
"""
+
def __init__(self, hs):
"""
Args:
@@ -69,7 +71,7 @@ class ResourceLimitsServerNotices(object):
return
if not self._server_notices_manager.is_enabled():
- # Don't try and send server notices unles they've been enabled
+ # Don't try and send server notices unless they've been enabled
return
timestamp = yield self._store.user_last_seen_monthly_active(user_id)
@@ -78,66 +80,93 @@ class ResourceLimitsServerNotices(object):
# In practice, not sure we can ever get here
return
- # Determine current state of room
-
room_id = yield self._server_notices_manager.get_notice_room_for_user(user_id)
if not room_id:
- logger.warn("Failed to get server notices room")
+ logger.warning("Failed to get server notices room")
return
yield self._check_and_set_tags(user_id, room_id)
+
+ # Determine current state of room
currently_blocked, ref_events = yield self._is_room_currently_blocked(room_id)
+ limit_msg = None
+ limit_type = None
try:
- # Normally should always pass in user_id if you have it, but in
- # this case are checking what would happen to other users if they
- # were to arrive.
- try:
- yield self._auth.check_auth_blocking()
- is_auth_blocking = False
- except ResourceLimitError as e:
- is_auth_blocking = True
- event_content = e.msg
- event_limit_type = e.limit_type
-
- if currently_blocked and not is_auth_blocking:
- # Room is notifying of a block, when it ought not to be.
- # Remove block notification
- content = {
- "pinned": ref_events
- }
- yield self._server_notices_manager.send_notice(
- user_id, content, EventTypes.Pinned, '',
- )
+ # Normally should always pass in user_id to check_auth_blocking
+ # if you have it, but in this case are checking what would happen
+ # to other users if they were to arrive.
+ yield self._auth.check_auth_blocking()
+ except ResourceLimitError as e:
+ limit_msg = e.msg
+ limit_type = e.limit_type
- elif not currently_blocked and is_auth_blocking:
+ try:
+ if (
+ limit_type == LimitBlockingTypes.MONTHLY_ACTIVE_USER
+ and not self._config.mau_limit_alerting
+ ):
+ # We have hit the MAU limit, but MAU alerting is disabled:
+ # reset room if necessary and return
+ if currently_blocked:
+ self._remove_limit_block_notification(user_id, ref_events)
+ return
+
+ if currently_blocked and not limit_msg:
+ # Room is notifying of a block, when it ought not to be.
+ yield self._remove_limit_block_notification(user_id, ref_events)
+ elif not currently_blocked and limit_msg:
# Room is not notifying of a block, when it ought to be.
- # Add block notification
- content = {
- 'body': event_content,
- 'msgtype': ServerNoticeMsgType,
- 'server_notice_type': ServerNoticeLimitReached,
- 'admin_contact': self._config.admin_contact,
- 'limit_type': event_limit_type
- }
- event = yield self._server_notices_manager.send_notice(
- user_id, content, EventTypes.Message,
+ yield self._apply_limit_block_notification(
+ user_id, limit_msg, limit_type
)
-
- content = {
- "pinned": [
- event.event_id,
- ]
- }
- yield self._server_notices_manager.send_notice(
- user_id, content, EventTypes.Pinned, '',
- )
-
except SynapseError as e:
logger.error("Error sending resource limits server notice: %s", e)
@defer.inlineCallbacks
+ def _remove_limit_block_notification(self, user_id, ref_events):
+ """Utility method to remove limit block notifications from the server
+ notices room.
+
+ Args:
+ user_id (str): user to notify
+ ref_events (list[str]): The event_ids of pinned events that are unrelated to
+ limit blocking and need to be preserved.
+ """
+ content = {"pinned": ref_events}
+ yield self._server_notices_manager.send_notice(
+ user_id, content, EventTypes.Pinned, ""
+ )
+
+ @defer.inlineCallbacks
+ def _apply_limit_block_notification(self, user_id, event_body, event_limit_type):
+ """Utility method to apply limit block notifications in the server
+ notices room.
+
+ Args:
+ user_id (str): user to notify
+ event_body(str): The human readable text that describes the block.
+ event_limit_type(str): Specifies the type of block e.g. monthly active user
+ limit has been exceeded.
+ """
+ content = {
+ "body": event_body,
+ "msgtype": ServerNoticeMsgType,
+ "server_notice_type": ServerNoticeLimitReached,
+ "admin_contact": self._config.admin_contact,
+ "limit_type": event_limit_type,
+ }
+ event = yield self._server_notices_manager.send_notice(
+ user_id, content, EventTypes.Message
+ )
+
+ content = {"pinned": [event.event_id]}
+ yield self._server_notices_manager.send_notice(
+ user_id, content, EventTypes.Pinned, ""
+ )
+
+ @defer.inlineCallbacks
def _check_and_set_tags(self, user_id, room_id):
"""
Since server notices rooms were originally not with tags,
@@ -156,9 +185,7 @@ class ResourceLimitsServerNotices(object):
max_id = yield self._store.add_tag_to_room(
user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
)
- self._notifier.on_new_event(
- "account_data_key", max_id, users=[user_id]
- )
+ self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
@defer.inlineCallbacks
def _is_room_currently_blocked(self, room_id):
@@ -188,7 +215,7 @@ class ResourceLimitsServerNotices(object):
referenced_events = []
if pinned_state_event is not None:
- referenced_events = list(pinned_state_event.content.get('pinned', []))
+ referenced_events = list(pinned_state_event.content.get("pinned", []))
events = yield self._store.get_events(referenced_events)
for event_id, event in iteritems(events):
@@ -200,4 +227,4 @@ class ResourceLimitsServerNotices(object):
if event_id in referenced_events:
referenced_events.remove(event.event_id)
- defer.returnValue((currently_blocked, referenced_events))
+ return currently_blocked, referenced_events
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index c5cc6d728e..f7432c8d2f 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -51,8 +51,7 @@ class ServerNoticesManager(object):
@defer.inlineCallbacks
def send_notice(
- self, user_id, event_content,
- type=EventTypes.Message, state_key=None
+ self, user_id, event_content, type=EventTypes.Message, state_key=None
):
"""Send a notice to the given user
@@ -82,12 +81,12 @@ class ServerNoticesManager(object):
}
if state_key is not None:
- event_dict['state_key'] = state_key
+ event_dict["state_key"] = state_key
res = yield self._event_creation_handler.create_and_send_nonmember_event(
- requester, event_dict, ratelimit=False,
+ requester, event_dict, ratelimit=False
)
- defer.returnValue(res)
+ return res
@cachedInlineCallbacks()
def get_notice_room_for_user(self, user_id):
@@ -104,11 +103,10 @@ class ServerNoticesManager(object):
if not self.is_enabled():
raise Exception("Server notices not enabled")
- assert self._is_mine_id(user_id), \
- "Cannot send server notices to remote users"
+ assert self._is_mine_id(user_id), "Cannot send server notices to remote users"
- rooms = yield self._store.get_rooms_for_user_where_membership_is(
- user_id, [Membership.INVITE, Membership.JOIN],
+ rooms = yield self._store.get_rooms_for_local_user_where_membership_is(
+ user_id, [Membership.INVITE, Membership.JOIN]
)
system_mxid = self._config.server_notices_mxid
for room in rooms:
@@ -122,7 +120,7 @@ class ServerNoticesManager(object):
# we found a room which our user shares with the system notice
# user
logger.info("Using room %s", room.room_id)
- defer.returnValue(room.room_id)
+ return room.room_id
# apparently no existing notice room: create a new one
logger.info("Creating server notices room for %s", user_id)
@@ -132,8 +130,8 @@ class ServerNoticesManager(object):
# avatar, we have to use both.
join_profile = None
if (
- self._config.server_notices_mxid_display_name is not None or
- self._config.server_notices_mxid_avatar_url is not None
+ self._config.server_notices_mxid_display_name is not None
+ or self._config.server_notices_mxid_avatar_url is not None
):
join_profile = {
"displayname": self._config.server_notices_mxid_display_name,
@@ -146,22 +144,18 @@ class ServerNoticesManager(object):
config={
"preset": RoomCreationPreset.PRIVATE_CHAT,
"name": self._config.server_notices_room_name,
- "power_level_content_override": {
- "users_default": -10,
- },
- "invite": (user_id,)
+ "power_level_content_override": {"users_default": -10},
+ "invite": (user_id,),
},
ratelimit=False,
creator_join_profile=join_profile,
)
- room_id = info['room_id']
+ room_id = info["room_id"]
max_id = yield self._store.add_tag_to_room(
- user_id, room_id, SERVER_NOTICE_ROOM_TAG, {},
- )
- self._notifier.on_new_event(
- "account_data_key", max_id, users=[user_id]
+ user_id, room_id, SERVER_NOTICE_ROOM_TAG, {}
)
+ self._notifier.on_new_event("account_data_key", max_id, users=[user_id])
logger.info("Created server notices room %s for %s", room_id, user_id)
- defer.returnValue(room_id)
+ return room_id
diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py
index 6121b2f267..652bab58e3 100644
--- a/synapse/server_notices/server_notices_sender.py
+++ b/synapse/server_notices/server_notices_sender.py
@@ -24,6 +24,7 @@ class ServerNoticesSender(object):
"""A centralised place which sends server notices automatically when
Certain Events take place
"""
+
def __init__(self, hs):
"""
@@ -32,7 +33,7 @@ class ServerNoticesSender(object):
"""
self._server_notices = (
ConsentServerNotices(hs),
- ResourceLimitsServerNotices(hs)
+ ResourceLimitsServerNotices(hs),
)
@defer.inlineCallbacks
@@ -43,9 +44,7 @@ class ServerNoticesSender(object):
user_id (str): mxid of user who synced
"""
for sn in self._server_notices:
- yield sn.maybe_send_server_notice_to_user(
- user_id,
- )
+ yield sn.maybe_send_server_notice_to_user(user_id)
@defer.inlineCallbacks
def on_user_ip(self, user_id):
@@ -58,6 +57,4 @@ class ServerNoticesSender(object):
# we check for notices to send to the user in on_user_ip as well as
# in on_user_syncing
for sn in self._server_notices:
- yield sn.maybe_send_server_notice_to_user(
- user_id,
- )
+ yield sn.maybe_send_server_notice_to_user(user_id)
diff --git a/synapse/server_notices/worker_server_notices_sender.py b/synapse/server_notices/worker_server_notices_sender.py
index 4a133026c3..245ec7c64f 100644
--- a/synapse/server_notices/worker_server_notices_sender.py
+++ b/synapse/server_notices/worker_server_notices_sender.py
@@ -17,6 +17,7 @@ from twisted.internet import defer
class WorkerServerNoticesSender(object):
"""Stub impl of ServerNoticesSender which does nothing"""
+
def __init__(self, hs):
"""
Args:
diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py
new file mode 100644
index 0000000000..9b78924d96
--- /dev/null
+++ b/synapse/spam_checker_api/__init__.py
@@ -0,0 +1,55 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.internet import defer
+
+from synapse.storage.state import StateFilter
+
+MYPY = False
+if MYPY:
+ import synapse.server
+
+logger = logging.getLogger(__name__)
+
+
+class SpamCheckerApi(object):
+ """A proxy object that gets passed to spam checkers so they can get
+ access to rooms and other relevant information.
+ """
+
+ def __init__(self, hs: "synapse.server.HomeServer"):
+ self.hs = hs
+
+ self._store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def get_state_events_in_room(self, room_id: str, types: tuple) -> defer.Deferred:
+ """Gets state events for the given room.
+
+ Args:
+ room_id: The room ID to get state events in.
+ types: The event type and state key (using None
+ to represent 'any') of the room state to acquire.
+
+ Returns:
+ twisted.internet.defer.Deferred[list(synapse.events.FrozenEvent)]:
+ The filtered state events in the room.
+ """
+ state_ids = yield self._store.get_filtered_current_state_ids(
+ room_id=room_id, state_filter=StateFilter.from_types(types)
+ )
+ state = yield self._store.get_events(state_ids.values())
+ return state.values()
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 36684ef9f6..4afefc6b1d 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -16,27 +16,40 @@
import logging
from collections import namedtuple
+from typing import Dict, Iterable, List, Optional, Set
from six import iteritems, itervalues
import attr
from frozendict import frozendict
+from prometheus_client import Histogram
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
+from synapse.events import EventBase
from synapse.events.snapshot import EventContext
+from synapse.logging.utils import log_function
from synapse.state import v1, v2
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
+from synapse.types import StateMap
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.expiringcache import ExpiringCache
-from synapse.util.logutils import log_function
-from synapse.util.metrics import Measure
+from synapse.util.metrics import Measure, measure_func
logger = logging.getLogger(__name__)
+# Metrics for number of state groups involved in a resolution.
+state_groups_histogram = Histogram(
+ "synapse_state_number_state_groups_in_resolution",
+ "Number of state groups used when performing a state resolution",
+ buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
+)
+
+
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
@@ -94,12 +107,14 @@ class StateHandler(object):
def __init__(self, hs):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
+ self.state_store = hs.get_storage().state
self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()
@defer.inlineCallbacks
- def get_current_state(self, room_id, event_type=None, state_key="",
- latest_event_ids=None):
+ def get_current_state(
+ self, room_id, event_type=None, state_key="", latest_event_ids=None
+ ):
""" Retrieves the current state for the room. This is done by
calling `get_latest_events_in_room` to get the leading edges of the
event graph and then resolving any of the state conflicts.
@@ -125,16 +140,16 @@ class StateHandler(object):
event = None
if event_id:
event = yield self.store.get_event(event_id, allow_none=True)
- defer.returnValue(event)
- return
+ return event
- state_map = yield self.store.get_events(list(state.values()),
- get_prev_content=False)
+ state_map = yield self.store.get_events(
+ list(state.values()), get_prev_content=False
+ )
state = {
key: state_map[e_id] for key, e_id in iteritems(state) if e_id in state_map
}
- defer.returnValue(state)
+ return state
@defer.inlineCallbacks
def get_current_state_ids(self, room_id, latest_event_ids=None):
@@ -158,7 +173,7 @@ class StateHandler(object):
ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
- defer.returnValue(state)
+ return state
@defer.inlineCallbacks
def get_current_users_in_room(self, room_id, latest_event_ids=None):
@@ -178,27 +193,40 @@ class StateHandler(object):
logger.debug("calling resolve_state_groups from get_current_users_in_room")
entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
- defer.returnValue(joined_users)
+ return joined_users
@defer.inlineCallbacks
- def get_current_hosts_in_room(self, room_id, latest_event_ids=None):
- if not latest_event_ids:
- latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
- logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
- entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
+ def get_current_hosts_in_room(self, room_id):
+ event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+ return (yield self.get_hosts_in_room_at_events(room_id, event_ids))
+
+ @defer.inlineCallbacks
+ def get_hosts_in_room_at_events(self, room_id, event_ids):
+ """Get the hosts that were in a room at the given event ids
+
+ Args:
+ room_id (str):
+ event_ids (list[str]):
+
+ Returns:
+ Deferred[list[str]]: the hosts in the room at the given events
+ """
+ entry = yield self.resolve_state_groups_for_events(room_id, event_ids)
joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
- defer.returnValue(joined_hosts)
+ return joined_hosts
@defer.inlineCallbacks
- def compute_event_context(self, event, old_state=None):
+ def compute_event_context(
+ self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
+ ):
"""Build an EventContext structure for the event.
This works out what the current state should be for the event, and
generates a new state group if necessary.
Args:
- event (synapse.events.EventBase):
- old_state (dict|None): The state at the event if it can't be
+ event:
+ old_state: The state at the event if it can't be
calculated from existing events. This is normally only specified
when receiving an event from federation where we don't have the
prev events for, e.g. when backfilling.
@@ -210,10 +238,11 @@ class StateHandler(object):
# If this is an outlier, then we know it shouldn't have any current
# state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group.
+
+ # FIXME: why do we populate current_state_ids? I thought the point was
+ # that we weren't supposed to have any state for outliers?
if old_state:
- prev_state_ids = {
- (s.type, s.state_key): s.event_id for s in old_state
- }
+ prev_state_ids = {(s.type, s.state_key): s.event_id for s in old_state}
if event.is_state():
current_state_ids = dict(prev_state_ids)
key = (event.type, event.state_key)
@@ -228,118 +257,105 @@ class StateHandler(object):
# group for it.
context = EventContext.with_state(
state_group=None,
+ state_group_before_event=None,
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
)
- defer.returnValue(context)
+ return context
- if old_state:
- # We already have the state, so we don't need to calculate it.
- # Let's just correctly fill out the context and create a
- # new state group for it.
+ #
+ # first of all, figure out the state before the event
+ #
- prev_state_ids = {
+ if old_state:
+ # if we're given the state before the event, then we use that
+ state_ids_before_event = {
(s.type, s.state_key): s.event_id for s in old_state
}
+ state_group_before_event = None
+ state_group_before_event_prev_group = None
+ deltas_to_state_group_before_event = None
- if event.is_state():
- key = (event.type, event.state_key)
- if key in prev_state_ids:
- replaces = prev_state_ids[key]
- if replaces != event.event_id: # Paranoia check
- event.unsigned["replaces_state"] = replaces
- current_state_ids = dict(prev_state_ids)
- current_state_ids[key] = event.event_id
- else:
- current_state_ids = prev_state_ids
-
- state_group = yield self.store.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=None,
- delta_ids=None,
- current_state_ids=current_state_ids,
- )
+ else:
+ # otherwise, we'll need to resolve the state across the prev_events.
+ logger.debug("calling resolve_state_groups from compute_event_context")
- context = EventContext.with_state(
- state_group=state_group,
- current_state_ids=current_state_ids,
- prev_state_ids=prev_state_ids,
+ entry = yield self.resolve_state_groups_for_events(
+ event.room_id, event.prev_event_ids()
)
- defer.returnValue(context)
+ state_ids_before_event = entry.state
+ state_group_before_event = entry.state_group
+ state_group_before_event_prev_group = entry.prev_group
+ deltas_to_state_group_before_event = entry.delta_ids
- logger.debug("calling resolve_state_groups from compute_event_context")
+ #
+ # make sure that we have a state group at that point. If it's not a state event,
+ # that will be the state group for the new event. If it *is* a state event,
+ # it might get rejected (in which case we'll need to persist it with the
+ # previous state group)
+ #
- entry = yield self.resolve_state_groups_for_events(
- event.room_id, event.prev_event_ids(),
- )
+ if not state_group_before_event:
+ state_group_before_event = yield self.state_store.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=state_group_before_event_prev_group,
+ delta_ids=deltas_to_state_group_before_event,
+ current_state_ids=state_ids_before_event,
+ )
- prev_state_ids = entry.state
- prev_group = None
- delta_ids = None
+ # XXX: can we update the state cache entry for the new state group? or
+ # could we set a flag on resolve_state_groups_for_events to tell it to
+ # always make a state group?
+
+ #
+ # now if it's not a state event, we're done
+ #
+
+ if not event.is_state():
+ return EventContext.with_state(
+ state_group_before_event=state_group_before_event,
+ state_group=state_group_before_event,
+ current_state_ids=state_ids_before_event,
+ prev_state_ids=state_ids_before_event,
+ prev_group=state_group_before_event_prev_group,
+ delta_ids=deltas_to_state_group_before_event,
+ )
- if event.is_state():
- # If this is a state event then we need to create a new state
- # group for the state after this event.
+ #
+ # otherwise, we'll need to create a new state group for after the event
+ #
- key = (event.type, event.state_key)
- if key in prev_state_ids:
- replaces = prev_state_ids[key]
+ key = (event.type, event.state_key)
+ if key in state_ids_before_event:
+ replaces = state_ids_before_event[key]
+ if replaces != event.event_id:
event.unsigned["replaces_state"] = replaces
- current_state_ids = dict(prev_state_ids)
- current_state_ids[key] = event.event_id
-
- if entry.state_group:
- # If the state at the event has a state group assigned then
- # we can use that as the prev group
- prev_group = entry.state_group
- delta_ids = {
- key: event.event_id
- }
- elif entry.prev_group:
- # If the state at the event only has a prev group, then we can
- # use that as a prev group too.
- prev_group = entry.prev_group
- delta_ids = dict(entry.delta_ids)
- delta_ids[key] = event.event_id
-
- state_group = yield self.store.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=prev_group,
- delta_ids=delta_ids,
- current_state_ids=current_state_ids,
- )
- else:
- current_state_ids = prev_state_ids
- prev_group = entry.prev_group
- delta_ids = entry.delta_ids
-
- if entry.state_group is None:
- entry.state_group = yield self.store.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=entry.prev_group,
- delta_ids=entry.delta_ids,
- current_state_ids=current_state_ids,
- )
- entry.state_id = entry.state_group
-
- state_group = entry.state_group
-
- context = EventContext.with_state(
- state_group=state_group,
- current_state_ids=current_state_ids,
- prev_state_ids=prev_state_ids,
- prev_group=prev_group,
+ state_ids_after_event = dict(state_ids_before_event)
+ state_ids_after_event[key] = event.event_id
+ delta_ids = {key: event.event_id}
+
+ state_group_after_event = yield self.state_store.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=state_group_before_event,
delta_ids=delta_ids,
+ current_state_ids=state_ids_after_event,
)
- defer.returnValue(context)
+ return EventContext.with_state(
+ state_group=state_group_after_event,
+ state_group_before_event=state_group_before_event,
+ current_state_ids=state_ids_after_event,
+ prev_state_ids=state_ids_before_event,
+ prev_group=state_group_before_event,
+ delta_ids=delta_ids,
+ )
+ @measure_func()
@defer.inlineCallbacks
def resolve_state_groups_for_events(self, room_id, event_ids):
""" Given a list of event_ids this method fetches the state at each
@@ -360,63 +376,58 @@ class StateHandler(object):
# map from state group id to the state in that state group (where
# 'state' is a map from state key to event id)
# dict[int, dict[(str, str), str]]
- state_groups_ids = yield self.store.get_state_groups_ids(
+ state_groups_ids = yield self.state_store.get_state_groups_ids(
room_id, event_ids
)
if len(state_groups_ids) == 0:
- defer.returnValue(_StateCacheEntry(
- state={},
- state_group=None,
- ))
+ return _StateCacheEntry(state={}, state_group=None)
elif len(state_groups_ids) == 1:
name, state_list = list(state_groups_ids.items()).pop()
- prev_group, delta_ids = yield self.store.get_state_group_delta(name)
+ prev_group, delta_ids = yield self.state_store.get_state_group_delta(name)
- defer.returnValue(_StateCacheEntry(
+ return _StateCacheEntry(
state=state_list,
state_group=name,
prev_group=prev_group,
delta_ids=delta_ids,
- ))
+ )
- room_version = yield self.store.get_room_version(room_id)
+ room_version = yield self.store.get_room_version_id(room_id)
result = yield self._state_resolution_handler.resolve_state_groups(
- room_id, room_version, state_groups_ids, None,
+ room_id,
+ room_version,
+ state_groups_ids,
+ None,
state_res_store=StateResolutionStore(self.store),
)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def resolve_events(self, room_version, state_sets, event):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
- state_set_ids = [{
- (ev.type, ev.state_key): ev.event_id
- for ev in st
- } for st in state_sets]
-
- state_map = {
- ev.event_id: ev
- for st in state_sets
- for ev in st
- }
+ state_set_ids = [
+ {(ev.type, ev.state_key): ev.event_id for ev in st} for st in state_sets
+ ]
+
+ state_map = {ev.event_id: ev for st in state_sets for ev in st}
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_store(
- room_version, state_set_ids,
+ event.room_id,
+ room_version,
+ state_set_ids,
event_map=state_map,
state_res_store=StateResolutionStore(self.store),
)
- new_state = {
- key: state_map[ev_id] for key, ev_id in iteritems(new_state)
- }
+ new_state = {key: state_map[ev_id] for key, ev_id in iteritems(new_state)}
- defer.returnValue(new_state)
+ return new_state
class StateResolutionHandler(object):
@@ -425,6 +436,7 @@ class StateResolutionHandler(object):
Note that the storage layer depends on this handler, so all functions must
be storage-independent.
"""
+
def __init__(self, hs):
self.clock = hs.get_clock()
@@ -444,7 +456,7 @@ class StateResolutionHandler(object):
@defer.inlineCallbacks
@log_function
def resolve_state_groups(
- self, room_id, room_version, state_groups_ids, event_map, state_res_store,
+ self, room_id, room_version, state_groups_ids, event_map, state_res_store
):
"""Resolves conflicts between a set of state groups
@@ -452,7 +464,7 @@ class StateResolutionHandler(object):
not be called for a single state group
Args:
- room_id (str): room we are resolving for (used for logging)
+ room_id (str): room we are resolving for (used for logging and sanity checks)
room_version (str): version of the room
state_groups_ids (dict[int, dict[(str, str), str]]):
map from state group id to the state in that state group
@@ -471,10 +483,7 @@ class StateResolutionHandler(object):
Returns:
Deferred[_StateCacheEntry]: resolved state
"""
- logger.debug(
- "resolve_state_groups state_groups %s",
- state_groups_ids.keys()
- )
+ logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
group_names = frozenset(state_groups_ids.keys())
@@ -482,12 +491,14 @@ class StateResolutionHandler(object):
if self._state_cache is not None:
cache = self._state_cache.get(group_names, None)
if cache:
- defer.returnValue(cache)
+ return cache
logger.info(
"Resolving state for %s with %d groups", room_id, len(state_groups_ids)
)
+ state_groups_histogram.observe(len(state_groups_ids))
+
# start by assuming we won't have any conflicted state, and build up the new
# state map by iterating through the state groups. If we discover a conflict,
# we give up and instead use `resolve_events_with_store`.
@@ -509,6 +520,7 @@ class StateResolutionHandler(object):
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_store(
+ room_id,
room_version,
list(itervalues(state_groups_ids)),
event_map=event_map,
@@ -526,13 +538,10 @@ class StateResolutionHandler(object):
if self._state_cache is not None:
self._state_cache[group_names] = cache
- defer.returnValue(cache)
+ return cache
-def _make_state_cache_entry(
- new_state,
- state_groups_ids,
-):
+def _make_state_cache_entry(new_state, state_groups_ids):
"""Given a resolved state, and a set of input state groups, pick one to base
a new state group on (if any), and return an appropriately-constructed
_StateCacheEntry.
@@ -562,10 +571,7 @@ def _make_state_cache_entry(
old_state_event_ids = set(itervalues(state))
if new_state_event_ids == old_state_event_ids:
# got an exact match.
- return _StateCacheEntry(
- state=new_state,
- state_group=sg,
- )
+ return _StateCacheEntry(state=new_state, state_group=sg)
# TODO: We want to create a state group for this set of events, to
# increase cache hits, but we need to make sure that it doesn't
@@ -576,53 +582,54 @@ def _make_state_cache_entry(
delta_ids = None
for old_group, old_state in iteritems(state_groups_ids):
- n_delta_ids = {
- k: v
- for k, v in iteritems(new_state)
- if old_state.get(k) != v
- }
+ n_delta_ids = {k: v for k, v in iteritems(new_state) if old_state.get(k) != v}
if not delta_ids or len(n_delta_ids) < len(delta_ids):
prev_group = old_group
delta_ids = n_delta_ids
return _StateCacheEntry(
- state=new_state,
- state_group=None,
- prev_group=prev_group,
- delta_ids=delta_ids,
+ state=new_state, state_group=None, prev_group=prev_group, delta_ids=delta_ids
)
-def resolve_events_with_store(room_version, state_sets, event_map, state_res_store):
+def resolve_events_with_store(
+ room_id: str,
+ room_version: str,
+ state_sets: List[StateMap[str]],
+ event_map: Optional[Dict[str, EventBase]],
+ state_res_store: "StateResolutionStore",
+):
"""
Args:
- room_version(str): Version of the room
+ room_id: the room we are working in
+
+ room_version: Version of the room
- state_sets(list): List of dicts of (type, state_key) -> event_id,
+ state_sets: List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
- event_map(dict[str,FrozenEvent]|None):
+ event_map:
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing
events will be requested via state_map_factory.
- If None, all events will be fetched via state_map_factory.
+ If None, all events will be fetched via state_res_store.
- state_res_store (StateResolutionStore)
+ state_res_store: a place to fetch events from
- Returns
+ Returns:
Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id.
"""
v = KNOWN_ROOM_VERSIONS[room_version]
if v.state_res == StateResolutionVersions.V1:
return v1.resolve_events_with_store(
- state_sets, event_map, state_res_store.get_events,
+ room_id, state_sets, event_map, state_res_store.get_events
)
else:
return v2.resolve_events_with_store(
- room_version, state_sets, event_map, state_res_store,
+ room_id, room_version, state_sets, event_map, state_res_store
)
@@ -650,28 +657,21 @@ class StateResolutionStore(object):
return self.store.get_events(
event_ids,
- check_redacted=False,
+ redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
allow_rejected=allow_rejected,
)
- def get_auth_chain(self, event_ids):
- """Gets the full auth chain for a set of events (including rejected
- events).
-
- Includes the given event IDs in the result.
+ def get_auth_chain_difference(self, state_sets: List[Set[str]]):
+ """Given sets of state events figure out the auth chain difference (as
+ per state res v2 algorithm).
- Note that:
- 1. All events must be state events.
- 2. For v1 rooms this may not have the full auth chain in the
- presence of rejected events
-
- Args:
- event_ids (list): The event IDs of the events to fetch the auth
- chain for. Must be state events.
+ This equivalent to fetching the full auth chain for each set of state
+ and returning the events that don't appear in each and every auth
+ chain.
Returns:
- Deferred[list[str]]: List of event IDs of the auth chain.
+ Deferred[Set[str]]: Set of event IDs.
"""
- return self.store.get_auth_chain_ids(event_ids, include_given=True)
+ return self.store.get_auth_chain_difference(state_sets)
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 29b4e86cfd..9bf98d06f2 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -15,6 +15,7 @@
import hashlib
import logging
+from typing import Callable, Dict, List, Optional
from six import iteritems, iterkeys, itervalues
@@ -24,6 +25,8 @@ from synapse import event_auth
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase
+from synapse.types import StateMap
logger = logging.getLogger(__name__)
@@ -32,13 +35,20 @@ POWER_KEY = (EventTypes.PowerLevels, "")
@defer.inlineCallbacks
-def resolve_events_with_store(state_sets, event_map, state_map_factory):
+def resolve_events_with_store(
+ room_id: str,
+ state_sets: List[StateMap[str]],
+ event_map: Optional[Dict[str, EventBase]],
+ state_map_factory: Callable,
+):
"""
Args:
- state_sets(list): List of dicts of (type, state_key) -> event_id,
+ room_id: the room we are working in
+
+ state_sets: List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
- event_map(dict[str,FrozenEvent]|None):
+ event_map:
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing
@@ -46,34 +56,28 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
If None, all events will be fetched via state_map_factory.
- state_map_factory(func): will be called
+ state_map_factory: will be called
with a list of event_ids that are needed, and should return with
a Deferred of dict of event_id to event.
- Returns
+ Returns:
Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id.
"""
if len(state_sets) == 1:
- defer.returnValue(state_sets[0])
+ return state_sets[0]
- unconflicted_state, conflicted_state = _seperate(
- state_sets,
- )
+ unconflicted_state, conflicted_state = _seperate(state_sets)
- needed_events = set(
- event_id
- for event_ids in itervalues(conflicted_state)
- for event_id in event_ids
- )
+ needed_events = {
+ event_id for event_ids in itervalues(conflicted_state) for event_id in event_ids
+ }
needed_event_count = len(needed_events)
if event_map is not None:
needed_events -= set(iterkeys(event_map))
logger.info(
- "Asking for %d/%d conflicted events",
- len(needed_events),
- needed_event_count,
+ "Asking for %d/%d conflicted events", len(needed_events), needed_event_count
)
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
@@ -82,6 +86,14 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
if event_map is not None:
state_map.update(event_map)
+ # everything in the state map should be in the right room
+ for event in state_map.values():
+ if event.room_id != room_id:
+ raise Exception(
+ "Attempting to state-resolve for room %s with event %s which is in %s"
+ % (room_id, event.event_id, event.room_id,)
+ )
+
# get the ids of the auth events which allow us to authenticate the
# conflicted state, picking only from the unconflicting state.
#
@@ -97,17 +109,22 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
new_needed_events -= set(iterkeys(event_map))
logger.info(
- "Asking for %d/%d auth events",
- len(new_needed_events),
- new_needed_event_count,
+ "Asking for %d/%d auth events", len(new_needed_events), new_needed_event_count
)
state_map_new = yield state_map_factory(new_needed_events)
+ for event in state_map_new.values():
+ if event.room_id != room_id:
+ raise Exception(
+ "Attempting to state-resolve for room %s with event %s which is in %s"
+ % (room_id, event.event_id, event.room_id,)
+ )
+
state_map.update(state_map_new)
- defer.returnValue(_resolve_with_state(
+ return _resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
- ))
+ )
def _seperate(state_sets):
@@ -173,8 +190,9 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma
return auth_events
-def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event_ids,
- state_map):
+def _resolve_with_state(
+ unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map
+):
conflicted_state = {}
for key, event_ids in iteritems(conflicted_state_ids):
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
@@ -190,9 +208,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ids, auth_event
}
try:
- resolved_state = _resolve_state_events(
- conflicted_state, auth_events
- )
+ resolved_state = _resolve_state_events(conflicted_state, auth_events)
except Exception:
logger.exception("Failed to resolve state")
raise
@@ -218,49 +234,38 @@ def _resolve_state_events(conflicted_state, auth_events):
if POWER_KEY in conflicted_state:
events = conflicted_state[POWER_KEY]
logger.debug("Resolving conflicted power levels %r", events)
- resolved_state[POWER_KEY] = _resolve_auth_events(
- events, auth_events)
+ resolved_state[POWER_KEY] = _resolve_auth_events(events, auth_events)
auth_events.update(resolved_state)
for key, events in iteritems(conflicted_state):
if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
- resolved_state[key] = _resolve_auth_events(
- events,
- auth_events
- )
+ resolved_state[key] = _resolve_auth_events(events, auth_events)
auth_events.update(resolved_state)
for key, events in iteritems(conflicted_state):
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
- resolved_state[key] = _resolve_auth_events(
- events,
- auth_events
- )
+ resolved_state[key] = _resolve_auth_events(events, auth_events)
auth_events.update(resolved_state)
for key, events in iteritems(conflicted_state):
if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events)
- resolved_state[key] = _resolve_normal_events(
- events, auth_events
- )
+ resolved_state[key] = _resolve_normal_events(events, auth_events)
return resolved_state
def _resolve_auth_events(events, auth_events):
- reverse = [i for i in reversed(_ordered_events(events))]
+ reverse = list(reversed(_ordered_events(events)))
- auth_keys = set(
- key
- for event in events
- for key in event_auth.auth_types_for_event(event)
- )
+ auth_keys = {
+ key for event in events for key in event_auth.auth_types_for_event(event)
+ }
new_auth_events = {}
for key in auth_keys:
@@ -276,7 +281,7 @@ def _resolve_auth_events(events, auth_events):
try:
# The signatures have already been checked at this point
event_auth.check(
- RoomVersions.V1.identifier,
+ RoomVersions.V1,
event,
auth_events,
do_sig_check=False,
@@ -294,7 +299,7 @@ def _resolve_normal_events(events, auth_events):
try:
# The signatures have already been checked at this point
event_auth.check(
- RoomVersions.V1.identifier,
+ RoomVersions.V1,
event,
auth_events,
do_sig_check=False,
@@ -313,6 +318,6 @@ def _ordered_events(events):
def key_func(e):
# we have to use utf-8 rather than ascii here because it turns out we allow
# people to send us events with non-ascii event IDs :/
- return -int(e.depth), hashlib.sha1(e.event_id.encode('utf-8')).hexdigest()
+ return -int(e.depth), hashlib.sha1(e.event_id.encode("utf-8")).hexdigest()
return sorted(events, key=key_func)
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 650995c92c..18484e2fa6 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -16,29 +16,42 @@
import heapq
import itertools
import logging
+from typing import Dict, List, Optional
from six import iteritems, itervalues
from twisted.internet import defer
+import synapse.state
from synapse import event_auth
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.events import EventBase
+from synapse.types import StateMap
logger = logging.getLogger(__name__)
@defer.inlineCallbacks
-def resolve_events_with_store(room_version, state_sets, event_map, state_res_store):
+def resolve_events_with_store(
+ room_id: str,
+ room_version: str,
+ state_sets: List[StateMap[str]],
+ event_map: Optional[Dict[str, EventBase]],
+ state_res_store: "synapse.state.StateResolutionStore",
+):
"""Resolves the state using the v2 state resolution algorithm
Args:
- room_version (str): The room version
+ room_id: the room we are working in
- state_sets(list): List of dicts of (type, state_key) -> event_id,
+ room_version: The room version
+
+ state_sets: List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
- event_map(dict[str,FrozenEvent]|None):
+ event_map:
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing
@@ -46,9 +59,9 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
If None, all events will be fetched via state_res_store.
- state_res_store (StateResolutionStore)
+ state_res_store:
- Returns
+ Returns:
Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id.
"""
@@ -63,50 +76,57 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
unconflicted_state, conflicted_state = _seperate(state_sets)
if not conflicted_state:
- defer.returnValue(unconflicted_state)
+ return unconflicted_state
logger.debug("%d conflicted state entries", len(conflicted_state))
logger.debug("Calculating auth chain difference")
# Also fetch all auth events that appear in only some of the state sets'
# auth chains.
- auth_diff = yield _get_auth_chain_difference(
- state_sets, event_map, state_res_store,
- )
+ auth_diff = yield _get_auth_chain_difference(state_sets, event_map, state_res_store)
- full_conflicted_set = set(itertools.chain(
- itertools.chain.from_iterable(itervalues(conflicted_state)),
- auth_diff,
- ))
+ full_conflicted_set = set(
+ itertools.chain(
+ itertools.chain.from_iterable(itervalues(conflicted_state)), auth_diff
+ )
+ )
- events = yield state_res_store.get_events([
- eid for eid in full_conflicted_set
- if eid not in event_map
- ], allow_rejected=True)
+ events = yield state_res_store.get_events(
+ [eid for eid in full_conflicted_set if eid not in event_map],
+ allow_rejected=True,
+ )
event_map.update(events)
- full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map)
+ # everything in the event map should be in the right room
+ for event in event_map.values():
+ if event.room_id != room_id:
+ raise Exception(
+ "Attempting to state-resolve for room %s with event %s which is in %s"
+ % (room_id, event.event_id, event.room_id,)
+ )
+
+ full_conflicted_set = {eid for eid in full_conflicted_set if eid in event_map}
logger.debug("%d full_conflicted_set entries", len(full_conflicted_set))
# Get and sort all the power events (kicks/bans/etc)
power_events = (
- eid for eid in full_conflicted_set
- if _is_power_event(event_map[eid])
+ eid for eid in full_conflicted_set if _is_power_event(event_map[eid])
)
sorted_power_events = yield _reverse_topological_power_sort(
- power_events,
- event_map,
- state_res_store,
- full_conflicted_set,
+ room_id, power_events, event_map, state_res_store, full_conflicted_set
)
logger.debug("sorted %d power events", len(sorted_power_events))
# Now sequentially auth each one
resolved_state = yield _iterative_auth_checks(
- room_version, sorted_power_events, unconflicted_state, event_map,
+ room_id,
+ room_version,
+ sorted_power_events,
+ unconflicted_state,
+ event_map,
state_res_store,
)
@@ -116,22 +136,24 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
# events using the mainline of the resolved power level.
leftover_events = [
- ev_id
- for ev_id in full_conflicted_set
- if ev_id not in sorted_power_events
+ ev_id for ev_id in full_conflicted_set if ev_id not in sorted_power_events
]
logger.debug("sorting %d remaining events", len(leftover_events))
pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
leftover_events = yield _mainline_sort(
- leftover_events, pl, event_map, state_res_store,
+ room_id, leftover_events, pl, event_map, state_res_store
)
logger.debug("resolving remaining events")
resolved_state = yield _iterative_auth_checks(
- room_version, leftover_events, resolved_state, event_map,
+ room_id,
+ room_version,
+ leftover_events,
+ resolved_state,
+ event_map,
state_res_store,
)
@@ -142,15 +164,16 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
logger.debug("done")
- defer.returnValue(resolved_state)
+ return resolved_state
@defer.inlineCallbacks
-def _get_power_level_for_sender(event_id, event_map, state_res_store):
+def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
"""Return the power level of the sender of the given event according to
their auth events.
Args:
+ room_id (str)
event_id (str)
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
@@ -158,33 +181,37 @@ def _get_power_level_for_sender(event_id, event_map, state_res_store):
Returns:
Deferred[int]
"""
- event = yield _get_event(event_id, event_map, state_res_store)
+ event = yield _get_event(room_id, event_id, event_map, state_res_store)
pl = None
for aid in event.auth_event_ids():
- aev = yield _get_event(aid, event_map, state_res_store)
- if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
+ aev = yield _get_event(
+ room_id, aid, event_map, state_res_store, allow_none=True
+ )
+ if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
pl = aev
break
if pl is None:
# Couldn't find power level. Check if they're the creator of the room
for aid in event.auth_event_ids():
- aev = yield _get_event(aid, event_map, state_res_store)
- if (aev.type, aev.state_key) == (EventTypes.Create, ""):
+ aev = yield _get_event(
+ room_id, aid, event_map, state_res_store, allow_none=True
+ )
+ if aev and (aev.type, aev.state_key) == (EventTypes.Create, ""):
if aev.content.get("creator") == event.sender:
- defer.returnValue(100)
+ return 100
break
- defer.returnValue(0)
+ return 0
level = pl.content.get("users", {}).get(event.sender)
if level is None:
level = pl.content.get("users_default", 0)
if level is None:
- defer.returnValue(0)
+ return 0
else:
- defer.returnValue(int(level))
+ return int(level)
@defer.inlineCallbacks
@@ -200,34 +227,12 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
Returns:
Deferred[set[str]]: Set of event IDs
"""
- common = set(itervalues(state_sets[0])).intersection(
- *(itervalues(s) for s in state_sets[1:])
- )
- auth_sets = []
- for state_set in state_sets:
- auth_ids = set(
- eid
- for key, eid in iteritems(state_set)
- if (key[0] in (
- EventTypes.Member,
- EventTypes.ThirdPartyInvite,
- ) or key in (
- (EventTypes.PowerLevels, ''),
- (EventTypes.Create, ''),
- (EventTypes.JoinRules, ''),
- )) and eid not in common
- )
-
- auth_chain = yield state_res_store.get_auth_chain(auth_ids)
- auth_ids.update(auth_chain)
-
- auth_sets.append(auth_ids)
-
- intersection = set(auth_sets[0]).intersection(*auth_sets[1:])
- union = set().union(*auth_sets)
+ difference = yield state_res_store.get_auth_chain_difference(
+ [set(state_set.values()) for state_set in state_sets]
+ )
- defer.returnValue(union - intersection)
+ return difference
def _seperate(state_sets):
@@ -246,7 +251,7 @@ def _seperate(state_sets):
conflicted_state = {}
for key in set(itertools.chain.from_iterable(state_sets)):
- event_ids = set(state_set.get(key) for state_set in state_sets)
+ event_ids = {state_set.get(key) for state_set in state_sets}
if len(event_ids) == 1:
unconflicted_state[key] = event_ids.pop()
else:
@@ -274,21 +279,23 @@ def _is_power_event(event):
return True
if event.type == EventTypes.Member:
- if event.membership in ('leave', 'ban'):
+ if event.membership in ("leave", "ban"):
return event.sender != event.state_key
return False
@defer.inlineCallbacks
-def _add_event_and_auth_chain_to_graph(graph, event_id, event_map,
- state_res_store, auth_diff):
+def _add_event_and_auth_chain_to_graph(
+ graph, room_id, event_id, event_map, state_res_store, auth_diff
+):
"""Helper function for _reverse_topological_power_sort that add the event
and its auth chain (that is in the auth diff) to the graph
Args:
graph (dict[str, set[str]]): A map from event ID to the events auth
event IDs
+ room_id (str): the room we are working in
event_id (str): Event to add to the graph
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
@@ -300,7 +307,7 @@ def _add_event_and_auth_chain_to_graph(graph, event_id, event_map,
eid = state.pop()
graph.setdefault(eid, set())
- event = yield _get_event(eid, event_map, state_res_store)
+ event = yield _get_event(room_id, eid, event_map, state_res_store)
for aid in event.auth_event_ids():
if aid in auth_diff:
if aid not in graph:
@@ -310,11 +317,14 @@ def _add_event_and_auth_chain_to_graph(graph, event_id, event_map,
@defer.inlineCallbacks
-def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_diff):
+def _reverse_topological_power_sort(
+ room_id, event_ids, event_map, state_res_store, auth_diff
+):
"""Returns a list of the event_ids sorted by reverse topological ordering,
and then by power level and origin_server_ts
Args:
+ room_id (str): the room we are working in
event_ids (list[str]): The events to sort
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
@@ -327,12 +337,14 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_
graph = {}
for event_id in event_ids:
yield _add_event_and_auth_chain_to_graph(
- graph, event_id, event_map, state_res_store, auth_diff,
+ graph, room_id, event_id, event_map, state_res_store, auth_diff
)
event_to_pl = {}
for event_id in graph:
- pl = yield _get_power_level_for_sender(event_id, event_map, state_res_store)
+ pl = yield _get_power_level_for_sender(
+ room_id, event_id, event_map, state_res_store
+ )
event_to_pl[event_id] = pl
def _get_power_order(event_id):
@@ -342,72 +354,83 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_
return -pl, ev.origin_server_ts, event_id
# Note: graph is modified during the sort
- it = lexicographical_topological_sort(
- graph,
- key=_get_power_order,
- )
+ it = lexicographical_topological_sort(graph, key=_get_power_order)
sorted_events = list(it)
- defer.returnValue(sorted_events)
+ return sorted_events
@defer.inlineCallbacks
-def _iterative_auth_checks(room_version, event_ids, base_state, event_map,
- state_res_store):
+def _iterative_auth_checks(
+ room_id, room_version, event_ids, base_state, event_map, state_res_store
+):
"""Sequentially apply auth checks to each event in given list, updating the
state as it goes along.
Args:
+ room_id (str)
room_version (str)
event_ids (list[str]): Ordered list of events to apply auth checks to
- base_state (dict[tuple[str, str], str]): The set of state to start with
+ base_state (StateMap[str]): The set of state to start with
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
Returns:
- Deferred[dict[tuple[str, str], str]]: Returns the final updated state
+ Deferred[StateMap[str]]: Returns the final updated state
"""
resolved_state = base_state.copy()
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
for event_id in event_ids:
event = event_map[event_id]
auth_events = {}
for aid in event.auth_event_ids():
- ev = yield _get_event(aid, event_map, state_res_store)
+ ev = yield _get_event(
+ room_id, aid, event_map, state_res_store, allow_none=True
+ )
- if ev.rejected_reason is None:
- auth_events[(ev.type, ev.state_key)] = ev
+ if not ev:
+ logger.warning(
+ "auth_event id %s for event %s is missing", aid, event_id
+ )
+ else:
+ if ev.rejected_reason is None:
+ auth_events[(ev.type, ev.state_key)] = ev
for key in event_auth.auth_types_for_event(event):
if key in resolved_state:
ev_id = resolved_state[key]
- ev = yield _get_event(ev_id, event_map, state_res_store)
+ ev = yield _get_event(room_id, ev_id, event_map, state_res_store)
if ev.rejected_reason is None:
auth_events[key] = event_map[ev_id]
try:
event_auth.check(
- room_version, event, auth_events,
+ room_version_obj,
+ event,
+ auth_events,
do_sig_check=False,
- do_size_check=False
+ do_size_check=False,
)
resolved_state[(event.type, event.state_key)] = event_id
except AuthError:
pass
- defer.returnValue(resolved_state)
+ return resolved_state
@defer.inlineCallbacks
-def _mainline_sort(event_ids, resolved_power_event_id, event_map,
- state_res_store):
+def _mainline_sort(
+ room_id, event_ids, resolved_power_event_id, event_map, state_res_store
+):
"""Returns a sorted list of event_ids sorted by mainline ordering based on
the given event resolved_power_event_id
Args:
+ room_id (str): room we're working in
event_ids (list[str]): Events to sort
resolved_power_event_id (str): The final resolved power level event ID
event_map (dict[str,FrozenEvent])
@@ -420,12 +443,14 @@ def _mainline_sort(event_ids, resolved_power_event_id, event_map,
pl = resolved_power_event_id
while pl:
mainline.append(pl)
- pl_ev = yield _get_event(pl, event_map, state_res_store)
+ pl_ev = yield _get_event(room_id, pl, event_map, state_res_store)
auth_events = pl_ev.auth_event_ids()
pl = None
for aid in auth_events:
- ev = yield _get_event(aid, event_map, state_res_store)
- if (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""):
+ ev = yield _get_event(
+ room_id, aid, event_map, state_res_store, allow_none=True
+ )
+ if ev and (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""):
pl = aid
break
@@ -436,14 +461,13 @@ def _mainline_sort(event_ids, resolved_power_event_id, event_map,
order_map = {}
for ev_id in event_ids:
depth = yield _get_mainline_depth_for_event(
- event_map[ev_id], mainline_map,
- event_map, state_res_store,
+ event_map[ev_id], mainline_map, event_map, state_res_store
)
order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id)
event_ids.sort(key=lambda ev_id: order_map[ev_id])
- defer.returnValue(event_ids)
+ return event_ids
@defer.inlineCallbacks
@@ -461,43 +485,62 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
Deferred[int]
"""
+ room_id = event.room_id
+
# We do an iterative search, replacing `event with the power level in its
# auth events (if any)
while event:
depth = mainline_map.get(event.event_id)
if depth is not None:
- defer.returnValue(depth)
+ return depth
auth_events = event.auth_event_ids()
event = None
for aid in auth_events:
- aev = yield _get_event(aid, event_map, state_res_store)
- if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
+ aev = yield _get_event(
+ room_id, aid, event_map, state_res_store, allow_none=True
+ )
+ if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
event = aev
break
# Didn't find a power level auth event, so we just return 0
- defer.returnValue(0)
+ return 0
@defer.inlineCallbacks
-def _get_event(event_id, event_map, state_res_store):
+def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
"""Helper function to look up event in event_map, falling back to looking
it up in the store
Args:
+ room_id (str)
event_id (str)
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
+ allow_none (bool): if the event is not found, return None rather than raising
+ an exception
Returns:
- Deferred[FrozenEvent]
+ Deferred[Optional[FrozenEvent]]
"""
if event_id not in event_map:
events = yield state_res_store.get_events([event_id], allow_rejected=True)
event_map.update(events)
- defer.returnValue(event_map[event_id])
+ event = event_map.get(event_id)
+
+ if event is None:
+ if allow_none:
+ return None
+ raise Exception("Unknown event %s" % (event_id,))
+
+ if event.room_id != room_id:
+ raise Exception(
+ "In state res for room %s, event %s is in %s"
+ % (room_id, event_id, event.room_id)
+ )
+ return event
def lexicographical_topological_sort(graph, key):
diff --git a/synapse/static/client/login/js/login.js b/synapse/static/client/login/js/login.js
index e02663f50e..276c271bbe 100644
--- a/synapse/static/client/login/js/login.js
+++ b/synapse/static/client/login/js/login.js
@@ -62,7 +62,7 @@ var show_login = function() {
$("#sso_flow").show();
}
- if (!matrixLogin.serverAcceptsPassword && !matrixLogin.serverAcceptsCas) {
+ if (!matrixLogin.serverAcceptsPassword && !matrixLogin.serverAcceptsCas && !matrixLogin.serverAcceptsSso) {
$("#no_login_types").show();
}
};
diff --git a/synapse/static/index.html b/synapse/static/index.html
index d3f1c7dce0..bf46df9097 100644
--- a/synapse/static/index.html
+++ b/synapse/static/index.html
@@ -48,13 +48,13 @@
</div>
<h1>It works! Synapse is running</h1>
<p>Your Synapse server is listening on this port and is ready for messages.</p>
- <p>To use this server you'll need <a href="https://matrix.org/docs/projects/try-matrix-now.html#clients" target="_blank">a Matrix client</a>.
+ <p>To use this server you'll need <a href="https://matrix.org/docs/projects/try-matrix-now.html#clients" target="_blank" rel="noopener noreferrer">a Matrix client</a>.
</p>
<p>Welcome to the Matrix universe :)</p>
<hr>
<p>
<small>
- <a href="https://matrix.org" target="_blank">
+ <a href="https://matrix.org" target="_blank" rel="noopener noreferrer">
matrix.org
</a>
</small>
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 71316f7d09..ec89f645d4 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018,2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,506 +14,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import calendar
-import logging
-import time
-
-from twisted.internet import defer
-
-from synapse.api.constants import PresenceState
-from synapse.storage.devices import DeviceStore
-from synapse.storage.user_erasure_store import UserErasureStore
-from synapse.util.caches.stream_change_cache import StreamChangeCache
-
-from .account_data import AccountDataStore
-from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
-from .client_ips import ClientIpStore
-from .deviceinbox import DeviceInboxStore
-from .directory import DirectoryStore
-from .e2e_room_keys import EndToEndRoomKeyStore
-from .end_to_end_keys import EndToEndKeyStore
-from .engines import PostgresEngine
-from .event_federation import EventFederationStore
-from .event_push_actions import EventPushActionsStore
-from .events import EventsStore
-from .events_bg_updates import EventsBackgroundUpdatesStore
-from .filtering import FilteringStore
-from .group_server import GroupServerStore
-from .keys import KeyStore
-from .media_repository import MediaRepositoryStore
-from .monthly_active_users import MonthlyActiveUsersStore
-from .openid import OpenIdStore
-from .presence import PresenceStore, UserPresenceState
-from .profile import ProfileStore
-from .push_rule import PushRuleStore
-from .pusher import PusherStore
-from .receipts import ReceiptsStore
-from .registration import RegistrationStore
-from .rejections import RejectionsStore
-from .relations import RelationsStore
-from .room import RoomStore
-from .roommember import RoomMemberStore
-from .search import SearchStore
-from .signatures import SignatureStore
-from .state import StateStore
-from .stats import StatsStore
-from .stream import StreamStore
-from .tags import TagsStore
-from .transactions import TransactionStore
-from .user_directory import UserDirectoryStore
-from .util.id_generators import ChainedIdGenerator, IdGenerator, StreamIdGenerator
-
-logger = logging.getLogger(__name__)
-
-
-class DataStore(
- EventsBackgroundUpdatesStore,
- RoomMemberStore,
- RoomStore,
- RegistrationStore,
- StreamStore,
- ProfileStore,
- PresenceStore,
- TransactionStore,
- DirectoryStore,
- KeyStore,
- StateStore,
- SignatureStore,
- ApplicationServiceStore,
- EventsStore,
- EventFederationStore,
- MediaRepositoryStore,
- RejectionsStore,
- FilteringStore,
- PusherStore,
- PushRuleStore,
- ApplicationServiceTransactionStore,
- ReceiptsStore,
- EndToEndKeyStore,
- EndToEndRoomKeyStore,
- SearchStore,
- TagsStore,
- AccountDataStore,
- EventPushActionsStore,
- OpenIdStore,
- ClientIpStore,
- DeviceStore,
- DeviceInboxStore,
- UserDirectoryStore,
- GroupServerStore,
- UserErasureStore,
- MonthlyActiveUsersStore,
- StatsStore,
- RelationsStore,
-):
- def __init__(self, db_conn, hs):
- self.hs = hs
- self._clock = hs.get_clock()
- self.database_engine = hs.database_engine
-
- self._stream_id_gen = StreamIdGenerator(
- db_conn,
- "events",
- "stream_ordering",
- extra_tables=[("local_invites", "stream_id")],
- )
- self._backfill_id_gen = StreamIdGenerator(
- db_conn,
- "events",
- "stream_ordering",
- step=-1,
- extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
- )
- self._presence_id_gen = StreamIdGenerator(
- db_conn, "presence_stream", "stream_id"
- )
- self._device_inbox_id_gen = StreamIdGenerator(
- db_conn, "device_max_stream_id", "stream_id"
- )
- self._public_room_id_gen = StreamIdGenerator(
- db_conn, "public_room_list_stream", "stream_id"
- )
- self._device_list_id_gen = StreamIdGenerator(
- db_conn, "device_lists_stream", "stream_id"
- )
-
- self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
- self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
- self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
- self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
- self._push_rules_stream_id_gen = ChainedIdGenerator(
- self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
- )
- self._pushers_id_gen = StreamIdGenerator(
- db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
- )
- self._group_updates_id_gen = StreamIdGenerator(
- db_conn, "local_group_updates", "stream_id"
- )
-
- if isinstance(self.database_engine, PostgresEngine):
- self._cache_id_gen = StreamIdGenerator(
- db_conn, "cache_invalidation_stream", "stream_id"
- )
- else:
- self._cache_id_gen = None
-
- self._presence_on_startup = self._get_active_presence(db_conn)
-
- presence_cache_prefill, min_presence_val = self._get_cache_dict(
- db_conn,
- "presence_stream",
- entity_column="user_id",
- stream_column="stream_id",
- max_value=self._presence_id_gen.get_current_token(),
- )
- self.presence_stream_cache = StreamChangeCache(
- "PresenceStreamChangeCache",
- min_presence_val,
- prefilled_cache=presence_cache_prefill,
- )
-
- max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
- device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
- db_conn,
- "device_inbox",
- entity_column="user_id",
- stream_column="stream_id",
- max_value=max_device_inbox_id,
- limit=1000,
- )
- self._device_inbox_stream_cache = StreamChangeCache(
- "DeviceInboxStreamChangeCache",
- min_device_inbox_id,
- prefilled_cache=device_inbox_prefill,
- )
- # The federation outbox and the local device inbox uses the same
- # stream_id generator.
- device_outbox_prefill, min_device_outbox_id = self._get_cache_dict(
- db_conn,
- "device_federation_outbox",
- entity_column="destination",
- stream_column="stream_id",
- max_value=max_device_inbox_id,
- limit=1000,
- )
- self._device_federation_outbox_stream_cache = StreamChangeCache(
- "DeviceFederationOutboxStreamChangeCache",
- min_device_outbox_id,
- prefilled_cache=device_outbox_prefill,
- )
-
- device_list_max = self._device_list_id_gen.get_current_token()
- self._device_list_stream_cache = StreamChangeCache(
- "DeviceListStreamChangeCache", device_list_max
- )
- self._device_list_federation_stream_cache = StreamChangeCache(
- "DeviceListFederationStreamChangeCache", device_list_max
- )
-
- events_max = self._stream_id_gen.get_current_token()
- curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
- db_conn,
- "current_state_delta_stream",
- entity_column="room_id",
- stream_column="stream_id",
- max_value=events_max, # As we share the stream id with events token
- limit=1000,
- )
- self._curr_state_delta_stream_cache = StreamChangeCache(
- "_curr_state_delta_stream_cache",
- min_curr_state_delta_id,
- prefilled_cache=curr_state_delta_prefill,
- )
-
- _group_updates_prefill, min_group_updates_id = self._get_cache_dict(
- db_conn,
- "local_group_updates",
- entity_column="user_id",
- stream_column="stream_id",
- max_value=self._group_updates_id_gen.get_current_token(),
- limit=1000,
- )
- self._group_updates_stream_cache = StreamChangeCache(
- "_group_updates_stream_cache",
- min_group_updates_id,
- prefilled_cache=_group_updates_prefill,
- )
-
- self._stream_order_on_start = self.get_room_max_stream_ordering()
- self._min_stream_order_on_start = self.get_room_min_stream_ordering()
-
- # Used in _generate_user_daily_visits to keep track of progress
- self._last_user_visit_update = self._get_start_of_day()
-
- super(DataStore, self).__init__(db_conn, hs)
-
- def take_presence_startup_info(self):
- active_on_startup = self._presence_on_startup
- self._presence_on_startup = None
- return active_on_startup
-
- def _get_active_presence(self, db_conn):
- """Fetch non-offline presence from the database so that we can register
- the appropriate time outs.
- """
-
- sql = (
- "SELECT user_id, state, last_active_ts, last_federation_update_ts,"
- " last_user_sync_ts, status_msg, currently_active FROM presence_stream"
- " WHERE state != ?"
- )
- sql = self.database_engine.convert_param_style(sql)
-
- txn = db_conn.cursor()
- txn.execute(sql, (PresenceState.OFFLINE,))
- rows = self.cursor_to_dict(txn)
- txn.close()
-
- for row in rows:
- row["currently_active"] = bool(row["currently_active"])
-
- return [UserPresenceState(**row) for row in rows]
-
- def count_daily_users(self):
- """
- Counts the number of users who used this homeserver in the last 24 hours.
- """
-
- def _count_users(txn):
- yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
-
- sql = """
- SELECT COALESCE(count(*), 0) FROM (
- SELECT user_id FROM user_ips
- WHERE last_seen > ?
- GROUP BY user_id
- ) u
- """
-
- txn.execute(sql, (yesterday,))
- count, = txn.fetchone()
- return count
-
- return self.runInteraction("count_users", _count_users)
-
- def count_r30_users(self):
- """
- Counts the number of 30 day retained users, defined as:-
- * Users who have created their accounts more than 30 days ago
- * Where last seen at most 30 days ago
- * Where account creation and last_seen are > 30 days apart
-
- Returns counts globaly for a given user as well as breaking
- by platform
- """
-
- def _count_r30_users(txn):
- thirty_days_in_secs = 86400 * 30
- now = int(self._clock.time())
- thirty_days_ago_in_secs = now - thirty_days_in_secs
-
- sql = """
- SELECT platform, COALESCE(count(*), 0) FROM (
- SELECT
- users.name, platform, users.creation_ts * 1000,
- MAX(uip.last_seen)
- FROM users
- INNER JOIN (
- SELECT
- user_id,
- last_seen,
- CASE
- WHEN user_agent LIKE '%%Android%%' THEN 'android'
- WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
- WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
- WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
- WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
- ELSE 'unknown'
- END
- AS platform
- FROM user_ips
- ) uip
- ON users.name = uip.user_id
- AND users.appservice_id is NULL
- AND users.creation_ts < ?
- AND uip.last_seen/1000 > ?
- AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
- GROUP BY users.name, platform, users.creation_ts
- ) u GROUP BY platform
- """
-
- results = {}
- txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
-
- for row in txn:
- if row[0] == 'unknown':
- pass
- results[row[0]] = row[1]
-
- sql = """
- SELECT COALESCE(count(*), 0) FROM (
- SELECT users.name, users.creation_ts * 1000,
- MAX(uip.last_seen)
- FROM users
- INNER JOIN (
- SELECT
- user_id,
- last_seen
- FROM user_ips
- ) uip
- ON users.name = uip.user_id
- AND appservice_id is NULL
- AND users.creation_ts < ?
- AND uip.last_seen/1000 > ?
- AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
- GROUP BY users.name, users.creation_ts
- ) u
- """
-
- txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
-
- count, = txn.fetchone()
- results['all'] = count
-
- return results
-
- return self.runInteraction("count_r30_users", _count_r30_users)
-
- def _get_start_of_day(self):
- """
- Returns millisecond unixtime for start of UTC day.
- """
- now = time.gmtime()
- today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
- return today_start * 1000
-
- def generate_user_daily_visits(self):
- """
- Generates daily visit data for use in cohort/ retention analysis
- """
-
- def _generate_user_daily_visits(txn):
- logger.info("Calling _generate_user_daily_visits")
- today_start = self._get_start_of_day()
- a_day_in_milliseconds = 24 * 60 * 60 * 1000
- now = self.clock.time_msec()
-
- sql = """
- INSERT INTO user_daily_visits (user_id, device_id, timestamp)
- SELECT u.user_id, u.device_id, ?
- FROM user_ips AS u
- LEFT JOIN (
- SELECT user_id, device_id, timestamp FROM user_daily_visits
- WHERE timestamp = ?
- ) udv
- ON u.user_id = udv.user_id AND u.device_id=udv.device_id
- INNER JOIN users ON users.name=u.user_id
- WHERE last_seen > ? AND last_seen <= ?
- AND udv.timestamp IS NULL AND users.is_guest=0
- AND users.appservice_id IS NULL
- GROUP BY u.user_id, u.device_id
- """
-
- # This means that the day has rolled over but there could still
- # be entries from the previous day. There is an edge case
- # where if the user logs in at 23:59 and overwrites their
- # last_seen at 00:01 then they will not be counted in the
- # previous day's stats - it is important that the query is run
- # often to minimise this case.
- if today_start > self._last_user_visit_update:
- yesterday_start = today_start - a_day_in_milliseconds
- txn.execute(
- sql,
- (
- yesterday_start,
- yesterday_start,
- self._last_user_visit_update,
- today_start,
- ),
- )
- self._last_user_visit_update = today_start
-
- txn.execute(
- sql, (today_start, today_start, self._last_user_visit_update, now)
- )
- # Update _last_user_visit_update to now. The reason to do this
- # rather just clamping to the beginning of the day is to limit
- # the size of the join - meaning that the query can be run more
- # frequently
- self._last_user_visit_update = now
-
- return self.runInteraction(
- "generate_user_daily_visits", _generate_user_daily_visits
- )
-
- def get_users(self):
- """Function to reterive a list of users in users table.
-
- Args:
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
- """
- return self._simple_select_list(
- table="users",
- keyvalues={},
- retcols=["name", "password_hash", "is_guest", "admin"],
- desc="get_users",
- )
-
- @defer.inlineCallbacks
- def get_users_paginate(self, order, start, limit):
- """Function to reterive a paginated list of users from
- users list. This will return a json object, which contains
- list of users and the total number of users in users table.
-
- Args:
- order (str): column name to order the select by this column
- start (int): start number to begin the query from
- limit (int): number of rows to reterive
- Returns:
- defer.Deferred: resolves to json object {list[dict[str, Any]], count}
- """
- users = yield self.runInteraction(
- "get_users_paginate",
- self._simple_select_list_paginate_txn,
- table="users",
- keyvalues={"is_guest": False},
- orderby=order,
- start=start,
- limit=limit,
- retcols=["name", "password_hash", "is_guest", "admin"],
- )
- count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn)
- retval = {"users": users, "total": count}
- defer.returnValue(retval)
-
- def search_users(self, term):
- """Function to search users list for one or more users with
- the matched term.
-
- Args:
- term (str): search term
- col (str): column to query term should be matched to
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
- """
- return self._simple_search_list(
- table="users",
- term=term,
- col="name",
- retcols=["name", "password_hash", "is_guest", "admin"],
- desc="search_users",
- )
-
-
-def are_all_users_on_domain(txn, database_engine, domain):
- sql = database_engine.convert_param_style(
- "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
- )
- pat = "%:" + domain
- txn.execute(sql, (pat,))
- num_not_matching = txn.fetchall()[0][0]
- if num_not_matching == 0:
- return True
- return False
+"""
+The storage layer is split up into multiple parts to allow Synapse to run
+against different configurations of databases (e.g. single or multiple
+databases). The `Database` class represents a single physical database. The
+`data_stores` are classes that talk directly to a `Database` instance and have
+associated schemas, background updates, etc. On top of those there are classes
+that provide high level interfaces that combine calls to multiple `data_stores`.
+
+There are also schemas that get applied to every database, regardless of the
+data stores associated with them (e.g. the schema version tables), which are
+stored in `synapse.storage.schema`.
+"""
+
+from synapse.storage.data_stores import DataStores
+from synapse.storage.data_stores.main import DataStore
+from synapse.storage.persist_events import EventsPersistenceStorage
+from synapse.storage.purge_events import PurgeEventsStorage
+from synapse.storage.state import StateGroupStorage
+
+__all__ = ["DataStores", "DataStore"]
+
+
+class Storage(object):
+ """The high level interfaces for talking to various storage layers.
+ """
+
+ def __init__(self, hs, stores: DataStores):
+ # We include the main data store here mainly so that we don't have to
+ # rewrite all the existing code to split it into high vs low level
+ # interfaces.
+ self.main = stores.main
+
+ self.persistence = EventsPersistenceStorage(hs, stores)
+ self.purge_events = PurgeEventsStorage(hs, stores)
+ self.state = StateGroupStorage(hs, stores)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 537696547c..13de5f1f62 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -14,1382 +14,39 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import itertools
import logging
import random
-import sys
-import threading
-import time
+from abc import ABCMeta
+from typing import Any, Optional
-from six import PY2, iteritems, iterkeys, itervalues
-from six.moves import builtins, intern, range
+from six import PY2
+from six.moves import builtins
from canonicaljson import json
-from prometheus_client import Histogram
-from twisted.internet import defer
-
-from synapse.api.errors import StoreError
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.types import get_domain_from_id
-from synapse.util import batch_iter
-from synapse.util.caches.descriptors import Cache
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
-from synapse.util.stringutils import exception_to_unicode
+from synapse.storage.database import LoggingTransaction # noqa: F401
+from synapse.storage.database import make_in_list_sql_clause # noqa: F401
+from synapse.storage.database import Database
+from synapse.types import Collection, get_domain_from_id
logger = logging.getLogger(__name__)
-try:
- MAX_TXN_ID = sys.maxint - 1
-except AttributeError:
- # python 3 does not have a maximum int value
- MAX_TXN_ID = 2 ** 63 - 1
-
-sql_logger = logging.getLogger("synapse.storage.SQL")
-transaction_logger = logging.getLogger("synapse.storage.txn")
-perf_logger = logging.getLogger("synapse.storage.TIME")
-
-sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
-
-sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
-sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
-
-
-# Unique indexes which have been added in background updates. Maps from table name
-# to the name of the background update which added the unique index to that table.
-#
-# This is used by the upsert logic to figure out which tables are safe to do a proper
-# UPSERT on: until the relevant background update has completed, we
-# have to emulate an upsert by locking the table.
-#
-UNIQUE_INDEX_BACKGROUND_UPDATES = {
- "user_ips": "user_ips_device_unique_index",
- "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
- "device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
- "event_search": "event_search_event_id_idx",
-}
-
-# This is a special cache name we use to batch multiple invalidations of caches
-# based on the current state when notifying workers over replication.
-_CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
-
-
-class LoggingTransaction(object):
- """An object that almost-transparently proxies for the 'txn' object
- passed to the constructor. Adds logging and metrics to the .execute()
- method."""
-
- __slots__ = [
- "txn",
- "name",
- "database_engine",
- "after_callbacks",
- "exception_callbacks",
- ]
-
- def __init__(
- self, txn, name, database_engine, after_callbacks, exception_callbacks
- ):
- object.__setattr__(self, "txn", txn)
- object.__setattr__(self, "name", name)
- object.__setattr__(self, "database_engine", database_engine)
- object.__setattr__(self, "after_callbacks", after_callbacks)
- object.__setattr__(self, "exception_callbacks", exception_callbacks)
-
- def call_after(self, callback, *args, **kwargs):
- """Call the given callback on the main twisted thread after the
- transaction has finished. Used to invalidate the caches on the
- correct thread.
- """
- self.after_callbacks.append((callback, args, kwargs))
-
- def call_on_exception(self, callback, *args, **kwargs):
- self.exception_callbacks.append((callback, args, kwargs))
-
- def __getattr__(self, name):
- return getattr(self.txn, name)
-
- def __setattr__(self, name, value):
- setattr(self.txn, name, value)
-
- def __iter__(self):
- return self.txn.__iter__()
-
- def execute_batch(self, sql, args):
- if isinstance(self.database_engine, PostgresEngine):
- from psycopg2.extras import execute_batch
-
- self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
- else:
- for val in args:
- self.execute(sql, val)
-
- def execute(self, sql, *args):
- self._do_execute(self.txn.execute, sql, *args)
- def executemany(self, sql, *args):
- self._do_execute(self.txn.executemany, sql, *args)
+# some of our subclasses have abstract methods, so we use the ABCMeta metaclass.
+class SQLBaseStore(metaclass=ABCMeta):
+ """Base class for data stores that holds helper functions.
- def _make_sql_one_line(self, sql):
- "Strip newlines out of SQL so that the loggers in the DB are on one line"
- return " ".join(l.strip() for l in sql.splitlines() if l.strip())
-
- def _do_execute(self, func, sql, *args):
- sql = self._make_sql_one_line(sql)
-
- # TODO(paul): Maybe use 'info' and 'debug' for values?
- sql_logger.debug("[SQL] {%s} %s", self.name, sql)
-
- sql = self.database_engine.convert_param_style(sql)
- if args:
- try:
- sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
- except Exception:
- # Don't let logging failures stop SQL from working
- pass
-
- start = time.time()
-
- try:
- return func(sql, *args)
- except Exception as e:
- logger.debug("[SQL FAIL] {%s} %s", self.name, e)
- raise
- finally:
- secs = time.time() - start
- sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
- sql_query_timer.labels(sql.split()[0]).observe(secs)
-
-
-class PerformanceCounters(object):
- def __init__(self):
- self.current_counters = {}
- self.previous_counters = {}
-
- def update(self, key, start_time, end_time=None):
- if end_time is None:
- end_time = time.time()
- duration = end_time - start_time
- count, cum_time = self.current_counters.get(key, (0, 0))
- count += 1
- cum_time += duration
- self.current_counters[key] = (count, cum_time)
- return end_time
-
- def interval(self, interval_duration, limit=3):
- counters = []
- for name, (count, cum_time) in iteritems(self.current_counters):
- prev_count, prev_time = self.previous_counters.get(name, (0, 0))
- counters.append(
- ((cum_time - prev_time) / interval_duration, count - prev_count, name)
- )
-
- self.previous_counters = dict(self.current_counters)
-
- counters.sort(reverse=True)
-
- top_n_counters = ", ".join(
- "%s(%d): %.3f%%" % (name, count, 100 * ratio)
- for ratio, count, name in counters[:limit]
- )
-
- return top_n_counters
-
-
-class SQLBaseStore(object):
- _TXN_ID = 0
+ Note that multiple instances of this class will exist as there will be one
+ per data store (and not one per physical database).
+ """
- def __init__(self, db_conn, hs):
+ def __init__(self, database: Database, db_conn, hs):
self.hs = hs
self._clock = hs.get_clock()
- self._db_pool = hs.get_db_pool()
-
- self._previous_txn_total_time = 0
- self._current_txn_total_time = 0
- self._previous_loop_ts = 0
-
- # TODO(paul): These can eventually be removed once the metrics code
- # is running in mainline, and we have some nice monitoring frontends
- # to watch it
- self._txn_perf_counters = PerformanceCounters()
- self._get_event_counters = PerformanceCounters()
-
- self._get_event_cache = Cache(
- "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
- )
-
- self._event_fetch_lock = threading.Condition()
- self._event_fetch_list = []
- self._event_fetch_ongoing = 0
-
- self._pending_ds = []
-
- self.database_engine = hs.database_engine
-
- # A set of tables that are not safe to use native upserts in.
- self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
-
- self._account_validity = self.hs.config.account_validity
-
- # We add the user_directory_search table to the blacklist on SQLite
- # because the existing search table does not have an index, making it
- # unsafe to use native upserts.
- if isinstance(self.database_engine, Sqlite3Engine):
- self._unsafe_to_upsert_tables.add("user_directory_search")
-
- if self.database_engine.can_native_upsert:
- # Check ASAP (and then later, every 1s) to see if we have finished
- # background updates of tables that aren't safe to update.
- self._clock.call_later(
- 0.0,
- run_as_background_process,
- "upsert_safety_check",
- self._check_safe_to_upsert,
- )
-
+ self.database_engine = database.engine
+ self.db = database
self.rand = random.SystemRandom()
- if self._account_validity.enabled:
- self._clock.call_later(
- 0.0,
- run_as_background_process,
- "account_validity_set_expiration_dates",
- self._set_expiration_date_when_missing,
- )
-
- @defer.inlineCallbacks
- def _check_safe_to_upsert(self):
- """
- Is it safe to use native UPSERT?
-
- If there are background updates, we will need to wait, as they may be
- the addition of indexes that set the UNIQUE constraint that we require.
-
- If the background updates have not completed, wait 15 sec and check again.
- """
- updates = yield self._simple_select_list(
- "background_updates",
- keyvalues=None,
- retcols=["update_name"],
- desc="check_background_updates",
- )
- updates = [x["update_name"] for x in updates]
-
- for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
- if update_name not in updates:
- logger.debug("Now safe to upsert in %s", table)
- self._unsafe_to_upsert_tables.discard(table)
-
- # If there's any updates still running, reschedule to run.
- if updates:
- self._clock.call_later(
- 15.0,
- run_as_background_process,
- "upsert_safety_check",
- self._check_safe_to_upsert,
- )
-
- @defer.inlineCallbacks
- def _set_expiration_date_when_missing(self):
- """
- Retrieves the list of registered users that don't have an expiration date, and
- adds an expiration date for each of them.
- """
-
- def select_users_with_no_expiration_date_txn(txn):
- """Retrieves the list of registered users with no expiration date from the
- database, filtering out deactivated users.
- """
- sql = (
- "SELECT users.name FROM users"
- " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
- " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
- )
- txn.execute(sql, [])
-
- res = self.cursor_to_dict(txn)
- if res:
- for user in res:
- self.set_expiration_date_for_user_txn(
- txn,
- user["name"],
- use_delta=True,
- )
-
- yield self.runInteraction(
- "get_users_with_no_expiration_date",
- select_users_with_no_expiration_date_txn,
- )
-
- def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
- """Sets an expiration date to the account with the given user ID.
-
- Args:
- user_id (str): User ID to set an expiration date for.
- use_delta (bool): If set to False, the expiration date for the user will be
- now + validity period. If set to True, this expiration date will be a
- random value in the [now + period - d ; now + period] range, d being a
- delta equal to 10% of the validity period.
- """
- now_ms = self._clock.time_msec()
- expiration_ts = now_ms + self._account_validity.period
-
- if use_delta:
- expiration_ts = self.rand.randrange(
- expiration_ts - self._account_validity.startup_job_max_delta,
- expiration_ts,
- )
-
- self._simple_upsert_txn(
- txn,
- "account_validity",
- keyvalues={"user_id": user_id, },
- values={"expiration_ts_ms": expiration_ts, "email_sent": False, },
- )
-
- def start_profiling(self):
- self._previous_loop_ts = self._clock.time_msec()
-
- def loop():
- curr = self._current_txn_total_time
- prev = self._previous_txn_total_time
- self._previous_txn_total_time = curr
-
- time_now = self._clock.time_msec()
- time_then = self._previous_loop_ts
- self._previous_loop_ts = time_now
-
- ratio = (curr - prev) / (time_now - time_then)
-
- top_three_counters = self._txn_perf_counters.interval(
- time_now - time_then, limit=3
- )
-
- top_3_event_counters = self._get_event_counters.interval(
- time_now - time_then, limit=3
- )
-
- perf_logger.info(
- "Total database time: %.3f%% {%s} {%s}",
- ratio * 100,
- top_three_counters,
- top_3_event_counters,
- )
-
- self._clock.looping_call(loop, 10000)
-
- def _new_transaction(
- self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
- ):
- start = time.time()
- txn_id = self._TXN_ID
-
- # We don't really need these to be unique, so lets stop it from
- # growing really large.
- self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
-
- name = "%s-%x" % (desc, txn_id)
-
- transaction_logger.debug("[TXN START] {%s}", name)
-
- try:
- i = 0
- N = 5
- while True:
- try:
- txn = conn.cursor()
- txn = LoggingTransaction(
- txn,
- name,
- self.database_engine,
- after_callbacks,
- exception_callbacks,
- )
- r = func(txn, *args, **kwargs)
- conn.commit()
- return r
- except self.database_engine.module.OperationalError as e:
- # This can happen if the database disappears mid
- # transaction.
- logger.warning(
- "[TXN OPERROR] {%s} %s %d/%d",
- name,
- exception_to_unicode(e),
- i,
- N,
- )
- if i < N:
- i += 1
- try:
- conn.rollback()
- except self.database_engine.module.Error as e1:
- logger.warning(
- "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
- )
- continue
- raise
- except self.database_engine.module.DatabaseError as e:
- if self.database_engine.is_deadlock(e):
- logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
- if i < N:
- i += 1
- try:
- conn.rollback()
- except self.database_engine.module.Error as e1:
- logger.warning(
- "[TXN EROLL] {%s} %s",
- name,
- exception_to_unicode(e1),
- )
- continue
- raise
- except Exception as e:
- logger.debug("[TXN FAIL] {%s} %s", name, e)
- raise
- finally:
- end = time.time()
- duration = end - start
-
- LoggingContext.current_context().add_database_transaction(duration)
-
- transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
-
- self._current_txn_total_time += duration
- self._txn_perf_counters.update(desc, start, end)
- sql_txn_timer.labels(desc).observe(duration)
-
- @defer.inlineCallbacks
- def runInteraction(self, desc, func, *args, **kwargs):
- """Starts a transaction on the database and runs a given function
-
- Arguments:
- desc (str): description of the transaction, for logging and metrics
- func (func): callback function, which will be called with a
- database transaction (twisted.enterprise.adbapi.Transaction) as
- its first argument, followed by `args` and `kwargs`.
-
- args (list): positional args to pass to `func`
- kwargs (dict): named args to pass to `func`
-
- Returns:
- Deferred: The result of func
- """
- after_callbacks = []
- exception_callbacks = []
-
- if LoggingContext.current_context() == LoggingContext.sentinel:
- logger.warn("Starting db txn '%s' from sentinel context", desc)
-
- try:
- result = yield self.runWithConnection(
- self._new_transaction,
- desc,
- after_callbacks,
- exception_callbacks,
- func,
- *args,
- **kwargs
- )
-
- for after_callback, after_args, after_kwargs in after_callbacks:
- after_callback(*after_args, **after_kwargs)
- except: # noqa: E722, as we reraise the exception this is fine.
- for after_callback, after_args, after_kwargs in exception_callbacks:
- after_callback(*after_args, **after_kwargs)
- raise
-
- defer.returnValue(result)
-
- @defer.inlineCallbacks
- def runWithConnection(self, func, *args, **kwargs):
- """Wraps the .runWithConnection() method on the underlying db_pool.
-
- Arguments:
- func (func): callback function, which will be called with a
- database connection (twisted.enterprise.adbapi.Connection) as
- its first argument, followed by `args` and `kwargs`.
- args (list): positional args to pass to `func`
- kwargs (dict): named args to pass to `func`
-
- Returns:
- Deferred: The result of func
- """
- parent_context = LoggingContext.current_context()
- if parent_context == LoggingContext.sentinel:
- logger.warn(
- "Starting db connection from sentinel context: metrics will be lost"
- )
- parent_context = None
-
- start_time = time.time()
-
- def inner_func(conn, *args, **kwargs):
- with LoggingContext("runWithConnection", parent_context) as context:
- sched_duration_sec = time.time() - start_time
- sql_scheduling_timer.observe(sched_duration_sec)
- context.add_database_scheduled(sched_duration_sec)
-
- if self.database_engine.is_connection_closed(conn):
- logger.debug("Reconnecting closed database connection")
- conn.reconnect()
-
- return func(conn, *args, **kwargs)
-
- with PreserveLoggingContext():
- result = yield self._db_pool.runWithConnection(inner_func, *args, **kwargs)
-
- defer.returnValue(result)
-
- @staticmethod
- def cursor_to_dict(cursor):
- """Converts a SQL cursor into an list of dicts.
-
- Args:
- cursor : The DBAPI cursor which has executed a query.
- Returns:
- A list of dicts where the key is the column header.
- """
- col_headers = list(intern(str(column[0])) for column in cursor.description)
- results = list(dict(zip(col_headers, row)) for row in cursor)
- return results
-
- def _execute(self, desc, decoder, query, *args):
- """Runs a single query for a result set.
-
- Args:
- decoder - The function which can resolve the cursor results to
- something meaningful.
- query - The query string to execute
- *args - Query args.
- Returns:
- The result of decoder(results)
- """
-
- def interaction(txn):
- txn.execute(query, args)
- if decoder:
- return decoder(txn)
- else:
- return txn.fetchall()
-
- return self.runInteraction(desc, interaction)
-
- # "Simple" SQL API methods that operate on a single table with no JOINs,
- # no complex WHERE clauses, just a dict of values for columns.
-
- @defer.inlineCallbacks
- def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"):
- """Executes an INSERT query on the named table.
-
- Args:
- table : string giving the table name
- values : dict of new column names and values for them
- or_ignore : bool stating whether an exception should be raised
- when a conflicting row already exists. If True, False will be
- returned by the function instead
- desc : string giving a description of the transaction
-
- Returns:
- bool: Whether the row was inserted or not. Only useful when
- `or_ignore` is True
- """
- try:
- yield self.runInteraction(desc, self._simple_insert_txn, table, values)
- except self.database_engine.module.IntegrityError:
- # We have to do or_ignore flag at this layer, since we can't reuse
- # a cursor after we receive an error from the db.
- if not or_ignore:
- raise
- defer.returnValue(False)
- defer.returnValue(True)
-
- @staticmethod
- def _simple_insert_txn(txn, table, values):
- keys, vals = zip(*values.items())
-
- sql = "INSERT INTO %s (%s) VALUES(%s)" % (
- table,
- ", ".join(k for k in keys),
- ", ".join("?" for _ in keys),
- )
-
- txn.execute(sql, vals)
-
- def _simple_insert_many(self, table, values, desc):
- return self.runInteraction(desc, self._simple_insert_many_txn, table, values)
-
- @staticmethod
- def _simple_insert_many_txn(txn, table, values):
- if not values:
- return
-
- # This is a *slight* abomination to get a list of tuples of key names
- # and a list of tuples of value names.
- #
- # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
- # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
- #
- # The sort is to ensure that we don't rely on dictionary iteration
- # order.
- keys, vals = zip(
- *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
- )
-
- for k in keys:
- if k != keys[0]:
- raise RuntimeError("All items must have the same keys")
-
- sql = "INSERT INTO %s (%s) VALUES(%s)" % (
- table,
- ", ".join(k for k in keys[0]),
- ", ".join("?" for _ in keys[0]),
- )
-
- txn.executemany(sql, vals)
-
- @defer.inlineCallbacks
- def _simple_upsert(
- self,
- table,
- keyvalues,
- values,
- insertion_values={},
- desc="_simple_upsert",
- lock=True,
- ):
- """
-
- `lock` should generally be set to True (the default), but can be set
- to False if either of the following are true:
-
- * there is a UNIQUE INDEX on the key columns. In this case a conflict
- will cause an IntegrityError in which case this function will retry
- the update.
-
- * we somehow know that we are the only thread which will be updating
- this table.
-
- Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key columns and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
- Returns:
- Deferred(None or bool): Native upserts always return None. Emulated
- upserts return True if a new entry was created, False if an existing
- one was updated.
- """
- attempts = 0
- while True:
- try:
- result = yield self.runInteraction(
- desc,
- self._simple_upsert_txn,
- table,
- keyvalues,
- values,
- insertion_values,
- lock=lock,
- )
- defer.returnValue(result)
- except self.database_engine.module.IntegrityError as e:
- attempts += 1
- if attempts >= 5:
- # don't retry forever, because things other than races
- # can cause IntegrityErrors
- raise
-
- # presumably we raced with another transaction: let's retry.
- logger.warn(
- "IntegrityError when upserting into %s; retrying: %s", table, e
- )
-
- def _simple_upsert_txn(
- self, txn, table, keyvalues, values, insertion_values={}, lock=True
- ):
- """
- Pick the UPSERT method which works best on the platform. Either the
- native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
-
- Args:
- txn: The transaction to use.
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
- Returns:
- None or bool: Native upserts always return None. Emulated
- upserts return True if a new entry was created, False if an existing
- one was updated.
- """
- if (
- self.database_engine.can_native_upsert
- and table not in self._unsafe_to_upsert_tables
- ):
- return self._simple_upsert_txn_native_upsert(
- txn, table, keyvalues, values, insertion_values=insertion_values
- )
- else:
- return self._simple_upsert_txn_emulated(
- txn,
- table,
- keyvalues,
- values,
- insertion_values=insertion_values,
- lock=lock,
- )
-
- def _simple_upsert_txn_emulated(
- self, txn, table, keyvalues, values, insertion_values={}, lock=True
- ):
- """
- Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- lock (bool): True to lock the table when doing the upsert.
- Returns:
- bool: Return True if a new entry was created, False if an existing
- one was updated.
- """
- # We need to lock the table :(, unless we're *really* careful
- if lock:
- self.database_engine.lock_table(txn, table)
-
- def _getwhere(key):
- # If the value we're passing in is None (aka NULL), we need to use
- # IS, not =, as NULL = NULL equals NULL (False).
- if keyvalues[key] is None:
- return "%s IS ?" % (key,)
- else:
- return "%s = ?" % (key,)
-
- if not values:
- # If `values` is empty, then all of the values we care about are in
- # the unique key, so there is nothing to UPDATE. We can just do a
- # SELECT instead to see if it exists.
- sql = "SELECT 1 FROM %s WHERE %s" % (
- table,
- " AND ".join(_getwhere(k) for k in keyvalues),
- )
- sqlargs = list(keyvalues.values())
- txn.execute(sql, sqlargs)
- if txn.fetchall():
- # We have an existing record.
- return False
- else:
- # First try to update.
- sql = "UPDATE %s SET %s WHERE %s" % (
- table,
- ", ".join("%s = ?" % (k,) for k in values),
- " AND ".join(_getwhere(k) for k in keyvalues),
- )
- sqlargs = list(values.values()) + list(keyvalues.values())
-
- txn.execute(sql, sqlargs)
- if txn.rowcount > 0:
- # successfully updated at least one row.
- return False
-
- # We didn't find any existing rows, so insert a new one
- allvalues = {}
- allvalues.update(keyvalues)
- allvalues.update(values)
- allvalues.update(insertion_values)
-
- sql = "INSERT INTO %s (%s) VALUES (%s)" % (
- table,
- ", ".join(k for k in allvalues),
- ", ".join("?" for _ in allvalues),
- )
- txn.execute(sql, list(allvalues.values()))
- # successfully inserted
- return True
-
- def _simple_upsert_txn_native_upsert(
- self, txn, table, keyvalues, values, insertion_values={}
- ):
- """
- Use the native UPSERT functionality in recent PostgreSQL versions.
-
- Args:
- table (str): The table to upsert into
- keyvalues (dict): The unique key tables and their new values
- values (dict): The nonunique columns and their new values
- insertion_values (dict): additional key/values to use only when
- inserting
- Returns:
- None
- """
- allvalues = {}
- allvalues.update(keyvalues)
- allvalues.update(insertion_values)
-
- if not values:
- latter = "NOTHING"
- else:
- allvalues.update(values)
- latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
-
- sql = ("INSERT INTO %s (%s) VALUES (%s) " "ON CONFLICT (%s) DO %s") % (
- table,
- ", ".join(k for k in allvalues),
- ", ".join("?" for _ in allvalues),
- ", ".join(k for k in keyvalues),
- latter,
- )
- txn.execute(sql, list(allvalues.values()))
-
- def _simple_upsert_many_txn(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
- """
- Upsert, many times.
-
- Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
- """
- if (
- self.database_engine.can_native_upsert
- and table not in self._unsafe_to_upsert_tables
- ):
- return self._simple_upsert_many_txn_native_upsert(
- txn, table, key_names, key_values, value_names, value_values
- )
- else:
- return self._simple_upsert_many_txn_emulated(
- txn, table, key_names, key_values, value_names, value_values
- )
-
- def _simple_upsert_many_txn_emulated(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
- """
- Upsert, many times, but without native UPSERT support or batching.
-
- Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
- """
- # No value columns, therefore make a blank list so that the following
- # zip() works correctly.
- if not value_names:
- value_values = [() for x in range(len(key_values))]
-
- for keyv, valv in zip(key_values, value_values):
- _keys = {x: y for x, y in zip(key_names, keyv)}
- _vals = {x: y for x, y in zip(value_names, valv)}
-
- self._simple_upsert_txn_emulated(txn, table, _keys, _vals)
-
- def _simple_upsert_many_txn_native_upsert(
- self, txn, table, key_names, key_values, value_names, value_values
- ):
- """
- Upsert, many times, using batching where possible.
-
- Args:
- table (str): The table to upsert into
- key_names (list[str]): The key column names.
- key_values (list[list]): A list of each row's key column values.
- value_names (list[str]): The value column names. If empty, no
- values will be used, even if value_values is provided.
- value_values (list[list]): A list of each row's value column values.
- Returns:
- None
- """
- allnames = []
- allnames.extend(key_names)
- allnames.extend(value_names)
-
- if not value_names:
- # No value columns, therefore make a blank list so that the
- # following zip() works correctly.
- latter = "NOTHING"
- value_values = [() for x in range(len(key_values))]
- else:
- latter = "UPDATE SET " + ", ".join(
- k + "=EXCLUDED." + k for k in value_names
- )
-
- sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
- table,
- ", ".join(k for k in allnames),
- ", ".join("?" for _ in allnames),
- ", ".join(key_names),
- latter,
- )
-
- args = []
-
- for x, y in zip(key_values, value_values):
- args.append(tuple(x) + tuple(y))
-
- return txn.execute_batch(sql, args)
-
- def _simple_select_one(
- self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"
- ):
- """Executes a SELECT query on the named table, which is expected to
- return a single row, returning multiple columns from it.
-
- Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- retcols : list of strings giving the names of the columns to return
-
- allow_none : If true, return None instead of failing if the SELECT
- statement returns no rows
- """
- return self.runInteraction(
- desc, self._simple_select_one_txn, table, keyvalues, retcols, allow_none
- )
-
- def _simple_select_one_onecol(
- self,
- table,
- keyvalues,
- retcol,
- allow_none=False,
- desc="_simple_select_one_onecol",
- ):
- """Executes a SELECT query on the named table, which is expected to
- return a single row, returning a single column from it.
-
- Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- retcol : string giving the name of the column to return
- """
- return self.runInteraction(
- desc,
- self._simple_select_one_onecol_txn,
- table,
- keyvalues,
- retcol,
- allow_none=allow_none,
- )
-
- @classmethod
- def _simple_select_one_onecol_txn(
- cls, txn, table, keyvalues, retcol, allow_none=False
- ):
- ret = cls._simple_select_onecol_txn(
- txn, table=table, keyvalues=keyvalues, retcol=retcol
- )
-
- if ret:
- return ret[0]
- else:
- if allow_none:
- return None
- else:
- raise StoreError(404, "No row found")
-
- @staticmethod
- def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
- sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
-
- if keyvalues:
- sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
- txn.execute(sql, list(keyvalues.values()))
- else:
- txn.execute(sql)
-
- return [r[0] for r in txn]
-
- def _simple_select_onecol(
- self, table, keyvalues, retcol, desc="_simple_select_onecol"
- ):
- """Executes a SELECT query on the named table, which returns a list
- comprising of the values of the named column from the selected rows.
-
- Args:
- table (str): table name
- keyvalues (dict|None): column names and values to select the rows with
- retcol (str): column whos value we wish to retrieve.
-
- Returns:
- Deferred: Results in a list
- """
- return self.runInteraction(
- desc, self._simple_select_onecol_txn, table, keyvalues, retcol
- )
-
- def _simple_select_list(
- self, table, keyvalues, retcols, desc="_simple_select_list"
- ):
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
-
- Args:
- table (str): the table name
- keyvalues (dict[str, Any] | None):
- column names and values to select the rows with, or None to not
- apply a WHERE clause.
- retcols (iterable[str]): the names of the columns to return
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
- """
- return self.runInteraction(
- desc, self._simple_select_list_txn, table, keyvalues, retcols
- )
-
- @classmethod
- def _simple_select_list_txn(cls, txn, table, keyvalues, retcols):
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
-
- Args:
- txn : Transaction object
- table (str): the table name
- keyvalues (dict[str, T] | None):
- column names and values to select the rows with, or None to not
- apply a WHERE clause.
- retcols (iterable[str]): the names of the columns to return
- """
- if keyvalues:
- sql = "SELECT %s FROM %s WHERE %s" % (
- ", ".join(retcols),
- table,
- " AND ".join("%s = ?" % (k,) for k in keyvalues),
- )
- txn.execute(sql, list(keyvalues.values()))
- else:
- sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
- txn.execute(sql)
-
- return cls.cursor_to_dict(txn)
-
- @defer.inlineCallbacks
- def _simple_select_many_batch(
- self,
- table,
- column,
- iterable,
- retcols,
- keyvalues={},
- desc="_simple_select_many_batch",
- batch_size=100,
- ):
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
-
- Filters rows by if value of `column` is in `iterable`.
-
- Args:
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
- retcols : list of strings giving the names of the columns to return
- """
- results = []
-
- if not iterable:
- defer.returnValue(results)
-
- # iterables can not be sliced, so convert it to a list first
- it_list = list(iterable)
-
- chunks = [
- it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
- ]
- for chunk in chunks:
- rows = yield self.runInteraction(
- desc,
- self._simple_select_many_txn,
- table,
- column,
- chunk,
- keyvalues,
- retcols,
- )
-
- results.extend(rows)
-
- defer.returnValue(results)
-
- @classmethod
- def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
-
- Filters rows by if value of `column` is in `iterable`.
-
- Args:
- txn : Transaction object
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
- retcols : list of strings giving the names of the columns to return
- """
- if not iterable:
- return []
-
- sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
-
- clauses = []
- values = []
- clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
- values.extend(iterable)
-
- for key, value in iteritems(keyvalues):
- clauses.append("%s = ?" % (key,))
- values.append(value)
-
- if clauses:
- sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
-
- txn.execute(sql, values)
- return cls.cursor_to_dict(txn)
-
- def _simple_update(self, table, keyvalues, updatevalues, desc):
- return self.runInteraction(
- desc, self._simple_update_txn, table, keyvalues, updatevalues
- )
-
- @staticmethod
- def _simple_update_txn(txn, table, keyvalues, updatevalues):
- if keyvalues:
- where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
- else:
- where = ""
-
- update_sql = "UPDATE %s SET %s %s" % (
- table,
- ", ".join("%s = ?" % (k,) for k in updatevalues),
- where,
- )
-
- txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values()))
-
- return txn.rowcount
-
- def _simple_update_one(
- self, table, keyvalues, updatevalues, desc="_simple_update_one"
- ):
- """Executes an UPDATE query on the named table, setting new values for
- columns in a row matching the key values.
-
- Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- updatevalues : dict giving column names and values to update
- retcols : optional list of column names to return
-
- If present, retcols gives a list of column names on which to perform
- a SELECT statement *before* performing the UPDATE statement. The values
- of these will be returned in a dict.
-
- These are performed within the same transaction, allowing an atomic
- get-and-set. This can be used to implement compare-and-set by putting
- the update column in the 'keyvalues' dict as well.
- """
- return self.runInteraction(
- desc, self._simple_update_one_txn, table, keyvalues, updatevalues
- )
-
- @classmethod
- def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
- rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues)
-
- if rowcount == 0:
- raise StoreError(404, "No row found (%s)" % (table,))
- if rowcount > 1:
- raise StoreError(500, "More than one row matched (%s)" % (table,))
-
- @staticmethod
- def _simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
- select_sql = "SELECT %s FROM %s WHERE %s" % (
- ", ".join(retcols),
- table,
- " AND ".join("%s = ?" % (k,) for k in keyvalues),
- )
-
- txn.execute(select_sql, list(keyvalues.values()))
- row = txn.fetchone()
-
- if not row:
- if allow_none:
- return None
- raise StoreError(404, "No row found (%s)" % (table,))
- if txn.rowcount > 1:
- raise StoreError(500, "More than one row matched (%s)" % (table,))
-
- return dict(zip(retcols, row))
-
- def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"):
- """Executes a DELETE query on the named table, expecting to delete a
- single row.
-
- Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- """
- return self.runInteraction(desc, self._simple_delete_one_txn, table, keyvalues)
-
- @staticmethod
- def _simple_delete_one_txn(txn, table, keyvalues):
- """Executes a DELETE query on the named table, expecting to delete a
- single row.
-
- Args:
- table : string giving the table name
- keyvalues : dict of column names and values to select the row with
- """
- sql = "DELETE FROM %s WHERE %s" % (
- table,
- " AND ".join("%s = ?" % (k,) for k in keyvalues),
- )
-
- txn.execute(sql, list(keyvalues.values()))
- if txn.rowcount == 0:
- raise StoreError(404, "No row found (%s)" % (table,))
- if txn.rowcount > 1:
- raise StoreError(500, "More than one row matched (%s)" % (table,))
-
- def _simple_delete(self, table, keyvalues, desc):
- return self.runInteraction(desc, self._simple_delete_txn, table, keyvalues)
-
- @staticmethod
- def _simple_delete_txn(txn, table, keyvalues):
- sql = "DELETE FROM %s WHERE %s" % (
- table,
- " AND ".join("%s = ?" % (k,) for k in keyvalues),
- )
-
- txn.execute(sql, list(keyvalues.values()))
- return txn.rowcount
-
- def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
- return self.runInteraction(
- desc, self._simple_delete_many_txn, table, column, iterable, keyvalues
- )
-
- @staticmethod
- def _simple_delete_many_txn(txn, table, column, iterable, keyvalues):
- """Executes a DELETE query on the named table.
-
- Filters rows by if value of `column` is in `iterable`.
-
- Args:
- txn : Transaction object
- table : string giving the table name
- column : column name to test for inclusion against `iterable`
- iterable : list
- keyvalues : dict of column names and values to select the rows with
-
- Returns:
- int: Number rows deleted
- """
- if not iterable:
- return 0
-
- sql = "DELETE FROM %s" % table
-
- clauses = []
- values = []
- clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
- values.extend(iterable)
-
- for key, value in iteritems(keyvalues):
- clauses.append("%s = ?" % (key,))
- values.append(value)
-
- if clauses:
- sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
- txn.execute(sql, values)
-
- return txn.rowcount
-
- def _get_cache_dict(
- self, db_conn, table, entity_column, stream_column, max_value, limit=100000
- ):
- # Fetch a mapping of room_id -> max stream position for "recent" rooms.
- # It doesn't really matter how many we get, the StreamChangeCache will
- # do the right thing to ensure it respects the max size of cache.
- sql = (
- "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
- " WHERE %(stream)s > ? - %(limit)s"
- " GROUP BY %(entity)s"
- ) % {
- "table": table,
- "entity": entity_column,
- "stream": stream_column,
- "limit": limit,
- }
-
- sql = self.database_engine.convert_param_style(sql)
-
- txn = db_conn.cursor()
- txn.execute(sql, (int(max_value),))
-
- cache = {row[0]: int(row[1]) for row in txn}
-
- txn.close()
-
- if cache:
- min_val = min(itervalues(cache))
- else:
- min_val = max_value
-
- return cache, min_val
-
- def _invalidate_cache_and_stream(self, txn, cache_func, keys):
- """Invalidates the cache and adds it to the cache stream so slaves
- will know to invalidate their caches.
-
- This should only be used to invalidate caches where slaves won't
- otherwise know from other replication streams that the cache should
- be invalidated.
- """
- txn.call_after(cache_func.invalidate, keys)
- self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
-
- def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed):
- """Special case invalidation of caches based on current state.
-
- We special case this so that we can batch the cache invalidations into a
- single replication poke.
-
- Args:
- txn
- room_id (str): Room where state changed
- members_changed (iterable[str]): The user_ids of members that have changed
- """
- txn.call_after(self._invalidate_state_caches, room_id, members_changed)
-
- # We need to be careful that the size of the `members_changed` list
- # isn't so large that it causes problems sending over replication, so we
- # send them in chunks.
- # Max line length is 16K, and max user ID length is 255, so 50 should
- # be safe.
- for chunk in batch_iter(members_changed, 50):
- keys = itertools.chain([room_id], chunk)
- self._send_invalidation_to_replication(txn, _CURRENT_STATE_CACHE_NAME, keys)
-
def _invalidate_state_caches(self, room_id, members_changed):
"""Invalidates caches that are based on the current state, but does
not stream invalidations down replication.
@@ -1399,7 +56,7 @@ class SQLBaseStore(object):
members_changed (iterable[str]): The user_ids of members that have
changed
"""
- for host in set(get_domain_from_id(u) for u in members_changed):
+ for host in {get_domain_from_id(u) for u in members_changed}:
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
self._attempt_to_invalidate_cache("was_host_joined", (room_id, host))
@@ -1407,242 +64,29 @@ class SQLBaseStore(object):
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,))
- def _attempt_to_invalidate_cache(self, cache_name, key):
+ def _attempt_to_invalidate_cache(
+ self, cache_name: str, key: Optional[Collection[Any]]
+ ):
"""Attempts to invalidate the cache of the given name, ignoring if the
cache doesn't exist. Mainly used for invalidating caches on workers,
where they may not have the cache.
Args:
- cache_name (str)
- key (tuple)
+ cache_name
+ key: Entry to invalidate. If None then invalidates the entire
+ cache.
"""
+
try:
- getattr(self, cache_name).invalidate(key)
+ if key is None:
+ getattr(self, cache_name).invalidate_all()
+ else:
+ getattr(self, cache_name).invalidate(tuple(key))
except AttributeError:
# We probably haven't pulled in the cache in this worker,
# which is fine.
pass
- def _send_invalidation_to_replication(self, txn, cache_name, keys):
- """Notifies replication that given cache has been invalidated.
-
- Note that this does *not* invalidate the cache locally.
-
- Args:
- txn
- cache_name (str)
- keys (iterable[str])
- """
-
- if isinstance(self.database_engine, PostgresEngine):
- # get_next() returns a context manager which is designed to wrap
- # the transaction. However, we want to only get an ID when we want
- # to use it, here, so we need to call __enter__ manually, and have
- # __exit__ called after the transaction finishes.
- ctx = self._cache_id_gen.get_next()
- stream_id = ctx.__enter__()
- txn.call_on_exception(ctx.__exit__, None, None, None)
- txn.call_after(ctx.__exit__, None, None, None)
- txn.call_after(self.hs.get_notifier().on_new_replication_data)
-
- self._simple_insert_txn(
- txn,
- table="cache_invalidation_stream",
- values={
- "stream_id": stream_id,
- "cache_func": cache_name,
- "keys": list(keys),
- "invalidation_ts": self.clock.time_msec(),
- },
- )
-
- def get_all_updated_caches(self, last_id, current_id, limit):
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_updated_caches_txn(txn):
- # We purposefully don't bound by the current token, as we want to
- # send across cache invalidations as quickly as possible. Cache
- # invalidations are idempotent, so duplicates are fine.
- sql = (
- "SELECT stream_id, cache_func, keys, invalidation_ts"
- " FROM cache_invalidation_stream"
- " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
- )
- txn.execute(sql, (last_id, limit))
- return txn.fetchall()
-
- return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn)
-
- def get_cache_stream_token(self):
- if self._cache_id_gen:
- return self._cache_id_gen.get_current_token()
- else:
- return 0
-
- def _simple_select_list_paginate(
- self,
- table,
- keyvalues,
- orderby,
- start,
- limit,
- retcols,
- order_direction="ASC",
- desc="_simple_select_list_paginate",
- ):
- """
- Executes a SELECT query on the named table with start and limit,
- of row numbers, which may return zero or number of rows from start to limit,
- returning the result as a list of dicts.
-
- Args:
- table (str): the table name
- keyvalues (dict[str, T] | None):
- column names and values to select the rows with, or None to not
- apply a WHERE clause.
- orderby (str): Column to order the results by.
- start (int): Index to begin the query at.
- limit (int): Number of results to return.
- retcols (iterable[str]): the names of the columns to return
- order_direction (str): Whether the results should be ordered "ASC" or "DESC".
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
- """
- return self.runInteraction(
- desc,
- self._simple_select_list_paginate_txn,
- table,
- keyvalues,
- orderby,
- start,
- limit,
- retcols,
- order_direction=order_direction,
- )
-
- @classmethod
- def _simple_select_list_paginate_txn(
- cls,
- txn,
- table,
- keyvalues,
- orderby,
- start,
- limit,
- retcols,
- order_direction="ASC",
- ):
- """
- Executes a SELECT query on the named table with start and limit,
- of row numbers, which may return zero or number of rows from start to limit,
- returning the result as a list of dicts.
-
- Args:
- txn : Transaction object
- table (str): the table name
- keyvalues (dict[str, T] | None):
- column names and values to select the rows with, or None to not
- apply a WHERE clause.
- orderby (str): Column to order the results by.
- start (int): Index to begin the query at.
- limit (int): Number of results to return.
- retcols (iterable[str]): the names of the columns to return
- order_direction (str): Whether the results should be ordered "ASC" or "DESC".
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]]
- """
- if order_direction not in ["ASC", "DESC"]:
- raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
-
- if keyvalues:
- where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues)
- else:
- where_clause = ""
-
- sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
- ", ".join(retcols),
- table,
- where_clause,
- orderby,
- order_direction,
- )
- txn.execute(sql, list(keyvalues.values()) + [limit, start])
-
- return cls.cursor_to_dict(txn)
-
- def get_user_count_txn(self, txn):
- """Get a total number of registered users in the users list.
-
- Args:
- txn : Transaction object
- Returns:
- int : number of users
- """
- sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;"
- txn.execute(sql_count)
- return txn.fetchone()[0]
-
- def _simple_search_list(
- self, table, term, col, retcols, desc="_simple_search_list"
- ):
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
-
- Args:
- table (str): the table name
- term (str | None):
- term for searching the table matched to a column.
- col (str): column to query term should be matched to
- retcols (iterable[str]): the names of the columns to return
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]] or None
- """
-
- return self.runInteraction(
- desc, self._simple_search_list_txn, table, term, col, retcols
- )
-
- @classmethod
- def _simple_search_list_txn(cls, txn, table, term, col, retcols):
- """Executes a SELECT query on the named table, which may return zero or
- more rows, returning the result as a list of dicts.
-
- Args:
- txn : Transaction object
- table (str): the table name
- term (str | None):
- term for searching the table matched to a column.
- col (str): column to query term should be matched to
- retcols (iterable[str]): the names of the columns to return
- Returns:
- defer.Deferred: resolves to list[dict[str, Any]] or None
- """
- if term:
- sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
- termvalues = ["%%" + term + "%%"]
- txn.execute(sql, termvalues)
- else:
- return 0
-
- return cls.cursor_to_dict(txn)
-
- @property
- def database_engine_name(self):
- return self.database_engine.module.__name__
-
- def get_server_version(self):
- """Returns a string describing the server version number"""
- return self.database_engine.server_version
-
-
-class _RollbackButIsFineException(Exception):
- """ This exception is used to rollback a transaction without implying
- something went wrong.
- """
-
- pass
-
def db_to_json(db_content):
"""
@@ -1664,7 +108,7 @@ def db_to_json(db_content):
# Decode it to a Unicode string before feeding it to json.loads, so we
# consistenty get a Unicode-containing object out.
if isinstance(db_content, (bytes, bytearray)):
- db_content = db_content.decode('utf8')
+ db_content = db_content.decode("utf8")
try:
return json.loads(db_content)
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index b8b8273f73..eb1a7e5002 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import Optional
from canonicaljson import json
@@ -22,7 +23,6 @@ from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from . import engines
-from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
@@ -74,7 +74,7 @@ class BackgroundUpdatePerformance(object):
return float(self.total_item_count) / float(self.total_duration_ms)
-class BackgroundUpdateStore(SQLBaseStore):
+class BackgroundUpdater(object):
""" Background updates are updates to the database that run in the
background. Each update processes a batch of data at once. We attempt to
limit the impact of each update by monitoring how long each batch takes to
@@ -86,24 +86,26 @@ class BackgroundUpdateStore(SQLBaseStore):
BACKGROUND_UPDATE_INTERVAL_MS = 1000
BACKGROUND_UPDATE_DURATION_MS = 100
- def __init__(self, db_conn, hs):
- super(BackgroundUpdateStore, self).__init__(db_conn, hs)
+ def __init__(self, hs, database):
+ self._clock = hs.get_clock()
+ self.db = database
+
self._background_update_performance = {}
self._background_update_queue = []
self._background_update_handlers = {}
self._all_done = False
def start_doing_background_updates(self):
- run_as_background_process("background_updates", self._run_background_updates)
+ run_as_background_process("background_updates", self.run_background_updates)
- @defer.inlineCallbacks
- def _run_background_updates(self):
+ async def run_background_updates(self, sleep=True):
logger.info("Starting background schema updates")
while True:
- yield self.hs.get_clock().sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
+ if sleep:
+ await self._clock.sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
try:
- result = yield self.do_next_background_update(
+ result = await self.do_next_background_update(
self.BACKGROUND_UPDATE_DURATION_MS
)
except Exception:
@@ -115,7 +117,7 @@ class BackgroundUpdateStore(SQLBaseStore):
" Unscheduling background update task."
)
self._all_done = True
- defer.returnValue(None)
+ return None
@defer.inlineCallbacks
def has_completed_background_updates(self):
@@ -127,63 +129,85 @@ class BackgroundUpdateStore(SQLBaseStore):
# if we've previously determined that there is nothing left to do, that
# is easy
if self._all_done:
- defer.returnValue(True)
+ return True
# obviously, if we have things in our queue, we're not done.
if self._background_update_queue:
- defer.returnValue(False)
+ return False
# otherwise, check if there are updates to be run. This is important,
# as we may be running on a worker which doesn't perform the bg updates
# itself, but still wants to wait for them to happen.
- updates = yield self._simple_select_onecol(
+ updates = yield self.db.simple_select_onecol(
"background_updates",
keyvalues=None,
retcol="1",
- desc="check_background_updates",
+ desc="has_completed_background_updates",
)
if not updates:
self._all_done = True
- defer.returnValue(True)
+ return True
- defer.returnValue(False)
+ return False
- @defer.inlineCallbacks
- def do_next_background_update(self, desired_duration_ms):
+ async def has_completed_background_update(self, update_name) -> bool:
+ """Check if the given background update has finished running.
+ """
+
+ if self._all_done:
+ return True
+
+ if update_name in self._background_update_queue:
+ return False
+
+ update_exists = await self.db.simple_select_one_onecol(
+ "background_updates",
+ keyvalues={"update_name": update_name},
+ retcol="1",
+ desc="has_completed_background_update",
+ allow_none=True,
+ )
+
+ return not update_exists
+
+ async def do_next_background_update(
+ self, desired_duration_ms: float
+ ) -> Optional[int]:
"""Does some amount of work on the next queued background update
+ Returns once some amount of work is done.
+
Args:
desired_duration_ms(float): How long we want to spend
updating.
Returns:
- A deferred that completes once some amount of work is done.
- The deferred will have a value of None if there is currently
- no more work to do.
+ None if there is no more work to do, otherwise an int
"""
if not self._background_update_queue:
- updates = yield self._simple_select_list(
+ updates = await self.db.simple_select_list(
"background_updates",
keyvalues=None,
retcols=("update_name", "depends_on"),
)
- in_flight = set(update["update_name"] for update in updates)
+ in_flight = {update["update_name"] for update in updates}
for update in updates:
if update["depends_on"] not in in_flight:
- self._background_update_queue.append(update['update_name'])
+ self._background_update_queue.append(update["update_name"])
if not self._background_update_queue:
# no work left to do
- defer.returnValue(None)
+ return None
# pop from the front, and add back to the back
update_name = self._background_update_queue.pop(0)
self._background_update_queue.append(update_name)
- res = yield self._do_background_update(update_name, desired_duration_ms)
- defer.returnValue(res)
+ res = await self._do_background_update(update_name, desired_duration_ms)
+ return res
- @defer.inlineCallbacks
- def _do_background_update(self, update_name, desired_duration_ms):
+ async def _do_background_update(
+ self, update_name: str, desired_duration_ms: float
+ ) -> int:
logger.info("Starting update batch on background update '%s'", update_name)
update_handler = self._background_update_handlers[update_name]
@@ -203,7 +227,7 @@ class BackgroundUpdateStore(SQLBaseStore):
else:
batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
- progress_json = yield self._simple_select_one_onecol(
+ progress_json = await self.db.simple_select_one_onecol(
"background_updates",
keyvalues={"update_name": update_name},
retcol="progress_json",
@@ -212,13 +236,13 @@ class BackgroundUpdateStore(SQLBaseStore):
progress = json.loads(progress_json)
time_start = self._clock.time_msec()
- items_updated = yield update_handler(progress, batch_size)
+ items_updated = await update_handler(progress, batch_size)
time_stop = self._clock.time_msec()
duration_ms = time_stop - time_start
logger.info(
- "Updating %r. Updated %r items in %rms."
+ "Running background update %r. Processed %r items in %rms."
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
update_name,
items_updated,
@@ -231,7 +255,7 @@ class BackgroundUpdateStore(SQLBaseStore):
performance.update(items_updated, duration_ms)
- defer.returnValue(len(self._background_update_performance))
+ return len(self._background_update_performance)
def register_background_update_handler(self, update_name, update_handler):
"""Register a handler for doing a background update.
@@ -241,7 +265,9 @@ class BackgroundUpdateStore(SQLBaseStore):
* A dict of the current progress
* An integer count of the number of items to update in this batch.
- The handler should return a deferred integer count of items updated.
+ The handler should return a deferred or coroutine which returns an integer count
+ of items updated.
+
The handler is responsible for updating the progress of the update.
Args:
@@ -266,7 +292,7 @@ class BackgroundUpdateStore(SQLBaseStore):
@defer.inlineCallbacks
def noop_update(progress, batch_size):
yield self._end_background_update(update_name)
- defer.returnValue(1)
+ return 1
self.register_background_update_handler(update_name, noop_update)
@@ -357,7 +383,7 @@ class BackgroundUpdateStore(SQLBaseStore):
logger.debug("[SQL] %s", sql)
c.execute(sql)
- if isinstance(self.database_engine, engines.PostgresEngine):
+ if isinstance(self.db.engine, engines.PostgresEngine):
runner = create_index_psql
elif psql_only:
runner = None
@@ -368,9 +394,9 @@ class BackgroundUpdateStore(SQLBaseStore):
def updater(progress, batch_size):
if runner is not None:
logger.info("Adding index %s to %s", index_name, table)
- yield self.runWithConnection(runner)
+ yield self.db.runWithConnection(runner)
yield self._end_background_update(update_name)
- defer.returnValue(1)
+ return 1
self.register_background_update_handler(update_name, updater)
@@ -390,7 +416,7 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_queue = []
progress_json = json.dumps(progress)
- return self._simple_insert(
+ return self.db.simple_insert(
"background_updates",
{"update_name": update_name, "progress_json": progress_json},
)
@@ -406,10 +432,25 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_queue = [
name for name in self._background_update_queue if name != update_name
]
- return self._simple_delete_one(
+ return self.db.simple_delete_one(
"background_updates", keyvalues={"update_name": update_name}
)
+ def _background_update_progress(self, update_name: str, progress: dict):
+ """Update the progress of a background update
+
+ Args:
+ update_name: The name of the background update task
+ progress: The progress of the update.
+ """
+
+ return self.db.runInteraction(
+ "background_update_progress",
+ self._background_update_progress_txn,
+ update_name,
+ progress,
+ )
+
def _background_update_progress_txn(self, txn, update_name, progress):
"""Update the progress of a background update
@@ -421,7 +462,7 @@ class BackgroundUpdateStore(SQLBaseStore):
progress_json = json.dumps(progress)
- self._simple_update_one_txn(
+ self.db.simple_update_one_txn(
txn,
"background_updates",
keyvalues={"update_name": update_name},
diff --git a/synapse/storage/data_stores/__init__.py b/synapse/storage/data_stores/__init__.py
new file mode 100644
index 0000000000..e1d03429ca
--- /dev/null
+++ b/synapse/storage/data_stores/__init__.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.storage.data_stores.state import StateGroupDataStore
+from synapse.storage.database import Database, make_conn
+from synapse.storage.engines import create_engine
+from synapse.storage.prepare_database import prepare_database
+
+logger = logging.getLogger(__name__)
+
+
+class DataStores(object):
+ """The various data stores.
+
+ These are low level interfaces to physical databases.
+
+ Attributes:
+ main (DataStore)
+ """
+
+ def __init__(self, main_store_class, hs):
+ # Note we pass in the main store class here as workers use a different main
+ # store.
+
+ self.databases = []
+ self.main = None
+ self.state = None
+
+ for database_config in hs.config.database.databases:
+ db_name = database_config.name
+ engine = create_engine(database_config.config)
+
+ with make_conn(database_config, engine) as db_conn:
+ logger.info("Preparing database %r...", db_name)
+
+ engine.check_database(db_conn)
+ prepare_database(
+ db_conn, engine, hs.config, data_stores=database_config.data_stores,
+ )
+
+ database = Database(hs, database_config, engine)
+
+ if "main" in database_config.data_stores:
+ logger.info("Starting 'main' data store")
+
+ # Sanity check we don't try and configure the main store on
+ # multiple databases.
+ if self.main:
+ raise Exception("'main' data store already configured")
+
+ self.main = main_store_class(database, db_conn, hs)
+
+ if "state" in database_config.data_stores:
+ logger.info("Starting 'state' data store")
+
+ # Sanity check we don't try and configure the state store on
+ # multiple databases.
+ if self.state:
+ raise Exception("'state' data store already configured")
+
+ self.state = StateGroupDataStore(database, db_conn, hs)
+
+ db_conn.commit()
+
+ self.databases.append(database)
+
+ logger.info("Database %r prepared", db_name)
+
+ # Sanity check that we have actually configured all the required stores.
+ if not self.main:
+ raise Exception("No 'main' data store configured")
+
+ if not self.state:
+ raise Exception("No 'main' data store configured")
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
new file mode 100644
index 0000000000..acca079f23
--- /dev/null
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -0,0 +1,583 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import calendar
+import logging
+import time
+
+from synapse.api.constants import PresenceState
+from synapse.config.homeserver import HomeServerConfig
+from synapse.storage.database import Database
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import (
+ ChainedIdGenerator,
+ IdGenerator,
+ StreamIdGenerator,
+)
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+from .account_data import AccountDataStore
+from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
+from .cache import CacheInvalidationStore
+from .client_ips import ClientIpStore
+from .deviceinbox import DeviceInboxStore
+from .devices import DeviceStore
+from .directory import DirectoryStore
+from .e2e_room_keys import EndToEndRoomKeyStore
+from .end_to_end_keys import EndToEndKeyStore
+from .event_federation import EventFederationStore
+from .event_push_actions import EventPushActionsStore
+from .events import EventsStore
+from .events_bg_updates import EventsBackgroundUpdatesStore
+from .filtering import FilteringStore
+from .group_server import GroupServerStore
+from .keys import KeyStore
+from .media_repository import MediaRepositoryStore
+from .monthly_active_users import MonthlyActiveUsersStore
+from .openid import OpenIdStore
+from .presence import PresenceStore, UserPresenceState
+from .profile import ProfileStore
+from .push_rule import PushRuleStore
+from .pusher import PusherStore
+from .receipts import ReceiptsStore
+from .registration import RegistrationStore
+from .rejections import RejectionsStore
+from .relations import RelationsStore
+from .room import RoomStore
+from .roommember import RoomMemberStore
+from .search import SearchStore
+from .signatures import SignatureStore
+from .state import StateStore
+from .stats import StatsStore
+from .stream import StreamStore
+from .tags import TagsStore
+from .transactions import TransactionStore
+from .user_directory import UserDirectoryStore
+from .user_erasure_store import UserErasureStore
+
+logger = logging.getLogger(__name__)
+
+
+class DataStore(
+ EventsBackgroundUpdatesStore,
+ RoomMemberStore,
+ RoomStore,
+ RegistrationStore,
+ StreamStore,
+ ProfileStore,
+ PresenceStore,
+ TransactionStore,
+ DirectoryStore,
+ KeyStore,
+ StateStore,
+ SignatureStore,
+ ApplicationServiceStore,
+ EventsStore,
+ EventFederationStore,
+ MediaRepositoryStore,
+ RejectionsStore,
+ FilteringStore,
+ PusherStore,
+ PushRuleStore,
+ ApplicationServiceTransactionStore,
+ ReceiptsStore,
+ EndToEndKeyStore,
+ EndToEndRoomKeyStore,
+ SearchStore,
+ TagsStore,
+ AccountDataStore,
+ EventPushActionsStore,
+ OpenIdStore,
+ ClientIpStore,
+ DeviceStore,
+ DeviceInboxStore,
+ UserDirectoryStore,
+ GroupServerStore,
+ UserErasureStore,
+ MonthlyActiveUsersStore,
+ StatsStore,
+ RelationsStore,
+ CacheInvalidationStore,
+):
+ def __init__(self, database: Database, db_conn, hs):
+ self.hs = hs
+ self._clock = hs.get_clock()
+ self.database_engine = database.engine
+
+ self._stream_id_gen = StreamIdGenerator(
+ db_conn,
+ "events",
+ "stream_ordering",
+ extra_tables=[("local_invites", "stream_id")],
+ )
+ self._backfill_id_gen = StreamIdGenerator(
+ db_conn,
+ "events",
+ "stream_ordering",
+ step=-1,
+ extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+ )
+ self._presence_id_gen = StreamIdGenerator(
+ db_conn, "presence_stream", "stream_id"
+ )
+ self._device_inbox_id_gen = StreamIdGenerator(
+ db_conn, "device_max_stream_id", "stream_id"
+ )
+ self._public_room_id_gen = StreamIdGenerator(
+ db_conn, "public_room_list_stream", "stream_id"
+ )
+ self._device_list_id_gen = StreamIdGenerator(
+ db_conn,
+ "device_lists_stream",
+ "stream_id",
+ extra_tables=[("user_signature_stream", "stream_id")],
+ )
+ self._cross_signing_id_gen = StreamIdGenerator(
+ db_conn, "e2e_cross_signing_keys", "stream_id"
+ )
+
+ self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
+ self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
+ self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
+ self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
+ self._push_rules_stream_id_gen = ChainedIdGenerator(
+ self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
+ )
+ self._pushers_id_gen = StreamIdGenerator(
+ db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
+ )
+ self._group_updates_id_gen = StreamIdGenerator(
+ db_conn, "local_group_updates", "stream_id"
+ )
+
+ if isinstance(self.database_engine, PostgresEngine):
+ self._cache_id_gen = StreamIdGenerator(
+ db_conn, "cache_invalidation_stream", "stream_id"
+ )
+ else:
+ self._cache_id_gen = None
+
+ super(DataStore, self).__init__(database, db_conn, hs)
+
+ self._presence_on_startup = self._get_active_presence(db_conn)
+
+ presence_cache_prefill, min_presence_val = self.db.get_cache_dict(
+ db_conn,
+ "presence_stream",
+ entity_column="user_id",
+ stream_column="stream_id",
+ max_value=self._presence_id_gen.get_current_token(),
+ )
+ self.presence_stream_cache = StreamChangeCache(
+ "PresenceStreamChangeCache",
+ min_presence_val,
+ prefilled_cache=presence_cache_prefill,
+ )
+
+ max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
+ device_inbox_prefill, min_device_inbox_id = self.db.get_cache_dict(
+ db_conn,
+ "device_inbox",
+ entity_column="user_id",
+ stream_column="stream_id",
+ max_value=max_device_inbox_id,
+ limit=1000,
+ )
+ self._device_inbox_stream_cache = StreamChangeCache(
+ "DeviceInboxStreamChangeCache",
+ min_device_inbox_id,
+ prefilled_cache=device_inbox_prefill,
+ )
+ # The federation outbox and the local device inbox uses the same
+ # stream_id generator.
+ device_outbox_prefill, min_device_outbox_id = self.db.get_cache_dict(
+ db_conn,
+ "device_federation_outbox",
+ entity_column="destination",
+ stream_column="stream_id",
+ max_value=max_device_inbox_id,
+ limit=1000,
+ )
+ self._device_federation_outbox_stream_cache = StreamChangeCache(
+ "DeviceFederationOutboxStreamChangeCache",
+ min_device_outbox_id,
+ prefilled_cache=device_outbox_prefill,
+ )
+
+ device_list_max = self._device_list_id_gen.get_current_token()
+ self._device_list_stream_cache = StreamChangeCache(
+ "DeviceListStreamChangeCache", device_list_max
+ )
+ self._user_signature_stream_cache = StreamChangeCache(
+ "UserSignatureStreamChangeCache", device_list_max
+ )
+ self._device_list_federation_stream_cache = StreamChangeCache(
+ "DeviceListFederationStreamChangeCache", device_list_max
+ )
+
+ events_max = self._stream_id_gen.get_current_token()
+ curr_state_delta_prefill, min_curr_state_delta_id = self.db.get_cache_dict(
+ db_conn,
+ "current_state_delta_stream",
+ entity_column="room_id",
+ stream_column="stream_id",
+ max_value=events_max, # As we share the stream id with events token
+ limit=1000,
+ )
+ self._curr_state_delta_stream_cache = StreamChangeCache(
+ "_curr_state_delta_stream_cache",
+ min_curr_state_delta_id,
+ prefilled_cache=curr_state_delta_prefill,
+ )
+
+ _group_updates_prefill, min_group_updates_id = self.db.get_cache_dict(
+ db_conn,
+ "local_group_updates",
+ entity_column="user_id",
+ stream_column="stream_id",
+ max_value=self._group_updates_id_gen.get_current_token(),
+ limit=1000,
+ )
+ self._group_updates_stream_cache = StreamChangeCache(
+ "_group_updates_stream_cache",
+ min_group_updates_id,
+ prefilled_cache=_group_updates_prefill,
+ )
+
+ self._stream_order_on_start = self.get_room_max_stream_ordering()
+ self._min_stream_order_on_start = self.get_room_min_stream_ordering()
+
+ # Used in _generate_user_daily_visits to keep track of progress
+ self._last_user_visit_update = self._get_start_of_day()
+
+ def take_presence_startup_info(self):
+ active_on_startup = self._presence_on_startup
+ self._presence_on_startup = None
+ return active_on_startup
+
+ def _get_active_presence(self, db_conn):
+ """Fetch non-offline presence from the database so that we can register
+ the appropriate time outs.
+ """
+
+ sql = (
+ "SELECT user_id, state, last_active_ts, last_federation_update_ts,"
+ " last_user_sync_ts, status_msg, currently_active FROM presence_stream"
+ " WHERE state != ?"
+ )
+ sql = self.database_engine.convert_param_style(sql)
+
+ txn = db_conn.cursor()
+ txn.execute(sql, (PresenceState.OFFLINE,))
+ rows = self.db.cursor_to_dict(txn)
+ txn.close()
+
+ for row in rows:
+ row["currently_active"] = bool(row["currently_active"])
+
+ return [UserPresenceState(**row) for row in rows]
+
+ def count_daily_users(self):
+ """
+ Counts the number of users who used this homeserver in the last 24 hours.
+ """
+ yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
+ return self.db.runInteraction("count_daily_users", self._count_users, yesterday)
+
+ def count_monthly_users(self):
+ """
+ Counts the number of users who used this homeserver in the last 30 days.
+ Note this method is intended for phonehome metrics only and is different
+ from the mau figure in synapse.storage.monthly_active_users which,
+ amongst other things, includes a 3 day grace period before a user counts.
+ """
+ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
+ return self.db.runInteraction(
+ "count_monthly_users", self._count_users, thirty_days_ago
+ )
+
+ def _count_users(self, txn, time_from):
+ """
+ Returns number of users seen in the past time_from period
+ """
+ sql = """
+ SELECT COALESCE(count(*), 0) FROM (
+ SELECT user_id FROM user_ips
+ WHERE last_seen > ?
+ GROUP BY user_id
+ ) u
+ """
+ txn.execute(sql, (time_from,))
+ (count,) = txn.fetchone()
+ return count
+
+ def count_r30_users(self):
+ """
+ Counts the number of 30 day retained users, defined as:-
+ * Users who have created their accounts more than 30 days ago
+ * Where last seen at most 30 days ago
+ * Where account creation and last_seen are > 30 days apart
+
+ Returns counts globaly for a given user as well as breaking
+ by platform
+ """
+
+ def _count_r30_users(txn):
+ thirty_days_in_secs = 86400 * 30
+ now = int(self._clock.time())
+ thirty_days_ago_in_secs = now - thirty_days_in_secs
+
+ sql = """
+ SELECT platform, COALESCE(count(*), 0) FROM (
+ SELECT
+ users.name, platform, users.creation_ts * 1000,
+ MAX(uip.last_seen)
+ FROM users
+ INNER JOIN (
+ SELECT
+ user_id,
+ last_seen,
+ CASE
+ WHEN user_agent LIKE '%%Android%%' THEN 'android'
+ WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
+ WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
+ WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
+ WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
+ ELSE 'unknown'
+ END
+ AS platform
+ FROM user_ips
+ ) uip
+ ON users.name = uip.user_id
+ AND users.appservice_id is NULL
+ AND users.creation_ts < ?
+ AND uip.last_seen/1000 > ?
+ AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+ GROUP BY users.name, platform, users.creation_ts
+ ) u GROUP BY platform
+ """
+
+ results = {}
+ txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
+
+ for row in txn:
+ if row[0] == "unknown":
+ pass
+ results[row[0]] = row[1]
+
+ sql = """
+ SELECT COALESCE(count(*), 0) FROM (
+ SELECT users.name, users.creation_ts * 1000,
+ MAX(uip.last_seen)
+ FROM users
+ INNER JOIN (
+ SELECT
+ user_id,
+ last_seen
+ FROM user_ips
+ ) uip
+ ON users.name = uip.user_id
+ AND appservice_id is NULL
+ AND users.creation_ts < ?
+ AND uip.last_seen/1000 > ?
+ AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+ GROUP BY users.name, users.creation_ts
+ ) u
+ """
+
+ txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
+
+ (count,) = txn.fetchone()
+ results["all"] = count
+
+ return results
+
+ return self.db.runInteraction("count_r30_users", _count_r30_users)
+
+ def _get_start_of_day(self):
+ """
+ Returns millisecond unixtime for start of UTC day.
+ """
+ now = time.gmtime()
+ today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
+ return today_start * 1000
+
+ def generate_user_daily_visits(self):
+ """
+ Generates daily visit data for use in cohort/ retention analysis
+ """
+
+ def _generate_user_daily_visits(txn):
+ logger.info("Calling _generate_user_daily_visits")
+ today_start = self._get_start_of_day()
+ a_day_in_milliseconds = 24 * 60 * 60 * 1000
+ now = self.clock.time_msec()
+
+ sql = """
+ INSERT INTO user_daily_visits (user_id, device_id, timestamp)
+ SELECT u.user_id, u.device_id, ?
+ FROM user_ips AS u
+ LEFT JOIN (
+ SELECT user_id, device_id, timestamp FROM user_daily_visits
+ WHERE timestamp = ?
+ ) udv
+ ON u.user_id = udv.user_id AND u.device_id=udv.device_id
+ INNER JOIN users ON users.name=u.user_id
+ WHERE last_seen > ? AND last_seen <= ?
+ AND udv.timestamp IS NULL AND users.is_guest=0
+ AND users.appservice_id IS NULL
+ GROUP BY u.user_id, u.device_id
+ """
+
+ # This means that the day has rolled over but there could still
+ # be entries from the previous day. There is an edge case
+ # where if the user logs in at 23:59 and overwrites their
+ # last_seen at 00:01 then they will not be counted in the
+ # previous day's stats - it is important that the query is run
+ # often to minimise this case.
+ if today_start > self._last_user_visit_update:
+ yesterday_start = today_start - a_day_in_milliseconds
+ txn.execute(
+ sql,
+ (
+ yesterday_start,
+ yesterday_start,
+ self._last_user_visit_update,
+ today_start,
+ ),
+ )
+ self._last_user_visit_update = today_start
+
+ txn.execute(
+ sql, (today_start, today_start, self._last_user_visit_update, now)
+ )
+ # Update _last_user_visit_update to now. The reason to do this
+ # rather just clamping to the beginning of the day is to limit
+ # the size of the join - meaning that the query can be run more
+ # frequently
+ self._last_user_visit_update = now
+
+ return self.db.runInteraction(
+ "generate_user_daily_visits", _generate_user_daily_visits
+ )
+
+ def get_users(self):
+ """Function to retrieve a list of users in users table.
+
+ Args:
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]]
+ """
+ return self.db.simple_select_list(
+ table="users",
+ keyvalues={},
+ retcols=[
+ "name",
+ "password_hash",
+ "is_guest",
+ "admin",
+ "user_type",
+ "deactivated",
+ ],
+ desc="get_users",
+ )
+
+ def get_users_paginate(
+ self, start, limit, name=None, guests=True, deactivated=False
+ ):
+ """Function to retrieve a paginated list of users from
+ users list. This will return a json list of users.
+
+ Args:
+ start (int): start number to begin the query from
+ limit (int): number of rows to retrieve
+ name (string): filter for user names
+ guests (bool): whether to in include guest users
+ deactivated (bool): whether to include deactivated users
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]]
+ """
+ name_filter = {}
+ if name:
+ name_filter["name"] = "%" + name + "%"
+
+ attr_filter = {}
+ if not guests:
+ attr_filter["is_guest"] = 0
+ if not deactivated:
+ attr_filter["deactivated"] = 0
+
+ return self.db.simple_select_list_paginate(
+ desc="get_users_paginate",
+ table="users",
+ orderby="name",
+ start=start,
+ limit=limit,
+ filters=name_filter,
+ keyvalues=attr_filter,
+ retcols=[
+ "name",
+ "password_hash",
+ "is_guest",
+ "admin",
+ "user_type",
+ "deactivated",
+ ],
+ )
+
+ def search_users(self, term):
+ """Function to search users list for one or more users with
+ the matched term.
+
+ Args:
+ term (str): search term
+ col (str): column to query term should be matched to
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]]
+ """
+ return self.db.simple_search_list(
+ table="users",
+ term=term,
+ col="name",
+ retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
+ desc="search_users",
+ )
+
+
+def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig):
+ """Called before upgrading an existing database to check that it is broadly sane
+ compared with the configuration.
+ """
+ domain = config.server_name
+
+ sql = database_engine.convert_param_style(
+ "SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
+ )
+ pat = "%:" + domain
+ cur.execute(sql, (pat,))
+ num_not_matching = cur.fetchall()[0][0]
+ if num_not_matching == 0:
+ return
+
+ raise Exception(
+ "Found users in database not native to %s!\n"
+ "You cannot changed a synapse server_name after it's been configured"
+ % (domain,)
+ )
+
+
+__all__ = ["DataStore", "check_database_before_upgrade"]
diff --git a/synapse/storage/account_data.py b/synapse/storage/data_stores/main/account_data.py
index 8394389073..46b494b334 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/data_stores/main/account_data.py
@@ -22,6 +22,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -38,13 +39,13 @@ class AccountDataWorkerStore(SQLBaseStore):
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
- def __init__(self, db_conn, hs):
+ def __init__(self, database: Database, db_conn, hs):
account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max
)
- super(AccountDataWorkerStore, self).__init__(db_conn, hs)
+ super(AccountDataWorkerStore, self).__init__(database, db_conn, hs)
@abc.abstractmethod
def get_max_account_data_stream_id(self):
@@ -67,7 +68,7 @@ class AccountDataWorkerStore(SQLBaseStore):
"""
def get_account_data_for_user_txn(txn):
- rows = self._simple_select_list_txn(
+ rows = self.db.simple_select_list_txn(
txn,
"account_data",
{"user_id": user_id},
@@ -78,7 +79,7 @@ class AccountDataWorkerStore(SQLBaseStore):
row["account_data_type"]: json.loads(row["content"]) for row in rows
}
- rows = self._simple_select_list_txn(
+ rows = self.db.simple_select_list_txn(
txn,
"room_account_data",
{"user_id": user_id},
@@ -90,9 +91,9 @@ class AccountDataWorkerStore(SQLBaseStore):
room_data = by_room.setdefault(row["room_id"], {})
room_data[row["account_data_type"]] = json.loads(row["content"])
- return (global_account_data, by_room)
+ return global_account_data, by_room
- return self.runInteraction(
+ return self.db.runInteraction(
"get_account_data_for_user", get_account_data_for_user_txn
)
@@ -102,7 +103,7 @@ class AccountDataWorkerStore(SQLBaseStore):
Returns:
Deferred: A dict
"""
- result = yield self._simple_select_one_onecol(
+ result = yield self.db.simple_select_one_onecol(
table="account_data",
keyvalues={"user_id": user_id, "account_data_type": data_type},
retcol="content",
@@ -111,9 +112,9 @@ class AccountDataWorkerStore(SQLBaseStore):
)
if result:
- defer.returnValue(json.loads(result))
+ return json.loads(result)
else:
- defer.returnValue(None)
+ return None
@cached(num_args=2)
def get_account_data_for_room(self, user_id, room_id):
@@ -127,7 +128,7 @@ class AccountDataWorkerStore(SQLBaseStore):
"""
def get_account_data_for_room_txn(txn):
- rows = self._simple_select_list_txn(
+ rows = self.db.simple_select_list_txn(
txn,
"room_account_data",
{"user_id": user_id, "room_id": room_id},
@@ -138,7 +139,7 @@ class AccountDataWorkerStore(SQLBaseStore):
row["account_data_type"]: json.loads(row["content"]) for row in rows
}
- return self.runInteraction(
+ return self.db.runInteraction(
"get_account_data_for_room", get_account_data_for_room_txn
)
@@ -156,7 +157,7 @@ class AccountDataWorkerStore(SQLBaseStore):
"""
def get_account_data_for_room_and_type_txn(txn):
- content_json = self._simple_select_one_onecol_txn(
+ content_json = self.db.simple_select_one_onecol_txn(
txn,
table="room_account_data",
keyvalues={
@@ -170,7 +171,7 @@ class AccountDataWorkerStore(SQLBaseStore):
return json.loads(content_json) if content_json else None
- return self.runInteraction(
+ return self.db.runInteraction(
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
)
@@ -184,14 +185,14 @@ class AccountDataWorkerStore(SQLBaseStore):
current_id(int): The position to fetch up to.
Returns:
A deferred pair of lists of tuples of stream_id int, user_id string,
- room_id string, type string, and content string.
+ room_id string, and type string.
"""
if last_room_id == current_id and last_global_id == current_id:
return defer.succeed(([], []))
def get_updated_account_data_txn(txn):
sql = (
- "SELECT stream_id, user_id, account_data_type, content"
+ "SELECT stream_id, user_id, account_data_type"
" FROM account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
@@ -199,15 +200,15 @@ class AccountDataWorkerStore(SQLBaseStore):
global_results = txn.fetchall()
sql = (
- "SELECT stream_id, user_id, room_id, account_data_type, content"
+ "SELECT stream_id, user_id, room_id, account_data_type"
" FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
txn.execute(sql, (last_room_id, current_id, limit))
room_results = txn.fetchall()
- return (global_results, room_results)
+ return global_results, room_results
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_updated_account_data_txn", get_updated_account_data_txn
)
@@ -244,15 +245,15 @@ class AccountDataWorkerStore(SQLBaseStore):
room_account_data = account_data_by_room.setdefault(row[0], {})
room_account_data[row[1]] = json.loads(row[2])
- return (global_account_data, account_data_by_room)
+ return global_account_data, account_data_by_room
changed = self._account_data_stream_cache.has_entity_changed(
user_id, int(stream_id)
)
if not changed:
- return ({}, {})
+ return defer.succeed(({}, {}))
- return self.runInteraction(
+ return self.db.runInteraction(
"get_updated_account_data_for_user", get_updated_account_data_for_user_txn
)
@@ -264,20 +265,18 @@ class AccountDataWorkerStore(SQLBaseStore):
on_invalidate=cache_context.invalidate,
)
if not ignored_account_data:
- defer.returnValue(False)
+ return False
- defer.returnValue(
- ignored_user_id in ignored_account_data.get("ignored_users", {})
- )
+ return ignored_user_id in ignored_account_data.get("ignored_users", {})
class AccountDataStore(AccountDataWorkerStore):
- def __init__(self, db_conn, hs):
+ def __init__(self, database: Database, db_conn, hs):
self._account_data_id_gen = StreamIdGenerator(
db_conn, "account_data_max_stream_id", "stream_id"
)
- super(AccountDataStore, self).__init__(db_conn, hs)
+ super(AccountDataStore, self).__init__(database, db_conn, hs)
def get_max_account_data_stream_id(self):
"""Get the current max stream id for the private user data stream
@@ -302,9 +301,9 @@ class AccountDataStore(AccountDataWorkerStore):
with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as room_account_data has a unique constraint
- # on (user_id, room_id, account_data_type) so _simple_upsert will
+ # on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict.
- yield self._simple_upsert(
+ yield self.db.simple_upsert(
desc="add_room_account_data",
table="room_account_data",
keyvalues={
@@ -332,7 +331,7 @@ class AccountDataStore(AccountDataWorkerStore):
)
result = self._account_data_id_gen.get_current_token()
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def add_account_data_for_user(self, user_id, account_data_type, content):
@@ -348,9 +347,9 @@ class AccountDataStore(AccountDataWorkerStore):
with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as account_data has a unique constraint on
- # (user_id, account_data_type) so _simple_upsert will retry if
+ # (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict.
- yield self._simple_upsert(
+ yield self.db.simple_upsert(
desc="add_user_account_data",
table="account_data",
keyvalues={"user_id": user_id, "account_data_type": account_data_type},
@@ -373,7 +372,7 @@ class AccountDataStore(AccountDataWorkerStore):
)
result = self._account_data_id_gen.get_current_token()
- defer.returnValue(result)
+ return result
def _update_max_stream_id(self, next_id):
"""Update the max stream_id
@@ -390,4 +389,4 @@ class AccountDataStore(AccountDataWorkerStore):
)
txn.execute(update_max_id_sql, (next_id, next_id))
- return self.runInteraction("update_account_data_max_stream_id", _update)
+ return self.db.runInteraction("update_account_data_max_stream_id", _update)
diff --git a/synapse/storage/appservice.py b/synapse/storage/data_stores/main/appservice.py
index 9d9b28de13..9c52aa5340 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/data_stores/main/appservice.py
@@ -22,9 +22,9 @@ from twisted.internet import defer
from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices
-from synapse.storage.events_worker import EventsWorkerStore
-
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
logger = logging.getLogger(__name__)
@@ -49,13 +49,13 @@ def _make_exclusive_regex(services_cache):
class ApplicationServiceWorkerStore(SQLBaseStore):
- def __init__(self, db_conn, hs):
+ def __init__(self, database: Database, db_conn, hs):
self.services_cache = load_appservices(
hs.hostname, hs.config.app_service_config_files
)
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
- super(ApplicationServiceWorkerStore, self).__init__(db_conn, hs)
+ super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs)
def get_app_services(self):
return self.services_cache
@@ -134,8 +134,8 @@ class ApplicationServiceTransactionWorkerStore(
A Deferred which resolves to a list of ApplicationServices, which
may be empty.
"""
- results = yield self._simple_select_list(
- "application_services_state", dict(state=state), ["as_id"]
+ results = yield self.db.simple_select_list(
+ "application_services_state", {"state": state}, ["as_id"]
)
# NB: This assumes this class is linked with ApplicationServiceStore
as_list = self.get_app_services()
@@ -145,7 +145,7 @@ class ApplicationServiceTransactionWorkerStore(
for service in as_list:
if service.id == res["as_id"]:
services.append(service)
- defer.returnValue(services)
+ return services
@defer.inlineCallbacks
def get_appservice_state(self, service):
@@ -156,17 +156,16 @@ class ApplicationServiceTransactionWorkerStore(
Returns:
A Deferred which resolves to ApplicationServiceState.
"""
- result = yield self._simple_select_one(
+ result = yield self.db.simple_select_one(
"application_services_state",
- dict(as_id=service.id),
+ {"as_id": service.id},
["state"],
allow_none=True,
desc="get_appservice_state",
)
if result:
- defer.returnValue(result.get("state"))
- return
- defer.returnValue(None)
+ return result.get("state")
+ return None
def set_appservice_state(self, service, state):
"""Set the application service state.
@@ -177,8 +176,8 @@ class ApplicationServiceTransactionWorkerStore(
Returns:
A Deferred which resolves when the state was set successfully.
"""
- return self._simple_upsert(
- "application_services_state", dict(as_id=service.id), dict(state=state)
+ return self.db.simple_upsert(
+ "application_services_state", {"as_id": service.id}, {"state": state}
)
def create_appservice_txn(self, service, events):
@@ -218,7 +217,7 @@ class ApplicationServiceTransactionWorkerStore(
)
return AppServiceTransaction(service=service, id=new_txn_id, events=events)
- return self.runInteraction("create_appservice_txn", _create_appservice_txn)
+ return self.db.runInteraction("create_appservice_txn", _create_appservice_txn)
def complete_appservice_txn(self, txn_id, service):
"""Completes an application service transaction.
@@ -251,19 +250,23 @@ class ApplicationServiceTransactionWorkerStore(
)
# Set current txn_id for AS to 'txn_id'
- self._simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
"application_services_state",
- dict(as_id=service.id),
- dict(last_txn=txn_id),
+ {"as_id": service.id},
+ {"last_txn": txn_id},
)
# Delete txn
- self._simple_delete_txn(
- txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id)
+ self.db.simple_delete_txn(
+ txn,
+ "application_services_txns",
+ {"txn_id": txn_id, "as_id": service.id},
)
- return self.runInteraction("complete_appservice_txn", _complete_appservice_txn)
+ return self.db.runInteraction(
+ "complete_appservice_txn", _complete_appservice_txn
+ )
@defer.inlineCallbacks
def get_oldest_unsent_txn(self, service):
@@ -285,7 +288,7 @@ class ApplicationServiceTransactionWorkerStore(
" ORDER BY txn_id ASC LIMIT 1",
(service.id,),
)
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
if not rows:
return None
@@ -293,20 +296,18 @@ class ApplicationServiceTransactionWorkerStore(
return entry
- entry = yield self.runInteraction(
+ entry = yield self.db.runInteraction(
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
)
if not entry:
- defer.returnValue(None)
+ return None
event_ids = json.loads(entry["event_ids"])
events = yield self.get_events_as_list(event_ids)
- defer.returnValue(
- AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
- )
+ return AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
def _get_last_txn(self, txn, service_id):
txn.execute(
@@ -325,7 +326,7 @@ class ApplicationServiceTransactionWorkerStore(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
)
- return self.runInteraction(
+ return self.db.runInteraction(
"set_appservice_last_pos", set_appservice_last_pos_txn
)
@@ -354,13 +355,13 @@ class ApplicationServiceTransactionWorkerStore(
return upper_bound, [row[1] for row in rows]
- upper_bound, event_ids = yield self.runInteraction(
+ upper_bound, event_ids = yield self.db.runInteraction(
"get_new_events_for_appservice", get_new_events_for_appservice_txn
)
events = yield self.get_events_as_list(event_ids)
- defer.returnValue((upper_bound, events))
+ return upper_bound, events
class ApplicationServiceTransactionStore(ApplicationServiceTransactionWorkerStore):
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
new file mode 100644
index 0000000000..d4c44dcc75
--- /dev/null
+++ b/synapse/storage/data_stores/main/cache.py
@@ -0,0 +1,172 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import itertools
+import logging
+from typing import Any, Iterable, Optional, Tuple
+
+from twisted.internet import defer
+
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.engines import PostgresEngine
+from synapse.util.iterutils import batch_iter
+
+logger = logging.getLogger(__name__)
+
+
+# This is a special cache name we use to batch multiple invalidations of caches
+# based on the current state when notifying workers over replication.
+CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
+
+
+class CacheInvalidationStore(SQLBaseStore):
+ async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
+ """Invalidates the cache and adds it to the cache stream so slaves
+ will know to invalidate their caches.
+
+ This should only be used to invalidate caches where slaves won't
+ otherwise know from other replication streams that the cache should
+ be invalidated.
+ """
+ cache_func = getattr(self, cache_name, None)
+ if not cache_func:
+ return
+
+ cache_func.invalidate(keys)
+ await self.runInteraction(
+ "invalidate_cache_and_stream",
+ self._send_invalidation_to_replication,
+ cache_func.__name__,
+ keys,
+ )
+
+ def _invalidate_cache_and_stream(self, txn, cache_func, keys):
+ """Invalidates the cache and adds it to the cache stream so slaves
+ will know to invalidate their caches.
+
+ This should only be used to invalidate caches where slaves won't
+ otherwise know from other replication streams that the cache should
+ be invalidated.
+ """
+ txn.call_after(cache_func.invalidate, keys)
+ self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
+
+ def _invalidate_all_cache_and_stream(self, txn, cache_func):
+ """Invalidates the entire cache and adds it to the cache stream so slaves
+ will know to invalidate their caches.
+ """
+
+ txn.call_after(cache_func.invalidate_all)
+ self._send_invalidation_to_replication(txn, cache_func.__name__, None)
+
+ def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed):
+ """Special case invalidation of caches based on current state.
+
+ We special case this so that we can batch the cache invalidations into a
+ single replication poke.
+
+ Args:
+ txn
+ room_id (str): Room where state changed
+ members_changed (iterable[str]): The user_ids of members that have changed
+ """
+ txn.call_after(self._invalidate_state_caches, room_id, members_changed)
+
+ if members_changed:
+ # We need to be careful that the size of the `members_changed` list
+ # isn't so large that it causes problems sending over replication, so we
+ # send them in chunks.
+ # Max line length is 16K, and max user ID length is 255, so 50 should
+ # be safe.
+ for chunk in batch_iter(members_changed, 50):
+ keys = itertools.chain([room_id], chunk)
+ self._send_invalidation_to_replication(
+ txn, CURRENT_STATE_CACHE_NAME, keys
+ )
+ else:
+ # if no members changed, we still need to invalidate the other caches.
+ self._send_invalidation_to_replication(
+ txn, CURRENT_STATE_CACHE_NAME, [room_id]
+ )
+
+ def _send_invalidation_to_replication(
+ self, txn, cache_name: str, keys: Optional[Iterable[Any]]
+ ):
+ """Notifies replication that given cache has been invalidated.
+
+ Note that this does *not* invalidate the cache locally.
+
+ Args:
+ txn
+ cache_name
+ keys: Entry to invalidate. If None will invalidate all.
+ """
+
+ if cache_name == CURRENT_STATE_CACHE_NAME and keys is None:
+ raise Exception(
+ "Can't stream invalidate all with magic current state cache"
+ )
+
+ if isinstance(self.database_engine, PostgresEngine):
+ # get_next() returns a context manager which is designed to wrap
+ # the transaction. However, we want to only get an ID when we want
+ # to use it, here, so we need to call __enter__ manually, and have
+ # __exit__ called after the transaction finishes.
+ ctx = self._cache_id_gen.get_next()
+ stream_id = ctx.__enter__()
+ txn.call_on_exception(ctx.__exit__, None, None, None)
+ txn.call_after(ctx.__exit__, None, None, None)
+ txn.call_after(self.hs.get_notifier().on_new_replication_data)
+
+ if keys is not None:
+ keys = list(keys)
+
+ self.db.simple_insert_txn(
+ txn,
+ table="cache_invalidation_stream",
+ values={
+ "stream_id": stream_id,
+ "cache_func": cache_name,
+ "keys": keys,
+ "invalidation_ts": self.clock.time_msec(),
+ },
+ )
+
+ def get_all_updated_caches(self, last_id, current_id, limit):
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_updated_caches_txn(txn):
+ # We purposefully don't bound by the current token, as we want to
+ # send across cache invalidations as quickly as possible. Cache
+ # invalidations are idempotent, so duplicates are fine.
+ sql = (
+ "SELECT stream_id, cache_func, keys, invalidation_ts"
+ " FROM cache_invalidation_stream"
+ " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_id, limit))
+ return txn.fetchall()
+
+ return self.db.runInteraction(
+ "get_all_updated_caches", get_all_updated_caches_txn
+ )
+
+ def get_cache_stream_token(self):
+ if self._cache_id_gen:
+ return self._cache_id_gen.get_current_token()
+ else:
+ return 0
diff --git a/synapse/storage/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index bda68de5be..e1ccb27142 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -19,11 +19,11 @@ from six import iteritems
from twisted.internet import defer
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
from synapse.util.caches import CACHE_SIZE_FACTOR
-
-from . import background_updates
-from ._base import Cache
+from synapse.util.caches.descriptors import Cache
logger = logging.getLogger(__name__)
@@ -33,46 +33,41 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 120 * 1000
-class ClientIpStore(background_updates.BackgroundUpdateStore):
- def __init__(self, db_conn, hs):
-
- self.client_ip_last_seen = Cache(
- name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
- )
+class ClientIpBackgroundUpdateStore(SQLBaseStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs)
- super(ClientIpStore, self).__init__(db_conn, hs)
-
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"user_ips_device_index",
index_name="user_ips_device_id",
table="user_ips",
columns=["user_id", "device_id", "last_seen"],
)
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"user_ips_last_seen_index",
index_name="user_ips_last_seen",
table="user_ips",
columns=["user_id", "last_seen"],
)
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"user_ips_last_seen_only_index",
index_name="user_ips_last_seen_only",
table="user_ips",
columns=["last_seen"],
)
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
"user_ips_analyze", self._analyze_user_ip
)
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
"user_ips_remove_dupes", self._remove_user_ip_dupes
)
# Register a unique index
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"user_ips_device_unique_index",
index_name="user_ips_user_token_ip_unique_index",
table="user_ips",
@@ -81,18 +76,13 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
)
# Drop the old non-unique index
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
"user_ips_drop_nonunique_index", self._remove_user_ip_nonunique
)
- # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
- self._batch_row_update = {}
-
- self._client_ip_looper = self._clock.looping_call(
- self._update_client_ips_batch, 5 * 1000
- )
- self.hs.get_reactor().addSystemEventTrigger(
- "before", "shutdown", self._update_client_ips_batch
+ # Update the last seen info in devices.
+ self.db.updates.register_background_update_handler(
+ "devices_last_seen", self._devices_last_seen_update
)
@defer.inlineCallbacks
@@ -102,9 +92,9 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
txn.close()
- yield self.runWithConnection(f)
- yield self._end_background_update("user_ips_drop_nonunique_index")
- defer.returnValue(1)
+ yield self.db.runWithConnection(f)
+ yield self.db.updates._end_background_update("user_ips_drop_nonunique_index")
+ return 1
@defer.inlineCallbacks
def _analyze_user_ip(self, progress, batch_size):
@@ -117,11 +107,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
def user_ips_analyze(txn):
txn.execute("ANALYZE user_ips")
- yield self.runInteraction("user_ips_analyze", user_ips_analyze)
+ yield self.db.runInteraction("user_ips_analyze", user_ips_analyze)
- yield self._end_background_update("user_ips_analyze")
+ yield self.db.updates._end_background_update("user_ips_analyze")
- defer.returnValue(1)
+ return 1
@defer.inlineCallbacks
def _remove_user_ip_dupes(self, progress, batch_size):
@@ -151,7 +141,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
return None
# Get a last seen that has roughly `batch_size` since `begin_last_seen`
- end_last_seen = yield self.runInteraction(
+ end_last_seen = yield self.db.runInteraction(
"user_ips_dups_get_last_seen", get_last_seen
)
@@ -282,16 +272,120 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
(user_id, access_token, ip, device_id, user_agent, last_seen),
)
- self._background_update_progress_txn(
+ self.db.updates._background_update_progress_txn(
txn, "user_ips_remove_dupes", {"last_seen": end_last_seen}
)
- yield self.runInteraction("user_ips_dups_remove", remove)
+ yield self.db.runInteraction("user_ips_dups_remove", remove)
if last:
- yield self._end_background_update("user_ips_remove_dupes")
+ yield self.db.updates._end_background_update("user_ips_remove_dupes")
+
+ return batch_size
+
+ @defer.inlineCallbacks
+ def _devices_last_seen_update(self, progress, batch_size):
+ """Background update to insert last seen info into devices table
+ """
+
+ last_user_id = progress.get("last_user_id", "")
+ last_device_id = progress.get("last_device_id", "")
+
+ def _devices_last_seen_update_txn(txn):
+ # This consists of two queries:
+ #
+ # 1. The sub-query searches for the next N devices and joins
+ # against user_ips to find the max last_seen associated with
+ # that device.
+ # 2. The outer query then joins again against user_ips on
+ # user/device/last_seen. This *should* hopefully only
+ # return one row, but if it does return more than one then
+ # we'll just end up updating the same device row multiple
+ # times, which is fine.
+
+ if self.database_engine.supports_tuple_comparison:
+ where_clause = "(user_id, device_id) > (?, ?)"
+ where_args = [last_user_id, last_device_id]
+ else:
+ # We explicitly do a `user_id >= ? AND (...)` here to ensure
+ # that an index is used, as doing `user_id > ? OR (user_id = ? AND ...)`
+ # makes it hard for query optimiser to tell that it can use the
+ # index on user_id
+ where_clause = "user_id >= ? AND (user_id > ? OR device_id > ?)"
+ where_args = [last_user_id, last_user_id, last_device_id]
+
+ sql = """
+ SELECT
+ last_seen, ip, user_agent, user_id, device_id
+ FROM (
+ SELECT
+ user_id, device_id, MAX(u.last_seen) AS last_seen
+ FROM devices
+ INNER JOIN user_ips AS u USING (user_id, device_id)
+ WHERE %(where_clause)s
+ GROUP BY user_id, device_id
+ ORDER BY user_id ASC, device_id ASC
+ LIMIT ?
+ ) c
+ INNER JOIN user_ips AS u USING (user_id, device_id, last_seen)
+ """ % {
+ "where_clause": where_clause
+ }
+ txn.execute(sql, where_args + [batch_size])
+
+ rows = txn.fetchall()
+ if not rows:
+ return 0
+
+ sql = """
+ UPDATE devices
+ SET last_seen = ?, ip = ?, user_agent = ?
+ WHERE user_id = ? AND device_id = ?
+ """
+ txn.execute_batch(sql, rows)
+
+ _, _, _, user_id, device_id = rows[-1]
+ self.db.updates._background_update_progress_txn(
+ txn,
+ "devices_last_seen",
+ {"last_user_id": user_id, "last_device_id": device_id},
+ )
+
+ return len(rows)
+
+ updated = yield self.db.runInteraction(
+ "_devices_last_seen_update", _devices_last_seen_update_txn
+ )
+
+ if not updated:
+ yield self.db.updates._end_background_update("devices_last_seen")
+
+ return updated
+
+
+class ClientIpStore(ClientIpBackgroundUpdateStore):
+ def __init__(self, database: Database, db_conn, hs):
+
+ self.client_ip_last_seen = Cache(
+ name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
+ )
+
+ super(ClientIpStore, self).__init__(database, db_conn, hs)
+
+ self.user_ips_max_age = hs.config.user_ips_max_age
- defer.returnValue(batch_size)
+ # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
+ self._batch_row_update = {}
+
+ self._client_ip_looper = self._clock.looping_call(
+ self._update_client_ips_batch, 5 * 1000
+ )
+ self.hs.get_reactor().addSystemEventTrigger(
+ "before", "shutdown", self._update_client_ips_batch
+ )
+
+ if self.user_ips_max_age:
+ self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
@defer.inlineCallbacks
def insert_client_ip(
@@ -314,23 +408,22 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
self._batch_row_update[key] = (user_agent, device_id, now)
+ @wrap_as_background_process("update_client_ips")
def _update_client_ips_batch(self):
# If the DB pool has already terminated, don't try updating
- if not self.hs.get_db_pool().running:
+ if not self.db.is_running():
return
- def update():
- to_update = self._batch_row_update
- self._batch_row_update = {}
- return self.runInteraction(
- "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
- )
+ to_update = self._batch_row_update
+ self._batch_row_update = {}
- return run_as_background_process("update_client_ips", update)
+ return self.db.runInteraction(
+ "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
+ )
def _update_client_ips_batch_txn(self, txn, to_update):
- if "user_ips" in self._unsafe_to_upsert_tables or (
+ if "user_ips" in self.db._unsafe_to_upsert_tables or (
not self.database_engine.can_native_upsert
):
self.database_engine.lock_table(txn, "user_ips")
@@ -339,7 +432,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
try:
- self._simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="user_ips",
keyvalues={
@@ -354,6 +447,23 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
},
lock=False,
)
+
+ # Technically an access token might not be associated with
+ # a device so we need to check.
+ if device_id:
+ # this is always an update rather than an upsert: the row should
+ # already exist, and if it doesn't, that may be because it has been
+ # deleted, and we don't want to re-create it.
+ self.db.simple_update_txn(
+ txn,
+ table="devices",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ updatevalues={
+ "user_agent": user_agent,
+ "last_seen": last_seen,
+ "ip": ip,
+ },
+ )
except Exception as e:
# Failed to upsert, log and continue
logger.error("Failed to insert client IP %r: %r", entry, e)
@@ -372,19 +482,14 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
keys giving the column names
"""
- res = yield self.runInteraction(
- "get_last_client_ip_by_device",
- self._get_last_client_ip_by_device_txn,
- user_id,
- device_id,
- retcols=(
- "user_id",
- "access_token",
- "ip",
- "user_agent",
- "device_id",
- "last_seen",
- ),
+ keyvalues = {"user_id": user_id}
+ if device_id is not None:
+ keyvalues["device_id"] = device_id
+
+ res = yield self.db.simple_select_list(
+ table="devices",
+ keyvalues=keyvalues,
+ retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
)
ret = {(d["user_id"], d["device_id"]): d for d in res}
@@ -401,43 +506,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"device_id": did,
"last_seen": last_seen,
}
- defer.returnValue(ret)
-
- @classmethod
- def _get_last_client_ip_by_device_txn(cls, txn, user_id, device_id, retcols):
- where_clauses = []
- bindings = []
- if device_id is None:
- where_clauses.append("user_id = ?")
- bindings.extend((user_id,))
- else:
- where_clauses.append("(user_id = ? AND device_id = ?)")
- bindings.extend((user_id, device_id))
-
- if not where_clauses:
- return []
-
- inner_select = (
- "SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips "
- "WHERE %(where)s "
- "GROUP BY user_id, device_id"
- ) % {"where": " OR ".join(where_clauses)}
-
- sql = (
- "SELECT %(retcols)s FROM user_ips "
- "JOIN (%(inner_select)s) ips ON"
- " user_ips.last_seen = ips.mls AND"
- " user_ips.user_id = ips.user_id AND"
- " (user_ips.device_id = ips.device_id OR"
- " (user_ips.device_id IS NULL AND ips.device_id IS NULL)"
- " )"
- ) % {
- "retcols": ",".join("user_ips." + c for c in retcols),
- "inner_select": inner_select,
- }
-
- txn.execute(sql, bindings)
- return cls.cursor_to_dict(txn)
+ return ret
@defer.inlineCallbacks
def get_user_ip_and_agents(self, user):
@@ -450,7 +519,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
user_agent, _, last_seen = self._batch_row_update[key]
results[(access_token, ip)] = (user_agent, last_seen)
- rows = yield self._simple_select_list(
+ rows = yield self.db.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "last_seen"],
@@ -461,14 +530,56 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
for row in rows
)
- defer.returnValue(
- list(
- {
- "access_token": access_token,
- "ip": ip,
- "user_agent": user_agent,
- "last_seen": last_seen,
- }
- for (access_token, ip), (user_agent, last_seen) in iteritems(results)
+ return [
+ {
+ "access_token": access_token,
+ "ip": ip,
+ "user_agent": user_agent,
+ "last_seen": last_seen,
+ }
+ for (access_token, ip), (user_agent, last_seen) in iteritems(results)
+ ]
+
+ @wrap_as_background_process("prune_old_user_ips")
+ async def _prune_old_user_ips(self):
+ """Removes entries in user IPs older than the configured period.
+ """
+
+ if self.user_ips_max_age is None:
+ # Nothing to do
+ return
+
+ if not await self.db.updates.has_completed_background_update(
+ "devices_last_seen"
+ ):
+ # Only start pruning if we have finished populating the devices
+ # last seen info.
+ return
+
+ # We do a slightly funky SQL delete to ensure we don't try and delete
+ # too much at once (as the table may be very large from before we
+ # started pruning).
+ #
+ # This works by finding the max last_seen that is less than the given
+ # time, but has no more than N rows before it, deleting all rows with
+ # a lesser last_seen time. (We COALESCE so that the sub-SELECT always
+ # returns exactly one row).
+ sql = """
+ DELETE FROM user_ips
+ WHERE last_seen <= (
+ SELECT COALESCE(MAX(last_seen), -1)
+ FROM (
+ SELECT last_seen FROM user_ips
+ WHERE last_seen <= ?
+ ORDER BY last_seen ASC
+ LIMIT 5000
+ ) AS u
)
- )
+ """
+
+ timestamp = self.clock.time_msec() - self.user_ips_max_age
+
+ def _prune_old_user_ips_txn(txn):
+ txn.execute(sql, (timestamp,))
+
+ await self.db.runInteraction("_prune_old_user_ips", _prune_old_user_ips_txn)
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index 4ea0deea4f..0613b49f4a 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -19,8 +19,9 @@ from canonicaljson import json
from twisted.internet import defer
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.logging.opentracing import log_kv, set_tag, trace
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
from synapse.util.caches.expiringcache import ExpiringCache
logger = logging.getLogger(__name__)
@@ -66,12 +67,13 @@ class DeviceInboxWorkerStore(SQLBaseStore):
messages.append(json.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
- return (messages, stream_pos)
+ return messages, stream_pos
- return self.runInteraction(
+ return self.db.runInteraction(
"get_new_messages_for_device", get_new_messages_for_device_txn
)
+ @trace
@defer.inlineCallbacks
def delete_messages_for_device(self, user_id, device_id, up_to_stream_id):
"""
@@ -87,12 +89,16 @@ class DeviceInboxWorkerStore(SQLBaseStore):
last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), None
)
+
+ set_tag("last_deleted_stream_id", last_deleted_stream_id)
+
if last_deleted_stream_id:
has_changed = self._device_inbox_stream_cache.has_entity_changed(
user_id, last_deleted_stream_id
)
if not has_changed:
- defer.returnValue(0)
+ log_kv({"message": "No changes in cache since last check"})
+ return 0
def delete_messages_for_device_txn(txn):
sql = (
@@ -103,10 +109,14 @@ class DeviceInboxWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, device_id, up_to_stream_id))
return txn.rowcount
- count = yield self.runInteraction(
+ count = yield self.db.runInteraction(
"delete_messages_for_device", delete_messages_for_device_txn
)
+ log_kv(
+ {"message": "deleted {} messages for device".format(count), "count": count}
+ )
+
# Update the cache, ensuring that we only ever increase the value
last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), 0
@@ -115,8 +125,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
last_deleted_stream_id, up_to_stream_id
)
- defer.returnValue(count)
+ return count
+ @trace
def get_new_device_msgs_for_remote(
self, destination, last_stream_id, current_stream_id, limit
):
@@ -132,16 +143,23 @@ class DeviceInboxWorkerStore(SQLBaseStore):
in the stream the messages got to.
"""
+ set_tag("destination", destination)
+ set_tag("last_stream_id", last_stream_id)
+ set_tag("current_stream_id", current_stream_id)
+ set_tag("limit", limit)
+
has_changed = self._device_federation_outbox_stream_cache.has_entity_changed(
destination, last_stream_id
)
if not has_changed or last_stream_id == current_stream_id:
+ log_kv({"message": "No new messages in stream"})
return defer.succeed(([], current_stream_id))
if limit <= 0:
# This can happen if we run out of room for EDUs in the transaction.
return defer.succeed(([], last_stream_id))
+ @trace
def get_new_messages_for_remote_destination_txn(txn):
sql = (
"SELECT stream_id, messages_json FROM device_federation_outbox"
@@ -156,14 +174,16 @@ class DeviceInboxWorkerStore(SQLBaseStore):
stream_pos = row[0]
messages.append(json.loads(row[1]))
if len(messages) < limit:
+ log_kv({"message": "Set stream position to current position"})
stream_pos = current_stream_id
- return (messages, stream_pos)
+ return messages, stream_pos
- return self.runInteraction(
+ return self.db.runInteraction(
"get_new_device_msgs_for_remote",
get_new_messages_for_remote_destination_txn,
)
+ @trace
def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
"""Used to delete messages when the remote destination acknowledges
their receipt.
@@ -183,28 +203,48 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
txn.execute(sql, (destination, up_to_stream_id))
- return self.runInteraction(
+ return self.db.runInteraction(
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
)
-class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
+class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
- def __init__(self, db_conn, hs):
- super(DeviceInboxStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs)
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"device_inbox_stream_index",
index_name="device_inbox_stream_id_user_id",
table="device_inbox",
columns=["stream_id", "user_id"],
)
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
)
+ @defer.inlineCallbacks
+ def _background_drop_index_device_inbox(self, progress, batch_size):
+ def reindex_txn(conn):
+ txn = conn.cursor()
+ txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
+ txn.close()
+
+ yield self.db.runWithConnection(reindex_txn)
+
+ yield self.db.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID)
+
+ return 1
+
+
+class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
+ DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
+
+ def __init__(self, database: Database, db_conn, hs):
+ super(DeviceInboxStore, self).__init__(database, db_conn, hs)
+
# Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions.
self._last_device_delete_cache = ExpiringCache(
@@ -214,6 +254,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
expiry_ms=30 * 60 * 1000,
)
+ @trace
@defer.inlineCallbacks
def add_messages_to_device_inbox(
self, local_messages_by_user_then_device, remote_messages_by_destination
@@ -253,7 +294,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
- yield self.runInteraction(
+ yield self.db.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
)
for user_id in local_messages_by_user_then_device.keys():
@@ -263,7 +304,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
destination, stream_id
)
- defer.returnValue(self._device_inbox_id_gen.get_current_token())
+ return self._device_inbox_id_gen.get_current_token()
@defer.inlineCallbacks
def add_messages_from_remote_to_device_inbox(
@@ -273,7 +314,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
# Check if we've already inserted a matching message_id for that
# origin. This can happen if the origin doesn't receive our
# acknowledgement from the first time we received the message.
- already_inserted = self._simple_select_one_txn(
+ already_inserted = self.db.simple_select_one_txn(
txn,
table="device_federation_inbox",
keyvalues={"origin": origin, "message_id": message_id},
@@ -285,7 +326,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
# Add an entry for this message_id so that we know we've processed
# it.
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="device_federation_inbox",
values={
@@ -303,7 +344,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
- yield self.runInteraction(
+ yield self.db.runInteraction(
"add_messages_from_remote_to_device_inbox",
add_messages_txn,
now_ms,
@@ -312,7 +353,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
for user_id in local_messages_by_user_then_device.keys():
self._device_inbox_stream_cache.entity_has_changed(user_id, stream_id)
- defer.returnValue(stream_id)
+ return stream_id
def _add_messages_to_local_device_inbox_txn(
self, txn, stream_id, messages_by_user_then_device
@@ -326,7 +367,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
devices = list(messages_by_device.keys())
if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids.
- sql = "SELECT device_id FROM devices" " WHERE user_id = ?"
+ sql = "SELECT device_id FROM devices WHERE user_id = ?"
txn.execute(sql, (user_id,))
message_json = json.dumps(messages_by_device["*"])
for row in txn:
@@ -337,15 +378,15 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
else:
if not devices:
continue
- sql = (
- "SELECT device_id FROM devices"
- " WHERE user_id = ? AND device_id IN ("
- + ",".join("?" * len(devices))
- + ")"
+
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "device_id", devices
)
+ sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause
+
# TODO: Maybe this needs to be done in batches if there are
# too many local devices for a given user.
- txn.execute(sql, [user_id] + devices)
+ txn.execute(sql, [user_id] + list(args))
for row in txn:
# Only insert into the local inbox if the device exists on
# this server
@@ -411,19 +452,6 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
return rows
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn
)
-
- @defer.inlineCallbacks
- def _background_drop_index_device_inbox(self, progress, batch_size):
- def reindex_txn(conn):
- txn = conn.cursor()
- txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
- txn.close()
-
- yield self.runWithConnection(reindex_txn)
-
- yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
-
- defer.returnValue(1)
diff --git a/synapse/storage/devices.py b/synapse/storage/data_stores/main/devices.py
index d102e07372..8af5f7de54 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,11 +22,24 @@ from canonicaljson import json
from twisted.internet import defer
-from synapse.api.errors import StoreError
+from synapse.api.errors import Codes, StoreError
+from synapse.logging.opentracing import (
+ get_active_span_text_map,
+ set_tag,
+ trace,
+ whitelisted_homeserver,
+)
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import Cache, SQLBaseStore, db_to_json
-from synapse.storage.background_updates import BackgroundUpdateStore
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
+from synapse.storage.database import Database
+from synapse.types import Collection, get_verify_key_from_cross_signing_key
+from synapse.util.caches.descriptors import (
+ Cache,
+ cached,
+ cachedInlineCallbacks,
+ cachedList,
+)
+from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__)
@@ -35,7 +50,8 @@ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
class DeviceWorkerStore(SQLBaseStore):
def get_device(self, user_id, device_id):
- """Retrieve a device.
+ """Retrieve a device. Only returns devices that are not marked as
+ hidden.
Args:
user_id (str): The ID of the user which owns the device
@@ -45,16 +61,17 @@ class DeviceWorkerStore(SQLBaseStore):
Raises:
StoreError: if the device is not found
"""
- return self._simple_select_one(
+ return self.db.simple_select_one(
table="devices",
- keyvalues={"user_id": user_id, "device_id": device_id},
+ keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_device",
)
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
- """Retrieve all of a user's registered devices.
+ """Retrieve all of a user's registered devices. Only returns devices
+ that are not marked as hidden.
Args:
user_id (str):
@@ -63,23 +80,29 @@ class DeviceWorkerStore(SQLBaseStore):
containing "device_id", "user_id" and "display_name" for each
device.
"""
- devices = yield self._simple_select_list(
+ devices = yield self.db.simple_select_list(
table="devices",
- keyvalues={"user_id": user_id},
+ keyvalues={"user_id": user_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
desc="get_devices_by_user",
)
- defer.returnValue({d["device_id"]: d for d in devices})
+ return {d["device_id"]: d for d in devices}
+ @trace
@defer.inlineCallbacks
- def get_devices_by_remote(self, destination, from_stream_id, limit):
- """Get stream of updates to send to remote servers
+ def get_device_updates_by_remote(self, destination, from_stream_id, limit):
+ """Get a stream of device updates to send to the given remote server.
+ Args:
+ destination (str): The host the device updates are intended for
+ from_stream_id (int): The minimum stream_id to filter updates by, exclusive
+ limit (int): Maximum number of device updates to return
Returns:
- Deferred[tuple[int, list[dict]]]:
+ Deferred[tuple[int, list[tuple[string,dict]]]]:
current stream id (ie, the stream id of the last update included in the
- response), and the list of updates
+ response), and the list of updates, where each update is a pair of EDU
+ type and EDU contents
"""
now_stream_id = self._device_list_id_gen.get_current_token()
@@ -87,7 +110,7 @@ class DeviceWorkerStore(SQLBaseStore):
destination, int(from_stream_id)
)
if not has_changed:
- defer.returnValue((now_stream_id, []))
+ return now_stream_id, []
# We retrieve n+1 devices from the list of outbound pokes where n is
# our outbound device update limit. We then check if the very last
@@ -99,9 +122,9 @@ class DeviceWorkerStore(SQLBaseStore):
# consider the device update to be too large, and simply skip the
# stream_id; the rationale being that such a large device list update
# is likely an error.
- updates = yield self.runInteraction(
- "get_devices_by_remote",
- self._get_devices_by_remote_txn,
+ updates = yield self.db.runInteraction(
+ "get_device_updates_by_remote",
+ self._get_device_updates_by_remote_txn,
destination,
from_stream_id,
now_stream_id,
@@ -110,7 +133,38 @@ class DeviceWorkerStore(SQLBaseStore):
# Return an empty list if there are no updates
if not updates:
- defer.returnValue((now_stream_id, []))
+ return now_stream_id, []
+
+ # get the cross-signing keys of the users in the list, so that we can
+ # determine which of the device changes were cross-signing keys
+ users = {r[0] for r in updates}
+ master_key_by_user = {}
+ self_signing_key_by_user = {}
+ for user in users:
+ cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
+ if cross_signing_key:
+ key_id, verify_key = get_verify_key_from_cross_signing_key(
+ cross_signing_key
+ )
+ # verify_key is a VerifyKey from signedjson, which uses
+ # .version to denote the portion of the key ID after the
+ # algorithm and colon, which is the device ID
+ master_key_by_user[user] = {
+ "key_info": cross_signing_key,
+ "device_id": verify_key.version,
+ }
+
+ cross_signing_key = yield self.get_e2e_cross_signing_key(
+ user, "self_signing"
+ )
+ if cross_signing_key:
+ key_id, verify_key = get_verify_key_from_cross_signing_key(
+ cross_signing_key
+ )
+ self_signing_key_by_user[user] = {
+ "key_info": cross_signing_key,
+ "device_id": verify_key.version,
+ }
# if we have exceeded the limit, we need to exclude any results with the
# same stream_id as the last row.
@@ -126,16 +180,43 @@ class DeviceWorkerStore(SQLBaseStore):
# (user_id, device_id) entries into a map, with the value being
# the max stream_id across each set of duplicate entries
#
- # maps (user_id, device_id) -> stream_id
+ # maps (user_id, device_id) -> (stream_id, opentracing_context)
# as long as their stream_id does not match that of the last row
+ #
+ # opentracing_context contains the opentracing metadata for the request
+ # that created the poke
+ #
+ # The most recent request's opentracing_context is used as the
+ # context which created the Edu.
+
query_map = {}
- for update in updates:
- if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
+ cross_signing_keys_by_user = {}
+ for user_id, device_id, update_stream_id, update_context in updates:
+ if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
# Stop processing updates
break
- key = (update[0], update[1])
- query_map[key] = max(query_map.get(key, 0), update[2])
+ if (
+ user_id in master_key_by_user
+ and device_id == master_key_by_user[user_id]["device_id"]
+ ):
+ result = cross_signing_keys_by_user.setdefault(user_id, {})
+ result["master_key"] = master_key_by_user[user_id]["key_info"]
+ elif (
+ user_id in self_signing_key_by_user
+ and device_id == self_signing_key_by_user[user_id]["device_id"]
+ ):
+ result = cross_signing_keys_by_user.setdefault(user_id, {})
+ result["self_signing_key"] = self_signing_key_by_user[user_id][
+ "key_info"
+ ]
+ else:
+ key = (user_id, device_id)
+
+ previous_update_stream_id, _ = query_map.get(key, (0, None))
+
+ if update_stream_id > previous_update_stream_id:
+ query_map[key] = (update_stream_id, update_context)
# If we didn't find any updates with a stream_id lower than the cutoff, it
# means that there are more than limit updates all of which have the same
@@ -145,18 +226,22 @@ class DeviceWorkerStore(SQLBaseStore):
# devices, in which case E2E isn't going to work well anyway. We'll just
# skip that stream_id and return an empty list, and continue with the next
# stream_id next time.
- if not query_map:
- defer.returnValue((stream_id_cutoff, []))
+ if not query_map and not cross_signing_keys_by_user:
+ return stream_id_cutoff, []
results = yield self._get_device_update_edus_by_remote(
- destination,
- from_stream_id,
- query_map,
+ destination, from_stream_id, query_map
)
- defer.returnValue((now_stream_id, results))
+ # add the updated cross-signing keys to the results list
+ for user_id, result in iteritems(cross_signing_keys_by_user):
+ result["user_id"] = user_id
+ # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+ results.append(("org.matrix.signing_key_update", result))
- def _get_devices_by_remote_txn(
+ return now_stream_id, results
+
+ def _get_device_updates_by_remote_txn(
self, txn, destination, from_stream_id, now_stream_id, limit
):
"""Return device update information for a given remote destination
@@ -171,8 +256,9 @@ class DeviceWorkerStore(SQLBaseStore):
Returns:
List: List of device updates
"""
+ # get the list of device updates that need to be sent
sql = """
- SELECT user_id, device_id, stream_id FROM device_lists_outbound_pokes
+ SELECT user_id, device_id, stream_id, opentracing_context FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
ORDER BY stream_id
LIMIT ?
@@ -182,27 +268,30 @@ class DeviceWorkerStore(SQLBaseStore):
return list(txn)
@defer.inlineCallbacks
- def _get_device_update_edus_by_remote(
- self, destination, from_stream_id, query_map,
- ):
+ def _get_device_update_edus_by_remote(self, destination, from_stream_id, query_map):
"""Returns a list of device update EDUs as well as E2EE keys
Args:
destination (str): The host the device updates are intended for
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
- query_map (Dict[(str, str): int]): Dictionary mapping
- user_id/device_id to update stream_id
+ query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
+ user_id/device_id to update stream_id and the relevent json-encoded
+ opentracing context
Returns:
List[Dict]: List of objects representing an device update EDU
"""
- devices = yield self.runInteraction(
- "_get_e2e_device_keys_txn",
- self._get_e2e_device_keys_txn,
- query_map.keys(),
- include_all_devices=True,
- include_deleted_devices=True,
+ devices = (
+ yield self.db.runInteraction(
+ "_get_e2e_device_keys_txn",
+ self._get_e2e_device_keys_txn,
+ query_map.keys(),
+ include_all_devices=True,
+ include_deleted_devices=True,
+ )
+ if query_map
+ else {}
)
results = []
@@ -210,15 +299,16 @@ class DeviceWorkerStore(SQLBaseStore):
# The prev_id for the first row is always the last row before
# `from_stream_id`
prev_id = yield self._get_last_device_update_for_remote_user(
- destination, user_id, from_stream_id,
+ destination, user_id, from_stream_id
)
for device_id, device in iteritems(user_devices):
- stream_id = query_map[(user_id, device_id)]
+ stream_id, opentracing_context = query_map[(user_id, device_id)]
result = {
"user_id": user_id,
"device_id": device_id,
"prev_id": [prev_id] if prev_id else [],
"stream_id": stream_id,
+ "org.matrix.opentracing_context": opentracing_context,
}
prev_id = stream_id
@@ -227,18 +317,25 @@ class DeviceWorkerStore(SQLBaseStore):
key_json = device.get("key_json", None)
if key_json:
result["keys"] = db_to_json(key_json)
+
+ if "signatures" in device:
+ for sig_user_id, sigs in device["signatures"].items():
+ result["keys"].setdefault("signatures", {}).setdefault(
+ sig_user_id, {}
+ ).update(sigs)
+
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
else:
result["deleted"] = True
- results.append(result)
+ results.append(("m.device_list_update", result))
- defer.returnValue(results)
+ return results
def _get_last_device_update_for_remote_user(
- self, destination, user_id, from_stream_id,
+ self, destination, user_id, from_stream_id
):
def f(txn):
prev_sent_id_sql = """
@@ -250,12 +347,12 @@ class DeviceWorkerStore(SQLBaseStore):
rows = txn.fetchall()
return rows[0][0]
- return self.runInteraction("get_last_device_update_for_remote_user", f)
+ return self.db.runInteraction("get_last_device_update_for_remote_user", f)
def mark_as_sent_devices_by_remote(self, destination, stream_id):
"""Mark that updates have successfully been sent to the destination.
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"mark_as_sent_devices_by_remote",
self._mark_as_sent_devices_by_remote_txn,
destination,
@@ -299,9 +396,45 @@ class DeviceWorkerStore(SQLBaseStore):
"""
txn.execute(sql, (destination, stream_id))
+ @defer.inlineCallbacks
+ def add_user_signature_change_to_streams(self, from_user_id, user_ids):
+ """Persist that a user has made new signatures
+
+ Args:
+ from_user_id (str): the user who made the signatures
+ user_ids (list[str]): the users who were signed
+ """
+
+ with self._device_list_id_gen.get_next() as stream_id:
+ yield self.db.runInteraction(
+ "add_user_sig_change_to_streams",
+ self._add_user_signature_change_txn,
+ from_user_id,
+ user_ids,
+ stream_id,
+ )
+ return stream_id
+
+ def _add_user_signature_change_txn(self, txn, from_user_id, user_ids, stream_id):
+ txn.call_after(
+ self._user_signature_stream_cache.entity_has_changed,
+ from_user_id,
+ stream_id,
+ )
+ self.db.simple_insert_txn(
+ txn,
+ "user_signature_stream",
+ values={
+ "stream_id": stream_id,
+ "from_user_id": from_user_id,
+ "user_ids": json.dumps(user_ids),
+ },
+ )
+
def get_device_stream_token(self):
return self._device_list_id_gen.get_current_token()
+ @trace
@defer.inlineCallbacks
def get_user_devices_from_cache(self, query_list):
"""Get the devices (and keys if any) for remote users from the cache.
@@ -315,11 +448,17 @@ class DeviceWorkerStore(SQLBaseStore):
a set of user_ids and results_map is a mapping of
user_id -> device_id -> device_info
"""
- user_ids = set(user_id for user_id, _ in query_list)
+ user_ids = {user_id for user_id, _ in query_list}
user_map = yield self.get_device_list_last_stream_id_for_remotes(list(user_ids))
- user_ids_in_cache = set(
- user_id for user_id, stream_id in user_map.items() if stream_id
+
+ # We go and check if any of the users need to have their device lists
+ # resynced. If they do then we remove them from the cached list.
+ users_needing_resync = yield self.get_user_ids_requiring_device_list_resync(
+ user_ids
)
+ user_ids_in_cache = {
+ user_id for user_id, stream_id in user_map.items() if stream_id
+ } - users_needing_resync
user_ids_not_in_cache = user_ids - user_ids_in_cache
results = {}
@@ -331,31 +470,34 @@ class DeviceWorkerStore(SQLBaseStore):
device = yield self._get_cached_user_device(user_id, device_id)
results.setdefault(user_id, {})[device_id] = device
else:
- results[user_id] = yield self._get_cached_devices_for_user(user_id)
+ results[user_id] = yield self.get_cached_devices_for_user(user_id)
- defer.returnValue((user_ids_not_in_cache, results))
+ set_tag("in_cache", results)
+ set_tag("not_in_cache", user_ids_not_in_cache)
+
+ return user_ids_not_in_cache, results
@cachedInlineCallbacks(num_args=2, tree=True)
def _get_cached_user_device(self, user_id, device_id):
- content = yield self._simple_select_one_onecol(
+ content = yield self.db.simple_select_one_onecol(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="content",
desc="_get_cached_user_device",
)
- defer.returnValue(db_to_json(content))
+ return db_to_json(content)
@cachedInlineCallbacks()
- def _get_cached_devices_for_user(self, user_id):
- devices = yield self._simple_select_list(
+ def get_cached_devices_for_user(self, user_id):
+ devices = yield self.db.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},
retcols=("device_id", "content"),
- desc="_get_cached_devices_for_user",
- )
- defer.returnValue(
- {device["device_id"]: db_to_json(device["content"]) for device in devices}
+ desc="get_cached_devices_for_user",
)
+ return {
+ device["device_id"]: db_to_json(device["content"]) for device in devices
+ }
def get_devices_with_keys_by_user(self, user_id):
"""Get all devices (with any device keys) for a user
@@ -363,7 +505,7 @@ class DeviceWorkerStore(SQLBaseStore):
Returns:
(stream_id, devices)
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"get_devices_with_keys_by_user",
self._get_devices_with_keys_by_user_txn,
user_id,
@@ -385,6 +527,13 @@ class DeviceWorkerStore(SQLBaseStore):
key_json = device.get("key_json", None)
if key_json:
result["keys"] = db_to_json(key_json)
+
+ if "signatures" in device:
+ for sig_user_id, sigs in device["signatures"].items():
+ result["keys"].setdefault("signatures", {}).setdefault(
+ sig_user_id, {}
+ ).update(sigs)
+
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
@@ -395,22 +544,72 @@ class DeviceWorkerStore(SQLBaseStore):
return now_stream_id, []
- @defer.inlineCallbacks
- def get_user_whose_devices_changed(self, from_key):
- """Get set of users whose devices have changed since `from_key`.
+ def get_users_whose_devices_changed(self, from_key, user_ids):
+ """Get set of users whose devices have changed since `from_key` that
+ are in the given list of user_ids.
+
+ Args:
+ from_key (str): The device lists stream token
+ user_ids (Iterable[str])
+
+ Returns:
+ Deferred[set[str]]: The set of user_ids whose devices have changed
+ since `from_key`
"""
from_key = int(from_key)
- changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
- if changed is not None:
- defer.returnValue(set(changed))
- sql = """
- SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
- """
- rows = yield self._execute(
- "get_user_whose_devices_changed", None, sql, from_key
+ # Get set of users who *may* have changed. Users not in the returned
+ # list have definitely not changed.
+ to_check = list(
+ self._device_list_stream_cache.get_entities_changed(user_ids, from_key)
+ )
+
+ if not to_check:
+ return defer.succeed(set())
+
+ def _get_users_whose_devices_changed_txn(txn):
+ changes = set()
+
+ sql = """
+ SELECT DISTINCT user_id FROM device_lists_stream
+ WHERE stream_id > ?
+ AND
+ """
+
+ for chunk in batch_iter(to_check, 100):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "user_id", chunk
+ )
+ txn.execute(sql + clause, (from_key,) + tuple(args))
+ changes.update(user_id for user_id, in txn)
+
+ return changes
+
+ return self.db.runInteraction(
+ "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
)
- defer.returnValue(set(row[0] for row in rows))
+
+ @defer.inlineCallbacks
+ def get_users_whose_signatures_changed(self, user_id, from_key):
+ """Get the users who have new cross-signing signatures made by `user_id` since
+ `from_key`.
+
+ Args:
+ user_id (str): the user who made the signatures
+ from_key (str): The device lists stream token
+ """
+ from_key = int(from_key)
+ if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
+ sql = """
+ SELECT DISTINCT user_ids FROM user_signature_stream
+ WHERE from_user_id = ? AND stream_id > ?
+ """
+ rows = yield self.db.execute(
+ "get_users_whose_signatures_changed", None, sql, user_id, from_key
+ )
+ return {user for row in rows for user in json.loads(row[0])}
+ else:
+ return set()
def get_all_device_list_changes_for_remotes(self, from_key, to_key):
"""Return a list of `(stream_id, user_id, destination)` which is the
@@ -426,7 +625,7 @@ class DeviceWorkerStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
GROUP BY user_id, destination
"""
- return self._execute(
+ return self.db.execute(
"get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
)
@@ -435,7 +634,7 @@ class DeviceWorkerStore(SQLBaseStore):
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
- return self._simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",
@@ -449,7 +648,7 @@ class DeviceWorkerStore(SQLBaseStore):
inlineCallbacks=True,
)
def get_device_list_last_stream_id_for_remotes(self, user_ids):
- rows = yield self._simple_select_many_batch(
+ rows = yield self.db.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
@@ -460,22 +659,45 @@ class DeviceWorkerStore(SQLBaseStore):
results = {user_id: None for user_id in user_ids}
results.update({row["user_id"]: row["stream_id"] for row in rows})
- defer.returnValue(results)
+ return results
+ @defer.inlineCallbacks
+ def get_user_ids_requiring_device_list_resync(self, user_ids: Collection[str]):
+ """Given a list of remote users return the list of users that we
+ should resync the device lists for.
-class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
- def __init__(self, db_conn, hs):
- super(DeviceStore, self).__init__(db_conn, hs)
+ Returns:
+ Deferred[Set[str]]
+ """
- # Map of (user_id, device_id) -> bool. If there is an entry that implies
- # the device exists.
- self.device_id_exists_cache = Cache(
- name="device_id_exists", keylen=2, max_entries=10000
+ rows = yield self.db.simple_select_many_batch(
+ table="device_lists_remote_resync",
+ column="user_id",
+ iterable=user_ids,
+ retcols=("user_id",),
+ desc="get_user_ids_requiring_device_list_resync",
+ )
+
+ return {row["user_id"] for row in rows}
+
+ def mark_remote_user_device_cache_as_stale(self, user_id: str):
+ """Records that the server has reason to believe the cache of the devices
+ for the remote users is out of date.
+ """
+ return self.db.simple_upsert(
+ table="device_lists_remote_resync",
+ keyvalues={"user_id": user_id},
+ values={},
+ insertion_values={"added_ts": self._clock.time_msec()},
+ desc="make_remote_user_device_cache_as_stale",
)
- self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
- self.register_background_index_update(
+class DeviceBackgroundUpdateStore(SQLBaseStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+
+ self.db.updates.register_background_index_update(
"device_lists_stream_idx",
index_name="device_lists_stream_user_id",
table="device_lists_stream",
@@ -483,7 +705,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
)
# create a unique index on device_lists_remote_cache
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"device_lists_remote_cache_unique_idx",
index_name="device_lists_remote_cache_unique_id",
table="device_lists_remote_cache",
@@ -492,7 +714,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
)
# And one on device_lists_remote_extremeties
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"device_lists_remote_extremeties_unique_idx",
index_name="device_lists_remote_extremeties_unique_idx",
table="device_lists_remote_extremeties",
@@ -501,12 +723,39 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
)
# once they complete, we can remove the old non-unique indexes.
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES,
self._drop_device_list_streams_non_unique_indexes,
)
@defer.inlineCallbacks
+ def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
+ def f(conn):
+ txn = conn.cursor()
+ txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
+ txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
+ txn.close()
+
+ yield self.db.runWithConnection(f)
+ yield self.db.updates._end_background_update(
+ DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES
+ )
+ return 1
+
+
+class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(DeviceStore, self).__init__(database, db_conn, hs)
+
+ # Map of (user_id, device_id) -> bool. If there is an entry that implies
+ # the device exists.
+ self.device_id_exists_cache = Cache(
+ name="device_id_exists", keylen=2, max_entries=10000
+ )
+
+ self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
+
+ @defer.inlineCallbacks
def store_device(self, user_id, device_id, initial_device_display_name):
"""Ensure the given device is known; add it to the store if not
@@ -518,24 +767,39 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
Returns:
defer.Deferred: boolean whether the device was inserted or an
existing device existed with that ID.
+ Raises:
+ StoreError: if the device is already in use
"""
key = (user_id, device_id)
if self.device_id_exists_cache.get(key, None):
- defer.returnValue(False)
+ return False
try:
- inserted = yield self._simple_insert(
+ inserted = yield self.db.simple_insert(
"devices",
values={
"user_id": user_id,
"device_id": device_id,
"display_name": initial_device_display_name,
+ "hidden": False,
},
desc="store_device",
or_ignore=True,
)
+ if not inserted:
+ # if the device already exists, check if it's a real device, or
+ # if the device ID is reserved by something else
+ hidden = yield self.db.simple_select_one_onecol(
+ "devices",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ retcol="hidden",
+ )
+ if hidden:
+ raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
self.device_id_exists_cache.prefill(key, True)
- defer.returnValue(inserted)
+ return inserted
+ except StoreError:
+ raise
except Exception as e:
logger.error(
"store_device with device_id=%s(%r) user_id=%s(%r)"
@@ -560,9 +824,9 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
Returns:
defer.Deferred
"""
- yield self._simple_delete_one(
+ yield self.db.simple_delete_one(
table="devices",
- keyvalues={"user_id": user_id, "device_id": device_id},
+ keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
desc="delete_device",
)
@@ -578,18 +842,19 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
Returns:
defer.Deferred
"""
- yield self._simple_delete_many(
+ yield self.db.simple_delete_many(
table="devices",
column="device_id",
iterable=device_ids,
- keyvalues={"user_id": user_id},
+ keyvalues={"user_id": user_id, "hidden": False},
desc="delete_devices",
)
for device_id in device_ids:
self.device_id_exists_cache.invalidate((user_id, device_id))
def update_device(self, user_id, device_id, new_display_name=None):
- """Update a device.
+ """Update a device. Only updates the device if it is not marked as
+ hidden.
Args:
user_id (str): The ID of the user which owns the device
@@ -606,9 +871,9 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
updates["display_name"] = new_display_name
if not updates:
return defer.succeed(None)
- return self._simple_update_one(
+ return self.db.simple_update_one(
table="devices",
- keyvalues={"user_id": user_id, "device_id": device_id},
+ keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
updatevalues=updates,
desc="update_device",
)
@@ -617,7 +882,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
"""Mark that we no longer track device lists for remote user.
"""
- yield self._simple_delete(
+ yield self.db.simple_delete(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
desc="mark_remote_user_device_list_as_unsubscribed",
@@ -641,7 +906,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
Returns:
Deferred[None]
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"update_remote_device_list_cache_entry",
self._update_remote_device_list_cache_entry_txn,
user_id,
@@ -654,7 +919,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
self, txn, user_id, device_id, content, stream_id
):
if content.get("deleted"):
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -662,7 +927,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id))
else:
- self._simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -673,12 +938,12 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
)
txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id))
- txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
+ txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,))
txn.call_after(
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
)
- self._simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
@@ -702,7 +967,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
Returns:
Deferred[None]
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"update_remote_device_list_cache",
self._update_remote_device_list_cache_txn,
user_id,
@@ -711,11 +976,11 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
)
def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id):
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
)
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="device_lists_remote_cache",
values=[
@@ -728,13 +993,13 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
],
)
- txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
+ txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,))
txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,))
txn.call_after(
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
)
- self._simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
@@ -744,13 +1009,20 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
lock=False,
)
+ # If we're replacing the remote user's device list cache presumably
+ # we've done a full resync, so we remove the entry that says we need
+ # to resync
+ self.db.simple_delete_txn(
+ txn, table="device_lists_remote_resync", keyvalues={"user_id": user_id},
+ )
+
@defer.inlineCallbacks
def add_device_change_to_streams(self, user_id, device_ids, hosts):
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
"""
with self._device_list_id_gen.get_next() as stream_id:
- yield self.runInteraction(
+ yield self.db.runInteraction(
"add_device_change_to_streams",
self._add_device_change_txn,
user_id,
@@ -758,7 +1030,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
hosts,
stream_id,
)
- defer.returnValue(stream_id)
+ return stream_id
def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
now = self._clock.time_msec()
@@ -783,7 +1055,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
[(user_id, device_id, stream_id) for device_id in device_ids],
)
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="device_lists_stream",
values=[
@@ -792,7 +1064,9 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
],
)
- self._simple_insert_many_txn(
+ context = get_active_span_text_map()
+
+ self.db.simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
values=[
@@ -803,6 +1077,9 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
"device_id": device_id,
"sent": False,
"ts": now,
+ "opentracing_context": json.dumps(context)
+ if whitelisted_homeserver(destination)
+ else "{}",
}
for destination in hosts
for device_id in device_ids
@@ -852,19 +1129,7 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
return run_as_background_process(
"prune_old_outbound_device_pokes",
- self.runInteraction,
+ self.db.runInteraction,
"_prune_old_outbound_device_pokes",
_prune_txn,
)
-
- @defer.inlineCallbacks
- def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
- def f(conn):
- txn = conn.cursor()
- txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
- txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
- txn.close()
-
- yield self.runWithConnection(f)
- yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
- defer.returnValue(1)
diff --git a/synapse/storage/directory.py b/synapse/storage/data_stores/main/directory.py
index 201bbd430c..c9e7de7d12 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/data_stores/main/directory.py
@@ -18,10 +18,9 @@ from collections import namedtuple
from twisted.internet import defer
from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
-from ._base import SQLBaseStore
-
RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
@@ -37,7 +36,7 @@ class DirectoryWorkerStore(SQLBaseStore):
Deferred: results in namedtuple with keys "room_id" and
"servers" or None if no association can be found
"""
- room_id = yield self._simple_select_one_onecol(
+ room_id = yield self.db.simple_select_one_onecol(
"room_aliases",
{"room_alias": room_alias.to_string()},
"room_id",
@@ -46,10 +45,9 @@ class DirectoryWorkerStore(SQLBaseStore):
)
if not room_id:
- defer.returnValue(None)
- return
+ return None
- servers = yield self._simple_select_onecol(
+ servers = yield self.db.simple_select_onecol(
"room_alias_servers",
{"room_alias": room_alias.to_string()},
"server",
@@ -57,13 +55,12 @@ class DirectoryWorkerStore(SQLBaseStore):
)
if not servers:
- defer.returnValue(None)
- return
+ return None
- defer.returnValue(RoomAliasMapping(room_id, room_alias.to_string(), servers))
+ return RoomAliasMapping(room_id, room_alias.to_string(), servers)
def get_room_alias_creator(self, room_alias):
- return self._simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="room_aliases",
keyvalues={"room_alias": room_alias},
retcol="creator",
@@ -72,7 +69,7 @@ class DirectoryWorkerStore(SQLBaseStore):
@cached(max_entries=5000)
def get_aliases_for_room(self, room_id):
- return self._simple_select_onecol(
+ return self.db.simple_select_onecol(
"room_aliases",
{"room_id": room_id},
"room_alias",
@@ -96,7 +93,7 @@ class DirectoryStore(DirectoryWorkerStore):
"""
def alias_txn(txn):
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
"room_aliases",
{
@@ -106,7 +103,7 @@ class DirectoryStore(DirectoryWorkerStore):
},
)
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="room_alias_servers",
values=[
@@ -120,20 +117,22 @@ class DirectoryStore(DirectoryWorkerStore):
)
try:
- ret = yield self.runInteraction("create_room_alias_association", alias_txn)
+ ret = yield self.db.runInteraction(
+ "create_room_alias_association", alias_txn
+ )
except self.database_engine.module.IntegrityError:
raise SynapseError(
409, "Room alias %s already exists" % room_alias.to_string()
)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def delete_room_alias(self, room_alias):
- room_id = yield self.runInteraction(
+ room_id = yield self.db.runInteraction(
"delete_room_alias", self._delete_room_alias_txn, room_alias
)
- defer.returnValue(room_id)
+ return room_id
def _delete_room_alias_txn(self, txn, room_alias):
txn.execute(
@@ -171,6 +170,6 @@ class DirectoryStore(DirectoryWorkerStore):
txn, self.get_aliases_for_room, (new_room_id,)
)
- return self.runInteraction(
+ return self.db.runInteraction(
"_update_aliases_for_room_txn", _update_aliases_for_room_txn
)
diff --git a/synapse/storage/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py
index 521936e3b0..84594cf0a9 100644
--- a/synapse/storage/e2e_room_keys.py
+++ b/synapse/storage/data_stores/main/e2e_room_keys.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
+# Copyright 2019 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,55 +19,14 @@ import json
from twisted.internet import defer
from synapse.api.errors import StoreError
-
-from ._base import SQLBaseStore
+from synapse.logging.opentracing import log_kv, trace
+from synapse.storage._base import SQLBaseStore
class EndToEndRoomKeyStore(SQLBaseStore):
@defer.inlineCallbacks
- def get_e2e_room_key(self, user_id, version, room_id, session_id):
- """Get the encrypted E2E room key for a given session from a given
- backup version of room_keys. We only store the 'best' room key for a given
- session at a given time, as determined by the handler.
-
- Args:
- user_id(str): the user whose backup we're querying
- version(str): the version ID of the backup for the set of keys we're querying
- room_id(str): the ID of the room whose keys we're querying.
- This is a bit redundant as it's implied by the session_id, but
- we include for consistency with the rest of the API.
- session_id(str): the session whose room_key we're querying.
-
- Returns:
- A deferred dict giving the session_data and message metadata for
- this room key.
- """
-
- row = yield self._simple_select_one(
- table="e2e_room_keys",
- keyvalues={
- "user_id": user_id,
- "version": version,
- "room_id": room_id,
- "session_id": session_id,
- },
- retcols=(
- "first_message_index",
- "forwarded_count",
- "is_verified",
- "session_data",
- ),
- desc="get_e2e_room_key",
- )
-
- row["session_data"] = json.loads(row["session_data"])
-
- defer.returnValue(row)
-
- @defer.inlineCallbacks
- def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
- """Replaces or inserts the encrypted E2E room key for a given session in
- a given backup
+ def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
+ """Replaces the encrypted E2E room key for a given session in a given backup
Args:
user_id(str): the user whose backup we're setting
@@ -78,34 +38,73 @@ class EndToEndRoomKeyStore(SQLBaseStore):
StoreError
"""
- yield self._simple_upsert(
+ yield self.db.simple_update_one(
table="e2e_room_keys",
keyvalues={
"user_id": user_id,
+ "version": version,
"room_id": room_id,
"session_id": session_id,
},
- values={
- "version": version,
- "first_message_index": room_key['first_message_index'],
- "forwarded_count": room_key['forwarded_count'],
- "is_verified": room_key['is_verified'],
- "session_data": json.dumps(room_key['session_data']),
+ updatevalues={
+ "first_message_index": room_key["first_message_index"],
+ "forwarded_count": room_key["forwarded_count"],
+ "is_verified": room_key["is_verified"],
+ "session_data": json.dumps(room_key["session_data"]),
},
- lock=False,
+ desc="update_e2e_room_key",
+ )
+
+ @defer.inlineCallbacks
+ def add_e2e_room_keys(self, user_id, version, room_keys):
+ """Bulk add room keys to a given backup.
+
+ Args:
+ user_id (str): the user whose backup we're adding to
+ version (str): the version ID of the backup for the set of keys we're adding to
+ room_keys (iterable[(str, str, dict)]): the keys to add, in the form
+ (roomID, sessionID, keyData)
+ """
+
+ values = []
+ for (room_id, session_id, room_key) in room_keys:
+ values.append(
+ {
+ "user_id": user_id,
+ "version": version,
+ "room_id": room_id,
+ "session_id": session_id,
+ "first_message_index": room_key["first_message_index"],
+ "forwarded_count": room_key["forwarded_count"],
+ "is_verified": room_key["is_verified"],
+ "session_data": json.dumps(room_key["session_data"]),
+ }
+ )
+ log_kv(
+ {
+ "message": "Set room key",
+ "room_id": room_id,
+ "session_id": session_id,
+ "room_key": room_key,
+ }
+ )
+
+ yield self.db.simple_insert_many(
+ table="e2e_room_keys", values=values, desc="add_e2e_room_keys"
)
+ @trace
@defer.inlineCallbacks
def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
room, or a given session.
Args:
- user_id(str): the user whose backup we're querying
- version(str): the version ID of the backup for the set of keys we're querying
- room_id(str): Optional. the ID of the room whose keys we're querying, if any.
+ user_id (str): the user whose backup we're querying
+ version (str): the version ID of the backup for the set of keys we're querying
+ room_id (str): Optional. the ID of the room whose keys we're querying, if any.
If not specified, we return the keys for all the rooms in the backup.
- session_id(str): Optional. the session whose room_key we're querying, if any.
+ session_id (str): Optional. the session whose room_key we're querying, if any.
If specified, we also require the room_id to be specified.
If not specified, we return all the keys in this version of
the backup (or for the specified room)
@@ -118,15 +117,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
try:
version = int(version)
except ValueError:
- defer.returnValue({'rooms': {}})
+ return {"rooms": {}}
keyvalues = {"user_id": user_id, "version": version}
if room_id:
- keyvalues['room_id'] = room_id
+ keyvalues["room_id"] = room_id
if session_id:
- keyvalues['session_id'] = session_id
+ keyvalues["session_id"] = session_id
- rows = yield self._simple_select_list(
+ rows = yield self.db.simple_select_list(
table="e2e_room_keys",
keyvalues=keyvalues,
retcols=(
@@ -141,18 +140,108 @@ class EndToEndRoomKeyStore(SQLBaseStore):
desc="get_e2e_room_keys",
)
- sessions = {'rooms': {}}
+ sessions = {"rooms": {}}
for row in rows:
- room_entry = sessions['rooms'].setdefault(row['room_id'], {"sessions": {}})
- room_entry['sessions'][row['session_id']] = {
+ room_entry = sessions["rooms"].setdefault(row["room_id"], {"sessions": {}})
+ room_entry["sessions"][row["session_id"]] = {
"first_message_index": row["first_message_index"],
"forwarded_count": row["forwarded_count"],
"is_verified": row["is_verified"],
"session_data": json.loads(row["session_data"]),
}
- defer.returnValue(sessions)
+ return sessions
+
+ def get_e2e_room_keys_multi(self, user_id, version, room_keys):
+ """Get multiple room keys at a time. The difference between this function and
+ get_e2e_room_keys is that this function can be used to retrieve
+ multiple specific keys at a time, whereas get_e2e_room_keys is used for
+ getting all the keys in a backup version, all the keys for a room, or a
+ specific key.
+
+ Args:
+ user_id (str): the user whose backup we're querying
+ version (str): the version ID of the backup we're querying about
+ room_keys (dict[str, dict[str, iterable[str]]]): a map from
+ room ID -> {"session": [session ids]} indicating the session IDs
+ that we want to query
+
+ Returns:
+ Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key
+ """
+
+ return self.db.runInteraction(
+ "get_e2e_room_keys_multi",
+ self._get_e2e_room_keys_multi_txn,
+ user_id,
+ version,
+ room_keys,
+ )
+
+ @staticmethod
+ def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys):
+ if not room_keys:
+ return {}
+
+ where_clauses = []
+ params = [user_id, version]
+ for room_id, room in room_keys.items():
+ sessions = list(room["sessions"])
+ if not sessions:
+ continue
+ params.append(room_id)
+ params.extend(sessions)
+ where_clauses.append(
+ "(room_id = ? AND session_id IN (%s))"
+ % (",".join(["?" for _ in sessions]),)
+ )
+
+ # check if we're actually querying something
+ if not where_clauses:
+ return {}
+
+ sql = """
+ SELECT room_id, session_id, first_message_index, forwarded_count,
+ is_verified, session_data
+ FROM e2e_room_keys
+ WHERE user_id = ? AND version = ? AND (%s)
+ """ % (
+ " OR ".join(where_clauses)
+ )
+
+ txn.execute(sql, params)
+
+ ret = {}
+
+ for row in txn:
+ room_id = row[0]
+ session_id = row[1]
+ ret.setdefault(room_id, {})
+ ret[room_id][session_id] = {
+ "first_message_index": row[2],
+ "forwarded_count": row[3],
+ "is_verified": row[4],
+ "session_data": json.loads(row[5]),
+ }
+
+ return ret
+
+ def count_e2e_room_keys(self, user_id, version):
+ """Get the number of keys in a backup version.
+
+ Args:
+ user_id (str): the user whose backup we're querying
+ version (str): the version ID of the backup we're querying about
+ """
+
+ return self.db.simple_select_one_onecol(
+ table="e2e_room_keys",
+ keyvalues={"user_id": user_id, "version": version},
+ retcol="COUNT(*)",
+ desc="count_e2e_room_keys",
+ )
+ @trace
@defer.inlineCallbacks
def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
@@ -174,11 +263,11 @@ class EndToEndRoomKeyStore(SQLBaseStore):
keyvalues = {"user_id": user_id, "version": int(version)}
if room_id:
- keyvalues['room_id'] = room_id
+ keyvalues["room_id"] = room_id
if session_id:
- keyvalues['session_id'] = session_id
+ keyvalues["session_id"] = session_id
- yield self._simple_delete(
+ yield self.db.simple_delete(
table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
)
@@ -191,7 +280,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
row = txn.fetchone()
if not row:
- raise StoreError(404, 'No current backup version')
+ raise StoreError(404, "No current backup version")
return row[0]
def get_e2e_room_keys_version_info(self, user_id, version=None):
@@ -209,6 +298,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
version(str)
algorithm(str)
auth_data(object): opaque dict supplied by the client
+ etag(int): tag of the keys in the backup
"""
def _get_e2e_room_keys_version_info_txn(txn):
@@ -222,20 +312,23 @@ class EndToEndRoomKeyStore(SQLBaseStore):
# it isn't there.
raise StoreError(404, "No row found")
- result = self._simple_select_one_txn(
+ result = self.db.simple_select_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
- retcols=("version", "algorithm", "auth_data"),
+ retcols=("version", "algorithm", "auth_data", "etag"),
)
result["auth_data"] = json.loads(result["auth_data"])
result["version"] = str(result["version"])
+ if result["etag"] is None:
+ result["etag"] = 0
return result
- return self.runInteraction(
+ return self.db.runInteraction(
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
)
+ @trace
def create_e2e_room_keys_version(self, user_id, info):
"""Atomically creates a new version of this user's e2e_room_keys store
with the given version info.
@@ -255,11 +348,11 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
current_version = txn.fetchone()[0]
if current_version is None:
- current_version = '0'
+ current_version = "0"
new_version = str(int(current_version) + 1)
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="e2e_room_keys_versions",
values={
@@ -272,26 +365,40 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return new_version
- return self.runInteraction(
+ return self.db.runInteraction(
"create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
)
- def update_e2e_room_keys_version(self, user_id, version, info):
+ @trace
+ def update_e2e_room_keys_version(
+ self, user_id, version, info=None, version_etag=None
+ ):
"""Update a given backup version
Args:
user_id(str): the user whose backup version we're updating
version(str): the version ID of the backup version we're updating
- info(dict): the new backup version info to store
+ info (dict): the new backup version info to store. If None, then
+ the backup version info is not updated
+ version_etag (Optional[int]): etag of the keys in the backup. If
+ None, then the etag is not updated
"""
+ updatevalues = {}
- return self._simple_update(
- table="e2e_room_keys_versions",
- keyvalues={"user_id": user_id, "version": version},
- updatevalues={"auth_data": json.dumps(info["auth_data"])},
- desc="update_e2e_room_keys_version",
- )
+ if info is not None and "auth_data" in info:
+ updatevalues["auth_data"] = json.dumps(info["auth_data"])
+ if version_etag is not None:
+ updatevalues["etag"] = version_etag
+
+ if updatevalues:
+ return self.db.simple_update(
+ table="e2e_room_keys_versions",
+ keyvalues={"user_id": user_id, "version": version},
+ updatevalues=updatevalues,
+ desc="update_e2e_room_keys_version",
+ )
+ @trace
def delete_e2e_room_keys_version(self, user_id, version=None):
"""Delete a given backup version of the user's room keys.
Doesn't delete their actual key data.
@@ -308,16 +415,24 @@ class EndToEndRoomKeyStore(SQLBaseStore):
def _delete_e2e_room_keys_version_txn(txn):
if version is None:
this_version = self._get_current_version(txn, user_id)
+ if this_version is None:
+ raise StoreError(404, "No current backup version")
else:
this_version = version
- return self._simple_update_one_txn(
+ self.db.simple_delete_txn(
+ txn,
+ table="e2e_room_keys",
+ keyvalues={"user_id": user_id, "version": this_version},
+ )
+
+ return self.db.simple_update_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version},
updatevalues={"deleted": 1},
)
- return self.runInteraction(
+ return self.db.runInteraction(
"delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn
)
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
new file mode 100644
index 0000000000..001a53f9b4
--- /dev/null
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -0,0 +1,771 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Dict, List
+
+from six import iteritems
+
+from canonicaljson import encode_canonical_json, json
+
+from twisted.enterprise.adbapi import Connection
+from twisted.internet import defer
+
+from synapse.logging.opentracing import log_kv, set_tag, trace
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.util.caches.descriptors import cached, cachedList
+
+
+class EndToEndKeyWorkerStore(SQLBaseStore):
+ @trace
+ @defer.inlineCallbacks
+ def get_e2e_device_keys(
+ self, query_list, include_all_devices=False, include_deleted_devices=False
+ ):
+ """Fetch a list of device keys.
+ Args:
+ query_list(list): List of pairs of user_ids and device_ids.
+ include_all_devices (bool): whether to include entries for devices
+ that don't have device keys
+ include_deleted_devices (bool): whether to include null entries for
+ devices which no longer exist (but were in the query_list).
+ This option only takes effect if include_all_devices is true.
+ Returns:
+ Dict mapping from user-id to dict mapping from device_id to
+ key data. The key data will be a dict in the same format as the
+ DeviceKeys type returned by POST /_matrix/client/r0/keys/query.
+ """
+ set_tag("query_list", query_list)
+ if not query_list:
+ return {}
+
+ results = yield self.db.runInteraction(
+ "get_e2e_device_keys",
+ self._get_e2e_device_keys_txn,
+ query_list,
+ include_all_devices,
+ include_deleted_devices,
+ )
+
+ # Build the result structure, un-jsonify the results, and add the
+ # "unsigned" section
+ rv = {}
+ for user_id, device_keys in iteritems(results):
+ rv[user_id] = {}
+ for device_id, device_info in iteritems(device_keys):
+ r = db_to_json(device_info.pop("key_json"))
+ r["unsigned"] = {}
+ display_name = device_info["device_display_name"]
+ if display_name is not None:
+ r["unsigned"]["device_display_name"] = display_name
+ if "signatures" in device_info:
+ for sig_user_id, sigs in device_info["signatures"].items():
+ r.setdefault("signatures", {}).setdefault(
+ sig_user_id, {}
+ ).update(sigs)
+ rv[user_id][device_id] = r
+
+ return rv
+
+ @trace
+ def _get_e2e_device_keys_txn(
+ self, txn, query_list, include_all_devices=False, include_deleted_devices=False
+ ):
+ set_tag("include_all_devices", include_all_devices)
+ set_tag("include_deleted_devices", include_deleted_devices)
+
+ query_clauses = []
+ query_params = []
+ signature_query_clauses = []
+ signature_query_params = []
+
+ if include_all_devices is False:
+ include_deleted_devices = False
+
+ if include_deleted_devices:
+ deleted_devices = set(query_list)
+
+ for (user_id, device_id) in query_list:
+ query_clause = "user_id = ?"
+ query_params.append(user_id)
+ signature_query_clause = "target_user_id = ?"
+ signature_query_params.append(user_id)
+
+ if device_id is not None:
+ query_clause += " AND device_id = ?"
+ query_params.append(device_id)
+ signature_query_clause += " AND target_device_id = ?"
+ signature_query_params.append(device_id)
+
+ signature_query_clause += " AND user_id = ?"
+ signature_query_params.append(user_id)
+
+ query_clauses.append(query_clause)
+ signature_query_clauses.append(signature_query_clause)
+
+ sql = (
+ "SELECT user_id, device_id, "
+ " d.display_name AS device_display_name, "
+ " k.key_json"
+ " FROM devices d"
+ " %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
+ " WHERE %s AND NOT d.hidden"
+ ) % (
+ "LEFT" if include_all_devices else "INNER",
+ " OR ".join("(" + q + ")" for q in query_clauses),
+ )
+
+ txn.execute(sql, query_params)
+ rows = self.db.cursor_to_dict(txn)
+
+ result = {}
+ for row in rows:
+ if include_deleted_devices:
+ deleted_devices.remove((row["user_id"], row["device_id"]))
+ result.setdefault(row["user_id"], {})[row["device_id"]] = row
+
+ if include_deleted_devices:
+ for user_id, device_id in deleted_devices:
+ result.setdefault(user_id, {})[device_id] = None
+
+ # get signatures on the device
+ signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % (
+ " OR ".join("(" + q + ")" for q in signature_query_clauses)
+ )
+
+ txn.execute(signature_sql, signature_query_params)
+ rows = self.db.cursor_to_dict(txn)
+
+ # add each cross-signing signature to the correct device in the result dict.
+ for row in rows:
+ signing_user_id = row["user_id"]
+ signing_key_id = row["key_id"]
+ target_user_id = row["target_user_id"]
+ target_device_id = row["target_device_id"]
+ signature = row["signature"]
+
+ target_user_result = result.get(target_user_id)
+ if not target_user_result:
+ continue
+
+ target_device_result = target_user_result.get(target_device_id)
+ if not target_device_result:
+ # note that target_device_result will be None for deleted devices.
+ continue
+
+ target_device_signatures = target_device_result.setdefault("signatures", {})
+ signing_user_signatures = target_device_signatures.setdefault(
+ signing_user_id, {}
+ )
+ signing_user_signatures[signing_key_id] = signature
+
+ log_kv(result)
+ return result
+
+ @defer.inlineCallbacks
+ def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
+ """Retrieve a number of one-time keys for a user
+
+ Args:
+ user_id(str): id of user to get keys for
+ device_id(str): id of device to get keys for
+ key_ids(list[str]): list of key ids (excluding algorithm) to
+ retrieve
+
+ Returns:
+ deferred resolving to Dict[(str, str), str]: map from (algorithm,
+ key_id) to json string for key
+ """
+
+ rows = yield self.db.simple_select_many_batch(
+ table="e2e_one_time_keys_json",
+ column="key_id",
+ iterable=key_ids,
+ retcols=("algorithm", "key_id", "key_json"),
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ desc="add_e2e_one_time_keys_check",
+ )
+ result = {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
+ log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
+ return result
+
+ @defer.inlineCallbacks
+ def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
+ """Insert some new one time keys for a device. Errors if any of the
+ keys already exist.
+
+ Args:
+ user_id(str): id of user to get keys for
+ device_id(str): id of device to get keys for
+ time_now(long): insertion time to record (ms since epoch)
+ new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
+ (algorithm, key_id, key json)
+ """
+
+ def _add_e2e_one_time_keys(txn):
+ set_tag("user_id", user_id)
+ set_tag("device_id", device_id)
+ set_tag("new_keys", new_keys)
+ # We are protected from race between lookup and insertion due to
+ # a unique constraint. If there is a race of two calls to
+ # `add_e2e_one_time_keys` then they'll conflict and we will only
+ # insert one set.
+ self.db.simple_insert_many_txn(
+ txn,
+ table="e2e_one_time_keys_json",
+ values=[
+ {
+ "user_id": user_id,
+ "device_id": device_id,
+ "algorithm": algorithm,
+ "key_id": key_id,
+ "ts_added_ms": time_now,
+ "key_json": json_bytes,
+ }
+ for algorithm, key_id, json_bytes in new_keys
+ ],
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
+ )
+
+ yield self.db.runInteraction(
+ "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
+ )
+
+ @cached(max_entries=10000)
+ def count_e2e_one_time_keys(self, user_id, device_id):
+ """ Count the number of one time keys the server has for a device
+ Returns:
+ Dict mapping from algorithm to number of keys for that algorithm.
+ """
+
+ def _count_e2e_one_time_keys(txn):
+ sql = (
+ "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
+ " WHERE user_id = ? AND device_id = ?"
+ " GROUP BY algorithm"
+ )
+ txn.execute(sql, (user_id, device_id))
+ result = {}
+ for algorithm, key_count in txn:
+ result[algorithm] = key_count
+ return result
+
+ return self.db.runInteraction(
+ "count_e2e_one_time_keys", _count_e2e_one_time_keys
+ )
+
+ def _get_e2e_cross_signing_key_txn(self, txn, user_id, key_type, from_user_id=None):
+ """Returns a user's cross-signing key.
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ user_id (str): the user whose key is being requested
+ key_type (str): the type of key that is being requested: either 'master'
+ for a master key, 'self_signing' for a self-signing key, or
+ 'user_signing' for a user-signing key
+ from_user_id (str): if specified, signatures made by this user on
+ the key will be included in the result
+
+ Returns:
+ dict of the key data or None if not found
+ """
+ sql = (
+ "SELECT keydata "
+ " FROM e2e_cross_signing_keys "
+ " WHERE user_id = ? AND keytype = ? ORDER BY stream_id DESC LIMIT 1"
+ )
+ txn.execute(sql, (user_id, key_type))
+ row = txn.fetchone()
+ if not row:
+ return None
+ key = json.loads(row[0])
+
+ device_id = None
+ for k in key["keys"].values():
+ device_id = k
+
+ if from_user_id is not None:
+ sql = (
+ "SELECT key_id, signature "
+ " FROM e2e_cross_signing_signatures "
+ " WHERE user_id = ? "
+ " AND target_user_id = ? "
+ " AND target_device_id = ? "
+ )
+ txn.execute(sql, (from_user_id, user_id, device_id))
+ row = txn.fetchone()
+ if row:
+ key.setdefault("signatures", {}).setdefault(from_user_id, {})[
+ row[0]
+ ] = row[1]
+
+ return key
+
+ def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
+ """Returns a user's cross-signing key.
+
+ Args:
+ user_id (str): the user whose key is being requested
+ key_type (str): the type of key that is being requested: either 'master'
+ for a master key, 'self_signing' for a self-signing key, or
+ 'user_signing' for a user-signing key
+ from_user_id (str): if specified, signatures made by this user on
+ the self-signing key will be included in the result
+
+ Returns:
+ dict of the key data or None if not found
+ """
+ return self.db.runInteraction(
+ "get_e2e_cross_signing_key",
+ self._get_e2e_cross_signing_key_txn,
+ user_id,
+ key_type,
+ from_user_id,
+ )
+
+ @cached(num_args=1)
+ def _get_bare_e2e_cross_signing_keys(self, user_id):
+ """Dummy function. Only used to make a cache for
+ _get_bare_e2e_cross_signing_keys_bulk.
+ """
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="_get_bare_e2e_cross_signing_keys",
+ list_name="user_ids",
+ num_args=1,
+ )
+ def _get_bare_e2e_cross_signing_keys_bulk(
+ self, user_ids: List[str]
+ ) -> Dict[str, Dict[str, dict]]:
+ """Returns the cross-signing keys for a set of users. The output of this
+ function should be passed to _get_e2e_cross_signing_signatures_txn if
+ the signatures for the calling user need to be fetched.
+
+ Args:
+ user_ids (list[str]): the users whose keys are being requested
+
+ Returns:
+ dict[str, dict[str, dict]]: mapping from user ID to key type to key
+ data. If a user's cross-signing keys were not found, either
+ their user ID will not be in the dict, or their user ID will map
+ to None.
+
+ """
+ return self.db.runInteraction(
+ "get_bare_e2e_cross_signing_keys_bulk",
+ self._get_bare_e2e_cross_signing_keys_bulk_txn,
+ user_ids,
+ )
+
+ def _get_bare_e2e_cross_signing_keys_bulk_txn(
+ self, txn: Connection, user_ids: List[str],
+ ) -> Dict[str, Dict[str, dict]]:
+ """Returns the cross-signing keys for a set of users. The output of this
+ function should be passed to _get_e2e_cross_signing_signatures_txn if
+ the signatures for the calling user need to be fetched.
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ user_ids (list[str]): the users whose keys are being requested
+
+ Returns:
+ dict[str, dict[str, dict]]: mapping from user ID to key type to key
+ data. If a user's cross-signing keys were not found, their user
+ ID will not be in the dict.
+
+ """
+ result = {}
+
+ batch_size = 100
+ chunks = [
+ user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size)
+ ]
+ for user_chunk in chunks:
+ sql = """
+ SELECT k.user_id, k.keytype, k.keydata, k.stream_id
+ FROM e2e_cross_signing_keys k
+ INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
+ FROM e2e_cross_signing_keys
+ GROUP BY user_id, keytype) s
+ USING (user_id, stream_id, keytype)
+ WHERE k.user_id IN (%s)
+ """ % (
+ ",".join("?" for u in user_chunk),
+ )
+ query_params = []
+ query_params.extend(user_chunk)
+
+ txn.execute(sql, query_params)
+ rows = self.db.cursor_to_dict(txn)
+
+ for row in rows:
+ user_id = row["user_id"]
+ key_type = row["keytype"]
+ key = json.loads(row["keydata"])
+ user_info = result.setdefault(user_id, {})
+ user_info[key_type] = key
+
+ return result
+
+ def _get_e2e_cross_signing_signatures_txn(
+ self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str,
+ ) -> Dict[str, Dict[str, dict]]:
+ """Returns the cross-signing signatures made by a user on a set of keys.
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ keys (dict[str, dict[str, dict]]): a map of user ID to key type to
+ key data. This dict will be modified to add signatures.
+ from_user_id (str): fetch the signatures made by this user
+
+ Returns:
+ dict[str, dict[str, dict]]: mapping from user ID to key type to key
+ data. The return value will be the same as the keys argument,
+ with the modifications included.
+ """
+
+ # find out what cross-signing keys (a.k.a. devices) we need to get
+ # signatures for. This is a map of (user_id, device_id) to key type
+ # (device_id is the key's public part).
+ devices = {}
+
+ for user_id, user_info in keys.items():
+ if user_info is None:
+ continue
+ for key_type, key in user_info.items():
+ device_id = None
+ for k in key["keys"].values():
+ device_id = k
+ devices[(user_id, device_id)] = key_type
+
+ device_list = list(devices)
+
+ # split into batches
+ batch_size = 100
+ chunks = [
+ device_list[i : i + batch_size]
+ for i in range(0, len(device_list), batch_size)
+ ]
+ for user_chunk in chunks:
+ sql = """
+ SELECT target_user_id, target_device_id, key_id, signature
+ FROM e2e_cross_signing_signatures
+ WHERE user_id = ?
+ AND (%s)
+ """ % (
+ " OR ".join(
+ "(target_user_id = ? AND target_device_id = ?)" for d in devices
+ )
+ )
+ query_params = [from_user_id]
+ for item in devices:
+ # item is a (user_id, device_id) tuple
+ query_params.extend(item)
+
+ txn.execute(sql, query_params)
+ rows = self.db.cursor_to_dict(txn)
+
+ # and add the signatures to the appropriate keys
+ for row in rows:
+ key_id = row["key_id"]
+ target_user_id = row["target_user_id"]
+ target_device_id = row["target_device_id"]
+ key_type = devices[(target_user_id, target_device_id)]
+ # We need to copy everything, because the result may have come
+ # from the cache. dict.copy only does a shallow copy, so we
+ # need to recursively copy the dicts that will be modified.
+ user_info = keys[target_user_id] = keys[target_user_id].copy()
+ target_user_key = user_info[key_type] = user_info[key_type].copy()
+ if "signatures" in target_user_key:
+ signatures = target_user_key["signatures"] = target_user_key[
+ "signatures"
+ ].copy()
+ if from_user_id in signatures:
+ user_sigs = signatures[from_user_id] = signatures[from_user_id]
+ user_sigs[key_id] = row["signature"]
+ else:
+ signatures[from_user_id] = {key_id: row["signature"]}
+ else:
+ target_user_key["signatures"] = {
+ from_user_id: {key_id: row["signature"]}
+ }
+
+ return keys
+
+ @defer.inlineCallbacks
+ def get_e2e_cross_signing_keys_bulk(
+ self, user_ids: List[str], from_user_id: str = None
+ ) -> defer.Deferred:
+ """Returns the cross-signing keys for a set of users.
+
+ Args:
+ user_ids (list[str]): the users whose keys are being requested
+ from_user_id (str): if specified, signatures made by this user on
+ the self-signing keys will be included in the result
+
+ Returns:
+ Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to
+ key data. If a user's cross-signing keys were not found, either
+ their user ID will not be in the dict, or their user ID will map
+ to None.
+ """
+
+ result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
+
+ if from_user_id:
+ result = yield self.db.runInteraction(
+ "get_e2e_cross_signing_signatures",
+ self._get_e2e_cross_signing_signatures_txn,
+ result,
+ from_user_id,
+ )
+
+ return result
+
+ def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
+ """Return a list of changes from the user signature stream to notify remotes.
+ Note that the user signature stream represents when a user signs their
+ device with their user-signing key, which is not published to other
+ users or servers, so no `destination` is needed in the returned
+ list. However, this is needed to poke workers.
+
+ Args:
+ from_key (int): the stream ID to start at (exclusive)
+ to_key (int): the stream ID to end at (inclusive)
+
+ Returns:
+ Deferred[list[(int,str)]] a list of `(stream_id, user_id)`
+ """
+ sql = """
+ SELECT MAX(stream_id) AS stream_id, from_user_id AS user_id
+ FROM user_signature_stream
+ WHERE ? < stream_id AND stream_id <= ?
+ GROUP BY user_id
+ """
+ return self.db.execute(
+ "get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key
+ )
+
+
+class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
+ def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
+ """Stores device keys for a device. Returns whether there was a change
+ or the keys were already in the database.
+ """
+
+ def _set_e2e_device_keys_txn(txn):
+ set_tag("user_id", user_id)
+ set_tag("device_id", device_id)
+ set_tag("time_now", time_now)
+ set_tag("device_keys", device_keys)
+
+ old_key_json = self.db.simple_select_one_onecol_txn(
+ txn,
+ table="e2e_device_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ retcol="key_json",
+ allow_none=True,
+ )
+
+ # In py3 we need old_key_json to match new_key_json type. The DB
+ # returns unicode while encode_canonical_json returns bytes.
+ new_key_json = encode_canonical_json(device_keys).decode("utf-8")
+
+ if old_key_json == new_key_json:
+ log_kv({"Message": "Device key already stored."})
+ return False
+
+ self.db.simple_upsert_txn(
+ txn,
+ table="e2e_device_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ values={"ts_added_ms": time_now, "key_json": new_key_json},
+ )
+ log_kv({"message": "Device keys stored."})
+ return True
+
+ return self.db.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn)
+
+ def claim_e2e_one_time_keys(self, query_list):
+ """Take a list of one time keys out of the database"""
+
+ @trace
+ def _claim_e2e_one_time_keys(txn):
+ sql = (
+ "SELECT key_id, key_json FROM e2e_one_time_keys_json"
+ " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
+ " LIMIT 1"
+ )
+ result = {}
+ delete = []
+ for user_id, device_id, algorithm in query_list:
+ user_result = result.setdefault(user_id, {})
+ device_result = user_result.setdefault(device_id, {})
+ txn.execute(sql, (user_id, device_id, algorithm))
+ for key_id, key_json in txn:
+ device_result[algorithm + ":" + key_id] = key_json
+ delete.append((user_id, device_id, algorithm, key_id))
+ sql = (
+ "DELETE FROM e2e_one_time_keys_json"
+ " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
+ " AND key_id = ?"
+ )
+ for user_id, device_id, algorithm, key_id in delete:
+ log_kv(
+ {
+ "message": "Executing claim e2e_one_time_keys transaction on database."
+ }
+ )
+ txn.execute(sql, (user_id, device_id, algorithm, key_id))
+ log_kv({"message": "finished executing and invalidating cache"})
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
+ )
+ return result
+
+ return self.db.runInteraction(
+ "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
+ )
+
+ def delete_e2e_keys_by_device(self, user_id, device_id):
+ def delete_e2e_keys_by_device_txn(txn):
+ log_kv(
+ {
+ "message": "Deleting keys for device",
+ "device_id": device_id,
+ "user_id": user_id,
+ }
+ )
+ self.db.simple_delete_txn(
+ txn,
+ table="e2e_device_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ )
+ self.db.simple_delete_txn(
+ txn,
+ table="e2e_one_time_keys_json",
+ keyvalues={"user_id": user_id, "device_id": device_id},
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.count_e2e_one_time_keys, (user_id, device_id)
+ )
+
+ return self.db.runInteraction(
+ "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
+ )
+
+ def _set_e2e_cross_signing_key_txn(self, txn, user_id, key_type, key):
+ """Set a user's cross-signing key.
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection): db connection
+ user_id (str): the user to set the signing key for
+ key_type (str): the type of key that is being set: either 'master'
+ for a master key, 'self_signing' for a self-signing key, or
+ 'user_signing' for a user-signing key
+ key (dict): the key data
+ """
+ # the 'key' dict will look something like:
+ # {
+ # "user_id": "@alice:example.com",
+ # "usage": ["self_signing"],
+ # "keys": {
+ # "ed25519:base64+self+signing+public+key": "base64+self+signing+public+key",
+ # },
+ # "signatures": {
+ # "@alice:example.com": {
+ # "ed25519:base64+master+public+key": "base64+signature"
+ # }
+ # }
+ # }
+ # The "keys" property must only have one entry, which will be the public
+ # key, so we just grab the first value in there
+ pubkey = next(iter(key["keys"].values()))
+
+ # The cross-signing keys need to occupy the same namespace as devices,
+ # since signatures are identified by device ID. So add an entry to the
+ # device table to make sure that we don't have a collision with device
+ # IDs.
+ # We only need to do this for local users, since remote servers should be
+ # responsible for checking this for their own users.
+ if self.hs.is_mine_id(user_id):
+ self.db.simple_insert_txn(
+ txn,
+ "devices",
+ values={
+ "user_id": user_id,
+ "device_id": pubkey,
+ "display_name": key_type + " signing key",
+ "hidden": True,
+ },
+ )
+
+ # and finally, store the key itself
+ with self._cross_signing_id_gen.get_next() as stream_id:
+ self.db.simple_insert_txn(
+ txn,
+ "e2e_cross_signing_keys",
+ values={
+ "user_id": user_id,
+ "keytype": key_type,
+ "keydata": json.dumps(key),
+ "stream_id": stream_id,
+ },
+ )
+
+ self._invalidate_cache_and_stream(
+ txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
+ )
+
+ def set_e2e_cross_signing_key(self, user_id, key_type, key):
+ """Set a user's cross-signing key.
+
+ Args:
+ user_id (str): the user to set the user-signing key for
+ key_type (str): the type of cross-signing key to set
+ key (dict): the key data
+ """
+ return self.db.runInteraction(
+ "add_e2e_cross_signing_key",
+ self._set_e2e_cross_signing_key_txn,
+ user_id,
+ key_type,
+ key,
+ )
+
+ def store_e2e_cross_signing_signatures(self, user_id, signatures):
+ """Stores cross-signing signatures.
+
+ Args:
+ user_id (str): the user who made the signatures
+ signatures (iterable[SignatureListItem]): signatures to add
+ """
+ return self.db.simple_insert_many(
+ "e2e_cross_signing_signatures",
+ [
+ {
+ "user_id": user_id,
+ "key_id": item.signing_key_id,
+ "target_user_id": item.target_user_id,
+ "target_device_id": item.target_device_id,
+ "signature": item.signature,
+ }
+ for item in signatures
+ ],
+ "add_e2e_signing_key",
+ )
diff --git a/synapse/storage/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index 09e39c2c28..62d4e9f599 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -12,22 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import itertools
import logging
-import random
+from typing import Dict, List, Optional, Set, Tuple
-from six.moves import range
from six.moves.queue import Empty, PriorityQueue
-from unpaddedbase64 import encode_base64
-
from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.events_worker import EventsWorkerStore
-from synapse.storage.signatures import SignatureWorkerStore
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.data_stores.main.signatures import SignatureWorkerStore
+from synapse.storage.database import Database
from synapse.util.caches.descriptors import cached
+from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__)
@@ -47,37 +47,55 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
event_ids, include_given=include_given
).addCallback(self.get_events_as_list)
- def get_auth_chain_ids(self, event_ids, include_given=False):
+ def get_auth_chain_ids(
+ self,
+ event_ids: List[str],
+ include_given: bool = False,
+ ignore_events: Optional[Set[str]] = None,
+ ):
"""Get auth events for given event_ids. The events *must* be state events.
Args:
- event_ids (list): state events
- include_given (bool): include the given events in result
+ event_ids: state events
+ include_given: include the given events in result
+ ignore_events: Set of events to exclude from the returned auth
+ chain. This is useful if the caller will just discard the
+ given events anyway, and saves us from figuring out their auth
+ chains if not required.
Returns:
list of event_ids
"""
- return self.runInteraction(
- "get_auth_chain_ids", self._get_auth_chain_ids_txn, event_ids, include_given
+ return self.db.runInteraction(
+ "get_auth_chain_ids",
+ self._get_auth_chain_ids_txn,
+ event_ids,
+ include_given,
+ ignore_events,
)
- def _get_auth_chain_ids_txn(self, txn, event_ids, include_given):
+ def _get_auth_chain_ids_txn(self, txn, event_ids, include_given, ignore_events):
+ if ignore_events is None:
+ ignore_events = set()
+
if include_given:
results = set(event_ids)
else:
results = set()
- base_sql = "SELECT auth_id FROM event_auth WHERE event_id IN (%s)"
+ base_sql = "SELECT auth_id FROM event_auth WHERE "
front = set(event_ids)
while front:
new_front = set()
- front_list = list(front)
- chunks = [front_list[x : x + 100] for x in range(0, len(front), 100)]
- for chunk in chunks:
- txn.execute(base_sql % (",".join(["?"] * len(chunk)),), chunk)
- new_front.update([r[0] for r in txn])
+ for chunk in batch_iter(front, 100):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "event_id", chunk
+ )
+ txn.execute(base_sql + clause, args)
+ new_front.update(r[0] for r in txn)
+ new_front -= ignore_events
new_front -= results
front = new_front
@@ -85,13 +103,161 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return list(results)
+ def get_auth_chain_difference(self, state_sets: List[Set[str]]):
+ """Given sets of state events figure out the auth chain difference (as
+ per state res v2 algorithm).
+
+ This equivalent to fetching the full auth chain for each set of state
+ and returning the events that don't appear in each and every auth
+ chain.
+
+ Returns:
+ Deferred[Set[str]]
+ """
+
+ return self.db.runInteraction(
+ "get_auth_chain_difference",
+ self._get_auth_chain_difference_txn,
+ state_sets,
+ )
+
+ def _get_auth_chain_difference_txn(
+ self, txn, state_sets: List[Set[str]]
+ ) -> Set[str]:
+
+ # Algorithm Description
+ # ~~~~~~~~~~~~~~~~~~~~~
+ #
+ # The idea here is to basically walk the auth graph of each state set in
+ # tandem, keeping track of which auth events are reachable by each state
+ # set. If we reach an auth event we've already visited (via a different
+ # state set) then we mark that auth event and all ancestors as reachable
+ # by the state set. This requires that we keep track of the auth chains
+ # in memory.
+ #
+ # Doing it in a such a way means that we can stop early if all auth
+ # events we're currently walking are reachable by all state sets.
+ #
+ # *Note*: We can't stop walking an event's auth chain if it is reachable
+ # by all state sets. This is because other auth chains we're walking
+ # might be reachable only via the original auth chain. For example,
+ # given the following auth chain:
+ #
+ # A -> C -> D -> E
+ # / /
+ # B -´---------´
+ #
+ # and state sets {A} and {B} then walking the auth chains of A and B
+ # would immediately show that C is reachable by both. However, if we
+ # stopped at C then we'd only reach E via the auth chain of B and so E
+ # would errornously get included in the returned difference.
+ #
+ # The other thing that we do is limit the number of auth chains we walk
+ # at once, due to practical limits (i.e. we can only query the database
+ # with a limited set of parameters). We pick the auth chains we walk
+ # each iteration based on their depth, in the hope that events with a
+ # lower depth are likely reachable by those with higher depths.
+ #
+ # We could use any ordering that we believe would give a rough
+ # topological ordering, e.g. origin server timestamp. If the ordering
+ # chosen is not topological then the algorithm still produces the right
+ # result, but perhaps a bit more inefficiently. This is why it is safe
+ # to use "depth" here.
+
+ initial_events = set(state_sets[0]).union(*state_sets[1:])
+
+ # Dict from events in auth chains to which sets *cannot* reach them.
+ # I.e. if the set is empty then all sets can reach the event.
+ event_to_missing_sets = {
+ event_id: {i for i, a in enumerate(state_sets) if event_id not in a}
+ for event_id in initial_events
+ }
+
+ # We need to get the depth of the initial events for sorting purposes.
+ sql = """
+ SELECT depth, event_id FROM events
+ WHERE %s
+ ORDER BY depth ASC
+ """
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "event_id", initial_events
+ )
+ txn.execute(sql % (clause,), args)
+
+ # The sorted list of events whose auth chains we should walk.
+ search = txn.fetchall() # type: List[Tuple[int, str]]
+
+ # Map from event to its auth events
+ event_to_auth_events = {} # type: Dict[str, Set[str]]
+
+ base_sql = """
+ SELECT a.event_id, auth_id, depth
+ FROM event_auth AS a
+ INNER JOIN events AS e ON (e.event_id = a.auth_id)
+ WHERE
+ """
+
+ while search:
+ # Check whether all our current walks are reachable by all state
+ # sets. If so we can bail.
+ if all(not event_to_missing_sets[eid] for _, eid in search):
+ break
+
+ # Fetch the auth events and their depths of the N last events we're
+ # currently walking
+ search, chunk = search[:-100], search[-100:]
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "a.event_id", [e_id for _, e_id in chunk]
+ )
+ txn.execute(base_sql + clause, args)
+
+ for event_id, auth_event_id, auth_event_depth in txn:
+ event_to_auth_events.setdefault(event_id, set()).add(auth_event_id)
+
+ sets = event_to_missing_sets.get(auth_event_id)
+ if sets is None:
+ # First time we're seeing this event, so we add it to the
+ # queue of things to fetch.
+ search.append((auth_event_depth, auth_event_id))
+
+ # Assume that this event is unreachable from any of the
+ # state sets until proven otherwise
+ sets = event_to_missing_sets[auth_event_id] = set(
+ range(len(state_sets))
+ )
+ else:
+ # We've previously seen this event, so look up its auth
+ # events and recursively mark all ancestors as reachable
+ # by the current event's state set.
+ a_ids = event_to_auth_events.get(auth_event_id)
+ while a_ids:
+ new_aids = set()
+ for a_id in a_ids:
+ event_to_missing_sets[a_id].intersection_update(
+ event_to_missing_sets[event_id]
+ )
+
+ b = event_to_auth_events.get(a_id)
+ if b:
+ new_aids.update(b)
+
+ a_ids = new_aids
+
+ # Mark that the auth event is reachable by the approriate sets.
+ sets.intersection_update(event_to_missing_sets[event_id])
+
+ search.sort()
+
+ # Return all events where not all sets can reach them.
+ return {eid for eid, n in event_to_missing_sets.items() if n}
+
def get_oldest_events_in_room(self, room_id):
- return self.runInteraction(
+ return self.db.runInteraction(
"get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
)
def get_oldest_events_with_depth_in_room(self, room_id):
- return self.runInteraction(
+ return self.db.runInteraction(
"get_oldest_events_with_depth_in_room",
self.get_oldest_events_with_depth_in_room_txn,
room_id,
@@ -122,7 +288,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
Returns
Deferred[int]
"""
- rows = yield self._simple_select_many_batch(
+ rows = yield self.db.simple_select_many_batch(
table="events",
column="event_id",
iterable=event_ids,
@@ -131,20 +297,19 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
)
if not rows:
- defer.returnValue(0)
+ return 0
else:
- defer.returnValue(max(row["depth"] for row in rows))
+ return max(row["depth"] for row in rows)
def _get_oldest_events_in_room_txn(self, txn, room_id):
- return self._simple_select_onecol_txn(
+ return self.db.simple_select_onecol_txn(
txn,
table="event_backward_extremities",
keyvalues={"room_id": room_id},
retcol="event_id",
)
- @defer.inlineCallbacks
- def get_prev_events_for_room(self, room_id):
+ def get_prev_events_for_room(self, room_id: str):
"""
Gets a subset of the current forward extremities in the given room.
@@ -155,80 +320,87 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
room_id (str): room_id
Returns:
- Deferred[list[(str, dict[str, str], int)]]
- for each event, a tuple of (event_id, hashes, depth)
- where *hashes* is a map from algorithm to hash.
+ Deferred[List[str]]: the event ids of the forward extremites
+
"""
- res = yield self.get_latest_event_ids_and_hashes_in_room(room_id)
- if len(res) > 10:
- # Sort by reverse depth, so we point to the most recent.
- res.sort(key=lambda a: -a[2])
- # we use half of the limit for the actual most recent events, and
- # the other half to randomly point to some of the older events, to
- # make sure that we don't completely ignore the older events.
- res = res[0:5] + random.sample(res[5:], 5)
+ return self.db.runInteraction(
+ "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
+ )
- defer.returnValue(res)
+ def _get_prev_events_for_room_txn(self, txn, room_id: str):
+ # we just use the 10 newest events. Older events will become
+ # prev_events of future events.
- def get_latest_event_ids_and_hashes_in_room(self, room_id):
+ sql = """
+ SELECT e.event_id FROM event_forward_extremities AS f
+ INNER JOIN events AS e USING (event_id)
+ WHERE f.room_id = ?
+ ORDER BY e.depth DESC
+ LIMIT 10
"""
- Gets the current forward extremities in the given room
+
+ txn.execute(sql, (room_id,))
+
+ return [row[0] for row in txn]
+
+ def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter):
+ """Get the top rooms with at least N extremities.
Args:
- room_id (str): room_id
+ min_count (int): The minimum number of extremities
+ limit (int): The maximum number of rooms to return.
+ room_id_filter (iterable[str]): room_ids to exclude from the results
Returns:
- Deferred[list[(str, dict[str, str], int)]]
- for each event, a tuple of (event_id, hashes, depth)
- where *hashes* is a map from algorithm to hash.
+ Deferred[list]: At most `limit` room IDs that have at least
+ `min_count` extremities, sorted by extremity count.
"""
- return self.runInteraction(
- "get_latest_event_ids_and_hashes_in_room",
- self._get_latest_event_ids_and_hashes_in_room,
- room_id,
+ def _get_rooms_with_many_extremities_txn(txn):
+ where_clause = "1=1"
+ if room_id_filter:
+ where_clause = "room_id NOT IN (%s)" % (
+ ",".join("?" for _ in room_id_filter),
+ )
+
+ sql = """
+ SELECT room_id FROM event_forward_extremities
+ WHERE %s
+ GROUP BY room_id
+ HAVING count(*) > ?
+ ORDER BY count(*) DESC
+ LIMIT ?
+ """ % (
+ where_clause,
+ )
+
+ query_args = list(itertools.chain(room_id_filter, [min_count, limit]))
+ txn.execute(sql, query_args)
+ return [room_id for room_id, in txn]
+
+ return self.db.runInteraction(
+ "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn
)
@cached(max_entries=5000, iterable=True)
def get_latest_event_ids_in_room(self, room_id):
- return self._simple_select_onecol(
+ return self.db.simple_select_onecol(
table="event_forward_extremities",
keyvalues={"room_id": room_id},
retcol="event_id",
desc="get_latest_event_ids_in_room",
)
- def _get_latest_event_ids_and_hashes_in_room(self, txn, room_id):
- sql = (
- "SELECT e.event_id, e.depth FROM events as e "
- "INNER JOIN event_forward_extremities as f "
- "ON e.event_id = f.event_id "
- "AND e.room_id = f.room_id "
- "WHERE f.room_id = ?"
- )
-
- txn.execute(sql, (room_id,))
-
- results = []
- for event_id, depth in txn.fetchall():
- hashes = self._get_event_reference_hashes_txn(txn, event_id)
- prev_hashes = {
- k: encode_base64(v) for k, v in hashes.items() if k == "sha256"
- }
- results.append((event_id, prev_hashes, depth))
-
- return results
-
def get_min_depth(self, room_id):
""" For hte given room, get the minimum depth we have seen for it.
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"get_min_depth", self._get_min_depth_interaction, room_id
)
def _get_min_depth_interaction(self, txn, room_id):
- min_depth = self._simple_select_one_onecol_txn(
+ min_depth = self.db.simple_select_one_onecol_txn(
txn,
table="room_depth",
keyvalues={"room_id": room_id},
@@ -294,7 +466,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(sql, (stream_ordering, room_id))
return [event_id for event_id, in txn]
- return self.runInteraction(
+ return self.db.runInteraction(
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
)
@@ -309,7 +481,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
limit (int)
"""
return (
- self.runInteraction(
+ self.db.runInteraction(
"get_backfill_events",
self._get_backfill_events,
room_id,
@@ -321,9 +493,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
)
def _get_backfill_events(self, txn, room_id, event_list, limit):
- logger.debug(
- "_get_backfill_events: %s, %s, %s", room_id, repr(event_list), limit
- )
+ logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
event_results = set()
@@ -342,7 +512,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
queue = PriorityQueue()
for event_id in event_list:
- depth = self._simple_select_one_onecol_txn(
+ depth = self.db.simple_select_one_onecol_txn(
txn,
table="events",
keyvalues={"event_id": event_id, "room_id": room_id},
@@ -374,7 +544,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
@defer.inlineCallbacks
def get_missing_events(self, room_id, earliest_events, latest_events, limit):
- ids = yield self.runInteraction(
+ ids = yield self.db.runInteraction(
"get_missing_events",
self._get_missing_events,
room_id,
@@ -383,7 +553,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
limit,
)
events = yield self.get_events_as_list(ids)
- defer.returnValue(events)
+ return events
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
@@ -404,7 +574,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
query, (room_id, event_id, False, limit - len(event_results))
)
- new_results = set(t[0] for t in txn) - seen_events
+ new_results = {t[0] for t in txn} - seen_events
new_front |= new_results
seen_events |= new_results
@@ -427,7 +597,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
Returns:
Deferred[list[str]]
"""
- rows = yield self._simple_select_many_batch(
+ rows = yield self.db.simple_select_many_batch(
table="event_edges",
column="prev_event_id",
iterable=event_ids,
@@ -435,7 +605,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
desc="get_successor_events",
)
- defer.returnValue([row["event_id"] for row in rows])
+ return [row["event_id"] for row in rows]
class EventFederationStore(EventFederationWorkerStore):
@@ -450,10 +620,10 @@ class EventFederationStore(EventFederationWorkerStore):
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
- def __init__(self, db_conn, hs):
- super(EventFederationStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(EventFederationStore, self).__init__(database, db_conn, hs)
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
)
@@ -464,10 +634,10 @@ class EventFederationStore(EventFederationWorkerStore):
def _update_min_depth_for_room_txn(self, txn, room_id, depth):
min_depth = self._get_min_depth_interaction(txn, room_id)
- if min_depth and depth >= min_depth:
+ if min_depth is not None and depth >= min_depth:
return
- self._simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="room_depth",
keyvalues={"room_id": room_id},
@@ -479,7 +649,7 @@ class EventFederationStore(EventFederationWorkerStore):
For the given event, update the event edges table and forward and
backward extremities tables.
"""
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="event_edges",
values=[
@@ -563,13 +733,13 @@ class EventFederationStore(EventFederationWorkerStore):
return run_as_background_process(
"delete_old_forward_extrem_cache",
- self.runInteraction,
+ self.db.runInteraction,
"_delete_old_forward_extrem_cache",
_delete_old_forward_extrem_cache_txn,
)
def clean_room_for_join(self, room_id):
- return self.runInteraction(
+ return self.db.runInteraction(
"clean_room_for_join", self._clean_room_for_join_txn, room_id
)
@@ -613,17 +783,17 @@ class EventFederationStore(EventFederationWorkerStore):
"max_stream_id_exclusive": min_stream_id,
}
- self._background_update_progress_txn(
+ self.db.updates._background_update_progress_txn(
txn, self.EVENT_AUTH_STATE_ONLY, new_progress
)
return min_stream_id >= target_min_stream_id
- result = yield self.runInteraction(
+ result = yield self.db.runInteraction(
self.EVENT_AUTH_STATE_ONLY, delete_event_auth
)
if not result:
- yield self._end_background_update(self.EVENT_AUTH_STATE_ONLY)
+ yield self.db.updates._end_background_update(self.EVENT_AUTH_STATE_ONLY)
- defer.returnValue(batch_size)
+ return batch_size
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index a729f3e067..8eed590929 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -24,6 +24,7 @@ from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore
+from synapse.storage.database import Database
from synapse.util.caches.descriptors import cachedInlineCallbacks
logger = logging.getLogger(__name__)
@@ -68,8 +69,8 @@ def _deserialize_action(actions, is_highlight):
class EventPushActionsWorkerStore(SQLBaseStore):
- def __init__(self, db_conn, hs):
- super(EventPushActionsWorkerStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs)
# These get correctly set by _find_stream_orderings_for_times_txn
self.stream_ordering_month_ago = None
@@ -79,8 +80,6 @@ class EventPushActionsWorkerStore(SQLBaseStore):
db_conn.cursor(),
name="_find_stream_orderings_for_times_txn",
database_engine=self.database_engine,
- after_callbacks=[],
- exception_callbacks=[],
)
self._find_stream_orderings_for_times_txn(cur)
cur.close()
@@ -95,14 +94,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id
):
- ret = yield self.runInteraction(
+ ret = yield self.db.runInteraction(
"get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn,
room_id,
user_id,
last_read_event_id,
)
- defer.returnValue(ret)
+ return ret
def _get_unread_counts_by_receipt_txn(
self, txn, room_id, user_id, last_read_event_id
@@ -179,8 +178,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
return [r[0] for r in txn]
- ret = yield self.runInteraction("get_push_action_users_in_range", f)
- defer.returnValue(ret)
+ ret = yield self.db.runInteraction("get_push_action_users_in_range", f)
+ return ret
@defer.inlineCallbacks
def get_unread_push_actions_for_user_in_range_for_http(
@@ -231,7 +230,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
- after_read_receipt = yield self.runInteraction(
+ after_read_receipt = yield self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt
)
@@ -259,7 +258,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
- no_read_receipt = yield self.runInteraction(
+ no_read_receipt = yield self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt
)
@@ -277,11 +276,11 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# contain results from the first query, correctly ordered, followed
# by results from the second query, but we want them all ordered
# by stream_ordering, oldest first.
- notifs.sort(key=lambda r: r['stream_ordering'])
+ notifs.sort(key=lambda r: r["stream_ordering"])
# Take only up to the limit. We have to stop at the limit because
# one of the subqueries may have hit the limit.
- defer.returnValue(notifs[:limit])
+ return notifs[:limit]
@defer.inlineCallbacks
def get_unread_push_actions_for_user_in_range_for_email(
@@ -331,7 +330,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
- after_read_receipt = yield self.runInteraction(
+ after_read_receipt = yield self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt
)
@@ -359,7 +358,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, args)
return txn.fetchall()
- no_read_receipt = yield self.runInteraction(
+ no_read_receipt = yield self.db.runInteraction(
"get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt
)
@@ -379,10 +378,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
# contain results from the first query, correctly ordered, followed
# by results from the second query, but we want them all ordered
# by received_ts (most recent first)
- notifs.sort(key=lambda r: -(r['received_ts'] or 0))
+ notifs.sort(key=lambda r: -(r["received_ts"] or 0))
# Now return the first `limit`
- defer.returnValue(notifs[:limit])
+ return notifs[:limit]
def get_if_maybe_push_in_range_for_user(self, user_id, min_stream_ordering):
"""A fast check to see if there might be something to push for the
@@ -409,7 +408,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, min_stream_ordering))
return bool(txn.fetchone())
- return self.runInteraction(
+ return self.db.runInteraction(
"get_if_maybe_push_in_range_for_user",
_get_if_maybe_push_in_range_for_user_txn,
)
@@ -443,7 +442,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
def _add_push_actions_to_staging_txn(txn):
- # We don't use _simple_insert_many here to avoid the overhead
+ # We don't use simple_insert_many here to avoid the overhead
# of generating lists of dicts.
sql = """
@@ -460,7 +459,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
),
)
- return self.runInteraction(
+ return self.db.runInteraction(
"add_push_actions_to_staging", _add_push_actions_to_staging_txn
)
@@ -474,12 +473,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"""
try:
- res = yield self._simple_delete(
+ res = yield self.db.simple_delete(
table="event_push_actions_staging",
keyvalues={"event_id": event_id},
desc="remove_push_actions_from_staging",
)
- defer.returnValue(res)
+ return res
except Exception:
# this method is called from an exception handler, so propagating
# another exception here really isn't helpful - there's nothing
@@ -491,7 +490,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def _find_stream_orderings_for_times(self):
return run_as_background_process(
"event_push_action_stream_orderings",
- self.runInteraction,
+ self.db.runInteraction,
"_find_stream_orderings_for_times",
self._find_stream_orderings_for_times_txn,
)
@@ -527,7 +526,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
Deferred[int]: stream ordering of the first event received on/after
the timestamp
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"_find_first_stream_ordering_after_ts_txn",
self._find_first_stream_ordering_after_ts_txn,
ts,
@@ -609,21 +608,38 @@ class EventPushActionsWorkerStore(SQLBaseStore):
return range_end
+ @defer.inlineCallbacks
+ def get_time_of_last_push_action_before(self, stream_ordering):
+ def f(txn):
+ sql = (
+ "SELECT e.received_ts"
+ " FROM event_push_actions AS ep"
+ " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
+ " WHERE ep.stream_ordering > ?"
+ " ORDER BY ep.stream_ordering ASC"
+ " LIMIT 1"
+ )
+ txn.execute(sql, (stream_ordering,))
+ return txn.fetchone()
+
+ result = yield self.db.runInteraction("get_time_of_last_push_action_before", f)
+ return result[0] if result else None
+
class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
- def __init__(self, db_conn, hs):
- super(EventPushActionsStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(EventPushActionsStore, self).__init__(database, db_conn, hs)
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
self.EPA_HIGHLIGHT_INDEX,
index_name="event_push_actions_u_highlight",
table="event_push_actions",
columns=["user_id", "stream_ordering"],
)
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"event_push_actions_highlights_index",
index_name="event_push_actions_highlights_index",
table="event_push_actions",
@@ -679,7 +695,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
)
for event, _ in events_and_contexts:
- user_ids = self._simple_select_onecol_txn(
+ user_ids = self.db.simple_select_onecol_txn(
txn,
table="event_push_actions_staging",
keyvalues={"event_id": event.event_id},
@@ -729,29 +745,12 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
" LIMIT ?" % (before_clause,)
)
txn.execute(sql, args)
- return self.cursor_to_dict(txn)
+ return self.db.cursor_to_dict(txn)
- push_actions = yield self.runInteraction("get_push_actions_for_user", f)
+ push_actions = yield self.db.runInteraction("get_push_actions_for_user", f)
for pa in push_actions:
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
- defer.returnValue(push_actions)
-
- @defer.inlineCallbacks
- def get_time_of_last_push_action_before(self, stream_ordering):
- def f(txn):
- sql = (
- "SELECT e.received_ts"
- " FROM event_push_actions AS ep"
- " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id"
- " WHERE ep.stream_ordering > ?"
- " ORDER BY ep.stream_ordering ASC"
- " LIMIT 1"
- )
- txn.execute(sql, (stream_ordering,))
- return txn.fetchone()
-
- result = yield self.runInteraction("get_time_of_last_push_action_before", f)
- defer.returnValue(result[0] if result else None)
+ return push_actions
@defer.inlineCallbacks
def get_latest_push_action_stream_ordering(self):
@@ -759,8 +758,10 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
return txn.fetchone()
- result = yield self.runInteraction("get_latest_push_action_stream_ordering", f)
- defer.returnValue(result[0] or 0)
+ result = yield self.db.runInteraction(
+ "get_latest_push_action_stream_ordering", f
+ )
+ return result[0] or 0
def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
# Sad that we have to blow away the cache for the whole room here
@@ -832,7 +833,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
while True:
logger.info("Rotating notifications")
- caught_up = yield self.runInteraction(
+ caught_up = yield self.db.runInteraction(
"_rotate_notifs", self._rotate_notifs_txn
)
if caught_up:
@@ -846,7 +847,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
the archiving process has caught up or not.
"""
- old_rotate_stream_ordering = self._simple_select_one_onecol_txn(
+ old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
keyvalues={},
@@ -865,7 +866,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
)
stream_row = txn.fetchone()
if stream_row:
- offset_stream_ordering, = stream_row
+ (offset_stream_ordering,) = stream_row
rotate_to_stream_ordering = min(
self.stream_ordering_day_ago, offset_stream_ordering
)
@@ -882,7 +883,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
return caught_up
def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering):
- old_rotate_stream_ordering = self._simple_select_one_onecol_txn(
+ old_rotate_stream_ordering = self.db.simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
keyvalues={},
@@ -914,7 +915,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
# If the `old.user_id` above is NULL then we know there isn't already an
# entry in the table, so we simply insert it. Otherwise we update the
# existing table.
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="event_push_summary",
values=[
diff --git a/synapse/storage/events.py b/synapse/storage/data_stores/main/events.py
index bc3e6de3bf..d593ef47b8 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -17,8 +17,9 @@
import itertools
import logging
-from collections import OrderedDict, deque, namedtuple
+from collections import Counter as c_counter, OrderedDict, namedtuple
from functools import wraps
+from typing import Dict, List, Tuple
from six import iteritems, text_type
from six.moves import range
@@ -29,24 +30,25 @@ from prometheus_client import Counter
from twisted.internet import defer
import synapse.metrics
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventContentFields, EventTypes
from synapse.api.errors import SynapseError
+from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
+from synapse.events.utils import prune_event_dict
+from synapse.logging.utils import log_function
+from synapse.metrics import BucketCollector
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.state import StateResolutionStore
-from synapse.storage.background_updates import BackgroundUpdateStore
-from synapse.storage.event_federation import EventFederationStore
-from synapse.storage.events_worker import EventsWorkerStore
-from synapse.storage.state import StateGroupWorkerStore
-from synapse.types import RoomStreamToken, get_domain_from_id
-from synapse.util import batch_iter
-from synapse.util.async_helpers import ObservableDeferred
+from synapse.storage._base import make_in_list_sql_clause
+from synapse.storage.data_stores.main.event_federation import EventFederationStore
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.data_stores.main.state import StateGroupWorkerStore
+from synapse.storage.database import Database, LoggingTransaction
+from synapse.storage.persist_events import DeltaState
+from synapse.types import RoomStreamToken, StateMap, get_domain_from_id
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.frozenutils import frozendict_json_encoder
-from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
-from synapse.util.logutils import log_function
-from synapse.util.metrics import Measure
+from synapse.util.iterutils import batch_iter
logger = logging.getLogger(__name__)
@@ -57,22 +59,6 @@ event_counter = Counter(
["type", "origin_type", "origin_entity"],
)
-# The number of times we are recalculating the current state
-state_delta_counter = Counter("synapse_storage_events_state_delta", "")
-
-# The number of times we are recalculating state when there is only a
-# single forward extremity
-state_delta_single_event_counter = Counter(
- "synapse_storage_events_state_delta_single_event", ""
-)
-
-# The number of times we are reculating state when we could have resonably
-# calculated the delta when we calculated the state for an event we were
-# persisting.
-state_delta_reuse_delta_counter = Counter(
- "synapse_storage_events_state_delta_reuse_delta", ""
-)
-
def encode_json(json_object):
"""
@@ -84,110 +70,6 @@ def encode_json(json_object):
return out
-class _EventPeristenceQueue(object):
- """Queues up events so that they can be persisted in bulk with only one
- concurrent transaction per room.
- """
-
- _EventPersistQueueItem = namedtuple(
- "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred")
- )
-
- def __init__(self):
- self._event_persist_queues = {}
- self._currently_persisting_rooms = set()
-
- def add_to_queue(self, room_id, events_and_contexts, backfilled):
- """Add events to the queue, with the given persist_event options.
-
- NB: due to the normal usage pattern of this method, it does *not*
- follow the synapse logcontext rules, and leaves the logcontext in
- place whether or not the returned deferred is ready.
-
- Args:
- room_id (str):
- events_and_contexts (list[(EventBase, EventContext)]):
- backfilled (bool):
-
- Returns:
- defer.Deferred: a deferred which will resolve once the events are
- persisted. Runs its callbacks *without* a logcontext.
- """
- queue = self._event_persist_queues.setdefault(room_id, deque())
- if queue:
- # if the last item in the queue has the same `backfilled` setting,
- # we can just add these new events to that item.
- end_item = queue[-1]
- if end_item.backfilled == backfilled:
- end_item.events_and_contexts.extend(events_and_contexts)
- return end_item.deferred.observe()
-
- deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
-
- queue.append(
- self._EventPersistQueueItem(
- events_and_contexts=events_and_contexts,
- backfilled=backfilled,
- deferred=deferred,
- )
- )
-
- return deferred.observe()
-
- def handle_queue(self, room_id, per_item_callback):
- """Attempts to handle the queue for a room if not already being handled.
-
- The given callback will be invoked with for each item in the queue,
- of type _EventPersistQueueItem. The per_item_callback will continuously
- be called with new items, unless the queue becomnes empty. The return
- value of the function will be given to the deferreds waiting on the item,
- exceptions will be passed to the deferreds as well.
-
- This function should therefore be called whenever anything is added
- to the queue.
-
- If another callback is currently handling the queue then it will not be
- invoked.
- """
-
- if room_id in self._currently_persisting_rooms:
- return
-
- self._currently_persisting_rooms.add(room_id)
-
- @defer.inlineCallbacks
- def handle_queue_loop():
- try:
- queue = self._get_drainining_queue(room_id)
- for item in queue:
- try:
- ret = yield per_item_callback(item)
- except Exception:
- with PreserveLoggingContext():
- item.deferred.errback()
- else:
- with PreserveLoggingContext():
- item.deferred.callback(ret)
- finally:
- queue = self._event_persist_queues.pop(room_id, None)
- if queue:
- self._event_persist_queues[room_id] = queue
- self._currently_persisting_rooms.discard(room_id)
-
- # set handle_queue_loop off in the background
- run_as_background_process("persist_events", handle_queue_loop)
-
- def _get_drainining_queue(self, room_id):
- queue = self._event_persist_queues.setdefault(room_id, deque())
-
- try:
- while True:
- yield queue.popleft()
- except IndexError:
- # Queue has been drained.
- pass
-
-
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
@@ -203,11 +85,11 @@ def _retry_on_integrity_error(func):
@defer.inlineCallbacks
def f(self, *args, **kwargs):
try:
- res = yield func(self, *args, **kwargs)
+ res = yield func(self, *args, delete_existing=False, **kwargs)
except self.database_engine.module.IntegrityError:
logger.exception("IntegrityError, retrying.")
res = yield func(self, *args, delete_existing=True, **kwargs)
- defer.returnValue(res)
+ return res
return f
@@ -215,106 +97,101 @@ def _retry_on_integrity_error(func):
# inherits from EventFederationStore so that we can call _update_backward_extremities
# and _handle_mult_prev_events (though arguably those could both be moved in here)
class EventsStore(
- StateGroupWorkerStore,
- EventFederationStore,
- EventsWorkerStore,
- BackgroundUpdateStore,
+ StateGroupWorkerStore, EventFederationStore, EventsWorkerStore,
):
+ def __init__(self, database: Database, db_conn, hs):
+ super(EventsStore, self).__init__(database, db_conn, hs)
- def __init__(self, db_conn, hs):
- super(EventsStore, self).__init__(db_conn, hs)
-
- self._event_persist_queue = _EventPeristenceQueue()
- self._state_resolution_handler = hs.get_state_resolution_handler()
+ # Collect metrics on the number of forward extremities that exist.
+ # Counter of number of extremities to count
+ self._current_forward_extremities_amount = c_counter()
- @defer.inlineCallbacks
- def persist_events(self, events_and_contexts, backfilled=False):
- """
- Write events to the database
- Args:
- events_and_contexts: list of tuples of (event, context)
- backfilled (bool): Whether the results are retrieved from federation
- via backfill or not. Used to determine if they're "new" events
- which might update the current state etc.
+ BucketCollector(
+ "synapse_forward_extremities",
+ lambda: self._current_forward_extremities_amount,
+ buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"],
+ )
- Returns:
- Deferred[int]: the stream ordering of the latest persisted event
- """
- partitioned = {}
- for event, ctx in events_and_contexts:
- partitioned.setdefault(event.room_id, []).append((event, ctx))
-
- deferreds = []
- for room_id, evs_ctxs in iteritems(partitioned):
- d = self._event_persist_queue.add_to_queue(
- room_id, evs_ctxs, backfilled=backfilled
+ # Read the extrems every 60 minutes
+ def read_forward_extremities():
+ # run as a background process to make sure that the database transactions
+ # have a logcontext to report to
+ return run_as_background_process(
+ "read_forward_extremities", self._read_forward_extremities
)
- deferreds.append(d)
- for room_id in partitioned:
- self._maybe_start_persisting(room_id)
+ hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000)
- yield make_deferred_yieldable(
- defer.gatherResults(deferreds, consumeErrors=True)
- )
+ def _censor_redactions():
+ return run_as_background_process(
+ "_censor_redactions", self._censor_redactions
+ )
- max_persisted_id = yield self._stream_id_gen.get_current_token()
+ if self.hs.config.redaction_retention_period is not None:
+ hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
- defer.returnValue(max_persisted_id)
+ self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
+ self.is_mine_id = hs.is_mine_id
@defer.inlineCallbacks
- @log_function
- def persist_event(self, event, context, backfilled=False):
- """
-
- Args:
- event (EventBase):
- context (EventContext):
- backfilled (bool):
-
- Returns:
- Deferred: resolves to (int, int): the stream ordering of ``event``,
- and the stream ordering of the latest persisted event
- """
- deferred = self._event_persist_queue.add_to_queue(
- event.room_id, [(event, context)], backfilled=backfilled
- )
-
- self._maybe_start_persisting(event.room_id)
-
- yield make_deferred_yieldable(deferred)
-
- max_persisted_id = yield self._stream_id_gen.get_current_token()
- defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
-
- def _maybe_start_persisting(self, room_id):
- @defer.inlineCallbacks
- def persisting_queue(item):
- with Measure(self._clock, "persist_events"):
- yield self._persist_events(
- item.events_and_contexts, backfilled=item.backfilled
- )
+ def _read_forward_extremities(self):
+ def fetch(txn):
+ txn.execute(
+ """
+ select count(*) c from event_forward_extremities
+ group by room_id
+ """
+ )
+ return txn.fetchall()
- self._event_persist_queue.handle_queue(room_id, persisting_queue)
+ res = yield self.db.runInteraction("read_forward_extremities", fetch)
+ self._current_forward_extremities_amount = c_counter([x[0] for x in res])
@_retry_on_integrity_error
@defer.inlineCallbacks
- def _persist_events(
- self, events_and_contexts, backfilled=False, delete_existing=False
+ def _persist_events_and_state_updates(
+ self,
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
+ current_state_for_room: Dict[str, StateMap[str]],
+ state_delta_for_room: Dict[str, DeltaState],
+ new_forward_extremeties: Dict[str, List[str]],
+ backfilled: bool = False,
+ delete_existing: bool = False,
):
- """Persist events to db
+ """Persist a set of events alongside updates to the current state and
+ forward extremities tables.
Args:
- events_and_contexts (list[(EventBase, EventContext)]):
- backfilled (bool):
- delete_existing (bool):
+ events_and_contexts:
+ current_state_for_room: Map from room_id to the current state of
+ the room based on forward extremities
+ state_delta_for_room: Map from room_id to the delta to apply to
+ room state
+ new_forward_extremities: Map from room_id to list of event IDs
+ that are the new forward extremities of the room.
+ backfilled
+ delete_existing
Returns:
Deferred: resolves when the events have been persisted
"""
- if not events_and_contexts:
- return
+ # We want to calculate the stream orderings as late as possible, as
+ # we only notify after all events with a lesser stream ordering have
+ # been persisted. I.e. if we spend 10s inside the with block then
+ # that will delay all subsequent events from being notified about.
+ # Hence why we do it down here rather than wrapping the entire
+ # function.
+ #
+ # Its safe to do this after calculating the state deltas etc as we
+ # only need to protect the *persistence* of the events. This is to
+ # ensure that queries of the form "fetch events since X" don't
+ # return events and stream positions after events that are still in
+ # flight, as otherwise subsequent requests "fetch event since Y"
+ # will not return those events.
+ #
+ # Note: Multiple instances of this function cannot be in flight at
+ # the same time for the same room.
if backfilled:
stream_ordering_manager = self._backfill_id_gen.get_next_mult(
len(events_and_contexts)
@@ -328,216 +205,44 @@ class EventsStore(
for (event, context), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream
- chunks = [
- events_and_contexts[x : x + 100]
- for x in range(0, len(events_and_contexts), 100)
- ]
-
- for chunk in chunks:
- # We can't easily parallelize these since different chunks
- # might contain the same event. :(
-
- # NB: Assumes that we are only persisting events for one room
- # at a time.
-
- # map room_id->list[event_ids] giving the new forward
- # extremities in each room
- new_forward_extremeties = {}
-
- # map room_id->(type,state_key)->event_id tracking the full
- # state in each room after adding these events.
- # This is simply used to prefill the get_current_state_ids
- # cache
- current_state_for_room = {}
-
- # map room_id->(to_delete, to_insert) where to_delete is a list
- # of type/state keys to remove from current state, and to_insert
- # is a map (type,key)->event_id giving the state delta in each
- # room
- state_delta_for_room = {}
-
- if not backfilled:
- with Measure(self._clock, "_calculate_state_and_extrem"):
- # Work out the new "current state" for each room.
- # We do this by working out what the new extremities are and then
- # calculating the state from that.
- events_by_room = {}
- for event, context in chunk:
- events_by_room.setdefault(event.room_id, []).append(
- (event, context)
- )
-
- for room_id, ev_ctx_rm in iteritems(events_by_room):
- latest_event_ids = yield self.get_latest_event_ids_in_room(
- room_id
- )
- new_latest_event_ids = yield self._calculate_new_extremities(
- room_id, ev_ctx_rm, latest_event_ids
- )
-
- latest_event_ids = set(latest_event_ids)
- if new_latest_event_ids == latest_event_ids:
- # No change in extremities, so no change in state
- continue
-
- # there should always be at least one forward extremity.
- # (except during the initial persistence of the send_join
- # results, in which case there will be no existing
- # extremities, so we'll `continue` above and skip this bit.)
- assert new_latest_event_ids, "No forward extremities left!"
-
- new_forward_extremeties[room_id] = new_latest_event_ids
-
- len_1 = (
- len(latest_event_ids) == 1
- and len(new_latest_event_ids) == 1
- )
- if len_1:
- all_single_prev_not_state = all(
- len(event.prev_event_ids()) == 1
- and not event.is_state()
- for event, ctx in ev_ctx_rm
- )
- # Don't bother calculating state if they're just
- # a long chain of single ancestor non-state events.
- if all_single_prev_not_state:
- continue
-
- state_delta_counter.inc()
- if len(new_latest_event_ids) == 1:
- state_delta_single_event_counter.inc()
-
- # This is a fairly handwavey check to see if we could
- # have guessed what the delta would have been when
- # processing one of these events.
- # What we're interested in is if the latest extremities
- # were the same when we created the event as they are
- # now. When this server creates a new event (as opposed
- # to receiving it over federation) it will use the
- # forward extremities as the prev_events, so we can
- # guess this by looking at the prev_events and checking
- # if they match the current forward extremities.
- for ev, _ in ev_ctx_rm:
- prev_event_ids = set(ev.prev_event_ids())
- if latest_event_ids == prev_event_ids:
- state_delta_reuse_delta_counter.inc()
- break
-
- logger.info("Calculating state delta for room %s", room_id)
- with Measure(
- self._clock, "persist_events.get_new_state_after_events"
- ):
- res = yield self._get_new_state_after_events(
- room_id,
- ev_ctx_rm,
- latest_event_ids,
- new_latest_event_ids,
- )
- current_state, delta_ids = res
-
- # If either are not None then there has been a change,
- # and we need to work out the delta (or use that
- # given)
- if delta_ids is not None:
- # If there is a delta we know that we've
- # only added or replaced state, never
- # removed keys entirely.
- state_delta_for_room[room_id] = ([], delta_ids)
- elif current_state is not None:
- with Measure(
- self._clock, "persist_events.calculate_state_delta"
- ):
- delta = yield self._calculate_state_delta(
- room_id, current_state
- )
- state_delta_for_room[room_id] = delta
-
- # If we have the current_state then lets prefill
- # the cache with it.
- if current_state is not None:
- current_state_for_room[room_id] = current_state
-
- yield self.runInteraction(
- "persist_events",
- self._persist_events_txn,
- events_and_contexts=chunk,
- backfilled=backfilled,
- delete_existing=delete_existing,
- state_delta_for_room=state_delta_for_room,
- new_forward_extremeties=new_forward_extremeties,
- )
- persist_event_counter.inc(len(chunk))
-
- if not backfilled:
- # backfilled events have negative stream orderings, so we don't
- # want to set the event_persisted_position to that.
- synapse.metrics.event_persisted_position.set(
- chunk[-1][0].internal_metadata.stream_ordering
- )
-
- for event, context in chunk:
- if context.app_service:
- origin_type = "local"
- origin_entity = context.app_service.id
- elif self.hs.is_mine_id(event.sender):
- origin_type = "local"
- origin_entity = "*client*"
- else:
- origin_type = "remote"
- origin_entity = get_domain_from_id(event.sender)
-
- event_counter.labels(event.type, origin_type, origin_entity).inc()
-
- for room_id, new_state in iteritems(current_state_for_room):
- self.get_current_state_ids.prefill((room_id,), new_state)
-
- for room_id, latest_event_ids in iteritems(new_forward_extremeties):
- self.get_latest_event_ids_in_room.prefill(
- (room_id,), list(latest_event_ids)
- )
-
- @defer.inlineCallbacks
- def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids):
- """Calculates the new forward extremities for a room given events to
- persist.
-
- Assumes that we are only persisting events for one room at a time.
- """
-
- # we're only interested in new events which aren't outliers and which aren't
- # being rejected.
- new_events = [
- event
- for event, ctx in event_contexts
- if not event.internal_metadata.is_outlier()
- and not ctx.rejected
- and not event.internal_metadata.is_soft_failed()
- ]
-
- # start with the existing forward extremities
- result = set(latest_event_ids)
+ yield self.db.runInteraction(
+ "persist_events",
+ self._persist_events_txn,
+ events_and_contexts=events_and_contexts,
+ backfilled=backfilled,
+ delete_existing=delete_existing,
+ state_delta_for_room=state_delta_for_room,
+ new_forward_extremeties=new_forward_extremeties,
+ )
+ persist_event_counter.inc(len(events_and_contexts))
- # add all the new events to the list
- result.update(event.event_id for event in new_events)
+ if not backfilled:
+ # backfilled events have negative stream orderings, so we don't
+ # want to set the event_persisted_position to that.
+ synapse.metrics.event_persisted_position.set(
+ events_and_contexts[-1][0].internal_metadata.stream_ordering
+ )
- # Now remove all events which are prev_events of any of the new events
- result.difference_update(
- e_id for event in new_events for e_id in event.prev_event_ids()
- )
+ for event, context in events_and_contexts:
+ if context.app_service:
+ origin_type = "local"
+ origin_entity = context.app_service.id
+ elif self.hs.is_mine_id(event.sender):
+ origin_type = "local"
+ origin_entity = "*client*"
+ else:
+ origin_type = "remote"
+ origin_entity = get_domain_from_id(event.sender)
- # Remove any events which are prev_events of any existing events.
- existing_prevs = yield self._get_events_which_are_prevs(result)
- result.difference_update(existing_prevs)
+ event_counter.labels(event.type, origin_type, origin_entity).inc()
- # Finally handle the case where the new events have soft-failed prev
- # events. If they do we need to remove them and their prev events,
- # otherwise we end up with dangling extremities.
- existing_prevs = yield self._get_prevs_before_rejected(
- e_id for event in new_events for e_id in event.prev_event_ids()
- )
- result.difference_update(existing_prevs)
+ for room_id, new_state in iteritems(current_state_for_room):
+ self.get_current_state_ids.prefill((room_id,), new_state)
- defer.returnValue(result)
+ for room_id, latest_event_ids in iteritems(new_forward_extremeties):
+ self.get_latest_event_ids_in_room.prefill(
+ (room_id,), list(latest_event_ids)
+ )
@defer.inlineCallbacks
def _get_events_which_are_prevs(self, event_ids):
@@ -560,28 +265,24 @@ class EventsStore(
LEFT JOIN rejections USING (event_id)
LEFT JOIN event_json USING (event_id)
WHERE
- prev_event_id IN (%s)
- AND NOT events.outlier
+ NOT events.outlier
AND rejections.event_id IS NULL
- """ % (
- ",".join("?" for _ in batch),
- )
+ AND
+ """
- txn.execute(sql, batch)
- results.extend(
- r[0]
- for r in txn
- if not json.loads(r[1]).get("soft_failed")
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "prev_event_id", batch
)
+ txn.execute(sql + clause, args)
+ results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed"))
+
for chunk in batch_iter(event_ids, 100):
- yield self.runInteraction(
- "_get_events_which_are_prevs",
- _get_events_which_are_prevs_txn,
- chunk,
+ yield self.db.runInteraction(
+ "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
)
- defer.returnValue(results)
+ return results
@defer.inlineCallbacks
def _get_prevs_before_rejected(self, event_ids):
@@ -620,13 +321,15 @@ class EventsStore(
LEFT JOIN rejections USING (event_id)
LEFT JOIN event_json USING (event_id)
WHERE
- event_id IN (%s)
- AND NOT events.outlier
- """ % (
- ",".join("?" for _ in to_recursively_check),
+ NOT events.outlier
+ AND
+ """
+
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "event_id", to_recursively_check
)
- txn.execute(sql, to_recursively_check)
+ txn.execute(sql + clause, args)
to_recursively_check = []
for event_id, prev_event_id, metadata, rejected in txn:
@@ -639,205 +342,21 @@ class EventsStore(
existing_prevs.add(prev_event_id)
for chunk in batch_iter(event_ids, 100):
- yield self.runInteraction(
- "_get_prevs_before_rejected",
- _get_prevs_before_rejected_txn,
- chunk,
+ yield self.db.runInteraction(
+ "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
)
- defer.returnValue(existing_prevs)
-
- @defer.inlineCallbacks
- def _get_new_state_after_events(
- self, room_id, events_context, old_latest_event_ids, new_latest_event_ids
- ):
- """Calculate the current state dict after adding some new events to
- a room
-
- Args:
- room_id (str):
- room to which the events are being added. Used for logging etc
-
- events_context (list[(EventBase, EventContext)]):
- events and contexts which are being added to the room
-
- old_latest_event_ids (iterable[str]):
- the old forward extremities for the room.
-
- new_latest_event_ids (iterable[str]):
- the new forward extremities for the room.
-
- Returns:
- Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]:
- Returns a tuple of two state maps, the first being the full new current
- state and the second being the delta to the existing current state.
- If both are None then there has been no change.
-
- If there has been a change then we only return the delta if its
- already been calculated. Conversely if we do know the delta then
- the new current state is only returned if we've already calculated
- it.
- """
- # map from state_group to ((type, key) -> event_id) state map
- state_groups_map = {}
-
- # Map from (prev state group, new state group) -> delta state dict
- state_group_deltas = {}
-
- for ev, ctx in events_context:
- if ctx.state_group is None:
- # This should only happen for outlier events.
- if not ev.internal_metadata.is_outlier():
- raise Exception(
- "Context for new event %s has no state "
- "group" % (ev.event_id,)
- )
- continue
-
- if ctx.state_group in state_groups_map:
- continue
-
- # We're only interested in pulling out state that has already
- # been cached in the context. We'll pull stuff out of the DB later
- # if necessary.
- current_state_ids = ctx.get_cached_current_state_ids()
- if current_state_ids is not None:
- state_groups_map[ctx.state_group] = current_state_ids
-
- if ctx.prev_group:
- state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
-
- # We need to map the event_ids to their state groups. First, let's
- # check if the event is one we're persisting, in which case we can
- # pull the state group from its context.
- # Otherwise we need to pull the state group from the database.
-
- # Set of events we need to fetch groups for. (We know none of the old
- # extremities are going to be in events_context).
- missing_event_ids = set(old_latest_event_ids)
-
- event_id_to_state_group = {}
- for event_id in new_latest_event_ids:
- # First search in the list of new events we're adding.
- for ev, ctx in events_context:
- if event_id == ev.event_id and ctx.state_group is not None:
- event_id_to_state_group[event_id] = ctx.state_group
- break
- else:
- # If we couldn't find it, then we'll need to pull
- # the state from the database
- missing_event_ids.add(event_id)
-
- if missing_event_ids:
- # Now pull out the state groups for any missing events from DB
- event_to_groups = yield self._get_state_group_for_events(missing_event_ids)
- event_id_to_state_group.update(event_to_groups)
-
- # State groups of old_latest_event_ids
- old_state_groups = set(
- event_id_to_state_group[evid] for evid in old_latest_event_ids
- )
-
- # State groups of new_latest_event_ids
- new_state_groups = set(
- event_id_to_state_group[evid] for evid in new_latest_event_ids
- )
-
- # If they old and new groups are the same then we don't need to do
- # anything.
- if old_state_groups == new_state_groups:
- defer.returnValue((None, None))
-
- if len(new_state_groups) == 1 and len(old_state_groups) == 1:
- # If we're going from one state group to another, lets check if
- # we have a delta for that transition. If we do then we can just
- # return that.
-
- new_state_group = next(iter(new_state_groups))
- old_state_group = next(iter(old_state_groups))
-
- delta_ids = state_group_deltas.get((old_state_group, new_state_group), None)
- if delta_ids is not None:
- # We have a delta from the existing to new current state,
- # so lets just return that. If we happen to already have
- # the current state in memory then lets also return that,
- # but it doesn't matter if we don't.
- new_state = state_groups_map.get(new_state_group)
- defer.returnValue((new_state, delta_ids))
-
- # Now that we have calculated new_state_groups we need to get
- # their state IDs so we can resolve to a single state set.
- missing_state = new_state_groups - set(state_groups_map)
- if missing_state:
- group_to_state = yield self._get_state_for_groups(missing_state)
- state_groups_map.update(group_to_state)
-
- if len(new_state_groups) == 1:
- # If there is only one state group, then we know what the current
- # state is.
- defer.returnValue((state_groups_map[new_state_groups.pop()], None))
-
- # Ok, we need to defer to the state handler to resolve our state sets.
-
- state_groups = {sg: state_groups_map[sg] for sg in new_state_groups}
-
- events_map = {ev.event_id: ev for ev, _ in events_context}
-
- # We need to get the room version, which is in the create event.
- # Normally that'd be in the database, but its also possible that we're
- # currently trying to persist it.
- room_version = None
- for ev, _ in events_context:
- if ev.type == EventTypes.Create and ev.state_key == "":
- room_version = ev.content.get("room_version", "1")
- break
-
- if not room_version:
- room_version = yield self.get_room_version(room_id)
-
- logger.debug("calling resolve_state_groups from preserve_events")
- res = yield self._state_resolution_handler.resolve_state_groups(
- room_id,
- room_version,
- state_groups,
- events_map,
- state_res_store=StateResolutionStore(self),
- )
-
- defer.returnValue((res.state, None))
-
- @defer.inlineCallbacks
- def _calculate_state_delta(self, room_id, current_state):
- """Calculate the new state deltas for a room.
-
- Assumes that we are only persisting events for one room at a time.
-
- Returns:
- tuple[list, dict] (to_delete, to_insert): where to_delete are the
- type/state_keys to remove from current_state_events and `to_insert`
- are the updates to current_state_events.
- """
- existing_state = yield self.get_current_state_ids(room_id)
-
- to_delete = [key for key in existing_state if key not in current_state]
-
- to_insert = {
- key: ev_id
- for key, ev_id in iteritems(current_state)
- if ev_id != existing_state.get(key)
- }
-
- defer.returnValue((to_delete, to_insert))
+ return existing_prevs
@log_function
def _persist_events_txn(
self,
- txn,
- events_and_contexts,
- backfilled,
- delete_existing=False,
- state_delta_for_room={},
- new_forward_extremeties={},
+ txn: LoggingTransaction,
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
+ backfilled: bool,
+ delete_existing: bool = False,
+ state_delta_for_room: Dict[str, DeltaState] = {},
+ new_forward_extremeties: Dict[str, List[str]] = {},
):
"""Insert some number of room events into the necessary database tables.
@@ -846,21 +365,16 @@ class EventsStore(
whether the event was rejected.
Args:
- txn (twisted.enterprise.adbapi.Connection): db connection
- events_and_contexts (list[(EventBase, EventContext)]):
- events to persist
- backfilled (bool): True if the events were backfilled
- delete_existing (bool): True to purge existing table rows for the
- events from the database. This is useful when retrying due to
+ txn
+ events_and_contexts: events to persist
+ backfilled: True if the events were backfilled
+ delete_existing True to purge existing table rows for the events
+ from the database. This is useful when retrying due to
IntegrityError.
- state_delta_for_room (dict[str, (list, dict)]):
- The current-state delta for each room. For each room, a tuple
- (to_delete, to_insert), being a list of type/state keys to be
- removed from the current state, and a state set to be added to
- the current state.
- new_forward_extremeties (dict[str, list[str]]):
- The new forward extremities for each room. For each room, a
- list of the event ids which are the forward extremities.
+ state_delta_for_room: The current-state delta for each room.
+ new_forward_extremetie: The new forward extremities for each room.
+ For each room, a list of the event ids which are the forward
+ extremities.
"""
all_events_and_contexts = events_and_contexts
@@ -868,8 +382,6 @@ class EventsStore(
min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
- self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
-
self._update_forward_extremities_txn(
txn,
new_forward_extremities=new_forward_extremeties,
@@ -912,7 +424,7 @@ class EventsStore(
# event's auth chain, but its easier for now just to store them (and
# it doesn't take much storage compared to storing the entire event
# anyway).
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="event_auth",
values=[
@@ -943,88 +455,135 @@ class EventsStore(
backfilled=backfilled,
)
- def _update_current_state_txn(self, txn, state_delta_by_room, stream_id):
- for room_id, current_state_tuple in iteritems(state_delta_by_room):
- to_delete, to_insert = current_state_tuple
-
- # First we add entries to the current_state_delta_stream. We
- # do this before updating the current_state_events table so
- # that we can use it to calculate the `prev_event_id`. (This
- # allows us to not have to pull out the existing state
- # unnecessarily).
- #
- # The stream_id for the update is chosen to be the minimum of the stream_ids
- # for the batch of the events that we are persisting; that means we do not
- # end up in a situation where workers see events before the
- # current_state_delta updates.
- #
- sql = """
- INSERT INTO current_state_delta_stream
- (stream_id, room_id, type, state_key, event_id, prev_event_id)
- SELECT ?, ?, ?, ?, ?, (
- SELECT event_id FROM current_state_events
- WHERE room_id = ? AND type = ? AND state_key = ?
+ # We call this last as it assumes we've inserted the events into
+ # room_memberships, where applicable.
+ self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+
+ def _update_current_state_txn(
+ self,
+ txn: LoggingTransaction,
+ state_delta_by_room: Dict[str, DeltaState],
+ stream_id: int,
+ ):
+ for room_id, delta_state in iteritems(state_delta_by_room):
+ to_delete = delta_state.to_delete
+ to_insert = delta_state.to_insert
+
+ if delta_state.no_longer_in_room:
+ # Server is no longer in the room so we delete the room from
+ # current_state_events, being careful we've already updated the
+ # rooms.room_version column (which gets populated in a
+ # background task).
+ self._upsert_room_version_txn(txn, room_id)
+
+ # Before deleting we populate the current_state_delta_stream
+ # so that async background tasks get told what happened.
+ sql = """
+ INSERT INTO current_state_delta_stream
+ (stream_id, room_id, type, state_key, event_id, prev_event_id)
+ SELECT ?, room_id, type, state_key, null, event_id
+ FROM current_state_events
+ WHERE room_id = ?
+ """
+ txn.execute(sql, (stream_id, room_id))
+
+ self.db.simple_delete_txn(
+ txn, table="current_state_events", keyvalues={"room_id": room_id},
)
- """
- txn.executemany(
- sql,
- (
- (
- stream_id,
- room_id,
- etype,
- state_key,
- None,
- room_id,
- etype,
- state_key,
+ else:
+ # We're still in the room, so we update the current state as normal.
+
+ # First we add entries to the current_state_delta_stream. We
+ # do this before updating the current_state_events table so
+ # that we can use it to calculate the `prev_event_id`. (This
+ # allows us to not have to pull out the existing state
+ # unnecessarily).
+ #
+ # The stream_id for the update is chosen to be the minimum of the stream_ids
+ # for the batch of the events that we are persisting; that means we do not
+ # end up in a situation where workers see events before the
+ # current_state_delta updates.
+ #
+ sql = """
+ INSERT INTO current_state_delta_stream
+ (stream_id, room_id, type, state_key, event_id, prev_event_id)
+ SELECT ?, ?, ?, ?, ?, (
+ SELECT event_id FROM current_state_events
+ WHERE room_id = ? AND type = ? AND state_key = ?
)
- for etype, state_key in to_delete
- # We sanity check that we're deleting rather than updating
- if (etype, state_key) not in to_insert
- ),
- )
- txn.executemany(
- sql,
- (
+ """
+ txn.executemany(
+ sql,
(
- stream_id,
- room_id,
- etype,
- state_key,
- ev_id,
- room_id,
- etype,
- state_key,
- )
- for (etype, state_key), ev_id in iteritems(to_insert)
- ),
- )
+ (
+ stream_id,
+ room_id,
+ etype,
+ state_key,
+ to_insert.get((etype, state_key)),
+ room_id,
+ etype,
+ state_key,
+ )
+ for etype, state_key in itertools.chain(to_delete, to_insert)
+ ),
+ )
+ # Now we actually update the current_state_events table
- # Now we actually update the current_state_events table
+ txn.executemany(
+ "DELETE FROM current_state_events"
+ " WHERE room_id = ? AND type = ? AND state_key = ?",
+ (
+ (room_id, etype, state_key)
+ for etype, state_key in itertools.chain(to_delete, to_insert)
+ ),
+ )
- txn.executemany(
- "DELETE FROM current_state_events"
- " WHERE room_id = ? AND type = ? AND state_key = ?",
- (
- (room_id, etype, state_key)
- for etype, state_key in itertools.chain(to_delete, to_insert)
- ),
- )
+ # We include the membership in the current state table, hence we do
+ # a lookup when we insert. This assumes that all events have already
+ # been inserted into room_memberships.
+ txn.executemany(
+ """INSERT INTO current_state_events
+ (room_id, type, state_key, event_id, membership)
+ VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
+ """,
+ [
+ (room_id, key[0], key[1], ev_id, ev_id)
+ for key, ev_id in iteritems(to_insert)
+ ],
+ )
- self._simple_insert_many_txn(
- txn,
- table="current_state_events",
- values=[
- {
- "event_id": ev_id,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- }
- for key, ev_id in iteritems(to_insert)
- ],
- )
+ # We now update `local_current_membership`. We do this regardless
+ # of whether we're still in the room or not to handle the case where
+ # e.g. we just got banned (where we need to record that fact here).
+
+ # Note: Do we really want to delete rows here (that we do not
+ # subsequently reinsert below)? While technically correct it means
+ # we have no record of the fact the user *was* a member of the
+ # room but got, say, state reset out of it.
+ if to_delete or to_insert:
+ txn.executemany(
+ "DELETE FROM local_current_membership"
+ " WHERE room_id = ? AND user_id = ?",
+ (
+ (room_id, state_key)
+ for etype, state_key in itertools.chain(to_delete, to_insert)
+ if etype == EventTypes.Member and self.is_mine_id(state_key)
+ ),
+ )
+
+ if to_insert:
+ txn.executemany(
+ """INSERT INTO local_current_membership
+ (room_id, user_id, event_id, membership)
+ VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
+ """,
+ [
+ (room_id, key[1], ev_id, ev_id)
+ for key, ev_id in to_insert.items()
+ if key[0] == EventTypes.Member and self.is_mine_id(key[1])
+ ],
+ )
txn.call_after(
self._curr_state_delta_stream_cache.entity_has_changed,
@@ -1039,11 +598,11 @@ class EventsStore(
# We find out which membership events we may have deleted
# and which we have added, then we invlidate the caches for all
# those users.
- members_changed = set(
+ members_changed = {
state_key
for ev_type, state_key in itertools.chain(to_delete, to_insert)
if ev_type == EventTypes.Member
- )
+ }
for member in members_changed:
txn.call_after(
@@ -1052,16 +611,45 @@ class EventsStore(
self._invalidate_state_caches_and_stream(txn, room_id, members_changed)
+ def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str):
+ """Update the room version in the database based off current state
+ events.
+
+ This is used when we're about to delete current state and we want to
+ ensure that the `rooms.room_version` column is up to date.
+ """
+
+ sql = """
+ SELECT json FROM event_json
+ INNER JOIN current_state_events USING (room_id, event_id)
+ WHERE room_id = ? AND type = ? AND state_key = ?
+ """
+ txn.execute(sql, (room_id, EventTypes.Create, ""))
+ row = txn.fetchone()
+ if row:
+ event_json = json.loads(row[0])
+ content = event_json.get("content", {})
+ creator = content.get("creator")
+ room_version_id = content.get("room_version", RoomVersions.V1.identifier)
+
+ self.db.simple_upsert_txn(
+ txn,
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ values={"room_version": room_version_id},
+ insertion_values={"is_public": False, "creator": creator},
+ )
+
def _update_forward_extremities_txn(
self, txn, new_forward_extremities, max_stream_order
):
for room_id, new_extrem in iteritems(new_forward_extremities):
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
)
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="event_forward_extremities",
values=[
@@ -1074,7 +662,7 @@ class EventsStore(
# new stream_ordering to new forward extremeties in the room.
# This allows us to later efficiently look up the forward extremeties
# for a room before a given stream_ordering
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="stream_ordering_to_exterm",
values=[
@@ -1191,16 +779,14 @@ class EventsStore(
metadata_json = encode_json(event.internal_metadata.get_dict())
- sql = (
- "UPDATE event_json SET internal_metadata = ?" " WHERE event_id = ?"
- )
+ sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
txn.execute(sql, (metadata_json, event.event_id))
# Add an entry to the ex_outlier_stream table to replicate the
# change in outlier status to our workers.
stream_order = event.internal_metadata.stream_ordering
state_group_id = context.state_group
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="ex_outlier_stream",
values={
@@ -1210,7 +796,7 @@ class EventsStore(
},
)
- sql = "UPDATE events SET outlier = ?" " WHERE event_id = ?"
+ sql = "UPDATE events SET outlier = ? WHERE event_id = ?"
txn.execute(sql, (False, event.event_id))
# Update the event_backward_extremities table now that this
@@ -1236,15 +822,11 @@ class EventsStore(
"event_reference_hashes",
"event_search",
"event_to_state_groups",
- "guest_access",
- "history_visibility",
"local_invites",
- "room_names",
"state_events",
"rejections",
"redactions",
"room_memberships",
- "topics",
):
txn.executemany(
"DELETE FROM %s WHERE event_id = ?" % (table,),
@@ -1276,7 +858,7 @@ class EventsStore(
d.pop("redacted_because", None)
return d
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="event_json",
values=[
@@ -1293,7 +875,7 @@ class EventsStore(
],
)
- self._simple_insert_many_txn(
+ self.db.simple_insert_many_txn(
txn,
table="events",
values=[
@@ -1318,6 +900,18 @@ class EventsStore(
],
)
+ for event, _ in events_and_contexts:
+ if not event.internal_metadata.is_redacted():
+ # If we're persisting an unredacted event we go and ensure
+ # that we mark any redactions that reference this event as
+ # requiring censoring.
+ self.db.simple_update_txn(
+ txn,
+ table="redactions",
+ keyvalues={"redacts": event.event_id},
+ updatevalues={"have_censored": False},
+ )
+
def _store_rejected_events_txn(self, txn, events_and_contexts):
"""Add rows to the 'rejections' table for received events which were
rejected
@@ -1388,29 +982,36 @@ class EventsStore(
for event, _ in events_and_contexts:
if event.type == EventTypes.Name:
- # Insert into the room_names and event_search tables.
+ # Insert into the event_search table.
self._store_room_name_txn(txn, event)
elif event.type == EventTypes.Topic:
- # Insert into the topics table and event_search table.
+ # Insert into the event_search table.
self._store_room_topic_txn(txn, event)
elif event.type == EventTypes.Message:
# Insert into the event_search table.
self._store_room_message_txn(txn, event)
- elif event.type == EventTypes.Redaction:
+ elif event.type == EventTypes.Redaction and event.redacts is not None:
# Insert into the redactions table.
self._store_redaction(txn, event)
- elif event.type == EventTypes.RoomHistoryVisibility:
- # Insert into the event_search table.
- self._store_history_visibility_txn(txn, event)
- elif event.type == EventTypes.GuestAccess:
- # Insert into the event_search table.
- self._store_guest_access_txn(txn, event)
elif event.type == EventTypes.Retention:
# Update the room_retention table.
self._store_retention_policy_for_room_txn(txn, event)
self._handle_event_relations(txn, event)
+ # Store the labels for this event.
+ labels = event.content.get(EventContentFields.LABELS)
+ if labels:
+ self.insert_labels_for_event_txn(
+ txn, event.event_id, labels, event.room_id, event.depth
+ )
+
+ if self._ephemeral_messages_enabled:
+ # If there's an expiry timestamp on the event, store it.
+ expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
+ if isinstance(expiry_ts, int) and not event.is_state():
+ self._insert_event_expiry_txn(txn, event.event_id, expiry_ts)
+
# Insert into the room_memberships table.
self._store_room_members_txn(
txn,
@@ -1446,7 +1047,7 @@ class EventsStore(
state_values.append(vals)
- self._simple_insert_many_txn(txn, table="state_events", values=state_values)
+ self.db.simple_insert_many_txn(txn, table="state_events", values=state_values)
# Prefill the event cache
self._add_to_cache(txn, events_and_contexts)
@@ -1469,11 +1070,15 @@ class EventsStore(
" FROM events as e"
" LEFT JOIN rejections as rej USING (event_id)"
" LEFT JOIN redactions as r ON e.event_id = r.redacts"
- " WHERE e.event_id IN (%s)"
- ) % (",".join(["?"] * len(ev_map)),)
+ " WHERE "
+ )
- txn.execute(sql, list(ev_map))
- rows = self.cursor_to_dict(txn)
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "e.event_id", list(ev_map)
+ )
+
+ txn.execute(sql + clause, args)
+ rows = self.db.cursor_to_dict(txn)
for row in rows:
event = ev_map[row["event_id"]]
if not row["rejects"] and not row["redacts"]:
@@ -1490,9 +1095,118 @@ class EventsStore(
def _store_redaction(self, txn, event):
# invalidate the cache for the redacted event
txn.call_after(self._invalidate_get_event_cache, event.redacts)
- txn.execute(
- "INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
- (event.event_id, event.redacts),
+
+ self.db.simple_insert_txn(
+ txn,
+ table="redactions",
+ values={
+ "event_id": event.event_id,
+ "redacts": event.redacts,
+ "received_ts": self._clock.time_msec(),
+ },
+ )
+
+ async def _censor_redactions(self):
+ """Censors all redactions older than the configured period that haven't
+ been censored yet.
+
+ By censor we mean update the event_json table with the redacted event.
+ """
+
+ if self.hs.config.redaction_retention_period is None:
+ return
+
+ if not (
+ await self.db.updates.has_completed_background_update(
+ "redactions_have_censored_ts_idx"
+ )
+ ):
+ # We don't want to run this until the appropriate index has been
+ # created.
+ return
+
+ before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period
+
+ # We fetch all redactions that:
+ # 1. point to an event we have,
+ # 2. has a received_ts from before the cut off, and
+ # 3. we haven't yet censored.
+ #
+ # This is limited to 100 events to ensure that we don't try and do too
+ # much at once. We'll get called again so this should eventually catch
+ # up.
+ sql = """
+ SELECT redactions.event_id, redacts FROM redactions
+ LEFT JOIN events AS original_event ON (
+ redacts = original_event.event_id
+ )
+ WHERE NOT have_censored
+ AND redactions.received_ts <= ?
+ ORDER BY redactions.received_ts ASC
+ LIMIT ?
+ """
+
+ rows = await self.db.execute(
+ "_censor_redactions_fetch", None, sql, before_ts, 100
+ )
+
+ updates = []
+
+ for redaction_id, event_id in rows:
+ redaction_event = await self.get_event(redaction_id, allow_none=True)
+ original_event = await self.get_event(
+ event_id, allow_rejected=True, allow_none=True
+ )
+
+ # The SQL above ensures that we have both the redaction and
+ # original event, so if the `get_event` calls return None it
+ # means that the redaction wasn't allowed. Either way we know that
+ # the result won't change so we mark the fact that we've checked.
+ if (
+ redaction_event
+ and original_event
+ and original_event.internal_metadata.is_redacted()
+ ):
+ # Redaction was allowed
+ pruned_json = encode_json(
+ prune_event_dict(
+ original_event.room_version, original_event.get_dict()
+ )
+ )
+ else:
+ # Redaction wasn't allowed
+ pruned_json = None
+
+ updates.append((redaction_id, event_id, pruned_json))
+
+ def _update_censor_txn(txn):
+ for redaction_id, event_id, pruned_json in updates:
+ if pruned_json:
+ self._censor_event_txn(txn, event_id, pruned_json)
+
+ self.db.simple_update_one_txn(
+ txn,
+ table="redactions",
+ keyvalues={"event_id": redaction_id},
+ updatevalues={"have_censored": True},
+ )
+
+ await self.db.runInteraction("_update_censor_txn", _update_censor_txn)
+
+ def _censor_event_txn(self, txn, event_id, pruned_json):
+ """Censor an event by replacing its JSON in the event_json table with the
+ provided pruned JSON.
+
+ Args:
+ txn (LoggingTransaction): The database transaction.
+ event_id (str): The ID of the event to censor.
+ pruned_json (str): The pruned JSON
+ """
+ self.db.simple_update_one_txn(
+ txn,
+ table="event_json",
+ keyvalues={"event_id": event_id},
+ updatevalues={"json": pruned_json},
)
@defer.inlineCallbacks
@@ -1511,11 +1225,11 @@ class EventsStore(
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count
- ret = yield self.runInteraction("count_messages", _count_messages)
- defer.returnValue(ret)
+ ret = yield self.db.runInteraction("count_messages", _count_messages)
+ return ret
@defer.inlineCallbacks
def count_daily_sent_messages(self):
@@ -1532,11 +1246,11 @@ class EventsStore(
"""
txn.execute(sql, (like_clause, self.stream_ordering_day_ago))
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count
- ret = yield self.runInteraction("count_daily_sent_messages", _count_messages)
- defer.returnValue(ret)
+ ret = yield self.db.runInteraction("count_daily_sent_messages", _count_messages)
+ return ret
@defer.inlineCallbacks
def count_daily_active_rooms(self):
@@ -1547,11 +1261,11 @@ class EventsStore(
AND stream_ordering > ?
"""
txn.execute(sql, (self.stream_ordering_day_ago,))
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count
- ret = yield self.runInteraction("count_daily_active_rooms", _count)
- defer.returnValue(ret)
+ ret = yield self.db.runInteraction("count_daily_active_rooms", _count)
+ return ret
def get_current_backfill_token(self):
"""The current minimum token that backfilled events have reached"""
@@ -1602,7 +1316,7 @@ class EventsStore(
return new_event_updates
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
)
@@ -1647,7 +1361,7 @@ class EventsStore(
return new_event_updates
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
)
@@ -1740,7 +1454,7 @@ class EventsStore(
backward_ex_outliers,
)
- return self.runInteraction("get_all_new_events", get_all_new_events_txn)
+ return self.db.runInteraction("get_all_new_events", get_all_new_events_txn)
def purge_history(self, room_id, token, delete_local_events):
"""Deletes room history before a certain point
@@ -1754,9 +1468,13 @@ class EventsStore(
if True, we will delete local events as well as remote ones
(instead of just marking them as outliers and deleting their
state groups).
+
+ Returns:
+ Deferred[set[int]]: The set of state groups that are referenced by
+ deleted events.
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"purge_history",
self._purge_history_txn,
room_id,
@@ -1854,7 +1572,7 @@ class EventsStore(
# We do joins against events_to_purge for e.g. calculating state
# groups to purge, etc., so lets make an index.
- txn.execute("CREATE INDEX events_to_purge_id" " ON events_to_purge(event_id)")
+ txn.execute("CREATE INDEX events_to_purge_id ON events_to_purge(event_id)")
txn.execute("SELECT event_id, should_delete FROM events_to_purge")
event_rows = txn.fetchall()
@@ -1890,11 +1608,10 @@ class EventsStore(
[(room_id, event_id) for event_id, in new_backwards_extrems],
)
- logger.info("[purge] finding redundant state groups")
+ logger.info("[purge] finding state groups referenced by deleted events")
# Get all state groups that are referenced by events that are to be
- # deleted. We then go and check if they are referenced by other events
- # or state groups, and if not we delete them.
+ # deleted.
txn.execute(
"""
SELECT DISTINCT state_group FROM events_to_purge
@@ -1902,65 +1619,11 @@ class EventsStore(
"""
)
- referenced_state_groups = set(sg for sg, in txn)
+ referenced_state_groups = {sg for sg, in txn}
logger.info(
"[purge] found %i referenced state groups", len(referenced_state_groups)
)
- logger.info("[purge] finding state groups that can be deleted")
-
- _ = self._find_unreferenced_groups_during_purge(txn, referenced_state_groups)
- state_groups_to_delete, remaining_state_groups = _
-
- logger.info(
- "[purge] found %i state groups to delete", len(state_groups_to_delete)
- )
-
- logger.info(
- "[purge] de-delta-ing %i remaining state groups",
- len(remaining_state_groups),
- )
-
- # Now we turn the state groups that reference to-be-deleted state
- # groups to non delta versions.
- for sg in remaining_state_groups:
- logger.info("[purge] de-delta-ing remaining state group %s", sg)
- curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
- curr_state = curr_state[sg]
-
- self._simple_delete_txn(
- txn, table="state_groups_state", keyvalues={"state_group": sg}
- )
-
- self._simple_delete_txn(
- txn, table="state_group_edges", keyvalues={"state_group": sg}
- )
-
- self._simple_insert_many_txn(
- txn,
- table="state_groups_state",
- values=[
- {
- "state_group": sg,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
- for key, state_id in iteritems(curr_state)
- ],
- )
-
- logger.info("[purge] removing redundant state groups")
- txn.executemany(
- "DELETE FROM state_groups_state WHERE state_group = ?",
- ((sg,) for sg in state_groups_to_delete),
- )
- txn.executemany(
- "DELETE FROM state_groups WHERE id = ?",
- ((sg,) for sg in state_groups_to_delete),
- )
-
logger.info("[purge] removing events from event_to_state_groups")
txn.execute(
"DELETE FROM event_to_state_groups "
@@ -2032,7 +1695,7 @@ class EventsStore(
""",
(room_id,),
)
- min_depth, = txn.fetchone()
+ (min_depth,) = txn.fetchone()
logger.info("[purge] updating room_depth to %d", min_depth)
@@ -2047,98 +1710,135 @@ class EventsStore(
logger.info("[purge] done")
- def _find_unreferenced_groups_during_purge(self, txn, state_groups):
- """Used when purging history to figure out which state groups can be
- deleted and which need to be de-delta'ed (due to one of its prev groups
- being scheduled for deletion).
+ return referenced_state_groups
+
+ def purge_room(self, room_id):
+ """Deletes all record of a room
Args:
- txn
- state_groups (set[int]): Set of state groups referenced by events
- that are going to be deleted.
+ room_id (str)
Returns:
- tuple[set[int], set[int]]: The set of state groups that can be
- deleted and the set of state groups that need to be de-delta'ed
+ Deferred[List[int]]: The list of state groups to delete.
"""
- # Graph of state group -> previous group
- graph = {}
-
- # Set of events that we have found to be referenced by events
- referenced_groups = set()
-
- # Set of state groups we've already seen
- state_groups_seen = set(state_groups)
-
- # Set of state groups to handle next.
- next_to_search = set(state_groups)
- while next_to_search:
- # We bound size of groups we're looking up at once, to stop the
- # SQL query getting too big
- if len(next_to_search) < 100:
- current_search = next_to_search
- next_to_search = set()
- else:
- current_search = set(itertools.islice(next_to_search, 100))
- next_to_search -= current_search
- # Check if state groups are referenced
- sql = """
- SELECT DISTINCT state_group FROM event_to_state_groups
- LEFT JOIN events_to_purge AS ep USING (event_id)
- WHERE state_group IN (%s) AND ep.event_id IS NULL
- """ % (
- ",".join("?" for _ in current_search),
- )
- txn.execute(sql, list(current_search))
+ return self.db.runInteraction("purge_room", self._purge_room_txn, room_id)
- referenced = set(sg for sg, in txn)
- referenced_groups |= referenced
+ def _purge_room_txn(self, txn, room_id):
+ # First we fetch all the state groups that should be deleted, before
+ # we delete that information.
+ txn.execute(
+ """
+ SELECT DISTINCT state_group FROM events
+ INNER JOIN event_to_state_groups USING(event_id)
+ WHERE events.room_id = ?
+ """,
+ (room_id,),
+ )
- # We don't continue iterating up the state group graphs for state
- # groups that are referenced.
- current_search -= referenced
+ state_groups = [row[0] for row in txn]
- rows = self._simple_select_many_txn(
- txn,
- table="state_group_edges",
- column="prev_state_group",
- iterable=current_search,
- keyvalues={},
- retcols=("prev_state_group", "state_group"),
+ # Now we delete tables which lack an index on room_id but have one on event_id
+ for table in (
+ "event_auth",
+ "event_edges",
+ "event_push_actions_staging",
+ "event_reference_hashes",
+ "event_relations",
+ "event_to_state_groups",
+ "redactions",
+ "rejections",
+ "state_events",
+ ):
+ logger.info("[purge] removing %s from %s", room_id, table)
+
+ txn.execute(
+ """
+ DELETE FROM %s WHERE event_id IN (
+ SELECT event_id FROM events WHERE room_id=?
+ )
+ """
+ % (table,),
+ (room_id,),
)
- prevs = set(row["state_group"] for row in rows)
- # We don't bother re-handling groups we've already seen
- prevs -= state_groups_seen
- next_to_search |= prevs
- state_groups_seen |= prevs
+ # and finally, the tables with an index on room_id (or no useful index)
+ for table in (
+ "current_state_events",
+ "event_backward_extremities",
+ "event_forward_extremities",
+ "event_json",
+ "event_push_actions",
+ "event_search",
+ "events",
+ "group_rooms",
+ "public_room_list_stream",
+ "receipts_graph",
+ "receipts_linearized",
+ "room_aliases",
+ "room_depth",
+ "room_memberships",
+ "room_stats_state",
+ "room_stats_current",
+ "room_stats_historical",
+ "room_stats_earliest_token",
+ "rooms",
+ "stream_ordering_to_exterm",
+ "users_in_public_rooms",
+ "users_who_share_private_rooms",
+ # no useful index, but let's clear them anyway
+ "appservice_room_list",
+ "e2e_room_keys",
+ "event_push_summary",
+ "pusher_throttle",
+ "group_summary_rooms",
+ "local_invites",
+ "room_account_data",
+ "room_tags",
+ "local_current_membership",
+ ):
+ logger.info("[purge] removing %s from %s", room_id, table)
+ txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
- for row in rows:
- # Note: Each state group can have at most one prev group
- graph[row["state_group"]] = row["prev_state_group"]
+ # Other tables we do NOT need to clear out:
+ #
+ # - blocked_rooms
+ # This is important, to make sure that we don't accidentally rejoin a blocked
+ # room after it was purged
+ #
+ # - user_directory
+ # This has a room_id column, but it is unused
+ #
+
+ # Other tables that we might want to consider clearing out include:
+ #
+ # - event_reports
+ # Given that these are intended for abuse management my initial
+ # inclination is to leave them in place.
+ #
+ # - current_state_delta_stream
+ # - ex_outlier_stream
+ # - room_tags_revisions
+ # The problem with these is that they are largeish and there is no room_id
+ # index on them. In any case we should be clearing out 'stream' tables
+ # periodically anyway (#5888)
- to_delete = state_groups_seen - referenced_groups
+ # TODO: we could probably usefully do a bunch of cache invalidation here
- to_dedelta = set()
- for sg in referenced_groups:
- prev_sg = graph.get(sg)
- if prev_sg and prev_sg in to_delete:
- to_dedelta.add(sg)
+ logger.info("[purge] done")
- return to_delete, to_dedelta
+ return state_groups
- @defer.inlineCallbacks
- def is_event_after(self, event_id1, event_id2):
+ async def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream
"""
- to_1, so_1 = yield self._get_event_ordering(event_id1)
- to_2, so_2 = yield self._get_event_ordering(event_id2)
- defer.returnValue((to_1, so_1) > (to_2, so_2))
+ to_1, so_1 = await self._get_event_ordering(event_id1)
+ to_2, so_2 = await self._get_event_ordering(event_id2)
+ return (to_1, so_1) > (to_2, so_2)
@cachedInlineCallbacks(max_entries=5000)
def _get_event_ordering(self, event_id):
- res = yield self._simple_select_one(
+ res = yield self.db.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
keyvalues={"event_id": event_id},
@@ -2148,9 +1848,7 @@ class EventsStore(
if not res:
raise SynapseError(404, "Could not find event %s" % (event_id,))
- defer.returnValue(
- (int(res["topological_ordering"]), int(res["stream_ordering"]))
- )
+ return (int(res["topological_ordering"]), int(res["stream_ordering"]))
def get_all_updated_current_state_deltas(self, from_token, to_token, limit):
def get_all_updated_current_state_deltas_txn(txn):
@@ -2163,11 +1861,135 @@ class EventsStore(
txn.execute(sql, (from_token, to_token, limit))
return txn.fetchall()
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_updated_current_state_deltas",
get_all_updated_current_state_deltas_txn,
)
+ def insert_labels_for_event_txn(
+ self, txn, event_id, labels, room_id, topological_ordering
+ ):
+ """Store the mapping between an event's ID and its labels, with one row per
+ (event_id, label) tuple.
+
+ Args:
+ txn (LoggingTransaction): The transaction to execute.
+ event_id (str): The event's ID.
+ labels (list[str]): A list of text labels.
+ room_id (str): The ID of the room the event was sent to.
+ topological_ordering (int): The position of the event in the room's topology.
+ """
+ return self.db.simple_insert_many_txn(
+ txn=txn,
+ table="event_labels",
+ values=[
+ {
+ "event_id": event_id,
+ "label": label,
+ "room_id": room_id,
+ "topological_ordering": topological_ordering,
+ }
+ for label in labels
+ ],
+ )
+
+ def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
+ """Save the expiry timestamp associated with a given event ID.
+
+ Args:
+ txn (LoggingTransaction): The database transaction to use.
+ event_id (str): The event ID the expiry timestamp is associated with.
+ expiry_ts (int): The timestamp at which to expire (delete) the event.
+ """
+ return self.db.simple_insert_txn(
+ txn=txn,
+ table="event_expiry",
+ values={"event_id": event_id, "expiry_ts": expiry_ts},
+ )
+
+ @defer.inlineCallbacks
+ def expire_event(self, event_id):
+ """Retrieve and expire an event that has expired, and delete its associated
+ expiry timestamp. If the event can't be retrieved, delete its associated
+ timestamp so we don't try to expire it again in the future.
+
+ Args:
+ event_id (str): The ID of the event to delete.
+ """
+ # Try to retrieve the event's content from the database or the event cache.
+ event = yield self.get_event(event_id)
+
+ def delete_expired_event_txn(txn):
+ # Delete the expiry timestamp associated with this event from the database.
+ self._delete_event_expiry_txn(txn, event_id)
+
+ if not event:
+ # If we can't find the event, log a warning and delete the expiry date
+ # from the database so that we don't try to expire it again in the
+ # future.
+ logger.warning(
+ "Can't expire event %s because we don't have it.", event_id
+ )
+ return
+
+ # Prune the event's dict then convert it to JSON.
+ pruned_json = encode_json(
+ prune_event_dict(event.room_version, event.get_dict())
+ )
+
+ # Update the event_json table to replace the event's JSON with the pruned
+ # JSON.
+ self._censor_event_txn(txn, event.event_id, pruned_json)
+
+ # We need to invalidate the event cache entry for this event because we
+ # changed its content in the database. We can't call
+ # self._invalidate_cache_and_stream because self.get_event_cache isn't of the
+ # right type.
+ txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
+ # Send that invalidation to replication so that other workers also invalidate
+ # the event cache.
+ self._send_invalidation_to_replication(
+ txn, "_get_event_cache", (event.event_id,)
+ )
+
+ yield self.db.runInteraction("delete_expired_event", delete_expired_event_txn)
+
+ def _delete_event_expiry_txn(self, txn, event_id):
+ """Delete the expiry timestamp associated with an event ID without deleting the
+ actual event.
+
+ Args:
+ txn (LoggingTransaction): The transaction to use to perform the deletion.
+ event_id (str): The event ID to delete the associated expiry timestamp of.
+ """
+ return self.db.simple_delete_txn(
+ txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
+ )
+
+ def get_next_event_to_expire(self):
+ """Retrieve the entry with the lowest expiry timestamp in the event_expiry
+ table, or None if there's no more event to expire.
+
+ Returns: Deferred[Optional[Tuple[str, int]]]
+ A tuple containing the event ID as its first element and an expiry timestamp
+ as its second one, if there's at least one row in the event_expiry table.
+ None otherwise.
+ """
+
+ def get_next_event_to_expire_txn(txn):
+ txn.execute(
+ """
+ SELECT event_id, expiry_ts FROM event_expiry
+ ORDER BY expiry_ts ASC LIMIT 1
+ """
+ )
+
+ return txn.fetchone()
+
+ return self.db.runInteraction(
+ desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
+ )
+
AllNewEventsResult = namedtuple(
"AllNewEventsResult",
diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py
index 75c1935bf3..f54c8b1ee0 100644
--- a/synapse/storage/events_bg_updates.py
+++ b/synapse/storage/data_stores/main/events_bg_updates.py
@@ -21,29 +21,31 @@ from canonicaljson import json
from twisted.internet import defer
-from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.api.constants import EventContentFields
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
logger = logging.getLogger(__name__)
-class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
+class EventsBackgroundUpdatesStore(SQLBaseStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
- def __init__(self, db_conn, hs):
- super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs)
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
)
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
self._background_reindex_fields_sender,
)
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"event_contains_url_index",
index_name="event_contains_url_index",
table="events",
@@ -54,7 +56,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
# an event_id index on event_search is useful for the purge_history
# api. Plus it means we get to enforce some integrity with a UNIQUE
# clause
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"event_search_event_id_idx",
index_name="event_search_event_id_idx",
table="event_search",
@@ -63,9 +65,37 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
psql_only=True,
)
- self.register_background_update_handler(
- self.DELETE_SOFT_FAILED_EXTREMITIES,
- self._cleanup_extremities_bg_update,
+ self.db.updates.register_background_update_handler(
+ self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update
+ )
+
+ self.db.updates.register_background_update_handler(
+ "redactions_received_ts", self._redactions_received_ts
+ )
+
+ # This index gets deleted in `event_fix_redactions_bytes` update
+ self.db.updates.register_background_index_update(
+ "event_fix_redactions_bytes_create_index",
+ index_name="redactions_censored_redacts",
+ table="redactions",
+ columns=["redacts"],
+ where_clause="have_censored",
+ )
+
+ self.db.updates.register_background_update_handler(
+ "event_fix_redactions_bytes", self._event_fix_redactions_bytes
+ )
+
+ self.db.updates.register_background_update_handler(
+ "event_store_labels", self._event_store_labels
+ )
+
+ self.db.updates.register_background_index_update(
+ "redactions_have_censored_ts_idx",
+ index_name="redactions_have_censored_ts",
+ table="redactions",
+ columns=["received_ts"],
+ where_clause="NOT have_censored",
)
@defer.inlineCallbacks
@@ -123,20 +153,22 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
"rows_inserted": rows_inserted + len(rows),
}
- self._background_update_progress_txn(
+ self.db.updates._background_update_progress_txn(
txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
)
return len(rows)
- result = yield self.runInteraction(
+ result = yield self.db.runInteraction(
self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
)
if not result:
- yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME)
+ yield self.db.updates._end_background_update(
+ self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
+ )
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _background_reindex_origin_server_ts(self, progress, batch_size):
@@ -167,7 +199,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
for chunk in chunks:
- ev_rows = self._simple_select_many_txn(
+ ev_rows = self.db.simple_select_many_txn(
txn,
table="event_json",
column="event_id",
@@ -200,20 +232,22 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
"rows_inserted": rows_inserted + len(rows_to_update),
}
- self._background_update_progress_txn(
+ self.db.updates._background_update_progress_txn(
txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress
)
return len(rows_to_update)
- result = yield self.runInteraction(
+ result = yield self.db.runInteraction(
self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
)
if not result:
- yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME)
+ yield self.db.updates._end_background_update(
+ self.EVENT_ORIGIN_SERVER_TS_NAME
+ )
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _cleanup_extremities_bg_update(self, progress, batch_size):
@@ -269,7 +303,8 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
LEFT JOIN events USING (event_id)
LEFT JOIN event_json USING (event_id)
LEFT JOIN rejections USING (event_id)
- """, (batch_size,)
+ """,
+ (batch_size,),
)
for prev_event_id, event_id, metadata, rejected, outlier in txn:
@@ -308,12 +343,13 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
INNER JOIN event_json USING (event_id)
LEFT JOIN rejections USING (event_id)
WHERE
- prev_event_id IN (%s)
- AND NOT events.outlier
- """ % (
- ",".join("?" for _ in to_check),
+ NOT events.outlier
+ AND
+ """
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "prev_event_id", to_check
)
- txn.execute(sql, to_check)
+ txn.execute(sql + clause, list(args))
for prev_event_id, event_id, metadata, rejected in txn:
if event_id in graph:
@@ -342,7 +378,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
to_delete.intersection_update(original_set)
- deleted = self._simple_delete_many_txn(
+ deleted = self.db.simple_delete_many_txn(
txn=txn,
table="event_forward_extremities",
column="event_id",
@@ -358,22 +394,21 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
if deleted:
# We now need to invalidate the caches of these rooms
- rows = self._simple_select_many_txn(
+ rows = self.db.simple_select_many_txn(
txn,
table="events",
column="event_id",
iterable=to_delete,
keyvalues={},
- retcols=("room_id",)
+ retcols=("room_id",),
)
- room_ids = set(row["room_id"] for row in rows)
+ room_ids = {row["room_id"] for row in rows}
for room_id in room_ids:
txn.call_after(
- self.get_latest_event_ids_in_room.invalidate,
- (room_id,)
+ self.get_latest_event_ids_in_room.invalidate, (room_id,)
)
- self._simple_delete_many_txn(
+ self.db.simple_delete_many_txn(
txn=txn,
table="_extremities_to_check",
column="event_id",
@@ -383,19 +418,172 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
return len(original_set)
- num_handled = yield self.runInteraction(
- "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn,
+ num_handled = yield self.db.runInteraction(
+ "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn
)
if not num_handled:
- yield self._end_background_update(self.DELETE_SOFT_FAILED_EXTREMITIES)
+ yield self.db.updates._end_background_update(
+ self.DELETE_SOFT_FAILED_EXTREMITIES
+ )
def _drop_table_txn(txn):
txn.execute("DROP TABLE _extremities_to_check")
- yield self.runInteraction(
- "_cleanup_extremities_bg_update_drop_table",
- _drop_table_txn,
+ yield self.db.runInteraction(
+ "_cleanup_extremities_bg_update_drop_table", _drop_table_txn
+ )
+
+ return num_handled
+
+ @defer.inlineCallbacks
+ def _redactions_received_ts(self, progress, batch_size):
+ """Handles filling out the `received_ts` column in redactions.
+ """
+ last_event_id = progress.get("last_event_id", "")
+
+ def _redactions_received_ts_txn(txn):
+ # Fetch the set of event IDs that we want to update
+ sql = """
+ SELECT event_id FROM redactions
+ WHERE event_id > ?
+ ORDER BY event_id ASC
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_event_id, batch_size))
+
+ rows = txn.fetchall()
+ if not rows:
+ return 0
+
+ (upper_event_id,) = rows[-1]
+
+ # Update the redactions with the received_ts.
+ #
+ # Note: Not all events have an associated received_ts, so we
+ # fallback to using origin_server_ts. If we for some reason don't
+ # have an origin_server_ts, lets just use the current timestamp.
+ #
+ # We don't want to leave it null, as then we'll never try and
+ # censor those redactions.
+ sql = """
+ UPDATE redactions
+ SET received_ts = (
+ SELECT COALESCE(received_ts, origin_server_ts, ?) FROM events
+ WHERE events.event_id = redactions.event_id
+ )
+ WHERE ? <= event_id AND event_id <= ?
+ """
+
+ txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id))
+
+ self.db.updates._background_update_progress_txn(
+ txn, "redactions_received_ts", {"last_event_id": upper_event_id}
+ )
+
+ return len(rows)
+
+ count = yield self.db.runInteraction(
+ "_redactions_received_ts", _redactions_received_ts_txn
+ )
+
+ if not count:
+ yield self.db.updates._end_background_update("redactions_received_ts")
+
+ return count
+
+ @defer.inlineCallbacks
+ def _event_fix_redactions_bytes(self, progress, batch_size):
+ """Undoes hex encoded censored redacted event JSON.
+ """
+
+ def _event_fix_redactions_bytes_txn(txn):
+ # This update is quite fast due to new index.
+ txn.execute(
+ """
+ UPDATE event_json
+ SET
+ json = convert_from(json::bytea, 'utf8')
+ FROM redactions
+ WHERE
+ redactions.have_censored
+ AND event_json.event_id = redactions.redacts
+ AND json NOT LIKE '{%';
+ """
)
- defer.returnValue(num_handled)
+ txn.execute("DROP INDEX redactions_censored_redacts")
+
+ yield self.db.runInteraction(
+ "_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
+ )
+
+ yield self.db.updates._end_background_update("event_fix_redactions_bytes")
+
+ return 1
+
+ @defer.inlineCallbacks
+ def _event_store_labels(self, progress, batch_size):
+ """Background update handler which will store labels for existing events."""
+ last_event_id = progress.get("last_event_id", "")
+
+ def _event_store_labels_txn(txn):
+ txn.execute(
+ """
+ SELECT event_id, json FROM event_json
+ LEFT JOIN event_labels USING (event_id)
+ WHERE event_id > ? AND label IS NULL
+ ORDER BY event_id LIMIT ?
+ """,
+ (last_event_id, batch_size),
+ )
+
+ results = list(txn)
+
+ nbrows = 0
+ last_row_event_id = ""
+ for (event_id, event_json_raw) in results:
+ try:
+ event_json = json.loads(event_json_raw)
+
+ self.db.simple_insert_many_txn(
+ txn=txn,
+ table="event_labels",
+ values=[
+ {
+ "event_id": event_id,
+ "label": label,
+ "room_id": event_json["room_id"],
+ "topological_ordering": event_json["depth"],
+ }
+ for label in event_json["content"].get(
+ EventContentFields.LABELS, []
+ )
+ if isinstance(label, str)
+ ],
+ )
+ except Exception as e:
+ logger.warning(
+ "Unable to load event %s (no labels will be imported): %s",
+ event_id,
+ e,
+ )
+
+ nbrows += 1
+ last_row_event_id = event_id
+
+ self.db.updates._background_update_progress_txn(
+ txn, "event_store_labels", {"last_event_id": last_row_event_id}
+ )
+
+ return nbrows
+
+ num_rows = yield self.db.runInteraction(
+ desc="event_store_labels", func=_event_store_labels_txn
+ )
+
+ if not num_rows:
+ yield self.db.updates._end_background_update("event_store_labels")
+
+ return num_rows
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
new file mode 100644
index 0000000000..ca237c6f12
--- /dev/null
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -0,0 +1,965 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import division
+
+import itertools
+import logging
+import threading
+from collections import namedtuple
+from typing import List, Optional
+
+from canonicaljson import json
+from constantly import NamedConstant, Names
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes
+from synapse.api.errors import NotFoundError
+from synapse.api.room_versions import (
+ KNOWN_ROOM_VERSIONS,
+ EventFormatVersions,
+ RoomVersions,
+)
+from synapse.events import make_event_from_dict
+from synapse.events.utils import prune_event
+from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
+from synapse.types import get_domain_from_id
+from synapse.util.caches.descriptors import Cache
+from synapse.util.iterutils import batch_iter
+from synapse.util.metrics import Measure
+
+logger = logging.getLogger(__name__)
+
+
+# These values are used in the `enqueus_event` and `_do_fetch` methods to
+# control how we batch/bulk fetch events from the database.
+# The values are plucked out of thing air to make initial sync run faster
+# on jki.re
+# TODO: Make these configurable.
+EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events
+EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
+EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
+
+
+_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
+
+
+class EventRedactBehaviour(Names):
+ """
+ What to do when retrieving a redacted event from the database.
+ """
+
+ AS_IS = NamedConstant()
+ REDACT = NamedConstant()
+ BLOCK = NamedConstant()
+
+
+class EventsWorkerStore(SQLBaseStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(EventsWorkerStore, self).__init__(database, db_conn, hs)
+
+ self._get_event_cache = Cache(
+ "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
+ )
+
+ self._event_fetch_lock = threading.Condition()
+ self._event_fetch_list = []
+ self._event_fetch_ongoing = 0
+
+ def get_received_ts(self, event_id):
+ """Get received_ts (when it was persisted) for the event.
+
+ Raises an exception for unknown events.
+
+ Args:
+ event_id (str)
+
+ Returns:
+ Deferred[int|None]: Timestamp in milliseconds, or None for events
+ that were persisted before received_ts was implemented.
+ """
+ return self.db.simple_select_one_onecol(
+ table="events",
+ keyvalues={"event_id": event_id},
+ retcol="received_ts",
+ desc="get_received_ts",
+ )
+
+ def get_received_ts_by_stream_pos(self, stream_ordering):
+ """Given a stream ordering get an approximate timestamp of when it
+ happened.
+
+ This is done by simply taking the received ts of the first event that
+ has a stream ordering greater than or equal to the given stream pos.
+ If none exists returns the current time, on the assumption that it must
+ have happened recently.
+
+ Args:
+ stream_ordering (int)
+
+ Returns:
+ Deferred[int]
+ """
+
+ def _get_approximate_received_ts_txn(txn):
+ sql = """
+ SELECT received_ts FROM events
+ WHERE stream_ordering >= ?
+ LIMIT 1
+ """
+
+ txn.execute(sql, (stream_ordering,))
+ row = txn.fetchone()
+ if row and row[0]:
+ ts = row[0]
+ else:
+ ts = self.clock.time_msec()
+
+ return ts
+
+ return self.db.runInteraction(
+ "get_approximate_received_ts", _get_approximate_received_ts_txn
+ )
+
+ @defer.inlineCallbacks
+ def get_event(
+ self,
+ event_id: str,
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
+ allow_none: bool = False,
+ check_room_id: Optional[str] = None,
+ ):
+ """Get an event from the database by event_id.
+
+ Args:
+ event_id: The event_id of the event to fetch
+
+ redact_behaviour: Determine what to do with a redacted event. Possible values:
+ * AS_IS - Return the full event body with no redacted content
+ * REDACT - Return the event but with a redacted body
+ * DISALLOW - Do not return redacted events (behave as per allow_none
+ if the event is redacted)
+
+ get_prev_content: If True and event is a state event,
+ include the previous states content in the unsigned field.
+
+ allow_rejected: If True, return rejected events. Otherwise,
+ behave as per allow_none.
+
+ allow_none: If True, return None if no event found, if
+ False throw a NotFoundError
+
+ check_room_id: if not None, check the room of the found event.
+ If there is a mismatch, behave as per allow_none.
+
+ Returns:
+ Deferred[EventBase|None]
+ """
+ if not isinstance(event_id, str):
+ raise TypeError("Invalid event event_id %r" % (event_id,))
+
+ events = yield self.get_events_as_list(
+ [event_id],
+ redact_behaviour=redact_behaviour,
+ get_prev_content=get_prev_content,
+ allow_rejected=allow_rejected,
+ )
+
+ event = events[0] if events else None
+
+ if event is not None and check_room_id is not None:
+ if event.room_id != check_room_id:
+ event = None
+
+ if event is None and not allow_none:
+ raise NotFoundError("Could not find event %s" % (event_id,))
+
+ return event
+
+ @defer.inlineCallbacks
+ def get_events(
+ self,
+ event_ids: List[str],
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
+ ):
+ """Get events from the database
+
+ Args:
+ event_ids: The event_ids of the events to fetch
+
+ redact_behaviour: Determine what to do with a redacted event. Possible
+ values:
+ * AS_IS - Return the full event body with no redacted content
+ * REDACT - Return the event but with a redacted body
+ * DISALLOW - Do not return redacted events (omit them from the response)
+
+ get_prev_content: If True and event is a state event,
+ include the previous states content in the unsigned field.
+
+ allow_rejected: If True, return rejected events. Otherwise,
+ omits rejeted events from the response.
+
+ Returns:
+ Deferred : Dict from event_id to event.
+ """
+ events = yield self.get_events_as_list(
+ event_ids,
+ redact_behaviour=redact_behaviour,
+ get_prev_content=get_prev_content,
+ allow_rejected=allow_rejected,
+ )
+
+ return {e.event_id: e for e in events}
+
+ @defer.inlineCallbacks
+ def get_events_as_list(
+ self,
+ event_ids: List[str],
+ redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+ get_prev_content: bool = False,
+ allow_rejected: bool = False,
+ ):
+ """Get events from the database and return in a list in the same order
+ as given by `event_ids` arg.
+
+ Unknown events will be omitted from the response.
+
+ Args:
+ event_ids: The event_ids of the events to fetch
+
+ redact_behaviour: Determine what to do with a redacted event. Possible values:
+ * AS_IS - Return the full event body with no redacted content
+ * REDACT - Return the event but with a redacted body
+ * DISALLOW - Do not return redacted events (omit them from the response)
+
+ get_prev_content: If True and event is a state event,
+ include the previous states content in the unsigned field.
+
+ allow_rejected: If True, return rejected events. Otherwise,
+ omits rejected events from the response.
+
+ Returns:
+ Deferred[list[EventBase]]: List of events fetched from the database. The
+ events are in the same order as `event_ids` arg.
+
+ Note that the returned list may be smaller than the list of event
+ IDs if not all events could be fetched.
+ """
+
+ if not event_ids:
+ return []
+
+ # there may be duplicates so we cast the list to a set
+ event_entry_map = yield self._get_events_from_cache_or_db(
+ set(event_ids), allow_rejected=allow_rejected
+ )
+
+ events = []
+ for event_id in event_ids:
+ entry = event_entry_map.get(event_id, None)
+ if not entry:
+ continue
+
+ if not allow_rejected:
+ assert not entry.event.rejected_reason, (
+ "rejected event returned from _get_events_from_cache_or_db despite "
+ "allow_rejected=False"
+ )
+
+ # We may not have had the original event when we received a redaction, so
+ # we have to recheck auth now.
+
+ if not allow_rejected and entry.event.type == EventTypes.Redaction:
+ if entry.event.redacts is None:
+ # A redacted redaction doesn't have a `redacts` key, in
+ # which case lets just withhold the event.
+ #
+ # Note: Most of the time if the redactions has been
+ # redacted we still have the un-redacted event in the DB
+ # and so we'll still see the `redacts` key. However, this
+ # isn't always true e.g. if we have censored the event.
+ logger.debug(
+ "Withholding redaction event %s as we don't have redacts key",
+ event_id,
+ )
+ continue
+
+ redacted_event_id = entry.event.redacts
+ event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
+ original_event_entry = event_map.get(redacted_event_id)
+ if not original_event_entry:
+ # we don't have the redacted event (or it was rejected).
+ #
+ # We assume that the redaction isn't authorized for now; if the
+ # redacted event later turns up, the redaction will be re-checked,
+ # and if it is found valid, the original will get redacted before it
+ # is served to the client.
+ logger.debug(
+ "Withholding redaction event %s since we don't (yet) have the "
+ "original %s",
+ event_id,
+ redacted_event_id,
+ )
+ continue
+
+ original_event = original_event_entry.event
+ if original_event.type == EventTypes.Create:
+ # we never serve redactions of Creates to clients.
+ logger.info(
+ "Withholding redaction %s of create event %s",
+ event_id,
+ redacted_event_id,
+ )
+ continue
+
+ if original_event.room_id != entry.event.room_id:
+ logger.info(
+ "Withholding redaction %s of event %s from a different room",
+ event_id,
+ redacted_event_id,
+ )
+ continue
+
+ if entry.event.internal_metadata.need_to_check_redaction():
+ original_domain = get_domain_from_id(original_event.sender)
+ redaction_domain = get_domain_from_id(entry.event.sender)
+ if original_domain != redaction_domain:
+ # the senders don't match, so this is forbidden
+ logger.info(
+ "Withholding redaction %s whose sender domain %s doesn't "
+ "match that of redacted event %s %s",
+ event_id,
+ redaction_domain,
+ redacted_event_id,
+ original_domain,
+ )
+ continue
+
+ # Update the cache to save doing the checks again.
+ entry.event.internal_metadata.recheck_redaction = False
+
+ event = entry.event
+
+ if entry.redacted_event:
+ if redact_behaviour == EventRedactBehaviour.BLOCK:
+ # Skip this event
+ continue
+ elif redact_behaviour == EventRedactBehaviour.REDACT:
+ event = entry.redacted_event
+
+ events.append(event)
+
+ if get_prev_content:
+ if "replaces_state" in event.unsigned:
+ prev = yield self.get_event(
+ event.unsigned["replaces_state"],
+ get_prev_content=False,
+ allow_none=True,
+ )
+ if prev:
+ event.unsigned = dict(event.unsigned)
+ event.unsigned["prev_content"] = prev.content
+ event.unsigned["prev_sender"] = prev.sender
+
+ return events
+
+ @defer.inlineCallbacks
+ def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
+ """Fetch a bunch of events from the cache or the database.
+
+ If events are pulled from the database, they will be cached for future lookups.
+
+ Unknown events are omitted from the response.
+
+ Args:
+
+ event_ids (Iterable[str]): The event_ids of the events to fetch
+
+ allow_rejected (bool): Whether to include rejected events. If False,
+ rejected events are omitted from the response.
+
+ Returns:
+ Deferred[Dict[str, _EventCacheEntry]]:
+ map from event id to result
+ """
+ event_entry_map = self._get_events_from_cache(
+ event_ids, allow_rejected=allow_rejected
+ )
+
+ missing_events_ids = [e for e in event_ids if e not in event_entry_map]
+
+ if missing_events_ids:
+ log_ctx = LoggingContext.current_context()
+ log_ctx.record_event_fetch(len(missing_events_ids))
+
+ # Note that _get_events_from_db is also responsible for turning db rows
+ # into FrozenEvents (via _get_event_from_row), which involves seeing if
+ # the events have been redacted, and if so pulling the redaction event out
+ # of the database to check it.
+ #
+ missing_events = yield self._get_events_from_db(
+ missing_events_ids, allow_rejected=allow_rejected
+ )
+
+ event_entry_map.update(missing_events)
+
+ return event_entry_map
+
+ def _invalidate_get_event_cache(self, event_id):
+ self._get_event_cache.invalidate((event_id,))
+
+ def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
+ """Fetch events from the caches
+
+ Args:
+ events (Iterable[str]): list of event_ids to fetch
+ allow_rejected (bool): Whether to return events that were rejected
+ update_metrics (bool): Whether to update the cache hit ratio metrics
+
+ Returns:
+ dict of event_id -> _EventCacheEntry for each event_id in cache. If
+ allow_rejected is `False` then there will still be an entry but it
+ will be `None`
+ """
+ event_map = {}
+
+ for event_id in events:
+ ret = self._get_event_cache.get(
+ (event_id,), None, update_metrics=update_metrics
+ )
+ if not ret:
+ continue
+
+ if allow_rejected or not ret.event.rejected_reason:
+ event_map[event_id] = ret
+ else:
+ event_map[event_id] = None
+
+ return event_map
+
+ def _do_fetch(self, conn):
+ """Takes a database connection and waits for requests for events from
+ the _event_fetch_list queue.
+ """
+ i = 0
+ while True:
+ with self._event_fetch_lock:
+ event_list = self._event_fetch_list
+ self._event_fetch_list = []
+
+ if not event_list:
+ single_threaded = self.database_engine.single_threaded
+ if single_threaded or i > EVENT_QUEUE_ITERATIONS:
+ self._event_fetch_ongoing -= 1
+ return
+ else:
+ self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
+ i += 1
+ continue
+ i = 0
+
+ self._fetch_event_list(conn, event_list)
+
+ def _fetch_event_list(self, conn, event_list):
+ """Handle a load of requests from the _event_fetch_list queue
+
+ Args:
+ conn (twisted.enterprise.adbapi.Connection): database connection
+
+ event_list (list[Tuple[list[str], Deferred]]):
+ The fetch requests. Each entry consists of a list of event
+ ids to be fetched, and a deferred to be completed once the
+ events have been fetched.
+
+ The deferreds are callbacked with a dictionary mapping from event id
+ to event row. Note that it may well contain additional events that
+ were not part of this request.
+ """
+ with Measure(self._clock, "_fetch_event_list"):
+ try:
+ events_to_fetch = {
+ event_id for events, _ in event_list for event_id in events
+ }
+
+ row_dict = self.db.new_transaction(
+ conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
+ )
+
+ # We only want to resolve deferreds from the main thread
+ def fire():
+ for _, d in event_list:
+ d.callback(row_dict)
+
+ with PreserveLoggingContext():
+ self.hs.get_reactor().callFromThread(fire)
+ except Exception as e:
+ logger.exception("do_fetch")
+
+ # We only want to resolve deferreds from the main thread
+ def fire(evs, exc):
+ for _, d in evs:
+ if not d.called:
+ with PreserveLoggingContext():
+ d.errback(exc)
+
+ with PreserveLoggingContext():
+ self.hs.get_reactor().callFromThread(fire, event_list, e)
+
+ @defer.inlineCallbacks
+ def _get_events_from_db(self, event_ids, allow_rejected=False):
+ """Fetch a bunch of events from the database.
+
+ Returned events will be added to the cache for future lookups.
+
+ Unknown events are omitted from the response.
+
+ Args:
+ event_ids (Iterable[str]): The event_ids of the events to fetch
+
+ allow_rejected (bool): Whether to include rejected events. If False,
+ rejected events are omitted from the response.
+
+ Returns:
+ Deferred[Dict[str, _EventCacheEntry]]:
+ map from event id to result. May return extra events which
+ weren't asked for.
+ """
+ fetched_events = {}
+ events_to_fetch = event_ids
+
+ while events_to_fetch:
+ row_map = yield self._enqueue_events(events_to_fetch)
+
+ # we need to recursively fetch any redactions of those events
+ redaction_ids = set()
+ for event_id in events_to_fetch:
+ row = row_map.get(event_id)
+ fetched_events[event_id] = row
+ if row:
+ redaction_ids.update(row["redactions"])
+
+ events_to_fetch = redaction_ids.difference(fetched_events.keys())
+ if events_to_fetch:
+ logger.debug("Also fetching redaction events %s", events_to_fetch)
+
+ # build a map from event_id to EventBase
+ event_map = {}
+ for event_id, row in fetched_events.items():
+ if not row:
+ continue
+ assert row["event_id"] == event_id
+
+ rejected_reason = row["rejected_reason"]
+
+ if not allow_rejected and rejected_reason:
+ continue
+
+ d = json.loads(row["json"])
+ internal_metadata = json.loads(row["internal_metadata"])
+
+ format_version = row["format_version"]
+ if format_version is None:
+ # This means that we stored the event before we had the concept
+ # of a event format version, so it must be a V1 event.
+ format_version = EventFormatVersions.V1
+
+ room_version_id = row["room_version_id"]
+
+ if not room_version_id:
+ # this should only happen for out-of-band membership events
+ if not internal_metadata.get("out_of_band_membership"):
+ logger.warning(
+ "Room %s for event %s is unknown", d["room_id"], event_id
+ )
+ continue
+
+ # take a wild stab at the room version based on the event format
+ if format_version == EventFormatVersions.V1:
+ room_version = RoomVersions.V1
+ elif format_version == EventFormatVersions.V2:
+ room_version = RoomVersions.V3
+ else:
+ room_version = RoomVersions.V5
+ else:
+ room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
+ if not room_version:
+ logger.error(
+ "Event %s in room %s has unknown room version %s",
+ event_id,
+ d["room_id"],
+ room_version_id,
+ )
+ continue
+
+ if room_version.event_format != format_version:
+ logger.error(
+ "Event %s in room %s with version %s has wrong format: "
+ "expected %s, was %s",
+ event_id,
+ d["room_id"],
+ room_version_id,
+ room_version.event_format,
+ format_version,
+ )
+ continue
+
+ original_ev = make_event_from_dict(
+ event_dict=d,
+ room_version=room_version,
+ internal_metadata_dict=internal_metadata,
+ rejected_reason=rejected_reason,
+ )
+
+ event_map[event_id] = original_ev
+
+ # finally, we can decide whether each one nededs redacting, and build
+ # the cache entries.
+ result_map = {}
+ for event_id, original_ev in event_map.items():
+ redactions = fetched_events[event_id]["redactions"]
+ redacted_event = self._maybe_redact_event_row(
+ original_ev, redactions, event_map
+ )
+
+ cache_entry = _EventCacheEntry(
+ event=original_ev, redacted_event=redacted_event
+ )
+
+ self._get_event_cache.prefill((event_id,), cache_entry)
+ result_map[event_id] = cache_entry
+
+ return result_map
+
+ @defer.inlineCallbacks
+ def _enqueue_events(self, events):
+ """Fetches events from the database using the _event_fetch_list. This
+ allows batch and bulk fetching of events - it allows us to fetch events
+ without having to create a new transaction for each request for events.
+
+ Args:
+ events (Iterable[str]): events to be fetched.
+
+ Returns:
+ Deferred[Dict[str, Dict]]: map from event id to row data from the database.
+ May contain events that weren't requested.
+ """
+
+ events_d = defer.Deferred()
+ with self._event_fetch_lock:
+ self._event_fetch_list.append((events, events_d))
+
+ self._event_fetch_lock.notify()
+
+ if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
+ self._event_fetch_ongoing += 1
+ should_start = True
+ else:
+ should_start = False
+
+ if should_start:
+ run_as_background_process(
+ "fetch_events", self.db.runWithConnection, self._do_fetch
+ )
+
+ logger.debug("Loading %d events: %s", len(events), events)
+ with PreserveLoggingContext():
+ row_map = yield events_d
+ logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
+
+ return row_map
+
+ def _fetch_event_rows(self, txn, event_ids):
+ """Fetch event rows from the database
+
+ Events which are not found are omitted from the result.
+
+ The returned per-event dicts contain the following keys:
+
+ * event_id (str)
+
+ * json (str): json-encoded event structure
+
+ * internal_metadata (str): json-encoded internal metadata dict
+
+ * format_version (int|None): The format of the event. Hopefully one
+ of EventFormatVersions. 'None' means the event predates
+ EventFormatVersions (so the event is format V1).
+
+ * room_version_id (str|None): The version of the room which contains the event.
+ Hopefully one of RoomVersions.
+
+ Due to historical reasons, there may be a few events in the database which
+ do not have an associated room; in this case None will be returned here.
+
+ * rejected_reason (str|None): if the event was rejected, the reason
+ why.
+
+ * redactions (List[str]): a list of event-ids which (claim to) redact
+ this event.
+
+ Args:
+ txn (twisted.enterprise.adbapi.Connection):
+ event_ids (Iterable[str]): event IDs to fetch
+
+ Returns:
+ Dict[str, Dict]: a map from event id to event info.
+ """
+ event_dict = {}
+ for evs in batch_iter(event_ids, 200):
+ sql = """\
+ SELECT
+ e.event_id,
+ e.internal_metadata,
+ e.json,
+ e.format_version,
+ r.room_version,
+ rej.reason
+ FROM event_json as e
+ LEFT JOIN rooms r USING (room_id)
+ LEFT JOIN rejections as rej USING (event_id)
+ WHERE """
+
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "e.event_id", evs
+ )
+
+ txn.execute(sql + clause, args)
+
+ for row in txn:
+ event_id = row[0]
+ event_dict[event_id] = {
+ "event_id": event_id,
+ "internal_metadata": row[1],
+ "json": row[2],
+ "format_version": row[3],
+ "room_version_id": row[4],
+ "rejected_reason": row[5],
+ "redactions": [],
+ }
+
+ # check for redactions
+ redactions_sql = "SELECT event_id, redacts FROM redactions WHERE "
+
+ clause, args = make_in_list_sql_clause(txn.database_engine, "redacts", evs)
+
+ txn.execute(redactions_sql + clause, args)
+
+ for (redacter, redacted) in txn:
+ d = event_dict.get(redacted)
+ if d:
+ d["redactions"].append(redacter)
+
+ return event_dict
+
+ def _maybe_redact_event_row(self, original_ev, redactions, event_map):
+ """Given an event object and a list of possible redacting event ids,
+ determine whether to honour any of those redactions and if so return a redacted
+ event.
+
+ Args:
+ original_ev (EventBase):
+ redactions (iterable[str]): list of event ids of potential redaction events
+ event_map (dict[str, EventBase]): other events which have been fetched, in
+ which we can look up the redaaction events. Map from event id to event.
+
+ Returns:
+ Deferred[EventBase|None]: if the event should be redacted, a pruned
+ event object. Otherwise, None.
+ """
+ if original_ev.type == "m.room.create":
+ # we choose to ignore redactions of m.room.create events.
+ return None
+
+ for redaction_id in redactions:
+ redaction_event = event_map.get(redaction_id)
+ if not redaction_event or redaction_event.rejected_reason:
+ # we don't have the redaction event, or the redaction event was not
+ # authorized.
+ logger.debug(
+ "%s was redacted by %s but redaction not found/authed",
+ original_ev.event_id,
+ redaction_id,
+ )
+ continue
+
+ if redaction_event.room_id != original_ev.room_id:
+ logger.debug(
+ "%s was redacted by %s but redaction was in a different room!",
+ original_ev.event_id,
+ redaction_id,
+ )
+ continue
+
+ # Starting in room version v3, some redactions need to be
+ # rechecked if we didn't have the redacted event at the
+ # time, so we recheck on read instead.
+ if redaction_event.internal_metadata.need_to_check_redaction():
+ expected_domain = get_domain_from_id(original_ev.sender)
+ if get_domain_from_id(redaction_event.sender) == expected_domain:
+ # This redaction event is allowed. Mark as not needing a recheck.
+ redaction_event.internal_metadata.recheck_redaction = False
+ else:
+ # Senders don't match, so the event isn't actually redacted
+ logger.debug(
+ "%s was redacted by %s but the senders don't match",
+ original_ev.event_id,
+ redaction_id,
+ )
+ continue
+
+ logger.debug("Redacting %s due to %s", original_ev.event_id, redaction_id)
+
+ # we found a good redaction event. Redact!
+ redacted_event = prune_event(original_ev)
+ redacted_event.unsigned["redacted_by"] = redaction_id
+
+ # It's fine to add the event directly, since get_pdu_json
+ # will serialise this field correctly
+ redacted_event.unsigned["redacted_because"] = redaction_event
+
+ return redacted_event
+
+ # no valid redaction found for this event
+ return None
+
+ @defer.inlineCallbacks
+ def have_events_in_timeline(self, event_ids):
+ """Given a list of event ids, check if we have already processed and
+ stored them as non outliers.
+ """
+ rows = yield self.db.simple_select_many_batch(
+ table="events",
+ retcols=("event_id",),
+ column="event_id",
+ iterable=list(event_ids),
+ keyvalues={"outlier": False},
+ desc="have_events_in_timeline",
+ )
+
+ return {r["event_id"] for r in rows}
+
+ @defer.inlineCallbacks
+ def have_seen_events(self, event_ids):
+ """Given a list of event ids, check if we have already processed them.
+
+ Args:
+ event_ids (iterable[str]):
+
+ Returns:
+ Deferred[set[str]]: The events we have already seen.
+ """
+ results = set()
+
+ def have_seen_events_txn(txn, chunk):
+ sql = "SELECT event_id FROM events as e WHERE "
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "e.event_id", chunk
+ )
+ txn.execute(sql + clause, args)
+ for (event_id,) in txn:
+ results.add(event_id)
+
+ # break the input up into chunks of 100
+ input_iterator = iter(event_ids)
+ for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
+ yield self.db.runInteraction(
+ "have_seen_events", have_seen_events_txn, chunk
+ )
+ return results
+
+ def _get_total_state_event_counts_txn(self, txn, room_id):
+ """
+ See get_total_state_event_counts.
+ """
+ # We join against the events table as that has an index on room_id
+ sql = """
+ SELECT COUNT(*) FROM state_events
+ INNER JOIN events USING (room_id, event_id)
+ WHERE room_id=?
+ """
+ txn.execute(sql, (room_id,))
+ row = txn.fetchone()
+ return row[0] if row else 0
+
+ def get_total_state_event_counts(self, room_id):
+ """
+ Gets the total number of state events in a room.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[int]
+ """
+ return self.db.runInteraction(
+ "get_total_state_event_counts",
+ self._get_total_state_event_counts_txn,
+ room_id,
+ )
+
+ def _get_current_state_event_counts_txn(self, txn, room_id):
+ """
+ See get_current_state_event_counts.
+ """
+ sql = "SELECT COUNT(*) FROM current_state_events WHERE room_id=?"
+ txn.execute(sql, (room_id,))
+ row = txn.fetchone()
+ return row[0] if row else 0
+
+ def get_current_state_event_counts(self, room_id):
+ """
+ Gets the current number of state events in a room.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[int]
+ """
+ return self.db.runInteraction(
+ "get_current_state_event_counts",
+ self._get_current_state_event_counts_txn,
+ room_id,
+ )
+
+ @defer.inlineCallbacks
+ def get_room_complexity(self, room_id):
+ """
+ Get a rough approximation of the complexity of the room. This is used by
+ remote servers to decide whether they wish to join the room or not.
+ Higher complexity value indicates that being in the room will consume
+ more resources.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[dict[str:int]] of complexity version to complexity.
+ """
+ state_events = yield self.get_current_state_event_counts(room_id)
+
+ # Call this one "v1", so we can introduce new ones as we want to develop
+ # it.
+ complexity_v1 = round(state_events / 500, 2)
+
+ return {"v1": complexity_v1}
diff --git a/synapse/storage/filtering.py b/synapse/storage/data_stores/main/filtering.py
index b195dc66a0..342d6622a4 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/data_stores/main/filtering.py
@@ -15,13 +15,10 @@
from canonicaljson import encode_canonical_json
-from twisted.internet import defer
-
from synapse.api.errors import Codes, SynapseError
+from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.util.caches.descriptors import cachedInlineCallbacks
-from ._base import SQLBaseStore, db_to_json
-
class FilteringStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=2)
@@ -33,7 +30,7 @@ class FilteringStore(SQLBaseStore):
except ValueError:
raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
- def_json = yield self._simple_select_one_onecol(
+ def_json = yield self.db.simple_select_one_onecol(
table="user_filters",
keyvalues={"user_id": user_localpart, "filter_id": filter_id},
retcol="filter_json",
@@ -41,7 +38,7 @@ class FilteringStore(SQLBaseStore):
desc="get_user_filter",
)
- defer.returnValue(db_to_json(def_json))
+ return db_to_json(def_json)
def add_user_filter(self, user_localpart, user_filter):
def_json = encode_canonical_json(user_filter)
@@ -53,12 +50,12 @@ class FilteringStore(SQLBaseStore):
"SELECT filter_id FROM user_filters "
"WHERE user_id = ? AND filter_json = ?"
)
- txn.execute(sql, (user_localpart, def_json))
+ txn.execute(sql, (user_localpart, bytearray(def_json)))
filter_id_response = txn.fetchone()
if filter_id_response is not None:
return filter_id_response[0]
- sql = "SELECT MAX(filter_id) FROM user_filters " "WHERE user_id = ?"
+ sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
txn.execute(sql, (user_localpart,))
max_id = txn.fetchone()[0]
if max_id is None:
@@ -70,8 +67,8 @@ class FilteringStore(SQLBaseStore):
"INSERT INTO user_filters (user_id, filter_id, filter_json)"
"VALUES(?, ?, ?)"
)
- txn.execute(sql, (user_localpart, filter_id, def_json))
+ txn.execute(sql, (user_localpart, filter_id, bytearray(def_json)))
return filter_id
- return self.runInteraction("add_user_filter", _do_txn)
+ return self.db.runInteraction("add_user_filter", _do_txn)
diff --git a/synapse/storage/group_server.py b/synapse/storage/data_stores/main/group_server.py
index dce6a43ac1..0963e6c250 100644
--- a/synapse/storage/group_server.py
+++ b/synapse/storage/data_stores/main/group_server.py
@@ -19,8 +19,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import SynapseError
-
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
# The category ID for the "default" category. We don't store as null in the
# database to avoid the fun of null != null
@@ -28,23 +27,9 @@ _DEFAULT_CATEGORY_ID = ""
_DEFAULT_ROLE_ID = ""
-class GroupServerStore(SQLBaseStore):
- def set_group_join_policy(self, group_id, join_policy):
- """Set the join policy of a group.
-
- join_policy can be one of:
- * "invite"
- * "open"
- """
- return self._simple_update_one(
- table="groups",
- keyvalues={"group_id": group_id},
- updatevalues={"join_policy": join_policy},
- desc="set_group_join_policy",
- )
-
+class GroupServerWorkerStore(SQLBaseStore):
def get_group(self, group_id):
- return self._simple_select_one(
+ return self.db.simple_select_one(
table="groups",
keyvalues={"group_id": group_id},
retcols=(
@@ -66,7 +51,7 @@ class GroupServerStore(SQLBaseStore):
if not include_private:
keyvalues["is_public"] = True
- return self._simple_select_list(
+ return self.db.simple_select_list(
table="group_users",
keyvalues=keyvalues,
retcols=("user_id", "is_public", "is_admin"),
@@ -76,7 +61,7 @@ class GroupServerStore(SQLBaseStore):
def get_invited_users_in_group(self, group_id):
# TODO: Pagination
- return self._simple_select_onecol(
+ return self.db.simple_select_onecol(
table="group_invites",
keyvalues={"group_id": group_id},
retcol="user_id",
@@ -90,7 +75,7 @@ class GroupServerStore(SQLBaseStore):
if not include_private:
keyvalues["is_public"] = True
- return self._simple_select_list(
+ return self.db.simple_select_list(
table="group_rooms",
keyvalues=keyvalues,
retcols=("room_id", "is_public"),
@@ -154,10 +139,372 @@ class GroupServerStore(SQLBaseStore):
return rooms, categories
- return self.runInteraction("get_rooms_for_summary", _get_rooms_for_summary_txn)
+ return self.db.runInteraction(
+ "get_rooms_for_summary", _get_rooms_for_summary_txn
+ )
+
+ @defer.inlineCallbacks
+ def get_group_categories(self, group_id):
+ rows = yield self.db.simple_select_list(
+ table="group_room_categories",
+ keyvalues={"group_id": group_id},
+ retcols=("category_id", "is_public", "profile"),
+ desc="get_group_categories",
+ )
+
+ return {
+ row["category_id"]: {
+ "is_public": row["is_public"],
+ "profile": json.loads(row["profile"]),
+ }
+ for row in rows
+ }
+
+ @defer.inlineCallbacks
+ def get_group_category(self, group_id, category_id):
+ category = yield self.db.simple_select_one(
+ table="group_room_categories",
+ keyvalues={"group_id": group_id, "category_id": category_id},
+ retcols=("is_public", "profile"),
+ desc="get_group_category",
+ )
+
+ category["profile"] = json.loads(category["profile"])
+
+ return category
+
+ @defer.inlineCallbacks
+ def get_group_roles(self, group_id):
+ rows = yield self.db.simple_select_list(
+ table="group_roles",
+ keyvalues={"group_id": group_id},
+ retcols=("role_id", "is_public", "profile"),
+ desc="get_group_roles",
+ )
+
+ return {
+ row["role_id"]: {
+ "is_public": row["is_public"],
+ "profile": json.loads(row["profile"]),
+ }
+ for row in rows
+ }
+
+ @defer.inlineCallbacks
+ def get_group_role(self, group_id, role_id):
+ role = yield self.db.simple_select_one(
+ table="group_roles",
+ keyvalues={"group_id": group_id, "role_id": role_id},
+ retcols=("is_public", "profile"),
+ desc="get_group_role",
+ )
+
+ role["profile"] = json.loads(role["profile"])
+
+ return role
+
+ def get_local_groups_for_room(self, room_id):
+ """Get all of the local group that contain a given room
+ Args:
+ room_id (str): The ID of a room
+ Returns:
+ Deferred[list[str]]: A twisted.Deferred containing a list of group ids
+ containing this room
+ """
+ return self.db.simple_select_onecol(
+ table="group_rooms",
+ keyvalues={"room_id": room_id},
+ retcol="group_id",
+ desc="get_local_groups_for_room",
+ )
+
+ def get_users_for_summary_by_role(self, group_id, include_private=False):
+ """Get the users and roles that should be included in a summary request
+
+ Returns ([users], [roles])
+ """
+
+ def _get_users_for_summary_txn(txn):
+ keyvalues = {"group_id": group_id}
+ if not include_private:
+ keyvalues["is_public"] = True
+
+ sql = """
+ SELECT user_id, is_public, role_id, user_order
+ FROM group_summary_users
+ WHERE group_id = ?
+ """
+
+ if not include_private:
+ sql += " AND is_public = ?"
+ txn.execute(sql, (group_id, True))
+ else:
+ txn.execute(sql, (group_id,))
+
+ users = [
+ {
+ "user_id": row[0],
+ "is_public": row[1],
+ "role_id": row[2] if row[2] != _DEFAULT_ROLE_ID else None,
+ "order": row[3],
+ }
+ for row in txn
+ ]
+
+ sql = """
+ SELECT role_id, is_public, profile, role_order
+ FROM group_summary_roles
+ INNER JOIN group_roles USING (group_id, role_id)
+ WHERE group_id = ?
+ """
+
+ if not include_private:
+ sql += " AND is_public = ?"
+ txn.execute(sql, (group_id, True))
+ else:
+ txn.execute(sql, (group_id,))
+
+ roles = {
+ row[0]: {
+ "is_public": row[1],
+ "profile": json.loads(row[2]),
+ "order": row[3],
+ }
+ for row in txn
+ }
+
+ return users, roles
+
+ return self.db.runInteraction(
+ "get_users_for_summary_by_role", _get_users_for_summary_txn
+ )
+
+ def is_user_in_group(self, user_id, group_id):
+ return self.db.simple_select_one_onecol(
+ table="group_users",
+ keyvalues={"group_id": group_id, "user_id": user_id},
+ retcol="user_id",
+ allow_none=True,
+ desc="is_user_in_group",
+ ).addCallback(lambda r: bool(r))
+
+ def is_user_admin_in_group(self, group_id, user_id):
+ return self.db.simple_select_one_onecol(
+ table="group_users",
+ keyvalues={"group_id": group_id, "user_id": user_id},
+ retcol="is_admin",
+ allow_none=True,
+ desc="is_user_admin_in_group",
+ )
+
+ def is_user_invited_to_local_group(self, group_id, user_id):
+ """Has the group server invited a user?
+ """
+ return self.db.simple_select_one_onecol(
+ table="group_invites",
+ keyvalues={"group_id": group_id, "user_id": user_id},
+ retcol="user_id",
+ desc="is_user_invited_to_local_group",
+ allow_none=True,
+ )
+
+ def get_users_membership_info_in_group(self, group_id, user_id):
+ """Get a dict describing the membership of a user in a group.
+
+ Example if joined:
+
+ {
+ "membership": "join",
+ "is_public": True,
+ "is_privileged": False,
+ }
+
+ Returns an empty dict if the user is not join/invite/etc
+ """
+
+ def _get_users_membership_in_group_txn(txn):
+ row = self.db.simple_select_one_txn(
+ txn,
+ table="group_users",
+ keyvalues={"group_id": group_id, "user_id": user_id},
+ retcols=("is_admin", "is_public"),
+ allow_none=True,
+ )
+
+ if row:
+ return {
+ "membership": "join",
+ "is_public": row["is_public"],
+ "is_privileged": row["is_admin"],
+ }
+
+ row = self.db.simple_select_one_onecol_txn(
+ txn,
+ table="group_invites",
+ keyvalues={"group_id": group_id, "user_id": user_id},
+ retcol="user_id",
+ allow_none=True,
+ )
+
+ if row:
+ return {"membership": "invite"}
+
+ return {}
+
+ return self.db.runInteraction(
+ "get_users_membership_info_in_group", _get_users_membership_in_group_txn
+ )
+
+ def get_publicised_groups_for_user(self, user_id):
+ """Get all groups a user is publicising
+ """
+ return self.db.simple_select_onecol(
+ table="local_group_membership",
+ keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
+ retcol="group_id",
+ desc="get_publicised_groups_for_user",
+ )
+
+ def get_attestations_need_renewals(self, valid_until_ms):
+ """Get all attestations that need to be renewed until givent time
+ """
+
+ def _get_attestations_need_renewals_txn(txn):
+ sql = """
+ SELECT group_id, user_id FROM group_attestations_renewals
+ WHERE valid_until_ms <= ?
+ """
+ txn.execute(sql, (valid_until_ms,))
+ return self.db.cursor_to_dict(txn)
+
+ return self.db.runInteraction(
+ "get_attestations_need_renewals", _get_attestations_need_renewals_txn
+ )
+
+ @defer.inlineCallbacks
+ def get_remote_attestation(self, group_id, user_id):
+ """Get the attestation that proves the remote agrees that the user is
+ in the group.
+ """
+ row = yield self.db.simple_select_one(
+ table="group_attestations_remote",
+ keyvalues={"group_id": group_id, "user_id": user_id},
+ retcols=("valid_until_ms", "attestation_json"),
+ desc="get_remote_attestation",
+ allow_none=True,
+ )
+
+ now = int(self._clock.time_msec())
+ if row and now < row["valid_until_ms"]:
+ return json.loads(row["attestation_json"])
+
+ return None
+
+ def get_joined_groups(self, user_id):
+ return self.db.simple_select_onecol(
+ table="local_group_membership",
+ keyvalues={"user_id": user_id, "membership": "join"},
+ retcol="group_id",
+ desc="get_joined_groups",
+ )
+
+ def get_all_groups_for_user(self, user_id, now_token):
+ def _get_all_groups_for_user_txn(txn):
+ sql = """
+ SELECT group_id, type, membership, u.content
+ FROM local_group_updates AS u
+ INNER JOIN local_group_membership USING (group_id, user_id)
+ WHERE user_id = ? AND membership != 'leave'
+ AND stream_id <= ?
+ """
+ txn.execute(sql, (user_id, now_token))
+ return [
+ {
+ "group_id": row[0],
+ "type": row[1],
+ "membership": row[2],
+ "content": json.loads(row[3]),
+ }
+ for row in txn
+ ]
+
+ return self.db.runInteraction(
+ "get_all_groups_for_user", _get_all_groups_for_user_txn
+ )
+
+ def get_groups_changes_for_user(self, user_id, from_token, to_token):
+ from_token = int(from_token)
+ has_changed = self._group_updates_stream_cache.has_entity_changed(
+ user_id, from_token
+ )
+ if not has_changed:
+ return defer.succeed([])
+
+ def _get_groups_changes_for_user_txn(txn):
+ sql = """
+ SELECT group_id, membership, type, u.content
+ FROM local_group_updates AS u
+ INNER JOIN local_group_membership USING (group_id, user_id)
+ WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
+ """
+ txn.execute(sql, (user_id, from_token, to_token))
+ return [
+ {
+ "group_id": group_id,
+ "membership": membership,
+ "type": gtype,
+ "content": json.loads(content_json),
+ }
+ for group_id, membership, gtype, content_json in txn
+ ]
+
+ return self.db.runInteraction(
+ "get_groups_changes_for_user", _get_groups_changes_for_user_txn
+ )
+
+ def get_all_groups_changes(self, from_token, to_token, limit):
+ from_token = int(from_token)
+ has_changed = self._group_updates_stream_cache.has_any_entity_changed(
+ from_token
+ )
+ if not has_changed:
+ return defer.succeed([])
+
+ def _get_all_groups_changes_txn(txn):
+ sql = """
+ SELECT stream_id, group_id, user_id, type, content
+ FROM local_group_updates
+ WHERE ? < stream_id AND stream_id <= ?
+ LIMIT ?
+ """
+ txn.execute(sql, (from_token, to_token, limit))
+ return [
+ (stream_id, group_id, user_id, gtype, json.loads(content_json))
+ for stream_id, group_id, user_id, gtype, content_json in txn
+ ]
+
+ return self.db.runInteraction(
+ "get_all_groups_changes", _get_all_groups_changes_txn
+ )
+
+
+class GroupServerStore(GroupServerWorkerStore):
+ def set_group_join_policy(self, group_id, join_policy):
+ """Set the join policy of a group.
+
+ join_policy can be one of:
+ * "invite"
+ * "open"
+ """
+ return self.db.simple_update_one(
+ table="groups",
+ keyvalues={"group_id": group_id},
+ updatevalues={"join_policy": join_policy},
+ desc="set_group_join_policy",
+ )
def add_room_to_summary(self, group_id, room_id, category_id, order, is_public):
- return self.runInteraction(
+ return self.db.runInteraction(
"add_room_to_summary",
self._add_room_to_summary_txn,
group_id,
@@ -181,7 +528,7 @@ class GroupServerStore(SQLBaseStore):
an order of 1 will put the room first. Otherwise, the room gets
added to the end.
"""
- room_in_group = self._simple_select_one_onecol_txn(
+ room_in_group = self.db.simple_select_one_onecol_txn(
txn,
table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
@@ -194,7 +541,7 @@ class GroupServerStore(SQLBaseStore):
if category_id is None:
category_id = _DEFAULT_CATEGORY_ID
else:
- cat_exists = self._simple_select_one_onecol_txn(
+ cat_exists = self.db.simple_select_one_onecol_txn(
txn,
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
@@ -205,7 +552,7 @@ class GroupServerStore(SQLBaseStore):
raise SynapseError(400, "Category doesn't exist")
# TODO: Check category is part of summary already
- cat_exists = self._simple_select_one_onecol_txn(
+ cat_exists = self.db.simple_select_one_onecol_txn(
txn,
table="group_summary_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
@@ -225,7 +572,7 @@ class GroupServerStore(SQLBaseStore):
(group_id, category_id, group_id, category_id),
)
- existing = self._simple_select_one_txn(
+ existing = self.db.simple_select_one_txn(
txn,
table="group_summary_rooms",
keyvalues={
@@ -250,7 +597,7 @@ class GroupServerStore(SQLBaseStore):
WHERE group_id = ? AND category_id = ?
"""
txn.execute(sql, (group_id, category_id))
- order, = txn.fetchone()
+ (order,) = txn.fetchone()
if existing:
to_update = {}
@@ -258,7 +605,7 @@ class GroupServerStore(SQLBaseStore):
to_update["room_order"] = order
if is_public is not None:
to_update["is_public"] = is_public
- self._simple_update_txn(
+ self.db.simple_update_txn(
txn,
table="group_summary_rooms",
keyvalues={
@@ -272,7 +619,7 @@ class GroupServerStore(SQLBaseStore):
if is_public is None:
is_public = True
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="group_summary_rooms",
values={
@@ -288,7 +635,7 @@ class GroupServerStore(SQLBaseStore):
if category_id is None:
category_id = _DEFAULT_CATEGORY_ID
- return self._simple_delete(
+ return self.db.simple_delete(
table="group_summary_rooms",
keyvalues={
"group_id": group_id,
@@ -298,38 +645,6 @@ class GroupServerStore(SQLBaseStore):
desc="remove_room_from_summary",
)
- @defer.inlineCallbacks
- def get_group_categories(self, group_id):
- rows = yield self._simple_select_list(
- table="group_room_categories",
- keyvalues={"group_id": group_id},
- retcols=("category_id", "is_public", "profile"),
- desc="get_group_categories",
- )
-
- defer.returnValue(
- {
- row["category_id"]: {
- "is_public": row["is_public"],
- "profile": json.loads(row["profile"]),
- }
- for row in rows
- }
- )
-
- @defer.inlineCallbacks
- def get_group_category(self, group_id, category_id):
- category = yield self._simple_select_one(
- table="group_room_categories",
- keyvalues={"group_id": group_id, "category_id": category_id},
- retcols=("is_public", "profile"),
- desc="get_group_category",
- )
-
- category["profile"] = json.loads(category["profile"])
-
- defer.returnValue(category)
-
def upsert_group_category(self, group_id, category_id, profile, is_public):
"""Add/update room category for group
"""
@@ -346,7 +661,7 @@ class GroupServerStore(SQLBaseStore):
else:
update_values["is_public"] = is_public
- return self._simple_upsert(
+ return self.db.simple_upsert(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
values=update_values,
@@ -355,44 +670,12 @@ class GroupServerStore(SQLBaseStore):
)
def remove_group_category(self, group_id, category_id):
- return self._simple_delete(
+ return self.db.simple_delete(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
desc="remove_group_category",
)
- @defer.inlineCallbacks
- def get_group_roles(self, group_id):
- rows = yield self._simple_select_list(
- table="group_roles",
- keyvalues={"group_id": group_id},
- retcols=("role_id", "is_public", "profile"),
- desc="get_group_roles",
- )
-
- defer.returnValue(
- {
- row["role_id"]: {
- "is_public": row["is_public"],
- "profile": json.loads(row["profile"]),
- }
- for row in rows
- }
- )
-
- @defer.inlineCallbacks
- def get_group_role(self, group_id, role_id):
- role = yield self._simple_select_one(
- table="group_roles",
- keyvalues={"group_id": group_id, "role_id": role_id},
- retcols=("is_public", "profile"),
- desc="get_group_role",
- )
-
- role["profile"] = json.loads(role["profile"])
-
- defer.returnValue(role)
-
def upsert_group_role(self, group_id, role_id, profile, is_public):
"""Add/remove user role
"""
@@ -409,7 +692,7 @@ class GroupServerStore(SQLBaseStore):
else:
update_values["is_public"] = is_public
- return self._simple_upsert(
+ return self.db.simple_upsert(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
values=update_values,
@@ -418,14 +701,14 @@ class GroupServerStore(SQLBaseStore):
)
def remove_group_role(self, group_id, role_id):
- return self._simple_delete(
+ return self.db.simple_delete(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
desc="remove_group_role",
)
def add_user_to_summary(self, group_id, user_id, role_id, order, is_public):
- return self.runInteraction(
+ return self.db.runInteraction(
"add_user_to_summary",
self._add_user_to_summary_txn,
group_id,
@@ -449,7 +732,7 @@ class GroupServerStore(SQLBaseStore):
an order of 1 will put the user first. Otherwise, the user gets
added to the end.
"""
- user_in_group = self._simple_select_one_onecol_txn(
+ user_in_group = self.db.simple_select_one_onecol_txn(
txn,
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -462,7 +745,7 @@ class GroupServerStore(SQLBaseStore):
if role_id is None:
role_id = _DEFAULT_ROLE_ID
else:
- role_exists = self._simple_select_one_onecol_txn(
+ role_exists = self.db.simple_select_one_onecol_txn(
txn,
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
@@ -473,7 +756,7 @@ class GroupServerStore(SQLBaseStore):
raise SynapseError(400, "Role doesn't exist")
# TODO: Check role is part of the summary already
- role_exists = self._simple_select_one_onecol_txn(
+ role_exists = self.db.simple_select_one_onecol_txn(
txn,
table="group_summary_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
@@ -493,7 +776,7 @@ class GroupServerStore(SQLBaseStore):
(group_id, role_id, group_id, role_id),
)
- existing = self._simple_select_one_txn(
+ existing = self.db.simple_select_one_txn(
txn,
table="group_summary_users",
keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id},
@@ -514,7 +797,7 @@ class GroupServerStore(SQLBaseStore):
WHERE group_id = ? AND role_id = ?
"""
txn.execute(sql, (group_id, role_id))
- order, = txn.fetchone()
+ (order,) = txn.fetchone()
if existing:
to_update = {}
@@ -522,7 +805,7 @@ class GroupServerStore(SQLBaseStore):
to_update["user_order"] = order
if is_public is not None:
to_update["is_public"] = is_public
- self._simple_update_txn(
+ self.db.simple_update_txn(
txn,
table="group_summary_users",
keyvalues={
@@ -536,7 +819,7 @@ class GroupServerStore(SQLBaseStore):
if is_public is None:
is_public = True
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="group_summary_users",
values={
@@ -552,158 +835,21 @@ class GroupServerStore(SQLBaseStore):
if role_id is None:
role_id = _DEFAULT_ROLE_ID
- return self._simple_delete(
+ return self.db.simple_delete(
table="group_summary_users",
keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id},
desc="remove_user_from_summary",
)
- def get_users_for_summary_by_role(self, group_id, include_private=False):
- """Get the users and roles that should be included in a summary request
-
- Returns ([users], [roles])
- """
-
- def _get_users_for_summary_txn(txn):
- keyvalues = {"group_id": group_id}
- if not include_private:
- keyvalues["is_public"] = True
-
- sql = """
- SELECT user_id, is_public, role_id, user_order
- FROM group_summary_users
- WHERE group_id = ?
- """
-
- if not include_private:
- sql += " AND is_public = ?"
- txn.execute(sql, (group_id, True))
- else:
- txn.execute(sql, (group_id,))
-
- users = [
- {
- "user_id": row[0],
- "is_public": row[1],
- "role_id": row[2] if row[2] != _DEFAULT_ROLE_ID else None,
- "order": row[3],
- }
- for row in txn
- ]
-
- sql = """
- SELECT role_id, is_public, profile, role_order
- FROM group_summary_roles
- INNER JOIN group_roles USING (group_id, role_id)
- WHERE group_id = ?
- """
-
- if not include_private:
- sql += " AND is_public = ?"
- txn.execute(sql, (group_id, True))
- else:
- txn.execute(sql, (group_id,))
-
- roles = {
- row[0]: {
- "is_public": row[1],
- "profile": json.loads(row[2]),
- "order": row[3],
- }
- for row in txn
- }
-
- return users, roles
-
- return self.runInteraction(
- "get_users_for_summary_by_role", _get_users_for_summary_txn
- )
-
- def is_user_in_group(self, user_id, group_id):
- return self._simple_select_one_onecol(
- table="group_users",
- keyvalues={"group_id": group_id, "user_id": user_id},
- retcol="user_id",
- allow_none=True,
- desc="is_user_in_group",
- ).addCallback(lambda r: bool(r))
-
- def is_user_admin_in_group(self, group_id, user_id):
- return self._simple_select_one_onecol(
- table="group_users",
- keyvalues={"group_id": group_id, "user_id": user_id},
- retcol="is_admin",
- allow_none=True,
- desc="is_user_admin_in_group",
- )
-
def add_group_invite(self, group_id, user_id):
"""Record that the group server has invited a user
"""
- return self._simple_insert(
+ return self.db.simple_insert(
table="group_invites",
values={"group_id": group_id, "user_id": user_id},
desc="add_group_invite",
)
- def is_user_invited_to_local_group(self, group_id, user_id):
- """Has the group server invited a user?
- """
- return self._simple_select_one_onecol(
- table="group_invites",
- keyvalues={"group_id": group_id, "user_id": user_id},
- retcol="user_id",
- desc="is_user_invited_to_local_group",
- allow_none=True,
- )
-
- def get_users_membership_info_in_group(self, group_id, user_id):
- """Get a dict describing the membership of a user in a group.
-
- Example if joined:
-
- {
- "membership": "join",
- "is_public": True,
- "is_privileged": False,
- }
-
- Returns an empty dict if the user is not join/invite/etc
- """
-
- def _get_users_membership_in_group_txn(txn):
- row = self._simple_select_one_txn(
- txn,
- table="group_users",
- keyvalues={"group_id": group_id, "user_id": user_id},
- retcols=("is_admin", "is_public"),
- allow_none=True,
- )
-
- if row:
- return {
- "membership": "join",
- "is_public": row["is_public"],
- "is_privileged": row["is_admin"],
- }
-
- row = self._simple_select_one_onecol_txn(
- txn,
- table="group_invites",
- keyvalues={"group_id": group_id, "user_id": user_id},
- retcol="user_id",
- allow_none=True,
- )
-
- if row:
- return {"membership": "invite"}
-
- return {}
-
- return self.runInteraction(
- "get_users_membership_info_in_group", _get_users_membership_in_group_txn
- )
-
def add_user_to_group(
self,
group_id,
@@ -728,7 +874,7 @@ class GroupServerStore(SQLBaseStore):
"""
def _add_user_to_group_txn(txn):
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="group_users",
values={
@@ -739,14 +885,14 @@ class GroupServerStore(SQLBaseStore):
},
)
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
)
if local_attestation:
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="group_attestations_renewals",
values={
@@ -756,7 +902,7 @@ class GroupServerStore(SQLBaseStore):
},
)
if remote_attestation:
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="group_attestations_remote",
values={
@@ -767,49 +913,49 @@ class GroupServerStore(SQLBaseStore):
},
)
- return self.runInteraction("add_user_to_group", _add_user_to_group_txn)
+ return self.db.runInteraction("add_user_to_group", _add_user_to_group_txn)
def remove_user_from_group(self, group_id, user_id):
def _remove_user_from_group_txn(txn):
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_summary_users",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- return self.runInteraction(
+ return self.db.runInteraction(
"remove_user_from_group", _remove_user_from_group_txn
)
def add_room_to_group(self, group_id, room_id, is_public):
- return self._simple_insert(
+ return self.db.simple_insert(
table="group_rooms",
values={"group_id": group_id, "room_id": room_id, "is_public": is_public},
desc="add_room_to_group",
)
def update_room_in_group_visibility(self, group_id, room_id, is_public):
- return self._simple_update(
+ return self.db.simple_update(
table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
updatevalues={"is_public": is_public},
@@ -818,36 +964,26 @@ class GroupServerStore(SQLBaseStore):
def remove_room_from_group(self, group_id, room_id):
def _remove_room_from_group_txn(txn):
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
)
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_summary_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
)
- return self.runInteraction(
+ return self.db.runInteraction(
"remove_room_from_group", _remove_room_from_group_txn
)
- def get_publicised_groups_for_user(self, user_id):
- """Get all groups a user is publicising
- """
- return self._simple_select_onecol(
- table="local_group_membership",
- keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
- retcol="group_id",
- desc="get_publicised_groups_for_user",
- )
-
def update_group_publicity(self, group_id, user_id, publicise):
"""Update whether the user is publicising their membership of the group
"""
- return self._simple_update_one(
+ return self.db.simple_update_one(
table="local_group_membership",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"is_publicised": publicise},
@@ -883,12 +1019,12 @@ class GroupServerStore(SQLBaseStore):
def _register_user_group_membership_txn(txn, next_id):
# TODO: Upsert?
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="local_group_membership",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="local_group_membership",
values={
@@ -901,7 +1037,7 @@ class GroupServerStore(SQLBaseStore):
},
)
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="local_group_updates",
values={
@@ -920,7 +1056,7 @@ class GroupServerStore(SQLBaseStore):
if membership == "join":
if local_attestation:
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="group_attestations_renewals",
values={
@@ -930,7 +1066,7 @@ class GroupServerStore(SQLBaseStore):
},
)
if remote_attestation:
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="group_attestations_remote",
values={
@@ -941,12 +1077,12 @@ class GroupServerStore(SQLBaseStore):
},
)
else:
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -955,18 +1091,18 @@ class GroupServerStore(SQLBaseStore):
return next_id
with self._group_updates_id_gen.get_next() as next_id:
- res = yield self.runInteraction(
+ res = yield self.db.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
next_id,
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def create_group(
self, group_id, user_id, name, avatar_url, short_description, long_description
):
- yield self._simple_insert(
+ yield self.db.simple_insert(
table="groups",
values={
"group_id": group_id,
@@ -981,33 +1117,17 @@ class GroupServerStore(SQLBaseStore):
@defer.inlineCallbacks
def update_group_profile(self, group_id, profile):
- yield self._simple_update_one(
+ yield self.db.simple_update_one(
table="groups",
keyvalues={"group_id": group_id},
updatevalues=profile,
desc="update_group_profile",
)
- def get_attestations_need_renewals(self, valid_until_ms):
- """Get all attestations that need to be renewed until givent time
- """
-
- def _get_attestations_need_renewals_txn(txn):
- sql = """
- SELECT group_id, user_id FROM group_attestations_renewals
- WHERE valid_until_ms <= ?
- """
- txn.execute(sql, (valid_until_ms,))
- return self.cursor_to_dict(txn)
-
- return self.runInteraction(
- "get_attestations_need_renewals", _get_attestations_need_renewals_txn
- )
-
def update_attestation_renewal(self, group_id, user_id, attestation):
"""Update an attestation that we have renewed
"""
- return self._simple_update_one(
+ return self.db.simple_update_one(
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"valid_until_ms": attestation["valid_until_ms"]},
@@ -1017,7 +1137,7 @@ class GroupServerStore(SQLBaseStore):
def update_remote_attestion(self, group_id, user_id, attestation):
"""Update an attestation that a remote has renewed
"""
- return self._simple_update_one(
+ return self.db.simple_update_one(
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={
@@ -1036,118 +1156,12 @@ class GroupServerStore(SQLBaseStore):
group_id (str)
user_id (str)
"""
- return self._simple_delete(
+ return self.db.simple_delete(
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
desc="remove_attestation_renewal",
)
- @defer.inlineCallbacks
- def get_remote_attestation(self, group_id, user_id):
- """Get the attestation that proves the remote agrees that the user is
- in the group.
- """
- row = yield self._simple_select_one(
- table="group_attestations_remote",
- keyvalues={"group_id": group_id, "user_id": user_id},
- retcols=("valid_until_ms", "attestation_json"),
- desc="get_remote_attestation",
- allow_none=True,
- )
-
- now = int(self._clock.time_msec())
- if row and now < row["valid_until_ms"]:
- defer.returnValue(json.loads(row["attestation_json"]))
-
- defer.returnValue(None)
-
- def get_joined_groups(self, user_id):
- return self._simple_select_onecol(
- table="local_group_membership",
- keyvalues={"user_id": user_id, "membership": "join"},
- retcol="group_id",
- desc="get_joined_groups",
- )
-
- def get_all_groups_for_user(self, user_id, now_token):
- def _get_all_groups_for_user_txn(txn):
- sql = """
- SELECT group_id, type, membership, u.content
- FROM local_group_updates AS u
- INNER JOIN local_group_membership USING (group_id, user_id)
- WHERE user_id = ? AND membership != 'leave'
- AND stream_id <= ?
- """
- txn.execute(sql, (user_id, now_token))
- return [
- {
- "group_id": row[0],
- "type": row[1],
- "membership": row[2],
- "content": json.loads(row[3]),
- }
- for row in txn
- ]
-
- return self.runInteraction(
- "get_all_groups_for_user", _get_all_groups_for_user_txn
- )
-
- def get_groups_changes_for_user(self, user_id, from_token, to_token):
- from_token = int(from_token)
- has_changed = self._group_updates_stream_cache.has_entity_changed(
- user_id, from_token
- )
- if not has_changed:
- return []
-
- def _get_groups_changes_for_user_txn(txn):
- sql = """
- SELECT group_id, membership, type, u.content
- FROM local_group_updates AS u
- INNER JOIN local_group_membership USING (group_id, user_id)
- WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
- """
- txn.execute(sql, (user_id, from_token, to_token))
- return [
- {
- "group_id": group_id,
- "membership": membership,
- "type": gtype,
- "content": json.loads(content_json),
- }
- for group_id, membership, gtype, content_json in txn
- ]
-
- return self.runInteraction(
- "get_groups_changes_for_user", _get_groups_changes_for_user_txn
- )
-
- def get_all_groups_changes(self, from_token, to_token, limit):
- from_token = int(from_token)
- has_changed = self._group_updates_stream_cache.has_any_entity_changed(
- from_token
- )
- if not has_changed:
- return []
-
- def _get_all_groups_changes_txn(txn):
- sql = """
- SELECT stream_id, group_id, user_id, type, content
- FROM local_group_updates
- WHERE ? < stream_id AND stream_id <= ?
- LIMIT ?
- """
- txn.execute(sql, (from_token, to_token, limit))
- return [
- (stream_id, group_id, user_id, gtype, json.loads(content_json))
- for stream_id, group_id, user_id, gtype, content_json in txn
- ]
-
- return self.runInteraction(
- "get_all_groups_changes", _get_all_groups_changes_txn
- )
-
def get_group_stream_token(self):
return self._group_updates_id_gen.get_current_token()
@@ -1178,12 +1192,8 @@ class GroupServerStore(SQLBaseStore):
]
for table in tables:
- self._simple_delete_txn(
- txn,
- table=table,
- keyvalues={"group_id": group_id},
+ self.db.simple_delete_txn(
+ txn, table=table, keyvalues={"group_id": group_id}
)
- return self.runInteraction(
- "delete_group", _delete_group_txn
- )
+ return self.db.runInteraction("delete_group", _delete_group_txn)
diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py
new file mode 100644
index 0000000000..ba89c68c9f
--- /dev/null
+++ b/synapse/storage/data_stores/main/keys.py
@@ -0,0 +1,214 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 New Vector Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import logging
+
+import six
+
+from signedjson.key import decode_verify_key_bytes
+
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.keys import FetchKeyResult
+from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.iterutils import batch_iter
+
+logger = logging.getLogger(__name__)
+
+# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
+# despite being deprecated and removed in favor of memoryview
+if six.PY2:
+ db_binary_type = six.moves.builtins.buffer
+else:
+ db_binary_type = memoryview
+
+
+class KeyStore(SQLBaseStore):
+ """Persistence for signature verification keys
+ """
+
+ @cached()
+ def _get_server_verify_key(self, server_name_and_key_id):
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids"
+ )
+ def get_server_verify_keys(self, server_name_and_key_ids):
+ """
+ Args:
+ server_name_and_key_ids (iterable[Tuple[str, str]]):
+ iterable of (server_name, key-id) tuples to fetch keys for
+
+ Returns:
+ Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
+ map from (server_name, key_id) -> FetchKeyResult, or None if the key is
+ unknown
+ """
+ keys = {}
+
+ def _get_keys(txn, batch):
+ """Processes a batch of keys to fetch, and adds the result to `keys`."""
+
+ # batch_iter always returns tuples so it's safe to do len(batch)
+ sql = (
+ "SELECT server_name, key_id, verify_key, ts_valid_until_ms "
+ "FROM server_signature_keys WHERE 1=0"
+ ) + " OR (server_name=? AND key_id=?)" * len(batch)
+
+ txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
+
+ for row in txn:
+ server_name, key_id, key_bytes, ts_valid_until_ms = row
+
+ if ts_valid_until_ms is None:
+ # Old keys may be stored with a ts_valid_until_ms of null,
+ # in which case we treat this as if it was set to `0`, i.e.
+ # it won't match key requests that define a minimum
+ # `ts_valid_until_ms`.
+ ts_valid_until_ms = 0
+
+ res = FetchKeyResult(
+ verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
+ valid_until_ts=ts_valid_until_ms,
+ )
+ keys[(server_name, key_id)] = res
+
+ def _txn(txn):
+ for batch in batch_iter(server_name_and_key_ids, 50):
+ _get_keys(txn, batch)
+ return keys
+
+ return self.db.runInteraction("get_server_verify_keys", _txn)
+
+ def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
+ """Stores NACL verification keys for remote servers.
+ Args:
+ from_server (str): Where the verification keys were looked up
+ ts_added_ms (int): The time to record that the key was added
+ verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
+ keys to be stored. Each entry is a triplet of
+ (server_name, key_id, key).
+ """
+ key_values = []
+ value_values = []
+ invalidations = []
+ for server_name, key_id, fetch_result in verify_keys:
+ key_values.append((server_name, key_id))
+ value_values.append(
+ (
+ from_server,
+ ts_added_ms,
+ fetch_result.valid_until_ts,
+ db_binary_type(fetch_result.verify_key.encode()),
+ )
+ )
+ # invalidate takes a tuple corresponding to the params of
+ # _get_server_verify_key. _get_server_verify_key only takes one
+ # param, which is itself the 2-tuple (server_name, key_id).
+ invalidations.append((server_name, key_id))
+
+ def _invalidate(res):
+ f = self._get_server_verify_key.invalidate
+ for i in invalidations:
+ f((i,))
+ return res
+
+ return self.db.runInteraction(
+ "store_server_verify_keys",
+ self.db.simple_upsert_many_txn,
+ table="server_signature_keys",
+ key_names=("server_name", "key_id"),
+ key_values=key_values,
+ value_names=(
+ "from_server",
+ "ts_added_ms",
+ "ts_valid_until_ms",
+ "verify_key",
+ ),
+ value_values=value_values,
+ ).addCallback(_invalidate)
+
+ def store_server_keys_json(
+ self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
+ ):
+ """Stores the JSON bytes for a set of keys from a server
+ The JSON should be signed by the originating server, the intermediate
+ server, and by this server. Updates the value for the
+ (server_name, key_id, from_server) triplet if one already existed.
+ Args:
+ server_name (str): The name of the server.
+ key_id (str): The identifer of the key this JSON is for.
+ from_server (str): The server this JSON was fetched from.
+ ts_now_ms (int): The time now in milliseconds.
+ ts_valid_until_ms (int): The time when this json stops being valid.
+ key_json (bytes): The encoded JSON.
+ """
+ return self.db.simple_upsert(
+ table="server_keys_json",
+ keyvalues={
+ "server_name": server_name,
+ "key_id": key_id,
+ "from_server": from_server,
+ },
+ values={
+ "server_name": server_name,
+ "key_id": key_id,
+ "from_server": from_server,
+ "ts_added_ms": ts_now_ms,
+ "ts_valid_until_ms": ts_expires_ms,
+ "key_json": db_binary_type(key_json_bytes),
+ },
+ desc="store_server_keys_json",
+ )
+
+ def get_server_keys_json(self, server_keys):
+ """Retrive the key json for a list of server_keys and key ids.
+ If no keys are found for a given server, key_id and source then
+ that server, key_id, and source triplet entry will be an empty list.
+ The JSON is returned as a byte array so that it can be efficiently
+ used in an HTTP response.
+ Args:
+ server_keys (list): List of (server_name, key_id, source) triplets.
+ Returns:
+ Deferred[dict[Tuple[str, str, str|None], list[dict]]]:
+ Dict mapping (server_name, key_id, source) triplets to lists of dicts
+ """
+
+ def _get_server_keys_json_txn(txn):
+ results = {}
+ for server_name, key_id, from_server in server_keys:
+ keyvalues = {"server_name": server_name}
+ if key_id is not None:
+ keyvalues["key_id"] = key_id
+ if from_server is not None:
+ keyvalues["from_server"] = from_server
+ rows = self.db.simple_select_list_txn(
+ txn,
+ "server_keys_json",
+ keyvalues=keyvalues,
+ retcols=(
+ "key_id",
+ "from_server",
+ "ts_added_ms",
+ "ts_valid_until_ms",
+ "key_json",
+ ),
+ )
+ results[(server_name, key_id, from_server)] = rows
+ return results
+
+ return self.db.runInteraction("get_server_keys_json", _get_server_keys_json_txn)
diff --git a/synapse/storage/media_repository.py b/synapse/storage/data_stores/main/media_repository.py
index 3ecf47e7a7..80ca36dedf 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/data_stores/main/media_repository.py
@@ -12,29 +12,37 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
-class MediaRepositoryStore(BackgroundUpdateStore):
- """Persistence for attachments and avatars"""
-
- def __init__(self, db_conn, hs):
- super(MediaRepositoryStore, self).__init__(db_conn, hs)
+class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(MediaRepositoryBackgroundUpdateStore, self).__init__(
+ database, db_conn, hs
+ )
- self.register_background_index_update(
- update_name='local_media_repository_url_idx',
- index_name='local_media_repository_url_idx',
- table='local_media_repository',
- columns=['created_ts'],
- where_clause='url_cache IS NOT NULL',
+ self.db.updates.register_background_index_update(
+ update_name="local_media_repository_url_idx",
+ index_name="local_media_repository_url_idx",
+ table="local_media_repository",
+ columns=["created_ts"],
+ where_clause="url_cache IS NOT NULL",
)
+
+class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
+ """Persistence for attachments and avatars"""
+
+ def __init__(self, database: Database, db_conn, hs):
+ super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
+
def get_local_media(self, media_id):
"""Get the metadata for a local piece of media
Returns:
None if the media_id doesn't exist.
"""
- return self._simple_select_one(
+ return self.db.simple_select_one(
"local_media_repository",
{"media_id": media_id},
(
@@ -59,7 +67,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
user_id,
url_cache=None,
):
- return self._simple_insert(
+ return self.db.simple_insert(
"local_media_repository",
{
"media_id": media_id,
@@ -108,23 +116,23 @@ class MediaRepositoryStore(BackgroundUpdateStore):
return dict(
zip(
(
- 'response_code',
- 'etag',
- 'expires_ts',
- 'og',
- 'media_id',
- 'download_ts',
+ "response_code",
+ "etag",
+ "expires_ts",
+ "og",
+ "media_id",
+ "download_ts",
),
row,
)
)
- return self.runInteraction("get_url_cache", get_url_cache_txn)
+ return self.db.runInteraction("get_url_cache", get_url_cache_txn)
def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts
):
- return self._simple_insert(
+ return self.db.simple_insert(
"local_media_repository_url_cache",
{
"url": url,
@@ -139,7 +147,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
)
def get_local_media_thumbnails(self, media_id):
- return self._simple_select_list(
+ return self.db.simple_select_list(
"local_media_repository_thumbnails",
{"media_id": media_id},
(
@@ -161,7 +169,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
thumbnail_method,
thumbnail_length,
):
- return self._simple_insert(
+ return self.db.simple_insert(
"local_media_repository_thumbnails",
{
"media_id": media_id,
@@ -175,7 +183,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
)
def get_cached_remote_media(self, origin, media_id):
- return self._simple_select_one(
+ return self.db.simple_select_one(
"remote_media_cache",
{"media_origin": origin, "media_id": media_id},
(
@@ -200,7 +208,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
upload_name,
filesystem_id,
):
- return self._simple_insert(
+ return self.db.simple_insert(
"remote_media_cache",
{
"media_origin": origin,
@@ -245,10 +253,12 @@ class MediaRepositoryStore(BackgroundUpdateStore):
txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
- return self.runInteraction("update_cached_last_access_time", update_cache_txn)
+ return self.db.runInteraction(
+ "update_cached_last_access_time", update_cache_txn
+ )
def get_remote_media_thumbnails(self, origin, media_id):
- return self._simple_select_list(
+ return self.db.simple_select_list(
"remote_media_cache_thumbnails",
{"media_origin": origin, "media_id": media_id},
(
@@ -273,7 +283,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
thumbnail_method,
thumbnail_length,
):
- return self._simple_insert(
+ return self.db.simple_insert(
"remote_media_cache_thumbnails",
{
"media_origin": origin,
@@ -295,24 +305,24 @@ class MediaRepositoryStore(BackgroundUpdateStore):
" WHERE last_access_ts < ?"
)
- return self._execute(
- "get_remote_media_before", self.cursor_to_dict, sql, before_ts
+ return self.db.execute(
+ "get_remote_media_before", self.db.cursor_to_dict, sql, before_ts
)
def delete_remote_media(self, media_origin, media_id):
def delete_remote_media_txn(txn):
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
"remote_media_cache",
keyvalues={"media_origin": media_origin, "media_id": media_id},
)
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
"remote_media_cache_thumbnails",
keyvalues={"media_origin": media_origin, "media_id": media_id},
)
- return self.runInteraction("delete_remote_media", delete_remote_media_txn)
+ return self.db.runInteraction("delete_remote_media", delete_remote_media_txn)
def get_expired_url_cache(self, now_ts):
sql = (
@@ -326,18 +336,20 @@ class MediaRepositoryStore(BackgroundUpdateStore):
txn.execute(sql, (now_ts,))
return [row[0] for row in txn]
- return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn)
+ return self.db.runInteraction(
+ "get_expired_url_cache", _get_expired_url_cache_txn
+ )
def delete_url_cache(self, media_ids):
if len(media_ids) == 0:
return
- sql = "DELETE FROM local_media_repository_url_cache" " WHERE media_id = ?"
+ sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
def _delete_url_cache_txn(txn):
txn.executemany(sql, [(media_id,) for media_id in media_ids])
- return self.runInteraction("delete_url_cache", _delete_url_cache_txn)
+ return self.db.runInteraction("delete_url_cache", _delete_url_cache_txn)
def get_url_cache_media_before(self, before_ts):
sql = (
@@ -351,7 +363,7 @@ class MediaRepositoryStore(BackgroundUpdateStore):
txn.execute(sql, (before_ts,))
return [row[0] for row in txn]
- return self.runInteraction(
+ return self.db.runInteraction(
"get_url_cache_media_before", _get_url_cache_media_before_txn
)
@@ -360,14 +372,14 @@ class MediaRepositoryStore(BackgroundUpdateStore):
return
def _delete_url_cache_media_txn(txn):
- sql = "DELETE FROM local_media_repository" " WHERE media_id = ?"
+ sql = "DELETE FROM local_media_repository WHERE media_id = ?"
txn.executemany(sql, [(media_id,) for media_id in media_ids])
- sql = "DELETE FROM local_media_repository_thumbnails" " WHERE media_id = ?"
+ sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
txn.executemany(sql, [(media_id,) for media_id in media_ids])
- return self.runInteraction(
+ return self.db.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn
)
diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py
index 8aa8abc470..925bc5691b 100644
--- a/synapse/storage/monthly_active_users.py
+++ b/synapse/storage/data_stores/main/monthly_active_users.py
@@ -16,10 +16,10 @@ import logging
from twisted.internet import defer
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
from synapse.util.caches.descriptors import cached
-from ._base import SQLBaseStore
-
logger = logging.getLogger(__name__)
# Number of msec of granularity to store the monthly_active_user timestamp
@@ -27,15 +27,105 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 60 * 60 * 1000
-class MonthlyActiveUsersStore(SQLBaseStore):
- def __init__(self, dbconn, hs):
- super(MonthlyActiveUsersStore, self).__init__(None, hs)
+class MonthlyActiveUsersWorkerStore(SQLBaseStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(MonthlyActiveUsersWorkerStore, self).__init__(database, db_conn, hs)
self._clock = hs.get_clock()
self.hs = hs
- self.reserved_users = ()
+
+ @cached(num_args=0)
+ def get_monthly_active_count(self):
+ """Generates current count of monthly active users
+
+ Returns:
+ Defered[int]: Number of current monthly active users
+ """
+
+ def _count_users(txn):
+ sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
+ txn.execute(sql)
+ (count,) = txn.fetchone()
+ return count
+
+ return self.db.runInteraction("count_users", _count_users)
+
+ @cached(num_args=0)
+ def get_monthly_active_count_by_service(self):
+ """Generates current count of monthly active users broken down by service.
+ A service is typically an appservice but also includes native matrix users.
+ Since the `monthly_active_users` table is populated from the `user_ips` table
+ `config.track_appservice_user_ips` must be set to `true` for this
+ method to return anything other than native matrix users.
+
+ Returns:
+ Deferred[dict]: dict that includes a mapping between app_service_id
+ and the number of occurrences.
+
+ """
+
+ def _count_users_by_service(txn):
+ sql = """
+ SELECT COALESCE(appservice_id, 'native'), COALESCE(count(*), 0)
+ FROM monthly_active_users
+ LEFT JOIN users ON monthly_active_users.user_id=users.name
+ GROUP BY appservice_id;
+ """
+
+ txn.execute(sql)
+ result = txn.fetchall()
+ return dict(result)
+
+ return self.db.runInteraction("count_users_by_service", _count_users_by_service)
+
+ @defer.inlineCallbacks
+ def get_registered_reserved_users(self):
+ """Of the reserved threepids defined in config, which are associated
+ with registered users?
+
+ Returns:
+ Defered[list]: Real reserved users
+ """
+ users = []
+
+ for tp in self.hs.config.mau_limits_reserved_threepids[
+ : self.hs.config.max_mau_value
+ ]:
+ user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
+ tp["medium"], tp["address"]
+ )
+ if user_id:
+ users.append(user_id)
+
+ return users
+
+ @cached(num_args=1)
+ def user_last_seen_monthly_active(self, user_id):
+ """
+ Checks if a given user is part of the monthly active user group
+ Arguments:
+ user_id (str): user to add/update
+ Return:
+ Deferred[int] : timestamp since last seen, None if never seen
+
+ """
+
+ return self.db.simple_select_one_onecol(
+ table="monthly_active_users",
+ keyvalues={"user_id": user_id},
+ retcol="timestamp",
+ allow_none=True,
+ desc="user_last_seen_monthly_active",
+ )
+
+
+class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs)
+
# Do not add more reserved users than the total allowable number
- self._new_transaction(
- dbconn,
+ # cur = LoggingTransaction(
+ self.db.new_transaction(
+ db_conn,
"initialise_mau_threepids",
[],
[],
@@ -51,7 +141,6 @@ class MonthlyActiveUsersStore(SQLBaseStore):
txn (cursor):
threepids (list[dict]): List of threepid dicts to reserve
"""
- reserved_user_list = []
for tp in threepids:
user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
@@ -59,11 +148,15 @@ class MonthlyActiveUsersStore(SQLBaseStore):
if user_id:
is_support = self.is_support_user_txn(txn, user_id)
if not is_support:
- self.upsert_monthly_active_user_txn(txn, user_id)
- reserved_user_list.append(user_id)
+ # We do this manually here to avoid hitting #6791
+ self.db.simple_upsert_txn(
+ txn,
+ table="monthly_active_users",
+ keyvalues={"user_id": user_id},
+ values={"timestamp": int(self._clock.time_msec())},
+ )
else:
logger.warning("mau limit reserved threepid %s not found in db" % tp)
- self.reserved_users = tuple(reserved_user_list)
@defer.inlineCallbacks
def reap_monthly_active_users(self):
@@ -74,8 +167,11 @@ class MonthlyActiveUsersStore(SQLBaseStore):
Deferred[]
"""
- def _reap_users(txn):
- # Purge stale users
+ def _reap_users(txn, reserved_users):
+ """
+ Args:
+ reserved_users (tuple): reserved users to preserve
+ """
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
query_args = [thirty_days_ago]
@@ -83,20 +179,19 @@ class MonthlyActiveUsersStore(SQLBaseStore):
# Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
# when len(reserved_users) == 0. Works fine on sqlite.
- if len(self.reserved_users) > 0:
+ if len(reserved_users) > 0:
# questionmarks is a hack to overcome sqlite not supporting
# tuples in 'WHERE IN %s'
- questionmarks = '?' * len(self.reserved_users)
+ question_marks = ",".join("?" * len(reserved_users))
- query_args.extend(self.reserved_users)
- sql = base_sql + """ AND user_id NOT IN ({})""".format(
- ','.join(questionmarks)
- )
+ query_args.extend(reserved_users)
+ sql = base_sql + " AND user_id NOT IN ({})".format(question_marks)
else:
sql = base_sql
txn.execute(sql, query_args)
+ max_mau_value = self.hs.config.max_mau_value
if self.hs.config.limit_usage_by_mau:
# If MAU user count still exceeds the MAU threshold, then delete on
# a least recently active basis.
@@ -106,74 +201,64 @@ class MonthlyActiveUsersStore(SQLBaseStore):
# While Postgres does not require 'LIMIT', but also does not support
# negative LIMIT values. So there is no way to write it that both can
# support
- safe_guard = self.hs.config.max_mau_value - len(self.reserved_users)
- # Must be greater than zero for postgres
- safe_guard = safe_guard if safe_guard > 0 else 0
- query_args = [safe_guard]
-
- base_sql = """
- DELETE FROM monthly_active_users
- WHERE user_id NOT IN (
- SELECT user_id FROM monthly_active_users
- ORDER BY timestamp DESC
- LIMIT ?
+ if len(reserved_users) == 0:
+ sql = """
+ DELETE FROM monthly_active_users
+ WHERE user_id NOT IN (
+ SELECT user_id FROM monthly_active_users
+ ORDER BY timestamp DESC
+ LIMIT ?
)
- """
+ """
+ txn.execute(sql, (max_mau_value,))
# Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
# when len(reserved_users) == 0. Works fine on sqlite.
- if len(self.reserved_users) > 0:
- query_args.extend(self.reserved_users)
- sql = base_sql + """ AND user_id NOT IN ({})""".format(
- ','.join(questionmarks)
- )
else:
- sql = base_sql
- txn.execute(sql, query_args)
-
- yield self.runInteraction("reap_monthly_active_users", _reap_users)
- # It seems poor to invalidate the whole cache, Postgres supports
- # 'Returning' which would allow me to invalidate only the
- # specific users, but sqlite has no way to do this and instead
- # I would need to SELECT and the DELETE which without locking
- # is racy.
- # Have resolved to invalidate the whole cache for now and do
- # something about it if and when the perf becomes significant
- self.user_last_seen_monthly_active.invalidate_all()
- self.get_monthly_active_count.invalidate_all()
-
- @cached(num_args=0)
- def get_monthly_active_count(self):
- """Generates current count of monthly active users
-
- Returns:
- Defered[int]: Number of current monthly active users
- """
-
- def _count_users(txn):
- sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
-
- txn.execute(sql)
- count, = txn.fetchone()
- return count
-
- return self.runInteraction("count_users", _count_users)
+ # Must be >= 0 for postgres
+ num_of_non_reserved_users_to_remove = max(
+ max_mau_value - len(reserved_users), 0
+ )
- @defer.inlineCallbacks
- def get_registered_reserved_users_count(self):
- """Of the reserved threepids defined in config, how many are associated
- with registered users?
+ # It is important to filter reserved users twice to guard
+ # against the case where the reserved user is present in the
+ # SELECT, meaning that a legitmate mau is deleted.
+ sql = """
+ DELETE FROM monthly_active_users
+ WHERE user_id NOT IN (
+ SELECT user_id FROM monthly_active_users
+ WHERE user_id NOT IN ({})
+ ORDER BY timestamp DESC
+ LIMIT ?
+ )
+ AND user_id NOT IN ({})
+ """.format(
+ question_marks, question_marks
+ )
- Returns:
- Defered[int]: Number of real reserved users
- """
- count = 0
- for tp in self.hs.config.mau_limits_reserved_threepids:
- user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
- tp["medium"], tp["address"]
+ query_args = [
+ *reserved_users,
+ num_of_non_reserved_users_to_remove,
+ *reserved_users,
+ ]
+
+ txn.execute(sql, query_args)
+
+ # It seems poor to invalidate the whole cache, Postgres supports
+ # 'Returning' which would allow me to invalidate only the
+ # specific users, but sqlite has no way to do this and instead
+ # I would need to SELECT and the DELETE which without locking
+ # is racy.
+ # Have resolved to invalidate the whole cache for now and do
+ # something about it if and when the perf becomes significant
+ self._invalidate_all_cache_and_stream(
+ txn, self.user_last_seen_monthly_active
)
- if user_id:
- count = count + 1
- defer.returnValue(count)
+ self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
+
+ reserved_users = yield self.get_registered_reserved_users()
+ yield self.db.runInteraction(
+ "reap_monthly_active_users", _reap_users, reserved_users
+ )
@defer.inlineCallbacks
def upsert_monthly_active_user(self, user_id):
@@ -195,27 +280,13 @@ class MonthlyActiveUsersStore(SQLBaseStore):
if is_support:
return
- yield self.runInteraction(
+ yield self.db.runInteraction(
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
)
- user_in_mau = self.user_last_seen_monthly_active.cache.get(
- (user_id,), None, update_metrics=False
- )
- if user_in_mau is None:
- self.get_monthly_active_count.invalidate(())
-
- self.user_last_seen_monthly_active.invalidate((user_id,))
-
def upsert_monthly_active_user_txn(self, txn, user_id):
"""Updates or inserts monthly active user member
- Note that, after calling this method, it will generally be necessary
- to invalidate the caches on user_last_seen_monthly_active and
- get_monthly_active_count. We can't do that here, because we are running
- in a database thread rather than the main thread, and we can't call
- txn.call_after because txn may not be a LoggingTransaction.
-
We consciously do not call is_support_txn from this method because it
is not possible to cache the response. is_support_txn will be false in
almost all cases, so it seems reasonable to call it only for
@@ -239,33 +310,22 @@ class MonthlyActiveUsersStore(SQLBaseStore):
# never be a big table and alternative approaches (batching multiple
# upserts into a single txn) introduced a lot of extra complexity.
# See https://github.com/matrix-org/synapse/issues/3854 for more
- is_insert = self._simple_upsert_txn(
+ is_insert = self.db.simple_upsert_txn(
txn,
table="monthly_active_users",
keyvalues={"user_id": user_id},
values={"timestamp": int(self._clock.time_msec())},
)
- return is_insert
-
- @cached(num_args=1)
- def user_last_seen_monthly_active(self, user_id):
- """
- Checks if a given user is part of the monthly active user group
- Arguments:
- user_id (str): user to add/update
- Return:
- Deferred[int] : timestamp since last seen, None if never seen
-
- """
-
- return self._simple_select_one_onecol(
- table="monthly_active_users",
- keyvalues={"user_id": user_id},
- retcol="timestamp",
- allow_none=True,
- desc="user_last_seen_monthly_active",
+ self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ())
+ self._invalidate_cache_and_stream(
+ txn, self.get_monthly_active_count_by_service, ()
)
+ self._invalidate_cache_and_stream(
+ txn, self.user_last_seen_monthly_active, (user_id,)
+ )
+
+ return is_insert
@defer.inlineCallbacks
def populate_monthly_active_users(self, user_id):
diff --git a/synapse/storage/openid.py b/synapse/storage/data_stores/main/openid.py
index b3318045ee..cc21437e92 100644
--- a/synapse/storage/openid.py
+++ b/synapse/storage/data_stores/main/openid.py
@@ -1,9 +1,9 @@
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
class OpenIdStore(SQLBaseStore):
def insert_open_id_token(self, token, ts_valid_until_ms, user_id):
- return self._simple_insert(
+ return self.db.simple_insert(
table="open_id_tokens",
values={
"token": token,
@@ -28,4 +28,6 @@ class OpenIdStore(SQLBaseStore):
else:
return rows[0][0]
- return self.runInteraction("get_user_id_for_token", get_user_id_for_token_txn)
+ return self.db.runInteraction(
+ "get_user_id_for_token", get_user_id_for_token_txn
+ )
diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py
new file mode 100644
index 0000000000..604c8b7ddd
--- /dev/null
+++ b/synapse/storage/data_stores/main/presence.py
@@ -0,0 +1,150 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.presence import UserPresenceState
+from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.iterutils import batch_iter
+
+
+class PresenceStore(SQLBaseStore):
+ @defer.inlineCallbacks
+ def update_presence(self, presence_states):
+ stream_ordering_manager = self._presence_id_gen.get_next_mult(
+ len(presence_states)
+ )
+
+ with stream_ordering_manager as stream_orderings:
+ yield self.db.runInteraction(
+ "update_presence",
+ self._update_presence_txn,
+ stream_orderings,
+ presence_states,
+ )
+
+ return stream_orderings[-1], self._presence_id_gen.get_current_token()
+
+ def _update_presence_txn(self, txn, stream_orderings, presence_states):
+ for stream_id, state in zip(stream_orderings, presence_states):
+ txn.call_after(
+ self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
+ )
+ txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
+
+ # Actually insert new rows
+ self.db.simple_insert_many_txn(
+ txn,
+ table="presence_stream",
+ values=[
+ {
+ "stream_id": stream_id,
+ "user_id": state.user_id,
+ "state": state.state,
+ "last_active_ts": state.last_active_ts,
+ "last_federation_update_ts": state.last_federation_update_ts,
+ "last_user_sync_ts": state.last_user_sync_ts,
+ "status_msg": state.status_msg,
+ "currently_active": state.currently_active,
+ }
+ for state in presence_states
+ ],
+ )
+
+ # Delete old rows to stop database from getting really big
+ sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "
+
+ for states in batch_iter(presence_states, 50):
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "user_id", [s.user_id for s in states]
+ )
+ txn.execute(sql + clause, [stream_id] + list(args))
+
+ def get_all_presence_updates(self, last_id, current_id):
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_presence_updates_txn(txn):
+ sql = (
+ "SELECT stream_id, user_id, state, last_active_ts,"
+ " last_federation_update_ts, last_user_sync_ts, status_msg,"
+ " currently_active"
+ " FROM presence_stream"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ )
+ txn.execute(sql, (last_id, current_id))
+ return txn.fetchall()
+
+ return self.db.runInteraction(
+ "get_all_presence_updates", get_all_presence_updates_txn
+ )
+
+ @cached()
+ def _get_presence_for_user(self, user_id):
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="_get_presence_for_user",
+ list_name="user_ids",
+ num_args=1,
+ inlineCallbacks=True,
+ )
+ def get_presence_for_users(self, user_ids):
+ rows = yield self.db.simple_select_many_batch(
+ table="presence_stream",
+ column="user_id",
+ iterable=user_ids,
+ keyvalues={},
+ retcols=(
+ "user_id",
+ "state",
+ "last_active_ts",
+ "last_federation_update_ts",
+ "last_user_sync_ts",
+ "status_msg",
+ "currently_active",
+ ),
+ desc="get_presence_for_users",
+ )
+
+ for row in rows:
+ row["currently_active"] = bool(row["currently_active"])
+
+ return {row["user_id"]: UserPresenceState(**row) for row in rows}
+
+ def get_current_presence_token(self):
+ return self._presence_id_gen.get_current_token()
+
+ def allow_presence_visible(self, observed_localpart, observer_userid):
+ return self.db.simple_insert(
+ table="presence_allow_inbound",
+ values={
+ "observed_user_id": observed_localpart,
+ "observer_user_id": observer_userid,
+ },
+ desc="allow_presence_visible",
+ or_ignore=True,
+ )
+
+ def disallow_presence_visible(self, observed_localpart, observer_userid):
+ return self.db.simple_delete_one(
+ table="presence_allow_inbound",
+ keyvalues={
+ "observed_user_id": observed_localpart,
+ "observer_user_id": observer_userid,
+ },
+ desc="disallow_presence_visible",
+ )
diff --git a/synapse/storage/profile.py b/synapse/storage/data_stores/main/profile.py
index 38524f2545..2a97991d23 100644
--- a/synapse/storage/profile.py
+++ b/synapse/storage/data_stores/main/profile.py
@@ -17,10 +17,9 @@
from twisted.internet import defer
from synapse.api.errors import StoreError
-from synapse.storage.roommember import ProfileInfo
-
-from . import background_updates
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.roommember import ProfileInfo
+from synapse.util.caches.descriptors import cached
BATCH_SIZE = 100
@@ -29,7 +28,7 @@ class ProfileWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_profileinfo(self, user_localpart):
try:
- profile = yield self._simple_select_one(
+ profile = yield self.db.simple_select_one(
table="profiles",
keyvalues={"user_id": user_localpart},
retcols=("displayname", "avatar_url"),
@@ -38,27 +37,26 @@ class ProfileWorkerStore(SQLBaseStore):
except StoreError as e:
if e.code == 404:
# no match
- defer.returnValue(ProfileInfo(None, None))
- return
+ return ProfileInfo(None, None)
else:
raise
- defer.returnValue(
- ProfileInfo(
- avatar_url=profile['avatar_url'], display_name=profile['displayname']
- )
+ return ProfileInfo(
+ avatar_url=profile["avatar_url"], display_name=profile["displayname"]
)
+ @cached(max_entries=5000)
def get_profile_displayname(self, user_localpart):
- return self._simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="displayname",
desc="get_profile_displayname",
)
+ @cached(max_entries=5000)
def get_profile_avatar_url(self, user_localpart):
- return self._simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="avatar_url",
@@ -68,18 +66,15 @@ class ProfileWorkerStore(SQLBaseStore):
def get_latest_profile_replication_batch_number(self):
def f(txn):
txn.execute("SELECT MAX(batch) as maxbatch FROM profiles")
- rows = self.cursor_to_dict(txn)
- return rows[0]['maxbatch']
- return self.runInteraction(
- "get_latest_profile_replication_batch_number", f,
- )
+ rows = self.db.cursor_to_dict(txn)
+ return rows[0]["maxbatch"]
+
+ return self.db.runInteraction("get_latest_profile_replication_batch_number", f)
def get_profile_batch(self, batchnum):
- return self._simple_select_list(
+ return self.db.simple_select_list(
table="profiles",
- keyvalues={
- "batch": batchnum,
- },
+ keyvalues={"batch": batchnum},
retcols=("user_id", "displayname", "avatar_url", "active"),
desc="get_profile_batch",
)
@@ -95,27 +90,29 @@ class ProfileWorkerStore(SQLBaseStore):
)
txn.execute(sql, (BATCH_SIZE,))
return txn.rowcount
- return self.runInteraction("assign_profile_batch", f)
+
+ return self.db.runInteraction("assign_profile_batch", f)
def get_replication_hosts(self):
def f(txn):
- txn.execute("SELECT host, last_synced_batch FROM profile_replication_status")
- rows = self.cursor_to_dict(txn)
- return {r['host']: r['last_synced_batch'] for r in rows}
- return self.runInteraction("get_replication_hosts", f)
+ txn.execute(
+ "SELECT host, last_synced_batch FROM profile_replication_status"
+ )
+ rows = self.db.cursor_to_dict(txn)
+ return {r["host"]: r["last_synced_batch"] for r in rows}
+
+ return self.db.runInteraction("get_replication_hosts", f)
def update_replication_batch_for_host(self, host, last_synced_batch):
- return self._simple_upsert(
+ return self.db.simple_upsert(
table="profile_replication_status",
keyvalues={"host": host},
- values={
- "last_synced_batch": last_synced_batch,
- },
+ values={"last_synced_batch": last_synced_batch},
desc="update_replication_batch_for_host",
)
def get_from_remote_profile_cache(self, user_id):
- return self._simple_select_one(
+ return self.db.simple_select_one(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
retcols=("displayname", "avatar_url"),
@@ -123,55 +120,57 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_from_remote_profile_cache",
)
+ def create_profile(self, user_localpart):
+ return self.db.simple_insert(
+ table="profiles", values={"user_id": user_localpart}, desc="create_profile"
+ )
+
def set_profile_displayname(self, user_localpart, new_displayname, batchnum):
- return self._simple_upsert(
+ # Invalidate the read cache for this user
+ self.get_profile_displayname.invalidate((user_localpart,))
+
+ return self.db.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
- values={
- "displayname": new_displayname,
- "batch": batchnum,
- },
+ values={"displayname": new_displayname, "batch": batchnum},
desc="set_profile_displayname",
- lock=False # we can do this because user_id has a unique index
+ lock=False, # we can do this because user_id has a unique index
)
def set_profile_avatar_url(self, user_localpart, new_avatar_url, batchnum):
- return self._simple_upsert(
+ # Invalidate the read cache for this user
+ self.get_profile_avatar_url.invalidate((user_localpart,))
+
+ return self.db.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
- values={
- "avatar_url": new_avatar_url,
- "batch": batchnum,
- },
+ values={"avatar_url": new_avatar_url, "batch": batchnum},
desc="set_profile_avatar_url",
- lock=False # we can do this because user_id has a unique index
+ lock=False, # we can do this because user_id has a unique index
)
def set_profile_active(self, user_localpart, active, hide, batchnum):
- values = {
- "active": int(active),
- "batch": batchnum,
- }
+ values = {"active": int(active), "batch": batchnum}
if not active and not hide:
# we are deactivating for real (not in hide mode)
# so clear the profile.
values["avatar_url"] = None
values["displayname"] = None
- return self._simple_upsert(
+ return self.db.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
values=values,
desc="set_profile_active",
- lock=False # we can do this because user_id has a unique index
+ lock=False, # we can do this because user_id has a unique index
)
-class ProfileStore(ProfileWorkerStore, background_updates.BackgroundUpdateStore):
- def __init__(self, db_conn, hs):
+class ProfileStore(ProfileWorkerStore):
+ def __init__(self, database, db_conn, hs):
- super(ProfileStore, self).__init__(db_conn, hs)
+ super(ProfileStore, self).__init__(database, db_conn, hs)
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"profile_replication_status_host_index",
index_name="profile_replication_status_idx",
table="profile_replication_status",
@@ -185,7 +184,7 @@ class ProfileStore(ProfileWorkerStore, background_updates.BackgroundUpdateStore)
This should only be called when `is_subscribed_remote_profile_for_user`
would return true for the user.
"""
- return self._simple_upsert(
+ return self.db.simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
values={
@@ -197,7 +196,7 @@ class ProfileStore(ProfileWorkerStore, background_updates.BackgroundUpdateStore)
)
def update_remote_profile_cache(self, user_id, displayname, avatar_url):
- return self._simple_update(
+ return self.db.simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
values={
@@ -215,7 +214,7 @@ class ProfileStore(ProfileWorkerStore, background_updates.BackgroundUpdateStore)
"""
subscribed = yield self.is_subscribed_remote_profile_for_user(user_id)
if not subscribed:
- yield self._simple_delete(
+ yield self.db.simple_delete(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
desc="delete_remote_profile_cache",
@@ -234,9 +233,9 @@ class ProfileStore(ProfileWorkerStore, background_updates.BackgroundUpdateStore)
txn.execute(sql, (last_checked,))
- return self.cursor_to_dict(txn)
+ return self.db.cursor_to_dict(txn)
- return self.runInteraction(
+ return self.db.runInteraction(
"get_remote_profile_cache_entries_that_expire",
_get_remote_profile_cache_entries_that_expire_txn,
)
@@ -245,7 +244,7 @@ class ProfileStore(ProfileWorkerStore, background_updates.BackgroundUpdateStore)
def is_subscribed_remote_profile_for_user(self, user_id):
"""Check whether we are interested in a remote user's profile.
"""
- res = yield self._simple_select_one_onecol(
+ res = yield self.db.simple_select_one_onecol(
table="group_users",
keyvalues={"user_id": user_id},
retcol="user_id",
@@ -254,9 +253,9 @@ class ProfileStore(ProfileWorkerStore, background_updates.BackgroundUpdateStore)
)
if res:
- defer.returnValue(True)
+ return True
- res = yield self._simple_select_one_onecol(
+ res = yield self.db.simple_select_one_onecol(
table="group_invites",
keyvalues={"user_id": user_id},
retcol="user_id",
@@ -265,4 +264,4 @@ class ProfileStore(ProfileWorkerStore, background_updates.BackgroundUpdateStore)
)
if res:
- defer.returnValue(True)
+ return True
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
new file mode 100644
index 0000000000..62ac88d9f2
--- /dev/null
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -0,0 +1,714 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import abc
+import logging
+
+from canonicaljson import json
+
+from twisted.internet import defer
+
+from synapse.push.baserules import list_with_base_rules
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore
+from synapse.storage.data_stores.main.pusher import PusherWorkerStore
+from synapse.storage.data_stores.main.receipts import ReceiptsWorkerStore
+from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
+from synapse.storage.database import Database
+from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
+from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+logger = logging.getLogger(__name__)
+
+
+def _load_rules(rawrules, enabled_map):
+ ruleslist = []
+ for rawrule in rawrules:
+ rule = dict(rawrule)
+ rule["conditions"] = json.loads(rawrule["conditions"])
+ rule["actions"] = json.loads(rawrule["actions"])
+ ruleslist.append(rule)
+
+ # We're going to be mutating this a lot, so do a deep copy
+ rules = list(list_with_base_rules(ruleslist))
+
+ for i, rule in enumerate(rules):
+ rule_id = rule["rule_id"]
+ if rule_id in enabled_map:
+ if rule.get("enabled", True) != bool(enabled_map[rule_id]):
+ # Rules are cached across users.
+ rule = dict(rule)
+ rule["enabled"] = bool(enabled_map[rule_id])
+ rules[i] = rule
+
+ return rules
+
+
+class PushRulesWorkerStore(
+ ApplicationServiceWorkerStore,
+ ReceiptsWorkerStore,
+ PusherWorkerStore,
+ RoomMemberWorkerStore,
+ SQLBaseStore,
+):
+ """This is an abstract base class where subclasses must implement
+ `get_max_push_rules_stream_id` which can be called in the initializer.
+ """
+
+ # This ABCMeta metaclass ensures that we cannot be instantiated without
+ # the abstract methods being implemented.
+ __metaclass__ = abc.ABCMeta
+
+ def __init__(self, database: Database, db_conn, hs):
+ super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
+
+ push_rules_prefill, push_rules_id = self.db.get_cache_dict(
+ db_conn,
+ "push_rules_stream",
+ entity_column="user_id",
+ stream_column="stream_id",
+ max_value=self.get_max_push_rules_stream_id(),
+ )
+
+ self.push_rules_stream_cache = StreamChangeCache(
+ "PushRulesStreamChangeCache",
+ push_rules_id,
+ prefilled_cache=push_rules_prefill,
+ )
+
+ @abc.abstractmethod
+ def get_max_push_rules_stream_id(self):
+ """Get the position of the push rules stream.
+
+ Returns:
+ int
+ """
+ raise NotImplementedError()
+
+ @cachedInlineCallbacks(max_entries=5000)
+ def get_push_rules_for_user(self, user_id):
+ rows = yield self.db.simple_select_list(
+ table="push_rules",
+ keyvalues={"user_name": user_id},
+ retcols=(
+ "user_name",
+ "rule_id",
+ "priority_class",
+ "priority",
+ "conditions",
+ "actions",
+ ),
+ desc="get_push_rules_enabled_for_user",
+ )
+
+ rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
+
+ enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
+
+ rules = _load_rules(rows, enabled_map)
+
+ return rules
+
+ @cachedInlineCallbacks(max_entries=5000)
+ def get_push_rules_enabled_for_user(self, user_id):
+ results = yield self.db.simple_select_list(
+ table="push_rules_enable",
+ keyvalues={"user_name": user_id},
+ retcols=("user_name", "rule_id", "enabled"),
+ desc="get_push_rules_enabled_for_user",
+ )
+ return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
+
+ def have_push_rules_changed_for_user(self, user_id, last_id):
+ if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
+ return defer.succeed(False)
+ else:
+
+ def have_push_rules_changed_txn(txn):
+ sql = (
+ "SELECT COUNT(stream_id) FROM push_rules_stream"
+ " WHERE user_id = ? AND ? < stream_id"
+ )
+ txn.execute(sql, (user_id, last_id))
+ (count,) = txn.fetchone()
+ return bool(count)
+
+ return self.db.runInteraction(
+ "have_push_rules_changed", have_push_rules_changed_txn
+ )
+
+ @cachedList(
+ cached_method_name="get_push_rules_for_user",
+ list_name="user_ids",
+ num_args=1,
+ inlineCallbacks=True,
+ )
+ def bulk_get_push_rules(self, user_ids):
+ if not user_ids:
+ return {}
+
+ results = {user_id: [] for user_id in user_ids}
+
+ rows = yield self.db.simple_select_many_batch(
+ table="push_rules",
+ column="user_name",
+ iterable=user_ids,
+ retcols=("*",),
+ desc="bulk_get_push_rules",
+ )
+
+ rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
+
+ for row in rows:
+ results.setdefault(row["user_name"], []).append(row)
+
+ enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
+
+ for user_id, rules in results.items():
+ results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
+
+ return results
+
+ @defer.inlineCallbacks
+ def copy_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
+ """Copy a single push rule from one room to another for a specific user.
+
+ Args:
+ new_room_id (str): ID of the new room.
+ user_id (str): ID of user the push rule belongs to.
+ rule (Dict): A push rule.
+ """
+ # Create new rule id
+ rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
+ new_rule_id = rule_id_scope + "/" + new_room_id
+
+ # Change room id in each condition
+ for condition in rule.get("conditions", []):
+ if condition.get("key") == "room_id":
+ condition["pattern"] = new_room_id
+
+ # Add the rule for the new room
+ yield self.add_push_rule(
+ user_id=user_id,
+ rule_id=new_rule_id,
+ priority_class=rule["priority_class"],
+ conditions=rule["conditions"],
+ actions=rule["actions"],
+ )
+
+ @defer.inlineCallbacks
+ def copy_push_rules_from_room_to_room_for_user(
+ self, old_room_id, new_room_id, user_id
+ ):
+ """Copy all of the push rules from one room to another for a specific
+ user.
+
+ Args:
+ old_room_id (str): ID of the old room.
+ new_room_id (str): ID of the new room.
+ user_id (str): ID of user to copy push rules for.
+ """
+ # Retrieve push rules for this user
+ user_push_rules = yield self.get_push_rules_for_user(user_id)
+
+ # Get rules relating to the old room and copy them to the new room
+ for rule in user_push_rules:
+ conditions = rule.get("conditions", [])
+ if any(
+ (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
+ for c in conditions
+ ):
+ yield self.copy_push_rule_from_room_to_room(new_room_id, user_id, rule)
+
+ @defer.inlineCallbacks
+ def bulk_get_push_rules_for_room(self, event, context):
+ state_group = context.state_group
+ if not state_group:
+ # If state_group is None it means it has yet to be assigned a
+ # state group, i.e. we need to make sure that calls with a state_group
+ # of None don't hit previous cached calls with a None state_group.
+ # To do this we set the state_group to a new object as object() != object()
+ state_group = object()
+
+ current_state_ids = yield context.get_current_state_ids()
+ result = yield self._bulk_get_push_rules_for_room(
+ event.room_id, state_group, current_state_ids, event=event
+ )
+ return result
+
+ @cachedInlineCallbacks(num_args=2, cache_context=True)
+ def _bulk_get_push_rules_for_room(
+ self, room_id, state_group, current_state_ids, cache_context, event=None
+ ):
+ # We don't use `state_group`, its there so that we can cache based
+ # on it. However, its important that its never None, since two current_state's
+ # with a state_group of None are likely to be different.
+ # See bulk_get_push_rules_for_room for how we work around this.
+ assert state_group is not None
+
+ # We also will want to generate notifs for other people in the room so
+ # their unread countss are correct in the event stream, but to avoid
+ # generating them for bot / AS users etc, we only do so for people who've
+ # sent a read receipt into the room.
+
+ users_in_room = yield self._get_joined_users_from_context(
+ room_id,
+ state_group,
+ current_state_ids,
+ on_invalidate=cache_context.invalidate,
+ event=event,
+ )
+
+ # We ignore app service users for now. This is so that we don't fill
+ # up the `get_if_users_have_pushers` cache with AS entries that we
+ # know don't have pushers, nor even read receipts.
+ local_users_in_room = {
+ u
+ for u in users_in_room
+ if self.hs.is_mine_id(u)
+ and not self.get_if_app_services_interested_in_user(u)
+ }
+
+ # users in the room who have pushers need to get push rules run because
+ # that's how their pushers work
+ if_users_with_pushers = yield self.get_if_users_have_pushers(
+ local_users_in_room, on_invalidate=cache_context.invalidate
+ )
+ user_ids = {
+ uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
+ }
+
+ users_with_receipts = yield self.get_users_with_read_receipts_in_room(
+ room_id, on_invalidate=cache_context.invalidate
+ )
+
+ # any users with pushers must be ours: they have pushers
+ for uid in users_with_receipts:
+ if uid in local_users_in_room:
+ user_ids.add(uid)
+
+ rules_by_user = yield self.bulk_get_push_rules(
+ user_ids, on_invalidate=cache_context.invalidate
+ )
+
+ rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
+
+ return rules_by_user
+
+ @cachedList(
+ cached_method_name="get_push_rules_enabled_for_user",
+ list_name="user_ids",
+ num_args=1,
+ inlineCallbacks=True,
+ )
+ def bulk_get_push_rules_enabled(self, user_ids):
+ if not user_ids:
+ return {}
+
+ results = {user_id: {} for user_id in user_ids}
+
+ rows = yield self.db.simple_select_many_batch(
+ table="push_rules_enable",
+ column="user_name",
+ iterable=user_ids,
+ retcols=("user_name", "rule_id", "enabled"),
+ desc="bulk_get_push_rules_enabled",
+ )
+ for row in rows:
+ enabled = bool(row["enabled"])
+ results.setdefault(row["user_name"], {})[row["rule_id"]] = enabled
+ return results
+
+
+class PushRuleStore(PushRulesWorkerStore):
+ @defer.inlineCallbacks
+ def add_push_rule(
+ self,
+ user_id,
+ rule_id,
+ priority_class,
+ conditions,
+ actions,
+ before=None,
+ after=None,
+ ):
+ conditions_json = json.dumps(conditions)
+ actions_json = json.dumps(actions)
+ with self._push_rules_stream_id_gen.get_next() as ids:
+ stream_id, event_stream_ordering = ids
+ if before or after:
+ yield self.db.runInteraction(
+ "_add_push_rule_relative_txn",
+ self._add_push_rule_relative_txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ priority_class,
+ conditions_json,
+ actions_json,
+ before,
+ after,
+ )
+ else:
+ yield self.db.runInteraction(
+ "_add_push_rule_highest_priority_txn",
+ self._add_push_rule_highest_priority_txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ priority_class,
+ conditions_json,
+ actions_json,
+ )
+
+ def _add_push_rule_relative_txn(
+ self,
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ priority_class,
+ conditions_json,
+ actions_json,
+ before,
+ after,
+ ):
+ # Lock the table since otherwise we'll have annoying races between the
+ # SELECT here and the UPSERT below.
+ self.database_engine.lock_table(txn, "push_rules")
+
+ relative_to_rule = before or after
+
+ res = self.db.simple_select_one_txn(
+ txn,
+ table="push_rules",
+ keyvalues={"user_name": user_id, "rule_id": relative_to_rule},
+ retcols=["priority_class", "priority"],
+ allow_none=True,
+ )
+
+ if not res:
+ raise RuleNotFoundException(
+ "before/after rule not found: %s" % (relative_to_rule,)
+ )
+
+ base_priority_class = res["priority_class"]
+ base_rule_priority = res["priority"]
+
+ if base_priority_class != priority_class:
+ raise InconsistentRuleException(
+ "Given priority class does not match class of relative rule"
+ )
+
+ if before:
+ # Higher priority rules are executed first, So adding a rule before
+ # a rule means giving it a higher priority than that rule.
+ new_rule_priority = base_rule_priority + 1
+ else:
+ # We increment the priority of the existing rules to make space for
+ # the new rule. Therefore if we want this rule to appear after
+ # an existing rule we give it the priority of the existing rule,
+ # and then increment the priority of the existing rule.
+ new_rule_priority = base_rule_priority
+
+ sql = (
+ "UPDATE push_rules SET priority = priority + 1"
+ " WHERE user_name = ? AND priority_class = ? AND priority >= ?"
+ )
+
+ txn.execute(sql, (user_id, priority_class, new_rule_priority))
+
+ self._upsert_push_rule_txn(
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ priority_class,
+ new_rule_priority,
+ conditions_json,
+ actions_json,
+ )
+
+ def _add_push_rule_highest_priority_txn(
+ self,
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ priority_class,
+ conditions_json,
+ actions_json,
+ ):
+ # Lock the table since otherwise we'll have annoying races between the
+ # SELECT here and the UPSERT below.
+ self.database_engine.lock_table(txn, "push_rules")
+
+ # find the highest priority rule in that class
+ sql = (
+ "SELECT COUNT(*), MAX(priority) FROM push_rules"
+ " WHERE user_name = ? and priority_class = ?"
+ )
+ txn.execute(sql, (user_id, priority_class))
+ res = txn.fetchall()
+ (how_many, highest_prio) = res[0]
+
+ new_prio = 0
+ if how_many > 0:
+ new_prio = highest_prio + 1
+
+ self._upsert_push_rule_txn(
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ priority_class,
+ new_prio,
+ conditions_json,
+ actions_json,
+ )
+
+ def _upsert_push_rule_txn(
+ self,
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ priority_class,
+ priority,
+ conditions_json,
+ actions_json,
+ update_stream=True,
+ ):
+ """Specialised version of simple_upsert_txn that picks a push_rule_id
+ using the _push_rule_id_gen if it needs to insert the rule. It assumes
+ that the "push_rules" table is locked"""
+
+ sql = (
+ "UPDATE push_rules"
+ " SET priority_class = ?, priority = ?, conditions = ?, actions = ?"
+ " WHERE user_name = ? AND rule_id = ?"
+ )
+
+ txn.execute(
+ sql,
+ (priority_class, priority, conditions_json, actions_json, user_id, rule_id),
+ )
+
+ if txn.rowcount == 0:
+ # We didn't update a row with the given rule_id so insert one
+ push_rule_id = self._push_rule_id_gen.get_next()
+
+ self.db.simple_insert_txn(
+ txn,
+ table="push_rules",
+ values={
+ "id": push_rule_id,
+ "user_name": user_id,
+ "rule_id": rule_id,
+ "priority_class": priority_class,
+ "priority": priority,
+ "conditions": conditions_json,
+ "actions": actions_json,
+ },
+ )
+
+ if update_stream:
+ self._insert_push_rules_update_txn(
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ op="ADD",
+ data={
+ "priority_class": priority_class,
+ "priority": priority,
+ "conditions": conditions_json,
+ "actions": actions_json,
+ },
+ )
+
+ @defer.inlineCallbacks
+ def delete_push_rule(self, user_id, rule_id):
+ """
+ Delete a push rule. Args specify the row to be deleted and can be
+ any of the columns in the push_rule table, but below are the
+ standard ones
+
+ Args:
+ user_id (str): The matrix ID of the push rule owner
+ rule_id (str): The rule_id of the rule to be deleted
+ """
+
+ def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
+ self.db.simple_delete_one_txn(
+ txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}
+ )
+
+ self._insert_push_rules_update_txn(
+ txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
+ )
+
+ with self._push_rules_stream_id_gen.get_next() as ids:
+ stream_id, event_stream_ordering = ids
+ yield self.db.runInteraction(
+ "delete_push_rule",
+ delete_push_rule_txn,
+ stream_id,
+ event_stream_ordering,
+ )
+
+ @defer.inlineCallbacks
+ def set_push_rule_enabled(self, user_id, rule_id, enabled):
+ with self._push_rules_stream_id_gen.get_next() as ids:
+ stream_id, event_stream_ordering = ids
+ yield self.db.runInteraction(
+ "_set_push_rule_enabled_txn",
+ self._set_push_rule_enabled_txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ enabled,
+ )
+
+ def _set_push_rule_enabled_txn(
+ self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
+ ):
+ new_id = self._push_rules_enable_id_gen.get_next()
+ self.db.simple_upsert_txn(
+ txn,
+ "push_rules_enable",
+ {"user_name": user_id, "rule_id": rule_id},
+ {"enabled": 1 if enabled else 0},
+ {"id": new_id},
+ )
+
+ self._insert_push_rules_update_txn(
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ op="ENABLE" if enabled else "DISABLE",
+ )
+
+ @defer.inlineCallbacks
+ def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
+ actions_json = json.dumps(actions)
+
+ def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
+ if is_default_rule:
+ # Add a dummy rule to the rules table with the user specified
+ # actions.
+ priority_class = -1
+ priority = 1
+ self._upsert_push_rule_txn(
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ priority_class,
+ priority,
+ "[]",
+ actions_json,
+ update_stream=False,
+ )
+ else:
+ self.db.simple_update_one_txn(
+ txn,
+ "push_rules",
+ {"user_name": user_id, "rule_id": rule_id},
+ {"actions": actions_json},
+ )
+
+ self._insert_push_rules_update_txn(
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ op="ACTIONS",
+ data={"actions": actions_json},
+ )
+
+ with self._push_rules_stream_id_gen.get_next() as ids:
+ stream_id, event_stream_ordering = ids
+ yield self.db.runInteraction(
+ "set_push_rule_actions",
+ set_push_rule_actions_txn,
+ stream_id,
+ event_stream_ordering,
+ )
+
+ def _insert_push_rules_update_txn(
+ self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None
+ ):
+ values = {
+ "stream_id": stream_id,
+ "event_stream_ordering": event_stream_ordering,
+ "user_id": user_id,
+ "rule_id": rule_id,
+ "op": op,
+ }
+ if data is not None:
+ values.update(data)
+
+ self.db.simple_insert_txn(txn, "push_rules_stream", values=values)
+
+ txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
+ txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
+ txn.call_after(
+ self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
+ )
+
+ def get_all_push_rule_updates(self, last_id, current_id, limit):
+ """Get all the push rules changes that have happend on the server"""
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_push_rule_updates_txn(txn):
+ sql = (
+ "SELECT stream_id, event_stream_ordering, user_id, rule_id,"
+ " op, priority_class, priority, conditions, actions"
+ " FROM push_rules_stream"
+ " WHERE ? < stream_id AND stream_id <= ?"
+ " ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_id, current_id, limit))
+ return txn.fetchall()
+
+ return self.db.runInteraction(
+ "get_all_push_rule_updates", get_all_push_rule_updates_txn
+ )
+
+ def get_push_rules_stream_token(self):
+ """Get the position of the push rules stream.
+ Returns a pair of a stream id for the push_rules stream and the
+ room stream ordering it corresponds to."""
+ return self._push_rules_stream_id_gen.get_current_token()
+
+ def get_max_push_rules_stream_id(self):
+ return self.get_push_rules_stream_token()[0]
diff --git a/synapse/storage/pusher.py b/synapse/storage/data_stores/main/pusher.py
index 1567e1df48..547b9d69cb 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/data_stores/main/pusher.py
@@ -15,55 +15,45 @@
# limitations under the License.
import logging
-
-import six
+from typing import Iterable, Iterator
from canonicaljson import encode_canonical_json, json
from twisted.internet import defer
+from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
-from ._base import SQLBaseStore
-
logger = logging.getLogger(__name__)
-if six.PY2:
- db_binary_type = six.moves.builtins.buffer
-else:
- db_binary_type = memoryview
-
class PusherWorkerStore(SQLBaseStore):
- def _decode_pushers_rows(self, rows):
+ def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]:
+ """JSON-decode the data in the rows returned from the `pushers` table
+
+ Drops any rows whose data cannot be decoded
+ """
for r in rows:
- dataJson = r['data']
- r['data'] = None
+ dataJson = r["data"]
try:
- if isinstance(dataJson, db_binary_type):
- dataJson = str(dataJson).decode("UTF8")
-
- r['data'] = json.loads(dataJson)
+ r["data"] = json.loads(dataJson)
except Exception as e:
- logger.warn(
+ logger.warning(
"Invalid JSON in data for pusher %d: %s, %s",
- r['id'],
+ r["id"],
dataJson,
e.args[0],
)
- pass
-
- if isinstance(r['pushkey'], db_binary_type):
- r['pushkey'] = str(r['pushkey']).decode("UTF8")
+ continue
- return rows
+ yield r
@defer.inlineCallbacks
def user_has_pusher(self, user_id):
- ret = yield self._simple_select_one_onecol(
+ ret = yield self.db.simple_select_one_onecol(
"pushers", {"user_name": user_id}, "id", allow_none=True
)
- defer.returnValue(ret is not None)
+ return ret is not None
def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
@@ -73,7 +63,7 @@ class PusherWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_pushers_by(self, keyvalues):
- ret = yield self._simple_select_list(
+ ret = yield self.db.simple_select_list(
"pushers",
keyvalues,
[
@@ -95,18 +85,18 @@ class PusherWorkerStore(SQLBaseStore):
],
desc="get_pushers_by",
)
- defer.returnValue(self._decode_pushers_rows(ret))
+ return self._decode_pushers_rows(ret)
@defer.inlineCallbacks
def get_all_pushers(self):
def get_pushers(txn):
txn.execute("SELECT * FROM pushers")
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
return self._decode_pushers_rows(rows)
- rows = yield self.runInteraction("get_all_pushers", get_pushers)
- defer.returnValue(rows)
+ rows = yield self.db.runInteraction("get_all_pushers", get_pushers)
+ return rows
def get_all_updated_pushers(self, last_id, current_id, limit):
if last_id == current_id:
@@ -133,9 +123,9 @@ class PusherWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id, limit))
deleted = txn.fetchall()
- return (updated, deleted)
+ return updated, deleted
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_updated_pushers", get_all_updated_pushers_txn
)
@@ -178,7 +168,7 @@ class PusherWorkerStore(SQLBaseStore):
return results
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_updated_pushers_rows", get_all_updated_pushers_rows_txn
)
@@ -194,18 +184,96 @@ class PusherWorkerStore(SQLBaseStore):
inlineCallbacks=True,
)
def get_if_users_have_pushers(self, user_ids):
- rows = yield self._simple_select_many_batch(
- table='pushers',
- column='user_name',
+ rows = yield self.db.simple_select_many_batch(
+ table="pushers",
+ column="user_name",
iterable=user_ids,
- retcols=['user_name'],
- desc='get_if_users_have_pushers',
+ retcols=["user_name"],
+ desc="get_if_users_have_pushers",
)
result = {user_id: False for user_id in user_ids}
- result.update({r['user_name']: True for r in rows})
+ result.update({r["user_name"]: True for r in rows})
+
+ return result
+
+ @defer.inlineCallbacks
+ def update_pusher_last_stream_ordering(
+ self, app_id, pushkey, user_id, last_stream_ordering
+ ):
+ yield self.db.simple_update_one(
+ "pushers",
+ {"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
+ {"last_stream_ordering": last_stream_ordering},
+ desc="update_pusher_last_stream_ordering",
+ )
+
+ @defer.inlineCallbacks
+ def update_pusher_last_stream_ordering_and_success(
+ self, app_id, pushkey, user_id, last_stream_ordering, last_success
+ ):
+ """Update the last stream ordering position we've processed up to for
+ the given pusher.
- defer.returnValue(result)
+ Args:
+ app_id (str)
+ pushkey (str)
+ last_stream_ordering (int)
+ last_success (int)
+
+ Returns:
+ Deferred[bool]: True if the pusher still exists; False if it has been deleted.
+ """
+ updated = yield self.db.simple_update(
+ table="pushers",
+ keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
+ updatevalues={
+ "last_stream_ordering": last_stream_ordering,
+ "last_success": last_success,
+ },
+ desc="update_pusher_last_stream_ordering_and_success",
+ )
+
+ return bool(updated)
+
+ @defer.inlineCallbacks
+ def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
+ yield self.db.simple_update(
+ table="pushers",
+ keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
+ updatevalues={"failing_since": failing_since},
+ desc="update_pusher_failing_since",
+ )
+
+ @defer.inlineCallbacks
+ def get_throttle_params_by_room(self, pusher_id):
+ res = yield self.db.simple_select_list(
+ "pusher_throttle",
+ {"pusher": pusher_id},
+ ["room_id", "last_sent_ts", "throttle_ms"],
+ desc="get_throttle_params_by_room",
+ )
+
+ params_by_room = {}
+ for row in res:
+ params_by_room[row["room_id"]] = {
+ "last_sent_ts": row["last_sent_ts"],
+ "throttle_ms": row["throttle_ms"],
+ }
+
+ return params_by_room
+
+ @defer.inlineCallbacks
+ def set_throttle_params(self, pusher_id, room_id, params):
+ # no need to lock because `pusher_throttle` has a primary key on
+ # (pusher, room_id) so simple_upsert will retry
+ yield self.db.simple_upsert(
+ "pusher_throttle",
+ {"pusher": pusher_id, "room_id": room_id},
+ params,
+ desc="set_throttle_params",
+ lock=False,
+ )
class PusherStore(PusherWorkerStore):
@@ -230,8 +298,8 @@ class PusherStore(PusherWorkerStore):
):
with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
- # (app_id, pushkey, user_name) so _simple_upsert will retry
- yield self._simple_upsert(
+ # (app_id, pushkey, user_name) so simple_upsert will retry
+ yield self.db.simple_upsert(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
values={
@@ -241,7 +309,7 @@ class PusherStore(PusherWorkerStore):
"device_display_name": device_display_name,
"ts": pushkey_ts,
"lang": lang,
- "data": encode_canonical_json(data),
+ "data": bytearray(encode_canonical_json(data)),
"last_stream_ordering": last_stream_ordering,
"profile_tag": profile_tag,
"id": stream_id,
@@ -256,7 +324,7 @@ class PusherStore(PusherWorkerStore):
if user_has_pusher is not True:
# invalidate, since we the user might not have had a pusher before
- yield self.runInteraction(
+ yield self.db.runInteraction(
"add_pusher",
self._invalidate_cache_and_stream,
self.get_if_user_has_pusher,
@@ -270,7 +338,7 @@ class PusherStore(PusherWorkerStore):
txn, self.get_if_user_has_pusher, (user_id,)
)
- self._simple_delete_one_txn(
+ self.db.simple_delete_one_txn(
txn,
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
@@ -279,7 +347,7 @@ class PusherStore(PusherWorkerStore):
# it's possible for us to end up with duplicate rows for
# (app_id, pushkey, user_id) at different stream_ids, but that
# doesn't really matter.
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="deleted_pushers",
values={
@@ -291,68 +359,4 @@ class PusherStore(PusherWorkerStore):
)
with self._pushers_id_gen.get_next() as stream_id:
- yield self.runInteraction("delete_pusher", delete_pusher_txn, stream_id)
-
- @defer.inlineCallbacks
- def update_pusher_last_stream_ordering(
- self, app_id, pushkey, user_id, last_stream_ordering
- ):
- yield self._simple_update_one(
- "pushers",
- {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
- {'last_stream_ordering': last_stream_ordering},
- desc="update_pusher_last_stream_ordering",
- )
-
- @defer.inlineCallbacks
- def update_pusher_last_stream_ordering_and_success(
- self, app_id, pushkey, user_id, last_stream_ordering, last_success
- ):
- yield self._simple_update_one(
- "pushers",
- {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
- {
- 'last_stream_ordering': last_stream_ordering,
- 'last_success': last_success,
- },
- desc="update_pusher_last_stream_ordering_and_success",
- )
-
- @defer.inlineCallbacks
- def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
- yield self._simple_update_one(
- "pushers",
- {'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
- {'failing_since': failing_since},
- desc="update_pusher_failing_since",
- )
-
- @defer.inlineCallbacks
- def get_throttle_params_by_room(self, pusher_id):
- res = yield self._simple_select_list(
- "pusher_throttle",
- {"pusher": pusher_id},
- ["room_id", "last_sent_ts", "throttle_ms"],
- desc="get_throttle_params_by_room",
- )
-
- params_by_room = {}
- for row in res:
- params_by_room[row["room_id"]] = {
- "last_sent_ts": row["last_sent_ts"],
- "throttle_ms": row["throttle_ms"],
- }
-
- defer.returnValue(params_by_room)
-
- @defer.inlineCallbacks
- def set_throttle_params(self, pusher_id, room_id, params):
- # no need to lock because `pusher_throttle` has a primary key on
- # (pusher, room_id) so _simple_upsert will retry
- yield self._simple_upsert(
- "pusher_throttle",
- {"pusher": pusher_id, "room_id": room_id},
- params,
- desc="set_throttle_params",
- lock=False,
- )
+ yield self.db.runInteraction("delete_pusher", delete_pusher_txn, stream_id)
diff --git a/synapse/storage/receipts.py b/synapse/storage/data_stores/main/receipts.py
index a1647e50a1..0d932a0672 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/data_stores/main/receipts.py
@@ -21,12 +21,12 @@ from canonicaljson import json
from twisted.internet import defer
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.database import Database
+from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
-from ._base import SQLBaseStore
-from .util.id_generators import StreamIdGenerator
-
logger = logging.getLogger(__name__)
@@ -39,8 +39,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
- def __init__(self, db_conn, hs):
- super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs)
self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
@@ -58,11 +58,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cachedInlineCallbacks()
def get_users_with_read_receipts_in_room(self, room_id):
receipts = yield self.get_receipts_for_room(room_id, "m.read")
- defer.returnValue(set(r['user_id'] for r in receipts))
+ return {r["user_id"] for r in receipts}
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
- return self._simple_select_list(
+ return self.db.simple_select_list(
table="receipts_linearized",
keyvalues={"room_id": room_id, "receipt_type": receipt_type},
retcols=("user_id", "event_id"),
@@ -71,7 +71,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached(num_args=3)
def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
- return self._simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="receipts_linearized",
keyvalues={
"room_id": room_id,
@@ -85,14 +85,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=2)
def get_receipts_for_user(self, user_id, receipt_type):
- rows = yield self._simple_select_list(
+ rows = yield self.db.simple_select_list(
table="receipts_linearized",
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
retcols=("room_id", "event_id"),
desc="get_receipts_for_user",
)
- defer.returnValue({row["room_id"]: row["event_id"] for row in rows})
+ return {row["room_id"]: row["event_id"] for row in rows}
@defer.inlineCallbacks
def get_receipts_for_user_with_orderings(self, user_id, receipt_type):
@@ -109,17 +109,15 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id,))
return txn.fetchall()
- rows = yield self.runInteraction("get_receipts_for_user_with_orderings", f)
- defer.returnValue(
- {
- row[0]: {
- "event_id": row[1],
- "topological_ordering": row[2],
- "stream_ordering": row[3],
- }
- for row in rows
+ rows = yield self.db.runInteraction("get_receipts_for_user_with_orderings", f)
+ return {
+ row[0]: {
+ "event_id": row[1],
+ "topological_ordering": row[2],
+ "stream_ordering": row[3],
}
- )
+ for row in rows
+ }
@defer.inlineCallbacks
def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
@@ -147,7 +145,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
room_ids, to_key, from_key=from_key
)
- defer.returnValue([ev for res in results.values() for ev in res])
+ return [ev for res in results.values() for ev in res]
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients.
@@ -190,14 +188,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (room_id, to_key))
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
return rows
- rows = yield self.runInteraction("get_linearized_receipts_for_room", f)
+ rows = yield self.db.runInteraction("get_linearized_receipts_for_room", f)
if not rows:
- defer.returnValue([])
+ return []
content = {}
for row in rows:
@@ -205,9 +203,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
row["user_id"]
] = json.loads(row["data"])
- defer.returnValue(
- [{"type": "m.receipt", "room_id": room_id, "content": content}]
- )
+ return [{"type": "m.receipt", "room_id": room_id, "content": content}]
@cachedList(
cached_method_name="_get_linearized_receipts_for_room",
@@ -217,32 +213,36 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
- defer.returnValue({})
+ return {}
def f(txn):
if from_key:
- sql = (
- "SELECT * FROM receipts_linearized WHERE"
- " room_id IN (%s) AND stream_id > ? AND stream_id <= ?"
- ) % (",".join(["?"] * len(room_ids)))
- args = list(room_ids)
- args.extend([from_key, to_key])
+ sql = """
+ SELECT * FROM receipts_linearized WHERE
+ stream_id > ? AND stream_id <= ? AND
+ """
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", room_ids
+ )
- txn.execute(sql, args)
+ txn.execute(sql + clause, [from_key, to_key] + list(args))
else:
- sql = (
- "SELECT * FROM receipts_linearized WHERE"
- " room_id IN (%s) AND stream_id <= ?"
- ) % (",".join(["?"] * len(room_ids)))
+ sql = """
+ SELECT * FROM receipts_linearized WHERE
+ stream_id <= ? AND
+ """
- args = list(room_ids)
- args.append(to_key)
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", room_ids
+ )
- txn.execute(sql, args)
+ txn.execute(sql + clause, [to_key] + list(args))
- return self.cursor_to_dict(txn)
+ return self.db.cursor_to_dict(txn)
- txn_results = yield self.runInteraction("_get_linearized_receipts_for_rooms", f)
+ txn_results = yield self.db.runInteraction(
+ "_get_linearized_receipts_for_rooms", f
+ )
results = {}
for row in txn_results:
@@ -264,7 +264,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
room_id: [results[room_id]] if room_id in results else []
for room_id in room_ids
}
- defer.returnValue(results)
+ return results
def get_all_updated_receipts(self, last_id, current_id, limit=None):
if last_id == current_id:
@@ -283,9 +283,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
args.append(limit)
txn.execute(sql, args)
- return (r[0:5] + (json.loads(r[5]),) for r in txn)
+ return [r[0:5] + (json.loads(r[5]),) for r in txn]
- return self.runInteraction(
+ return self.db.runInteraction(
"get_all_updated_receipts", get_all_updated_receipts_txn
)
@@ -316,14 +316,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
class ReceiptsStore(ReceiptsWorkerStore):
- def __init__(self, db_conn, hs):
+ def __init__(self, database: Database, db_conn, hs):
# We instantiate this first as the ReceiptsWorkerStore constructor
# needs to be able to call get_max_receipt_stream_id
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
)
- super(ReceiptsStore, self).__init__(db_conn, hs)
+ super(ReceiptsStore, self).__init__(database, db_conn, hs)
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
@@ -338,7 +338,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
otherwise, the rx timestamp of the event that the RR corresponds to
(or 0 if the event is unknown)
"""
- res = self._simple_select_one_txn(
+ res = self.db.simple_select_one_txn(
txn,
table="events",
retcols=["stream_ordering", "received_ts"],
@@ -391,7 +391,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
(user_id, room_id, receipt_type),
)
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="receipts_linearized",
keyvalues={
@@ -401,7 +401,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
},
)
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="receipts_linearized",
values={
@@ -437,26 +437,32 @@ class ReceiptsStore(ReceiptsWorkerStore):
# we need to points in graph -> linearized form.
# TODO: Make this better.
def graph_to_linear(txn):
- query = (
- "SELECT event_id WHERE room_id = ? AND stream_ordering IN ("
- " SELECT max(stream_ordering) WHERE event_id IN (%s)"
- ")"
- ) % (",".join(["?"] * len(event_ids)))
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "event_id", event_ids
+ )
+
+ sql = """
+ SELECT event_id WHERE room_id = ? AND stream_ordering IN (
+ SELECT max(stream_ordering) WHERE %s
+ )
+ """ % (
+ clause,
+ )
- txn.execute(query, [room_id] + event_ids)
+ txn.execute(sql, [room_id] + list(args))
rows = txn.fetchall()
if rows:
return rows[0][0]
else:
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
- linearized_event_id = yield self.runInteraction(
+ linearized_event_id = yield self.db.runInteraction(
"insert_receipt_conv", graph_to_linear
)
stream_id_manager = self._receipts_id_gen.get_next()
with stream_id_manager as stream_id:
- event_ts = yield self.runInteraction(
+ event_ts = yield self.db.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
room_id,
@@ -468,7 +474,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
)
if event_ts is None:
- defer.returnValue(None)
+ return None
now = self._clock.time_msec()
logger.debug(
@@ -482,10 +488,10 @@ class ReceiptsStore(ReceiptsWorkerStore):
max_persisted_id = self._receipts_id_gen.get_current_token()
- defer.returnValue((stream_id, max_persisted_id))
+ return stream_id, max_persisted_id
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
- return self.runInteraction(
+ return self.db.runInteraction(
"insert_graph_receipt",
self.insert_graph_receipt_txn,
room_id,
@@ -511,7 +517,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
)
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="receipts_graph",
keyvalues={
@@ -520,7 +526,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"user_id": user_id,
},
)
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="receipts_graph",
values={
diff --git a/synapse/storage/registration.py b/synapse/storage/data_stores/main/registration.py
index 028848cf89..035fe348b0 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -19,14 +19,15 @@ import logging
import re
from six import iterkeys
-from six.moves import range
from twisted.internet import defer
+from twisted.internet.defer import Deferred
from synapse.api.constants import UserTypes
-from synapse.api.errors import Codes, StoreError, ThreepidValidationError
-from synapse.storage import background_updates
+from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
from synapse.types import UserID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@@ -36,25 +37,28 @@ logger = logging.getLogger(__name__)
class RegistrationWorkerStore(SQLBaseStore):
- def __init__(self, db_conn, hs):
- super(RegistrationWorkerStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(RegistrationWorkerStore, self).__init__(database, db_conn, hs)
self.config = hs.config
self.clock = hs.get_clock()
@cached()
def get_user_by_id(self, user_id):
- return self._simple_select_one(
+ return self.db.simple_select_one(
table="users",
keyvalues={"name": user_id},
retcols=[
"name",
"password_hash",
"is_guest",
+ "admin",
"consent_version",
"consent_server_notice_sent",
"appservice_id",
"creation_ts",
+ "user_type",
+ "deactivated",
],
allow_none=True,
desc="get_user_by_id",
@@ -74,12 +78,12 @@ class RegistrationWorkerStore(SQLBaseStore):
info = yield self.get_user_by_id(user_id)
if not info:
- defer.returnValue(False)
+ return False
now = self.clock.time_msec()
trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
- defer.returnValue(is_trial)
+ return is_trial
@cached()
def get_user_by_access_token(self, token):
@@ -89,9 +93,10 @@ class RegistrationWorkerStore(SQLBaseStore):
token (str): The access token of a user.
Returns:
defer.Deferred: None, if the token did not match, otherwise dict
- including the keys `name`, `is_guest`, `device_id`, `token_id`.
+ including the keys `name`, `is_guest`, `device_id`, `token_id`,
+ `valid_until_ms`.
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"get_user_by_access_token", self._query_for_auth, token
)
@@ -106,18 +111,19 @@ class RegistrationWorkerStore(SQLBaseStore):
otherwise int representation of the timestamp (as a number of
milliseconds since epoch).
"""
- res = yield self._simple_select_one_onecol(
+ res = yield self.db.simple_select_one_onecol(
table="account_validity",
keyvalues={"user_id": user_id},
retcol="expiration_ts_ms",
allow_none=True,
desc="get_expiration_ts_for_user",
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
- def set_account_validity_for_user(self, user_id, expiration_ts, email_sent,
- renewal_token=None):
+ def set_account_validity_for_user(
+ self, user_id, expiration_ts, email_sent, renewal_token=None
+ ):
"""Updates the account validity properties of the given account, with the
given values.
@@ -131,8 +137,9 @@ class RegistrationWorkerStore(SQLBaseStore):
renewal_token (str): Renewal token the user can use to extend the validity
of their account. Defaults to no token.
"""
+
def set_account_validity_for_user_txn(txn):
- self._simple_update_txn(
+ self.db.simple_update_txn(
txn=txn,
table="account_validity",
keyvalues={"user_id": user_id},
@@ -143,12 +150,11 @@ class RegistrationWorkerStore(SQLBaseStore):
},
)
self._invalidate_cache_and_stream(
- txn, self.get_expiration_ts_for_user, (user_id,),
+ txn, self.get_expiration_ts_for_user, (user_id,)
)
- yield self.runInteraction(
- "set_account_validity_for_user",
- set_account_validity_for_user_txn,
+ yield self.db.runInteraction(
+ "set_account_validity_for_user", set_account_validity_for_user_txn
)
@defer.inlineCallbacks
@@ -158,6 +164,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred[list[str]]: List of expired user IDs
"""
+
def get_expired_users_txn(txn, now_ms):
sql = """
SELECT user_id from account_validity
@@ -167,10 +174,8 @@ class RegistrationWorkerStore(SQLBaseStore):
rows = txn.fetchall()
return [row[0] for row in rows]
- res = yield self.runInteraction(
- "get_expired_users",
- get_expired_users_txn,
- self.clock.time_msec(),
+ res = yield self.db.runInteraction(
+ "get_expired_users", get_expired_users_txn, self.clock.time_msec()
)
defer.returnValue(res)
@@ -186,7 +191,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Raises:
StoreError: The provided token is already set for another user.
"""
- yield self._simple_update_one(
+ yield self.db.simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
updatevalues={"renewal_token": renewal_token},
@@ -203,14 +208,14 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
defer.Deferred[str]: The ID of the user to which the token belongs.
"""
- res = yield self._simple_select_one_onecol(
+ res = yield self.db.simple_select_one_onecol(
table="account_validity",
keyvalues={"renewal_token": renewal_token},
retcol="user_id",
desc="get_user_from_renewal_token",
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def get_renewal_token_for_user(self, user_id):
@@ -222,14 +227,14 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
defer.Deferred[str]: The renewal token associated with this user ID.
"""
- res = yield self._simple_select_one_onecol(
+ res = yield self.db.simple_select_one_onecol(
table="account_validity",
keyvalues={"user_id": user_id},
retcol="renewal_token",
desc="get_renewal_token_for_user",
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def get_users_expiring_soon(self):
@@ -240,6 +245,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred: Resolves to a list[dict[user_id (str), expiration_ts_ms (int)]]
"""
+
def select_users_txn(txn, now_ms, renew_at):
sql = (
"SELECT user_id, expiration_ts_ms FROM account_validity"
@@ -247,15 +253,16 @@ class RegistrationWorkerStore(SQLBaseStore):
)
values = [False, now_ms, renew_at]
txn.execute(sql, values)
- return self.cursor_to_dict(txn)
+ return self.db.cursor_to_dict(txn)
- res = yield self.runInteraction(
+ res = yield self.db.runInteraction(
"get_users_expiring_soon",
select_users_txn,
- self.clock.time_msec(), self.config.account_validity.renew_at,
+ self.clock.time_msec(),
+ self.config.account_validity.renew_at,
)
- defer.returnValue(res)
+ return res
@defer.inlineCallbacks
def set_renewal_mail_status(self, user_id, email_sent):
@@ -267,7 +274,7 @@ class RegistrationWorkerStore(SQLBaseStore):
email_sent (bool): Flag which indicates whether a renewal email has been sent
to this user.
"""
- yield self._simple_update_one(
+ yield self.db.simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
updatevalues={"email_sent": email_sent},
@@ -282,7 +289,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Args:
user_id (str): ID of the user to remove from the account validity table.
"""
- yield self._simple_delete_one(
+ yield self.db.simple_delete_one(
table="account_validity",
keyvalues={"user_id": user_id},
desc="delete_account_validity_for_user",
@@ -290,7 +297,15 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def is_server_admin(self, user):
- res = yield self._simple_select_one_onecol(
+ """Determines if a user is an admin of this homeserver.
+
+ Args:
+ user (UserID): user ID of the user to test
+
+ Returns (bool):
+ true iff the user is a server admin, false otherwise.
+ """
+ res = yield self.db.simple_select_one_onecol(
table="users",
keyvalues={"name": user.to_string()},
retcol="admin",
@@ -298,25 +313,59 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="is_server_admin",
)
- defer.returnValue(res if res else False)
+ return bool(res) if res else False
+
+ def set_server_admin(self, user, admin):
+ """Sets whether a user is an admin of this homeserver.
+
+ Args:
+ user (UserID): user ID of the user to test
+ admin (bool): true iff the user is to be a server admin,
+ false otherwise.
+ """
+
+ def set_server_admin_txn(txn):
+ self.db.simple_update_one_txn(
+ txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0}
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_by_id, (user.to_string(),)
+ )
+
+ return self.db.runInteraction("set_server_admin", set_server_admin_txn)
def _query_for_auth(self, txn, token):
sql = (
"SELECT users.name, users.is_guest, access_tokens.id as token_id,"
- " access_tokens.device_id"
+ " access_tokens.device_id, access_tokens.valid_until_ms"
" FROM users"
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
" WHERE token = ?"
)
txn.execute(sql, (token,))
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
if rows:
return rows[0]
return None
@cachedInlineCallbacks()
+ def is_real_user(self, user_id):
+ """Determines if the user is a real user, ie does not have a 'user_type'.
+
+ Args:
+ user_id (str): user id to test
+
+ Returns:
+ Deferred[bool]: True if user 'user_type' is null or empty string
+ """
+ res = yield self.db.runInteraction(
+ "is_real_user", self.is_real_user_txn, user_id
+ )
+ return res
+
+ @cachedInlineCallbacks()
def is_support_user(self, user_id):
"""Determines if the user is of type UserTypes.SUPPORT
@@ -326,13 +375,23 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred[bool]: True if user is of type UserTypes.SUPPORT
"""
- res = yield self.runInteraction(
+ res = yield self.db.runInteraction(
"is_support_user", self.is_support_user_txn, user_id
)
- defer.returnValue(res)
+ return res
+
+ def is_real_user_txn(self, txn, user_id):
+ res = self.db.simple_select_one_onecol_txn(
+ txn=txn,
+ table="users",
+ keyvalues={"name": user_id},
+ retcol="user_type",
+ allow_none=True,
+ )
+ return res is None
def is_support_user_txn(self, txn, user_id):
- res = self._simple_select_one_onecol_txn(
+ res = self.db.simple_select_one_onecol_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
@@ -347,13 +406,31 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
def f(txn):
- sql = (
- "SELECT name, password_hash FROM users" " WHERE lower(name) = lower(?)"
- )
+ sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)"
txn.execute(sql, (user_id,))
return dict(txn)
- return self.runInteraction("get_users_by_id_case_insensitive", f)
+ return self.db.runInteraction("get_users_by_id_case_insensitive", f)
+
+ async def get_user_by_external_id(
+ self, auth_provider: str, external_id: str
+ ) -> str:
+ """Look up a user by their external auth id
+
+ Args:
+ auth_provider: identifier for the remote auth provider
+ external_id: id on that system
+
+ Returns:
+ str|None: the mxid of the user, or None if they are not known
+ """
+ return await self.db.simple_select_one_onecol(
+ table="user_external_ids",
+ keyvalues={"auth_provider": auth_provider, "external_id": external_id},
+ retcol="user_id",
+ allow_none=True,
+ desc="get_user_by_external_id",
+ )
@defer.inlineCallbacks
def count_all_users(self):
@@ -361,13 +438,13 @@ class RegistrationWorkerStore(SQLBaseStore):
def _count_users(txn):
txn.execute("SELECT COUNT(*) AS users FROM users")
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
if rows:
return rows[0]["users"]
return 0
- ret = yield self.runInteraction("count_users", _count_users)
- defer.returnValue(ret)
+ ret = yield self.db.runInteraction("count_users", _count_users)
+ return ret
def count_daily_user_type(self):
"""
@@ -392,13 +469,13 @@ class RegistrationWorkerStore(SQLBaseStore):
WHERE creation_ts > ?
) AS t GROUP BY user_type
"""
- results = {'native': 0, 'guest': 0, 'bridged': 0}
+ results = {"native": 0, "guest": 0, "bridged": 0}
txn.execute(sql, (yesterday,))
for row in txn:
results[row[0]] = row[1]
return results
- return self.runInteraction("count_daily_user_type", _count_daily_user_type)
+ return self.db.runInteraction("count_daily_user_type", _count_daily_user_type)
@defer.inlineCallbacks
def count_nonbridged_users(self):
@@ -409,63 +486,61 @@ class RegistrationWorkerStore(SQLBaseStore):
WHERE appservice_id IS NULL
"""
)
- count, = txn.fetchone()
+ (count,) = txn.fetchone()
return count
- ret = yield self.runInteraction("count_users", _count_users)
- defer.returnValue(ret)
+ ret = yield self.db.runInteraction("count_users", _count_users)
+ return ret
+
+ @defer.inlineCallbacks
+ def count_real_users(self):
+ """Counts all users without a special user_type registered on the homeserver."""
+
+ def _count_users(txn):
+ txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
+ rows = self.db.cursor_to_dict(txn)
+ if rows:
+ return rows[0]["users"]
+ return 0
+
+ ret = yield self.db.runInteraction("count_real_users", _count_users)
+ return ret
@defer.inlineCallbacks
def find_next_generated_user_id_localpart(self):
"""
Gets the localpart of the next generated user ID.
- Generated user IDs are integers, and we aim for them to be as small as
- we can. Unfortunately, it's possible some of them are already taken by
- existing users, and there may be gaps in the already taken range. This
- function returns the start of the first allocatable gap. This is to
- avoid the case of ID 10000000 being pre-allocated, so us wasting the
- first (and shortest) many generated user IDs.
+ Generated user IDs are integers, so we find the largest integer user ID
+ already taken and return that plus one.
"""
def _find_next_generated_user_id(txn):
- txn.execute("SELECT name FROM users")
+ # We bound between '@0' and '@a' to avoid pulling the entire table
+ # out.
+ txn.execute("SELECT name FROM users WHERE '@0' <= name AND name < '@a'")
regex = re.compile(r"^@(\d+):")
- found = set()
+ max_found = 0
for (user_id,) in txn:
match = regex.search(user_id)
if match:
- found.add(int(match.group(1)))
- for i in range(len(found) + 1):
- if i not in found:
- return i
+ max_found = max(int(match.group(1)), max_found)
+
+ return max_found + 1
- defer.returnValue(
+ return (
(
- yield self.runInteraction(
+ yield self.db.runInteraction(
"find_next_generated_user_id", _find_next_generated_user_id
)
)
)
@defer.inlineCallbacks
- def get_3pid_guest_access_token(self, medium, address):
- ret = yield self._simple_select_one(
- "threepid_guest_access_tokens",
- {"medium": medium, "address": address},
- ["guest_access_token"],
- True,
- 'get_3pid_guest_access_token',
- )
- if ret:
- defer.returnValue(ret["guest_access_token"])
- defer.returnValue(None)
-
- @defer.inlineCallbacks
- def get_user_id_by_threepid(self, medium, address, require_verified=False):
+ def get_user_id_by_threepid(self, medium, address):
"""Returns user id from threepid
Args:
@@ -475,10 +550,10 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred[str|None]: user id or None if no user id/threepid mapping exists
"""
- user_id = yield self.runInteraction(
+ user_id = yield self.db.runInteraction(
"get_user_id_by_threepid", self.get_user_id_by_threepid_txn, medium, address
)
- defer.returnValue(user_id)
+ return user_id
def get_user_id_by_threepid_txn(self, txn, medium, address):
"""Returns user id from threepid
@@ -491,20 +566,20 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
str|None: user id or None if no user id/threepid mapping exists
"""
- ret = self._simple_select_one_txn(
+ ret = self.db.simple_select_one_txn(
txn,
"user_threepids",
{"medium": medium, "address": address},
- ['user_id'],
+ ["user_id"],
True,
)
if ret:
- return ret['user_id']
+ return ret["user_id"]
return None
@defer.inlineCallbacks
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
- yield self._simple_upsert(
+ yield self.db.simple_upsert(
"user_threepids",
{"medium": medium, "address": address},
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
@@ -512,18 +587,31 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def user_get_threepids(self, user_id):
- ret = yield self._simple_select_list(
+ ret = yield self.db.simple_select_list(
"user_threepids",
{"user_id": user_id},
- ['medium', 'address', 'validated_at', 'added_at'],
- 'user_get_threepids',
+ ["medium", "address", "validated_at", "added_at"],
+ "user_get_threepids",
)
- defer.returnValue(ret)
+ return ret
def user_delete_threepid(self, user_id, medium, address):
- return self._simple_delete(
+ return self.db.simple_delete(
"user_threepids",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
+ desc="user_delete_threepid",
+ )
+
+ def user_delete_threepids(self, user_id: str):
+ """Delete all threepid this user has bound
+
+ Args:
+ user_id: The user id to delete all threepids of
+
+ """
+ return self.db.simple_delete(
+ "user_threepids",
+ keyvalues={"user_id": user_id},
desc="user_delete_threepids",
)
@@ -543,7 +631,7 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
# We need to use an upsert, in case they user had already bound the
# threepid
- return self._simple_upsert(
+ return self.db.simple_upsert(
table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
@@ -556,6 +644,26 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="add_user_bound_threepid",
)
+ def user_get_bound_threepids(self, user_id):
+ """Get the threepids that a user has bound to an identity server through the homeserver
+ The homeserver remembers where binds to an identity server occurred. Using this
+ method can retrieve those threepids.
+
+ Args:
+ user_id (str): The ID of the user to retrieve threepids for
+
+ Returns:
+ Deferred[list[dict]]: List of dictionaries containing the following:
+ medium (str): The medium of the threepid (e.g "email")
+ address (str): The address of the threepid (e.g "bob@example.com")
+ """
+ return self.db.simple_select_list(
+ table="user_threepid_id_server",
+ keyvalues={"user_id": user_id},
+ retcols=["medium", "address"],
+ desc="user_get_bound_threepids",
+ )
+
def remove_user_bound_threepid(self, user_id, medium, address, id_server):
"""The server proxied an unbind request to the given identity server on
behalf of the given user, so we remove the mapping of threepid to
@@ -570,7 +678,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred
"""
- return self._simple_delete(
+ return self.db.simple_delete(
table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
@@ -593,69 +701,170 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred[list[str]]: Resolves to a list of identity servers
"""
- return self._simple_select_onecol(
+ return self.db.simple_select_onecol(
table="user_threepid_id_server",
- keyvalues={
- "user_id": user_id,
- "medium": medium,
- "address": address,
- },
+ keyvalues={"user_id": user_id, "medium": medium, "address": address},
retcol="id_server",
desc="get_id_servers_user_bound",
)
+ @cachedInlineCallbacks()
+ def get_user_deactivated_status(self, user_id):
+ """Retrieve the value for the `deactivated` property for the provided user.
-class RegistrationStore(
- RegistrationWorkerStore, background_updates.BackgroundUpdateStore
-):
- def __init__(self, db_conn, hs):
- super(RegistrationStore, self).__init__(db_conn, hs)
+ Args:
+ user_id (str): The ID of the user to retrieve the status for.
+
+ Returns:
+ defer.Deferred(bool): The requested value.
+ """
+
+ res = yield self.db.simple_select_one_onecol(
+ table="users",
+ keyvalues={"name": user_id},
+ retcol="deactivated",
+ desc="get_user_deactivated_status",
+ )
+
+ # Convert the integer into a boolean.
+ return res == 1
+
+ def get_threepid_validation_session(
+ self, medium, client_secret, address=None, sid=None, validated=True
+ ):
+ """Gets a session_id and last_send_attempt (if available) for a
+ combination of validation metadata
+
+ Args:
+ medium (str|None): The medium of the 3PID
+ address (str|None): The address of the 3PID
+ sid (str|None): The ID of the validation session
+ client_secret (str): A unique string provided by the client to help identify this
+ validation attempt
+ validated (bool|None): Whether sessions should be filtered by
+ whether they have been validated already or not. None to
+ perform no filtering
+
+ Returns:
+ Deferred[dict|None]: A dict containing the following:
+ * address - address of the 3pid
+ * medium - medium of the 3pid
+ * client_secret - a secret provided by the client for this validation session
+ * session_id - ID of the validation session
+ * send_attempt - a number serving to dedupe send attempts for this session
+ * validated_at - timestamp of when this session was validated if so
+
+ Otherwise None if a validation session is not found
+ """
+ if not client_secret:
+ raise SynapseError(
+ 400, "Missing parameter: client_secret", errcode=Codes.MISSING_PARAM
+ )
+
+ keyvalues = {"client_secret": client_secret}
+ if medium:
+ keyvalues["medium"] = medium
+ if address:
+ keyvalues["address"] = address
+ if sid:
+ keyvalues["session_id"] = sid
+
+ assert address or sid
+
+ def get_threepid_validation_session_txn(txn):
+ sql = """
+ SELECT address, session_id, medium, client_secret,
+ last_send_attempt, validated_at
+ FROM threepid_validation_session WHERE %s
+ """ % (
+ " AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),
+ )
+
+ if validated is not None:
+ sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL")
+
+ sql += " LIMIT 1"
+
+ txn.execute(sql, list(keyvalues.values()))
+ rows = self.db.cursor_to_dict(txn)
+ if not rows:
+ return None
+
+ return rows[0]
+
+ return self.db.runInteraction(
+ "get_threepid_validation_session", get_threepid_validation_session_txn
+ )
+
+ def delete_threepid_session(self, session_id):
+ """Removes a threepid validation session from the database. This can
+ be done after validation has been performed and whatever action was
+ waiting on it has been carried out
+
+ Args:
+ session_id (str): The ID of the session to delete
+ """
+
+ def delete_threepid_session_txn(txn):
+ self.db.simple_delete_txn(
+ txn,
+ table="threepid_validation_token",
+ keyvalues={"session_id": session_id},
+ )
+ self.db.simple_delete_txn(
+ txn,
+ table="threepid_validation_session",
+ keyvalues={"session_id": session_id},
+ )
+
+ return self.db.runInteraction(
+ "delete_threepid_session", delete_threepid_session_txn
+ )
+
+
+class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.clock = hs.get_clock()
+ self.config = hs.config
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"access_tokens_device_index",
index_name="access_tokens_device_id",
table="access_tokens",
columns=["user_id", "device_id"],
)
- self.register_background_index_update(
+ self.db.updates.register_background_index_update(
"users_creation_ts",
index_name="users_creation_ts",
table="users",
columns=["creation_ts"],
)
- self._account_validity = hs.config.account_validity
-
# we no longer use refresh tokens, but it's possible that some people
# might have a background update queued to build this index. Just
# clear the background update.
- self.register_noop_background_update("refresh_tokens_device_index")
+ self.db.updates.register_noop_background_update("refresh_tokens_device_index")
- self.register_background_update_handler(
- "user_threepids_grandfather", self._bg_user_threepids_grandfather,
+ self.db.updates.register_background_update_handler(
+ "user_threepids_grandfather", self._bg_user_threepids_grandfather
)
- self.register_background_update_handler(
- "users_set_deactivated_flag", self._backgroud_update_set_deactivated_flag,
- )
-
- # Create a background job for culling expired 3PID validity tokens
- hs.get_clock().looping_call(
- self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS,
+ self.db.updates.register_background_update_handler(
+ "users_set_deactivated_flag", self._background_update_set_deactivated_flag
)
@defer.inlineCallbacks
- def _backgroud_update_set_deactivated_flag(self, progress, batch_size):
+ def _background_update_set_deactivated_flag(self, progress, batch_size):
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
for each of them.
"""
last_user = progress.get("user_id", "")
- def _backgroud_update_set_deactivated_flag_txn(txn):
+ def _background_update_set_deactivated_flag_txn(txn):
txn.execute(
"""
SELECT
@@ -676,10 +885,10 @@ class RegistrationStore(
(last_user, batch_size),
)
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
if not rows:
- return True
+ return True, 0
rows_processed_nb = 0
@@ -690,49 +899,111 @@ class RegistrationStore(
logger.info("Marked %d rows as deactivated", rows_processed_nb)
- self._background_update_progress_txn(
+ self.db.updates._background_update_progress_txn(
txn, "users_set_deactivated_flag", {"user_id": rows[-1]["name"]}
)
if batch_size > len(rows):
- return True
+ return True, len(rows)
else:
- return False
+ return False, len(rows)
- end = yield self.runInteraction(
- "users_set_deactivated_flag",
- _backgroud_update_set_deactivated_flag_txn,
+ end, nb_processed = yield self.db.runInteraction(
+ "users_set_deactivated_flag", _background_update_set_deactivated_flag_txn
)
if end:
- yield self._end_background_update("users_set_deactivated_flag")
+ yield self.db.updates._end_background_update("users_set_deactivated_flag")
- defer.returnValue(batch_size)
+ return nb_processed
@defer.inlineCallbacks
- def add_access_token_to_user(self, user_id, token, device_id=None):
+ def _bg_user_threepids_grandfather(self, progress, batch_size):
+ """We now track which identity servers a user binds their 3PID to, so
+ we need to handle the case of existing bindings where we didn't track
+ this.
+
+ We do this by grandfathering in existing user threepids assuming that
+ they used one of the server configured trusted identity servers.
+ """
+ id_servers = set(self.config.trusted_third_party_id_servers)
+
+ def _bg_user_threepids_grandfather_txn(txn):
+ sql = """
+ INSERT INTO user_threepid_id_server
+ (user_id, medium, address, id_server)
+ SELECT user_id, medium, address, ?
+ FROM user_threepids
+ """
+
+ txn.executemany(sql, [(id_server,) for id_server in id_servers])
+
+ if id_servers:
+ yield self.db.runInteraction(
+ "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
+ )
+
+ yield self.db.updates._end_background_update("user_threepids_grandfather")
+
+ return 1
+
+
+class RegistrationStore(RegistrationBackgroundUpdateStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(RegistrationStore, self).__init__(database, db_conn, hs)
+
+ self._account_validity = hs.config.account_validity
+
+ if self._account_validity.enabled:
+ self._clock.call_later(
+ 0.0,
+ run_as_background_process,
+ "account_validity_set_expiration_dates",
+ self._set_expiration_date_when_missing,
+ )
+
+ # Create a background job for culling expired 3PID validity tokens
+ def start_cull():
+ # run as a background process to make sure that the database transactions
+ # have a logcontext to report to
+ return run_as_background_process(
+ "cull_expired_threepid_validation_tokens",
+ self.cull_expired_threepid_validation_tokens,
+ )
+
+ hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
+
+ @defer.inlineCallbacks
+ def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
"""Adds an access token for the given user.
Args:
user_id (str): The user ID.
token (str): The new access token to add.
device_id (str): ID of the device to associate with the access
- token
+ token
+ valid_until_ms (int|None): when the token is valid until. None for
+ no expiry.
Raises:
StoreError if there was a problem adding this.
"""
next_id = self._access_tokens_id_gen.get_next()
- yield self._simple_insert(
+ yield self.db.simple_insert(
"access_tokens",
- {"id": next_id, "user_id": user_id, "token": token, "device_id": device_id},
+ {
+ "id": next_id,
+ "user_id": user_id,
+ "token": token,
+ "device_id": device_id,
+ "valid_until_ms": valid_until_ms,
+ },
desc="add_access_token_to_user",
)
- def register(
+ def register_user(
self,
user_id,
- token=None,
password_hash=None,
was_guest=False,
make_guest=False,
@@ -745,9 +1016,6 @@ class RegistrationStore(
Args:
user_id (str): The desired user ID to register.
- token (str): The desired access token to use for this user. If this
- is not None, the given access token is associated with the user
- id.
password_hash (str): Optional. The password hash for this user.
was_guest (bool): Optional. Whether this is a guest account being
upgraded to a non-guest account.
@@ -763,11 +1031,10 @@ class RegistrationStore(
Raises:
StoreError if the user_id could not be registered.
"""
- return self.runInteraction(
- "register",
- self._register,
+ return self.db.runInteraction(
+ "register_user",
+ self._register_user,
user_id,
- token,
password_hash,
was_guest,
make_guest,
@@ -777,11 +1044,10 @@ class RegistrationStore(
user_type,
)
- def _register(
+ def _register_user(
self,
txn,
user_id,
- token,
password_hash,
was_guest,
make_guest,
@@ -794,14 +1060,12 @@ class RegistrationStore(
now = int(self.clock.time())
- next_id = self._access_tokens_id_gen.get_next()
-
try:
if was_guest:
# Ensure that the guest user actually exists
# ``allow_none=False`` makes this raise an exception
# if the row isn't in the database.
- self._simple_select_one_txn(
+ self.db.simple_select_one_txn(
txn,
"users",
keyvalues={"name": user_id, "is_guest": 1},
@@ -809,7 +1073,7 @@ class RegistrationStore(
allow_none=False,
)
- self._simple_update_one_txn(
+ self.db.simple_update_one_txn(
txn,
"users",
keyvalues={"name": user_id, "is_guest": 1},
@@ -823,7 +1087,7 @@ class RegistrationStore(
},
)
else:
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
"users",
values={
@@ -843,14 +1107,6 @@ class RegistrationStore(
if self._account_validity.enabled:
self.set_expiration_date_for_user_txn(txn, user_id)
- if token:
- # it's possible for this to get a conflict, but only for a single user
- # since tokens are namespaced based on their user ID
- txn.execute(
- "INSERT INTO access_tokens(id, user_id, token)" " VALUES (?,?,?)",
- (next_id, user_id, token),
- )
-
if create_profile_with_displayname:
# set a default displayname serverside to avoid ugly race
# between auto-joins and clients trying to set displaynames
@@ -862,9 +1118,40 @@ class RegistrationStore(
(user_id_obj.localpart, create_profile_with_displayname),
)
+ if self.hs.config.stats_enabled:
+ # we create a new completed user statistics row
+
+ # we don't strictly need current_token since this user really can't
+ # have any state deltas before now (as it is a new user), but still,
+ # we include it for completeness.
+ current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn)
+ self._update_stats_delta_txn(
+ txn, now, "user", user_id, {}, complete_with_stream_id=current_token
+ )
+
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))
+ def record_user_external_id(
+ self, auth_provider: str, external_id: str, user_id: str
+ ) -> Deferred:
+ """Record a mapping from an external user id to a mxid
+
+ Args:
+ auth_provider: identifier for the remote auth provider
+ external_id: id on that system
+ user_id: complete mxid that it is mapped to
+ """
+ return self.db.simple_insert(
+ table="user_external_ids",
+ values={
+ "auth_provider": auth_provider,
+ "external_id": external_id,
+ "user_id": user_id,
+ },
+ desc="record_user_external_id",
+ )
+
def user_set_password_hash(self, user_id, password_hash):
"""
NB. This does *not* evict any cache because the one use for this
@@ -873,12 +1160,14 @@ class RegistrationStore(
"""
def user_set_password_hash_txn(txn):
- self._simple_update_one_txn(
- txn, 'users', {'name': user_id}, {'password_hash': password_hash}
+ self.db.simple_update_one_txn(
+ txn, "users", {"name": user_id}, {"password_hash": password_hash}
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- return self.runInteraction("user_set_password_hash", user_set_password_hash_txn)
+ return self.db.runInteraction(
+ "user_set_password_hash", user_set_password_hash_txn
+ )
def user_set_consent_version(self, user_id, consent_version):
"""Updates the user table to record privacy policy consent
@@ -893,15 +1182,15 @@ class RegistrationStore(
"""
def f(txn):
- self._simple_update_one_txn(
+ self.db.simple_update_one_txn(
txn,
- table='users',
- keyvalues={'name': user_id},
- updatevalues={'consent_version': consent_version},
+ table="users",
+ keyvalues={"name": user_id},
+ updatevalues={"consent_version": consent_version},
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- return self.runInteraction("user_set_consent_version", f)
+ return self.db.runInteraction("user_set_consent_version", f)
def user_set_consent_server_notice_sent(self, user_id, consent_version):
"""Updates the user table to record that we have sent the user a server
@@ -917,15 +1206,15 @@ class RegistrationStore(
"""
def f(txn):
- self._simple_update_one_txn(
+ self.db.simple_update_one_txn(
txn,
- table='users',
- keyvalues={'name': user_id},
- updatevalues={'consent_server_notice_sent': consent_version},
+ table="users",
+ keyvalues={"name": user_id},
+ updatevalues={"consent_server_notice_sent": consent_version},
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
- return self.runInteraction("user_set_consent_server_notice_sent", f)
+ return self.db.runInteraction("user_set_consent_server_notice_sent", f)
def user_delete_access_tokens(self, user_id, except_token_id=None, device_id=None):
"""
@@ -971,11 +1260,11 @@ class RegistrationStore(
return tokens_and_devices
- return self.runInteraction("user_delete_access_tokens", f)
+ return self.db.runInteraction("user_delete_access_tokens", f)
def delete_access_token(self, access_token):
def f(txn):
- self._simple_delete_one_txn(
+ self.db.simple_delete_one_txn(
txn, table="access_tokens", keyvalues={"token": access_token}
)
@@ -983,11 +1272,11 @@ class RegistrationStore(
txn, self.get_user_by_access_token, (access_token,)
)
- return self.runInteraction("delete_access_token", f)
+ return self.db.runInteraction("delete_access_token", f)
@cachedInlineCallbacks()
def is_guest(self, user_id):
- res = yield self._simple_select_one_onecol(
+ res = yield self.db.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="is_guest",
@@ -995,48 +1284,14 @@ class RegistrationStore(
desc="is_guest",
)
- defer.returnValue(res if res else False)
-
- @defer.inlineCallbacks
- def save_or_get_3pid_guest_access_token(
- self, medium, address, access_token, inviter_user_id
- ):
- """
- Gets the 3pid's guest access token if exists, else saves access_token.
-
- Args:
- medium (str): Medium of the 3pid. Must be "email".
- address (str): 3pid address.
- access_token (str): The access token to persist if none is
- already persisted.
- inviter_user_id (str): User ID of the inviter.
-
- Returns:
- deferred str: Whichever access token is persisted at the end
- of this function call.
- """
-
- def insert(txn):
- txn.execute(
- "INSERT INTO threepid_guest_access_tokens "
- "(medium, address, guest_access_token, first_inviter) "
- "VALUES (?, ?, ?, ?)",
- (medium, address, access_token, inviter_user_id),
- )
-
- try:
- yield self.runInteraction("save_3pid_guest_access_token", insert)
- defer.returnValue(access_token)
- except self.database_engine.module.IntegrityError:
- ret = yield self.get_3pid_guest_access_token(medium, address)
- defer.returnValue(ret)
+ return res if res else False
def add_user_pending_deactivation(self, user_id):
"""
Adds a user to the table of users who need to be parted from all the rooms they're
in
"""
- return self._simple_insert(
+ return self.db.simple_insert(
"users_pending_deactivation",
values={"user_id": user_id},
desc="add_user_pending_deactivation",
@@ -1049,7 +1304,7 @@ class RegistrationStore(
"""
# XXX: This should be simple_delete_one but we failed to put a unique index on
# the table, so somehow duplicate entries have ended up in it.
- return self._simple_delete(
+ return self.db.simple_delete(
"users_pending_deactivation",
keyvalues={"user_id": user_id},
desc="del_user_pending_deactivation",
@@ -1060,7 +1315,7 @@ class RegistrationStore(
Gets one user from the table of users waiting to be parted from all the rooms
they're in.
"""
- return self._simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
"users_pending_deactivation",
keyvalues={},
retcol="user_id",
@@ -1068,104 +1323,7 @@ class RegistrationStore(
desc="get_users_pending_deactivation",
)
- @defer.inlineCallbacks
- def _bg_user_threepids_grandfather(self, progress, batch_size):
- """We now track which identity servers a user binds their 3PID to, so
- we need to handle the case of existing bindings where we didn't track
- this.
-
- We do this by grandfathering in existing user threepids assuming that
- they used one of the server configured trusted identity servers.
- """
- id_servers = set(self.config.trusted_third_party_id_servers)
-
- def _bg_user_threepids_grandfather_txn(txn):
- sql = """
- INSERT INTO user_threepid_id_server
- (user_id, medium, address, id_server)
- SELECT user_id, medium, address, ?
- FROM user_threepids
- """
-
- txn.executemany(sql, [(id_server,) for id_server in id_servers])
-
- if id_servers:
- yield self.runInteraction(
- "_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn,
- )
-
- yield self._end_background_update("user_threepids_grandfather")
-
- defer.returnValue(1)
-
- def get_threepid_validation_session(
- self,
- medium,
- client_secret,
- address=None,
- sid=None,
- validated=True,
- ):
- """Gets a session_id and last_send_attempt (if available) for a
- client_secret/medium/(address|session_id) combo
-
- Args:
- medium (str|None): The medium of the 3PID
- address (str|None): The address of the 3PID
- sid (str|None): The ID of the validation session
- client_secret (str|None): A unique string provided by the client to
- help identify this validation attempt
- validated (bool|None): Whether sessions should be filtered by
- whether they have been validated already or not. None to
- perform no filtering
-
- Returns:
- deferred {str, int}|None: A dict containing the
- latest session_id and send_attempt count for this 3PID.
- Otherwise None if there hasn't been a previous attempt
- """
- keyvalues = {
- "medium": medium,
- "client_secret": client_secret,
- }
- if address:
- keyvalues["address"] = address
- if sid:
- keyvalues["session_id"] = sid
-
- assert(address or sid)
-
- def get_threepid_validation_session_txn(txn):
- sql = """
- SELECT address, session_id, medium, client_secret,
- last_send_attempt, validated_at
- FROM threepid_validation_session WHERE %s
- """ % (" AND ".join("%s = ?" % k for k in iterkeys(keyvalues)),)
-
- if validated is not None:
- sql += " AND validated_at IS " + ("NOT NULL" if validated else "NULL")
-
- sql += " LIMIT 1"
-
- txn.execute(sql, list(keyvalues.values()))
- rows = self.cursor_to_dict(txn)
- if not rows:
- return None
-
- return rows[0]
-
- return self.runInteraction(
- "get_threepid_validation_session",
- get_threepid_validation_session_txn,
- )
-
- def validate_threepid_session(
- self,
- session_id,
- client_secret,
- token,
- current_ts,
- ):
+ def validate_threepid_session(self, session_id, client_secret, token, current_ts):
"""Attempt to validate a threepid session using a token
Args:
@@ -1176,13 +1334,18 @@ class RegistrationStore(
current_ts (int): The current unix time in milliseconds. Used for
checking token expiry status
+ Raises:
+ ThreepidValidationError: if a matching validation token was not found or has
+ expired
+
Returns:
deferred str|None: A str representing a link to redirect the user
to if there is one.
"""
+
# Insert everything into a transaction in order to run atomically
def validate_threepid_session_txn(txn):
- row = self._simple_select_one_txn(
+ row = self.db.simple_select_one_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
@@ -1197,10 +1360,10 @@ class RegistrationStore(
if retrieved_client_secret != client_secret:
raise ThreepidValidationError(
- 400, "This client_secret does not match the provided session_id",
+ 400, "This client_secret does not match the provided session_id"
)
- row = self._simple_select_one_txn(
+ row = self.db.simple_select_one_txn(
txn,
table="threepid_validation_token",
keyvalues={"session_id": session_id, "token": token},
@@ -1210,7 +1373,7 @@ class RegistrationStore(
if not row:
raise ThreepidValidationError(
- 400, "Validation token not found or has expired",
+ 400, "Validation token not found or has expired"
)
expires = row["expires"]
next_link = row["next_link"]
@@ -1221,11 +1384,11 @@ class RegistrationStore(
if expires <= current_ts:
raise ThreepidValidationError(
- 400, "This token has expired. Please request a new one",
+ 400, "This token has expired. Please request a new one"
)
# Looks good. Validate the session
- self._simple_update_txn(
+ self.db.simple_update_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
@@ -1235,9 +1398,8 @@ class RegistrationStore(
return next_link
# Return next_link if it exists
- return self.runInteraction(
- "validate_threepid_session_txn",
- validate_threepid_session_txn,
+ return self.db.runInteraction(
+ "validate_threepid_session_txn", validate_threepid_session_txn
)
def upsert_threepid_validation_session(
@@ -1269,7 +1431,7 @@ class RegistrationStore(
if validated_at:
insertion_values["validated_at"] = validated_at
- return self._simple_upsert(
+ return self.db.simple_upsert(
table="threepid_validation_session",
keyvalues={"session_id": session_id},
values={"last_send_attempt": send_attempt},
@@ -1304,9 +1466,10 @@ class RegistrationStore(
token_expires (int): The timestamp for which after the token
will no longer be valid
"""
+
def start_or_continue_validation_session_txn(txn):
# Create or update a validation session
- self._simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
@@ -1319,7 +1482,7 @@ class RegistrationStore(
)
# Create a new validation token with this session ID
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="threepid_validation_token",
values={
@@ -1330,13 +1493,14 @@ class RegistrationStore(
},
)
- return self.runInteraction(
+ return self.db.runInteraction(
"start_or_continue_validation_session",
start_or_continue_validation_session_txn,
)
def cull_expired_threepid_validation_tokens(self):
"""Remove threepid validation tokens with expiry dates that have passed"""
+
def cull_expired_threepid_validation_tokens_txn(txn, ts):
sql = """
DELETE FROM threepid_validation_token WHERE
@@ -1344,80 +1508,91 @@ class RegistrationStore(
"""
return txn.execute(sql, (ts,))
- return self.runInteraction(
+ return self.db.runInteraction(
"cull_expired_threepid_validation_tokens",
cull_expired_threepid_validation_tokens_txn,
self.clock.time_msec(),
)
- def delete_threepid_session(self, session_id):
- """Removes a threepid validation session from the database. This can
- be done after validation has been performed and whatever action was
- waiting on it has been carried out
+ @defer.inlineCallbacks
+ def set_user_deactivated_status(self, user_id, deactivated):
+ """Set the `deactivated` property for the provided user to the provided value.
Args:
- session_id (str): The ID of the session to delete
+ user_id (str): The ID of the user to set the status for.
+ deactivated (bool): The value to set for `deactivated`.
"""
- def delete_threepid_session_txn(txn):
- self._simple_delete_txn(
- txn,
- table="threepid_validation_token",
- keyvalues={"session_id": session_id},
- )
- self._simple_delete_txn(
- txn,
- table="threepid_validation_session",
- keyvalues={"session_id": session_id},
- )
- return self.runInteraction(
- "delete_threepid_session",
- delete_threepid_session_txn,
+ yield self.db.runInteraction(
+ "set_user_deactivated_status",
+ self.set_user_deactivated_status_txn,
+ user_id,
+ deactivated,
)
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
- self._simple_update_one_txn(
+ self.db.simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
updatevalues={"deactivated": 1 if deactivated else 0},
)
self._invalidate_cache_and_stream(
- txn, self.get_user_deactivated_status, (user_id,),
+ txn, self.get_user_deactivated_status, (user_id,)
)
@defer.inlineCallbacks
- def set_user_deactivated_status(self, user_id, deactivated):
- """Set the `deactivated` property for the provided user to the provided value.
-
- Args:
- user_id (str): The ID of the user to set the status for.
- deactivated (bool): The value to set for `deactivated`.
+ def _set_expiration_date_when_missing(self):
+ """
+ Retrieves the list of registered users that don't have an expiration date, and
+ adds an expiration date for each of them.
"""
- yield self.runInteraction(
- "set_user_deactivated_status",
- self.set_user_deactivated_status_txn,
- user_id, deactivated,
+ def select_users_with_no_expiration_date_txn(txn):
+ """Retrieves the list of registered users with no expiration date from the
+ database, filtering out deactivated users.
+ """
+ sql = (
+ "SELECT users.name FROM users"
+ " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
+ " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
+ )
+ txn.execute(sql, [])
+
+ res = self.db.cursor_to_dict(txn)
+ if res:
+ for user in res:
+ self.set_expiration_date_for_user_txn(
+ txn, user["name"], use_delta=True
+ )
+
+ yield self.db.runInteraction(
+ "get_users_with_no_expiration_date",
+ select_users_with_no_expiration_date_txn,
)
- @cachedInlineCallbacks()
- def get_user_deactivated_status(self, user_id):
- """Retrieve the value for the `deactivated` property for the provided user.
+ def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
+ """Sets an expiration date to the account with the given user ID.
Args:
- user_id (str): The ID of the user to retrieve the status for.
-
- Returns:
- defer.Deferred(bool): The requested value.
+ user_id (str): User ID to set an expiration date for.
+ use_delta (bool): If set to False, the expiration date for the user will be
+ now + validity period. If set to True, this expiration date will be a
+ random value in the [now + period - d ; now + period] range, d being a
+ delta equal to 10% of the validity period.
"""
+ now_ms = self._clock.time_msec()
+ expiration_ts = now_ms + self._account_validity.period
- res = yield self._simple_select_one_onecol(
- table="users",
- keyvalues={"name": user_id},
- retcol="deactivated",
- desc="get_user_deactivated_status",
- )
+ if use_delta:
+ expiration_ts = self.rand.randrange(
+ expiration_ts - self._account_validity.startup_job_max_delta,
+ expiration_ts,
+ )
- # Convert the integer into a boolean.
- defer.returnValue(res == 1)
+ self.db.simple_upsert_txn(
+ txn,
+ "account_validity",
+ keyvalues={"user_id": user_id},
+ values={"expiration_ts_ms": expiration_ts, "email_sent": False},
+ )
diff --git a/synapse/storage/rejections.py b/synapse/storage/data_stores/main/rejections.py
index f4c1c2a457..1c07c7a425 100644
--- a/synapse/storage/rejections.py
+++ b/synapse/storage/data_stores/main/rejections.py
@@ -15,14 +15,14 @@
import logging
-from ._base import SQLBaseStore
+from synapse.storage._base import SQLBaseStore
logger = logging.getLogger(__name__)
class RejectionsStore(SQLBaseStore):
def _store_rejections_txn(self, txn, event_id, reason):
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="rejections",
values={
@@ -33,7 +33,7 @@ class RejectionsStore(SQLBaseStore):
)
def get_rejection_reason(self, event_id):
- return self._simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="rejections",
retcol="reason",
keyvalues={"event_id": event_id},
diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/data_stores/main/relations.py
new file mode 100644
index 0000000000..046c2b4845
--- /dev/null
+++ b/synapse/storage/data_stores/main/relations.py
@@ -0,0 +1,385 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+import attr
+
+from synapse.api.constants import RelationTypes
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.stream import generate_pagination_where_clause
+from synapse.storage.relations import (
+ AggregationPaginationToken,
+ PaginationChunk,
+ RelationPaginationToken,
+)
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+
+logger = logging.getLogger(__name__)
+
+
+class RelationsWorkerStore(SQLBaseStore):
+ @cached(tree=True)
+ def get_relations_for_event(
+ self,
+ event_id,
+ relation_type=None,
+ event_type=None,
+ aggregation_key=None,
+ limit=5,
+ direction="b",
+ from_token=None,
+ to_token=None,
+ ):
+ """Get a list of relations for an event, ordered by topological ordering.
+
+ Args:
+ event_id (str): Fetch events that relate to this event ID.
+ relation_type (str|None): Only fetch events with this relation
+ type, if given.
+ event_type (str|None): Only fetch events with this event type, if
+ given.
+ aggregation_key (str|None): Only fetch events with this aggregation
+ key, if given.
+ limit (int): Only fetch the most recent `limit` events.
+ direction (str): Whether to fetch the most recent first (`"b"`) or
+ the oldest first (`"f"`).
+ from_token (RelationPaginationToken|None): Fetch rows from the given
+ token, or from the start if None.
+ to_token (RelationPaginationToken|None): Fetch rows up to the given
+ token, or up to the end if None.
+
+ Returns:
+ Deferred[PaginationChunk]: List of event IDs that match relations
+ requested. The rows are of the form `{"event_id": "..."}`.
+ """
+
+ where_clause = ["relates_to_id = ?"]
+ where_args = [event_id]
+
+ if relation_type is not None:
+ where_clause.append("relation_type = ?")
+ where_args.append(relation_type)
+
+ if event_type is not None:
+ where_clause.append("type = ?")
+ where_args.append(event_type)
+
+ if aggregation_key:
+ where_clause.append("aggregation_key = ?")
+ where_args.append(aggregation_key)
+
+ pagination_clause = generate_pagination_where_clause(
+ direction=direction,
+ column_names=("topological_ordering", "stream_ordering"),
+ from_token=attr.astuple(from_token) if from_token else None,
+ to_token=attr.astuple(to_token) if to_token else None,
+ engine=self.database_engine,
+ )
+
+ if pagination_clause:
+ where_clause.append(pagination_clause)
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ sql = """
+ SELECT event_id, topological_ordering, stream_ordering
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE %s
+ ORDER BY topological_ordering %s, stream_ordering %s
+ LIMIT ?
+ """ % (
+ " AND ".join(where_clause),
+ order,
+ order,
+ )
+
+ def _get_recent_references_for_event_txn(txn):
+ txn.execute(sql, where_args + [limit + 1])
+
+ last_topo_id = None
+ last_stream_id = None
+ events = []
+ for row in txn:
+ events.append({"event_id": row[0]})
+ last_topo_id = row[1]
+ last_stream_id = row[2]
+
+ next_batch = None
+ if len(events) > limit and last_topo_id and last_stream_id:
+ next_batch = RelationPaginationToken(last_topo_id, last_stream_id)
+
+ return PaginationChunk(
+ chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
+ )
+
+ return self.db.runInteraction(
+ "get_recent_references_for_event", _get_recent_references_for_event_txn
+ )
+
+ @cached(tree=True)
+ def get_aggregation_groups_for_event(
+ self,
+ event_id,
+ event_type=None,
+ limit=5,
+ direction="b",
+ from_token=None,
+ to_token=None,
+ ):
+ """Get a list of annotations on the event, grouped by event type and
+ aggregation key, sorted by count.
+
+ This is used e.g. to get the what and how many reactions have happend
+ on an event.
+
+ Args:
+ event_id (str): Fetch events that relate to this event ID.
+ event_type (str|None): Only fetch events with this event type, if
+ given.
+ limit (int): Only fetch the `limit` groups.
+ direction (str): Whether to fetch the highest count first (`"b"`) or
+ the lowest count first (`"f"`).
+ from_token (AggregationPaginationToken|None): Fetch rows from the
+ given token, or from the start if None.
+ to_token (AggregationPaginationToken|None): Fetch rows up to the
+ given token, or up to the end if None.
+
+
+ Returns:
+ Deferred[PaginationChunk]: List of groups of annotations that
+ match. Each row is a dict with `type`, `key` and `count` fields.
+ """
+
+ where_clause = ["relates_to_id = ?", "relation_type = ?"]
+ where_args = [event_id, RelationTypes.ANNOTATION]
+
+ if event_type:
+ where_clause.append("type = ?")
+ where_args.append(event_type)
+
+ having_clause = generate_pagination_where_clause(
+ direction=direction,
+ column_names=("COUNT(*)", "MAX(stream_ordering)"),
+ from_token=attr.astuple(from_token) if from_token else None,
+ to_token=attr.astuple(to_token) if to_token else None,
+ engine=self.database_engine,
+ )
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ if having_clause:
+ having_clause = "HAVING " + having_clause
+ else:
+ having_clause = ""
+
+ sql = """
+ SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering)
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE {where_clause}
+ GROUP BY relation_type, type, aggregation_key
+ {having_clause}
+ ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order}
+ LIMIT ?
+ """.format(
+ where_clause=" AND ".join(where_clause),
+ order=order,
+ having_clause=having_clause,
+ )
+
+ def _get_aggregation_groups_for_event_txn(txn):
+ txn.execute(sql, where_args + [limit + 1])
+
+ next_batch = None
+ events = []
+ for row in txn:
+ events.append({"type": row[0], "key": row[1], "count": row[2]})
+ next_batch = AggregationPaginationToken(row[2], row[3])
+
+ if len(events) <= limit:
+ next_batch = None
+
+ return PaginationChunk(
+ chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
+ )
+
+ return self.db.runInteraction(
+ "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
+ )
+
+ @cachedInlineCallbacks()
+ def get_applicable_edit(self, event_id):
+ """Get the most recent edit (if any) that has happened for the given
+ event.
+
+ Correctly handles checking whether edits were allowed to happen.
+
+ Args:
+ event_id (str): The original event ID
+
+ Returns:
+ Deferred[EventBase|None]: Returns the most recent edit, if any.
+ """
+
+ # We only allow edits for `m.room.message` events that have the same sender
+ # and event type. We can't assert these things during regular event auth so
+ # we have to do the checks post hoc.
+
+ # Fetches latest edit that has the same type and sender as the
+ # original, and is an `m.room.message`.
+ sql = """
+ SELECT edit.event_id FROM events AS edit
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS original ON
+ original.event_id = relates_to_id
+ AND edit.type = original.type
+ AND edit.sender = original.sender
+ WHERE
+ relates_to_id = ?
+ AND relation_type = ?
+ AND edit.type = 'm.room.message'
+ ORDER by edit.origin_server_ts DESC, edit.event_id DESC
+ LIMIT 1
+ """
+
+ def _get_applicable_edit_txn(txn):
+ txn.execute(sql, (event_id, RelationTypes.REPLACE))
+ row = txn.fetchone()
+ if row:
+ return row[0]
+
+ edit_id = yield self.db.runInteraction(
+ "get_applicable_edit", _get_applicable_edit_txn
+ )
+
+ if not edit_id:
+ return
+
+ edit_event = yield self.get_event(edit_id, allow_none=True)
+ return edit_event
+
+ def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
+ """Check if a user has already annotated an event with the same key
+ (e.g. already liked an event).
+
+ Args:
+ parent_id (str): The event being annotated
+ event_type (str): The event type of the annotation
+ aggregation_key (str): The aggregation key of the annotation
+ sender (str): The sender of the annotation
+
+ Returns:
+ Deferred[bool]
+ """
+
+ sql = """
+ SELECT 1 FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE
+ relates_to_id = ?
+ AND relation_type = ?
+ AND type = ?
+ AND sender = ?
+ AND aggregation_key = ?
+ LIMIT 1;
+ """
+
+ def _get_if_user_has_annotated_event(txn):
+ txn.execute(
+ sql,
+ (
+ parent_id,
+ RelationTypes.ANNOTATION,
+ event_type,
+ sender,
+ aggregation_key,
+ ),
+ )
+
+ return bool(txn.fetchone())
+
+ return self.db.runInteraction(
+ "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
+ )
+
+
+class RelationsStore(RelationsWorkerStore):
+ def _handle_event_relations(self, txn, event):
+ """Handles inserting relation data during peristence of events
+
+ Args:
+ txn
+ event (EventBase)
+ """
+ relation = event.content.get("m.relates_to")
+ if not relation:
+ # No relations
+ return
+
+ rel_type = relation.get("rel_type")
+ if rel_type not in (
+ RelationTypes.ANNOTATION,
+ RelationTypes.REFERENCE,
+ RelationTypes.REPLACE,
+ ):
+ # Unknown relation type
+ return
+
+ parent_id = relation.get("event_id")
+ if not parent_id:
+ # Invalid relation
+ return
+
+ aggregation_key = relation.get("key")
+
+ self.db.simple_insert_txn(
+ txn,
+ table="event_relations",
+ values={
+ "event_id": event.event_id,
+ "relates_to_id": parent_id,
+ "relation_type": rel_type,
+ "aggregation_key": aggregation_key,
+ },
+ )
+
+ txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,))
+ txn.call_after(
+ self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
+ )
+
+ if rel_type == RelationTypes.REPLACE:
+ txn.call_after(self.get_applicable_edit.invalidate, (parent_id,))
+
+ def _handle_redaction(self, txn, redacted_event_id):
+ """Handles receiving a redaction and checking whether we need to remove
+ any redacted relations from the database.
+
+ Args:
+ txn
+ redacted_event_id (str): The event that was redacted.
+ """
+
+ self.db.simple_delete_txn(
+ txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
+ )
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
new file mode 100644
index 0000000000..511316938d
--- /dev/null
+++ b/synapse/storage/data_stores/main/room.py
@@ -0,0 +1,1404 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import collections
+import logging
+import re
+from abc import abstractmethod
+from enum import Enum
+from typing import Any, Dict, List, Optional, Tuple
+
+from six import integer_types
+
+from canonicaljson import json
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes
+from synapse.api.errors import StoreError
+from synapse.api.room_versions import RoomVersion, RoomVersions
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.search import SearchStore
+from synapse.storage.database import Database, LoggingTransaction
+from synapse.types import ThirdPartyInstanceID
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+
+logger = logging.getLogger(__name__)
+
+
+OpsLevel = collections.namedtuple(
+ "OpsLevel", ("ban_level", "kick_level", "redact_level")
+)
+
+RatelimitOverride = collections.namedtuple(
+ "RatelimitOverride", ("messages_per_second", "burst_count")
+)
+
+
+class RoomSortOrder(Enum):
+ """
+ Enum to define the sorting method used when returning rooms with get_rooms_paginate
+
+ ALPHABETICAL = sort rooms alphabetically by name
+ SIZE = sort rooms by membership size, highest to lowest
+ """
+
+ ALPHABETICAL = "alphabetical"
+ SIZE = "size"
+
+
+class RoomWorkerStore(SQLBaseStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(RoomWorkerStore, self).__init__(database, db_conn, hs)
+
+ self.config = hs.config
+
+ def get_room(self, room_id):
+ """Retrieve a room.
+
+ Args:
+ room_id (str): The ID of the room to retrieve.
+ Returns:
+ A dict containing the room information, or None if the room is unknown.
+ """
+ return self.db.simple_select_one(
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ retcols=("room_id", "is_public", "creator"),
+ desc="get_room",
+ allow_none=True,
+ )
+
+ def get_public_room_ids(self):
+ return self.db.simple_select_onecol(
+ table="rooms",
+ keyvalues={"is_public": True},
+ retcol="room_id",
+ desc="get_public_room_ids",
+ )
+
+ def count_public_rooms(self, network_tuple, ignore_non_federatable):
+ """Counts the number of public rooms as tracked in the room_stats_current
+ and room_stats_state table.
+
+ Args:
+ network_tuple (ThirdPartyInstanceID|None)
+ ignore_non_federatable (bool): If true filters out non-federatable rooms
+ """
+
+ def _count_public_rooms_txn(txn):
+ query_args = []
+
+ if network_tuple:
+ if network_tuple.appservice_id:
+ published_sql = """
+ SELECT room_id from appservice_room_list
+ WHERE appservice_id = ? AND network_id = ?
+ """
+ query_args.append(network_tuple.appservice_id)
+ query_args.append(network_tuple.network_id)
+ else:
+ published_sql = """
+ SELECT room_id FROM rooms WHERE is_public
+ """
+ else:
+ published_sql = """
+ SELECT room_id FROM rooms WHERE is_public
+ UNION SELECT room_id from appservice_room_list
+ """
+
+ sql = """
+ SELECT
+ COALESCE(COUNT(*), 0)
+ FROM (
+ %(published_sql)s
+ ) published
+ INNER JOIN room_stats_state USING (room_id)
+ INNER JOIN room_stats_current USING (room_id)
+ WHERE
+ (
+ join_rules = 'public' OR history_visibility = 'world_readable'
+ )
+ AND joined_members > 0
+ """ % {
+ "published_sql": published_sql
+ }
+
+ txn.execute(sql, query_args)
+ return txn.fetchone()[0]
+
+ return self.db.runInteraction("count_public_rooms", _count_public_rooms_txn)
+
+ @defer.inlineCallbacks
+ def get_largest_public_rooms(
+ self,
+ network_tuple: Optional[ThirdPartyInstanceID],
+ search_filter: Optional[dict],
+ limit: Optional[int],
+ bounds: Optional[Tuple[int, str]],
+ forwards: bool,
+ ignore_non_federatable: bool = False,
+ ):
+ """Gets the largest public rooms (where largest is in terms of joined
+ members, as tracked in the statistics table).
+
+ Args:
+ network_tuple
+ search_filter
+ limit: Maxmimum number of rows to return, unlimited otherwise.
+ bounds: An uppoer or lower bound to apply to result set if given,
+ consists of a joined member count and room_id (these are
+ excluded from result set).
+ forwards: true iff going forwards, going backwards otherwise
+ ignore_non_federatable: If true filters out non-federatable rooms.
+
+ Returns:
+ Rooms in order: biggest number of joined users first.
+ We then arbitrarily use the room_id as a tie breaker.
+
+ """
+
+ where_clauses = []
+ query_args = []
+
+ if network_tuple:
+ if network_tuple.appservice_id:
+ published_sql = """
+ SELECT room_id from appservice_room_list
+ WHERE appservice_id = ? AND network_id = ?
+ """
+ query_args.append(network_tuple.appservice_id)
+ query_args.append(network_tuple.network_id)
+ else:
+ published_sql = """
+ SELECT room_id FROM rooms WHERE is_public
+ """
+ else:
+ published_sql = """
+ SELECT room_id FROM rooms WHERE is_public
+ UNION SELECT room_id from appservice_room_list
+ """
+
+ # Work out the bounds if we're given them, these bounds look slightly
+ # odd, but are designed to help query planner use indices by pulling
+ # out a common bound.
+ if bounds:
+ last_joined_members, last_room_id = bounds
+ if forwards:
+ where_clauses.append(
+ """
+ joined_members <= ? AND (
+ joined_members < ? OR room_id < ?
+ )
+ """
+ )
+ else:
+ where_clauses.append(
+ """
+ joined_members >= ? AND (
+ joined_members > ? OR room_id > ?
+ )
+ """
+ )
+
+ query_args += [last_joined_members, last_joined_members, last_room_id]
+
+ if ignore_non_federatable:
+ where_clauses.append("is_federatable")
+
+ if search_filter and search_filter.get("generic_search_term", None):
+ search_term = "%" + search_filter["generic_search_term"] + "%"
+
+ where_clauses.append(
+ """
+ (
+ LOWER(name) LIKE ?
+ OR LOWER(topic) LIKE ?
+ OR LOWER(canonical_alias) LIKE ?
+ )
+ """
+ )
+ query_args += [
+ search_term.lower(),
+ search_term.lower(),
+ search_term.lower(),
+ ]
+
+ where_clause = ""
+ if where_clauses:
+ where_clause = " AND " + " AND ".join(where_clauses)
+
+ sql = """
+ SELECT
+ room_id, name, topic, canonical_alias, joined_members,
+ avatar, history_visibility, joined_members, guest_access
+ FROM (
+ %(published_sql)s
+ ) published
+ INNER JOIN room_stats_state USING (room_id)
+ INNER JOIN room_stats_current USING (room_id)
+ WHERE
+ (
+ join_rules = 'public' OR history_visibility = 'world_readable'
+ )
+ AND joined_members > 0
+ %(where_clause)s
+ ORDER BY joined_members %(dir)s, room_id %(dir)s
+ """ % {
+ "published_sql": published_sql,
+ "where_clause": where_clause,
+ "dir": "DESC" if forwards else "ASC",
+ }
+
+ if limit is not None:
+ query_args.append(limit)
+
+ sql += """
+ LIMIT ?
+ """
+
+ def _get_largest_public_rooms_txn(txn):
+ txn.execute(sql, query_args)
+
+ results = self.db.cursor_to_dict(txn)
+
+ if not forwards:
+ results.reverse()
+
+ return results
+
+ ret_val = yield self.db.runInteraction(
+ "get_largest_public_rooms", _get_largest_public_rooms_txn
+ )
+ defer.returnValue(ret_val)
+
+ @cached(max_entries=10000)
+ def is_room_blocked(self, room_id):
+ return self.db.simple_select_one_onecol(
+ table="blocked_rooms",
+ keyvalues={"room_id": room_id},
+ retcol="1",
+ allow_none=True,
+ desc="is_room_blocked",
+ )
+
+ @defer.inlineCallbacks
+ def is_room_published(self, room_id):
+ """Check whether a room has been published in the local public room
+ directory.
+
+ Args:
+ room_id (str)
+ Returns:
+ bool: Whether the room is currently published in the room directory
+ """
+ # Get room information
+ room_info = yield self.get_room(room_id)
+ if not room_info:
+ defer.returnValue(False)
+
+ # Check the is_public value
+ defer.returnValue(room_info.get("is_public", False))
+
+ async def get_rooms_paginate(
+ self,
+ start: int,
+ limit: int,
+ order_by: RoomSortOrder,
+ reverse_order: bool,
+ search_term: Optional[str],
+ ) -> Tuple[List[Dict[str, Any]], int]:
+ """Function to retrieve a paginated list of rooms as json.
+
+ Args:
+ start: offset in the list
+ limit: maximum amount of rooms to retrieve
+ order_by: the sort order of the returned list
+ reverse_order: whether to reverse the room list
+ search_term: a string to filter room names by
+ Returns:
+ A list of room dicts and an integer representing the total number of
+ rooms that exist given this query
+ """
+ # Filter room names by a string
+ where_statement = ""
+ if search_term:
+ where_statement = "WHERE state.name LIKE ?"
+
+ # Our postgres db driver converts ? -> %s in SQL strings as that's the
+ # placeholder for postgres.
+ # HOWEVER, if you put a % into your SQL then everything goes wibbly.
+ # To get around this, we're going to surround search_term with %'s
+ # before giving it to the database in python instead
+ search_term = "%" + search_term + "%"
+
+ # Set ordering
+ if RoomSortOrder(order_by) == RoomSortOrder.SIZE:
+ order_by_column = "curr.joined_members"
+ order_by_asc = False
+ elif RoomSortOrder(order_by) == RoomSortOrder.ALPHABETICAL:
+ # Sort alphabetically
+ order_by_column = "state.name"
+ order_by_asc = True
+ else:
+ raise StoreError(
+ 500, "Incorrect value for order_by provided: %s" % order_by
+ )
+
+ # Whether to return the list in reverse order
+ if reverse_order:
+ # Flip the boolean
+ order_by_asc = not order_by_asc
+
+ # Create one query for getting the limited number of events that the user asked
+ # for, and another query for getting the total number of events that could be
+ # returned. Thus allowing us to see if there are more events to paginate through
+ info_sql = """
+ SELECT state.room_id, state.name, state.canonical_alias, curr.joined_members
+ FROM room_stats_state state
+ INNER JOIN room_stats_current curr USING (room_id)
+ %s
+ ORDER BY %s %s
+ LIMIT ?
+ OFFSET ?
+ """ % (
+ where_statement,
+ order_by_column,
+ "ASC" if order_by_asc else "DESC",
+ )
+
+ # Use a nested SELECT statement as SQL can't count(*) with an OFFSET
+ count_sql = """
+ SELECT count(*) FROM (
+ SELECT room_id FROM room_stats_state state
+ %s
+ ) AS get_room_ids
+ """ % (
+ where_statement,
+ )
+
+ def _get_rooms_paginate_txn(txn):
+ # Execute the data query
+ sql_values = (limit, start)
+ if search_term:
+ # Add the search term into the WHERE clause
+ sql_values = (search_term,) + sql_values
+ txn.execute(info_sql, sql_values)
+
+ # Refactor room query data into a structured dictionary
+ rooms = []
+ for room in txn:
+ rooms.append(
+ {
+ "room_id": room[0],
+ "name": room[1],
+ "canonical_alias": room[2],
+ "joined_members": room[3],
+ }
+ )
+
+ # Execute the count query
+
+ # Add the search term into the WHERE clause if present
+ sql_values = (search_term,) if search_term else ()
+ txn.execute(count_sql, sql_values)
+
+ room_count = txn.fetchone()
+ return rooms, room_count[0]
+
+ return await self.db.runInteraction(
+ "get_rooms_paginate", _get_rooms_paginate_txn,
+ )
+
+ @cachedInlineCallbacks(max_entries=10000)
+ def get_ratelimit_for_user(self, user_id):
+ """Check if there are any overrides for ratelimiting for the given
+ user
+
+ Args:
+ user_id (str)
+
+ Returns:
+ RatelimitOverride if there is an override, else None. If the contents
+ of RatelimitOverride are None or 0 then ratelimitng has been
+ disabled for that user entirely.
+ """
+ row = yield self.db.simple_select_one(
+ table="ratelimit_override",
+ keyvalues={"user_id": user_id},
+ retcols=("messages_per_second", "burst_count"),
+ allow_none=True,
+ desc="get_ratelimit_for_user",
+ )
+
+ if row:
+ return RatelimitOverride(
+ messages_per_second=row["messages_per_second"],
+ burst_count=row["burst_count"],
+ )
+ else:
+ return None
+
+ @cachedInlineCallbacks()
+ def get_retention_policy_for_room(self, room_id):
+ """Get the retention policy for a given room.
+
+ If no retention policy has been found for this room, returns a policy defined
+ by the configured default policy (which has None as both the 'min_lifetime' and
+ the 'max_lifetime' if no default policy has been defined in the server's
+ configuration).
+
+ Args:
+ room_id (str): The ID of the room to get the retention policy of.
+
+ Returns:
+ dict[int, int]: "min_lifetime" and "max_lifetime" for this room.
+ """
+ # If the room retention feature is disabled, return a policy with no minimum nor
+ # maximum, in order not to filter out events we should filter out when sending to
+ # the client.
+ if not self.config.retention_enabled:
+ defer.returnValue({"min_lifetime": None, "max_lifetime": None})
+
+ def get_retention_policy_for_room_txn(txn):
+ txn.execute(
+ """
+ SELECT min_lifetime, max_lifetime FROM room_retention
+ INNER JOIN current_state_events USING (event_id, room_id)
+ WHERE room_id = ?;
+ """,
+ (room_id,),
+ )
+
+ return self.db.cursor_to_dict(txn)
+
+ ret = yield self.db.runInteraction(
+ "get_retention_policy_for_room", get_retention_policy_for_room_txn,
+ )
+
+ # If we don't know this room ID, ret will be None, in this case return the default
+ # policy.
+ if not ret:
+ defer.returnValue(
+ {
+ "min_lifetime": self.config.retention_default_min_lifetime,
+ "max_lifetime": self.config.retention_default_max_lifetime,
+ }
+ )
+
+ row = ret[0]
+
+ # If one of the room's policy's attributes isn't defined, use the matching
+ # attribute from the default policy.
+ # The default values will be None if no default policy has been defined, or if one
+ # of the attributes is missing from the default policy.
+ if row["min_lifetime"] is None:
+ row["min_lifetime"] = self.config.retention_default_min_lifetime
+
+ if row["max_lifetime"] is None:
+ row["max_lifetime"] = self.config.retention_default_max_lifetime
+
+ defer.returnValue(row)
+
+ def get_media_mxcs_in_room(self, room_id):
+ """Retrieves all the local and remote media MXC URIs in a given room
+
+ Args:
+ room_id (str)
+
+ Returns:
+ The local and remote media as a lists of tuples where the key is
+ the hostname and the value is the media ID.
+ """
+
+ def _get_media_mxcs_in_room_txn(txn):
+ local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+ local_media_mxcs = []
+ remote_media_mxcs = []
+
+ # Convert the IDs to MXC URIs
+ for media_id in local_mxcs:
+ local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id))
+ for hostname, media_id in remote_mxcs:
+ remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
+
+ return local_media_mxcs, remote_media_mxcs
+
+ return self.db.runInteraction(
+ "get_media_ids_in_room", _get_media_mxcs_in_room_txn
+ )
+
+ def quarantine_media_ids_in_room(self, room_id, quarantined_by):
+ """For a room loops through all events with media and quarantines
+ the associated media
+ """
+
+ logger.info("Quarantining media in room: %s", room_id)
+
+ def _quarantine_media_in_room_txn(txn):
+ local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+ total_media_quarantined = 0
+
+ # Now update all the tables to set the quarantined_by flag
+
+ txn.executemany(
+ """
+ UPDATE local_media_repository
+ SET quarantined_by = ?
+ WHERE media_id = ?
+ """,
+ ((quarantined_by, media_id) for media_id in local_mxcs),
+ )
+
+ txn.executemany(
+ """
+ UPDATE remote_media_cache
+ SET quarantined_by = ?
+ WHERE media_origin = ? AND media_id = ?
+ """,
+ (
+ (quarantined_by, origin, media_id)
+ for origin, media_id in remote_mxcs
+ ),
+ )
+
+ total_media_quarantined += len(local_mxcs)
+ total_media_quarantined += len(remote_mxcs)
+
+ return total_media_quarantined
+
+ return self.db.runInteraction(
+ "quarantine_media_in_room", _quarantine_media_in_room_txn
+ )
+
+ def _get_media_mxcs_in_room_txn(self, txn, room_id):
+ """Retrieves all the local and remote media MXC URIs in a given room
+
+ Args:
+ txn (cursor)
+ room_id (str)
+
+ Returns:
+ The local and remote media as a lists of tuples where the key is
+ the hostname and the value is the media ID.
+ """
+ mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+
+ sql = """
+ SELECT stream_ordering, json FROM events
+ JOIN event_json USING (room_id, event_id)
+ WHERE room_id = ?
+ %(where_clause)s
+ AND contains_url = ? AND outlier = ?
+ ORDER BY stream_ordering DESC
+ LIMIT ?
+ """
+ txn.execute(sql % {"where_clause": ""}, (room_id, True, False, 100))
+
+ local_media_mxcs = []
+ remote_media_mxcs = []
+
+ while True:
+ next_token = None
+ for stream_ordering, content_json in txn:
+ next_token = stream_ordering
+ event_json = json.loads(content_json)
+ content = event_json["content"]
+ content_url = content.get("url")
+ thumbnail_url = content.get("info", {}).get("thumbnail_url")
+
+ for url in (content_url, thumbnail_url):
+ if not url:
+ continue
+ matches = mxc_re.match(url)
+ if matches:
+ hostname = matches.group(1)
+ media_id = matches.group(2)
+ if hostname == self.hs.hostname:
+ local_media_mxcs.append(media_id)
+ else:
+ remote_media_mxcs.append((hostname, media_id))
+
+ if next_token is None:
+ # We've gone through the whole room, so we're finished.
+ break
+
+ txn.execute(
+ sql % {"where_clause": "AND stream_ordering < ?"},
+ (room_id, next_token, True, False, 100),
+ )
+
+ return local_media_mxcs, remote_media_mxcs
+
+ def quarantine_media_by_id(
+ self, server_name: str, media_id: str, quarantined_by: str,
+ ):
+ """quarantines a single local or remote media id
+
+ Args:
+ server_name: The name of the server that holds this media
+ media_id: The ID of the media to be quarantined
+ quarantined_by: The user ID that initiated the quarantine request
+ """
+ logger.info("Quarantining media: %s/%s", server_name, media_id)
+ is_local = server_name == self.config.server_name
+
+ def _quarantine_media_by_id_txn(txn):
+ local_mxcs = [media_id] if is_local else []
+ remote_mxcs = [(server_name, media_id)] if not is_local else []
+
+ return self._quarantine_media_txn(
+ txn, local_mxcs, remote_mxcs, quarantined_by
+ )
+
+ return self.db.runInteraction(
+ "quarantine_media_by_user", _quarantine_media_by_id_txn
+ )
+
+ def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str):
+ """quarantines all local media associated with a single user
+
+ Args:
+ user_id: The ID of the user to quarantine media of
+ quarantined_by: The ID of the user who made the quarantine request
+ """
+
+ def _quarantine_media_by_user_txn(txn):
+ local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
+ return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
+
+ return self.db.runInteraction(
+ "quarantine_media_by_user", _quarantine_media_by_user_txn
+ )
+
+ def _get_media_ids_by_user_txn(self, txn, user_id: str, filter_quarantined=True):
+ """Retrieves local media IDs by a given user
+
+ Args:
+ txn (cursor)
+ user_id: The ID of the user to retrieve media IDs of
+
+ Returns:
+ The local and remote media as a lists of tuples where the key is
+ the hostname and the value is the media ID.
+ """
+ # Local media
+ sql = """
+ SELECT media_id
+ FROM local_media_repository
+ WHERE user_id = ?
+ """
+ if filter_quarantined:
+ sql += "AND quarantined_by IS NULL"
+ txn.execute(sql, (user_id,))
+
+ local_media_ids = [row[0] for row in txn]
+
+ # TODO: Figure out all remote media a user has referenced in a message
+
+ return local_media_ids
+
+ def _quarantine_media_txn(
+ self,
+ txn,
+ local_mxcs: List[str],
+ remote_mxcs: List[Tuple[str, str]],
+ quarantined_by: str,
+ ) -> int:
+ """Quarantine local and remote media items
+
+ Args:
+ txn (cursor)
+ local_mxcs: A list of local mxc URLs
+ remote_mxcs: A list of (remote server, media id) tuples representing
+ remote mxc URLs
+ quarantined_by: The ID of the user who initiated the quarantine request
+ Returns:
+ The total number of media items quarantined
+ """
+ total_media_quarantined = 0
+
+ # Update all the tables to set the quarantined_by flag
+ txn.executemany(
+ """
+ UPDATE local_media_repository
+ SET quarantined_by = ?
+ WHERE media_id = ?
+ """,
+ ((quarantined_by, media_id) for media_id in local_mxcs),
+ )
+
+ txn.executemany(
+ """
+ UPDATE remote_media_cache
+ SET quarantined_by = ?
+ WHERE media_origin = ? AND media_id = ?
+ """,
+ ((quarantined_by, origin, media_id) for origin, media_id in remote_mxcs),
+ )
+
+ total_media_quarantined += len(local_mxcs)
+ total_media_quarantined += len(remote_mxcs)
+
+ return total_media_quarantined
+
+
+class RoomBackgroundUpdateStore(SQLBaseStore):
+ REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
+ ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column"
+
+ def __init__(self, database: Database, db_conn, hs):
+ super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+
+ self.config = hs.config
+
+ self.db.updates.register_background_update_handler(
+ "insert_room_retention", self._background_insert_retention,
+ )
+
+ self.db.updates.register_background_update_handler(
+ self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE,
+ self._remove_tombstoned_rooms_from_directory,
+ )
+
+ self.db.updates.register_background_update_handler(
+ self.ADD_ROOMS_ROOM_VERSION_COLUMN,
+ self._background_add_rooms_room_version_column,
+ )
+
+ @defer.inlineCallbacks
+ def _background_insert_retention(self, progress, batch_size):
+ """Retrieves a list of all rooms within a range and inserts an entry for each of
+ them into the room_retention table.
+ NULLs the property's columns if missing from the retention event in the room's
+ state (or NULLs all of them if there's no retention event in the room's state),
+ so that we fall back to the server's retention policy.
+ """
+
+ last_room = progress.get("room_id", "")
+
+ def _background_insert_retention_txn(txn):
+ txn.execute(
+ """
+ SELECT state.room_id, state.event_id, events.json
+ FROM current_state_events as state
+ LEFT JOIN event_json AS events ON (state.event_id = events.event_id)
+ WHERE state.room_id > ? AND state.type = '%s'
+ ORDER BY state.room_id ASC
+ LIMIT ?;
+ """
+ % EventTypes.Retention,
+ (last_room, batch_size),
+ )
+
+ rows = self.db.cursor_to_dict(txn)
+
+ if not rows:
+ return True
+
+ for row in rows:
+ if not row["json"]:
+ retention_policy = {}
+ else:
+ ev = json.loads(row["json"])
+ retention_policy = json.dumps(ev["content"])
+
+ self.db.simple_insert_txn(
+ txn=txn,
+ table="room_retention",
+ values={
+ "room_id": row["room_id"],
+ "event_id": row["event_id"],
+ "min_lifetime": retention_policy.get("min_lifetime"),
+ "max_lifetime": retention_policy.get("max_lifetime"),
+ },
+ )
+
+ logger.info("Inserted %d rows into room_retention", len(rows))
+
+ self.db.updates._background_update_progress_txn(
+ txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]}
+ )
+
+ if batch_size > len(rows):
+ return True
+ else:
+ return False
+
+ end = yield self.db.runInteraction(
+ "insert_room_retention", _background_insert_retention_txn,
+ )
+
+ if end:
+ yield self.db.updates._end_background_update("insert_room_retention")
+
+ defer.returnValue(batch_size)
+
+ async def _background_add_rooms_room_version_column(
+ self, progress: dict, batch_size: int
+ ):
+ """Background update to go and add room version inforamtion to `rooms`
+ table from `current_state_events` table.
+ """
+
+ last_room_id = progress.get("room_id", "")
+
+ def _background_add_rooms_room_version_column_txn(txn: LoggingTransaction):
+ sql = """
+ SELECT room_id, json FROM current_state_events
+ INNER JOIN event_json USING (room_id, event_id)
+ WHERE room_id > ? AND type = 'm.room.create' AND state_key = ''
+ ORDER BY room_id
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_room_id, batch_size))
+
+ updates = []
+ for room_id, event_json in txn:
+ event_dict = json.loads(event_json)
+ room_version_id = event_dict.get("content", {}).get(
+ "room_version", RoomVersions.V1.identifier
+ )
+
+ creator = event_dict.get("content").get("creator")
+
+ updates.append((room_id, creator, room_version_id))
+
+ if not updates:
+ return True
+
+ new_last_room_id = ""
+ for room_id, creator, room_version_id in updates:
+ # We upsert here just in case we don't already have a row,
+ # mainly for paranoia as much badness would happen if we don't
+ # insert the row and then try and get the room version for the
+ # room.
+ self.db.simple_upsert_txn(
+ txn,
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ values={"room_version": room_version_id},
+ insertion_values={"is_public": False, "creator": creator},
+ )
+ new_last_room_id = room_id
+
+ self.db.updates._background_update_progress_txn(
+ txn, self.ADD_ROOMS_ROOM_VERSION_COLUMN, {"room_id": new_last_room_id}
+ )
+
+ return False
+
+ end = await self.db.runInteraction(
+ "_background_add_rooms_room_version_column",
+ _background_add_rooms_room_version_column_txn,
+ )
+
+ if end:
+ await self.db.updates._end_background_update(
+ self.ADD_ROOMS_ROOM_VERSION_COLUMN
+ )
+
+ return batch_size
+
+ async def _remove_tombstoned_rooms_from_directory(
+ self, progress, batch_size
+ ) -> int:
+ """Removes any rooms with tombstone events from the room directory
+
+ Nowadays this is handled by the room upgrade handler, but we may have some
+ that got left behind
+ """
+
+ last_room = progress.get("room_id", "")
+
+ def _get_rooms(txn):
+ txn.execute(
+ """
+ SELECT room_id
+ FROM rooms r
+ INNER JOIN current_state_events cse USING (room_id)
+ WHERE room_id > ? AND r.is_public
+ AND cse.type = '%s' AND cse.state_key = ''
+ ORDER BY room_id ASC
+ LIMIT ?;
+ """
+ % EventTypes.Tombstone,
+ (last_room, batch_size),
+ )
+
+ return [row[0] for row in txn]
+
+ rooms = await self.db.runInteraction(
+ "get_tombstoned_directory_rooms", _get_rooms
+ )
+
+ if not rooms:
+ await self.db.updates._end_background_update(
+ self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE
+ )
+ return 0
+
+ for room_id in rooms:
+ logger.info("Removing tombstoned room %s from the directory", room_id)
+ await self.set_room_is_public(room_id, False)
+
+ await self.db.updates._background_update_progress(
+ self.REMOVE_TOMESTONED_ROOMS_BG_UPDATE, {"room_id": rooms[-1]}
+ )
+
+ return len(rooms)
+
+ @abstractmethod
+ def set_room_is_public(self, room_id, is_public):
+ # this will need to be implemented if a background update is performed with
+ # existing (tombstoned, public) rooms in the database.
+ #
+ # It's overridden by RoomStore for the synapse master.
+ raise NotImplementedError()
+
+
+class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(RoomStore, self).__init__(database, db_conn, hs)
+
+ self.config = hs.config
+
+ async def upsert_room_on_join(self, room_id: str, room_version: RoomVersion):
+ """Ensure that the room is stored in the table
+
+ Called when we join a room over federation, and overwrites any room version
+ currently in the table.
+ """
+ await self.db.simple_upsert(
+ desc="upsert_room_on_join",
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ values={"room_version": room_version.identifier},
+ insertion_values={"is_public": False, "creator": ""},
+ # rooms has a unique constraint on room_id, so no need to lock when doing an
+ # emulated upsert.
+ lock=False,
+ )
+
+ @defer.inlineCallbacks
+ def store_room(
+ self,
+ room_id: str,
+ room_creator_user_id: str,
+ is_public: bool,
+ room_version: RoomVersion,
+ ):
+ """Stores a room.
+
+ Args:
+ room_id: The desired room ID, can be None.
+ room_creator_user_id: The user ID of the room creator.
+ is_public: True to indicate that this room should appear in
+ public room lists.
+ room_version: The version of the room
+ Raises:
+ StoreError if the room could not be stored.
+ """
+ try:
+
+ def store_room_txn(txn, next_id):
+ self.db.simple_insert_txn(
+ txn,
+ "rooms",
+ {
+ "room_id": room_id,
+ "creator": room_creator_user_id,
+ "is_public": is_public,
+ "room_version": room_version.identifier,
+ },
+ )
+ if is_public:
+ self.db.simple_insert_txn(
+ txn,
+ table="public_room_list_stream",
+ values={
+ "stream_id": next_id,
+ "room_id": room_id,
+ "visibility": is_public,
+ },
+ )
+
+ with self._public_room_id_gen.get_next() as next_id:
+ yield self.db.runInteraction("store_room_txn", store_room_txn, next_id)
+ except Exception as e:
+ logger.error("store_room with room_id=%s failed: %s", room_id, e)
+ raise StoreError(500, "Problem creating room.")
+
+ async def maybe_store_room_on_invite(self, room_id: str, room_version: RoomVersion):
+ """
+ When we receive an invite over federation, store the version of the room if we
+ don't already know the room version.
+ """
+ await self.db.simple_upsert(
+ desc="maybe_store_room_on_invite",
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ values={},
+ insertion_values={
+ "room_version": room_version.identifier,
+ "is_public": False,
+ "creator": "",
+ },
+ # rooms has a unique constraint on room_id, so no need to lock when doing an
+ # emulated upsert.
+ lock=False,
+ )
+
+ @defer.inlineCallbacks
+ def set_room_is_public(self, room_id, is_public):
+ def set_room_is_public_txn(txn, next_id):
+ self.db.simple_update_one_txn(
+ txn,
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={"is_public": is_public},
+ )
+
+ entries = self.db.simple_select_list_txn(
+ txn,
+ table="public_room_list_stream",
+ keyvalues={
+ "room_id": room_id,
+ "appservice_id": None,
+ "network_id": None,
+ },
+ retcols=("stream_id", "visibility"),
+ )
+
+ entries.sort(key=lambda r: r["stream_id"])
+
+ add_to_stream = True
+ if entries:
+ add_to_stream = bool(entries[-1]["visibility"]) != is_public
+
+ if add_to_stream:
+ self.db.simple_insert_txn(
+ txn,
+ table="public_room_list_stream",
+ values={
+ "stream_id": next_id,
+ "room_id": room_id,
+ "visibility": is_public,
+ "appservice_id": None,
+ "network_id": None,
+ },
+ )
+
+ with self._public_room_id_gen.get_next() as next_id:
+ yield self.db.runInteraction(
+ "set_room_is_public", set_room_is_public_txn, next_id
+ )
+ self.hs.get_notifier().on_new_replication_data()
+
+ @defer.inlineCallbacks
+ def set_room_is_public_appservice(
+ self, room_id, appservice_id, network_id, is_public
+ ):
+ """Edit the appservice/network specific public room list.
+
+ Each appservice can have a number of published room lists associated
+ with them, keyed off of an appservice defined `network_id`, which
+ basically represents a single instance of a bridge to a third party
+ network.
+
+ Args:
+ room_id (str)
+ appservice_id (str)
+ network_id (str)
+ is_public (bool): Whether to publish or unpublish the room from the
+ list.
+ """
+
+ def set_room_is_public_appservice_txn(txn, next_id):
+ if is_public:
+ try:
+ self.db.simple_insert_txn(
+ txn,
+ table="appservice_room_list",
+ values={
+ "appservice_id": appservice_id,
+ "network_id": network_id,
+ "room_id": room_id,
+ },
+ )
+ except self.database_engine.module.IntegrityError:
+ # We've already inserted, nothing to do.
+ return
+ else:
+ self.db.simple_delete_txn(
+ txn,
+ table="appservice_room_list",
+ keyvalues={
+ "appservice_id": appservice_id,
+ "network_id": network_id,
+ "room_id": room_id,
+ },
+ )
+
+ entries = self.db.simple_select_list_txn(
+ txn,
+ table="public_room_list_stream",
+ keyvalues={
+ "room_id": room_id,
+ "appservice_id": appservice_id,
+ "network_id": network_id,
+ },
+ retcols=("stream_id", "visibility"),
+ )
+
+ entries.sort(key=lambda r: r["stream_id"])
+
+ add_to_stream = True
+ if entries:
+ add_to_stream = bool(entries[-1]["visibility"]) != is_public
+
+ if add_to_stream:
+ self.db.simple_insert_txn(
+ txn,
+ table="public_room_list_stream",
+ values={
+ "stream_id": next_id,
+ "room_id": room_id,
+ "visibility": is_public,
+ "appservice_id": appservice_id,
+ "network_id": network_id,
+ },
+ )
+
+ with self._public_room_id_gen.get_next() as next_id:
+ yield self.db.runInteraction(
+ "set_room_is_public_appservice",
+ set_room_is_public_appservice_txn,
+ next_id,
+ )
+ self.hs.get_notifier().on_new_replication_data()
+
+ def get_room_count(self):
+ """Retrieve a list of all rooms
+ """
+
+ def f(txn):
+ sql = "SELECT count(*) FROM rooms"
+ txn.execute(sql)
+ row = txn.fetchone()
+ return row[0] or 0
+
+ return self.db.runInteraction("get_rooms", f)
+
+ def _store_room_topic_txn(self, txn, event):
+ if hasattr(event, "content") and "topic" in event.content:
+ self.store_event_search_txn(
+ txn, event, "content.topic", event.content["topic"]
+ )
+
+ def _store_room_name_txn(self, txn, event):
+ if hasattr(event, "content") and "name" in event.content:
+ self.store_event_search_txn(
+ txn, event, "content.name", event.content["name"]
+ )
+
+ def _store_room_message_txn(self, txn, event):
+ if hasattr(event, "content") and "body" in event.content:
+ self.store_event_search_txn(
+ txn, event, "content.body", event.content["body"]
+ )
+
+ def _store_retention_policy_for_room_txn(self, txn, event):
+ if hasattr(event, "content") and (
+ "min_lifetime" in event.content or "max_lifetime" in event.content
+ ):
+ if (
+ "min_lifetime" in event.content
+ and not isinstance(event.content.get("min_lifetime"), integer_types)
+ ) or (
+ "max_lifetime" in event.content
+ and not isinstance(event.content.get("max_lifetime"), integer_types)
+ ):
+ # Ignore the event if one of the value isn't an integer.
+ return
+
+ self.db.simple_insert_txn(
+ txn=txn,
+ table="room_retention",
+ values={
+ "room_id": event.room_id,
+ "event_id": event.event_id,
+ "min_lifetime": event.content.get("min_lifetime"),
+ "max_lifetime": event.content.get("max_lifetime"),
+ },
+ )
+
+ self._invalidate_cache_and_stream(
+ txn, self.get_retention_policy_for_room, (event.room_id,)
+ )
+
+ def add_event_report(
+ self, room_id, event_id, user_id, reason, content, received_ts
+ ):
+ next_id = self._event_reports_id_gen.get_next()
+ return self.db.simple_insert(
+ table="event_reports",
+ values={
+ "id": next_id,
+ "received_ts": received_ts,
+ "room_id": room_id,
+ "event_id": event_id,
+ "user_id": user_id,
+ "reason": reason,
+ "content": json.dumps(content),
+ },
+ desc="add_event_report",
+ )
+
+ def get_current_public_room_stream_id(self):
+ return self._public_room_id_gen.get_current_token()
+
+ def get_all_new_public_rooms(self, prev_id, current_id, limit):
+ def get_all_new_public_rooms(txn):
+ sql = """
+ SELECT stream_id, room_id, visibility, appservice_id, network_id
+ FROM public_room_list_stream
+ WHERE stream_id > ? AND stream_id <= ?
+ ORDER BY stream_id ASC
+ LIMIT ?
+ """
+
+ txn.execute(sql, (prev_id, current_id, limit))
+ return txn.fetchall()
+
+ if prev_id == current_id:
+ return defer.succeed([])
+
+ return self.db.runInteraction(
+ "get_all_new_public_rooms", get_all_new_public_rooms
+ )
+
+ @defer.inlineCallbacks
+ def block_room(self, room_id, user_id):
+ """Marks the room as blocked. Can be called multiple times.
+
+ Args:
+ room_id (str): Room to block
+ user_id (str): Who blocked it
+
+ Returns:
+ Deferred
+ """
+ yield self.db.simple_upsert(
+ table="blocked_rooms",
+ keyvalues={"room_id": room_id},
+ values={},
+ insertion_values={"user_id": user_id},
+ desc="block_room",
+ )
+ yield self.db.runInteraction(
+ "block_room_invalidation",
+ self._invalidate_cache_and_stream,
+ self.is_room_blocked,
+ (room_id,),
+ )
+
+ @defer.inlineCallbacks
+ def get_rooms_for_retention_period_in_range(
+ self, min_ms, max_ms, include_null=False
+ ):
+ """Retrieves all of the rooms within the given retention range.
+
+ Optionally includes the rooms which don't have a retention policy.
+
+ Args:
+ min_ms (int|None): Duration in milliseconds that define the lower limit of
+ the range to handle (exclusive). If None, doesn't set a lower limit.
+ max_ms (int|None): Duration in milliseconds that define the upper limit of
+ the range to handle (inclusive). If None, doesn't set an upper limit.
+ include_null (bool): Whether to include rooms which retention policy is NULL
+ in the returned set.
+
+ Returns:
+ dict[str, dict]: The rooms within this range, along with their retention
+ policy. The key is "room_id", and maps to a dict describing the retention
+ policy associated with this room ID. The keys for this nested dict are
+ "min_lifetime" (int|None), and "max_lifetime" (int|None).
+ """
+
+ def get_rooms_for_retention_period_in_range_txn(txn):
+ range_conditions = []
+ args = []
+
+ if min_ms is not None:
+ range_conditions.append("max_lifetime > ?")
+ args.append(min_ms)
+
+ if max_ms is not None:
+ range_conditions.append("max_lifetime <= ?")
+ args.append(max_ms)
+
+ # Do a first query which will retrieve the rooms that have a retention policy
+ # in their current state.
+ sql = """
+ SELECT room_id, min_lifetime, max_lifetime FROM room_retention
+ INNER JOIN current_state_events USING (event_id, room_id)
+ """
+
+ if len(range_conditions):
+ sql += " WHERE (" + " AND ".join(range_conditions) + ")"
+
+ if include_null:
+ sql += " OR max_lifetime IS NULL"
+
+ txn.execute(sql, args)
+
+ rows = self.db.cursor_to_dict(txn)
+ rooms_dict = {}
+
+ for row in rows:
+ rooms_dict[row["room_id"]] = {
+ "min_lifetime": row["min_lifetime"],
+ "max_lifetime": row["max_lifetime"],
+ }
+
+ if include_null:
+ # If required, do a second query that retrieves all of the rooms we know
+ # of so we can handle rooms with no retention policy.
+ sql = "SELECT DISTINCT room_id FROM current_state_events"
+
+ txn.execute(sql)
+
+ rows = self.db.cursor_to_dict(txn)
+
+ # If a room isn't already in the dict (i.e. it doesn't have a retention
+ # policy in its state), add it with a null policy.
+ for row in rows:
+ if row["room_id"] not in rooms_dict:
+ rooms_dict[row["room_id"]] = {
+ "min_lifetime": None,
+ "max_lifetime": None,
+ }
+
+ return rooms_dict
+
+ rooms = yield self.db.runInteraction(
+ "get_rooms_for_retention_period_in_range",
+ get_rooms_for_retention_period_in_range_txn,
+ )
+
+ defer.returnValue(rooms)
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
new file mode 100644
index 0000000000..d5bd0cb5cf
--- /dev/null
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -0,0 +1,1265 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Iterable, List, Set
+
+from six import iteritems, itervalues
+
+from canonicaljson import json
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.metrics import LaterGauge
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import (
+ LoggingTransaction,
+ SQLBaseStore,
+ make_in_list_sql_clause,
+)
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
+from synapse.storage.engines import Sqlite3Engine
+from synapse.storage.roommember import (
+ GetRoomsForUserWithStreamOrdering,
+ MemberSummary,
+ ProfileInfo,
+ RoomsForUser,
+)
+from synapse.types import Collection, get_domain_from_id
+from synapse.util.async_helpers import Linearizer
+from synapse.util.caches import intern_string
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util.metrics import Measure
+from synapse.util.stringutils import to_ascii
+
+logger = logging.getLogger(__name__)
+
+
+_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
+_CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
+
+
+class RoomMemberWorkerStore(EventsWorkerStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs)
+
+ # Is the current_state_events.membership up to date? Or is the
+ # background update still running?
+ self._current_state_events_membership_up_to_date = False
+
+ txn = LoggingTransaction(
+ db_conn.cursor(),
+ name="_check_safe_current_state_events_membership_updated",
+ database_engine=self.database_engine,
+ )
+ self._check_safe_current_state_events_membership_updated_txn(txn)
+ txn.close()
+
+ if self.hs.config.metrics_flags.known_servers:
+ self._known_servers_count = 1
+ self.hs.get_clock().looping_call(
+ run_as_background_process,
+ 60 * 1000,
+ "_count_known_servers",
+ self._count_known_servers,
+ )
+ self.hs.get_clock().call_later(
+ 1000,
+ run_as_background_process,
+ "_count_known_servers",
+ self._count_known_servers,
+ )
+ LaterGauge(
+ "synapse_federation_known_servers",
+ "",
+ [],
+ lambda: self._known_servers_count,
+ )
+
+ @defer.inlineCallbacks
+ def _count_known_servers(self):
+ """
+ Count the servers that this server knows about.
+
+ The statistic is stored on the class for the
+ `synapse_federation_known_servers` LaterGauge to collect.
+ """
+
+ def _transact(txn):
+ if isinstance(self.database_engine, Sqlite3Engine):
+ query = """
+ SELECT COUNT(DISTINCT substr(out.user_id, pos+1))
+ FROM (
+ SELECT rm.user_id as user_id, instr(rm.user_id, ':')
+ AS pos FROM room_memberships as rm
+ INNER JOIN current_state_events as c ON rm.event_id = c.event_id
+ WHERE c.type = 'm.room.member'
+ ) as out
+ """
+ else:
+ query = """
+ SELECT COUNT(DISTINCT split_part(state_key, ':', 2))
+ FROM current_state_events
+ WHERE type = 'm.room.member' AND membership = 'join';
+ """
+ txn.execute(query)
+ return list(txn)[0][0]
+
+ count = yield self.db.runInteraction("get_known_servers", _transact)
+
+ # We always know about ourselves, even if we have nothing in
+ # room_memberships (for example, the server is new).
+ self._known_servers_count = max([count, 1])
+ return self._known_servers_count
+
+ def _check_safe_current_state_events_membership_updated_txn(self, txn):
+ """Checks if it is safe to assume the new current_state_events
+ membership column is up to date
+ """
+
+ pending_update = self.db.simple_select_one_txn(
+ txn,
+ table="background_updates",
+ keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME},
+ retcols=["update_name"],
+ allow_none=True,
+ )
+
+ self._current_state_events_membership_up_to_date = not pending_update
+
+ # If the update is still running, reschedule to run.
+ if pending_update:
+ self._clock.call_later(
+ 15.0,
+ run_as_background_process,
+ "_check_safe_current_state_events_membership_updated",
+ self.db.runInteraction,
+ "_check_safe_current_state_events_membership_updated",
+ self._check_safe_current_state_events_membership_updated_txn,
+ )
+
+ @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
+ def get_hosts_in_room(self, room_id, cache_context):
+ """Returns the set of all hosts currently in the room
+ """
+ user_ids = yield self.get_users_in_room(
+ room_id, on_invalidate=cache_context.invalidate
+ )
+ hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
+ return hosts
+
+ @cached(max_entries=100000, iterable=True)
+ def get_users_in_room(self, room_id):
+ return self.db.runInteraction(
+ "get_users_in_room", self.get_users_in_room_txn, room_id
+ )
+
+ def get_users_in_room_txn(self, txn, room_id):
+ # If we can assume current_state_events.membership is up to date
+ # then we can avoid a join, which is a Very Good Thing given how
+ # frequently this function gets called.
+ if self._current_state_events_membership_up_to_date:
+ sql = """
+ SELECT state_key FROM current_state_events
+ WHERE type = 'm.room.member' AND room_id = ? AND membership = ?
+ """
+ else:
+ sql = """
+ SELECT state_key FROM room_memberships as m
+ INNER JOIN current_state_events as c
+ ON m.event_id = c.event_id
+ AND m.room_id = c.room_id
+ AND m.user_id = c.state_key
+ WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?
+ """
+
+ txn.execute(sql, (room_id, Membership.JOIN))
+ return [to_ascii(r[0]) for r in txn]
+
+ @cached(max_entries=100000)
+ def get_room_summary(self, room_id):
+ """ Get the details of a room roughly suitable for use by the room
+ summary extension to /sync. Useful when lazy loading room members.
+ Args:
+ room_id (str): The room ID to query
+ Returns:
+ Deferred[dict[str, MemberSummary]:
+ dict of membership states, pointing to a MemberSummary named tuple.
+ """
+
+ def _get_room_summary_txn(txn):
+ # first get counts.
+ # We do this all in one transaction to keep the cache small.
+ # FIXME: get rid of this when we have room_stats
+
+ # If we can assume current_state_events.membership is up to date
+ # then we can avoid a join, which is a Very Good Thing given how
+ # frequently this function gets called.
+ if self._current_state_events_membership_up_to_date:
+ # Note, rejected events will have a null membership field, so
+ # we we manually filter them out.
+ sql = """
+ SELECT count(*), membership FROM current_state_events
+ WHERE type = 'm.room.member' AND room_id = ?
+ AND membership IS NOT NULL
+ GROUP BY membership
+ """
+ else:
+ sql = """
+ SELECT count(*), m.membership FROM room_memberships as m
+ INNER JOIN current_state_events as c
+ ON m.event_id = c.event_id
+ AND m.room_id = c.room_id
+ AND m.user_id = c.state_key
+ WHERE c.type = 'm.room.member' AND c.room_id = ?
+ GROUP BY m.membership
+ """
+
+ txn.execute(sql, (room_id,))
+ res = {}
+ for count, membership in txn:
+ summary = res.setdefault(to_ascii(membership), MemberSummary([], count))
+
+ # we order by membership and then fairly arbitrarily by event_id so
+ # heroes are consistent
+ if self._current_state_events_membership_up_to_date:
+ # Note, rejected events will have a null membership field, so
+ # we we manually filter them out.
+ sql = """
+ SELECT state_key, membership, event_id
+ FROM current_state_events
+ WHERE type = 'm.room.member' AND room_id = ?
+ AND membership IS NOT NULL
+ ORDER BY
+ CASE membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
+ event_id ASC
+ LIMIT ?
+ """
+ else:
+ sql = """
+ SELECT c.state_key, m.membership, c.event_id
+ FROM room_memberships as m
+ INNER JOIN current_state_events as c USING (room_id, event_id)
+ WHERE c.type = 'm.room.member' AND c.room_id = ?
+ ORDER BY
+ CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
+ c.event_id ASC
+ LIMIT ?
+ """
+
+ # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
+ txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6))
+ for user_id, membership, event_id in txn:
+ summary = res[to_ascii(membership)]
+ # we will always have a summary for this membership type at this
+ # point given the summary currently contains the counts.
+ members = summary.members
+ members.append((to_ascii(user_id), to_ascii(event_id)))
+
+ return res
+
+ return self.db.runInteraction("get_room_summary", _get_room_summary_txn)
+
+ def _get_user_counts_in_room_txn(self, txn, room_id):
+ """
+ Get the user count in a room by membership.
+
+ Args:
+ room_id (str)
+ membership (Membership)
+
+ Returns:
+ Deferred[int]
+ """
+ sql = """
+ SELECT m.membership, count(*) FROM room_memberships as m
+ INNER JOIN current_state_events as c USING(event_id)
+ WHERE c.type = 'm.room.member' AND c.room_id = ?
+ GROUP BY m.membership
+ """
+
+ txn.execute(sql, (room_id,))
+ return {row[0]: row[1] for row in txn}
+
+ @cached()
+ def get_invited_rooms_for_local_user(self, user_id):
+ """ Get all the rooms the *local* user is invited to
+
+ Args:
+ user_id (str): The user ID.
+ Returns:
+ A deferred list of RoomsForUser.
+ """
+
+ return self.get_rooms_for_local_user_where_membership_is(
+ user_id, [Membership.INVITE]
+ )
+
+ @defer.inlineCallbacks
+ def get_invite_for_local_user_in_room(self, user_id, room_id):
+ """Gets the invite for the given *local* user and room
+
+ Args:
+ user_id (str)
+ room_id (str)
+
+ Returns:
+ Deferred: Resolves to either a RoomsForUser or None if no invite was
+ found.
+ """
+ invites = yield self.get_invited_rooms_for_local_user(user_id)
+ for invite in invites:
+ if invite.room_id == room_id:
+ return invite
+ return None
+
+ @defer.inlineCallbacks
+ def get_rooms_for_local_user_where_membership_is(self, user_id, membership_list):
+ """ Get all the rooms for this *local* user where the membership for this user
+ matches one in the membership list.
+
+ Filters out forgotten rooms.
+
+ Args:
+ user_id (str): The user ID.
+ membership_list (list): A list of synapse.api.constants.Membership
+ values which the user must be in.
+
+ Returns:
+ Deferred[list[RoomsForUser]]
+ """
+ if not membership_list:
+ return defer.succeed(None)
+
+ rooms = yield self.db.runInteraction(
+ "get_rooms_for_local_user_where_membership_is",
+ self._get_rooms_for_local_user_where_membership_is_txn,
+ user_id,
+ membership_list,
+ )
+
+ # Now we filter out forgotten rooms
+ forgotten_rooms = yield self.get_forgotten_rooms_for_user(user_id)
+ return [room for room in rooms if room.room_id not in forgotten_rooms]
+
+ def _get_rooms_for_local_user_where_membership_is_txn(
+ self, txn, user_id, membership_list
+ ):
+ # Paranoia check.
+ if not self.hs.is_mine_id(user_id):
+ raise Exception(
+ "Cannot call 'get_rooms_for_local_user_where_membership_is' on non-local user %r"
+ % (user_id,),
+ )
+
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "c.membership", membership_list
+ )
+
+ sql = """
+ SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
+ FROM local_current_membership AS c
+ INNER JOIN events AS e USING (room_id, event_id)
+ WHERE
+ user_id = ?
+ AND %s
+ """ % (
+ clause,
+ )
+
+ txn.execute(sql, (user_id, *args))
+ results = [RoomsForUser(**r) for r in self.db.cursor_to_dict(txn)]
+
+ return results
+
+ @cached(max_entries=500000, iterable=True)
+ def get_rooms_for_user_with_stream_ordering(self, user_id):
+ """Returns a set of room_ids the user is currently joined to.
+
+ If a remote user only returns rooms this server is currently
+ participating in.
+
+ Args:
+ user_id (str)
+
+ Returns:
+ Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
+ the rooms the user is in currently, along with the stream ordering
+ of the most recent join for that user and room.
+ """
+ return self.db.runInteraction(
+ "get_rooms_for_user_with_stream_ordering",
+ self._get_rooms_for_user_with_stream_ordering_txn,
+ user_id,
+ )
+
+ def _get_rooms_for_user_with_stream_ordering_txn(self, txn, user_id):
+ # We use `current_state_events` here and not `local_current_membership`
+ # as a) this gets called with remote users and b) this only gets called
+ # for rooms the server is participating in.
+ if self._current_state_events_membership_up_to_date:
+ sql = """
+ SELECT room_id, e.stream_ordering
+ FROM current_state_events AS c
+ INNER JOIN events AS e USING (room_id, event_id)
+ WHERE
+ c.type = 'm.room.member'
+ AND state_key = ?
+ AND c.membership = ?
+ """
+ else:
+ sql = """
+ SELECT room_id, e.stream_ordering
+ FROM current_state_events AS c
+ INNER JOIN room_memberships AS m USING (room_id, event_id)
+ INNER JOIN events AS e USING (room_id, event_id)
+ WHERE
+ c.type = 'm.room.member'
+ AND state_key = ?
+ AND m.membership = ?
+ """
+
+ txn.execute(sql, (user_id, Membership.JOIN))
+ results = frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
+
+ return results
+
+ async def get_users_server_still_shares_room_with(
+ self, user_ids: Collection[str]
+ ) -> Set[str]:
+ """Given a list of users return the set that the server still share a
+ room with.
+ """
+
+ if not user_ids:
+ return set()
+
+ def _get_users_server_still_shares_room_with_txn(txn):
+ sql = """
+ SELECT state_key FROM current_state_events
+ WHERE
+ type = 'm.room.member'
+ AND membership = 'join'
+ AND %s
+ GROUP BY state_key
+ """
+
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "state_key", user_ids
+ )
+
+ txn.execute(sql % (clause,), args)
+
+ return {row[0] for row in txn}
+
+ return await self.db.runInteraction(
+ "get_users_server_still_shares_room_with",
+ _get_users_server_still_shares_room_with_txn,
+ )
+
+ @defer.inlineCallbacks
+ def get_rooms_for_user(self, user_id, on_invalidate=None):
+ """Returns a set of room_ids the user is currently joined to.
+
+ If a remote user only returns rooms this server is currently
+ participating in.
+ """
+ rooms = yield self.get_rooms_for_user_with_stream_ordering(
+ user_id, on_invalidate=on_invalidate
+ )
+ return frozenset(r.room_id for r in rooms)
+
+ @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
+ def get_users_who_share_room_with_user(self, user_id, cache_context):
+ """Returns the set of users who share a room with `user_id`
+ """
+ room_ids = yield self.get_rooms_for_user(
+ user_id, on_invalidate=cache_context.invalidate
+ )
+
+ user_who_share_room = set()
+ for room_id in room_ids:
+ user_ids = yield self.get_users_in_room(
+ room_id, on_invalidate=cache_context.invalidate
+ )
+ user_who_share_room.update(user_ids)
+
+ return user_who_share_room
+
+ @defer.inlineCallbacks
+ def get_joined_users_from_context(self, event, context):
+ state_group = context.state_group
+ if not state_group:
+ # If state_group is None it means it has yet to be assigned a
+ # state group, i.e. we need to make sure that calls with a state_group
+ # of None don't hit previous cached calls with a None state_group.
+ # To do this we set the state_group to a new object as object() != object()
+ state_group = object()
+
+ current_state_ids = yield context.get_current_state_ids()
+ result = yield self._get_joined_users_from_context(
+ event.room_id, state_group, current_state_ids, event=event, context=context
+ )
+ return result
+
+ @defer.inlineCallbacks
+ def get_joined_users_from_state(self, room_id, state_entry):
+ state_group = state_entry.state_group
+ if not state_group:
+ # If state_group is None it means it has yet to be assigned a
+ # state group, i.e. we need to make sure that calls with a state_group
+ # of None don't hit previous cached calls with a None state_group.
+ # To do this we set the state_group to a new object as object() != object()
+ state_group = object()
+
+ with Measure(self._clock, "get_joined_users_from_state"):
+ return (
+ yield self._get_joined_users_from_context(
+ room_id, state_group, state_entry.state, context=state_entry
+ )
+ )
+
+ @cachedInlineCallbacks(
+ num_args=2, cache_context=True, iterable=True, max_entries=100000
+ )
+ def _get_joined_users_from_context(
+ self,
+ room_id,
+ state_group,
+ current_state_ids,
+ cache_context,
+ event=None,
+ context=None,
+ ):
+ # We don't use `state_group`, it's there so that we can cache based
+ # on it. However, it's important that it's never None, since two current_states
+ # with a state_group of None are likely to be different.
+ # See bulk_get_push_rules_for_room for how we work around this.
+ assert state_group is not None
+
+ users_in_room = {}
+ member_event_ids = [
+ e_id
+ for key, e_id in iteritems(current_state_ids)
+ if key[0] == EventTypes.Member
+ ]
+
+ if context is not None:
+ # If we have a context with a delta from a previous state group,
+ # check if we also have the result from the previous group in cache.
+ # If we do then we can reuse that result and simply update it with
+ # any membership changes in `delta_ids`
+ if context.prev_group and context.delta_ids:
+ prev_res = self._get_joined_users_from_context.cache.get(
+ (room_id, context.prev_group), None
+ )
+ if prev_res and isinstance(prev_res, dict):
+ users_in_room = dict(prev_res)
+ member_event_ids = [
+ e_id
+ for key, e_id in iteritems(context.delta_ids)
+ if key[0] == EventTypes.Member
+ ]
+ for etype, state_key in context.delta_ids:
+ users_in_room.pop(state_key, None)
+
+ # We check if we have any of the member event ids in the event cache
+ # before we ask the DB
+
+ # We don't update the event cache hit ratio as it completely throws off
+ # the hit ratio counts. After all, we don't populate the cache if we
+ # miss it here
+ event_map = self._get_events_from_cache(
+ member_event_ids, allow_rejected=False, update_metrics=False
+ )
+
+ missing_member_event_ids = []
+ for event_id in member_event_ids:
+ ev_entry = event_map.get(event_id)
+ if ev_entry:
+ if ev_entry.event.membership == Membership.JOIN:
+ users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo(
+ display_name=to_ascii(
+ ev_entry.event.content.get("displayname", None)
+ ),
+ avatar_url=to_ascii(
+ ev_entry.event.content.get("avatar_url", None)
+ ),
+ )
+ else:
+ missing_member_event_ids.append(event_id)
+
+ if missing_member_event_ids:
+ event_to_memberships = yield self._get_joined_profiles_from_event_ids(
+ missing_member_event_ids
+ )
+ users_in_room.update((row for row in event_to_memberships.values() if row))
+
+ if event is not None and event.type == EventTypes.Member:
+ if event.membership == Membership.JOIN:
+ if event.event_id in member_event_ids:
+ users_in_room[to_ascii(event.state_key)] = ProfileInfo(
+ display_name=to_ascii(event.content.get("displayname", None)),
+ avatar_url=to_ascii(event.content.get("avatar_url", None)),
+ )
+
+ return users_in_room
+
+ @cached(max_entries=10000)
+ def _get_joined_profile_from_event_id(self, event_id):
+ raise NotImplementedError()
+
+ @cachedList(
+ cached_method_name="_get_joined_profile_from_event_id",
+ list_name="event_ids",
+ inlineCallbacks=True,
+ )
+ def _get_joined_profiles_from_event_ids(self, event_ids):
+ """For given set of member event_ids check if they point to a join
+ event and if so return the associated user and profile info.
+
+ Args:
+ event_ids (Iterable[str]): The member event IDs to lookup
+
+ Returns:
+ Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
+ to `user_id` and ProfileInfo (or None if not join event).
+ """
+
+ rows = yield self.db.simple_select_many_batch(
+ table="room_memberships",
+ column="event_id",
+ iterable=event_ids,
+ retcols=("user_id", "display_name", "avatar_url", "event_id"),
+ keyvalues={"membership": Membership.JOIN},
+ batch_size=500,
+ desc="_get_membership_from_event_ids",
+ )
+
+ return {
+ row["event_id"]: (
+ row["user_id"],
+ ProfileInfo(
+ avatar_url=row["avatar_url"], display_name=row["display_name"]
+ ),
+ )
+ for row in rows
+ }
+
+ @cachedInlineCallbacks(max_entries=10000)
+ def is_host_joined(self, room_id, host):
+ if "%" in host or "_" in host:
+ raise Exception("Invalid host name")
+
+ sql = """
+ SELECT state_key FROM current_state_events AS c
+ INNER JOIN room_memberships AS m USING (event_id)
+ WHERE m.membership = 'join'
+ AND type = 'm.room.member'
+ AND c.room_id = ?
+ AND state_key LIKE ?
+ LIMIT 1
+ """
+
+ # We do need to be careful to ensure that host doesn't have any wild cards
+ # in it, but we checked above for known ones and we'll check below that
+ # the returned user actually has the correct domain.
+ like_clause = "%:" + host
+
+ rows = yield self.db.execute("is_host_joined", None, sql, room_id, like_clause)
+
+ if not rows:
+ return False
+
+ user_id = rows[0][0]
+ if get_domain_from_id(user_id) != host:
+ # This can only happen if the host name has something funky in it
+ raise Exception("Invalid host name")
+
+ return True
+
+ @cachedInlineCallbacks()
+ def was_host_joined(self, room_id, host):
+ """Check whether the server is or ever was in the room.
+
+ Args:
+ room_id (str)
+ host (str)
+
+ Returns:
+ Deferred: Resolves to True if the host is/was in the room, otherwise
+ False.
+ """
+ if "%" in host or "_" in host:
+ raise Exception("Invalid host name")
+
+ sql = """
+ SELECT user_id FROM room_memberships
+ WHERE room_id = ?
+ AND user_id LIKE ?
+ AND membership = 'join'
+ LIMIT 1
+ """
+
+ # We do need to be careful to ensure that host doesn't have any wild cards
+ # in it, but we checked above for known ones and we'll check below that
+ # the returned user actually has the correct domain.
+ like_clause = "%:" + host
+
+ rows = yield self.db.execute("was_host_joined", None, sql, room_id, like_clause)
+
+ if not rows:
+ return False
+
+ user_id = rows[0][0]
+ if get_domain_from_id(user_id) != host:
+ # This can only happen if the host name has something funky in it
+ raise Exception("Invalid host name")
+
+ return True
+
+ @defer.inlineCallbacks
+ def get_joined_hosts(self, room_id, state_entry):
+ state_group = state_entry.state_group
+ if not state_group:
+ # If state_group is None it means it has yet to be assigned a
+ # state group, i.e. we need to make sure that calls with a state_group
+ # of None don't hit previous cached calls with a None state_group.
+ # To do this we set the state_group to a new object as object() != object()
+ state_group = object()
+
+ with Measure(self._clock, "get_joined_hosts"):
+ return (
+ yield self._get_joined_hosts(
+ room_id, state_group, state_entry.state, state_entry=state_entry
+ )
+ )
+
+ @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
+ # @defer.inlineCallbacks
+ def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
+ # We don't use `state_group`, its there so that we can cache based
+ # on it. However, its important that its never None, since two current_state's
+ # with a state_group of None are likely to be different.
+ # See bulk_get_push_rules_for_room for how we work around this.
+ assert state_group is not None
+
+ cache = yield self._get_joined_hosts_cache(room_id)
+ joined_hosts = yield cache.get_destinations(state_entry)
+
+ return joined_hosts
+
+ @cached(max_entries=10000)
+ def _get_joined_hosts_cache(self, room_id):
+ return _JoinedHostsCache(self, room_id)
+
+ @cachedInlineCallbacks(num_args=2)
+ def did_forget(self, user_id, room_id):
+ """Returns whether user_id has elected to discard history for room_id.
+
+ Returns False if they have since re-joined."""
+
+ def f(txn):
+ sql = (
+ "SELECT"
+ " COUNT(*)"
+ " FROM"
+ " room_memberships"
+ " WHERE"
+ " user_id = ?"
+ " AND"
+ " room_id = ?"
+ " AND"
+ " forgotten = 0"
+ )
+ txn.execute(sql, (user_id, room_id))
+ rows = txn.fetchall()
+ return rows[0][0]
+
+ count = yield self.db.runInteraction("did_forget_membership", f)
+ return count == 0
+
+ @cached()
+ def get_forgotten_rooms_for_user(self, user_id):
+ """Gets all rooms the user has forgotten.
+
+ Args:
+ user_id (str)
+
+ Returns:
+ Deferred[set[str]]
+ """
+
+ def _get_forgotten_rooms_for_user_txn(txn):
+ # This is a slightly convoluted query that first looks up all rooms
+ # that the user has forgotten in the past, then rechecks that list
+ # to see if any have subsequently been updated. This is done so that
+ # we can use a partial index on `forgotten = 1` on the assumption
+ # that few users will actually forget many rooms.
+ #
+ # Note that a room is considered "forgotten" if *all* membership
+ # events for that user and room have the forgotten field set (as
+ # when a user forgets a room we update all rows for that user and
+ # room, not just the current one).
+ sql = """
+ SELECT room_id, (
+ SELECT count(*) FROM room_memberships
+ WHERE room_id = m.room_id AND user_id = m.user_id AND forgotten = 0
+ ) AS count
+ FROM room_memberships AS m
+ WHERE user_id = ? AND forgotten = 1
+ GROUP BY room_id, user_id;
+ """
+ txn.execute(sql, (user_id,))
+ return {row[0] for row in txn if row[1] == 0}
+
+ return self.db.runInteraction(
+ "get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
+ )
+
+ @defer.inlineCallbacks
+ def get_rooms_user_has_been_in(self, user_id):
+ """Get all rooms that the user has ever been in.
+
+ Args:
+ user_id (str)
+
+ Returns:
+ Deferred[set[str]]: Set of room IDs.
+ """
+
+ room_ids = yield self.db.simple_select_onecol(
+ table="room_memberships",
+ keyvalues={"membership": Membership.JOIN, "user_id": user_id},
+ retcol="room_id",
+ desc="get_rooms_user_has_been_in",
+ )
+
+ return set(room_ids)
+
+ def get_membership_from_event_ids(
+ self, member_event_ids: Iterable[str]
+ ) -> List[dict]:
+ """Get user_id and membership of a set of event IDs.
+ """
+
+ return self.db.simple_select_many_batch(
+ table="room_memberships",
+ column="event_id",
+ iterable=member_event_ids,
+ retcols=("user_id", "membership", "event_id"),
+ keyvalues={},
+ batch_size=500,
+ desc="get_membership_from_event_ids",
+ )
+
+ async def is_local_host_in_room_ignoring_users(
+ self, room_id: str, ignore_users: Collection[str]
+ ) -> bool:
+ """Check if there are any local users, excluding those in the given
+ list, in the room.
+ """
+
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "user_id", ignore_users
+ )
+
+ sql = """
+ SELECT 1 FROM local_current_membership
+ WHERE
+ room_id = ? AND membership = ?
+ AND NOT (%s)
+ LIMIT 1
+ """ % (
+ clause,
+ )
+
+ def _is_local_host_in_room_ignoring_users_txn(txn):
+ txn.execute(sql, (room_id, Membership.JOIN, *args))
+
+ return bool(txn.fetchone())
+
+ return await self.db.runInteraction(
+ "is_local_host_in_room_ignoring_users",
+ _is_local_host_in_room_ignoring_users_txn,
+ )
+
+
+class RoomMemberBackgroundUpdateStore(SQLBaseStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ self.db.updates.register_background_update_handler(
+ _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
+ )
+ self.db.updates.register_background_update_handler(
+ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
+ self._background_current_state_membership,
+ )
+ self.db.updates.register_background_index_update(
+ "room_membership_forgotten_idx",
+ index_name="room_memberships_user_room_forgotten",
+ table="room_memberships",
+ columns=["user_id", "room_id"],
+ where_clause="forgotten = 1",
+ )
+
+ @defer.inlineCallbacks
+ def _background_add_membership_profile(self, progress, batch_size):
+ target_min_stream_id = progress.get(
+ "target_min_stream_id_inclusive", self._min_stream_order_on_start
+ )
+ max_stream_id = progress.get(
+ "max_stream_id_exclusive", self._stream_order_on_start + 1
+ )
+
+ INSERT_CLUMP_SIZE = 1000
+
+ def add_membership_profile_txn(txn):
+ sql = """
+ SELECT stream_ordering, event_id, events.room_id, event_json.json
+ FROM events
+ INNER JOIN event_json USING (event_id)
+ INNER JOIN room_memberships USING (event_id)
+ WHERE ? <= stream_ordering AND stream_ordering < ?
+ AND type = 'm.room.member'
+ ORDER BY stream_ordering DESC
+ LIMIT ?
+ """
+
+ txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
+
+ rows = self.db.cursor_to_dict(txn)
+ if not rows:
+ return 0
+
+ min_stream_id = rows[-1]["stream_ordering"]
+
+ to_update = []
+ for row in rows:
+ event_id = row["event_id"]
+ room_id = row["room_id"]
+ try:
+ event_json = json.loads(row["json"])
+ content = event_json["content"]
+ except Exception:
+ continue
+
+ display_name = content.get("displayname", None)
+ avatar_url = content.get("avatar_url", None)
+
+ if display_name or avatar_url:
+ to_update.append((display_name, avatar_url, event_id, room_id))
+
+ to_update_sql = """
+ UPDATE room_memberships SET display_name = ?, avatar_url = ?
+ WHERE event_id = ? AND room_id = ?
+ """
+ for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
+ clump = to_update[index : index + INSERT_CLUMP_SIZE]
+ txn.executemany(to_update_sql, clump)
+
+ progress = {
+ "target_min_stream_id_inclusive": target_min_stream_id,
+ "max_stream_id_exclusive": min_stream_id,
+ }
+
+ self.db.updates._background_update_progress_txn(
+ txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress
+ )
+
+ return len(rows)
+
+ result = yield self.db.runInteraction(
+ _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
+ )
+
+ if not result:
+ yield self.db.updates._end_background_update(
+ _MEMBERSHIP_PROFILE_UPDATE_NAME
+ )
+
+ return result
+
+ @defer.inlineCallbacks
+ def _background_current_state_membership(self, progress, batch_size):
+ """Update the new membership column on current_state_events.
+
+ This works by iterating over all rooms in alphebetical order.
+ """
+
+ def _background_current_state_membership_txn(txn, last_processed_room):
+ processed = 0
+ while processed < batch_size:
+ txn.execute(
+ """
+ SELECT MIN(room_id) FROM current_state_events WHERE room_id > ?
+ """,
+ (last_processed_room,),
+ )
+ row = txn.fetchone()
+ if not row or not row[0]:
+ return processed, True
+
+ (next_room,) = row
+
+ sql = """
+ UPDATE current_state_events
+ SET membership = (
+ SELECT membership FROM room_memberships
+ WHERE event_id = current_state_events.event_id
+ )
+ WHERE room_id = ?
+ """
+ txn.execute(sql, (next_room,))
+ processed += txn.rowcount
+
+ last_processed_room = next_room
+
+ self.db.updates._background_update_progress_txn(
+ txn,
+ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME,
+ {"last_processed_room": last_processed_room},
+ )
+
+ return processed, False
+
+ # If we haven't got a last processed room then just use the empty
+ # string, which will compare before all room IDs correctly.
+ last_processed_room = progress.get("last_processed_room", "")
+
+ row_count, finished = yield self.db.runInteraction(
+ "_background_current_state_membership_update",
+ _background_current_state_membership_txn,
+ last_processed_room,
+ )
+
+ if finished:
+ yield self.db.updates._end_background_update(
+ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME
+ )
+
+ return row_count
+
+
+class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(RoomMemberStore, self).__init__(database, db_conn, hs)
+
+ def _store_room_members_txn(self, txn, events, backfilled):
+ """Store a room member in the database.
+ """
+ self.db.simple_insert_many_txn(
+ txn,
+ table="room_memberships",
+ values=[
+ {
+ "event_id": event.event_id,
+ "user_id": event.state_key,
+ "sender": event.user_id,
+ "room_id": event.room_id,
+ "membership": event.membership,
+ "display_name": event.content.get("displayname", None),
+ "avatar_url": event.content.get("avatar_url", None),
+ }
+ for event in events
+ ],
+ )
+
+ for event in events:
+ txn.call_after(
+ self._membership_stream_cache.entity_has_changed,
+ event.state_key,
+ event.internal_metadata.stream_ordering,
+ )
+ txn.call_after(
+ self.get_invited_rooms_for_local_user.invalidate, (event.state_key,)
+ )
+
+ # We update the local_invites table only if the event is "current",
+ # i.e., its something that has just happened. If the event is an
+ # outlier it is only current if its an "out of band membership",
+ # like a remote invite or a rejection of a remote invite.
+ is_new_state = not backfilled and (
+ not event.internal_metadata.is_outlier()
+ or event.internal_metadata.is_out_of_band_membership()
+ )
+ is_mine = self.hs.is_mine_id(event.state_key)
+ if is_new_state and is_mine:
+ if event.membership == Membership.INVITE:
+ self.db.simple_insert_txn(
+ txn,
+ table="local_invites",
+ values={
+ "event_id": event.event_id,
+ "invitee": event.state_key,
+ "inviter": event.sender,
+ "room_id": event.room_id,
+ "stream_id": event.internal_metadata.stream_ordering,
+ },
+ )
+ else:
+ sql = (
+ "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
+ " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+ " AND replaced_by is NULL"
+ )
+
+ txn.execute(
+ sql,
+ (
+ event.internal_metadata.stream_ordering,
+ event.event_id,
+ event.room_id,
+ event.state_key,
+ ),
+ )
+
+ # We also update the `local_current_membership` table with
+ # latest invite info. This will usually get updated by the
+ # `current_state_events` handling, unless its an outlier.
+ if event.internal_metadata.is_outlier():
+ # This should only happen for out of band memberships, so
+ # we add a paranoia check.
+ assert event.internal_metadata.is_out_of_band_membership()
+
+ self.db.simple_upsert_txn(
+ txn,
+ table="local_current_membership",
+ keyvalues={
+ "room_id": event.room_id,
+ "user_id": event.state_key,
+ },
+ values={
+ "event_id": event.event_id,
+ "membership": event.membership,
+ },
+ )
+
+ @defer.inlineCallbacks
+ def locally_reject_invite(self, user_id, room_id):
+ sql = (
+ "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
+ " room_id = ? AND invitee = ? AND locally_rejected is NULL"
+ " AND replaced_by is NULL"
+ )
+
+ def f(txn, stream_ordering):
+ txn.execute(sql, (stream_ordering, True, room_id, user_id))
+
+ # We also clear this entry from `local_current_membership`.
+ # Ideally we'd point to a leave event, but we don't have one, so
+ # nevermind.
+ self.db.simple_delete_txn(
+ txn,
+ table="local_current_membership",
+ keyvalues={"room_id": room_id, "user_id": user_id},
+ )
+
+ with self._stream_id_gen.get_next() as stream_ordering:
+ yield self.db.runInteraction("locally_reject_invite", f, stream_ordering)
+
+ def forget(self, user_id, room_id):
+ """Indicate that user_id wishes to discard history for room_id."""
+
+ def f(txn):
+ sql = (
+ "UPDATE"
+ " room_memberships"
+ " SET"
+ " forgotten = 1"
+ " WHERE"
+ " user_id = ?"
+ " AND"
+ " room_id = ?"
+ )
+ txn.execute(sql, (user_id, room_id))
+
+ self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
+ self._invalidate_cache_and_stream(
+ txn, self.get_forgotten_rooms_for_user, (user_id,)
+ )
+
+ return self.db.runInteraction("forget_membership", f)
+
+
+class _JoinedHostsCache(object):
+ """Cache for joined hosts in a room that is optimised to handle updates
+ via state deltas.
+ """
+
+ def __init__(self, store, room_id):
+ self.store = store
+ self.room_id = room_id
+
+ self.hosts_to_joined_users = {}
+
+ self.state_group = object()
+
+ self.linearizer = Linearizer("_JoinedHostsCache")
+
+ self._len = 0
+
+ @defer.inlineCallbacks
+ def get_destinations(self, state_entry):
+ """Get set of destinations for a state entry
+
+ Args:
+ state_entry(synapse.state._StateCacheEntry)
+ """
+ if state_entry.state_group == self.state_group:
+ return frozenset(self.hosts_to_joined_users)
+
+ with (yield self.linearizer.queue(())):
+ if state_entry.state_group == self.state_group:
+ pass
+ elif state_entry.prev_group == self.state_group:
+ for (typ, state_key), event_id in iteritems(state_entry.delta_ids):
+ if typ != EventTypes.Member:
+ continue
+
+ host = intern_string(get_domain_from_id(state_key))
+ user_id = state_key
+ known_joins = self.hosts_to_joined_users.setdefault(host, set())
+
+ event = yield self.store.get_event(event_id)
+ if event.membership == Membership.JOIN:
+ known_joins.add(user_id)
+ else:
+ known_joins.discard(user_id)
+
+ if not known_joins:
+ self.hosts_to_joined_users.pop(host, None)
+ else:
+ joined_users = yield self.store.get_joined_users_from_state(
+ self.room_id, state_entry
+ )
+
+ self.hosts_to_joined_users = {}
+ for user_id in joined_users:
+ host = intern_string(get_domain_from_id(user_id))
+ self.hosts_to_joined_users.setdefault(host, set()).add(user_id)
+
+ if state_entry.state_group:
+ self.state_group = state_entry.state_group
+ else:
+ self.state_group = object()
+ self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users))
+ return frozenset(self.hosts_to_joined_users)
+
+ def __len__(self):
+ return self._len
diff --git a/synapse/storage/schema/delta/12/v12.sql b/synapse/storage/data_stores/main/schema/delta/12/v12.sql
index 5964c5aaac..5964c5aaac 100644
--- a/synapse/storage/schema/delta/12/v12.sql
+++ b/synapse/storage/data_stores/main/schema/delta/12/v12.sql
diff --git a/synapse/storage/schema/delta/13/v13.sql b/synapse/storage/data_stores/main/schema/delta/13/v13.sql
index f8649e5d99..f8649e5d99 100644
--- a/synapse/storage/schema/delta/13/v13.sql
+++ b/synapse/storage/data_stores/main/schema/delta/13/v13.sql
diff --git a/synapse/storage/schema/delta/14/v14.sql b/synapse/storage/data_stores/main/schema/delta/14/v14.sql
index a831920da6..a831920da6 100644
--- a/synapse/storage/schema/delta/14/v14.sql
+++ b/synapse/storage/data_stores/main/schema/delta/14/v14.sql
diff --git a/synapse/storage/schema/delta/15/appservice_txns.sql b/synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql
index e4f5e76aec..e4f5e76aec 100644
--- a/synapse/storage/schema/delta/15/appservice_txns.sql
+++ b/synapse/storage/data_stores/main/schema/delta/15/appservice_txns.sql
diff --git a/synapse/storage/schema/delta/15/presence_indices.sql b/synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql
index 6b8d0f1ca7..6b8d0f1ca7 100644
--- a/synapse/storage/schema/delta/15/presence_indices.sql
+++ b/synapse/storage/data_stores/main/schema/delta/15/presence_indices.sql
diff --git a/synapse/storage/schema/delta/15/v15.sql b/synapse/storage/data_stores/main/schema/delta/15/v15.sql
index 9523d2bcc3..9523d2bcc3 100644
--- a/synapse/storage/schema/delta/15/v15.sql
+++ b/synapse/storage/data_stores/main/schema/delta/15/v15.sql
diff --git a/synapse/storage/schema/delta/16/events_order_index.sql b/synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql
index a48f215170..a48f215170 100644
--- a/synapse/storage/schema/delta/16/events_order_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/16/events_order_index.sql
diff --git a/synapse/storage/schema/delta/16/remote_media_cache_index.sql b/synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql
index 7a15265cb1..7a15265cb1 100644
--- a/synapse/storage/schema/delta/16/remote_media_cache_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/16/remote_media_cache_index.sql
diff --git a/synapse/storage/schema/delta/16/remove_duplicates.sql b/synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql
index 65c97b5e2f..65c97b5e2f 100644
--- a/synapse/storage/schema/delta/16/remove_duplicates.sql
+++ b/synapse/storage/data_stores/main/schema/delta/16/remove_duplicates.sql
diff --git a/synapse/storage/schema/delta/16/room_alias_index.sql b/synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql
index f82486132b..f82486132b 100644
--- a/synapse/storage/schema/delta/16/room_alias_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/16/room_alias_index.sql
diff --git a/synapse/storage/schema/delta/16/unique_constraints.sql b/synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql
index 5b8de52c33..5b8de52c33 100644
--- a/synapse/storage/schema/delta/16/unique_constraints.sql
+++ b/synapse/storage/data_stores/main/schema/delta/16/unique_constraints.sql
diff --git a/synapse/storage/schema/delta/16/users.sql b/synapse/storage/data_stores/main/schema/delta/16/users.sql
index cd0709250d..cd0709250d 100644
--- a/synapse/storage/schema/delta/16/users.sql
+++ b/synapse/storage/data_stores/main/schema/delta/16/users.sql
diff --git a/synapse/storage/schema/delta/17/drop_indexes.sql b/synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql
index 7c9a90e27f..7c9a90e27f 100644
--- a/synapse/storage/schema/delta/17/drop_indexes.sql
+++ b/synapse/storage/data_stores/main/schema/delta/17/drop_indexes.sql
diff --git a/synapse/storage/schema/delta/17/server_keys.sql b/synapse/storage/data_stores/main/schema/delta/17/server_keys.sql
index 70b247a06b..70b247a06b 100644
--- a/synapse/storage/schema/delta/17/server_keys.sql
+++ b/synapse/storage/data_stores/main/schema/delta/17/server_keys.sql
diff --git a/synapse/storage/schema/delta/17/user_threepids.sql b/synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql
index c17715ac80..c17715ac80 100644
--- a/synapse/storage/schema/delta/17/user_threepids.sql
+++ b/synapse/storage/data_stores/main/schema/delta/17/user_threepids.sql
diff --git a/synapse/storage/schema/delta/18/server_keys_bigger_ints.sql b/synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql
index 6e0871c92b..6e0871c92b 100644
--- a/synapse/storage/schema/delta/18/server_keys_bigger_ints.sql
+++ b/synapse/storage/data_stores/main/schema/delta/18/server_keys_bigger_ints.sql
diff --git a/synapse/storage/schema/delta/19/event_index.sql b/synapse/storage/data_stores/main/schema/delta/19/event_index.sql
index 18b97b4332..18b97b4332 100644
--- a/synapse/storage/schema/delta/19/event_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/19/event_index.sql
diff --git a/synapse/storage/schema/delta/20/dummy.sql b/synapse/storage/data_stores/main/schema/delta/20/dummy.sql
index e0ac49d1ec..e0ac49d1ec 100644
--- a/synapse/storage/schema/delta/20/dummy.sql
+++ b/synapse/storage/data_stores/main/schema/delta/20/dummy.sql
diff --git a/synapse/storage/schema/delta/20/pushers.py b/synapse/storage/data_stores/main/schema/delta/20/pushers.py
index 147496a38b..3edfcfd783 100644
--- a/synapse/storage/schema/delta/20/pushers.py
+++ b/synapse/storage/data_stores/main/schema/delta/20/pushers.py
@@ -29,7 +29,8 @@ logger = logging.getLogger(__name__)
def run_create(cur, database_engine, *args, **kwargs):
logger.info("Porting pushers table...")
- cur.execute("""
+ cur.execute(
+ """
CREATE TABLE IF NOT EXISTS pushers2 (
id BIGINT PRIMARY KEY,
user_name TEXT NOT NULL,
@@ -48,27 +49,34 @@ def run_create(cur, database_engine, *args, **kwargs):
failing_since BIGINT,
UNIQUE (app_id, pushkey, user_name)
)
- """)
- cur.execute("""SELECT
+ """
+ )
+ cur.execute(
+ """SELECT
id, user_name, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, ts, lang, data, last_token, last_success,
failing_since
FROM pushers
- """)
+ """
+ )
count = 0
for row in cur.fetchall():
row = list(row)
row[8] = bytes(row[8]).decode("utf-8")
row[11] = bytes(row[11]).decode("utf-8")
- cur.execute(database_engine.convert_param_style("""
+ cur.execute(
+ database_engine.convert_param_style(
+ """
INSERT into pushers2 (
id, user_name, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, ts, lang, data, last_token, last_success,
failing_since
- ) values (%s)""" % (','.join(['?' for _ in range(len(row))]))),
- row
+ ) values (%s)"""
+ % (",".join(["?" for _ in range(len(row))]))
+ ),
+ row,
)
count += 1
cur.execute("DROP TABLE pushers")
diff --git a/synapse/storage/schema/delta/21/end_to_end_keys.sql b/synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql
index 4c2fb20b77..4c2fb20b77 100644
--- a/synapse/storage/schema/delta/21/end_to_end_keys.sql
+++ b/synapse/storage/data_stores/main/schema/delta/21/end_to_end_keys.sql
diff --git a/synapse/storage/schema/delta/21/receipts.sql b/synapse/storage/data_stores/main/schema/delta/21/receipts.sql
index d070845477..d070845477 100644
--- a/synapse/storage/schema/delta/21/receipts.sql
+++ b/synapse/storage/data_stores/main/schema/delta/21/receipts.sql
diff --git a/synapse/storage/schema/delta/22/receipts_index.sql b/synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql
index bfc0b3bcaa..bfc0b3bcaa 100644
--- a/synapse/storage/schema/delta/22/receipts_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/22/receipts_index.sql
diff --git a/synapse/storage/schema/delta/22/user_threepids_unique.sql b/synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql
index 87edfa454c..87edfa454c 100644
--- a/synapse/storage/schema/delta/22/user_threepids_unique.sql
+++ b/synapse/storage/data_stores/main/schema/delta/22/user_threepids_unique.sql
diff --git a/synapse/storage/schema/delta/24/stats_reporting.sql b/synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql
index acea7483bd..acea7483bd 100644
--- a/synapse/storage/schema/delta/24/stats_reporting.sql
+++ b/synapse/storage/data_stores/main/schema/delta/24/stats_reporting.sql
diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/data_stores/main/schema/delta/25/fts.py
index 4b2ffd35fd..4b2ffd35fd 100644
--- a/synapse/storage/schema/delta/25/fts.py
+++ b/synapse/storage/data_stores/main/schema/delta/25/fts.py
diff --git a/synapse/storage/schema/delta/25/guest_access.sql b/synapse/storage/data_stores/main/schema/delta/25/guest_access.sql
index 1ea389b471..1ea389b471 100644
--- a/synapse/storage/schema/delta/25/guest_access.sql
+++ b/synapse/storage/data_stores/main/schema/delta/25/guest_access.sql
diff --git a/synapse/storage/schema/delta/25/history_visibility.sql b/synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql
index f468fc1897..f468fc1897 100644
--- a/synapse/storage/schema/delta/25/history_visibility.sql
+++ b/synapse/storage/data_stores/main/schema/delta/25/history_visibility.sql
diff --git a/synapse/storage/schema/delta/25/tags.sql b/synapse/storage/data_stores/main/schema/delta/25/tags.sql
index 7a32ce68e4..7a32ce68e4 100644
--- a/synapse/storage/schema/delta/25/tags.sql
+++ b/synapse/storage/data_stores/main/schema/delta/25/tags.sql
diff --git a/synapse/storage/schema/delta/26/account_data.sql b/synapse/storage/data_stores/main/schema/delta/26/account_data.sql
index e395de2b5e..e395de2b5e 100644
--- a/synapse/storage/schema/delta/26/account_data.sql
+++ b/synapse/storage/data_stores/main/schema/delta/26/account_data.sql
diff --git a/synapse/storage/schema/delta/27/account_data.sql b/synapse/storage/data_stores/main/schema/delta/27/account_data.sql
index bf0558b5b3..bf0558b5b3 100644
--- a/synapse/storage/schema/delta/27/account_data.sql
+++ b/synapse/storage/data_stores/main/schema/delta/27/account_data.sql
diff --git a/synapse/storage/schema/delta/27/forgotten_memberships.sql b/synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql
index e2094f37fe..e2094f37fe 100644
--- a/synapse/storage/schema/delta/27/forgotten_memberships.sql
+++ b/synapse/storage/data_stores/main/schema/delta/27/forgotten_memberships.sql
diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/data_stores/main/schema/delta/27/ts.py
index 414f9f5aa0..414f9f5aa0 100644
--- a/synapse/storage/schema/delta/27/ts.py
+++ b/synapse/storage/data_stores/main/schema/delta/27/ts.py
diff --git a/synapse/storage/schema/delta/28/event_push_actions.sql b/synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql
index 4d519849df..4d519849df 100644
--- a/synapse/storage/schema/delta/28/event_push_actions.sql
+++ b/synapse/storage/data_stores/main/schema/delta/28/event_push_actions.sql
diff --git a/synapse/storage/schema/delta/28/events_room_stream.sql b/synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql
index 36609475f1..36609475f1 100644
--- a/synapse/storage/schema/delta/28/events_room_stream.sql
+++ b/synapse/storage/data_stores/main/schema/delta/28/events_room_stream.sql
diff --git a/synapse/storage/schema/delta/28/public_roms_index.sql b/synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql
index 6c1fd68c5b..6c1fd68c5b 100644
--- a/synapse/storage/schema/delta/28/public_roms_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/28/public_roms_index.sql
diff --git a/synapse/storage/schema/delta/28/receipts_user_id_index.sql b/synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql
index cb84c69baa..cb84c69baa 100644
--- a/synapse/storage/schema/delta/28/receipts_user_id_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/28/receipts_user_id_index.sql
diff --git a/synapse/storage/schema/delta/28/upgrade_times.sql b/synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql
index 3e4a9ab455..3e4a9ab455 100644
--- a/synapse/storage/schema/delta/28/upgrade_times.sql
+++ b/synapse/storage/data_stores/main/schema/delta/28/upgrade_times.sql
diff --git a/synapse/storage/schema/delta/28/users_is_guest.sql b/synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql
index 21d2b420bf..21d2b420bf 100644
--- a/synapse/storage/schema/delta/28/users_is_guest.sql
+++ b/synapse/storage/data_stores/main/schema/delta/28/users_is_guest.sql
diff --git a/synapse/storage/schema/delta/29/push_actions.sql b/synapse/storage/data_stores/main/schema/delta/29/push_actions.sql
index 84b21cf813..84b21cf813 100644
--- a/synapse/storage/schema/delta/29/push_actions.sql
+++ b/synapse/storage/data_stores/main/schema/delta/29/push_actions.sql
diff --git a/synapse/storage/schema/delta/30/alias_creator.sql b/synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql
index c9d0dde638..c9d0dde638 100644
--- a/synapse/storage/schema/delta/30/alias_creator.sql
+++ b/synapse/storage/data_stores/main/schema/delta/30/alias_creator.sql
diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/data_stores/main/schema/delta/30/as_users.py
index ef7ec34346..9b95411fb6 100644
--- a/synapse/storage/schema/delta/30/as_users.py
+++ b/synapse/storage/data_stores/main/schema/delta/30/as_users.py
@@ -40,9 +40,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
logger.warning("Could not get app_service_config_files from config")
pass
- appservices = load_appservices(
- config.server_name, config_files
- )
+ appservices = load_appservices(config.server_name, config_files)
owned = {}
@@ -53,20 +51,19 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
if user_id in owned.keys():
logger.error(
"user_id %s was owned by more than one application"
- " service (IDs %s and %s); assigning arbitrarily to %s" %
- (user_id, owned[user_id], appservice.id, owned[user_id])
+ " service (IDs %s and %s); assigning arbitrarily to %s"
+ % (user_id, owned[user_id], appservice.id, owned[user_id])
)
owned.setdefault(appservice.id, []).append(user_id)
for as_id, user_ids in owned.items():
n = 100
- user_chunks = (user_ids[i:i + 100] for i in range(0, len(user_ids), n))
+ user_chunks = (user_ids[i : i + 100] for i in range(0, len(user_ids), n))
for chunk in user_chunks:
cur.execute(
database_engine.convert_param_style(
- "UPDATE users SET appservice_id = ? WHERE name IN (%s)" % (
- ",".join("?" for _ in chunk),
- )
+ "UPDATE users SET appservice_id = ? WHERE name IN (%s)"
+ % (",".join("?" for _ in chunk),)
),
- [as_id] + chunk
+ [as_id] + chunk,
)
diff --git a/synapse/storage/schema/delta/30/deleted_pushers.sql b/synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql
index 712c454aa1..712c454aa1 100644
--- a/synapse/storage/schema/delta/30/deleted_pushers.sql
+++ b/synapse/storage/data_stores/main/schema/delta/30/deleted_pushers.sql
diff --git a/synapse/storage/schema/delta/30/presence_stream.sql b/synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql
index 606bbb037d..606bbb037d 100644
--- a/synapse/storage/schema/delta/30/presence_stream.sql
+++ b/synapse/storage/data_stores/main/schema/delta/30/presence_stream.sql
diff --git a/synapse/storage/schema/delta/30/public_rooms.sql b/synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql
index f09db4faa6..f09db4faa6 100644
--- a/synapse/storage/schema/delta/30/public_rooms.sql
+++ b/synapse/storage/data_stores/main/schema/delta/30/public_rooms.sql
diff --git a/synapse/storage/schema/delta/30/push_rule_stream.sql b/synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql
index 735aa8d5f6..735aa8d5f6 100644
--- a/synapse/storage/schema/delta/30/push_rule_stream.sql
+++ b/synapse/storage/data_stores/main/schema/delta/30/push_rule_stream.sql
diff --git a/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql b/synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql
index 0dd2f1360c..0dd2f1360c 100644
--- a/synapse/storage/schema/delta/30/threepid_guest_access_tokens.sql
+++ b/synapse/storage/data_stores/main/schema/delta/30/threepid_guest_access_tokens.sql
diff --git a/synapse/storage/schema/delta/31/invites.sql b/synapse/storage/data_stores/main/schema/delta/31/invites.sql
index 2c57846d5a..2c57846d5a 100644
--- a/synapse/storage/schema/delta/31/invites.sql
+++ b/synapse/storage/data_stores/main/schema/delta/31/invites.sql
diff --git a/synapse/storage/schema/delta/31/local_media_repository_url_cache.sql b/synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql
index 9efb4280eb..9efb4280eb 100644
--- a/synapse/storage/schema/delta/31/local_media_repository_url_cache.sql
+++ b/synapse/storage/data_stores/main/schema/delta/31/local_media_repository_url_cache.sql
diff --git a/synapse/storage/schema/delta/31/pushers.py b/synapse/storage/data_stores/main/schema/delta/31/pushers.py
index 93367fa09e..9bb504aad5 100644
--- a/synapse/storage/schema/delta/31/pushers.py
+++ b/synapse/storage/data_stores/main/schema/delta/31/pushers.py
@@ -24,12 +24,13 @@ logger = logging.getLogger(__name__)
def token_to_stream_ordering(token):
- return int(token[1:].split('_')[0])
+ return int(token[1:].split("_")[0])
def run_create(cur, database_engine, *args, **kwargs):
logger.info("Porting pushers table, delta 31...")
- cur.execute("""
+ cur.execute(
+ """
CREATE TABLE IF NOT EXISTS pushers2 (
id BIGINT PRIMARY KEY,
user_name TEXT NOT NULL,
@@ -48,26 +49,33 @@ def run_create(cur, database_engine, *args, **kwargs):
failing_since BIGINT,
UNIQUE (app_id, pushkey, user_name)
)
- """)
- cur.execute("""SELECT
+ """
+ )
+ cur.execute(
+ """SELECT
id, user_name, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, ts, lang, data, last_token, last_success,
failing_since
FROM pushers
- """)
+ """
+ )
count = 0
for row in cur.fetchall():
row = list(row)
row[12] = token_to_stream_ordering(row[12])
- cur.execute(database_engine.convert_param_style("""
+ cur.execute(
+ database_engine.convert_param_style(
+ """
INSERT into pushers2 (
id, user_name, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, ts, lang, data, last_stream_ordering, last_success,
failing_since
- ) values (%s)""" % (','.join(['?' for _ in range(len(row))]))),
- row
+ ) values (%s)"""
+ % (",".join(["?" for _ in range(len(row))]))
+ ),
+ row,
)
count += 1
cur.execute("DROP TABLE pushers")
diff --git a/synapse/storage/schema/delta/31/pushers_index.sql b/synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql
index a82add88fd..a82add88fd 100644
--- a/synapse/storage/schema/delta/31/pushers_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/31/pushers_index.sql
diff --git a/synapse/storage/schema/delta/31/search_update.py b/synapse/storage/data_stores/main/schema/delta/31/search_update.py
index 7d8ca5f93f..7d8ca5f93f 100644
--- a/synapse/storage/schema/delta/31/search_update.py
+++ b/synapse/storage/data_stores/main/schema/delta/31/search_update.py
diff --git a/synapse/storage/schema/delta/32/events.sql b/synapse/storage/data_stores/main/schema/delta/32/events.sql
index 1dd0f9e170..1dd0f9e170 100644
--- a/synapse/storage/schema/delta/32/events.sql
+++ b/synapse/storage/data_stores/main/schema/delta/32/events.sql
diff --git a/synapse/storage/schema/delta/32/openid.sql b/synapse/storage/data_stores/main/schema/delta/32/openid.sql
index 36f37b11c8..36f37b11c8 100644
--- a/synapse/storage/schema/delta/32/openid.sql
+++ b/synapse/storage/data_stores/main/schema/delta/32/openid.sql
diff --git a/synapse/storage/schema/delta/32/pusher_throttle.sql b/synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql
index d86d30c13c..d86d30c13c 100644
--- a/synapse/storage/schema/delta/32/pusher_throttle.sql
+++ b/synapse/storage/data_stores/main/schema/delta/32/pusher_throttle.sql
diff --git a/synapse/storage/schema/delta/32/remove_indices.sql b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql
index 4219cdd06a..2de50d408c 100644
--- a/synapse/storage/schema/delta/32/remove_indices.sql
+++ b/synapse/storage/data_stores/main/schema/delta/32/remove_indices.sql
@@ -20,7 +20,6 @@ DROP INDEX IF EXISTS events_room_id; -- Prefix of events_room_stream
DROP INDEX IF EXISTS events_order; -- Prefix of events_order_topo_stream_room
DROP INDEX IF EXISTS events_topological_ordering; -- Prefix of events_order_topo_stream_room
DROP INDEX IF EXISTS events_stream_ordering; -- Duplicate of PRIMARY KEY
-DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY
DROP INDEX IF EXISTS event_to_state_groups_id; -- Duplicate of PRIMARY KEY
DROP INDEX IF EXISTS event_push_actions_room_id_event_id_user_id_profile_tag; -- Duplicate of UNIQUE CONSTRAINT
diff --git a/synapse/storage/schema/delta/32/reports.sql b/synapse/storage/data_stores/main/schema/delta/32/reports.sql
index d13609776f..d13609776f 100644
--- a/synapse/storage/schema/delta/32/reports.sql
+++ b/synapse/storage/data_stores/main/schema/delta/32/reports.sql
diff --git a/synapse/storage/schema/delta/33/access_tokens_device_index.sql b/synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql
index 61ad3fe3e8..61ad3fe3e8 100644
--- a/synapse/storage/schema/delta/33/access_tokens_device_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/33/access_tokens_device_index.sql
diff --git a/synapse/storage/schema/delta/33/devices.sql b/synapse/storage/data_stores/main/schema/delta/33/devices.sql
index eca7268d82..eca7268d82 100644
--- a/synapse/storage/schema/delta/33/devices.sql
+++ b/synapse/storage/data_stores/main/schema/delta/33/devices.sql
diff --git a/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql
index aa4a3b9f2f..aa4a3b9f2f 100644
--- a/synapse/storage/schema/delta/33/devices_for_e2e_keys.sql
+++ b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys.sql
diff --git a/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql
index 6671573398..6671573398 100644
--- a/synapse/storage/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql
+++ b/synapse/storage/data_stores/main/schema/delta/33/devices_for_e2e_keys_clear_unknown_device.sql
diff --git a/synapse/storage/schema/delta/33/event_fields.py b/synapse/storage/data_stores/main/schema/delta/33/event_fields.py
index bff1256a7b..bff1256a7b 100644
--- a/synapse/storage/schema/delta/33/event_fields.py
+++ b/synapse/storage/data_stores/main/schema/delta/33/event_fields.py
diff --git a/synapse/storage/schema/delta/33/remote_media_ts.py b/synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py
index 9754d3ccfb..a26057dfb6 100644
--- a/synapse/storage/schema/delta/33/remote_media_ts.py
+++ b/synapse/storage/data_stores/main/schema/delta/33/remote_media_ts.py
@@ -26,5 +26,5 @@ def run_upgrade(cur, database_engine, *args, **kwargs):
database_engine.convert_param_style(
"UPDATE remote_media_cache SET last_access_ts = ?"
),
- (int(time.time() * 1000),)
+ (int(time.time() * 1000),),
)
diff --git a/synapse/storage/schema/delta/33/user_ips_index.sql b/synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql
index 473f75a78e..473f75a78e 100644
--- a/synapse/storage/schema/delta/33/user_ips_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/33/user_ips_index.sql
diff --git a/synapse/storage/schema/delta/34/appservice_stream.sql b/synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql
index 69e16eda0f..69e16eda0f 100644
--- a/synapse/storage/schema/delta/34/appservice_stream.sql
+++ b/synapse/storage/data_stores/main/schema/delta/34/appservice_stream.sql
diff --git a/synapse/storage/schema/delta/34/cache_stream.py b/synapse/storage/data_stores/main/schema/delta/34/cache_stream.py
index cf09e43e2b..cf09e43e2b 100644
--- a/synapse/storage/schema/delta/34/cache_stream.py
+++ b/synapse/storage/data_stores/main/schema/delta/34/cache_stream.py
diff --git a/synapse/storage/schema/delta/34/device_inbox.sql b/synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql
index e68844c74a..e68844c74a 100644
--- a/synapse/storage/schema/delta/34/device_inbox.sql
+++ b/synapse/storage/data_stores/main/schema/delta/34/device_inbox.sql
diff --git a/synapse/storage/schema/delta/34/push_display_name_rename.sql b/synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql
index 0d9fe1a99a..0d9fe1a99a 100644
--- a/synapse/storage/schema/delta/34/push_display_name_rename.sql
+++ b/synapse/storage/data_stores/main/schema/delta/34/push_display_name_rename.sql
diff --git a/synapse/storage/schema/delta/34/received_txn_purge.py b/synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py
index 67d505e68b..67d505e68b 100644
--- a/synapse/storage/schema/delta/34/received_txn_purge.py
+++ b/synapse/storage/data_stores/main/schema/delta/34/received_txn_purge.py
diff --git a/synapse/storage/schema/delta/35/contains_url.sql b/synapse/storage/data_stores/main/schema/delta/35/contains_url.sql
index 6cd123027b..6cd123027b 100644
--- a/synapse/storage/schema/delta/35/contains_url.sql
+++ b/synapse/storage/data_stores/main/schema/delta/35/contains_url.sql
diff --git a/synapse/storage/schema/delta/35/device_outbox.sql b/synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql
index 17e6c43105..17e6c43105 100644
--- a/synapse/storage/schema/delta/35/device_outbox.sql
+++ b/synapse/storage/data_stores/main/schema/delta/35/device_outbox.sql
diff --git a/synapse/storage/schema/delta/35/device_stream_id.sql b/synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql
index 7ab7d942e2..7ab7d942e2 100644
--- a/synapse/storage/schema/delta/35/device_stream_id.sql
+++ b/synapse/storage/data_stores/main/schema/delta/35/device_stream_id.sql
diff --git a/synapse/storage/schema/delta/35/event_push_actions_index.sql b/synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql
index 2e836d8e9c..2e836d8e9c 100644
--- a/synapse/storage/schema/delta/35/event_push_actions_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/35/event_push_actions_index.sql
diff --git a/synapse/storage/schema/delta/35/public_room_list_change_stream.sql b/synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql
index dd2bf2e28a..dd2bf2e28a 100644
--- a/synapse/storage/schema/delta/35/public_room_list_change_stream.sql
+++ b/synapse/storage/data_stores/main/schema/delta/35/public_room_list_change_stream.sql
diff --git a/synapse/storage/schema/delta/35/stream_order_to_extrem.sql b/synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql
index 2b945d8a57..2b945d8a57 100644
--- a/synapse/storage/schema/delta/35/stream_order_to_extrem.sql
+++ b/synapse/storage/data_stores/main/schema/delta/35/stream_order_to_extrem.sql
diff --git a/synapse/storage/schema/delta/36/readd_public_rooms.sql b/synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql
index 90d8fd18f9..90d8fd18f9 100644
--- a/synapse/storage/schema/delta/36/readd_public_rooms.sql
+++ b/synapse/storage/data_stores/main/schema/delta/36/readd_public_rooms.sql
diff --git a/synapse/storage/schema/delta/37/remove_auth_idx.py b/synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py
index a377884169..a377884169 100644
--- a/synapse/storage/schema/delta/37/remove_auth_idx.py
+++ b/synapse/storage/data_stores/main/schema/delta/37/remove_auth_idx.py
diff --git a/synapse/storage/schema/delta/37/user_threepids.sql b/synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql
index cf7a90dd10..cf7a90dd10 100644
--- a/synapse/storage/schema/delta/37/user_threepids.sql
+++ b/synapse/storage/data_stores/main/schema/delta/37/user_threepids.sql
diff --git a/synapse/storage/schema/delta/38/postgres_fts_gist.sql b/synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql
index 515e6b8e84..515e6b8e84 100644
--- a/synapse/storage/schema/delta/38/postgres_fts_gist.sql
+++ b/synapse/storage/data_stores/main/schema/delta/38/postgres_fts_gist.sql
diff --git a/synapse/storage/schema/delta/39/appservice_room_list.sql b/synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql
index 74bdc49073..74bdc49073 100644
--- a/synapse/storage/schema/delta/39/appservice_room_list.sql
+++ b/synapse/storage/data_stores/main/schema/delta/39/appservice_room_list.sql
diff --git a/synapse/storage/schema/delta/39/device_federation_stream_idx.sql b/synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql
index 00be801e90..00be801e90 100644
--- a/synapse/storage/schema/delta/39/device_federation_stream_idx.sql
+++ b/synapse/storage/data_stores/main/schema/delta/39/device_federation_stream_idx.sql
diff --git a/synapse/storage/schema/delta/39/event_push_index.sql b/synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql
index de2ad93e5c..de2ad93e5c 100644
--- a/synapse/storage/schema/delta/39/event_push_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/39/event_push_index.sql
diff --git a/synapse/storage/schema/delta/39/federation_out_position.sql b/synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql
index 5af814290b..5af814290b 100644
--- a/synapse/storage/schema/delta/39/federation_out_position.sql
+++ b/synapse/storage/data_stores/main/schema/delta/39/federation_out_position.sql
diff --git a/synapse/storage/schema/delta/39/membership_profile.sql b/synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql
index 1bf911c8ab..1bf911c8ab 100644
--- a/synapse/storage/schema/delta/39/membership_profile.sql
+++ b/synapse/storage/data_stores/main/schema/delta/39/membership_profile.sql
diff --git a/synapse/storage/schema/delta/40/current_state_idx.sql b/synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql
index 7ffa189f39..7ffa189f39 100644
--- a/synapse/storage/schema/delta/40/current_state_idx.sql
+++ b/synapse/storage/data_stores/main/schema/delta/40/current_state_idx.sql
diff --git a/synapse/storage/schema/delta/40/device_inbox.sql b/synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql
index b9fe1f0480..b9fe1f0480 100644
--- a/synapse/storage/schema/delta/40/device_inbox.sql
+++ b/synapse/storage/data_stores/main/schema/delta/40/device_inbox.sql
diff --git a/synapse/storage/schema/delta/40/device_list_streams.sql b/synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql
index dd6dcb65f1..dd6dcb65f1 100644
--- a/synapse/storage/schema/delta/40/device_list_streams.sql
+++ b/synapse/storage/data_stores/main/schema/delta/40/device_list_streams.sql
diff --git a/synapse/storage/schema/delta/40/event_push_summary.sql b/synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql
index 3918f0b794..3918f0b794 100644
--- a/synapse/storage/schema/delta/40/event_push_summary.sql
+++ b/synapse/storage/data_stores/main/schema/delta/40/event_push_summary.sql
diff --git a/synapse/storage/schema/delta/40/pushers.sql b/synapse/storage/data_stores/main/schema/delta/40/pushers.sql
index 054a223f14..054a223f14 100644
--- a/synapse/storage/schema/delta/40/pushers.sql
+++ b/synapse/storage/data_stores/main/schema/delta/40/pushers.sql
diff --git a/synapse/storage/schema/delta/41/device_list_stream_idx.sql b/synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql
index b7bee8b692..b7bee8b692 100644
--- a/synapse/storage/schema/delta/41/device_list_stream_idx.sql
+++ b/synapse/storage/data_stores/main/schema/delta/41/device_list_stream_idx.sql
diff --git a/synapse/storage/schema/delta/41/device_outbound_index.sql b/synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql
index 62f0b9892b..62f0b9892b 100644
--- a/synapse/storage/schema/delta/41/device_outbound_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/41/device_outbound_index.sql
diff --git a/synapse/storage/schema/delta/41/event_search_event_id_idx.sql b/synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql
index 5d9cfecf36..5d9cfecf36 100644
--- a/synapse/storage/schema/delta/41/event_search_event_id_idx.sql
+++ b/synapse/storage/data_stores/main/schema/delta/41/event_search_event_id_idx.sql
diff --git a/synapse/storage/schema/delta/41/ratelimit.sql b/synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql
index a194bf0238..a194bf0238 100644
--- a/synapse/storage/schema/delta/41/ratelimit.sql
+++ b/synapse/storage/data_stores/main/schema/delta/41/ratelimit.sql
diff --git a/synapse/storage/schema/delta/42/current_state_delta.sql b/synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql
index d28851aff8..d28851aff8 100644
--- a/synapse/storage/schema/delta/42/current_state_delta.sql
+++ b/synapse/storage/data_stores/main/schema/delta/42/current_state_delta.sql
diff --git a/synapse/storage/schema/delta/42/device_list_last_id.sql b/synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql
index 9ab8c14fa3..9ab8c14fa3 100644
--- a/synapse/storage/schema/delta/42/device_list_last_id.sql
+++ b/synapse/storage/data_stores/main/schema/delta/42/device_list_last_id.sql
diff --git a/synapse/storage/schema/delta/42/event_auth_state_only.sql b/synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql
index b8821ac759..b8821ac759 100644
--- a/synapse/storage/schema/delta/42/event_auth_state_only.sql
+++ b/synapse/storage/data_stores/main/schema/delta/42/event_auth_state_only.sql
diff --git a/synapse/storage/schema/delta/42/user_dir.py b/synapse/storage/data_stores/main/schema/delta/42/user_dir.py
index 506f326f4d..506f326f4d 100644
--- a/synapse/storage/schema/delta/42/user_dir.py
+++ b/synapse/storage/data_stores/main/schema/delta/42/user_dir.py
diff --git a/synapse/storage/schema/delta/43/blocked_rooms.sql b/synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql
index 0e3cd143ff..0e3cd143ff 100644
--- a/synapse/storage/schema/delta/43/blocked_rooms.sql
+++ b/synapse/storage/data_stores/main/schema/delta/43/blocked_rooms.sql
diff --git a/synapse/storage/schema/delta/43/quarantine_media.sql b/synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql
index 630907ec4f..630907ec4f 100644
--- a/synapse/storage/schema/delta/43/quarantine_media.sql
+++ b/synapse/storage/data_stores/main/schema/delta/43/quarantine_media.sql
diff --git a/synapse/storage/schema/delta/43/url_cache.sql b/synapse/storage/data_stores/main/schema/delta/43/url_cache.sql
index 45ebe020da..45ebe020da 100644
--- a/synapse/storage/schema/delta/43/url_cache.sql
+++ b/synapse/storage/data_stores/main/schema/delta/43/url_cache.sql
diff --git a/synapse/storage/schema/delta/43/user_share.sql b/synapse/storage/data_stores/main/schema/delta/43/user_share.sql
index ee7062abe4..ee7062abe4 100644
--- a/synapse/storage/schema/delta/43/user_share.sql
+++ b/synapse/storage/data_stores/main/schema/delta/43/user_share.sql
diff --git a/synapse/storage/schema/delta/44/expire_url_cache.sql b/synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql
index b12f9b2ebf..b12f9b2ebf 100644
--- a/synapse/storage/schema/delta/44/expire_url_cache.sql
+++ b/synapse/storage/data_stores/main/schema/delta/44/expire_url_cache.sql
diff --git a/synapse/storage/schema/delta/45/group_server.sql b/synapse/storage/data_stores/main/schema/delta/45/group_server.sql
index b2333848a0..b2333848a0 100644
--- a/synapse/storage/schema/delta/45/group_server.sql
+++ b/synapse/storage/data_stores/main/schema/delta/45/group_server.sql
diff --git a/synapse/storage/schema/delta/45/profile_cache.sql b/synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql
index e5ddc84df0..e5ddc84df0 100644
--- a/synapse/storage/schema/delta/45/profile_cache.sql
+++ b/synapse/storage/data_stores/main/schema/delta/45/profile_cache.sql
diff --git a/synapse/storage/schema/delta/46/drop_refresh_tokens.sql b/synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql
index 68c48a89a9..68c48a89a9 100644
--- a/synapse/storage/schema/delta/46/drop_refresh_tokens.sql
+++ b/synapse/storage/data_stores/main/schema/delta/46/drop_refresh_tokens.sql
diff --git a/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql b/synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql
index bb307889c1..bb307889c1 100644
--- a/synapse/storage/schema/delta/46/drop_unique_deleted_pushers.sql
+++ b/synapse/storage/data_stores/main/schema/delta/46/drop_unique_deleted_pushers.sql
diff --git a/synapse/storage/schema/delta/46/group_server.sql b/synapse/storage/data_stores/main/schema/delta/46/group_server.sql
index 097679bc9a..097679bc9a 100644
--- a/synapse/storage/schema/delta/46/group_server.sql
+++ b/synapse/storage/data_stores/main/schema/delta/46/group_server.sql
diff --git a/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql b/synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql
index bbfc7f5d1a..bbfc7f5d1a 100644
--- a/synapse/storage/schema/delta/46/local_media_repository_url_idx.sql
+++ b/synapse/storage/data_stores/main/schema/delta/46/local_media_repository_url_idx.sql
diff --git a/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql b/synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql
index cb0d5a2576..cb0d5a2576 100644
--- a/synapse/storage/schema/delta/46/user_dir_null_room_ids.sql
+++ b/synapse/storage/data_stores/main/schema/delta/46/user_dir_null_room_ids.sql
diff --git a/synapse/storage/schema/delta/46/user_dir_typos.sql b/synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql
index d9505f8da1..d9505f8da1 100644
--- a/synapse/storage/schema/delta/46/user_dir_typos.sql
+++ b/synapse/storage/data_stores/main/schema/delta/46/user_dir_typos.sql
diff --git a/synapse/storage/schema/delta/47/last_access_media.sql b/synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql
index f505fb22b5..f505fb22b5 100644
--- a/synapse/storage/schema/delta/47/last_access_media.sql
+++ b/synapse/storage/data_stores/main/schema/delta/47/last_access_media.sql
diff --git a/synapse/storage/schema/delta/47/postgres_fts_gin.sql b/synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql
index 31d7a817eb..31d7a817eb 100644
--- a/synapse/storage/schema/delta/47/postgres_fts_gin.sql
+++ b/synapse/storage/data_stores/main/schema/delta/47/postgres_fts_gin.sql
diff --git a/synapse/storage/schema/delta/47/push_actions_staging.sql b/synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql
index edccf4a96f..edccf4a96f 100644
--- a/synapse/storage/schema/delta/47/push_actions_staging.sql
+++ b/synapse/storage/data_stores/main/schema/delta/47/push_actions_staging.sql
diff --git a/synapse/storage/schema/delta/48/add_user_consent.sql b/synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql
index 5237491506..5237491506 100644
--- a/synapse/storage/schema/delta/48/add_user_consent.sql
+++ b/synapse/storage/data_stores/main/schema/delta/48/add_user_consent.sql
diff --git a/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql b/synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql
index 9248b0b24a..9248b0b24a 100644
--- a/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/48/add_user_ips_last_seen_index.sql
diff --git a/synapse/storage/schema/delta/48/deactivated_users.sql b/synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql
index e9013a6969..e9013a6969 100644
--- a/synapse/storage/schema/delta/48/deactivated_users.sql
+++ b/synapse/storage/data_stores/main/schema/delta/48/deactivated_users.sql
diff --git a/synapse/storage/schema/delta/48/group_unique_indexes.py b/synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py
index 2233af87d7..49f5f2c003 100644
--- a/synapse/storage/schema/delta/48/group_unique_indexes.py
+++ b/synapse/storage/data_stores/main/schema/delta/48/group_unique_indexes.py
@@ -38,16 +38,22 @@ def run_create(cur, database_engine, *args, **kwargs):
rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid"
# remove duplicates from group_users & group_invites tables
- cur.execute("""
+ cur.execute(
+ """
DELETE FROM group_users WHERE %s NOT IN (
SELECT min(%s) FROM group_users GROUP BY group_id, user_id
);
- """ % (rowid, rowid))
- cur.execute("""
+ """
+ % (rowid, rowid)
+ )
+ cur.execute(
+ """
DELETE FROM group_invites WHERE %s NOT IN (
SELECT min(%s) FROM group_invites GROUP BY group_id, user_id
);
- """ % (rowid, rowid))
+ """
+ % (rowid, rowid)
+ )
for statement in get_statements(FIX_INDEXES.splitlines()):
cur.execute(statement)
diff --git a/synapse/storage/schema/delta/48/groups_joinable.sql b/synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql
index ce26eaf0c9..ce26eaf0c9 100644
--- a/synapse/storage/schema/delta/48/groups_joinable.sql
+++ b/synapse/storage/data_stores/main/schema/delta/48/groups_joinable.sql
diff --git a/synapse/storage/schema/delta/48/profiles_batch.sql b/synapse/storage/data_stores/main/schema/delta/48/profiles_batch.sql
index e744c02fe8..e744c02fe8 100644
--- a/synapse/storage/schema/delta/48/profiles_batch.sql
+++ b/synapse/storage/data_stores/main/schema/delta/48/profiles_batch.sql
diff --git a/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql b/synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql
index 14dcf18d73..14dcf18d73 100644
--- a/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql
+++ b/synapse/storage/data_stores/main/schema/delta/49/add_user_consent_server_notice_sent.sql
diff --git a/synapse/storage/schema/delta/49/add_user_daily_visits.sql b/synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql
index 3dd478196f..3dd478196f 100644
--- a/synapse/storage/schema/delta/49/add_user_daily_visits.sql
+++ b/synapse/storage/data_stores/main/schema/delta/49/add_user_daily_visits.sql
diff --git a/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql b/synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql
index 3a4ed59b5b..3a4ed59b5b 100644
--- a/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/49/add_user_ips_last_seen_only_index.sql
diff --git a/synapse/storage/schema/delta/50/add_creation_ts_users_index.sql b/synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql
index c93ae47532..c93ae47532 100644
--- a/synapse/storage/schema/delta/50/add_creation_ts_users_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/50/add_creation_ts_users_index.sql
diff --git a/synapse/storage/schema/delta/50/erasure_store.sql b/synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql
index 5d8641a9ab..5d8641a9ab 100644
--- a/synapse/storage/schema/delta/50/erasure_store.sql
+++ b/synapse/storage/data_stores/main/schema/delta/50/erasure_store.sql
diff --git a/synapse/storage/schema/delta/50/make_event_content_nullable.py b/synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py
index 6dd467b6c5..b1684a8441 100644
--- a/synapse/storage/schema/delta/50/make_event_content_nullable.py
+++ b/synapse/storage/data_stores/main/schema/delta/50/make_event_content_nullable.py
@@ -65,14 +65,18 @@ def run_create(cur, database_engine, *args, **kwargs):
def run_upgrade(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine):
- cur.execute("""
+ cur.execute(
+ """
ALTER TABLE events ALTER COLUMN content DROP NOT NULL;
- """)
+ """
+ )
return
# sqlite is an arse about this. ref: https://www.sqlite.org/lang_altertable.html
- cur.execute("SELECT sql FROM sqlite_master WHERE tbl_name='events' AND type='table'")
+ cur.execute(
+ "SELECT sql FROM sqlite_master WHERE tbl_name='events' AND type='table'"
+ )
(oldsql,) = cur.fetchone()
sql = oldsql.replace("content TEXT NOT NULL", "content TEXT")
@@ -86,7 +90,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs):
cur.execute("PRAGMA writable_schema=ON")
cur.execute(
"UPDATE sqlite_master SET sql=? WHERE tbl_name='events' AND type='table'",
- (sql, ),
+ (sql,),
)
cur.execute("PRAGMA schema_version=%i" % (oldver + 1,))
cur.execute("PRAGMA writable_schema=OFF")
diff --git a/synapse/storage/schema/delta/50/profiles_deactivated_users.sql b/synapse/storage/data_stores/main/schema/delta/50/profiles_deactivated_users.sql
index c8893ecbe8..c8893ecbe8 100644
--- a/synapse/storage/schema/delta/50/profiles_deactivated_users.sql
+++ b/synapse/storage/data_stores/main/schema/delta/50/profiles_deactivated_users.sql
diff --git a/synapse/storage/schema/delta/51/e2e_room_keys.sql b/synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql
index c0e66a697d..c0e66a697d 100644
--- a/synapse/storage/schema/delta/51/e2e_room_keys.sql
+++ b/synapse/storage/data_stores/main/schema/delta/51/e2e_room_keys.sql
diff --git a/synapse/storage/schema/delta/51/monthly_active_users.sql b/synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql
index c9d537d5a3..c9d537d5a3 100644
--- a/synapse/storage/schema/delta/51/monthly_active_users.sql
+++ b/synapse/storage/data_stores/main/schema/delta/51/monthly_active_users.sql
diff --git a/synapse/storage/schema/delta/52/add_event_to_state_group_index.sql b/synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql
index 91e03d13e1..91e03d13e1 100644
--- a/synapse/storage/schema/delta/52/add_event_to_state_group_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/52/add_event_to_state_group_index.sql
diff --git a/synapse/storage/schema/delta/52/device_list_streams_unique_idx.sql b/synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql
index bfa49e6f92..bfa49e6f92 100644
--- a/synapse/storage/schema/delta/52/device_list_streams_unique_idx.sql
+++ b/synapse/storage/data_stores/main/schema/delta/52/device_list_streams_unique_idx.sql
diff --git a/synapse/storage/schema/delta/52/e2e_room_keys.sql b/synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql
index db687cccae..db687cccae 100644
--- a/synapse/storage/schema/delta/52/e2e_room_keys.sql
+++ b/synapse/storage/data_stores/main/schema/delta/52/e2e_room_keys.sql
diff --git a/synapse/storage/schema/delta/53/add_user_type_to_users.sql b/synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql
index 88ec2f83e5..88ec2f83e5 100644
--- a/synapse/storage/schema/delta/53/add_user_type_to_users.sql
+++ b/synapse/storage/data_stores/main/schema/delta/53/add_user_type_to_users.sql
diff --git a/synapse/storage/schema/delta/53/drop_sent_transactions.sql b/synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql
index e372f5a44a..e372f5a44a 100644
--- a/synapse/storage/schema/delta/53/drop_sent_transactions.sql
+++ b/synapse/storage/data_stores/main/schema/delta/53/drop_sent_transactions.sql
diff --git a/synapse/storage/schema/delta/53/event_format_version.sql b/synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql
index 1d977c2834..1d977c2834 100644
--- a/synapse/storage/schema/delta/53/event_format_version.sql
+++ b/synapse/storage/data_stores/main/schema/delta/53/event_format_version.sql
diff --git a/synapse/storage/schema/delta/53/user_dir_populate.sql b/synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql
index ffcc896b58..ffcc896b58 100644
--- a/synapse/storage/schema/delta/53/user_dir_populate.sql
+++ b/synapse/storage/data_stores/main/schema/delta/53/user_dir_populate.sql
diff --git a/synapse/storage/schema/delta/53/user_ips_index.sql b/synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql
index b812c5794f..b812c5794f 100644
--- a/synapse/storage/schema/delta/53/user_ips_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/53/user_ips_index.sql
diff --git a/synapse/storage/schema/delta/53/user_share.sql b/synapse/storage/data_stores/main/schema/delta/53/user_share.sql
index 5831b1a6f8..5831b1a6f8 100644
--- a/synapse/storage/schema/delta/53/user_share.sql
+++ b/synapse/storage/data_stores/main/schema/delta/53/user_share.sql
diff --git a/synapse/storage/schema/delta/53/user_threepid_id.sql b/synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql
index 80c2c573b6..80c2c573b6 100644
--- a/synapse/storage/schema/delta/53/user_threepid_id.sql
+++ b/synapse/storage/data_stores/main/schema/delta/53/user_threepid_id.sql
diff --git a/synapse/storage/schema/delta/53/users_in_public_rooms.sql b/synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql
index f7827ca6d2..f7827ca6d2 100644
--- a/synapse/storage/schema/delta/53/users_in_public_rooms.sql
+++ b/synapse/storage/data_stores/main/schema/delta/53/users_in_public_rooms.sql
diff --git a/synapse/storage/schema/delta/54/account_validity_with_renewal.sql b/synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql
index 0adb2ad55e..0adb2ad55e 100644
--- a/synapse/storage/schema/delta/54/account_validity_with_renewal.sql
+++ b/synapse/storage/data_stores/main/schema/delta/54/account_validity_with_renewal.sql
diff --git a/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql b/synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql
index c01aa9d2d9..c01aa9d2d9 100644
--- a/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql
+++ b/synapse/storage/data_stores/main/schema/delta/54/add_validity_to_server_keys.sql
diff --git a/synapse/storage/schema/delta/54/delete_forward_extremities.sql b/synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql
index b062ec840c..b062ec840c 100644
--- a/synapse/storage/schema/delta/54/delete_forward_extremities.sql
+++ b/synapse/storage/data_stores/main/schema/delta/54/delete_forward_extremities.sql
diff --git a/synapse/storage/schema/delta/54/drop_legacy_tables.sql b/synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql
index dbbe682697..dbbe682697 100644
--- a/synapse/storage/schema/delta/54/drop_legacy_tables.sql
+++ b/synapse/storage/data_stores/main/schema/delta/54/drop_legacy_tables.sql
diff --git a/synapse/storage/schema/delta/54/drop_presence_list.sql b/synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql
index e6ee70c623..e6ee70c623 100644
--- a/synapse/storage/schema/delta/54/drop_presence_list.sql
+++ b/synapse/storage/data_stores/main/schema/delta/54/drop_presence_list.sql
diff --git a/synapse/storage/schema/delta/54/relations.sql b/synapse/storage/data_stores/main/schema/delta/54/relations.sql
index 134862b870..134862b870 100644
--- a/synapse/storage/schema/delta/54/relations.sql
+++ b/synapse/storage/data_stores/main/schema/delta/54/relations.sql
diff --git a/synapse/storage/schema/delta/54/stats.sql b/synapse/storage/data_stores/main/schema/delta/54/stats.sql
index 652e58308e..652e58308e 100644
--- a/synapse/storage/schema/delta/54/stats.sql
+++ b/synapse/storage/data_stores/main/schema/delta/54/stats.sql
diff --git a/synapse/storage/schema/delta/54/stats2.sql b/synapse/storage/data_stores/main/schema/delta/54/stats2.sql
index 3b2d48447f..3b2d48447f 100644
--- a/synapse/storage/schema/delta/54/stats2.sql
+++ b/synapse/storage/data_stores/main/schema/delta/54/stats2.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql b/synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql
new file mode 100644
index 0000000000..4590604bfd
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/55/access_token_expiry.sql
@@ -0,0 +1,18 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- when this access token can be used until, in ms since the epoch. NULL means the token
+-- never expires.
+ALTER TABLE access_tokens ADD COLUMN valid_until_ms BIGINT;
diff --git a/synapse/storage/schema/delta/55/profile_replication_status_index.sql b/synapse/storage/data_stores/main/schema/delta/55/profile_replication_status_index.sql
index 18a0f7e10c..18a0f7e10c 100644
--- a/synapse/storage/schema/delta/55/profile_replication_status_index.sql
+++ b/synapse/storage/data_stores/main/schema/delta/55/profile_replication_status_index.sql
diff --git a/synapse/storage/schema/delta/55/room_retention.sql b/synapse/storage/data_stores/main/schema/delta/55/room_retention.sql
index ee6cdf7a14..ee6cdf7a14 100644
--- a/synapse/storage/schema/delta/55/room_retention.sql
+++ b/synapse/storage/data_stores/main/schema/delta/55/room_retention.sql
diff --git a/synapse/storage/schema/delta/55/track_threepid_validations.sql b/synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql
index a8eced2e0a..a8eced2e0a 100644
--- a/synapse/storage/schema/delta/55/track_threepid_validations.sql
+++ b/synapse/storage/data_stores/main/schema/delta/55/track_threepid_validations.sql
diff --git a/synapse/storage/schema/delta/55/users_alter_deactivated.sql b/synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql
index dabdde489b..dabdde489b 100644
--- a/synapse/storage/schema/delta/55/users_alter_deactivated.sql
+++ b/synapse/storage/data_stores/main/schema/delta/55/users_alter_deactivated.sql
diff --git a/synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql b/synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql
new file mode 100644
index 0000000000..41807eb1e7
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/add_spans_to_device_lists.sql
@@ -0,0 +1,20 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Opentracing context data for inclusion in the device_list_update EDUs, as a
+ * json-encoded dictionary. NULL if opentracing is disabled (or not enabled for this destination).
+ */
+ALTER TABLE device_lists_outbound_pokes ADD opentracing_context TEXT;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql
new file mode 100644
index 0000000000..473018676f
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership.sql
@@ -0,0 +1,22 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We add membership to current state so that we don't need to join against
+-- room_memberships, which can be surprisingly costly (we do such queries
+-- very frequently).
+-- This will be null for non-membership events and the content.membership key
+-- for membership events. (Will also be null for membership events until the
+-- background update job has finished).
+ALTER TABLE current_state_events ADD membership TEXT;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql
new file mode 100644
index 0000000000..3133d42d4a
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/current_state_events_membership_mk2.sql
@@ -0,0 +1,24 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We add membership to current state so that we don't need to join against
+-- room_memberships, which can be surprisingly costly (we do such queries
+-- very frequently).
+-- This will be null for non-membership events and the content.membership key
+-- for membership events. (Will also be null for membership events until the
+-- background update job has finished).
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('current_state_events_membership', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql b/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql
new file mode 100644
index 0000000000..1d2ddb1b1a
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/delete_keys_from_deleted_backups.sql
@@ -0,0 +1,25 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* delete room keys that belong to deleted room key version, or to room key
+ * versions that don't exist (anymore)
+ */
+DELETE FROM e2e_room_keys
+WHERE version NOT IN (
+ SELECT version
+ FROM e2e_room_keys_versions
+ WHERE e2e_room_keys.user_id = e2e_room_keys_versions.user_id
+ AND e2e_room_keys_versions.deleted = 0
+);
diff --git a/synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql b/synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql
new file mode 100644
index 0000000000..f00889290b
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/destinations_failure_ts.sql
@@ -0,0 +1,25 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Record the timestamp when a given server started failing
+ */
+ALTER TABLE destinations ADD failure_ts BIGINT;
+
+/* as a rough approximation, we assume that the server started failing at
+ * retry_interval before the last retry
+ */
+UPDATE destinations SET failure_ts = retry_last_ts - retry_interval
+ WHERE retry_last_ts > 0;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres b/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres
new file mode 100644
index 0000000000..b9bbb18a91
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/destinations_retry_interval_type.sql.postgres
@@ -0,0 +1,18 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We want to store large retry intervals so we upgrade the column from INT
+-- to BIGINT. We don't need to do this on SQLite.
+ALTER TABLE destinations ALTER retry_interval SET DATA TYPE BIGINT;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql b/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql
new file mode 100644
index 0000000000..c2f557fde9
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/device_stream_id_insert.sql
@@ -0,0 +1,20 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This line already existed in deltas/35/device_stream_id but was not included in the
+-- 54 full schema SQL. Add some SQL here to insert the missing row if it does not exist
+INSERT INTO device_max_stream_id (stream_id) SELECT 0 WHERE NOT EXISTS (
+ SELECT * from device_max_stream_id
+);
\ No newline at end of file
diff --git a/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql b/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql
new file mode 100644
index 0000000000..dfa902d0ba
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/devices_last_seen.sql
@@ -0,0 +1,24 @@
+/* Copyright 2019 Matrix.org Foundation CIC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Track last seen information for a device in the devices table, rather
+-- than relying on it being in the user_ips table (which we want to be able
+-- to purge old entries from)
+ALTER TABLE devices ADD COLUMN last_seen BIGINT;
+ALTER TABLE devices ADD COLUMN ip TEXT;
+ALTER TABLE devices ADD COLUMN user_agent TEXT;
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('devices_last_seen', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql b/synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql
new file mode 100644
index 0000000000..9f09922c67
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/drop_unused_event_tables.sql
@@ -0,0 +1,20 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- these tables are never used.
+DROP TABLE IF EXISTS room_names;
+DROP TABLE IF EXISTS topics;
+DROP TABLE IF EXISTS history_visibility;
+DROP TABLE IF EXISTS guest_access;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql
new file mode 100644
index 0000000000..81a36a8b1d
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql
@@ -0,0 +1,21 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS event_expiry (
+ event_id TEXT PRIMARY KEY,
+ expiry_ts BIGINT NOT NULL
+);
+
+CREATE INDEX event_expiry_expiry_ts_idx ON event_expiry(expiry_ts);
diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql
new file mode 100644
index 0000000000..5e29c1da19
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/event_labels.sql
@@ -0,0 +1,30 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- room_id and topoligical_ordering are denormalised from the events table in order to
+-- make the index work.
+CREATE TABLE IF NOT EXISTS event_labels (
+ event_id TEXT,
+ label TEXT,
+ room_id TEXT NOT NULL,
+ topological_ordering BIGINT NOT NULL,
+ PRIMARY KEY(event_id, label)
+);
+
+
+-- This index enables an event pagination looking for a particular label to index the
+-- event_labels table first, which is much quicker than scanning the events table and then
+-- filtering by label, if the label is rarely used relative to the size of the room.
+CREATE INDEX event_labels_room_id_label_idx ON event_labels(room_id, label, topological_ordering);
diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql b/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql
new file mode 100644
index 0000000000..5f5e0499ae
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/event_labels_background_update.sql
@@ -0,0 +1,17 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('event_store_labels', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql b/synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql
new file mode 100644
index 0000000000..014cb3b538
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/fix_room_keys_index.sql
@@ -0,0 +1,18 @@
+/* Copyright 2019 Matrix.org Foundation CIC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- version is supposed to be part of the room keys index
+CREATE UNIQUE INDEX e2e_room_keys_with_version_idx ON e2e_room_keys(user_id, version, room_id, session_id);
+DROP INDEX IF EXISTS e2e_room_keys_idx;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql
new file mode 100644
index 0000000000..67f8b20297
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices.sql
@@ -0,0 +1,18 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- device list needs to know which ones are "real" devices, and which ones are
+-- just used to avoid collisions
+ALTER TABLE devices ADD COLUMN hidden BOOLEAN DEFAULT FALSE;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite
new file mode 100644
index 0000000000..e8b1fd35d8
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/hidden_devices_fix.sql.sqlite
@@ -0,0 +1,42 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* Change the hidden column from a default value of FALSE to a default value of
+ * 0, because sqlite3 prior to 3.23.0 caused the hidden column to contain the
+ * string 'FALSE', which is truthy.
+ *
+ * Since sqlite doesn't allow us to just change the default value, we have to
+ * recreate the table, copy the data, fix the rows that have incorrect data, and
+ * replace the old table with the new table.
+ */
+
+CREATE TABLE IF NOT EXISTS devices2 (
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ display_name TEXT,
+ last_seen BIGINT,
+ ip TEXT,
+ user_agent TEXT,
+ hidden BOOLEAN DEFAULT 0,
+ CONSTRAINT device_uniqueness UNIQUE (user_id, device_id)
+);
+
+INSERT INTO devices2 SELECT * FROM devices;
+
+UPDATE devices2 SET hidden = 0 WHERE hidden = 'FALSE';
+
+DROP TABLE devices;
+
+ALTER TABLE devices2 RENAME TO devices;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql b/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql
new file mode 100644
index 0000000000..4f24c1405d
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/nuke_empty_communities_from_db.sql
@@ -0,0 +1,29 @@
+/* Copyright 2019 Werner Sembach
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Groups/communities now get deleted when the last member leaves. This is a one time cleanup to remove old groups/communities that were already empty before that change was made.
+DELETE FROM group_attestations_remote WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_attestations_renewals WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_invites WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_roles WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_room_categories WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_rooms WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_summary_roles WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_summary_room_categories WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_summary_rooms WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM group_summary_users WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM local_group_membership WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM local_group_updates WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
+DELETE FROM groups WHERE group_id IN (SELECT group_id FROM groups WHERE NOT EXISTS (SELECT group_id FROM group_users WHERE group_id = groups.group_id));
diff --git a/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql b/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql
new file mode 100644
index 0000000000..7be31ffebb
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/public_room_list_idx.sql
@@ -0,0 +1,16 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE INDEX public_room_list_stream_network ON public_room_list_stream (appservice_id, network_id, room_id);
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
new file mode 100644
index 0000000000..ea95db0ed7
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor.sql
@@ -0,0 +1,16 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE redactions ADD COLUMN have_censored BOOL NOT NULL DEFAULT false;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
new file mode 100644
index 0000000000..49ce35d794
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor2.sql
@@ -0,0 +1,22 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE redactions ADD COLUMN received_ts BIGINT;
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('redactions_received_ts', '{}');
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('redactions_have_censored_ts_idx', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres
new file mode 100644
index 0000000000..67471f3ef5
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor3_fix_update.sql.postgres
@@ -0,0 +1,25 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- There was a bug where we may have updated censored redactions as bytes,
+-- which can (somehow) cause json to be inserted hex encoded. These updates go
+-- and undoes any such hex encoded JSON.
+
+INSERT into background_updates (update_name, progress_json)
+ VALUES ('event_fix_redactions_bytes_create_index', '{}');
+
+INSERT into background_updates (update_name, progress_json, depends_on)
+ VALUES ('event_fix_redactions_bytes', '{}', 'event_fix_redactions_bytes_create_index');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql
new file mode 100644
index 0000000000..b7550f6f4e
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/redaction_censor4.sql
@@ -0,0 +1,16 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+DROP INDEX IF EXISTS redactions_have_censored;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql b/synapse/storage/data_stores/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql
new file mode 100644
index 0000000000..aeb17813d3
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/remove_tombstoned_rooms_from_directory.sql
@@ -0,0 +1,18 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Now that #6232 is a thing, we can remove old rooms from the directory.
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('remove_tombstoned_rooms_from_directory', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql b/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql
new file mode 100644
index 0000000000..7d70dd071e
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql
@@ -0,0 +1,17 @@
+/* Copyright 2019 Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- store the current etag of backup version
+ALTER TABLE e2e_room_keys_versions ADD COLUMN etag BIGINT;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql b/synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql
new file mode 100644
index 0000000000..92ab1f5e65
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/room_membership_idx.sql
@@ -0,0 +1,18 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Adds an index on room_memberships for fetching all forgotten rooms for a user
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('room_membership_forgotten_idx', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql b/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql
new file mode 100644
index 0000000000..5c5fffcafb
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql
@@ -0,0 +1,56 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- cross-signing keys
+CREATE TABLE IF NOT EXISTS e2e_cross_signing_keys (
+ user_id TEXT NOT NULL,
+ -- the type of cross-signing key (master, user_signing, or self_signing)
+ keytype TEXT NOT NULL,
+ -- the full key information, as a json-encoded dict
+ keydata TEXT NOT NULL,
+ -- for keeping the keys in order, so that we can fetch the latest one
+ stream_id BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX e2e_cross_signing_keys_idx ON e2e_cross_signing_keys(user_id, keytype, stream_id);
+
+-- cross-signing signatures
+CREATE TABLE IF NOT EXISTS e2e_cross_signing_signatures (
+ -- user who did the signing
+ user_id TEXT NOT NULL,
+ -- key used to sign
+ key_id TEXT NOT NULL,
+ -- user who was signed
+ target_user_id TEXT NOT NULL,
+ -- device/key that was signed
+ target_device_id TEXT NOT NULL,
+ -- the actual signature
+ signature TEXT NOT NULL
+);
+
+-- replaced by the index created in signing_keys_nonunique_signatures.sql
+-- CREATE UNIQUE INDEX e2e_cross_signing_signatures_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id);
+
+-- stream of user signature updates
+CREATE TABLE IF NOT EXISTS user_signature_stream (
+ -- uses the same stream ID as device list stream
+ stream_id BIGINT NOT NULL,
+ -- user who did the signing
+ from_user_id TEXT NOT NULL,
+ -- list of users who were signed, as a JSON array
+ user_ids TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX user_signature_stream_idx ON user_signature_stream(stream_id);
diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql b/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql
new file mode 100644
index 0000000000..0aa90ebf0c
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql
@@ -0,0 +1,22 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* The cross-signing signatures index should not be a unique index, because a
+ * user may upload multiple signatures for the same target user. The previous
+ * index was unique, so delete it if it's there and create a new non-unique
+ * index. */
+
+DROP INDEX IF EXISTS e2e_cross_signing_signatures_idx; CREATE INDEX IF NOT
+EXISTS e2e_cross_signing_signatures2_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id);
diff --git a/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql b/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql
new file mode 100644
index 0000000000..163529c071
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/stats_separated.sql
@@ -0,0 +1,152 @@
+/* Copyright 2018 New Vector Ltd
+ * Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+----- First clean up from previous versions of room stats.
+
+-- First remove old stats stuff
+DROP TABLE IF EXISTS room_stats;
+DROP TABLE IF EXISTS room_state;
+DROP TABLE IF EXISTS room_stats_state;
+DROP TABLE IF EXISTS user_stats;
+DROP TABLE IF EXISTS room_stats_earliest_tokens;
+DROP TABLE IF EXISTS _temp_populate_stats_position;
+DROP TABLE IF EXISTS _temp_populate_stats_rooms;
+DROP TABLE IF EXISTS stats_stream_pos;
+
+-- Unschedule old background updates if they're still scheduled
+DELETE FROM background_updates WHERE update_name IN (
+ 'populate_stats_createtables',
+ 'populate_stats_process_rooms',
+ 'populate_stats_process_users',
+ 'populate_stats_cleanup'
+);
+
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+ ('populate_stats_process_rooms', '{}', '');
+
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+ ('populate_stats_process_users', '{}', 'populate_stats_process_rooms');
+
+----- Create tables for our version of room stats.
+
+-- single-row table to track position of incremental updates
+DROP TABLE IF EXISTS stats_incremental_position;
+CREATE TABLE stats_incremental_position (
+ Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
+ stream_id BIGINT NOT NULL,
+ CHECK (Lock='X')
+);
+
+-- insert a null row and make sure it is the only one.
+INSERT INTO stats_incremental_position (
+ stream_id
+) SELECT COALESCE(MAX(stream_ordering), 0) from events;
+
+-- represents PRESENT room statistics for a room
+-- only holds absolute fields
+DROP TABLE IF EXISTS room_stats_current;
+CREATE TABLE room_stats_current (
+ room_id TEXT NOT NULL PRIMARY KEY,
+
+ -- These are absolute counts
+ current_state_events INT NOT NULL,
+ joined_members INT NOT NULL,
+ invited_members INT NOT NULL,
+ left_members INT NOT NULL,
+ banned_members INT NOT NULL,
+
+ local_users_in_room INT NOT NULL,
+
+ -- The maximum delta stream position that this row takes into account.
+ completed_delta_stream_id BIGINT NOT NULL
+);
+
+
+-- represents HISTORICAL room statistics for a room
+DROP TABLE IF EXISTS room_stats_historical;
+CREATE TABLE room_stats_historical (
+ room_id TEXT NOT NULL,
+ -- These stats cover the time from (end_ts - bucket_size)...end_ts (in ms).
+ -- Note that end_ts is quantised.
+ end_ts BIGINT NOT NULL,
+ bucket_size BIGINT NOT NULL,
+
+ -- These stats are absolute counts
+ current_state_events BIGINT NOT NULL,
+ joined_members BIGINT NOT NULL,
+ invited_members BIGINT NOT NULL,
+ left_members BIGINT NOT NULL,
+ banned_members BIGINT NOT NULL,
+ local_users_in_room BIGINT NOT NULL,
+
+ -- These stats are per time slice
+ total_events BIGINT NOT NULL,
+ total_event_bytes BIGINT NOT NULL,
+
+ PRIMARY KEY (room_id, end_ts)
+);
+
+-- We use this index to speed up deletion of ancient room stats.
+CREATE INDEX room_stats_historical_end_ts ON room_stats_historical (end_ts);
+
+-- represents PRESENT statistics for a user
+-- only holds absolute fields
+DROP TABLE IF EXISTS user_stats_current;
+CREATE TABLE user_stats_current (
+ user_id TEXT NOT NULL PRIMARY KEY,
+
+ joined_rooms BIGINT NOT NULL,
+
+ -- The maximum delta stream position that this row takes into account.
+ completed_delta_stream_id BIGINT NOT NULL
+);
+
+-- represents HISTORICAL statistics for a user
+DROP TABLE IF EXISTS user_stats_historical;
+CREATE TABLE user_stats_historical (
+ user_id TEXT NOT NULL,
+ end_ts BIGINT NOT NULL,
+ bucket_size BIGINT NOT NULL,
+
+ joined_rooms BIGINT NOT NULL,
+
+ invites_sent BIGINT NOT NULL,
+ rooms_created BIGINT NOT NULL,
+ total_events BIGINT NOT NULL,
+ total_event_bytes BIGINT NOT NULL,
+
+ PRIMARY KEY (user_id, end_ts)
+);
+
+-- We use this index to speed up deletion of ancient user stats.
+CREATE INDEX user_stats_historical_end_ts ON user_stats_historical (end_ts);
+
+
+CREATE TABLE room_stats_state (
+ room_id TEXT NOT NULL,
+ name TEXT,
+ canonical_alias TEXT,
+ join_rules TEXT,
+ history_visibility TEXT,
+ encryption TEXT,
+ avatar TEXT,
+ guest_access TEXT,
+ is_federatable BOOLEAN,
+ topic TEXT
+);
+
+CREATE UNIQUE INDEX room_stats_state_room ON room_stats_state(room_id);
diff --git a/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py b/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py
new file mode 100644
index 0000000000..1de8b54961
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/unique_user_filter_index.py
@@ -0,0 +1,52 @@
+import logging
+
+from synapse.storage.engines import PostgresEngine
+
+logger = logging.getLogger(__name__)
+
+
+"""
+This migration updates the user_filters table as follows:
+
+ - drops any (user_id, filter_id) duplicates
+ - makes the columns NON-NULLable
+ - turns the index into a UNIQUE index
+"""
+
+
+def run_upgrade(cur, database_engine, *args, **kwargs):
+ pass
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ if isinstance(database_engine, PostgresEngine):
+ select_clause = """
+ SELECT DISTINCT ON (user_id, filter_id) user_id, filter_id, filter_json
+ FROM user_filters
+ """
+ else:
+ select_clause = """
+ SELECT * FROM user_filters GROUP BY user_id, filter_id
+ """
+ sql = """
+ DROP TABLE IF EXISTS user_filters_migration;
+ DROP INDEX IF EXISTS user_filters_unique;
+ CREATE TABLE user_filters_migration (
+ user_id TEXT NOT NULL,
+ filter_id BIGINT NOT NULL,
+ filter_json BYTEA NOT NULL
+ );
+ INSERT INTO user_filters_migration (user_id, filter_id, filter_json)
+ %s;
+ CREATE UNIQUE INDEX user_filters_unique ON user_filters_migration
+ (user_id, filter_id);
+ DROP TABLE user_filters;
+ ALTER TABLE user_filters_migration RENAME TO user_filters;
+ """ % (
+ select_clause,
+ )
+
+ if isinstance(database_engine, PostgresEngine):
+ cur.execute(sql)
+ else:
+ cur.executescript(sql)
diff --git a/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql b/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql
new file mode 100644
index 0000000000..91390c4527
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/user_external_ids.sql
@@ -0,0 +1,24 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * a table which records mappings from external auth providers to mxids
+ */
+CREATE TABLE IF NOT EXISTS user_external_ids (
+ auth_provider TEXT NOT NULL,
+ external_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ UNIQUE (auth_provider, external_id)
+);
diff --git a/synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql b/synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql
new file mode 100644
index 0000000000..149f8be8b6
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/users_in_public_rooms_idx.sql
@@ -0,0 +1,17 @@
+/* Copyright 2019 Matrix.org Foundation CIC
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- this was apparently forgotten when the table was created back in delta 53.
+CREATE INDEX users_in_public_rooms_r_idx ON users_in_public_rooms(room_id);
diff --git a/synapse/storage/data_stores/main/schema/delta/57/delete_old_current_state_events.sql b/synapse/storage/data_stores/main/schema/delta/57/delete_old_current_state_events.sql
new file mode 100644
index 0000000000..aec06c8261
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/delete_old_current_state_events.sql
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Add background update to go and delete current state events for rooms the
+-- server is no longer in.
+--
+-- this relies on the 'membership' column of current_state_events, so make sure
+-- that's populated first!
+INSERT into background_updates (update_name, progress_json, depends_on)
+ VALUES ('delete_old_current_state_events', '{}', 'current_state_events_membership');
diff --git a/synapse/storage/data_stores/main/schema/delta/57/device_list_remote_cache_stale.sql b/synapse/storage/data_stores/main/schema/delta/57/device_list_remote_cache_stale.sql
new file mode 100644
index 0000000000..c3b6de2099
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/device_list_remote_cache_stale.sql
@@ -0,0 +1,25 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Records whether the server thinks that the remote users cached device lists
+-- may be out of date (e.g. if we have received a to device message from a
+-- device we don't know about).
+CREATE TABLE IF NOT EXISTS device_lists_remote_resync (
+ user_id TEXT NOT NULL,
+ added_ts BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX device_lists_remote_resync_idx ON device_lists_remote_resync (user_id);
+CREATE INDEX device_lists_remote_resync_ts_idx ON device_lists_remote_resync (added_ts);
diff --git a/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py b/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py
new file mode 100644
index 0000000000..63b5acdcf7
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/local_current_membership.py
@@ -0,0 +1,98 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+# We create a new table called `local_current_membership` that stores the latest
+# membership state of local users in rooms, which helps track leaves/bans/etc
+# even if the server has left the room (and so has deleted the room from
+# `current_state_events`). This will also include outstanding invites for local
+# users for rooms the server isn't in.
+#
+# If the server isn't and hasn't been in the room then it will only include
+# outsstanding invites, and not e.g. pre-emptive bans of local users.
+#
+# If the server later rejoins a room `local_current_membership` can simply be
+# replaced with the new current state of the room (which results in the
+# equivalent behaviour as if the server had remained in the room).
+
+
+def run_upgrade(cur, database_engine, config, *args, **kwargs):
+ # We need to do the insert in `run_upgrade` section as we don't have access
+ # to `config` in `run_create`.
+
+ # This upgrade may take a bit of time for large servers (e.g. one minute for
+ # matrix.org) but means we avoid a lots of book keeping required to do it as
+ # a background update.
+
+ # We check if the `current_state_events.membership` is up to date by
+ # checking if the relevant background update has finished. If it has
+ # finished we can avoid doing a join against `room_memberships`, which
+ # speesd things up.
+ cur.execute(
+ """SELECT 1 FROM background_updates
+ WHERE update_name = 'current_state_events_membership'
+ """
+ )
+ current_state_membership_up_to_date = not bool(cur.fetchone())
+
+ # Cheekily drop and recreate indices, as that is faster.
+ cur.execute("DROP INDEX local_current_membership_idx")
+ cur.execute("DROP INDEX local_current_membership_room_idx")
+
+ if current_state_membership_up_to_date:
+ sql = """
+ INSERT INTO local_current_membership (room_id, user_id, event_id, membership)
+ SELECT c.room_id, state_key AS user_id, event_id, c.membership
+ FROM current_state_events AS c
+ WHERE type = 'm.room.member' AND c.membership IS NOT NULL AND state_key LIKE ?
+ """
+ else:
+ # We can't rely on the membership column, so we need to join against
+ # `room_memberships`.
+ sql = """
+ INSERT INTO local_current_membership (room_id, user_id, event_id, membership)
+ SELECT c.room_id, state_key AS user_id, event_id, r.membership
+ FROM current_state_events AS c
+ INNER JOIN room_memberships AS r USING (event_id)
+ WHERE type = 'm.room.member' AND state_key LIKE ?
+ """
+ sql = database_engine.convert_param_style(sql)
+ cur.execute(sql, ("%:" + config.server_name,))
+
+ cur.execute(
+ "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)"
+ )
+ cur.execute(
+ "CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)"
+ )
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ cur.execute(
+ """
+ CREATE TABLE local_current_membership (
+ room_id TEXT NOT NULL,
+ user_id TEXT NOT NULL,
+ event_id TEXT NOT NULL,
+ membership TEXT NOT NULL
+ )"""
+ )
+
+ cur.execute(
+ "CREATE UNIQUE INDEX local_current_membership_idx ON local_current_membership(user_id, room_id)"
+ )
+ cur.execute(
+ "CREATE INDEX local_current_membership_room_idx ON local_current_membership(room_id)"
+ )
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql
new file mode 100644
index 0000000000..352a66f5b0
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column.sql
@@ -0,0 +1,24 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- We want to start storing the room version independently of
+-- `current_state_events` so that we can delete stale entries from it without
+-- losing the information.
+ALTER TABLE rooms ADD COLUMN room_version TEXT;
+
+
+INSERT into background_updates (update_name, progress_json)
+ VALUES ('add_rooms_room_version_column', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.postgres b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.postgres
new file mode 100644
index 0000000000..c601cff6de
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.postgres
@@ -0,0 +1,35 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- when we first added the room_version column, it was populated via a background
+-- update. We now need it to be populated before synapse starts, so we populate
+-- any remaining rows with a NULL room version now. For servers which have completed
+-- the background update, this will be pretty quick.
+
+-- the following query will set room_version to NULL if no create event is found for
+-- the room in current_state_events, and will set it to '1' if a create event with no
+-- room_version is found.
+
+UPDATE rooms SET room_version=(
+ SELECT COALESCE(json::json->'content'->>'room_version','1')
+ FROM current_state_events cse INNER JOIN event_json ej USING (event_id)
+ WHERE cse.room_id=rooms.room_id AND cse.type='m.room.create' AND cse.state_key=''
+) WHERE rooms.room_version IS NULL;
+
+-- we still allow the background update to complete: it has the useful side-effect of
+-- populating `rooms` with any missing rooms (based on the current_state_events table).
+
+-- see also rooms_version_column_2.sql.sqlite which has a copy of the above query, using
+-- sqlite syntax for the json extraction.
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.sqlite
new file mode 100644
index 0000000000..335c6f2074
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_2.sql.sqlite
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- see rooms_version_column_2.sql.postgres for details of what's going on here.
+
+UPDATE rooms SET room_version=(
+ SELECT COALESCE(json_extract(ej.json, '$.content.room_version'), '1')
+ FROM current_state_events cse INNER JOIN event_json ej USING (event_id)
+ WHERE cse.room_id=rooms.room_id AND cse.type='m.room.create' AND cse.state_key=''
+) WHERE rooms.room_version IS NULL;
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres
new file mode 100644
index 0000000000..92aaadde0d
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.postgres
@@ -0,0 +1,39 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- When we first added the room_version column to the rooms table, it was populated from
+-- the current_state_events table. However, there was an issue causing a background
+-- update to clean up the current_state_events table for rooms where the server is no
+-- longer participating, before that column could be populated. Therefore, some rooms had
+-- a NULL room_version.
+
+-- The rooms_version_column_2.sql.* delta files were introduced to make the populating
+-- synchronous instead of running it in a background update, which fixed this issue.
+-- However, all of the instances of Synapse installed or updated in the meantime got
+-- their rooms table corrupted with NULL room_versions.
+
+-- This query fishes out the room versions from the create event using the state_events
+-- table instead of the current_state_events one, as the former still have all of the
+-- create events.
+
+UPDATE rooms SET room_version=(
+ SELECT COALESCE(json::json->'content'->>'room_version','1')
+ FROM state_events se INNER JOIN event_json ej USING (event_id)
+ WHERE se.room_id=rooms.room_id AND se.type='m.room.create' AND se.state_key=''
+ LIMIT 1
+) WHERE rooms.room_version IS NULL;
+
+-- see also rooms_version_column_3.sql.sqlite which has a copy of the above query, using
+-- sqlite syntax for the json extraction.
diff --git a/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite
new file mode 100644
index 0000000000..e19dab97cb
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/57/rooms_version_column_3.sql.sqlite
@@ -0,0 +1,23 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- see rooms_version_column_3.sql.postgres for details of what's going on here.
+
+UPDATE rooms SET room_version=(
+ SELECT COALESCE(json_extract(ej.json, '$.content.room_version'), '1')
+ FROM state_events se INNER JOIN event_json ej USING (event_id)
+ WHERE se.room_id=rooms.room_id AND se.type='m.room.create' AND se.state_key=''
+ LIMIT 1
+) WHERE rooms.room_version IS NULL;
diff --git a/synapse/storage/schema/full_schemas/16/application_services.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql
index 883fcd10b2..883fcd10b2 100644
--- a/synapse/storage/schema/full_schemas/16/application_services.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/application_services.sql
diff --git a/synapse/storage/schema/full_schemas/16/event_edges.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql
index 10ce2aa7a0..10ce2aa7a0 100644
--- a/synapse/storage/schema/full_schemas/16/event_edges.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/event_edges.sql
diff --git a/synapse/storage/schema/full_schemas/16/event_signatures.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql
index 95826da431..95826da431 100644
--- a/synapse/storage/schema/full_schemas/16/event_signatures.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/event_signatures.sql
diff --git a/synapse/storage/schema/full_schemas/16/im.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/im.sql
index a1a2aa8e5b..a1a2aa8e5b 100644
--- a/synapse/storage/schema/full_schemas/16/im.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/im.sql
diff --git a/synapse/storage/schema/full_schemas/16/keys.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql
index 11cdffdbb3..11cdffdbb3 100644
--- a/synapse/storage/schema/full_schemas/16/keys.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/keys.sql
diff --git a/synapse/storage/schema/full_schemas/16/media_repository.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql
index 8f3759bb2a..8f3759bb2a 100644
--- a/synapse/storage/schema/full_schemas/16/media_repository.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/media_repository.sql
diff --git a/synapse/storage/schema/full_schemas/16/presence.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql
index 01d2d8f833..01d2d8f833 100644
--- a/synapse/storage/schema/full_schemas/16/presence.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/presence.sql
diff --git a/synapse/storage/schema/full_schemas/16/profiles.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql
index c04f4747d9..c04f4747d9 100644
--- a/synapse/storage/schema/full_schemas/16/profiles.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/profiles.sql
diff --git a/synapse/storage/schema/full_schemas/16/push.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/push.sql
index e44465cf45..e44465cf45 100644
--- a/synapse/storage/schema/full_schemas/16/push.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/push.sql
diff --git a/synapse/storage/schema/full_schemas/16/redactions.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql
index 318f0d9aa5..318f0d9aa5 100644
--- a/synapse/storage/schema/full_schemas/16/redactions.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/redactions.sql
diff --git a/synapse/storage/schema/full_schemas/16/room_aliases.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql
index d47da3b12f..d47da3b12f 100644
--- a/synapse/storage/schema/full_schemas/16/room_aliases.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/room_aliases.sql
diff --git a/synapse/storage/schema/full_schemas/16/state.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/state.sql
index 96391a8f0e..96391a8f0e 100644
--- a/synapse/storage/schema/full_schemas/16/state.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/state.sql
diff --git a/synapse/storage/schema/full_schemas/16/transactions.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql
index 17e67bedac..17e67bedac 100644
--- a/synapse/storage/schema/full_schemas/16/transactions.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/transactions.sql
diff --git a/synapse/storage/schema/full_schemas/16/users.sql b/synapse/storage/data_stores/main/schema/full_schemas/16/users.sql
index f013aa8b18..f013aa8b18 100644
--- a/synapse/storage/schema/full_schemas/16/users.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/16/users.sql
diff --git a/synapse/storage/schema/full_schemas/54/full.sql.postgres b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
index 01a2b0e024..20c5af2eb7 100644
--- a/synapse/storage/schema/full_schemas/54/full.sql.postgres
+++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.postgres
@@ -70,15 +70,6 @@ CREATE TABLE appservice_stream_position (
);
-
-CREATE TABLE background_updates (
- update_name text NOT NULL,
- progress_json text NOT NULL,
- depends_on text
-);
-
-
-
CREATE TABLE blocked_rooms (
room_id text NOT NULL,
user_id text NOT NULL
@@ -993,40 +984,6 @@ CREATE TABLE state_events (
-CREATE TABLE state_group_edges (
- state_group bigint NOT NULL,
- prev_state_group bigint NOT NULL
-);
-
-
-
-CREATE SEQUENCE state_group_id_seq
- START WITH 1
- INCREMENT BY 1
- NO MINVALUE
- NO MAXVALUE
- CACHE 1;
-
-
-
-CREATE TABLE state_groups (
- id bigint NOT NULL,
- room_id text NOT NULL,
- event_id text NOT NULL
-);
-
-
-
-CREATE TABLE state_groups_state (
- state_group bigint NOT NULL,
- room_id text NOT NULL,
- type text NOT NULL,
- state_key text NOT NULL,
- event_id text NOT NULL
-);
-
-
-
CREATE TABLE stats_stream_pos (
lock character(1) DEFAULT 'X'::bpchar NOT NULL,
stream_id bigint,
@@ -1211,11 +1168,6 @@ ALTER TABLE ONLY appservice_stream_position
-ALTER TABLE ONLY background_updates
- ADD CONSTRAINT background_updates_uniqueness UNIQUE (update_name);
-
-
-
ALTER TABLE ONLY current_state_events
ADD CONSTRAINT current_state_events_event_id_key UNIQUE (event_id);
@@ -1505,12 +1457,6 @@ ALTER TABLE ONLY state_events
ADD CONSTRAINT state_events_event_id_key UNIQUE (event_id);
-
-ALTER TABLE ONLY state_groups
- ADD CONSTRAINT state_groups_pkey PRIMARY KEY (id);
-
-
-
ALTER TABLE ONLY stats_stream_pos
ADD CONSTRAINT stats_stream_pos_lock_key UNIQUE (lock);
@@ -1955,18 +1901,6 @@ CREATE UNIQUE INDEX room_stats_room_ts ON room_stats USING btree (room_id, ts);
-CREATE INDEX state_group_edges_idx ON state_group_edges USING btree (state_group);
-
-
-
-CREATE INDEX state_group_edges_prev_idx ON state_group_edges USING btree (prev_state_group);
-
-
-
-CREATE INDEX state_groups_state_type_idx ON state_groups_state USING btree (state_group, type, state_key);
-
-
-
CREATE INDEX stream_ordering_to_exterm_idx ON stream_ordering_to_exterm USING btree (stream_ordering);
@@ -2060,6 +1994,3 @@ CREATE INDEX users_who_share_private_rooms_r_idx ON users_who_share_private_room
CREATE UNIQUE INDEX users_who_share_private_rooms_u_idx ON users_who_share_private_rooms USING btree (user_id, other_user_id, room_id);
-
-
-
diff --git a/synapse/storage/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
index f1a71627f0..e28ec3fa45 100644
--- a/synapse/storage/schema/full_schemas/54/full.sql.sqlite
+++ b/synapse/storage/data_stores/main/schema/full_schemas/54/full.sql.sqlite
@@ -42,8 +42,6 @@ CREATE INDEX ev_edges_id ON event_edges(event_id);
CREATE INDEX ev_edges_prev_id ON event_edges(prev_event_id);
CREATE TABLE room_depth( room_id TEXT NOT NULL, min_depth INTEGER NOT NULL, UNIQUE (room_id) );
CREATE INDEX room_depth_room ON room_depth(room_id);
-CREATE TABLE state_groups( id BIGINT PRIMARY KEY, room_id TEXT NOT NULL, event_id TEXT NOT NULL );
-CREATE TABLE state_groups_state( state_group BIGINT NOT NULL, room_id TEXT NOT NULL, type TEXT NOT NULL, state_key TEXT NOT NULL, event_id TEXT NOT NULL );
CREATE TABLE event_to_state_groups( event_id TEXT NOT NULL, state_group BIGINT NOT NULL, UNIQUE (event_id) );
CREATE TABLE local_media_repository ( media_id TEXT, media_type TEXT, media_length INTEGER, created_ts BIGINT, upload_name TEXT, user_id TEXT, quarantined_by TEXT, url_cache TEXT, last_access_ts BIGINT, UNIQUE (media_id) );
CREATE TABLE local_media_repository_thumbnails ( media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_type TEXT, thumbnail_method TEXT, thumbnail_length INTEGER, UNIQUE ( media_id, thumbnail_width, thumbnail_height, thumbnail_type ) );
@@ -67,7 +65,6 @@ CREATE INDEX receipts_linearized_id ON receipts_linearized( stream_id );
CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( room_id, stream_id );
CREATE TABLE IF NOT EXISTS "user_threepids" ( user_id TEXT NOT NULL, medium TEXT NOT NULL, address TEXT NOT NULL, validated_at BIGINT NOT NULL, added_at BIGINT NOT NULL, CONSTRAINT medium_address UNIQUE (medium, address) );
CREATE INDEX user_threepids_user_id ON user_threepids(user_id);
-CREATE TABLE background_updates( update_name TEXT NOT NULL, progress_json TEXT NOT NULL, depends_on TEXT, CONSTRAINT background_updates_uniqueness UNIQUE (update_name) );
CREATE VIRTUAL TABLE event_search USING fts4 ( event_id, room_id, sender, key, value )
/* event_search(event_id,room_id,sender,"key",value) */;
CREATE TABLE IF NOT EXISTS 'event_search_content'(docid INTEGER PRIMARY KEY, 'c0event_id', 'c1room_id', 'c2sender', 'c3key', 'c4value');
@@ -121,9 +118,6 @@ CREATE TABLE device_max_stream_id ( stream_id BIGINT NOT NULL );
CREATE TABLE public_room_list_stream ( stream_id BIGINT NOT NULL, room_id TEXT NOT NULL, visibility BOOLEAN NOT NULL , appservice_id TEXT, network_id TEXT);
CREATE INDEX public_room_list_stream_idx on public_room_list_stream( stream_id );
CREATE INDEX public_room_list_stream_rm_idx on public_room_list_stream( room_id, stream_id );
-CREATE TABLE state_group_edges( state_group BIGINT NOT NULL, prev_state_group BIGINT NOT NULL );
-CREATE INDEX state_group_edges_idx ON state_group_edges(state_group);
-CREATE INDEX state_group_edges_prev_idx ON state_group_edges(prev_state_group);
CREATE TABLE stream_ordering_to_exterm ( stream_ordering BIGINT NOT NULL, room_id TEXT NOT NULL, event_id TEXT NOT NULL );
CREATE INDEX stream_ordering_to_exterm_idx on stream_ordering_to_exterm( stream_ordering );
CREATE INDEX stream_ordering_to_exterm_rm_idx on stream_ordering_to_exterm( room_id, stream_ordering );
@@ -257,6 +251,5 @@ CREATE INDEX user_ips_last_seen_only ON user_ips (last_seen);
CREATE INDEX users_creation_ts ON users (creation_ts);
CREATE INDEX event_to_state_groups_sg_index ON event_to_state_groups (state_group);
CREATE UNIQUE INDEX device_lists_remote_cache_unique_id ON device_lists_remote_cache (user_id, device_id);
-CREATE INDEX state_groups_state_type_idx ON state_groups_state(state_group, type, state_key);
CREATE UNIQUE INDEX device_lists_remote_extremeties_unique_idx ON device_lists_remote_extremeties (user_id);
CREATE UNIQUE INDEX user_ips_user_token_ip_unique_index ON user_ips (user_id, access_token, ip);
diff --git a/synapse/storage/schema/full_schemas/54/stream_positions.sql b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql
index c265fd20e2..91d21b2921 100644
--- a/synapse/storage/schema/full_schemas/54/stream_positions.sql
+++ b/synapse/storage/data_stores/main/schema/full_schemas/54/stream_positions.sql
@@ -5,3 +5,4 @@ INSERT INTO federation_stream_position (type, stream_id) SELECT 'events', coales
INSERT INTO user_directory_stream_pos (stream_id) VALUES (0);
INSERT INTO stats_stream_pos (stream_id) VALUES (0);
INSERT INTO event_push_summary_stream_ordering (stream_ordering) VALUES (0);
+-- device_max_stream_id is handled separately in 56/device_stream_id_insert.sql
\ No newline at end of file
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.md b/synapse/storage/data_stores/main/schema/full_schemas/README.md
new file mode 100644
index 0000000000..c00f287190
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/full_schemas/README.md
@@ -0,0 +1,21 @@
+# Synapse Database Schemas
+
+These schemas are used as a basis to create brand new Synapse databases, on both
+SQLite3 and Postgres.
+
+## Building full schema dumps
+
+If you want to recreate these schemas, they need to be made from a database that
+has had all background updates run.
+
+To do so, use `scripts-dev/make_full_schema.sh`. This will produce new
+`full.sql.postgres ` and `full.sql.sqlite` files.
+
+Ensure postgres is installed and your user has the ability to run bash commands
+such as `createdb`, then call
+
+ ./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/
+
+There are currently two folders with full-schema snapshots. `16` is a snapshot
+from 2015, for historical reference. The other contains the most recent full
+schema snapshot.
diff --git a/synapse/storage/search.py b/synapse/storage/data_stores/main/search.py
index ff49eaae02..47ebb8a214 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -24,35 +24,36 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
+from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from .background_updates import BackgroundUpdateStore
-
logger = logging.getLogger(__name__)
SearchEntry = namedtuple(
- 'SearchEntry',
- ['key', 'value', 'event_id', 'room_id', 'stream_ordering', 'origin_server_ts'],
+ "SearchEntry",
+ ["key", "value", "event_id", "room_id", "stream_ordering", "origin_server_ts"],
)
-class SearchStore(BackgroundUpdateStore):
+class SearchBackgroundUpdateStore(SQLBaseStore):
EVENT_SEARCH_UPDATE_NAME = "event_search"
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
- def __init__(self, db_conn, hs):
- super(SearchStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs)
if not hs.config.enable_search:
return
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
)
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order
)
@@ -61,9 +62,11 @@ class SearchStore(BackgroundUpdateStore):
# a GIN index. However, it's possible that some people might still have
# the background update queued, so we register a handler to clear the
# background update.
- self.register_noop_background_update(self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME)
+ self.db.updates.register_noop_background_update(
+ self.EVENT_SEARCH_USE_GIST_POSTGRES_NAME
+ )
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search
)
@@ -93,7 +96,7 @@ class SearchStore(BackgroundUpdateStore):
# store_search_entries_txn with a generator function, but that
# would mean having two cursors open on the database at once.
# Instead we just build a list of results.
- rows = self.cursor_to_dict(txn)
+ rows = self.db.cursor_to_dict(txn)
if not rows:
return 0
@@ -153,20 +156,20 @@ class SearchStore(BackgroundUpdateStore):
"rows_inserted": rows_inserted + len(event_search_rows),
}
- self._background_update_progress_txn(
+ self.db.updates._background_update_progress_txn(
txn, self.EVENT_SEARCH_UPDATE_NAME, progress
)
return len(event_search_rows)
- result = yield self.runInteraction(
+ result = yield self.db.runInteraction(
self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
)
if not result:
- yield self._end_background_update(self.EVENT_SEARCH_UPDATE_NAME)
+ yield self.db.updates._end_background_update(self.EVENT_SEARCH_UPDATE_NAME)
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def _background_reindex_gin_search(self, progress, batch_size):
@@ -196,7 +199,7 @@ class SearchStore(BackgroundUpdateStore):
" ON event_search USING GIN (vector)"
)
except psycopg2.ProgrammingError as e:
- logger.warn(
+ logger.warning(
"Ignoring error %r when trying to switch from GIST to GIN", e
)
@@ -206,17 +209,19 @@ class SearchStore(BackgroundUpdateStore):
conn.set_session(autocommit=False)
if isinstance(self.database_engine, PostgresEngine):
- yield self.runWithConnection(create_index)
+ yield self.db.runWithConnection(create_index)
- yield self._end_background_update(self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME)
- defer.returnValue(1)
+ yield self.db.updates._end_background_update(
+ self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME
+ )
+ return 1
@defer.inlineCallbacks
def _background_reindex_search_order(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
- have_added_index = progress['have_added_indexes']
+ have_added_index = progress["have_added_indexes"]
if not have_added_index:
@@ -237,14 +242,14 @@ class SearchStore(BackgroundUpdateStore):
)
conn.set_session(autocommit=False)
- yield self.runWithConnection(create_index)
+ yield self.db.runWithConnection(create_index)
pg = dict(progress)
pg["have_added_indexes"] = True
- yield self.runInteraction(
+ yield self.db.runInteraction(
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
- self._background_update_progress_txn,
+ self.db.updates._background_update_progress_txn,
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
pg,
)
@@ -274,43 +279,22 @@ class SearchStore(BackgroundUpdateStore):
"have_added_indexes": True,
}
- self._background_update_progress_txn(
+ self.db.updates._background_update_progress_txn(
txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress
)
return len(rows), True
- num_rows, finished = yield self.runInteraction(
+ num_rows, finished = yield self.db.runInteraction(
self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn
)
if not finished:
- yield self._end_background_update(self.EVENT_SEARCH_ORDER_UPDATE_NAME)
-
- defer.returnValue(num_rows)
-
- def store_event_search_txn(self, txn, event, key, value):
- """Add event to the search table
+ yield self.db.updates._end_background_update(
+ self.EVENT_SEARCH_ORDER_UPDATE_NAME
+ )
- Args:
- txn (cursor):
- event (EventBase):
- key (str):
- value (str):
- """
- self.store_search_entries_txn(
- txn,
- (
- SearchEntry(
- key=key,
- value=value,
- event_id=event.event_id,
- room_id=event.room_id,
- stream_ordering=event.internal_metadata.stream_ordering,
- origin_server_ts=event.origin_server_ts,
- ),
- ),
- )
+ return num_rows
def store_search_entries_txn(self, txn, entries):
"""Add entries to the search table
@@ -341,29 +325,7 @@ class SearchStore(BackgroundUpdateStore):
for entry in entries
)
- # inserts to a GIN index are normally batched up into a pending
- # list, and then all committed together once the list gets to a
- # certain size. The trouble with that is that postgres (pre-9.5)
- # uses work_mem to determine the length of the list, and work_mem
- # is typically very large.
- #
- # We therefore reduce work_mem while we do the insert.
- #
- # (postgres 9.5 uses the separate gin_pending_list_limit setting,
- # so doesn't suffer the same problem, but changing work_mem will
- # be harmless)
- #
- # Note that we don't need to worry about restoring it on
- # exception, because exceptions will cause the transaction to be
- # rolled back, including the effects of the SET command.
- #
- # Also: we use SET rather than SET LOCAL because there's lots of
- # other stuff going on in this transaction, which want to have the
- # normal work_mem setting.
-
- txn.execute("SET work_mem='256kB'")
txn.executemany(sql, args)
- txn.execute("RESET work_mem")
elif isinstance(self.database_engine, Sqlite3Engine):
sql = (
@@ -380,6 +342,34 @@ class SearchStore(BackgroundUpdateStore):
# This should be unreachable.
raise Exception("Unrecognized database engine")
+
+class SearchStore(SearchBackgroundUpdateStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(SearchStore, self).__init__(database, db_conn, hs)
+
+ def store_event_search_txn(self, txn, event, key, value):
+ """Add event to the search table
+
+ Args:
+ txn (cursor):
+ event (EventBase):
+ key (str):
+ value (str):
+ """
+ self.store_search_entries_txn(
+ txn,
+ (
+ SearchEntry(
+ key=key,
+ value=value,
+ event_id=event.event_id,
+ room_id=event.room_id,
+ stream_ordering=event.internal_metadata.stream_ordering,
+ origin_server_ts=event.origin_server_ts,
+ ),
+ ),
+ )
+
@defer.inlineCallbacks
def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys.
@@ -395,15 +385,17 @@ class SearchStore(BackgroundUpdateStore):
"""
clauses = []
- search_query = search_query = _parse_query(self.database_engine, search_term)
+ search_query = _parse_query(self.database_engine, search_term)
args = []
# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
if len(room_ids) < 500:
- clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),))
- args.extend(room_ids)
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", room_ids
+ )
+ clauses = [clause]
local_clauses = []
for key in keys:
@@ -456,11 +448,18 @@ class SearchStore(BackgroundUpdateStore):
# entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500"
- results = yield self._execute("search_msgs", self.cursor_to_dict, sql, *args)
+ results = yield self.db.execute(
+ "search_msgs", self.db.cursor_to_dict, sql, *args
+ )
results = list(filter(lambda row: row["room_id"] in room_ids, results))
- events = yield self.get_events_as_list([r["event_id"] for r in results])
+ # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
+ # search results (which is a data leak)
+ events = yield self.get_events_as_list(
+ [r["event_id"] for r in results],
+ redact_behaviour=EventRedactBehaviour.BLOCK,
+ )
event_map = {ev.event_id: ev for ev in events}
@@ -470,23 +469,21 @@ class SearchStore(BackgroundUpdateStore):
count_sql += " GROUP BY room_id"
- count_results = yield self._execute(
- "search_rooms_count", self.cursor_to_dict, count_sql, *count_args
+ count_results = yield self.db.execute(
+ "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args
)
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
- defer.returnValue(
- {
- "results": [
- {"event": event_map[r["event_id"]], "rank": r["rank"]}
- for r in results
- if r["event_id"] in event_map
- ],
- "highlights": highlights,
- "count": count,
- }
- )
+ return {
+ "results": [
+ {"event": event_map[r["event_id"]], "rank": r["rank"]}
+ for r in results
+ if r["event_id"] in event_map
+ ],
+ "highlights": highlights,
+ "count": count,
+ }
@defer.inlineCallbacks
def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None):
@@ -504,15 +501,17 @@ class SearchStore(BackgroundUpdateStore):
"""
clauses = []
- search_query = search_query = _parse_query(self.database_engine, search_term)
+ search_query = _parse_query(self.database_engine, search_term)
args = []
# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
if len(room_ids) < 500:
- clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),))
- args.extend(room_ids)
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", room_ids
+ )
+ clauses = [clause]
local_clauses = []
for key in keys:
@@ -601,11 +600,18 @@ class SearchStore(BackgroundUpdateStore):
args.append(limit)
- results = yield self._execute("search_rooms", self.cursor_to_dict, sql, *args)
+ results = yield self.db.execute(
+ "search_rooms", self.db.cursor_to_dict, sql, *args
+ )
results = list(filter(lambda row: row["room_id"] in room_ids, results))
- events = yield self.get_events_as_list([r["event_id"] for r in results])
+ # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
+ # search results (which is a data leak)
+ events = yield self.get_events_as_list(
+ [r["event_id"] for r in results],
+ redact_behaviour=EventRedactBehaviour.BLOCK,
+ )
event_map = {ev.event_id: ev for ev in events}
@@ -615,28 +621,26 @@ class SearchStore(BackgroundUpdateStore):
count_sql += " GROUP BY room_id"
- count_results = yield self._execute(
- "search_rooms_count", self.cursor_to_dict, count_sql, *count_args
+ count_results = yield self.db.execute(
+ "search_rooms_count", self.db.cursor_to_dict, count_sql, *count_args
)
count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
- defer.returnValue(
- {
- "results": [
- {
- "event": event_map[r["event_id"]],
- "rank": r["rank"],
- "pagination_token": "%s,%s"
- % (r["origin_server_ts"], r["stream_ordering"]),
- }
- for r in results
- if r["event_id"] in event_map
- ],
- "highlights": highlights,
- "count": count,
- }
- )
+ return {
+ "results": [
+ {
+ "event": event_map[r["event_id"]],
+ "rank": r["rank"],
+ "pagination_token": "%s,%s"
+ % (r["origin_server_ts"], r["stream_ordering"]),
+ }
+ for r in results
+ if r["event_id"] in event_map
+ ],
+ "highlights": highlights,
+ "count": count,
+ }
def _find_highlights_in_postgres(self, search_query, events):
"""Given a list of events and a search term, return a list of words
@@ -689,7 +693,7 @@ class SearchStore(BackgroundUpdateStore):
)
)
txn.execute(query, (value, search_query))
- headline, = txn.fetchall()[0]
+ (headline,) = txn.fetchall()[0]
# Now we need to pick the possible highlights out of the haedline
# result.
@@ -703,7 +707,7 @@ class SearchStore(BackgroundUpdateStore):
return highlight_words
- return self.runInteraction("_find_highlights", f)
+ return self.db.runInteraction("_find_highlights", f)
def _to_postgres_options(options_dict):
diff --git a/synapse/storage/signatures.py b/synapse/storage/data_stores/main/signatures.py
index 6bd81e84ad..563216b63c 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/data_stores/main/signatures.py
@@ -20,10 +20,9 @@ from unpaddedbase64 import encode_base64
from twisted.internet import defer
from synapse.crypto.event_signing import compute_event_reference_hash
+from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList
-from ._base import SQLBaseStore
-
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview
if six.PY2:
@@ -49,7 +48,7 @@ class SignatureWorkerStore(SQLBaseStore):
for event_id in event_ids
}
- return self.runInteraction("get_event_reference_hashes", f)
+ return self.db.runInteraction("get_event_reference_hashes", f)
@defer.inlineCallbacks
def add_event_hashes(self, event_ids):
@@ -59,7 +58,7 @@ class SignatureWorkerStore(SQLBaseStore):
for e_id, h in hashes.items()
}
- defer.returnValue(list(hashes.items()))
+ return list(hashes.items())
def _get_event_reference_hashes_txn(self, txn, event_id):
"""Get all the hashes for a given PDU.
@@ -99,4 +98,4 @@ class SignatureStore(SignatureWorkerStore):
}
)
- self._simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
+ self.db.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
new file mode 100644
index 0000000000..3a3b9a8e72
--- /dev/null
+++ b/synapse/storage/data_stores/main/state.py
@@ -0,0 +1,505 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import collections.abc
+import logging
+from collections import namedtuple
+from typing import Iterable, Tuple
+
+from six import iteritems
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.data_stores.main.roommember import RoomMemberWorkerStore
+from synapse.storage.database import Database
+from synapse.storage.state import StateFilter
+from synapse.util.caches import intern_string
+from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.stringutils import to_ascii
+
+logger = logging.getLogger(__name__)
+
+
+MAX_STATE_DELTA_HOPS = 100
+
+
+class _GetStateGroupDelta(
+ namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
+):
+ """Return type of get_state_group_delta that implements __len__, which lets
+ us use the itrable flag when caching
+ """
+
+ __slots__ = []
+
+ def __len__(self):
+ return len(self.delta_ids) if self.delta_ids else 0
+
+
+# this inherits from EventsWorkerStore because it calls self.get_events
+class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
+ """The parts of StateGroupStore that can be called from workers.
+ """
+
+ def __init__(self, database: Database, db_conn, hs):
+ super(StateGroupWorkerStore, self).__init__(database, db_conn, hs)
+
+ async def get_room_version(self, room_id: str) -> RoomVersion:
+ """Get the room_version of a given room
+
+ Raises:
+ NotFoundError: if the room is unknown
+
+ UnsupportedRoomVersionError: if the room uses an unknown room version.
+ Typically this happens if support for the room's version has been
+ removed from Synapse.
+ """
+ room_version_id = await self.get_room_version_id(room_id)
+ v = KNOWN_ROOM_VERSIONS.get(room_version_id)
+
+ if not v:
+ raise UnsupportedRoomVersionError(
+ "Room %s uses a room version %s which is no longer supported"
+ % (room_id, room_version_id)
+ )
+
+ return v
+
+ @cached(max_entries=10000)
+ async def get_room_version_id(self, room_id: str) -> str:
+ """Get the room_version of a given room
+
+ Raises:
+ NotFoundError: if the room is unknown
+ """
+
+ # First we try looking up room version from the database, but for old
+ # rooms we might not have added the room version to it yet so we fall
+ # back to previous behaviour and look in current state events.
+
+ # We really should have an entry in the rooms table for every room we
+ # care about, but let's be a bit paranoid (at least while the background
+ # update is happening) to avoid breaking existing rooms.
+ version = await self.db.simple_select_one_onecol(
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ retcol="room_version",
+ desc="get_room_version",
+ allow_none=True,
+ )
+
+ if version is not None:
+ return version
+
+ # Retrieve the room's create event
+ create_event = await self.get_create_event_for_room(room_id)
+ return create_event.content.get("room_version", "1")
+
+ @defer.inlineCallbacks
+ def get_room_predecessor(self, room_id):
+ """Get the predecessor of an upgraded room if it exists.
+ Otherwise return None.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[dict|None]: A dictionary containing the structure of the predecessor
+ field from the room's create event. The structure is subject to other servers,
+ but it is expected to be:
+ * room_id (str): The room ID of the predecessor room
+ * event_id (str): The ID of the tombstone event in the predecessor room
+
+ None if a predecessor key is not found, or is not a dictionary.
+
+ Raises:
+ NotFoundError if the given room is unknown
+ """
+ # Retrieve the room's create event
+ create_event = yield self.get_create_event_for_room(room_id)
+
+ # Retrieve the predecessor key of the create event
+ predecessor = create_event.content.get("predecessor", None)
+
+ # Ensure the key is a dictionary
+ if not isinstance(predecessor, collections.abc.Mapping):
+ return None
+
+ return predecessor
+
+ @defer.inlineCallbacks
+ def get_create_event_for_room(self, room_id):
+ """Get the create state event for a room.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[EventBase]: The room creation event.
+
+ Raises:
+ NotFoundError if the room is unknown
+ """
+ state_ids = yield self.get_current_state_ids(room_id)
+ create_id = state_ids.get((EventTypes.Create, ""))
+
+ # If we can't find the create event, assume we've hit a dead end
+ if not create_id:
+ raise NotFoundError("Unknown room %s" % (room_id,))
+
+ # Retrieve the room's create event and return
+ create_event = yield self.get_event(create_id)
+ return create_event
+
+ @cached(max_entries=100000, iterable=True)
+ def get_current_state_ids(self, room_id):
+ """Get the current state event ids for a room based on the
+ current_state_events table.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ deferred: dict of (type, state_key) -> event_id
+ """
+
+ def _get_current_state_ids_txn(txn):
+ txn.execute(
+ """SELECT type, state_key, event_id FROM current_state_events
+ WHERE room_id = ?
+ """,
+ (room_id,),
+ )
+
+ return {
+ (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
+ }
+
+ return self.db.runInteraction(
+ "get_current_state_ids", _get_current_state_ids_txn
+ )
+
+ # FIXME: how should this be cached?
+ def get_filtered_current_state_ids(
+ self, room_id: str, state_filter: StateFilter = StateFilter.all()
+ ):
+ """Get the current state event of a given type for a room based on the
+ current_state_events table. This may not be as up-to-date as the result
+ of doing a fresh state resolution as per state_handler.get_current_state
+
+ Args:
+ room_id
+ state_filter: The state filter used to fetch state
+ from the database.
+
+ Returns:
+ defer.Deferred[StateMap[str]]: Map from type/state_key to event ID.
+ """
+
+ where_clause, where_args = state_filter.make_sql_filter_clause()
+
+ if not where_clause:
+ # We delegate to the cached version
+ return self.get_current_state_ids(room_id)
+
+ def _get_filtered_current_state_ids_txn(txn):
+ results = {}
+ sql = """
+ SELECT type, state_key, event_id FROM current_state_events
+ WHERE room_id = ?
+ """
+
+ if where_clause:
+ sql += " AND (%s)" % (where_clause,)
+
+ args = [room_id]
+ args.extend(where_args)
+ txn.execute(sql, args)
+ for row in txn:
+ typ, state_key, event_id = row
+ key = (intern_string(typ), intern_string(state_key))
+ results[key] = event_id
+
+ return results
+
+ return self.db.runInteraction(
+ "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
+ )
+
+ @defer.inlineCallbacks
+ def get_canonical_alias_for_room(self, room_id):
+ """Get canonical alias for room, if any
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[str|None]: The canonical alias, if any
+ """
+
+ state = yield self.get_filtered_current_state_ids(
+ room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
+ )
+
+ event_id = state.get((EventTypes.CanonicalAlias, ""))
+ if not event_id:
+ return
+
+ event = yield self.get_event(event_id, allow_none=True)
+ if not event:
+ return
+
+ return event.content.get("canonical_alias")
+
+ @cached(max_entries=50000)
+ def _get_state_group_for_event(self, event_id):
+ return self.db.simple_select_one_onecol(
+ table="event_to_state_groups",
+ keyvalues={"event_id": event_id},
+ retcol="state_group",
+ allow_none=True,
+ desc="_get_state_group_for_event",
+ )
+
+ @cachedList(
+ cached_method_name="_get_state_group_for_event",
+ list_name="event_ids",
+ num_args=1,
+ inlineCallbacks=True,
+ )
+ def _get_state_group_for_events(self, event_ids):
+ """Returns mapping event_id -> state_group
+ """
+ rows = yield self.db.simple_select_many_batch(
+ table="event_to_state_groups",
+ column="event_id",
+ iterable=event_ids,
+ keyvalues={},
+ retcols=("event_id", "state_group"),
+ desc="_get_state_group_for_events",
+ )
+
+ return {row["event_id"]: row["state_group"] for row in rows}
+
+ @defer.inlineCallbacks
+ def get_referenced_state_groups(self, state_groups):
+ """Check if the state groups are referenced by events.
+
+ Args:
+ state_groups (Iterable[int])
+
+ Returns:
+ Deferred[set[int]]: The subset of state groups that are
+ referenced.
+ """
+
+ rows = yield self.db.simple_select_many_batch(
+ table="event_to_state_groups",
+ column="state_group",
+ iterable=state_groups,
+ keyvalues={},
+ retcols=("DISTINCT state_group",),
+ desc="get_referenced_state_groups",
+ )
+
+ return {row["state_group"] for row in rows}
+
+
+class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
+
+ CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
+ EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
+ DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
+
+ def __init__(self, database: Database, db_conn, hs):
+ super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+
+ self.server_name = hs.hostname
+
+ self.db.updates.register_background_index_update(
+ self.CURRENT_STATE_INDEX_UPDATE_NAME,
+ index_name="current_state_events_member_index",
+ table="current_state_events",
+ columns=["state_key"],
+ where_clause="type='m.room.member'",
+ )
+ self.db.updates.register_background_index_update(
+ self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME,
+ index_name="event_to_state_groups_sg_index",
+ table="event_to_state_groups",
+ columns=["state_group"],
+ )
+ self.db.updates.register_background_update_handler(
+ self.DELETE_CURRENT_STATE_UPDATE_NAME, self._background_remove_left_rooms,
+ )
+
+ async def _background_remove_left_rooms(self, progress, batch_size):
+ """Background update to delete rows from `current_state_events` and
+ `event_forward_extremities` tables of rooms that the server is no
+ longer joined to.
+ """
+
+ last_room_id = progress.get("last_room_id", "")
+
+ def _background_remove_left_rooms_txn(txn):
+ sql = """
+ SELECT DISTINCT room_id FROM current_state_events
+ WHERE room_id > ? ORDER BY room_id LIMIT ?
+ """
+
+ txn.execute(sql, (last_room_id, batch_size))
+ room_ids = [row[0] for row in txn]
+ if not room_ids:
+ return True, set()
+
+ sql = """
+ SELECT room_id
+ FROM current_state_events
+ WHERE
+ room_id > ? AND room_id <= ?
+ AND type = 'm.room.member'
+ AND membership = 'join'
+ AND state_key LIKE ?
+ GROUP BY room_id
+ """
+
+ txn.execute(sql, (last_room_id, room_ids[-1], "%:" + self.server_name))
+
+ joined_room_ids = {row[0] for row in txn}
+
+ left_rooms = set(room_ids) - joined_room_ids
+
+ logger.info("Deleting current state left rooms: %r", left_rooms)
+
+ # First we get all users that we still think were joined to the
+ # room. This is so that we can mark those device lists as
+ # potentially stale, since there may have been a period where the
+ # server didn't share a room with the remote user and therefore may
+ # have missed any device updates.
+ rows = self.db.simple_select_many_txn(
+ txn,
+ table="current_state_events",
+ column="room_id",
+ iterable=left_rooms,
+ keyvalues={"type": EventTypes.Member, "membership": Membership.JOIN},
+ retcols=("state_key",),
+ )
+
+ potentially_left_users = {row["state_key"] for row in rows}
+
+ # Now lets actually delete the rooms from the DB.
+ self.db.simple_delete_many_txn(
+ txn,
+ table="current_state_events",
+ column="room_id",
+ iterable=left_rooms,
+ keyvalues={},
+ )
+
+ self.db.simple_delete_many_txn(
+ txn,
+ table="event_forward_extremities",
+ column="room_id",
+ iterable=left_rooms,
+ keyvalues={},
+ )
+
+ self.db.updates._background_update_progress_txn(
+ txn,
+ self.DELETE_CURRENT_STATE_UPDATE_NAME,
+ {"last_room_id": room_ids[-1]},
+ )
+
+ return False, potentially_left_users
+
+ finished, potentially_left_users = await self.db.runInteraction(
+ "_background_remove_left_rooms", _background_remove_left_rooms_txn
+ )
+
+ if finished:
+ await self.db.updates._end_background_update(
+ self.DELETE_CURRENT_STATE_UPDATE_NAME
+ )
+
+ # Now go and check if we still share a room with the remote users in
+ # the deleted rooms. If not mark their device lists as stale.
+ joined_users = await self.get_users_server_still_shares_room_with(
+ potentially_left_users
+ )
+
+ for user_id in potentially_left_users - joined_users:
+ await self.mark_remote_user_device_list_as_unsubscribed(user_id)
+
+ return batch_size
+
+
+class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
+ """ Keeps track of the state at a given event.
+
+ This is done by the concept of `state groups`. Every event is a assigned
+ a state group (identified by an arbitrary string), which references a
+ collection of state events. The current state of an event is then the
+ collection of state events referenced by the event's state group.
+
+ Hence, every change in the current state causes a new state group to be
+ generated. However, if no change happens (e.g., if we get a message event
+ with only one parent it inherits the state group from its parent.)
+
+ There are three tables:
+ * `state_groups`: Stores group name, first event with in the group and
+ room id.
+ * `event_to_state_groups`: Maps events to state groups.
+ * `state_groups_state`: Maps state group to state events.
+ """
+
+ def __init__(self, database: Database, db_conn, hs):
+ super(StateStore, self).__init__(database, db_conn, hs)
+
+ def _store_event_state_mappings_txn(
+ self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
+ ):
+ state_groups = {}
+ for event, context in events_and_contexts:
+ if event.internal_metadata.is_outlier():
+ continue
+
+ # if the event was rejected, just give it the same state as its
+ # predecessor.
+ if context.rejected:
+ state_groups[event.event_id] = context.state_group_before_event
+ continue
+
+ state_groups[event.event_id] = context.state_group
+
+ self.db.simple_insert_many_txn(
+ txn,
+ table="event_to_state_groups",
+ values=[
+ {"state_group": state_group_id, "event_id": event_id}
+ for event_id, state_group_id in iteritems(state_groups)
+ ],
+ )
+
+ for event_id, state_group_id in iteritems(state_groups):
+ txn.call_after(
+ self._get_state_group_for_event.prefill, (event_id,), state_group_id
+ )
diff --git a/synapse/storage/state_deltas.py b/synapse/storage/data_stores/main/state_deltas.py
index 5fdb442104..725e12507f 100644
--- a/synapse/storage/state_deltas.py
+++ b/synapse/storage/data_stores/main/state_deltas.py
@@ -15,13 +15,15 @@
import logging
+from twisted.internet import defer
+
from synapse.storage._base import SQLBaseStore
logger = logging.getLogger(__name__)
class StateDeltasStore(SQLBaseStore):
- def get_current_state_deltas(self, prev_stream_id):
+ def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int):
"""Fetch a list of room state changes since the given stream id
Each entry in the result contains the following fields:
@@ -36,15 +38,27 @@ class StateDeltasStore(SQLBaseStore):
Args:
prev_stream_id (int): point to get changes since (exclusive)
+ max_stream_id (int): the point that we know has been correctly persisted
+ - ie, an upper limit to return changes from.
Returns:
- Deferred[list[dict]]: results
+ Deferred[tuple[int, list[dict]]: A tuple consisting of:
+ - the stream id which these results go up to
+ - list of current_state_delta_stream rows. If it is empty, we are
+ up to date.
"""
prev_stream_id = int(prev_stream_id)
+
+ # check we're not going backwards
+ assert prev_stream_id <= max_stream_id
+
if not self._curr_state_delta_stream_cache.has_any_entity_changed(
prev_stream_id
):
- return []
+ # if the CSDs haven't changed between prev_stream_id and now, we
+ # know for certain that they haven't changed between prev_stream_id and
+ # max_stream_id.
+ return defer.succeed((max_stream_id, []))
def get_current_state_deltas_txn(txn):
# First we calculate the max stream id that will give us less than
@@ -54,21 +68,29 @@ class StateDeltasStore(SQLBaseStore):
sql = """
SELECT stream_id, count(*)
FROM current_state_delta_stream
- WHERE stream_id > ?
+ WHERE stream_id > ? AND stream_id <= ?
GROUP BY stream_id
ORDER BY stream_id ASC
LIMIT 100
"""
- txn.execute(sql, (prev_stream_id,))
+ txn.execute(sql, (prev_stream_id, max_stream_id))
total = 0
- max_stream_id = prev_stream_id
- for max_stream_id, count in txn:
+
+ for stream_id, count in txn:
total += count
if total > 100:
# We arbitarily limit to 100 entries to ensure we don't
# select toooo many.
+ logger.debug(
+ "Clipping current_state_delta_stream rows to stream_id %i",
+ stream_id,
+ )
+ clipped_stream_id = stream_id
break
+ else:
+ # if there's no problem, we may as well go right up to the max_stream_id
+ clipped_stream_id = max_stream_id
# Now actually get the deltas
sql = """
@@ -77,15 +99,15 @@ class StateDeltasStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
"""
- txn.execute(sql, (prev_stream_id, max_stream_id))
- return self.cursor_to_dict(txn)
+ txn.execute(sql, (prev_stream_id, clipped_stream_id))
+ return clipped_stream_id, self.db.cursor_to_dict(txn)
- return self.runInteraction(
+ return self.db.runInteraction(
"get_current_state_deltas", get_current_state_deltas_txn
)
def _get_max_stream_id_in_current_state_deltas_txn(self, txn):
- return self._simple_select_one_onecol_txn(
+ return self.db.simple_select_one_onecol_txn(
txn,
table="current_state_delta_stream",
keyvalues={},
@@ -93,7 +115,7 @@ class StateDeltasStore(SQLBaseStore):
)
def get_max_stream_id_in_current_state_deltas(self):
- return self.runInteraction(
+ return self.db.runInteraction(
"get_max_stream_id_in_current_state_deltas",
self._get_max_stream_id_in_current_state_deltas_txn,
)
diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py
new file mode 100644
index 0000000000..380c1ec7da
--- /dev/null
+++ b/synapse/storage/data_stores/main/stats.py
@@ -0,0 +1,857 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018, 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from itertools import chain
+
+from twisted.internet import defer
+from twisted.internet.defer import DeferredLock
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.storage.data_stores.main.state_deltas import StateDeltasStore
+from synapse.storage.database import Database
+from synapse.storage.engines import PostgresEngine
+from synapse.util.caches.descriptors import cached
+
+logger = logging.getLogger(__name__)
+
+# these fields track absolutes (e.g. total number of rooms on the server)
+# You can think of these as Prometheus Gauges.
+# You can draw these stats on a line graph.
+# Example: number of users in a room
+ABSOLUTE_STATS_FIELDS = {
+ "room": (
+ "current_state_events",
+ "joined_members",
+ "invited_members",
+ "left_members",
+ "banned_members",
+ "local_users_in_room",
+ ),
+ "user": ("joined_rooms",),
+}
+
+# these fields are per-timeslice and so should be reset to 0 upon a new slice
+# You can draw these stats on a histogram.
+# Example: number of events sent locally during a time slice
+PER_SLICE_FIELDS = {
+ "room": ("total_events", "total_event_bytes"),
+ "user": ("invites_sent", "rooms_created", "total_events", "total_event_bytes"),
+}
+
+TYPE_TO_TABLE = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user_id")}
+
+# these are the tables (& ID columns) which contain our actual subjects
+TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")}
+
+
+class StatsStore(StateDeltasStore):
+ def __init__(self, database: Database, db_conn, hs):
+ super(StatsStore, self).__init__(database, db_conn, hs)
+
+ self.server_name = hs.hostname
+ self.clock = self.hs.get_clock()
+ self.stats_enabled = hs.config.stats_enabled
+ self.stats_bucket_size = hs.config.stats_bucket_size
+
+ self.stats_delta_processing_lock = DeferredLock()
+
+ self.db.updates.register_background_update_handler(
+ "populate_stats_process_rooms", self._populate_stats_process_rooms
+ )
+ self.db.updates.register_background_update_handler(
+ "populate_stats_process_users", self._populate_stats_process_users
+ )
+ # we no longer need to perform clean-up, but we will give ourselves
+ # the potential to reintroduce it in the future – so documentation
+ # will still encourage the use of this no-op handler.
+ self.db.updates.register_noop_background_update("populate_stats_cleanup")
+ self.db.updates.register_noop_background_update("populate_stats_prepare")
+
+ def quantise_stats_time(self, ts):
+ """
+ Quantises a timestamp to be a multiple of the bucket size.
+
+ Args:
+ ts (int): the timestamp to quantise, in milliseconds since the Unix
+ Epoch
+
+ Returns:
+ int: a timestamp which
+ - is divisible by the bucket size;
+ - is no later than `ts`; and
+ - is the largest such timestamp.
+ """
+ return (ts // self.stats_bucket_size) * self.stats_bucket_size
+
+ @defer.inlineCallbacks
+ def _populate_stats_process_users(self, progress, batch_size):
+ """
+ This is a background update which regenerates statistics for users.
+ """
+ if not self.stats_enabled:
+ yield self.db.updates._end_background_update("populate_stats_process_users")
+ return 1
+
+ last_user_id = progress.get("last_user_id", "")
+
+ def _get_next_batch(txn):
+ sql = """
+ SELECT DISTINCT name FROM users
+ WHERE name > ?
+ ORDER BY name ASC
+ LIMIT ?
+ """
+ txn.execute(sql, (last_user_id, batch_size))
+ return [r for r, in txn]
+
+ users_to_work_on = yield self.db.runInteraction(
+ "_populate_stats_process_users", _get_next_batch
+ )
+
+ # No more rooms -- complete the transaction.
+ if not users_to_work_on:
+ yield self.db.updates._end_background_update("populate_stats_process_users")
+ return 1
+
+ for user_id in users_to_work_on:
+ yield self._calculate_and_set_initial_state_for_user(user_id)
+ progress["last_user_id"] = user_id
+
+ yield self.db.runInteraction(
+ "populate_stats_process_users",
+ self.db.updates._background_update_progress_txn,
+ "populate_stats_process_users",
+ progress,
+ )
+
+ return len(users_to_work_on)
+
+ @defer.inlineCallbacks
+ def _populate_stats_process_rooms(self, progress, batch_size):
+ """
+ This is a background update which regenerates statistics for rooms.
+ """
+ if not self.stats_enabled:
+ yield self.db.updates._end_background_update("populate_stats_process_rooms")
+ return 1
+
+ last_room_id = progress.get("last_room_id", "")
+
+ def _get_next_batch(txn):
+ sql = """
+ SELECT DISTINCT room_id FROM current_state_events
+ WHERE room_id > ?
+ ORDER BY room_id ASC
+ LIMIT ?
+ """
+ txn.execute(sql, (last_room_id, batch_size))
+ return [r for r, in txn]
+
+ rooms_to_work_on = yield self.db.runInteraction(
+ "populate_stats_rooms_get_batch", _get_next_batch
+ )
+
+ # No more rooms -- complete the transaction.
+ if not rooms_to_work_on:
+ yield self.db.updates._end_background_update("populate_stats_process_rooms")
+ return 1
+
+ for room_id in rooms_to_work_on:
+ yield self._calculate_and_set_initial_state_for_room(room_id)
+ progress["last_room_id"] = room_id
+
+ yield self.db.runInteraction(
+ "_populate_stats_process_rooms",
+ self.db.updates._background_update_progress_txn,
+ "populate_stats_process_rooms",
+ progress,
+ )
+
+ return len(rooms_to_work_on)
+
+ def get_stats_positions(self):
+ """
+ Returns the stats processor positions.
+ """
+ return self.db.simple_select_one_onecol(
+ table="stats_incremental_position",
+ keyvalues={},
+ retcol="stream_id",
+ desc="stats_incremental_position",
+ )
+
+ def update_room_state(self, room_id, fields):
+ """
+ Args:
+ room_id (str)
+ fields (dict[str:Any])
+ """
+
+ # For whatever reason some of the fields may contain null bytes, which
+ # postgres isn't a fan of, so we replace those fields with null.
+ for col in (
+ "join_rules",
+ "history_visibility",
+ "encryption",
+ "name",
+ "topic",
+ "avatar",
+ "canonical_alias",
+ ):
+ field = fields.get(col)
+ if field and "\0" in field:
+ fields[col] = None
+
+ return self.db.simple_upsert(
+ table="room_stats_state",
+ keyvalues={"room_id": room_id},
+ values=fields,
+ desc="update_room_state",
+ )
+
+ def get_statistics_for_subject(self, stats_type, stats_id, start, size=100):
+ """
+ Get statistics for a given subject.
+
+ Args:
+ stats_type (str): The type of subject
+ stats_id (str): The ID of the subject (e.g. room_id or user_id)
+ start (int): Pagination start. Number of entries, not timestamp.
+ size (int): How many entries to return.
+
+ Returns:
+ Deferred[list[dict]], where the dict has the keys of
+ ABSOLUTE_STATS_FIELDS[stats_type], and "bucket_size" and "end_ts".
+ """
+ return self.db.runInteraction(
+ "get_statistics_for_subject",
+ self._get_statistics_for_subject_txn,
+ stats_type,
+ stats_id,
+ start,
+ size,
+ )
+
+ def _get_statistics_for_subject_txn(
+ self, txn, stats_type, stats_id, start, size=100
+ ):
+ """
+ Transaction-bound version of L{get_statistics_for_subject}.
+ """
+
+ table, id_col = TYPE_TO_TABLE[stats_type]
+ selected_columns = list(
+ ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type]
+ )
+
+ slice_list = self.db.simple_select_list_paginate_txn(
+ txn,
+ table + "_historical",
+ "end_ts",
+ start,
+ size,
+ retcols=selected_columns + ["bucket_size", "end_ts"],
+ keyvalues={id_col: stats_id},
+ order_direction="DESC",
+ )
+
+ return slice_list
+
+ @cached()
+ def get_earliest_token_for_stats(self, stats_type, id):
+ """
+ Fetch the "earliest token". This is used by the room stats delta
+ processor to ignore deltas that have been processed between the
+ start of the background task and any particular room's stats
+ being calculated.
+
+ Returns:
+ Deferred[int]
+ """
+ table, id_col = TYPE_TO_TABLE[stats_type]
+
+ return self.db.simple_select_one_onecol(
+ "%s_current" % (table,),
+ keyvalues={id_col: id},
+ retcol="completed_delta_stream_id",
+ allow_none=True,
+ )
+
+ def bulk_update_stats_delta(self, ts, updates, stream_id):
+ """Bulk update stats tables for a given stream_id and updates the stats
+ incremental position.
+
+ Args:
+ ts (int): Current timestamp in ms
+ updates(dict[str, dict[str, dict[str, Counter]]]): The updates to
+ commit as a mapping stats_type -> stats_id -> field -> delta.
+ stream_id (int): Current position.
+
+ Returns:
+ Deferred
+ """
+
+ def _bulk_update_stats_delta_txn(txn):
+ for stats_type, stats_updates in updates.items():
+ for stats_id, fields in stats_updates.items():
+ logger.debug(
+ "Updating %s stats for %s: %s", stats_type, stats_id, fields
+ )
+ self._update_stats_delta_txn(
+ txn,
+ ts=ts,
+ stats_type=stats_type,
+ stats_id=stats_id,
+ fields=fields,
+ complete_with_stream_id=stream_id,
+ )
+
+ self.db.simple_update_one_txn(
+ txn,
+ table="stats_incremental_position",
+ keyvalues={},
+ updatevalues={"stream_id": stream_id},
+ )
+
+ return self.db.runInteraction(
+ "bulk_update_stats_delta", _bulk_update_stats_delta_txn
+ )
+
+ def update_stats_delta(
+ self,
+ ts,
+ stats_type,
+ stats_id,
+ fields,
+ complete_with_stream_id,
+ absolute_field_overrides=None,
+ ):
+ """
+ Updates the statistics for a subject, with a delta (difference/relative
+ change).
+
+ Args:
+ ts (int): timestamp of the change
+ stats_type (str): "room" or "user" – the kind of subject
+ stats_id (str): the subject's ID (room ID or user ID)
+ fields (dict[str, int]): Deltas of stats values.
+ complete_with_stream_id (int, optional):
+ If supplied, converts an incomplete row into a complete row,
+ with the supplied stream_id marked as the stream_id where the
+ row was completed.
+ absolute_field_overrides (dict[str, int]): Current stats values
+ (i.e. not deltas) of absolute fields.
+ Does not work with per-slice fields.
+ """
+
+ return self.db.runInteraction(
+ "update_stats_delta",
+ self._update_stats_delta_txn,
+ ts,
+ stats_type,
+ stats_id,
+ fields,
+ complete_with_stream_id=complete_with_stream_id,
+ absolute_field_overrides=absolute_field_overrides,
+ )
+
+ def _update_stats_delta_txn(
+ self,
+ txn,
+ ts,
+ stats_type,
+ stats_id,
+ fields,
+ complete_with_stream_id,
+ absolute_field_overrides=None,
+ ):
+ if absolute_field_overrides is None:
+ absolute_field_overrides = {}
+
+ table, id_col = TYPE_TO_TABLE[stats_type]
+
+ quantised_ts = self.quantise_stats_time(int(ts))
+ end_ts = quantised_ts + self.stats_bucket_size
+
+ # Lets be paranoid and check that all the given field names are known
+ abs_field_names = ABSOLUTE_STATS_FIELDS[stats_type]
+ slice_field_names = PER_SLICE_FIELDS[stats_type]
+ for field in chain(fields.keys(), absolute_field_overrides.keys()):
+ if field not in abs_field_names and field not in slice_field_names:
+ # guard against potential SQL injection dodginess
+ raise ValueError(
+ "%s is not a recognised field"
+ " for stats type %s" % (field, stats_type)
+ )
+
+ # Per slice fields do not get added to the _current table
+
+ # This calculates the deltas (`field = field + ?` values)
+ # for absolute fields,
+ # * defaulting to 0 if not specified
+ # (required for the INSERT part of upserting to work)
+ # * omitting overrides specified in `absolute_field_overrides`
+ deltas_of_absolute_fields = {
+ key: fields.get(key, 0)
+ for key in abs_field_names
+ if key not in absolute_field_overrides
+ }
+
+ # Keep the delta stream ID field up to date
+ absolute_field_overrides = absolute_field_overrides.copy()
+ absolute_field_overrides["completed_delta_stream_id"] = complete_with_stream_id
+
+ # first upsert the `_current` table
+ self._upsert_with_additive_relatives_txn(
+ txn=txn,
+ table=table + "_current",
+ keyvalues={id_col: stats_id},
+ absolutes=absolute_field_overrides,
+ additive_relatives=deltas_of_absolute_fields,
+ )
+
+ per_slice_additive_relatives = {
+ key: fields.get(key, 0) for key in slice_field_names
+ }
+ self._upsert_copy_from_table_with_additive_relatives_txn(
+ txn=txn,
+ into_table=table + "_historical",
+ keyvalues={id_col: stats_id},
+ extra_dst_insvalues={"bucket_size": self.stats_bucket_size},
+ extra_dst_keyvalues={"end_ts": end_ts},
+ additive_relatives=per_slice_additive_relatives,
+ src_table=table + "_current",
+ copy_columns=abs_field_names,
+ )
+
+ def _upsert_with_additive_relatives_txn(
+ self, txn, table, keyvalues, absolutes, additive_relatives
+ ):
+ """Used to update values in the stats tables.
+
+ This is basically a slightly convoluted upsert that *adds* to any
+ existing rows.
+
+ Args:
+ txn
+ table (str): Table name
+ keyvalues (dict[str, any]): Row-identifying key values
+ absolutes (dict[str, any]): Absolute (set) fields
+ additive_relatives (dict[str, int]): Fields that will be added onto
+ if existing row present.
+ """
+ if self.database_engine.can_native_upsert:
+ absolute_updates = [
+ "%(field)s = EXCLUDED.%(field)s" % {"field": field}
+ for field in absolutes.keys()
+ ]
+
+ relative_updates = [
+ "%(field)s = EXCLUDED.%(field)s + %(table)s.%(field)s"
+ % {"table": table, "field": field}
+ for field in additive_relatives.keys()
+ ]
+
+ insert_cols = []
+ qargs = []
+
+ for (key, val) in chain(
+ keyvalues.items(), absolutes.items(), additive_relatives.items()
+ ):
+ insert_cols.append(key)
+ qargs.append(val)
+
+ sql = """
+ INSERT INTO %(table)s (%(insert_cols_cs)s)
+ VALUES (%(insert_vals_qs)s)
+ ON CONFLICT (%(key_columns)s) DO UPDATE SET %(updates)s
+ """ % {
+ "table": table,
+ "insert_cols_cs": ", ".join(insert_cols),
+ "insert_vals_qs": ", ".join(
+ ["?"] * (len(keyvalues) + len(absolutes) + len(additive_relatives))
+ ),
+ "key_columns": ", ".join(keyvalues),
+ "updates": ", ".join(chain(absolute_updates, relative_updates)),
+ }
+
+ txn.execute(sql, qargs)
+ else:
+ self.database_engine.lock_table(txn, table)
+ retcols = list(chain(absolutes.keys(), additive_relatives.keys()))
+ current_row = self.db.simple_select_one_txn(
+ txn, table, keyvalues, retcols, allow_none=True
+ )
+ if current_row is None:
+ merged_dict = {**keyvalues, **absolutes, **additive_relatives}
+ self.db.simple_insert_txn(txn, table, merged_dict)
+ else:
+ for (key, val) in additive_relatives.items():
+ current_row[key] += val
+ current_row.update(absolutes)
+ self.db.simple_update_one_txn(txn, table, keyvalues, current_row)
+
+ def _upsert_copy_from_table_with_additive_relatives_txn(
+ self,
+ txn,
+ into_table,
+ keyvalues,
+ extra_dst_keyvalues,
+ extra_dst_insvalues,
+ additive_relatives,
+ src_table,
+ copy_columns,
+ ):
+ """Updates the historic stats table with latest updates.
+
+ This involves copying "absolute" fields from the `_current` table, and
+ adding relative fields to any existing values.
+
+ Args:
+ txn: Transaction
+ into_table (str): The destination table to UPSERT the row into
+ keyvalues (dict[str, any]): Row-identifying key values
+ extra_dst_keyvalues (dict[str, any]): Additional keyvalues
+ for `into_table`.
+ extra_dst_insvalues (dict[str, any]): Additional values to insert
+ on new row creation for `into_table`.
+ additive_relatives (dict[str, any]): Fields that will be added onto
+ if existing row present. (Must be disjoint from copy_columns.)
+ src_table (str): The source table to copy from
+ copy_columns (iterable[str]): The list of columns to copy
+ """
+ if self.database_engine.can_native_upsert:
+ ins_columns = chain(
+ keyvalues,
+ copy_columns,
+ additive_relatives,
+ extra_dst_keyvalues,
+ extra_dst_insvalues,
+ )
+ sel_exprs = chain(
+ keyvalues,
+ copy_columns,
+ (
+ "?"
+ for _ in chain(
+ additive_relatives, extra_dst_keyvalues, extra_dst_insvalues
+ )
+ ),
+ )
+ keyvalues_where = ("%s = ?" % f for f in keyvalues)
+
+ sets_cc = ("%s = EXCLUDED.%s" % (f, f) for f in copy_columns)
+ sets_ar = (
+ "%s = EXCLUDED.%s + %s.%s" % (f, f, into_table, f)
+ for f in additive_relatives
+ )
+
+ sql = """
+ INSERT INTO %(into_table)s (%(ins_columns)s)
+ SELECT %(sel_exprs)s
+ FROM %(src_table)s
+ WHERE %(keyvalues_where)s
+ ON CONFLICT (%(keyvalues)s)
+ DO UPDATE SET %(sets)s
+ """ % {
+ "into_table": into_table,
+ "ins_columns": ", ".join(ins_columns),
+ "sel_exprs": ", ".join(sel_exprs),
+ "keyvalues_where": " AND ".join(keyvalues_where),
+ "src_table": src_table,
+ "keyvalues": ", ".join(
+ chain(keyvalues.keys(), extra_dst_keyvalues.keys())
+ ),
+ "sets": ", ".join(chain(sets_cc, sets_ar)),
+ }
+
+ qargs = list(
+ chain(
+ additive_relatives.values(),
+ extra_dst_keyvalues.values(),
+ extra_dst_insvalues.values(),
+ keyvalues.values(),
+ )
+ )
+ txn.execute(sql, qargs)
+ else:
+ self.database_engine.lock_table(txn, into_table)
+ src_row = self.db.simple_select_one_txn(
+ txn, src_table, keyvalues, copy_columns
+ )
+ all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues}
+ dest_current_row = self.db.simple_select_one_txn(
+ txn,
+ into_table,
+ keyvalues=all_dest_keyvalues,
+ retcols=list(chain(additive_relatives.keys(), copy_columns)),
+ allow_none=True,
+ )
+
+ if dest_current_row is None:
+ merged_dict = {
+ **keyvalues,
+ **extra_dst_keyvalues,
+ **extra_dst_insvalues,
+ **src_row,
+ **additive_relatives,
+ }
+ self.db.simple_insert_txn(txn, into_table, merged_dict)
+ else:
+ for (key, val) in additive_relatives.items():
+ src_row[key] = dest_current_row[key] + val
+ self.db.simple_update_txn(txn, into_table, all_dest_keyvalues, src_row)
+
+ def get_changes_room_total_events_and_bytes(self, min_pos, max_pos):
+ """Fetches the counts of events in the given range of stream IDs.
+
+ Args:
+ min_pos (int)
+ max_pos (int)
+
+ Returns:
+ Deferred[dict[str, dict[str, int]]]: Mapping of room ID to field
+ changes.
+ """
+
+ return self.db.runInteraction(
+ "stats_incremental_total_events_and_bytes",
+ self.get_changes_room_total_events_and_bytes_txn,
+ min_pos,
+ max_pos,
+ )
+
+ def get_changes_room_total_events_and_bytes_txn(self, txn, low_pos, high_pos):
+ """Gets the total_events and total_event_bytes counts for rooms and
+ senders, in a range of stream_orderings (including backfilled events).
+
+ Args:
+ txn
+ low_pos (int): Low stream ordering
+ high_pos (int): High stream ordering
+
+ Returns:
+ tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: The
+ room and user deltas for total_events/total_event_bytes in the
+ format of `stats_id` -> fields
+ """
+
+ if low_pos >= high_pos:
+ # nothing to do here.
+ return {}, {}
+
+ if isinstance(self.database_engine, PostgresEngine):
+ new_bytes_expression = "OCTET_LENGTH(json)"
+ else:
+ new_bytes_expression = "LENGTH(CAST(json AS BLOB))"
+
+ sql = """
+ SELECT events.room_id, COUNT(*) AS new_events, SUM(%s) AS new_bytes
+ FROM events INNER JOIN event_json USING (event_id)
+ WHERE (? < stream_ordering AND stream_ordering <= ?)
+ OR (? <= stream_ordering AND stream_ordering <= ?)
+ GROUP BY events.room_id
+ """ % (
+ new_bytes_expression,
+ )
+
+ txn.execute(sql, (low_pos, high_pos, -high_pos, -low_pos))
+
+ room_deltas = {
+ room_id: {"total_events": new_events, "total_event_bytes": new_bytes}
+ for room_id, new_events, new_bytes in txn
+ }
+
+ sql = """
+ SELECT events.sender, COUNT(*) AS new_events, SUM(%s) AS new_bytes
+ FROM events INNER JOIN event_json USING (event_id)
+ WHERE (? < stream_ordering AND stream_ordering <= ?)
+ OR (? <= stream_ordering AND stream_ordering <= ?)
+ GROUP BY events.sender
+ """ % (
+ new_bytes_expression,
+ )
+
+ txn.execute(sql, (low_pos, high_pos, -high_pos, -low_pos))
+
+ user_deltas = {
+ user_id: {"total_events": new_events, "total_event_bytes": new_bytes}
+ for user_id, new_events, new_bytes in txn
+ if self.hs.is_mine_id(user_id)
+ }
+
+ return room_deltas, user_deltas
+
+ @defer.inlineCallbacks
+ def _calculate_and_set_initial_state_for_room(self, room_id):
+ """Calculate and insert an entry into room_stats_current.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[tuple[dict, dict, int]]: A tuple of room state, membership
+ counts and stream position.
+ """
+
+ def _fetch_current_state_stats(txn):
+ pos = self.get_room_max_stream_ordering()
+
+ rows = self.db.simple_select_many_txn(
+ txn,
+ table="current_state_events",
+ column="type",
+ iterable=[
+ EventTypes.Create,
+ EventTypes.JoinRules,
+ EventTypes.RoomHistoryVisibility,
+ EventTypes.RoomEncryption,
+ EventTypes.Name,
+ EventTypes.Topic,
+ EventTypes.RoomAvatar,
+ EventTypes.CanonicalAlias,
+ ],
+ keyvalues={"room_id": room_id, "state_key": ""},
+ retcols=["event_id"],
+ )
+
+ event_ids = [row["event_id"] for row in rows]
+
+ txn.execute(
+ """
+ SELECT membership, count(*) FROM current_state_events
+ WHERE room_id = ? AND type = 'm.room.member'
+ GROUP BY membership
+ """,
+ (room_id,),
+ )
+ membership_counts = {membership: cnt for membership, cnt in txn}
+
+ txn.execute(
+ """
+ SELECT COALESCE(count(*), 0) FROM current_state_events
+ WHERE room_id = ?
+ """,
+ (room_id,),
+ )
+
+ (current_state_events_count,) = txn.fetchone()
+
+ users_in_room = self.get_users_in_room_txn(txn, room_id)
+
+ return (
+ event_ids,
+ membership_counts,
+ current_state_events_count,
+ users_in_room,
+ pos,
+ )
+
+ (
+ event_ids,
+ membership_counts,
+ current_state_events_count,
+ users_in_room,
+ pos,
+ ) = yield self.db.runInteraction(
+ "get_initial_state_for_room", _fetch_current_state_stats
+ )
+
+ state_event_map = yield self.get_events(event_ids, get_prev_content=False)
+
+ room_state = {
+ "join_rules": None,
+ "history_visibility": None,
+ "encryption": None,
+ "name": None,
+ "topic": None,
+ "avatar": None,
+ "canonical_alias": None,
+ "is_federatable": True,
+ }
+
+ for event in state_event_map.values():
+ if event.type == EventTypes.JoinRules:
+ room_state["join_rules"] = event.content.get("join_rule")
+ elif event.type == EventTypes.RoomHistoryVisibility:
+ room_state["history_visibility"] = event.content.get(
+ "history_visibility"
+ )
+ elif event.type == EventTypes.RoomEncryption:
+ room_state["encryption"] = event.content.get("algorithm")
+ elif event.type == EventTypes.Name:
+ room_state["name"] = event.content.get("name")
+ elif event.type == EventTypes.Topic:
+ room_state["topic"] = event.content.get("topic")
+ elif event.type == EventTypes.RoomAvatar:
+ room_state["avatar"] = event.content.get("url")
+ elif event.type == EventTypes.CanonicalAlias:
+ room_state["canonical_alias"] = event.content.get("alias")
+ elif event.type == EventTypes.Create:
+ room_state["is_federatable"] = (
+ event.content.get("m.federate", True) is True
+ )
+
+ yield self.update_room_state(room_id, room_state)
+
+ local_users_in_room = [u for u in users_in_room if self.hs.is_mine_id(u)]
+
+ yield self.update_stats_delta(
+ ts=self.clock.time_msec(),
+ stats_type="room",
+ stats_id=room_id,
+ fields={},
+ complete_with_stream_id=pos,
+ absolute_field_overrides={
+ "current_state_events": current_state_events_count,
+ "joined_members": membership_counts.get(Membership.JOIN, 0),
+ "invited_members": membership_counts.get(Membership.INVITE, 0),
+ "left_members": membership_counts.get(Membership.LEAVE, 0),
+ "banned_members": membership_counts.get(Membership.BAN, 0),
+ "local_users_in_room": len(local_users_in_room),
+ },
+ )
+
+ @defer.inlineCallbacks
+ def _calculate_and_set_initial_state_for_user(self, user_id):
+ def _calculate_and_set_initial_state_for_user_txn(txn):
+ pos = self._get_max_stream_id_in_current_state_deltas_txn(txn)
+
+ txn.execute(
+ """
+ SELECT COUNT(distinct room_id) FROM current_state_events
+ WHERE type = 'm.room.member' AND state_key = ?
+ AND membership = 'join'
+ """,
+ (user_id,),
+ )
+ (count,) = txn.fetchone()
+ return count, pos
+
+ joined_rooms, pos = yield self.db.runInteraction(
+ "calculate_and_set_initial_state_for_user",
+ _calculate_and_set_initial_state_for_user_txn,
+ )
+
+ yield self.update_stats_delta(
+ ts=self.clock.time_msec(),
+ stats_type="user",
+ stats_id=user_id,
+ fields={},
+ complete_with_stream_id=pos,
+ absolute_field_overrides={"joined_rooms": joined_rooms},
+ )
diff --git a/synapse/storage/stream.py b/synapse/storage/data_stores/main/stream.py
index 6f7f65d96b..ada5cce6c2 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -1,5 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -41,12 +44,13 @@ from six.moves import range
from twisted.internet import defer
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
+from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine
-from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
@@ -65,7 +69,7 @@ _EventDictReturn = namedtuple(
def generate_pagination_where_clause(
- direction, column_names, from_token, to_token, engine,
+ direction, column_names, from_token, to_token, engine
):
"""Creates an SQL expression to bound the columns by the pagination
tokens.
@@ -153,7 +157,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
str
"""
- assert(bound in (">", "<", ">=", "<="))
+ assert bound in (">", "<", ">=", "<=")
name1, name2 = column_names
val1, val2 = values
@@ -169,11 +173,7 @@ def _make_generic_sql_bound(bound, column_names, values, engine):
# Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
# as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
# use the later form when running against postgres.
- return "((%d,%d) %s (%s,%s))" % (
- val1, val2,
- bound,
- name1, name2,
- )
+ return "((%d,%d) %s (%s,%s))" % (val1, val2, bound, name1, name2)
# We want to generate queries of e.g. the form:
#
@@ -233,6 +233,14 @@ def filter_to_clause(event_filter):
clauses.append("contains_url = ?")
args.append(event_filter.contains_url)
+ # We're only applying the "labels" filter on the database query, because applying the
+ # "not_labels" filter via a SQL query is non-trivial. Instead, we let
+ # event_filter.check_fields apply it, which is not as efficient but makes the
+ # implementation simpler.
+ if event_filter.labels:
+ clauses.append("(%s)" % " OR ".join("label = ?" for _ in event_filter.labels))
+ args.extend(event_filter.labels)
+
return " AND ".join(clauses), args
@@ -244,11 +252,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
__metaclass__ = abc.ABCMeta
- def __init__(self, db_conn, hs):
- super(StreamWorkerStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(StreamWorkerStore, self).__init__(database, db_conn, hs)
events_max = self.get_room_max_stream_ordering()
- event_cache_prefill, min_event_val = self._get_cache_dict(
+ event_cache_prefill, min_event_val = self.db.get_cache_dict(
db_conn,
"events",
entity_column="room_id",
@@ -276,7 +284,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
@defer.inlineCallbacks
def get_room_events_stream_for_rooms(
- self, room_ids, from_key, to_key, limit=0, order='DESC'
+ self, room_ids, from_key, to_key, limit=0, order="DESC"
):
"""Get new room events in stream ordering since `from_key`.
@@ -304,7 +312,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
if not room_ids:
- defer.returnValue({})
+ return {}
results = {}
room_ids = list(room_ids)
@@ -327,7 +335,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
results.update(dict(zip(rm_ids, res)))
- defer.returnValue(results)
+ return results
def get_rooms_that_changed(self, room_ids, from_key):
"""Given a list of rooms and a token, return rooms where there may have
@@ -338,15 +346,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
from_key (str): The room_key portion of a StreamToken
"""
from_key = RoomStreamToken.parse_stream_token(from_key).stream
- return set(
+ return {
room_id
for room_id in room_ids
if self._events_stream_cache.has_entity_changed(room_id, from_key)
- )
+ }
@defer.inlineCallbacks
def get_room_events_stream_for_room(
- self, room_id, from_key, to_key, limit=0, order='DESC'
+ self, room_id, from_key, to_key, limit=0, order="DESC"
):
"""Get new room events in stream ordering since `from_key`.
@@ -368,7 +376,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
the chunk of events returned.
"""
if from_key == to_key:
- defer.returnValue(([], from_key))
+ return [], from_key
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
@@ -378,7 +386,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
if not has_changed:
- defer.returnValue(([], from_key))
+ return [], from_key
def f(txn):
sql = (
@@ -393,10 +401,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
return rows
- rows = yield self.runInteraction("get_room_events_stream_for_room", f)
+ rows = yield self.db.runInteraction("get_room_events_stream_for_room", f)
- ret = yield self.get_events_as_list([
- r.event_id for r in rows], get_prev_content=True,
+ ret = yield self.get_events_as_list(
+ [r.event_id for r in rows], get_prev_content=True
)
self._set_before_and_after(ret, rows, topo_order=from_id is None)
@@ -411,7 +419,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# get.
key = from_key
- defer.returnValue((ret, key))
+ return ret, key
@defer.inlineCallbacks
def get_membership_changes_for_user(self, user_id, from_key, to_key):
@@ -419,14 +427,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
to_id = RoomStreamToken.parse_stream_token(to_key).stream
if from_key == to_key:
- defer.returnValue([])
+ return []
if from_id:
has_changed = self._membership_stream_cache.has_entity_changed(
user_id, int(from_id)
)
if not has_changed:
- defer.returnValue([])
+ return []
def f(txn):
sql = (
@@ -443,15 +451,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return rows
- rows = yield self.runInteraction("get_membership_changes_for_user", f)
+ rows = yield self.db.runInteraction("get_membership_changes_for_user", f)
ret = yield self.get_events_as_list(
- [r.event_id for r in rows], get_prev_content=True,
+ [r.event_id for r in rows], get_prev_content=True
)
self._set_before_and_after(ret, rows, topo_order=False)
- defer.returnValue(ret)
+ return ret
@defer.inlineCallbacks
def get_recent_events_for_room(self, room_id, limit, end_token):
@@ -481,7 +489,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self._set_before_and_after(events, rows)
- defer.returnValue((events, token))
+ return (events, token)
@defer.inlineCallbacks
def get_recent_event_ids_for_room(self, room_id, limit, end_token):
@@ -500,11 +508,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
# Allow a zero limit here, and no-op.
if limit == 0:
- defer.returnValue(([], end_token))
+ return [], end_token
end_token = RoomStreamToken.parse(end_token)
- rows, token = yield self.runInteraction(
+ rows, token = yield self.db.runInteraction(
"get_recent_event_ids_for_room",
self._paginate_room_events_txn,
room_id,
@@ -515,10 +523,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# We want to return the results in ascending order.
rows.reverse()
- defer.returnValue((rows, token))
+ return rows, token
- def get_room_event_after_stream_ordering(self, room_id, stream_ordering):
- """Gets details of the first event in a room at or after a stream ordering
+ def get_room_event_before_stream_ordering(self, room_id, stream_ordering):
+ """Gets details of the first event in a room at or before a stream ordering
Args:
room_id (str):
@@ -533,15 +541,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
sql = (
"SELECT stream_ordering, topological_ordering, event_id"
" FROM events"
- " WHERE room_id = ? AND stream_ordering >= ?"
+ " WHERE room_id = ? AND stream_ordering <= ?"
" AND NOT outlier"
- " ORDER BY stream_ordering"
+ " ORDER BY stream_ordering DESC"
" LIMIT 1"
)
txn.execute(sql, (room_id, stream_ordering))
return txn.fetchone()
- return self.runInteraction("get_room_event_after_stream_ordering", _f)
+ return self.db.runInteraction("get_room_event_before_stream_ordering", _f)
@defer.inlineCallbacks
def get_room_events_max_id(self, room_id=None):
@@ -553,12 +561,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
token = yield self.get_room_max_stream_ordering()
if room_id is None:
- defer.returnValue("s%d" % (token,))
+ return "s%d" % (token,)
else:
- topo = yield self.runInteraction(
+ topo = yield self.db.runInteraction(
"_get_max_topological_txn", self._get_max_topological_txn, room_id
)
- defer.returnValue("t%d-%d" % (topo, token))
+ return "t%d-%d" % (topo, token)
def get_stream_token_for_event(self, event_id):
"""The stream token for an event
@@ -569,7 +577,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
A deferred "s%d" stream token.
"""
- return self._simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
).addCallback(lambda row: "s%d" % (row,))
@@ -582,7 +590,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
A deferred "t%d-%d" topological token.
"""
- return self._simple_select_one(
+ return self.db.simple_select_one(
table="events",
keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"),
@@ -606,13 +614,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"SELECT coalesce(max(topological_ordering), 0) FROM events"
" WHERE room_id = ? AND stream_ordering < ?"
)
- return self._execute(
+ return self.db.execute(
"get_max_topological_token", None, sql, room_id, stream_key
).addCallback(lambda r: r[0][0] if r else 0)
def _get_max_topological_txn(self, txn, room_id):
txn.execute(
- "SELECT MAX(topological_ordering) FROM events" " WHERE room_id = ?",
+ "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
(room_id,),
)
@@ -660,7 +668,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
dict
"""
- results = yield self.runInteraction(
+ results = yield self.db.runInteraction(
"get_events_around",
self._get_events_around_txn,
room_id,
@@ -671,21 +679,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
events_before = yield self.get_events_as_list(
- [e for e in results["before"]["event_ids"]], get_prev_content=True
+ list(results["before"]["event_ids"]), get_prev_content=True
)
events_after = yield self.get_events_as_list(
- [e for e in results["after"]["event_ids"]], get_prev_content=True
+ list(results["after"]["event_ids"]), get_prev_content=True
)
- defer.returnValue(
- {
- "events_before": events_before,
- "events_after": events_after,
- "start": results["before"]["token"],
- "end": results["after"]["token"],
- }
- )
+ return {
+ "events_before": events_before,
+ "events_after": events_after,
+ "start": results["before"]["token"],
+ "end": results["after"]["token"],
+ }
def _get_events_around_txn(
self, txn, room_id, event_id, before_limit, after_limit, event_filter
@@ -704,7 +710,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
dict
"""
- results = self._simple_select_one_txn(
+ results = self.db.simple_select_one_txn(
txn,
"events",
keyvalues={"event_id": event_id, "room_id": room_id},
@@ -725,7 +731,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn,
room_id,
before_token,
- direction='b',
+ direction="b",
limit=before_limit,
event_filter=event_filter,
)
@@ -735,7 +741,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
txn,
room_id,
after_token,
- direction='f',
+ direction="f",
limit=after_limit,
event_filter=event_filter,
)
@@ -783,16 +789,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, [row[1] for row in rows]
- upper_bound, event_ids = yield self.runInteraction(
+ upper_bound, event_ids = yield self.db.runInteraction(
"get_all_new_events_stream", get_all_new_events_stream_txn
)
events = yield self.get_events_as_list(event_ids)
- defer.returnValue((upper_bound, events))
+ return upper_bound, events
def get_federation_out_pos(self, typ):
- return self._simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="federation_stream_position",
retcol="stream_id",
keyvalues={"type": typ},
@@ -800,7 +806,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
def update_federation_out_pos(self, typ, stream_id):
- return self._simple_update_one(
+ return self.db.simple_update_one(
table="federation_stream_position",
keyvalues={"type": typ},
updatevalues={"stream_id": stream_id},
@@ -816,7 +822,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
room_id,
from_token,
to_token=None,
- direction='b',
+ direction="b",
limit=-1,
event_filter=None,
):
@@ -837,7 +843,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
Deferred[tuple[list[_EventDictReturn], str]]: Returns the results
as a list of _EventDictReturn and a token that points to the end
- of the result set.
+ of the result set. If no events are returned then the end of the
+ stream has been reached (i.e. there are no events between
+ `from_token` and `to_token`), or `limit` is zero.
"""
assert int(limit) >= 0
@@ -846,7 +854,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities.
args = [False, room_id]
- if direction == 'b':
+ if direction == "b":
order = "DESC"
else:
order = "ASC"
@@ -867,13 +875,38 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
args.append(int(limit))
- sql = (
- "SELECT event_id, topological_ordering, stream_ordering"
- " FROM events"
- " WHERE outlier = ? AND room_id = ? AND %(bounds)s"
- " ORDER BY topological_ordering %(order)s,"
- " stream_ordering %(order)s LIMIT ?"
- ) % {"bounds": bounds, "order": order}
+ select_keywords = "SELECT"
+ join_clause = ""
+ if event_filter and event_filter.labels:
+ # If we're not filtering on a label, then joining on event_labels will
+ # return as many row for a single event as the number of labels it has. To
+ # avoid this, only join if we're filtering on at least one label.
+ join_clause = """
+ LEFT JOIN event_labels
+ USING (event_id, room_id, topological_ordering)
+ """
+ if len(event_filter.labels) > 1:
+ # Using DISTINCT in this SELECT query is quite expensive, because it
+ # requires the engine to sort on the entire (not limited) result set,
+ # i.e. the entire events table. We only need to use it when we're
+ # filtering on more than two labels, because that's the only scenario
+ # in which we can possibly to get multiple times the same event ID in
+ # the results.
+ select_keywords += "DISTINCT"
+
+ sql = """
+ %(select_keywords)s event_id, topological_ordering, stream_ordering
+ FROM events
+ %(join_clause)s
+ WHERE outlier = ? AND room_id = ? AND %(bounds)s
+ ORDER BY topological_ordering %(order)s,
+ stream_ordering %(order)s LIMIT ?
+ """ % {
+ "select_keywords": select_keywords,
+ "join_clause": join_clause,
+ "bounds": bounds,
+ "order": order,
+ }
txn.execute(sql, args)
@@ -882,7 +915,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if rows:
topo = rows[-1].topological_ordering
toke = rows[-1].stream_ordering
- if direction == 'b':
+ if direction == "b":
# Tokens are positions between events.
# This token points *after* the last event in the chunk.
# We need it to point to the event before it in the chunk
@@ -898,7 +931,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
@defer.inlineCallbacks
def paginate_room_events(
- self, room_id, from_key, to_key=None, direction='b', limit=-1, event_filter=None
+ self, room_id, from_key, to_key=None, direction="b", limit=-1, event_filter=None
):
"""Returns list of events before or after a given token.
@@ -909,22 +942,22 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
only those before
direction(char): Either 'b' or 'f' to indicate whether we are
paginating forwards or backwards from `from_key`.
- limit (int): The maximum number of events to return. Zero or less
- means no limit.
+ limit (int): The maximum number of events to return.
event_filter (Filter|None): If provided filters the events to
those that match the filter.
Returns:
- tuple[list[dict], str]: Returns the results as a list of dicts and
- a token that points to the end of the result set. The dicts have
- the keys "event_id", "topological_ordering" and "stream_orderign".
+ tuple[list[FrozenEvent], str]: Returns the results as a list of
+ events and a token that points to the end of the result set. If no
+ events are returned then the end of the stream has been reached
+ (i.e. there are no events between `from_key` and `to_key`).
"""
from_key = RoomStreamToken.parse(from_key)
if to_key:
to_key = RoomStreamToken.parse(to_key)
- rows, token = yield self.runInteraction(
+ rows, token = yield self.db.runInteraction(
"paginate_room_events",
self._paginate_room_events_txn,
room_id,
@@ -941,7 +974,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self._set_before_and_after(events, rows)
- defer.returnValue((events, token))
+ return (events, token)
class StreamStore(StreamWorkerStore):
diff --git a/synapse/storage/tags.py b/synapse/storage/data_stores/main/tags.py
index e88f8ea35f..2aa1bafd48 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/data_stores/main/tags.py
@@ -22,7 +22,7 @@ from canonicaljson import json
from twisted.internet import defer
-from synapse.storage.account_data import AccountDataWorkerStore
+from synapse.storage.data_stores.main.account_data import AccountDataWorkerStore
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -41,7 +41,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
tag strings to tag content.
"""
- deferred = self._simple_select_list(
+ deferred = self.db.simple_select_list(
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
)
@@ -66,7 +66,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
room_id string, tag string and content string.
"""
if last_id == current_id:
- defer.returnValue([])
+ return []
def get_all_updated_tags_txn(txn):
sql = (
@@ -78,14 +78,12 @@ class TagsWorkerStore(AccountDataWorkerStore):
txn.execute(sql, (last_id, current_id, limit))
return txn.fetchall()
- tag_ids = yield self.runInteraction(
+ tag_ids = yield self.db.runInteraction(
"get_all_updated_tags", get_all_updated_tags_txn
)
def get_tag_content(txn, tag_ids):
- sql = (
- "SELECT tag, content" " FROM room_tags" " WHERE user_id=? AND room_id=?"
- )
+ sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
results = []
for stream_id, user_id, room_id in tag_ids:
txn.execute(sql, (user_id, room_id))
@@ -100,14 +98,14 @@ class TagsWorkerStore(AccountDataWorkerStore):
batch_size = 50
results = []
for i in range(0, len(tag_ids), batch_size):
- tags = yield self.runInteraction(
+ tags = yield self.db.runInteraction(
"get_all_updated_tag_content",
get_tag_content,
tag_ids[i : i + batch_size],
)
results.extend(tags)
- defer.returnValue(results)
+ return results
@defer.inlineCallbacks
def get_updated_tags(self, user_id, stream_id):
@@ -135,9 +133,11 @@ class TagsWorkerStore(AccountDataWorkerStore):
user_id, int(stream_id)
)
if not changed:
- defer.returnValue({})
+ return {}
- room_ids = yield self.runInteraction("get_updated_tags", get_updated_tags_txn)
+ room_ids = yield self.db.runInteraction(
+ "get_updated_tags", get_updated_tags_txn
+ )
results = {}
if room_ids:
@@ -145,7 +145,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
for room_id in room_ids:
results[room_id] = tags_by_room.get(room_id, {})
- defer.returnValue(results)
+ return results
def get_tags_for_room(self, user_id, room_id):
"""Get all the tags for the given room
@@ -155,7 +155,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
Returns:
A deferred list of string tags.
"""
- return self._simple_select_list(
+ return self.db.simple_select_list(
table="room_tags",
keyvalues={"user_id": user_id, "room_id": room_id},
retcols=("tag", "content"),
@@ -180,7 +180,7 @@ class TagsStore(TagsWorkerStore):
content_json = json.dumps(content)
def add_tag_txn(txn, next_id):
- self._simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="room_tags",
keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag},
@@ -189,12 +189,12 @@ class TagsStore(TagsWorkerStore):
self._update_revision_txn(txn, user_id, room_id, next_id)
with self._account_data_id_gen.get_next() as next_id:
- yield self.runInteraction("add_tag", add_tag_txn, next_id)
+ yield self.db.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
result = self._account_data_id_gen.get_current_token()
- defer.returnValue(result)
+ return result
@defer.inlineCallbacks
def remove_tag_from_room(self, user_id, room_id, tag):
@@ -212,12 +212,12 @@ class TagsStore(TagsWorkerStore):
self._update_revision_txn(txn, user_id, room_id, next_id)
with self._account_data_id_gen.get_next() as next_id:
- yield self.runInteraction("remove_tag", remove_tag_txn, next_id)
+ yield self.db.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
result = self._account_data_id_gen.get_current_token()
- defer.returnValue(result)
+ return result
def _update_revision_txn(self, txn, user_id, room_id, next_id):
"""Update the latest revision of the tags for the given user and room.
diff --git a/synapse/storage/transactions.py b/synapse/storage/data_stores/main/transactions.py
index b1188f6bcb..5b07c2fbc0 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/data_stores/main/transactions.py
@@ -23,10 +23,10 @@ from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import Database
from synapse.util.caches.expiringcache import ExpiringCache
-from ._base import SQLBaseStore, db_to_json
-
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview
if six.PY2:
@@ -53,8 +53,8 @@ class TransactionStore(SQLBaseStore):
"""A collection of queries for handling PDUs.
"""
- def __init__(self, db_conn, hs):
- super(TransactionStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(TransactionStore, self).__init__(database, db_conn, hs)
self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
@@ -78,7 +78,7 @@ class TransactionStore(SQLBaseStore):
this transaction or a 2-tuple of (int, dict)
"""
- return self.runInteraction(
+ return self.db.runInteraction(
"get_received_txn_response",
self._get_received_txn_response,
transaction_id,
@@ -86,7 +86,7 @@ class TransactionStore(SQLBaseStore):
)
def _get_received_txn_response(self, txn, transaction_id, origin):
- result = self._simple_select_one_txn(
+ result = self.db.simple_select_one_txn(
txn,
table="received_transactions",
keyvalues={"transaction_id": transaction_id, "origin": origin},
@@ -120,7 +120,7 @@ class TransactionStore(SQLBaseStore):
response_json (str)
"""
- return self._simple_insert(
+ return self.db.simple_insert(
table="received_transactions",
values={
"transaction_id": transaction_id,
@@ -133,34 +133,6 @@ class TransactionStore(SQLBaseStore):
desc="set_received_txn_response",
)
- def prep_send_transaction(self, transaction_id, destination, origin_server_ts):
- """Persists an outgoing transaction and calculates the values for the
- previous transaction id list.
-
- This should be called before sending the transaction so that it has the
- correct value for the `prev_ids` key.
-
- Args:
- transaction_id (str)
- destination (str)
- origin_server_ts (int)
-
- Returns:
- list: A list of previous transaction ids.
- """
- return defer.succeed([])
-
- def delivered_txn(self, transaction_id, destination, code, response_dict):
- """Persists the response for an outgoing transaction.
-
- Args:
- transaction_id (str)
- destination (str)
- code (int)
- response_json (str)
- """
- pass
-
@defer.inlineCallbacks
def get_destination_retry_timings(self, destination):
"""Gets the current retry timings (if any) for a given destination.
@@ -175,9 +147,9 @@ class TransactionStore(SQLBaseStore):
result = self._destination_retry_cache.get(destination, SENTINEL)
if result is not SENTINEL:
- defer.returnValue(result)
+ return result
- result = yield self.runInteraction(
+ result = yield self.db.runInteraction(
"get_destination_retry_timings",
self._get_destination_retry_timings,
destination,
@@ -186,14 +158,14 @@ class TransactionStore(SQLBaseStore):
# We don't hugely care about race conditions between getting and
# invalidating the cache, since we time out fairly quickly anyway.
self._destination_retry_cache[destination] = result
- defer.returnValue(result)
+ return result
def _get_destination_retry_timings(self, txn, destination):
- result = self._simple_select_one_txn(
+ result = self.db.simple_select_one_txn(
txn,
table="destinations",
keyvalues={"destination": destination},
- retcols=("destination", "retry_last_ts", "retry_interval"),
+ retcols=("destination", "failure_ts", "retry_last_ts", "retry_interval"),
allow_none=True,
)
@@ -202,82 +174,91 @@ class TransactionStore(SQLBaseStore):
else:
return None
- def set_destination_retry_timings(self, destination, retry_last_ts, retry_interval):
+ def set_destination_retry_timings(
+ self, destination, failure_ts, retry_last_ts, retry_interval
+ ):
"""Sets the current retry timings for a given destination.
Both timings should be zero if retrying is no longer occuring.
Args:
destination (str)
+ failure_ts (int|None) - when the server started failing (ms since epoch)
retry_last_ts (int) - time of last retry attempt in unix epoch ms
retry_interval (int) - how long until next retry in ms
"""
self._destination_retry_cache.pop(destination, None)
- return self.runInteraction(
+ return self.db.runInteraction(
"set_destination_retry_timings",
self._set_destination_retry_timings,
destination,
+ failure_ts,
retry_last_ts,
retry_interval,
)
def _set_destination_retry_timings(
- self, txn, destination, retry_last_ts, retry_interval
+ self, txn, destination, failure_ts, retry_last_ts, retry_interval
):
+
+ if self.database_engine.can_native_upsert:
+ # Upsert retry time interval if retry_interval is zero (i.e. we're
+ # resetting it) or greater than the existing retry interval.
+
+ sql = """
+ INSERT INTO destinations (
+ destination, failure_ts, retry_last_ts, retry_interval
+ )
+ VALUES (?, ?, ?, ?)
+ ON CONFLICT (destination) DO UPDATE SET
+ failure_ts = EXCLUDED.failure_ts,
+ retry_last_ts = EXCLUDED.retry_last_ts,
+ retry_interval = EXCLUDED.retry_interval
+ WHERE
+ EXCLUDED.retry_interval = 0
+ OR destinations.retry_interval < EXCLUDED.retry_interval
+ """
+
+ txn.execute(sql, (destination, failure_ts, retry_last_ts, retry_interval))
+
+ return
+
self.database_engine.lock_table(txn, "destinations")
# We need to be careful here as the data may have changed from under us
# due to a worker setting the timings.
- prev_row = self._simple_select_one_txn(
+ prev_row = self.db.simple_select_one_txn(
txn,
table="destinations",
keyvalues={"destination": destination},
- retcols=("retry_last_ts", "retry_interval"),
+ retcols=("failure_ts", "retry_last_ts", "retry_interval"),
allow_none=True,
)
if not prev_row:
- self._simple_insert_txn(
+ self.db.simple_insert_txn(
txn,
table="destinations",
values={
"destination": destination,
+ "failure_ts": failure_ts,
"retry_last_ts": retry_last_ts,
"retry_interval": retry_interval,
},
)
elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval:
- self._simple_update_one_txn(
+ self.db.simple_update_one_txn(
txn,
"destinations",
keyvalues={"destination": destination},
updatevalues={
+ "failure_ts": failure_ts,
"retry_last_ts": retry_last_ts,
"retry_interval": retry_interval,
},
)
- def get_destinations_needing_retry(self):
- """Get all destinations which are due a retry for sending a transaction.
-
- Returns:
- list: A list of dicts
- """
-
- return self.runInteraction(
- "get_destinations_needing_retry", self._get_destinations_needing_retry
- )
-
- def _get_destinations_needing_retry(self, txn):
- query = (
- "SELECT * FROM destinations"
- " WHERE retry_last_ts > 0 and retry_next_ts < ?"
- )
-
- txn.execute(query, (self._clock.time_msec(),))
- return self.cursor_to_dict(txn)
-
def _start_cleanup_transactions(self):
return run_as_background_process(
"cleanup_transactions", self._cleanup_transactions
@@ -290,4 +271,6 @@ class TransactionStore(SQLBaseStore):
def _cleanup_transactions_txn(txn):
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
- return self.runInteraction("_cleanup_transactions", _cleanup_transactions_txn)
+ return self.db.runInteraction(
+ "_cleanup_transactions", _cleanup_transactions_txn
+ )
diff --git a/synapse/storage/user_directory.py b/synapse/storage/data_stores/main/user_directory.py
index 83466e25d9..6b8130bf0f 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/data_stores/main/user_directory.py
@@ -19,10 +19,10 @@ import re
from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules
-from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.storage.data_stores.main.state import StateFilter
+from synapse.storage.data_stores.main.state_deltas import StateDeltasStore
+from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-from synapse.storage.state import StateFilter
-from synapse.storage.state_deltas import StateDeltasStore
from synapse.types import get_domain_from_id, get_localpart_from_id
from synapse.util.caches.descriptors import cached
@@ -32,30 +32,30 @@ logger = logging.getLogger(__name__)
TEMP_TABLE = "_temp_populate_user_directory"
-class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
+class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
# How many records do we calculate before sending it to
# add_users_who_share_private_rooms?
SHARE_PRIVATE_WORKING_SET = 500
- def __init__(self, db_conn, hs):
- super(UserDirectoryStore, self).__init__(db_conn, hs)
+ def __init__(self, database: Database, db_conn, hs):
+ super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs)
self.server_name = hs.hostname
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
"populate_user_directory_createtables",
self._populate_user_directory_createtables,
)
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
"populate_user_directory_process_rooms",
self._populate_user_directory_process_rooms,
)
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
"populate_user_directory_process_users",
self._populate_user_directory_process_users,
)
- self.register_background_update_handler(
+ self.db.updates.register_background_update_handler(
"populate_user_directory_cleanup", self._populate_user_directory_cleanup
)
@@ -85,7 +85,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
"""
txn.execute(sql)
rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()]
- self._simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
+ self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
del rooms
# If search all users is on, get all the users we want to add.
@@ -100,23 +100,25 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
txn.execute("SELECT name FROM users")
users = [{"user_id": x[0]} for x in txn.fetchall()]
- self._simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
+ self.db.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
new_pos = yield self.get_max_stream_id_in_current_state_deltas()
- yield self.runInteraction(
+ yield self.db.runInteraction(
"populate_user_directory_temp_build", _make_staging_area
)
- yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
+ yield self.db.simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
- yield self._end_background_update("populate_user_directory_createtables")
- defer.returnValue(1)
+ yield self.db.updates._end_background_update(
+ "populate_user_directory_createtables"
+ )
+ return 1
@defer.inlineCallbacks
def _populate_user_directory_cleanup(self, progress, batch_size):
"""
Update the user directory stream position, then clean up the old tables.
"""
- position = yield self._simple_select_one_onecol(
+ position = yield self.db.simple_select_one_onecol(
TEMP_TABLE + "_position", None, "position"
)
yield self.update_user_directory_stream_pos(position)
@@ -126,12 +128,12 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users")
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
- yield self.runInteraction(
+ yield self.db.runInteraction(
"populate_user_directory_cleanup", _delete_staging_area
)
- yield self._end_background_update("populate_user_directory_cleanup")
- defer.returnValue(1)
+ yield self.db.updates._end_background_update("populate_user_directory_cleanup")
+ return 1
@defer.inlineCallbacks
def _populate_user_directory_process_rooms(self, progress, batch_size):
@@ -170,16 +172,18 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
return rooms_to_work_on
- rooms_to_work_on = yield self.runInteraction(
+ rooms_to_work_on = yield self.db.runInteraction(
"populate_user_directory_temp_read", _get_next_batch
)
# No more rooms -- complete the transaction.
if not rooms_to_work_on:
- yield self._end_background_update("populate_user_directory_process_rooms")
- defer.returnValue(1)
+ yield self.db.updates._end_background_update(
+ "populate_user_directory_process_rooms"
+ )
+ return 1
- logger.info(
+ logger.debug(
"Processing the next %d rooms of %d remaining"
% (len(rooms_to_work_on), progress["remaining"])
)
@@ -243,12 +247,12 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
to_insert.clear()
# We've finished a room. Delete it from the table.
- yield self._simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id})
+ yield self.db.simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id})
# Update the remaining counter.
progress["remaining"] -= 1
- yield self.runInteraction(
+ yield self.db.runInteraction(
"populate_user_directory",
- self._background_update_progress_txn,
+ self.db.updates._background_update_progress_txn,
"populate_user_directory_process_rooms",
progress,
)
@@ -257,9 +261,9 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
if processed_event_count > batch_size:
# Don't process any more rooms, we've hit our batch size.
- defer.returnValue(processed_event_count)
+ return processed_event_count
- defer.returnValue(processed_event_count)
+ return processed_event_count
@defer.inlineCallbacks
def _populate_user_directory_process_users(self, progress, batch_size):
@@ -267,8 +271,10 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
If search_all_users is enabled, add all of the users to the user directory.
"""
if not self.hs.config.user_directory_search_all_users:
- yield self._end_background_update("populate_user_directory_process_users")
- defer.returnValue(1)
+ yield self.db.updates._end_background_update(
+ "populate_user_directory_process_users"
+ )
+ return 1
def _get_next_batch(txn):
sql = "SELECT user_id FROM %s LIMIT %s" % (
@@ -291,16 +297,18 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
return users_to_work_on
- users_to_work_on = yield self.runInteraction(
+ users_to_work_on = yield self.db.runInteraction(
"populate_user_directory_temp_read", _get_next_batch
)
# No more users -- complete the transaction.
if not users_to_work_on:
- yield self._end_background_update("populate_user_directory_process_users")
- defer.returnValue(1)
+ yield self.db.updates._end_background_update(
+ "populate_user_directory_process_users"
+ )
+ return 1
- logger.info(
+ logger.debug(
"Processing the next %d users of %d remaining"
% (len(users_to_work_on), progress["remaining"])
)
@@ -312,17 +320,17 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
)
# We've finished processing a user. Delete it from the table.
- yield self._simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id})
+ yield self.db.simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id})
# Update the remaining counter.
progress["remaining"] -= 1
- yield self.runInteraction(
+ yield self.db.runInteraction(
"populate_user_directory",
- self._background_update_progress_txn,
+ self.db.updates._background_update_progress_txn,
"populate_user_directory_process_users",
progress,
)
- defer.returnValue(len(users_to_work_on))
+ return len(users_to_work_on)
@defer.inlineCallbacks
def is_room_world_readable_or_publicly_joinable(self, room_id):
@@ -344,16 +352,16 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
join_rule_ev = yield self.get_event(join_rules_id, allow_none=True)
if join_rule_ev:
if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
- defer.returnValue(True)
+ return True
hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
if hist_vis_id:
hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True)
if hist_vis_ev:
if hist_vis_ev.content.get("history_visibility") == "world_readable":
- defer.returnValue(True)
+ return True
- defer.returnValue(False)
+ return False
def update_profile_in_user_dir(self, user_id, display_name, avatar_url):
"""
@@ -361,7 +369,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
"""
def _update_profile_in_user_dir_txn(txn):
- new_entry = self._simple_upsert_txn(
+ new_entry = self.db.simple_upsert_txn(
txn,
table="user_directory",
keyvalues={"user_id": user_id},
@@ -435,7 +443,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
)
elif isinstance(self.database_engine, Sqlite3Engine):
value = "%s %s" % (user_id, display_name) if display_name else user_id
- self._simple_upsert_txn(
+ self.db.simple_upsert_txn(
txn,
table="user_directory_search",
keyvalues={"user_id": user_id},
@@ -448,59 +456,10 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
- return self.runInteraction(
+ return self.db.runInteraction(
"update_profile_in_user_dir", _update_profile_in_user_dir_txn
)
- def remove_from_user_dir(self, user_id):
- def _remove_from_user_dir_txn(txn):
- self._simple_delete_txn(
- txn, table="user_directory", keyvalues={"user_id": user_id}
- )
- self._simple_delete_txn(
- txn, table="user_directory_search", keyvalues={"user_id": user_id}
- )
- self._simple_delete_txn(
- txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
- )
- self._simple_delete_txn(
- txn,
- table="users_who_share_private_rooms",
- keyvalues={"user_id": user_id},
- )
- self._simple_delete_txn(
- txn,
- table="users_who_share_private_rooms",
- keyvalues={"other_user_id": user_id},
- )
- txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
-
- return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
-
- @defer.inlineCallbacks
- def get_users_in_dir_due_to_room(self, room_id):
- """Get all user_ids that are in the room directory because they're
- in the given room_id
- """
- user_ids_share_pub = yield self._simple_select_onecol(
- table="users_in_public_rooms",
- keyvalues={"room_id": room_id},
- retcol="user_id",
- desc="get_users_in_dir_due_to_room",
- )
-
- user_ids_share_priv = yield self._simple_select_onecol(
- table="users_who_share_private_rooms",
- keyvalues={"room_id": room_id},
- retcol="other_user_id",
- desc="get_users_in_dir_due_to_room",
- )
-
- user_ids = set(user_ids_share_pub)
- user_ids.update(user_ids_share_priv)
-
- defer.returnValue(user_ids)
-
def add_users_who_share_private_room(self, room_id, user_id_tuples):
"""Insert entries into the users_who_share_private_rooms table. The first
user should be a local user.
@@ -511,7 +470,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
"""
def _add_users_who_share_room_txn(txn):
- self._simple_upsert_many_txn(
+ self.db.simple_upsert_many_txn(
txn,
table="users_who_share_private_rooms",
key_names=["user_id", "other_user_id", "room_id"],
@@ -523,7 +482,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
value_values=None,
)
- return self.runInteraction(
+ return self.db.runInteraction(
"add_users_who_share_room", _add_users_who_share_room_txn
)
@@ -538,7 +497,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
def _add_users_in_public_rooms_txn(txn):
- self._simple_upsert_many_txn(
+ self.db.simple_upsert_many_txn(
txn,
table="users_in_public_rooms",
key_names=["user_id", "room_id"],
@@ -547,10 +506,102 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
value_values=None,
)
- return self.runInteraction(
+ return self.db.runInteraction(
"add_users_in_public_rooms", _add_users_in_public_rooms_txn
)
+ def delete_all_from_user_dir(self):
+ """Delete the entire user directory
+ """
+
+ def _delete_all_from_user_dir_txn(txn):
+ txn.execute("DELETE FROM user_directory")
+ txn.execute("DELETE FROM user_directory_search")
+ txn.execute("DELETE FROM users_in_public_rooms")
+ txn.execute("DELETE FROM users_who_share_private_rooms")
+ txn.call_after(self.get_user_in_directory.invalidate_all)
+
+ return self.db.runInteraction(
+ "delete_all_from_user_dir", _delete_all_from_user_dir_txn
+ )
+
+ @cached()
+ def get_user_in_directory(self, user_id):
+ return self.db.simple_select_one(
+ table="user_directory",
+ keyvalues={"user_id": user_id},
+ retcols=("display_name", "avatar_url"),
+ allow_none=True,
+ desc="get_user_in_directory",
+ )
+
+ def update_user_directory_stream_pos(self, stream_id):
+ return self.db.simple_update_one(
+ table="user_directory_stream_pos",
+ keyvalues={},
+ updatevalues={"stream_id": stream_id},
+ desc="update_user_directory_stream_pos",
+ )
+
+
+class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
+
+ # How many records do we calculate before sending it to
+ # add_users_who_share_private_rooms?
+ SHARE_PRIVATE_WORKING_SET = 500
+
+ def __init__(self, database: Database, db_conn, hs):
+ super(UserDirectoryStore, self).__init__(database, db_conn, hs)
+
+ def remove_from_user_dir(self, user_id):
+ def _remove_from_user_dir_txn(txn):
+ self.db.simple_delete_txn(
+ txn, table="user_directory", keyvalues={"user_id": user_id}
+ )
+ self.db.simple_delete_txn(
+ txn, table="user_directory_search", keyvalues={"user_id": user_id}
+ )
+ self.db.simple_delete_txn(
+ txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
+ )
+ self.db.simple_delete_txn(
+ txn,
+ table="users_who_share_private_rooms",
+ keyvalues={"user_id": user_id},
+ )
+ self.db.simple_delete_txn(
+ txn,
+ table="users_who_share_private_rooms",
+ keyvalues={"other_user_id": user_id},
+ )
+ txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
+
+ return self.db.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
+
+ @defer.inlineCallbacks
+ def get_users_in_dir_due_to_room(self, room_id):
+ """Get all user_ids that are in the room directory because they're
+ in the given room_id
+ """
+ user_ids_share_pub = yield self.db.simple_select_onecol(
+ table="users_in_public_rooms",
+ keyvalues={"room_id": room_id},
+ retcol="user_id",
+ desc="get_users_in_dir_due_to_room",
+ )
+
+ user_ids_share_priv = yield self.db.simple_select_onecol(
+ table="users_who_share_private_rooms",
+ keyvalues={"room_id": room_id},
+ retcol="other_user_id",
+ desc="get_users_in_dir_due_to_room",
+ )
+
+ user_ids = set(user_ids_share_pub)
+ user_ids.update(user_ids_share_priv)
+
+ return user_ids
+
def remove_user_who_share_room(self, user_id, room_id):
"""
Deletes entries in the users_who_share_*_rooms table. The first
@@ -562,23 +613,23 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
"""
def _remove_user_who_share_room_txn(txn):
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"user_id": user_id, "room_id": room_id},
)
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"other_user_id": user_id, "room_id": room_id},
)
- self._simple_delete_txn(
+ self.db.simple_delete_txn(
txn,
table="users_in_public_rooms",
keyvalues={"user_id": user_id, "room_id": room_id},
)
- return self.runInteraction(
+ return self.db.runInteraction(
"remove_user_who_share_room", _remove_user_who_share_room_txn
)
@@ -593,14 +644,14 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
Returns:
list: user_id
"""
- rows = yield self._simple_select_onecol(
+ rows = yield self.db.simple_select_onecol(
table="users_who_share_private_rooms",
keyvalues={"user_id": user_id},
retcol="room_id",
desc="get_rooms_user_is_in",
)
- pub_rows = yield self._simple_select_onecol(
+ pub_rows = yield self.db.simple_select_onecol(
table="users_in_public_rooms",
keyvalues={"user_id": user_id},
retcol="room_id",
@@ -609,7 +660,7 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
users = set(pub_rows)
users.update(rows)
- defer.returnValue(list(users))
+ return list(users)
@defer.inlineCallbacks
def get_rooms_in_common_for_users(self, user_id, other_user_id):
@@ -618,66 +669,33 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
sql = """
SELECT room_id FROM (
SELECT c.room_id FROM current_state_events AS c
- INNER JOIN room_memberships USING (event_id)
+ INNER JOIN room_memberships AS m USING (event_id)
WHERE type = 'm.room.member'
- AND membership = 'join'
+ AND m.membership = 'join'
AND state_key = ?
) AS f1 INNER JOIN (
SELECT c.room_id FROM current_state_events AS c
- INNER JOIN room_memberships USING (event_id)
+ INNER JOIN room_memberships AS m USING (event_id)
WHERE type = 'm.room.member'
- AND membership = 'join'
+ AND m.membership = 'join'
AND state_key = ?
) f2 USING (room_id)
"""
- rows = yield self._execute(
+ rows = yield self.db.execute(
"get_rooms_in_common_for_users", None, sql, user_id, other_user_id
)
- defer.returnValue([room_id for room_id, in rows])
-
- def delete_all_from_user_dir(self):
- """Delete the entire user directory
- """
-
- def _delete_all_from_user_dir_txn(txn):
- txn.execute("DELETE FROM user_directory")
- txn.execute("DELETE FROM user_directory_search")
- txn.execute("DELETE FROM users_in_public_rooms")
- txn.execute("DELETE FROM users_who_share_private_rooms")
- txn.call_after(self.get_user_in_directory.invalidate_all)
-
- return self.runInteraction(
- "delete_all_from_user_dir", _delete_all_from_user_dir_txn
- )
-
- @cached()
- def get_user_in_directory(self, user_id):
- return self._simple_select_one(
- table="user_directory",
- keyvalues={"user_id": user_id},
- retcols=("display_name", "avatar_url"),
- allow_none=True,
- desc="get_user_in_directory",
- )
+ return [room_id for room_id, in rows]
def get_user_directory_stream_pos(self):
- return self._simple_select_one_onecol(
+ return self.db.simple_select_one_onecol(
table="user_directory_stream_pos",
keyvalues={},
retcol="stream_id",
desc="get_user_directory_stream_pos",
)
- def update_user_directory_stream_pos(self, stream_id):
- return self._simple_update_one(
- table="user_directory_stream_pos",
- keyvalues={},
- updatevalues={"stream_id": stream_id},
- desc="update_user_directory_stream_pos",
- )
-
@defer.inlineCallbacks
def search_user_dir(self, user_id, search_term, limit):
"""Searches for users in directory
@@ -776,13 +794,13 @@ class UserDirectoryStore(StateDeltasStore, BackgroundUpdateStore):
# This should be unreachable.
raise Exception("Unrecognized database engine")
- results = yield self._execute(
- "search_user_dir", self.cursor_to_dict, sql, *args
+ results = yield self.db.execute(
+ "search_user_dir", self.db.cursor_to_dict, sql, *args
)
limited = len(results) > limit
- defer.returnValue({"limited": limited, "results": results})
+ return {"limited": limited, "results": results}
def _parse_query_sqlite(search_term):
diff --git a/synapse/storage/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py
index 1815fdc0dd..ec6b8a4ffd 100644
--- a/synapse/storage/user_erasure_store.py
+++ b/synapse/storage/data_stores/main/user_erasure_store.py
@@ -12,9 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import operator
-from twisted.internet import defer
+import operator
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedList
@@ -32,7 +31,7 @@ class UserErasureWorkerStore(SQLBaseStore):
Returns:
Deferred[bool]: True if the user has requested erasure
"""
- return self._simple_select_onecol(
+ return self.db.simple_select_onecol(
table="erased_users",
keyvalues={"user_id": user_id},
retcol="1",
@@ -57,17 +56,17 @@ class UserErasureWorkerStore(SQLBaseStore):
# iterate it multiple times, and (b) avoiding duplicates.
user_ids = tuple(set(user_ids))
- def _get_erased_users(txn):
- txn.execute(
- "SELECT user_id FROM erased_users WHERE user_id IN (%s)"
- % (",".join("?" * len(user_ids))),
- user_ids,
- )
- return set(r[0] for r in txn)
+ rows = yield self.db.simple_select_many_batch(
+ table="erased_users",
+ column="user_id",
+ iterable=user_ids,
+ retcols=("user_id",),
+ desc="are_users_erased",
+ )
+ erased_users = {row["user_id"] for row in rows}
- erased_users = yield self.runInteraction("are_users_erased", _get_erased_users)
- res = dict((u, u in erased_users) for u in user_ids)
- defer.returnValue(res)
+ res = {u: u in erased_users for u in user_ids}
+ return res
class UserErasureStore(UserErasureWorkerStore):
@@ -89,4 +88,4 @@ class UserErasureStore(UserErasureWorkerStore):
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
- return self.runInteraction("mark_user_erased", f)
+ return self.db.runInteraction("mark_user_erased", f)
diff --git a/synapse/metrics/resource.py b/synapse/storage/data_stores/state/__init__.py
index 9789359077..86e09f6229 100644
--- a/synapse/metrics/resource.py
+++ b/synapse/storage/data_stores/state/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,8 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from prometheus_client.twisted import MetricsResource
-
-METRICS_PREFIX = "/_synapse/metrics"
-
-__all__ = ["MetricsResource", "METRICS_PREFIX"]
+from synapse.storage.data_stores.state.store import StateGroupDataStore # noqa: F401
diff --git a/synapse/storage/data_stores/state/bg_updates.py b/synapse/storage/data_stores/state/bg_updates.py
new file mode 100644
index 0000000000..e8edaf9f7b
--- /dev/null
+++ b/synapse/storage/data_stores/state/bg_updates.py
@@ -0,0 +1,374 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from six import iteritems
+
+from twisted.internet import defer
+
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import Database
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.state import StateFilter
+
+logger = logging.getLogger(__name__)
+
+
+MAX_STATE_DELTA_HOPS = 100
+
+
+class StateGroupBackgroundUpdateStore(SQLBaseStore):
+ """Defines functions related to state groups needed to run the state backgroud
+ updates.
+ """
+
+ def _count_state_group_hops_txn(self, txn, state_group):
+ """Given a state group, count how many hops there are in the tree.
+
+ This is used to ensure the delta chains don't get too long.
+ """
+ if isinstance(self.database_engine, PostgresEngine):
+ sql = """
+ WITH RECURSIVE state(state_group) AS (
+ VALUES(?::bigint)
+ UNION ALL
+ SELECT prev_state_group FROM state_group_edges e, state s
+ WHERE s.state_group = e.state_group
+ )
+ SELECT count(*) FROM state;
+ """
+
+ txn.execute(sql, (state_group,))
+ row = txn.fetchone()
+ if row and row[0]:
+ return row[0]
+ else:
+ return 0
+ else:
+ # We don't use WITH RECURSIVE on sqlite3 as there are distributions
+ # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
+ next_group = state_group
+ count = 0
+
+ while next_group:
+ next_group = self.db.simple_select_one_onecol_txn(
+ txn,
+ table="state_group_edges",
+ keyvalues={"state_group": next_group},
+ retcol="prev_state_group",
+ allow_none=True,
+ )
+ if next_group:
+ count += 1
+
+ return count
+
+ def _get_state_groups_from_groups_txn(
+ self, txn, groups, state_filter=StateFilter.all()
+ ):
+ results = {group: {} for group in groups}
+
+ where_clause, where_args = state_filter.make_sql_filter_clause()
+
+ # Unless the filter clause is empty, we're going to append it after an
+ # existing where clause
+ if where_clause:
+ where_clause = " AND (%s)" % (where_clause,)
+
+ if isinstance(self.database_engine, PostgresEngine):
+ # Temporarily disable sequential scans in this transaction. This is
+ # a temporary hack until we can add the right indices in
+ txn.execute("SET LOCAL enable_seqscan=off")
+
+ # The below query walks the state_group tree so that the "state"
+ # table includes all state_groups in the tree. It then joins
+ # against `state_groups_state` to fetch the latest state.
+ # It assumes that previous state groups are always numerically
+ # lesser.
+ # The PARTITION is used to get the event_id in the greatest state
+ # group for the given type, state_key.
+ # This may return multiple rows per (type, state_key), but last_value
+ # should be the same.
+ sql = """
+ WITH RECURSIVE state(state_group) AS (
+ VALUES(?::bigint)
+ UNION ALL
+ SELECT prev_state_group FROM state_group_edges e, state s
+ WHERE s.state_group = e.state_group
+ )
+ SELECT DISTINCT type, state_key, last_value(event_id) OVER (
+ PARTITION BY type, state_key ORDER BY state_group ASC
+ ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
+ ) AS event_id FROM state_groups_state
+ WHERE state_group IN (
+ SELECT state_group FROM state
+ )
+ """
+
+ for group in groups:
+ args = [group]
+ args.extend(where_args)
+
+ txn.execute(sql + where_clause, args)
+ for row in txn:
+ typ, state_key, event_id = row
+ key = (typ, state_key)
+ results[group][key] = event_id
+ else:
+ max_entries_returned = state_filter.max_entries_returned()
+
+ # We don't use WITH RECURSIVE on sqlite3 as there are distributions
+ # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
+ for group in groups:
+ next_group = group
+
+ while next_group:
+ # We did this before by getting the list of group ids, and
+ # then passing that list to sqlite to get latest event for
+ # each (type, state_key). However, that was terribly slow
+ # without the right indices (which we can't add until
+ # after we finish deduping state, which requires this func)
+ args = [next_group]
+ args.extend(where_args)
+
+ txn.execute(
+ "SELECT type, state_key, event_id FROM state_groups_state"
+ " WHERE state_group = ? " + where_clause,
+ args,
+ )
+ results[group].update(
+ ((typ, state_key), event_id)
+ for typ, state_key, event_id in txn
+ if (typ, state_key) not in results[group]
+ )
+
+ # If the number of entries in the (type,state_key)->event_id dict
+ # matches the number of (type,state_keys) types we were searching
+ # for, then we must have found them all, so no need to go walk
+ # further down the tree... UNLESS our types filter contained
+ # wildcards (i.e. Nones) in which case we have to do an exhaustive
+ # search
+ if (
+ max_entries_returned is not None
+ and len(results[group]) == max_entries_returned
+ ):
+ break
+
+ next_group = self.db.simple_select_one_onecol_txn(
+ txn,
+ table="state_group_edges",
+ keyvalues={"state_group": next_group},
+ retcol="prev_state_group",
+ allow_none=True,
+ )
+
+ return results
+
+
+class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
+
+ STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
+ STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
+ STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
+
+ def __init__(self, database: Database, db_conn, hs):
+ super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ self.db.updates.register_background_update_handler(
+ self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
+ self._background_deduplicate_state,
+ )
+ self.db.updates.register_background_update_handler(
+ self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
+ )
+ self.db.updates.register_background_index_update(
+ self.STATE_GROUPS_ROOM_INDEX_UPDATE_NAME,
+ index_name="state_groups_room_id_idx",
+ table="state_groups",
+ columns=["room_id"],
+ )
+
+ @defer.inlineCallbacks
+ def _background_deduplicate_state(self, progress, batch_size):
+ """This background update will slowly deduplicate state by reencoding
+ them as deltas.
+ """
+ last_state_group = progress.get("last_state_group", 0)
+ rows_inserted = progress.get("rows_inserted", 0)
+ max_group = progress.get("max_group", None)
+
+ BATCH_SIZE_SCALE_FACTOR = 100
+
+ batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
+
+ if max_group is None:
+ rows = yield self.db.execute(
+ "_background_deduplicate_state",
+ None,
+ "SELECT coalesce(max(id), 0) FROM state_groups",
+ )
+ max_group = rows[0][0]
+
+ def reindex_txn(txn):
+ new_last_state_group = last_state_group
+ for count in range(batch_size):
+ txn.execute(
+ "SELECT id, room_id FROM state_groups"
+ " WHERE ? < id AND id <= ?"
+ " ORDER BY id ASC"
+ " LIMIT 1",
+ (new_last_state_group, max_group),
+ )
+ row = txn.fetchone()
+ if row:
+ state_group, room_id = row
+
+ if not row or not state_group:
+ return True, count
+
+ txn.execute(
+ "SELECT state_group FROM state_group_edges"
+ " WHERE state_group = ?",
+ (state_group,),
+ )
+
+ # If we reach a point where we've already started inserting
+ # edges we should stop.
+ if txn.fetchall():
+ return True, count
+
+ txn.execute(
+ "SELECT coalesce(max(id), 0) FROM state_groups"
+ " WHERE id < ? AND room_id = ?",
+ (state_group, room_id),
+ )
+ (prev_group,) = txn.fetchone()
+ new_last_state_group = state_group
+
+ if prev_group:
+ potential_hops = self._count_state_group_hops_txn(txn, prev_group)
+ if potential_hops >= MAX_STATE_DELTA_HOPS:
+ # We want to ensure chains are at most this long,#
+ # otherwise read performance degrades.
+ continue
+
+ prev_state = self._get_state_groups_from_groups_txn(
+ txn, [prev_group]
+ )
+ prev_state = prev_state[prev_group]
+
+ curr_state = self._get_state_groups_from_groups_txn(
+ txn, [state_group]
+ )
+ curr_state = curr_state[state_group]
+
+ if not set(prev_state.keys()) - set(curr_state.keys()):
+ # We can only do a delta if the current has a strict super set
+ # of keys
+
+ delta_state = {
+ key: value
+ for key, value in iteritems(curr_state)
+ if prev_state.get(key, None) != value
+ }
+
+ self.db.simple_delete_txn(
+ txn,
+ table="state_group_edges",
+ keyvalues={"state_group": state_group},
+ )
+
+ self.db.simple_insert_txn(
+ txn,
+ table="state_group_edges",
+ values={
+ "state_group": state_group,
+ "prev_state_group": prev_group,
+ },
+ )
+
+ self.db.simple_delete_txn(
+ txn,
+ table="state_groups_state",
+ keyvalues={"state_group": state_group},
+ )
+
+ self.db.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ values=[
+ {
+ "state_group": state_group,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": state_id,
+ }
+ for key, state_id in iteritems(delta_state)
+ ],
+ )
+
+ progress = {
+ "last_state_group": state_group,
+ "rows_inserted": rows_inserted + batch_size,
+ "max_group": max_group,
+ }
+
+ self.db.updates._background_update_progress_txn(
+ txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress
+ )
+
+ return False, batch_size
+
+ finished, result = yield self.db.runInteraction(
+ self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
+ )
+
+ if finished:
+ yield self.db.updates._end_background_update(
+ self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
+ )
+
+ return result * BATCH_SIZE_SCALE_FACTOR
+
+ @defer.inlineCallbacks
+ def _background_index_state(self, progress, batch_size):
+ def reindex_txn(conn):
+ conn.rollback()
+ if isinstance(self.database_engine, PostgresEngine):
+ # postgres insists on autocommit for the index
+ conn.set_session(autocommit=True)
+ try:
+ txn = conn.cursor()
+ txn.execute(
+ "CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
+ " ON state_groups_state(state_group, type, state_key)"
+ )
+ txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
+ finally:
+ conn.set_session(autocommit=False)
+ else:
+ txn = conn.cursor()
+ txn.execute(
+ "CREATE INDEX state_groups_state_type_idx"
+ " ON state_groups_state(state_group, type, state_key)"
+ )
+ txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
+
+ yield self.db.runWithConnection(reindex_txn)
+
+ yield self.db.updates._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
+
+ return 1
diff --git a/synapse/storage/schema/delta/23/drop_state_index.sql b/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql
index ae09fa0065..ae09fa0065 100644
--- a/synapse/storage/schema/delta/23/drop_state_index.sql
+++ b/synapse/storage/data_stores/state/schema/delta/23/drop_state_index.sql
diff --git a/synapse/storage/schema/delta/30/state_stream.sql b/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql
index e85699e82e..e85699e82e 100644
--- a/synapse/storage/schema/delta/30/state_stream.sql
+++ b/synapse/storage/data_stores/state/schema/delta/30/state_stream.sql
diff --git a/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql b/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql
new file mode 100644
index 0000000000..1450313bfa
--- /dev/null
+++ b/synapse/storage/data_stores/state/schema/delta/32/remove_state_indices.sql
@@ -0,0 +1,19 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+-- The following indices are redundant, other indices are equivalent or
+-- supersets
+DROP INDEX IF EXISTS state_groups_id; -- Duplicate of PRIMARY KEY
diff --git a/synapse/storage/schema/delta/35/add_state_index.sql b/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql
index 0fce26345b..33980d02f0 100644
--- a/synapse/storage/schema/delta/35/add_state_index.sql
+++ b/synapse/storage/data_stores/state/schema/delta/35/add_state_index.sql
@@ -13,8 +13,5 @@
* limitations under the License.
*/
-
-ALTER TABLE background_updates ADD COLUMN depends_on TEXT;
-
INSERT into background_updates (update_name, progress_json, depends_on)
VALUES ('state_group_state_type_index', '{}', 'state_group_state_deduplication');
diff --git a/synapse/storage/schema/delta/35/state.sql b/synapse/storage/data_stores/state/schema/delta/35/state.sql
index 0f1fa68a89..0f1fa68a89 100644
--- a/synapse/storage/schema/delta/35/state.sql
+++ b/synapse/storage/data_stores/state/schema/delta/35/state.sql
diff --git a/synapse/storage/schema/delta/35/state_dedupe.sql b/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql
index 97e5067ef4..97e5067ef4 100644
--- a/synapse/storage/schema/delta/35/state_dedupe.sql
+++ b/synapse/storage/data_stores/state/schema/delta/35/state_dedupe.sql
diff --git a/synapse/storage/schema/delta/47/state_group_seq.py b/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py
index f6766501d2..9fd1ccf6f7 100644
--- a/synapse/storage/schema/delta/47/state_group_seq.py
+++ b/synapse/storage/data_stores/state/schema/delta/47/state_group_seq.py
@@ -27,10 +27,7 @@ def run_create(cur, database_engine, *args, **kwargs):
else:
start_val = row[0] + 1
- cur.execute(
- "CREATE SEQUENCE state_group_id_seq START WITH %s",
- (start_val, ),
- )
+ cur.execute("CREATE SEQUENCE state_group_id_seq START WITH %s", (start_val,))
def run_upgrade(*args, **kwargs):
diff --git a/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql b/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql
new file mode 100644
index 0000000000..7916ef18b2
--- /dev/null
+++ b/synapse/storage/data_stores/state/schema/delta/56/state_group_room_idx.sql
@@ -0,0 +1,17 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('state_groups_room_id_idx', '{}');
diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql b/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql
new file mode 100644
index 0000000000..35f97d6b3d
--- /dev/null
+++ b/synapse/storage/data_stores/state/schema/full_schemas/54/full.sql
@@ -0,0 +1,37 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE state_groups (
+ id BIGINT PRIMARY KEY,
+ room_id TEXT NOT NULL,
+ event_id TEXT NOT NULL
+);
+
+CREATE TABLE state_groups_state (
+ state_group BIGINT NOT NULL,
+ room_id TEXT NOT NULL,
+ type TEXT NOT NULL,
+ state_key TEXT NOT NULL,
+ event_id TEXT NOT NULL
+);
+
+CREATE TABLE state_group_edges (
+ state_group BIGINT NOT NULL,
+ prev_state_group BIGINT NOT NULL
+);
+
+CREATE INDEX state_group_edges_idx ON state_group_edges (state_group);
+CREATE INDEX state_group_edges_prev_idx ON state_group_edges (prev_state_group);
+CREATE INDEX state_groups_state_type_idx ON state_groups_state (state_group, type, state_key);
diff --git a/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres b/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres
new file mode 100644
index 0000000000..fcd926c9fb
--- /dev/null
+++ b/synapse/storage/data_stores/state/schema/full_schemas/54/sequence.sql.postgres
@@ -0,0 +1,21 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE state_group_id_seq
+ START WITH 1
+ INCREMENT BY 1
+ NO MINVALUE
+ NO MAXVALUE
+ CACHE 1;
diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py
new file mode 100644
index 0000000000..57a5267663
--- /dev/null
+++ b/synapse/storage/data_stores/state/store.py
@@ -0,0 +1,644 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from collections import namedtuple
+from typing import Dict, Iterable, List, Set, Tuple
+
+from six import iteritems
+from six.moves import range
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
+from synapse.storage.database import Database
+from synapse.storage.state import StateFilter
+from synapse.types import StateMap
+from synapse.util.caches import get_cache_factor_for
+from synapse.util.caches.descriptors import cached
+from synapse.util.caches.dictionary_cache import DictionaryCache
+
+logger = logging.getLogger(__name__)
+
+
+MAX_STATE_DELTA_HOPS = 100
+
+
+class _GetStateGroupDelta(
+ namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
+):
+ """Return type of get_state_group_delta that implements __len__, which lets
+ us use the itrable flag when caching
+ """
+
+ __slots__ = []
+
+ def __len__(self):
+ return len(self.delta_ids) if self.delta_ids else 0
+
+
+class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
+ """A data store for fetching/storing state groups.
+ """
+
+ def __init__(self, database: Database, db_conn, hs):
+ super(StateGroupDataStore, self).__init__(database, db_conn, hs)
+
+ # Originally the state store used a single DictionaryCache to cache the
+ # event IDs for the state types in a given state group to avoid hammering
+ # on the state_group* tables.
+ #
+ # The point of using a DictionaryCache is that it can cache a subset
+ # of the state events for a given state group (i.e. a subset of the keys for a
+ # given dict which is an entry in the cache for a given state group ID).
+ #
+ # However, this poses problems when performing complicated queries
+ # on the store - for instance: "give me all the state for this group, but
+ # limit members to this subset of users", as DictionaryCache's API isn't
+ # rich enough to say "please cache any of these fields, apart from this subset".
+ # This is problematic when lazy loading members, which requires this behaviour,
+ # as without it the cache has no choice but to speculatively load all
+ # state events for the group, which negates the efficiency being sought.
+ #
+ # Rather than overcomplicating DictionaryCache's API, we instead split the
+ # state_group_cache into two halves - one for tracking non-member events,
+ # and the other for tracking member_events. This means that lazy loading
+ # queries can be made in a cache-friendly manner by querying both caches
+ # separately and then merging the result. So for the example above, you
+ # would query the members cache for a specific subset of state keys
+ # (which DictionaryCache will handle efficiently and fine) and the non-members
+ # cache for all state (which DictionaryCache will similarly handle fine)
+ # and then just merge the results together.
+ #
+ # We size the non-members cache to be smaller than the members cache as the
+ # vast majority of state in Matrix (today) is member events.
+
+ self._state_group_cache = DictionaryCache(
+ "*stateGroupCache*",
+ # TODO: this hasn't been tuned yet
+ 50000 * get_cache_factor_for("stateGroupCache"),
+ )
+ self._state_group_members_cache = DictionaryCache(
+ "*stateGroupMembersCache*",
+ 500000 * get_cache_factor_for("stateGroupMembersCache"),
+ )
+
+ @cached(max_entries=10000, iterable=True)
+ def get_state_group_delta(self, state_group):
+ """Given a state group try to return a previous group and a delta between
+ the old and the new.
+
+ Returns:
+ (prev_group, delta_ids), where both may be None.
+ """
+
+ def _get_state_group_delta_txn(txn):
+ prev_group = self.db.simple_select_one_onecol_txn(
+ txn,
+ table="state_group_edges",
+ keyvalues={"state_group": state_group},
+ retcol="prev_state_group",
+ allow_none=True,
+ )
+
+ if not prev_group:
+ return _GetStateGroupDelta(None, None)
+
+ delta_ids = self.db.simple_select_list_txn(
+ txn,
+ table="state_groups_state",
+ keyvalues={"state_group": state_group},
+ retcols=("type", "state_key", "event_id"),
+ )
+
+ return _GetStateGroupDelta(
+ prev_group,
+ {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
+ )
+
+ return self.db.runInteraction(
+ "get_state_group_delta", _get_state_group_delta_txn
+ )
+
+ @defer.inlineCallbacks
+ def _get_state_groups_from_groups(
+ self, groups: List[int], state_filter: StateFilter
+ ):
+ """Returns the state groups for a given set of groups from the
+ database, filtering on types of state events.
+
+ Args:
+ groups: list of state group IDs to query
+ state_filter: The state filter used to fetch state
+ from the database.
+ Returns:
+ Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
+ """
+ results = {}
+
+ chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
+ for chunk in chunks:
+ res = yield self.db.runInteraction(
+ "_get_state_groups_from_groups",
+ self._get_state_groups_from_groups_txn,
+ chunk,
+ state_filter,
+ )
+ results.update(res)
+
+ return results
+
+ def _get_state_for_group_using_cache(self, cache, group, state_filter):
+ """Checks if group is in cache. See `_get_state_for_groups`
+
+ Args:
+ cache(DictionaryCache): the state group cache to use
+ group(int): The state group to lookup
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+
+ Returns 2-tuple (`state_dict`, `got_all`).
+ `got_all` is a bool indicating if we successfully retrieved all
+ requests state from the cache, if False we need to query the DB for the
+ missing state.
+ """
+ is_all, known_absent, state_dict_ids = cache.get(group)
+
+ if is_all or state_filter.is_full():
+ # Either we have everything or want everything, either way
+ # `is_all` tells us whether we've gotten everything.
+ return state_filter.filter_state(state_dict_ids), is_all
+
+ # tracks whether any of our requested types are missing from the cache
+ missing_types = False
+
+ if state_filter.has_wildcards():
+ # We don't know if we fetched all the state keys for the types in
+ # the filter that are wildcards, so we have to assume that we may
+ # have missed some.
+ missing_types = True
+ else:
+ # There aren't any wild cards, so `concrete_types()` returns the
+ # complete list of event types we're wanting.
+ for key in state_filter.concrete_types():
+ if key not in state_dict_ids and key not in known_absent:
+ missing_types = True
+ break
+
+ return state_filter.filter_state(state_dict_ids), not missing_types
+
+ @defer.inlineCallbacks
+ def _get_state_for_groups(
+ self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
+ ):
+ """Gets the state at each of a list of state groups, optionally
+ filtering by type/state_key
+
+ Args:
+ groups: list of state groups for which we want
+ to get the state.
+ state_filter: The state filter used to fetch state
+ from the database.
+ Returns:
+ Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
+ """
+
+ member_filter, non_member_filter = state_filter.get_member_split()
+
+ # Now we look them up in the member and non-member caches
+ (
+ non_member_state,
+ incomplete_groups_nm,
+ ) = yield self._get_state_for_groups_using_cache(
+ groups, self._state_group_cache, state_filter=non_member_filter
+ )
+
+ (
+ member_state,
+ incomplete_groups_m,
+ ) = yield self._get_state_for_groups_using_cache(
+ groups, self._state_group_members_cache, state_filter=member_filter
+ )
+
+ state = dict(non_member_state)
+ for group in groups:
+ state[group].update(member_state[group])
+
+ # Now fetch any missing groups from the database
+
+ incomplete_groups = incomplete_groups_m | incomplete_groups_nm
+
+ if not incomplete_groups:
+ return state
+
+ cache_sequence_nm = self._state_group_cache.sequence
+ cache_sequence_m = self._state_group_members_cache.sequence
+
+ # Help the cache hit ratio by expanding the filter a bit
+ db_state_filter = state_filter.return_expanded()
+
+ group_to_state_dict = yield self._get_state_groups_from_groups(
+ list(incomplete_groups), state_filter=db_state_filter
+ )
+
+ # Now lets update the caches
+ self._insert_into_cache(
+ group_to_state_dict,
+ db_state_filter,
+ cache_seq_num_members=cache_sequence_m,
+ cache_seq_num_non_members=cache_sequence_nm,
+ )
+
+ # And finally update the result dict, by filtering out any extra
+ # stuff we pulled out of the database.
+ for group, group_state_dict in iteritems(group_to_state_dict):
+ # We just replace any existing entries, as we will have loaded
+ # everything we need from the database anyway.
+ state[group] = state_filter.filter_state(group_state_dict)
+
+ return state
+
+ def _get_state_for_groups_using_cache(
+ self, groups: Iterable[int], cache: DictionaryCache, state_filter: StateFilter
+ ) -> Tuple[Dict[int, StateMap[str]], Set[int]]:
+ """Gets the state at each of a list of state groups, optionally
+ filtering by type/state_key, querying from a specific cache.
+
+ Args:
+ groups: list of state groups for which we want to get the state.
+ cache: the cache of group ids to state dicts which
+ we will pass through - either the normal state cache or the
+ specific members state cache.
+ state_filter: The state filter used to fetch state from the
+ database.
+
+ Returns:
+ Tuple of dict of state_group_id to state map of entries in the
+ cache, and the state group ids either missing from the cache or
+ incomplete.
+ """
+ results = {}
+ incomplete_groups = set()
+ for group in set(groups):
+ state_dict_ids, got_all = self._get_state_for_group_using_cache(
+ cache, group, state_filter
+ )
+ results[group] = state_dict_ids
+
+ if not got_all:
+ incomplete_groups.add(group)
+
+ return results, incomplete_groups
+
+ def _insert_into_cache(
+ self,
+ group_to_state_dict,
+ state_filter,
+ cache_seq_num_members,
+ cache_seq_num_non_members,
+ ):
+ """Inserts results from querying the database into the relevant cache.
+
+ Args:
+ group_to_state_dict (dict): The new entries pulled from database.
+ Map from state group to state dict
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+ cache_seq_num_members (int): Sequence number of member cache since
+ last lookup in cache
+ cache_seq_num_non_members (int): Sequence number of member cache since
+ last lookup in cache
+ """
+
+ # We need to work out which types we've fetched from the DB for the
+ # member vs non-member caches. This should be as accurate as possible,
+ # but can be an underestimate (e.g. when we have wild cards)
+
+ member_filter, non_member_filter = state_filter.get_member_split()
+ if member_filter.is_full():
+ # We fetched all member events
+ member_types = None
+ else:
+ # `concrete_types()` will only return a subset when there are wild
+ # cards in the filter, but that's fine.
+ member_types = member_filter.concrete_types()
+
+ if non_member_filter.is_full():
+ # We fetched all non member events
+ non_member_types = None
+ else:
+ non_member_types = non_member_filter.concrete_types()
+
+ for group, group_state_dict in iteritems(group_to_state_dict):
+ state_dict_members = {}
+ state_dict_non_members = {}
+
+ for k, v in iteritems(group_state_dict):
+ if k[0] == EventTypes.Member:
+ state_dict_members[k] = v
+ else:
+ state_dict_non_members[k] = v
+
+ self._state_group_members_cache.update(
+ cache_seq_num_members,
+ key=group,
+ value=state_dict_members,
+ fetched_keys=member_types,
+ )
+
+ self._state_group_cache.update(
+ cache_seq_num_non_members,
+ key=group,
+ value=state_dict_non_members,
+ fetched_keys=non_member_types,
+ )
+
+ def store_state_group(
+ self, event_id, room_id, prev_group, delta_ids, current_state_ids
+ ):
+ """Store a new set of state, returning a newly assigned state group.
+
+ Args:
+ event_id (str): The event ID for which the state was calculated
+ room_id (str)
+ prev_group (int|None): A previous state group for the room, optional.
+ delta_ids (dict|None): The delta between state at `prev_group` and
+ `current_state_ids`, if `prev_group` was given. Same format as
+ `current_state_ids`.
+ current_state_ids (dict): The state to store. Map of (type, state_key)
+ to event_id.
+
+ Returns:
+ Deferred[int]: The state group ID
+ """
+
+ def _store_state_group_txn(txn):
+ if current_state_ids is None:
+ # AFAIK, this can never happen
+ raise Exception("current_state_ids cannot be None")
+
+ state_group = self.database_engine.get_next_state_group_id(txn)
+
+ self.db.simple_insert_txn(
+ txn,
+ table="state_groups",
+ values={"id": state_group, "room_id": room_id, "event_id": event_id},
+ )
+
+ # We persist as a delta if we can, while also ensuring the chain
+ # of deltas isn't tooo long, as otherwise read performance degrades.
+ if prev_group:
+ is_in_db = self.db.simple_select_one_onecol_txn(
+ txn,
+ table="state_groups",
+ keyvalues={"id": prev_group},
+ retcol="id",
+ allow_none=True,
+ )
+ if not is_in_db:
+ raise Exception(
+ "Trying to persist state with unpersisted prev_group: %r"
+ % (prev_group,)
+ )
+
+ potential_hops = self._count_state_group_hops_txn(txn, prev_group)
+ if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
+ self.db.simple_insert_txn(
+ txn,
+ table="state_group_edges",
+ values={"state_group": state_group, "prev_state_group": prev_group},
+ )
+
+ self.db.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ values=[
+ {
+ "state_group": state_group,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": state_id,
+ }
+ for key, state_id in iteritems(delta_ids)
+ ],
+ )
+ else:
+ self.db.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ values=[
+ {
+ "state_group": state_group,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": state_id,
+ }
+ for key, state_id in iteritems(current_state_ids)
+ ],
+ )
+
+ # Prefill the state group caches with this group.
+ # It's fine to use the sequence like this as the state group map
+ # is immutable. (If the map wasn't immutable then this prefill could
+ # race with another update)
+
+ current_member_state_ids = {
+ s: ev
+ for (s, ev) in iteritems(current_state_ids)
+ if s[0] == EventTypes.Member
+ }
+ txn.call_after(
+ self._state_group_members_cache.update,
+ self._state_group_members_cache.sequence,
+ key=state_group,
+ value=dict(current_member_state_ids),
+ )
+
+ current_non_member_state_ids = {
+ s: ev
+ for (s, ev) in iteritems(current_state_ids)
+ if s[0] != EventTypes.Member
+ }
+ txn.call_after(
+ self._state_group_cache.update,
+ self._state_group_cache.sequence,
+ key=state_group,
+ value=dict(current_non_member_state_ids),
+ )
+
+ return state_group
+
+ return self.db.runInteraction("store_state_group", _store_state_group_txn)
+
+ def purge_unreferenced_state_groups(
+ self, room_id: str, state_groups_to_delete
+ ) -> defer.Deferred:
+ """Deletes no longer referenced state groups and de-deltas any state
+ groups that reference them.
+
+ Args:
+ room_id: The room the state groups belong to (must all be in the
+ same room).
+ state_groups_to_delete (Collection[int]): Set of all state groups
+ to delete.
+ """
+
+ return self.db.runInteraction(
+ "purge_unreferenced_state_groups",
+ self._purge_unreferenced_state_groups,
+ room_id,
+ state_groups_to_delete,
+ )
+
+ def _purge_unreferenced_state_groups(self, txn, room_id, state_groups_to_delete):
+ logger.info(
+ "[purge] found %i state groups to delete", len(state_groups_to_delete)
+ )
+
+ rows = self.db.simple_select_many_txn(
+ txn,
+ table="state_group_edges",
+ column="prev_state_group",
+ iterable=state_groups_to_delete,
+ keyvalues={},
+ retcols=("state_group",),
+ )
+
+ remaining_state_groups = {
+ row["state_group"]
+ for row in rows
+ if row["state_group"] not in state_groups_to_delete
+ }
+
+ logger.info(
+ "[purge] de-delta-ing %i remaining state groups",
+ len(remaining_state_groups),
+ )
+
+ # Now we turn the state groups that reference to-be-deleted state
+ # groups to non delta versions.
+ for sg in remaining_state_groups:
+ logger.info("[purge] de-delta-ing remaining state group %s", sg)
+ curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
+ curr_state = curr_state[sg]
+
+ self.db.simple_delete_txn(
+ txn, table="state_groups_state", keyvalues={"state_group": sg}
+ )
+
+ self.db.simple_delete_txn(
+ txn, table="state_group_edges", keyvalues={"state_group": sg}
+ )
+
+ self.db.simple_insert_many_txn(
+ txn,
+ table="state_groups_state",
+ values=[
+ {
+ "state_group": sg,
+ "room_id": room_id,
+ "type": key[0],
+ "state_key": key[1],
+ "event_id": state_id,
+ }
+ for key, state_id in iteritems(curr_state)
+ ],
+ )
+
+ logger.info("[purge] removing redundant state groups")
+ txn.executemany(
+ "DELETE FROM state_groups_state WHERE state_group = ?",
+ ((sg,) for sg in state_groups_to_delete),
+ )
+ txn.executemany(
+ "DELETE FROM state_groups WHERE id = ?",
+ ((sg,) for sg in state_groups_to_delete),
+ )
+
+ @defer.inlineCallbacks
+ def get_previous_state_groups(self, state_groups):
+ """Fetch the previous groups of the given state groups.
+
+ Args:
+ state_groups (Iterable[int])
+
+ Returns:
+ Deferred[dict[int, int]]: mapping from state group to previous
+ state group.
+ """
+
+ rows = yield self.db.simple_select_many_batch(
+ table="state_group_edges",
+ column="prev_state_group",
+ iterable=state_groups,
+ keyvalues={},
+ retcols=("prev_state_group", "state_group"),
+ desc="get_previous_state_groups",
+ )
+
+ return {row["state_group"]: row["prev_state_group"] for row in rows}
+
+ def purge_room_state(self, room_id, state_groups_to_delete):
+ """Deletes all record of a room from state tables
+
+ Args:
+ room_id (str):
+ state_groups_to_delete (list[int]): State groups to delete
+ """
+
+ return self.db.runInteraction(
+ "purge_room_state",
+ self._purge_room_state_txn,
+ room_id,
+ state_groups_to_delete,
+ )
+
+ def _purge_room_state_txn(self, txn, room_id, state_groups_to_delete):
+ # first we have to delete the state groups states
+ logger.info("[purge] removing %s from state_groups_state", room_id)
+
+ self.db.simple_delete_many_txn(
+ txn,
+ table="state_groups_state",
+ column="state_group",
+ iterable=state_groups_to_delete,
+ keyvalues={},
+ )
+
+ # ... and the state group edges
+ logger.info("[purge] removing %s from state_group_edges", room_id)
+
+ self.db.simple_delete_many_txn(
+ txn,
+ table="state_group_edges",
+ column="state_group",
+ iterable=state_groups_to_delete,
+ keyvalues={},
+ )
+
+ # ... and the state groups
+ logger.info("[purge] removing %s from state_groups", room_id)
+
+ self.db.simple_delete_many_txn(
+ txn,
+ table="state_groups",
+ column="id",
+ iterable=state_groups_to_delete,
+ keyvalues={},
+ )
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
new file mode 100644
index 0000000000..e61595336c
--- /dev/null
+++ b/synapse/storage/database.py
@@ -0,0 +1,1560 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import time
+from time import monotonic as monotonic_time
+from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
+
+from six import iteritems, iterkeys, itervalues
+from six.moves import intern, range
+
+from prometheus_client import Histogram
+
+from twisted.enterprise import adbapi
+from twisted.internet import defer
+
+from synapse.api.errors import StoreError
+from synapse.config.database import DatabaseConnectionConfig
+from synapse.logging.context import (
+ LoggingContext,
+ LoggingContextOrSentinel,
+ make_deferred_yieldable,
+)
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.background_updates import BackgroundUpdater
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
+from synapse.storage.types import Connection, Cursor
+from synapse.util.stringutils import exception_to_unicode
+
+logger = logging.getLogger(__name__)
+
+# python 3 does not have a maximum int value
+MAX_TXN_ID = 2 ** 63 - 1
+
+sql_logger = logging.getLogger("synapse.storage.SQL")
+transaction_logger = logging.getLogger("synapse.storage.txn")
+perf_logger = logging.getLogger("synapse.storage.TIME")
+
+sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
+
+sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
+sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
+
+
+# Unique indexes which have been added in background updates. Maps from table name
+# to the name of the background update which added the unique index to that table.
+#
+# This is used by the upsert logic to figure out which tables are safe to do a proper
+# UPSERT on: until the relevant background update has completed, we
+# have to emulate an upsert by locking the table.
+#
+UNIQUE_INDEX_BACKGROUND_UPDATES = {
+ "user_ips": "user_ips_device_unique_index",
+ "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
+ "device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
+ "event_search": "event_search_event_id_idx",
+}
+
+
+def make_pool(
+ reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+) -> adbapi.ConnectionPool:
+ """Get the connection pool for the database.
+ """
+
+ return adbapi.ConnectionPool(
+ db_config.config["name"],
+ cp_reactor=reactor,
+ cp_openfun=engine.on_new_connection,
+ **db_config.config.get("args", {})
+ )
+
+
+def make_conn(
+ db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+) -> Connection:
+ """Make a new connection to the database and return it.
+
+ Returns:
+ Connection
+ """
+
+ db_params = {
+ k: v
+ for k, v in db_config.config.get("args", {}).items()
+ if not k.startswith("cp_")
+ }
+ db_conn = engine.module.connect(**db_params)
+ engine.on_new_connection(db_conn)
+ return db_conn
+
+
+# The type of entry which goes on our after_callbacks and exception_callbacks lists.
+#
+# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
+# that mypy sees the type but the runtime python doesn't.
+_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
+
+
+class LoggingTransaction:
+ """An object that almost-transparently proxies for the 'txn' object
+ passed to the constructor. Adds logging and metrics to the .execute()
+ method.
+
+ Args:
+ txn: The database transcation object to wrap.
+ name: The name of this transactions for logging.
+ database_engine
+ after_callbacks: A list that callbacks will be appended to
+ that have been added by `call_after` which should be run on
+ successful completion of the transaction. None indicates that no
+ callbacks should be allowed to be scheduled to run.
+ exception_callbacks: A list that callbacks will be appended
+ to that have been added by `call_on_exception` which should be run
+ if transaction ends with an error. None indicates that no callbacks
+ should be allowed to be scheduled to run.
+ """
+
+ __slots__ = [
+ "txn",
+ "name",
+ "database_engine",
+ "after_callbacks",
+ "exception_callbacks",
+ ]
+
+ def __init__(
+ self,
+ txn: Cursor,
+ name: str,
+ database_engine: BaseDatabaseEngine,
+ after_callbacks: Optional[List[_CallbackListEntry]] = None,
+ exception_callbacks: Optional[List[_CallbackListEntry]] = None,
+ ):
+ self.txn = txn
+ self.name = name
+ self.database_engine = database_engine
+ self.after_callbacks = after_callbacks
+ self.exception_callbacks = exception_callbacks
+
+ def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
+ """Call the given callback on the main twisted thread after the
+ transaction has finished. Used to invalidate the caches on the
+ correct thread.
+ """
+ # if self.after_callbacks is None, that means that whatever constructed the
+ # LoggingTransaction isn't expecting there to be any callbacks; assert that
+ # is not the case.
+ assert self.after_callbacks is not None
+ self.after_callbacks.append((callback, args, kwargs))
+
+ def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
+ # if self.exception_callbacks is None, that means that whatever constructed the
+ # LoggingTransaction isn't expecting there to be any callbacks; assert that
+ # is not the case.
+ assert self.exception_callbacks is not None
+ self.exception_callbacks.append((callback, args, kwargs))
+
+ def fetchall(self) -> List[Tuple]:
+ return self.txn.fetchall()
+
+ def fetchone(self) -> Tuple:
+ return self.txn.fetchone()
+
+ def __iter__(self) -> Iterator[Tuple]:
+ return self.txn.__iter__()
+
+ @property
+ def rowcount(self) -> int:
+ return self.txn.rowcount
+
+ @property
+ def description(self) -> Any:
+ return self.txn.description
+
+ def execute_batch(self, sql, args):
+ if isinstance(self.database_engine, PostgresEngine):
+ from psycopg2.extras import execute_batch # type: ignore
+
+ self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
+ else:
+ for val in args:
+ self.execute(sql, val)
+
+ def execute(self, sql: str, *args: Any):
+ self._do_execute(self.txn.execute, sql, *args)
+
+ def executemany(self, sql: str, *args: Any):
+ self._do_execute(self.txn.executemany, sql, *args)
+
+ def _make_sql_one_line(self, sql):
+ "Strip newlines out of SQL so that the loggers in the DB are on one line"
+ return " ".join(l.strip() for l in sql.splitlines() if l.strip())
+
+ def _do_execute(self, func, sql, *args):
+ sql = self._make_sql_one_line(sql)
+
+ # TODO(paul): Maybe use 'info' and 'debug' for values?
+ sql_logger.debug("[SQL] {%s} %s", self.name, sql)
+
+ sql = self.database_engine.convert_param_style(sql)
+ if args:
+ try:
+ sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
+ except Exception:
+ # Don't let logging failures stop SQL from working
+ pass
+
+ start = time.time()
+
+ try:
+ return func(sql, *args)
+ except Exception as e:
+ logger.debug("[SQL FAIL] {%s} %s", self.name, e)
+ raise
+ finally:
+ secs = time.time() - start
+ sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
+ sql_query_timer.labels(sql.split()[0]).observe(secs)
+
+ def close(self):
+ self.txn.close()
+
+
+class PerformanceCounters(object):
+ def __init__(self):
+ self.current_counters = {}
+ self.previous_counters = {}
+
+ def update(self, key, duration_secs):
+ count, cum_time = self.current_counters.get(key, (0, 0))
+ count += 1
+ cum_time += duration_secs
+ self.current_counters[key] = (count, cum_time)
+
+ def interval(self, interval_duration_secs, limit=3):
+ counters = []
+ for name, (count, cum_time) in iteritems(self.current_counters):
+ prev_count, prev_time = self.previous_counters.get(name, (0, 0))
+ counters.append(
+ (
+ (cum_time - prev_time) / interval_duration_secs,
+ count - prev_count,
+ name,
+ )
+ )
+
+ self.previous_counters = dict(self.current_counters)
+
+ counters.sort(reverse=True)
+
+ top_n_counters = ", ".join(
+ "%s(%d): %.3f%%" % (name, count, 100 * ratio)
+ for ratio, count, name in counters[:limit]
+ )
+
+ return top_n_counters
+
+
+class Database(object):
+ """Wraps a single physical database and connection pool.
+
+ A single database may be used by multiple data stores.
+ """
+
+ _TXN_ID = 0
+
+ def __init__(
+ self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+ ):
+ self.hs = hs
+ self._clock = hs.get_clock()
+ self._database_config = database_config
+ self._db_pool = make_pool(hs.get_reactor(), database_config, engine)
+
+ self.updates = BackgroundUpdater(hs, self)
+
+ self._previous_txn_total_time = 0.0
+ self._current_txn_total_time = 0.0
+ self._previous_loop_ts = 0.0
+
+ # TODO(paul): These can eventually be removed once the metrics code
+ # is running in mainline, and we have some nice monitoring frontends
+ # to watch it
+ self._txn_perf_counters = PerformanceCounters()
+
+ self.engine = engine
+
+ # A set of tables that are not safe to use native upserts in.
+ self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
+
+ # We add the user_directory_search table to the blacklist on SQLite
+ # because the existing search table does not have an index, making it
+ # unsafe to use native upserts.
+ if isinstance(self.engine, Sqlite3Engine):
+ self._unsafe_to_upsert_tables.add("user_directory_search")
+
+ if self.engine.can_native_upsert:
+ # Check ASAP (and then later, every 1s) to see if we have finished
+ # background updates of tables that aren't safe to update.
+ self._clock.call_later(
+ 0.0,
+ run_as_background_process,
+ "upsert_safety_check",
+ self._check_safe_to_upsert,
+ )
+
+ def is_running(self):
+ """Is the database pool currently running
+ """
+ return self._db_pool.running
+
+ @defer.inlineCallbacks
+ def _check_safe_to_upsert(self):
+ """
+ Is it safe to use native UPSERT?
+
+ If there are background updates, we will need to wait, as they may be
+ the addition of indexes that set the UNIQUE constraint that we require.
+
+ If the background updates have not completed, wait 15 sec and check again.
+ """
+ updates = yield self.simple_select_list(
+ "background_updates",
+ keyvalues=None,
+ retcols=["update_name"],
+ desc="check_background_updates",
+ )
+ updates = [x["update_name"] for x in updates]
+
+ for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
+ if update_name not in updates:
+ logger.debug("Now safe to upsert in %s", table)
+ self._unsafe_to_upsert_tables.discard(table)
+
+ # If there's any updates still running, reschedule to run.
+ if updates:
+ self._clock.call_later(
+ 15.0,
+ run_as_background_process,
+ "upsert_safety_check",
+ self._check_safe_to_upsert,
+ )
+
+ def start_profiling(self):
+ self._previous_loop_ts = monotonic_time()
+
+ def loop():
+ curr = self._current_txn_total_time
+ prev = self._previous_txn_total_time
+ self._previous_txn_total_time = curr
+
+ time_now = monotonic_time()
+ time_then = self._previous_loop_ts
+ self._previous_loop_ts = time_now
+
+ duration = time_now - time_then
+ ratio = (curr - prev) / duration
+
+ top_three_counters = self._txn_perf_counters.interval(duration, limit=3)
+
+ perf_logger.debug(
+ "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters
+ )
+
+ self._clock.looping_call(loop, 10000)
+
+ def new_transaction(
+ self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
+ ):
+ start = monotonic_time()
+ txn_id = self._TXN_ID
+
+ # We don't really need these to be unique, so lets stop it from
+ # growing really large.
+ self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
+
+ name = "%s-%x" % (desc, txn_id)
+
+ transaction_logger.debug("[TXN START] {%s}", name)
+
+ try:
+ i = 0
+ N = 5
+ while True:
+ cursor = LoggingTransaction(
+ conn.cursor(),
+ name,
+ self.engine,
+ after_callbacks,
+ exception_callbacks,
+ )
+ try:
+ r = func(cursor, *args, **kwargs)
+ conn.commit()
+ return r
+ except self.engine.module.OperationalError as e:
+ # This can happen if the database disappears mid
+ # transaction.
+ logger.warning(
+ "[TXN OPERROR] {%s} %s %d/%d",
+ name,
+ exception_to_unicode(e),
+ i,
+ N,
+ )
+ if i < N:
+ i += 1
+ try:
+ conn.rollback()
+ except self.engine.module.Error as e1:
+ logger.warning(
+ "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
+ )
+ continue
+ raise
+ except self.engine.module.DatabaseError as e:
+ if self.engine.is_deadlock(e):
+ logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
+ if i < N:
+ i += 1
+ try:
+ conn.rollback()
+ except self.engine.module.Error as e1:
+ logger.warning(
+ "[TXN EROLL] {%s} %s",
+ name,
+ exception_to_unicode(e1),
+ )
+ continue
+ raise
+ finally:
+ # we're either about to retry with a new cursor, or we're about to
+ # release the connection. Once we release the connection, it could
+ # get used for another query, which might do a conn.rollback().
+ #
+ # In the latter case, even though that probably wouldn't affect the
+ # results of this transaction, python's sqlite will reset all
+ # statements on the connection [1], which will make our cursor
+ # invalid [2].
+ #
+ # In any case, continuing to read rows after commit()ing seems
+ # dubious from the PoV of ACID transactional semantics
+ # (sqlite explicitly says that once you commit, you may see rows
+ # from subsequent updates.)
+ #
+ # In psycopg2, cursors are essentially a client-side fabrication -
+ # all the data is transferred to the client side when the statement
+ # finishes executing - so in theory we could go on streaming results
+ # from the cursor, but attempting to do so would make us
+ # incompatible with sqlite, so let's make sure we're not doing that
+ # by closing the cursor.
+ #
+ # (*named* cursors in psycopg2 are different and are proper server-
+ # side things, but (a) we don't use them and (b) they are implicitly
+ # closed by ending the transaction anyway.)
+ #
+ # In short, if we haven't finished with the cursor yet, that's a
+ # problem waiting to bite us.
+ #
+ # TL;DR: we're done with the cursor, so we can close it.
+ #
+ # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465
+ # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
+ cursor.close()
+ except Exception as e:
+ logger.debug("[TXN FAIL] {%s} %s", name, e)
+ raise
+ finally:
+ end = monotonic_time()
+ duration = end - start
+
+ LoggingContext.current_context().add_database_transaction(duration)
+
+ transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
+
+ self._current_txn_total_time += duration
+ self._txn_perf_counters.update(desc, duration)
+ sql_txn_timer.labels(desc).observe(duration)
+
+ @defer.inlineCallbacks
+ def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
+ """Starts a transaction on the database and runs a given function
+
+ Arguments:
+ desc: description of the transaction, for logging and metrics
+ func: callback function, which will be called with a
+ database transaction (twisted.enterprise.adbapi.Transaction) as
+ its first argument, followed by `args` and `kwargs`.
+
+ args: positional args to pass to `func`
+ kwargs: named args to pass to `func`
+
+ Returns:
+ Deferred: The result of func
+ """
+ after_callbacks = [] # type: List[_CallbackListEntry]
+ exception_callbacks = [] # type: List[_CallbackListEntry]
+
+ if LoggingContext.current_context() == LoggingContext.sentinel:
+ logger.warning("Starting db txn '%s' from sentinel context", desc)
+
+ try:
+ result = yield self.runWithConnection(
+ self.new_transaction,
+ desc,
+ after_callbacks,
+ exception_callbacks,
+ func,
+ *args,
+ **kwargs
+ )
+
+ for after_callback, after_args, after_kwargs in after_callbacks:
+ after_callback(*after_args, **after_kwargs)
+ except: # noqa: E722, as we reraise the exception this is fine.
+ for after_callback, after_args, after_kwargs in exception_callbacks:
+ after_callback(*after_args, **after_kwargs)
+ raise
+
+ return result
+
+ @defer.inlineCallbacks
+ def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
+ """Wraps the .runWithConnection() method on the underlying db_pool.
+
+ Arguments:
+ func: callback function, which will be called with a
+ database connection (twisted.enterprise.adbapi.Connection) as
+ its first argument, followed by `args` and `kwargs`.
+ args: positional args to pass to `func`
+ kwargs: named args to pass to `func`
+
+ Returns:
+ Deferred: The result of func
+ """
+ parent_context = (
+ LoggingContext.current_context()
+ ) # type: Optional[LoggingContextOrSentinel]
+ if parent_context == LoggingContext.sentinel:
+ logger.warning(
+ "Starting db connection from sentinel context: metrics will be lost"
+ )
+ parent_context = None
+
+ start_time = monotonic_time()
+
+ def inner_func(conn, *args, **kwargs):
+ with LoggingContext("runWithConnection", parent_context) as context:
+ sched_duration_sec = monotonic_time() - start_time
+ sql_scheduling_timer.observe(sched_duration_sec)
+ context.add_database_scheduled(sched_duration_sec)
+
+ if self.engine.is_connection_closed(conn):
+ logger.debug("Reconnecting closed database connection")
+ conn.reconnect()
+
+ return func(conn, *args, **kwargs)
+
+ result = yield make_deferred_yieldable(
+ self._db_pool.runWithConnection(inner_func, *args, **kwargs)
+ )
+
+ return result
+
+ @staticmethod
+ def cursor_to_dict(cursor):
+ """Converts a SQL cursor into an list of dicts.
+
+ Args:
+ cursor : The DBAPI cursor which has executed a query.
+ Returns:
+ A list of dicts where the key is the column header.
+ """
+ col_headers = [intern(str(column[0])) for column in cursor.description]
+ results = [dict(zip(col_headers, row)) for row in cursor]
+ return results
+
+ def execute(self, desc, decoder, query, *args):
+ """Runs a single query for a result set.
+
+ Args:
+ decoder - The function which can resolve the cursor results to
+ something meaningful.
+ query - The query string to execute
+ *args - Query args.
+ Returns:
+ The result of decoder(results)
+ """
+
+ def interaction(txn):
+ txn.execute(query, args)
+ if decoder:
+ return decoder(txn)
+ else:
+ return txn.fetchall()
+
+ return self.runInteraction(desc, interaction)
+
+ # "Simple" SQL API methods that operate on a single table with no JOINs,
+ # no complex WHERE clauses, just a dict of values for columns.
+
+ @defer.inlineCallbacks
+ def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
+ """Executes an INSERT query on the named table.
+
+ Args:
+ table : string giving the table name
+ values : dict of new column names and values for them
+ or_ignore : bool stating whether an exception should be raised
+ when a conflicting row already exists. If True, False will be
+ returned by the function instead
+ desc : string giving a description of the transaction
+
+ Returns:
+ bool: Whether the row was inserted or not. Only useful when
+ `or_ignore` is True
+ """
+ try:
+ yield self.runInteraction(desc, self.simple_insert_txn, table, values)
+ except self.engine.module.IntegrityError:
+ # We have to do or_ignore flag at this layer, since we can't reuse
+ # a cursor after we receive an error from the db.
+ if not or_ignore:
+ raise
+ return False
+ return True
+
+ @staticmethod
+ def simple_insert_txn(txn, table, values):
+ keys, vals = zip(*values.items())
+
+ sql = "INSERT INTO %s (%s) VALUES(%s)" % (
+ table,
+ ", ".join(k for k in keys),
+ ", ".join("?" for _ in keys),
+ )
+
+ txn.execute(sql, vals)
+
+ def simple_insert_many(self, table, values, desc):
+ return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
+
+ @staticmethod
+ def simple_insert_many_txn(txn, table, values):
+ if not values:
+ return
+
+ # This is a *slight* abomination to get a list of tuples of key names
+ # and a list of tuples of value names.
+ #
+ # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
+ # => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
+ #
+ # The sort is to ensure that we don't rely on dictionary iteration
+ # order.
+ keys, vals = zip(
+ *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
+ )
+
+ for k in keys:
+ if k != keys[0]:
+ raise RuntimeError("All items must have the same keys")
+
+ sql = "INSERT INTO %s (%s) VALUES(%s)" % (
+ table,
+ ", ".join(k for k in keys[0]),
+ ", ".join("?" for _ in keys[0]),
+ )
+
+ txn.executemany(sql, vals)
+
+ @defer.inlineCallbacks
+ def simple_upsert(
+ self,
+ table,
+ keyvalues,
+ values,
+ insertion_values={},
+ desc="simple_upsert",
+ lock=True,
+ ):
+ """
+
+ `lock` should generally be set to True (the default), but can be set
+ to False if either of the following are true:
+
+ * there is a UNIQUE INDEX on the key columns. In this case a conflict
+ will cause an IntegrityError in which case this function will retry
+ the update.
+
+ * we somehow know that we are the only thread which will be updating
+ this table.
+
+ Args:
+ table (str): The table to upsert into
+ keyvalues (dict): The unique key columns and their new values
+ values (dict): The nonunique columns and their new values
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ lock (bool): True to lock the table when doing the upsert.
+ Returns:
+ Deferred(None or bool): Native upserts always return None. Emulated
+ upserts return True if a new entry was created, False if an existing
+ one was updated.
+ """
+ attempts = 0
+ while True:
+ try:
+ result = yield self.runInteraction(
+ desc,
+ self.simple_upsert_txn,
+ table,
+ keyvalues,
+ values,
+ insertion_values,
+ lock=lock,
+ )
+ return result
+ except self.engine.module.IntegrityError as e:
+ attempts += 1
+ if attempts >= 5:
+ # don't retry forever, because things other than races
+ # can cause IntegrityErrors
+ raise
+
+ # presumably we raced with another transaction: let's retry.
+ logger.warning(
+ "IntegrityError when upserting into %s; retrying: %s", table, e
+ )
+
+ def simple_upsert_txn(
+ self, txn, table, keyvalues, values, insertion_values={}, lock=True
+ ):
+ """
+ Pick the UPSERT method which works best on the platform. Either the
+ native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
+
+ Args:
+ txn: The transaction to use.
+ table (str): The table to upsert into
+ keyvalues (dict): The unique key tables and their new values
+ values (dict): The nonunique columns and their new values
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ lock (bool): True to lock the table when doing the upsert.
+ Returns:
+ None or bool: Native upserts always return None. Emulated
+ upserts return True if a new entry was created, False if an existing
+ one was updated.
+ """
+ if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
+ return self.simple_upsert_txn_native_upsert(
+ txn, table, keyvalues, values, insertion_values=insertion_values
+ )
+ else:
+ return self.simple_upsert_txn_emulated(
+ txn,
+ table,
+ keyvalues,
+ values,
+ insertion_values=insertion_values,
+ lock=lock,
+ )
+
+ def simple_upsert_txn_emulated(
+ self, txn, table, keyvalues, values, insertion_values={}, lock=True
+ ):
+ """
+ Args:
+ table (str): The table to upsert into
+ keyvalues (dict): The unique key tables and their new values
+ values (dict): The nonunique columns and their new values
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ lock (bool): True to lock the table when doing the upsert.
+ Returns:
+ bool: Return True if a new entry was created, False if an existing
+ one was updated.
+ """
+ # We need to lock the table :(, unless we're *really* careful
+ if lock:
+ self.engine.lock_table(txn, table)
+
+ def _getwhere(key):
+ # If the value we're passing in is None (aka NULL), we need to use
+ # IS, not =, as NULL = NULL equals NULL (False).
+ if keyvalues[key] is None:
+ return "%s IS ?" % (key,)
+ else:
+ return "%s = ?" % (key,)
+
+ if not values:
+ # If `values` is empty, then all of the values we care about are in
+ # the unique key, so there is nothing to UPDATE. We can just do a
+ # SELECT instead to see if it exists.
+ sql = "SELECT 1 FROM %s WHERE %s" % (
+ table,
+ " AND ".join(_getwhere(k) for k in keyvalues),
+ )
+ sqlargs = list(keyvalues.values())
+ txn.execute(sql, sqlargs)
+ if txn.fetchall():
+ # We have an existing record.
+ return False
+ else:
+ # First try to update.
+ sql = "UPDATE %s SET %s WHERE %s" % (
+ table,
+ ", ".join("%s = ?" % (k,) for k in values),
+ " AND ".join(_getwhere(k) for k in keyvalues),
+ )
+ sqlargs = list(values.values()) + list(keyvalues.values())
+
+ txn.execute(sql, sqlargs)
+ if txn.rowcount > 0:
+ # successfully updated at least one row.
+ return False
+
+ # We didn't find any existing rows, so insert a new one
+ allvalues = {} # type: Dict[str, Any]
+ allvalues.update(keyvalues)
+ allvalues.update(values)
+ allvalues.update(insertion_values)
+
+ sql = "INSERT INTO %s (%s) VALUES (%s)" % (
+ table,
+ ", ".join(k for k in allvalues),
+ ", ".join("?" for _ in allvalues),
+ )
+ txn.execute(sql, list(allvalues.values()))
+ # successfully inserted
+ return True
+
+ def simple_upsert_txn_native_upsert(
+ self, txn, table, keyvalues, values, insertion_values={}
+ ):
+ """
+ Use the native UPSERT functionality in recent PostgreSQL versions.
+
+ Args:
+ table (str): The table to upsert into
+ keyvalues (dict): The unique key tables and their new values
+ values (dict): The nonunique columns and their new values
+ insertion_values (dict): additional key/values to use only when
+ inserting
+ Returns:
+ None
+ """
+ allvalues = {} # type: Dict[str, Any]
+ allvalues.update(keyvalues)
+ allvalues.update(insertion_values)
+
+ if not values:
+ latter = "NOTHING"
+ else:
+ allvalues.update(values)
+ latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
+
+ sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % (
+ table,
+ ", ".join(k for k in allvalues),
+ ", ".join("?" for _ in allvalues),
+ ", ".join(k for k in keyvalues),
+ latter,
+ )
+ txn.execute(sql, list(allvalues.values()))
+
+ def simple_upsert_many_txn(
+ self, txn, table, key_names, key_values, value_names, value_values
+ ):
+ """
+ Upsert, many times.
+
+ Args:
+ table (str): The table to upsert into
+ key_names (list[str]): The key column names.
+ key_values (list[list]): A list of each row's key column values.
+ value_names (list[str]): The value column names. If empty, no
+ values will be used, even if value_values is provided.
+ value_values (list[list]): A list of each row's value column values.
+ Returns:
+ None
+ """
+ if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
+ return self.simple_upsert_many_txn_native_upsert(
+ txn, table, key_names, key_values, value_names, value_values
+ )
+ else:
+ return self.simple_upsert_many_txn_emulated(
+ txn, table, key_names, key_values, value_names, value_values
+ )
+
+ def simple_upsert_many_txn_emulated(
+ self, txn, table, key_names, key_values, value_names, value_values
+ ):
+ """
+ Upsert, many times, but without native UPSERT support or batching.
+
+ Args:
+ table (str): The table to upsert into
+ key_names (list[str]): The key column names.
+ key_values (list[list]): A list of each row's key column values.
+ value_names (list[str]): The value column names. If empty, no
+ values will be used, even if value_values is provided.
+ value_values (list[list]): A list of each row's value column values.
+ Returns:
+ None
+ """
+ # No value columns, therefore make a blank list so that the following
+ # zip() works correctly.
+ if not value_names:
+ value_values = [() for x in range(len(key_values))]
+
+ for keyv, valv in zip(key_values, value_values):
+ _keys = {x: y for x, y in zip(key_names, keyv)}
+ _vals = {x: y for x, y in zip(value_names, valv)}
+
+ self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
+
+ def simple_upsert_many_txn_native_upsert(
+ self, txn, table, key_names, key_values, value_names, value_values
+ ):
+ """
+ Upsert, many times, using batching where possible.
+
+ Args:
+ table (str): The table to upsert into
+ key_names (list[str]): The key column names.
+ key_values (list[list]): A list of each row's key column values.
+ value_names (list[str]): The value column names. If empty, no
+ values will be used, even if value_values is provided.
+ value_values (list[list]): A list of each row's value column values.
+ Returns:
+ None
+ """
+ allnames = [] # type: List[str]
+ allnames.extend(key_names)
+ allnames.extend(value_names)
+
+ if not value_names:
+ # No value columns, therefore make a blank list so that the
+ # following zip() works correctly.
+ latter = "NOTHING"
+ value_values = [() for x in range(len(key_values))]
+ else:
+ latter = "UPDATE SET " + ", ".join(
+ k + "=EXCLUDED." + k for k in value_names
+ )
+
+ sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
+ table,
+ ", ".join(k for k in allnames),
+ ", ".join("?" for _ in allnames),
+ ", ".join(key_names),
+ latter,
+ )
+
+ args = []
+
+ for x, y in zip(key_values, value_values):
+ args.append(tuple(x) + tuple(y))
+
+ return txn.execute_batch(sql, args)
+
+ def simple_select_one(
+ self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one"
+ ):
+ """Executes a SELECT query on the named table, which is expected to
+ return a single row, returning multiple columns from it.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ retcols : list of strings giving the names of the columns to return
+
+ allow_none : If true, return None instead of failing if the SELECT
+ statement returns no rows
+ """
+ return self.runInteraction(
+ desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
+ )
+
+ def simple_select_one_onecol(
+ self,
+ table,
+ keyvalues,
+ retcol,
+ allow_none=False,
+ desc="simple_select_one_onecol",
+ ):
+ """Executes a SELECT query on the named table, which is expected to
+ return a single row, returning a single column from it.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ retcol : string giving the name of the column to return
+ """
+ return self.runInteraction(
+ desc,
+ self.simple_select_one_onecol_txn,
+ table,
+ keyvalues,
+ retcol,
+ allow_none=allow_none,
+ )
+
+ @classmethod
+ def simple_select_one_onecol_txn(
+ cls, txn, table, keyvalues, retcol, allow_none=False
+ ):
+ ret = cls.simple_select_onecol_txn(
+ txn, table=table, keyvalues=keyvalues, retcol=retcol
+ )
+
+ if ret:
+ return ret[0]
+ else:
+ if allow_none:
+ return None
+ else:
+ raise StoreError(404, "No row found")
+
+ @staticmethod
+ def simple_select_onecol_txn(txn, table, keyvalues, retcol):
+ sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
+
+ if keyvalues:
+ sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+ txn.execute(sql, list(keyvalues.values()))
+ else:
+ txn.execute(sql)
+
+ return [r[0] for r in txn]
+
+ def simple_select_onecol(
+ self, table, keyvalues, retcol, desc="simple_select_onecol"
+ ):
+ """Executes a SELECT query on the named table, which returns a list
+ comprising of the values of the named column from the selected rows.
+
+ Args:
+ table (str): table name
+ keyvalues (dict|None): column names and values to select the rows with
+ retcol (str): column whos value we wish to retrieve.
+
+ Returns:
+ Deferred: Results in a list
+ """
+ return self.runInteraction(
+ desc, self.simple_select_onecol_txn, table, keyvalues, retcol
+ )
+
+ def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
+
+ Args:
+ table (str): the table name
+ keyvalues (dict[str, Any] | None):
+ column names and values to select the rows with, or None to not
+ apply a WHERE clause.
+ retcols (iterable[str]): the names of the columns to return
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]]
+ """
+ return self.runInteraction(
+ desc, self.simple_select_list_txn, table, keyvalues, retcols
+ )
+
+ @classmethod
+ def simple_select_list_txn(cls, txn, table, keyvalues, retcols):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
+
+ Args:
+ txn : Transaction object
+ table (str): the table name
+ keyvalues (dict[str, T] | None):
+ column names and values to select the rows with, or None to not
+ apply a WHERE clause.
+ retcols (iterable[str]): the names of the columns to return
+ """
+ if keyvalues:
+ sql = "SELECT %s FROM %s WHERE %s" % (
+ ", ".join(retcols),
+ table,
+ " AND ".join("%s = ?" % (k,) for k in keyvalues),
+ )
+ txn.execute(sql, list(keyvalues.values()))
+ else:
+ sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
+ txn.execute(sql)
+
+ return cls.cursor_to_dict(txn)
+
+ @defer.inlineCallbacks
+ def simple_select_many_batch(
+ self,
+ table,
+ column,
+ iterable,
+ retcols,
+ keyvalues={},
+ desc="simple_select_many_batch",
+ batch_size=100,
+ ):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
+
+ Filters rows by if value of `column` is in `iterable`.
+
+ Args:
+ table : string giving the table name
+ column : column name to test for inclusion against `iterable`
+ iterable : list
+ keyvalues : dict of column names and values to select the rows with
+ retcols : list of strings giving the names of the columns to return
+ """
+ results = [] # type: List[Dict[str, Any]]
+
+ if not iterable:
+ return results
+
+ # iterables can not be sliced, so convert it to a list first
+ it_list = list(iterable)
+
+ chunks = [
+ it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
+ ]
+ for chunk in chunks:
+ rows = yield self.runInteraction(
+ desc,
+ self.simple_select_many_txn,
+ table,
+ column,
+ chunk,
+ keyvalues,
+ retcols,
+ )
+
+ results.extend(rows)
+
+ return results
+
+ @classmethod
+ def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
+
+ Filters rows by if value of `column` is in `iterable`.
+
+ Args:
+ txn : Transaction object
+ table : string giving the table name
+ column : column name to test for inclusion against `iterable`
+ iterable : list
+ keyvalues : dict of column names and values to select the rows with
+ retcols : list of strings giving the names of the columns to return
+ """
+ if not iterable:
+ return []
+
+ clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
+ clauses = [clause]
+
+ for key, value in iteritems(keyvalues):
+ clauses.append("%s = ?" % (key,))
+ values.append(value)
+
+ sql = "SELECT %s FROM %s WHERE %s" % (
+ ", ".join(retcols),
+ table,
+ " AND ".join(clauses),
+ )
+
+ txn.execute(sql, values)
+ return cls.cursor_to_dict(txn)
+
+ def simple_update(self, table, keyvalues, updatevalues, desc):
+ return self.runInteraction(
+ desc, self.simple_update_txn, table, keyvalues, updatevalues
+ )
+
+ @staticmethod
+ def simple_update_txn(txn, table, keyvalues, updatevalues):
+ if keyvalues:
+ where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+ else:
+ where = ""
+
+ update_sql = "UPDATE %s SET %s %s" % (
+ table,
+ ", ".join("%s = ?" % (k,) for k in updatevalues),
+ where,
+ )
+
+ txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values()))
+
+ return txn.rowcount
+
+ def simple_update_one(
+ self, table, keyvalues, updatevalues, desc="simple_update_one"
+ ):
+ """Executes an UPDATE query on the named table, setting new values for
+ columns in a row matching the key values.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ updatevalues : dict giving column names and values to update
+ retcols : optional list of column names to return
+
+ If present, retcols gives a list of column names on which to perform
+ a SELECT statement *before* performing the UPDATE statement. The values
+ of these will be returned in a dict.
+
+ These are performed within the same transaction, allowing an atomic
+ get-and-set. This can be used to implement compare-and-set by putting
+ the update column in the 'keyvalues' dict as well.
+ """
+ return self.runInteraction(
+ desc, self.simple_update_one_txn, table, keyvalues, updatevalues
+ )
+
+ @classmethod
+ def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
+ rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
+
+ if rowcount == 0:
+ raise StoreError(404, "No row found (%s)" % (table,))
+ if rowcount > 1:
+ raise StoreError(500, "More than one row matched (%s)" % (table,))
+
+ @staticmethod
+ def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
+ select_sql = "SELECT %s FROM %s WHERE %s" % (
+ ", ".join(retcols),
+ table,
+ " AND ".join("%s = ?" % (k,) for k in keyvalues),
+ )
+
+ txn.execute(select_sql, list(keyvalues.values()))
+ row = txn.fetchone()
+
+ if not row:
+ if allow_none:
+ return None
+ raise StoreError(404, "No row found (%s)" % (table,))
+ if txn.rowcount > 1:
+ raise StoreError(500, "More than one row matched (%s)" % (table,))
+
+ return dict(zip(retcols, row))
+
+ def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"):
+ """Executes a DELETE query on the named table, expecting to delete a
+ single row.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ """
+ return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
+
+ @staticmethod
+ def simple_delete_one_txn(txn, table, keyvalues):
+ """Executes a DELETE query on the named table, expecting to delete a
+ single row.
+
+ Args:
+ table : string giving the table name
+ keyvalues : dict of column names and values to select the row with
+ """
+ sql = "DELETE FROM %s WHERE %s" % (
+ table,
+ " AND ".join("%s = ?" % (k,) for k in keyvalues),
+ )
+
+ txn.execute(sql, list(keyvalues.values()))
+ if txn.rowcount == 0:
+ raise StoreError(404, "No row found (%s)" % (table,))
+ if txn.rowcount > 1:
+ raise StoreError(500, "More than one row matched (%s)" % (table,))
+
+ def simple_delete(self, table, keyvalues, desc):
+ return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
+
+ @staticmethod
+ def simple_delete_txn(txn, table, keyvalues):
+ sql = "DELETE FROM %s WHERE %s" % (
+ table,
+ " AND ".join("%s = ?" % (k,) for k in keyvalues),
+ )
+
+ txn.execute(sql, list(keyvalues.values()))
+ return txn.rowcount
+
+ def simple_delete_many(self, table, column, iterable, keyvalues, desc):
+ return self.runInteraction(
+ desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
+ )
+
+ @staticmethod
+ def simple_delete_many_txn(txn, table, column, iterable, keyvalues):
+ """Executes a DELETE query on the named table.
+
+ Filters rows by if value of `column` is in `iterable`.
+
+ Args:
+ txn : Transaction object
+ table : string giving the table name
+ column : column name to test for inclusion against `iterable`
+ iterable : list
+ keyvalues : dict of column names and values to select the rows with
+
+ Returns:
+ int: Number rows deleted
+ """
+ if not iterable:
+ return 0
+
+ sql = "DELETE FROM %s" % table
+
+ clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
+ clauses = [clause]
+
+ for key, value in iteritems(keyvalues):
+ clauses.append("%s = ?" % (key,))
+ values.append(value)
+
+ if clauses:
+ sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
+ txn.execute(sql, values)
+
+ return txn.rowcount
+
+ def get_cache_dict(
+ self, db_conn, table, entity_column, stream_column, max_value, limit=100000
+ ):
+ # Fetch a mapping of room_id -> max stream position for "recent" rooms.
+ # It doesn't really matter how many we get, the StreamChangeCache will
+ # do the right thing to ensure it respects the max size of cache.
+ sql = (
+ "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
+ " WHERE %(stream)s > ? - %(limit)s"
+ " GROUP BY %(entity)s"
+ ) % {
+ "table": table,
+ "entity": entity_column,
+ "stream": stream_column,
+ "limit": limit,
+ }
+
+ sql = self.engine.convert_param_style(sql)
+
+ txn = db_conn.cursor()
+ txn.execute(sql, (int(max_value),))
+
+ cache = {row[0]: int(row[1]) for row in txn}
+
+ txn.close()
+
+ if cache:
+ min_val = min(itervalues(cache))
+ else:
+ min_val = max_value
+
+ return cache, min_val
+
+ def simple_select_list_paginate(
+ self,
+ table,
+ orderby,
+ start,
+ limit,
+ retcols,
+ filters=None,
+ keyvalues=None,
+ order_direction="ASC",
+ desc="simple_select_list_paginate",
+ ):
+ """
+ Executes a SELECT query on the named table with start and limit,
+ of row numbers, which may return zero or number of rows from start to limit,
+ returning the result as a list of dicts.
+
+ Args:
+ table (str): the table name
+ filters (dict[str, T] | None):
+ column names and values to filter the rows with, or None to not
+ apply a WHERE ? LIKE ? clause.
+ keyvalues (dict[str, T] | None):
+ column names and values to select the rows with, or None to not
+ apply a WHERE clause.
+ orderby (str): Column to order the results by.
+ start (int): Index to begin the query at.
+ limit (int): Number of results to return.
+ retcols (iterable[str]): the names of the columns to return
+ order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]]
+ """
+ return self.runInteraction(
+ desc,
+ self.simple_select_list_paginate_txn,
+ table,
+ orderby,
+ start,
+ limit,
+ retcols,
+ filters=filters,
+ keyvalues=keyvalues,
+ order_direction=order_direction,
+ )
+
+ @classmethod
+ def simple_select_list_paginate_txn(
+ cls,
+ txn,
+ table,
+ orderby,
+ start,
+ limit,
+ retcols,
+ filters=None,
+ keyvalues=None,
+ order_direction="ASC",
+ ):
+ """
+ Executes a SELECT query on the named table with start and limit,
+ of row numbers, which may return zero or number of rows from start to limit,
+ returning the result as a list of dicts.
+
+ Use `filters` to search attributes using SQL wildcards and/or `keyvalues` to
+ select attributes with exact matches. All constraints are joined together
+ using 'AND'.
+
+ Args:
+ txn : Transaction object
+ table (str): the table name
+ orderby (str): Column to order the results by.
+ start (int): Index to begin the query at.
+ limit (int): Number of results to return.
+ retcols (iterable[str]): the names of the columns to return
+ filters (dict[str, T] | None):
+ column names and values to filter the rows with, or None to not
+ apply a WHERE ? LIKE ? clause.
+ keyvalues (dict[str, T] | None):
+ column names and values to select the rows with, or None to not
+ apply a WHERE clause.
+ order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]]
+ """
+ if order_direction not in ["ASC", "DESC"]:
+ raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
+
+ where_clause = "WHERE " if filters or keyvalues else ""
+ arg_list = [] # type: List[Any]
+ if filters:
+ where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
+ arg_list += list(filters.values())
+ where_clause += " AND " if filters and keyvalues else ""
+ if keyvalues:
+ where_clause += " AND ".join("%s = ?" % (k,) for k in keyvalues)
+ arg_list += list(keyvalues.values())
+
+ sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
+ ", ".join(retcols),
+ table,
+ where_clause,
+ orderby,
+ order_direction,
+ )
+ txn.execute(sql, arg_list + [limit, start])
+
+ return cls.cursor_to_dict(txn)
+
+ def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
+
+ Args:
+ table (str): the table name
+ term (str | None):
+ term for searching the table matched to a column.
+ col (str): column to query term should be matched to
+ retcols (iterable[str]): the names of the columns to return
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]] or None
+ """
+
+ return self.runInteraction(
+ desc, self.simple_search_list_txn, table, term, col, retcols
+ )
+
+ @classmethod
+ def simple_search_list_txn(cls, txn, table, term, col, retcols):
+ """Executes a SELECT query on the named table, which may return zero or
+ more rows, returning the result as a list of dicts.
+
+ Args:
+ txn : Transaction object
+ table (str): the table name
+ term (str | None):
+ term for searching the table matched to a column.
+ col (str): column to query term should be matched to
+ retcols (iterable[str]): the names of the columns to return
+ Returns:
+ defer.Deferred: resolves to list[dict[str, Any]] or None
+ """
+ if term:
+ sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
+ termvalues = ["%%" + term + "%%"]
+ txn.execute(sql, termvalues)
+ else:
+ return 0
+
+ return cls.cursor_to_dict(txn)
+
+
+def make_in_list_sql_clause(
+ database_engine, column: str, iterable: Iterable
+) -> Tuple[str, list]:
+ """Returns an SQL clause that checks the given column is in the iterable.
+
+ On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres
+ it expands to `column = ANY(?)`. While both DBs support the `IN` form,
+ using the `ANY` form on postgres means that it views queries with
+ different length iterables as the same, helping the query stats.
+
+ Args:
+ database_engine
+ column: Name of the column
+ iterable: The values to check the column against.
+
+ Returns:
+ A tuple of SQL query and the args
+ """
+
+ if database_engine.supports_using_any_list:
+ # This should hopefully be faster, but also makes postgres query
+ # stats easier to understand.
+ return "%s = ANY(?)" % (column,), [list(iterable)]
+ else:
+ return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
deleted file mode 100644
index 2fabb9e2cb..0000000000
--- a/synapse/storage/end_to_end_keys.py
+++ /dev/null
@@ -1,283 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-from six import iteritems
-
-from canonicaljson import encode_canonical_json
-
-from twisted.internet import defer
-
-from synapse.util.caches.descriptors import cached
-
-from ._base import SQLBaseStore, db_to_json
-
-
-class EndToEndKeyWorkerStore(SQLBaseStore):
- @defer.inlineCallbacks
- def get_e2e_device_keys(
- self, query_list, include_all_devices=False, include_deleted_devices=False
- ):
- """Fetch a list of device keys.
- Args:
- query_list(list): List of pairs of user_ids and device_ids.
- include_all_devices (bool): whether to include entries for devices
- that don't have device keys
- include_deleted_devices (bool): whether to include null entries for
- devices which no longer exist (but were in the query_list).
- This option only takes effect if include_all_devices is true.
- Returns:
- Dict mapping from user-id to dict mapping from device_id to
- dict containing "key_json", "device_display_name".
- """
- if not query_list:
- defer.returnValue({})
-
- results = yield self.runInteraction(
- "get_e2e_device_keys",
- self._get_e2e_device_keys_txn,
- query_list,
- include_all_devices,
- include_deleted_devices,
- )
-
- for user_id, device_keys in iteritems(results):
- for device_id, device_info in iteritems(device_keys):
- device_info["keys"] = db_to_json(device_info.pop("key_json"))
-
- defer.returnValue(results)
-
- def _get_e2e_device_keys_txn(
- self, txn, query_list, include_all_devices=False, include_deleted_devices=False
- ):
- query_clauses = []
- query_params = []
-
- if include_all_devices is False:
- include_deleted_devices = False
-
- if include_deleted_devices:
- deleted_devices = set(query_list)
-
- for (user_id, device_id) in query_list:
- query_clause = "user_id = ?"
- query_params.append(user_id)
-
- if device_id is not None:
- query_clause += " AND device_id = ?"
- query_params.append(device_id)
-
- query_clauses.append(query_clause)
-
- sql = (
- "SELECT user_id, device_id, "
- " d.display_name AS device_display_name, "
- " k.key_json"
- " FROM devices d"
- " %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
- " WHERE %s"
- ) % (
- "LEFT" if include_all_devices else "INNER",
- " OR ".join("(" + q + ")" for q in query_clauses),
- )
-
- txn.execute(sql, query_params)
- rows = self.cursor_to_dict(txn)
-
- result = {}
- for row in rows:
- if include_deleted_devices:
- deleted_devices.remove((row["user_id"], row["device_id"]))
- result.setdefault(row["user_id"], {})[row["device_id"]] = row
-
- if include_deleted_devices:
- for user_id, device_id in deleted_devices:
- result.setdefault(user_id, {})[device_id] = None
-
- return result
-
- @defer.inlineCallbacks
- def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
- """Retrieve a number of one-time keys for a user
-
- Args:
- user_id(str): id of user to get keys for
- device_id(str): id of device to get keys for
- key_ids(list[str]): list of key ids (excluding algorithm) to
- retrieve
-
- Returns:
- deferred resolving to Dict[(str, str), str]: map from (algorithm,
- key_id) to json string for key
- """
-
- rows = yield self._simple_select_many_batch(
- table="e2e_one_time_keys_json",
- column="key_id",
- iterable=key_ids,
- retcols=("algorithm", "key_id", "key_json"),
- keyvalues={"user_id": user_id, "device_id": device_id},
- desc="add_e2e_one_time_keys_check",
- )
-
- defer.returnValue(
- {(row["algorithm"], row["key_id"]): row["key_json"] for row in rows}
- )
-
- @defer.inlineCallbacks
- def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
- """Insert some new one time keys for a device. Errors if any of the
- keys already exist.
-
- Args:
- user_id(str): id of user to get keys for
- device_id(str): id of device to get keys for
- time_now(long): insertion time to record (ms since epoch)
- new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
- (algorithm, key_id, key json)
- """
-
- def _add_e2e_one_time_keys(txn):
- # We are protected from race between lookup and insertion due to
- # a unique constraint. If there is a race of two calls to
- # `add_e2e_one_time_keys` then they'll conflict and we will only
- # insert one set.
- self._simple_insert_many_txn(
- txn,
- table="e2e_one_time_keys_json",
- values=[
- {
- "user_id": user_id,
- "device_id": device_id,
- "algorithm": algorithm,
- "key_id": key_id,
- "ts_added_ms": time_now,
- "key_json": json_bytes,
- }
- for algorithm, key_id, json_bytes in new_keys
- ],
- )
- self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id)
- )
-
- yield self.runInteraction(
- "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
- )
-
- @cached(max_entries=10000)
- def count_e2e_one_time_keys(self, user_id, device_id):
- """ Count the number of one time keys the server has for a device
- Returns:
- Dict mapping from algorithm to number of keys for that algorithm.
- """
-
- def _count_e2e_one_time_keys(txn):
- sql = (
- "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
- " WHERE user_id = ? AND device_id = ?"
- " GROUP BY algorithm"
- )
- txn.execute(sql, (user_id, device_id))
- result = {}
- for algorithm, key_count in txn:
- result[algorithm] = key_count
- return result
-
- return self.runInteraction("count_e2e_one_time_keys", _count_e2e_one_time_keys)
-
-
-class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
- def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
- """Stores device keys for a device. Returns whether there was a change
- or the keys were already in the database.
- """
-
- def _set_e2e_device_keys_txn(txn):
- old_key_json = self._simple_select_one_onecol_txn(
- txn,
- table="e2e_device_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- retcol="key_json",
- allow_none=True,
- )
-
- # In py3 we need old_key_json to match new_key_json type. The DB
- # returns unicode while encode_canonical_json returns bytes.
- new_key_json = encode_canonical_json(device_keys).decode("utf-8")
-
- if old_key_json == new_key_json:
- return False
-
- self._simple_upsert_txn(
- txn,
- table="e2e_device_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- values={"ts_added_ms": time_now, "key_json": new_key_json},
- )
-
- return True
-
- return self.runInteraction("set_e2e_device_keys", _set_e2e_device_keys_txn)
-
- def claim_e2e_one_time_keys(self, query_list):
- """Take a list of one time keys out of the database"""
-
- def _claim_e2e_one_time_keys(txn):
- sql = (
- "SELECT key_id, key_json FROM e2e_one_time_keys_json"
- " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
- " LIMIT 1"
- )
- result = {}
- delete = []
- for user_id, device_id, algorithm in query_list:
- user_result = result.setdefault(user_id, {})
- device_result = user_result.setdefault(device_id, {})
- txn.execute(sql, (user_id, device_id, algorithm))
- for key_id, key_json in txn:
- device_result[algorithm + ":" + key_id] = key_json
- delete.append((user_id, device_id, algorithm, key_id))
- sql = (
- "DELETE FROM e2e_one_time_keys_json"
- " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
- " AND key_id = ?"
- )
- for user_id, device_id, algorithm, key_id in delete:
- txn.execute(sql, (user_id, device_id, algorithm, key_id))
- self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id)
- )
- return result
-
- return self.runInteraction("claim_e2e_one_time_keys", _claim_e2e_one_time_keys)
-
- def delete_e2e_keys_by_device(self, user_id, device_id):
- def delete_e2e_keys_by_device_txn(txn):
- self._simple_delete_txn(
- txn,
- table="e2e_device_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- )
- self._simple_delete_txn(
- txn,
- table="e2e_one_time_keys_json",
- keyvalues={"user_id": user_id, "device_id": device_id},
- )
- self._invalidate_cache_and_stream(
- txn, self.count_e2e_one_time_keys, (user_id, device_id)
- )
-
- return self.runInteraction(
- "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
- )
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 9d2d519922..035f9ea6e9 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -12,29 +12,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-import importlib
import platform
-from ._base import IncorrectDatabaseSetup
+from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
from .postgres import PostgresEngine
from .sqlite import Sqlite3Engine
-SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine}
-
-def create_engine(database_config):
+def create_engine(database_config) -> BaseDatabaseEngine:
name = database_config["name"]
- engine_class = SUPPORTED_MODULE.get(name, None)
- if engine_class:
+ if name == "sqlite3":
+ import sqlite3
+
+ return Sqlite3Engine(sqlite3, database_config)
+
+ if name == "psycopg2":
# pypy requires psycopg2cffi rather than psycopg2
- if name == "psycopg2" and platform.python_implementation() == "PyPy":
- name = "psycopg2cffi"
- module = importlib.import_module(name)
- return engine_class(module, database_config)
+ if platform.python_implementation() == "PyPy":
+ import psycopg2cffi as psycopg2 # type: ignore
+ else:
+ import psycopg2 # type: ignore
+
+ return PostgresEngine(psycopg2, database_config)
raise RuntimeError("Unsupported database engine '%s'" % (name,))
-__all__ = ["create_engine", "IncorrectDatabaseSetup"]
+__all__ = ["create_engine", "BaseDatabaseEngine", "IncorrectDatabaseSetup"]
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index ec5a4d198b..ab0bbe4bd3 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -12,7 +12,94 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
+from typing import Generic, TypeVar
+
+from synapse.storage.types import Connection
class IncorrectDatabaseSetup(RuntimeError):
pass
+
+
+ConnectionType = TypeVar("ConnectionType", bound=Connection)
+
+
+class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
+ def __init__(self, module, database_config: dict):
+ self.module = module
+
+ @property
+ @abc.abstractmethod
+ def single_threaded(self) -> bool:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def can_native_upsert(self) -> bool:
+ """
+ Do we support native UPSERTs?
+ """
+ ...
+
+ @property
+ @abc.abstractmethod
+ def supports_tuple_comparison(self) -> bool:
+ """
+ Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
+ """
+ ...
+
+ @property
+ @abc.abstractmethod
+ def supports_using_any_list(self) -> bool:
+ """
+ Do we support using `a = ANY(?)` and passing a list
+ """
+ ...
+
+ @abc.abstractmethod
+ def check_database(
+ self, db_conn: ConnectionType, allow_outdated_version: bool = False
+ ) -> None:
+ ...
+
+ @abc.abstractmethod
+ def check_new_database(self, txn) -> None:
+ """Gets called when setting up a brand new database. This allows us to
+ apply stricter checks on new databases versus existing database.
+ """
+ ...
+
+ @abc.abstractmethod
+ def convert_param_style(self, sql: str) -> str:
+ ...
+
+ @abc.abstractmethod
+ def on_new_connection(self, db_conn: ConnectionType) -> None:
+ ...
+
+ @abc.abstractmethod
+ def is_deadlock(self, error: Exception) -> bool:
+ ...
+
+ @abc.abstractmethod
+ def is_connection_closed(self, conn: ConnectionType) -> bool:
+ ...
+
+ @abc.abstractmethod
+ def lock_table(self, txn, table: str) -> None:
+ ...
+
+ @abc.abstractmethod
+ def get_next_state_group_id(self, txn) -> int:
+ """Returns an int that can be used as a new state_group ID
+ """
+ ...
+
+ @property
+ @abc.abstractmethod
+ def server_version(self) -> str:
+ """Gets a string giving the server version. For example: '3.22.0'
+ """
+ ...
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 1b97ee74e3..6c7d08a6f2 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -13,38 +13,97 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import IncorrectDatabaseSetup
+import logging
+from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
-class PostgresEngine(object):
- single_threaded = False
+logger = logging.getLogger(__name__)
+
+class PostgresEngine(BaseDatabaseEngine):
def __init__(self, database_module, database_config):
- self.module = database_module
+ super().__init__(database_module, database_config)
self.module.extensions.register_type(self.module.extensions.UNICODE)
- self.synchronous_commit = database_config.get("synchronous_commit", True)
- self._version = None # unknown as yet
- def check_database(self, txn):
- txn.execute("SHOW SERVER_ENCODING")
- rows = txn.fetchall()
- if rows and rows[0][0] != "UTF8":
- raise IncorrectDatabaseSetup(
- "Database has incorrect encoding: '%s' instead of 'UTF8'\n"
- "See docs/postgres.rst for more information." % (rows[0][0],)
- )
+ # Disables passing `bytes` to txn.execute, c.f. #6186. If you do
+ # actually want to use bytes than wrap it in `bytearray`.
+ def _disable_bytes_adapter(_):
+ raise Exception("Passing bytes to DB is disabled.")
- def convert_param_style(self, sql):
- return sql.replace("?", "%s")
+ self.module.extensions.register_adapter(bytes, _disable_bytes_adapter)
+ self.synchronous_commit = database_config.get("synchronous_commit", True)
+ self._version = None # unknown as yet
- def on_new_connection(self, db_conn):
+ @property
+ def single_threaded(self) -> bool:
+ return False
+ def check_database(self, db_conn, allow_outdated_version: bool = False):
# Get the version of PostgreSQL that we're using. As per the psycopg2
# docs: The number is formed by converting the major, minor, and
# revision numbers into two-decimal-digit numbers and appending them
# together. For example, version 8.1.5 will be returned as 80105
self._version = db_conn.server_version
+ # Are we on a supported PostgreSQL version?
+ if not allow_outdated_version and self._version < 90500:
+ raise RuntimeError("Synapse requires PostgreSQL 9.5+ or above.")
+
+ with db_conn.cursor() as txn:
+ txn.execute("SHOW SERVER_ENCODING")
+ rows = txn.fetchall()
+ if rows and rows[0][0] != "UTF8":
+ raise IncorrectDatabaseSetup(
+ "Database has incorrect encoding: '%s' instead of 'UTF8'\n"
+ "See docs/postgres.md for more information." % (rows[0][0],)
+ )
+
+ txn.execute(
+ "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
+ )
+ collation, ctype = txn.fetchone()
+ if collation != "C":
+ logger.warning(
+ "Database has incorrect collation of %r. Should be 'C'\n"
+ "See docs/postgres.md for more information.",
+ collation,
+ )
+
+ if ctype != "C":
+ logger.warning(
+ "Database has incorrect ctype of %r. Should be 'C'\n"
+ "See docs/postgres.md for more information.",
+ ctype,
+ )
+
+ def check_new_database(self, txn):
+ """Gets called when setting up a brand new database. This allows us to
+ apply stricter checks on new databases versus existing database.
+ """
+
+ txn.execute(
+ "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
+ )
+ collation, ctype = txn.fetchone()
+
+ errors = []
+
+ if collation != "C":
+ errors.append(" - 'COLLATE' is set to %r. Should be 'C'" % (collation,))
+
+ if ctype != "C":
+ errors.append(" - 'CTYPE' is set to %r. Should be 'C'" % (collation,))
+
+ if errors:
+ raise IncorrectDatabaseSetup(
+ "Database is incorrectly configured:\n\n%s\n\n"
+ "See docs/postgres.md for more information." % ("\n".join(errors))
+ )
+
+ def convert_param_style(self, sql):
+ return sql.replace("?", "%s")
+
+ def on_new_connection(self, db_conn):
db_conn.set_isolation_level(
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
@@ -64,9 +123,22 @@ class PostgresEngine(object):
@property
def can_native_upsert(self):
"""
- Can we use native UPSERTs? This requires PostgreSQL 9.5+.
+ Can we use native UPSERTs?
+ """
+ return True
+
+ @property
+ def supports_tuple_comparison(self):
+ """
+ Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
+ """
+ return True
+
+ @property
+ def supports_using_any_list(self):
+ """Do we support using `a = ANY(?)` and passing a list
"""
- return self._version >= 90500
+ return True
def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):
@@ -95,8 +167,8 @@ class PostgresEngine(object):
Returns:
string
"""
- # note that this is a bit of a hack because it relies on on_new_connection
- # having been called at least once. Still, that should be a safe bet here.
+ # note that this is a bit of a hack because it relies on check_database
+ # having been called. Still, that should be a safe bet here.
numver = self._version
assert numver is not None
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 933bcf42c2..3bc2e8b986 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -12,18 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import struct
import threading
+import typing
-from synapse.storage.prepare_database import prepare_database
+from synapse.storage.engines import BaseDatabaseEngine
+if typing.TYPE_CHECKING:
+ import sqlite3 # noqa: F401
-class Sqlite3Engine(object):
- single_threaded = True
+class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
def __init__(self, database_module, database_config):
- self.module = database_module
+ super().__init__(database_module, database_config)
+
+ database = database_config.get("args", {}).get("database")
+ self._is_in_memory = database in (None, ":memory:",)
# The current max state_group, or None if we haven't looked
# in the DB yet.
@@ -31,6 +35,10 @@ class Sqlite3Engine(object):
self._current_state_group_id_lock = threading.Lock()
@property
+ def single_threaded(self) -> bool:
+ return True
+
+ @property
def can_native_upsert(self):
"""
Do we support native UPSERTs? This requires SQLite3 3.24+, plus some
@@ -38,14 +46,44 @@ class Sqlite3Engine(object):
"""
return self.module.sqlite_version_info >= (3, 24, 0)
- def check_database(self, txn):
- pass
+ @property
+ def supports_tuple_comparison(self):
+ """
+ Do we support comparing tuples, i.e. `(a, b) > (c, d)`? This requires
+ SQLite 3.15+.
+ """
+ return self.module.sqlite_version_info >= (3, 15, 0)
+
+ @property
+ def supports_using_any_list(self):
+ """Do we support using `a = ANY(?)` and passing a list
+ """
+ return False
+
+ def check_database(self, db_conn, allow_outdated_version: bool = False):
+ if not allow_outdated_version:
+ version = self.module.sqlite_version_info
+ if version < (3, 11, 0):
+ raise RuntimeError("Synapse requires sqlite 3.11 or above.")
+
+ def check_new_database(self, txn):
+ """Gets called when setting up a brand new database. This allows us to
+ apply stricter checks on new databases versus existing database.
+ """
def convert_param_style(self, sql):
return sql
def on_new_connection(self, db_conn):
- prepare_database(db_conn, self, config=None)
+ # We need to import here to avoid an import loop.
+ from synapse.storage.prepare_database import prepare_database
+
+ if self._is_in_memory:
+ # In memory databases need to be rebuilt each time. Ideally we'd
+ # reuse the same connection as we do when starting up, but that
+ # would involve using adbapi before we have started the reactor.
+ prepare_database(db_conn, self, config=None)
+
db_conn.create_function("rank", 1, _rank)
def is_deadlock(self, error):
@@ -85,7 +123,7 @@ class Sqlite3Engine(object):
def _parse_match_info(buf):
bufsize = len(buf)
- return [struct.unpack('@I', buf[i : i + 4])[0] for i in range(0, bufsize, 4)]
+ return [struct.unpack("@I", buf[i : i + 4])[0] for i in range(0, bufsize, 4)]
def _rank(raw_match_info):
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
deleted file mode 100644
index 5dc49822b5..0000000000
--- a/synapse/storage/events_worker.py
+++ /dev/null
@@ -1,742 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from __future__ import division
-
-import itertools
-import logging
-from collections import namedtuple
-
-from canonicaljson import json
-
-from twisted.internet import defer
-
-from synapse.api.constants import EventTypes
-from synapse.api.errors import NotFoundError
-from synapse.api.room_versions import EventFormatVersions
-from synapse.events import FrozenEvent, event_type_from_format_version # noqa: F401
-# these are only included to make the type annotations work
-from synapse.events.snapshot import EventContext # noqa: F401
-from synapse.events.utils import prune_event
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import get_domain_from_id
-from synapse.util.logcontext import (
- LoggingContext,
- PreserveLoggingContext,
- make_deferred_yieldable,
- run_in_background,
-)
-from synapse.util.metrics import Measure
-
-from ._base import SQLBaseStore
-
-logger = logging.getLogger(__name__)
-
-
-# These values are used in the `enqueus_event` and `_do_fetch` methods to
-# control how we batch/bulk fetch events from the database.
-# The values are plucked out of thing air to make initial sync run faster
-# on jki.re
-# TODO: Make these configurable.
-EVENT_QUEUE_THREADS = 3 # Max number of threads that will fetch events
-EVENT_QUEUE_ITERATIONS = 3 # No. times we block waiting for requests for events
-EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
-
-
-_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
-
-
-class EventsWorkerStore(SQLBaseStore):
- def get_received_ts(self, event_id):
- """Get received_ts (when it was persisted) for the event.
-
- Raises an exception for unknown events.
-
- Args:
- event_id (str)
-
- Returns:
- Deferred[int|None]: Timestamp in milliseconds, or None for events
- that were persisted before received_ts was implemented.
- """
- return self._simple_select_one_onecol(
- table="events",
- keyvalues={"event_id": event_id},
- retcol="received_ts",
- desc="get_received_ts",
- )
-
- def get_received_ts_by_stream_pos(self, stream_ordering):
- """Given a stream ordering get an approximate timestamp of when it
- happened.
-
- This is done by simply taking the received ts of the first event that
- has a stream ordering greater than or equal to the given stream pos.
- If none exists returns the current time, on the assumption that it must
- have happened recently.
-
- Args:
- stream_ordering (int)
-
- Returns:
- Deferred[int]
- """
-
- def _get_approximate_received_ts_txn(txn):
- sql = """
- SELECT received_ts FROM events
- WHERE stream_ordering >= ?
- LIMIT 1
- """
-
- txn.execute(sql, (stream_ordering,))
- row = txn.fetchone()
- if row and row[0]:
- ts = row[0]
- else:
- ts = self.clock.time_msec()
-
- return ts
-
- return self.runInteraction(
- "get_approximate_received_ts",
- _get_approximate_received_ts_txn,
- )
-
- @defer.inlineCallbacks
- def get_event(
- self,
- event_id,
- check_redacted=True,
- get_prev_content=False,
- allow_rejected=False,
- allow_none=False,
- check_room_id=None,
- ):
- """Get an event from the database by event_id.
-
- Args:
- event_id (str): The event_id of the event to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
- include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
- allow_none (bool): If True, return None if no event found, if
- False throw a NotFoundError
- check_room_id (str|None): if not None, check the room of the found event.
- If there is a mismatch, behave as per allow_none.
-
- Returns:
- Deferred : A FrozenEvent.
- """
- events = yield self.get_events_as_list(
- [event_id],
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- allow_rejected=allow_rejected,
- )
-
- event = events[0] if events else None
-
- if event is not None and check_room_id is not None:
- if event.room_id != check_room_id:
- event = None
-
- if event is None and not allow_none:
- raise NotFoundError("Could not find event %s" % (event_id,))
-
- defer.returnValue(event)
-
- @defer.inlineCallbacks
- def get_events(
- self,
- event_ids,
- check_redacted=True,
- get_prev_content=False,
- allow_rejected=False,
- ):
- """Get events from the database
-
- Args:
- event_ids (list): The event_ids of the events to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
- include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
-
- Returns:
- Deferred : Dict from event_id to event.
- """
- events = yield self.get_events_as_list(
- event_ids,
- check_redacted=check_redacted,
- get_prev_content=get_prev_content,
- allow_rejected=allow_rejected,
- )
-
- defer.returnValue({e.event_id: e for e in events})
-
- @defer.inlineCallbacks
- def get_events_as_list(
- self,
- event_ids,
- check_redacted=True,
- get_prev_content=False,
- allow_rejected=False,
- ):
- """Get events from the database and return in a list in the same order
- as given by `event_ids` arg.
-
- Args:
- event_ids (list): The event_ids of the events to fetch
- check_redacted (bool): If True, check if event has been redacted
- and redact it.
- get_prev_content (bool): If True and event is a state event,
- include the previous states content in the unsigned field.
- allow_rejected (bool): If True return rejected events.
-
- Returns:
- Deferred[list[EventBase]]: List of events fetched from the database. The
- events are in the same order as `event_ids` arg.
-
- Note that the returned list may be smaller than the list of event
- IDs if not all events could be fetched.
- """
-
- if not event_ids:
- defer.returnValue([])
-
- event_id_list = event_ids
- event_ids = set(event_ids)
-
- event_entry_map = self._get_events_from_cache(
- event_ids, allow_rejected=allow_rejected
- )
-
- missing_events_ids = [e for e in event_ids if e not in event_entry_map]
-
- if missing_events_ids:
- log_ctx = LoggingContext.current_context()
- log_ctx.record_event_fetch(len(missing_events_ids))
-
- # Note that _enqueue_events is also responsible for turning db rows
- # into FrozenEvents (via _get_event_from_row), which involves seeing if
- # the events have been redacted, and if so pulling the redaction event out
- # of the database to check it.
- #
- # _enqueue_events is a bit of a rubbish name but naming is hard.
- missing_events = yield self._enqueue_events(
- missing_events_ids, allow_rejected=allow_rejected
- )
-
- event_entry_map.update(missing_events)
-
- events = []
- for event_id in event_id_list:
- entry = event_entry_map.get(event_id, None)
- if not entry:
- continue
-
- # Starting in room version v3, some redactions need to be rechecked if we
- # didn't have the redacted event at the time, so we recheck on read
- # instead.
- if not allow_rejected and entry.event.type == EventTypes.Redaction:
- orig_event_info = yield self._simple_select_one(
- table="events",
- keyvalues={"event_id": entry.event.redacts},
- retcols=["sender", "room_id", "type"],
- allow_none=True,
- )
-
- if not orig_event_info:
- # We don't have the event that is being redacted, so we
- # assume that the event isn't authorized for now. (If we
- # later receive the event, then we will always redact
- # it anyway, since we have this redaction)
- continue
-
- if orig_event_info["room_id"] != entry.event.room_id:
- # Don't process redactions if the redacted event doesn't belong to the
- # redaction's room.
- logger.info("Ignoring redation in another room.")
- continue
-
- if entry.event.internal_metadata.need_to_check_redaction():
- # XXX: we need to avoid calling get_event here.
- #
- # The problem is that we end up at this point when an event
- # which has been redacted is pulled out of the database by
- # _enqueue_events, because _enqueue_events needs to check
- # the redaction before it can cache the redacted event. So
- # obviously, calling get_event to get the redacted event out
- # of the database gives us an infinite loop.
- #
- # For now (quick hack to fix during 0.99 release cycle), we
- # just go and fetch the relevant row from the db, but it
- # would be nice to think about how we can cache this rather
- # than hit the db every time we access a redaction event.
- #
- # One thought on how to do this:
- # 1. split get_events_as_list up so that it is divided into
- # (a) get the rawish event from the db/cache, (b) do the
- # redaction/rejection filtering
- # 2. have _get_event_from_row just call the first half of
- # that
-
- expected_domain = get_domain_from_id(entry.event.sender)
- if (
- get_domain_from_id(orig_event_info["sender"]) == expected_domain
- ):
- # This redaction event is allowed. Mark as not needing a
- # recheck.
- entry.event.internal_metadata.recheck_redaction = False
-
- if allow_rejected or not entry.event.rejected_reason:
- if check_redacted and entry.redacted_event:
- event = entry.redacted_event
- else:
- event = entry.event
-
- events.append(event)
-
- if get_prev_content:
- if "replaces_state" in event.unsigned:
- prev = yield self.get_event(
- event.unsigned["replaces_state"],
- get_prev_content=False,
- allow_none=True,
- )
- if prev:
- event.unsigned = dict(event.unsigned)
- event.unsigned["prev_content"] = prev.content
- event.unsigned["prev_sender"] = prev.sender
-
- defer.returnValue(events)
-
- def _invalidate_get_event_cache(self, event_id):
- self._get_event_cache.invalidate((event_id,))
-
- def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
- """Fetch events from the caches
-
- Args:
- events (list(str)): list of event_ids to fetch
- allow_rejected (bool): Whether to teturn events that were rejected
- update_metrics (bool): Whether to update the cache hit ratio metrics
-
- Returns:
- dict of event_id -> _EventCacheEntry for each event_id in cache. If
- allow_rejected is `False` then there will still be an entry but it
- will be `None`
- """
- event_map = {}
-
- for event_id in events:
- ret = self._get_event_cache.get(
- (event_id,), None, update_metrics=update_metrics
- )
- if not ret:
- continue
-
- if allow_rejected or not ret.event.rejected_reason:
- event_map[event_id] = ret
- else:
- event_map[event_id] = None
-
- return event_map
-
- def _do_fetch(self, conn):
- """Takes a database connection and waits for requests for events from
- the _event_fetch_list queue.
- """
- i = 0
- while True:
- with self._event_fetch_lock:
- event_list = self._event_fetch_list
- self._event_fetch_list = []
-
- if not event_list:
- single_threaded = self.database_engine.single_threaded
- if single_threaded or i > EVENT_QUEUE_ITERATIONS:
- self._event_fetch_ongoing -= 1
- return
- else:
- self._event_fetch_lock.wait(EVENT_QUEUE_TIMEOUT_S)
- i += 1
- continue
- i = 0
-
- self._fetch_event_list(conn, event_list)
-
- def _fetch_event_list(self, conn, event_list):
- """Handle a load of requests from the _event_fetch_list queue
-
- Args:
- conn (twisted.enterprise.adbapi.Connection): database connection
-
- event_list (list[Tuple[list[str], Deferred]]):
- The fetch requests. Each entry consists of a list of event
- ids to be fetched, and a deferred to be completed once the
- events have been fetched.
-
- """
- with Measure(self._clock, "_fetch_event_list"):
- try:
- event_id_lists = list(zip(*event_list))[0]
- event_ids = [item for sublist in event_id_lists for item in sublist]
-
- rows = self._new_transaction(
- conn, "do_fetch", [], [], self._fetch_event_rows, event_ids
- )
-
- row_dict = {r["event_id"]: r for r in rows}
-
- # We only want to resolve deferreds from the main thread
- def fire(lst, res):
- for ids, d in lst:
- if not d.called:
- try:
- with PreserveLoggingContext():
- d.callback([res[i] for i in ids if i in res])
- except Exception:
- logger.exception("Failed to callback")
-
- with PreserveLoggingContext():
- self.hs.get_reactor().callFromThread(fire, event_list, row_dict)
- except Exception as e:
- logger.exception("do_fetch")
-
- # We only want to resolve deferreds from the main thread
- def fire(evs, exc):
- for _, d in evs:
- if not d.called:
- with PreserveLoggingContext():
- d.errback(exc)
-
- with PreserveLoggingContext():
- self.hs.get_reactor().callFromThread(fire, event_list, e)
-
- @defer.inlineCallbacks
- def _enqueue_events(self, events, allow_rejected=False):
- """Fetches events from the database using the _event_fetch_list. This
- allows batch and bulk fetching of events - it allows us to fetch events
- without having to create a new transaction for each request for events.
- """
- if not events:
- defer.returnValue({})
-
- events_d = defer.Deferred()
- with self._event_fetch_lock:
- self._event_fetch_list.append((events, events_d))
-
- self._event_fetch_lock.notify()
-
- if self._event_fetch_ongoing < EVENT_QUEUE_THREADS:
- self._event_fetch_ongoing += 1
- should_start = True
- else:
- should_start = False
-
- if should_start:
- run_as_background_process(
- "fetch_events", self.runWithConnection, self._do_fetch
- )
-
- logger.debug("Loading %d events", len(events))
- with PreserveLoggingContext():
- rows = yield events_d
- logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
-
- if not allow_rejected:
- rows[:] = [r for r in rows if not r["rejects"]]
-
- res = yield make_deferred_yieldable(
- defer.gatherResults(
- [
- run_in_background(
- self._get_event_from_row,
- row["internal_metadata"],
- row["json"],
- row["redacts"],
- rejected_reason=row["rejects"],
- format_version=row["format_version"],
- )
- for row in rows
- ],
- consumeErrors=True,
- )
- )
-
- defer.returnValue({e.event.event_id: e for e in res if e})
-
- def _fetch_event_rows(self, txn, events):
- rows = []
- N = 200
- for i in range(1 + len(events) // N):
- evs = events[i * N : (i + 1) * N]
- if not evs:
- break
-
- sql = (
- "SELECT "
- " e.event_id as event_id, "
- " e.internal_metadata,"
- " e.json,"
- " e.format_version, "
- " r.redacts as redacts,"
- " rej.event_id as rejects "
- " FROM event_json as e"
- " LEFT JOIN rejections as rej USING (event_id)"
- " LEFT JOIN redactions as r ON e.event_id = r.redacts"
- " WHERE e.event_id IN (%s)"
- ) % (",".join(["?"] * len(evs)),)
-
- txn.execute(sql, evs)
- rows.extend(self.cursor_to_dict(txn))
-
- return rows
-
- @defer.inlineCallbacks
- def _get_event_from_row(
- self, internal_metadata, js, redacted, format_version, rejected_reason=None
- ):
- with Measure(self._clock, "_get_event_from_row"):
- d = json.loads(js)
- internal_metadata = json.loads(internal_metadata)
-
- if rejected_reason:
- rejected_reason = yield self._simple_select_one_onecol(
- table="rejections",
- keyvalues={"event_id": rejected_reason},
- retcol="reason",
- desc="_get_event_from_row_rejected_reason",
- )
-
- if format_version is None:
- # This means that we stored the event before we had the concept
- # of a event format version, so it must be a V1 event.
- format_version = EventFormatVersions.V1
-
- original_ev = event_type_from_format_version(format_version)(
- event_dict=d,
- internal_metadata_dict=internal_metadata,
- rejected_reason=rejected_reason,
- )
-
- redacted_event = None
- if redacted and original_ev.type != EventTypes.Redaction:
- redacted_event = prune_event(original_ev)
-
- redaction_id = yield self._simple_select_one_onecol(
- table="redactions",
- keyvalues={"redacts": redacted_event.event_id},
- retcol="event_id",
- desc="_get_event_from_row_redactions",
- )
-
- redacted_event.unsigned["redacted_by"] = redaction_id
- # Get the redaction event.
-
- because = yield self.get_event(
- redaction_id, check_redacted=False, allow_none=True
- )
-
- if because:
- # It's fine to do add the event directly, since get_pdu_json
- # will serialise this field correctly
- redacted_event.unsigned["redacted_because"] = because
-
- # Starting in room version v3, some redactions need to be
- # rechecked if we didn't have the redacted event at the
- # time, so we recheck on read instead.
- if because.internal_metadata.need_to_check_redaction():
- expected_domain = get_domain_from_id(original_ev.sender)
- if get_domain_from_id(because.sender) == expected_domain:
- # This redaction event is allowed. Mark as not needing a
- # recheck.
- because.internal_metadata.recheck_redaction = False
- else:
- # Senders don't match, so the event isn't actually
- # redacted
- redacted_event = None
-
- if because.room_id != original_ev.room_id:
- redacted_event = None
- else:
- # The lack of a redaction likely means that the redaction is invalid
- # and therefore not returned by get_event, so it should be safe to
- # just ignore it here.
- redacted_event = None
-
- cache_entry = _EventCacheEntry(
- event=original_ev, redacted_event=redacted_event
- )
-
- self._get_event_cache.prefill((original_ev.event_id,), cache_entry)
-
- defer.returnValue(cache_entry)
-
- @defer.inlineCallbacks
- def have_events_in_timeline(self, event_ids):
- """Given a list of event ids, check if we have already processed and
- stored them as non outliers.
- """
- rows = yield self._simple_select_many_batch(
- table="events",
- retcols=("event_id",),
- column="event_id",
- iterable=list(event_ids),
- keyvalues={"outlier": False},
- desc="have_events_in_timeline",
- )
-
- defer.returnValue(set(r["event_id"] for r in rows))
-
- @defer.inlineCallbacks
- def have_seen_events(self, event_ids):
- """Given a list of event ids, check if we have already processed them.
-
- Args:
- event_ids (iterable[str]):
-
- Returns:
- Deferred[set[str]]: The events we have already seen.
- """
- results = set()
-
- def have_seen_events_txn(txn, chunk):
- sql = "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" % (
- ",".join("?" * len(chunk)),
- )
- txn.execute(sql, chunk)
- for (event_id,) in txn:
- results.add(event_id)
-
- # break the input up into chunks of 100
- input_iterator = iter(event_ids)
- for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
- yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk)
- defer.returnValue(results)
-
- def get_seen_events_with_rejections(self, event_ids):
- """Given a list of event ids, check if we rejected them.
-
- Args:
- event_ids (list[str])
-
- Returns:
- Deferred[dict[str, str|None):
- Has an entry for each event id we already have seen. Maps to
- the rejected reason string if we rejected the event, else maps
- to None.
- """
- if not event_ids:
- return defer.succeed({})
-
- def f(txn):
- sql = (
- "SELECT e.event_id, reason FROM events as e "
- "LEFT JOIN rejections as r ON e.event_id = r.event_id "
- "WHERE e.event_id = ?"
- )
-
- res = {}
- for event_id in event_ids:
- txn.execute(sql, (event_id,))
- row = txn.fetchone()
- if row:
- _, rejected = row
- res[event_id] = rejected
-
- return res
-
- return self.runInteraction("get_seen_events_with_rejections", f)
-
- def _get_total_state_event_counts_txn(self, txn, room_id):
- """
- See get_total_state_event_counts.
- """
- # We join against the events table as that has an index on room_id
- sql = """
- SELECT COUNT(*) FROM state_events
- INNER JOIN events USING (room_id, event_id)
- WHERE room_id=?
- """
- txn.execute(sql, (room_id,))
- row = txn.fetchone()
- return row[0] if row else 0
-
- def get_total_state_event_counts(self, room_id):
- """
- Gets the total number of state events in a room.
-
- Args:
- room_id (str)
-
- Returns:
- Deferred[int]
- """
- return self.runInteraction(
- "get_total_state_event_counts",
- self._get_total_state_event_counts_txn, room_id
- )
-
- def _get_current_state_event_counts_txn(self, txn, room_id):
- """
- See get_current_state_event_counts.
- """
- sql = "SELECT COUNT(*) FROM current_state_events WHERE room_id=?"
- txn.execute(sql, (room_id,))
- row = txn.fetchone()
- return row[0] if row else 0
-
- def get_current_state_event_counts(self, room_id):
- """
- Gets the current number of state events in a room.
-
- Args:
- room_id (str)
-
- Returns:
- Deferred[int]
- """
- return self.runInteraction(
- "get_current_state_event_counts",
- self._get_current_state_event_counts_txn, room_id
- )
-
- @defer.inlineCallbacks
- def get_room_complexity(self, room_id):
- """
- Get a rough approximation of the complexity of the room. This is used by
- remote servers to decide whether they wish to join the room or not.
- Higher complexity value indicates that being in the room will consume
- more resources.
-
- Args:
- room_id (str)
-
- Returns:
- Deferred[dict[str:int]] of complexity version to complexity.
- """
- state_events = yield self.get_current_state_event_counts(room_id)
-
- # Call this one "v1", so we can introduce new ones as we want to develop
- # it.
- complexity_v1 = round(state_events / 500, 2)
-
- defer.returnValue({"v1": complexity_v1})
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index e3655ad8d7..4769b21529 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -14,208 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import itertools
import logging
-import six
-
import attr
-from signedjson.key import decode_verify_key_bytes
-
-from synapse.util import batch_iter
-from synapse.util.caches.descriptors import cached, cachedList
-
-from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
-# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
-# despite being deprecated and removed in favor of memoryview
-if six.PY2:
- db_binary_type = six.moves.builtins.buffer
-else:
- db_binary_type = memoryview
-
@attr.s(slots=True, frozen=True)
class FetchKeyResult(object):
verify_key = attr.ib() # VerifyKey: the key itself
valid_until_ts = attr.ib() # int: how long we can use this key for
-
-
-class KeyStore(SQLBaseStore):
- """Persistence for signature verification keys
- """
-
- @cached()
- def _get_server_verify_key(self, server_name_and_key_id):
- raise NotImplementedError()
-
- @cachedList(
- cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids"
- )
- def get_server_verify_keys(self, server_name_and_key_ids):
- """
- Args:
- server_name_and_key_ids (iterable[Tuple[str, str]]):
- iterable of (server_name, key-id) tuples to fetch keys for
-
- Returns:
- Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
- map from (server_name, key_id) -> FetchKeyResult, or None if the key is
- unknown
- """
- keys = {}
-
- def _get_keys(txn, batch):
- """Processes a batch of keys to fetch, and adds the result to `keys`."""
-
- # batch_iter always returns tuples so it's safe to do len(batch)
- sql = (
- "SELECT server_name, key_id, verify_key, ts_valid_until_ms "
- "FROM server_signature_keys WHERE 1=0"
- ) + " OR (server_name=? AND key_id=?)" * len(batch)
-
- txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
-
- for row in txn:
- server_name, key_id, key_bytes, ts_valid_until_ms = row
-
- if ts_valid_until_ms is None:
- # Old keys may be stored with a ts_valid_until_ms of null,
- # in which case we treat this as if it was set to `0`, i.e.
- # it won't match key requests that define a minimum
- # `ts_valid_until_ms`.
- ts_valid_until_ms = 0
-
- res = FetchKeyResult(
- verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
- valid_until_ts=ts_valid_until_ms,
- )
- keys[(server_name, key_id)] = res
-
- def _txn(txn):
- for batch in batch_iter(server_name_and_key_ids, 50):
- _get_keys(txn, batch)
- return keys
-
- return self.runInteraction("get_server_verify_keys", _txn)
-
- def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
- """Stores NACL verification keys for remote servers.
- Args:
- from_server (str): Where the verification keys were looked up
- ts_added_ms (int): The time to record that the key was added
- verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
- keys to be stored. Each entry is a triplet of
- (server_name, key_id, key).
- """
- key_values = []
- value_values = []
- invalidations = []
- for server_name, key_id, fetch_result in verify_keys:
- key_values.append((server_name, key_id))
- value_values.append(
- (
- from_server,
- ts_added_ms,
- fetch_result.valid_until_ts,
- db_binary_type(fetch_result.verify_key.encode()),
- )
- )
- # invalidate takes a tuple corresponding to the params of
- # _get_server_verify_key. _get_server_verify_key only takes one
- # param, which is itself the 2-tuple (server_name, key_id).
- invalidations.append((server_name, key_id))
-
- def _invalidate(res):
- f = self._get_server_verify_key.invalidate
- for i in invalidations:
- f((i, ))
- return res
-
- return self.runInteraction(
- "store_server_verify_keys",
- self._simple_upsert_many_txn,
- table="server_signature_keys",
- key_names=("server_name", "key_id"),
- key_values=key_values,
- value_names=(
- "from_server",
- "ts_added_ms",
- "ts_valid_until_ms",
- "verify_key",
- ),
- value_values=value_values,
- ).addCallback(_invalidate)
-
- def store_server_keys_json(
- self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
- ):
- """Stores the JSON bytes for a set of keys from a server
- The JSON should be signed by the originating server, the intermediate
- server, and by this server. Updates the value for the
- (server_name, key_id, from_server) triplet if one already existed.
- Args:
- server_name (str): The name of the server.
- key_id (str): The identifer of the key this JSON is for.
- from_server (str): The server this JSON was fetched from.
- ts_now_ms (int): The time now in milliseconds.
- ts_valid_until_ms (int): The time when this json stops being valid.
- key_json (bytes): The encoded JSON.
- """
- return self._simple_upsert(
- table="server_keys_json",
- keyvalues={
- "server_name": server_name,
- "key_id": key_id,
- "from_server": from_server,
- },
- values={
- "server_name": server_name,
- "key_id": key_id,
- "from_server": from_server,
- "ts_added_ms": ts_now_ms,
- "ts_valid_until_ms": ts_expires_ms,
- "key_json": db_binary_type(key_json_bytes),
- },
- desc="store_server_keys_json",
- )
-
- def get_server_keys_json(self, server_keys):
- """Retrive the key json for a list of server_keys and key ids.
- If no keys are found for a given server, key_id and source then
- that server, key_id, and source triplet entry will be an empty list.
- The JSON is returned as a byte array so that it can be efficiently
- used in an HTTP response.
- Args:
- server_keys (list): List of (server_name, key_id, source) triplets.
- Returns:
- Deferred[dict[Tuple[str, str, str|None], list[dict]]]:
- Dict mapping (server_name, key_id, source) triplets to lists of dicts
- """
-
- def _get_server_keys_json_txn(txn):
- results = {}
- for server_name, key_id, from_server in server_keys:
- keyvalues = {"server_name": server_name}
- if key_id is not None:
- keyvalues["key_id"] = key_id
- if from_server is not None:
- keyvalues["from_server"] = from_server
- rows = self._simple_select_list_txn(
- txn,
- "server_keys_json",
- keyvalues=keyvalues,
- retcols=(
- "key_id",
- "from_server",
- "ts_added_ms",
- "ts_valid_until_ms",
- "key_json",
- ),
- )
- results[(server_name, key_id, from_server)] = rows
- return results
-
- return self.runInteraction("get_server_keys_json", _get_server_keys_json_txn)
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
new file mode 100644
index 0000000000..0f9ac1cf09
--- /dev/null
+++ b/synapse/storage/persist_events.py
@@ -0,0 +1,801 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import logging
+from collections import deque, namedtuple
+from typing import Iterable, List, Optional, Set, Tuple
+
+from six import iteritems
+from six.moves import range
+
+import attr
+from prometheus_client import Counter, Histogram
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.events import FrozenEvent
+from synapse.events.snapshot import EventContext
+from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.state import StateResolutionStore
+from synapse.storage.data_stores import DataStores
+from synapse.types import StateMap
+from synapse.util.async_helpers import ObservableDeferred
+from synapse.util.metrics import Measure
+
+logger = logging.getLogger(__name__)
+
+# The number of times we are recalculating the current state
+state_delta_counter = Counter("synapse_storage_events_state_delta", "")
+
+# The number of times we are recalculating state when there is only a
+# single forward extremity
+state_delta_single_event_counter = Counter(
+ "synapse_storage_events_state_delta_single_event", ""
+)
+
+# The number of times we are reculating state when we could have resonably
+# calculated the delta when we calculated the state for an event we were
+# persisting.
+state_delta_reuse_delta_counter = Counter(
+ "synapse_storage_events_state_delta_reuse_delta", ""
+)
+
+# The number of forward extremities for each new event.
+forward_extremities_counter = Histogram(
+ "synapse_storage_events_forward_extremities_persisted",
+ "Number of forward extremities for each new event",
+ buckets=(1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
+)
+
+# The number of stale forward extremities for each new event. Stale extremities
+# are those that were in the previous set of extremities as well as the new.
+stale_forward_extremities_counter = Histogram(
+ "synapse_storage_events_stale_forward_extremities_persisted",
+ "Number of unchanged forward extremities for each new event",
+ buckets=(0, 1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"),
+)
+
+
+@attr.s(slots=True)
+class DeltaState:
+ """Deltas to use to update the `current_state_events` table.
+
+ Attributes:
+ to_delete: List of type/state_keys to delete from current state
+ to_insert: Map of state to upsert into current state
+ no_longer_in_room: The server is not longer in the room, so the room
+ should e.g. be removed from `current_state_events` table.
+ """
+
+ to_delete = attr.ib(type=List[Tuple[str, str]])
+ to_insert = attr.ib(type=StateMap[str])
+ no_longer_in_room = attr.ib(type=bool, default=False)
+
+
+class _EventPeristenceQueue(object):
+ """Queues up events so that they can be persisted in bulk with only one
+ concurrent transaction per room.
+ """
+
+ _EventPersistQueueItem = namedtuple(
+ "_EventPersistQueueItem", ("events_and_contexts", "backfilled", "deferred")
+ )
+
+ def __init__(self):
+ self._event_persist_queues = {}
+ self._currently_persisting_rooms = set()
+
+ def add_to_queue(self, room_id, events_and_contexts, backfilled):
+ """Add events to the queue, with the given persist_event options.
+
+ NB: due to the normal usage pattern of this method, it does *not*
+ follow the synapse logcontext rules, and leaves the logcontext in
+ place whether or not the returned deferred is ready.
+
+ Args:
+ room_id (str):
+ events_and_contexts (list[(EventBase, EventContext)]):
+ backfilled (bool):
+
+ Returns:
+ defer.Deferred: a deferred which will resolve once the events are
+ persisted. Runs its callbacks *without* a logcontext.
+ """
+ queue = self._event_persist_queues.setdefault(room_id, deque())
+ if queue:
+ # if the last item in the queue has the same `backfilled` setting,
+ # we can just add these new events to that item.
+ end_item = queue[-1]
+ if end_item.backfilled == backfilled:
+ end_item.events_and_contexts.extend(events_and_contexts)
+ return end_item.deferred.observe()
+
+ deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
+
+ queue.append(
+ self._EventPersistQueueItem(
+ events_and_contexts=events_and_contexts,
+ backfilled=backfilled,
+ deferred=deferred,
+ )
+ )
+
+ return deferred.observe()
+
+ def handle_queue(self, room_id, per_item_callback):
+ """Attempts to handle the queue for a room if not already being handled.
+
+ The given callback will be invoked with for each item in the queue,
+ of type _EventPersistQueueItem. The per_item_callback will continuously
+ be called with new items, unless the queue becomnes empty. The return
+ value of the function will be given to the deferreds waiting on the item,
+ exceptions will be passed to the deferreds as well.
+
+ This function should therefore be called whenever anything is added
+ to the queue.
+
+ If another callback is currently handling the queue then it will not be
+ invoked.
+ """
+
+ if room_id in self._currently_persisting_rooms:
+ return
+
+ self._currently_persisting_rooms.add(room_id)
+
+ async def handle_queue_loop():
+ try:
+ queue = self._get_drainining_queue(room_id)
+ for item in queue:
+ try:
+ ret = await per_item_callback(item)
+ except Exception:
+ with PreserveLoggingContext():
+ item.deferred.errback()
+ else:
+ with PreserveLoggingContext():
+ item.deferred.callback(ret)
+ finally:
+ queue = self._event_persist_queues.pop(room_id, None)
+ if queue:
+ self._event_persist_queues[room_id] = queue
+ self._currently_persisting_rooms.discard(room_id)
+
+ # set handle_queue_loop off in the background
+ run_as_background_process("persist_events", handle_queue_loop)
+
+ def _get_drainining_queue(self, room_id):
+ queue = self._event_persist_queues.setdefault(room_id, deque())
+
+ try:
+ while True:
+ yield queue.popleft()
+ except IndexError:
+ # Queue has been drained.
+ pass
+
+
+class EventsPersistenceStorage(object):
+ """High level interface for handling persisting newly received events.
+
+ Takes care of batching up events by room, and calculating the necessary
+ current state and forward extremity changes.
+ """
+
+ def __init__(self, hs, stores: DataStores):
+ # We ultimately want to split out the state store from the main store,
+ # so we use separate variables here even though they point to the same
+ # store for now.
+ self.main_store = stores.main
+ self.state_store = stores.state
+
+ self._clock = hs.get_clock()
+ self.is_mine_id = hs.is_mine_id
+ self._event_persist_queue = _EventPeristenceQueue()
+ self._state_resolution_handler = hs.get_state_resolution_handler()
+
+ @defer.inlineCallbacks
+ def persist_events(
+ self,
+ events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
+ backfilled: bool = False,
+ ):
+ """
+ Write events to the database
+ Args:
+ events_and_contexts: list of tuples of (event, context)
+ backfilled: Whether the results are retrieved from federation
+ via backfill or not. Used to determine if they're "new" events
+ which might update the current state etc.
+
+ Returns:
+ Deferred[int]: the stream ordering of the latest persisted event
+ """
+ partitioned = {}
+ for event, ctx in events_and_contexts:
+ partitioned.setdefault(event.room_id, []).append((event, ctx))
+
+ deferreds = []
+ for room_id, evs_ctxs in iteritems(partitioned):
+ d = self._event_persist_queue.add_to_queue(
+ room_id, evs_ctxs, backfilled=backfilled
+ )
+ deferreds.append(d)
+
+ for room_id in partitioned:
+ self._maybe_start_persisting(room_id)
+
+ yield make_deferred_yieldable(
+ defer.gatherResults(deferreds, consumeErrors=True)
+ )
+
+ max_persisted_id = yield self.main_store.get_current_events_token()
+
+ return max_persisted_id
+
+ @defer.inlineCallbacks
+ def persist_event(
+ self, event: FrozenEvent, context: EventContext, backfilled: bool = False
+ ):
+ """
+ Returns:
+ Deferred[Tuple[int, int]]: the stream ordering of ``event``,
+ and the stream ordering of the latest persisted event
+ """
+ deferred = self._event_persist_queue.add_to_queue(
+ event.room_id, [(event, context)], backfilled=backfilled
+ )
+
+ self._maybe_start_persisting(event.room_id)
+
+ yield make_deferred_yieldable(deferred)
+
+ max_persisted_id = yield self.main_store.get_current_events_token()
+ return (event.internal_metadata.stream_ordering, max_persisted_id)
+
+ def _maybe_start_persisting(self, room_id: str):
+ async def persisting_queue(item):
+ with Measure(self._clock, "persist_events"):
+ await self._persist_events(
+ item.events_and_contexts, backfilled=item.backfilled
+ )
+
+ self._event_persist_queue.handle_queue(room_id, persisting_queue)
+
+ async def _persist_events(
+ self,
+ events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
+ backfilled: bool = False,
+ ):
+ """Calculates the change to current state and forward extremities, and
+ persists the given events and with those updates.
+ """
+ if not events_and_contexts:
+ return
+
+ chunks = [
+ events_and_contexts[x : x + 100]
+ for x in range(0, len(events_and_contexts), 100)
+ ]
+
+ for chunk in chunks:
+ # We can't easily parallelize these since different chunks
+ # might contain the same event. :(
+
+ # NB: Assumes that we are only persisting events for one room
+ # at a time.
+
+ # map room_id->list[event_ids] giving the new forward
+ # extremities in each room
+ new_forward_extremeties = {}
+
+ # map room_id->(type,state_key)->event_id tracking the full
+ # state in each room after adding these events.
+ # This is simply used to prefill the get_current_state_ids
+ # cache
+ current_state_for_room = {}
+
+ # map room_id->(to_delete, to_insert) where to_delete is a list
+ # of type/state keys to remove from current state, and to_insert
+ # is a map (type,key)->event_id giving the state delta in each
+ # room
+ state_delta_for_room = {}
+
+ # Set of remote users which were in rooms the server has left. We
+ # should check if we still share any rooms and if not we mark their
+ # device lists as stale.
+ potentially_left_users = set() # type: Set[str]
+
+ if not backfilled:
+ with Measure(self._clock, "_calculate_state_and_extrem"):
+ # Work out the new "current state" for each room.
+ # We do this by working out what the new extremities are and then
+ # calculating the state from that.
+ events_by_room = {}
+ for event, context in chunk:
+ events_by_room.setdefault(event.room_id, []).append(
+ (event, context)
+ )
+
+ for room_id, ev_ctx_rm in iteritems(events_by_room):
+ latest_event_ids = await self.main_store.get_latest_event_ids_in_room(
+ room_id
+ )
+ new_latest_event_ids = await self._calculate_new_extremities(
+ room_id, ev_ctx_rm, latest_event_ids
+ )
+
+ latest_event_ids = set(latest_event_ids)
+ if new_latest_event_ids == latest_event_ids:
+ # No change in extremities, so no change in state
+ continue
+
+ # there should always be at least one forward extremity.
+ # (except during the initial persistence of the send_join
+ # results, in which case there will be no existing
+ # extremities, so we'll `continue` above and skip this bit.)
+ assert new_latest_event_ids, "No forward extremities left!"
+
+ new_forward_extremeties[room_id] = new_latest_event_ids
+
+ len_1 = (
+ len(latest_event_ids) == 1
+ and len(new_latest_event_ids) == 1
+ )
+ if len_1:
+ all_single_prev_not_state = all(
+ len(event.prev_event_ids()) == 1
+ and not event.is_state()
+ for event, ctx in ev_ctx_rm
+ )
+ # Don't bother calculating state if they're just
+ # a long chain of single ancestor non-state events.
+ if all_single_prev_not_state:
+ continue
+
+ state_delta_counter.inc()
+ if len(new_latest_event_ids) == 1:
+ state_delta_single_event_counter.inc()
+
+ # This is a fairly handwavey check to see if we could
+ # have guessed what the delta would have been when
+ # processing one of these events.
+ # What we're interested in is if the latest extremities
+ # were the same when we created the event as they are
+ # now. When this server creates a new event (as opposed
+ # to receiving it over federation) it will use the
+ # forward extremities as the prev_events, so we can
+ # guess this by looking at the prev_events and checking
+ # if they match the current forward extremities.
+ for ev, _ in ev_ctx_rm:
+ prev_event_ids = set(ev.prev_event_ids())
+ if latest_event_ids == prev_event_ids:
+ state_delta_reuse_delta_counter.inc()
+ break
+
+ logger.debug("Calculating state delta for room %s", room_id)
+ with Measure(
+ self._clock, "persist_events.get_new_state_after_events"
+ ):
+ res = await self._get_new_state_after_events(
+ room_id,
+ ev_ctx_rm,
+ latest_event_ids,
+ new_latest_event_ids,
+ )
+ current_state, delta_ids = res
+
+ # If either are not None then there has been a change,
+ # and we need to work out the delta (or use that
+ # given)
+ delta = None
+ if delta_ids is not None:
+ # If there is a delta we know that we've
+ # only added or replaced state, never
+ # removed keys entirely.
+ delta = DeltaState([], delta_ids)
+ elif current_state is not None:
+ with Measure(
+ self._clock, "persist_events.calculate_state_delta"
+ ):
+ delta = await self._calculate_state_delta(
+ room_id, current_state
+ )
+
+ if delta:
+ # If we have a change of state then lets check
+ # whether we're actually still a member of the room,
+ # or if our last user left. If we're no longer in
+ # the room then we delete the current state and
+ # extremities.
+ is_still_joined = await self._is_server_still_joined(
+ room_id,
+ ev_ctx_rm,
+ delta,
+ current_state,
+ potentially_left_users,
+ )
+ if not is_still_joined:
+ logger.info("Server no longer in room %s", room_id)
+ latest_event_ids = []
+ current_state = {}
+ delta.no_longer_in_room = True
+
+ state_delta_for_room[room_id] = delta
+
+ # If we have the current_state then lets prefill
+ # the cache with it.
+ if current_state is not None:
+ current_state_for_room[room_id] = current_state
+
+ await self.main_store._persist_events_and_state_updates(
+ chunk,
+ current_state_for_room=current_state_for_room,
+ state_delta_for_room=state_delta_for_room,
+ new_forward_extremeties=new_forward_extremeties,
+ backfilled=backfilled,
+ )
+
+ await self._handle_potentially_left_users(potentially_left_users)
+
+ async def _calculate_new_extremities(
+ self,
+ room_id: str,
+ event_contexts: List[Tuple[FrozenEvent, EventContext]],
+ latest_event_ids: List[str],
+ ):
+ """Calculates the new forward extremities for a room given events to
+ persist.
+
+ Assumes that we are only persisting events for one room at a time.
+ """
+
+ # we're only interested in new events which aren't outliers and which aren't
+ # being rejected.
+ new_events = [
+ event
+ for event, ctx in event_contexts
+ if not event.internal_metadata.is_outlier()
+ and not ctx.rejected
+ and not event.internal_metadata.is_soft_failed()
+ ]
+
+ latest_event_ids = set(latest_event_ids)
+
+ # start with the existing forward extremities
+ result = set(latest_event_ids)
+
+ # add all the new events to the list
+ result.update(event.event_id for event in new_events)
+
+ # Now remove all events which are prev_events of any of the new events
+ result.difference_update(
+ e_id for event in new_events for e_id in event.prev_event_ids()
+ )
+
+ # Remove any events which are prev_events of any existing events.
+ existing_prevs = await self.main_store._get_events_which_are_prevs(result)
+ result.difference_update(existing_prevs)
+
+ # Finally handle the case where the new events have soft-failed prev
+ # events. If they do we need to remove them and their prev events,
+ # otherwise we end up with dangling extremities.
+ existing_prevs = await self.main_store._get_prevs_before_rejected(
+ e_id for event in new_events for e_id in event.prev_event_ids()
+ )
+ result.difference_update(existing_prevs)
+
+ # We only update metrics for events that change forward extremities
+ # (e.g. we ignore backfill/outliers/etc)
+ if result != latest_event_ids:
+ forward_extremities_counter.observe(len(result))
+ stale = latest_event_ids & result
+ stale_forward_extremities_counter.observe(len(stale))
+
+ return result
+
+ async def _get_new_state_after_events(
+ self,
+ room_id: str,
+ events_context: List[Tuple[FrozenEvent, EventContext]],
+ old_latest_event_ids: Iterable[str],
+ new_latest_event_ids: Iterable[str],
+ ) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
+ """Calculate the current state dict after adding some new events to
+ a room
+
+ Args:
+ room_id (str):
+ room to which the events are being added. Used for logging etc
+
+ events_context (list[(EventBase, EventContext)]):
+ events and contexts which are being added to the room
+
+ old_latest_event_ids (iterable[str]):
+ the old forward extremities for the room.
+
+ new_latest_event_ids (iterable[str]):
+ the new forward extremities for the room.
+
+ Returns:
+ Returns a tuple of two state maps, the first being the full new current
+ state and the second being the delta to the existing current state.
+ If both are None then there has been no change.
+
+ If there has been a change then we only return the delta if its
+ already been calculated. Conversely if we do know the delta then
+ the new current state is only returned if we've already calculated
+ it.
+ """
+ # map from state_group to ((type, key) -> event_id) state map
+ state_groups_map = {}
+
+ # Map from (prev state group, new state group) -> delta state dict
+ state_group_deltas = {}
+
+ for ev, ctx in events_context:
+ if ctx.state_group is None:
+ # This should only happen for outlier events.
+ if not ev.internal_metadata.is_outlier():
+ raise Exception(
+ "Context for new event %s has no state "
+ "group" % (ev.event_id,)
+ )
+ continue
+
+ if ctx.state_group in state_groups_map:
+ continue
+
+ # We're only interested in pulling out state that has already
+ # been cached in the context. We'll pull stuff out of the DB later
+ # if necessary.
+ current_state_ids = ctx.get_cached_current_state_ids()
+ if current_state_ids is not None:
+ state_groups_map[ctx.state_group] = current_state_ids
+
+ if ctx.prev_group:
+ state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
+
+ # We need to map the event_ids to their state groups. First, let's
+ # check if the event is one we're persisting, in which case we can
+ # pull the state group from its context.
+ # Otherwise we need to pull the state group from the database.
+
+ # Set of events we need to fetch groups for. (We know none of the old
+ # extremities are going to be in events_context).
+ missing_event_ids = set(old_latest_event_ids)
+
+ event_id_to_state_group = {}
+ for event_id in new_latest_event_ids:
+ # First search in the list of new events we're adding.
+ for ev, ctx in events_context:
+ if event_id == ev.event_id and ctx.state_group is not None:
+ event_id_to_state_group[event_id] = ctx.state_group
+ break
+ else:
+ # If we couldn't find it, then we'll need to pull
+ # the state from the database
+ missing_event_ids.add(event_id)
+
+ if missing_event_ids:
+ # Now pull out the state groups for any missing events from DB
+ event_to_groups = await self.main_store._get_state_group_for_events(
+ missing_event_ids
+ )
+ event_id_to_state_group.update(event_to_groups)
+
+ # State groups of old_latest_event_ids
+ old_state_groups = {
+ event_id_to_state_group[evid] for evid in old_latest_event_ids
+ }
+
+ # State groups of new_latest_event_ids
+ new_state_groups = {
+ event_id_to_state_group[evid] for evid in new_latest_event_ids
+ }
+
+ # If they old and new groups are the same then we don't need to do
+ # anything.
+ if old_state_groups == new_state_groups:
+ return None, None
+
+ if len(new_state_groups) == 1 and len(old_state_groups) == 1:
+ # If we're going from one state group to another, lets check if
+ # we have a delta for that transition. If we do then we can just
+ # return that.
+
+ new_state_group = next(iter(new_state_groups))
+ old_state_group = next(iter(old_state_groups))
+
+ delta_ids = state_group_deltas.get((old_state_group, new_state_group), None)
+ if delta_ids is not None:
+ # We have a delta from the existing to new current state,
+ # so lets just return that. If we happen to already have
+ # the current state in memory then lets also return that,
+ # but it doesn't matter if we don't.
+ new_state = state_groups_map.get(new_state_group)
+ return new_state, delta_ids
+
+ # Now that we have calculated new_state_groups we need to get
+ # their state IDs so we can resolve to a single state set.
+ missing_state = new_state_groups - set(state_groups_map)
+ if missing_state:
+ group_to_state = await self.state_store._get_state_for_groups(missing_state)
+ state_groups_map.update(group_to_state)
+
+ if len(new_state_groups) == 1:
+ # If there is only one state group, then we know what the current
+ # state is.
+ return state_groups_map[new_state_groups.pop()], None
+
+ # Ok, we need to defer to the state handler to resolve our state sets.
+
+ state_groups = {sg: state_groups_map[sg] for sg in new_state_groups}
+
+ events_map = {ev.event_id: ev for ev, _ in events_context}
+
+ # We need to get the room version, which is in the create event.
+ # Normally that'd be in the database, but its also possible that we're
+ # currently trying to persist it.
+ room_version = None
+ for ev, _ in events_context:
+ if ev.type == EventTypes.Create and ev.state_key == "":
+ room_version = ev.content.get("room_version", "1")
+ break
+
+ if not room_version:
+ room_version = await self.main_store.get_room_version_id(room_id)
+
+ logger.debug("calling resolve_state_groups from preserve_events")
+ res = await self._state_resolution_handler.resolve_state_groups(
+ room_id,
+ room_version,
+ state_groups,
+ events_map,
+ state_res_store=StateResolutionStore(self.main_store),
+ )
+
+ return res.state, None
+
+ async def _calculate_state_delta(
+ self, room_id: str, current_state: StateMap[str]
+ ) -> DeltaState:
+ """Calculate the new state deltas for a room.
+
+ Assumes that we are only persisting events for one room at a time.
+ """
+ existing_state = await self.main_store.get_current_state_ids(room_id)
+
+ to_delete = [key for key in existing_state if key not in current_state]
+
+ to_insert = {
+ key: ev_id
+ for key, ev_id in iteritems(current_state)
+ if ev_id != existing_state.get(key)
+ }
+
+ return DeltaState(to_delete=to_delete, to_insert=to_insert)
+
+ async def _is_server_still_joined(
+ self,
+ room_id: str,
+ ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]],
+ delta: DeltaState,
+ current_state: Optional[StateMap[str]],
+ potentially_left_users: Set[str],
+ ) -> bool:
+ """Check if the server will still be joined after the given events have
+ been persised.
+
+ Args:
+ room_id
+ ev_ctx_rm
+ delta: The delta of current state between what is in the database
+ and what the new current state will be.
+ current_state: The new current state if it already been calculated,
+ otherwise None.
+ potentially_left_users: If the server has left the room, then joined
+ remote users will be added to this set to indicate that the
+ server may no longer be sharing a room with them.
+ """
+
+ if not any(
+ self.is_mine_id(state_key)
+ for typ, state_key in itertools.chain(delta.to_delete, delta.to_insert)
+ if typ == EventTypes.Member
+ ):
+ # There have been no changes to membership of our users, so nothing
+ # has changed and we assume we're still in the room.
+ return True
+
+ # Check if any of the given events are a local join that appear in the
+ # current state
+ events_to_check = [] # Event IDs that aren't an event we're persisting
+ for (typ, state_key), event_id in delta.to_insert.items():
+ if typ != EventTypes.Member or not self.is_mine_id(state_key):
+ continue
+
+ for event, _ in ev_ctx_rm:
+ if event_id == event.event_id:
+ if event.membership == Membership.JOIN:
+ return True
+
+ # The event is not in `ev_ctx_rm`, so we need to pull it out of
+ # the DB.
+ events_to_check.append(event_id)
+
+ # Check if any of the changes that we don't have events for are joins.
+ if events_to_check:
+ rows = await self.main_store.get_membership_from_event_ids(events_to_check)
+ is_still_joined = any(row["membership"] == Membership.JOIN for row in rows)
+ if is_still_joined:
+ return True
+
+ # None of the new state events are local joins, so we check the database
+ # to see if there are any other local users in the room. We ignore users
+ # whose state has changed as we've already their new state above.
+ users_to_ignore = [
+ state_key
+ for _, state_key in itertools.chain(delta.to_insert, delta.to_delete)
+ if self.is_mine_id(state_key)
+ ]
+
+ if await self.main_store.is_local_host_in_room_ignoring_users(
+ room_id, users_to_ignore
+ ):
+ return True
+
+ # The server will leave the room, so we go and find out which remote
+ # users will still be joined when we leave.
+ if current_state is None:
+ current_state = await self.main_store.get_current_state_ids(room_id)
+ current_state = dict(current_state)
+ for key in delta.to_delete:
+ current_state.pop(key, None)
+
+ current_state.update(delta.to_insert)
+
+ remote_event_ids = [
+ event_id
+ for (typ, state_key,), event_id in current_state.items()
+ if typ == EventTypes.Member and not self.is_mine_id(state_key)
+ ]
+ rows = await self.main_store.get_membership_from_event_ids(remote_event_ids)
+ potentially_left_users.update(
+ row["user_id"] for row in rows if row["membership"] == Membership.JOIN
+ )
+
+ return False
+
+ async def _handle_potentially_left_users(self, user_ids: Set[str]):
+ """Given a set of remote users check if the server still shares a room with
+ them. If not then mark those users' device cache as stale.
+ """
+
+ if not user_ids:
+ return
+
+ joined_users = await self.main_store.get_users_server_still_shares_room_with(
+ user_ids
+ )
+ left_users = user_ids - joined_users
+
+ for user_id in left_users:
+ await self.main_store.mark_remote_user_device_list_as_unsubscribed(user_id)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index f2c1bed487..6cb7d4b922 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -14,11 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import fnmatch
import imp
import logging
import os
import re
+from collections import Counter
+
+import attr
from synapse.storage.engines.postgres import PostgresEngine
@@ -27,7 +29,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 55
+SCHEMA_VERSION = 57
dir_path = os.path.abspath(os.path.dirname(__file__))
@@ -40,7 +42,7 @@ class UpgradeDatabaseException(PrepareDatabaseException):
pass
-def prepare_database(db_conn, database_engine, config):
+def prepare_database(db_conn, database_engine, config, data_stores=["main", "state"]):
"""Prepares a database for usage. Will either create all necessary tables
or upgrade from an older schema version.
@@ -53,7 +55,10 @@ def prepare_database(db_conn, database_engine, config):
config (synapse.config.homeserver.HomeServerConfig|None):
application config, or None if we are connecting to an existing
database which we expect to be configured already
+ data_stores (list[str]): The name of the data stores that will be used
+ with this database. Defaults to all data stores.
"""
+
try:
cur = db_conn.cursor()
version_info = _get_or_create_schema_state(cur, database_engine)
@@ -65,13 +70,22 @@ def prepare_database(db_conn, database_engine, config):
if user_version != SCHEMA_VERSION:
# If we don't pass in a config file then we are expecting to
# have already upgraded the DB.
- raise UpgradeDatabaseException("Database needs to be upgraded")
+ raise UpgradeDatabaseException(
+ "Expected database schema version %i but got %i"
+ % (SCHEMA_VERSION, user_version)
+ )
else:
_upgrade_existing_database(
- cur, user_version, delta_files, upgraded, database_engine, config
+ cur,
+ user_version,
+ delta_files,
+ upgraded,
+ database_engine,
+ config,
+ data_stores=data_stores,
)
else:
- _setup_new_database(cur, database_engine)
+ _setup_new_database(cur, database_engine, data_stores=data_stores)
# check if any of our configured dynamic modules want a database
if config is not None:
@@ -84,9 +98,10 @@ def prepare_database(db_conn, database_engine, config):
raise
-def _setup_new_database(cur, database_engine):
+def _setup_new_database(cur, database_engine, data_stores):
"""Sets up the database by finding a base set of "full schemas" and then
- applying any necessary deltas.
+ applying any necessary deltas, including schemas from the given data
+ stores.
The "full_schemas" directory has subdirectories named after versions. This
function searches for the highest version less than or equal to
@@ -111,51 +126,83 @@ def _setup_new_database(cur, database_engine):
In the example foo.sql and bar.sql would be run, and then any delta files
for versions strictly greater than 11.
+
+ Note: we apply the full schemas and deltas from the top level `schema/`
+ folder as well those in the data stores specified.
+
+ Args:
+ cur (Cursor): a database cursor
+ database_engine (DatabaseEngine)
+ data_stores (list[str]): The names of the data stores to instantiate
+ on the given database.
"""
- current_dir = os.path.join(dir_path, "schema", "full_schemas")
- directory_entries = os.listdir(current_dir)
- valid_dirs = []
- pattern = re.compile(r"^\d+(\.sql)?$")
+ # We're about to set up a brand new database so we check that its
+ # configured to our liking.
+ database_engine.check_new_database(cur)
- if isinstance(database_engine, PostgresEngine):
- specific = "postgres"
- else:
- specific = "sqlite"
+ current_dir = os.path.join(dir_path, "schema", "full_schemas")
+ directory_entries = os.listdir(current_dir)
- specific_pattern = re.compile(r"^\d+(\.sql." + specific + r")?$")
+ # First we find the highest full schema version we have
+ valid_versions = []
for filename in directory_entries:
- match = pattern.match(filename) or specific_pattern.match(filename)
- abs_path = os.path.join(current_dir, filename)
- if match and os.path.isdir(abs_path):
- ver = int(match.group(0))
- if ver <= SCHEMA_VERSION:
- valid_dirs.append((ver, abs_path))
- else:
- logger.warn("Unexpected entry in 'full_schemas': %s", filename)
+ try:
+ ver = int(filename)
+ except ValueError:
+ continue
- if not valid_dirs:
+ if ver <= SCHEMA_VERSION:
+ valid_versions.append(ver)
+
+ if not valid_versions:
raise PrepareDatabaseException(
"Could not find a suitable base set of full schemas"
)
- max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0])
+ max_current_ver = max(valid_versions)
logger.debug("Initialising schema v%d", max_current_ver)
- directory_entries = os.listdir(sql_dir)
+ # Now lets find all the full schema files, both in the global schema and
+ # in data store schemas.
+ directories = [os.path.join(current_dir, str(max_current_ver))]
+ directories.extend(
+ os.path.join(
+ dir_path,
+ "data_stores",
+ data_store,
+ "schema",
+ "full_schemas",
+ str(max_current_ver),
+ )
+ for data_store in data_stores
+ )
+
+ directory_entries = []
+ for directory in directories:
+ directory_entries.extend(
+ _DirectoryListing(file_name, os.path.join(directory, file_name))
+ for file_name in os.listdir(directory)
+ )
+
+ if isinstance(database_engine, PostgresEngine):
+ specific = "postgres"
+ else:
+ specific = "sqlite"
- for filename in sorted(fnmatch.filter(directory_entries, "*.sql") + fnmatch.filter(
- directory_entries, "*.sql." + specific
- )):
- sql_loc = os.path.join(sql_dir, filename)
- logger.debug("Applying schema %s", sql_loc)
- executescript(cur, sql_loc)
+ directory_entries.sort()
+ for entry in directory_entries:
+ if entry.file_name.endswith(".sql") or entry.file_name.endswith(
+ ".sql." + specific
+ ):
+ logger.debug("Applying schema %s", entry.absolute_path)
+ executescript(cur, entry.absolute_path)
cur.execute(
database_engine.convert_param_style(
- "INSERT INTO schema_version (version, upgraded)" " VALUES (?,?)"
+ "INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
),
(max_current_ver, False),
)
@@ -167,6 +214,7 @@ def _setup_new_database(cur, database_engine):
upgraded=False,
database_engine=database_engine,
config=None,
+ data_stores=data_stores,
is_empty=True,
)
@@ -178,6 +226,7 @@ def _upgrade_existing_database(
upgraded,
database_engine,
config,
+ data_stores,
is_empty=False,
):
"""Upgrades an existing database.
@@ -214,6 +263,10 @@ def _upgrade_existing_database(
only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in
some arbitrary order.
+ Note: we apply the delta files from the specified data stores as well as
+ those in the top-level schema. We apply all delta files across data stores
+ for a version before applying those in the next version.
+
Args:
cur (Cursor)
current_version (int): The current version of the schema.
@@ -223,7 +276,19 @@ def _upgrade_existing_database(
applied deltas or from full schema file. If `True` the function
will never apply delta files for the given `current_version`, since
the current_version wasn't generated by applying those delta files.
+ database_engine (DatabaseEngine)
+ config (synapse.config.homeserver.HomeServerConfig|None):
+ None if we are initialising a blank database, otherwise the application
+ config
+ data_stores (list[str]): The names of the data stores to instantiate
+ on the given database.
+ is_empty (bool): Is this a blank database? I.e. do we need to run the
+ upgrade portions of the delta scripts.
"""
+ if is_empty:
+ assert not applied_delta_files
+ else:
+ assert config
if current_version > SCHEMA_VERSION:
raise ValueError(
@@ -231,33 +296,89 @@ def _upgrade_existing_database(
+ "new for the server to understand"
)
+ # some of the deltas assume that config.server_name is set correctly, so now
+ # is a good time to run the sanity check.
+ if not is_empty and "main" in data_stores:
+ from synapse.storage.data_stores.main import check_database_before_upgrade
+
+ check_database_before_upgrade(cur, database_engine, config)
+
start_ver = current_version
if not upgraded:
start_ver += 1
logger.debug("applied_delta_files: %s", applied_delta_files)
+ if isinstance(database_engine, PostgresEngine):
+ specific_engine_extension = ".postgres"
+ else:
+ specific_engine_extension = ".sqlite"
+
+ specific_engine_extensions = (".sqlite", ".postgres")
+
for v in range(start_ver, SCHEMA_VERSION + 1):
logger.info("Upgrading schema to v%d", v)
+ # We need to search both the global and per data store schema
+ # directories for schema updates.
+
+ # First we find the directories to search in
delta_dir = os.path.join(dir_path, "schema", "delta", str(v))
+ directories = [delta_dir]
+ for data_store in data_stores:
+ directories.append(
+ os.path.join(
+ dir_path, "data_stores", data_store, "schema", "delta", str(v)
+ )
+ )
- try:
- directory_entries = os.listdir(delta_dir)
- except OSError:
- logger.exception("Could not open delta dir for version %d", v)
- raise UpgradeDatabaseException(
- "Could not open delta dir for version %d" % (v,)
+ # Used to check if we have any duplicate file names
+ file_name_counter = Counter()
+
+ # Now find which directories have anything of interest.
+ directory_entries = []
+ for directory in directories:
+ logger.debug("Looking for schema deltas in %s", directory)
+ try:
+ file_names = os.listdir(directory)
+ directory_entries.extend(
+ _DirectoryListing(file_name, os.path.join(directory, file_name))
+ for file_name in file_names
+ )
+
+ for file_name in file_names:
+ file_name_counter[file_name] += 1
+ except FileNotFoundError:
+ # Data stores can have empty entries for a given version delta.
+ pass
+ except OSError:
+ raise UpgradeDatabaseException(
+ "Could not open delta dir for version %d: %s" % (v, directory)
+ )
+
+ duplicates = {
+ file_name for file_name, count in file_name_counter.items() if count > 1
+ }
+ if duplicates:
+ # We don't support using the same file name in the same delta version.
+ raise PrepareDatabaseException(
+ "Found multiple delta files with the same name in v%d: %s",
+ v,
+ duplicates,
)
+ # We sort to ensure that we apply the delta files in a consistent
+ # order (to avoid bugs caused by inconsistent directory listing order)
directory_entries.sort()
- for file_name in directory_entries:
+ for entry in directory_entries:
+ file_name = entry.file_name
relative_path = os.path.join(str(v), file_name)
- logger.debug("Found file: %s", relative_path)
+ absolute_path = entry.absolute_path
+
+ logger.debug("Found file: %s (%s)", relative_path, absolute_path)
if relative_path in applied_delta_files:
continue
- absolute_path = os.path.join(dir_path, "schema", "delta", relative_path)
root_name, ext = os.path.splitext(file_name)
if ext == ".py":
# This is a python upgrade module. We need to import into some
@@ -273,15 +394,22 @@ def _upgrade_existing_database(
# Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package
# installers. Silently skip it
- pass
+ continue
elif ext == ".sql":
# A plain old .sql file, just read and execute it
logger.info("Applying schema %s", relative_path)
executescript(cur, absolute_path)
+ elif ext == specific_engine_extension and root_name.endswith(".sql"):
+ # A .sql file specific to our engine; just read and execute it
+ logger.info("Applying engine-specific schema %s", relative_path)
+ executescript(cur, absolute_path)
+ elif ext in specific_engine_extensions and root_name.endswith(".sql"):
+ # A .sql file for a different engine; skip it.
+ continue
else:
# Not a valid delta file.
- logger.warn(
- "Found directory entry that did not end in .py or" " .sql: %s",
+ logger.warning(
+ "Found directory entry that did not end in .py or .sql: %s",
relative_path,
)
continue
@@ -289,7 +417,7 @@ def _upgrade_existing_database(
# Mark as done.
cur.execute(
database_engine.convert_param_style(
- "INSERT INTO applied_schema_deltas (version, file)" " VALUES (?,?)"
+ "INSERT INTO applied_schema_deltas (version, file) VALUES (?,?)"
),
(v, relative_path),
)
@@ -297,7 +425,7 @@ def _upgrade_existing_database(
cur.execute("DELETE FROM schema_version")
cur.execute(
database_engine.convert_param_style(
- "INSERT INTO schema_version (version, upgraded)" " VALUES (?,?)"
+ "INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
),
(v, True),
)
@@ -313,7 +441,7 @@ def _apply_module_schemas(txn, database_engine, config):
application config
"""
for (mod, _config) in config.password_providers:
- if not hasattr(mod, 'get_db_schema_files'):
+ if not hasattr(mod, "get_db_schema_files"):
continue
modname = ".".join((mod.__module__, mod.__name__))
_apply_module_schema_files(
@@ -337,13 +465,13 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
),
(modname,),
)
- applied_deltas = set(d for d, in cur)
+ applied_deltas = {d for d, in cur}
for (name, stream) in names_and_streams:
if name in applied_deltas:
continue
root_name, ext = os.path.splitext(name)
- if ext != '.sql':
+ if ext != ".sql":
raise PrepareDatabaseException(
"only .sql files are currently supported for module schemas"
)
@@ -355,7 +483,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
# Mark as done.
cur.execute(
database_engine.convert_param_style(
- "INSERT INTO applied_module_schemas (module_name, file)" " VALUES (?,?)"
+ "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)"
),
(modname, name),
)
@@ -407,7 +535,7 @@ def get_statements(f):
def executescript(txn, schema_path):
- with open(schema_path, 'r') as f:
+ with open(schema_path, "r") as f:
for statement in get_statements(f):
txn.execute(statement)
@@ -433,3 +561,16 @@ def _get_or_create_schema_state(txn, database_engine):
return current_version, applied_deltas, upgraded
return None
+
+
+@attr.s()
+class _DirectoryListing(object):
+ """Helper class to store schema file name and the
+ absolute path to it.
+
+ These entries get sorted, so for consistency we want to ensure that
+ `file_name` attr is kept first.
+ """
+
+ file_name = attr.ib()
+ absolute_path = attr.ib()
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 42ec8c6bb8..18a462f0ee 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -15,13 +15,7 @@
from collections import namedtuple
-from twisted.internet import defer
-
from synapse.api.constants import PresenceState
-from synapse.util import batch_iter
-from synapse.util.caches.descriptors import cached, cachedList
-
-from ._base import SQLBaseStore
class UserPresenceState(
@@ -73,135 +67,3 @@ class UserPresenceState(
status_msg=None,
currently_active=False,
)
-
-
-class PresenceStore(SQLBaseStore):
- @defer.inlineCallbacks
- def update_presence(self, presence_states):
- stream_ordering_manager = self._presence_id_gen.get_next_mult(
- len(presence_states)
- )
-
- with stream_ordering_manager as stream_orderings:
- yield self.runInteraction(
- "update_presence",
- self._update_presence_txn,
- stream_orderings,
- presence_states,
- )
-
- defer.returnValue(
- (stream_orderings[-1], self._presence_id_gen.get_current_token())
- )
-
- def _update_presence_txn(self, txn, stream_orderings, presence_states):
- for stream_id, state in zip(stream_orderings, presence_states):
- txn.call_after(
- self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
- )
- txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
-
- # Actually insert new rows
- self._simple_insert_many_txn(
- txn,
- table="presence_stream",
- values=[
- {
- "stream_id": stream_id,
- "user_id": state.user_id,
- "state": state.state,
- "last_active_ts": state.last_active_ts,
- "last_federation_update_ts": state.last_federation_update_ts,
- "last_user_sync_ts": state.last_user_sync_ts,
- "status_msg": state.status_msg,
- "currently_active": state.currently_active,
- }
- for state in presence_states
- ],
- )
-
- # Delete old rows to stop database from getting really big
- sql = (
- "DELETE FROM presence_stream WHERE" " stream_id < ?" " AND user_id IN (%s)"
- )
-
- for states in batch_iter(presence_states, 50):
- args = [stream_id]
- args.extend(s.user_id for s in states)
- txn.execute(sql % (",".join("?" for _ in states),), args)
-
- def get_all_presence_updates(self, last_id, current_id):
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_presence_updates_txn(txn):
- sql = (
- "SELECT stream_id, user_id, state, last_active_ts,"
- " last_federation_update_ts, last_user_sync_ts, status_msg,"
- " currently_active"
- " FROM presence_stream"
- " WHERE ? < stream_id AND stream_id <= ?"
- )
- txn.execute(sql, (last_id, current_id))
- return txn.fetchall()
-
- return self.runInteraction(
- "get_all_presence_updates", get_all_presence_updates_txn
- )
-
- @cached()
- def _get_presence_for_user(self, user_id):
- raise NotImplementedError()
-
- @cachedList(
- cached_method_name="_get_presence_for_user",
- list_name="user_ids",
- num_args=1,
- inlineCallbacks=True,
- )
- def get_presence_for_users(self, user_ids):
- rows = yield self._simple_select_many_batch(
- table="presence_stream",
- column="user_id",
- iterable=user_ids,
- keyvalues={},
- retcols=(
- "user_id",
- "state",
- "last_active_ts",
- "last_federation_update_ts",
- "last_user_sync_ts",
- "status_msg",
- "currently_active",
- ),
- desc="get_presence_for_users",
- )
-
- for row in rows:
- row["currently_active"] = bool(row["currently_active"])
-
- defer.returnValue({row["user_id"]: UserPresenceState(**row) for row in rows})
-
- def get_current_presence_token(self):
- return self._presence_id_gen.get_current_token()
-
- def allow_presence_visible(self, observed_localpart, observer_userid):
- return self._simple_insert(
- table="presence_allow_inbound",
- values={
- "observed_user_id": observed_localpart,
- "observer_user_id": observer_userid,
- },
- desc="allow_presence_visible",
- or_ignore=True,
- )
-
- def disallow_presence_visible(self, observed_localpart, observer_userid):
- return self._simple_delete_one(
- table="presence_allow_inbound",
- keyvalues={
- "observed_user_id": observed_localpart,
- "observer_user_id": observer_userid,
- },
- desc="disallow_presence_visible",
- )
diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
new file mode 100644
index 0000000000..fdc0abf5cf
--- /dev/null
+++ b/synapse/storage/purge_events.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import itertools
+import logging
+
+from twisted.internet import defer
+
+logger = logging.getLogger(__name__)
+
+
+class PurgeEventsStorage(object):
+ """High level interface for purging rooms and event history.
+ """
+
+ def __init__(self, hs, stores):
+ self.stores = stores
+
+ @defer.inlineCallbacks
+ def purge_room(self, room_id: str):
+ """Deletes all record of a room
+ """
+
+ state_groups_to_delete = yield self.stores.main.purge_room(room_id)
+ yield self.stores.state.purge_room_state(room_id, state_groups_to_delete)
+
+ @defer.inlineCallbacks
+ def purge_history(self, room_id, token, delete_local_events):
+ """Deletes room history before a certain point
+
+ Args:
+ room_id (str):
+
+ token (str): A topological token to delete events before
+
+ delete_local_events (bool):
+ if True, we will delete local events as well as remote ones
+ (instead of just marking them as outliers and deleting their
+ state groups).
+ """
+ state_groups = yield self.stores.main.purge_history(
+ room_id, token, delete_local_events
+ )
+
+ logger.info("[purge] finding state groups that can be deleted")
+
+ sg_to_delete = yield self._find_unreferenced_groups(state_groups)
+
+ yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
+
+ @defer.inlineCallbacks
+ def _find_unreferenced_groups(self, state_groups):
+ """Used when purging history to figure out which state groups can be
+ deleted.
+
+ Args:
+ state_groups (set[int]): Set of state groups referenced by events
+ that are going to be deleted.
+
+ Returns:
+ Deferred[set[int]] The set of state groups that can be deleted.
+ """
+ # Graph of state group -> previous group
+ graph = {}
+
+ # Set of events that we have found to be referenced by events
+ referenced_groups = set()
+
+ # Set of state groups we've already seen
+ state_groups_seen = set(state_groups)
+
+ # Set of state groups to handle next.
+ next_to_search = set(state_groups)
+ while next_to_search:
+ # We bound size of groups we're looking up at once, to stop the
+ # SQL query getting too big
+ if len(next_to_search) < 100:
+ current_search = next_to_search
+ next_to_search = set()
+ else:
+ current_search = set(itertools.islice(next_to_search, 100))
+ next_to_search -= current_search
+
+ referenced = yield self.stores.main.get_referenced_state_groups(
+ current_search
+ )
+ referenced_groups |= referenced
+
+ # We don't continue iterating up the state group graphs for state
+ # groups that are referenced.
+ current_search -= referenced
+
+ edges = yield self.stores.state.get_previous_state_groups(current_search)
+
+ prevs = set(edges.values())
+ # We don't bother re-handling groups we've already seen
+ prevs -= state_groups_seen
+ next_to_search |= prevs
+ state_groups_seen |= prevs
+
+ graph.update(edges)
+
+ to_delete = state_groups_seen - referenced_groups
+
+ return to_delete
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 9e406baafa..f47cec0d86 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -14,710 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import abc
-import logging
-
-from canonicaljson import json
-
-from twisted.internet import defer
-
-from synapse.push.baserules import list_with_base_rules
-from synapse.storage.appservice import ApplicationServiceWorkerStore
-from synapse.storage.pusher import PusherWorkerStore
-from synapse.storage.receipts import ReceiptsWorkerStore
-from synapse.storage.roommember import RoomMemberWorkerStore
-from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
-from synapse.util.caches.stream_change_cache import StreamChangeCache
-
-from ._base import SQLBaseStore
-
-logger = logging.getLogger(__name__)
-
-
-def _load_rules(rawrules, enabled_map):
- ruleslist = []
- for rawrule in rawrules:
- rule = dict(rawrule)
- rule["conditions"] = json.loads(rawrule["conditions"])
- rule["actions"] = json.loads(rawrule["actions"])
- ruleslist.append(rule)
-
- # We're going to be mutating this a lot, so do a deep copy
- rules = list(list_with_base_rules(ruleslist))
-
- for i, rule in enumerate(rules):
- rule_id = rule['rule_id']
- if rule_id in enabled_map:
- if rule.get('enabled', True) != bool(enabled_map[rule_id]):
- # Rules are cached across users.
- rule = dict(rule)
- rule['enabled'] = bool(enabled_map[rule_id])
- rules[i] = rule
-
- return rules
-
-
-class PushRulesWorkerStore(
- ApplicationServiceWorkerStore,
- ReceiptsWorkerStore,
- PusherWorkerStore,
- RoomMemberWorkerStore,
- SQLBaseStore,
-):
- """This is an abstract base class where subclasses must implement
- `get_max_push_rules_stream_id` which can be called in the initializer.
- """
-
- # This ABCMeta metaclass ensures that we cannot be instantiated without
- # the abstract methods being implemented.
- __metaclass__ = abc.ABCMeta
-
- def __init__(self, db_conn, hs):
- super(PushRulesWorkerStore, self).__init__(db_conn, hs)
-
- push_rules_prefill, push_rules_id = self._get_cache_dict(
- db_conn,
- "push_rules_stream",
- entity_column="user_id",
- stream_column="stream_id",
- max_value=self.get_max_push_rules_stream_id(),
- )
-
- self.push_rules_stream_cache = StreamChangeCache(
- "PushRulesStreamChangeCache",
- push_rules_id,
- prefilled_cache=push_rules_prefill,
- )
-
- @abc.abstractmethod
- def get_max_push_rules_stream_id(self):
- """Get the position of the push rules stream.
-
- Returns:
- int
- """
- raise NotImplementedError()
-
- @cachedInlineCallbacks(max_entries=5000)
- def get_push_rules_for_user(self, user_id):
- rows = yield self._simple_select_list(
- table="push_rules",
- keyvalues={"user_name": user_id},
- retcols=(
- "user_name",
- "rule_id",
- "priority_class",
- "priority",
- "conditions",
- "actions",
- ),
- desc="get_push_rules_enabled_for_user",
- )
-
- rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
-
- enabled_map = yield self.get_push_rules_enabled_for_user(user_id)
-
- rules = _load_rules(rows, enabled_map)
-
- defer.returnValue(rules)
-
- @cachedInlineCallbacks(max_entries=5000)
- def get_push_rules_enabled_for_user(self, user_id):
- results = yield self._simple_select_list(
- table="push_rules_enable",
- keyvalues={'user_name': user_id},
- retcols=("user_name", "rule_id", "enabled"),
- desc="get_push_rules_enabled_for_user",
- )
- defer.returnValue(
- {r['rule_id']: False if r['enabled'] == 0 else True for r in results}
- )
-
- def have_push_rules_changed_for_user(self, user_id, last_id):
- if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
- return defer.succeed(False)
- else:
-
- def have_push_rules_changed_txn(txn):
- sql = (
- "SELECT COUNT(stream_id) FROM push_rules_stream"
- " WHERE user_id = ? AND ? < stream_id"
- )
- txn.execute(sql, (user_id, last_id))
- count, = txn.fetchone()
- return bool(count)
-
- return self.runInteraction(
- "have_push_rules_changed", have_push_rules_changed_txn
- )
-
- @cachedList(
- cached_method_name="get_push_rules_for_user",
- list_name="user_ids",
- num_args=1,
- inlineCallbacks=True,
- )
- def bulk_get_push_rules(self, user_ids):
- if not user_ids:
- defer.returnValue({})
-
- results = {user_id: [] for user_id in user_ids}
-
- rows = yield self._simple_select_many_batch(
- table="push_rules",
- column="user_name",
- iterable=user_ids,
- retcols=("*",),
- desc="bulk_get_push_rules",
- )
-
- rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
-
- for row in rows:
- results.setdefault(row['user_name'], []).append(row)
-
- enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
-
- for user_id, rules in results.items():
- results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {}))
-
- defer.returnValue(results)
-
- @defer.inlineCallbacks
- def move_push_rule_from_room_to_room(self, new_room_id, user_id, rule):
- """Move a single push rule from one room to another for a specific user.
-
- Args:
- new_room_id (str): ID of the new room.
- user_id (str): ID of user the push rule belongs to.
- rule (Dict): A push rule.
- """
- # Create new rule id
- rule_id_scope = '/'.join(rule["rule_id"].split('/')[:-1])
- new_rule_id = rule_id_scope + "/" + new_room_id
-
- # Change room id in each condition
- for condition in rule.get("conditions", []):
- if condition.get("key") == "room_id":
- condition["pattern"] = new_room_id
-
- # Add the rule for the new room
- yield self.add_push_rule(
- user_id=user_id,
- rule_id=new_rule_id,
- priority_class=rule["priority_class"],
- conditions=rule["conditions"],
- actions=rule["actions"],
- )
-
- # Delete push rule for the old room
- yield self.delete_push_rule(user_id, rule["rule_id"])
-
- @defer.inlineCallbacks
- def move_push_rules_from_room_to_room_for_user(
- self, old_room_id, new_room_id, user_id
- ):
- """Move all of the push rules from one room to another for a specific
- user.
-
- Args:
- old_room_id (str): ID of the old room.
- new_room_id (str): ID of the new room.
- user_id (str): ID of user to copy push rules for.
- """
- # Retrieve push rules for this user
- user_push_rules = yield self.get_push_rules_for_user(user_id)
-
- # Get rules relating to the old room, move them to the new room, then
- # delete them from the old room
- for rule in user_push_rules:
- conditions = rule.get("conditions", [])
- if any(
- (c.get("key") == "room_id" and c.get("pattern") == old_room_id)
- for c in conditions
- ):
- self.move_push_rule_from_room_to_room(new_room_id, user_id, rule)
-
- @defer.inlineCallbacks
- def bulk_get_push_rules_for_room(self, event, context):
- state_group = context.state_group
- if not state_group:
- # If state_group is None it means it has yet to be assigned a
- # state group, i.e. we need to make sure that calls with a state_group
- # of None don't hit previous cached calls with a None state_group.
- # To do this we set the state_group to a new object as object() != object()
- state_group = object()
-
- current_state_ids = yield context.get_current_state_ids(self)
- result = yield self._bulk_get_push_rules_for_room(
- event.room_id, state_group, current_state_ids, event=event
- )
- defer.returnValue(result)
-
- @cachedInlineCallbacks(num_args=2, cache_context=True)
- def _bulk_get_push_rules_for_room(
- self, room_id, state_group, current_state_ids, cache_context, event=None
- ):
- # We don't use `state_group`, its there so that we can cache based
- # on it. However, its important that its never None, since two current_state's
- # with a state_group of None are likely to be different.
- # See bulk_get_push_rules_for_room for how we work around this.
- assert state_group is not None
-
- # We also will want to generate notifs for other people in the room so
- # their unread countss are correct in the event stream, but to avoid
- # generating them for bot / AS users etc, we only do so for people who've
- # sent a read receipt into the room.
-
- users_in_room = yield self._get_joined_users_from_context(
- room_id,
- state_group,
- current_state_ids,
- on_invalidate=cache_context.invalidate,
- event=event,
- )
-
- # We ignore app service users for now. This is so that we don't fill
- # up the `get_if_users_have_pushers` cache with AS entries that we
- # know don't have pushers, nor even read receipts.
- local_users_in_room = set(
- u
- for u in users_in_room
- if self.hs.is_mine_id(u)
- and not self.get_if_app_services_interested_in_user(u)
- )
-
- # users in the room who have pushers need to get push rules run because
- # that's how their pushers work
- if_users_with_pushers = yield self.get_if_users_have_pushers(
- local_users_in_room, on_invalidate=cache_context.invalidate
- )
- user_ids = set(
- uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
- )
-
- users_with_receipts = yield self.get_users_with_read_receipts_in_room(
- room_id, on_invalidate=cache_context.invalidate
- )
-
- # any users with pushers must be ours: they have pushers
- for uid in users_with_receipts:
- if uid in local_users_in_room:
- user_ids.add(uid)
-
- rules_by_user = yield self.bulk_get_push_rules(
- user_ids, on_invalidate=cache_context.invalidate
- )
-
- rules_by_user = {k: v for k, v in rules_by_user.items() if v is not None}
-
- defer.returnValue(rules_by_user)
-
- @cachedList(
- cached_method_name="get_push_rules_enabled_for_user",
- list_name="user_ids",
- num_args=1,
- inlineCallbacks=True,
- )
- def bulk_get_push_rules_enabled(self, user_ids):
- if not user_ids:
- defer.returnValue({})
-
- results = {user_id: {} for user_id in user_ids}
-
- rows = yield self._simple_select_many_batch(
- table="push_rules_enable",
- column="user_name",
- iterable=user_ids,
- retcols=("user_name", "rule_id", "enabled"),
- desc="bulk_get_push_rules_enabled",
- )
- for row in rows:
- enabled = bool(row['enabled'])
- results.setdefault(row['user_name'], {})[row['rule_id']] = enabled
- defer.returnValue(results)
-
-
-class PushRuleStore(PushRulesWorkerStore):
- @defer.inlineCallbacks
- def add_push_rule(
- self,
- user_id,
- rule_id,
- priority_class,
- conditions,
- actions,
- before=None,
- after=None,
- ):
- conditions_json = json.dumps(conditions)
- actions_json = json.dumps(actions)
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
- if before or after:
- yield self.runInteraction(
- "_add_push_rule_relative_txn",
- self._add_push_rule_relative_txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- priority_class,
- conditions_json,
- actions_json,
- before,
- after,
- )
- else:
- yield self.runInteraction(
- "_add_push_rule_highest_priority_txn",
- self._add_push_rule_highest_priority_txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- priority_class,
- conditions_json,
- actions_json,
- )
-
- def _add_push_rule_relative_txn(
- self,
- txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- priority_class,
- conditions_json,
- actions_json,
- before,
- after,
- ):
- # Lock the table since otherwise we'll have annoying races between the
- # SELECT here and the UPSERT below.
- self.database_engine.lock_table(txn, "push_rules")
-
- relative_to_rule = before or after
-
- res = self._simple_select_one_txn(
- txn,
- table="push_rules",
- keyvalues={"user_name": user_id, "rule_id": relative_to_rule},
- retcols=["priority_class", "priority"],
- allow_none=True,
- )
-
- if not res:
- raise RuleNotFoundException(
- "before/after rule not found: %s" % (relative_to_rule,)
- )
-
- base_priority_class = res["priority_class"]
- base_rule_priority = res["priority"]
-
- if base_priority_class != priority_class:
- raise InconsistentRuleException(
- "Given priority class does not match class of relative rule"
- )
-
- if before:
- # Higher priority rules are executed first, So adding a rule before
- # a rule means giving it a higher priority than that rule.
- new_rule_priority = base_rule_priority + 1
- else:
- # We increment the priority of the existing rules to make space for
- # the new rule. Therefore if we want this rule to appear after
- # an existing rule we give it the priority of the existing rule,
- # and then increment the priority of the existing rule.
- new_rule_priority = base_rule_priority
-
- sql = (
- "UPDATE push_rules SET priority = priority + 1"
- " WHERE user_name = ? AND priority_class = ? AND priority >= ?"
- )
-
- txn.execute(sql, (user_id, priority_class, new_rule_priority))
-
- self._upsert_push_rule_txn(
- txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- priority_class,
- new_rule_priority,
- conditions_json,
- actions_json,
- )
-
- def _add_push_rule_highest_priority_txn(
- self,
- txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- priority_class,
- conditions_json,
- actions_json,
- ):
- # Lock the table since otherwise we'll have annoying races between the
- # SELECT here and the UPSERT below.
- self.database_engine.lock_table(txn, "push_rules")
-
- # find the highest priority rule in that class
- sql = (
- "SELECT COUNT(*), MAX(priority) FROM push_rules"
- " WHERE user_name = ? and priority_class = ?"
- )
- txn.execute(sql, (user_id, priority_class))
- res = txn.fetchall()
- (how_many, highest_prio) = res[0]
-
- new_prio = 0
- if how_many > 0:
- new_prio = highest_prio + 1
-
- self._upsert_push_rule_txn(
- txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- priority_class,
- new_prio,
- conditions_json,
- actions_json,
- )
-
- def _upsert_push_rule_txn(
- self,
- txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- priority_class,
- priority,
- conditions_json,
- actions_json,
- update_stream=True,
- ):
- """Specialised version of _simple_upsert_txn that picks a push_rule_id
- using the _push_rule_id_gen if it needs to insert the rule. It assumes
- that the "push_rules" table is locked"""
-
- sql = (
- "UPDATE push_rules"
- " SET priority_class = ?, priority = ?, conditions = ?, actions = ?"
- " WHERE user_name = ? AND rule_id = ?"
- )
-
- txn.execute(
- sql,
- (priority_class, priority, conditions_json, actions_json, user_id, rule_id),
- )
-
- if txn.rowcount == 0:
- # We didn't update a row with the given rule_id so insert one
- push_rule_id = self._push_rule_id_gen.get_next()
-
- self._simple_insert_txn(
- txn,
- table="push_rules",
- values={
- "id": push_rule_id,
- "user_name": user_id,
- "rule_id": rule_id,
- "priority_class": priority_class,
- "priority": priority,
- "conditions": conditions_json,
- "actions": actions_json,
- },
- )
-
- if update_stream:
- self._insert_push_rules_update_txn(
- txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- op="ADD",
- data={
- "priority_class": priority_class,
- "priority": priority,
- "conditions": conditions_json,
- "actions": actions_json,
- },
- )
-
- @defer.inlineCallbacks
- def delete_push_rule(self, user_id, rule_id):
- """
- Delete a push rule. Args specify the row to be deleted and can be
- any of the columns in the push_rule table, but below are the
- standard ones
-
- Args:
- user_id (str): The matrix ID of the push rule owner
- rule_id (str): The rule_id of the rule to be deleted
- """
-
- def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
- self._simple_delete_one_txn(
- txn, "push_rules", {'user_name': user_id, 'rule_id': rule_id}
- )
-
- self._insert_push_rules_update_txn(
- txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
- )
-
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
- yield self.runInteraction(
- "delete_push_rule",
- delete_push_rule_txn,
- stream_id,
- event_stream_ordering,
- )
-
- @defer.inlineCallbacks
- def set_push_rule_enabled(self, user_id, rule_id, enabled):
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
- yield self.runInteraction(
- "_set_push_rule_enabled_txn",
- self._set_push_rule_enabled_txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- enabled,
- )
-
- def _set_push_rule_enabled_txn(
- self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
- ):
- new_id = self._push_rules_enable_id_gen.get_next()
- self._simple_upsert_txn(
- txn,
- "push_rules_enable",
- {'user_name': user_id, 'rule_id': rule_id},
- {'enabled': 1 if enabled else 0},
- {'id': new_id},
- )
-
- self._insert_push_rules_update_txn(
- txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- op="ENABLE" if enabled else "DISABLE",
- )
-
- @defer.inlineCallbacks
- def set_push_rule_actions(self, user_id, rule_id, actions, is_default_rule):
- actions_json = json.dumps(actions)
-
- def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
- if is_default_rule:
- # Add a dummy rule to the rules table with the user specified
- # actions.
- priority_class = -1
- priority = 1
- self._upsert_push_rule_txn(
- txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- priority_class,
- priority,
- "[]",
- actions_json,
- update_stream=False,
- )
- else:
- self._simple_update_one_txn(
- txn,
- "push_rules",
- {'user_name': user_id, 'rule_id': rule_id},
- {'actions': actions_json},
- )
-
- self._insert_push_rules_update_txn(
- txn,
- stream_id,
- event_stream_ordering,
- user_id,
- rule_id,
- op="ACTIONS",
- data={"actions": actions_json},
- )
-
- with self._push_rules_stream_id_gen.get_next() as ids:
- stream_id, event_stream_ordering = ids
- yield self.runInteraction(
- "set_push_rule_actions",
- set_push_rule_actions_txn,
- stream_id,
- event_stream_ordering,
- )
-
- def _insert_push_rules_update_txn(
- self, txn, stream_id, event_stream_ordering, user_id, rule_id, op, data=None
- ):
- values = {
- "stream_id": stream_id,
- "event_stream_ordering": event_stream_ordering,
- "user_id": user_id,
- "rule_id": rule_id,
- "op": op,
- }
- if data is not None:
- values.update(data)
-
- self._simple_insert_txn(txn, "push_rules_stream", values=values)
-
- txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
- txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
- txn.call_after(
- self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
- )
-
- def get_all_push_rule_updates(self, last_id, current_id, limit):
- """Get all the push rules changes that have happend on the server"""
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_push_rule_updates_txn(txn):
- sql = (
- "SELECT stream_id, event_stream_ordering, user_id, rule_id,"
- " op, priority_class, priority, conditions, actions"
- " FROM push_rules_stream"
- " WHERE ? < stream_id AND stream_id <= ?"
- " ORDER BY stream_id ASC LIMIT ?"
- )
- txn.execute(sql, (last_id, current_id, limit))
- return txn.fetchall()
-
- return self.runInteraction(
- "get_all_push_rule_updates", get_all_push_rule_updates_txn
- )
-
- def get_push_rules_stream_token(self):
- """Get the position of the push rules stream.
- Returns a pair of a stream id for the push_rules stream and the
- room stream ordering it corresponds to."""
- return self._push_rules_stream_id_gen.get_current_token()
-
- def get_max_push_rules_stream_id(self):
- return self.get_push_rules_stream_token()[0]
-
class RuleNotFoundException(Exception):
pass
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index 4c83800cca..d471ec9860 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -17,13 +17,7 @@ import logging
import attr
-from twisted.internet import defer
-
-from synapse.api.constants import RelationTypes
from synapse.api.errors import SynapseError
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.stream import generate_pagination_where_clause
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
logger = logging.getLogger(__name__)
@@ -60,7 +54,7 @@ class PaginationChunk(object):
class RelationPaginationToken(object):
"""Pagination token for relation pagination API.
- As the results are order by topological ordering, we can use the
+ As the results are in topological order, we can use the
`topological_ordering` and `stream_ordering` fields of the events at the
boundaries of the chunk as pagination tokens.
@@ -115,362 +109,3 @@ class AggregationPaginationToken(object):
def as_tuple(self):
return attr.astuple(self)
-
-
-class RelationsWorkerStore(SQLBaseStore):
- @cached(tree=True)
- def get_relations_for_event(
- self,
- event_id,
- relation_type=None,
- event_type=None,
- aggregation_key=None,
- limit=5,
- direction="b",
- from_token=None,
- to_token=None,
- ):
- """Get a list of relations for an event, ordered by topological ordering.
-
- Args:
- event_id (str): Fetch events that relate to this event ID.
- relation_type (str|None): Only fetch events with this relation
- type, if given.
- event_type (str|None): Only fetch events with this event type, if
- given.
- aggregation_key (str|None): Only fetch events with this aggregation
- key, if given.
- limit (int): Only fetch the most recent `limit` events.
- direction (str): Whether to fetch the most recent first (`"b"`) or
- the oldest first (`"f"`).
- from_token (RelationPaginationToken|None): Fetch rows from the given
- token, or from the start if None.
- to_token (RelationPaginationToken|None): Fetch rows up to the given
- token, or up to the end if None.
-
- Returns:
- Deferred[PaginationChunk]: List of event IDs that match relations
- requested. The rows are of the form `{"event_id": "..."}`.
- """
-
- where_clause = ["relates_to_id = ?"]
- where_args = [event_id]
-
- if relation_type is not None:
- where_clause.append("relation_type = ?")
- where_args.append(relation_type)
-
- if event_type is not None:
- where_clause.append("type = ?")
- where_args.append(event_type)
-
- if aggregation_key:
- where_clause.append("aggregation_key = ?")
- where_args.append(aggregation_key)
-
- pagination_clause = generate_pagination_where_clause(
- direction=direction,
- column_names=("topological_ordering", "stream_ordering"),
- from_token=attr.astuple(from_token) if from_token else None,
- to_token=attr.astuple(to_token) if to_token else None,
- engine=self.database_engine,
- )
-
- if pagination_clause:
- where_clause.append(pagination_clause)
-
- if direction == "b":
- order = "DESC"
- else:
- order = "ASC"
-
- sql = """
- SELECT event_id, topological_ordering, stream_ordering
- FROM event_relations
- INNER JOIN events USING (event_id)
- WHERE %s
- ORDER BY topological_ordering %s, stream_ordering %s
- LIMIT ?
- """ % (
- " AND ".join(where_clause),
- order,
- order,
- )
-
- def _get_recent_references_for_event_txn(txn):
- txn.execute(sql, where_args + [limit + 1])
-
- last_topo_id = None
- last_stream_id = None
- events = []
- for row in txn:
- events.append({"event_id": row[0]})
- last_topo_id = row[1]
- last_stream_id = row[2]
-
- next_batch = None
- if len(events) > limit and last_topo_id and last_stream_id:
- next_batch = RelationPaginationToken(last_topo_id, last_stream_id)
-
- return PaginationChunk(
- chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
- )
-
- return self.runInteraction(
- "get_recent_references_for_event", _get_recent_references_for_event_txn
- )
-
- @cached(tree=True)
- def get_aggregation_groups_for_event(
- self,
- event_id,
- event_type=None,
- limit=5,
- direction="b",
- from_token=None,
- to_token=None,
- ):
- """Get a list of annotations on the event, grouped by event type and
- aggregation key, sorted by count.
-
- This is used e.g. to get the what and how many reactions have happend
- on an event.
-
- Args:
- event_id (str): Fetch events that relate to this event ID.
- event_type (str|None): Only fetch events with this event type, if
- given.
- limit (int): Only fetch the `limit` groups.
- direction (str): Whether to fetch the highest count first (`"b"`) or
- the lowest count first (`"f"`).
- from_token (AggregationPaginationToken|None): Fetch rows from the
- given token, or from the start if None.
- to_token (AggregationPaginationToken|None): Fetch rows up to the
- given token, or up to the end if None.
-
-
- Returns:
- Deferred[PaginationChunk]: List of groups of annotations that
- match. Each row is a dict with `type`, `key` and `count` fields.
- """
-
- where_clause = ["relates_to_id = ?", "relation_type = ?"]
- where_args = [event_id, RelationTypes.ANNOTATION]
-
- if event_type:
- where_clause.append("type = ?")
- where_args.append(event_type)
-
- having_clause = generate_pagination_where_clause(
- direction=direction,
- column_names=("COUNT(*)", "MAX(stream_ordering)"),
- from_token=attr.astuple(from_token) if from_token else None,
- to_token=attr.astuple(to_token) if to_token else None,
- engine=self.database_engine,
- )
-
- if direction == "b":
- order = "DESC"
- else:
- order = "ASC"
-
- if having_clause:
- having_clause = "HAVING " + having_clause
- else:
- having_clause = ""
-
- sql = """
- SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering)
- FROM event_relations
- INNER JOIN events USING (event_id)
- WHERE {where_clause}
- GROUP BY relation_type, type, aggregation_key
- {having_clause}
- ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order}
- LIMIT ?
- """.format(
- where_clause=" AND ".join(where_clause),
- order=order,
- having_clause=having_clause,
- )
-
- def _get_aggregation_groups_for_event_txn(txn):
- txn.execute(sql, where_args + [limit + 1])
-
- next_batch = None
- events = []
- for row in txn:
- events.append({"type": row[0], "key": row[1], "count": row[2]})
- next_batch = AggregationPaginationToken(row[2], row[3])
-
- if len(events) <= limit:
- next_batch = None
-
- return PaginationChunk(
- chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
- )
-
- return self.runInteraction(
- "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
- )
-
- @cachedInlineCallbacks()
- def get_applicable_edit(self, event_id):
- """Get the most recent edit (if any) that has happened for the given
- event.
-
- Correctly handles checking whether edits were allowed to happen.
-
- Args:
- event_id (str): The original event ID
-
- Returns:
- Deferred[EventBase|None]: Returns the most recent edit, if any.
- """
-
- # We only allow edits for `m.room.message` events that have the same sender
- # and event type. We can't assert these things during regular event auth so
- # we have to do the checks post hoc.
-
- # Fetches latest edit that has the same type and sender as the
- # original, and is an `m.room.message`.
- sql = """
- SELECT edit.event_id FROM events AS edit
- INNER JOIN event_relations USING (event_id)
- INNER JOIN events AS original ON
- original.event_id = relates_to_id
- AND edit.type = original.type
- AND edit.sender = original.sender
- WHERE
- relates_to_id = ?
- AND relation_type = ?
- AND edit.type = 'm.room.message'
- ORDER by edit.origin_server_ts DESC, edit.event_id DESC
- LIMIT 1
- """
-
- def _get_applicable_edit_txn(txn):
- txn.execute(sql, (event_id, RelationTypes.REPLACE))
- row = txn.fetchone()
- if row:
- return row[0]
-
- edit_id = yield self.runInteraction(
- "get_applicable_edit", _get_applicable_edit_txn
- )
-
- if not edit_id:
- return
-
- edit_event = yield self.get_event(edit_id, allow_none=True)
- defer.returnValue(edit_event)
-
- def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
- """Check if a user has already annotated an event with the same key
- (e.g. already liked an event).
-
- Args:
- parent_id (str): The event being annotated
- event_type (str): The event type of the annotation
- aggregation_key (str): The aggregation key of the annotation
- sender (str): The sender of the annotation
-
- Returns:
- Deferred[bool]
- """
-
- sql = """
- SELECT 1 FROM event_relations
- INNER JOIN events USING (event_id)
- WHERE
- relates_to_id = ?
- AND relation_type = ?
- AND type = ?
- AND sender = ?
- AND aggregation_key = ?
- LIMIT 1;
- """
-
- def _get_if_user_has_annotated_event(txn):
- txn.execute(
- sql,
- (
- parent_id,
- RelationTypes.ANNOTATION,
- event_type,
- sender,
- aggregation_key,
- ),
- )
-
- return bool(txn.fetchone())
-
- return self.runInteraction(
- "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
- )
-
-
-class RelationsStore(RelationsWorkerStore):
- def _handle_event_relations(self, txn, event):
- """Handles inserting relation data during peristence of events
-
- Args:
- txn
- event (EventBase)
- """
- relation = event.content.get("m.relates_to")
- if not relation:
- # No relations
- return
-
- rel_type = relation.get("rel_type")
- if rel_type not in (
- RelationTypes.ANNOTATION,
- RelationTypes.REFERENCE,
- RelationTypes.REPLACE,
- ):
- # Unknown relation type
- return
-
- parent_id = relation.get("event_id")
- if not parent_id:
- # Invalid relation
- return
-
- aggregation_key = relation.get("key")
-
- self._simple_insert_txn(
- txn,
- table="event_relations",
- values={
- "event_id": event.event_id,
- "relates_to_id": parent_id,
- "relation_type": rel_type,
- "aggregation_key": aggregation_key,
- },
- )
-
- txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,))
- txn.call_after(
- self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
- )
-
- if rel_type == RelationTypes.REPLACE:
- txn.call_after(self.get_applicable_edit.invalidate, (parent_id,))
-
- def _handle_redaction(self, txn, redacted_event_id):
- """Handles receiving a redaction and checking whether we need to remove
- any redacted relations from the database.
-
- Args:
- txn
- redacted_event_id (str): The event that was redacted.
- """
-
- self._simple_delete_txn(
- txn,
- table="event_relations",
- keyvalues={
- "event_id": redacted_event_id,
- }
- )
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
deleted file mode 100644
index db3d052d33..0000000000
--- a/synapse/storage/room.py
+++ /dev/null
@@ -1,900 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import collections
-import logging
-import re
-
-from six import integer_types
-
-from canonicaljson import json
-
-from twisted.internet import defer
-
-from synapse.api.constants import EventTypes
-from synapse.api.errors import StoreError
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.search import SearchStore
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-
-logger = logging.getLogger(__name__)
-
-
-OpsLevel = collections.namedtuple(
- "OpsLevel", ("ban_level", "kick_level", "redact_level")
-)
-
-RatelimitOverride = collections.namedtuple(
- "RatelimitOverride", ("messages_per_second", "burst_count")
-)
-
-
-class RoomWorkerStore(SQLBaseStore):
- def get_room(self, room_id):
- """Retrieve a room.
-
- Args:
- room_id (str): The ID of the room to retrieve.
- Returns:
- A dict containing the room information, or None if the room is unknown.
- """
- return self._simple_select_one(
- table="rooms",
- keyvalues={"room_id": room_id},
- retcols=("room_id", "is_public", "creator"),
- desc="get_room",
- allow_none=True,
- )
-
- def get_public_room_ids(self):
- return self._simple_select_onecol(
- table="rooms",
- keyvalues={"is_public": True},
- retcol="room_id",
- desc="get_public_room_ids",
- )
-
- @cached(num_args=2, max_entries=100)
- def get_public_room_ids_at_stream_id(self, stream_id, network_tuple):
- """Get pulbic rooms for a particular list, or across all lists.
-
- Args:
- stream_id (int)
- network_tuple (ThirdPartyInstanceID): The list to use (None, None)
- means the main list, None means all lsits.
- """
- return self.runInteraction(
- "get_public_room_ids_at_stream_id",
- self.get_public_room_ids_at_stream_id_txn,
- stream_id,
- network_tuple=network_tuple,
- )
-
- def get_public_room_ids_at_stream_id_txn(self, txn, stream_id, network_tuple):
- return {
- rm
- for rm, vis in self.get_published_at_stream_id_txn(
- txn, stream_id, network_tuple=network_tuple
- ).items()
- if vis
- }
-
- def get_published_at_stream_id_txn(self, txn, stream_id, network_tuple):
- if network_tuple:
- # We want to get from a particular list. No aggregation required.
-
- sql = """
- SELECT room_id, visibility FROM public_room_list_stream
- INNER JOIN (
- SELECT room_id, max(stream_id) AS stream_id
- FROM public_room_list_stream
- WHERE stream_id <= ? %s
- GROUP BY room_id
- ) grouped USING (room_id, stream_id)
- """
-
- if network_tuple.appservice_id is not None:
- txn.execute(
- sql % ("AND appservice_id = ? AND network_id = ?",),
- (stream_id, network_tuple.appservice_id, network_tuple.network_id),
- )
- else:
- txn.execute(sql % ("AND appservice_id IS NULL",), (stream_id,))
- return dict(txn)
- else:
- # We want to get from all lists, so we need to aggregate the results
-
- logger.info("Executing full list")
-
- sql = """
- SELECT room_id, visibility
- FROM public_room_list_stream
- INNER JOIN (
- SELECT
- room_id, max(stream_id) AS stream_id, appservice_id,
- network_id
- FROM public_room_list_stream
- WHERE stream_id <= ?
- GROUP BY room_id, appservice_id, network_id
- ) grouped USING (room_id, stream_id)
- """
-
- txn.execute(sql, (stream_id,))
-
- results = {}
- # A room is visible if its visible on any list.
- for room_id, visibility in txn:
- results[room_id] = bool(visibility) or results.get(room_id, False)
-
- return results
-
- def get_public_room_changes(self, prev_stream_id, new_stream_id, network_tuple):
- def get_public_room_changes_txn(txn):
- then_rooms = self.get_public_room_ids_at_stream_id_txn(
- txn, prev_stream_id, network_tuple
- )
-
- now_rooms_dict = self.get_published_at_stream_id_txn(
- txn, new_stream_id, network_tuple
- )
-
- now_rooms_visible = set(rm for rm, vis in now_rooms_dict.items() if vis)
- now_rooms_not_visible = set(
- rm for rm, vis in now_rooms_dict.items() if not vis
- )
-
- newly_visible = now_rooms_visible - then_rooms
- newly_unpublished = now_rooms_not_visible & then_rooms
-
- return newly_visible, newly_unpublished
-
- return self.runInteraction(
- "get_public_room_changes", get_public_room_changes_txn
- )
-
- @cached(max_entries=10000)
- def is_room_blocked(self, room_id):
- return self._simple_select_one_onecol(
- table="blocked_rooms",
- keyvalues={"room_id": room_id},
- retcol="1",
- allow_none=True,
- desc="is_room_blocked",
- )
-
- @defer.inlineCallbacks
- def is_room_published(self, room_id):
- """Check whether a room has been published in the local public room
- directory.
-
- Args:
- room_id (str)
- Returns:
- bool: Whether the room is currently published in the room directory
- """
- # Get room information
- room_info = yield self.get_room(room_id)
- if not room_info:
- defer.returnValue(False)
-
- # Check the is_public value
- defer.returnValue(room_info.get("is_public", False))
-
- @cachedInlineCallbacks(max_entries=10000)
- def get_ratelimit_for_user(self, user_id):
- """Check if there are any overrides for ratelimiting for the given
- user
-
- Args:
- user_id (str)
-
- Returns:
- RatelimitOverride if there is an override, else None. If the contents
- of RatelimitOverride are None or 0 then ratelimitng has been
- disabled for that user entirely.
- """
- row = yield self._simple_select_one(
- table="ratelimit_override",
- keyvalues={"user_id": user_id},
- retcols=("messages_per_second", "burst_count"),
- allow_none=True,
- desc="get_ratelimit_for_user",
- )
-
- if row:
- defer.returnValue(
- RatelimitOverride(
- messages_per_second=row["messages_per_second"],
- burst_count=row["burst_count"],
- )
- )
- else:
- defer.returnValue(None)
-
- @cachedInlineCallbacks()
- def get_retention_policy_for_room(self, room_id):
- """Get the retention policy for a given room.
-
- If no retention policy has been found for this room, returns a policy defined
- by the configured default policy (which has None as both the 'min_lifetime' and
- the 'max_lifetime' if no default policy has been defined in the server's
- configuration).
-
- Args:
- room_id (str): The ID of the room to get the retention policy of.
-
- Returns:
- dict[int, int]: "min_lifetime" and "max_lifetime" for this room.
- """
- # If the room retention feature is disabled, return a policy with no minimum nor
- # maximum, in order not to filter out events we should filter out when sending to
- # the client.
- if not self.config.retention_enabled:
- defer.returnValue({
- "min_lifetime": None,
- "max_lifetime": None,
- })
-
- def get_retention_policy_for_room_txn(txn):
- txn.execute(
- """
- SELECT min_lifetime, max_lifetime FROM room_retention
- INNER JOIN current_state_events USING (event_id, room_id)
- WHERE room_id = ?;
- """,
- (room_id,)
- )
-
- return self.cursor_to_dict(txn)
-
- ret = yield self.runInteraction(
- "get_retention_policy_for_room",
- get_retention_policy_for_room_txn,
- )
-
- # If we don't know this room ID, ret will be None, in this case return the default
- # policy.
- if not ret:
- defer.returnValue({
- "min_lifetime": self.config.retention_default_min_lifetime,
- "max_lifetime": self.config.retention_default_max_lifetime,
- })
-
- row = ret[0]
-
- # If one of the room's policy's attributes isn't defined, use the matching
- # attribute from the default policy.
- # The default values will be None if no default policy has been defined, or if one
- # of the attributes is missing from the default policy.
- if row["min_lifetime"] is None:
- row["min_lifetime"] = self.config.retention_default_min_lifetime
-
- if row["max_lifetime"] is None:
- row["max_lifetime"] = self.config.retention_default_max_lifetime
-
- defer.returnValue(row)
-
-
-class RoomStore(RoomWorkerStore, SearchStore):
- def __init__(self, db_conn, hs):
- super(RoomStore, self).__init__(db_conn, hs)
-
- self.config = hs.config
-
- self.register_background_update_handler(
- "insert_room_retention", self._background_insert_retention,
- )
-
- @defer.inlineCallbacks
- def _background_insert_retention(self, progress, batch_size):
- """Retrieves a list of all rooms within a range and inserts an entry for each of
- them into the room_retention table.
- NULLs the property's columns if missing from the retention event in the room's
- state (or NULLs all of them if there's no retention event in the room's state),
- so that we fall back to the server's retention policy.
- """
-
- last_room = progress.get("room_id", "")
-
- def _background_insert_retention_txn(txn):
- txn.execute(
- """
- SELECT state.room_id, state.event_id, events.json
- FROM current_state_events as state
- LEFT JOIN event_json AS events ON (state.event_id = events.event_id)
- WHERE state.room_id > ? AND state.type = '%s'
- ORDER BY state.room_id ASC
- LIMIT ?;
- """ % EventTypes.Retention,
- (last_room, batch_size)
- )
-
- rows = self.cursor_to_dict(txn)
-
- if not rows:
- return True
-
- for row in rows:
- if not row["json"]:
- retention_policy = {}
- else:
- ev = json.loads(row["json"])
- retention_policy = json.dumps(ev["content"])
-
- self._simple_insert_txn(
- txn=txn,
- table="room_retention",
- values={
- "room_id": row["room_id"],
- "event_id": row["event_id"],
- "min_lifetime": retention_policy.get("min_lifetime"),
- "max_lifetime": retention_policy.get("max_lifetime"),
- }
- )
-
- logger.info("Inserted %d rows into room_retention", len(rows))
-
- self._background_update_progress_txn(
- txn, "insert_room_retention", {
- "room_id": rows[-1]["room_id"],
- }
- )
-
- if batch_size > len(rows):
- return True
- else:
- return False
-
- end = yield self.runInteraction(
- "insert_room_retention",
- _background_insert_retention_txn,
- )
-
- if end:
- yield self._end_background_update("insert_room_retention")
-
- defer.returnValue(batch_size)
-
- @defer.inlineCallbacks
- def store_room(self, room_id, room_creator_user_id, is_public):
- """Stores a room.
-
- Args:
- room_id (str): The desired room ID, can be None.
- room_creator_user_id (str): The user ID of the room creator.
- is_public (bool): True to indicate that this room should appear in
- public room lists.
- Raises:
- StoreError if the room could not be stored.
- """
- try:
-
- def store_room_txn(txn, next_id):
- self._simple_insert_txn(
- txn,
- "rooms",
- {
- "room_id": room_id,
- "creator": room_creator_user_id,
- "is_public": is_public,
- },
- )
- if is_public:
- self._simple_insert_txn(
- txn,
- table="public_room_list_stream",
- values={
- "stream_id": next_id,
- "room_id": room_id,
- "visibility": is_public,
- },
- )
-
- with self._public_room_id_gen.get_next() as next_id:
- yield self.runInteraction("store_room_txn", store_room_txn, next_id)
- except Exception as e:
- logger.error("store_room with room_id=%s failed: %s", room_id, e)
- raise StoreError(500, "Problem creating room.")
-
- @defer.inlineCallbacks
- def set_room_is_public(self, room_id, is_public):
- def set_room_is_public_txn(txn, next_id):
- self._simple_update_one_txn(
- txn,
- table="rooms",
- keyvalues={"room_id": room_id},
- updatevalues={"is_public": is_public},
- )
-
- entries = self._simple_select_list_txn(
- txn,
- table="public_room_list_stream",
- keyvalues={
- "room_id": room_id,
- "appservice_id": None,
- "network_id": None,
- },
- retcols=("stream_id", "visibility"),
- )
-
- entries.sort(key=lambda r: r["stream_id"])
-
- add_to_stream = True
- if entries:
- add_to_stream = bool(entries[-1]["visibility"]) != is_public
-
- if add_to_stream:
- self._simple_insert_txn(
- txn,
- table="public_room_list_stream",
- values={
- "stream_id": next_id,
- "room_id": room_id,
- "visibility": is_public,
- "appservice_id": None,
- "network_id": None,
- },
- )
-
- with self._public_room_id_gen.get_next() as next_id:
- yield self.runInteraction(
- "set_room_is_public", set_room_is_public_txn, next_id
- )
- self.hs.get_notifier().on_new_replication_data()
-
- @defer.inlineCallbacks
- def set_room_is_public_appservice(
- self, room_id, appservice_id, network_id, is_public
- ):
- """Edit the appservice/network specific public room list.
-
- Each appservice can have a number of published room lists associated
- with them, keyed off of an appservice defined `network_id`, which
- basically represents a single instance of a bridge to a third party
- network.
-
- Args:
- room_id (str)
- appservice_id (str)
- network_id (str)
- is_public (bool): Whether to publish or unpublish the room from the
- list.
- """
-
- def set_room_is_public_appservice_txn(txn, next_id):
- if is_public:
- try:
- self._simple_insert_txn(
- txn,
- table="appservice_room_list",
- values={
- "appservice_id": appservice_id,
- "network_id": network_id,
- "room_id": room_id,
- },
- )
- except self.database_engine.module.IntegrityError:
- # We've already inserted, nothing to do.
- return
- else:
- self._simple_delete_txn(
- txn,
- table="appservice_room_list",
- keyvalues={
- "appservice_id": appservice_id,
- "network_id": network_id,
- "room_id": room_id,
- },
- )
-
- entries = self._simple_select_list_txn(
- txn,
- table="public_room_list_stream",
- keyvalues={
- "room_id": room_id,
- "appservice_id": appservice_id,
- "network_id": network_id,
- },
- retcols=("stream_id", "visibility"),
- )
-
- entries.sort(key=lambda r: r["stream_id"])
-
- add_to_stream = True
- if entries:
- add_to_stream = bool(entries[-1]["visibility"]) != is_public
-
- if add_to_stream:
- self._simple_insert_txn(
- txn,
- table="public_room_list_stream",
- values={
- "stream_id": next_id,
- "room_id": room_id,
- "visibility": is_public,
- "appservice_id": appservice_id,
- "network_id": network_id,
- },
- )
-
- with self._public_room_id_gen.get_next() as next_id:
- yield self.runInteraction(
- "set_room_is_public_appservice",
- set_room_is_public_appservice_txn,
- next_id,
- )
- self.hs.get_notifier().on_new_replication_data()
-
- def get_room_count(self):
- """Retrieve a list of all rooms
- """
-
- def f(txn):
- sql = "SELECT count(*) FROM rooms"
- txn.execute(sql)
- row = txn.fetchone()
- return row[0] or 0
-
- return self.runInteraction("get_rooms", f)
-
- def _store_room_topic_txn(self, txn, event):
- if hasattr(event, "content") and "topic" in event.content:
- self._simple_insert_txn(
- txn,
- "topics",
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "topic": event.content["topic"],
- },
- )
-
- self.store_event_search_txn(
- txn, event, "content.topic", event.content["topic"]
- )
-
- def _store_room_name_txn(self, txn, event):
- if hasattr(event, "content") and "name" in event.content:
- self._simple_insert_txn(
- txn,
- "room_names",
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "name": event.content["name"],
- },
- )
-
- self.store_event_search_txn(
- txn, event, "content.name", event.content["name"]
- )
-
- def _store_room_message_txn(self, txn, event):
- if hasattr(event, "content") and "body" in event.content:
- self.store_event_search_txn(
- txn, event, "content.body", event.content["body"]
- )
-
- def _store_history_visibility_txn(self, txn, event):
- self._store_content_index_txn(txn, event, "history_visibility")
-
- def _store_guest_access_txn(self, txn, event):
- self._store_content_index_txn(txn, event, "guest_access")
-
- def _store_content_index_txn(self, txn, event, key):
- if hasattr(event, "content") and key in event.content:
- sql = (
- "INSERT INTO %(key)s"
- " (event_id, room_id, %(key)s)"
- " VALUES (?, ?, ?)" % {"key": key}
- )
- txn.execute(sql, (event.event_id, event.room_id, event.content[key]))
-
- def _store_retention_policy_for_room_txn(self, txn, event):
- if (
- hasattr(event, "content")
- and ("min_lifetime" in event.content or "max_lifetime" in event.content)
- ):
- if (
- ("min_lifetime" in event.content and not isinstance(
- event.content.get("min_lifetime"), integer_types
- ))
- or ("max_lifetime" in event.content and not isinstance(
- event.content.get("max_lifetime"), integer_types
- ))
- ):
- # Ignore the event if one of the value isn't an integer.
- return
-
- self._simple_insert_txn(
- txn=txn,
- table="room_retention",
- values={
- "room_id": event.room_id,
- "event_id": event.event_id,
- "min_lifetime": event.content.get("min_lifetime"),
- "max_lifetime": event.content.get("max_lifetime"),
- },
- )
-
- self._invalidate_cache_and_stream(
- txn, self.get_retention_policy_for_room, (event.room_id,)
- )
-
- def add_event_report(
- self, room_id, event_id, user_id, reason, content, received_ts
- ):
- next_id = self._event_reports_id_gen.get_next()
- return self._simple_insert(
- table="event_reports",
- values={
- "id": next_id,
- "received_ts": received_ts,
- "room_id": room_id,
- "event_id": event_id,
- "user_id": user_id,
- "reason": reason,
- "content": json.dumps(content),
- },
- desc="add_event_report",
- )
-
- def get_current_public_room_stream_id(self):
- return self._public_room_id_gen.get_current_token()
-
- def get_all_new_public_rooms(self, prev_id, current_id, limit):
- def get_all_new_public_rooms(txn):
- sql = """
- SELECT stream_id, room_id, visibility, appservice_id, network_id
- FROM public_room_list_stream
- WHERE stream_id > ? AND stream_id <= ?
- ORDER BY stream_id ASC
- LIMIT ?
- """
-
- txn.execute(sql, (prev_id, current_id, limit))
- return txn.fetchall()
-
- if prev_id == current_id:
- return defer.succeed([])
-
- return self.runInteraction("get_all_new_public_rooms", get_all_new_public_rooms)
-
- @defer.inlineCallbacks
- def block_room(self, room_id, user_id):
- """Marks the room as blocked. Can be called multiple times.
-
- Args:
- room_id (str): Room to block
- user_id (str): Who blocked it
-
- Returns:
- Deferred
- """
- yield self._simple_upsert(
- table="blocked_rooms",
- keyvalues={"room_id": room_id},
- values={},
- insertion_values={"user_id": user_id},
- desc="block_room",
- )
- yield self.runInteraction(
- "block_room_invalidation",
- self._invalidate_cache_and_stream,
- self.is_room_blocked,
- (room_id,),
- )
-
- def get_media_mxcs_in_room(self, room_id):
- """Retrieves all the local and remote media MXC URIs in a given room
-
- Args:
- room_id (str)
-
- Returns:
- The local and remote media as a lists of tuples where the key is
- the hostname and the value is the media ID.
- """
-
- def _get_media_mxcs_in_room_txn(txn):
- local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
- local_media_mxcs = []
- remote_media_mxcs = []
-
- # Convert the IDs to MXC URIs
- for media_id in local_mxcs:
- local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id))
- for hostname, media_id in remote_mxcs:
- remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
-
- return local_media_mxcs, remote_media_mxcs
-
- return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
-
- def quarantine_media_ids_in_room(self, room_id, quarantined_by):
- """For a room loops through all events with media and quarantines
- the associated media
- """
-
- def _quarantine_media_in_room_txn(txn):
- local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
- total_media_quarantined = 0
-
- # Now update all the tables to set the quarantined_by flag
-
- txn.executemany(
- """
- UPDATE local_media_repository
- SET quarantined_by = ?
- WHERE media_id = ?
- """,
- ((quarantined_by, media_id) for media_id in local_mxcs),
- )
-
- txn.executemany(
- """
- UPDATE remote_media_cache
- SET quarantined_by = ?
- WHERE media_origin = ? AND media_id = ?
- """,
- (
- (quarantined_by, origin, media_id)
- for origin, media_id in remote_mxcs
- ),
- )
-
- total_media_quarantined += len(local_mxcs)
- total_media_quarantined += len(remote_mxcs)
-
- return total_media_quarantined
-
- return self.runInteraction(
- "quarantine_media_in_room", _quarantine_media_in_room_txn
- )
-
- def _get_media_mxcs_in_room_txn(self, txn, room_id):
- """Retrieves all the local and remote media MXC URIs in a given room
-
- Args:
- txn (cursor)
- room_id (str)
-
- Returns:
- The local and remote media as a lists of tuples where the key is
- the hostname and the value is the media ID.
- """
- mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
-
- next_token = self.get_current_events_token() + 1
- local_media_mxcs = []
- remote_media_mxcs = []
-
- while next_token:
- sql = """
- SELECT stream_ordering, json FROM events
- JOIN event_json USING (room_id, event_id)
- WHERE room_id = ?
- AND stream_ordering < ?
- AND contains_url = ? AND outlier = ?
- ORDER BY stream_ordering DESC
- LIMIT ?
- """
- txn.execute(sql, (room_id, next_token, True, False, 100))
-
- next_token = None
- for stream_ordering, content_json in txn:
- next_token = stream_ordering
- event_json = json.loads(content_json)
- content = event_json["content"]
- content_url = content.get("url")
- thumbnail_url = content.get("info", {}).get("thumbnail_url")
-
- for url in (content_url, thumbnail_url):
- if not url:
- continue
- matches = mxc_re.match(url)
- if matches:
- hostname = matches.group(1)
- media_id = matches.group(2)
- if hostname == self.hs.hostname:
- local_media_mxcs.append(media_id)
- else:
- remote_media_mxcs.append((hostname, media_id))
-
- return local_media_mxcs, remote_media_mxcs
-
- @defer.inlineCallbacks
- def get_rooms_for_retention_period_in_range(self, min_ms, max_ms, include_null=False):
- """Retrieves all of the rooms within the given retention range.
-
- Optionally includes the rooms which don't have a retention policy.
-
- Args:
- min_ms (int|None): Duration in milliseconds that define the lower limit of
- the range to handle (exclusive). If None, doesn't set a lower limit.
- max_ms (int|None): Duration in milliseconds that define the upper limit of
- the range to handle (inclusive). If None, doesn't set an upper limit.
- include_null (bool): Whether to include rooms which retention policy is NULL
- in the returned set.
-
- Returns:
- dict[str, dict]: The rooms within this range, along with their retention
- policy. The key is "room_id", and maps to a dict describing the retention
- policy associated with this room ID. The keys for this nested dict are
- "min_lifetime" (int|None), and "max_lifetime" (int|None).
- """
-
- def get_rooms_for_retention_period_in_range_txn(txn):
- range_conditions = []
- args = []
-
- if min_ms is not None:
- range_conditions.append("max_lifetime > ?")
- args.append(min_ms)
-
- if max_ms is not None:
- range_conditions.append("max_lifetime <= ?")
- args.append(max_ms)
-
- # Do a first query which will retrieve the rooms that have a retention policy
- # in their current state.
- sql = """
- SELECT room_id, min_lifetime, max_lifetime FROM room_retention
- INNER JOIN current_state_events USING (event_id, room_id)
- """
-
- if len(range_conditions):
- sql += " WHERE (" + " AND ".join(range_conditions) + ")"
-
- if include_null:
- sql += " OR max_lifetime IS NULL"
-
- txn.execute(sql, args)
-
- rows = self.cursor_to_dict(txn)
- rooms_dict = {}
-
- for row in rows:
- rooms_dict[row["room_id"]] = {
- "min_lifetime": row["min_lifetime"],
- "max_lifetime": row["max_lifetime"],
- }
-
- if include_null:
- # If required, do a second query that retrieves all of the rooms we know
- # of so we can handle rooms with no retention policy.
- sql = "SELECT DISTINCT room_id FROM current_state_events"
-
- txn.execute(sql)
-
- rows = self.cursor_to_dict(txn)
-
- # If a room isn't already in the dict (i.e. it doesn't have a retention
- # policy in its state), add it with a null policy.
- for row in rows:
- if row["room_id"] not in rooms_dict:
- rooms_dict[row["room_id"]] = {
- "min_lifetime": None,
- "max_lifetime": None,
- }
-
- return rooms_dict
-
- rooms = yield self.runInteraction(
- "get_rooms_for_retention_period_in_range",
- get_rooms_for_retention_period_in_range_txn,
- )
-
- defer.returnValue(rooms)
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 7617913326..8c4a83a840 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -17,20 +17,6 @@
import logging
from collections import namedtuple
-from six import iteritems, itervalues
-
-from canonicaljson import json
-
-from twisted.internet import defer
-
-from synapse.api.constants import EventTypes, Membership
-from synapse.storage.events_worker import EventsWorkerStore
-from synapse.types import get_domain_from_id
-from synapse.util.async_helpers import Linearizer
-from synapse.util.caches import intern_string
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-from synapse.util.stringutils import to_ascii
-
logger = logging.getLogger(__name__)
@@ -51,780 +37,3 @@ ProfileInfo = namedtuple("ProfileInfo", ("avatar_url", "display_name"))
# a given membership type, suitable for use in calculating heroes for a room.
# "count" points to the total numberr of users of a given membership type.
MemberSummary = namedtuple("MemberSummary", ("members", "count"))
-
-_MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
-
-
-class RoomMemberWorkerStore(EventsWorkerStore):
- @cachedInlineCallbacks(max_entries=100000, iterable=True, cache_context=True)
- def get_hosts_in_room(self, room_id, cache_context):
- """Returns the set of all hosts currently in the room
- """
- user_ids = yield self.get_users_in_room(
- room_id, on_invalidate=cache_context.invalidate
- )
- hosts = frozenset(get_domain_from_id(user_id) for user_id in user_ids)
- defer.returnValue(hosts)
-
- @cached(max_entries=100000, iterable=True)
- def get_users_in_room(self, room_id):
- def f(txn):
- sql = (
- "SELECT m.user_id FROM room_memberships as m"
- " INNER JOIN current_state_events as c"
- " ON m.event_id = c.event_id "
- " AND m.room_id = c.room_id "
- " AND m.user_id = c.state_key"
- " WHERE c.type = 'm.room.member' AND c.room_id = ? AND m.membership = ?"
- )
-
- txn.execute(sql, (room_id, Membership.JOIN))
- return [to_ascii(r[0]) for r in txn]
-
- return self.runInteraction("get_users_in_room", f)
-
- @cached(max_entries=100000)
- def get_room_summary(self, room_id):
- """ Get the details of a room roughly suitable for use by the room
- summary extension to /sync. Useful when lazy loading room members.
- Args:
- room_id (str): The room ID to query
- Returns:
- Deferred[dict[str, MemberSummary]:
- dict of membership states, pointing to a MemberSummary named tuple.
- """
-
- def _get_room_summary_txn(txn):
- # first get counts.
- # We do this all in one transaction to keep the cache small.
- # FIXME: get rid of this when we have room_stats
- sql = """
- SELECT count(*), m.membership FROM room_memberships as m
- INNER JOIN current_state_events as c
- ON m.event_id = c.event_id
- AND m.room_id = c.room_id
- AND m.user_id = c.state_key
- WHERE c.type = 'm.room.member' AND c.room_id = ?
- GROUP BY m.membership
- """
-
- txn.execute(sql, (room_id,))
- res = {}
- for count, membership in txn:
- summary = res.setdefault(to_ascii(membership), MemberSummary([], count))
-
- # we order by membership and then fairly arbitrarily by event_id so
- # heroes are consistent
- sql = """
- SELECT m.user_id, m.membership, m.event_id
- FROM room_memberships as m
- INNER JOIN current_state_events as c
- ON m.event_id = c.event_id
- AND m.room_id = c.room_id
- AND m.user_id = c.state_key
- WHERE c.type = 'm.room.member' AND c.room_id = ?
- ORDER BY
- CASE m.membership WHEN ? THEN 1 WHEN ? THEN 2 ELSE 3 END ASC,
- m.event_id ASC
- LIMIT ?
- """
-
- # 6 is 5 (number of heroes) plus 1, in case one of them is the calling user.
- txn.execute(sql, (room_id, Membership.JOIN, Membership.INVITE, 6))
- for user_id, membership, event_id in txn:
- summary = res[to_ascii(membership)]
- # we will always have a summary for this membership type at this
- # point given the summary currently contains the counts.
- members = summary.members
- members.append((to_ascii(user_id), to_ascii(event_id)))
-
- return res
-
- return self.runInteraction("get_room_summary", _get_room_summary_txn)
-
- def _get_user_counts_in_room_txn(self, txn, room_id):
- """
- Get the user count in a room by membership.
-
- Args:
- room_id (str)
- membership (Membership)
-
- Returns:
- Deferred[int]
- """
- sql = """
- SELECT m.membership, count(*) FROM room_memberships as m
- INNER JOIN current_state_events as c USING(event_id)
- WHERE c.type = 'm.room.member' AND c.room_id = ?
- GROUP BY m.membership
- """
-
- txn.execute(sql, (room_id,))
- return {row[0]: row[1] for row in txn}
-
- @cached()
- def get_invited_rooms_for_user(self, user_id):
- """ Get all the rooms the user is invited to
- Args:
- user_id (str): The user ID.
- Returns:
- A deferred list of RoomsForUser.
- """
-
- return self.get_rooms_for_user_where_membership_is(user_id, [Membership.INVITE])
-
- @defer.inlineCallbacks
- def get_invite_for_user_in_room(self, user_id, room_id):
- """Gets the invite for the given user and room
-
- Args:
- user_id (str)
- room_id (str)
-
- Returns:
- Deferred: Resolves to either a RoomsForUser or None if no invite was
- found.
- """
- invites = yield self.get_invited_rooms_for_user(user_id)
- for invite in invites:
- if invite.room_id == room_id:
- defer.returnValue(invite)
- defer.returnValue(None)
-
- def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
- """ Get all the rooms for this user where the membership for this user
- matches one in the membership list.
-
- Args:
- user_id (str): The user ID.
- membership_list (list): A list of synapse.api.constants.Membership
- values which the user must be in.
- Returns:
- A list of dictionary objects, with room_id, membership and sender
- defined.
- """
- if not membership_list:
- return defer.succeed(None)
-
- return self.runInteraction(
- "get_rooms_for_user_where_membership_is",
- self._get_rooms_for_user_where_membership_is_txn,
- user_id,
- membership_list,
- )
-
- def _get_rooms_for_user_where_membership_is_txn(
- self, txn, user_id, membership_list
- ):
-
- do_invite = Membership.INVITE in membership_list
- membership_list = [m for m in membership_list if m != Membership.INVITE]
-
- results = []
- if membership_list:
- where_clause = "user_id = ? AND (%s) AND forgotten = 0" % (
- " OR ".join(["membership = ?" for _ in membership_list]),
- )
-
- args = [user_id]
- args.extend(membership_list)
-
- sql = (
- "SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering"
- " FROM current_state_events as c"
- " INNER JOIN room_memberships as m"
- " ON m.event_id = c.event_id"
- " INNER JOIN events as e"
- " ON e.event_id = c.event_id"
- " AND m.room_id = c.room_id"
- " AND m.user_id = c.state_key"
- " WHERE c.type = 'm.room.member' AND %s"
- ) % (where_clause,)
-
- txn.execute(sql, args)
- results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)]
-
- if do_invite:
- sql = (
- "SELECT i.room_id, inviter, i.event_id, e.stream_ordering"
- " FROM local_invites as i"
- " INNER JOIN events as e USING (event_id)"
- " WHERE invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
-
- txn.execute(sql, (user_id,))
- results.extend(
- RoomsForUser(
- room_id=r["room_id"],
- sender=r["inviter"],
- event_id=r["event_id"],
- stream_ordering=r["stream_ordering"],
- membership=Membership.INVITE,
- )
- for r in self.cursor_to_dict(txn)
- )
-
- return results
-
- @cachedInlineCallbacks(max_entries=500000, iterable=True)
- def get_rooms_for_user_with_stream_ordering(self, user_id):
- """Returns a set of room_ids the user is currently joined to
-
- Args:
- user_id (str)
-
- Returns:
- Deferred[frozenset[GetRoomsForUserWithStreamOrdering]]: Returns
- the rooms the user is in currently, along with the stream ordering
- of the most recent join for that user and room.
- """
- rooms = yield self.get_rooms_for_user_where_membership_is(
- user_id, membership_list=[Membership.JOIN]
- )
- defer.returnValue(
- frozenset(
- GetRoomsForUserWithStreamOrdering(r.room_id, r.stream_ordering)
- for r in rooms
- )
- )
-
- @defer.inlineCallbacks
- def get_rooms_for_user(self, user_id, on_invalidate=None):
- """Returns a set of room_ids the user is currently joined to
- """
- rooms = yield self.get_rooms_for_user_with_stream_ordering(
- user_id, on_invalidate=on_invalidate
- )
- defer.returnValue(frozenset(r.room_id for r in rooms))
-
- @cachedInlineCallbacks(max_entries=500000, cache_context=True, iterable=True)
- def get_users_who_share_room_with_user(self, user_id, cache_context):
- """Returns the set of users who share a room with `user_id`
- """
- room_ids = yield self.get_rooms_for_user(
- user_id, on_invalidate=cache_context.invalidate
- )
-
- user_who_share_room = set()
- for room_id in room_ids:
- user_ids = yield self.get_users_in_room(
- room_id, on_invalidate=cache_context.invalidate
- )
- user_who_share_room.update(user_ids)
-
- defer.returnValue(user_who_share_room)
-
- @defer.inlineCallbacks
- def get_joined_users_from_context(self, event, context):
- state_group = context.state_group
- if not state_group:
- # If state_group is None it means it has yet to be assigned a
- # state group, i.e. we need to make sure that calls with a state_group
- # of None don't hit previous cached calls with a None state_group.
- # To do this we set the state_group to a new object as object() != object()
- state_group = object()
-
- current_state_ids = yield context.get_current_state_ids(self)
- result = yield self._get_joined_users_from_context(
- event.room_id, state_group, current_state_ids, event=event, context=context
- )
- defer.returnValue(result)
-
- def get_joined_users_from_state(self, room_id, state_entry):
- state_group = state_entry.state_group
- if not state_group:
- # If state_group is None it means it has yet to be assigned a
- # state group, i.e. we need to make sure that calls with a state_group
- # of None don't hit previous cached calls with a None state_group.
- # To do this we set the state_group to a new object as object() != object()
- state_group = object()
-
- return self._get_joined_users_from_context(
- room_id, state_group, state_entry.state, context=state_entry
- )
-
- @cachedInlineCallbacks(
- num_args=2, cache_context=True, iterable=True, max_entries=100000
- )
- def _get_joined_users_from_context(
- self,
- room_id,
- state_group,
- current_state_ids,
- cache_context,
- event=None,
- context=None,
- ):
- # We don't use `state_group`, it's there so that we can cache based
- # on it. However, it's important that it's never None, since two current_states
- # with a state_group of None are likely to be different.
- # See bulk_get_push_rules_for_room for how we work around this.
- assert state_group is not None
-
- users_in_room = {}
- member_event_ids = [
- e_id
- for key, e_id in iteritems(current_state_ids)
- if key[0] == EventTypes.Member
- ]
-
- if context is not None:
- # If we have a context with a delta from a previous state group,
- # check if we also have the result from the previous group in cache.
- # If we do then we can reuse that result and simply update it with
- # any membership changes in `delta_ids`
- if context.prev_group and context.delta_ids:
- prev_res = self._get_joined_users_from_context.cache.get(
- (room_id, context.prev_group), None
- )
- if prev_res and isinstance(prev_res, dict):
- users_in_room = dict(prev_res)
- member_event_ids = [
- e_id
- for key, e_id in iteritems(context.delta_ids)
- if key[0] == EventTypes.Member
- ]
- for etype, state_key in context.delta_ids:
- users_in_room.pop(state_key, None)
-
- # We check if we have any of the member event ids in the event cache
- # before we ask the DB
-
- # We don't update the event cache hit ratio as it completely throws off
- # the hit ratio counts. After all, we don't populate the cache if we
- # miss it here
- event_map = self._get_events_from_cache(
- member_event_ids, allow_rejected=False, update_metrics=False
- )
-
- missing_member_event_ids = []
- for event_id in member_event_ids:
- ev_entry = event_map.get(event_id)
- if ev_entry:
- if ev_entry.event.membership == Membership.JOIN:
- users_in_room[to_ascii(ev_entry.event.state_key)] = ProfileInfo(
- display_name=to_ascii(
- ev_entry.event.content.get("displayname", None)
- ),
- avatar_url=to_ascii(
- ev_entry.event.content.get("avatar_url", None)
- ),
- )
- else:
- missing_member_event_ids.append(event_id)
-
- if missing_member_event_ids:
- rows = yield self._simple_select_many_batch(
- table="room_memberships",
- column="event_id",
- iterable=missing_member_event_ids,
- retcols=('user_id', 'display_name', 'avatar_url'),
- keyvalues={"membership": Membership.JOIN},
- batch_size=500,
- desc="_get_joined_users_from_context",
- )
-
- users_in_room.update(
- {
- to_ascii(row["user_id"]): ProfileInfo(
- avatar_url=to_ascii(row["avatar_url"]),
- display_name=to_ascii(row["display_name"]),
- )
- for row in rows
- }
- )
-
- if event is not None and event.type == EventTypes.Member:
- if event.membership == Membership.JOIN:
- if event.event_id in member_event_ids:
- users_in_room[to_ascii(event.state_key)] = ProfileInfo(
- display_name=to_ascii(event.content.get("displayname", None)),
- avatar_url=to_ascii(event.content.get("avatar_url", None)),
- )
-
- defer.returnValue(users_in_room)
-
- @cachedInlineCallbacks(max_entries=10000)
- def is_host_joined(self, room_id, host):
- if '%' in host or '_' in host:
- raise Exception("Invalid host name")
-
- sql = """
- SELECT state_key FROM current_state_events AS c
- INNER JOIN room_memberships USING (event_id)
- WHERE membership = 'join'
- AND type = 'm.room.member'
- AND c.room_id = ?
- AND state_key LIKE ?
- LIMIT 1
- """
-
- # We do need to be careful to ensure that host doesn't have any wild cards
- # in it, but we checked above for known ones and we'll check below that
- # the returned user actually has the correct domain.
- like_clause = "%:" + host
-
- rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause)
-
- if not rows:
- defer.returnValue(False)
-
- user_id = rows[0][0]
- if get_domain_from_id(user_id) != host:
- # This can only happen if the host name has something funky in it
- raise Exception("Invalid host name")
-
- defer.returnValue(True)
-
- @cachedInlineCallbacks()
- def was_host_joined(self, room_id, host):
- """Check whether the server is or ever was in the room.
-
- Args:
- room_id (str)
- host (str)
-
- Returns:
- Deferred: Resolves to True if the host is/was in the room, otherwise
- False.
- """
- if '%' in host or '_' in host:
- raise Exception("Invalid host name")
-
- sql = """
- SELECT user_id FROM room_memberships
- WHERE room_id = ?
- AND user_id LIKE ?
- AND membership = 'join'
- LIMIT 1
- """
-
- # We do need to be careful to ensure that host doesn't have any wild cards
- # in it, but we checked above for known ones and we'll check below that
- # the returned user actually has the correct domain.
- like_clause = "%:" + host
-
- rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause)
-
- if not rows:
- defer.returnValue(False)
-
- user_id = rows[0][0]
- if get_domain_from_id(user_id) != host:
- # This can only happen if the host name has something funky in it
- raise Exception("Invalid host name")
-
- defer.returnValue(True)
-
- def get_joined_hosts(self, room_id, state_entry):
- state_group = state_entry.state_group
- if not state_group:
- # If state_group is None it means it has yet to be assigned a
- # state group, i.e. we need to make sure that calls with a state_group
- # of None don't hit previous cached calls with a None state_group.
- # To do this we set the state_group to a new object as object() != object()
- state_group = object()
-
- return self._get_joined_hosts(
- room_id, state_group, state_entry.state, state_entry=state_entry
- )
-
- @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True)
- # @defer.inlineCallbacks
- def _get_joined_hosts(self, room_id, state_group, current_state_ids, state_entry):
- # We don't use `state_group`, its there so that we can cache based
- # on it. However, its important that its never None, since two current_state's
- # with a state_group of None are likely to be different.
- # See bulk_get_push_rules_for_room for how we work around this.
- assert state_group is not None
-
- cache = self._get_joined_hosts_cache(room_id)
- joined_hosts = yield cache.get_destinations(state_entry)
-
- defer.returnValue(joined_hosts)
-
- @cached(max_entries=10000)
- def _get_joined_hosts_cache(self, room_id):
- return _JoinedHostsCache(self, room_id)
-
- @cachedInlineCallbacks(num_args=2)
- def did_forget(self, user_id, room_id):
- """Returns whether user_id has elected to discard history for room_id.
-
- Returns False if they have since re-joined."""
-
- def f(txn):
- sql = (
- "SELECT"
- " COUNT(*)"
- " FROM"
- " room_memberships"
- " WHERE"
- " user_id = ?"
- " AND"
- " room_id = ?"
- " AND"
- " forgotten = 0"
- )
- txn.execute(sql, (user_id, room_id))
- rows = txn.fetchall()
- return rows[0][0]
-
- count = yield self.runInteraction("did_forget_membership", f)
- defer.returnValue(count == 0)
-
-
-class RoomMemberStore(RoomMemberWorkerStore):
- def __init__(self, db_conn, hs):
- super(RoomMemberStore, self).__init__(db_conn, hs)
- self.register_background_update_handler(
- _MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
- )
-
- def _store_room_members_txn(self, txn, events, backfilled):
- """Store a room member in the database.
- """
- self._simple_insert_many_txn(
- txn,
- table="room_memberships",
- values=[
- {
- "event_id": event.event_id,
- "user_id": event.state_key,
- "sender": event.user_id,
- "room_id": event.room_id,
- "membership": event.membership,
- "display_name": event.content.get("displayname", None),
- "avatar_url": event.content.get("avatar_url", None),
- }
- for event in events
- ],
- )
-
- for event in events:
- txn.call_after(
- self._membership_stream_cache.entity_has_changed,
- event.state_key,
- event.internal_metadata.stream_ordering,
- )
- txn.call_after(
- self.get_invited_rooms_for_user.invalidate, (event.state_key,)
- )
-
- # We update the local_invites table only if the event is "current",
- # i.e., its something that has just happened. If the event is an
- # outlier it is only current if its an "out of band membership",
- # like a remote invite or a rejection of a remote invite.
- is_new_state = not backfilled and (
- not event.internal_metadata.is_outlier()
- or event.internal_metadata.is_out_of_band_membership()
- )
- is_mine = self.hs.is_mine_id(event.state_key)
- if is_new_state and is_mine:
- if event.membership == Membership.INVITE:
- self._simple_insert_txn(
- txn,
- table="local_invites",
- values={
- "event_id": event.event_id,
- "invitee": event.state_key,
- "inviter": event.sender,
- "room_id": event.room_id,
- "stream_id": event.internal_metadata.stream_ordering,
- },
- )
- else:
- sql = (
- "UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
- " room_id = ? AND invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
-
- txn.execute(
- sql,
- (
- event.internal_metadata.stream_ordering,
- event.event_id,
- event.room_id,
- event.state_key,
- ),
- )
-
- @defer.inlineCallbacks
- def locally_reject_invite(self, user_id, room_id):
- sql = (
- "UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
- " room_id = ? AND invitee = ? AND locally_rejected is NULL"
- " AND replaced_by is NULL"
- )
-
- def f(txn, stream_ordering):
- txn.execute(sql, (stream_ordering, True, room_id, user_id))
-
- with self._stream_id_gen.get_next() as stream_ordering:
- yield self.runInteraction("locally_reject_invite", f, stream_ordering)
-
- def forget(self, user_id, room_id):
- """Indicate that user_id wishes to discard history for room_id."""
-
- def f(txn):
- sql = (
- "UPDATE"
- " room_memberships"
- " SET"
- " forgotten = 1"
- " WHERE"
- " user_id = ?"
- " AND"
- " room_id = ?"
- )
- txn.execute(sql, (user_id, room_id))
-
- self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
-
- return self.runInteraction("forget_membership", f)
-
- @defer.inlineCallbacks
- def _background_add_membership_profile(self, progress, batch_size):
- target_min_stream_id = progress.get(
- "target_min_stream_id_inclusive", self._min_stream_order_on_start
- )
- max_stream_id = progress.get(
- "max_stream_id_exclusive", self._stream_order_on_start + 1
- )
-
- INSERT_CLUMP_SIZE = 1000
-
- def add_membership_profile_txn(txn):
- sql = """
- SELECT stream_ordering, event_id, events.room_id, event_json.json
- FROM events
- INNER JOIN event_json USING (event_id)
- INNER JOIN room_memberships USING (event_id)
- WHERE ? <= stream_ordering AND stream_ordering < ?
- AND type = 'm.room.member'
- ORDER BY stream_ordering DESC
- LIMIT ?
- """
-
- txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
-
- rows = self.cursor_to_dict(txn)
- if not rows:
- return 0
-
- min_stream_id = rows[-1]["stream_ordering"]
-
- to_update = []
- for row in rows:
- event_id = row["event_id"]
- room_id = row["room_id"]
- try:
- event_json = json.loads(row["json"])
- content = event_json['content']
- except Exception:
- continue
-
- display_name = content.get("displayname", None)
- avatar_url = content.get("avatar_url", None)
-
- if display_name or avatar_url:
- to_update.append((display_name, avatar_url, event_id, room_id))
-
- to_update_sql = """
- UPDATE room_memberships SET display_name = ?, avatar_url = ?
- WHERE event_id = ? AND room_id = ?
- """
- for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
- clump = to_update[index : index + INSERT_CLUMP_SIZE]
- txn.executemany(to_update_sql, clump)
-
- progress = {
- "target_min_stream_id_inclusive": target_min_stream_id,
- "max_stream_id_exclusive": min_stream_id,
- }
-
- self._background_update_progress_txn(
- txn, _MEMBERSHIP_PROFILE_UPDATE_NAME, progress
- )
-
- return len(rows)
-
- result = yield self.runInteraction(
- _MEMBERSHIP_PROFILE_UPDATE_NAME, add_membership_profile_txn
- )
-
- if not result:
- yield self._end_background_update(_MEMBERSHIP_PROFILE_UPDATE_NAME)
-
- defer.returnValue(result)
-
-
-class _JoinedHostsCache(object):
- """Cache for joined hosts in a room that is optimised to handle updates
- via state deltas.
- """
-
- def __init__(self, store, room_id):
- self.store = store
- self.room_id = room_id
-
- self.hosts_to_joined_users = {}
-
- self.state_group = object()
-
- self.linearizer = Linearizer("_JoinedHostsCache")
-
- self._len = 0
-
- @defer.inlineCallbacks
- def get_destinations(self, state_entry):
- """Get set of destinations for a state entry
-
- Args:
- state_entry(synapse.state._StateCacheEntry)
- """
- if state_entry.state_group == self.state_group:
- defer.returnValue(frozenset(self.hosts_to_joined_users))
-
- with (yield self.linearizer.queue(())):
- if state_entry.state_group == self.state_group:
- pass
- elif state_entry.prev_group == self.state_group:
- for (typ, state_key), event_id in iteritems(state_entry.delta_ids):
- if typ != EventTypes.Member:
- continue
-
- host = intern_string(get_domain_from_id(state_key))
- user_id = state_key
- known_joins = self.hosts_to_joined_users.setdefault(host, set())
-
- event = yield self.store.get_event(event_id)
- if event.membership == Membership.JOIN:
- known_joins.add(user_id)
- else:
- known_joins.discard(user_id)
-
- if not known_joins:
- self.hosts_to_joined_users.pop(host, None)
- else:
- joined_users = yield self.store.get_joined_users_from_state(
- self.room_id, state_entry
- )
-
- self.hosts_to_joined_users = {}
- for user_id in joined_users:
- host = intern_string(get_domain_from_id(user_id))
- self.hosts_to_joined_users.setdefault(host, set()).add(user_id)
-
- if state_entry.state_group:
- self.state_group = state_entry.state_group
- else:
- self.state_group = object()
- self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users))
- defer.returnValue(frozenset(self.hosts_to_joined_users))
-
- def __len__(self):
- return self._len
diff --git a/synapse/storage/schema/delta/35/00background_updates_add_col.sql b/synapse/storage/schema/delta/35/00background_updates_add_col.sql
new file mode 100644
index 0000000000..c2d2a4f836
--- /dev/null
+++ b/synapse/storage/schema/delta/35/00background_updates_add_col.sql
@@ -0,0 +1,17 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+ALTER TABLE background_updates ADD COLUMN depends_on TEXT;
diff --git a/synapse/storage/schema/full_schemas/54/full.sql b/synapse/storage/schema/full_schemas/54/full.sql
new file mode 100644
index 0000000000..1005880466
--- /dev/null
+++ b/synapse/storage/schema/full_schemas/54/full.sql
@@ -0,0 +1,8 @@
+
+
+CREATE TABLE background_updates (
+ update_name text NOT NULL,
+ progress_json text NOT NULL,
+ depends_on text,
+ CONSTRAINT background_updates_uniqueness UNIQUE (update_name)
+);
diff --git a/synapse/storage/schema/full_schemas/README.txt b/synapse/storage/schema/full_schemas/README.txt
deleted file mode 100644
index d3f6401344..0000000000
--- a/synapse/storage/schema/full_schemas/README.txt
+++ /dev/null
@@ -1,19 +0,0 @@
-Building full schema dumps
-==========================
-
-These schemas need to be made from a database that has had all background updates run.
-
-Postgres
---------
-
-$ pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner $DATABASE_NAME| sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > full.sql.postgres
-
-SQLite
-------
-
-$ sqlite3 $DATABASE_FILE ".schema" > full.sql.sqlite
-
-After
------
-
-Delete the CREATE statements for "sqlite_stat1", "schema_version", "applied_schema_deltas", and "applied_module_schemas".
\ No newline at end of file
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 0bfe1b4550..c522c80922 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -14,43 +14,21 @@
# limitations under the License.
import logging
-from collections import namedtuple
+from typing import Iterable, List, TypeVar
from six import iteritems, itervalues
-from six.moves import range
import attr
from twisted.internet import defer
from synapse.api.constants import EventTypes
-from synapse.api.errors import NotFoundError
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.background_updates import BackgroundUpdateStore
-from synapse.storage.engines import PostgresEngine
-from synapse.storage.events_worker import EventsWorkerStore
-from synapse.util.caches import get_cache_factor_for, intern_string
-from synapse.util.caches.descriptors import cached, cachedList
-from synapse.util.caches.dictionary_cache import DictionaryCache
-from synapse.util.stringutils import to_ascii
+from synapse.types import StateMap
logger = logging.getLogger(__name__)
-
-MAX_STATE_DELTA_HOPS = 100
-
-
-class _GetStateGroupDelta(
- namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
-):
- """Return type of get_state_group_delta that implements __len__, which lets
- us use the itrable flag when caching
- """
-
- __slots__ = []
-
- def __len__(self):
- return len(self.delta_ids) if self.delta_ids else 0
+# Used for generic functions below
+T = TypeVar("T")
@attr.s(slots=True)
@@ -260,14 +238,14 @@ class StateFilter(object):
return len(self.concrete_types())
- def filter_state(self, state_dict):
+ def filter_state(self, state_dict: StateMap[T]) -> StateMap[T]:
"""Returns the state filtered with by this StateFilter
Args:
- state (dict[tuple[str, str], Any]): The state map to filter
+ state: The state map to filter
Returns:
- dict[tuple[str, str], Any]: The filtered state map
+ The filtered state map
"""
if self.is_full():
return dict(state_dict)
@@ -353,248 +331,23 @@ class StateFilter(object):
return member_filter, non_member_filter
-# this inherits from EventsWorkerStore because it calls self.get_events
-class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
- """The parts of StateGroupStore that can be called from workers.
+class StateGroupStorage(object):
+ """High level interface to fetching state for event.
"""
- STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
- STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
- CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
-
- def __init__(self, db_conn, hs):
- super(StateGroupWorkerStore, self).__init__(db_conn, hs)
-
- # Originally the state store used a single DictionaryCache to cache the
- # event IDs for the state types in a given state group to avoid hammering
- # on the state_group* tables.
- #
- # The point of using a DictionaryCache is that it can cache a subset
- # of the state events for a given state group (i.e. a subset of the keys for a
- # given dict which is an entry in the cache for a given state group ID).
- #
- # However, this poses problems when performing complicated queries
- # on the store - for instance: "give me all the state for this group, but
- # limit members to this subset of users", as DictionaryCache's API isn't
- # rich enough to say "please cache any of these fields, apart from this subset".
- # This is problematic when lazy loading members, which requires this behaviour,
- # as without it the cache has no choice but to speculatively load all
- # state events for the group, which negates the efficiency being sought.
- #
- # Rather than overcomplicating DictionaryCache's API, we instead split the
- # state_group_cache into two halves - one for tracking non-member events,
- # and the other for tracking member_events. This means that lazy loading
- # queries can be made in a cache-friendly manner by querying both caches
- # separately and then merging the result. So for the example above, you
- # would query the members cache for a specific subset of state keys
- # (which DictionaryCache will handle efficiently and fine) and the non-members
- # cache for all state (which DictionaryCache will similarly handle fine)
- # and then just merge the results together.
- #
- # We size the non-members cache to be smaller than the members cache as the
- # vast majority of state in Matrix (today) is member events.
-
- self._state_group_cache = DictionaryCache(
- "*stateGroupCache*",
- # TODO: this hasn't been tuned yet
- 50000 * get_cache_factor_for("stateGroupCache"),
- )
- self._state_group_members_cache = DictionaryCache(
- "*stateGroupMembersCache*",
- 500000 * get_cache_factor_for("stateGroupMembersCache"),
- )
-
- @defer.inlineCallbacks
- def get_room_version(self, room_id):
- """Get the room_version of a given room
-
- Args:
- room_id (str)
-
- Returns:
- Deferred[str]
-
- Raises:
- NotFoundError if the room is unknown
- """
- # for now we do this by looking at the create event. We may want to cache this
- # more intelligently in future.
-
- # Retrieve the room's create event
- create_event = yield self.get_create_event_for_room(room_id)
- defer.returnValue(create_event.content.get("room_version", "1"))
-
- @defer.inlineCallbacks
- def get_room_predecessor(self, room_id):
- """Get the predecessor room of an upgraded room if one exists.
- Otherwise return None.
-
- Args:
- room_id (str)
-
- Returns:
- Deferred[unicode|None]: predecessor room id
-
- Raises:
- NotFoundError if the room is unknown
- """
- # Retrieve the room's create event
- create_event = yield self.get_create_event_for_room(room_id)
-
- # Return predecessor if present
- defer.returnValue(create_event.content.get("predecessor", None))
-
- @defer.inlineCallbacks
- def get_create_event_for_room(self, room_id):
- """Get the create state event for a room.
-
- Args:
- room_id (str)
-
- Returns:
- Deferred[EventBase]: The room creation event.
-
- Raises:
- NotFoundError if the room is unknown
- """
- state_ids = yield self.get_current_state_ids(room_id)
- create_id = state_ids.get((EventTypes.Create, ""))
+ def __init__(self, hs, stores):
+ self.stores = stores
- # If we can't find the create event, assume we've hit a dead end
- if not create_id:
- raise NotFoundError("Unknown room %s" % (room_id))
-
- # Retrieve the room's create event and return
- create_event = yield self.get_event(create_id)
- defer.returnValue(create_event)
-
- @cached(max_entries=100000, iterable=True)
- def get_current_state_ids(self, room_id):
- """Get the current state event ids for a room based on the
- current_state_events table.
-
- Args:
- room_id (str)
-
- Returns:
- deferred: dict of (type, state_key) -> event_id
- """
-
- def _get_current_state_ids_txn(txn):
- txn.execute(
- """SELECT type, state_key, event_id FROM current_state_events
- WHERE room_id = ?
- """,
- (room_id,),
- )
-
- return {
- (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
- }
-
- return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn)
-
- # FIXME: how should this be cached?
- def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
- """Get the current state event of a given type for a room based on the
- current_state_events table. This may not be as up-to-date as the result
- of doing a fresh state resolution as per state_handler.get_current_state
-
- Args:
- room_id (str)
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
-
- Returns:
- Deferred[dict[tuple[str, str], str]]: Map from type/state_key to
- event ID.
- """
-
- def _get_filtered_current_state_ids_txn(txn):
- results = {}
- sql = """
- SELECT type, state_key, event_id FROM current_state_events
- WHERE room_id = ?
- """
-
- where_clause, where_args = state_filter.make_sql_filter_clause()
-
- if where_clause:
- sql += " AND (%s)" % (where_clause,)
-
- args = [room_id]
- args.extend(where_args)
- txn.execute(sql, args)
- for row in txn:
- typ, state_key, event_id = row
- key = (intern_string(typ), intern_string(state_key))
- results[key] = event_id
-
- return results
-
- return self.runInteraction(
- "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
- )
-
- @defer.inlineCallbacks
- def get_canonical_alias_for_room(self, room_id):
- """Get canonical alias for room, if any
-
- Args:
- room_id (str)
-
- Returns:
- Deferred[str|None]: The canonical alias, if any
- """
-
- state = yield self.get_filtered_current_state_ids(
- room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
- )
-
- event_id = state.get((EventTypes.CanonicalAlias, ""))
- if not event_id:
- return
-
- event = yield self.get_event(event_id, allow_none=True)
- if not event:
- return
-
- defer.returnValue(event.content.get("canonical_alias"))
-
- @cached(max_entries=10000, iterable=True)
- def get_state_group_delta(self, state_group):
+ def get_state_group_delta(self, state_group: int):
"""Given a state group try to return a previous group and a delta between
the old and the new.
Returns:
- (prev_group, delta_ids), where both may be None.
+ Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
+ (prev_group, delta_ids)
"""
- def _get_state_group_delta_txn(txn):
- prev_group = self._simple_select_one_onecol_txn(
- txn,
- table="state_group_edges",
- keyvalues={"state_group": state_group},
- retcol="prev_state_group",
- allow_none=True,
- )
-
- if not prev_group:
- return _GetStateGroupDelta(None, None)
-
- delta_ids = self._simple_select_list_txn(
- txn,
- table="state_groups_state",
- keyvalues={"state_group": state_group},
- retcols=("type", "state_key", "event_id"),
- )
-
- return _GetStateGroupDelta(
- prev_group,
- {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
- )
-
- return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn)
+ return self.stores.state.get_state_group_delta(state_group)
@defer.inlineCallbacks
def get_state_groups_ids(self, _room_id, event_ids):
@@ -605,18 +358,18 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
event_ids (iterable[str]): ids of the events
Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
+ Deferred[dict[int, StateMap[str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
if not event_ids:
- defer.returnValue({})
+ return {}
- event_to_groups = yield self._get_state_group_for_events(event_ids)
+ event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
- group_to_state = yield self._get_state_for_groups(groups)
+ group_to_state = yield self.stores.state._get_state_for_groups(groups)
- defer.returnValue(group_to_state)
+ return group_to_state
@defer.inlineCallbacks
def get_state_ids_for_group(self, state_group):
@@ -630,22 +383,21 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
group_to_state = yield self._get_state_for_groups((state_group,))
- defer.returnValue(group_to_state[state_group])
+ return group_to_state[state_group]
@defer.inlineCallbacks
def get_state_groups(self, room_id, event_ids):
""" Get the state groups for the given list of event_ids
-
Returns:
Deferred[dict[int, list[EventBase]]]:
dict of state_group_id -> list of state events.
"""
if not event_ids:
- defer.returnValue({})
+ return {}
group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
- state_event_map = yield self.get_events(
+ state_event_map = yield self.stores.main.get_events(
[
ev_id
for group_ids in itervalues(group_to_ids)
@@ -654,164 +406,50 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
get_prev_content=False,
)
- defer.returnValue(
- {
- group: [
- state_event_map[v]
- for v in itervalues(event_id_map)
- if v in state_event_map
- ]
- for group, event_id_map in iteritems(group_to_ids)
- }
- )
+ return {
+ group: [
+ state_event_map[v]
+ for v in itervalues(event_id_map)
+ if v in state_event_map
+ ]
+ for group, event_id_map in iteritems(group_to_ids)
+ }
- @defer.inlineCallbacks
- def _get_state_groups_from_groups(self, groups, state_filter):
+ def _get_state_groups_from_groups(
+ self, groups: List[int], state_filter: StateFilter
+ ):
"""Returns the state groups for a given set of groups, filtering on
types of state events.
Args:
- groups(list[int]): list of state group IDs to query
- state_filter (StateFilter): The state filter used to fetch state
+ groups: list of state group IDs to query
+ state_filter: The state filter used to fetch state
from the database.
Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
+ Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
"""
- results = {}
-
- chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
- for chunk in chunks:
- res = yield self.runInteraction(
- "_get_state_groups_from_groups",
- self._get_state_groups_from_groups_txn,
- chunk,
- state_filter,
- )
- results.update(res)
- defer.returnValue(results)
-
- def _get_state_groups_from_groups_txn(
- self, txn, groups, state_filter=StateFilter.all()
- ):
- results = {group: {} for group in groups}
-
- where_clause, where_args = state_filter.make_sql_filter_clause()
-
- # Unless the filter clause is empty, we're going to append it after an
- # existing where clause
- if where_clause:
- where_clause = " AND (%s)" % (where_clause,)
-
- if isinstance(self.database_engine, PostgresEngine):
- # Temporarily disable sequential scans in this transaction. This is
- # a temporary hack until we can add the right indices in
- txn.execute("SET LOCAL enable_seqscan=off")
-
- # The below query walks the state_group tree so that the "state"
- # table includes all state_groups in the tree. It then joins
- # against `state_groups_state` to fetch the latest state.
- # It assumes that previous state groups are always numerically
- # lesser.
- # The PARTITION is used to get the event_id in the greatest state
- # group for the given type, state_key.
- # This may return multiple rows per (type, state_key), but last_value
- # should be the same.
- sql = """
- WITH RECURSIVE state(state_group) AS (
- VALUES(?::bigint)
- UNION ALL
- SELECT prev_state_group FROM state_group_edges e, state s
- WHERE s.state_group = e.state_group
- )
- SELECT DISTINCT type, state_key, last_value(event_id) OVER (
- PARTITION BY type, state_key ORDER BY state_group ASC
- ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
- ) AS event_id FROM state_groups_state
- WHERE state_group IN (
- SELECT state_group FROM state
- )
- """
-
- for group in groups:
- args = [group]
- args.extend(where_args)
-
- txn.execute(sql + where_clause, args)
- for row in txn:
- typ, state_key, event_id = row
- key = (typ, state_key)
- results[group][key] = event_id
- else:
- max_entries_returned = state_filter.max_entries_returned()
-
- # We don't use WITH RECURSIVE on sqlite3 as there are distributions
- # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
- for group in groups:
- next_group = group
-
- while next_group:
- # We did this before by getting the list of group ids, and
- # then passing that list to sqlite to get latest event for
- # each (type, state_key). However, that was terribly slow
- # without the right indices (which we can't add until
- # after we finish deduping state, which requires this func)
- args = [next_group]
- args.extend(where_args)
-
- txn.execute(
- "SELECT type, state_key, event_id FROM state_groups_state"
- " WHERE state_group = ? " + where_clause,
- args,
- )
- results[group].update(
- ((typ, state_key), event_id)
- for typ, state_key, event_id in txn
- if (typ, state_key) not in results[group]
- )
-
- # If the number of entries in the (type,state_key)->event_id dict
- # matches the number of (type,state_keys) types we were searching
- # for, then we must have found them all, so no need to go walk
- # further down the tree... UNLESS our types filter contained
- # wildcards (i.e. Nones) in which case we have to do an exhaustive
- # search
- if (
- max_entries_returned is not None
- and len(results[group]) == max_entries_returned
- ):
- break
-
- next_group = self._simple_select_one_onecol_txn(
- txn,
- table="state_group_edges",
- keyvalues={"state_group": next_group},
- retcol="prev_state_group",
- allow_none=True,
- )
-
- return results
+ return self.stores.state._get_state_groups_from_groups(groups, state_filter)
@defer.inlineCallbacks
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event.
-
Args:
event_ids (list[string])
state_filter (StateFilter): The state filter used to fetch state
from the database.
-
Returns:
deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
"""
- event_to_groups = yield self._get_state_group_for_events(event_ids)
+ event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
- group_to_state = yield self._get_state_for_groups(groups, state_filter)
+ group_to_state = yield self.stores.state._get_state_for_groups(
+ groups, state_filter
+ )
- state_event_map = yield self.get_events(
+ state_event_map = yield self.stores.main.get_events(
[ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
get_prev_content=False,
)
@@ -825,7 +463,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
for event_id, group in iteritems(event_to_groups)
}
- defer.returnValue({event: event_to_state[event] for event in event_ids})
+ return {event: event_to_state[event] for event in event_ids}
@defer.inlineCallbacks
def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
@@ -841,17 +479,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
A deferred dict from event_id -> (type, state_key) -> event_id
"""
- event_to_groups = yield self._get_state_group_for_events(event_ids)
+ event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
groups = set(itervalues(event_to_groups))
- group_to_state = yield self._get_state_for_groups(groups, state_filter)
+ group_to_state = yield self.stores.state._get_state_for_groups(
+ groups, state_filter
+ )
event_to_state = {
event_id: group_to_state[group]
for event_id, group in iteritems(event_to_groups)
}
- defer.returnValue({event: event_to_state[event] for event in event_ids})
+ return {event: event_to_state[event] for event in event_ids}
@defer.inlineCallbacks
def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
@@ -867,7 +507,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
A deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_for_events([event_id], state_filter)
- defer.returnValue(state_map[event_id])
+ return state_map[event_id]
@defer.inlineCallbacks
def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
@@ -883,79 +523,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
A deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_ids_for_events([event_id], state_filter)
- defer.returnValue(state_map[event_id])
-
- @cached(max_entries=50000)
- def _get_state_group_for_event(self, event_id):
- return self._simple_select_one_onecol(
- table="event_to_state_groups",
- keyvalues={"event_id": event_id},
- retcol="state_group",
- allow_none=True,
- desc="_get_state_group_for_event",
- )
-
- @cachedList(
- cached_method_name="_get_state_group_for_event",
- list_name="event_ids",
- num_args=1,
- inlineCallbacks=True,
- )
- def _get_state_group_for_events(self, event_ids):
- """Returns mapping event_id -> state_group
- """
- rows = yield self._simple_select_many_batch(
- table="event_to_state_groups",
- column="event_id",
- iterable=event_ids,
- keyvalues={},
- retcols=("event_id", "state_group"),
- desc="_get_state_group_for_events",
- )
-
- defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
-
- def _get_state_for_group_using_cache(self, cache, group, state_filter):
- """Checks if group is in cache. See `_get_state_for_groups`
-
- Args:
- cache(DictionaryCache): the state group cache to use
- group(int): The state group to lookup
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
+ return state_map[event_id]
- Returns 2-tuple (`state_dict`, `got_all`).
- `got_all` is a bool indicating if we successfully retrieved all
- requests state from the cache, if False we need to query the DB for the
- missing state.
- """
- is_all, known_absent, state_dict_ids = cache.get(group)
-
- if is_all or state_filter.is_full():
- # Either we have everything or want everything, either way
- # `is_all` tells us whether we've gotten everything.
- return state_filter.filter_state(state_dict_ids), is_all
-
- # tracks whether any of our requested types are missing from the cache
- missing_types = False
-
- if state_filter.has_wildcards():
- # We don't know if we fetched all the state keys for the types in
- # the filter that are wildcards, so we have to assume that we may
- # have missed some.
- missing_types = True
- else:
- # There aren't any wild cards, so `concrete_types()` returns the
- # complete list of event types we're wanting.
- for key in state_filter.concrete_types():
- if key not in state_dict_ids and key not in known_absent:
- missing_types = True
- break
-
- return state_filter.filter_state(state_dict_ids), not missing_types
-
- @defer.inlineCallbacks
- def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
+ def _get_state_for_groups(
+ self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
+ ):
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
@@ -965,157 +537,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
state_filter (StateFilter): The state filter used to fetch state
from the database.
Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
+ Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
"""
-
- member_filter, non_member_filter = state_filter.get_member_split()
-
- # Now we look them up in the member and non-member caches
- non_member_state, incomplete_groups_nm, = (
- yield self._get_state_for_groups_using_cache(
- groups, self._state_group_cache, state_filter=non_member_filter
- )
- )
-
- member_state, incomplete_groups_m, = (
- yield self._get_state_for_groups_using_cache(
- groups, self._state_group_members_cache, state_filter=member_filter
- )
- )
-
- state = dict(non_member_state)
- for group in groups:
- state[group].update(member_state[group])
-
- # Now fetch any missing groups from the database
-
- incomplete_groups = incomplete_groups_m | incomplete_groups_nm
-
- if not incomplete_groups:
- defer.returnValue(state)
-
- cache_sequence_nm = self._state_group_cache.sequence
- cache_sequence_m = self._state_group_members_cache.sequence
-
- # Help the cache hit ratio by expanding the filter a bit
- db_state_filter = state_filter.return_expanded()
-
- group_to_state_dict = yield self._get_state_groups_from_groups(
- list(incomplete_groups), state_filter=db_state_filter
- )
-
- # Now lets update the caches
- self._insert_into_cache(
- group_to_state_dict,
- db_state_filter,
- cache_seq_num_members=cache_sequence_m,
- cache_seq_num_non_members=cache_sequence_nm,
- )
-
- # And finally update the result dict, by filtering out any extra
- # stuff we pulled out of the database.
- for group, group_state_dict in iteritems(group_to_state_dict):
- # We just replace any existing entries, as we will have loaded
- # everything we need from the database anyway.
- state[group] = state_filter.filter_state(group_state_dict)
-
- defer.returnValue(state)
-
- def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
- """Gets the state at each of a list of state groups, optionally
- filtering by type/state_key, querying from a specific cache.
-
- Args:
- groups (iterable[int]): list of state groups for which we want
- to get the state.
- cache (DictionaryCache): the cache of group ids to state dicts which
- we will pass through - either the normal state cache or the specific
- members state cache.
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
-
- Returns:
- tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of
- dict of state_group_id -> (dict of (type, state_key) -> event id)
- of entries in the cache, and the state group ids either missing
- from the cache or incomplete.
- """
- results = {}
- incomplete_groups = set()
- for group in set(groups):
- state_dict_ids, got_all = self._get_state_for_group_using_cache(
- cache, group, state_filter
- )
- results[group] = state_dict_ids
-
- if not got_all:
- incomplete_groups.add(group)
-
- return results, incomplete_groups
-
- def _insert_into_cache(
- self,
- group_to_state_dict,
- state_filter,
- cache_seq_num_members,
- cache_seq_num_non_members,
- ):
- """Inserts results from querying the database into the relevant cache.
-
- Args:
- group_to_state_dict (dict): The new entries pulled from database.
- Map from state group to state dict
- state_filter (StateFilter): The state filter used to fetch state
- from the database.
- cache_seq_num_members (int): Sequence number of member cache since
- last lookup in cache
- cache_seq_num_non_members (int): Sequence number of member cache since
- last lookup in cache
- """
-
- # We need to work out which types we've fetched from the DB for the
- # member vs non-member caches. This should be as accurate as possible,
- # but can be an underestimate (e.g. when we have wild cards)
-
- member_filter, non_member_filter = state_filter.get_member_split()
- if member_filter.is_full():
- # We fetched all member events
- member_types = None
- else:
- # `concrete_types()` will only return a subset when there are wild
- # cards in the filter, but that's fine.
- member_types = member_filter.concrete_types()
-
- if non_member_filter.is_full():
- # We fetched all non member events
- non_member_types = None
- else:
- non_member_types = non_member_filter.concrete_types()
-
- for group, group_state_dict in iteritems(group_to_state_dict):
- state_dict_members = {}
- state_dict_non_members = {}
-
- for k, v in iteritems(group_state_dict):
- if k[0] == EventTypes.Member:
- state_dict_members[k] = v
- else:
- state_dict_non_members[k] = v
-
- self._state_group_members_cache.update(
- cache_seq_num_members,
- key=group,
- value=state_dict_members,
- fetched_keys=member_types,
- )
-
- self._state_group_cache.update(
- cache_seq_num_non_members,
- key=group,
- value=state_dict_non_members,
- fetched_keys=non_member_types,
- )
+ return self.stores.state._get_state_for_groups(groups, state_filter)
def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids
@@ -1135,393 +559,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
Deferred[int]: The state group ID
"""
-
- def _store_state_group_txn(txn):
- if current_state_ids is None:
- # AFAIK, this can never happen
- raise Exception("current_state_ids cannot be None")
-
- state_group = self.database_engine.get_next_state_group_id(txn)
-
- self._simple_insert_txn(
- txn,
- table="state_groups",
- values={"id": state_group, "room_id": room_id, "event_id": event_id},
- )
-
- # We persist as a delta if we can, while also ensuring the chain
- # of deltas isn't tooo long, as otherwise read performance degrades.
- if prev_group:
- is_in_db = self._simple_select_one_onecol_txn(
- txn,
- table="state_groups",
- keyvalues={"id": prev_group},
- retcol="id",
- allow_none=True,
- )
- if not is_in_db:
- raise Exception(
- "Trying to persist state with unpersisted prev_group: %r"
- % (prev_group,)
- )
-
- potential_hops = self._count_state_group_hops_txn(txn, prev_group)
- if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
- self._simple_insert_txn(
- txn,
- table="state_group_edges",
- values={"state_group": state_group, "prev_state_group": prev_group},
- )
-
- self._simple_insert_many_txn(
- txn,
- table="state_groups_state",
- values=[
- {
- "state_group": state_group,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
- for key, state_id in iteritems(delta_ids)
- ],
- )
- else:
- self._simple_insert_many_txn(
- txn,
- table="state_groups_state",
- values=[
- {
- "state_group": state_group,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
- for key, state_id in iteritems(current_state_ids)
- ],
- )
-
- # Prefill the state group caches with this group.
- # It's fine to use the sequence like this as the state group map
- # is immutable. (If the map wasn't immutable then this prefill could
- # race with another update)
-
- current_member_state_ids = {
- s: ev
- for (s, ev) in iteritems(current_state_ids)
- if s[0] == EventTypes.Member
- }
- txn.call_after(
- self._state_group_members_cache.update,
- self._state_group_members_cache.sequence,
- key=state_group,
- value=dict(current_member_state_ids),
- )
-
- current_non_member_state_ids = {
- s: ev
- for (s, ev) in iteritems(current_state_ids)
- if s[0] != EventTypes.Member
- }
- txn.call_after(
- self._state_group_cache.update,
- self._state_group_cache.sequence,
- key=state_group,
- value=dict(current_non_member_state_ids),
- )
-
- return state_group
-
- return self.runInteraction("store_state_group", _store_state_group_txn)
-
- def _count_state_group_hops_txn(self, txn, state_group):
- """Given a state group, count how many hops there are in the tree.
-
- This is used to ensure the delta chains don't get too long.
- """
- if isinstance(self.database_engine, PostgresEngine):
- sql = """
- WITH RECURSIVE state(state_group) AS (
- VALUES(?::bigint)
- UNION ALL
- SELECT prev_state_group FROM state_group_edges e, state s
- WHERE s.state_group = e.state_group
- )
- SELECT count(*) FROM state;
- """
-
- txn.execute(sql, (state_group,))
- row = txn.fetchone()
- if row and row[0]:
- return row[0]
- else:
- return 0
- else:
- # We don't use WITH RECURSIVE on sqlite3 as there are distributions
- # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
- next_group = state_group
- count = 0
-
- while next_group:
- next_group = self._simple_select_one_onecol_txn(
- txn,
- table="state_group_edges",
- keyvalues={"state_group": next_group},
- retcol="prev_state_group",
- allow_none=True,
- )
- if next_group:
- count += 1
-
- return count
-
-
-class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
- """ Keeps track of the state at a given event.
-
- This is done by the concept of `state groups`. Every event is a assigned
- a state group (identified by an arbitrary string), which references a
- collection of state events. The current state of an event is then the
- collection of state events referenced by the event's state group.
-
- Hence, every change in the current state causes a new state group to be
- generated. However, if no change happens (e.g., if we get a message event
- with only one parent it inherits the state group from its parent.)
-
- There are three tables:
- * `state_groups`: Stores group name, first event with in the group and
- room id.
- * `event_to_state_groups`: Maps events to state groups.
- * `state_groups_state`: Maps state group to state events.
- """
-
- STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
- STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
- CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
- EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
-
- def __init__(self, db_conn, hs):
- super(StateStore, self).__init__(db_conn, hs)
- self.register_background_update_handler(
- self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
- self._background_deduplicate_state,
+ return self.stores.state.store_state_group(
+ event_id, room_id, prev_group, delta_ids, current_state_ids
)
- self.register_background_update_handler(
- self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
- )
- self.register_background_index_update(
- self.CURRENT_STATE_INDEX_UPDATE_NAME,
- index_name="current_state_events_member_index",
- table="current_state_events",
- columns=["state_key"],
- where_clause="type='m.room.member'",
- )
- self.register_background_index_update(
- self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME,
- index_name="event_to_state_groups_sg_index",
- table="event_to_state_groups",
- columns=["state_group"],
- )
-
- def _store_event_state_mappings_txn(self, txn, events_and_contexts):
- state_groups = {}
- for event, context in events_and_contexts:
- if event.internal_metadata.is_outlier():
- continue
-
- # if the event was rejected, just give it the same state as its
- # predecessor.
- if context.rejected:
- state_groups[event.event_id] = context.prev_group
- continue
-
- state_groups[event.event_id] = context.state_group
-
- self._simple_insert_many_txn(
- txn,
- table="event_to_state_groups",
- values=[
- {"state_group": state_group_id, "event_id": event_id}
- for event_id, state_group_id in iteritems(state_groups)
- ],
- )
-
- for event_id, state_group_id in iteritems(state_groups):
- txn.call_after(
- self._get_state_group_for_event.prefill, (event_id,), state_group_id
- )
-
- @defer.inlineCallbacks
- def _background_deduplicate_state(self, progress, batch_size):
- """This background update will slowly deduplicate state by reencoding
- them as deltas.
- """
- last_state_group = progress.get("last_state_group", 0)
- rows_inserted = progress.get("rows_inserted", 0)
- max_group = progress.get("max_group", None)
-
- BATCH_SIZE_SCALE_FACTOR = 100
-
- batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
-
- if max_group is None:
- rows = yield self._execute(
- "_background_deduplicate_state",
- None,
- "SELECT coalesce(max(id), 0) FROM state_groups",
- )
- max_group = rows[0][0]
-
- def reindex_txn(txn):
- new_last_state_group = last_state_group
- for count in range(batch_size):
- txn.execute(
- "SELECT id, room_id FROM state_groups"
- " WHERE ? < id AND id <= ?"
- " ORDER BY id ASC"
- " LIMIT 1",
- (new_last_state_group, max_group),
- )
- row = txn.fetchone()
- if row:
- state_group, room_id = row
-
- if not row or not state_group:
- return True, count
-
- txn.execute(
- "SELECT state_group FROM state_group_edges"
- " WHERE state_group = ?",
- (state_group,),
- )
-
- # If we reach a point where we've already started inserting
- # edges we should stop.
- if txn.fetchall():
- return True, count
-
- txn.execute(
- "SELECT coalesce(max(id), 0) FROM state_groups"
- " WHERE id < ? AND room_id = ?",
- (state_group, room_id),
- )
- prev_group, = txn.fetchone()
- new_last_state_group = state_group
-
- if prev_group:
- potential_hops = self._count_state_group_hops_txn(txn, prev_group)
- if potential_hops >= MAX_STATE_DELTA_HOPS:
- # We want to ensure chains are at most this long,#
- # otherwise read performance degrades.
- continue
-
- prev_state = self._get_state_groups_from_groups_txn(
- txn, [prev_group]
- )
- prev_state = prev_state[prev_group]
-
- curr_state = self._get_state_groups_from_groups_txn(
- txn, [state_group]
- )
- curr_state = curr_state[state_group]
-
- if not set(prev_state.keys()) - set(curr_state.keys()):
- # We can only do a delta if the current has a strict super set
- # of keys
-
- delta_state = {
- key: value
- for key, value in iteritems(curr_state)
- if prev_state.get(key, None) != value
- }
-
- self._simple_delete_txn(
- txn,
- table="state_group_edges",
- keyvalues={"state_group": state_group},
- )
-
- self._simple_insert_txn(
- txn,
- table="state_group_edges",
- values={
- "state_group": state_group,
- "prev_state_group": prev_group,
- },
- )
-
- self._simple_delete_txn(
- txn,
- table="state_groups_state",
- keyvalues={"state_group": state_group},
- )
-
- self._simple_insert_many_txn(
- txn,
- table="state_groups_state",
- values=[
- {
- "state_group": state_group,
- "room_id": room_id,
- "type": key[0],
- "state_key": key[1],
- "event_id": state_id,
- }
- for key, state_id in iteritems(delta_state)
- ],
- )
-
- progress = {
- "last_state_group": state_group,
- "rows_inserted": rows_inserted + batch_size,
- "max_group": max_group,
- }
-
- self._background_update_progress_txn(
- txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress
- )
-
- return False, batch_size
-
- finished, result = yield self.runInteraction(
- self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
- )
-
- if finished:
- yield self._end_background_update(
- self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
- )
-
- defer.returnValue(result * BATCH_SIZE_SCALE_FACTOR)
-
- @defer.inlineCallbacks
- def _background_index_state(self, progress, batch_size):
- def reindex_txn(conn):
- conn.rollback()
- if isinstance(self.database_engine, PostgresEngine):
- # postgres insists on autocommit for the index
- conn.set_session(autocommit=True)
- try:
- txn = conn.cursor()
- txn.execute(
- "CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
- " ON state_groups_state(state_group, type, state_key)"
- )
- txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
- finally:
- conn.set_session(autocommit=False)
- else:
- txn = conn.cursor()
- txn.execute(
- "CREATE INDEX state_groups_state_type_idx"
- " ON state_groups_state(state_group, type, state_key)"
- )
- txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
-
- yield self.runWithConnection(reindex_txn)
-
- yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
-
- defer.returnValue(1)
diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py
deleted file mode 100644
index ff266b09b0..0000000000
--- a/synapse/storage/stats.py
+++ /dev/null
@@ -1,468 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2018, 2019 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-
-from twisted.internet import defer
-
-from synapse.api.constants import EventTypes, Membership
-from synapse.storage.prepare_database import get_statements
-from synapse.storage.state_deltas import StateDeltasStore
-from synapse.util.caches.descriptors import cached
-
-logger = logging.getLogger(__name__)
-
-# these fields track absolutes (e.g. total number of rooms on the server)
-ABSOLUTE_STATS_FIELDS = {
- "room": (
- "current_state_events",
- "joined_members",
- "invited_members",
- "left_members",
- "banned_members",
- "state_events",
- ),
- "user": ("public_rooms", "private_rooms"),
-}
-
-TYPE_TO_ROOM = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user_id")}
-
-TEMP_TABLE = "_temp_populate_stats"
-
-
-class StatsStore(StateDeltasStore):
- def __init__(self, db_conn, hs):
- super(StatsStore, self).__init__(db_conn, hs)
-
- self.server_name = hs.hostname
- self.clock = self.hs.get_clock()
- self.stats_enabled = hs.config.stats_enabled
- self.stats_bucket_size = hs.config.stats_bucket_size
-
- self.register_background_update_handler(
- "populate_stats_createtables", self._populate_stats_createtables
- )
- self.register_background_update_handler(
- "populate_stats_process_rooms", self._populate_stats_process_rooms
- )
- self.register_background_update_handler(
- "populate_stats_cleanup", self._populate_stats_cleanup
- )
-
- @defer.inlineCallbacks
- def _populate_stats_createtables(self, progress, batch_size):
-
- if not self.stats_enabled:
- yield self._end_background_update("populate_stats_createtables")
- defer.returnValue(1)
-
- # Get all the rooms that we want to process.
- def _make_staging_area(txn):
- # Create the temporary tables
- stmts = get_statements("""
- -- We just recreate the table, we'll be reinserting the
- -- correct entries again later anyway.
- DROP TABLE IF EXISTS {temp}_rooms;
-
- CREATE TABLE IF NOT EXISTS {temp}_rooms(
- room_id TEXT NOT NULL,
- events BIGINT NOT NULL
- );
-
- CREATE INDEX {temp}_rooms_events
- ON {temp}_rooms(events);
- CREATE INDEX {temp}_rooms_id
- ON {temp}_rooms(room_id);
- """.format(temp=TEMP_TABLE).splitlines())
-
- for statement in stmts:
- txn.execute(statement)
-
- sql = (
- "CREATE TABLE IF NOT EXISTS "
- + TEMP_TABLE
- + "_position(position TEXT NOT NULL)"
- )
- txn.execute(sql)
-
- # Get rooms we want to process from the database, only adding
- # those that we haven't (i.e. those not in room_stats_earliest_token)
- sql = """
- INSERT INTO %s_rooms (room_id, events)
- SELECT c.room_id, count(*) FROM current_state_events AS c
- LEFT JOIN room_stats_earliest_token AS t USING (room_id)
- WHERE t.room_id IS NULL
- GROUP BY c.room_id
- """ % (TEMP_TABLE,)
- txn.execute(sql)
-
- new_pos = yield self.get_max_stream_id_in_current_state_deltas()
- yield self.runInteraction("populate_stats_temp_build", _make_staging_area)
- yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
- self.get_earliest_token_for_room_stats.invalidate_all()
-
- yield self._end_background_update("populate_stats_createtables")
- defer.returnValue(1)
-
- @defer.inlineCallbacks
- def _populate_stats_cleanup(self, progress, batch_size):
- """
- Update the user directory stream position, then clean up the old tables.
- """
- if not self.stats_enabled:
- yield self._end_background_update("populate_stats_cleanup")
- defer.returnValue(1)
-
- position = yield self._simple_select_one_onecol(
- TEMP_TABLE + "_position", None, "position"
- )
- yield self.update_stats_stream_pos(position)
-
- def _delete_staging_area(txn):
- txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms")
- txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
-
- yield self.runInteraction("populate_stats_cleanup", _delete_staging_area)
-
- yield self._end_background_update("populate_stats_cleanup")
- defer.returnValue(1)
-
- @defer.inlineCallbacks
- def _populate_stats_process_rooms(self, progress, batch_size):
-
- if not self.stats_enabled:
- yield self._end_background_update("populate_stats_process_rooms")
- defer.returnValue(1)
-
- # If we don't have progress filed, delete everything.
- if not progress:
- yield self.delete_all_stats()
-
- def _get_next_batch(txn):
- # Only fetch 250 rooms, so we don't fetch too many at once, even
- # if those 250 rooms have less than batch_size state events.
- sql = """
- SELECT room_id, events FROM %s_rooms
- ORDER BY events DESC
- LIMIT 250
- """ % (
- TEMP_TABLE,
- )
- txn.execute(sql)
- rooms_to_work_on = txn.fetchall()
-
- if not rooms_to_work_on:
- return None
-
- # Get how many are left to process, so we can give status on how
- # far we are in processing
- txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
- progress["remaining"] = txn.fetchone()[0]
-
- return rooms_to_work_on
-
- rooms_to_work_on = yield self.runInteraction(
- "populate_stats_temp_read", _get_next_batch
- )
-
- # No more rooms -- complete the transaction.
- if not rooms_to_work_on:
- yield self._end_background_update("populate_stats_process_rooms")
- defer.returnValue(1)
-
- logger.info(
- "Processing the next %d rooms of %d remaining",
- len(rooms_to_work_on), progress["remaining"],
- )
-
- # Number of state events we've processed by going through each room
- processed_event_count = 0
-
- for room_id, event_count in rooms_to_work_on:
-
- current_state_ids = yield self.get_current_state_ids(room_id)
-
- join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
- history_visibility_id = current_state_ids.get(
- (EventTypes.RoomHistoryVisibility, "")
- )
- encryption_id = current_state_ids.get((EventTypes.RoomEncryption, ""))
- name_id = current_state_ids.get((EventTypes.Name, ""))
- topic_id = current_state_ids.get((EventTypes.Topic, ""))
- avatar_id = current_state_ids.get((EventTypes.RoomAvatar, ""))
- canonical_alias_id = current_state_ids.get((EventTypes.CanonicalAlias, ""))
-
- state_events = yield self.get_events([
- join_rules_id, history_visibility_id, encryption_id, name_id,
- topic_id, avatar_id, canonical_alias_id,
- ])
-
- def _get_or_none(event_id, arg):
- event = state_events.get(event_id)
- if event:
- return event.content.get(arg)
- return None
-
- yield self.update_room_state(
- room_id,
- {
- "join_rules": _get_or_none(join_rules_id, "join_rule"),
- "history_visibility": _get_or_none(
- history_visibility_id, "history_visibility"
- ),
- "encryption": _get_or_none(encryption_id, "algorithm"),
- "name": _get_or_none(name_id, "name"),
- "topic": _get_or_none(topic_id, "topic"),
- "avatar": _get_or_none(avatar_id, "url"),
- "canonical_alias": _get_or_none(canonical_alias_id, "alias"),
- },
- )
-
- now = self.hs.get_reactor().seconds()
-
- # quantise time to the nearest bucket
- now = (now // self.stats_bucket_size) * self.stats_bucket_size
-
- def _fetch_data(txn):
-
- # Get the current token of the room
- current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn)
-
- current_state_events = len(current_state_ids)
-
- membership_counts = self._get_user_counts_in_room_txn(txn, room_id)
-
- total_state_events = self._get_total_state_event_counts_txn(
- txn, room_id
- )
-
- self._update_stats_txn(
- txn,
- "room",
- room_id,
- now,
- {
- "bucket_size": self.stats_bucket_size,
- "current_state_events": current_state_events,
- "joined_members": membership_counts.get(Membership.JOIN, 0),
- "invited_members": membership_counts.get(Membership.INVITE, 0),
- "left_members": membership_counts.get(Membership.LEAVE, 0),
- "banned_members": membership_counts.get(Membership.BAN, 0),
- "state_events": total_state_events,
- },
- )
- self._simple_insert_txn(
- txn,
- "room_stats_earliest_token",
- {"room_id": room_id, "token": current_token},
- )
-
- # We've finished a room. Delete it from the table.
- self._simple_delete_one_txn(
- txn, TEMP_TABLE + "_rooms", {"room_id": room_id},
- )
-
- yield self.runInteraction("update_room_stats", _fetch_data)
-
- # Update the remaining counter.
- progress["remaining"] -= 1
- yield self.runInteraction(
- "populate_stats",
- self._background_update_progress_txn,
- "populate_stats_process_rooms",
- progress,
- )
-
- processed_event_count += event_count
-
- if processed_event_count > batch_size:
- # Don't process any more rooms, we've hit our batch size.
- defer.returnValue(processed_event_count)
-
- defer.returnValue(processed_event_count)
-
- def delete_all_stats(self):
- """
- Delete all statistics records.
- """
-
- def _delete_all_stats_txn(txn):
- txn.execute("DELETE FROM room_state")
- txn.execute("DELETE FROM room_stats")
- txn.execute("DELETE FROM room_stats_earliest_token")
- txn.execute("DELETE FROM user_stats")
-
- return self.runInteraction("delete_all_stats", _delete_all_stats_txn)
-
- def get_stats_stream_pos(self):
- return self._simple_select_one_onecol(
- table="stats_stream_pos",
- keyvalues={},
- retcol="stream_id",
- desc="stats_stream_pos",
- )
-
- def update_stats_stream_pos(self, stream_id):
- return self._simple_update_one(
- table="stats_stream_pos",
- keyvalues={},
- updatevalues={"stream_id": stream_id},
- desc="update_stats_stream_pos",
- )
-
- def update_room_state(self, room_id, fields):
- """
- Args:
- room_id (str)
- fields (dict[str:Any])
- """
-
- # For whatever reason some of the fields may contain null bytes, which
- # postgres isn't a fan of, so we replace those fields with null.
- for col in (
- "join_rules",
- "history_visibility",
- "encryption",
- "name",
- "topic",
- "avatar",
- "canonical_alias"
- ):
- field = fields.get(col)
- if field and "\0" in field:
- fields[col] = None
-
- return self._simple_upsert(
- table="room_state",
- keyvalues={"room_id": room_id},
- values=fields,
- desc="update_room_state",
- )
-
- def get_deltas_for_room(self, room_id, start, size=100):
- """
- Get statistics deltas for a given room.
-
- Args:
- room_id (str)
- start (int): Pagination start. Number of entries, not timestamp.
- size (int): How many entries to return.
-
- Returns:
- Deferred[list[dict]], where the dict has the keys of
- ABSOLUTE_STATS_FIELDS["room"] and "ts".
- """
- return self._simple_select_list_paginate(
- "room_stats",
- {"room_id": room_id},
- "ts",
- start,
- size,
- retcols=(list(ABSOLUTE_STATS_FIELDS["room"]) + ["ts"]),
- order_direction="DESC",
- )
-
- def get_all_room_state(self):
- return self._simple_select_list(
- "room_state", None, retcols=("name", "topic", "canonical_alias")
- )
-
- @cached()
- def get_earliest_token_for_room_stats(self, room_id):
- """
- Fetch the "earliest token". This is used by the room stats delta
- processor to ignore deltas that have been processed between the
- start of the background task and any particular room's stats
- being calculated.
-
- Returns:
- Deferred[int]
- """
- return self._simple_select_one_onecol(
- "room_stats_earliest_token",
- {"room_id": room_id},
- retcol="token",
- allow_none=True,
- )
-
- def update_stats(self, stats_type, stats_id, ts, fields):
- table, id_col = TYPE_TO_ROOM[stats_type]
- return self._simple_upsert(
- table=table,
- keyvalues={id_col: stats_id, "ts": ts},
- values=fields,
- desc="update_stats",
- )
-
- def _update_stats_txn(self, txn, stats_type, stats_id, ts, fields):
- table, id_col = TYPE_TO_ROOM[stats_type]
- return self._simple_upsert_txn(
- txn, table=table, keyvalues={id_col: stats_id, "ts": ts}, values=fields
- )
-
- def update_stats_delta(self, ts, stats_type, stats_id, field, value):
- def _update_stats_delta(txn):
- table, id_col = TYPE_TO_ROOM[stats_type]
-
- sql = (
- "SELECT * FROM %s"
- " WHERE %s=? and ts=("
- " SELECT MAX(ts) FROM %s"
- " WHERE %s=?"
- ")"
- ) % (table, id_col, table, id_col)
- txn.execute(sql, (stats_id, stats_id))
- rows = self.cursor_to_dict(txn)
- if len(rows) == 0:
- # silently skip as we don't have anything to apply a delta to yet.
- # this tries to minimise any race between the initial sync and
- # subsequent deltas arriving.
- return
-
- current_ts = ts
- latest_ts = rows[0]["ts"]
- if current_ts < latest_ts:
- # This one is in the past, but we're just encountering it now.
- # Mark it as part of the current bucket.
- current_ts = latest_ts
- elif ts != latest_ts:
- # we have to copy our absolute counters over to the new entry.
- values = {
- key: rows[0][key] for key in ABSOLUTE_STATS_FIELDS[stats_type]
- }
- values[id_col] = stats_id
- values["ts"] = ts
- values["bucket_size"] = self.stats_bucket_size
-
- self._simple_insert_txn(txn, table=table, values=values)
-
- # actually update the new value
- if stats_type in ABSOLUTE_STATS_FIELDS[stats_type]:
- self._simple_update_txn(
- txn,
- table=table,
- keyvalues={id_col: stats_id, "ts": current_ts},
- updatevalues={field: value},
- )
- else:
- sql = ("UPDATE %s SET %s=%s+? WHERE %s=? AND ts=?") % (
- table,
- field,
- field,
- id_col,
- )
- txn.execute(sql, (value, stats_id, current_ts))
-
- return self.runInteraction("update_stats_delta", _update_stats_delta)
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
new file mode 100644
index 0000000000..daff81c5ee
--- /dev/null
+++ b/synapse/storage/types.py
@@ -0,0 +1,65 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Iterable, Iterator, List, Tuple
+
+from typing_extensions import Protocol
+
+
+"""
+Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
+"""
+
+
+class Cursor(Protocol):
+ def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any:
+ ...
+
+ def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any:
+ ...
+
+ def fetchall(self) -> List[Tuple]:
+ ...
+
+ def fetchone(self) -> Tuple:
+ ...
+
+ @property
+ def description(self) -> Any:
+ return None
+
+ @property
+ def rowcount(self) -> int:
+ return 0
+
+ def __iter__(self) -> Iterator[Tuple]:
+ ...
+
+ def close(self) -> None:
+ ...
+
+
+class Connection(Protocol):
+ def cursor(self) -> Cursor:
+ ...
+
+ def close(self) -> None:
+ ...
+
+ def commit(self) -> None:
+ ...
+
+ def rollback(self, *args, **kwargs) -> None:
+ ...
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index f1c8d99419..9d851beaa5 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -46,7 +46,7 @@ def _load_current_id(db_conn, table, column, step=1):
cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
else:
cur.execute("SELECT MIN(%s) FROM %s" % (column, table))
- val, = cur.fetchone()
+ (val,) = cur.fetchone()
cur.close()
current_id = int(val) if val else step
return (max if step > 0 else min)(current_id, step)
@@ -195,6 +195,6 @@ class ChainedIdGenerator(object):
with self._lock:
if self._unfinished_ids:
stream_id, chained_id = self._unfinished_ids[0]
- return (stream_id - 1, chained_id)
+ return stream_id - 1, chained_id
- return (self._current_max, self.chained_generator.get_current_token())
+ return self._current_max, self.chained_generator.get_current_token()
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index 451e4fa441..cd56cd91ed 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -30,34 +30,34 @@ class SourcePaginationConfig(object):
"""A configuration object which stores pagination parameters for a
specific event source."""
- def __init__(self, from_key=None, to_key=None, direction='f',
- limit=None):
+ def __init__(self, from_key=None, to_key=None, direction="f", limit=None):
self.from_key = from_key
self.to_key = to_key
- self.direction = 'f' if direction == 'f' else 'b'
+ self.direction = "f" if direction == "f" else "b"
self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None
def __repr__(self):
- return (
- "StreamConfig(from_key=%r, to_key=%r, direction=%r, limit=%r)"
- ) % (self.from_key, self.to_key, self.direction, self.limit)
+ return "StreamConfig(from_key=%r, to_key=%r, direction=%r, limit=%r)" % (
+ self.from_key,
+ self.to_key,
+ self.direction,
+ self.limit,
+ )
class PaginationConfig(object):
"""A configuration object which stores pagination parameters."""
- def __init__(self, from_token=None, to_token=None, direction='f',
- limit=None):
+ def __init__(self, from_token=None, to_token=None, direction="f", limit=None):
self.from_token = from_token
self.to_token = to_token
- self.direction = 'f' if direction == 'f' else 'b'
+ self.direction = "f" if direction == "f" else "b"
self.limit = min(int(limit), MAX_LIMIT) if limit is not None else None
@classmethod
- def from_request(cls, request, raise_invalid_params=True,
- default_limit=None):
- direction = parse_string(request, "dir", default='f', allowed_values=['f', 'b'])
+ def from_request(cls, request, raise_invalid_params=True, default_limit=None):
+ direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
from_tok = parse_string(request, "from")
to_tok = parse_string(request, "to")
@@ -88,10 +88,12 @@ class PaginationConfig(object):
raise SynapseError(400, "Invalid request.")
def __repr__(self):
- return (
- "PaginationConfig(from_tok=%r, to_tok=%r,"
- " direction=%r, limit=%r)"
- ) % (self.from_token, self.to_token, self.direction, self.limit)
+ return ("PaginationConfig(from_tok=%r, to_tok=%r, direction=%r, limit=%r)") % (
+ self.from_token,
+ self.to_token,
+ self.direction,
+ self.limit,
+ )
def get_source_config(self, source_name):
keyname = "%s_key" % source_name
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index e5220132a3..fcd2aaa9c9 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Any, Dict
+
from twisted.internet import defer
from synapse.handlers.account_data import AccountDataEventSource
@@ -34,9 +36,8 @@ class EventSources(object):
def __init__(self, hs):
self.sources = {
- name: cls(hs)
- for name, cls in EventSources.SOURCE_TYPES.items()
- }
+ name: cls(hs) for name, cls in EventSources.SOURCE_TYPES.items()
+ } # type: Dict[str, Any]
self.store = hs.get_datastore()
@defer.inlineCallbacks
@@ -47,54 +48,38 @@ class EventSources(object):
groups_key = self.store.get_group_stream_token()
token = StreamToken(
- room_key=(
- yield self.sources["room"].get_current_key()
- ),
- presence_key=(
- yield self.sources["presence"].get_current_key()
- ),
- typing_key=(
- yield self.sources["typing"].get_current_key()
- ),
- receipt_key=(
- yield self.sources["receipt"].get_current_key()
- ),
- account_data_key=(
- yield self.sources["account_data"].get_current_key()
- ),
+ room_key=(yield self.sources["room"].get_current_key()),
+ presence_key=(yield self.sources["presence"].get_current_key()),
+ typing_key=(yield self.sources["typing"].get_current_key()),
+ receipt_key=(yield self.sources["receipt"].get_current_key()),
+ account_data_key=(yield self.sources["account_data"].get_current_key()),
push_rules_key=push_rules_key,
to_device_key=to_device_key,
device_list_key=device_list_key,
groups_key=groups_key,
)
- defer.returnValue(token)
+ return token
@defer.inlineCallbacks
- def get_current_token_for_room(self, room_id):
- push_rules_key, _ = self.store.get_push_rules_stream_token()
- to_device_key = self.store.get_to_device_stream_token()
- device_list_key = self.store.get_device_stream_token()
- groups_key = self.store.get_group_stream_token()
+ def get_current_token_for_pagination(self):
+ """Get the current token for a given room to be used to paginate
+ events.
+ The returned token does not have the current values for fields other
+ than `room`, since they are not used during pagination.
+
+ Retuns:
+ Deferred[StreamToken]
+ """
token = StreamToken(
- room_key=(
- yield self.sources["room"].get_current_key_for_room(room_id)
- ),
- presence_key=(
- yield self.sources["presence"].get_current_key()
- ),
- typing_key=(
- yield self.sources["typing"].get_current_key()
- ),
- receipt_key=(
- yield self.sources["receipt"].get_current_key()
- ),
- account_data_key=(
- yield self.sources["account_data"].get_current_key()
- ),
- push_rules_key=push_rules_key,
- to_device_key=to_device_key,
- device_list_key=device_list_key,
- groups_key=groups_key,
+ room_key=(yield self.sources["room"].get_current_key()),
+ presence_key=0,
+ typing_key=0,
+ receipt_key=0,
+ account_data_key=0,
+ push_rules_key=0,
+ to_device_key=0,
+ device_list_key=0,
+ groups_key=0,
)
- defer.returnValue(token)
+ return token
diff --git a/synapse/third_party_rules/access_rules.py b/synapse/third_party_rules/access_rules.py
index bd79de845f..2c9155d15c 100644
--- a/synapse/third_party_rules/access_rules.py
+++ b/synapse/third_party_rules/access_rules.py
@@ -44,9 +44,7 @@ VALID_ACCESS_RULES = (
# * the default power level for users (users_default) being set to anything other than 0.
# * a non-default power level being assigned to any user which would be forbidden from
# joining a restricted room.
-RULES_WITH_RESTRICTED_POWER_LEVELS = (
- ACCESS_RULE_UNRESTRICTED,
-)
+RULES_WITH_RESTRICTED_POWER_LEVELS = (ACCESS_RULE_UNRESTRICTED,)
class RoomAccessRules(object):
@@ -76,7 +74,7 @@ class RoomAccessRules(object):
self.id_server = config["id_server"]
self.domains_forbidden_when_restricted = config.get(
- "domains_forbidden_when_restricted", [],
+ "domains_forbidden_when_restricted", []
)
@staticmethod
@@ -86,7 +84,7 @@ class RoomAccessRules(object):
else:
raise ConfigError("No IS for event rules TchapEventRules")
- def on_create_room(self, requester, config, is_requester_admin):
+ def on_create_room(self, requester, config, is_requester_admin) -> bool:
"""Implements synapse.events.ThirdPartyEventRules.on_create_room
Checks if a im.vector.room.access_rules event is being set during room creation.
@@ -113,9 +111,8 @@ class RoomAccessRules(object):
raise SynapseError(400, "Invalid access rule")
# Make sure the rule is "direct" if the room is a direct chat.
- if (
- (is_direct and access_rule != ACCESS_RULE_DIRECT)
- or (access_rule == ACCESS_RULE_DIRECT and not is_direct)
+ if (is_direct and access_rule != ACCESS_RULE_DIRECT) or (
+ access_rule == ACCESS_RULE_DIRECT and not is_direct
):
raise SynapseError(400, "Invalid access rule")
@@ -136,13 +133,13 @@ class RoomAccessRules(object):
if not config.get("initial_state"):
config["initial_state"] = []
- config["initial_state"].append({
- "type": ACCESS_RULES_TYPE,
- "state_key": "",
- "content": {
- "rule": default_rule,
+ config["initial_state"].append(
+ {
+ "type": ACCESS_RULES_TYPE,
+ "state_key": "",
+ "content": {"rule": default_rule},
}
- })
+ )
access_rule = default_rule
@@ -150,16 +147,13 @@ class RoomAccessRules(object):
# rule, whether it's a user-defined one or the default one (i.e. if it involves
# a "public" join rule, the access rule must be "restricted").
if (
- (
- join_rule == JoinRules.PUBLIC
- or preset == RoomCreationPreset.PUBLIC_CHAT
- ) and access_rule != ACCESS_RULE_RESTRICTED
- ):
+ join_rule == JoinRules.PUBLIC or preset == RoomCreationPreset.PUBLIC_CHAT
+ ) and access_rule != ACCESS_RULE_RESTRICTED:
raise SynapseError(400, "Invalid access rule")
# Check if the creator can override values for the power levels.
allowed = self._is_power_level_content_allowed(
- config.get("power_level_content_override", {}), access_rule,
+ config.get("power_level_content_override", {}), access_rule
)
if not allowed:
raise SynapseError(400, "Invalid power levels content override")
@@ -173,6 +167,8 @@ class RoomAccessRules(object):
if not allowed:
raise SynapseError(400, "Invalid power levels content")
+ return True
+
@defer.inlineCallbacks
def check_threepid_can_be_invited(self, medium, address, state_events):
"""Implements synapse.events.ThirdPartyEventRules.check_threepid_can_be_invited
@@ -202,10 +198,7 @@ class RoomAccessRules(object):
# Get the HS this address belongs to from the identity server.
res = yield self.http_client.get_json(
"https://%s/_matrix/identity/api/v1/info" % (self.id_server,),
- {
- "medium": medium,
- "address": address,
- }
+ {"medium": medium, "address": address},
)
# Look for a domain that's not forbidden from being invited.
@@ -411,7 +404,7 @@ class RoomAccessRules(object):
# user.
target = event.state_key
is_from_threepid_invite = self._is_invite_from_threepid(
- event, threepid_tokens[0],
+ event, threepid_tokens[0]
)
if is_from_threepid_invite or target == existing_members[0]:
return True
@@ -438,11 +431,11 @@ class RoomAccessRules(object):
return True
# If users_default is explicitly set to a non-0 value, deny the event.
- users_default = content.get('users_default', 0)
+ users_default = content.get("users_default", 0)
if users_default:
return False
- users = content.get('users', {})
+ users = content.get("users", {})
for user_id, power_level in users.items():
server_name = get_domain_from_id(user_id)
# Check the domain against the blacklist. If found, and the PL isn't 0, deny
@@ -477,7 +470,7 @@ class RoomAccessRules(object):
Returns:
bool, True if the event can be allowed, False otherwise.
"""
- if event.content.get('join_rule') == JoinRules.PUBLIC:
+ if event.content.get("join_rule") == JoinRules.PUBLIC:
return rule == ACCESS_RULE_RESTRICTED
return True
@@ -586,8 +579,10 @@ class RoomAccessRules(object):
invite (EventBase): The m.room.member event with "invite" membership.
threepid_invite_token (str): The state key from the 3PID invite.
"""
- token = invite.content.get(
- "third_party_invite", {},
- ).get("signed", {}).get("token", "")
+ token = (
+ invite.content.get("third_party_invite", {})
+ .get("signed", {})
+ .get("token", "")
+ )
return token == threepid_invite_token
diff --git a/synapse/types.py b/synapse/types.py
index e6afc05cee..9f28f9a192 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,18 +15,46 @@
# limitations under the License.
import re
import string
+import sys
from collections import namedtuple
+from typing import Any, Dict, Tuple, TypeVar
from six.moves import filter
import attr
+from signedjson.key import decode_verify_key_bytes
+from unpaddedbase64 import decode_base64
-from synapse.api.errors import SynapseError
+from synapse.api.errors import Codes, SynapseError
+# define a version of typing.Collection that works on python 3.5
+if sys.version_info[:3] >= (3, 6, 0):
+ from typing import Collection
+else:
+ from typing import Sized, Iterable, Container
-class Requester(namedtuple("Requester", [
- "user", "access_token_id", "is_guest", "device_id", "app_service",
-])):
+ T_co = TypeVar("T_co", covariant=True)
+
+ class Collection(Iterable[T_co], Container[T_co], Sized):
+ __slots__ = ()
+
+
+# Define a state map type from type/state_key to T (usually an event ID or
+# event)
+T = TypeVar("T")
+StateMap = Dict[Tuple[str, str], T]
+
+
+# the type of a JSON-serialisable dict. This could be made stronger, but it will
+# do for now.
+JsonDict = Dict[str, Any]
+
+
+class Requester(
+ namedtuple(
+ "Requester", ["user", "access_token_id", "is_guest", "device_id", "app_service"]
+ )
+):
"""
Represents the user making a request
@@ -78,8 +107,9 @@ class Requester(namedtuple("Requester", [
)
-def create_requester(user_id, access_token_id=None, is_guest=False,
- device_id=None, app_service=None):
+def create_requester(
+ user_id, access_token_id=None, is_guest=False, device_id=None, app_service=None
+):
"""
Create a new ``Requester`` object
@@ -103,7 +133,7 @@ def get_domain_from_id(string):
idx = string.find(":")
if idx == -1:
raise SynapseError(400, "Invalid ID: %r" % (string,))
- return string[idx + 1:]
+ return string[idx + 1 :]
def get_localpart_from_id(string):
@@ -113,9 +143,7 @@ def get_localpart_from_id(string):
return string[1:idx]
-class DomainSpecificString(
- namedtuple("DomainSpecificString", ("localpart", "domain"))
-):
+class DomainSpecificString(namedtuple("DomainSpecificString", ("localpart", "domain"))):
"""Common base class among ID/name strings that have a local part and a
domain name, prefixed with a sigil.
@@ -140,19 +168,22 @@ class DomainSpecificString(
return self
@classmethod
- def from_string(cls, s):
+ def from_string(cls, s: str):
"""Parse the string given by 's' into a structure object."""
if len(s) < 1 or s[0:1] != cls.SIGIL:
- raise SynapseError(400, "Expected %s string to start with '%s'" % (
- cls.__name__, cls.SIGIL,
- ))
+ raise SynapseError(
+ 400,
+ "Expected %s string to start with '%s'" % (cls.__name__, cls.SIGIL),
+ Codes.INVALID_PARAM,
+ )
- parts = s[1:].split(':', 1)
+ parts = s[1:].split(":", 1)
if len(parts) != 2:
raise SynapseError(
- 400, "Expected %s of the form '%slocalname:domain'" % (
- cls.__name__, cls.SIGIL,
- )
+ 400,
+ "Expected %s of the form '%slocalname:domain'"
+ % (cls.__name__, cls.SIGIL),
+ Codes.INVALID_PARAM,
)
domain = parts[1]
@@ -178,47 +209,52 @@ class DomainSpecificString(
class UserID(DomainSpecificString):
"""Structure representing a user ID."""
+
SIGIL = "@"
class RoomAlias(DomainSpecificString):
"""Structure representing a room name."""
+
SIGIL = "#"
class RoomID(DomainSpecificString):
"""Structure representing a room id. """
+
SIGIL = "!"
class EventID(DomainSpecificString):
"""Structure representing an event id. """
+
SIGIL = "$"
class GroupID(DomainSpecificString):
"""Structure representing a group ID."""
+
SIGIL = "+"
@classmethod
def from_string(cls, s):
group_id = super(GroupID, cls).from_string(s)
if not group_id.localpart:
- raise SynapseError(
- 400,
- "Group ID cannot be empty",
- )
+ raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM)
if contains_invalid_mxid_characters(group_id.localpart):
raise SynapseError(
400,
"Group ID can only contain characters a-z, 0-9, or '=_-./'",
+ Codes.INVALID_PARAM,
)
return group_id
-mxid_localpart_allowed_characters = set("_-./=" + string.ascii_lowercase + string.digits)
+mxid_localpart_allowed_characters = set(
+ "_-./=" + string.ascii_lowercase + string.digits
+)
def contains_invalid_mxid_characters(localpart):
@@ -260,9 +296,9 @@ UPPER_CASE_PATTERN = re.compile(b"[A-Z_]")
# bytes rather than strings
#
NON_MXID_CHARACTER_PATTERN = re.compile(
- ("[^%s]" % (
- re.escape("".join(mxid_localpart_allowed_characters - {"="}),),
- )).encode("ascii"),
+ ("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters - {"="})),)).encode(
+ "ascii"
+ )
)
@@ -281,10 +317,11 @@ def map_username_to_mxid_localpart(username, case_sensitive=False):
unicode: string suitable for a mxid localpart
"""
if not isinstance(username, bytes):
- username = username.encode('utf-8')
+ username = username.encode("utf-8")
# first we sort out upper-case characters
if case_sensitive:
+
def f1(m):
return b"_" + m.group().lower()
@@ -304,27 +341,31 @@ def map_username_to_mxid_localpart(username, case_sensitive=False):
username = NON_MXID_CHARACTER_PATTERN.sub(f2, username)
# we also do the =-escaping to mxids starting with an underscore.
- username = re.sub(b'^_', b'=5f', username)
+ username = re.sub(b"^_", b"=5f", username)
# we should now only have ascii bytes left, so can decode back to a
# unicode.
- return username.decode('ascii')
+ return username.decode("ascii")
class StreamToken(
- namedtuple("Token", (
- "room_key",
- "presence_key",
- "typing_key",
- "receipt_key",
- "account_data_key",
- "push_rules_key",
- "to_device_key",
- "device_list_key",
- "groups_key",
- ))
+ namedtuple(
+ "Token",
+ (
+ "room_key",
+ "presence_key",
+ "typing_key",
+ "receipt_key",
+ "account_data_key",
+ "push_rules_key",
+ "to_device_key",
+ "device_list_key",
+ "groups_key",
+ ),
+ )
):
_SEPARATOR = "_"
+ START = None # type: StreamToken
@classmethod
def from_string(cls, string):
@@ -383,9 +424,7 @@ class StreamToken(
return self._replace(**{key: new_value})
-StreamToken.START = StreamToken(
- *(["s0"] + ["0"] * (len(StreamToken._fields) - 1))
-)
+StreamToken.START = StreamToken(*(["s0"] + ["0"] * (len(StreamToken._fields) - 1)))
class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
@@ -410,15 +449,16 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
"topological_ordering" id of the event it comes after, followed by "-",
followed by the "stream_ordering" id of the event it comes after.
"""
- __slots__ = []
+
+ __slots__ = [] # type: list
@classmethod
def parse(cls, string):
try:
- if string[0] == 's':
+ if string[0] == "s":
return cls(topological=None, stream=int(string[1:]))
- if string[0] == 't':
- parts = string[1:].split('-', 1)
+ if string[0] == "t":
+ parts = string[1:].split("-", 1)
return cls(topological=int(parts[0]), stream=int(parts[1]))
except Exception:
pass
@@ -427,7 +467,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
@classmethod
def parse_stream_token(cls, string):
try:
- if string[0] == 's':
+ if string[0] == "s":
return cls(topological=None, stream=int(string[1:]))
except Exception:
pass
@@ -441,7 +481,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
class ThirdPartyInstanceID(
- namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id"))
+ namedtuple("ThirdPartyInstanceID", ("appservice_id", "network_id"))
):
# Deny iteration because it will bite you if you try to create a singleton
# set by:
@@ -465,20 +505,42 @@ class ThirdPartyInstanceID(
return cls(appservice_id=bits[0], network_id=bits[1])
def to_string(self):
- return "%s|%s" % (self.appservice_id, self.network_id,)
+ return "%s|%s" % (self.appservice_id, self.network_id)
__str__ = to_string
@classmethod
- def create(cls, appservice_id, network_id,):
+ def create(cls, appservice_id, network_id):
return cls(appservice_id=appservice_id, network_id=network_id)
@attr.s(slots=True)
class ReadReceipt(object):
"""Information about a read-receipt"""
+
room_id = attr.ib()
receipt_type = attr.ib()
user_id = attr.ib()
event_ids = attr.ib()
data = attr.ib()
+
+
+def get_verify_key_from_cross_signing_key(key_info):
+ """Get the key ID and signedjson verify key from a cross-signing key dict
+
+ Args:
+ key_info (dict): a cross-signing key dict, which must have a "keys"
+ property that has exactly one item in it
+
+ Returns:
+ (str, VerifyKey): the key ID and verify key for the cross-signing key
+ """
+ # make sure that exactly one key is provided
+ if "keys" not in key_info:
+ raise ValueError("Invalid key")
+ keys = key_info["keys"]
+ if len(keys) != 1:
+ raise ValueError("Invalid key")
+ # and return that one key
+ for key_id, key_data in keys.items():
+ return (key_id, decode_verify_key_bytes(key_id, decode_base64(key_data)))
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 8f5a526800..60f0de70f7 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -15,13 +15,12 @@
import logging
import re
-from itertools import islice
import attr
from twisted.internet import defer, task
-from synapse.util.logcontext import PreserveLoggingContext
+from synapse.logging import context
logger = logging.getLogger(__name__)
@@ -40,15 +39,16 @@ class Clock(object):
Args:
reactor: The Twisted reactor to use.
"""
+
_reactor = attr.ib()
@defer.inlineCallbacks
def sleep(self, seconds):
d = defer.Deferred()
- with PreserveLoggingContext():
+ with context.PreserveLoggingContext():
self._reactor.callLater(seconds, d.callback, seconds)
res = yield d
- defer.returnValue(res)
+ return res
def time(self):
"""Returns the current system time in seconds since epoch."""
@@ -61,7 +61,10 @@ class Clock(object):
def looping_call(self, f, msec, *args, **kwargs):
"""Call a function repeatedly.
- Waits `msec` initially before calling `f` for the first time.
+ Waits `msec` initially before calling `f` for the first time.
+
+ Note that the function will be called with no logcontext, so if it is anything
+ other than trivial, you probably want to wrap it in run_as_background_process.
Args:
f(function): The function to call repeatedly.
@@ -72,25 +75,27 @@ class Clock(object):
call = task.LoopingCall(f, *args, **kwargs)
call.clock = self._reactor
d = call.start(msec / 1000.0, now=False)
- d.addErrback(
- log_failure, "Looping call died", consumeErrors=False,
- )
+ d.addErrback(log_failure, "Looping call died", consumeErrors=False)
return call
def call_later(self, delay, callback, *args, **kwargs):
"""Call something later
+ Note that the function will be called with no logcontext, so if it is anything
+ other than trivial, you probably want to wrap it in run_as_background_process.
+
Args:
delay(float): How long to wait in seconds.
callback(function): Function to call
*args: Postional arguments to pass to function.
**kwargs: Key arguments to pass to function.
"""
+
def wrapped_callback(*args, **kwargs):
- with PreserveLoggingContext():
+ with context.PreserveLoggingContext():
callback(*args, **kwargs)
- with PreserveLoggingContext():
+ with context.PreserveLoggingContext():
return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs)
def cancel_call_later(self, timer, ignore_errs=False):
@@ -101,22 +106,6 @@ class Clock(object):
raise
-def batch_iter(iterable, size):
- """batch an iterable up into tuples with a maximum size
-
- Args:
- iterable (iterable): the iterable to slice
- size (int): the maximum batch size
-
- Returns:
- an iterator over the chunks
- """
- # make sure we can deal with iterables like lists too
- sourceiter = iter(iterable)
- # call islice until it returns an empty tuple
- return iter(lambda: tuple(islice(sourceiter, size)), ())
-
-
def log_failure(failure, msg, consumeErrors=True):
"""Creates a function suitable for passing to `Deferred.addErrback` that
logs any failures that occur.
@@ -131,12 +120,7 @@ def log_failure(failure, msg, consumeErrors=True):
"""
logger.error(
- msg,
- exc_info=(
- failure.type,
- failure.value,
- failure.getTracebackObject()
- )
+ msg, exc_info=(failure.type, failure.value, failure.getTracebackObject())
)
if not consumeErrors:
@@ -154,12 +138,12 @@ def glob_to_regex(glob):
Returns:
re.RegexObject
"""
- res = ''
+ res = ""
for c in glob:
- if c == '*':
- res = res + '.*'
- elif c == '?':
- res = res + '.'
+ if c == "*":
+ res = res + ".*"
+ elif c == "?":
+ res = res + "."
else:
res = res + re.escape(c)
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 7253ba120f..581dffd8a0 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -13,23 +13,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import collections
import logging
from contextlib import contextmanager
+from typing import Dict, Sequence, Set, Union
from six.moves import range
+import attr
+
from twisted.internet import defer
from twisted.internet.defer import CancelledError
from twisted.python import failure
-from synapse.util import Clock, logcontext, unwrapFirstError
-
-from .logcontext import (
+from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
run_in_background,
)
+from synapse.util import Clock, unwrapFirstError
logger = logging.getLogger(__name__)
@@ -70,6 +73,10 @@ class ObservableDeferred(object):
def errback(f):
object.__setattr__(self, "_result", (False, f))
while self._observers:
+ # This is a little bit of magic to correctly propagate stack
+ # traces when we `await` on one of the observer deferreds.
+ f.value.__failure__ = f
+
try:
# TODO: Handle errors here.
self._observers.pop().errback(f)
@@ -83,11 +90,12 @@ class ObservableDeferred(object):
deferred.addCallbacks(callback, errback)
- def observe(self):
+ def observe(self) -> defer.Deferred:
"""Observe the underlying deferred.
- Can return either a deferred if the underlying deferred is still pending
- (or has failed), or the actual value. Callers may need to use maybeDeferred.
+ This returns a brand new deferred that is resolved when the underlying
+ deferred is resolved. Interacting with the returned deferred does not
+ effect the underdlying deferred.
"""
if not self._result:
d = defer.Deferred()
@@ -95,13 +103,14 @@ class ObservableDeferred(object):
def remove(r):
self._observers.discard(d)
return r
+
d.addBoth(remove)
self._observers.add(d)
return d
else:
success, res = self._result
- return res if success else defer.fail(res)
+ return defer.succeed(res) if success else defer.fail(res)
def observers(self):
return self._observers
@@ -123,7 +132,9 @@ class ObservableDeferred(object):
def __repr__(self):
return "<ObservableDeferred object at %s, result=%r, _deferred=%r>" % (
- id(self), self._result, self._deferred,
+ id(self),
+ self._result,
+ self._deferred,
)
@@ -132,9 +143,9 @@ def concurrently_execute(func, args, limit):
the number of concurrent executions.
Args:
- func (func): Function to execute, should return a deferred.
- args (list): List of arguments to pass to func, each invocation of func
- gets a signle argument.
+ func (func): Function to execute, should return a deferred or coroutine.
+ args (Iterable): List of arguments to pass to func, each invocation of func
+ gets a single argument.
limit (int): Maximum number of conccurent executions.
Returns:
@@ -142,18 +153,19 @@ def concurrently_execute(func, args, limit):
"""
it = iter(args)
- @defer.inlineCallbacks
- def _concurrently_execute_inner():
+ async def _concurrently_execute_inner():
try:
while True:
- yield func(next(it))
+ await maybe_awaitable(func(next(it)))
except StopIteration:
pass
- return logcontext.make_deferred_yieldable(defer.gatherResults([
- run_in_background(_concurrently_execute_inner)
- for _ in range(limit)
- ], consumeErrors=True)).addErrback(unwrapFirstError)
+ return make_deferred_yieldable(
+ defer.gatherResults(
+ [run_in_background(_concurrently_execute_inner) for _ in range(limit)],
+ consumeErrors=True,
+ )
+ ).addErrback(unwrapFirstError)
def yieldable_gather_results(func, iter, *args, **kwargs):
@@ -169,10 +181,12 @@ def yieldable_gather_results(func, iter, *args, **kwargs):
Deferred[list]: Resolved when all functions have been invoked, or errors if
one of the function calls fails.
"""
- return logcontext.make_deferred_yieldable(defer.gatherResults([
- run_in_background(func, item, *args, **kwargs)
- for item in iter
- ], consumeErrors=True)).addErrback(unwrapFirstError)
+ return make_deferred_yieldable(
+ defer.gatherResults(
+ [run_in_background(func, item, *args, **kwargs) for item in iter],
+ consumeErrors=True,
+ )
+ ).addErrback(unwrapFirstError)
class Linearizer(object):
@@ -185,6 +199,7 @@ class Linearizer(object):
# do some work.
"""
+
def __init__(self, name=None, max_count=1, clock=None):
"""
Args:
@@ -197,6 +212,7 @@ class Linearizer(object):
if not clock:
from twisted.internet import reactor
+
clock = Clock(reactor)
self._clock = clock
self.max_count = max_count
@@ -205,7 +221,9 @@ class Linearizer(object):
# the first element is the number of things executing, and
# the second element is an OrderedDict, where the keys are deferreds for the
# things blocked from executing.
- self.key_to_defer = {}
+ self.key_to_defer = (
+ {}
+ ) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]
def queue(self, key):
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
@@ -221,7 +239,7 @@ class Linearizer(object):
res = self._await_lock(key)
else:
logger.debug(
- "Acquired uncontended linearizer lock %r for key %r", self.name, key,
+ "Acquired uncontended linearizer lock %r for key %r", self.name, key
)
entry[0] += 1
res = defer.succeed(None)
@@ -266,9 +284,7 @@ class Linearizer(object):
"""
entry = self.key_to_defer[key]
- logger.debug(
- "Waiting to acquire linearizer lock %r for key %r", self.name, key,
- )
+ logger.debug("Waiting to acquire linearizer lock %r for key %r", self.name, key)
new_defer = make_deferred_yieldable(defer.Deferred())
entry[1][new_defer] = 1
@@ -293,14 +309,14 @@ class Linearizer(object):
logger.info("defer %r got err %r", new_defer, e)
if isinstance(e, CancelledError):
logger.debug(
- "Cancelling wait for linearizer lock %r for key %r",
- self.name, key,
+ "Cancelling wait for linearizer lock %r for key %r", self.name, key
)
else:
- logger.warn(
+ logger.warning(
"Unexpected exception waiting for linearizer lock %r for key %r",
- self.name, key,
+ self.name,
+ key,
)
# we just have to take ourselves back out of the queue.
@@ -334,10 +350,10 @@ class ReadWriteLock(object):
def __init__(self):
# Latest readers queued
- self.key_to_current_readers = {}
+ self.key_to_current_readers = {} # type: Dict[str, Set[defer.Deferred]]
# Latest writer queued
- self.key_to_current_writer = {}
+ self.key_to_current_writer = {} # type: Dict[str, defer.Deferred]
@defer.inlineCallbacks
def read(self, key):
@@ -360,7 +376,7 @@ class ReadWriteLock(object):
new_defer.callback(None)
self.key_to_current_readers.get(key, set()).discard(new_defer)
- defer.returnValue(_ctx_manager())
+ return _ctx_manager()
@defer.inlineCallbacks
def write(self, key):
@@ -390,7 +406,7 @@ class ReadWriteLock(object):
if self.key_to_current_writer[key] == new_defer:
self.key_to_current_writer.pop(key)
- defer.returnValue(_ctx_manager())
+ return _ctx_manager()
def _cancelled_to_timed_out_error(value, timeout):
@@ -438,7 +454,7 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
try:
deferred.cancel()
- except: # noqa: E722, if we throw any exception it'll break time outs
+ except: # noqa: E722, if we throw any exception it'll break time outs
logger.exception("Canceller failed during timeout")
if not new_d.called:
@@ -473,3 +489,30 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
deferred.addCallbacks(success_cb, failure_cb)
return new_d
+
+
+@attr.s(slots=True, frozen=True)
+class DoneAwaitable(object):
+ """Simple awaitable that returns the provided value.
+ """
+
+ value = attr.ib()
+
+ def __await__(self):
+ return self
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ raise StopIteration(self.value)
+
+
+def maybe_awaitable(value):
+ """Convert a value to an awaitable if not already an awaitable.
+ """
+
+ if hasattr(value, "__await__"):
+ return value
+
+ return DoneAwaitable(value)
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index f37d5bec08..da5077b471 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,6 +16,7 @@
import logging
import os
+from typing import Dict
import six
from six.moves import intern
@@ -36,7 +38,7 @@ def get_cache_factor_for(cache_name):
caches_by_name = {}
-collectors_by_name = {}
+collectors_by_name = {} # type: Dict
cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
@@ -51,7 +53,19 @@ response_cache_evicted = Gauge(
response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"])
-def register_cache(cache_type, cache_name, cache):
+def register_cache(cache_type, cache_name, cache, collect_callback=None):
+ """Register a cache object for metric collection.
+
+ Args:
+ cache_type (str):
+ cache_name (str): name of the cache
+ cache (object): cache itself
+ collect_callback (callable|None): if not None, a function which is called during
+ metric collection to update additional metrics.
+
+ Returns:
+ CacheMetric: an object which provides inc_{hits,misses,evictions} methods
+ """
# Check if the metric is already registered. Unregister it, if so.
# This usually happens during tests, as at runtime these caches are
@@ -90,8 +104,10 @@ def register_cache(cache_type, cache_name, cache):
cache_hits.labels(cache_name).set(self.hits)
cache_evicted.labels(cache_name).set(self.evicted_size)
cache_total.labels(cache_name).set(self.hits + self.misses)
+ if collect_callback:
+ collect_callback()
except Exception as e:
- logger.warn("Error calculating metrics for %s: %s", cache_name, e)
+ logger.warning("Error calculating metrics for %s: %s", cache_name, e)
raise
yield GaugeMetricFamily("__unused", "")
@@ -104,8 +120,8 @@ def register_cache(cache_type, cache_name, cache):
KNOWN_KEYS = {
- key: key for key in
- (
+ key: key
+ for key in (
"auth_events",
"content",
"depth",
@@ -150,7 +166,7 @@ def intern_dict(dictionary):
def _intern_known_values(key, value):
- intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key",)
+ intern_keys = ("event_id", "room_id", "sender", "user_id", "type", "state_key")
if key in intern_keys:
return intern_string(value)
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 187510576a..2e8f6543e5 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -17,32 +17,53 @@ import functools
import inspect
import logging
import threading
-from collections import namedtuple
+from typing import Any, Tuple, Union, cast
+from weakref import WeakValueDictionary
-import six
-from six import itervalues, string_types
+from six import itervalues
+
+from prometheus_client import Gauge
+from typing_extensions import Protocol
from twisted.internet import defer
-from synapse.util import logcontext, unwrapFirstError
+from synapse.logging.context import make_deferred_yieldable, preserve_fn
+from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
-from synapse.util.stringutils import to_ascii
from . import register_cache
logger = logging.getLogger(__name__)
+CacheKey = Union[Tuple, Any]
+
+
+class _CachedFunction(Protocol):
+ invalidate = None # type: Any
+ invalidate_all = None # type: Any
+ invalidate_many = None # type: Any
+ prefill = None # type: Any
+ cache = None # type: Any
+ num_args = None # type: Any
+
+ def __name__(self):
+ ...
+
+
+cache_pending_metric = Gauge(
+ "synapse_util_caches_cache_pending",
+ "Number of lookups currently pending for this cache",
+ ["name"],
+)
_CacheSentinel = object()
class CacheEntry(object):
- __slots__ = [
- "deferred", "callbacks", "invalidated"
- ]
+ __slots__ = ["deferred", "callbacks", "invalidated"]
def __init__(self, deferred, callbacks):
self.deferred = deferred
@@ -73,7 +94,9 @@ class Cache(object):
self._pending_deferred_cache = cache_type()
self.cache = LruCache(
- max_size=max_entries, keylen=keylen, cache_type=cache_type,
+ max_size=max_entries,
+ keylen=keylen,
+ cache_type=cache_type,
size_callback=(lambda d: len(d)) if iterable else None,
evicted_callback=self._on_evicted,
)
@@ -81,11 +104,19 @@ class Cache(object):
self.name = name
self.keylen = keylen
self.thread = None
- self.metrics = register_cache("cache", name, self.cache)
+ self.metrics = register_cache(
+ "cache",
+ name,
+ self.cache,
+ collect_callback=self._metrics_collection_callback,
+ )
def _on_evicted(self, evicted_count):
self.metrics.inc_evictions(evicted_count)
+ def _metrics_collection_callback(self):
+ cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))
+
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
@@ -107,7 +138,7 @@ class Cache(object):
update_metrics (bool): whether to update the cache hit rate metrics
Returns:
- Either a Deferred or the raw result
+ Either an ObservableDeferred or the raw result
"""
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
@@ -131,12 +162,14 @@ class Cache(object):
return default
def set(self, key, value, callback=None):
+ if not isinstance(value, defer.Deferred):
+ raise TypeError("not a Deferred")
+
callbacks = [callback] if callback else []
self.check_thread()
- entry = CacheEntry(
- deferred=value,
- callbacks=callbacks,
- )
+ observable = ObservableDeferred(value, consumeErrors=True)
+ observer = defer.maybeDeferred(observable.observe)
+ entry = CacheEntry(deferred=observable, callbacks=callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
@@ -144,20 +177,31 @@ class Cache(object):
self._pending_deferred_cache[key] = entry
- def shuffle(result):
+ def compare_and_pop():
+ """Check if our entry is still the one in _pending_deferred_cache, and
+ if so, pop it.
+
+ Returns true if the entries matched.
+ """
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
+ return True
+
+ # oops, the _pending_deferred_cache has been updated since
+ # we started our query, so we are out of date.
+ #
+ # Better put back whatever we took out. (We do it this way
+ # round, rather than peeking into the _pending_deferred_cache
+ # and then removing on a match, to make the common case faster)
+ if existing_entry is not None:
+ self._pending_deferred_cache[key] = existing_entry
+
+ return False
+
+ def cb(result):
+ if compare_and_pop():
self.cache.set(key, result, entry.callbacks)
else:
- # oops, the _pending_deferred_cache has been updated since
- # we started our query, so we are out of date.
- #
- # Better put back whatever we took out. (We do it this way
- # round, rather than peeking into the _pending_deferred_cache
- # and then removing on a match, to make the common case faster)
- if existing_entry is not None:
- self._pending_deferred_cache[key] = existing_entry
-
# we're not going to put this entry into the cache, so need
# to make sure that the invalidation callbacks are called.
# That was probably done when _pending_deferred_cache was
@@ -165,9 +209,16 @@ class Cache(object):
# `invalidate` being previously called, in which case it may
# not have been. Either way, let's double-check now.
entry.invalidate()
- return result
- entry.deferred.addCallback(shuffle)
+ def eb(_fail):
+ compare_and_pop()
+ entry.invalidate()
+
+ # once the deferred completes, we can move the entry from the
+ # _pending_deferred_cache to the real cache.
+ #
+ observer.addCallbacks(cb, eb)
+ return observable
def prefill(self, key, value, callback=None):
callbacks = [callback] if callback else []
@@ -191,9 +242,7 @@ class Cache(object):
def invalidate_many(self, key):
self.check_thread()
if not isinstance(key, tuple):
- raise TypeError(
- "The cache key must be a tuple not %r" % (type(key),)
- )
+ raise TypeError("The cache key must be a tuple not %r" % (type(key),))
self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the
@@ -212,7 +261,9 @@ class Cache(object):
class _CacheDescriptorBase(object):
- def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
+ def __init__(
+ self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False
+ ):
self.orig = orig
if inlineCallbacks:
@@ -220,7 +271,7 @@ class _CacheDescriptorBase(object):
else:
self.function_to_call = orig
- arg_spec = inspect.getargspec(orig)
+ arg_spec = inspect.getfullargspec(orig)
all_args = arg_spec.args
if "cache_context" in all_args:
@@ -244,29 +295,25 @@ class _CacheDescriptorBase(object):
raise Exception(
"Not enough explicit positional arguments to key off for %r: "
"got %i args, but wanted %i. (@cached cannot key off *args or "
- "**kwargs)"
- % (orig.__name__, len(all_args), num_args)
+ "**kwargs)" % (orig.__name__, len(all_args), num_args)
)
self.num_args = num_args
# list of the names of the args used as the cache key
- self.arg_names = all_args[1:num_args + 1]
+ self.arg_names = all_args[1 : num_args + 1]
# self.arg_defaults is a map of arg name to its default value for each
# argument that has a default value
if arg_spec.defaults:
- self.arg_defaults = dict(zip(
- all_args[-len(arg_spec.defaults):],
- arg_spec.defaults
- ))
+ self.arg_defaults = dict(
+ zip(all_args[-len(arg_spec.defaults) :], arg_spec.defaults)
+ )
else:
self.arg_defaults = {}
if "cache_context" in self.arg_names:
- raise Exception(
- "cache_context arg cannot be included among the cache keys"
- )
+ raise Exception("cache_context arg cannot be included among the cache keys")
self.add_cache_context = cache_context
@@ -297,19 +344,31 @@ class CacheDescriptor(_CacheDescriptorBase):
def foo(self, key, cache_context):
r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
- defer.returnValue(r1 + r2)
+ return r1 + r2
Args:
num_args (int): number of positional arguments (excluding ``self`` and
``cache_context``) to use as cache keys. Defaults to all named
args of the function.
"""
- def __init__(self, orig, max_entries=1000, num_args=None, tree=False,
- inlineCallbacks=False, cache_context=False, iterable=False):
+
+ def __init__(
+ self,
+ orig,
+ max_entries=1000,
+ num_args=None,
+ tree=False,
+ inlineCallbacks=False,
+ cache_context=False,
+ iterable=False,
+ ):
super(CacheDescriptor, self).__init__(
- orig, num_args=num_args, inlineCallbacks=inlineCallbacks,
- cache_context=cache_context)
+ orig,
+ num_args=num_args,
+ inlineCallbacks=inlineCallbacks,
+ cache_context=cache_context,
+ )
max_entries = int(max_entries * get_cache_factor_for(orig.__name__))
@@ -356,12 +415,14 @@ class CacheDescriptor(_CacheDescriptorBase):
return args[0]
else:
return self.arg_defaults[nm]
+
else:
+
def get_cache_key(args, kwargs):
return tuple(get_cache_key_gen(args, kwargs))
@functools.wraps(self.orig)
- def wrapped(*args, **kwargs):
+ def _wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)
@@ -371,7 +432,7 @@ class CacheDescriptor(_CacheDescriptorBase):
# Add our own `cache_context` to argument list if the wrapped function
# has asked for one
if self.add_cache_context:
- kwargs["cache_context"] = _CacheContext(cache, cache_key)
+ kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
try:
cached_result_d = cache.get(cache_key, callback=invalidate_callback)
@@ -379,12 +440,11 @@ class CacheDescriptor(_CacheDescriptorBase):
if isinstance(cached_result_d, ObservableDeferred):
observer = cached_result_d.observe()
else:
- observer = cached_result_d
+ observer = defer.succeed(cached_result_d)
except KeyError:
ret = defer.maybeDeferred(
- logcontext.preserve_fn(self.function_to_call),
- obj, *args, **kwargs
+ preserve_fn(self.function_to_call), obj, *args, **kwargs
)
def onErr(f):
@@ -393,20 +453,12 @@ class CacheDescriptor(_CacheDescriptorBase):
ret.addErrback(onErr)
- # If our cache_key is a string on py2, try to convert to ascii
- # to save a bit of space in large caches. Py3 does this
- # internally automatically.
- if six.PY2 and isinstance(cache_key, string_types):
- cache_key = to_ascii(cache_key)
-
- result_d = ObservableDeferred(ret, consumeErrors=True)
- cache.set(cache_key, result_d, callback=invalidate_callback)
+ result_d = cache.set(cache_key, ret, callback=invalidate_callback)
observer = result_d.observe()
- if isinstance(observer, defer.Deferred):
- return logcontext.make_deferred_yieldable(observer)
- else:
- return observer
+ return make_deferred_yieldable(observer)
+
+ wrapped = cast(_CachedFunction, _wrapped)
if self.num_args == 1:
wrapped.invalidate = lambda key: cache.invalidate(key[0])
@@ -432,13 +484,13 @@ class CacheListDescriptor(_CacheDescriptorBase):
Given a list of keys it looks in the cache to find any hits, then passes
the list of missing keys to the wrapped function.
- Once wrapped, the function returns either a Deferred which resolves to
- the list of results, or (if all results were cached), just the list of
- results.
+ Once wrapped, the function returns a Deferred which resolves to the list
+ of results.
"""
- def __init__(self, orig, cached_method_name, list_name, num_args=None,
- inlineCallbacks=False):
+ def __init__(
+ self, orig, cached_method_name, list_name, num_args=None, inlineCallbacks=False
+ ):
"""
Args:
orig (function)
@@ -451,7 +503,8 @@ class CacheListDescriptor(_CacheDescriptorBase):
be wrapped by defer.inlineCallbacks
"""
super(CacheListDescriptor, self).__init__(
- orig, num_args=num_args, inlineCallbacks=inlineCallbacks)
+ orig, num_args=num_args, inlineCallbacks=inlineCallbacks
+ )
self.list_name = list_name
@@ -463,7 +516,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
if self.list_name not in self.arg_names:
raise Exception(
"Couldn't see arguments %r for %r."
- % (self.list_name, cached_method_name,)
+ % (self.list_name, cached_method_name)
)
def __get__(self, obj, objtype=None):
@@ -494,8 +547,10 @@ class CacheListDescriptor(_CacheDescriptorBase):
# If the cache takes a single arg then that is used as the key,
# otherwise a tuple is used.
if num_args == 1:
+
def arg_to_cache_key(arg):
return arg
+
else:
keylist = list(keyargs)
@@ -505,8 +560,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
for arg in list_args:
try:
- res = cache.get(arg_to_cache_key(arg),
- callback=invalidate_callback)
+ res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
if not isinstance(res, ObservableDeferred):
results[arg] = res
elif not res.has_succeeded():
@@ -519,7 +573,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
missing.add(arg)
if missing:
- # we need an observable deferred for each entry in the list,
+ # we need a deferred for each entry in the list,
# which we put in the cache. Each deferred resolves with the
# relevant result for that key.
deferreds_map = {}
@@ -527,8 +581,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
deferred = defer.Deferred()
deferreds_map[arg] = deferred
key = arg_to_cache_key(arg)
- observable = ObservableDeferred(deferred)
- cache.set(key, observable, callback=invalidate_callback)
+ cache.set(key, deferred, callback=invalidate_callback)
def complete_all(res):
# the wrapped function has completed. It returns a
@@ -554,40 +607,62 @@ class CacheListDescriptor(_CacheDescriptorBase):
args_to_call = dict(arg_dict)
args_to_call[self.list_name] = list(missing)
- cached_defers.append(defer.maybeDeferred(
- logcontext.preserve_fn(self.function_to_call),
- **args_to_call
- ).addCallbacks(complete_all, errback))
+ cached_defers.append(
+ defer.maybeDeferred(
+ preserve_fn(self.function_to_call), **args_to_call
+ ).addCallbacks(complete_all, errback)
+ )
if cached_defers:
- d = defer.gatherResults(
- cached_defers,
- consumeErrors=True,
- ).addCallbacks(
- lambda _: results,
- unwrapFirstError
+ d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
+ lambda _: results, unwrapFirstError
)
- return logcontext.make_deferred_yieldable(d)
+ return make_deferred_yieldable(d)
else:
- return results
+ return defer.succeed(results)
obj.__dict__[self.orig.__name__] = wrapped
return wrapped
-class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
- # We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
- # which namedtuple does for us (i.e. two _CacheContext are the same if
- # their caches and keys match). This is important in particular to
- # dedupe when we add callbacks to lru cache nodes, otherwise the number
- # of callbacks would grow.
- def invalidate(self):
- self.cache.invalidate(self.key)
+class _CacheContext:
+ """Holds cache information from the cached function higher in the calling order.
+
+ Can be used to invalidate the higher level cache entry if something changes
+ on a lower level.
+ """
+
+ _cache_context_objects = (
+ WeakValueDictionary()
+ ) # type: WeakValueDictionary[Tuple[Cache, CacheKey], _CacheContext]
+
+ def __init__(self, cache, cache_key): # type: (Cache, CacheKey) -> None
+ self._cache = cache
+ self._cache_key = cache_key
+
+ def invalidate(self): # type: () -> None
+ """Invalidates the cache entry referred to by the context."""
+ self._cache.invalidate(self._cache_key)
+
+ @classmethod
+ def get_instance(cls, cache, cache_key): # type: (Cache, CacheKey) -> _CacheContext
+ """Returns an instance constructed with the given arguments.
+
+ A new instance is only created if none already exists.
+ """
+
+ # We make sure there are no identical _CacheContext instances. This is
+ # important in particular to dedupe when we add callbacks to lru cache
+ # nodes, otherwise the number of callbacks would grow.
+ return cls._cache_context_objects.setdefault(
+ (cache, cache_key), cls(cache, cache_key)
+ )
-def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
- iterable=False):
+def cached(
+ max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
+):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
@@ -598,8 +673,9 @@ def cached(max_entries=1000, num_args=None, tree=False, cache_context=False,
)
-def cachedInlineCallbacks(max_entries=1000, num_args=None, tree=False,
- cache_context=False, iterable=False):
+def cachedInlineCallbacks(
+ max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
+):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 6c0b5a4094..6834e6f3ae 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -35,6 +35,7 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va
there.
value (dict): The full or partial dict value
"""
+
def __len__(self):
return len(self.value)
@@ -84,13 +85,15 @@ class DictionaryCache(object):
self.metrics.inc_hits()
if dict_keys is None:
- return DictionaryEntry(entry.full, entry.known_absent, dict(entry.value))
+ return DictionaryEntry(
+ entry.full, entry.known_absent, dict(entry.value)
+ )
else:
- return DictionaryEntry(entry.full, entry.known_absent, {
- k: entry.value[k]
- for k in dict_keys
- if k in entry.value
- })
+ return DictionaryEntry(
+ entry.full,
+ entry.known_absent,
+ {k: entry.value[k] for k in dict_keys if k in entry.value},
+ )
self.metrics.inc_misses()
return DictionaryEntry(False, set(), {})
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index f369780277..cddf1ed515 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -28,8 +28,15 @@ SENTINEL = object()
class ExpiringCache(object):
- def __init__(self, cache_name, clock, max_len=0, expiry_ms=0,
- reset_expiry_on_get=False, iterable=False):
+ def __init__(
+ self,
+ cache_name,
+ clock,
+ max_len=0,
+ expiry_ms=0,
+ reset_expiry_on_get=False,
+ iterable=False,
+ ):
"""
Args:
cache_name (str): Name of this cache, used for logging.
@@ -67,8 +74,7 @@ class ExpiringCache(object):
def f():
return run_as_background_process(
- "prune_cache_%s" % self._cache_name,
- self._prune_cache,
+ "prune_cache_%s" % self._cache_name, self._prune_cache
)
self._clock.looping_call(f, self._expiry_ms / 2)
@@ -153,7 +159,9 @@ class ExpiringCache(object):
logger.debug(
"[%s] _prune_cache before: %d, after len: %d",
- self._cache_name, begin_length, len(self)
+ self._cache_name,
+ begin_length,
+ len(self),
)
def __len__(self):
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index b684f24e7b..1536cb64f3 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -49,8 +49,15 @@ class LruCache(object):
Can also set callbacks on objects when getting/setting which are fired
when that key gets invalidated/evicted.
"""
- def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None,
- evicted_callback=None):
+
+ def __init__(
+ self,
+ max_size,
+ keylen=1,
+ cache_type=dict,
+ size_callback=None,
+ evicted_callback=None,
+ ):
"""
Args:
max_size (int):
@@ -93,9 +100,12 @@ class LruCache(object):
cached_cache_len = [0]
if size_callback is not None:
+
def cache_len():
return cached_cache_len[0]
+
else:
+
def cache_len():
return len(cache)
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index afb03b2e1b..b68f9fe0d4 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -16,9 +16,9 @@ import logging
from twisted.internet import defer
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import register_cache
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
@@ -35,12 +35,10 @@ class ResponseCache(object):
self.pending_result_cache = {} # Requests that haven't finished yet.
self.clock = hs.get_clock()
- self.timeout_sec = timeout_ms / 1000.
+ self.timeout_sec = timeout_ms / 1000.0
self._name = name
- self._metrics = register_cache(
- "response_cache", name, self
- )
+ self._metrics = register_cache("response_cache", name, self)
def size(self):
return len(self.pending_result_cache)
@@ -80,7 +78,7 @@ class ResponseCache(object):
*deferred* should run its callbacks in the sentinel logcontext (ie,
you should wrap normal synapse deferreds with
- logcontext.run_in_background).
+ synapse.logging.context.run_in_background).
Can return either a new Deferred (which also doesn't follow the synapse
logcontext rules), or, if *deferred* was already complete, the actual
@@ -100,8 +98,7 @@ class ResponseCache(object):
def remove(r):
if self.timeout_sec:
self.clock.call_later(
- self.timeout_sec,
- self.pending_result_cache.pop, key, None,
+ self.timeout_sec, self.pending_result_cache.pop, key, None
)
else:
self.pending_result_cache.pop(key, None)
@@ -124,7 +121,7 @@ class ResponseCache(object):
@defer.inlineCallbacks
def handle_request(request):
# etc
- defer.returnValue(result)
+ return result
result = yield response_cache.wrap(
key,
@@ -140,21 +137,22 @@ class ResponseCache(object):
*args: positional parameters to pass to the callback, if it is used
- **kwargs: named paramters to pass to the callback, if it is used
+ **kwargs: named parameters to pass to the callback, if it is used
Returns:
twisted.internet.defer.Deferred: yieldable result
"""
result = self.get(key)
if not result:
- logger.info("[%s]: no cached result for [%s], calculating new one",
- self._name, key)
+ logger.debug(
+ "[%s]: no cached result for [%s], calculating new one", self._name, key
+ )
d = run_in_background(callback, *args, **kwargs)
result = self.set(key, d)
elif not isinstance(result, defer.Deferred) or result.called:
- logger.info("[%s]: using completed cached result for [%s]",
- self._name, key)
+ logger.info("[%s]: using completed cached result for [%s]", self._name, key)
else:
- logger.info("[%s]: using incomplete cached result for [%s]",
- self._name, key)
+ logger.info(
+ "[%s]: using incomplete cached result for [%s]", self._name, key
+ )
return make_deferred_yieldable(result)
diff --git a/synapse/util/caches/snapshot_cache.py b/synapse/util/caches/snapshot_cache.py
deleted file mode 100644
index 8318db8d2c..0000000000
--- a/synapse/util/caches/snapshot_cache.py
+++ /dev/null
@@ -1,94 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.util.async_helpers import ObservableDeferred
-
-
-class SnapshotCache(object):
- """Cache for snapshots like the response of /initialSync.
- The response of initialSync only has to be a recent snapshot of the
- server state. It shouldn't matter to clients if it is a few minutes out
- of date.
-
- This caches a deferred response. Until the deferred completes it will be
- returned from the cache. This means that if the client retries the request
- while the response is still being computed, that original response will be
- used rather than trying to compute a new response.
-
- Once the deferred completes it will removed from the cache after 5 minutes.
- We delay removing it from the cache because a client retrying its request
- could race with us finishing computing the response.
-
- Rather than tracking precisely how long something has been in the cache we
- keep two generations of completed responses. Every 5 minutes discard the
- old generation, move the new generation to the old generation, and set the
- new generation to be empty. This means that a result will be in the cache
- somewhere between 5 and 10 minutes.
- """
-
- DURATION_MS = 5 * 60 * 1000 # Cache results for 5 minutes.
-
- def __init__(self):
- self.pending_result_cache = {} # Request that haven't finished yet.
- self.prev_result_cache = {} # The older requests that have finished.
- self.next_result_cache = {} # The newer requests that have finished.
- self.time_last_rotated_ms = 0
-
- def rotate(self, time_now_ms):
- # Rotate once if the cache duration has passed since the last rotation.
- if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS:
- self.prev_result_cache = self.next_result_cache
- self.next_result_cache = {}
- self.time_last_rotated_ms += self.DURATION_MS
-
- # Rotate again if the cache duration has passed twice since the last
- # rotation.
- if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS:
- self.prev_result_cache = self.next_result_cache
- self.next_result_cache = {}
- self.time_last_rotated_ms = time_now_ms
-
- def get(self, time_now_ms, key):
- self.rotate(time_now_ms)
- # This cache is intended to deduplicate requests, so we expect it to be
- # missed most of the time. So we just lookup the key in all of the
- # dictionaries rather than trying to short circuit the lookup if the
- # key is found.
- result = self.prev_result_cache.get(key)
- result = self.next_result_cache.get(key, result)
- result = self.pending_result_cache.get(key, result)
- if result is not None:
- return result.observe()
- else:
- return None
-
- def set(self, time_now_ms, key, deferred):
- self.rotate(time_now_ms)
-
- result = ObservableDeferred(deferred)
-
- self.pending_result_cache[key] = result
-
- def shuffle_along(r):
- # When the deferred completes we shuffle it along to the first
- # generation of the result cache. So that it will eventually
- # expire from the rotation of that cache.
- self.next_result_cache[key] = result
- self.pending_result_cache.pop(key, None)
- return r
-
- result.addBoth(shuffle_along)
-
- return result.observe()
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 625aedc940..235f64049c 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -77,9 +77,8 @@ class StreamChangeCache(object):
if stream_pos >= self._earliest_known_stream_pos:
changed_entities = {
- self._cache[k] for k in self._cache.islice(
- start=self._cache.bisect_right(stream_pos),
- )
+ self._cache[k]
+ for k in self._cache.islice(start=self._cache.bisect_right(stream_pos))
}
result = changed_entities.intersection(entities)
@@ -114,8 +113,10 @@ class StreamChangeCache(object):
assert type(stream_pos) is int
if stream_pos >= self._earliest_known_stream_pos:
- return [self._cache[k] for k in self._cache.islice(
- start=self._cache.bisect_right(stream_pos))]
+ return [
+ self._cache[k]
+ for k in self._cache.islice(start=self._cache.bisect_right(stream_pos))
+ ]
else:
return None
@@ -136,7 +137,7 @@ class StreamChangeCache(object):
while len(self._cache) > self._max_size:
k, r = self._cache.popitem(0)
self._earliest_known_stream_pos = max(
- k, self._earliest_known_stream_pos,
+ k, self._earliest_known_stream_pos
)
self._entity_to_key.pop(r, None)
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index dd4c9e6067..2ea4e4e911 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -1,3 +1,5 @@
+from typing import Dict
+
from six import itervalues
SENTINEL = object()
@@ -9,9 +11,10 @@ class TreeCache(object):
efficiently.
Keys must be tuples.
"""
+
def __init__(self):
self.size = 0
- self.root = {}
+ self.root = {} # type: Dict
def __setitem__(self, key, value):
return self.set(key, value)
diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py
index 5ba1862506..99646c7cf0 100644
--- a/synapse/util/caches/ttlcache.py
+++ b/synapse/util/caches/ttlcache.py
@@ -55,7 +55,7 @@ class TTLCache(object):
if e != SENTINEL:
self._expiry_list.remove(e)
- entry = _CacheEntry(expiry_time=expiry, key=key, value=value)
+ entry = _CacheEntry(expiry_time=expiry, ttl=ttl, key=key, value=value)
self._data[key] = entry
self._expiry_list.add(entry)
@@ -87,7 +87,8 @@ class TTLCache(object):
key: key to look up
Returns:
- Tuple[Any, float]: the value from the cache, and the expiry time
+ Tuple[Any, float, float]: the value from the cache, the expiry time
+ and the TTL
Raises:
KeyError if the entry is not found
@@ -99,7 +100,7 @@ class TTLCache(object):
self._metrics.inc_misses()
raise
self._metrics.inc_hits()
- return e.value, e.expiry_time
+ return e.value, e.expiry_time, e.ttl
def pop(self, key, default=SENTINEL):
"""Remove a value from the cache
@@ -155,7 +156,9 @@ class TTLCache(object):
@attr.s(frozen=True, slots=True)
class _CacheEntry(object):
"""TTLCache entry"""
+
# expiry_time is the first attribute, so that entries are sorted by expiry.
expiry_time = attr.ib()
+ ttl = attr.ib()
key = attr.ib()
value = attr.ib()
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index e14c8bdfda..45af8d3eeb 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -17,8 +17,8 @@ import logging
from twisted.internet import defer
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
@@ -51,9 +51,7 @@ class Distributor(object):
if name in self.signals:
raise KeyError("%r already has a signal named %s" % (self, name))
- self.signals[name] = Signal(
- name,
- )
+ self.signals[name] = Signal(name)
if name in self.pre_registration:
signal = self.signals[name]
@@ -78,11 +76,7 @@ class Distributor(object):
if name not in self.signals:
raise KeyError("%r does not have a signal named %s" % (self, name))
- run_as_background_process(
- name,
- self.signals[name].fire,
- *args, **kwargs
- )
+ run_as_background_process(name, self.signals[name].fire, *args, **kwargs)
class Signal(object):
@@ -118,22 +112,23 @@ class Signal(object):
def eb(failure):
logger.warning(
"%s signal observer %s failed: %r",
- self.name, observer, failure,
+ self.name,
+ observer,
+ failure,
exc_info=(
failure.type,
failure.value,
- failure.getTracebackObject()))
+ failure.getTracebackObject(),
+ ),
+ )
return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
- deferreds = [
- run_in_background(do, o)
- for o in self.observers
- ]
+ deferreds = [run_in_background(do, o) for o in self.observers]
- return make_deferred_yieldable(defer.gatherResults(
- deferreds, consumeErrors=True,
- ))
+ return make_deferred_yieldable(
+ defer.gatherResults(deferreds, consumeErrors=True)
+ )
def __repr__(self):
return "<Signal name=%r>" % (self.name,)
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
index 629ed44149..8b17d1c8b8 100644
--- a/synapse/util/file_consumer.py
+++ b/synapse/util/file_consumer.py
@@ -17,7 +17,7 @@ from six.moves import queue
from twisted.internet import threads
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+from synapse.logging.context import make_deferred_yieldable, run_in_background
class BackgroundFileConsumer(object):
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 014edea971..f2ccd5e7c6 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -30,7 +30,7 @@ def freeze(o):
return o
try:
- return tuple([freeze(i) for i in o])
+ return tuple(freeze(i) for i in o)
except TypeError:
pass
@@ -60,11 +60,10 @@ def _handle_frozendict(obj):
# fishing the protected dict out of the object is a bit nasty,
# but we don't really want the overhead of copying the dict.
return obj._dict
- raise TypeError('Object of type %s is not JSON serializable' %
- obj.__class__.__name__)
+ raise TypeError(
+ "Object of type %s is not JSON serializable" % obj.__class__.__name__
+ )
# A JSONEncoder which is capable of encoding frozendics without barfing
-frozendict_json_encoder = json.JSONEncoder(
- default=_handle_frozendict,
-)
+frozendict_json_encoder = json.JSONEncoder(default=_handle_frozendict)
diff --git a/synapse/util/hash.py b/synapse/util/hash.py
new file mode 100644
index 0000000000..359168704e
--- /dev/null
+++ b/synapse/util/hash.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import hashlib
+
+import unpaddedbase64
+
+
+def sha256_and_url_safe_base64(input_text):
+ """SHA256 hash an input string, encode the digest as url-safe base64, and
+ return
+
+ :param input_text: string to hash
+ :type input_text: str
+
+ :returns a sha256 hashed and url-safe base64 encoded digest
+ :rtype: str
+ """
+ digest = hashlib.sha256(input_text.encode()).digest()
+ return unpaddedbase64.encode_base64(digest, urlsafe=True)
diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py
index 2d7ddc1cbe..3c0e8469f3 100644
--- a/synapse/util/httpresourcetree.py
+++ b/synapse/util/httpresourcetree.py
@@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
def create_resource_tree(desired_tree, root_resource):
- """Create the resource tree for this Home Server.
+ """Create the resource tree for this homeserver.
This in unduly complicated because Twisted does not support putting
child resources more than 1 level deep at a time.
@@ -45,7 +45,7 @@ def create_resource_tree(desired_tree, root_resource):
logger.info("Attaching %s to path %s", res, full_path)
last_resource = root_resource
- for path_seg in full_path.split(b'/')[1:-1]:
+ for path_seg in full_path.split(b"/")[1:-1]:
if path_seg not in last_resource.listNames():
# resource doesn't exist, so make a "dummy resource"
child_resource = NoResource()
@@ -60,7 +60,7 @@ def create_resource_tree(desired_tree, root_resource):
# ===========================
# now attach the actual desired resource
- last_path_seg = full_path.split(b'/')[-1]
+ last_path_seg = full_path.split(b"/")[-1]
# if there is already a resource here, thieve its children and
# replace it
@@ -70,9 +70,7 @@ def create_resource_tree(desired_tree, root_resource):
# to be replaced with the desired resource.
existing_dummy_resource = resource_mappings[res_id]
for child_name in existing_dummy_resource.listNames():
- child_res_id = _resource_id(
- existing_dummy_resource, child_name
- )
+ child_res_id = _resource_id(existing_dummy_resource, child_name)
child_resource = resource_mappings[child_res_id]
# steal the children
res.putChild(child_name, child_resource)
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
new file mode 100644
index 0000000000..06faeebe7f
--- /dev/null
+++ b/synapse/util/iterutils.py
@@ -0,0 +1,48 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from itertools import islice
+from typing import Iterable, Iterator, Sequence, Tuple, TypeVar
+
+T = TypeVar("T")
+
+
+def batch_iter(iterable: Iterable[T], size: int) -> Iterator[Tuple[T]]:
+ """batch an iterable up into tuples with a maximum size
+
+ Args:
+ iterable (iterable): the iterable to slice
+ size (int): the maximum batch size
+
+ Returns:
+ an iterator over the chunks
+ """
+ # make sure we can deal with iterables like lists too
+ sourceiter = iter(iterable)
+ # call islice until it returns an empty tuple
+ return iter(lambda: tuple(islice(sourceiter, size)), ())
+
+
+ISeq = TypeVar("ISeq", bound=Sequence, covariant=True)
+
+
+def chunk_seq(iseq: ISeq, maxlen: int) -> Iterable[ISeq]:
+ """Split the given sequence into chunks of the given size
+
+ The last chunk may be shorter than the given size.
+
+ If the input is empty, no chunks are returned.
+ """
+ return (iseq[i : i + maxlen] for i in range(0, len(iseq), maxlen))
diff --git a/synapse/util/jsonobject.py b/synapse/util/jsonobject.py
index d668e5a6b8..6dce03dd3a 100644
--- a/synapse/util/jsonobject.py
+++ b/synapse/util/jsonobject.py
@@ -70,7 +70,8 @@ class JsonEncodedObject(object):
dict
"""
d = {
- k: _encode(v) for (k, v) in self.__dict__.items()
+ k: _encode(v)
+ for (k, v) in self.__dict__.items()
if k in self.valid_keys and k not in self.internal_keys
}
d.update(self.unrecognized_keys)
@@ -78,7 +79,8 @@ class JsonEncodedObject(object):
def get_internal_dict(self):
d = {
- k: _encode(v, internal=True) for (k, v) in self.__dict__.items()
+ k: _encode(v, internal=True)
+ for (k, v) in self.__dict__.items()
if k in self.valid_keys
}
d.update(self.unrecognized_keys)
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index fe412355d8..40e5c10a49 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -1,4 +1,4 @@
-# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,633 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-""" Thread-local-alike tracking of log contexts within synapse
-
-This module provides objects and utilities for tracking contexts through
-synapse code, so that log lines can include a request identifier, and so that
-CPU and database activity can be accounted for against the request that caused
-them.
-
-See doc/log_contexts.rst for details on how this works.
+"""
+Backwards compatibility re-exports of ``synapse.logging.context`` functionality.
"""
-import logging
-import threading
-
-from twisted.internet import defer, threads
-
-logger = logging.getLogger(__name__)
-
-try:
- import resource
-
- # Python doesn't ship with a definition of RUSAGE_THREAD but it's defined
- # to be 1 on linux so we hard code it.
- RUSAGE_THREAD = 1
-
- # If the system doesn't support RUSAGE_THREAD then this should throw an
- # exception.
- resource.getrusage(RUSAGE_THREAD)
-
- def get_thread_resource_usage():
- return resource.getrusage(RUSAGE_THREAD)
-except Exception:
- # If the system doesn't support resource.getrusage(RUSAGE_THREAD) then we
- # won't track resource usage by returning None.
- def get_thread_resource_usage():
- return None
-
-
-class ContextResourceUsage(object):
- """Object for tracking the resources used by a log context
-
- Attributes:
- ru_utime (float): user CPU time (in seconds)
- ru_stime (float): system CPU time (in seconds)
- db_txn_count (int): number of database transactions done
- db_sched_duration_sec (float): amount of time spent waiting for a
- database connection
- db_txn_duration_sec (float): amount of time spent doing database
- transactions (excluding scheduling time)
- evt_db_fetch_count (int): number of events requested from the database
- """
-
- __slots__ = [
- "ru_stime", "ru_utime",
- "db_txn_count", "db_txn_duration_sec", "db_sched_duration_sec",
- "evt_db_fetch_count",
- ]
-
- def __init__(self, copy_from=None):
- """Create a new ContextResourceUsage
-
- Args:
- copy_from (ContextResourceUsage|None): if not None, an object to
- copy stats from
- """
- if copy_from is None:
- self.reset()
- else:
- self.ru_utime = copy_from.ru_utime
- self.ru_stime = copy_from.ru_stime
- self.db_txn_count = copy_from.db_txn_count
-
- self.db_txn_duration_sec = copy_from.db_txn_duration_sec
- self.db_sched_duration_sec = copy_from.db_sched_duration_sec
- self.evt_db_fetch_count = copy_from.evt_db_fetch_count
-
- def copy(self):
- return ContextResourceUsage(copy_from=self)
-
- def reset(self):
- self.ru_stime = 0.
- self.ru_utime = 0.
- self.db_txn_count = 0
-
- self.db_txn_duration_sec = 0
- self.db_sched_duration_sec = 0
- self.evt_db_fetch_count = 0
-
- def __repr__(self):
- return ("<ContextResourceUsage ru_stime='%r', ru_utime='%r', "
- "db_txn_count='%r', db_txn_duration_sec='%r', "
- "db_sched_duration_sec='%r', evt_db_fetch_count='%r'>") % (
- self.ru_stime,
- self.ru_utime,
- self.db_txn_count,
- self.db_txn_duration_sec,
- self.db_sched_duration_sec,
- self.evt_db_fetch_count,)
-
- def __iadd__(self, other):
- """Add another ContextResourceUsage's stats to this one's.
-
- Args:
- other (ContextResourceUsage): the other resource usage object
- """
- self.ru_utime += other.ru_utime
- self.ru_stime += other.ru_stime
- self.db_txn_count += other.db_txn_count
- self.db_txn_duration_sec += other.db_txn_duration_sec
- self.db_sched_duration_sec += other.db_sched_duration_sec
- self.evt_db_fetch_count += other.evt_db_fetch_count
- return self
-
- def __isub__(self, other):
- self.ru_utime -= other.ru_utime
- self.ru_stime -= other.ru_stime
- self.db_txn_count -= other.db_txn_count
- self.db_txn_duration_sec -= other.db_txn_duration_sec
- self.db_sched_duration_sec -= other.db_sched_duration_sec
- self.evt_db_fetch_count -= other.evt_db_fetch_count
- return self
-
- def __add__(self, other):
- res = ContextResourceUsage(copy_from=self)
- res += other
- return res
-
- def __sub__(self, other):
- res = ContextResourceUsage(copy_from=self)
- res -= other
- return res
-
-
-class LoggingContext(object):
- """Additional context for log formatting. Contexts are scoped within a
- "with" block.
-
- If a parent is given when creating a new context, then:
- - logging fields are copied from the parent to the new context on entry
- - when the new context exits, the cpu usage stats are copied from the
- child to the parent
-
- Args:
- name (str): Name for the context for debugging.
- parent_context (LoggingContext|None): The parent of the new context
- """
-
- __slots__ = [
- "previous_context", "name", "parent_context",
- "_resource_usage",
- "usage_start",
- "main_thread", "alive",
- "request", "tag",
- ]
-
- thread_local = threading.local()
-
- class Sentinel(object):
- """Sentinel to represent the root context"""
-
- __slots__ = []
-
- def __str__(self):
- return "sentinel"
-
- def copy_to(self, record):
- pass
-
- def start(self):
- pass
-
- def stop(self):
- pass
-
- def add_database_transaction(self, duration_sec):
- pass
-
- def add_database_scheduled(self, sched_sec):
- pass
-
- def record_event_fetch(self, event_count):
- pass
-
- def __nonzero__(self):
- return False
- __bool__ = __nonzero__ # python3
-
- sentinel = Sentinel()
-
- def __init__(self, name=None, parent_context=None, request=None):
- self.previous_context = LoggingContext.current_context()
- self.name = name
-
- # track the resources used by this context so far
- self._resource_usage = ContextResourceUsage()
-
- # If alive has the thread resource usage when the logcontext last
- # became active.
- self.usage_start = None
-
- self.main_thread = threading.current_thread()
- self.request = None
- self.tag = ""
- self.alive = True
-
- self.parent_context = parent_context
-
- if self.parent_context is not None:
- self.parent_context.copy_to(self)
-
- if request is not None:
- # the request param overrides the request from the parent context
- self.request = request
-
- def __str__(self):
- if self.request:
- return str(self.request)
- return "%s@%x" % (self.name, id(self))
-
- @classmethod
- def current_context(cls):
- """Get the current logging context from thread local storage
-
- Returns:
- LoggingContext: the current logging context
- """
- return getattr(cls.thread_local, "current_context", cls.sentinel)
-
- @classmethod
- def set_current_context(cls, context):
- """Set the current logging context in thread local storage
- Args:
- context(LoggingContext): The context to activate.
- Returns:
- The context that was previously active
- """
- current = cls.current_context()
-
- if current is not context:
- current.stop()
- cls.thread_local.current_context = context
- context.start()
- return current
-
- def __enter__(self):
- """Enters this logging context into thread local storage"""
- old_context = self.set_current_context(self)
- if self.previous_context != old_context:
- logger.warn(
- "Expected previous context %r, found %r",
- self.previous_context, old_context
- )
- self.alive = True
-
- return self
-
- def __exit__(self, type, value, traceback):
- """Restore the logging context in thread local storage to the state it
- was before this context was entered.
- Returns:
- None to avoid suppressing any exceptions that were thrown.
- """
- current = self.set_current_context(self.previous_context)
- if current is not self:
- if current is self.sentinel:
- logger.warning("Expected logging context %s was lost", self)
- else:
- logger.warning(
- "Expected logging context %s but found %s", self, current
- )
- self.previous_context = None
- self.alive = False
-
- # if we have a parent, pass our CPU usage stats on
- if (
- self.parent_context is not None
- and hasattr(self.parent_context, '_resource_usage')
- ):
- self.parent_context._resource_usage += self._resource_usage
-
- # reset them in case we get entered again
- self._resource_usage.reset()
-
- def copy_to(self, record):
- """Copy logging fields from this context to a log record or
- another LoggingContext
- """
-
- # 'request' is the only field we currently use in the logger, so that's
- # all we need to copy
- record.request = self.request
-
- def start(self):
- if threading.current_thread() is not self.main_thread:
- logger.warning("Started logcontext %s on different thread", self)
- return
-
- # If we haven't already started record the thread resource usage so
- # far
- if not self.usage_start:
- self.usage_start = get_thread_resource_usage()
-
- def stop(self):
- if threading.current_thread() is not self.main_thread:
- logger.warning("Stopped logcontext %s on different thread", self)
- return
-
- # When we stop, let's record the cpu used since we started
- if not self.usage_start:
- logger.warning(
- "Called stop on logcontext %s without calling start", self,
- )
- return
-
- usage_end = get_thread_resource_usage()
-
- self._resource_usage.ru_utime += usage_end.ru_utime - self.usage_start.ru_utime
- self._resource_usage.ru_stime += usage_end.ru_stime - self.usage_start.ru_stime
-
- self.usage_start = None
-
- def get_resource_usage(self):
- """Get resources used by this logcontext so far.
-
- Returns:
- ContextResourceUsage: a *copy* of the object tracking resource
- usage so far
- """
- # we always return a copy, for consistency
- res = self._resource_usage.copy()
-
- # If we are on the correct thread and we're currently running then we
- # can include resource usage so far.
- is_main_thread = threading.current_thread() is self.main_thread
- if self.alive and self.usage_start and is_main_thread:
- current = get_thread_resource_usage()
- res.ru_utime += current.ru_utime - self.usage_start.ru_utime
- res.ru_stime += current.ru_stime - self.usage_start.ru_stime
-
- return res
-
- def add_database_transaction(self, duration_sec):
- self._resource_usage.db_txn_count += 1
- self._resource_usage.db_txn_duration_sec += duration_sec
-
- def add_database_scheduled(self, sched_sec):
- """Record a use of the database pool
-
- Args:
- sched_sec (float): number of seconds it took us to get a
- connection
- """
- self._resource_usage.db_sched_duration_sec += sched_sec
-
- def record_event_fetch(self, event_count):
- """Record a number of events being fetched from the db
-
- Args:
- event_count (int): number of events being fetched
- """
- self._resource_usage.evt_db_fetch_count += event_count
-
-
-class LoggingContextFilter(logging.Filter):
- """Logging filter that adds values from the current logging context to each
- record.
- Args:
- **defaults: Default values to avoid formatters complaining about
- missing fields
- """
- def __init__(self, **defaults):
- self.defaults = defaults
-
- def filter(self, record):
- """Add each fields from the logging contexts to the record.
- Returns:
- True to include the record in the log output.
- """
- context = LoggingContext.current_context()
- for key, value in self.defaults.items():
- setattr(record, key, value)
-
- # context should never be None, but if it somehow ends up being, then
- # we end up in a death spiral of infinite loops, so let's check, for
- # robustness' sake.
- if context is not None:
- context.copy_to(record)
-
- return True
-
-
-class PreserveLoggingContext(object):
- """Captures the current logging context and restores it when the scope is
- exited. Used to restore the context after a function using
- @defer.inlineCallbacks is resumed by a callback from the reactor."""
-
- __slots__ = ["current_context", "new_context", "has_parent"]
-
- def __init__(self, new_context=None):
- if new_context is None:
- new_context = LoggingContext.sentinel
- self.new_context = new_context
-
- def __enter__(self):
- """Captures the current logging context"""
- self.current_context = LoggingContext.set_current_context(
- self.new_context
- )
-
- if self.current_context:
- self.has_parent = self.current_context.previous_context is not None
- if not self.current_context.alive:
- logger.debug(
- "Entering dead context: %s",
- self.current_context,
- )
-
- def __exit__(self, type, value, traceback):
- """Restores the current logging context"""
- context = LoggingContext.set_current_context(self.current_context)
-
- if context != self.new_context:
- if context is LoggingContext.sentinel:
- logger.warning("Expected logging context %s was lost", self.new_context)
- else:
- logger.warning(
- "Expected logging context %s but found %s",
- self.new_context,
- context,
- )
-
- if self.current_context is not LoggingContext.sentinel:
- if not self.current_context.alive:
- logger.debug(
- "Restoring dead context: %s",
- self.current_context,
- )
-
-
-def nested_logging_context(suffix, parent_context=None):
- """Creates a new logging context as a child of another.
-
- The nested logging context will have a 'request' made up of the parent context's
- request, plus the given suffix.
-
- CPU/db usage stats will be added to the parent context's on exit.
-
- Normal usage looks like:
-
- with nested_logging_context(suffix):
- # ... do stuff
-
- Args:
- suffix (str): suffix to add to the parent context's 'request'.
- parent_context (LoggingContext|None): parent context. Will use the current context
- if None.
-
- Returns:
- LoggingContext: new logging context.
- """
- if parent_context is None:
- parent_context = LoggingContext.current_context()
- return LoggingContext(
- parent_context=parent_context,
- request=parent_context.request + "-" + suffix,
- )
-
-
-def preserve_fn(f):
- """Function decorator which wraps the function with run_in_background"""
- def g(*args, **kwargs):
- return run_in_background(f, *args, **kwargs)
- return g
-
-
-def run_in_background(f, *args, **kwargs):
- """Calls a function, ensuring that the current context is restored after
- return from the function, and that the sentinel context is set once the
- deferred returned by the function completes.
-
- Useful for wrapping functions that return a deferred which you don't yield
- on (for instance because you want to pass it to deferred.gatherResults()).
-
- Note that if you completely discard the result, you should make sure that
- `f` doesn't raise any deferred exceptions, otherwise a scary-looking
- CRITICAL error about an unhandled error will be logged without much
- indication about where it came from.
- """
- current = LoggingContext.current_context()
- try:
- res = f(*args, **kwargs)
- except: # noqa: E722
- # the assumption here is that the caller doesn't want to be disturbed
- # by synchronous exceptions, so let's turn them into Failures.
- return defer.fail()
-
- if not isinstance(res, defer.Deferred):
- return res
-
- if res.called and not res.paused:
- # The function should have maintained the logcontext, so we can
- # optimise out the messing about
- return res
-
- # The function may have reset the context before returning, so
- # we need to restore it now.
- ctx = LoggingContext.set_current_context(current)
-
- # The original context will be restored when the deferred
- # completes, but there is nothing waiting for it, so it will
- # get leaked into the reactor or some other function which
- # wasn't expecting it. We therefore need to reset the context
- # here.
- #
- # (If this feels asymmetric, consider it this way: we are
- # effectively forking a new thread of execution. We are
- # probably currently within a ``with LoggingContext()`` block,
- # which is supposed to have a single entry and exit point. But
- # by spawning off another deferred, we are effectively
- # adding a new exit point.)
- res.addBoth(_set_context_cb, ctx)
- return res
-
-
-def make_deferred_yieldable(deferred):
- """Given a deferred, make it follow the Synapse logcontext rules:
-
- If the deferred has completed (or is not actually a Deferred), essentially
- does nothing (just returns another completed deferred with the
- result/failure).
-
- If the deferred has not yet completed, resets the logcontext before
- returning a deferred. Then, when the deferred completes, restores the
- current logcontext before running callbacks/errbacks.
-
- (This is more-or-less the opposite operation to run_in_background.)
- """
- if not isinstance(deferred, defer.Deferred):
- return deferred
-
- if deferred.called and not deferred.paused:
- # it looks like this deferred is ready to run any callbacks we give it
- # immediately. We may as well optimise out the logcontext faffery.
- return deferred
-
- # ok, we can't be sure that a yield won't block, so let's reset the
- # logcontext, and add a callback to the deferred to restore it.
- prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
- deferred.addBoth(_set_context_cb, prev_context)
- return deferred
-
-
-def _set_context_cb(result, context):
- """A callback function which just sets the logging context"""
- LoggingContext.set_current_context(context)
- return result
-
-
-def defer_to_thread(reactor, f, *args, **kwargs):
- """
- Calls the function `f` using a thread from the reactor's default threadpool and
- returns the result as a Deferred.
-
- Creates a new logcontext for `f`, which is created as a child of the current
- logcontext (so its CPU usage metrics will get attributed to the current
- logcontext). `f` should preserve the logcontext it is given.
-
- The result deferred follows the Synapse logcontext rules: you should `yield`
- on it.
-
- Args:
- reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread
- the Deferred will be invoked, and whose threadpool we should use for the
- function.
-
- Normally this will be hs.get_reactor().
-
- f (callable): The function to call.
-
- args: positional arguments to pass to f.
-
- kwargs: keyword arguments to pass to f.
-
- Returns:
- Deferred: A Deferred which fires a callback with the result of `f`, or an
- errback if `f` throws an exception.
- """
- return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs)
-
-
-def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
- """
- A wrapper for twisted.internet.threads.deferToThreadpool, which handles
- logcontexts correctly.
-
- Calls the function `f` using a thread from the given threadpool and returns
- the result as a Deferred.
-
- Creates a new logcontext for `f`, which is created as a child of the current
- logcontext (so its CPU usage metrics will get attributed to the current
- logcontext). `f` should preserve the logcontext it is given.
-
- The result deferred follows the Synapse logcontext rules: you should `yield`
- on it.
-
- Args:
- reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread
- the Deferred will be invoked. Normally this will be hs.get_reactor().
-
- threadpool (twisted.python.threadpool.ThreadPool): The threadpool to use for
- running `f`. Normally this will be hs.get_reactor().getThreadPool().
-
- f (callable): The function to call.
-
- args: positional arguments to pass to f.
-
- kwargs: keyword arguments to pass to f.
-
- Returns:
- Deferred: A Deferred which fires a callback with the result of `f`, or an
- errback if `f` throws an exception.
- """
- logcontext = LoggingContext.current_context()
-
- def g():
- with LoggingContext(parent_context=logcontext):
- return f(*args, **kwargs)
-
- return make_deferred_yieldable(
- threads.deferToThreadPool(reactor, threadpool, g)
- )
+from synapse.logging.context import (
+ LoggingContext,
+ LoggingContextFilter,
+ PreserveLoggingContext,
+ defer_to_thread,
+ make_deferred_yieldable,
+ nested_logging_context,
+ preserve_fn,
+ run_in_background,
+)
+
+__all__ = [
+ "defer_to_thread",
+ "LoggingContext",
+ "LoggingContextFilter",
+ "make_deferred_yieldable",
+ "nested_logging_context",
+ "preserve_fn",
+ "PreserveLoggingContext",
+ "run_in_background",
+]
diff --git a/synapse/util/logformatter.py b/synapse/util/logformatter.py
index a46bc47ce3..320e8f8174 100644
--- a/synapse/util/logformatter.py
+++ b/synapse/util/logformatter.py
@@ -1,5 +1,4 @@
-# -*- coding: utf-8 -*-
-# Copyright 2017 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,40 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+"""
+Backwards compatibility re-exports of ``synapse.logging.formatter`` functionality.
+"""
-import logging
-import traceback
+from synapse.logging.formatter import LogFormatter
-from six import StringIO
-
-
-class LogFormatter(logging.Formatter):
- """Log formatter which gives more detail for exceptions
-
- This is the same as the standard log formatter, except that when logging
- exceptions [typically via log.foo("msg", exc_info=1)], it prints the
- sequence that led up to the point at which the exception was caught.
- (Normally only stack frames between the point the exception was raised and
- where it was caught are logged).
- """
- def __init__(self, *args, **kwargs):
- super(LogFormatter, self).__init__(*args, **kwargs)
-
- def formatException(self, ei):
- sio = StringIO()
- (typ, val, tb) = ei
-
- # log the stack above the exception capture point if possible, but
- # check that we actually have an f_back attribute to work around
- # https://twistedmatrix.com/trac/ticket/9305
-
- if tb and hasattr(tb.tb_frame, 'f_back'):
- sio.write("Capture point (most recent call last):\n")
- traceback.print_stack(tb.tb_frame.f_back, None, sio)
-
- traceback.print_exception(typ, val, tb, None, sio)
- s = sio.getvalue()
- sio.close()
- if s[-1:] == "\n":
- s = s[:-1]
- return s
+__all__ = ["LogFormatter"]
diff --git a/synapse/util/manhole.py b/synapse/util/manhole.py
index 628a2962d9..631654f297 100644
--- a/synapse/util/manhole.py
+++ b/synapse/util/manhole.py
@@ -74,27 +74,25 @@ def manhole(username, password, globals):
twisted.internet.protocol.Factory: A factory to pass to ``listenTCP``
"""
if not isinstance(password, bytes):
- password = password.encode('ascii')
+ password = password.encode("ascii")
- checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(
- **{username: password}
- )
+ checker = checkers.InMemoryUsernamePasswordDatabaseDontUse(**{username: password})
rlm = manhole_ssh.TerminalRealm()
rlm.chainedProtocolFactory = lambda: insults.ServerProtocol(
- SynapseManhole,
- dict(globals, __name__="__console__")
+ SynapseManhole, dict(globals, __name__="__console__")
)
factory = manhole_ssh.ConchFactory(portal.Portal(rlm, [checker]))
- factory.publicKeys[b'ssh-rsa'] = Key.fromString(PUBLIC_KEY)
- factory.privateKeys[b'ssh-rsa'] = Key.fromString(PRIVATE_KEY)
+ factory.publicKeys[b"ssh-rsa"] = Key.fromString(PUBLIC_KEY)
+ factory.privateKeys[b"ssh-rsa"] = Key.fromString(PRIVATE_KEY)
return factory
class SynapseManhole(ColoredManhole):
"""Overrides connectionMade to create our own ManholeInterpreter"""
+
def connectionMade(self):
super(SynapseManhole, self).connectionMade()
@@ -127,7 +125,7 @@ class SynapseManholeInterpreter(ManholeInterpreter):
value = SyntaxError(msg, (filename, lineno, offset, line))
sys.last_value = value
lines = traceback.format_exception_only(type, value)
- self.write(''.join(lines))
+ self.write("".join(lines))
def showtraceback(self):
"""Display the exception that just occurred.
@@ -140,6 +138,6 @@ class SynapseManholeInterpreter(ManholeInterpreter):
try:
# We remove the first stack item because it is our own code.
lines = traceback.format_exception(ei[0], ei[1], last_tb.tb_next)
- self.write(''.join(lines))
+ self.write("".join(lines))
finally:
last_tb = ei = None
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index 4b4ac5f6c7..7b18455469 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
import logging
from functools import wraps
@@ -20,8 +21,8 @@ from prometheus_client import Counter
from twisted.internet import defer
+from synapse.logging.context import LoggingContext
from synapse.metrics import InFlightGauge
-from synapse.util.logcontext import LoggingContext
logger = logging.getLogger(__name__)
@@ -30,108 +31,108 @@ block_counter = Counter("synapse_util_metrics_block_count", "", ["block_name"])
block_timer = Counter("synapse_util_metrics_block_time_seconds", "", ["block_name"])
block_ru_utime = Counter(
- "synapse_util_metrics_block_ru_utime_seconds", "", ["block_name"])
+ "synapse_util_metrics_block_ru_utime_seconds", "", ["block_name"]
+)
block_ru_stime = Counter(
- "synapse_util_metrics_block_ru_stime_seconds", "", ["block_name"])
+ "synapse_util_metrics_block_ru_stime_seconds", "", ["block_name"]
+)
block_db_txn_count = Counter(
- "synapse_util_metrics_block_db_txn_count", "", ["block_name"])
+ "synapse_util_metrics_block_db_txn_count", "", ["block_name"]
+)
# seconds spent waiting for db txns, excluding scheduling time, in this block
block_db_txn_duration = Counter(
- "synapse_util_metrics_block_db_txn_duration_seconds", "", ["block_name"])
+ "synapse_util_metrics_block_db_txn_duration_seconds", "", ["block_name"]
+)
# seconds spent waiting for a db connection, in this block
block_db_sched_duration = Counter(
- "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"])
+ "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"]
+)
# Tracks the number of blocks currently active
in_flight = InFlightGauge(
- "synapse_util_metrics_block_in_flight", "",
+ "synapse_util_metrics_block_in_flight",
+ "",
labels=["block_name"],
sub_metrics=["real_time_max", "real_time_sum"],
)
-def measure_func(name):
+def measure_func(name=None):
def wrapper(func):
- @wraps(func)
- @defer.inlineCallbacks
- def measured_func(self, *args, **kwargs):
- with Measure(self.clock, name):
- r = yield func(self, *args, **kwargs)
- defer.returnValue(r)
+ block_name = func.__name__ if name is None else name
+
+ if inspect.iscoroutinefunction(func):
+
+ @wraps(func)
+ async def measured_func(self, *args, **kwargs):
+ with Measure(self.clock, block_name):
+ r = await func(self, *args, **kwargs)
+ return r
+
+ else:
+
+ @wraps(func)
+ @defer.inlineCallbacks
+ def measured_func(self, *args, **kwargs):
+ with Measure(self.clock, block_name):
+ r = yield func(self, *args, **kwargs)
+ return r
+
return measured_func
+
return wrapper
class Measure(object):
__slots__ = [
- "clock", "name", "start_context", "start",
- "created_context",
- "start_usage",
+ "clock",
+ "name",
+ "_logging_context",
+ "start",
]
def __init__(self, clock, name):
self.clock = clock
self.name = name
- self.start_context = None
+ self._logging_context = None
self.start = None
- self.created_context = False
def __enter__(self):
- self.start = self.clock.time()
- self.start_context = LoggingContext.current_context()
- if not self.start_context:
- self.start_context = LoggingContext("Measure")
- self.start_context.__enter__()
- self.created_context = True
-
- self.start_usage = self.start_context.get_resource_usage()
+ if self._logging_context:
+ raise RuntimeError("Measure() objects cannot be re-used")
+ self.start = self.clock.time()
+ parent_context = LoggingContext.current_context()
+ self._logging_context = LoggingContext(
+ "Measure[%s]" % (self.name,), parent_context
+ )
+ self._logging_context.__enter__()
in_flight.register((self.name,), self._update_in_flight)
def __exit__(self, exc_type, exc_val, exc_tb):
- if isinstance(exc_type, Exception) or not self.start_context:
- return
-
- in_flight.unregister((self.name,), self._update_in_flight)
+ if not self._logging_context:
+ raise RuntimeError("Measure() block exited without being entered")
duration = self.clock.time() - self.start
+ usage = self._logging_context.get_resource_usage()
- block_counter.labels(self.name).inc()
- block_timer.labels(self.name).inc(duration)
-
- context = LoggingContext.current_context()
-
- if context != self.start_context:
- logger.warn(
- "Context has unexpectedly changed from '%s' to '%s'. (%r)",
- self.start_context, context, self.name
- )
- return
-
- if not context:
- logger.warn("Expected context. (%r)", self.name)
- return
+ in_flight.unregister((self.name,), self._update_in_flight)
+ self._logging_context.__exit__(exc_type, exc_val, exc_tb)
- current = context.get_resource_usage()
- usage = current - self.start_usage
try:
+ block_counter.labels(self.name).inc()
+ block_timer.labels(self.name).inc(duration)
block_ru_utime.labels(self.name).inc(usage.ru_utime)
block_ru_stime.labels(self.name).inc(usage.ru_stime)
block_db_txn_count.labels(self.name).inc(usage.db_txn_count)
block_db_txn_duration.labels(self.name).inc(usage.db_txn_duration_sec)
block_db_sched_duration.labels(self.name).inc(usage.db_sched_duration_sec)
except ValueError:
- logger.warn(
- "Failed to save metrics! OLD: %r, NEW: %r",
- self.start_usage, current
- )
-
- if self.created_context:
- self.start_context.__exit__(exc_type, exc_val, exc_tb)
+ logger.warning("Failed to save metrics! Usage: %s", usage)
def _update_in_flight(self, metrics):
"""Gets called when processing in flight metrics
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 4288312b8a..bb62db4637 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -14,12 +14,13 @@
# limitations under the License.
import importlib
+import importlib.util
from synapse.config._base import ConfigError
def load_module(provider):
- """ Loads a module with its config
+ """ Loads a synapse module with its config
Take a dict with keys 'module' (the module name) and 'config'
(the config dict).
@@ -28,15 +29,30 @@ def load_module(provider):
"""
# We need to import the module, and then pick the class out of
# that, so we split based on the last dot.
- module, clz = provider['module'].rsplit(".", 1)
+ module, clz = provider["module"].rsplit(".", 1)
module = importlib.import_module(module)
provider_class = getattr(module, clz)
try:
- provider_config = provider_class.parse_config(provider["config"])
+ provider_config = provider_class.parse_config(provider.get("config"))
except Exception as e:
- raise ConfigError(
- "Failed to parse config for %r: %r" % (provider['module'], e)
- )
+ raise ConfigError("Failed to parse config for %r: %r" % (provider["module"], e))
return provider_class, provider_config
+
+
+def load_python_module(location: str):
+ """Load a python module, and return a reference to its global namespace
+
+ Args:
+ location (str): path to the module
+
+ Returns:
+ python module object
+ """
+ spec = importlib.util.spec_from_file_location(location, location)
+ if spec is None:
+ raise Exception("Unable to load module at %s" % (location,))
+ mod = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(mod) # type: ignore
+ return mod
diff --git a/synapse/util/msisdn.py b/synapse/util/msisdn.py
index a6c30e5265..c8bcbe297a 100644
--- a/synapse/util/msisdn.py
+++ b/synapse/util/msisdn.py
@@ -36,6 +36,6 @@ def phone_number_to_msisdn(country, number):
phoneNumber = phonenumbers.parse(number, country)
except phonenumbers.NumberParseException:
raise SynapseError(400, "Unable to parse phone number")
- return phonenumbers.format_number(
- phoneNumber, phonenumbers.PhoneNumberFormat.E164
- )[1:]
+ return phonenumbers.format_number(phoneNumber, phonenumbers.PhoneNumberFormat.E164)[
+ 1:
+ ]
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
new file mode 100644
index 0000000000..3925927f9f
--- /dev/null
+++ b/synapse/util/patch_inline_callbacks.py
@@ -0,0 +1,219 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import functools
+import sys
+from typing import Any, Callable, List
+
+from twisted.internet import defer
+from twisted.internet.defer import Deferred
+from twisted.python.failure import Failure
+
+# Tracks if we've already patched inlineCallbacks
+_already_patched = False
+
+
+def do_patch():
+ """
+ Patch defer.inlineCallbacks so that it checks the state of the logcontext on exit
+ """
+
+ from synapse.logging.context import LoggingContext
+
+ global _already_patched
+
+ orig_inline_callbacks = defer.inlineCallbacks
+ if _already_patched:
+ return
+
+ def new_inline_callbacks(f):
+ @functools.wraps(f)
+ def wrapped(*args, **kwargs):
+ start_context = LoggingContext.current_context()
+ changes = [] # type: List[str]
+ orig = orig_inline_callbacks(_check_yield_points(f, changes))
+
+ try:
+ res = orig(*args, **kwargs)
+ except Exception:
+ if LoggingContext.current_context() != start_context:
+ for err in changes:
+ print(err, file=sys.stderr)
+
+ err = "%s changed context from %s to %s on exception" % (
+ f,
+ start_context,
+ LoggingContext.current_context(),
+ )
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ raise
+
+ if not isinstance(res, Deferred) or res.called:
+ if LoggingContext.current_context() != start_context:
+ for err in changes:
+ print(err, file=sys.stderr)
+
+ err = "Completed %s changed context from %s to %s" % (
+ f,
+ start_context,
+ LoggingContext.current_context(),
+ )
+ # print the error to stderr because otherwise all we
+ # see in travis-ci is the 500 error
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ return res
+
+ if LoggingContext.current_context() != LoggingContext.sentinel:
+ err = (
+ "%s returned incomplete deferred in non-sentinel context "
+ "%s (start was %s)"
+ ) % (f, LoggingContext.current_context(), start_context)
+ print(err, file=sys.stderr)
+ raise Exception(err)
+
+ def check_ctx(r):
+ if LoggingContext.current_context() != start_context:
+ for err in changes:
+ print(err, file=sys.stderr)
+ err = "%s completion of %s changed context from %s to %s" % (
+ "Failure" if isinstance(r, Failure) else "Success",
+ f,
+ start_context,
+ LoggingContext.current_context(),
+ )
+ print(err, file=sys.stderr)
+ raise Exception(err)
+ return r
+
+ res.addBoth(check_ctx)
+ return res
+
+ return wrapped
+
+ defer.inlineCallbacks = new_inline_callbacks
+ _already_patched = True
+
+
+def _check_yield_points(f: Callable, changes: List[str]):
+ """Wraps a generator that is about to be passed to defer.inlineCallbacks
+ checking that after every yield the log contexts are correct.
+
+ It's perfectly valid for log contexts to change within a function, e.g. due
+ to new Measure blocks, so such changes are added to the given `changes`
+ list instead of triggering an exception.
+
+ Args:
+ f: generator function to wrap
+ changes: A list of strings detailing how the contexts
+ changed within a function.
+
+ Returns:
+ function
+ """
+
+ from synapse.logging.context import LoggingContext
+
+ @functools.wraps(f)
+ def check_yield_points_inner(*args, **kwargs):
+ gen = f(*args, **kwargs)
+
+ last_yield_line_no = gen.gi_frame.f_lineno
+ result = None # type: Any
+ while True:
+ expected_context = LoggingContext.current_context()
+
+ try:
+ isFailure = isinstance(result, Failure)
+ if isFailure:
+ d = result.throwExceptionIntoGenerator(gen)
+ else:
+ d = gen.send(result)
+ except (StopIteration, defer._DefGen_Return) as e:
+ if LoggingContext.current_context() != expected_context:
+ # This happens when the context is lost sometime *after* the
+ # final yield and returning. E.g. we forgot to yield on a
+ # function that returns a deferred.
+ #
+ # We don't raise here as it's perfectly valid for contexts to
+ # change in a function, as long as it sets the correct context
+ # on resolving (which is checked separately).
+ err = (
+ "Function %r returned and changed context from %s to %s,"
+ " in %s between %d and end of func"
+ % (
+ f.__qualname__,
+ expected_context,
+ LoggingContext.current_context(),
+ f.__code__.co_filename,
+ last_yield_line_no,
+ )
+ )
+ changes.append(err)
+ return getattr(e, "value", None)
+
+ frame = gen.gi_frame
+
+ if isinstance(d, defer.Deferred) and not d.called:
+ # This happens if we yield on a deferred that doesn't follow
+ # the log context rules without wrapping in a `make_deferred_yieldable`.
+ # We raise here as this should never happen.
+ if LoggingContext.current_context() is not LoggingContext.sentinel:
+ err = (
+ "%s yielded with context %s rather than sentinel,"
+ " yielded on line %d in %s"
+ % (
+ frame.f_code.co_name,
+ LoggingContext.current_context(),
+ frame.f_lineno,
+ frame.f_code.co_filename,
+ )
+ )
+ raise Exception(err)
+
+ try:
+ result = yield d
+ except Exception as e:
+ result = Failure(e)
+
+ if LoggingContext.current_context() != expected_context:
+
+ # This happens because the context is lost sometime *after* the
+ # previous yield and *after* the current yield. E.g. the
+ # deferred we waited on didn't follow the rules, or we forgot to
+ # yield on a function between the two yield points.
+ #
+ # We don't raise here as its perfectly valid for contexts to
+ # change in a function, as long as it sets the correct context
+ # on resolving (which is checked separately).
+ err = (
+ "%s changed context from %s to %s, happened between lines %d and %d in %s"
+ % (
+ frame.f_code.co_name,
+ expected_context,
+ LoggingContext.current_context(),
+ last_yield_line_no,
+ frame.f_lineno,
+ frame.f_code.co_filename,
+ )
+ )
+ changes.append(err)
+
+ last_yield_line_no = frame.f_lineno
+
+ return check_yield_points_inner
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index b146d137f4..5ca4521ce3 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -20,7 +20,7 @@ import logging
from twisted.internet import defer
from synapse.api.errors import LimitExceededError
-from synapse.util.logcontext import (
+from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
run_in_background,
@@ -36,9 +36,11 @@ class FederationRateLimiter(object):
clock (Clock)
config (FederationRateLimitConfig)
"""
- self.clock = clock
- self._config = config
- self.ratelimiters = {}
+
+ def new_limiter():
+ return _PerHostRatelimiter(clock=clock, config=config)
+
+ self.ratelimiters = collections.defaultdict(new_limiter)
def ratelimit(self, host):
"""Used to ratelimit an incoming request from given host
@@ -53,15 +55,9 @@ class FederationRateLimiter(object):
host (str): Origin of incoming request.
Returns:
- _PerHostRatelimiter
+ context manager which returns a deferred.
"""
- return self.ratelimiters.setdefault(
- host,
- _PerHostRatelimiter(
- clock=self.clock,
- config=self._config,
- )
- ).ratelimit()
+ return self.ratelimiters[host].ratelimit()
class _PerHostRatelimiter(object):
@@ -112,8 +108,7 @@ class _PerHostRatelimiter(object):
# remove any entries from request_times which aren't within the window
self.request_times[:] = [
- r for r in self.request_times
- if time_now - r < self.window_size
+ r for r in self.request_times if time_now - r < self.window_size
]
# reject the request if we already have too many queued up (either
@@ -121,15 +116,13 @@ class _PerHostRatelimiter(object):
queue_size = len(self.ready_request_queue) + len(self.sleeping_requests)
if queue_size > self.reject_limit:
raise LimitExceededError(
- retry_after_ms=int(
- self.window_size / self.sleep_limit
- ),
+ retry_after_ms=int(self.window_size / self.sleep_limit)
)
self.request_times.append(time_now)
def queue_request():
- if len(self.current_processing) > self.concurrent_requests:
+ if len(self.current_processing) >= self.concurrent_requests:
queue_defer = defer.Deferred()
self.ready_request_queue[request_id] = queue_defer
logger.info(
@@ -143,22 +136,18 @@ class _PerHostRatelimiter(object):
logger.debug(
"Ratelimit [%s]: len(self.request_times)=%d",
- id(request_id), len(self.request_times),
+ id(request_id),
+ len(self.request_times),
)
if len(self.request_times) > self.sleep_limit:
- logger.debug(
- "Ratelimiter: sleeping request for %f sec", self.sleep_sec,
- )
+ logger.debug("Ratelimiter: sleeping request for %f sec", self.sleep_sec)
ret_defer = run_in_background(self.clock.sleep, self.sleep_sec)
self.sleeping_requests.add(request_id)
def on_wait_finished(_):
- logger.debug(
- "Ratelimit [%s]: Finished sleeping",
- id(request_id),
- )
+ logger.debug("Ratelimit [%s]: Finished sleeping", id(request_id))
self.sleeping_requests.discard(request_id)
queue_defer = queue_request()
return queue_defer
@@ -168,10 +157,7 @@ class _PerHostRatelimiter(object):
ret_defer = queue_request()
def on_start(r):
- logger.debug(
- "Ratelimit [%s]: Processing req",
- id(request_id),
- )
+ logger.debug("Ratelimit [%s]: Processing req", id(request_id))
self.current_processing.add(request_id)
return r
@@ -193,10 +179,7 @@ class _PerHostRatelimiter(object):
return make_deferred_yieldable(ret_defer)
def _on_exit(self, request_id):
- logger.debug(
- "Ratelimit [%s]: Processed req",
- id(request_id),
- )
+ logger.debug("Ratelimit [%s]: Processed req", id(request_id))
self.current_processing.discard(request_id)
try:
# start processing the next item on the queue.
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 1a77456498..af69587196 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -17,11 +17,20 @@ import random
from twisted.internet import defer
-import synapse.util.logcontext
+import synapse.logging.context
from synapse.api.errors import CodeMessageException
logger = logging.getLogger(__name__)
+# the intial backoff, after the first transaction fails
+MIN_RETRY_INTERVAL = 10 * 60 * 1000
+
+# how much we multiply the backoff by after each subsequent fail
+RETRY_MULTIPLIER = 5
+
+# a cap on the backoff. (Essentially none)
+MAX_RETRY_INTERVAL = 2 ** 62
+
class NotRetryingDestination(Exception):
def __init__(self, retry_last_ts, retry_interval, destination):
@@ -71,11 +80,13 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs)
# We aren't ready to retry that destination.
raise
"""
+ failure_ts = None
retry_last_ts, retry_interval = (0, 0)
retry_timings = yield store.get_destination_retry_timings(destination)
if retry_timings:
+ failure_ts = retry_timings["failure_ts"]
retry_last_ts, retry_interval = (
retry_timings["retry_last_ts"],
retry_timings["retry_interval"],
@@ -95,15 +106,14 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs)
# maximum backoff even though it might only have been down briefly
backoff_on_failure = not ignore_backoff
- defer.returnValue(
- RetryDestinationLimiter(
- destination,
- clock,
- store,
- retry_interval,
- backoff_on_failure=backoff_on_failure,
- **kwargs
- )
+ return RetryDestinationLimiter(
+ destination,
+ clock,
+ store,
+ failure_ts,
+ retry_interval,
+ backoff_on_failure=backoff_on_failure,
+ **kwargs
)
@@ -113,10 +123,8 @@ class RetryDestinationLimiter(object):
destination,
clock,
store,
+ failure_ts,
retry_interval,
- min_retry_interval=10 * 60 * 1000,
- max_retry_interval=24 * 60 * 60 * 1000,
- multiplier_retry_interval=5,
backoff_on_404=False,
backoff_on_failure=True,
):
@@ -129,15 +137,11 @@ class RetryDestinationLimiter(object):
destination (str)
clock (Clock)
store (DataStore)
+ failure_ts (int|None): when this destination started failing (in ms since
+ the epoch), or zero if the last request was successful
retry_interval (int): The next retry interval taken from the
database in milliseconds, or zero if the last request was
successful.
- min_retry_interval (int): The minimum retry interval to use after
- a failed request, in milliseconds.
- max_retry_interval (int): The maximum retry interval to use after
- a failed request, in milliseconds.
- multiplier_retry_interval (int): The multiplier to use to increase
- the retry interval after a failed request.
backoff_on_404 (bool): Back off if we get a 404
backoff_on_failure (bool): set to False if we should not increase the
@@ -147,10 +151,8 @@ class RetryDestinationLimiter(object):
self.store = store
self.destination = destination
+ self.failure_ts = failure_ts
self.retry_interval = retry_interval
- self.min_retry_interval = min_retry_interval
- self.max_retry_interval = max_retry_interval
- self.multiplier_retry_interval = multiplier_retry_interval
self.backoff_on_404 = backoff_on_404
self.backoff_on_failure = backoff_on_failure
@@ -191,6 +193,7 @@ class RetryDestinationLimiter(object):
logger.debug(
"Connection to %s was successful; clearing backoff", self.destination
)
+ self.failure_ts = None
retry_last_ts = 0
self.retry_interval = 0
elif not self.backoff_on_failure:
@@ -198,13 +201,14 @@ class RetryDestinationLimiter(object):
else:
# We couldn't connect.
if self.retry_interval:
- self.retry_interval *= self.multiplier_retry_interval
- self.retry_interval *= int(random.uniform(0.8, 1.4))
+ self.retry_interval = int(
+ self.retry_interval * RETRY_MULTIPLIER * random.uniform(0.8, 1.4)
+ )
- if self.retry_interval >= self.max_retry_interval:
- self.retry_interval = self.max_retry_interval
+ if self.retry_interval >= MAX_RETRY_INTERVAL:
+ self.retry_interval = MAX_RETRY_INTERVAL
else:
- self.retry_interval = self.min_retry_interval
+ self.retry_interval = MIN_RETRY_INTERVAL
logger.info(
"Connection to %s was unsuccessful (%s(%s)); backoff now %i",
@@ -215,14 +219,20 @@ class RetryDestinationLimiter(object):
)
retry_last_ts = int(self.clock.time_msec())
+ if self.failure_ts is None:
+ self.failure_ts = retry_last_ts
+
@defer.inlineCallbacks
def store_retry_timings():
try:
yield self.store.set_destination_retry_timings(
- self.destination, retry_last_ts, self.retry_interval
+ self.destination,
+ self.failure_ts,
+ retry_last_ts,
+ self.retry_interval,
)
except Exception:
logger.exception("Failed to store destination_retry_timings")
# we deliberately do this in the background.
- synapse.util.logcontext.run_in_background(store_retry_timings)
+ synapse.logging.context.run_in_background(store_retry_timings)
diff --git a/synapse/util/rlimit.py b/synapse/util/rlimit.py
index 6c0f2bb0cf..207cd17c2a 100644
--- a/synapse/util/rlimit.py
+++ b/synapse/util/rlimit.py
@@ -33,4 +33,4 @@ def change_resource_limit(soft_file_no):
resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY)
)
except (ValueError, resource.error) as e:
- logger.warn("Failed to set file or core limit: %s", e)
+ logger.warning("Failed to set file or core limit: %s", e)
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 5fb18ee1f8..2c0dcb5208 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -24,26 +24,25 @@ from six.moves import range
from synapse.api.errors import Codes, SynapseError
-_string_with_symbols = (
- string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
-)
+_string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
+
+# https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
+# Note: The : character is allowed here for older clients, but will be removed in a
+# future release. Context: https://github.com/matrix-org/synapse/issues/6766
+client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-\:]+$")
# random_string and random_string_with_symbols are used for a range of things,
# some cryptographically important, some less so. We use SystemRandom to make sure
# we get cryptographically-secure randoms.
rand = random.SystemRandom()
-client_secret_regex = re.compile(r"^[0-9a-zA-Z.=_-]+$")
-
def random_string(length):
- return ''.join(rand.choice(string.ascii_letters) for _ in range(length))
+ return "".join(rand.choice(string.ascii_letters) for _ in range(length))
def random_string_with_symbols(length):
- return ''.join(
- rand.choice(_string_with_symbols) for _ in range(length)
- )
+ return "".join(rand.choice(_string_with_symbols) for _ in range(length))
def is_ascii(s):
@@ -51,7 +50,7 @@ def is_ascii(s):
if PY3:
if isinstance(s, bytes):
try:
- s.decode('ascii').encode('ascii')
+ s.decode("ascii").encode("ascii")
except UnicodeDecodeError:
return False
except UnicodeEncodeError:
@@ -110,13 +109,13 @@ def exception_to_unicode(e):
# and instead look at what is in the args member.
if len(e.args) == 0:
- return u""
+ return ""
elif len(e.args) > 1:
return six.text_type(repr(e.args))
msg = e.args[0]
if isinstance(msg, bytes):
- return msg.decode('utf-8', errors='replace')
+ return msg.decode("utf-8", errors="replace")
else:
return msg
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
index 4cc7d27ce5..34ce7cac16 100644
--- a/synapse/util/threepids.py
+++ b/synapse/util/threepids.py
@@ -36,25 +36,26 @@ def check_3pid_allowed(hs, medium, address):
if hs.config.check_is_for_allowed_local_3pids:
data = yield hs.get_simple_http_client().get_json(
- "https://%s%s" % (
+ "https://%s%s"
+ % (
hs.config.check_is_for_allowed_local_3pids,
- "/_matrix/identity/api/v1/internal-info"
+ "/_matrix/identity/api/v1/internal-info",
),
- {'medium': medium, 'address': address}
+ {"medium": medium, "address": address},
)
# Check for invalid response
- if 'hs' not in data and 'shadow_hs' not in data:
+ if "hs" not in data and "shadow_hs" not in data:
defer.returnValue(False)
# Check if this user is intended to register for this homeserver
if (
- data.get('hs') != hs.config.server_name
- and data.get('shadow_hs') != hs.config.server_name
+ data.get("hs") != hs.config.server_name
+ and data.get("shadow_hs") != hs.config.server_name
):
defer.returnValue(False)
- if data.get('requires_invite', False) and not data.get('invited', False):
+ if data.get("requires_invite", False) and not data.get("invited", False):
# Requires an invite but hasn't been invited
defer.returnValue(False)
@@ -64,11 +65,13 @@ def check_3pid_allowed(hs, medium, address):
for constraint in hs.config.allowed_local_3pids:
logger.debug(
"Checking 3PID %s (%s) against %s (%s)",
- address, medium, constraint['pattern'], constraint['medium'],
+ address,
+ medium,
+ constraint["pattern"],
+ constraint["medium"],
)
- if (
- medium == constraint['medium'] and
- re.match(constraint['pattern'], address)
+ if medium == constraint["medium"] and re.match(
+ constraint["pattern"], address
):
defer.returnValue(True)
else:
diff --git a/synapse/util/versionstring.py b/synapse/util/versionstring.py
index 3baba3225a..ab7d03af3a 100644
--- a/synapse/util/versionstring.py
+++ b/synapse/util/versionstring.py
@@ -22,63 +22,87 @@ logger = logging.getLogger(__name__)
def get_version_string(module):
+ """Given a module calculate a git-aware version string for it.
+
+ If called on a module not in a git checkout will return `__verison__`.
+
+ Args:
+ module (module)
+
+ Returns:
+ str
+ """
+
+ cached_version = getattr(module, "_synapse_version_string_cache", None)
+ if cached_version:
+ return cached_version
+
+ version_string = module.__version__
+
try:
- null = open(os.devnull, 'w')
+ null = open(os.devnull, "w")
cwd = os.path.dirname(os.path.abspath(module.__file__))
+
try:
- git_branch = subprocess.check_output(
- ['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
- stderr=null,
- cwd=cwd,
- ).strip().decode('ascii')
+ git_branch = (
+ subprocess.check_output(
+ ["git", "rev-parse", "--abbrev-ref", "HEAD"], stderr=null, cwd=cwd
+ )
+ .strip()
+ .decode("ascii")
+ )
git_branch = "b=" + git_branch
- except subprocess.CalledProcessError:
+ except (subprocess.CalledProcessError, FileNotFoundError):
+ # FileNotFoundError can arise when git is not installed
git_branch = ""
try:
- git_tag = subprocess.check_output(
- ['git', 'describe', '--exact-match'],
- stderr=null,
- cwd=cwd,
- ).strip().decode('ascii')
+ git_tag = (
+ subprocess.check_output(
+ ["git", "describe", "--exact-match"], stderr=null, cwd=cwd
+ )
+ .strip()
+ .decode("ascii")
+ )
git_tag = "t=" + git_tag
- except subprocess.CalledProcessError:
+ except (subprocess.CalledProcessError, FileNotFoundError):
git_tag = ""
try:
- git_commit = subprocess.check_output(
- ['git', 'rev-parse', '--short', 'HEAD'],
- stderr=null,
- cwd=cwd,
- ).strip().decode('ascii')
- except subprocess.CalledProcessError:
+ git_commit = (
+ subprocess.check_output(
+ ["git", "rev-parse", "--short", "HEAD"], stderr=null, cwd=cwd
+ )
+ .strip()
+ .decode("ascii")
+ )
+ except (subprocess.CalledProcessError, FileNotFoundError):
git_commit = ""
try:
dirty_string = "-this_is_a_dirty_checkout"
- is_dirty = subprocess.check_output(
- ['git', 'describe', '--dirty=' + dirty_string],
- stderr=null,
- cwd=cwd,
- ).strip().decode('ascii').endswith(dirty_string)
+ is_dirty = (
+ subprocess.check_output(
+ ["git", "describe", "--dirty=" + dirty_string], stderr=null, cwd=cwd
+ )
+ .strip()
+ .decode("ascii")
+ .endswith(dirty_string)
+ )
git_dirty = "dirty" if is_dirty else ""
- except subprocess.CalledProcessError:
+ except (subprocess.CalledProcessError, FileNotFoundError):
git_dirty = ""
if git_branch or git_tag or git_commit or git_dirty:
git_version = ",".join(
- s for s in
- (git_branch, git_tag, git_commit, git_dirty,)
- if s
+ s for s in (git_branch, git_tag, git_commit, git_dirty) if s
)
- return (
- "%s (%s)" % (
- module.__version__, git_version,
- )
- )
+ version_string = "%s (%s)" % (module.__version__, git_version)
except Exception as e:
logger.info("Failed to check for git repository: %s", e)
- return module.__version__
+ module._synapse_version_string_cache = version_string
+
+ return version_string
diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index 7a9e45aca9..9bf6a44f75 100644
--- a/synapse/util/wheel_timer.py
+++ b/synapse/util/wheel_timer.py
@@ -69,9 +69,7 @@ class WheelTimer(object):
# Add empty entries between the end of the current list and when we want
# to insert. This ensures there are no gaps.
- self.entries.extend(
- _Entry(key) for key in range(last_key, then_key + 1)
- )
+ self.entries.extend(_Entry(key) for key in range(last_key, then_key + 1))
self.entries[-1].queue.append(obj)
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 09d8334b26..bab41182b9 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -23,18 +23,14 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events.utils import prune_event
+from synapse.storage import Storage
from synapse.storage.state import StateFilter
from synapse.types import get_domain_from_id
logger = logging.getLogger(__name__)
-VISIBILITY_PRIORITY = (
- "world_readable",
- "shared",
- "invited",
- "joined",
-)
+VISIBILITY_PRIORITY = ("world_readable", "shared", "invited", "joined")
MEMBERSHIP_PRIORITY = (
@@ -47,15 +43,20 @@ MEMBERSHIP_PRIORITY = (
@defer.inlineCallbacks
-def filter_events_for_client(store, user_id, events, is_peeking=False,
- always_include_ids=frozenset(),
- apply_retention_policies=True):
+def filter_events_for_client(
+ storage: Storage,
+ user_id,
+ events,
+ is_peeking=False,
+ always_include_ids=frozenset(),
+ filter_send_to_client=True,
+):
"""
- Check which events a user is allowed to see
+ Check which events a user is allowed to see. If the user can see the event but its
+ sender asked for their data to be erased, prune the content of the event.
Args:
- store (synapse.storage.DataStore): our datastore (can also be a worker
- store)
+ storage
user_id(str): user id to be checked
events(list[synapse.events.EventBase]): sequence of events to be checked
is_peeking(bool): should be True if:
@@ -64,47 +65,44 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
events
always_include_ids (set(event_id)): set of event ids to specifically
include (unless sender is ignored)
- apply_retention_policies (bool): Whether to filter out events that's older than
- allowed by the room's retention policy. Useful when this function is called
- to e.g. check whether a user should be allowed to see the state at a given
- event rather than to know if it should send an event to a user's client(s).
+ filter_send_to_client (bool): Whether we're checking an event that's going to be
+ sent to a client. This might not always be the case since this function can
+ also be called to check whether a user can see the state at a given point.
Returns:
Deferred[list[synapse.events.EventBase]]
"""
# Filter out events that have been soft failed so that we don't relay them
# to clients.
- events = list(e for e in events if not e.internal_metadata.is_soft_failed())
+ events = [e for e in events if not e.internal_metadata.is_soft_failed()]
- types = (
- (EventTypes.RoomHistoryVisibility, ""),
- (EventTypes.Member, user_id),
- )
- event_id_to_state = yield store.get_state_for_events(
+ types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
+ event_id_to_state = yield storage.state.get_state_for_events(
frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(types),
)
- ignore_dict_content = yield store.get_global_account_data_by_type_for_user(
- "m.ignored_user_list", user_id,
+ ignore_dict_content = yield storage.main.get_global_account_data_by_type_for_user(
+ "m.ignored_user_list", user_id
)
# FIXME: This will explode if people upload something incorrect.
ignore_list = frozenset(
ignore_dict_content.get("ignored_users", {}).keys()
- if ignore_dict_content else []
+ if ignore_dict_content
+ else []
)
- erased_senders = yield store.are_users_erased((e.sender for e in events))
+ erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
- if apply_retention_policies:
- room_ids = set(e.room_id for e in events)
+ if filter_send_to_client:
+ room_ids = {e.room_id for e in events}
retention_policies = {}
for room_id in room_ids:
- retention_policies[room_id] = (
- yield store.get_retention_policy_for_room(room_id)
- )
+ retention_policies[
+ room_id
+ ] = yield storage.main.get_retention_policy_for_room(room_id)
def allowed(event):
"""
@@ -120,20 +118,36 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
the original event if they can see it as normal.
"""
- if not event.is_state() and event.sender in ignore_list:
- return None
-
- # Don't try to apply the room's retention policy if the event is a state event, as
- # MSC1763 states that retention is only considered for non-state events.
- if apply_retention_policies and not event.is_state():
- retention_policy = retention_policies[event.room_id]
- max_lifetime = retention_policy.get("max_lifetime")
-
- if max_lifetime is not None:
- oldest_allowed_ts = store.clock.time_msec() - max_lifetime
-
- if event.origin_server_ts < oldest_allowed_ts:
- return None
+ # Only run some checks if these events aren't about to be sent to clients. This is
+ # because, if this is not the case, we're probably only checking if the users can
+ # see events in the room at that point in the DAG, and that shouldn't be decided
+ # on those checks.
+ if filter_send_to_client:
+ if event.type == "org.matrix.dummy_event":
+ return None
+
+ if not event.is_state() and event.sender in ignore_list:
+ return None
+
+ # Until MSC2261 has landed we can't redact malicious alias events, so for
+ # now we temporarily filter out m.room.aliases entirely to mitigate
+ # abuse, while we spec a better solution to advertising aliases
+ # on rooms.
+ if event.type == EventTypes.Aliases:
+ return None
+
+ # Don't try to apply the room's retention policy if the event is a state
+ # event, as MSC1763 states that retention is only considered for non-state
+ # events.
+ if not event.is_state():
+ retention_policy = retention_policies[event.room_id]
+ max_lifetime = retention_policy.get("max_lifetime")
+
+ if max_lifetime is not None:
+ oldest_allowed_ts = storage.main.clock.time_msec() - max_lifetime
+
+ if event.origin_server_ts < oldest_allowed_ts:
+ return None
if event.event_id in always_include_ids:
return event
@@ -211,9 +225,7 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
elif visibility == "invited":
# user can also see the event if they were *invited* at the time
# of the event.
- return (
- event if membership == Membership.INVITE else None
- )
+ return event if membership == Membership.INVITE else None
elif visibility == "shared" and is_peeking:
# if the visibility is shared, users cannot see the event unless
@@ -242,17 +254,22 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
filtered_events = filter(operator.truth, filtered_events)
# we turn it into a list before returning it.
- defer.returnValue(list(filtered_events))
+ return list(filtered_events)
@defer.inlineCallbacks
-def filter_events_for_server(store, server_name, events, redact=True,
- check_history_visibility_only=False):
+def filter_events_for_server(
+ storage: Storage,
+ server_name,
+ events,
+ redact=True,
+ check_history_visibility_only=False,
+):
"""Filter a list of events based on whether given server is allowed to
see them.
Args:
- store (DataStore)
+ storage
server_name (str)
events (iterable[FrozenEvent])
redact (bool): Whether to return a redacted version of the event, or
@@ -268,15 +285,12 @@ def filter_events_for_server(store, server_name, events, redact=True,
def is_sender_erased(event, erased_senders):
if erased_senders and erased_senders[event.sender]:
- logger.info(
- "Sender of %s has been erased, redacting",
- event.event_id,
- )
+ logger.info("Sender of %s has been erased, redacting", event.event_id)
return True
return False
def check_event_is_visible(event, state):
- history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
+ history = state.get((EventTypes.RoomHistoryVisibility, ""), None)
if history:
visibility = history.content.get("history_visibility", "shared")
if visibility in ["invited", "joined"]:
@@ -310,11 +324,11 @@ def filter_events_for_server(store, server_name, events, redact=True,
# Lets check to see if all the events have a history visibility
# of "shared" or "world_readable". If thats the case then we don't
# need to check membership (as we know the server is in the room).
- event_to_state_ids = yield store.get_state_ids_for_events(
+ event_to_state_ids = yield storage.state.get_state_ids_for_events(
frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(
- types=((EventTypes.RoomHistoryVisibility, ""),),
- )
+ types=((EventTypes.RoomHistoryVisibility, ""),)
+ ),
)
visibility_ids = set()
@@ -328,16 +342,14 @@ def filter_events_for_server(store, server_name, events, redact=True,
if not visibility_ids:
all_open = True
else:
- event_map = yield store.get_events(visibility_ids)
+ event_map = yield storage.main.get_events(visibility_ids)
all_open = all(
e.content.get("history_visibility") in (None, "shared", "world_readable")
for e in itervalues(event_map)
)
if not check_history_visibility_only:
- erased_senders = yield store.are_users_erased(
- (e.sender for e in events),
- )
+ erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
else:
# We don't want to check whether users are erased, which is equivalent
# to no users having been erased.
@@ -355,25 +367,22 @@ def filter_events_for_server(store, server_name, events, redact=True,
elif redact:
to_return.append(prune_event(e))
- defer.returnValue(to_return)
+ return to_return
# If there are no erased users then we can just return the given list
# of events without having to copy it.
- defer.returnValue(events)
+ return events
# Ok, so we're dealing with events that have non-trivial visibility
# rules, so we need to also get the memberships of the room.
# first, for each event we're wanting to return, get the event_ids
# of the history vis and membership state at those events.
- event_to_state_ids = yield store.get_state_ids_for_events(
+ event_to_state_ids = yield storage.state.get_state_ids_for_events(
frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(
- types=(
- (EventTypes.RoomHistoryVisibility, ""),
- (EventTypes.Member, None),
- ),
- )
+ types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None))
+ ),
)
# We only want to pull out member events that correspond to the
@@ -397,13 +406,15 @@ def filter_events_for_server(store, server_name, events, redact=True,
idx = state_key.find(":")
if idx == -1:
return False
- return state_key[idx + 1:] == server_name
-
- event_map = yield store.get_events([
- e_id
- for e_id, key in iteritems(event_id_to_state_key)
- if include(key[0], key[1])
- ])
+ return state_key[idx + 1 :] == server_name
+
+ event_map = yield storage.main.get_events(
+ [
+ e_id
+ for e_id, key in iteritems(event_id_to_state_key)
+ if include(key[0], key[1])
+ ]
+ )
event_to_state = {
e_id: {
@@ -423,4 +434,4 @@ def filter_events_for_server(store, server_name, events, redact=True,
elif redact:
to_return.append(prune_event(e))
- defer.returnValue(to_return)
+ return to_return
|