diff --git a/synapse/__init__.py b/synapse/__init__.py
index bf9e810da6..d0e8d7c21b 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -27,4 +27,4 @@ try:
except ImportError:
pass
-__version__ = "0.99.4"
+__version__ = "0.99.5.2"
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 6b347b1749..ee129c8689 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -79,6 +79,7 @@ class EventTypes(object):
RoomHistoryVisibility = "m.room.history_visibility"
CanonicalAlias = "m.room.canonical_alias"
+ Encryption = "m.room.encryption"
RoomAvatar = "m.room.avatar"
RoomEncryption = "m.room.encryption"
GuestAccess = "m.room.guest_access"
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index 485b3d0237..d644803d38 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -50,6 +50,7 @@ class RoomVersion(object):
disposition = attr.ib() # str; one of the RoomDispositions
event_format = attr.ib() # int; one of the EventFormatVersions
state_res = attr.ib() # int; one of the StateResolutionVersions
+ enforce_key_validity = attr.ib() # bool
class RoomVersions(object):
@@ -58,35 +59,36 @@ class RoomVersions(object):
RoomDisposition.STABLE,
EventFormatVersions.V1,
StateResolutionVersions.V1,
- )
- STATE_V2_TEST = RoomVersion(
- "state-v2-test",
- RoomDisposition.UNSTABLE,
- EventFormatVersions.V1,
- StateResolutionVersions.V2,
+ enforce_key_validity=False,
)
V2 = RoomVersion(
"2",
RoomDisposition.STABLE,
EventFormatVersions.V1,
StateResolutionVersions.V2,
+ enforce_key_validity=False,
)
V3 = RoomVersion(
"3",
RoomDisposition.STABLE,
EventFormatVersions.V2,
StateResolutionVersions.V2,
+ enforce_key_validity=False,
)
- EVENTID_NOSLASH_TEST = RoomVersion(
- "eventid-noslash-test",
- RoomDisposition.UNSTABLE,
+ V4 = RoomVersion(
+ "4",
+ RoomDisposition.STABLE,
EventFormatVersions.V3,
StateResolutionVersions.V2,
+ enforce_key_validity=False,
+ )
+ V5 = RoomVersion(
+ "5",
+ RoomDisposition.STABLE,
+ EventFormatVersions.V3,
+ StateResolutionVersions.V2,
+ enforce_key_validity=True,
)
-
-
-# the version we will give rooms which are created on this server
-DEFAULT_ROOM_VERSION = RoomVersions.V1
KNOWN_ROOM_VERSIONS = {
@@ -94,7 +96,7 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.V1,
RoomVersions.V2,
RoomVersions.V3,
- RoomVersions.STATE_V2_TEST,
- RoomVersions.EVENTID_NOSLASH_TEST,
+ RoomVersions.V4,
+ RoomVersions.V5,
)
} # type: dict[str, RoomVersion]
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index 3c6bddff7a..e16c386a14 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -26,6 +26,7 @@ CLIENT_API_PREFIX = "/_matrix/client"
FEDERATION_PREFIX = "/_matrix/federation"
FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1"
FEDERATION_V2_PREFIX = FEDERATION_PREFIX + "/v2"
+FEDERATION_UNSTABLE_PREFIX = FEDERATION_PREFIX + "/unstable"
STATIC_PREFIX = "/_matrix/static"
WEB_CLIENT_PREFIX = "/_matrix/client"
CONTENT_REPO_PREFIX = "/_matrix/content"
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 08199a5e8d..8cc990399f 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -344,15 +344,21 @@ class _LimitedHostnameResolver(object):
def resolveHostName(self, resolutionReceiver, hostName, portNumber=0,
addressTypes=None, transportSemantics='TCP'):
- # Note this is happening deep within the reactor, so we don't need to
- # worry about log contexts.
-
# We need this function to return `resolutionReceiver` so we do all the
# actual logic involving deferreds in a separate function.
- self._resolve(
- resolutionReceiver, hostName, portNumber,
- addressTypes, transportSemantics,
- )
+
+ # even though this is happening within the depths of twisted, we need to drop
+ # our logcontext before starting _resolve, otherwise: (a) _resolve will drop
+ # the logcontext if it returns an incomplete deferred; (b) _resolve will
+ # call the resolutionReceiver *with* a logcontext, which it won't be expecting.
+ with PreserveLoggingContext():
+ self._resolve(
+ resolutionReceiver,
+ hostName,
+ portNumber,
+ addressTypes,
+ transportSemantics,
+ )
return resolutionReceiver
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index 864f1eac48..a16e037f32 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -38,6 +38,7 @@ from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
+from synapse.replication.slave.storage.profile import SlavedProfileStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
@@ -81,6 +82,7 @@ class ClientReaderSlavedStore(
SlavedApplicationServiceStore,
SlavedRegistrationStore,
SlavedTransactionStore,
+ SlavedProfileStore,
SlavedClientIpStore,
BaseSlavedStore,
):
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index 8479fee738..6504da5278 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -37,8 +37,7 @@ from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.rest.client.v1.base import ClientV1RestServlet, client_path_patterns
-from synapse.rest.client.v2_alpha._base import client_v2_patterns
+from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
@@ -49,11 +48,11 @@ from synapse.util.versionstring import get_version_string
logger = logging.getLogger("synapse.app.frontend_proxy")
-class PresenceStatusStubServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status")
+class PresenceStatusStubServlet(RestServlet):
+ PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status")
def __init__(self, hs):
- super(PresenceStatusStubServlet, self).__init__(hs)
+ super(PresenceStatusStubServlet, self).__init__()
self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
self.main_uri = hs.config.worker_main_http_uri
@@ -84,7 +83,7 @@ class PresenceStatusStubServlet(ClientV1RestServlet):
class KeyUploadServlet(RestServlet):
- PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
+ PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
def __init__(self, hs):
"""
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 342a6ce5fd..8400471f40 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2015-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -29,12 +31,50 @@ logger = logging.getLogger(__name__)
class EmailConfig(Config):
def read_config(self, config):
+ # TODO: We should separate better the email configuration from the notification
+ # and account validity config.
+
self.email_enable_notifs = False
email_config = config.get("email", {})
+
+ self.email_smtp_host = email_config.get("smtp_host", None)
+ self.email_smtp_port = email_config.get("smtp_port", None)
+ self.email_smtp_user = email_config.get("smtp_user", None)
+ self.email_smtp_pass = email_config.get("smtp_pass", None)
+ self.require_transport_security = email_config.get(
+ "require_transport_security", False
+ )
+ if "app_name" in email_config:
+ self.email_app_name = email_config["app_name"]
+ else:
+ self.email_app_name = "Matrix"
+
+ self.email_notif_from = email_config.get("notif_from", None)
+ if self.email_notif_from is not None:
+ # make sure it's valid
+ parsed = email.utils.parseaddr(self.email_notif_from)
+ if parsed[1] == '':
+ raise RuntimeError("Invalid notif_from address")
+
+ template_dir = email_config.get("template_dir")
+ # we need an absolute path, because we change directory after starting (and
+ # we don't yet know what auxilliary templates like mail.css we will need).
+ # (Note that loading as package_resources with jinja.PackageLoader doesn't
+ # work for the same reason.)
+ if not template_dir:
+ template_dir = pkg_resources.resource_filename(
+ 'synapse', 'res/templates'
+ )
+
+ self.email_template_dir = os.path.abspath(template_dir)
+
self.email_enable_notifs = email_config.get("enable_notifs", False)
+ account_validity_renewal_enabled = config.get(
+ "account_validity", {},
+ ).get("renew_at")
- if self.email_enable_notifs:
+ if self.email_enable_notifs or account_validity_renewal_enabled:
# make sure we can import the required deps
import jinja2
import bleach
@@ -42,6 +82,7 @@ class EmailConfig(Config):
jinja2
bleach
+ if self.email_enable_notifs:
required = [
"smtp_host",
"smtp_port",
@@ -66,34 +107,13 @@ class EmailConfig(Config):
"email.enable_notifs is True but no public_baseurl is set"
)
- self.email_smtp_host = email_config["smtp_host"]
- self.email_smtp_port = email_config["smtp_port"]
- self.email_notif_from = email_config["notif_from"]
self.email_notif_template_html = email_config["notif_template_html"]
self.email_notif_template_text = email_config["notif_template_text"]
- self.email_expiry_template_html = email_config.get(
- "expiry_template_html", "notice_expiry.html",
- )
- self.email_expiry_template_text = email_config.get(
- "expiry_template_text", "notice_expiry.txt",
- )
-
- template_dir = email_config.get("template_dir")
- # we need an absolute path, because we change directory after starting (and
- # we don't yet know what auxilliary templates like mail.css we will need).
- # (Note that loading as package_resources with jinja.PackageLoader doesn't
- # work for the same reason.)
- if not template_dir:
- template_dir = pkg_resources.resource_filename(
- 'synapse', 'res/templates'
- )
- template_dir = os.path.abspath(template_dir)
for f in self.email_notif_template_text, self.email_notif_template_html:
- p = os.path.join(template_dir, f)
+ p = os.path.join(self.email_template_dir, f)
if not os.path.isfile(p):
raise ConfigError("Unable to find email template file %s" % (p, ))
- self.email_template_dir = template_dir
self.email_notif_for_new_users = email_config.get(
"notif_for_new_users", True
@@ -101,29 +121,24 @@ class EmailConfig(Config):
self.email_riot_base_url = email_config.get(
"riot_base_url", None
)
- self.email_smtp_user = email_config.get(
- "smtp_user", None
- )
- self.email_smtp_pass = email_config.get(
- "smtp_pass", None
- )
- self.require_transport_security = email_config.get(
- "require_transport_security", False
- )
- if "app_name" in email_config:
- self.email_app_name = email_config["app_name"]
- else:
- self.email_app_name = "Matrix"
-
- # make sure it's valid
- parsed = email.utils.parseaddr(self.email_notif_from)
- if parsed[1] == '':
- raise RuntimeError("Invalid notif_from address")
else:
self.email_enable_notifs = False
# Not much point setting defaults for the rest: it would be an
# error for them to be used.
+ if account_validity_renewal_enabled:
+ self.email_expiry_template_html = email_config.get(
+ "expiry_template_html", "notice_expiry.html",
+ )
+ self.email_expiry_template_text = email_config.get(
+ "expiry_template_text", "notice_expiry.txt",
+ )
+
+ for f in self.email_expiry_template_text, self.email_expiry_template_html:
+ p = os.path.join(self.email_template_dir, f)
+ if not os.path.isfile(p):
+ raise ConfigError("Unable to find email template file %s" % (p, ))
+
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Enable sending emails for notification events or expiry notices
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 727fdc54d8..5c4fc8ff21 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from .api import ApiConfig
from .appservice import AppServiceConfig
from .captcha import CaptchaConfig
@@ -36,20 +37,41 @@ from .saml2_config import SAML2Config
from .server import ServerConfig
from .server_notices_config import ServerNoticesConfig
from .spam_checker import SpamCheckerConfig
+from .stats import StatsConfig
from .tls import TlsConfig
from .user_directory import UserDirectoryConfig
from .voip import VoipConfig
from .workers import WorkerConfig
-class HomeServerConfig(ServerConfig, TlsConfig, DatabaseConfig, LoggingConfig,
- RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
- VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
- AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
- JWTConfig, PasswordConfig, EmailConfig,
- WorkerConfig, PasswordAuthProviderConfig, PushConfig,
- SpamCheckerConfig, GroupsConfig, UserDirectoryConfig,
- ConsentConfig,
- ServerNoticesConfig, RoomDirectoryConfig,
- ):
+class HomeServerConfig(
+ ServerConfig,
+ TlsConfig,
+ DatabaseConfig,
+ LoggingConfig,
+ RatelimitConfig,
+ ContentRepositoryConfig,
+ CaptchaConfig,
+ VoipConfig,
+ RegistrationConfig,
+ MetricsConfig,
+ ApiConfig,
+ AppServiceConfig,
+ KeyConfig,
+ SAML2Config,
+ CasConfig,
+ JWTConfig,
+ PasswordConfig,
+ EmailConfig,
+ WorkerConfig,
+ PasswordAuthProviderConfig,
+ PushConfig,
+ SpamCheckerConfig,
+ GroupsConfig,
+ UserDirectoryConfig,
+ ConsentConfig,
+ StatsConfig,
+ ServerNoticesConfig,
+ RoomDirectoryConfig,
+):
pass
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 1309bce3ee..aad3400819 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -39,6 +39,8 @@ class AccountValidityConfig(Config):
else:
self.renew_email_subject = "Renew your %(app)s account"
+ self.startup_job_max_delta = self.period * 10. / 100.
+
if self.renew_by_email_enabled and "public_baseurl" not in synapse_config:
raise ConfigError("Can't send renewal emails without 'public_baseurl'")
@@ -123,6 +125,16 @@ class RegistrationConfig(Config):
# link. ``%%(app)s`` can be used as a placeholder for the ``app_name`` parameter
# from the ``email`` section.
#
+ # Once this feature is enabled, Synapse will look for registered users without an
+ # expiration date at startup and will add one to every account it found using the
+ # current settings at that time.
+ # This means that, if a validity period is set, and Synapse is restarted (it will
+ # then derive an expiration date from the current validity period), and some time
+ # after that the validity period changes and Synapse is restarted, the users'
+ # expiration dates won't be updated unless their account is manually renewed. This
+ # date will be randomly selected within a range [now + period - d ; now + period],
+ # where d is equal to 10%% of the validity period.
+ #
#account_validity:
# enabled: True
# period: 6w
diff --git a/synapse/config/server.py b/synapse/config/server.py
index f34aa42afa..334921d421 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -20,6 +20,7 @@ import os.path
from netaddr import IPSet
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.python_dependencies import DependencyException, check_requirements
@@ -35,6 +36,8 @@ logger = logging.Logger(__name__)
# in the list.
DEFAULT_BIND_ADDRESSES = ['::', '0.0.0.0']
+DEFAULT_ROOM_VERSION = "1"
+
class ServerConfig(Config):
@@ -88,6 +91,22 @@ class ServerConfig(Config):
"restrict_public_rooms_to_local_users", False,
)
+ default_room_version = config.get(
+ "default_room_version", DEFAULT_ROOM_VERSION,
+ )
+
+ # Ensure room version is a str
+ default_room_version = str(default_room_version)
+
+ if default_room_version not in KNOWN_ROOM_VERSIONS:
+ raise ConfigError(
+ "Unknown default_room_version: %s, known room versions: %s" %
+ (default_room_version, list(KNOWN_ROOM_VERSIONS.keys()))
+ )
+
+ # Get the actual room version object rather than just the identifier
+ self.default_room_version = KNOWN_ROOM_VERSIONS[default_room_version]
+
# whether to enable search. If disabled, new entries will not be inserted
# into the search tables and they will not be indexed. Users will receive
# errors when attempting to search for messages.
@@ -310,6 +329,10 @@ class ServerConfig(Config):
unsecure_port = 8008
pid_file = os.path.join(data_dir_path, "homeserver.pid")
+
+ # Bring DEFAULT_ROOM_VERSION into the local-scope for use in the
+ # default config string
+ default_room_version = DEFAULT_ROOM_VERSION
return """\
## Server ##
@@ -384,6 +407,16 @@ class ServerConfig(Config):
#
#restrict_public_rooms_to_local_users: true
+ # The default room version for newly created rooms.
+ #
+ # Known room versions are listed here:
+ # https://matrix.org/docs/spec/#complete-list-of-room-versions
+ #
+ # For example, for room version 1, default_room_version should be set
+ # to "1".
+ #
+ #default_room_version: "%(default_room_version)s"
+
# The GC threshold parameters to pass to `gc.set_threshold`, if defined
#
#gc_thresholds: [700, 10, 10]
@@ -552,6 +585,22 @@ class ServerConfig(Config):
# Monthly Active User Blocking
#
+ # Used in cases where the admin or server owner wants to limit to the
+ # number of monthly active users.
+ #
+ # 'limit_usage_by_mau' disables/enables monthly active user blocking. When
+ # anabled and a limit is reached the server returns a 'ResourceLimitError'
+ # with error type Codes.RESOURCE_LIMIT_EXCEEDED
+ #
+ # 'max_mau_value' is the hard limit of monthly active users above which
+ # the server will start blocking user actions.
+ #
+ # 'mau_trial_days' is a means to add a grace period for active users. It
+ # means that users must be active for this number of days before they
+ # can be considered active and guards against the case where lots of users
+ # sign up in a short space of time never to return after their initial
+ # session.
+ #
#limit_usage_by_mau: False
#max_mau_value: 50
#mau_trial_days: 2
diff --git a/synapse/config/stats.py b/synapse/config/stats.py
new file mode 100644
index 0000000000..80fc1b9dd0
--- /dev/null
+++ b/synapse/config/stats.py
@@ -0,0 +1,60 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import division
+
+import sys
+
+from ._base import Config
+
+
+class StatsConfig(Config):
+ """Stats Configuration
+ Configuration for the behaviour of synapse's stats engine
+ """
+
+ def read_config(self, config):
+ self.stats_enabled = True
+ self.stats_bucket_size = 86400
+ self.stats_retention = sys.maxsize
+ stats_config = config.get("stats", None)
+ if stats_config:
+ self.stats_enabled = stats_config.get("enabled", self.stats_enabled)
+ self.stats_bucket_size = (
+ self.parse_duration(stats_config.get("bucket_size", "1d")) / 1000
+ )
+ self.stats_retention = (
+ self.parse_duration(
+ stats_config.get("retention", "%ds" % (sys.maxsize,))
+ )
+ / 1000
+ )
+
+ def default_config(self, config_dir_path, server_name, **kwargs):
+ return """
+ # Local statistics collection. Used in populating the room directory.
+ #
+ # 'bucket_size' controls how large each statistics timeslice is. It can
+ # be defined in a human readable short form -- e.g. "1d", "1y".
+ #
+ # 'retention' controls how long historical statistics will be kept for.
+ # It can be defined in a human readable short form -- e.g. "1d", "1y".
+ #
+ #
+ #stats:
+ # enabled: true
+ # bucket_size: 1d
+ # retention: 1y
+ """
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 72dd5926f9..658f9dd361 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -74,7 +74,7 @@ class TlsConfig(Config):
# Whether to verify certificates on outbound federation traffic
self.federation_verify_certificates = config.get(
- "federation_verify_certificates", False,
+ "federation_verify_certificates", True,
)
# Whitelist of domains to not verify certificates for
@@ -107,7 +107,7 @@ class TlsConfig(Config):
certs = []
for ca_file in custom_ca_list:
logger.debug("Reading custom CA certificate file: %s", ca_file)
- content = self.read_file(ca_file)
+ content = self.read_file(ca_file, "federation_custom_ca_list")
# Parse the CA certificates
try:
@@ -241,12 +241,12 @@ class TlsConfig(Config):
#
#tls_private_key_path: "%(tls_private_key_path)s"
- # Whether to verify TLS certificates when sending federation traffic.
+ # Whether to verify TLS server certificates for outbound federation requests.
#
- # This currently defaults to `false`, however this will change in
- # Synapse 1.0 when valid federation certificates will be required.
+ # Defaults to `true`. To disable certificate verification, uncomment the
+ # following line.
#
- #federation_verify_certificates: true
+ #federation_verify_certificates: false
# Skip federation certificate verification on the following whitelist
# of domains.
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index 142754a7dc..023997ccde 100644
--- a/synapse/config/user_directory.py
+++ b/synapse/config/user_directory.py
@@ -43,9 +43,9 @@ class UserDirectoryConfig(Config):
#
# 'search_all_users' defines whether to search all users visible to your HS
# when searching the user directory, rather than limiting to users visible
- # in public rooms. Defaults to false. If you set it True, you'll have to run
- # UPDATE user_directory_stream_pos SET stream_id = NULL;
- # on your database to tell it to rebuild the user_directory search indexes.
+ # in public rooms. Defaults to false. If you set it True, you'll have to
+ # rebuild the user_directory search indexes, see
+ # https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md
#
#user_directory:
# enabled: true
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index 1dfa727fcf..99a586655b 100644
--- a/synapse/crypto/event_signing.py
+++ b/synapse/crypto/event_signing.py
@@ -31,7 +31,11 @@ logger = logging.getLogger(__name__)
def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
"""Check whether the hash for this PDU matches the contents"""
name, expected_hash = compute_content_hash(event.get_pdu_json(), hash_algorithm)
- logger.debug("Expecting hash: %s", encode_base64(expected_hash))
+ logger.debug(
+ "Verifying content hash on %s (expecting: %s)",
+ event.event_id,
+ encode_base64(expected_hash),
+ )
# some malformed events lack a 'hashes'. Protect against it being missing
# or a weird type by basically treating it the same as an unhashed event.
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index d8ba870cca..2b6b5913bc 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -15,12 +15,13 @@
# limitations under the License.
import logging
-from collections import namedtuple
+from collections import defaultdict
+import six
from six import raise_from
from six.moves import urllib
-import nacl.signing
+import attr
from signedjson.key import (
decode_verify_key_bytes,
encode_verify_key_base64,
@@ -43,7 +44,9 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
+from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext, unwrapFirstError
+from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.logcontext import (
LoggingContext,
PreserveLoggingContext,
@@ -56,22 +59,40 @@ from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
-VerifyKeyRequest = namedtuple("VerifyRequest", (
- "server_name", "key_ids", "json_object", "deferred"
-))
-"""
-A request for a verify key to verify a JSON object.
+@attr.s(slots=True, cmp=False)
+class VerifyJsonRequest(object):
+ """
+ A request to verify a JSON object.
+
+ Attributes:
+ server_name(str): The name of the server to verify against.
+
+ key_ids(set[str]): The set of key_ids to that could be used to verify the
+ JSON object
+
+ json_object(dict): The JSON object to verify.
+
+ minimum_valid_until_ts (int): time at which we require the signing key to
+ be valid. (0 implies we don't care)
+
+ key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
+ A deferred (server_name, key_id, verify_key) tuple that resolves when
+ a verify key has been fetched. The deferreds' callbacks are run with no
+ logcontext.
+
+ If we are unable to find a key which satisfies the request, the deferred
+ errbacks with an M_UNAUTHORIZED SynapseError.
+ """
+
+ server_name = attr.ib()
+ json_object = attr.ib()
+ minimum_valid_until_ts = attr.ib()
+ request_name = attr.ib()
+ key_ids = attr.ib(init=False)
+ key_ready = attr.ib(default=attr.Factory(defer.Deferred))
-Attributes:
- server_name(str): The name of the server to verify against.
- key_ids(set(str)): The set of key_ids to that could be used to verify the
- JSON object
- json_object(dict): The JSON object to verify.
- deferred(Deferred[str, str, nacl.signing.VerifyKey]):
- A deferred (server_name, key_id, verify_key) tuple that resolves when
- a verify key has been fetched. The deferreds' callbacks are run with no
- logcontext.
-"""
+ def __attrs_post_init__(self):
+ self.key_ids = signature_ids(self.json_object, self.server_name)
class KeyLookupError(ValueError):
@@ -79,13 +100,16 @@ class KeyLookupError(ValueError):
class Keyring(object):
- def __init__(self, hs):
- self.store = hs.get_datastore()
+ def __init__(self, hs, key_fetchers=None):
self.clock = hs.get_clock()
- self.client = hs.get_http_client()
- self.config = hs.get_config()
- self.perspective_servers = self.config.perspectives
- self.hs = hs
+
+ if key_fetchers is None:
+ key_fetchers = (
+ StoreKeyFetcher(hs),
+ PerspectivesKeyFetcher(hs),
+ ServerKeyFetcher(hs),
+ )
+ self._key_fetchers = key_fetchers
# map from server name to Deferred. Has an entry for each server with
# an ongoing key download; the Deferred completes once the download
@@ -94,56 +118,99 @@ class Keyring(object):
# These are regular, logcontext-agnostic Deferreds.
self.key_downloads = {}
- def verify_json_for_server(self, server_name, json_object):
- return logcontext.make_deferred_yieldable(
- self.verify_json_objects_for_server(
- [(server_name, json_object)]
- )[0]
- )
+ def verify_json_for_server(
+ self, server_name, json_object, validity_time, request_name
+ ):
+ """Verify that a JSON object has been signed by a given server
+
+ Args:
+ server_name (str): name of the server which must have signed this object
+
+ json_object (dict): object to be checked
+
+ validity_time (int): timestamp at which we require the signing key to
+ be valid. (0 implies we don't care)
+
+ request_name (str): an identifier for this json object (eg, an event id)
+ for logging.
+
+ Returns:
+ Deferred[None]: completes if the the object was correctly signed, otherwise
+ errbacks with an error
+ """
+ req = VerifyJsonRequest(server_name, json_object, validity_time, request_name)
+ requests = (req,)
+ return logcontext.make_deferred_yieldable(self._verify_objects(requests)[0])
def verify_json_objects_for_server(self, server_and_json):
"""Bulk verifies signatures of json objects, bulk fetching keys as
necessary.
Args:
- server_and_json (list): List of pairs of (server_name, json_object)
+ server_and_json (iterable[Tuple[str, dict, int, str]):
+ Iterable of (server_name, json_object, validity_time, request_name)
+ tuples.
+
+ validity_time is a timestamp at which the signing key must be
+ valid.
+
+ request_name is an identifier for this json object (eg, an event id)
+ for logging.
Returns:
- List<Deferred>: for each input pair, a deferred indicating success
+ List<Deferred[None]>: for each input triplet, a deferred indicating success
or failure to verify each json object's signature for the given
server_name. The deferreds run their callbacks in the sentinel
logcontext.
"""
- # a list of VerifyKeyRequests
- verify_requests = []
+ return self._verify_objects(
+ VerifyJsonRequest(server_name, json_object, validity_time, request_name)
+ for server_name, json_object, validity_time, request_name in server_and_json
+ )
+
+ def _verify_objects(self, verify_requests):
+ """Does the work of verify_json_[objects_]for_server
+
+
+ Args:
+ verify_requests (iterable[VerifyJsonRequest]):
+ Iterable of verification requests.
+
+ Returns:
+ List<Deferred[None]>: for each input item, a deferred indicating success
+ or failure to verify each json object's signature for the given
+ server_name. The deferreds run their callbacks in the sentinel
+ logcontext.
+ """
+ # a list of VerifyJsonRequests which are awaiting a key lookup
+ key_lookups = []
handle = preserve_fn(_handle_key_deferred)
- def process(server_name, json_object):
+ def process(verify_request):
"""Process an entry in the request list
- Given a (server_name, json_object) pair from the request list,
- adds a key request to verify_requests, and returns a deferred which will
- complete or fail (in the sentinel context) when verification completes.
+ Adds a key request to key_lookups, and returns a deferred which
+ will complete or fail (in the sentinel context) when verification completes.
"""
- key_ids = signature_ids(json_object, server_name)
-
- if not key_ids:
+ if not verify_request.key_ids:
return defer.fail(
SynapseError(
400,
- "Not signed by %s" % (server_name,),
+ "Not signed by %s" % (verify_request.server_name,),
Codes.UNAUTHORIZED,
)
)
- logger.debug("Verifying for %s with key_ids %s",
- server_name, key_ids)
+ logger.debug(
+ "Verifying %s for %s with key_ids %s, min_validity %i",
+ verify_request.request_name,
+ verify_request.server_name,
+ verify_request.key_ids,
+ verify_request.minimum_valid_until_ts,
+ )
# add the key request to the queue, but don't start it off yet.
- verify_request = VerifyKeyRequest(
- server_name, key_ids, json_object, defer.Deferred(),
- )
- verify_requests.append(verify_request)
+ key_lookups.append(verify_request)
# now run _handle_key_deferred, which will wait for the key request
# to complete and then do the verification.
@@ -152,13 +219,10 @@ class Keyring(object):
# wrap it with preserve_fn (aka run_in_background)
return handle(verify_request)
- results = [
- process(server_name, json_object)
- for server_name, json_object in server_and_json
- ]
+ results = [process(r) for r in verify_requests]
- if verify_requests:
- run_in_background(self._start_key_lookups, verify_requests)
+ if key_lookups:
+ run_in_background(self._start_key_lookups, key_lookups)
return results
@@ -166,10 +230,10 @@ class Keyring(object):
def _start_key_lookups(self, verify_requests):
"""Sets off the key fetches for each verify request
- Once each fetch completes, verify_request.deferred will be resolved.
+ Once each fetch completes, verify_request.key_ready will be resolved.
Args:
- verify_requests (List[VerifyKeyRequest]):
+ verify_requests (List[VerifyJsonRequest]):
"""
try:
@@ -179,16 +243,12 @@ class Keyring(object):
# any other lookups until we have finished.
# The deferreds are called with no logcontext.
server_to_deferred = {
- rq.server_name: defer.Deferred()
- for rq in verify_requests
+ rq.server_name: defer.Deferred() for rq in verify_requests
}
# We want to wait for any previous lookups to complete before
# proceeding.
- yield self.wait_for_previous_lookups(
- [rq.server_name for rq in verify_requests],
- server_to_deferred,
- )
+ yield self.wait_for_previous_lookups(server_to_deferred)
# Actually start fetching keys.
self._get_server_verify_keys(verify_requests)
@@ -216,19 +276,16 @@ class Keyring(object):
return res
for verify_request in verify_requests:
- verify_request.deferred.addBoth(
- remove_deferreds, verify_request,
- )
+ verify_request.key_ready.addBoth(remove_deferreds, verify_request)
except Exception:
logger.exception("Error starting key lookups")
@defer.inlineCallbacks
- def wait_for_previous_lookups(self, server_names, server_to_deferred):
+ def wait_for_previous_lookups(self, server_to_deferred):
"""Waits for any previous key lookups for the given servers to finish.
Args:
- server_names (list): list of server_names we want to lookup
- server_to_deferred (dict): server_name to deferred which gets
+ server_to_deferred (dict[str, Deferred]): server_name to deferred which gets
resolved once we've finished looking up keys for that server.
The Deferreds should be regular twisted ones which call their
callbacks with no logcontext.
@@ -241,14 +298,15 @@ class Keyring(object):
while True:
wait_on = [
(server_name, self.key_downloads[server_name])
- for server_name in server_names
+ for server_name in server_to_deferred.keys()
if server_name in self.key_downloads
]
if not wait_on:
break
logger.info(
"Waiting for existing lookups for %s to complete [loop %i]",
- [w[0] for w in wait_on], loop_count,
+ [w[0] for w in wait_on],
+ loop_count,
)
with PreserveLoggingContext():
yield defer.DeferredList((w[1] for w in wait_on))
@@ -271,137 +329,296 @@ class Keyring(object):
def _get_server_verify_keys(self, verify_requests):
"""Tries to find at least one key for each verify request
- For each verify_request, verify_request.deferred is called back with
+ For each verify_request, verify_request.key_ready is called back with
params (server_name, key_id, VerifyKey) if a key is found, or errbacked
with a SynapseError if none of the keys are found.
Args:
- verify_requests (list[VerifyKeyRequest]): list of verify requests
+ verify_requests (list[VerifyJsonRequest]): list of verify requests
"""
- # These are functions that produce keys given a list of key ids
- key_fetch_fns = (
- self.get_keys_from_store, # First try the local store
- self.get_keys_from_perspectives, # Then try via perspectives
- self.get_keys_from_server, # Then try directly
+ remaining_requests = set(
+ (rq for rq in verify_requests if not rq.key_ready.called)
)
@defer.inlineCallbacks
def do_iterations():
with Measure(self.clock, "get_server_verify_keys"):
- # dict[str, set(str)]: keys to fetch for each server
- missing_keys = {}
- for verify_request in verify_requests:
- missing_keys.setdefault(verify_request.server_name, set()).update(
- verify_request.key_ids
- )
-
- for fn in key_fetch_fns:
- results = yield fn(missing_keys.items())
-
- # We now need to figure out which verify requests we have keys
- # for and which we don't
- missing_keys = {}
- requests_missing_keys = []
- for verify_request in verify_requests:
- if verify_request.deferred.called:
- # We've already called this deferred, which probably
- # means that we've already found a key for it.
- continue
-
- server_name = verify_request.server_name
-
- # see if any of the keys we got this time are sufficient to
- # complete this VerifyKeyRequest.
- result_keys = results.get(server_name, {})
- for key_id in verify_request.key_ids:
- key = result_keys.get(key_id)
- if key:
- with PreserveLoggingContext():
- verify_request.deferred.callback(
- (server_name, key_id, key)
- )
- break
- else:
- # The else block is only reached if the loop above
- # doesn't break.
- missing_keys.setdefault(server_name, set()).update(
- verify_request.key_ids
- )
- requests_missing_keys.append(verify_request)
-
- if not missing_keys:
- break
+ for f in self._key_fetchers:
+ if not remaining_requests:
+ return
+ yield self._attempt_key_fetches_with_fetcher(f, remaining_requests)
+ # look for any requests which weren't satisfied
with PreserveLoggingContext():
- for verify_request in requests_missing_keys:
- verify_request.deferred.errback(SynapseError(
- 401,
- "No key for %s with id %s" % (
- verify_request.server_name, verify_request.key_ids,
- ),
- Codes.UNAUTHORIZED,
- ))
+ for verify_request in remaining_requests:
+ verify_request.key_ready.errback(
+ SynapseError(
+ 401,
+ "No key for %s with ids in %s (min_validity %i)"
+ % (
+ verify_request.server_name,
+ verify_request.key_ids,
+ verify_request.minimum_valid_until_ts,
+ ),
+ Codes.UNAUTHORIZED,
+ )
+ )
def on_err(err):
+ # we don't really expect to get here, because any errors should already
+ # have been caught and logged. But if we do, let's log the error and make
+ # sure that all of the deferreds are resolved.
+ logger.error("Unexpected error in _get_server_verify_keys: %s", err)
with PreserveLoggingContext():
- for verify_request in verify_requests:
- if not verify_request.deferred.called:
- verify_request.deferred.errback(err)
+ for verify_request in remaining_requests:
+ if not verify_request.key_ready.called:
+ verify_request.key_ready.errback(err)
run_in_background(do_iterations).addErrback(on_err)
@defer.inlineCallbacks
- def get_keys_from_store(self, server_name_and_key_ids):
+ def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
+ """Use a key fetcher to attempt to satisfy some key requests
+
+ Args:
+ fetcher (KeyFetcher): fetcher to use to fetch the keys
+ remaining_requests (set[VerifyJsonRequest]): outstanding key requests.
+ Any successfully-completed requests will be removed from the list.
+ """
+ # dict[str, dict[str, int]]: keys to fetch.
+ # server_name -> key_id -> min_valid_ts
+ missing_keys = defaultdict(dict)
+
+ for verify_request in remaining_requests:
+ # any completed requests should already have been removed
+ assert not verify_request.key_ready.called
+ keys_for_server = missing_keys[verify_request.server_name]
+
+ for key_id in verify_request.key_ids:
+ # If we have several requests for the same key, then we only need to
+ # request that key once, but we should do so with the greatest
+ # min_valid_until_ts of the requests, so that we can satisfy all of
+ # the requests.
+ keys_for_server[key_id] = max(
+ keys_for_server.get(key_id, -1),
+ verify_request.minimum_valid_until_ts,
+ )
+
+ results = yield fetcher.get_keys(missing_keys)
+
+ completed = list()
+ for verify_request in remaining_requests:
+ server_name = verify_request.server_name
+
+ # see if any of the keys we got this time are sufficient to
+ # complete this VerifyJsonRequest.
+ result_keys = results.get(server_name, {})
+ for key_id in verify_request.key_ids:
+ fetch_key_result = result_keys.get(key_id)
+ if not fetch_key_result:
+ # we didn't get a result for this key
+ continue
+
+ if (
+ fetch_key_result.valid_until_ts
+ < verify_request.minimum_valid_until_ts
+ ):
+ # key was not valid at this point
+ continue
+
+ with PreserveLoggingContext():
+ verify_request.key_ready.callback(
+ (server_name, key_id, fetch_key_result.verify_key)
+ )
+ completed.append(verify_request)
+ break
+
+ remaining_requests.difference_update(completed)
+
+
+class KeyFetcher(object):
+ def get_keys(self, keys_to_fetch):
"""
Args:
- server_name_and_key_ids (iterable(Tuple[str, iterable[str]]):
- list of (server_name, iterable[key_id]) tuples to fetch keys for
+ keys_to_fetch (dict[str, dict[str, int]]):
+ the keys to be fetched. server_name -> key_id -> min_valid_ts
Returns:
- Deferred: resolves to dict[str, dict[str, VerifyKey|None]]: map from
- server_name -> key_id -> VerifyKey
+ Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
+ map from server_name -> key_id -> FetchKeyResult
"""
+ raise NotImplementedError
+
+
+class StoreKeyFetcher(KeyFetcher):
+ """KeyFetcher impl which fetches keys from our data store"""
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def get_keys(self, keys_to_fetch):
+ """see KeyFetcher.get_keys"""
+
keys_to_fetch = (
(server_name, key_id)
- for server_name, key_ids in server_name_and_key_ids
- for key_id in key_ids
+ for server_name, keys_for_server in keys_to_fetch.items()
+ for key_id in keys_for_server.keys()
)
+
res = yield self.store.get_server_verify_keys(keys_to_fetch)
keys = {}
for (server_name, key_id), key in res.items():
keys.setdefault(server_name, {})[key_id] = key
defer.returnValue(keys)
+
+class BaseV2KeyFetcher(object):
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.config = hs.get_config()
+
+ @defer.inlineCallbacks
+ def process_v2_response(self, from_server, response_json, time_added_ms):
+ """Parse a 'Server Keys' structure from the result of a /key request
+
+ This is used to parse either the entirety of the response from
+ GET /_matrix/key/v2/server, or a single entry from the list returned by
+ POST /_matrix/key/v2/query.
+
+ Checks that each signature in the response that claims to come from the origin
+ server is valid, and that there is at least one such signature.
+
+ Stores the json in server_keys_json so that it can be used for future responses
+ to /_matrix/key/v2/query.
+
+ Args:
+ from_server (str): the name of the server producing this result: either
+ the origin server for a /_matrix/key/v2/server request, or the notary
+ for a /_matrix/key/v2/query.
+
+ response_json (dict): the json-decoded Server Keys response object
+
+ time_added_ms (int): the timestamp to record in server_keys_json
+
+ Returns:
+ Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
+ """
+ ts_valid_until_ms = response_json[u"valid_until_ts"]
+
+ # start by extracting the keys from the response, since they may be required
+ # to validate the signature on the response.
+ verify_keys = {}
+ for key_id, key_data in response_json["verify_keys"].items():
+ if is_signing_algorithm_supported(key_id):
+ key_base64 = key_data["key"]
+ key_bytes = decode_base64(key_base64)
+ verify_key = decode_verify_key_bytes(key_id, key_bytes)
+ verify_keys[key_id] = FetchKeyResult(
+ verify_key=verify_key, valid_until_ts=ts_valid_until_ms
+ )
+
+ server_name = response_json["server_name"]
+ verified = False
+ for key_id in response_json["signatures"].get(server_name, {}):
+ # each of the keys used for the signature must be present in the response
+ # json.
+ key = verify_keys.get(key_id)
+ if not key:
+ raise KeyLookupError(
+ "Key response is signed by key id %s:%s but that key is not "
+ "present in the response" % (server_name, key_id)
+ )
+
+ verify_signed_json(response_json, server_name, key.verify_key)
+ verified = True
+
+ if not verified:
+ raise KeyLookupError(
+ "Key response for %s is not signed by the origin server"
+ % (server_name,)
+ )
+
+ for key_id, key_data in response_json["old_verify_keys"].items():
+ if is_signing_algorithm_supported(key_id):
+ key_base64 = key_data["key"]
+ key_bytes = decode_base64(key_base64)
+ verify_key = decode_verify_key_bytes(key_id, key_bytes)
+ verify_keys[key_id] = FetchKeyResult(
+ verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
+ )
+
+ # re-sign the json with our own key, so that it is ready if we are asked to
+ # give it out as a notary server
+ signed_key_json = sign_json(
+ response_json, self.config.server_name, self.config.signing_key[0]
+ )
+
+ signed_key_json_bytes = encode_canonical_json(signed_key_json)
+
+ yield logcontext.make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(
+ self.store.store_server_keys_json,
+ server_name=server_name,
+ key_id=key_id,
+ from_server=from_server,
+ ts_now_ms=time_added_ms,
+ ts_expires_ms=ts_valid_until_ms,
+ key_json_bytes=signed_key_json_bytes,
+ )
+ for key_id in verify_keys
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+ )
+
+ defer.returnValue(verify_keys)
+
+
+class PerspectivesKeyFetcher(BaseV2KeyFetcher):
+ """KeyFetcher impl which fetches keys from the "perspectives" servers"""
+
+ def __init__(self, hs):
+ super(PerspectivesKeyFetcher, self).__init__(hs)
+ self.clock = hs.get_clock()
+ self.client = hs.get_http_client()
+ self.perspective_servers = self.config.perspectives
+
@defer.inlineCallbacks
- def get_keys_from_perspectives(self, server_name_and_key_ids):
+ def get_keys(self, keys_to_fetch):
+ """see KeyFetcher.get_keys"""
+
@defer.inlineCallbacks
def get_key(perspective_name, perspective_keys):
try:
result = yield self.get_server_verify_key_v2_indirect(
- server_name_and_key_ids, perspective_name, perspective_keys
+ keys_to_fetch, perspective_name, perspective_keys
)
defer.returnValue(result)
except KeyLookupError as e:
- logger.warning(
- "Key lookup failed from %r: %s", perspective_name, e,
- )
+ logger.warning("Key lookup failed from %r: %s", perspective_name, e)
except Exception as e:
logger.exception(
"Unable to get key from %r: %s %s",
perspective_name,
- type(e).__name__, str(e),
+ type(e).__name__,
+ str(e),
)
defer.returnValue({})
- results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- run_in_background(get_key, p_name, p_keys)
- for p_name, p_keys in self.perspective_servers.items()
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError))
+ results = yield logcontext.make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(get_key, p_name, p_keys)
+ for p_name, p_keys in self.perspective_servers.items()
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+ )
union_of_keys = {}
for result in results:
@@ -411,36 +628,33 @@ class Keyring(object):
defer.returnValue(union_of_keys)
@defer.inlineCallbacks
- def get_keys_from_server(self, server_name_and_key_ids):
- results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- run_in_background(
- self.get_server_verify_key_v2_direct,
- server_name,
- key_ids,
- )
- for server_name, key_ids in server_name_and_key_ids
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError))
+ def get_server_verify_key_v2_indirect(
+ self, keys_to_fetch, perspective_name, perspective_keys
+ ):
+ """
+ Args:
+ keys_to_fetch (dict[str, dict[str, int]]):
+ the keys to be fetched. server_name -> key_id -> min_valid_ts
- merged = {}
- for result in results:
- merged.update(result)
+ perspective_name (str): name of the notary server to query for the keys
- defer.returnValue({
- server_name: keys
- for server_name, keys in merged.items()
- if keys
- })
+ perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
+ notary server
+
+ Returns:
+ Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map
+ from server_name -> key_id -> FetchKeyResult
+
+ Raises:
+ KeyLookupError if there was an error processing the entire response from
+ the server
+ """
+ logger.info(
+ "Requesting keys %s from notary server %s",
+ keys_to_fetch.items(),
+ perspective_name,
+ )
- @defer.inlineCallbacks
- def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
- perspective_name,
- perspective_keys):
- # TODO(mark): Set the minimum_valid_until_ts to that needed by
- # the events being validated or the current time if validating
- # an incoming request.
try:
query_response = yield self.client.post_json(
destination=perspective_name,
@@ -448,249 +662,213 @@ class Keyring(object):
data={
u"server_keys": {
server_name: {
- key_id: {
- u"minimum_valid_until_ts": 0
- } for key_id in key_ids
+ key_id: {u"minimum_valid_until_ts": min_valid_ts}
+ for key_id, min_valid_ts in server_keys.items()
}
- for server_name, key_ids in server_names_and_key_ids
+ for server_name, server_keys in keys_to_fetch.items()
}
},
- long_retries=True,
)
except (NotRetryingDestination, RequestSendFailed) as e:
- raise_from(
- KeyLookupError("Failed to connect to remote server"), e,
- )
+ raise_from(KeyLookupError("Failed to connect to remote server"), e)
except HttpResponseException as e:
- raise_from(
- KeyLookupError("Remote server returned an error"), e,
- )
+ raise_from(KeyLookupError("Remote server returned an error"), e)
keys = {}
+ added_keys = []
- responses = query_response["server_keys"]
+ time_now_ms = self.clock.time_msec()
- for response in responses:
- if (u"signatures" not in response
- or perspective_name not in response[u"signatures"]):
+ for response in query_response["server_keys"]:
+ # do this first, so that we can give useful errors thereafter
+ server_name = response.get("server_name")
+ if not isinstance(server_name, six.string_types):
raise KeyLookupError(
- "Key response not signed by perspective server"
- " %r" % (perspective_name,)
+ "Malformed response from key notary server %s: invalid server_name"
+ % (perspective_name,)
)
- verified = False
- for key_id in response[u"signatures"][perspective_name]:
- if key_id in perspective_keys:
- verify_signed_json(
- response,
- perspective_name,
- perspective_keys[key_id]
- )
- verified = True
-
- if not verified:
- logging.info(
- "Response from perspective server %r not signed with a"
- " known key, signed with: %r, known keys: %r",
+ try:
+ processed_response = yield self._process_perspectives_response(
perspective_name,
- list(response[u"signatures"][perspective_name]),
- list(perspective_keys)
+ perspective_keys,
+ response,
+ time_added_ms=time_now_ms,
)
- raise KeyLookupError(
- "Response not signed with a known key for perspective"
- " server %r" % (perspective_name,)
+ except KeyLookupError as e:
+ logger.warning(
+ "Error processing response from key notary server %s for origin "
+ "server %s: %s",
+ perspective_name,
+ server_name,
+ e,
)
+ # we continue to process the rest of the response
+ continue
- processed_response = yield self.process_v2_response(
- perspective_name, response
+ added_keys.extend(
+ (server_name, key_id, key) for key_id, key in processed_response.items()
)
- server_name = response["server_name"]
-
keys.setdefault(server_name, {}).update(processed_response)
- yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- run_in_background(
- self.store_keys,
- server_name=server_name,
- from_server=perspective_name,
- verify_keys=response_keys,
- )
- for server_name, response_keys in keys.items()
- ],
- consumeErrors=True
- ).addErrback(unwrapFirstError))
+ yield self.store.store_server_verify_keys(
+ perspective_name, time_now_ms, added_keys
+ )
defer.returnValue(keys)
- @defer.inlineCallbacks
- def get_server_verify_key_v2_direct(self, server_name, key_ids):
- keys = {} # type: dict[str, nacl.signing.VerifyKey]
+ def _process_perspectives_response(
+ self, perspective_name, perspective_keys, response, time_added_ms
+ ):
+ """Parse a 'Server Keys' structure from the result of a /key/query request
- for requested_key_id in key_ids:
- if requested_key_id in keys:
- continue
+ Checks that the entry is correctly signed by the perspectives server, and then
+ passes over to process_v2_response
- try:
- response = yield self.client.get_json(
- destination=server_name,
- path="/_matrix/key/v2/server/" + urllib.parse.quote(requested_key_id),
- ignore_backoff=True,
- )
- except (NotRetryingDestination, RequestSendFailed) as e:
- raise_from(
- KeyLookupError("Failed to connect to remote server"), e,
- )
- except HttpResponseException as e:
- raise_from(
- KeyLookupError("Remote server returned an error"), e,
- )
+ Args:
+ perspective_name (str): the name of the notary server that produced this
+ result
- if (u"signatures" not in response
- or server_name not in response[u"signatures"]):
- raise KeyLookupError("Key response not signed by remote server")
+ perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
+ notary server
- if response["server_name"] != server_name:
- raise KeyLookupError("Expected a response for server %r not %r" % (
- server_name, response["server_name"]
- ))
+ response (dict): the json-decoded Server Keys response object
- response_keys = yield self.process_v2_response(
- from_server=server_name,
- requested_ids=[requested_key_id],
- response_json=response,
- )
+ time_added_ms (int): the timestamp to record in server_keys_json
- keys.update(response_keys)
+ Returns:
+ Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
+ """
+ if (
+ u"signatures" not in response
+ or perspective_name not in response[u"signatures"]
+ ):
+ raise KeyLookupError("Response not signed by the notary server")
+
+ verified = False
+ for key_id in response[u"signatures"][perspective_name]:
+ if key_id in perspective_keys:
+ verify_signed_json(response, perspective_name, perspective_keys[key_id])
+ verified = True
+
+ if not verified:
+ raise KeyLookupError(
+ "Response not signed with a known key: signed with: %r, known keys: %r"
+ % (
+ list(response[u"signatures"][perspective_name].keys()),
+ list(perspective_keys.keys()),
+ )
+ )
- yield self.store_keys(
- server_name=server_name,
- from_server=server_name,
- verify_keys=keys,
+ return self.process_v2_response(
+ perspective_name, response, time_added_ms=time_added_ms
)
- defer.returnValue({server_name: keys})
- @defer.inlineCallbacks
- def process_v2_response(
- self, from_server, response_json, requested_ids=[],
- ):
- """Parse a 'Server Keys' structure from the result of a /key request
- This is used to parse either the entirety of the response from
- GET /_matrix/key/v2/server, or a single entry from the list returned by
- POST /_matrix/key/v2/query.
+class ServerKeyFetcher(BaseV2KeyFetcher):
+ """KeyFetcher impl which fetches keys from the origin servers"""
- Checks that each signature in the response that claims to come from the origin
- server is valid. (Does not check that there actually is such a signature, for
- some reason.)
-
- Stores the json in server_keys_json so that it can be used for future responses
- to /_matrix/key/v2/query.
+ def __init__(self, hs):
+ super(ServerKeyFetcher, self).__init__(hs)
+ self.clock = hs.get_clock()
+ self.client = hs.get_http_client()
+ def get_keys(self, keys_to_fetch):
+ """
Args:
- from_server (str): the name of the server producing this result: either
- the origin server for a /_matrix/key/v2/server request, or the notary
- for a /_matrix/key/v2/query.
-
- response_json (dict): the json-decoded Server Keys response object
-
- requested_ids (iterable[str]): a list of the key IDs that were requested.
- We will store the json for these key ids as well as any that are
- actually in the response
+ keys_to_fetch (dict[str, iterable[str]]):
+ the keys to be fetched. server_name -> key_ids
Returns:
- Deferred[dict[str, nacl.signing.VerifyKey]]:
- map from key_id to key object
+ Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
+ map from server_name -> key_id -> FetchKeyResult
"""
- time_now_ms = self.clock.time_msec()
- response_keys = {}
- verify_keys = {}
- for key_id, key_data in response_json["verify_keys"].items():
- if is_signing_algorithm_supported(key_id):
- key_base64 = key_data["key"]
- key_bytes = decode_base64(key_base64)
- verify_key = decode_verify_key_bytes(key_id, key_bytes)
- verify_key.time_added = time_now_ms
- verify_keys[key_id] = verify_key
- old_verify_keys = {}
- for key_id, key_data in response_json["old_verify_keys"].items():
- if is_signing_algorithm_supported(key_id):
- key_base64 = key_data["key"]
- key_bytes = decode_base64(key_base64)
- verify_key = decode_verify_key_bytes(key_id, key_bytes)
- verify_key.expired = key_data["expired_ts"]
- verify_key.time_added = time_now_ms
- old_verify_keys[key_id] = verify_key
+ results = {}
- server_name = response_json["server_name"]
- for key_id in response_json["signatures"].get(server_name, {}):
- if key_id not in response_json["verify_keys"]:
- raise KeyLookupError(
- "Key response must include verification keys for all"
- " signatures"
- )
- if key_id in verify_keys:
- verify_signed_json(
- response_json,
- server_name,
- verify_keys[key_id]
+ @defer.inlineCallbacks
+ def get_key(key_to_fetch_item):
+ server_name, key_ids = key_to_fetch_item
+ try:
+ keys = yield self.get_server_verify_key_v2_direct(server_name, key_ids)
+ results[server_name] = keys
+ except KeyLookupError as e:
+ logger.warning(
+ "Error looking up keys %s from %s: %s", key_ids, server_name, e
)
+ except Exception:
+ logger.exception("Error getting keys %s from %s", key_ids, server_name)
- signed_key_json = sign_json(
- response_json,
- self.config.server_name,
- self.config.signing_key[0],
+ return yieldable_gather_results(get_key, keys_to_fetch.items()).addCallback(
+ lambda _: results
)
- signed_key_json_bytes = encode_canonical_json(signed_key_json)
- ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
-
- updated_key_ids = set(requested_ids)
- updated_key_ids.update(verify_keys)
- updated_key_ids.update(old_verify_keys)
-
- response_keys.update(verify_keys)
- response_keys.update(old_verify_keys)
-
- yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- run_in_background(
- self.store.store_server_keys_json,
- server_name=server_name,
- key_id=key_id,
- from_server=from_server,
- ts_now_ms=time_now_ms,
- ts_expires_ms=ts_valid_until_ms,
- key_json_bytes=signed_key_json_bytes,
- )
- for key_id in updated_key_ids
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError))
-
- defer.returnValue(response_keys)
+ @defer.inlineCallbacks
+ def get_server_verify_key_v2_direct(self, server_name, key_ids):
+ """
- def store_keys(self, server_name, from_server, verify_keys):
- """Store a collection of verify keys for a given server
Args:
- server_name(str): The name of the server the keys are for.
- from_server(str): The server the keys were downloaded from.
- verify_keys(dict): A mapping of key_id to VerifyKey.
+ server_name (str):
+ key_ids (iterable[str]):
+
Returns:
- A deferred that completes when the keys are stored.
+ Deferred[dict[str, FetchKeyResult]]: map from key ID to lookup result
+
+ Raises:
+ KeyLookupError if there was a problem making the lookup
"""
- # TODO(markjh): Store whether the keys have expired.
- return logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- run_in_background(
- self.store.store_server_verify_key,
- server_name, server_name, key.time_added, key
+ keys = {} # type: dict[str, FetchKeyResult]
+
+ for requested_key_id in key_ids:
+ # we may have found this key as a side-effect of asking for another.
+ if requested_key_id in keys:
+ continue
+
+ time_now_ms = self.clock.time_msec()
+ try:
+ response = yield self.client.get_json(
+ destination=server_name,
+ path="/_matrix/key/v2/server/"
+ + urllib.parse.quote(requested_key_id),
+ ignore_backoff=True,
+
+ # we only give the remote server 10s to respond. It should be an
+ # easy request to handle, so if it doesn't reply within 10s, it's
+ # probably not going to.
+ #
+ # Furthermore, when we are acting as a notary server, we cannot
+ # wait all day for all of the origin servers, as the requesting
+ # server will otherwise time out before we can respond.
+ #
+ # (Note that get_json may make 4 attempts, so this can still take
+ # almost 45 seconds to fetch the headers, plus up to another 60s to
+ # read the response).
+ timeout=10000,
)
- for key_id, key in verify_keys.items()
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError))
+ except (NotRetryingDestination, RequestSendFailed) as e:
+ raise_from(KeyLookupError("Failed to connect to remote server"), e)
+ except HttpResponseException as e:
+ raise_from(KeyLookupError("Remote server returned an error"), e)
+
+ if response["server_name"] != server_name:
+ raise KeyLookupError(
+ "Expected a response for server %r not %r"
+ % (server_name, response["server_name"])
+ )
+
+ response_keys = yield self.process_v2_response(
+ from_server=server_name,
+ response_json=response,
+ time_added_ms=time_now_ms,
+ )
+ yield self.store.store_server_verify_keys(
+ server_name,
+ time_now_ms,
+ ((server_name, key_id, key) for key_id, key in response_keys.items()),
+ )
+ keys.update(response_keys)
+
+ defer.returnValue(keys)
@defer.inlineCallbacks
@@ -698,7 +876,7 @@ def _handle_key_deferred(verify_request):
"""Waits for the key to become available, and then performs a verification
Args:
- verify_request (VerifyKeyRequest):
+ verify_request (VerifyJsonRequest):
Returns:
Deferred[None]
@@ -707,48 +885,25 @@ def _handle_key_deferred(verify_request):
SynapseError if there was a problem performing the verification
"""
server_name = verify_request.server_name
- try:
- with PreserveLoggingContext():
- _, key_id, verify_key = yield verify_request.deferred
- except KeyLookupError as e:
- logger.warn(
- "Failed to download keys for %s: %s %s",
- server_name, type(e).__name__, str(e),
- )
- raise SynapseError(
- 502,
- "Error downloading keys for %s" % (server_name,),
- Codes.UNAUTHORIZED,
- )
- except Exception as e:
- logger.exception(
- "Got Exception when downloading keys for %s: %s %s",
- server_name, type(e).__name__, str(e),
- )
- raise SynapseError(
- 401,
- "No key for %s with id %s" % (server_name, verify_request.key_ids),
- Codes.UNAUTHORIZED,
- )
+ with PreserveLoggingContext():
+ _, key_id, verify_key = yield verify_request.key_ready
json_object = verify_request.json_object
- logger.debug("Got key %s %s:%s for server %s, verifying" % (
- key_id, verify_key.alg, verify_key.version, server_name,
- ))
try:
verify_signed_json(json_object, server_name, verify_key)
except SignatureVerifyException as e:
logger.debug(
"Error verifying signature for %s:%s:%s with key %s: %s",
- server_name, verify_key.alg, verify_key.version,
+ server_name,
+ verify_key.alg,
+ verify_key.version,
encode_verify_key_base64(verify_key),
str(e),
)
raise SynapseError(
401,
- "Invalid signature for server %s with key %s:%s: %s" % (
- server_name, verify_key.alg, verify_key.version, str(e),
- ),
+ "Invalid signature for server %s with key %s:%s: %s"
+ % (server_name, verify_key.alg, verify_key.version, str(e)),
Codes.UNAUTHORIZED,
)
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 1fe995f212..546b6f4982 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -76,6 +76,7 @@ class EventBuilder(object):
# someone tries to get them when they don't exist.
_state_key = attr.ib(default=None)
_redacts = attr.ib(default=None)
+ _origin_server_ts = attr.ib(default=None)
internal_metadata = attr.ib(default=attr.Factory(lambda: _EventInternalMetadata({})))
@@ -142,6 +143,9 @@ class EventBuilder(object):
if self._redacts is not None:
event_dict["redacts"] = self._redacts
+ if self._origin_server_ts is not None:
+ event_dict["origin_server_ts"] = self._origin_server_ts
+
defer.returnValue(
create_local_event_from_event_dict(
clock=self._clock,
@@ -209,6 +213,7 @@ class EventBuilderFactory(object):
content=key_values.get("content", {}),
unsigned=key_values.get("unsigned", {}),
redacts=key_values.get("redacts", None),
+ origin_server_ts=key_values.get("origin_server_ts", None),
)
@@ -245,7 +250,7 @@ def create_local_event_from_event_dict(clock, hostname, signing_key,
event_dict["event_id"] = _create_event_id(clock, hostname)
event_dict["origin"] = hostname
- event_dict["origin_server_ts"] = time_now
+ event_dict.setdefault("origin_server_ts", time_now)
event_dict.setdefault("unsigned", {})
age = event_dict["unsigned"].pop("age", 0)
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 27a2a9ef98..e2d4384de1 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -330,12 +330,13 @@ class EventClientSerializer(object):
)
@defer.inlineCallbacks
- def serialize_event(self, event, time_now, **kwargs):
+ def serialize_event(self, event, time_now, bundle_aggregations=True, **kwargs):
"""Serializes a single event.
Args:
event (EventBase)
time_now (int): The current time in milliseconds
+ bundle_aggregations (bool): Whether to bundle in related events
**kwargs: Arguments to pass to `serialize_event`
Returns:
@@ -350,7 +351,7 @@ class EventClientSerializer(object):
# If MSC1849 is enabled then we need to look if thre are any relations
# we need to bundle in with the event
- if self.experimental_msc1849_support_enabled:
+ if self.experimental_msc1849_support_enabled and bundle_aggregations:
annotations = yield self.store.get_aggregation_groups_for_event(
event_id,
)
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index cffa831d80..fc5cfb7d83 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -223,9 +223,6 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
the signatures are valid, or fail (with a SynapseError) if not.
"""
- # (currently this is written assuming the v1 room structure; we'll probably want a
- # separate function for checking v2 rooms)
-
# we want to check that the event is signed by:
#
# (a) the sender's server
@@ -257,6 +254,10 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
for p in pdus
]
+ v = KNOWN_ROOM_VERSIONS.get(room_version)
+ if not v:
+ raise RuntimeError("Unrecognized room version %s" % (room_version,))
+
# First we check that the sender event is signed by the sender's domain
# (except if its a 3pid invite, in which case it may be sent by any server)
pdus_to_check_sender = [
@@ -264,10 +265,17 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
if not _is_invite_via_3pid(p.pdu)
]
- more_deferreds = keyring.verify_json_objects_for_server([
- (p.sender_domain, p.redacted_pdu_json)
- for p in pdus_to_check_sender
- ])
+ more_deferreds = keyring.verify_json_objects_for_server(
+ [
+ (
+ p.sender_domain,
+ p.redacted_pdu_json,
+ p.pdu.origin_server_ts if v.enforce_key_validity else 0,
+ p.pdu.event_id,
+ )
+ for p in pdus_to_check_sender
+ ]
+ )
def sender_err(e, pdu_to_check):
errmsg = "event id %s: unable to verify signature for sender %s: %s" % (
@@ -287,20 +295,23 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
# event id's domain (normally only the case for joins/leaves), and add additional
# checks. Only do this if the room version has a concept of event ID domain
# (ie, the room version uses old-style non-hash event IDs).
- v = KNOWN_ROOM_VERSIONS.get(room_version)
- if not v:
- raise RuntimeError("Unrecognized room version %s" % (room_version,))
-
if v.event_format == EventFormatVersions.V1:
pdus_to_check_event_id = [
p for p in pdus_to_check
if p.sender_domain != get_domain_from_id(p.pdu.event_id)
]
- more_deferreds = keyring.verify_json_objects_for_server([
- (get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json)
- for p in pdus_to_check_event_id
- ])
+ more_deferreds = keyring.verify_json_objects_for_server(
+ [
+ (
+ get_domain_from_id(p.pdu.event_id),
+ p.redacted_pdu_json,
+ p.pdu.origin_server_ts if v.enforce_key_validity else 0,
+ p.pdu.event_id,
+ )
+ for p in pdus_to_check_event_id
+ ]
+ )
def event_err(e, pdu_to_check):
errmsg = (
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index f3fc897a0a..70573746d6 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -17,7 +17,6 @@
import copy
import itertools
import logging
-import random
from six.moves import range
@@ -233,7 +232,8 @@ class FederationClient(FederationBase):
moving to the next destination. None indicates no timeout.
Returns:
- Deferred: Results in the requested PDU.
+ Deferred: Results in the requested PDU, or None if we were unable to find
+ it.
"""
# TODO: Rate limit the number of times we try and get the same event.
@@ -258,7 +258,12 @@ class FederationClient(FederationBase):
destination, event_id, timeout=timeout,
)
- logger.debug("transaction_data %r", transaction_data)
+ logger.debug(
+ "retrieved event id %s from %s: %r",
+ event_id,
+ destination,
+ transaction_data,
+ )
pdu_list = [
event_from_pdu_json(p, format_ver, outlier=outlier)
@@ -280,6 +285,7 @@ class FederationClient(FederationBase):
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
)
+ continue
except NotRetryingDestination as e:
logger.info(str(e))
continue
@@ -326,12 +332,16 @@ class FederationClient(FederationBase):
state_event_ids = result["pdu_ids"]
auth_event_ids = result.get("auth_chain_ids", [])
- fetched_events, failed_to_fetch = yield self.get_events(
- [destination], room_id, set(state_event_ids + auth_event_ids)
+ fetched_events, failed_to_fetch = yield self.get_events_from_store_or_dest(
+ destination, room_id, set(state_event_ids + auth_event_ids)
)
if failed_to_fetch:
- logger.warn("Failed to get %r", failed_to_fetch)
+ logger.warning(
+ "Failed to fetch missing state/auth events for %s: %s",
+ room_id,
+ failed_to_fetch
+ )
event_map = {
ev.event_id: ev for ev in fetched_events
@@ -397,27 +407,20 @@ class FederationClient(FederationBase):
defer.returnValue((signed_pdus, signed_auth))
@defer.inlineCallbacks
- def get_events(self, destinations, room_id, event_ids, return_local=True):
- """Fetch events from some remote destinations, checking if we already
- have them.
+ def get_events_from_store_or_dest(self, destination, room_id, event_ids):
+ """Fetch events from a remote destination, checking if we already have them.
Args:
- destinations (list)
+ destination (str)
room_id (str)
event_ids (list)
- return_local (bool): Whether to include events we already have in
- the DB in the returned list of events
Returns:
Deferred: A deferred resolving to a 2-tuple where the first is a list of
events and the second is a list of event ids that we failed to fetch.
"""
- if return_local:
- seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
- signed_events = list(seen_events.values())
- else:
- seen_events = yield self.store.have_seen_events(event_ids)
- signed_events = []
+ seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
+ signed_events = list(seen_events.values())
failed_to_fetch = set()
@@ -428,10 +431,11 @@ class FederationClient(FederationBase):
if not missing_events:
defer.returnValue((signed_events, failed_to_fetch))
- def random_server_list():
- srvs = list(destinations)
- random.shuffle(srvs)
- return srvs
+ logger.debug(
+ "Fetching unknown state/auth events %s for room %s",
+ missing_events,
+ event_ids,
+ )
room_version = yield self.store.get_room_version(room_id)
@@ -443,7 +447,7 @@ class FederationClient(FederationBase):
deferreds = [
run_in_background(
self.get_pdu,
- destinations=random_server_list(),
+ destinations=[destination],
event_id=e_id,
room_version=room_version,
)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 385eda2dca..949a5fb2aa 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -23,7 +23,11 @@ from twisted.internet import defer
import synapse
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
from synapse.api.room_versions import RoomVersions
-from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX
+from synapse.api.urls import (
+ FEDERATION_UNSTABLE_PREFIX,
+ FEDERATION_V1_PREFIX,
+ FEDERATION_V2_PREFIX,
+)
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.server import JsonResource
from synapse.http.servlet import (
@@ -90,6 +94,7 @@ class NoAuthenticationError(AuthenticationError):
class Authenticator(object):
def __init__(self, hs):
+ self._clock = hs.get_clock()
self.keyring = hs.get_keyring()
self.server_name = hs.hostname
self.store = hs.get_datastore()
@@ -98,6 +103,7 @@ class Authenticator(object):
# A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks
def authenticate_request(self, request, content):
+ now = self._clock.time_msec()
json_request = {
"method": request.method.decode('ascii'),
"uri": request.uri.decode('ascii'),
@@ -134,7 +140,9 @@ class Authenticator(object):
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
)
- yield self.keyring.verify_json_for_server(origin, json_request)
+ yield self.keyring.verify_json_for_server(
+ origin, json_request, now, "Incoming request"
+ )
logger.info("Request from %s", origin)
request.authenticated_entity = origin
@@ -1304,6 +1312,30 @@ class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
defer.returnValue((200, new_content))
+class RoomComplexityServlet(BaseFederationServlet):
+ """
+ Indicates to other servers how complex (and therefore likely
+ resource-intensive) a public room this server knows about is.
+ """
+ PATH = "/rooms/(?P<room_id>[^/]*)/complexity"
+ PREFIX = FEDERATION_UNSTABLE_PREFIX
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, room_id):
+
+ store = self.handler.hs.get_datastore()
+
+ is_public = yield store.is_room_world_readable_or_publicly_joinable(
+ room_id
+ )
+
+ if not is_public:
+ raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM)
+
+ complexity = yield store.get_room_complexity(room_id)
+ defer.returnValue((200, complexity))
+
+
FEDERATION_SERVLET_CLASSES = (
FederationSendServlet,
FederationEventServlet,
@@ -1327,6 +1359,7 @@ FEDERATION_SERVLET_CLASSES = (
FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet,
FederationVersionServlet,
+ RoomComplexityServlet,
)
OPENID_SERVLET_CLASSES = (
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index 786149be65..e5dda1975f 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -97,10 +97,13 @@ class GroupAttestationSigning(object):
# TODO: We also want to check that *new* attestations that people give
# us to store are valid for at least a little while.
- if valid_until_ms < self.clock.time_msec():
+ now = self.clock.time_msec()
+ if valid_until_ms < now:
raise SynapseError(400, "Attestation expired")
- yield self.keyring.verify_json_for_server(server_name, attestation)
+ yield self.keyring.verify_json_for_server(
+ server_name, attestation, now, "Group attestation"
+ )
def create_attestation(self, group_id, user_id):
"""Create an attestation for the group_id and user_id with default
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 6003ad9cca..eb525070cf 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -122,6 +122,9 @@ class EventStreamHandler(BaseHandler):
chunks = yield self._event_serializer.serialize_events(
events, time_now, as_client_event=as_client_event,
+ # We don't bundle "live" events, as otherwise clients
+ # will end up double counting annotations.
+ bundle_aggregations=False,
)
chunk = {
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 0684778882..ac5ca79143 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -35,6 +35,7 @@ from synapse.api.errors import (
CodeMessageException,
FederationDeniedError,
FederationError,
+ RequestSendFailed,
StoreError,
SynapseError,
)
@@ -1916,6 +1917,11 @@ class FederationHandler(BaseHandler):
event.room_id, latest_event_ids=extrem_ids,
)
+ logger.debug(
+ "Doing soft-fail check for %s: state %s",
+ event.event_id, current_state_ids,
+ )
+
# Now check if event pass auth against said current state
auth_types = auth_types_for_event(event)
current_state_ids = [
@@ -1932,7 +1938,7 @@ class FederationHandler(BaseHandler):
self.auth.check(room_version, event, auth_events=current_auth_events)
except AuthError as e:
logger.warn(
- "Failed current state auth resolution for %r because %s",
+ "Soft-failing %r because %s",
event, e,
)
event.internal_metadata.soft_failed = True
@@ -2008,15 +2014,65 @@ class FederationHandler(BaseHandler):
Args:
origin (str):
- event (synapse.events.FrozenEvent):
+ event (synapse.events.EventBase):
context (synapse.events.snapshot.EventContext):
- auth_events (dict[(str, str)->str]):
+ auth_events (dict[(str, str)->synapse.events.EventBase]):
+ Map from (event_type, state_key) to event
+
+ What we expect the event's auth_events to be, based on the event's
+ position in the dag. I think? maybe??
+
+ Also NB that this function adds entries to it.
+ Returns:
+ defer.Deferred[None]
+ """
+ room_version = yield self.store.get_room_version(event.room_id)
+
+ try:
+ yield self._update_auth_events_and_context_for_auth(
+ origin, event, context, auth_events
+ )
+ except Exception:
+ # We don't really mind if the above fails, so lets not fail
+ # processing if it does. However, it really shouldn't fail so
+ # let's still log as an exception since we'll still want to fix
+ # any bugs.
+ logger.exception(
+ "Failed to double check auth events for %s with remote. "
+ "Ignoring failure and continuing processing of event.",
+ event.event_id,
+ )
+
+ try:
+ self.auth.check(room_version, event, auth_events=auth_events)
+ except AuthError as e:
+ logger.warn("Failed auth resolution for %r because %s", event, e)
+ raise e
+
+ @defer.inlineCallbacks
+ def _update_auth_events_and_context_for_auth(
+ self, origin, event, context, auth_events
+ ):
+ """Helper for do_auth. See there for docs.
+
+ Checks whether a given event has the expected auth events. If it
+ doesn't then we talk to the remote server to compare state to see if
+ we can come to a consensus (e.g. if one server missed some valid
+ state).
+
+ This attempts to resovle any potential divergence of state between
+ servers, but is not essential and so failures should not block further
+ processing of the event.
+
+ Args:
+ origin (str):
+ event (synapse.events.EventBase):
+ context (synapse.events.snapshot.EventContext):
+ auth_events (dict[(str, str)->synapse.events.EventBase]):
Returns:
defer.Deferred[None]
"""
- # Check if we have all the auth events.
- current_state = set(e.event_id for e in auth_events.values())
event_auth_events = set(event.auth_event_ids())
if event.is_state():
@@ -2024,11 +2080,21 @@ class FederationHandler(BaseHandler):
else:
event_key = None
- if event_auth_events - current_state:
+ # if the event's auth_events refers to events which are not in our
+ # calculated auth_events, we need to fetch those events from somewhere.
+ #
+ # we start by fetching them from the store, and then try calling /event_auth/.
+ missing_auth = event_auth_events.difference(
+ e.event_id for e in auth_events.values()
+ )
+
+ if missing_auth:
# TODO: can we use store.have_seen_events here instead?
have_events = yield self.store.get_seen_events_with_rejections(
- event_auth_events - current_state
+ missing_auth
)
+ logger.debug("Got events %s from store", have_events)
+ missing_auth.difference_update(have_events.keys())
else:
have_events = {}
@@ -2037,17 +2103,22 @@ class FederationHandler(BaseHandler):
for e in auth_events.values()
})
- seen_events = set(have_events.keys())
-
- missing_auth = event_auth_events - seen_events - current_state
-
if missing_auth:
- logger.info("Missing auth: %s", missing_auth)
# If we don't have all the auth events, we need to get them.
+ logger.info(
+ "auth_events contains unknown events: %s",
+ missing_auth,
+ )
try:
- remote_auth_chain = yield self.federation_client.get_event_auth(
- origin, event.room_id, event.event_id
- )
+ try:
+ remote_auth_chain = yield self.federation_client.get_event_auth(
+ origin, event.room_id, event.event_id
+ )
+ except RequestSendFailed as e:
+ # The other side isn't around or doesn't implement the
+ # endpoint, so lets just bail out.
+ logger.info("Failed to get event auth from remote: %s", e)
+ return
seen_remotes = yield self.store.have_seen_events(
[e.event_id for e in remote_auth_chain]
@@ -2084,145 +2155,174 @@ class FederationHandler(BaseHandler):
have_events = yield self.store.get_seen_events_with_rejections(
event.auth_event_ids()
)
- seen_events = set(have_events.keys())
except Exception:
# FIXME:
logger.exception("Failed to get auth chain")
+ if event.internal_metadata.is_outlier():
+ logger.info("Skipping auth_event fetch for outlier")
+ return
+
# FIXME: Assumes we have and stored all the state for all the
# prev_events
- current_state = set(e.event_id for e in auth_events.values())
- different_auth = event_auth_events - current_state
+ different_auth = event_auth_events.difference(
+ e.event_id for e in auth_events.values()
+ )
- room_version = yield self.store.get_room_version(event.room_id)
+ if not different_auth:
+ return
- if different_auth and not event.internal_metadata.is_outlier():
- # Do auth conflict res.
- logger.info("Different auth: %s", different_auth)
-
- different_events = yield logcontext.make_deferred_yieldable(
- defer.gatherResults([
- logcontext.run_in_background(
- self.store.get_event,
- d,
- allow_none=True,
- allow_rejected=False,
- )
- for d in different_auth
- if d in have_events and not have_events[d]
- ], consumeErrors=True)
- ).addErrback(unwrapFirstError)
-
- if different_events:
- local_view = dict(auth_events)
- remote_view = dict(auth_events)
- remote_view.update({
- (d.type, d.state_key): d for d in different_events if d
- })
+ logger.info(
+ "auth_events refers to events which are not in our calculated auth "
+ "chain: %s",
+ different_auth,
+ )
+
+ room_version = yield self.store.get_room_version(event.room_id)
- new_state = yield self.state_handler.resolve_events(
- room_version,
- [list(local_view.values()), list(remote_view.values())],
- event
+ different_events = yield logcontext.make_deferred_yieldable(
+ defer.gatherResults([
+ logcontext.run_in_background(
+ self.store.get_event,
+ d,
+ allow_none=True,
+ allow_rejected=False,
)
+ for d in different_auth
+ if d in have_events and not have_events[d]
+ ], consumeErrors=True)
+ ).addErrback(unwrapFirstError)
+
+ if different_events:
+ local_view = dict(auth_events)
+ remote_view = dict(auth_events)
+ remote_view.update({
+ (d.type, d.state_key): d for d in different_events if d
+ })
- auth_events.update(new_state)
+ new_state = yield self.state_handler.resolve_events(
+ room_version,
+ [list(local_view.values()), list(remote_view.values())],
+ event
+ )
- current_state = set(e.event_id for e in auth_events.values())
- different_auth = event_auth_events - current_state
+ logger.info(
+ "After state res: updating auth_events with new state %s",
+ {
+ (d.type, d.state_key): d.event_id for d in new_state.values()
+ if auth_events.get((d.type, d.state_key)) != d
+ },
+ )
- yield self._update_context_for_auth_events(
- event, context, auth_events, event_key,
- )
+ auth_events.update(new_state)
- if different_auth and not event.internal_metadata.is_outlier():
- logger.info("Different auth after resolution: %s", different_auth)
+ different_auth = event_auth_events.difference(
+ e.event_id for e in auth_events.values()
+ )
- # Only do auth resolution if we have something new to say.
- # We can't rove an auth failure.
- do_resolution = False
+ yield self._update_context_for_auth_events(
+ event, context, auth_events, event_key,
+ )
- provable = [
- RejectedReason.NOT_ANCESTOR, RejectedReason.NOT_ANCESTOR,
- ]
+ if not different_auth:
+ # we're done
+ return
- for e_id in different_auth:
- if e_id in have_events:
- if have_events[e_id] in provable:
- do_resolution = True
- break
+ logger.info(
+ "auth_events still refers to events which are not in the calculated auth "
+ "chain after state resolution: %s",
+ different_auth,
+ )
- if do_resolution:
- prev_state_ids = yield context.get_prev_state_ids(self.store)
- # 1. Get what we think is the auth chain.
- auth_ids = yield self.auth.compute_auth_events(
- event, prev_state_ids
- )
- local_auth_chain = yield self.store.get_auth_chain(
- auth_ids, include_given=True
- )
+ # Only do auth resolution if we have something new to say.
+ # We can't prove an auth failure.
+ do_resolution = False
- try:
- # 2. Get remote difference.
- result = yield self.federation_client.query_auth(
- origin,
- event.room_id,
- event.event_id,
- local_auth_chain,
- )
+ for e_id in different_auth:
+ if e_id in have_events:
+ if have_events[e_id] == RejectedReason.NOT_ANCESTOR:
+ do_resolution = True
+ break
- seen_remotes = yield self.store.have_seen_events(
- [e.event_id for e in result["auth_chain"]]
- )
+ if not do_resolution:
+ logger.info(
+ "Skipping auth resolution due to lack of provable rejection reasons"
+ )
+ return
- # 3. Process any remote auth chain events we haven't seen.
- for ev in result["auth_chain"]:
- if ev.event_id in seen_remotes:
- continue
+ logger.info("Doing auth resolution")
- if ev.event_id == event.event_id:
- continue
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
- try:
- auth_ids = ev.auth_event_ids()
- auth = {
- (e.type, e.state_key): e
- for e in result["auth_chain"]
- if e.event_id in auth_ids
- or event.type == EventTypes.Create
- }
- ev.internal_metadata.outlier = True
+ # 1. Get what we think is the auth chain.
+ auth_ids = yield self.auth.compute_auth_events(
+ event, prev_state_ids
+ )
+ local_auth_chain = yield self.store.get_auth_chain(
+ auth_ids, include_given=True
+ )
- logger.debug(
- "do_auth %s different_auth: %s",
- event.event_id, e.event_id
- )
+ try:
+ # 2. Get remote difference.
+ try:
+ result = yield self.federation_client.query_auth(
+ origin,
+ event.room_id,
+ event.event_id,
+ local_auth_chain,
+ )
+ except RequestSendFailed as e:
+ # The other side isn't around or doesn't implement the
+ # endpoint, so lets just bail out.
+ logger.info("Failed to query auth from remote: %s", e)
+ return
+
+ seen_remotes = yield self.store.have_seen_events(
+ [e.event_id for e in result["auth_chain"]]
+ )
- yield self._handle_new_event(
- origin, ev, auth_events=auth
- )
+ # 3. Process any remote auth chain events we haven't seen.
+ for ev in result["auth_chain"]:
+ if ev.event_id in seen_remotes:
+ continue
- if ev.event_id in event_auth_events:
- auth_events[(ev.type, ev.state_key)] = ev
- except AuthError:
- pass
+ if ev.event_id == event.event_id:
+ continue
- except Exception:
- # FIXME:
- logger.exception("Failed to query auth chain")
+ try:
+ auth_ids = ev.auth_event_ids()
+ auth = {
+ (e.type, e.state_key): e
+ for e in result["auth_chain"]
+ if e.event_id in auth_ids
+ or event.type == EventTypes.Create
+ }
+ ev.internal_metadata.outlier = True
+
+ logger.debug(
+ "do_auth %s different_auth: %s",
+ event.event_id, e.event_id
+ )
- # 4. Look at rejects and their proofs.
- # TODO.
+ yield self._handle_new_event(
+ origin, ev, auth_events=auth
+ )
- yield self._update_context_for_auth_events(
- event, context, auth_events, event_key,
- )
+ if ev.event_id in event_auth_events:
+ auth_events[(ev.type, ev.state_key)] = ev
+ except AuthError:
+ pass
- try:
- self.auth.check(room_version, event, auth_events=auth_events)
- except AuthError as e:
- logger.warn("Failed auth resolution for %r because %s", event, e)
- raise e
+ except Exception:
+ # FIXME:
+ logger.exception("Failed to query auth chain")
+
+ # 4. Look at rejects and their proofs.
+ # TODO.
+
+ yield self._update_context_for_auth_events(
+ event, context, auth_events, event_key,
+ )
@defer.inlineCallbacks
def _update_context_for_auth_events(self, event, context, auth_events,
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 7b2c33a922..0b02469ceb 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -22,7 +22,7 @@ from canonicaljson import encode_canonical_json, json
from twisted.internet import defer
from twisted.internet.defer import succeed
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RelationTypes
from synapse.api.errors import (
AuthError,
Codes,
@@ -166,6 +166,9 @@ class MessageHandler(object):
now = self.clock.time_msec()
events = yield self._event_serializer.serialize_events(
room_state.values(), now,
+ # We don't bother bundling aggregations in when asked for state
+ # events, as clients won't use them.
+ bundle_aggregations=False,
)
defer.returnValue(events)
@@ -601,6 +604,20 @@ class EventCreationHandler(object):
self.validator.validate_new(event)
+ # If this event is an annotation then we check that that the sender
+ # can't annotate the same way twice (e.g. stops users from liking an
+ # event multiple times).
+ relation = event.content.get("m.relates_to", {})
+ if relation.get("rel_type") == RelationTypes.ANNOTATION:
+ relates_to = relation["event_id"]
+ aggregation_key = relation["key"]
+
+ already_exists = yield self.store.has_user_annotated_event(
+ relates_to, event.type, aggregation_key, event.sender,
+ )
+ if already_exists:
+ raise SynapseError(400, "Can't send same reaction twice")
+
logger.debug(
"Created event %s",
event.event_id,
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 59d53f1050..e49c8203ef 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -182,17 +182,27 @@ class PresenceHandler(object):
# Start a LoopingCall in 30s that fires every 5s.
# The initial delay is to allow disconnected clients a chance to
# reconnect before we treat them as offline.
+ def run_timeout_handler():
+ return run_as_background_process(
+ "handle_presence_timeouts", self._handle_timeouts
+ )
+
self.clock.call_later(
30,
self.clock.looping_call,
- self._handle_timeouts,
+ run_timeout_handler,
5000,
)
+ def run_persister():
+ return run_as_background_process(
+ "persist_presence_changes", self._persist_unpersisted_changes
+ )
+
self.clock.call_later(
60,
self.clock.looping_call,
- self._persist_unpersisted_changes,
+ run_persister,
60 * 1000,
)
@@ -229,6 +239,7 @@ class PresenceHandler(object):
)
if self.unpersisted_users_changes:
+
yield self.store.update_presence([
self.user_to_current_state[user_id]
for user_id in self.unpersisted_users_changes
@@ -240,30 +251,18 @@ class PresenceHandler(object):
"""We periodically persist the unpersisted changes, as otherwise they
may stack up and slow down shutdown times.
"""
- logger.info(
- "Performing _persist_unpersisted_changes. Persisting %d unpersisted changes",
- len(self.unpersisted_users_changes)
- )
-
unpersisted = self.unpersisted_users_changes
self.unpersisted_users_changes = set()
if unpersisted:
+ logger.info(
+ "Persisting %d upersisted presence updates", len(unpersisted)
+ )
yield self.store.update_presence([
self.user_to_current_state[user_id]
for user_id in unpersisted
])
- logger.info("Finished _persist_unpersisted_changes")
-
- @defer.inlineCallbacks
- def _update_states_and_catch_exception(self, new_states):
- try:
- res = yield self._update_states(new_states)
- defer.returnValue(res)
- except Exception:
- logger.exception("Error updating presence")
-
@defer.inlineCallbacks
def _update_states(self, new_states):
"""Updates presence of users. Sets the appropriate timeouts. Pokes
@@ -338,45 +337,41 @@ class PresenceHandler(object):
logger.info("Handling presence timeouts")
now = self.clock.time_msec()
- try:
- with Measure(self.clock, "presence_handle_timeouts"):
- # Fetch the list of users that *may* have timed out. Things may have
- # changed since the timeout was set, so we won't necessarily have to
- # take any action.
- users_to_check = set(self.wheel_timer.fetch(now))
-
- # Check whether the lists of syncing processes from an external
- # process have expired.
- expired_process_ids = [
- process_id for process_id, last_update
- in self.external_process_last_updated_ms.items()
- if now - last_update > EXTERNAL_PROCESS_EXPIRY
- ]
- for process_id in expired_process_ids:
- users_to_check.update(
- self.external_process_last_updated_ms.pop(process_id, ())
- )
- self.external_process_last_update.pop(process_id)
+ # Fetch the list of users that *may* have timed out. Things may have
+ # changed since the timeout was set, so we won't necessarily have to
+ # take any action.
+ users_to_check = set(self.wheel_timer.fetch(now))
+
+ # Check whether the lists of syncing processes from an external
+ # process have expired.
+ expired_process_ids = [
+ process_id for process_id, last_update
+ in self.external_process_last_updated_ms.items()
+ if now - last_update > EXTERNAL_PROCESS_EXPIRY
+ ]
+ for process_id in expired_process_ids:
+ users_to_check.update(
+ self.external_process_last_updated_ms.pop(process_id, ())
+ )
+ self.external_process_last_update.pop(process_id)
- states = [
- self.user_to_current_state.get(
- user_id, UserPresenceState.default(user_id)
- )
- for user_id in users_to_check
- ]
+ states = [
+ self.user_to_current_state.get(
+ user_id, UserPresenceState.default(user_id)
+ )
+ for user_id in users_to_check
+ ]
- timers_fired_counter.inc(len(states))
+ timers_fired_counter.inc(len(states))
- changes = handle_timeouts(
- states,
- is_mine_fn=self.is_mine_id,
- syncing_user_ids=self.get_currently_syncing_users(),
- now=now,
- )
+ changes = handle_timeouts(
+ states,
+ is_mine_fn=self.is_mine_id,
+ syncing_user_ids=self.get_currently_syncing_users(),
+ now=now,
+ )
- run_in_background(self._update_states_and_catch_exception, changes)
- except Exception:
- logger.exception("Exception in _handle_timeouts loop")
+ return self._update_states(changes)
@defer.inlineCallbacks
def bump_presence_active_time(self, user):
@@ -833,14 +828,17 @@ class PresenceHandler(object):
# joins.
continue
- event = yield self.store.get_event(event_id)
- if event.content.get("membership") != Membership.JOIN:
+ event = yield self.store.get_event(event_id, allow_none=True)
+ if not event or event.content.get("membership") != Membership.JOIN:
# We only care about joins
continue
if prev_event_id:
- prev_event = yield self.store.get_event(prev_event_id)
- if prev_event.content.get("membership") == Membership.JOIN:
+ prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
+ if (
+ prev_event
+ and prev_event.content.get("membership") == Membership.JOIN
+ ):
# Ignore changes to join events.
continue
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 91fc718ff8..a5fc6c5dbf 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -31,6 +31,9 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__)
+MAX_DISPLAYNAME_LEN = 100
+MAX_AVATAR_URL_LEN = 1000
+
class BaseProfileHandler(BaseHandler):
"""Handles fetching and updating user profile information.
@@ -162,6 +165,11 @@ class BaseProfileHandler(BaseHandler):
if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's displayname")
+ if len(new_displayname) > MAX_DISPLAYNAME_LEN:
+ raise SynapseError(
+ 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN, ),
+ )
+
if new_displayname == '':
new_displayname = None
@@ -217,6 +225,11 @@ class BaseProfileHandler(BaseHandler):
if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's avatar_url")
+ if len(new_avatar_url) > MAX_AVATAR_URL_LEN:
+ raise SynapseError(
+ 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN, ),
+ )
+
yield self.store.set_profile_avatar_url(
target_user.localpart, new_avatar_url
)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index e83ee24f10..9a388ea013 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -531,6 +531,8 @@ class RegistrationHandler(BaseHandler):
A tuple of (user_id, access_token).
Raises:
RegistrationError if there was a problem registering.
+
+ NB this is only used in tests. TODO: move it to the test package!
"""
if localpart is None:
raise SynapseError(400, "Request must include user id")
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index e37ae96899..4a17911a87 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -27,7 +27,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
-from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
from synapse.util import stringutils
@@ -70,6 +70,7 @@ class RoomCreationHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
+ self.config = hs.config
# linearizer to stop two upgrades happening at once
self._upgrade_linearizer = Linearizer("room_upgrade_linearizer")
@@ -475,7 +476,11 @@ class RoomCreationHandler(BaseHandler):
if ratelimit:
yield self.ratelimit(requester)
- room_version = config.get("room_version", DEFAULT_ROOM_VERSION.identifier)
+ room_version = config.get(
+ "room_version",
+ self.config.default_room_version.identifier,
+ )
+
if not isinstance(room_version, string_types):
raise SynapseError(
400,
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
new file mode 100644
index 0000000000..7ad16c8566
--- /dev/null
+++ b/synapse/handlers/stats.py
@@ -0,0 +1,333 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.handlers.state_deltas import StateDeltasHandler
+from synapse.metrics import event_processing_positions
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import UserID
+from synapse.util.metrics import Measure
+
+logger = logging.getLogger(__name__)
+
+
+class StatsHandler(StateDeltasHandler):
+ """Handles keeping the *_stats tables updated with a simple time-series of
+ information about the users, rooms and media on the server, such that admins
+ have some idea of who is consuming their resources.
+
+ Heavily derived from UserDirectoryHandler
+ """
+
+ def __init__(self, hs):
+ super(StatsHandler, self).__init__(hs)
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.state = hs.get_state_handler()
+ self.server_name = hs.hostname
+ self.clock = hs.get_clock()
+ self.notifier = hs.get_notifier()
+ self.is_mine_id = hs.is_mine_id
+ self.stats_bucket_size = hs.config.stats_bucket_size
+
+ # The current position in the current_state_delta stream
+ self.pos = None
+
+ # Guard to ensure we only process deltas one at a time
+ self._is_processing = False
+
+ if hs.config.stats_enabled:
+ self.notifier.add_replication_callback(self.notify_new_event)
+
+ # We kick this off so that we don't have to wait for a change before
+ # we start populating stats
+ self.clock.call_later(0, self.notify_new_event)
+
+ def notify_new_event(self):
+ """Called when there may be more deltas to process
+ """
+ if not self.hs.config.stats_enabled:
+ return
+
+ if self._is_processing:
+ return
+
+ @defer.inlineCallbacks
+ def process():
+ try:
+ yield self._unsafe_process()
+ finally:
+ self._is_processing = False
+
+ self._is_processing = True
+ run_as_background_process("stats.notify_new_event", process)
+
+ @defer.inlineCallbacks
+ def _unsafe_process(self):
+ # If self.pos is None then means we haven't fetched it from DB
+ if self.pos is None:
+ self.pos = yield self.store.get_stats_stream_pos()
+
+ # If still None then the initial background update hasn't happened yet
+ if self.pos is None:
+ defer.returnValue(None)
+
+ # Loop round handling deltas until we're up to date
+ while True:
+ with Measure(self.clock, "stats_delta"):
+ deltas = yield self.store.get_current_state_deltas(self.pos)
+ if not deltas:
+ return
+
+ logger.info("Handling %d state deltas", len(deltas))
+ yield self._handle_deltas(deltas)
+
+ self.pos = deltas[-1]["stream_id"]
+ yield self.store.update_stats_stream_pos(self.pos)
+
+ event_processing_positions.labels("stats").set(self.pos)
+
+ @defer.inlineCallbacks
+ def _handle_deltas(self, deltas):
+ """
+ Called with the state deltas to process
+ """
+ for delta in deltas:
+ typ = delta["type"]
+ state_key = delta["state_key"]
+ room_id = delta["room_id"]
+ event_id = delta["event_id"]
+ stream_id = delta["stream_id"]
+ prev_event_id = delta["prev_event_id"]
+ stream_pos = delta["stream_id"]
+
+ logger.debug("Handling: %r %r, %s", typ, state_key, event_id)
+
+ token = yield self.store.get_earliest_token_for_room_stats(room_id)
+
+ # If the earliest token to begin from is larger than our current
+ # stream ID, skip processing this delta.
+ if token is not None and token >= stream_id:
+ logger.debug(
+ "Ignoring: %s as earlier than this room's initial ingestion event",
+ event_id,
+ )
+ continue
+
+ if event_id is None and prev_event_id is None:
+ # Errr...
+ continue
+
+ event_content = {}
+
+ if event_id is not None:
+ event = yield self.store.get_event(event_id, allow_none=True)
+ if event:
+ event_content = event.content or {}
+
+ # We use stream_pos here rather than fetch by event_id as event_id
+ # may be None
+ now = yield self.store.get_received_ts_by_stream_pos(stream_pos)
+
+ # quantise time to the nearest bucket
+ now = (now // 1000 // self.stats_bucket_size) * self.stats_bucket_size
+
+ if typ == EventTypes.Member:
+ # we could use _get_key_change here but it's a bit inefficient
+ # given we're not testing for a specific result; might as well
+ # just grab the prev_membership and membership strings and
+ # compare them.
+ prev_event_content = {}
+ if prev_event_id is not None:
+ prev_event = yield self.store.get_event(
+ prev_event_id, allow_none=True,
+ )
+ if prev_event:
+ prev_event_content = prev_event.content
+
+ membership = event_content.get("membership", Membership.LEAVE)
+ prev_membership = prev_event_content.get("membership", Membership.LEAVE)
+
+ if prev_membership == membership:
+ continue
+
+ if prev_membership == Membership.JOIN:
+ yield self.store.update_stats_delta(
+ now, "room", room_id, "joined_members", -1
+ )
+ elif prev_membership == Membership.INVITE:
+ yield self.store.update_stats_delta(
+ now, "room", room_id, "invited_members", -1
+ )
+ elif prev_membership == Membership.LEAVE:
+ yield self.store.update_stats_delta(
+ now, "room", room_id, "left_members", -1
+ )
+ elif prev_membership == Membership.BAN:
+ yield self.store.update_stats_delta(
+ now, "room", room_id, "banned_members", -1
+ )
+ else:
+ err = "%s is not a valid prev_membership" % (repr(prev_membership),)
+ logger.error(err)
+ raise ValueError(err)
+
+ if membership == Membership.JOIN:
+ yield self.store.update_stats_delta(
+ now, "room", room_id, "joined_members", +1
+ )
+ elif membership == Membership.INVITE:
+ yield self.store.update_stats_delta(
+ now, "room", room_id, "invited_members", +1
+ )
+ elif membership == Membership.LEAVE:
+ yield self.store.update_stats_delta(
+ now, "room", room_id, "left_members", +1
+ )
+ elif membership == Membership.BAN:
+ yield self.store.update_stats_delta(
+ now, "room", room_id, "banned_members", +1
+ )
+ else:
+ err = "%s is not a valid membership" % (repr(membership),)
+ logger.error(err)
+ raise ValueError(err)
+
+ user_id = state_key
+ if self.is_mine_id(user_id):
+ # update user_stats as it's one of our users
+ public = yield self._is_public_room(room_id)
+
+ if membership == Membership.LEAVE:
+ yield self.store.update_stats_delta(
+ now,
+ "user",
+ user_id,
+ "public_rooms" if public else "private_rooms",
+ -1,
+ )
+ elif membership == Membership.JOIN:
+ yield self.store.update_stats_delta(
+ now,
+ "user",
+ user_id,
+ "public_rooms" if public else "private_rooms",
+ +1,
+ )
+
+ elif typ == EventTypes.Create:
+ # Newly created room. Add it with all blank portions.
+ yield self.store.update_room_state(
+ room_id,
+ {
+ "join_rules": None,
+ "history_visibility": None,
+ "encryption": None,
+ "name": None,
+ "topic": None,
+ "avatar": None,
+ "canonical_alias": None,
+ },
+ )
+
+ elif typ == EventTypes.JoinRules:
+ yield self.store.update_room_state(
+ room_id, {"join_rules": event_content.get("join_rule")}
+ )
+
+ is_public = yield self._get_key_change(
+ prev_event_id, event_id, "join_rule", JoinRules.PUBLIC
+ )
+ if is_public is not None:
+ yield self.update_public_room_stats(now, room_id, is_public)
+
+ elif typ == EventTypes.RoomHistoryVisibility:
+ yield self.store.update_room_state(
+ room_id,
+ {"history_visibility": event_content.get("history_visibility")},
+ )
+
+ is_public = yield self._get_key_change(
+ prev_event_id, event_id, "history_visibility", "world_readable"
+ )
+ if is_public is not None:
+ yield self.update_public_room_stats(now, room_id, is_public)
+
+ elif typ == EventTypes.Encryption:
+ yield self.store.update_room_state(
+ room_id, {"encryption": event_content.get("algorithm")}
+ )
+ elif typ == EventTypes.Name:
+ yield self.store.update_room_state(
+ room_id, {"name": event_content.get("name")}
+ )
+ elif typ == EventTypes.Topic:
+ yield self.store.update_room_state(
+ room_id, {"topic": event_content.get("topic")}
+ )
+ elif typ == EventTypes.RoomAvatar:
+ yield self.store.update_room_state(
+ room_id, {"avatar": event_content.get("url")}
+ )
+ elif typ == EventTypes.CanonicalAlias:
+ yield self.store.update_room_state(
+ room_id, {"canonical_alias": event_content.get("alias")}
+ )
+
+ @defer.inlineCallbacks
+ def update_public_room_stats(self, ts, room_id, is_public):
+ """
+ Increment/decrement a user's number of public rooms when a room they are
+ in changes to/from public visibility.
+
+ Args:
+ ts (int): Timestamp in seconds
+ room_id (str)
+ is_public (bool)
+ """
+ # For now, blindly iterate over all local users in the room so that
+ # we can handle the whole problem of copying buckets over as needed
+ user_ids = yield self.store.get_users_in_room(room_id)
+
+ for user_id in user_ids:
+ if self.hs.is_mine(UserID.from_string(user_id)):
+ yield self.store.update_stats_delta(
+ ts, "user", user_id, "public_rooms", +1 if is_public else -1
+ )
+ yield self.store.update_stats_delta(
+ ts, "user", user_id, "private_rooms", -1 if is_public else +1
+ )
+
+ @defer.inlineCallbacks
+ def _is_public_room(self, room_id):
+ join_rules = yield self.state.get_current_state(room_id, EventTypes.JoinRules)
+ history_visibility = yield self.state.get_current_state(
+ room_id, EventTypes.RoomHistoryVisibility
+ )
+
+ if (join_rules and join_rules.content.get("join_rule") == JoinRules.PUBLIC) or (
+ (
+ history_visibility
+ and history_visibility.content.get("history_visibility")
+ == "world_readable"
+ )
+ ):
+ defer.returnValue(True)
+ else:
+ defer.returnValue(False)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 1ee9a6e313..62fda0c664 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -583,30 +583,42 @@ class SyncHandler(object):
)
# if the room has a name or canonical_alias set, we can skip
- # calculating heroes. we assume that if the event has contents, it'll
- # be a valid name or canonical_alias - i.e. we're checking that they
- # haven't been "deleted" by blatting {} over the top.
+ # calculating heroes. Empty strings are falsey, so we check
+ # for the "name" value and default to an empty string.
if name_id:
name = yield self.store.get_event(name_id, allow_none=True)
- if name and name.content:
+ if name and name.content.get("name"):
defer.returnValue(summary)
if canonical_alias_id:
canonical_alias = yield self.store.get_event(
canonical_alias_id, allow_none=True,
)
- if canonical_alias and canonical_alias.content:
+ if canonical_alias and canonical_alias.content.get("alias"):
defer.returnValue(summary)
+ me = sync_config.user.to_string()
+
joined_user_ids = [
- r[0] for r in details.get(Membership.JOIN, empty_ms).members
+ r[0]
+ for r in details.get(Membership.JOIN, empty_ms).members
+ if r[0] != me
]
invited_user_ids = [
- r[0] for r in details.get(Membership.INVITE, empty_ms).members
+ r[0]
+ for r in details.get(Membership.INVITE, empty_ms).members
+ if r[0] != me
]
gone_user_ids = (
- [r[0] for r in details.get(Membership.LEAVE, empty_ms).members] +
- [r[0] for r in details.get(Membership.BAN, empty_ms).members]
+ [
+ r[0]
+ for r in details.get(Membership.LEAVE, empty_ms).members
+ if r[0] != me
+ ] + [
+ r[0]
+ for r in details.get(Membership.BAN, empty_ms).members
+ if r[0] != me
+ ]
)
# FIXME: only build up a member_ids list for our heroes
@@ -621,22 +633,13 @@ class SyncHandler(object):
member_ids[user_id] = event_id
# FIXME: order by stream ordering rather than as returned by SQL
- me = sync_config.user.to_string()
if (joined_user_ids or invited_user_ids):
summary['m.heroes'] = sorted(
- [
- user_id
- for user_id in (joined_user_ids + invited_user_ids)
- if user_id != me
- ]
+ [user_id for user_id in (joined_user_ids + invited_user_ids)]
)[0:5]
else:
summary['m.heroes'] = sorted(
- [
- user_id
- for user_id in gone_user_ids
- if user_id != me
- ]
+ [user_id for user_id in gone_user_ids]
)[0:5]
if not sync_config.filter_collection.lazy_load_members():
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 7eefc7b1fc..663ea72a7a 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -285,7 +285,24 @@ class MatrixFederationHttpClient(object):
request (MatrixFederationRequest): details of request to be sent
timeout (int|None): number of milliseconds to wait for the response headers
- (including connecting to the server). 60s by default.
+ (including connecting to the server), *for each attempt*.
+ 60s by default.
+
+ long_retries (bool): whether to use the long retry algorithm.
+
+ The regular retry algorithm makes 4 attempts, with intervals
+ [0.5s, 1s, 2s].
+
+ The long retry algorithm makes 11 attempts, with intervals
+ [4s, 16s, 60s, 60s, ...]
+
+ Both algorithms add -20%/+40% jitter to the retry intervals.
+
+ Note that the above intervals are *in addition* to the time spent
+ waiting for the request to complete (up to `timeout` ms).
+
+ NB: the long retry algorithm takes over 20 minutes to complete, with
+ a default timeout of 60s!
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
@@ -566,10 +583,14 @@ class MatrixFederationHttpClient(object):
the request body. This will be encoded as JSON.
json_data_callback (callable): A callable returning the dict to
use as the request body.
- long_retries (bool): A boolean that indicates whether we should
- retry for a short or long time.
- timeout(int): How long to try (in ms) the destination for before
- giving up. None indicates no timeout.
+
+ long_retries (bool): whether to use the long retry algorithm. See
+ docs on _send_request for details.
+
+ timeout (int|None): number of milliseconds to wait for the response headers
+ (including connecting to the server), *for each attempt*.
+ self._default_timeout (60s) by default.
+
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
backoff_on_404 (bool): True if we should count a 404 response as
@@ -627,15 +648,22 @@ class MatrixFederationHttpClient(object):
Args:
destination (str): The remote server to send the HTTP request
to.
+
path (str): The HTTP path.
+
data (dict): A dict containing the data that will be used as
the request body. This will be encoded as JSON.
- long_retries (bool): A boolean that indicates whether we should
- retry for a short or long time.
- timeout(int): How long to try (in ms) the destination for before
- giving up. None indicates no timeout.
+
+ long_retries (bool): whether to use the long retry algorithm. See
+ docs on _send_request for details.
+
+ timeout (int|None): number of milliseconds to wait for the response headers
+ (including connecting to the server), *for each attempt*.
+ self._default_timeout (60s) by default.
+
ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway.
+
args (dict): query params
Returns:
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
@@ -686,14 +714,19 @@ class MatrixFederationHttpClient(object):
Args:
destination (str): The remote server to send the HTTP request
to.
+
path (str): The HTTP path.
+
args (dict|None): A dictionary used to create query strings, defaults to
None.
- timeout (int): How long to try (in ms) the destination for before
- giving up. None indicates no timeout and that the request will
- be retried.
+
+ timeout (int|None): number of milliseconds to wait for the response headers
+ (including connecting to the server), *for each attempt*.
+ self._default_timeout (60s) by default.
+
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
+
try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end of
the request. Workaround for #3622 in Synapse <= v0.99.3.
@@ -711,10 +744,6 @@ class MatrixFederationHttpClient(object):
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
"""
- logger.debug("get_json args: %s", args)
-
- logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
-
request = MatrixFederationRequest(
method="GET",
destination=destination,
@@ -746,12 +775,18 @@ class MatrixFederationHttpClient(object):
destination (str): The remote server to send the HTTP request
to.
path (str): The HTTP path.
- long_retries (bool): A boolean that indicates whether we should
- retry for a short or long time.
- timeout(int): How long to try (in ms) the destination for before
- giving up. None indicates no timeout.
+
+ long_retries (bool): whether to use the long retry algorithm. See
+ docs on _send_request for details.
+
+ timeout (int|None): number of milliseconds to wait for the response headers
+ (including connecting to the server), *for each attempt*.
+ self._default_timeout (60s) by default.
+
ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway.
+
+ args (dict): query params
Returns:
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 528125e737..197c652850 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -55,7 +55,7 @@ def parse_integer_from_args(args, name, default=None, required=False):
return int(args[name][0])
except Exception:
message = "Query parameter %r must be an integer" % (name,)
- raise SynapseError(400, message)
+ raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
else:
if required:
message = "Missing integer query parameter %r" % (name,)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index fdcfb90a7e..f64baa4d58 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -16,7 +16,12 @@
import logging
-from pkg_resources import DistributionNotFound, VersionConflict, get_distribution
+from pkg_resources import (
+ DistributionNotFound,
+ Requirement,
+ VersionConflict,
+ get_provider,
+)
logger = logging.getLogger(__name__)
@@ -69,14 +74,6 @@ REQUIREMENTS = [
"attrs>=17.4.0",
"netaddr>=0.7.18",
-
- # requests is a transitive dep of treq, and urlib3 is a transitive dep
- # of requests, as well as of sentry-sdk.
- #
- # As of requests 2.21, requests does not yet support urllib3 1.25.
- # (If we do not pin it here, pip will give us the latest urllib3
- # due to the dep via sentry-sdk.)
- "urllib3<1.25",
]
CONDITIONAL_REQUIREMENTS = {
@@ -91,7 +88,13 @@ CONDITIONAL_REQUIREMENTS = {
# ACME support is required to provision TLS certificates from authorities
# that use the protocol, such as Let's Encrypt.
- "acme": ["txacme>=0.9.2"],
+ "acme": [
+ "txacme>=0.9.2",
+
+ # txacme depends on eliot. Eliot 1.8.0 is incompatible with
+ # python 3.5.2, as per https://github.com/itamarst/eliot/issues/418
+ 'eliot<1.8.0;python_version<"3.5.3"',
+ ],
"saml2": ["pysaml2>=4.5.0"],
"systemd": ["systemd-python>=231"],
@@ -125,10 +128,10 @@ class DependencyException(Exception):
@property
def dependencies(self):
for i in self.args[0]:
- yield '"' + i + '"'
+ yield "'" + i + "'"
-def check_requirements(for_feature=None, _get_distribution=get_distribution):
+def check_requirements(for_feature=None):
deps_needed = []
errors = []
@@ -139,7 +142,7 @@ def check_requirements(for_feature=None, _get_distribution=get_distribution):
for dependency in reqs:
try:
- _get_distribution(dependency)
+ _check_requirement(dependency)
except VersionConflict as e:
deps_needed.append(dependency)
errors.append(
@@ -157,7 +160,7 @@ def check_requirements(for_feature=None, _get_distribution=get_distribution):
for dependency in OPTS:
try:
- _get_distribution(dependency)
+ _check_requirement(dependency)
except VersionConflict as e:
deps_needed.append(dependency)
errors.append(
@@ -175,6 +178,23 @@ def check_requirements(for_feature=None, _get_distribution=get_distribution):
raise DependencyException(deps_needed)
+def _check_requirement(dependency_string):
+ """Parses a dependency string, and checks if the specified requirement is installed
+
+ Raises:
+ VersionConflict if the requirement is installed, but with the the wrong version
+ DistributionNotFound if nothing is found to provide the requirement
+ """
+ req = Requirement.parse(dependency_string)
+
+ # first check if the markers specify that this requirement needs installing
+ if req.marker is not None and not req.marker.evaluate():
+ # not required for this environment
+ return
+
+ get_provider(req)
+
+
if __name__ == "__main__":
import sys
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 744d85594f..d6c4dcdb18 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -822,10 +822,16 @@ class AdminRestResource(JsonResource):
def __init__(self, hs):
JsonResource.__init__(self, hs, canonical_json=False)
+ register_servlets(hs, self)
- register_servlets_for_client_rest_resource(hs, self)
- SendServerNoticeServlet(hs).register(self)
- VersionServlet(hs).register(self)
+
+def register_servlets(hs, http_server):
+ """
+ Register all the admin servlets.
+ """
+ register_servlets_for_client_rest_resource(hs, http_server)
+ SendServerNoticeServlet(hs).register(http_server)
+ VersionServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py
deleted file mode 100644
index dc63b661c0..0000000000
--- a/synapse/rest/client/v1/base.py
+++ /dev/null
@@ -1,65 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""This module contains base REST classes for constructing client v1 servlets.
-"""
-
-import logging
-import re
-
-from synapse.api.urls import CLIENT_API_PREFIX
-from synapse.http.servlet import RestServlet
-from synapse.rest.client.transactions import HttpTransactionCache
-
-logger = logging.getLogger(__name__)
-
-
-def client_path_patterns(path_regex, releases=(0,), include_in_unstable=True):
- """Creates a regex compiled client path with the correct client path
- prefix.
-
- Args:
- path_regex (str): The regex string to match. This should NOT have a ^
- as this will be prefixed.
- Returns:
- SRE_Pattern
- """
- patterns = [re.compile("^" + CLIENT_API_PREFIX + "/api/v1" + path_regex)]
- if include_in_unstable:
- unstable_prefix = CLIENT_API_PREFIX + "/unstable"
- patterns.append(re.compile("^" + unstable_prefix + path_regex))
- for release in releases:
- new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,)
- patterns.append(re.compile("^" + new_prefix + path_regex))
- return patterns
-
-
-class ClientV1RestServlet(RestServlet):
- """A base Synapse REST Servlet for the client version 1 API.
- """
-
- # This subclass was presumably created to allow the auth for the v1
- # protocol version to be different, however this behaviour was removed.
- # it may no longer be necessary
-
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer):
- """
- self.hs = hs
- self.builder_factory = hs.get_event_builder_factory()
- self.auth = hs.get_auth()
- self.txns = HttpTransactionCache(hs)
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 0220acf644..0035182bb9 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -19,11 +19,10 @@ import logging
from twisted.internet import defer
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import RoomAlias
-from .base import ClientV1RestServlet, client_path_patterns
-
logger = logging.getLogger(__name__)
@@ -33,13 +32,14 @@ def register_servlets(hs, http_server):
ClientAppserviceDirectoryListServer(hs).register(http_server)
-class ClientDirectoryServer(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$")
+class ClientDirectoryServer(RestServlet):
+ PATTERNS = client_patterns("/directory/room/(?P<room_alias>[^/]*)$", v1=True)
def __init__(self, hs):
- super(ClientDirectoryServer, self).__init__(hs)
+ super(ClientDirectoryServer, self).__init__()
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_alias):
@@ -120,13 +120,14 @@ class ClientDirectoryServer(ClientV1RestServlet):
defer.returnValue((200, {}))
-class ClientDirectoryListServer(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/directory/list/room/(?P<room_id>[^/]*)$")
+class ClientDirectoryListServer(RestServlet):
+ PATTERNS = client_patterns("/directory/list/room/(?P<room_id>[^/]*)$", v1=True)
def __init__(self, hs):
- super(ClientDirectoryListServer, self).__init__(hs)
+ super(ClientDirectoryListServer, self).__init__()
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
@@ -162,15 +163,16 @@ class ClientDirectoryListServer(ClientV1RestServlet):
defer.returnValue((200, {}))
-class ClientAppserviceDirectoryListServer(ClientV1RestServlet):
- PATTERNS = client_path_patterns(
- "/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$"
+class ClientAppserviceDirectoryListServer(RestServlet):
+ PATTERNS = client_patterns(
+ "/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$", v1=True
)
def __init__(self, hs):
- super(ClientAppserviceDirectoryListServer, self).__init__(hs)
+ super(ClientAppserviceDirectoryListServer, self).__init__()
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
+ self.auth = hs.get_auth()
def on_PUT(self, request, network_id, room_id):
content = parse_json_object_from_request(request)
diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index c3b0a39ab7..84ca36270b 100644
--- a/synapse/rest/client/v1/events.py
+++ b/synapse/rest/client/v1/events.py
@@ -19,21 +19,22 @@ import logging
from twisted.internet import defer
from synapse.api.errors import SynapseError
+from synapse.http.servlet import RestServlet
+from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.streams.config import PaginationConfig
-from .base import ClientV1RestServlet, client_path_patterns
-
logger = logging.getLogger(__name__)
-class EventStreamRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/events$")
+class EventStreamRestServlet(RestServlet):
+ PATTERNS = client_patterns("/events$", v1=True)
DEFAULT_LONGPOLL_TIME_MS = 30000
def __init__(self, hs):
- super(EventStreamRestServlet, self).__init__(hs)
+ super(EventStreamRestServlet, self).__init__()
self.event_stream_handler = hs.get_event_stream_handler()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -76,11 +77,11 @@ class EventStreamRestServlet(ClientV1RestServlet):
# TODO: Unit test gets, with and without auth, with different kinds of events.
-class EventRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/events/(?P<event_id>[^/]*)$")
+class EventRestServlet(RestServlet):
+ PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True)
def __init__(self, hs):
- super(EventRestServlet, self).__init__(hs)
+ super(EventRestServlet, self).__init__()
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py
index 3ead75cb77..0fe5f2d79b 100644
--- a/synapse/rest/client/v1/initial_sync.py
+++ b/synapse/rest/client/v1/initial_sync.py
@@ -15,19 +15,19 @@
from twisted.internet import defer
-from synapse.http.servlet import parse_boolean
+from synapse.http.servlet import RestServlet, parse_boolean
+from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.streams.config import PaginationConfig
-from .base import ClientV1RestServlet, client_path_patterns
-
# TODO: Needs unit testing
-class InitialSyncRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/initialSync$")
+class InitialSyncRestServlet(RestServlet):
+ PATTERNS = client_patterns("/initialSync$", v1=True)
def __init__(self, hs):
- super(InitialSyncRestServlet, self).__init__(hs)
+ super(InitialSyncRestServlet, self).__init__()
self.initial_sync_handler = hs.get_initial_sync_handler()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request):
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 5180e9eaf1..3b60728628 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -29,12 +29,11 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import UserID, map_username_to_mxid_localpart
from synapse.util.msisdn import phone_number_to_msisdn
-from .base import ClientV1RestServlet, client_path_patterns
-
logger = logging.getLogger(__name__)
@@ -81,15 +80,16 @@ def login_id_thirdparty_from_phone(identifier):
}
-class LoginRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/login$")
+class LoginRestServlet(RestServlet):
+ PATTERNS = client_patterns("/login$", v1=True)
CAS_TYPE = "m.login.cas"
SSO_TYPE = "m.login.sso"
TOKEN_TYPE = "m.login.token"
JWT_TYPE = "m.login.jwt"
def __init__(self, hs):
- super(LoginRestServlet, self).__init__(hs)
+ super(LoginRestServlet, self).__init__()
+ self.hs = hs
self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm
@@ -371,7 +371,7 @@ class LoginRestServlet(ClientV1RestServlet):
class CasRedirectServlet(RestServlet):
- PATTERNS = client_path_patterns("/login/(cas|sso)/redirect")
+ PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
def __init__(self, hs):
super(CasRedirectServlet, self).__init__()
@@ -386,7 +386,7 @@ class CasRedirectServlet(RestServlet):
b"redirectUrl": args[b"redirectUrl"][0]
}).encode('ascii')
hs_redirect_url = (self.cas_service_url +
- b"/_matrix/client/api/v1/login/cas/ticket")
+ b"/_matrix/client/r0/login/cas/ticket")
service_param = urllib.parse.urlencode({
b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)
}).encode('ascii')
@@ -394,27 +394,27 @@ class CasRedirectServlet(RestServlet):
finish_request(request)
-class CasTicketServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/login/cas/ticket", releases=())
+class CasTicketServlet(RestServlet):
+ PATTERNS = client_patterns("/login/cas/ticket", v1=True)
def __init__(self, hs):
- super(CasTicketServlet, self).__init__(hs)
+ super(CasTicketServlet, self).__init__()
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes
self._sso_auth_handler = SSOAuthHandler(hs)
+ self._http_client = hs.get_simple_http_client()
@defer.inlineCallbacks
def on_GET(self, request):
client_redirect_url = parse_string(request, "redirectUrl", required=True)
- http_client = self.hs.get_simple_http_client()
uri = self.cas_server_url + "/proxyValidate"
args = {
"ticket": parse_string(request, "ticket", required=True),
"service": self.cas_service_url
}
try:
- body = yield http_client.get_raw(uri, args)
+ body = yield self._http_client.get_raw(uri, args)
except PartialDownloadError as pde:
# Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index 430c692336..b8064f261e 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -17,19 +17,18 @@ import logging
from twisted.internet import defer
-from synapse.api.errors import AuthError
-
-from .base import ClientV1RestServlet, client_path_patterns
+from synapse.http.servlet import RestServlet
+from synapse.rest.client.v2_alpha._base import client_patterns
logger = logging.getLogger(__name__)
-class LogoutRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/logout$")
+class LogoutRestServlet(RestServlet):
+ PATTERNS = client_patterns("/logout$", v1=True)
def __init__(self, hs):
- super(LogoutRestServlet, self).__init__(hs)
- self._auth = hs.get_auth()
+ super(LogoutRestServlet, self).__init__()
+ self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
@@ -38,32 +37,25 @@ class LogoutRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
- try:
- requester = yield self.auth.get_user_by_req(request)
- except AuthError:
- # this implies the access token has already been deleted.
- defer.returnValue((401, {
- "errcode": "M_UNKNOWN_TOKEN",
- "error": "Access Token unknown or expired"
- }))
+ requester = yield self.auth.get_user_by_req(request)
+
+ if requester.device_id is None:
+ # the acccess token wasn't associated with a device.
+ # Just delete the access token
+ access_token = self.auth.get_access_token_from_request(request)
+ yield self._auth_handler.delete_access_token(access_token)
else:
- if requester.device_id is None:
- # the acccess token wasn't associated with a device.
- # Just delete the access token
- access_token = self._auth.get_access_token_from_request(request)
- yield self._auth_handler.delete_access_token(access_token)
- else:
- yield self._device_handler.delete_device(
- requester.user.to_string(), requester.device_id)
+ yield self._device_handler.delete_device(
+ requester.user.to_string(), requester.device_id)
defer.returnValue((200, {}))
-class LogoutAllRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/logout/all$")
+class LogoutAllRestServlet(RestServlet):
+ PATTERNS = client_patterns("/logout/all$", v1=True)
def __init__(self, hs):
- super(LogoutAllRestServlet, self).__init__(hs)
+ super(LogoutAllRestServlet, self).__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index 045d5a20ac..e263da3cb7 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -23,21 +23,22 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError
from synapse.handlers.presence import format_user_presence_state
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import UserID
-from .base import ClientV1RestServlet, client_path_patterns
-
logger = logging.getLogger(__name__)
-class PresenceStatusRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status")
+class PresenceStatusRestServlet(RestServlet):
+ PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status", v1=True)
def __init__(self, hs):
- super(PresenceStatusRestServlet, self).__init__(hs)
+ super(PresenceStatusRestServlet, self).__init__()
+ self.hs = hs
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index eac1966c5e..e15d9d82a6 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -16,18 +16,19 @@
""" This module contains REST servlets to do with profile: /profile/<paths> """
from twisted.internet import defer
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import UserID
-from .base import ClientV1RestServlet, client_path_patterns
-
-class ProfileDisplaynameRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname")
+class ProfileDisplaynameRestServlet(RestServlet):
+ PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True)
def __init__(self, hs):
- super(ProfileDisplaynameRestServlet, self).__init__(hs)
+ super(ProfileDisplaynameRestServlet, self).__init__()
+ self.hs = hs
self.profile_handler = hs.get_profile_handler()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
@@ -71,12 +72,14 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
return (200, {})
-class ProfileAvatarURLRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url")
+class ProfileAvatarURLRestServlet(RestServlet):
+ PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
def __init__(self, hs):
- super(ProfileAvatarURLRestServlet, self).__init__(hs)
+ super(ProfileAvatarURLRestServlet, self).__init__()
+ self.hs = hs
self.profile_handler = hs.get_profile_handler()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
@@ -119,12 +122,14 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
return (200, {})
-class ProfileRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)")
+class ProfileRestServlet(RestServlet):
+ PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
def __init__(self, hs):
- super(ProfileRestServlet, self).__init__(hs)
+ super(ProfileRestServlet, self).__init__()
+ self.hs = hs
self.profile_handler = hs.get_profile_handler()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 506ec95ddd..3d6326fe2f 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -21,22 +21,22 @@ from synapse.api.errors import (
SynapseError,
UnrecognizedRequestError,
)
-from synapse.http.servlet import parse_json_value_from_request, parse_string
+from synapse.http.servlet import RestServlet, parse_json_value_from_request, parse_string
from synapse.push.baserules import BASE_RULE_IDS
from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
+from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from .base import ClientV1RestServlet, client_path_patterns
-
-class PushRuleRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/(?P<path>pushrules/.*)$")
+class PushRuleRestServlet(RestServlet):
+ PATTERNS = client_patterns("/(?P<path>pushrules/.*)$", v1=True)
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
"Unrecognised request: You probably wanted a trailing slash")
def __init__(self, hs):
- super(PushRuleRestServlet, self).__init__(hs)
+ super(PushRuleRestServlet, self).__init__()
+ self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 4c07ae7f45..15d860db37 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -26,17 +26,18 @@ from synapse.http.servlet import (
parse_string,
)
from synapse.push import PusherConfigException
-
-from .base import ClientV1RestServlet, client_path_patterns
+from synapse.rest.client.v2_alpha._base import client_patterns
logger = logging.getLogger(__name__)
-class PushersRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/pushers$")
+class PushersRestServlet(RestServlet):
+ PATTERNS = client_patterns("/pushers$", v1=True)
def __init__(self, hs):
- super(PushersRestServlet, self).__init__(hs)
+ super(PushersRestServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -69,11 +70,13 @@ class PushersRestServlet(ClientV1RestServlet):
return 200, {}
-class PushersSetRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/pushers/set$")
+class PushersSetRestServlet(RestServlet):
+ PATTERNS = client_patterns("/pushers/set$", v1=True)
def __init__(self, hs):
- super(PushersSetRestServlet, self).__init__(hs)
+ super(PushersSetRestServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
self.pusher_pool = self.hs.get_pusherpool()
@@ -141,7 +144,7 @@ class PushersRemoveRestServlet(RestServlet):
"""
To allow pusher to be delete by clicking a link (ie. GET request)
"""
- PATTERNS = client_path_patterns("/pushers/remove$")
+ PATTERNS = client_patterns("/pushers/remove$", v1=True)
SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>"
def __init__(self, hs):
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 255a85c588..e8f672c4ba 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -28,37 +28,45 @@ from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.filtering import Filter
from synapse.events.utils import format_event_for_client_v2
from synapse.http.servlet import (
+ RestServlet,
assert_params_in_dict,
parse_integer,
parse_json_object_from_request,
parse_string,
)
+from synapse.rest.client.transactions import HttpTransactionCache
+from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
-from .base import ClientV1RestServlet, client_path_patterns
-
logger = logging.getLogger(__name__)
-class RoomCreateRestServlet(ClientV1RestServlet):
+class TransactionRestServlet(RestServlet):
+ def __init__(self, hs):
+ super(TransactionRestServlet, self).__init__()
+ self.txns = HttpTransactionCache(hs)
+
+
+class RoomCreateRestServlet(TransactionRestServlet):
# No PATTERN; we have custom dispatch rules here
def __init__(self, hs):
super(RoomCreateRestServlet, self).__init__(hs)
self._room_creation_handler = hs.get_room_creation_handler()
+ self.auth = hs.get_auth()
def register(self, http_server):
PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server)
# define CORS for all of /rooms in RoomCreateRestServlet for simplicity
http_server.register_paths("OPTIONS",
- client_path_patterns("/rooms(?:/.*)?$"),
+ client_patterns("/rooms(?:/.*)?$", v1=True),
self.on_OPTIONS)
# define CORS for /createRoom[/txnid]
http_server.register_paths("OPTIONS",
- client_path_patterns("/createRoom(?:/.*)?$"),
+ client_patterns("/createRoom(?:/.*)?$", v1=True),
self.on_OPTIONS)
def on_PUT(self, request, txn_id):
@@ -85,13 +93,14 @@ class RoomCreateRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events
-class RoomStateEventRestServlet(ClientV1RestServlet):
+class RoomStateEventRestServlet(TransactionRestServlet):
def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.message_handler = hs.get_message_handler()
+ self.auth = hs.get_auth()
def register(self, http_server):
# /room/$roomid/state/$eventtype
@@ -102,16 +111,16 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
"(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$")
http_server.register_paths("GET",
- client_path_patterns(state_key),
+ client_patterns(state_key, v1=True),
self.on_GET)
http_server.register_paths("PUT",
- client_path_patterns(state_key),
+ client_patterns(state_key, v1=True),
self.on_PUT)
http_server.register_paths("GET",
- client_path_patterns(no_state_key),
+ client_patterns(no_state_key, v1=True),
self.on_GET_no_state_key)
http_server.register_paths("PUT",
- client_path_patterns(no_state_key),
+ client_patterns(no_state_key, v1=True),
self.on_PUT_no_state_key)
def on_GET_no_state_key(self, request, room_id, event_type):
@@ -185,11 +194,12 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events + feedback
-class RoomSendEventRestServlet(ClientV1RestServlet):
+class RoomSendEventRestServlet(TransactionRestServlet):
def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
+ self.auth = hs.get_auth()
def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id]
@@ -229,10 +239,11 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for room ID + alias joins
-class JoinRoomAliasServlet(ClientV1RestServlet):
+class JoinRoomAliasServlet(TransactionRestServlet):
def __init__(self, hs):
super(JoinRoomAliasServlet, self).__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
+ self.auth = hs.get_auth()
def register(self, http_server):
# /join/$room_identifier[/$txn_id]
@@ -291,8 +302,13 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
# TODO: Needs unit testing
-class PublicRoomListRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/publicRooms$")
+class PublicRoomListRestServlet(TransactionRestServlet):
+ PATTERNS = client_patterns("/publicRooms$", v1=True)
+
+ def __init__(self, hs):
+ super(PublicRoomListRestServlet, self).__init__(hs)
+ self.hs = hs
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -382,12 +398,13 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing
-class RoomMemberListRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$")
+class RoomMemberListRestServlet(RestServlet):
+ PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True)
def __init__(self, hs):
- super(RoomMemberListRestServlet, self).__init__(hs)
+ super(RoomMemberListRestServlet, self).__init__()
self.message_handler = hs.get_message_handler()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
@@ -436,12 +453,13 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
# deprecated in favour of /members?membership=join?
# except it does custom AS logic and has a simpler return format
-class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$")
+class JoinedRoomMemberListRestServlet(RestServlet):
+ PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True)
def __init__(self, hs):
- super(JoinedRoomMemberListRestServlet, self).__init__(hs)
+ super(JoinedRoomMemberListRestServlet, self).__init__()
self.message_handler = hs.get_message_handler()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
@@ -457,12 +475,13 @@ class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
# TODO: Needs better unit testing
-class RoomMessageListRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
+class RoomMessageListRestServlet(RestServlet):
+ PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True)
def __init__(self, hs):
- super(RoomMessageListRestServlet, self).__init__(hs)
+ super(RoomMessageListRestServlet, self).__init__()
self.pagination_handler = hs.get_pagination_handler()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
@@ -475,6 +494,8 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
if filter_bytes:
filter_json = urlparse.unquote(filter_bytes.decode("UTF-8"))
event_filter = Filter(json.loads(filter_json))
+ if event_filter.filter_json.get("event_format", "client") == "federation":
+ as_client_event = False
else:
event_filter = None
msgs = yield self.pagination_handler.get_messages(
@@ -489,12 +510,13 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing
-class RoomStateRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$")
+class RoomStateRestServlet(RestServlet):
+ PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True)
def __init__(self, hs):
- super(RoomStateRestServlet, self).__init__(hs)
+ super(RoomStateRestServlet, self).__init__()
self.message_handler = hs.get_message_handler()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
@@ -509,12 +531,13 @@ class RoomStateRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing
-class RoomInitialSyncRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$")
+class RoomInitialSyncRestServlet(RestServlet):
+ PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True)
def __init__(self, hs):
- super(RoomInitialSyncRestServlet, self).__init__(hs)
+ super(RoomInitialSyncRestServlet, self).__init__()
self.initial_sync_handler = hs.get_initial_sync_handler()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
@@ -528,16 +551,17 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
defer.returnValue((200, content))
-class RoomEventServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns(
- "/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$"
+class RoomEventServlet(RestServlet):
+ PATTERNS = client_patterns(
+ "/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$", v1=True
)
def __init__(self, hs):
- super(RoomEventServlet, self).__init__(hs)
+ super(RoomEventServlet, self).__init__()
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_id):
@@ -552,16 +576,17 @@ class RoomEventServlet(ClientV1RestServlet):
defer.returnValue((404, "Event not found."))
-class RoomEventContextServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns(
- "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$"
+class RoomEventContextServlet(RestServlet):
+ PATTERNS = client_patterns(
+ "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$", v1=True
)
def __init__(self, hs):
- super(RoomEventContextServlet, self).__init__(hs)
+ super(RoomEventContextServlet, self).__init__()
self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler()
self._event_serializer = hs.get_event_client_serializer()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_id):
@@ -607,10 +632,11 @@ class RoomEventContextServlet(ClientV1RestServlet):
defer.returnValue((200, results))
-class RoomForgetRestServlet(ClientV1RestServlet):
+class RoomForgetRestServlet(TransactionRestServlet):
def __init__(self, hs):
super(RoomForgetRestServlet, self).__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
+ self.auth = hs.get_auth()
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
@@ -637,11 +663,12 @@ class RoomForgetRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing
-class RoomMembershipRestServlet(ClientV1RestServlet):
+class RoomMembershipRestServlet(TransactionRestServlet):
def __init__(self, hs):
super(RoomMembershipRestServlet, self).__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
+ self.auth = hs.get_auth()
def register(self, http_server):
# /rooms/$roomid/[invite|join|leave]
@@ -720,11 +747,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
)
-class RoomRedactEventRestServlet(ClientV1RestServlet):
+class RoomRedactEventRestServlet(TransactionRestServlet):
def __init__(self, hs):
super(RoomRedactEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler()
+ self.auth = hs.get_auth()
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
@@ -755,15 +783,16 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
)
-class RoomTypingRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns(
- "/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$"
+class RoomTypingRestServlet(RestServlet):
+ PATTERNS = client_patterns(
+ "/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$", v1=True
)
def __init__(self, hs):
- super(RoomTypingRestServlet, self).__init__(hs)
+ super(RoomTypingRestServlet, self).__init__()
self.presence_handler = hs.get_presence_handler()
self.typing_handler = hs.get_typing_handler()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id):
@@ -796,14 +825,13 @@ class RoomTypingRestServlet(ClientV1RestServlet):
defer.returnValue((200, {}))
-class SearchRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns(
- "/search$"
- )
+class SearchRestServlet(RestServlet):
+ PATTERNS = client_patterns("/search$", v1=True)
def __init__(self, hs):
- super(SearchRestServlet, self).__init__(hs)
+ super(SearchRestServlet, self).__init__()
self.handlers = hs.get_handlers()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_POST(self, request):
@@ -821,12 +849,13 @@ class SearchRestServlet(ClientV1RestServlet):
defer.returnValue((200, results))
-class JoinedRoomsRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/joined_rooms$")
+class JoinedRoomsRestServlet(RestServlet):
+ PATTERNS = client_patterns("/joined_rooms$", v1=True)
def __init__(self, hs):
- super(JoinedRoomsRestServlet, self).__init__(hs)
+ super(JoinedRoomsRestServlet, self).__init__()
self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -851,18 +880,18 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
"""
http_server.register_paths(
"POST",
- client_path_patterns(regex_string + "$"),
+ client_patterns(regex_string + "$", v1=True),
servlet.on_POST
)
http_server.register_paths(
"PUT",
- client_path_patterns(regex_string + "/(?P<txn_id>[^/]*)$"),
+ client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
servlet.on_PUT
)
if with_get:
http_server.register_paths(
"GET",
- client_path_patterns(regex_string + "/(?P<txn_id>[^/]*)$"),
+ client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
servlet.on_GET
)
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 53da905eea..6381049210 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -19,11 +19,17 @@ import hmac
from twisted.internet import defer
-from .base import ClientV1RestServlet, client_path_patterns
+from synapse.http.servlet import RestServlet
+from synapse.rest.client.v2_alpha._base import client_patterns
-class VoipRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/voip/turnServer$")
+class VoipRestServlet(RestServlet):
+ PATTERNS = client_patterns("/voip/turnServer$", v1=True)
+
+ def __init__(self, hs):
+ super(VoipRestServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request):
diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py
index 24ac26bf03..5236d5d566 100644
--- a/synapse/rest/client/v2_alpha/_base.py
+++ b/synapse/rest/client/v2_alpha/_base.py
@@ -26,8 +26,7 @@ from synapse.api.urls import CLIENT_API_PREFIX
logger = logging.getLogger(__name__)
-def client_v2_patterns(path_regex, releases=(0,),
- unstable=True):
+def client_patterns(path_regex, releases=(0,), unstable=True, v1=False):
"""Creates a regex compiled client path with the correct client path
prefix.
@@ -41,6 +40,9 @@ def client_v2_patterns(path_regex, releases=(0,),
if unstable:
unstable_prefix = CLIENT_API_PREFIX + "/unstable"
patterns.append(re.compile("^" + unstable_prefix + path_regex))
+ if v1:
+ v1_prefix = CLIENT_API_PREFIX + "/api/v1"
+ patterns.append(re.compile("^" + v1_prefix + path_regex))
for release in releases:
new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,)
patterns.append(re.compile("^" + new_prefix + path_regex))
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index ee069179f0..ca35dc3c83 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -30,13 +30,13 @@ from synapse.http.servlet import (
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import check_3pid_allowed
-from ._base import client_v2_patterns, interactive_auth_handler
+from ._base import client_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
class EmailPasswordRequestTokenRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/account/password/email/requestToken$")
+ PATTERNS = client_patterns("/account/password/email/requestToken$")
def __init__(self, hs):
super(EmailPasswordRequestTokenRestServlet, self).__init__()
@@ -70,7 +70,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
class MsisdnPasswordRequestTokenRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/account/password/msisdn/requestToken$")
+ PATTERNS = client_patterns("/account/password/msisdn/requestToken$")
def __init__(self, hs):
super(MsisdnPasswordRequestTokenRestServlet, self).__init__()
@@ -108,7 +108,7 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
class PasswordRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/account/password$")
+ PATTERNS = client_patterns("/account/password$")
def __init__(self, hs):
super(PasswordRestServlet, self).__init__()
@@ -180,7 +180,7 @@ class PasswordRestServlet(RestServlet):
class DeactivateAccountRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/account/deactivate$")
+ PATTERNS = client_patterns("/account/deactivate$")
def __init__(self, hs):
super(DeactivateAccountRestServlet, self).__init__()
@@ -228,7 +228,7 @@ class DeactivateAccountRestServlet(RestServlet):
class EmailThreepidRequestTokenRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$")
+ PATTERNS = client_patterns("/account/3pid/email/requestToken$")
def __init__(self, hs):
self.hs = hs
@@ -263,7 +263,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
class MsisdnThreepidRequestTokenRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/account/3pid/msisdn/requestToken$")
+ PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$")
def __init__(self, hs):
self.hs = hs
@@ -300,7 +300,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
class ThreepidRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/account/3pid$")
+ PATTERNS = client_patterns("/account/3pid$")
def __init__(self, hs):
super(ThreepidRestServlet, self).__init__()
@@ -364,7 +364,7 @@ class ThreepidRestServlet(RestServlet):
class ThreepidDeleteRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/account/3pid/delete$")
+ PATTERNS = client_patterns("/account/3pid/delete$")
def __init__(self, hs):
super(ThreepidDeleteRestServlet, self).__init__()
@@ -401,7 +401,7 @@ class ThreepidDeleteRestServlet(RestServlet):
class WhoamiRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/account/whoami$")
+ PATTERNS = client_patterns("/account/whoami$")
def __init__(self, hs):
super(WhoamiRestServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index f171b8d626..574a6298ce 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -20,7 +20,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
@@ -30,7 +30,7 @@ class AccountDataServlet(RestServlet):
PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1
GET /user/{user_id}/account_data/{account_dataType} HTTP/1.1
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)"
)
@@ -79,7 +79,7 @@ class RoomAccountDataServlet(RestServlet):
PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1
GET /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)"
"/rooms/(?P<room_id>[^/]*)"
"/account_data/(?P<account_data_type>[^/]*)"
diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py
index fc8dbeb617..55c4ed5660 100644
--- a/synapse/rest/client/v2_alpha/account_validity.py
+++ b/synapse/rest/client/v2_alpha/account_validity.py
@@ -21,13 +21,13 @@ from synapse.api.errors import AuthError, SynapseError
from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
class AccountValidityRenewServlet(RestServlet):
- PATTERNS = client_v2_patterns("/account_validity/renew$")
+ PATTERNS = client_patterns("/account_validity/renew$")
SUCCESS_HTML = b"<html><body>Your account has been successfully renewed.</body><html>"
def __init__(self, hs):
@@ -60,7 +60,7 @@ class AccountValidityRenewServlet(RestServlet):
class AccountValiditySendMailServlet(RestServlet):
- PATTERNS = client_v2_patterns("/account_validity/send_mail$")
+ PATTERNS = client_patterns("/account_validity/send_mail$")
def __init__(self, hs):
"""
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 4c380ab84d..8dfe5cba02 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -23,7 +23,7 @@ from synapse.api.urls import CLIENT_API_PREFIX
from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet, parse_string
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
@@ -122,7 +122,7 @@ class AuthRestServlet(RestServlet):
cannot be handled in the normal flow (with requests to the same endpoint).
Current use is for web fallback auth.
"""
- PATTERNS = client_v2_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
+ PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
def __init__(self, hs):
super(AuthRestServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py
index a868d06098..fc7e2f4dd5 100644
--- a/synapse/rest/client/v2_alpha/capabilities.py
+++ b/synapse/rest/client/v2_alpha/capabilities.py
@@ -16,10 +16,10 @@ import logging
from twisted.internet import defer
-from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import RestServlet
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
class CapabilitiesRestServlet(RestServlet):
"""End point to expose the capabilities of the server."""
- PATTERNS = client_v2_patterns("/capabilities$")
+ PATTERNS = client_patterns("/capabilities$")
def __init__(self, hs):
"""
@@ -36,6 +36,7 @@ class CapabilitiesRestServlet(RestServlet):
"""
super(CapabilitiesRestServlet, self).__init__()
self.hs = hs
+ self.config = hs.config
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -48,7 +49,7 @@ class CapabilitiesRestServlet(RestServlet):
response = {
"capabilities": {
"m.room_versions": {
- "default": DEFAULT_ROOM_VERSION.identifier,
+ "default": self.config.default_room_version.identifier,
"available": {
v.identifier: v.disposition
for v in KNOWN_ROOM_VERSIONS.values()
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index 5a5be7c390..78665304a5 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -24,13 +24,13 @@ from synapse.http.servlet import (
parse_json_object_from_request,
)
-from ._base import client_v2_patterns, interactive_auth_handler
+from ._base import client_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
class DevicesRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/devices$")
+ PATTERNS = client_patterns("/devices$")
def __init__(self, hs):
"""
@@ -56,7 +56,7 @@ class DeleteDevicesRestServlet(RestServlet):
API for bulk deletion of devices. Accepts a JSON object with a devices
key which lists the device_ids to delete. Requires user interactive auth.
"""
- PATTERNS = client_v2_patterns("/delete_devices")
+ PATTERNS = client_patterns("/delete_devices")
def __init__(self, hs):
super(DeleteDevicesRestServlet, self).__init__()
@@ -95,7 +95,7 @@ class DeleteDevicesRestServlet(RestServlet):
class DeviceRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$")
+ PATTERNS = client_patterns("/devices/(?P<device_id>[^/]*)$")
def __init__(self, hs):
"""
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index ae86728879..65db48c3cc 100644
--- a/synapse/rest/client/v2_alpha/filter.py
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -21,13 +21,13 @@ from synapse.api.errors import AuthError, Codes, StoreError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID
-from ._base import client_v2_patterns, set_timeline_upper_limit
+from ._base import client_patterns, set_timeline_upper_limit
logger = logging.getLogger(__name__)
class GetFilterRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
+ PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
def __init__(self, hs):
super(GetFilterRestServlet, self).__init__()
@@ -63,7 +63,7 @@ class GetFilterRestServlet(RestServlet):
class CreateFilterRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/user/(?P<user_id>[^/]*)/filter")
+ PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter")
def __init__(self, hs):
super(CreateFilterRestServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 21e02c07c0..d082385ec7 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -21,7 +21,7 @@ from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import GroupID
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class GroupServlet(RestServlet):
"""Get the group profile
"""
- PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/profile$")
+ PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$")
def __init__(self, hs):
super(GroupServlet, self).__init__()
@@ -65,7 +65,7 @@ class GroupServlet(RestServlet):
class GroupSummaryServlet(RestServlet):
"""Get the full group summary
"""
- PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/summary$")
+ PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$")
def __init__(self, hs):
super(GroupSummaryServlet, self).__init__()
@@ -93,7 +93,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
- /groups/:group/summary/rooms/:room_id
- /groups/:group/summary/categories/:category/rooms/:room_id
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/summary"
"(/categories/(?P<category_id>[^/]+))?"
"/rooms/(?P<room_id>[^/]*)$"
@@ -137,7 +137,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
class GroupCategoryServlet(RestServlet):
"""Get/add/update/delete a group category
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
)
@@ -189,7 +189,7 @@ class GroupCategoryServlet(RestServlet):
class GroupCategoriesServlet(RestServlet):
"""Get all group categories
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/categories/$"
)
@@ -214,7 +214,7 @@ class GroupCategoriesServlet(RestServlet):
class GroupRoleServlet(RestServlet):
"""Get/add/update/delete a group role
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
)
@@ -266,7 +266,7 @@ class GroupRoleServlet(RestServlet):
class GroupRolesServlet(RestServlet):
"""Get all group roles
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/roles/$"
)
@@ -295,7 +295,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
- /groups/:group/summary/users/:room_id
- /groups/:group/summary/roles/:role/users/:user_id
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/summary"
"(/roles/(?P<role_id>[^/]+))?"
"/users/(?P<user_id>[^/]*)$"
@@ -339,7 +339,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
class GroupRoomServlet(RestServlet):
"""Get all rooms in a group
"""
- PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
+ PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
def __init__(self, hs):
super(GroupRoomServlet, self).__init__()
@@ -360,7 +360,7 @@ class GroupRoomServlet(RestServlet):
class GroupUsersServlet(RestServlet):
"""Get all users in a group
"""
- PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/users$")
+ PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$")
def __init__(self, hs):
super(GroupUsersServlet, self).__init__()
@@ -381,7 +381,7 @@ class GroupUsersServlet(RestServlet):
class GroupInvitedUsersServlet(RestServlet):
"""Get users invited to a group
"""
- PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
+ PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
def __init__(self, hs):
super(GroupInvitedUsersServlet, self).__init__()
@@ -405,7 +405,7 @@ class GroupInvitedUsersServlet(RestServlet):
class GroupSettingJoinPolicyServlet(RestServlet):
"""Set group join policy
"""
- PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
+ PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
def __init__(self, hs):
super(GroupSettingJoinPolicyServlet, self).__init__()
@@ -431,7 +431,7 @@ class GroupSettingJoinPolicyServlet(RestServlet):
class GroupCreateServlet(RestServlet):
"""Create a group
"""
- PATTERNS = client_v2_patterns("/create_group$")
+ PATTERNS = client_patterns("/create_group$")
def __init__(self, hs):
super(GroupCreateServlet, self).__init__()
@@ -462,7 +462,7 @@ class GroupCreateServlet(RestServlet):
class GroupAdminRoomsServlet(RestServlet):
"""Add a room to the group
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
)
@@ -499,7 +499,7 @@ class GroupAdminRoomsServlet(RestServlet):
class GroupAdminRoomsConfigServlet(RestServlet):
"""Update the config of a room in a group
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)"
"/config/(?P<config_key>[^/]*)$"
)
@@ -526,7 +526,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
class GroupAdminUsersInviteServlet(RestServlet):
"""Invite a user to the group
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
)
@@ -555,7 +555,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
class GroupAdminUsersKickServlet(RestServlet):
"""Kick a user from the group
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
)
@@ -581,7 +581,7 @@ class GroupAdminUsersKickServlet(RestServlet):
class GroupSelfLeaveServlet(RestServlet):
"""Leave a joined group
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/self/leave$"
)
@@ -607,7 +607,7 @@ class GroupSelfLeaveServlet(RestServlet):
class GroupSelfJoinServlet(RestServlet):
"""Attempt to join a group, or knock
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/self/join$"
)
@@ -633,7 +633,7 @@ class GroupSelfJoinServlet(RestServlet):
class GroupSelfAcceptInviteServlet(RestServlet):
"""Accept a group invite
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/self/accept_invite$"
)
@@ -659,7 +659,7 @@ class GroupSelfAcceptInviteServlet(RestServlet):
class GroupSelfUpdatePublicityServlet(RestServlet):
"""Update whether we publicise a users membership of a group
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/self/update_publicity$"
)
@@ -686,7 +686,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
class PublicisedGroupsForUserServlet(RestServlet):
"""Get the list of groups a user is advertising
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/publicised_groups/(?P<user_id>[^/]*)$"
)
@@ -711,7 +711,7 @@ class PublicisedGroupsForUserServlet(RestServlet):
class PublicisedGroupsForUsersServlet(RestServlet):
"""Get the list of groups a user is advertising
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/publicised_groups$"
)
@@ -739,7 +739,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
class GroupsForUserServlet(RestServlet):
"""Get all groups the logged in user is joined to
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/joined_groups$"
)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 8486086b51..4cbfbf5631 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -26,7 +26,7 @@ from synapse.http.servlet import (
)
from synapse.types import StreamToken
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
@@ -56,7 +56,7 @@ class KeyUploadServlet(RestServlet):
},
}
"""
- PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
+ PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
def __init__(self, hs):
"""
@@ -130,7 +130,7 @@ class KeyQueryServlet(RestServlet):
} } } } } }
"""
- PATTERNS = client_v2_patterns("/keys/query$")
+ PATTERNS = client_patterns("/keys/query$")
def __init__(self, hs):
"""
@@ -159,7 +159,7 @@ class KeyChangesServlet(RestServlet):
200 OK
{ "changed": ["@foo:example.com"] }
"""
- PATTERNS = client_v2_patterns("/keys/changes$")
+ PATTERNS = client_patterns("/keys/changes$")
def __init__(self, hs):
"""
@@ -209,7 +209,7 @@ class OneTimeKeyServlet(RestServlet):
} } } }
"""
- PATTERNS = client_v2_patterns("/keys/claim$")
+ PATTERNS = client_patterns("/keys/claim$")
def __init__(self, hs):
super(OneTimeKeyServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py
index 0a1eb0ae45..53e666989b 100644
--- a/synapse/rest/client/v2_alpha/notifications.py
+++ b/synapse/rest/client/v2_alpha/notifications.py
@@ -20,13 +20,13 @@ from twisted.internet import defer
from synapse.events.utils import format_event_for_client_v2_without_room_id
from synapse.http.servlet import RestServlet, parse_integer, parse_string
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
class NotificationsServlet(RestServlet):
- PATTERNS = client_v2_patterns("/notifications$")
+ PATTERNS = client_patterns("/notifications$")
def __init__(self, hs):
super(NotificationsServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py
index 01c90aa2a3..bb927d9f9d 100644
--- a/synapse/rest/client/v2_alpha/openid.py
+++ b/synapse/rest/client/v2_alpha/openid.py
@@ -22,7 +22,7 @@ from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.util.stringutils import random_string
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
@@ -56,7 +56,7 @@ class IdTokenServlet(RestServlet):
"expires_in": 3600,
}
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/openid/request_token"
)
diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py
index a6e582a5ae..f4bd0d077f 100644
--- a/synapse/rest/client/v2_alpha/read_marker.py
+++ b/synapse/rest/client/v2_alpha/read_marker.py
@@ -19,13 +19,13 @@ from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
class ReadMarkerRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$")
+ PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$")
def __init__(self, hs):
super(ReadMarkerRestServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py
index de370cac45..fa12ac3e4d 100644
--- a/synapse/rest/client/v2_alpha/receipts.py
+++ b/synapse/rest/client/v2_alpha/receipts.py
@@ -20,13 +20,13 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
class ReceiptRestServlet(RestServlet):
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)"
"/receipt/(?P<receipt_type>[^/]*)"
"/(?P<event_id>[^/]*)$"
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 042f636135..79c085408b 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -43,7 +43,7 @@ from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.threepids import check_3pid_allowed
-from ._base import client_v2_patterns, interactive_auth_handler
+from ._base import client_patterns, interactive_auth_handler
# We ought to be using hmac.compare_digest() but on older pythons it doesn't
# exist. It's a _really minor_ security flaw to use plain string comparison
@@ -60,7 +60,7 @@ logger = logging.getLogger(__name__)
class EmailRegisterRequestTokenRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/register/email/requestToken$")
+ PATTERNS = client_patterns("/register/email/requestToken$")
def __init__(self, hs):
"""
@@ -98,7 +98,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
class MsisdnRegisterRequestTokenRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/register/msisdn/requestToken$")
+ PATTERNS = client_patterns("/register/msisdn/requestToken$")
def __init__(self, hs):
"""
@@ -142,7 +142,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
class UsernameAvailabilityRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/register/available")
+ PATTERNS = client_patterns("/register/available")
def __init__(self, hs):
"""
@@ -182,7 +182,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
class RegisterRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/register$")
+ PATTERNS = client_patterns("/register$")
def __init__(self, hs):
"""
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index 41e0a44936..f8f8742bdc 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -34,7 +34,7 @@ from synapse.http.servlet import (
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.storage.relations import AggregationPaginationToken, RelationPaginationToken
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
@@ -66,12 +66,12 @@ class RelationSendServlet(RestServlet):
def register(self, http_server):
http_server.register_paths(
"POST",
- client_v2_patterns(self.PATTERN + "$", releases=()),
+ client_patterns(self.PATTERN + "$", releases=()),
self.on_PUT_or_POST,
)
http_server.register_paths(
"PUT",
- client_v2_patterns(self.PATTERN + "/(?P<txn_id>[^/]*)$", releases=()),
+ client_patterns(self.PATTERN + "/(?P<txn_id>[^/]*)$", releases=()),
self.on_PUT,
)
@@ -120,7 +120,7 @@ class RelationPaginationServlet(RestServlet):
filtered by relation type and event type.
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/relations/(?P<parent_id>[^/]*)"
"(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$",
releases=(),
@@ -197,7 +197,7 @@ class RelationAggregationPaginationServlet(RestServlet):
}
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)"
"(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$",
releases=(),
@@ -269,7 +269,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
}
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)"
"/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)/(?P<key>[^/]*)$",
releases=(),
diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py
index 95d2a71ec2..10198662a9 100644
--- a/synapse/rest/client/v2_alpha/report_event.py
+++ b/synapse/rest/client/v2_alpha/report_event.py
@@ -27,13 +27,13 @@ from synapse.http.servlet import (
parse_json_object_from_request,
)
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
class ReportEventRestServlet(RestServlet):
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$"
)
diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py
index 220a0de30b..87779645f9 100644
--- a/synapse/rest/client/v2_alpha/room_keys.py
+++ b/synapse/rest/client/v2_alpha/room_keys.py
@@ -24,13 +24,13 @@ from synapse.http.servlet import (
parse_string,
)
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
class RoomKeysServlet(RestServlet):
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/room_keys/keys(/(?P<room_id>[^/]+))?(/(?P<session_id>[^/]+))?$"
)
@@ -256,7 +256,7 @@ class RoomKeysServlet(RestServlet):
class RoomKeysNewVersionServlet(RestServlet):
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/room_keys/version$"
)
@@ -314,7 +314,7 @@ class RoomKeysNewVersionServlet(RestServlet):
class RoomKeysVersionServlet(RestServlet):
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/room_keys/version(/(?P<version>[^/]+))?$"
)
diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
index 62b8de71fa..c621a90fba 100644
--- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
@@ -25,7 +25,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
)
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
@@ -47,7 +47,7 @@ class RoomUpgradeRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer):
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
# /rooms/$roomid/upgrade
"/rooms/(?P<room_id>[^/]*)/upgrade$",
)
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index 21e9cef2d0..120a713361 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -21,13 +21,13 @@ from synapse.http import servlet
from synapse.http.servlet import parse_json_object_from_request
from synapse.rest.client.transactions import HttpTransactionCache
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
class SendToDeviceRestServlet(servlet.RestServlet):
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$",
)
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index c701e534e7..148fc6c985 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -32,7 +32,7 @@ from synapse.handlers.sync import SyncConfig
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.types import StreamToken
-from ._base import client_v2_patterns, set_timeline_upper_limit
+from ._base import client_patterns, set_timeline_upper_limit
logger = logging.getLogger(__name__)
@@ -73,7 +73,7 @@ class SyncRestServlet(RestServlet):
}
"""
- PATTERNS = client_v2_patterns("/sync$")
+ PATTERNS = client_patterns("/sync$")
ALLOWED_PRESENCE = set(["online", "offline", "unavailable"])
def __init__(self, hs):
@@ -358,6 +358,9 @@ class SyncRestServlet(RestServlet):
def serialize(events):
return self._event_serializer.serialize_events(
events, time_now=time_now,
+ # We don't bundle "live" events, as otherwise clients
+ # will end up double counting annotations.
+ bundle_aggregations=False,
token_id=token_id,
event_format=event_formatter,
only_event_fields=only_fields,
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index 4fea614e95..ebff7cff45 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -20,7 +20,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
@@ -29,7 +29,7 @@ class TagListServlet(RestServlet):
"""
GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags"
)
@@ -54,7 +54,7 @@ class TagServlet(RestServlet):
PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)"
)
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index b9b5d07677..e7a987466a 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -21,13 +21,13 @@ from twisted.internet import defer
from synapse.api.constants import ThirdPartyEntityKind
from synapse.http.servlet import RestServlet
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
class ThirdPartyProtocolsServlet(RestServlet):
- PATTERNS = client_v2_patterns("/thirdparty/protocols")
+ PATTERNS = client_patterns("/thirdparty/protocols")
def __init__(self, hs):
super(ThirdPartyProtocolsServlet, self).__init__()
@@ -44,7 +44,7 @@ class ThirdPartyProtocolsServlet(RestServlet):
class ThirdPartyProtocolServlet(RestServlet):
- PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
+ PATTERNS = client_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
def __init__(self, hs):
super(ThirdPartyProtocolServlet, self).__init__()
@@ -66,7 +66,7 @@ class ThirdPartyProtocolServlet(RestServlet):
class ThirdPartyUserServlet(RestServlet):
- PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
+ PATTERNS = client_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
def __init__(self, hs):
super(ThirdPartyUserServlet, self).__init__()
@@ -89,7 +89,7 @@ class ThirdPartyUserServlet(RestServlet):
class ThirdPartyLocationServlet(RestServlet):
- PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
+ PATTERNS = client_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
def __init__(self, hs):
super(ThirdPartyLocationServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py
index 6e76b9e9c2..6c366142e1 100644
--- a/synapse/rest/client/v2_alpha/tokenrefresh.py
+++ b/synapse/rest/client/v2_alpha/tokenrefresh.py
@@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet
-from ._base import client_v2_patterns
+from ._base import client_patterns
class TokenRefreshRestServlet(RestServlet):
@@ -26,7 +26,7 @@ class TokenRefreshRestServlet(RestServlet):
Exchanges refresh tokens for a pair of an access token and a new refresh
token.
"""
- PATTERNS = client_v2_patterns("/tokenrefresh")
+ PATTERNS = client_patterns("/tokenrefresh")
def __init__(self, hs):
super(TokenRefreshRestServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index 36b02de37f..69e4efc47a 100644
--- a/synapse/rest/client/v2_alpha/user_directory.py
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -20,13 +20,13 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
class UserDirectorySearchRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/user_directory/search$")
+ PATTERNS = client_patterns("/user_directory/search$")
def __init__(self, hs):
"""
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 27e7cbf3cc..babbf6a23c 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -39,6 +39,7 @@ class VersionsRestServlet(RestServlet):
"r0.2.0",
"r0.3.0",
"r0.4.0",
+ "r0.5.0",
],
# as per MSC1497:
"unstable_features": {
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index eb8782aa6e..8a730bbc35 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -20,7 +20,7 @@ from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, SynapseError
-from synapse.crypto.keyring import KeyLookupError
+from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import respond_with_json_bytes, wrap_json_request_handler
from synapse.http.servlet import parse_integer, parse_json_object_from_request
@@ -89,7 +89,7 @@ class RemoteKey(Resource):
isLeaf = True
def __init__(self, hs):
- self.keyring = hs.get_keyring()
+ self.fetcher = ServerKeyFetcher(hs)
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
@@ -215,15 +215,7 @@ class RemoteKey(Resource):
json_results.add(bytes(result["key_json"]))
if cache_misses and query_remote_on_cache_miss:
- for server_name, key_ids in cache_misses.items():
- try:
- yield self.keyring.get_server_verify_key_v2_direct(
- server_name, key_ids
- )
- except KeyLookupError as e:
- logger.info("Failed to fetch key: %s", e)
- except Exception:
- logger.exception("Failed to get key for %r", server_name)
+ yield self.fetcher.get_keys(cache_misses)
yield self.query_keys(
request, query, query_remote_on_cache_miss=False
)
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 5305e9175f..35a750923b 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -56,8 +56,8 @@ class ThumbnailResource(Resource):
def _async_render_GET(self, request):
set_cors_headers(request)
server_name, media_id, _ = parse_media_id(request)
- width = parse_integer(request, "width")
- height = parse_integer(request, "height")
+ width = parse_integer(request, "width", required=True)
+ height = parse_integer(request, "height", required=True)
method = parse_string(request, "method", "scale")
m_type = parse_string(request, "type", "image/png")
diff --git a/synapse/server.py b/synapse/server.py
index 80d40b9272..9229a68a8d 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -72,6 +72,7 @@ from synapse.handlers.room_list import RoomListHandler
from synapse.handlers.room_member import RoomMemberMasterHandler
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
from synapse.handlers.set_password import SetPasswordHandler
+from synapse.handlers.stats import StatsHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler
from synapse.handlers.user_directory import UserDirectoryHandler
@@ -139,6 +140,7 @@ class HomeServer(object):
'acme_handler',
'auth_handler',
'device_handler',
+ 'stats_handler',
'e2e_keys_handler',
'e2e_room_keys_handler',
'event_handler',
@@ -191,6 +193,7 @@ class HomeServer(object):
REQUIRED_ON_MASTER_STARTUP = [
"user_directory_handler",
+ "stats_handler"
]
# This is overridden in derived application classes
@@ -474,6 +477,9 @@ class HomeServer(object):
def build_secrets(self):
return Secrets()
+ def build_stats_handler(self):
+ return StatsHandler(self)
+
def build_spam_checker(self):
return SpamChecker(self)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 7522d3fd57..71316f7d09 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -36,6 +36,7 @@ from .engines import PostgresEngine
from .event_federation import EventFederationStore
from .event_push_actions import EventPushActionsStore
from .events import EventsStore
+from .events_bg_updates import EventsBackgroundUpdatesStore
from .filtering import FilteringStore
from .group_server import GroupServerStore
from .keys import KeyStore
@@ -55,6 +56,7 @@ from .roommember import RoomMemberStore
from .search import SearchStore
from .signatures import SignatureStore
from .state import StateStore
+from .stats import StatsStore
from .stream import StreamStore
from .tags import TagsStore
from .transactions import TransactionStore
@@ -65,6 +67,7 @@ logger = logging.getLogger(__name__)
class DataStore(
+ EventsBackgroundUpdatesStore,
RoomMemberStore,
RoomStore,
RegistrationStore,
@@ -100,6 +103,7 @@ class DataStore(
GroupServerStore,
UserErasureStore,
MonthlyActiveUsersStore,
+ StatsStore,
RelationsStore,
):
def __init__(self, db_conn, hs):
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 983ce026e1..52891bb9eb 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,6 +16,7 @@
# limitations under the License.
import itertools
import logging
+import random
import sys
import threading
import time
@@ -227,6 +230,8 @@ class SQLBaseStore(object):
# A set of tables that are not safe to use native upserts in.
self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
+ self._account_validity = self.hs.config.account_validity
+
# We add the user_directory_search table to the blacklist on SQLite
# because the existing search table does not have an index, making it
# unsafe to use native upserts.
@@ -243,6 +248,16 @@ class SQLBaseStore(object):
self._check_safe_to_upsert,
)
+ self.rand = random.SystemRandom()
+
+ if self._account_validity.enabled:
+ self._clock.call_later(
+ 0.0,
+ run_as_background_process,
+ "account_validity_set_expiration_dates",
+ self._set_expiration_date_when_missing,
+ )
+
@defer.inlineCallbacks
def _check_safe_to_upsert(self):
"""
@@ -275,6 +290,67 @@ class SQLBaseStore(object):
self._check_safe_to_upsert,
)
+ @defer.inlineCallbacks
+ def _set_expiration_date_when_missing(self):
+ """
+ Retrieves the list of registered users that don't have an expiration date, and
+ adds an expiration date for each of them.
+ """
+
+ def select_users_with_no_expiration_date_txn(txn):
+ """Retrieves the list of registered users with no expiration date from the
+ database.
+ """
+ sql = (
+ "SELECT users.name FROM users"
+ " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
+ " WHERE account_validity.user_id is NULL;"
+ )
+ txn.execute(sql, [])
+
+ res = self.cursor_to_dict(txn)
+ if res:
+ for user in res:
+ self.set_expiration_date_for_user_txn(
+ txn,
+ user["name"],
+ use_delta=True,
+ )
+
+ yield self.runInteraction(
+ "get_users_with_no_expiration_date",
+ select_users_with_no_expiration_date_txn,
+ )
+
+ def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
+ """Sets an expiration date to the account with the given user ID.
+
+ Args:
+ user_id (str): User ID to set an expiration date for.
+ use_delta (bool): If set to False, the expiration date for the user will be
+ now + validity period. If set to True, this expiration date will be a
+ random value in the [now + period - d ; now + period] range, d being a
+ delta equal to 10% of the validity period.
+ """
+ now_ms = self._clock.time_msec()
+ expiration_ts = now_ms + self._account_validity.period
+
+ if use_delta:
+ expiration_ts = self.rand.randrange(
+ expiration_ts - self._account_validity.startup_job_max_delta,
+ expiration_ts,
+ )
+
+ self._simple_insert_txn(
+ txn,
+ "account_validity",
+ values={
+ "user_id": user_id,
+ "expiration_ts_ms": expiration_ts,
+ "email_sent": False,
+ },
+ )
+
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
@@ -1203,7 +1279,8 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k,) for k in keyvalues),
)
- return txn.execute(sql, list(keyvalues.values()))
+ txn.execute(sql, list(keyvalues.values()))
+ return txn.rowcount
def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
return self.runInteraction(
@@ -1222,9 +1299,12 @@ class SQLBaseStore(object):
column : column name to test for inclusion against `iterable`
iterable : list
keyvalues : dict of column names and values to select the rows with
+
+ Returns:
+ int: Number rows deleted
"""
if not iterable:
- return
+ return 0
sql = "DELETE FROM %s" % table
@@ -1239,7 +1319,9 @@ class SQLBaseStore(object):
if clauses:
sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
- return txn.execute(sql, values)
+ txn.execute(sql, values)
+
+ return txn.rowcount
def _get_cache_dict(
self, db_conn, table, entity_column, stream_column, max_value, limit=100000
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 881d6d0126..f9162be9b9 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -219,41 +220,11 @@ class EventsStore(
EventsWorkerStore,
BackgroundUpdateStore,
):
- EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
- EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
def __init__(self, db_conn, hs):
super(EventsStore, self).__init__(db_conn, hs)
- self.register_background_update_handler(
- self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
- )
- self.register_background_update_handler(
- self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
- self._background_reindex_fields_sender,
- )
-
- self.register_background_index_update(
- "event_contains_url_index",
- index_name="event_contains_url_index",
- table="events",
- columns=["room_id", "topological_ordering", "stream_ordering"],
- where_clause="contains_url = true AND outlier = false",
- )
-
- # an event_id index on event_search is useful for the purge_history
- # api. Plus it means we get to enforce some integrity with a UNIQUE
- # clause
- self.register_background_index_update(
- "event_search_event_id_idx",
- index_name="event_search_event_id_idx",
- table="event_search",
- columns=["event_id"],
- unique=True,
- psql_only=True,
- )
self._event_persist_queue = _EventPeristenceQueue()
-
self._state_resolution_handler = hs.get_state_resolution_handler()
@defer.inlineCallbacks
@@ -554,10 +525,18 @@ class EventsStore(
e_id for event in new_events for e_id in event.prev_event_ids()
)
- # Finally, remove any events which are prev_events of any existing events.
+ # Remove any events which are prev_events of any existing events.
existing_prevs = yield self._get_events_which_are_prevs(result)
result.difference_update(existing_prevs)
+ # Finally handle the case where the new events have soft-failed prev
+ # events. If they do we need to remove them and their prev events,
+ # otherwise we end up with dangling extremities.
+ existing_prevs = yield self._get_prevs_before_rejected(
+ e_id for event in new_events for e_id in event.prev_event_ids()
+ )
+ result.difference_update(existing_prevs)
+
defer.returnValue(result)
@defer.inlineCallbacks
@@ -573,12 +552,13 @@ class EventsStore(
"""
results = []
- def _get_events(txn, batch):
+ def _get_events_which_are_prevs_txn(txn, batch):
sql = """
- SELECT prev_event_id
+ SELECT prev_event_id, internal_metadata
FROM event_edges
INNER JOIN events USING (event_id)
LEFT JOIN rejections USING (event_id)
+ LEFT JOIN event_json USING (event_id)
WHERE
prev_event_id IN (%s)
AND NOT events.outlier
@@ -588,14 +568,86 @@ class EventsStore(
)
txn.execute(sql, batch)
- results.extend(r[0] for r in txn)
+ results.extend(
+ r[0]
+ for r in txn
+ if not json.loads(r[1]).get("soft_failed")
+ )
for chunk in batch_iter(event_ids, 100):
- yield self.runInteraction("_get_events_which_are_prevs", _get_events, chunk)
+ yield self.runInteraction(
+ "_get_events_which_are_prevs",
+ _get_events_which_are_prevs_txn,
+ chunk,
+ )
defer.returnValue(results)
@defer.inlineCallbacks
+ def _get_prevs_before_rejected(self, event_ids):
+ """Get soft-failed ancestors to remove from the extremities.
+
+ Given a set of events, find all those that have been soft-failed or
+ rejected. Returns those soft failed/rejected events and their prev
+ events (whether soft-failed/rejected or not), and recurses up the
+ prev-event graph until it finds no more soft-failed/rejected events.
+
+ This is used to find extremities that are ancestors of new events, but
+ are separated by soft failed events.
+
+ Args:
+ event_ids (Iterable[str]): Events to find prev events for. Note
+ that these must have already been persisted.
+
+ Returns:
+ Deferred[set[str]]
+ """
+
+ # The set of event_ids to return. This includes all soft-failed events
+ # and their prev events.
+ existing_prevs = set()
+
+ def _get_prevs_before_rejected_txn(txn, batch):
+ to_recursively_check = batch
+
+ while to_recursively_check:
+ sql = """
+ SELECT
+ event_id, prev_event_id, internal_metadata,
+ rejections.event_id IS NOT NULL
+ FROM event_edges
+ INNER JOIN events USING (event_id)
+ LEFT JOIN rejections USING (event_id)
+ LEFT JOIN event_json USING (event_id)
+ WHERE
+ event_id IN (%s)
+ AND NOT events.outlier
+ """ % (
+ ",".join("?" for _ in to_recursively_check),
+ )
+
+ txn.execute(sql, to_recursively_check)
+ to_recursively_check = []
+
+ for event_id, prev_event_id, metadata, rejected in txn:
+ if prev_event_id in existing_prevs:
+ continue
+
+ soft_failed = json.loads(metadata).get("soft_failed")
+ if soft_failed or rejected:
+ to_recursively_check.append(prev_event_id)
+ existing_prevs.add(prev_event_id)
+
+ for chunk in batch_iter(event_ids, 100):
+ yield self.runInteraction(
+ "_get_prevs_before_rejected",
+ _get_prevs_before_rejected_txn,
+ chunk,
+ )
+
+ defer.returnValue(existing_prevs)
+
+ @defer.inlineCallbacks
def _get_new_state_after_events(
self, room_id, events_context, old_latest_event_ids, new_latest_event_ids
):
@@ -1498,153 +1550,6 @@ class EventsStore(
ret = yield self.runInteraction("count_daily_active_rooms", _count)
defer.returnValue(ret)
- @defer.inlineCallbacks
- def _background_reindex_fields_sender(self, progress, batch_size):
- target_min_stream_id = progress["target_min_stream_id_inclusive"]
- max_stream_id = progress["max_stream_id_exclusive"]
- rows_inserted = progress.get("rows_inserted", 0)
-
- INSERT_CLUMP_SIZE = 1000
-
- def reindex_txn(txn):
- sql = (
- "SELECT stream_ordering, event_id, json FROM events"
- " INNER JOIN event_json USING (event_id)"
- " WHERE ? <= stream_ordering AND stream_ordering < ?"
- " ORDER BY stream_ordering DESC"
- " LIMIT ?"
- )
-
- txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
-
- rows = txn.fetchall()
- if not rows:
- return 0
-
- min_stream_id = rows[-1][0]
-
- update_rows = []
- for row in rows:
- try:
- event_id = row[1]
- event_json = json.loads(row[2])
- sender = event_json["sender"]
- content = event_json["content"]
-
- contains_url = "url" in content
- if contains_url:
- contains_url &= isinstance(content["url"], text_type)
- except (KeyError, AttributeError):
- # If the event is missing a necessary field then
- # skip over it.
- continue
-
- update_rows.append((sender, contains_url, event_id))
-
- sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
-
- for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
- clump = update_rows[index : index + INSERT_CLUMP_SIZE]
- txn.executemany(sql, clump)
-
- progress = {
- "target_min_stream_id_inclusive": target_min_stream_id,
- "max_stream_id_exclusive": min_stream_id,
- "rows_inserted": rows_inserted + len(rows),
- }
-
- self._background_update_progress_txn(
- txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
- )
-
- return len(rows)
-
- result = yield self.runInteraction(
- self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
- )
-
- if not result:
- yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME)
-
- defer.returnValue(result)
-
- @defer.inlineCallbacks
- def _background_reindex_origin_server_ts(self, progress, batch_size):
- target_min_stream_id = progress["target_min_stream_id_inclusive"]
- max_stream_id = progress["max_stream_id_exclusive"]
- rows_inserted = progress.get("rows_inserted", 0)
-
- INSERT_CLUMP_SIZE = 1000
-
- def reindex_search_txn(txn):
- sql = (
- "SELECT stream_ordering, event_id FROM events"
- " WHERE ? <= stream_ordering AND stream_ordering < ?"
- " ORDER BY stream_ordering DESC"
- " LIMIT ?"
- )
-
- txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
-
- rows = txn.fetchall()
- if not rows:
- return 0
-
- min_stream_id = rows[-1][0]
- event_ids = [row[1] for row in rows]
-
- rows_to_update = []
-
- chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
- for chunk in chunks:
- ev_rows = self._simple_select_many_txn(
- txn,
- table="event_json",
- column="event_id",
- iterable=chunk,
- retcols=["event_id", "json"],
- keyvalues={},
- )
-
- for row in ev_rows:
- event_id = row["event_id"]
- event_json = json.loads(row["json"])
- try:
- origin_server_ts = event_json["origin_server_ts"]
- except (KeyError, AttributeError):
- # If the event is missing a necessary field then
- # skip over it.
- continue
-
- rows_to_update.append((origin_server_ts, event_id))
-
- sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
-
- for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
- clump = rows_to_update[index : index + INSERT_CLUMP_SIZE]
- txn.executemany(sql, clump)
-
- progress = {
- "target_min_stream_id_inclusive": target_min_stream_id,
- "max_stream_id_exclusive": min_stream_id,
- "rows_inserted": rows_inserted + len(rows_to_update),
- }
-
- self._background_update_progress_txn(
- txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress
- )
-
- return len(rows_to_update)
-
- result = yield self.runInteraction(
- self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
- )
-
- if not result:
- yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME)
-
- defer.returnValue(result)
-
def get_current_backfill_token(self):
"""The current minimum token that backfilled events have reached"""
return -self._backfill_id_gen.get_current_token()
diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/events_bg_updates.py
new file mode 100644
index 0000000000..75c1935bf3
--- /dev/null
+++ b/synapse/storage/events_bg_updates.py
@@ -0,0 +1,401 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from six import text_type
+
+from canonicaljson import json
+
+from twisted.internet import defer
+
+from synapse.storage.background_updates import BackgroundUpdateStore
+
+logger = logging.getLogger(__name__)
+
+
+class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
+
+ EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
+ EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
+ DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
+
+ def __init__(self, db_conn, hs):
+ super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs)
+
+ self.register_background_update_handler(
+ self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
+ )
+ self.register_background_update_handler(
+ self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
+ self._background_reindex_fields_sender,
+ )
+
+ self.register_background_index_update(
+ "event_contains_url_index",
+ index_name="event_contains_url_index",
+ table="events",
+ columns=["room_id", "topological_ordering", "stream_ordering"],
+ where_clause="contains_url = true AND outlier = false",
+ )
+
+ # an event_id index on event_search is useful for the purge_history
+ # api. Plus it means we get to enforce some integrity with a UNIQUE
+ # clause
+ self.register_background_index_update(
+ "event_search_event_id_idx",
+ index_name="event_search_event_id_idx",
+ table="event_search",
+ columns=["event_id"],
+ unique=True,
+ psql_only=True,
+ )
+
+ self.register_background_update_handler(
+ self.DELETE_SOFT_FAILED_EXTREMITIES,
+ self._cleanup_extremities_bg_update,
+ )
+
+ @defer.inlineCallbacks
+ def _background_reindex_fields_sender(self, progress, batch_size):
+ target_min_stream_id = progress["target_min_stream_id_inclusive"]
+ max_stream_id = progress["max_stream_id_exclusive"]
+ rows_inserted = progress.get("rows_inserted", 0)
+
+ INSERT_CLUMP_SIZE = 1000
+
+ def reindex_txn(txn):
+ sql = (
+ "SELECT stream_ordering, event_id, json FROM events"
+ " INNER JOIN event_json USING (event_id)"
+ " WHERE ? <= stream_ordering AND stream_ordering < ?"
+ " ORDER BY stream_ordering DESC"
+ " LIMIT ?"
+ )
+
+ txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
+
+ rows = txn.fetchall()
+ if not rows:
+ return 0
+
+ min_stream_id = rows[-1][0]
+
+ update_rows = []
+ for row in rows:
+ try:
+ event_id = row[1]
+ event_json = json.loads(row[2])
+ sender = event_json["sender"]
+ content = event_json["content"]
+
+ contains_url = "url" in content
+ if contains_url:
+ contains_url &= isinstance(content["url"], text_type)
+ except (KeyError, AttributeError):
+ # If the event is missing a necessary field then
+ # skip over it.
+ continue
+
+ update_rows.append((sender, contains_url, event_id))
+
+ sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
+
+ for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
+ clump = update_rows[index : index + INSERT_CLUMP_SIZE]
+ txn.executemany(sql, clump)
+
+ progress = {
+ "target_min_stream_id_inclusive": target_min_stream_id,
+ "max_stream_id_exclusive": min_stream_id,
+ "rows_inserted": rows_inserted + len(rows),
+ }
+
+ self._background_update_progress_txn(
+ txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
+ )
+
+ return len(rows)
+
+ result = yield self.runInteraction(
+ self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
+ )
+
+ if not result:
+ yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME)
+
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def _background_reindex_origin_server_ts(self, progress, batch_size):
+ target_min_stream_id = progress["target_min_stream_id_inclusive"]
+ max_stream_id = progress["max_stream_id_exclusive"]
+ rows_inserted = progress.get("rows_inserted", 0)
+
+ INSERT_CLUMP_SIZE = 1000
+
+ def reindex_search_txn(txn):
+ sql = (
+ "SELECT stream_ordering, event_id FROM events"
+ " WHERE ? <= stream_ordering AND stream_ordering < ?"
+ " ORDER BY stream_ordering DESC"
+ " LIMIT ?"
+ )
+
+ txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
+
+ rows = txn.fetchall()
+ if not rows:
+ return 0
+
+ min_stream_id = rows[-1][0]
+ event_ids = [row[1] for row in rows]
+
+ rows_to_update = []
+
+ chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
+ for chunk in chunks:
+ ev_rows = self._simple_select_many_txn(
+ txn,
+ table="event_json",
+ column="event_id",
+ iterable=chunk,
+ retcols=["event_id", "json"],
+ keyvalues={},
+ )
+
+ for row in ev_rows:
+ event_id = row["event_id"]
+ event_json = json.loads(row["json"])
+ try:
+ origin_server_ts = event_json["origin_server_ts"]
+ except (KeyError, AttributeError):
+ # If the event is missing a necessary field then
+ # skip over it.
+ continue
+
+ rows_to_update.append((origin_server_ts, event_id))
+
+ sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
+
+ for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
+ clump = rows_to_update[index : index + INSERT_CLUMP_SIZE]
+ txn.executemany(sql, clump)
+
+ progress = {
+ "target_min_stream_id_inclusive": target_min_stream_id,
+ "max_stream_id_exclusive": min_stream_id,
+ "rows_inserted": rows_inserted + len(rows_to_update),
+ }
+
+ self._background_update_progress_txn(
+ txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress
+ )
+
+ return len(rows_to_update)
+
+ result = yield self.runInteraction(
+ self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
+ )
+
+ if not result:
+ yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME)
+
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def _cleanup_extremities_bg_update(self, progress, batch_size):
+ """Background update to clean out extremities that should have been
+ deleted previously.
+
+ Mainly used to deal with the aftermath of #5269.
+ """
+
+ # This works by first copying all existing forward extremities into the
+ # `_extremities_to_check` table at start up, and then checking each
+ # event in that table whether we have any descendants that are not
+ # soft-failed/rejected. If that is the case then we delete that event
+ # from the forward extremities table.
+ #
+ # For efficiency, we do this in batches by recursively pulling out all
+ # descendants of a batch until we find the non soft-failed/rejected
+ # events, i.e. the set of descendants whose chain of prev events back
+ # to the batch of extremities are all soft-failed or rejected.
+ # Typically, we won't find any such events as extremities will rarely
+ # have any descendants, but if they do then we should delete those
+ # extremities.
+
+ def _cleanup_extremities_bg_update_txn(txn):
+ # The set of extremity event IDs that we're checking this round
+ original_set = set()
+
+ # A dict[str, set[str]] of event ID to their prev events.
+ graph = {}
+
+ # The set of descendants of the original set that are not rejected
+ # nor soft-failed. Ancestors of these events should be removed
+ # from the forward extremities table.
+ non_rejected_leaves = set()
+
+ # Set of event IDs that have been soft failed, and for which we
+ # should check if they have descendants which haven't been soft
+ # failed.
+ soft_failed_events_to_lookup = set()
+
+ # First, we get `batch_size` events from the table, pulling out
+ # their successor events, if any, and the successor events'
+ # rejection status.
+ txn.execute(
+ """SELECT prev_event_id, event_id, internal_metadata,
+ rejections.event_id IS NOT NULL, events.outlier
+ FROM (
+ SELECT event_id AS prev_event_id
+ FROM _extremities_to_check
+ LIMIT ?
+ ) AS f
+ LEFT JOIN event_edges USING (prev_event_id)
+ LEFT JOIN events USING (event_id)
+ LEFT JOIN event_json USING (event_id)
+ LEFT JOIN rejections USING (event_id)
+ """, (batch_size,)
+ )
+
+ for prev_event_id, event_id, metadata, rejected, outlier in txn:
+ original_set.add(prev_event_id)
+
+ if not event_id or outlier:
+ # Common case where the forward extremity doesn't have any
+ # descendants.
+ continue
+
+ graph.setdefault(event_id, set()).add(prev_event_id)
+
+ soft_failed = False
+ if metadata:
+ soft_failed = json.loads(metadata).get("soft_failed")
+
+ if soft_failed or rejected:
+ soft_failed_events_to_lookup.add(event_id)
+ else:
+ non_rejected_leaves.add(event_id)
+
+ # Now we recursively check all the soft-failed descendants we
+ # found above in the same way, until we have nothing left to
+ # check.
+ while soft_failed_events_to_lookup:
+ # We only want to do 100 at a time, so we split given list
+ # into two.
+ batch = list(soft_failed_events_to_lookup)
+ to_check, to_defer = batch[:100], batch[100:]
+ soft_failed_events_to_lookup = set(to_defer)
+
+ sql = """SELECT prev_event_id, event_id, internal_metadata,
+ rejections.event_id IS NOT NULL
+ FROM event_edges
+ INNER JOIN events USING (event_id)
+ INNER JOIN event_json USING (event_id)
+ LEFT JOIN rejections USING (event_id)
+ WHERE
+ prev_event_id IN (%s)
+ AND NOT events.outlier
+ """ % (
+ ",".join("?" for _ in to_check),
+ )
+ txn.execute(sql, to_check)
+
+ for prev_event_id, event_id, metadata, rejected in txn:
+ if event_id in graph:
+ # Already handled this event previously, but we still
+ # want to record the edge.
+ graph[event_id].add(prev_event_id)
+ continue
+
+ graph[event_id] = {prev_event_id}
+
+ soft_failed = json.loads(metadata).get("soft_failed")
+ if soft_failed or rejected:
+ soft_failed_events_to_lookup.add(event_id)
+ else:
+ non_rejected_leaves.add(event_id)
+
+ # We have a set of non-soft-failed descendants, so we recurse up
+ # the graph to find all ancestors and add them to the set of event
+ # IDs that we can delete from forward extremities table.
+ to_delete = set()
+ while non_rejected_leaves:
+ event_id = non_rejected_leaves.pop()
+ prev_event_ids = graph.get(event_id, set())
+ non_rejected_leaves.update(prev_event_ids)
+ to_delete.update(prev_event_ids)
+
+ to_delete.intersection_update(original_set)
+
+ deleted = self._simple_delete_many_txn(
+ txn=txn,
+ table="event_forward_extremities",
+ column="event_id",
+ iterable=to_delete,
+ keyvalues={},
+ )
+
+ logger.info(
+ "Deleted %d forward extremities of %d checked, to clean up #5269",
+ deleted,
+ len(original_set),
+ )
+
+ if deleted:
+ # We now need to invalidate the caches of these rooms
+ rows = self._simple_select_many_txn(
+ txn,
+ table="events",
+ column="event_id",
+ iterable=to_delete,
+ keyvalues={},
+ retcols=("room_id",)
+ )
+ room_ids = set(row["room_id"] for row in rows)
+ for room_id in room_ids:
+ txn.call_after(
+ self.get_latest_event_ids_in_room.invalidate,
+ (room_id,)
+ )
+
+ self._simple_delete_many_txn(
+ txn=txn,
+ table="_extremities_to_check",
+ column="event_id",
+ iterable=original_set,
+ keyvalues={},
+ )
+
+ return len(original_set)
+
+ num_handled = yield self.runInteraction(
+ "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn,
+ )
+
+ if not num_handled:
+ yield self._end_background_update(self.DELETE_SOFT_FAILED_EXTREMITIES)
+
+ def _drop_table_txn(txn):
+ txn.execute("DROP TABLE _extremities_to_check")
+
+ yield self.runInteraction(
+ "_cleanup_extremities_bg_update_drop_table",
+ _drop_table_txn,
+ )
+
+ defer.returnValue(num_handled)
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index adc6cf26b5..cc7df5cf14 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import division
+
import itertools
import logging
from collections import namedtuple
@@ -76,6 +78,43 @@ class EventsWorkerStore(SQLBaseStore):
desc="get_received_ts",
)
+ def get_received_ts_by_stream_pos(self, stream_ordering):
+ """Given a stream ordering get an approximate timestamp of when it
+ happened.
+
+ This is done by simply taking the received ts of the first event that
+ has a stream ordering greater than or equal to the given stream pos.
+ If none exists returns the current time, on the assumption that it must
+ have happened recently.
+
+ Args:
+ stream_ordering (int)
+
+ Returns:
+ Deferred[int]
+ """
+
+ def _get_approximate_received_ts_txn(txn):
+ sql = """
+ SELECT received_ts FROM events
+ WHERE stream_ordering >= ?
+ LIMIT 1
+ """
+
+ txn.execute(sql, (stream_ordering,))
+ row = txn.fetchone()
+ if row and row[0]:
+ ts = row[0]
+ else:
+ ts = self.clock.time_msec()
+
+ return ts
+
+ return self.runInteraction(
+ "get_approximate_received_ts",
+ _get_approximate_received_ts_txn,
+ )
+
@defer.inlineCallbacks
def get_event(
self,
@@ -610,4 +649,79 @@ class EventsWorkerStore(SQLBaseStore):
return res
- return self.runInteraction("get_rejection_reasons", f)
+ return self.runInteraction("get_seen_events_with_rejections", f)
+
+ def _get_total_state_event_counts_txn(self, txn, room_id):
+ """
+ See get_total_state_event_counts.
+ """
+ # We join against the events table as that has an index on room_id
+ sql = """
+ SELECT COUNT(*) FROM state_events
+ INNER JOIN events USING (room_id, event_id)
+ WHERE room_id=?
+ """
+ txn.execute(sql, (room_id,))
+ row = txn.fetchone()
+ return row[0] if row else 0
+
+ def get_total_state_event_counts(self, room_id):
+ """
+ Gets the total number of state events in a room.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[int]
+ """
+ return self.runInteraction(
+ "get_total_state_event_counts",
+ self._get_total_state_event_counts_txn, room_id
+ )
+
+ def _get_current_state_event_counts_txn(self, txn, room_id):
+ """
+ See get_current_state_event_counts.
+ """
+ sql = "SELECT COUNT(*) FROM current_state_events WHERE room_id=?"
+ txn.execute(sql, (room_id,))
+ row = txn.fetchone()
+ return row[0] if row else 0
+
+ def get_current_state_event_counts(self, room_id):
+ """
+ Gets the current number of state events in a room.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[int]
+ """
+ return self.runInteraction(
+ "get_current_state_event_counts",
+ self._get_current_state_event_counts_txn, room_id
+ )
+
+ @defer.inlineCallbacks
+ def get_room_complexity(self, room_id):
+ """
+ Get a rough approximation of the complexity of the room. This is used by
+ remote servers to decide whether they wish to join the room or not.
+ Higher complexity value indicates that being in the room will consume
+ more resources.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[dict[str:int]] of complexity version to complexity.
+ """
+ state_events = yield self.get_current_state_event_counts(room_id)
+
+ # Call this one "v1", so we can introduce new ones as we want to develop
+ # it.
+ complexity_v1 = round(state_events / 500, 2)
+
+ defer.returnValue({"v1": complexity_v1})
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 7036541792..5300720dbb 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -19,6 +19,7 @@ import logging
import six
+import attr
from signedjson.key import decode_verify_key_bytes
from synapse.util import batch_iter
@@ -36,6 +37,12 @@ else:
db_binary_type = memoryview
+@attr.s(slots=True, frozen=True)
+class FetchKeyResult(object):
+ verify_key = attr.ib() # VerifyKey: the key itself
+ valid_until_ts = attr.ib() # int: how long we can use this key for
+
+
class KeyStore(SQLBaseStore):
"""Persistence for signature verification keys
"""
@@ -54,8 +61,8 @@ class KeyStore(SQLBaseStore):
iterable of (server_name, key-id) tuples to fetch keys for
Returns:
- Deferred: resolves to dict[Tuple[str, str], VerifyKey|None]:
- map from (server_name, key_id) -> VerifyKey, or None if the key is
+ Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
+ map from (server_name, key_id) -> FetchKeyResult, or None if the key is
unknown
"""
keys = {}
@@ -65,17 +72,19 @@ class KeyStore(SQLBaseStore):
# batch_iter always returns tuples so it's safe to do len(batch)
sql = (
- "SELECT server_name, key_id, verify_key FROM server_signature_keys "
- "WHERE 1=0"
+ "SELECT server_name, key_id, verify_key, ts_valid_until_ms "
+ "FROM server_signature_keys WHERE 1=0"
) + " OR (server_name=? AND key_id=?)" * len(batch)
txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
for row in txn:
- server_name, key_id, key_bytes = row
- keys[(server_name, key_id)] = decode_verify_key_bytes(
- key_id, bytes(key_bytes)
+ server_name, key_id, key_bytes, ts_valid_until_ms = row
+ res = FetchKeyResult(
+ verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
+ valid_until_ts=ts_valid_until_ms,
)
+ keys[(server_name, key_id)] = res
def _txn(txn):
for batch in batch_iter(server_name_and_key_ids, 50):
@@ -84,38 +93,53 @@ class KeyStore(SQLBaseStore):
return self.runInteraction("get_server_verify_keys", _txn)
- def store_server_verify_key(
- self, server_name, from_server, time_now_ms, verify_key
- ):
- """Stores a NACL verification key for the given server.
+ def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
+ """Stores NACL verification keys for remote servers.
Args:
- server_name (str): The name of the server.
- from_server (str): Where the verification key was looked up
- time_now_ms (int): The time now in milliseconds
- verify_key (nacl.signing.VerifyKey): The NACL verify key.
+ from_server (str): Where the verification keys were looked up
+ ts_added_ms (int): The time to record that the key was added
+ verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
+ keys to be stored. Each entry is a triplet of
+ (server_name, key_id, key).
"""
- key_id = "%s:%s" % (verify_key.alg, verify_key.version)
-
- # XXX fix this to not need a lock (#3819)
- def _txn(txn):
- self._simple_upsert_txn(
- txn,
- table="server_signature_keys",
- keyvalues={"server_name": server_name, "key_id": key_id},
- values={
- "from_server": from_server,
- "ts_added_ms": time_now_ms,
- "verify_key": db_binary_type(verify_key.encode()),
- },
+ key_values = []
+ value_values = []
+ invalidations = []
+ for server_name, key_id, fetch_result in verify_keys:
+ key_values.append((server_name, key_id))
+ value_values.append(
+ (
+ from_server,
+ ts_added_ms,
+ fetch_result.valid_until_ts,
+ db_binary_type(fetch_result.verify_key.encode()),
+ )
)
# invalidate takes a tuple corresponding to the params of
# _get_server_verify_key. _get_server_verify_key only takes one
# param, which is itself the 2-tuple (server_name, key_id).
- txn.call_after(
- self._get_server_verify_key.invalidate, ((server_name, key_id),)
- )
-
- return self.runInteraction("store_server_verify_key", _txn)
+ invalidations.append((server_name, key_id))
+
+ def _invalidate(res):
+ f = self._get_server_verify_key.invalidate
+ for i in invalidations:
+ f((i, ))
+ return res
+
+ return self.runInteraction(
+ "store_server_verify_keys",
+ self._simple_upsert_many_txn,
+ table="server_signature_keys",
+ key_names=("server_name", "key_id"),
+ key_values=key_values,
+ value_names=(
+ "from_server",
+ "ts_added_ms",
+ "ts_valid_until_ms",
+ "verify_key",
+ ),
+ value_values=value_values,
+ ).addCallback(_invalidate)
def store_server_keys_json(
self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 03a06a83d6..4cf159ba81 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
-# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -725,17 +727,7 @@ class RegistrationStore(
raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
if self._account_validity.enabled:
- now_ms = self.clock.time_msec()
- expiration_ts = now_ms + self._account_validity.period
- self._simple_insert_txn(
- txn,
- "account_validity",
- values={
- "user_id": user_id,
- "expiration_ts_ms": expiration_ts,
- "email_sent": False,
- }
- )
+ self.set_expiration_date_for_user_txn(txn, user_id)
if token:
# it's possible for this to get a conflict, but only for a single user
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index 493abe405e..4c83800cca 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -280,7 +280,7 @@ class RelationsWorkerStore(SQLBaseStore):
having_clause = ""
sql = """
- SELECT type, aggregation_key, COUNT(*), MAX(stream_ordering)
+ SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering)
FROM event_relations
INNER JOIN events USING (event_id)
WHERE {where_clause}
@@ -350,9 +350,7 @@ class RelationsWorkerStore(SQLBaseStore):
"""
def _get_applicable_edit_txn(txn):
- txn.execute(
- sql, (event_id, RelationTypes.REPLACE,)
- )
+ txn.execute(sql, (event_id, RelationTypes.REPLACE))
row = txn.fetchone()
if row:
return row[0]
@@ -367,6 +365,50 @@ class RelationsWorkerStore(SQLBaseStore):
edit_event = yield self.get_event(edit_id, allow_none=True)
defer.returnValue(edit_event)
+ def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
+ """Check if a user has already annotated an event with the same key
+ (e.g. already liked an event).
+
+ Args:
+ parent_id (str): The event being annotated
+ event_type (str): The event type of the annotation
+ aggregation_key (str): The aggregation key of the annotation
+ sender (str): The sender of the annotation
+
+ Returns:
+ Deferred[bool]
+ """
+
+ sql = """
+ SELECT 1 FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE
+ relates_to_id = ?
+ AND relation_type = ?
+ AND type = ?
+ AND sender = ?
+ AND aggregation_key = ?
+ LIMIT 1;
+ """
+
+ def _get_if_user_has_annotated_event(txn):
+ txn.execute(
+ sql,
+ (
+ parent_id,
+ RelationTypes.ANNOTATION,
+ event_type,
+ sender,
+ aggregation_key,
+ ),
+ )
+
+ return bool(txn.fetchone())
+
+ return self.runInteraction(
+ "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
+ )
+
class RelationsStore(RelationsWorkerStore):
def _handle_event_relations(self, txn, event):
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 57df17bcc2..7617913326 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -142,6 +142,27 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return self.runInteraction("get_room_summary", _get_room_summary_txn)
+ def _get_user_counts_in_room_txn(self, txn, room_id):
+ """
+ Get the user count in a room by membership.
+
+ Args:
+ room_id (str)
+ membership (Membership)
+
+ Returns:
+ Deferred[int]
+ """
+ sql = """
+ SELECT m.membership, count(*) FROM room_memberships as m
+ INNER JOIN current_state_events as c USING(event_id)
+ WHERE c.type = 'm.room.member' AND c.room_id = ?
+ GROUP BY m.membership
+ """
+
+ txn.execute(sql, (room_id,))
+ return {row[0]: row[1] for row in txn}
+
@cached()
def get_invited_rooms_for_user(self, user_id):
""" Get all the rooms the user is invited to
diff --git a/synapse/storage/schema/delta/54/account_validity.sql b/synapse/storage/schema/delta/54/account_validity_with_renewal.sql
index 2357626000..0adb2ad55e 100644
--- a/synapse/storage/schema/delta/54/account_validity.sql
+++ b/synapse/storage/schema/delta/54/account_validity_with_renewal.sql
@@ -13,6 +13,9 @@
* limitations under the License.
*/
+-- We previously changed the schema for this table without renaming the file, which means
+-- that some databases might still be using the old schema. This ensures Synapse uses the
+-- right schema for the table.
DROP TABLE IF EXISTS account_validity;
-- Track what users are in public rooms.
diff --git a/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql b/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql
new file mode 100644
index 0000000000..c01aa9d2d9
--- /dev/null
+++ b/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql
@@ -0,0 +1,23 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* When we can use this key until, before we have to refresh it. */
+ALTER TABLE server_signature_keys ADD COLUMN ts_valid_until_ms BIGINT;
+
+UPDATE server_signature_keys SET ts_valid_until_ms = (
+ SELECT MAX(ts_valid_until_ms) FROM server_keys_json skj WHERE
+ skj.server_name = server_signature_keys.server_name AND
+ skj.key_id = server_signature_keys.key_id
+);
diff --git a/synapse/storage/schema/delta/54/delete_forward_extremities.sql b/synapse/storage/schema/delta/54/delete_forward_extremities.sql
new file mode 100644
index 0000000000..b062ec840c
--- /dev/null
+++ b/synapse/storage/schema/delta/54/delete_forward_extremities.sql
@@ -0,0 +1,23 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Start a background job to cleanup extremities that were incorrectly added
+-- by bug #5269.
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('delete_soft_failed_extremities', '{}');
+
+DROP TABLE IF EXISTS _extremities_to_check; -- To make this delta schema file idempotent.
+CREATE TABLE _extremities_to_check AS SELECT event_id FROM event_forward_extremities;
+CREATE INDEX _extremities_to_check_id ON _extremities_to_check(event_id);
diff --git a/synapse/storage/schema/delta/54/stats.sql b/synapse/storage/schema/delta/54/stats.sql
new file mode 100644
index 0000000000..652e58308e
--- /dev/null
+++ b/synapse/storage/schema/delta/54/stats.sql
@@ -0,0 +1,80 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE stats_stream_pos (
+ Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
+ stream_id BIGINT,
+ CHECK (Lock='X')
+);
+
+INSERT INTO stats_stream_pos (stream_id) VALUES (null);
+
+CREATE TABLE user_stats (
+ user_id TEXT NOT NULL,
+ ts BIGINT NOT NULL,
+ bucket_size INT NOT NULL,
+ public_rooms INT NOT NULL,
+ private_rooms INT NOT NULL
+);
+
+CREATE UNIQUE INDEX user_stats_user_ts ON user_stats(user_id, ts);
+
+CREATE TABLE room_stats (
+ room_id TEXT NOT NULL,
+ ts BIGINT NOT NULL,
+ bucket_size INT NOT NULL,
+ current_state_events INT NOT NULL,
+ joined_members INT NOT NULL,
+ invited_members INT NOT NULL,
+ left_members INT NOT NULL,
+ banned_members INT NOT NULL,
+ state_events INT NOT NULL
+);
+
+CREATE UNIQUE INDEX room_stats_room_ts ON room_stats(room_id, ts);
+
+-- cache of current room state; useful for the publicRooms list
+CREATE TABLE room_state (
+ room_id TEXT NOT NULL,
+ join_rules TEXT,
+ history_visibility TEXT,
+ encryption TEXT,
+ name TEXT,
+ topic TEXT,
+ avatar TEXT,
+ canonical_alias TEXT
+ -- get aliases straight from the right table
+);
+
+CREATE UNIQUE INDEX room_state_room ON room_state(room_id);
+
+CREATE TABLE room_stats_earliest_token (
+ room_id TEXT NOT NULL,
+ token BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX room_stats_earliest_token_idx ON room_stats_earliest_token(room_id);
+
+-- Set up staging tables
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('populate_stats_createtables', '{}');
+
+-- Run through each room and update stats
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+ ('populate_stats_process_rooms', '{}', 'populate_stats_createtables');
+
+-- Clean up staging tables
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+ ('populate_stats_cleanup', '{}', 'populate_stats_process_rooms');
diff --git a/synapse/storage/schema/delta/54/stats2.sql b/synapse/storage/schema/delta/54/stats2.sql
new file mode 100644
index 0000000000..3b2d48447f
--- /dev/null
+++ b/synapse/storage/schema/delta/54/stats2.sql
@@ -0,0 +1,28 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This delta file gets run after `54/stats.sql` delta.
+
+-- We want to add some indices to the temporary stats table, so we re-insert
+-- 'populate_stats_createtables' if we are still processing the rooms update.
+INSERT INTO background_updates (update_name, progress_json)
+ SELECT 'populate_stats_createtables', '{}'
+ WHERE
+ 'populate_stats_process_rooms' IN (
+ SELECT update_name FROM background_updates
+ )
+ AND 'populate_stats_createtables' NOT IN ( -- don't insert if already exists
+ SELECT update_name FROM background_updates
+ );
diff --git a/synapse/storage/state_deltas.py b/synapse/storage/state_deltas.py
index 31a0279b18..5fdb442104 100644
--- a/synapse/storage/state_deltas.py
+++ b/synapse/storage/state_deltas.py
@@ -84,10 +84,16 @@ class StateDeltasStore(SQLBaseStore):
"get_current_state_deltas", get_current_state_deltas_txn
)
- def get_max_stream_id_in_current_state_deltas(self):
- return self._simple_select_one_onecol(
+ def _get_max_stream_id_in_current_state_deltas_txn(self, txn):
+ return self._simple_select_one_onecol_txn(
+ txn,
table="current_state_delta_stream",
keyvalues={},
retcol="COALESCE(MAX(stream_id), -1)",
- desc="get_max_stream_id_in_current_state_deltas",
+ )
+
+ def get_max_stream_id_in_current_state_deltas(self):
+ return self.runInteraction(
+ "get_max_stream_id_in_current_state_deltas",
+ self._get_max_stream_id_in_current_state_deltas_txn,
)
diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py
new file mode 100644
index 0000000000..ff266b09b0
--- /dev/null
+++ b/synapse/storage/stats.py
@@ -0,0 +1,468 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018, 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.storage.prepare_database import get_statements
+from synapse.storage.state_deltas import StateDeltasStore
+from synapse.util.caches.descriptors import cached
+
+logger = logging.getLogger(__name__)
+
+# these fields track absolutes (e.g. total number of rooms on the server)
+ABSOLUTE_STATS_FIELDS = {
+ "room": (
+ "current_state_events",
+ "joined_members",
+ "invited_members",
+ "left_members",
+ "banned_members",
+ "state_events",
+ ),
+ "user": ("public_rooms", "private_rooms"),
+}
+
+TYPE_TO_ROOM = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user_id")}
+
+TEMP_TABLE = "_temp_populate_stats"
+
+
+class StatsStore(StateDeltasStore):
+ def __init__(self, db_conn, hs):
+ super(StatsStore, self).__init__(db_conn, hs)
+
+ self.server_name = hs.hostname
+ self.clock = self.hs.get_clock()
+ self.stats_enabled = hs.config.stats_enabled
+ self.stats_bucket_size = hs.config.stats_bucket_size
+
+ self.register_background_update_handler(
+ "populate_stats_createtables", self._populate_stats_createtables
+ )
+ self.register_background_update_handler(
+ "populate_stats_process_rooms", self._populate_stats_process_rooms
+ )
+ self.register_background_update_handler(
+ "populate_stats_cleanup", self._populate_stats_cleanup
+ )
+
+ @defer.inlineCallbacks
+ def _populate_stats_createtables(self, progress, batch_size):
+
+ if not self.stats_enabled:
+ yield self._end_background_update("populate_stats_createtables")
+ defer.returnValue(1)
+
+ # Get all the rooms that we want to process.
+ def _make_staging_area(txn):
+ # Create the temporary tables
+ stmts = get_statements("""
+ -- We just recreate the table, we'll be reinserting the
+ -- correct entries again later anyway.
+ DROP TABLE IF EXISTS {temp}_rooms;
+
+ CREATE TABLE IF NOT EXISTS {temp}_rooms(
+ room_id TEXT NOT NULL,
+ events BIGINT NOT NULL
+ );
+
+ CREATE INDEX {temp}_rooms_events
+ ON {temp}_rooms(events);
+ CREATE INDEX {temp}_rooms_id
+ ON {temp}_rooms(room_id);
+ """.format(temp=TEMP_TABLE).splitlines())
+
+ for statement in stmts:
+ txn.execute(statement)
+
+ sql = (
+ "CREATE TABLE IF NOT EXISTS "
+ + TEMP_TABLE
+ + "_position(position TEXT NOT NULL)"
+ )
+ txn.execute(sql)
+
+ # Get rooms we want to process from the database, only adding
+ # those that we haven't (i.e. those not in room_stats_earliest_token)
+ sql = """
+ INSERT INTO %s_rooms (room_id, events)
+ SELECT c.room_id, count(*) FROM current_state_events AS c
+ LEFT JOIN room_stats_earliest_token AS t USING (room_id)
+ WHERE t.room_id IS NULL
+ GROUP BY c.room_id
+ """ % (TEMP_TABLE,)
+ txn.execute(sql)
+
+ new_pos = yield self.get_max_stream_id_in_current_state_deltas()
+ yield self.runInteraction("populate_stats_temp_build", _make_staging_area)
+ yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
+ self.get_earliest_token_for_room_stats.invalidate_all()
+
+ yield self._end_background_update("populate_stats_createtables")
+ defer.returnValue(1)
+
+ @defer.inlineCallbacks
+ def _populate_stats_cleanup(self, progress, batch_size):
+ """
+ Update the user directory stream position, then clean up the old tables.
+ """
+ if not self.stats_enabled:
+ yield self._end_background_update("populate_stats_cleanup")
+ defer.returnValue(1)
+
+ position = yield self._simple_select_one_onecol(
+ TEMP_TABLE + "_position", None, "position"
+ )
+ yield self.update_stats_stream_pos(position)
+
+ def _delete_staging_area(txn):
+ txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms")
+ txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
+
+ yield self.runInteraction("populate_stats_cleanup", _delete_staging_area)
+
+ yield self._end_background_update("populate_stats_cleanup")
+ defer.returnValue(1)
+
+ @defer.inlineCallbacks
+ def _populate_stats_process_rooms(self, progress, batch_size):
+
+ if not self.stats_enabled:
+ yield self._end_background_update("populate_stats_process_rooms")
+ defer.returnValue(1)
+
+ # If we don't have progress filed, delete everything.
+ if not progress:
+ yield self.delete_all_stats()
+
+ def _get_next_batch(txn):
+ # Only fetch 250 rooms, so we don't fetch too many at once, even
+ # if those 250 rooms have less than batch_size state events.
+ sql = """
+ SELECT room_id, events FROM %s_rooms
+ ORDER BY events DESC
+ LIMIT 250
+ """ % (
+ TEMP_TABLE,
+ )
+ txn.execute(sql)
+ rooms_to_work_on = txn.fetchall()
+
+ if not rooms_to_work_on:
+ return None
+
+ # Get how many are left to process, so we can give status on how
+ # far we are in processing
+ txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
+ progress["remaining"] = txn.fetchone()[0]
+
+ return rooms_to_work_on
+
+ rooms_to_work_on = yield self.runInteraction(
+ "populate_stats_temp_read", _get_next_batch
+ )
+
+ # No more rooms -- complete the transaction.
+ if not rooms_to_work_on:
+ yield self._end_background_update("populate_stats_process_rooms")
+ defer.returnValue(1)
+
+ logger.info(
+ "Processing the next %d rooms of %d remaining",
+ len(rooms_to_work_on), progress["remaining"],
+ )
+
+ # Number of state events we've processed by going through each room
+ processed_event_count = 0
+
+ for room_id, event_count in rooms_to_work_on:
+
+ current_state_ids = yield self.get_current_state_ids(room_id)
+
+ join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
+ history_visibility_id = current_state_ids.get(
+ (EventTypes.RoomHistoryVisibility, "")
+ )
+ encryption_id = current_state_ids.get((EventTypes.RoomEncryption, ""))
+ name_id = current_state_ids.get((EventTypes.Name, ""))
+ topic_id = current_state_ids.get((EventTypes.Topic, ""))
+ avatar_id = current_state_ids.get((EventTypes.RoomAvatar, ""))
+ canonical_alias_id = current_state_ids.get((EventTypes.CanonicalAlias, ""))
+
+ state_events = yield self.get_events([
+ join_rules_id, history_visibility_id, encryption_id, name_id,
+ topic_id, avatar_id, canonical_alias_id,
+ ])
+
+ def _get_or_none(event_id, arg):
+ event = state_events.get(event_id)
+ if event:
+ return event.content.get(arg)
+ return None
+
+ yield self.update_room_state(
+ room_id,
+ {
+ "join_rules": _get_or_none(join_rules_id, "join_rule"),
+ "history_visibility": _get_or_none(
+ history_visibility_id, "history_visibility"
+ ),
+ "encryption": _get_or_none(encryption_id, "algorithm"),
+ "name": _get_or_none(name_id, "name"),
+ "topic": _get_or_none(topic_id, "topic"),
+ "avatar": _get_or_none(avatar_id, "url"),
+ "canonical_alias": _get_or_none(canonical_alias_id, "alias"),
+ },
+ )
+
+ now = self.hs.get_reactor().seconds()
+
+ # quantise time to the nearest bucket
+ now = (now // self.stats_bucket_size) * self.stats_bucket_size
+
+ def _fetch_data(txn):
+
+ # Get the current token of the room
+ current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn)
+
+ current_state_events = len(current_state_ids)
+
+ membership_counts = self._get_user_counts_in_room_txn(txn, room_id)
+
+ total_state_events = self._get_total_state_event_counts_txn(
+ txn, room_id
+ )
+
+ self._update_stats_txn(
+ txn,
+ "room",
+ room_id,
+ now,
+ {
+ "bucket_size": self.stats_bucket_size,
+ "current_state_events": current_state_events,
+ "joined_members": membership_counts.get(Membership.JOIN, 0),
+ "invited_members": membership_counts.get(Membership.INVITE, 0),
+ "left_members": membership_counts.get(Membership.LEAVE, 0),
+ "banned_members": membership_counts.get(Membership.BAN, 0),
+ "state_events": total_state_events,
+ },
+ )
+ self._simple_insert_txn(
+ txn,
+ "room_stats_earliest_token",
+ {"room_id": room_id, "token": current_token},
+ )
+
+ # We've finished a room. Delete it from the table.
+ self._simple_delete_one_txn(
+ txn, TEMP_TABLE + "_rooms", {"room_id": room_id},
+ )
+
+ yield self.runInteraction("update_room_stats", _fetch_data)
+
+ # Update the remaining counter.
+ progress["remaining"] -= 1
+ yield self.runInteraction(
+ "populate_stats",
+ self._background_update_progress_txn,
+ "populate_stats_process_rooms",
+ progress,
+ )
+
+ processed_event_count += event_count
+
+ if processed_event_count > batch_size:
+ # Don't process any more rooms, we've hit our batch size.
+ defer.returnValue(processed_event_count)
+
+ defer.returnValue(processed_event_count)
+
+ def delete_all_stats(self):
+ """
+ Delete all statistics records.
+ """
+
+ def _delete_all_stats_txn(txn):
+ txn.execute("DELETE FROM room_state")
+ txn.execute("DELETE FROM room_stats")
+ txn.execute("DELETE FROM room_stats_earliest_token")
+ txn.execute("DELETE FROM user_stats")
+
+ return self.runInteraction("delete_all_stats", _delete_all_stats_txn)
+
+ def get_stats_stream_pos(self):
+ return self._simple_select_one_onecol(
+ table="stats_stream_pos",
+ keyvalues={},
+ retcol="stream_id",
+ desc="stats_stream_pos",
+ )
+
+ def update_stats_stream_pos(self, stream_id):
+ return self._simple_update_one(
+ table="stats_stream_pos",
+ keyvalues={},
+ updatevalues={"stream_id": stream_id},
+ desc="update_stats_stream_pos",
+ )
+
+ def update_room_state(self, room_id, fields):
+ """
+ Args:
+ room_id (str)
+ fields (dict[str:Any])
+ """
+
+ # For whatever reason some of the fields may contain null bytes, which
+ # postgres isn't a fan of, so we replace those fields with null.
+ for col in (
+ "join_rules",
+ "history_visibility",
+ "encryption",
+ "name",
+ "topic",
+ "avatar",
+ "canonical_alias"
+ ):
+ field = fields.get(col)
+ if field and "\0" in field:
+ fields[col] = None
+
+ return self._simple_upsert(
+ table="room_state",
+ keyvalues={"room_id": room_id},
+ values=fields,
+ desc="update_room_state",
+ )
+
+ def get_deltas_for_room(self, room_id, start, size=100):
+ """
+ Get statistics deltas for a given room.
+
+ Args:
+ room_id (str)
+ start (int): Pagination start. Number of entries, not timestamp.
+ size (int): How many entries to return.
+
+ Returns:
+ Deferred[list[dict]], where the dict has the keys of
+ ABSOLUTE_STATS_FIELDS["room"] and "ts".
+ """
+ return self._simple_select_list_paginate(
+ "room_stats",
+ {"room_id": room_id},
+ "ts",
+ start,
+ size,
+ retcols=(list(ABSOLUTE_STATS_FIELDS["room"]) + ["ts"]),
+ order_direction="DESC",
+ )
+
+ def get_all_room_state(self):
+ return self._simple_select_list(
+ "room_state", None, retcols=("name", "topic", "canonical_alias")
+ )
+
+ @cached()
+ def get_earliest_token_for_room_stats(self, room_id):
+ """
+ Fetch the "earliest token". This is used by the room stats delta
+ processor to ignore deltas that have been processed between the
+ start of the background task and any particular room's stats
+ being calculated.
+
+ Returns:
+ Deferred[int]
+ """
+ return self._simple_select_one_onecol(
+ "room_stats_earliest_token",
+ {"room_id": room_id},
+ retcol="token",
+ allow_none=True,
+ )
+
+ def update_stats(self, stats_type, stats_id, ts, fields):
+ table, id_col = TYPE_TO_ROOM[stats_type]
+ return self._simple_upsert(
+ table=table,
+ keyvalues={id_col: stats_id, "ts": ts},
+ values=fields,
+ desc="update_stats",
+ )
+
+ def _update_stats_txn(self, txn, stats_type, stats_id, ts, fields):
+ table, id_col = TYPE_TO_ROOM[stats_type]
+ return self._simple_upsert_txn(
+ txn, table=table, keyvalues={id_col: stats_id, "ts": ts}, values=fields
+ )
+
+ def update_stats_delta(self, ts, stats_type, stats_id, field, value):
+ def _update_stats_delta(txn):
+ table, id_col = TYPE_TO_ROOM[stats_type]
+
+ sql = (
+ "SELECT * FROM %s"
+ " WHERE %s=? and ts=("
+ " SELECT MAX(ts) FROM %s"
+ " WHERE %s=?"
+ ")"
+ ) % (table, id_col, table, id_col)
+ txn.execute(sql, (stats_id, stats_id))
+ rows = self.cursor_to_dict(txn)
+ if len(rows) == 0:
+ # silently skip as we don't have anything to apply a delta to yet.
+ # this tries to minimise any race between the initial sync and
+ # subsequent deltas arriving.
+ return
+
+ current_ts = ts
+ latest_ts = rows[0]["ts"]
+ if current_ts < latest_ts:
+ # This one is in the past, but we're just encountering it now.
+ # Mark it as part of the current bucket.
+ current_ts = latest_ts
+ elif ts != latest_ts:
+ # we have to copy our absolute counters over to the new entry.
+ values = {
+ key: rows[0][key] for key in ABSOLUTE_STATS_FIELDS[stats_type]
+ }
+ values[id_col] = stats_id
+ values["ts"] = ts
+ values["bucket_size"] = self.stats_bucket_size
+
+ self._simple_insert_txn(txn, table=table, values=values)
+
+ # actually update the new value
+ if stats_type in ABSOLUTE_STATS_FIELDS[stats_type]:
+ self._simple_update_txn(
+ txn,
+ table=table,
+ keyvalues={id_col: stats_id, "ts": current_ts},
+ updatevalues={field: value},
+ )
+ else:
+ sql = ("UPDATE %s SET %s=%s+? WHERE %s=? AND ts=?") % (
+ table,
+ field,
+ field,
+ id_col,
+ )
+ txn.execute(sql, (value, stats_id, current_ts))
+
+ return self.runInteraction("update_stats_delta", _update_stats_delta)
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 311b49e18a..fe412355d8 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -226,6 +226,8 @@ class LoggingContext(object):
self.request = request
def __str__(self):
+ if self.request:
+ return str(self.request)
return "%s@%x" % (self.name, id(self))
@classmethod
@@ -274,12 +276,10 @@ class LoggingContext(object):
current = self.set_current_context(self.previous_context)
if current is not self:
if current is self.sentinel:
- logger.warn("Expected logging context %s has been lost", self)
+ logger.warning("Expected logging context %s was lost", self)
else:
- logger.warn(
- "Current logging context %s is not expected context %s",
- current,
- self
+ logger.warning(
+ "Expected logging context %s but found %s", self, current
)
self.previous_context = None
self.alive = False
@@ -433,10 +433,14 @@ class PreserveLoggingContext(object):
context = LoggingContext.set_current_context(self.current_context)
if context != self.new_context:
- logger.warn(
- "Unexpected logging context: %s is not %s",
- context, self.new_context,
- )
+ if context is LoggingContext.sentinel:
+ logger.warning("Expected logging context %s was lost", self.new_context)
+ else:
+ logger.warning(
+ "Expected logging context %s but found %s",
+ self.new_context,
+ context,
+ )
if self.current_context is not LoggingContext.sentinel:
if not self.current_context.alive:
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 26cce7d197..1a77456498 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -46,8 +46,7 @@ class NotRetryingDestination(Exception):
@defer.inlineCallbacks
-def get_retry_limiter(destination, clock, store, ignore_backoff=False,
- **kwargs):
+def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs):
"""For a given destination check if we have previously failed to
send a request there and are waiting before retrying the destination.
If we are not ready to retry the destination, this will raise a
@@ -60,8 +59,7 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False,
clock (synapse.util.clock): timing source
store (synapse.storage.transactions.TransactionStore): datastore
ignore_backoff (bool): true to ignore the historical backoff data and
- try the request anyway. We will still update the next
- retry_interval on success/failure.
+ try the request anyway. We will still reset the retry_interval on success.
Example usage:
@@ -75,13 +73,12 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False,
"""
retry_last_ts, retry_interval = (0, 0)
- retry_timings = yield store.get_destination_retry_timings(
- destination
- )
+ retry_timings = yield store.get_destination_retry_timings(destination)
if retry_timings:
retry_last_ts, retry_interval = (
- retry_timings["retry_last_ts"], retry_timings["retry_interval"]
+ retry_timings["retry_last_ts"],
+ retry_timings["retry_interval"],
)
now = int(clock.time_msec())
@@ -93,22 +90,36 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False,
destination=destination,
)
+ # if we are ignoring the backoff data, we should also not increment the backoff
+ # when we get another failure - otherwise a server can very quickly reach the
+ # maximum backoff even though it might only have been down briefly
+ backoff_on_failure = not ignore_backoff
+
defer.returnValue(
RetryDestinationLimiter(
destination,
clock,
store,
retry_interval,
+ backoff_on_failure=backoff_on_failure,
**kwargs
)
)
class RetryDestinationLimiter(object):
- def __init__(self, destination, clock, store, retry_interval,
- min_retry_interval=10 * 60 * 1000,
- max_retry_interval=24 * 60 * 60 * 1000,
- multiplier_retry_interval=5, backoff_on_404=False):
+ def __init__(
+ self,
+ destination,
+ clock,
+ store,
+ retry_interval,
+ min_retry_interval=10 * 60 * 1000,
+ max_retry_interval=24 * 60 * 60 * 1000,
+ multiplier_retry_interval=5,
+ backoff_on_404=False,
+ backoff_on_failure=True,
+ ):
"""Marks the destination as "down" if an exception is thrown in the
context, except for CodeMessageException with code < 500.
@@ -128,6 +139,9 @@ class RetryDestinationLimiter(object):
multiplier_retry_interval (int): The multiplier to use to increase
the retry interval after a failed request.
backoff_on_404 (bool): Back off if we get a 404
+
+ backoff_on_failure (bool): set to False if we should not increase the
+ retry interval on a failure.
"""
self.clock = clock
self.store = store
@@ -138,6 +152,7 @@ class RetryDestinationLimiter(object):
self.max_retry_interval = max_retry_interval
self.multiplier_retry_interval = multiplier_retry_interval
self.backoff_on_404 = backoff_on_404
+ self.backoff_on_failure = backoff_on_failure
def __enter__(self):
pass
@@ -173,10 +188,13 @@ class RetryDestinationLimiter(object):
if not self.retry_interval:
return
- logger.debug("Connection to %s was successful; clearing backoff",
- self.destination)
+ logger.debug(
+ "Connection to %s was successful; clearing backoff", self.destination
+ )
retry_last_ts = 0
self.retry_interval = 0
+ elif not self.backoff_on_failure:
+ return
else:
# We couldn't connect.
if self.retry_interval:
@@ -190,7 +208,10 @@ class RetryDestinationLimiter(object):
logger.info(
"Connection to %s was unsuccessful (%s(%s)); backoff now %i",
- self.destination, exc_type, exc_val, self.retry_interval
+ self.destination,
+ exc_type,
+ exc_val,
+ self.retry_interval,
)
retry_last_ts = int(self.clock.time_msec())
@@ -201,9 +222,7 @@ class RetryDestinationLimiter(object):
self.destination, retry_last_ts, self.retry_interval
)
except Exception:
- logger.exception(
- "Failed to store destination_retry_timings",
- )
+ logger.exception("Failed to store destination_retry_timings")
# we deliberately do this in the background.
synapse.util.logcontext.run_in_background(store_retry_timings)
|