diff --git a/synapse/__init__.py b/synapse/__init__.py
index 6bb5a8b24d..d0e8d7c21b 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -27,4 +27,4 @@ try:
except ImportError:
pass
-__version__ = "0.99.3"
+__version__ = "0.99.5.2"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 960e66dbdc..0c6c93a87b 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -556,7 +556,7 @@ class Auth(object):
""" Check if the given user is a local server admin.
Args:
- user (str): mxid of user to check
+ user (UserID): user to check
Returns:
bool: True if the user is an admin
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 0860b75905..ee129c8689 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -20,6 +20,12 @@
# the "depth" field on events is limited to 2**63 - 1
MAX_DEPTH = 2**63 - 1
+# the maximum length for a room alias is 255 characters
+MAX_ALIAS_LENGTH = 255
+
+# the maximum length for a user id is 255 characters
+MAX_USERID_LENGTH = 255
+
class Membership(object):
@@ -73,6 +79,7 @@ class EventTypes(object):
RoomHistoryVisibility = "m.room.history_visibility"
CanonicalAlias = "m.room.canonical_alias"
+ Encryption = "m.room.encryption"
RoomAvatar = "m.room.avatar"
RoomEncryption = "m.room.encryption"
GuestAccess = "m.room.guest_access"
@@ -113,3 +120,11 @@ class UserTypes(object):
"""
SUPPORT = "support"
ALL_USER_TYPES = (SUPPORT,)
+
+
+class RelationTypes(object):
+ """The types of relations known to this server.
+ """
+ ANNOTATION = "m.annotation"
+ REPLACE = "m.replace"
+ REFERENCE = "m.reference"
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index ff89259dec..e91697049c 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -328,9 +328,23 @@ class RoomKeysVersionError(SynapseError):
self.current_version = current_version
+class UnsupportedRoomVersionError(SynapseError):
+ """The client's request to create a room used a room version that the server does
+ not support."""
+ def __init__(self):
+ super(UnsupportedRoomVersionError, self).__init__(
+ code=400,
+ msg="Homeserver does not support this room version",
+ errcode=Codes.UNSUPPORTED_ROOM_VERSION,
+ )
+
+
class IncompatibleRoomVersionError(SynapseError):
- """A server is trying to join a room whose version it does not support."""
+ """A server is trying to join a room whose version it does not support.
+ Unlike UnsupportedRoomVersionError, it is specific to the case of the make_join
+ failing.
+ """
def __init__(self, room_version):
super(IncompatibleRoomVersionError, self).__init__(
code=400,
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index e77abe1040..4085bd10b9 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -19,13 +19,15 @@ class EventFormatVersions(object):
"""This is an internal enum for tracking the version of the event format,
independently from the room version.
"""
- V1 = 1 # $id:server format
- V2 = 2 # MSC1659-style $hash format: introduced for room v3
+ V1 = 1 # $id:server event id format
+ V2 = 2 # MSC1659-style $hash event id format: introduced for room v3
+ V3 = 3 # MSC1884-style $hash format: introduced for room v4
KNOWN_EVENT_FORMAT_VERSIONS = {
EventFormatVersions.V1,
EventFormatVersions.V2,
+ EventFormatVersions.V3,
}
@@ -75,10 +77,12 @@ class RoomVersions(object):
EventFormatVersions.V2,
StateResolutionVersions.V2,
)
-
-
-# the version we will give rooms which are created on this server
-DEFAULT_ROOM_VERSION = RoomVersions.V1
+ V4 = RoomVersion(
+ "4",
+ RoomDisposition.STABLE,
+ EventFormatVersions.V3,
+ StateResolutionVersions.V2,
+ )
KNOWN_ROOM_VERSIONS = {
@@ -87,5 +91,6 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.V2,
RoomVersions.V3,
RoomVersions.STATE_V2_TEST,
+ RoomVersions.V4,
)
} # type: dict[str, RoomVersion]
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index cb71d80875..e16c386a14 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -22,11 +22,11 @@ from six.moves.urllib.parse import urlencode
from synapse.config import ConfigError
-CLIENT_PREFIX = "/_matrix/client/api/v1"
-CLIENT_V2_ALPHA_PREFIX = "/_matrix/client/v2_alpha"
+CLIENT_API_PREFIX = "/_matrix/client"
FEDERATION_PREFIX = "/_matrix/federation"
FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1"
FEDERATION_V2_PREFIX = FEDERATION_PREFIX + "/v2"
+FEDERATION_UNSTABLE_PREFIX = FEDERATION_PREFIX + "/unstable"
STATIC_PREFIX = "/_matrix/static"
WEB_CLIENT_PREFIX = "/_matrix/client"
CONTENT_REPO_PREFIX = "/_matrix/content"
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index d4c6c4c8e2..8cc990399f 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -22,13 +22,14 @@ import traceback
import psutil
from daemonize import Daemonize
-from twisted.internet import error, reactor
+from twisted.internet import defer, error, reactor
from twisted.protocols.tls import TLSMemoryBIOFactory
import synapse
from synapse.app import check_bind_error
from synapse.crypto import context_factory
from synapse.util import PreserveLoggingContext
+from synapse.util.async_helpers import Linearizer
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
@@ -99,6 +100,8 @@ def start_reactor(
logger (logging.Logger): logger instance to pass to Daemonize
"""
+ install_dns_limiter(reactor)
+
def run():
# make sure that we run the reactor with the sentinel log context,
# otherwise other PreserveLoggingContext instances will get confused
@@ -312,3 +315,87 @@ def setup_sentry(hs):
name = hs.config.worker_name if hs.config.worker_name else "master"
scope.set_tag("worker_app", app)
scope.set_tag("worker_name", name)
+
+
+def install_dns_limiter(reactor, max_dns_requests_in_flight=100):
+ """Replaces the resolver with one that limits the number of in flight DNS
+ requests.
+
+ This is to workaround https://twistedmatrix.com/trac/ticket/9620, where we
+ can run out of file descriptors and infinite loop if we attempt to do too
+ many DNS queries at once
+ """
+ new_resolver = _LimitedHostnameResolver(
+ reactor.nameResolver, max_dns_requests_in_flight,
+ )
+
+ reactor.installNameResolver(new_resolver)
+
+
+class _LimitedHostnameResolver(object):
+ """Wraps a IHostnameResolver, limiting the number of in-flight DNS lookups.
+ """
+
+ def __init__(self, resolver, max_dns_requests_in_flight):
+ self._resolver = resolver
+ self._limiter = Linearizer(
+ name="dns_client_limiter", max_count=max_dns_requests_in_flight,
+ )
+
+ def resolveHostName(self, resolutionReceiver, hostName, portNumber=0,
+ addressTypes=None, transportSemantics='TCP'):
+ # We need this function to return `resolutionReceiver` so we do all the
+ # actual logic involving deferreds in a separate function.
+
+ # even though this is happening within the depths of twisted, we need to drop
+ # our logcontext before starting _resolve, otherwise: (a) _resolve will drop
+ # the logcontext if it returns an incomplete deferred; (b) _resolve will
+ # call the resolutionReceiver *with* a logcontext, which it won't be expecting.
+ with PreserveLoggingContext():
+ self._resolve(
+ resolutionReceiver,
+ hostName,
+ portNumber,
+ addressTypes,
+ transportSemantics,
+ )
+
+ return resolutionReceiver
+
+ @defer.inlineCallbacks
+ def _resolve(self, resolutionReceiver, hostName, portNumber=0,
+ addressTypes=None, transportSemantics='TCP'):
+
+ with (yield self._limiter.queue(())):
+ # resolveHostName doesn't return a Deferred, so we need to hook into
+ # the receiver interface to get told when resolution has finished.
+
+ deferred = defer.Deferred()
+ receiver = _DeferredResolutionReceiver(resolutionReceiver, deferred)
+
+ self._resolver.resolveHostName(
+ receiver, hostName, portNumber,
+ addressTypes, transportSemantics,
+ )
+
+ yield deferred
+
+
+class _DeferredResolutionReceiver(object):
+ """Wraps a IResolutionReceiver and simply resolves the given deferred when
+ resolution is complete
+ """
+
+ def __init__(self, receiver, deferred):
+ self._receiver = receiver
+ self._deferred = deferred
+
+ def resolutionBegan(self, resolutionInProgress):
+ self._receiver.resolutionBegan(resolutionInProgress)
+
+ def addressResolved(self, address):
+ self._receiver.addressResolved(address)
+
+ def resolutionComplete(self):
+ self._deferred.callback(())
+ self._receiver.resolutionComplete()
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index 864f1eac48..a16e037f32 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -38,6 +38,7 @@ from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.directory import DirectoryStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.keys import SlavedKeyStore
+from synapse.replication.slave.storage.profile import SlavedProfileStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
@@ -81,6 +82,7 @@ class ClientReaderSlavedStore(
SlavedApplicationServiceStore,
SlavedRegistrationStore,
SlavedTransactionStore,
+ SlavedProfileStore,
SlavedClientIpStore,
BaseSlavedStore,
):
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index 8479fee738..6504da5278 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -37,8 +37,7 @@ from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.tcp.client import ReplicationClientHandler
-from synapse.rest.client.v1.base import ClientV1RestServlet, client_path_patterns
-from synapse.rest.client.v2_alpha._base import client_v2_patterns
+from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
@@ -49,11 +48,11 @@ from synapse.util.versionstring import get_version_string
logger = logging.getLogger("synapse.app.frontend_proxy")
-class PresenceStatusStubServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status")
+class PresenceStatusStubServlet(RestServlet):
+ PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status")
def __init__(self, hs):
- super(PresenceStatusStubServlet, self).__init__(hs)
+ super(PresenceStatusStubServlet, self).__init__()
self.http_client = hs.get_simple_http_client()
self.auth = hs.get_auth()
self.main_uri = hs.config.worker_main_http_uri
@@ -84,7 +83,7 @@ class PresenceStatusStubServlet(ClientV1RestServlet):
class KeyUploadServlet(RestServlet):
- PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
+ PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
def __init__(self, hs):
"""
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 79be977ea6..1045d28949 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -62,6 +62,7 @@ from synapse.python_dependencies import check_requirements
from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.rest import ClientRestResource
+from synapse.rest.admin import AdminRestResource
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.well_known import WellKnownResource
@@ -180,6 +181,7 @@ class SynapseHomeServer(HomeServer):
"/_matrix/client/v2_alpha": client_resource,
"/_matrix/client/versions": client_resource,
"/.well-known/matrix/client": WellKnownResource(self),
+ "/_synapse/admin": AdminRestResource(self),
})
if self.get_config().saml2_enabled:
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index 342a6ce5fd..8400471f40 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2015-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -29,12 +31,50 @@ logger = logging.getLogger(__name__)
class EmailConfig(Config):
def read_config(self, config):
+ # TODO: We should separate better the email configuration from the notification
+ # and account validity config.
+
self.email_enable_notifs = False
email_config = config.get("email", {})
+
+ self.email_smtp_host = email_config.get("smtp_host", None)
+ self.email_smtp_port = email_config.get("smtp_port", None)
+ self.email_smtp_user = email_config.get("smtp_user", None)
+ self.email_smtp_pass = email_config.get("smtp_pass", None)
+ self.require_transport_security = email_config.get(
+ "require_transport_security", False
+ )
+ if "app_name" in email_config:
+ self.email_app_name = email_config["app_name"]
+ else:
+ self.email_app_name = "Matrix"
+
+ self.email_notif_from = email_config.get("notif_from", None)
+ if self.email_notif_from is not None:
+ # make sure it's valid
+ parsed = email.utils.parseaddr(self.email_notif_from)
+ if parsed[1] == '':
+ raise RuntimeError("Invalid notif_from address")
+
+ template_dir = email_config.get("template_dir")
+ # we need an absolute path, because we change directory after starting (and
+ # we don't yet know what auxilliary templates like mail.css we will need).
+ # (Note that loading as package_resources with jinja.PackageLoader doesn't
+ # work for the same reason.)
+ if not template_dir:
+ template_dir = pkg_resources.resource_filename(
+ 'synapse', 'res/templates'
+ )
+
+ self.email_template_dir = os.path.abspath(template_dir)
+
self.email_enable_notifs = email_config.get("enable_notifs", False)
+ account_validity_renewal_enabled = config.get(
+ "account_validity", {},
+ ).get("renew_at")
- if self.email_enable_notifs:
+ if self.email_enable_notifs or account_validity_renewal_enabled:
# make sure we can import the required deps
import jinja2
import bleach
@@ -42,6 +82,7 @@ class EmailConfig(Config):
jinja2
bleach
+ if self.email_enable_notifs:
required = [
"smtp_host",
"smtp_port",
@@ -66,34 +107,13 @@ class EmailConfig(Config):
"email.enable_notifs is True but no public_baseurl is set"
)
- self.email_smtp_host = email_config["smtp_host"]
- self.email_smtp_port = email_config["smtp_port"]
- self.email_notif_from = email_config["notif_from"]
self.email_notif_template_html = email_config["notif_template_html"]
self.email_notif_template_text = email_config["notif_template_text"]
- self.email_expiry_template_html = email_config.get(
- "expiry_template_html", "notice_expiry.html",
- )
- self.email_expiry_template_text = email_config.get(
- "expiry_template_text", "notice_expiry.txt",
- )
-
- template_dir = email_config.get("template_dir")
- # we need an absolute path, because we change directory after starting (and
- # we don't yet know what auxilliary templates like mail.css we will need).
- # (Note that loading as package_resources with jinja.PackageLoader doesn't
- # work for the same reason.)
- if not template_dir:
- template_dir = pkg_resources.resource_filename(
- 'synapse', 'res/templates'
- )
- template_dir = os.path.abspath(template_dir)
for f in self.email_notif_template_text, self.email_notif_template_html:
- p = os.path.join(template_dir, f)
+ p = os.path.join(self.email_template_dir, f)
if not os.path.isfile(p):
raise ConfigError("Unable to find email template file %s" % (p, ))
- self.email_template_dir = template_dir
self.email_notif_for_new_users = email_config.get(
"notif_for_new_users", True
@@ -101,29 +121,24 @@ class EmailConfig(Config):
self.email_riot_base_url = email_config.get(
"riot_base_url", None
)
- self.email_smtp_user = email_config.get(
- "smtp_user", None
- )
- self.email_smtp_pass = email_config.get(
- "smtp_pass", None
- )
- self.require_transport_security = email_config.get(
- "require_transport_security", False
- )
- if "app_name" in email_config:
- self.email_app_name = email_config["app_name"]
- else:
- self.email_app_name = "Matrix"
-
- # make sure it's valid
- parsed = email.utils.parseaddr(self.email_notif_from)
- if parsed[1] == '':
- raise RuntimeError("Invalid notif_from address")
else:
self.email_enable_notifs = False
# Not much point setting defaults for the rest: it would be an
# error for them to be used.
+ if account_validity_renewal_enabled:
+ self.email_expiry_template_html = email_config.get(
+ "expiry_template_html", "notice_expiry.html",
+ )
+ self.email_expiry_template_text = email_config.get(
+ "expiry_template_text", "notice_expiry.txt",
+ )
+
+ for f in self.email_expiry_template_text, self.email_expiry_template_html:
+ p = os.path.join(self.email_template_dir, f)
+ if not os.path.isfile(p):
+ raise ConfigError("Unable to find email template file %s" % (p, ))
+
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Enable sending emails for notification events or expiry notices
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 727fdc54d8..5c4fc8ff21 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from .api import ApiConfig
from .appservice import AppServiceConfig
from .captcha import CaptchaConfig
@@ -36,20 +37,41 @@ from .saml2_config import SAML2Config
from .server import ServerConfig
from .server_notices_config import ServerNoticesConfig
from .spam_checker import SpamCheckerConfig
+from .stats import StatsConfig
from .tls import TlsConfig
from .user_directory import UserDirectoryConfig
from .voip import VoipConfig
from .workers import WorkerConfig
-class HomeServerConfig(ServerConfig, TlsConfig, DatabaseConfig, LoggingConfig,
- RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
- VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
- AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
- JWTConfig, PasswordConfig, EmailConfig,
- WorkerConfig, PasswordAuthProviderConfig, PushConfig,
- SpamCheckerConfig, GroupsConfig, UserDirectoryConfig,
- ConsentConfig,
- ServerNoticesConfig, RoomDirectoryConfig,
- ):
+class HomeServerConfig(
+ ServerConfig,
+ TlsConfig,
+ DatabaseConfig,
+ LoggingConfig,
+ RatelimitConfig,
+ ContentRepositoryConfig,
+ CaptchaConfig,
+ VoipConfig,
+ RegistrationConfig,
+ MetricsConfig,
+ ApiConfig,
+ AppServiceConfig,
+ KeyConfig,
+ SAML2Config,
+ CasConfig,
+ JWTConfig,
+ PasswordConfig,
+ EmailConfig,
+ WorkerConfig,
+ PasswordAuthProviderConfig,
+ PushConfig,
+ SpamCheckerConfig,
+ GroupsConfig,
+ UserDirectoryConfig,
+ ConsentConfig,
+ StatsConfig,
+ ServerNoticesConfig,
+ RoomDirectoryConfig,
+):
pass
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 5a68399e63..5a9adac480 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -16,16 +16,56 @@ from ._base import Config
class RateLimitConfig(object):
- def __init__(self, config):
- self.per_second = config.get("per_second", 0.17)
- self.burst_count = config.get("burst_count", 3.0)
+ def __init__(self, config, defaults={"per_second": 0.17, "burst_count": 3.0}):
+ self.per_second = config.get("per_second", defaults["per_second"])
+ self.burst_count = config.get("burst_count", defaults["burst_count"])
-class RatelimitConfig(Config):
+class FederationRateLimitConfig(object):
+ _items_and_default = {
+ "window_size": 10000,
+ "sleep_limit": 10,
+ "sleep_delay": 500,
+ "reject_limit": 50,
+ "concurrent": 3,
+ }
+
+ def __init__(self, **kwargs):
+ for i in self._items_and_default.keys():
+ setattr(self, i, kwargs.get(i) or self._items_and_default[i])
+
+class RatelimitConfig(Config):
def read_config(self, config):
- self.rc_messages_per_second = config.get("rc_messages_per_second", 0.2)
- self.rc_message_burst_count = config.get("rc_message_burst_count", 10.0)
+
+ # Load the new-style messages config if it exists. Otherwise fall back
+ # to the old method.
+ if "rc_message" in config:
+ self.rc_message = RateLimitConfig(
+ config["rc_message"], defaults={"per_second": 0.2, "burst_count": 10.0}
+ )
+ else:
+ self.rc_message = RateLimitConfig(
+ {
+ "per_second": config.get("rc_messages_per_second", 0.2),
+ "burst_count": config.get("rc_message_burst_count", 10.0),
+ }
+ )
+
+ # Load the new-style federation config, if it exists. Otherwise, fall
+ # back to the old method.
+ if "federation_rc" in config:
+ self.rc_federation = FederationRateLimitConfig(**config["rc_federation"])
+ else:
+ self.rc_federation = FederationRateLimitConfig(
+ **{
+ "window_size": config.get("federation_rc_window_size"),
+ "sleep_limit": config.get("federation_rc_sleep_limit"),
+ "sleep_delay": config.get("federation_rc_sleep_delay"),
+ "reject_limit": config.get("federation_rc_reject_limit"),
+ "concurrent": config.get("federation_rc_concurrent"),
+ }
+ )
self.rc_registration = RateLimitConfig(config.get("rc_registration", {}))
@@ -33,38 +73,26 @@ class RatelimitConfig(Config):
self.rc_login_address = RateLimitConfig(rc_login_config.get("address", {}))
self.rc_login_account = RateLimitConfig(rc_login_config.get("account", {}))
self.rc_login_failed_attempts = RateLimitConfig(
- rc_login_config.get("failed_attempts", {}),
+ rc_login_config.get("failed_attempts", {})
)
- self.federation_rc_window_size = config.get("federation_rc_window_size", 1000)
- self.federation_rc_sleep_limit = config.get("federation_rc_sleep_limit", 10)
- self.federation_rc_sleep_delay = config.get("federation_rc_sleep_delay", 500)
- self.federation_rc_reject_limit = config.get("federation_rc_reject_limit", 50)
- self.federation_rc_concurrent = config.get("federation_rc_concurrent", 3)
-
self.federation_rr_transactions_per_room_per_second = config.get(
- "federation_rr_transactions_per_room_per_second", 50,
+ "federation_rr_transactions_per_room_per_second", 50
)
def default_config(self, **kwargs):
return """\
## Ratelimiting ##
- # Number of messages a client can send per second
- #
- #rc_messages_per_second: 0.2
-
- # Number of message a client can send before being throttled
- #
- #rc_message_burst_count: 10.0
-
- # Ratelimiting settings for registration and login.
+ # Ratelimiting settings for client actions (registration, login, messaging).
#
# Each ratelimiting configuration is made of two parameters:
# - per_second: number of requests a client can send per second.
# - burst_count: number of requests a client can send before being throttled.
#
# Synapse currently uses the following configurations:
+ # - one for messages that ratelimits sending based on the account the client
+ # is using
# - one for registration that ratelimits registration requests based on the
# client's IP address.
# - one for login that ratelimits login requests based on the client's IP
@@ -77,6 +105,10 @@ class RatelimitConfig(Config):
#
# The defaults are as shown below.
#
+ #rc_message:
+ # per_second: 0.2
+ # burst_count: 10
+ #
#rc_registration:
# per_second: 0.17
# burst_count: 3
@@ -92,29 +124,28 @@ class RatelimitConfig(Config):
# per_second: 0.17
# burst_count: 3
- # The federation window size in milliseconds
- #
- #federation_rc_window_size: 1000
-
- # The number of federation requests from a single server in a window
- # before the server will delay processing the request.
- #
- #federation_rc_sleep_limit: 10
- # The duration in milliseconds to delay processing events from
- # remote servers by if they go over the sleep limit.
+ # Ratelimiting settings for incoming federation
#
- #federation_rc_sleep_delay: 500
-
- # The maximum number of concurrent federation requests allowed
- # from a single server
+ # The rc_federation configuration is made up of the following settings:
+ # - window_size: window size in milliseconds
+ # - sleep_limit: number of federation requests from a single server in
+ # a window before the server will delay processing the request.
+ # - sleep_delay: duration in milliseconds to delay processing events
+ # from remote servers by if they go over the sleep limit.
+ # - reject_limit: maximum number of concurrent federation requests
+ # allowed from a single server
+ # - concurrent: number of federation requests to concurrently process
+ # from a single server
#
- #federation_rc_reject_limit: 50
-
- # The number of federation requests to concurrently process from a
- # single server
+ # The defaults are as shown below.
#
- #federation_rc_concurrent: 3
+ #rc_federation:
+ # window_size: 1000
+ # sleep_limit: 10
+ # sleep_delay: 500
+ # reject_limit: 50
+ # concurrent: 3
# Target outgoing federation transaction frequency for sending read-receipts,
# per-room.
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 1309bce3ee..aad3400819 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -39,6 +39,8 @@ class AccountValidityConfig(Config):
else:
self.renew_email_subject = "Renew your %(app)s account"
+ self.startup_job_max_delta = self.period * 10. / 100.
+
if self.renew_by_email_enabled and "public_baseurl" not in synapse_config:
raise ConfigError("Can't send renewal emails without 'public_baseurl'")
@@ -123,6 +125,16 @@ class RegistrationConfig(Config):
# link. ``%%(app)s`` can be used as a placeholder for the ``app_name`` parameter
# from the ``email`` section.
#
+ # Once this feature is enabled, Synapse will look for registered users without an
+ # expiration date at startup and will add one to every account it found using the
+ # current settings at that time.
+ # This means that, if a validity period is set, and Synapse is restarted (it will
+ # then derive an expiration date from the current validity period), and some time
+ # after that the validity period changes and Synapse is restarted, the users'
+ # expiration dates won't be updated unless their account is manually renewed. This
+ # date will be randomly selected within a range [now + period - d ; now + period],
+ # where d is equal to 10%% of the validity period.
+ #
#account_validity:
# enabled: True
# period: 6w
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 3f34ad9b2a..fbfcecc240 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -186,17 +186,21 @@ class ContentRepositoryConfig(Config):
except ImportError:
raise ConfigError(MISSING_NETADDR)
- if "url_preview_ip_range_blacklist" in config:
- self.url_preview_ip_range_blacklist = IPSet(
- config["url_preview_ip_range_blacklist"]
- )
- else:
+ if "url_preview_ip_range_blacklist" not in config:
raise ConfigError(
"For security, you must specify an explicit target IP address "
"blacklist in url_preview_ip_range_blacklist for url previewing "
"to work"
)
+ self.url_preview_ip_range_blacklist = IPSet(
+ config["url_preview_ip_range_blacklist"]
+ )
+
+ # we always blacklist '0.0.0.0' and '::', which are supposed to be
+ # unroutable addresses.
+ self.url_preview_ip_range_blacklist.update(['0.0.0.0', '::'])
+
self.url_preview_ip_range_whitelist = IPSet(
config.get("url_preview_ip_range_whitelist", ())
)
@@ -260,11 +264,12 @@ class ContentRepositoryConfig(Config):
#thumbnail_sizes:
%(formatted_thumbnail_sizes)s
- # Is the preview URL API enabled? If enabled, you *must* specify
- # an explicit url_preview_ip_range_blacklist of IPs that the spider is
- # denied from accessing.
+ # Is the preview URL API enabled?
+ #
+ # 'false' by default: uncomment the following to enable it (and specify a
+ # url_preview_ip_range_blacklist blacklist).
#
- #url_preview_enabled: false
+ #url_preview_enabled: true
# List of IP address CIDR ranges that the URL preview spider is denied
# from accessing. There are no defaults: you must explicitly
@@ -274,6 +279,12 @@ class ContentRepositoryConfig(Config):
# synapse to issue arbitrary GET requests to your internal services,
# causing serious security issues.
#
+ # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
+ # listed here, since they correspond to unroutable addresses.)
+ #
+ # This must be specified if url_preview_enabled is set. It is recommended that
+ # you uncomment the following list as a starting point.
+ #
#url_preview_ip_range_blacklist:
# - '127.0.0.0/8'
# - '10.0.0.0/8'
@@ -284,7 +295,7 @@ class ContentRepositoryConfig(Config):
# - '::1/128'
# - 'fe80::/64'
# - 'fc00::/7'
- #
+
# List of IP address CIDR ranges that the URL preview spider is allowed
# to access even if they are specified in url_preview_ip_range_blacklist.
# This is useful for specifying exceptions to wide-ranging blacklisted
diff --git a/synapse/config/server.py b/synapse/config/server.py
index c5e5679d52..e763e19e15 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,6 +18,9 @@
import logging
import os.path
+from netaddr import IPSet
+
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.python_dependencies import DependencyException, check_requirements
@@ -32,6 +36,8 @@ logger = logging.Logger(__name__)
# in the list.
DEFAULT_BIND_ADDRESSES = ['::', '0.0.0.0']
+DEFAULT_ROOM_VERSION = "1"
+
class ServerConfig(Config):
@@ -72,6 +78,35 @@ class ServerConfig(Config):
# master, potentially causing inconsistency.
self.enable_media_repo = config.get("enable_media_repo", True)
+ # Whether to require authentication to retrieve profile data (avatars,
+ # display names) of other users through the client API.
+ self.require_auth_for_profile_requests = config.get(
+ "require_auth_for_profile_requests", False,
+ )
+
+ # If set to 'True', requires authentication to access the server's
+ # public rooms directory through the client API, and forbids any other
+ # homeserver to fetch it via federation.
+ self.restrict_public_rooms_to_local_users = config.get(
+ "restrict_public_rooms_to_local_users", False,
+ )
+
+ default_room_version = config.get(
+ "default_room_version", DEFAULT_ROOM_VERSION,
+ )
+
+ # Ensure room version is a str
+ default_room_version = str(default_room_version)
+
+ if default_room_version not in KNOWN_ROOM_VERSIONS:
+ raise ConfigError(
+ "Unknown default_room_version: %s, known room versions: %s" %
+ (default_room_version, list(KNOWN_ROOM_VERSIONS.keys()))
+ )
+
+ # Get the actual room version object rather than just the identifier
+ self.default_room_version = KNOWN_ROOM_VERSIONS[default_room_version]
+
# whether to enable search. If disabled, new entries will not be inserted
# into the search tables and they will not be indexed. Users will receive
# errors when attempting to search for messages.
@@ -85,6 +120,11 @@ class ServerConfig(Config):
"block_non_admin_invites", False,
)
+ # Whether to enable experimental MSC1849 (aka relations) support
+ self.experimental_msc1849_support_enabled = config.get(
+ "experimental_msc1849_support_enabled", False,
+ )
+
# Options to control access by tracking MAU
self.limit_usage_by_mau = config.get("limit_usage_by_mau", False)
self.max_mau_value = 0
@@ -114,14 +154,34 @@ class ServerConfig(Config):
# FIXME: federation_domain_whitelist needs sytests
self.federation_domain_whitelist = None
federation_domain_whitelist = config.get(
- "federation_domain_whitelist", None
+ "federation_domain_whitelist", None,
)
- # turn the whitelist into a hash for speed of lookup
+
if federation_domain_whitelist is not None:
+ # turn the whitelist into a hash for speed of lookup
self.federation_domain_whitelist = {}
+
for domain in federation_domain_whitelist:
self.federation_domain_whitelist[domain] = True
+ self.federation_ip_range_blacklist = config.get(
+ "federation_ip_range_blacklist", [],
+ )
+
+ # Attempt to create an IPSet from the given ranges
+ try:
+ self.federation_ip_range_blacklist = IPSet(
+ self.federation_ip_range_blacklist
+ )
+
+ # Always blacklist 0.0.0.0, ::
+ self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
+ except Exception as e:
+ raise ConfigError(
+ "Invalid range(s) provided in "
+ "federation_ip_range_blacklist: %s" % e
+ )
+
if self.public_baseurl is not None:
if self.public_baseurl[-1] != '/':
self.public_baseurl += '/'
@@ -132,6 +192,16 @@ class ServerConfig(Config):
# sending out any replication updates.
self.replication_torture_level = config.get("replication_torture_level")
+ # Whether to require a user to be in the room to add an alias to it.
+ # Defaults to True.
+ self.require_membership_for_aliases = config.get(
+ "require_membership_for_aliases", True,
+ )
+
+ # Whether to allow per-room membership profiles through the send of membership
+ # events with profile information that differ from the target's global profile.
+ self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)
+
self.listeners = []
for listener in config.get("listeners", []):
if not isinstance(listener.get("port", None), int):
@@ -259,6 +329,10 @@ class ServerConfig(Config):
unsecure_port = 8008
pid_file = os.path.join(data_dir_path, "homeserver.pid")
+
+ # Bring DEFAULT_ROOM_VERSION into the local-scope for use in the
+ # default config string
+ default_room_version = DEFAULT_ROOM_VERSION
return """\
## Server ##
@@ -319,6 +393,30 @@ class ServerConfig(Config):
#
#use_presence: false
+ # Whether to require authentication to retrieve profile data (avatars,
+ # display names) of other users through the client API. Defaults to
+ # 'false'. Note that profile data is also available via the federation
+ # API, so this setting is of limited value if federation is enabled on
+ # the server.
+ #
+ #require_auth_for_profile_requests: true
+
+ # If set to 'true', requires authentication to access the server's
+ # public rooms directory through the client API, and forbids any other
+ # homeserver to fetch it via federation. Defaults to 'false'.
+ #
+ #restrict_public_rooms_to_local_users: true
+
+ # The default room version for newly created rooms.
+ #
+ # Known room versions are listed here:
+ # https://matrix.org/docs/spec/#complete-list-of-room-versions
+ #
+ # For example, for room version 1, default_room_version should be set
+ # to "1".
+ #
+ #default_room_version: "%(default_room_version)s"
+
# The GC threshold parameters to pass to `gc.set_threshold`, if defined
#
#gc_thresholds: [700, 10, 10]
@@ -351,6 +449,24 @@ class ServerConfig(Config):
# - nyc.example.com
# - syd.example.com
+ # Prevent federation requests from being sent to the following
+ # blacklist IP address CIDR ranges. If this option is not specified, or
+ # specified with an empty list, no ip range blacklist will be enforced.
+ #
+ # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
+ # listed here, since they correspond to unroutable addresses.)
+ #
+ federation_ip_range_blacklist:
+ - '127.0.0.0/8'
+ - '10.0.0.0/8'
+ - '172.16.0.0/12'
+ - '192.168.0.0/16'
+ - '100.64.0.0/10'
+ - '169.254.0.0/16'
+ - '::1/128'
+ - 'fe80::/64'
+ - 'fc00::/7'
+
# List of ports that Synapse should listen on, their purpose and their
# configuration.
#
@@ -386,8 +502,8 @@ class ServerConfig(Config):
#
# Valid resource names are:
#
- # client: the client-server API (/_matrix/client). Also implies 'media' and
- # 'static'.
+ # client: the client-server API (/_matrix/client), and the synapse admin
+ # API (/_synapse/admin). Also implies 'media' and 'static'.
#
# consent: user consent forms (/_matrix/consent). See
# docs/consent_tracking.md.
@@ -488,6 +604,17 @@ class ServerConfig(Config):
# Used by phonehome stats to group together related servers.
#server_context: context
+
+ # Whether to require a user to be in the room to add an alias to it.
+ # Defaults to 'true'.
+ #
+ #require_membership_for_aliases: false
+
+ # Whether to allow per-room membership profiles through the send of membership
+ # events with profile information that differ from the target's global profile.
+ # Defaults to 'true'.
+ #
+ #allow_per_room_profiles: false
""" % locals()
def read_arguments(self, args):
diff --git a/synapse/config/stats.py b/synapse/config/stats.py
new file mode 100644
index 0000000000..80fc1b9dd0
--- /dev/null
+++ b/synapse/config/stats.py
@@ -0,0 +1,60 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import division
+
+import sys
+
+from ._base import Config
+
+
+class StatsConfig(Config):
+ """Stats Configuration
+ Configuration for the behaviour of synapse's stats engine
+ """
+
+ def read_config(self, config):
+ self.stats_enabled = True
+ self.stats_bucket_size = 86400
+ self.stats_retention = sys.maxsize
+ stats_config = config.get("stats", None)
+ if stats_config:
+ self.stats_enabled = stats_config.get("enabled", self.stats_enabled)
+ self.stats_bucket_size = (
+ self.parse_duration(stats_config.get("bucket_size", "1d")) / 1000
+ )
+ self.stats_retention = (
+ self.parse_duration(
+ stats_config.get("retention", "%ds" % (sys.maxsize,))
+ )
+ / 1000
+ )
+
+ def default_config(self, config_dir_path, server_name, **kwargs):
+ return """
+ # Local statistics collection. Used in populating the room directory.
+ #
+ # 'bucket_size' controls how large each statistics timeslice is. It can
+ # be defined in a human readable short form -- e.g. "1d", "1y".
+ #
+ # 'retention' controls how long historical statistics will be kept for.
+ # It can be defined in a human readable short form -- e.g. "1d", "1y".
+ #
+ #
+ #stats:
+ # enabled: true
+ # bucket_size: 1d
+ # retention: 1y
+ """
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index f0014902da..72dd5926f9 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -24,8 +24,10 @@ import six
from unpaddedbase64 import encode_base64
from OpenSSL import crypto
+from twisted.internet._sslverify import Certificate, trustRootFromCertificates
from synapse.config._base import Config, ConfigError
+from synapse.util import glob_to_regex
logger = logging.getLogger(__name__)
@@ -70,6 +72,53 @@ class TlsConfig(Config):
self.tls_fingerprints = list(self._original_tls_fingerprints)
+ # Whether to verify certificates on outbound federation traffic
+ self.federation_verify_certificates = config.get(
+ "federation_verify_certificates", False,
+ )
+
+ # Whitelist of domains to not verify certificates for
+ fed_whitelist_entries = config.get(
+ "federation_certificate_verification_whitelist", [],
+ )
+
+ # Support globs (*) in whitelist values
+ self.federation_certificate_verification_whitelist = []
+ for entry in fed_whitelist_entries:
+ # Convert globs to regex
+ entry_regex = glob_to_regex(entry)
+ self.federation_certificate_verification_whitelist.append(entry_regex)
+
+ # List of custom certificate authorities for federation traffic validation
+ custom_ca_list = config.get(
+ "federation_custom_ca_list", None,
+ )
+
+ # Read in and parse custom CA certificates
+ self.federation_ca_trust_root = None
+ if custom_ca_list is not None:
+ if len(custom_ca_list) == 0:
+ # A trustroot cannot be generated without any CA certificates.
+ # Raise an error if this option has been specified without any
+ # corresponding certificates.
+ raise ConfigError("federation_custom_ca_list specified without "
+ "any certificate files")
+
+ certs = []
+ for ca_file in custom_ca_list:
+ logger.debug("Reading custom CA certificate file: %s", ca_file)
+ content = self.read_file(ca_file)
+
+ # Parse the CA certificates
+ try:
+ cert_base = Certificate.loadPEM(content)
+ certs.append(cert_base)
+ except Exception as e:
+ raise ConfigError("Error parsing custom CA certificate file %s: %s"
+ % (ca_file, e))
+
+ self.federation_ca_trust_root = trustRootFromCertificates(certs)
+
# This config option applies to non-federation HTTP clients
# (e.g. for talking to recaptcha, identity servers, and such)
# It should never be used in production, and is intended for
@@ -99,15 +148,15 @@ class TlsConfig(Config):
try:
with open(self.tls_certificate_file, 'rb') as f:
cert_pem = f.read()
- except Exception:
- logger.exception("Failed to read existing certificate off disk!")
- raise
+ except Exception as e:
+ raise ConfigError("Failed to read existing certificate file %s: %s"
+ % (self.tls_certificate_file, e))
try:
tls_certificate = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)
- except Exception:
- logger.exception("Failed to parse existing certificate off disk!")
- raise
+ except Exception as e:
+ raise ConfigError("Failed to parse existing certificate file %s: %s"
+ % (self.tls_certificate_file, e))
if not allow_self_signed:
if tls_certificate.get_subject() == tls_certificate.get_issuer():
@@ -192,6 +241,40 @@ class TlsConfig(Config):
#
#tls_private_key_path: "%(tls_private_key_path)s"
+ # Whether to verify TLS certificates when sending federation traffic.
+ #
+ # This currently defaults to `false`, however this will change in
+ # Synapse 1.0 when valid federation certificates will be required.
+ #
+ #federation_verify_certificates: true
+
+ # Skip federation certificate verification on the following whitelist
+ # of domains.
+ #
+ # This setting should only be used in very specific cases, such as
+ # federation over Tor hidden services and similar. For private networks
+ # of homeservers, you likely want to use a private CA instead.
+ #
+ # Only effective if federation_verify_certicates is `true`.
+ #
+ #federation_certificate_verification_whitelist:
+ # - lon.example.com
+ # - *.domain.com
+ # - *.onion
+
+ # List of custom certificate authorities for federation traffic.
+ #
+ # This setting should only normally be used within a private network of
+ # homeservers.
+ #
+ # Note that this list will replace those that are provided by your
+ # operating environment. Certificates must be in PEM format.
+ #
+ #federation_custom_ca_list:
+ # - myCA1.pem
+ # - myCA2.pem
+ # - myCA3.pem
+
# ACME support: This will configure Synapse to request a valid TLS certificate
# for your configured `server_name` via Let's Encrypt.
#
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index 142754a7dc..023997ccde 100644
--- a/synapse/config/user_directory.py
+++ b/synapse/config/user_directory.py
@@ -43,9 +43,9 @@ class UserDirectoryConfig(Config):
#
# 'search_all_users' defines whether to search all users visible to your HS
# when searching the user directory, rather than limiting to users visible
- # in public rooms. Defaults to false. If you set it True, you'll have to run
- # UPDATE user_directory_stream_pos SET stream_id = NULL;
- # on your database to tell it to rebuild the user_directory search indexes.
+ # in public rooms. Defaults to false. If you set it True, you'll have to
+ # rebuild the user_directory search indexes, see
+ # https://github.com/matrix-org/synapse/blob/master/docs/user_directory.md
#
#user_directory:
# enabled: true
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index 49cbc7098f..59ea087e66 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -18,10 +18,10 @@ import logging
from zope.interface import implementer
from OpenSSL import SSL, crypto
-from twisted.internet._sslverify import _defaultCurveName
+from twisted.internet._sslverify import ClientTLSOptions, _defaultCurveName
from twisted.internet.abstract import isIPAddress, isIPv6Address
from twisted.internet.interfaces import IOpenSSLClientConnectionCreator
-from twisted.internet.ssl import CertificateOptions, ContextFactory
+from twisted.internet.ssl import CertificateOptions, ContextFactory, platformTrust
from twisted.python.failure import Failure
logger = logging.getLogger(__name__)
@@ -90,7 +90,7 @@ def _tolerateErrors(wrapped):
@implementer(IOpenSSLClientConnectionCreator)
-class ClientTLSOptions(object):
+class ClientTLSOptionsNoVerify(object):
"""
Client creator for TLS without certificate identity verification. This is a
copy of twisted.internet._sslverify.ClientTLSOptions with the identity
@@ -127,9 +127,30 @@ class ClientTLSOptionsFactory(object):
to remote servers for federation."""
def __init__(self, config):
- # We don't use config options yet
- self._options = CertificateOptions(verify=False)
+ self._config = config
+ self._options_noverify = CertificateOptions()
+
+ # Check if we're using a custom list of a CA certificates
+ trust_root = config.federation_ca_trust_root
+ if trust_root is None:
+ # Use CA root certs provided by OpenSSL
+ trust_root = platformTrust()
+
+ self._options_verify = CertificateOptions(trustRoot=trust_root)
def get_options(self, host):
# Use _makeContext so that we get a fresh OpenSSL CTX each time.
- return ClientTLSOptions(host, self._options._makeContext())
+
+ # Check if certificate verification has been enabled
+ should_verify = self._config.federation_verify_certificates
+
+ # Check if we've disabled certificate verification for this host
+ if should_verify:
+ for regex in self._config.federation_certificate_verification_whitelist:
+ if regex.match(host):
+ should_verify = False
+ break
+
+ if should_verify:
+ return ClientTLSOptions(host, self._options_verify._makeContext())
+ return ClientTLSOptionsNoVerify(host, self._options_noverify._makeContext())
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index 1dfa727fcf..99a586655b 100644
--- a/synapse/crypto/event_signing.py
+++ b/synapse/crypto/event_signing.py
@@ -31,7 +31,11 @@ logger = logging.getLogger(__name__)
def check_event_content_hash(event, hash_algorithm=hashlib.sha256):
"""Check whether the hash for this PDU matches the contents"""
name, expected_hash = compute_content_hash(event.get_pdu_json(), hash_algorithm)
- logger.debug("Expecting hash: %s", encode_base64(expected_hash))
+ logger.debug(
+ "Verifying content hash on %s (expecting: %s)",
+ event.event_id,
+ encode_base64(expected_hash),
+ )
# some malformed events lack a 'hashes'. Protect against it being missing
# or a weird type by basically treating it the same as an unhashed event.
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index ed2e994437..e94e71bdad 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -15,12 +15,13 @@
# limitations under the License.
import logging
-from collections import namedtuple
+from collections import defaultdict
+import six
from six import raise_from
from six.moves import urllib
-import nacl.signing
+import attr
from signedjson.key import (
decode_verify_key_bytes,
encode_verify_key_base64,
@@ -43,7 +44,9 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
+from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext, unwrapFirstError
+from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.logcontext import (
LoggingContext,
PreserveLoggingContext,
@@ -56,22 +59,36 @@ from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
-VerifyKeyRequest = namedtuple("VerifyRequest", (
- "server_name", "key_ids", "json_object", "deferred"
-))
-"""
-A request for a verify key to verify a JSON object.
+@attr.s(slots=True, cmp=False)
+class VerifyKeyRequest(object):
+ """
+ A request for a verify key to verify a JSON object.
+
+ Attributes:
+ server_name(str): The name of the server to verify against.
+
+ key_ids(set[str]): The set of key_ids to that could be used to verify the
+ JSON object
-Attributes:
- server_name(str): The name of the server to verify against.
- key_ids(set(str)): The set of key_ids to that could be used to verify the
- JSON object
- json_object(dict): The JSON object to verify.
- deferred(Deferred[str, str, nacl.signing.VerifyKey]):
- A deferred (server_name, key_id, verify_key) tuple that resolves when
- a verify key has been fetched. The deferreds' callbacks are run with no
- logcontext.
-"""
+ json_object(dict): The JSON object to verify.
+
+ minimum_valid_until_ts (int): time at which we require the signing key to
+ be valid. (0 implies we don't care)
+
+ key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
+ A deferred (server_name, key_id, verify_key) tuple that resolves when
+ a verify key has been fetched. The deferreds' callbacks are run with no
+ logcontext.
+
+ If we are unable to find a key which satisfies the request, the deferred
+ errbacks with an M_UNAUTHORIZED SynapseError.
+ """
+
+ server_name = attr.ib()
+ key_ids = attr.ib()
+ json_object = attr.ib()
+ minimum_valid_until_ts = attr.ib()
+ key_ready = attr.ib(default=attr.Factory(defer.Deferred))
class KeyLookupError(ValueError):
@@ -79,13 +96,16 @@ class KeyLookupError(ValueError):
class Keyring(object):
- def __init__(self, hs):
- self.store = hs.get_datastore()
+ def __init__(self, hs, key_fetchers=None):
self.clock = hs.get_clock()
- self.client = hs.get_http_client()
- self.config = hs.get_config()
- self.perspective_servers = self.config.perspectives
- self.hs = hs
+
+ if key_fetchers is None:
+ key_fetchers = (
+ StoreKeyFetcher(hs),
+ PerspectivesKeyFetcher(hs),
+ ServerKeyFetcher(hs),
+ )
+ self._key_fetchers = key_fetchers
# map from server name to Deferred. Has an entry for each server with
# an ongoing key download; the Deferred completes once the download
@@ -94,11 +114,25 @@ class Keyring(object):
# These are regular, logcontext-agnostic Deferreds.
self.key_downloads = {}
- def verify_json_for_server(self, server_name, json_object):
+ def verify_json_for_server(self, server_name, json_object, validity_time):
+ """Verify that a JSON object has been signed by a given server
+
+ Args:
+ server_name (str): name of the server which must have signed this object
+
+ json_object (dict): object to be checked
+
+ validity_time (int): timestamp at which we require the signing key to
+ be valid. (0 implies we don't care)
+
+ Returns:
+ Deferred[None]: completes if the the object was correctly signed, otherwise
+ errbacks with an error
+ """
+ req = server_name, json_object, validity_time
+
return logcontext.make_deferred_yieldable(
- self.verify_json_objects_for_server(
- [(server_name, json_object)]
- )[0]
+ self.verify_json_objects_for_server((req,))[0]
)
def verify_json_objects_for_server(self, server_and_json):
@@ -106,53 +140,71 @@ class Keyring(object):
necessary.
Args:
- server_and_json (list): List of pairs of (server_name, json_object)
+ server_and_json (iterable[Tuple[str, dict, int]):
+ Iterable of triplets of (server_name, json_object, validity_time)
+ validity_time is a timestamp at which the signing key must be valid.
Returns:
- List<Deferred>: for each input pair, a deferred indicating success
+ List<Deferred[None]>: for each input triplet, a deferred indicating success
or failure to verify each json object's signature for the given
server_name. The deferreds run their callbacks in the sentinel
logcontext.
"""
+ # a list of VerifyKeyRequests
verify_requests = []
+ handle = preserve_fn(_handle_key_deferred)
- for server_name, json_object in server_and_json:
+ def process(server_name, json_object, validity_time):
+ """Process an entry in the request list
+ Given a (server_name, json_object, validity_time) triplet from the request
+ list, adds a key request to verify_requests, and returns a deferred which
+ will complete or fail (in the sentinel context) when verification completes.
+ """
key_ids = signature_ids(json_object, server_name)
+
if not key_ids:
- logger.warn("Request from %s: no supported signature keys",
- server_name)
- deferred = defer.fail(SynapseError(
- 400,
- "Not signed with a supported algorithm",
- Codes.UNAUTHORIZED,
- ))
- else:
- deferred = defer.Deferred()
-
- logger.debug("Verifying for %s with key_ids %s",
- server_name, key_ids)
+ return defer.fail(
+ SynapseError(
+ 400, "Not signed by %s" % (server_name,), Codes.UNAUTHORIZED
+ )
+ )
- verify_request = VerifyKeyRequest(
- server_name, key_ids, json_object, deferred
+ logger.debug(
+ "Verifying for %s with key_ids %s, min_validity %i",
+ server_name,
+ key_ids,
+ validity_time,
)
+ # add the key request to the queue, but don't start it off yet.
+ verify_request = VerifyKeyRequest(
+ server_name, key_ids, json_object, validity_time
+ )
verify_requests.append(verify_request)
- run_in_background(self._start_key_lookups, verify_requests)
+ # now run _handle_key_deferred, which will wait for the key request
+ # to complete and then do the verification.
+ #
+ # We want _handle_key_request to log to the right context, so we
+ # wrap it with preserve_fn (aka run_in_background)
+ return handle(verify_request)
- # Pass those keys to handle_key_deferred so that the json object
- # signatures can be verified
- handle = preserve_fn(_handle_key_deferred)
- return [
- handle(rq) for rq in verify_requests
+ results = [
+ process(server_name, json_object, validity_time)
+ for server_name, json_object, validity_time in server_and_json
]
+ if verify_requests:
+ run_in_background(self._start_key_lookups, verify_requests)
+
+ return results
+
@defer.inlineCallbacks
def _start_key_lookups(self, verify_requests):
"""Sets off the key fetches for each verify request
- Once each fetch completes, verify_request.deferred will be resolved.
+ Once each fetch completes, verify_request.key_ready will be resolved.
Args:
verify_requests (List[VerifyKeyRequest]):
@@ -165,16 +217,12 @@ class Keyring(object):
# any other lookups until we have finished.
# The deferreds are called with no logcontext.
server_to_deferred = {
- rq.server_name: defer.Deferred()
- for rq in verify_requests
+ rq.server_name: defer.Deferred() for rq in verify_requests
}
# We want to wait for any previous lookups to complete before
# proceeding.
- yield self.wait_for_previous_lookups(
- [rq.server_name for rq in verify_requests],
- server_to_deferred,
- )
+ yield self.wait_for_previous_lookups(server_to_deferred)
# Actually start fetching keys.
self._get_server_verify_keys(verify_requests)
@@ -202,19 +250,16 @@ class Keyring(object):
return res
for verify_request in verify_requests:
- verify_request.deferred.addBoth(
- remove_deferreds, verify_request,
- )
+ verify_request.key_ready.addBoth(remove_deferreds, verify_request)
except Exception:
logger.exception("Error starting key lookups")
@defer.inlineCallbacks
- def wait_for_previous_lookups(self, server_names, server_to_deferred):
+ def wait_for_previous_lookups(self, server_to_deferred):
"""Waits for any previous key lookups for the given servers to finish.
Args:
- server_names (list): list of server_names we want to lookup
- server_to_deferred (dict): server_name to deferred which gets
+ server_to_deferred (dict[str, Deferred]): server_name to deferred which gets
resolved once we've finished looking up keys for that server.
The Deferreds should be regular twisted ones which call their
callbacks with no logcontext.
@@ -227,14 +272,15 @@ class Keyring(object):
while True:
wait_on = [
(server_name, self.key_downloads[server_name])
- for server_name in server_names
+ for server_name in server_to_deferred.keys()
if server_name in self.key_downloads
]
if not wait_on:
break
logger.info(
"Waiting for existing lookups for %s to complete [loop %i]",
- [w[0] for w in wait_on], loop_count,
+ [w[0] for w in wait_on],
+ loop_count,
)
with PreserveLoggingContext():
yield defer.DeferredList((w[1] for w in wait_on))
@@ -257,7 +303,7 @@ class Keyring(object):
def _get_server_verify_keys(self, verify_requests):
"""Tries to find at least one key for each verify request
- For each verify_request, verify_request.deferred is called back with
+ For each verify_request, verify_request.key_ready is called back with
params (server_name, key_id, VerifyKey) if a key is found, or errbacked
with a SynapseError if none of the keys are found.
@@ -265,129 +311,290 @@ class Keyring(object):
verify_requests (list[VerifyKeyRequest]): list of verify requests
"""
- # These are functions that produce keys given a list of key ids
- key_fetch_fns = (
- self.get_keys_from_store, # First try the local store
- self.get_keys_from_perspectives, # Then try via perspectives
- self.get_keys_from_server, # Then try directly
+ remaining_requests = set(
+ (rq for rq in verify_requests if not rq.key_ready.called)
)
@defer.inlineCallbacks
def do_iterations():
with Measure(self.clock, "get_server_verify_keys"):
- # dict[str, set(str)]: keys to fetch for each server
- missing_keys = {}
- for verify_request in verify_requests:
- missing_keys.setdefault(verify_request.server_name, set()).update(
- verify_request.key_ids
- )
-
- for fn in key_fetch_fns:
- results = yield fn(missing_keys.items())
-
- # We now need to figure out which verify requests we have keys
- # for and which we don't
- missing_keys = {}
- requests_missing_keys = []
- for verify_request in verify_requests:
- if verify_request.deferred.called:
- # We've already called this deferred, which probably
- # means that we've already found a key for it.
- continue
-
- server_name = verify_request.server_name
-
- # see if any of the keys we got this time are sufficient to
- # complete this VerifyKeyRequest.
- result_keys = results.get(server_name, {})
- for key_id in verify_request.key_ids:
- key = result_keys.get(key_id)
- if key:
- with PreserveLoggingContext():
- verify_request.deferred.callback(
- (server_name, key_id, key)
- )
- break
- else:
- # The else block is only reached if the loop above
- # doesn't break.
- missing_keys.setdefault(server_name, set()).update(
- verify_request.key_ids
- )
- requests_missing_keys.append(verify_request)
-
- if not missing_keys:
- break
+ for f in self._key_fetchers:
+ if not remaining_requests:
+ return
+ yield self._attempt_key_fetches_with_fetcher(f, remaining_requests)
+ # look for any requests which weren't satisfied
with PreserveLoggingContext():
- for verify_request in requests_missing_keys:
- verify_request.deferred.errback(SynapseError(
- 401,
- "No key for %s with id %s" % (
- verify_request.server_name, verify_request.key_ids,
- ),
- Codes.UNAUTHORIZED,
- ))
+ for verify_request in remaining_requests:
+ verify_request.key_ready.errback(
+ SynapseError(
+ 401,
+ "No key for %s with ids in %s (min_validity %i)"
+ % (
+ verify_request.server_name,
+ verify_request.key_ids,
+ verify_request.minimum_valid_until_ts,
+ ),
+ Codes.UNAUTHORIZED,
+ )
+ )
def on_err(err):
+ # we don't really expect to get here, because any errors should already
+ # have been caught and logged. But if we do, let's log the error and make
+ # sure that all of the deferreds are resolved.
+ logger.error("Unexpected error in _get_server_verify_keys: %s", err)
with PreserveLoggingContext():
- for verify_request in verify_requests:
- if not verify_request.deferred.called:
- verify_request.deferred.errback(err)
+ for verify_request in remaining_requests:
+ if not verify_request.key_ready.called:
+ verify_request.key_ready.errback(err)
run_in_background(do_iterations).addErrback(on_err)
@defer.inlineCallbacks
- def get_keys_from_store(self, server_name_and_key_ids):
+ def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
+ """Use a key fetcher to attempt to satisfy some key requests
+
+ Args:
+ fetcher (KeyFetcher): fetcher to use to fetch the keys
+ remaining_requests (set[VerifyKeyRequest]): outstanding key requests.
+ Any successfully-completed requests will be removed from the list.
+ """
+ # dict[str, dict[str, int]]: keys to fetch.
+ # server_name -> key_id -> min_valid_ts
+ missing_keys = defaultdict(dict)
+
+ for verify_request in remaining_requests:
+ # any completed requests should already have been removed
+ assert not verify_request.key_ready.called
+ keys_for_server = missing_keys[verify_request.server_name]
+
+ for key_id in verify_request.key_ids:
+ # If we have several requests for the same key, then we only need to
+ # request that key once, but we should do so with the greatest
+ # min_valid_until_ts of the requests, so that we can satisfy all of
+ # the requests.
+ keys_for_server[key_id] = max(
+ keys_for_server.get(key_id, -1),
+ verify_request.minimum_valid_until_ts
+ )
+
+ results = yield fetcher.get_keys(missing_keys)
+
+ completed = list()
+ for verify_request in remaining_requests:
+ server_name = verify_request.server_name
+
+ # see if any of the keys we got this time are sufficient to
+ # complete this VerifyKeyRequest.
+ result_keys = results.get(server_name, {})
+ for key_id in verify_request.key_ids:
+ fetch_key_result = result_keys.get(key_id)
+ if not fetch_key_result:
+ # we didn't get a result for this key
+ continue
+
+ if (
+ fetch_key_result.valid_until_ts
+ < verify_request.minimum_valid_until_ts
+ ):
+ # key was not valid at this point
+ continue
+
+ with PreserveLoggingContext():
+ verify_request.key_ready.callback(
+ (server_name, key_id, fetch_key_result.verify_key)
+ )
+ completed.append(verify_request)
+ break
+
+ remaining_requests.difference_update(completed)
+
+
+class KeyFetcher(object):
+ def get_keys(self, keys_to_fetch):
"""
Args:
- server_name_and_key_ids (iterable(Tuple[str, iterable[str]]):
- list of (server_name, iterable[key_id]) tuples to fetch keys for
+ keys_to_fetch (dict[str, dict[str, int]]):
+ the keys to be fetched. server_name -> key_id -> min_valid_ts
Returns:
- Deferred: resolves to dict[str, dict[str, VerifyKey|None]]: map from
- server_name -> key_id -> VerifyKey
+ Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
+ map from server_name -> key_id -> FetchKeyResult
"""
+ raise NotImplementedError
+
+
+class StoreKeyFetcher(KeyFetcher):
+ """KeyFetcher impl which fetches keys from our data store"""
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def get_keys(self, keys_to_fetch):
+ """see KeyFetcher.get_keys"""
+
keys_to_fetch = (
(server_name, key_id)
- for server_name, key_ids in server_name_and_key_ids
- for key_id in key_ids
+ for server_name, keys_for_server in keys_to_fetch.items()
+ for key_id in keys_for_server.keys()
)
+
res = yield self.store.get_server_verify_keys(keys_to_fetch)
keys = {}
for (server_name, key_id), key in res.items():
keys.setdefault(server_name, {})[key_id] = key
defer.returnValue(keys)
+
+class BaseV2KeyFetcher(object):
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.config = hs.get_config()
+
@defer.inlineCallbacks
- def get_keys_from_perspectives(self, server_name_and_key_ids):
+ def process_v2_response(
+ self, from_server, response_json, time_added_ms
+ ):
+ """Parse a 'Server Keys' structure from the result of a /key request
+
+ This is used to parse either the entirety of the response from
+ GET /_matrix/key/v2/server, or a single entry from the list returned by
+ POST /_matrix/key/v2/query.
+
+ Checks that each signature in the response that claims to come from the origin
+ server is valid, and that there is at least one such signature.
+
+ Stores the json in server_keys_json so that it can be used for future responses
+ to /_matrix/key/v2/query.
+
+ Args:
+ from_server (str): the name of the server producing this result: either
+ the origin server for a /_matrix/key/v2/server request, or the notary
+ for a /_matrix/key/v2/query.
+
+ response_json (dict): the json-decoded Server Keys response object
+
+ time_added_ms (int): the timestamp to record in server_keys_json
+
+ Returns:
+ Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
+ """
+ ts_valid_until_ms = response_json[u"valid_until_ts"]
+
+ # start by extracting the keys from the response, since they may be required
+ # to validate the signature on the response.
+ verify_keys = {}
+ for key_id, key_data in response_json["verify_keys"].items():
+ if is_signing_algorithm_supported(key_id):
+ key_base64 = key_data["key"]
+ key_bytes = decode_base64(key_base64)
+ verify_key = decode_verify_key_bytes(key_id, key_bytes)
+ verify_keys[key_id] = FetchKeyResult(
+ verify_key=verify_key, valid_until_ts=ts_valid_until_ms
+ )
+
+ server_name = response_json["server_name"]
+ verified = False
+ for key_id in response_json["signatures"].get(server_name, {}):
+ # each of the keys used for the signature must be present in the response
+ # json.
+ key = verify_keys.get(key_id)
+ if not key:
+ raise KeyLookupError(
+ "Key response is signed by key id %s:%s but that key is not "
+ "present in the response" % (server_name, key_id)
+ )
+
+ verify_signed_json(response_json, server_name, key.verify_key)
+ verified = True
+
+ if not verified:
+ raise KeyLookupError(
+ "Key response for %s is not signed by the origin server"
+ % (server_name,)
+ )
+
+ for key_id, key_data in response_json["old_verify_keys"].items():
+ if is_signing_algorithm_supported(key_id):
+ key_base64 = key_data["key"]
+ key_bytes = decode_base64(key_base64)
+ verify_key = decode_verify_key_bytes(key_id, key_bytes)
+ verify_keys[key_id] = FetchKeyResult(
+ verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
+ )
+
+ # re-sign the json with our own key, so that it is ready if we are asked to
+ # give it out as a notary server
+ signed_key_json = sign_json(
+ response_json, self.config.server_name, self.config.signing_key[0]
+ )
+
+ signed_key_json_bytes = encode_canonical_json(signed_key_json)
+
+ yield logcontext.make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(
+ self.store.store_server_keys_json,
+ server_name=server_name,
+ key_id=key_id,
+ from_server=from_server,
+ ts_now_ms=time_added_ms,
+ ts_expires_ms=ts_valid_until_ms,
+ key_json_bytes=signed_key_json_bytes,
+ )
+ for key_id in verify_keys
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+ )
+
+ defer.returnValue(verify_keys)
+
+
+class PerspectivesKeyFetcher(BaseV2KeyFetcher):
+ """KeyFetcher impl which fetches keys from the "perspectives" servers"""
+
+ def __init__(self, hs):
+ super(PerspectivesKeyFetcher, self).__init__(hs)
+ self.clock = hs.get_clock()
+ self.client = hs.get_http_client()
+ self.perspective_servers = self.config.perspectives
+
+ @defer.inlineCallbacks
+ def get_keys(self, keys_to_fetch):
+ """see KeyFetcher.get_keys"""
+
@defer.inlineCallbacks
def get_key(perspective_name, perspective_keys):
try:
result = yield self.get_server_verify_key_v2_indirect(
- server_name_and_key_ids, perspective_name, perspective_keys
+ keys_to_fetch, perspective_name, perspective_keys
)
defer.returnValue(result)
except KeyLookupError as e:
- logger.warning(
- "Key lookup failed from %r: %s", perspective_name, e,
- )
+ logger.warning("Key lookup failed from %r: %s", perspective_name, e)
except Exception as e:
logger.exception(
"Unable to get key from %r: %s %s",
perspective_name,
- type(e).__name__, str(e),
+ type(e).__name__,
+ str(e),
)
defer.returnValue({})
- results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- run_in_background(get_key, p_name, p_keys)
- for p_name, p_keys in self.perspective_servers.items()
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError))
+ results = yield logcontext.make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(get_key, p_name, p_keys)
+ for p_name, p_keys in self.perspective_servers.items()
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+ )
union_of_keys = {}
for result in results:
@@ -397,36 +604,33 @@ class Keyring(object):
defer.returnValue(union_of_keys)
@defer.inlineCallbacks
- def get_keys_from_server(self, server_name_and_key_ids):
- results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- run_in_background(
- self.get_server_verify_key_v2_direct,
- server_name,
- key_ids,
- )
- for server_name, key_ids in server_name_and_key_ids
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError))
+ def get_server_verify_key_v2_indirect(
+ self, keys_to_fetch, perspective_name, perspective_keys
+ ):
+ """
+ Args:
+ keys_to_fetch (dict[str, dict[str, int]]):
+ the keys to be fetched. server_name -> key_id -> min_valid_ts
- merged = {}
- for result in results:
- merged.update(result)
+ perspective_name (str): name of the notary server to query for the keys
- defer.returnValue({
- server_name: keys
- for server_name, keys in merged.items()
- if keys
- })
+ perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
+ notary server
+
+ Returns:
+ Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map
+ from server_name -> key_id -> FetchKeyResult
+
+ Raises:
+ KeyLookupError if there was an error processing the entire response from
+ the server
+ """
+ logger.info(
+ "Requesting keys %s from notary server %s",
+ keys_to_fetch.items(),
+ perspective_name,
+ )
- @defer.inlineCallbacks
- def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
- perspective_name,
- perspective_keys):
- # TODO(mark): Set the minimum_valid_until_ts to that needed by
- # the events being validated or the current time if validating
- # an incoming request.
try:
query_response = yield self.client.post_json(
destination=perspective_name,
@@ -434,249 +638,213 @@ class Keyring(object):
data={
u"server_keys": {
server_name: {
- key_id: {
- u"minimum_valid_until_ts": 0
- } for key_id in key_ids
+ key_id: {u"minimum_valid_until_ts": min_valid_ts}
+ for key_id, min_valid_ts in server_keys.items()
}
- for server_name, key_ids in server_names_and_key_ids
+ for server_name, server_keys in keys_to_fetch.items()
}
},
- long_retries=True,
)
except (NotRetryingDestination, RequestSendFailed) as e:
- raise_from(
- KeyLookupError("Failed to connect to remote server"), e,
- )
+ raise_from(KeyLookupError("Failed to connect to remote server"), e)
except HttpResponseException as e:
- raise_from(
- KeyLookupError("Remote server returned an error"), e,
- )
+ raise_from(KeyLookupError("Remote server returned an error"), e)
keys = {}
+ added_keys = []
- responses = query_response["server_keys"]
+ time_now_ms = self.clock.time_msec()
- for response in responses:
- if (u"signatures" not in response
- or perspective_name not in response[u"signatures"]):
+ for response in query_response["server_keys"]:
+ # do this first, so that we can give useful errors thereafter
+ server_name = response.get("server_name")
+ if not isinstance(server_name, six.string_types):
raise KeyLookupError(
- "Key response not signed by perspective server"
- " %r" % (perspective_name,)
+ "Malformed response from key notary server %s: invalid server_name"
+ % (perspective_name,)
)
- verified = False
- for key_id in response[u"signatures"][perspective_name]:
- if key_id in perspective_keys:
- verify_signed_json(
- response,
- perspective_name,
- perspective_keys[key_id]
- )
- verified = True
-
- if not verified:
- logging.info(
- "Response from perspective server %r not signed with a"
- " known key, signed with: %r, known keys: %r",
+ try:
+ processed_response = yield self._process_perspectives_response(
perspective_name,
- list(response[u"signatures"][perspective_name]),
- list(perspective_keys)
+ perspective_keys,
+ response,
+ time_added_ms=time_now_ms,
)
- raise KeyLookupError(
- "Response not signed with a known key for perspective"
- " server %r" % (perspective_name,)
+ except KeyLookupError as e:
+ logger.warning(
+ "Error processing response from key notary server %s for origin "
+ "server %s: %s",
+ perspective_name,
+ server_name,
+ e,
)
+ # we continue to process the rest of the response
+ continue
- processed_response = yield self.process_v2_response(
- perspective_name, response
+ added_keys.extend(
+ (server_name, key_id, key) for key_id, key in processed_response.items()
)
- server_name = response["server_name"]
-
keys.setdefault(server_name, {}).update(processed_response)
- yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- run_in_background(
- self.store_keys,
- server_name=server_name,
- from_server=perspective_name,
- verify_keys=response_keys,
- )
- for server_name, response_keys in keys.items()
- ],
- consumeErrors=True
- ).addErrback(unwrapFirstError))
+ yield self.store.store_server_verify_keys(
+ perspective_name, time_now_ms, added_keys
+ )
defer.returnValue(keys)
- @defer.inlineCallbacks
- def get_server_verify_key_v2_direct(self, server_name, key_ids):
- keys = {} # type: dict[str, nacl.signing.VerifyKey]
+ def _process_perspectives_response(
+ self, perspective_name, perspective_keys, response, time_added_ms
+ ):
+ """Parse a 'Server Keys' structure from the result of a /key/query request
- for requested_key_id in key_ids:
- if requested_key_id in keys:
- continue
+ Checks that the entry is correctly signed by the perspectives server, and then
+ passes over to process_v2_response
- try:
- response = yield self.client.get_json(
- destination=server_name,
- path="/_matrix/key/v2/server/" + urllib.parse.quote(requested_key_id),
- ignore_backoff=True,
- )
- except (NotRetryingDestination, RequestSendFailed) as e:
- raise_from(
- KeyLookupError("Failed to connect to remote server"), e,
- )
- except HttpResponseException as e:
- raise_from(
- KeyLookupError("Remote server returned an error"), e,
- )
+ Args:
+ perspective_name (str): the name of the notary server that produced this
+ result
- if (u"signatures" not in response
- or server_name not in response[u"signatures"]):
- raise KeyLookupError("Key response not signed by remote server")
+ perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
+ notary server
- if response["server_name"] != server_name:
- raise KeyLookupError("Expected a response for server %r not %r" % (
- server_name, response["server_name"]
- ))
+ response (dict): the json-decoded Server Keys response object
- response_keys = yield self.process_v2_response(
- from_server=server_name,
- requested_ids=[requested_key_id],
- response_json=response,
- )
+ time_added_ms (int): the timestamp to record in server_keys_json
- keys.update(response_keys)
+ Returns:
+ Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
+ """
+ if (
+ u"signatures" not in response
+ or perspective_name not in response[u"signatures"]
+ ):
+ raise KeyLookupError("Response not signed by the notary server")
+
+ verified = False
+ for key_id in response[u"signatures"][perspective_name]:
+ if key_id in perspective_keys:
+ verify_signed_json(response, perspective_name, perspective_keys[key_id])
+ verified = True
+
+ if not verified:
+ raise KeyLookupError(
+ "Response not signed with a known key: signed with: %r, known keys: %r"
+ % (
+ list(response[u"signatures"][perspective_name].keys()),
+ list(perspective_keys.keys()),
+ )
+ )
- yield self.store_keys(
- server_name=server_name,
- from_server=server_name,
- verify_keys=keys,
+ return self.process_v2_response(
+ perspective_name, response, time_added_ms=time_added_ms
)
- defer.returnValue({server_name: keys})
- @defer.inlineCallbacks
- def process_v2_response(
- self, from_server, response_json, requested_ids=[],
- ):
- """Parse a 'Server Keys' structure from the result of a /key request
- This is used to parse either the entirety of the response from
- GET /_matrix/key/v2/server, or a single entry from the list returned by
- POST /_matrix/key/v2/query.
-
- Checks that each signature in the response that claims to come from the origin
- server is valid. (Does not check that there actually is such a signature, for
- some reason.)
+class ServerKeyFetcher(BaseV2KeyFetcher):
+ """KeyFetcher impl which fetches keys from the origin servers"""
- Stores the json in server_keys_json so that it can be used for future responses
- to /_matrix/key/v2/query.
+ def __init__(self, hs):
+ super(ServerKeyFetcher, self).__init__(hs)
+ self.clock = hs.get_clock()
+ self.client = hs.get_http_client()
+ def get_keys(self, keys_to_fetch):
+ """
Args:
- from_server (str): the name of the server producing this result: either
- the origin server for a /_matrix/key/v2/server request, or the notary
- for a /_matrix/key/v2/query.
-
- response_json (dict): the json-decoded Server Keys response object
-
- requested_ids (iterable[str]): a list of the key IDs that were requested.
- We will store the json for these key ids as well as any that are
- actually in the response
+ keys_to_fetch (dict[str, iterable[str]]):
+ the keys to be fetched. server_name -> key_ids
Returns:
- Deferred[dict[str, nacl.signing.VerifyKey]]:
- map from key_id to key object
+ Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
+ map from server_name -> key_id -> FetchKeyResult
"""
- time_now_ms = self.clock.time_msec()
- response_keys = {}
- verify_keys = {}
- for key_id, key_data in response_json["verify_keys"].items():
- if is_signing_algorithm_supported(key_id):
- key_base64 = key_data["key"]
- key_bytes = decode_base64(key_base64)
- verify_key = decode_verify_key_bytes(key_id, key_bytes)
- verify_key.time_added = time_now_ms
- verify_keys[key_id] = verify_key
- old_verify_keys = {}
- for key_id, key_data in response_json["old_verify_keys"].items():
- if is_signing_algorithm_supported(key_id):
- key_base64 = key_data["key"]
- key_bytes = decode_base64(key_base64)
- verify_key = decode_verify_key_bytes(key_id, key_bytes)
- verify_key.expired = key_data["expired_ts"]
- verify_key.time_added = time_now_ms
- old_verify_keys[key_id] = verify_key
+ results = {}
- server_name = response_json["server_name"]
- for key_id in response_json["signatures"].get(server_name, {}):
- if key_id not in response_json["verify_keys"]:
- raise KeyLookupError(
- "Key response must include verification keys for all"
- " signatures"
- )
- if key_id in verify_keys:
- verify_signed_json(
- response_json,
- server_name,
- verify_keys[key_id]
+ @defer.inlineCallbacks
+ def get_key(key_to_fetch_item):
+ server_name, key_ids = key_to_fetch_item
+ try:
+ keys = yield self.get_server_verify_key_v2_direct(server_name, key_ids)
+ results[server_name] = keys
+ except KeyLookupError as e:
+ logger.warning(
+ "Error looking up keys %s from %s: %s", key_ids, server_name, e
)
+ except Exception:
+ logger.exception("Error getting keys %s from %s", key_ids, server_name)
- signed_key_json = sign_json(
- response_json,
- self.config.server_name,
- self.config.signing_key[0],
+ return yieldable_gather_results(get_key, keys_to_fetch.items()).addCallback(
+ lambda _: results
)
- signed_key_json_bytes = encode_canonical_json(signed_key_json)
- ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
-
- updated_key_ids = set(requested_ids)
- updated_key_ids.update(verify_keys)
- updated_key_ids.update(old_verify_keys)
-
- response_keys.update(verify_keys)
- response_keys.update(old_verify_keys)
-
- yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- run_in_background(
- self.store.store_server_keys_json,
- server_name=server_name,
- key_id=key_id,
- from_server=from_server,
- ts_now_ms=time_now_ms,
- ts_expires_ms=ts_valid_until_ms,
- key_json_bytes=signed_key_json_bytes,
- )
- for key_id in updated_key_ids
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError))
-
- defer.returnValue(response_keys)
+ @defer.inlineCallbacks
+ def get_server_verify_key_v2_direct(self, server_name, key_ids):
+ """
- def store_keys(self, server_name, from_server, verify_keys):
- """Store a collection of verify keys for a given server
Args:
- server_name(str): The name of the server the keys are for.
- from_server(str): The server the keys were downloaded from.
- verify_keys(dict): A mapping of key_id to VerifyKey.
+ server_name (str):
+ key_ids (iterable[str]):
+
Returns:
- A deferred that completes when the keys are stored.
+ Deferred[dict[str, FetchKeyResult]]: map from key ID to lookup result
+
+ Raises:
+ KeyLookupError if there was a problem making the lookup
"""
- # TODO(markjh): Store whether the keys have expired.
- return logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- run_in_background(
- self.store.store_server_verify_key,
- server_name, server_name, key.time_added, key
+ keys = {} # type: dict[str, FetchKeyResult]
+
+ for requested_key_id in key_ids:
+ # we may have found this key as a side-effect of asking for another.
+ if requested_key_id in keys:
+ continue
+
+ time_now_ms = self.clock.time_msec()
+ try:
+ response = yield self.client.get_json(
+ destination=server_name,
+ path="/_matrix/key/v2/server/"
+ + urllib.parse.quote(requested_key_id),
+ ignore_backoff=True,
+
+ # we only give the remote server 10s to respond. It should be an
+ # easy request to handle, so if it doesn't reply within 10s, it's
+ # probably not going to.
+ #
+ # Furthermore, when we are acting as a notary server, we cannot
+ # wait all day for all of the origin servers, as the requesting
+ # server will otherwise time out before we can respond.
+ #
+ # (Note that get_json may make 4 attempts, so this can still take
+ # almost 45 seconds to fetch the headers, plus up to another 60s to
+ # read the response).
+ timeout=10000,
)
- for key_id, key in verify_keys.items()
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError))
+ except (NotRetryingDestination, RequestSendFailed) as e:
+ raise_from(KeyLookupError("Failed to connect to remote server"), e)
+ except HttpResponseException as e:
+ raise_from(KeyLookupError("Remote server returned an error"), e)
+
+ if response["server_name"] != server_name:
+ raise KeyLookupError(
+ "Expected a response for server %r not %r"
+ % (server_name, response["server_name"])
+ )
+
+ response_keys = yield self.process_v2_response(
+ from_server=server_name,
+ response_json=response,
+ time_added_ms=time_now_ms,
+ )
+ yield self.store.store_server_verify_keys(
+ server_name,
+ time_now_ms,
+ ((server_name, key_id, key) for key_id, key in response_keys.items()),
+ )
+ keys.update(response_keys)
+
+ defer.returnValue(keys)
@defer.inlineCallbacks
@@ -693,48 +861,25 @@ def _handle_key_deferred(verify_request):
SynapseError if there was a problem performing the verification
"""
server_name = verify_request.server_name
- try:
- with PreserveLoggingContext():
- _, key_id, verify_key = yield verify_request.deferred
- except KeyLookupError as e:
- logger.warn(
- "Failed to download keys for %s: %s %s",
- server_name, type(e).__name__, str(e),
- )
- raise SynapseError(
- 502,
- "Error downloading keys for %s" % (server_name,),
- Codes.UNAUTHORIZED,
- )
- except Exception as e:
- logger.exception(
- "Got Exception when downloading keys for %s: %s %s",
- server_name, type(e).__name__, str(e),
- )
- raise SynapseError(
- 401,
- "No key for %s with id %s" % (server_name, verify_request.key_ids),
- Codes.UNAUTHORIZED,
- )
+ with PreserveLoggingContext():
+ _, key_id, verify_key = yield verify_request.key_ready
json_object = verify_request.json_object
- logger.debug("Got key %s %s:%s for server %s, verifying" % (
- key_id, verify_key.alg, verify_key.version, server_name,
- ))
try:
verify_signed_json(json_object, server_name, verify_key)
except SignatureVerifyException as e:
logger.debug(
"Error verifying signature for %s:%s:%s with key %s: %s",
- server_name, verify_key.alg, verify_key.version,
+ server_name,
+ verify_key.alg,
+ verify_key.version,
encode_verify_key_base64(verify_key),
str(e),
)
raise SynapseError(
401,
- "Invalid signature for server %s with key %s:%s: %s" % (
- server_name, verify_key.alg, verify_key.version, str(e),
- ),
+ "Invalid signature for server %s with key %s:%s: %s"
+ % (server_name, verify_key.alg, verify_key.version, str(e)),
Codes.UNAUTHORIZED,
)
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 12056d5be2..1edd19cc13 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -21,6 +21,7 @@ import six
from unpaddedbase64 import encode_base64
+from synapse.api.errors import UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions
from synapse.util.caches import intern_dict
from synapse.util.frozenutils import freeze
@@ -335,13 +336,32 @@ class FrozenEventV2(EventBase):
return self.__repr__()
def __repr__(self):
- return "<FrozenEventV2 event_id='%s', type='%s', state_key='%s'>" % (
+ return "<%s event_id='%s', type='%s', state_key='%s'>" % (
+ self.__class__.__name__,
self.event_id,
self.get("type", None),
self.get("state_key", None),
)
+class FrozenEventV3(FrozenEventV2):
+ """FrozenEventV3, which differs from FrozenEventV2 only in the event_id format"""
+ format_version = EventFormatVersions.V3 # All events of this type are V3
+
+ @property
+ def event_id(self):
+ # We have to import this here as otherwise we get an import loop which
+ # is hard to break.
+ from synapse.crypto.event_signing import compute_event_reference_hash
+
+ if self._event_id:
+ return self._event_id
+ self._event_id = "$" + encode_base64(
+ compute_event_reference_hash(self)[1], urlsafe=True
+ )
+ return self._event_id
+
+
def room_version_to_event_format(room_version):
"""Converts a room version string to the event format
@@ -350,12 +370,15 @@ def room_version_to_event_format(room_version):
Returns:
int
+
+ Raises:
+ UnsupportedRoomVersionError if the room version is unknown
"""
v = KNOWN_ROOM_VERSIONS.get(room_version)
if not v:
- # We should have already checked version, so this should not happen
- raise RuntimeError("Unrecognized room version %s" % (room_version,))
+ # this can happen if support is withdrawn for a room version
+ raise UnsupportedRoomVersionError()
return v.event_format
@@ -376,6 +399,8 @@ def event_type_from_format_version(format_version):
return FrozenEvent
elif format_version == EventFormatVersions.V2:
return FrozenEventV2
+ elif format_version == EventFormatVersions.V3:
+ return FrozenEventV3
else:
raise Exception(
"No event format %r" % (format_version,)
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index fba27177c7..546b6f4982 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -18,6 +18,7 @@ import attr
from twisted.internet import defer
from synapse.api.constants import MAX_DEPTH
+from synapse.api.errors import UnsupportedRoomVersionError
from synapse.api.room_versions import (
KNOWN_EVENT_FORMAT_VERSIONS,
KNOWN_ROOM_VERSIONS,
@@ -75,6 +76,7 @@ class EventBuilder(object):
# someone tries to get them when they don't exist.
_state_key = attr.ib(default=None)
_redacts = attr.ib(default=None)
+ _origin_server_ts = attr.ib(default=None)
internal_metadata = attr.ib(default=attr.Factory(lambda: _EventInternalMetadata({})))
@@ -141,6 +143,9 @@ class EventBuilder(object):
if self._redacts is not None:
event_dict["redacts"] = self._redacts
+ if self._origin_server_ts is not None:
+ event_dict["origin_server_ts"] = self._origin_server_ts
+
defer.returnValue(
create_local_event_from_event_dict(
clock=self._clock,
@@ -178,9 +183,8 @@ class EventBuilderFactory(object):
"""
v = KNOWN_ROOM_VERSIONS.get(room_version)
if not v:
- raise Exception(
- "No event format defined for version %r" % (room_version,)
- )
+ # this can happen if support is withdrawn for a room version
+ raise UnsupportedRoomVersionError()
return self.for_room_version(v, key_values)
def for_room_version(self, room_version, key_values):
@@ -209,6 +213,7 @@ class EventBuilderFactory(object):
content=key_values.get("content", {}),
unsigned=key_values.get("unsigned", {}),
redacts=key_values.get("redacts", None),
+ origin_server_ts=key_values.get("origin_server_ts", None),
)
@@ -245,7 +250,7 @@ def create_local_event_from_event_dict(clock, hostname, signing_key,
event_dict["event_id"] = _create_event_id(clock, hostname)
event_dict["origin"] = hostname
- event_dict["origin_server_ts"] = time_now
+ event_dict.setdefault("origin_server_ts", time_now)
event_dict.setdefault("unsigned", {})
age = event_dict["unsigned"].pop("age", 0)
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 368b5f6ae4..fa09c132a0 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -187,7 +187,9 @@ class EventContext(object):
Returns:
Deferred[dict[(str, str), str]|None]: Returns None if state_group
- is None, which happens when the associated event is an outlier.
+ is None, which happens when the associated event is an outlier.
+ Maps a (type, state_key) to the event ID of the state event matching
+ this tuple.
"""
if not self._fetching_state_deferred:
@@ -205,7 +207,9 @@ class EventContext(object):
Returns:
Deferred[dict[(str, str), str]|None]: Returns None if state_group
- is None, which happens when the associated event is an outlier.
+ is None, which happens when the associated event is an outlier.
+ Maps a (type, state_key) to the event ID of the state event matching
+ this tuple.
"""
if not self._fetching_state_deferred:
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 07fccdd8f9..e2d4384de1 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -19,7 +19,10 @@ from six import string_types
from frozendict import frozendict
-from synapse.api.constants import EventTypes
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, RelationTypes
+from synapse.util.async_helpers import yieldable_gather_results
from . import EventBase
@@ -311,3 +314,93 @@ def serialize_event(e, time_now_ms, as_client_event=True,
d = only_fields(d, only_event_fields)
return d
+
+
+class EventClientSerializer(object):
+ """Serializes events that are to be sent to clients.
+
+ This is used for bundling extra information with any events to be sent to
+ clients.
+ """
+
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.experimental_msc1849_support_enabled = (
+ hs.config.experimental_msc1849_support_enabled
+ )
+
+ @defer.inlineCallbacks
+ def serialize_event(self, event, time_now, bundle_aggregations=True, **kwargs):
+ """Serializes a single event.
+
+ Args:
+ event (EventBase)
+ time_now (int): The current time in milliseconds
+ bundle_aggregations (bool): Whether to bundle in related events
+ **kwargs: Arguments to pass to `serialize_event`
+
+ Returns:
+ Deferred[dict]: The serialized event
+ """
+ # To handle the case of presence events and the like
+ if not isinstance(event, EventBase):
+ defer.returnValue(event)
+
+ event_id = event.event_id
+ serialized_event = serialize_event(event, time_now, **kwargs)
+
+ # If MSC1849 is enabled then we need to look if thre are any relations
+ # we need to bundle in with the event
+ if self.experimental_msc1849_support_enabled and bundle_aggregations:
+ annotations = yield self.store.get_aggregation_groups_for_event(
+ event_id,
+ )
+ references = yield self.store.get_relations_for_event(
+ event_id, RelationTypes.REFERENCE, direction="f",
+ )
+
+ if annotations.chunk:
+ r = serialized_event["unsigned"].setdefault("m.relations", {})
+ r[RelationTypes.ANNOTATION] = annotations.to_dict()
+
+ if references.chunk:
+ r = serialized_event["unsigned"].setdefault("m.relations", {})
+ r[RelationTypes.REFERENCE] = references.to_dict()
+
+ edit = None
+ if event.type == EventTypes.Message:
+ edit = yield self.store.get_applicable_edit(event_id)
+
+ if edit:
+ # If there is an edit replace the content, preserving existing
+ # relations.
+
+ relations = event.content.get("m.relates_to")
+ serialized_event["content"] = edit.content.get("m.new_content", {})
+ if relations:
+ serialized_event["content"]["m.relates_to"] = relations
+ else:
+ serialized_event["content"].pop("m.relates_to", None)
+
+ r = serialized_event["unsigned"].setdefault("m.relations", {})
+ r[RelationTypes.REPLACE] = {
+ "event_id": edit.event_id,
+ }
+
+ defer.returnValue(serialized_event)
+
+ def serialize_events(self, events, time_now, **kwargs):
+ """Serializes multiple events.
+
+ Args:
+ event (iter[EventBase])
+ time_now (int): The current time in milliseconds
+ **kwargs: Arguments to pass to `serialize_event`
+
+ Returns:
+ Deferred[list[dict]]: The list of serialized events
+ """
+ return yieldable_gather_results(
+ self.serialize_event, events,
+ time_now=time_now, **kwargs
+ )
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 514273c792..711af512b2 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -15,8 +15,8 @@
from six import string_types
-from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import SynapseError
+from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes, Membership
+from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import EventFormatVersions
from synapse.types import EventID, RoomID, UserID
@@ -56,6 +56,17 @@ class EventValidator(object):
if not isinstance(getattr(event, s), string_types):
raise SynapseError(400, "'%s' not a string type" % (s,))
+ if event.type == EventTypes.Aliases:
+ if "aliases" in event.content:
+ for alias in event.content["aliases"]:
+ if len(alias) > MAX_ALIAS_LENGTH:
+ raise SynapseError(
+ 400,
+ ("Can't create aliases longer than"
+ " %d characters" % (MAX_ALIAS_LENGTH,)),
+ Codes.INVALID_PARAM,
+ )
+
def validate_builder(self, event):
"""Validates that the builder/event has roughly the right format. Only
checks values that we expect a proto event to have, rather than all the
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index dfe6b4aa5c..4b38f7c759 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -265,11 +265,22 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
]
more_deferreds = keyring.verify_json_objects_for_server([
- (p.sender_domain, p.redacted_pdu_json)
+ (p.sender_domain, p.redacted_pdu_json, 0)
for p in pdus_to_check_sender
])
+ def sender_err(e, pdu_to_check):
+ errmsg = "event id %s: unable to verify signature for sender %s: %s" % (
+ pdu_to_check.pdu.event_id,
+ pdu_to_check.sender_domain,
+ e.getErrorMessage(),
+ )
+ # XX not really sure if these are the right codes, but they are what
+ # we've done for ages
+ raise SynapseError(400, errmsg, Codes.UNAUTHORIZED)
+
for p, d in zip(pdus_to_check_sender, more_deferreds):
+ d.addErrback(sender_err, p)
p.deferreds.append(d)
# now let's look for events where the sender's domain is different to the
@@ -287,11 +298,22 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
]
more_deferreds = keyring.verify_json_objects_for_server([
- (get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json)
+ (get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json, 0)
for p in pdus_to_check_event_id
])
+ def event_err(e, pdu_to_check):
+ errmsg = (
+ "event id %s: unable to verify signature for event id domain: %s" % (
+ pdu_to_check.pdu.event_id,
+ e.getErrorMessage(),
+ )
+ )
+ # XX as above: not really sure if these are the right codes
+ raise SynapseError(400, errmsg, Codes.UNAUTHORIZED)
+
for p, d in zip(pdus_to_check_event_id, more_deferreds):
+ d.addErrback(event_err, p)
p.deferreds.append(d)
# replace lists of deferreds with single Deferreds
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index f3fc897a0a..70573746d6 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -17,7 +17,6 @@
import copy
import itertools
import logging
-import random
from six.moves import range
@@ -233,7 +232,8 @@ class FederationClient(FederationBase):
moving to the next destination. None indicates no timeout.
Returns:
- Deferred: Results in the requested PDU.
+ Deferred: Results in the requested PDU, or None if we were unable to find
+ it.
"""
# TODO: Rate limit the number of times we try and get the same event.
@@ -258,7 +258,12 @@ class FederationClient(FederationBase):
destination, event_id, timeout=timeout,
)
- logger.debug("transaction_data %r", transaction_data)
+ logger.debug(
+ "retrieved event id %s from %s: %r",
+ event_id,
+ destination,
+ transaction_data,
+ )
pdu_list = [
event_from_pdu_json(p, format_ver, outlier=outlier)
@@ -280,6 +285,7 @@ class FederationClient(FederationBase):
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
)
+ continue
except NotRetryingDestination as e:
logger.info(str(e))
continue
@@ -326,12 +332,16 @@ class FederationClient(FederationBase):
state_event_ids = result["pdu_ids"]
auth_event_ids = result.get("auth_chain_ids", [])
- fetched_events, failed_to_fetch = yield self.get_events(
- [destination], room_id, set(state_event_ids + auth_event_ids)
+ fetched_events, failed_to_fetch = yield self.get_events_from_store_or_dest(
+ destination, room_id, set(state_event_ids + auth_event_ids)
)
if failed_to_fetch:
- logger.warn("Failed to get %r", failed_to_fetch)
+ logger.warning(
+ "Failed to fetch missing state/auth events for %s: %s",
+ room_id,
+ failed_to_fetch
+ )
event_map = {
ev.event_id: ev for ev in fetched_events
@@ -397,27 +407,20 @@ class FederationClient(FederationBase):
defer.returnValue((signed_pdus, signed_auth))
@defer.inlineCallbacks
- def get_events(self, destinations, room_id, event_ids, return_local=True):
- """Fetch events from some remote destinations, checking if we already
- have them.
+ def get_events_from_store_or_dest(self, destination, room_id, event_ids):
+ """Fetch events from a remote destination, checking if we already have them.
Args:
- destinations (list)
+ destination (str)
room_id (str)
event_ids (list)
- return_local (bool): Whether to include events we already have in
- the DB in the returned list of events
Returns:
Deferred: A deferred resolving to a 2-tuple where the first is a list of
events and the second is a list of event ids that we failed to fetch.
"""
- if return_local:
- seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
- signed_events = list(seen_events.values())
- else:
- seen_events = yield self.store.have_seen_events(event_ids)
- signed_events = []
+ seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
+ signed_events = list(seen_events.values())
failed_to_fetch = set()
@@ -428,10 +431,11 @@ class FederationClient(FederationBase):
if not missing_events:
defer.returnValue((signed_events, failed_to_fetch))
- def random_server_list():
- srvs = list(destinations)
- random.shuffle(srvs)
- return srvs
+ logger.debug(
+ "Fetching unknown state/auth events %s for room %s",
+ missing_events,
+ event_ids,
+ )
room_version = yield self.store.get_room_version(room_id)
@@ -443,7 +447,7 @@ class FederationClient(FederationBase):
deferreds = [
run_in_background(
self.get_pdu,
- destinations=random_server_list(),
+ destinations=[destination],
event_id=e_id,
room_version=room_version,
)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index df60828dba..4c28c1dc3c 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -33,6 +33,7 @@ from synapse.api.errors import (
IncompatibleRoomVersionError,
NotFoundError,
SynapseError,
+ UnsupportedRoomVersionError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.crypto.event_signing import compute_event_signature
@@ -198,11 +199,22 @@ class FederationServer(FederationBase):
try:
room_version = yield self.store.get_room_version(room_id)
- format_ver = room_version_to_event_format(room_version)
except NotFoundError:
logger.info("Ignoring PDU for unknown room_id: %s", room_id)
continue
+ try:
+ format_ver = room_version_to_event_format(room_version)
+ except UnsupportedRoomVersionError:
+ # this can happen if support for a given room version is withdrawn,
+ # so that we still get events for said room.
+ logger.info(
+ "Ignoring PDU for room %s with unknown version %s",
+ room_id,
+ room_version,
+ )
+ continue
+
event = event_from_pdu_json(p, format_ver)
pdus_by_room.setdefault(room_id, []).append(event)
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index be99211003..fae8bea392 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -33,12 +33,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage import UserPresenceState
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
+# This is defined in the Matrix spec and enforced by the receiver.
+MAX_EDUS_PER_TRANSACTION = 100
+
logger = logging.getLogger(__name__)
sent_edus_counter = Counter(
- "synapse_federation_client_sent_edus",
- "Total number of EDUs successfully sent",
+ "synapse_federation_client_sent_edus", "Total number of EDUs successfully sent"
)
sent_edus_by_type = Counter(
@@ -58,6 +60,7 @@ class PerDestinationQueue(object):
destination (str): the server_name of the destination that we are managing
transmission for.
"""
+
def __init__(self, hs, transaction_manager, destination):
self._server_name = hs.hostname
self._clock = hs.get_clock()
@@ -68,17 +71,17 @@ class PerDestinationQueue(object):
self.transmission_loop_running = False
# a list of tuples of (pending pdu, order)
- self._pending_pdus = [] # type: list[tuple[EventBase, int]]
- self._pending_edus = [] # type: list[Edu]
+ self._pending_pdus = [] # type: list[tuple[EventBase, int]]
+ self._pending_edus = [] # type: list[Edu]
# Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered
# based on their key (e.g. typing events by room_id)
# Map of (edu_type, key) -> Edu
- self._pending_edus_keyed = {} # type: dict[tuple[str, str], Edu]
+ self._pending_edus_keyed = {} # type: dict[tuple[str, str], Edu]
# Map of user_id -> UserPresenceState of pending presence to be sent to this
# destination
- self._pending_presence = {} # type: dict[str, UserPresenceState]
+ self._pending_presence = {} # type: dict[str, UserPresenceState]
# room_id -> receipt_type -> user_id -> receipt_dict
self._pending_rrs = {}
@@ -120,9 +123,7 @@ class PerDestinationQueue(object):
Args:
states (iterable[UserPresenceState]): presence to send
"""
- self._pending_presence.update({
- state.user_id: state for state in states
- })
+ self._pending_presence.update({state.user_id: state for state in states})
self.attempt_new_transaction()
def queue_read_receipt(self, receipt):
@@ -132,14 +133,9 @@ class PerDestinationQueue(object):
Args:
receipt (synapse.api.receipt_info.ReceiptInfo): receipt to be queued
"""
- self._pending_rrs.setdefault(
- receipt.room_id, {},
- ).setdefault(
+ self._pending_rrs.setdefault(receipt.room_id, {}).setdefault(
receipt.receipt_type, {}
- )[receipt.user_id] = {
- "event_ids": receipt.event_ids,
- "data": receipt.data,
- }
+ )[receipt.user_id] = {"event_ids": receipt.event_ids, "data": receipt.data}
def flush_read_receipts_for_room(self, room_id):
# if we don't have any read-receipts for this room, it may be that we've already
@@ -170,10 +166,7 @@ class PerDestinationQueue(object):
# request at which point pending_pdus just keeps growing.
# we need application-layer timeouts of some flavour of these
# requests
- logger.debug(
- "TX [%s] Transaction already in progress",
- self._destination
- )
+ logger.debug("TX [%s] Transaction already in progress", self._destination)
return
logger.debug("TX [%s] Starting transaction loop", self._destination)
@@ -197,7 +190,8 @@ class PerDestinationQueue(object):
pending_pdus = []
while True:
device_message_edus, device_stream_id, dev_list_id = (
- yield self._get_new_device_messages()
+ # We have to keep 2 free slots for presence and rr_edus
+ yield self._get_new_device_messages(MAX_EDUS_PER_TRANSACTION - 2)
)
# BEGIN CRITICAL SECTION
@@ -216,19 +210,9 @@ class PerDestinationQueue(object):
pending_edus = []
- pending_edus.extend(self._get_rr_edus(force_flush=False))
-
# We can only include at most 100 EDUs per transactions
- pending_edus.extend(self._pop_pending_edus(100 - len(pending_edus)))
-
- pending_edus.extend(
- self._pending_edus_keyed.values()
- )
-
- self._pending_edus_keyed = {}
-
- pending_edus.extend(device_message_edus)
-
+ # rr_edus and pending_presence take at most one slot each
+ pending_edus.extend(self._get_rr_edus(force_flush=False))
pending_presence = self._pending_presence
self._pending_presence = {}
if pending_presence:
@@ -248,9 +232,23 @@ class PerDestinationQueue(object):
)
)
+ pending_edus.extend(device_message_edus)
+ pending_edus.extend(
+ self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))
+ )
+ while (
+ len(pending_edus) < MAX_EDUS_PER_TRANSACTION
+ and self._pending_edus_keyed
+ ):
+ _, val = self._pending_edus_keyed.popitem()
+ pending_edus.append(val)
+
if pending_pdus:
- logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
- self._destination, len(pending_pdus))
+ logger.debug(
+ "TX [%s] len(pending_pdus_by_dest[dest]) = %d",
+ self._destination,
+ len(pending_pdus),
+ )
if not pending_pdus and not pending_edus:
logger.debug("TX [%s] Nothing to send", self._destination)
@@ -259,7 +257,7 @@ class PerDestinationQueue(object):
# if we've decided to send a transaction anyway, and we have room, we
# may as well send any pending RRs
- if len(pending_edus) < 100:
+ if len(pending_edus) < MAX_EDUS_PER_TRANSACTION:
pending_edus.extend(self._get_rr_edus(force_flush=True))
# END CRITICAL SECTION
@@ -303,22 +301,25 @@ class PerDestinationQueue(object):
except HttpResponseException as e:
logger.warning(
"TX [%s] Received %d response to transaction: %s",
- self._destination, e.code, e,
+ self._destination,
+ e.code,
+ e,
)
except RequestSendFailed as e:
- logger.warning("TX [%s] Failed to send transaction: %s", self._destination, e)
+ logger.warning(
+ "TX [%s] Failed to send transaction: %s", self._destination, e
+ )
for p, _ in pending_pdus:
- logger.info("Failed to send event %s to %s", p.event_id,
- self._destination)
+ logger.info(
+ "Failed to send event %s to %s", p.event_id, self._destination
+ )
except Exception:
- logger.exception(
- "TX [%s] Failed to send transaction",
- self._destination,
- )
+ logger.exception("TX [%s] Failed to send transaction", self._destination)
for p, _ in pending_pdus:
- logger.info("Failed to send event %s to %s", p.event_id,
- self._destination)
+ logger.info(
+ "Failed to send event %s to %s", p.event_id, self._destination
+ )
finally:
# We want to be *very* sure we clear this after we stop processing
self.transmission_loop_running = False
@@ -346,33 +347,40 @@ class PerDestinationQueue(object):
return pending_edus
@defer.inlineCallbacks
- def _get_new_device_messages(self):
- last_device_stream_id = self._last_device_stream_id
- to_device_stream_id = self._store.get_to_device_stream_token()
- contents, stream_id = yield self._store.get_new_device_msgs_for_remote(
- self._destination, last_device_stream_id, to_device_stream_id
+ def _get_new_device_messages(self, limit):
+ last_device_list = self._last_device_list_stream_id
+ # Will return at most 20 entries
+ now_stream_id, results = yield self._store.get_devices_by_remote(
+ self._destination, last_device_list
)
edus = [
Edu(
origin=self._server_name,
destination=self._destination,
- edu_type="m.direct_to_device",
+ edu_type="m.device_list_update",
content=content,
)
- for content in contents
+ for content in results
]
- last_device_list = self._last_device_list_stream_id
- now_stream_id, results = yield self._store.get_devices_by_remote(
- self._destination, last_device_list
+ assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"
+
+ last_device_stream_id = self._last_device_stream_id
+ to_device_stream_id = self._store.get_to_device_stream_token()
+ contents, stream_id = yield self._store.get_new_device_msgs_for_remote(
+ self._destination,
+ last_device_stream_id,
+ to_device_stream_id,
+ limit - len(edus),
)
edus.extend(
Edu(
origin=self._server_name,
destination=self._destination,
- edu_type="m.device_list_update",
+ edu_type="m.direct_to_device",
content=content,
)
- for content in results
+ for content in contents
)
+
defer.returnValue((edus, stream_id, now_stream_id))
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 452599e1a1..0db8858cf1 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -23,7 +23,11 @@ from twisted.internet import defer
import synapse
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
from synapse.api.room_versions import RoomVersions
-from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX
+from synapse.api.urls import (
+ FEDERATION_UNSTABLE_PREFIX,
+ FEDERATION_V1_PREFIX,
+ FEDERATION_V2_PREFIX,
+)
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.http.server import JsonResource
from synapse.http.servlet import (
@@ -63,11 +67,7 @@ class TransportLayerServer(JsonResource):
self.authenticator = Authenticator(hs)
self.ratelimiter = FederationRateLimiter(
self.clock,
- window_size=hs.config.federation_rc_window_size,
- sleep_limit=hs.config.federation_rc_sleep_limit,
- sleep_msec=hs.config.federation_rc_sleep_delay,
- reject_limit=hs.config.federation_rc_reject_limit,
- concurrent_requests=hs.config.federation_rc_concurrent,
+ config=hs.config.rc_federation,
)
self.register_servlets()
@@ -94,6 +94,7 @@ class NoAuthenticationError(AuthenticationError):
class Authenticator(object):
def __init__(self, hs):
+ self._clock = hs.get_clock()
self.keyring = hs.get_keyring()
self.server_name = hs.hostname
self.store = hs.get_datastore()
@@ -102,6 +103,7 @@ class Authenticator(object):
# A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks
def authenticate_request(self, request, content):
+ now = self._clock.time_msec()
json_request = {
"method": request.method.decode('ascii'),
"uri": request.uri.decode('ascii'),
@@ -138,7 +140,7 @@ class Authenticator(object):
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
)
- yield self.keyring.verify_json_for_server(origin, json_request)
+ yield self.keyring.verify_json_for_server(origin, json_request, now)
logger.info("Request from %s", origin)
request.authenticated_entity = origin
@@ -716,8 +718,17 @@ class PublicRoomList(BaseFederationServlet):
PATH = "/publicRooms"
+ def __init__(self, handler, authenticator, ratelimiter, server_name, deny_access):
+ super(PublicRoomList, self).__init__(
+ handler, authenticator, ratelimiter, server_name,
+ )
+ self.deny_access = deny_access
+
@defer.inlineCallbacks
def on_GET(self, origin, content, query):
+ if self.deny_access:
+ raise FederationDeniedError(origin)
+
limit = parse_integer_from_args(query, "limit", 0)
since_token = parse_string_from_args(query, "since", None)
include_all_networks = parse_boolean_from_args(
@@ -1299,6 +1310,30 @@ class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
defer.returnValue((200, new_content))
+class RoomComplexityServlet(BaseFederationServlet):
+ """
+ Indicates to other servers how complex (and therefore likely
+ resource-intensive) a public room this server knows about is.
+ """
+ PATH = "/rooms/(?P<room_id>[^/]*)/complexity"
+ PREFIX = FEDERATION_UNSTABLE_PREFIX
+
+ @defer.inlineCallbacks
+ def on_GET(self, origin, content, query, room_id):
+
+ store = self.handler.hs.get_datastore()
+
+ is_public = yield store.is_room_world_readable_or_publicly_joinable(
+ room_id
+ )
+
+ if not is_public:
+ raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM)
+
+ complexity = yield store.get_room_complexity(room_id)
+ defer.returnValue((200, complexity))
+
+
FEDERATION_SERVLET_CLASSES = (
FederationSendServlet,
FederationEventServlet,
@@ -1322,6 +1357,7 @@ FEDERATION_SERVLET_CLASSES = (
FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet,
FederationVersionServlet,
+ RoomComplexityServlet,
)
OPENID_SERVLET_CLASSES = (
@@ -1417,6 +1453,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N
authenticator=authenticator,
ratelimiter=ratelimiter,
server_name=hs.hostname,
+ deny_access=hs.config.restrict_public_rooms_to_local_users,
).register(resource)
if "group_server" in servlet_groups:
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index 786149be65..fa6b641ee1 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -97,10 +97,11 @@ class GroupAttestationSigning(object):
# TODO: We also want to check that *new* attestations that people give
# us to store are valid for at least a little while.
- if valid_until_ms < self.clock.time_msec():
+ now = self.clock.time_msec()
+ if valid_until_ms < now:
raise SynapseError(400, "Attestation expired")
- yield self.keyring.verify_json_for_server(server_name, attestation)
+ yield self.keyring.verify_json_for_server(server_name, attestation, now)
def create_attestation(self, group_id, user_id):
"""Create an attestation for the group_id and user_id with default
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index ac09d03ba9..dca337ec61 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -90,8 +90,8 @@ class BaseHandler(object):
messages_per_second = override.messages_per_second
burst_count = override.burst_count
else:
- messages_per_second = self.hs.config.rc_messages_per_second
- burst_count = self.hs.config.rc_message_burst_count
+ messages_per_second = self.hs.config.rc_message.per_second
+ burst_count = self.hs.config.rc_message.burst_count
allowed, time_allowed = self.ratelimiter.can_do_action(
user_id, time_now,
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 27bd06df5d..a12f9508d8 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -19,7 +19,7 @@ import string
from twisted.internet import defer
-from synapse.api.constants import EventTypes
+from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes
from synapse.api.errors import (
AuthError,
CodeMessageException,
@@ -43,8 +43,10 @@ class DirectoryHandler(BaseHandler):
self.state = hs.get_state_handler()
self.appservice_handler = hs.get_application_service_handler()
self.event_creation_handler = hs.get_event_creation_handler()
+ self.store = hs.get_datastore()
self.config = hs.config
self.enable_room_list_search = hs.config.enable_room_list_search
+ self.require_membership = hs.config.require_membership_for_aliases
self.federation = hs.get_federation_client()
hs.get_federation_registry().register_query_handler(
@@ -83,7 +85,7 @@ class DirectoryHandler(BaseHandler):
@defer.inlineCallbacks
def create_association(self, requester, room_alias, room_id, servers=None,
- send_event=True):
+ send_event=True, check_membership=True):
"""Attempt to create a new alias
Args:
@@ -93,6 +95,8 @@ class DirectoryHandler(BaseHandler):
servers (list[str]|None): List of servers that others servers
should try and join via
send_event (bool): Whether to send an updated m.room.aliases event
+ check_membership (bool): Whether to check if the user is in the room
+ before the alias can be set (if the server's config requires it).
Returns:
Deferred
@@ -100,6 +104,13 @@ class DirectoryHandler(BaseHandler):
user_id = requester.user.to_string()
+ if len(room_alias.to_string()) > MAX_ALIAS_LENGTH:
+ raise SynapseError(
+ 400,
+ "Can't create aliases longer than %s characters" % MAX_ALIAS_LENGTH,
+ Codes.INVALID_PARAM,
+ )
+
service = requester.app_service
if service:
if not service.is_interested_in_alias(room_alias.to_string()):
@@ -108,6 +119,14 @@ class DirectoryHandler(BaseHandler):
" this kind of alias.", errcode=Codes.EXCLUSIVE
)
else:
+ if self.require_membership and check_membership:
+ rooms_for_user = yield self.store.get_rooms_for_user(user_id)
+ if room_id not in rooms_for_user:
+ raise AuthError(
+ 403,
+ "You must be in the room to create an alias for it",
+ )
+
if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
raise AuthError(
403, "This user is not permitted to create this alias",
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 1b4d8c74ae..eb525070cf 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -21,7 +21,6 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, SynapseError
from synapse.events import EventBase
-from synapse.events.utils import serialize_event
from synapse.types import UserID
from synapse.util.logutils import log_function
from synapse.visibility import filter_events_for_client
@@ -50,6 +49,7 @@ class EventStreamHandler(BaseHandler):
self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
self._server_notices_sender = hs.get_server_notices_sender()
+ self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
@log_function
@@ -120,9 +120,12 @@ class EventStreamHandler(BaseHandler):
time_now = self.clock.time_msec()
- chunks = [
- serialize_event(e, time_now, as_client_event) for e in events
- ]
+ chunks = yield self._event_serializer.serialize_events(
+ events, time_now, as_client_event=as_client_event,
+ # We don't bundle "live" events, as otherwise clients
+ # will end up double counting annotations.
+ bundle_aggregations=False,
+ )
chunk = {
"chunk": chunks,
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 0684778882..cf4fad7de0 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1916,6 +1916,11 @@ class FederationHandler(BaseHandler):
event.room_id, latest_event_ids=extrem_ids,
)
+ logger.debug(
+ "Doing soft-fail check for %s: state %s",
+ event.event_id, current_state_ids,
+ )
+
# Now check if event pass auth against said current state
auth_types = auth_types_for_event(event)
current_state_ids = [
@@ -1932,7 +1937,7 @@ class FederationHandler(BaseHandler):
self.auth.check(room_version, event, auth_events=current_auth_events)
except AuthError as e:
logger.warn(
- "Failed current state auth resolution for %r because %s",
+ "Soft-failing %r because %s",
event, e,
)
event.internal_metadata.soft_failed = True
@@ -2008,15 +2013,44 @@ class FederationHandler(BaseHandler):
Args:
origin (str):
- event (synapse.events.FrozenEvent):
+ event (synapse.events.EventBase):
context (synapse.events.snapshot.EventContext):
- auth_events (dict[(str, str)->str]):
+ auth_events (dict[(str, str)->synapse.events.EventBase]):
+ Map from (event_type, state_key) to event
+
+ What we expect the event's auth_events to be, based on the event's
+ position in the dag. I think? maybe??
+
+ Also NB that this function adds entries to it.
+ Returns:
+ defer.Deferred[None]
+ """
+ room_version = yield self.store.get_room_version(event.room_id)
+
+ yield self._update_auth_events_and_context_for_auth(
+ origin, event, context, auth_events
+ )
+ try:
+ self.auth.check(room_version, event, auth_events=auth_events)
+ except AuthError as e:
+ logger.warn("Failed auth resolution for %r because %s", event, e)
+ raise e
+
+ @defer.inlineCallbacks
+ def _update_auth_events_and_context_for_auth(
+ self, origin, event, context, auth_events
+ ):
+ """Helper for do_auth. See there for docs.
+
+ Args:
+ origin (str):
+ event (synapse.events.EventBase):
+ context (synapse.events.snapshot.EventContext):
+ auth_events (dict[(str, str)->synapse.events.EventBase]):
Returns:
defer.Deferred[None]
"""
- # Check if we have all the auth events.
- current_state = set(e.event_id for e in auth_events.values())
event_auth_events = set(event.auth_event_ids())
if event.is_state():
@@ -2024,11 +2058,21 @@ class FederationHandler(BaseHandler):
else:
event_key = None
- if event_auth_events - current_state:
+ # if the event's auth_events refers to events which are not in our
+ # calculated auth_events, we need to fetch those events from somewhere.
+ #
+ # we start by fetching them from the store, and then try calling /event_auth/.
+ missing_auth = event_auth_events.difference(
+ e.event_id for e in auth_events.values()
+ )
+
+ if missing_auth:
# TODO: can we use store.have_seen_events here instead?
have_events = yield self.store.get_seen_events_with_rejections(
- event_auth_events - current_state
+ missing_auth
)
+ logger.debug("Got events %s from store", have_events)
+ missing_auth.difference_update(have_events.keys())
else:
have_events = {}
@@ -2037,13 +2081,12 @@ class FederationHandler(BaseHandler):
for e in auth_events.values()
})
- seen_events = set(have_events.keys())
-
- missing_auth = event_auth_events - seen_events - current_state
-
if missing_auth:
- logger.info("Missing auth: %s", missing_auth)
# If we don't have all the auth events, we need to get them.
+ logger.info(
+ "auth_events contains unknown events: %s",
+ missing_auth,
+ )
try:
remote_auth_chain = yield self.federation_client.get_event_auth(
origin, event.room_id, event.event_id
@@ -2084,145 +2127,168 @@ class FederationHandler(BaseHandler):
have_events = yield self.store.get_seen_events_with_rejections(
event.auth_event_ids()
)
- seen_events = set(have_events.keys())
except Exception:
# FIXME:
logger.exception("Failed to get auth chain")
+ if event.internal_metadata.is_outlier():
+ logger.info("Skipping auth_event fetch for outlier")
+ return
+
# FIXME: Assumes we have and stored all the state for all the
# prev_events
- current_state = set(e.event_id for e in auth_events.values())
- different_auth = event_auth_events - current_state
+ different_auth = event_auth_events.difference(
+ e.event_id for e in auth_events.values()
+ )
- room_version = yield self.store.get_room_version(event.room_id)
+ if not different_auth:
+ return
- if different_auth and not event.internal_metadata.is_outlier():
- # Do auth conflict res.
- logger.info("Different auth: %s", different_auth)
-
- different_events = yield logcontext.make_deferred_yieldable(
- defer.gatherResults([
- logcontext.run_in_background(
- self.store.get_event,
- d,
- allow_none=True,
- allow_rejected=False,
- )
- for d in different_auth
- if d in have_events and not have_events[d]
- ], consumeErrors=True)
- ).addErrback(unwrapFirstError)
-
- if different_events:
- local_view = dict(auth_events)
- remote_view = dict(auth_events)
- remote_view.update({
- (d.type, d.state_key): d for d in different_events if d
- })
+ logger.info(
+ "auth_events refers to events which are not in our calculated auth "
+ "chain: %s",
+ different_auth,
+ )
- new_state = yield self.state_handler.resolve_events(
- room_version,
- [list(local_view.values()), list(remote_view.values())],
- event
+ room_version = yield self.store.get_room_version(event.room_id)
+
+ different_events = yield logcontext.make_deferred_yieldable(
+ defer.gatherResults([
+ logcontext.run_in_background(
+ self.store.get_event,
+ d,
+ allow_none=True,
+ allow_rejected=False,
)
+ for d in different_auth
+ if d in have_events and not have_events[d]
+ ], consumeErrors=True)
+ ).addErrback(unwrapFirstError)
+
+ if different_events:
+ local_view = dict(auth_events)
+ remote_view = dict(auth_events)
+ remote_view.update({
+ (d.type, d.state_key): d for d in different_events if d
+ })
- auth_events.update(new_state)
+ new_state = yield self.state_handler.resolve_events(
+ room_version,
+ [list(local_view.values()), list(remote_view.values())],
+ event
+ )
- current_state = set(e.event_id for e in auth_events.values())
- different_auth = event_auth_events - current_state
+ logger.info(
+ "After state res: updating auth_events with new state %s",
+ {
+ (d.type, d.state_key): d.event_id for d in new_state.values()
+ if auth_events.get((d.type, d.state_key)) != d
+ },
+ )
- yield self._update_context_for_auth_events(
- event, context, auth_events, event_key,
- )
+ auth_events.update(new_state)
- if different_auth and not event.internal_metadata.is_outlier():
- logger.info("Different auth after resolution: %s", different_auth)
+ different_auth = event_auth_events.difference(
+ e.event_id for e in auth_events.values()
+ )
- # Only do auth resolution if we have something new to say.
- # We can't rove an auth failure.
- do_resolution = False
+ yield self._update_context_for_auth_events(
+ event, context, auth_events, event_key,
+ )
- provable = [
- RejectedReason.NOT_ANCESTOR, RejectedReason.NOT_ANCESTOR,
- ]
+ if not different_auth:
+ # we're done
+ return
- for e_id in different_auth:
- if e_id in have_events:
- if have_events[e_id] in provable:
- do_resolution = True
- break
+ logger.info(
+ "auth_events still refers to events which are not in the calculated auth "
+ "chain after state resolution: %s",
+ different_auth,
+ )
- if do_resolution:
- prev_state_ids = yield context.get_prev_state_ids(self.store)
- # 1. Get what we think is the auth chain.
- auth_ids = yield self.auth.compute_auth_events(
- event, prev_state_ids
- )
- local_auth_chain = yield self.store.get_auth_chain(
- auth_ids, include_given=True
- )
+ # Only do auth resolution if we have something new to say.
+ # We can't prove an auth failure.
+ do_resolution = False
- try:
- # 2. Get remote difference.
- result = yield self.federation_client.query_auth(
- origin,
- event.room_id,
- event.event_id,
- local_auth_chain,
- )
+ for e_id in different_auth:
+ if e_id in have_events:
+ if have_events[e_id] == RejectedReason.NOT_ANCESTOR:
+ do_resolution = True
+ break
- seen_remotes = yield self.store.have_seen_events(
- [e.event_id for e in result["auth_chain"]]
- )
+ if not do_resolution:
+ logger.info(
+ "Skipping auth resolution due to lack of provable rejection reasons"
+ )
+ return
- # 3. Process any remote auth chain events we haven't seen.
- for ev in result["auth_chain"]:
- if ev.event_id in seen_remotes:
- continue
+ logger.info("Doing auth resolution")
- if ev.event_id == event.event_id:
- continue
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
- try:
- auth_ids = ev.auth_event_ids()
- auth = {
- (e.type, e.state_key): e
- for e in result["auth_chain"]
- if e.event_id in auth_ids
- or event.type == EventTypes.Create
- }
- ev.internal_metadata.outlier = True
+ # 1. Get what we think is the auth chain.
+ auth_ids = yield self.auth.compute_auth_events(
+ event, prev_state_ids
+ )
+ local_auth_chain = yield self.store.get_auth_chain(
+ auth_ids, include_given=True
+ )
- logger.debug(
- "do_auth %s different_auth: %s",
- event.event_id, e.event_id
- )
+ try:
+ # 2. Get remote difference.
+ result = yield self.federation_client.query_auth(
+ origin,
+ event.room_id,
+ event.event_id,
+ local_auth_chain,
+ )
- yield self._handle_new_event(
- origin, ev, auth_events=auth
- )
+ seen_remotes = yield self.store.have_seen_events(
+ [e.event_id for e in result["auth_chain"]]
+ )
- if ev.event_id in event_auth_events:
- auth_events[(ev.type, ev.state_key)] = ev
- except AuthError:
- pass
+ # 3. Process any remote auth chain events we haven't seen.
+ for ev in result["auth_chain"]:
+ if ev.event_id in seen_remotes:
+ continue
- except Exception:
- # FIXME:
- logger.exception("Failed to query auth chain")
+ if ev.event_id == event.event_id:
+ continue
- # 4. Look at rejects and their proofs.
- # TODO.
+ try:
+ auth_ids = ev.auth_event_ids()
+ auth = {
+ (e.type, e.state_key): e
+ for e in result["auth_chain"]
+ if e.event_id in auth_ids
+ or event.type == EventTypes.Create
+ }
+ ev.internal_metadata.outlier = True
+
+ logger.debug(
+ "do_auth %s different_auth: %s",
+ event.event_id, e.event_id
+ )
- yield self._update_context_for_auth_events(
- event, context, auth_events, event_key,
- )
+ yield self._handle_new_event(
+ origin, ev, auth_events=auth
+ )
- try:
- self.auth.check(room_version, event, auth_events=auth_events)
- except AuthError as e:
- logger.warn("Failed auth resolution for %r because %s", event, e)
- raise e
+ if ev.event_id in event_auth_events:
+ auth_events[(ev.type, ev.state_key)] = ev
+ except AuthError:
+ pass
+
+ except Exception:
+ # FIXME:
+ logger.exception("Failed to query auth chain")
+
+ # 4. Look at rejects and their proofs.
+ # TODO.
+
+ yield self._update_context_for_auth_events(
+ event, context, auth_events, event_key,
+ )
@defer.inlineCallbacks
def _update_context_for_auth_events(self, event, context, auth_events,
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 7dfae78db0..aaee5db0b7 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -19,7 +19,6 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError
-from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.handlers.presence import format_user_presence_state
from synapse.streams.config import PaginationConfig
@@ -43,6 +42,7 @@ class InitialSyncHandler(BaseHandler):
self.clock = hs.get_clock()
self.validator = EventValidator()
self.snapshot_cache = SnapshotCache()
+ self._event_serializer = hs.get_event_client_serializer()
def snapshot_all_rooms(self, user_id=None, pagin_config=None,
as_client_event=True, include_archived=False):
@@ -138,7 +138,9 @@ class InitialSyncHandler(BaseHandler):
d["inviter"] = event.sender
invite_event = yield self.store.get_event(event.event_id)
- d["invite"] = serialize_event(invite_event, time_now, as_client_event)
+ d["invite"] = yield self._event_serializer.serialize_event(
+ invite_event, time_now, as_client_event,
+ )
rooms_ret.append(d)
@@ -185,18 +187,21 @@ class InitialSyncHandler(BaseHandler):
time_now = self.clock.time_msec()
d["messages"] = {
- "chunk": [
- serialize_event(m, time_now, as_client_event)
- for m in messages
- ],
+ "chunk": (
+ yield self._event_serializer.serialize_events(
+ messages, time_now=time_now,
+ as_client_event=as_client_event,
+ )
+ ),
"start": start_token.to_string(),
"end": end_token.to_string(),
}
- d["state"] = [
- serialize_event(c, time_now, as_client_event)
- for c in current_state.values()
- ]
+ d["state"] = yield self._event_serializer.serialize_events(
+ current_state.values(),
+ time_now=time_now,
+ as_client_event=as_client_event
+ )
account_data_events = []
tags = tags_by_room.get(event.room_id)
@@ -337,11 +342,15 @@ class InitialSyncHandler(BaseHandler):
"membership": membership,
"room_id": room_id,
"messages": {
- "chunk": [serialize_event(m, time_now) for m in messages],
+ "chunk": (yield self._event_serializer.serialize_events(
+ messages, time_now,
+ )),
"start": start_token.to_string(),
"end": end_token.to_string(),
},
- "state": [serialize_event(s, time_now) for s in room_state.values()],
+ "state": (yield self._event_serializer.serialize_events(
+ room_state.values(), time_now,
+ )),
"presence": [],
"receipts": [],
})
@@ -355,10 +364,9 @@ class InitialSyncHandler(BaseHandler):
# TODO: These concurrently
time_now = self.clock.time_msec()
- state = [
- serialize_event(x, time_now)
- for x in current_state.values()
- ]
+ state = yield self._event_serializer.serialize_events(
+ current_state.values(), time_now,
+ )
now_token = yield self.hs.get_event_sources().get_current_token()
@@ -425,7 +433,9 @@ class InitialSyncHandler(BaseHandler):
ret = {
"room_id": room_id,
"messages": {
- "chunk": [serialize_event(m, time_now) for m in messages],
+ "chunk": (yield self._event_serializer.serialize_events(
+ messages, time_now,
+ )),
"start": start_token.to_string(),
"end": end_token.to_string(),
},
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 224d34ef3a..0b02469ceb 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -22,7 +22,7 @@ from canonicaljson import encode_canonical_json, json
from twisted.internet import defer
from twisted.internet.defer import succeed
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RelationTypes
from synapse.api.errors import (
AuthError,
Codes,
@@ -32,7 +32,6 @@ from synapse.api.errors import (
)
from synapse.api.room_versions import RoomVersions
from synapse.api.urls import ConsentURIBuilder
-from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.state import StateFilter
@@ -57,6 +56,7 @@ class MessageHandler(object):
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
self.store = hs.get_datastore()
+ self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
def get_room_data(self, user_id=None, room_id=None,
@@ -164,9 +164,13 @@ class MessageHandler(object):
room_state = room_state[membership_event_id]
now = self.clock.time_msec()
- defer.returnValue(
- [serialize_event(c, now) for c in room_state.values()]
+ events = yield self._event_serializer.serialize_events(
+ room_state.values(), now,
+ # We don't bother bundling aggregations in when asked for state
+ # events, as clients won't use them.
+ bundle_aggregations=False,
)
+ defer.returnValue(events)
@defer.inlineCallbacks
def get_joined_members(self, requester, room_id):
@@ -228,6 +232,7 @@ class EventCreationHandler(object):
self.ratelimiter = hs.get_ratelimiter()
self.notifier = hs.get_notifier()
self.config = hs.config
+ self.require_membership_for_aliases = hs.config.require_membership_for_aliases
self.send_event_to_master = ReplicationSendEventRestServlet.make_client(hs)
@@ -336,6 +341,35 @@ class EventCreationHandler(object):
prev_events_and_hashes=prev_events_and_hashes,
)
+ # In an ideal world we wouldn't need the second part of this condition. However,
+ # this behaviour isn't spec'd yet, meaning we should be able to deactivate this
+ # behaviour. Another reason is that this code is also evaluated each time a new
+ # m.room.aliases event is created, which includes hitting a /directory route.
+ # Therefore not including this condition here would render the similar one in
+ # synapse.handlers.directory pointless.
+ if builder.type == EventTypes.Aliases and self.require_membership_for_aliases:
+ # Ideally we'd do the membership check in event_auth.check(), which
+ # describes a spec'd algorithm for authenticating events received over
+ # federation as well as those created locally. As of room v3, aliases events
+ # can be created by users that are not in the room, therefore we have to
+ # tolerate them in event_auth.check().
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
+ prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
+ if not prev_event or prev_event.membership != Membership.JOIN:
+ logger.warning(
+ ("Attempt to send `m.room.aliases` in room %s by user %s but"
+ " membership is %s"),
+ event.room_id,
+ event.sender,
+ prev_event.membership if prev_event else None,
+ )
+
+ raise AuthError(
+ 403,
+ "You must be in the room to create an alias for it",
+ )
+
self.validator.validate_new(event)
defer.returnValue((event, context))
@@ -570,6 +604,20 @@ class EventCreationHandler(object):
self.validator.validate_new(event)
+ # If this event is an annotation then we check that that the sender
+ # can't annotate the same way twice (e.g. stops users from liking an
+ # event multiple times).
+ relation = event.content.get("m.relates_to", {})
+ if relation.get("rel_type") == RelationTypes.ANNOTATION:
+ relates_to = relation["event_id"]
+ aggregation_key = relation["key"]
+
+ already_exists = yield self.store.has_user_annotated_event(
+ relates_to, event.type, aggregation_key, event.sender,
+ )
+ if already_exists:
+ raise SynapseError(400, "Can't send same reaction twice")
+
logger.debug(
"Created event %s",
event.event_id,
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index e4fdae9266..8f811e24fe 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -20,7 +20,6 @@ from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
-from synapse.events.utils import serialize_event
from synapse.storage.state import StateFilter
from synapse.types import RoomStreamToken
from synapse.util.async_helpers import ReadWriteLock
@@ -78,6 +77,7 @@ class PaginationHandler(object):
self._purges_in_progress_by_room = set()
# map from purge id to PurgeStatus
self._purges_by_id = {}
+ self._event_serializer = hs.get_event_client_serializer()
def start_purge_history(self, room_id, token,
delete_local_events=False):
@@ -278,18 +278,22 @@ class PaginationHandler(object):
time_now = self.clock.time_msec()
chunk = {
- "chunk": [
- serialize_event(e, time_now, as_client_event)
- for e in events
- ],
+ "chunk": (
+ yield self._event_serializer.serialize_events(
+ events, time_now,
+ as_client_event=as_client_event,
+ )
+ ),
"start": pagin_config.from_token.to_string(),
"end": next_token.to_string(),
}
if state:
- chunk["state"] = [
- serialize_event(e, time_now, as_client_event)
- for e in state
- ]
+ chunk["state"] = (
+ yield self._event_serializer.serialize_events(
+ state, time_now,
+ as_client_event=as_client_event,
+ )
+ )
defer.returnValue(chunk)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index bd1285b15c..6209858bbb 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -182,17 +182,27 @@ class PresenceHandler(object):
# Start a LoopingCall in 30s that fires every 5s.
# The initial delay is to allow disconnected clients a chance to
# reconnect before we treat them as offline.
+ def run_timeout_handler():
+ return run_as_background_process(
+ "handle_presence_timeouts", self._handle_timeouts
+ )
+
self.clock.call_later(
30,
self.clock.looping_call,
- self._handle_timeouts,
+ run_timeout_handler,
5000,
)
+ def run_persister():
+ return run_as_background_process(
+ "persist_presence_changes", self._persist_unpersisted_changes
+ )
+
self.clock.call_later(
60,
self.clock.looping_call,
- self._persist_unpersisted_changes,
+ run_persister,
60 * 1000,
)
@@ -229,6 +239,7 @@ class PresenceHandler(object):
)
if self.unpersisted_users_changes:
+
yield self.store.update_presence([
self.user_to_current_state[user_id]
for user_id in self.unpersisted_users_changes
@@ -240,30 +251,18 @@ class PresenceHandler(object):
"""We periodically persist the unpersisted changes, as otherwise they
may stack up and slow down shutdown times.
"""
- logger.info(
- "Performing _persist_unpersisted_changes. Persisting %d unpersisted changes",
- len(self.unpersisted_users_changes)
- )
-
unpersisted = self.unpersisted_users_changes
self.unpersisted_users_changes = set()
if unpersisted:
+ logger.info(
+ "Persisting %d upersisted presence updates", len(unpersisted)
+ )
yield self.store.update_presence([
self.user_to_current_state[user_id]
for user_id in unpersisted
])
- logger.info("Finished _persist_unpersisted_changes")
-
- @defer.inlineCallbacks
- def _update_states_and_catch_exception(self, new_states):
- try:
- res = yield self._update_states(new_states)
- defer.returnValue(res)
- except Exception:
- logger.exception("Error updating presence")
-
@defer.inlineCallbacks
def _update_states(self, new_states):
"""Updates presence of users. Sets the appropriate timeouts. Pokes
@@ -338,45 +337,41 @@ class PresenceHandler(object):
logger.info("Handling presence timeouts")
now = self.clock.time_msec()
- try:
- with Measure(self.clock, "presence_handle_timeouts"):
- # Fetch the list of users that *may* have timed out. Things may have
- # changed since the timeout was set, so we won't necessarily have to
- # take any action.
- users_to_check = set(self.wheel_timer.fetch(now))
-
- # Check whether the lists of syncing processes from an external
- # process have expired.
- expired_process_ids = [
- process_id for process_id, last_update
- in self.external_process_last_updated_ms.items()
- if now - last_update > EXTERNAL_PROCESS_EXPIRY
- ]
- for process_id in expired_process_ids:
- users_to_check.update(
- self.external_process_last_updated_ms.pop(process_id, ())
- )
- self.external_process_last_update.pop(process_id)
+ # Fetch the list of users that *may* have timed out. Things may have
+ # changed since the timeout was set, so we won't necessarily have to
+ # take any action.
+ users_to_check = set(self.wheel_timer.fetch(now))
+
+ # Check whether the lists of syncing processes from an external
+ # process have expired.
+ expired_process_ids = [
+ process_id for process_id, last_update
+ in self.external_process_last_updated_ms.items()
+ if now - last_update > EXTERNAL_PROCESS_EXPIRY
+ ]
+ for process_id in expired_process_ids:
+ users_to_check.update(
+ self.external_process_last_updated_ms.pop(process_id, ())
+ )
+ self.external_process_last_update.pop(process_id)
- states = [
- self.user_to_current_state.get(
- user_id, UserPresenceState.default(user_id)
- )
- for user_id in users_to_check
- ]
+ states = [
+ self.user_to_current_state.get(
+ user_id, UserPresenceState.default(user_id)
+ )
+ for user_id in users_to_check
+ ]
- timers_fired_counter.inc(len(states))
+ timers_fired_counter.inc(len(states))
- changes = handle_timeouts(
- states,
- is_mine_fn=self.is_mine_id,
- syncing_user_ids=self.get_currently_syncing_users(),
- now=now,
- )
+ changes = handle_timeouts(
+ states,
+ is_mine_fn=self.is_mine_id,
+ syncing_user_ids=self.get_currently_syncing_users(),
+ now=now,
+ )
- run_in_background(self._update_states_and_catch_exception, changes)
- except Exception:
- logger.exception("Exception in _handle_timeouts loop")
+ return self._update_states(changes)
@defer.inlineCallbacks
def bump_presence_active_time(self, user):
@@ -828,6 +823,11 @@ class PresenceHandler(object):
if typ != EventTypes.Member:
continue
+ if event_id is None:
+ # state has been deleted, so this is not a join. We only care about
+ # joins.
+ continue
+
event = yield self.store.get_event(event_id)
if event.content.get("membership") != Membership.JOIN:
# We only care about joins
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index a65c98ff5c..a5fc6c5dbf 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -31,6 +31,9 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__)
+MAX_DISPLAYNAME_LEN = 100
+MAX_AVATAR_URL_LEN = 1000
+
class BaseProfileHandler(BaseHandler):
"""Handles fetching and updating user profile information.
@@ -53,6 +56,7 @@ class BaseProfileHandler(BaseHandler):
@defer.inlineCallbacks
def get_profile(self, user_id):
target_user = UserID.from_string(user_id)
+
if self.hs.is_mine(target_user):
try:
displayname = yield self.store.get_profile_displayname(
@@ -161,6 +165,11 @@ class BaseProfileHandler(BaseHandler):
if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's displayname")
+ if len(new_displayname) > MAX_DISPLAYNAME_LEN:
+ raise SynapseError(
+ 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN, ),
+ )
+
if new_displayname == '':
new_displayname = None
@@ -216,6 +225,11 @@ class BaseProfileHandler(BaseHandler):
if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's avatar_url")
+ if len(new_avatar_url) > MAX_AVATAR_URL_LEN:
+ raise SynapseError(
+ 400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN, ),
+ )
+
yield self.store.set_profile_avatar_url(
target_user.localpart, new_avatar_url
)
@@ -283,6 +297,48 @@ class BaseProfileHandler(BaseHandler):
room_id, str(e)
)
+ @defer.inlineCallbacks
+ def check_profile_query_allowed(self, target_user, requester=None):
+ """Checks whether a profile query is allowed. If the
+ 'require_auth_for_profile_requests' config flag is set to True and a
+ 'requester' is provided, the query is only allowed if the two users
+ share a room.
+
+ Args:
+ target_user (UserID): The owner of the queried profile.
+ requester (None|UserID): The user querying for the profile.
+
+ Raises:
+ SynapseError(403): The two users share no room, or ne user couldn't
+ be found to be in any room the server is in, and therefore the query
+ is denied.
+ """
+ # Implementation of MSC1301: don't allow looking up profiles if the
+ # requester isn't in the same room as the target. We expect requester to
+ # be None when this function is called outside of a profile query, e.g.
+ # when building a membership event. In this case, we must allow the
+ # lookup.
+ if not self.hs.config.require_auth_for_profile_requests or not requester:
+ return
+
+ try:
+ requester_rooms = yield self.store.get_rooms_for_user(
+ requester.to_string()
+ )
+ target_user_rooms = yield self.store.get_rooms_for_user(
+ target_user.to_string(),
+ )
+
+ # Check if the room lists have no elements in common.
+ if requester_rooms.isdisjoint(target_user_rooms):
+ raise SynapseError(403, "Profile isn't available", Codes.FORBIDDEN)
+ except StoreError as e:
+ if e.code == 404:
+ # This likely means that one of the users doesn't exist,
+ # so we act as if we couldn't find the profile.
+ raise SynapseError(403, "Profile isn't available", Codes.FORBIDDEN)
+ raise
+
class MasterProfileHandler(BaseProfileHandler):
PROFILE_UPDATE_MS = 60 * 1000
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index a51d11a257..9a388ea013 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -19,7 +19,7 @@ import logging
from twisted.internet import defer
from synapse import types
-from synapse.api.constants import LoginType
+from synapse.api.constants import MAX_USERID_LENGTH, LoginType
from synapse.api.errors import (
AuthError,
Codes,
@@ -123,6 +123,15 @@ class RegistrationHandler(BaseHandler):
self.check_user_id_not_appservice_exclusive(user_id)
+ if len(user_id) > MAX_USERID_LENGTH:
+ raise SynapseError(
+ 400,
+ "User ID may not be longer than %s characters" % (
+ MAX_USERID_LENGTH,
+ ),
+ Codes.INVALID_USERNAME
+ )
+
users = yield self.store.get_users_by_id_case_insensitive(user_id)
if users:
if not guest_access_token:
@@ -522,6 +531,8 @@ class RegistrationHandler(BaseHandler):
A tuple of (user_id, access_token).
Raises:
RegistrationError if there was a problem registering.
+
+ NB this is only used in tests. TODO: move it to the test package!
"""
if localpart is None:
raise SynapseError(400, "Request must include user id")
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 17628e2684..4a17911a87 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -27,7 +27,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
-from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
from synapse.util import stringutils
@@ -70,6 +70,7 @@ class RoomCreationHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
+ self.config = hs.config
# linearizer to stop two upgrades happening at once
self._upgrade_linearizer = Linearizer("room_upgrade_linearizer")
@@ -402,7 +403,7 @@ class RoomCreationHandler(BaseHandler):
yield directory_handler.create_association(
requester, RoomAlias.from_string(alias),
new_room_id, servers=(self.hs.hostname, ),
- send_event=False,
+ send_event=False, check_membership=False,
)
logger.info("Moved alias %s to new room", alias)
except SynapseError as e:
@@ -475,7 +476,11 @@ class RoomCreationHandler(BaseHandler):
if ratelimit:
yield self.ratelimit(requester)
- room_version = config.get("room_version", DEFAULT_ROOM_VERSION.identifier)
+ room_version = config.get(
+ "room_version",
+ self.config.default_room_version.identifier,
+ )
+
if not isinstance(room_version, string_types):
raise SynapseError(
400,
@@ -538,6 +543,7 @@ class RoomCreationHandler(BaseHandler):
room_alias=room_alias,
servers=[self.hs.hostname],
send_event=False,
+ check_membership=False,
)
preset_config = config.get(
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 024d6db27a..93ac986c86 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -33,6 +34,8 @@ from synapse.types import RoomID, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room
+from ._base import BaseHandler
+
logger = logging.getLogger(__name__)
id_server_scheme = "https://"
@@ -71,6 +74,12 @@ class RoomMemberHandler(object):
self.spam_checker = hs.get_spam_checker()
self._server_notices_mxid = self.config.server_notices_mxid
self._enable_lookup = hs.config.enable_3pid_lookup
+ self.allow_per_room_profiles = self.config.allow_per_room_profiles
+
+ # This is only used to get at ratelimit function, and
+ # maybe_kick_guest_users. It's fine there are multiple of these as
+ # it doesn't store state.
+ self.base_handler = BaseHandler(hs)
@abc.abstractmethod
def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
@@ -350,6 +359,13 @@ class RoomMemberHandler(object):
# later on.
content = dict(content)
+ if not self.allow_per_room_profiles:
+ # Strip profile data, knowing that new profile data will be added to the
+ # event's content in event_creation_handler.create_event() using the target's
+ # global profile.
+ content.pop("displayname", None)
+ content.pop("avatar_url", None)
+
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
@@ -703,6 +719,10 @@ class RoomMemberHandler(object):
Codes.FORBIDDEN,
)
+ # We need to rate limit *before* we send out any 3PID invites, so we
+ # can't just rely on the standard ratelimiting of events.
+ yield self.base_handler.ratelimit(requester)
+
invitee = yield self._lookup_3pid(
id_server, medium, address
)
@@ -924,7 +944,7 @@ class RoomMemberHandler(object):
}
if self.config.invite_3pid_guest:
- guest_access_token, guest_user_id = yield self.get_or_register_3pid_guest(
+ guest_user_id, guest_access_token = yield self.get_or_register_3pid_guest(
requester=requester,
medium=medium,
address=address,
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 49c439313e..9bba74d6c9 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -23,7 +23,6 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
-from synapse.events.utils import serialize_event
from synapse.storage.state import StateFilter
from synapse.visibility import filter_events_for_client
@@ -36,6 +35,7 @@ class SearchHandler(BaseHandler):
def __init__(self, hs):
super(SearchHandler, self).__init__(hs)
+ self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
def get_old_rooms_from_upgraded_room(self, room_id):
@@ -401,14 +401,16 @@ class SearchHandler(BaseHandler):
time_now = self.clock.time_msec()
for context in contexts.values():
- context["events_before"] = [
- serialize_event(e, time_now)
- for e in context["events_before"]
- ]
- context["events_after"] = [
- serialize_event(e, time_now)
- for e in context["events_after"]
- ]
+ context["events_before"] = (
+ yield self._event_serializer.serialize_events(
+ context["events_before"], time_now,
+ )
+ )
+ context["events_after"] = (
+ yield self._event_serializer.serialize_events(
+ context["events_after"], time_now,
+ )
+ )
state_results = {}
if include_state:
@@ -422,14 +424,13 @@ class SearchHandler(BaseHandler):
# We're now about to serialize the events. We should not make any
# blocking calls after this. Otherwise the 'age' will be wrong
- results = [
- {
+ results = []
+ for e in allowed_events:
+ results.append({
"rank": rank_map[e.event_id],
- "result": serialize_event(e, time_now),
+ "result": (yield self._event_serializer.serialize_event(e, time_now)),
"context": contexts.get(e.event_id, {}),
- }
- for e in allowed_events
- ]
+ })
rooms_cat_res = {
"results": results,
@@ -438,10 +439,13 @@ class SearchHandler(BaseHandler):
}
if state_results:
- rooms_cat_res["state"] = {
- room_id: [serialize_event(e, time_now) for e in state]
- for room_id, state in state_results.items()
- }
+ s = {}
+ for room_id, state in state_results.items():
+ s[room_id] = yield self._event_serializer.serialize_events(
+ state, time_now,
+ )
+
+ rooms_cat_res["state"] = s
if room_groups and "room_id" in group_keys:
rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
new file mode 100644
index 0000000000..0e92b405ba
--- /dev/null
+++ b/synapse/handlers/stats.py
@@ -0,0 +1,325 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.handlers.state_deltas import StateDeltasHandler
+from synapse.metrics import event_processing_positions
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import UserID
+from synapse.util.metrics import Measure
+
+logger = logging.getLogger(__name__)
+
+
+class StatsHandler(StateDeltasHandler):
+ """Handles keeping the *_stats tables updated with a simple time-series of
+ information about the users, rooms and media on the server, such that admins
+ have some idea of who is consuming their resources.
+
+ Heavily derived from UserDirectoryHandler
+ """
+
+ def __init__(self, hs):
+ super(StatsHandler, self).__init__(hs)
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.state = hs.get_state_handler()
+ self.server_name = hs.hostname
+ self.clock = hs.get_clock()
+ self.notifier = hs.get_notifier()
+ self.is_mine_id = hs.is_mine_id
+ self.stats_bucket_size = hs.config.stats_bucket_size
+
+ # The current position in the current_state_delta stream
+ self.pos = None
+
+ # Guard to ensure we only process deltas one at a time
+ self._is_processing = False
+
+ if hs.config.stats_enabled:
+ self.notifier.add_replication_callback(self.notify_new_event)
+
+ # We kick this off so that we don't have to wait for a change before
+ # we start populating stats
+ self.clock.call_later(0, self.notify_new_event)
+
+ def notify_new_event(self):
+ """Called when there may be more deltas to process
+ """
+ if not self.hs.config.stats_enabled:
+ return
+
+ if self._is_processing:
+ return
+
+ @defer.inlineCallbacks
+ def process():
+ try:
+ yield self._unsafe_process()
+ finally:
+ self._is_processing = False
+
+ self._is_processing = True
+ run_as_background_process("stats.notify_new_event", process)
+
+ @defer.inlineCallbacks
+ def _unsafe_process(self):
+ # If self.pos is None then means we haven't fetched it from DB
+ if self.pos is None:
+ self.pos = yield self.store.get_stats_stream_pos()
+
+ # If still None then the initial background update hasn't happened yet
+ if self.pos is None:
+ defer.returnValue(None)
+
+ # Loop round handling deltas until we're up to date
+ while True:
+ with Measure(self.clock, "stats_delta"):
+ deltas = yield self.store.get_current_state_deltas(self.pos)
+ if not deltas:
+ return
+
+ logger.info("Handling %d state deltas", len(deltas))
+ yield self._handle_deltas(deltas)
+
+ self.pos = deltas[-1]["stream_id"]
+ yield self.store.update_stats_stream_pos(self.pos)
+
+ event_processing_positions.labels("stats").set(self.pos)
+
+ @defer.inlineCallbacks
+ def _handle_deltas(self, deltas):
+ """
+ Called with the state deltas to process
+ """
+ for delta in deltas:
+ typ = delta["type"]
+ state_key = delta["state_key"]
+ room_id = delta["room_id"]
+ event_id = delta["event_id"]
+ stream_id = delta["stream_id"]
+ prev_event_id = delta["prev_event_id"]
+
+ logger.debug("Handling: %r %r, %s", typ, state_key, event_id)
+
+ token = yield self.store.get_earliest_token_for_room_stats(room_id)
+
+ # If the earliest token to begin from is larger than our current
+ # stream ID, skip processing this delta.
+ if token is not None and token >= stream_id:
+ logger.debug(
+ "Ignoring: %s as earlier than this room's initial ingestion event",
+ event_id,
+ )
+ continue
+
+ if event_id is None and prev_event_id is None:
+ # Errr...
+ continue
+
+ event_content = {}
+
+ if event_id is not None:
+ event_content = (yield self.store.get_event(event_id)).content or {}
+
+ # quantise time to the nearest bucket
+ now = yield self.store.get_received_ts(event_id)
+ now = (now // 1000 // self.stats_bucket_size) * self.stats_bucket_size
+
+ if typ == EventTypes.Member:
+ # we could use _get_key_change here but it's a bit inefficient
+ # given we're not testing for a specific result; might as well
+ # just grab the prev_membership and membership strings and
+ # compare them.
+ prev_event_content = {}
+ if prev_event_id is not None:
+ prev_event_content = (
+ yield self.store.get_event(prev_event_id)
+ ).content
+
+ membership = event_content.get("membership", Membership.LEAVE)
+ prev_membership = prev_event_content.get("membership", Membership.LEAVE)
+
+ if prev_membership == membership:
+ continue
+
+ if prev_membership == Membership.JOIN:
+ yield self.store.update_stats_delta(
+ now, "room", room_id, "joined_members", -1
+ )
+ elif prev_membership == Membership.INVITE:
+ yield self.store.update_stats_delta(
+ now, "room", room_id, "invited_members", -1
+ )
+ elif prev_membership == Membership.LEAVE:
+ yield self.store.update_stats_delta(
+ now, "room", room_id, "left_members", -1
+ )
+ elif prev_membership == Membership.BAN:
+ yield self.store.update_stats_delta(
+ now, "room", room_id, "banned_members", -1
+ )
+ else:
+ err = "%s is not a valid prev_membership" % (repr(prev_membership),)
+ logger.error(err)
+ raise ValueError(err)
+
+ if membership == Membership.JOIN:
+ yield self.store.update_stats_delta(
+ now, "room", room_id, "joined_members", +1
+ )
+ elif membership == Membership.INVITE:
+ yield self.store.update_stats_delta(
+ now, "room", room_id, "invited_members", +1
+ )
+ elif membership == Membership.LEAVE:
+ yield self.store.update_stats_delta(
+ now, "room", room_id, "left_members", +1
+ )
+ elif membership == Membership.BAN:
+ yield self.store.update_stats_delta(
+ now, "room", room_id, "banned_members", +1
+ )
+ else:
+ err = "%s is not a valid membership" % (repr(membership),)
+ logger.error(err)
+ raise ValueError(err)
+
+ user_id = state_key
+ if self.is_mine_id(user_id):
+ # update user_stats as it's one of our users
+ public = yield self._is_public_room(room_id)
+
+ if membership == Membership.LEAVE:
+ yield self.store.update_stats_delta(
+ now,
+ "user",
+ user_id,
+ "public_rooms" if public else "private_rooms",
+ -1,
+ )
+ elif membership == Membership.JOIN:
+ yield self.store.update_stats_delta(
+ now,
+ "user",
+ user_id,
+ "public_rooms" if public else "private_rooms",
+ +1,
+ )
+
+ elif typ == EventTypes.Create:
+ # Newly created room. Add it with all blank portions.
+ yield self.store.update_room_state(
+ room_id,
+ {
+ "join_rules": None,
+ "history_visibility": None,
+ "encryption": None,
+ "name": None,
+ "topic": None,
+ "avatar": None,
+ "canonical_alias": None,
+ },
+ )
+
+ elif typ == EventTypes.JoinRules:
+ yield self.store.update_room_state(
+ room_id, {"join_rules": event_content.get("join_rule")}
+ )
+
+ is_public = yield self._get_key_change(
+ prev_event_id, event_id, "join_rule", JoinRules.PUBLIC
+ )
+ if is_public is not None:
+ yield self.update_public_room_stats(now, room_id, is_public)
+
+ elif typ == EventTypes.RoomHistoryVisibility:
+ yield self.store.update_room_state(
+ room_id,
+ {"history_visibility": event_content.get("history_visibility")},
+ )
+
+ is_public = yield self._get_key_change(
+ prev_event_id, event_id, "history_visibility", "world_readable"
+ )
+ if is_public is not None:
+ yield self.update_public_room_stats(now, room_id, is_public)
+
+ elif typ == EventTypes.Encryption:
+ yield self.store.update_room_state(
+ room_id, {"encryption": event_content.get("algorithm")}
+ )
+ elif typ == EventTypes.Name:
+ yield self.store.update_room_state(
+ room_id, {"name": event_content.get("name")}
+ )
+ elif typ == EventTypes.Topic:
+ yield self.store.update_room_state(
+ room_id, {"topic": event_content.get("topic")}
+ )
+ elif typ == EventTypes.RoomAvatar:
+ yield self.store.update_room_state(
+ room_id, {"avatar": event_content.get("url")}
+ )
+ elif typ == EventTypes.CanonicalAlias:
+ yield self.store.update_room_state(
+ room_id, {"canonical_alias": event_content.get("alias")}
+ )
+
+ @defer.inlineCallbacks
+ def update_public_room_stats(self, ts, room_id, is_public):
+ """
+ Increment/decrement a user's number of public rooms when a room they are
+ in changes to/from public visibility.
+
+ Args:
+ ts (int): Timestamp in seconds
+ room_id (str)
+ is_public (bool)
+ """
+ # For now, blindly iterate over all local users in the room so that
+ # we can handle the whole problem of copying buckets over as needed
+ user_ids = yield self.store.get_users_in_room(room_id)
+
+ for user_id in user_ids:
+ if self.hs.is_mine(UserID.from_string(user_id)):
+ yield self.store.update_stats_delta(
+ ts, "user", user_id, "public_rooms", +1 if is_public else -1
+ )
+ yield self.store.update_stats_delta(
+ ts, "user", user_id, "private_rooms", -1 if is_public else +1
+ )
+
+ @defer.inlineCallbacks
+ def _is_public_room(self, room_id):
+ join_rules = yield self.state.get_current_state(room_id, EventTypes.JoinRules)
+ history_visibility = yield self.state.get_current_state(
+ room_id, EventTypes.RoomHistoryVisibility
+ )
+
+ if (join_rules and join_rules.content.get("join_rule") == JoinRules.PUBLIC) or (
+ (
+ history_visibility
+ and history_visibility.content.get("history_visibility")
+ == "world_readable"
+ )
+ ):
+ defer.returnValue(True)
+ else:
+ defer.returnValue(False)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 7cf757f65a..72997d6d04 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -934,7 +934,7 @@ class SyncHandler(object):
res = yield self._generate_sync_entry_for_rooms(
sync_result_builder, account_data_by_room
)
- newly_joined_rooms, newly_joined_users, _, _ = res
+ newly_joined_rooms, newly_joined_or_invited_users, _, _ = res
_, _, newly_left_rooms, newly_left_users = res
block_all_presence_data = (
@@ -943,7 +943,7 @@ class SyncHandler(object):
)
if self.hs_config.use_presence and not block_all_presence_data:
yield self._generate_sync_entry_for_presence(
- sync_result_builder, newly_joined_rooms, newly_joined_users
+ sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
)
yield self._generate_sync_entry_for_to_device(sync_result_builder)
@@ -951,7 +951,7 @@ class SyncHandler(object):
device_lists = yield self._generate_sync_entry_for_device_list(
sync_result_builder,
newly_joined_rooms=newly_joined_rooms,
- newly_joined_users=newly_joined_users,
+ newly_joined_or_invited_users=newly_joined_or_invited_users,
newly_left_rooms=newly_left_rooms,
newly_left_users=newly_left_users,
)
@@ -1036,7 +1036,8 @@ class SyncHandler(object):
@measure_func("_generate_sync_entry_for_device_list")
@defer.inlineCallbacks
def _generate_sync_entry_for_device_list(self, sync_result_builder,
- newly_joined_rooms, newly_joined_users,
+ newly_joined_rooms,
+ newly_joined_or_invited_users,
newly_left_rooms, newly_left_users):
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
@@ -1050,7 +1051,7 @@ class SyncHandler(object):
# share a room with?
for room_id in newly_joined_rooms:
joined_users = yield self.state.get_current_users_in_room(room_id)
- newly_joined_users.update(joined_users)
+ newly_joined_or_invited_users.update(joined_users)
for room_id in newly_left_rooms:
left_users = yield self.state.get_current_users_in_room(room_id)
@@ -1058,7 +1059,7 @@ class SyncHandler(object):
# TODO: Check that these users are actually new, i.e. either they
# weren't in the previous sync *or* they left and rejoined.
- changed.update(newly_joined_users)
+ changed.update(newly_joined_or_invited_users)
if not changed and not newly_left_users:
defer.returnValue(DeviceLists(
@@ -1176,7 +1177,7 @@ class SyncHandler(object):
@defer.inlineCallbacks
def _generate_sync_entry_for_presence(self, sync_result_builder, newly_joined_rooms,
- newly_joined_users):
+ newly_joined_or_invited_users):
"""Generates the presence portion of the sync response. Populates the
`sync_result_builder` with the result.
@@ -1184,8 +1185,9 @@ class SyncHandler(object):
sync_result_builder(SyncResultBuilder)
newly_joined_rooms(list): List of rooms that the user has joined
since the last sync (or empty if an initial sync)
- newly_joined_users(list): List of users that have joined rooms
- since the last sync (or empty if an initial sync)
+ newly_joined_or_invited_users(list): List of users that have joined
+ or been invited to rooms since the last sync (or empty if an initial
+ sync)
"""
now_token = sync_result_builder.now_token
sync_config = sync_result_builder.sync_config
@@ -1211,7 +1213,7 @@ class SyncHandler(object):
"presence_key", presence_key
)
- extra_users_ids = set(newly_joined_users)
+ extra_users_ids = set(newly_joined_or_invited_users)
for room_id in newly_joined_rooms:
users = yield self.state.get_current_users_in_room(room_id)
extra_users_ids.update(users)
@@ -1243,7 +1245,8 @@ class SyncHandler(object):
Returns:
Deferred(tuple): Returns a 4-tuple of
- `(newly_joined_rooms, newly_joined_users, newly_left_rooms, newly_left_users)`
+ `(newly_joined_rooms, newly_joined_or_invited_users,
+ newly_left_rooms, newly_left_users)`
"""
user_id = sync_result_builder.sync_config.user.to_string()
block_all_room_ephemeral = (
@@ -1314,8 +1317,8 @@ class SyncHandler(object):
sync_result_builder.invited.extend(invited)
- # Now we want to get any newly joined users
- newly_joined_users = set()
+ # Now we want to get any newly joined or invited users
+ newly_joined_or_invited_users = set()
newly_left_users = set()
if since_token:
for joined_sync in sync_result_builder.joined:
@@ -1324,19 +1327,22 @@ class SyncHandler(object):
)
for event in it:
if event.type == EventTypes.Member:
- if event.membership == Membership.JOIN:
- newly_joined_users.add(event.state_key)
+ if (
+ event.membership == Membership.JOIN or
+ event.membership == Membership.INVITE
+ ):
+ newly_joined_or_invited_users.add(event.state_key)
else:
prev_content = event.unsigned.get("prev_content", {})
prev_membership = prev_content.get("membership", None)
if prev_membership == Membership.JOIN:
newly_left_users.add(event.state_key)
- newly_left_users -= newly_joined_users
+ newly_left_users -= newly_joined_or_invited_users
defer.returnValue((
newly_joined_rooms,
- newly_joined_users,
+ newly_joined_or_invited_users,
newly_left_rooms,
newly_left_users,
))
@@ -1381,7 +1387,7 @@ class SyncHandler(object):
where:
room_entries is a list [RoomSyncResultBuilder]
invited_rooms is a list [InvitedSyncResult]
- newly_joined rooms is a list[str] of room ids
+ newly_joined_rooms is a list[str] of room ids
newly_left_rooms is a list[str] of room ids
"""
user_id = sync_result_builder.sync_config.user.to_string()
@@ -1422,7 +1428,7 @@ class SyncHandler(object):
if room_id in sync_result_builder.joined_room_ids and non_joins:
# Always include if the user (re)joined the room, especially
# important so that device list changes are calculated correctly.
- # If there are non join member events, but we are still in the room,
+ # If there are non-join member events, but we are still in the room,
# then the user must have left and joined
newly_joined_rooms.append(room_id)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index ad454f4964..77fe68818b 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -90,45 +90,50 @@ class IPBlacklistingResolver(object):
def resolveHostName(self, recv, hostname, portNumber=0):
r = recv()
- d = defer.Deferred()
addresses = []
- @provider(IResolutionReceiver)
- class EndpointReceiver(object):
- @staticmethod
- def resolutionBegan(resolutionInProgress):
- pass
+ def _callback():
+ r.resolutionBegan(None)
- @staticmethod
- def addressResolved(address):
- ip_address = IPAddress(address.host)
+ has_bad_ip = False
+ for i in addresses:
+ ip_address = IPAddress(i.host)
if check_against_blacklist(
ip_address, self._ip_whitelist, self._ip_blacklist
):
logger.info(
- "Dropped %s from DNS resolution to %s" % (ip_address, hostname)
+ "Dropped %s from DNS resolution to %s due to blacklist" %
+ (ip_address, hostname)
)
- raise SynapseError(403, "IP address blocked by IP blacklist entry")
+ has_bad_ip = True
+
+ # if we have a blacklisted IP, we'd like to raise an error to block the
+ # request, but all we can really do from here is claim that there were no
+ # valid results.
+ if not has_bad_ip:
+ for i in addresses:
+ r.addressResolved(i)
+ r.resolutionComplete()
+ @provider(IResolutionReceiver)
+ class EndpointReceiver(object):
+ @staticmethod
+ def resolutionBegan(resolutionInProgress):
+ pass
+
+ @staticmethod
+ def addressResolved(address):
addresses.append(address)
@staticmethod
def resolutionComplete():
- d.callback(addresses)
+ _callback()
self._reactor.nameResolver.resolveHostName(
EndpointReceiver, hostname, portNumber=portNumber
)
- def _callback(addrs):
- r.resolutionBegan(None)
- for i in addrs:
- r.addressResolved(i)
- r.resolutionComplete()
-
- d.addCallback(_callback)
-
return r
@@ -160,7 +165,8 @@ class BlacklistingAgentWrapper(Agent):
ip_address, self._ip_whitelist, self._ip_blacklist
):
logger.info(
- "Blocking access to %s because of blacklist" % (ip_address,)
+ "Blocking access to %s due to blacklist" %
+ (ip_address,)
)
e = SynapseError(403, "IP address blocked by IP blacklist entry")
return defer.fail(Failure(e))
@@ -258,9 +264,6 @@ class SimpleHttpClient(object):
uri (str): URI to query.
data (bytes): Data to send in the request body, if applicable.
headers (t.w.http_headers.Headers): Request headers.
-
- Raises:
- SynapseError: If the IP is blacklisted.
"""
# A small wrapper around self.agent.request() so we can easily attach
# counters to it
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 1334c630cc..b4cbe97b41 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -149,7 +149,7 @@ class MatrixFederationAgent(object):
tls_options = None
else:
tls_options = self._tls_client_options_factory.get_options(
- res.tls_server_name.decode("ascii")
+ res.tls_server_name.decode("ascii"),
)
# make sure that the Host header is set correctly
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index ff63d0b2a8..663ea72a7a 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -27,9 +27,11 @@ import treq
from canonicaljson import encode_canonical_json
from prometheus_client import Counter
from signedjson.sign import sign_json
+from zope.interface import implementer
from twisted.internet import defer, protocol
from twisted.internet.error import DNSLookupError
+from twisted.internet.interfaces import IReactorPluggableNameResolver
from twisted.internet.task import _EPSILON, Cooperator
from twisted.web._newclient import ResponseDone
from twisted.web.http_headers import Headers
@@ -44,6 +46,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.http import QuieterFileBodyProducer
+from synapse.http.client import BlacklistingAgentWrapper, IPBlacklistingResolver
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.util.async_helpers import timeout_deferred
from synapse.util.logcontext import make_deferred_yieldable
@@ -172,19 +175,44 @@ class MatrixFederationHttpClient(object):
self.hs = hs
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
- reactor = hs.get_reactor()
+
+ real_reactor = hs.get_reactor()
+
+ # We need to use a DNS resolver which filters out blacklisted IP
+ # addresses, to prevent DNS rebinding.
+ nameResolver = IPBlacklistingResolver(
+ real_reactor, None, hs.config.federation_ip_range_blacklist,
+ )
+
+ @implementer(IReactorPluggableNameResolver)
+ class Reactor(object):
+ def __getattr__(_self, attr):
+ if attr == "nameResolver":
+ return nameResolver
+ else:
+ return getattr(real_reactor, attr)
+
+ self.reactor = Reactor()
self.agent = MatrixFederationAgent(
- hs.get_reactor(),
+ self.reactor,
tls_client_options_factory,
)
+
+ # Use a BlacklistingAgentWrapper to prevent circumventing the IP
+ # blacklist via IP literals in server names
+ self.agent = BlacklistingAgentWrapper(
+ self.agent, self.reactor,
+ ip_blacklist=hs.config.federation_ip_range_blacklist,
+ )
+
self.clock = hs.get_clock()
self._store = hs.get_datastore()
self.version_string_bytes = hs.version_string.encode('ascii')
self.default_timeout = 60
def schedule(x):
- reactor.callLater(_EPSILON, x)
+ self.reactor.callLater(_EPSILON, x)
self._cooperator = Cooperator(scheduler=schedule)
@@ -257,7 +285,24 @@ class MatrixFederationHttpClient(object):
request (MatrixFederationRequest): details of request to be sent
timeout (int|None): number of milliseconds to wait for the response headers
- (including connecting to the server). 60s by default.
+ (including connecting to the server), *for each attempt*.
+ 60s by default.
+
+ long_retries (bool): whether to use the long retry algorithm.
+
+ The regular retry algorithm makes 4 attempts, with intervals
+ [0.5s, 1s, 2s].
+
+ The long retry algorithm makes 11 attempts, with intervals
+ [4s, 16s, 60s, 60s, ...]
+
+ Both algorithms add -20%/+40% jitter to the retry intervals.
+
+ Note that the above intervals are *in addition* to the time spent
+ waiting for the request to complete (up to `timeout` ms).
+
+ NB: the long retry algorithm takes over 20 minutes to complete, with
+ a default timeout of 60s!
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
@@ -370,7 +415,7 @@ class MatrixFederationHttpClient(object):
request_deferred = timeout_deferred(
request_deferred,
timeout=_sec_timeout,
- reactor=self.hs.get_reactor(),
+ reactor=self.reactor,
)
response = yield request_deferred
@@ -397,7 +442,7 @@ class MatrixFederationHttpClient(object):
d = timeout_deferred(
d,
timeout=_sec_timeout,
- reactor=self.hs.get_reactor(),
+ reactor=self.reactor,
)
try:
@@ -538,10 +583,14 @@ class MatrixFederationHttpClient(object):
the request body. This will be encoded as JSON.
json_data_callback (callable): A callable returning the dict to
use as the request body.
- long_retries (bool): A boolean that indicates whether we should
- retry for a short or long time.
- timeout(int): How long to try (in ms) the destination for before
- giving up. None indicates no timeout.
+
+ long_retries (bool): whether to use the long retry algorithm. See
+ docs on _send_request for details.
+
+ timeout (int|None): number of milliseconds to wait for the response headers
+ (including connecting to the server), *for each attempt*.
+ self._default_timeout (60s) by default.
+
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
backoff_on_404 (bool): True if we should count a 404 response as
@@ -586,7 +635,7 @@ class MatrixFederationHttpClient(object):
)
body = yield _handle_json_response(
- self.hs.get_reactor(), self.default_timeout, request, response,
+ self.reactor, self.default_timeout, request, response,
)
defer.returnValue(body)
@@ -599,15 +648,22 @@ class MatrixFederationHttpClient(object):
Args:
destination (str): The remote server to send the HTTP request
to.
+
path (str): The HTTP path.
+
data (dict): A dict containing the data that will be used as
the request body. This will be encoded as JSON.
- long_retries (bool): A boolean that indicates whether we should
- retry for a short or long time.
- timeout(int): How long to try (in ms) the destination for before
- giving up. None indicates no timeout.
+
+ long_retries (bool): whether to use the long retry algorithm. See
+ docs on _send_request for details.
+
+ timeout (int|None): number of milliseconds to wait for the response headers
+ (including connecting to the server), *for each attempt*.
+ self._default_timeout (60s) by default.
+
ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway.
+
args (dict): query params
Returns:
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
@@ -645,7 +701,7 @@ class MatrixFederationHttpClient(object):
_sec_timeout = self.default_timeout
body = yield _handle_json_response(
- self.hs.get_reactor(), _sec_timeout, request, response,
+ self.reactor, _sec_timeout, request, response,
)
defer.returnValue(body)
@@ -658,14 +714,19 @@ class MatrixFederationHttpClient(object):
Args:
destination (str): The remote server to send the HTTP request
to.
+
path (str): The HTTP path.
+
args (dict|None): A dictionary used to create query strings, defaults to
None.
- timeout (int): How long to try (in ms) the destination for before
- giving up. None indicates no timeout and that the request will
- be retried.
+
+ timeout (int|None): number of milliseconds to wait for the response headers
+ (including connecting to the server), *for each attempt*.
+ self._default_timeout (60s) by default.
+
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
+
try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end of
the request. Workaround for #3622 in Synapse <= v0.99.3.
@@ -683,10 +744,6 @@ class MatrixFederationHttpClient(object):
RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc.
"""
- logger.debug("get_json args: %s", args)
-
- logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
-
request = MatrixFederationRequest(
method="GET",
destination=destination,
@@ -704,7 +761,7 @@ class MatrixFederationHttpClient(object):
)
body = yield _handle_json_response(
- self.hs.get_reactor(), self.default_timeout, request, response,
+ self.reactor, self.default_timeout, request, response,
)
defer.returnValue(body)
@@ -718,12 +775,18 @@ class MatrixFederationHttpClient(object):
destination (str): The remote server to send the HTTP request
to.
path (str): The HTTP path.
- long_retries (bool): A boolean that indicates whether we should
- retry for a short or long time.
- timeout(int): How long to try (in ms) the destination for before
- giving up. None indicates no timeout.
+
+ long_retries (bool): whether to use the long retry algorithm. See
+ docs on _send_request for details.
+
+ timeout (int|None): number of milliseconds to wait for the response headers
+ (including connecting to the server), *for each attempt*.
+ self._default_timeout (60s) by default.
+
ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway.
+
+ args (dict): query params
Returns:
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
@@ -753,7 +816,7 @@ class MatrixFederationHttpClient(object):
)
body = yield _handle_json_response(
- self.hs.get_reactor(), self.default_timeout, request, response,
+ self.reactor, self.default_timeout, request, response,
)
defer.returnValue(body)
@@ -801,7 +864,7 @@ class MatrixFederationHttpClient(object):
try:
d = _readBodyToFile(response, output_stream, max_size)
- d.addTimeout(self.default_timeout, self.hs.get_reactor())
+ d.addTimeout(self.default_timeout, self.reactor)
length = yield make_deferred_yieldable(d)
except Exception as e:
logger.warn(
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 528125e737..197c652850 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -55,7 +55,7 @@ def parse_integer_from_args(args, name, default=None, required=False):
return int(args[name][0])
except Exception:
message = "Query parameter %r must be an integer" % (name,)
- raise SynapseError(400, message)
+ raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
else:
if required:
message = "Missing integer query parameter %r" % (name,)
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 8f0682c948..3523a40108 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -261,6 +261,23 @@ BASE_APPEND_OVERRIDE_RULES = [
'value': True,
}
]
+ },
+ {
+ 'rule_id': 'global/override/.m.rule.tombstone',
+ 'conditions': [
+ {
+ 'kind': 'event_match',
+ 'key': 'type',
+ 'pattern': 'm.room.tombstone',
+ '_id': '_tombstone',
+ }
+ ],
+ 'actions': [
+ 'notify', {
+ 'set_tweak': 'highlight',
+ 'value': True,
+ }
+ ]
}
]
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 779f36dbed..f64baa4d58 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -16,7 +16,12 @@
import logging
-from pkg_resources import DistributionNotFound, VersionConflict, get_distribution
+from pkg_resources import (
+ DistributionNotFound,
+ Requirement,
+ VersionConflict,
+ get_provider,
+)
logger = logging.getLogger(__name__)
@@ -53,7 +58,7 @@ REQUIREMENTS = [
"pyasn1-modules>=0.0.7",
"daemonize>=2.3.1",
"bcrypt>=3.1.0",
- "pillow>=3.1.2",
+ "pillow>=4.3.0",
"sortedcontainers>=1.4.4",
"psutil>=2.0.0",
"pymacaroons>=0.13.0",
@@ -83,7 +88,13 @@ CONDITIONAL_REQUIREMENTS = {
# ACME support is required to provision TLS certificates from authorities
# that use the protocol, such as Let's Encrypt.
- "acme": ["txacme>=0.9.2"],
+ "acme": [
+ "txacme>=0.9.2",
+
+ # txacme depends on eliot. Eliot 1.8.0 is incompatible with
+ # python 3.5.2, as per https://github.com/itamarst/eliot/issues/418
+ 'eliot<1.8.0;python_version<"3.5.3"',
+ ],
"saml2": ["pysaml2>=4.5.0"],
"systemd": ["systemd-python>=231"],
@@ -117,10 +128,10 @@ class DependencyException(Exception):
@property
def dependencies(self):
for i in self.args[0]:
- yield '"' + i + '"'
+ yield "'" + i + "'"
-def check_requirements(for_feature=None, _get_distribution=get_distribution):
+def check_requirements(for_feature=None):
deps_needed = []
errors = []
@@ -131,7 +142,7 @@ def check_requirements(for_feature=None, _get_distribution=get_distribution):
for dependency in reqs:
try:
- _get_distribution(dependency)
+ _check_requirement(dependency)
except VersionConflict as e:
deps_needed.append(dependency)
errors.append(
@@ -149,7 +160,7 @@ def check_requirements(for_feature=None, _get_distribution=get_distribution):
for dependency in OPTS:
try:
- _get_distribution(dependency)
+ _check_requirement(dependency)
except VersionConflict as e:
deps_needed.append(dependency)
errors.append(
@@ -167,6 +178,23 @@ def check_requirements(for_feature=None, _get_distribution=get_distribution):
raise DependencyException(deps_needed)
+def _check_requirement(dependency_string):
+ """Parses a dependency string, and checks if the specified requirement is installed
+
+ Raises:
+ VersionConflict if the requirement is installed, but with the the wrong version
+ DistributionNotFound if nothing is found to provide the requirement
+ """
+ req = Requirement.parse(dependency_string)
+
+ # first check if the markers specify that this requirement needs installing
+ if req.marker is not None and not req.marker.evaluate():
+ # not required for this environment
+ return
+
+ get_provider(req)
+
+
if __name__ == "__main__":
import sys
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index b457c5563f..a3952506c1 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import (
from synapse.storage.event_federation import EventFederationWorkerStore
from synapse.storage.event_push_actions import EventPushActionsWorkerStore
from synapse.storage.events_worker import EventsWorkerStore
+from synapse.storage.relations import RelationsWorkerStore
from synapse.storage.roommember import RoomMemberWorkerStore
from synapse.storage.signatures import SignatureWorkerStore
from synapse.storage.state import StateGroupWorkerStore
@@ -52,6 +53,7 @@ class SlavedEventStore(EventFederationWorkerStore,
EventsWorkerStore,
SignatureWorkerStore,
UserErasureWorkerStore,
+ RelationsWorkerStore,
BaseSlavedStore):
def __init__(self, db_conn, hs):
@@ -89,7 +91,7 @@ class SlavedEventStore(EventFederationWorkerStore,
for row in rows:
self.invalidate_caches_for_event(
-token, row.event_id, row.room_id, row.type, row.state_key,
- row.redacts,
+ row.redacts, row.relates_to,
backfilled=True,
)
return super(SlavedEventStore, self).process_replication_rows(
@@ -102,7 +104,7 @@ class SlavedEventStore(EventFederationWorkerStore,
if row.type == EventsStreamEventRow.TypeId:
self.invalidate_caches_for_event(
token, data.event_id, data.room_id, data.type, data.state_key,
- data.redacts,
+ data.redacts, data.relates_to,
backfilled=False,
)
elif row.type == EventsStreamCurrentStateRow.TypeId:
@@ -114,7 +116,8 @@ class SlavedEventStore(EventFederationWorkerStore,
raise Exception("Unknown events stream row type %s" % (row.type, ))
def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,
- etype, state_key, redacts, backfilled):
+ etype, state_key, redacts, relates_to,
+ backfilled):
self._invalidate_get_event_cache(event_id)
self.get_latest_event_ids_in_room.invalidate((room_id,))
@@ -136,3 +139,8 @@ class SlavedEventStore(EventFederationWorkerStore,
state_key, stream_ordering
)
self.get_invited_rooms_for_user.invalidate((state_key,))
+
+ if relates_to:
+ self.get_relations_for_event.invalidate_many((relates_to,))
+ self.get_aggregation_groups_for_event.invalidate_many((relates_to,))
+ self.get_applicable_edit.invalidate((relates_to,))
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 8971a6a22e..b6ce7a7bee 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -32,6 +32,7 @@ BackfillStreamRow = namedtuple("BackfillStreamRow", (
"type", # str
"state_key", # str, optional
"redacts", # str, optional
+ "relates_to", # str, optional
))
PresenceStreamRow = namedtuple("PresenceStreamRow", (
"user_id", # str
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index e0f6e29248..f1290d022a 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -80,11 +80,12 @@ class BaseEventsStreamRow(object):
class EventsStreamEventRow(BaseEventsStreamRow):
TypeId = "ev"
- event_id = attr.ib() # str
- room_id = attr.ib() # str
- type = attr.ib() # str
- state_key = attr.ib() # str, optional
- redacts = attr.ib() # str, optional
+ event_id = attr.ib() # str
+ room_id = attr.ib() # str
+ type = attr.ib() # str
+ state_key = attr.ib() # str, optional
+ redacts = attr.ib() # str, optional
+ relates_to = attr.ib() # str, optional
@attr.s(slots=True, frozen=True)
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index a66885d349..e6110ad9b1 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -13,11 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import synapse.rest.admin
from synapse.http.server import JsonResource
from synapse.rest.client import versions
from synapse.rest.client.v1 import (
- admin,
directory,
events,
initial_sync,
@@ -45,6 +44,7 @@ from synapse.rest.client.v2_alpha import (
read_marker,
receipts,
register,
+ relations,
report_event,
room_keys,
room_upgrade_rest_servlet,
@@ -58,8 +58,14 @@ from synapse.rest.client.v2_alpha import (
class ClientRestResource(JsonResource):
- """A resource for version 1 of the matrix client API."""
+ """Matrix Client API REST resource.
+ This gets mounted at various points under /_matrix/client, including:
+ * /_matrix/client/r0
+ * /_matrix/client/api/v1
+ * /_matrix/client/unstable
+ * etc
+ """
def __init__(self, hs):
JsonResource.__init__(self, hs, canonical_json=False)
self.register_servlets(self, hs)
@@ -82,7 +88,6 @@ class ClientRestResource(JsonResource):
presence.register_servlets(hs, client_resource)
directory.register_servlets(hs, client_resource)
voip.register_servlets(hs, client_resource)
- admin.register_servlets(hs, client_resource)
pusher.register_servlets(hs, client_resource)
push_rule.register_servlets(hs, client_resource)
logout.register_servlets(hs, client_resource)
@@ -111,3 +116,9 @@ class ClientRestResource(JsonResource):
room_upgrade_rest_servlet.register_servlets(hs, client_resource)
capabilities.register_servlets(hs, client_resource)
account_validity.register_servlets(hs, client_resource)
+ relations.register_servlets(hs, client_resource)
+
+ # moving to /_synapse/admin
+ synapse.rest.admin.register_servlets_for_client_rest_resource(
+ hs, client_resource
+ )
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/admin/__init__.py
index 0a1e233b23..d6c4dcdb18 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/admin/__init__.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2019 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@ import hashlib
import hmac
import logging
import platform
+import re
from six import text_type
from six.moves import http_client
@@ -27,39 +28,56 @@ from twisted.internet import defer
import synapse
from synapse.api.constants import Membership, UserTypes
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
+from synapse.http.server import JsonResource
from synapse.http.servlet import (
+ RestServlet,
assert_params_in_dict,
parse_integer,
parse_json_object_from_request,
parse_string,
)
+from synapse.rest.admin._base import assert_requester_is_admin, assert_user_is_admin
+from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
from synapse.types import UserID, create_requester
from synapse.util.versionstring import get_version_string
-from .base import ClientV1RestServlet, client_path_patterns
-
logger = logging.getLogger(__name__)
-class UsersRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/admin/users/(?P<user_id>[^/]*)")
+def historical_admin_path_patterns(path_regex):
+ """Returns the list of patterns for an admin endpoint, including historical ones
+
+ This is a backwards-compatibility hack. Previously, the Admin API was exposed at
+ various paths under /_matrix/client. This function returns a list of patterns
+ matching those paths (as well as the new one), so that existing scripts which rely
+ on the endpoints being available there are not broken.
+
+ Note that this should only be used for existing endpoints: new ones should just
+ register for the /_synapse/admin path.
+ """
+ return list(
+ re.compile(prefix + path_regex)
+ for prefix in (
+ "^/_synapse/admin/v1",
+ "^/_matrix/client/api/v1/admin",
+ "^/_matrix/client/unstable/admin",
+ "^/_matrix/client/r0/admin"
+ )
+ )
+
+
+class UsersRestServlet(RestServlet):
+ PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)")
def __init__(self, hs):
- super(UsersRestServlet, self).__init__(hs)
+ self.hs = hs
+ self.auth = hs.get_auth()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id)
- requester = yield self.auth.get_user_by_req(request)
- is_admin = yield self.auth.is_server_admin(requester.user)
-
- if not is_admin:
- raise AuthError(403, "You are not a server admin")
-
- # To allow all users to get the users list
- # if not is_admin and target_user != auth_user:
- # raise AuthError(403, "You are not a server admin")
+ yield assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only users a local user")
@@ -69,37 +87,30 @@ class UsersRestServlet(ClientV1RestServlet):
defer.returnValue((200, ret))
-class VersionServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/admin/server_version")
+class VersionServlet(RestServlet):
+ PATTERNS = (re.compile("^/_synapse/admin/v1/server_version$"), )
- @defer.inlineCallbacks
- def on_GET(self, request):
- requester = yield self.auth.get_user_by_req(request)
- is_admin = yield self.auth.is_server_admin(requester.user)
-
- if not is_admin:
- raise AuthError(403, "You are not a server admin")
-
- ret = {
+ def __init__(self, hs):
+ self.res = {
'server_version': get_version_string(synapse),
'python_version': platform.python_version(),
}
- defer.returnValue((200, ret))
+ def on_GET(self, request):
+ return 200, self.res
-class UserRegisterServlet(ClientV1RestServlet):
+class UserRegisterServlet(RestServlet):
"""
Attributes:
NONCE_TIMEOUT (int): Seconds until a generated nonce won't be accepted
nonces (dict[str, int]): The nonces that we will accept. A dict of
nonce to the time it was generated, in int seconds.
"""
- PATTERNS = client_path_patterns("/admin/register")
+ PATTERNS = historical_admin_path_patterns("/register")
NONCE_TIMEOUT = 60
def __init__(self, hs):
- super(UserRegisterServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
self.reactor = hs.get_reactor()
self.nonces = {}
@@ -226,11 +237,12 @@ class UserRegisterServlet(ClientV1RestServlet):
defer.returnValue((200, result))
-class WhoisRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/admin/whois/(?P<user_id>[^/]*)")
+class WhoisRestServlet(RestServlet):
+ PATTERNS = historical_admin_path_patterns("/whois/(?P<user_id>[^/]*)")
def __init__(self, hs):
- super(WhoisRestServlet, self).__init__(hs)
+ self.hs = hs
+ self.auth = hs.get_auth()
self.handlers = hs.get_handlers()
@defer.inlineCallbacks
@@ -238,10 +250,9 @@ class WhoisRestServlet(ClientV1RestServlet):
target_user = UserID.from_string(user_id)
requester = yield self.auth.get_user_by_req(request)
auth_user = requester.user
- is_admin = yield self.auth.is_server_admin(requester.user)
- if not is_admin and target_user != auth_user:
- raise AuthError(403, "You are not a server admin")
+ if target_user != auth_user:
+ yield assert_user_is_admin(self.auth, auth_user)
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only whois a local user")
@@ -251,20 +262,16 @@ class WhoisRestServlet(ClientV1RestServlet):
defer.returnValue((200, ret))
-class PurgeMediaCacheRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/admin/purge_media_cache")
+class PurgeMediaCacheRestServlet(RestServlet):
+ PATTERNS = historical_admin_path_patterns("/purge_media_cache")
def __init__(self, hs):
self.media_repository = hs.get_media_repository()
- super(PurgeMediaCacheRestServlet, self).__init__(hs)
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
- is_admin = yield self.auth.is_server_admin(requester.user)
-
- if not is_admin:
- raise AuthError(403, "You are not a server admin")
+ yield assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
logger.info("before_ts: %r", before_ts)
@@ -274,9 +281,9 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet):
defer.returnValue((200, ret))
-class PurgeHistoryRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns(
- "/admin/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
+class PurgeHistoryRestServlet(RestServlet):
+ PATTERNS = historical_admin_path_patterns(
+ "/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
)
def __init__(self, hs):
@@ -285,17 +292,13 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
Args:
hs (synapse.server.HomeServer)
"""
- super(PurgeHistoryRestServlet, self).__init__(hs)
self.pagination_handler = hs.get_pagination_handler()
self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_id):
- requester = yield self.auth.get_user_by_req(request)
- is_admin = yield self.auth.is_server_admin(requester.user)
-
- if not is_admin:
- raise AuthError(403, "You are not a server admin")
+ yield assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request, allow_empty_body=True)
@@ -371,9 +374,9 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
}))
-class PurgeHistoryStatusRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns(
- "/admin/purge_history_status/(?P<purge_id>[^/]+)"
+class PurgeHistoryStatusRestServlet(RestServlet):
+ PATTERNS = historical_admin_path_patterns(
+ "/purge_history_status/(?P<purge_id>[^/]+)"
)
def __init__(self, hs):
@@ -382,16 +385,12 @@ class PurgeHistoryStatusRestServlet(ClientV1RestServlet):
Args:
hs (synapse.server.HomeServer)
"""
- super(PurgeHistoryStatusRestServlet, self).__init__(hs)
self.pagination_handler = hs.get_pagination_handler()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, purge_id):
- requester = yield self.auth.get_user_by_req(request)
- is_admin = yield self.auth.is_server_admin(requester.user)
-
- if not is_admin:
- raise AuthError(403, "You are not a server admin")
+ yield assert_requester_is_admin(self.auth, request)
purge_status = self.pagination_handler.get_purge_status(purge_id)
if purge_status is None:
@@ -400,15 +399,16 @@ class PurgeHistoryStatusRestServlet(ClientV1RestServlet):
defer.returnValue((200, purge_status.asdict()))
-class DeactivateAccountRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)")
+class DeactivateAccountRestServlet(RestServlet):
+ PATTERNS = historical_admin_path_patterns("/deactivate/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
- super(DeactivateAccountRestServlet, self).__init__(hs)
self._deactivate_account_handler = hs.get_deactivate_account_handler()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_POST(self, request, target_user_id):
+ yield assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request, allow_empty_body=True)
erase = body.get("erase", False)
if not isinstance(erase, bool):
@@ -419,11 +419,6 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
)
UserID.from_string(target_user_id)
- requester = yield self.auth.get_user_by_req(request)
- is_admin = yield self.auth.is_server_admin(requester.user)
-
- if not is_admin:
- raise AuthError(403, "You are not a server admin")
result = yield self._deactivate_account_handler.deactivate_account(
target_user_id, erase,
@@ -438,13 +433,13 @@ class DeactivateAccountRestServlet(ClientV1RestServlet):
}))
-class ShutdownRoomRestServlet(ClientV1RestServlet):
+class ShutdownRoomRestServlet(RestServlet):
"""Shuts down a room by removing all local users from the room and blocking
all future invites and joins to the room. Any local aliases will be repointed
to a new room created by `new_room_user_id` and kicked users will be auto
joined to the new room.
"""
- PATTERNS = client_path_patterns("/admin/shutdown_room/(?P<room_id>[^/]+)")
+ PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P<room_id>[^/]+)")
DEFAULT_MESSAGE = (
"Sharing illegal content on this server is not permitted and rooms in"
@@ -452,19 +447,18 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
)
def __init__(self, hs):
- super(ShutdownRoomRestServlet, self).__init__(hs)
+ self.hs = hs
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self._room_creation_handler = hs.get_room_creation_handler()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_POST(self, request, room_id):
requester = yield self.auth.get_user_by_req(request)
- is_admin = yield self.auth.is_server_admin(requester.user)
- if not is_admin:
- raise AuthError(403, "You are not a server admin")
+ yield assert_user_is_admin(self.auth, requester.user)
content = parse_json_object_from_request(request)
assert_params_in_dict(content, ["new_room_user_id"])
@@ -564,22 +558,20 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
}))
-class QuarantineMediaInRoom(ClientV1RestServlet):
+class QuarantineMediaInRoom(RestServlet):
"""Quarantines all media in a room so that no one can download it via
this server.
"""
- PATTERNS = client_path_patterns("/admin/quarantine_media/(?P<room_id>[^/]+)")
+ PATTERNS = historical_admin_path_patterns("/quarantine_media/(?P<room_id>[^/]+)")
def __init__(self, hs):
- super(QuarantineMediaInRoom, self).__init__(hs)
self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_POST(self, request, room_id):
requester = yield self.auth.get_user_by_req(request)
- is_admin = yield self.auth.is_server_admin(requester.user)
- if not is_admin:
- raise AuthError(403, "You are not a server admin")
+ yield assert_user_is_admin(self.auth, requester.user)
num_quarantined = yield self.store.quarantine_media_ids_in_room(
room_id, requester.user.to_string(),
@@ -588,13 +580,12 @@ class QuarantineMediaInRoom(ClientV1RestServlet):
defer.returnValue((200, {"num_quarantined": num_quarantined}))
-class ListMediaInRoom(ClientV1RestServlet):
+class ListMediaInRoom(RestServlet):
"""Lists all of the media in a given room.
"""
- PATTERNS = client_path_patterns("/admin/room/(?P<room_id>[^/]+)/media")
+ PATTERNS = historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media")
def __init__(self, hs):
- super(ListMediaInRoom, self).__init__(hs)
self.store = hs.get_datastore()
@defer.inlineCallbacks
@@ -609,11 +600,11 @@ class ListMediaInRoom(ClientV1RestServlet):
defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs}))
-class ResetPasswordRestServlet(ClientV1RestServlet):
+class ResetPasswordRestServlet(RestServlet):
"""Post request to allow an administrator reset password for a user.
This needs user to have administrator access in Synapse.
Example:
- http://localhost:8008/_matrix/client/api/v1/admin/reset_password/
+ http://localhost:8008/_synapse/admin/v1/reset_password/
@user:to_reset_password?access_token=admin_access_token
JsonBodyToSend:
{
@@ -622,11 +613,10 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
Returns:
200 OK with empty object if success otherwise an error.
"""
- PATTERNS = client_path_patterns("/admin/reset_password/(?P<target_user_id>[^/]*)")
+ PATTERNS = historical_admin_path_patterns("/reset_password/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
self.store = hs.get_datastore()
- super(ResetPasswordRestServlet, self).__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
self._set_password_handler = hs.get_set_password_handler()
@@ -636,12 +626,10 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
"""Post request to allow an administrator reset password for a user.
This needs user to have administrator access in Synapse.
"""
- UserID.from_string(target_user_id)
requester = yield self.auth.get_user_by_req(request)
- is_admin = yield self.auth.is_server_admin(requester.user)
+ yield assert_user_is_admin(self.auth, requester.user)
- if not is_admin:
- raise AuthError(403, "You are not a server admin")
+ UserID.from_string(target_user_id)
params = parse_json_object_from_request(request)
assert_params_in_dict(params, ["new_password"])
@@ -653,20 +641,19 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
defer.returnValue((200, {}))
-class GetUsersPaginatedRestServlet(ClientV1RestServlet):
+class GetUsersPaginatedRestServlet(RestServlet):
"""Get request to get specific number of users from Synapse.
This needs user to have administrator access in Synapse.
Example:
- http://localhost:8008/_matrix/client/api/v1/admin/users_paginate/
+ http://localhost:8008/_synapse/admin/v1/users_paginate/
@admin:user?access_token=admin_access_token&start=0&limit=10
Returns:
200 OK with json object {list[dict[str, Any]], count} or empty object.
"""
- PATTERNS = client_path_patterns("/admin/users_paginate/(?P<target_user_id>[^/]*)")
+ PATTERNS = historical_admin_path_patterns("/users_paginate/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
self.store = hs.get_datastore()
- super(GetUsersPaginatedRestServlet, self).__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
self.handlers = hs.get_handlers()
@@ -676,16 +663,9 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
"""Get request to get specific number of users from Synapse.
This needs user to have administrator access in Synapse.
"""
- target_user = UserID.from_string(target_user_id)
- requester = yield self.auth.get_user_by_req(request)
- is_admin = yield self.auth.is_server_admin(requester.user)
-
- if not is_admin:
- raise AuthError(403, "You are not a server admin")
+ yield assert_requester_is_admin(self.auth, request)
- # To allow all users to get the users list
- # if not is_admin and target_user != auth_user:
- # raise AuthError(403, "You are not a server admin")
+ target_user = UserID.from_string(target_user_id)
if not self.hs.is_mine(target_user):
raise SynapseError(400, "Can only users a local user")
@@ -706,7 +686,7 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
"""Post request to get specific number of users from Synapse..
This needs user to have administrator access in Synapse.
Example:
- http://localhost:8008/_matrix/client/api/v1/admin/users_paginate/
+ http://localhost:8008/_synapse/admin/v1/users_paginate/
@admin:user?access_token=admin_access_token
JsonBodyToSend:
{
@@ -716,12 +696,8 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
Returns:
200 OK with json object {list[dict[str, Any]], count} or empty object.
"""
+ yield assert_requester_is_admin(self.auth, request)
UserID.from_string(target_user_id)
- requester = yield self.auth.get_user_by_req(request)
- is_admin = yield self.auth.is_server_admin(requester.user)
-
- if not is_admin:
- raise AuthError(403, "You are not a server admin")
order = "name" # order by name in user table
params = parse_json_object_from_request(request)
@@ -736,21 +712,20 @@ class GetUsersPaginatedRestServlet(ClientV1RestServlet):
defer.returnValue((200, ret))
-class SearchUsersRestServlet(ClientV1RestServlet):
+class SearchUsersRestServlet(RestServlet):
"""Get request to search user table for specific users according to
search term.
This needs user to have administrator access in Synapse.
Example:
- http://localhost:8008/_matrix/client/api/v1/admin/search_users/
+ http://localhost:8008/_synapse/admin/v1/search_users/
@admin:user?access_token=admin_access_token&term=alice
Returns:
200 OK with json object {list[dict[str, Any]], count} or empty object.
"""
- PATTERNS = client_path_patterns("/admin/search_users/(?P<target_user_id>[^/]*)")
+ PATTERNS = historical_admin_path_patterns("/search_users/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
self.store = hs.get_datastore()
- super(SearchUsersRestServlet, self).__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
self.handlers = hs.get_handlers()
@@ -761,12 +736,9 @@ class SearchUsersRestServlet(ClientV1RestServlet):
search term.
This needs user to have a administrator access in Synapse.
"""
- target_user = UserID.from_string(target_user_id)
- requester = yield self.auth.get_user_by_req(request)
- is_admin = yield self.auth.is_server_admin(requester.user)
+ yield assert_requester_is_admin(self.auth, request)
- if not is_admin:
- raise AuthError(403, "You are not a server admin")
+ target_user = UserID.from_string(target_user_id)
# To allow all users to get the users list
# if not is_admin and target_user != auth_user:
@@ -784,23 +756,20 @@ class SearchUsersRestServlet(ClientV1RestServlet):
defer.returnValue((200, ret))
-class DeleteGroupAdminRestServlet(ClientV1RestServlet):
+class DeleteGroupAdminRestServlet(RestServlet):
"""Allows deleting of local groups
"""
- PATTERNS = client_path_patterns("/admin/delete_group/(?P<group_id>[^/]*)")
+ PATTERNS = historical_admin_path_patterns("/delete_group/(?P<group_id>[^/]*)")
def __init__(self, hs):
- super(DeleteGroupAdminRestServlet, self).__init__(hs)
self.group_server = hs.get_groups_server_handler()
self.is_mine_id = hs.is_mine_id
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_POST(self, request, group_id):
requester = yield self.auth.get_user_by_req(request)
- is_admin = yield self.auth.is_server_admin(requester.user)
-
- if not is_admin:
- raise AuthError(403, "You are not a server admin")
+ yield assert_user_is_admin(self.auth, requester.user)
if not self.is_mine_id(group_id):
raise SynapseError(400, "Can only delete local groups")
@@ -809,27 +778,21 @@ class DeleteGroupAdminRestServlet(ClientV1RestServlet):
defer.returnValue((200, {}))
-class AccountValidityRenewServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/admin/account_validity/validity$")
+class AccountValidityRenewServlet(RestServlet):
+ PATTERNS = historical_admin_path_patterns("/account_validity/validity$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
- super(AccountValidityRenewServlet, self).__init__(hs)
-
self.hs = hs
self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_POST(self, request):
- requester = yield self.auth.get_user_by_req(request)
- is_admin = yield self.auth.is_server_admin(requester.user)
-
- if not is_admin:
- raise AuthError(403, "You are not a server admin")
+ yield assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
@@ -846,8 +809,33 @@ class AccountValidityRenewServlet(ClientV1RestServlet):
}
defer.returnValue((200, res))
+########################################################################################
+#
+# please don't add more servlets here: this file is already long and unwieldy. Put
+# them in separate files within the 'admin' package.
+#
+########################################################################################
+
+
+class AdminRestResource(JsonResource):
+ """The REST resource which gets mounted at /_synapse/admin"""
+
+ def __init__(self, hs):
+ JsonResource.__init__(self, hs, canonical_json=False)
+ register_servlets(hs, self)
+
def register_servlets(hs, http_server):
+ """
+ Register all the admin servlets.
+ """
+ register_servlets_for_client_rest_resource(hs, http_server)
+ SendServerNoticeServlet(hs).register(http_server)
+ VersionServlet(hs).register(http_server)
+
+
+def register_servlets_for_client_rest_resource(hs, http_server):
+ """Register only the servlets which need to be exposed on /_matrix/client/xxx"""
WhoisRestServlet(hs).register(http_server)
PurgeMediaCacheRestServlet(hs).register(http_server)
PurgeHistoryStatusRestServlet(hs).register(http_server)
@@ -861,6 +849,7 @@ def register_servlets(hs, http_server):
QuarantineMediaInRoom(hs).register(http_server)
ListMediaInRoom(hs).register(http_server)
UserRegisterServlet(hs).register(http_server)
- VersionServlet(hs).register(http_server)
DeleteGroupAdminRestServlet(hs).register(http_server)
AccountValidityRenewServlet(hs).register(http_server)
+ # don't add more things here: new servlets should only be exposed on
+ # /_synapse/admin so should not go here. Instead register them in AdminRestResource.
diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
new file mode 100644
index 0000000000..881d67b89c
--- /dev/null
+++ b/synapse/rest/admin/_base.py
@@ -0,0 +1,59 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+
+from synapse.api.errors import AuthError
+
+
+@defer.inlineCallbacks
+def assert_requester_is_admin(auth, request):
+ """Verify that the requester is an admin user
+
+ WARNING: MAKE SURE YOU YIELD ON THE RESULT!
+
+ Args:
+ auth (synapse.api.auth.Auth):
+ request (twisted.web.server.Request): incoming request
+
+ Returns:
+ Deferred
+
+ Raises:
+ AuthError if the requester is not an admin
+ """
+ requester = yield auth.get_user_by_req(request)
+ yield assert_user_is_admin(auth, requester.user)
+
+
+@defer.inlineCallbacks
+def assert_user_is_admin(auth, user_id):
+ """Verify that the given user is an admin user
+
+ WARNING: MAKE SURE YOU YIELD ON THE RESULT!
+
+ Args:
+ auth (synapse.api.auth.Auth):
+ user_id (UserID):
+
+ Returns:
+ Deferred
+
+ Raises:
+ AuthError if the user is not an admin
+ """
+
+ is_admin = yield auth.is_server_admin(user_id)
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
new file mode 100644
index 0000000000..ae5aca9dac
--- /dev/null
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -0,0 +1,100 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import re
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
+from synapse.rest.admin import assert_requester_is_admin
+from synapse.rest.client.transactions import HttpTransactionCache
+from synapse.types import UserID
+
+
+class SendServerNoticeServlet(RestServlet):
+ """Servlet which will send a server notice to a given user
+
+ POST /_synapse/admin/v1/send_server_notice
+ {
+ "user_id": "@target_user:server_name",
+ "content": {
+ "msgtype": "m.text",
+ "body": "This is my message"
+ }
+ }
+
+ returns:
+
+ {
+ "event_id": "$1895723857jgskldgujpious"
+ }
+ """
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): server
+ """
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.txns = HttpTransactionCache(hs)
+ self.snm = hs.get_server_notices_manager()
+
+ def register(self, json_resource):
+ PATTERN = "^/_synapse/admin/v1/send_server_notice"
+ json_resource.register_paths(
+ "POST",
+ (re.compile(PATTERN + "$"), ),
+ self.on_POST,
+ )
+ json_resource.register_paths(
+ "PUT",
+ (re.compile(PATTERN + "/(?P<txn_id>[^/]*)$",), ),
+ self.on_PUT,
+ )
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, txn_id=None):
+ yield assert_requester_is_admin(self.auth, request)
+ body = parse_json_object_from_request(request)
+ assert_params_in_dict(body, ("user_id", "content"))
+ event_type = body.get("type", EventTypes.Message)
+ state_key = body.get("state_key")
+
+ if not self.snm.is_enabled():
+ raise SynapseError(400, "Server notices are not enabled on this server")
+
+ user_id = body["user_id"]
+ UserID.from_string(user_id)
+ if not self.hs.is_mine_id(user_id):
+ raise SynapseError(400, "Server notices can only be sent to local users")
+
+ event = yield self.snm.send_notice(
+ user_id=body["user_id"],
+ type=event_type,
+ state_key=state_key,
+ event_content=body["content"],
+ )
+
+ defer.returnValue((200, {"event_id": event.event_id}))
+
+ def on_PUT(self, request, txn_id):
+ return self.txns.fetch_or_execute_request(
+ request, self.on_POST, request, txn_id,
+ )
diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py
deleted file mode 100644
index c77d7aba68..0000000000
--- a/synapse/rest/client/v1/base.py
+++ /dev/null
@@ -1,65 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""This module contains base REST classes for constructing client v1 servlets.
-"""
-
-import logging
-import re
-
-from synapse.api.urls import CLIENT_PREFIX
-from synapse.http.servlet import RestServlet
-from synapse.rest.client.transactions import HttpTransactionCache
-
-logger = logging.getLogger(__name__)
-
-
-def client_path_patterns(path_regex, releases=(0,), include_in_unstable=True):
- """Creates a regex compiled client path with the correct client path
- prefix.
-
- Args:
- path_regex (str): The regex string to match. This should NOT have a ^
- as this will be prefixed.
- Returns:
- SRE_Pattern
- """
- patterns = [re.compile("^" + CLIENT_PREFIX + path_regex)]
- if include_in_unstable:
- unstable_prefix = CLIENT_PREFIX.replace("/api/v1", "/unstable")
- patterns.append(re.compile("^" + unstable_prefix + path_regex))
- for release in releases:
- new_prefix = CLIENT_PREFIX.replace("/api/v1", "/r%d" % release)
- patterns.append(re.compile("^" + new_prefix + path_regex))
- return patterns
-
-
-class ClientV1RestServlet(RestServlet):
- """A base Synapse REST Servlet for the client version 1 API.
- """
-
- # This subclass was presumably created to allow the auth for the v1
- # protocol version to be different, however this behaviour was removed.
- # it may no longer be necessary
-
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer):
- """
- self.hs = hs
- self.builder_factory = hs.get_event_builder_factory()
- self.auth = hs.get_auth()
- self.txns = HttpTransactionCache(hs)
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 0220acf644..0035182bb9 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -19,11 +19,10 @@ import logging
from twisted.internet import defer
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import RoomAlias
-from .base import ClientV1RestServlet, client_path_patterns
-
logger = logging.getLogger(__name__)
@@ -33,13 +32,14 @@ def register_servlets(hs, http_server):
ClientAppserviceDirectoryListServer(hs).register(http_server)
-class ClientDirectoryServer(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$")
+class ClientDirectoryServer(RestServlet):
+ PATTERNS = client_patterns("/directory/room/(?P<room_alias>[^/]*)$", v1=True)
def __init__(self, hs):
- super(ClientDirectoryServer, self).__init__(hs)
+ super(ClientDirectoryServer, self).__init__()
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_alias):
@@ -120,13 +120,14 @@ class ClientDirectoryServer(ClientV1RestServlet):
defer.returnValue((200, {}))
-class ClientDirectoryListServer(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/directory/list/room/(?P<room_id>[^/]*)$")
+class ClientDirectoryListServer(RestServlet):
+ PATTERNS = client_patterns("/directory/list/room/(?P<room_id>[^/]*)$", v1=True)
def __init__(self, hs):
- super(ClientDirectoryListServer, self).__init__(hs)
+ super(ClientDirectoryListServer, self).__init__()
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
@@ -162,15 +163,16 @@ class ClientDirectoryListServer(ClientV1RestServlet):
defer.returnValue((200, {}))
-class ClientAppserviceDirectoryListServer(ClientV1RestServlet):
- PATTERNS = client_path_patterns(
- "/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$"
+class ClientAppserviceDirectoryListServer(RestServlet):
+ PATTERNS = client_patterns(
+ "/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$", v1=True
)
def __init__(self, hs):
- super(ClientAppserviceDirectoryListServer, self).__init__(hs)
+ super(ClientAppserviceDirectoryListServer, self).__init__()
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
+ self.auth = hs.get_auth()
def on_PUT(self, request, network_id, room_id):
content = parse_json_object_from_request(request)
diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index cd9b3bdbd1..84ca36270b 100644
--- a/synapse/rest/client/v1/events.py
+++ b/synapse/rest/client/v1/events.py
@@ -19,22 +19,22 @@ import logging
from twisted.internet import defer
from synapse.api.errors import SynapseError
-from synapse.events.utils import serialize_event
+from synapse.http.servlet import RestServlet
+from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.streams.config import PaginationConfig
-from .base import ClientV1RestServlet, client_path_patterns
-
logger = logging.getLogger(__name__)
-class EventStreamRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/events$")
+class EventStreamRestServlet(RestServlet):
+ PATTERNS = client_patterns("/events$", v1=True)
DEFAULT_LONGPOLL_TIME_MS = 30000
def __init__(self, hs):
- super(EventStreamRestServlet, self).__init__(hs)
+ super(EventStreamRestServlet, self).__init__()
self.event_stream_handler = hs.get_event_stream_handler()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -77,13 +77,14 @@ class EventStreamRestServlet(ClientV1RestServlet):
# TODO: Unit test gets, with and without auth, with different kinds of events.
-class EventRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/events/(?P<event_id>[^/]*)$")
+class EventRestServlet(RestServlet):
+ PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True)
def __init__(self, hs):
- super(EventRestServlet, self).__init__(hs)
+ super(EventRestServlet, self).__init__()
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
+ self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
def on_GET(self, request, event_id):
@@ -92,7 +93,8 @@ class EventRestServlet(ClientV1RestServlet):
time_now = self.clock.time_msec()
if event:
- defer.returnValue((200, serialize_event(event, time_now)))
+ event = yield self._event_serializer.serialize_event(event, time_now)
+ defer.returnValue((200, event))
else:
defer.returnValue((404, "Event not found."))
diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py
index 3ead75cb77..0fe5f2d79b 100644
--- a/synapse/rest/client/v1/initial_sync.py
+++ b/synapse/rest/client/v1/initial_sync.py
@@ -15,19 +15,19 @@
from twisted.internet import defer
-from synapse.http.servlet import parse_boolean
+from synapse.http.servlet import RestServlet, parse_boolean
+from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.streams.config import PaginationConfig
-from .base import ClientV1RestServlet, client_path_patterns
-
# TODO: Needs unit testing
-class InitialSyncRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/initialSync$")
+class InitialSyncRestServlet(RestServlet):
+ PATTERNS = client_patterns("/initialSync$", v1=True)
def __init__(self, hs):
- super(InitialSyncRestServlet, self).__init__(hs)
+ super(InitialSyncRestServlet, self).__init__()
self.initial_sync_handler = hs.get_initial_sync_handler()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request):
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 5180e9eaf1..3b60728628 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -29,12 +29,11 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import UserID, map_username_to_mxid_localpart
from synapse.util.msisdn import phone_number_to_msisdn
-from .base import ClientV1RestServlet, client_path_patterns
-
logger = logging.getLogger(__name__)
@@ -81,15 +80,16 @@ def login_id_thirdparty_from_phone(identifier):
}
-class LoginRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/login$")
+class LoginRestServlet(RestServlet):
+ PATTERNS = client_patterns("/login$", v1=True)
CAS_TYPE = "m.login.cas"
SSO_TYPE = "m.login.sso"
TOKEN_TYPE = "m.login.token"
JWT_TYPE = "m.login.jwt"
def __init__(self, hs):
- super(LoginRestServlet, self).__init__(hs)
+ super(LoginRestServlet, self).__init__()
+ self.hs = hs
self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm
@@ -371,7 +371,7 @@ class LoginRestServlet(ClientV1RestServlet):
class CasRedirectServlet(RestServlet):
- PATTERNS = client_path_patterns("/login/(cas|sso)/redirect")
+ PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
def __init__(self, hs):
super(CasRedirectServlet, self).__init__()
@@ -386,7 +386,7 @@ class CasRedirectServlet(RestServlet):
b"redirectUrl": args[b"redirectUrl"][0]
}).encode('ascii')
hs_redirect_url = (self.cas_service_url +
- b"/_matrix/client/api/v1/login/cas/ticket")
+ b"/_matrix/client/r0/login/cas/ticket")
service_param = urllib.parse.urlencode({
b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)
}).encode('ascii')
@@ -394,27 +394,27 @@ class CasRedirectServlet(RestServlet):
finish_request(request)
-class CasTicketServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/login/cas/ticket", releases=())
+class CasTicketServlet(RestServlet):
+ PATTERNS = client_patterns("/login/cas/ticket", v1=True)
def __init__(self, hs):
- super(CasTicketServlet, self).__init__(hs)
+ super(CasTicketServlet, self).__init__()
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes
self._sso_auth_handler = SSOAuthHandler(hs)
+ self._http_client = hs.get_simple_http_client()
@defer.inlineCallbacks
def on_GET(self, request):
client_redirect_url = parse_string(request, "redirectUrl", required=True)
- http_client = self.hs.get_simple_http_client()
uri = self.cas_server_url + "/proxyValidate"
args = {
"ticket": parse_string(request, "ticket", required=True),
"service": self.cas_service_url
}
try:
- body = yield http_client.get_raw(uri, args)
+ body = yield self._http_client.get_raw(uri, args)
except PartialDownloadError as pde:
# Twisted raises this error if the connection is closed,
# even if that's being used old-http style to signal end-of-data
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index 430c692336..b8064f261e 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -17,19 +17,18 @@ import logging
from twisted.internet import defer
-from synapse.api.errors import AuthError
-
-from .base import ClientV1RestServlet, client_path_patterns
+from synapse.http.servlet import RestServlet
+from synapse.rest.client.v2_alpha._base import client_patterns
logger = logging.getLogger(__name__)
-class LogoutRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/logout$")
+class LogoutRestServlet(RestServlet):
+ PATTERNS = client_patterns("/logout$", v1=True)
def __init__(self, hs):
- super(LogoutRestServlet, self).__init__(hs)
- self._auth = hs.get_auth()
+ super(LogoutRestServlet, self).__init__()
+ self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
@@ -38,32 +37,25 @@ class LogoutRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_POST(self, request):
- try:
- requester = yield self.auth.get_user_by_req(request)
- except AuthError:
- # this implies the access token has already been deleted.
- defer.returnValue((401, {
- "errcode": "M_UNKNOWN_TOKEN",
- "error": "Access Token unknown or expired"
- }))
+ requester = yield self.auth.get_user_by_req(request)
+
+ if requester.device_id is None:
+ # the acccess token wasn't associated with a device.
+ # Just delete the access token
+ access_token = self.auth.get_access_token_from_request(request)
+ yield self._auth_handler.delete_access_token(access_token)
else:
- if requester.device_id is None:
- # the acccess token wasn't associated with a device.
- # Just delete the access token
- access_token = self._auth.get_access_token_from_request(request)
- yield self._auth_handler.delete_access_token(access_token)
- else:
- yield self._device_handler.delete_device(
- requester.user.to_string(), requester.device_id)
+ yield self._device_handler.delete_device(
+ requester.user.to_string(), requester.device_id)
defer.returnValue((200, {}))
-class LogoutAllRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/logout/all$")
+class LogoutAllRestServlet(RestServlet):
+ PATTERNS = client_patterns("/logout/all$", v1=True)
def __init__(self, hs):
- super(LogoutAllRestServlet, self).__init__(hs)
+ super(LogoutAllRestServlet, self).__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index 045d5a20ac..e263da3cb7 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -23,21 +23,22 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, SynapseError
from synapse.handlers.presence import format_user_presence_state
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import UserID
-from .base import ClientV1RestServlet, client_path_patterns
-
logger = logging.getLogger(__name__)
-class PresenceStatusRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status")
+class PresenceStatusRestServlet(RestServlet):
+ PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status", v1=True)
def __init__(self, hs):
- super(PresenceStatusRestServlet, self).__init__(hs)
+ super(PresenceStatusRestServlet, self).__init__()
+ self.hs = hs
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index a23edd8fe5..e15d9d82a6 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -16,26 +16,33 @@
""" This module contains REST servlets to do with profile: /profile/<paths> """
from twisted.internet import defer
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import UserID
-from .base import ClientV1RestServlet, client_path_patterns
-
-class ProfileDisplaynameRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname")
+class ProfileDisplaynameRestServlet(RestServlet):
+ PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True)
def __init__(self, hs):
- super(ProfileDisplaynameRestServlet, self).__init__(hs)
+ super(ProfileDisplaynameRestServlet, self).__init__()
+ self.hs = hs
self.profile_handler = hs.get_profile_handler()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
+ requester_user = None
+
+ if self.hs.config.require_auth_for_profile_requests:
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user = requester.user
+
user = UserID.from_string(user_id)
- displayname = yield self.profile_handler.get_displayname(
- user,
- )
+ yield self.profile_handler.check_profile_query_allowed(user, requester_user)
+
+ displayname = yield self.profile_handler.get_displayname(user)
ret = {}
if displayname is not None:
@@ -65,20 +72,28 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
return (200, {})
-class ProfileAvatarURLRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url")
+class ProfileAvatarURLRestServlet(RestServlet):
+ PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True)
def __init__(self, hs):
- super(ProfileAvatarURLRestServlet, self).__init__(hs)
+ super(ProfileAvatarURLRestServlet, self).__init__()
+ self.hs = hs
self.profile_handler = hs.get_profile_handler()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
+ requester_user = None
+
+ if self.hs.config.require_auth_for_profile_requests:
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user = requester.user
+
user = UserID.from_string(user_id)
- avatar_url = yield self.profile_handler.get_avatar_url(
- user,
- )
+ yield self.profile_handler.check_profile_query_allowed(user, requester_user)
+
+ avatar_url = yield self.profile_handler.get_avatar_url(user)
ret = {}
if avatar_url is not None:
@@ -107,23 +122,29 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
return (200, {})
-class ProfileRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)")
+class ProfileRestServlet(RestServlet):
+ PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True)
def __init__(self, hs):
- super(ProfileRestServlet, self).__init__(hs)
+ super(ProfileRestServlet, self).__init__()
+ self.hs = hs
self.profile_handler = hs.get_profile_handler()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, user_id):
+ requester_user = None
+
+ if self.hs.config.require_auth_for_profile_requests:
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user = requester.user
+
user = UserID.from_string(user_id)
- displayname = yield self.profile_handler.get_displayname(
- user,
- )
- avatar_url = yield self.profile_handler.get_avatar_url(
- user,
- )
+ yield self.profile_handler.check_profile_query_allowed(user, requester_user)
+
+ displayname = yield self.profile_handler.get_displayname(user)
+ avatar_url = yield self.profile_handler.get_avatar_url(user)
ret = {}
if displayname is not None:
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 506ec95ddd..3d6326fe2f 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -21,22 +21,22 @@ from synapse.api.errors import (
SynapseError,
UnrecognizedRequestError,
)
-from synapse.http.servlet import parse_json_value_from_request, parse_string
+from synapse.http.servlet import RestServlet, parse_json_value_from_request, parse_string
from synapse.push.baserules import BASE_RULE_IDS
from synapse.push.clientformat import format_push_rules_for_user
from synapse.push.rulekinds import PRIORITY_CLASS_MAP
+from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from .base import ClientV1RestServlet, client_path_patterns
-
-class PushRuleRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/(?P<path>pushrules/.*)$")
+class PushRuleRestServlet(RestServlet):
+ PATTERNS = client_patterns("/(?P<path>pushrules/.*)$", v1=True)
SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
"Unrecognised request: You probably wanted a trailing slash")
def __init__(self, hs):
- super(PushRuleRestServlet, self).__init__(hs)
+ super(PushRuleRestServlet, self).__init__()
+ self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 4c07ae7f45..15d860db37 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -26,17 +26,18 @@ from synapse.http.servlet import (
parse_string,
)
from synapse.push import PusherConfigException
-
-from .base import ClientV1RestServlet, client_path_patterns
+from synapse.rest.client.v2_alpha._base import client_patterns
logger = logging.getLogger(__name__)
-class PushersRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/pushers$")
+class PushersRestServlet(RestServlet):
+ PATTERNS = client_patterns("/pushers$", v1=True)
def __init__(self, hs):
- super(PushersRestServlet, self).__init__(hs)
+ super(PushersRestServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -69,11 +70,13 @@ class PushersRestServlet(ClientV1RestServlet):
return 200, {}
-class PushersSetRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/pushers/set$")
+class PushersSetRestServlet(RestServlet):
+ PATTERNS = client_patterns("/pushers/set$", v1=True)
def __init__(self, hs):
- super(PushersSetRestServlet, self).__init__(hs)
+ super(PushersSetRestServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
self.pusher_pool = self.hs.get_pusherpool()
@@ -141,7 +144,7 @@ class PushersRemoveRestServlet(RestServlet):
"""
To allow pusher to be delete by clicking a link (ie. GET request)
"""
- PATTERNS = client_path_patterns("/pushers/remove$")
+ PATTERNS = client_patterns("/pushers/remove$", v1=True)
SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>"
def __init__(self, hs):
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 48da4d557f..e8f672c4ba 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -26,39 +26,47 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.filtering import Filter
-from synapse.events.utils import format_event_for_client_v2, serialize_event
+from synapse.events.utils import format_event_for_client_v2
from synapse.http.servlet import (
+ RestServlet,
assert_params_in_dict,
parse_integer,
parse_json_object_from_request,
parse_string,
)
+from synapse.rest.client.transactions import HttpTransactionCache
+from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
-from .base import ClientV1RestServlet, client_path_patterns
-
logger = logging.getLogger(__name__)
-class RoomCreateRestServlet(ClientV1RestServlet):
+class TransactionRestServlet(RestServlet):
+ def __init__(self, hs):
+ super(TransactionRestServlet, self).__init__()
+ self.txns = HttpTransactionCache(hs)
+
+
+class RoomCreateRestServlet(TransactionRestServlet):
# No PATTERN; we have custom dispatch rules here
def __init__(self, hs):
super(RoomCreateRestServlet, self).__init__(hs)
self._room_creation_handler = hs.get_room_creation_handler()
+ self.auth = hs.get_auth()
def register(self, http_server):
PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server)
# define CORS for all of /rooms in RoomCreateRestServlet for simplicity
http_server.register_paths("OPTIONS",
- client_path_patterns("/rooms(?:/.*)?$"),
+ client_patterns("/rooms(?:/.*)?$", v1=True),
self.on_OPTIONS)
# define CORS for /createRoom[/txnid]
http_server.register_paths("OPTIONS",
- client_path_patterns("/createRoom(?:/.*)?$"),
+ client_patterns("/createRoom(?:/.*)?$", v1=True),
self.on_OPTIONS)
def on_PUT(self, request, txn_id):
@@ -85,13 +93,14 @@ class RoomCreateRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events
-class RoomStateEventRestServlet(ClientV1RestServlet):
+class RoomStateEventRestServlet(TransactionRestServlet):
def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.message_handler = hs.get_message_handler()
+ self.auth = hs.get_auth()
def register(self, http_server):
# /room/$roomid/state/$eventtype
@@ -102,16 +111,16 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
"(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$")
http_server.register_paths("GET",
- client_path_patterns(state_key),
+ client_patterns(state_key, v1=True),
self.on_GET)
http_server.register_paths("PUT",
- client_path_patterns(state_key),
+ client_patterns(state_key, v1=True),
self.on_PUT)
http_server.register_paths("GET",
- client_path_patterns(no_state_key),
+ client_patterns(no_state_key, v1=True),
self.on_GET_no_state_key)
http_server.register_paths("PUT",
- client_path_patterns(no_state_key),
+ client_patterns(no_state_key, v1=True),
self.on_PUT_no_state_key)
def on_GET_no_state_key(self, request, room_id, event_type):
@@ -185,11 +194,12 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for generic events + feedback
-class RoomSendEventRestServlet(ClientV1RestServlet):
+class RoomSendEventRestServlet(TransactionRestServlet):
def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
+ self.auth = hs.get_auth()
def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id]
@@ -229,10 +239,11 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing for room ID + alias joins
-class JoinRoomAliasServlet(ClientV1RestServlet):
+class JoinRoomAliasServlet(TransactionRestServlet):
def __init__(self, hs):
super(JoinRoomAliasServlet, self).__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
+ self.auth = hs.get_auth()
def register(self, http_server):
# /join/$room_identifier[/$txn_id]
@@ -291,8 +302,13 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
# TODO: Needs unit testing
-class PublicRoomListRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/publicRooms$")
+class PublicRoomListRestServlet(TransactionRestServlet):
+ PATTERNS = client_patterns("/publicRooms$", v1=True)
+
+ def __init__(self, hs):
+ super(PublicRoomListRestServlet, self).__init__(hs)
+ self.hs = hs
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -301,6 +317,12 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
try:
yield self.auth.get_user_by_req(request, allow_guest=True)
except AuthError as e:
+ # Option to allow servers to require auth when accessing
+ # /publicRooms via CS API. This is especially helpful in private
+ # federations.
+ if self.hs.config.restrict_public_rooms_to_local_users:
+ raise
+
# We allow people to not be authed if they're just looking at our
# room list, but require auth when we proxy the request.
# In both cases we call the auth function, as that has the side
@@ -376,12 +398,13 @@ class PublicRoomListRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing
-class RoomMemberListRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$")
+class RoomMemberListRestServlet(RestServlet):
+ PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True)
def __init__(self, hs):
- super(RoomMemberListRestServlet, self).__init__(hs)
+ super(RoomMemberListRestServlet, self).__init__()
self.message_handler = hs.get_message_handler()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
@@ -430,12 +453,13 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
# deprecated in favour of /members?membership=join?
# except it does custom AS logic and has a simpler return format
-class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$")
+class JoinedRoomMemberListRestServlet(RestServlet):
+ PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True)
def __init__(self, hs):
- super(JoinedRoomMemberListRestServlet, self).__init__(hs)
+ super(JoinedRoomMemberListRestServlet, self).__init__()
self.message_handler = hs.get_message_handler()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
@@ -451,12 +475,13 @@ class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
# TODO: Needs better unit testing
-class RoomMessageListRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
+class RoomMessageListRestServlet(RestServlet):
+ PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True)
def __init__(self, hs):
- super(RoomMessageListRestServlet, self).__init__(hs)
+ super(RoomMessageListRestServlet, self).__init__()
self.pagination_handler = hs.get_pagination_handler()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
@@ -469,6 +494,8 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
if filter_bytes:
filter_json = urlparse.unquote(filter_bytes.decode("UTF-8"))
event_filter = Filter(json.loads(filter_json))
+ if event_filter.filter_json.get("event_format", "client") == "federation":
+ as_client_event = False
else:
event_filter = None
msgs = yield self.pagination_handler.get_messages(
@@ -483,12 +510,13 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing
-class RoomStateRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$")
+class RoomStateRestServlet(RestServlet):
+ PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True)
def __init__(self, hs):
- super(RoomStateRestServlet, self).__init__(hs)
+ super(RoomStateRestServlet, self).__init__()
self.message_handler = hs.get_message_handler()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
@@ -503,12 +531,13 @@ class RoomStateRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing
-class RoomInitialSyncRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$")
+class RoomInitialSyncRestServlet(RestServlet):
+ PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True)
def __init__(self, hs):
- super(RoomInitialSyncRestServlet, self).__init__(hs)
+ super(RoomInitialSyncRestServlet, self).__init__()
self.initial_sync_handler = hs.get_initial_sync_handler()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
@@ -522,15 +551,17 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
defer.returnValue((200, content))
-class RoomEventServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns(
- "/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$"
+class RoomEventServlet(RestServlet):
+ PATTERNS = client_patterns(
+ "/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$", v1=True
)
def __init__(self, hs):
- super(RoomEventServlet, self).__init__(hs)
+ super(RoomEventServlet, self).__init__()
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
+ self._event_serializer = hs.get_event_client_serializer()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_id):
@@ -539,20 +570,23 @@ class RoomEventServlet(ClientV1RestServlet):
time_now = self.clock.time_msec()
if event:
- defer.returnValue((200, serialize_event(event, time_now)))
+ event = yield self._event_serializer.serialize_event(event, time_now)
+ defer.returnValue((200, event))
else:
defer.returnValue((404, "Event not found."))
-class RoomEventContextServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns(
- "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$"
+class RoomEventContextServlet(RestServlet):
+ PATTERNS = client_patterns(
+ "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$", v1=True
)
def __init__(self, hs):
- super(RoomEventContextServlet, self).__init__(hs)
+ super(RoomEventContextServlet, self).__init__()
self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler()
+ self._event_serializer = hs.get_event_client_serializer()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request, room_id, event_id):
@@ -582,24 +616,27 @@ class RoomEventContextServlet(ClientV1RestServlet):
)
time_now = self.clock.time_msec()
- results["events_before"] = [
- serialize_event(event, time_now) for event in results["events_before"]
- ]
- results["event"] = serialize_event(results["event"], time_now)
- results["events_after"] = [
- serialize_event(event, time_now) for event in results["events_after"]
- ]
- results["state"] = [
- serialize_event(event, time_now) for event in results["state"]
- ]
+ results["events_before"] = yield self._event_serializer.serialize_events(
+ results["events_before"], time_now,
+ )
+ results["event"] = yield self._event_serializer.serialize_event(
+ results["event"], time_now,
+ )
+ results["events_after"] = yield self._event_serializer.serialize_events(
+ results["events_after"], time_now,
+ )
+ results["state"] = yield self._event_serializer.serialize_events(
+ results["state"], time_now,
+ )
defer.returnValue((200, results))
-class RoomForgetRestServlet(ClientV1RestServlet):
+class RoomForgetRestServlet(TransactionRestServlet):
def __init__(self, hs):
super(RoomForgetRestServlet, self).__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
+ self.auth = hs.get_auth()
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
@@ -626,11 +663,12 @@ class RoomForgetRestServlet(ClientV1RestServlet):
# TODO: Needs unit testing
-class RoomMembershipRestServlet(ClientV1RestServlet):
+class RoomMembershipRestServlet(TransactionRestServlet):
def __init__(self, hs):
super(RoomMembershipRestServlet, self).__init__(hs)
self.room_member_handler = hs.get_room_member_handler()
+ self.auth = hs.get_auth()
def register(self, http_server):
# /rooms/$roomid/[invite|join|leave]
@@ -709,11 +747,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
)
-class RoomRedactEventRestServlet(ClientV1RestServlet):
+class RoomRedactEventRestServlet(TransactionRestServlet):
def __init__(self, hs):
super(RoomRedactEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
self.event_creation_handler = hs.get_event_creation_handler()
+ self.auth = hs.get_auth()
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
@@ -744,15 +783,16 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
)
-class RoomTypingRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns(
- "/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$"
+class RoomTypingRestServlet(RestServlet):
+ PATTERNS = client_patterns(
+ "/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$", v1=True
)
def __init__(self, hs):
- super(RoomTypingRestServlet, self).__init__(hs)
+ super(RoomTypingRestServlet, self).__init__()
self.presence_handler = hs.get_presence_handler()
self.typing_handler = hs.get_typing_handler()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_PUT(self, request, room_id, user_id):
@@ -785,14 +825,13 @@ class RoomTypingRestServlet(ClientV1RestServlet):
defer.returnValue((200, {}))
-class SearchRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns(
- "/search$"
- )
+class SearchRestServlet(RestServlet):
+ PATTERNS = client_patterns("/search$", v1=True)
def __init__(self, hs):
- super(SearchRestServlet, self).__init__(hs)
+ super(SearchRestServlet, self).__init__()
self.handlers = hs.get_handlers()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_POST(self, request):
@@ -810,12 +849,13 @@ class SearchRestServlet(ClientV1RestServlet):
defer.returnValue((200, results))
-class JoinedRoomsRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/joined_rooms$")
+class JoinedRoomsRestServlet(RestServlet):
+ PATTERNS = client_patterns("/joined_rooms$", v1=True)
def __init__(self, hs):
- super(JoinedRoomsRestServlet, self).__init__(hs)
+ super(JoinedRoomsRestServlet, self).__init__()
self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -840,18 +880,18 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
"""
http_server.register_paths(
"POST",
- client_path_patterns(regex_string + "$"),
+ client_patterns(regex_string + "$", v1=True),
servlet.on_POST
)
http_server.register_paths(
"PUT",
- client_path_patterns(regex_string + "/(?P<txn_id>[^/]*)$"),
+ client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
servlet.on_PUT
)
if with_get:
http_server.register_paths(
"GET",
- client_path_patterns(regex_string + "/(?P<txn_id>[^/]*)$"),
+ client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
servlet.on_GET
)
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 53da905eea..6381049210 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -19,11 +19,17 @@ import hmac
from twisted.internet import defer
-from .base import ClientV1RestServlet, client_path_patterns
+from synapse.http.servlet import RestServlet
+from synapse.rest.client.v2_alpha._base import client_patterns
-class VoipRestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/voip/turnServer$")
+class VoipRestServlet(RestServlet):
+ PATTERNS = client_patterns("/voip/turnServer$", v1=True)
+
+ def __init__(self, hs):
+ super(VoipRestServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request):
diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py
index 77434937ff..5236d5d566 100644
--- a/synapse/rest/client/v2_alpha/_base.py
+++ b/synapse/rest/client/v2_alpha/_base.py
@@ -21,14 +21,12 @@ import re
from twisted.internet import defer
from synapse.api.errors import InteractiveAuthIncompleteError
-from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
+from synapse.api.urls import CLIENT_API_PREFIX
logger = logging.getLogger(__name__)
-def client_v2_patterns(path_regex, releases=(0,),
- v2_alpha=True,
- unstable=True):
+def client_patterns(path_regex, releases=(0,), unstable=True, v1=False):
"""Creates a regex compiled client path with the correct client path
prefix.
@@ -39,13 +37,14 @@ def client_v2_patterns(path_regex, releases=(0,),
SRE_Pattern
"""
patterns = []
- if v2_alpha:
- patterns.append(re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex))
if unstable:
- unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable")
+ unstable_prefix = CLIENT_API_PREFIX + "/unstable"
patterns.append(re.compile("^" + unstable_prefix + path_regex))
+ if v1:
+ v1_prefix = CLIENT_API_PREFIX + "/api/v1"
+ patterns.append(re.compile("^" + v1_prefix + path_regex))
for release in releases:
- new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release)
+ new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,)
patterns.append(re.compile("^" + new_prefix + path_regex))
return patterns
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index ee069179f0..ca35dc3c83 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -30,13 +30,13 @@ from synapse.http.servlet import (
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import check_3pid_allowed
-from ._base import client_v2_patterns, interactive_auth_handler
+from ._base import client_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
class EmailPasswordRequestTokenRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/account/password/email/requestToken$")
+ PATTERNS = client_patterns("/account/password/email/requestToken$")
def __init__(self, hs):
super(EmailPasswordRequestTokenRestServlet, self).__init__()
@@ -70,7 +70,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
class MsisdnPasswordRequestTokenRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/account/password/msisdn/requestToken$")
+ PATTERNS = client_patterns("/account/password/msisdn/requestToken$")
def __init__(self, hs):
super(MsisdnPasswordRequestTokenRestServlet, self).__init__()
@@ -108,7 +108,7 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
class PasswordRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/account/password$")
+ PATTERNS = client_patterns("/account/password$")
def __init__(self, hs):
super(PasswordRestServlet, self).__init__()
@@ -180,7 +180,7 @@ class PasswordRestServlet(RestServlet):
class DeactivateAccountRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/account/deactivate$")
+ PATTERNS = client_patterns("/account/deactivate$")
def __init__(self, hs):
super(DeactivateAccountRestServlet, self).__init__()
@@ -228,7 +228,7 @@ class DeactivateAccountRestServlet(RestServlet):
class EmailThreepidRequestTokenRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$")
+ PATTERNS = client_patterns("/account/3pid/email/requestToken$")
def __init__(self, hs):
self.hs = hs
@@ -263,7 +263,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
class MsisdnThreepidRequestTokenRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/account/3pid/msisdn/requestToken$")
+ PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$")
def __init__(self, hs):
self.hs = hs
@@ -300,7 +300,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
class ThreepidRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/account/3pid$")
+ PATTERNS = client_patterns("/account/3pid$")
def __init__(self, hs):
super(ThreepidRestServlet, self).__init__()
@@ -364,7 +364,7 @@ class ThreepidRestServlet(RestServlet):
class ThreepidDeleteRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/account/3pid/delete$")
+ PATTERNS = client_patterns("/account/3pid/delete$")
def __init__(self, hs):
super(ThreepidDeleteRestServlet, self).__init__()
@@ -401,7 +401,7 @@ class ThreepidDeleteRestServlet(RestServlet):
class WhoamiRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/account/whoami$")
+ PATTERNS = client_patterns("/account/whoami$")
def __init__(self, hs):
super(WhoamiRestServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index f171b8d626..574a6298ce 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -20,7 +20,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
@@ -30,7 +30,7 @@ class AccountDataServlet(RestServlet):
PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1
GET /user/{user_id}/account_data/{account_dataType} HTTP/1.1
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)"
)
@@ -79,7 +79,7 @@ class RoomAccountDataServlet(RestServlet):
PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1
GET /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)"
"/rooms/(?P<room_id>[^/]*)"
"/account_data/(?P<account_data_type>[^/]*)"
diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py
index fc8dbeb617..55c4ed5660 100644
--- a/synapse/rest/client/v2_alpha/account_validity.py
+++ b/synapse/rest/client/v2_alpha/account_validity.py
@@ -21,13 +21,13 @@ from synapse.api.errors import AuthError, SynapseError
from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
class AccountValidityRenewServlet(RestServlet):
- PATTERNS = client_v2_patterns("/account_validity/renew$")
+ PATTERNS = client_patterns("/account_validity/renew$")
SUCCESS_HTML = b"<html><body>Your account has been successfully renewed.</body><html>"
def __init__(self, hs):
@@ -60,7 +60,7 @@ class AccountValidityRenewServlet(RestServlet):
class AccountValiditySendMailServlet(RestServlet):
- PATTERNS = client_v2_patterns("/account_validity/send_mail$")
+ PATTERNS = client_patterns("/account_validity/send_mail$")
def __init__(self, hs):
"""
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index ac035c7735..8dfe5cba02 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -19,11 +19,11 @@ from twisted.internet import defer
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError
-from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
+from synapse.api.urls import CLIENT_API_PREFIX
from synapse.http.server import finish_request
from synapse.http.servlet import RestServlet, parse_string
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
@@ -122,7 +122,7 @@ class AuthRestServlet(RestServlet):
cannot be handled in the normal flow (with requests to the same endpoint).
Current use is for web fallback auth.
"""
- PATTERNS = client_v2_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
+ PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
def __init__(self, hs):
super(AuthRestServlet, self).__init__()
@@ -139,8 +139,8 @@ class AuthRestServlet(RestServlet):
if stagetype == LoginType.RECAPTCHA:
html = RECAPTCHA_TEMPLATE % {
'session': session,
- 'myurl': "%s/auth/%s/fallback/web" % (
- CLIENT_V2_ALPHA_PREFIX, LoginType.RECAPTCHA
+ 'myurl': "%s/r0/auth/%s/fallback/web" % (
+ CLIENT_API_PREFIX, LoginType.RECAPTCHA
),
'sitekey': self.hs.config.recaptcha_public_key,
}
@@ -159,8 +159,8 @@ class AuthRestServlet(RestServlet):
self.hs.config.public_baseurl,
self.hs.config.user_consent_version,
),
- 'myurl': "%s/auth/%s/fallback/web" % (
- CLIENT_V2_ALPHA_PREFIX, LoginType.TERMS
+ 'myurl': "%s/r0/auth/%s/fallback/web" % (
+ CLIENT_API_PREFIX, LoginType.TERMS
),
}
html_bytes = html.encode("utf8")
@@ -203,8 +203,8 @@ class AuthRestServlet(RestServlet):
else:
html = RECAPTCHA_TEMPLATE % {
'session': session,
- 'myurl': "%s/auth/%s/fallback/web" % (
- CLIENT_V2_ALPHA_PREFIX, LoginType.RECAPTCHA
+ 'myurl': "%s/r0/auth/%s/fallback/web" % (
+ CLIENT_API_PREFIX, LoginType.RECAPTCHA
),
'sitekey': self.hs.config.recaptcha_public_key,
}
@@ -240,8 +240,8 @@ class AuthRestServlet(RestServlet):
self.hs.config.public_baseurl,
self.hs.config.user_consent_version,
),
- 'myurl': "%s/auth/%s/fallback/web" % (
- CLIENT_V2_ALPHA_PREFIX, LoginType.TERMS
+ 'myurl': "%s/r0/auth/%s/fallback/web" % (
+ CLIENT_API_PREFIX, LoginType.TERMS
),
}
html_bytes = html.encode("utf8")
diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py
index a868d06098..fc7e2f4dd5 100644
--- a/synapse/rest/client/v2_alpha/capabilities.py
+++ b/synapse/rest/client/v2_alpha/capabilities.py
@@ -16,10 +16,10 @@ import logging
from twisted.internet import defer
-from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.servlet import RestServlet
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
class CapabilitiesRestServlet(RestServlet):
"""End point to expose the capabilities of the server."""
- PATTERNS = client_v2_patterns("/capabilities$")
+ PATTERNS = client_patterns("/capabilities$")
def __init__(self, hs):
"""
@@ -36,6 +36,7 @@ class CapabilitiesRestServlet(RestServlet):
"""
super(CapabilitiesRestServlet, self).__init__()
self.hs = hs
+ self.config = hs.config
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -48,7 +49,7 @@ class CapabilitiesRestServlet(RestServlet):
response = {
"capabilities": {
"m.room_versions": {
- "default": DEFAULT_ROOM_VERSION.identifier,
+ "default": self.config.default_room_version.identifier,
"available": {
v.identifier: v.disposition
for v in KNOWN_ROOM_VERSIONS.values()
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index 9b75bb1377..78665304a5 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -24,13 +24,13 @@ from synapse.http.servlet import (
parse_json_object_from_request,
)
-from ._base import client_v2_patterns, interactive_auth_handler
+from ._base import client_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
class DevicesRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/devices$", v2_alpha=False)
+ PATTERNS = client_patterns("/devices$")
def __init__(self, hs):
"""
@@ -56,7 +56,7 @@ class DeleteDevicesRestServlet(RestServlet):
API for bulk deletion of devices. Accepts a JSON object with a devices
key which lists the device_ids to delete. Requires user interactive auth.
"""
- PATTERNS = client_v2_patterns("/delete_devices", v2_alpha=False)
+ PATTERNS = client_patterns("/delete_devices")
def __init__(self, hs):
super(DeleteDevicesRestServlet, self).__init__()
@@ -95,7 +95,7 @@ class DeleteDevicesRestServlet(RestServlet):
class DeviceRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", v2_alpha=False)
+ PATTERNS = client_patterns("/devices/(?P<device_id>[^/]*)$")
def __init__(self, hs):
"""
diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index ae86728879..65db48c3cc 100644
--- a/synapse/rest/client/v2_alpha/filter.py
+++ b/synapse/rest/client/v2_alpha/filter.py
@@ -21,13 +21,13 @@ from synapse.api.errors import AuthError, Codes, StoreError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID
-from ._base import client_v2_patterns, set_timeline_upper_limit
+from ._base import client_patterns, set_timeline_upper_limit
logger = logging.getLogger(__name__)
class GetFilterRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
+ PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)")
def __init__(self, hs):
super(GetFilterRestServlet, self).__init__()
@@ -63,7 +63,7 @@ class GetFilterRestServlet(RestServlet):
class CreateFilterRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/user/(?P<user_id>[^/]*)/filter")
+ PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter")
def __init__(self, hs):
super(CreateFilterRestServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 21e02c07c0..d082385ec7 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -21,7 +21,7 @@ from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import GroupID
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
class GroupServlet(RestServlet):
"""Get the group profile
"""
- PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/profile$")
+ PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$")
def __init__(self, hs):
super(GroupServlet, self).__init__()
@@ -65,7 +65,7 @@ class GroupServlet(RestServlet):
class GroupSummaryServlet(RestServlet):
"""Get the full group summary
"""
- PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/summary$")
+ PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$")
def __init__(self, hs):
super(GroupSummaryServlet, self).__init__()
@@ -93,7 +93,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
- /groups/:group/summary/rooms/:room_id
- /groups/:group/summary/categories/:category/rooms/:room_id
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/summary"
"(/categories/(?P<category_id>[^/]+))?"
"/rooms/(?P<room_id>[^/]*)$"
@@ -137,7 +137,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
class GroupCategoryServlet(RestServlet):
"""Get/add/update/delete a group category
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
)
@@ -189,7 +189,7 @@ class GroupCategoryServlet(RestServlet):
class GroupCategoriesServlet(RestServlet):
"""Get all group categories
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/categories/$"
)
@@ -214,7 +214,7 @@ class GroupCategoriesServlet(RestServlet):
class GroupRoleServlet(RestServlet):
"""Get/add/update/delete a group role
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
)
@@ -266,7 +266,7 @@ class GroupRoleServlet(RestServlet):
class GroupRolesServlet(RestServlet):
"""Get all group roles
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/roles/$"
)
@@ -295,7 +295,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
- /groups/:group/summary/users/:room_id
- /groups/:group/summary/roles/:role/users/:user_id
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/summary"
"(/roles/(?P<role_id>[^/]+))?"
"/users/(?P<user_id>[^/]*)$"
@@ -339,7 +339,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
class GroupRoomServlet(RestServlet):
"""Get all rooms in a group
"""
- PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
+ PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
def __init__(self, hs):
super(GroupRoomServlet, self).__init__()
@@ -360,7 +360,7 @@ class GroupRoomServlet(RestServlet):
class GroupUsersServlet(RestServlet):
"""Get all users in a group
"""
- PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/users$")
+ PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$")
def __init__(self, hs):
super(GroupUsersServlet, self).__init__()
@@ -381,7 +381,7 @@ class GroupUsersServlet(RestServlet):
class GroupInvitedUsersServlet(RestServlet):
"""Get users invited to a group
"""
- PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
+ PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
def __init__(self, hs):
super(GroupInvitedUsersServlet, self).__init__()
@@ -405,7 +405,7 @@ class GroupInvitedUsersServlet(RestServlet):
class GroupSettingJoinPolicyServlet(RestServlet):
"""Set group join policy
"""
- PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
+ PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
def __init__(self, hs):
super(GroupSettingJoinPolicyServlet, self).__init__()
@@ -431,7 +431,7 @@ class GroupSettingJoinPolicyServlet(RestServlet):
class GroupCreateServlet(RestServlet):
"""Create a group
"""
- PATTERNS = client_v2_patterns("/create_group$")
+ PATTERNS = client_patterns("/create_group$")
def __init__(self, hs):
super(GroupCreateServlet, self).__init__()
@@ -462,7 +462,7 @@ class GroupCreateServlet(RestServlet):
class GroupAdminRoomsServlet(RestServlet):
"""Add a room to the group
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
)
@@ -499,7 +499,7 @@ class GroupAdminRoomsServlet(RestServlet):
class GroupAdminRoomsConfigServlet(RestServlet):
"""Update the config of a room in a group
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)"
"/config/(?P<config_key>[^/]*)$"
)
@@ -526,7 +526,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
class GroupAdminUsersInviteServlet(RestServlet):
"""Invite a user to the group
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
)
@@ -555,7 +555,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
class GroupAdminUsersKickServlet(RestServlet):
"""Kick a user from the group
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
)
@@ -581,7 +581,7 @@ class GroupAdminUsersKickServlet(RestServlet):
class GroupSelfLeaveServlet(RestServlet):
"""Leave a joined group
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/self/leave$"
)
@@ -607,7 +607,7 @@ class GroupSelfLeaveServlet(RestServlet):
class GroupSelfJoinServlet(RestServlet):
"""Attempt to join a group, or knock
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/self/join$"
)
@@ -633,7 +633,7 @@ class GroupSelfJoinServlet(RestServlet):
class GroupSelfAcceptInviteServlet(RestServlet):
"""Accept a group invite
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/self/accept_invite$"
)
@@ -659,7 +659,7 @@ class GroupSelfAcceptInviteServlet(RestServlet):
class GroupSelfUpdatePublicityServlet(RestServlet):
"""Update whether we publicise a users membership of a group
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/self/update_publicity$"
)
@@ -686,7 +686,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
class PublicisedGroupsForUserServlet(RestServlet):
"""Get the list of groups a user is advertising
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/publicised_groups/(?P<user_id>[^/]*)$"
)
@@ -711,7 +711,7 @@ class PublicisedGroupsForUserServlet(RestServlet):
class PublicisedGroupsForUsersServlet(RestServlet):
"""Get the list of groups a user is advertising
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/publicised_groups$"
)
@@ -739,7 +739,7 @@ class PublicisedGroupsForUsersServlet(RestServlet):
class GroupsForUserServlet(RestServlet):
"""Get all groups the logged in user is joined to
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/joined_groups$"
)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 8486086b51..4cbfbf5631 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -26,7 +26,7 @@ from synapse.http.servlet import (
)
from synapse.types import StreamToken
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
@@ -56,7 +56,7 @@ class KeyUploadServlet(RestServlet):
},
}
"""
- PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
+ PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
def __init__(self, hs):
"""
@@ -130,7 +130,7 @@ class KeyQueryServlet(RestServlet):
} } } } } }
"""
- PATTERNS = client_v2_patterns("/keys/query$")
+ PATTERNS = client_patterns("/keys/query$")
def __init__(self, hs):
"""
@@ -159,7 +159,7 @@ class KeyChangesServlet(RestServlet):
200 OK
{ "changed": ["@foo:example.com"] }
"""
- PATTERNS = client_v2_patterns("/keys/changes$")
+ PATTERNS = client_patterns("/keys/changes$")
def __init__(self, hs):
"""
@@ -209,7 +209,7 @@ class OneTimeKeyServlet(RestServlet):
} } } }
"""
- PATTERNS = client_v2_patterns("/keys/claim$")
+ PATTERNS = client_patterns("/keys/claim$")
def __init__(self, hs):
super(OneTimeKeyServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py
index 2a6ea3df5f..53e666989b 100644
--- a/synapse/rest/client/v2_alpha/notifications.py
+++ b/synapse/rest/client/v2_alpha/notifications.py
@@ -17,25 +17,23 @@ import logging
from twisted.internet import defer
-from synapse.events.utils import (
- format_event_for_client_v2_without_room_id,
- serialize_event,
-)
+from synapse.events.utils import format_event_for_client_v2_without_room_id
from synapse.http.servlet import RestServlet, parse_integer, parse_string
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
class NotificationsServlet(RestServlet):
- PATTERNS = client_v2_patterns("/notifications$")
+ PATTERNS = client_patterns("/notifications$")
def __init__(self, hs):
super(NotificationsServlet, self).__init__()
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
+ self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -69,11 +67,11 @@ class NotificationsServlet(RestServlet):
"profile_tag": pa["profile_tag"],
"actions": pa["actions"],
"ts": pa["received_ts"],
- "event": serialize_event(
+ "event": (yield self._event_serializer.serialize_event(
notif_events[pa["event_id"]],
self.clock.time_msec(),
event_format=format_event_for_client_v2_without_room_id,
- ),
+ )),
}
if pa["room_id"] not in receipts_by_room:
diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py
index 01c90aa2a3..bb927d9f9d 100644
--- a/synapse/rest/client/v2_alpha/openid.py
+++ b/synapse/rest/client/v2_alpha/openid.py
@@ -22,7 +22,7 @@ from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.util.stringutils import random_string
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
@@ -56,7 +56,7 @@ class IdTokenServlet(RestServlet):
"expires_in": 3600,
}
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/openid/request_token"
)
diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py
index a6e582a5ae..f4bd0d077f 100644
--- a/synapse/rest/client/v2_alpha/read_marker.py
+++ b/synapse/rest/client/v2_alpha/read_marker.py
@@ -19,13 +19,13 @@ from twisted.internet import defer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
class ReadMarkerRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$")
+ PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$")
def __init__(self, hs):
super(ReadMarkerRestServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py
index de370cac45..fa12ac3e4d 100644
--- a/synapse/rest/client/v2_alpha/receipts.py
+++ b/synapse/rest/client/v2_alpha/receipts.py
@@ -20,13 +20,13 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
class ReceiptRestServlet(RestServlet):
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)"
"/receipt/(?P<receipt_type>[^/]*)"
"/(?P<event_id>[^/]*)$"
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index dc3e265bcd..79c085408b 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -31,6 +31,7 @@ from synapse.api.errors import (
SynapseError,
UnrecognizedRequestError,
)
+from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.config.server import is_threepid_reserved
from synapse.http.servlet import (
RestServlet,
@@ -42,7 +43,7 @@ from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.threepids import check_3pid_allowed
-from ._base import client_v2_patterns, interactive_auth_handler
+from ._base import client_patterns, interactive_auth_handler
# We ought to be using hmac.compare_digest() but on older pythons it doesn't
# exist. It's a _really minor_ security flaw to use plain string comparison
@@ -59,7 +60,7 @@ logger = logging.getLogger(__name__)
class EmailRegisterRequestTokenRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/register/email/requestToken$")
+ PATTERNS = client_patterns("/register/email/requestToken$")
def __init__(self, hs):
"""
@@ -97,7 +98,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
class MsisdnRegisterRequestTokenRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/register/msisdn/requestToken$")
+ PATTERNS = client_patterns("/register/msisdn/requestToken$")
def __init__(self, hs):
"""
@@ -141,7 +142,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
class UsernameAvailabilityRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/register/available")
+ PATTERNS = client_patterns("/register/available")
def __init__(self, hs):
"""
@@ -153,16 +154,18 @@ class UsernameAvailabilityRestServlet(RestServlet):
self.registration_handler = hs.get_registration_handler()
self.ratelimiter = FederationRateLimiter(
hs.get_clock(),
- # Time window of 2s
- window_size=2000,
- # Artificially delay requests if rate > sleep_limit/window_size
- sleep_limit=1,
- # Amount of artificial delay to apply
- sleep_msec=1000,
- # Error with 429 if more than reject_limit requests are queued
- reject_limit=1,
- # Allow 1 request at a time
- concurrent_requests=1,
+ FederationRateLimitConfig(
+ # Time window of 2s
+ window_size=2000,
+ # Artificially delay requests if rate > sleep_limit/window_size
+ sleep_limit=1,
+ # Amount of artificial delay to apply
+ sleep_msec=1000,
+ # Error with 429 if more than reject_limit requests are queued
+ reject_limit=1,
+ # Allow 1 request at a time
+ concurrent_requests=1,
+ )
)
@defer.inlineCallbacks
@@ -179,7 +182,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
class RegisterRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/register$")
+ PATTERNS = client_patterns("/register$")
def __init__(self, hs):
"""
@@ -345,18 +348,22 @@ class RegisterRestServlet(RestServlet):
if self.hs.config.enable_registration_captcha:
# only support 3PIDless registration if no 3PIDs are required
if not require_email and not require_msisdn:
- flows.extend([[LoginType.RECAPTCHA]])
+ # Also add a dummy flow here, otherwise if a client completes
+ # recaptcha first we'll assume they were going for this flow
+ # and complete the request, when they could have been trying to
+ # complete one of the flows with email/msisdn auth.
+ flows.extend([[LoginType.RECAPTCHA, LoginType.DUMMY]])
# only support the email-only flow if we don't require MSISDN 3PIDs
if not require_msisdn:
- flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]])
+ flows.extend([[LoginType.RECAPTCHA, LoginType.EMAIL_IDENTITY]])
if show_msisdn:
# only support the MSISDN-only flow if we don't require email 3PIDs
if not require_email:
- flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]])
+ flows.extend([[LoginType.RECAPTCHA, LoginType.MSISDN]])
# always let users provide both MSISDN & email
flows.extend([
- [LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
+ [LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
])
else:
# only support 3PIDless registration if no 3PIDs are required
@@ -379,7 +386,15 @@ class RegisterRestServlet(RestServlet):
if self.hs.config.user_consent_at_registration:
new_flows = []
for flow in flows:
- flow.append(LoginType.TERMS)
+ inserted = False
+ # m.login.terms should go near the end but before msisdn or email auth
+ for i, stage in enumerate(flow):
+ if stage == LoginType.EMAIL_IDENTITY or stage == LoginType.MSISDN:
+ flow.insert(i, LoginType.TERMS)
+ inserted = True
+ break
+ if not inserted:
+ flow.append(LoginType.TERMS)
flows.extend(new_flows)
auth_result, params, session_id = yield self.auth_handler.check_auth(
@@ -391,13 +406,6 @@ class RegisterRestServlet(RestServlet):
# the user-facing checks will probably already have happened in
# /register/email/requestToken when we requested a 3pid, but that's not
# guaranteed.
- #
- # Also check that we're not trying to register a 3pid that's already
- # been registered.
- #
- # This has probably happened in /register/email/requestToken as well,
- # but if a user hits this endpoint twice then clicks on each link from
- # the two activation emails, they would register the same 3pid twice.
if auth_result:
for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
@@ -413,17 +421,6 @@ class RegisterRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
- existingUid = yield self.store.get_user_id_by_threepid(
- medium, address,
- )
-
- if existingUid is not None:
- raise SynapseError(
- 400,
- "%s is already in use" % medium,
- Codes.THREEPID_IN_USE,
- )
-
if registered_user_id is not None:
logger.info(
"Already registered user ID %r for this session",
@@ -446,6 +443,28 @@ class RegisterRestServlet(RestServlet):
if auth_result:
threepid = auth_result.get(LoginType.EMAIL_IDENTITY)
+ # Also check that we're not trying to register a 3pid that's already
+ # been registered.
+ #
+ # This has probably happened in /register/email/requestToken as well,
+ # but if a user hits this endpoint twice then clicks on each link from
+ # the two activation emails, they would register the same 3pid twice.
+ for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
+ if login_type in auth_result:
+ medium = auth_result[login_type]['medium']
+ address = auth_result[login_type]['address']
+
+ existingUid = yield self.store.get_user_id_by_threepid(
+ medium, address,
+ )
+
+ if existingUid is not None:
+ raise SynapseError(
+ 400,
+ "%s is already in use" % medium,
+ Codes.THREEPID_IN_USE,
+ )
+
(registered_user_id, _) = yield self.registration_handler.register(
localpart=desired_username,
password=new_password,
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
new file mode 100644
index 0000000000..f8f8742bdc
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -0,0 +1,338 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""This class implements the proposed relation APIs from MSC 1849.
+
+Since the MSC has not been approved all APIs here are unstable and may change at
+any time to reflect changes in the MSC.
+"""
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, RelationTypes
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ parse_integer,
+ parse_json_object_from_request,
+ parse_string,
+)
+from synapse.rest.client.transactions import HttpTransactionCache
+from synapse.storage.relations import AggregationPaginationToken, RelationPaginationToken
+
+from ._base import client_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class RelationSendServlet(RestServlet):
+ """Helper API for sending events that have relation data.
+
+ Example API shape to send a 👍 reaction to a room:
+
+ POST /rooms/!foo/send_relation/$bar/m.annotation/m.reaction?key=%F0%9F%91%8D
+ {}
+
+ {
+ "event_id": "$foobar"
+ }
+ """
+
+ PATTERN = (
+ "/rooms/(?P<room_id>[^/]*)/send_relation"
+ "/(?P<parent_id>[^/]*)/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)"
+ )
+
+ def __init__(self, hs):
+ super(RelationSendServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self.txns = HttpTransactionCache(hs)
+
+ def register(self, http_server):
+ http_server.register_paths(
+ "POST",
+ client_patterns(self.PATTERN + "$", releases=()),
+ self.on_PUT_or_POST,
+ )
+ http_server.register_paths(
+ "PUT",
+ client_patterns(self.PATTERN + "/(?P<txn_id>[^/]*)$", releases=()),
+ self.on_PUT,
+ )
+
+ def on_PUT(self, request, *args, **kwargs):
+ return self.txns.fetch_or_execute_request(
+ request, self.on_PUT_or_POST, request, *args, **kwargs
+ )
+
+ @defer.inlineCallbacks
+ def on_PUT_or_POST(
+ self, request, room_id, parent_id, relation_type, event_type, txn_id=None
+ ):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+
+ if event_type == EventTypes.Member:
+ # Add relations to a membership is meaningless, so we just deny it
+ # at the CS API rather than trying to handle it correctly.
+ raise SynapseError(400, "Cannot send member events with relations")
+
+ content = parse_json_object_from_request(request)
+
+ aggregation_key = parse_string(request, "key", encoding="utf-8")
+
+ content["m.relates_to"] = {
+ "event_id": parent_id,
+ "key": aggregation_key,
+ "rel_type": relation_type,
+ }
+
+ event_dict = {
+ "type": event_type,
+ "content": content,
+ "room_id": room_id,
+ "sender": requester.user.to_string(),
+ }
+
+ event = yield self.event_creation_handler.create_and_send_nonmember_event(
+ requester, event_dict=event_dict, txn_id=txn_id
+ )
+
+ defer.returnValue((200, {"event_id": event.event_id}))
+
+
+class RelationPaginationServlet(RestServlet):
+ """API to paginate relations on an event by topological ordering, optionally
+ filtered by relation type and event type.
+ """
+
+ PATTERNS = client_patterns(
+ "/rooms/(?P<room_id>[^/]*)/relations/(?P<parent_id>[^/]*)"
+ "(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$",
+ releases=(),
+ )
+
+ def __init__(self, hs):
+ super(RelationPaginationServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self._event_serializer = hs.get_event_client_serializer()
+ self.event_handler = hs.get_event_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+
+ yield self.auth.check_in_room_or_world_readable(
+ room_id, requester.user.to_string()
+ )
+
+ # This checks that a) the event exists and b) the user is allowed to
+ # view it.
+ yield self.event_handler.get_event(requester.user, room_id, parent_id)
+
+ limit = parse_integer(request, "limit", default=5)
+ from_token = parse_string(request, "from")
+ to_token = parse_string(request, "to")
+
+ if from_token:
+ from_token = RelationPaginationToken.from_string(from_token)
+
+ if to_token:
+ to_token = RelationPaginationToken.from_string(to_token)
+
+ result = yield self.store.get_relations_for_event(
+ event_id=parent_id,
+ relation_type=relation_type,
+ event_type=event_type,
+ limit=limit,
+ from_token=from_token,
+ to_token=to_token,
+ )
+
+ events = yield self.store.get_events_as_list(
+ [c["event_id"] for c in result.chunk]
+ )
+
+ now = self.clock.time_msec()
+ events = yield self._event_serializer.serialize_events(events, now)
+
+ return_value = result.to_dict()
+ return_value["chunk"] = events
+
+ defer.returnValue((200, return_value))
+
+
+class RelationAggregationPaginationServlet(RestServlet):
+ """API to paginate aggregation groups of relations, e.g. paginate the
+ types and counts of the reactions on the events.
+
+ Example request and response:
+
+ GET /rooms/{room_id}/aggregations/{parent_id}
+
+ {
+ chunk: [
+ {
+ "type": "m.reaction",
+ "key": "👍",
+ "count": 3
+ }
+ ]
+ }
+ """
+
+ PATTERNS = client_patterns(
+ "/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)"
+ "(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$",
+ releases=(),
+ )
+
+ def __init__(self, hs):
+ super(RelationAggregationPaginationServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.event_handler = hs.get_event_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+
+ yield self.auth.check_in_room_or_world_readable(
+ room_id, requester.user.to_string()
+ )
+
+ # This checks that a) the event exists and b) the user is allowed to
+ # view it.
+ yield self.event_handler.get_event(requester.user, room_id, parent_id)
+
+ if relation_type not in (RelationTypes.ANNOTATION, None):
+ raise SynapseError(400, "Relation type must be 'annotation'")
+
+ limit = parse_integer(request, "limit", default=5)
+ from_token = parse_string(request, "from")
+ to_token = parse_string(request, "to")
+
+ if from_token:
+ from_token = AggregationPaginationToken.from_string(from_token)
+
+ if to_token:
+ to_token = AggregationPaginationToken.from_string(to_token)
+
+ res = yield self.store.get_aggregation_groups_for_event(
+ event_id=parent_id,
+ event_type=event_type,
+ limit=limit,
+ from_token=from_token,
+ to_token=to_token,
+ )
+
+ defer.returnValue((200, res.to_dict()))
+
+
+class RelationAggregationGroupPaginationServlet(RestServlet):
+ """API to paginate within an aggregation group of relations, e.g. paginate
+ all the 👍 reactions on an event.
+
+ Example request and response:
+
+ GET /rooms/{room_id}/aggregations/{parent_id}/m.annotation/m.reaction/👍
+
+ {
+ chunk: [
+ {
+ "type": "m.reaction",
+ "content": {
+ "m.relates_to": {
+ "rel_type": "m.annotation",
+ "key": "👍"
+ }
+ }
+ },
+ ...
+ ]
+ }
+ """
+
+ PATTERNS = client_patterns(
+ "/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)"
+ "/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)/(?P<key>[^/]*)$",
+ releases=(),
+ )
+
+ def __init__(self, hs):
+ super(RelationAggregationGroupPaginationServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self._event_serializer = hs.get_event_client_serializer()
+ self.event_handler = hs.get_event_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id, parent_id, relation_type, event_type, key):
+ requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+
+ yield self.auth.check_in_room_or_world_readable(
+ room_id, requester.user.to_string()
+ )
+
+ # This checks that a) the event exists and b) the user is allowed to
+ # view it.
+ yield self.event_handler.get_event(requester.user, room_id, parent_id)
+
+ if relation_type != RelationTypes.ANNOTATION:
+ raise SynapseError(400, "Relation type must be 'annotation'")
+
+ limit = parse_integer(request, "limit", default=5)
+ from_token = parse_string(request, "from")
+ to_token = parse_string(request, "to")
+
+ if from_token:
+ from_token = RelationPaginationToken.from_string(from_token)
+
+ if to_token:
+ to_token = RelationPaginationToken.from_string(to_token)
+
+ result = yield self.store.get_relations_for_event(
+ event_id=parent_id,
+ relation_type=relation_type,
+ event_type=event_type,
+ aggregation_key=key,
+ limit=limit,
+ from_token=from_token,
+ to_token=to_token,
+ )
+
+ events = yield self.store.get_events_as_list(
+ [c["event_id"] for c in result.chunk]
+ )
+
+ now = self.clock.time_msec()
+ events = yield self._event_serializer.serialize_events(events, now)
+
+ return_value = result.to_dict()
+ return_value["chunk"] = events
+
+ defer.returnValue((200, return_value))
+
+
+def register_servlets(hs, http_server):
+ RelationSendServlet(hs).register(http_server)
+ RelationPaginationServlet(hs).register(http_server)
+ RelationAggregationPaginationServlet(hs).register(http_server)
+ RelationAggregationGroupPaginationServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py
index 95d2a71ec2..10198662a9 100644
--- a/synapse/rest/client/v2_alpha/report_event.py
+++ b/synapse/rest/client/v2_alpha/report_event.py
@@ -27,13 +27,13 @@ from synapse.http.servlet import (
parse_json_object_from_request,
)
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
class ReportEventRestServlet(RestServlet):
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$"
)
diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py
index 220a0de30b..87779645f9 100644
--- a/synapse/rest/client/v2_alpha/room_keys.py
+++ b/synapse/rest/client/v2_alpha/room_keys.py
@@ -24,13 +24,13 @@ from synapse.http.servlet import (
parse_string,
)
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
class RoomKeysServlet(RestServlet):
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/room_keys/keys(/(?P<room_id>[^/]+))?(/(?P<session_id>[^/]+))?$"
)
@@ -256,7 +256,7 @@ class RoomKeysServlet(RestServlet):
class RoomKeysNewVersionServlet(RestServlet):
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/room_keys/version$"
)
@@ -314,7 +314,7 @@ class RoomKeysNewVersionServlet(RestServlet):
class RoomKeysVersionServlet(RestServlet):
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/room_keys/version(/(?P<version>[^/]+))?$"
)
diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
index 3db7ff8d1b..c621a90fba 100644
--- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
@@ -25,7 +25,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
)
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
@@ -47,10 +47,9 @@ class RoomUpgradeRestServlet(RestServlet):
Args:
hs (synapse.server.HomeServer):
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
# /rooms/$roomid/upgrade
"/rooms/(?P<room_id>[^/]*)/upgrade$",
- v2_alpha=False,
)
def __init__(self, hs):
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index a9e9a47a0b..120a713361 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -21,15 +21,14 @@ from synapse.http import servlet
from synapse.http.servlet import parse_json_object_from_request
from synapse.rest.client.transactions import HttpTransactionCache
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
class SendToDeviceRestServlet(servlet.RestServlet):
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$",
- v2_alpha=False
)
def __init__(self, hs):
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 39d157a44b..148fc6c985 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -26,14 +26,13 @@ from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
from synapse.events.utils import (
format_event_for_client_v2_without_room_id,
format_event_raw,
- serialize_event,
)
from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.sync import SyncConfig
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.types import StreamToken
-from ._base import client_v2_patterns, set_timeline_upper_limit
+from ._base import client_patterns, set_timeline_upper_limit
logger = logging.getLogger(__name__)
@@ -74,7 +73,7 @@ class SyncRestServlet(RestServlet):
}
"""
- PATTERNS = client_v2_patterns("/sync$")
+ PATTERNS = client_patterns("/sync$")
ALLOWED_PRESENCE = set(["online", "offline", "unavailable"])
def __init__(self, hs):
@@ -86,6 +85,7 @@ class SyncRestServlet(RestServlet):
self.filtering = hs.get_filtering()
self.presence_handler = hs.get_presence_handler()
self._server_notices_sender = hs.get_server_notices_sender()
+ self._event_serializer = hs.get_event_client_serializer()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -168,14 +168,14 @@ class SyncRestServlet(RestServlet):
)
time_now = self.clock.time_msec()
- response_content = self.encode_response(
+ response_content = yield self.encode_response(
time_now, sync_result, requester.access_token_id, filter
)
defer.returnValue((200, response_content))
- @staticmethod
- def encode_response(time_now, sync_result, access_token_id, filter):
+ @defer.inlineCallbacks
+ def encode_response(self, time_now, sync_result, access_token_id, filter):
if filter.event_format == 'client':
event_formatter = format_event_for_client_v2_without_room_id
elif filter.event_format == 'federation':
@@ -183,24 +183,24 @@ class SyncRestServlet(RestServlet):
else:
raise Exception("Unknown event format %s" % (filter.event_format, ))
- joined = SyncRestServlet.encode_joined(
+ joined = yield self.encode_joined(
sync_result.joined, time_now, access_token_id,
filter.event_fields,
event_formatter,
)
- invited = SyncRestServlet.encode_invited(
+ invited = yield self.encode_invited(
sync_result.invited, time_now, access_token_id,
event_formatter,
)
- archived = SyncRestServlet.encode_archived(
+ archived = yield self.encode_archived(
sync_result.archived, time_now, access_token_id,
filter.event_fields,
event_formatter,
)
- return {
+ defer.returnValue({
"account_data": {"events": sync_result.account_data},
"to_device": {"events": sync_result.to_device},
"device_lists": {
@@ -222,7 +222,7 @@ class SyncRestServlet(RestServlet):
},
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
"next_batch": sync_result.next_batch.to_string(),
- }
+ })
@staticmethod
def encode_presence(events, time_now):
@@ -239,8 +239,8 @@ class SyncRestServlet(RestServlet):
]
}
- @staticmethod
- def encode_joined(rooms, time_now, token_id, event_fields, event_formatter):
+ @defer.inlineCallbacks
+ def encode_joined(self, rooms, time_now, token_id, event_fields, event_formatter):
"""
Encode the joined rooms in a sync result
@@ -261,15 +261,15 @@ class SyncRestServlet(RestServlet):
"""
joined = {}
for room in rooms:
- joined[room.room_id] = SyncRestServlet.encode_room(
+ joined[room.room_id] = yield self.encode_room(
room, time_now, token_id, joined=True, only_fields=event_fields,
event_formatter=event_formatter,
)
- return joined
+ defer.returnValue(joined)
- @staticmethod
- def encode_invited(rooms, time_now, token_id, event_formatter):
+ @defer.inlineCallbacks
+ def encode_invited(self, rooms, time_now, token_id, event_formatter):
"""
Encode the invited rooms in a sync result
@@ -289,7 +289,7 @@ class SyncRestServlet(RestServlet):
"""
invited = {}
for room in rooms:
- invite = serialize_event(
+ invite = yield self._event_serializer.serialize_event(
room.invite, time_now, token_id=token_id,
event_format=event_formatter,
is_invite=True,
@@ -302,10 +302,10 @@ class SyncRestServlet(RestServlet):
"invite_state": {"events": invited_state}
}
- return invited
+ defer.returnValue(invited)
- @staticmethod
- def encode_archived(rooms, time_now, token_id, event_fields, event_formatter):
+ @defer.inlineCallbacks
+ def encode_archived(self, rooms, time_now, token_id, event_fields, event_formatter):
"""
Encode the archived rooms in a sync result
@@ -326,17 +326,17 @@ class SyncRestServlet(RestServlet):
"""
joined = {}
for room in rooms:
- joined[room.room_id] = SyncRestServlet.encode_room(
+ joined[room.room_id] = yield self.encode_room(
room, time_now, token_id, joined=False,
only_fields=event_fields,
event_formatter=event_formatter,
)
- return joined
+ defer.returnValue(joined)
- @staticmethod
+ @defer.inlineCallbacks
def encode_room(
- room, time_now, token_id, joined,
+ self, room, time_now, token_id, joined,
only_fields, event_formatter,
):
"""
@@ -355,9 +355,13 @@ class SyncRestServlet(RestServlet):
Returns:
dict[str, object]: the room, encoded in our response format
"""
- def serialize(event):
- return serialize_event(
- event, time_now, token_id=token_id,
+ def serialize(events):
+ return self._event_serializer.serialize_events(
+ events, time_now=time_now,
+ # We don't bundle "live" events, as otherwise clients
+ # will end up double counting annotations.
+ bundle_aggregations=False,
+ token_id=token_id,
event_format=event_formatter,
only_event_fields=only_fields,
)
@@ -376,8 +380,8 @@ class SyncRestServlet(RestServlet):
event.event_id, room.room_id, event.room_id,
)
- serialized_state = [serialize(e) for e in state_events]
- serialized_timeline = [serialize(e) for e in timeline_events]
+ serialized_state = yield serialize(state_events)
+ serialized_timeline = yield serialize(timeline_events)
account_data = room.account_data
@@ -397,7 +401,7 @@ class SyncRestServlet(RestServlet):
result["unread_notifications"] = room.unread_notifications
result["summary"] = room.summary
- return result
+ defer.returnValue(result)
def register_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index 4fea614e95..ebff7cff45 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -20,7 +20,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
@@ -29,7 +29,7 @@ class TagListServlet(RestServlet):
"""
GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags"
)
@@ -54,7 +54,7 @@ class TagServlet(RestServlet):
PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1
"""
- PATTERNS = client_v2_patterns(
+ PATTERNS = client_patterns(
"/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)"
)
diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index b9b5d07677..e7a987466a 100644
--- a/synapse/rest/client/v2_alpha/thirdparty.py
+++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -21,13 +21,13 @@ from twisted.internet import defer
from synapse.api.constants import ThirdPartyEntityKind
from synapse.http.servlet import RestServlet
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
class ThirdPartyProtocolsServlet(RestServlet):
- PATTERNS = client_v2_patterns("/thirdparty/protocols")
+ PATTERNS = client_patterns("/thirdparty/protocols")
def __init__(self, hs):
super(ThirdPartyProtocolsServlet, self).__init__()
@@ -44,7 +44,7 @@ class ThirdPartyProtocolsServlet(RestServlet):
class ThirdPartyProtocolServlet(RestServlet):
- PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
+ PATTERNS = client_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$")
def __init__(self, hs):
super(ThirdPartyProtocolServlet, self).__init__()
@@ -66,7 +66,7 @@ class ThirdPartyProtocolServlet(RestServlet):
class ThirdPartyUserServlet(RestServlet):
- PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
+ PATTERNS = client_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$")
def __init__(self, hs):
super(ThirdPartyUserServlet, self).__init__()
@@ -89,7 +89,7 @@ class ThirdPartyUserServlet(RestServlet):
class ThirdPartyLocationServlet(RestServlet):
- PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
+ PATTERNS = client_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$")
def __init__(self, hs):
super(ThirdPartyLocationServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py
index 6e76b9e9c2..6c366142e1 100644
--- a/synapse/rest/client/v2_alpha/tokenrefresh.py
+++ b/synapse/rest/client/v2_alpha/tokenrefresh.py
@@ -18,7 +18,7 @@ from twisted.internet import defer
from synapse.api.errors import AuthError
from synapse.http.servlet import RestServlet
-from ._base import client_v2_patterns
+from ._base import client_patterns
class TokenRefreshRestServlet(RestServlet):
@@ -26,7 +26,7 @@ class TokenRefreshRestServlet(RestServlet):
Exchanges refresh tokens for a pair of an access token and a new refresh
token.
"""
- PATTERNS = client_v2_patterns("/tokenrefresh")
+ PATTERNS = client_patterns("/tokenrefresh")
def __init__(self, hs):
super(TokenRefreshRestServlet, self).__init__()
diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index 36b02de37f..69e4efc47a 100644
--- a/synapse/rest/client/v2_alpha/user_directory.py
+++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -20,13 +20,13 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from ._base import client_v2_patterns
+from ._base import client_patterns
logger = logging.getLogger(__name__)
class UserDirectorySearchRestServlet(RestServlet):
- PATTERNS = client_v2_patterns("/user_directory/search$")
+ PATTERNS = client_patterns("/user_directory/search$")
def __init__(self, hs):
"""
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index eb8782aa6e..8a730bbc35 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -20,7 +20,7 @@ from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, SynapseError
-from synapse.crypto.keyring import KeyLookupError
+from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import respond_with_json_bytes, wrap_json_request_handler
from synapse.http.servlet import parse_integer, parse_json_object_from_request
@@ -89,7 +89,7 @@ class RemoteKey(Resource):
isLeaf = True
def __init__(self, hs):
- self.keyring = hs.get_keyring()
+ self.fetcher = ServerKeyFetcher(hs)
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
@@ -215,15 +215,7 @@ class RemoteKey(Resource):
json_results.add(bytes(result["key_json"]))
if cache_misses and query_remote_on_cache_miss:
- for server_name, key_ids in cache_misses.items():
- try:
- yield self.keyring.get_server_verify_key_v2_direct(
- server_name, key_ids
- )
- except KeyLookupError as e:
- logger.info("Failed to fetch key: %s", e)
- except Exception:
- logger.exception("Failed to get key for %r", server_name)
+ yield self.fetcher.get_keys(cache_misses)
yield self.query_keys(
request, query, query_remote_on_cache_miss=False
)
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index e2b5df701d..2dcc8f74d6 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -191,6 +191,10 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam
# in that case.
logger.warning("Failed to write to consumer: %s %s", type(e), e)
+ # Unregister the producer, if it has one, so Twisted doesn't complain
+ if request.producer:
+ request.unregisterProducer()
+
finish_request(request)
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index bdffa97805..8569677355 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -444,6 +444,9 @@ class MediaRepository(object):
)
return
+ if thumbnailer.transpose_method is not None:
+ m_width, m_height = thumbnailer.transpose()
+
if t_method == "crop":
t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
elif t_method == "scale":
@@ -578,6 +581,12 @@ class MediaRepository(object):
)
return
+ if thumbnailer.transpose_method is not None:
+ m_width, m_height = yield logcontext.defer_to_thread(
+ self.hs.get_reactor(),
+ thumbnailer.transpose
+ )
+
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
# they have the same dimensions of a scaled one.
thumbnails = {}
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index ba3ab1d37d..acf87709f2 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -31,6 +31,7 @@ from six.moves import urllib_parse as urlparse
from canonicaljson import json
from twisted.internet import defer
+from twisted.internet.error import DNSLookupError
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
@@ -328,9 +329,18 @@ class PreviewUrlResource(Resource):
# handler will return a SynapseError to the client instead of
# blank data or a 500.
raise
+ except DNSLookupError:
+ # DNS lookup returned no results
+ # Note: This will also be the case if one of the resolved IP
+ # addresses is blacklisted
+ raise SynapseError(
+ 502, "DNS resolution failure during URL preview generation",
+ Codes.UNKNOWN
+ )
except Exception as e:
# FIXME: pass through 404s and other error messages nicely
logger.warn("Error downloading %s: %r", url, e)
+
raise SynapseError(
500, "Failed to download content: %s" % (
traceback.format_exception_only(sys.exc_info()[0], e),
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 5aa03031f6..d90cbfb56a 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -108,6 +108,7 @@ class FileStorageProviderBackend(StorageProvider):
"""
def __init__(self, hs, config):
+ self.hs = hs
self.cache_directory = hs.config.media_store_path
self.base_directory = config
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 5305e9175f..35a750923b 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -56,8 +56,8 @@ class ThumbnailResource(Resource):
def _async_render_GET(self, request):
set_cors_headers(request)
server_name, media_id, _ = parse_media_id(request)
- width = parse_integer(request, "width")
- height = parse_integer(request, "height")
+ width = parse_integer(request, "width", required=True)
+ height = parse_integer(request, "height", required=True)
method = parse_string(request, "method", "scale")
m_type = parse_string(request, "type", "image/png")
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index a4b26c2587..3efd0d80fc 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -20,6 +20,17 @@ import PIL.Image as Image
logger = logging.getLogger(__name__)
+EXIF_ORIENTATION_TAG = 0x0112
+EXIF_TRANSPOSE_MAPPINGS = {
+ 2: Image.FLIP_LEFT_RIGHT,
+ 3: Image.ROTATE_180,
+ 4: Image.FLIP_TOP_BOTTOM,
+ 5: Image.TRANSPOSE,
+ 6: Image.ROTATE_270,
+ 7: Image.TRANSVERSE,
+ 8: Image.ROTATE_90
+}
+
class Thumbnailer(object):
@@ -31,6 +42,30 @@ class Thumbnailer(object):
def __init__(self, input_path):
self.image = Image.open(input_path)
self.width, self.height = self.image.size
+ self.transpose_method = None
+ try:
+ # We don't use ImageOps.exif_transpose since it crashes with big EXIF
+ image_exif = self.image._getexif()
+ if image_exif is not None:
+ image_orientation = image_exif.get(EXIF_ORIENTATION_TAG)
+ self.transpose_method = EXIF_TRANSPOSE_MAPPINGS.get(image_orientation)
+ except Exception as e:
+ # A lot of parsing errors can happen when parsing EXIF
+ logger.info("Error parsing image EXIF information: %s", e)
+
+ def transpose(self):
+ """Transpose the image using its EXIF Orientation tag
+
+ Returns:
+ Tuple[int, int]: (width, height) containing the new image size in pixels.
+ """
+ if self.transpose_method is not None:
+ self.image = self.image.transpose(self.transpose_method)
+ self.width, self.height = self.image.size
+ self.transpose_method = None
+ # We don't need EXIF any more
+ self.image.info["exif"] = None
+ return self.image.size
def aspect(self, max_width, max_height):
"""Calculate the largest size that preserves aspect ratio which
diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index ab901e63f2..a7fa4f39af 100644
--- a/synapse/rest/well_known.py
+++ b/synapse/rest/well_known.py
@@ -68,6 +68,6 @@ class WellKnownResource(Resource):
request.setHeader(b"Content-Type", b"text/plain")
return b'.well-known not available'
- logger.error("returning: %s", r)
+ logger.debug("returning: %s", r)
request.setHeader(b"Content-Type", b"application/json")
return json.dumps(r).encode("utf-8")
diff --git a/synapse/server.py b/synapse/server.py
index 8c30ac2fa5..9229a68a8d 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -35,6 +35,7 @@ from synapse.crypto import context_factory
from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory
from synapse.events.spamcheck import SpamChecker
+from synapse.events.utils import EventClientSerializer
from synapse.federation.federation_client import FederationClient
from synapse.federation.federation_server import (
FederationHandlerRegistry,
@@ -71,6 +72,7 @@ from synapse.handlers.room_list import RoomListHandler
from synapse.handlers.room_member import RoomMemberMasterHandler
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
from synapse.handlers.set_password import SetPasswordHandler
+from synapse.handlers.stats import StatsHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler
from synapse.handlers.user_directory import UserDirectoryHandler
@@ -138,6 +140,7 @@ class HomeServer(object):
'acme_handler',
'auth_handler',
'device_handler',
+ 'stats_handler',
'e2e_keys_handler',
'e2e_room_keys_handler',
'event_handler',
@@ -185,10 +188,12 @@ class HomeServer(object):
'sendmail',
'registration_handler',
'account_validity_handler',
+ 'event_client_serializer',
]
REQUIRED_ON_MASTER_STARTUP = [
"user_directory_handler",
+ "stats_handler"
]
# This is overridden in derived application classes
@@ -472,6 +477,9 @@ class HomeServer(object):
def build_secrets(self):
return Secrets()
+ def build_stats_handler(self):
+ return StatsHandler(self)
+
def build_spam_checker(self):
return SpamChecker(self)
@@ -511,6 +519,9 @@ class HomeServer(object):
def build_account_validity_handler(self):
return AccountValidityHandler(self)
+ def build_event_client_serializer(self):
+ return EventClientSerializer(self)
+
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 3ba3a967c2..9583e82d52 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -18,7 +18,6 @@ import synapse.server_notices.server_notices_sender
import synapse.state
import synapse.storage
-
class HomeServer(object):
@property
def config(self) -> synapse.config.homeserver.HomeServerConfig:
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index c432041b4e..71316f7d09 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -36,6 +36,7 @@ from .engines import PostgresEngine
from .event_federation import EventFederationStore
from .event_push_actions import EventPushActionsStore
from .events import EventsStore
+from .events_bg_updates import EventsBackgroundUpdatesStore
from .filtering import FilteringStore
from .group_server import GroupServerStore
from .keys import KeyStore
@@ -49,11 +50,13 @@ from .pusher import PusherStore
from .receipts import ReceiptsStore
from .registration import RegistrationStore
from .rejections import RejectionsStore
+from .relations import RelationsStore
from .room import RoomStore
from .roommember import RoomMemberStore
from .search import SearchStore
from .signatures import SignatureStore
from .state import StateStore
+from .stats import StatsStore
from .stream import StreamStore
from .tags import TagsStore
from .transactions import TransactionStore
@@ -64,6 +67,7 @@ logger = logging.getLogger(__name__)
class DataStore(
+ EventsBackgroundUpdatesStore,
RoomMemberStore,
RoomStore,
RegistrationStore,
@@ -99,6 +103,8 @@ class DataStore(
GroupServerStore,
UserErasureStore,
MonthlyActiveUsersStore,
+ StatsStore,
+ RelationsStore,
):
def __init__(self, db_conn, hs):
self.hs = hs
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 983ce026e1..52891bb9eb 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,6 +16,7 @@
# limitations under the License.
import itertools
import logging
+import random
import sys
import threading
import time
@@ -227,6 +230,8 @@ class SQLBaseStore(object):
# A set of tables that are not safe to use native upserts in.
self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
+ self._account_validity = self.hs.config.account_validity
+
# We add the user_directory_search table to the blacklist on SQLite
# because the existing search table does not have an index, making it
# unsafe to use native upserts.
@@ -243,6 +248,16 @@ class SQLBaseStore(object):
self._check_safe_to_upsert,
)
+ self.rand = random.SystemRandom()
+
+ if self._account_validity.enabled:
+ self._clock.call_later(
+ 0.0,
+ run_as_background_process,
+ "account_validity_set_expiration_dates",
+ self._set_expiration_date_when_missing,
+ )
+
@defer.inlineCallbacks
def _check_safe_to_upsert(self):
"""
@@ -275,6 +290,67 @@ class SQLBaseStore(object):
self._check_safe_to_upsert,
)
+ @defer.inlineCallbacks
+ def _set_expiration_date_when_missing(self):
+ """
+ Retrieves the list of registered users that don't have an expiration date, and
+ adds an expiration date for each of them.
+ """
+
+ def select_users_with_no_expiration_date_txn(txn):
+ """Retrieves the list of registered users with no expiration date from the
+ database.
+ """
+ sql = (
+ "SELECT users.name FROM users"
+ " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
+ " WHERE account_validity.user_id is NULL;"
+ )
+ txn.execute(sql, [])
+
+ res = self.cursor_to_dict(txn)
+ if res:
+ for user in res:
+ self.set_expiration_date_for_user_txn(
+ txn,
+ user["name"],
+ use_delta=True,
+ )
+
+ yield self.runInteraction(
+ "get_users_with_no_expiration_date",
+ select_users_with_no_expiration_date_txn,
+ )
+
+ def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
+ """Sets an expiration date to the account with the given user ID.
+
+ Args:
+ user_id (str): User ID to set an expiration date for.
+ use_delta (bool): If set to False, the expiration date for the user will be
+ now + validity period. If set to True, this expiration date will be a
+ random value in the [now + period - d ; now + period] range, d being a
+ delta equal to 10% of the validity period.
+ """
+ now_ms = self._clock.time_msec()
+ expiration_ts = now_ms + self._account_validity.period
+
+ if use_delta:
+ expiration_ts = self.rand.randrange(
+ expiration_ts - self._account_validity.startup_job_max_delta,
+ expiration_ts,
+ )
+
+ self._simple_insert_txn(
+ txn,
+ "account_validity",
+ values={
+ "user_id": user_id,
+ "expiration_ts_ms": expiration_ts,
+ "email_sent": False,
+ },
+ )
+
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
@@ -1203,7 +1279,8 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k,) for k in keyvalues),
)
- return txn.execute(sql, list(keyvalues.values()))
+ txn.execute(sql, list(keyvalues.values()))
+ return txn.rowcount
def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
return self.runInteraction(
@@ -1222,9 +1299,12 @@ class SQLBaseStore(object):
column : column name to test for inclusion against `iterable`
iterable : list
keyvalues : dict of column names and values to select the rows with
+
+ Returns:
+ int: Number rows deleted
"""
if not iterable:
- return
+ return 0
sql = "DELETE FROM %s" % table
@@ -1239,7 +1319,9 @@ class SQLBaseStore(object):
if clauses:
sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
- return txn.execute(sql, values)
+ txn.execute(sql, values)
+
+ return txn.rowcount
def _get_cache_dict(
self, db_conn, table, entity_column, stream_column, max_value, limit=100000
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index 6092f600ba..eb329ebd8b 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -302,7 +302,7 @@ class ApplicationServiceTransactionWorkerStore(
event_ids = json.loads(entry["event_ids"])
- events = yield self._get_events(event_ids)
+ events = yield self.get_events_as_list(event_ids)
defer.returnValue(
AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
@@ -358,7 +358,7 @@ class ApplicationServiceTransactionWorkerStore(
"get_new_events_for_appservice", get_new_events_for_appservice_txn
)
- events = yield self._get_events(event_ids)
+ events = yield self.get_events_as_list(event_ids)
defer.returnValue((upper_bound, events))
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index fed4ea3610..9b0a99cb49 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -118,7 +118,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
defer.returnValue(count)
def get_new_device_msgs_for_remote(
- self, destination, last_stream_id, current_stream_id, limit=100
+ self, destination, last_stream_id, current_stream_id, limit
):
"""
Args:
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 956f876572..09e39c2c28 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -45,7 +45,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"""
return self.get_auth_chain_ids(
event_ids, include_given=include_given
- ).addCallback(self._get_events)
+ ).addCallback(self.get_events_as_list)
def get_auth_chain_ids(self, event_ids, include_given=False):
"""Get auth events for given event_ids. The events *must* be state events.
@@ -316,7 +316,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
event_list,
limit,
)
- .addCallback(self._get_events)
+ .addCallback(self.get_events_as_list)
.addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
)
@@ -382,7 +382,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
latest_events,
limit,
)
- events = yield self._get_events(ids)
+ events = yield self.get_events_as_list(ids)
defer.returnValue(events)
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 7a7f841c6c..f9162be9b9 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -219,41 +220,11 @@ class EventsStore(
EventsWorkerStore,
BackgroundUpdateStore,
):
- EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
- EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
def __init__(self, db_conn, hs):
super(EventsStore, self).__init__(db_conn, hs)
- self.register_background_update_handler(
- self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
- )
- self.register_background_update_handler(
- self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
- self._background_reindex_fields_sender,
- )
-
- self.register_background_index_update(
- "event_contains_url_index",
- index_name="event_contains_url_index",
- table="events",
- columns=["room_id", "topological_ordering", "stream_ordering"],
- where_clause="contains_url = true AND outlier = false",
- )
-
- # an event_id index on event_search is useful for the purge_history
- # api. Plus it means we get to enforce some integrity with a UNIQUE
- # clause
- self.register_background_index_update(
- "event_search_event_id_idx",
- index_name="event_search_event_id_idx",
- table="event_search",
- columns=["event_id"],
- unique=True,
- psql_only=True,
- )
self._event_persist_queue = _EventPeristenceQueue()
-
self._state_resolution_handler = hs.get_state_resolution_handler()
@defer.inlineCallbacks
@@ -554,10 +525,18 @@ class EventsStore(
e_id for event in new_events for e_id in event.prev_event_ids()
)
- # Finally, remove any events which are prev_events of any existing events.
+ # Remove any events which are prev_events of any existing events.
existing_prevs = yield self._get_events_which_are_prevs(result)
result.difference_update(existing_prevs)
+ # Finally handle the case where the new events have soft-failed prev
+ # events. If they do we need to remove them and their prev events,
+ # otherwise we end up with dangling extremities.
+ existing_prevs = yield self._get_prevs_before_rejected(
+ e_id for event in new_events for e_id in event.prev_event_ids()
+ )
+ result.difference_update(existing_prevs)
+
defer.returnValue(result)
@defer.inlineCallbacks
@@ -573,12 +552,13 @@ class EventsStore(
"""
results = []
- def _get_events(txn, batch):
+ def _get_events_which_are_prevs_txn(txn, batch):
sql = """
- SELECT prev_event_id
+ SELECT prev_event_id, internal_metadata
FROM event_edges
INNER JOIN events USING (event_id)
LEFT JOIN rejections USING (event_id)
+ LEFT JOIN event_json USING (event_id)
WHERE
prev_event_id IN (%s)
AND NOT events.outlier
@@ -588,14 +568,86 @@ class EventsStore(
)
txn.execute(sql, batch)
- results.extend(r[0] for r in txn)
+ results.extend(
+ r[0]
+ for r in txn
+ if not json.loads(r[1]).get("soft_failed")
+ )
for chunk in batch_iter(event_ids, 100):
- yield self.runInteraction("_get_events_which_are_prevs", _get_events, chunk)
+ yield self.runInteraction(
+ "_get_events_which_are_prevs",
+ _get_events_which_are_prevs_txn,
+ chunk,
+ )
defer.returnValue(results)
@defer.inlineCallbacks
+ def _get_prevs_before_rejected(self, event_ids):
+ """Get soft-failed ancestors to remove from the extremities.
+
+ Given a set of events, find all those that have been soft-failed or
+ rejected. Returns those soft failed/rejected events and their prev
+ events (whether soft-failed/rejected or not), and recurses up the
+ prev-event graph until it finds no more soft-failed/rejected events.
+
+ This is used to find extremities that are ancestors of new events, but
+ are separated by soft failed events.
+
+ Args:
+ event_ids (Iterable[str]): Events to find prev events for. Note
+ that these must have already been persisted.
+
+ Returns:
+ Deferred[set[str]]
+ """
+
+ # The set of event_ids to return. This includes all soft-failed events
+ # and their prev events.
+ existing_prevs = set()
+
+ def _get_prevs_before_rejected_txn(txn, batch):
+ to_recursively_check = batch
+
+ while to_recursively_check:
+ sql = """
+ SELECT
+ event_id, prev_event_id, internal_metadata,
+ rejections.event_id IS NOT NULL
+ FROM event_edges
+ INNER JOIN events USING (event_id)
+ LEFT JOIN rejections USING (event_id)
+ LEFT JOIN event_json USING (event_id)
+ WHERE
+ event_id IN (%s)
+ AND NOT events.outlier
+ """ % (
+ ",".join("?" for _ in to_recursively_check),
+ )
+
+ txn.execute(sql, to_recursively_check)
+ to_recursively_check = []
+
+ for event_id, prev_event_id, metadata, rejected in txn:
+ if prev_event_id in existing_prevs:
+ continue
+
+ soft_failed = json.loads(metadata).get("soft_failed")
+ if soft_failed or rejected:
+ to_recursively_check.append(prev_event_id)
+ existing_prevs.add(prev_event_id)
+
+ for chunk in batch_iter(event_ids, 100):
+ yield self.runInteraction(
+ "_get_prevs_before_rejected",
+ _get_prevs_before_rejected_txn,
+ chunk,
+ )
+
+ defer.returnValue(existing_prevs)
+
+ @defer.inlineCallbacks
def _get_new_state_after_events(
self, room_id, events_context, old_latest_event_ids, new_latest_event_ids
):
@@ -1325,6 +1377,9 @@ class EventsStore(
txn, event.room_id, event.redacts
)
+ # Remove from relations table.
+ self._handle_redaction(txn, event.redacts)
+
# Update the event_forward_extremities, event_backward_extremities and
# event_edges tables.
self._handle_mult_prev_events(
@@ -1351,6 +1406,8 @@ class EventsStore(
# Insert into the event_search table.
self._store_guest_access_txn(txn, event)
+ self._handle_event_relations(txn, event)
+
# Insert into the room_memberships table.
self._store_room_members_txn(
txn,
@@ -1493,153 +1550,6 @@ class EventsStore(
ret = yield self.runInteraction("count_daily_active_rooms", _count)
defer.returnValue(ret)
- @defer.inlineCallbacks
- def _background_reindex_fields_sender(self, progress, batch_size):
- target_min_stream_id = progress["target_min_stream_id_inclusive"]
- max_stream_id = progress["max_stream_id_exclusive"]
- rows_inserted = progress.get("rows_inserted", 0)
-
- INSERT_CLUMP_SIZE = 1000
-
- def reindex_txn(txn):
- sql = (
- "SELECT stream_ordering, event_id, json FROM events"
- " INNER JOIN event_json USING (event_id)"
- " WHERE ? <= stream_ordering AND stream_ordering < ?"
- " ORDER BY stream_ordering DESC"
- " LIMIT ?"
- )
-
- txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
-
- rows = txn.fetchall()
- if not rows:
- return 0
-
- min_stream_id = rows[-1][0]
-
- update_rows = []
- for row in rows:
- try:
- event_id = row[1]
- event_json = json.loads(row[2])
- sender = event_json["sender"]
- content = event_json["content"]
-
- contains_url = "url" in content
- if contains_url:
- contains_url &= isinstance(content["url"], text_type)
- except (KeyError, AttributeError):
- # If the event is missing a necessary field then
- # skip over it.
- continue
-
- update_rows.append((sender, contains_url, event_id))
-
- sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
-
- for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
- clump = update_rows[index : index + INSERT_CLUMP_SIZE]
- txn.executemany(sql, clump)
-
- progress = {
- "target_min_stream_id_inclusive": target_min_stream_id,
- "max_stream_id_exclusive": min_stream_id,
- "rows_inserted": rows_inserted + len(rows),
- }
-
- self._background_update_progress_txn(
- txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
- )
-
- return len(rows)
-
- result = yield self.runInteraction(
- self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
- )
-
- if not result:
- yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME)
-
- defer.returnValue(result)
-
- @defer.inlineCallbacks
- def _background_reindex_origin_server_ts(self, progress, batch_size):
- target_min_stream_id = progress["target_min_stream_id_inclusive"]
- max_stream_id = progress["max_stream_id_exclusive"]
- rows_inserted = progress.get("rows_inserted", 0)
-
- INSERT_CLUMP_SIZE = 1000
-
- def reindex_search_txn(txn):
- sql = (
- "SELECT stream_ordering, event_id FROM events"
- " WHERE ? <= stream_ordering AND stream_ordering < ?"
- " ORDER BY stream_ordering DESC"
- " LIMIT ?"
- )
-
- txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
-
- rows = txn.fetchall()
- if not rows:
- return 0
-
- min_stream_id = rows[-1][0]
- event_ids = [row[1] for row in rows]
-
- rows_to_update = []
-
- chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
- for chunk in chunks:
- ev_rows = self._simple_select_many_txn(
- txn,
- table="event_json",
- column="event_id",
- iterable=chunk,
- retcols=["event_id", "json"],
- keyvalues={},
- )
-
- for row in ev_rows:
- event_id = row["event_id"]
- event_json = json.loads(row["json"])
- try:
- origin_server_ts = event_json["origin_server_ts"]
- except (KeyError, AttributeError):
- # If the event is missing a necessary field then
- # skip over it.
- continue
-
- rows_to_update.append((origin_server_ts, event_id))
-
- sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
-
- for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
- clump = rows_to_update[index : index + INSERT_CLUMP_SIZE]
- txn.executemany(sql, clump)
-
- progress = {
- "target_min_stream_id_inclusive": target_min_stream_id,
- "max_stream_id_exclusive": min_stream_id,
- "rows_inserted": rows_inserted + len(rows_to_update),
- }
-
- self._background_update_progress_txn(
- txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress
- )
-
- return len(rows_to_update)
-
- result = yield self.runInteraction(
- self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
- )
-
- if not result:
- yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME)
-
- defer.returnValue(result)
-
def get_current_backfill_token(self):
"""The current minimum token that backfilled events have reached"""
return -self._backfill_id_gen.get_current_token()
@@ -1655,10 +1565,11 @@ class EventsStore(
def get_all_new_forward_event_rows(txn):
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
+ " state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
" WHERE ? < stream_ordering AND stream_ordering <= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
@@ -1673,11 +1584,12 @@ class EventsStore(
sql = (
"SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
+ " state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
" ORDER BY event_stream_ordering DESC"
@@ -1698,10 +1610,11 @@ class EventsStore(
def get_all_new_backfill_event_rows(txn):
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
+ " state_key, redacts, relates_to_id"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
" WHERE ? > stream_ordering AND stream_ordering >= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
@@ -1716,11 +1629,12 @@ class EventsStore(
sql = (
"SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
- " state_key, redacts"
+ " state_key, redacts, relates_to_id"
" FROM events AS e"
" INNER JOIN ex_outlier_stream USING (event_id)"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
+ " LEFT JOIN event_relations USING (event_id)"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
diff --git a/synapse/storage/events_bg_updates.py b/synapse/storage/events_bg_updates.py
new file mode 100644
index 0000000000..75c1935bf3
--- /dev/null
+++ b/synapse/storage/events_bg_updates.py
@@ -0,0 +1,401 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from six import text_type
+
+from canonicaljson import json
+
+from twisted.internet import defer
+
+from synapse.storage.background_updates import BackgroundUpdateStore
+
+logger = logging.getLogger(__name__)
+
+
+class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
+
+ EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
+ EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
+ DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
+
+ def __init__(self, db_conn, hs):
+ super(EventsBackgroundUpdatesStore, self).__init__(db_conn, hs)
+
+ self.register_background_update_handler(
+ self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
+ )
+ self.register_background_update_handler(
+ self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
+ self._background_reindex_fields_sender,
+ )
+
+ self.register_background_index_update(
+ "event_contains_url_index",
+ index_name="event_contains_url_index",
+ table="events",
+ columns=["room_id", "topological_ordering", "stream_ordering"],
+ where_clause="contains_url = true AND outlier = false",
+ )
+
+ # an event_id index on event_search is useful for the purge_history
+ # api. Plus it means we get to enforce some integrity with a UNIQUE
+ # clause
+ self.register_background_index_update(
+ "event_search_event_id_idx",
+ index_name="event_search_event_id_idx",
+ table="event_search",
+ columns=["event_id"],
+ unique=True,
+ psql_only=True,
+ )
+
+ self.register_background_update_handler(
+ self.DELETE_SOFT_FAILED_EXTREMITIES,
+ self._cleanup_extremities_bg_update,
+ )
+
+ @defer.inlineCallbacks
+ def _background_reindex_fields_sender(self, progress, batch_size):
+ target_min_stream_id = progress["target_min_stream_id_inclusive"]
+ max_stream_id = progress["max_stream_id_exclusive"]
+ rows_inserted = progress.get("rows_inserted", 0)
+
+ INSERT_CLUMP_SIZE = 1000
+
+ def reindex_txn(txn):
+ sql = (
+ "SELECT stream_ordering, event_id, json FROM events"
+ " INNER JOIN event_json USING (event_id)"
+ " WHERE ? <= stream_ordering AND stream_ordering < ?"
+ " ORDER BY stream_ordering DESC"
+ " LIMIT ?"
+ )
+
+ txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
+
+ rows = txn.fetchall()
+ if not rows:
+ return 0
+
+ min_stream_id = rows[-1][0]
+
+ update_rows = []
+ for row in rows:
+ try:
+ event_id = row[1]
+ event_json = json.loads(row[2])
+ sender = event_json["sender"]
+ content = event_json["content"]
+
+ contains_url = "url" in content
+ if contains_url:
+ contains_url &= isinstance(content["url"], text_type)
+ except (KeyError, AttributeError):
+ # If the event is missing a necessary field then
+ # skip over it.
+ continue
+
+ update_rows.append((sender, contains_url, event_id))
+
+ sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
+
+ for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
+ clump = update_rows[index : index + INSERT_CLUMP_SIZE]
+ txn.executemany(sql, clump)
+
+ progress = {
+ "target_min_stream_id_inclusive": target_min_stream_id,
+ "max_stream_id_exclusive": min_stream_id,
+ "rows_inserted": rows_inserted + len(rows),
+ }
+
+ self._background_update_progress_txn(
+ txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
+ )
+
+ return len(rows)
+
+ result = yield self.runInteraction(
+ self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
+ )
+
+ if not result:
+ yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME)
+
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def _background_reindex_origin_server_ts(self, progress, batch_size):
+ target_min_stream_id = progress["target_min_stream_id_inclusive"]
+ max_stream_id = progress["max_stream_id_exclusive"]
+ rows_inserted = progress.get("rows_inserted", 0)
+
+ INSERT_CLUMP_SIZE = 1000
+
+ def reindex_search_txn(txn):
+ sql = (
+ "SELECT stream_ordering, event_id FROM events"
+ " WHERE ? <= stream_ordering AND stream_ordering < ?"
+ " ORDER BY stream_ordering DESC"
+ " LIMIT ?"
+ )
+
+ txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
+
+ rows = txn.fetchall()
+ if not rows:
+ return 0
+
+ min_stream_id = rows[-1][0]
+ event_ids = [row[1] for row in rows]
+
+ rows_to_update = []
+
+ chunks = [event_ids[i : i + 100] for i in range(0, len(event_ids), 100)]
+ for chunk in chunks:
+ ev_rows = self._simple_select_many_txn(
+ txn,
+ table="event_json",
+ column="event_id",
+ iterable=chunk,
+ retcols=["event_id", "json"],
+ keyvalues={},
+ )
+
+ for row in ev_rows:
+ event_id = row["event_id"]
+ event_json = json.loads(row["json"])
+ try:
+ origin_server_ts = event_json["origin_server_ts"]
+ except (KeyError, AttributeError):
+ # If the event is missing a necessary field then
+ # skip over it.
+ continue
+
+ rows_to_update.append((origin_server_ts, event_id))
+
+ sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
+
+ for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
+ clump = rows_to_update[index : index + INSERT_CLUMP_SIZE]
+ txn.executemany(sql, clump)
+
+ progress = {
+ "target_min_stream_id_inclusive": target_min_stream_id,
+ "max_stream_id_exclusive": min_stream_id,
+ "rows_inserted": rows_inserted + len(rows_to_update),
+ }
+
+ self._background_update_progress_txn(
+ txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress
+ )
+
+ return len(rows_to_update)
+
+ result = yield self.runInteraction(
+ self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
+ )
+
+ if not result:
+ yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME)
+
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def _cleanup_extremities_bg_update(self, progress, batch_size):
+ """Background update to clean out extremities that should have been
+ deleted previously.
+
+ Mainly used to deal with the aftermath of #5269.
+ """
+
+ # This works by first copying all existing forward extremities into the
+ # `_extremities_to_check` table at start up, and then checking each
+ # event in that table whether we have any descendants that are not
+ # soft-failed/rejected. If that is the case then we delete that event
+ # from the forward extremities table.
+ #
+ # For efficiency, we do this in batches by recursively pulling out all
+ # descendants of a batch until we find the non soft-failed/rejected
+ # events, i.e. the set of descendants whose chain of prev events back
+ # to the batch of extremities are all soft-failed or rejected.
+ # Typically, we won't find any such events as extremities will rarely
+ # have any descendants, but if they do then we should delete those
+ # extremities.
+
+ def _cleanup_extremities_bg_update_txn(txn):
+ # The set of extremity event IDs that we're checking this round
+ original_set = set()
+
+ # A dict[str, set[str]] of event ID to their prev events.
+ graph = {}
+
+ # The set of descendants of the original set that are not rejected
+ # nor soft-failed. Ancestors of these events should be removed
+ # from the forward extremities table.
+ non_rejected_leaves = set()
+
+ # Set of event IDs that have been soft failed, and for which we
+ # should check if they have descendants which haven't been soft
+ # failed.
+ soft_failed_events_to_lookup = set()
+
+ # First, we get `batch_size` events from the table, pulling out
+ # their successor events, if any, and the successor events'
+ # rejection status.
+ txn.execute(
+ """SELECT prev_event_id, event_id, internal_metadata,
+ rejections.event_id IS NOT NULL, events.outlier
+ FROM (
+ SELECT event_id AS prev_event_id
+ FROM _extremities_to_check
+ LIMIT ?
+ ) AS f
+ LEFT JOIN event_edges USING (prev_event_id)
+ LEFT JOIN events USING (event_id)
+ LEFT JOIN event_json USING (event_id)
+ LEFT JOIN rejections USING (event_id)
+ """, (batch_size,)
+ )
+
+ for prev_event_id, event_id, metadata, rejected, outlier in txn:
+ original_set.add(prev_event_id)
+
+ if not event_id or outlier:
+ # Common case where the forward extremity doesn't have any
+ # descendants.
+ continue
+
+ graph.setdefault(event_id, set()).add(prev_event_id)
+
+ soft_failed = False
+ if metadata:
+ soft_failed = json.loads(metadata).get("soft_failed")
+
+ if soft_failed or rejected:
+ soft_failed_events_to_lookup.add(event_id)
+ else:
+ non_rejected_leaves.add(event_id)
+
+ # Now we recursively check all the soft-failed descendants we
+ # found above in the same way, until we have nothing left to
+ # check.
+ while soft_failed_events_to_lookup:
+ # We only want to do 100 at a time, so we split given list
+ # into two.
+ batch = list(soft_failed_events_to_lookup)
+ to_check, to_defer = batch[:100], batch[100:]
+ soft_failed_events_to_lookup = set(to_defer)
+
+ sql = """SELECT prev_event_id, event_id, internal_metadata,
+ rejections.event_id IS NOT NULL
+ FROM event_edges
+ INNER JOIN events USING (event_id)
+ INNER JOIN event_json USING (event_id)
+ LEFT JOIN rejections USING (event_id)
+ WHERE
+ prev_event_id IN (%s)
+ AND NOT events.outlier
+ """ % (
+ ",".join("?" for _ in to_check),
+ )
+ txn.execute(sql, to_check)
+
+ for prev_event_id, event_id, metadata, rejected in txn:
+ if event_id in graph:
+ # Already handled this event previously, but we still
+ # want to record the edge.
+ graph[event_id].add(prev_event_id)
+ continue
+
+ graph[event_id] = {prev_event_id}
+
+ soft_failed = json.loads(metadata).get("soft_failed")
+ if soft_failed or rejected:
+ soft_failed_events_to_lookup.add(event_id)
+ else:
+ non_rejected_leaves.add(event_id)
+
+ # We have a set of non-soft-failed descendants, so we recurse up
+ # the graph to find all ancestors and add them to the set of event
+ # IDs that we can delete from forward extremities table.
+ to_delete = set()
+ while non_rejected_leaves:
+ event_id = non_rejected_leaves.pop()
+ prev_event_ids = graph.get(event_id, set())
+ non_rejected_leaves.update(prev_event_ids)
+ to_delete.update(prev_event_ids)
+
+ to_delete.intersection_update(original_set)
+
+ deleted = self._simple_delete_many_txn(
+ txn=txn,
+ table="event_forward_extremities",
+ column="event_id",
+ iterable=to_delete,
+ keyvalues={},
+ )
+
+ logger.info(
+ "Deleted %d forward extremities of %d checked, to clean up #5269",
+ deleted,
+ len(original_set),
+ )
+
+ if deleted:
+ # We now need to invalidate the caches of these rooms
+ rows = self._simple_select_many_txn(
+ txn,
+ table="events",
+ column="event_id",
+ iterable=to_delete,
+ keyvalues={},
+ retcols=("room_id",)
+ )
+ room_ids = set(row["room_id"] for row in rows)
+ for room_id in room_ids:
+ txn.call_after(
+ self.get_latest_event_ids_in_room.invalidate,
+ (room_id,)
+ )
+
+ self._simple_delete_many_txn(
+ txn=txn,
+ table="_extremities_to_check",
+ column="event_id",
+ iterable=original_set,
+ keyvalues={},
+ )
+
+ return len(original_set)
+
+ num_handled = yield self.runInteraction(
+ "_cleanup_extremities_bg_update", _cleanup_extremities_bg_update_txn,
+ )
+
+ if not num_handled:
+ yield self._end_background_update(self.DELETE_SOFT_FAILED_EXTREMITIES)
+
+ def _drop_table_txn(txn):
+ txn.execute("DROP TABLE _extremities_to_check")
+
+ yield self.runInteraction(
+ "_cleanup_extremities_bg_update_drop_table",
+ _drop_table_txn,
+ )
+
+ defer.returnValue(num_handled)
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 663991a9b6..1782428048 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from __future__ import division
+
import itertools
import logging
from collections import namedtuple
@@ -103,7 +105,7 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
Deferred : A FrozenEvent.
"""
- events = yield self._get_events(
+ events = yield self.get_events_as_list(
[event_id],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
@@ -142,7 +144,7 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
Deferred : Dict from event_id to event.
"""
- events = yield self._get_events(
+ events = yield self.get_events_as_list(
event_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
@@ -152,13 +154,32 @@ class EventsWorkerStore(SQLBaseStore):
defer.returnValue({e.event_id: e for e in events})
@defer.inlineCallbacks
- def _get_events(
+ def get_events_as_list(
self,
event_ids,
check_redacted=True,
get_prev_content=False,
allow_rejected=False,
):
+ """Get events from the database and return in a list in the same order
+ as given by `event_ids` arg.
+
+ Args:
+ event_ids (list): The event_ids of the events to fetch
+ check_redacted (bool): If True, check if event has been redacted
+ and redact it.
+ get_prev_content (bool): If True and event is a state event,
+ include the previous states content in the unsigned field.
+ allow_rejected (bool): If True return rejected events.
+
+ Returns:
+ Deferred[list[EventBase]]: List of events fetched from the database. The
+ events are in the same order as `event_ids` arg.
+
+ Note that the returned list may be smaller than the list of event
+ IDs if not all events could be fetched.
+ """
+
if not event_ids:
defer.returnValue([])
@@ -202,21 +223,22 @@ class EventsWorkerStore(SQLBaseStore):
#
# The problem is that we end up at this point when an event
# which has been redacted is pulled out of the database by
- # _enqueue_events, because _enqueue_events needs to check the
- # redaction before it can cache the redacted event. So obviously,
- # calling get_event to get the redacted event out of the database
- # gives us an infinite loop.
+ # _enqueue_events, because _enqueue_events needs to check
+ # the redaction before it can cache the redacted event. So
+ # obviously, calling get_event to get the redacted event out
+ # of the database gives us an infinite loop.
#
- # For now (quick hack to fix during 0.99 release cycle), we just
- # go and fetch the relevant row from the db, but it would be nice
- # to think about how we can cache this rather than hit the db
- # every time we access a redaction event.
+ # For now (quick hack to fix during 0.99 release cycle), we
+ # just go and fetch the relevant row from the db, but it
+ # would be nice to think about how we can cache this rather
+ # than hit the db every time we access a redaction event.
#
# One thought on how to do this:
- # 1. split _get_events up so that it is divided into (a) get the
- # rawish event from the db/cache, (b) do the redaction/rejection
- # filtering
- # 2. have _get_event_from_row just call the first half of that
+ # 1. split get_events_as_list up so that it is divided into
+ # (a) get the rawish event from the db/cache, (b) do the
+ # redaction/rejection filtering
+ # 2. have _get_event_from_row just call the first half of
+ # that
orig_sender = yield self._simple_select_one_onecol(
table="events",
@@ -590,4 +612,79 @@ class EventsWorkerStore(SQLBaseStore):
return res
- return self.runInteraction("get_rejection_reasons", f)
+ return self.runInteraction("get_seen_events_with_rejections", f)
+
+ def _get_total_state_event_counts_txn(self, txn, room_id):
+ """
+ See get_total_state_event_counts.
+ """
+ # We join against the events table as that has an index on room_id
+ sql = """
+ SELECT COUNT(*) FROM state_events
+ INNER JOIN events USING (room_id, event_id)
+ WHERE room_id=?
+ """
+ txn.execute(sql, (room_id,))
+ row = txn.fetchone()
+ return row[0] if row else 0
+
+ def get_total_state_event_counts(self, room_id):
+ """
+ Gets the total number of state events in a room.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[int]
+ """
+ return self.runInteraction(
+ "get_total_state_event_counts",
+ self._get_total_state_event_counts_txn, room_id
+ )
+
+ def _get_current_state_event_counts_txn(self, txn, room_id):
+ """
+ See get_current_state_event_counts.
+ """
+ sql = "SELECT COUNT(*) FROM current_state_events WHERE room_id=?"
+ txn.execute(sql, (room_id,))
+ row = txn.fetchone()
+ return row[0] if row else 0
+
+ def get_current_state_event_counts(self, room_id):
+ """
+ Gets the current number of state events in a room.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[int]
+ """
+ return self.runInteraction(
+ "get_current_state_event_counts",
+ self._get_current_state_event_counts_txn, room_id
+ )
+
+ @defer.inlineCallbacks
+ def get_room_complexity(self, room_id):
+ """
+ Get a rough approximation of the complexity of the room. This is used by
+ remote servers to decide whether they wish to join the room or not.
+ Higher complexity value indicates that being in the room will consume
+ more resources.
+
+ Args:
+ room_id (str)
+
+ Returns:
+ Deferred[dict[str:int]] of complexity version to complexity.
+ """
+ state_events = yield self.get_current_state_event_counts(room_id)
+
+ # Call this one "v1", so we can introduce new ones as we want to develop
+ # it.
+ complexity_v1 = round(state_events / 500, 2)
+
+ defer.returnValue({"v1": complexity_v1})
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 7036541792..5300720dbb 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -19,6 +19,7 @@ import logging
import six
+import attr
from signedjson.key import decode_verify_key_bytes
from synapse.util import batch_iter
@@ -36,6 +37,12 @@ else:
db_binary_type = memoryview
+@attr.s(slots=True, frozen=True)
+class FetchKeyResult(object):
+ verify_key = attr.ib() # VerifyKey: the key itself
+ valid_until_ts = attr.ib() # int: how long we can use this key for
+
+
class KeyStore(SQLBaseStore):
"""Persistence for signature verification keys
"""
@@ -54,8 +61,8 @@ class KeyStore(SQLBaseStore):
iterable of (server_name, key-id) tuples to fetch keys for
Returns:
- Deferred: resolves to dict[Tuple[str, str], VerifyKey|None]:
- map from (server_name, key_id) -> VerifyKey, or None if the key is
+ Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
+ map from (server_name, key_id) -> FetchKeyResult, or None if the key is
unknown
"""
keys = {}
@@ -65,17 +72,19 @@ class KeyStore(SQLBaseStore):
# batch_iter always returns tuples so it's safe to do len(batch)
sql = (
- "SELECT server_name, key_id, verify_key FROM server_signature_keys "
- "WHERE 1=0"
+ "SELECT server_name, key_id, verify_key, ts_valid_until_ms "
+ "FROM server_signature_keys WHERE 1=0"
) + " OR (server_name=? AND key_id=?)" * len(batch)
txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
for row in txn:
- server_name, key_id, key_bytes = row
- keys[(server_name, key_id)] = decode_verify_key_bytes(
- key_id, bytes(key_bytes)
+ server_name, key_id, key_bytes, ts_valid_until_ms = row
+ res = FetchKeyResult(
+ verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
+ valid_until_ts=ts_valid_until_ms,
)
+ keys[(server_name, key_id)] = res
def _txn(txn):
for batch in batch_iter(server_name_and_key_ids, 50):
@@ -84,38 +93,53 @@ class KeyStore(SQLBaseStore):
return self.runInteraction("get_server_verify_keys", _txn)
- def store_server_verify_key(
- self, server_name, from_server, time_now_ms, verify_key
- ):
- """Stores a NACL verification key for the given server.
+ def store_server_verify_keys(self, from_server, ts_added_ms, verify_keys):
+ """Stores NACL verification keys for remote servers.
Args:
- server_name (str): The name of the server.
- from_server (str): Where the verification key was looked up
- time_now_ms (int): The time now in milliseconds
- verify_key (nacl.signing.VerifyKey): The NACL verify key.
+ from_server (str): Where the verification keys were looked up
+ ts_added_ms (int): The time to record that the key was added
+ verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
+ keys to be stored. Each entry is a triplet of
+ (server_name, key_id, key).
"""
- key_id = "%s:%s" % (verify_key.alg, verify_key.version)
-
- # XXX fix this to not need a lock (#3819)
- def _txn(txn):
- self._simple_upsert_txn(
- txn,
- table="server_signature_keys",
- keyvalues={"server_name": server_name, "key_id": key_id},
- values={
- "from_server": from_server,
- "ts_added_ms": time_now_ms,
- "verify_key": db_binary_type(verify_key.encode()),
- },
+ key_values = []
+ value_values = []
+ invalidations = []
+ for server_name, key_id, fetch_result in verify_keys:
+ key_values.append((server_name, key_id))
+ value_values.append(
+ (
+ from_server,
+ ts_added_ms,
+ fetch_result.valid_until_ts,
+ db_binary_type(fetch_result.verify_key.encode()),
+ )
)
# invalidate takes a tuple corresponding to the params of
# _get_server_verify_key. _get_server_verify_key only takes one
# param, which is itself the 2-tuple (server_name, key_id).
- txn.call_after(
- self._get_server_verify_key.invalidate, ((server_name, key_id),)
- )
-
- return self.runInteraction("store_server_verify_key", _txn)
+ invalidations.append((server_name, key_id))
+
+ def _invalidate(res):
+ f = self._get_server_verify_key.invalidate
+ for i in invalidations:
+ f((i, ))
+ return res
+
+ return self.runInteraction(
+ "store_server_verify_keys",
+ self._simple_upsert_many_txn,
+ table="server_signature_keys",
+ key_names=("server_name", "key_id"),
+ key_values=key_values,
+ value_names=(
+ "from_server",
+ "ts_added_ms",
+ "ts_valid_until_ms",
+ "verify_key",
+ ),
+ value_values=value_values,
+ ).addCallback(_invalidate)
def store_server_keys_json(
self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 03a06a83d6..4cf159ba81 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
-# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -725,17 +727,7 @@ class RegistrationStore(
raise StoreError(400, "User ID already taken.", errcode=Codes.USER_IN_USE)
if self._account_validity.enabled:
- now_ms = self.clock.time_msec()
- expiration_ts = now_ms + self._account_validity.period
- self._simple_insert_txn(
- txn,
- "account_validity",
- values={
- "user_id": user_id,
- "expiration_ts_ms": expiration_ts,
- "email_sent": False,
- }
- )
+ self.set_expiration_date_for_user_txn(txn, user_id)
if token:
# it's possible for this to get a conflict, but only for a single user
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
new file mode 100644
index 0000000000..4c83800cca
--- /dev/null
+++ b/synapse/storage/relations.py
@@ -0,0 +1,476 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+import attr
+
+from twisted.internet import defer
+
+from synapse.api.constants import RelationTypes
+from synapse.api.errors import SynapseError
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.stream import generate_pagination_where_clause
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+
+logger = logging.getLogger(__name__)
+
+
+@attr.s
+class PaginationChunk(object):
+ """Returned by relation pagination APIs.
+
+ Attributes:
+ chunk (list): The rows returned by pagination
+ next_batch (Any|None): Token to fetch next set of results with, if
+ None then there are no more results.
+ prev_batch (Any|None): Token to fetch previous set of results with, if
+ None then there are no previous results.
+ """
+
+ chunk = attr.ib()
+ next_batch = attr.ib(default=None)
+ prev_batch = attr.ib(default=None)
+
+ def to_dict(self):
+ d = {"chunk": self.chunk}
+
+ if self.next_batch:
+ d["next_batch"] = self.next_batch.to_string()
+
+ if self.prev_batch:
+ d["prev_batch"] = self.prev_batch.to_string()
+
+ return d
+
+
+@attr.s(frozen=True, slots=True)
+class RelationPaginationToken(object):
+ """Pagination token for relation pagination API.
+
+ As the results are order by topological ordering, we can use the
+ `topological_ordering` and `stream_ordering` fields of the events at the
+ boundaries of the chunk as pagination tokens.
+
+ Attributes:
+ topological (int): The topological ordering of the boundary event
+ stream (int): The stream ordering of the boundary event.
+ """
+
+ topological = attr.ib()
+ stream = attr.ib()
+
+ @staticmethod
+ def from_string(string):
+ try:
+ t, s = string.split("-")
+ return RelationPaginationToken(int(t), int(s))
+ except ValueError:
+ raise SynapseError(400, "Invalid token")
+
+ def to_string(self):
+ return "%d-%d" % (self.topological, self.stream)
+
+ def as_tuple(self):
+ return attr.astuple(self)
+
+
+@attr.s(frozen=True, slots=True)
+class AggregationPaginationToken(object):
+ """Pagination token for relation aggregation pagination API.
+
+ As the results are order by count and then MAX(stream_ordering) of the
+ aggregation groups, we can just use them as our pagination token.
+
+ Attributes:
+ count (int): The count of relations in the boundar group.
+ stream (int): The MAX stream ordering in the boundary group.
+ """
+
+ count = attr.ib()
+ stream = attr.ib()
+
+ @staticmethod
+ def from_string(string):
+ try:
+ c, s = string.split("-")
+ return AggregationPaginationToken(int(c), int(s))
+ except ValueError:
+ raise SynapseError(400, "Invalid token")
+
+ def to_string(self):
+ return "%d-%d" % (self.count, self.stream)
+
+ def as_tuple(self):
+ return attr.astuple(self)
+
+
+class RelationsWorkerStore(SQLBaseStore):
+ @cached(tree=True)
+ def get_relations_for_event(
+ self,
+ event_id,
+ relation_type=None,
+ event_type=None,
+ aggregation_key=None,
+ limit=5,
+ direction="b",
+ from_token=None,
+ to_token=None,
+ ):
+ """Get a list of relations for an event, ordered by topological ordering.
+
+ Args:
+ event_id (str): Fetch events that relate to this event ID.
+ relation_type (str|None): Only fetch events with this relation
+ type, if given.
+ event_type (str|None): Only fetch events with this event type, if
+ given.
+ aggregation_key (str|None): Only fetch events with this aggregation
+ key, if given.
+ limit (int): Only fetch the most recent `limit` events.
+ direction (str): Whether to fetch the most recent first (`"b"`) or
+ the oldest first (`"f"`).
+ from_token (RelationPaginationToken|None): Fetch rows from the given
+ token, or from the start if None.
+ to_token (RelationPaginationToken|None): Fetch rows up to the given
+ token, or up to the end if None.
+
+ Returns:
+ Deferred[PaginationChunk]: List of event IDs that match relations
+ requested. The rows are of the form `{"event_id": "..."}`.
+ """
+
+ where_clause = ["relates_to_id = ?"]
+ where_args = [event_id]
+
+ if relation_type is not None:
+ where_clause.append("relation_type = ?")
+ where_args.append(relation_type)
+
+ if event_type is not None:
+ where_clause.append("type = ?")
+ where_args.append(event_type)
+
+ if aggregation_key:
+ where_clause.append("aggregation_key = ?")
+ where_args.append(aggregation_key)
+
+ pagination_clause = generate_pagination_where_clause(
+ direction=direction,
+ column_names=("topological_ordering", "stream_ordering"),
+ from_token=attr.astuple(from_token) if from_token else None,
+ to_token=attr.astuple(to_token) if to_token else None,
+ engine=self.database_engine,
+ )
+
+ if pagination_clause:
+ where_clause.append(pagination_clause)
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ sql = """
+ SELECT event_id, topological_ordering, stream_ordering
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE %s
+ ORDER BY topological_ordering %s, stream_ordering %s
+ LIMIT ?
+ """ % (
+ " AND ".join(where_clause),
+ order,
+ order,
+ )
+
+ def _get_recent_references_for_event_txn(txn):
+ txn.execute(sql, where_args + [limit + 1])
+
+ last_topo_id = None
+ last_stream_id = None
+ events = []
+ for row in txn:
+ events.append({"event_id": row[0]})
+ last_topo_id = row[1]
+ last_stream_id = row[2]
+
+ next_batch = None
+ if len(events) > limit and last_topo_id and last_stream_id:
+ next_batch = RelationPaginationToken(last_topo_id, last_stream_id)
+
+ return PaginationChunk(
+ chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
+ )
+
+ return self.runInteraction(
+ "get_recent_references_for_event", _get_recent_references_for_event_txn
+ )
+
+ @cached(tree=True)
+ def get_aggregation_groups_for_event(
+ self,
+ event_id,
+ event_type=None,
+ limit=5,
+ direction="b",
+ from_token=None,
+ to_token=None,
+ ):
+ """Get a list of annotations on the event, grouped by event type and
+ aggregation key, sorted by count.
+
+ This is used e.g. to get the what and how many reactions have happend
+ on an event.
+
+ Args:
+ event_id (str): Fetch events that relate to this event ID.
+ event_type (str|None): Only fetch events with this event type, if
+ given.
+ limit (int): Only fetch the `limit` groups.
+ direction (str): Whether to fetch the highest count first (`"b"`) or
+ the lowest count first (`"f"`).
+ from_token (AggregationPaginationToken|None): Fetch rows from the
+ given token, or from the start if None.
+ to_token (AggregationPaginationToken|None): Fetch rows up to the
+ given token, or up to the end if None.
+
+
+ Returns:
+ Deferred[PaginationChunk]: List of groups of annotations that
+ match. Each row is a dict with `type`, `key` and `count` fields.
+ """
+
+ where_clause = ["relates_to_id = ?", "relation_type = ?"]
+ where_args = [event_id, RelationTypes.ANNOTATION]
+
+ if event_type:
+ where_clause.append("type = ?")
+ where_args.append(event_type)
+
+ having_clause = generate_pagination_where_clause(
+ direction=direction,
+ column_names=("COUNT(*)", "MAX(stream_ordering)"),
+ from_token=attr.astuple(from_token) if from_token else None,
+ to_token=attr.astuple(to_token) if to_token else None,
+ engine=self.database_engine,
+ )
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ if having_clause:
+ having_clause = "HAVING " + having_clause
+ else:
+ having_clause = ""
+
+ sql = """
+ SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering)
+ FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE {where_clause}
+ GROUP BY relation_type, type, aggregation_key
+ {having_clause}
+ ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order}
+ LIMIT ?
+ """.format(
+ where_clause=" AND ".join(where_clause),
+ order=order,
+ having_clause=having_clause,
+ )
+
+ def _get_aggregation_groups_for_event_txn(txn):
+ txn.execute(sql, where_args + [limit + 1])
+
+ next_batch = None
+ events = []
+ for row in txn:
+ events.append({"type": row[0], "key": row[1], "count": row[2]})
+ next_batch = AggregationPaginationToken(row[2], row[3])
+
+ if len(events) <= limit:
+ next_batch = None
+
+ return PaginationChunk(
+ chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
+ )
+
+ return self.runInteraction(
+ "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
+ )
+
+ @cachedInlineCallbacks()
+ def get_applicable_edit(self, event_id):
+ """Get the most recent edit (if any) that has happened for the given
+ event.
+
+ Correctly handles checking whether edits were allowed to happen.
+
+ Args:
+ event_id (str): The original event ID
+
+ Returns:
+ Deferred[EventBase|None]: Returns the most recent edit, if any.
+ """
+
+ # We only allow edits for `m.room.message` events that have the same sender
+ # and event type. We can't assert these things during regular event auth so
+ # we have to do the checks post hoc.
+
+ # Fetches latest edit that has the same type and sender as the
+ # original, and is an `m.room.message`.
+ sql = """
+ SELECT edit.event_id FROM events AS edit
+ INNER JOIN event_relations USING (event_id)
+ INNER JOIN events AS original ON
+ original.event_id = relates_to_id
+ AND edit.type = original.type
+ AND edit.sender = original.sender
+ WHERE
+ relates_to_id = ?
+ AND relation_type = ?
+ AND edit.type = 'm.room.message'
+ ORDER by edit.origin_server_ts DESC, edit.event_id DESC
+ LIMIT 1
+ """
+
+ def _get_applicable_edit_txn(txn):
+ txn.execute(sql, (event_id, RelationTypes.REPLACE))
+ row = txn.fetchone()
+ if row:
+ return row[0]
+
+ edit_id = yield self.runInteraction(
+ "get_applicable_edit", _get_applicable_edit_txn
+ )
+
+ if not edit_id:
+ return
+
+ edit_event = yield self.get_event(edit_id, allow_none=True)
+ defer.returnValue(edit_event)
+
+ def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
+ """Check if a user has already annotated an event with the same key
+ (e.g. already liked an event).
+
+ Args:
+ parent_id (str): The event being annotated
+ event_type (str): The event type of the annotation
+ aggregation_key (str): The aggregation key of the annotation
+ sender (str): The sender of the annotation
+
+ Returns:
+ Deferred[bool]
+ """
+
+ sql = """
+ SELECT 1 FROM event_relations
+ INNER JOIN events USING (event_id)
+ WHERE
+ relates_to_id = ?
+ AND relation_type = ?
+ AND type = ?
+ AND sender = ?
+ AND aggregation_key = ?
+ LIMIT 1;
+ """
+
+ def _get_if_user_has_annotated_event(txn):
+ txn.execute(
+ sql,
+ (
+ parent_id,
+ RelationTypes.ANNOTATION,
+ event_type,
+ sender,
+ aggregation_key,
+ ),
+ )
+
+ return bool(txn.fetchone())
+
+ return self.runInteraction(
+ "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
+ )
+
+
+class RelationsStore(RelationsWorkerStore):
+ def _handle_event_relations(self, txn, event):
+ """Handles inserting relation data during peristence of events
+
+ Args:
+ txn
+ event (EventBase)
+ """
+ relation = event.content.get("m.relates_to")
+ if not relation:
+ # No relations
+ return
+
+ rel_type = relation.get("rel_type")
+ if rel_type not in (
+ RelationTypes.ANNOTATION,
+ RelationTypes.REFERENCE,
+ RelationTypes.REPLACE,
+ ):
+ # Unknown relation type
+ return
+
+ parent_id = relation.get("event_id")
+ if not parent_id:
+ # Invalid relation
+ return
+
+ aggregation_key = relation.get("key")
+
+ self._simple_insert_txn(
+ txn,
+ table="event_relations",
+ values={
+ "event_id": event.event_id,
+ "relates_to_id": parent_id,
+ "relation_type": rel_type,
+ "aggregation_key": aggregation_key,
+ },
+ )
+
+ txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,))
+ txn.call_after(
+ self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
+ )
+
+ if rel_type == RelationTypes.REPLACE:
+ txn.call_after(self.get_applicable_edit.invalidate, (parent_id,))
+
+ def _handle_redaction(self, txn, redacted_event_id):
+ """Handles receiving a redaction and checking whether we need to remove
+ any redacted relations from the database.
+
+ Args:
+ txn
+ redacted_event_id (str): The event that was redacted.
+ """
+
+ self._simple_delete_txn(
+ txn,
+ table="event_relations",
+ keyvalues={
+ "event_id": redacted_event_id,
+ }
+ )
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 57df17bcc2..7617913326 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -142,6 +142,27 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return self.runInteraction("get_room_summary", _get_room_summary_txn)
+ def _get_user_counts_in_room_txn(self, txn, room_id):
+ """
+ Get the user count in a room by membership.
+
+ Args:
+ room_id (str)
+ membership (Membership)
+
+ Returns:
+ Deferred[int]
+ """
+ sql = """
+ SELECT m.membership, count(*) FROM room_memberships as m
+ INNER JOIN current_state_events as c USING(event_id)
+ WHERE c.type = 'm.room.member' AND c.room_id = ?
+ GROUP BY m.membership
+ """
+
+ txn.execute(sql, (room_id,))
+ return {row[0]: row[1] for row in txn}
+
@cached()
def get_invited_rooms_for_user(self, user_id):
""" Get all the rooms the user is invited to
diff --git a/synapse/storage/schema/delta/54/account_validity.sql b/synapse/storage/schema/delta/54/account_validity_with_renewal.sql
index 2357626000..0adb2ad55e 100644
--- a/synapse/storage/schema/delta/54/account_validity.sql
+++ b/synapse/storage/schema/delta/54/account_validity_with_renewal.sql
@@ -13,6 +13,9 @@
* limitations under the License.
*/
+-- We previously changed the schema for this table without renaming the file, which means
+-- that some databases might still be using the old schema. This ensures Synapse uses the
+-- right schema for the table.
DROP TABLE IF EXISTS account_validity;
-- Track what users are in public rooms.
diff --git a/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql b/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql
new file mode 100644
index 0000000000..c01aa9d2d9
--- /dev/null
+++ b/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql
@@ -0,0 +1,23 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* When we can use this key until, before we have to refresh it. */
+ALTER TABLE server_signature_keys ADD COLUMN ts_valid_until_ms BIGINT;
+
+UPDATE server_signature_keys SET ts_valid_until_ms = (
+ SELECT MAX(ts_valid_until_ms) FROM server_keys_json skj WHERE
+ skj.server_name = server_signature_keys.server_name AND
+ skj.key_id = server_signature_keys.key_id
+);
diff --git a/synapse/storage/schema/delta/54/delete_forward_extremities.sql b/synapse/storage/schema/delta/54/delete_forward_extremities.sql
new file mode 100644
index 0000000000..b062ec840c
--- /dev/null
+++ b/synapse/storage/schema/delta/54/delete_forward_extremities.sql
@@ -0,0 +1,23 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Start a background job to cleanup extremities that were incorrectly added
+-- by bug #5269.
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('delete_soft_failed_extremities', '{}');
+
+DROP TABLE IF EXISTS _extremities_to_check; -- To make this delta schema file idempotent.
+CREATE TABLE _extremities_to_check AS SELECT event_id FROM event_forward_extremities;
+CREATE INDEX _extremities_to_check_id ON _extremities_to_check(event_id);
diff --git a/synapse/storage/schema/delta/54/relations.sql b/synapse/storage/schema/delta/54/relations.sql
new file mode 100644
index 0000000000..134862b870
--- /dev/null
+++ b/synapse/storage/schema/delta/54/relations.sql
@@ -0,0 +1,27 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Tracks related events, like reactions, replies, edits, etc. Note that things
+-- in this table are not necessarily "valid", e.g. it may contain edits from
+-- people who don't have power to edit other peoples events.
+CREATE TABLE IF NOT EXISTS event_relations (
+ event_id TEXT NOT NULL,
+ relates_to_id TEXT NOT NULL,
+ relation_type TEXT NOT NULL,
+ aggregation_key TEXT
+);
+
+CREATE UNIQUE INDEX event_relations_id ON event_relations(event_id);
+CREATE INDEX event_relations_relates ON event_relations(relates_to_id, relation_type, aggregation_key);
diff --git a/synapse/storage/schema/delta/54/stats.sql b/synapse/storage/schema/delta/54/stats.sql
new file mode 100644
index 0000000000..652e58308e
--- /dev/null
+++ b/synapse/storage/schema/delta/54/stats.sql
@@ -0,0 +1,80 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE stats_stream_pos (
+ Lock CHAR(1) NOT NULL DEFAULT 'X' UNIQUE, -- Makes sure this table only has one row.
+ stream_id BIGINT,
+ CHECK (Lock='X')
+);
+
+INSERT INTO stats_stream_pos (stream_id) VALUES (null);
+
+CREATE TABLE user_stats (
+ user_id TEXT NOT NULL,
+ ts BIGINT NOT NULL,
+ bucket_size INT NOT NULL,
+ public_rooms INT NOT NULL,
+ private_rooms INT NOT NULL
+);
+
+CREATE UNIQUE INDEX user_stats_user_ts ON user_stats(user_id, ts);
+
+CREATE TABLE room_stats (
+ room_id TEXT NOT NULL,
+ ts BIGINT NOT NULL,
+ bucket_size INT NOT NULL,
+ current_state_events INT NOT NULL,
+ joined_members INT NOT NULL,
+ invited_members INT NOT NULL,
+ left_members INT NOT NULL,
+ banned_members INT NOT NULL,
+ state_events INT NOT NULL
+);
+
+CREATE UNIQUE INDEX room_stats_room_ts ON room_stats(room_id, ts);
+
+-- cache of current room state; useful for the publicRooms list
+CREATE TABLE room_state (
+ room_id TEXT NOT NULL,
+ join_rules TEXT,
+ history_visibility TEXT,
+ encryption TEXT,
+ name TEXT,
+ topic TEXT,
+ avatar TEXT,
+ canonical_alias TEXT
+ -- get aliases straight from the right table
+);
+
+CREATE UNIQUE INDEX room_state_room ON room_state(room_id);
+
+CREATE TABLE room_stats_earliest_token (
+ room_id TEXT NOT NULL,
+ token BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX room_stats_earliest_token_idx ON room_stats_earliest_token(room_id);
+
+-- Set up staging tables
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('populate_stats_createtables', '{}');
+
+-- Run through each room and update stats
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+ ('populate_stats_process_rooms', '{}', 'populate_stats_createtables');
+
+-- Clean up staging tables
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+ ('populate_stats_cleanup', '{}', 'populate_stats_process_rooms');
diff --git a/synapse/storage/schema/delta/54/stats2.sql b/synapse/storage/schema/delta/54/stats2.sql
new file mode 100644
index 0000000000..3b2d48447f
--- /dev/null
+++ b/synapse/storage/schema/delta/54/stats2.sql
@@ -0,0 +1,28 @@
+/* Copyright 2019 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This delta file gets run after `54/stats.sql` delta.
+
+-- We want to add some indices to the temporary stats table, so we re-insert
+-- 'populate_stats_createtables' if we are still processing the rooms update.
+INSERT INTO background_updates (update_name, progress_json)
+ SELECT 'populate_stats_createtables', '{}'
+ WHERE
+ 'populate_stats_process_rooms' IN (
+ SELECT update_name FROM background_updates
+ )
+ AND 'populate_stats_createtables' NOT IN ( -- don't insert if already exists
+ SELECT update_name FROM background_updates
+ );
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index 226f8f1b7e..ff49eaae02 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -460,7 +460,7 @@ class SearchStore(BackgroundUpdateStore):
results = list(filter(lambda row: row["room_id"] in room_ids, results))
- events = yield self._get_events([r["event_id"] for r in results])
+ events = yield self.get_events_as_list([r["event_id"] for r in results])
event_map = {ev.event_id: ev for ev in events}
@@ -605,7 +605,7 @@ class SearchStore(BackgroundUpdateStore):
results = list(filter(lambda row: row["room_id"] in room_ids, results))
- events = yield self._get_events([r["event_id"] for r in results])
+ events = yield self.get_events_as_list([r["event_id"] for r in results])
event_map = {ev.event_id: ev for ev in events}
diff --git a/synapse/storage/state_deltas.py b/synapse/storage/state_deltas.py
index 56e42f583d..5fdb442104 100644
--- a/synapse/storage/state_deltas.py
+++ b/synapse/storage/state_deltas.py
@@ -22,6 +22,24 @@ logger = logging.getLogger(__name__)
class StateDeltasStore(SQLBaseStore):
def get_current_state_deltas(self, prev_stream_id):
+ """Fetch a list of room state changes since the given stream id
+
+ Each entry in the result contains the following fields:
+ - stream_id (int)
+ - room_id (str)
+ - type (str): event type
+ - state_key (str):
+ - event_id (str|None): new event_id for this state key. None if the
+ state has been deleted.
+ - prev_event_id (str|None): previous event_id for this state key. None
+ if it's new state.
+
+ Args:
+ prev_stream_id (int): point to get changes since (exclusive)
+
+ Returns:
+ Deferred[list[dict]]: results
+ """
prev_stream_id = int(prev_stream_id)
if not self._curr_state_delta_stream_cache.has_any_entity_changed(
prev_stream_id
@@ -66,10 +84,16 @@ class StateDeltasStore(SQLBaseStore):
"get_current_state_deltas", get_current_state_deltas_txn
)
- def get_max_stream_id_in_current_state_deltas(self):
- return self._simple_select_one_onecol(
+ def _get_max_stream_id_in_current_state_deltas_txn(self, txn):
+ return self._simple_select_one_onecol_txn(
+ txn,
table="current_state_delta_stream",
keyvalues={},
retcol="COALESCE(MAX(stream_id), -1)",
- desc="get_max_stream_id_in_current_state_deltas",
+ )
+
+ def get_max_stream_id_in_current_state_deltas(self):
+ return self.runInteraction(
+ "get_max_stream_id_in_current_state_deltas",
+ self._get_max_stream_id_in_current_state_deltas_txn,
)
diff --git a/synapse/storage/stats.py b/synapse/storage/stats.py
new file mode 100644
index 0000000000..ff266b09b0
--- /dev/null
+++ b/synapse/storage/stats.py
@@ -0,0 +1,468 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018, 2019 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.storage.prepare_database import get_statements
+from synapse.storage.state_deltas import StateDeltasStore
+from synapse.util.caches.descriptors import cached
+
+logger = logging.getLogger(__name__)
+
+# these fields track absolutes (e.g. total number of rooms on the server)
+ABSOLUTE_STATS_FIELDS = {
+ "room": (
+ "current_state_events",
+ "joined_members",
+ "invited_members",
+ "left_members",
+ "banned_members",
+ "state_events",
+ ),
+ "user": ("public_rooms", "private_rooms"),
+}
+
+TYPE_TO_ROOM = {"room": ("room_stats", "room_id"), "user": ("user_stats", "user_id")}
+
+TEMP_TABLE = "_temp_populate_stats"
+
+
+class StatsStore(StateDeltasStore):
+ def __init__(self, db_conn, hs):
+ super(StatsStore, self).__init__(db_conn, hs)
+
+ self.server_name = hs.hostname
+ self.clock = self.hs.get_clock()
+ self.stats_enabled = hs.config.stats_enabled
+ self.stats_bucket_size = hs.config.stats_bucket_size
+
+ self.register_background_update_handler(
+ "populate_stats_createtables", self._populate_stats_createtables
+ )
+ self.register_background_update_handler(
+ "populate_stats_process_rooms", self._populate_stats_process_rooms
+ )
+ self.register_background_update_handler(
+ "populate_stats_cleanup", self._populate_stats_cleanup
+ )
+
+ @defer.inlineCallbacks
+ def _populate_stats_createtables(self, progress, batch_size):
+
+ if not self.stats_enabled:
+ yield self._end_background_update("populate_stats_createtables")
+ defer.returnValue(1)
+
+ # Get all the rooms that we want to process.
+ def _make_staging_area(txn):
+ # Create the temporary tables
+ stmts = get_statements("""
+ -- We just recreate the table, we'll be reinserting the
+ -- correct entries again later anyway.
+ DROP TABLE IF EXISTS {temp}_rooms;
+
+ CREATE TABLE IF NOT EXISTS {temp}_rooms(
+ room_id TEXT NOT NULL,
+ events BIGINT NOT NULL
+ );
+
+ CREATE INDEX {temp}_rooms_events
+ ON {temp}_rooms(events);
+ CREATE INDEX {temp}_rooms_id
+ ON {temp}_rooms(room_id);
+ """.format(temp=TEMP_TABLE).splitlines())
+
+ for statement in stmts:
+ txn.execute(statement)
+
+ sql = (
+ "CREATE TABLE IF NOT EXISTS "
+ + TEMP_TABLE
+ + "_position(position TEXT NOT NULL)"
+ )
+ txn.execute(sql)
+
+ # Get rooms we want to process from the database, only adding
+ # those that we haven't (i.e. those not in room_stats_earliest_token)
+ sql = """
+ INSERT INTO %s_rooms (room_id, events)
+ SELECT c.room_id, count(*) FROM current_state_events AS c
+ LEFT JOIN room_stats_earliest_token AS t USING (room_id)
+ WHERE t.room_id IS NULL
+ GROUP BY c.room_id
+ """ % (TEMP_TABLE,)
+ txn.execute(sql)
+
+ new_pos = yield self.get_max_stream_id_in_current_state_deltas()
+ yield self.runInteraction("populate_stats_temp_build", _make_staging_area)
+ yield self._simple_insert(TEMP_TABLE + "_position", {"position": new_pos})
+ self.get_earliest_token_for_room_stats.invalidate_all()
+
+ yield self._end_background_update("populate_stats_createtables")
+ defer.returnValue(1)
+
+ @defer.inlineCallbacks
+ def _populate_stats_cleanup(self, progress, batch_size):
+ """
+ Update the user directory stream position, then clean up the old tables.
+ """
+ if not self.stats_enabled:
+ yield self._end_background_update("populate_stats_cleanup")
+ defer.returnValue(1)
+
+ position = yield self._simple_select_one_onecol(
+ TEMP_TABLE + "_position", None, "position"
+ )
+ yield self.update_stats_stream_pos(position)
+
+ def _delete_staging_area(txn):
+ txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms")
+ txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
+
+ yield self.runInteraction("populate_stats_cleanup", _delete_staging_area)
+
+ yield self._end_background_update("populate_stats_cleanup")
+ defer.returnValue(1)
+
+ @defer.inlineCallbacks
+ def _populate_stats_process_rooms(self, progress, batch_size):
+
+ if not self.stats_enabled:
+ yield self._end_background_update("populate_stats_process_rooms")
+ defer.returnValue(1)
+
+ # If we don't have progress filed, delete everything.
+ if not progress:
+ yield self.delete_all_stats()
+
+ def _get_next_batch(txn):
+ # Only fetch 250 rooms, so we don't fetch too many at once, even
+ # if those 250 rooms have less than batch_size state events.
+ sql = """
+ SELECT room_id, events FROM %s_rooms
+ ORDER BY events DESC
+ LIMIT 250
+ """ % (
+ TEMP_TABLE,
+ )
+ txn.execute(sql)
+ rooms_to_work_on = txn.fetchall()
+
+ if not rooms_to_work_on:
+ return None
+
+ # Get how many are left to process, so we can give status on how
+ # far we are in processing
+ txn.execute("SELECT COUNT(*) FROM " + TEMP_TABLE + "_rooms")
+ progress["remaining"] = txn.fetchone()[0]
+
+ return rooms_to_work_on
+
+ rooms_to_work_on = yield self.runInteraction(
+ "populate_stats_temp_read", _get_next_batch
+ )
+
+ # No more rooms -- complete the transaction.
+ if not rooms_to_work_on:
+ yield self._end_background_update("populate_stats_process_rooms")
+ defer.returnValue(1)
+
+ logger.info(
+ "Processing the next %d rooms of %d remaining",
+ len(rooms_to_work_on), progress["remaining"],
+ )
+
+ # Number of state events we've processed by going through each room
+ processed_event_count = 0
+
+ for room_id, event_count in rooms_to_work_on:
+
+ current_state_ids = yield self.get_current_state_ids(room_id)
+
+ join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
+ history_visibility_id = current_state_ids.get(
+ (EventTypes.RoomHistoryVisibility, "")
+ )
+ encryption_id = current_state_ids.get((EventTypes.RoomEncryption, ""))
+ name_id = current_state_ids.get((EventTypes.Name, ""))
+ topic_id = current_state_ids.get((EventTypes.Topic, ""))
+ avatar_id = current_state_ids.get((EventTypes.RoomAvatar, ""))
+ canonical_alias_id = current_state_ids.get((EventTypes.CanonicalAlias, ""))
+
+ state_events = yield self.get_events([
+ join_rules_id, history_visibility_id, encryption_id, name_id,
+ topic_id, avatar_id, canonical_alias_id,
+ ])
+
+ def _get_or_none(event_id, arg):
+ event = state_events.get(event_id)
+ if event:
+ return event.content.get(arg)
+ return None
+
+ yield self.update_room_state(
+ room_id,
+ {
+ "join_rules": _get_or_none(join_rules_id, "join_rule"),
+ "history_visibility": _get_or_none(
+ history_visibility_id, "history_visibility"
+ ),
+ "encryption": _get_or_none(encryption_id, "algorithm"),
+ "name": _get_or_none(name_id, "name"),
+ "topic": _get_or_none(topic_id, "topic"),
+ "avatar": _get_or_none(avatar_id, "url"),
+ "canonical_alias": _get_or_none(canonical_alias_id, "alias"),
+ },
+ )
+
+ now = self.hs.get_reactor().seconds()
+
+ # quantise time to the nearest bucket
+ now = (now // self.stats_bucket_size) * self.stats_bucket_size
+
+ def _fetch_data(txn):
+
+ # Get the current token of the room
+ current_token = self._get_max_stream_id_in_current_state_deltas_txn(txn)
+
+ current_state_events = len(current_state_ids)
+
+ membership_counts = self._get_user_counts_in_room_txn(txn, room_id)
+
+ total_state_events = self._get_total_state_event_counts_txn(
+ txn, room_id
+ )
+
+ self._update_stats_txn(
+ txn,
+ "room",
+ room_id,
+ now,
+ {
+ "bucket_size": self.stats_bucket_size,
+ "current_state_events": current_state_events,
+ "joined_members": membership_counts.get(Membership.JOIN, 0),
+ "invited_members": membership_counts.get(Membership.INVITE, 0),
+ "left_members": membership_counts.get(Membership.LEAVE, 0),
+ "banned_members": membership_counts.get(Membership.BAN, 0),
+ "state_events": total_state_events,
+ },
+ )
+ self._simple_insert_txn(
+ txn,
+ "room_stats_earliest_token",
+ {"room_id": room_id, "token": current_token},
+ )
+
+ # We've finished a room. Delete it from the table.
+ self._simple_delete_one_txn(
+ txn, TEMP_TABLE + "_rooms", {"room_id": room_id},
+ )
+
+ yield self.runInteraction("update_room_stats", _fetch_data)
+
+ # Update the remaining counter.
+ progress["remaining"] -= 1
+ yield self.runInteraction(
+ "populate_stats",
+ self._background_update_progress_txn,
+ "populate_stats_process_rooms",
+ progress,
+ )
+
+ processed_event_count += event_count
+
+ if processed_event_count > batch_size:
+ # Don't process any more rooms, we've hit our batch size.
+ defer.returnValue(processed_event_count)
+
+ defer.returnValue(processed_event_count)
+
+ def delete_all_stats(self):
+ """
+ Delete all statistics records.
+ """
+
+ def _delete_all_stats_txn(txn):
+ txn.execute("DELETE FROM room_state")
+ txn.execute("DELETE FROM room_stats")
+ txn.execute("DELETE FROM room_stats_earliest_token")
+ txn.execute("DELETE FROM user_stats")
+
+ return self.runInteraction("delete_all_stats", _delete_all_stats_txn)
+
+ def get_stats_stream_pos(self):
+ return self._simple_select_one_onecol(
+ table="stats_stream_pos",
+ keyvalues={},
+ retcol="stream_id",
+ desc="stats_stream_pos",
+ )
+
+ def update_stats_stream_pos(self, stream_id):
+ return self._simple_update_one(
+ table="stats_stream_pos",
+ keyvalues={},
+ updatevalues={"stream_id": stream_id},
+ desc="update_stats_stream_pos",
+ )
+
+ def update_room_state(self, room_id, fields):
+ """
+ Args:
+ room_id (str)
+ fields (dict[str:Any])
+ """
+
+ # For whatever reason some of the fields may contain null bytes, which
+ # postgres isn't a fan of, so we replace those fields with null.
+ for col in (
+ "join_rules",
+ "history_visibility",
+ "encryption",
+ "name",
+ "topic",
+ "avatar",
+ "canonical_alias"
+ ):
+ field = fields.get(col)
+ if field and "\0" in field:
+ fields[col] = None
+
+ return self._simple_upsert(
+ table="room_state",
+ keyvalues={"room_id": room_id},
+ values=fields,
+ desc="update_room_state",
+ )
+
+ def get_deltas_for_room(self, room_id, start, size=100):
+ """
+ Get statistics deltas for a given room.
+
+ Args:
+ room_id (str)
+ start (int): Pagination start. Number of entries, not timestamp.
+ size (int): How many entries to return.
+
+ Returns:
+ Deferred[list[dict]], where the dict has the keys of
+ ABSOLUTE_STATS_FIELDS["room"] and "ts".
+ """
+ return self._simple_select_list_paginate(
+ "room_stats",
+ {"room_id": room_id},
+ "ts",
+ start,
+ size,
+ retcols=(list(ABSOLUTE_STATS_FIELDS["room"]) + ["ts"]),
+ order_direction="DESC",
+ )
+
+ def get_all_room_state(self):
+ return self._simple_select_list(
+ "room_state", None, retcols=("name", "topic", "canonical_alias")
+ )
+
+ @cached()
+ def get_earliest_token_for_room_stats(self, room_id):
+ """
+ Fetch the "earliest token". This is used by the room stats delta
+ processor to ignore deltas that have been processed between the
+ start of the background task and any particular room's stats
+ being calculated.
+
+ Returns:
+ Deferred[int]
+ """
+ return self._simple_select_one_onecol(
+ "room_stats_earliest_token",
+ {"room_id": room_id},
+ retcol="token",
+ allow_none=True,
+ )
+
+ def update_stats(self, stats_type, stats_id, ts, fields):
+ table, id_col = TYPE_TO_ROOM[stats_type]
+ return self._simple_upsert(
+ table=table,
+ keyvalues={id_col: stats_id, "ts": ts},
+ values=fields,
+ desc="update_stats",
+ )
+
+ def _update_stats_txn(self, txn, stats_type, stats_id, ts, fields):
+ table, id_col = TYPE_TO_ROOM[stats_type]
+ return self._simple_upsert_txn(
+ txn, table=table, keyvalues={id_col: stats_id, "ts": ts}, values=fields
+ )
+
+ def update_stats_delta(self, ts, stats_type, stats_id, field, value):
+ def _update_stats_delta(txn):
+ table, id_col = TYPE_TO_ROOM[stats_type]
+
+ sql = (
+ "SELECT * FROM %s"
+ " WHERE %s=? and ts=("
+ " SELECT MAX(ts) FROM %s"
+ " WHERE %s=?"
+ ")"
+ ) % (table, id_col, table, id_col)
+ txn.execute(sql, (stats_id, stats_id))
+ rows = self.cursor_to_dict(txn)
+ if len(rows) == 0:
+ # silently skip as we don't have anything to apply a delta to yet.
+ # this tries to minimise any race between the initial sync and
+ # subsequent deltas arriving.
+ return
+
+ current_ts = ts
+ latest_ts = rows[0]["ts"]
+ if current_ts < latest_ts:
+ # This one is in the past, but we're just encountering it now.
+ # Mark it as part of the current bucket.
+ current_ts = latest_ts
+ elif ts != latest_ts:
+ # we have to copy our absolute counters over to the new entry.
+ values = {
+ key: rows[0][key] for key in ABSOLUTE_STATS_FIELDS[stats_type]
+ }
+ values[id_col] = stats_id
+ values["ts"] = ts
+ values["bucket_size"] = self.stats_bucket_size
+
+ self._simple_insert_txn(txn, table=table, values=values)
+
+ # actually update the new value
+ if stats_type in ABSOLUTE_STATS_FIELDS[stats_type]:
+ self._simple_update_txn(
+ txn,
+ table=table,
+ keyvalues={id_col: stats_id, "ts": current_ts},
+ updatevalues={field: value},
+ )
+ else:
+ sql = ("UPDATE %s SET %s=%s+? WHERE %s=? AND ts=?") % (
+ table,
+ field,
+ field,
+ id_col,
+ )
+ txn.execute(sql, (value, stats_id, current_ts))
+
+ return self.runInteraction("update_stats_delta", _update_stats_delta)
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 9cd1e0f9fe..529ad4ea79 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -64,59 +64,135 @@ _EventDictReturn = namedtuple(
)
-def lower_bound(token, engine, inclusive=False):
- inclusive = "=" if inclusive else ""
- if token.topological is None:
- return "(%d <%s %s)" % (token.stream, inclusive, "stream_ordering")
- else:
- if isinstance(engine, PostgresEngine):
- # Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
- # as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
- # use the later form when running against postgres.
- return "((%d,%d) <%s (%s,%s))" % (
- token.topological,
- token.stream,
- inclusive,
- "topological_ordering",
- "stream_ordering",
+def generate_pagination_where_clause(
+ direction, column_names, from_token, to_token, engine,
+):
+ """Creates an SQL expression to bound the columns by the pagination
+ tokens.
+
+ For example creates an SQL expression like:
+
+ (6, 7) >= (topological_ordering, stream_ordering)
+ AND (5, 3) < (topological_ordering, stream_ordering)
+
+ would be generated for dir=b, from_token=(6, 7) and to_token=(5, 3).
+
+ Note that tokens are considered to be after the row they are in, e.g. if
+ a row A has a token T, then we consider A to be before T. This convention
+ is important when figuring out inequalities for the generated SQL, and
+ produces the following result:
+ - If paginating forwards then we exclude any rows matching the from
+ token, but include those that match the to token.
+ - If paginating backwards then we include any rows matching the from
+ token, but include those that match the to token.
+
+ Args:
+ direction (str): Whether we're paginating backwards("b") or
+ forwards ("f").
+ column_names (tuple[str, str]): The column names to bound. Must *not*
+ be user defined as these get inserted directly into the SQL
+ statement without escapes.
+ from_token (tuple[int, int]|None): The start point for the pagination.
+ This is an exclusive minimum bound if direction is "f", and an
+ inclusive maximum bound if direction is "b".
+ to_token (tuple[int, int]|None): The endpoint point for the pagination.
+ This is an inclusive maximum bound if direction is "f", and an
+ exclusive minimum bound if direction is "b".
+ engine: The database engine to generate the clauses for
+
+ Returns:
+ str: The sql expression
+ """
+ assert direction in ("b", "f")
+
+ where_clause = []
+ if from_token:
+ where_clause.append(
+ _make_generic_sql_bound(
+ bound=">=" if direction == "b" else "<",
+ column_names=column_names,
+ values=from_token,
+ engine=engine,
)
- return "(%d < %s OR (%d = %s AND %d <%s %s))" % (
- token.topological,
- "topological_ordering",
- token.topological,
- "topological_ordering",
- token.stream,
- inclusive,
- "stream_ordering",
- )
-
-
-def upper_bound(token, engine, inclusive=True):
- inclusive = "=" if inclusive else ""
- if token.topological is None:
- return "(%d >%s %s)" % (token.stream, inclusive, "stream_ordering")
- else:
- if isinstance(engine, PostgresEngine):
- # Postgres doesn't optimise ``(x > a) OR (x=a AND y>b)`` as well
- # as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we
- # use the later form when running against postgres.
- return "((%d,%d) >%s (%s,%s))" % (
- token.topological,
- token.stream,
- inclusive,
- "topological_ordering",
- "stream_ordering",
+ )
+
+ if to_token:
+ where_clause.append(
+ _make_generic_sql_bound(
+ bound="<" if direction == "b" else ">=",
+ column_names=column_names,
+ values=to_token,
+ engine=engine,
)
- return "(%d > %s OR (%d = %s AND %d >%s %s))" % (
- token.topological,
- "topological_ordering",
- token.topological,
- "topological_ordering",
- token.stream,
- inclusive,
- "stream_ordering",
)
+ return " AND ".join(where_clause)
+
+
+def _make_generic_sql_bound(bound, column_names, values, engine):
+ """Create an SQL expression that bounds the given column names by the
+ values, e.g. create the equivalent of `(1, 2) < (col1, col2)`.
+
+ Only works with two columns.
+
+ Older versions of SQLite don't support that syntax so we have to expand it
+ out manually.
+
+ Args:
+ bound (str): The comparison operator to use. One of ">", "<", ">=",
+ "<=", where the values are on the left and columns on the right.
+ names (tuple[str, str]): The column names. Must *not* be user defined
+ as these get inserted directly into the SQL statement without
+ escapes.
+ values (tuple[int|None, int]): The values to bound the columns by. If
+ the first value is None then only creates a bound on the second
+ column.
+ engine: The database engine to generate the SQL for
+
+ Returns:
+ str
+ """
+
+ assert(bound in (">", "<", ">=", "<="))
+
+ name1, name2 = column_names
+ val1, val2 = values
+
+ if val1 is None:
+ val2 = int(val2)
+ return "(%d %s %s)" % (val2, bound, name2)
+
+ val1 = int(val1)
+ val2 = int(val2)
+
+ if isinstance(engine, PostgresEngine):
+ # Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
+ # as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
+ # use the later form when running against postgres.
+ return "((%d,%d) %s (%s,%s))" % (
+ val1, val2,
+ bound,
+ name1, name2,
+ )
+
+ # We want to generate queries of e.g. the form:
+ #
+ # (val1 < name1 OR (val1 = name1 AND val2 <= name2))
+ #
+ # which is equivalent to (val1, val2) < (name1, name2)
+
+ return """(
+ {val1:d} {strict_bound} {name1}
+ OR ({val1:d} = {name1} AND {val2:d} {bound} {name2})
+ )""".format(
+ name1=name1,
+ val1=val1,
+ name2=name2,
+ val2=val2,
+ strict_bound=bound[0], # The first bound must always be strict equality here
+ bound=bound,
+ )
+
def filter_to_clause(event_filter):
# NB: This may create SQL clauses that don't optimise well (and we don't
@@ -319,7 +395,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = yield self.runInteraction("get_room_events_stream_for_room", f)
- ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True)
+ ret = yield self.get_events_as_list([
+ r.event_id for r in rows], get_prev_content=True,
+ )
self._set_before_and_after(ret, rows, topo_order=from_id is None)
@@ -367,7 +445,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rows = yield self.runInteraction("get_membership_changes_for_user", f)
- ret = yield self._get_events([r.event_id for r in rows], get_prev_content=True)
+ ret = yield self.get_events_as_list(
+ [r.event_id for r in rows], get_prev_content=True,
+ )
self._set_before_and_after(ret, rows, topo_order=False)
@@ -394,7 +474,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
logger.debug("stream before")
- events = yield self._get_events(
+ events = yield self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
logger.debug("stream after")
@@ -580,11 +660,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_filter,
)
- events_before = yield self._get_events(
+ events_before = yield self.get_events_as_list(
[e for e in results["before"]["event_ids"]], get_prev_content=True
)
- events_after = yield self._get_events(
+ events_after = yield self.get_events_as_list(
[e for e in results["after"]["event_ids"]], get_prev_content=True
)
@@ -697,7 +777,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
"get_all_new_events_stream", get_all_new_events_stream_txn
)
- events = yield self._get_events(event_ids)
+ events = yield self.get_events_as_list(event_ids)
defer.returnValue((upper_bound, events))
@@ -758,20 +838,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
args = [False, room_id]
if direction == 'b':
order = "DESC"
- bounds = upper_bound(from_token, self.database_engine)
- if to_token:
- bounds = "%s AND %s" % (
- bounds,
- lower_bound(to_token, self.database_engine),
- )
else:
order = "ASC"
- bounds = lower_bound(from_token, self.database_engine)
- if to_token:
- bounds = "%s AND %s" % (
- bounds,
- upper_bound(to_token, self.database_engine),
- )
+
+ bounds = generate_pagination_where_clause(
+ direction=direction,
+ column_names=("topological_ordering", "stream_ordering"),
+ from_token=from_token,
+ to_token=to_token,
+ engine=self.database_engine,
+ )
filter_clause, filter_args = filter_to_clause(event_filter)
@@ -849,7 +925,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
event_filter,
)
- events = yield self._get_events(
+ events = yield self.get_events_as_list(
[r.event_id for r in rows], get_prev_content=True
)
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 2f16f23d91..7253ba120f 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -156,6 +156,25 @@ def concurrently_execute(func, args, limit):
], consumeErrors=True)).addErrback(unwrapFirstError)
+def yieldable_gather_results(func, iter, *args, **kwargs):
+ """Executes the function with each argument concurrently.
+
+ Args:
+ func (func): Function to execute that returns a Deferred
+ iter (iter): An iterable that yields items that get passed as the first
+ argument to the function
+ *args: Arguments to be passed to each call to func
+
+ Returns
+ Deferred[list]: Resolved when all functions have been invoked, or errors if
+ one of the function calls fails.
+ """
+ return logcontext.make_deferred_yieldable(defer.gatherResults([
+ run_in_background(func, item, *args, **kwargs)
+ for item in iter
+ ], consumeErrors=True)).addErrback(unwrapFirstError)
+
+
class Linearizer(object):
"""Limits concurrent access to resources based on a key. Useful to ensure
only a few things happen at a time on a given resource.
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index 194da87639..e14c8bdfda 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -27,6 +27,7 @@ def user_left_room(distributor, user, room_id):
distributor.fire("user_left_room", user=user, room_id=room_id)
+# XXX: this is no longer used. We should probably kill it.
def user_joined_room(distributor, user, room_id):
distributor.fire("user_joined_room", user=user, room_id=room_id)
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 311b49e18a..fe412355d8 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -226,6 +226,8 @@ class LoggingContext(object):
self.request = request
def __str__(self):
+ if self.request:
+ return str(self.request)
return "%s@%x" % (self.name, id(self))
@classmethod
@@ -274,12 +276,10 @@ class LoggingContext(object):
current = self.set_current_context(self.previous_context)
if current is not self:
if current is self.sentinel:
- logger.warn("Expected logging context %s has been lost", self)
+ logger.warning("Expected logging context %s was lost", self)
else:
- logger.warn(
- "Current logging context %s is not expected context %s",
- current,
- self
+ logger.warning(
+ "Expected logging context %s but found %s", self, current
)
self.previous_context = None
self.alive = False
@@ -433,10 +433,14 @@ class PreserveLoggingContext(object):
context = LoggingContext.set_current_context(self.current_context)
if context != self.new_context:
- logger.warn(
- "Unexpected logging context: %s is not %s",
- context, self.new_context,
- )
+ if context is LoggingContext.sentinel:
+ logger.warning("Expected logging context %s was lost", self.new_context)
+ else:
+ logger.warning(
+ "Expected logging context %s but found %s",
+ self.new_context,
+ context,
+ )
if self.current_context is not LoggingContext.sentinel:
if not self.current_context.alive:
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 7deb38f2a7..b146d137f4 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -30,31 +30,14 @@ logger = logging.getLogger(__name__)
class FederationRateLimiter(object):
- def __init__(self, clock, window_size, sleep_limit, sleep_msec,
- reject_limit, concurrent_requests):
+ def __init__(self, clock, config):
"""
Args:
clock (Clock)
- window_size (int): The window size in milliseconds.
- sleep_limit (int): The number of requests received in the last
- `window_size` milliseconds before we artificially start
- delaying processing of requests.
- sleep_msec (int): The number of milliseconds to delay processing
- of incoming requests by.
- reject_limit (int): The maximum number of requests that are can be
- queued for processing before we start rejecting requests with
- a 429 Too Many Requests response.
- concurrent_requests (int): The number of concurrent requests to
- process.
+ config (FederationRateLimitConfig)
"""
self.clock = clock
-
- self.window_size = window_size
- self.sleep_limit = sleep_limit
- self.sleep_msec = sleep_msec
- self.reject_limit = reject_limit
- self.concurrent_requests = concurrent_requests
-
+ self._config = config
self.ratelimiters = {}
def ratelimit(self, host):
@@ -76,25 +59,25 @@ class FederationRateLimiter(object):
host,
_PerHostRatelimiter(
clock=self.clock,
- window_size=self.window_size,
- sleep_limit=self.sleep_limit,
- sleep_msec=self.sleep_msec,
- reject_limit=self.reject_limit,
- concurrent_requests=self.concurrent_requests,
+ config=self._config,
)
).ratelimit()
class _PerHostRatelimiter(object):
- def __init__(self, clock, window_size, sleep_limit, sleep_msec,
- reject_limit, concurrent_requests):
+ def __init__(self, clock, config):
+ """
+ Args:
+ clock (Clock)
+ config (FederationRateLimitConfig)
+ """
self.clock = clock
- self.window_size = window_size
- self.sleep_limit = sleep_limit
- self.sleep_sec = sleep_msec / 1000.0
- self.reject_limit = reject_limit
- self.concurrent_requests = concurrent_requests
+ self.window_size = config.window_size
+ self.sleep_limit = config.sleep_limit
+ self.sleep_sec = config.sleep_delay / 1000.0
+ self.reject_limit = config.reject_limit
+ self.concurrent_requests = config.concurrent
# request_id objects for requests which have been slept
self.sleeping_requests = set()
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 26cce7d197..1a77456498 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -46,8 +46,7 @@ class NotRetryingDestination(Exception):
@defer.inlineCallbacks
-def get_retry_limiter(destination, clock, store, ignore_backoff=False,
- **kwargs):
+def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs):
"""For a given destination check if we have previously failed to
send a request there and are waiting before retrying the destination.
If we are not ready to retry the destination, this will raise a
@@ -60,8 +59,7 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False,
clock (synapse.util.clock): timing source
store (synapse.storage.transactions.TransactionStore): datastore
ignore_backoff (bool): true to ignore the historical backoff data and
- try the request anyway. We will still update the next
- retry_interval on success/failure.
+ try the request anyway. We will still reset the retry_interval on success.
Example usage:
@@ -75,13 +73,12 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False,
"""
retry_last_ts, retry_interval = (0, 0)
- retry_timings = yield store.get_destination_retry_timings(
- destination
- )
+ retry_timings = yield store.get_destination_retry_timings(destination)
if retry_timings:
retry_last_ts, retry_interval = (
- retry_timings["retry_last_ts"], retry_timings["retry_interval"]
+ retry_timings["retry_last_ts"],
+ retry_timings["retry_interval"],
)
now = int(clock.time_msec())
@@ -93,22 +90,36 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False,
destination=destination,
)
+ # if we are ignoring the backoff data, we should also not increment the backoff
+ # when we get another failure - otherwise a server can very quickly reach the
+ # maximum backoff even though it might only have been down briefly
+ backoff_on_failure = not ignore_backoff
+
defer.returnValue(
RetryDestinationLimiter(
destination,
clock,
store,
retry_interval,
+ backoff_on_failure=backoff_on_failure,
**kwargs
)
)
class RetryDestinationLimiter(object):
- def __init__(self, destination, clock, store, retry_interval,
- min_retry_interval=10 * 60 * 1000,
- max_retry_interval=24 * 60 * 60 * 1000,
- multiplier_retry_interval=5, backoff_on_404=False):
+ def __init__(
+ self,
+ destination,
+ clock,
+ store,
+ retry_interval,
+ min_retry_interval=10 * 60 * 1000,
+ max_retry_interval=24 * 60 * 60 * 1000,
+ multiplier_retry_interval=5,
+ backoff_on_404=False,
+ backoff_on_failure=True,
+ ):
"""Marks the destination as "down" if an exception is thrown in the
context, except for CodeMessageException with code < 500.
@@ -128,6 +139,9 @@ class RetryDestinationLimiter(object):
multiplier_retry_interval (int): The multiplier to use to increase
the retry interval after a failed request.
backoff_on_404 (bool): Back off if we get a 404
+
+ backoff_on_failure (bool): set to False if we should not increase the
+ retry interval on a failure.
"""
self.clock = clock
self.store = store
@@ -138,6 +152,7 @@ class RetryDestinationLimiter(object):
self.max_retry_interval = max_retry_interval
self.multiplier_retry_interval = multiplier_retry_interval
self.backoff_on_404 = backoff_on_404
+ self.backoff_on_failure = backoff_on_failure
def __enter__(self):
pass
@@ -173,10 +188,13 @@ class RetryDestinationLimiter(object):
if not self.retry_interval:
return
- logger.debug("Connection to %s was successful; clearing backoff",
- self.destination)
+ logger.debug(
+ "Connection to %s was successful; clearing backoff", self.destination
+ )
retry_last_ts = 0
self.retry_interval = 0
+ elif not self.backoff_on_failure:
+ return
else:
# We couldn't connect.
if self.retry_interval:
@@ -190,7 +208,10 @@ class RetryDestinationLimiter(object):
logger.info(
"Connection to %s was unsuccessful (%s(%s)); backoff now %i",
- self.destination, exc_type, exc_val, self.retry_interval
+ self.destination,
+ exc_type,
+ exc_val,
+ self.retry_interval,
)
retry_last_ts = int(self.clock.time_msec())
@@ -201,9 +222,7 @@ class RetryDestinationLimiter(object):
self.destination, retry_last_ts, self.retry_interval
)
except Exception:
- logger.exception(
- "Failed to store destination_retry_timings",
- )
+ logger.exception("Failed to store destination_retry_timings")
# we deliberately do this in the background.
synapse.util.logcontext.run_in_background(store_retry_timings)
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index fdcb375f95..69dffd8244 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -24,14 +24,19 @@ _string_with_symbols = (
string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
)
+# random_string and random_string_with_symbols are used for a range of things,
+# some cryptographically important, some less so. We use SystemRandom to make sure
+# we get cryptographically-secure randoms.
+rand = random.SystemRandom()
+
def random_string(length):
- return ''.join(random.choice(string.ascii_letters) for _ in range(length))
+ return ''.join(rand.choice(string.ascii_letters) for _ in range(length))
def random_string_with_symbols(length):
return ''.join(
- random.choice(_string_with_symbols) for _ in range(length)
+ rand.choice(_string_with_symbols) for _ in range(length)
)
|