diff --git a/synapse/__init__.py b/synapse/__init__.py
index ec16f54a49..f99de2f3f3 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-""" This is a reference implementation of a Matrix home server.
+""" This is a reference implementation of a Matrix homeserver.
"""
import os
@@ -36,7 +36,7 @@ try:
except ImportError:
pass
-__version__ = "1.5.1"
+__version__ = "1.6.1"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py
index bdcd915bbe..d528450c78 100644
--- a/synapse/_scripts/register_new_matrix_user.py
+++ b/synapse/_scripts/register_new_matrix_user.py
@@ -144,8 +144,8 @@ def main():
logging.captureWarnings(True)
parser = argparse.ArgumentParser(
- description="Used to register new users with a given home server when"
- " registration has been disabled. The home server must be"
+ description="Used to register new users with a given homeserver when"
+ " registration has been disabled. The homeserver must be"
" configured with the 'registration_shared_secret' option"
" set."
)
@@ -202,7 +202,7 @@ def main():
"server_url",
default="https://localhost:8448",
nargs="?",
- help="URL to use to talk to the home server. Defaults to "
+ help="URL to use to talk to the homeserver. Defaults to "
" 'https://localhost:8448'.",
)
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 49c4b85054..0ade47e624 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -94,6 +95,8 @@ class EventTypes(object):
ServerACL = "m.room.server_acl"
Pinned = "m.room.pinned_events"
+ Retention = "m.room.retention"
+
class RejectedReason(object):
AUTH_ERROR = "auth_error"
@@ -145,3 +148,7 @@ class EventContentFields(object):
# Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326
LABELS = "org.matrix.labels"
+
+ # Timestamp to delete the event after
+ # cf https://github.com/matrix-org/matrix-doc/pull/2228
+ SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index cca92c34ba..5853a54c95 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -457,7 +457,7 @@ def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
class FederationError(RuntimeError):
- """ This class is used to inform remote home servers about erroneous
+ """ This class is used to inform remote homeservers about erroneous
PDUs they sent us.
FATAL: The remote server could not interpret the source event.
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index bec13f08d8..6eab1f13f0 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -1,5 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 139221ad34..448e45e00f 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -69,7 +69,7 @@ class FederationSenderSlaveStore(
self.federation_out_pos_startup = self._get_federation_out_pos(db_conn)
def _get_federation_out_pos(self, db_conn):
- sql = "SELECT stream_id FROM federation_stream_position" " WHERE type = ?"
+ sql = "SELECT stream_id FROM federation_stream_position WHERE type = ?"
sql = self.database_engine.convert_param_style(sql)
txn = db_conn.cursor()
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 73e2c29d06..267aebaae9 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -542,8 +542,8 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
# Database version
#
- stats["database_engine"] = hs.get_datastore().database_engine_name
- stats["database_server_version"] = hs.get_datastore().get_server_version()
+ stats["database_engine"] = hs.database_engine.module.__name__
+ stats["database_server_version"] = hs.database_engine.server_version
logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
try:
yield hs.get_proxied_http_client().put_json(
@@ -585,7 +585,7 @@ def run(hs):
def performance_stats_init():
_stats_process.clear()
_stats_process.append(
- (int(hs.get_clock().time(), resource.getrusage(resource.RUSAGE_SELF)))
+ (int(hs.get_clock().time()), resource.getrusage(resource.RUSAGE_SELF))
)
def start_phone_stats_home():
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index b14da09f47..288ee64b42 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -151,7 +151,7 @@ class SynchrotronPresence(object):
def set_state(self, user, state, ignore_status_msg=False):
# TODO Hows this supposed to work?
- pass
+ return defer.succeed(None)
get_states = __func__(PresenceHandler.get_states)
get_state = __func__(PresenceHandler.get_state)
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index 6cb100319f..0fa2b50999 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -64,7 +64,7 @@ class UserDirectorySlaveStore(
super(UserDirectorySlaveStore, self).__init__(db_conn, hs)
events_max = self._stream_id_gen.get_current_token()
- curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
+ curr_state_delta_prefill, min_curr_state_delta_id = self.get_cache_dict(
db_conn,
"current_state_delta_stream",
entity_column="room_id",
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 3e25bf5747..57174da021 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -185,7 +185,7 @@ class ApplicationServiceApi(SimpleHttpClient):
if not _is_valid_3pe_metadata(info):
logger.warning(
- "query_3pe_protocol to %s did not return a" " valid result", uri
+ "query_3pe_protocol to %s did not return a valid result", uri
)
return None
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index e77d3387ff..ca43e96bd1 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -134,7 +134,7 @@ def _load_appservice(hostname, as_info, config_filename):
for regex_obj in as_info["namespaces"][ns]:
if not isinstance(regex_obj, dict):
raise ValueError(
- "Expected namespace entry in %s to be an object," " but got %s",
+ "Expected namespace entry in %s to be an object, but got %s",
ns,
regex_obj,
)
diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py
index 44bd5c6799..f0171bb5b2 100644
--- a/synapse/config/captcha.py
+++ b/synapse/config/captcha.py
@@ -35,11 +35,11 @@ class CaptchaConfig(Config):
## Captcha ##
# See docs/CAPTCHA_SETUP for full details of configuring this.
- # This Home Server's ReCAPTCHA public key.
+ # This homeserver's ReCAPTCHA public key.
#
#recaptcha_public_key: "YOUR_PUBLIC_KEY"
- # This Home Server's ReCAPTCHA private key.
+ # This homeserver's ReCAPTCHA private key.
#
#recaptcha_private_key: "YOUR_PRIVATE_KEY"
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 39e7a1dddb..18f42a87f9 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -146,6 +146,8 @@ class EmailConfig(Config):
if k not in email_config:
missing.append("email." + k)
+ # public_baseurl is required to build password reset and validation links that
+ # will be emailed to users
if config.get("public_baseurl") is None:
missing.append("public_baseurl")
@@ -305,8 +307,23 @@ class EmailConfig(Config):
# smtp_user: "exampleusername"
# smtp_pass: "examplepassword"
# require_transport_security: false
- # notif_from: "Your Friendly %(app)s Home Server <noreply@example.com>"
- # app_name: Matrix
+ #
+ # # notif_from defines the "From" address to use when sending emails.
+ # # It must be set if email sending is enabled.
+ # #
+ # # The placeholder '%(app)s' will be replaced by the application name,
+ # # which is normally 'app_name' (below), but may be overridden by the
+ # # Matrix client application.
+ # #
+ # # Note that the placeholder must be written '%(app)s', including the
+ # # trailing 's'.
+ # #
+ # notif_from: "Your Friendly %(app)s homeserver <noreply@example.com>"
+ #
+ # # app_name defines the default value for '%(app)s' in notif_from. It
+ # # defaults to 'Matrix'.
+ # #
+ # #app_name: my_branded_matrix_server
#
# # Enable email notifications by default
# #
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 1f6dac69da..ee9614c5f7 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -106,6 +106,13 @@ class RegistrationConfig(Config):
account_threepid_delegates = config.get("account_threepid_delegates") or {}
self.account_threepid_delegate_email = account_threepid_delegates.get("email")
self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn")
+ if self.account_threepid_delegate_msisdn and not self.public_baseurl:
+ raise ConfigError(
+ "The configuration option `public_baseurl` is required if "
+ "`account_threepid_delegate.msisdn` is set, such that "
+ "clients know where to submit validation tokens to. Please "
+ "configure `public_baseurl`."
+ )
self.default_identity_server = config.get("default_identity_server")
self.allow_guest_access = config.get("allow_guest_access", False)
diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
index 7c9f05bde4..7ac7699676 100644
--- a/synapse/config/room_directory.py
+++ b/synapse/config/room_directory.py
@@ -170,7 +170,7 @@ class _RoomDirectoryRule(object):
self.action = action
else:
raise ConfigError(
- "%s rules can only have action of 'allow'" " or 'deny'" % (option_name,)
+ "%s rules can only have action of 'allow' or 'deny'" % (option_name,)
)
self._alias_matches_all = alias == "*"
diff --git a/synapse/config/server.py b/synapse/config/server.py
index d556df308d..a4bef00936 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -19,7 +19,7 @@ import logging
import os.path
import re
from textwrap import indent
-from typing import List
+from typing import Dict, List, Optional
import attr
import yaml
@@ -41,7 +41,7 @@ logger = logging.Logger(__name__)
# in the list.
DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"]
-DEFAULT_ROOM_VERSION = "4"
+DEFAULT_ROOM_VERSION = "5"
ROOM_COMPLEXITY_TOO_GREAT = (
"Your homeserver is unable to join rooms this large or complex. "
@@ -118,15 +118,16 @@ class ServerConfig(Config):
self.allow_public_rooms_without_auth = False
self.allow_public_rooms_over_federation = False
else:
- # If set to 'False', requires authentication to access the server's public
- # rooms directory through the client API. Defaults to 'True'.
+ # If set to 'true', removes the need for authentication to access the server's
+ # public rooms directory through the client API, meaning that anyone can
+ # query the room directory. Defaults to 'false'.
self.allow_public_rooms_without_auth = config.get(
- "allow_public_rooms_without_auth", True
+ "allow_public_rooms_without_auth", False
)
- # If set to 'False', forbids any other homeserver to fetch the server's public
- # rooms directory via federation. Defaults to 'True'.
+ # If set to 'true', allows any other homeserver to fetch the server's public
+ # rooms directory via federation. Defaults to 'false'.
self.allow_public_rooms_over_federation = config.get(
- "allow_public_rooms_over_federation", True
+ "allow_public_rooms_over_federation", False
)
default_room_version = config.get("default_room_version", DEFAULT_ROOM_VERSION)
@@ -223,7 +224,7 @@ class ServerConfig(Config):
self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
except Exception as e:
raise ConfigError(
- "Invalid range(s) provided in " "federation_ip_range_blacklist: %s" % e
+ "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e
)
if self.public_baseurl is not None:
@@ -246,6 +247,124 @@ class ServerConfig(Config):
# events with profile information that differ from the target's global profile.
self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)
+ retention_config = config.get("retention")
+ if retention_config is None:
+ retention_config = {}
+
+ self.retention_enabled = retention_config.get("enabled", False)
+
+ retention_default_policy = retention_config.get("default_policy")
+
+ if retention_default_policy is not None:
+ self.retention_default_min_lifetime = retention_default_policy.get(
+ "min_lifetime"
+ )
+ if self.retention_default_min_lifetime is not None:
+ self.retention_default_min_lifetime = self.parse_duration(
+ self.retention_default_min_lifetime
+ )
+
+ self.retention_default_max_lifetime = retention_default_policy.get(
+ "max_lifetime"
+ )
+ if self.retention_default_max_lifetime is not None:
+ self.retention_default_max_lifetime = self.parse_duration(
+ self.retention_default_max_lifetime
+ )
+
+ if (
+ self.retention_default_min_lifetime is not None
+ and self.retention_default_max_lifetime is not None
+ and (
+ self.retention_default_min_lifetime
+ > self.retention_default_max_lifetime
+ )
+ ):
+ raise ConfigError(
+ "The default retention policy's 'min_lifetime' can not be greater"
+ " than its 'max_lifetime'"
+ )
+ else:
+ self.retention_default_min_lifetime = None
+ self.retention_default_max_lifetime = None
+
+ self.retention_allowed_lifetime_min = retention_config.get(
+ "allowed_lifetime_min"
+ )
+ if self.retention_allowed_lifetime_min is not None:
+ self.retention_allowed_lifetime_min = self.parse_duration(
+ self.retention_allowed_lifetime_min
+ )
+
+ self.retention_allowed_lifetime_max = retention_config.get(
+ "allowed_lifetime_max"
+ )
+ if self.retention_allowed_lifetime_max is not None:
+ self.retention_allowed_lifetime_max = self.parse_duration(
+ self.retention_allowed_lifetime_max
+ )
+
+ if (
+ self.retention_allowed_lifetime_min is not None
+ and self.retention_allowed_lifetime_max is not None
+ and self.retention_allowed_lifetime_min
+ > self.retention_allowed_lifetime_max
+ ):
+ raise ConfigError(
+ "Invalid retention policy limits: 'allowed_lifetime_min' can not be"
+ " greater than 'allowed_lifetime_max'"
+ )
+
+ self.retention_purge_jobs = [] # type: List[Dict[str, Optional[int]]]
+ for purge_job_config in retention_config.get("purge_jobs", []):
+ interval_config = purge_job_config.get("interval")
+
+ if interval_config is None:
+ raise ConfigError(
+ "A retention policy's purge jobs configuration must have the"
+ " 'interval' key set."
+ )
+
+ interval = self.parse_duration(interval_config)
+
+ shortest_max_lifetime = purge_job_config.get("shortest_max_lifetime")
+
+ if shortest_max_lifetime is not None:
+ shortest_max_lifetime = self.parse_duration(shortest_max_lifetime)
+
+ longest_max_lifetime = purge_job_config.get("longest_max_lifetime")
+
+ if longest_max_lifetime is not None:
+ longest_max_lifetime = self.parse_duration(longest_max_lifetime)
+
+ if (
+ shortest_max_lifetime is not None
+ and longest_max_lifetime is not None
+ and shortest_max_lifetime > longest_max_lifetime
+ ):
+ raise ConfigError(
+ "A retention policy's purge jobs configuration's"
+ " 'shortest_max_lifetime' value can not be greater than its"
+ " 'longest_max_lifetime' value."
+ )
+
+ self.retention_purge_jobs.append(
+ {
+ "interval": interval,
+ "shortest_max_lifetime": shortest_max_lifetime,
+ "longest_max_lifetime": longest_max_lifetime,
+ }
+ )
+
+ if not self.retention_purge_jobs:
+ self.retention_purge_jobs = [
+ {
+ "interval": self.parse_duration("1d"),
+ "shortest_max_lifetime": None,
+ "longest_max_lifetime": None,
+ }
+ ]
+
self.listeners = [] # type: List[dict]
for listener in config.get("listeners", []):
if not isinstance(listener.get("port", None), int):
@@ -372,6 +491,8 @@ class ServerConfig(Config):
"cleanup_extremities_with_dummy_events", True
)
+ self.enable_ephemeral_messages = config.get("enable_ephemeral_messages", False)
+
def has_tls_listener(self) -> bool:
return any(l["tls"] for l in self.listeners)
@@ -500,15 +621,16 @@ class ServerConfig(Config):
#
#require_auth_for_profile_requests: true
- # If set to 'false', requires authentication to access the server's public rooms
- # directory through the client API. Defaults to 'true'.
+ # If set to 'true', removes the need for authentication to access the server's
+ # public rooms directory through the client API, meaning that anyone can
+ # query the room directory. Defaults to 'false'.
#
- #allow_public_rooms_without_auth: false
+ #allow_public_rooms_without_auth: true
- # If set to 'false', forbids any other homeserver to fetch the server's public
- # rooms directory via federation. Defaults to 'true'.
+ # If set to 'true', allows any other homeserver to fetch the server's public
+ # rooms directory via federation. Defaults to 'false'.
#
- #allow_public_rooms_over_federation: false
+ #allow_public_rooms_over_federation: true
# The default room version for newly created rooms.
#
@@ -721,7 +843,7 @@ class ServerConfig(Config):
# Used by phonehome stats to group together related servers.
#server_context: context
- # Resource-constrained Homeserver Settings
+ # Resource-constrained homeserver Settings
#
# If limit_remote_rooms.enabled is True, the room complexity will be
# checked before a user joins a new remote room. If it is above
@@ -761,6 +883,69 @@ class ServerConfig(Config):
# Defaults to `28d`. Set to `null` to disable clearing out of old rows.
#
#user_ips_max_age: 14d
+
+ # Message retention policy at the server level.
+ #
+ # Room admins and mods can define a retention period for their rooms using the
+ # 'm.room.retention' state event, and server admins can cap this period by setting
+ # the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options.
+ #
+ # If this feature is enabled, Synapse will regularly look for and purge events
+ # which are older than the room's maximum retention period. Synapse will also
+ # filter events received over federation so that events that should have been
+ # purged are ignored and not stored again.
+ #
+ retention:
+ # The message retention policies feature is disabled by default. Uncomment the
+ # following line to enable it.
+ #
+ #enabled: true
+
+ # Default retention policy. If set, Synapse will apply it to rooms that lack the
+ # 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't
+ # matter much because Synapse doesn't take it into account yet.
+ #
+ #default_policy:
+ # min_lifetime: 1d
+ # max_lifetime: 1y
+
+ # Retention policy limits. If set, a user won't be able to send a
+ # 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
+ # that's not within this range. This is especially useful in closed federations,
+ # in which server admins can make sure every federating server applies the same
+ # rules.
+ #
+ #allowed_lifetime_min: 1d
+ #allowed_lifetime_max: 1y
+
+ # Server admins can define the settings of the background jobs purging the
+ # events which lifetime has expired under the 'purge_jobs' section.
+ #
+ # If no configuration is provided, a single job will be set up to delete expired
+ # events in every room daily.
+ #
+ # Each job's configuration defines which range of message lifetimes the job
+ # takes care of. For example, if 'shortest_max_lifetime' is '2d' and
+ # 'longest_max_lifetime' is '3d', the job will handle purging expired events in
+ # rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and
+ # lower than or equal to 3 days. Both the minimum and the maximum value of a
+ # range are optional, e.g. a job with no 'shortest_max_lifetime' and a
+ # 'longest_max_lifetime' of '3d' will handle every room with a retention policy
+ # which 'max_lifetime' is lower than or equal to three days.
+ #
+ # The rationale for this per-job configuration is that some rooms might have a
+ # retention policy with a low 'max_lifetime', where history needs to be purged
+ # of outdated messages on a very frequent basis (e.g. every 5min), but not want
+ # that purge to be performed by a job that's iterating over every room it knows,
+ # which would be quite heavy on the server.
+ #
+ #purge_jobs:
+ # - shortest_max_lifetime: 1d
+ # longest_max_lifetime: 3d
+ # interval: 5m:
+ # - shortest_max_lifetime: 3d
+ # longest_max_lifetime: 1y
+ # interval: 24h
"""
% locals()
)
@@ -781,20 +966,20 @@ class ServerConfig(Config):
"--daemonize",
action="store_true",
default=None,
- help="Daemonize the home server",
+ help="Daemonize the homeserver",
)
server_group.add_argument(
"--print-pidfile",
action="store_true",
default=None,
- help="Print the path to the pidfile just" " before daemonizing",
+ help="Print the path to the pidfile just before daemonizing",
)
server_group.add_argument(
"--manhole",
metavar="PORT",
dest="manhole",
type=int,
- help="Turn on the twisted telnet manhole" " service on the given port.",
+ help="Turn on the twisted telnet manhole service on the given port.",
)
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 272426e105..9b90c9ce04 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from six import string_types
+from six import integer_types, string_types
from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
@@ -22,11 +22,12 @@ from synapse.types import EventID, RoomID, UserID
class EventValidator(object):
- def validate_new(self, event):
+ def validate_new(self, event, config):
"""Validates the event has roughly the right format
Args:
- event (FrozenEvent)
+ event (FrozenEvent): The event to validate.
+ config (Config): The homeserver's configuration.
"""
self.validate_builder(event)
@@ -67,6 +68,99 @@ class EventValidator(object):
Codes.INVALID_PARAM,
)
+ if event.type == EventTypes.Retention:
+ self._validate_retention(event, config)
+
+ def _validate_retention(self, event, config):
+ """Checks that an event that defines the retention policy for a room respects the
+ boundaries imposed by the server's administrator.
+
+ Args:
+ event (FrozenEvent): The event to validate.
+ config (Config): The homeserver's configuration.
+ """
+ min_lifetime = event.content.get("min_lifetime")
+ max_lifetime = event.content.get("max_lifetime")
+
+ if min_lifetime is not None:
+ if not isinstance(min_lifetime, integer_types):
+ raise SynapseError(
+ code=400,
+ msg="'min_lifetime' must be an integer",
+ errcode=Codes.BAD_JSON,
+ )
+
+ if (
+ config.retention_allowed_lifetime_min is not None
+ and min_lifetime < config.retention_allowed_lifetime_min
+ ):
+ raise SynapseError(
+ code=400,
+ msg=(
+ "'min_lifetime' can't be lower than the minimum allowed"
+ " value enforced by the server's administrator"
+ ),
+ errcode=Codes.BAD_JSON,
+ )
+
+ if (
+ config.retention_allowed_lifetime_max is not None
+ and min_lifetime > config.retention_allowed_lifetime_max
+ ):
+ raise SynapseError(
+ code=400,
+ msg=(
+ "'min_lifetime' can't be greater than the maximum allowed"
+ " value enforced by the server's administrator"
+ ),
+ errcode=Codes.BAD_JSON,
+ )
+
+ if max_lifetime is not None:
+ if not isinstance(max_lifetime, integer_types):
+ raise SynapseError(
+ code=400,
+ msg="'max_lifetime' must be an integer",
+ errcode=Codes.BAD_JSON,
+ )
+
+ if (
+ config.retention_allowed_lifetime_min is not None
+ and max_lifetime < config.retention_allowed_lifetime_min
+ ):
+ raise SynapseError(
+ code=400,
+ msg=(
+ "'max_lifetime' can't be lower than the minimum allowed value"
+ " enforced by the server's administrator"
+ ),
+ errcode=Codes.BAD_JSON,
+ )
+
+ if (
+ config.retention_allowed_lifetime_max is not None
+ and max_lifetime > config.retention_allowed_lifetime_max
+ ):
+ raise SynapseError(
+ code=400,
+ msg=(
+ "'max_lifetime' can't be greater than the maximum allowed"
+ " value enforced by the server's administrator"
+ ),
+ errcode=Codes.BAD_JSON,
+ )
+
+ if (
+ min_lifetime is not None
+ and max_lifetime is not None
+ and min_lifetime > max_lifetime
+ ):
+ raise SynapseError(
+ code=400,
+ msg="'min_lifetime' can't be greater than 'max_lifetime",
+ errcode=Codes.BAD_JSON,
+ )
+
def validate_builder(self, event):
"""Validates that the builder/event has roughly the right format. Only
checks values that we expect a proto event to have, rather than all the
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 23c08104b3..890a201a5a 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -177,7 +177,7 @@ class FederationClient(FederationBase):
given destination server.
Args:
- dest (str): The remote home server to ask.
+ dest (str): The remote homeserver to ask.
room_id (str): The room_id to backfill.
limit (int): The maximum number of PDUs to return.
extremities (list): List of PDU id and origins of the first pdus
@@ -227,7 +227,7 @@ class FederationClient(FederationBase):
one succeeds.
Args:
- destinations (list): Which home servers to query
+ destinations (list): Which homeservers to query
event_id (str): event to fetch
room_version (str): version of the room
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
@@ -312,7 +312,7 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
@log_function
def get_state_for_room(self, destination, room_id, event_id):
- """Requests all of the room state at a given event from a remote home server.
+ """Requests all of the room state at a given event from a remote homeserver.
Args:
destination (str): The remote homeserver to query for the state.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 08a913e08a..d7ce333822 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
+# Copyright 2019 Matrix.org Federation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -73,6 +74,7 @@ class FederationServer(FederationBase):
self.auth = hs.get_auth()
self.handler = hs.get_handlers().federation_handler
+ self.state = hs.get_state_handler()
self._server_linearizer = Linearizer("fed_server")
self._transaction_linearizer = Linearizer("fed_txn_handler")
@@ -264,9 +266,6 @@ class FederationServer(FederationBase):
await self.registry.on_edu(edu_type, origin, content)
async def on_context_state_request(self, origin, room_id, event_id):
- if not event_id:
- raise NotImplementedError("Specify an event")
-
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)
@@ -280,13 +279,18 @@ class FederationServer(FederationBase):
# - but that's non-trivial to get right, and anyway somewhat defeats
# the point of the linearizer.
with (await self._server_linearizer.queue((origin, room_id))):
- resp = await self._state_resp_cache.wrap(
- (room_id, event_id),
- self._on_context_state_request_compute,
- room_id,
- event_id,
+ resp = dict(
+ await self._state_resp_cache.wrap(
+ (room_id, event_id),
+ self._on_context_state_request_compute,
+ room_id,
+ event_id,
+ )
)
+ room_version = await self.store.get_room_version(room_id)
+ resp["room_version"] = room_version
+
return 200, resp
async def on_state_ids_request(self, origin, room_id, event_id):
@@ -306,7 +310,11 @@ class FederationServer(FederationBase):
return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
async def _on_context_state_request_compute(self, room_id, event_id):
- pdus = await self.handler.get_state_for_pdu(room_id, event_id)
+ if event_id:
+ pdus = await self.handler.get_state_for_pdu(room_id, event_id)
+ else:
+ pdus = (await self.state.get_current_state(room_id)).values()
+
auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus])
return {
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 44edcabed4..d68b4bd670 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -44,7 +44,7 @@ class TransactionActions(object):
response code and response body.
"""
if not transaction.transaction_id:
- raise RuntimeError("Cannot persist a transaction with no " "transaction_id")
+ raise RuntimeError("Cannot persist a transaction with no transaction_id")
return self.store.get_received_txn_response(transaction.transaction_id, origin)
@@ -56,7 +56,7 @@ class TransactionActions(object):
Deferred
"""
if not transaction.transaction_id:
- raise RuntimeError("Cannot persist a transaction with no " "transaction_id")
+ raise RuntimeError("Cannot persist a transaction with no transaction_id")
return self.store.set_received_txn_response(
transaction.transaction_id, origin, code, response
diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py
index 2b2ee8612a..4ebb0e8bc0 100644
--- a/synapse/federation/sender/__init__.py
+++ b/synapse/federation/sender/__init__.py
@@ -49,7 +49,7 @@ sent_pdus_destination_dist_count = Counter(
sent_pdus_destination_dist_total = Counter(
"synapse_federation_client_sent_pdu_destinations:total",
- "" "Total number of PDUs queued for sending across all destinations",
+ "Total number of PDUs queued for sending across all destinations",
)
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 67b3e1ab6e..5fed626d5b 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -84,7 +84,7 @@ class TransactionManager(object):
txn_id = str(self._next_txn_id)
logger.debug(
- "TX [%s] {%s} Attempting new transaction" " (pdus: %d, edus: %d)",
+ "TX [%s] {%s} Attempting new transaction (pdus: %d, edus: %d)",
destination,
txn_id,
len(pdus),
@@ -103,7 +103,7 @@ class TransactionManager(object):
self._next_txn_id += 1
logger.info(
- "TX [%s] {%s} Sending transaction [%s]," " (PDUs: %d, EDUs: %d)",
+ "TX [%s] {%s} Sending transaction [%s], (PDUs: %d, EDUs: %d)",
destination,
txn_id,
transaction.transaction_id,
diff --git a/synapse/federation/transport/__init__.py b/synapse/federation/transport/__init__.py
index d9fcc520a0..5db733af98 100644
--- a/synapse/federation/transport/__init__.py
+++ b/synapse/federation/transport/__init__.py
@@ -14,9 +14,9 @@
# limitations under the License.
"""The transport layer is responsible for both sending transactions to remote
-home servers and receiving a variety of requests from other home servers.
+homeservers and receiving a variety of requests from other homeservers.
-By default this is done over HTTPS (and all home servers are required to
+By default this is done over HTTPS (and all homeservers are required to
support HTTPS), however individual pairings of servers may decide to
communicate over a different (albeit still reliable) protocol.
"""
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index df2b5dc91b..bd7fb81995 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -44,7 +44,7 @@ class TransportLayerClient(object):
given event.
Args:
- destination (str): The host name of the remote home server we want
+ destination (str): The host name of the remote homeserver we want
to get the state from.
context (str): The name of the context we want the state of
event_id (str): The event we want the context at.
@@ -68,7 +68,7 @@ class TransportLayerClient(object):
given event. Returns the state's event_id's
Args:
- destination (str): The host name of the remote home server we want
+ destination (str): The host name of the remote homeserver we want
to get the state from.
context (str): The name of the context we want the state of
event_id (str): The event we want the context at.
@@ -91,7 +91,7 @@ class TransportLayerClient(object):
""" Requests the pdu with give id and origin from the given server.
Args:
- destination (str): The host name of the remote home server we want
+ destination (str): The host name of the remote homeserver we want
to get the state from.
event_id (str): The id of the event being requested.
timeout (int): How long to try (in ms) the destination for before
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 551a162eb7..b4cbf23394 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -421,7 +421,7 @@ class FederationEventServlet(BaseFederationServlet):
return await self.handler.on_pdu_request(origin, event_id)
-class FederationStateServlet(BaseFederationServlet):
+class FederationStateV1Servlet(BaseFederationServlet):
PATH = "/state/(?P<context>[^/]*)/?"
# This is when someone asks for all data for a given context.
@@ -429,7 +429,7 @@ class FederationStateServlet(BaseFederationServlet):
return await self.handler.on_context_state_request(
origin,
context,
- parse_string_from_args(query, "event_id", None, required=True),
+ parse_string_from_args(query, "event_id", None, required=False),
)
@@ -736,7 +736,7 @@ class PublicRoomList(BaseFederationServlet):
This API returns information in the same format as /publicRooms on the
client API, but will only ever include local public rooms and hence is
- intended for consumption by other home servers.
+ intended for consumption by other homeservers.
GET /publicRooms HTTP/1.1
@@ -1382,7 +1382,7 @@ class RoomComplexityServlet(BaseFederationServlet):
FEDERATION_SERVLET_CLASSES = (
FederationSendServlet,
FederationEventServlet,
- FederationStateServlet,
+ FederationStateV1Servlet,
FederationStateIdsServlet,
FederationBackfillServlet,
FederationQueryServlet,
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 6407d56f8e..14449b9a1e 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -56,7 +56,7 @@ class AdminHandler(BaseHandler):
@defer.inlineCallbacks
def get_users(self):
- """Function to reterive a list of users in users table.
+ """Function to retrieve a list of users in users table.
Args:
Returns:
@@ -67,19 +67,22 @@ class AdminHandler(BaseHandler):
return ret
@defer.inlineCallbacks
- def get_users_paginate(self, order, start, limit):
- """Function to reterive a paginated list of users from
- users list. This will return a json object, which contains
- list of users and the total number of users in users table.
+ def get_users_paginate(self, start, limit, name, guests, deactivated):
+ """Function to retrieve a paginated list of users from
+ users list. This will return a json list of users.
Args:
- order (str): column name to order the select by this column
start (int): start number to begin the query from
- limit (int): number of rows to reterive
+ limit (int): number of rows to retrieve
+ name (string): filter for user names
+ guests (bool): whether to in include guest users
+ deactivated (bool): whether to include deactivated users
Returns:
- defer.Deferred: resolves to json object {list[dict[str, Any]], count}
+ defer.Deferred: resolves to json list[dict[str, Any]]
"""
- ret = yield self.store.get_users_paginate(order, start, limit)
+ ret = yield self.store.get_users_paginate(
+ start, limit, name, guests, deactivated
+ )
return ret
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 7a0f54ca24..54a71c49d2 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -102,8 +102,9 @@ class AuthHandler(BaseHandler):
login_types.append(t)
self._supported_login_types = login_types
- self._account_ratelimiter = Ratelimiter()
- self._failed_attempts_ratelimiter = Ratelimiter()
+ # Ratelimiter for failed auth during UIA. Uses same ratelimit config
+ # as per `rc_login.failed_attempts`.
+ self._failed_uia_attempts_ratelimiter = Ratelimiter()
self._clock = self.hs.get_clock()
@@ -133,12 +134,38 @@ class AuthHandler(BaseHandler):
AuthError if the client has completed a login flow, and it gives
a different user to `requester`
+
+ LimitExceededError if the ratelimiter's failed request count for this
+ user is too high to proceed
+
"""
+ user_id = requester.user.to_string()
+
+ # Check if we should be ratelimited due to too many previous failed attempts
+ self._failed_uia_attempts_ratelimiter.ratelimit(
+ user_id,
+ time_now_s=self._clock.time(),
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ update=False,
+ )
+
# build a list of supported flows
flows = [[login_type] for login_type in self._supported_login_types]
- result, params, _ = yield self.check_auth(flows, request_body, clientip)
+ try:
+ result, params, _ = yield self.check_auth(flows, request_body, clientip)
+ except LoginError:
+ # Update the ratelimite to say we failed (`can_do_action` doesn't raise).
+ self._failed_uia_attempts_ratelimiter.can_do_action(
+ user_id,
+ time_now_s=self._clock.time(),
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ update=True,
+ )
+ raise
# find the completed login type
for login_type in self._supported_login_types:
@@ -223,7 +250,7 @@ class AuthHandler(BaseHandler):
# could continue registration from your phone having clicked the
# email auth link on there). It's probably too open to abuse
# because it lets unauthenticated clients store arbitrary objects
- # on a home server.
+ # on a homeserver.
# Revisit: Assumimg the REST APIs do sensible validation, the data
# isn't arbintrary.
session["clientdict"] = clientdict
@@ -501,11 +528,8 @@ class AuthHandler(BaseHandler):
multiple matches
Raises:
- LimitExceededError if the ratelimiter's login requests count for this
- user is too high too proceed.
UserDeactivatedError if a user is found but is deactivated.
"""
- self.ratelimit_login_per_account(user_id)
res = yield self._find_user_id_and_pwd_hash(user_id)
if res is not None:
return res[0]
@@ -572,8 +596,6 @@ class AuthHandler(BaseHandler):
StoreError if there was a problem accessing the database
SynapseError if there was a problem with the request
LoginError if there was an authentication problem.
- LimitExceededError if the ratelimiter's login requests count for this
- user is too high too proceed.
"""
if username.startswith("@"):
@@ -581,8 +603,6 @@ class AuthHandler(BaseHandler):
else:
qualified_user_id = UserID(username, self.hs.hostname).to_string()
- self.ratelimit_login_per_account(qualified_user_id)
-
login_type = login_submission.get("type")
known_login_type = False
@@ -650,15 +670,6 @@ class AuthHandler(BaseHandler):
if not known_login_type:
raise SynapseError(400, "Unknown login type %s" % login_type)
- # unknown username or invalid password.
- self._failed_attempts_ratelimiter.ratelimit(
- qualified_user_id.lower(),
- time_now_s=self._clock.time(),
- rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
- burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
- update=True,
- )
-
# We raise a 403 here, but note that if we're doing user-interactive
# login, it turns all LoginErrors into a 401 anyway.
raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN)
@@ -710,10 +721,6 @@ class AuthHandler(BaseHandler):
Returns:
Deferred[unicode] the canonical_user_id, or Deferred[None] if
unknown user/bad password
-
- Raises:
- LimitExceededError if the ratelimiter's login requests count for this
- user is too high too proceed.
"""
lookupres = yield self._find_user_id_and_pwd_hash(user_id)
if not lookupres:
@@ -742,7 +749,7 @@ class AuthHandler(BaseHandler):
auth_api.validate_macaroon(macaroon, "login", user_id)
except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
- self.ratelimit_login_per_account(user_id)
+
yield self.auth.check_auth_blocking(user_id)
return user_id
@@ -810,7 +817,7 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at):
# 'Canonicalise' email addresses down to lower case.
- # We've now moving towards the Home Server being the entity that
+ # We've now moving towards the homeserver being the entity that
# is responsible for validating threepids used for resetting passwords
# on accounts, so in future Synapse will gain knowledge of specific
# types (mediums) of threepid. For now, we still use the existing
@@ -912,35 +919,6 @@ class AuthHandler(BaseHandler):
else:
return defer.succeed(False)
- def ratelimit_login_per_account(self, user_id):
- """Checks whether the process must be stopped because of ratelimiting.
-
- Checks against two ratelimiters: the generic one for login attempts per
- account and the one specific to failed attempts.
-
- Args:
- user_id (unicode): complete @user:id
-
- Raises:
- LimitExceededError if one of the ratelimiters' login requests count
- for this user is too high too proceed.
- """
- self._failed_attempts_ratelimiter.ratelimit(
- user_id.lower(),
- time_now_s=self._clock.time(),
- rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
- burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
- update=False,
- )
-
- self._account_ratelimiter.ratelimit(
- user_id.lower(),
- time_now_s=self._clock.time(),
- rate_hz=self.hs.config.rc_login_account.per_second,
- burst_count=self.hs.config.rc_login_account.burst_count,
- update=True,
- )
-
@attr.s
class MacaroonGenerator(object):
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 63267a0a4c..6dedaaff8d 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -95,6 +95,9 @@ class DeactivateAccountHandler(BaseHandler):
user_id, threepid["medium"], threepid["address"]
)
+ # Remove all 3PIDs this user has bound to the homeserver
+ yield self.store.user_delete_threepids(user_id)
+
# delete any devices belonging to the user, which will also
# delete corresponding access tokens.
yield self._device_handler.delete_all_devices_for_user(user_id)
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index c4632f8984..a07d2f1a17 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -119,7 +119,7 @@ class DirectoryHandler(BaseHandler):
if not service.is_interested_in_alias(room_alias.to_string()):
raise SynapseError(
400,
- "This application service has not reserved" " this kind of alias.",
+ "This application service has not reserved this kind of alias.",
errcode=Codes.EXCLUSIVE,
)
else:
@@ -283,7 +283,7 @@ class DirectoryHandler(BaseHandler):
def on_directory_query(self, args):
room_alias = RoomAlias.from_string(args["room_alias"])
if not self.hs.is_mine(room_alias):
- raise SynapseError(400, "Room Alias is not hosted on this Home Server")
+ raise SynapseError(400, "Room Alias is not hosted on this homeserver")
result = yield self.get_association_from_room_alias(room_alias)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index f09a0b73c8..28c12753c1 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -30,6 +30,7 @@ from twisted.internet import defer
from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
+from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
from synapse.types import (
UserID,
get_domain_from_id,
@@ -53,6 +54,12 @@ class E2eKeysHandler(object):
self._edu_updater = SigningKeyEduUpdater(hs, self)
+ self._is_master = hs.config.worker_app is None
+ if not self._is_master:
+ self._user_device_resync_client = ReplicationUserDevicesResyncRestServlet.make_client(
+ hs
+ )
+
federation_registry = hs.get_federation_registry()
# FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
@@ -191,9 +198,15 @@ class E2eKeysHandler(object):
# probably be tracking their device lists. However, we haven't
# done an initial sync on the device list so we do it now.
try:
- user_devices = yield self.device_handler.device_list_updater.user_device_resync(
- user_id
- )
+ if self._is_master:
+ user_devices = yield self.device_handler.device_list_updater.user_device_resync(
+ user_id
+ )
+ else:
+ user_devices = yield self._user_device_resync_client(
+ user_id=user_id
+ )
+
user_devices = user_devices["devices"]
for device in user_devices:
results[user_id] = {device["device_id"]: device["keys"]}
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 0cea445f0d..f1b4424a02 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2017, 2018 New Vector Ltd
+# Copyright 2019 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -103,14 +104,35 @@ class E2eRoomKeysHandler(object):
rooms
session_id(string): session ID to delete keys for, for None to delete keys
for all sessions
+ Raises:
+ NotFoundError: if the backup version does not exist
Returns:
- A deferred of the deletion transaction
+ A dict containing the count and etag for the backup version
"""
# lock for consistency with uploading
with (yield self._upload_linearizer.queue(user_id)):
+ # make sure the backup version exists
+ try:
+ version_info = yield self.store.get_e2e_room_keys_version_info(
+ user_id, version
+ )
+ except StoreError as e:
+ if e.code == 404:
+ raise NotFoundError("Unknown backup version")
+ else:
+ raise
+
yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id)
+ version_etag = version_info["etag"] + 1
+ yield self.store.update_e2e_room_keys_version(
+ user_id, version, None, version_etag
+ )
+
+ count = yield self.store.count_e2e_room_keys(user_id, version)
+ return {"etag": str(version_etag), "count": count}
+
@trace
@defer.inlineCallbacks
def upload_room_keys(self, user_id, version, room_keys):
@@ -138,6 +160,9 @@ class E2eRoomKeysHandler(object):
}
}
+ Returns:
+ A dict containing the count and etag for the backup version
+
Raises:
NotFoundError: if there are no versions defined
RoomKeysVersionError: if the uploaded version is not the current version
@@ -171,59 +196,62 @@ class E2eRoomKeysHandler(object):
else:
raise
- # go through the room_keys.
- # XXX: this should/could be done concurrently, given we're in a lock.
+ # Fetch any existing room keys for the sessions that have been
+ # submitted. Then compare them with the submitted keys. If the
+ # key is new, insert it; if the key should be updated, then update
+ # it; otherwise, drop it.
+ existing_keys = yield self.store.get_e2e_room_keys_multi(
+ user_id, version, room_keys["rooms"]
+ )
+ to_insert = [] # batch the inserts together
+ changed = False # if anything has changed, we need to update the etag
for room_id, room in iteritems(room_keys["rooms"]):
- for session_id, session in iteritems(room["sessions"]):
- yield self._upload_room_key(
- user_id, version, room_id, session_id, session
+ for session_id, room_key in iteritems(room["sessions"]):
+ log_kv(
+ {
+ "message": "Trying to upload room key",
+ "room_id": room_id,
+ "session_id": session_id,
+ "user_id": user_id,
+ }
)
-
- @defer.inlineCallbacks
- def _upload_room_key(self, user_id, version, room_id, session_id, room_key):
- """Upload a given room_key for a given room and session into a given
- version of the backup. Merges the key with any which might already exist.
-
- Args:
- user_id(str): the user whose backup we're setting
- version(str): the version ID of the backup we're updating
- room_id(str): the ID of the room whose keys we're setting
- session_id(str): the session whose room_key we're setting
- room_key(dict): the room_key being set
- """
- log_kv(
- {
- "message": "Trying to upload room key",
- "room_id": room_id,
- "session_id": session_id,
- "user_id": user_id,
- }
- )
- # get the room_key for this particular row
- current_room_key = None
- try:
- current_room_key = yield self.store.get_e2e_room_key(
- user_id, version, room_id, session_id
- )
- except StoreError as e:
- if e.code == 404:
- log_kv(
- {
- "message": "Room key not found.",
- "room_id": room_id,
- "user_id": user_id,
- }
+ current_room_key = existing_keys.get(room_id, {}).get(session_id)
+ if current_room_key:
+ if self._should_replace_room_key(current_room_key, room_key):
+ log_kv({"message": "Replacing room key."})
+ # updates are done one at a time in the DB, so send
+ # updates right away rather than batching them up,
+ # like we do with the inserts
+ yield self.store.update_e2e_room_key(
+ user_id, version, room_id, session_id, room_key
+ )
+ changed = True
+ else:
+ log_kv({"message": "Not replacing room_key."})
+ else:
+ log_kv(
+ {
+ "message": "Room key not found.",
+ "room_id": room_id,
+ "user_id": user_id,
+ }
+ )
+ log_kv({"message": "Replacing room key."})
+ to_insert.append((room_id, session_id, room_key))
+ changed = True
+
+ if len(to_insert):
+ yield self.store.add_e2e_room_keys(user_id, version, to_insert)
+
+ version_etag = version_info["etag"]
+ if changed:
+ version_etag = version_etag + 1
+ yield self.store.update_e2e_room_keys_version(
+ user_id, version, None, version_etag
)
- else:
- raise
- if self._should_replace_room_key(current_room_key, room_key):
- log_kv({"message": "Replacing room key."})
- yield self.store.set_e2e_room_key(
- user_id, version, room_id, session_id, room_key
- )
- else:
- log_kv({"message": "Not replacing room_key."})
+ count = yield self.store.count_e2e_room_keys(user_id, version)
+ return {"etag": str(version_etag), "count": count}
@staticmethod
def _should_replace_room_key(current_room_key, room_key):
@@ -314,6 +342,8 @@ class E2eRoomKeysHandler(object):
raise NotFoundError("Unknown backup version")
else:
raise
+
+ res["count"] = yield self.store.count_e2e_room_keys(user_id, res["version"])
return res
@trace
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 05dd8d2671..bc26921768 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -19,11 +19,13 @@
import itertools
import logging
+from typing import Dict, Iterable, Optional, Sequence, Tuple
import six
from six import iteritems, itervalues
from six.moves import http_client, zip
+import attr
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from unpaddedbase64 import decode_base64
@@ -45,6 +47,7 @@ from synapse.api.errors import (
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
from synapse.crypto.event_signing import compute_event_signature
from synapse.event_auth import auth_types_for_event
+from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.logging.context import (
@@ -72,6 +75,23 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__)
+@attr.s
+class _NewEventInfo:
+ """Holds information about a received event, ready for passing to _handle_new_events
+
+ Attributes:
+ event: the received event
+
+ state: the state at that event
+
+ auth_events: the auth_event map for that event
+ """
+
+ event = attr.ib(type=EventBase)
+ state = attr.ib(type=Optional[Sequence[EventBase]], default=None)
+ auth_events = attr.ib(type=Optional[Dict[Tuple[str, str], EventBase]], default=None)
+
+
def shortstr(iterable, maxitems=5):
"""If iterable has maxitems or fewer, return the stringification of a list
containing those items.
@@ -97,9 +117,9 @@ class FederationHandler(BaseHandler):
"""Handles events that originated from federation.
Responsible for:
a) handling received Pdus before handing them on as Events to the rest
- of the home server (including auth and state conflict resoultion)
+ of the homeserver (including auth and state conflict resoultion)
b) converting events that were produced by local clients that may need
- to be sent to remote home servers.
+ to be sent to remote homeservers.
c) doing the necessary dances to invite remote users and join remote
rooms.
"""
@@ -121,6 +141,7 @@ class FederationHandler(BaseHandler):
self.pusher_pool = hs.get_pusherpool()
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
+ self._message_handler = hs.get_message_handler()
self._server_notices_mxid = hs.config.server_notices_mxid
self.config = hs.config
self.http_client = hs.get_simple_http_client()
@@ -141,6 +162,8 @@ class FederationHandler(BaseHandler):
self.third_party_event_rules = hs.get_third_party_event_rules()
+ self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
+
@defer.inlineCallbacks
def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
""" Process a PDU received via a federation /send/ transaction, or
@@ -594,14 +617,14 @@ class FederationHandler(BaseHandler):
for e in auth_chain
if e.event_id in auth_ids or e.type == EventTypes.Create
}
- event_infos.append({"event": e, "auth_events": auth})
+ event_infos.append(_NewEventInfo(event=e, auth_events=auth))
seen_ids.add(e.event_id)
logger.info(
"[%s %s] persisting newly-received auth/state events %s",
room_id,
event_id,
- [e["event"].event_id for e in event_infos],
+ [e.event.event_id for e in event_infos],
)
yield self._handle_new_events(origin, event_infos)
@@ -792,9 +815,9 @@ class FederationHandler(BaseHandler):
a.internal_metadata.outlier = True
ev_infos.append(
- {
- "event": a,
- "auth_events": {
+ _NewEventInfo(
+ event=a,
+ auth_events={
(
auth_events[a_id].type,
auth_events[a_id].state_key,
@@ -802,7 +825,7 @@ class FederationHandler(BaseHandler):
for a_id in a.auth_event_ids()
if a_id in auth_events
},
- }
+ )
)
# Step 1b: persist the events in the chunk we fetched state for (i.e.
@@ -814,10 +837,10 @@ class FederationHandler(BaseHandler):
assert not ev.internal_metadata.is_outlier()
ev_infos.append(
- {
- "event": ev,
- "state": events_to_state[e_id],
- "auth_events": {
+ _NewEventInfo(
+ event=ev,
+ state=events_to_state[e_id],
+ auth_events={
(
auth_events[a_id].type,
auth_events[a_id].state_key,
@@ -825,7 +848,7 @@ class FederationHandler(BaseHandler):
for a_id in ev.auth_event_ids()
if a_id in auth_events
},
- }
+ )
)
yield self._handle_new_events(dest, ev_infos, backfilled=True)
@@ -1428,9 +1451,9 @@ class FederationHandler(BaseHandler):
return event
@defer.inlineCallbacks
- def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
+ def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content):
origin, event, event_format_version = yield self._make_and_verify_event(
- target_hosts, room_id, user_id, "leave"
+ target_hosts, room_id, user_id, "leave", content=content,
)
# Mark as outlier as we don't have any state for this event; we're not
# even in the room.
@@ -1710,7 +1733,12 @@ class FederationHandler(BaseHandler):
return context
@defer.inlineCallbacks
- def _handle_new_events(self, origin, event_infos, backfilled=False):
+ def _handle_new_events(
+ self,
+ origin: str,
+ event_infos: Iterable[_NewEventInfo],
+ backfilled: bool = False,
+ ):
"""Creates the appropriate contexts and persists events. The events
should not depend on one another, e.g. this should be used to persist
a bunch of outliers, but not a chunk of individual events that depend
@@ -1720,14 +1748,14 @@ class FederationHandler(BaseHandler):
"""
@defer.inlineCallbacks
- def prep(ev_info):
- event = ev_info["event"]
+ def prep(ev_info: _NewEventInfo):
+ event = ev_info.event
with nested_logging_context(suffix=event.event_id):
res = yield self._prep_event(
origin,
event,
- state=ev_info.get("state"),
- auth_events=ev_info.get("auth_events"),
+ state=ev_info.state,
+ auth_events=ev_info.auth_events,
backfilled=backfilled,
)
return res
@@ -1741,7 +1769,7 @@ class FederationHandler(BaseHandler):
yield self.persist_events_and_notify(
[
- (ev_info["event"], context)
+ (ev_info.event, context)
for ev_info, context in zip(event_infos, contexts)
],
backfilled=backfilled,
@@ -1843,7 +1871,14 @@ class FederationHandler(BaseHandler):
yield self.persist_events_and_notify([(event, new_event_context)])
@defer.inlineCallbacks
- def _prep_event(self, origin, event, state, auth_events, backfilled):
+ def _prep_event(
+ self,
+ origin: str,
+ event: EventBase,
+ state: Optional[Iterable[EventBase]],
+ auth_events: Optional[Dict[Tuple[str, str], EventBase]],
+ backfilled: bool,
+ ):
"""
Args:
@@ -1851,7 +1886,7 @@ class FederationHandler(BaseHandler):
event:
state:
auth_events:
- backfilled (bool)
+ backfilled:
Returns:
Deferred, which resolves to synapse.events.snapshot.EventContext
@@ -1887,15 +1922,16 @@ class FederationHandler(BaseHandler):
return context
@defer.inlineCallbacks
- def _check_for_soft_fail(self, event, state, backfilled):
+ def _check_for_soft_fail(
+ self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool
+ ):
"""Checks if we should soft fail the event, if so marks the event as
such.
Args:
- event (FrozenEvent)
- state (dict|None): The state at the event if we don't have all the
- event's prev events
- backfilled (bool): Whether the event is from backfill
+ event
+ state: The state at the event if we don't have all the event's prev events
+ backfilled: Whether the event is from backfill
Returns:
Deferred
@@ -2040,8 +2076,10 @@ class FederationHandler(BaseHandler):
auth_events (dict[(str, str)->synapse.events.EventBase]):
Map from (event_type, state_key) to event
- What we expect the event's auth_events to be, based on the event's
- position in the dag. I think? maybe??
+ Normally, our calculated auth_events based on the state of the room
+ at the event's position in the DAG, though occasionally (eg if the
+ event is an outlier), may be the auth events claimed by the remote
+ server.
Also NB that this function adds entries to it.
Returns:
@@ -2091,35 +2129,35 @@ class FederationHandler(BaseHandler):
origin (str):
event (synapse.events.EventBase):
context (synapse.events.snapshot.EventContext):
+
auth_events (dict[(str, str)->synapse.events.EventBase]):
+ Map from (event_type, state_key) to event
+
+ Normally, our calculated auth_events based on the state of the room
+ at the event's position in the DAG, though occasionally (eg if the
+ event is an outlier), may be the auth events claimed by the remote
+ server.
+
+ Also NB that this function adds entries to it.
Returns:
defer.Deferred[EventContext]: updated context
"""
event_auth_events = set(event.auth_event_ids())
- if event.is_state():
- event_key = (event.type, event.state_key)
- else:
- event_key = None
-
- # if the event's auth_events refers to events which are not in our
- # calculated auth_events, we need to fetch those events from somewhere.
- #
- # we start by fetching them from the store, and then try calling /event_auth/.
+ # missing_auth is the set of the event's auth_events which we don't yet have
+ # in auth_events.
missing_auth = event_auth_events.difference(
e.event_id for e in auth_events.values()
)
+ # if we have missing events, we need to fetch those events from somewhere.
+ #
+ # we start by checking if they are in the store, and then try calling /event_auth/.
if missing_auth:
- # TODO: can we use store.have_seen_events here instead?
- have_events = yield self.store.get_seen_events_with_rejections(missing_auth)
- logger.debug("Got events %s from store", have_events)
- missing_auth.difference_update(have_events.keys())
- else:
- have_events = {}
-
- have_events.update({e.event_id: "" for e in auth_events.values()})
+ have_events = yield self.store.have_seen_events(missing_auth)
+ logger.debug("Events %s are in the store", have_events)
+ missing_auth.difference_update(have_events)
if missing_auth:
# If we don't have all the auth events, we need to get them.
@@ -2165,19 +2203,18 @@ class FederationHandler(BaseHandler):
except AuthError:
pass
- have_events = yield self.store.get_seen_events_with_rejections(
- event.auth_event_ids()
- )
except Exception:
- # FIXME:
logger.exception("Failed to get auth chain")
if event.internal_metadata.is_outlier():
+ # XXX: given that, for an outlier, we'll be working with the
+ # event's *claimed* auth events rather than those we calculated:
+ # (a) is there any point in this test, since different_auth below will
+ # obviously be empty
+ # (b) alternatively, why don't we do it earlier?
logger.info("Skipping auth_event fetch for outlier")
return context
- # FIXME: Assumes we have and stored all the state for all the
- # prev_events
different_auth = event_auth_events.difference(
e.event_id for e in auth_events.values()
)
@@ -2191,53 +2228,58 @@ class FederationHandler(BaseHandler):
different_auth,
)
- room_version = yield self.store.get_room_version(event.room_id)
+ # XXX: currently this checks for redactions but I'm not convinced that is
+ # necessary?
+ different_events = yield self.store.get_events_as_list(different_auth)
- different_events = yield make_deferred_yieldable(
- defer.gatherResults(
- [
- run_in_background(
- self.store.get_event, d, allow_none=True, allow_rejected=False
- )
- for d in different_auth
- if d in have_events and not have_events[d]
- ],
- consumeErrors=True,
- )
- ).addErrback(unwrapFirstError)
+ for d in different_events:
+ if d.room_id != event.room_id:
+ logger.warning(
+ "Event %s refers to auth_event %s which is in a different room",
+ event.event_id,
+ d.event_id,
+ )
- if different_events:
- local_view = dict(auth_events)
- remote_view = dict(auth_events)
- remote_view.update(
- {(d.type, d.state_key): d for d in different_events if d}
- )
+ # don't attempt to resolve the claimed auth events against our own
+ # in this case: just use our own auth events.
+ #
+ # XXX: should we reject the event in this case? It feels like we should,
+ # but then shouldn't we also do so if we've failed to fetch any of the
+ # auth events?
+ return context
- new_state = yield self.state_handler.resolve_events(
- room_version,
- [list(local_view.values()), list(remote_view.values())],
- event,
- )
+ # now we state-resolve between our own idea of the auth events, and the remote's
+ # idea of them.
- logger.info(
- "After state res: updating auth_events with new state %s",
- {
- (d.type, d.state_key): d.event_id
- for d in new_state.values()
- if auth_events.get((d.type, d.state_key)) != d
- },
- )
+ local_state = auth_events.values()
+ remote_auth_events = dict(auth_events)
+ remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
+ remote_state = remote_auth_events.values()
+
+ room_version = yield self.store.get_room_version(event.room_id)
+ new_state = yield self.state_handler.resolve_events(
+ room_version, (local_state, remote_state), event
+ )
+
+ logger.info(
+ "After state res: updating auth_events with new state %s",
+ {
+ (d.type, d.state_key): d.event_id
+ for d in new_state.values()
+ if auth_events.get((d.type, d.state_key)) != d
+ },
+ )
- auth_events.update(new_state)
+ auth_events.update(new_state)
- context = yield self._update_context_for_auth_events(
- event, context, auth_events, event_key
- )
+ context = yield self._update_context_for_auth_events(
+ event, context, auth_events
+ )
return context
@defer.inlineCallbacks
- def _update_context_for_auth_events(self, event, context, auth_events, event_key):
+ def _update_context_for_auth_events(self, event, context, auth_events):
"""Update the state_ids in an event context after auth event resolution,
storing the changes as a new state group.
@@ -2246,18 +2288,21 @@ class FederationHandler(BaseHandler):
context (synapse.events.snapshot.EventContext): initial event context
- auth_events (dict[(str, str)->str]): Events to update in the event
+ auth_events (dict[(str, str)->EventBase]): Events to update in the event
context.
- event_key ((str, str)): (type, state_key) for the current event.
- this will not be included in the current_state in the context.
-
Returns:
Deferred[EventContext]: new event context
"""
+ # exclude the state key of the new event from the current_state in the context.
+ if event.is_state():
+ event_key = (event.type, event.state_key)
+ else:
+ event_key = None
state_updates = {
k: a.event_id for k, a in iteritems(auth_events) if k != event_key
}
+
current_state_ids = yield context.get_current_state_ids(self.store)
current_state_ids = dict(current_state_ids)
@@ -2459,7 +2504,7 @@ class FederationHandler(BaseHandler):
room_version, event_dict, event, context
)
- EventValidator().validate_new(event)
+ EventValidator().validate_new(event, self.config)
# We need to tell the transaction queue to send this out, even
# though the sender isn't a local user.
@@ -2574,7 +2619,7 @@ class FederationHandler(BaseHandler):
event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder
)
- EventValidator().validate_new(event)
+ EventValidator().validate_new(event, self.config)
return (event, context)
@defer.inlineCallbacks
@@ -2708,6 +2753,11 @@ class FederationHandler(BaseHandler):
event_and_contexts, backfilled=backfilled
)
+ if self._ephemeral_messages_enabled:
+ for (event, context) in event_and_contexts:
+ # If there's an expiry timestamp on the event, schedule its expiry.
+ self._message_handler.maybe_schedule_expiry(event)
+
if not backfilled: # Never notify for backfilled events
for event, _ in event_and_contexts:
yield self._notify_persisted_event(event, max_stream_id)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index d682dc2b7a..4f53a5f5dc 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import Optional
from six import iteritems, itervalues, string_types
@@ -22,9 +23,16 @@ from canonicaljson import encode_canonical_json, json
from twisted.internet import defer
from twisted.internet.defer import succeed
+from twisted.internet.interfaces import IDelayedCall
from synapse import event_auth
-from synapse.api.constants import EventTypes, Membership, RelationTypes, UserTypes
+from synapse.api.constants import (
+ EventContentFields,
+ EventTypes,
+ Membership,
+ RelationTypes,
+ UserTypes,
+)
from synapse.api.errors import (
AuthError,
Codes,
@@ -62,6 +70,17 @@ class MessageHandler(object):
self.storage = hs.get_storage()
self.state_store = self.storage.state
self._event_serializer = hs.get_event_client_serializer()
+ self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
+ self._is_worker_app = bool(hs.config.worker_app)
+
+ # The scheduled call to self._expire_event. None if no call is currently
+ # scheduled.
+ self._scheduled_expiry = None # type: Optional[IDelayedCall]
+
+ if not hs.config.worker_app:
+ run_as_background_process(
+ "_schedule_next_expiry", self._schedule_next_expiry
+ )
@defer.inlineCallbacks
def get_room_data(
@@ -138,7 +157,7 @@ class MessageHandler(object):
raise NotFoundError("Can't find event for token %s" % (at_token,))
visible_events = yield filter_events_for_client(
- self.storage, user_id, last_events
+ self.storage, user_id, last_events, apply_retention_policies=False
)
event = last_events[0]
@@ -225,6 +244,100 @@ class MessageHandler(object):
for user_id, profile in iteritems(users_with_profile)
}
+ def maybe_schedule_expiry(self, event):
+ """Schedule the expiry of an event if there's not already one scheduled,
+ or if the one running is for an event that will expire after the provided
+ timestamp.
+
+ This function needs to invalidate the event cache, which is only possible on
+ the master process, and therefore needs to be run on there.
+
+ Args:
+ event (EventBase): The event to schedule the expiry of.
+ """
+ assert not self._is_worker_app
+
+ expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
+ if not isinstance(expiry_ts, int) or event.is_state():
+ return
+
+ # _schedule_expiry_for_event won't actually schedule anything if there's already
+ # a task scheduled for a timestamp that's sooner than the provided one.
+ self._schedule_expiry_for_event(event.event_id, expiry_ts)
+
+ @defer.inlineCallbacks
+ def _schedule_next_expiry(self):
+ """Retrieve the ID and the expiry timestamp of the next event to be expired,
+ and schedule an expiry task for it.
+
+ If there's no event left to expire, set _expiry_scheduled to None so that a
+ future call to save_expiry_ts can schedule a new expiry task.
+ """
+ # Try to get the expiry timestamp of the next event to expire.
+ res = yield self.store.get_next_event_to_expire()
+ if res:
+ event_id, expiry_ts = res
+ self._schedule_expiry_for_event(event_id, expiry_ts)
+
+ def _schedule_expiry_for_event(self, event_id, expiry_ts):
+ """Schedule an expiry task for the provided event if there's not already one
+ scheduled at a timestamp that's sooner than the provided one.
+
+ Args:
+ event_id (str): The ID of the event to expire.
+ expiry_ts (int): The timestamp at which to expire the event.
+ """
+ if self._scheduled_expiry:
+ # If the provided timestamp refers to a time before the scheduled time of the
+ # next expiry task, cancel that task and reschedule it for this timestamp.
+ next_scheduled_expiry_ts = self._scheduled_expiry.getTime() * 1000
+ if expiry_ts < next_scheduled_expiry_ts:
+ self._scheduled_expiry.cancel()
+ else:
+ return
+
+ # Figure out how many seconds we need to wait before expiring the event.
+ now_ms = self.clock.time_msec()
+ delay = (expiry_ts - now_ms) / 1000
+
+ # callLater doesn't support negative delays, so trim the delay to 0 if we're
+ # in that case.
+ if delay < 0:
+ delay = 0
+
+ logger.info("Scheduling expiry for event %s in %.3fs", event_id, delay)
+
+ self._scheduled_expiry = self.clock.call_later(
+ delay,
+ run_as_background_process,
+ "_expire_event",
+ self._expire_event,
+ event_id,
+ )
+
+ @defer.inlineCallbacks
+ def _expire_event(self, event_id):
+ """Retrieve and expire an event that needs to be expired from the database.
+
+ If the event doesn't exist in the database, log it and delete the expiry date
+ from the database (so that we don't try to expire it again).
+ """
+ assert self._ephemeral_events_enabled
+
+ self._scheduled_expiry = None
+
+ logger.info("Expiring event %s", event_id)
+
+ try:
+ # Expire the event if we know about it. This function also deletes the expiry
+ # date from the database in the same database transaction.
+ yield self.store.expire_event(event_id)
+ except Exception as e:
+ logger.error("Could not expire event %s: %r", event_id, e)
+
+ # Schedule the expiry of the next event to expire.
+ yield self._schedule_next_expiry()
+
# The duration (in ms) after which rooms should be removed
# `_rooms_to_exclude_from_dummy_event_insertion` (with the effect that we will try
@@ -295,6 +408,10 @@ class EventCreationHandler(object):
5 * 60 * 1000,
)
+ self._message_handler = hs.get_message_handler()
+
+ self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
+
@defer.inlineCallbacks
def create_event(
self,
@@ -417,7 +534,7 @@ class EventCreationHandler(object):
403, "You must be in the room to create an alias for it"
)
- self.validator.validate_new(event)
+ self.validator.validate_new(event, self.config)
return (event, context)
@@ -634,7 +751,7 @@ class EventCreationHandler(object):
if requester:
context.app_service = requester.app_service
- self.validator.validate_new(event)
+ self.validator.validate_new(event, self.config)
# If this event is an annotation then we check that that the sender
# can't annotate the same way twice (e.g. stops users from liking an
@@ -877,6 +994,10 @@ class EventCreationHandler(object):
event, context=context
)
+ if self._ephemeral_events_enabled:
+ # If there's an expiry timestamp on the event, schedule its expiry.
+ self._message_handler.maybe_schedule_expiry(event)
+
yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
def _notify():
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 260a4351ca..8514ddc600 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -15,12 +15,15 @@
# limitations under the License.
import logging
+from six import iteritems
+
from twisted.internet import defer
from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.logging.context import run_in_background
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.state import StateFilter
from synapse.types import RoomStreamToken
from synapse.util.async_helpers import ReadWriteLock
@@ -80,6 +83,109 @@ class PaginationHandler(object):
self._purges_by_id = {}
self._event_serializer = hs.get_event_client_serializer()
+ self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime
+
+ if hs.config.retention_enabled:
+ # Run the purge jobs described in the configuration file.
+ for job in hs.config.retention_purge_jobs:
+ self.clock.looping_call(
+ run_as_background_process,
+ job["interval"],
+ "purge_history_for_rooms_in_range",
+ self.purge_history_for_rooms_in_range,
+ job["shortest_max_lifetime"],
+ job["longest_max_lifetime"],
+ )
+
+ @defer.inlineCallbacks
+ def purge_history_for_rooms_in_range(self, min_ms, max_ms):
+ """Purge outdated events from rooms within the given retention range.
+
+ If a default retention policy is defined in the server's configuration and its
+ 'max_lifetime' is within this range, also targets rooms which don't have a
+ retention policy.
+
+ Args:
+ min_ms (int|None): Duration in milliseconds that define the lower limit of
+ the range to handle (exclusive). If None, it means that the range has no
+ lower limit.
+ max_ms (int|None): Duration in milliseconds that define the upper limit of
+ the range to handle (inclusive). If None, it means that the range has no
+ upper limit.
+ """
+ # We want the storage layer to to include rooms with no retention policy in its
+ # return value only if a default retention policy is defined in the server's
+ # configuration and that policy's 'max_lifetime' is either lower (or equal) than
+ # max_ms or higher than min_ms (or both).
+ if self._retention_default_max_lifetime is not None:
+ include_null = True
+
+ if min_ms is not None and min_ms >= self._retention_default_max_lifetime:
+ # The default max_lifetime is lower than (or equal to) min_ms.
+ include_null = False
+
+ if max_ms is not None and max_ms < self._retention_default_max_lifetime:
+ # The default max_lifetime is higher than max_ms.
+ include_null = False
+ else:
+ include_null = False
+
+ rooms = yield self.store.get_rooms_for_retention_period_in_range(
+ min_ms, max_ms, include_null
+ )
+
+ for room_id, retention_policy in iteritems(rooms):
+ if room_id in self._purges_in_progress_by_room:
+ logger.warning(
+ "[purge] not purging room %s as there's an ongoing purge running"
+ " for this room",
+ room_id,
+ )
+ continue
+
+ max_lifetime = retention_policy["max_lifetime"]
+
+ if max_lifetime is None:
+ # If max_lifetime is None, it means that include_null equals True,
+ # therefore we can safely assume that there is a default policy defined
+ # in the server's configuration.
+ max_lifetime = self._retention_default_max_lifetime
+
+ # Figure out what token we should start purging at.
+ ts = self.clock.time_msec() - max_lifetime
+
+ stream_ordering = yield self.store.find_first_stream_ordering_after_ts(ts)
+
+ r = yield self.store.get_room_event_after_stream_ordering(
+ room_id, stream_ordering,
+ )
+ if not r:
+ logger.warning(
+ "[purge] purging events not possible: No event found "
+ "(ts %i => stream_ordering %i)",
+ ts,
+ stream_ordering,
+ )
+ continue
+
+ (stream, topo, _event_id) = r
+ token = "t%d-%d" % (topo, stream)
+
+ purge_id = random_string(16)
+
+ self._purges_by_id[purge_id] = PurgeStatus()
+
+ logger.info(
+ "Starting purging events in room %s (purge_id %s)" % (room_id, purge_id)
+ )
+
+ # We want to purge everything, including local events, and to run the purge in
+ # the background so that it's not blocking any other operation apart from
+ # other purges in the same room.
+ run_as_background_process(
+ "_purge_history", self._purge_history, purge_id, room_id, token, True,
+ )
+
def start_purge_history(self, room_id, token, delete_local_events=False):
"""Start off a history purge on a room.
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 22e0a04da4..1e5a4613c9 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -152,7 +152,7 @@ class BaseProfileHandler(BaseHandler):
by_admin (bool): Whether this change was made by an administrator.
"""
if not self.hs.is_mine(target_user):
- raise SynapseError(400, "User is not hosted on this Home Server")
+ raise SynapseError(400, "User is not hosted on this homeserver")
if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's displayname")
@@ -207,7 +207,7 @@ class BaseProfileHandler(BaseHandler):
"""target_user is the user whose avatar_url is to be changed;
auth_user is the user attempting to make this change."""
if not self.hs.is_mine(target_user):
- raise SynapseError(400, "User is not hosted on this Home Server")
+ raise SynapseError(400, "User is not hosted on this homeserver")
if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's avatar_url")
@@ -231,7 +231,7 @@ class BaseProfileHandler(BaseHandler):
def on_profile_query(self, args):
user = UserID.from_string(args["user_id"])
if not self.hs.is_mine(user):
- raise SynapseError(400, "User is not hosted on this Home Server")
+ raise SynapseError(400, "User is not hosted on this homeserver")
just_field = args.get("field", None)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 235f11c322..8a7d965feb 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -266,7 +266,7 @@ class RegistrationHandler(BaseHandler):
}
# Bind email to new account
- yield self._register_email_threepid(user_id, threepid_dict, None, False)
+ yield self._register_email_threepid(user_id, threepid_dict, None)
return user_id
@@ -630,7 +630,7 @@ class RegistrationHandler(BaseHandler):
# And we add an email pusher for them by default, but only
# if email notifications are enabled (so people don't start
# getting mail spam where they weren't before if email
- # notifs are set up on a home server)
+ # notifs are set up on a homeserver)
if (
self.hs.config.email_enable_notifs
and self.hs.config.email_notif_for_new_users
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index e92b2eafd5..22768e97ff 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -198,21 +199,21 @@ class RoomCreationHandler(BaseHandler):
# finally, shut down the PLs in the old room, and update them in the new
# room.
yield self._update_upgraded_room_pls(
- requester, old_room_id, new_room_id, old_room_state
+ requester, old_room_id, new_room_id, old_room_state,
)
return new_room_id
@defer.inlineCallbacks
def _update_upgraded_room_pls(
- self, requester, old_room_id, new_room_id, old_room_state
+ self, requester, old_room_id, new_room_id, old_room_state,
):
"""Send updated power levels in both rooms after an upgrade
Args:
requester (synapse.types.Requester): the user requesting the upgrade
- old_room_id (unicode): the id of the room to be replaced
- new_room_id (unicode): the id of the replacement room
+ old_room_id (str): the id of the room to be replaced
+ new_room_id (str): the id of the replacement room
old_room_state (dict[tuple[str, str], str]): the state map for the old room
Returns:
@@ -298,7 +299,7 @@ class RoomCreationHandler(BaseHandler):
tombstone_event_id (unicode|str): the ID of the tombstone event in the old
room.
Returns:
- Deferred[None]
+ Deferred
"""
user_id = requester.user.to_string()
@@ -333,6 +334,7 @@ class RoomCreationHandler(BaseHandler):
(EventTypes.Encryption, ""),
(EventTypes.ServerACL, ""),
(EventTypes.RelatedGroups, ""),
+ (EventTypes.PowerLevels, ""),
)
old_room_state_ids = yield self.store.get_filtered_current_state_ids(
@@ -346,6 +348,31 @@ class RoomCreationHandler(BaseHandler):
if old_event:
initial_state[k] = old_event.content
+ # Resolve the minimum power level required to send any state event
+ # We will give the upgrading user this power level temporarily (if necessary) such that
+ # they are able to copy all of the state events over, then revert them back to their
+ # original power level afterwards in _update_upgraded_room_pls
+
+ # Copy over user power levels now as this will not be possible with >100PL users once
+ # the room has been created
+
+ power_levels = initial_state[(EventTypes.PowerLevels, "")]
+
+ # Calculate the minimum power level needed to clone the room
+ event_power_levels = power_levels.get("events", {})
+ state_default = power_levels.get("state_default", 0)
+ ban = power_levels.get("ban")
+ needed_power_level = max(state_default, ban, max(event_power_levels.values()))
+
+ # Raise the requester's power level in the new room if necessary
+ current_power_level = power_levels["users"][requester.user.to_string()]
+ if current_power_level < needed_power_level:
+ # Assign this power level to the requester
+ power_levels["users"][requester.user.to_string()] = needed_power_level
+
+ # Set the power levels to the modified state
+ initial_state[(EventTypes.PowerLevels, "")] = power_levels
+
yield self._send_events_for_new_room(
requester,
new_room_id,
@@ -874,6 +901,10 @@ class RoomContextHandler(object):
room_id, event_id, before_limit, after_limit, event_filter
)
+ if event_filter:
+ results["events_before"] = event_filter.filter(results["events_before"])
+ results["events_after"] = event_filter.filter(results["events_after"])
+
results["events_before"] = yield filter_evts(results["events_before"])
results["events_after"] = yield filter_evts(results["events_after"])
results["event"] = event
@@ -902,7 +933,12 @@ class RoomContextHandler(object):
state = yield self.state_store.get_state_for_events(
[last_event_id], state_filter=state_filter
)
- results["state"] = list(state[last_event_id].values())
+
+ state_events = list(state[last_event_id].values())
+ if event_filter:
+ state_events = event_filter.filter(state_events)
+
+ results["state"] = state_events
# We use a dummy token here as we only care about the room portion of
# the token, which we replace.
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 6cfee4b361..7b7270fc61 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -94,7 +94,9 @@ class RoomMemberHandler(object):
raise NotImplementedError()
@abc.abstractmethod
- def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
+ def _remote_reject_invite(
+ self, requester, remote_room_hosts, room_id, target, content
+ ):
"""Attempt to reject an invite for a room this server is not in. If we
fail to do so we locally mark the invite as rejected.
@@ -104,6 +106,7 @@ class RoomMemberHandler(object):
reject invite
room_id (str)
target (UserID): The user rejecting the invite
+ content (dict): The content for the rejection event
Returns:
Deferred[dict]: A dictionary to be returned to the client, may
@@ -471,7 +474,7 @@ class RoomMemberHandler(object):
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
res = yield self._remote_reject_invite(
- requester, remote_room_hosts, room_id, target
+ requester, remote_room_hosts, room_id, target, content,
)
return res
@@ -971,13 +974,15 @@ class RoomMemberMasterHandler(RoomMemberHandler):
)
@defer.inlineCallbacks
- def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
+ def _remote_reject_invite(
+ self, requester, remote_room_hosts, room_id, target, content
+ ):
"""Implements RoomMemberHandler._remote_reject_invite
"""
fed_handler = self.federation_handler
try:
ret = yield fed_handler.do_remotely_reject_invite(
- remote_room_hosts, room_id, target.to_string()
+ remote_room_hosts, room_id, target.to_string(), content=content,
)
return ret
except Exception as e:
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index 75e96ae1a2..69be86893b 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -55,7 +55,9 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
return ret
- def _remote_reject_invite(self, requester, remote_room_hosts, room_id, target):
+ def _remote_reject_invite(
+ self, requester, remote_room_hosts, room_id, target, content
+ ):
"""Implements RoomMemberHandler._remote_reject_invite
"""
return self._remote_reject_client(
@@ -63,6 +65,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
remote_room_hosts=remote_room_hosts,
room_id=room_id,
user_id=target.to_string(),
+ content=content,
)
def _user_joined_room(self, target, room_id):
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index ca8ae9fb5b..856337b7e2 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -120,7 +120,7 @@ class TypingHandler(object):
auth_user_id = auth_user.to_string()
if not self.is_mine_id(target_user_id):
- raise SynapseError(400, "User is not hosted on this Home Server")
+ raise SynapseError(400, "User is not hosted on this homeserver")
if target_user_id != auth_user_id:
raise AuthError(400, "Cannot set another user's typing state")
@@ -150,7 +150,7 @@ class TypingHandler(object):
auth_user_id = auth_user.to_string()
if not self.is_mine_id(target_user_id):
- raise SynapseError(400, "User is not hosted on this Home Server")
+ raise SynapseError(400, "User is not hosted on this homeserver")
if target_user_id != auth_user_id:
raise AuthError(400, "Cannot set another user's typing state")
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 691380abda..16765d54e0 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -530,7 +530,7 @@ class MatrixFederationHttpClient(object):
"""
Builds the Authorization headers for a federation request
Args:
- destination (bytes|None): The desination home server of the request.
+ destination (bytes|None): The desination homeserver of the request.
May be None if the destination is an identity server, in which case
destination_is must be non-None.
method (bytes): The HTTP method of the request
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index e9a5e46ced..13fcb408a6 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -96,7 +96,7 @@ def parse_boolean_from_args(args, name, default=None, required=False):
return {b"true": True, b"false": False}[args[name][0]]
except Exception:
message = (
- "Boolean query parameter %r must be one of" " ['true', 'false']"
+ "Boolean query parameter %r must be one of ['true', 'false']"
) % (name,)
raise SynapseError(400, message)
else:
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
index 334ddaf39a..ffa7b20ca8 100644
--- a/synapse/logging/_structured.py
+++ b/synapse/logging/_structured.py
@@ -261,6 +261,18 @@ def parse_drain_configs(
)
+class StoppableLogPublisher(LogPublisher):
+ """
+ A log publisher that can tell its observers to shut down any external
+ communications.
+ """
+
+ def stop(self):
+ for obs in self._observers:
+ if hasattr(obs, "stop"):
+ obs.stop()
+
+
def setup_structured_logging(
hs,
config,
@@ -336,7 +348,7 @@ def setup_structured_logging(
# We should never get here, but, just in case, throw an error.
raise ConfigError("%s drain type cannot be configured" % (observer.type,))
- publisher = LogPublisher(*observers)
+ publisher = StoppableLogPublisher(*observers)
log_filter = LogLevelFilterPredicate()
for namespace, namespace_config in log_config.get(
diff --git a/synapse/logging/_terse_json.py b/synapse/logging/_terse_json.py
index 0ebbde06f2..03934956f4 100644
--- a/synapse/logging/_terse_json.py
+++ b/synapse/logging/_terse_json.py
@@ -17,25 +17,29 @@
Log formatters that output terse JSON.
"""
+import json
import sys
+import traceback
from collections import deque
from ipaddress import IPv4Address, IPv6Address, ip_address
from math import floor
-from typing import IO
+from typing import IO, Optional
import attr
-from simplejson import dumps
from zope.interface import implementer
from twisted.application.internet import ClientService
+from twisted.internet.defer import Deferred
from twisted.internet.endpoints import (
HostnameEndpoint,
TCP4ClientEndpoint,
TCP6ClientEndpoint,
)
+from twisted.internet.interfaces import IPushProducer, ITransport
from twisted.internet.protocol import Factory, Protocol
from twisted.logger import FileLogObserver, ILogObserver, Logger
-from twisted.python.failure import Failure
+
+_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"))
def flatten_event(event: dict, metadata: dict, include_time: bool = False):
@@ -141,19 +145,57 @@ def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogOb
def formatEvent(_event: dict) -> str:
flattened = flatten_event(_event, metadata)
- return dumps(flattened, ensure_ascii=False, separators=(",", ":")) + "\n"
+ return _encoder.encode(flattened) + "\n"
return FileLogObserver(outFile, formatEvent)
@attr.s
+@implementer(IPushProducer)
+class LogProducer(object):
+ """
+ An IPushProducer that writes logs from its buffer to its transport when it
+ is resumed.
+
+ Args:
+ buffer: Log buffer to read logs from.
+ transport: Transport to write to.
+ """
+
+ transport = attr.ib(type=ITransport)
+ _buffer = attr.ib(type=deque)
+ _paused = attr.ib(default=False, type=bool, init=False)
+
+ def pauseProducing(self):
+ self._paused = True
+
+ def stopProducing(self):
+ self._paused = True
+ self._buffer = None
+
+ def resumeProducing(self):
+ self._paused = False
+
+ while self._paused is False and (self._buffer and self.transport.connected):
+ try:
+ event = self._buffer.popleft()
+ self.transport.write(_encoder.encode(event).encode("utf8"))
+ self.transport.write(b"\n")
+ except Exception:
+ # Something has gone wrong writing to the transport -- log it
+ # and break out of the while.
+ traceback.print_exc(file=sys.__stderr__)
+ break
+
+
+@attr.s
@implementer(ILogObserver)
class TerseJSONToTCPLogObserver(object):
"""
An IObserver that writes JSON logs to a TCP target.
Args:
- hs (HomeServer): The Homeserver that is being logged for.
+ hs (HomeServer): The homeserver that is being logged for.
host: The host of the logging target.
port: The logging target's port.
metadata: Metadata to be added to each log entry.
@@ -165,8 +207,9 @@ class TerseJSONToTCPLogObserver(object):
metadata = attr.ib(type=dict)
maximum_buffer = attr.ib(type=int)
_buffer = attr.ib(default=attr.Factory(deque), type=deque)
- _writer = attr.ib(default=None)
+ _connection_waiter = attr.ib(default=None, type=Optional[Deferred])
_logger = attr.ib(default=attr.Factory(Logger))
+ _producer = attr.ib(default=None, type=Optional[LogProducer])
def start(self) -> None:
@@ -187,38 +230,44 @@ class TerseJSONToTCPLogObserver(object):
factory = Factory.forProtocol(Protocol)
self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
self._service.startService()
+ self._connect()
- def _write_loop(self) -> None:
+ def stop(self):
+ self._service.stopService()
+
+ def _connect(self) -> None:
"""
- Implement the write loop.
+ Triggers an attempt to connect then write to the remote if not already writing.
"""
- if self._writer:
+ if self._connection_waiter:
return
- self._writer = self._service.whenConnected()
+ self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
+
+ @self._connection_waiter.addErrback
+ def fail(r):
+ r.printTraceback(file=sys.__stderr__)
+ self._connection_waiter = None
+ self._connect()
- @self._writer.addBoth
+ @self._connection_waiter.addCallback
def writer(r):
- if isinstance(r, Failure):
- r.printTraceback(file=sys.__stderr__)
- self._writer = None
- self.hs.get_reactor().callLater(1, self._write_loop)
+ # We have a connection. If we already have a producer, and its
+ # transport is the same, just trigger a resumeProducing.
+ if self._producer and r.transport is self._producer.transport:
+ self._producer.resumeProducing()
+ self._connection_waiter = None
return
- try:
- for event in self._buffer:
- r.transport.write(
- dumps(event, ensure_ascii=False, separators=(",", ":")).encode(
- "utf8"
- )
- )
- r.transport.write(b"\n")
- self._buffer.clear()
- except Exception as e:
- sys.__stderr__.write("Failed writing out logs with %s\n" % (str(e),))
-
- self._writer = False
- self.hs.get_reactor().callLater(1, self._write_loop)
+ # If the producer is still producing, stop it.
+ if self._producer:
+ self._producer.stopProducing()
+
+ # Make a new producer and start it.
+ self._producer = LogProducer(buffer=self._buffer, transport=r.transport)
+ r.transport.registerProducer(self._producer, True)
+ self._producer.resumeProducing()
+ self._connection_waiter = None
def _handle_pressure(self) -> None:
"""
@@ -277,4 +326,4 @@ class TerseJSONToTCPLogObserver(object):
self._logger.failure("Failed clearing backpressure")
# Try and write immediately.
- self._write_loop()
+ self._connect()
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 1ba7bcd4d8..7881780760 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -386,15 +386,7 @@ class RulesForRoom(object):
"""
sequence = self.sequence
- rows = yield self.store._simple_select_many_batch(
- table="room_memberships",
- column="event_id",
- iterable=member_event_ids.values(),
- retcols=("user_id", "membership", "event_id"),
- keyvalues={},
- batch_size=500,
- desc="_get_rules_for_member_event_ids",
- )
+ rows = yield self.store.get_membership_from_event_ids(member_event_ids.values())
members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows}
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index e994037be6..d0879b0490 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -246,7 +246,7 @@ class HttpPusher(object):
# fixed, we don't suddenly deliver a load
# of old notifications.
logger.warning(
- "Giving up on a notification to user %s, " "pushkey %s",
+ "Giving up on a notification to user %s, pushkey %s",
self.user_id,
self.pushkey,
)
@@ -299,8 +299,7 @@ class HttpPusher(object):
# for sanity, we only remove the pushkey if it
# was the one we actually sent...
logger.warning(
- ("Ignoring rejected pushkey %s because we" " didn't send it"),
- pk,
+ ("Ignoring rejected pushkey %s because we didn't send it"), pk,
)
else:
logger.info("Pushkey %s was rejected: removing", pk)
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 1d15a06a58..b13b646bfd 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -43,7 +43,7 @@ logger = logging.getLogger(__name__)
MESSAGE_FROM_PERSON_IN_ROOM = (
- "You have a message on %(app)s from %(person)s " "in the %(room)s room..."
+ "You have a message on %(app)s from %(person)s in the %(room)s room..."
)
MESSAGE_FROM_PERSON = "You have a message on %(app)s from %(person)s..."
MESSAGES_FROM_PERSON = "You have messages on %(app)s from %(person)s..."
@@ -55,7 +55,7 @@ MESSAGES_FROM_PERSON_AND_OTHERS = (
"You have messages on %(app)s from %(person)s and others..."
)
INVITE_FROM_PERSON_TO_ROOM = (
- "%(person)s has invited you to join the " "%(room)s room on %(app)s..."
+ "%(person)s has invited you to join the %(room)s room on %(app)s..."
)
INVITE_FROM_PERSON = "%(person)s has invited you to chat on %(app)s..."
diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py
index 81b85352b1..28dbc6fcba 100644
--- a/synapse/replication/http/__init__.py
+++ b/synapse/replication/http/__init__.py
@@ -14,7 +14,14 @@
# limitations under the License.
from synapse.http.server import JsonResource
-from synapse.replication.http import federation, login, membership, register, send_event
+from synapse.replication.http import (
+ devices,
+ federation,
+ login,
+ membership,
+ register,
+ send_event,
+)
REPLICATION_PREFIX = "/_synapse/replication"
@@ -30,3 +37,4 @@ class ReplicationRestResource(JsonResource):
federation.register_servlets(hs, self)
login.register_servlets(hs, self)
register.register_servlets(hs, self)
+ devices.register_servlets(hs, self)
diff --git a/synapse/replication/http/devices.py b/synapse/replication/http/devices.py
new file mode 100644
index 0000000000..e32aac0a25
--- /dev/null
+++ b/synapse/replication/http/devices.py
@@ -0,0 +1,73 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from synapse.replication.http._base import ReplicationEndpoint
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
+ """Ask master to resync the device list for a user by contacting their
+ server.
+
+ This must happen on master so that the results can be correctly cached in
+ the database and streamed to workers.
+
+ Request format:
+
+ POST /_synapse/replication/user_device_resync/:user_id
+
+ {}
+
+ Response is equivalent to ` /_matrix/federation/v1/user/devices/:user_id`
+ response, e.g.:
+
+ {
+ "user_id": "@alice:example.org",
+ "devices": [
+ {
+ "device_id": "JLAFKJWSCS",
+ "keys": { ... },
+ "device_display_name": "Alice's Mobile Phone"
+ }
+ ]
+ }
+ """
+
+ NAME = "user_device_resync"
+ PATH_ARGS = ("user_id",)
+ CACHE = False
+
+ def __init__(self, hs):
+ super(ReplicationUserDevicesResyncRestServlet, self).__init__(hs)
+
+ self.device_list_updater = hs.get_device_handler().device_list_updater
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+
+ @staticmethod
+ def _serialize_payload(user_id):
+ return {}
+
+ async def _handle_request(self, request, user_id):
+ user_devices = await self.device_list_updater.user_device_resync(user_id)
+
+ return 200, user_devices
+
+
+def register_servlets(hs, http_server):
+ ReplicationUserDevicesResyncRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index cc1f249740..3577611fd7 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -93,6 +93,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
{
"requester": ...,
"remote_room_hosts": [...],
+ "content": { ... }
}
"""
@@ -107,7 +108,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
- def _serialize_payload(requester, room_id, user_id, remote_room_hosts):
+ def _serialize_payload(requester, room_id, user_id, remote_room_hosts, content):
"""
Args:
requester(Requester)
@@ -118,12 +119,14 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
return {
"requester": requester.serialize(),
"remote_room_hosts": remote_room_hosts,
+ "content": content,
}
async def _handle_request(self, request, room_id, user_id):
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
+ event_content = content["content"]
requester = Requester.deserialize(self.store, content["requester"])
@@ -134,7 +137,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
try:
event = await self.federation_handler.do_remotely_reject_invite(
- remote_room_hosts, room_id, user_id
+ remote_room_hosts, room_id, user_id, event_content,
)
ret = event.get_pdu_json()
except Exception as e:
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 456bc005a0..6ece1d6745 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -18,7 +18,8 @@ from typing import Dict
import six
-from synapse.storage._base import _CURRENT_STATE_CACHE_NAME, SQLBaseStore
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME
from synapse.storage.engines import PostgresEngine
from ._slaved_id_tracker import SlavedIdTracker
@@ -62,7 +63,7 @@ class BaseSlavedStore(SQLBaseStore):
if stream_name == "caches":
self._cache_id_gen.advance(token)
for row in rows:
- if row.cache_func == _CURRENT_STATE_CACHE_NAME:
+ if row.cache_func == CURRENT_STATE_CACHE_NAME:
room_id = row.keys[0]
members_changed = set(row.keys[1:])
self._invalidate_state_caches(room_id, members_changed)
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 9e45429d49..8512923eae 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -88,8 +88,7 @@ TagAccountDataStreamRow = namedtuple(
"TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
)
AccountDataStreamRow = namedtuple(
- "AccountDataStream",
- ("user_id", "room_id", "data_type", "data"), # str # str # str # dict
+ "AccountDataStream", ("user_id", "room_id", "data_type") # str # str # str
)
GroupsStreamRow = namedtuple(
"GroupsStreamRow",
@@ -421,8 +420,8 @@ class AccountDataStream(Stream):
results = list(room_results)
results.extend(
- (stream_id, user_id, None, account_data_type, content)
- for stream_id, user_id, account_data_type, content in global_results
+ (stream_id, user_id, None, account_data_type)
+ for stream_id, user_id, account_data_type in global_results
)
return results
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 5c2a2eb593..c122c449f4 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -14,62 +14,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import hashlib
-import hmac
import logging
import platform
import re
-from six import text_type
-from six.moves import http_client
-
import synapse
-from synapse.api.constants import Membership, UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import JsonResource
-from synapse.http.servlet import (
- RestServlet,
- assert_params_in_dict,
- parse_integer,
- parse_json_object_from_request,
- parse_string,
-)
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.admin._base import (
assert_requester_is_admin,
- assert_user_is_admin,
historical_admin_path_patterns,
)
+from synapse.rest.admin.groups import DeleteGroupAdminRestServlet
from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
+from synapse.rest.admin.rooms import ShutdownRoomRestServlet
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
-from synapse.rest.admin.users import UserAdminServlet
-from synapse.types import UserID, create_requester
-from synapse.util.async_helpers import maybe_awaitable
+from synapse.rest.admin.users import (
+ AccountValidityRenewServlet,
+ DeactivateAccountRestServlet,
+ ResetPasswordRestServlet,
+ SearchUsersRestServlet,
+ UserAdminServlet,
+ UserRegisterServlet,
+ UsersRestServlet,
+ UsersRestServletV2,
+ WhoisRestServlet,
+)
from synapse.util.versionstring import get_version_string
logger = logging.getLogger(__name__)
-class UsersRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)$")
-
- def __init__(self, hs):
- self.hs = hs
- self.auth = hs.get_auth()
- self.handlers = hs.get_handlers()
-
- async def on_GET(self, request, user_id):
- target_user = UserID.from_string(user_id)
- await assert_requester_is_admin(self.auth, request)
-
- if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Can only users a local user")
-
- ret = await self.handlers.admin_handler.get_users()
-
- return 200, ret
-
-
class VersionServlet(RestServlet):
PATTERNS = (re.compile("^/_synapse/admin/v1/server_version$"),)
@@ -83,159 +60,6 @@ class VersionServlet(RestServlet):
return 200, self.res
-class UserRegisterServlet(RestServlet):
- """
- Attributes:
- NONCE_TIMEOUT (int): Seconds until a generated nonce won't be accepted
- nonces (dict[str, int]): The nonces that we will accept. A dict of
- nonce to the time it was generated, in int seconds.
- """
-
- PATTERNS = historical_admin_path_patterns("/register")
- NONCE_TIMEOUT = 60
-
- def __init__(self, hs):
- self.handlers = hs.get_handlers()
- self.reactor = hs.get_reactor()
- self.nonces = {}
- self.hs = hs
-
- def _clear_old_nonces(self):
- """
- Clear out old nonces that are older than NONCE_TIMEOUT.
- """
- now = int(self.reactor.seconds())
-
- for k, v in list(self.nonces.items()):
- if now - v > self.NONCE_TIMEOUT:
- del self.nonces[k]
-
- def on_GET(self, request):
- """
- Generate a new nonce.
- """
- self._clear_old_nonces()
-
- nonce = self.hs.get_secrets().token_hex(64)
- self.nonces[nonce] = int(self.reactor.seconds())
- return 200, {"nonce": nonce}
-
- async def on_POST(self, request):
- self._clear_old_nonces()
-
- if not self.hs.config.registration_shared_secret:
- raise SynapseError(400, "Shared secret registration is not enabled")
-
- body = parse_json_object_from_request(request)
-
- if "nonce" not in body:
- raise SynapseError(400, "nonce must be specified", errcode=Codes.BAD_JSON)
-
- nonce = body["nonce"]
-
- if nonce not in self.nonces:
- raise SynapseError(400, "unrecognised nonce")
-
- # Delete the nonce, so it can't be reused, even if it's invalid
- del self.nonces[nonce]
-
- if "username" not in body:
- raise SynapseError(
- 400, "username must be specified", errcode=Codes.BAD_JSON
- )
- else:
- if (
- not isinstance(body["username"], text_type)
- or len(body["username"]) > 512
- ):
- raise SynapseError(400, "Invalid username")
-
- username = body["username"].encode("utf-8")
- if b"\x00" in username:
- raise SynapseError(400, "Invalid username")
-
- if "password" not in body:
- raise SynapseError(
- 400, "password must be specified", errcode=Codes.BAD_JSON
- )
- else:
- if (
- not isinstance(body["password"], text_type)
- or len(body["password"]) > 512
- ):
- raise SynapseError(400, "Invalid password")
-
- password = body["password"].encode("utf-8")
- if b"\x00" in password:
- raise SynapseError(400, "Invalid password")
-
- admin = body.get("admin", None)
- user_type = body.get("user_type", None)
-
- if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
- raise SynapseError(400, "Invalid user type")
-
- got_mac = body["mac"]
-
- want_mac = hmac.new(
- key=self.hs.config.registration_shared_secret.encode(),
- digestmod=hashlib.sha1,
- )
- want_mac.update(nonce.encode("utf8"))
- want_mac.update(b"\x00")
- want_mac.update(username)
- want_mac.update(b"\x00")
- want_mac.update(password)
- want_mac.update(b"\x00")
- want_mac.update(b"admin" if admin else b"notadmin")
- if user_type:
- want_mac.update(b"\x00")
- want_mac.update(user_type.encode("utf8"))
- want_mac = want_mac.hexdigest()
-
- if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")):
- raise SynapseError(403, "HMAC incorrect")
-
- # Reuse the parts of RegisterRestServlet to reduce code duplication
- from synapse.rest.client.v2_alpha.register import RegisterRestServlet
-
- register = RegisterRestServlet(self.hs)
-
- user_id = await register.registration_handler.register_user(
- localpart=body["username"].lower(),
- password=body["password"],
- admin=bool(admin),
- user_type=user_type,
- )
-
- result = await register._create_registration_details(user_id, body)
- return 200, result
-
-
-class WhoisRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/whois/(?P<user_id>[^/]*)")
-
- def __init__(self, hs):
- self.hs = hs
- self.auth = hs.get_auth()
- self.handlers = hs.get_handlers()
-
- async def on_GET(self, request, user_id):
- target_user = UserID.from_string(user_id)
- requester = await self.auth.get_user_by_req(request)
- auth_user = requester.user
-
- if target_user != auth_user:
- await assert_user_is_admin(self.auth, auth_user)
-
- if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Can only whois a local user")
-
- ret = await self.handlers.admin_handler.get_whois(target_user)
-
- return 200, ret
-
-
class PurgeHistoryRestServlet(RestServlet):
PATTERNS = historical_admin_path_patterns(
"/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
@@ -342,369 +166,6 @@ class PurgeHistoryStatusRestServlet(RestServlet):
return 200, purge_status.asdict()
-class DeactivateAccountRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/deactivate/(?P<target_user_id>[^/]*)")
-
- def __init__(self, hs):
- self._deactivate_account_handler = hs.get_deactivate_account_handler()
- self.auth = hs.get_auth()
-
- async def on_POST(self, request, target_user_id):
- await assert_requester_is_admin(self.auth, request)
- body = parse_json_object_from_request(request, allow_empty_body=True)
- erase = body.get("erase", False)
- if not isinstance(erase, bool):
- raise SynapseError(
- http_client.BAD_REQUEST,
- "Param 'erase' must be a boolean, if given",
- Codes.BAD_JSON,
- )
-
- UserID.from_string(target_user_id)
-
- result = await self._deactivate_account_handler.deactivate_account(
- target_user_id, erase
- )
- if result:
- id_server_unbind_result = "success"
- else:
- id_server_unbind_result = "no-support"
-
- return 200, {"id_server_unbind_result": id_server_unbind_result}
-
-
-class ShutdownRoomRestServlet(RestServlet):
- """Shuts down a room by removing all local users from the room and blocking
- all future invites and joins to the room. Any local aliases will be repointed
- to a new room created by `new_room_user_id` and kicked users will be auto
- joined to the new room.
- """
-
- PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P<room_id>[^/]+)")
-
- DEFAULT_MESSAGE = (
- "Sharing illegal content on this server is not permitted and rooms in"
- " violation will be blocked."
- )
-
- def __init__(self, hs):
- self.hs = hs
- self.store = hs.get_datastore()
- self.state = hs.get_state_handler()
- self._room_creation_handler = hs.get_room_creation_handler()
- self.event_creation_handler = hs.get_event_creation_handler()
- self.room_member_handler = hs.get_room_member_handler()
- self.auth = hs.get_auth()
-
- async def on_POST(self, request, room_id):
- requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
-
- content = parse_json_object_from_request(request)
- assert_params_in_dict(content, ["new_room_user_id"])
- new_room_user_id = content["new_room_user_id"]
-
- room_creator_requester = create_requester(new_room_user_id)
-
- message = content.get("message", self.DEFAULT_MESSAGE)
- room_name = content.get("room_name", "Content Violation Notification")
-
- info = await self._room_creation_handler.create_room(
- room_creator_requester,
- config={
- "preset": "public_chat",
- "name": room_name,
- "power_level_content_override": {"users_default": -10},
- },
- ratelimit=False,
- )
- new_room_id = info["room_id"]
-
- requester_user_id = requester.user.to_string()
-
- logger.info(
- "Shutting down room %r, joining to new room: %r", room_id, new_room_id
- )
-
- # This will work even if the room is already blocked, but that is
- # desirable in case the first attempt at blocking the room failed below.
- await self.store.block_room(room_id, requester_user_id)
-
- users = await self.state.get_current_users_in_room(room_id)
- kicked_users = []
- failed_to_kick_users = []
- for user_id in users:
- if not self.hs.is_mine_id(user_id):
- continue
-
- logger.info("Kicking %r from %r...", user_id, room_id)
-
- try:
- target_requester = create_requester(user_id)
- await self.room_member_handler.update_membership(
- requester=target_requester,
- target=target_requester.user,
- room_id=room_id,
- action=Membership.LEAVE,
- content={},
- ratelimit=False,
- require_consent=False,
- )
-
- await self.room_member_handler.forget(target_requester.user, room_id)
-
- await self.room_member_handler.update_membership(
- requester=target_requester,
- target=target_requester.user,
- room_id=new_room_id,
- action=Membership.JOIN,
- content={},
- ratelimit=False,
- require_consent=False,
- )
-
- kicked_users.append(user_id)
- except Exception:
- logger.exception(
- "Failed to leave old room and join new room for %r", user_id
- )
- failed_to_kick_users.append(user_id)
-
- await self.event_creation_handler.create_and_send_nonmember_event(
- room_creator_requester,
- {
- "type": "m.room.message",
- "content": {"body": message, "msgtype": "m.text"},
- "room_id": new_room_id,
- "sender": new_room_user_id,
- },
- ratelimit=False,
- )
-
- aliases_for_room = await maybe_awaitable(
- self.store.get_aliases_for_room(room_id)
- )
-
- await self.store.update_aliases_for_room(
- room_id, new_room_id, requester_user_id
- )
-
- return (
- 200,
- {
- "kicked_users": kicked_users,
- "failed_to_kick_users": failed_to_kick_users,
- "local_aliases": aliases_for_room,
- "new_room_id": new_room_id,
- },
- )
-
-
-class ResetPasswordRestServlet(RestServlet):
- """Post request to allow an administrator reset password for a user.
- This needs user to have administrator access in Synapse.
- Example:
- http://localhost:8008/_synapse/admin/v1/reset_password/
- @user:to_reset_password?access_token=admin_access_token
- JsonBodyToSend:
- {
- "new_password": "secret"
- }
- Returns:
- 200 OK with empty object if success otherwise an error.
- """
-
- PATTERNS = historical_admin_path_patterns(
- "/reset_password/(?P<target_user_id>[^/]*)"
- )
-
- def __init__(self, hs):
- self.store = hs.get_datastore()
- self.hs = hs
- self.auth = hs.get_auth()
- self._set_password_handler = hs.get_set_password_handler()
-
- async def on_POST(self, request, target_user_id):
- """Post request to allow an administrator reset password for a user.
- This needs user to have administrator access in Synapse.
- """
- requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
-
- UserID.from_string(target_user_id)
-
- params = parse_json_object_from_request(request)
- assert_params_in_dict(params, ["new_password"])
- new_password = params["new_password"]
-
- await self._set_password_handler.set_password(
- target_user_id, new_password, requester
- )
- return 200, {}
-
-
-class GetUsersPaginatedRestServlet(RestServlet):
- """Get request to get specific number of users from Synapse.
- This needs user to have administrator access in Synapse.
- Example:
- http://localhost:8008/_synapse/admin/v1/users_paginate/
- @admin:user?access_token=admin_access_token&start=0&limit=10
- Returns:
- 200 OK with json object {list[dict[str, Any]], count} or empty object.
- """
-
- PATTERNS = historical_admin_path_patterns(
- "/users_paginate/(?P<target_user_id>[^/]*)"
- )
-
- def __init__(self, hs):
- self.store = hs.get_datastore()
- self.hs = hs
- self.auth = hs.get_auth()
- self.handlers = hs.get_handlers()
-
- async def on_GET(self, request, target_user_id):
- """Get request to get specific number of users from Synapse.
- This needs user to have administrator access in Synapse.
- """
- await assert_requester_is_admin(self.auth, request)
-
- target_user = UserID.from_string(target_user_id)
-
- if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Can only users a local user")
-
- order = "name" # order by name in user table
- start = parse_integer(request, "start", required=True)
- limit = parse_integer(request, "limit", required=True)
-
- logger.info("limit: %s, start: %s", limit, start)
-
- ret = await self.handlers.admin_handler.get_users_paginate(order, start, limit)
- return 200, ret
-
- async def on_POST(self, request, target_user_id):
- """Post request to get specific number of users from Synapse..
- This needs user to have administrator access in Synapse.
- Example:
- http://localhost:8008/_synapse/admin/v1/users_paginate/
- @admin:user?access_token=admin_access_token
- JsonBodyToSend:
- {
- "start": "0",
- "limit": "10
- }
- Returns:
- 200 OK with json object {list[dict[str, Any]], count} or empty object.
- """
- await assert_requester_is_admin(self.auth, request)
- UserID.from_string(target_user_id)
-
- order = "name" # order by name in user table
- params = parse_json_object_from_request(request)
- assert_params_in_dict(params, ["limit", "start"])
- limit = params["limit"]
- start = params["start"]
- logger.info("limit: %s, start: %s", limit, start)
-
- ret = await self.handlers.admin_handler.get_users_paginate(order, start, limit)
- return 200, ret
-
-
-class SearchUsersRestServlet(RestServlet):
- """Get request to search user table for specific users according to
- search term.
- This needs user to have administrator access in Synapse.
- Example:
- http://localhost:8008/_synapse/admin/v1/search_users/
- @admin:user?access_token=admin_access_token&term=alice
- Returns:
- 200 OK with json object {list[dict[str, Any]], count} or empty object.
- """
-
- PATTERNS = historical_admin_path_patterns("/search_users/(?P<target_user_id>[^/]*)")
-
- def __init__(self, hs):
- self.store = hs.get_datastore()
- self.hs = hs
- self.auth = hs.get_auth()
- self.handlers = hs.get_handlers()
-
- async def on_GET(self, request, target_user_id):
- """Get request to search user table for specific users according to
- search term.
- This needs user to have a administrator access in Synapse.
- """
- await assert_requester_is_admin(self.auth, request)
-
- target_user = UserID.from_string(target_user_id)
-
- # To allow all users to get the users list
- # if not is_admin and target_user != auth_user:
- # raise AuthError(403, "You are not a server admin")
-
- if not self.hs.is_mine(target_user):
- raise SynapseError(400, "Can only users a local user")
-
- term = parse_string(request, "term", required=True)
- logger.info("term: %s ", term)
-
- ret = await self.handlers.admin_handler.search_users(term)
- return 200, ret
-
-
-class DeleteGroupAdminRestServlet(RestServlet):
- """Allows deleting of local groups
- """
-
- PATTERNS = historical_admin_path_patterns("/delete_group/(?P<group_id>[^/]*)")
-
- def __init__(self, hs):
- self.group_server = hs.get_groups_server_handler()
- self.is_mine_id = hs.is_mine_id
- self.auth = hs.get_auth()
-
- async def on_POST(self, request, group_id):
- requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
-
- if not self.is_mine_id(group_id):
- raise SynapseError(400, "Can only delete local groups")
-
- await self.group_server.delete_group(group_id, requester.user.to_string())
- return 200, {}
-
-
-class AccountValidityRenewServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/account_validity/validity$")
-
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
- self.hs = hs
- self.account_activity_handler = hs.get_account_validity_handler()
- self.auth = hs.get_auth()
-
- async def on_POST(self, request):
- await assert_requester_is_admin(self.auth, request)
-
- body = parse_json_object_from_request(request)
-
- if "user_id" not in body:
- raise SynapseError(400, "Missing property 'user_id' in the request body")
-
- expiration_ts = await self.account_activity_handler.renew_account_for_user(
- body["user_id"],
- body.get("expiration_ts"),
- not body.get("enable_renewal_emails", True),
- )
-
- res = {"expiration_ts": expiration_ts}
- return 200, res
-
-
########################################################################################
#
# please don't add more servlets here: this file is already long and unwieldy. Put
@@ -730,6 +191,7 @@ def register_servlets(hs, http_server):
SendServerNoticeServlet(hs).register(http_server)
VersionServlet(hs).register(http_server)
UserAdminServlet(hs).register(http_server)
+ UsersRestServletV2(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server):
@@ -740,7 +202,6 @@ def register_servlets_for_client_rest_resource(hs, http_server):
PurgeHistoryRestServlet(hs).register(http_server)
UsersRestServlet(hs).register(http_server)
ResetPasswordRestServlet(hs).register(http_server)
- GetUsersPaginatedRestServlet(hs).register(http_server)
SearchUsersRestServlet(hs).register(http_server)
ShutdownRoomRestServlet(hs).register(http_server)
UserRegisterServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py
new file mode 100644
index 0000000000..0b54ca09f4
--- /dev/null
+++ b/synapse/rest/admin/groups.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import RestServlet
+from synapse.rest.admin._base import (
+ assert_user_is_admin,
+ historical_admin_path_patterns,
+)
+
+logger = logging.getLogger(__name__)
+
+
+class DeleteGroupAdminRestServlet(RestServlet):
+ """Allows deleting of local groups
+ """
+
+ PATTERNS = historical_admin_path_patterns("/delete_group/(?P<group_id>[^/]*)")
+
+ def __init__(self, hs):
+ self.group_server = hs.get_groups_server_handler()
+ self.is_mine_id = hs.is_mine_id
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ if not self.is_mine_id(group_id):
+ raise SynapseError(400, "Can only delete local groups")
+
+ await self.group_server.delete_group(group_id, requester.user.to_string())
+ return 200, {}
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
new file mode 100644
index 0000000000..f7cc5e9be9
--- /dev/null
+++ b/synapse/rest/admin/rooms.py
@@ -0,0 +1,157 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.api.constants import Membership
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
+from synapse.rest.admin._base import (
+ assert_user_is_admin,
+ historical_admin_path_patterns,
+)
+from synapse.types import create_requester
+from synapse.util.async_helpers import maybe_awaitable
+
+logger = logging.getLogger(__name__)
+
+
+class ShutdownRoomRestServlet(RestServlet):
+ """Shuts down a room by removing all local users from the room and blocking
+ all future invites and joins to the room. Any local aliases will be repointed
+ to a new room created by `new_room_user_id` and kicked users will be auto
+ joined to the new room.
+ """
+
+ PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P<room_id>[^/]+)")
+
+ DEFAULT_MESSAGE = (
+ "Sharing illegal content on this server is not permitted and rooms in"
+ " violation will be blocked."
+ )
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.state = hs.get_state_handler()
+ self._room_creation_handler = hs.get_room_creation_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self.room_member_handler = hs.get_room_member_handler()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ content = parse_json_object_from_request(request)
+ assert_params_in_dict(content, ["new_room_user_id"])
+ new_room_user_id = content["new_room_user_id"]
+
+ room_creator_requester = create_requester(new_room_user_id)
+
+ message = content.get("message", self.DEFAULT_MESSAGE)
+ room_name = content.get("room_name", "Content Violation Notification")
+
+ info = await self._room_creation_handler.create_room(
+ room_creator_requester,
+ config={
+ "preset": "public_chat",
+ "name": room_name,
+ "power_level_content_override": {"users_default": -10},
+ },
+ ratelimit=False,
+ )
+ new_room_id = info["room_id"]
+
+ requester_user_id = requester.user.to_string()
+
+ logger.info(
+ "Shutting down room %r, joining to new room: %r", room_id, new_room_id
+ )
+
+ # This will work even if the room is already blocked, but that is
+ # desirable in case the first attempt at blocking the room failed below.
+ await self.store.block_room(room_id, requester_user_id)
+
+ users = await self.state.get_current_users_in_room(room_id)
+ kicked_users = []
+ failed_to_kick_users = []
+ for user_id in users:
+ if not self.hs.is_mine_id(user_id):
+ continue
+
+ logger.info("Kicking %r from %r...", user_id, room_id)
+
+ try:
+ target_requester = create_requester(user_id)
+ await self.room_member_handler.update_membership(
+ requester=target_requester,
+ target=target_requester.user,
+ room_id=room_id,
+ action=Membership.LEAVE,
+ content={},
+ ratelimit=False,
+ require_consent=False,
+ )
+
+ await self.room_member_handler.forget(target_requester.user, room_id)
+
+ await self.room_member_handler.update_membership(
+ requester=target_requester,
+ target=target_requester.user,
+ room_id=new_room_id,
+ action=Membership.JOIN,
+ content={},
+ ratelimit=False,
+ require_consent=False,
+ )
+
+ kicked_users.append(user_id)
+ except Exception:
+ logger.exception(
+ "Failed to leave old room and join new room for %r", user_id
+ )
+ failed_to_kick_users.append(user_id)
+
+ await self.event_creation_handler.create_and_send_nonmember_event(
+ room_creator_requester,
+ {
+ "type": "m.room.message",
+ "content": {"body": message, "msgtype": "m.text"},
+ "room_id": new_room_id,
+ "sender": new_room_user_id,
+ },
+ ratelimit=False,
+ )
+
+ aliases_for_room = await maybe_awaitable(
+ self.store.get_aliases_for_room(room_id)
+ )
+
+ await self.store.update_aliases_for_room(
+ room_id, new_room_id, requester_user_id
+ )
+
+ return (
+ 200,
+ {
+ "kicked_users": kicked_users,
+ "failed_to_kick_users": failed_to_kick_users,
+ "local_aliases": aliases_for_room,
+ "new_room_id": new_room_id,
+ },
+ )
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index d5d124a0dc..1937879dbe 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -12,17 +12,394 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import hashlib
+import hmac
+import logging
import re
-from synapse.api.errors import SynapseError
+from six import text_type
+from six.moves import http_client
+
+from synapse.api.constants import UserTypes
+from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
+ parse_boolean,
+ parse_integer,
parse_json_object_from_request,
+ parse_string,
+)
+from synapse.rest.admin._base import (
+ assert_requester_is_admin,
+ assert_user_is_admin,
+ historical_admin_path_patterns,
)
-from synapse.rest.admin import assert_requester_is_admin, assert_user_is_admin
from synapse.types import UserID
+logger = logging.getLogger(__name__)
+
+
+class UsersRestServlet(RestServlet):
+ PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)$")
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.admin_handler = hs.get_handlers().admin_handler
+
+ async def on_GET(self, request, user_id):
+ target_user = UserID.from_string(user_id)
+ await assert_requester_is_admin(self.auth, request)
+
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only users a local user")
+
+ ret = await self.admin_handler.get_users()
+
+ return 200, ret
+
+
+class UsersRestServletV2(RestServlet):
+ PATTERNS = (re.compile("^/_synapse/admin/v2/users$"),)
+
+ """Get request to list all local users.
+ This needs user to have administrator access in Synapse.
+
+ GET /_synapse/admin/v2/users?from=0&limit=10&guests=false
+
+ returns:
+ 200 OK with list of users if success otherwise an error.
+
+ The parameters `from` and `limit` are required only for pagination.
+ By default, a `limit` of 100 is used.
+ The parameter `user_id` can be used to filter by user id.
+ The parameter `guests` can be used to exclude guest users.
+ The parameter `deactivated` can be used to include deactivated users.
+ """
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.admin_handler = hs.get_handlers().admin_handler
+
+ async def on_GET(self, request):
+ await assert_requester_is_admin(self.auth, request)
+
+ start = parse_integer(request, "from", default=0)
+ limit = parse_integer(request, "limit", default=100)
+ user_id = parse_string(request, "user_id", default=None)
+ guests = parse_boolean(request, "guests", default=True)
+ deactivated = parse_boolean(request, "deactivated", default=False)
+
+ users = await self.admin_handler.get_users_paginate(
+ start, limit, user_id, guests, deactivated
+ )
+ ret = {"users": users}
+ if len(users) >= limit:
+ ret["next_token"] = str(start + len(users))
+
+ return 200, ret
+
+
+class UserRegisterServlet(RestServlet):
+ """
+ Attributes:
+ NONCE_TIMEOUT (int): Seconds until a generated nonce won't be accepted
+ nonces (dict[str, int]): The nonces that we will accept. A dict of
+ nonce to the time it was generated, in int seconds.
+ """
+
+ PATTERNS = historical_admin_path_patterns("/register")
+ NONCE_TIMEOUT = 60
+
+ def __init__(self, hs):
+ self.handlers = hs.get_handlers()
+ self.reactor = hs.get_reactor()
+ self.nonces = {}
+ self.hs = hs
+
+ def _clear_old_nonces(self):
+ """
+ Clear out old nonces that are older than NONCE_TIMEOUT.
+ """
+ now = int(self.reactor.seconds())
+
+ for k, v in list(self.nonces.items()):
+ if now - v > self.NONCE_TIMEOUT:
+ del self.nonces[k]
+
+ def on_GET(self, request):
+ """
+ Generate a new nonce.
+ """
+ self._clear_old_nonces()
+
+ nonce = self.hs.get_secrets().token_hex(64)
+ self.nonces[nonce] = int(self.reactor.seconds())
+ return 200, {"nonce": nonce}
+
+ async def on_POST(self, request):
+ self._clear_old_nonces()
+
+ if not self.hs.config.registration_shared_secret:
+ raise SynapseError(400, "Shared secret registration is not enabled")
+
+ body = parse_json_object_from_request(request)
+
+ if "nonce" not in body:
+ raise SynapseError(400, "nonce must be specified", errcode=Codes.BAD_JSON)
+
+ nonce = body["nonce"]
+
+ if nonce not in self.nonces:
+ raise SynapseError(400, "unrecognised nonce")
+
+ # Delete the nonce, so it can't be reused, even if it's invalid
+ del self.nonces[nonce]
+
+ if "username" not in body:
+ raise SynapseError(
+ 400, "username must be specified", errcode=Codes.BAD_JSON
+ )
+ else:
+ if (
+ not isinstance(body["username"], text_type)
+ or len(body["username"]) > 512
+ ):
+ raise SynapseError(400, "Invalid username")
+
+ username = body["username"].encode("utf-8")
+ if b"\x00" in username:
+ raise SynapseError(400, "Invalid username")
+
+ if "password" not in body:
+ raise SynapseError(
+ 400, "password must be specified", errcode=Codes.BAD_JSON
+ )
+ else:
+ if (
+ not isinstance(body["password"], text_type)
+ or len(body["password"]) > 512
+ ):
+ raise SynapseError(400, "Invalid password")
+
+ password = body["password"].encode("utf-8")
+ if b"\x00" in password:
+ raise SynapseError(400, "Invalid password")
+
+ admin = body.get("admin", None)
+ user_type = body.get("user_type", None)
+
+ if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
+ raise SynapseError(400, "Invalid user type")
+
+ got_mac = body["mac"]
+
+ want_mac = hmac.new(
+ key=self.hs.config.registration_shared_secret.encode(),
+ digestmod=hashlib.sha1,
+ )
+ want_mac.update(nonce.encode("utf8"))
+ want_mac.update(b"\x00")
+ want_mac.update(username)
+ want_mac.update(b"\x00")
+ want_mac.update(password)
+ want_mac.update(b"\x00")
+ want_mac.update(b"admin" if admin else b"notadmin")
+ if user_type:
+ want_mac.update(b"\x00")
+ want_mac.update(user_type.encode("utf8"))
+ want_mac = want_mac.hexdigest()
+
+ if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")):
+ raise SynapseError(403, "HMAC incorrect")
+
+ # Reuse the parts of RegisterRestServlet to reduce code duplication
+ from synapse.rest.client.v2_alpha.register import RegisterRestServlet
+
+ register = RegisterRestServlet(self.hs)
+
+ user_id = await register.registration_handler.register_user(
+ localpart=body["username"].lower(),
+ password=body["password"],
+ admin=bool(admin),
+ user_type=user_type,
+ )
+
+ result = await register._create_registration_details(user_id, body)
+ return 200, result
+
+
+class WhoisRestServlet(RestServlet):
+ PATTERNS = historical_admin_path_patterns("/whois/(?P<user_id>[^/]*)")
+
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.handlers = hs.get_handlers()
+
+ async def on_GET(self, request, user_id):
+ target_user = UserID.from_string(user_id)
+ requester = await self.auth.get_user_by_req(request)
+ auth_user = requester.user
+
+ if target_user != auth_user:
+ await assert_user_is_admin(self.auth, auth_user)
+
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only whois a local user")
+
+ ret = await self.handlers.admin_handler.get_whois(target_user)
+
+ return 200, ret
+
+
+class DeactivateAccountRestServlet(RestServlet):
+ PATTERNS = historical_admin_path_patterns("/deactivate/(?P<target_user_id>[^/]*)")
+
+ def __init__(self, hs):
+ self._deactivate_account_handler = hs.get_deactivate_account_handler()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request, target_user_id):
+ await assert_requester_is_admin(self.auth, request)
+ body = parse_json_object_from_request(request, allow_empty_body=True)
+ erase = body.get("erase", False)
+ if not isinstance(erase, bool):
+ raise SynapseError(
+ http_client.BAD_REQUEST,
+ "Param 'erase' must be a boolean, if given",
+ Codes.BAD_JSON,
+ )
+
+ UserID.from_string(target_user_id)
+
+ result = await self._deactivate_account_handler.deactivate_account(
+ target_user_id, erase
+ )
+ if result:
+ id_server_unbind_result = "success"
+ else:
+ id_server_unbind_result = "no-support"
+
+ return 200, {"id_server_unbind_result": id_server_unbind_result}
+
+
+class AccountValidityRenewServlet(RestServlet):
+ PATTERNS = historical_admin_path_patterns("/account_validity/validity$")
+
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ self.hs = hs
+ self.account_activity_handler = hs.get_account_validity_handler()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request):
+ await assert_requester_is_admin(self.auth, request)
+
+ body = parse_json_object_from_request(request)
+
+ if "user_id" not in body:
+ raise SynapseError(400, "Missing property 'user_id' in the request body")
+
+ expiration_ts = await self.account_activity_handler.renew_account_for_user(
+ body["user_id"],
+ body.get("expiration_ts"),
+ not body.get("enable_renewal_emails", True),
+ )
+
+ res = {"expiration_ts": expiration_ts}
+ return 200, res
+
+
+class ResetPasswordRestServlet(RestServlet):
+ """Post request to allow an administrator reset password for a user.
+ This needs user to have administrator access in Synapse.
+ Example:
+ http://localhost:8008/_synapse/admin/v1/reset_password/
+ @user:to_reset_password?access_token=admin_access_token
+ JsonBodyToSend:
+ {
+ "new_password": "secret"
+ }
+ Returns:
+ 200 OK with empty object if success otherwise an error.
+ """
+
+ PATTERNS = historical_admin_path_patterns(
+ "/reset_password/(?P<target_user_id>[^/]*)"
+ )
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self._set_password_handler = hs.get_set_password_handler()
+
+ async def on_POST(self, request, target_user_id):
+ """Post request to allow an administrator reset password for a user.
+ This needs user to have administrator access in Synapse.
+ """
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ UserID.from_string(target_user_id)
+
+ params = parse_json_object_from_request(request)
+ assert_params_in_dict(params, ["new_password"])
+ new_password = params["new_password"]
+
+ await self._set_password_handler.set_password(
+ target_user_id, new_password, requester
+ )
+ return 200, {}
+
+
+class SearchUsersRestServlet(RestServlet):
+ """Get request to search user table for specific users according to
+ search term.
+ This needs user to have administrator access in Synapse.
+ Example:
+ http://localhost:8008/_synapse/admin/v1/search_users/
+ @admin:user?access_token=admin_access_token&term=alice
+ Returns:
+ 200 OK with json object {list[dict[str, Any]], count} or empty object.
+ """
+
+ PATTERNS = historical_admin_path_patterns("/search_users/(?P<target_user_id>[^/]*)")
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.handlers = hs.get_handlers()
+
+ async def on_GET(self, request, target_user_id):
+ """Get request to search user table for specific users according to
+ search term.
+ This needs user to have a administrator access in Synapse.
+ """
+ await assert_requester_is_admin(self.auth, request)
+
+ target_user = UserID.from_string(target_user_id)
+
+ # To allow all users to get the users list
+ # if not is_admin and target_user != auth_user:
+ # raise AuthError(403, "You are not a server admin")
+
+ if not self.hs.is_mine(target_user):
+ raise SynapseError(400, "Can only users a local user")
+
+ term = parse_string(request, "term", required=True)
+ logger.info("term: %s ", term)
+
+ ret = await self.handlers.admin_handler.search_users(term)
+ return 200, ret
+
class UserAdminServlet(RestServlet):
"""
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 4ea3666874..5934b1fe8b 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -16,8 +16,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import (
AuthError,
Codes,
@@ -47,17 +45,15 @@ class ClientDirectoryServer(RestServlet):
self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_alias):
+ async def on_GET(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias)
dir_handler = self.handlers.directory_handler
- res = yield dir_handler.get_association(room_alias)
+ res = await dir_handler.get_association(room_alias)
return 200, res
- @defer.inlineCallbacks
- def on_PUT(self, request, room_alias):
+ async def on_PUT(self, request, room_alias):
room_alias = RoomAlias.from_string(room_alias)
content = parse_json_object_from_request(request)
@@ -77,26 +73,25 @@ class ClientDirectoryServer(RestServlet):
# TODO(erikj): Check types.
- room = yield self.store.get_room(room_id)
+ room = await self.store.get_room(room_id)
if room is None:
raise SynapseError(400, "Room does not exist")
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
- yield self.handlers.directory_handler.create_association(
+ await self.handlers.directory_handler.create_association(
requester, room_alias, room_id, servers
)
return 200, {}
- @defer.inlineCallbacks
- def on_DELETE(self, request, room_alias):
+ async def on_DELETE(self, request, room_alias):
dir_handler = self.handlers.directory_handler
try:
- service = yield self.auth.get_appservice_by_req(request)
+ service = await self.auth.get_appservice_by_req(request)
room_alias = RoomAlias.from_string(room_alias)
- yield dir_handler.delete_appservice_association(service, room_alias)
+ await dir_handler.delete_appservice_association(service, room_alias)
logger.info(
"Application service at %s deleted alias %s",
service.url,
@@ -107,12 +102,12 @@ class ClientDirectoryServer(RestServlet):
# fallback to default user behaviour if they aren't an AS
pass
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
user = requester.user
room_alias = RoomAlias.from_string(room_alias)
- yield dir_handler.delete_association(requester, room_alias)
+ await dir_handler.delete_association(requester, room_alias)
logger.info(
"User %s deleted alias %s", user.to_string(), room_alias.to_string()
@@ -130,32 +125,29 @@ class ClientDirectoryListServer(RestServlet):
self.handlers = hs.get_handlers()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id):
- room = yield self.store.get_room(room_id)
+ async def on_GET(self, request, room_id):
+ room = await self.store.get_room(room_id)
if room is None:
raise NotFoundError("Unknown room")
return 200, {"visibility": "public" if room["is_public"] else "private"}
- @defer.inlineCallbacks
- def on_PUT(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
visibility = content.get("visibility", "public")
- yield self.handlers.directory_handler.edit_published_room_list(
+ await self.handlers.directory_handler.edit_published_room_list(
requester, room_id, visibility
)
return 200, {}
- @defer.inlineCallbacks
- def on_DELETE(self, request, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_DELETE(self, request, room_id):
+ requester = await self.auth.get_user_by_req(request)
- yield self.handlers.directory_handler.edit_published_room_list(
+ await self.handlers.directory_handler.edit_published_room_list(
requester, room_id, "private"
)
@@ -181,15 +173,14 @@ class ClientAppserviceDirectoryListServer(RestServlet):
def on_DELETE(self, request, network_id, room_id):
return self._edit(request, network_id, room_id, "private")
- @defer.inlineCallbacks
- def _edit(self, request, network_id, room_id, visibility):
- requester = yield self.auth.get_user_by_req(request)
+ async def _edit(self, request, network_id, room_id, visibility):
+ requester = await self.auth.get_user_by_req(request)
if not requester.app_service:
raise AuthError(
403, "Only appservices can edit the appservice published room list"
)
- yield self.handlers.directory_handler.edit_published_appservice_room_list(
+ await self.handlers.directory_handler.edit_published_appservice_room_list(
requester.app_service.id, network_id, room_id, visibility
)
diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index 6651b4cf07..4beb617733 100644
--- a/synapse/rest/client/v1/events.py
+++ b/synapse/rest/client/v1/events.py
@@ -16,8 +16,6 @@
"""This module contains REST servlets to do with event streaming, /events."""
import logging
-from twisted.internet import defer
-
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns
@@ -36,9 +34,8 @@ class EventStreamRestServlet(RestServlet):
self.event_stream_handler = hs.get_event_stream_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
is_guest = requester.is_guest
room_id = None
if is_guest:
@@ -57,7 +54,7 @@ class EventStreamRestServlet(RestServlet):
as_client_event = b"raw" not in request.args
- chunk = yield self.event_stream_handler.get_stream(
+ chunk = await self.event_stream_handler.get_stream(
requester.user.to_string(),
pagin_config,
timeout=timeout,
@@ -83,14 +80,13 @@ class EventRestServlet(RestServlet):
self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
- @defer.inlineCallbacks
- def on_GET(self, request, event_id):
- requester = yield self.auth.get_user_by_req(request)
- event = yield self.event_handler.get_event(requester.user, None, event_id)
+ async def on_GET(self, request, event_id):
+ requester = await self.auth.get_user_by_req(request)
+ event = await self.event_handler.get_event(requester.user, None, event_id)
time_now = self.clock.time_msec()
if event:
- event = yield self._event_serializer.serialize_event(event, time_now)
+ event = await self._event_serializer.serialize_event(event, time_now)
return 200, event
else:
return 404, "Event not found."
diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py
index 2da3cd7511..910b3b4eeb 100644
--- a/synapse/rest/client/v1/initial_sync.py
+++ b/synapse/rest/client/v1/initial_sync.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_boolean
from synapse.rest.client.v2_alpha._base import client_patterns
@@ -29,13 +28,12 @@ class InitialSyncRestServlet(RestServlet):
self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request)
as_client_event = b"raw" not in request.args
pagination_config = PaginationConfig.from_request(request)
include_archived = parse_boolean(request, "archived", default=False)
- content = yield self.initial_sync_handler.snapshot_all_rooms(
+ content = await self.initial_sync_handler.snapshot_all_rooms(
user_id=requester.user.to_string(),
pagin_config=pagination_config,
as_client_event=as_client_event,
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 24a0ce74f2..ff9c978fe7 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -18,7 +18,6 @@ import xml.etree.ElementTree as ET
from six.moves import urllib
-from twisted.internet import defer
from twisted.web.client import PartialDownloadError
from synapse.api.errors import Codes, LoginError, SynapseError
@@ -92,8 +91,11 @@ class LoginRestServlet(RestServlet):
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self.handlers = hs.get_handlers()
+ self._clock = hs.get_clock()
self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter()
+ self._account_ratelimiter = Ratelimiter()
+ self._failed_attempts_ratelimiter = Ratelimiter()
def on_GET(self, request):
flows = []
@@ -127,8 +129,7 @@ class LoginRestServlet(RestServlet):
def on_OPTIONS(self, request):
return 200, {}
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
self._address_ratelimiter.ratelimit(
request.getClientIP(),
time_now_s=self.hs.clock.time(),
@@ -142,11 +143,11 @@ class LoginRestServlet(RestServlet):
if self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE
):
- result = yield self.do_jwt_login(login_submission)
+ result = await self.do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
- result = yield self.do_token_login(login_submission)
+ result = await self.do_token_login(login_submission)
else:
- result = yield self._do_other_login(login_submission)
+ result = await self._do_other_login(login_submission)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@@ -155,8 +156,7 @@ class LoginRestServlet(RestServlet):
result["well_known"] = well_known_data
return 200, result
- @defer.inlineCallbacks
- def _do_other_login(self, login_submission):
+ async def _do_other_login(self, login_submission):
"""Handle non-token/saml/jwt logins
Args:
@@ -202,29 +202,55 @@ class LoginRestServlet(RestServlet):
# (See add_threepid in synapse/handlers/auth.py)
address = address.lower()
+ # We also apply account rate limiting using the 3PID as a key, as
+ # otherwise using 3PID bypasses the ratelimiting based on user ID.
+ self._failed_attempts_ratelimiter.ratelimit(
+ (medium, address),
+ time_now_s=self._clock.time(),
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ update=False,
+ )
+
# Check for login providers that support 3pid login types
(
canonical_user_id,
callback_3pid,
- ) = yield self.auth_handler.check_password_provider_3pid(
+ ) = await self.auth_handler.check_password_provider_3pid(
medium, address, login_submission["password"]
)
if canonical_user_id:
# Authentication through password provider and 3pid succeeded
- result = yield self._register_device_with_callback(
+
+ result = await self._complete_login(
canonical_user_id, login_submission, callback_3pid
)
return result
# No password providers were able to handle this 3pid
# Check local store
- user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
+ user_id = await self.hs.get_datastore().get_user_id_by_threepid(
medium, address
)
if not user_id:
logger.warning(
"unknown 3pid identifier medium %s, address %r", medium, address
)
+ # We mark that we've failed to log in here, as
+ # `check_password_provider_3pid` might have returned `None` due
+ # to an incorrect password, rather than the account not
+ # existing.
+ #
+ # If it returned None but the 3PID was bound then we won't hit
+ # this code path, which is fine as then the per-user ratelimit
+ # will kick in below.
+ self._failed_attempts_ratelimiter.can_do_action(
+ (medium, address),
+ time_now_s=self._clock.time(),
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ update=True,
+ )
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
identifier = {"type": "m.id.user", "user": user_id}
@@ -236,32 +262,86 @@ class LoginRestServlet(RestServlet):
if "user" not in identifier:
raise SynapseError(400, "User identifier is missing 'user' key")
- canonical_user_id, callback = yield self.auth_handler.validate_login(
- identifier["user"], login_submission
+ if identifier["user"].startswith("@"):
+ qualified_user_id = identifier["user"]
+ else:
+ qualified_user_id = UserID(identifier["user"], self.hs.hostname).to_string()
+
+ # Check if we've hit the failed ratelimit (but don't update it)
+ self._failed_attempts_ratelimiter.ratelimit(
+ qualified_user_id.lower(),
+ time_now_s=self._clock.time(),
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ update=False,
)
- result = yield self._register_device_with_callback(
+ try:
+ canonical_user_id, callback = await self.auth_handler.validate_login(
+ identifier["user"], login_submission
+ )
+ except LoginError:
+ # The user has failed to log in, so we need to update the rate
+ # limiter. Using `can_do_action` avoids us raising a ratelimit
+ # exception and masking the LoginError. The actual ratelimiting
+ # should have happened above.
+ self._failed_attempts_ratelimiter.can_do_action(
+ qualified_user_id.lower(),
+ time_now_s=self._clock.time(),
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ update=True,
+ )
+ raise
+
+ result = await self._complete_login(
canonical_user_id, login_submission, callback
)
return result
- @defer.inlineCallbacks
- def _register_device_with_callback(self, user_id, login_submission, callback=None):
- """ Registers a device with a given user_id. Optionally run a callback
- function after registration has completed.
+ async def _complete_login(
+ self, user_id, login_submission, callback=None, create_non_existant_users=False
+ ):
+ """Called when we've successfully authed the user and now need to
+ actually login them in (e.g. create devices). This gets called on
+ all succesful logins.
+
+ Applies the ratelimiting for succesful login attempts against an
+ account.
Args:
user_id (str): ID of the user to register.
login_submission (dict): Dictionary of login information.
callback (func|None): Callback function to run after registration.
+ create_non_existant_users (bool): Whether to create the user if
+ they don't exist. Defaults to False.
Returns:
result (Dict[str,str]): Dictionary of account information after
successful registration.
"""
+
+ # Before we actually log them in we check if they've already logged in
+ # too often. This happens here rather than before as we don't
+ # necessarily know the user before now.
+ self._account_ratelimiter.ratelimit(
+ user_id.lower(),
+ time_now_s=self._clock.time(),
+ rate_hz=self.hs.config.rc_login_account.per_second,
+ burst_count=self.hs.config.rc_login_account.burst_count,
+ update=True,
+ )
+
+ if create_non_existant_users:
+ user_id = await self.auth_handler.check_user_exists(user_id)
+ if not user_id:
+ user_id = await self.registration_handler.register_user(
+ localpart=UserID.from_string(user_id).localpart
+ )
+
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
- device_id, access_token = yield self.registration_handler.register_device(
+ device_id, access_token = await self.registration_handler.register_device(
user_id, device_id, initial_display_name
)
@@ -273,23 +353,21 @@ class LoginRestServlet(RestServlet):
}
if callback is not None:
- yield callback(result)
+ await callback(result)
return result
- @defer.inlineCallbacks
- def do_token_login(self, login_submission):
+ async def do_token_login(self, login_submission):
token = login_submission["token"]
auth_handler = self.auth_handler
- user_id = yield auth_handler.validate_short_term_login_token_and_get_user_id(
+ user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
token
)
- result = yield self._register_device_with_callback(user_id, login_submission)
+ result = await self._complete_login(user_id, login_submission)
return result
- @defer.inlineCallbacks
- def do_jwt_login(self, login_submission):
+ async def do_jwt_login(self, login_submission):
token = login_submission.get("token", None)
if token is None:
raise LoginError(
@@ -313,15 +391,8 @@ class LoginRestServlet(RestServlet):
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
user_id = UserID(user, self.hs.hostname).to_string()
-
- registered_user_id = yield self.auth_handler.check_user_exists(user_id)
- if not registered_user_id:
- registered_user_id = yield self.registration_handler.register_user(
- localpart=user
- )
-
- result = yield self._register_device_with_callback(
- registered_user_id, login_submission
+ result = await self._complete_login(
+ user_id, login_submission, create_non_existant_users=True
)
return result
@@ -383,8 +454,7 @@ class CasTicketServlet(RestServlet):
self._sso_auth_handler = SSOAuthHandler(hs)
self._http_client = hs.get_proxied_http_client()
- @defer.inlineCallbacks
- def on_GET(self, request):
+ async def on_GET(self, request):
client_redirect_url = parse_string(request, "redirectUrl", required=True)
uri = self.cas_server_url + "/proxyValidate"
args = {
@@ -392,12 +462,12 @@ class CasTicketServlet(RestServlet):
"service": self.cas_service_url,
}
try:
- body = yield self._http_client.get_raw(uri, args)
+ body = await self._http_client.get_raw(uri, args)
except PartialDownloadError as pde:
# Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data
body = pde.response
- result = yield self.handle_cas_response(request, body, client_redirect_url)
+ result = await self.handle_cas_response(request, body, client_redirect_url)
return result
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
@@ -478,8 +548,7 @@ class SSOAuthHandler(object):
self._registration_handler = hs.get_registration_handler()
self._macaroon_gen = hs.get_macaroon_generator()
- @defer.inlineCallbacks
- def on_successful_auth(
+ async def on_successful_auth(
self, username, request, client_redirect_url, user_display_name=None
):
"""Called once the user has successfully authenticated with the SSO.
@@ -505,9 +574,9 @@ class SSOAuthHandler(object):
"""
localpart = map_username_to_mxid_localpart(username)
user_id = UserID(localpart, self._hostname).to_string()
- registered_user_id = yield self._auth_handler.check_user_exists(user_id)
+ registered_user_id = await self._auth_handler.check_user_exists(user_id)
if not registered_user_id:
- registered_user_id = yield self._registration_handler.register_user(
+ registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=user_display_name
)
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index 4785a34d75..1cf3caf832 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns
@@ -35,17 +33,16 @@ class LogoutRestServlet(RestServlet):
def on_OPTIONS(self, request):
return 200, {}
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
if requester.device_id is None:
# the acccess token wasn't associated with a device.
# Just delete the access token
access_token = self.auth.get_access_token_from_request(request)
- yield self._auth_handler.delete_access_token(access_token)
+ await self._auth_handler.delete_access_token(access_token)
else:
- yield self._device_handler.delete_device(
+ await self._device_handler.delete_device(
requester.user.to_string(), requester.device_id
)
@@ -64,17 +61,16 @@ class LogoutAllRestServlet(RestServlet):
def on_OPTIONS(self, request):
return 200, {}
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
# first delete all of the user's devices
- yield self._device_handler.delete_all_devices_for_user(user_id)
+ await self._device_handler.delete_all_devices_for_user(user_id)
# .. and then delete any access tokens which weren't associated with
# devices.
- yield self._auth_handler.delete_access_tokens_for_user(user_id)
+ await self._auth_handler.delete_access_tokens_for_user(user_id)
return 200, {}
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index 0153525cef..eec16f8ad8 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -19,8 +19,6 @@ import logging
from six import string_types
-from twisted.internet import defer
-
from synapse.api.errors import AuthError, SynapseError
from synapse.handlers.presence import format_user_presence_state
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -40,27 +38,25 @@ class PresenceStatusRestServlet(RestServlet):
self.clock = hs.get_clock()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request, user_id):
+ requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
if requester.user != user:
- allowed = yield self.presence_handler.is_visible(
+ allowed = await self.presence_handler.is_visible(
observed_user=user, observer_user=requester.user
)
if not allowed:
raise AuthError(403, "You are not allowed to see their presence.")
- state = yield self.presence_handler.get_state(target_user=user)
+ state = await self.presence_handler.get_state(target_user=user)
state = format_user_presence_state(state, self.clock.time_msec())
return 200, state
- @defer.inlineCallbacks
- def on_PUT(self, request, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, user_id):
+ requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
if requester.user != user:
@@ -86,7 +82,7 @@ class PresenceStatusRestServlet(RestServlet):
raise SynapseError(400, "Unable to parse state")
if self.hs.config.use_presence:
- yield self.presence_handler.set_state(user, state)
+ await self.presence_handler.set_state(user, state)
return 200, {}
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index bbce2e2b71..1eac8a44c5 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -14,7 +14,6 @@
# limitations under the License.
""" This module contains REST servlets to do with profile: /profile/<paths> """
-from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.client.v2_alpha._base import client_patterns
@@ -30,19 +29,18 @@ class ProfileDisplaynameRestServlet(RestServlet):
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
+ async def on_GET(self, request, user_id):
requester_user = None
if self.hs.config.require_auth_for_profile_requests:
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
requester_user = requester.user
user = UserID.from_string(user_id)
- yield self.profile_handler.check_profile_query_allowed(user, requester_user)
+ await self.profile_handler.check_profile_query_allowed(user, requester_user)
- displayname = yield self.profile_handler.get_displayname(user)
+ displayname = await self.profile_handler.get_displayname(user)
ret = {}
if displayname is not None:
@@ -50,11 +48,10 @@ class ProfileDisplaynameRestServlet(RestServlet):
return 200, ret
- @defer.inlineCallbacks
- def on_PUT(self, request, user_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_PUT(self, request, user_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
user = UserID.from_string(user_id)
- is_admin = yield self.auth.is_server_admin(requester.user)
+ is_admin = await self.auth.is_server_admin(requester.user)
content = parse_json_object_from_request(request)
@@ -63,7 +60,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
except Exception:
return 400, "Unable to parse name"
- yield self.profile_handler.set_displayname(user, requester, new_name, is_admin)
+ await self.profile_handler.set_displayname(user, requester, new_name, is_admin)
return 200, {}
@@ -80,19 +77,18 @@ class ProfileAvatarURLRestServlet(RestServlet):
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
+ async def on_GET(self, request, user_id):
requester_user = None
if self.hs.config.require_auth_for_profile_requests:
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
requester_user = requester.user
user = UserID.from_string(user_id)
- yield self.profile_handler.check_profile_query_allowed(user, requester_user)
+ await self.profile_handler.check_profile_query_allowed(user, requester_user)
- avatar_url = yield self.profile_handler.get_avatar_url(user)
+ avatar_url = await self.profile_handler.get_avatar_url(user)
ret = {}
if avatar_url is not None:
@@ -100,11 +96,10 @@ class ProfileAvatarURLRestServlet(RestServlet):
return 200, ret
- @defer.inlineCallbacks
- def on_PUT(self, request, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, user_id):
+ requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
- is_admin = yield self.auth.is_server_admin(requester.user)
+ is_admin = await self.auth.is_server_admin(requester.user)
content = parse_json_object_from_request(request)
try:
@@ -112,7 +107,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
except Exception:
return 400, "Unable to parse name"
- yield self.profile_handler.set_avatar_url(user, requester, new_name, is_admin)
+ await self.profile_handler.set_avatar_url(user, requester, new_name, is_admin)
return 200, {}
@@ -129,20 +124,19 @@ class ProfileRestServlet(RestServlet):
self.profile_handler = hs.get_profile_handler()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
+ async def on_GET(self, request, user_id):
requester_user = None
if self.hs.config.require_auth_for_profile_requests:
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
requester_user = requester.user
user = UserID.from_string(user_id)
- yield self.profile_handler.check_profile_query_allowed(user, requester_user)
+ await self.profile_handler.check_profile_query_allowed(user, requester_user)
- displayname = yield self.profile_handler.get_displayname(user)
- avatar_url = yield self.profile_handler.get_avatar_url(user)
+ displayname = await self.profile_handler.get_displayname(user)
+ avatar_url = await self.profile_handler.get_avatar_url(user)
ret = {}
if displayname is not None:
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 9f8c3d09e3..4f74600239 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
from synapse.api.errors import (
NotFoundError,
@@ -46,8 +45,7 @@ class PushRuleRestServlet(RestServlet):
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None
- @defer.inlineCallbacks
- def on_PUT(self, request, path):
+ async def on_PUT(self, request, path):
if self._is_worker:
raise Exception("Cannot handle PUT /push_rules on worker")
@@ -57,7 +55,7 @@ class PushRuleRestServlet(RestServlet):
except InvalidRuleException as e:
raise SynapseError(400, str(e))
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
if "/" in spec["rule_id"] or "\\" in spec["rule_id"]:
raise SynapseError(400, "rule_id may not contain slashes")
@@ -67,7 +65,7 @@ class PushRuleRestServlet(RestServlet):
user_id = requester.user.to_string()
if "attr" in spec:
- yield self.set_rule_attr(user_id, spec, content)
+ await self.set_rule_attr(user_id, spec, content)
self.notify_user(user_id)
return 200, {}
@@ -91,7 +89,7 @@ class PushRuleRestServlet(RestServlet):
after = _namespaced_rule_id(spec, after)
try:
- yield self.store.add_push_rule(
+ await self.store.add_push_rule(
user_id=user_id,
rule_id=_namespaced_rule_id_from_spec(spec),
priority_class=priority_class,
@@ -108,20 +106,19 @@ class PushRuleRestServlet(RestServlet):
return 200, {}
- @defer.inlineCallbacks
- def on_DELETE(self, request, path):
+ async def on_DELETE(self, request, path):
if self._is_worker:
raise Exception("Cannot handle DELETE /push_rules on worker")
spec = _rule_spec_from_path([x for x in path.split("/")])
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
try:
- yield self.store.delete_push_rule(user_id, namespaced_rule_id)
+ await self.store.delete_push_rule(user_id, namespaced_rule_id)
self.notify_user(user_id)
return 200, {}
except StoreError as e:
@@ -130,15 +127,14 @@ class PushRuleRestServlet(RestServlet):
else:
raise
- @defer.inlineCallbacks
- def on_GET(self, request, path):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request, path):
+ requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
# we build up the full structure and then decide which bits of it
# to send which means doing unnecessary work sometimes but is
# is probably not going to make a whole lot of difference
- rules = yield self.store.get_push_rules_for_user(user_id)
+ rules = await self.store.get_push_rules_for_user(user_id)
rules = format_push_rules_for_user(requester.user, rules)
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 41660682d9..0791866f55 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.http.server import finish_request
from synapse.http.servlet import (
@@ -39,12 +37,11 @@ class PushersRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request)
user = requester.user
- pushers = yield self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
+ pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
allowed_keys = [
"app_display_name",
@@ -78,9 +75,8 @@ class PushersSetRestServlet(RestServlet):
self.notifier = hs.get_notifier()
self.pusher_pool = self.hs.get_pusherpool()
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
user = requester.user
content = parse_json_object_from_request(request)
@@ -91,7 +87,7 @@ class PushersSetRestServlet(RestServlet):
and "kind" in content
and content["kind"] is None
):
- yield self.pusher_pool.remove_pusher(
+ await self.pusher_pool.remove_pusher(
content["app_id"], content["pushkey"], user_id=user.to_string()
)
return 200, {}
@@ -117,14 +113,14 @@ class PushersSetRestServlet(RestServlet):
append = content["append"]
if not append:
- yield self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
+ await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
app_id=content["app_id"],
pushkey=content["pushkey"],
not_user_id=user.to_string(),
)
try:
- yield self.pusher_pool.add_pusher(
+ await self.pusher_pool.add_pusher(
user_id=user.to_string(),
access_token=requester.access_token_id,
kind=content["kind"],
@@ -164,16 +160,15 @@ class PushersRemoveRestServlet(RestServlet):
self.auth = hs.get_auth()
self.pusher_pool = self.hs.get_pusherpool()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request, rights="delete_pusher")
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request, rights="delete_pusher")
user = requester.user
app_id = parse_string(request, "app_id", required=True)
pushkey = parse_string(request, "pushkey", required=True)
try:
- yield self.pusher_pool.remove_pusher(
+ await self.pusher_pool.remove_pusher(
app_id=app_id, pushkey=pushkey, user_id=user.to_string()
)
except StoreError as se:
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 86bbcc0eea..711d4ad304 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -714,7 +714,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
target = UserID.from_string(content["user_id"])
event_content = None
- if "reason" in content and membership_action in ["kick", "ban"]:
+ if "reason" in content:
event_content = {"reason": content["reason"]}
await self.room_member_handler.update_membership(
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 2afdbb89e5..747d46eac2 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -17,8 +17,6 @@ import base64
import hashlib
import hmac
-from twisted.internet import defer
-
from synapse.http.servlet import RestServlet
from synapse.rest.client.v2_alpha._base import client_patterns
@@ -31,9 +29,8 @@ class VoipRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(
request, self.hs.config.turn_allow_guests
)
diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py
index 8250ae0ae1..2a3f4dd58f 100644
--- a/synapse/rest/client/v2_alpha/_base.py
+++ b/synapse/rest/client/v2_alpha/_base.py
@@ -78,7 +78,7 @@ def interactive_auth_handler(orig):
"""
def wrapped(*args, **kwargs):
- res = defer.maybeDeferred(orig, *args, **kwargs)
+ res = defer.ensureDeferred(orig(*args, **kwargs))
res.addErrback(_catch_incomplete_interactive_auth)
return res
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index f26eae794c..fc240f5cf8 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -18,8 +18,6 @@ import logging
from six.moves import http_client
-from twisted.internet import defer
-
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError, ThreepidValidationError
from synapse.config.emailconfig import ThreepidBehaviour
@@ -67,8 +65,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
template_text=template_text,
)
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
@@ -95,7 +92,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
+ existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", email
)
@@ -106,7 +103,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
assert self.hs.config.account_threepid_delegate_email
# Have the configured identity server handle the request
- ret = yield self.identity_handler.requestEmailToken(
+ ret = await self.identity_handler.requestEmailToken(
self.hs.config.account_threepid_delegate_email,
email,
client_secret,
@@ -115,7 +112,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
)
else:
# Send password reset emails from Synapse
- sid = yield self.identity_handler.send_threepid_validation(
+ sid = await self.identity_handler.send_threepid_validation(
email,
client_secret,
send_attempt,
@@ -153,8 +150,7 @@ class PasswordResetSubmitTokenServlet(RestServlet):
[self.config.email_password_reset_template_failure_html],
)
- @defer.inlineCallbacks
- def on_GET(self, request, medium):
+ async def on_GET(self, request, medium):
# We currently only handle threepid token submissions for email
if medium != "email":
raise SynapseError(
@@ -176,7 +172,7 @@ class PasswordResetSubmitTokenServlet(RestServlet):
# Attempt to validate a 3PID session
try:
# Mark the session as valid
- next_link = yield self.store.validate_threepid_session(
+ next_link = await self.store.validate_threepid_session(
sid, client_secret, token, self.clock.time_msec()
)
@@ -218,8 +214,7 @@ class PasswordRestServlet(RestServlet):
self._set_password_handler = hs.get_set_password_handler()
@interactive_auth_handler
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
body = parse_json_object_from_request(request)
# there are two possibilities here. Either the user does not have an
@@ -233,14 +228,14 @@ class PasswordRestServlet(RestServlet):
# In the second case, we require a password to confirm their identity.
if self.auth.has_access_token(request):
- requester = yield self.auth.get_user_by_req(request)
- params = yield self.auth_handler.validate_user_via_ui_auth(
+ requester = await self.auth.get_user_by_req(request)
+ params = await self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request)
)
user_id = requester.user.to_string()
else:
requester = None
- result, params, _ = yield self.auth_handler.check_auth(
+ result, params, _ = await self.auth_handler.check_auth(
[[LoginType.EMAIL_IDENTITY]], body, self.hs.get_ip_from_request(request)
)
@@ -254,7 +249,7 @@ class PasswordRestServlet(RestServlet):
# (See add_threepid in synapse/handlers/auth.py)
threepid["address"] = threepid["address"].lower()
# if using email, we must know about the email they're authing with!
- threepid_user_id = yield self.datastore.get_user_id_by_threepid(
+ threepid_user_id = await self.datastore.get_user_id_by_threepid(
threepid["medium"], threepid["address"]
)
if not threepid_user_id:
@@ -267,7 +262,7 @@ class PasswordRestServlet(RestServlet):
assert_params_in_dict(params, ["new_password"])
new_password = params["new_password"]
- yield self._set_password_handler.set_password(user_id, new_password, requester)
+ await self._set_password_handler.set_password(user_id, new_password, requester)
return 200, {}
@@ -286,8 +281,7 @@ class DeactivateAccountRestServlet(RestServlet):
self._deactivate_account_handler = hs.get_deactivate_account_handler()
@interactive_auth_handler
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
body = parse_json_object_from_request(request)
erase = body.get("erase", False)
if not isinstance(erase, bool):
@@ -297,19 +291,19 @@ class DeactivateAccountRestServlet(RestServlet):
Codes.BAD_JSON,
)
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
# allow ASes to dectivate their own users
if requester.app_service:
- yield self._deactivate_account_handler.deactivate_account(
+ await self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase
)
return 200, {}
- yield self.auth_handler.validate_user_via_ui_auth(
+ await self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request)
)
- result = yield self._deactivate_account_handler.deactivate_account(
+ result = await self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase, id_server=body.get("id_server")
)
if result:
@@ -346,8 +340,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
template_text=template_text,
)
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
@@ -371,7 +364,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- existing_user_id = yield self.store.get_user_id_by_threepid(
+ existing_user_id = await self.store.get_user_id_by_threepid(
"email", body["email"]
)
@@ -382,7 +375,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
assert self.hs.config.account_threepid_delegate_email
# Have the configured identity server handle the request
- ret = yield self.identity_handler.requestEmailToken(
+ ret = await self.identity_handler.requestEmailToken(
self.hs.config.account_threepid_delegate_email,
email,
client_secret,
@@ -391,7 +384,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
)
else:
# Send threepid validation emails from Synapse
- sid = yield self.identity_handler.send_threepid_validation(
+ sid = await self.identity_handler.send_threepid_validation(
email,
client_secret,
send_attempt,
@@ -414,8 +407,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
self.store = self.hs.get_datastore()
self.identity_handler = hs.get_handlers().identity_handler
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
body = parse_json_object_from_request(request)
assert_params_in_dict(
body, ["client_secret", "country", "phone_number", "send_attempt"]
@@ -435,7 +427,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- existing_user_id = yield self.store.get_user_id_by_threepid("msisdn", msisdn)
+ existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
if existing_user_id is not None:
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
@@ -450,7 +442,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
"Adding phone numbers to user account is not supported by this homeserver",
)
- ret = yield self.identity_handler.requestMsisdnToken(
+ ret = await self.identity_handler.requestMsisdnToken(
self.hs.config.account_threepid_delegate_msisdn,
country,
phone_number,
@@ -484,8 +476,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
[self.config.email_add_threepid_template_failure_html],
)
- @defer.inlineCallbacks
- def on_GET(self, request):
+ async def on_GET(self, request):
if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
@@ -508,7 +499,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
# Attempt to validate a 3PID session
try:
# Mark the session as valid
- next_link = yield self.store.validate_threepid_session(
+ next_link = await self.store.validate_threepid_session(
sid, client_secret, token, self.clock.time_msec()
)
@@ -558,8 +549,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
self.store = hs.get_datastore()
self.identity_handler = hs.get_handlers().identity_handler
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
if not self.config.account_threepid_delegate_msisdn:
raise SynapseError(
400,
@@ -571,7 +561,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
assert_params_in_dict(body, ["client_secret", "sid", "token"])
# Proxy submit_token request to msisdn threepid delegate
- response = yield self.identity_handler.proxy_msisdn_submit_token(
+ response = await self.identity_handler.proxy_msisdn_submit_token(
self.config.account_threepid_delegate_msisdn,
body["client_secret"],
body["sid"],
@@ -591,17 +581,15 @@ class ThreepidRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
self.datastore = self.hs.get_datastore()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request)
- threepids = yield self.datastore.user_get_threepids(requester.user.to_string())
+ threepids = await self.datastore.user_get_threepids(requester.user.to_string())
return 200, {"threepids": threepids}
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -615,11 +603,11 @@ class ThreepidRestServlet(RestServlet):
client_secret = threepid_creds["client_secret"]
sid = threepid_creds["sid"]
- validation_session = yield self.identity_handler.validate_threepid_session(
+ validation_session = await self.identity_handler.validate_threepid_session(
client_secret, sid
)
if validation_session:
- yield self.auth_handler.add_threepid(
+ await self.auth_handler.add_threepid(
user_id,
validation_session["medium"],
validation_session["address"],
@@ -642,9 +630,9 @@ class ThreepidAddRestServlet(RestServlet):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ @interactive_auth_handler
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -652,11 +640,15 @@ class ThreepidAddRestServlet(RestServlet):
client_secret = body["client_secret"]
sid = body["sid"]
- validation_session = yield self.identity_handler.validate_threepid_session(
+ await self.auth_handler.validate_user_via_ui_auth(
+ requester, body, self.hs.get_ip_from_request(request)
+ )
+
+ validation_session = await self.identity_handler.validate_threepid_session(
client_secret, sid
)
if validation_session:
- yield self.auth_handler.add_threepid(
+ await self.auth_handler.add_threepid(
user_id,
validation_session["medium"],
validation_session["address"],
@@ -678,8 +670,7 @@ class ThreepidBindRestServlet(RestServlet):
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["id_server", "sid", "client_secret"])
@@ -688,10 +679,10 @@ class ThreepidBindRestServlet(RestServlet):
client_secret = body["client_secret"]
id_access_token = body.get("id_access_token") # optional
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
- yield self.identity_handler.bind_threepid(
+ await self.identity_handler.bind_threepid(
client_secret, sid, user_id, id_server, id_access_token
)
@@ -708,12 +699,11 @@ class ThreepidUnbindRestServlet(RestServlet):
self.auth = hs.get_auth()
self.datastore = self.hs.get_datastore()
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
"""Unbind the given 3pid from a specific identity server, or identity servers that are
known to have this 3pid bound
"""
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["medium", "address"])
@@ -723,7 +713,7 @@ class ThreepidUnbindRestServlet(RestServlet):
# Attempt to unbind the threepid from an identity server. If id_server is None, try to
# unbind from all identity servers this threepid has been added to in the past
- result = yield self.identity_handler.try_unbind_threepid(
+ result = await self.identity_handler.try_unbind_threepid(
requester.user.to_string(),
{"address": address, "medium": medium, "id_server": id_server},
)
@@ -738,16 +728,15 @@ class ThreepidDeleteRestServlet(RestServlet):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["medium", "address"])
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
try:
- ret = yield self.auth_handler.delete_threepid(
+ ret = await self.auth_handler.delete_threepid(
user_id, body["medium"], body["address"], body.get("id_server")
)
except Exception:
@@ -772,9 +761,8 @@ class WhoamiRestServlet(RestServlet):
super(WhoamiRestServlet, self).__init__()
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request)
return 200, {"user_id": requester.user.to_string()}
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index f0db204ffa..64eb7fec3b 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -41,15 +39,14 @@ class AccountDataServlet(RestServlet):
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
- @defer.inlineCallbacks
- def on_PUT(self, request, user_id, account_data_type):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, user_id, account_data_type):
+ requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
body = parse_json_object_from_request(request)
- max_id = yield self.store.add_account_data_for_user(
+ max_id = await self.store.add_account_data_for_user(
user_id, account_data_type, body
)
@@ -57,13 +54,12 @@ class AccountDataServlet(RestServlet):
return 200, {}
- @defer.inlineCallbacks
- def on_GET(self, request, user_id, account_data_type):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request, user_id, account_data_type):
+ requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.")
- event = yield self.store.get_global_account_data_by_type_for_user(
+ event = await self.store.get_global_account_data_by_type_for_user(
account_data_type, user_id
)
@@ -91,9 +87,8 @@ class RoomAccountDataServlet(RestServlet):
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
- @defer.inlineCallbacks
- def on_PUT(self, request, user_id, room_id, account_data_type):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, user_id, room_id, account_data_type):
+ requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
@@ -106,7 +101,7 @@ class RoomAccountDataServlet(RestServlet):
" Use /rooms/!roomId:server.name/read_markers",
)
- max_id = yield self.store.add_account_data_to_room(
+ max_id = await self.store.add_account_data_to_room(
user_id, room_id, account_data_type, body
)
@@ -114,13 +109,12 @@ class RoomAccountDataServlet(RestServlet):
return 200, {}
- @defer.inlineCallbacks
- def on_GET(self, request, user_id, room_id, account_data_type):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request, user_id, room_id, account_data_type):
+ requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get account data for other users.")
- event = yield self.store.get_account_data_for_room_and_type(
+ event = await self.store.get_account_data_for_room_and_type(
user_id, room_id, account_data_type
)
diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py
index 33f6a23028..2f10fa64e2 100644
--- a/synapse/rest/client/v2_alpha/account_validity.py
+++ b/synapse/rest/client/v2_alpha/account_validity.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import AuthError, SynapseError
from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet
@@ -45,13 +43,12 @@ class AccountValidityRenewServlet(RestServlet):
self.success_html = hs.config.account_validity.account_renewed_html_content
self.failure_html = hs.config.account_validity.invalid_token_html_content
- @defer.inlineCallbacks
- def on_GET(self, request):
+ async def on_GET(self, request):
if b"token" not in request.args:
raise SynapseError(400, "Missing renewal token")
renewal_token = request.args[b"token"][0]
- token_valid = yield self.account_activity_handler.renew_account(
+ token_valid = await self.account_activity_handler.renew_account(
renewal_token.decode("utf8")
)
@@ -67,7 +64,6 @@ class AccountValidityRenewServlet(RestServlet):
request.setHeader(b"Content-Length", b"%d" % (len(response),))
request.write(response.encode("utf8"))
finish_request(request)
- defer.returnValue(None)
class AccountValiditySendMailServlet(RestServlet):
@@ -85,18 +81,17 @@ class AccountValiditySendMailServlet(RestServlet):
self.auth = hs.get_auth()
self.account_validity = self.hs.config.account_validity
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
if not self.account_validity.renew_by_email_enabled:
raise AuthError(
403, "Account renewal via email is disabled on this server."
)
- requester = yield self.auth.get_user_by_req(request, allow_expired=True)
+ requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string()
- yield self.account_activity_handler.send_renewal_email_to_user(user_id)
+ await self.account_activity_handler.send_renewal_email_to_user(user_id)
- defer.returnValue((200, {}))
+ return 200, {}
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index f21aff39e5..7a256b6ecb 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError
from synapse.api.urls import CLIENT_API_PREFIX
@@ -171,8 +169,7 @@ class AuthRestServlet(RestServlet):
else:
raise SynapseError(404, "Unknown auth stage type")
- @defer.inlineCallbacks
- def on_POST(self, request, stagetype):
+ async def on_POST(self, request, stagetype):
session = parse_string(request, "session")
if not session:
@@ -186,7 +183,7 @@ class AuthRestServlet(RestServlet):
authdict = {"response": response, "session": session}
- success = yield self.auth_handler.add_oob_auth(
+ success = await self.auth_handler.add_oob_auth(
LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request)
)
@@ -215,7 +212,7 @@ class AuthRestServlet(RestServlet):
session = request.args["session"][0]
authdict = {"session": session}
- success = yield self.auth_handler.add_oob_auth(
+ success = await self.auth_handler.add_oob_auth(
LoginType.TERMS, authdict, self.hs.get_ip_from_request(request)
)
diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py
index acd58af193..fe9d019c44 100644
--- a/synapse/rest/client/v2_alpha/capabilities.py
+++ b/synapse/rest/client/v2_alpha/capabilities.py
@@ -14,8 +14,6 @@
# limitations under the License.
import logging
-from twisted.internet import defer
-
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import RestServlet
@@ -40,10 +38,9 @@ class CapabilitiesRestServlet(RestServlet):
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- user = yield self.store.get_user_by_id(requester.user.to_string())
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
+ user = await self.store.get_user_by_id(requester.user.to_string())
change_password = bool(user["password_hash"])
response = {
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index 26d0235208..94ff73f384 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api import errors
from synapse.http.servlet import (
RestServlet,
@@ -42,10 +40,9 @@ class DevicesRestServlet(RestServlet):
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- devices = yield self.device_handler.get_devices_by_user(
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
+ devices = await self.device_handler.get_devices_by_user(
requester.user.to_string()
)
return 200, {"devices": devices}
@@ -67,9 +64,8 @@ class DeleteDevicesRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
try:
body = parse_json_object_from_request(request)
@@ -84,11 +80,11 @@ class DeleteDevicesRestServlet(RestServlet):
assert_params_in_dict(body, ["devices"])
- yield self.auth_handler.validate_user_via_ui_auth(
+ await self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request)
)
- yield self.device_handler.delete_devices(
+ await self.device_handler.delete_devices(
requester.user.to_string(), body["devices"]
)
return 200, {}
@@ -108,18 +104,16 @@ class DeviceRestServlet(RestServlet):
self.device_handler = hs.get_device_handler()
self.auth_handler = hs.get_auth_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, device_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
- device = yield self.device_handler.get_device(
+ async def on_GET(self, request, device_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
+ device = await self.device_handler.get_device(
requester.user.to_string(), device_id
)
return 200, device
@interactive_auth_handler
- @defer.inlineCallbacks
- def on_DELETE(self, request, device_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_DELETE(self, request, device_id):
+ requester = await self.auth.get_user_by_req(request)
try:
body = parse_json_object_from_request(request)
@@ -132,19 +126,18 @@ class DeviceRestServlet(RestServlet):
else:
raise
- yield self.auth_handler.validate_user_via_ui_auth(
+ await self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request)
)
- yield self.device_handler.delete_device(requester.user.to_string(), device_id)
+ await self.device_handler.delete_device(requester.user.to_string(), device_id)
return 200, {}
- @defer.inlineCallbacks
- def on_PUT(self, request, device_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_PUT(self, request, device_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
body = parse_json_object_from_request(request)
- yield self.device_handler.update_device(
+ await self.device_handler.update_device(
requester.user.to_string(), device_id, body
)
return 200, {}
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index 17a8bc7366..b28da017cd 100644
--- a/synapse/rest/client/v2_alpha/filter.py
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID
@@ -35,10 +33,9 @@ class GetFilterRestServlet(RestServlet):
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
- @defer.inlineCallbacks
- def on_GET(self, request, user_id, filter_id):
+ async def on_GET(self, request, user_id, filter_id):
target_user = UserID.from_string(user_id)
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
if target_user != requester.user:
raise AuthError(403, "Cannot get filters for other users")
@@ -52,7 +49,7 @@ class GetFilterRestServlet(RestServlet):
raise SynapseError(400, "Invalid filter_id")
try:
- filter_collection = yield self.filtering.get_user_filter(
+ filter_collection = await self.filtering.get_user_filter(
user_localpart=target_user.localpart, filter_id=filter_id
)
except StoreError as e:
@@ -72,11 +69,10 @@ class CreateFilterRestServlet(RestServlet):
self.auth = hs.get_auth()
self.filtering = hs.get_filtering()
- @defer.inlineCallbacks
- def on_POST(self, request, user_id):
+ async def on_POST(self, request, user_id):
target_user = UserID.from_string(user_id)
- requester = yield self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request)
if target_user != requester.user:
raise AuthError(403, "Cannot create filters for other users")
@@ -87,7 +83,7 @@ class CreateFilterRestServlet(RestServlet):
content = parse_json_object_from_request(request)
set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit)
- filter_id = yield self.filtering.add_user_filter(
+ filter_id = await self.filtering.add_user_filter(
user_localpart=target_user.localpart, user_filter=content
)
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 999a0fa80c..d84a6d7e11 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -16,8 +16,6 @@
import logging
-from twisted.internet import defer
-
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import GroupID
@@ -38,24 +36,22 @@ class GroupServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- group_description = yield self.groups_handler.get_group_profile(
+ group_description = await self.groups_handler.get_group_profile(
group_id, requester_user_id
)
return 200, group_description
- @defer.inlineCallbacks
- def on_POST(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- yield self.groups_handler.update_group_profile(
+ await self.groups_handler.update_group_profile(
group_id, requester_user_id, content
)
@@ -74,12 +70,11 @@ class GroupSummaryServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- get_group_summary = yield self.groups_handler.get_group_summary(
+ get_group_summary = await self.groups_handler.get_group_summary(
group_id, requester_user_id
)
@@ -106,13 +101,12 @@ class GroupSummaryRoomsCatServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id, category_id, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id, category_id, room_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- resp = yield self.groups_handler.update_group_summary_room(
+ resp = await self.groups_handler.update_group_summary_room(
group_id,
requester_user_id,
room_id=room_id,
@@ -122,12 +116,11 @@ class GroupSummaryRoomsCatServlet(RestServlet):
return 200, resp
- @defer.inlineCallbacks
- def on_DELETE(self, request, group_id, category_id, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_DELETE(self, request, group_id, category_id, room_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
- resp = yield self.groups_handler.delete_group_summary_room(
+ resp = await self.groups_handler.delete_group_summary_room(
group_id, requester_user_id, room_id=room_id, category_id=category_id
)
@@ -148,35 +141,32 @@ class GroupCategoryServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id, category_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, group_id, category_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- category = yield self.groups_handler.get_group_category(
+ category = await self.groups_handler.get_group_category(
group_id, requester_user_id, category_id=category_id
)
return 200, category
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id, category_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id, category_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- resp = yield self.groups_handler.update_group_category(
+ resp = await self.groups_handler.update_group_category(
group_id, requester_user_id, category_id=category_id, content=content
)
return 200, resp
- @defer.inlineCallbacks
- def on_DELETE(self, request, group_id, category_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_DELETE(self, request, group_id, category_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
- resp = yield self.groups_handler.delete_group_category(
+ resp = await self.groups_handler.delete_group_category(
group_id, requester_user_id, category_id=category_id
)
@@ -195,12 +185,11 @@ class GroupCategoriesServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- category = yield self.groups_handler.get_group_categories(
+ category = await self.groups_handler.get_group_categories(
group_id, requester_user_id
)
@@ -219,35 +208,32 @@ class GroupRoleServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id, role_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, group_id, role_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- category = yield self.groups_handler.get_group_role(
+ category = await self.groups_handler.get_group_role(
group_id, requester_user_id, role_id=role_id
)
return 200, category
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id, role_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id, role_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- resp = yield self.groups_handler.update_group_role(
+ resp = await self.groups_handler.update_group_role(
group_id, requester_user_id, role_id=role_id, content=content
)
return 200, resp
- @defer.inlineCallbacks
- def on_DELETE(self, request, group_id, role_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_DELETE(self, request, group_id, role_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
- resp = yield self.groups_handler.delete_group_role(
+ resp = await self.groups_handler.delete_group_role(
group_id, requester_user_id, role_id=role_id
)
@@ -266,12 +252,11 @@ class GroupRolesServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- category = yield self.groups_handler.get_group_roles(
+ category = await self.groups_handler.get_group_roles(
group_id, requester_user_id
)
@@ -298,13 +283,12 @@ class GroupSummaryUsersRoleServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id, role_id, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id, role_id, user_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- resp = yield self.groups_handler.update_group_summary_user(
+ resp = await self.groups_handler.update_group_summary_user(
group_id,
requester_user_id,
user_id=user_id,
@@ -314,12 +298,11 @@ class GroupSummaryUsersRoleServlet(RestServlet):
return 200, resp
- @defer.inlineCallbacks
- def on_DELETE(self, request, group_id, role_id, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_DELETE(self, request, group_id, role_id, user_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
- resp = yield self.groups_handler.delete_group_summary_user(
+ resp = await self.groups_handler.delete_group_summary_user(
group_id, requester_user_id, user_id=user_id, role_id=role_id
)
@@ -338,12 +321,11 @@ class GroupRoomServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- result = yield self.groups_handler.get_rooms_in_group(
+ result = await self.groups_handler.get_rooms_in_group(
group_id, requester_user_id
)
@@ -362,12 +344,11 @@ class GroupUsersServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- result = yield self.groups_handler.get_users_in_group(
+ result = await self.groups_handler.get_users_in_group(
group_id, requester_user_id
)
@@ -386,12 +367,11 @@ class GroupInvitedUsersServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
- result = yield self.groups_handler.get_invited_users_in_group(
+ result = await self.groups_handler.get_invited_users_in_group(
group_id, requester_user_id
)
@@ -409,14 +389,13 @@ class GroupSettingJoinPolicyServlet(RestServlet):
self.auth = hs.get_auth()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- result = yield self.groups_handler.set_group_join_policy(
+ result = await self.groups_handler.set_group_join_policy(
group_id, requester_user_id, content
)
@@ -436,9 +415,8 @@ class GroupCreateServlet(RestServlet):
self.groups_handler = hs.get_groups_local_handler()
self.server_name = hs.hostname
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
# TODO: Create group on remote server
@@ -446,7 +424,7 @@ class GroupCreateServlet(RestServlet):
localpart = content.pop("localpart")
group_id = GroupID(localpart, self.server_name).to_string()
- result = yield self.groups_handler.create_group(
+ result = await self.groups_handler.create_group(
group_id, requester_user_id, content
)
@@ -467,24 +445,22 @@ class GroupAdminRoomsServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id, room_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- result = yield self.groups_handler.add_room_to_group(
+ result = await self.groups_handler.add_room_to_group(
group_id, requester_user_id, room_id, content
)
return 200, result
- @defer.inlineCallbacks
- def on_DELETE(self, request, group_id, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_DELETE(self, request, group_id, room_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
- result = yield self.groups_handler.remove_room_from_group(
+ result = await self.groups_handler.remove_room_from_group(
group_id, requester_user_id, room_id
)
@@ -506,13 +482,12 @@ class GroupAdminRoomsConfigServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id, room_id, config_key):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id, room_id, config_key):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- result = yield self.groups_handler.update_room_in_group(
+ result = await self.groups_handler.update_room_in_group(
group_id, requester_user_id, room_id, config_key, content
)
@@ -535,14 +510,13 @@ class GroupAdminUsersInviteServlet(RestServlet):
self.store = hs.get_datastore()
self.is_mine_id = hs.is_mine_id
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id, user_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
config = content.get("config", {})
- result = yield self.groups_handler.invite(
+ result = await self.groups_handler.invite(
group_id, user_id, requester_user_id, config
)
@@ -563,13 +537,12 @@ class GroupAdminUsersKickServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id, user_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- result = yield self.groups_handler.remove_user_from_group(
+ result = await self.groups_handler.remove_user_from_group(
group_id, user_id, requester_user_id, content
)
@@ -588,13 +561,12 @@ class GroupSelfLeaveServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- result = yield self.groups_handler.remove_user_from_group(
+ result = await self.groups_handler.remove_user_from_group(
group_id, requester_user_id, requester_user_id, content
)
@@ -613,13 +585,12 @@ class GroupSelfJoinServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- result = yield self.groups_handler.join_group(
+ result = await self.groups_handler.join_group(
group_id, requester_user_id, content
)
@@ -638,13 +609,12 @@ class GroupSelfAcceptInviteServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
- result = yield self.groups_handler.accept_invite(
+ result = await self.groups_handler.accept_invite(
group_id, requester_user_id, content
)
@@ -663,14 +633,13 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def on_PUT(self, request, group_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, group_id):
+ requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
publicise = content["publicise"]
- yield self.store.update_group_publicity(group_id, requester_user_id, publicise)
+ await self.store.update_group_publicity(group_id, requester_user_id, publicise)
return 200, {}
@@ -688,11 +657,10 @@ class PublicisedGroupsForUserServlet(RestServlet):
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, user_id):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, user_id):
+ await self.auth.get_user_by_req(request, allow_guest=True)
- result = yield self.groups_handler.get_publicised_groups_for_user(user_id)
+ result = await self.groups_handler.get_publicised_groups_for_user(user_id)
return 200, result
@@ -710,14 +678,13 @@ class PublicisedGroupsForUsersServlet(RestServlet):
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_POST(self, request):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_POST(self, request):
+ await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
user_ids = content["user_ids"]
- result = yield self.groups_handler.bulk_get_publicised_groups(user_ids)
+ result = await self.groups_handler.bulk_get_publicised_groups(user_ids)
return 200, result
@@ -734,12 +701,11 @@ class GroupsForUserServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- result = yield self.groups_handler.get_joined_groups(requester_user_id)
+ result = await self.groups_handler.get_joined_groups(requester_user_id)
return 200, result
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 341567ae21..f7ed4daf90 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -16,8 +16,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import SynapseError
from synapse.http.servlet import (
RestServlet,
@@ -71,9 +69,8 @@ class KeyUploadServlet(RestServlet):
self.e2e_keys_handler = hs.get_e2e_keys_handler()
@trace(opname="upload_keys")
- @defer.inlineCallbacks
- def on_POST(self, request, device_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_POST(self, request, device_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -103,7 +100,7 @@ class KeyUploadServlet(RestServlet):
400, "To upload keys, you must pass device_id when authenticating"
)
- result = yield self.e2e_keys_handler.upload_keys_for_user(
+ result = await self.e2e_keys_handler.upload_keys_for_user(
user_id, device_id, body
)
return 200, result
@@ -154,13 +151,12 @@ class KeyQueryServlet(RestServlet):
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
- result = yield self.e2e_keys_handler.query_devices(body, timeout, user_id)
+ result = await self.e2e_keys_handler.query_devices(body, timeout, user_id)
return 200, result
@@ -185,9 +181,8 @@ class KeyChangesServlet(RestServlet):
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
from_token_string = parse_string(request, "from")
set_tag("from", from_token_string)
@@ -200,7 +195,7 @@ class KeyChangesServlet(RestServlet):
user_id = requester.user.to_string()
- results = yield self.device_handler.get_user_ids_changed(user_id, from_token)
+ results = await self.device_handler.get_user_ids_changed(user_id, from_token)
return 200, results
@@ -231,12 +226,11 @@ class OneTimeKeyServlet(RestServlet):
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
- @defer.inlineCallbacks
- def on_POST(self, request):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_POST(self, request):
+ await self.auth.get_user_by_req(request, allow_guest=True)
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
- result = yield self.e2e_keys_handler.claim_one_time_keys(body, timeout)
+ result = await self.e2e_keys_handler.claim_one_time_keys(body, timeout)
return 200, result
@@ -263,17 +257,16 @@ class SigningKeyUploadServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
@interactive_auth_handler
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
- yield self.auth_handler.validate_user_via_ui_auth(
+ await self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request)
)
- result = yield self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
+ result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
return 200, result
@@ -315,13 +308,12 @@ class SignaturesUploadServlet(RestServlet):
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
- @defer.inlineCallbacks
- def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_POST(self, request):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
- result = yield self.e2e_keys_handler.upload_signatures_for_device_keys(
+ result = await self.e2e_keys_handler.upload_signatures_for_device_keys(
user_id, body
)
return 200, result
diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py
index 10c1ad5b07..aa911d75ee 100644
--- a/synapse/rest/client/v2_alpha/notifications.py
+++ b/synapse/rest/client/v2_alpha/notifications.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.events.utils import format_event_for_client_v2_without_room_id
from synapse.http.servlet import RestServlet, parse_integer, parse_string
@@ -35,9 +33,8 @@ class NotificationsServlet(RestServlet):
self.clock = hs.get_clock()
self._event_serializer = hs.get_event_client_serializer()
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request):
+ requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
from_token = parse_string(request, "from", required=False)
@@ -46,16 +43,16 @@ class NotificationsServlet(RestServlet):
limit = min(limit, 500)
- push_actions = yield self.store.get_push_actions_for_user(
+ push_actions = await self.store.get_push_actions_for_user(
user_id, from_token, limit, only_highlight=(only == "highlight")
)
- receipts_by_room = yield self.store.get_receipts_for_user_with_orderings(
+ receipts_by_room = await self.store.get_receipts_for_user_with_orderings(
user_id, "m.read"
)
notif_event_ids = [pa["event_id"] for pa in push_actions]
- notif_events = yield self.store.get_events(notif_event_ids)
+ notif_events = await self.store.get_events(notif_event_ids)
returned_push_actions = []
@@ -68,7 +65,7 @@ class NotificationsServlet(RestServlet):
"actions": pa["actions"],
"ts": pa["received_ts"],
"event": (
- yield self._event_serializer.serialize_event(
+ await self._event_serializer.serialize_event(
notif_events[pa["event_id"]],
self.clock.time_msec(),
event_format=format_event_for_client_v2_without_room_id,
diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py
index b4925c0f59..6ae9a5a8e9 100644
--- a/synapse/rest/client/v2_alpha/openid.py
+++ b/synapse/rest/client/v2_alpha/openid.py
@@ -16,8 +16,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.util.stringutils import random_string
@@ -68,9 +66,8 @@ class IdTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.server_name = hs.config.server_name
- @defer.inlineCallbacks
- def on_POST(self, request, user_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request, user_id):
+ requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot request tokens for other users.")
@@ -81,7 +78,7 @@ class IdTokenServlet(RestServlet):
token = random_string(24)
ts_valid_until_ms = self.clock.time_msec() + self.EXPIRES_MS
- yield self.store.insert_open_id_token(token, ts_valid_until_ms, user_id)
+ await self.store.insert_open_id_token(token, ts_valid_until_ms, user_id)
return (
200,
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 91db923814..66de16a1fa 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -20,8 +20,6 @@ from typing import List, Union
from six import string_types
-from twisted.internet import defer
-
import synapse
import synapse.types
from synapse.api.constants import LoginType
@@ -102,8 +100,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
template_text=template_text,
)
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF:
if self.hs.config.local_threepid_handling_disabled_due_to_email_config:
logger.warning(
@@ -129,7 +126,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
+ existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", body["email"]
)
@@ -140,7 +137,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
assert self.hs.config.account_threepid_delegate_email
# Have the configured identity server handle the request
- ret = yield self.identity_handler.requestEmailToken(
+ ret = await self.identity_handler.requestEmailToken(
self.hs.config.account_threepid_delegate_email,
email,
client_secret,
@@ -149,7 +146,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
)
else:
# Send registration emails from Synapse
- sid = yield self.identity_handler.send_threepid_validation(
+ sid = await self.identity_handler.send_threepid_validation(
email,
client_secret,
send_attempt,
@@ -175,8 +172,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
self.hs = hs
self.identity_handler = hs.get_handlers().identity_handler
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
body = parse_json_object_from_request(request)
assert_params_in_dict(
@@ -197,7 +193,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- existing_user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
+ existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"msisdn", msisdn
)
@@ -215,7 +211,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
400, "Registration by phone number is not supported on this homeserver"
)
- ret = yield self.identity_handler.requestMsisdnToken(
+ ret = await self.identity_handler.requestMsisdnToken(
self.hs.config.account_threepid_delegate_msisdn,
country,
phone_number,
@@ -258,8 +254,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
[self.config.email_registration_template_failure_html],
)
- @defer.inlineCallbacks
- def on_GET(self, request, medium):
+ async def on_GET(self, request, medium):
if medium != "email":
raise SynapseError(
400, "This medium is currently not supported for registration"
@@ -280,7 +275,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
# Attempt to validate a 3PID session
try:
# Mark the session as valid
- next_link = yield self.store.validate_threepid_session(
+ next_link = await self.store.validate_threepid_session(
sid, client_secret, token, self.clock.time_msec()
)
@@ -338,8 +333,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
),
)
- @defer.inlineCallbacks
- def on_GET(self, request):
+ async def on_GET(self, request):
if not self.hs.config.enable_registration:
raise SynapseError(
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
@@ -347,11 +341,11 @@ class UsernameAvailabilityRestServlet(RestServlet):
ip = self.hs.get_ip_from_request(request)
with self.ratelimiter.ratelimit(ip) as wait_deferred:
- yield wait_deferred
+ await wait_deferred
username = parse_string(request, "username", required=True)
- yield self.registration_handler.check_username(username)
+ await self.registration_handler.check_username(username)
return 200, {"available": True}
@@ -382,8 +376,7 @@ class RegisterRestServlet(RestServlet):
)
@interactive_auth_handler
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
body = parse_json_object_from_request(request)
client_addr = request.getClientIP()
@@ -408,7 +401,7 @@ class RegisterRestServlet(RestServlet):
kind = request.args[b"kind"][0]
if kind == b"guest":
- ret = yield self._do_guest_registration(body, address=client_addr)
+ ret = await self._do_guest_registration(body, address=client_addr)
return ret
elif kind != b"user":
raise UnrecognizedRequestError(
@@ -435,7 +428,7 @@ class RegisterRestServlet(RestServlet):
appservice = None
if self.auth.has_access_token(request):
- appservice = yield self.auth.get_appservice_by_req(request)
+ appservice = await self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes which have completely
# different registration flows to normal users
@@ -455,7 +448,7 @@ class RegisterRestServlet(RestServlet):
access_token = self.auth.get_access_token_from_request(request)
if isinstance(desired_username, string_types):
- result = yield self._do_appservice_registration(
+ result = await self._do_appservice_registration(
desired_username, access_token, body
)
return 200, result # we throw for non 200 responses
@@ -495,13 +488,13 @@ class RegisterRestServlet(RestServlet):
)
if desired_username is not None:
- yield self.registration_handler.check_username(
+ await self.registration_handler.check_username(
desired_username,
guest_access_token=guest_access_token,
assigned_user_id=registered_user_id,
)
- auth_result, params, session_id = yield self.auth_handler.check_auth(
+ auth_result, params, session_id = await self.auth_handler.check_auth(
self._registration_flows, body, self.hs.get_ip_from_request(request)
)
@@ -557,7 +550,7 @@ class RegisterRestServlet(RestServlet):
medium = auth_result[login_type]["medium"]
address = auth_result[login_type]["address"]
- existing_user_id = yield self.store.get_user_id_by_threepid(
+ existing_user_id = await self.store.get_user_id_by_threepid(
medium, address
)
@@ -568,7 +561,7 @@ class RegisterRestServlet(RestServlet):
Codes.THREEPID_IN_USE,
)
- registered_user_id = yield self.registration_handler.register_user(
+ registered_user_id = await self.registration_handler.register_user(
localpart=desired_username,
password=new_password,
guest_access_token=guest_access_token,
@@ -581,7 +574,7 @@ class RegisterRestServlet(RestServlet):
if is_threepid_reserved(
self.hs.config.mau_limits_reserved_threepids, threepid
):
- yield self.store.upsert_monthly_active_user(registered_user_id)
+ await self.store.upsert_monthly_active_user(registered_user_id)
# remember that we've now registered that user account, and with
# what user ID (since the user may not have specified)
@@ -591,12 +584,12 @@ class RegisterRestServlet(RestServlet):
registered = True
- return_dict = yield self._create_registration_details(
+ return_dict = await self._create_registration_details(
registered_user_id, params
)
if registered:
- yield self.registration_handler.post_registration_actions(
+ await self.registration_handler.post_registration_actions(
user_id=registered_user_id,
auth_result=auth_result,
access_token=return_dict.get("access_token"),
@@ -607,15 +600,13 @@ class RegisterRestServlet(RestServlet):
def on_OPTIONS(self, _):
return 200, {}
- @defer.inlineCallbacks
- def _do_appservice_registration(self, username, as_token, body):
- user_id = yield self.registration_handler.appservice_register(
+ async def _do_appservice_registration(self, username, as_token, body):
+ user_id = await self.registration_handler.appservice_register(
username, as_token
)
- return (yield self._create_registration_details(user_id, body))
+ return await self._create_registration_details(user_id, body)
- @defer.inlineCallbacks
- def _create_registration_details(self, user_id, params):
+ async def _create_registration_details(self, user_id, params):
"""Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token.
@@ -631,18 +622,17 @@ class RegisterRestServlet(RestServlet):
if not params.get("inhibit_login", False):
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
- device_id, access_token = yield self.registration_handler.register_device(
+ device_id, access_token = await self.registration_handler.register_device(
user_id, device_id, initial_display_name, is_guest=False
)
result.update({"access_token": access_token, "device_id": device_id})
return result
- @defer.inlineCallbacks
- def _do_guest_registration(self, params, address=None):
+ async def _do_guest_registration(self, params, address=None):
if not self.hs.config.allow_guest_access:
raise SynapseError(403, "Guest access is disabled")
- user_id = yield self.registration_handler.register_user(
+ user_id = await self.registration_handler.register_user(
make_guest=True, address=address
)
@@ -650,7 +640,7 @@ class RegisterRestServlet(RestServlet):
# we have nowhere to store it.
device_id = synapse.api.auth.GUEST_DEVICE_ID
initial_display_name = params.get("initial_device_display_name")
- device_id, access_token = yield self.registration_handler.register_device(
+ device_id, access_token = await self.registration_handler.register_device(
user_id, device_id, initial_display_name, is_guest=True
)
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index 040b37c504..9be9a34b91 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -21,8 +21,6 @@ any time to reflect changes in the MSC.
import logging
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import SynapseError
from synapse.http.servlet import (
@@ -86,11 +84,10 @@ class RelationSendServlet(RestServlet):
request, self.on_PUT_or_POST, request, *args, **kwargs
)
- @defer.inlineCallbacks
- def on_PUT_or_POST(
+ async def on_PUT_or_POST(
self, request, room_id, parent_id, relation_type, event_type, txn_id=None
):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
if event_type == EventTypes.Member:
# Add relations to a membership is meaningless, so we just deny it
@@ -114,7 +111,7 @@ class RelationSendServlet(RestServlet):
"sender": requester.user.to_string(),
}
- event = yield self.event_creation_handler.create_and_send_nonmember_event(
+ event = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict=event_dict, txn_id=txn_id
)
@@ -140,17 +137,18 @@ class RelationPaginationServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(
+ self, request, room_id, parent_id, relation_type=None, event_type=None
+ ):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
- yield self.auth.check_in_room_or_world_readable(
+ await self.auth.check_in_room_or_world_readable(
room_id, requester.user.to_string()
)
# This gets the original event and checks that a) the event exists and
# b) the user is allowed to view it.
- event = yield self.event_handler.get_event(requester.user, room_id, parent_id)
+ event = await self.event_handler.get_event(requester.user, room_id, parent_id)
limit = parse_integer(request, "limit", default=5)
from_token = parse_string(request, "from")
@@ -167,7 +165,7 @@ class RelationPaginationServlet(RestServlet):
if to_token:
to_token = RelationPaginationToken.from_string(to_token)
- pagination_chunk = yield self.store.get_relations_for_event(
+ pagination_chunk = await self.store.get_relations_for_event(
event_id=parent_id,
relation_type=relation_type,
event_type=event_type,
@@ -176,7 +174,7 @@ class RelationPaginationServlet(RestServlet):
to_token=to_token,
)
- events = yield self.store.get_events_as_list(
+ events = await self.store.get_events_as_list(
[c["event_id"] for c in pagination_chunk.chunk]
)
@@ -184,13 +182,13 @@ class RelationPaginationServlet(RestServlet):
# We set bundle_aggregations to False when retrieving the original
# event because we want the content before relations were applied to
# it.
- original_event = yield self._event_serializer.serialize_event(
+ original_event = await self._event_serializer.serialize_event(
event, now, bundle_aggregations=False
)
# Similarly, we don't allow relations to be applied to relations, so we
# return the original relations without any aggregations on top of them
# here.
- events = yield self._event_serializer.serialize_events(
+ events = await self._event_serializer.serialize_events(
events, now, bundle_aggregations=False
)
@@ -232,17 +230,18 @@ class RelationAggregationPaginationServlet(RestServlet):
self.store = hs.get_datastore()
self.event_handler = hs.get_event_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(
+ self, request, room_id, parent_id, relation_type=None, event_type=None
+ ):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
- yield self.auth.check_in_room_or_world_readable(
+ await self.auth.check_in_room_or_world_readable(
room_id, requester.user.to_string()
)
# This checks that a) the event exists and b) the user is allowed to
# view it.
- event = yield self.event_handler.get_event(requester.user, room_id, parent_id)
+ event = await self.event_handler.get_event(requester.user, room_id, parent_id)
if relation_type not in (RelationTypes.ANNOTATION, None):
raise SynapseError(400, "Relation type must be 'annotation'")
@@ -262,7 +261,7 @@ class RelationAggregationPaginationServlet(RestServlet):
if to_token:
to_token = AggregationPaginationToken.from_string(to_token)
- pagination_chunk = yield self.store.get_aggregation_groups_for_event(
+ pagination_chunk = await self.store.get_aggregation_groups_for_event(
event_id=parent_id,
event_type=event_type,
limit=limit,
@@ -311,17 +310,16 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
self._event_serializer = hs.get_event_client_serializer()
self.event_handler = hs.get_event_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, room_id, parent_id, relation_type, event_type, key):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, room_id, parent_id, relation_type, event_type, key):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
- yield self.auth.check_in_room_or_world_readable(
+ await self.auth.check_in_room_or_world_readable(
room_id, requester.user.to_string()
)
# This checks that a) the event exists and b) the user is allowed to
# view it.
- yield self.event_handler.get_event(requester.user, room_id, parent_id)
+ await self.event_handler.get_event(requester.user, room_id, parent_id)
if relation_type != RelationTypes.ANNOTATION:
raise SynapseError(400, "Relation type must be 'annotation'")
@@ -336,7 +334,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
if to_token:
to_token = RelationPaginationToken.from_string(to_token)
- result = yield self.store.get_relations_for_event(
+ result = await self.store.get_relations_for_event(
event_id=parent_id,
relation_type=relation_type,
event_type=event_type,
@@ -346,12 +344,12 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
to_token=to_token,
)
- events = yield self.store.get_events_as_list(
+ events = await self.store.get_events_as_list(
[c["event_id"] for c in result.chunk]
)
now = self.clock.time_msec()
- events = yield self._event_serializer.serialize_events(events, now)
+ events = await self._event_serializer.serialize_events(events, now)
return_value = result.to_dict()
return_value["chunk"] = events
diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py
index e7449864cd..f067b5edac 100644
--- a/synapse/rest/client/v2_alpha/report_event.py
+++ b/synapse/rest/client/v2_alpha/report_event.py
@@ -18,8 +18,6 @@ import logging
from six import string_types
from six.moves import http_client
-from twisted.internet import defer
-
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import (
RestServlet,
@@ -42,9 +40,8 @@ class ReportEventRestServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def on_POST(self, request, room_id, event_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_POST(self, request, room_id, event_id):
+ requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -63,7 +60,7 @@ class ReportEventRestServlet(RestServlet):
Codes.BAD_JSON,
)
- yield self.store.add_event_report(
+ await self.store.add_event_report(
room_id=room_id,
event_id=event_id,
user_id=user_id,
diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py
index d596786430..38952a1d27 100644
--- a/synapse/rest/client/v2_alpha/room_keys.py
+++ b/synapse/rest/client/v2_alpha/room_keys.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import (
RestServlet,
@@ -43,8 +41,7 @@ class RoomKeysServlet(RestServlet):
self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
- @defer.inlineCallbacks
- def on_PUT(self, request, room_id, session_id):
+ async def on_PUT(self, request, room_id, session_id):
"""
Uploads one or more encrypted E2E room keys for backup purposes.
room_id: the ID of the room the keys are for (optional)
@@ -123,7 +120,7 @@ class RoomKeysServlet(RestServlet):
}
}
"""
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
version = parse_string(request, "version")
@@ -134,11 +131,10 @@ class RoomKeysServlet(RestServlet):
if room_id:
body = {"rooms": {room_id: body}}
- yield self.e2e_room_keys_handler.upload_room_keys(user_id, version, body)
- return 200, {}
+ ret = await self.e2e_room_keys_handler.upload_room_keys(user_id, version, body)
+ return 200, ret
- @defer.inlineCallbacks
- def on_GET(self, request, room_id, session_id):
+ async def on_GET(self, request, room_id, session_id):
"""
Retrieves one or more encrypted E2E room keys for backup purposes.
Symmetric with the PUT version of the API.
@@ -190,11 +186,11 @@ class RoomKeysServlet(RestServlet):
}
}
"""
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
version = parse_string(request, "version")
- room_keys = yield self.e2e_room_keys_handler.get_room_keys(
+ room_keys = await self.e2e_room_keys_handler.get_room_keys(
user_id, version, room_id, session_id
)
@@ -220,8 +216,7 @@ class RoomKeysServlet(RestServlet):
return 200, room_keys
- @defer.inlineCallbacks
- def on_DELETE(self, request, room_id, session_id):
+ async def on_DELETE(self, request, room_id, session_id):
"""
Deletes one or more encrypted E2E room keys for a user for backup purposes.
@@ -235,14 +230,14 @@ class RoomKeysServlet(RestServlet):
the version must already have been created via the /change_secret API.
"""
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
version = parse_string(request, "version")
- yield self.e2e_room_keys_handler.delete_room_keys(
+ ret = await self.e2e_room_keys_handler.delete_room_keys(
user_id, version, room_id, session_id
)
- return 200, {}
+ return 200, ret
class RoomKeysNewVersionServlet(RestServlet):
@@ -257,8 +252,7 @@ class RoomKeysNewVersionServlet(RestServlet):
self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
"""
Create a new backup version for this user's room_keys with the given
info. The version is allocated by the server and returned to the user
@@ -288,11 +282,11 @@ class RoomKeysNewVersionServlet(RestServlet):
"version": 12345
}
"""
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
info = parse_json_object_from_request(request)
- new_version = yield self.e2e_room_keys_handler.create_version(user_id, info)
+ new_version = await self.e2e_room_keys_handler.create_version(user_id, info)
return 200, {"version": new_version}
# we deliberately don't have a PUT /version, as these things really should
@@ -311,8 +305,7 @@ class RoomKeysVersionServlet(RestServlet):
self.auth = hs.get_auth()
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, version):
+ async def on_GET(self, request, version):
"""
Retrieve the version information about a given version of the user's
room_keys backup. If the version part is missing, returns info about the
@@ -330,18 +323,17 @@ class RoomKeysVersionServlet(RestServlet):
"auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K"
}
"""
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
try:
- info = yield self.e2e_room_keys_handler.get_version_info(user_id, version)
+ info = await self.e2e_room_keys_handler.get_version_info(user_id, version)
except SynapseError as e:
if e.code == 404:
raise SynapseError(404, "No backup found", Codes.NOT_FOUND)
return 200, info
- @defer.inlineCallbacks
- def on_DELETE(self, request, version):
+ async def on_DELETE(self, request, version):
"""
Delete the information about a given version of the user's
room_keys backup. If the version part is missing, deletes the most
@@ -354,14 +346,13 @@ class RoomKeysVersionServlet(RestServlet):
if version is None:
raise SynapseError(400, "No version specified to delete", Codes.NOT_FOUND)
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
- yield self.e2e_room_keys_handler.delete_version(user_id, version)
+ await self.e2e_room_keys_handler.delete_version(user_id, version)
return 200, {}
- @defer.inlineCallbacks
- def on_PUT(self, request, version):
+ async def on_PUT(self, request, version):
"""
Update the information about a given version of the user's room_keys backup.
@@ -382,7 +373,7 @@ class RoomKeysVersionServlet(RestServlet):
Content-Type: application/json
{}
"""
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
info = parse_json_object_from_request(request)
@@ -391,7 +382,7 @@ class RoomKeysVersionServlet(RestServlet):
400, "No version specified to update", Codes.MISSING_PARAM
)
- yield self.e2e_room_keys_handler.update_version(user_id, version, info)
+ await self.e2e_room_keys_handler.update_version(user_id, version, info)
return 200, {}
diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
index d2c3316eb7..ca97330797 100644
--- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import (
@@ -59,9 +57,8 @@ class RoomUpgradeRestServlet(RestServlet):
self._room_creation_handler = hs.get_room_creation_handler()
self._auth = hs.get_auth()
- @defer.inlineCallbacks
- def on_POST(self, request, room_id):
- requester = yield self._auth.get_user_by_req(request)
+ async def on_POST(self, request, room_id):
+ requester = await self._auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
assert_params_in_dict(content, ("new_version",))
@@ -74,7 +71,7 @@ class RoomUpgradeRestServlet(RestServlet):
Codes.UNSUPPORTED_ROOM_VERSION,
)
- new_room_id = yield self._room_creation_handler.upgrade_room(
+ new_room_id = await self._room_creation_handler.upgrade_room(
requester, room_id, new_version
)
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index d90e52ed1a..501b52fb6c 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.http import servlet
from synapse.http.servlet import parse_json_object_from_request
from synapse.logging.opentracing import set_tag, trace
@@ -51,15 +49,14 @@ class SendToDeviceRestServlet(servlet.RestServlet):
request, self._put, request, message_type, txn_id
)
- @defer.inlineCallbacks
- def _put(self, request, message_type, txn_id):
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def _put(self, request, message_type, txn_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
sender_user_id = requester.user.to_string()
- yield self.device_message_handler.send_device_message(
+ await self.device_message_handler.send_device_message(
sender_user_id, message_type, content["messages"]
)
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index ccd8b17b23..d8292ce29f 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -18,8 +18,6 @@ import logging
from canonicaljson import json
-from twisted.internet import defer
-
from synapse.api.constants import PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
@@ -87,8 +85,7 @@ class SyncRestServlet(RestServlet):
self._server_notices_sender = hs.get_server_notices_sender()
self._event_serializer = hs.get_event_client_serializer()
- @defer.inlineCallbacks
- def on_GET(self, request):
+ async def on_GET(self, request):
if b"from" in request.args:
# /events used to use 'from', but /sync uses 'since'.
# Lets be helpful and whine if we see a 'from'.
@@ -96,7 +93,7 @@ class SyncRestServlet(RestServlet):
400, "'from' is not a valid query parameter. Did you mean 'since'?"
)
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
user = requester.user
device_id = requester.device_id
@@ -138,7 +135,7 @@ class SyncRestServlet(RestServlet):
filter_collection = FilterCollection(filter_object)
else:
try:
- filter_collection = yield self.filtering.get_user_filter(
+ filter_collection = await self.filtering.get_user_filter(
user.localpart, filter_id
)
except StoreError as err:
@@ -161,20 +158,20 @@ class SyncRestServlet(RestServlet):
since_token = None
# send any outstanding server notices to the user.
- yield self._server_notices_sender.on_user_syncing(user.to_string())
+ await self._server_notices_sender.on_user_syncing(user.to_string())
affect_presence = set_presence != PresenceState.OFFLINE
if affect_presence:
- yield self.presence_handler.set_state(
+ await self.presence_handler.set_state(
user, {"presence": set_presence}, True
)
- context = yield self.presence_handler.user_syncing(
+ context = await self.presence_handler.user_syncing(
user.to_string(), affect_presence=affect_presence
)
with context:
- sync_result = yield self.sync_handler.wait_for_sync_for_user(
+ sync_result = await self.sync_handler.wait_for_sync_for_user(
sync_config,
since_token=since_token,
timeout=timeout,
@@ -182,14 +179,13 @@ class SyncRestServlet(RestServlet):
)
time_now = self.clock.time_msec()
- response_content = yield self.encode_response(
+ response_content = await self.encode_response(
time_now, sync_result, requester.access_token_id, filter_collection
)
return 200, response_content
- @defer.inlineCallbacks
- def encode_response(self, time_now, sync_result, access_token_id, filter):
+ async def encode_response(self, time_now, sync_result, access_token_id, filter):
if filter.event_format == "client":
event_formatter = format_event_for_client_v2_without_room_id
elif filter.event_format == "federation":
@@ -197,7 +193,7 @@ class SyncRestServlet(RestServlet):
else:
raise Exception("Unknown event format %s" % (filter.event_format,))
- joined = yield self.encode_joined(
+ joined = await self.encode_joined(
sync_result.joined,
time_now,
access_token_id,
@@ -205,11 +201,11 @@ class SyncRestServlet(RestServlet):
event_formatter,
)
- invited = yield self.encode_invited(
+ invited = await self.encode_invited(
sync_result.invited, time_now, access_token_id, event_formatter
)
- archived = yield self.encode_archived(
+ archived = await self.encode_archived(
sync_result.archived,
time_now,
access_token_id,
@@ -250,8 +246,9 @@ class SyncRestServlet(RestServlet):
]
}
- @defer.inlineCallbacks
- def encode_joined(self, rooms, time_now, token_id, event_fields, event_formatter):
+ async def encode_joined(
+ self, rooms, time_now, token_id, event_fields, event_formatter
+ ):
"""
Encode the joined rooms in a sync result
@@ -272,7 +269,7 @@ class SyncRestServlet(RestServlet):
"""
joined = {}
for room in rooms:
- joined[room.room_id] = yield self.encode_room(
+ joined[room.room_id] = await self.encode_room(
room,
time_now,
token_id,
@@ -283,8 +280,7 @@ class SyncRestServlet(RestServlet):
return joined
- @defer.inlineCallbacks
- def encode_invited(self, rooms, time_now, token_id, event_formatter):
+ async def encode_invited(self, rooms, time_now, token_id, event_formatter):
"""
Encode the invited rooms in a sync result
@@ -304,7 +300,7 @@ class SyncRestServlet(RestServlet):
"""
invited = {}
for room in rooms:
- invite = yield self._event_serializer.serialize_event(
+ invite = await self._event_serializer.serialize_event(
room.invite,
time_now,
token_id=token_id,
@@ -319,8 +315,9 @@ class SyncRestServlet(RestServlet):
return invited
- @defer.inlineCallbacks
- def encode_archived(self, rooms, time_now, token_id, event_fields, event_formatter):
+ async def encode_archived(
+ self, rooms, time_now, token_id, event_fields, event_formatter
+ ):
"""
Encode the archived rooms in a sync result
@@ -341,7 +338,7 @@ class SyncRestServlet(RestServlet):
"""
joined = {}
for room in rooms:
- joined[room.room_id] = yield self.encode_room(
+ joined[room.room_id] = await self.encode_room(
room,
time_now,
token_id,
@@ -352,8 +349,7 @@ class SyncRestServlet(RestServlet):
return joined
- @defer.inlineCallbacks
- def encode_room(
+ async def encode_room(
self, room, time_now, token_id, joined, only_fields, event_formatter
):
"""
@@ -401,8 +397,8 @@ class SyncRestServlet(RestServlet):
event.room_id,
)
- serialized_state = yield serialize(state_events)
- serialized_timeline = yield serialize(timeline_events)
+ serialized_state = await serialize(state_events)
+ serialized_timeline = await serialize(timeline_events)
account_data = room.account_data
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index 3b555669a0..a3f12e8a77 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -37,13 +35,12 @@ class TagListServlet(RestServlet):
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def on_GET(self, request, user_id, room_id):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_GET(self, request, user_id, room_id):
+ requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot get tags for other users.")
- tags = yield self.store.get_tags_for_room(user_id, room_id)
+ tags = await self.store.get_tags_for_room(user_id, room_id)
return 200, {"tags": tags}
@@ -64,27 +61,25 @@ class TagServlet(RestServlet):
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
- @defer.inlineCallbacks
- def on_PUT(self, request, user_id, room_id, tag):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_PUT(self, request, user_id, room_id, tag):
+ requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
body = parse_json_object_from_request(request)
- max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body)
+ max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
return 200, {}
- @defer.inlineCallbacks
- def on_DELETE(self, request, user_id, room_id, tag):
- requester = yield self.auth.get_user_by_req(request)
+ async def on_DELETE(self, request, user_id, room_id, tag):
+ requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
- max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)
+ max_id = await self.store.remove_tag_from_room(user_id, room_id, tag)
self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index 2e8d672471..23709960ad 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -16,8 +16,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.constants import ThirdPartyEntityKind
from synapse.http.servlet import RestServlet
@@ -35,11 +33,10 @@ class ThirdPartyProtocolsServlet(RestServlet):
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- @defer.inlineCallbacks
- def on_GET(self, request):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request):
+ await self.auth.get_user_by_req(request, allow_guest=True)
- protocols = yield self.appservice_handler.get_3pe_protocols()
+ protocols = await self.appservice_handler.get_3pe_protocols()
return 200, protocols
@@ -52,11 +49,10 @@ class ThirdPartyProtocolServlet(RestServlet):
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, protocol):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, protocol):
+ await self.auth.get_user_by_req(request, allow_guest=True)
- protocols = yield self.appservice_handler.get_3pe_protocols(
+ protocols = await self.appservice_handler.get_3pe_protocols(
only_protocol=protocol
)
if protocol in protocols:
@@ -74,14 +70,13 @@ class ThirdPartyUserServlet(RestServlet):
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, protocol):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, protocol):
+ await self.auth.get_user_by_req(request, allow_guest=True)
fields = request.args
fields.pop(b"access_token", None)
- results = yield self.appservice_handler.query_3pe(
+ results = await self.appservice_handler.query_3pe(
ThirdPartyEntityKind.USER, protocol, fields
)
@@ -97,14 +92,13 @@ class ThirdPartyLocationServlet(RestServlet):
self.auth = hs.get_auth()
self.appservice_handler = hs.get_application_service_handler()
- @defer.inlineCallbacks
- def on_GET(self, request, protocol):
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ async def on_GET(self, request, protocol):
+ await self.auth.get_user_by_req(request, allow_guest=True)
fields = request.args
fields.pop(b"access_token", None)
- results = yield self.appservice_handler.query_3pe(
+ results = await self.appservice_handler.query_3pe(
ThirdPartyEntityKind.LOCATION, protocol, fields
)
diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py
index 2da0f55811..83f3b6b70a 100644
--- a/synapse/rest/client/v2_alpha/tokenrefresh.py
+++ b/synapse/rest/client/v2_alpha/tokenrefresh.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
-
from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet
@@ -32,8 +30,7 @@ class TokenRefreshRestServlet(RestServlet):
def __init__(self, hs):
super(TokenRefreshRestServlet, self).__init__()
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
raise AuthError(403, "tokenrefresh is no longer supported.")
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index 2863affbab..bef91a2d3e 100644
--- a/synapse/rest/client/v2_alpha/user_directory.py
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -15,8 +15,6 @@
import logging
-from twisted.internet import defer
-
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -38,8 +36,7 @@ class UserDirectorySearchRestServlet(RestServlet):
self.auth = hs.get_auth()
self.user_directory_handler = hs.get_user_directory_handler()
- @defer.inlineCallbacks
- def on_POST(self, request):
+ async def on_POST(self, request):
"""Searches for users in directory
Returns:
@@ -56,7 +53,7 @@ class UserDirectorySearchRestServlet(RestServlet):
]
}
"""
- requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
if not self.hs.config.user_directory_search_enabled:
@@ -72,7 +69,7 @@ class UserDirectorySearchRestServlet(RestServlet):
except Exception:
raise SynapseError(400, "`search_term` is required field")
- results = yield self.user_directory_handler.search_users(
+ results = await self.user_directory_handler.search_users(
user_id, search_term, limit
)
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index bb30ce3f34..2a477ad22e 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -1,5 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 15c15a12f5..fb0d02aa83 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -77,8 +77,8 @@ class PreviewUrlResource(DirectServeResource):
treq_args={"browser_like_redirects": True},
ip_whitelist=hs.config.url_preview_ip_range_whitelist,
ip_blacklist=hs.config.url_preview_ip_range_blacklist,
- http_proxy=os.getenv("http_proxy"),
- https_proxy=os.getenv("HTTPS_PROXY"),
+ http_proxy=os.getenvb(b"http_proxy"),
+ https_proxy=os.getenvb(b"HTTPS_PROXY"),
)
self.media_repo = media_repo
self.primary_base_path = media_repo.primary_base_path
@@ -122,7 +122,7 @@ class PreviewUrlResource(DirectServeResource):
pattern = entry[attrib]
value = getattr(url_tuple, attrib)
logger.debug(
- "Matching attrib '%s' with value '%s' against" " pattern '%s'",
+ "Matching attrib '%s' with value '%s' against pattern '%s'",
attrib,
value,
pattern,
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 8cf415e29d..c234ea7421 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -129,5 +129,8 @@ class Thumbnailer(object):
def _encode_image(self, output_image, output_type):
output_bytes_io = BytesIO()
- output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80)
+ fmt = self.FORMATS[output_type]
+ if fmt == "JPEG":
+ output_image = output_image.convert("RGB")
+ output_image.save(output_bytes_io, fmt, quality=80)
return output_bytes_io
diff --git a/synapse/server.py b/synapse/server.py
index 90c3b072e8..be9af7f986 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -318,8 +318,8 @@ class HomeServer(object):
def build_proxied_http_client(self):
return SimpleHttpClient(
self,
- http_proxy=os.getenv("http_proxy"),
- https_proxy=os.getenv("HTTPS_PROXY"),
+ http_proxy=os.getenvb(b"http_proxy"),
+ https_proxy=os.getenvb(b"HTTPS_PROXY"),
)
def build_room_creation_handler(self):
diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py
index 415e9c17d8..5736c56032 100644
--- a/synapse/server_notices/consent_server_notices.py
+++ b/synapse/server_notices/consent_server_notices.py
@@ -54,7 +54,7 @@ class ConsentServerNotices(object):
)
if "body" not in self._server_notice_content:
raise ConfigError(
- "user_consent server_notice_consent must contain a 'body' " "key."
+ "user_consent server_notice_consent must contain a 'body' key."
)
self._consent_uri_builder = ConsentURIBuilder(hs.config)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 1a2b7ebe25..0d7c7dff27 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -14,11 +14,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import itertools
import logging
import random
import sys
-import threading
import time
from typing import Iterable, Tuple
@@ -35,8 +33,6 @@ from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import get_domain_from_id
-from synapse.util import batch_iter
-from synapse.util.caches.descriptors import Cache
from synapse.util.stringutils import exception_to_unicode
# import a function which will return a monotonic time, in seconds
@@ -79,10 +75,6 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
"event_search": "event_search_event_id_idx",
}
-# This is a special cache name we use to batch multiple invalidations of caches
-# based on the current state when notifying workers over replication.
-_CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
-
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
@@ -237,23 +229,11 @@ class SQLBaseStore(object):
# to watch it
self._txn_perf_counters = PerformanceCounters()
- self._get_event_cache = Cache(
- "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
- )
-
- self._event_fetch_lock = threading.Condition()
- self._event_fetch_list = []
- self._event_fetch_ongoing = 0
-
- self._pending_ds = []
-
self.database_engine = hs.database_engine
# A set of tables that are not safe to use native upserts in.
self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
- self._account_validity = self.hs.config.account_validity
-
# We add the user_directory_search table to the blacklist on SQLite
# because the existing search table does not have an index, making it
# unsafe to use native upserts.
@@ -272,14 +252,6 @@ class SQLBaseStore(object):
self.rand = random.SystemRandom()
- if self._account_validity.enabled:
- self._clock.call_later(
- 0.0,
- run_as_background_process,
- "account_validity_set_expiration_dates",
- self._set_expiration_date_when_missing,
- )
-
@defer.inlineCallbacks
def _check_safe_to_upsert(self):
"""
@@ -290,7 +262,7 @@ class SQLBaseStore(object):
If the background updates have not completed, wait 15 sec and check again.
"""
- updates = yield self._simple_select_list(
+ updates = yield self.simple_select_list(
"background_updates",
keyvalues=None,
retcols=["update_name"],
@@ -312,65 +284,6 @@ class SQLBaseStore(object):
self._check_safe_to_upsert,
)
- @defer.inlineCallbacks
- def _set_expiration_date_when_missing(self):
- """
- Retrieves the list of registered users that don't have an expiration date, and
- adds an expiration date for each of them.
- """
-
- def select_users_with_no_expiration_date_txn(txn):
- """Retrieves the list of registered users with no expiration date from the
- database, filtering out deactivated users.
- """
- sql = (
- "SELECT users.name FROM users"
- " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
- " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
- )
- txn.execute(sql, [])
-
- res = self.cursor_to_dict(txn)
- if res:
- for user in res:
- self.set_expiration_date_for_user_txn(
- txn, user["name"], use_delta=True
- )
-
- yield self.runInteraction(
- "get_users_with_no_expiration_date",
- select_users_with_no_expiration_date_txn,
- )
-
- def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
- """Sets an expiration date to the account with the given user ID.
-
- Args:
- user_id (str): User ID to set an expiration date for.
- use_delta (bool): If set to False, the expiration date for the user will be
- now + validity period. If set to True, this expiration date will be a
- random value in the [now + period - d ; now + period] range, d being a
- delta equal to 10% of the validity period.
- """
- now_ms = self._clock.time_msec()
- expiration_ts = now_ms + self._account_validity.period
-
- if use_delta:
- expiration_ts = self.rand.randrange(
- expiration_ts - self._account_validity.startup_job_max_delta,
- expiration_ts,
- )
-
- self._simple_insert_txn(
- txn,
- "account_validity",
- values={
- "user_id": user_id,
- "expiration_ts_ms": expiration_ts,
- "email_sent": False,
- },
- )
-
def start_profiling(self):
self._previous_loop_ts = monotonic_time()
@@ -394,7 +307,7 @@ class SQLBaseStore(object):
self._clock.looping_call(loop, 10000)
- def _new_transaction(
+ def new_transaction(
self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
):
start = monotonic_time()
@@ -412,16 +325,15 @@ class SQLBaseStore(object):
i = 0
N = 5
while True:
+ cursor = LoggingTransaction(
+ conn.cursor(),
+ name,
+ self.database_engine,
+ after_callbacks,
+ exception_callbacks,
+ )
try:
- txn = conn.cursor()
- txn = LoggingTransaction(
- txn,
- name,
- self.database_engine,
- after_callbacks,
- exception_callbacks,
- )
- r = func(txn, *args, **kwargs)
+ r = func(cursor, *args, **kwargs)
conn.commit()
return r
except self.database_engine.module.OperationalError as e:
@@ -459,6 +371,40 @@ class SQLBaseStore(object):
)
continue
raise
+ finally:
+ # we're either about to retry with a new cursor, or we're about to
+ # release the connection. Once we release the connection, it could
+ # get used for another query, which might do a conn.rollback().
+ #
+ # In the latter case, even though that probably wouldn't affect the
+ # results of this transaction, python's sqlite will reset all
+ # statements on the connection [1], which will make our cursor
+ # invalid [2].
+ #
+ # In any case, continuing to read rows after commit()ing seems
+ # dubious from the PoV of ACID transactional semantics
+ # (sqlite explicitly says that once you commit, you may see rows
+ # from subsequent updates.)
+ #
+ # In psycopg2, cursors are essentially a client-side fabrication -
+ # all the data is transferred to the client side when the statement
+ # finishes executing - so in theory we could go on streaming results
+ # from the cursor, but attempting to do so would make us
+ # incompatible with sqlite, so let's make sure we're not doing that
+ # by closing the cursor.
+ #
+ # (*named* cursors in psycopg2 are different and are proper server-
+ # side things, but (a) we don't use them and (b) they are implicitly
+ # closed by ending the transaction anyway.)
+ #
+ # In short, if we haven't finished with the cursor yet, that's a
+ # problem waiting to bite us.
+ #
+ # TL;DR: we're done with the cursor, so we can close it.
+ #
+ # [1]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/connection.c#L465
+ # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
+ cursor.close()
except Exception as e:
logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
@@ -498,7 +444,7 @@ class SQLBaseStore(object):
try:
result = yield self.runWithConnection(
- self._new_transaction,
+ self.new_transaction,
desc,
after_callbacks,
exception_callbacks,
@@ -570,7 +516,7 @@ class SQLBaseStore(object):
results = list(dict(zip(col_headers, row)) for row in cursor)
return results
- def _execute(self, desc, decoder, query, *args):
+ def execute(self, desc, decoder, query, *args):
"""Runs a single query for a result set.
Args:
@@ -595,7 +541,7 @@ class SQLBaseStore(object):
# no complex WHERE clauses, just a dict of values for columns.
@defer.inlineCallbacks
- def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"):
+ def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
"""Executes an INSERT query on the named table.
Args:
@@ -611,7 +557,7 @@ class SQLBaseStore(object):
`or_ignore` is True
"""
try:
- yield self.runInteraction(desc, self._simple_insert_txn, table, values)
+ yield self.runInteraction(desc, self.simple_insert_txn, table, values)
except self.database_engine.module.IntegrityError:
# We have to do or_ignore flag at this layer, since we can't reuse
# a cursor after we receive an error from the db.
@@ -621,7 +567,7 @@ class SQLBaseStore(object):
return True
@staticmethod
- def _simple_insert_txn(txn, table, values):
+ def simple_insert_txn(txn, table, values):
keys, vals = zip(*values.items())
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
@@ -632,11 +578,11 @@ class SQLBaseStore(object):
txn.execute(sql, vals)
- def _simple_insert_many(self, table, values, desc):
- return self.runInteraction(desc, self._simple_insert_many_txn, table, values)
+ def simple_insert_many(self, table, values, desc):
+ return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
@staticmethod
- def _simple_insert_many_txn(txn, table, values):
+ def simple_insert_many_txn(txn, table, values):
if not values:
return
@@ -665,13 +611,13 @@ class SQLBaseStore(object):
txn.executemany(sql, vals)
@defer.inlineCallbacks
- def _simple_upsert(
+ def simple_upsert(
self,
table,
keyvalues,
values,
insertion_values={},
- desc="_simple_upsert",
+ desc="simple_upsert",
lock=True,
):
"""
@@ -703,7 +649,7 @@ class SQLBaseStore(object):
try:
result = yield self.runInteraction(
desc,
- self._simple_upsert_txn,
+ self.simple_upsert_txn,
table,
keyvalues,
values,
@@ -723,7 +669,7 @@ class SQLBaseStore(object):
"IntegrityError when upserting into %s; retrying: %s", table, e
)
- def _simple_upsert_txn(
+ def simple_upsert_txn(
self, txn, table, keyvalues, values, insertion_values={}, lock=True
):
"""
@@ -747,11 +693,11 @@ class SQLBaseStore(object):
self.database_engine.can_native_upsert
and table not in self._unsafe_to_upsert_tables
):
- return self._simple_upsert_txn_native_upsert(
+ return self.simple_upsert_txn_native_upsert(
txn, table, keyvalues, values, insertion_values=insertion_values
)
else:
- return self._simple_upsert_txn_emulated(
+ return self.simple_upsert_txn_emulated(
txn,
table,
keyvalues,
@@ -760,7 +706,7 @@ class SQLBaseStore(object):
lock=lock,
)
- def _simple_upsert_txn_emulated(
+ def simple_upsert_txn_emulated(
self, txn, table, keyvalues, values, insertion_values={}, lock=True
):
"""
@@ -829,7 +775,7 @@ class SQLBaseStore(object):
# successfully inserted
return True
- def _simple_upsert_txn_native_upsert(
+ def simple_upsert_txn_native_upsert(
self, txn, table, keyvalues, values, insertion_values={}
):
"""
@@ -854,7 +800,7 @@ class SQLBaseStore(object):
allvalues.update(values)
latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
- sql = ("INSERT INTO %s (%s) VALUES (%s) " "ON CONFLICT (%s) DO %s") % (
+ sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % (
table,
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues),
@@ -863,7 +809,7 @@ class SQLBaseStore(object):
)
txn.execute(sql, list(allvalues.values()))
- def _simple_upsert_many_txn(
+ def simple_upsert_many_txn(
self, txn, table, key_names, key_values, value_names, value_values
):
"""
@@ -883,15 +829,15 @@ class SQLBaseStore(object):
self.database_engine.can_native_upsert
and table not in self._unsafe_to_upsert_tables
):
- return self._simple_upsert_many_txn_native_upsert(
+ return self.simple_upsert_many_txn_native_upsert(
txn, table, key_names, key_values, value_names, value_values
)
else:
- return self._simple_upsert_many_txn_emulated(
+ return self.simple_upsert_many_txn_emulated(
txn, table, key_names, key_values, value_names, value_values
)
- def _simple_upsert_many_txn_emulated(
+ def simple_upsert_many_txn_emulated(
self, txn, table, key_names, key_values, value_names, value_values
):
"""
@@ -916,9 +862,9 @@ class SQLBaseStore(object):
_keys = {x: y for x, y in zip(key_names, keyv)}
_vals = {x: y for x, y in zip(value_names, valv)}
- self._simple_upsert_txn_emulated(txn, table, _keys, _vals)
+ self.simple_upsert_txn_emulated(txn, table, _keys, _vals)
- def _simple_upsert_many_txn_native_upsert(
+ def simple_upsert_many_txn_native_upsert(
self, txn, table, key_names, key_values, value_names, value_values
):
"""
@@ -963,8 +909,8 @@ class SQLBaseStore(object):
return txn.execute_batch(sql, args)
- def _simple_select_one(
- self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"
+ def simple_select_one(
+ self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one"
):
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it.
@@ -978,16 +924,16 @@ class SQLBaseStore(object):
statement returns no rows
"""
return self.runInteraction(
- desc, self._simple_select_one_txn, table, keyvalues, retcols, allow_none
+ desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
)
- def _simple_select_one_onecol(
+ def simple_select_one_onecol(
self,
table,
keyvalues,
retcol,
allow_none=False,
- desc="_simple_select_one_onecol",
+ desc="simple_select_one_onecol",
):
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it.
@@ -999,7 +945,7 @@ class SQLBaseStore(object):
"""
return self.runInteraction(
desc,
- self._simple_select_one_onecol_txn,
+ self.simple_select_one_onecol_txn,
table,
keyvalues,
retcol,
@@ -1007,10 +953,10 @@ class SQLBaseStore(object):
)
@classmethod
- def _simple_select_one_onecol_txn(
+ def simple_select_one_onecol_txn(
cls, txn, table, keyvalues, retcol, allow_none=False
):
- ret = cls._simple_select_onecol_txn(
+ ret = cls.simple_select_onecol_txn(
txn, table=table, keyvalues=keyvalues, retcol=retcol
)
@@ -1023,7 +969,7 @@ class SQLBaseStore(object):
raise StoreError(404, "No row found")
@staticmethod
- def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
+ def simple_select_onecol_txn(txn, table, keyvalues, retcol):
sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
if keyvalues:
@@ -1034,8 +980,8 @@ class SQLBaseStore(object):
return [r[0] for r in txn]
- def _simple_select_onecol(
- self, table, keyvalues, retcol, desc="_simple_select_onecol"
+ def simple_select_onecol(
+ self, table, keyvalues, retcol, desc="simple_select_onecol"
):
"""Executes a SELECT query on the named table, which returns a list
comprising of the values of the named column from the selected rows.
@@ -1049,12 +995,10 @@ class SQLBaseStore(object):
Deferred: Results in a list
"""
return self.runInteraction(
- desc, self._simple_select_onecol_txn, table, keyvalues, retcol
+ desc, self.simple_select_onecol_txn, table, keyvalues, retcol
)
- def _simple_select_list(
- self, table, keyvalues, retcols, desc="_simple_select_list"
- ):
+ def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -1068,11 +1012,11 @@ class SQLBaseStore(object):
defer.Deferred: resolves to list[dict[str, Any]]
"""
return self.runInteraction(
- desc, self._simple_select_list_txn, table, keyvalues, retcols
+ desc, self.simple_select_list_txn, table, keyvalues, retcols
)
@classmethod
- def _simple_select_list_txn(cls, txn, table, keyvalues, retcols):
+ def simple_select_list_txn(cls, txn, table, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -1098,14 +1042,14 @@ class SQLBaseStore(object):
return cls.cursor_to_dict(txn)
@defer.inlineCallbacks
- def _simple_select_many_batch(
+ def simple_select_many_batch(
self,
table,
column,
iterable,
retcols,
keyvalues={},
- desc="_simple_select_many_batch",
+ desc="simple_select_many_batch",
batch_size=100,
):
"""Executes a SELECT query on the named table, which may return zero or
@@ -1134,7 +1078,7 @@ class SQLBaseStore(object):
for chunk in chunks:
rows = yield self.runInteraction(
desc,
- self._simple_select_many_txn,
+ self.simple_select_many_txn,
table,
column,
chunk,
@@ -1147,7 +1091,7 @@ class SQLBaseStore(object):
return results
@classmethod
- def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
+ def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -1180,13 +1124,13 @@ class SQLBaseStore(object):
txn.execute(sql, values)
return cls.cursor_to_dict(txn)
- def _simple_update(self, table, keyvalues, updatevalues, desc):
+ def simple_update(self, table, keyvalues, updatevalues, desc):
return self.runInteraction(
- desc, self._simple_update_txn, table, keyvalues, updatevalues
+ desc, self.simple_update_txn, table, keyvalues, updatevalues
)
@staticmethod
- def _simple_update_txn(txn, table, keyvalues, updatevalues):
+ def simple_update_txn(txn, table, keyvalues, updatevalues):
if keyvalues:
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
else:
@@ -1202,8 +1146,8 @@ class SQLBaseStore(object):
return txn.rowcount
- def _simple_update_one(
- self, table, keyvalues, updatevalues, desc="_simple_update_one"
+ def simple_update_one(
+ self, table, keyvalues, updatevalues, desc="simple_update_one"
):
"""Executes an UPDATE query on the named table, setting new values for
columns in a row matching the key values.
@@ -1223,12 +1167,12 @@ class SQLBaseStore(object):
the update column in the 'keyvalues' dict as well.
"""
return self.runInteraction(
- desc, self._simple_update_one_txn, table, keyvalues, updatevalues
+ desc, self.simple_update_one_txn, table, keyvalues, updatevalues
)
@classmethod
- def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
- rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues)
+ def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
+ rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
if rowcount == 0:
raise StoreError(404, "No row found (%s)" % (table,))
@@ -1236,7 +1180,7 @@ class SQLBaseStore(object):
raise StoreError(500, "More than one row matched (%s)" % (table,))
@staticmethod
- def _simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
+ def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
select_sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
@@ -1255,7 +1199,7 @@ class SQLBaseStore(object):
return dict(zip(retcols, row))
- def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"):
+ def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"):
"""Executes a DELETE query on the named table, expecting to delete a
single row.
@@ -1263,10 +1207,10 @@ class SQLBaseStore(object):
table : string giving the table name
keyvalues : dict of column names and values to select the row with
"""
- return self.runInteraction(desc, self._simple_delete_one_txn, table, keyvalues)
+ return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
@staticmethod
- def _simple_delete_one_txn(txn, table, keyvalues):
+ def simple_delete_one_txn(txn, table, keyvalues):
"""Executes a DELETE query on the named table, expecting to delete a
single row.
@@ -1285,11 +1229,11 @@ class SQLBaseStore(object):
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched (%s)" % (table,))
- def _simple_delete(self, table, keyvalues, desc):
- return self.runInteraction(desc, self._simple_delete_txn, table, keyvalues)
+ def simple_delete(self, table, keyvalues, desc):
+ return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
@staticmethod
- def _simple_delete_txn(txn, table, keyvalues):
+ def simple_delete_txn(txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % (
table,
" AND ".join("%s = ?" % (k,) for k in keyvalues),
@@ -1298,13 +1242,13 @@ class SQLBaseStore(object):
txn.execute(sql, list(keyvalues.values()))
return txn.rowcount
- def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
+ def simple_delete_many(self, table, column, iterable, keyvalues, desc):
return self.runInteraction(
- desc, self._simple_delete_many_txn, table, column, iterable, keyvalues
+ desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
)
@staticmethod
- def _simple_delete_many_txn(txn, table, column, iterable, keyvalues):
+ def simple_delete_many_txn(txn, table, column, iterable, keyvalues):
"""Executes a DELETE query on the named table.
Filters rows by if value of `column` is in `iterable`.
@@ -1337,7 +1281,7 @@ class SQLBaseStore(object):
return txn.rowcount
- def _get_cache_dict(
+ def get_cache_dict(
self, db_conn, table, entity_column, stream_column, max_value, limit=100000
):
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
@@ -1370,47 +1314,6 @@ class SQLBaseStore(object):
return cache, min_val
- def _invalidate_cache_and_stream(self, txn, cache_func, keys):
- """Invalidates the cache and adds it to the cache stream so slaves
- will know to invalidate their caches.
-
- This should only be used to invalidate caches where slaves won't
- otherwise know from other replication streams that the cache should
- be invalidated.
- """
- txn.call_after(cache_func.invalidate, keys)
- self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
-
- def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed):
- """Special case invalidation of caches based on current state.
-
- We special case this so that we can batch the cache invalidations into a
- single replication poke.
-
- Args:
- txn
- room_id (str): Room where state changed
- members_changed (iterable[str]): The user_ids of members that have changed
- """
- txn.call_after(self._invalidate_state_caches, room_id, members_changed)
-
- if members_changed:
- # We need to be careful that the size of the `members_changed` list
- # isn't so large that it causes problems sending over replication, so we
- # send them in chunks.
- # Max line length is 16K, and max user ID length is 255, so 50 should
- # be safe.
- for chunk in batch_iter(members_changed, 50):
- keys = itertools.chain([room_id], chunk)
- self._send_invalidation_to_replication(
- txn, _CURRENT_STATE_CACHE_NAME, keys
- )
- else:
- # if no members changed, we still need to invalidate the other caches.
- self._send_invalidation_to_replication(
- txn, _CURRENT_STATE_CACHE_NAME, [room_id]
- )
-
def _invalidate_state_caches(self, room_id, members_changed):
"""Invalidates caches that are based on the current state, but does
not stream invalidations down replication.
@@ -1444,73 +1347,17 @@ class SQLBaseStore(object):
# which is fine.
pass
- def _send_invalidation_to_replication(self, txn, cache_name, keys):
- """Notifies replication that given cache has been invalidated.
-
- Note that this does *not* invalidate the cache locally.
-
- Args:
- txn
- cache_name (str)
- keys (iterable[str])
- """
-
- if isinstance(self.database_engine, PostgresEngine):
- # get_next() returns a context manager which is designed to wrap
- # the transaction. However, we want to only get an ID when we want
- # to use it, here, so we need to call __enter__ manually, and have
- # __exit__ called after the transaction finishes.
- ctx = self._cache_id_gen.get_next()
- stream_id = ctx.__enter__()
- txn.call_on_exception(ctx.__exit__, None, None, None)
- txn.call_after(ctx.__exit__, None, None, None)
- txn.call_after(self.hs.get_notifier().on_new_replication_data)
-
- self._simple_insert_txn(
- txn,
- table="cache_invalidation_stream",
- values={
- "stream_id": stream_id,
- "cache_func": cache_name,
- "keys": list(keys),
- "invalidation_ts": self.clock.time_msec(),
- },
- )
-
- def get_all_updated_caches(self, last_id, current_id, limit):
- if last_id == current_id:
- return defer.succeed([])
-
- def get_all_updated_caches_txn(txn):
- # We purposefully don't bound by the current token, as we want to
- # send across cache invalidations as quickly as possible. Cache
- # invalidations are idempotent, so duplicates are fine.
- sql = (
- "SELECT stream_id, cache_func, keys, invalidation_ts"
- " FROM cache_invalidation_stream"
- " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
- )
- txn.execute(sql, (last_id, limit))
- return txn.fetchall()
-
- return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn)
-
- def get_cache_stream_token(self):
- if self._cache_id_gen:
- return self._cache_id_gen.get_current_token()
- else:
- return 0
-
- def _simple_select_list_paginate(
+ def simple_select_list_paginate(
self,
table,
- keyvalues,
orderby,
start,
limit,
retcols,
+ filters=None,
+ keyvalues=None,
order_direction="ASC",
- desc="_simple_select_list_paginate",
+ desc="simple_select_list_paginate",
):
"""
Executes a SELECT query on the named table with start and limit,
@@ -1519,6 +1366,9 @@ class SQLBaseStore(object):
Args:
table (str): the table name
+ filters (dict[str, T] | None):
+ column names and values to filter the rows with, or None to not
+ apply a WHERE ? LIKE ? clause.
keyvalues (dict[str, T] | None):
column names and values to select the rows with, or None to not
apply a WHERE clause.
@@ -1532,26 +1382,28 @@ class SQLBaseStore(object):
"""
return self.runInteraction(
desc,
- self._simple_select_list_paginate_txn,
+ self.simple_select_list_paginate_txn,
table,
- keyvalues,
orderby,
start,
limit,
retcols,
+ filters=filters,
+ keyvalues=keyvalues,
order_direction=order_direction,
)
@classmethod
- def _simple_select_list_paginate_txn(
+ def simple_select_list_paginate_txn(
cls,
txn,
table,
- keyvalues,
orderby,
start,
limit,
retcols,
+ filters=None,
+ keyvalues=None,
order_direction="ASC",
):
"""
@@ -1559,16 +1411,23 @@ class SQLBaseStore(object):
of row numbers, which may return zero or number of rows from start to limit,
returning the result as a list of dicts.
+ Use `filters` to search attributes using SQL wildcards and/or `keyvalues` to
+ select attributes with exact matches. All constraints are joined together
+ using 'AND'.
+
Args:
txn : Transaction object
table (str): the table name
- keyvalues (dict[str, T] | None):
- column names and values to select the rows with, or None to not
- apply a WHERE clause.
orderby (str): Column to order the results by.
start (int): Index to begin the query at.
limit (int): Number of results to return.
retcols (iterable[str]): the names of the columns to return
+ filters (dict[str, T] | None):
+ column names and values to filter the rows with, or None to not
+ apply a WHERE ? LIKE ? clause.
+ keyvalues (dict[str, T] | None):
+ column names and values to select the rows with, or None to not
+ apply a WHERE clause.
order_direction (str): Whether the results should be ordered "ASC" or "DESC".
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
@@ -1576,10 +1435,15 @@ class SQLBaseStore(object):
if order_direction not in ["ASC", "DESC"]:
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
+ where_clause = "WHERE " if filters or keyvalues else ""
+ arg_list = []
+ if filters:
+ where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
+ arg_list += list(filters.values())
+ where_clause += " AND " if filters and keyvalues else ""
if keyvalues:
- where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues)
- else:
- where_clause = ""
+ where_clause += " AND ".join("%s = ?" % (k,) for k in keyvalues)
+ arg_list += list(keyvalues.values())
sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
", ".join(retcols),
@@ -1588,25 +1452,11 @@ class SQLBaseStore(object):
orderby,
order_direction,
)
- txn.execute(sql, list(keyvalues.values()) + [limit, start])
+ txn.execute(sql, arg_list + [limit, start])
return cls.cursor_to_dict(txn)
- def get_user_count_txn(self, txn):
- """Get a total number of registered users in the users list.
-
- Args:
- txn : Transaction object
- Returns:
- int : number of users
- """
- sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;"
- txn.execute(sql_count)
- return txn.fetchone()[0]
-
- def _simple_search_list(
- self, table, term, col, retcols, desc="_simple_search_list"
- ):
+ def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -1621,11 +1471,11 @@ class SQLBaseStore(object):
"""
return self.runInteraction(
- desc, self._simple_search_list_txn, table, term, col, retcols
+ desc, self.simple_search_list_txn, table, term, col, retcols
)
@classmethod
- def _simple_search_list_txn(cls, txn, table, term, col, retcols):
+ def simple_search_list_txn(cls, txn, table, term, col, retcols):
"""Executes a SELECT query on the named table, which may return zero or
more rows, returning the result as a list of dicts.
@@ -1648,14 +1498,6 @@ class SQLBaseStore(object):
return cls.cursor_to_dict(txn)
- @property
- def database_engine_name(self):
- return self.database_engine.module.__name__
-
- def get_server_version(self):
- """Returns a string describing the server version number"""
- return self.database_engine.server_version
-
class _RollbackButIsFineException(Exception):
""" This exception is used to rollback a transaction without implying
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 37d469ffd7..06955a0537 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -139,7 +139,7 @@ class BackgroundUpdateStore(SQLBaseStore):
# otherwise, check if there are updates to be run. This is important,
# as we may be running on a worker which doesn't perform the bg updates
# itself, but still wants to wait for them to happen.
- updates = yield self._simple_select_onecol(
+ updates = yield self.simple_select_onecol(
"background_updates",
keyvalues=None,
retcol="1",
@@ -161,7 +161,7 @@ class BackgroundUpdateStore(SQLBaseStore):
if update_name in self._background_update_queue:
return False
- update_exists = await self._simple_select_one_onecol(
+ update_exists = await self.simple_select_one_onecol(
"background_updates",
keyvalues={"update_name": update_name},
retcol="1",
@@ -184,7 +184,7 @@ class BackgroundUpdateStore(SQLBaseStore):
no more work to do.
"""
if not self._background_update_queue:
- updates = yield self._simple_select_list(
+ updates = yield self.simple_select_list(
"background_updates",
keyvalues=None,
retcols=("update_name", "depends_on"),
@@ -226,7 +226,7 @@ class BackgroundUpdateStore(SQLBaseStore):
else:
batch_size = self.DEFAULT_BACKGROUND_BATCH_SIZE
- progress_json = yield self._simple_select_one_onecol(
+ progress_json = yield self.simple_select_one_onecol(
"background_updates",
keyvalues={"update_name": update_name},
retcol="progress_json",
@@ -413,7 +413,7 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_queue = []
progress_json = json.dumps(progress)
- return self._simple_insert(
+ return self.simple_insert(
"background_updates",
{"update_name": update_name, "progress_json": progress_json},
)
@@ -429,7 +429,7 @@ class BackgroundUpdateStore(SQLBaseStore):
self._background_update_queue = [
name for name in self._background_update_queue if name != update_name
]
- return self._simple_delete_one(
+ return self.simple_delete_one(
"background_updates", keyvalues={"update_name": update_name}
)
@@ -444,7 +444,7 @@ class BackgroundUpdateStore(SQLBaseStore):
progress_json = json.dumps(progress)
- self._simple_update_one_txn(
+ self.simple_update_one_txn(
txn,
"background_updates",
keyvalues={"update_name": update_name},
diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py
index 10c940df1e..3720ff3088 100644
--- a/synapse/storage/data_stores/main/__init__.py
+++ b/synapse/storage/data_stores/main/__init__.py
@@ -19,8 +19,6 @@ import calendar
import logging
import time
-from twisted.internet import defer
-
from synapse.api.constants import PresenceState
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import (
@@ -32,6 +30,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
from .account_data import AccountDataStore
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
+from .cache import CacheInvalidationStore
from .client_ips import ClientIpStore
from .deviceinbox import DeviceInboxStore
from .devices import DeviceStore
@@ -110,6 +109,7 @@ class DataStore(
MonthlyActiveUsersStore,
StatsStore,
RelationsStore,
+ CacheInvalidationStore,
):
def __init__(self, db_conn, hs):
self.hs = hs
@@ -171,7 +171,7 @@ class DataStore(
self._presence_on_startup = self._get_active_presence(db_conn)
- presence_cache_prefill, min_presence_val = self._get_cache_dict(
+ presence_cache_prefill, min_presence_val = self.get_cache_dict(
db_conn,
"presence_stream",
entity_column="user_id",
@@ -185,7 +185,7 @@ class DataStore(
)
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
- device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
+ device_inbox_prefill, min_device_inbox_id = self.get_cache_dict(
db_conn,
"device_inbox",
entity_column="user_id",
@@ -200,7 +200,7 @@ class DataStore(
)
# The federation outbox and the local device inbox uses the same
# stream_id generator.
- device_outbox_prefill, min_device_outbox_id = self._get_cache_dict(
+ device_outbox_prefill, min_device_outbox_id = self.get_cache_dict(
db_conn,
"device_federation_outbox",
entity_column="destination",
@@ -226,7 +226,7 @@ class DataStore(
)
events_max = self._stream_id_gen.get_current_token()
- curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
+ curr_state_delta_prefill, min_curr_state_delta_id = self.get_cache_dict(
db_conn,
"current_state_delta_stream",
entity_column="room_id",
@@ -240,7 +240,7 @@ class DataStore(
prefilled_cache=curr_state_delta_prefill,
)
- _group_updates_prefill, min_group_updates_id = self._get_cache_dict(
+ _group_updates_prefill, min_group_updates_id = self.get_cache_dict(
db_conn,
"local_group_updates",
entity_column="user_id",
@@ -474,45 +474,68 @@ class DataStore(
)
def get_users(self):
- """Function to reterive a list of users in users table.
+ """Function to retrieve a list of users in users table.
Args:
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
- return self._simple_select_list(
+ return self.simple_select_list(
table="users",
keyvalues={},
- retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
+ retcols=[
+ "name",
+ "password_hash",
+ "is_guest",
+ "admin",
+ "user_type",
+ "deactivated",
+ ],
desc="get_users",
)
- @defer.inlineCallbacks
- def get_users_paginate(self, order, start, limit):
- """Function to reterive a paginated list of users from
- users list. This will return a json object, which contains
- list of users and the total number of users in users table.
+ def get_users_paginate(
+ self, start, limit, name=None, guests=True, deactivated=False
+ ):
+ """Function to retrieve a paginated list of users from
+ users list. This will return a json list of users.
Args:
- order (str): column name to order the select by this column
start (int): start number to begin the query from
- limit (int): number of rows to reterive
+ limit (int): number of rows to retrieve
+ name (string): filter for user names
+ guests (bool): whether to in include guest users
+ deactivated (bool): whether to include deactivated users
Returns:
- defer.Deferred: resolves to json object {list[dict[str, Any]], count}
+ defer.Deferred: resolves to list[dict[str, Any]]
"""
- users = yield self.runInteraction(
- "get_users_paginate",
- self._simple_select_list_paginate_txn,
+ name_filter = {}
+ if name:
+ name_filter["name"] = "%" + name + "%"
+
+ attr_filter = {}
+ if not guests:
+ attr_filter["is_guest"] = False
+ if not deactivated:
+ attr_filter["deactivated"] = False
+
+ return self.simple_select_list_paginate(
+ desc="get_users_paginate",
table="users",
- keyvalues={"is_guest": False},
- orderby=order,
+ orderby="name",
start=start,
limit=limit,
- retcols=["name", "password_hash", "is_guest", "admin", "user_type"],
+ filters=name_filter,
+ keyvalues=attr_filter,
+ retcols=[
+ "name",
+ "password_hash",
+ "is_guest",
+ "admin",
+ "user_type",
+ "deactivated",
+ ],
)
- count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn)
- retval = {"users": users, "total": count}
- return retval
def search_users(self, term):
"""Function to search users list for one or more users with
@@ -524,7 +547,7 @@ class DataStore(
Returns:
defer.Deferred: resolves to list[dict[str, Any]]
"""
- return self._simple_search_list(
+ return self.simple_search_list(
table="users",
term=term,
col="name",
diff --git a/synapse/storage/data_stores/main/account_data.py b/synapse/storage/data_stores/main/account_data.py
index 6afbfc0d74..b0d22faf3f 100644
--- a/synapse/storage/data_stores/main/account_data.py
+++ b/synapse/storage/data_stores/main/account_data.py
@@ -67,7 +67,7 @@ class AccountDataWorkerStore(SQLBaseStore):
"""
def get_account_data_for_user_txn(txn):
- rows = self._simple_select_list_txn(
+ rows = self.simple_select_list_txn(
txn,
"account_data",
{"user_id": user_id},
@@ -78,7 +78,7 @@ class AccountDataWorkerStore(SQLBaseStore):
row["account_data_type"]: json.loads(row["content"]) for row in rows
}
- rows = self._simple_select_list_txn(
+ rows = self.simple_select_list_txn(
txn,
"room_account_data",
{"user_id": user_id},
@@ -102,7 +102,7 @@ class AccountDataWorkerStore(SQLBaseStore):
Returns:
Deferred: A dict
"""
- result = yield self._simple_select_one_onecol(
+ result = yield self.simple_select_one_onecol(
table="account_data",
keyvalues={"user_id": user_id, "account_data_type": data_type},
retcol="content",
@@ -127,7 +127,7 @@ class AccountDataWorkerStore(SQLBaseStore):
"""
def get_account_data_for_room_txn(txn):
- rows = self._simple_select_list_txn(
+ rows = self.simple_select_list_txn(
txn,
"room_account_data",
{"user_id": user_id, "room_id": room_id},
@@ -156,7 +156,7 @@ class AccountDataWorkerStore(SQLBaseStore):
"""
def get_account_data_for_room_and_type_txn(txn):
- content_json = self._simple_select_one_onecol_txn(
+ content_json = self.simple_select_one_onecol_txn(
txn,
table="room_account_data",
keyvalues={
@@ -184,14 +184,14 @@ class AccountDataWorkerStore(SQLBaseStore):
current_id(int): The position to fetch up to.
Returns:
A deferred pair of lists of tuples of stream_id int, user_id string,
- room_id string, type string, and content string.
+ room_id string, and type string.
"""
if last_room_id == current_id and last_global_id == current_id:
return defer.succeed(([], []))
def get_updated_account_data_txn(txn):
sql = (
- "SELECT stream_id, user_id, account_data_type, content"
+ "SELECT stream_id, user_id, account_data_type"
" FROM account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
@@ -199,7 +199,7 @@ class AccountDataWorkerStore(SQLBaseStore):
global_results = txn.fetchall()
sql = (
- "SELECT stream_id, user_id, room_id, account_data_type, content"
+ "SELECT stream_id, user_id, room_id, account_data_type"
" FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC LIMIT ?"
)
@@ -300,9 +300,9 @@ class AccountDataStore(AccountDataWorkerStore):
with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as room_account_data has a unique constraint
- # on (user_id, room_id, account_data_type) so _simple_upsert will
+ # on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict.
- yield self._simple_upsert(
+ yield self.simple_upsert(
desc="add_room_account_data",
table="room_account_data",
keyvalues={
@@ -346,9 +346,9 @@ class AccountDataStore(AccountDataWorkerStore):
with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as account_data has a unique constraint on
- # (user_id, account_data_type) so _simple_upsert will retry if
+ # (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict.
- yield self._simple_upsert(
+ yield self.simple_upsert(
desc="add_user_account_data",
table="account_data",
keyvalues={"user_id": user_id, "account_data_type": account_data_type},
diff --git a/synapse/storage/data_stores/main/appservice.py b/synapse/storage/data_stores/main/appservice.py
index 81babf2029..6b82fd392a 100644
--- a/synapse/storage/data_stores/main/appservice.py
+++ b/synapse/storage/data_stores/main/appservice.py
@@ -133,7 +133,7 @@ class ApplicationServiceTransactionWorkerStore(
A Deferred which resolves to a list of ApplicationServices, which
may be empty.
"""
- results = yield self._simple_select_list(
+ results = yield self.simple_select_list(
"application_services_state", dict(state=state), ["as_id"]
)
# NB: This assumes this class is linked with ApplicationServiceStore
@@ -155,7 +155,7 @@ class ApplicationServiceTransactionWorkerStore(
Returns:
A Deferred which resolves to ApplicationServiceState.
"""
- result = yield self._simple_select_one(
+ result = yield self.simple_select_one(
"application_services_state",
dict(as_id=service.id),
["state"],
@@ -175,7 +175,7 @@ class ApplicationServiceTransactionWorkerStore(
Returns:
A Deferred which resolves when the state was set successfully.
"""
- return self._simple_upsert(
+ return self.simple_upsert(
"application_services_state", dict(as_id=service.id), dict(state=state)
)
@@ -249,7 +249,7 @@ class ApplicationServiceTransactionWorkerStore(
)
# Set current txn_id for AS to 'txn_id'
- self._simple_upsert_txn(
+ self.simple_upsert_txn(
txn,
"application_services_state",
dict(as_id=service.id),
@@ -257,7 +257,7 @@ class ApplicationServiceTransactionWorkerStore(
)
# Delete txn
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id)
)
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
new file mode 100644
index 0000000000..de3256049d
--- /dev/null
+++ b/synapse/storage/data_stores/main/cache.py
@@ -0,0 +1,131 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import itertools
+import logging
+
+from twisted.internet import defer
+
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.engines import PostgresEngine
+from synapse.util import batch_iter
+
+logger = logging.getLogger(__name__)
+
+
+# This is a special cache name we use to batch multiple invalidations of caches
+# based on the current state when notifying workers over replication.
+CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
+
+
+class CacheInvalidationStore(SQLBaseStore):
+ def _invalidate_cache_and_stream(self, txn, cache_func, keys):
+ """Invalidates the cache and adds it to the cache stream so slaves
+ will know to invalidate their caches.
+
+ This should only be used to invalidate caches where slaves won't
+ otherwise know from other replication streams that the cache should
+ be invalidated.
+ """
+ txn.call_after(cache_func.invalidate, keys)
+ self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
+
+ def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed):
+ """Special case invalidation of caches based on current state.
+
+ We special case this so that we can batch the cache invalidations into a
+ single replication poke.
+
+ Args:
+ txn
+ room_id (str): Room where state changed
+ members_changed (iterable[str]): The user_ids of members that have changed
+ """
+ txn.call_after(self._invalidate_state_caches, room_id, members_changed)
+
+ if members_changed:
+ # We need to be careful that the size of the `members_changed` list
+ # isn't so large that it causes problems sending over replication, so we
+ # send them in chunks.
+ # Max line length is 16K, and max user ID length is 255, so 50 should
+ # be safe.
+ for chunk in batch_iter(members_changed, 50):
+ keys = itertools.chain([room_id], chunk)
+ self._send_invalidation_to_replication(
+ txn, CURRENT_STATE_CACHE_NAME, keys
+ )
+ else:
+ # if no members changed, we still need to invalidate the other caches.
+ self._send_invalidation_to_replication(
+ txn, CURRENT_STATE_CACHE_NAME, [room_id]
+ )
+
+ def _send_invalidation_to_replication(self, txn, cache_name, keys):
+ """Notifies replication that given cache has been invalidated.
+
+ Note that this does *not* invalidate the cache locally.
+
+ Args:
+ txn
+ cache_name (str)
+ keys (iterable[str])
+ """
+
+ if isinstance(self.database_engine, PostgresEngine):
+ # get_next() returns a context manager which is designed to wrap
+ # the transaction. However, we want to only get an ID when we want
+ # to use it, here, so we need to call __enter__ manually, and have
+ # __exit__ called after the transaction finishes.
+ ctx = self._cache_id_gen.get_next()
+ stream_id = ctx.__enter__()
+ txn.call_on_exception(ctx.__exit__, None, None, None)
+ txn.call_after(ctx.__exit__, None, None, None)
+ txn.call_after(self.hs.get_notifier().on_new_replication_data)
+
+ self.simple_insert_txn(
+ txn,
+ table="cache_invalidation_stream",
+ values={
+ "stream_id": stream_id,
+ "cache_func": cache_name,
+ "keys": list(keys),
+ "invalidation_ts": self.clock.time_msec(),
+ },
+ )
+
+ def get_all_updated_caches(self, last_id, current_id, limit):
+ if last_id == current_id:
+ return defer.succeed([])
+
+ def get_all_updated_caches_txn(txn):
+ # We purposefully don't bound by the current token, as we want to
+ # send across cache invalidations as quickly as possible. Cache
+ # invalidations are idempotent, so duplicates are fine.
+ sql = (
+ "SELECT stream_id, cache_func, keys, invalidation_ts"
+ " FROM cache_invalidation_stream"
+ " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
+ )
+ txn.execute(sql, (last_id, limit))
+ return txn.fetchall()
+
+ return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn)
+
+ def get_cache_stream_token(self):
+ if self._cache_id_gen:
+ return self._cache_id_gen.get_current_token()
+ else:
+ return 0
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index 706c6a1f3f..66522a04b7 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -21,8 +21,8 @@ from twisted.internet import defer
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage import background_updates
-from synapse.storage._base import Cache
from synapse.util.caches import CACHE_SIZE_FACTOR
+from synapse.util.caches.descriptors import Cache
logger = logging.getLogger(__name__)
@@ -431,7 +431,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
try:
- self._simple_upsert_txn(
+ self.simple_upsert_txn(
txn,
table="user_ips",
keyvalues={
@@ -450,7 +450,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
# Technically an access token might not be associated with
# a device so we need to check.
if device_id:
- self._simple_upsert_txn(
+ self.simple_upsert_txn(
txn,
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -483,7 +483,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
if device_id is not None:
keyvalues["device_id"] = device_id
- res = yield self._simple_select_list(
+ res = yield self.simple_select_list(
table="devices",
keyvalues=keyvalues,
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
@@ -516,7 +516,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
user_agent, _, last_seen = self._batch_row_update[key]
results[(access_token, ip)] = (user_agent, last_seen)
- rows = yield self._simple_select_list(
+ rows = yield self.simple_select_list(
table="user_ips",
keyvalues={"user_id": user_id},
retcols=["access_token", "ip", "user_agent", "last_seen"],
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index f04aad0743..206d39134d 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -314,7 +314,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
# Check if we've already inserted a matching message_id for that
# origin. This can happen if the origin doesn't receive our
# acknowledgement from the first time we received the message.
- already_inserted = self._simple_select_one_txn(
+ already_inserted = self.simple_select_one_txn(
txn,
table="device_federation_inbox",
keyvalues={"origin": origin, "message_id": message_id},
@@ -326,7 +326,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
# Add an entry for this message_id so that we know we've processed
# it.
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="device_federation_inbox",
values={
@@ -358,8 +358,21 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
def _add_messages_to_local_device_inbox_txn(
self, txn, stream_id, messages_by_user_then_device
):
- sql = "UPDATE device_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?"
- txn.execute(sql, (stream_id, stream_id))
+ # Compatible method of performing an upsert
+ sql = "SELECT stream_id FROM device_max_stream_id"
+
+ txn.execute(sql)
+ rows = txn.fetchone()
+ if rows:
+ db_stream_id = rows[0]
+ if db_stream_id < stream_id:
+ # Insert the new stream_id
+ sql = "UPDATE device_max_stream_id SET stream_id = ?"
+ else:
+ # No rows, perform an insert
+ sql = "INSERT INTO device_max_stream_id (stream_id) VALUES (?)"
+
+ txn.execute(sql, (stream_id,))
local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items():
@@ -367,7 +380,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
devices = list(messages_by_device.keys())
if len(devices) == 1 and devices[0] == "*":
# Handle wildcard device_ids.
- sql = "SELECT device_id FROM devices" " WHERE user_id = ?"
+ sql = "SELECT device_id FROM devices WHERE user_id = ?"
txn.execute(sql, (user_id,))
message_json = json.dumps(messages_by_device["*"])
for row in txn:
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 71f62036c0..727c582121 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -30,16 +30,16 @@ from synapse.logging.opentracing import (
whitelisted_homeserver,
)
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage._base import (
- Cache,
- SQLBaseStore,
- db_to_json,
- make_in_list_sql_clause,
-)
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.types import get_verify_key_from_cross_signing_key
from synapse.util import batch_iter
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util.caches.descriptors import (
+ Cache,
+ cached,
+ cachedInlineCallbacks,
+ cachedList,
+)
logger = logging.getLogger(__name__)
@@ -61,7 +61,7 @@ class DeviceWorkerStore(SQLBaseStore):
Raises:
StoreError: if the device is not found
"""
- return self._simple_select_one(
+ return self.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
@@ -80,7 +80,7 @@ class DeviceWorkerStore(SQLBaseStore):
containing "device_id", "user_id" and "display_name" for each
device.
"""
- devices = yield self._simple_select_list(
+ devices = yield self.simple_select_list(
table="devices",
keyvalues={"user_id": user_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
@@ -414,7 +414,7 @@ class DeviceWorkerStore(SQLBaseStore):
from_user_id,
stream_id,
)
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
"user_signature_stream",
values={
@@ -466,7 +466,7 @@ class DeviceWorkerStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=2, tree=True)
def _get_cached_user_device(self, user_id, device_id):
- content = yield self._simple_select_one_onecol(
+ content = yield self.simple_select_one_onecol(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="content",
@@ -476,7 +476,7 @@ class DeviceWorkerStore(SQLBaseStore):
@cachedInlineCallbacks()
def _get_cached_devices_for_user(self, user_id):
- devices = yield self._simple_select_list(
+ devices = yield self.simple_select_list(
table="device_lists_remote_cache",
keyvalues={"user_id": user_id},
retcols=("device_id", "content"),
@@ -584,7 +584,7 @@ class DeviceWorkerStore(SQLBaseStore):
SELECT DISTINCT user_ids FROM user_signature_stream
WHERE from_user_id = ? AND stream_id > ?
"""
- rows = yield self._execute(
+ rows = yield self.execute(
"get_users_whose_signatures_changed", None, sql, user_id, from_key
)
return set(user for row in rows for user in json.loads(row[0]))
@@ -605,7 +605,7 @@ class DeviceWorkerStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
GROUP BY user_id, destination
"""
- return self._execute(
+ return self.execute(
"get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
)
@@ -614,7 +614,7 @@ class DeviceWorkerStore(SQLBaseStore):
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
- return self._simple_select_one_onecol(
+ return self.simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",
@@ -628,7 +628,7 @@ class DeviceWorkerStore(SQLBaseStore):
inlineCallbacks=True,
)
def get_device_list_last_stream_id_for_remotes(self, user_ids):
- rows = yield self._simple_select_many_batch(
+ rows = yield self.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,
@@ -722,7 +722,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return False
try:
- inserted = yield self._simple_insert(
+ inserted = yield self.simple_insert(
"devices",
values={
"user_id": user_id,
@@ -736,7 +736,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not inserted:
# if the device already exists, check if it's a real device, or
# if the device ID is reserved by something else
- hidden = yield self._simple_select_one_onecol(
+ hidden = yield self.simple_select_one_onecol(
"devices",
keyvalues={"user_id": user_id, "device_id": device_id},
retcol="hidden",
@@ -771,7 +771,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
Returns:
defer.Deferred
"""
- yield self._simple_delete_one(
+ yield self.simple_delete_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
desc="delete_device",
@@ -789,7 +789,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
Returns:
defer.Deferred
"""
- yield self._simple_delete_many(
+ yield self.simple_delete_many(
table="devices",
column="device_id",
iterable=device_ids,
@@ -818,7 +818,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
updates["display_name"] = new_display_name
if not updates:
return defer.succeed(None)
- return self._simple_update_one(
+ return self.simple_update_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
updatevalues=updates,
@@ -829,7 +829,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
"""Mark that we no longer track device lists for remote user.
"""
- yield self._simple_delete(
+ yield self.simple_delete(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
desc="mark_remote_user_device_list_as_unsubscribed",
@@ -866,7 +866,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self, txn, user_id, device_id, content, stream_id
):
if content.get("deleted"):
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -874,7 +874,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
txn.call_after(self.device_id_exists_cache.invalidate, (user_id, device_id))
else:
- self._simple_upsert_txn(
+ self.simple_upsert_txn(
txn,
table="device_lists_remote_cache",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -890,7 +890,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
)
- self._simple_upsert_txn(
+ self.simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
@@ -923,11 +923,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
def _update_remote_device_list_cache_txn(self, txn, user_id, devices, stream_id):
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
)
- self._simple_insert_many_txn(
+ self.simple_insert_many_txn(
txn,
table="device_lists_remote_cache",
values=[
@@ -946,7 +946,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,)
)
- self._simple_upsert_txn(
+ self.simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
@@ -995,7 +995,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
[(user_id, device_id, stream_id) for device_id in device_ids],
)
- self._simple_insert_many_txn(
+ self.simple_insert_many_txn(
txn,
table="device_lists_stream",
values=[
@@ -1006,7 +1006,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
context = get_active_span_text_map()
- self._simple_insert_many_txn(
+ self.simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
values=[
diff --git a/synapse/storage/data_stores/main/directory.py b/synapse/storage/data_stores/main/directory.py
index 297966d9f4..d332f8a409 100644
--- a/synapse/storage/data_stores/main/directory.py
+++ b/synapse/storage/data_stores/main/directory.py
@@ -36,7 +36,7 @@ class DirectoryWorkerStore(SQLBaseStore):
Deferred: results in namedtuple with keys "room_id" and
"servers" or None if no association can be found
"""
- room_id = yield self._simple_select_one_onecol(
+ room_id = yield self.simple_select_one_onecol(
"room_aliases",
{"room_alias": room_alias.to_string()},
"room_id",
@@ -47,7 +47,7 @@ class DirectoryWorkerStore(SQLBaseStore):
if not room_id:
return None
- servers = yield self._simple_select_onecol(
+ servers = yield self.simple_select_onecol(
"room_alias_servers",
{"room_alias": room_alias.to_string()},
"server",
@@ -60,7 +60,7 @@ class DirectoryWorkerStore(SQLBaseStore):
return RoomAliasMapping(room_id, room_alias.to_string(), servers)
def get_room_alias_creator(self, room_alias):
- return self._simple_select_one_onecol(
+ return self.simple_select_one_onecol(
table="room_aliases",
keyvalues={"room_alias": room_alias},
retcol="creator",
@@ -69,7 +69,7 @@ class DirectoryWorkerStore(SQLBaseStore):
@cached(max_entries=5000)
def get_aliases_for_room(self, room_id):
- return self._simple_select_onecol(
+ return self.simple_select_onecol(
"room_aliases",
{"room_id": room_id},
"room_alias",
@@ -93,7 +93,7 @@ class DirectoryStore(DirectoryWorkerStore):
"""
def alias_txn(txn):
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
"room_aliases",
{
@@ -103,7 +103,7 @@ class DirectoryStore(DirectoryWorkerStore):
},
)
- self._simple_insert_many_txn(
+ self.simple_insert_many_txn(
txn,
table="room_alias_servers",
values=[
diff --git a/synapse/storage/data_stores/main/e2e_room_keys.py b/synapse/storage/data_stores/main/e2e_room_keys.py
index 1cbbae5b63..df89eda337 100644
--- a/synapse/storage/data_stores/main/e2e_room_keys.py
+++ b/synapse/storage/data_stores/main/e2e_room_keys.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2017 New Vector Ltd
+# Copyright 2019 Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -24,49 +25,8 @@ from synapse.storage._base import SQLBaseStore
class EndToEndRoomKeyStore(SQLBaseStore):
@defer.inlineCallbacks
- def get_e2e_room_key(self, user_id, version, room_id, session_id):
- """Get the encrypted E2E room key for a given session from a given
- backup version of room_keys. We only store the 'best' room key for a given
- session at a given time, as determined by the handler.
-
- Args:
- user_id(str): the user whose backup we're querying
- version(str): the version ID of the backup for the set of keys we're querying
- room_id(str): the ID of the room whose keys we're querying.
- This is a bit redundant as it's implied by the session_id, but
- we include for consistency with the rest of the API.
- session_id(str): the session whose room_key we're querying.
-
- Returns:
- A deferred dict giving the session_data and message metadata for
- this room key.
- """
-
- row = yield self._simple_select_one(
- table="e2e_room_keys",
- keyvalues={
- "user_id": user_id,
- "version": version,
- "room_id": room_id,
- "session_id": session_id,
- },
- retcols=(
- "first_message_index",
- "forwarded_count",
- "is_verified",
- "session_data",
- ),
- desc="get_e2e_room_key",
- )
-
- row["session_data"] = json.loads(row["session_data"])
-
- return row
-
- @defer.inlineCallbacks
- def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
- """Replaces or inserts the encrypted E2E room key for a given session in
- a given backup
+ def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
+ """Replaces the encrypted E2E room key for a given session in a given backup
Args:
user_id(str): the user whose backup we're setting
@@ -78,7 +38,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
StoreError
"""
- yield self._simple_upsert(
+ yield self.simple_update_one(
table="e2e_room_keys",
keyvalues={
"user_id": user_id,
@@ -86,21 +46,51 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"room_id": room_id,
"session_id": session_id,
},
- values={
+ updatevalues={
"first_message_index": room_key["first_message_index"],
"forwarded_count": room_key["forwarded_count"],
"is_verified": room_key["is_verified"],
"session_data": json.dumps(room_key["session_data"]),
},
- lock=False,
+ desc="update_e2e_room_key",
)
- log_kv(
- {
- "message": "Set room key",
- "room_id": room_id,
- "session_id": session_id,
- "room_key": room_key,
- }
+
+ @defer.inlineCallbacks
+ def add_e2e_room_keys(self, user_id, version, room_keys):
+ """Bulk add room keys to a given backup.
+
+ Args:
+ user_id (str): the user whose backup we're adding to
+ version (str): the version ID of the backup for the set of keys we're adding to
+ room_keys (iterable[(str, str, dict)]): the keys to add, in the form
+ (roomID, sessionID, keyData)
+ """
+
+ values = []
+ for (room_id, session_id, room_key) in room_keys:
+ values.append(
+ {
+ "user_id": user_id,
+ "version": version,
+ "room_id": room_id,
+ "session_id": session_id,
+ "first_message_index": room_key["first_message_index"],
+ "forwarded_count": room_key["forwarded_count"],
+ "is_verified": room_key["is_verified"],
+ "session_data": json.dumps(room_key["session_data"]),
+ }
+ )
+ log_kv(
+ {
+ "message": "Set room key",
+ "room_id": room_id,
+ "session_id": session_id,
+ "room_key": room_key,
+ }
+ )
+
+ yield self.simple_insert_many(
+ table="e2e_room_keys", values=values, desc="add_e2e_room_keys"
)
@trace
@@ -110,11 +100,11 @@ class EndToEndRoomKeyStore(SQLBaseStore):
room, or a given session.
Args:
- user_id(str): the user whose backup we're querying
- version(str): the version ID of the backup for the set of keys we're querying
- room_id(str): Optional. the ID of the room whose keys we're querying, if any.
+ user_id (str): the user whose backup we're querying
+ version (str): the version ID of the backup for the set of keys we're querying
+ room_id (str): Optional. the ID of the room whose keys we're querying, if any.
If not specified, we return the keys for all the rooms in the backup.
- session_id(str): Optional. the session whose room_key we're querying, if any.
+ session_id (str): Optional. the session whose room_key we're querying, if any.
If specified, we also require the room_id to be specified.
If not specified, we return all the keys in this version of
the backup (or for the specified room)
@@ -135,7 +125,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
if session_id:
keyvalues["session_id"] = session_id
- rows = yield self._simple_select_list(
+ rows = yield self.simple_select_list(
table="e2e_room_keys",
keyvalues=keyvalues,
retcols=(
@@ -162,6 +152,95 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return sessions
+ def get_e2e_room_keys_multi(self, user_id, version, room_keys):
+ """Get multiple room keys at a time. The difference between this function and
+ get_e2e_room_keys is that this function can be used to retrieve
+ multiple specific keys at a time, whereas get_e2e_room_keys is used for
+ getting all the keys in a backup version, all the keys for a room, or a
+ specific key.
+
+ Args:
+ user_id (str): the user whose backup we're querying
+ version (str): the version ID of the backup we're querying about
+ room_keys (dict[str, dict[str, iterable[str]]]): a map from
+ room ID -> {"session": [session ids]} indicating the session IDs
+ that we want to query
+
+ Returns:
+ Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key
+ """
+
+ return self.runInteraction(
+ "get_e2e_room_keys_multi",
+ self._get_e2e_room_keys_multi_txn,
+ user_id,
+ version,
+ room_keys,
+ )
+
+ @staticmethod
+ def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys):
+ if not room_keys:
+ return {}
+
+ where_clauses = []
+ params = [user_id, version]
+ for room_id, room in room_keys.items():
+ sessions = list(room["sessions"])
+ if not sessions:
+ continue
+ params.append(room_id)
+ params.extend(sessions)
+ where_clauses.append(
+ "(room_id = ? AND session_id IN (%s))"
+ % (",".join(["?" for _ in sessions]),)
+ )
+
+ # check if we're actually querying something
+ if not where_clauses:
+ return {}
+
+ sql = """
+ SELECT room_id, session_id, first_message_index, forwarded_count,
+ is_verified, session_data
+ FROM e2e_room_keys
+ WHERE user_id = ? AND version = ? AND (%s)
+ """ % (
+ " OR ".join(where_clauses)
+ )
+
+ txn.execute(sql, params)
+
+ ret = {}
+
+ for row in txn:
+ room_id = row[0]
+ session_id = row[1]
+ ret.setdefault(room_id, {})
+ ret[room_id][session_id] = {
+ "first_message_index": row[2],
+ "forwarded_count": row[3],
+ "is_verified": row[4],
+ "session_data": json.loads(row[5]),
+ }
+
+ return ret
+
+ def count_e2e_room_keys(self, user_id, version):
+ """Get the number of keys in a backup version.
+
+ Args:
+ user_id (str): the user whose backup we're querying
+ version (str): the version ID of the backup we're querying about
+ """
+
+ return self.simple_select_one_onecol(
+ table="e2e_room_keys",
+ keyvalues={"user_id": user_id, "version": version},
+ retcol="COUNT(*)",
+ desc="count_e2e_room_keys",
+ )
+
@trace
@defer.inlineCallbacks
def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
@@ -188,7 +267,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
if session_id:
keyvalues["session_id"] = session_id
- yield self._simple_delete(
+ yield self.simple_delete(
table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
)
@@ -219,6 +298,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
version(str)
algorithm(str)
auth_data(object): opaque dict supplied by the client
+ etag(int): tag of the keys in the backup
"""
def _get_e2e_room_keys_version_info_txn(txn):
@@ -232,14 +312,16 @@ class EndToEndRoomKeyStore(SQLBaseStore):
# it isn't there.
raise StoreError(404, "No row found")
- result = self._simple_select_one_txn(
+ result = self.simple_select_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
- retcols=("version", "algorithm", "auth_data"),
+ retcols=("version", "algorithm", "auth_data", "etag"),
)
result["auth_data"] = json.loads(result["auth_data"])
result["version"] = str(result["version"])
+ if result["etag"] is None:
+ result["etag"] = 0
return result
return self.runInteraction(
@@ -270,7 +352,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
new_version = str(int(current_version) + 1)
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="e2e_room_keys_versions",
values={
@@ -288,21 +370,33 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
@trace
- def update_e2e_room_keys_version(self, user_id, version, info):
+ def update_e2e_room_keys_version(
+ self, user_id, version, info=None, version_etag=None
+ ):
"""Update a given backup version
Args:
user_id(str): the user whose backup version we're updating
version(str): the version ID of the backup version we're updating
- info(dict): the new backup version info to store
+ info (dict): the new backup version info to store. If None, then
+ the backup version info is not updated
+ version_etag (Optional[int]): etag of the keys in the backup. If
+ None, then the etag is not updated
"""
+ updatevalues = {}
- return self._simple_update(
- table="e2e_room_keys_versions",
- keyvalues={"user_id": user_id, "version": version},
- updatevalues={"auth_data": json.dumps(info["auth_data"])},
- desc="update_e2e_room_keys_version",
- )
+ if info is not None and "auth_data" in info:
+ updatevalues["auth_data"] = json.dumps(info["auth_data"])
+ if version_etag is not None:
+ updatevalues["etag"] = version_etag
+
+ if updatevalues:
+ return self.simple_update(
+ table="e2e_room_keys_versions",
+ keyvalues={"user_id": user_id, "version": version},
+ updatevalues=updatevalues,
+ desc="update_e2e_room_keys_version",
+ )
@trace
def delete_e2e_room_keys_version(self, user_id, version=None):
@@ -326,13 +420,13 @@ class EndToEndRoomKeyStore(SQLBaseStore):
else:
this_version = version
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": this_version},
)
- return self._simple_update_one_txn(
+ return self.simple_update_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version},
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 073412a78d..08bcdc4725 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -138,20 +138,35 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
result.setdefault(user_id, {})[device_id] = None
# get signatures on the device
- signature_sql = (
- "SELECT * " " FROM e2e_cross_signing_signatures " " WHERE %s"
- ) % (" OR ".join("(" + q + ")" for q in signature_query_clauses))
+ signature_sql = ("SELECT * FROM e2e_cross_signing_signatures WHERE %s") % (
+ " OR ".join("(" + q + ")" for q in signature_query_clauses)
+ )
txn.execute(signature_sql, signature_query_params)
rows = self.cursor_to_dict(txn)
+ # add each cross-signing signature to the correct device in the result dict.
for row in rows:
+ signing_user_id = row["user_id"]
+ signing_key_id = row["key_id"]
target_user_id = row["target_user_id"]
target_device_id = row["target_device_id"]
- if target_user_id in result and target_device_id in result[target_user_id]:
- result[target_user_id][target_device_id].setdefault(
- "signatures", {}
- ).setdefault(row["user_id"], {})[row["key_id"]] = row["signature"]
+ signature = row["signature"]
+
+ target_user_result = result.get(target_user_id)
+ if not target_user_result:
+ continue
+
+ target_device_result = target_user_result.get(target_device_id)
+ if not target_device_result:
+ # note that target_device_result will be None for deleted devices.
+ continue
+
+ target_device_signatures = target_device_result.setdefault("signatures", {})
+ signing_user_signatures = target_device_signatures.setdefault(
+ signing_user_id, {}
+ )
+ signing_user_signatures[signing_key_id] = signature
log_kv(result)
return result
@@ -171,7 +186,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
key_id) to json string for key
"""
- rows = yield self._simple_select_many_batch(
+ rows = yield self.simple_select_many_batch(
table="e2e_one_time_keys_json",
column="key_id",
iterable=key_ids,
@@ -204,7 +219,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
# a unique constraint. If there is a race of two calls to
# `add_e2e_one_time_keys` then they'll conflict and we will only
# insert one set.
- self._simple_insert_many_txn(
+ self.simple_insert_many_txn(
txn,
table="e2e_one_time_keys_json",
values=[
@@ -335,7 +350,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
GROUP BY user_id
"""
- return self._execute(
+ return self.execute(
"get_all_user_signature_changes_for_remotes", None, sql, from_key, to_key
)
@@ -352,7 +367,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
set_tag("time_now", time_now)
set_tag("device_keys", device_keys)
- old_key_json = self._simple_select_one_onecol_txn(
+ old_key_json = self.simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -368,7 +383,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
log_kv({"Message": "Device key already stored."})
return False
- self._simple_upsert_txn(
+ self.simple_upsert_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -427,12 +442,12 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
"user_id": user_id,
}
)
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
)
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="e2e_one_time_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
@@ -477,7 +492,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
# The "keys" property must only have one entry, which will be the public
# key, so we just grab the first value in there
pubkey = next(iter(key["keys"].values()))
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
"devices",
values={
@@ -490,7 +505,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
# and finally, store the key itself
with self._cross_signing_id_gen.get_next() as stream_id:
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
"e2e_cross_signing_keys",
values={
@@ -524,7 +539,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
user_id (str): the user who made the signatures
signatures (iterable[SignatureListItem]): signatures to add
"""
- return self._simple_insert_many(
+ return self.simple_insert_many(
"e2e_cross_signing_signatures",
[
{
diff --git a/synapse/storage/data_stores/main/event_federation.py b/synapse/storage/data_stores/main/event_federation.py
index 90bef0cd2c..051ac7a8cb 100644
--- a/synapse/storage/data_stores/main/event_federation.py
+++ b/synapse/storage/data_stores/main/event_federation.py
@@ -126,7 +126,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
Returns
Deferred[int]
"""
- rows = yield self._simple_select_many_batch(
+ rows = yield self.simple_select_many_batch(
table="events",
column="event_id",
iterable=event_ids,
@@ -140,7 +140,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return max(row["depth"] for row in rows)
def _get_oldest_events_in_room_txn(self, txn, room_id):
- return self._simple_select_onecol_txn(
+ return self.simple_select_onecol_txn(
txn,
table="event_backward_extremities",
keyvalues={"room_id": room_id},
@@ -235,7 +235,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
@cached(max_entries=5000, iterable=True)
def get_latest_event_ids_in_room(self, room_id):
- return self._simple_select_onecol(
+ return self.simple_select_onecol(
table="event_forward_extremities",
keyvalues={"room_id": room_id},
retcol="event_id",
@@ -271,7 +271,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
)
def _get_min_depth_interaction(self, txn, room_id):
- min_depth = self._simple_select_one_onecol_txn(
+ min_depth = self.simple_select_one_onecol_txn(
txn,
table="room_depth",
keyvalues={"room_id": room_id},
@@ -383,7 +383,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
queue = PriorityQueue()
for event_id in event_list:
- depth = self._simple_select_one_onecol_txn(
+ depth = self.simple_select_one_onecol_txn(
txn,
table="events",
keyvalues={"event_id": event_id, "room_id": room_id},
@@ -468,7 +468,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
Returns:
Deferred[list[str]]
"""
- rows = yield self._simple_select_many_batch(
+ rows = yield self.simple_select_many_batch(
table="event_edges",
column="prev_event_id",
iterable=event_ids,
@@ -508,7 +508,7 @@ class EventFederationStore(EventFederationWorkerStore):
if min_depth and depth >= min_depth:
return
- self._simple_upsert_txn(
+ self.simple_upsert_txn(
txn,
table="room_depth",
keyvalues={"room_id": room_id},
@@ -520,7 +520,7 @@ class EventFederationStore(EventFederationWorkerStore):
For the given event, update the event edges table and forward and
backward extremities tables.
"""
- self._simple_insert_many_txn(
+ self.simple_insert_many_txn(
txn,
table="event_edges",
values=[
diff --git a/synapse/storage/data_stores/main/event_push_actions.py b/synapse/storage/data_stores/main/event_push_actions.py
index 04ce21ac66..0a37847cfd 100644
--- a/synapse/storage/data_stores/main/event_push_actions.py
+++ b/synapse/storage/data_stores/main/event_push_actions.py
@@ -441,7 +441,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
def _add_push_actions_to_staging_txn(txn):
- # We don't use _simple_insert_many here to avoid the overhead
+ # We don't use simple_insert_many here to avoid the overhead
# of generating lists of dicts.
sql = """
@@ -472,7 +472,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
"""
try:
- res = yield self._simple_delete(
+ res = yield self.simple_delete(
table="event_push_actions_staging",
keyvalues={"event_id": event_id},
desc="remove_push_actions_from_staging",
@@ -677,7 +677,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
)
for event, _ in events_and_contexts:
- user_ids = self._simple_select_onecol_txn(
+ user_ids = self.simple_select_onecol_txn(
txn,
table="event_push_actions_staging",
keyvalues={"event_id": event.event_id},
@@ -844,7 +844,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
the archiving process has caught up or not.
"""
- old_rotate_stream_ordering = self._simple_select_one_onecol_txn(
+ old_rotate_stream_ordering = self.simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
keyvalues={},
@@ -880,7 +880,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
return caught_up
def _rotate_notifs_before_txn(self, txn, rotate_to_stream_ordering):
- old_rotate_stream_ordering = self._simple_select_one_onecol_txn(
+ old_rotate_stream_ordering = self.simple_select_one_onecol_txn(
txn,
table="event_push_summary_stream_ordering",
keyvalues={},
@@ -912,7 +912,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
# If the `old.user_id` above is NULL then we know there isn't already an
# entry in the table, so we simply insert it. Otherwise we update the
# existing table.
- self._simple_insert_many_txn(
+ self.simple_insert_many_txn(
txn,
table="event_push_summary",
values=[
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 878f7568a6..98ae69e996 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -130,6 +130,8 @@ class EventsStore(
if self.hs.config.redaction_retention_period is not None:
hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
+ self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
+
@defer.inlineCallbacks
def _read_forward_extremities(self):
def fetch(txn):
@@ -430,7 +432,7 @@ class EventsStore(
# event's auth chain, but its easier for now just to store them (and
# it doesn't take much storage compared to storing the entire event
# anyway).
- self._simple_insert_many_txn(
+ self.simple_insert_many_txn(
txn,
table="event_auth",
values=[
@@ -578,12 +580,12 @@ class EventsStore(
self, txn, new_forward_extremities, max_stream_order
):
for room_id, new_extrem in iteritems(new_forward_extremities):
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
)
txn.call_after(self.get_latest_event_ids_in_room.invalidate, (room_id,))
- self._simple_insert_many_txn(
+ self.simple_insert_many_txn(
txn,
table="event_forward_extremities",
values=[
@@ -596,7 +598,7 @@ class EventsStore(
# new stream_ordering to new forward extremeties in the room.
# This allows us to later efficiently look up the forward extremeties
# for a room before a given stream_ordering
- self._simple_insert_many_txn(
+ self.simple_insert_many_txn(
txn,
table="stream_ordering_to_exterm",
values=[
@@ -713,16 +715,14 @@ class EventsStore(
metadata_json = encode_json(event.internal_metadata.get_dict())
- sql = (
- "UPDATE event_json SET internal_metadata = ?" " WHERE event_id = ?"
- )
+ sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
txn.execute(sql, (metadata_json, event.event_id))
# Add an entry to the ex_outlier_stream table to replicate the
# change in outlier status to our workers.
stream_order = event.internal_metadata.stream_ordering
state_group_id = context.state_group
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="ex_outlier_stream",
values={
@@ -732,7 +732,7 @@ class EventsStore(
},
)
- sql = "UPDATE events SET outlier = ?" " WHERE event_id = ?"
+ sql = "UPDATE events SET outlier = ? WHERE event_id = ?"
txn.execute(sql, (False, event.event_id))
# Update the event_backward_extremities table now that this
@@ -794,7 +794,7 @@ class EventsStore(
d.pop("redacted_because", None)
return d
- self._simple_insert_many_txn(
+ self.simple_insert_many_txn(
txn,
table="event_json",
values=[
@@ -811,7 +811,7 @@ class EventsStore(
],
)
- self._simple_insert_many_txn(
+ self.simple_insert_many_txn(
txn,
table="events",
values=[
@@ -841,7 +841,7 @@ class EventsStore(
# If we're persisting an unredacted event we go and ensure
# that we mark any redactions that reference this event as
# requiring censoring.
- self._simple_update_txn(
+ self.simple_update_txn(
txn,
table="redactions",
keyvalues={"redacts": event.event_id},
@@ -929,6 +929,9 @@ class EventsStore(
elif event.type == EventTypes.Redaction:
# Insert into the redactions table.
self._store_redaction(txn, event)
+ elif event.type == EventTypes.Retention:
+ # Update the room_retention table.
+ self._store_retention_policy_for_room_txn(txn, event)
self._handle_event_relations(txn, event)
@@ -939,6 +942,12 @@ class EventsStore(
txn, event.event_id, labels, event.room_id, event.depth
)
+ if self._ephemeral_messages_enabled:
+ # If there's an expiry timestamp on the event, store it.
+ expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
+ if isinstance(expiry_ts, int) and not event.is_state():
+ self._insert_event_expiry_txn(txn, event.event_id, expiry_ts)
+
# Insert into the room_memberships table.
self._store_room_members_txn(
txn,
@@ -974,7 +983,7 @@ class EventsStore(
state_values.append(vals)
- self._simple_insert_many_txn(txn, table="state_events", values=state_values)
+ self.simple_insert_many_txn(txn, table="state_events", values=state_values)
# Prefill the event cache
self._add_to_cache(txn, events_and_contexts)
@@ -1023,7 +1032,7 @@ class EventsStore(
# invalidate the cache for the redacted event
txn.call_after(self._invalidate_get_event_cache, event.redacts)
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="redactions",
values={
@@ -1068,9 +1077,7 @@ class EventsStore(
LIMIT ?
"""
- rows = yield self._execute(
- "_censor_redactions_fetch", None, sql, before_ts, 100
- )
+ rows = yield self.execute("_censor_redactions_fetch", None, sql, before_ts, 100)
updates = []
@@ -1100,14 +1107,9 @@ class EventsStore(
def _update_censor_txn(txn):
for redaction_id, event_id, pruned_json in updates:
if pruned_json:
- self._simple_update_one_txn(
- txn,
- table="event_json",
- keyvalues={"event_id": event_id},
- updatevalues={"json": pruned_json},
- )
+ self._censor_event_txn(txn, event_id, pruned_json)
- self._simple_update_one_txn(
+ self.simple_update_one_txn(
txn,
table="redactions",
keyvalues={"event_id": redaction_id},
@@ -1116,6 +1118,22 @@ class EventsStore(
yield self.runInteraction("_update_censor_txn", _update_censor_txn)
+ def _censor_event_txn(self, txn, event_id, pruned_json):
+ """Censor an event by replacing its JSON in the event_json table with the
+ provided pruned JSON.
+
+ Args:
+ txn (LoggingTransaction): The database transaction.
+ event_id (str): The ID of the event to censor.
+ pruned_json (str): The pruned JSON
+ """
+ self.simple_update_one_txn(
+ txn,
+ table="event_json",
+ keyvalues={"event_id": event_id},
+ updatevalues={"json": pruned_json},
+ )
+
@defer.inlineCallbacks
def count_daily_messages(self):
"""
@@ -1479,7 +1497,7 @@ class EventsStore(
# We do joins against events_to_purge for e.g. calculating state
# groups to purge, etc., so lets make an index.
- txn.execute("CREATE INDEX events_to_purge_id" " ON events_to_purge(event_id)")
+ txn.execute("CREATE INDEX events_to_purge_id ON events_to_purge(event_id)")
txn.execute("SELECT event_id, should_delete FROM events_to_purge")
event_rows = txn.fetchall()
@@ -1760,7 +1778,7 @@ class EventsStore(
"[purge] found %i state groups to delete", len(state_groups_to_delete)
)
- rows = self._simple_select_many_txn(
+ rows = self.simple_select_many_txn(
txn,
table="state_group_edges",
column="prev_state_group",
@@ -1787,15 +1805,15 @@ class EventsStore(
curr_state = self._get_state_groups_from_groups_txn(txn, [sg])
curr_state = curr_state[sg]
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn, table="state_groups_state", keyvalues={"state_group": sg}
)
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn, table="state_group_edges", keyvalues={"state_group": sg}
)
- self._simple_insert_many_txn(
+ self.simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
@@ -1832,7 +1850,7 @@ class EventsStore(
state group.
"""
- rows = yield self._simple_select_many_batch(
+ rows = yield self.simple_select_many_batch(
table="state_group_edges",
column="prev_state_group",
iterable=state_groups,
@@ -1862,7 +1880,7 @@ class EventsStore(
# first we have to delete the state groups states
logger.info("[purge] removing %s from state_groups_state", room_id)
- self._simple_delete_many_txn(
+ self.simple_delete_many_txn(
txn,
table="state_groups_state",
column="state_group",
@@ -1873,7 +1891,7 @@ class EventsStore(
# ... and the state group edges
logger.info("[purge] removing %s from state_group_edges", room_id)
- self._simple_delete_many_txn(
+ self.simple_delete_many_txn(
txn,
table="state_group_edges",
column="state_group",
@@ -1884,7 +1902,7 @@ class EventsStore(
# ... and the state groups
logger.info("[purge] removing %s from state_groups", room_id)
- self._simple_delete_many_txn(
+ self.simple_delete_many_txn(
txn,
table="state_groups",
column="id",
@@ -1901,7 +1919,7 @@ class EventsStore(
@cachedInlineCallbacks(max_entries=5000)
def _get_event_ordering(self, event_id):
- res = yield self._simple_select_one(
+ res = yield self.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
keyvalues={"event_id": event_id},
@@ -1942,7 +1960,7 @@ class EventsStore(
room_id (str): The ID of the room the event was sent to.
topological_ordering (int): The position of the event in the room's topology.
"""
- return self._simple_insert_many_txn(
+ return self.simple_insert_many_txn(
txn=txn,
table="event_labels",
values=[
@@ -1956,6 +1974,101 @@ class EventsStore(
],
)
+ def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
+ """Save the expiry timestamp associated with a given event ID.
+
+ Args:
+ txn (LoggingTransaction): The database transaction to use.
+ event_id (str): The event ID the expiry timestamp is associated with.
+ expiry_ts (int): The timestamp at which to expire (delete) the event.
+ """
+ return self.simple_insert_txn(
+ txn=txn,
+ table="event_expiry",
+ values={"event_id": event_id, "expiry_ts": expiry_ts},
+ )
+
+ @defer.inlineCallbacks
+ def expire_event(self, event_id):
+ """Retrieve and expire an event that has expired, and delete its associated
+ expiry timestamp. If the event can't be retrieved, delete its associated
+ timestamp so we don't try to expire it again in the future.
+
+ Args:
+ event_id (str): The ID of the event to delete.
+ """
+ # Try to retrieve the event's content from the database or the event cache.
+ event = yield self.get_event(event_id)
+
+ def delete_expired_event_txn(txn):
+ # Delete the expiry timestamp associated with this event from the database.
+ self._delete_event_expiry_txn(txn, event_id)
+
+ if not event:
+ # If we can't find the event, log a warning and delete the expiry date
+ # from the database so that we don't try to expire it again in the
+ # future.
+ logger.warning(
+ "Can't expire event %s because we don't have it.", event_id
+ )
+ return
+
+ # Prune the event's dict then convert it to JSON.
+ pruned_json = encode_json(prune_event_dict(event.get_dict()))
+
+ # Update the event_json table to replace the event's JSON with the pruned
+ # JSON.
+ self._censor_event_txn(txn, event.event_id, pruned_json)
+
+ # We need to invalidate the event cache entry for this event because we
+ # changed its content in the database. We can't call
+ # self._invalidate_cache_and_stream because self.get_event_cache isn't of the
+ # right type.
+ txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
+ # Send that invalidation to replication so that other workers also invalidate
+ # the event cache.
+ self._send_invalidation_to_replication(
+ txn, "_get_event_cache", (event.event_id,)
+ )
+
+ yield self.runInteraction("delete_expired_event", delete_expired_event_txn)
+
+ def _delete_event_expiry_txn(self, txn, event_id):
+ """Delete the expiry timestamp associated with an event ID without deleting the
+ actual event.
+
+ Args:
+ txn (LoggingTransaction): The transaction to use to perform the deletion.
+ event_id (str): The event ID to delete the associated expiry timestamp of.
+ """
+ return self.simple_delete_txn(
+ txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
+ )
+
+ def get_next_event_to_expire(self):
+ """Retrieve the entry with the lowest expiry timestamp in the event_expiry
+ table, or None if there's no more event to expire.
+
+ Returns: Deferred[Optional[Tuple[str, int]]]
+ A tuple containing the event ID as its first element and an expiry timestamp
+ as its second one, if there's at least one row in the event_expiry table.
+ None otherwise.
+ """
+
+ def get_next_event_to_expire_txn(txn):
+ txn.execute(
+ """
+ SELECT event_id, expiry_ts FROM event_expiry
+ ORDER BY expiry_ts ASC LIMIT 1
+ """
+ )
+
+ return txn.fetchone()
+
+ return self.runInteraction(
+ desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
+ )
+
AllNewEventsResult = namedtuple(
"AllNewEventsResult",
diff --git a/synapse/storage/data_stores/main/events_bg_updates.py b/synapse/storage/data_stores/main/events_bg_updates.py
index 0ed59ef48e..37dfc8c871 100644
--- a/synapse/storage/data_stores/main/events_bg_updates.py
+++ b/synapse/storage/data_stores/main/events_bg_updates.py
@@ -189,7 +189,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
for chunk in chunks:
- ev_rows = self._simple_select_many_txn(
+ ev_rows = self.simple_select_many_txn(
txn,
table="event_json",
column="event_id",
@@ -366,7 +366,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
to_delete.intersection_update(original_set)
- deleted = self._simple_delete_many_txn(
+ deleted = self.simple_delete_many_txn(
txn=txn,
table="event_forward_extremities",
column="event_id",
@@ -382,7 +382,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
if deleted:
# We now need to invalidate the caches of these rooms
- rows = self._simple_select_many_txn(
+ rows = self.simple_select_many_txn(
txn,
table="events",
column="event_id",
@@ -396,7 +396,7 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
self.get_latest_event_ids_in_room.invalidate, (room_id,)
)
- self._simple_delete_many_txn(
+ self.simple_delete_many_txn(
txn=txn,
table="_extremities_to_check",
column="event_id",
@@ -530,24 +530,31 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
nbrows = 0
last_row_event_id = ""
for (event_id, event_json_raw) in results:
- event_json = json.loads(event_json_raw)
-
- self._simple_insert_many_txn(
- txn=txn,
- table="event_labels",
- values=[
- {
- "event_id": event_id,
- "label": label,
- "room_id": event_json["room_id"],
- "topological_ordering": event_json["depth"],
- }
- for label in event_json["content"].get(
- EventContentFields.LABELS, []
- )
- if isinstance(label, str)
- ],
- )
+ try:
+ event_json = json.loads(event_json_raw)
+
+ self.simple_insert_many_txn(
+ txn=txn,
+ table="event_labels",
+ values=[
+ {
+ "event_id": event_id,
+ "label": label,
+ "room_id": event_json["room_id"],
+ "topological_ordering": event_json["depth"],
+ }
+ for label in event_json["content"].get(
+ EventContentFields.LABELS, []
+ )
+ if isinstance(label, str)
+ ],
+ )
+ except Exception as e:
+ logger.warning(
+ "Unable to load event %s (no labels will be imported): %s",
+ event_id,
+ e,
+ )
nbrows += 1
last_row_event_id = event_id
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 4c4b76bd93..6a08a746b6 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -17,6 +17,7 @@ from __future__ import division
import itertools
import logging
+import threading
from collections import namedtuple
from canonicaljson import json
@@ -34,6 +35,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.types import get_domain_from_id
from synapse.util import batch_iter
+from synapse.util.caches.descriptors import Cache
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@@ -53,6 +55,17 @@ _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
class EventsWorkerStore(SQLBaseStore):
+ def __init__(self, db_conn, hs):
+ super(EventsWorkerStore, self).__init__(db_conn, hs)
+
+ self._get_event_cache = Cache(
+ "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
+ )
+
+ self._event_fetch_lock = threading.Condition()
+ self._event_fetch_list = []
+ self._event_fetch_ongoing = 0
+
def get_received_ts(self, event_id):
"""Get received_ts (when it was persisted) for the event.
@@ -65,7 +78,7 @@ class EventsWorkerStore(SQLBaseStore):
Deferred[int|None]: Timestamp in milliseconds, or None for events
that were persisted before received_ts was implemented.
"""
- return self._simple_select_one_onecol(
+ return self.simple_select_one_onecol(
table="events",
keyvalues={"event_id": event_id},
retcol="received_ts",
@@ -439,7 +452,7 @@ class EventsWorkerStore(SQLBaseStore):
event_id for events, _ in event_list for event_id in events
)
- row_dict = self._new_transaction(
+ row_dict = self.new_transaction(
conn, "do_fetch", [], [], self._fetch_event_rows, events_to_fetch
)
@@ -732,7 +745,7 @@ class EventsWorkerStore(SQLBaseStore):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
- rows = yield self._simple_select_many_batch(
+ rows = yield self.simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
@@ -770,40 +783,6 @@ class EventsWorkerStore(SQLBaseStore):
yield self.runInteraction("have_seen_events", have_seen_events_txn, chunk)
return results
- def get_seen_events_with_rejections(self, event_ids):
- """Given a list of event ids, check if we rejected them.
-
- Args:
- event_ids (list[str])
-
- Returns:
- Deferred[dict[str, str|None):
- Has an entry for each event id we already have seen. Maps to
- the rejected reason string if we rejected the event, else maps
- to None.
- """
- if not event_ids:
- return defer.succeed({})
-
- def f(txn):
- sql = (
- "SELECT e.event_id, reason FROM events as e "
- "LEFT JOIN rejections as r ON e.event_id = r.event_id "
- "WHERE e.event_id = ?"
- )
-
- res = {}
- for event_id in event_ids:
- txn.execute(sql, (event_id,))
- row = txn.fetchone()
- if row:
- _, rejected = row
- res[event_id] = rejected
-
- return res
-
- return self.runInteraction("get_seen_events_with_rejections", f)
-
def _get_total_state_event_counts_txn(self, txn, room_id):
"""
See get_total_state_event_counts.
diff --git a/synapse/storage/data_stores/main/filtering.py b/synapse/storage/data_stores/main/filtering.py
index a2a2a67927..17ef7b9354 100644
--- a/synapse/storage/data_stores/main/filtering.py
+++ b/synapse/storage/data_stores/main/filtering.py
@@ -30,7 +30,7 @@ class FilteringStore(SQLBaseStore):
except ValueError:
raise SynapseError(400, "Invalid filter ID", Codes.INVALID_PARAM)
- def_json = yield self._simple_select_one_onecol(
+ def_json = yield self.simple_select_one_onecol(
table="user_filters",
keyvalues={"user_id": user_localpart, "filter_id": filter_id},
retcol="filter_json",
@@ -55,7 +55,7 @@ class FilteringStore(SQLBaseStore):
if filter_id_response is not None:
return filter_id_response[0]
- sql = "SELECT MAX(filter_id) FROM user_filters " "WHERE user_id = ?"
+ sql = "SELECT MAX(filter_id) FROM user_filters WHERE user_id = ?"
txn.execute(sql, (user_localpart,))
max_id = txn.fetchone()[0]
if max_id is None:
diff --git a/synapse/storage/data_stores/main/group_server.py b/synapse/storage/data_stores/main/group_server.py
index 5ded539af8..9e1d12bcb7 100644
--- a/synapse/storage/data_stores/main/group_server.py
+++ b/synapse/storage/data_stores/main/group_server.py
@@ -35,7 +35,7 @@ class GroupServerStore(SQLBaseStore):
* "invite"
* "open"
"""
- return self._simple_update_one(
+ return self.simple_update_one(
table="groups",
keyvalues={"group_id": group_id},
updatevalues={"join_policy": join_policy},
@@ -43,7 +43,7 @@ class GroupServerStore(SQLBaseStore):
)
def get_group(self, group_id):
- return self._simple_select_one(
+ return self.simple_select_one(
table="groups",
keyvalues={"group_id": group_id},
retcols=(
@@ -65,7 +65,7 @@ class GroupServerStore(SQLBaseStore):
if not include_private:
keyvalues["is_public"] = True
- return self._simple_select_list(
+ return self.simple_select_list(
table="group_users",
keyvalues=keyvalues,
retcols=("user_id", "is_public", "is_admin"),
@@ -75,7 +75,7 @@ class GroupServerStore(SQLBaseStore):
def get_invited_users_in_group(self, group_id):
# TODO: Pagination
- return self._simple_select_onecol(
+ return self.simple_select_onecol(
table="group_invites",
keyvalues={"group_id": group_id},
retcol="user_id",
@@ -89,7 +89,7 @@ class GroupServerStore(SQLBaseStore):
if not include_private:
keyvalues["is_public"] = True
- return self._simple_select_list(
+ return self.simple_select_list(
table="group_rooms",
keyvalues=keyvalues,
retcols=("room_id", "is_public"),
@@ -180,7 +180,7 @@ class GroupServerStore(SQLBaseStore):
an order of 1 will put the room first. Otherwise, the room gets
added to the end.
"""
- room_in_group = self._simple_select_one_onecol_txn(
+ room_in_group = self.simple_select_one_onecol_txn(
txn,
table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
@@ -193,7 +193,7 @@ class GroupServerStore(SQLBaseStore):
if category_id is None:
category_id = _DEFAULT_CATEGORY_ID
else:
- cat_exists = self._simple_select_one_onecol_txn(
+ cat_exists = self.simple_select_one_onecol_txn(
txn,
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
@@ -204,7 +204,7 @@ class GroupServerStore(SQLBaseStore):
raise SynapseError(400, "Category doesn't exist")
# TODO: Check category is part of summary already
- cat_exists = self._simple_select_one_onecol_txn(
+ cat_exists = self.simple_select_one_onecol_txn(
txn,
table="group_summary_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
@@ -224,7 +224,7 @@ class GroupServerStore(SQLBaseStore):
(group_id, category_id, group_id, category_id),
)
- existing = self._simple_select_one_txn(
+ existing = self.simple_select_one_txn(
txn,
table="group_summary_rooms",
keyvalues={
@@ -257,7 +257,7 @@ class GroupServerStore(SQLBaseStore):
to_update["room_order"] = order
if is_public is not None:
to_update["is_public"] = is_public
- self._simple_update_txn(
+ self.simple_update_txn(
txn,
table="group_summary_rooms",
keyvalues={
@@ -271,7 +271,7 @@ class GroupServerStore(SQLBaseStore):
if is_public is None:
is_public = True
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="group_summary_rooms",
values={
@@ -287,7 +287,7 @@ class GroupServerStore(SQLBaseStore):
if category_id is None:
category_id = _DEFAULT_CATEGORY_ID
- return self._simple_delete(
+ return self.simple_delete(
table="group_summary_rooms",
keyvalues={
"group_id": group_id,
@@ -299,7 +299,7 @@ class GroupServerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_group_categories(self, group_id):
- rows = yield self._simple_select_list(
+ rows = yield self.simple_select_list(
table="group_room_categories",
keyvalues={"group_id": group_id},
retcols=("category_id", "is_public", "profile"),
@@ -316,7 +316,7 @@ class GroupServerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_group_category(self, group_id, category_id):
- category = yield self._simple_select_one(
+ category = yield self.simple_select_one(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
retcols=("is_public", "profile"),
@@ -343,7 +343,7 @@ class GroupServerStore(SQLBaseStore):
else:
update_values["is_public"] = is_public
- return self._simple_upsert(
+ return self.simple_upsert(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
values=update_values,
@@ -352,7 +352,7 @@ class GroupServerStore(SQLBaseStore):
)
def remove_group_category(self, group_id, category_id):
- return self._simple_delete(
+ return self.simple_delete(
table="group_room_categories",
keyvalues={"group_id": group_id, "category_id": category_id},
desc="remove_group_category",
@@ -360,7 +360,7 @@ class GroupServerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_group_roles(self, group_id):
- rows = yield self._simple_select_list(
+ rows = yield self.simple_select_list(
table="group_roles",
keyvalues={"group_id": group_id},
retcols=("role_id", "is_public", "profile"),
@@ -377,7 +377,7 @@ class GroupServerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_group_role(self, group_id, role_id):
- role = yield self._simple_select_one(
+ role = yield self.simple_select_one(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
retcols=("is_public", "profile"),
@@ -404,7 +404,7 @@ class GroupServerStore(SQLBaseStore):
else:
update_values["is_public"] = is_public
- return self._simple_upsert(
+ return self.simple_upsert(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
values=update_values,
@@ -413,7 +413,7 @@ class GroupServerStore(SQLBaseStore):
)
def remove_group_role(self, group_id, role_id):
- return self._simple_delete(
+ return self.simple_delete(
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
desc="remove_group_role",
@@ -444,7 +444,7 @@ class GroupServerStore(SQLBaseStore):
an order of 1 will put the user first. Otherwise, the user gets
added to the end.
"""
- user_in_group = self._simple_select_one_onecol_txn(
+ user_in_group = self.simple_select_one_onecol_txn(
txn,
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -457,7 +457,7 @@ class GroupServerStore(SQLBaseStore):
if role_id is None:
role_id = _DEFAULT_ROLE_ID
else:
- role_exists = self._simple_select_one_onecol_txn(
+ role_exists = self.simple_select_one_onecol_txn(
txn,
table="group_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
@@ -468,7 +468,7 @@ class GroupServerStore(SQLBaseStore):
raise SynapseError(400, "Role doesn't exist")
# TODO: Check role is part of the summary already
- role_exists = self._simple_select_one_onecol_txn(
+ role_exists = self.simple_select_one_onecol_txn(
txn,
table="group_summary_roles",
keyvalues={"group_id": group_id, "role_id": role_id},
@@ -488,7 +488,7 @@ class GroupServerStore(SQLBaseStore):
(group_id, role_id, group_id, role_id),
)
- existing = self._simple_select_one_txn(
+ existing = self.simple_select_one_txn(
txn,
table="group_summary_users",
keyvalues={"group_id": group_id, "user_id": user_id, "role_id": role_id},
@@ -517,7 +517,7 @@ class GroupServerStore(SQLBaseStore):
to_update["user_order"] = order
if is_public is not None:
to_update["is_public"] = is_public
- self._simple_update_txn(
+ self.simple_update_txn(
txn,
table="group_summary_users",
keyvalues={
@@ -531,7 +531,7 @@ class GroupServerStore(SQLBaseStore):
if is_public is None:
is_public = True
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="group_summary_users",
values={
@@ -547,7 +547,7 @@ class GroupServerStore(SQLBaseStore):
if role_id is None:
role_id = _DEFAULT_ROLE_ID
- return self._simple_delete(
+ return self.simple_delete(
table="group_summary_users",
keyvalues={"group_id": group_id, "role_id": role_id, "user_id": user_id},
desc="remove_user_from_summary",
@@ -561,7 +561,7 @@ class GroupServerStore(SQLBaseStore):
Deferred[list[str]]: A twisted.Deferred containing a list of group ids
containing this room
"""
- return self._simple_select_onecol(
+ return self.simple_select_onecol(
table="group_rooms",
keyvalues={"room_id": room_id},
retcol="group_id",
@@ -630,7 +630,7 @@ class GroupServerStore(SQLBaseStore):
)
def is_user_in_group(self, user_id, group_id):
- return self._simple_select_one_onecol(
+ return self.simple_select_one_onecol(
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id",
@@ -639,7 +639,7 @@ class GroupServerStore(SQLBaseStore):
).addCallback(lambda r: bool(r))
def is_user_admin_in_group(self, group_id, user_id):
- return self._simple_select_one_onecol(
+ return self.simple_select_one_onecol(
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="is_admin",
@@ -650,7 +650,7 @@ class GroupServerStore(SQLBaseStore):
def add_group_invite(self, group_id, user_id):
"""Record that the group server has invited a user
"""
- return self._simple_insert(
+ return self.simple_insert(
table="group_invites",
values={"group_id": group_id, "user_id": user_id},
desc="add_group_invite",
@@ -659,7 +659,7 @@ class GroupServerStore(SQLBaseStore):
def is_user_invited_to_local_group(self, group_id, user_id):
"""Has the group server invited a user?
"""
- return self._simple_select_one_onecol(
+ return self.simple_select_one_onecol(
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id",
@@ -682,7 +682,7 @@ class GroupServerStore(SQLBaseStore):
"""
def _get_users_membership_in_group_txn(txn):
- row = self._simple_select_one_txn(
+ row = self.simple_select_one_txn(
txn,
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -697,7 +697,7 @@ class GroupServerStore(SQLBaseStore):
"is_privileged": row["is_admin"],
}
- row = self._simple_select_one_onecol_txn(
+ row = self.simple_select_one_onecol_txn(
txn,
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -738,7 +738,7 @@ class GroupServerStore(SQLBaseStore):
"""
def _add_user_to_group_txn(txn):
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="group_users",
values={
@@ -749,14 +749,14 @@ class GroupServerStore(SQLBaseStore):
},
)
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
)
if local_attestation:
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="group_attestations_renewals",
values={
@@ -766,7 +766,7 @@ class GroupServerStore(SQLBaseStore):
},
)
if remote_attestation:
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="group_attestations_remote",
values={
@@ -781,27 +781,27 @@ class GroupServerStore(SQLBaseStore):
def remove_user_from_group(self, group_id, user_id):
def _remove_user_from_group_txn(txn):
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="group_summary_users",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -812,14 +812,14 @@ class GroupServerStore(SQLBaseStore):
)
def add_room_to_group(self, group_id, room_id, is_public):
- return self._simple_insert(
+ return self.simple_insert(
table="group_rooms",
values={"group_id": group_id, "room_id": room_id, "is_public": is_public},
desc="add_room_to_group",
)
def update_room_in_group_visibility(self, group_id, room_id, is_public):
- return self._simple_update(
+ return self.simple_update(
table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
updatevalues={"is_public": is_public},
@@ -828,13 +828,13 @@ class GroupServerStore(SQLBaseStore):
def remove_room_from_group(self, group_id, room_id):
def _remove_room_from_group_txn(txn):
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="group_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
)
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="group_summary_rooms",
keyvalues={"group_id": group_id, "room_id": room_id},
@@ -847,7 +847,7 @@ class GroupServerStore(SQLBaseStore):
def get_publicised_groups_for_user(self, user_id):
"""Get all groups a user is publicising
"""
- return self._simple_select_onecol(
+ return self.simple_select_onecol(
table="local_group_membership",
keyvalues={"user_id": user_id, "membership": "join", "is_publicised": True},
retcol="group_id",
@@ -857,7 +857,7 @@ class GroupServerStore(SQLBaseStore):
def update_group_publicity(self, group_id, user_id, publicise):
"""Update whether the user is publicising their membership of the group
"""
- return self._simple_update_one(
+ return self.simple_update_one(
table="local_group_membership",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"is_publicised": publicise},
@@ -893,12 +893,12 @@ class GroupServerStore(SQLBaseStore):
def _register_user_group_membership_txn(txn, next_id):
# TODO: Upsert?
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="local_group_membership",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="local_group_membership",
values={
@@ -911,7 +911,7 @@ class GroupServerStore(SQLBaseStore):
},
)
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="local_group_updates",
values={
@@ -930,7 +930,7 @@ class GroupServerStore(SQLBaseStore):
if membership == "join":
if local_attestation:
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="group_attestations_renewals",
values={
@@ -940,7 +940,7 @@ class GroupServerStore(SQLBaseStore):
},
)
if remote_attestation:
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="group_attestations_remote",
values={
@@ -951,12 +951,12 @@ class GroupServerStore(SQLBaseStore):
},
)
else:
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
)
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
@@ -976,7 +976,7 @@ class GroupServerStore(SQLBaseStore):
def create_group(
self, group_id, user_id, name, avatar_url, short_description, long_description
):
- yield self._simple_insert(
+ yield self.simple_insert(
table="groups",
values={
"group_id": group_id,
@@ -991,7 +991,7 @@ class GroupServerStore(SQLBaseStore):
@defer.inlineCallbacks
def update_group_profile(self, group_id, profile):
- yield self._simple_update_one(
+ yield self.simple_update_one(
table="groups",
keyvalues={"group_id": group_id},
updatevalues=profile,
@@ -1017,7 +1017,7 @@ class GroupServerStore(SQLBaseStore):
def update_attestation_renewal(self, group_id, user_id, attestation):
"""Update an attestation that we have renewed
"""
- return self._simple_update_one(
+ return self.simple_update_one(
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"valid_until_ms": attestation["valid_until_ms"]},
@@ -1027,7 +1027,7 @@ class GroupServerStore(SQLBaseStore):
def update_remote_attestion(self, group_id, user_id, attestation):
"""Update an attestation that a remote has renewed
"""
- return self._simple_update_one(
+ return self.simple_update_one(
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={
@@ -1046,7 +1046,7 @@ class GroupServerStore(SQLBaseStore):
group_id (str)
user_id (str)
"""
- return self._simple_delete(
+ return self.simple_delete(
table="group_attestations_renewals",
keyvalues={"group_id": group_id, "user_id": user_id},
desc="remove_attestation_renewal",
@@ -1057,7 +1057,7 @@ class GroupServerStore(SQLBaseStore):
"""Get the attestation that proves the remote agrees that the user is
in the group.
"""
- row = yield self._simple_select_one(
+ row = yield self.simple_select_one(
table="group_attestations_remote",
keyvalues={"group_id": group_id, "user_id": user_id},
retcols=("valid_until_ms", "attestation_json"),
@@ -1072,7 +1072,7 @@ class GroupServerStore(SQLBaseStore):
return None
def get_joined_groups(self, user_id):
- return self._simple_select_onecol(
+ return self.simple_select_onecol(
table="local_group_membership",
keyvalues={"user_id": user_id, "membership": "join"},
retcol="group_id",
@@ -1188,7 +1188,7 @@ class GroupServerStore(SQLBaseStore):
]
for table in tables:
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn, table=table, keyvalues={"group_id": group_id}
)
diff --git a/synapse/storage/data_stores/main/keys.py b/synapse/storage/data_stores/main/keys.py
index ebc7db3ed6..c7150432b3 100644
--- a/synapse/storage/data_stores/main/keys.py
+++ b/synapse/storage/data_stores/main/keys.py
@@ -129,7 +129,7 @@ class KeyStore(SQLBaseStore):
return self.runInteraction(
"store_server_verify_keys",
- self._simple_upsert_many_txn,
+ self.simple_upsert_many_txn,
table="server_signature_keys",
key_names=("server_name", "key_id"),
key_values=key_values,
@@ -157,7 +157,7 @@ class KeyStore(SQLBaseStore):
ts_valid_until_ms (int): The time when this json stops being valid.
key_json (bytes): The encoded JSON.
"""
- return self._simple_upsert(
+ return self.simple_upsert(
table="server_keys_json",
keyvalues={
"server_name": server_name,
@@ -196,7 +196,7 @@ class KeyStore(SQLBaseStore):
keyvalues["key_id"] = key_id
if from_server is not None:
keyvalues["from_server"] = from_server
- rows = self._simple_select_list_txn(
+ rows = self.simple_select_list_txn(
txn,
"server_keys_json",
keyvalues=keyvalues,
diff --git a/synapse/storage/data_stores/main/media_repository.py b/synapse/storage/data_stores/main/media_repository.py
index 84b5f3ad5e..0cb9446f96 100644
--- a/synapse/storage/data_stores/main/media_repository.py
+++ b/synapse/storage/data_stores/main/media_repository.py
@@ -39,7 +39,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
Returns:
None if the media_id doesn't exist.
"""
- return self._simple_select_one(
+ return self.simple_select_one(
"local_media_repository",
{"media_id": media_id},
(
@@ -64,7 +64,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
user_id,
url_cache=None,
):
- return self._simple_insert(
+ return self.simple_insert(
"local_media_repository",
{
"media_id": media_id,
@@ -129,7 +129,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def store_url_cache(
self, url, response_code, etag, expires_ts, og, media_id, download_ts
):
- return self._simple_insert(
+ return self.simple_insert(
"local_media_repository_url_cache",
{
"url": url,
@@ -144,7 +144,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
def get_local_media_thumbnails(self, media_id):
- return self._simple_select_list(
+ return self.simple_select_list(
"local_media_repository_thumbnails",
{"media_id": media_id},
(
@@ -166,7 +166,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
thumbnail_method,
thumbnail_length,
):
- return self._simple_insert(
+ return self.simple_insert(
"local_media_repository_thumbnails",
{
"media_id": media_id,
@@ -180,7 +180,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
)
def get_cached_remote_media(self, origin, media_id):
- return self._simple_select_one(
+ return self.simple_select_one(
"remote_media_cache",
{"media_origin": origin, "media_id": media_id},
(
@@ -205,7 +205,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
upload_name,
filesystem_id,
):
- return self._simple_insert(
+ return self.simple_insert(
"remote_media_cache",
{
"media_origin": origin,
@@ -253,7 +253,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
return self.runInteraction("update_cached_last_access_time", update_cache_txn)
def get_remote_media_thumbnails(self, origin, media_id):
- return self._simple_select_list(
+ return self.simple_select_list(
"remote_media_cache_thumbnails",
{"media_origin": origin, "media_id": media_id},
(
@@ -278,7 +278,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
thumbnail_method,
thumbnail_length,
):
- return self._simple_insert(
+ return self.simple_insert(
"remote_media_cache_thumbnails",
{
"media_origin": origin,
@@ -300,18 +300,18 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" WHERE last_access_ts < ?"
)
- return self._execute(
+ return self.execute(
"get_remote_media_before", self.cursor_to_dict, sql, before_ts
)
def delete_remote_media(self, media_origin, media_id):
def delete_remote_media_txn(txn):
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
"remote_media_cache",
keyvalues={"media_origin": media_origin, "media_id": media_id},
)
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
"remote_media_cache_thumbnails",
keyvalues={"media_origin": media_origin, "media_id": media_id},
@@ -337,7 +337,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
if len(media_ids) == 0:
return
- sql = "DELETE FROM local_media_repository_url_cache" " WHERE media_id = ?"
+ sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
def _delete_url_cache_txn(txn):
txn.executemany(sql, [(media_id,) for media_id in media_ids])
@@ -365,11 +365,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
return
def _delete_url_cache_media_txn(txn):
- sql = "DELETE FROM local_media_repository" " WHERE media_id = ?"
+ sql = "DELETE FROM local_media_repository WHERE media_id = ?"
txn.executemany(sql, [(media_id,) for media_id in media_ids])
- sql = "DELETE FROM local_media_repository_thumbnails" " WHERE media_id = ?"
+ sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
txn.executemany(sql, [(media_id,) for media_id in media_ids])
diff --git a/synapse/storage/data_stores/main/monthly_active_users.py b/synapse/storage/data_stores/main/monthly_active_users.py
index b41c3d317a..b8fc28f97b 100644
--- a/synapse/storage/data_stores/main/monthly_active_users.py
+++ b/synapse/storage/data_stores/main/monthly_active_users.py
@@ -32,7 +32,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
self._clock = hs.get_clock()
self.hs = hs
# Do not add more reserved users than the total allowable number
- self._new_transaction(
+ self.new_transaction(
dbconn,
"initialise_mau_threepids",
[],
@@ -261,7 +261,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
# never be a big table and alternative approaches (batching multiple
# upserts into a single txn) introduced a lot of extra complexity.
# See https://github.com/matrix-org/synapse/issues/3854 for more
- is_insert = self._simple_upsert_txn(
+ is_insert = self.simple_upsert_txn(
txn,
table="monthly_active_users",
keyvalues={"user_id": user_id},
@@ -281,7 +281,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
"""
- return self._simple_select_one_onecol(
+ return self.simple_select_one_onecol(
table="monthly_active_users",
keyvalues={"user_id": user_id},
retcol="timestamp",
diff --git a/synapse/storage/data_stores/main/openid.py b/synapse/storage/data_stores/main/openid.py
index 79b40044d9..650e49750e 100644
--- a/synapse/storage/data_stores/main/openid.py
+++ b/synapse/storage/data_stores/main/openid.py
@@ -3,7 +3,7 @@ from synapse.storage._base import SQLBaseStore
class OpenIdStore(SQLBaseStore):
def insert_open_id_token(self, token, ts_valid_until_ms, user_id):
- return self._simple_insert(
+ return self.simple_insert(
table="open_id_tokens",
values={
"token": token,
diff --git a/synapse/storage/data_stores/main/presence.py b/synapse/storage/data_stores/main/presence.py
index 523ed6575e..a5e121efd1 100644
--- a/synapse/storage/data_stores/main/presence.py
+++ b/synapse/storage/data_stores/main/presence.py
@@ -46,7 +46,7 @@ class PresenceStore(SQLBaseStore):
txn.call_after(self._get_presence_for_user.invalidate, (state.user_id,))
# Actually insert new rows
- self._simple_insert_many_txn(
+ self.simple_insert_many_txn(
txn,
table="presence_stream",
values=[
@@ -103,7 +103,7 @@ class PresenceStore(SQLBaseStore):
inlineCallbacks=True,
)
def get_presence_for_users(self, user_ids):
- rows = yield self._simple_select_many_batch(
+ rows = yield self.simple_select_many_batch(
table="presence_stream",
column="user_id",
iterable=user_ids,
@@ -129,7 +129,7 @@ class PresenceStore(SQLBaseStore):
return self._presence_id_gen.get_current_token()
def allow_presence_visible(self, observed_localpart, observer_userid):
- return self._simple_insert(
+ return self.simple_insert(
table="presence_allow_inbound",
values={
"observed_user_id": observed_localpart,
@@ -140,7 +140,7 @@ class PresenceStore(SQLBaseStore):
)
def disallow_presence_visible(self, observed_localpart, observer_userid):
- return self._simple_delete_one(
+ return self.simple_delete_one(
table="presence_allow_inbound",
keyvalues={
"observed_user_id": observed_localpart,
diff --git a/synapse/storage/data_stores/main/profile.py b/synapse/storage/data_stores/main/profile.py
index e4e8a1c1d6..c8b5b60301 100644
--- a/synapse/storage/data_stores/main/profile.py
+++ b/synapse/storage/data_stores/main/profile.py
@@ -24,7 +24,7 @@ class ProfileWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_profileinfo(self, user_localpart):
try:
- profile = yield self._simple_select_one(
+ profile = yield self.simple_select_one(
table="profiles",
keyvalues={"user_id": user_localpart},
retcols=("displayname", "avatar_url"),
@@ -42,7 +42,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
def get_profile_displayname(self, user_localpart):
- return self._simple_select_one_onecol(
+ return self.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="displayname",
@@ -50,7 +50,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
def get_profile_avatar_url(self, user_localpart):
- return self._simple_select_one_onecol(
+ return self.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="avatar_url",
@@ -58,7 +58,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
def get_from_remote_profile_cache(self, user_id):
- return self._simple_select_one(
+ return self.simple_select_one(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
retcols=("displayname", "avatar_url"),
@@ -67,12 +67,12 @@ class ProfileWorkerStore(SQLBaseStore):
)
def create_profile(self, user_localpart):
- return self._simple_insert(
+ return self.simple_insert(
table="profiles", values={"user_id": user_localpart}, desc="create_profile"
)
def set_profile_displayname(self, user_localpart, new_displayname):
- return self._simple_update_one(
+ return self.simple_update_one(
table="profiles",
keyvalues={"user_id": user_localpart},
updatevalues={"displayname": new_displayname},
@@ -80,7 +80,7 @@ class ProfileWorkerStore(SQLBaseStore):
)
def set_profile_avatar_url(self, user_localpart, new_avatar_url):
- return self._simple_update_one(
+ return self.simple_update_one(
table="profiles",
keyvalues={"user_id": user_localpart},
updatevalues={"avatar_url": new_avatar_url},
@@ -95,7 +95,7 @@ class ProfileStore(ProfileWorkerStore):
This should only be called when `is_subscribed_remote_profile_for_user`
would return true for the user.
"""
- return self._simple_upsert(
+ return self.simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
values={
@@ -107,7 +107,7 @@ class ProfileStore(ProfileWorkerStore):
)
def update_remote_profile_cache(self, user_id, displayname, avatar_url):
- return self._simple_update(
+ return self.simple_update(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
values={
@@ -125,7 +125,7 @@ class ProfileStore(ProfileWorkerStore):
"""
subscribed = yield self.is_subscribed_remote_profile_for_user(user_id)
if not subscribed:
- yield self._simple_delete(
+ yield self.simple_delete(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
desc="delete_remote_profile_cache",
@@ -155,7 +155,7 @@ class ProfileStore(ProfileWorkerStore):
def is_subscribed_remote_profile_for_user(self, user_id):
"""Check whether we are interested in a remote user's profile.
"""
- res = yield self._simple_select_one_onecol(
+ res = yield self.simple_select_one_onecol(
table="group_users",
keyvalues={"user_id": user_id},
retcol="user_id",
@@ -166,7 +166,7 @@ class ProfileStore(ProfileWorkerStore):
if res:
return True
- res = yield self._simple_select_one_onecol(
+ res = yield self.simple_select_one_onecol(
table="group_invites",
keyvalues={"user_id": user_id},
retcol="user_id",
diff --git a/synapse/storage/data_stores/main/push_rule.py b/synapse/storage/data_stores/main/push_rule.py
index b520062d84..75bd499bcd 100644
--- a/synapse/storage/data_stores/main/push_rule.py
+++ b/synapse/storage/data_stores/main/push_rule.py
@@ -75,7 +75,7 @@ class PushRulesWorkerStore(
def __init__(self, db_conn, hs):
super(PushRulesWorkerStore, self).__init__(db_conn, hs)
- push_rules_prefill, push_rules_id = self._get_cache_dict(
+ push_rules_prefill, push_rules_id = self.get_cache_dict(
db_conn,
"push_rules_stream",
entity_column="user_id",
@@ -100,7 +100,7 @@ class PushRulesWorkerStore(
@cachedInlineCallbacks(max_entries=5000)
def get_push_rules_for_user(self, user_id):
- rows = yield self._simple_select_list(
+ rows = yield self.simple_select_list(
table="push_rules",
keyvalues={"user_name": user_id},
retcols=(
@@ -124,7 +124,7 @@ class PushRulesWorkerStore(
@cachedInlineCallbacks(max_entries=5000)
def get_push_rules_enabled_for_user(self, user_id):
- results = yield self._simple_select_list(
+ results = yield self.simple_select_list(
table="push_rules_enable",
keyvalues={"user_name": user_id},
retcols=("user_name", "rule_id", "enabled"),
@@ -162,7 +162,7 @@ class PushRulesWorkerStore(
results = {user_id: [] for user_id in user_ids}
- rows = yield self._simple_select_many_batch(
+ rows = yield self.simple_select_many_batch(
table="push_rules",
column="user_name",
iterable=user_ids,
@@ -320,7 +320,7 @@ class PushRulesWorkerStore(
results = {user_id: {} for user_id in user_ids}
- rows = yield self._simple_select_many_batch(
+ rows = yield self.simple_select_many_batch(
table="push_rules_enable",
column="user_name",
iterable=user_ids,
@@ -395,7 +395,7 @@ class PushRuleStore(PushRulesWorkerStore):
relative_to_rule = before or after
- res = self._simple_select_one_txn(
+ res = self.simple_select_one_txn(
txn,
table="push_rules",
keyvalues={"user_name": user_id, "rule_id": relative_to_rule},
@@ -499,7 +499,7 @@ class PushRuleStore(PushRulesWorkerStore):
actions_json,
update_stream=True,
):
- """Specialised version of _simple_upsert_txn that picks a push_rule_id
+ """Specialised version of simple_upsert_txn that picks a push_rule_id
using the _push_rule_id_gen if it needs to insert the rule. It assumes
that the "push_rules" table is locked"""
@@ -518,7 +518,7 @@ class PushRuleStore(PushRulesWorkerStore):
# We didn't update a row with the given rule_id so insert one
push_rule_id = self._push_rule_id_gen.get_next()
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="push_rules",
values={
@@ -561,7 +561,7 @@ class PushRuleStore(PushRulesWorkerStore):
"""
def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
- self._simple_delete_one_txn(
+ self.simple_delete_one_txn(
txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}
)
@@ -596,7 +596,7 @@ class PushRuleStore(PushRulesWorkerStore):
self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
):
new_id = self._push_rules_enable_id_gen.get_next()
- self._simple_upsert_txn(
+ self.simple_upsert_txn(
txn,
"push_rules_enable",
{"user_name": user_id, "rule_id": rule_id},
@@ -636,7 +636,7 @@ class PushRuleStore(PushRulesWorkerStore):
update_stream=False,
)
else:
- self._simple_update_one_txn(
+ self.simple_update_one_txn(
txn,
"push_rules",
{"user_name": user_id, "rule_id": rule_id},
@@ -675,7 +675,7 @@ class PushRuleStore(PushRulesWorkerStore):
if data is not None:
values.update(data)
- self._simple_insert_txn(txn, "push_rules_stream", values=values)
+ self.simple_insert_txn(txn, "push_rules_stream", values=values)
txn.call_after(self.get_push_rules_for_user.invalidate, (user_id,))
txn.call_after(self.get_push_rules_enabled_for_user.invalidate, (user_id,))
diff --git a/synapse/storage/data_stores/main/pusher.py b/synapse/storage/data_stores/main/pusher.py
index d76861cdc0..d5a169872b 100644
--- a/synapse/storage/data_stores/main/pusher.py
+++ b/synapse/storage/data_stores/main/pusher.py
@@ -59,7 +59,7 @@ class PusherWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def user_has_pusher(self, user_id):
- ret = yield self._simple_select_one_onecol(
+ ret = yield self.simple_select_one_onecol(
"pushers", {"user_name": user_id}, "id", allow_none=True
)
return ret is not None
@@ -72,7 +72,7 @@ class PusherWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_pushers_by(self, keyvalues):
- ret = yield self._simple_select_list(
+ ret = yield self.simple_select_list(
"pushers",
keyvalues,
[
@@ -193,7 +193,7 @@ class PusherWorkerStore(SQLBaseStore):
inlineCallbacks=True,
)
def get_if_users_have_pushers(self, user_ids):
- rows = yield self._simple_select_many_batch(
+ rows = yield self.simple_select_many_batch(
table="pushers",
column="user_name",
iterable=user_ids,
@@ -229,8 +229,8 @@ class PusherStore(PusherWorkerStore):
):
with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
- # (app_id, pushkey, user_name) so _simple_upsert will retry
- yield self._simple_upsert(
+ # (app_id, pushkey, user_name) so simple_upsert will retry
+ yield self.simple_upsert(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
values={
@@ -269,7 +269,7 @@ class PusherStore(PusherWorkerStore):
txn, self.get_if_user_has_pusher, (user_id,)
)
- self._simple_delete_one_txn(
+ self.simple_delete_one_txn(
txn,
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
@@ -278,7 +278,7 @@ class PusherStore(PusherWorkerStore):
# it's possible for us to end up with duplicate rows for
# (app_id, pushkey, user_id) at different stream_ids, but that
# doesn't really matter.
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="deleted_pushers",
values={
@@ -296,7 +296,7 @@ class PusherStore(PusherWorkerStore):
def update_pusher_last_stream_ordering(
self, app_id, pushkey, user_id, last_stream_ordering
):
- yield self._simple_update_one(
+ yield self.simple_update_one(
"pushers",
{"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
{"last_stream_ordering": last_stream_ordering},
@@ -319,7 +319,7 @@ class PusherStore(PusherWorkerStore):
Returns:
Deferred[bool]: True if the pusher still exists; False if it has been deleted.
"""
- updated = yield self._simple_update(
+ updated = yield self.simple_update(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={
@@ -333,7 +333,7 @@ class PusherStore(PusherWorkerStore):
@defer.inlineCallbacks
def update_pusher_failing_since(self, app_id, pushkey, user_id, failing_since):
- yield self._simple_update(
+ yield self.simple_update(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
updatevalues={"failing_since": failing_since},
@@ -342,7 +342,7 @@ class PusherStore(PusherWorkerStore):
@defer.inlineCallbacks
def get_throttle_params_by_room(self, pusher_id):
- res = yield self._simple_select_list(
+ res = yield self.simple_select_list(
"pusher_throttle",
{"pusher": pusher_id},
["room_id", "last_sent_ts", "throttle_ms"],
@@ -361,8 +361,8 @@ class PusherStore(PusherWorkerStore):
@defer.inlineCallbacks
def set_throttle_params(self, pusher_id, room_id, params):
# no need to lock because `pusher_throttle` has a primary key on
- # (pusher, room_id) so _simple_upsert will retry
- yield self._simple_upsert(
+ # (pusher, room_id) so simple_upsert will retry
+ yield self.simple_upsert(
"pusher_throttle",
{"pusher": pusher_id, "room_id": room_id},
params,
diff --git a/synapse/storage/data_stores/main/receipts.py b/synapse/storage/data_stores/main/receipts.py
index 0c24430f28..380f388e30 100644
--- a/synapse/storage/data_stores/main/receipts.py
+++ b/synapse/storage/data_stores/main/receipts.py
@@ -61,7 +61,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached(num_args=2)
def get_receipts_for_room(self, room_id, receipt_type):
- return self._simple_select_list(
+ return self.simple_select_list(
table="receipts_linearized",
keyvalues={"room_id": room_id, "receipt_type": receipt_type},
retcols=("user_id", "event_id"),
@@ -70,7 +70,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cached(num_args=3)
def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
- return self._simple_select_one_onecol(
+ return self.simple_select_one_onecol(
table="receipts_linearized",
keyvalues={
"room_id": room_id,
@@ -84,7 +84,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
@cachedInlineCallbacks(num_args=2)
def get_receipts_for_user(self, user_id, receipt_type):
- rows = yield self._simple_select_list(
+ rows = yield self.simple_select_list(
table="receipts_linearized",
keyvalues={"user_id": user_id, "receipt_type": receipt_type},
retcols=("room_id", "event_id"),
@@ -280,7 +280,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
args.append(limit)
txn.execute(sql, args)
- return (r[0:5] + (json.loads(r[5]),) for r in txn)
+ return list(r[0:5] + (json.loads(r[5]),) for r in txn)
return self.runInteraction(
"get_all_updated_receipts", get_all_updated_receipts_txn
@@ -335,7 +335,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
otherwise, the rx timestamp of the event that the RR corresponds to
(or 0 if the event is unknown)
"""
- res = self._simple_select_one_txn(
+ res = self.simple_select_one_txn(
txn,
table="events",
retcols=["stream_ordering", "received_ts"],
@@ -388,7 +388,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
(user_id, room_id, receipt_type),
)
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="receipts_linearized",
keyvalues={
@@ -398,7 +398,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
},
)
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="receipts_linearized",
values={
@@ -514,7 +514,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
self._get_linearized_receipts_for_room.invalidate_many, (room_id,)
)
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="receipts_graph",
keyvalues={
@@ -523,7 +523,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"user_id": user_id,
},
)
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="receipts_graph",
values={
diff --git a/synapse/storage/data_stores/main/registration.py b/synapse/storage/data_stores/main/registration.py
index ee1b2b2bbf..debc6706f5 100644
--- a/synapse/storage/data_stores/main/registration.py
+++ b/synapse/storage/data_stores/main/registration.py
@@ -19,7 +19,6 @@ import logging
import re
from six import iterkeys
-from six.moves import range
from twisted.internet import defer
from twisted.internet.defer import Deferred
@@ -46,7 +45,7 @@ class RegistrationWorkerStore(SQLBaseStore):
@cached()
def get_user_by_id(self, user_id):
- return self._simple_select_one(
+ return self.simple_select_one(
table="users",
keyvalues={"name": user_id},
retcols=[
@@ -110,7 +109,7 @@ class RegistrationWorkerStore(SQLBaseStore):
otherwise int representation of the timestamp (as a number of
milliseconds since epoch).
"""
- res = yield self._simple_select_one_onecol(
+ res = yield self.simple_select_one_onecol(
table="account_validity",
keyvalues={"user_id": user_id},
retcol="expiration_ts_ms",
@@ -138,7 +137,7 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
def set_account_validity_for_user_txn(txn):
- self._simple_update_txn(
+ self.simple_update_txn(
txn=txn,
table="account_validity",
keyvalues={"user_id": user_id},
@@ -168,7 +167,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Raises:
StoreError: The provided token is already set for another user.
"""
- yield self._simple_update_one(
+ yield self.simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
updatevalues={"renewal_token": renewal_token},
@@ -185,7 +184,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
defer.Deferred[str]: The ID of the user to which the token belongs.
"""
- res = yield self._simple_select_one_onecol(
+ res = yield self.simple_select_one_onecol(
table="account_validity",
keyvalues={"renewal_token": renewal_token},
retcol="user_id",
@@ -204,7 +203,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
defer.Deferred[str]: The renewal token associated with this user ID.
"""
- res = yield self._simple_select_one_onecol(
+ res = yield self.simple_select_one_onecol(
table="account_validity",
keyvalues={"user_id": user_id},
retcol="renewal_token",
@@ -251,7 +250,7 @@ class RegistrationWorkerStore(SQLBaseStore):
email_sent (bool): Flag which indicates whether a renewal email has been sent
to this user.
"""
- yield self._simple_update_one(
+ yield self.simple_update_one(
table="account_validity",
keyvalues={"user_id": user_id},
updatevalues={"email_sent": email_sent},
@@ -266,7 +265,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Args:
user_id (str): ID of the user to remove from the account validity table.
"""
- yield self._simple_delete_one(
+ yield self.simple_delete_one(
table="account_validity",
keyvalues={"user_id": user_id},
desc="delete_account_validity_for_user",
@@ -282,7 +281,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns (bool):
true iff the user is a server admin, false otherwise.
"""
- res = yield self._simple_select_one_onecol(
+ res = yield self.simple_select_one_onecol(
table="users",
keyvalues={"name": user.to_string()},
retcol="admin",
@@ -300,7 +299,7 @@ class RegistrationWorkerStore(SQLBaseStore):
admin (bool): true iff the user is to be a server admin,
false otherwise.
"""
- return self._simple_update_one(
+ return self.simple_update_one(
table="users",
keyvalues={"name": user.to_string()},
updatevalues={"admin": 1 if admin else 0},
@@ -352,7 +351,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return res
def is_real_user_txn(self, txn, user_id):
- res = self._simple_select_one_onecol_txn(
+ res = self.simple_select_one_onecol_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
@@ -362,7 +361,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return res is None
def is_support_user_txn(self, txn, user_id):
- res = self._simple_select_one_onecol_txn(
+ res = self.simple_select_one_onecol_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
@@ -377,9 +376,7 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
def f(txn):
- sql = (
- "SELECT name, password_hash FROM users" " WHERE lower(name) = lower(?)"
- )
+ sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)"
txn.execute(sql, (user_id,))
return dict(txn)
@@ -397,7 +394,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
str|None: the mxid of the user, or None if they are not known
"""
- return await self._simple_select_one_onecol(
+ return await self.simple_select_one_onecol(
table="user_external_ids",
keyvalues={"auth_provider": auth_provider, "external_id": external_id},
retcol="user_id",
@@ -484,12 +481,8 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
Gets the localpart of the next generated user ID.
- Generated user IDs are integers, and we aim for them to be as small as
- we can. Unfortunately, it's possible some of them are already taken by
- existing users, and there may be gaps in the already taken range. This
- function returns the start of the first allocatable gap. This is to
- avoid the case of ID 1000 being pre-allocated and starting at 1001 while
- 0-999 are available.
+ Generated user IDs are integers, so we find the largest integer user ID
+ already taken and return that plus one.
"""
def _find_next_generated_user_id(txn):
@@ -499,15 +492,14 @@ class RegistrationWorkerStore(SQLBaseStore):
regex = re.compile(r"^@(\d+):")
- found = set()
+ max_found = 0
for (user_id,) in txn:
match = regex.search(user_id)
if match:
- found.add(int(match.group(1)))
- for i in range(len(found) + 1):
- if i not in found:
- return i
+ max_found = max(int(match.group(1)), max_found)
+
+ return max_found + 1
return (
(
@@ -544,7 +536,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
str|None: user id or None if no user id/threepid mapping exists
"""
- ret = self._simple_select_one_txn(
+ ret = self.simple_select_one_txn(
txn,
"user_threepids",
{"medium": medium, "address": address},
@@ -557,7 +549,7 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def user_add_threepid(self, user_id, medium, address, validated_at, added_at):
- yield self._simple_upsert(
+ yield self.simple_upsert(
"user_threepids",
{"medium": medium, "address": address},
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
@@ -565,7 +557,7 @@ class RegistrationWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def user_get_threepids(self, user_id):
- ret = yield self._simple_select_list(
+ ret = yield self.simple_select_list(
"user_threepids",
{"user_id": user_id},
["medium", "address", "validated_at", "added_at"],
@@ -574,9 +566,22 @@ class RegistrationWorkerStore(SQLBaseStore):
return ret
def user_delete_threepid(self, user_id, medium, address):
- return self._simple_delete(
+ return self.simple_delete(
"user_threepids",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
+ desc="user_delete_threepid",
+ )
+
+ def user_delete_threepids(self, user_id: str):
+ """Delete all threepid this user has bound
+
+ Args:
+ user_id: The user id to delete all threepids of
+
+ """
+ return self.simple_delete(
+ "user_threepids",
+ keyvalues={"user_id": user_id},
desc="user_delete_threepids",
)
@@ -596,7 +601,7 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
# We need to use an upsert, in case they user had already bound the
# threepid
- return self._simple_upsert(
+ return self.simple_upsert(
table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
@@ -622,7 +627,7 @@ class RegistrationWorkerStore(SQLBaseStore):
medium (str): The medium of the threepid (e.g "email")
address (str): The address of the threepid (e.g "bob@example.com")
"""
- return self._simple_select_list(
+ return self.simple_select_list(
table="user_threepid_id_server",
keyvalues={"user_id": user_id},
retcols=["medium", "address"],
@@ -643,7 +648,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred
"""
- return self._simple_delete(
+ return self.simple_delete(
table="user_threepid_id_server",
keyvalues={
"user_id": user_id,
@@ -666,7 +671,7 @@ class RegistrationWorkerStore(SQLBaseStore):
Returns:
Deferred[list[str]]: Resolves to a list of identity servers
"""
- return self._simple_select_onecol(
+ return self.simple_select_onecol(
table="user_threepid_id_server",
keyvalues={"user_id": user_id, "medium": medium, "address": address},
retcol="id_server",
@@ -684,7 +689,7 @@ class RegistrationWorkerStore(SQLBaseStore):
defer.Deferred(bool): The requested value.
"""
- res = yield self._simple_select_one_onecol(
+ res = yield self.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="deactivated",
@@ -771,12 +776,12 @@ class RegistrationWorkerStore(SQLBaseStore):
"""
def delete_threepid_session_txn(txn):
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="threepid_validation_token",
keyvalues={"session_id": session_id},
)
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
@@ -921,6 +926,14 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
self._account_validity = hs.config.account_validity
+ if self._account_validity.enabled:
+ self._clock.call_later(
+ 0.0,
+ run_as_background_process,
+ "account_validity_set_expiration_dates",
+ self._set_expiration_date_when_missing,
+ )
+
# Create a background job for culling expired 3PID validity tokens
def start_cull():
# run as a background process to make sure that the database transactions
@@ -948,7 +961,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
next_id = self._access_tokens_id_gen.get_next()
- yield self._simple_insert(
+ yield self.simple_insert(
"access_tokens",
{
"id": next_id,
@@ -1024,7 +1037,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
# Ensure that the guest user actually exists
# ``allow_none=False`` makes this raise an exception
# if the row isn't in the database.
- self._simple_select_one_txn(
+ self.simple_select_one_txn(
txn,
"users",
keyvalues={"name": user_id, "is_guest": 1},
@@ -1032,7 +1045,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
allow_none=False,
)
- self._simple_update_one_txn(
+ self.simple_update_one_txn(
txn,
"users",
keyvalues={"name": user_id, "is_guest": 1},
@@ -1046,7 +1059,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
},
)
else:
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
"users",
values={
@@ -1101,7 +1114,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
external_id: id on that system
user_id: complete mxid that it is mapped to
"""
- return self._simple_insert(
+ return self.simple_insert(
table="user_external_ids",
values={
"auth_provider": auth_provider,
@@ -1119,7 +1132,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
def user_set_password_hash_txn(txn):
- self._simple_update_one_txn(
+ self.simple_update_one_txn(
txn, "users", {"name": user_id}, {"password_hash": password_hash}
)
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
@@ -1139,7 +1152,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
def f(txn):
- self._simple_update_one_txn(
+ self.simple_update_one_txn(
txn,
table="users",
keyvalues={"name": user_id},
@@ -1163,7 +1176,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
def f(txn):
- self._simple_update_one_txn(
+ self.simple_update_one_txn(
txn,
table="users",
keyvalues={"name": user_id},
@@ -1221,7 +1234,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
def delete_access_token(self, access_token):
def f(txn):
- self._simple_delete_one_txn(
+ self.simple_delete_one_txn(
txn, table="access_tokens", keyvalues={"token": access_token}
)
@@ -1233,7 +1246,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
@cachedInlineCallbacks()
def is_guest(self, user_id):
- res = yield self._simple_select_one_onecol(
+ res = yield self.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="is_guest",
@@ -1248,7 +1261,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
Adds a user to the table of users who need to be parted from all the rooms they're
in
"""
- return self._simple_insert(
+ return self.simple_insert(
"users_pending_deactivation",
values={"user_id": user_id},
desc="add_user_pending_deactivation",
@@ -1261,7 +1274,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"""
# XXX: This should be simple_delete_one but we failed to put a unique index on
# the table, so somehow duplicate entries have ended up in it.
- return self._simple_delete(
+ return self.simple_delete(
"users_pending_deactivation",
keyvalues={"user_id": user_id},
desc="del_user_pending_deactivation",
@@ -1272,7 +1285,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
Gets one user from the table of users waiting to be parted from all the rooms
they're in.
"""
- return self._simple_select_one_onecol(
+ return self.simple_select_one_onecol(
"users_pending_deactivation",
keyvalues={},
retcol="user_id",
@@ -1302,7 +1315,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
# Insert everything into a transaction in order to run atomically
def validate_threepid_session_txn(txn):
- row = self._simple_select_one_txn(
+ row = self.simple_select_one_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
@@ -1320,7 +1333,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
400, "This client_secret does not match the provided session_id"
)
- row = self._simple_select_one_txn(
+ row = self.simple_select_one_txn(
txn,
table="threepid_validation_token",
keyvalues={"session_id": session_id, "token": token},
@@ -1345,7 +1358,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
# Looks good. Validate the session
- self._simple_update_txn(
+ self.simple_update_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
@@ -1388,7 +1401,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
if validated_at:
insertion_values["validated_at"] = validated_at
- return self._simple_upsert(
+ return self.simple_upsert(
table="threepid_validation_session",
keyvalues={"session_id": session_id},
values={"last_send_attempt": send_attempt},
@@ -1426,7 +1439,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
def start_or_continue_validation_session_txn(txn):
# Create or update a validation session
- self._simple_upsert_txn(
+ self.simple_upsert_txn(
txn,
table="threepid_validation_session",
keyvalues={"session_id": session_id},
@@ -1439,7 +1452,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
# Create a new validation token with this session ID
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="threepid_validation_token",
values={
@@ -1488,7 +1501,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
)
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
- self._simple_update_one_txn(
+ self.simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
@@ -1497,3 +1510,59 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,)
)
+
+ @defer.inlineCallbacks
+ def _set_expiration_date_when_missing(self):
+ """
+ Retrieves the list of registered users that don't have an expiration date, and
+ adds an expiration date for each of them.
+ """
+
+ def select_users_with_no_expiration_date_txn(txn):
+ """Retrieves the list of registered users with no expiration date from the
+ database, filtering out deactivated users.
+ """
+ sql = (
+ "SELECT users.name FROM users"
+ " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
+ " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
+ )
+ txn.execute(sql, [])
+
+ res = self.cursor_to_dict(txn)
+ if res:
+ for user in res:
+ self.set_expiration_date_for_user_txn(
+ txn, user["name"], use_delta=True
+ )
+
+ yield self.runInteraction(
+ "get_users_with_no_expiration_date",
+ select_users_with_no_expiration_date_txn,
+ )
+
+ def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
+ """Sets an expiration date to the account with the given user ID.
+
+ Args:
+ user_id (str): User ID to set an expiration date for.
+ use_delta (bool): If set to False, the expiration date for the user will be
+ now + validity period. If set to True, this expiration date will be a
+ random value in the [now + period - d ; now + period] range, d being a
+ delta equal to 10% of the validity period.
+ """
+ now_ms = self._clock.time_msec()
+ expiration_ts = now_ms + self._account_validity.period
+
+ if use_delta:
+ expiration_ts = self.rand.randrange(
+ expiration_ts - self._account_validity.startup_job_max_delta,
+ expiration_ts,
+ )
+
+ self.simple_upsert_txn(
+ txn,
+ "account_validity",
+ keyvalues={"user_id": user_id},
+ values={"expiration_ts_ms": expiration_ts, "email_sent": False},
+ )
diff --git a/synapse/storage/data_stores/main/rejections.py b/synapse/storage/data_stores/main/rejections.py
index 7d5de0ea2e..f81f9279a1 100644
--- a/synapse/storage/data_stores/main/rejections.py
+++ b/synapse/storage/data_stores/main/rejections.py
@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
class RejectionsStore(SQLBaseStore):
def _store_rejections_txn(self, txn, event_id, reason):
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="rejections",
values={
@@ -33,7 +33,7 @@ class RejectionsStore(SQLBaseStore):
)
def get_rejection_reason(self, event_id):
- return self._simple_select_one_onecol(
+ return self.simple_select_one_onecol(
table="rejections",
retcol="reason",
keyvalues={"event_id": event_id},
diff --git a/synapse/storage/data_stores/main/relations.py b/synapse/storage/data_stores/main/relations.py
index 858f65582b..aa5e10538b 100644
--- a/synapse/storage/data_stores/main/relations.py
+++ b/synapse/storage/data_stores/main/relations.py
@@ -352,7 +352,7 @@ class RelationsStore(RelationsWorkerStore):
aggregation_key = relation.get("key")
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="event_relations",
values={
@@ -380,6 +380,6 @@ class RelationsStore(RelationsWorkerStore):
redacted_event_id (str): The event that was redacted.
"""
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
)
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index 67bb1b6f60..f309e3640c 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -19,12 +19,16 @@ import logging
import re
from typing import Optional, Tuple
+from six import integer_types
+
from canonicaljson import json
from twisted.internet import defer
+from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
+from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.data_stores.main.search import SearchStore
from synapse.types import ThirdPartyInstanceID
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@@ -50,7 +54,7 @@ class RoomWorkerStore(SQLBaseStore):
Returns:
A dict containing the room information, or None if the room is unknown.
"""
- return self._simple_select_one(
+ return self.simple_select_one(
table="rooms",
keyvalues={"room_id": room_id},
retcols=("room_id", "is_public", "creator"),
@@ -59,7 +63,7 @@ class RoomWorkerStore(SQLBaseStore):
)
def get_public_room_ids(self):
- return self._simple_select_onecol(
+ return self.simple_select_onecol(
table="rooms",
keyvalues={"is_public": True},
retcol="room_id",
@@ -263,7 +267,7 @@ class RoomWorkerStore(SQLBaseStore):
@cached(max_entries=10000)
def is_room_blocked(self, room_id):
- return self._simple_select_one_onecol(
+ return self.simple_select_one_onecol(
table="blocked_rooms",
keyvalues={"room_id": room_id},
retcol="1",
@@ -284,7 +288,7 @@ class RoomWorkerStore(SQLBaseStore):
of RatelimitOverride are None or 0 then ratelimitng has been
disabled for that user entirely.
"""
- row = yield self._simple_select_one(
+ row = yield self.simple_select_one(
table="ratelimit_override",
keyvalues={"user_id": user_id},
retcols=("messages_per_second", "burst_count"),
@@ -300,8 +304,148 @@ class RoomWorkerStore(SQLBaseStore):
else:
return None
+ @cachedInlineCallbacks()
+ def get_retention_policy_for_room(self, room_id):
+ """Get the retention policy for a given room.
+
+ If no retention policy has been found for this room, returns a policy defined
+ by the configured default policy (which has None as both the 'min_lifetime' and
+ the 'max_lifetime' if no default policy has been defined in the server's
+ configuration).
+
+ Args:
+ room_id (str): The ID of the room to get the retention policy of.
+
+ Returns:
+ dict[int, int]: "min_lifetime" and "max_lifetime" for this room.
+ """
+
+ def get_retention_policy_for_room_txn(txn):
+ txn.execute(
+ """
+ SELECT min_lifetime, max_lifetime FROM room_retention
+ INNER JOIN current_state_events USING (event_id, room_id)
+ WHERE room_id = ?;
+ """,
+ (room_id,),
+ )
+
+ return self.cursor_to_dict(txn)
+
+ ret = yield self.runInteraction(
+ "get_retention_policy_for_room", get_retention_policy_for_room_txn,
+ )
+
+ # If we don't know this room ID, ret will be None, in this case return the default
+ # policy.
+ if not ret:
+ defer.returnValue(
+ {
+ "min_lifetime": self.config.retention_default_min_lifetime,
+ "max_lifetime": self.config.retention_default_max_lifetime,
+ }
+ )
+
+ row = ret[0]
+
+ # If one of the room's policy's attributes isn't defined, use the matching
+ # attribute from the default policy.
+ # The default values will be None if no default policy has been defined, or if one
+ # of the attributes is missing from the default policy.
+ if row["min_lifetime"] is None:
+ row["min_lifetime"] = self.config.retention_default_min_lifetime
+
+ if row["max_lifetime"] is None:
+ row["max_lifetime"] = self.config.retention_default_max_lifetime
+
+ defer.returnValue(row)
+
+
+class RoomBackgroundUpdateStore(BackgroundUpdateStore):
+ def __init__(self, db_conn, hs):
+ super(RoomBackgroundUpdateStore, self).__init__(db_conn, hs)
+
+ self.config = hs.config
+
+ self.register_background_update_handler(
+ "insert_room_retention", self._background_insert_retention,
+ )
+
+ @defer.inlineCallbacks
+ def _background_insert_retention(self, progress, batch_size):
+ """Retrieves a list of all rooms within a range and inserts an entry for each of
+ them into the room_retention table.
+ NULLs the property's columns if missing from the retention event in the room's
+ state (or NULLs all of them if there's no retention event in the room's state),
+ so that we fall back to the server's retention policy.
+ """
+
+ last_room = progress.get("room_id", "")
+
+ def _background_insert_retention_txn(txn):
+ txn.execute(
+ """
+ SELECT state.room_id, state.event_id, events.json
+ FROM current_state_events as state
+ LEFT JOIN event_json AS events ON (state.event_id = events.event_id)
+ WHERE state.room_id > ? AND state.type = '%s'
+ ORDER BY state.room_id ASC
+ LIMIT ?;
+ """
+ % EventTypes.Retention,
+ (last_room, batch_size),
+ )
+
+ rows = self.cursor_to_dict(txn)
+
+ if not rows:
+ return True
+
+ for row in rows:
+ if not row["json"]:
+ retention_policy = {}
+ else:
+ ev = json.loads(row["json"])
+ retention_policy = json.dumps(ev["content"])
+
+ self.simple_insert_txn(
+ txn=txn,
+ table="room_retention",
+ values={
+ "room_id": row["room_id"],
+ "event_id": row["event_id"],
+ "min_lifetime": retention_policy.get("min_lifetime"),
+ "max_lifetime": retention_policy.get("max_lifetime"),
+ },
+ )
+
+ logger.info("Inserted %d rows into room_retention", len(rows))
+
+ self._background_update_progress_txn(
+ txn, "insert_room_retention", {"room_id": rows[-1]["room_id"]}
+ )
+
+ if batch_size > len(rows):
+ return True
+ else:
+ return False
+
+ end = yield self.runInteraction(
+ "insert_room_retention", _background_insert_retention_txn,
+ )
+
+ if end:
+ yield self._end_background_update("insert_room_retention")
+
+ defer.returnValue(batch_size)
+
+
+class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
+ def __init__(self, db_conn, hs):
+ super(RoomStore, self).__init__(db_conn, hs)
+
+ self.config = hs.config
-class RoomStore(RoomWorkerStore, SearchStore):
@defer.inlineCallbacks
def store_room(self, room_id, room_creator_user_id, is_public):
"""Stores a room.
@@ -317,7 +461,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
try:
def store_room_txn(txn, next_id):
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
"rooms",
{
@@ -327,7 +471,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
},
)
if is_public:
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="public_room_list_stream",
values={
@@ -346,14 +490,14 @@ class RoomStore(RoomWorkerStore, SearchStore):
@defer.inlineCallbacks
def set_room_is_public(self, room_id, is_public):
def set_room_is_public_txn(txn, next_id):
- self._simple_update_one_txn(
+ self.simple_update_one_txn(
txn,
table="rooms",
keyvalues={"room_id": room_id},
updatevalues={"is_public": is_public},
)
- entries = self._simple_select_list_txn(
+ entries = self.simple_select_list_txn(
txn,
table="public_room_list_stream",
keyvalues={
@@ -371,7 +515,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
add_to_stream = bool(entries[-1]["visibility"]) != is_public
if add_to_stream:
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="public_room_list_stream",
values={
@@ -411,7 +555,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
def set_room_is_public_appservice_txn(txn, next_id):
if is_public:
try:
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="appservice_room_list",
values={
@@ -424,7 +568,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
# We've already inserted, nothing to do.
return
else:
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="appservice_room_list",
keyvalues={
@@ -434,7 +578,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
},
)
- entries = self._simple_select_list_txn(
+ entries = self.simple_select_list_txn(
txn,
table="public_room_list_stream",
keyvalues={
@@ -452,7 +596,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
add_to_stream = bool(entries[-1]["visibility"]) != is_public
if add_to_stream:
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="public_room_list_stream",
values={
@@ -502,11 +646,40 @@ class RoomStore(RoomWorkerStore, SearchStore):
txn, event, "content.body", event.content["body"]
)
+ def _store_retention_policy_for_room_txn(self, txn, event):
+ if hasattr(event, "content") and (
+ "min_lifetime" in event.content or "max_lifetime" in event.content
+ ):
+ if (
+ "min_lifetime" in event.content
+ and not isinstance(event.content.get("min_lifetime"), integer_types)
+ ) or (
+ "max_lifetime" in event.content
+ and not isinstance(event.content.get("max_lifetime"), integer_types)
+ ):
+ # Ignore the event if one of the value isn't an integer.
+ return
+
+ self.simple_insert_txn(
+ txn=txn,
+ table="room_retention",
+ values={
+ "room_id": event.room_id,
+ "event_id": event.event_id,
+ "min_lifetime": event.content.get("min_lifetime"),
+ "max_lifetime": event.content.get("max_lifetime"),
+ },
+ )
+
+ self._invalidate_cache_and_stream(
+ txn, self.get_retention_policy_for_room, (event.room_id,)
+ )
+
def add_event_report(
self, room_id, event_id, user_id, reason, content, received_ts
):
next_id = self._event_reports_id_gen.get_next()
- return self._simple_insert(
+ return self.simple_insert(
table="event_reports",
values={
"id": next_id,
@@ -552,7 +725,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
Returns:
Deferred
"""
- yield self._simple_upsert(
+ yield self.simple_upsert(
table="blocked_rooms",
keyvalues={"room_id": room_id},
values={},
@@ -683,3 +856,89 @@ class RoomStore(RoomWorkerStore, SearchStore):
remote_media_mxcs.append((hostname, media_id))
return local_media_mxcs, remote_media_mxcs
+
+ @defer.inlineCallbacks
+ def get_rooms_for_retention_period_in_range(
+ self, min_ms, max_ms, include_null=False
+ ):
+ """Retrieves all of the rooms within the given retention range.
+
+ Optionally includes the rooms which don't have a retention policy.
+
+ Args:
+ min_ms (int|None): Duration in milliseconds that define the lower limit of
+ the range to handle (exclusive). If None, doesn't set a lower limit.
+ max_ms (int|None): Duration in milliseconds that define the upper limit of
+ the range to handle (inclusive). If None, doesn't set an upper limit.
+ include_null (bool): Whether to include rooms which retention policy is NULL
+ in the returned set.
+
+ Returns:
+ dict[str, dict]: The rooms within this range, along with their retention
+ policy. The key is "room_id", and maps to a dict describing the retention
+ policy associated with this room ID. The keys for this nested dict are
+ "min_lifetime" (int|None), and "max_lifetime" (int|None).
+ """
+
+ def get_rooms_for_retention_period_in_range_txn(txn):
+ range_conditions = []
+ args = []
+
+ if min_ms is not None:
+ range_conditions.append("max_lifetime > ?")
+ args.append(min_ms)
+
+ if max_ms is not None:
+ range_conditions.append("max_lifetime <= ?")
+ args.append(max_ms)
+
+ # Do a first query which will retrieve the rooms that have a retention policy
+ # in their current state.
+ sql = """
+ SELECT room_id, min_lifetime, max_lifetime FROM room_retention
+ INNER JOIN current_state_events USING (event_id, room_id)
+ """
+
+ if len(range_conditions):
+ sql += " WHERE (" + " AND ".join(range_conditions) + ")"
+
+ if include_null:
+ sql += " OR max_lifetime IS NULL"
+
+ txn.execute(sql, args)
+
+ rows = self.cursor_to_dict(txn)
+ rooms_dict = {}
+
+ for row in rows:
+ rooms_dict[row["room_id"]] = {
+ "min_lifetime": row["min_lifetime"],
+ "max_lifetime": row["max_lifetime"],
+ }
+
+ if include_null:
+ # If required, do a second query that retrieves all of the rooms we know
+ # of so we can handle rooms with no retention policy.
+ sql = "SELECT DISTINCT room_id FROM current_state_events"
+
+ txn.execute(sql)
+
+ rows = self.cursor_to_dict(txn)
+
+ # If a room isn't already in the dict (i.e. it doesn't have a retention
+ # policy in its state), add it with a null policy.
+ for row in rows:
+ if row["room_id"] not in rooms_dict:
+ rooms_dict[row["room_id"]] = {
+ "min_lifetime": None,
+ "max_lifetime": None,
+ }
+
+ return rooms_dict
+
+ rooms = yield self.runInteraction(
+ "get_rooms_for_retention_period_in_range",
+ get_rooms_for_retention_period_in_range_txn,
+ )
+
+ defer.returnValue(rooms)
diff --git a/synapse/storage/data_stores/main/roommember.py b/synapse/storage/data_stores/main/roommember.py
index 2af24a20b7..fe2428a281 100644
--- a/synapse/storage/data_stores/main/roommember.py
+++ b/synapse/storage/data_stores/main/roommember.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
+from typing import Iterable, List
from six import iteritems, itervalues
@@ -127,7 +128,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
membership column is up to date
"""
- pending_update = self._simple_select_one_txn(
+ pending_update = self.simple_select_one_txn(
txn,
table="background_updates",
keyvalues={"update_name": _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME},
@@ -602,7 +603,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
to `user_id` and ProfileInfo (or None if not join event).
"""
- rows = yield self._simple_select_many_batch(
+ rows = yield self.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,
@@ -642,7 +643,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# the returned user actually has the correct domain.
like_clause = "%:" + host
- rows = yield self._execute("is_host_joined", None, sql, room_id, like_clause)
+ rows = yield self.execute("is_host_joined", None, sql, room_id, like_clause)
if not rows:
return False
@@ -682,7 +683,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# the returned user actually has the correct domain.
like_clause = "%:" + host
- rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause)
+ rows = yield self.execute("was_host_joined", None, sql, room_id, like_clause)
if not rows:
return False
@@ -804,7 +805,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
Deferred[set[str]]: Set of room IDs.
"""
- room_ids = yield self._simple_select_onecol(
+ room_ids = yield self.simple_select_onecol(
table="room_memberships",
keyvalues={"membership": Membership.JOIN, "user_id": user_id},
retcol="room_id",
@@ -813,6 +814,22 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return set(room_ids)
+ def get_membership_from_event_ids(
+ self, member_event_ids: Iterable[str]
+ ) -> List[dict]:
+ """Get user_id and membership of a set of event IDs.
+ """
+
+ return self.simple_select_many_batch(
+ table="room_memberships",
+ column="event_id",
+ iterable=member_event_ids,
+ retcols=("user_id", "membership", "event_id"),
+ keyvalues={},
+ batch_size=500,
+ desc="get_membership_from_event_ids",
+ )
+
class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
def __init__(self, db_conn, hs):
@@ -973,7 +990,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
"""
- self._simple_insert_many_txn(
+ self.simple_insert_many_txn(
txn,
table="room_memberships",
values=[
@@ -1011,7 +1028,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
is_mine = self.hs.is_mine_id(event.state_key)
if is_new_state and is_mine:
if event.membership == Membership.INVITE:
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="local_invites",
values={
diff --git a/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql
new file mode 100644
index 0000000000..81a36a8b1d
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/event_expiry.sql
@@ -0,0 +1,21 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE IF NOT EXISTS event_expiry (
+ event_id TEXT PRIMARY KEY,
+ expiry_ts BIGINT NOT NULL
+);
+
+CREATE INDEX event_expiry_expiry_ts_idx ON event_expiry(expiry_ts);
diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql b/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql
new file mode 100644
index 0000000000..7d70dd071e
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/room_key_etag.sql
@@ -0,0 +1,17 @@
+/* Copyright 2019 Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- store the current etag of backup version
+ALTER TABLE e2e_room_keys_versions ADD COLUMN etag BIGINT;
diff --git a/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql b/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql
new file mode 100644
index 0000000000..ee6cdf7a14
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/room_retention.sql
@@ -0,0 +1,33 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Tracks the retention policy of a room.
+-- A NULL max_lifetime or min_lifetime means that the matching property is not defined in
+-- the room's retention policy state event.
+-- If a room doesn't have a retention policy state event in its state, both max_lifetime
+-- and min_lifetime are NULL.
+CREATE TABLE IF NOT EXISTS room_retention(
+ room_id TEXT,
+ event_id TEXT,
+ min_lifetime BIGINT,
+ max_lifetime BIGINT,
+
+ PRIMARY KEY(room_id, event_id)
+);
+
+CREATE INDEX room_retention_max_lifetime_idx on room_retention(max_lifetime);
+
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('insert_room_retention', '{}');
diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql b/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql
index 27a96123e3..5c5fffcafb 100644
--- a/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql
+++ b/synapse/storage/data_stores/main/schema/delta/56/signing_keys.sql
@@ -40,7 +40,8 @@ CREATE TABLE IF NOT EXISTS e2e_cross_signing_signatures (
signature TEXT NOT NULL
);
-CREATE UNIQUE INDEX e2e_cross_signing_signatures_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id);
+-- replaced by the index created in signing_keys_nonunique_signatures.sql
+-- CREATE UNIQUE INDEX e2e_cross_signing_signatures_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id);
-- stream of user signature updates
CREATE TABLE IF NOT EXISTS user_signature_stream (
diff --git a/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql b/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql
new file mode 100644
index 0000000000..0aa90ebf0c
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/delta/56/signing_keys_nonunique_signatures.sql
@@ -0,0 +1,22 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* The cross-signing signatures index should not be a unique index, because a
+ * user may upload multiple signatures for the same target user. The previous
+ * index was unique, so delete it if it's there and create a new non-unique
+ * index. */
+
+DROP INDEX IF EXISTS e2e_cross_signing_signatures_idx; CREATE INDEX IF NOT
+EXISTS e2e_cross_signing_signatures2_idx ON e2e_cross_signing_signatures(user_id, target_user_id, target_device_id);
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index d1d7c6863d..f735cf095c 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -441,7 +441,7 @@ class SearchStore(SearchBackgroundUpdateStore):
# entire table from the database.
sql += " ORDER BY rank DESC LIMIT 500"
- results = yield self._execute("search_msgs", self.cursor_to_dict, sql, *args)
+ results = yield self.execute("search_msgs", self.cursor_to_dict, sql, *args)
results = list(filter(lambda row: row["room_id"] in room_ids, results))
@@ -455,7 +455,7 @@ class SearchStore(SearchBackgroundUpdateStore):
count_sql += " GROUP BY room_id"
- count_results = yield self._execute(
+ count_results = yield self.execute(
"search_rooms_count", self.cursor_to_dict, count_sql, *count_args
)
@@ -586,7 +586,7 @@ class SearchStore(SearchBackgroundUpdateStore):
args.append(limit)
- results = yield self._execute("search_rooms", self.cursor_to_dict, sql, *args)
+ results = yield self.execute("search_rooms", self.cursor_to_dict, sql, *args)
results = list(filter(lambda row: row["room_id"] in room_ids, results))
@@ -600,7 +600,7 @@ class SearchStore(SearchBackgroundUpdateStore):
count_sql += " GROUP BY room_id"
- count_results = yield self._execute(
+ count_results = yield self.execute(
"search_rooms_count", self.cursor_to_dict, count_sql, *count_args
)
diff --git a/synapse/storage/data_stores/main/signatures.py b/synapse/storage/data_stores/main/signatures.py
index 556191b76f..f3da29ce14 100644
--- a/synapse/storage/data_stores/main/signatures.py
+++ b/synapse/storage/data_stores/main/signatures.py
@@ -98,4 +98,4 @@ class SignatureStore(SignatureWorkerStore):
}
)
- self._simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
+ self.simple_insert_many_txn(txn, table="event_reference_hashes", values=vals)
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 6a90daea31..2b33ec1a35 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -89,7 +89,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
count = 0
while next_group:
- next_group = self._simple_select_one_onecol_txn(
+ next_group = self.simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
@@ -192,7 +192,7 @@ class StateGroupBackgroundUpdateStore(SQLBaseStore):
):
break
- next_group = self._simple_select_one_onecol_txn(
+ next_group = self.simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
@@ -431,7 +431,7 @@ class StateGroupWorkerStore(
"""
def _get_state_group_delta_txn(txn):
- prev_group = self._simple_select_one_onecol_txn(
+ prev_group = self.simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": state_group},
@@ -442,7 +442,7 @@ class StateGroupWorkerStore(
if not prev_group:
return _GetStateGroupDelta(None, None)
- delta_ids = self._simple_select_list_txn(
+ delta_ids = self.simple_select_list_txn(
txn,
table="state_groups_state",
keyvalues={"state_group": state_group},
@@ -644,7 +644,7 @@ class StateGroupWorkerStore(
@cached(max_entries=50000)
def _get_state_group_for_event(self, event_id):
- return self._simple_select_one_onecol(
+ return self.simple_select_one_onecol(
table="event_to_state_groups",
keyvalues={"event_id": event_id},
retcol="state_group",
@@ -661,7 +661,7 @@ class StateGroupWorkerStore(
def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
- rows = yield self._simple_select_many_batch(
+ rows = yield self.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
iterable=event_ids,
@@ -902,7 +902,7 @@ class StateGroupWorkerStore(
state_group = self.database_engine.get_next_state_group_id(txn)
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="state_groups",
values={"id": state_group, "room_id": room_id, "event_id": event_id},
@@ -911,7 +911,7 @@ class StateGroupWorkerStore(
# We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades.
if prev_group:
- is_in_db = self._simple_select_one_onecol_txn(
+ is_in_db = self.simple_select_one_onecol_txn(
txn,
table="state_groups",
keyvalues={"id": prev_group},
@@ -926,13 +926,13 @@ class StateGroupWorkerStore(
potential_hops = self._count_state_group_hops_txn(txn, prev_group)
if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="state_group_edges",
values={"state_group": state_group, "prev_state_group": prev_group},
)
- self._simple_insert_many_txn(
+ self.simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
@@ -947,7 +947,7 @@ class StateGroupWorkerStore(
],
)
else:
- self._simple_insert_many_txn(
+ self.simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
@@ -1007,7 +1007,7 @@ class StateGroupWorkerStore(
referenced.
"""
- rows = yield self._simple_select_many_batch(
+ rows = yield self.simple_select_many_batch(
table="event_to_state_groups",
column="state_group",
iterable=state_groups,
@@ -1065,7 +1065,7 @@ class StateBackgroundUpdateStore(
batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
if max_group is None:
- rows = yield self._execute(
+ rows = yield self.execute(
"_background_deduplicate_state",
None,
"SELECT coalesce(max(id), 0) FROM state_groups",
@@ -1135,13 +1135,13 @@ class StateBackgroundUpdateStore(
if prev_state.get(key, None) != value
}
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": state_group},
)
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="state_group_edges",
values={
@@ -1150,13 +1150,13 @@ class StateBackgroundUpdateStore(
},
)
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="state_groups_state",
keyvalues={"state_group": state_group},
)
- self._simple_insert_many_txn(
+ self.simple_insert_many_txn(
txn,
table="state_groups_state",
values=[
@@ -1263,7 +1263,7 @@ class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
state_groups[event.event_id] = context.state_group
- self._simple_insert_many_txn(
+ self.simple_insert_many_txn(
txn,
table="event_to_state_groups",
values=[
diff --git a/synapse/storage/data_stores/main/state_deltas.py b/synapse/storage/data_stores/main/state_deltas.py
index 28f33ec18f..03b908026b 100644
--- a/synapse/storage/data_stores/main/state_deltas.py
+++ b/synapse/storage/data_stores/main/state_deltas.py
@@ -105,7 +105,7 @@ class StateDeltasStore(SQLBaseStore):
)
def _get_max_stream_id_in_current_state_deltas_txn(self, txn):
- return self._simple_select_one_onecol_txn(
+ return self.simple_select_one_onecol_txn(
txn,
table="current_state_delta_stream",
keyvalues={},
diff --git a/synapse/storage/data_stores/main/stats.py b/synapse/storage/data_stores/main/stats.py
index 45b3de7d56..b306478824 100644
--- a/synapse/storage/data_stores/main/stats.py
+++ b/synapse/storage/data_stores/main/stats.py
@@ -186,7 +186,7 @@ class StatsStore(StateDeltasStore):
"""
Returns the stats processor positions.
"""
- return self._simple_select_one_onecol(
+ return self.simple_select_one_onecol(
table="stats_incremental_position",
keyvalues={},
retcol="stream_id",
@@ -215,7 +215,7 @@ class StatsStore(StateDeltasStore):
if field and "\0" in field:
fields[col] = None
- return self._simple_upsert(
+ return self.simple_upsert(
table="room_stats_state",
keyvalues={"room_id": room_id},
values=fields,
@@ -257,14 +257,14 @@ class StatsStore(StateDeltasStore):
ABSOLUTE_STATS_FIELDS[stats_type] + PER_SLICE_FIELDS[stats_type]
)
- slice_list = self._simple_select_list_paginate_txn(
+ slice_list = self.simple_select_list_paginate_txn(
txn,
table + "_historical",
- {id_col: stats_id},
"end_ts",
start,
size,
retcols=selected_columns + ["bucket_size", "end_ts"],
+ keyvalues={id_col: stats_id},
order_direction="DESC",
)
@@ -282,7 +282,7 @@ class StatsStore(StateDeltasStore):
"name", "topic", "canonical_alias", "avatar", "join_rules",
"history_visibility"
"""
- return self._simple_select_one(
+ return self.simple_select_one(
"room_stats_state",
{"room_id": room_id},
retcols=(
@@ -308,7 +308,7 @@ class StatsStore(StateDeltasStore):
"""
table, id_col = TYPE_TO_TABLE[stats_type]
- return self._simple_select_one_onecol(
+ return self.simple_select_one_onecol(
"%s_current" % (table,),
keyvalues={id_col: id},
retcol="completed_delta_stream_id",
@@ -344,7 +344,7 @@ class StatsStore(StateDeltasStore):
complete_with_stream_id=stream_id,
)
- self._simple_update_one_txn(
+ self.simple_update_one_txn(
txn,
table="stats_incremental_position",
keyvalues={},
@@ -517,17 +517,17 @@ class StatsStore(StateDeltasStore):
else:
self.database_engine.lock_table(txn, table)
retcols = list(chain(absolutes.keys(), additive_relatives.keys()))
- current_row = self._simple_select_one_txn(
+ current_row = self.simple_select_one_txn(
txn, table, keyvalues, retcols, allow_none=True
)
if current_row is None:
merged_dict = {**keyvalues, **absolutes, **additive_relatives}
- self._simple_insert_txn(txn, table, merged_dict)
+ self.simple_insert_txn(txn, table, merged_dict)
else:
for (key, val) in additive_relatives.items():
current_row[key] += val
current_row.update(absolutes)
- self._simple_update_one_txn(txn, table, keyvalues, current_row)
+ self.simple_update_one_txn(txn, table, keyvalues, current_row)
def _upsert_copy_from_table_with_additive_relatives_txn(
self,
@@ -614,11 +614,11 @@ class StatsStore(StateDeltasStore):
txn.execute(sql, qargs)
else:
self.database_engine.lock_table(txn, into_table)
- src_row = self._simple_select_one_txn(
+ src_row = self.simple_select_one_txn(
txn, src_table, keyvalues, copy_columns
)
all_dest_keyvalues = {**keyvalues, **extra_dst_keyvalues}
- dest_current_row = self._simple_select_one_txn(
+ dest_current_row = self.simple_select_one_txn(
txn,
into_table,
keyvalues=all_dest_keyvalues,
@@ -634,11 +634,11 @@ class StatsStore(StateDeltasStore):
**src_row,
**additive_relatives,
}
- self._simple_insert_txn(txn, into_table, merged_dict)
+ self.simple_insert_txn(txn, into_table, merged_dict)
else:
for (key, val) in additive_relatives.items():
src_row[key] = dest_current_row[key] + val
- self._simple_update_txn(txn, into_table, all_dest_keyvalues, src_row)
+ self.simple_update_txn(txn, into_table, all_dest_keyvalues, src_row)
def get_changes_room_total_events_and_bytes(self, min_pos, max_pos):
"""Fetches the counts of events in the given range of stream IDs.
@@ -735,7 +735,7 @@ class StatsStore(StateDeltasStore):
def _fetch_current_state_stats(txn):
pos = self.get_room_max_stream_ordering()
- rows = self._simple_select_many_txn(
+ rows = self.simple_select_many_txn(
txn,
table="current_state_events",
column="type",
diff --git a/synapse/storage/data_stores/main/stream.py b/synapse/storage/data_stores/main/stream.py
index 8780fdd989..60487c4559 100644
--- a/synapse/storage/data_stores/main/stream.py
+++ b/synapse/storage/data_stores/main/stream.py
@@ -1,5 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017 Vector Creations Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -252,7 +255,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
super(StreamWorkerStore, self).__init__(db_conn, hs)
events_max = self.get_room_max_stream_ordering()
- event_cache_prefill, min_event_val = self._get_cache_dict(
+ event_cache_prefill, min_event_val = self.get_cache_dict(
db_conn,
"events",
entity_column="room_id",
@@ -573,7 +576,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
A deferred "s%d" stream token.
"""
- return self._simple_select_one_onecol(
+ return self.simple_select_one_onecol(
table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
).addCallback(lambda row: "s%d" % (row,))
@@ -586,7 +589,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
A deferred "t%d-%d" topological token.
"""
- return self._simple_select_one(
+ return self.simple_select_one(
table="events",
keyvalues={"event_id": event_id},
retcols=("stream_ordering", "topological_ordering"),
@@ -610,13 +613,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"SELECT coalesce(max(topological_ordering), 0) FROM events"
" WHERE room_id = ? AND stream_ordering < ?"
)
- return self._execute(
+ return self.execute(
"get_max_topological_token", None, sql, room_id, stream_key
).addCallback(lambda r: r[0][0] if r else 0)
def _get_max_topological_txn(self, txn, room_id):
txn.execute(
- "SELECT MAX(topological_ordering) FROM events" " WHERE room_id = ?",
+ "SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
(room_id,),
)
@@ -706,7 +709,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
dict
"""
- results = self._simple_select_one_txn(
+ results = self.simple_select_one_txn(
txn,
"events",
keyvalues={"event_id": event_id, "room_id": room_id},
@@ -794,7 +797,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return upper_bound, events
def get_federation_out_pos(self, typ):
- return self._simple_select_one_onecol(
+ return self.simple_select_one_onecol(
table="federation_stream_position",
retcol="stream_id",
keyvalues={"type": typ},
@@ -802,7 +805,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
def update_federation_out_pos(self, typ, stream_id):
- return self._simple_update_one(
+ return self.simple_update_one(
table="federation_stream_position",
keyvalues={"type": typ},
updatevalues={"stream_id": stream_id},
diff --git a/synapse/storage/data_stores/main/tags.py b/synapse/storage/data_stores/main/tags.py
index 10d1887f75..85012403be 100644
--- a/synapse/storage/data_stores/main/tags.py
+++ b/synapse/storage/data_stores/main/tags.py
@@ -41,7 +41,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
tag strings to tag content.
"""
- deferred = self._simple_select_list(
+ deferred = self.simple_select_list(
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
)
@@ -83,9 +83,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
)
def get_tag_content(txn, tag_ids):
- sql = (
- "SELECT tag, content" " FROM room_tags" " WHERE user_id=? AND room_id=?"
- )
+ sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
results = []
for stream_id, user_id, room_id in tag_ids:
txn.execute(sql, (user_id, room_id))
@@ -155,7 +153,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
Returns:
A deferred list of string tags.
"""
- return self._simple_select_list(
+ return self.simple_select_list(
table="room_tags",
keyvalues={"user_id": user_id, "room_id": room_id},
retcols=("tag", "content"),
@@ -180,7 +178,7 @@ class TagsStore(TagsWorkerStore):
content_json = json.dumps(content)
def add_tag_txn(txn, next_id):
- self._simple_upsert_txn(
+ self.simple_upsert_txn(
txn,
table="room_tags",
keyvalues={"user_id": user_id, "room_id": room_id, "tag": tag},
diff --git a/synapse/storage/data_stores/main/transactions.py b/synapse/storage/data_stores/main/transactions.py
index 01b1be5e14..c162f3ea16 100644
--- a/synapse/storage/data_stores/main/transactions.py
+++ b/synapse/storage/data_stores/main/transactions.py
@@ -85,7 +85,7 @@ class TransactionStore(SQLBaseStore):
)
def _get_received_txn_response(self, txn, transaction_id, origin):
- result = self._simple_select_one_txn(
+ result = self.simple_select_one_txn(
txn,
table="received_transactions",
keyvalues={"transaction_id": transaction_id, "origin": origin},
@@ -119,7 +119,7 @@ class TransactionStore(SQLBaseStore):
response_json (str)
"""
- return self._simple_insert(
+ return self.simple_insert(
table="received_transactions",
values={
"transaction_id": transaction_id,
@@ -160,7 +160,7 @@ class TransactionStore(SQLBaseStore):
return result
def _get_destination_retry_timings(self, txn, destination):
- result = self._simple_select_one_txn(
+ result = self.simple_select_one_txn(
txn,
table="destinations",
keyvalues={"destination": destination},
@@ -227,7 +227,7 @@ class TransactionStore(SQLBaseStore):
# We need to be careful here as the data may have changed from under us
# due to a worker setting the timings.
- prev_row = self._simple_select_one_txn(
+ prev_row = self.simple_select_one_txn(
txn,
table="destinations",
keyvalues={"destination": destination},
@@ -236,7 +236,7 @@ class TransactionStore(SQLBaseStore):
)
if not prev_row:
- self._simple_insert_txn(
+ self.simple_insert_txn(
txn,
table="destinations",
values={
@@ -247,7 +247,7 @@ class TransactionStore(SQLBaseStore):
},
)
elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval:
- self._simple_update_one_txn(
+ self.simple_update_one_txn(
txn,
"destinations",
keyvalues={"destination": destination},
diff --git a/synapse/storage/data_stores/main/user_directory.py b/synapse/storage/data_stores/main/user_directory.py
index 652abe0e6a..1a85aabbfb 100644
--- a/synapse/storage/data_stores/main/user_directory.py
+++ b/synapse/storage/data_stores/main/user_directory.py
@@ -85,7 +85,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
"""
txn.execute(sql)
rooms = [{"room_id": x[0], "events": x[1]} for x in txn.fetchall()]
- self._simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
+ self.simple_insert_many_txn(txn, TEMP_TABLE + "_rooms", rooms)
del rooms
# If search all users is on, get all the users we want to add.
@@ -100,13 +100,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
txn.execute("SELECT name FROM users")
users = [{"user_id": x[0]} for x in txn.fetchall()]
- self._simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
+ self.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
new_pos = yield self.get_max_stream_id_in_current_state_deltas()
yield self.runInteraction(
"populate_user_directory_temp_build", _make_staging_area
)
- yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
+ yield self.simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
yield self._end_background_update("populate_user_directory_createtables")
return 1
@@ -116,7 +116,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
"""
Update the user directory stream position, then clean up the old tables.
"""
- position = yield self._simple_select_one_onecol(
+ position = yield self.simple_select_one_onecol(
TEMP_TABLE + "_position", None, "position"
)
yield self.update_user_directory_stream_pos(position)
@@ -243,7 +243,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
to_insert.clear()
# We've finished a room. Delete it from the table.
- yield self._simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id})
+ yield self.simple_delete_one(TEMP_TABLE + "_rooms", {"room_id": room_id})
# Update the remaining counter.
progress["remaining"] -= 1
yield self.runInteraction(
@@ -312,7 +312,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
)
# We've finished processing a user. Delete it from the table.
- yield self._simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id})
+ yield self.simple_delete_one(TEMP_TABLE + "_users", {"user_id": user_id})
# Update the remaining counter.
progress["remaining"] -= 1
yield self.runInteraction(
@@ -361,7 +361,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
"""
def _update_profile_in_user_dir_txn(txn):
- new_entry = self._simple_upsert_txn(
+ new_entry = self.simple_upsert_txn(
txn,
table="user_directory",
keyvalues={"user_id": user_id},
@@ -435,7 +435,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
)
elif isinstance(self.database_engine, Sqlite3Engine):
value = "%s %s" % (user_id, display_name) if display_name else user_id
- self._simple_upsert_txn(
+ self.simple_upsert_txn(
txn,
table="user_directory_search",
keyvalues={"user_id": user_id},
@@ -462,7 +462,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
"""
def _add_users_who_share_room_txn(txn):
- self._simple_upsert_many_txn(
+ self.simple_upsert_many_txn(
txn,
table="users_who_share_private_rooms",
key_names=["user_id", "other_user_id", "room_id"],
@@ -489,7 +489,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
def _add_users_in_public_rooms_txn(txn):
- self._simple_upsert_many_txn(
+ self.simple_upsert_many_txn(
txn,
table="users_in_public_rooms",
key_names=["user_id", "room_id"],
@@ -519,7 +519,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
@cached()
def get_user_in_directory(self, user_id):
- return self._simple_select_one(
+ return self.simple_select_one(
table="user_directory",
keyvalues={"user_id": user_id},
retcols=("display_name", "avatar_url"),
@@ -528,7 +528,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore, BackgroundUpdateStore
)
def update_user_directory_stream_pos(self, stream_id):
- return self._simple_update_one(
+ return self.simple_update_one(
table="user_directory_stream_pos",
keyvalues={},
updatevalues={"stream_id": stream_id},
@@ -547,21 +547,21 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
def remove_from_user_dir(self, user_id):
def _remove_from_user_dir_txn(txn):
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn, table="user_directory", keyvalues={"user_id": user_id}
)
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn, table="user_directory_search", keyvalues={"user_id": user_id}
)
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
)
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"user_id": user_id},
)
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"other_user_id": user_id},
@@ -575,14 +575,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
"""Get all user_ids that are in the room directory because they're
in the given room_id
"""
- user_ids_share_pub = yield self._simple_select_onecol(
+ user_ids_share_pub = yield self.simple_select_onecol(
table="users_in_public_rooms",
keyvalues={"room_id": room_id},
retcol="user_id",
desc="get_users_in_dir_due_to_room",
)
- user_ids_share_priv = yield self._simple_select_onecol(
+ user_ids_share_priv = yield self.simple_select_onecol(
table="users_who_share_private_rooms",
keyvalues={"room_id": room_id},
retcol="other_user_id",
@@ -605,17 +605,17 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
"""
def _remove_user_who_share_room_txn(txn):
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"user_id": user_id, "room_id": room_id},
)
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="users_who_share_private_rooms",
keyvalues={"other_user_id": user_id, "room_id": room_id},
)
- self._simple_delete_txn(
+ self.simple_delete_txn(
txn,
table="users_in_public_rooms",
keyvalues={"user_id": user_id, "room_id": room_id},
@@ -636,14 +636,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
Returns:
list: user_id
"""
- rows = yield self._simple_select_onecol(
+ rows = yield self.simple_select_onecol(
table="users_who_share_private_rooms",
keyvalues={"user_id": user_id},
retcol="room_id",
desc="get_rooms_user_is_in",
)
- pub_rows = yield self._simple_select_onecol(
+ pub_rows = yield self.simple_select_onecol(
table="users_in_public_rooms",
keyvalues={"user_id": user_id},
retcol="room_id",
@@ -674,14 +674,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
) f2 USING (room_id)
"""
- rows = yield self._execute(
+ rows = yield self.execute(
"get_rooms_in_common_for_users", None, sql, user_id, other_user_id
)
return [room_id for room_id, in rows]
def get_user_directory_stream_pos(self):
- return self._simple_select_one_onecol(
+ return self.simple_select_one_onecol(
table="user_directory_stream_pos",
keyvalues={},
retcol="stream_id",
@@ -786,9 +786,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
# This should be unreachable.
raise Exception("Unrecognized database engine")
- results = yield self._execute(
- "search_user_dir", self.cursor_to_dict, sql, *args
- )
+ results = yield self.execute("search_user_dir", self.cursor_to_dict, sql, *args)
limited = len(results) > limit
diff --git a/synapse/storage/data_stores/main/user_erasure_store.py b/synapse/storage/data_stores/main/user_erasure_store.py
index aa4f0da5f0..37860af070 100644
--- a/synapse/storage/data_stores/main/user_erasure_store.py
+++ b/synapse/storage/data_stores/main/user_erasure_store.py
@@ -31,7 +31,7 @@ class UserErasureWorkerStore(SQLBaseStore):
Returns:
Deferred[bool]: True if the user has requested erasure
"""
- return self._simple_select_onecol(
+ return self.simple_select_onecol(
table="erased_users",
keyvalues={"user_id": user_id},
retcol="1",
@@ -56,7 +56,7 @@ class UserErasureWorkerStore(SQLBaseStore):
# iterate it multiple times, and (b) avoiding duplicates.
user_ids = tuple(set(user_ids))
- rows = yield self._simple_select_many_batch(
+ rows = yield self.simple_select_many_batch(
table="erased_users",
column="user_id",
iterable=user_ids,
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 2e7753820e..731e1c9d9c 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -447,7 +447,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
# Mark as done.
cur.execute(
database_engine.convert_param_style(
- "INSERT INTO applied_module_schemas (module_name, file)" " VALUES (?,?)"
+ "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)"
),
(modname, name),
)
diff --git a/synapse/streams/config.py b/synapse/streams/config.py
index 02994ab2a5..cd56cd91ed 100644
--- a/synapse/streams/config.py
+++ b/synapse/streams/config.py
@@ -88,9 +88,12 @@ class PaginationConfig(object):
raise SynapseError(400, "Invalid request.")
def __repr__(self):
- return (
- "PaginationConfig(from_tok=%r, to_tok=%r," " direction=%r, limit=%r)"
- ) % (self.from_token, self.to_token, self.direction, self.limit)
+ return ("PaginationConfig(from_tok=%r, to_tok=%r, direction=%r, limit=%r)") % (
+ self.from_token,
+ self.to_token,
+ self.direction,
+ self.limit,
+ )
def get_source_config(self, source_name):
keyname = "%s_key" % source_name
diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py
index 1a20c596bf..3c0e8469f3 100644
--- a/synapse/util/httpresourcetree.py
+++ b/synapse/util/httpresourcetree.py
@@ -20,7 +20,7 @@ logger = logging.getLogger(__name__)
def create_resource_tree(desired_tree, root_resource):
- """Create the resource tree for this Home Server.
+ """Create the resource tree for this homeserver.
This in unduly complicated because Twisted does not support putting
child resources more than 1 level deep at a time.
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 8c843febd8..dffe943b28 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -44,7 +44,12 @@ MEMBERSHIP_PRIORITY = (
@defer.inlineCallbacks
def filter_events_for_client(
- storage: Storage, user_id, events, is_peeking=False, always_include_ids=frozenset()
+ storage: Storage,
+ user_id,
+ events,
+ is_peeking=False,
+ always_include_ids=frozenset(),
+ apply_retention_policies=True,
):
"""
Check which events a user is allowed to see
@@ -59,6 +64,10 @@ def filter_events_for_client(
events
always_include_ids (set(event_id)): set of event ids to specifically
include (unless sender is ignored)
+ apply_retention_policies (bool): Whether to filter out events that's older than
+ allowed by the room's retention policy. Useful when this function is called
+ to e.g. check whether a user should be allowed to see the state at a given
+ event rather than to know if it should send an event to a user's client(s).
Returns:
Deferred[list[synapse.events.EventBase]]
@@ -86,6 +95,15 @@ def filter_events_for_client(
erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
+ if apply_retention_policies:
+ room_ids = set(e.room_id for e in events)
+ retention_policies = {}
+
+ for room_id in room_ids:
+ retention_policies[
+ room_id
+ ] = yield storage.main.get_retention_policy_for_room(room_id)
+
def allowed(event):
"""
Args:
@@ -103,6 +121,18 @@ def filter_events_for_client(
if not event.is_state() and event.sender in ignore_list:
return None
+ # Don't try to apply the room's retention policy if the event is a state event, as
+ # MSC1763 states that retention is only considered for non-state events.
+ if apply_retention_policies and not event.is_state():
+ retention_policy = retention_policies[event.room_id]
+ max_lifetime = retention_policy.get("max_lifetime")
+
+ if max_lifetime is not None:
+ oldest_allowed_ts = storage.main.clock.time_msec() - max_lifetime
+
+ if event.origin_server_ts < oldest_allowed_ts:
+ return None
+
if event.event_id in always_include_ids:
return event
|