diff --git a/synapse/__init__.py b/synapse/__init__.py
index ef8853bd24..5bada5e290 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
-__version__ = "0.26.0"
+__version__ = "0.30.0"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index ac0a3655a5..b052cf532b 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -57,7 +57,7 @@ class Auth(object):
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000)
- register_cache("token_cache", self.token_cache)
+ register_cache("cache", "token_cache", self.token_cache)
@defer.inlineCallbacks
def check_from_context(self, event, context, do_sig_check=True):
@@ -204,8 +204,8 @@ class Auth(object):
ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders(
- "User-Agent",
- default=[""]
+ b"User-Agent",
+ default=[b""]
)[0]
if user and access_token and ip_addr:
self.store.insert_client_ip(
@@ -672,7 +672,7 @@ def has_access_token(request):
bool: False if no access_token was given, True otherwise.
"""
query_params = request.args.get("access_token")
- auth_headers = request.requestHeaders.getRawHeaders("Authorization")
+ auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
return bool(query_params) or bool(auth_headers)
@@ -692,8 +692,8 @@ def get_access_token_from_request(request, token_not_found_http_status=401):
AuthError: If there isn't an access_token in the request.
"""
- auth_headers = request.requestHeaders.getRawHeaders("Authorization")
- query_params = request.args.get("access_token")
+ auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
+ query_params = request.args.get(b"access_token")
if auth_headers:
# Try the get the access_token from a "Authorization: Bearer"
# header
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 489efb7f86..5baba43966 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -16,6 +16,9 @@
"""Contains constants from the specification."""
+# the "depth" field on events is limited to 2**63 - 1
+MAX_DEPTH = 2**63 - 1
+
class Membership(object):
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index aa15f73f36..e6ad3768f0 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -15,9 +15,12 @@
"""Contains exceptions and error codes."""
-import json
import logging
+import simplejson as json
+from six import iteritems
+from six.moves import http_client
+
logger = logging.getLogger(__name__)
@@ -49,6 +52,8 @@ class Codes(object):
THREEPID_DENIED = "M_THREEPID_DENIED"
INVALID_USERNAME = "M_INVALID_USERNAME"
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
+ CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN"
+ CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM"
class CodeMessageException(RuntimeError):
@@ -136,6 +141,32 @@ class SynapseError(CodeMessageException):
return res
+class ConsentNotGivenError(SynapseError):
+ """The error returned to the client when the user has not consented to the
+ privacy policy.
+ """
+ def __init__(self, msg, consent_uri):
+ """Constructs a ConsentNotGivenError
+
+ Args:
+ msg (str): The human-readable error message
+ consent_url (str): The URL where the user can give their consent
+ """
+ super(ConsentNotGivenError, self).__init__(
+ code=http_client.FORBIDDEN,
+ msg=msg,
+ errcode=Codes.CONSENT_NOT_GIVEN
+ )
+ self._consent_uri = consent_uri
+
+ def error_dict(self):
+ return cs_error(
+ self.msg,
+ self.errcode,
+ consent_uri=self._consent_uri
+ )
+
+
class RegistrationError(SynapseError):
"""An error raised when a registration event fails."""
pass
@@ -290,13 +321,13 @@ def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
Args:
msg (str): The error message.
- code (int): The error code.
+ code (str): The error code.
kwargs : Additional keys to add to the response.
Returns:
A dict representing the error response JSON.
"""
err = {"error": msg, "errcode": code}
- for key, value in kwargs.iteritems():
+ for key, value in iteritems(kwargs):
err[key] = value
return err
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 339e4a31d6..7158dd75e9 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -17,7 +17,7 @@ from synapse.storage.presence import UserPresenceState
from synapse.types import UserID, RoomID
from twisted.internet import defer
-import ujson as json
+import simplejson as json
import jsonschema
from jsonschema import FormatChecker
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index 91a33a3402..bb46b5da8a 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,6 +15,12 @@
# limitations under the License.
"""Contains the URL paths to prefix various aspects of the server with. """
+from hashlib import sha256
+import hmac
+
+from six.moves.urllib.parse import urlencode
+
+from synapse.config import ConfigError
CLIENT_PREFIX = "/_matrix/client/api/v1"
CLIENT_V2_ALPHA_PREFIX = "/_matrix/client/v2_alpha"
@@ -25,3 +32,46 @@ SERVER_KEY_PREFIX = "/_matrix/key/v1"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
MEDIA_PREFIX = "/_matrix/media/r0"
LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
+
+
+class ConsentURIBuilder(object):
+ def __init__(self, hs_config):
+ """
+ Args:
+ hs_config (synapse.config.homeserver.HomeServerConfig):
+ """
+ if hs_config.form_secret is None:
+ raise ConfigError(
+ "form_secret not set in config",
+ )
+ if hs_config.public_baseurl is None:
+ raise ConfigError(
+ "public_baseurl not set in config",
+ )
+
+ self._hmac_secret = hs_config.form_secret.encode("utf-8")
+ self._public_baseurl = hs_config.public_baseurl
+
+ def build_user_consent_uri(self, user_id):
+ """Build a URI which we can give to the user to do their privacy
+ policy consent
+
+ Args:
+ user_id (str): mxid or username of user
+
+ Returns
+ (str) the URI where the user can do consent
+ """
+ mac = hmac.new(
+ key=self._hmac_secret,
+ msg=user_id,
+ digestmod=sha256,
+ ).hexdigest()
+ consent_uri = "%s_matrix/consent?%s" % (
+ self._public_baseurl,
+ urlencode({
+ "u": user_id,
+ "h": mac
+ }),
+ )
+ return consent_uri
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index c6fe4516d1..b1efacc9f8 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -32,11 +32,11 @@ from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, preserve_fn
+from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
-from twisted.internet import reactor
-from twisted.web.resource import Resource
+from twisted.internet import reactor, defer
+from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.appservice")
@@ -64,7 +64,7 @@ class AppserviceServer(HomeServer):
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
- root_resource = create_resource_tree(resources, Resource())
+ root_resource = create_resource_tree(resources, NoResource())
_base.listen_tcp(
bind_addresses,
@@ -74,6 +74,7 @@ class AppserviceServer(HomeServer):
site_tag,
listener_config,
root_resource,
+ self.version_string,
)
)
@@ -112,9 +113,14 @@ class ASReplicationHandler(ReplicationClientHandler):
if stream_name == "events":
max_stream_id = self.store.get_room_max_stream_ordering()
- preserve_fn(
- self.appservice_handler.notify_interested_services
- )(max_stream_id)
+ run_in_background(self._notify_app_services, max_stream_id)
+
+ @defer.inlineCallbacks
+ def _notify_app_services(self, room_stream_id):
+ try:
+ yield self.appservice_handler.notify_interested_services(room_stream_id)
+ except Exception:
+ logger.exception("Error notifying application services of event")
def start(config_options):
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index 0a8ce9bc66..38b98382c6 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -44,7 +44,7 @@ from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
from twisted.internet import reactor
-from twisted.web.resource import Resource
+from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.client_reader")
@@ -88,7 +88,7 @@ class ClientReaderServer(HomeServer):
"/_matrix/client/api/v1": resource,
})
- root_resource = create_resource_tree(resources, Resource())
+ root_resource = create_resource_tree(resources, NoResource())
_base.listen_tcp(
bind_addresses,
@@ -98,6 +98,7 @@ class ClientReaderServer(HomeServer):
site_tag,
listener_config,
root_resource,
+ self.version_string,
)
)
diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py
index 172e989b54..bd7f3d5679 100644
--- a/synapse/app/event_creator.py
+++ b/synapse/app/event_creator.py
@@ -52,7 +52,7 @@ from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
from twisted.internet import reactor
-from twisted.web.resource import Resource
+from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.event_creator")
@@ -104,7 +104,7 @@ class EventCreatorServer(HomeServer):
"/_matrix/client/api/v1": resource,
})
- root_resource = create_resource_tree(resources, Resource())
+ root_resource = create_resource_tree(resources, NoResource())
_base.listen_tcp(
bind_addresses,
@@ -114,6 +114,7 @@ class EventCreatorServer(HomeServer):
site_tag,
listener_config,
root_resource,
+ self.version_string,
)
)
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index 20d157911b..6e10b27b9e 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -41,7 +41,7 @@ from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
from twisted.internet import reactor
-from twisted.web.resource import Resource
+from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.federation_reader")
@@ -77,7 +77,7 @@ class FederationReaderServer(HomeServer):
FEDERATION_PREFIX: TransportLayerServer(self),
})
- root_resource = create_resource_tree(resources, Resource())
+ root_resource = create_resource_tree(resources, NoResource())
_base.listen_tcp(
bind_addresses,
@@ -87,6 +87,7 @@ class FederationReaderServer(HomeServer):
site_tag,
listener_config,
root_resource,
+ self.version_string,
)
)
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index f760826d27..6f24e32d6d 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -38,11 +38,11 @@ from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.async import Linearizer
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, preserve_fn
+from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
from twisted.internet import defer, reactor
-from twisted.web.resource import Resource
+from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.federation_sender")
@@ -91,7 +91,7 @@ class FederationSenderServer(HomeServer):
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
- root_resource = create_resource_tree(resources, Resource())
+ root_resource = create_resource_tree(resources, NoResource())
_base.listen_tcp(
bind_addresses,
@@ -101,6 +101,7 @@ class FederationSenderServer(HomeServer):
site_tag,
listener_config,
root_resource,
+ self.version_string,
)
)
@@ -229,7 +230,7 @@ class FederationSenderHandler(object):
# presence, typing, etc.
if stream_name == "federation":
send_queue.process_rows_for_federation(self.federation_sender, rows)
- preserve_fn(self.update_token)(token)
+ run_in_background(self.update_token, token)
# We also need to poke the federation sender when new events happen
elif stream_name == "events":
@@ -237,19 +238,22 @@ class FederationSenderHandler(object):
@defer.inlineCallbacks
def update_token(self, token):
- self.federation_position = token
-
- # We linearize here to ensure we don't have races updating the token
- with (yield self._fed_position_linearizer.queue(None)):
- if self._last_ack < self.federation_position:
- yield self.store.update_federation_out_pos(
- "federation", self.federation_position
- )
+ try:
+ self.federation_position = token
+
+ # We linearize here to ensure we don't have races updating the token
+ with (yield self._fed_position_linearizer.queue(None)):
+ if self._last_ack < self.federation_position:
+ yield self.store.update_federation_out_pos(
+ "federation", self.federation_position
+ )
- # We ACK this token over replication so that the master can drop
- # its in memory queues
- self.replication_client.send_federation_ack(self.federation_position)
- self._last_ack = self.federation_position
+ # We ACK this token over replication so that the master can drop
+ # its in memory queues
+ self.replication_client.send_federation_ack(self.federation_position)
+ self._last_ack = self.federation_position
+ except Exception:
+ logger.exception("Error updating federation stream position")
if __name__ == '__main__':
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index 816c080d18..0f700ee786 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -44,7 +44,7 @@ from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
from twisted.internet import defer, reactor
-from twisted.web.resource import Resource
+from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.frontend_proxy")
@@ -90,7 +90,7 @@ class KeyUploadServlet(RestServlet):
# They're actually trying to upload something, proxy to main synapse.
# Pass through the auth headers, if any, in case the access token
# is there.
- auth_headers = request.requestHeaders.getRawHeaders("Authorization", [])
+ auth_headers = request.requestHeaders.getRawHeaders(b"Authorization", [])
headers = {
"Authorization": auth_headers,
}
@@ -142,7 +142,7 @@ class FrontendProxyServer(HomeServer):
"/_matrix/client/api/v1": resource,
})
- root_resource = create_resource_tree(resources, Resource())
+ root_resource = create_resource_tree(resources, NoResource())
_base.listen_tcp(
bind_addresses,
@@ -152,6 +152,7 @@ class FrontendProxyServer(HomeServer):
site_tag,
listener_config,
root_resource,
+ self.version_string,
)
)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index e477c7ced6..449bfacdb9 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -34,8 +34,8 @@ from synapse.module_api import ModuleApi
from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import RootRedirect
from synapse.http.site import SynapseSite
-from synapse.metrics import register_memory_metrics
-from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
+from synapse.metrics import RegistryProxy
+from synapse.metrics.resource import METRICS_PREFIX
from synapse.python_dependencies import CONDITIONAL_REQUIREMENTS, \
check_requirements
from synapse.replication.http import ReplicationRestResource, REPLICATION_PREFIX
@@ -48,6 +48,7 @@ from synapse.server import HomeServer
from synapse.storage import are_all_users_on_domain
from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database
+from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
@@ -56,10 +57,12 @@ from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from twisted.application import service
from twisted.internet import defer, reactor
-from twisted.web.resource import EncodingResourceWrapper, Resource
+from twisted.web.resource import EncodingResourceWrapper, NoResource
from twisted.web.server import GzipEncoderFactory
from twisted.web.static import File
+from prometheus_client.twisted import MetricsResource
+
logger = logging.getLogger("synapse.app.homeserver")
@@ -126,7 +129,7 @@ class SynapseHomeServer(HomeServer):
if WEB_CLIENT_PREFIX in resources:
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
else:
- root_resource = Resource()
+ root_resource = NoResource()
root_resource = create_resource_tree(resources, root_resource)
@@ -139,6 +142,7 @@ class SynapseHomeServer(HomeServer):
site_tag,
listener_config,
root_resource,
+ self.version_string,
),
self.tls_server_context_factory,
)
@@ -152,6 +156,7 @@ class SynapseHomeServer(HomeServer):
site_tag,
listener_config,
root_resource,
+ self.version_string,
)
)
logger.info("Synapse now listening on port %d", port)
@@ -181,6 +186,15 @@ class SynapseHomeServer(HomeServer):
"/_matrix/client/versions": client_resource,
})
+ if name == "consent":
+ from synapse.rest.consent.consent_resource import ConsentResource
+ consent_resource = ConsentResource(self)
+ if compress:
+ consent_resource = gz_wrap(consent_resource)
+ resources.update({
+ "/_matrix/consent": consent_resource,
+ })
+
if name == "federation":
resources.update({
FEDERATION_PREFIX: TransportLayerServer(self),
@@ -218,7 +232,7 @@ class SynapseHomeServer(HomeServer):
resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self)
if name == "metrics" and self.get_config().enable_metrics:
- resources[METRICS_PREFIX] = MetricsResource(self)
+ resources[METRICS_PREFIX] = MetricsResource(RegistryProxy())
if name == "replication":
resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
@@ -350,8 +364,6 @@ def setup(config_options):
hs.get_datastore().start_doing_background_updates()
hs.get_federation_client().start_get_pdu_cache()
- register_memory_metrics(hs)
-
reactor.callWhenRunning(start)
return hs
@@ -402,6 +414,10 @@ def run(hs):
stats = {}
+ # Contains the list of processes we will be monitoring
+ # currently either 0 or 1
+ stats_process = []
+
@defer.inlineCallbacks
def phone_stats_home():
logger.info("Gathering stats for reporting")
@@ -425,8 +441,21 @@ def run(hs):
stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms()
stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
+ r30_results = yield hs.get_datastore().count_r30_users()
+ for name, count in r30_results.iteritems():
+ stats["r30_users_" + name] = count
+
daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
stats["daily_sent_messages"] = daily_sent_messages
+ stats["cache_factor"] = CACHE_SIZE_FACTOR
+ stats["event_cache_size"] = hs.config.event_cache_size
+
+ if len(stats_process) > 0:
+ stats["memory_rss"] = 0
+ stats["cpu_average"] = 0
+ for process in stats_process:
+ stats["memory_rss"] += process.memory_info().rss
+ stats["cpu_average"] += int(process.cpu_percent(interval=None))
logger.info("Reporting stats to matrix.org: %s" % (stats,))
try:
@@ -437,10 +466,40 @@ def run(hs):
except Exception as e:
logger.warn("Error reporting stats: %s", e)
+ def performance_stats_init():
+ try:
+ import psutil
+ process = psutil.Process()
+ # Ensure we can fetch both, and make the initial request for cpu_percent
+ # so the next request will use this as the initial point.
+ process.memory_info().rss
+ process.cpu_percent(interval=None)
+ logger.info("report_stats can use psutil")
+ stats_process.append(process)
+ except (ImportError, AttributeError):
+ logger.warn(
+ "report_stats enabled but psutil is not installed or incorrect version."
+ " Disabling reporting of memory/cpu stats."
+ " Ensuring psutil is available will help matrix.org track performance"
+ " changes across releases."
+ )
+
+ def generate_user_daily_visit_stats():
+ hs.get_datastore().generate_user_daily_visits()
+
+ # Rather than update on per session basis, batch up the requests.
+ # If you increase the loop period, the accuracy of user_daily_visits
+ # table will decrease
+ clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000)
+
if hs.config.report_stats:
logger.info("Scheduling stats reporting for 3 hour intervals")
clock.looping_call(phone_stats_home, 3 * 60 * 60 * 1000)
+ # We need to defer this init for the cases that we daemonize
+ # otherwise the process ID we get is that of the non-daemon process
+ clock.call_later(0, performance_stats_init)
+
# We wait 5 minutes to send the first set of stats as the server can
# be quite busy the first few minutes
clock.call_later(5 * 60, phone_stats_home)
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index 84c5791b3b..9c93195f0a 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -43,7 +43,7 @@ from synapse.util.logcontext import LoggingContext
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
from twisted.internet import reactor
-from twisted.web.resource import Resource
+from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.media_repository")
@@ -84,7 +84,7 @@ class MediaRepositoryServer(HomeServer):
),
})
- root_resource = create_resource_tree(resources, Resource())
+ root_resource = create_resource_tree(resources, NoResource())
_base.listen_tcp(
bind_addresses,
@@ -94,6 +94,7 @@ class MediaRepositoryServer(HomeServer):
site_tag,
listener_config,
root_resource,
+ self.version_string,
)
)
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index 98a4a7c62c..3912eae48c 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -33,11 +33,11 @@ from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, preserve_fn
+from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
from twisted.internet import defer, reactor
-from twisted.web.resource import Resource
+from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.pusher")
@@ -94,7 +94,7 @@ class PusherServer(HomeServer):
if name == "metrics":
resources[METRICS_PREFIX] = MetricsResource(self)
- root_resource = create_resource_tree(resources, Resource())
+ root_resource = create_resource_tree(resources, NoResource())
_base.listen_tcp(
bind_addresses,
@@ -104,6 +104,7 @@ class PusherServer(HomeServer):
site_tag,
listener_config,
root_resource,
+ self.version_string,
)
)
@@ -140,24 +141,27 @@ class PusherReplicationHandler(ReplicationClientHandler):
def on_rdata(self, stream_name, token, rows):
super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
- preserve_fn(self.poke_pushers)(stream_name, token, rows)
+ run_in_background(self.poke_pushers, stream_name, token, rows)
@defer.inlineCallbacks
def poke_pushers(self, stream_name, token, rows):
- if stream_name == "pushers":
- for row in rows:
- if row.deleted:
- yield self.stop_pusher(row.user_id, row.app_id, row.pushkey)
- else:
- yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
- elif stream_name == "events":
- yield self.pusher_pool.on_new_notifications(
- token, token,
- )
- elif stream_name == "receipts":
- yield self.pusher_pool.on_new_receipts(
- token, token, set(row.room_id for row in rows)
- )
+ try:
+ if stream_name == "pushers":
+ for row in rows:
+ if row.deleted:
+ yield self.stop_pusher(row.user_id, row.app_id, row.pushkey)
+ else:
+ yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
+ elif stream_name == "events":
+ yield self.pusher_pool.on_new_notifications(
+ token, token,
+ )
+ elif stream_name == "receipts":
+ yield self.pusher_pool.on_new_receipts(
+ token, token, set(row.room_id for row in rows)
+ )
+ except Exception:
+ logger.exception("Error poking pushers")
def stop_pusher(self, user_id, app_id, pushkey):
key = "%s:%s" % (app_id, pushkey)
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index abe91dcfbd..c6294a7a0c 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -51,12 +51,14 @@ from synapse.storage.engines import create_engine
from synapse.storage.presence import UserPresenceState
from synapse.storage.roommember import RoomMemberStore
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, preserve_fn
+from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole
from synapse.util.stringutils import random_string
from synapse.util.versionstring import get_version_string
from twisted.internet import defer, reactor
-from twisted.web.resource import Resource
+from twisted.web.resource import NoResource
+
+from six import iteritems
logger = logging.getLogger("synapse.app.synchrotron")
@@ -211,7 +213,7 @@ class SynchrotronPresence(object):
def get_currently_syncing_users(self):
return [
- user_id for user_id, count in self.user_to_num_current_syncs.iteritems()
+ user_id for user_id, count in iteritems(self.user_to_num_current_syncs)
if count > 0
]
@@ -269,7 +271,7 @@ class SynchrotronServer(HomeServer):
"/_matrix/client/api/v1": resource,
})
- root_resource = create_resource_tree(resources, Resource())
+ root_resource = create_resource_tree(resources, NoResource())
_base.listen_tcp(
bind_addresses,
@@ -279,6 +281,7 @@ class SynchrotronServer(HomeServer):
site_tag,
listener_config,
root_resource,
+ self.version_string,
)
)
@@ -325,8 +328,7 @@ class SyncReplicationHandler(ReplicationClientHandler):
def on_rdata(self, stream_name, token, rows):
super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
-
- preserve_fn(self.process_and_notify)(stream_name, token, rows)
+ run_in_background(self.process_and_notify, stream_name, token, rows)
def get_streams_to_replicate(self):
args = super(SyncReplicationHandler, self).get_streams_to_replicate()
@@ -338,55 +340,58 @@ class SyncReplicationHandler(ReplicationClientHandler):
@defer.inlineCallbacks
def process_and_notify(self, stream_name, token, rows):
- if stream_name == "events":
- # We shouldn't get multiple rows per token for events stream, so
- # we don't need to optimise this for multiple rows.
- for row in rows:
- event = yield self.store.get_event(row.event_id)
- extra_users = ()
- if event.type == EventTypes.Member:
- extra_users = (event.state_key,)
- max_token = self.store.get_room_max_stream_ordering()
- self.notifier.on_new_room_event(
- event, token, max_token, extra_users
+ try:
+ if stream_name == "events":
+ # We shouldn't get multiple rows per token for events stream, so
+ # we don't need to optimise this for multiple rows.
+ for row in rows:
+ event = yield self.store.get_event(row.event_id)
+ extra_users = ()
+ if event.type == EventTypes.Member:
+ extra_users = (event.state_key,)
+ max_token = self.store.get_room_max_stream_ordering()
+ self.notifier.on_new_room_event(
+ event, token, max_token, extra_users
+ )
+ elif stream_name == "push_rules":
+ self.notifier.on_new_event(
+ "push_rules_key", token, users=[row.user_id for row in rows],
)
- elif stream_name == "push_rules":
- self.notifier.on_new_event(
- "push_rules_key", token, users=[row.user_id for row in rows],
- )
- elif stream_name in ("account_data", "tag_account_data",):
- self.notifier.on_new_event(
- "account_data_key", token, users=[row.user_id for row in rows],
- )
- elif stream_name == "receipts":
- self.notifier.on_new_event(
- "receipt_key", token, rooms=[row.room_id for row in rows],
- )
- elif stream_name == "typing":
- self.typing_handler.process_replication_rows(token, rows)
- self.notifier.on_new_event(
- "typing_key", token, rooms=[row.room_id for row in rows],
- )
- elif stream_name == "to_device":
- entities = [row.entity for row in rows if row.entity.startswith("@")]
- if entities:
+ elif stream_name in ("account_data", "tag_account_data",):
self.notifier.on_new_event(
- "to_device_key", token, users=entities,
+ "account_data_key", token, users=[row.user_id for row in rows],
)
- elif stream_name == "device_lists":
- all_room_ids = set()
- for row in rows:
- room_ids = yield self.store.get_rooms_for_user(row.user_id)
- all_room_ids.update(room_ids)
- self.notifier.on_new_event(
- "device_list_key", token, rooms=all_room_ids,
- )
- elif stream_name == "presence":
- yield self.presence_handler.process_replication_rows(token, rows)
- elif stream_name == "receipts":
- self.notifier.on_new_event(
- "groups_key", token, users=[row.user_id for row in rows],
- )
+ elif stream_name == "receipts":
+ self.notifier.on_new_event(
+ "receipt_key", token, rooms=[row.room_id for row in rows],
+ )
+ elif stream_name == "typing":
+ self.typing_handler.process_replication_rows(token, rows)
+ self.notifier.on_new_event(
+ "typing_key", token, rooms=[row.room_id for row in rows],
+ )
+ elif stream_name == "to_device":
+ entities = [row.entity for row in rows if row.entity.startswith("@")]
+ if entities:
+ self.notifier.on_new_event(
+ "to_device_key", token, users=entities,
+ )
+ elif stream_name == "device_lists":
+ all_room_ids = set()
+ for row in rows:
+ room_ids = yield self.store.get_rooms_for_user(row.user_id)
+ all_room_ids.update(room_ids)
+ self.notifier.on_new_event(
+ "device_list_key", token, rooms=all_room_ids,
+ )
+ elif stream_name == "presence":
+ yield self.presence_handler.process_replication_rows(token, rows)
+ elif stream_name == "receipts":
+ self.notifier.on_new_event(
+ "groups_key", token, users=[row.user_id for row in rows],
+ )
+ except Exception:
+ logger.exception("Error processing replication")
def start(config_options):
diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py
index 0f0ddfa78a..712dfa870e 100755
--- a/synapse/app/synctl.py
+++ b/synapse/app/synctl.py
@@ -38,7 +38,7 @@ def pid_running(pid):
try:
os.kill(pid, 0)
return True
- except OSError, err:
+ except OSError as err:
if err.errno == errno.EPERM:
return True
return False
@@ -98,7 +98,7 @@ def stop(pidfile, app):
try:
os.kill(pid, signal.SIGTERM)
write("stopped %s" % (app,), colour=GREEN)
- except OSError, err:
+ except OSError as err:
if err.errno == errno.ESRCH:
write("%s not running" % (app,), colour=YELLOW)
elif err.errno == errno.EPERM:
@@ -252,6 +252,7 @@ def main():
for running_pid in running_pids:
while pid_running(running_pid):
time.sleep(0.2)
+ write("All processes exited; now restarting...")
if action == "start" or action == "restart":
if start_stop_synapse:
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index 494ccb702c..53eb3474da 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -39,11 +39,11 @@ from synapse.storage.engines import create_engine
from synapse.storage.user_directory import UserDirectoryStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, preserve_fn
+from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
-from twisted.internet import reactor
-from twisted.web.resource import Resource
+from twisted.internet import reactor, defer
+from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.user_dir")
@@ -116,7 +116,7 @@ class UserDirectoryServer(HomeServer):
"/_matrix/client/api/v1": resource,
})
- root_resource = create_resource_tree(resources, Resource())
+ root_resource = create_resource_tree(resources, NoResource())
_base.listen_tcp(
bind_addresses,
@@ -126,6 +126,7 @@ class UserDirectoryServer(HomeServer):
site_tag,
listener_config,
root_resource,
+ self.version_string,
)
)
@@ -164,7 +165,14 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler):
stream_name, token, rows
)
if stream_name == "current_state_deltas":
- preserve_fn(self.user_directory.notify_new_event)()
+ run_in_background(self._notify_directory)
+
+ @defer.inlineCallbacks
+ def _notify_directory(self):
+ try:
+ yield self.user_directory.notify_new_event()
+ except Exception:
+ logger.exception("Error notifiying user directory of state update")
def start(config_options):
diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py
index d5a7a5ce2f..5fdb579723 100644
--- a/synapse/appservice/__init__.py
+++ b/synapse/appservice/__init__.py
@@ -21,6 +21,8 @@ from twisted.internet import defer
import logging
import re
+from six import string_types
+
logger = logging.getLogger(__name__)
@@ -146,7 +148,7 @@ class ApplicationService(object):
)
regex = regex_obj.get("regex")
- if isinstance(regex, basestring):
+ if isinstance(regex, string_types):
regex_obj["regex"] = re.compile(regex) # Pre-compile regex
else:
raise ValueError(
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 40c433d7ae..00efff1464 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -18,7 +18,6 @@ from synapse.api.constants import ThirdPartyEntityKind
from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event
-from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
from synapse.util.caches.response_cache import ResponseCache
from synapse.types import ThirdPartyInstanceID
@@ -73,7 +72,8 @@ class ApplicationServiceApi(SimpleHttpClient):
super(ApplicationServiceApi, self).__init__(hs)
self.clock = hs.get_clock()
- self.protocol_meta_cache = ResponseCache(hs, timeout_ms=HOUR_IN_MS)
+ self.protocol_meta_cache = ResponseCache(hs, "as_protocol_meta",
+ timeout_ms=HOUR_IN_MS)
@defer.inlineCallbacks
def query_user(self, service, user_id):
@@ -193,12 +193,7 @@ class ApplicationServiceApi(SimpleHttpClient):
defer.returnValue(None)
key = (service.id, protocol)
- result = self.protocol_meta_cache.get(key)
- if not result:
- result = self.protocol_meta_cache.set(
- key, preserve_fn(_get)()
- )
- return make_deferred_yieldable(result)
+ return self.protocol_meta_cache.wrap(key, _get)
@defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None):
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 6da315473d..6eddbc0828 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -51,7 +51,7 @@ components.
from twisted.internet import defer
from synapse.appservice import ApplicationServiceState
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import run_in_background
from synapse.util.metrics import Measure
import logging
@@ -106,7 +106,7 @@ class _ServiceQueuer(object):
def enqueue(self, service, event):
# if this service isn't being sent something
self.queued_events.setdefault(service.id, []).append(event)
- preserve_fn(self._send_request)(service)
+ run_in_background(self._send_request, service)
@defer.inlineCallbacks
def _send_request(self, service):
@@ -152,10 +152,10 @@ class _TransactionController(object):
if sent:
yield txn.complete(self.store)
else:
- preserve_fn(self._start_recoverer)(service)
- except Exception as e:
- logger.exception(e)
- preserve_fn(self._start_recoverer)(service)
+ run_in_background(self._start_recoverer, service)
+ except Exception:
+ logger.exception("Error creating appservice transaction")
+ run_in_background(self._start_recoverer, service)
@defer.inlineCallbacks
def on_recovered(self, recoverer):
@@ -176,17 +176,20 @@ class _TransactionController(object):
@defer.inlineCallbacks
def _start_recoverer(self, service):
- yield self.store.set_appservice_state(
- service,
- ApplicationServiceState.DOWN
- )
- logger.info(
- "Application service falling behind. Starting recoverer. AS ID %s",
- service.id
- )
- recoverer = self.recoverer_fn(service, self.on_recovered)
- self.add_recoverers([recoverer])
- recoverer.recover()
+ try:
+ yield self.store.set_appservice_state(
+ service,
+ ApplicationServiceState.DOWN
+ )
+ logger.info(
+ "Application service falling behind. Starting recoverer. AS ID %s",
+ service.id
+ )
+ recoverer = self.recoverer_fn(service, self.on_recovered)
+ self.add_recoverers([recoverer])
+ recoverer.recover()
+ except Exception:
+ logger.exception("Error starting AS recoverer")
@defer.inlineCallbacks
def _is_service_up(self, service):
diff --git a/synapse/config/__init__.py b/synapse/config/__init__.py
index bfebb0f644..f2a5a41e92 100644
--- a/synapse/config/__init__.py
+++ b/synapse/config/__init__.py
@@ -12,3 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+from ._base import ConfigError
+
+# export ConfigError if somebody does import *
+# this is largely a fudge to stop PEP8 moaning about the import
+__all__ = ["ConfigError"]
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index fa105bce72..b748ed2b0a 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -19,6 +19,8 @@ import os
import yaml
from textwrap import dedent
+from six import integer_types
+
class ConfigError(Exception):
pass
@@ -49,7 +51,7 @@ Missing mandatory `server_name` config option.
class Config(object):
@staticmethod
def parse_size(value):
- if isinstance(value, int) or isinstance(value, long):
+ if isinstance(value, integer_types):
return value
sizes = {"K": 1024, "M": 1024 * 1024}
size = 1
@@ -61,7 +63,7 @@ class Config(object):
@staticmethod
def parse_duration(value):
- if isinstance(value, int) or isinstance(value, long):
+ if isinstance(value, integer_types):
return value
second = 1000
minute = 60 * second
@@ -279,31 +281,31 @@ class Config(object):
)
if not cls.path_exists(config_dir_path):
os.makedirs(config_dir_path)
- with open(config_path, "wb") as config_file:
- config_bytes, config = obj.generate_config(
+ with open(config_path, "w") as config_file:
+ config_str, config = obj.generate_config(
config_dir_path=config_dir_path,
server_name=server_name,
report_stats=(config_args.report_stats == "yes"),
is_generating_file=True
)
obj.invoke_all("generate_files", config)
- config_file.write(config_bytes)
- print (
+ config_file.write(config_str)
+ print((
"A config file has been generated in %r for server name"
" %r with corresponding SSL keys and self-signed"
" certificates. Please review this file and customise it"
" to your needs."
- ) % (config_path, server_name)
- print (
+ ) % (config_path, server_name))
+ print(
"If this server name is incorrect, you will need to"
" regenerate the SSL certificates"
)
return
else:
- print (
+ print((
"Config file %r already exists. Generating any missing key"
" files."
- ) % (config_path,)
+ ) % (config_path,))
generate_keys = True
parser = argparse.ArgumentParser(
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index aba0aec6e8..277305e184 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -17,10 +17,12 @@ from ._base import Config, ConfigError
from synapse.appservice import ApplicationService
from synapse.types import UserID
-import urllib
import yaml
import logging
+from six import string_types
+from six.moves.urllib import parse as urlparse
+
logger = logging.getLogger(__name__)
@@ -89,21 +91,21 @@ def _load_appservice(hostname, as_info, config_filename):
"id", "as_token", "hs_token", "sender_localpart"
]
for field in required_string_fields:
- if not isinstance(as_info.get(field), basestring):
+ if not isinstance(as_info.get(field), string_types):
raise KeyError("Required string field: '%s' (%s)" % (
field, config_filename,
))
# 'url' must either be a string or explicitly null, not missing
# to avoid accidentally turning off push for ASes.
- if (not isinstance(as_info.get("url"), basestring) and
+ if (not isinstance(as_info.get("url"), string_types) and
as_info.get("url", "") is not None):
raise KeyError(
"Required string field or explicit null: 'url' (%s)" % (config_filename,)
)
localpart = as_info["sender_localpart"]
- if urllib.quote(localpart) != localpart:
+ if urlparse.quote(localpart) != localpart:
raise ValueError(
"sender_localpart needs characters which are not URL encoded."
)
@@ -128,7 +130,7 @@ def _load_appservice(hostname, as_info, config_filename):
"Expected namespace entry in %s to be an object,"
" but got %s", ns, regex_obj
)
- if not isinstance(regex_obj.get("regex"), basestring):
+ if not isinstance(regex_obj.get("regex"), string_types):
raise ValueError(
"Missing/bad type 'regex' key in %s", regex_obj
)
diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py
new file mode 100644
index 0000000000..8f6ed73328
--- /dev/null
+++ b/synapse/config/consent_config.py
@@ -0,0 +1,85 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import Config
+
+DEFAULT_CONFIG = """\
+# User Consent configuration
+#
+# Parts of this section are required if enabling the 'consent' resource under
+# 'listeners', in particular 'template_dir' and 'version'.
+#
+# 'template_dir' gives the location of the templates for the HTML forms.
+# This directory should contain one subdirectory per language (eg, 'en', 'fr'),
+# and each language directory should contain the policy document (named as
+# '<version>.html') and a success page (success.html).
+#
+# 'version' specifies the 'current' version of the policy document. It defines
+# the version to be served by the consent resource if there is no 'v'
+# parameter.
+#
+# 'server_notice_content', if enabled, will send a user a "Server Notice"
+# asking them to consent to the privacy policy. The 'server_notices' section
+# must also be configured for this to work. Notices will *not* be sent to
+# guest users unless 'send_server_notice_to_guests' is set to true.
+#
+# 'block_events_error', if set, will block any attempts to send events
+# until the user consents to the privacy policy. The value of the setting is
+# used as the text of the error.
+#
+# user_consent:
+# template_dir: res/templates/privacy
+# version: 1.0
+# server_notice_content:
+# msgtype: m.text
+# body: >-
+# To continue using this homeserver you must review and agree to the
+# terms and conditions at %(consent_uri)s
+# send_server_notice_to_guests: True
+# block_events_error: >-
+# To continue using this homeserver you must review and agree to the
+# terms and conditions at %(consent_uri)s
+#
+"""
+
+
+class ConsentConfig(Config):
+ def __init__(self):
+ super(ConsentConfig, self).__init__()
+
+ self.user_consent_version = None
+ self.user_consent_template_dir = None
+ self.user_consent_server_notice_content = None
+ self.user_consent_server_notice_to_guests = False
+ self.block_events_without_consent_error = None
+
+ def read_config(self, config):
+ consent_config = config.get("user_consent")
+ if consent_config is None:
+ return
+ self.user_consent_version = str(consent_config["version"])
+ self.user_consent_template_dir = consent_config["template_dir"]
+ self.user_consent_server_notice_content = consent_config.get(
+ "server_notice_content",
+ )
+ self.block_events_without_consent_error = consent_config.get(
+ "block_events_error",
+ )
+ self.user_consent_server_notice_to_guests = bool(consent_config.get(
+ "send_server_notice_to_guests", False,
+ ))
+
+ def default_config(self, **kwargs):
+ return DEFAULT_CONFIG
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index bf19cfee29..1dea2ad024 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from .tls import TlsConfig
from .server import ServerConfig
from .logger import LoggingConfig
@@ -37,6 +37,8 @@ from .push import PushConfig
from .spam_checker import SpamCheckerConfig
from .groups import GroupsConfig
from .user_directory import UserDirectoryConfig
+from .consent_config import ConsentConfig
+from .server_notices_config import ServerNoticesConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
@@ -45,12 +47,15 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
JWTConfig, PasswordConfig, EmailConfig,
WorkerConfig, PasswordAuthProviderConfig, PushConfig,
- SpamCheckerConfig, GroupsConfig, UserDirectoryConfig,):
+ SpamCheckerConfig, GroupsConfig, UserDirectoryConfig,
+ ConsentConfig,
+ ServerNoticesConfig,
+ ):
pass
if __name__ == '__main__':
import sys
sys.stdout.write(
- HomeServerConfig().generate_config(sys.argv[1], sys.argv[2])[0]
+ HomeServerConfig().generate_config(sys.argv[1], sys.argv[2], True)[0]
)
diff --git a/synapse/config/key.py b/synapse/config/key.py
index 4b8fc063d0..d1382ad9ac 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -59,14 +59,20 @@ class KeyConfig(Config):
self.expire_access_token = config.get("expire_access_token", False)
+ # a secret which is used to calculate HMACs for form values, to stop
+ # falsification of values
+ self.form_secret = config.get("form_secret", None)
+
def default_config(self, config_dir_path, server_name, is_generating_file=False,
**kwargs):
base_key_name = os.path.join(config_dir_path, server_name)
if is_generating_file:
macaroon_secret_key = random_string_with_symbols(50)
+ form_secret = '"%s"' % random_string_with_symbols(50)
else:
macaroon_secret_key = None
+ form_secret = 'null'
return """\
macaroon_secret_key: "%(macaroon_secret_key)s"
@@ -74,6 +80,10 @@ class KeyConfig(Config):
# Used to enable access token expiration.
expire_access_token: False
+ # a secret which is used to calculate HMACs for form values, to stop
+ # falsification of values
+ form_secret: %(form_secret)s
+
## Signing Keys ##
# Path to the signing key to sign messages with
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 3f70039acd..6a7228dc2f 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -117,7 +117,7 @@ class LoggingConfig(Config):
log_config = config.get("log_config")
if log_config and not os.path.exists(log_config):
log_file = self.abspath("homeserver.log")
- with open(log_config, "wb") as log_config_file:
+ with open(log_config, "w") as log_config_file:
log_config_file.write(
DEFAULT_LOG_CONFIG.substitute(log_file=log_file)
)
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 336959094b..c5384b3ad4 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -77,7 +77,9 @@ class RegistrationConfig(Config):
# Set the number of bcrypt rounds used to generate password hash.
# Larger numbers increase the work factor needed to generate the hash.
- # The default number of rounds is 12.
+ # The default number is 12 (which equates to 2^12 rounds).
+ # N.B. that increasing this will exponentially increase the time required
+ # to register or login - e.g. 24 => 2^24 rounds which will take >20 mins.
bcrypt_rounds: 12
# Allows users to register as guests without a password/email/etc, and
diff --git a/synapse/config/server_notices_config.py b/synapse/config/server_notices_config.py
new file mode 100644
index 0000000000..be1d1f762c
--- /dev/null
+++ b/synapse/config/server_notices_config.py
@@ -0,0 +1,86 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ._base import Config
+from synapse.types import UserID
+
+DEFAULT_CONFIG = """\
+# Server Notices room configuration
+#
+# Uncomment this section to enable a room which can be used to send notices
+# from the server to users. It is a special room which cannot be left; notices
+# come from a special "notices" user id.
+#
+# If you uncomment this section, you *must* define the system_mxid_localpart
+# setting, which defines the id of the user which will be used to send the
+# notices.
+#
+# It's also possible to override the room name, the display name of the
+# "notices" user, and the avatar for the user.
+#
+# server_notices:
+# system_mxid_localpart: notices
+# system_mxid_display_name: "Server Notices"
+# system_mxid_avatar_url: "mxc://server.com/oumMVlgDnLYFaPVkExemNVVZ"
+# room_name: "Server Notices"
+"""
+
+
+class ServerNoticesConfig(Config):
+ """Configuration for the server notices room.
+
+ Attributes:
+ server_notices_mxid (str|None):
+ The MXID to use for server notices.
+ None if server notices are not enabled.
+
+ server_notices_mxid_display_name (str|None):
+ The display name to use for the server notices user.
+ None if server notices are not enabled.
+
+ server_notices_mxid_avatar_url (str|None):
+ The display name to use for the server notices user.
+ None if server notices are not enabled.
+
+ server_notices_room_name (str|None):
+ The name to use for the server notices room.
+ None if server notices are not enabled.
+ """
+ def __init__(self):
+ super(ServerNoticesConfig, self).__init__()
+ self.server_notices_mxid = None
+ self.server_notices_mxid_display_name = None
+ self.server_notices_mxid_avatar_url = None
+ self.server_notices_room_name = None
+
+ def read_config(self, config):
+ c = config.get("server_notices")
+ if c is None:
+ return
+
+ mxid_localpart = c['system_mxid_localpart']
+ self.server_notices_mxid = UserID(
+ mxid_localpart, self.server_name,
+ ).to_string()
+ self.server_notices_mxid_display_name = c.get(
+ 'system_mxid_display_name', None,
+ )
+ self.server_notices_mxid_avatar_url = c.get(
+ 'system_mxid_avatar_url', None,
+ )
+ # todo: i18n
+ self.server_notices_room_name = c.get('room_name', "Server Notices")
+
+ def default_config(self, **kwargs):
+ return DEFAULT_CONFIG
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 29eb012ddb..b66154bc7c 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -133,7 +133,7 @@ class TlsConfig(Config):
tls_dh_params_path = config["tls_dh_params_path"]
if not self.path_exists(tls_private_key_path):
- with open(tls_private_key_path, "w") as private_key_file:
+ with open(tls_private_key_path, "wb") as private_key_file:
tls_private_key = crypto.PKey()
tls_private_key.generate_key(crypto.TYPE_RSA, 2048)
private_key_pem = crypto.dump_privatekey(
@@ -148,7 +148,7 @@ class TlsConfig(Config):
)
if not self.path_exists(tls_certificate_path):
- with open(tls_certificate_path, "w") as certificate_file:
+ with open(tls_certificate_path, "wb") as certificate_file:
cert = crypto.X509()
subject = cert.get_subject()
subject.CN = config["server_name"]
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index cff3ca809a..0397f73ab4 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -13,8 +13,8 @@
# limitations under the License.
from twisted.internet import ssl
-from OpenSSL import SSL
-from twisted.internet._sslverify import _OpenSSLECCurve, _defaultCurveName
+from OpenSSL import SSL, crypto
+from twisted.internet._sslverify import _defaultCurveName
import logging
@@ -32,8 +32,9 @@ class ServerContextFactory(ssl.ContextFactory):
@staticmethod
def configure_context(context, config):
try:
- _ecCurve = _OpenSSLECCurve(_defaultCurveName)
- _ecCurve.addECKeyToContext(context)
+ _ecCurve = crypto.get_elliptic_curve(_defaultCurveName)
+ context.set_tmp_ecdh(_ecCurve)
+
except Exception:
logger.exception("Failed to enable elliptic curve for TLS")
context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 35f810b07b..22ee0fc93f 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -19,7 +19,8 @@ from synapse.api.errors import SynapseError, Codes
from synapse.util import unwrapFirstError, logcontext
from synapse.util.logcontext import (
PreserveLoggingContext,
- preserve_fn
+ preserve_fn,
+ run_in_background,
)
from synapse.util.metrics import Measure
@@ -127,7 +128,7 @@ class Keyring(object):
verify_requests.append(verify_request)
- preserve_fn(self._start_key_lookups)(verify_requests)
+ run_in_background(self._start_key_lookups, verify_requests)
# Pass those keys to handle_key_deferred so that the json object
# signatures can be verified
@@ -146,53 +147,56 @@ class Keyring(object):
verify_requests (List[VerifyKeyRequest]):
"""
- # create a deferred for each server we're going to look up the keys
- # for; we'll resolve them once we have completed our lookups.
- # These will be passed into wait_for_previous_lookups to block
- # any other lookups until we have finished.
- # The deferreds are called with no logcontext.
- server_to_deferred = {
- rq.server_name: defer.Deferred()
- for rq in verify_requests
- }
-
- # We want to wait for any previous lookups to complete before
- # proceeding.
- yield self.wait_for_previous_lookups(
- [rq.server_name for rq in verify_requests],
- server_to_deferred,
- )
-
- # Actually start fetching keys.
- self._get_server_verify_keys(verify_requests)
-
- # When we've finished fetching all the keys for a given server_name,
- # resolve the deferred passed to `wait_for_previous_lookups` so that
- # any lookups waiting will proceed.
- #
- # map from server name to a set of request ids
- server_to_request_ids = {}
-
- for verify_request in verify_requests:
- server_name = verify_request.server_name
- request_id = id(verify_request)
- server_to_request_ids.setdefault(server_name, set()).add(request_id)
-
- def remove_deferreds(res, verify_request):
- server_name = verify_request.server_name
- request_id = id(verify_request)
- server_to_request_ids[server_name].discard(request_id)
- if not server_to_request_ids[server_name]:
- d = server_to_deferred.pop(server_name, None)
- if d:
- d.callback(None)
- return res
-
- for verify_request in verify_requests:
- verify_request.deferred.addBoth(
- remove_deferreds, verify_request,
+ try:
+ # create a deferred for each server we're going to look up the keys
+ # for; we'll resolve them once we have completed our lookups.
+ # These will be passed into wait_for_previous_lookups to block
+ # any other lookups until we have finished.
+ # The deferreds are called with no logcontext.
+ server_to_deferred = {
+ rq.server_name: defer.Deferred()
+ for rq in verify_requests
+ }
+
+ # We want to wait for any previous lookups to complete before
+ # proceeding.
+ yield self.wait_for_previous_lookups(
+ [rq.server_name for rq in verify_requests],
+ server_to_deferred,
)
+ # Actually start fetching keys.
+ self._get_server_verify_keys(verify_requests)
+
+ # When we've finished fetching all the keys for a given server_name,
+ # resolve the deferred passed to `wait_for_previous_lookups` so that
+ # any lookups waiting will proceed.
+ #
+ # map from server name to a set of request ids
+ server_to_request_ids = {}
+
+ for verify_request in verify_requests:
+ server_name = verify_request.server_name
+ request_id = id(verify_request)
+ server_to_request_ids.setdefault(server_name, set()).add(request_id)
+
+ def remove_deferreds(res, verify_request):
+ server_name = verify_request.server_name
+ request_id = id(verify_request)
+ server_to_request_ids[server_name].discard(request_id)
+ if not server_to_request_ids[server_name]:
+ d = server_to_deferred.pop(server_name, None)
+ if d:
+ d.callback(None)
+ return res
+
+ for verify_request in verify_requests:
+ verify_request.deferred.addBoth(
+ remove_deferreds, verify_request,
+ )
+ except Exception:
+ logger.exception("Error starting key lookups")
+
@defer.inlineCallbacks
def wait_for_previous_lookups(self, server_names, server_to_deferred):
"""Waits for any previous key lookups for the given servers to finish.
@@ -313,7 +317,7 @@ class Keyring(object):
if not verify_request.deferred.called:
verify_request.deferred.errback(err)
- preserve_fn(do_iterations)().addErrback(on_err)
+ run_in_background(do_iterations).addErrback(on_err)
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):
@@ -329,8 +333,9 @@ class Keyring(object):
"""
res = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.store.get_server_verify_keys)(
- server_name, key_ids
+ run_in_background(
+ self.store.get_server_verify_keys,
+ server_name, key_ids,
).addCallback(lambda ks, server: (server, ks), server_name)
for server_name, key_ids in server_name_and_key_ids
],
@@ -352,13 +357,13 @@ class Keyring(object):
logger.exception(
"Unable to get key from %r: %s %s",
perspective_name,
- type(e).__name__, str(e.message),
+ type(e).__name__, str(e),
)
defer.returnValue({})
results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(get_key)(p_name, p_keys)
+ run_in_background(get_key, p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items()
],
consumeErrors=True,
@@ -384,7 +389,7 @@ class Keyring(object):
logger.info(
"Unable to get key %r for %r directly: %s %s",
key_ids, server_name,
- type(e).__name__, str(e.message),
+ type(e).__name__, str(e),
)
if not keys:
@@ -398,7 +403,7 @@ class Keyring(object):
results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(get_key)(server_name, key_ids)
+ run_in_background(get_key, server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
@@ -481,7 +486,8 @@ class Keyring(object):
yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.store_keys)(
+ run_in_background(
+ self.store_keys,
server_name=server_name,
from_server=perspective_name,
verify_keys=response_keys,
@@ -539,7 +545,8 @@ class Keyring(object):
yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.store_keys)(
+ run_in_background(
+ self.store_keys,
server_name=key_server_name,
from_server=server_name,
verify_keys=verify_keys,
@@ -615,7 +622,8 @@ class Keyring(object):
yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.store.store_server_keys_json)(
+ run_in_background(
+ self.store.store_server_keys_json,
server_name=server_name,
key_id=key_id,
from_server=server_name,
@@ -716,7 +724,8 @@ class Keyring(object):
# TODO(markjh): Store whether the keys have expired.
return logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.store.store_server_verify_key)(
+ run_in_background(
+ self.store.store_server_verify_key,
server_name, server_name, key.time_added, key
)
for key_id, key in verify_keys.items()
@@ -734,7 +743,7 @@ def _handle_key_deferred(verify_request):
except IOError as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
- server_name, type(e).__name__, str(e.message),
+ server_name, type(e).__name__, str(e),
)
raise SynapseError(
502,
@@ -744,7 +753,7 @@ def _handle_key_deferred(verify_request):
except Exception as e:
logger.exception(
"Got Exception when downloading keys for %s: %s %s",
- server_name, type(e).__name__, str(e.message),
+ server_name, type(e).__name__, str(e),
)
raise SynapseError(
401,
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index e673e96cc0..c3ff85c49a 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -47,14 +47,26 @@ class _EventInternalMetadata(object):
def _event_dict_property(key):
+ # We want to be able to use hasattr with the event dict properties.
+ # However, (on python3) hasattr expects AttributeError to be raised. Hence,
+ # we need to transform the KeyError into an AttributeError
def getter(self):
- return self._event_dict[key]
+ try:
+ return self._event_dict[key]
+ except KeyError:
+ raise AttributeError(key)
def setter(self, v):
- self._event_dict[key] = v
+ try:
+ self._event_dict[key] = v
+ except KeyError:
+ raise AttributeError(key)
def delete(self):
- del self._event_dict[key]
+ try:
+ del self._event_dict[key]
+ except KeyError:
+ raise AttributeError(key)
return property(
getter,
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 824f4a42e3..29ae086786 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -20,6 +20,8 @@ from frozendict import frozendict
import re
+from six import string_types
+
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
# (?<!stuff) matches if the current position in the string is not preceded
# by a match for 'stuff'.
@@ -277,7 +279,7 @@ def serialize_event(e, time_now_ms, as_client_event=True,
if only_event_fields:
if (not isinstance(only_event_fields, list) or
- not all(isinstance(f, basestring) for f in only_event_fields)):
+ not all(isinstance(f, string_types) for f in only_event_fields)):
raise TypeError("only_event_fields must be a list of strings")
d = only_fields(d, only_event_fields)
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 2f4c8a1018..e0e5bf818c 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -17,6 +17,8 @@ from synapse.types import EventID, RoomID, UserID
from synapse.api.errors import SynapseError
from synapse.api.constants import EventTypes, Membership
+from six import string_types
+
class EventValidator(object):
@@ -49,7 +51,7 @@ class EventValidator(object):
strings.append("state_key")
for s in strings:
- if not isinstance(getattr(event, s), basestring):
+ if not isinstance(getattr(event, s), string_types):
raise SynapseError(400, "Not '%s' a string type" % (s,))
if event.type == EventTypes.Member:
@@ -88,5 +90,5 @@ class EventValidator(object):
for s in keys:
if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,))
- if not isinstance(d[s], basestring):
+ if not isinstance(d[s], string_types):
raise SynapseError(400, "Not '%s' a string type" % (s,))
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 79eaa31031..4cc98a3fe8 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -14,7 +14,10 @@
# limitations under the License.
import logging
-from synapse.api.errors import SynapseError
+import six
+
+from synapse.api.constants import MAX_DEPTH
+from synapse.api.errors import SynapseError, Codes
from synapse.crypto.event_signing import check_event_content_hash
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
@@ -190,11 +193,23 @@ def event_from_pdu_json(pdu_json, outlier=False):
FrozenEvent
Raises:
- SynapseError: if the pdu is missing required fields
+ SynapseError: if the pdu is missing required fields or is otherwise
+ not a valid matrix event
"""
# we could probably enforce a bunch of other fields here (room_id, sender,
# origin, etc etc)
- assert_params_in_request(pdu_json, ('event_id', 'type'))
+ assert_params_in_request(pdu_json, ('event_id', 'type', 'depth'))
+
+ depth = pdu_json['depth']
+ if not isinstance(depth, six.integer_types):
+ raise SynapseError(400, "Depth %r not an intger" % (depth, ),
+ Codes.BAD_JSON)
+
+ if depth < 0:
+ raise SynapseError(400, "Depth too small", Codes.BAD_JSON)
+ elif depth > MAX_DEPTH:
+ raise SynapseError(400, "Depth too large", Codes.BAD_JSON)
+
event = FrozenEvent(
pdu_json
)
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 38440da5b5..2761ffae07 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -19,6 +19,8 @@ import itertools
import logging
import random
+from six.moves import range
+
from twisted.internet import defer
from synapse.api.constants import Membership
@@ -30,20 +32,17 @@ from synapse.federation.federation_base import (
FederationBase,
event_from_pdu_json,
)
-import synapse.metrics
from synapse.util import logcontext, unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
-from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.logutils import log_function
from synapse.util.retryutils import NotRetryingDestination
-logger = logging.getLogger(__name__)
-
+from prometheus_client import Counter
-# synapse.federation.federation_client is a silly name
-metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
+logger = logging.getLogger(__name__)
-sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
+sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"])
PDU_RETRY_TIME_MS = 1 * 60 * 1000
@@ -106,7 +105,7 @@ class FederationClient(FederationBase):
a Deferred which will eventually yield a JSON object from the
response
"""
- sent_queries_counter.inc(query_type)
+ sent_queries_counter.labels(query_type).inc()
return self.transport_layer.make_query(
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail,
@@ -125,7 +124,7 @@ class FederationClient(FederationBase):
a Deferred which will eventually yield a JSON object from the
response
"""
- sent_queries_counter.inc("client_device_keys")
+ sent_queries_counter.labels("client_device_keys").inc()
return self.transport_layer.query_client_keys(
destination, content, timeout
)
@@ -135,7 +134,7 @@ class FederationClient(FederationBase):
"""Query the device keys for a list of user ids hosted on a remote
server.
"""
- sent_queries_counter.inc("user_devices")
+ sent_queries_counter.labels("user_devices").inc()
return self.transport_layer.query_user_devices(
destination, user_id, timeout
)
@@ -152,7 +151,7 @@ class FederationClient(FederationBase):
a Deferred which will eventually yield a JSON object from the
response
"""
- sent_queries_counter.inc("client_one_time_keys")
+ sent_queries_counter.labels("client_one_time_keys").inc()
return self.transport_layer.claim_client_keys(
destination, content, timeout
)
@@ -394,7 +393,7 @@ class FederationClient(FederationBase):
seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
signed_events = seen_events.values()
else:
- seen_events = yield self.store.have_events(event_ids)
+ seen_events = yield self.store.have_seen_events(event_ids)
signed_events = []
failed_to_fetch = set()
@@ -413,11 +412,12 @@ class FederationClient(FederationBase):
batch_size = 20
missing_events = list(missing_events)
- for i in xrange(0, len(missing_events), batch_size):
+ for i in range(0, len(missing_events), batch_size):
batch = set(missing_events[i:i + batch_size])
deferreds = [
- preserve_fn(self.get_pdu)(
+ run_in_background(
+ self.get_pdu,
destinations=random_server_list(),
event_id=e_id,
)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index bea7fd0b71..2d420a58a2 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -26,30 +27,32 @@ from synapse.federation.federation_base import (
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
-import synapse.metrics
from synapse.types import get_domain_from_id
from synapse.util import async
from synapse.util.caches.response_cache import ResponseCache
-from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.logutils import log_function
+from prometheus_client import Counter
+
+from six import iteritems
+
# when processing incoming transactions, we try to handle multiple rooms in
# parallel, up to this limit.
TRANSACTION_CONCURRENCY_LIMIT = 10
logger = logging.getLogger(__name__)
-# synapse.federation.federation_server is a silly name
-metrics = synapse.metrics.get_metrics_for("synapse.federation.server")
+received_pdus_counter = Counter("synapse_federation_server_received_pdus", "")
-received_pdus_counter = metrics.register_counter("received_pdus")
+received_edus_counter = Counter("synapse_federation_server_received_edus", "")
-received_edus_counter = metrics.register_counter("received_edus")
-
-received_queries_counter = metrics.register_counter("received_queries", labels=["type"])
+received_queries_counter = Counter(
+ "synapse_federation_server_received_queries", "", ["type"]
+)
class FederationServer(FederationBase):
+
def __init__(self, hs):
super(FederationServer, self).__init__(hs)
@@ -65,7 +68,7 @@ class FederationServer(FederationBase):
# We cache responses to state queries, as they take a while and often
# come in waves.
- self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
+ self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000)
@defer.inlineCallbacks
@log_function
@@ -129,7 +132,7 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id)
- received_pdus_counter.inc_by(len(transaction.pdus))
+ received_pdus_counter.inc(len(transaction.pdus))
pdus_by_room = {}
@@ -212,16 +215,17 @@ class FederationServer(FederationBase):
if not in_room:
raise AuthError(403, "Host not in room.")
- result = self._state_resp_cache.get((room_id, event_id))
- if not result:
- with (yield self._server_linearizer.queue((origin, room_id))):
- d = self._state_resp_cache.set(
- (room_id, event_id),
- preserve_fn(self._on_context_state_request_compute)(room_id, event_id)
- )
- resp = yield make_deferred_yieldable(d)
- else:
- resp = yield make_deferred_yieldable(result)
+ # we grab the linearizer to protect ourselves from servers which hammer
+ # us. In theory we might already have the response to this query
+ # in the cache so we could return it without waiting for the linearizer
+ # - but that's non-trivial to get right, and anyway somewhat defeats
+ # the point of the linearizer.
+ with (yield self._server_linearizer.queue((origin, room_id))):
+ resp = yield self._state_resp_cache.wrap(
+ (room_id, event_id),
+ self._on_context_state_request_compute,
+ room_id, event_id,
+ )
defer.returnValue((200, resp))
@@ -289,7 +293,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_query_request(self, query_type, args):
- received_queries_counter.inc(query_type)
+ received_queries_counter.labels(query_type).inc()
resp = yield self.registry.on_query(query_type, args)
defer.returnValue((200, resp))
@@ -425,9 +429,9 @@ class FederationServer(FederationBase):
"Claimed one-time-keys: %s",
",".join((
"%s for %s:%s" % (key_id, user_id, device_id)
- for user_id, user_keys in json_result.iteritems()
- for device_id, device_keys in user_keys.iteritems()
- for key_id, _ in device_keys.iteritems()
+ for user_id, user_keys in iteritems(json_result)
+ for device_id, device_keys in iteritems(user_keys)
+ for key_id, _ in iteritems(device_keys)
)),
)
@@ -494,13 +498,33 @@ class FederationServer(FederationBase):
def _handle_received_pdu(self, origin, pdu):
""" Process a PDU received in a federation /send/ transaction.
+ If the event is invalid, then this method throws a FederationError.
+ (The error will then be logged and sent back to the sender (which
+ probably won't do anything with it), and other events in the
+ transaction will be processed as normal).
+
+ It is likely that we'll then receive other events which refer to
+ this rejected_event in their prev_events, etc. When that happens,
+ we'll attempt to fetch the rejected event again, which will presumably
+ fail, so those second-generation events will also get rejected.
+
+ Eventually, we get to the point where there are more than 10 events
+ between any new events and the original rejected event. Since we
+ only try to backfill 10 events deep on received pdu, we then accept the
+ new event, possibly introducing a discontinuity in the DAG, with new
+ forward extremities, so normal service is approximately returned,
+ until we try to backfill across the discontinuity.
+
Args:
origin (str): server which sent the pdu
pdu (FrozenEvent): received pdu
Returns (Deferred): completes with None
- Raises: FederationError if the signatures / hash do not match
- """
+
+ Raises: FederationError if the signatures / hash do not match, or
+ if the event was unacceptable for any other reason (eg, too large,
+ too many prev_events, couldn't find the prev_events)
+ """
# check that it's actually being sent from a valid destination to
# workaround bug #1753 in 0.18.5 and 0.18.6
if origin != get_domain_from_id(pdu.event_id):
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 93e5acebc1..c7ed465617 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -33,17 +33,16 @@ from .units import Edu
from synapse.storage.presence import UserPresenceState
from synapse.util.metrics import Measure
-import synapse.metrics
+from synapse.metrics import LaterGauge
from blist import sorteddict
from collections import namedtuple
import logging
-logger = logging.getLogger(__name__)
-
+from six import itervalues, iteritems
-metrics = synapse.metrics.get_metrics_for(__name__)
+logger = logging.getLogger(__name__)
class FederationRemoteSendQueue(object):
@@ -75,10 +74,8 @@ class FederationRemoteSendQueue(object):
# lambda binds to the queue rather than to the name of the queue which
# changes. ARGH.
def register(name, queue):
- metrics.register_callback(
- queue_name + "_size",
- lambda: len(queue),
- )
+ LaterGauge("synapse_federation_send_queue_%s_size" % (queue_name,),
+ "", lambda: len(queue))
for queue_name in [
"presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
@@ -122,7 +119,7 @@ class FederationRemoteSendQueue(object):
user_ids = set(
user_id
- for uids in self.presence_changed.itervalues()
+ for uids in itervalues(self.presence_changed)
for user_id in uids
)
@@ -276,7 +273,7 @@ class FederationRemoteSendQueue(object):
# stream position.
keyed_edus = {self.keyed_edu_changed[k]: k for k in keys[i:j]}
- for ((destination, edu_key), pos) in keyed_edus.iteritems():
+ for ((destination, edu_key), pos) in iteritems(keyed_edus):
rows.append((pos, KeyedEduRow(
key=edu_key,
edu=self.keyed_edu[(destination, edu_key)],
@@ -309,7 +306,7 @@ class FederationRemoteSendQueue(object):
j = keys.bisect_right(to_token) + 1
device_messages = {self.device_messages[k]: k for k in keys[i:j]}
- for (destination, pos) in device_messages.iteritems():
+ for (destination, pos) in iteritems(device_messages):
rows.append((pos, DeviceRow(
destination=destination,
)))
@@ -528,19 +525,19 @@ def process_rows_for_federation(transaction_queue, rows):
if buff.presence:
transaction_queue.send_presence(buff.presence)
- for destination, edu_map in buff.keyed_edus.iteritems():
+ for destination, edu_map in iteritems(buff.keyed_edus):
for key, edu in edu_map.items():
transaction_queue.send_edu(
edu.destination, edu.edu_type, edu.content, key=key,
)
- for destination, edu_list in buff.edus.iteritems():
+ for destination, edu_list in iteritems(buff.edus):
for edu in edu_list:
transaction_queue.send_edu(
edu.destination, edu.edu_type, edu.content, key=None,
)
- for destination, failure_list in buff.failures.iteritems():
+ for destination, failure_list in iteritems(buff.failures):
for failure in failure_list:
transaction_queue.send_failure(destination, failure)
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index a141ec9953..69312ec233 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -26,23 +26,23 @@ from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
from synapse.util.metrics import measure_func
from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
import synapse.metrics
+from synapse.metrics import LaterGauge
+from synapse.metrics import (
+ sent_edus_counter,
+ sent_transactions_counter,
+ events_processed_counter,
+)
+
+from prometheus_client import Counter
import logging
logger = logging.getLogger(__name__)
-metrics = synapse.metrics.get_metrics_for(__name__)
-
-client_metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
-sent_pdus_destination_dist = client_metrics.register_distribution(
- "sent_pdu_destinations"
+sent_pdus_destination_dist = Counter(
+ "synapse_federation_transaction_queue_sent_pdu_destinations", ""
)
-sent_edus_counter = client_metrics.register_counter("sent_edus")
-
-sent_transactions_counter = client_metrics.register_counter("sent_transactions")
-
-events_processed_counter = client_metrics.register_counter("events_processed")
class TransactionQueue(object):
@@ -69,8 +69,10 @@ class TransactionQueue(object):
# done
self.pending_transactions = {}
- metrics.register_callback(
- "pending_destinations",
+ LaterGauge(
+ "synapse_federation_transaction_queue_pending_destinations",
+ "",
+ [],
lambda: len(self.pending_transactions),
)
@@ -94,12 +96,16 @@ class TransactionQueue(object):
# Map of destination -> (edu_type, key) -> Edu
self.pending_edus_keyed_by_dest = edus_keyed = {}
- metrics.register_callback(
- "pending_pdus",
+ LaterGauge(
+ "synapse_federation_transaction_queue_pending_pdus",
+ "",
+ [],
lambda: sum(map(len, pdus.values())),
)
- metrics.register_callback(
- "pending_edus",
+ LaterGauge(
+ "synapse_federation_transaction_queue_pending_edus",
+ "",
+ [],
lambda: (
sum(map(len, edus.values()))
+ sum(map(len, presence.values()))
@@ -169,7 +175,7 @@ class TransactionQueue(object):
while True:
last_token = yield self.store.get_federation_out_pos("events")
next_token, events = yield self.store.get_all_new_events_stream(
- last_token, self._last_poked_id, limit=20,
+ last_token, self._last_poked_id, limit=100,
)
logger.debug("Handling %s -> %s", last_token, next_token)
@@ -177,24 +183,33 @@ class TransactionQueue(object):
if not events and next_token >= self._last_poked_id:
break
- for event in events:
+ @defer.inlineCallbacks
+ def handle_event(event):
# Only send events for this server.
send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of()
is_mine = self.is_mine_id(event.event_id)
if not is_mine and send_on_behalf_of is None:
- continue
-
- # Get the state from before the event.
- # We need to make sure that this is the state from before
- # the event and not from after it.
- # Otherwise if the last member on a server in a room is
- # banned then it won't receive the event because it won't
- # be in the room after the ban.
- destinations = yield self.state.get_current_hosts_in_room(
- event.room_id, latest_event_ids=[
- prev_id for prev_id, _ in event.prev_events
- ],
- )
+ return
+
+ try:
+ # Get the state from before the event.
+ # We need to make sure that this is the state from before
+ # the event and not from after it.
+ # Otherwise if the last member on a server in a room is
+ # banned then it won't receive the event because it won't
+ # be in the room after the ban.
+ destinations = yield self.state.get_current_hosts_in_room(
+ event.room_id, latest_event_ids=[
+ prev_id for prev_id, _ in event.prev_events
+ ],
+ )
+ except Exception:
+ logger.exception(
+ "Failed to calculate hosts in room for event: %s",
+ event.event_id,
+ )
+ return
+
destinations = set(destinations)
if send_on_behalf_of is not None:
@@ -207,12 +222,41 @@ class TransactionQueue(object):
self._send_pdu(event, destinations)
- events_processed_counter.inc_by(len(events))
+ @defer.inlineCallbacks
+ def handle_room_events(events):
+ for event in events:
+ yield handle_event(event)
+
+ events_by_room = {}
+ for event in events:
+ events_by_room.setdefault(event.room_id, []).append(event)
+
+ yield logcontext.make_deferred_yieldable(defer.gatherResults(
+ [
+ logcontext.run_in_background(handle_room_events, evs)
+ for evs in events_by_room.itervalues()
+ ],
+ consumeErrors=True
+ ))
yield self.store.update_federation_out_pos(
"events", next_token
)
+ if events:
+ now = self.clock.time_msec()
+ ts = yield self.store.get_received_ts(events[-1].event_id)
+
+ synapse.metrics.event_processing_lag.labels(
+ "federation_sender").set(now - ts)
+ synapse.metrics.event_processing_last_ts.labels(
+ "federation_sender").set(ts)
+
+ events_processed_counter.inc(len(events))
+
+ synapse.metrics.event_processing_positions.labels(
+ "federation_sender").set(next_token)
+
finally:
self._is_processing = False
@@ -234,7 +278,7 @@ class TransactionQueue(object):
if not destinations:
return
- sent_pdus_destination_dist.inc_by(len(destinations))
+ sent_pdus_destination_dist.inc(len(destinations))
for destination in destinations:
self.pending_pdus_by_dest.setdefault(destination, []).append(
@@ -282,6 +326,8 @@ class TransactionQueue(object):
break
yield self._process_presence_inner(states_map.values())
+ except Exception:
+ logger.exception("Error sending presence states to servers")
finally:
self._processing_pending_presence = False
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 5488e82985..6db8efa6dd 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,6 +21,7 @@ from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.util.logutils import log_function
import logging
+import urllib
logger = logging.getLogger(__name__)
@@ -49,7 +51,7 @@ class TransportLayerClient(object):
logger.debug("get_room_state dest=%s, room=%s",
destination, room_id)
- path = PREFIX + "/state/%s/" % room_id
+ path = _create_path(PREFIX, "/state/%s/", room_id)
return self.client.get_json(
destination, path=path, args={"event_id": event_id},
)
@@ -71,7 +73,7 @@ class TransportLayerClient(object):
logger.debug("get_room_state_ids dest=%s, room=%s",
destination, room_id)
- path = PREFIX + "/state_ids/%s/" % room_id
+ path = _create_path(PREFIX, "/state_ids/%s/", room_id)
return self.client.get_json(
destination, path=path, args={"event_id": event_id},
)
@@ -93,7 +95,7 @@ class TransportLayerClient(object):
logger.debug("get_pdu dest=%s, event_id=%s",
destination, event_id)
- path = PREFIX + "/event/%s/" % (event_id, )
+ path = _create_path(PREFIX, "/event/%s/", event_id)
return self.client.get_json(destination, path=path, timeout=timeout)
@log_function
@@ -119,7 +121,7 @@ class TransportLayerClient(object):
# TODO: raise?
return
- path = PREFIX + "/backfill/%s/" % (room_id,)
+ path = _create_path(PREFIX, "/backfill/%s/", room_id)
args = {
"v": event_tuples,
@@ -157,9 +159,11 @@ class TransportLayerClient(object):
# generated by the json_data_callback.
json_data = transaction.get_dict()
+ path = _create_path(PREFIX, "/send/%s/", transaction.transaction_id)
+
response = yield self.client.put_json(
transaction.destination,
- path=PREFIX + "/send/%s/" % transaction.transaction_id,
+ path=path,
data=json_data,
json_data_callback=json_data_callback,
long_retries=True,
@@ -177,7 +181,7 @@ class TransportLayerClient(object):
@log_function
def make_query(self, destination, query_type, args, retry_on_dns_fail,
ignore_backoff=False):
- path = PREFIX + "/query/%s" % query_type
+ path = _create_path(PREFIX, "/query/%s", query_type)
content = yield self.client.get_json(
destination=destination,
@@ -222,7 +226,7 @@ class TransportLayerClient(object):
"make_membership_event called with membership='%s', must be one of %s" %
(membership, ",".join(valid_memberships))
)
- path = PREFIX + "/make_%s/%s/%s" % (membership, room_id, user_id)
+ path = _create_path(PREFIX, "/make_%s/%s/%s", membership, room_id, user_id)
ignore_backoff = False
retry_on_dns_fail = False
@@ -248,7 +252,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_join(self, destination, room_id, event_id, content):
- path = PREFIX + "/send_join/%s/%s" % (room_id, event_id)
+ path = _create_path(PREFIX, "/send_join/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
@@ -261,7 +265,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_leave(self, destination, room_id, event_id, content):
- path = PREFIX + "/send_leave/%s/%s" % (room_id, event_id)
+ path = _create_path(PREFIX, "/send_leave/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
@@ -280,7 +284,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_invite(self, destination, room_id, event_id, content):
- path = PREFIX + "/invite/%s/%s" % (room_id, event_id)
+ path = _create_path(PREFIX, "/invite/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
@@ -322,7 +326,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def exchange_third_party_invite(self, destination, room_id, event_dict):
- path = PREFIX + "/exchange_third_party_invite/%s" % (room_id,)
+ path = _create_path(PREFIX, "/exchange_third_party_invite/%s", room_id,)
response = yield self.client.put_json(
destination=destination,
@@ -335,7 +339,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def get_event_auth(self, destination, room_id, event_id):
- path = PREFIX + "/event_auth/%s/%s" % (room_id, event_id)
+ path = _create_path(PREFIX, "/event_auth/%s/%s", room_id, event_id)
content = yield self.client.get_json(
destination=destination,
@@ -347,7 +351,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_query_auth(self, destination, room_id, event_id, content):
- path = PREFIX + "/query_auth/%s/%s" % (room_id, event_id)
+ path = _create_path(PREFIX, "/query_auth/%s/%s", room_id, event_id)
content = yield self.client.post_json(
destination=destination,
@@ -409,7 +413,7 @@ class TransportLayerClient(object):
Returns:
A dict containg the device keys.
"""
- path = PREFIX + "/user/devices/" + user_id
+ path = _create_path(PREFIX, "/user/devices/%s", user_id)
content = yield self.client.get_json(
destination=destination,
@@ -459,7 +463,7 @@ class TransportLayerClient(object):
@log_function
def get_missing_events(self, destination, room_id, earliest_events,
latest_events, limit, min_depth, timeout):
- path = PREFIX + "/get_missing_events/%s" % (room_id,)
+ path = _create_path(PREFIX, "/get_missing_events/%s", room_id,)
content = yield self.client.post_json(
destination=destination,
@@ -479,7 +483,7 @@ class TransportLayerClient(object):
def get_group_profile(self, destination, group_id, requester_user_id):
"""Get a group profile
"""
- path = PREFIX + "/groups/%s/profile" % (group_id,)
+ path = _create_path(PREFIX, "/groups/%s/profile", group_id,)
return self.client.get_json(
destination=destination,
@@ -498,7 +502,7 @@ class TransportLayerClient(object):
requester_user_id (str)
content (dict): The new profile of the group
"""
- path = PREFIX + "/groups/%s/profile" % (group_id,)
+ path = _create_path(PREFIX, "/groups/%s/profile", group_id,)
return self.client.post_json(
destination=destination,
@@ -512,7 +516,7 @@ class TransportLayerClient(object):
def get_group_summary(self, destination, group_id, requester_user_id):
"""Get a group summary
"""
- path = PREFIX + "/groups/%s/summary" % (group_id,)
+ path = _create_path(PREFIX, "/groups/%s/summary", group_id,)
return self.client.get_json(
destination=destination,
@@ -525,7 +529,7 @@ class TransportLayerClient(object):
def get_rooms_in_group(self, destination, group_id, requester_user_id):
"""Get all rooms in a group
"""
- path = PREFIX + "/groups/%s/rooms" % (group_id,)
+ path = _create_path(PREFIX, "/groups/%s/rooms", group_id,)
return self.client.get_json(
destination=destination,
@@ -538,7 +542,7 @@ class TransportLayerClient(object):
content):
"""Add a room to a group
"""
- path = PREFIX + "/groups/%s/room/%s" % (group_id, room_id,)
+ path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,)
return self.client.post_json(
destination=destination,
@@ -552,7 +556,10 @@ class TransportLayerClient(object):
config_key, content):
"""Update room in group
"""
- path = PREFIX + "/groups/%s/room/%s/config/%s" % (group_id, room_id, config_key,)
+ path = _create_path(
+ PREFIX, "/groups/%s/room/%s/config/%s",
+ group_id, room_id, config_key,
+ )
return self.client.post_json(
destination=destination,
@@ -565,7 +572,7 @@ class TransportLayerClient(object):
def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
"""Remove a room from a group
"""
- path = PREFIX + "/groups/%s/room/%s" % (group_id, room_id,)
+ path = _create_path(PREFIX, "/groups/%s/room/%s", group_id, room_id,)
return self.client.delete_json(
destination=destination,
@@ -578,7 +585,7 @@ class TransportLayerClient(object):
def get_users_in_group(self, destination, group_id, requester_user_id):
"""Get users in a group
"""
- path = PREFIX + "/groups/%s/users" % (group_id,)
+ path = _create_path(PREFIX, "/groups/%s/users", group_id,)
return self.client.get_json(
destination=destination,
@@ -591,7 +598,7 @@ class TransportLayerClient(object):
def get_invited_users_in_group(self, destination, group_id, requester_user_id):
"""Get users that have been invited to a group
"""
- path = PREFIX + "/groups/%s/invited_users" % (group_id,)
+ path = _create_path(PREFIX, "/groups/%s/invited_users", group_id,)
return self.client.get_json(
destination=destination,
@@ -604,7 +611,23 @@ class TransportLayerClient(object):
def accept_group_invite(self, destination, group_id, user_id, content):
"""Accept a group invite
"""
- path = PREFIX + "/groups/%s/users/%s/accept_invite" % (group_id, user_id)
+ path = _create_path(
+ PREFIX, "/groups/%s/users/%s/accept_invite",
+ group_id, user_id,
+ )
+
+ return self.client.post_json(
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
+ def join_group(self, destination, group_id, user_id, content):
+ """Attempts to join a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/users/%s/join", group_id, user_id)
return self.client.post_json(
destination=destination,
@@ -617,7 +640,7 @@ class TransportLayerClient(object):
def invite_to_group(self, destination, group_id, user_id, requester_user_id, content):
"""Invite a user to a group
"""
- path = PREFIX + "/groups/%s/users/%s/invite" % (group_id, user_id)
+ path = _create_path(PREFIX, "/groups/%s/users/%s/invite", group_id, user_id)
return self.client.post_json(
destination=destination,
@@ -633,7 +656,7 @@ class TransportLayerClient(object):
invited.
"""
- path = PREFIX + "/groups/local/%s/users/%s/invite" % (group_id, user_id)
+ path = _create_path(PREFIX, "/groups/local/%s/users/%s/invite", group_id, user_id)
return self.client.post_json(
destination=destination,
@@ -647,7 +670,7 @@ class TransportLayerClient(object):
user_id, content):
"""Remove a user fron a group
"""
- path = PREFIX + "/groups/%s/users/%s/remove" % (group_id, user_id)
+ path = _create_path(PREFIX, "/groups/%s/users/%s/remove", group_id, user_id)
return self.client.post_json(
destination=destination,
@@ -664,7 +687,7 @@ class TransportLayerClient(object):
kicked from the group.
"""
- path = PREFIX + "/groups/local/%s/users/%s/remove" % (group_id, user_id)
+ path = _create_path(PREFIX, "/groups/local/%s/users/%s/remove", group_id, user_id)
return self.client.post_json(
destination=destination,
@@ -679,7 +702,7 @@ class TransportLayerClient(object):
the attestations
"""
- path = PREFIX + "/groups/%s/renew_attestation/%s" % (group_id, user_id)
+ path = _create_path(PREFIX, "/groups/%s/renew_attestation/%s", group_id, user_id)
return self.client.post_json(
destination=destination,
@@ -694,11 +717,12 @@ class TransportLayerClient(object):
"""Update a room entry in a group summary
"""
if category_id:
- path = PREFIX + "/groups/%s/summary/categories/%s/rooms/%s" % (
+ path = _create_path(
+ PREFIX, "/groups/%s/summary/categories/%s/rooms/%s",
group_id, category_id, room_id,
)
else:
- path = PREFIX + "/groups/%s/summary/rooms/%s" % (group_id, room_id,)
+ path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,)
return self.client.post_json(
destination=destination,
@@ -714,11 +738,12 @@ class TransportLayerClient(object):
"""Delete a room entry in a group summary
"""
if category_id:
- path = PREFIX + "/groups/%s/summary/categories/%s/rooms/%s" % (
+ path = _create_path(
+ PREFIX + "/groups/%s/summary/categories/%s/rooms/%s",
group_id, category_id, room_id,
)
else:
- path = PREFIX + "/groups/%s/summary/rooms/%s" % (group_id, room_id,)
+ path = _create_path(PREFIX, "/groups/%s/summary/rooms/%s", group_id, room_id,)
return self.client.delete_json(
destination=destination,
@@ -731,7 +756,7 @@ class TransportLayerClient(object):
def get_group_categories(self, destination, group_id, requester_user_id):
"""Get all categories in a group
"""
- path = PREFIX + "/groups/%s/categories" % (group_id,)
+ path = _create_path(PREFIX, "/groups/%s/categories", group_id,)
return self.client.get_json(
destination=destination,
@@ -744,7 +769,7 @@ class TransportLayerClient(object):
def get_group_category(self, destination, group_id, requester_user_id, category_id):
"""Get category info in a group
"""
- path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,)
+ path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
return self.client.get_json(
destination=destination,
@@ -758,7 +783,7 @@ class TransportLayerClient(object):
content):
"""Update a category in a group
"""
- path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,)
+ path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
return self.client.post_json(
destination=destination,
@@ -773,7 +798,7 @@ class TransportLayerClient(object):
category_id):
"""Delete a category in a group
"""
- path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,)
+ path = _create_path(PREFIX, "/groups/%s/categories/%s", group_id, category_id,)
return self.client.delete_json(
destination=destination,
@@ -786,7 +811,7 @@ class TransportLayerClient(object):
def get_group_roles(self, destination, group_id, requester_user_id):
"""Get all roles in a group
"""
- path = PREFIX + "/groups/%s/roles" % (group_id,)
+ path = _create_path(PREFIX, "/groups/%s/roles", group_id,)
return self.client.get_json(
destination=destination,
@@ -799,7 +824,7 @@ class TransportLayerClient(object):
def get_group_role(self, destination, group_id, requester_user_id, role_id):
"""Get a roles info
"""
- path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,)
+ path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
return self.client.get_json(
destination=destination,
@@ -813,7 +838,7 @@ class TransportLayerClient(object):
content):
"""Update a role in a group
"""
- path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,)
+ path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
return self.client.post_json(
destination=destination,
@@ -827,7 +852,7 @@ class TransportLayerClient(object):
def delete_group_role(self, destination, group_id, requester_user_id, role_id):
"""Delete a role in a group
"""
- path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,)
+ path = _create_path(PREFIX, "/groups/%s/roles/%s", group_id, role_id,)
return self.client.delete_json(
destination=destination,
@@ -842,11 +867,12 @@ class TransportLayerClient(object):
"""Update a users entry in a group
"""
if role_id:
- path = PREFIX + "/groups/%s/summary/roles/%s/users/%s" % (
+ path = _create_path(
+ PREFIX, "/groups/%s/summary/roles/%s/users/%s",
group_id, role_id, user_id,
)
else:
- path = PREFIX + "/groups/%s/summary/users/%s" % (group_id, user_id,)
+ path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,)
return self.client.post_json(
destination=destination,
@@ -857,16 +883,32 @@ class TransportLayerClient(object):
)
@log_function
+ def set_group_join_policy(self, destination, group_id, requester_user_id,
+ content):
+ """Sets the join policy for a group
+ """
+ path = _create_path(PREFIX, "/groups/%s/settings/m.join_policy", group_id,)
+
+ return self.client.put_json(
+ destination=destination,
+ path=path,
+ args={"requester_user_id": requester_user_id},
+ data=content,
+ ignore_backoff=True,
+ )
+
+ @log_function
def delete_group_summary_user(self, destination, group_id, requester_user_id,
user_id, role_id):
"""Delete a users entry in a group
"""
if role_id:
- path = PREFIX + "/groups/%s/summary/roles/%s/users/%s" % (
+ path = _create_path(
+ PREFIX, "/groups/%s/summary/roles/%s/users/%s",
group_id, role_id, user_id,
)
else:
- path = PREFIX + "/groups/%s/summary/users/%s" % (group_id, user_id,)
+ path = _create_path(PREFIX, "/groups/%s/summary/users/%s", group_id, user_id,)
return self.client.delete_json(
destination=destination,
@@ -889,3 +931,22 @@ class TransportLayerClient(object):
data=content,
ignore_backoff=True,
)
+
+
+def _create_path(prefix, path, *args):
+ """Creates a path from the prefix, path template and args. Ensures that
+ all args are url encoded.
+
+ Example:
+
+ _create_path(PREFIX, "/event/%s/", event_id)
+
+ Args:
+ prefix (str)
+ path (str): String template for the path
+ args: ([str]): Args to insert into path. Each arg will be url encoded
+
+ Returns:
+ str
+ """
+ return prefix + path % tuple(urllib.quote(arg, "") for arg in args)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index a66a6b0692..19d09f5422 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -24,7 +25,7 @@ from synapse.http.servlet import (
)
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import run_in_background
from synapse.types import ThirdPartyInstanceID, get_domain_from_id
import functools
@@ -93,12 +94,6 @@ class Authenticator(object):
"signatures": {},
}
- if (
- self.federation_domain_whitelist is not None and
- self.server_name not in self.federation_domain_whitelist
- ):
- raise FederationDeniedError(self.server_name)
-
if content is not None:
json_request["content"] = content
@@ -137,6 +132,12 @@ class Authenticator(object):
json_request["origin"] = origin
json_request["signatures"].setdefault(origin, {})[key] = sig
+ if (
+ self.federation_domain_whitelist is not None and
+ origin not in self.federation_domain_whitelist
+ ):
+ raise FederationDeniedError(origin)
+
if not json_request["signatures"]:
raise NoAuthenticationError(
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
@@ -151,11 +152,18 @@ class Authenticator(object):
# alive
retry_timings = yield self.store.get_destination_retry_timings(origin)
if retry_timings and retry_timings["retry_last_ts"]:
- logger.info("Marking origin %r as up", origin)
- preserve_fn(self.store.set_destination_retry_timings)(origin, 0, 0)
+ run_in_background(self._reset_retry_timings, origin)
defer.returnValue(origin)
+ @defer.inlineCallbacks
+ def _reset_retry_timings(self, origin):
+ try:
+ logger.info("Marking origin %r as up", origin)
+ yield self.store.set_destination_retry_timings(origin, 0, 0)
+ except Exception:
+ logger.exception("Error resetting retry timings on %s", origin)
+
class BaseFederationServlet(object):
REQUIRE_AUTH = True
@@ -802,6 +810,23 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
defer.returnValue((200, new_content))
+class FederationGroupsJoinServlet(BaseFederationServlet):
+ """Attempt to join a group
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join$"
+
+ @defer.inlineCallbacks
+ def on_POST(self, origin, content, query, group_id, user_id):
+ if get_domain_from_id(user_id) != origin:
+ raise SynapseError(403, "user_id doesn't match origin")
+
+ new_content = yield self.handler.join_group(
+ group_id, user_id, content,
+ )
+
+ defer.returnValue((200, new_content))
+
+
class FederationGroupsRemoveUserServlet(BaseFederationServlet):
"""Leave or kick a user from the group
"""
@@ -1124,6 +1149,24 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
defer.returnValue((200, resp))
+class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
+ """Sets whether a group is joinable without an invite or knock
+ """
+ PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy$"
+
+ @defer.inlineCallbacks
+ def on_PUT(self, origin, content, query, group_id):
+ requester_user_id = parse_string_from_args(query, "requester_user_id")
+ if get_domain_from_id(requester_user_id) != origin:
+ raise SynapseError(403, "requester_user_id doesn't match origin")
+
+ new_content = yield self.handler.set_group_join_policy(
+ group_id, requester_user_id, content
+ )
+
+ defer.returnValue((200, new_content))
+
+
FEDERATION_SERVLET_CLASSES = (
FederationSendServlet,
FederationPullServlet,
@@ -1163,6 +1206,7 @@ GROUP_SERVER_SERVLET_CLASSES = (
FederationGroupsInvitedUsersServlet,
FederationGroupsInviteServlet,
FederationGroupsAcceptInviteServlet,
+ FederationGroupsJoinServlet,
FederationGroupsRemoveUserServlet,
FederationGroupsSummaryRoomsServlet,
FederationGroupsCategoriesServlet,
@@ -1172,6 +1216,7 @@ GROUP_SERVER_SERVLET_CLASSES = (
FederationGroupsSummaryUsersServlet,
FederationGroupsAddRoomsServlet,
FederationGroupsAddRoomsConfigServlet,
+ FederationGroupsSettingJoinPolicyServlet,
)
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index 3f645acc43..01c5b8fe17 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -74,8 +74,6 @@ class Transaction(JsonEncodedObject):
"previous_ids",
"pdus",
"edus",
- "transaction_id",
- "destination",
"pdu_failures",
]
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index 1fb709e6c3..6f11fa374b 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -42,7 +42,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.types import get_domain_from_id
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import run_in_background
from signedjson.sign import sign_json
@@ -165,31 +165,35 @@ class GroupAttestionRenewer(object):
@defer.inlineCallbacks
def _renew_attestation(group_id, user_id):
- if not self.is_mine_id(group_id):
- destination = get_domain_from_id(group_id)
- elif not self.is_mine_id(user_id):
- destination = get_domain_from_id(user_id)
- else:
- logger.warn(
- "Incorrectly trying to do attestations for user: %r in %r",
- user_id, group_id,
+ try:
+ if not self.is_mine_id(group_id):
+ destination = get_domain_from_id(group_id)
+ elif not self.is_mine_id(user_id):
+ destination = get_domain_from_id(user_id)
+ else:
+ logger.warn(
+ "Incorrectly trying to do attestations for user: %r in %r",
+ user_id, group_id,
+ )
+ yield self.store.remove_attestation_renewal(group_id, user_id)
+ return
+
+ attestation = self.attestations.create_attestation(group_id, user_id)
+
+ yield self.transport_client.renew_group_attestation(
+ destination, group_id, user_id,
+ content={"attestation": attestation},
)
- yield self.store.remove_attestation_renewal(group_id, user_id)
- return
-
- attestation = self.attestations.create_attestation(group_id, user_id)
- yield self.transport_client.renew_group_attestation(
- destination, group_id, user_id,
- content={"attestation": attestation},
- )
-
- yield self.store.update_attestation_renewal(
- group_id, user_id, attestation
- )
+ yield self.store.update_attestation_renewal(
+ group_id, user_id, attestation
+ )
+ except Exception:
+ logger.exception("Error renewing attestation of %r in %r",
+ user_id, group_id)
for row in rows:
group_id = row["group_id"]
user_id = row["user_id"]
- preserve_fn(_renew_attestation)(group_id, user_id)
+ run_in_background(_renew_attestation, group_id, user_id)
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 0b995aed70..62d20ad130 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,6 +20,8 @@ from synapse.api.errors import SynapseError
from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
from twisted.internet import defer
+from six import string_types
+
logger = logging.getLogger(__name__)
@@ -206,6 +209,28 @@ class GroupsServerHandler(object):
defer.returnValue({})
@defer.inlineCallbacks
+ def set_group_join_policy(self, group_id, requester_user_id, content):
+ """Sets the group join policy.
+
+ Currently supported policies are:
+ - "invite": an invite must be received and accepted in order to join.
+ - "open": anyone can join.
+ """
+ yield self.check_group_is_ours(
+ group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
+ )
+
+ join_policy = _parse_join_policy_from_contents(content)
+ if join_policy is None:
+ raise SynapseError(
+ 400, "No value specified for 'm.join_policy'"
+ )
+
+ yield self.store.set_group_join_policy(group_id, join_policy=join_policy)
+
+ defer.returnValue({})
+
+ @defer.inlineCallbacks
def get_group_categories(self, group_id, requester_user_id):
"""Get all categories in a group (as seen by user)
"""
@@ -381,9 +406,16 @@ class GroupsServerHandler(object):
yield self.check_group_is_ours(group_id, requester_user_id)
- group_description = yield self.store.get_group(group_id)
+ group = yield self.store.get_group(group_id)
+
+ if group:
+ cols = [
+ "name", "short_description", "long_description",
+ "avatar_url", "is_public",
+ ]
+ group_description = {key: group[key] for key in cols}
+ group_description["is_openly_joinable"] = group["join_policy"] == "open"
- if group_description:
defer.returnValue(group_description)
else:
raise SynapseError(404, "Unknown group")
@@ -401,7 +433,7 @@ class GroupsServerHandler(object):
"long_description"):
if keyname in content:
value = content[keyname]
- if not isinstance(value, basestring):
+ if not isinstance(value, string_types):
raise SynapseError(400, "%r value is not a string" % (keyname,))
profile[keyname] = value
@@ -655,30 +687,21 @@ class GroupsServerHandler(object):
raise SynapseError(502, "Unknown state returned by HS")
@defer.inlineCallbacks
- def accept_invite(self, group_id, requester_user_id, content):
- """User tries to accept an invite to the group.
+ def _add_user(self, group_id, user_id, content):
+ """Add a user to a group based on a content dict.
- This is different from them asking to join, and so should error if no
- invite exists (and they're not a member of the group)
+ See accept_invite, join_group.
"""
-
- yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
-
- is_invited = yield self.store.is_user_invited_to_local_group(
- group_id, requester_user_id,
- )
- if not is_invited:
- raise SynapseError(403, "User not invited to group")
-
- if not self.hs.is_mine_id(requester_user_id):
+ if not self.hs.is_mine_id(user_id):
local_attestation = self.attestations.create_attestation(
- group_id, requester_user_id,
+ group_id, user_id,
)
+
remote_attestation = content["attestation"]
yield self.attestations.verify_attestation(
remote_attestation,
- user_id=requester_user_id,
+ user_id=user_id,
group_id=group_id,
)
else:
@@ -688,13 +711,53 @@ class GroupsServerHandler(object):
is_public = _parse_visibility_from_contents(content)
yield self.store.add_user_to_group(
- group_id, requester_user_id,
+ group_id, user_id,
is_admin=False,
is_public=is_public,
local_attestation=local_attestation,
remote_attestation=remote_attestation,
)
+ defer.returnValue(local_attestation)
+
+ @defer.inlineCallbacks
+ def accept_invite(self, group_id, requester_user_id, content):
+ """User tries to accept an invite to the group.
+
+ This is different from them asking to join, and so should error if no
+ invite exists (and they're not a member of the group)
+ """
+
+ yield self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
+
+ is_invited = yield self.store.is_user_invited_to_local_group(
+ group_id, requester_user_id,
+ )
+ if not is_invited:
+ raise SynapseError(403, "User not invited to group")
+
+ local_attestation = yield self._add_user(group_id, requester_user_id, content)
+
+ defer.returnValue({
+ "state": "join",
+ "attestation": local_attestation,
+ })
+
+ @defer.inlineCallbacks
+ def join_group(self, group_id, requester_user_id, content):
+ """User tries to join the group.
+
+ This will error if the group requires an invite/knock to join
+ """
+
+ group_info = yield self.check_group_is_ours(
+ group_id, requester_user_id, and_exists=True
+ )
+ if group_info['join_policy'] != "open":
+ raise SynapseError(403, "Group is not publicly joinable")
+
+ local_attestation = yield self._add_user(group_id, requester_user_id, content)
+
defer.returnValue({
"state": "join",
"attestation": local_attestation,
@@ -835,6 +898,31 @@ class GroupsServerHandler(object):
})
+def _parse_join_policy_from_contents(content):
+ """Given a content for a request, return the specified join policy or None
+ """
+
+ join_policy_dict = content.get("m.join_policy")
+ if join_policy_dict:
+ return _parse_join_policy_dict(join_policy_dict)
+ else:
+ return None
+
+
+def _parse_join_policy_dict(join_policy_dict):
+ """Given a dict for the "m.join_policy" config return the join policy specified
+ """
+ join_policy_type = join_policy_dict.get("type")
+ if not join_policy_type:
+ return "invite"
+
+ if join_policy_type not in ("invite", "open"):
+ raise SynapseError(
+ 400, "Synapse only supports 'invite'/'open' join rule"
+ )
+ return join_policy_type
+
+
def _parse_visibility_from_contents(content):
"""Given a content for a request parse out whether the entity should be
public or not
diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py
index 8f8fd82eb0..d358842b3e 100644
--- a/synapse/handlers/__init__.py
+++ b/synapse/handlers/__init__.py
@@ -14,9 +14,7 @@
# limitations under the License.
from .register import RegistrationHandler
-from .room import (
- RoomCreationHandler, RoomContextHandler,
-)
+from .room import RoomContextHandler
from .message import MessageHandler
from .federation import FederationHandler
from .directory import DirectoryHandler
@@ -47,7 +45,6 @@ class Handlers(object):
def __init__(self, hs):
self.registration_handler = RegistrationHandler(hs)
self.message_handler = MessageHandler(hs)
- self.room_creation_handler = RoomCreationHandler(hs)
self.federation_handler = FederationHandler(hs)
self.directory_handler = DirectoryHandler(hs)
self.admin_handler = AdminHandler(hs)
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 3dd3fa2a27..d9f35a5dba 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -18,15 +18,16 @@ from twisted.internet import defer
import synapse
from synapse.api.constants import EventTypes
from synapse.util.metrics import Measure
-from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
+from synapse.util.logcontext import (
+ make_deferred_yieldable, run_in_background,
+)
+from prometheus_client import Counter
import logging
logger = logging.getLogger(__name__)
-metrics = synapse.metrics.get_metrics_for(__name__)
-
-events_processed_counter = metrics.register_counter("events_processed")
+events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "")
def log_failure(failure):
@@ -84,11 +85,16 @@ class ApplicationServicesHandler(object):
if not events:
break
+ events_by_room = {}
for event in events:
+ events_by_room.setdefault(event.room_id, []).append(event)
+
+ @defer.inlineCallbacks
+ def handle_event(event):
# Gather interested services
services = yield self._get_services_for_event(event)
if len(services) == 0:
- continue # no services need notifying
+ return # no services need notifying
# Do we know this user exists? If not, poke the user
# query API for all services which match that user regex.
@@ -104,13 +110,32 @@ class ApplicationServicesHandler(object):
# Fork off pushes to these services
for service in services:
- preserve_fn(self.scheduler.submit_event_for_as)(
- service, event
- )
+ self.scheduler.submit_event_for_as(service, event)
- events_processed_counter.inc_by(len(events))
+ @defer.inlineCallbacks
+ def handle_room_events(events):
+ for event in events:
+ yield handle_event(event)
+
+ yield make_deferred_yieldable(defer.gatherResults([
+ run_in_background(handle_room_events, evs)
+ for evs in events_by_room.itervalues()
+ ], consumeErrors=True))
yield self.store.set_appservice_last_pos(upper_bound)
+
+ now = self.clock.time_msec()
+ ts = yield self.store.get_received_ts(events[-1].event_id)
+
+ synapse.metrics.event_processing_positions.labels(
+ "appservice_sender").set(upper_bound)
+
+ events_processed_counter.inc(len(events))
+
+ synapse.metrics.event_processing_lag.labels(
+ "appservice_sender").set(now - ts)
+ synapse.metrics.event_processing_last_ts.labels(
+ "appservice_sender").set(ts)
finally:
self.is_processing = False
@@ -167,7 +192,10 @@ class ApplicationServicesHandler(object):
services = yield self._get_services_for_3pn(protocol)
results = yield make_deferred_yieldable(defer.DeferredList([
- preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields)
+ run_in_background(
+ self.appservice_api.query_3pe,
+ service, kind, protocol, fields,
+ )
for service in services
], consumeErrors=True))
@@ -228,11 +256,15 @@ class ApplicationServicesHandler(object):
event based on the service regex.
"""
services = self.store.get_app_services()
- interested_list = [
- s for s in services if (
- yield s.is_interested(event, self.store)
- )
- ]
+
+ # we can't use a list comprehension here. Since python 3, list
+ # comprehensions use a generator internally. This means you can't yield
+ # inside of a list comprehension anymore.
+ interested_list = []
+ for s in services:
+ if (yield s.is_interested(event, self.store)):
+ interested_list.append(s)
+
defer.returnValue(interested_list)
def _get_services_for_user(self, user_id):
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index b1d3814909..c5e92f6214 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2017 New Vector Ltd
+# Copyright 2017, 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,9 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+from twisted.internet import defer, reactor
from ._base import BaseHandler
+from synapse.types import UserID, create_requester
+from synapse.util.logcontext import run_in_background
import logging
@@ -27,6 +29,15 @@ class DeactivateAccountHandler(BaseHandler):
super(DeactivateAccountHandler, self).__init__(hs)
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
+ self._room_member_handler = hs.get_room_member_handler()
+ self.user_directory_handler = hs.get_user_directory_handler()
+
+ # Flag that indicates whether the process to part users from rooms is running
+ self._user_parter_running = False
+
+ # Start the user parter loop so it can resume parting users from rooms where
+ # it left off (if it has work left to do).
+ reactor.callWhenRunning(self._start_user_parting)
@defer.inlineCallbacks
def deactivate_account(self, user_id):
@@ -50,3 +61,73 @@ class DeactivateAccountHandler(BaseHandler):
yield self.store.user_delete_threepids(user_id)
yield self.store.user_set_password_hash(user_id, None)
+
+ # Add the user to a table of users pending deactivation (ie.
+ # removal from all the rooms they're a member of)
+ yield self.store.add_user_pending_deactivation(user_id)
+
+ # delete from user directory
+ yield self.user_directory_handler.handle_user_deactivated(user_id)
+
+ # Now start the process that goes through that list and
+ # parts users from rooms (if it isn't already running)
+ self._start_user_parting()
+
+ def _start_user_parting(self):
+ """
+ Start the process that goes through the table of users
+ pending deactivation, if it isn't already running.
+
+ Returns:
+ None
+ """
+ if not self._user_parter_running:
+ run_in_background(self._user_parter_loop)
+
+ @defer.inlineCallbacks
+ def _user_parter_loop(self):
+ """Loop that parts deactivated users from rooms
+
+ Returns:
+ None
+ """
+ self._user_parter_running = True
+ logger.info("Starting user parter")
+ try:
+ while True:
+ user_id = yield self.store.get_user_pending_deactivation()
+ if user_id is None:
+ break
+ logger.info("User parter parting %r", user_id)
+ yield self._part_user(user_id)
+ yield self.store.del_user_pending_deactivation(user_id)
+ logger.info("User parter finished parting %r", user_id)
+ logger.info("User parter finished: stopping")
+ finally:
+ self._user_parter_running = False
+
+ @defer.inlineCallbacks
+ def _part_user(self, user_id):
+ """Causes the given user_id to leave all the rooms they're joined to
+
+ Returns:
+ None
+ """
+ user = UserID.from_string(user_id)
+
+ rooms_for_user = yield self.store.get_rooms_for_user(user_id)
+ for room_id in rooms_for_user:
+ logger.info("User parter parting %r from %r", user_id, room_id)
+ try:
+ yield self._room_member_handler.update_membership(
+ create_requester(user),
+ user,
+ room_id,
+ "leave",
+ ratelimit=False,
+ )
+ except Exception:
+ logger.exception(
+ "Failed to part user %r from room %r: ignoring and continuing",
+ user_id, room_id,
+ )
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 40f3d24678..31bd0e60c6 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -26,6 +26,8 @@ from ._base import BaseHandler
import logging
+from six import itervalues, iteritems
+
logger = logging.getLogger(__name__)
@@ -155,7 +157,7 @@ class DeviceHandler(BaseHandler):
try:
yield self.store.delete_device(user_id, device_id)
- except errors.StoreError, e:
+ except errors.StoreError as e:
if e.code == 404:
# no match
pass
@@ -204,7 +206,7 @@ class DeviceHandler(BaseHandler):
try:
yield self.store.delete_devices(user_id, device_ids)
- except errors.StoreError, e:
+ except errors.StoreError as e:
if e.code == 404:
# no match
pass
@@ -243,7 +245,7 @@ class DeviceHandler(BaseHandler):
new_display_name=content.get("display_name")
)
yield self.notify_device_update(user_id, [device_id])
- except errors.StoreError, e:
+ except errors.StoreError as e:
if e.code == 404:
raise errors.NotFoundError()
else:
@@ -318,7 +320,7 @@ class DeviceHandler(BaseHandler):
# The user may have left the room
# TODO: Check if they actually did or if we were just invited.
if room_id not in room_ids:
- for key, event_id in current_state_ids.iteritems():
+ for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
@@ -338,7 +340,7 @@ class DeviceHandler(BaseHandler):
# special-case for an empty prev state: include all members
# in the changed list
if not event_ids:
- for key, event_id in current_state_ids.iteritems():
+ for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
@@ -354,10 +356,10 @@ class DeviceHandler(BaseHandler):
# Check if we've joined the room? If so we just blindly add all the users to
# the "possibly changed" users.
- for state_dict in prev_state_ids.itervalues():
+ for state_dict in itervalues(prev_state_ids):
member_event = state_dict.get((EventTypes.Member, user_id), None)
if not member_event or member_event != current_member_id:
- for key, event_id in current_state_ids.iteritems():
+ for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
@@ -367,14 +369,14 @@ class DeviceHandler(BaseHandler):
# If there has been any change in membership, include them in the
# possibly changed list. We'll check if they are joined below,
# and we're not toooo worried about spuriously adding users.
- for key, event_id in current_state_ids.iteritems():
+ for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
# check if this member has changed since any of the extremities
# at the stream_ordering, and add them to the list if so.
- for state_dict in prev_state_ids.itervalues():
+ for state_dict in itervalues(prev_state_ids):
prev_event_id = state_dict.get(key, None)
if not prev_event_id or prev_event_id != event_id:
if state_key != user_id:
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 31b1ece13e..8a2d177539 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,17 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import ujson as json
+import simplejson as json
import logging
from canonicaljson import encode_canonical_json
from twisted.internet import defer
+from six import iteritems
from synapse.api.errors import (
SynapseError, CodeMessageException, FederationDeniedError,
)
from synapse.types import get_domain_from_id, UserID
-from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
@@ -91,7 +93,7 @@ class E2eKeysHandler(object):
remote_queries_not_in_cache = {}
if remote_queries:
query_list = []
- for user_id, device_ids in remote_queries.iteritems():
+ for user_id, device_ids in iteritems(remote_queries):
if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids)
else:
@@ -102,9 +104,9 @@ class E2eKeysHandler(object):
query_list
)
)
- for user_id, devices in remote_results.iteritems():
+ for user_id, devices in iteritems(remote_results):
user_devices = results.setdefault(user_id, {})
- for device_id, device in devices.iteritems():
+ for device_id, device in iteritems(devices):
keys = device.get("keys", None)
device_display_name = device.get("device_display_name", None)
if keys:
@@ -134,28 +136,13 @@ class E2eKeysHandler(object):
if user_id in destination_query:
results[user_id] = keys
- except CodeMessageException as e:
- failures[destination] = {
- "status": e.code, "message": e.message
- }
- except NotRetryingDestination as e:
- failures[destination] = {
- "status": 503, "message": "Not ready for retry",
- }
- except FederationDeniedError as e:
- failures[destination] = {
- "status": 403, "message": "Federation Denied",
- }
except Exception as e:
- # include ConnectionRefused and other errors
- failures[destination] = {
- "status": 503, "message": e.message
- }
+ failures[destination] = _exception_to_failure(e)
yield make_deferred_yieldable(defer.gatherResults([
- preserve_fn(do_remote_query)(destination)
+ run_in_background(do_remote_query, destination)
for destination in remote_queries_not_in_cache
- ]))
+ ], consumeErrors=True))
defer.returnValue({
"device_keys": results, "failures": failures,
@@ -252,32 +239,21 @@ class E2eKeysHandler(object):
for user_id, keys in remote_result["one_time_keys"].items():
if user_id in device_keys:
json_result[user_id] = keys
- except CodeMessageException as e:
- failures[destination] = {
- "status": e.code, "message": e.message
- }
- except NotRetryingDestination as e:
- failures[destination] = {
- "status": 503, "message": "Not ready for retry",
- }
except Exception as e:
- # include ConnectionRefused and other errors
- failures[destination] = {
- "status": 503, "message": e.message
- }
+ failures[destination] = _exception_to_failure(e)
yield make_deferred_yieldable(defer.gatherResults([
- preserve_fn(claim_client_keys)(destination)
+ run_in_background(claim_client_keys, destination)
for destination in remote_queries
- ]))
+ ], consumeErrors=True))
logger.info(
"Claimed one-time-keys: %s",
",".join((
"%s for %s:%s" % (key_id, user_id, device_id)
- for user_id, user_keys in json_result.iteritems()
- for device_id, device_keys in user_keys.iteritems()
- for key_id, _ in device_keys.iteritems()
+ for user_id, user_keys in iteritems(json_result)
+ for device_id, device_keys in iteritems(user_keys)
+ for key_id, _ in iteritems(device_keys)
)),
)
@@ -362,6 +338,31 @@ class E2eKeysHandler(object):
)
+def _exception_to_failure(e):
+ if isinstance(e, CodeMessageException):
+ return {
+ "status": e.code, "message": e.message,
+ }
+
+ if isinstance(e, NotRetryingDestination):
+ return {
+ "status": 503, "message": "Not ready for retry",
+ }
+
+ if isinstance(e, FederationDeniedError):
+ return {
+ "status": 403, "message": "Federation Denied",
+ }
+
+ # include ConnectionRefused and other errors
+ #
+ # Note that some Exceptions (notably twisted's ResponseFailed etc) don't
+ # give a string for e.message, which simplejson then fails to serialize.
+ return {
+ "status": 503, "message": str(e.message),
+ }
+
+
def _one_time_keys_match(old_key_json, new_key):
old_key = json.loads(old_key_json)
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index d3685fb12a..8bc642675f 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -48,6 +48,7 @@ class EventStreamHandler(BaseHandler):
self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
+ self._server_notices_sender = hs.get_server_notices_sender()
@defer.inlineCallbacks
@log_function
@@ -58,6 +59,10 @@ class EventStreamHandler(BaseHandler):
If `only_keys` is not None, events from keys will be sent down.
"""
+
+ # send any outstanding server notices to the user.
+ yield self._server_notices_sender.on_user_syncing(auth_user_id)
+
auth_user = UserID.from_string(auth_user_id)
presence_handler = self.hs.get_presence_handler()
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 080aca3d71..87c0615820 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -15,8 +15,17 @@
# limitations under the License.
"""Contains handlers for federation events."""
+
+import itertools
+import logging
+import sys
+
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
+import six
+from six.moves import http_client
+from six import iteritems
+from twisted.internet import defer
from unpaddedbase64 import decode_base64
from ._base import BaseHandler
@@ -43,10 +52,6 @@ from synapse.util.retryutils import NotRetryingDestination
from synapse.util.distributor import user_joined_room
-from twisted.internet import defer
-
-import itertools
-import logging
logger = logging.getLogger(__name__)
@@ -77,6 +82,7 @@ class FederationHandler(BaseHandler):
self.pusher_pool = hs.get_pusherpool()
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
+ self._server_notices_mxid = hs.config.server_notices_mxid
# When joining a room we need to queue any events for that room up
self.room_queues = {}
@@ -115,6 +121,19 @@ class FederationHandler(BaseHandler):
logger.debug("Already seen pdu %s", pdu.event_id)
return
+ # do some initial sanity-checking of the event. In particular, make
+ # sure it doesn't have hundreds of prev_events or auth_events, which
+ # could cause a huge state resolution or cascade of event fetches.
+ try:
+ self._sanity_check_event(pdu)
+ except SynapseError as err:
+ raise FederationError(
+ "ERROR",
+ err.code,
+ err.msg,
+ affected=pdu.event_id,
+ )
+
# If we are currently in the process of joining this room, then we
# queue up events for later processing.
if pdu.room_id in self.room_queues:
@@ -149,10 +168,6 @@ class FederationHandler(BaseHandler):
auth_chain = []
- have_seen = yield self.store.have_events(
- [ev for ev, _ in pdu.prev_events]
- )
-
fetch_state = False
# Get missing pdus if necessary.
@@ -168,7 +183,7 @@ class FederationHandler(BaseHandler):
)
prevs = {e_id for e_id, _ in pdu.prev_events}
- seen = set(have_seen.keys())
+ seen = yield self.store.have_seen_events(prevs)
if min_depth and pdu.depth < min_depth:
# This is so that we don't notify the user about this
@@ -196,8 +211,7 @@ class FederationHandler(BaseHandler):
# Update the set of things we've seen after trying to
# fetch the missing stuff
- have_seen = yield self.store.have_events(prevs)
- seen = set(have_seen.iterkeys())
+ seen = yield self.store.have_seen_events(prevs)
if not prevs - seen:
logger.info(
@@ -248,8 +262,7 @@ class FederationHandler(BaseHandler):
min_depth (int): Minimum depth of events to return.
"""
# We recalculate seen, since it may have changed.
- have_seen = yield self.store.have_events(prevs)
- seen = set(have_seen.keys())
+ seen = yield self.store.have_seen_events(prevs)
if not prevs - seen:
return
@@ -361,9 +374,7 @@ class FederationHandler(BaseHandler):
if auth_chain:
event_ids |= {e.event_id for e in auth_chain}
- seen_ids = set(
- (yield self.store.have_events(event_ids)).keys()
- )
+ seen_ids = yield self.store.have_seen_events(event_ids)
if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication
@@ -469,18 +480,18 @@ class FederationHandler(BaseHandler):
# to get all state ids that we're interested in.
event_map = yield self.store.get_events([
e_id
- for key_to_eid in event_to_state_ids.values()
- for key, e_id in key_to_eid.items()
+ for key_to_eid in event_to_state_ids.itervalues()
+ for key, e_id in key_to_eid.iteritems()
if key[0] != EventTypes.Member or check_match(key[1])
])
event_to_state = {
e_id: {
key: event_map[inner_e_id]
- for key, inner_e_id in key_to_eid.items()
+ for key, inner_e_id in key_to_eid.iteritems()
if inner_e_id in event_map
}
- for e_id, key_to_eid in event_to_state_ids.items()
+ for e_id, key_to_eid in event_to_state_ids.iteritems()
}
def redact_disallowed(event, state):
@@ -495,7 +506,7 @@ class FederationHandler(BaseHandler):
# membership states for the requesting server to determine
# if the server is either in the room or has been invited
# into the room.
- for ev in state.values():
+ for ev in state.itervalues():
if ev.type != EventTypes.Member:
continue
try:
@@ -527,9 +538,16 @@ class FederationHandler(BaseHandler):
def backfill(self, dest, room_id, limit, extremities):
""" Trigger a backfill request to `dest` for the given `room_id`
- This will attempt to get more events from the remote. This may return
- be successfull and still return no events if the other side has no new
- events to offer.
+ This will attempt to get more events from the remote. If the other side
+ has no new events to offer, this will return an empty list.
+
+ As the events are received, we check their signatures, and also do some
+ sanity-checking on them. If any of the backfilled events are invalid,
+ this method throws a SynapseError.
+
+ TODO: make this more useful to distinguish failures of the remote
+ server from invalid events (there is probably no point in trying to
+ re-fetch invalid events from every other HS in the room.)
"""
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")
@@ -541,6 +559,16 @@ class FederationHandler(BaseHandler):
extremities=extremities,
)
+ # ideally we'd sanity check the events here for excess prev_events etc,
+ # but it's hard to reject events at this point without completely
+ # breaking backfill in the same way that it is currently broken by
+ # events whose signature we cannot verify (#3121).
+ #
+ # So for now we accept the events anyway. #3124 tracks this.
+ #
+ # for ev in events:
+ # self._sanity_check_event(ev)
+
# Don't bother processing events we already have.
seen_events = yield self.store.have_events_in_timeline(
set(e.event_id for e in events)
@@ -613,7 +641,8 @@ class FederationHandler(BaseHandler):
results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- logcontext.preserve_fn(self.replication_layer.get_pdu)(
+ logcontext.run_in_background(
+ self.replication_layer.get_pdu,
[dest],
event_id,
outlier=True,
@@ -633,7 +662,7 @@ class FederationHandler(BaseHandler):
failed_to_fetch = missing_auth - set(auth_events)
- seen_events = yield self.store.have_events(
+ seen_events = yield self.store.have_seen_events(
set(auth_events.keys()) | set(state_events.keys())
)
@@ -723,9 +752,19 @@ class FederationHandler(BaseHandler):
curr_state = yield self.state_handler.get_current_state(room_id)
def get_domains_from_state(state):
+ """Get joined domains from state
+
+ Args:
+ state (dict[tuple, FrozenEvent]): State map from type/state
+ key to event.
+
+ Returns:
+ list[tuple[str, int]]: Returns a list of servers with the
+ lowest depth of their joins. Sorted by lowest depth first.
+ """
joined_users = [
(state_key, int(event.depth))
- for (e_type, state_key), event in state.items()
+ for (e_type, state_key), event in state.iteritems()
if e_type == EventTypes.Member
and event.membership == Membership.JOIN
]
@@ -742,7 +781,7 @@ class FederationHandler(BaseHandler):
except Exception:
pass
- return sorted(joined_domains.items(), key=lambda d: d[1])
+ return sorted(joined_domains.iteritems(), key=lambda d: d[1])
curr_domains = get_domains_from_state(curr_state)
@@ -759,7 +798,7 @@ class FederationHandler(BaseHandler):
yield self.backfill(
dom, room_id,
limit=100,
- extremities=[e for e in extremities.keys()]
+ extremities=extremities,
)
# If this succeeded then we probably already have the
# appropriate stuff.
@@ -805,7 +844,7 @@ class FederationHandler(BaseHandler):
tried_domains = set(likely_domains)
tried_domains.add(self.server_name)
- event_ids = list(extremities.keys())
+ event_ids = list(extremities.iterkeys())
logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = logcontext.preserve_fn(
@@ -815,34 +854,69 @@ class FederationHandler(BaseHandler):
[resolve(room_id, [e]) for e in event_ids],
consumeErrors=True,
))
+
+ # dict[str, dict[tuple, str]], a map from event_id to state map of
+ # event_ids.
states = dict(zip(event_ids, [s.state for s in states]))
state_map = yield self.store.get_events(
- [e_id for ids in states.values() for e_id in ids],
+ [e_id for ids in states.itervalues() for e_id in ids.itervalues()],
get_prev_content=False
)
states = {
key: {
k: state_map[e_id]
- for k, e_id in state_dict.items()
+ for k, e_id in state_dict.iteritems()
if e_id in state_map
- } for key, state_dict in states.items()
+ } for key, state_dict in states.iteritems()
}
for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id])
success = yield try_backfill([
- dom for dom in likely_domains
+ dom for dom, _ in likely_domains
if dom not in tried_domains
])
if success:
defer.returnValue(True)
- tried_domains.update(likely_domains)
+ tried_domains.update(dom for dom, _ in likely_domains)
defer.returnValue(False)
+ def _sanity_check_event(self, ev):
+ """
+ Do some early sanity checks of a received event
+
+ In particular, checks it doesn't have an excessive number of
+ prev_events or auth_events, which could cause a huge state resolution
+ or cascade of event fetches.
+
+ Args:
+ ev (synapse.events.EventBase): event to be checked
+
+ Returns: None
+
+ Raises:
+ SynapseError if the event does not pass muster
+ """
+ if len(ev.prev_events) > 20:
+ logger.warn("Rejecting event %s which has %i prev_events",
+ ev.event_id, len(ev.prev_events))
+ raise SynapseError(
+ http_client.BAD_REQUEST,
+ "Too many prev_events",
+ )
+
+ if len(ev.auth_events) > 10:
+ logger.warn("Rejecting event %s which has %i auth_events",
+ ev.event_id, len(ev.auth_events))
+ raise SynapseError(
+ http_client.BAD_REQUEST,
+ "Too many auth_events",
+ )
+
@defer.inlineCallbacks
def send_invite(self, target_host, event):
""" Sends the invite to the remote server for signing.
@@ -967,7 +1041,7 @@ class FederationHandler(BaseHandler):
# lots of requests for missing prev_events which we do actually
# have. Hence we fire off the deferred, but don't wait for it.
- logcontext.preserve_fn(self._handle_queued_pdus)(room_queue)
+ logcontext.run_in_background(self._handle_queued_pdus, room_queue)
defer.returnValue(True)
@@ -1121,6 +1195,13 @@ class FederationHandler(BaseHandler):
if not self.is_mine_id(event.state_key):
raise SynapseError(400, "The invite event must be for this server")
+ # block any attempts to invite the server notices mxid
+ if event.state_key == self._server_notices_mxid:
+ raise SynapseError(
+ http_client.FORBIDDEN,
+ "Cannot invite this user",
+ )
+
event.internal_metadata.outlier = True
event.internal_metadata.invite_from_remote = True
@@ -1308,7 +1389,7 @@ class FederationHandler(BaseHandler):
)
if state_groups:
- _, state = state_groups.items().pop()
+ _, state = list(iteritems(state_groups)).pop()
results = {
(e.type, e.state_key): e for e in state
}
@@ -1457,18 +1538,21 @@ class FederationHandler(BaseHandler):
backfilled=backfilled,
)
except: # noqa: E722, as we reraise the exception this is fine.
- # Ensure that we actually remove the entries in the push actions
- # staging area
- logcontext.preserve_fn(
- self.store.remove_push_actions_from_staging
- )(event.event_id)
- raise
+ tp, value, tb = sys.exc_info()
+
+ logcontext.run_in_background(
+ self.store.remove_push_actions_from_staging,
+ event.event_id,
+ )
+
+ six.reraise(tp, value, tb)
if not backfilled:
# this intentionally does not yield: we don't care about the result
# and don't need to wait for it.
- logcontext.preserve_fn(self.pusher_pool.on_new_notifications)(
- event_stream_id, max_stream_id
+ logcontext.run_in_background(
+ self.pusher_pool.on_new_notifications,
+ event_stream_id, max_stream_id,
)
defer.returnValue((context, event_stream_id, max_stream_id))
@@ -1482,7 +1566,8 @@ class FederationHandler(BaseHandler):
"""
contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- logcontext.preserve_fn(self._prep_event)(
+ logcontext.run_in_background(
+ self._prep_event,
origin,
ev_info["event"],
state=ev_info.get("state"),
@@ -1736,7 +1821,8 @@ class FederationHandler(BaseHandler):
event_key = None
if event_auth_events - current_state:
- have_events = yield self.store.have_events(
+ # TODO: can we use store.have_seen_events here instead?
+ have_events = yield self.store.get_seen_events_with_rejections(
event_auth_events - current_state
)
else:
@@ -1759,12 +1845,12 @@ class FederationHandler(BaseHandler):
origin, event.room_id, event.event_id
)
- seen_remotes = yield self.store.have_events(
+ seen_remotes = yield self.store.have_seen_events(
[e.event_id for e in remote_auth_chain]
)
for e in remote_auth_chain:
- if e.event_id in seen_remotes.keys():
+ if e.event_id in seen_remotes:
continue
if e.event_id == event.event_id:
@@ -1791,7 +1877,7 @@ class FederationHandler(BaseHandler):
except AuthError:
pass
- have_events = yield self.store.have_events(
+ have_events = yield self.store.get_seen_events_with_rejections(
[e_id for e_id, _ in event.auth_events]
)
seen_events = set(have_events.keys())
@@ -1810,7 +1896,8 @@ class FederationHandler(BaseHandler):
different_events = yield logcontext.make_deferred_yieldable(
defer.gatherResults([
- logcontext.preserve_fn(self.store.get_event)(
+ logcontext.run_in_background(
+ self.store.get_event,
d,
allow_none=True,
allow_rejected=False,
@@ -1876,13 +1963,13 @@ class FederationHandler(BaseHandler):
local_auth_chain,
)
- seen_remotes = yield self.store.have_events(
+ seen_remotes = yield self.store.have_seen_events(
[e.event_id for e in result["auth_chain"]]
)
# 3. Process any remote auth chain events we haven't seen.
for ev in result["auth_chain"]:
- if ev.event_id in seen_remotes.keys():
+ if ev.event_id in seen_remotes:
continue
if ev.event_id == event.event_id:
@@ -1948,7 +2035,7 @@ class FederationHandler(BaseHandler):
this will not be included in the current_state in the context.
"""
state_updates = {
- k: a.event_id for k, a in auth_events.iteritems()
+ k: a.event_id for k, a in iteritems(auth_events)
if k != event_key
}
context.current_state_ids = dict(context.current_state_ids)
@@ -1958,7 +2045,7 @@ class FederationHandler(BaseHandler):
context.delta_ids.update(state_updates)
context.prev_state_ids = dict(context.prev_state_ids)
context.prev_state_ids.update({
- k: a.event_id for k, a in auth_events.iteritems()
+ k: a.event_id for k, a in iteritems(auth_events)
})
context.state_group = yield self.store.store_state_group(
event.event_id,
@@ -2010,7 +2097,7 @@ class FederationHandler(BaseHandler):
def get_next(it, opt=None):
try:
- return it.next()
+ return next(it)
except Exception:
return opt
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index e4d0cc8b02..dcae083734 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,6 +15,7 @@
# limitations under the License.
from twisted.internet import defer
+from six import iteritems
from synapse.api.errors import SynapseError
from synapse.types import get_domain_from_id
@@ -90,6 +92,8 @@ class GroupsLocalHandler(object):
get_group_role = _create_rerouter("get_group_role")
get_group_roles = _create_rerouter("get_group_roles")
+ set_group_join_policy = _create_rerouter("set_group_join_policy")
+
@defer.inlineCallbacks
def get_group_summary(self, group_id, requester_user_id):
"""Get the group summary for a group.
@@ -226,7 +230,45 @@ class GroupsLocalHandler(object):
def join_group(self, group_id, user_id, content):
"""Request to join a group
"""
- raise NotImplementedError() # TODO
+ if self.is_mine_id(group_id):
+ yield self.groups_server_handler.join_group(
+ group_id, user_id, content
+ )
+ local_attestation = None
+ remote_attestation = None
+ else:
+ local_attestation = self.attestations.create_attestation(group_id, user_id)
+ content["attestation"] = local_attestation
+
+ res = yield self.transport_client.join_group(
+ get_domain_from_id(group_id), group_id, user_id, content,
+ )
+
+ remote_attestation = res["attestation"]
+
+ yield self.attestations.verify_attestation(
+ remote_attestation,
+ group_id=group_id,
+ user_id=user_id,
+ server_name=get_domain_from_id(group_id),
+ )
+
+ # TODO: Check that the group is public and we're being added publically
+ is_publicised = content.get("publicise", False)
+
+ token = yield self.store.register_user_group_membership(
+ group_id, user_id,
+ membership="join",
+ is_admin=False,
+ local_attestation=local_attestation,
+ remote_attestation=remote_attestation,
+ is_publicised=is_publicised,
+ )
+ self.notifier.on_new_event(
+ "groups_key", token, users=[user_id],
+ )
+
+ defer.returnValue({})
@defer.inlineCallbacks
def accept_invite(self, group_id, user_id, content):
@@ -408,7 +450,7 @@ class GroupsLocalHandler(object):
results = {}
failed_results = []
- for destination, dest_user_ids in destinations.iteritems():
+ for destination, dest_user_ids in iteritems(destinations):
try:
r = yield self.transport_client.bulk_get_publicised_groups(
destination, list(dest_user_ids),
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 9efcdff1d6..91a0898860 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -15,6 +15,11 @@
# limitations under the License.
"""Utilities for interacting with Identity Servers"""
+
+import logging
+
+import simplejson as json
+
from twisted.internet import defer
from synapse.api.errors import (
@@ -24,9 +29,6 @@ from ._base import BaseHandler
from synapse.util.async import run_on_reactor
from synapse.api.errors import SynapseError, Codes
-import json
-import logging
-
logger = logging.getLogger(__name__)
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index c5267b4b84..71af86fe21 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -27,7 +27,7 @@ from synapse.types import (
from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.snapshot_cache import SnapshotCache
-from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -166,7 +166,8 @@ class InitialSyncHandler(BaseHandler):
(messages, token), current_state = yield make_deferred_yieldable(
defer.gatherResults(
[
- preserve_fn(self.store.get_recent_events_for_room)(
+ run_in_background(
+ self.store.get_recent_events_for_room,
event.room_id,
limit=limit,
end_token=room_end_token,
@@ -180,8 +181,8 @@ class InitialSyncHandler(BaseHandler):
self.store, user_id, messages
)
- start_token = now_token.copy_and_replace("room_key", token[0])
- end_token = now_token.copy_and_replace("room_key", token[1])
+ start_token = now_token.copy_and_replace("room_key", token)
+ end_token = now_token.copy_and_replace("room_key", room_end_token)
time_now = self.clock.time_msec()
d["messages"] = {
@@ -324,8 +325,8 @@ class InitialSyncHandler(BaseHandler):
self.store, user_id, messages, is_peeking=is_peeking
)
- start_token = StreamToken.START.copy_and_replace("room_key", token[0])
- end_token = StreamToken.START.copy_and_replace("room_key", token[1])
+ start_token = StreamToken.START.copy_and_replace("room_key", token)
+ end_token = StreamToken.START.copy_and_replace("room_key", stream_token)
time_now = self.clock.time_msec()
@@ -391,9 +392,10 @@ class InitialSyncHandler(BaseHandler):
presence, receipts, (messages, token) = yield defer.gatherResults(
[
- preserve_fn(get_presence)(),
- preserve_fn(get_receipts)(),
- preserve_fn(self.store.get_recent_events_for_room)(
+ run_in_background(get_presence),
+ run_in_background(get_receipts),
+ run_in_background(
+ self.store.get_recent_events_for_room,
room_id,
limit=limit,
end_token=now_token.room_key,
@@ -406,8 +408,8 @@ class InitialSyncHandler(BaseHandler):
self.store, user_id, messages, is_peeking=is_peeking,
)
- start_token = now_token.copy_and_replace("room_key", token[0])
- end_token = now_token.copy_and_replace("room_key", token[1])
+ start_token = now_token.copy_and_replace("room_key", token)
+ end_token = now_token
time_now = self.clock.time_msec()
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 4f97c8db79..81cff0870e 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -13,11 +13,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import simplejson
+import sys
+
+from canonicaljson import encode_canonical_json
+import six
+from six import string_types, itervalues, iteritems
from twisted.internet import defer, reactor
+from twisted.internet.defer import succeed
from twisted.python.failure import Failure
-from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.constants import EventTypes, Membership, MAX_DEPTH
+from synapse.api.errors import (
+ AuthError, Codes, SynapseError,
+ ConsentNotGivenError,
+)
+from synapse.api.urls import ConsentURIBuilder
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
@@ -25,21 +37,15 @@ from synapse.types import (
UserID, RoomAlias, RoomStreamToken,
)
from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
-from synapse.util.logcontext import preserve_fn, run_in_background
+from synapse.util.logcontext import run_in_background
from synapse.util.metrics import measure_func
-from synapse.util.frozenutils import unfreeze
+from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.stringutils import random_string
from synapse.visibility import filter_events_for_client
from synapse.replication.http.send_event import send_event_to_master
from ._base import BaseHandler
-from canonicaljson import encode_canonical_json
-
-import logging
-import random
-import ujson
-
logger = logging.getLogger(__name__)
@@ -86,14 +92,14 @@ class MessageHandler(BaseHandler):
# map from purge id to PurgeStatus
self._purges_by_id = {}
- def start_purge_history(self, room_id, topological_ordering,
+ def start_purge_history(self, room_id, token,
delete_local_events=False):
"""Start off a history purge on a room.
Args:
room_id (str): The room to purge from
- topological_ordering (int): minimum topo ordering to preserve
+ token (str): topological token to delete events before
delete_local_events (bool): True to delete local events as well as
remote ones
@@ -115,19 +121,19 @@ class MessageHandler(BaseHandler):
self._purges_by_id[purge_id] = PurgeStatus()
run_in_background(
self._purge_history,
- purge_id, room_id, topological_ordering, delete_local_events,
+ purge_id, room_id, token, delete_local_events,
)
return purge_id
@defer.inlineCallbacks
- def _purge_history(self, purge_id, room_id, topological_ordering,
+ def _purge_history(self, purge_id, room_id, token,
delete_local_events):
"""Carry out a history purge on a room.
Args:
purge_id (str): The id for this purge
room_id (str): The room to purge from
- topological_ordering (int): minimum topo ordering to preserve
+ token (str): topological token to delete events before
delete_local_events (bool): True to delete local events as well as
remote ones
@@ -138,7 +144,7 @@ class MessageHandler(BaseHandler):
try:
with (yield self.pagination_lock.write(room_id)):
yield self.store.purge_history(
- room_id, topological_ordering, delete_local_events,
+ room_id, token, delete_local_events,
)
logger.info("[purge] complete")
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE
@@ -397,7 +403,7 @@ class MessageHandler(BaseHandler):
"avatar_url": profile.avatar_url,
"display_name": profile.display_name,
}
- for user_id, profile in users_with_profile.iteritems()
+ for user_id, profile in iteritems(users_with_profile)
})
@@ -431,9 +437,12 @@ class EventCreationHandler(object):
self.spam_checker = hs.get_spam_checker()
+ if self.config.block_events_without_consent_error is not None:
+ self._consent_uri_builder = ConsentURIBuilder(self.config)
+
@defer.inlineCallbacks
def create_event(self, requester, event_dict, token_id=None, txn_id=None,
- prev_event_ids=None):
+ prev_events_and_hashes=None):
"""
Given a dict from a client, create a new event.
@@ -447,50 +456,136 @@ class EventCreationHandler(object):
event_dict (dict): An entire event
token_id (str)
txn_id (str)
- prev_event_ids (list): The prev event ids to use when creating the event
+
+ prev_events_and_hashes (list[(str, dict[str, str], int)]|None):
+ the forward extremities to use as the prev_events for the
+ new event. For each event, a tuple of (event_id, hashes, depth)
+ where *hashes* is a map from algorithm to hash.
+
+ If None, they will be requested from the database.
Returns:
Tuple of created event (FrozenEvent), Context
"""
builder = self.event_builder_factory.new(event_dict)
- with (yield self.limiter.queue(builder.room_id)):
- self.validator.validate_new(builder)
-
- if builder.type == EventTypes.Member:
- membership = builder.content.get("membership", None)
- target = UserID.from_string(builder.state_key)
-
- if membership in {Membership.JOIN, Membership.INVITE}:
- # If event doesn't include a display name, add one.
- profile = self.profile_handler
- content = builder.content
-
- try:
- if "displayname" not in content:
- content["displayname"] = yield profile.get_displayname(target)
- if "avatar_url" not in content:
- content["avatar_url"] = yield profile.get_avatar_url(target)
- except Exception as e:
- logger.info(
- "Failed to get profile information for %r: %s",
- target, e
- )
+ self.validator.validate_new(builder)
+
+ if builder.type == EventTypes.Member:
+ membership = builder.content.get("membership", None)
+ target = UserID.from_string(builder.state_key)
+
+ if membership in {Membership.JOIN, Membership.INVITE}:
+ # If event doesn't include a display name, add one.
+ profile = self.profile_handler
+ content = builder.content
+
+ try:
+ if "displayname" not in content:
+ content["displayname"] = yield profile.get_displayname(target)
+ if "avatar_url" not in content:
+ content["avatar_url"] = yield profile.get_avatar_url(target)
+ except Exception as e:
+ logger.info(
+ "Failed to get profile information for %r: %s",
+ target, e
+ )
- if token_id is not None:
- builder.internal_metadata.token_id = token_id
+ is_exempt = yield self._is_exempt_from_privacy_policy(builder)
+ if not is_exempt:
+ yield self.assert_accepted_privacy_policy(requester)
- if txn_id is not None:
- builder.internal_metadata.txn_id = txn_id
+ if token_id is not None:
+ builder.internal_metadata.token_id = token_id
- event, context = yield self.create_new_client_event(
- builder=builder,
- requester=requester,
- prev_event_ids=prev_event_ids,
- )
+ if txn_id is not None:
+ builder.internal_metadata.txn_id = txn_id
+
+ event, context = yield self.create_new_client_event(
+ builder=builder,
+ requester=requester,
+ prev_events_and_hashes=prev_events_and_hashes,
+ )
defer.returnValue((event, context))
+ def _is_exempt_from_privacy_policy(self, builder):
+ """"Determine if an event to be sent is exempt from having to consent
+ to the privacy policy
+
+ Args:
+ builder (synapse.events.builder.EventBuilder): event being created
+
+ Returns:
+ Deferred[bool]: true if the event can be sent without the user
+ consenting
+ """
+ # the only thing the user can do is join the server notices room.
+ if builder.type == EventTypes.Member:
+ membership = builder.content.get("membership", None)
+ if membership == Membership.JOIN:
+ return self._is_server_notices_room(builder.room_id)
+ return succeed(False)
+
+ @defer.inlineCallbacks
+ def _is_server_notices_room(self, room_id):
+ if self.config.server_notices_mxid is None:
+ defer.returnValue(False)
+ user_ids = yield self.store.get_users_in_room(room_id)
+ defer.returnValue(self.config.server_notices_mxid in user_ids)
+
+ @defer.inlineCallbacks
+ def assert_accepted_privacy_policy(self, requester):
+ """Check if a user has accepted the privacy policy
+
+ Called when the given user is about to do something that requires
+ privacy consent. We see if the user is exempt and otherwise check that
+ they have given consent. If they have not, a ConsentNotGiven error is
+ raised.
+
+ Args:
+ requester (synapse.types.Requester):
+ The user making the request
+
+ Returns:
+ Deferred[None]: returns normally if the user has consented or is
+ exempt
+
+ Raises:
+ ConsentNotGivenError: if the user has not given consent yet
+ """
+ if self.config.block_events_without_consent_error is None:
+ return
+
+ # exempt AS users from needing consent
+ if requester.app_service is not None:
+ return
+
+ user_id = requester.user.to_string()
+
+ # exempt the system notices user
+ if (
+ self.config.server_notices_mxid is not None and
+ user_id == self.config.server_notices_mxid
+ ):
+ return
+
+ u = yield self.store.get_user_by_id(user_id)
+ assert u is not None
+ if u["consent_version"] == self.config.user_consent_version:
+ return
+
+ consent_uri = self._consent_uri_builder.build_user_consent_uri(
+ requester.user.localpart,
+ )
+ msg = self.config.block_events_without_consent_error % {
+ 'consent_uri': consent_uri,
+ }
+ raise ConsentNotGivenError(
+ msg=msg,
+ consent_uri=consent_uri,
+ )
+
@defer.inlineCallbacks
def send_nonmember_event(self, requester, event, context, ratelimit=True):
"""
@@ -557,64 +652,80 @@ class EventCreationHandler(object):
See self.create_event and self.send_nonmember_event.
"""
- event, context = yield self.create_event(
- requester,
- event_dict,
- token_id=requester.access_token_id,
- txn_id=txn_id
- )
- spam_error = self.spam_checker.check_event_for_spam(event)
- if spam_error:
- if not isinstance(spam_error, basestring):
- spam_error = "Spam is not permitted here"
- raise SynapseError(
- 403, spam_error, Codes.FORBIDDEN
+ # We limit the number of concurrent event sends in a room so that we
+ # don't fork the DAG too much. If we don't limit then we can end up in
+ # a situation where event persistence can't keep up, causing
+ # extremities to pile up, which in turn leads to state resolution
+ # taking longer.
+ with (yield self.limiter.queue(event_dict["room_id"])):
+ event, context = yield self.create_event(
+ requester,
+ event_dict,
+ token_id=requester.access_token_id,
+ txn_id=txn_id
)
- yield self.send_nonmember_event(
- requester,
- event,
- context,
- ratelimit=ratelimit,
- )
+ spam_error = self.spam_checker.check_event_for_spam(event)
+ if spam_error:
+ if not isinstance(spam_error, string_types):
+ spam_error = "Spam is not permitted here"
+ raise SynapseError(
+ 403, spam_error, Codes.FORBIDDEN
+ )
+
+ yield self.send_nonmember_event(
+ requester,
+ event,
+ context,
+ ratelimit=ratelimit,
+ )
defer.returnValue(event)
@measure_func("create_new_client_event")
@defer.inlineCallbacks
- def create_new_client_event(self, builder, requester=None, prev_event_ids=None):
- if prev_event_ids:
- prev_events = yield self.store.add_event_hashes(prev_event_ids)
- prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
- depth = prev_max_depth + 1
- else:
- latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room(
- builder.room_id,
- )
+ def create_new_client_event(self, builder, requester=None,
+ prev_events_and_hashes=None):
+ """Create a new event for a local client
- # We want to limit the max number of prev events we point to in our
- # new event
- if len(latest_ret) > 10:
- # Sort by reverse depth, so we point to the most recent.
- latest_ret.sort(key=lambda a: -a[2])
- new_latest_ret = latest_ret[:5]
-
- # We also randomly point to some of the older events, to make
- # sure that we don't completely ignore the older events.
- if latest_ret[5:]:
- sample_size = min(5, len(latest_ret[5:]))
- new_latest_ret.extend(random.sample(latest_ret[5:], sample_size))
- latest_ret = new_latest_ret
-
- if latest_ret:
- depth = max([d for _, _, d in latest_ret]) + 1
- else:
- depth = 1
+ Args:
+ builder (EventBuilder):
+
+ requester (synapse.types.Requester|None):
+
+ prev_events_and_hashes (list[(str, dict[str, str], int)]|None):
+ the forward extremities to use as the prev_events for the
+ new event. For each event, a tuple of (event_id, hashes, depth)
+ where *hashes* is a map from algorithm to hash.
- prev_events = [
- (event_id, prev_hashes)
- for event_id, prev_hashes, _ in latest_ret
- ]
+ If None, they will be requested from the database.
+
+ Returns:
+ Deferred[(synapse.events.EventBase, synapse.events.snapshot.EventContext)]
+ """
+
+ if prev_events_and_hashes is not None:
+ assert len(prev_events_and_hashes) <= 10, \
+ "Attempting to create an event with %i prev_events" % (
+ len(prev_events_and_hashes),
+ )
+ else:
+ prev_events_and_hashes = \
+ yield self.store.get_prev_events_for_room(builder.room_id)
+
+ if prev_events_and_hashes:
+ depth = max([d for _, _, d in prev_events_and_hashes]) + 1
+ # we cap depth of generated events, to ensure that they are not
+ # rejected by other servers (and so that they can be persisted in
+ # the db)
+ depth = min(depth, MAX_DEPTH)
+ else:
+ depth = 1
+
+ prev_events = [
+ (event_id, prev_hashes)
+ for event_id, prev_hashes, _ in prev_events_and_hashes
+ ]
builder.prev_events = prev_events
builder.depth = depth
@@ -678,8 +789,8 @@ class EventCreationHandler(object):
# Ensure that we can round trip before trying to persist in db
try:
- dump = ujson.dumps(unfreeze(event.content))
- ujson.loads(dump)
+ dump = frozendict_json_encoder.encode(event.content)
+ simplejson.loads(dump)
except Exception:
logger.exception("Failed to encode content: %r", event.content)
raise
@@ -713,8 +824,14 @@ class EventCreationHandler(object):
except: # noqa: E722, as we reraise the exception this is fine.
# Ensure that we actually remove the entries in the push actions
# staging area, if we calculated them.
- preserve_fn(self.store.remove_push_actions_from_staging)(event.event_id)
- raise
+ tp, value, tb = sys.exc_info()
+
+ run_in_background(
+ self.store.remove_push_actions_from_staging,
+ event.event_id,
+ )
+
+ six.reraise(tp, value, tb)
@defer.inlineCallbacks
def persist_and_notify_client_event(
@@ -765,7 +882,7 @@ class EventCreationHandler(object):
state_to_include_ids = [
e_id
- for k, e_id in context.current_state_ids.iteritems()
+ for k, e_id in iteritems(context.current_state_ids)
if k[0] in self.hs.config.room_invite_state_types
or k == (EventTypes.Member, event.sender)
]
@@ -779,7 +896,7 @@ class EventCreationHandler(object):
"content": e.content,
"sender": e.sender,
}
- for e in state_to_include.itervalues()
+ for e in itervalues(state_to_include)
]
invitee = UserID.from_string(event.state_key)
@@ -834,22 +951,33 @@ class EventCreationHandler(object):
# this intentionally does not yield: we don't care about the result
# and don't need to wait for it.
- preserve_fn(self.pusher_pool.on_new_notifications)(
+ run_in_background(
+ self.pusher_pool.on_new_notifications,
event_stream_id, max_stream_id
)
@defer.inlineCallbacks
def _notify():
yield run_on_reactor()
- self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id,
- extra_users=extra_users
- )
+ try:
+ self.notifier.on_new_room_event(
+ event, event_stream_id, max_stream_id,
+ extra_users=extra_users
+ )
+ except Exception:
+ logger.exception("Error notifying about new room event")
- preserve_fn(_notify)()
+ run_in_background(_notify)
if event.type == EventTypes.Message:
- presence = self.hs.get_presence_handler()
# We don't want to block sending messages on any presence code. This
# matters as sometimes presence code can take a while.
- preserve_fn(presence.bump_presence_active_time)(requester.user)
+ run_in_background(self._bump_active_time, requester.user)
+
+ @defer.inlineCallbacks
+ def _bump_active_time(self, user):
+ try:
+ presence = self.hs.get_presence_handler()
+ yield presence.bump_presence_active_time(user)
+ except Exception:
+ logger.exception("Error bumping presence active time")
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index a5e501897c..26fc0d3ec7 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -25,38 +25,42 @@ The methods that define policy are:
from twisted.internet import defer, reactor
from contextlib import contextmanager
+from six import itervalues, iteritems
+
from synapse.api.errors import SynapseError
from synapse.api.constants import PresenceState
from synapse.storage.presence import UserPresenceState
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.async import Linearizer
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import run_in_background
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
from synapse.types import UserID, get_domain_from_id
-import synapse.metrics
+from synapse.metrics import LaterGauge
import logging
+from prometheus_client import Counter
logger = logging.getLogger(__name__)
-metrics = synapse.metrics.get_metrics_for(__name__)
-notified_presence_counter = metrics.register_counter("notified_presence")
-federation_presence_out_counter = metrics.register_counter("federation_presence_out")
-presence_updates_counter = metrics.register_counter("presence_updates")
-timers_fired_counter = metrics.register_counter("timers_fired")
-federation_presence_counter = metrics.register_counter("federation_presence")
-bump_active_time_counter = metrics.register_counter("bump_active_time")
+notified_presence_counter = Counter("synapse_handler_presence_notified_presence", "")
+federation_presence_out_counter = Counter(
+ "synapse_handler_presence_federation_presence_out", "")
+presence_updates_counter = Counter("synapse_handler_presence_presence_updates", "")
+timers_fired_counter = Counter("synapse_handler_presence_timers_fired", "")
+federation_presence_counter = Counter("synapse_handler_presence_federation_presence", "")
+bump_active_time_counter = Counter("synapse_handler_presence_bump_active_time", "")
-get_updates_counter = metrics.register_counter("get_updates", labels=["type"])
+get_updates_counter = Counter("synapse_handler_presence_get_updates", "", ["type"])
-notify_reason_counter = metrics.register_counter("notify_reason", labels=["reason"])
-state_transition_counter = metrics.register_counter(
- "state_transition", labels=["from", "to"]
+notify_reason_counter = Counter(
+ "synapse_handler_presence_notify_reason", "", ["reason"])
+state_transition_counter = Counter(
+ "synapse_handler_presence_state_transition", "", ["from", "to"]
)
@@ -87,6 +91,11 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
class PresenceHandler(object):
def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer):
+ """
self.is_mine = hs.is_mine
self.is_mine_id = hs.is_mine_id
self.clock = hs.get_clock()
@@ -94,7 +103,6 @@ class PresenceHandler(object):
self.wheel_timer = WheelTimer()
self.notifier = hs.get_notifier()
self.federation = hs.get_federation_sender()
-
self.state = hs.get_state_handler()
federation_registry = hs.get_federation_registry()
@@ -137,8 +145,9 @@ class PresenceHandler(object):
for state in active_presence
}
- metrics.register_callback(
- "user_to_current_state_size", lambda: len(self.user_to_current_state)
+ LaterGauge(
+ "synapse_handlers_presence_user_to_current_state_size", "", [],
+ lambda: len(self.user_to_current_state)
)
now = self.clock.time_msec()
@@ -208,7 +217,8 @@ class PresenceHandler(object):
60 * 1000,
)
- metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer))
+ LaterGauge("synapse_handlers_presence_wheel_timer_size", "", [],
+ lambda: len(self.wheel_timer))
@defer.inlineCallbacks
def _on_shutdown(self):
@@ -255,6 +265,14 @@ class PresenceHandler(object):
logger.info("Finished _persist_unpersisted_changes")
@defer.inlineCallbacks
+ def _update_states_and_catch_exception(self, new_states):
+ try:
+ res = yield self._update_states(new_states)
+ defer.returnValue(res)
+ except Exception:
+ logger.exception("Error updating presence")
+
+ @defer.inlineCallbacks
def _update_states(self, new_states):
"""Updates presence of users. Sets the appropriate timeouts. Pokes
the notifier and federation if and only if the changed presence state
@@ -303,10 +321,10 @@ class PresenceHandler(object):
# TODO: We should probably ensure there are no races hereafter
- presence_updates_counter.inc_by(len(new_states))
+ presence_updates_counter.inc(len(new_states))
if to_notify:
- notified_presence_counter.inc_by(len(to_notify))
+ notified_presence_counter.inc(len(to_notify))
yield self._persist_and_notify(to_notify.values())
self.unpersisted_users_changes |= set(s.user_id for s in new_states)
@@ -317,7 +335,7 @@ class PresenceHandler(object):
if user_id not in to_notify
}
if to_federation_ping:
- federation_presence_out_counter.inc_by(len(to_federation_ping))
+ federation_presence_out_counter.inc(len(to_federation_ping))
self._push_to_remotes(to_federation_ping.values())
@@ -355,7 +373,7 @@ class PresenceHandler(object):
for user_id in users_to_check
]
- timers_fired_counter.inc_by(len(states))
+ timers_fired_counter.inc(len(states))
changes = handle_timeouts(
states,
@@ -364,7 +382,7 @@ class PresenceHandler(object):
now=now,
)
- preserve_fn(self._update_states)(changes)
+ run_in_background(self._update_states_and_catch_exception, changes)
except Exception:
logger.exception("Exception in _handle_timeouts loop")
@@ -422,20 +440,23 @@ class PresenceHandler(object):
@defer.inlineCallbacks
def _end():
- if affect_presence:
+ try:
self.user_to_num_current_syncs[user_id] -= 1
prev_state = yield self.current_state_for_user(user_id)
yield self._update_states([prev_state.copy_and_replace(
last_user_sync_ts=self.clock.time_msec(),
)])
+ except Exception:
+ logger.exception("Error updating presence after sync")
@contextmanager
def _user_syncing():
try:
yield
finally:
- preserve_fn(_end)()
+ if affect_presence:
+ run_in_background(_end)
defer.returnValue(_user_syncing())
@@ -453,61 +474,6 @@ class PresenceHandler(object):
return syncing_user_ids
@defer.inlineCallbacks
- def update_external_syncs(self, process_id, syncing_user_ids):
- """Update the syncing users for an external process
-
- Args:
- process_id(str): An identifier for the process the users are
- syncing against. This allows synapse to process updates
- as user start and stop syncing against a given process.
- syncing_user_ids(set(str)): The set of user_ids that are
- currently syncing on that server.
- """
-
- # Grab the previous list of user_ids that were syncing on that process
- prev_syncing_user_ids = (
- self.external_process_to_current_syncs.get(process_id, set())
- )
- # Grab the current presence state for both the users that are syncing
- # now and the users that were syncing before this update.
- prev_states = yield self.current_state_for_users(
- syncing_user_ids | prev_syncing_user_ids
- )
- updates = []
- time_now_ms = self.clock.time_msec()
-
- # For each new user that is syncing check if we need to mark them as
- # being online.
- for new_user_id in syncing_user_ids - prev_syncing_user_ids:
- prev_state = prev_states[new_user_id]
- if prev_state.state == PresenceState.OFFLINE:
- updates.append(prev_state.copy_and_replace(
- state=PresenceState.ONLINE,
- last_active_ts=time_now_ms,
- last_user_sync_ts=time_now_ms,
- ))
- else:
- updates.append(prev_state.copy_and_replace(
- last_user_sync_ts=time_now_ms,
- ))
-
- # For each user that is still syncing or stopped syncing update the
- # last sync time so that we will correctly apply the grace period when
- # they stop syncing.
- for old_user_id in prev_syncing_user_ids:
- prev_state = prev_states[old_user_id]
- updates.append(prev_state.copy_and_replace(
- last_user_sync_ts=time_now_ms,
- ))
-
- yield self._update_states(updates)
-
- # Update the last updated time for the process. We expire the entries
- # if we don't receive an update in the given timeframe.
- self.external_process_last_updated_ms[process_id] = self.clock.time_msec()
- self.external_process_to_current_syncs[process_id] = syncing_user_ids
-
- @defer.inlineCallbacks
def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec):
"""Update the syncing users for an external process as a delta.
@@ -570,7 +536,7 @@ class PresenceHandler(object):
prev_state.copy_and_replace(
last_user_sync_ts=time_now_ms,
)
- for prev_state in prev_states.itervalues()
+ for prev_state in itervalues(prev_states)
])
self.external_process_last_updated_ms.pop(process_id, None)
@@ -593,14 +559,14 @@ class PresenceHandler(object):
for user_id in user_ids
}
- missing = [user_id for user_id, state in states.iteritems() if not state]
+ missing = [user_id for user_id, state in iteritems(states) if not state]
if missing:
# There are things not in our in memory cache. Lets pull them out of
# the database.
res = yield self.store.get_presence_for_users(missing)
states.update(res)
- missing = [user_id for user_id, state in states.iteritems() if not state]
+ missing = [user_id for user_id, state in iteritems(states) if not state]
if missing:
new = {
user_id: UserPresenceState.default(user_id)
@@ -696,7 +662,7 @@ class PresenceHandler(object):
updates.append(prev_state.copy_and_replace(**new_fields))
if updates:
- federation_presence_counter.inc_by(len(updates))
+ federation_presence_counter.inc(len(updates))
yield self._update_states(updates)
@defer.inlineCallbacks
@@ -971,28 +937,28 @@ def should_notify(old_state, new_state):
return False
if old_state.status_msg != new_state.status_msg:
- notify_reason_counter.inc("status_msg_change")
+ notify_reason_counter.labels("status_msg_change").inc()
return True
if old_state.state != new_state.state:
- notify_reason_counter.inc("state_change")
- state_transition_counter.inc(old_state.state, new_state.state)
+ notify_reason_counter.labels("state_change").inc()
+ state_transition_counter.labels(old_state.state, new_state.state).inc()
return True
if old_state.state == PresenceState.ONLINE:
if new_state.currently_active != old_state.currently_active:
- notify_reason_counter.inc("current_active_change")
+ notify_reason_counter.labels("current_active_change").inc()
return True
if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
# Only notify about last active bumps if we're not currently acive
if not new_state.currently_active:
- notify_reason_counter.inc("last_active_change_online")
+ notify_reason_counter.labels("last_active_change_online").inc()
return True
elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
# Always notify for a transition where last active gets bumped.
- notify_reason_counter.inc("last_active_change_not_online")
+ notify_reason_counter.labels("last_active_change_not_online").inc()
return True
return False
@@ -1066,14 +1032,14 @@ class PresenceEventSource(object):
if changed is not None and len(changed) < 500:
# For small deltas, its quicker to get all changes and then
# work out if we share a room or they're in our presence list
- get_updates_counter.inc("stream")
+ get_updates_counter.labels("stream").inc()
for other_user_id in changed:
if other_user_id in users_interested_in:
user_ids_changed.add(other_user_id)
else:
# Too many possible updates. Find all users we can see and check
# if any of them have changed.
- get_updates_counter.inc("full")
+ get_updates_counter.labels("full").inc()
if from_key:
user_ids_changed = stream_change_cache.get_entities_changed(
@@ -1088,7 +1054,7 @@ class PresenceEventSource(object):
defer.returnValue((updates.values(), max_token))
else:
defer.returnValue(([
- s for s in updates.itervalues()
+ s for s in itervalues(updates)
if s.state != PresenceState.OFFLINE
], max_token))
@@ -1345,11 +1311,11 @@ def get_interested_remotes(store, states, state_handler):
# hosts in those rooms.
room_ids_to_states, users_to_states = yield get_interested_parties(store, states)
- for room_id, states in room_ids_to_states.iteritems():
+ for room_id, states in iteritems(room_ids_to_states):
hosts = yield state_handler.get_current_hosts_in_room(room_id)
hosts_and_states.append((hosts, states))
- for user_id, states in users_to_states.iteritems():
+ for user_id, states in iteritems(users_to_states):
host = get_domain_from_id(user_id)
hosts_and_states.append(([host], states))
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 3f215c2b4e..2e0672161c 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -135,37 +135,40 @@ class ReceiptsHandler(BaseHandler):
"""Given a list of receipts, works out which remote servers should be
poked and pokes them.
"""
- # TODO: Some of this stuff should be coallesced.
- for receipt in receipts:
- room_id = receipt["room_id"]
- receipt_type = receipt["receipt_type"]
- user_id = receipt["user_id"]
- event_ids = receipt["event_ids"]
- data = receipt["data"]
-
- users = yield self.state.get_current_user_in_room(room_id)
- remotedomains = set(get_domain_from_id(u) for u in users)
- remotedomains = remotedomains.copy()
- remotedomains.discard(self.server_name)
-
- logger.debug("Sending receipt to: %r", remotedomains)
-
- for domain in remotedomains:
- self.federation.send_edu(
- destination=domain,
- edu_type="m.receipt",
- content={
- room_id: {
- receipt_type: {
- user_id: {
- "event_ids": event_ids,
- "data": data,
+ try:
+ # TODO: Some of this stuff should be coallesced.
+ for receipt in receipts:
+ room_id = receipt["room_id"]
+ receipt_type = receipt["receipt_type"]
+ user_id = receipt["user_id"]
+ event_ids = receipt["event_ids"]
+ data = receipt["data"]
+
+ users = yield self.state.get_current_user_in_room(room_id)
+ remotedomains = set(get_domain_from_id(u) for u in users)
+ remotedomains = remotedomains.copy()
+ remotedomains.discard(self.server_name)
+
+ logger.debug("Sending receipt to: %r", remotedomains)
+
+ for domain in remotedomains:
+ self.federation.send_edu(
+ destination=domain,
+ edu_type="m.receipt",
+ content={
+ room_id: {
+ receipt_type: {
+ user_id: {
+ "event_ids": event_ids,
+ "data": data,
+ }
}
- }
+ },
},
- },
- key=(room_id, receipt_type, user_id),
- )
+ key=(room_id, receipt_type, user_id),
+ )
+ except Exception:
+ logger.exception("Error pushing receipts to remote servers")
@defer.inlineCallbacks
def get_receipts_for_room(self, room_id, to_key):
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index ed5939880a..7e52adda3c 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -23,8 +23,8 @@ from synapse.api.errors import (
)
from synapse.http.client import CaptchaServerHttpClient
from synapse import types
-from synapse.types import UserID
-from synapse.util.async import run_on_reactor
+from synapse.types import UserID, create_requester, RoomID, RoomAlias
+from synapse.util.async import run_on_reactor, Linearizer
from synapse.util.threepids import check_3pid_allowed
from ._base import BaseHandler
@@ -34,6 +34,11 @@ logger = logging.getLogger(__name__)
class RegistrationHandler(BaseHandler):
def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer):
+ """
super(RegistrationHandler, self).__init__(hs)
self.auth = hs.get_auth()
@@ -46,6 +51,11 @@ class RegistrationHandler(BaseHandler):
self.macaroon_gen = hs.get_macaroon_generator()
+ self._generate_user_id_linearizer = Linearizer(
+ name="_generate_user_id_linearizer",
+ )
+ self._server_notices_mxid = hs.config.server_notices_mxid
+
@defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None,
assigned_user_id=None):
@@ -201,10 +211,17 @@ class RegistrationHandler(BaseHandler):
token = None
attempts += 1
+ # auto-join the user to any rooms we're supposed to dump them into
+ fake_requester = create_requester(user_id)
+ for r in self.hs.config.auto_join_rooms:
+ try:
+ yield self._join_user_to_room(fake_requester, r)
+ except Exception as e:
+ logger.error("Failed to join new user to %r: %r", r, e)
+
# We used to generate default identicons here, but nowadays
# we want clients to generate their own as part of their branding
# rather than there being consistent matrix-wide ones, so we don't.
-
defer.returnValue((user_id, token))
@defer.inlineCallbacks
@@ -327,6 +344,14 @@ class RegistrationHandler(BaseHandler):
yield identity_handler.bind_threepid(c, user_id)
def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
+ # don't allow people to register the server notices mxid
+ if self._server_notices_mxid is not None:
+ if user_id == self._server_notices_mxid:
+ raise SynapseError(
+ 400, "This user ID is reserved.",
+ errcode=Codes.EXCLUSIVE
+ )
+
# valid user IDs must not clash with any user ID namespaces claimed by
# application services.
services = self.store.get_app_services()
@@ -345,9 +370,11 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks
def _generate_user_id(self, reseed=False):
if reseed or self._next_generated_user_id is None:
- self._next_generated_user_id = (
- yield self.store.find_next_generated_user_id_localpart()
- )
+ with (yield self._generate_user_id_linearizer.queue(())):
+ if reseed or self._next_generated_user_id is None:
+ self._next_generated_user_id = (
+ yield self.store.find_next_generated_user_id_localpart()
+ )
id = self._next_generated_user_id
self._next_generated_user_id += 1
@@ -477,3 +504,28 @@ class RegistrationHandler(BaseHandler):
)
defer.returnValue((user_id, access_token))
+
+ @defer.inlineCallbacks
+ def _join_user_to_room(self, requester, room_identifier):
+ room_id = None
+ room_member_handler = self.hs.get_room_member_handler()
+ if RoomID.is_valid(room_identifier):
+ room_id = room_identifier
+ elif RoomAlias.is_valid(room_identifier):
+ room_alias = RoomAlias.from_string(room_identifier)
+ room_id, remote_room_hosts = (
+ yield room_member_handler.lookup_room_alias(room_alias)
+ )
+ room_id = room_id.to_string()
+ else:
+ raise SynapseError(400, "%s was not legal room ID or room alias" % (
+ room_identifier,
+ ))
+
+ yield room_member_handler.update_membership(
+ requester=requester,
+ target=requester.user,
+ room_id=room_id,
+ remote_room_hosts=remote_room_hosts,
+ action="join",
+ )
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 8df8fcbbad..b5850db42f 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -68,14 +68,27 @@ class RoomCreationHandler(BaseHandler):
self.event_creation_handler = hs.get_event_creation_handler()
@defer.inlineCallbacks
- def create_room(self, requester, config, ratelimit=True):
+ def create_room(self, requester, config, ratelimit=True,
+ creator_join_profile=None):
""" Creates a new room.
Args:
- requester (Requester): The user who requested the room creation.
+ requester (synapse.types.Requester):
+ The user who requested the room creation.
config (dict) : A dict of configuration options.
+ ratelimit (bool): set to False to disable the rate limiter
+
+ creator_join_profile (dict|None):
+ Set to override the displayname and avatar for the creating
+ user in this room. If unset, displayname and avatar will be
+ derived from the user's profile. If set, should contain the
+ values to go in the body of the 'join' event (typically
+ `avatar_url` and/or `displayname`.
+
Returns:
- The new room ID.
+ Deferred[dict]:
+ a dict containing the keys `room_id` and, if an alias was
+ requested, `room_alias`.
Raises:
SynapseError if the room ID couldn't be stored, or something went
horribly wrong.
@@ -113,6 +126,10 @@ class RoomCreationHandler(BaseHandler):
except Exception:
raise SynapseError(400, "Invalid user_id: %s" % (i,))
+ yield self.event_creation_handler.assert_accepted_privacy_policy(
+ requester,
+ )
+
invite_3pid_list = config.get("invite_3pid", [])
visibility = config.get("visibility", None)
@@ -176,7 +193,8 @@ class RoomCreationHandler(BaseHandler):
initial_state=initial_state,
creation_content=creation_content,
room_alias=room_alias,
- power_level_content_override=config.get("power_level_content_override", {})
+ power_level_content_override=config.get("power_level_content_override", {}),
+ creator_join_profile=creator_join_profile,
)
if "name" in config:
@@ -256,6 +274,7 @@ class RoomCreationHandler(BaseHandler):
creation_content,
room_alias,
power_level_content_override,
+ creator_join_profile,
):
def create(etype, content, **kwargs):
e = {
@@ -299,6 +318,7 @@ class RoomCreationHandler(BaseHandler):
room_id,
"join",
ratelimit=False,
+ content=creator_join_profile,
)
# We treat the power levels override specially as this needs to be one
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 5d81f59b44..5757bb7f8a 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -15,12 +15,13 @@
from twisted.internet import defer
+from six.moves import range
+
from ._base import BaseHandler
from synapse.api.constants import (
EventTypes, JoinRules,
)
-from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.async import concurrently_execute
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.response_cache import ResponseCache
@@ -44,8 +45,9 @@ EMTPY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
class RoomListHandler(BaseHandler):
def __init__(self, hs):
super(RoomListHandler, self).__init__(hs)
- self.response_cache = ResponseCache(hs)
- self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000)
+ self.response_cache = ResponseCache(hs, "room_list")
+ self.remote_response_cache = ResponseCache(hs, "remote_room_list",
+ timeout_ms=30 * 1000)
def get_local_public_room_list(self, limit=None, since_token=None,
search_filter=None,
@@ -77,18 +79,11 @@ class RoomListHandler(BaseHandler):
)
key = (limit, since_token, network_tuple)
- result = self.response_cache.get(key)
- if not result:
- logger.info("No cached result, calculating one.")
- result = self.response_cache.set(
- key,
- preserve_fn(self._get_public_room_list)(
- limit, since_token, network_tuple=network_tuple
- )
- )
- else:
- logger.info("Using cached deferred result.")
- return make_deferred_yieldable(result)
+ return self.response_cache.wrap(
+ key,
+ self._get_public_room_list,
+ limit, since_token, network_tuple=network_tuple,
+ )
@defer.inlineCallbacks
def _get_public_room_list(self, limit=None, since_token=None,
@@ -207,7 +202,7 @@ class RoomListHandler(BaseHandler):
step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1
chunk = []
- for i in xrange(0, len(rooms_to_scan), step):
+ for i in range(0, len(rooms_to_scan), step):
batch = rooms_to_scan[i:i + step]
logger.info("Processing %i rooms for result", len(batch))
yield concurrently_execute(
@@ -422,18 +417,14 @@ class RoomListHandler(BaseHandler):
server_name, limit, since_token, include_all_networks,
third_party_instance_id,
)
- result = self.remote_response_cache.get(key)
- if not result:
- result = self.remote_response_cache.set(
- key,
- repl_layer.get_public_rooms(
- server_name, limit=limit, since_token=since_token,
- search_filter=search_filter,
- include_all_networks=include_all_networks,
- third_party_instance_id=third_party_instance_id,
- )
- )
- return result
+ return self.remote_response_cache.wrap(
+ key,
+ repl_layer.get_public_rooms,
+ server_name, limit=limit, since_token=since_token,
+ search_filter=search_filter,
+ include_all_networks=include_all_networks,
+ third_party_instance_id=third_party_instance_id,
+ )
class RoomListNextBatch(namedtuple("RoomListNextBatch", (
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 9977be8831..f930e939e8 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -17,11 +17,14 @@
import abc
import logging
+from six.moves import http_client
+
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from twisted.internet import defer
from unpaddedbase64 import decode_base64
+import synapse.server
import synapse.types
from synapse.api.constants import (
EventTypes, Membership,
@@ -46,6 +49,11 @@ class RoomMemberHandler(object):
__metaclass__ = abc.ABCMeta
def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer):
+ """
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
@@ -63,6 +71,7 @@ class RoomMemberHandler(object):
self.clock = hs.get_clock()
self.spam_checker = hs.get_spam_checker()
+ self._server_notices_mxid = self.config.server_notices_mxid
@abc.abstractmethod
def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
@@ -149,7 +158,7 @@ class RoomMemberHandler(object):
@defer.inlineCallbacks
def _local_membership_update(
self, requester, target, room_id, membership,
- prev_event_ids,
+ prev_events_and_hashes,
txn_id=None,
ratelimit=True,
content=None,
@@ -175,7 +184,7 @@ class RoomMemberHandler(object):
},
token_id=requester.access_token_id,
txn_id=txn_id,
- prev_event_ids=prev_event_ids,
+ prev_events_and_hashes=prev_events_and_hashes,
)
# Check if this event matches the previous membership event for the user.
@@ -290,11 +299,26 @@ class RoomMemberHandler(object):
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
- if effective_membership_state == "invite":
+ if effective_membership_state == Membership.INVITE:
+ # block any attempts to invite the server notices mxid
+ if target.to_string() == self._server_notices_mxid:
+ raise SynapseError(
+ http_client.FORBIDDEN,
+ "Cannot invite this user",
+ )
+
block_invite = False
- is_requester_admin = yield self.auth.is_server_admin(
- requester.user,
- )
+
+ if (self._server_notices_mxid is not None and
+ requester.user.to_string() == self._server_notices_mxid):
+ # allow the server notices mxid to send invites
+ is_requester_admin = True
+
+ else:
+ is_requester_admin = yield self.auth.is_server_admin(
+ requester.user,
+ )
+
if not is_requester_admin:
if self.config.block_non_admin_invites:
logger.info(
@@ -314,7 +338,12 @@ class RoomMemberHandler(object):
403, "Invites have been disabled on this server",
)
- latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+ prev_events_and_hashes = yield self.store.get_prev_events_for_room(
+ room_id,
+ )
+ latest_event_ids = (
+ event_id for (event_id, _, _) in prev_events_and_hashes
+ )
current_state_ids = yield self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids,
)
@@ -344,6 +373,20 @@ class RoomMemberHandler(object):
if same_sender and same_membership and same_content:
defer.returnValue(old_state)
+ # we don't allow people to reject invites to the server notice
+ # room, but they can leave it once they are joined.
+ if (
+ old_membership == Membership.INVITE and
+ effective_membership_state == Membership.LEAVE
+ ):
+ is_blocked = yield self._is_server_notice_room(room_id)
+ if is_blocked:
+ raise SynapseError(
+ http_client.FORBIDDEN,
+ "You cannot reject this invite",
+ errcode=Codes.CANNOT_LEAVE_SERVER_NOTICE_ROOM,
+ )
+
is_host_in_room = yield self._is_host_in_room(current_state_ids)
if effective_membership_state == Membership.JOIN:
@@ -403,7 +446,7 @@ class RoomMemberHandler(object):
membership=effective_membership_state,
txn_id=txn_id,
ratelimit=ratelimit,
- prev_event_ids=latest_event_ids,
+ prev_events_and_hashes=prev_events_and_hashes,
content=content,
)
defer.returnValue(res)
@@ -839,6 +882,13 @@ class RoomMemberHandler(object):
defer.returnValue(False)
+ @defer.inlineCallbacks
+ def _is_server_notice_room(self, room_id):
+ if self._server_notices_mxid is None:
+ defer.returnValue(False)
+ user_ids = yield self.store.get_users_in_room(room_id)
+ defer.returnValue(self._server_notices_mxid in user_ids)
+
class RoomMemberMasterHandler(RoomMemberHandler):
def __init__(self, hs):
@@ -852,6 +902,14 @@ class RoomMemberMasterHandler(RoomMemberHandler):
def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
"""Implements RoomMemberHandler._remote_join
"""
+ # filter ourselves out of remote_room_hosts: do_invite_join ignores it
+ # and if it is the only entry we'd like to return a 404 rather than a
+ # 500.
+
+ remote_room_hosts = [
+ host for host in remote_room_hosts if host != self.hs.hostname
+ ]
+
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index f521d22e91..05bf6d46dd 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -16,7 +16,7 @@
from synapse.api.constants import Membership, EventTypes
from synapse.util.async import concurrently_execute
-from synapse.util.logcontext import LoggingContext, make_deferred_yieldable, preserve_fn
+from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure, measure_func
from synapse.util.caches.response_cache import ResponseCache
from synapse.push.clientformat import format_push_rules_for_user
@@ -29,6 +29,8 @@ import collections
import logging
import itertools
+from six import itervalues, iteritems
+
logger = logging.getLogger(__name__)
@@ -53,6 +55,7 @@ class TimelineBatch(collections.namedtuple("TimelineBatch", [
to tell if room needs to be part of the sync result.
"""
return bool(self.events)
+ __bool__ = __nonzero__ # python3
class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
@@ -77,6 +80,7 @@ class JoinedSyncResult(collections.namedtuple("JoinedSyncResult", [
# nb the notification count does not, er, count: if there's nothing
# else in the result, we don't need to send it.
)
+ __bool__ = __nonzero__ # python3
class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [
@@ -96,6 +100,7 @@ class ArchivedSyncResult(collections.namedtuple("ArchivedSyncResult", [
or self.state
or self.account_data
)
+ __bool__ = __nonzero__ # python3
class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
@@ -107,6 +112,7 @@ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
def __nonzero__(self):
"""Invited rooms should always be reported to the client"""
return True
+ __bool__ = __nonzero__ # python3
class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [
@@ -118,6 +124,7 @@ class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [
def __nonzero__(self):
return bool(self.join or self.invite or self.leave)
+ __bool__ = __nonzero__ # python3
class DeviceLists(collections.namedtuple("DeviceLists", [
@@ -128,6 +135,7 @@ class DeviceLists(collections.namedtuple("DeviceLists", [
def __nonzero__(self):
return bool(self.changed or self.left)
+ __bool__ = __nonzero__ # python3
class SyncResult(collections.namedtuple("SyncResult", [
@@ -160,6 +168,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
self.device_lists or
self.groups
)
+ __bool__ = __nonzero__ # python3
class SyncHandler(object):
@@ -170,7 +179,7 @@ class SyncHandler(object):
self.presence_handler = hs.get_presence_handler()
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
- self.response_cache = ResponseCache(hs)
+ self.response_cache = ResponseCache(hs, "sync")
self.state = hs.get_state_handler()
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
@@ -181,15 +190,11 @@ class SyncHandler(object):
Returns:
A Deferred SyncResult.
"""
- result = self.response_cache.get(sync_config.request_key)
- if not result:
- result = self.response_cache.set(
- sync_config.request_key,
- preserve_fn(self._wait_for_sync_for_user)(
- sync_config, since_token, timeout, full_state
- )
- )
- return make_deferred_yieldable(result)
+ return self.response_cache.wrap(
+ sync_config.request_key,
+ self._wait_for_sync_for_user,
+ sync_config, since_token, timeout, full_state,
+ )
@defer.inlineCallbacks
def _wait_for_sync_for_user(self, sync_config, since_token, timeout,
@@ -273,7 +278,7 @@ class SyncHandler(object):
# result returned by the event source is poor form (it might cache
# the object)
room_id = event["room_id"]
- event_copy = {k: v for (k, v) in event.iteritems()
+ event_copy = {k: v for (k, v) in iteritems(event)
if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
@@ -292,7 +297,7 @@ class SyncHandler(object):
for event in receipts:
room_id = event["room_id"]
# exclude room id, as above
- event_copy = {k: v for (k, v) in event.iteritems()
+ event_copy = {k: v for (k, v) in iteritems(event)
if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
@@ -323,7 +328,7 @@ class SyncHandler(object):
current_state_ids = frozenset()
if any(e.is_state() for e in recents):
current_state_ids = yield self.state.get_current_state_ids(room_id)
- current_state_ids = frozenset(current_state_ids.itervalues())
+ current_state_ids = frozenset(itervalues(current_state_ids))
recents = yield filter_events_for_client(
self.store,
@@ -352,12 +357,24 @@ class SyncHandler(object):
since_key = since_token.room_key
while limited and len(recents) < timeline_limit and max_repeat:
- events, end_key = yield self.store.get_room_events_stream_for_room(
- room_id,
- limit=load_limit + 1,
- from_key=since_key,
- to_key=end_key,
- )
+ # If we have a since_key then we are trying to get any events
+ # that have happened since `since_key` up to `end_key`, so we
+ # can just use `get_room_events_stream_for_room`.
+ # Otherwise, we want to return the last N events in the room
+ # in toplogical ordering.
+ if since_key:
+ events, end_key = yield self.store.get_room_events_stream_for_room(
+ room_id,
+ limit=load_limit + 1,
+ from_key=since_key,
+ to_key=end_key,
+ )
+ else:
+ events, end_key = yield self.store.get_recent_events_for_room(
+ room_id,
+ limit=load_limit + 1,
+ end_token=end_key,
+ )
loaded_recents = sync_config.filter_collection.filter_room_timeline(
events
)
@@ -368,7 +385,7 @@ class SyncHandler(object):
current_state_ids = frozenset()
if any(e.is_state() for e in loaded_recents):
current_state_ids = yield self.state.get_current_state_ids(room_id)
- current_state_ids = frozenset(current_state_ids.itervalues())
+ current_state_ids = frozenset(itervalues(current_state_ids))
loaded_recents = yield filter_events_for_client(
self.store,
@@ -427,7 +444,7 @@ class SyncHandler(object):
Returns:
A Deferred map from ((type, state_key)->Event)
"""
- last_events, token = yield self.store.get_recent_events_for_room(
+ last_events, _ = yield self.store.get_recent_events_for_room(
room_id, end_token=stream_position.room_key, limit=1,
)
@@ -1026,7 +1043,7 @@ class SyncHandler(object):
if since_token:
for joined_sync in sync_result_builder.joined:
it = itertools.chain(
- joined_sync.timeline.events, joined_sync.state.itervalues()
+ joined_sync.timeline.events, itervalues(joined_sync.state)
)
for event in it:
if event.type == EventTypes.Member:
@@ -1104,7 +1121,7 @@ class SyncHandler(object):
newly_left_rooms = []
room_entries = []
invited = []
- for room_id, events in mem_change_events_by_room_id.iteritems():
+ for room_id, events in iteritems(mem_change_events_by_room_id):
non_joins = [e for e in events if e.membership != Membership.JOIN]
has_join = len(non_joins) != len(events)
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 77c0cf146f..5d9736e88f 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -16,7 +16,7 @@
from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import run_in_background
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
from synapse.types import UserID, get_domain_from_id
@@ -97,7 +97,8 @@ class TypingHandler(object):
if self.hs.is_mine_id(member.user_id):
last_fed_poke = self._member_last_federation_poke.get(member, None)
if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now:
- preserve_fn(self._push_remote)(
+ run_in_background(
+ self._push_remote,
member=member,
typing=True
)
@@ -196,7 +197,7 @@ class TypingHandler(object):
def _push_update(self, member, typing):
if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users.
- preserve_fn(self._push_remote)(member, typing)
+ run_in_background(self._push_remote, member, typing)
self._push_update_local(
member=member,
@@ -205,28 +206,31 @@ class TypingHandler(object):
@defer.inlineCallbacks
def _push_remote(self, member, typing):
- users = yield self.state.get_current_user_in_room(member.room_id)
- self._member_last_federation_poke[member] = self.clock.time_msec()
+ try:
+ users = yield self.state.get_current_user_in_room(member.room_id)
+ self._member_last_federation_poke[member] = self.clock.time_msec()
- now = self.clock.time_msec()
- self.wheel_timer.insert(
- now=now,
- obj=member,
- then=now + FEDERATION_PING_INTERVAL,
- )
+ now = self.clock.time_msec()
+ self.wheel_timer.insert(
+ now=now,
+ obj=member,
+ then=now + FEDERATION_PING_INTERVAL,
+ )
- for domain in set(get_domain_from_id(u) for u in users):
- if domain != self.server_name:
- self.federation.send_edu(
- destination=domain,
- edu_type="m.typing",
- content={
- "room_id": member.room_id,
- "user_id": member.user_id,
- "typing": typing,
- },
- key=member,
- )
+ for domain in set(get_domain_from_id(u) for u in users):
+ if domain != self.server_name:
+ self.federation.send_edu(
+ destination=domain,
+ edu_type="m.typing",
+ content={
+ "room_id": member.room_id,
+ "user_id": member.user_id,
+ "typing": typing,
+ },
+ key=member,
+ )
+ except Exception:
+ logger.exception("Error pushing typing notif to remotes")
@defer.inlineCallbacks
def _recv_edu(self, origin, content):
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 714f0195c8..a39f0f7343 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -22,6 +22,7 @@ from synapse.util.metrics import Measure
from synapse.util.async import sleep
from synapse.types import get_localpart_from_id
+from six import iteritems
logger = logging.getLogger(__name__)
@@ -123,6 +124,13 @@ class UserDirectoryHandler(object):
)
@defer.inlineCallbacks
+ def handle_user_deactivated(self, user_id):
+ """Called when a user ID is deactivated
+ """
+ yield self.store.remove_from_user_dir(user_id)
+ yield self.store.remove_from_user_in_public_room(user_id)
+
+ @defer.inlineCallbacks
def _unsafe_process(self):
# If self.pos is None then means we haven't fetched it from DB
if self.pos is None:
@@ -403,7 +411,7 @@ class UserDirectoryHandler(object):
if change:
users_with_profile = yield self.state.get_current_user_in_room(room_id)
- for user_id, profile in users_with_profile.iteritems():
+ for user_id, profile in iteritems(users_with_profile):
yield self._handle_new_user(room_id, user_id, profile)
else:
users = yield self.store.get_users_in_public_due_to_room(room_id)
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index bfebb0f644..054372e179 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,3 +13,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet.defer import CancelledError
+from twisted.python import failure
+
+from synapse.api.errors import SynapseError
+
+
+class RequestTimedOutError(SynapseError):
+ """Exception representing timeout of an outbound request"""
+ def __init__(self):
+ super(RequestTimedOutError, self).__init__(504, "Timed out")
+
+
+def cancelled_to_request_timed_out_error(value, timeout):
+ """Turns CancelledErrors into RequestTimedOutErrors.
+
+ For use with async.add_timeout_to_deferred
+ """
+ if isinstance(value, failure.Failure):
+ value.trap(CancelledError)
+ raise RequestTimedOutError()
+ return value
diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py
index 343e932cb1..a797396ade 100644
--- a/synapse/http/additional_resource.py
+++ b/synapse/http/additional_resource.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.http.server import wrap_request_handler
+from synapse.http.server import wrap_json_request_handler
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
@@ -42,14 +42,13 @@ class AdditionalResource(Resource):
Resource.__init__(self)
self._handler = handler
- # these are required by the request_handler wrapper
- self.version_string = hs.version_string
+ # required by the request_handler wrapper
self.clock = hs.get_clock()
def render(self, request):
self._async_render(request)
return NOT_DONE_YET
- @wrap_request_handler
+ @wrap_json_request_handler
def _async_render(self, request):
return self._handler(request)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index f3e4973c2e..4d4eee3d64 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,10 +19,10 @@ from OpenSSL.SSL import VERIFY_NONE
from synapse.api.errors import (
CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
)
+from synapse.http import cancelled_to_request_timed_out_error
+from synapse.util.async import add_timeout_to_deferred
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.logcontext import make_deferred_yieldable
-from synapse.util import logcontext
-import synapse.metrics
from synapse.http.endpoint import SpiderEndpoint
from canonicaljson import encode_canonical_json
@@ -38,8 +39,9 @@ from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
from twisted.web._newclient import ResponseDone
-from StringIO import StringIO
+from six import StringIO
+from prometheus_client import Counter
import simplejson as json
import logging
import urllib
@@ -47,16 +49,9 @@ import urllib
logger = logging.getLogger(__name__)
-metrics = synapse.metrics.get_metrics_for(__name__)
-
-outgoing_requests_counter = metrics.register_counter(
- "requests",
- labels=["method"],
-)
-incoming_responses_counter = metrics.register_counter(
- "responses",
- labels=["method", "code"],
-)
+outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"])
+incoming_responses_counter = Counter("synapse_http_client_responses", "",
+ ["method", "code"])
class SimpleHttpClient(object):
@@ -93,32 +88,28 @@ class SimpleHttpClient(object):
def request(self, method, uri, *args, **kwargs):
# A small wrapper around self.agent.request() so we can easily attach
# counters to it
- outgoing_requests_counter.inc(method)
+ outgoing_requests_counter.labels(method).inc()
- def send_request():
+ logger.info("Sending request %s %s", method, uri)
+
+ try:
request_deferred = self.agent.request(
method, uri, *args, **kwargs
)
-
- return self.clock.time_bound_deferred(
+ add_timeout_to_deferred(
request_deferred,
- time_out=60,
+ 60, cancelled_to_request_timed_out_error,
)
+ response = yield make_deferred_yieldable(request_deferred)
- logger.info("Sending request %s %s", method, uri)
-
- try:
- with logcontext.PreserveLoggingContext():
- response = yield send_request()
-
- incoming_responses_counter.inc(method, response.code)
+ incoming_responses_counter.labels(method, response.code).inc()
logger.info(
"Received response to %s %s: %s",
method, uri, response.code
)
defer.returnValue(response)
except Exception as e:
- incoming_responses_counter.inc(method, "ERR")
+ incoming_responses_counter.labels(method, "ERR").inc()
logger.info(
"Error sending request to %s %s: %s %s",
method, uri, type(e).__name__, e.message
@@ -509,7 +500,7 @@ class SpiderHttpClient(SimpleHttpClient):
reactor,
SpiderEndpointFactory(hs)
)
- ), [('gzip', GzipDecoder)]
+ ), [(b'gzip', GzipDecoder)]
)
# We could look like Chrome:
# self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko)
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index 87639b9151..87a482650d 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -12,8 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import socket
-
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet import defer, reactor
from twisted.internet.error import ConnectError
@@ -33,7 +31,7 @@ SERVER_CACHE = {}
# our record of an individual server which can be tried to reach a destination.
#
-# "host" is actually a dotted-quad or ipv6 address string. Except when there's
+# "host" is the hostname acquired from the SRV record. Except when there's
# no SRV record, in which case it is the original hostname.
_Server = collections.namedtuple(
"_Server", "priority weight host port expires"
@@ -117,10 +115,15 @@ class _WrappedConnection(object):
if time.time() - self.last_request >= 2.5 * 60:
self.abort()
# Abort the underlying TLS connection. The abort() method calls
- # loseConnection() on the underlying TLS connection which tries to
+ # loseConnection() on the TLS connection which tries to
# shutdown the connection cleanly. We call abortConnection()
- # since that will promptly close the underlying TCP connection.
- self.transport.abortConnection()
+ # since that will promptly close the TLS connection.
+ #
+ # In Twisted >18.4; the TLS connection will be None if it has closed
+ # which will make abortConnection() throw. Check that the TLS connection
+ # is not None before trying to close it.
+ if self.transport.getHandle() is not None:
+ self.transport.abortConnection()
def request(self, request):
self.last_request = time.time()
@@ -288,7 +291,7 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
if (len(answers) == 1
and answers[0].type == dns.SRV
and answers[0].payload
- and answers[0].payload.target == dns.Name('.')):
+ and answers[0].payload.target == dns.Name(b'.')):
raise ConnectError("Service %s unavailable" % service_name)
for answer in answers:
@@ -297,20 +300,13 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
payload = answer.payload
- hosts = yield _get_hosts_for_srv_record(
- dns_client, str(payload.target)
- )
-
- for (ip, ttl) in hosts:
- host_ttl = min(answer.ttl, ttl)
-
- servers.append(_Server(
- host=ip,
- port=int(payload.port),
- priority=int(payload.priority),
- weight=int(payload.weight),
- expires=int(clock.time()) + host_ttl,
- ))
+ servers.append(_Server(
+ host=str(payload.target),
+ port=int(payload.port),
+ priority=int(payload.priority),
+ weight=int(payload.weight),
+ expires=int(clock.time()) + answer.ttl,
+ ))
servers.sort()
cache[service_name] = list(servers)
@@ -328,81 +324,3 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
raise e
defer.returnValue(servers)
-
-
-@defer.inlineCallbacks
-def _get_hosts_for_srv_record(dns_client, host):
- """Look up each of the hosts in a SRV record
-
- Args:
- dns_client (twisted.names.dns.IResolver):
- host (basestring): host to look up
-
- Returns:
- Deferred[list[(str, int)]]: a list of (host, ttl) pairs
-
- """
- ip4_servers = []
- ip6_servers = []
-
- def cb(res):
- # lookupAddress and lookupIP6Address return a three-tuple
- # giving the answer, authority, and additional sections of the
- # response.
- #
- # we only care about the answers.
-
- return res[0]
-
- def eb(res, record_type):
- if res.check(DNSNameError):
- return []
- logger.warn("Error looking up %s for %s: %s", record_type, host, res)
- return res
-
- # no logcontexts here, so we can safely fire these off and gatherResults
- d1 = dns_client.lookupAddress(host).addCallbacks(
- cb, eb, errbackArgs=("A", ))
- d2 = dns_client.lookupIPV6Address(host).addCallbacks(
- cb, eb, errbackArgs=("AAAA", ))
- results = yield defer.DeferredList(
- [d1, d2], consumeErrors=True)
-
- # if all of the lookups failed, raise an exception rather than blowing out
- # the cache with an empty result.
- if results and all(s == defer.FAILURE for (s, _) in results):
- defer.returnValue(results[0][1])
-
- for (success, result) in results:
- if success == defer.FAILURE:
- continue
-
- for answer in result:
- if not answer.payload:
- continue
-
- try:
- if answer.type == dns.A:
- ip = answer.payload.dottedQuad()
- ip4_servers.append((ip, answer.ttl))
- elif answer.type == dns.AAAA:
- ip = socket.inet_ntop(
- socket.AF_INET6, answer.payload.address,
- )
- ip6_servers.append((ip, answer.ttl))
- else:
- # the most likely candidate here is a CNAME record.
- # rfc2782 says srvs may not point to aliases.
- logger.warn(
- "Ignoring unexpected DNS record type %s for %s",
- answer.type, host,
- )
- continue
- except Exception as e:
- logger.warn("Ignoring invalid DNS response for %s: %s",
- host, e)
- continue
-
- # keep the ipv4 results before the ipv6 results, mostly to match historical
- # behaviour.
- defer.returnValue(ip4_servers + ip6_servers)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 9145405cb0..821aed362b 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,17 +13,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import synapse.util.retryutils
from twisted.internet import defer, reactor, protocol
from twisted.internet.error import DNSLookupError
from twisted.web.client import readBody, HTTPConnectionPool, Agent
from twisted.web.http_headers import Headers
from twisted.web._newclient import ResponseDone
+from synapse.http import cancelled_to_request_timed_out_error
from synapse.http.endpoint import matrix_federation_endpoint
-from synapse.util.async import sleep
-from synapse.util import logcontext
import synapse.metrics
+from synapse.util.async import sleep, add_timeout_to_deferred
+from synapse.util import logcontext
+from synapse.util.logcontext import make_deferred_yieldable
+import synapse.util.retryutils
from canonicaljson import encode_canonical_json
@@ -38,22 +41,19 @@ import logging
import random
import sys
import urllib
-import urlparse
+from six.moves.urllib import parse as urlparse
+from six import string_types
+
+from prometheus_client import Counter
logger = logging.getLogger(__name__)
outbound_logger = logging.getLogger("synapse.http.outbound")
-metrics = synapse.metrics.get_metrics_for(__name__)
-
-outgoing_requests_counter = metrics.register_counter(
- "requests",
- labels=["method"],
-)
-incoming_responses_counter = metrics.register_counter(
- "responses",
- labels=["method", "code"],
-)
+outgoing_requests_counter = Counter("synapse_http_matrixfederationclient_requests",
+ "", ["method"])
+incoming_responses_counter = Counter("synapse_http_matrixfederationclient_responses",
+ "", ["method", "code"])
MAX_LONG_RETRIES = 10
@@ -184,21 +184,20 @@ class MatrixFederationHttpClient(object):
producer = body_callback(method, http_url_bytes, headers_dict)
try:
- def send_request():
- request_deferred = self.agent.request(
- method,
- url_bytes,
- Headers(headers_dict),
- producer
- )
-
- return self.clock.time_bound_deferred(
- request_deferred,
- time_out=timeout / 1000. if timeout else 60,
- )
-
- with logcontext.PreserveLoggingContext():
- response = yield send_request()
+ request_deferred = self.agent.request(
+ method,
+ url_bytes,
+ Headers(headers_dict),
+ producer
+ )
+ add_timeout_to_deferred(
+ request_deferred,
+ timeout / 1000. if timeout else 60,
+ cancelled_to_request_timed_out_error,
+ )
+ response = yield make_deferred_yieldable(
+ request_deferred,
+ )
log_result = "%d %s" % (response.code, response.phrase,)
break
@@ -286,7 +285,8 @@ class MatrixFederationHttpClient(object):
headers_dict[b"Authorization"] = auth_headers
@defer.inlineCallbacks
- def put_json(self, destination, path, data={}, json_data_callback=None,
+ def put_json(self, destination, path, args={}, data={},
+ json_data_callback=None,
long_retries=False, timeout=None,
ignore_backoff=False,
backoff_on_404=False):
@@ -296,6 +296,7 @@ class MatrixFederationHttpClient(object):
destination (str): The remote server to send the HTTP request
to.
path (str): The HTTP path.
+ args (dict): query params
data (dict): A dict containing the data that will be used as
the request body. This will be encoded as JSON.
json_data_callback (callable): A callable returning the dict to
@@ -342,6 +343,7 @@ class MatrixFederationHttpClient(object):
path,
body_callback=body_callback,
headers_dict={"Content-Type": ["application/json"]},
+ query_bytes=encode_query_args(args),
long_retries=long_retries,
timeout=timeout,
ignore_backoff=ignore_backoff,
@@ -373,6 +375,7 @@ class MatrixFederationHttpClient(object):
giving up. None indicates no timeout.
ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway.
+ args (dict): query params
Returns:
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body.
@@ -548,7 +551,7 @@ class MatrixFederationHttpClient(object):
encoded_args = {}
for k, vs in args.items():
- if isinstance(vs, basestring):
+ if isinstance(vs, string_types):
vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs]
@@ -663,7 +666,7 @@ def check_content_type_is_json(headers):
RuntimeError if the
"""
- c_type = headers.getRawHeaders("Content-Type")
+ c_type = headers.getRawHeaders(b"Content-Type")
if c_type is None:
raise RuntimeError(
"No Content-Type header"
@@ -680,7 +683,7 @@ def check_content_type_is_json(headers):
def encode_query_args(args):
encoded_args = {}
for k, vs in args.items():
- if isinstance(vs, basestring):
+ if isinstance(vs, string_types):
vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs]
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
new file mode 100644
index 0000000000..dc06f6c443
--- /dev/null
+++ b/synapse/http/request_metrics.py
@@ -0,0 +1,273 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from prometheus_client.core import Counter, Histogram
+from synapse.metrics import LaterGauge
+
+from synapse.util.logcontext import LoggingContext
+
+logger = logging.getLogger(__name__)
+
+
+# total number of responses served, split by method/servlet/tag
+response_count = Counter(
+ "synapse_http_server_response_count", "", ["method", "servlet", "tag"]
+)
+
+requests_counter = Counter(
+ "synapse_http_server_requests_received", "", ["method", "servlet"]
+)
+
+outgoing_responses_counter = Counter(
+ "synapse_http_server_responses", "", ["method", "code"]
+)
+
+response_timer = Histogram(
+ "synapse_http_server_response_time_seconds", "sec", ["method", "servlet", "tag"]
+)
+
+response_ru_utime = Counter(
+ "synapse_http_server_response_ru_utime_seconds", "sec", ["method", "servlet", "tag"]
+)
+
+response_ru_stime = Counter(
+ "synapse_http_server_response_ru_stime_seconds", "sec", ["method", "servlet", "tag"]
+)
+
+response_db_txn_count = Counter(
+ "synapse_http_server_response_db_txn_count", "", ["method", "servlet", "tag"]
+)
+
+# seconds spent waiting for db txns, excluding scheduling time, when processing
+# this request
+response_db_txn_duration = Counter(
+ "synapse_http_server_response_db_txn_duration_seconds",
+ "",
+ ["method", "servlet", "tag"],
+)
+
+# seconds spent waiting for a db connection, when processing this request
+response_db_sched_duration = Counter(
+ "synapse_http_server_response_db_sched_duration_seconds",
+ "",
+ ["method", "servlet", "tag"],
+)
+
+# size in bytes of the response written
+response_size = Counter(
+ "synapse_http_server_response_size", "", ["method", "servlet", "tag"]
+)
+
+# In flight metrics are incremented while the requests are in flight, rather
+# than when the response was written.
+
+in_flight_requests_ru_utime = Counter(
+ "synapse_http_server_in_flight_requests_ru_utime_seconds",
+ "",
+ ["method", "servlet"],
+)
+
+in_flight_requests_ru_stime = Counter(
+ "synapse_http_server_in_flight_requests_ru_stime_seconds",
+ "",
+ ["method", "servlet"],
+)
+
+in_flight_requests_db_txn_count = Counter(
+ "synapse_http_server_in_flight_requests_db_txn_count", "", ["method", "servlet"]
+)
+
+# seconds spent waiting for db txns, excluding scheduling time, when processing
+# this request
+in_flight_requests_db_txn_duration = Counter(
+ "synapse_http_server_in_flight_requests_db_txn_duration_seconds",
+ "",
+ ["method", "servlet"],
+)
+
+# seconds spent waiting for a db connection, when processing this request
+in_flight_requests_db_sched_duration = Counter(
+ "synapse_http_server_in_flight_requests_db_sched_duration_seconds",
+ "",
+ ["method", "servlet"],
+)
+
+# The set of all in flight requests, set[RequestMetrics]
+_in_flight_requests = set()
+
+
+def _get_in_flight_counts():
+ """Returns a count of all in flight requests by (method, server_name)
+
+ Returns:
+ dict[tuple[str, str], int]
+ """
+ for rm in _in_flight_requests:
+ rm.update_metrics()
+
+ # Map from (method, name) -> int, the number of in flight requests of that
+ # type
+ counts = {}
+ for rm in _in_flight_requests:
+ key = (rm.method, rm.name,)
+ counts[key] = counts.get(key, 0) + 1
+
+ return counts
+
+
+LaterGauge(
+ "synapse_http_request_metrics_in_flight_requests_count",
+ "",
+ ["method", "servlet"],
+ _get_in_flight_counts,
+)
+
+
+class RequestMetrics(object):
+ def start(self, time_sec, name, method):
+ self.start = time_sec
+ self.start_context = LoggingContext.current_context()
+ self.name = name
+ self.method = method
+
+ self._request_stats = _RequestStats.from_context(self.start_context)
+
+ _in_flight_requests.add(self)
+
+ def stop(self, time_sec, request):
+ _in_flight_requests.discard(self)
+
+ context = LoggingContext.current_context()
+
+ tag = ""
+ if context:
+ tag = context.tag
+
+ if context != self.start_context:
+ logger.warn(
+ "Context have unexpectedly changed %r, %r",
+ context, self.start_context
+ )
+ return
+
+ outgoing_responses_counter.labels(request.method, str(request.code)).inc()
+
+ response_count.labels(request.method, self.name, tag).inc()
+
+ response_timer.labels(request.method, self.name, tag).observe(
+ time_sec - self.start
+ )
+
+ ru_utime, ru_stime = context.get_resource_usage()
+
+ response_ru_utime.labels(request.method, self.name, tag).inc(ru_utime)
+ response_ru_stime.labels(request.method, self.name, tag).inc(ru_stime)
+ response_db_txn_count.labels(request.method, self.name, tag).inc(
+ context.db_txn_count
+ )
+ response_db_txn_duration.labels(request.method, self.name, tag).inc(
+ context.db_txn_duration_sec
+ )
+ response_db_sched_duration.labels(request.method, self.name, tag).inc(
+ context.db_sched_duration_sec
+ )
+
+ response_size.labels(request.method, self.name, tag).inc(request.sentLength)
+
+ # We always call this at the end to ensure that we update the metrics
+ # regardless of whether a call to /metrics while the request was in
+ # flight.
+ self.update_metrics()
+
+ def update_metrics(self):
+ """Updates the in flight metrics with values from this request.
+ """
+ diff = self._request_stats.update(self.start_context)
+
+ in_flight_requests_ru_utime.labels(self.method, self.name).inc(diff.ru_utime)
+ in_flight_requests_ru_stime.labels(self.method, self.name).inc(diff.ru_stime)
+
+ in_flight_requests_db_txn_count.labels(self.method, self.name).inc(
+ diff.db_txn_count
+ )
+
+ in_flight_requests_db_txn_duration.labels(self.method, self.name).inc(
+ diff.db_txn_duration_sec
+ )
+
+ in_flight_requests_db_sched_duration.labels(self.method, self.name).inc(
+ diff.db_sched_duration_sec
+ )
+
+
+class _RequestStats(object):
+ """Keeps tracks of various metrics for an in flight request.
+ """
+
+ __slots__ = [
+ "ru_utime",
+ "ru_stime",
+ "db_txn_count",
+ "db_txn_duration_sec",
+ "db_sched_duration_sec",
+ ]
+
+ def __init__(
+ self, ru_utime, ru_stime, db_txn_count, db_txn_duration_sec, db_sched_duration_sec
+ ):
+ self.ru_utime = ru_utime
+ self.ru_stime = ru_stime
+ self.db_txn_count = db_txn_count
+ self.db_txn_duration_sec = db_txn_duration_sec
+ self.db_sched_duration_sec = db_sched_duration_sec
+
+ @staticmethod
+ def from_context(context):
+ ru_utime, ru_stime = context.get_resource_usage()
+
+ return _RequestStats(
+ ru_utime, ru_stime,
+ context.db_txn_count,
+ context.db_txn_duration_sec,
+ context.db_sched_duration_sec,
+ )
+
+ def update(self, context):
+ """Updates the current values and returns the difference between the
+ old and new values.
+
+ Returns:
+ _RequestStats: The difference between the old and new values
+ """
+ new = _RequestStats.from_context(context)
+
+ diff = _RequestStats(
+ new.ru_utime - self.ru_utime,
+ new.ru_stime - self.ru_stime,
+ new.db_txn_count - self.db_txn_count,
+ new.db_txn_duration_sec - self.db_txn_duration_sec,
+ new.db_sched_duration_sec - self.db_sched_duration_sec,
+ )
+
+ self.ru_utime = new.ru_utime
+ self.ru_stime = new.ru_stime
+ self.db_txn_count = new.db_txn_count
+ self.db_txn_duration_sec = new.db_txn_duration_sec
+ self.db_sched_duration_sec = new.db_sched_duration_sec
+
+ return diff
diff --git a/synapse/http/server.py b/synapse/http/server.py
index e64aa92729..bc09b8b2be 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -13,11 +13,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import cgi
+from six.moves import http_client
from synapse.api.errors import (
cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError, Codes
)
+from synapse.http.request_metrics import (
+ requests_counter,
+)
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches import intern_dict
from synapse.util.metrics import Measure
@@ -41,173 +45,174 @@ import simplejson
logger = logging.getLogger(__name__)
-metrics = synapse.metrics.get_metrics_for(__name__)
-
-# total number of responses served, split by method/servlet/tag
-response_count = metrics.register_counter(
- "response_count",
- labels=["method", "servlet", "tag"],
- alternative_names=(
- # the following are all deprecated aliases for the same metric
- metrics.name_prefix + x for x in (
- "_requests",
- "_response_time:count",
- "_response_ru_utime:count",
- "_response_ru_stime:count",
- "_response_db_txn_count:count",
- "_response_db_txn_duration:count",
- )
- )
-)
+HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
+<html lang=en>
+ <head>
+ <meta charset="utf-8">
+ <title>Error {code}</title>
+ </head>
+ <body>
+ <p>{msg}</p>
+ </body>
+</html>
+"""
-requests_counter = metrics.register_counter(
- "requests_received",
- labels=["method", "servlet", ],
-)
-outgoing_responses_counter = metrics.register_counter(
- "responses",
- labels=["method", "code"],
-)
+def wrap_json_request_handler(h):
+ """Wraps a request handler method with exception handling.
-response_timer = metrics.register_counter(
- "response_time_seconds",
- labels=["method", "servlet", "tag"],
- alternative_names=(
- metrics.name_prefix + "_response_time:total",
- ),
-)
+ Also adds logging as per wrap_request_handler_with_logging.
-response_ru_utime = metrics.register_counter(
- "response_ru_utime_seconds", labels=["method", "servlet", "tag"],
- alternative_names=(
- metrics.name_prefix + "_response_ru_utime:total",
- ),
-)
+ The handler method must have a signature of "handle_foo(self, request)",
+ where "self" must have a "clock" attribute (and "request" must be a
+ SynapseRequest).
-response_ru_stime = metrics.register_counter(
- "response_ru_stime_seconds", labels=["method", "servlet", "tag"],
- alternative_names=(
- metrics.name_prefix + "_response_ru_stime:total",
- ),
-)
+ The handler must return a deferred. If the deferred succeeds we assume that
+ a response has been sent. If the deferred fails with a SynapseError we use
+ it to send a JSON response with the appropriate HTTP reponse code. If the
+ deferred fails with any other type of error we send a 500 reponse.
+ """
-response_db_txn_count = metrics.register_counter(
- "response_db_txn_count", labels=["method", "servlet", "tag"],
- alternative_names=(
- metrics.name_prefix + "_response_db_txn_count:total",
- ),
-)
+ @defer.inlineCallbacks
+ def wrapped_request_handler(self, request):
+ try:
+ yield h(self, request)
+ except CodeMessageException as e:
+ code = e.code
+ if isinstance(e, SynapseError):
+ logger.info(
+ "%s SynapseError: %s - %s", request, code, e.msg
+ )
+ else:
+ logger.exception(e)
+ respond_with_json(
+ request, code, cs_exception(e), send_cors=True,
+ pretty_print=_request_user_agent_is_curl(request),
+ )
-# seconds spent waiting for db txns, excluding scheduling time, when processing
-# this request
-response_db_txn_duration = metrics.register_counter(
- "response_db_txn_duration_seconds", labels=["method", "servlet", "tag"],
- alternative_names=(
- metrics.name_prefix + "_response_db_txn_duration:total",
- ),
-)
+ except Exception:
+ # failure.Failure() fishes the original Failure out
+ # of our stack, and thus gives us a sensible stack
+ # trace.
+ f = failure.Failure()
+ logger.error(
+ "Failed handle request via %r: %r: %s",
+ h,
+ request,
+ f.getTraceback().rstrip(),
+ )
+ respond_with_json(
+ request,
+ 500,
+ {
+ "error": "Internal server error",
+ "errcode": Codes.UNKNOWN,
+ },
+ send_cors=True,
+ pretty_print=_request_user_agent_is_curl(request),
+ )
-# seconds spent waiting for a db connection, when processing this request
-response_db_sched_duration = metrics.register_counter(
- "response_db_sched_duration_seconds", labels=["method", "servlet", "tag"]
-)
+ return wrap_request_handler_with_logging(wrapped_request_handler)
-_next_request_id = 0
+def wrap_html_request_handler(h):
+ """Wraps a request handler method with exception handling.
-def request_handler(include_metrics=False):
- """Decorator for ``wrap_request_handler``"""
- return lambda request_handler: wrap_request_handler(request_handler, include_metrics)
+ Also adds logging as per wrap_request_handler_with_logging.
+ The handler method must have a signature of "handle_foo(self, request)",
+ where "self" must have a "clock" attribute (and "request" must be a
+ SynapseRequest).
+ """
+ def wrapped_request_handler(self, request):
+ d = defer.maybeDeferred(h, self, request)
+ d.addErrback(_return_html_error, request)
+ return d
-def wrap_request_handler(request_handler, include_metrics=False):
- """Wraps a method that acts as a request handler with the necessary logging
- and exception handling.
+ return wrap_request_handler_with_logging(wrapped_request_handler)
- The method must have a signature of "handle_foo(self, request)". The
- argument "self" must have "version_string" and "clock" attributes. The
- argument "request" must be a twisted HTTP request.
- The method must return a deferred. If the deferred succeeds we assume that
- a response has been sent. If the deferred fails with a SynapseError we use
- it to send a JSON response with the appropriate HTTP reponse code. If the
- deferred fails with any other type of error we send a 500 reponse.
+def _return_html_error(f, request):
+ """Sends an HTML error page corresponding to the given failure
- We insert a unique request-id into the logging context for this request and
- log the response and duration for this request.
+ Args:
+ f (twisted.python.failure.Failure):
+ request (twisted.web.iweb.IRequest):
"""
+ if f.check(CodeMessageException):
+ cme = f.value
+ code = cme.code
+ msg = cme.msg
+
+ if isinstance(cme, SynapseError):
+ logger.info(
+ "%s SynapseError: %s - %s", request, code, msg
+ )
+ else:
+ logger.error(
+ "Failed handle request %r: %s",
+ request,
+ f.getTraceback().rstrip(),
+ )
+ else:
+ code = http_client.INTERNAL_SERVER_ERROR
+ msg = "Internal server error"
+
+ logger.error(
+ "Failed handle request %r: %s",
+ request,
+ f.getTraceback().rstrip(),
+ )
+
+ body = HTML_ERROR_TEMPLATE.format(
+ code=code, msg=cgi.escape(msg),
+ ).encode("utf-8")
+ request.setResponseCode(code)
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(b"Content-Length", b"%i" % (len(body),))
+ request.write(body)
+ finish_request(request)
+
+
+def wrap_request_handler_with_logging(h):
+ """Wraps a request handler to provide logging and metrics
+ The handler method must have a signature of "handle_foo(self, request)",
+ where "self" must have a "clock" attribute (and "request" must be a
+ SynapseRequest).
+
+ As well as calling `request.processing` (which will log the response and
+ duration for this request), the wrapped request handler will insert the
+ request id into the logging context.
+ """
@defer.inlineCallbacks
def wrapped_request_handler(self, request):
- global _next_request_id
- request_id = "%s-%s" % (request.method, _next_request_id)
- _next_request_id += 1
+ """
+ Args:
+ self:
+ request (synapse.http.site.SynapseRequest):
+ """
+ request_id = request.get_request_id()
with LoggingContext(request_id) as request_context:
+ request_context.request = request_id
with Measure(self.clock, "wrapped_request_handler"):
- request_metrics = RequestMetrics()
# we start the request metrics timer here with an initial stab
# at the servlet name. For most requests that name will be
# JsonResource (or a subclass), and JsonResource._async_render
# will update it once it picks a servlet.
servlet_name = self.__class__.__name__
- request_metrics.start(self.clock, name=servlet_name)
-
- request_context.request = request_id
- with request.processing():
- try:
- with PreserveLoggingContext(request_context):
- if include_metrics:
- yield request_handler(self, request, request_metrics)
- else:
- requests_counter.inc(request.method, servlet_name)
- yield request_handler(self, request)
- except CodeMessageException as e:
- code = e.code
- if isinstance(e, SynapseError):
- logger.info(
- "%s SynapseError: %s - %s", request, code, e.msg
- )
- else:
- logger.exception(e)
- outgoing_responses_counter.inc(request.method, str(code))
- respond_with_json(
- request, code, cs_exception(e), send_cors=True,
- pretty_print=_request_user_agent_is_curl(request),
- version_string=self.version_string,
- )
- except Exception:
- # failure.Failure() fishes the original Failure out
- # of our stack, and thus gives us a sensible stack
- # trace.
- f = failure.Failure()
- logger.error(
- "Failed handle request %s.%s on %r: %r: %s",
- request_handler.__module__,
- request_handler.__name__,
- self,
- request,
- f.getTraceback().rstrip(),
- )
- respond_with_json(
- request,
- 500,
- {
- "error": "Internal server error",
- "errcode": Codes.UNKNOWN,
- },
- send_cors=True,
- pretty_print=_request_user_agent_is_curl(request),
- version_string=self.version_string,
- )
- finally:
- try:
- request_metrics.stop(
- self.clock, request
- )
- except Exception as e:
- logger.warn("Failed to stop metrics: %r", e)
+ with request.processing(servlet_name):
+ with PreserveLoggingContext(request_context):
+ d = defer.maybeDeferred(h, self, request)
+
+ # record the arrival of the request *after*
+ # dispatching to the handler, so that the handler
+ # can update the servlet name in the request
+ # metrics
+ requests_counter.labels(request.method,
+ request.request_metrics.name).inc()
+ yield d
return wrapped_request_handler
@@ -257,7 +262,6 @@ class JsonResource(HttpServer, resource.Resource):
self.canonical_json = canonical_json
self.clock = hs.get_clock()
self.path_regexs = {}
- self.version_string = hs.version_string
self.hs = hs
def register_paths(self, method, path_patterns, callback):
@@ -273,13 +277,9 @@ class JsonResource(HttpServer, resource.Resource):
self._async_render(request)
return server.NOT_DONE_YET
- # Disable metric reporting because _async_render does its own metrics.
- # It does its own metric reporting because _async_render dispatches to
- # a callback and it's the class name of that callback we want to report
- # against rather than the JsonResource itself.
- @request_handler(include_metrics=True)
+ @wrap_json_request_handler
@defer.inlineCallbacks
- def _async_render(self, request, request_metrics):
+ def _async_render(self, request):
""" This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and
path.
@@ -291,9 +291,7 @@ class JsonResource(HttpServer, resource.Resource):
servlet_classname = servlet_instance.__class__.__name__
else:
servlet_classname = "%r" % callback
-
- request_metrics.name = servlet_classname
- requests_counter.inc(request.method, servlet_classname)
+ request.request_metrics.name = servlet_classname
# Now trigger the callback. If it returns a response, we send it
# here. If it throws an exception, that is handled by the wrapper
@@ -324,7 +322,7 @@ class JsonResource(HttpServer, resource.Resource):
register_paths, so will return (possibly via Deferred) either
None, or a tuple of (http code, response body).
"""
- if request.method == "OPTIONS":
+ if request.method == b"OPTIONS":
return _options_handler, {}
# Loop through all the registered callbacks to check if the method
@@ -340,15 +338,12 @@ class JsonResource(HttpServer, resource.Resource):
def _send_response(self, request, code, response_json_object,
response_code_message=None):
- outgoing_responses_counter.inc(request.method, str(code))
-
# TODO: Only enable CORS for the requests that need it.
respond_with_json(
request, code, response_json_object,
send_cors=True,
response_code_message=response_code_message,
pretty_print=_request_user_agent_is_curl(request),
- version_string=self.version_string,
canonical_json=self.canonical_json,
)
@@ -381,52 +376,6 @@ def _unrecognised_request_handler(request):
raise UnrecognizedRequestError()
-class RequestMetrics(object):
- def start(self, clock, name):
- self.start = clock.time_msec()
- self.start_context = LoggingContext.current_context()
- self.name = name
-
- def stop(self, clock, request):
- context = LoggingContext.current_context()
-
- tag = ""
- if context:
- tag = context.tag
-
- if context != self.start_context:
- logger.warn(
- "Context have unexpectedly changed %r, %r",
- context, self.start_context
- )
- return
-
- response_count.inc(request.method, self.name, tag)
-
- response_timer.inc_by(
- clock.time_msec() - self.start, request.method,
- self.name, tag
- )
-
- ru_utime, ru_stime = context.get_resource_usage()
-
- response_ru_utime.inc_by(
- ru_utime, request.method, self.name, tag
- )
- response_ru_stime.inc_by(
- ru_stime, request.method, self.name, tag
- )
- response_db_txn_count.inc_by(
- context.db_txn_count, request.method, self.name, tag
- )
- response_db_txn_duration.inc_by(
- context.db_txn_duration_ms / 1000., request.method, self.name, tag
- )
- response_db_sched_duration.inc_by(
- context.db_sched_duration_ms / 1000., request.method, self.name, tag
- )
-
-
class RootRedirect(resource.Resource):
"""Redirects the root '/' path to another path."""
@@ -445,7 +394,7 @@ class RootRedirect(resource.Resource):
def respond_with_json(request, code, json_object, send_cors=False,
response_code_message=None, pretty_print=False,
- version_string="", canonical_json=True):
+ canonical_json=True):
# could alternatively use request.notifyFinish() and flip a flag when
# the Deferred fires, but since the flag is RIGHT THERE it seems like
# a waste.
@@ -461,19 +410,17 @@ def respond_with_json(request, code, json_object, send_cors=False,
if canonical_json or synapse.events.USE_FROZEN_DICTS:
json_bytes = encode_canonical_json(json_object)
else:
- # ujson doesn't like frozen_dicts.
json_bytes = simplejson.dumps(json_object)
return respond_with_json_bytes(
request, code, json_bytes,
send_cors=send_cors,
response_code_message=response_code_message,
- version_string=version_string
)
def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
- version_string="", response_code_message=None):
+ response_code_message=None):
"""Sends encoded JSON in response to the given request.
Args:
@@ -487,8 +434,8 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
request.setResponseCode(code, message=response_code_message)
request.setHeader(b"Content-Type", b"application/json")
- request.setHeader(b"Server", version_string)
request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),))
+ request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate")
if send_cors:
set_cors_headers(request)
@@ -536,9 +483,9 @@ def finish_request(request):
def _request_user_agent_is_curl(request):
user_agents = request.requestHeaders.getRawHeaders(
- "User-Agent", default=[]
+ b"User-Agent", default=[]
)
for user_agent in user_agents:
- if "curl" in user_agent:
+ if b"curl" in user_agent:
return True
return False
diff --git a/synapse/http/site.py b/synapse/http/site.py
index e422c8dfae..60299657b9 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -12,27 +12,51 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.logcontext import LoggingContext
-from twisted.web.server import Site, Request
-
import contextlib
import logging
import re
import time
-ACCESS_TOKEN_RE = re.compile(r'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
+from twisted.web.server import Site, Request
+
+from synapse.http.request_metrics import RequestMetrics
+from synapse.util.logcontext import LoggingContext
+
+logger = logging.getLogger(__name__)
+
+ACCESS_TOKEN_RE = re.compile(br'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
+
+_next_request_seq = 0
class SynapseRequest(Request):
+ """Class which encapsulates an HTTP request to synapse.
+
+ All of the requests processed in synapse are of this type.
+
+ It extends twisted's twisted.web.server.Request, and adds:
+ * Unique request ID
+ * Redaction of access_token query-params in __repr__
+ * Logging at start and end
+ * Metrics to record CPU, wallclock and DB time by endpoint.
+
+ It provides a method `processing` which should be called by the Resource
+ which is handling the request, and returns a context manager.
+
+ """
def __init__(self, site, *args, **kw):
Request.__init__(self, *args, **kw)
self.site = site
self.authenticated_entity = None
self.start_time = 0
+ global _next_request_seq
+ self.request_seq = _next_request_seq
+ _next_request_seq += 1
+
def __repr__(self):
# We overwrite this so that we don't log ``access_token``
- return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % (
+ return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % (
self.__class__.__name__,
id(self),
self.method,
@@ -41,16 +65,30 @@ class SynapseRequest(Request):
self.site.site_tag,
)
+ def get_request_id(self):
+ return "%s-%i" % (self.method, self.request_seq)
+
def get_redacted_uri(self):
return ACCESS_TOKEN_RE.sub(
- r'\1<redacted>\3',
+ br'\1<redacted>\3',
self.uri
)
def get_user_agent(self):
- return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1]
+ return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
+
+ def render(self, resrc):
+ # override the Server header which is set by twisted
+ self.setHeader("Server", self.site.server_version_string)
+ return Request.render(self, resrc)
+
+ def _started_processing(self, servlet_name):
+ self.start_time = time.time()
+ self.request_metrics = RequestMetrics()
+ self.request_metrics.start(
+ self.start_time, name=servlet_name, method=self.method,
+ )
- def started_processing(self):
self.site.access_logger.info(
"%s - %s - Received request: %s %s",
self.getClientIP(),
@@ -58,32 +96,32 @@ class SynapseRequest(Request):
self.method,
self.get_redacted_uri()
)
- self.start_time = int(time.time() * 1000)
-
- def finished_processing(self):
+ def _finished_processing(self):
try:
context = LoggingContext.current_context()
ru_utime, ru_stime = context.get_resource_usage()
db_txn_count = context.db_txn_count
- db_txn_duration_ms = context.db_txn_duration_ms
- db_sched_duration_ms = context.db_sched_duration_ms
+ db_txn_duration_sec = context.db_txn_duration_sec
+ db_sched_duration_sec = context.db_sched_duration_sec
except Exception:
ru_utime, ru_stime = (0, 0)
- db_txn_count, db_txn_duration_ms = (0, 0)
+ db_txn_count, db_txn_duration_sec = (0, 0)
+
+ end_time = time.time()
self.site.access_logger.info(
"%s - %s - {%s}"
- " Processed request: %dms (%dms, %dms) (%dms/%dms/%d)"
+ " Processed request: %.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
" %sB %s \"%s %s %s\" \"%s\"",
self.getClientIP(),
self.site.site_tag,
self.authenticated_entity,
- int(time.time() * 1000) - self.start_time,
- int(ru_utime * 1000),
- int(ru_stime * 1000),
- db_sched_duration_ms,
- db_txn_duration_ms,
+ end_time - self.start_time,
+ ru_utime,
+ ru_stime,
+ db_sched_duration_sec,
+ db_txn_duration_sec,
int(db_txn_count),
self.sentLength,
self.code,
@@ -93,11 +131,38 @@ class SynapseRequest(Request):
self.get_user_agent(),
)
+ try:
+ self.request_metrics.stop(end_time, self)
+ except Exception as e:
+ logger.warn("Failed to stop metrics: %r", e)
+
@contextlib.contextmanager
- def processing(self):
- self.started_processing()
+ def processing(self, servlet_name):
+ """Record the fact that we are processing this request.
+
+ Returns a context manager; the correct way to use this is:
+
+ @defer.inlineCallbacks
+ def handle_request(request):
+ with request.processing("FooServlet"):
+ yield really_handle_the_request()
+
+ This will log the request's arrival. Once the context manager is
+ closed, the completion of the request will be logged, and the various
+ metrics will be updated.
+
+ Args:
+ servlet_name (str): the name of the servlet which will be
+ processing this request. This is used in the metrics.
+
+ It is possible to update this afterwards by updating
+ self.request_metrics.servlet_name.
+ """
+ # TODO: we should probably just move this into render() and finish(),
+ # to save having to call a separate method.
+ self._started_processing(servlet_name)
yield
- self.finished_processing()
+ self._finished_processing()
class XForwardedForRequest(SynapseRequest):
@@ -135,7 +200,8 @@ class SynapseSite(Site):
Subclass of a twisted http Site that does access logging with python's
standard logging
"""
- def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs):
+ def __init__(self, logger_name, site_tag, config, resource,
+ server_version_string, *args, **kwargs):
Site.__init__(self, resource, *args, **kwargs)
self.site_tag = site_tag
@@ -143,6 +209,7 @@ class SynapseSite(Site):
proxied = config.get("x_forwarded", False)
self.requestFactory = SynapseRequestFactory(self, proxied)
self.access_logger = logging.getLogger(logger_name)
+ self.server_version_string = server_version_string
def log(self, request):
pass
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 50d99d7a5c..371c300cac 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -17,130 +17,170 @@ import logging
import functools
import time
import gc
+import os
+import platform
+import attr
-from twisted.internet import reactor
+from prometheus_client import Gauge, Histogram, Counter
+from prometheus_client.core import GaugeMetricFamily, REGISTRY
-from .metric import (
- CounterMetric, CallbackMetric, DistributionMetric, CacheMetric,
- MemoryUsageMetric,
-)
-from .process_collector import register_process_collector
+from twisted.internet import reactor
logger = logging.getLogger(__name__)
-
+running_on_pypy = platform.python_implementation() == "PyPy"
all_metrics = []
all_collectors = []
+all_gauges = {}
+
+
+class RegistryProxy(object):
+
+ def collect(self):
+ for metric in REGISTRY.collect():
+ if not metric.name.startswith("__"):
+ yield metric
+
+@attr.s(hash=True)
+class LaterGauge(object):
-class Metrics(object):
- """ A single Metrics object gives a (mutable) slice view of the all_metrics
- dict, allowing callers to easily register new metrics that are namespaced
- nicely."""
+ name = attr.ib()
+ desc = attr.ib()
+ labels = attr.ib(hash=False)
+ caller = attr.ib()
- def __init__(self, name):
- self.name_prefix = name
+ def collect(self):
- def make_subspace(self, name):
- return Metrics("%s_%s" % (self.name_prefix, name))
+ g = GaugeMetricFamily(self.name, self.desc, labels=self.labels)
- def register_collector(self, func):
- all_collectors.append(func)
+ try:
+ calls = self.caller()
+ except Exception as e:
+ print(e)
+ logger.err()
+ yield g
+
+ if isinstance(calls, dict):
+ for k, v in calls.items():
+ g.add_metric(k, v)
+ else:
+ g.add_metric([], calls)
- def _register(self, metric_class, name, *args, **kwargs):
- full_name = "%s_%s" % (self.name_prefix, name)
+ yield g
- metric = metric_class(full_name, *args, **kwargs)
+ def __attrs_post_init__(self):
+ self._register()
- all_metrics.append(metric)
- return metric
+ def _register(self):
+ if self.name in all_gauges.keys():
+ logger.warning("%s already registered, reregistering" % (self.name,))
+ REGISTRY.unregister(all_gauges.pop(self.name))
- def register_counter(self, *args, **kwargs):
- """
- Returns:
- CounterMetric
- """
- return self._register(CounterMetric, *args, **kwargs)
+ REGISTRY.register(self)
+ all_gauges[self.name] = self
- def register_callback(self, *args, **kwargs):
- """
- Returns:
- CallbackMetric
- """
- return self._register(CallbackMetric, *args, **kwargs)
- def register_distribution(self, *args, **kwargs):
- """
- Returns:
- DistributionMetric
- """
- return self._register(DistributionMetric, *args, **kwargs)
+#
+# Detailed CPU metrics
+#
- def register_cache(self, *args, **kwargs):
- """
- Returns:
- CacheMetric
- """
- return self._register(CacheMetric, *args, **kwargs)
+class CPUMetrics(object):
+ def __init__(self):
+ ticks_per_sec = 100
+ try:
+ # Try and get the system config
+ ticks_per_sec = os.sysconf('SC_CLK_TCK')
+ except (ValueError, TypeError, AttributeError):
+ pass
-def register_memory_metrics(hs):
- try:
- import psutil
- process = psutil.Process()
- process.memory_info().rss
- except (ImportError, AttributeError):
- logger.warn(
- "psutil is not installed or incorrect version."
- " Disabling memory metrics."
- )
- return
- metric = MemoryUsageMetric(hs, psutil)
- all_metrics.append(metric)
+ self.ticks_per_sec = ticks_per_sec
+ def collect(self):
-def get_metrics_for(pkg_name):
- """ Returns a Metrics instance for conveniently creating metrics
- namespaced with the given name prefix. """
+ with open("/proc/self/stat") as s:
+ line = s.read()
+ raw_stats = line.split(") ", 1)[1].split(" ")
- # Convert a "package.name" to "package_name" because Prometheus doesn't
- # let us use . in metric names
- return Metrics(pkg_name.replace(".", "_"))
+ user = GaugeMetricFamily("process_cpu_user_seconds_total", "")
+ user.add_metric([], float(raw_stats[11]) / self.ticks_per_sec)
+ yield user
+ sys = GaugeMetricFamily("process_cpu_system_seconds_total", "")
+ sys.add_metric([], float(raw_stats[12]) / self.ticks_per_sec)
+ yield sys
-def render_all():
- strs = []
- for collector in all_collectors:
- collector()
+REGISTRY.register(CPUMetrics())
- for metric in all_metrics:
- try:
- strs += metric.render()
- except Exception:
- strs += ["# FAILED to render"]
- logger.exception("Failed to render metric")
+#
+# Python GC metrics
+#
+
+gc_unreachable = Gauge("python_gc_unreachable_total", "Unreachable GC objects", ["gen"])
+gc_time = Histogram(
+ "python_gc_time",
+ "Time taken to GC (sec)",
+ ["gen"],
+ buckets=[0.0025, 0.005, 0.01, 0.025, 0.05, 0.10, 0.25, 0.50, 1.00, 2.50,
+ 5.00, 7.50, 15.00, 30.00, 45.00, 60.00],
+)
- strs.append("") # to generate a final CRLF
- return "\n".join(strs)
+class GCCounts(object):
+ def collect(self):
+ cm = GaugeMetricFamily("python_gc_counts", "GC cycle counts", labels=["gen"])
+ for n, m in enumerate(gc.get_count()):
+ cm.add_metric([str(n)], m)
-register_process_collector(get_metrics_for("process"))
+ yield cm
-python_metrics = get_metrics_for("python")
+REGISTRY.register(GCCounts())
+
+#
+# Twisted reactor metrics
+#
-gc_time = python_metrics.register_distribution("gc_time", labels=["gen"])
-gc_unreachable = python_metrics.register_counter("gc_unreachable_total", labels=["gen"])
-python_metrics.register_callback(
- "gc_counts", lambda: {(i,): v for i, v in enumerate(gc.get_count())}, labels=["gen"]
+tick_time = Histogram(
+ "python_twisted_reactor_tick_time",
+ "Tick time of the Twisted reactor (sec)",
+ buckets=[0.001, 0.002, 0.005, 0.01, 0.025, 0.05, 0.1, 0.2, 0.5, 1, 2, 5],
+)
+pending_calls_metric = Histogram(
+ "python_twisted_reactor_pending_calls",
+ "Pending calls",
+ buckets=[1, 2, 5, 10, 25, 50, 100, 250, 500, 1000],
)
-reactor_metrics = get_metrics_for("python.twisted.reactor")
-tick_time = reactor_metrics.register_distribution("tick_time")
-pending_calls_metric = reactor_metrics.register_distribution("pending_calls")
+#
+# Federation Metrics
+#
+
+sent_edus_counter = Counter("synapse_federation_client_sent_edus", "")
+
+sent_transactions_counter = Counter("synapse_federation_client_sent_transactions", "")
+
+events_processed_counter = Counter("synapse_federation_client_events_processed", "")
+
+# Used to track where various components have processed in the event stream,
+# e.g. federation sending, appservice sending, etc.
+event_processing_positions = Gauge("synapse_event_processing_positions", "", ["name"])
+
+# Used to track the current max events stream position
+event_persisted_position = Gauge("synapse_event_persisted_position", "")
+
+# Used to track the received_ts of the last event processed by various
+# components
+event_processing_last_ts = Gauge("synapse_event_processing_last_ts", "", ["name"])
+
+# Used to track the lag processing events. This is the time difference
+# between the last processed event's received_ts and the time it was
+# finished being processed.
+event_processing_lag = Gauge("synapse_event_processing_lag", "", ["name"])
def runUntilCurrentTimer(func):
@@ -162,17 +202,20 @@ def runUntilCurrentTimer(func):
num_pending += 1
num_pending += len(reactor.threadCallQueue)
- start = time.time() * 1000
+ start = time.time()
ret = func(*args, **kwargs)
- end = time.time() * 1000
+ end = time.time()
# record the amount of wallclock time spent running pending calls.
# This is a proxy for the actual amount of time between reactor polls,
# since about 25% of time is actually spent running things triggered by
# I/O events, but that is harder to capture without rewriting half the
# reactor.
- tick_time.inc_by(end - start)
- pending_calls_metric.inc_by(num_pending)
+ tick_time.observe(end - start)
+ pending_calls_metric.observe(num_pending)
+
+ if running_on_pypy:
+ return ret
# Check if we need to do a manual GC (since its been disabled), and do
# one if necessary.
@@ -182,12 +225,12 @@ def runUntilCurrentTimer(func):
if threshold[i] < counts[i]:
logger.info("Collecting gc %d", i)
- start = time.time() * 1000
+ start = time.time()
unreachable = gc.collect(i)
- end = time.time() * 1000
+ end = time.time()
- gc_time.inc_by(end - start, i)
- gc_unreachable.inc_by(unreachable, i)
+ gc_time.labels(i).observe(end - start)
+ gc_unreachable.labels(i).set(unreachable)
return ret
@@ -206,6 +249,7 @@ try:
# We manually run the GC each reactor tick so that we can get some metrics
# about time spent doing GC,
- gc.disable()
+ if not running_on_pypy:
+ gc.disable()
except AttributeError:
pass
diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py
deleted file mode 100644
index ff5aa8c0e1..0000000000
--- a/synapse/metrics/metric.py
+++ /dev/null
@@ -1,271 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from itertools import chain
-import logging
-
-logger = logging.getLogger(__name__)
-
-
-def flatten(items):
- """Flatten a list of lists
-
- Args:
- items: iterable[iterable[X]]
-
- Returns:
- list[X]: flattened list
- """
- return list(chain.from_iterable(items))
-
-
-class BaseMetric(object):
- """Base class for metrics which report a single value per label set
- """
-
- def __init__(self, name, labels=[], alternative_names=[]):
- """
- Args:
- name (str): principal name for this metric
- labels (list(str)): names of the labels which will be reported
- for this metric
- alternative_names (iterable(str)): list of alternative names for
- this metric. This can be useful to provide a migration path
- when renaming metrics.
- """
- self._names = [name] + list(alternative_names)
- self.labels = labels # OK not to clone as we never write it
-
- def dimension(self):
- return len(self.labels)
-
- def is_scalar(self):
- return not len(self.labels)
-
- def _render_labelvalue(self, value):
- # TODO: escape backslashes, quotes and newlines
- return '"%s"' % (value)
-
- def _render_key(self, values):
- if self.is_scalar():
- return ""
- return "{%s}" % (
- ",".join(["%s=%s" % (k, self._render_labelvalue(v))
- for k, v in zip(self.labels, values)])
- )
-
- def _render_for_labels(self, label_values, value):
- """Render this metric for a single set of labels
-
- Args:
- label_values (list[str]): values for each of the labels
- value: value of the metric at with these labels
-
- Returns:
- iterable[str]: rendered metric
- """
- rendered_labels = self._render_key(label_values)
- return (
- "%s%s %.12g" % (name, rendered_labels, value)
- for name in self._names
- )
-
- def render(self):
- """Render this metric
-
- Each metric is rendered as:
-
- name{label1="val1",label2="val2"} value
-
- https://prometheus.io/docs/instrumenting/exposition_formats/#text-format-details
-
- Returns:
- iterable[str]: rendered metrics
- """
- raise NotImplementedError()
-
-
-class CounterMetric(BaseMetric):
- """The simplest kind of metric; one that stores a monotonically-increasing
- value that counts events or running totals.
-
- Example use cases for Counters:
- - Number of requests processed
- - Number of items that were inserted into a queue
- - Total amount of data that a system has processed
- Counters can only go up (and be reset when the process restarts).
- """
-
- def __init__(self, *args, **kwargs):
- super(CounterMetric, self).__init__(*args, **kwargs)
-
- # dict[list[str]]: value for each set of label values. the keys are the
- # label values, in the same order as the labels in self.labels.
- #
- # (if the metric is a scalar, the (single) key is the empty list).
- self.counts = {}
-
- # Scalar metrics are never empty
- if self.is_scalar():
- self.counts[()] = 0.
-
- def inc_by(self, incr, *values):
- if len(values) != self.dimension():
- raise ValueError(
- "Expected as many values to inc() as labels (%d)" % (self.dimension())
- )
-
- # TODO: should assert that the tag values are all strings
-
- if values not in self.counts:
- self.counts[values] = incr
- else:
- self.counts[values] += incr
-
- def inc(self, *values):
- self.inc_by(1, *values)
-
- def render(self):
- return flatten(
- self._render_for_labels(k, self.counts[k])
- for k in sorted(self.counts.keys())
- )
-
-
-class CallbackMetric(BaseMetric):
- """A metric that returns the numeric value returned by a callback whenever
- it is rendered. Typically this is used to implement gauges that yield the
- size or other state of some in-memory object by actively querying it."""
-
- def __init__(self, name, callback, labels=[]):
- super(CallbackMetric, self).__init__(name, labels=labels)
-
- self.callback = callback
-
- def render(self):
- try:
- value = self.callback()
- except Exception:
- logger.exception("Failed to render %s", self.name)
- return ["# FAILED to render " + self.name]
-
- if self.is_scalar():
- return list(self._render_for_labels([], value))
-
- return flatten(
- self._render_for_labels(k, value[k])
- for k in sorted(value.keys())
- )
-
-
-class DistributionMetric(object):
- """A combination of an event counter and an accumulator, which counts
- both the number of events and accumulates the total value. Typically this
- could be used to keep track of method-running times, or other distributions
- of values that occur in discrete occurances.
-
- TODO(paul): Try to export some heatmap-style stats?
- """
-
- def __init__(self, name, *args, **kwargs):
- self.counts = CounterMetric(name + ":count", **kwargs)
- self.totals = CounterMetric(name + ":total", **kwargs)
-
- def inc_by(self, inc, *values):
- self.counts.inc(*values)
- self.totals.inc_by(inc, *values)
-
- def render(self):
- return self.counts.render() + self.totals.render()
-
-
-class CacheMetric(object):
- __slots__ = (
- "name", "cache_name", "hits", "misses", "evicted_size", "size_callback",
- )
-
- def __init__(self, name, size_callback, cache_name):
- self.name = name
- self.cache_name = cache_name
-
- self.hits = 0
- self.misses = 0
- self.evicted_size = 0
-
- self.size_callback = size_callback
-
- def inc_hits(self):
- self.hits += 1
-
- def inc_misses(self):
- self.misses += 1
-
- def inc_evictions(self, size=1):
- self.evicted_size += size
-
- def render(self):
- size = self.size_callback()
- hits = self.hits
- total = self.misses + self.hits
-
- return [
- """%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits),
- """%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
- """%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
- """%s:evicted_size{name="%s"} %d""" % (
- self.name, self.cache_name, self.evicted_size
- ),
- ]
-
-
-class MemoryUsageMetric(object):
- """Keeps track of the current memory usage, using psutil.
-
- The class will keep the current min/max/sum/counts of rss over the last
- WINDOW_SIZE_SEC, by polling UPDATE_HZ times per second
- """
-
- UPDATE_HZ = 2 # number of times to get memory per second
- WINDOW_SIZE_SEC = 30 # the size of the window in seconds
-
- def __init__(self, hs, psutil):
- clock = hs.get_clock()
- self.memory_snapshots = []
-
- self.process = psutil.Process()
-
- clock.looping_call(self._update_curr_values, 1000 / self.UPDATE_HZ)
-
- def _update_curr_values(self):
- max_size = self.UPDATE_HZ * self.WINDOW_SIZE_SEC
- self.memory_snapshots.append(self.process.memory_info().rss)
- self.memory_snapshots[:] = self.memory_snapshots[-max_size:]
-
- def render(self):
- if not self.memory_snapshots:
- return []
-
- max_rss = max(self.memory_snapshots)
- min_rss = min(self.memory_snapshots)
- sum_rss = sum(self.memory_snapshots)
- len_rss = len(self.memory_snapshots)
-
- return [
- "process_psutil_rss:max %d" % max_rss,
- "process_psutil_rss:min %d" % min_rss,
- "process_psutil_rss:total %d" % sum_rss,
- "process_psutil_rss:count %d" % len_rss,
- ]
diff --git a/synapse/metrics/process_collector.py b/synapse/metrics/process_collector.py
deleted file mode 100644
index 6fec3de399..0000000000
--- a/synapse/metrics/process_collector.py
+++ /dev/null
@@ -1,122 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-
-
-TICKS_PER_SEC = 100
-BYTES_PER_PAGE = 4096
-
-HAVE_PROC_STAT = os.path.exists("/proc/stat")
-HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
-HAVE_PROC_SELF_LIMITS = os.path.exists("/proc/self/limits")
-HAVE_PROC_SELF_FD = os.path.exists("/proc/self/fd")
-
-# Field indexes from /proc/self/stat, taken from the proc(5) manpage
-STAT_FIELDS = {
- "utime": 14,
- "stime": 15,
- "starttime": 22,
- "vsize": 23,
- "rss": 24,
-}
-
-
-stats = {}
-
-# In order to report process_start_time_seconds we need to know the
-# machine's boot time, because the value in /proc/self/stat is relative to
-# this
-boot_time = None
-if HAVE_PROC_STAT:
- with open("/proc/stat") as _procstat:
- for line in _procstat:
- if line.startswith("btime "):
- boot_time = int(line.split()[1])
-
-
-def update_resource_metrics():
- if HAVE_PROC_SELF_STAT:
- global stats
- with open("/proc/self/stat") as s:
- line = s.read()
- # line is PID (command) more stats go here ...
- raw_stats = line.split(") ", 1)[1].split(" ")
-
- for (name, index) in STAT_FIELDS.iteritems():
- # subtract 3 from the index, because proc(5) is 1-based, and
- # we've lost the first two fields in PID and COMMAND above
- stats[name] = int(raw_stats[index - 3])
-
-
-def _count_fds():
- # Not every OS will have a /proc/self/fd directory
- if not HAVE_PROC_SELF_FD:
- return 0
-
- return len(os.listdir("/proc/self/fd"))
-
-
-def register_process_collector(process_metrics):
- process_metrics.register_collector(update_resource_metrics)
-
- if HAVE_PROC_SELF_STAT:
- process_metrics.register_callback(
- "cpu_user_seconds_total",
- lambda: float(stats["utime"]) / TICKS_PER_SEC
- )
- process_metrics.register_callback(
- "cpu_system_seconds_total",
- lambda: float(stats["stime"]) / TICKS_PER_SEC
- )
- process_metrics.register_callback(
- "cpu_seconds_total",
- lambda: (float(stats["utime"] + stats["stime"])) / TICKS_PER_SEC
- )
-
- process_metrics.register_callback(
- "virtual_memory_bytes",
- lambda: int(stats["vsize"])
- )
- process_metrics.register_callback(
- "resident_memory_bytes",
- lambda: int(stats["rss"]) * BYTES_PER_PAGE
- )
-
- process_metrics.register_callback(
- "start_time_seconds",
- lambda: boot_time + int(stats["starttime"]) / TICKS_PER_SEC
- )
-
- if HAVE_PROC_SELF_FD:
- process_metrics.register_callback(
- "open_fds",
- lambda: _count_fds()
- )
-
- if HAVE_PROC_SELF_LIMITS:
- def _get_max_fds():
- with open("/proc/self/limits") as limits:
- for line in limits:
- if not line.startswith("Max open files "):
- continue
- # Line is Max open files $SOFT $HARD
- return int(line.split()[3])
- return None
-
- process_metrics.register_callback(
- "max_fds",
- lambda: _get_max_fds()
- )
diff --git a/synapse/metrics/resource.py b/synapse/metrics/resource.py
index 870f400600..7996e6ab66 100644
--- a/synapse/metrics/resource.py
+++ b/synapse/metrics/resource.py
@@ -13,27 +13,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.web.resource import Resource
-
-import synapse.metrics
-
-
METRICS_PREFIX = "/_synapse/metrics"
-
-
-class MetricsResource(Resource):
- isLeaf = True
-
- def __init__(self, hs):
- Resource.__init__(self) # Resource is old-style, so no super()
-
- self.hs = hs
-
- def render_GET(self, request):
- response = synapse.metrics.render_all()
-
- request.setHeader("Content-Type", "text/plain")
- request.setHeader("Content-Length", str(len(response)))
-
- # Encode as UTF-8 (default)
- return response.encode()
diff --git a/synapse/notifier.py b/synapse/notifier.py
index ef042681bc..6dce20a284 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -14,33 +14,34 @@
# limitations under the License.
from twisted.internet import defer
+
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError
from synapse.handlers.presence import format_user_presence_state
-from synapse.util import DeferredTimedOutError
from synapse.util.logutils import log_function
-from synapse.util.async import ObservableDeferred
-from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
+from synapse.util.async import (
+ ObservableDeferred, add_timeout_to_deferred,
+ DeferredTimeoutError,
+)
+from synapse.util.logcontext import PreserveLoggingContext, run_in_background
from synapse.util.metrics import Measure
from synapse.types import StreamToken
from synapse.visibility import filter_events_for_client
-import synapse.metrics
+from synapse.metrics import LaterGauge
from collections import namedtuple
+from prometheus_client import Counter
import logging
logger = logging.getLogger(__name__)
-metrics = synapse.metrics.get_metrics_for(__name__)
-
-notified_events_counter = metrics.register_counter("notified_events")
+notified_events_counter = Counter("synapse_notifier_notified_events", "")
-users_woken_by_stream_counter = metrics.register_counter(
- "users_woken_by_stream", labels=["stream"]
-)
+users_woken_by_stream_counter = Counter(
+ "synapse_notifier_users_woken_by_stream", "", ["stream"])
# TODO(paul): Should be shared somewhere
@@ -105,7 +106,7 @@ class _NotifierUserStream(object):
self.last_notified_ms = time_now_ms
noify_deferred = self.notify_deferred
- users_woken_by_stream_counter.inc(stream_key)
+ users_woken_by_stream_counter.labels(stream_key).inc()
with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred())
@@ -144,6 +145,7 @@ class _NotifierUserStream(object):
class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))):
def __nonzero__(self):
return bool(self.events)
+ __bool__ = __nonzero__ # python3
class Notifier(object):
@@ -193,14 +195,14 @@ class Notifier(object):
all_user_streams.add(x)
return sum(stream.count_listeners() for stream in all_user_streams)
- metrics.register_callback("listeners", count_listeners)
+ LaterGauge("synapse_notifier_listeners", "", [], count_listeners)
- metrics.register_callback(
- "rooms",
+ LaterGauge(
+ "synapse_notifier_rooms", "", [],
lambda: count(bool, self.room_to_user_streams.values()),
)
- metrics.register_callback(
- "users",
+ LaterGauge(
+ "synapse_notifier_users", "", [],
lambda: len(self.user_to_user_stream),
)
@@ -250,9 +252,7 @@ class Notifier(object):
def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
"""Notify any user streams that are interested in this room event"""
# poke any interested application service.
- preserve_fn(self.appservice_handler.notify_interested_services)(
- room_stream_id
- )
+ run_in_background(self._notify_app_services, room_stream_id)
if self.federation_sender:
self.federation_sender.notify_new_events(room_stream_id)
@@ -266,6 +266,13 @@ class Notifier(object):
rooms=[event.room_id],
)
+ @defer.inlineCallbacks
+ def _notify_app_services(self, room_stream_id):
+ try:
+ yield self.appservice_handler.notify_interested_services(room_stream_id)
+ except Exception:
+ logger.exception("Error notifying application services of event")
+
def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
""" Used to inform listeners that something has happend event wise.
@@ -330,11 +337,12 @@ class Notifier(object):
# Now we wait for the _NotifierUserStream to be told there
# is a new token.
listener = user_stream.new_listener(prev_token)
+ add_timeout_to_deferred(
+ listener.deferred,
+ (end_time - now) / 1000.,
+ )
with PreserveLoggingContext():
- yield self.clock.time_bound_deferred(
- listener.deferred,
- time_out=(end_time - now) / 1000.
- )
+ yield listener.deferred
current_token = user_stream.current_token
@@ -345,7 +353,7 @@ class Notifier(object):
# Update the prev_token to the current_token since nothing
# has happened between the old prev_token and the current_token
prev_token = current_token
- except DeferredTimedOutError:
+ except DeferredTimeoutError:
break
except defer.CancelledError:
break
@@ -550,13 +558,14 @@ class Notifier(object):
if end_time <= now:
break
+ add_timeout_to_deferred(
+ listener.deferred.addTimeout,
+ (end_time - now) / 1000.,
+ )
try:
with PreserveLoggingContext():
- yield self.clock.time_bound_deferred(
- listener.deferred,
- time_out=(end_time - now) / 1000.
- )
- except DeferredTimedOutError:
+ yield listener.deferred
+ except DeferredTimeoutError:
break
except defer.CancelledError:
break
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 7c680659b6..a5cab1f043 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -22,35 +22,32 @@ from .push_rule_evaluator import PushRuleEvaluatorForEvent
from synapse.event_auth import get_user_power_level
from synapse.api.constants import EventTypes, Membership
-from synapse.metrics import get_metrics_for
-from synapse.util.caches import metrics as cache_metrics
+from synapse.util.caches import register_cache
from synapse.util.caches.descriptors import cached
from synapse.util.async import Linearizer
from synapse.state import POWER_KEY
from collections import namedtuple
-
+from prometheus_client import Counter
+from six import itervalues, iteritems
logger = logging.getLogger(__name__)
rules_by_room = {}
-push_metrics = get_metrics_for(__name__)
-push_rules_invalidation_counter = push_metrics.register_counter(
- "push_rules_invalidation_counter"
-)
-push_rules_state_size_counter = push_metrics.register_counter(
- "push_rules_state_size_counter"
-)
+push_rules_invalidation_counter = Counter(
+ "synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", "")
+push_rules_state_size_counter = Counter(
+ "synapse_push_bulk_push_rule_evaluator_push_rules_state_size_counter", "")
# Measures whether we use the fast path of using state deltas, or if we have to
# recalculate from scratch
-push_rules_delta_state_cache_metric = cache_metrics.register_cache(
+push_rules_delta_state_cache_metric = register_cache(
"cache",
- size_callback=lambda: 0, # Meaningless size, as this isn't a cache that stores values
- cache_name="push_rules_delta_state_cache_metric",
+ "push_rules_delta_state_cache_metric",
+ cache=[], # Meaningless size, as this isn't a cache that stores values
)
@@ -64,10 +61,10 @@ class BulkPushRuleEvaluator(object):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- self.room_push_rule_cache_metrics = cache_metrics.register_cache(
+ self.room_push_rule_cache_metrics = register_cache(
"cache",
- size_callback=lambda: 0, # There's not good value for this
- cache_name="room_push_rule_cache",
+ "room_push_rule_cache",
+ cache=[], # Meaningless size, as this isn't a cache that stores values
)
@defer.inlineCallbacks
@@ -126,7 +123,7 @@ class BulkPushRuleEvaluator(object):
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
- (e.type, e.state_key): e for e in auth_events.itervalues()
+ (e.type, e.state_key): e for e in itervalues(auth_events)
}
sender_level = get_user_power_level(event.sender, auth_events)
@@ -160,7 +157,7 @@ class BulkPushRuleEvaluator(object):
condition_cache = {}
- for uid, rules in rules_by_user.iteritems():
+ for uid, rules in iteritems(rules_by_user):
if event.sender == uid:
continue
@@ -309,7 +306,7 @@ class RulesForRoom(object):
current_state_ids = context.current_state_ids
push_rules_delta_state_cache_metric.inc_misses()
- push_rules_state_size_counter.inc_by(len(current_state_ids))
+ push_rules_state_size_counter.inc(len(current_state_ids))
logger.debug(
"Looking for member changes in %r %r", state_group, current_state_ids
@@ -406,7 +403,7 @@ class RulesForRoom(object):
# If the event is a join event then it will be in current state evnts
# map but not in the DB, so we have to explicitly insert it.
if event.type == EventTypes.Member:
- for event_id in member_event_ids.itervalues():
+ for event_id in itervalues(member_event_ids):
if event_id == event.event_id:
members[event_id] = (event.state_key, event.membership)
@@ -414,7 +411,7 @@ class RulesForRoom(object):
logger.debug("Found members %r: %r", self.room_id, members.values())
interested_in_user_ids = set(
- user_id for user_id, membership in members.itervalues()
+ user_id for user_id, membership in itervalues(members)
if membership == Membership.JOIN
)
@@ -426,7 +423,7 @@ class RulesForRoom(object):
)
user_ids = set(
- uid for uid, have_pusher in if_users_with_pushers.iteritems() if have_pusher
+ uid for uid, have_pusher in iteritems(if_users_with_pushers) if have_pusher
)
logger.debug("With pushers: %r", user_ids)
@@ -447,7 +444,7 @@ class RulesForRoom(object):
)
ret_rules_by_user.update(
- item for item in rules_by_user.iteritems() if item[0] is not None
+ item for item in iteritems(rules_by_user) if item[0] is not None
)
self.update_cache(sequence, members, ret_rules_by_user, state_group)
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 58df98a793..ba7286cb72 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -77,10 +77,13 @@ class EmailPusher(object):
@defer.inlineCallbacks
def on_started(self):
if self.mailer is not None:
- self.throttle_params = yield self.store.get_throttle_params_by_room(
- self.pusher_id
- )
- yield self._process()
+ try:
+ self.throttle_params = yield self.store.get_throttle_params_by_room(
+ self.pusher_id
+ )
+ yield self._process()
+ except Exception:
+ logger.exception("Error starting email pusher")
def on_stop(self):
if self.timed_call:
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 2cbac571b8..bf7ff74a1a 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -18,24 +18,19 @@ import logging
from twisted.internet import defer, reactor
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
-import push_rule_evaluator
-import push_tools
-import synapse
+from . import push_rule_evaluator
+from . import push_tools
from synapse.push import PusherConfigException
from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure
-logger = logging.getLogger(__name__)
+from prometheus_client import Counter
-metrics = synapse.metrics.get_metrics_for(__name__)
+logger = logging.getLogger(__name__)
-http_push_processed_counter = metrics.register_counter(
- "http_pushes_processed",
-)
+http_push_processed_counter = Counter("synapse_http_httppusher_http_pushes_processed", "")
-http_push_failed_counter = metrics.register_counter(
- "http_pushes_failed",
-)
+http_push_failed_counter = Counter("synapse_http_httppusher_http_pushes_failed", "")
class HttpPusher(object):
@@ -94,7 +89,10 @@ class HttpPusher(object):
@defer.inlineCallbacks
def on_started(self):
- yield self._process()
+ try:
+ yield self._process()
+ except Exception:
+ logger.exception("Error starting http pusher")
@defer.inlineCallbacks
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 3601f2d365..cf735f7468 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -21,6 +21,8 @@ from synapse.types import UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches.lrucache import LruCache
+from six import string_types
+
logger = logging.getLogger(__name__)
@@ -150,7 +152,7 @@ class PushRuleEvaluatorForEvent(object):
# Caches (glob, word_boundary) -> regex for push. See _glob_matches
regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR)
-register_cache("regex_push_cache", regex_cache)
+register_cache("cache", "regex_push_cache", regex_cache)
def _glob_matches(glob, value, word_boundary=False):
@@ -238,7 +240,7 @@ def _flatten_dict(d, prefix=[], result=None):
if result is None:
result = {}
for key, value in d.items():
- if isinstance(value, basestring):
+ if isinstance(value, string_types):
result[".".join(prefix + [key])] = value.lower()
elif hasattr(value, "items"):
_flatten_dict(value, prefix=(prefix + [key]), result=result)
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index 71576330a9..5aa6667e91 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from httppusher import HttpPusher
+from .httppusher import HttpPusher
import logging
logger = logging.getLogger(__name__)
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 134e89b371..750d11ca38 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -14,13 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
-from .pusher import PusherFactory
-from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
+from synapse.push.pusher import PusherFactory
from synapse.util.async import run_on_reactor
-
-import logging
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
@@ -137,12 +137,15 @@ class PusherPool:
if u in self.pushers:
for p in self.pushers[u].values():
deferreds.append(
- preserve_fn(p.on_new_notifications)(
- min_stream_id, max_stream_id
+ run_in_background(
+ p.on_new_notifications,
+ min_stream_id, max_stream_id,
)
)
- yield make_deferred_yieldable(defer.gatherResults(deferreds))
+ yield make_deferred_yieldable(
+ defer.gatherResults(deferreds, consumeErrors=True),
+ )
except Exception:
logger.exception("Exception in pusher on_new_notifications")
@@ -164,10 +167,15 @@ class PusherPool:
if u in self.pushers:
for p in self.pushers[u].values():
deferreds.append(
- preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id)
+ run_in_background(
+ p.on_new_receipts,
+ min_stream_id, max_stream_id,
+ )
)
- yield make_deferred_yieldable(defer.gatherResults(deferreds))
+ yield make_deferred_yieldable(
+ defer.gatherResults(deferreds, consumeErrors=True),
+ )
except Exception:
logger.exception("Exception in pusher on_new_receipts")
@@ -207,7 +215,7 @@ class PusherPool:
if appid_pushkey in byuser:
byuser[appid_pushkey].on_stop()
byuser[appid_pushkey] = p
- preserve_fn(p.on_started)()
+ run_in_background(p.on_started)
logger.info("Started pushers")
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 91179ce532..478c497722 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -1,5 +1,6 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,28 +19,44 @@ from distutils.version import LooseVersion
logger = logging.getLogger(__name__)
+# this dict maps from python package name to a list of modules we expect it to
+# provide.
+#
+# the key is a "requirement specifier", as used as a parameter to `pip
+# install`[1], or an `install_requires` argument to `setuptools.setup` [2].
+#
+# the value is a sequence of strings; each entry should be the name of the
+# python module, optionally followed by a version assertion which can be either
+# ">=<ver>" or "==<ver>".
+#
+# [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers.
+# [2] https://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-dependencies
REQUIREMENTS = {
"jsonschema>=2.5.1": ["jsonschema>=2.5.1"],
"frozendict>=0.4": ["frozendict"],
"unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"],
- "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
+ "canonicaljson>=1.1.3": ["canonicaljson>=1.1.3"],
"signedjson>=1.0.0": ["signedjson>=1.0.0"],
"pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
"Twisted>=16.0.0": ["twisted>=16.0.0"],
- "pyopenssl>=0.14": ["OpenSSL>=0.14"],
+
+ # We use crypto.get_elliptic_curve which is only supported in >=0.15
+ "pyopenssl>=0.15": ["OpenSSL>=0.15"],
+
"pyyaml": ["yaml"],
"pyasn1": ["pyasn1"],
"daemonize": ["daemonize"],
"bcrypt": ["bcrypt>=3.1.0"],
"pillow": ["PIL"],
"pydenticon": ["pydenticon"],
- "ujson": ["ujson"],
"blist": ["blist"],
"pysaml2>=3.0.0": ["saml2>=3.0.0"],
"pymacaroons-pynacl": ["pymacaroons"],
"msgpack-python>=0.3.0": ["msgpack"],
"phonenumbers>=8.2.0": ["phonenumbers"],
+ "six": ["six"],
+ "prometheus_client": ["prometheus_client"],
}
CONDITIONAL_REQUIREMENTS = {
"web_client": {
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index bbe2f967b7..a9baa2c1c3 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -23,7 +23,6 @@ from synapse.events.snapshot import EventContext
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.util.async import sleep
from synapse.util.caches.response_cache import ResponseCache
-from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
from synapse.util.metrics import Measure
from synapse.types import Requester, UserID
@@ -115,20 +114,15 @@ class ReplicationSendEventRestServlet(RestServlet):
self.clock = hs.get_clock()
# The responses are tiny, so we may as well cache them for a while
- self.response_cache = ResponseCache(hs, timeout_ms=30 * 60 * 1000)
+ self.response_cache = ResponseCache(hs, "send_event", timeout_ms=30 * 60 * 1000)
def on_PUT(self, request, event_id):
- result = self.response_cache.get(event_id)
- if not result:
- result = self.response_cache.set(
- event_id,
- self._handle_request(request)
- )
- else:
- logger.warn("Returning cached response")
- return make_deferred_yieldable(result)
-
- @preserve_fn
+ return self.response_cache.wrap(
+ event_id,
+ self._handle_request,
+ request
+ )
+
@defer.inlineCallbacks
def _handle_request(self, request):
with Measure(self.clock, "repl_send_event_parse"):
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 171227cce2..12aac3cc6b 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -19,11 +19,13 @@ allowed to be sent by which side.
"""
import logging
-import ujson as json
+import simplejson
logger = logging.getLogger(__name__)
+_json_encoder = simplejson.JSONEncoder(namedtuple_as_object=False)
+
class Command(object):
"""The base command class.
@@ -100,14 +102,14 @@ class RdataCommand(Command):
return cls(
stream_name,
None if token == "batch" else int(token),
- json.loads(row_json)
+ simplejson.loads(row_json)
)
def to_line(self):
return " ".join((
self.stream_name,
str(self.token) if self.token is not None else "batch",
- json.dumps(self.row),
+ _json_encoder.encode(self.row),
))
@@ -298,10 +300,12 @@ class InvalidateCacheCommand(Command):
def from_line(cls, line):
cache_func, keys_json = line.split(" ", 1)
- return cls(cache_func, json.loads(keys_json))
+ return cls(cache_func, simplejson.loads(keys_json))
def to_line(self):
- return " ".join((self.cache_func, json.dumps(self.keys)))
+ return " ".join((
+ self.cache_func, _json_encoder.encode(self.keys),
+ ))
class UserIpCommand(Command):
@@ -325,14 +329,14 @@ class UserIpCommand(Command):
def from_line(cls, line):
user_id, jsn = line.split(" ", 1)
- access_token, ip, user_agent, device_id, last_seen = json.loads(jsn)
+ access_token, ip, user_agent, device_id, last_seen = simplejson.loads(jsn)
return cls(
user_id, access_token, ip, user_agent, device_id, last_seen
)
def to_line(self):
- return self.user_id + " " + json.dumps((
+ return self.user_id + " " + _json_encoder.encode((
self.access_token, self.ip, self.user_agent, self.device_id,
self.last_seen,
))
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 0a9a290af4..a6280aae70 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -53,28 +53,28 @@ from twisted.internet import defer
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
-from commands import (
+from .commands import (
COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS,
ErrorCommand, ServerCommand, RdataCommand, PositionCommand, PingCommand,
NameCommand, ReplicateCommand, UserSyncCommand, SyncCommand,
)
-from streams import STREAMS_MAP
+from .streams import STREAMS_MAP
+from synapse.metrics import LaterGauge
from synapse.util.stringutils import random_string
-from synapse.metrics.metric import CounterMetric
-import logging
-import synapse.metrics
-import struct
-import fcntl
+from prometheus_client import Counter
+from collections import defaultdict
-metrics = synapse.metrics.get_metrics_for(__name__)
+from six import iterkeys, iteritems
-connection_close_counter = metrics.register_counter(
- "close_reason", labels=["reason_type"],
-)
+import logging
+import struct
+import fcntl
+connection_close_counter = Counter(
+ "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"])
# A list of all connected protocols. This allows us to send metrics about the
# connections.
@@ -136,12 +136,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# The LoopingCall for sending pings.
self._send_ping_loop = None
- self.inbound_commands_counter = CounterMetric(
- "inbound_commands", labels=["command"],
- )
- self.outbound_commands_counter = CounterMetric(
- "outbound_commands", labels=["command"],
- )
+ self.inbound_commands_counter = defaultdict(int)
+ self.outbound_commands_counter = defaultdict(int)
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
@@ -201,7 +197,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.last_received_command = self.clock.time_msec()
- self.inbound_commands_counter.inc(cmd_name)
+ self.inbound_commands_counter[cmd_name] = (
+ self.inbound_commands_counter[cmd_name] + 1)
cmd_cls = COMMAND_MAP[cmd_name]
try:
@@ -251,8 +248,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self._queue_command(cmd)
return
- self.outbound_commands_counter.inc(cmd.NAME)
-
+ self.outbound_commands_counter[cmd.NAME] = (
+ self.outbound_commands_counter[cmd.NAME] + 1)
string = "%s %s" % (cmd.NAME, cmd.to_line(),)
if "\n" in string:
raise Exception("Unexpected newline in command: %r", string)
@@ -317,9 +314,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def connectionLost(self, reason):
logger.info("[%s] Replication connection closed: %r", self.id(), reason)
if isinstance(reason, Failure):
- connection_close_counter.inc(reason.type.__name__)
+ connection_close_counter.labels(reason.type.__name__).inc()
else:
- connection_close_counter.inc(reason.__class__.__name__)
+ connection_close_counter.labels(reason.__class__.__name__).inc()
try:
# Remove us from list of connections to be monitored
@@ -392,7 +389,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
if stream_name == "ALL":
# Subscribe to all streams we're publishing to.
- for stream in self.streamer.streams_by_name.iterkeys():
+ for stream in iterkeys(self.streamer.streams_by_name):
self.subscribe_to_stream(stream, token)
else:
self.subscribe_to_stream(stream_name, token)
@@ -498,7 +495,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
BaseReplicationStreamProtocol.connectionMade(self)
# Once we've connected subscribe to the necessary streams
- for stream_name, token in self.handler.get_streams_to_replicate().iteritems():
+ for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
self.replicate(stream_name, token)
# Tell the server if we have any users currently syncing (should only
@@ -518,7 +515,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
def on_RDATA(self, cmd):
stream_name = cmd.stream_name
- inbound_rdata_count.inc(stream_name)
+ inbound_rdata_count.labels(stream_name).inc()
try:
row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row)
@@ -566,14 +563,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# The following simply registers metrics for the replication connections
-metrics.register_callback(
- "pending_commands",
+pending_commands = LaterGauge(
+ "pending_commands", "", ["name", "conn_id"],
lambda: {
(p.name, p.conn_id): len(p.pending_commands)
for p in connected_connections
- },
- labels=["name", "conn_id"],
-)
+ })
def transport_buffer_size(protocol):
@@ -583,14 +578,12 @@ def transport_buffer_size(protocol):
return 0
-metrics.register_callback(
- "transport_send_buffer",
+transport_send_buffer = LaterGauge(
+ "synapse_replication_tcp_transport_send_buffer", "", ["name", "conn_id"],
lambda: {
(p.name, p.conn_id): transport_buffer_size(p)
for p in connected_connections
- },
- labels=["name", "conn_id"],
-)
+ })
def transport_kernel_read_buffer_size(protocol, read=True):
@@ -608,48 +601,38 @@ def transport_kernel_read_buffer_size(protocol, read=True):
return 0
-metrics.register_callback(
- "transport_kernel_send_buffer",
+tcp_transport_kernel_send_buffer = LaterGauge(
+ "synapse_replication_tcp_transport_kernel_send_buffer", "", ["name", "conn_id"],
lambda: {
(p.name, p.conn_id): transport_kernel_read_buffer_size(p, False)
for p in connected_connections
- },
- labels=["name", "conn_id"],
-)
+ })
-metrics.register_callback(
- "transport_kernel_read_buffer",
+tcp_transport_kernel_read_buffer = LaterGauge(
+ "synapse_replication_tcp_transport_kernel_read_buffer", "", ["name", "conn_id"],
lambda: {
(p.name, p.conn_id): transport_kernel_read_buffer_size(p, True)
for p in connected_connections
- },
- labels=["name", "conn_id"],
-)
+ })
-metrics.register_callback(
- "inbound_commands",
+tcp_inbound_commands = LaterGauge(
+ "synapse_replication_tcp_inbound_commands", "", ["command", "name", "conn_id"],
lambda: {
(k[0], p.name, p.conn_id): count
for p in connected_connections
- for k, count in p.inbound_commands_counter.counts.iteritems()
- },
- labels=["command", "name", "conn_id"],
-)
+ for k, count in iteritems(p.inbound_commands_counter.counts)
+ })
-metrics.register_callback(
- "outbound_commands",
+tcp_outbound_commands = LaterGauge(
+ "synapse_replication_tcp_outbound_commands", "", ["command", "name", "conn_id"],
lambda: {
(k[0], p.name, p.conn_id): count
for p in connected_connections
- for k, count in p.outbound_commands_counter.counts.iteritems()
- },
- labels=["command", "name", "conn_id"],
-)
+ for k, count in iteritems(p.outbound_commands_counter.counts)
+ })
# number of updates received for each RDATA stream
-inbound_rdata_count = metrics.register_counter(
- "inbound_rdata_count",
- labels=["stream_name"],
-)
+inbound_rdata_count = Counter("synapse_replication_tcp_inbound_rdata_count", "",
+ ["stream_name"])
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 786c3fe864..63bd6d2652 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -18,24 +18,25 @@
from twisted.internet import defer, reactor
from twisted.internet.protocol import Factory
-from streams import STREAMS_MAP, FederationStream
-from protocol import ServerReplicationStreamProtocol
+from .streams import STREAMS_MAP, FederationStream
+from .protocol import ServerReplicationStreamProtocol
from synapse.util.metrics import Measure, measure_func
+from synapse.metrics import LaterGauge
import logging
-import synapse.metrics
+from prometheus_client import Counter
+from six import itervalues
-metrics = synapse.metrics.get_metrics_for(__name__)
-stream_updates_counter = metrics.register_counter(
- "stream_updates", labels=["stream_name"]
-)
-user_sync_counter = metrics.register_counter("user_sync")
-federation_ack_counter = metrics.register_counter("federation_ack")
-remove_pusher_counter = metrics.register_counter("remove_pusher")
-invalidate_cache_counter = metrics.register_counter("invalidate_cache")
-user_ip_cache_counter = metrics.register_counter("user_ip_cache")
+stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates",
+ "", ["stream_name"])
+user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
+federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
+remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
+invalidate_cache_counter = Counter("synapse_replication_tcp_resource_invalidate_cache",
+ "")
+user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
logger = logging.getLogger(__name__)
@@ -69,33 +70,34 @@ class ReplicationStreamer(object):
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
+ self._server_notices_sender = hs.get_server_notices_sender()
# Current connections.
self.connections = []
- metrics.register_callback("total_connections", lambda: len(self.connections))
+ LaterGauge("synapse_replication_tcp_resource_total_connections", "", [],
+ lambda: len(self.connections))
# List of streams that clients can subscribe to.
# We only support federation stream if federation sending hase been
# disabled on the master.
self.streams = [
- stream(hs) for stream in STREAMS_MAP.itervalues()
+ stream(hs) for stream in itervalues(STREAMS_MAP)
if stream != FederationStream or not hs.config.send_federation
]
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
- metrics.register_callback(
- "connections_per_stream",
+ LaterGauge(
+ "synapse_replication_tcp_resource_connections_per_stream", "",
+ ["stream_name"],
lambda: {
(stream_name,): len([
conn for conn in self.connections
if stream_name in conn.replication_streams
])
for stream_name in self.streams_by_name
- },
- labels=["stream_name"],
- )
+ })
self.federation_sender = None
if not hs.config.send_federation:
@@ -175,7 +177,7 @@ class ReplicationStreamer(object):
logger.info(
"Streaming: %s -> %s", stream.NAME, updates[-1][0]
)
- stream_updates_counter.inc_by(len(updates), stream.NAME)
+ stream_updates_counter.labels(stream.NAME).inc(len(updates))
# Some streams return multiple rows with the same stream IDs,
# we need to make sure they get sent out in batches. We do
@@ -253,6 +255,7 @@ class ReplicationStreamer(object):
yield self.store.insert_client_ip(
user_id, access_token, ip, user_agent, device_id, last_seen,
)
+ yield self._server_notices_sender.on_user_ip(user_id)
def send_sync_to_all_connections(self, data):
"""Sends a SYNC command to all clients.
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index fceca2edeb..20fa6678ef 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -19,6 +19,7 @@ import logging
from synapse.api.auth import get_access_token_from_request
from synapse.util.async import ObservableDeferred
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
@@ -80,27 +81,26 @@ class HttpTransactionCache(object):
Returns:
Deferred which resolves to a tuple of (response_code, response_dict).
"""
- try:
- return self.transactions[txn_key][0].observe()
- except (KeyError, IndexError):
- pass # execute the function instead.
-
- deferred = fn(*args, **kwargs)
-
- # if the request fails with a Twisted failure, remove it
- # from the transaction map. This is done to ensure that we don't
- # cache transient errors like rate-limiting errors, etc.
- def remove_from_map(err):
- self.transactions.pop(txn_key, None)
- return err
- deferred.addErrback(remove_from_map)
-
- # We don't add any other errbacks to the raw deferred, so we ask
- # ObservableDeferred to swallow the error. This is fine as the error will
- # still be reported to the observers.
- observable = ObservableDeferred(deferred, consumeErrors=True)
- self.transactions[txn_key] = (observable, self.clock.time_msec())
- return observable.observe()
+ if txn_key in self.transactions:
+ observable = self.transactions[txn_key][0]
+ else:
+ # execute the function instead.
+ deferred = run_in_background(fn, *args, **kwargs)
+
+ observable = ObservableDeferred(deferred)
+ self.transactions[txn_key] = (observable, self.clock.time_msec())
+
+ # if the request fails with an exception, remove it
+ # from the transaction map. This is done to ensure that we don't
+ # cache transient errors like rate-limiting errors, etc.
+ def remove_from_map(err):
+ self.transactions.pop(txn_key, None)
+ # we deliberately do not propagate the error any further, as we
+ # expect the observers to have reported it.
+
+ deferred.addErrback(remove_from_map)
+
+ return make_deferred_yieldable(observable.observe())
def _cleanup(self):
now = self.clock.time_msec()
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index 303419d281..6835a7bba2 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -151,10 +151,11 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
if event.room_id != room_id:
raise SynapseError(400, "Event is for wrong room.")
- depth = event.depth
+ token = yield self.store.get_topological_token_for_event(event_id)
+
logger.info(
- "[purge] purging up to depth %i (event_id %s)",
- depth, event_id,
+ "[purge] purging up to token %s (event_id %s)",
+ token, event_id,
)
elif 'purge_up_to_ts' in body:
ts = body['purge_up_to_ts']
@@ -168,15 +169,30 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
yield self.store.find_first_stream_ordering_after_ts(ts)
)
- (_, depth, _) = (
+ room_event_after_stream_ordering = (
yield self.store.get_room_event_after_stream_ordering(
room_id, stream_ordering,
)
)
+ if room_event_after_stream_ordering:
+ token = yield self.store.get_topological_token_for_event(
+ room_event_after_stream_ordering,
+ )
+ else:
+ logger.warn(
+ "[purge] purging events not possible: No event found "
+ "(received_ts %i => stream_ordering %i)",
+ ts, stream_ordering,
+ )
+ raise SynapseError(
+ 404,
+ "there is no event to be purged",
+ errcode=Codes.NOT_FOUND,
+ )
logger.info(
- "[purge] purging up to depth %i (received_ts %i => "
+ "[purge] purging up to token %d (received_ts %i => "
"stream_ordering %i)",
- depth, ts, stream_ordering,
+ token, ts, stream_ordering,
)
else:
raise SynapseError(
@@ -186,7 +202,7 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
)
purge_id = yield self.handlers.message_handler.start_purge_history(
- room_id, depth,
+ room_id, token,
delete_local_events=delete_local_events,
)
@@ -260,8 +276,8 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(ShutdownRoomRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
- self.handlers = hs.get_handlers()
self.state = hs.get_state_handler()
+ self._room_creation_handler = hs.get_room_creation_handler()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
@@ -283,7 +299,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
message = content.get("message", self.DEFAULT_MESSAGE)
room_name = content.get("room_name", "Content Violation Notification")
- info = yield self.handlers.room_creation_handler.create_room(
+ info = yield self._room_creation_handler.create_room(
room_creator_requester,
config={
"preset": "public_chat",
diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py
index c7aa0bbf59..197335d7aa 100644
--- a/synapse/rest/client/v1/base.py
+++ b/synapse/rest/client/v1/base.py
@@ -52,6 +52,10 @@ class ClientV1RestServlet(RestServlet):
"""A base Synapse REST Servlet for the client version 1 API.
"""
+ # This subclass was presumably created to allow the auth for the v1
+ # protocol version to be different, however this behaviour was removed.
+ # it may no longer be necessary
+
def __init__(self, hs):
"""
Args:
@@ -59,5 +63,5 @@ class ClientV1RestServlet(RestServlet):
"""
self.hs = hs
self.builder_factory = hs.get_event_builder_factory()
- self.auth = hs.get_v1auth()
+ self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs.get_clock())
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 45844aa2d2..34df5be4e9 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -25,7 +25,7 @@ from .base import ClientV1RestServlet, client_path_patterns
import simplejson as json
import urllib
-import urlparse
+from six.moves.urllib import parse as urlparse
import logging
from saml2 import BINDING_HTTP_POST
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index ca49955935..e092158cb7 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -44,7 +44,10 @@ class LogoutRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request)
except AuthError:
# this implies the access token has already been deleted.
- pass
+ defer.returnValue((401, {
+ "errcode": "M_UNKNOWN_TOKEN",
+ "error": "Access Token unknown or expired"
+ }))
else:
if requester.device_id is None:
# the acccess token wasn't associated with a device.
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index 4a73813c58..647994bd53 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -23,6 +23,8 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns
+from six import string_types
+
import logging
logger = logging.getLogger(__name__)
@@ -71,7 +73,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
if "status_msg" in content:
state["status_msg"] = content.pop("status_msg")
- if not isinstance(state["status_msg"], basestring):
+ if not isinstance(state["status_msg"], string_types):
raise SynapseError(400, "status_msg must be a string.")
if content:
@@ -129,7 +131,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
if "invite" in content:
for u in content["invite"]:
- if not isinstance(u, basestring):
+ if not isinstance(u, string_types):
raise SynapseError(400, "Bad invite value.")
if len(u) == 0:
continue
@@ -140,7 +142,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
if "drop" in content:
for u in content["drop"]:
- if not isinstance(u, basestring):
+ if not isinstance(u, string_types):
raise SynapseError(400, "Bad drop value.")
if len(u) == 0:
continue
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 1819a560cb..40e523cc5f 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -150,7 +150,7 @@ class PushersRemoveRestServlet(RestServlet):
super(RestServlet, self).__init__()
self.hs = hs
self.notifier = hs.get_notifier()
- self.auth = hs.get_v1auth()
+ self.auth = hs.get_auth()
self.pusher_pool = self.hs.get_pusherpool()
@defer.inlineCallbacks
@@ -176,7 +176,6 @@ class PushersRemoveRestServlet(RestServlet):
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Server", self.hs.version_string)
request.setHeader(b"Content-Length", b"%d" % (
len(PushersRemoveRestServlet.SUCCESS_HTML),
))
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index 5c5fa8f7ab..9b3022e0b0 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -30,6 +30,8 @@ from hashlib import sha1
import hmac
import logging
+from six import string_types
+
logger = logging.getLogger(__name__)
@@ -333,11 +335,11 @@ class RegisterRestServlet(ClientV1RestServlet):
def _do_shared_secret(self, request, register_json, session):
yield run_on_reactor()
- if not isinstance(register_json.get("mac", None), basestring):
+ if not isinstance(register_json.get("mac", None), string_types):
raise SynapseError(400, "Expected mac.")
- if not isinstance(register_json.get("user", None), basestring):
+ if not isinstance(register_json.get("user", None), string_types):
raise SynapseError(400, "Expected 'user' key.")
- if not isinstance(register_json.get("password", None), basestring):
+ if not isinstance(register_json.get("password", None), string_types):
raise SynapseError(400, "Expected 'password' key.")
if not self.hs.config.registration_shared_secret:
@@ -348,9 +350,9 @@ class RegisterRestServlet(ClientV1RestServlet):
admin = register_json.get("admin", None)
# Its important to check as we use null bytes as HMAC field separators
- if "\x00" in user:
+ if b"\x00" in user:
raise SynapseError(400, "Invalid user")
- if "\x00" in password:
+ if b"\x00" in password:
raise SynapseError(400, "Invalid password")
# str() because otherwise hmac complains that 'unicode' does not
@@ -358,14 +360,14 @@ class RegisterRestServlet(ClientV1RestServlet):
got_mac = str(register_json["mac"])
want_mac = hmac.new(
- key=self.hs.config.registration_shared_secret,
+ key=self.hs.config.registration_shared_secret.encode(),
digestmod=sha1,
)
want_mac.update(user)
- want_mac.update("\x00")
+ want_mac.update(b"\x00")
want_mac.update(password)
- want_mac.update("\x00")
- want_mac.update("admin" if admin else "notadmin")
+ want_mac.update(b"\x00")
+ want_mac.update(b"admin" if admin else b"notadmin")
want_mac = want_mac.hexdigest()
if compare_digest(want_mac, got_mac):
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index f8999d64d7..0b984987ed 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -28,9 +28,10 @@ from synapse.http.servlet import (
parse_json_object_from_request, parse_string, parse_integer
)
+from six.moves.urllib import parse as urlparse
+
import logging
-import urllib
-import ujson as json
+import simplejson as json
logger = logging.getLogger(__name__)
@@ -40,7 +41,7 @@ class RoomCreateRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomCreateRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self._room_creation_handler = hs.get_room_creation_handler()
def register(self, http_server):
PATTERNS = "/createRoom"
@@ -63,8 +64,7 @@ class RoomCreateRestServlet(ClientV1RestServlet):
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
- handler = self.handlers.room_creation_handler
- info = yield handler.create_room(
+ info = yield self._room_creation_handler.create_room(
requester, self.get_room_config(request)
)
@@ -165,17 +165,12 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
content=content,
)
else:
- event, context = yield self.event_creation_hander.create_event(
+ event = yield self.event_creation_hander.create_and_send_nonmember_event(
requester,
event_dict,
- token_id=requester.access_token_id,
txn_id=txn_id,
)
- yield self.event_creation_hander.send_nonmember_event(
- requester, event, context,
- )
-
ret = {}
if event:
ret = {"event_id": event.event_id}
@@ -438,7 +433,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
as_client_event = "raw" not in request.args
filter_bytes = request.args.get("filter", None)
if filter_bytes:
- filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8")
+ filter_json = urlparse.unquote(filter_bytes[-1]).decode("UTF-8")
event_filter = Filter(json.loads(filter_json))
else:
event_filter = None
@@ -655,7 +650,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
content=event_content,
)
- defer.returnValue((200, {}))
+ return_value = {}
+
+ if membership_action == "join":
+ return_value["room_id"] = room_id
+
+ defer.returnValue((200, return_value))
def _has_3pid_invite_keys(self, content):
for key in {"id_server", "medium", "address"}:
@@ -718,8 +718,8 @@ class RoomTypingRestServlet(ClientV1RestServlet):
def on_PUT(self, request, room_id, user_id):
requester = yield self.auth.get_user_by_req(request)
- room_id = urllib.unquote(room_id)
- target_user = UserID.from_string(urllib.unquote(user_id))
+ room_id = urlparse.unquote(room_id)
+ target_user = UserID.from_string(urlparse.unquote(user_id))
content = parse_json_object_from_request(request)
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 8e5577148f..d6f3a19648 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -129,7 +129,6 @@ class AuthRestServlet(RestServlet):
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Server", self.hs.version_string)
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
@@ -175,7 +174,6 @@ class AuthRestServlet(RestServlet):
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Server", self.hs.version_string)
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index f762dbfa9a..3bb1ec2af6 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -401,6 +402,32 @@ class GroupInvitedUsersServlet(RestServlet):
defer.returnValue((200, result))
+class GroupSettingJoinPolicyServlet(RestServlet):
+ """Set group join policy
+ """
+ PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
+
+ def __init__(self, hs):
+ super(GroupSettingJoinPolicyServlet, self).__init__()
+ self.auth = hs.get_auth()
+ self.groups_handler = hs.get_groups_local_handler()
+
+ @defer.inlineCallbacks
+ def on_PUT(self, request, group_id):
+ requester = yield self.auth.get_user_by_req(request)
+ requester_user_id = requester.user.to_string()
+
+ content = parse_json_object_from_request(request)
+
+ result = yield self.groups_handler.set_group_join_policy(
+ group_id,
+ requester_user_id,
+ content,
+ )
+
+ defer.returnValue((200, result))
+
+
class GroupCreateServlet(RestServlet):
"""Create a group
"""
@@ -738,6 +765,7 @@ def register_servlets(hs, http_server):
GroupInvitedUsersServlet(hs).register(http_server)
GroupUsersServlet(hs).register(http_server)
GroupRoomServlet(hs).register(http_server)
+ GroupSettingJoinPolicyServlet(hs).register(http_server)
GroupCreateServlet(hs).register(http_server)
GroupAdminRoomsServlet(hs).register(http_server)
GroupAdminRoomsConfigServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py
index ec170109fe..66583d6778 100644
--- a/synapse/rest/client/v2_alpha/notifications.py
+++ b/synapse/rest/client/v2_alpha/notifications.py
@@ -88,7 +88,7 @@ class NotificationsServlet(RestServlet):
pa["topological_ordering"], pa["stream_ordering"]
)
returned_push_actions.append(returned_pa)
- next_token = pa["stream_ordering"]
+ next_token = str(pa["stream_ordering"])
defer.returnValue((200, {
"notifications": returned_push_actions,
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 0ba62bddc1..5cab00aea9 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -20,7 +20,6 @@ import synapse
import synapse.types
from synapse.api.auth import get_access_token_from_request, has_access_token
from synapse.api.constants import LoginType
-from synapse.types import RoomID, RoomAlias
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
from synapse.http.servlet import (
RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
@@ -36,6 +35,8 @@ from hashlib import sha1
from synapse.util.async import run_on_reactor
from synapse.util.ratelimitutils import FederationRateLimiter
+from six import string_types
+
# We ought to be using hmac.compare_digest() but on older pythons it doesn't
# exist. It's a _really minor_ security flaw to use plain string comparison
@@ -211,14 +212,14 @@ class RegisterRestServlet(RestServlet):
# in sessions. Pull out the username/password provided to us.
desired_password = None
if 'password' in body:
- if (not isinstance(body['password'], basestring) or
+ if (not isinstance(body['password'], string_types) or
len(body['password']) > 512):
raise SynapseError(400, "Invalid password")
desired_password = body["password"]
desired_username = None
if 'username' in body:
- if (not isinstance(body['username'], basestring) or
+ if (not isinstance(body['username'], string_types) or
len(body['username']) > 512):
raise SynapseError(400, "Invalid username")
desired_username = body['username']
@@ -244,7 +245,7 @@ class RegisterRestServlet(RestServlet):
access_token = get_access_token_from_request(request)
- if isinstance(desired_username, basestring):
+ if isinstance(desired_username, string_types):
result = yield self._do_appservice_registration(
desired_username, access_token, body
)
@@ -405,14 +406,6 @@ class RegisterRestServlet(RestServlet):
generate_token=False,
)
- # auto-join the user to any rooms we're supposed to dump them into
- fake_requester = synapse.types.create_requester(registered_user_id)
- for r in self.hs.config.auto_join_rooms:
- try:
- yield self._join_user_to_room(fake_requester, r)
- except Exception as e:
- logger.error("Failed to join new user to %r: %r", r, e)
-
# remember that we've now registered that user account, and with
# what user ID (since the user may not have specified)
self.auth_handler.set_session_data(
@@ -446,29 +439,6 @@ class RegisterRestServlet(RestServlet):
return 200, {}
@defer.inlineCallbacks
- def _join_user_to_room(self, requester, room_identifier):
- room_id = None
- if RoomID.is_valid(room_identifier):
- room_id = room_identifier
- elif RoomAlias.is_valid(room_identifier):
- room_alias = RoomAlias.from_string(room_identifier)
- room_id, remote_room_hosts = (
- yield self.room_member_handler.lookup_room_alias(room_alias)
- )
- room_id = room_id.to_string()
- else:
- raise SynapseError(400, "%s was not legal room ID or room alias" % (
- room_identifier,
- ))
-
- yield self.room_member_handler.update_membership(
- requester=requester,
- target=requester.user,
- room_id=room_id,
- action="join",
- )
-
- @defer.inlineCallbacks
def _do_appservice_registration(self, username, as_token, body):
user_id = yield self.registration_handler.appservice_register(
username, as_token
@@ -496,7 +466,7 @@ class RegisterRestServlet(RestServlet):
# includes the password and admin flag in the hashed text. Why are
# these different?
want_mac = hmac.new(
- key=self.hs.config.registration_shared_secret,
+ key=self.hs.config.registration_shared_secret.encode(),
msg=user,
digestmod=sha1,
).hexdigest()
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index eb91c0b293..a291cffbf1 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -85,6 +85,7 @@ class SyncRestServlet(RestServlet):
self.clock = hs.get_clock()
self.filtering = hs.get_filtering()
self.presence_handler = hs.get_presence_handler()
+ self._server_notices_sender = hs.get_server_notices_sender()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -149,6 +150,9 @@ class SyncRestServlet(RestServlet):
else:
since_token = None
+ # send any outstanding server notices to the user.
+ yield self._server_notices_sender.on_user_syncing(user.to_string())
+
affect_presence = set_presence != PresenceState.OFFLINE
if affect_presence:
diff --git a/synapse/rest/consent/__init__.py b/synapse/rest/consent/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/synapse/rest/consent/__init__.py
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
new file mode 100644
index 0000000000..724911d1e6
--- /dev/null
+++ b/synapse/rest/consent/consent_resource.py
@@ -0,0 +1,222 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from hashlib import sha256
+import hmac
+import logging
+from os import path
+from six.moves import http_client
+
+import jinja2
+from jinja2 import TemplateNotFound
+from twisted.internet import defer
+from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
+
+from synapse.api.errors import NotFoundError, SynapseError, StoreError
+from synapse.config import ConfigError
+from synapse.http.server import (
+ finish_request,
+ wrap_html_request_handler,
+)
+from synapse.http.servlet import parse_string
+from synapse.types import UserID
+
+
+# language to use for the templates. TODO: figure this out from Accept-Language
+TEMPLATE_LANGUAGE = "en"
+
+logger = logging.getLogger(__name__)
+
+# use hmac.compare_digest if we have it (python 2.7.7), else just use equality
+if hasattr(hmac, "compare_digest"):
+ compare_digest = hmac.compare_digest
+else:
+ def compare_digest(a, b):
+ return a == b
+
+
+class ConsentResource(Resource):
+ """A twisted Resource to display a privacy policy and gather consent to it
+
+ When accessed via GET, returns the privacy policy via a template.
+
+ When accessed via POST, records the user's consent in the database and
+ displays a success page.
+
+ The config should include a template_dir setting which contains templates
+ for the HTML. The directory should contain one subdirectory per language
+ (eg, 'en', 'fr'), and each language directory should contain the policy
+ document (named as '<version>.html') and a success page (success.html).
+
+ Both forms take a set of parameters from the browser. For the POST form,
+ these are normally sent as form parameters (but may be query-params); for
+ GET requests they must be query params. These are:
+
+ u: the complete mxid, or the localpart of the user giving their
+ consent. Required for both GET (where it is used as an input to the
+ template) and for POST (where it is used to find the row in the db
+ to update).
+
+ h: hmac_sha256(secret, u), where 'secret' is the privacy_secret in the
+ config file. If it doesn't match, the request is 403ed.
+
+ v: the version of the privacy policy being agreed to.
+
+ For GET: optional, and defaults to whatever was set in the config
+ file. Used to choose the version of the policy to pick from the
+ templates directory.
+
+ For POST: required; gives the value to be recorded in the database
+ against the user.
+ """
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): homeserver
+ """
+ Resource.__init__(self)
+
+ self.hs = hs
+ self.store = hs.get_datastore()
+
+ # this is required by the request_handler wrapper
+ self.clock = hs.get_clock()
+
+ self._default_consent_version = hs.config.user_consent_version
+ if self._default_consent_version is None:
+ raise ConfigError(
+ "Consent resource is enabled but user_consent section is "
+ "missing in config file.",
+ )
+
+ # daemonize changes the cwd to /, so make the path absolute now.
+ consent_template_directory = path.abspath(
+ hs.config.user_consent_template_dir,
+ )
+ if not path.isdir(consent_template_directory):
+ raise ConfigError(
+ "Could not find template directory '%s'" % (
+ consent_template_directory,
+ ),
+ )
+
+ loader = jinja2.FileSystemLoader(consent_template_directory)
+ self._jinja_env = jinja2.Environment(
+ loader=loader,
+ autoescape=jinja2.select_autoescape(['html', 'htm', 'xml']),
+ )
+
+ if hs.config.form_secret is None:
+ raise ConfigError(
+ "Consent resource is enabled but form_secret is not set in "
+ "config file. It should be set to an arbitrary secret string.",
+ )
+
+ self._hmac_secret = hs.config.form_secret.encode("utf-8")
+
+ def render_GET(self, request):
+ self._async_render_GET(request)
+ return NOT_DONE_YET
+
+ @wrap_html_request_handler
+ @defer.inlineCallbacks
+ def _async_render_GET(self, request):
+ """
+ Args:
+ request (twisted.web.http.Request):
+ """
+
+ version = parse_string(request, "v",
+ default=self._default_consent_version)
+ username = parse_string(request, "u", required=True)
+ userhmac = parse_string(request, "h", required=True)
+
+ self._check_hash(username, userhmac)
+
+ if username.startswith('@'):
+ qualified_user_id = username
+ else:
+ qualified_user_id = UserID(username, self.hs.hostname).to_string()
+
+ u = yield self.store.get_user_by_id(qualified_user_id)
+ if u is None:
+ raise NotFoundError("Unknown user")
+
+ try:
+ self._render_template(
+ request, "%s.html" % (version,),
+ user=username, userhmac=userhmac, version=version,
+ has_consented=(u["consent_version"] == version),
+ )
+ except TemplateNotFound:
+ raise NotFoundError("Unknown policy version")
+
+ def render_POST(self, request):
+ self._async_render_POST(request)
+ return NOT_DONE_YET
+
+ @wrap_html_request_handler
+ @defer.inlineCallbacks
+ def _async_render_POST(self, request):
+ """
+ Args:
+ request (twisted.web.http.Request):
+ """
+ version = parse_string(request, "v", required=True)
+ username = parse_string(request, "u", required=True)
+ userhmac = parse_string(request, "h", required=True)
+
+ self._check_hash(username, userhmac)
+
+ if username.startswith('@'):
+ qualified_user_id = username
+ else:
+ qualified_user_id = UserID(username, self.hs.hostname).to_string()
+
+ try:
+ yield self.store.user_set_consent_version(qualified_user_id, version)
+ except StoreError as e:
+ if e.code != 404:
+ raise
+ raise NotFoundError("Unknown user")
+
+ try:
+ self._render_template(request, "success.html")
+ except TemplateNotFound:
+ raise NotFoundError("success.html not found")
+
+ def _render_template(self, request, template_name, **template_args):
+ # get_template checks for ".." so we don't need to worry too much
+ # about path traversal here.
+ template_html = self._jinja_env.get_template(
+ path.join(TEMPLATE_LANGUAGE, template_name)
+ )
+ html_bytes = template_html.render(**template_args).encode("utf8")
+
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(b"Content-Length", b"%i" % len(html_bytes))
+ request.write(html_bytes)
+ finish_request(request)
+
+ def _check_hash(self, userid, userhmac):
+ want_mac = hmac.new(
+ key=self._hmac_secret,
+ msg=userid,
+ digestmod=sha256,
+ ).hexdigest()
+
+ if not compare_digest(want_mac, userhmac):
+ raise SynapseError(http_client.FORBIDDEN, "HMAC incorrect")
diff --git a/synapse/rest/key/v1/server_key_resource.py b/synapse/rest/key/v1/server_key_resource.py
index bd4fea5774..1498d188c1 100644
--- a/synapse/rest/key/v1/server_key_resource.py
+++ b/synapse/rest/key/v1/server_key_resource.py
@@ -49,7 +49,6 @@ class LocalKey(Resource):
"""
def __init__(self, hs):
- self.version_string = hs.version_string
self.response_body = encode_canonical_json(
self.response_json_object(hs.config)
)
@@ -84,7 +83,6 @@ class LocalKey(Resource):
def render_GET(self, request):
return respond_with_json_bytes(
request, 200, self.response_body,
- version_string=self.version_string
)
def getChild(self, name, request):
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index be68d9a096..04775b3c45 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -63,7 +63,6 @@ class LocalKey(Resource):
isLeaf = True
def __init__(self, hs):
- self.version_string = hs.version_string
self.config = hs.config
self.clock = hs.clock
self.update_response_body(self.clock.time_msec())
@@ -115,5 +114,4 @@ class LocalKey(Resource):
self.update_response_body(time_now)
return respond_with_json_bytes(
request, 200, self.response_body,
- version_string=self.version_string
)
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 17e6079cba..21b4c1175e 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.http.server import request_handler, respond_with_json_bytes
+from synapse.http.server import (
+ respond_with_json_bytes, wrap_json_request_handler,
+)
from synapse.http.servlet import parse_integer, parse_json_object_from_request
from synapse.api.errors import SynapseError, Codes
from synapse.crypto.keyring import KeyLookupError
@@ -91,7 +93,6 @@ class RemoteKey(Resource):
def __init__(self, hs):
self.keyring = hs.get_keyring()
self.store = hs.get_datastore()
- self.version_string = hs.version_string
self.clock = hs.get_clock()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
@@ -99,7 +100,7 @@ class RemoteKey(Resource):
self.async_render_GET(request)
return NOT_DONE_YET
- @request_handler()
+ @wrap_json_request_handler
@defer.inlineCallbacks
def async_render_GET(self, request):
if len(request.postpath) == 1:
@@ -124,7 +125,7 @@ class RemoteKey(Resource):
self.async_render_POST(request)
return NOT_DONE_YET
- @request_handler()
+ @wrap_json_request_handler
@defer.inlineCallbacks
def async_render_POST(self, request):
content = parse_json_object_from_request(request)
@@ -240,5 +241,4 @@ class RemoteKey(Resource):
respond_with_json_bytes(
request, 200, result_io.getvalue(),
- version_string=self.version_string
)
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index e7ac01da01..c0d2f06855 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -28,7 +28,7 @@ import os
import logging
import urllib
-import urlparse
+from six.moves.urllib import parse as urlparse
logger = logging.getLogger(__name__)
@@ -143,6 +143,7 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam
respond_404(request)
return
+ logger.debug("Responding to media request with responder %s")
add_file_headers(request, media_type, file_size, upload_name)
with responder:
yield responder.write_to_consumer(request)
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index fe7e17596f..8cf8820c31 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -12,16 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import synapse.http.servlet
+import logging
-from ._base import parse_media_id, respond_404
+from twisted.internet import defer
from twisted.web.resource import Resource
-from synapse.http.server import request_handler, set_cors_headers
-
from twisted.web.server import NOT_DONE_YET
-from twisted.internet import defer
-import logging
+from synapse.http.server import (
+ set_cors_headers,
+ wrap_json_request_handler,
+)
+import synapse.http.servlet
+from ._base import parse_media_id, respond_404
logger = logging.getLogger(__name__)
@@ -35,15 +37,14 @@ class DownloadResource(Resource):
self.media_repo = media_repo
self.server_name = hs.hostname
- # Both of these are expected by @request_handler()
+ # this is expected by @wrap_json_request_handler
self.clock = hs.get_clock()
- self.version_string = hs.version_string
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
- @request_handler()
+ @wrap_json_request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
set_cors_headers(request)
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index bb79599379..2ac767d2dc 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -47,7 +47,8 @@ import shutil
import cgi
import logging
-import urlparse
+from six.moves.urllib import parse as urlparse
+from six import iteritems
logger = logging.getLogger(__name__)
@@ -603,7 +604,7 @@ class MediaRepository(object):
thumbnails[(t_width, t_height, r_type)] = r_method
# Now we generate the thumbnails for each dimension, store it
- for (t_width, t_height, t_type), t_method in thumbnails.iteritems():
+ for (t_width, t_height, t_type), t_method in iteritems(thumbnails):
# Generate the thumbnail
if t_method == "crop":
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 83471b3173..d23fe10b07 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -16,6 +16,8 @@
from twisted.internet import defer, threads
from twisted.protocols.basic import FileSender
+import six
+
from ._base import Responder
from synapse.util.file_consumer import BackgroundFileConsumer
@@ -119,7 +121,7 @@ class MediaStorage(object):
os.remove(fname)
except Exception:
pass
- raise t, v, tb
+ six.reraise(t, v, tb)
if not finished_called:
raise Exception("Finished callback not called")
@@ -253,7 +255,9 @@ class FileResponder(Responder):
self.open_file = open_file
def write_to_consumer(self, consumer):
- return FileSender().beginFileTransfer(self.open_file, consumer)
+ return make_deferred_yieldable(
+ FileSender().beginFileTransfer(self.open_file, consumer)
+ )
def __exit__(self, exc_type, exc_val, exc_tb):
self.open_file.close()
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 31fe7aa75c..565cef2b8d 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -23,8 +23,10 @@ import re
import shutil
import sys
import traceback
-import ujson as json
-import urlparse
+import simplejson as json
+
+from six.moves import urllib_parse as urlparse
+from six import string_types
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
@@ -35,13 +37,14 @@ from ._base import FileInfo
from synapse.api.errors import (
SynapseError, Codes,
)
-from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.stringutils import random_string
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.http.client import SpiderHttpClient
from synapse.http.server import (
- request_handler, respond_with_json_bytes,
+ respond_with_json_bytes,
respond_with_json,
+ wrap_json_request_handler,
)
from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii
@@ -57,7 +60,6 @@ class PreviewUrlResource(Resource):
self.auth = hs.get_auth()
self.clock = hs.get_clock()
- self.version_string = hs.version_string
self.filepaths = media_repo.filepaths
self.max_spider_size = hs.config.max_spider_size
self.server_name = hs.hostname
@@ -90,7 +92,7 @@ class PreviewUrlResource(Resource):
self._async_render_GET(request)
return NOT_DONE_YET
- @request_handler()
+ @wrap_json_request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
@@ -144,7 +146,8 @@ class PreviewUrlResource(Resource):
observable = self._cache.get(url)
if not observable:
- download = preserve_fn(self._do_preview)(
+ download = run_in_background(
+ self._do_preview,
url, requester.user, ts,
)
observable = ObservableDeferred(
@@ -589,8 +592,8 @@ def _iterate_over_text(tree, *tags_to_ignore):
# to be returned.
elements = iter([tree])
while True:
- el = elements.next()
- if isinstance(el, basestring):
+ el = next(elements)
+ if isinstance(el, string_types):
yield el
elif el is not None and el.tag not in tags_to_ignore:
# el.text is the text before the first child, so we can immediately
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index c188192f2b..0252afd9d3 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -18,7 +18,7 @@ from twisted.internet import defer, threads
from .media_storage import FileResponder
from synapse.config._base import Config
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import run_in_background
import logging
import os
@@ -87,7 +87,12 @@ class StorageProviderWrapper(StorageProvider):
return self.backend.store_file(path, file_info)
else:
# TODO: Handle errors.
- preserve_fn(self.backend.store_file)(path, file_info)
+ def store():
+ try:
+ return self.backend.store_file(path, file_info)
+ except Exception:
+ logger.exception("Error storing file")
+ run_in_background(store)
return defer.succeed(None)
def fetch(self, path, file_info):
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 58ada49711..aae6e464e8 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -14,18 +14,21 @@
# limitations under the License.
-from ._base import (
- parse_media_id, respond_404, respond_with_file, FileInfo,
- respond_with_responder,
-)
-from twisted.web.resource import Resource
-from synapse.http.servlet import parse_string, parse_integer
-from synapse.http.server import request_handler, set_cors_headers
+import logging
-from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
+from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
-import logging
+from synapse.http.server import (
+ set_cors_headers,
+ wrap_json_request_handler,
+)
+from synapse.http.servlet import parse_integer, parse_string
+from ._base import (
+ FileInfo, parse_media_id, respond_404, respond_with_file,
+ respond_with_responder,
+)
logger = logging.getLogger(__name__)
@@ -41,14 +44,13 @@ class ThumbnailResource(Resource):
self.media_storage = media_storage
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname
- self.version_string = hs.version_string
self.clock = hs.get_clock()
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
- @request_handler()
+ @wrap_json_request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
set_cors_headers(request)
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index f6f498cdc5..7567476fce 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -13,16 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.http.server import respond_with_json, request_handler
-
-from synapse.api.errors import SynapseError
+import logging
-from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
-
from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
-import logging
+from synapse.api.errors import SynapseError
+from synapse.http.server import (
+ respond_with_json,
+ wrap_json_request_handler,
+)
logger = logging.getLogger(__name__)
@@ -40,7 +41,6 @@ class UploadResource(Resource):
self.server_name = hs.hostname
self.auth = hs.get_auth()
self.max_upload_size = hs.config.max_upload_size
- self.version_string = hs.version_string
self.clock = hs.get_clock()
def render_POST(self, request):
@@ -51,7 +51,7 @@ class UploadResource(Resource):
respond_with_json(request, 200, {}, send_cors=True)
return NOT_DONE_YET
- @request_handler()
+ @wrap_json_request_handler
@defer.inlineCallbacks
def _async_render_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
@@ -81,15 +81,15 @@ class UploadResource(Resource):
headers = request.requestHeaders
if headers.hasHeader("Content-Type"):
- media_type = headers.getRawHeaders("Content-Type")[0]
+ media_type = headers.getRawHeaders(b"Content-Type")[0]
else:
raise SynapseError(
msg="Upload request missing 'Content-Type'",
code=400,
)
- # if headers.hasHeader("Content-Disposition"):
- # disposition = headers.getRawHeaders("Content-Disposition")[0]
+ # if headers.hasHeader(b"Content-Disposition"):
+ # disposition = headers.getRawHeaders(b"Content-Disposition")[0]
# TODO(markjh): parse content-dispostion
content_uri = yield self.media_repo.create_content(
diff --git a/synapse/server.py b/synapse/server.py
index cd0c1a51be..58dbf78437 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -46,6 +46,7 @@ from synapse.handlers.devicemessage import DeviceMessageHandler
from synapse.handlers.device import DeviceHandler
from synapse.handlers.e2e_keys import E2eKeysHandler
from synapse.handlers.presence import PresenceHandler
+from synapse.handlers.room import RoomCreationHandler
from synapse.handlers.room_list import RoomListHandler
from synapse.handlers.room_member import RoomMemberMasterHandler
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
@@ -71,6 +72,11 @@ from synapse.rest.media.v1.media_repository import (
MediaRepository,
MediaRepositoryResource,
)
+from synapse.server_notices.server_notices_manager import ServerNoticesManager
+from synapse.server_notices.server_notices_sender import ServerNoticesSender
+from synapse.server_notices.worker_server_notices_sender import (
+ WorkerServerNoticesSender,
+)
from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import DataStore
from synapse.streams.events import EventSources
@@ -97,6 +103,9 @@ class HomeServer(object):
which must be implemented by the subclass. This code may call any of the
required "get" methods on the instance to obtain the sub-dependencies that
one requires.
+
+ Attributes:
+ config (synapse.config.homeserver.HomeserverConfig):
"""
DEPENDENCIES = [
@@ -105,8 +114,8 @@ class HomeServer(object):
'federation_client',
'federation_server',
'handlers',
- 'v1auth',
'auth',
+ 'room_creation_handler',
'state_handler',
'state_resolution_handler',
'presence_handler',
@@ -152,6 +161,8 @@ class HomeServer(object):
'spam_checker',
'room_member_handler',
'federation_registry',
+ 'server_notices_manager',
+ 'server_notices_sender',
]
def __init__(self, hostname, **kwargs):
@@ -225,14 +236,8 @@ class HomeServer(object):
def build_simple_http_client(self):
return SimpleHttpClient(self)
- def build_v1auth(self):
- orf = Auth(self)
- # Matrix spec makes no reference to what HTTP status code is returned,
- # but the V1 API uses 403 where it means 401, and the webclient
- # relies on this behaviour, so V1 gets its own copy of the auth
- # with backwards compat behaviour.
- orf.TOKEN_NOT_FOUND_HTTP_STATUS = 403
- return orf
+ def build_room_creation_handler(self):
+ return RoomCreationHandler(self)
def build_state_handler(self):
return StateHandler(self)
@@ -400,6 +405,16 @@ class HomeServer(object):
def build_federation_registry(self):
return FederationHandlerRegistry()
+ def build_server_notices_manager(self):
+ if self.config.worker_app:
+ raise Exception("Workers cannot send server notices")
+ return ServerNoticesManager(self)
+
+ def build_server_notices_sender(self):
+ if self.config.worker_app:
+ return WorkerServerNoticesSender(self)
+ return ServerNoticesSender(self)
+
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/server.pyi b/synapse/server.pyi
index c3a9a3847b..ce28486233 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -1,4 +1,5 @@
import synapse.api.auth
+import synapse.config.homeserver
import synapse.federation.transaction_queue
import synapse.federation.transport.client
import synapse.handlers
@@ -8,11 +9,17 @@ import synapse.handlers.device
import synapse.handlers.e2e_keys
import synapse.handlers.set_password
import synapse.rest.media.v1.media_repository
+import synapse.server_notices.server_notices_manager
+import synapse.server_notices.server_notices_sender
import synapse.state
import synapse.storage
class HomeServer(object):
+ @property
+ def config(self) -> synapse.config.homeserver.HomeServerConfig:
+ pass
+
def get_auth(self) -> synapse.api.auth.Auth:
pass
@@ -40,6 +47,12 @@ class HomeServer(object):
def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
pass
+ def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler:
+ pass
+
+ def get_event_creation_handler(self) -> synapse.handlers.message.EventCreationHandler:
+ pass
+
def get_set_password_handler(self) -> synapse.handlers.set_password.SetPasswordHandler:
pass
@@ -54,3 +67,9 @@ class HomeServer(object):
def get_media_repository(self) -> synapse.rest.media.v1.media_repository.MediaRepository:
pass
+
+ def get_server_notices_manager(self) -> synapse.server_notices.server_notices_manager.ServerNoticesManager:
+ pass
+
+ def get_server_notices_sender(self) -> synapse.server_notices.server_notices_sender.ServerNoticesSender:
+ pass
diff --git a/synapse/server_notices/__init__.py b/synapse/server_notices/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/synapse/server_notices/__init__.py
diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py
new file mode 100644
index 0000000000..bb74af1af5
--- /dev/null
+++ b/synapse/server_notices/consent_server_notices.py
@@ -0,0 +1,137 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from six import (iteritems, string_types)
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.api.urls import ConsentURIBuilder
+from synapse.config import ConfigError
+from synapse.types import get_localpart_from_id
+
+logger = logging.getLogger(__name__)
+
+
+class ConsentServerNotices(object):
+ """Keeps track of whether we need to send users server_notices about
+ privacy policy consent, and sends one if we do.
+ """
+ def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer):
+ """
+ self._server_notices_manager = hs.get_server_notices_manager()
+ self._store = hs.get_datastore()
+
+ self._users_in_progress = set()
+
+ self._current_consent_version = hs.config.user_consent_version
+ self._server_notice_content = hs.config.user_consent_server_notice_content
+ self._send_to_guests = hs.config.user_consent_server_notice_to_guests
+
+ if self._server_notice_content is not None:
+ if not self._server_notices_manager.is_enabled():
+ raise ConfigError(
+ "user_consent configuration requires server notices, but "
+ "server notices are not enabled.",
+ )
+ if 'body' not in self._server_notice_content:
+ raise ConfigError(
+ "user_consent server_notice_consent must contain a 'body' "
+ "key.",
+ )
+
+ self._consent_uri_builder = ConsentURIBuilder(hs.config)
+
+ @defer.inlineCallbacks
+ def maybe_send_server_notice_to_user(self, user_id):
+ """Check if we need to send a notice to this user, and does so if so
+
+ Args:
+ user_id (str): user to check
+
+ Returns:
+ Deferred
+ """
+ if self._server_notice_content is None:
+ # not enabled
+ return
+
+ # make sure we don't send two messages to the same user at once
+ if user_id in self._users_in_progress:
+ return
+ self._users_in_progress.add(user_id)
+ try:
+ u = yield self._store.get_user_by_id(user_id)
+
+ if u["is_guest"] and not self._send_to_guests:
+ # don't send to guests
+ return
+
+ if u["consent_version"] == self._current_consent_version:
+ # user has already consented
+ return
+
+ if u["consent_server_notice_sent"] == self._current_consent_version:
+ # we've already sent a notice to the user
+ return
+
+ # need to send a message.
+ try:
+ consent_uri = self._consent_uri_builder.build_user_consent_uri(
+ get_localpart_from_id(user_id),
+ )
+ content = copy_with_str_subst(
+ self._server_notice_content, {
+ 'consent_uri': consent_uri,
+ },
+ )
+ yield self._server_notices_manager.send_notice(
+ user_id, content,
+ )
+ yield self._store.user_set_consent_server_notice_sent(
+ user_id, self._current_consent_version,
+ )
+ except SynapseError as e:
+ logger.error("Error sending server notice about user consent: %s", e)
+ finally:
+ self._users_in_progress.remove(user_id)
+
+
+def copy_with_str_subst(x, substitutions):
+ """Deep-copy a structure, carrying out string substitions on any strings
+
+ Args:
+ x (object): structure to be copied
+ substitutions (object): substitutions to be made - passed into the
+ string '%' operator
+
+ Returns:
+ copy of x
+ """
+ if isinstance(x, string_types):
+ return x % substitutions
+ if isinstance(x, dict):
+ return {
+ k: copy_with_str_subst(v, substitutions) for (k, v) in iteritems(x)
+ }
+ if isinstance(x, (list, tuple)):
+ return [copy_with_str_subst(y) for y in x]
+
+ # assume it's uninterested and can be shallow-copied.
+ return x
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
new file mode 100644
index 0000000000..a26deace53
--- /dev/null
+++ b/synapse/server_notices/server_notices_manager.py
@@ -0,0 +1,146 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership, RoomCreationPreset
+from synapse.types import create_requester
+from synapse.util.caches.descriptors import cachedInlineCallbacks
+
+logger = logging.getLogger(__name__)
+
+
+class ServerNoticesManager(object):
+ def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer):
+ """
+
+ self._store = hs.get_datastore()
+ self._config = hs.config
+ self._room_creation_handler = hs.get_room_creation_handler()
+ self._event_creation_handler = hs.get_event_creation_handler()
+ self._is_mine_id = hs.is_mine_id
+
+ def is_enabled(self):
+ """Checks if server notices are enabled on this server.
+
+ Returns:
+ bool
+ """
+ return self._config.server_notices_mxid is not None
+
+ @defer.inlineCallbacks
+ def send_notice(self, user_id, event_content):
+ """Send a notice to the given user
+
+ Creates the server notices room, if none exists.
+
+ Args:
+ user_id (str): mxid of user to send event to.
+ event_content (dict): content of event to send
+
+ Returns:
+ Deferred[None]
+ """
+ room_id = yield self.get_notice_room_for_user(user_id)
+
+ system_mxid = self._config.server_notices_mxid
+ requester = create_requester(system_mxid)
+
+ logger.info("Sending server notice to %s", user_id)
+
+ yield self._event_creation_handler.create_and_send_nonmember_event(
+ requester, {
+ "type": EventTypes.Message,
+ "room_id": room_id,
+ "sender": system_mxid,
+ "content": event_content,
+ },
+ ratelimit=False,
+ )
+
+ @cachedInlineCallbacks()
+ def get_notice_room_for_user(self, user_id):
+ """Get the room for notices for a given user
+
+ If we have not yet created a notice room for this user, create it
+
+ Args:
+ user_id (str): complete user id for the user we want a room for
+
+ Returns:
+ str: room id of notice room.
+ """
+ if not self.is_enabled():
+ raise Exception("Server notices not enabled")
+
+ assert self._is_mine_id(user_id), \
+ "Cannot send server notices to remote users"
+
+ rooms = yield self._store.get_rooms_for_user_where_membership_is(
+ user_id, [Membership.INVITE, Membership.JOIN],
+ )
+ system_mxid = self._config.server_notices_mxid
+ for room in rooms:
+ # it's worth noting that there is an asymmetry here in that we
+ # expect the user to be invited or joined, but the system user must
+ # be joined. This is kinda deliberate, in that if somebody somehow
+ # manages to invite the system user to a room, that doesn't make it
+ # the server notices room.
+ user_ids = yield self._store.get_users_in_room(room.room_id)
+ if system_mxid in user_ids:
+ # we found a room which our user shares with the system notice
+ # user
+ logger.info("Using room %s", room.room_id)
+ defer.returnValue(room.room_id)
+
+ # apparently no existing notice room: create a new one
+ logger.info("Creating server notices room for %s", user_id)
+
+ # see if we want to override the profile info for the server user.
+ # note that if we want to override either the display name or the
+ # avatar, we have to use both.
+ join_profile = None
+ if (
+ self._config.server_notices_mxid_display_name is not None or
+ self._config.server_notices_mxid_avatar_url is not None
+ ):
+ join_profile = {
+ "displayname": self._config.server_notices_mxid_display_name,
+ "avatar_url": self._config.server_notices_mxid_avatar_url,
+ }
+
+ requester = create_requester(system_mxid)
+ info = yield self._room_creation_handler.create_room(
+ requester,
+ config={
+ "preset": RoomCreationPreset.PRIVATE_CHAT,
+ "name": self._config.server_notices_room_name,
+ "power_level_content_override": {
+ "users_default": -10,
+ },
+ "invite": (user_id,)
+ },
+ ratelimit=False,
+ creator_join_profile=join_profile,
+ )
+ room_id = info['room_id']
+
+ logger.info("Created server notices room %s for %s", room_id, user_id)
+ defer.returnValue(room_id)
diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py
new file mode 100644
index 0000000000..5d23965f34
--- /dev/null
+++ b/synapse/server_notices/server_notices_sender.py
@@ -0,0 +1,58 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.server_notices.consent_server_notices import ConsentServerNotices
+
+
+class ServerNoticesSender(object):
+ """A centralised place which sends server notices automatically when
+ Certain Events take place
+ """
+ def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer):
+ """
+ # todo: it would be nice to make this more dynamic
+ self._consent_server_notices = ConsentServerNotices(hs)
+
+ def on_user_syncing(self, user_id):
+ """Called when the user performs a sync operation.
+
+ Args:
+ user_id (str): mxid of user who synced
+
+ Returns:
+ Deferred
+ """
+ return self._consent_server_notices.maybe_send_server_notice_to_user(
+ user_id,
+ )
+
+ def on_user_ip(self, user_id):
+ """Called on the master when a worker process saw a client request.
+
+ Args:
+ user_id (str): mxid
+
+ Returns:
+ Deferred
+ """
+ # The synchrotrons use a stubbed version of ServerNoticesSender, so
+ # we check for notices to send to the user in on_user_ip as well as
+ # in on_user_syncing
+ return self._consent_server_notices.maybe_send_server_notice_to_user(
+ user_id,
+ )
diff --git a/synapse/server_notices/worker_server_notices_sender.py b/synapse/server_notices/worker_server_notices_sender.py
new file mode 100644
index 0000000000..4a133026c3
--- /dev/null
+++ b/synapse/server_notices/worker_server_notices_sender.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+
+
+class WorkerServerNoticesSender(object):
+ """Stub impl of ServerNoticesSender which does nothing"""
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer):
+ """
+
+ def on_user_syncing(self, user_id):
+ """Called when the user performs a sync operation.
+
+ Args:
+ user_id (str): mxid of user who synced
+
+ Returns:
+ Deferred
+ """
+ return defer.succeed(None)
+
+ def on_user_ip(self, user_id):
+ """Called on the master when a worker process saw a client request.
+
+ Args:
+ user_id (str): mxid
+
+ Returns:
+ Deferred
+ """
+ raise AssertionError("on_user_ip unexpectedly called on worker")
diff --git a/synapse/state.py b/synapse/state.py
index 932f602508..b8c27c6815 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -32,6 +32,8 @@ from frozendict import frozendict
import logging
import hashlib
+from six import iteritems, itervalues
+
logger = logging.getLogger(__name__)
@@ -132,7 +134,7 @@ class StateHandler(object):
state_map = yield self.store.get_events(state.values(), get_prev_content=False)
state = {
- key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
+ key: state_map[e_id] for key, e_id in iteritems(state) if e_id in state_map
}
defer.returnValue(state)
@@ -338,7 +340,7 @@ class StateHandler(object):
)
if len(state_groups_ids) == 1:
- name, state_list = state_groups_ids.items().pop()
+ name, state_list = list(state_groups_ids.items()).pop()
prev_group, delta_ids = yield self.store.get_state_group_delta(name)
@@ -378,7 +380,7 @@ class StateHandler(object):
new_state = resolve_events_with_state_map(state_set_ids, state_map)
new_state = {
- key: state_map[ev_id] for key, ev_id in new_state.items()
+ key: state_map[ev_id] for key, ev_id in iteritems(new_state)
}
return new_state
@@ -458,15 +460,15 @@ class StateResolutionHandler(object):
# build a map from state key to the event_ids which set that state.
# dict[(str, str), set[str])
state = {}
- for st in state_groups_ids.values():
- for key, e_id in st.items():
+ for st in itervalues(state_groups_ids):
+ for key, e_id in iteritems(st):
state.setdefault(key, set()).add(e_id)
# build a map from state key to the event_ids which set that state,
# including only those where there are state keys in conflict.
conflicted_state = {
k: list(v)
- for k, v in state.items()
+ for k, v in iteritems(state)
if len(v) > 1
}
@@ -474,42 +476,43 @@ class StateResolutionHandler(object):
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_factory(
- state_groups_ids.values(),
+ list(state_groups_ids.values()),
event_map=event_map,
state_map_factory=state_map_factory,
)
else:
new_state = {
- key: e_ids.pop() for key, e_ids in state.items()
+ key: e_ids.pop() for key, e_ids in iteritems(state)
}
- # if the new state matches any of the input state groups, we can
- # use that state group again. Otherwise we will generate a state_id
- # which will be used as a cache key for future resolutions, but
- # not get persisted.
- state_group = None
- new_state_event_ids = frozenset(new_state.values())
- for sg, events in state_groups_ids.items():
- if new_state_event_ids == frozenset(e_id for e_id in events):
- state_group = sg
- break
-
- # TODO: We want to create a state group for this set of events, to
- # increase cache hits, but we need to make sure that it doesn't
- # end up as a prev_group without being added to the database
-
- prev_group = None
- delta_ids = None
- for old_group, old_ids in state_groups_ids.iteritems():
- if not set(new_state) - set(old_ids):
- n_delta_ids = {
- k: v
- for k, v in new_state.iteritems()
- if old_ids.get(k) != v
- }
- if not delta_ids or len(n_delta_ids) < len(delta_ids):
- prev_group = old_group
- delta_ids = n_delta_ids
+ with Measure(self.clock, "state.create_group_ids"):
+ # if the new state matches any of the input state groups, we can
+ # use that state group again. Otherwise we will generate a state_id
+ # which will be used as a cache key for future resolutions, but
+ # not get persisted.
+ state_group = None
+ new_state_event_ids = frozenset(itervalues(new_state))
+ for sg, events in iteritems(state_groups_ids):
+ if new_state_event_ids == frozenset(e_id for e_id in events):
+ state_group = sg
+ break
+
+ # TODO: We want to create a state group for this set of events, to
+ # increase cache hits, but we need to make sure that it doesn't
+ # end up as a prev_group without being added to the database
+
+ prev_group = None
+ delta_ids = None
+ for old_group, old_ids in iteritems(state_groups_ids):
+ if not set(new_state) - set(old_ids):
+ n_delta_ids = {
+ k: v
+ for k, v in iteritems(new_state)
+ if old_ids.get(k) != v
+ }
+ if not delta_ids or len(n_delta_ids) < len(delta_ids):
+ prev_group = old_group
+ delta_ids = n_delta_ids
cache = _StateCacheEntry(
state=new_state,
@@ -526,7 +529,7 @@ class StateResolutionHandler(object):
def _ordered_events(events):
def key_func(e):
- return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
+ return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest()
return sorted(events, key=key_func)
@@ -583,7 +586,7 @@ def _seperate(state_sets):
conflicted_state = {}
for state_set in state_sets[1:]:
- for key, value in state_set.iteritems():
+ for key, value in iteritems(state_set):
# Check if there is an unconflicted entry for the state key.
unconflicted_value = unconflicted_state.get(key)
if unconflicted_value is None:
@@ -639,7 +642,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
needed_events = set(
event_id
- for event_ids in conflicted_state.itervalues()
+ for event_ids in itervalues(conflicted_state)
for event_id in event_ids
)
if event_map is not None:
@@ -661,7 +664,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
unconflicted_state, conflicted_state, state_map
)
- new_needed_events = set(auth_events.itervalues())
+ new_needed_events = set(itervalues(auth_events))
new_needed_events -= needed_events
if event_map is not None:
new_needed_events -= set(event_map.iterkeys())
@@ -678,7 +681,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
auth_events = {}
- for event_ids in conflicted_state.itervalues():
+ for event_ids in itervalues(conflicted_state):
for event_id in event_ids:
if event_id in state_map:
keys = event_auth.auth_types_for_event(state_map[event_id])
@@ -693,7 +696,7 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma
def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ids,
state_map):
conflicted_state = {}
- for key, event_ids in conflicted_state_ds.iteritems():
+ for key, event_ids in iteritems(conflicted_state_ds):
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
if len(events) > 1:
conflicted_state[key] = events
@@ -702,7 +705,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_
auth_events = {
key: state_map[ev_id]
- for key, ev_id in auth_event_ids.items()
+ for key, ev_id in iteritems(auth_event_ids)
if ev_id in state_map
}
@@ -715,7 +718,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_
raise
new_state = unconflicted_state_ids
- for key, event in resolved_state.iteritems():
+ for key, event in iteritems(resolved_state):
new_state[key] = event.event_id
return new_state
@@ -740,7 +743,7 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state)
- for key, events in conflicted_state.items():
+ for key, events in iteritems(conflicted_state):
if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = _resolve_auth_events(
@@ -750,7 +753,7 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state)
- for key, events in conflicted_state.items():
+ for key, events in iteritems(conflicted_state):
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = _resolve_auth_events(
@@ -760,7 +763,7 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state)
- for key, events in conflicted_state.items():
+ for key, events in iteritems(conflicted_state):
if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = _resolve_normal_events(
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index de00cae447..979fa22438 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -14,7 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+import datetime
+from dateutil import tz
+import time
+import logging
from synapse.storage.devices import DeviceStore
from .appservice import (
@@ -57,10 +60,6 @@ from .engines import PostgresEngine
from synapse.api.constants import PresenceState
from synapse.util.caches.stream_change_cache import StreamChangeCache
-
-import logging
-
-
logger = logging.getLogger(__name__)
@@ -215,6 +214,9 @@ class DataStore(RoomMemberStore, RoomStore,
self._stream_order_on_start = self.get_room_max_stream_ordering()
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
+ # Used in _generate_user_daily_visits to keep track of progress
+ self._last_user_visit_update = self._get_start_of_day()
+
super(DataStore, self).__init__(db_conn, hs)
def take_presence_startup_info(self):
@@ -244,13 +246,12 @@ class DataStore(RoomMemberStore, RoomStore,
return [UserPresenceState(**row) for row in rows]
- @defer.inlineCallbacks
def count_daily_users(self):
"""
Counts the number of users who used this homeserver in the last 24 hours.
"""
def _count_users(txn):
- yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24),
+ yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
sql = """
SELECT COALESCE(count(*), 0) FROM (
@@ -264,8 +265,154 @@ class DataStore(RoomMemberStore, RoomStore,
count, = txn.fetchone()
return count
- ret = yield self.runInteraction("count_users", _count_users)
- defer.returnValue(ret)
+ return self.runInteraction("count_users", _count_users)
+
+ def count_r30_users(self):
+ """
+ Counts the number of 30 day retained users, defined as:-
+ * Users who have created their accounts more than 30 days ago
+ * Where last seen at most 30 days ago
+ * Where account creation and last_seen are > 30 days apart
+
+ Returns counts globaly for a given user as well as breaking
+ by platform
+ """
+ def _count_r30_users(txn):
+ thirty_days_in_secs = 86400 * 30
+ now = int(self._clock.time())
+ thirty_days_ago_in_secs = now - thirty_days_in_secs
+
+ sql = """
+ SELECT platform, COALESCE(count(*), 0) FROM (
+ SELECT
+ users.name, platform, users.creation_ts * 1000,
+ MAX(uip.last_seen)
+ FROM users
+ INNER JOIN (
+ SELECT
+ user_id,
+ last_seen,
+ CASE
+ WHEN user_agent LIKE '%%Android%%' THEN 'android'
+ WHEN user_agent LIKE '%%iOS%%' THEN 'ios'
+ WHEN user_agent LIKE '%%Electron%%' THEN 'electron'
+ WHEN user_agent LIKE '%%Mozilla%%' THEN 'web'
+ WHEN user_agent LIKE '%%Gecko%%' THEN 'web'
+ ELSE 'unknown'
+ END
+ AS platform
+ FROM user_ips
+ ) uip
+ ON users.name = uip.user_id
+ AND users.appservice_id is NULL
+ AND users.creation_ts < ?
+ AND uip.last_seen/1000 > ?
+ AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+ GROUP BY users.name, platform, users.creation_ts
+ ) u GROUP BY platform
+ """
+
+ results = {}
+ txn.execute(sql, (thirty_days_ago_in_secs,
+ thirty_days_ago_in_secs))
+
+ for row in txn:
+ if row[0] is 'unknown':
+ pass
+ results[row[0]] = row[1]
+
+ sql = """
+ SELECT COALESCE(count(*), 0) FROM (
+ SELECT users.name, users.creation_ts * 1000,
+ MAX(uip.last_seen)
+ FROM users
+ INNER JOIN (
+ SELECT
+ user_id,
+ last_seen
+ FROM user_ips
+ ) uip
+ ON users.name = uip.user_id
+ AND appservice_id is NULL
+ AND users.creation_ts < ?
+ AND uip.last_seen/1000 > ?
+ AND (uip.last_seen/1000) - users.creation_ts > 86400 * 30
+ GROUP BY users.name, users.creation_ts
+ ) u
+ """
+
+ txn.execute(sql, (thirty_days_ago_in_secs,
+ thirty_days_ago_in_secs))
+
+ count, = txn.fetchone()
+ results['all'] = count
+
+ return results
+
+ return self.runInteraction("count_r30_users", _count_r30_users)
+
+ def _get_start_of_day(self):
+ """
+ Returns millisecond unixtime for start of UTC day.
+ """
+ now = datetime.datetime.utcnow()
+ today_start = datetime.datetime(now.year, now.month,
+ now.day, tzinfo=tz.tzutc())
+ return int(time.mktime(today_start.timetuple())) * 1000
+
+ def generate_user_daily_visits(self):
+ """
+ Generates daily visit data for use in cohort/ retention analysis
+ """
+ def _generate_user_daily_visits(txn):
+ logger.info("Calling _generate_user_daily_visits")
+ today_start = self._get_start_of_day()
+ a_day_in_milliseconds = 24 * 60 * 60 * 1000
+ now = self.clock.time_msec()
+
+ sql = """
+ INSERT INTO user_daily_visits (user_id, device_id, timestamp)
+ SELECT u.user_id, u.device_id, ?
+ FROM user_ips AS u
+ LEFT JOIN (
+ SELECT user_id, device_id, timestamp FROM user_daily_visits
+ WHERE timestamp = ?
+ ) udv
+ ON u.user_id = udv.user_id AND u.device_id=udv.device_id
+ INNER JOIN users ON users.name=u.user_id
+ WHERE last_seen > ? AND last_seen <= ?
+ AND udv.timestamp IS NULL AND users.is_guest=0
+ AND users.appservice_id IS NULL
+ GROUP BY u.user_id, u.device_id
+ """
+
+ # This means that the day has rolled over but there could still
+ # be entries from the previous day. There is an edge case
+ # where if the user logs in at 23:59 and overwrites their
+ # last_seen at 00:01 then they will not be counted in the
+ # previous day's stats - it is important that the query is run
+ # often to minimise this case.
+ if today_start > self._last_user_visit_update:
+ yesterday_start = today_start - a_day_in_milliseconds
+ txn.execute(sql, (
+ yesterday_start, yesterday_start,
+ self._last_user_visit_update, today_start
+ ))
+ self._last_user_visit_update = today_start
+
+ txn.execute(sql, (
+ today_start, today_start,
+ self._last_user_visit_update,
+ now
+ ))
+ # Update _last_user_visit_update to now. The reason to do this
+ # rather just clamping to the beginning of the day is to limit
+ # the size of the join - meaning that the query can be run more
+ # frequently
+ self._last_user_visit_update = now
+
+ return self.runInteraction("generate_user_daily_visits",
+ _generate_user_daily_visits)
def get_users(self):
"""Function to reterive a list of users in users table.
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 2fbebd4907..22d6257a9f 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -18,8 +18,8 @@ from synapse.api.errors import StoreError
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches.descriptors import Cache
from synapse.storage.engines import PostgresEngine
-import synapse.metrics
+from prometheus_client import Histogram
from twisted.internet import defer
@@ -27,20 +27,25 @@ import sys
import time
import threading
+from six import itervalues, iterkeys, iteritems
+from six.moves import intern, range
logger = logging.getLogger(__name__)
+try:
+ MAX_TXN_ID = sys.maxint - 1
+except AttributeError:
+ # python 3 does not have a maximum int value
+ MAX_TXN_ID = 2**63 - 1
+
sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
perf_logger = logging.getLogger("synapse.storage.TIME")
+sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
-metrics = synapse.metrics.get_metrics_for("synapse.storage")
-
-sql_scheduling_timer = metrics.register_distribution("schedule_time")
-
-sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
-sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
+sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
+sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
class LoggingTransaction(object):
@@ -105,7 +110,7 @@ class LoggingTransaction(object):
# Don't let logging failures stop SQL from working
pass
- start = time.time() * 1000
+ start = time.time()
try:
return func(
@@ -115,9 +120,9 @@ class LoggingTransaction(object):
logger.debug("[SQL FAIL] {%s} %s", self.name, e)
raise
finally:
- msecs = (time.time() * 1000) - start
- sql_logger.debug("[SQL time] {%s} %f", self.name, msecs)
- sql_query_timer.inc_by(msecs, sql.split()[0])
+ secs = time.time() - start
+ sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
+ sql_query_timer.labels(sql.split()[0]).observe(secs)
class PerformanceCounters(object):
@@ -127,7 +132,7 @@ class PerformanceCounters(object):
def update(self, key, start_time, end_time=None):
if end_time is None:
- end_time = time.time() * 1000
+ end_time = time.time()
duration = end_time - start_time
count, cum_time = self.current_counters.get(key, (0, 0))
count += 1
@@ -137,7 +142,7 @@ class PerformanceCounters(object):
def interval(self, interval_duration, limit=3):
counters = []
- for name, (count, cum_time) in self.current_counters.iteritems():
+ for name, (count, cum_time) in iteritems(self.current_counters):
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
counters.append((
(cum_time - prev_time) / interval_duration,
@@ -217,12 +222,12 @@ class SQLBaseStore(object):
def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks,
logging_context, func, *args, **kwargs):
- start = time.time() * 1000
+ start = time.time()
txn_id = self._TXN_ID
# We don't really need these to be unique, so lets stop it from
# growing really large.
- self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
+ self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
name = "%s-%x" % (desc, txn_id, )
@@ -277,17 +282,17 @@ class SQLBaseStore(object):
logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
finally:
- end = time.time() * 1000
+ end = time.time()
duration = end - start
if logging_context is not None:
logging_context.add_database_transaction(duration)
- transaction_logger.debug("[TXN END] {%s} %f", name, duration)
+ transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
self._current_txn_total_time += duration
self._txn_perf_counters.update(desc, start, end)
- sql_txn_timer.inc_by(duration, desc)
+ sql_txn_timer.labels(desc).observe(duration)
@defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs):
@@ -344,13 +349,13 @@ class SQLBaseStore(object):
"""
current_context = LoggingContext.current_context()
- start_time = time.time() * 1000
+ start_time = time.time()
def inner_func(conn, *args, **kwargs):
with LoggingContext("runWithConnection") as context:
- sched_duration_ms = time.time() * 1000 - start_time
- sql_scheduling_timer.inc_by(sched_duration_ms)
- current_context.add_database_scheduled(sched_duration_ms)
+ sched_duration_sec = time.time() - start_time
+ sql_scheduling_timer.observe(sched_duration_sec)
+ current_context.add_database_scheduled(sched_duration_sec)
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
@@ -376,7 +381,7 @@ class SQLBaseStore(object):
Returns:
A list of dicts where the key is the column header.
"""
- col_headers = list(intern(column[0]) for column in cursor.description)
+ col_headers = list(intern(str(column[0])) for column in cursor.description)
results = list(
dict(zip(col_headers, row)) for row in cursor
)
@@ -543,7 +548,7 @@ class SQLBaseStore(object):
", ".join("%s = ?" % (k,) for k in values),
" AND ".join("%s = ?" % (k,) for k in keyvalues)
)
- sqlargs = values.values() + keyvalues.values()
+ sqlargs = list(values.values()) + list(keyvalues.values())
txn.execute(sql, sqlargs)
if txn.rowcount > 0:
@@ -561,7 +566,7 @@ class SQLBaseStore(object):
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues)
)
- txn.execute(sql, allvalues.values())
+ txn.execute(sql, list(allvalues.values()))
# successfully inserted
return True
@@ -629,8 +634,8 @@ class SQLBaseStore(object):
}
if keyvalues:
- sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
- txn.execute(sql, keyvalues.values())
+ sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+ txn.execute(sql, list(keyvalues.values()))
else:
txn.execute(sql)
@@ -694,7 +699,7 @@ class SQLBaseStore(object):
table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
- txn.execute(sql, keyvalues.values())
+ txn.execute(sql, list(keyvalues.values()))
else:
sql = "SELECT %s FROM %s" % (
", ".join(retcols),
@@ -725,9 +730,12 @@ class SQLBaseStore(object):
if not iterable:
defer.returnValue(results)
+ # iterables can not be sliced, so convert it to a list first
+ it_list = list(iterable)
+
chunks = [
- iterable[i:i + batch_size]
- for i in xrange(0, len(iterable), batch_size)
+ it_list[i:i + batch_size]
+ for i in range(0, len(it_list), batch_size)
]
for chunk in chunks:
rows = yield self.runInteraction(
@@ -767,7 +775,7 @@ class SQLBaseStore(object):
)
values.extend(iterable)
- for key, value in keyvalues.iteritems():
+ for key, value in iteritems(keyvalues):
clauses.append("%s = ?" % (key,))
values.append(value)
@@ -790,7 +798,7 @@ class SQLBaseStore(object):
@staticmethod
def _simple_update_txn(txn, table, keyvalues, updatevalues):
if keyvalues:
- where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
+ where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
else:
where = ""
@@ -802,7 +810,7 @@ class SQLBaseStore(object):
txn.execute(
update_sql,
- updatevalues.values() + keyvalues.values()
+ list(updatevalues.values()) + list(keyvalues.values())
)
return txn.rowcount
@@ -850,7 +858,7 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k,) for k in keyvalues)
)
- txn.execute(select_sql, keyvalues.values())
+ txn.execute(select_sql, list(keyvalues.values()))
row = txn.fetchone()
if not row:
@@ -888,7 +896,7 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
- txn.execute(sql, keyvalues.values())
+ txn.execute(sql, list(keyvalues.values()))
if txn.rowcount == 0:
raise StoreError(404, "No row found")
if txn.rowcount > 1:
@@ -906,7 +914,7 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
- return txn.execute(sql, keyvalues.values())
+ return txn.execute(sql, list(keyvalues.values()))
def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
return self.runInteraction(
@@ -938,7 +946,7 @@ class SQLBaseStore(object):
)
values.extend(iterable)
- for key, value in keyvalues.iteritems():
+ for key, value in iteritems(keyvalues):
clauses.append("%s = ?" % (key,))
values.append(value)
@@ -978,7 +986,7 @@ class SQLBaseStore(object):
txn.close()
if cache:
- min_val = min(cache.itervalues())
+ min_val = min(itervalues(cache))
else:
min_val = max_value
@@ -1093,7 +1101,7 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k,) for k in keyvalues),
" ? ASC LIMIT ? OFFSET ?"
)
- txn.execute(sql, keyvalues.values() + pagevalues)
+ txn.execute(sql, list(keyvalues.values()) + list(pagevalues))
else:
sql = "SELECT %s FROM %s ORDER BY %s" % (
", ".join(retcols),
diff --git a/synapse/storage/account_data.py b/synapse/storage/account_data.py
index e70c9423e3..f83ff0454a 100644
--- a/synapse/storage/account_data.py
+++ b/synapse/storage/account_data.py
@@ -23,7 +23,7 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
import abc
-import ujson as json
+import simplejson as json
import logging
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index c88759bf2c..8af325a9f5 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -19,7 +19,7 @@ from . import engines
from twisted.internet import defer
-import ujson as json
+import simplejson as json
import logging
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index a03d1d6104..ce338514e8 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -22,6 +22,8 @@ from . import background_updates
from synapse.util.caches import CACHE_SIZE_FACTOR
+from six import iteritems
+
logger = logging.getLogger(__name__)
@@ -48,6 +50,20 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
columns=["user_id", "device_id", "last_seen"],
)
+ self.register_background_index_update(
+ "user_ips_last_seen_index",
+ index_name="user_ips_last_seen",
+ table="user_ips",
+ columns=["user_id", "last_seen"],
+ )
+
+ self.register_background_index_update(
+ "user_ips_last_seen_only_index",
+ index_name="user_ips_last_seen_only",
+ table="user_ips",
+ columns=["last_seen"],
+ )
+
# (user_id, access_token, ip) -> (user_agent, device_id, last_seen)
self._batch_row_update = {}
@@ -85,7 +101,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
def _update_client_ips_batch_txn(self, txn, to_update):
self.database_engine.lock_table(txn, "user_ips")
- for entry in to_update.iteritems():
+ for entry in iteritems(to_update):
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
self._simple_upsert_txn(
@@ -217,5 +233,5 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"user_agent": user_agent,
"last_seen": last_seen,
}
- for (access_token, ip), (user_agent, last_seen) in results.iteritems()
+ for (access_token, ip), (user_agent, last_seen) in iteritems(results)
))
diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py
index 548e795daf..a879e5bfc1 100644
--- a/synapse/storage/deviceinbox.py
+++ b/synapse/storage/deviceinbox.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-import ujson
+import simplejson
from twisted.internet import defer
@@ -85,7 +85,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
)
rows = []
for destination, edu in remote_messages_by_destination.items():
- edu_json = ujson.dumps(edu)
+ edu_json = simplejson.dumps(edu)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
@@ -177,7 +177,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
" WHERE user_id = ?"
)
txn.execute(sql, (user_id,))
- message_json = ujson.dumps(messages_by_device["*"])
+ message_json = simplejson.dumps(messages_by_device["*"])
for row in txn:
# Add the message for all devices for this user on this
# server.
@@ -199,7 +199,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
# Only insert into the local inbox if the device exists on
# this server
device = row[0]
- message_json = ujson.dumps(messages_by_device[device])
+ message_json = simplejson.dumps(messages_by_device[device])
messages_json_for_user[device] = message_json
if messages_json_for_user:
@@ -253,7 +253,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
messages = []
for row in txn:
stream_pos = row[0]
- messages.append(ujson.loads(row[1]))
+ messages.append(simplejson.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return (messages, stream_pos)
@@ -389,7 +389,7 @@ class DeviceInboxStore(BackgroundUpdateStore):
messages = []
for row in txn:
stream_pos = row[0]
- messages.append(ujson.loads(row[1]))
+ messages.append(simplejson.loads(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return (messages, stream_pos)
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index bd2effdf34..d149d8392e 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import ujson as json
+import simplejson as json
from twisted.internet import defer
@@ -21,6 +21,7 @@ from synapse.api.errors import StoreError
from ._base import SQLBaseStore, Cache
from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
+from six import itervalues, iteritems
logger = logging.getLogger(__name__)
@@ -360,7 +361,7 @@ class DeviceStore(SQLBaseStore):
return (now_stream_id, [])
if len(query_map) >= 20:
- now_stream_id = max(stream_id for stream_id in query_map.itervalues())
+ now_stream_id = max(stream_id for stream_id in itervalues(query_map))
devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True
@@ -373,13 +374,13 @@ class DeviceStore(SQLBaseStore):
"""
results = []
- for user_id, user_devices in devices.iteritems():
+ for user_id, user_devices in iteritems(devices):
# The prev_id for the first row is always the last row before
# `from_stream_id`
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
rows = txn.fetchall()
prev_id = rows[0][0]
- for device_id, device in user_devices.iteritems():
+ for device_id, device in iteritems(user_devices):
stream_id = query_map[(user_id, device_id)]
result = {
"user_id": user_id,
@@ -483,7 +484,7 @@ class DeviceStore(SQLBaseStore):
if devices:
user_devices = devices[user_id]
results = []
- for device_id, device in user_devices.iteritems():
+ for device_id, device in iteritems(user_devices):
result = {
"device_id": device_id,
}
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index 2cebb203c6..b146487943 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -17,10 +17,12 @@ from twisted.internet import defer
from synapse.util.caches.descriptors import cached
from canonicaljson import encode_canonical_json
-import ujson as json
+import simplejson as json
from ._base import SQLBaseStore
+from six import iteritems
+
class EndToEndKeyStore(SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
@@ -81,8 +83,8 @@ class EndToEndKeyStore(SQLBaseStore):
query_list, include_all_devices,
)
- for user_id, device_keys in results.iteritems():
- for device_id, device_info in device_keys.iteritems():
+ for user_id, device_keys in iteritems(results):
+ for device_id, device_info in iteritems(device_keys):
device_info["keys"] = json.loads(device_info.pop("key_json"))
defer.returnValue(results)
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 338b495611..8c868ece75 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -18,6 +18,7 @@ from .postgres import PostgresEngine
from .sqlite3 import Sqlite3Engine
import importlib
+import platform
SUPPORTED_MODULE = {
@@ -31,6 +32,10 @@ def create_engine(database_config):
engine_class = SUPPORTED_MODULE.get(name, None)
if engine_class:
+ # pypy requires psycopg2cffi rather than psycopg2
+ if (name == "psycopg2" and
+ platform.python_implementation() == "PyPy"):
+ name = "psycopg2cffi"
module = importlib.import_module(name)
return engine_class(module, database_config)
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 00ee82d300..8fbf7ffba7 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import random
from twisted.internet import defer
@@ -24,7 +25,9 @@ from synapse.util.caches.descriptors import cached
from unpaddedbase64 import encode_base64
import logging
-from Queue import PriorityQueue, Empty
+from six.moves.queue import PriorityQueue, Empty
+
+from six.moves import range
logger = logging.getLogger(__name__)
@@ -78,7 +81,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
front_list = list(front)
chunks = [
front_list[x:x + 100]
- for x in xrange(0, len(front), 100)
+ for x in range(0, len(front), 100)
]
for chunk in chunks:
txn.execute(
@@ -133,7 +136,47 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
retcol="event_id",
)
+ @defer.inlineCallbacks
+ def get_prev_events_for_room(self, room_id):
+ """
+ Gets a subset of the current forward extremities in the given room.
+
+ Limits the result to 10 extremities, so that we can avoid creating
+ events which refer to hundreds of prev_events.
+
+ Args:
+ room_id (str): room_id
+
+ Returns:
+ Deferred[list[(str, dict[str, str], int)]]
+ for each event, a tuple of (event_id, hashes, depth)
+ where *hashes* is a map from algorithm to hash.
+ """
+ res = yield self.get_latest_event_ids_and_hashes_in_room(room_id)
+ if len(res) > 10:
+ # Sort by reverse depth, so we point to the most recent.
+ res.sort(key=lambda a: -a[2])
+
+ # we use half of the limit for the actual most recent events, and
+ # the other half to randomly point to some of the older events, to
+ # make sure that we don't completely ignore the older events.
+ res = res[0:5] + random.sample(res[5:], 5)
+
+ defer.returnValue(res)
+
def get_latest_event_ids_and_hashes_in_room(self, room_id):
+ """
+ Gets the current forward extremities in the given room
+
+ Args:
+ room_id (str): room_id
+
+ Returns:
+ Deferred[list[(str, dict[str, str], int)]]
+ for each event, a tuple of (event_id, hashes, depth)
+ where *hashes* is a map from algorithm to hash.
+ """
+
return self.runInteraction(
"get_latest_event_ids_and_hashes_in_room",
self._get_latest_event_ids_and_hashes_in_room,
@@ -182,22 +225,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore,
room_id,
)
- @defer.inlineCallbacks
- def get_max_depth_of_events(self, event_ids):
- sql = (
- "SELECT MAX(depth) FROM events WHERE event_id IN (%s)"
- ) % (",".join(["?"] * len(event_ids)),)
-
- rows = yield self._execute(
- "get_max_depth_of_events", None,
- sql, *event_ids
- )
-
- if rows:
- defer.returnValue(rows[0][0])
- else:
- defer.returnValue(1)
-
def _get_min_depth_interaction(self, txn, room_id):
min_depth = self._simple_select_one_onecol_txn(
txn,
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index 01f8339825..d0350ee5fe 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -18,11 +18,11 @@ from synapse.storage._base import SQLBaseStore, LoggingTransaction
from twisted.internet import defer
from synapse.util.async import sleep
from synapse.util.caches.descriptors import cachedInlineCallbacks
-from synapse.types import RoomStreamToken
-from .stream import lower_bound
import logging
-import ujson as json
+import simplejson as json
+
+from six import iteritems
logger = logging.getLogger(__name__)
@@ -99,7 +99,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def _get_unread_counts_by_receipt_txn(self, txn, room_id, user_id,
last_read_event_id):
sql = (
- "SELECT stream_ordering, topological_ordering"
+ "SELECT stream_ordering"
" FROM events"
" WHERE room_id = ? AND event_id = ?"
)
@@ -111,17 +111,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
return {"notify_count": 0, "highlight_count": 0}
stream_ordering = results[0][0]
- topological_ordering = results[0][1]
return self._get_unread_counts_by_pos_txn(
- txn, room_id, user_id, topological_ordering, stream_ordering
+ txn, room_id, user_id, stream_ordering
)
- def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, topological_ordering,
- stream_ordering):
- token = RoomStreamToken(
- topological_ordering, stream_ordering
- )
+ def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering):
# First get number of notifications.
# We don't need to put a notif=1 clause as all rows always have
@@ -132,10 +127,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" WHERE"
" user_id = ?"
" AND room_id = ?"
- " AND %s"
- ) % (lower_bound(token, self.database_engine, inclusive=False),)
+ " AND stream_ordering > ?"
+ )
- txn.execute(sql, (user_id, room_id))
+ txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone()
notify_count = row[0] if row else 0
@@ -155,10 +150,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" highlight = 1"
" AND user_id = ?"
" AND room_id = ?"
- " AND %s"
- ) % (lower_bound(token, self.database_engine, inclusive=False),)
+ " AND stream_ordering > ?"
+ )
- txn.execute(sql, (user_id, room_id))
+ txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone()
highlight_count = row[0] if row else 0
@@ -209,7 +204,6 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" ep.highlight "
" FROM ("
" SELECT room_id,"
- " MAX(topological_ordering) as topological_ordering,"
" MAX(stream_ordering) as stream_ordering"
" FROM events"
" INNER JOIN receipts_linearized USING (room_id, event_id)"
@@ -219,13 +213,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" event_push_actions AS ep"
" WHERE"
" ep.room_id = rl.room_id"
- " AND ("
- " ep.topological_ordering > rl.topological_ordering"
- " OR ("
- " ep.topological_ordering = rl.topological_ordering"
- " AND ep.stream_ordering > rl.stream_ordering"
- " )"
- " )"
+ " AND ep.stream_ordering > rl.stream_ordering"
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
@@ -318,7 +306,6 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" ep.highlight, e.received_ts"
" FROM ("
" SELECT room_id,"
- " MAX(topological_ordering) as topological_ordering,"
" MAX(stream_ordering) as stream_ordering"
" FROM events"
" INNER JOIN receipts_linearized USING (room_id, event_id)"
@@ -329,13 +316,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" INNER JOIN events AS e USING (room_id, event_id)"
" WHERE"
" ep.room_id = rl.room_id"
- " AND ("
- " ep.topological_ordering > rl.topological_ordering"
- " OR ("
- " ep.topological_ordering = rl.topological_ordering"
- " AND ep.stream_ordering > rl.stream_ordering"
- " )"
- " )"
+ " AND ep.stream_ordering > rl.stream_ordering"
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
@@ -441,13 +422,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.executemany(sql, (
_gen_entry(user_id, actions)
- for user_id, actions in user_id_actions.iteritems()
+ for user_id, actions in iteritems(user_id_actions)
))
return self.runInteraction(
"add_push_actions_to_staging", _add_push_actions_to_staging_txn
)
+ @defer.inlineCallbacks
def remove_push_actions_from_staging(self, event_id):
"""Called if we failed to persist the event to ensure that stale push
actions don't build up in the DB
@@ -456,13 +438,22 @@ class EventPushActionsWorkerStore(SQLBaseStore):
event_id (str)
"""
- return self._simple_delete(
- table="event_push_actions_staging",
- keyvalues={
- "event_id": event_id,
- },
- desc="remove_push_actions_from_staging",
- )
+ try:
+ res = yield self._simple_delete(
+ table="event_push_actions_staging",
+ keyvalues={
+ "event_id": event_id,
+ },
+ desc="remove_push_actions_from_staging",
+ )
+ defer.returnValue(res)
+ except Exception:
+ # this method is called from an exception handler, so propagating
+ # another exception here really isn't helpful - there's nothing
+ # the caller can do about it. Just log the exception and move on.
+ logger.exception(
+ "Error removing push actions after event persistence failure",
+ )
@defer.inlineCallbacks
def _find_stream_orderings_for_times(self):
@@ -752,10 +743,10 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
)
def _remove_old_push_actions_before_txn(self, txn, room_id, user_id,
- topological_ordering, stream_ordering):
+ stream_ordering):
"""
Purges old push actions for a user and room before a given
- topological_ordering.
+ stream_ordering.
We however keep a months worth of highlighted notifications, so that
users can still get a list of recent highlights.
@@ -764,7 +755,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
txn: The transcation
room_id: Room ID to delete from
user_id: user ID to delete for
- topological_ordering: The lowest topological ordering which will
+ stream_ordering: The lowest stream ordering which will
not be deleted.
"""
txn.call_after(
@@ -783,9 +774,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
txn.execute(
"DELETE FROM event_push_actions "
" WHERE user_id = ? AND room_id = ? AND "
- " topological_ordering <= ?"
+ " stream_ordering <= ?"
" AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
- (user_id, room_id, topological_ordering, self.stream_ordering_month_ago)
+ (user_id, room_id, stream_ordering, self.stream_ordering_month_ago)
)
txn.execute("""
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 9fc65229fd..b96104ccae 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -14,52 +14,57 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.storage.events_worker import EventsWorkerStore
+from collections import OrderedDict, deque, namedtuple
+from functools import wraps
+import itertools
+import logging
+import simplejson as json
from twisted.internet import defer
-from synapse.events import USE_FROZEN_DICTS
-
+from synapse.storage.events_worker import EventsWorkerStore
from synapse.util.async import ObservableDeferred
+from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.logcontext import (
- PreserveLoggingContext, make_deferred_yieldable
+ PreserveLoggingContext, make_deferred_yieldable,
)
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-from synapse.types import get_domain_from_id
-
-from canonicaljson import encode_canonical_json
-from collections import deque, namedtuple, OrderedDict
-from functools import wraps
-
+from synapse.types import get_domain_from_id, RoomStreamToken
import synapse.metrics
-import logging
-import simplejson as json
-
# these are only included to make the type annotations work
from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
+from prometheus_client import Counter
+
logger = logging.getLogger(__name__)
+persist_event_counter = Counter("synapse_storage_events_persisted_events", "")
+event_counter = Counter("synapse_storage_events_persisted_events_sep", "",
+ ["type", "origin_type", "origin_entity"])
-metrics = synapse.metrics.get_metrics_for(__name__)
-persist_event_counter = metrics.register_counter("persisted_events")
-event_counter = metrics.register_counter(
- "persisted_events_sep", labels=["type", "origin_type", "origin_entity"]
-)
+# The number of times we are recalculating the current state
+state_delta_counter = Counter("synapse_storage_events_state_delta", "")
+
+# The number of times we are recalculating state when there is only a
+# single forward extremity
+state_delta_single_event_counter = Counter(
+ "synapse_storage_events_state_delta_single_event", "")
+
+# The number of times we are reculating state when we could have resonably
+# calculated the delta when we calculated the state for an event we were
+# persisting.
+state_delta_reuse_delta_counter = Counter(
+ "synapse_storage_events_state_delta_reuse_delta", "")
def encode_json(json_object):
- if USE_FROZEN_DICTS:
- # ujson doesn't like frozen_dicts
- return encode_canonical_json(json_object)
- else:
- return json.dumps(json_object, ensure_ascii=False)
+ return frozendict_json_encoder.encode(json_object)
class _EventPeristenceQueue(object):
@@ -369,7 +374,8 @@ class EventsStore(EventsWorkerStore):
room_id, ev_ctx_rm, latest_event_ids
)
- if new_latest_event_ids == set(latest_event_ids):
+ latest_event_ids = set(latest_event_ids)
+ if new_latest_event_ids == latest_event_ids:
# No change in extremities, so no change in state
continue
@@ -390,12 +396,34 @@ class EventsStore(EventsWorkerStore):
if all_single_prev_not_state:
continue
+ state_delta_counter.inc()
+ if len(new_latest_event_ids) == 1:
+ state_delta_single_event_counter.inc()
+
+ # This is a fairly handwavey check to see if we could
+ # have guessed what the delta would have been when
+ # processing one of these events.
+ # What we're interested in is if the latest extremities
+ # were the same when we created the event as they are
+ # now. When this server creates a new event (as opposed
+ # to receiving it over federation) it will use the
+ # forward extremities as the prev_events, so we can
+ # guess this by looking at the prev_events and checking
+ # if they match the current forward extremities.
+ for ev, _ in ev_ctx_rm:
+ prev_event_ids = set(e for e, _ in ev.prev_events)
+ if latest_event_ids == prev_event_ids:
+ state_delta_reuse_delta_counter.inc()
+ break
+
logger.info(
"Calculating state delta for room %s", room_id,
)
current_state = yield self._get_new_state_after_events(
room_id,
- ev_ctx_rm, new_latest_event_ids,
+ ev_ctx_rm,
+ latest_event_ids,
+ new_latest_event_ids,
)
if current_state is not None:
current_state_for_room[room_id] = current_state
@@ -414,7 +442,10 @@ class EventsStore(EventsWorkerStore):
state_delta_for_room=state_delta_for_room,
new_forward_extremeties=new_forward_extremeties,
)
- persist_event_counter.inc_by(len(chunk))
+ persist_event_counter.inc(len(chunk))
+ synapse.metrics.event_persisted_position.set(
+ chunk[-1][0].internal_metadata.stream_ordering,
+ )
for event, context in chunk:
if context.app_service:
origin_type = "local"
@@ -426,7 +457,7 @@ class EventsStore(EventsWorkerStore):
origin_type = "remote"
origin_entity = get_domain_from_id(event.sender)
- event_counter.inc(event.type, origin_type, origin_entity)
+ event_counter.labels(event.type, origin_type, origin_entity).inc()
for room_id, new_state in current_state_for_room.iteritems():
self.get_current_state_ids.prefill(
@@ -480,7 +511,8 @@ class EventsStore(EventsWorkerStore):
defer.returnValue(new_latest_event_ids)
@defer.inlineCallbacks
- def _get_new_state_after_events(self, room_id, events_context, new_latest_event_ids):
+ def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids,
+ new_latest_event_ids):
"""Calculate the current state dict after adding some new events to
a room
@@ -491,6 +523,9 @@ class EventsStore(EventsWorkerStore):
events_context (list[(EventBase, EventContext)]):
events and contexts which are being added to the room
+ old_latest_event_ids (iterable[str]):
+ the old forward extremities for the room.
+
new_latest_event_ids (iterable[str]):
the new forward extremities for the room.
@@ -501,64 +536,89 @@ class EventsStore(EventsWorkerStore):
"""
if not new_latest_event_ids:
- defer.returnValue({})
+ return
# map from state_group to ((type, key) -> event_id) state map
- state_groups = {}
- missing_event_ids = []
- was_updated = False
+ state_groups_map = {}
+ for ev, ctx in events_context:
+ if ctx.state_group is None:
+ # I don't think this can happen, but let's double-check
+ raise Exception(
+ "Context for new extremity event %s has no state "
+ "group" % (ev.event_id, ),
+ )
+
+ if ctx.state_group in state_groups_map:
+ continue
+
+ state_groups_map[ctx.state_group] = ctx.current_state_ids
+
+ # We need to map the event_ids to their state groups. First, let's
+ # check if the event is one we're persisting, in which case we can
+ # pull the state group from its context.
+ # Otherwise we need to pull the state group from the database.
+
+ # Set of events we need to fetch groups for. (We know none of the old
+ # extremities are going to be in events_context).
+ missing_event_ids = set(old_latest_event_ids)
+
+ event_id_to_state_group = {}
for event_id in new_latest_event_ids:
- # First search in the list of new events we're adding,
- # and then use the current state from that
+ # First search in the list of new events we're adding.
for ev, ctx in events_context:
if event_id == ev.event_id:
- if ctx.current_state_ids is None:
- raise Exception("Unknown current state")
-
- if ctx.state_group is None:
- # I don't think this can happen, but let's double-check
- raise Exception(
- "Context for new extremity event %s has no state "
- "group" % (event_id, ),
- )
-
- # If we've already seen the state group don't bother adding
- # it to the state sets again
- if ctx.state_group not in state_groups:
- state_groups[ctx.state_group] = ctx.current_state_ids
- if ctx.delta_ids or hasattr(ev, "state_key"):
- was_updated = True
+ event_id_to_state_group[event_id] = ctx.state_group
break
else:
# If we couldn't find it, then we'll need to pull
# the state from the database
- was_updated = True
- missing_event_ids.append(event_id)
-
- if not was_updated:
- return
+ missing_event_ids.add(event_id)
if missing_event_ids:
- # Now pull out the state for any missing events from DB
+ # Now pull out the state groups for any missing events from DB
event_to_groups = yield self._get_state_group_for_events(
missing_event_ids,
)
+ event_id_to_state_group.update(event_to_groups)
- groups = set(event_to_groups.itervalues()) - set(state_groups.iterkeys())
+ # State groups of old_latest_event_ids
+ old_state_groups = set(
+ event_id_to_state_group[evid] for evid in old_latest_event_ids
+ )
- if groups:
- group_to_state = yield self._get_state_for_groups(groups)
- state_groups.update(group_to_state)
+ # State groups of new_latest_event_ids
+ new_state_groups = set(
+ event_id_to_state_group[evid] for evid in new_latest_event_ids
+ )
- if len(state_groups) == 1:
+ # If they old and new groups are the same then we don't need to do
+ # anything.
+ if old_state_groups == new_state_groups:
+ return
+
+ # Now that we have calculated new_state_groups we need to get
+ # their state IDs so we can resolve to a single state set.
+ missing_state = new_state_groups - set(state_groups_map)
+ if missing_state:
+ group_to_state = yield self._get_state_for_groups(missing_state)
+ state_groups_map.update(group_to_state)
+
+ if len(new_state_groups) == 1:
# If there is only one state group, then we know what the current
# state is.
- defer.returnValue(state_groups.values()[0])
+ defer.returnValue(state_groups_map[new_state_groups.pop()])
+
+ # Ok, we need to defer to the state handler to resolve our state sets.
def get_events(ev_ids):
return self.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
)
+
+ state_groups = {
+ sg: state_groups_map[sg] for sg in new_state_groups
+ }
+
events_map = {ev.event_id: ev for ev, _ in events_context}
logger.debug("calling resolve_state_groups from preserve_events")
res = yield self._state_resolution_handler.resolve_state_groups(
@@ -1288,13 +1348,49 @@ class EventsStore(EventsWorkerStore):
defer.returnValue(set(r["event_id"] for r in rows))
- def have_events(self, event_ids):
+ @defer.inlineCallbacks
+ def have_seen_events(self, event_ids):
"""Given a list of event ids, check if we have already processed them.
+ Args:
+ event_ids (iterable[str]):
+
+ Returns:
+ Deferred[set[str]]: The events we have already seen.
+ """
+ results = set()
+
+ def have_seen_events_txn(txn, chunk):
+ sql = (
+ "SELECT event_id FROM events as e WHERE e.event_id IN (%s)"
+ % (",".join("?" * len(chunk)), )
+ )
+ txn.execute(sql, chunk)
+ for (event_id, ) in txn:
+ results.add(event_id)
+
+ # break the input up into chunks of 100
+ input_iterator = iter(event_ids)
+ for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)),
+ []):
+ yield self.runInteraction(
+ "have_seen_events",
+ have_seen_events_txn,
+ chunk,
+ )
+ defer.returnValue(results)
+
+ def get_seen_events_with_rejections(self, event_ids):
+ """Given a list of event ids, check if we rejected them.
+
+ Args:
+ event_ids (list[str])
+
Returns:
- dict: Has an entry for each event id we already have seen. Maps to
- the rejected reason string if we rejected the event, else maps to
- None.
+ Deferred[dict[str, str|None):
+ Has an entry for each event id we already have seen. Maps to
+ the rejected reason string if we rejected the event, else maps
+ to None.
"""
if not event_ids:
return defer.succeed({})
@@ -1316,9 +1412,7 @@ class EventsStore(EventsWorkerStore):
return res
- return self.runInteraction(
- "have_events", f,
- )
+ return self.runInteraction("get_rejection_reasons", f)
@defer.inlineCallbacks
def count_daily_messages(self):
@@ -1706,15 +1800,14 @@ class EventsStore(EventsWorkerStore):
return self.runInteraction("get_all_new_events", get_all_new_events_txn)
def purge_history(
- self, room_id, topological_ordering, delete_local_events,
+ self, room_id, token, delete_local_events,
):
"""Deletes room history before a certain point
Args:
room_id (str):
- topological_ordering (int):
- minimum topo ordering to preserve
+ token (str): A topological token to delete events before
delete_local_events (bool):
if True, we will delete local events as well as remote ones
@@ -1724,13 +1817,15 @@ class EventsStore(EventsWorkerStore):
return self.runInteraction(
"purge_history",
- self._purge_history_txn, room_id, topological_ordering,
+ self._purge_history_txn, room_id, token,
delete_local_events,
)
def _purge_history_txn(
- self, txn, room_id, topological_ordering, delete_local_events,
+ self, txn, room_id, token_str, delete_local_events,
):
+ token = RoomStreamToken.parse(token_str)
+
# Tables that should be pruned:
# event_auth
# event_backward_extremities
@@ -1775,6 +1870,13 @@ class EventsStore(EventsWorkerStore):
" ON events_to_purge(should_delete)",
)
+ # We do joins against events_to_purge for e.g. calculating state
+ # groups to purge, etc., so lets make an index.
+ txn.execute(
+ "CREATE INDEX events_to_purge_id"
+ " ON events_to_purge(event_id)",
+ )
+
# First ensure that we're not about to delete all the forward extremeties
txn.execute(
"SELECT e.event_id, e.depth FROM events as e "
@@ -1787,7 +1889,7 @@ class EventsStore(EventsWorkerStore):
rows = txn.fetchall()
max_depth = max(row[0] for row in rows)
- if max_depth <= topological_ordering:
+ if max_depth <= token.topological:
# We need to ensure we don't delete all the events from the datanase
# otherwise we wouldn't be able to send any events (due to not
# having any backwards extremeties)
@@ -1803,7 +1905,7 @@ class EventsStore(EventsWorkerStore):
should_delete_expr += " AND event_id NOT LIKE ?"
should_delete_params += ("%:" + self.hs.hostname, )
- should_delete_params += (room_id, topological_ordering)
+ should_delete_params += (room_id, token.topological)
txn.execute(
"INSERT INTO events_to_purge"
@@ -1826,13 +1928,13 @@ class EventsStore(EventsWorkerStore):
logger.info("[purge] Finding new backward extremities")
# We calculate the new entries for the backward extremeties by finding
- # all events that point to events that are to be purged
+ # events to be purged that are pointed to by events we're not going to
+ # purge.
txn.execute(
"SELECT DISTINCT e.event_id FROM events_to_purge AS e"
" INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id"
- " INNER JOIN events AS e2 ON e2.event_id = ed.event_id"
- " WHERE e2.topological_ordering >= ?",
- (topological_ordering, )
+ " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id"
+ " WHERE ep2.event_id IS NULL",
)
new_backwards_extrems = txn.fetchall()
@@ -1856,16 +1958,22 @@ class EventsStore(EventsWorkerStore):
# Get all state groups that are only referenced by events that are
# to be deleted.
- txn.execute(
- "SELECT state_group FROM event_to_state_groups"
- " INNER JOIN events USING (event_id)"
- " WHERE state_group IN ("
- " SELECT DISTINCT state_group FROM events_to_purge"
- " INNER JOIN event_to_state_groups USING (event_id)"
- " )"
- " GROUP BY state_group HAVING MAX(topological_ordering) < ?",
- (topological_ordering, )
- )
+ # This works by first getting state groups that we may want to delete,
+ # joining against event_to_state_groups to get events that use that
+ # state group, then left joining against events_to_purge again. Any
+ # state group where the left join produce *no nulls* are referenced
+ # only by events that are going to be purged.
+ txn.execute("""
+ SELECT state_group FROM
+ (
+ SELECT DISTINCT state_group FROM events_to_purge
+ INNER JOIN event_to_state_groups USING (event_id)
+ ) AS sp
+ INNER JOIN event_to_state_groups USING (state_group)
+ LEFT JOIN events_to_purge AS ep USING (event_id)
+ GROUP BY state_group
+ HAVING SUM(CASE WHEN ep.event_id IS NULL THEN 1 ELSE 0 END) = 0
+ """)
state_rows = txn.fetchall()
logger.info("[purge] found %i redundant state groups", len(state_rows))
@@ -2012,10 +2120,25 @@ class EventsStore(EventsWorkerStore):
#
# So, let's stick it at the end so that we don't block event
# persistence.
- logger.info("[purge] updating room_depth")
+ #
+ # We do this by calculating the minimum depth of the backwards
+ # extremities. However, the events in event_backward_extremities
+ # are ones we don't have yet so we need to look at the events that
+ # point to it via event_edges table.
+ txn.execute("""
+ SELECT COALESCE(MIN(depth), 0)
+ FROM event_backward_extremities AS eb
+ INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id
+ INNER JOIN events AS e ON e.event_id = eg.event_id
+ WHERE eb.room_id = ?
+ """, (room_id,))
+ min_depth, = txn.fetchone()
+
+ logger.info("[purge] updating room_depth to %d", min_depth)
+
txn.execute(
"UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
- (topological_ordering, room_id,)
+ (min_depth, room_id,)
)
# finally, drop the temp table. this will commit the txn in sqlite,
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 2e23dd78ba..32d9d00ffb 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -20,7 +20,7 @@ from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from synapse.util.logcontext import (
- preserve_fn, PreserveLoggingContext, make_deferred_yieldable
+ PreserveLoggingContext, make_deferred_yieldable, run_in_background,
)
from synapse.util.metrics import Measure
from synapse.api.errors import SynapseError
@@ -51,6 +51,26 @@ _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
class EventsWorkerStore(SQLBaseStore):
+ def get_received_ts(self, event_id):
+ """Get received_ts (when it was persisted) for the event.
+
+ Raises an exception for unknown events.
+
+ Args:
+ event_id (str)
+
+ Returns:
+ Deferred[int|None]: Timestamp in milliseconds, or None for events
+ that were persisted before received_ts was implemented.
+ """
+ return self._simple_select_one_onecol(
+ table="events",
+ keyvalues={
+ "event_id": event_id,
+ },
+ retcol="received_ts",
+ desc="get_received_ts",
+ )
@defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True,
@@ -299,7 +319,8 @@ class EventsWorkerStore(SQLBaseStore):
res = yield make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self._get_event_from_row)(
+ run_in_background(
+ self._get_event_from_row,
row["internal_metadata"], row["json"], row["redacts"],
rejected_reason=row["rejects"],
)
@@ -316,7 +337,7 @@ class EventsWorkerStore(SQLBaseStore):
def _fetch_event_rows(self, txn, events):
rows = []
N = 200
- for i in range(1 + len(events) / N):
+ for i in range(1 + len(events) // N):
evs = events[i * N:(i + 1) * N]
if not evs:
break
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index 78b1e30945..2e2763126d 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -44,7 +44,7 @@ class FilteringStore(SQLBaseStore):
desc="get_user_filter",
)
- defer.returnValue(json.loads(str(def_json).decode("utf-8")))
+ defer.returnValue(json.loads(bytes(def_json).decode("utf-8")))
def add_user_filter(self, user_localpart, user_filter):
def_json = encode_canonical_json(user_filter)
diff --git a/synapse/storage/group_server.py b/synapse/storage/group_server.py
index 8fde1aab8e..da05ccb027 100644
--- a/synapse/storage/group_server.py
+++ b/synapse/storage/group_server.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,7 +20,7 @@ from synapse.api.errors import SynapseError
from ._base import SQLBaseStore
-import ujson as json
+import simplejson as json
# The category ID for the "default" category. We don't store as null in the
@@ -29,6 +30,24 @@ _DEFAULT_ROLE_ID = ""
class GroupServerStore(SQLBaseStore):
+ def set_group_join_policy(self, group_id, join_policy):
+ """Set the join policy of a group.
+
+ join_policy can be one of:
+ * "invite"
+ * "open"
+ """
+ return self._simple_update_one(
+ table="groups",
+ keyvalues={
+ "group_id": group_id,
+ },
+ updatevalues={
+ "join_policy": join_policy,
+ },
+ desc="set_group_join_policy",
+ )
+
def get_group(self, group_id):
return self._simple_select_one(
table="groups",
@@ -36,10 +55,11 @@ class GroupServerStore(SQLBaseStore):
"group_id": group_id,
},
retcols=(
- "name", "short_description", "long_description", "avatar_url", "is_public"
+ "name", "short_description", "long_description",
+ "avatar_url", "is_public", "join_policy",
),
allow_none=True,
- desc="is_user_in_group",
+ desc="get_group",
)
def get_users_in_group(self, group_id, include_private=False):
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 87aeaf71d6..0540c2b0b1 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -92,7 +92,7 @@ class KeyStore(SQLBaseStore):
if verify_key_bytes:
defer.returnValue(decode_verify_key_bytes(
- key_id, str(verify_key_bytes)
+ key_id, bytes(verify_key_bytes)
))
@defer.inlineCallbacks
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index c845a0cec5..c08e9cd65a 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -25,7 +26,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 47
+SCHEMA_VERSION = 49
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index eac8694e0f..709c69a926 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -23,7 +23,7 @@ from twisted.internet import defer
import abc
import logging
-import ujson as json
+import simplejson as json
logger = logging.getLogger(__name__)
@@ -297,18 +297,22 @@ class ReceiptsWorkerStore(SQLBaseStore):
if receipt_type != "m.read":
return
- # Returns an ObservableDeferred
+ # Returns either an ObservableDeferred or the raw result
res = self.get_users_with_read_receipts_in_room.cache.get(
room_id, None, update_metrics=False,
)
- if res:
- if isinstance(res, defer.Deferred) and res.called:
+ # first handle the Deferred case
+ if isinstance(res, defer.Deferred):
+ if res.called:
res = res.result
- if user_id in res:
- # We'd only be adding to the set, so no point invalidating if the
- # user is already there
- return
+ else:
+ res = None
+
+ if res and user_id in res:
+ # We'd only be adding to the set, so no point invalidating if the
+ # user is already there
+ return
self.get_users_with_read_receipts_in_room.invalidate((room_id,))
@@ -407,7 +411,6 @@ class ReceiptsStore(ReceiptsWorkerStore):
txn,
room_id=room_id,
user_id=user_id,
- topological_ordering=topological_ordering,
stream_ordering=stream_ordering,
)
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index d809b2ba46..a530e29f43 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -22,6 +22,8 @@ from synapse.storage import background_updates
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from six.moves import range
+
class RegistrationWorkerStore(SQLBaseStore):
@cached()
@@ -31,7 +33,10 @@ class RegistrationWorkerStore(SQLBaseStore):
keyvalues={
"name": user_id,
},
- retcols=["name", "password_hash", "is_guest"],
+ retcols=[
+ "name", "password_hash", "is_guest",
+ "consent_version", "consent_server_notice_sent",
+ ],
allow_none=True,
desc="get_user_by_id",
)
@@ -284,6 +289,53 @@ class RegistrationStore(RegistrationWorkerStore,
"user_set_password_hash", user_set_password_hash_txn
)
+ def user_set_consent_version(self, user_id, consent_version):
+ """Updates the user table to record privacy policy consent
+
+ Args:
+ user_id (str): full mxid of the user to update
+ consent_version (str): version of the policy the user has consented
+ to
+
+ Raises:
+ StoreError(404) if user not found
+ """
+ def f(txn):
+ self._simple_update_one_txn(
+ txn,
+ table='users',
+ keyvalues={'name': user_id, },
+ updatevalues={'consent_version': consent_version, },
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_by_id, (user_id,)
+ )
+ return self.runInteraction("user_set_consent_version", f)
+
+ def user_set_consent_server_notice_sent(self, user_id, consent_version):
+ """Updates the user table to record that we have sent the user a server
+ notice about privacy policy consent
+
+ Args:
+ user_id (str): full mxid of the user to update
+ consent_version (str): version of the policy we have notified the
+ user about
+
+ Raises:
+ StoreError(404) if user not found
+ """
+ def f(txn):
+ self._simple_update_one_txn(
+ txn,
+ table='users',
+ keyvalues={'name': user_id, },
+ updatevalues={'consent_server_notice_sent': consent_version, },
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_by_id, (user_id,)
+ )
+ return self.runInteraction("user_set_consent_server_notice_sent", f)
+
def user_delete_access_tokens(self, user_id, except_token_id=None,
device_id=None):
"""
@@ -460,18 +512,16 @@ class RegistrationStore(RegistrationWorkerStore,
"""
def _find_next_generated_user_id(txn):
txn.execute("SELECT name FROM users")
- rows = self.cursor_to_dict(txn)
regex = re.compile("^@(\d+):")
found = set()
- for r in rows:
- user_id = r["name"]
+ for user_id, in txn:
match = regex.search(user_id)
if match:
found.add(int(match.group(1)))
- for i in xrange(len(found) + 1):
+ for i in range(len(found) + 1):
if i not in found:
return i
@@ -526,3 +576,42 @@ class RegistrationStore(RegistrationWorkerStore,
except self.database_engine.module.IntegrityError:
ret = yield self.get_3pid_guest_access_token(medium, address)
defer.returnValue(ret)
+
+ def add_user_pending_deactivation(self, user_id):
+ """
+ Adds a user to the table of users who need to be parted from all the rooms they're
+ in
+ """
+ return self._simple_insert(
+ "users_pending_deactivation",
+ values={
+ "user_id": user_id,
+ },
+ desc="add_user_pending_deactivation",
+ )
+
+ def del_user_pending_deactivation(self, user_id):
+ """
+ Removes the given user to the table of users who need to be parted from all the
+ rooms they're in, effectively marking that user as fully deactivated.
+ """
+ return self._simple_delete_one(
+ "users_pending_deactivation",
+ keyvalues={
+ "user_id": user_id,
+ },
+ desc="del_user_pending_deactivation",
+ )
+
+ def get_user_pending_deactivation(self):
+ """
+ Gets one user from the table of users waiting to be parted from all the rooms
+ they're in.
+ """
+ return self._simple_select_one_onecol(
+ "users_pending_deactivation",
+ keyvalues={},
+ retcol="user_id",
+ allow_none=True,
+ desc="get_users_pending_deactivation",
+ )
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 34ed84ea22..ea6a189185 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -22,7 +22,7 @@ from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
import collections
import logging
-import ujson as json
+import simplejson as json
import re
logger = logging.getLogger(__name__)
@@ -530,7 +530,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
# Convert the IDs to MXC URIs
for media_id in local_mxcs:
- local_media_mxcs.append("mxc://%s/%s" % (self.hostname, media_id))
+ local_media_mxcs.append("mxc://%s/%s" % (self.hs.hostname, media_id))
for hostname, media_id in remote_mxcs:
remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
@@ -594,7 +594,8 @@ class RoomStore(RoomWorkerStore, SearchStore):
while next_token:
sql = """
- SELECT stream_ordering, content FROM events
+ SELECT stream_ordering, json FROM events
+ JOIN event_json USING (room_id, event_id)
WHERE room_id = ?
AND stream_ordering < ?
AND contains_url = ? AND outlier = ?
@@ -606,8 +607,8 @@ class RoomStore(RoomWorkerStore, SearchStore):
next_token = None
for stream_ordering, content_json in txn:
next_token = stream_ordering
- content = json.loads(content_json)
-
+ event_json = json.loads(content_json)
+ content = event_json["content"]
content_url = content.get("url")
thumbnail_url = content.get("info", {}).get("thumbnail_url")
@@ -618,7 +619,7 @@ class RoomStore(RoomWorkerStore, SearchStore):
if matches:
hostname = matches.group(1)
media_id = matches.group(2)
- if hostname == self.hostname:
+ if hostname == self.hs.hostname:
local_media_mxcs.append(media_id)
else:
remote_media_mxcs.append((hostname, media_id))
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 52e19e16b0..7bfc3d91b5 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -28,7 +28,9 @@ from synapse.api.constants import Membership, EventTypes
from synapse.types import get_domain_from_id
import logging
-import ujson as json
+import simplejson as json
+
+from six import itervalues, iteritems
logger = logging.getLogger(__name__)
@@ -272,7 +274,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
users_in_room = {}
member_event_ids = [
e_id
- for key, e_id in current_state_ids.iteritems()
+ for key, e_id in iteritems(current_state_ids)
if key[0] == EventTypes.Member
]
@@ -289,7 +291,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
users_in_room = dict(prev_res)
member_event_ids = [
e_id
- for key, e_id in context.delta_ids.iteritems()
+ for key, e_id in iteritems(context.delta_ids)
if key[0] == EventTypes.Member
]
for etype, state_key in context.delta_ids:
@@ -645,8 +647,9 @@ class RoomMemberStore(RoomMemberWorkerStore):
def add_membership_profile_txn(txn):
sql = ("""
- SELECT stream_ordering, event_id, events.room_id, content
+ SELECT stream_ordering, event_id, events.room_id, event_json.json
FROM events
+ INNER JOIN event_json USING (event_id)
INNER JOIN room_memberships USING (event_id)
WHERE ? <= stream_ordering AND stream_ordering < ?
AND type = 'm.room.member'
@@ -667,7 +670,8 @@ class RoomMemberStore(RoomMemberWorkerStore):
event_id = row["event_id"]
room_id = row["room_id"]
try:
- content = json.loads(row["content"])
+ event_json = json.loads(row["json"])
+ content = event_json['content']
except Exception:
continue
@@ -739,7 +743,7 @@ class _JoinedHostsCache(object):
if state_entry.state_group == self.state_group:
pass
elif state_entry.prev_group == self.state_group:
- for (typ, state_key), event_id in state_entry.delta_ids.iteritems():
+ for (typ, state_key), event_id in iteritems(state_entry.delta_ids):
if typ != EventTypes.Member:
continue
@@ -769,7 +773,7 @@ class _JoinedHostsCache(object):
self.state_group = state_entry.state_group
else:
self.state_group = object()
- self._len = sum(len(v) for v in self.hosts_to_joined_users.itervalues())
+ self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users))
defer.returnValue(frozenset(self.hosts_to_joined_users))
def __len__(self):
diff --git a/synapse/storage/schema/delta/14/upgrade_appservice_db.py b/synapse/storage/schema/delta/14/upgrade_appservice_db.py
index 8755bb2e49..4d725b92fe 100644
--- a/synapse/storage/schema/delta/14/upgrade_appservice_db.py
+++ b/synapse/storage/schema/delta/14/upgrade_appservice_db.py
@@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import logging
+import simplejson as json
+
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/schema/delta/25/fts.py b/synapse/storage/schema/delta/25/fts.py
index 4269ac69ad..e7351c3ae6 100644
--- a/synapse/storage/schema/delta/25/fts.py
+++ b/synapse/storage/schema/delta/25/fts.py
@@ -17,7 +17,7 @@ import logging
from synapse.storage.prepare_database import get_statements
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-import ujson
+import simplejson
logger = logging.getLogger(__name__)
@@ -66,7 +66,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
- progress_json = ujson.dumps(progress)
+ progress_json = simplejson.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/schema/delta/27/ts.py b/synapse/storage/schema/delta/27/ts.py
index 71b12a2731..6df57b5206 100644
--- a/synapse/storage/schema/delta/27/ts.py
+++ b/synapse/storage/schema/delta/27/ts.py
@@ -16,7 +16,7 @@ import logging
from synapse.storage.prepare_database import get_statements
-import ujson
+import simplejson
logger = logging.getLogger(__name__)
@@ -45,7 +45,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
- progress_json = ujson.dumps(progress)
+ progress_json = simplejson.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py
index c53e53c94f..85bd1a2006 100644
--- a/synapse/storage/schema/delta/30/as_users.py
+++ b/synapse/storage/schema/delta/30/as_users.py
@@ -14,6 +14,8 @@
import logging
from synapse.config.appservice import load_appservices
+from six.moves import range
+
logger = logging.getLogger(__name__)
@@ -58,7 +60,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
for as_id, user_ids in owned.items():
n = 100
- user_chunks = (user_ids[i:i + 100] for i in xrange(0, len(user_ids), n))
+ user_chunks = (user_ids[i:i + 100] for i in range(0, len(user_ids), n))
for chunk in user_chunks:
cur.execute(
database_engine.convert_param_style(
diff --git a/synapse/storage/schema/delta/31/search_update.py b/synapse/storage/schema/delta/31/search_update.py
index 470ae0c005..fe6b7d196d 100644
--- a/synapse/storage/schema/delta/31/search_update.py
+++ b/synapse/storage/schema/delta/31/search_update.py
@@ -16,7 +16,7 @@ from synapse.storage.engines import PostgresEngine
from synapse.storage.prepare_database import get_statements
import logging
-import ujson
+import simplejson
logger = logging.getLogger(__name__)
@@ -49,7 +49,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"rows_inserted": 0,
"have_added_indexes": False,
}
- progress_json = ujson.dumps(progress)
+ progress_json = simplejson.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/schema/delta/33/event_fields.py b/synapse/storage/schema/delta/33/event_fields.py
index 83066cccc9..1e002f9db2 100644
--- a/synapse/storage/schema/delta/33/event_fields.py
+++ b/synapse/storage/schema/delta/33/event_fields.py
@@ -15,7 +15,7 @@
from synapse.storage.prepare_database import get_statements
import logging
-import ujson
+import simplejson
logger = logging.getLogger(__name__)
@@ -44,7 +44,7 @@ def run_create(cur, database_engine, *args, **kwargs):
"max_stream_id_exclusive": max_stream_id + 1,
"rows_inserted": 0,
}
- progress_json = ujson.dumps(progress)
+ progress_json = simplejson.dumps(progress)
sql = (
"INSERT into background_updates (update_name, progress_json)"
diff --git a/synapse/storage/schema/delta/48/add_user_consent.sql b/synapse/storage/schema/delta/48/add_user_consent.sql
new file mode 100644
index 0000000000..5237491506
--- /dev/null
+++ b/synapse/storage/schema/delta/48/add_user_consent.sql
@@ -0,0 +1,18 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* record the version of the privacy policy the user has consented to
+ */
+ALTER TABLE users ADD COLUMN consent_version TEXT;
diff --git a/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql b/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql
new file mode 100644
index 0000000000..9248b0b24a
--- /dev/null
+++ b/synapse/storage/schema/delta/48/add_user_ips_last_seen_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT into background_updates (update_name, progress_json)
+ VALUES ('user_ips_last_seen_index', '{}');
diff --git a/synapse/storage/schema/delta/48/deactivated_users.sql b/synapse/storage/schema/delta/48/deactivated_users.sql
new file mode 100644
index 0000000000..e9013a6969
--- /dev/null
+++ b/synapse/storage/schema/delta/48/deactivated_users.sql
@@ -0,0 +1,25 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Store any accounts that have been requested to be deactivated.
+ * We part the account from all the rooms its in when its
+ * deactivated. This can take some time and synapse may be restarted
+ * before it completes, so store the user IDs here until the process
+ * is complete.
+ */
+CREATE TABLE users_pending_deactivation (
+ user_id TEXT NOT NULL
+);
diff --git a/synapse/storage/schema/delta/48/group_unique_indexes.py b/synapse/storage/schema/delta/48/group_unique_indexes.py
new file mode 100644
index 0000000000..2233af87d7
--- /dev/null
+++ b/synapse/storage/schema/delta/48/group_unique_indexes.py
@@ -0,0 +1,57 @@
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.prepare_database import get_statements
+
+FIX_INDEXES = """
+-- rebuild indexes as uniques
+DROP INDEX groups_invites_g_idx;
+CREATE UNIQUE INDEX group_invites_g_idx ON group_invites(group_id, user_id);
+DROP INDEX groups_users_g_idx;
+CREATE UNIQUE INDEX group_users_g_idx ON group_users(group_id, user_id);
+
+-- rename other indexes to actually match their table names..
+DROP INDEX groups_users_u_idx;
+CREATE INDEX group_users_u_idx ON group_users(user_id);
+DROP INDEX groups_invites_u_idx;
+CREATE INDEX group_invites_u_idx ON group_invites(user_id);
+DROP INDEX groups_rooms_g_idx;
+CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id);
+DROP INDEX groups_rooms_r_idx;
+CREATE INDEX group_rooms_r_idx ON group_rooms(room_id);
+"""
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid"
+
+ # remove duplicates from group_users & group_invites tables
+ cur.execute("""
+ DELETE FROM group_users WHERE %s NOT IN (
+ SELECT min(%s) FROM group_users GROUP BY group_id, user_id
+ );
+ """ % (rowid, rowid))
+ cur.execute("""
+ DELETE FROM group_invites WHERE %s NOT IN (
+ SELECT min(%s) FROM group_invites GROUP BY group_id, user_id
+ );
+ """ % (rowid, rowid))
+
+ for statement in get_statements(FIX_INDEXES.splitlines()):
+ cur.execute(statement)
+
+
+def run_upgrade(*args, **kwargs):
+ pass
diff --git a/synapse/storage/schema/delta/48/groups_joinable.sql b/synapse/storage/schema/delta/48/groups_joinable.sql
new file mode 100644
index 0000000000..ce26eaf0c9
--- /dev/null
+++ b/synapse/storage/schema/delta/48/groups_joinable.sql
@@ -0,0 +1,22 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * This isn't a real ENUM because sqlite doesn't support it
+ * and we use a default of NULL for inserted rows and interpret
+ * NULL at the python store level as necessary so that existing
+ * rows are given the correct default policy.
+ */
+ALTER TABLE groups ADD COLUMN join_policy TEXT NOT NULL DEFAULT 'invite';
diff --git a/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql b/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql
new file mode 100644
index 0000000000..14dcf18d73
--- /dev/null
+++ b/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql
@@ -0,0 +1,20 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* record whether we have sent a server notice about consenting to the
+ * privacy policy. Specifically records the version of the policy we sent
+ * a message about.
+ */
+ALTER TABLE users ADD COLUMN consent_server_notice_sent TEXT;
diff --git a/synapse/storage/schema/delta/49/add_user_daily_visits.sql b/synapse/storage/schema/delta/49/add_user_daily_visits.sql
new file mode 100644
index 0000000000..3dd478196f
--- /dev/null
+++ b/synapse/storage/schema/delta/49/add_user_daily_visits.sql
@@ -0,0 +1,21 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL,
+ device_id TEXT,
+ timestamp BIGINT NOT NULL );
+CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp);
+CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp);
diff --git a/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql b/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql
new file mode 100644
index 0000000000..3a4ed59b5b
--- /dev/null
+++ b/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT into background_updates (update_name, progress_json)
+ VALUES ('user_ips_last_seen_only_index', '{}');
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index 2755acff40..6ba3e59889 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -16,7 +16,7 @@
from collections import namedtuple
import logging
import re
-import ujson as json
+import simplejson as json
from twisted.internet import defer
@@ -75,8 +75,9 @@ class SearchStore(BackgroundUpdateStore):
def reindex_search_txn(txn):
sql = (
- "SELECT stream_ordering, event_id, room_id, type, content, "
+ "SELECT stream_ordering, event_id, room_id, type, json, "
" origin_server_ts FROM events"
+ " JOIN event_json USING (room_id, event_id)"
" WHERE ? <= stream_ordering AND stream_ordering < ?"
" AND (%s)"
" ORDER BY stream_ordering DESC"
@@ -104,7 +105,8 @@ class SearchStore(BackgroundUpdateStore):
stream_ordering = row["stream_ordering"]
origin_server_ts = row["origin_server_ts"]
try:
- content = json.loads(row["content"])
+ event_json = json.loads(row["json"])
+ content = event_json["content"]
except Exception:
continue
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 2956c3b3e0..fb463c525a 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -38,15 +38,17 @@ from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
from synapse.storage.events import EventsWorkerStore
-from synapse.util.caches.descriptors import cached
from synapse.types import RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
-from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+from synapse.storage.engines import PostgresEngine
import abc
import logging
+from six.moves import range
+from collections import namedtuple
+
logger = logging.getLogger(__name__)
@@ -58,6 +60,12 @@ _STREAM_TOKEN = "stream"
_TOPOLOGICAL_TOKEN = "topological"
+# Used as return values for pagination APIs
+_EventDictReturn = namedtuple("_EventDictReturn", (
+ "event_id", "topological_ordering", "stream_ordering",
+))
+
+
def lower_bound(token, engine, inclusive=False):
inclusive = "=" if inclusive else ""
if token.topological is None:
@@ -196,13 +204,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
results = {}
room_ids = list(room_ids)
- for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
+ for rm_ids in (room_ids[i:i + 20] for i in range(0, len(room_ids), 20)):
res = yield make_deferred_yieldable(defer.gatherResults([
- preserve_fn(self.get_room_events_stream_for_room)(
+ run_in_background(
+ self.get_room_events_stream_for_room,
room_id, from_key, to_key, limit, order=order,
)
for room_id in rm_ids
- ]))
+ ], consumeErrors=True))
results.update(dict(zip(rm_ids, res)))
defer.returnValue(results)
@@ -224,54 +233,55 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
@defer.inlineCallbacks
def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0,
order='DESC'):
- # Note: If from_key is None then we return in topological order. This
- # is because in that case we're using this as a "get the last few messages
- # in a room" function, rather than "get new messages since last sync"
- if from_key is not None:
- from_id = RoomStreamToken.parse_stream_token(from_key).stream
- else:
- from_id = None
- to_id = RoomStreamToken.parse_stream_token(to_key).stream
+ """Get new room events in stream ordering since `from_key`.
+
+ Args:
+ room_id (str)
+ from_key (str): Token from which no events are returned before
+ to_key (str): Token from which no events are returned after. (This
+ is typically the current stream token)
+ limit (int): Maximum number of events to return
+ order (str): Either "DESC" or "ASC". Determines which events are
+ returned when the result is limited. If "DESC" then the most
+ recent `limit` events are returned, otherwise returns the
+ oldest `limit` events.
+
+ Returns:
+ Deferred[tuple[list[FrozenEvent], str]]: Returns the list of
+ events (in ascending order) and the token from the start of
+ the chunk of events returned.
+ """
if from_key == to_key:
defer.returnValue(([], from_key))
- if from_id:
- has_changed = yield self._events_stream_cache.has_entity_changed(
- room_id, from_id
- )
-
- if not has_changed:
- defer.returnValue(([], from_key))
+ from_id = RoomStreamToken.parse_stream_token(from_key).stream
+ to_id = RoomStreamToken.parse_stream_token(to_key).stream
- def f(txn):
- if from_id is not None:
- sql = (
- "SELECT event_id, stream_ordering FROM events WHERE"
- " room_id = ?"
- " AND not outlier"
- " AND stream_ordering > ? AND stream_ordering <= ?"
- " ORDER BY stream_ordering %s LIMIT ?"
- ) % (order,)
- txn.execute(sql, (room_id, from_id, to_id, limit))
- else:
- sql = (
- "SELECT event_id, stream_ordering FROM events WHERE"
- " room_id = ?"
- " AND not outlier"
- " AND stream_ordering <= ?"
- " ORDER BY topological_ordering %s, stream_ordering %s LIMIT ?"
- ) % (order, order,)
- txn.execute(sql, (room_id, to_id, limit))
+ has_changed = yield self._events_stream_cache.has_entity_changed(
+ room_id, from_id
+ )
- rows = self.cursor_to_dict(txn)
+ if not has_changed:
+ defer.returnValue(([], from_key))
+ def f(txn):
+ sql = (
+ "SELECT event_id, stream_ordering FROM events WHERE"
+ " room_id = ?"
+ " AND not outlier"
+ " AND stream_ordering > ? AND stream_ordering <= ?"
+ " ORDER BY stream_ordering %s LIMIT ?"
+ ) % (order,)
+ txn.execute(sql, (room_id, from_id, to_id, limit))
+
+ rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
return rows
rows = yield self.runInteraction("get_room_events_stream_for_room", f)
ret = yield self._get_events(
- [r["event_id"] for r in rows],
+ [r.event_id for r in rows],
get_prev_content=True
)
@@ -281,7 +291,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
ret.reverse()
if rows:
- key = "s%d" % min(r["stream_ordering"] for r in rows)
+ key = "s%d" % min(r.stream_ordering for r in rows)
else:
# Assume we didn't get anything because there was nothing to
# get.
@@ -291,10 +301,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
@defer.inlineCallbacks
def get_membership_changes_for_user(self, user_id, from_key, to_key):
- if from_key is not None:
- from_id = RoomStreamToken.parse_stream_token(from_key).stream
- else:
- from_id = None
+ from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
if from_key == to_key:
@@ -308,34 +315,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
defer.returnValue([])
def f(txn):
- if from_id is not None:
- sql = (
- "SELECT m.event_id, stream_ordering FROM events AS e,"
- " room_memberships AS m"
- " WHERE e.event_id = m.event_id"
- " AND m.user_id = ?"
- " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
- " ORDER BY e.stream_ordering ASC"
- )
- txn.execute(sql, (user_id, from_id, to_id,))
- else:
- sql = (
- "SELECT m.event_id, stream_ordering FROM events AS e,"
- " room_memberships AS m"
- " WHERE e.event_id = m.event_id"
- " AND m.user_id = ?"
- " AND stream_ordering <= ?"
- " ORDER BY stream_ordering ASC"
- )
- txn.execute(sql, (user_id, to_id,))
- rows = self.cursor_to_dict(txn)
+ sql = (
+ "SELECT m.event_id, stream_ordering FROM events AS e,"
+ " room_memberships AS m"
+ " WHERE e.event_id = m.event_id"
+ " AND m.user_id = ?"
+ " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
+ " ORDER BY e.stream_ordering ASC"
+ )
+ txn.execute(sql, (user_id, from_id, to_id,))
+
+ rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
return rows
rows = yield self.runInteraction("get_membership_changes_for_user", f)
ret = yield self._get_events(
- [r["event_id"] for r in rows],
+ [r.event_id for r in rows],
get_prev_content=True
)
@@ -344,14 +341,28 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
defer.returnValue(ret)
@defer.inlineCallbacks
- def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None):
+ def get_recent_events_for_room(self, room_id, limit, end_token):
+ """Get the most recent events in the room in topological ordering.
+
+ Args:
+ room_id (str)
+ limit (int)
+ end_token (str): The stream token representing now.
+
+ Returns:
+ Deferred[tuple[list[FrozenEvent], str]]: Returns a list of
+ events and a token pointing to the start of the returned
+ events.
+ The events returned are in ascending order.
+ """
+
rows, token = yield self.get_recent_event_ids_for_room(
- room_id, limit, end_token, from_token
+ room_id, limit, end_token,
)
logger.debug("stream before")
events = yield self._get_events(
- [r["event_id"] for r in rows],
+ [r.event_id for r in rows],
get_prev_content=True
)
logger.debug("stream after")
@@ -360,60 +371,36 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
defer.returnValue((events, token))
- @cached(num_args=4)
- def get_recent_event_ids_for_room(self, room_id, limit, end_token, from_token=None):
- end_token = RoomStreamToken.parse_stream_token(end_token)
-
- if from_token is None:
- sql = (
- "SELECT stream_ordering, topological_ordering, event_id"
- " FROM events"
- " WHERE room_id = ? AND stream_ordering <= ? AND outlier = ?"
- " ORDER BY topological_ordering DESC, stream_ordering DESC"
- " LIMIT ?"
- )
- else:
- from_token = RoomStreamToken.parse_stream_token(from_token)
- sql = (
- "SELECT stream_ordering, topological_ordering, event_id"
- " FROM events"
- " WHERE room_id = ? AND stream_ordering > ?"
- " AND stream_ordering <= ? AND outlier = ?"
- " ORDER BY topological_ordering DESC, stream_ordering DESC"
- " LIMIT ?"
- )
-
- def get_recent_events_for_room_txn(txn):
- if from_token is None:
- txn.execute(sql, (room_id, end_token.stream, False, limit,))
- else:
- txn.execute(sql, (
- room_id, from_token.stream, end_token.stream, False, limit
- ))
+ @defer.inlineCallbacks
+ def get_recent_event_ids_for_room(self, room_id, limit, end_token):
+ """Get the most recent events in the room in topological ordering.
- rows = self.cursor_to_dict(txn)
+ Args:
+ room_id (str)
+ limit (int)
+ end_token (str): The stream token representing now.
- rows.reverse() # As we selected with reverse ordering
+ Returns:
+ Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of
+ _EventDictReturn and a token pointing to the start of the returned
+ events.
+ The events returned are in ascending order.
+ """
+ # Allow a zero limit here, and no-op.
+ if limit == 0:
+ defer.returnValue(([], end_token))
- if rows:
- # Tokens are positions between events.
- # This token points *after* the last event in the chunk.
- # We need it to point to the event before it in the chunk
- # since we are going backwards so we subtract one from the
- # stream part.
- topo = rows[0]["topological_ordering"]
- toke = rows[0]["stream_ordering"] - 1
- start_token = str(RoomStreamToken(topo, toke))
+ end_token = RoomStreamToken.parse(end_token)
- token = (start_token, str(end_token))
- else:
- token = (str(end_token), str(end_token))
+ rows, token = yield self.runInteraction(
+ "get_recent_event_ids_for_room", self._paginate_room_events_txn,
+ room_id, from_token=end_token, limit=limit,
+ )
- return rows, token
+ # We want to return the results in ascending order.
+ rows.reverse()
- return self.runInteraction(
- "get_recent_events_for_room", get_recent_events_for_room_txn
- )
+ defer.returnValue((rows, token))
def get_room_event_after_stream_ordering(self, room_id, stream_ordering):
"""Gets details of the first event in a room at or after a stream ordering
@@ -517,10 +504,20 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
@staticmethod
def _set_before_and_after(events, rows, topo_order=True):
+ """Inserts ordering information to events' internal metadata from
+ the DB rows.
+
+ Args:
+ events (list[FrozenEvent])
+ rows (list[_EventDictReturn])
+ topo_order (bool): Whether the events were ordered topologically
+ or by stream ordering. If true then all rows should have a non
+ null topological_ordering.
+ """
for event, row in zip(events, rows):
- stream = row["stream_ordering"]
- if topo_order:
- topo = event.depth
+ stream = row.stream_ordering
+ if topo_order and row.topological_ordering:
+ topo = row.topological_ordering
else:
topo = None
internal = event.internal_metadata
@@ -592,87 +589,27 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
retcols=["stream_ordering", "topological_ordering"],
)
- token = RoomStreamToken(
- results["topological_ordering"],
+ # Paginating backwards includes the event at the token, but paginating
+ # forward doesn't.
+ before_token = RoomStreamToken(
+ results["topological_ordering"] - 1,
results["stream_ordering"],
)
- if isinstance(self.database_engine, Sqlite3Engine):
- # SQLite3 doesn't optimise ``(x < a) OR (x = a AND y < b)``
- # So we give pass it to SQLite3 as the UNION ALL of the two queries.
-
- query_before = (
- "SELECT topological_ordering, stream_ordering, event_id FROM events"
- " WHERE room_id = ? AND topological_ordering < ?"
- " UNION ALL"
- " SELECT topological_ordering, stream_ordering, event_id FROM events"
- " WHERE room_id = ? AND topological_ordering = ? AND stream_ordering < ?"
- " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
- )
- before_args = (
- room_id, token.topological,
- room_id, token.topological, token.stream,
- before_limit,
- )
-
- query_after = (
- "SELECT topological_ordering, stream_ordering, event_id FROM events"
- " WHERE room_id = ? AND topological_ordering > ?"
- " UNION ALL"
- " SELECT topological_ordering, stream_ordering, event_id FROM events"
- " WHERE room_id = ? AND topological_ordering = ? AND stream_ordering > ?"
- " ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?"
- )
- after_args = (
- room_id, token.topological,
- room_id, token.topological, token.stream,
- after_limit,
- )
- else:
- query_before = (
- "SELECT topological_ordering, stream_ordering, event_id FROM events"
- " WHERE room_id = ? AND %s"
- " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
- ) % (upper_bound(token, self.database_engine, inclusive=False),)
-
- before_args = (room_id, before_limit)
-
- query_after = (
- "SELECT topological_ordering, stream_ordering, event_id FROM events"
- " WHERE room_id = ? AND %s"
- " ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?"
- ) % (lower_bound(token, self.database_engine, inclusive=False),)
-
- after_args = (room_id, after_limit)
-
- txn.execute(query_before, before_args)
-
- rows = self.cursor_to_dict(txn)
- events_before = [r["event_id"] for r in rows]
-
- if rows:
- start_token = str(RoomStreamToken(
- rows[0]["topological_ordering"],
- rows[0]["stream_ordering"] - 1,
- ))
- else:
- start_token = str(RoomStreamToken(
- token.topological,
- token.stream - 1,
- ))
-
- txn.execute(query_after, after_args)
+ after_token = RoomStreamToken(
+ results["topological_ordering"],
+ results["stream_ordering"],
+ )
- rows = self.cursor_to_dict(txn)
- events_after = [r["event_id"] for r in rows]
+ rows, start_token = self._paginate_room_events_txn(
+ txn, room_id, before_token, direction='b', limit=before_limit,
+ )
+ events_before = [r.event_id for r in rows]
- if rows:
- end_token = str(RoomStreamToken(
- rows[-1]["topological_ordering"],
- rows[-1]["stream_ordering"],
- ))
- else:
- end_token = str(token)
+ rows, end_token = self._paginate_room_events_txn(
+ txn, room_id, after_token, direction='f', limit=after_limit,
+ )
+ events_after = [r.event_id for r in rows]
return {
"before": {
@@ -735,17 +672,30 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def has_room_changed_since(self, room_id, stream_id):
return self._events_stream_cache.has_entity_changed(room_id, stream_id)
+ def _paginate_room_events_txn(self, txn, room_id, from_token, to_token=None,
+ direction='b', limit=-1, event_filter=None):
+ """Returns list of events before or after a given token.
-class StreamStore(StreamWorkerStore):
- def get_room_max_stream_ordering(self):
- return self._stream_id_gen.get_current_token()
+ Args:
+ txn
+ room_id (str)
+ from_token (RoomStreamToken): The token used to stream from
+ to_token (RoomStreamToken|None): A token which if given limits the
+ results to only those before
+ direction(char): Either 'b' or 'f' to indicate whether we are
+ paginating forwards or backwards from `from_key`.
+ limit (int): The maximum number of events to return.
+ event_filter (Filter|None): If provided filters the events to
+ those that match the filter.
- def get_room_min_stream_ordering(self):
- return self._backfill_id_gen.get_current_token()
+ Returns:
+ Deferred[tuple[list[_EventDictReturn], str]]: Returns the results
+ as a list of _EventDictReturn and a token that points to the end
+ of the result set.
+ """
+
+ assert int(limit) >= 0
- @defer.inlineCallbacks
- def paginate_room_events(self, room_id, from_key, to_key=None,
- direction='b', limit=-1, event_filter=None):
# Tokens really represent positions between elements, but we use
# the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities.
@@ -753,20 +703,20 @@ class StreamStore(StreamWorkerStore):
if direction == 'b':
order = "DESC"
bounds = upper_bound(
- RoomStreamToken.parse(from_key), self.database_engine
+ from_token, self.database_engine
)
- if to_key:
+ if to_token:
bounds = "%s AND %s" % (bounds, lower_bound(
- RoomStreamToken.parse(to_key), self.database_engine
+ to_token, self.database_engine
))
else:
order = "ASC"
bounds = lower_bound(
- RoomStreamToken.parse(from_key), self.database_engine
+ from_token, self.database_engine
)
- if to_key:
+ if to_token:
bounds = "%s AND %s" % (bounds, upper_bound(
- RoomStreamToken.parse(to_key), self.database_engine
+ to_token, self.database_engine
))
filter_clause, filter_args = filter_to_clause(event_filter)
@@ -775,52 +725,85 @@ class StreamStore(StreamWorkerStore):
bounds += " AND " + filter_clause
args.extend(filter_args)
- if int(limit) > 0:
- args.append(int(limit))
- limit_str = " LIMIT ?"
- else:
- limit_str = ""
+ args.append(int(limit))
sql = (
- "SELECT * FROM events"
+ "SELECT event_id, topological_ordering, stream_ordering"
+ " FROM events"
" WHERE outlier = ? AND room_id = ? AND %(bounds)s"
" ORDER BY topological_ordering %(order)s,"
- " stream_ordering %(order)s %(limit)s"
+ " stream_ordering %(order)s LIMIT ?"
) % {
"bounds": bounds,
"order": order,
- "limit": limit_str
}
- def f(txn):
- txn.execute(sql, args)
-
- rows = self.cursor_to_dict(txn)
-
- if rows:
- topo = rows[-1]["topological_ordering"]
- toke = rows[-1]["stream_ordering"]
- if direction == 'b':
- # Tokens are positions between events.
- # This token points *after* the last event in the chunk.
- # We need it to point to the event before it in the chunk
- # when we are going backwards so we subtract one from the
- # stream part.
- toke -= 1
- next_token = str(RoomStreamToken(topo, toke))
- else:
- # TODO (erikj): We should work out what to do here instead.
- next_token = to_key if to_key else from_key
+ txn.execute(sql, args)
+
+ rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
+
+ if rows:
+ topo = rows[-1].topological_ordering
+ toke = rows[-1].stream_ordering
+ if direction == 'b':
+ # Tokens are positions between events.
+ # This token points *after* the last event in the chunk.
+ # We need it to point to the event before it in the chunk
+ # when we are going backwards so we subtract one from the
+ # stream part.
+ toke -= 1
+ next_token = RoomStreamToken(topo, toke)
+ else:
+ # TODO (erikj): We should work out what to do here instead.
+ next_token = to_token if to_token else from_token
+
+ return rows, str(next_token),
+
+ @defer.inlineCallbacks
+ def paginate_room_events(self, room_id, from_key, to_key=None,
+ direction='b', limit=-1, event_filter=None):
+ """Returns list of events before or after a given token.
- return rows, next_token,
+ Args:
+ room_id (str)
+ from_key (str): The token used to stream from
+ to_key (str|None): A token which if given limits the results to
+ only those before
+ direction(char): Either 'b' or 'f' to indicate whether we are
+ paginating forwards or backwards from `from_key`.
+ limit (int): The maximum number of events to return. Zero or less
+ means no limit.
+ event_filter (Filter|None): If provided filters the events to
+ those that match the filter.
- rows, token = yield self.runInteraction("paginate_room_events", f)
+ Returns:
+ tuple[list[dict], str]: Returns the results as a list of dicts and
+ a token that points to the end of the result set. The dicts have
+ the keys "event_id", "topological_ordering" and "stream_orderign".
+ """
+
+ from_key = RoomStreamToken.parse(from_key)
+ if to_key:
+ to_key = RoomStreamToken.parse(to_key)
+
+ rows, token = yield self.runInteraction(
+ "paginate_room_events", self._paginate_room_events_txn,
+ room_id, from_key, to_key, direction, limit, event_filter,
+ )
events = yield self._get_events(
- [r["event_id"] for r in rows],
+ [r.event_id for r in rows],
get_prev_content=True
)
self._set_before_and_after(events, rows)
defer.returnValue((events, token))
+
+
+class StreamStore(StreamWorkerStore):
+ def get_room_max_stream_ordering(self):
+ return self._stream_id_gen.get_current_token()
+
+ def get_room_min_stream_ordering(self):
+ return self._backfill_id_gen.get_current_token()
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index fc46bf7bb3..6671d3cfca 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -19,9 +19,11 @@ from synapse.storage.account_data import AccountDataWorkerStore
from synapse.util.caches.descriptors import cached
from twisted.internet import defer
-import ujson as json
+import simplejson as json
import logging
+from six.moves import range
+
logger = logging.getLogger(__name__)
@@ -98,7 +100,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
batch_size = 50
results = []
- for i in xrange(0, len(tag_ids), batch_size):
+ for i in range(0, len(tag_ids), batch_size):
tags = yield self.runInteraction(
"get_all_updated_tag_content",
get_tag_content,
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index 8f61f7ffae..f825264ea9 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -23,7 +23,7 @@ from canonicaljson import encode_canonical_json
from collections import namedtuple
import logging
-import ujson as json
+import simplejson as json
logger = logging.getLogger(__name__)
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index dfdcbb3181..d6e289ffbe 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -667,7 +667,7 @@ class UserDirectoryStore(SQLBaseStore):
# The array of numbers are the weights for the various part of the
# search: (domain, _, display name, localpart)
sql = """
- SELECT d.user_id, display_name, avatar_url
+ SELECT d.user_id AS user_id, display_name, avatar_url
FROM user_directory_search
INNER JOIN user_directory AS d USING (user_id)
%s
@@ -702,7 +702,7 @@ class UserDirectoryStore(SQLBaseStore):
search_query = _parse_query_sqlite(search_term)
sql = """
- SELECT d.user_id, display_name, avatar_url
+ SELECT d.user_id AS user_id, display_name, avatar_url
FROM user_directory_search
INNER JOIN user_directory AS d USING (user_id)
%s
diff --git a/synapse/types.py b/synapse/types.py
index 7cb24cecb2..cc7c182a78 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -169,7 +169,7 @@ class DomainSpecificString(
except Exception:
return False
- __str__ = to_string
+ __repr__ = to_string
class UserID(DomainSpecificString):
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 756d8ffa32..fc11e26623 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.api.errors import SynapseError
from synapse.util.logcontext import PreserveLoggingContext
from twisted.internet import defer, reactor, task
@@ -21,12 +20,9 @@ from twisted.internet import defer, reactor, task
import time
import logging
-logger = logging.getLogger(__name__)
-
+from itertools import islice
-class DeferredTimedOutError(SynapseError):
- def __init__(self):
- super(DeferredTimedOutError, self).__init__(504, "Timed out")
+logger = logging.getLogger(__name__)
def unwrapFirstError(failure):
@@ -86,52 +82,18 @@ class Clock(object):
if not ignore_errs:
raise
- def time_bound_deferred(self, given_deferred, time_out):
- if given_deferred.called:
- return given_deferred
-
- ret_deferred = defer.Deferred()
-
- def timed_out_fn():
- e = DeferredTimedOutError()
-
- try:
- ret_deferred.errback(e)
- except Exception:
- pass
-
- try:
- given_deferred.cancel()
- except Exception:
- pass
- timer = None
+def batch_iter(iterable, size):
+ """batch an iterable up into tuples with a maximum size
- def cancel(res):
- try:
- self.cancel_call_later(timer)
- except Exception:
- pass
- return res
+ Args:
+ iterable (iterable): the iterable to slice
+ size (int): the maximum batch size
- ret_deferred.addBoth(cancel)
-
- def success(res):
- try:
- ret_deferred.callback(res)
- except Exception:
- pass
-
- return res
-
- def err(res):
- try:
- ret_deferred.errback(res)
- except Exception:
- pass
-
- given_deferred.addCallbacks(callback=success, errback=err)
-
- timer = self.call_later(time_out, timed_out_fn)
-
- return ret_deferred
+ Returns:
+ an iterator over the chunks
+ """
+ # make sure we can deal with iterables like lists too
+ sourceiter = iter(iterable)
+ # call islice until it returns an empty tuple
+ return iter(lambda: tuple(islice(sourceiter, size)), ())
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 0729bb2863..9dd4e6b5bc 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -15,9 +15,11 @@
from twisted.internet import defer, reactor
+from twisted.internet.defer import CancelledError
+from twisted.python import failure
from .logcontext import (
- PreserveLoggingContext, make_deferred_yieldable, preserve_fn
+ PreserveLoggingContext, make_deferred_yieldable, run_in_background
)
from synapse.util import logcontext, unwrapFirstError
@@ -25,6 +27,8 @@ from contextlib import contextmanager
import logging
+from six.moves import range
+
logger = logging.getLogger(__name__)
@@ -156,13 +160,13 @@ def concurrently_execute(func, args, limit):
def _concurrently_execute_inner():
try:
while True:
- yield func(it.next())
+ yield func(next(it))
except StopIteration:
pass
return logcontext.make_deferred_yieldable(defer.gatherResults([
- preserve_fn(_concurrently_execute_inner)()
- for _ in xrange(limit)
+ run_in_background(_concurrently_execute_inner)
+ for _ in range(limit)
], consumeErrors=True)).addErrback(unwrapFirstError)
@@ -392,3 +396,68 @@ class ReadWriteLock(object):
self.key_to_current_writer.pop(key)
defer.returnValue(_ctx_manager())
+
+
+class DeferredTimeoutError(Exception):
+ """
+ This error is raised by default when a L{Deferred} times out.
+ """
+
+
+def add_timeout_to_deferred(deferred, timeout, on_timeout_cancel=None):
+ """
+ Add a timeout to a deferred by scheduling it to be cancelled after
+ timeout seconds.
+
+ This is essentially a backport of deferred.addTimeout, which was introduced
+ in twisted 16.5.
+
+ If the deferred gets timed out, it errbacks with a DeferredTimeoutError,
+ unless a cancelable function was passed to its initialization or unless
+ a different on_timeout_cancel callable is provided.
+
+ Args:
+ deferred (defer.Deferred): deferred to be timed out
+ timeout (Number): seconds to time out after
+
+ on_timeout_cancel (callable): A callable which is called immediately
+ after the deferred times out, and not if this deferred is
+ otherwise cancelled before the timeout.
+
+ It takes an arbitrary value, which is the value of the deferred at
+ that exact point in time (probably a CancelledError Failure), and
+ the timeout.
+
+ The default callable (if none is provided) will translate a
+ CancelledError Failure into a DeferredTimeoutError.
+ """
+ timed_out = [False]
+
+ def time_it_out():
+ timed_out[0] = True
+ deferred.cancel()
+
+ delayed_call = reactor.callLater(timeout, time_it_out)
+
+ def convert_cancelled(value):
+ if timed_out[0]:
+ to_call = on_timeout_cancel or _cancelled_to_timed_out_error
+ return to_call(value, timeout)
+ return value
+
+ deferred.addBoth(convert_cancelled)
+
+ def cancel_timeout(result):
+ # stop the pending call to cancel the deferred if it's been fired
+ if delayed_call.active():
+ delayed_call.cancel()
+ return result
+
+ deferred.addBoth(cancel_timeout)
+
+
+def _cancelled_to_timed_out_error(value, timeout):
+ if isinstance(value, failure.Failure):
+ value.trap(CancelledError)
+ raise DeferredTimeoutError(timeout, "Deferred")
+ return value
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 4adae96681..183faf75a1 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -13,28 +13,77 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import synapse.metrics
+from prometheus_client.core import Gauge, REGISTRY, GaugeMetricFamily
+
import os
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5))
+from six.moves import intern
+import six
-metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
+CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5))
caches_by_name = {}
-# cache_counter = metrics.register_cache(
-# "cache",
-# lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
-# labels=["name"],
-# )
-
-
-def register_cache(name, cache):
- caches_by_name[name] = cache
- return metrics.register_cache(
- "cache",
- lambda: len(cache),
- name,
- )
+collectors_by_name = {}
+
+cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
+cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
+cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name"])
+cache_total = Gauge("synapse_util_caches_cache:total", "", ["name"])
+
+response_cache_size = Gauge("synapse_util_caches_response_cache:size", "", ["name"])
+response_cache_hits = Gauge("synapse_util_caches_response_cache:hits", "", ["name"])
+response_cache_evicted = Gauge(
+ "synapse_util_caches_response_cache:evicted_size", "", ["name"]
+)
+response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"])
+
+
+def register_cache(cache_type, cache_name, cache):
+
+ # Check if the metric is already registered. Unregister it, if so.
+ # This usually happens during tests, as at runtime these caches are
+ # effectively singletons.
+ metric_name = "cache_%s_%s" % (cache_type, cache_name)
+ if metric_name in collectors_by_name.keys():
+ REGISTRY.unregister(collectors_by_name[metric_name])
+
+ class CacheMetric(object):
+
+ hits = 0
+ misses = 0
+ evicted_size = 0
+
+ def inc_hits(self):
+ self.hits += 1
+
+ def inc_misses(self):
+ self.misses += 1
+
+ def inc_evictions(self, size=1):
+ self.evicted_size += size
+
+ def describe(self):
+ return []
+
+ def collect(self):
+ if cache_type == "response_cache":
+ response_cache_size.labels(cache_name).set(len(cache))
+ response_cache_hits.labels(cache_name).set(self.hits)
+ response_cache_evicted.labels(cache_name).set(self.evicted_size)
+ response_cache_total.labels(cache_name).set(self.hits + self.misses)
+ else:
+ cache_size.labels(cache_name).set(len(cache))
+ cache_hits.labels(cache_name).set(self.hits)
+ cache_evicted.labels(cache_name).set(self.evicted_size)
+ cache_total.labels(cache_name).set(self.hits + self.misses)
+
+ yield GaugeMetricFamily("__unused", "")
+
+ metric = CacheMetric()
+ REGISTRY.register(metric)
+ caches_by_name[cache_name] = cache
+ collectors_by_name[metric_name] = metric
+ return metric
KNOWN_KEYS = {
@@ -66,7 +115,9 @@ def intern_string(string):
return None
try:
- string = string.encode("ascii")
+ if six.PY2:
+ string = string.encode("ascii")
+
return intern(string)
except UnicodeEncodeError:
return string
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index bf3a66eae4..8a9dcb2fc2 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -39,12 +40,11 @@ _CacheSentinel = object()
class CacheEntry(object):
__slots__ = [
- "deferred", "sequence", "callbacks", "invalidated"
+ "deferred", "callbacks", "invalidated"
]
- def __init__(self, deferred, sequence, callbacks):
+ def __init__(self, deferred, callbacks):
self.deferred = deferred
- self.sequence = sequence
self.callbacks = set(callbacks)
self.invalidated = False
@@ -62,7 +62,6 @@ class Cache(object):
"max_entries",
"name",
"keylen",
- "sequence",
"thread",
"metrics",
"_pending_deferred_cache",
@@ -80,9 +79,8 @@ class Cache(object):
self.name = name
self.keylen = keylen
- self.sequence = 0
self.thread = None
- self.metrics = register_cache(name, self.cache)
+ self.metrics = register_cache("cache", name, self.cache)
def _on_evicted(self, evicted_count):
self.metrics.inc_evictions(evicted_count)
@@ -113,11 +111,10 @@ class Cache(object):
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
- if val.sequence == self.sequence:
- val.callbacks.update(callbacks)
- if update_metrics:
- self.metrics.inc_hits()
- return val.deferred
+ val.callbacks.update(callbacks)
+ if update_metrics:
+ self.metrics.inc_hits()
+ return val.deferred
val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
if val is not _CacheSentinel:
@@ -137,12 +134,9 @@ class Cache(object):
self.check_thread()
entry = CacheEntry(
deferred=value,
- sequence=self.sequence,
callbacks=callbacks,
)
- entry.callbacks.update(callbacks)
-
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()
@@ -150,13 +144,25 @@ class Cache(object):
self._pending_deferred_cache[key] = entry
def shuffle(result):
- if self.sequence == entry.sequence:
- existing_entry = self._pending_deferred_cache.pop(key, None)
- if existing_entry is entry:
- self.cache.set(key, result, entry.callbacks)
- else:
- entry.invalidate()
+ existing_entry = self._pending_deferred_cache.pop(key, None)
+ if existing_entry is entry:
+ self.cache.set(key, result, entry.callbacks)
else:
+ # oops, the _pending_deferred_cache has been updated since
+ # we started our query, so we are out of date.
+ #
+ # Better put back whatever we took out. (We do it this way
+ # round, rather than peeking into the _pending_deferred_cache
+ # and then removing on a match, to make the common case faster)
+ if existing_entry is not None:
+ self._pending_deferred_cache[key] = existing_entry
+
+ # we're not going to put this entry into the cache, so need
+ # to make sure that the invalidation callbacks are called.
+ # That was probably done when _pending_deferred_cache was
+ # updated, but it's possible that `set` was called without
+ # `invalidate` being previously called, in which case it may
+ # not have been. Either way, let's double-check now.
entry.invalidate()
return result
@@ -168,25 +174,29 @@ class Cache(object):
def invalidate(self, key):
self.check_thread()
+ self.cache.pop(key, None)
- # Increment the sequence number so that any SELECT statements that
- # raced with the INSERT don't update the cache (SYN-369)
- self.sequence += 1
+ # if we have a pending lookup for this key, remove it from the
+ # _pending_deferred_cache, which will (a) stop it being returned
+ # for future queries and (b) stop it being persisted as a proper entry
+ # in self.cache.
entry = self._pending_deferred_cache.pop(key, None)
+
+ # run the invalidation callbacks now, rather than waiting for the
+ # deferred to resolve.
if entry:
entry.invalidate()
- self.cache.pop(key, None)
-
def invalidate_many(self, key):
self.check_thread()
if not isinstance(key, tuple):
raise TypeError(
"The cache key must be a tuple not %r" % (type(key),)
)
- self.sequence += 1
self.cache.del_multi(key)
+ # if we have a pending lookup for this key, remove it from the
+ # _pending_deferred_cache, as above
entry_dict = self._pending_deferred_cache.pop(key, None)
if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict):
@@ -194,8 +204,10 @@ class Cache(object):
def invalidate_all(self):
self.check_thread()
- self.sequence += 1
self.cache.clear()
+ for entry in self._pending_deferred_cache.itervalues():
+ entry.invalidate()
+ self._pending_deferred_cache.clear()
class _CacheDescriptorBase(object):
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index d4105822b3..bdc21e348f 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -55,7 +55,7 @@ class DictionaryCache(object):
__slots__ = []
self.sentinel = Sentinel()
- self.metrics = register_cache(name, self.cache)
+ self.metrics = register_cache("dictionary", name, self.cache)
def check_thread(self):
expected_thread = self.thread
@@ -132,9 +132,13 @@ class DictionaryCache(object):
self._update_or_insert(key, value, known_absent)
def _update_or_insert(self, key, value, known_absent):
- entry = self.cache.setdefault(key, DictionaryEntry(False, set(), {}))
+ # We pop and reinsert as we need to tell the cache the size may have
+ # changed
+
+ entry = self.cache.pop(key, DictionaryEntry(False, set(), {}))
entry.value.update(value)
entry.known_absent.update(known_absent)
+ self.cache[key] = entry
def _insert(self, key, value, known_absent):
self.cache[key] = DictionaryEntry(True, known_absent, value)
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 0aa103eecb..ff04c91955 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -52,12 +52,12 @@ class ExpiringCache(object):
self._cache = OrderedDict()
- self.metrics = register_cache(cache_name, self)
-
self.iterable = iterable
self._size_estimate = 0
+ self.metrics = register_cache("expiring", cache_name, self)
+
def start(self):
if not self._expiry_ms:
# Don't bother starting the loop if things never expire
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index f088dd430e..1c5a982094 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -154,14 +154,21 @@ class LruCache(object):
def cache_set(key, value, callbacks=[]):
node = cache.get(key, None)
if node is not None:
- if value != node.value:
+ # We sometimes store large objects, e.g. dicts, which cause
+ # the inequality check to take a long time. So let's only do
+ # the check if we have some callbacks to call.
+ if node.callbacks and value != node.value:
for cb in node.callbacks:
cb()
node.callbacks.clear()
- if size_callback:
- cached_cache_len[0] -= size_callback(node.value)
- cached_cache_len[0] += size_callback(value)
+ # We don't bother to protect this by value != node.value as
+ # generally size_callback will be cheap compared with equality
+ # checks. (For example, taking the size of two dicts is quicker
+ # than comparing them for equality.)
+ if size_callback:
+ cached_cache_len[0] -= size_callback(node.value)
+ cached_cache_len[0] += size_callback(value)
node.callbacks.update(callbacks)
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 00af539880..a8491b42d5 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -12,8 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
+from twisted.internet import defer
from synapse.util.async import ObservableDeferred
+from synapse.util.caches import register_cache
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+
+logger = logging.getLogger(__name__)
class ResponseCache(object):
@@ -24,20 +31,69 @@ class ResponseCache(object):
used rather than trying to compute a new response.
"""
- def __init__(self, hs, timeout_ms=0):
+ def __init__(self, hs, name, timeout_ms=0):
self.pending_result_cache = {} # Requests that haven't finished yet.
self.clock = hs.get_clock()
self.timeout_sec = timeout_ms / 1000.
+ self._name = name
+ self._metrics = register_cache(
+ "response_cache", name, self
+ )
+
+ def size(self):
+ return len(self.pending_result_cache)
+
+ def __len__(self):
+ return self.size()
+
def get(self, key):
+ """Look up the given key.
+
+ Can return either a new Deferred (which also doesn't follow the synapse
+ logcontext rules), or, if the request has completed, the actual
+ result. You will probably want to make_deferred_yieldable the result.
+
+ If there is no entry for the key, returns None. It is worth noting that
+ this means there is no way to distinguish a completed result of None
+ from an absent cache entry.
+
+ Args:
+ key (hashable):
+
+ Returns:
+ twisted.internet.defer.Deferred|None|E: None if there is no entry
+ for this key; otherwise either a deferred result or the result
+ itself.
+ """
result = self.pending_result_cache.get(key)
if result is not None:
+ self._metrics.inc_hits()
return result.observe()
else:
+ self._metrics.inc_misses()
return None
def set(self, key, deferred):
+ """Set the entry for the given key to the given deferred.
+
+ *deferred* should run its callbacks in the sentinel logcontext (ie,
+ you should wrap normal synapse deferreds with
+ logcontext.run_in_background).
+
+ Can return either a new Deferred (which also doesn't follow the synapse
+ logcontext rules), or, if *deferred* was already complete, the actual
+ result. You will probably want to make_deferred_yieldable the result.
+
+ Args:
+ key (hashable):
+ deferred (twisted.internet.defer.Deferred[T):
+
+ Returns:
+ twisted.internet.defer.Deferred[T]|T: a new deferred, or the actual
+ result.
+ """
result = ObservableDeferred(deferred, consumeErrors=True)
self.pending_result_cache[key] = result
@@ -53,3 +109,52 @@ class ResponseCache(object):
result.addBoth(remove)
return result.observe()
+
+ def wrap(self, key, callback, *args, **kwargs):
+ """Wrap together a *get* and *set* call, taking care of logcontexts
+
+ First looks up the key in the cache, and if it is present makes it
+ follow the synapse logcontext rules and returns it.
+
+ Otherwise, makes a call to *callback(*args, **kwargs)*, which should
+ follow the synapse logcontext rules, and adds the result to the cache.
+
+ Example usage:
+
+ @defer.inlineCallbacks
+ def handle_request(request):
+ # etc
+ defer.returnValue(result)
+
+ result = yield response_cache.wrap(
+ key,
+ handle_request,
+ request,
+ )
+
+ Args:
+ key (hashable): key to get/set in the cache
+
+ callback (callable): function to call if the key is not found in
+ the cache
+
+ *args: positional parameters to pass to the callback, if it is used
+
+ **kwargs: named paramters to pass to the callback, if it is used
+
+ Returns:
+ twisted.internet.defer.Deferred: yieldable result
+ """
+ result = self.get(key)
+ if not result:
+ logger.info("[%s]: no cached result for [%s], calculating new one",
+ self._name, key)
+ d = run_in_background(callback, *args, **kwargs)
+ result = self.set(key, d)
+ elif not isinstance(result, defer.Deferred) or result.called:
+ logger.info("[%s]: using completed cached result for [%s]",
+ self._name, key)
+ else:
+ logger.info("[%s]: using incomplete cached result for [%s]",
+ self._name, key)
+ return make_deferred_yieldable(result)
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 941d873ab8..a7fe0397fa 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -38,7 +38,7 @@ class StreamChangeCache(object):
self._cache = sorteddict()
self._earliest_known_stream_pos = current_stream_pos
self.name = name
- self.metrics = register_cache(self.name, self._cache)
+ self.metrics = register_cache("cache", self.name, self._cache)
for entity, stream_pos in prefilled_cache.items():
self.entity_has_changed(entity, stream_pos)
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
index 90a2608d6f..3380970e4e 100644
--- a/synapse/util/file_consumer.py
+++ b/synapse/util/file_consumer.py
@@ -15,9 +15,9 @@
from twisted.internet import threads, reactor
-from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
-import Queue
+from six.moves import queue
class BackgroundFileConsumer(object):
@@ -49,7 +49,7 @@ class BackgroundFileConsumer(object):
# Queue of slices of bytes to be written. When producer calls
# unregister a final None is sent.
- self._bytes_queue = Queue.Queue()
+ self._bytes_queue = queue.Queue()
# Deferred that is resolved when finished writing
self._finished_deferred = None
@@ -70,7 +70,9 @@ class BackgroundFileConsumer(object):
self._producer = producer
self.streaming = streaming
- self._finished_deferred = preserve_fn(threads.deferToThread)(self._writer)
+ self._finished_deferred = run_in_background(
+ threads.deferToThread, self._writer
+ )
if not streaming:
self._producer.resumeProducing()
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 6322f0f55c..f497b51f4a 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -14,6 +14,7 @@
# limitations under the License.
from frozendict import frozendict
+import simplejson as json
def freeze(o):
@@ -49,3 +50,21 @@ def unfreeze(o):
pass
return o
+
+
+def _handle_frozendict(obj):
+ """Helper for EventEncoder. Makes frozendicts serializable by returning
+ the underlying dict
+ """
+ if type(obj) is frozendict:
+ # fishing the protected dict out of the object is a bit nasty,
+ # but we don't really want the overhead of copying the dict.
+ return obj._dict
+ raise TypeError('Object of type %s is not JSON serializable' %
+ obj.__class__.__name__)
+
+
+# A JSONEncoder which is capable of encoding frozendics without barfing
+frozendict_json_encoder = json.JSONEncoder(
+ default=_handle_frozendict,
+)
diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py
index 45be47159a..e9f0f292ee 100644
--- a/synapse/util/httpresourcetree.py
+++ b/synapse/util/httpresourcetree.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.web.resource import Resource
+from twisted.web.resource import NoResource
import logging
@@ -40,12 +40,15 @@ def create_resource_tree(desired_tree, root_resource):
# extra resources to existing nodes. See self._resource_id for the key.
resource_mappings = {}
for full_path, res in desired_tree.items():
+ # twisted requires all resources to be bytes
+ full_path = full_path.encode("utf-8")
+
logger.info("Attaching %s to path %s", res, full_path)
last_resource = root_resource
- for path_seg in full_path.split('/')[1:-1]:
+ for path_seg in full_path.split(b'/')[1:-1]:
if path_seg not in last_resource.listNames():
# resource doesn't exist, so make a "dummy resource"
- child_resource = Resource()
+ child_resource = NoResource()
last_resource.putChild(path_seg, child_resource)
res_id = _resource_id(last_resource, path_seg)
resource_mappings[res_id] = child_resource
@@ -57,7 +60,7 @@ def create_resource_tree(desired_tree, root_resource):
# ===========================
# now attach the actual desired resource
- last_path_seg = full_path.split('/')[-1]
+ last_path_seg = full_path.split(b'/')[-1]
# if there is already a resource here, thieve its children and
# replace it
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index d660ec785b..a58c723403 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -59,8 +59,8 @@ class LoggingContext(object):
__slots__ = [
"previous_context", "name", "ru_stime", "ru_utime",
- "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
- "usage_start", "usage_end",
+ "db_txn_count", "db_txn_duration_sec", "db_sched_duration_sec",
+ "usage_start",
"main_thread", "alive",
"request", "tag",
]
@@ -84,14 +84,15 @@ class LoggingContext(object):
def stop(self):
pass
- def add_database_transaction(self, duration_ms):
+ def add_database_transaction(self, duration_sec):
pass
- def add_database_scheduled(self, sched_ms):
+ def add_database_scheduled(self, sched_sec):
pass
def __nonzero__(self):
return False
+ __bool__ = __nonzero__ # python3
sentinel = Sentinel()
@@ -102,14 +103,16 @@ class LoggingContext(object):
self.ru_utime = 0.
self.db_txn_count = 0
- # ms spent waiting for db txns, excluding scheduling time
- self.db_txn_duration_ms = 0
+ # sec spent waiting for db txns, excluding scheduling time
+ self.db_txn_duration_sec = 0
- # ms spent waiting for db txns to be scheduled
- self.db_sched_duration_ms = 0
+ # sec spent waiting for db txns to be scheduled
+ self.db_sched_duration_sec = 0
+ # If alive has the thread resource usage when the logcontext last
+ # became active.
self.usage_start = None
- self.usage_end = None
+
self.main_thread = threading.current_thread()
self.request = None
self.tag = ""
@@ -158,12 +161,12 @@ class LoggingContext(object):
"""Restore the logging context in thread local storage to the state it
was before this context was entered.
Returns:
- None to avoid suppressing any exeptions that were thrown.
+ None to avoid suppressing any exceptions that were thrown.
"""
current = self.set_current_context(self.previous_context)
if current is not self:
if current is self.sentinel:
- logger.debug("Expected logging context %s has been lost", self)
+ logger.warn("Expected logging context %s has been lost", self)
else:
logger.warn(
"Current logging context %s is not expected context %s",
@@ -184,47 +187,61 @@ class LoggingContext(object):
def start(self):
if threading.current_thread() is not self.main_thread:
+ logger.warning("Started logcontext %s on different thread", self)
return
- if self.usage_start and self.usage_end:
- self.ru_utime += self.usage_end.ru_utime - self.usage_start.ru_utime
- self.ru_stime += self.usage_end.ru_stime - self.usage_start.ru_stime
- self.usage_start = None
- self.usage_end = None
-
+ # If we haven't already started record the thread resource usage so
+ # far
if not self.usage_start:
self.usage_start = get_thread_resource_usage()
def stop(self):
if threading.current_thread() is not self.main_thread:
+ logger.warning("Stopped logcontext %s on different thread", self)
return
+ # When we stop, let's record the resource used since we started
if self.usage_start:
- self.usage_end = get_thread_resource_usage()
+ usage_end = get_thread_resource_usage()
+
+ self.ru_utime += usage_end.ru_utime - self.usage_start.ru_utime
+ self.ru_stime += usage_end.ru_stime - self.usage_start.ru_stime
+
+ self.usage_start = None
+ else:
+ logger.warning("Called stop on logcontext %s without calling start", self)
def get_resource_usage(self):
+ """Get CPU time used by this logcontext so far.
+
+ Returns:
+ tuple[float, float]: The user and system CPU usage in seconds
+ """
ru_utime = self.ru_utime
ru_stime = self.ru_stime
- if self.usage_start and threading.current_thread() is self.main_thread:
+ # If we are on the correct thread and we're currently running then we
+ # can include resource usage so far.
+ is_main_thread = threading.current_thread() is self.main_thread
+ if self.alive and self.usage_start and is_main_thread:
current = get_thread_resource_usage()
ru_utime += current.ru_utime - self.usage_start.ru_utime
ru_stime += current.ru_stime - self.usage_start.ru_stime
return ru_utime, ru_stime
- def add_database_transaction(self, duration_ms):
+ def add_database_transaction(self, duration_sec):
self.db_txn_count += 1
- self.db_txn_duration_ms += duration_ms
+ self.db_txn_duration_sec += duration_sec
- def add_database_scheduled(self, sched_ms):
+ def add_database_scheduled(self, sched_sec):
"""Record a use of the database pool
Args:
- sched_ms (int): number of milliseconds it took us to get a
+ sched_sec (float): number of seconds it took us to get a
connection
"""
- self.db_sched_duration_ms += sched_ms
+ self.db_sched_duration_sec += sched_sec
class LoggingContextFilter(logging.Filter):
@@ -278,7 +295,7 @@ class PreserveLoggingContext(object):
context = LoggingContext.set_current_context(self.current_context)
if context != self.new_context:
- logger.debug(
+ logger.warn(
"Unexpected logging context: %s is not %s",
context, self.new_context,
)
@@ -301,31 +318,49 @@ def preserve_fn(f):
def run_in_background(f, *args, **kwargs):
"""Calls a function, ensuring that the current context is restored after
return from the function, and that the sentinel context is set once the
- deferred returned by the funtion completes.
+ deferred returned by the function completes.
Useful for wrapping functions that return a deferred which you don't yield
- on.
+ on (for instance because you want to pass it to deferred.gatherResults()).
+
+ Note that if you completely discard the result, you should make sure that
+ `f` doesn't raise any deferred exceptions, otherwise a scary-looking
+ CRITICAL error about an unhandled error will be logged without much
+ indication about where it came from.
"""
current = LoggingContext.current_context()
- res = f(*args, **kwargs)
- if isinstance(res, defer.Deferred) and not res.called:
- # The function will have reset the context before returning, so
- # we need to restore it now.
- LoggingContext.set_current_context(current)
-
- # The original context will be restored when the deferred
- # completes, but there is nothing waiting for it, so it will
- # get leaked into the reactor or some other function which
- # wasn't expecting it. We therefore need to reset the context
- # here.
- #
- # (If this feels asymmetric, consider it this way: we are
- # effectively forking a new thread of execution. We are
- # probably currently within a ``with LoggingContext()`` block,
- # which is supposed to have a single entry and exit point. But
- # by spawning off another deferred, we are effectively
- # adding a new exit point.)
- res.addBoth(_set_context_cb, LoggingContext.sentinel)
+ try:
+ res = f(*args, **kwargs)
+ except: # noqa: E722
+ # the assumption here is that the caller doesn't want to be disturbed
+ # by synchronous exceptions, so let's turn them into Failures.
+ return defer.fail()
+
+ if not isinstance(res, defer.Deferred):
+ return res
+
+ if res.called and not res.paused:
+ # The function should have maintained the logcontext, so we can
+ # optimise out the messing about
+ return res
+
+ # The function may have reset the context before returning, so
+ # we need to restore it now.
+ ctx = LoggingContext.set_current_context(current)
+
+ # The original context will be restored when the deferred
+ # completes, but there is nothing waiting for it, so it will
+ # get leaked into the reactor or some other function which
+ # wasn't expecting it. We therefore need to reset the context
+ # here.
+ #
+ # (If this feels asymmetric, consider it this way: we are
+ # effectively forking a new thread of execution. We are
+ # probably currently within a ``with LoggingContext()`` block,
+ # which is supposed to have a single entry and exit point. But
+ # by spawning off another deferred, we are effectively
+ # adding a new exit point.)
+ res.addBoth(_set_context_cb, ctx)
return res
@@ -340,11 +375,20 @@ def make_deferred_yieldable(deferred):
returning a deferred. Then, when the deferred completes, restores the
current logcontext before running callbacks/errbacks.
- (This is more-or-less the opposite operation to preserve_fn.)
+ (This is more-or-less the opposite operation to run_in_background.)
"""
- if isinstance(deferred, defer.Deferred) and not deferred.called:
- prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
- deferred.addBoth(_set_context_cb, prev_context)
+ if not isinstance(deferred, defer.Deferred):
+ return deferred
+
+ if deferred.called and not deferred.paused:
+ # it looks like this deferred is ready to run any callbacks we give it
+ # immediately. We may as well optimise out the logcontext faffery.
+ return deferred
+
+ # ok, we can't be sure that a yield won't block, so let's reset the
+ # logcontext, and add a callback to the deferred to restore it.
+ prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
+ deferred.addBoth(_set_context_cb, prev_context)
return deferred
diff --git a/synapse/util/logformatter.py b/synapse/util/logformatter.py
index cdbc4bffd7..3e42868ea9 100644
--- a/synapse/util/logformatter.py
+++ b/synapse/util/logformatter.py
@@ -14,7 +14,7 @@
# limitations under the License.
-import StringIO
+from six import StringIO
import logging
import traceback
@@ -32,7 +32,7 @@ class LogFormatter(logging.Formatter):
super(LogFormatter, self).__init__(*args, **kwargs)
def formatException(self, ei):
- sio = StringIO.StringIO()
+ sio = StringIO()
(typ, val, tb) = ei
# log the stack above the exception capture point if possible, but
diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py
index 3a83828d25..03249c5dc8 100644
--- a/synapse/util/logutils.py
+++ b/synapse/util/logutils.py
@@ -96,7 +96,7 @@ def time_function(f):
id = _TIME_FUNC_ID
_TIME_FUNC_ID += 1
- start = time.clock() * 1000
+ start = time.clock()
try:
_log_debug_as_f(
@@ -107,10 +107,10 @@ def time_function(f):
r = f(*args, **kwargs)
finally:
- end = time.clock() * 1000
+ end = time.clock()
_log_debug_as_f(
f,
- "[FUNC END] {%s-%d} %f",
+ "[FUNC END] {%s-%d} %.3f sec",
(func_name, id, end - start,),
)
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index e4b5687a4b..1ba7d65c7c 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -15,8 +15,8 @@
from twisted.internet import defer
+from prometheus_client import Counter
from synapse.util.logcontext import LoggingContext
-import synapse.metrics
from functools import wraps
import logging
@@ -24,66 +24,26 @@ import logging
logger = logging.getLogger(__name__)
+block_counter = Counter("synapse_util_metrics_block_count", "", ["block_name"])
-metrics = synapse.metrics.get_metrics_for(__name__)
-
-# total number of times we have hit this block
-block_counter = metrics.register_counter(
- "block_count",
- labels=["block_name"],
- alternative_names=(
- # the following are all deprecated aliases for the same metric
- metrics.name_prefix + x for x in (
- "_block_timer:count",
- "_block_ru_utime:count",
- "_block_ru_stime:count",
- "_block_db_txn_count:count",
- "_block_db_txn_duration:count",
- )
- )
-)
-
-block_timer = metrics.register_counter(
- "block_time_seconds",
- labels=["block_name"],
- alternative_names=(
- metrics.name_prefix + "_block_timer:total",
- ),
-)
-
-block_ru_utime = metrics.register_counter(
- "block_ru_utime_seconds", labels=["block_name"],
- alternative_names=(
- metrics.name_prefix + "_block_ru_utime:total",
- ),
-)
-
-block_ru_stime = metrics.register_counter(
- "block_ru_stime_seconds", labels=["block_name"],
- alternative_names=(
- metrics.name_prefix + "_block_ru_stime:total",
- ),
-)
-
-block_db_txn_count = metrics.register_counter(
- "block_db_txn_count", labels=["block_name"],
- alternative_names=(
- metrics.name_prefix + "_block_db_txn_count:total",
- ),
-)
+block_timer = Counter("synapse_util_metrics_block_time_seconds", "", ["block_name"])
+
+block_ru_utime = Counter(
+ "synapse_util_metrics_block_ru_utime_seconds", "", ["block_name"])
+
+block_ru_stime = Counter(
+ "synapse_util_metrics_block_ru_stime_seconds", "", ["block_name"])
+
+block_db_txn_count = Counter(
+ "synapse_util_metrics_block_db_txn_count", "", ["block_name"])
# seconds spent waiting for db txns, excluding scheduling time, in this block
-block_db_txn_duration = metrics.register_counter(
- "block_db_txn_duration_seconds", labels=["block_name"],
- alternative_names=(
- metrics.name_prefix + "_block_db_txn_duration:total",
- ),
-)
+block_db_txn_duration = Counter(
+ "synapse_util_metrics_block_db_txn_duration_seconds", "", ["block_name"])
# seconds spent waiting for a db connection, in this block
-block_db_sched_duration = metrics.register_counter(
- "block_db_sched_duration_seconds", labels=["block_name"],
-)
+block_db_sched_duration = Counter(
+ "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"])
def measure_func(name):
@@ -102,7 +62,7 @@ class Measure(object):
__slots__ = [
"clock", "name", "start_context", "start", "new_context", "ru_utime",
"ru_stime",
- "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
+ "db_txn_count", "db_txn_duration_sec", "db_sched_duration_sec",
"created_context",
]
@@ -114,7 +74,7 @@ class Measure(object):
self.created_context = False
def __enter__(self):
- self.start = self.clock.time_msec()
+ self.start = self.clock.time()
self.start_context = LoggingContext.current_context()
if not self.start_context:
self.start_context = LoggingContext("Measure")
@@ -123,17 +83,17 @@ class Measure(object):
self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
self.db_txn_count = self.start_context.db_txn_count
- self.db_txn_duration_ms = self.start_context.db_txn_duration_ms
- self.db_sched_duration_ms = self.start_context.db_sched_duration_ms
+ self.db_txn_duration_sec = self.start_context.db_txn_duration_sec
+ self.db_sched_duration_sec = self.start_context.db_sched_duration_sec
def __exit__(self, exc_type, exc_val, exc_tb):
if isinstance(exc_type, Exception) or not self.start_context:
return
- duration = self.clock.time_msec() - self.start
+ duration = self.clock.time() - self.start
- block_counter.inc(self.name)
- block_timer.inc_by(duration, self.name)
+ block_counter.labels(self.name).inc()
+ block_timer.labels(self.name).inc(duration)
context = LoggingContext.current_context()
@@ -150,19 +110,13 @@ class Measure(object):
ru_utime, ru_stime = context.get_resource_usage()
- block_ru_utime.inc_by(ru_utime - self.ru_utime, self.name)
- block_ru_stime.inc_by(ru_stime - self.ru_stime, self.name)
- block_db_txn_count.inc_by(
- context.db_txn_count - self.db_txn_count, self.name
- )
- block_db_txn_duration.inc_by(
- (context.db_txn_duration_ms - self.db_txn_duration_ms) / 1000.,
- self.name
- )
- block_db_sched_duration.inc_by(
- (context.db_sched_duration_ms - self.db_sched_duration_ms) / 1000.,
- self.name
- )
+ block_ru_utime.labels(self.name).inc(ru_utime - self.ru_utime)
+ block_ru_stime.labels(self.name).inc(ru_stime - self.ru_stime)
+ block_db_txn_count.labels(self.name).inc(context.db_txn_count - self.db_txn_count)
+ block_db_txn_duration.labels(self.name).inc(
+ context.db_txn_duration_sec - self.db_txn_duration_sec)
+ block_db_sched_duration.labels(self.name).inc(
+ context.db_sched_duration_sec - self.db_sched_duration_sec)
if self.created_context:
self.start_context.__exit__(exc_type, exc_val, exc_tb)
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 1101881a2d..0ab63c3d7d 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -18,7 +18,10 @@ from twisted.internet import defer
from synapse.api.errors import LimitExceededError
from synapse.util.async import sleep
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import (
+ run_in_background, make_deferred_yieldable,
+ PreserveLoggingContext,
+)
import collections
import contextlib
@@ -150,7 +153,7 @@ class _PerHostRatelimiter(object):
"Ratelimit [%s]: sleeping req",
id(request_id),
)
- ret_defer = preserve_fn(sleep)(self.sleep_msec / 1000.0)
+ ret_defer = run_in_background(sleep, self.sleep_msec / 1000.0)
self.sleeping_requests.add(request_id)
@@ -176,6 +179,9 @@ class _PerHostRatelimiter(object):
return r
def on_err(r):
+ # XXX: why is this necessary? this is called before we start
+ # processing the request so why would the request be in
+ # current_processing?
self.current_processing.discard(request_id)
return r
@@ -187,7 +193,7 @@ class _PerHostRatelimiter(object):
ret_defer.addCallbacks(on_start, on_err)
ret_defer.addBoth(on_both)
- return ret_defer
+ return make_deferred_yieldable(ret_defer)
def _on_exit(self, request_id):
logger.debug(
@@ -197,7 +203,12 @@ class _PerHostRatelimiter(object):
self.current_processing.discard(request_id)
try:
request_id, deferred = self.ready_request_queue.popitem()
+
+ # XXX: why do we do the following? the on_start callback above will
+ # do it for us.
self.current_processing.add(request_id)
- deferred.callback(None)
+
+ with PreserveLoggingContext():
+ deferred.callback(None)
except KeyError:
pass
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 47b0bb5eb3..4e93f69d3a 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -203,8 +203,8 @@ class RetryDestinationLimiter(object):
)
except Exception:
logger.exception(
- "Failed to store set_destination_retry_timings",
+ "Failed to store destination_retry_timings",
)
# we deliberately do this in the background.
- synapse.util.logcontext.preserve_fn(store_retry_timings)()
+ synapse.util.logcontext.run_in_background(store_retry_timings)
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 95a6168e16..b98b9dc6e4 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -15,6 +15,7 @@
import random
import string
+from six.moves import range
_string_with_symbols = (
string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
@@ -22,12 +23,12 @@ _string_with_symbols = (
def random_string(length):
- return ''.join(random.choice(string.ascii_letters) for _ in xrange(length))
+ return ''.join(random.choice(string.ascii_letters) for _ in range(length))
def random_string_with_symbols(length):
return ''.join(
- random.choice(_string_with_symbols) for _ in xrange(length)
+ random.choice(_string_with_symbols) for _ in range(length)
)
diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index b70f9a6b0a..7a9e45aca9 100644
--- a/synapse/util/wheel_timer.py
+++ b/synapse/util/wheel_timer.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from six.moves import range
+
class _Entry(object):
__slots__ = ["end_key", "queue"]
@@ -68,7 +70,7 @@ class WheelTimer(object):
# Add empty entries between the end of the current list and when we want
# to insert. This ensures there are no gaps.
self.entries.extend(
- _Entry(key) for key in xrange(last_key, then_key + 1)
+ _Entry(key) for key in range(last_key, then_key + 1)
)
self.entries[-1].queue.append(obj)
|