diff --git a/synapse/__init__.py b/synapse/__init__.py
index f31cb9a3cb..faa183a99e 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,4 +17,4 @@
""" This is a reference implementation of a Matrix home server.
"""
-__version__ = "0.28.1"
+__version__ = "0.31.2"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index f17fda6315..66639b0089 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -15,6 +15,8 @@
import logging
+from six import itervalues
+
import pymacaroons
from twisted.internet import defer
@@ -57,7 +59,7 @@ class Auth(object):
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000)
- register_cache("token_cache", self.token_cache)
+ register_cache("cache", "token_cache", self.token_cache)
@defer.inlineCallbacks
def check_from_context(self, event, context, do_sig_check=True):
@@ -66,7 +68,7 @@ class Auth(object):
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
- (e.type, e.state_key): e for e in auth_events.values()
+ (e.type, e.state_key): e for e in itervalues(auth_events)
}
self.check(event, auth_events=auth_events, do_sig_check=do_sig_check)
@@ -653,7 +655,7 @@ class Auth(object):
auth_events[(EventTypes.PowerLevels, "")] = power_level_event
send_level = event_auth.get_send_level(
- EventTypes.Aliases, "", auth_events
+ EventTypes.Aliases, "", power_level_event,
)
user_level = event_auth.get_user_power_level(user_id, auth_events)
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index a9ff5576f3..e6ad3768f0 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -19,6 +19,7 @@ import logging
import simplejson as json
from six import iteritems
+from six.moves import http_client
logger = logging.getLogger(__name__)
@@ -51,6 +52,8 @@ class Codes(object):
THREEPID_DENIED = "M_THREEPID_DENIED"
INVALID_USERNAME = "M_INVALID_USERNAME"
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
+ CONSENT_NOT_GIVEN = "M_CONSENT_NOT_GIVEN"
+ CANNOT_LEAVE_SERVER_NOTICE_ROOM = "M_CANNOT_LEAVE_SERVER_NOTICE_ROOM"
class CodeMessageException(RuntimeError):
@@ -138,6 +141,32 @@ class SynapseError(CodeMessageException):
return res
+class ConsentNotGivenError(SynapseError):
+ """The error returned to the client when the user has not consented to the
+ privacy policy.
+ """
+ def __init__(self, msg, consent_uri):
+ """Constructs a ConsentNotGivenError
+
+ Args:
+ msg (str): The human-readable error message
+ consent_url (str): The URL where the user can give their consent
+ """
+ super(ConsentNotGivenError, self).__init__(
+ code=http_client.FORBIDDEN,
+ msg=msg,
+ errcode=Codes.CONSENT_NOT_GIVEN
+ )
+ self._consent_uri = consent_uri
+
+ def error_dict(self):
+ return cs_error(
+ self.msg,
+ self.errcode,
+ consent_uri=self._consent_uri
+ )
+
+
class RegistrationError(SynapseError):
"""An error raised when a registration event fails."""
pass
@@ -292,7 +321,7 @@ def cs_error(msg, code=Codes.UNKNOWN, **kwargs):
Args:
msg (str): The error message.
- code (int): The error code.
+ code (str): The error code.
kwargs : Additional keys to add to the response.
Returns:
A dict representing the error response JSON.
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index db43219d24..dbc0e7e445 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -411,7 +411,7 @@ class Filter(object):
return room_ids
def filter(self, events):
- return filter(self.check, events)
+ return list(filter(self.check, events))
def limit(self):
return self.filter_json.get("limit", 10)
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index 91a33a3402..bb46b5da8a 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,6 +15,12 @@
# limitations under the License.
"""Contains the URL paths to prefix various aspects of the server with. """
+from hashlib import sha256
+import hmac
+
+from six.moves.urllib.parse import urlencode
+
+from synapse.config import ConfigError
CLIENT_PREFIX = "/_matrix/client/api/v1"
CLIENT_V2_ALPHA_PREFIX = "/_matrix/client/v2_alpha"
@@ -25,3 +32,46 @@ SERVER_KEY_PREFIX = "/_matrix/key/v1"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
MEDIA_PREFIX = "/_matrix/media/r0"
LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
+
+
+class ConsentURIBuilder(object):
+ def __init__(self, hs_config):
+ """
+ Args:
+ hs_config (synapse.config.homeserver.HomeServerConfig):
+ """
+ if hs_config.form_secret is None:
+ raise ConfigError(
+ "form_secret not set in config",
+ )
+ if hs_config.public_baseurl is None:
+ raise ConfigError(
+ "public_baseurl not set in config",
+ )
+
+ self._hmac_secret = hs_config.form_secret.encode("utf-8")
+ self._public_baseurl = hs_config.public_baseurl
+
+ def build_user_consent_uri(self, user_id):
+ """Build a URI which we can give to the user to do their privacy
+ policy consent
+
+ Args:
+ user_id (str): mxid or username of user
+
+ Returns
+ (str) the URI where the user can do consent
+ """
+ mac = hmac.new(
+ key=self._hmac_secret,
+ msg=user_id,
+ digestmod=sha256,
+ ).hexdigest()
+ consent_uri = "%s_matrix/consent?%s" % (
+ self._public_baseurl,
+ urlencode({
+ "u": user_id,
+ "h": mac
+ }),
+ )
+ return consent_uri
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index e4318cdfc3..a6925ab139 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -124,6 +124,19 @@ def quit_with_error(error_string):
sys.exit(1)
+def listen_metrics(bind_addresses, port):
+ """
+ Start Prometheus metrics server.
+ """
+ from synapse.metrics import RegistryProxy
+ from prometheus_client import start_http_server
+
+ for host in bind_addresses:
+ reactor.callInThread(start_http_server, int(port),
+ addr=host, registry=RegistryProxy)
+ logger.info("Metrics now reporting on %s:%d", host, port)
+
+
def listen_tcp(bind_addresses, port, factory, backlog=50):
"""
Create a TCP socket for a port and several addresses
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index f2540023a7..dd114dee07 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -32,10 +32,10 @@ from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, preserve_fn
+from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
-from twisted.internet import reactor
+from twisted.internet import reactor, defer
from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.appservice")
@@ -74,6 +74,7 @@ class AppserviceServer(HomeServer):
site_tag,
listener_config,
root_resource,
+ self.version_string,
)
)
@@ -93,6 +94,13 @@ class AppserviceServer(HomeServer):
globals={"hs": self},
)
)
+ elif listener["type"] == "metrics":
+ if not self.get_config().enable_metrics:
+ logger.warn(("Metrics listener configured, but "
+ "collect_metrics is not enabled!"))
+ else:
+ _base.listen_metrics(listener["bind_addresses"],
+ listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -112,9 +120,14 @@ class ASReplicationHandler(ReplicationClientHandler):
if stream_name == "events":
max_stream_id = self.store.get_room_max_stream_ordering()
- preserve_fn(
- self.appservice_handler.notify_interested_services
- )(max_stream_id)
+ run_in_background(self._notify_app_services, max_stream_id)
+
+ @defer.inlineCallbacks
+ def _notify_app_services(self, room_stream_id):
+ try:
+ yield self.appservice_handler.notify_interested_services(room_stream_id)
+ except Exception:
+ logger.exception("Error notifying application services of event")
def start(config_options):
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index 267d34c881..85dada7f9f 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -25,6 +25,7 @@ from synapse.config.logger import setup_logging
from synapse.crypto import context_factory
from synapse.http.server import JsonResource
from synapse.http.site import SynapseSite
+from synapse.metrics import RegistryProxy
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
@@ -77,7 +78,7 @@ class ClientReaderServer(HomeServer):
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(self)
+ resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
elif name == "client":
resource = JsonResource(self, canonical_json=False)
PublicRoomListRestServlet(self).register(resource)
@@ -98,6 +99,7 @@ class ClientReaderServer(HomeServer):
site_tag,
listener_config,
root_resource,
+ self.version_string,
)
)
@@ -117,7 +119,13 @@ class ClientReaderServer(HomeServer):
globals={"hs": self},
)
)
-
+ elif listener["type"] == "metrics":
+ if not self.get_config().enable_metrics:
+ logger.warn(("Metrics listener configured, but "
+ "collect_metrics is not enabled!"))
+ else:
+ _base.listen_metrics(listener["bind_addresses"],
+ listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
diff --git a/synapse/app/event_creator.py b/synapse/app/event_creator.py
index b915d12d53..5ca77c0f1a 100644
--- a/synapse/app/event_creator.py
+++ b/synapse/app/event_creator.py
@@ -25,6 +25,7 @@ from synapse.config.logger import setup_logging
from synapse.crypto import context_factory
from synapse.http.server import JsonResource
from synapse.http.site import SynapseSite
+from synapse.metrics import RegistryProxy
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
@@ -90,7 +91,7 @@ class EventCreatorServer(HomeServer):
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(self)
+ resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
elif name == "client":
resource = JsonResource(self, canonical_json=False)
RoomSendEventRestServlet(self).register(resource)
@@ -114,6 +115,7 @@ class EventCreatorServer(HomeServer):
site_tag,
listener_config,
root_resource,
+ self.version_string,
)
)
@@ -133,6 +135,13 @@ class EventCreatorServer(HomeServer):
globals={"hs": self},
)
)
+ elif listener["type"] == "metrics":
+ if not self.get_config().enable_metrics:
+ logger.warn(("Metrics listener configured, but "
+ "collect_metrics is not enabled!"))
+ else:
+ _base.listen_metrics(listener["bind_addresses"],
+ listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index c1dc66dd17..2a1995d0cd 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -26,6 +26,7 @@ from synapse.config.logger import setup_logging
from synapse.crypto import context_factory
from synapse.federation.transport.server import TransportLayerServer
from synapse.http.site import SynapseSite
+from synapse.metrics import RegistryProxy
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.directory import DirectoryStore
@@ -71,7 +72,7 @@ class FederationReaderServer(HomeServer):
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(self)
+ resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
elif name == "federation":
resources.update({
FEDERATION_PREFIX: TransportLayerServer(self),
@@ -87,6 +88,7 @@ class FederationReaderServer(HomeServer):
site_tag,
listener_config,
root_resource,
+ self.version_string,
)
)
@@ -106,6 +108,13 @@ class FederationReaderServer(HomeServer):
globals={"hs": self},
)
)
+ elif listener["type"] == "metrics":
+ if not self.get_config().enable_metrics:
+ logger.warn(("Metrics listener configured, but "
+ "collect_metrics is not enabled!"))
+ else:
+ _base.listen_metrics(listener["bind_addresses"],
+ listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 0cc3331519..81ad574043 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -25,6 +25,7 @@ from synapse.config.logger import setup_logging
from synapse.crypto import context_factory
from synapse.federation import send_queue
from synapse.http.site import SynapseSite
+from synapse.metrics import RegistryProxy
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
@@ -38,7 +39,7 @@ from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.async import Linearizer
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, preserve_fn
+from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
from twisted.internet import defer, reactor
@@ -89,7 +90,7 @@ class FederationSenderServer(HomeServer):
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(self)
+ resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
root_resource = create_resource_tree(resources, NoResource())
@@ -101,6 +102,7 @@ class FederationSenderServer(HomeServer):
site_tag,
listener_config,
root_resource,
+ self.version_string,
)
)
@@ -120,6 +122,13 @@ class FederationSenderServer(HomeServer):
globals={"hs": self},
)
)
+ elif listener["type"] == "metrics":
+ if not self.get_config().enable_metrics:
+ logger.warn(("Metrics listener configured, but "
+ "collect_metrics is not enabled!"))
+ else:
+ _base.listen_metrics(listener["bind_addresses"],
+ listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -229,7 +238,7 @@ class FederationSenderHandler(object):
# presence, typing, etc.
if stream_name == "federation":
send_queue.process_rows_for_federation(self.federation_sender, rows)
- preserve_fn(self.update_token)(token)
+ run_in_background(self.update_token, token)
# We also need to poke the federation sender when new events happen
elif stream_name == "events":
@@ -237,19 +246,22 @@ class FederationSenderHandler(object):
@defer.inlineCallbacks
def update_token(self, token):
- self.federation_position = token
-
- # We linearize here to ensure we don't have races updating the token
- with (yield self._fed_position_linearizer.queue(None)):
- if self._last_ack < self.federation_position:
- yield self.store.update_federation_out_pos(
- "federation", self.federation_position
- )
+ try:
+ self.federation_position = token
+
+ # We linearize here to ensure we don't have races updating the token
+ with (yield self._fed_position_linearizer.queue(None)):
+ if self._last_ack < self.federation_position:
+ yield self.store.update_federation_out_pos(
+ "federation", self.federation_position
+ )
- # We ACK this token over replication so that the master can drop
- # its in memory queues
- self.replication_client.send_federation_ack(self.federation_position)
- self._last_ack = self.federation_position
+ # We ACK this token over replication so that the master can drop
+ # its in memory queues
+ self.replication_client.send_federation_ack(self.federation_position)
+ self._last_ack = self.federation_position
+ except Exception:
+ logger.exception("Error updating federation stream position")
if __name__ == '__main__':
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index b349e3e3ce..5a164a7a95 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -29,6 +29,7 @@ from synapse.http.servlet import (
RestServlet, parse_json_object_from_request,
)
from synapse.http.site import SynapseSite
+from synapse.metrics import RegistryProxy
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
@@ -131,7 +132,7 @@ class FrontendProxyServer(HomeServer):
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(self)
+ resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
elif name == "client":
resource = JsonResource(self, canonical_json=False)
KeyUploadServlet(self).register(resource)
@@ -152,6 +153,7 @@ class FrontendProxyServer(HomeServer):
site_tag,
listener_config,
root_resource,
+ self.version_string,
)
)
@@ -171,6 +173,13 @@ class FrontendProxyServer(HomeServer):
globals={"hs": self},
)
)
+ elif listener["type"] == "metrics":
+ if not self.get_config().enable_metrics:
+ logger.warn(("Metrics listener configured, but "
+ "collect_metrics is not enabled!"))
+ else:
+ _base.listen_metrics(listener["bind_addresses"],
+ listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index a0e465d644..714f98a3e0 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -34,7 +34,7 @@ from synapse.module_api import ModuleApi
from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import RootRedirect
from synapse.http.site import SynapseSite
-from synapse.metrics import register_memory_metrics
+from synapse.metrics import RegistryProxy
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.python_dependencies import CONDITIONAL_REQUIREMENTS, \
check_requirements
@@ -140,6 +140,7 @@ class SynapseHomeServer(HomeServer):
site_tag,
listener_config,
root_resource,
+ self.version_string,
),
self.tls_server_context_factory,
)
@@ -153,6 +154,7 @@ class SynapseHomeServer(HomeServer):
site_tag,
listener_config,
root_resource,
+ self.version_string,
)
)
logger.info("Synapse now listening on port %d", port)
@@ -182,6 +184,15 @@ class SynapseHomeServer(HomeServer):
"/_matrix/client/versions": client_resource,
})
+ if name == "consent":
+ from synapse.rest.consent.consent_resource import ConsentResource
+ consent_resource = ConsentResource(self)
+ if compress:
+ consent_resource = gz_wrap(consent_resource)
+ resources.update({
+ "/_matrix/consent": consent_resource,
+ })
+
if name == "federation":
resources.update({
FEDERATION_PREFIX: TransportLayerServer(self),
@@ -219,7 +230,7 @@ class SynapseHomeServer(HomeServer):
resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self)
if name == "metrics" and self.get_config().enable_metrics:
- resources[METRICS_PREFIX] = MetricsResource(self)
+ resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
if name == "replication":
resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
@@ -252,6 +263,13 @@ class SynapseHomeServer(HomeServer):
reactor.addSystemEventTrigger(
"before", "shutdown", server_listener.stopListening,
)
+ elif listener["type"] == "metrics":
+ if not self.get_config().enable_metrics:
+ logger.warn(("Metrics listener configured, but "
+ "collect_metrics is not enabled!"))
+ else:
+ _base.listen_metrics(listener["bind_addresses"],
+ listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -351,8 +369,6 @@ def setup(config_options):
hs.get_datastore().start_doing_background_updates()
hs.get_federation_client().start_get_pdu_cache()
- register_memory_metrics(hs)
-
reactor.callWhenRunning(start)
return hs
@@ -423,6 +439,10 @@ def run(hs):
total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users()
stats["total_nonbridged_users"] = total_nonbridged_users
+ daily_user_type_results = yield hs.get_datastore().count_daily_user_type()
+ for name, count in daily_user_type_results.iteritems():
+ stats["daily_user_type_" + name] = count
+
room_count = yield hs.get_datastore().get_room_count()
stats["total_room_count"] = room_count
@@ -473,6 +493,14 @@ def run(hs):
" changes across releases."
)
+ def generate_user_daily_visit_stats():
+ hs.get_datastore().generate_user_daily_visits()
+
+ # Rather than update on per session basis, batch up the requests.
+ # If you increase the loop period, the accuracy of user_daily_visits
+ # table will decrease
+ clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000)
+
if hs.config.report_stats:
logger.info("Scheduling stats reporting for 3 hour intervals")
clock.looping_call(phone_stats_home, 3 * 60 * 60 * 1000)
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index fc8282bbc1..006bba80a8 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -27,6 +27,7 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.crypto import context_factory
from synapse.http.site import SynapseSite
+from synapse.metrics import RegistryProxy
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
@@ -73,7 +74,7 @@ class MediaRepositoryServer(HomeServer):
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(self)
+ resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
elif name == "media":
media_repo = self.get_media_repository_resource()
resources.update({
@@ -94,6 +95,7 @@ class MediaRepositoryServer(HomeServer):
site_tag,
listener_config,
root_resource,
+ self.version_string,
)
)
@@ -113,6 +115,13 @@ class MediaRepositoryServer(HomeServer):
globals={"hs": self},
)
)
+ elif listener["type"] == "metrics":
+ if not self.get_config().enable_metrics:
+ logger.warn(("Metrics listener configured, but "
+ "collect_metrics is not enabled!"))
+ else:
+ _base.listen_metrics(listener["bind_addresses"],
+ listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index d5c3a85195..64df47f9cc 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -23,6 +23,7 @@ from synapse.config._base import ConfigError
from synapse.config.homeserver import HomeServerConfig
from synapse.config.logger import setup_logging
from synapse.http.site import SynapseSite
+from synapse.metrics import RegistryProxy
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
from synapse.replication.slave.storage.events import SlavedEventStore
@@ -33,7 +34,7 @@ from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, preserve_fn
+from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
from twisted.internet import defer, reactor
@@ -92,7 +93,7 @@ class PusherServer(HomeServer):
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(self)
+ resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
root_resource = create_resource_tree(resources, NoResource())
@@ -104,6 +105,7 @@ class PusherServer(HomeServer):
site_tag,
listener_config,
root_resource,
+ self.version_string,
)
)
@@ -123,6 +125,13 @@ class PusherServer(HomeServer):
globals={"hs": self},
)
)
+ elif listener["type"] == "metrics":
+ if not self.get_config().enable_metrics:
+ logger.warn(("Metrics listener configured, but "
+ "collect_metrics is not enabled!"))
+ else:
+ _base.listen_metrics(listener["bind_addresses"],
+ listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -140,24 +149,27 @@ class PusherReplicationHandler(ReplicationClientHandler):
def on_rdata(self, stream_name, token, rows):
super(PusherReplicationHandler, self).on_rdata(stream_name, token, rows)
- preserve_fn(self.poke_pushers)(stream_name, token, rows)
+ run_in_background(self.poke_pushers, stream_name, token, rows)
@defer.inlineCallbacks
def poke_pushers(self, stream_name, token, rows):
- if stream_name == "pushers":
- for row in rows:
- if row.deleted:
- yield self.stop_pusher(row.user_id, row.app_id, row.pushkey)
- else:
- yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
- elif stream_name == "events":
- yield self.pusher_pool.on_new_notifications(
- token, token,
- )
- elif stream_name == "receipts":
- yield self.pusher_pool.on_new_receipts(
- token, token, set(row.room_id for row in rows)
- )
+ try:
+ if stream_name == "pushers":
+ for row in rows:
+ if row.deleted:
+ yield self.stop_pusher(row.user_id, row.app_id, row.pushkey)
+ else:
+ yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
+ elif stream_name == "events":
+ yield self.pusher_pool.on_new_notifications(
+ token, token,
+ )
+ elif stream_name == "receipts":
+ yield self.pusher_pool.on_new_receipts(
+ token, token, set(row.room_id for row in rows)
+ )
+ except Exception:
+ logger.exception("Error poking pushers")
def stop_pusher(self, user_id, app_id, pushkey):
key = "%s:%s" % (app_id, pushkey)
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 2fddcd935a..6808d6d3e0 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -26,6 +26,7 @@ from synapse.config.logger import setup_logging
from synapse.handlers.presence import PresenceHandler, get_interested_parties
from synapse.http.server import JsonResource
from synapse.http.site import SynapseSite
+from synapse.metrics import RegistryProxy
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
@@ -51,7 +52,7 @@ from synapse.storage.engines import create_engine
from synapse.storage.presence import UserPresenceState
from synapse.storage.roommember import RoomMemberStore
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, preserve_fn
+from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole
from synapse.util.stringutils import random_string
from synapse.util.versionstring import get_version_string
@@ -257,7 +258,7 @@ class SynchrotronServer(HomeServer):
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(self)
+ resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
elif name == "client":
resource = JsonResource(self, canonical_json=False)
sync.register_servlets(self, resource)
@@ -281,6 +282,7 @@ class SynchrotronServer(HomeServer):
site_tag,
listener_config,
root_resource,
+ self.version_string,
)
)
@@ -300,6 +302,13 @@ class SynchrotronServer(HomeServer):
globals={"hs": self},
)
)
+ elif listener["type"] == "metrics":
+ if not self.get_config().enable_metrics:
+ logger.warn(("Metrics listener configured, but "
+ "collect_metrics is not enabled!"))
+ else:
+ _base.listen_metrics(listener["bind_addresses"],
+ listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -327,8 +336,7 @@ class SyncReplicationHandler(ReplicationClientHandler):
def on_rdata(self, stream_name, token, rows):
super(SyncReplicationHandler, self).on_rdata(stream_name, token, rows)
-
- preserve_fn(self.process_and_notify)(stream_name, token, rows)
+ run_in_background(self.process_and_notify, stream_name, token, rows)
def get_streams_to_replicate(self):
args = super(SyncReplicationHandler, self).get_streams_to_replicate()
@@ -340,55 +348,58 @@ class SyncReplicationHandler(ReplicationClientHandler):
@defer.inlineCallbacks
def process_and_notify(self, stream_name, token, rows):
- if stream_name == "events":
- # We shouldn't get multiple rows per token for events stream, so
- # we don't need to optimise this for multiple rows.
- for row in rows:
- event = yield self.store.get_event(row.event_id)
- extra_users = ()
- if event.type == EventTypes.Member:
- extra_users = (event.state_key,)
- max_token = self.store.get_room_max_stream_ordering()
- self.notifier.on_new_room_event(
- event, token, max_token, extra_users
+ try:
+ if stream_name == "events":
+ # We shouldn't get multiple rows per token for events stream, so
+ # we don't need to optimise this for multiple rows.
+ for row in rows:
+ event = yield self.store.get_event(row.event_id)
+ extra_users = ()
+ if event.type == EventTypes.Member:
+ extra_users = (event.state_key,)
+ max_token = self.store.get_room_max_stream_ordering()
+ self.notifier.on_new_room_event(
+ event, token, max_token, extra_users
+ )
+ elif stream_name == "push_rules":
+ self.notifier.on_new_event(
+ "push_rules_key", token, users=[row.user_id for row in rows],
)
- elif stream_name == "push_rules":
- self.notifier.on_new_event(
- "push_rules_key", token, users=[row.user_id for row in rows],
- )
- elif stream_name in ("account_data", "tag_account_data",):
- self.notifier.on_new_event(
- "account_data_key", token, users=[row.user_id for row in rows],
- )
- elif stream_name == "receipts":
- self.notifier.on_new_event(
- "receipt_key", token, rooms=[row.room_id for row in rows],
- )
- elif stream_name == "typing":
- self.typing_handler.process_replication_rows(token, rows)
- self.notifier.on_new_event(
- "typing_key", token, rooms=[row.room_id for row in rows],
- )
- elif stream_name == "to_device":
- entities = [row.entity for row in rows if row.entity.startswith("@")]
- if entities:
+ elif stream_name in ("account_data", "tag_account_data",):
self.notifier.on_new_event(
- "to_device_key", token, users=entities,
+ "account_data_key", token, users=[row.user_id for row in rows],
)
- elif stream_name == "device_lists":
- all_room_ids = set()
- for row in rows:
- room_ids = yield self.store.get_rooms_for_user(row.user_id)
- all_room_ids.update(room_ids)
- self.notifier.on_new_event(
- "device_list_key", token, rooms=all_room_ids,
- )
- elif stream_name == "presence":
- yield self.presence_handler.process_replication_rows(token, rows)
- elif stream_name == "receipts":
- self.notifier.on_new_event(
- "groups_key", token, users=[row.user_id for row in rows],
- )
+ elif stream_name == "receipts":
+ self.notifier.on_new_event(
+ "receipt_key", token, rooms=[row.room_id for row in rows],
+ )
+ elif stream_name == "typing":
+ self.typing_handler.process_replication_rows(token, rows)
+ self.notifier.on_new_event(
+ "typing_key", token, rooms=[row.room_id for row in rows],
+ )
+ elif stream_name == "to_device":
+ entities = [row.entity for row in rows if row.entity.startswith("@")]
+ if entities:
+ self.notifier.on_new_event(
+ "to_device_key", token, users=entities,
+ )
+ elif stream_name == "device_lists":
+ all_room_ids = set()
+ for row in rows:
+ room_ids = yield self.store.get_rooms_for_user(row.user_id)
+ all_room_ids.update(room_ids)
+ self.notifier.on_new_event(
+ "device_list_key", token, rooms=all_room_ids,
+ )
+ elif stream_name == "presence":
+ yield self.presence_handler.process_replication_rows(token, rows)
+ elif stream_name == "receipts":
+ self.notifier.on_new_event(
+ "groups_key", token, users=[row.user_id for row in rows],
+ )
+ except Exception:
+ logger.exception("Error processing replication")
def start(config_options):
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index 5f845e80d1..ada1c13cec 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -26,6 +26,7 @@ from synapse.config.logger import setup_logging
from synapse.crypto import context_factory
from synapse.http.server import JsonResource
from synapse.http.site import SynapseSite
+from synapse.metrics import RegistryProxy
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
@@ -39,10 +40,10 @@ from synapse.storage.engines import create_engine
from synapse.storage.user_directory import UserDirectoryStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.logcontext import LoggingContext, preserve_fn
+from synapse.util.logcontext import LoggingContext, run_in_background
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
-from twisted.internet import reactor
+from twisted.internet import reactor, defer
from twisted.web.resource import NoResource
logger = logging.getLogger("synapse.app.user_dir")
@@ -105,7 +106,7 @@ class UserDirectoryServer(HomeServer):
for res in listener_config["resources"]:
for name in res["names"]:
if name == "metrics":
- resources[METRICS_PREFIX] = MetricsResource(self)
+ resources[METRICS_PREFIX] = MetricsResource(RegistryProxy)
elif name == "client":
resource = JsonResource(self, canonical_json=False)
user_directory.register_servlets(self, resource)
@@ -126,6 +127,7 @@ class UserDirectoryServer(HomeServer):
site_tag,
listener_config,
root_resource,
+ self.version_string,
)
)
@@ -145,6 +147,13 @@ class UserDirectoryServer(HomeServer):
globals={"hs": self},
)
)
+ elif listener["type"] == "metrics":
+ if not self.get_config().enable_metrics:
+ logger.warn(("Metrics listener configured, but "
+ "collect_metrics is not enabled!"))
+ else:
+ _base.listen_metrics(listener["bind_addresses"],
+ listener["port"])
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -164,7 +173,14 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler):
stream_name, token, rows
)
if stream_name == "current_state_deltas":
- preserve_fn(self.user_directory.notify_new_event)()
+ run_in_background(self._notify_directory)
+
+ @defer.inlineCallbacks
+ def _notify_directory(self):
+ try:
+ yield self.user_directory.notify_new_event()
+ except Exception:
+ logger.exception("Error notifiying user directory of state update")
def start(config_options):
diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py
index 6da315473d..6eddbc0828 100644
--- a/synapse/appservice/scheduler.py
+++ b/synapse/appservice/scheduler.py
@@ -51,7 +51,7 @@ components.
from twisted.internet import defer
from synapse.appservice import ApplicationServiceState
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import run_in_background
from synapse.util.metrics import Measure
import logging
@@ -106,7 +106,7 @@ class _ServiceQueuer(object):
def enqueue(self, service, event):
# if this service isn't being sent something
self.queued_events.setdefault(service.id, []).append(event)
- preserve_fn(self._send_request)(service)
+ run_in_background(self._send_request, service)
@defer.inlineCallbacks
def _send_request(self, service):
@@ -152,10 +152,10 @@ class _TransactionController(object):
if sent:
yield txn.complete(self.store)
else:
- preserve_fn(self._start_recoverer)(service)
- except Exception as e:
- logger.exception(e)
- preserve_fn(self._start_recoverer)(service)
+ run_in_background(self._start_recoverer, service)
+ except Exception:
+ logger.exception("Error creating appservice transaction")
+ run_in_background(self._start_recoverer, service)
@defer.inlineCallbacks
def on_recovered(self, recoverer):
@@ -176,17 +176,20 @@ class _TransactionController(object):
@defer.inlineCallbacks
def _start_recoverer(self, service):
- yield self.store.set_appservice_state(
- service,
- ApplicationServiceState.DOWN
- )
- logger.info(
- "Application service falling behind. Starting recoverer. AS ID %s",
- service.id
- )
- recoverer = self.recoverer_fn(service, self.on_recovered)
- self.add_recoverers([recoverer])
- recoverer.recover()
+ try:
+ yield self.store.set_appservice_state(
+ service,
+ ApplicationServiceState.DOWN
+ )
+ logger.info(
+ "Application service falling behind. Starting recoverer. AS ID %s",
+ service.id
+ )
+ recoverer = self.recoverer_fn(service, self.on_recovered)
+ self.add_recoverers([recoverer])
+ recoverer.recover()
+ except Exception:
+ logger.exception("Error starting AS recoverer")
@defer.inlineCallbacks
def _is_service_up(self, service):
diff --git a/synapse/config/__init__.py b/synapse/config/__init__.py
index bfebb0f644..f2a5a41e92 100644
--- a/synapse/config/__init__.py
+++ b/synapse/config/__init__.py
@@ -12,3 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+from ._base import ConfigError
+
+# export ConfigError if somebody does import *
+# this is largely a fudge to stop PEP8 moaning about the import
+__all__ = ["ConfigError"]
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 32b439d20a..b748ed2b0a 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -281,15 +281,15 @@ class Config(object):
)
if not cls.path_exists(config_dir_path):
os.makedirs(config_dir_path)
- with open(config_path, "wb") as config_file:
- config_bytes, config = obj.generate_config(
+ with open(config_path, "w") as config_file:
+ config_str, config = obj.generate_config(
config_dir_path=config_dir_path,
server_name=server_name,
report_stats=(config_args.report_stats == "yes"),
is_generating_file=True
)
obj.invoke_all("generate_files", config)
- config_file.write(config_bytes)
+ config_file.write(config_str)
print((
"A config file has been generated in %r for server name"
" %r with corresponding SSL keys and self-signed"
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index 9a2359b6fd..277305e184 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -17,11 +17,11 @@ from ._base import Config, ConfigError
from synapse.appservice import ApplicationService
from synapse.types import UserID
-import urllib
import yaml
import logging
from six import string_types
+from six.moves.urllib import parse as urlparse
logger = logging.getLogger(__name__)
@@ -105,7 +105,7 @@ def _load_appservice(hostname, as_info, config_filename):
)
localpart = as_info["sender_localpart"]
- if urllib.quote(localpart) != localpart:
+ if urlparse.quote(localpart) != localpart:
raise ValueError(
"sender_localpart needs characters which are not URL encoded."
)
diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py
new file mode 100644
index 0000000000..8f6ed73328
--- /dev/null
+++ b/synapse/config/consent_config.py
@@ -0,0 +1,85 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import Config
+
+DEFAULT_CONFIG = """\
+# User Consent configuration
+#
+# Parts of this section are required if enabling the 'consent' resource under
+# 'listeners', in particular 'template_dir' and 'version'.
+#
+# 'template_dir' gives the location of the templates for the HTML forms.
+# This directory should contain one subdirectory per language (eg, 'en', 'fr'),
+# and each language directory should contain the policy document (named as
+# '<version>.html') and a success page (success.html).
+#
+# 'version' specifies the 'current' version of the policy document. It defines
+# the version to be served by the consent resource if there is no 'v'
+# parameter.
+#
+# 'server_notice_content', if enabled, will send a user a "Server Notice"
+# asking them to consent to the privacy policy. The 'server_notices' section
+# must also be configured for this to work. Notices will *not* be sent to
+# guest users unless 'send_server_notice_to_guests' is set to true.
+#
+# 'block_events_error', if set, will block any attempts to send events
+# until the user consents to the privacy policy. The value of the setting is
+# used as the text of the error.
+#
+# user_consent:
+# template_dir: res/templates/privacy
+# version: 1.0
+# server_notice_content:
+# msgtype: m.text
+# body: >-
+# To continue using this homeserver you must review and agree to the
+# terms and conditions at %(consent_uri)s
+# send_server_notice_to_guests: True
+# block_events_error: >-
+# To continue using this homeserver you must review and agree to the
+# terms and conditions at %(consent_uri)s
+#
+"""
+
+
+class ConsentConfig(Config):
+ def __init__(self):
+ super(ConsentConfig, self).__init__()
+
+ self.user_consent_version = None
+ self.user_consent_template_dir = None
+ self.user_consent_server_notice_content = None
+ self.user_consent_server_notice_to_guests = False
+ self.block_events_without_consent_error = None
+
+ def read_config(self, config):
+ consent_config = config.get("user_consent")
+ if consent_config is None:
+ return
+ self.user_consent_version = str(consent_config["version"])
+ self.user_consent_template_dir = consent_config["template_dir"]
+ self.user_consent_server_notice_content = consent_config.get(
+ "server_notice_content",
+ )
+ self.block_events_without_consent_error = consent_config.get(
+ "block_events_error",
+ )
+ self.user_consent_server_notice_to_guests = bool(consent_config.get(
+ "send_server_notice_to_guests", False,
+ ))
+
+ def default_config(self, **kwargs):
+ return DEFAULT_CONFIG
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index bf19cfee29..1dea2ad024 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from .tls import TlsConfig
from .server import ServerConfig
from .logger import LoggingConfig
@@ -37,6 +37,8 @@ from .push import PushConfig
from .spam_checker import SpamCheckerConfig
from .groups import GroupsConfig
from .user_directory import UserDirectoryConfig
+from .consent_config import ConsentConfig
+from .server_notices_config import ServerNoticesConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
@@ -45,12 +47,15 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
JWTConfig, PasswordConfig, EmailConfig,
WorkerConfig, PasswordAuthProviderConfig, PushConfig,
- SpamCheckerConfig, GroupsConfig, UserDirectoryConfig,):
+ SpamCheckerConfig, GroupsConfig, UserDirectoryConfig,
+ ConsentConfig,
+ ServerNoticesConfig,
+ ):
pass
if __name__ == '__main__':
import sys
sys.stdout.write(
- HomeServerConfig().generate_config(sys.argv[1], sys.argv[2])[0]
+ HomeServerConfig().generate_config(sys.argv[1], sys.argv[2], True)[0]
)
diff --git a/synapse/config/key.py b/synapse/config/key.py
index 4b8fc063d0..d1382ad9ac 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -59,14 +59,20 @@ class KeyConfig(Config):
self.expire_access_token = config.get("expire_access_token", False)
+ # a secret which is used to calculate HMACs for form values, to stop
+ # falsification of values
+ self.form_secret = config.get("form_secret", None)
+
def default_config(self, config_dir_path, server_name, is_generating_file=False,
**kwargs):
base_key_name = os.path.join(config_dir_path, server_name)
if is_generating_file:
macaroon_secret_key = random_string_with_symbols(50)
+ form_secret = '"%s"' % random_string_with_symbols(50)
else:
macaroon_secret_key = None
+ form_secret = 'null'
return """\
macaroon_secret_key: "%(macaroon_secret_key)s"
@@ -74,6 +80,10 @@ class KeyConfig(Config):
# Used to enable access token expiration.
expire_access_token: False
+ # a secret which is used to calculate HMACs for form values, to stop
+ # falsification of values
+ form_secret: %(form_secret)s
+
## Signing Keys ##
# Path to the signing key to sign messages with
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 3f70039acd..6a7228dc2f 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -117,7 +117,7 @@ class LoggingConfig(Config):
log_config = config.get("log_config")
if log_config and not os.path.exists(log_config):
log_file = self.abspath("homeserver.log")
- with open(log_config, "wb") as log_config_file:
+ with open(log_config, "w") as log_config_file:
log_config_file.write(
DEFAULT_LOG_CONFIG.substitute(log_file=log_file)
)
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 25ea77738a..81ecf9778c 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -250,6 +250,9 @@ class ContentRepositoryConfig(Config):
# - '192.168.0.0/16'
# - '100.64.0.0/10'
# - '169.254.0.0/16'
+ # - '::1/128'
+ # - 'fe80::/64'
+ # - 'fc00::/7'
#
# List of IP address CIDR ranges that the URL preview spider is allowed
# to access even if they are specified in url_preview_ip_range_blacklist.
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 8f0b6d1f28..968ecd9ea0 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -14,8 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from ._base import Config, ConfigError
+logger = logging.Logger(__name__)
+
class ServerConfig(Config):
@@ -138,6 +142,12 @@ class ServerConfig(Config):
metrics_port = config.get("metrics_port")
if metrics_port:
+ logger.warn(
+ ("The metrics_port configuration option is deprecated in Synapse 0.31 "
+ "in favour of a listener. Please see "
+ "http://github.com/matrix-org/synapse/blob/master/docs/metrics-howto.rst"
+ " on how to configure the new listener."))
+
self.listeners.append({
"port": metrics_port,
"bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")],
diff --git a/synapse/config/server_notices_config.py b/synapse/config/server_notices_config.py
new file mode 100644
index 0000000000..be1d1f762c
--- /dev/null
+++ b/synapse/config/server_notices_config.py
@@ -0,0 +1,86 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from ._base import Config
+from synapse.types import UserID
+
+DEFAULT_CONFIG = """\
+# Server Notices room configuration
+#
+# Uncomment this section to enable a room which can be used to send notices
+# from the server to users. It is a special room which cannot be left; notices
+# come from a special "notices" user id.
+#
+# If you uncomment this section, you *must* define the system_mxid_localpart
+# setting, which defines the id of the user which will be used to send the
+# notices.
+#
+# It's also possible to override the room name, the display name of the
+# "notices" user, and the avatar for the user.
+#
+# server_notices:
+# system_mxid_localpart: notices
+# system_mxid_display_name: "Server Notices"
+# system_mxid_avatar_url: "mxc://server.com/oumMVlgDnLYFaPVkExemNVVZ"
+# room_name: "Server Notices"
+"""
+
+
+class ServerNoticesConfig(Config):
+ """Configuration for the server notices room.
+
+ Attributes:
+ server_notices_mxid (str|None):
+ The MXID to use for server notices.
+ None if server notices are not enabled.
+
+ server_notices_mxid_display_name (str|None):
+ The display name to use for the server notices user.
+ None if server notices are not enabled.
+
+ server_notices_mxid_avatar_url (str|None):
+ The display name to use for the server notices user.
+ None if server notices are not enabled.
+
+ server_notices_room_name (str|None):
+ The name to use for the server notices room.
+ None if server notices are not enabled.
+ """
+ def __init__(self):
+ super(ServerNoticesConfig, self).__init__()
+ self.server_notices_mxid = None
+ self.server_notices_mxid_display_name = None
+ self.server_notices_mxid_avatar_url = None
+ self.server_notices_room_name = None
+
+ def read_config(self, config):
+ c = config.get("server_notices")
+ if c is None:
+ return
+
+ mxid_localpart = c['system_mxid_localpart']
+ self.server_notices_mxid = UserID(
+ mxid_localpart, self.server_name,
+ ).to_string()
+ self.server_notices_mxid_display_name = c.get(
+ 'system_mxid_display_name', None,
+ )
+ self.server_notices_mxid_avatar_url = c.get(
+ 'system_mxid_avatar_url', None,
+ )
+ # todo: i18n
+ self.server_notices_room_name = c.get('room_name', "Server Notices")
+
+ def default_config(self, **kwargs):
+ return DEFAULT_CONFIG
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 29eb012ddb..b66154bc7c 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -133,7 +133,7 @@ class TlsConfig(Config):
tls_dh_params_path = config["tls_dh_params_path"]
if not self.path_exists(tls_private_key_path):
- with open(tls_private_key_path, "w") as private_key_file:
+ with open(tls_private_key_path, "wb") as private_key_file:
tls_private_key = crypto.PKey()
tls_private_key.generate_key(crypto.TYPE_RSA, 2048)
private_key_pem = crypto.dump_privatekey(
@@ -148,7 +148,7 @@ class TlsConfig(Config):
)
if not self.path_exists(tls_certificate_path):
- with open(tls_certificate_path, "w") as certificate_file:
+ with open(tls_certificate_path, "wb") as certificate_file:
cert = crypto.X509()
subject = cert.get_subject()
subject.CN = config["server_name"]
diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py
index cff3ca809a..0397f73ab4 100644
--- a/synapse/crypto/context_factory.py
+++ b/synapse/crypto/context_factory.py
@@ -13,8 +13,8 @@
# limitations under the License.
from twisted.internet import ssl
-from OpenSSL import SSL
-from twisted.internet._sslverify import _OpenSSLECCurve, _defaultCurveName
+from OpenSSL import SSL, crypto
+from twisted.internet._sslverify import _defaultCurveName
import logging
@@ -32,8 +32,9 @@ class ServerContextFactory(ssl.ContextFactory):
@staticmethod
def configure_context(context, config):
try:
- _ecCurve = _OpenSSLECCurve(_defaultCurveName)
- _ecCurve.addECKeyToContext(context)
+ _ecCurve = crypto.get_elliptic_curve(_defaultCurveName)
+ context.set_tmp_ecdh(_ecCurve)
+
except Exception:
logger.exception("Failed to enable elliptic curve for TLS")
context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index fce83d445f..22ee0fc93f 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -19,7 +19,8 @@ from synapse.api.errors import SynapseError, Codes
from synapse.util import unwrapFirstError, logcontext
from synapse.util.logcontext import (
PreserveLoggingContext,
- preserve_fn
+ preserve_fn,
+ run_in_background,
)
from synapse.util.metrics import Measure
@@ -127,7 +128,7 @@ class Keyring(object):
verify_requests.append(verify_request)
- preserve_fn(self._start_key_lookups)(verify_requests)
+ run_in_background(self._start_key_lookups, verify_requests)
# Pass those keys to handle_key_deferred so that the json object
# signatures can be verified
@@ -146,53 +147,56 @@ class Keyring(object):
verify_requests (List[VerifyKeyRequest]):
"""
- # create a deferred for each server we're going to look up the keys
- # for; we'll resolve them once we have completed our lookups.
- # These will be passed into wait_for_previous_lookups to block
- # any other lookups until we have finished.
- # The deferreds are called with no logcontext.
- server_to_deferred = {
- rq.server_name: defer.Deferred()
- for rq in verify_requests
- }
-
- # We want to wait for any previous lookups to complete before
- # proceeding.
- yield self.wait_for_previous_lookups(
- [rq.server_name for rq in verify_requests],
- server_to_deferred,
- )
-
- # Actually start fetching keys.
- self._get_server_verify_keys(verify_requests)
-
- # When we've finished fetching all the keys for a given server_name,
- # resolve the deferred passed to `wait_for_previous_lookups` so that
- # any lookups waiting will proceed.
- #
- # map from server name to a set of request ids
- server_to_request_ids = {}
-
- for verify_request in verify_requests:
- server_name = verify_request.server_name
- request_id = id(verify_request)
- server_to_request_ids.setdefault(server_name, set()).add(request_id)
-
- def remove_deferreds(res, verify_request):
- server_name = verify_request.server_name
- request_id = id(verify_request)
- server_to_request_ids[server_name].discard(request_id)
- if not server_to_request_ids[server_name]:
- d = server_to_deferred.pop(server_name, None)
- if d:
- d.callback(None)
- return res
-
- for verify_request in verify_requests:
- verify_request.deferred.addBoth(
- remove_deferreds, verify_request,
+ try:
+ # create a deferred for each server we're going to look up the keys
+ # for; we'll resolve them once we have completed our lookups.
+ # These will be passed into wait_for_previous_lookups to block
+ # any other lookups until we have finished.
+ # The deferreds are called with no logcontext.
+ server_to_deferred = {
+ rq.server_name: defer.Deferred()
+ for rq in verify_requests
+ }
+
+ # We want to wait for any previous lookups to complete before
+ # proceeding.
+ yield self.wait_for_previous_lookups(
+ [rq.server_name for rq in verify_requests],
+ server_to_deferred,
)
+ # Actually start fetching keys.
+ self._get_server_verify_keys(verify_requests)
+
+ # When we've finished fetching all the keys for a given server_name,
+ # resolve the deferred passed to `wait_for_previous_lookups` so that
+ # any lookups waiting will proceed.
+ #
+ # map from server name to a set of request ids
+ server_to_request_ids = {}
+
+ for verify_request in verify_requests:
+ server_name = verify_request.server_name
+ request_id = id(verify_request)
+ server_to_request_ids.setdefault(server_name, set()).add(request_id)
+
+ def remove_deferreds(res, verify_request):
+ server_name = verify_request.server_name
+ request_id = id(verify_request)
+ server_to_request_ids[server_name].discard(request_id)
+ if not server_to_request_ids[server_name]:
+ d = server_to_deferred.pop(server_name, None)
+ if d:
+ d.callback(None)
+ return res
+
+ for verify_request in verify_requests:
+ verify_request.deferred.addBoth(
+ remove_deferreds, verify_request,
+ )
+ except Exception:
+ logger.exception("Error starting key lookups")
+
@defer.inlineCallbacks
def wait_for_previous_lookups(self, server_names, server_to_deferred):
"""Waits for any previous key lookups for the given servers to finish.
@@ -313,7 +317,7 @@ class Keyring(object):
if not verify_request.deferred.called:
verify_request.deferred.errback(err)
- preserve_fn(do_iterations)().addErrback(on_err)
+ run_in_background(do_iterations).addErrback(on_err)
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):
@@ -329,8 +333,9 @@ class Keyring(object):
"""
res = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.store.get_server_verify_keys)(
- server_name, key_ids
+ run_in_background(
+ self.store.get_server_verify_keys,
+ server_name, key_ids,
).addCallback(lambda ks, server: (server, ks), server_name)
for server_name, key_ids in server_name_and_key_ids
],
@@ -358,7 +363,7 @@ class Keyring(object):
results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(get_key)(p_name, p_keys)
+ run_in_background(get_key, p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items()
],
consumeErrors=True,
@@ -398,7 +403,7 @@ class Keyring(object):
results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(get_key)(server_name, key_ids)
+ run_in_background(get_key, server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
@@ -481,7 +486,8 @@ class Keyring(object):
yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.store_keys)(
+ run_in_background(
+ self.store_keys,
server_name=server_name,
from_server=perspective_name,
verify_keys=response_keys,
@@ -539,7 +545,8 @@ class Keyring(object):
yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.store_keys)(
+ run_in_background(
+ self.store_keys,
server_name=key_server_name,
from_server=server_name,
verify_keys=verify_keys,
@@ -615,7 +622,8 @@ class Keyring(object):
yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.store.store_server_keys_json)(
+ run_in_background(
+ self.store.store_server_keys_json,
server_name=server_name,
key_id=key_id,
from_server=server_name,
@@ -716,7 +724,8 @@ class Keyring(object):
# TODO(markjh): Store whether the keys have expired.
return logcontext.make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self.store.store_server_verify_key)(
+ run_in_background(
+ self.store.store_server_verify_key,
server_name, server_name, key.time_added, key
)
for key_id, key in verify_keys.items()
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index cd5627e36a..f512d88145 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -34,9 +34,11 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
event: the event being checked.
auth_events (dict: event-key -> event): the existing room state.
+ Raises:
+ AuthError if the checks fail
Returns:
- True if the auth checks pass.
+ if the auth checks pass.
"""
if do_size_check:
_check_size_limits(event)
@@ -71,7 +73,7 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
# Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now)
logger.warn("Trusting event: %s", event.event_id)
- return True
+ return
if event.type == EventTypes.Create:
room_id_domain = get_domain_from_id(event.room_id)
@@ -81,7 +83,8 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
"Creation event's room_id domain does not match sender's"
)
# FIXME
- return True
+ logger.debug("Allowing! %s", event)
+ return
creation_event = auth_events.get((EventTypes.Create, ""), None)
@@ -118,7 +121,8 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
403,
"Alias event's state_key does not match sender's domain"
)
- return True
+ logger.debug("Allowing! %s", event)
+ return
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
@@ -127,14 +131,9 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
)
if event.type == EventTypes.Member:
- allowed = _is_membership_change_allowed(
- event, auth_events
- )
- if allowed:
- logger.debug("Allowing! %s", event)
- else:
- logger.debug("Denying! %s", event)
- return allowed
+ _is_membership_change_allowed(event, auth_events)
+ logger.debug("Allowing! %s", event)
+ return
_check_event_sender_in_room(event, auth_events)
@@ -153,7 +152,8 @@ def check(event, auth_events, do_sig_check=True, do_size_check=True):
)
)
else:
- return True
+ logger.debug("Allowing! %s", event)
+ return
_can_send_event(event, auth_events)
@@ -200,7 +200,7 @@ def _is_membership_change_allowed(event, auth_events):
create = auth_events.get(key)
if create and event.prev_events[0][0] == create.event_id:
if create.content["creator"] == event.state_key:
- return True
+ return
target_user_id = event.state_key
@@ -265,13 +265,13 @@ def _is_membership_change_allowed(event, auth_events):
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
- return True
+ return
if Membership.JOIN != membership:
if (caller_invited
and Membership.LEAVE == membership
and target_user_id == event.user_id):
- return True
+ return
if not caller_in_room: # caller isn't joined
raise AuthError(
@@ -334,8 +334,6 @@ def _is_membership_change_allowed(event, auth_events):
else:
raise AuthError(500, "Unknown membership %s" % membership)
- return True
-
def _check_event_sender_in_room(event, auth_events):
key = (EventTypes.Member, event.user_id, )
@@ -355,35 +353,46 @@ def _check_joined_room(member, user_id, room_id):
))
-def get_send_level(etype, state_key, auth_events):
- key = (EventTypes.PowerLevels, "", )
- send_level_event = auth_events.get(key)
- send_level = None
- if send_level_event:
- send_level = send_level_event.content.get("events", {}).get(
- etype
- )
- if send_level is None:
- if state_key is not None:
- send_level = send_level_event.content.get(
- "state_default", 50
- )
- else:
- send_level = send_level_event.content.get(
- "events_default", 0
- )
+def get_send_level(etype, state_key, power_levels_event):
+ """Get the power level required to send an event of a given type
+
+ The federation spec [1] refers to this as "Required Power Level".
+
+ https://matrix.org/docs/spec/server_server/unstable.html#definitions
- if send_level:
- send_level = int(send_level)
+ Args:
+ etype (str): type of event
+ state_key (str|None): state_key of state event, or None if it is not
+ a state event.
+ power_levels_event (synapse.events.EventBase|None): power levels event
+ in force at this point in the room
+ Returns:
+ int: power level required to send this event.
+ """
+
+ if power_levels_event:
+ power_levels_content = power_levels_event.content
else:
- send_level = 0
+ power_levels_content = {}
+
+ # see if we have a custom level for this event type
+ send_level = power_levels_content.get("events", {}).get(etype)
+
+ # otherwise, fall back to the state_default/events_default.
+ if send_level is None:
+ if state_key is not None:
+ send_level = power_levels_content.get("state_default", 50)
+ else:
+ send_level = power_levels_content.get("events_default", 0)
- return send_level
+ return int(send_level)
def _can_send_event(event, auth_events):
+ power_levels_event = _get_power_level_event(auth_events)
+
send_level = get_send_level(
- event.type, event.get("state_key", None), auth_events
+ event.type, event.get("state_key"), power_levels_event,
)
user_level = get_user_power_level(event.user_id, auth_events)
@@ -471,14 +480,14 @@ def _check_power_levels(event, auth_events):
]
old_list = current_state.content.get("users", {})
- for user in set(old_list.keys() + user_list.keys()):
+ for user in set(list(old_list) + list(user_list)):
levels_to_check.append(
(user, "users")
)
old_list = current_state.content.get("events", {})
new_list = event.content.get("events", {})
- for ev_id in set(old_list.keys() + new_list.keys()):
+ for ev_id in set(list(old_list) + list(new_list)):
levels_to_check.append(
(ev_id, "events")
)
@@ -524,13 +533,22 @@ def _check_power_levels(event, auth_events):
def _get_power_level_event(auth_events):
- key = (EventTypes.PowerLevels, "", )
- return auth_events.get(key)
+ return auth_events.get((EventTypes.PowerLevels, ""))
def get_user_power_level(user_id, auth_events):
- power_level_event = _get_power_level_event(auth_events)
+ """Get a user's power level
+
+ Args:
+ user_id (str): user's id to look up in power_levels
+ auth_events (dict[(str, str), synapse.events.EventBase]):
+ state in force at this point in the room (or rather, a subset of
+ it including at least the create event and power levels event.
+ Returns:
+ int: the user's power level in this room.
+ """
+ power_level_event = _get_power_level_event(auth_events)
if power_level_event:
level = power_level_event.content.get("users", {}).get(user_id)
if not level:
@@ -541,6 +559,11 @@ def get_user_power_level(user_id, auth_events):
else:
return int(level)
else:
+ # if there is no power levels event, the creator gets 100 and everyone
+ # else gets 0.
+
+ # some things which call this don't pass the create event: hack around
+ # that.
key = (EventTypes.Create, "", )
create_event = auth_events.get(key)
if (create_event is not None and
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index e673e96cc0..cb08da4984 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -47,14 +47,26 @@ class _EventInternalMetadata(object):
def _event_dict_property(key):
+ # We want to be able to use hasattr with the event dict properties.
+ # However, (on python3) hasattr expects AttributeError to be raised. Hence,
+ # we need to transform the KeyError into an AttributeError
def getter(self):
- return self._event_dict[key]
+ try:
+ return self._event_dict[key]
+ except KeyError:
+ raise AttributeError(key)
def setter(self, v):
- self._event_dict[key] = v
+ try:
+ self._event_dict[key] = v
+ except KeyError:
+ raise AttributeError(key)
def delete(self):
- del self._event_dict[key]
+ try:
+ del self._event_dict[key]
+ except KeyError:
+ raise AttributeError(key)
return property(
getter,
@@ -134,7 +146,7 @@ class EventBase(object):
return field in self._event_dict
def items(self):
- return self._event_dict.items()
+ return list(self._event_dict.items())
class FrozenEvent(EventBase):
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 824f4a42e3..29ae086786 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -20,6 +20,8 @@ from frozendict import frozendict
import re
+from six import string_types
+
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
# (?<!stuff) matches if the current position in the string is not preceded
# by a match for 'stuff'.
@@ -277,7 +279,7 @@ def serialize_event(e, time_now_ms, as_client_event=True,
if only_event_fields:
if (not isinstance(only_event_fields, list) or
- not all(isinstance(f, basestring) for f in only_event_fields)):
+ not all(isinstance(f, string_types) for f in only_event_fields)):
raise TypeError("only_event_fields must be a list of strings")
d = only_fields(d, only_event_fields)
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 2f4c8a1018..e0e5bf818c 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -17,6 +17,8 @@ from synapse.types import EventID, RoomID, UserID
from synapse.api.errors import SynapseError
from synapse.api.constants import EventTypes, Membership
+from six import string_types
+
class EventValidator(object):
@@ -49,7 +51,7 @@ class EventValidator(object):
strings.append("state_key")
for s in strings:
- if not isinstance(getattr(event, s), basestring):
+ if not isinstance(getattr(event, s), string_types):
raise SynapseError(400, "Not '%s' a string type" % (s,))
if event.type == EventTypes.Member:
@@ -88,5 +90,5 @@ class EventValidator(object):
for s in keys:
if s not in d:
raise SynapseError(400, "'%s' not in content" % (s,))
- if not isinstance(d[s], basestring):
+ if not isinstance(d[s], string_types):
raise SynapseError(400, "Not '%s' a string type" % (s,))
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 8e2c0c4cd2..87a92f6ea9 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -19,6 +19,8 @@ import itertools
import logging
import random
+from six.moves import range
+
from twisted.internet import defer
from synapse.api.constants import Membership
@@ -30,20 +32,17 @@ from synapse.federation.federation_base import (
FederationBase,
event_from_pdu_json,
)
-import synapse.metrics
from synapse.util import logcontext, unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
-from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.logutils import log_function
from synapse.util.retryutils import NotRetryingDestination
-logger = logging.getLogger(__name__)
-
+from prometheus_client import Counter
-# synapse.federation.federation_client is a silly name
-metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
+logger = logging.getLogger(__name__)
-sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
+sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"])
PDU_RETRY_TIME_MS = 1 * 60 * 1000
@@ -106,7 +105,7 @@ class FederationClient(FederationBase):
a Deferred which will eventually yield a JSON object from the
response
"""
- sent_queries_counter.inc(query_type)
+ sent_queries_counter.labels(query_type).inc()
return self.transport_layer.make_query(
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail,
@@ -125,7 +124,7 @@ class FederationClient(FederationBase):
a Deferred which will eventually yield a JSON object from the
response
"""
- sent_queries_counter.inc("client_device_keys")
+ sent_queries_counter.labels("client_device_keys").inc()
return self.transport_layer.query_client_keys(
destination, content, timeout
)
@@ -135,7 +134,7 @@ class FederationClient(FederationBase):
"""Query the device keys for a list of user ids hosted on a remote
server.
"""
- sent_queries_counter.inc("user_devices")
+ sent_queries_counter.labels("user_devices").inc()
return self.transport_layer.query_user_devices(
destination, user_id, timeout
)
@@ -152,7 +151,7 @@ class FederationClient(FederationBase):
a Deferred which will eventually yield a JSON object from the
response
"""
- sent_queries_counter.inc("client_one_time_keys")
+ sent_queries_counter.labels("client_one_time_keys").inc()
return self.transport_layer.claim_client_keys(
destination, content, timeout
)
@@ -392,7 +391,7 @@ class FederationClient(FederationBase):
"""
if return_local:
seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
- signed_events = seen_events.values()
+ signed_events = list(seen_events.values())
else:
seen_events = yield self.store.have_seen_events(event_ids)
signed_events = []
@@ -413,11 +412,12 @@ class FederationClient(FederationBase):
batch_size = 20
missing_events = list(missing_events)
- for i in xrange(0, len(missing_events), batch_size):
+ for i in range(0, len(missing_events), batch_size):
batch = set(missing_events[i:i + batch_size])
deferreds = [
- preserve_fn(self.get_pdu)(
+ run_in_background(
+ self.get_pdu,
destinations=random_server_list(),
event_id=e_id,
)
@@ -589,7 +589,7 @@ class FederationClient(FederationBase):
}
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
- destination, pdus.values(),
+ destination, list(pdus.values()),
outlier=True,
)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 247ddc89d5..2d420a58a2 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -27,12 +27,13 @@ from synapse.federation.federation_base import (
from synapse.federation.persistence import TransactionActions
from synapse.federation.units import Edu, Transaction
-import synapse.metrics
from synapse.types import get_domain_from_id
from synapse.util import async
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.logutils import log_function
+from prometheus_client import Counter
+
from six import iteritems
# when processing incoming transactions, we try to handle multiple rooms in
@@ -41,17 +42,17 @@ TRANSACTION_CONCURRENCY_LIMIT = 10
logger = logging.getLogger(__name__)
-# synapse.federation.federation_server is a silly name
-metrics = synapse.metrics.get_metrics_for("synapse.federation.server")
-
-received_pdus_counter = metrics.register_counter("received_pdus")
+received_pdus_counter = Counter("synapse_federation_server_received_pdus", "")
-received_edus_counter = metrics.register_counter("received_edus")
+received_edus_counter = Counter("synapse_federation_server_received_edus", "")
-received_queries_counter = metrics.register_counter("received_queries", labels=["type"])
+received_queries_counter = Counter(
+ "synapse_federation_server_received_queries", "", ["type"]
+)
class FederationServer(FederationBase):
+
def __init__(self, hs):
super(FederationServer, self).__init__(hs)
@@ -131,7 +132,7 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id)
- received_pdus_counter.inc_by(len(transaction.pdus))
+ received_pdus_counter.inc(len(transaction.pdus))
pdus_by_room = {}
@@ -292,7 +293,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_query_request(self, query_type, args):
- received_queries_counter.inc(query_type)
+ received_queries_counter.labels(query_type).inc()
resp = yield self.registry.on_query(query_type, args)
defer.returnValue((200, resp))
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 0f0c687b37..9f1142b5a9 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -33,7 +33,7 @@ from .units import Edu
from synapse.storage.presence import UserPresenceState
from synapse.util.metrics import Measure
-import synapse.metrics
+from synapse.metrics import LaterGauge
from blist import sorteddict
from collections import namedtuple
@@ -45,9 +45,6 @@ from six import itervalues, iteritems
logger = logging.getLogger(__name__)
-metrics = synapse.metrics.get_metrics_for(__name__)
-
-
class FederationRemoteSendQueue(object):
"""A drop in replacement for TransactionQueue"""
@@ -77,10 +74,8 @@ class FederationRemoteSendQueue(object):
# lambda binds to the queue rather than to the name of the queue which
# changes. ARGH.
def register(name, queue):
- metrics.register_callback(
- queue_name + "_size",
- lambda: len(queue),
- )
+ LaterGauge("synapse_federation_send_queue_%s_size" % (queue_name,),
+ "", [], lambda: len(queue))
for queue_name in [
"presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
@@ -202,7 +197,7 @@ class FederationRemoteSendQueue(object):
# We only want to send presence for our own users, so lets always just
# filter here just in case.
- local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
+ local_states = list(filter(lambda s: self.is_mine_id(s.user_id), states))
self.presence_map.update({state.user_id: state for state in local_states})
self.presence_changed[pos] = [state.user_id for state in local_states]
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 963d938edd..f0aeb5a0d3 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -26,23 +26,25 @@ from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
from synapse.util.metrics import measure_func
from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
import synapse.metrics
+from synapse.metrics import LaterGauge
+from synapse.metrics import (
+ sent_edus_counter,
+ sent_transactions_counter,
+ events_processed_counter,
+)
+
+from prometheus_client import Counter
+
+from six import itervalues
import logging
logger = logging.getLogger(__name__)
-metrics = synapse.metrics.get_metrics_for(__name__)
-
-client_metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
-sent_pdus_destination_dist = client_metrics.register_distribution(
- "sent_pdu_destinations"
+sent_pdus_destination_dist = Counter(
+ "synapse_federation_transaction_queue_sent_pdu_destinations", ""
)
-sent_edus_counter = client_metrics.register_counter("sent_edus")
-
-sent_transactions_counter = client_metrics.register_counter("sent_transactions")
-
-events_processed_counter = client_metrics.register_counter("events_processed")
class TransactionQueue(object):
@@ -69,8 +71,10 @@ class TransactionQueue(object):
# done
self.pending_transactions = {}
- metrics.register_callback(
- "pending_destinations",
+ LaterGauge(
+ "synapse_federation_transaction_queue_pending_destinations",
+ "",
+ [],
lambda: len(self.pending_transactions),
)
@@ -94,12 +98,16 @@ class TransactionQueue(object):
# Map of destination -> (edu_type, key) -> Edu
self.pending_edus_keyed_by_dest = edus_keyed = {}
- metrics.register_callback(
- "pending_pdus",
+ LaterGauge(
+ "synapse_federation_transaction_queue_pending_pdus",
+ "",
+ [],
lambda: sum(map(len, pdus.values())),
)
- metrics.register_callback(
- "pending_edus",
+ LaterGauge(
+ "synapse_federation_transaction_queue_pending_edus",
+ "",
+ [],
lambda: (
sum(map(len, edus.values()))
+ sum(map(len, presence.values()))
@@ -228,7 +236,7 @@ class TransactionQueue(object):
yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
logcontext.run_in_background(handle_room_events, evs)
- for evs in events_by_room.itervalues()
+ for evs in itervalues(events_by_room)
],
consumeErrors=True
))
@@ -241,18 +249,15 @@ class TransactionQueue(object):
now = self.clock.time_msec()
ts = yield self.store.get_received_ts(events[-1].event_id)
- synapse.metrics.event_processing_lag.set(
- now - ts, "federation_sender",
- )
- synapse.metrics.event_processing_last_ts.set(
- ts, "federation_sender",
- )
+ synapse.metrics.event_processing_lag.labels(
+ "federation_sender").set(now - ts)
+ synapse.metrics.event_processing_last_ts.labels(
+ "federation_sender").set(ts)
- events_processed_counter.inc_by(len(events))
+ events_processed_counter.inc(len(events))
- synapse.metrics.event_processing_positions.set(
- next_token, "federation_sender",
- )
+ synapse.metrics.event_processing_positions.labels(
+ "federation_sender").set(next_token)
finally:
self._is_processing = False
@@ -275,7 +280,7 @@ class TransactionQueue(object):
if not destinations:
return
- sent_pdus_destination_dist.inc_by(len(destinations))
+ sent_pdus_destination_dist.inc(len(destinations))
for destination in destinations:
self.pending_pdus_by_dest.setdefault(destination, []).append(
@@ -322,7 +327,9 @@ class TransactionQueue(object):
if not states_map:
break
- yield self._process_presence_inner(states_map.values())
+ yield self._process_presence_inner(list(states_map.values()))
+ except Exception:
+ logger.exception("Error sending presence states to servers")
finally:
self._processing_pending_presence = False
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index ff0656df3e..19d09f5422 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -25,7 +25,7 @@ from synapse.http.servlet import (
)
from synapse.util.ratelimitutils import FederationRateLimiter
from synapse.util.versionstring import get_version_string
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import run_in_background
from synapse.types import ThirdPartyInstanceID, get_domain_from_id
import functools
@@ -152,11 +152,18 @@ class Authenticator(object):
# alive
retry_timings = yield self.store.get_destination_retry_timings(origin)
if retry_timings and retry_timings["retry_last_ts"]:
- logger.info("Marking origin %r as up", origin)
- preserve_fn(self.store.set_destination_retry_timings)(origin, 0, 0)
+ run_in_background(self._reset_retry_timings, origin)
defer.returnValue(origin)
+ @defer.inlineCallbacks
+ def _reset_retry_timings(self, origin):
+ try:
+ logger.info("Marking origin %r as up", origin)
+ yield self.store.set_destination_retry_timings(origin, 0, 0)
+ except Exception:
+ logger.exception("Error resetting retry timings on %s", origin)
+
class BaseFederationServlet(object):
REQUIRE_AUTH = True
diff --git a/synapse/federation/units.py b/synapse/federation/units.py
index 3f645acc43..01c5b8fe17 100644
--- a/synapse/federation/units.py
+++ b/synapse/federation/units.py
@@ -74,8 +74,6 @@ class Transaction(JsonEncodedObject):
"previous_ids",
"pdus",
"edus",
- "transaction_id",
- "destination",
"pdu_failures",
]
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index 1fb709e6c3..6f11fa374b 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -42,7 +42,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.types import get_domain_from_id
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import run_in_background
from signedjson.sign import sign_json
@@ -165,31 +165,35 @@ class GroupAttestionRenewer(object):
@defer.inlineCallbacks
def _renew_attestation(group_id, user_id):
- if not self.is_mine_id(group_id):
- destination = get_domain_from_id(group_id)
- elif not self.is_mine_id(user_id):
- destination = get_domain_from_id(user_id)
- else:
- logger.warn(
- "Incorrectly trying to do attestations for user: %r in %r",
- user_id, group_id,
+ try:
+ if not self.is_mine_id(group_id):
+ destination = get_domain_from_id(group_id)
+ elif not self.is_mine_id(user_id):
+ destination = get_domain_from_id(user_id)
+ else:
+ logger.warn(
+ "Incorrectly trying to do attestations for user: %r in %r",
+ user_id, group_id,
+ )
+ yield self.store.remove_attestation_renewal(group_id, user_id)
+ return
+
+ attestation = self.attestations.create_attestation(group_id, user_id)
+
+ yield self.transport_client.renew_group_attestation(
+ destination, group_id, user_id,
+ content={"attestation": attestation},
)
- yield self.store.remove_attestation_renewal(group_id, user_id)
- return
-
- attestation = self.attestations.create_attestation(group_id, user_id)
- yield self.transport_client.renew_group_attestation(
- destination, group_id, user_id,
- content={"attestation": attestation},
- )
-
- yield self.store.update_attestation_renewal(
- group_id, user_id, attestation
- )
+ yield self.store.update_attestation_renewal(
+ group_id, user_id, attestation
+ )
+ except Exception:
+ logger.exception("Error renewing attestation of %r in %r",
+ user_id, group_id)
for row in rows:
group_id = row["group_id"]
user_id = row["user_id"]
- preserve_fn(_renew_attestation)(group_id, user_id)
+ run_in_background(_renew_attestation, group_id, user_id)
diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py
index 2d95b04e0c..62d20ad130 100644
--- a/synapse/groups/groups_server.py
+++ b/synapse/groups/groups_server.py
@@ -20,6 +20,8 @@ from synapse.api.errors import SynapseError
from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
from twisted.internet import defer
+from six import string_types
+
logger = logging.getLogger(__name__)
@@ -431,7 +433,7 @@ class GroupsServerHandler(object):
"long_description"):
if keyname in content:
value = content[keyname]
- if not isinstance(value, basestring):
+ if not isinstance(value, string_types):
raise SynapseError(400, "%r value is not a string" % (keyname,))
profile[keyname] = value
diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py
index 8f8fd82eb0..d358842b3e 100644
--- a/synapse/handlers/__init__.py
+++ b/synapse/handlers/__init__.py
@@ -14,9 +14,7 @@
# limitations under the License.
from .register import RegistrationHandler
-from .room import (
- RoomCreationHandler, RoomContextHandler,
-)
+from .room import RoomContextHandler
from .message import MessageHandler
from .federation import FederationHandler
from .directory import DirectoryHandler
@@ -47,7 +45,6 @@ class Handlers(object):
def __init__(self, hs):
self.registration_handler = RegistrationHandler(hs)
self.message_handler = MessageHandler(hs)
- self.room_creation_handler = RoomCreationHandler(hs)
self.federation_handler = FederationHandler(hs)
self.directory_handler = DirectoryHandler(hs)
self.admin_handler = AdminHandler(hs)
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index e089e66fde..2d1db0c245 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -114,14 +114,14 @@ class BaseHandler(object):
if guest_access != "can_join":
if context:
current_state = yield self.store.get_events(
- context.current_state_ids.values()
+ list(context.current_state_ids.values())
)
else:
current_state = yield self.state_handler.get_current_state(
event.room_id
)
- current_state = current_state.values()
+ current_state = list(current_state.values())
logger.info("maybe_kick_guest_users %r", current_state)
yield self.kick_guest_users(current_state)
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 0245197c02..1c29c43a83 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -15,20 +15,21 @@
from twisted.internet import defer
+from six import itervalues
+
import synapse
from synapse.api.constants import EventTypes
from synapse.util.metrics import Measure
from synapse.util.logcontext import (
- make_deferred_yieldable, preserve_fn, run_in_background,
+ make_deferred_yieldable, run_in_background,
)
+from prometheus_client import Counter
import logging
logger = logging.getLogger(__name__)
-metrics = synapse.metrics.get_metrics_for(__name__)
-
-events_processed_counter = metrics.register_counter("events_processed")
+events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "")
def log_failure(failure):
@@ -111,9 +112,7 @@ class ApplicationServicesHandler(object):
# Fork off pushes to these services
for service in services:
- preserve_fn(self.scheduler.submit_event_for_as)(
- service, event
- )
+ self.scheduler.submit_event_for_as(service, event)
@defer.inlineCallbacks
def handle_room_events(events):
@@ -122,7 +121,7 @@ class ApplicationServicesHandler(object):
yield make_deferred_yieldable(defer.gatherResults([
run_in_background(handle_room_events, evs)
- for evs in events_by_room.itervalues()
+ for evs in itervalues(events_by_room)
], consumeErrors=True))
yield self.store.set_appservice_last_pos(upper_bound)
@@ -130,18 +129,15 @@ class ApplicationServicesHandler(object):
now = self.clock.time_msec()
ts = yield self.store.get_received_ts(events[-1].event_id)
- synapse.metrics.event_processing_positions.set(
- upper_bound, "appservice_sender",
- )
+ synapse.metrics.event_processing_positions.labels(
+ "appservice_sender").set(upper_bound)
- events_processed_counter.inc_by(len(events))
+ events_processed_counter.inc(len(events))
- synapse.metrics.event_processing_lag.set(
- now - ts, "appservice_sender",
- )
- synapse.metrics.event_processing_last_ts.set(
- ts, "appservice_sender",
- )
+ synapse.metrics.event_processing_lag.labels(
+ "appservice_sender").set(now - ts)
+ synapse.metrics.event_processing_last_ts.labels(
+ "appservice_sender").set(ts)
finally:
self.is_processing = False
@@ -198,7 +194,10 @@ class ApplicationServicesHandler(object):
services = yield self._get_services_for_3pn(protocol)
results = yield make_deferred_yieldable(defer.DeferredList([
- preserve_fn(self.appservice_api.query_3pe)(service, kind, protocol, fields)
+ run_in_background(
+ self.appservice_api.query_3pe,
+ service, kind, protocol, fields,
+ )
for service in services
], consumeErrors=True))
@@ -259,11 +258,15 @@ class ApplicationServicesHandler(object):
event based on the service regex.
"""
services = self.store.get_app_services()
- interested_list = [
- s for s in services if (
- yield s.is_interested(event, self.store)
- )
- ]
+
+ # we can't use a list comprehension here. Since python 3, list
+ # comprehensions use a generator internally. This means you can't yield
+ # inside of a list comprehension anymore.
+ interested_list = []
+ for s in services:
+ if (yield s.is_interested(event, self.store)):
+ interested_list.append(s)
+
defer.returnValue(interested_list)
def _get_services_for_user(self, user_id):
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index a5365c4fe4..3c0051586d 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -249,7 +249,7 @@ class AuthHandler(BaseHandler):
errordict = e.error_dict()
for f in flows:
- if len(set(f) - set(creds.keys())) == 0:
+ if len(set(f) - set(creds)) == 0:
# it's very useful to know what args are stored, but this can
# include the password in the case of registering, so only log
# the keys (confusingly, clientdict may contain a password
@@ -257,12 +257,12 @@ class AuthHandler(BaseHandler):
# and is not sensitive).
logger.info(
"Auth completed with creds: %r. Client dict has keys: %r",
- creds, clientdict.keys()
+ creds, list(clientdict)
)
defer.returnValue((creds, clientdict, session['id']))
ret = self._auth_dict_for_flows(flows, session)
- ret['completed'] = creds.keys()
+ ret['completed'] = list(creds)
ret.update(errordict)
raise InteractiveAuthIncompleteError(
ret,
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index b1d3814909..c5e92f6214 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2017 New Vector Ltd
+# Copyright 2017, 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,9 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
+from twisted.internet import defer, reactor
from ._base import BaseHandler
+from synapse.types import UserID, create_requester
+from synapse.util.logcontext import run_in_background
import logging
@@ -27,6 +29,15 @@ class DeactivateAccountHandler(BaseHandler):
super(DeactivateAccountHandler, self).__init__(hs)
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
+ self._room_member_handler = hs.get_room_member_handler()
+ self.user_directory_handler = hs.get_user_directory_handler()
+
+ # Flag that indicates whether the process to part users from rooms is running
+ self._user_parter_running = False
+
+ # Start the user parter loop so it can resume parting users from rooms where
+ # it left off (if it has work left to do).
+ reactor.callWhenRunning(self._start_user_parting)
@defer.inlineCallbacks
def deactivate_account(self, user_id):
@@ -50,3 +61,73 @@ class DeactivateAccountHandler(BaseHandler):
yield self.store.user_delete_threepids(user_id)
yield self.store.user_set_password_hash(user_id, None)
+
+ # Add the user to a table of users pending deactivation (ie.
+ # removal from all the rooms they're a member of)
+ yield self.store.add_user_pending_deactivation(user_id)
+
+ # delete from user directory
+ yield self.user_directory_handler.handle_user_deactivated(user_id)
+
+ # Now start the process that goes through that list and
+ # parts users from rooms (if it isn't already running)
+ self._start_user_parting()
+
+ def _start_user_parting(self):
+ """
+ Start the process that goes through the table of users
+ pending deactivation, if it isn't already running.
+
+ Returns:
+ None
+ """
+ if not self._user_parter_running:
+ run_in_background(self._user_parter_loop)
+
+ @defer.inlineCallbacks
+ def _user_parter_loop(self):
+ """Loop that parts deactivated users from rooms
+
+ Returns:
+ None
+ """
+ self._user_parter_running = True
+ logger.info("Starting user parter")
+ try:
+ while True:
+ user_id = yield self.store.get_user_pending_deactivation()
+ if user_id is None:
+ break
+ logger.info("User parter parting %r", user_id)
+ yield self._part_user(user_id)
+ yield self.store.del_user_pending_deactivation(user_id)
+ logger.info("User parter finished parting %r", user_id)
+ logger.info("User parter finished: stopping")
+ finally:
+ self._user_parter_running = False
+
+ @defer.inlineCallbacks
+ def _part_user(self, user_id):
+ """Causes the given user_id to leave all the rooms they're joined to
+
+ Returns:
+ None
+ """
+ user = UserID.from_string(user_id)
+
+ rooms_for_user = yield self.store.get_rooms_for_user(user_id)
+ for room_id in rooms_for_user:
+ logger.info("User parter parting %r from %r", user_id, room_id)
+ try:
+ yield self._room_member_handler.update_membership(
+ create_requester(user),
+ user,
+ room_id,
+ "leave",
+ ratelimit=False,
+ )
+ except Exception:
+ logger.exception(
+ "Failed to part user %r from room %r: ignoring and continuing",
+ user_id, room_id,
+ )
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index f7457a7082..11c6fb3657 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -26,6 +26,8 @@ from ._base import BaseHandler
import logging
+from six import itervalues, iteritems
+
logger = logging.getLogger(__name__)
@@ -112,7 +114,7 @@ class DeviceHandler(BaseHandler):
user_id, device_id=None
)
- devices = device_map.values()
+ devices = list(device_map.values())
for device in devices:
_update_device_from_client_ips(device, ips)
@@ -185,7 +187,7 @@ class DeviceHandler(BaseHandler):
defer.Deferred:
"""
device_map = yield self.store.get_devices_by_user(user_id)
- device_ids = device_map.keys()
+ device_ids = list(device_map)
if except_device_id is not None:
device_ids = [d for d in device_ids if d != except_device_id]
yield self.delete_devices(user_id, device_ids)
@@ -318,7 +320,7 @@ class DeviceHandler(BaseHandler):
# The user may have left the room
# TODO: Check if they actually did or if we were just invited.
if room_id not in room_ids:
- for key, event_id in current_state_ids.iteritems():
+ for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
@@ -338,7 +340,7 @@ class DeviceHandler(BaseHandler):
# special-case for an empty prev state: include all members
# in the changed list
if not event_ids:
- for key, event_id in current_state_ids.iteritems():
+ for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
@@ -354,10 +356,10 @@ class DeviceHandler(BaseHandler):
# Check if we've joined the room? If so we just blindly add all the users to
# the "possibly changed" users.
- for state_dict in prev_state_ids.itervalues():
+ for state_dict in itervalues(prev_state_ids):
member_event = state_dict.get((EventTypes.Member, user_id), None)
if not member_event or member_event != current_member_id:
- for key, event_id in current_state_ids.iteritems():
+ for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
@@ -367,14 +369,14 @@ class DeviceHandler(BaseHandler):
# If there has been any change in membership, include them in the
# possibly changed list. We'll check if they are joined below,
# and we're not toooo worried about spuriously adding users.
- for key, event_id in current_state_ids.iteritems():
+ for key, event_id in iteritems(current_state_ids):
etype, state_key = key
if etype != EventTypes.Member:
continue
# check if this member has changed since any of the extremities
# at the stream_ordering, and add them to the list if so.
- for state_dict in prev_state_ids.itervalues():
+ for state_dict in itervalues(prev_state_ids):
prev_event_id = state_dict.get(key, None)
if not prev_event_id or prev_event_id != event_id:
if state_key != user_id:
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 325c0c4a9f..8a2d177539 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -19,12 +19,13 @@ import logging
from canonicaljson import encode_canonical_json
from twisted.internet import defer
+from six import iteritems
from synapse.api.errors import (
SynapseError, CodeMessageException, FederationDeniedError,
)
from synapse.types import get_domain_from_id, UserID
-from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
@@ -92,7 +93,7 @@ class E2eKeysHandler(object):
remote_queries_not_in_cache = {}
if remote_queries:
query_list = []
- for user_id, device_ids in remote_queries.iteritems():
+ for user_id, device_ids in iteritems(remote_queries):
if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids)
else:
@@ -103,9 +104,9 @@ class E2eKeysHandler(object):
query_list
)
)
- for user_id, devices in remote_results.iteritems():
+ for user_id, devices in iteritems(remote_results):
user_devices = results.setdefault(user_id, {})
- for device_id, device in devices.iteritems():
+ for device_id, device in iteritems(devices):
keys = device.get("keys", None)
device_display_name = device.get("device_display_name", None)
if keys:
@@ -139,9 +140,9 @@ class E2eKeysHandler(object):
failures[destination] = _exception_to_failure(e)
yield make_deferred_yieldable(defer.gatherResults([
- preserve_fn(do_remote_query)(destination)
+ run_in_background(do_remote_query, destination)
for destination in remote_queries_not_in_cache
- ]))
+ ], consumeErrors=True))
defer.returnValue({
"device_keys": results, "failures": failures,
@@ -242,17 +243,17 @@ class E2eKeysHandler(object):
failures[destination] = _exception_to_failure(e)
yield make_deferred_yieldable(defer.gatherResults([
- preserve_fn(claim_client_keys)(destination)
+ run_in_background(claim_client_keys, destination)
for destination in remote_queries
- ]))
+ ], consumeErrors=True))
logger.info(
"Claimed one-time-keys: %s",
",".join((
"%s for %s:%s" % (key_id, user_id, device_id)
- for user_id, user_keys in json_result.iteritems()
- for device_id, device_keys in user_keys.iteritems()
- for key_id, _ in device_keys.iteritems()
+ for user_id, user_keys in iteritems(json_result)
+ for device_id, device_keys in iteritems(user_keys)
+ for key_id, _ in iteritems(device_keys)
)),
)
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index d3685fb12a..8bc642675f 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -48,6 +48,7 @@ class EventStreamHandler(BaseHandler):
self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
+ self._server_notices_sender = hs.get_server_notices_sender()
@defer.inlineCallbacks
@log_function
@@ -58,6 +59,10 @@ class EventStreamHandler(BaseHandler):
If `only_keys` is not None, events from keys will be sent down.
"""
+
+ # send any outstanding server notices to the user.
+ yield self._server_notices_sender.on_user_syncing(auth_user_id)
+
auth_user = UserID.from_string(auth_user_id)
presence_handler = self.hs.get_presence_handler()
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index ae7e0d6da2..495ac4c648 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -16,12 +16,15 @@
"""Contains handlers for federation events."""
-import httplib
import itertools
import logging
+import sys
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
+import six
+from six.moves import http_client
+from six import iteritems
from twisted.internet import defer
from unpaddedbase64 import decode_base64
@@ -49,7 +52,6 @@ from synapse.util.retryutils import NotRetryingDestination
from synapse.util.distributor import user_joined_room
-
logger = logging.getLogger(__name__)
@@ -79,6 +81,7 @@ class FederationHandler(BaseHandler):
self.pusher_pool = hs.get_pusherpool()
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
+ self._server_notices_mxid = hs.config.server_notices_mxid
# When joining a room we need to queue any events for that room up
self.room_queues = {}
@@ -476,7 +479,7 @@ class FederationHandler(BaseHandler):
# to get all state ids that we're interested in.
event_map = yield self.store.get_events([
e_id
- for key_to_eid in event_to_state_ids.values()
+ for key_to_eid in list(event_to_state_ids.values())
for key, e_id in key_to_eid.items()
if key[0] != EventTypes.Member or check_match(key[1])
])
@@ -484,10 +487,10 @@ class FederationHandler(BaseHandler):
event_to_state = {
e_id: {
key: event_map[inner_e_id]
- for key, inner_e_id in key_to_eid.items()
+ for key, inner_e_id in key_to_eid.iteritems()
if inner_e_id in event_map
}
- for e_id, key_to_eid in event_to_state_ids.items()
+ for e_id, key_to_eid in event_to_state_ids.iteritems()
}
def redact_disallowed(event, state):
@@ -502,7 +505,7 @@ class FederationHandler(BaseHandler):
# membership states for the requesting server to determine
# if the server is either in the room or has been invited
# into the room.
- for ev in state.values():
+ for ev in state.itervalues():
if ev.type != EventTypes.Member:
continue
try:
@@ -637,7 +640,8 @@ class FederationHandler(BaseHandler):
results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- logcontext.preserve_fn(self.replication_layer.get_pdu)(
+ logcontext.run_in_background(
+ self.replication_layer.get_pdu,
[dest],
event_id,
outlier=True,
@@ -747,9 +751,19 @@ class FederationHandler(BaseHandler):
curr_state = yield self.state_handler.get_current_state(room_id)
def get_domains_from_state(state):
+ """Get joined domains from state
+
+ Args:
+ state (dict[tuple, FrozenEvent]): State map from type/state
+ key to event.
+
+ Returns:
+ list[tuple[str, int]]: Returns a list of servers with the
+ lowest depth of their joins. Sorted by lowest depth first.
+ """
joined_users = [
(state_key, int(event.depth))
- for (e_type, state_key), event in state.items()
+ for (e_type, state_key), event in state.iteritems()
if e_type == EventTypes.Member
and event.membership == Membership.JOIN
]
@@ -766,7 +780,7 @@ class FederationHandler(BaseHandler):
except Exception:
pass
- return sorted(joined_domains.items(), key=lambda d: d[1])
+ return sorted(joined_domains.iteritems(), key=lambda d: d[1])
curr_domains = get_domains_from_state(curr_state)
@@ -783,7 +797,7 @@ class FederationHandler(BaseHandler):
yield self.backfill(
dom, room_id,
limit=100,
- extremities=[e for e in extremities.keys()]
+ extremities=extremities,
)
# If this succeeded then we probably already have the
# appropriate stuff.
@@ -829,7 +843,7 @@ class FederationHandler(BaseHandler):
tried_domains = set(likely_domains)
tried_domains.add(self.server_name)
- event_ids = list(extremities.keys())
+ event_ids = list(extremities.iterkeys())
logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = logcontext.preserve_fn(
@@ -839,31 +853,34 @@ class FederationHandler(BaseHandler):
[resolve(room_id, [e]) for e in event_ids],
consumeErrors=True,
))
+
+ # dict[str, dict[tuple, str]], a map from event_id to state map of
+ # event_ids.
states = dict(zip(event_ids, [s.state for s in states]))
state_map = yield self.store.get_events(
- [e_id for ids in states.values() for e_id in ids],
+ [e_id for ids in states.itervalues() for e_id in ids.itervalues()],
get_prev_content=False
)
states = {
key: {
k: state_map[e_id]
- for k, e_id in state_dict.items()
+ for k, e_id in state_dict.iteritems()
if e_id in state_map
- } for key, state_dict in states.items()
+ } for key, state_dict in states.iteritems()
}
for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id])
success = yield try_backfill([
- dom for dom in likely_domains
+ dom for dom, _ in likely_domains
if dom not in tried_domains
])
if success:
defer.returnValue(True)
- tried_domains.update(likely_domains)
+ tried_domains.update(dom for dom, _ in likely_domains)
defer.returnValue(False)
@@ -887,7 +904,7 @@ class FederationHandler(BaseHandler):
logger.warn("Rejecting event %s which has %i prev_events",
ev.event_id, len(ev.prev_events))
raise SynapseError(
- httplib.BAD_REQUEST,
+ http_client.BAD_REQUEST,
"Too many prev_events",
)
@@ -895,7 +912,7 @@ class FederationHandler(BaseHandler):
logger.warn("Rejecting event %s which has %i auth_events",
ev.event_id, len(ev.auth_events))
raise SynapseError(
- httplib.BAD_REQUEST,
+ http_client.BAD_REQUEST,
"Too many auth_events",
)
@@ -1023,7 +1040,7 @@ class FederationHandler(BaseHandler):
# lots of requests for missing prev_events which we do actually
# have. Hence we fire off the deferred, but don't wait for it.
- logcontext.preserve_fn(self._handle_queued_pdus)(room_queue)
+ logcontext.run_in_background(self._handle_queued_pdus, room_queue)
defer.returnValue(True)
@@ -1131,13 +1148,13 @@ class FederationHandler(BaseHandler):
user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id)
- state_ids = context.prev_state_ids.values()
+ state_ids = list(context.prev_state_ids.values())
auth_chain = yield self.store.get_auth_chain(state_ids)
- state = yield self.store.get_events(context.prev_state_ids.values())
+ state = yield self.store.get_events(list(context.prev_state_ids.values()))
defer.returnValue({
- "state": state.values(),
+ "state": list(state.values()),
"auth_chain": auth_chain,
})
@@ -1177,6 +1194,13 @@ class FederationHandler(BaseHandler):
if not self.is_mine_id(event.state_key):
raise SynapseError(400, "The invite event must be for this server")
+ # block any attempts to invite the server notices mxid
+ if event.state_key == self._server_notices_mxid:
+ raise SynapseError(
+ http_client.FORBIDDEN,
+ "Cannot invite this user",
+ )
+
event.internal_metadata.outlier = True
event.internal_metadata.invite_from_remote = True
@@ -1364,7 +1388,7 @@ class FederationHandler(BaseHandler):
)
if state_groups:
- _, state = state_groups.items().pop()
+ _, state = list(iteritems(state_groups)).pop()
results = {
(e.type, e.state_key): e for e in state
}
@@ -1380,7 +1404,7 @@ class FederationHandler(BaseHandler):
else:
del results[(event.type, event.state_key)]
- res = results.values()
+ res = list(results.values())
for event in res:
# We sign these again because there was a bug where we
# incorrectly signed things the first time round
@@ -1421,7 +1445,7 @@ class FederationHandler(BaseHandler):
else:
results.pop((event.type, event.state_key), None)
- defer.returnValue(results.values())
+ defer.returnValue(list(results.values()))
else:
defer.returnValue([])
@@ -1513,18 +1537,21 @@ class FederationHandler(BaseHandler):
backfilled=backfilled,
)
except: # noqa: E722, as we reraise the exception this is fine.
- # Ensure that we actually remove the entries in the push actions
- # staging area
- logcontext.preserve_fn(
- self.store.remove_push_actions_from_staging
- )(event.event_id)
- raise
+ tp, value, tb = sys.exc_info()
+
+ logcontext.run_in_background(
+ self.store.remove_push_actions_from_staging,
+ event.event_id,
+ )
+
+ six.reraise(tp, value, tb)
if not backfilled:
# this intentionally does not yield: we don't care about the result
# and don't need to wait for it.
- logcontext.preserve_fn(self.pusher_pool.on_new_notifications)(
- event_stream_id, max_stream_id
+ logcontext.run_in_background(
+ self.pusher_pool.on_new_notifications,
+ event_stream_id, max_stream_id,
)
defer.returnValue((context, event_stream_id, max_stream_id))
@@ -1538,7 +1565,8 @@ class FederationHandler(BaseHandler):
"""
contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- logcontext.preserve_fn(self._prep_event)(
+ logcontext.run_in_background(
+ self._prep_event,
origin,
ev_info["event"],
state=ev_info.get("state"),
@@ -1766,6 +1794,10 @@ class FederationHandler(BaseHandler):
min_depth=min_depth,
)
+ missing_events = yield self._filter_events_for_server(
+ origin, room_id, missing_events,
+ )
+
defer.returnValue(missing_events)
@defer.inlineCallbacks
@@ -1867,7 +1899,8 @@ class FederationHandler(BaseHandler):
different_events = yield logcontext.make_deferred_yieldable(
defer.gatherResults([
- logcontext.preserve_fn(self.store.get_event)(
+ logcontext.run_in_background(
+ self.store.get_event,
d,
allow_none=True,
allow_rejected=False,
@@ -1885,7 +1918,7 @@ class FederationHandler(BaseHandler):
})
new_state = self.state_handler.resolve_events(
- [local_view.values(), remote_view.values()],
+ [list(local_view.values()), list(remote_view.values())],
event
)
@@ -2005,7 +2038,7 @@ class FederationHandler(BaseHandler):
this will not be included in the current_state in the context.
"""
state_updates = {
- k: a.event_id for k, a in auth_events.iteritems()
+ k: a.event_id for k, a in iteritems(auth_events)
if k != event_key
}
context.current_state_ids = dict(context.current_state_ids)
@@ -2015,7 +2048,7 @@ class FederationHandler(BaseHandler):
context.delta_ids.update(state_updates)
context.prev_state_ids = dict(context.prev_state_ids)
context.prev_state_ids.update({
- k: a.event_id for k, a in auth_events.iteritems()
+ k: a.event_id for k, a in iteritems(auth_events)
})
context.state_group = yield self.store.store_state_group(
event.event_id,
@@ -2067,7 +2100,7 @@ class FederationHandler(BaseHandler):
def get_next(it, opt=None):
try:
- return it.next()
+ return next(it)
except Exception:
return opt
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 977993e7d4..dcae083734 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -15,6 +15,7 @@
# limitations under the License.
from twisted.internet import defer
+from six import iteritems
from synapse.api.errors import SynapseError
from synapse.types import get_domain_from_id
@@ -449,7 +450,7 @@ class GroupsLocalHandler(object):
results = {}
failed_results = []
- for destination, dest_user_ids in destinations.iteritems():
+ for destination, dest_user_ids in iteritems(destinations):
try:
r = yield self.transport_client.bulk_get_publicised_groups(
destination, list(dest_user_ids),
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index c5267b4b84..71af86fe21 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -27,7 +27,7 @@ from synapse.types import (
from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.snapshot_cache import SnapshotCache
-from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -166,7 +166,8 @@ class InitialSyncHandler(BaseHandler):
(messages, token), current_state = yield make_deferred_yieldable(
defer.gatherResults(
[
- preserve_fn(self.store.get_recent_events_for_room)(
+ run_in_background(
+ self.store.get_recent_events_for_room,
event.room_id,
limit=limit,
end_token=room_end_token,
@@ -180,8 +181,8 @@ class InitialSyncHandler(BaseHandler):
self.store, user_id, messages
)
- start_token = now_token.copy_and_replace("room_key", token[0])
- end_token = now_token.copy_and_replace("room_key", token[1])
+ start_token = now_token.copy_and_replace("room_key", token)
+ end_token = now_token.copy_and_replace("room_key", room_end_token)
time_now = self.clock.time_msec()
d["messages"] = {
@@ -324,8 +325,8 @@ class InitialSyncHandler(BaseHandler):
self.store, user_id, messages, is_peeking=is_peeking
)
- start_token = StreamToken.START.copy_and_replace("room_key", token[0])
- end_token = StreamToken.START.copy_and_replace("room_key", token[1])
+ start_token = StreamToken.START.copy_and_replace("room_key", token)
+ end_token = StreamToken.START.copy_and_replace("room_key", stream_token)
time_now = self.clock.time_msec()
@@ -391,9 +392,10 @@ class InitialSyncHandler(BaseHandler):
presence, receipts, (messages, token) = yield defer.gatherResults(
[
- preserve_fn(get_presence)(),
- preserve_fn(get_receipts)(),
- preserve_fn(self.store.get_recent_events_for_room)(
+ run_in_background(get_presence),
+ run_in_background(get_receipts),
+ run_in_background(
+ self.store.get_recent_events_for_room,
room_id,
limit=limit,
end_token=now_token.room_key,
@@ -406,8 +408,8 @@ class InitialSyncHandler(BaseHandler):
self.store, user_id, messages, is_peeking=is_peeking,
)
- start_token = now_token.copy_and_replace("room_key", token[0])
- end_token = now_token.copy_and_replace("room_key", token[1])
+ start_token = now_token.copy_and_replace("room_key", token)
+ end_token = now_token
time_now = self.clock.time_msec()
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 53beb2b9ab..1cb81b6cf8 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -13,11 +13,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+import simplejson
+import sys
+
+from canonicaljson import encode_canonical_json
+import six
+from six import string_types, itervalues, iteritems
from twisted.internet import defer, reactor
+from twisted.internet.defer import succeed
from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership, MAX_DEPTH
-from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.errors import (
+ AuthError, Codes, SynapseError,
+ ConsentNotGivenError,
+)
+from synapse.api.urls import ConsentURIBuilder
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
@@ -25,7 +37,7 @@ from synapse.types import (
UserID, RoomAlias, RoomStreamToken,
)
from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
-from synapse.util.logcontext import preserve_fn, run_in_background
+from synapse.util.logcontext import run_in_background
from synapse.util.metrics import measure_func
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.stringutils import random_string
@@ -34,11 +46,6 @@ from synapse.replication.http.send_event import send_event_to_master
from ._base import BaseHandler
-from canonicaljson import encode_canonical_json
-
-import logging
-import simplejson
-
logger = logging.getLogger(__name__)
@@ -85,14 +92,14 @@ class MessageHandler(BaseHandler):
# map from purge id to PurgeStatus
self._purges_by_id = {}
- def start_purge_history(self, room_id, topological_ordering,
+ def start_purge_history(self, room_id, token,
delete_local_events=False):
"""Start off a history purge on a room.
Args:
room_id (str): The room to purge from
- topological_ordering (int): minimum topo ordering to preserve
+ token (str): topological token to delete events before
delete_local_events (bool): True to delete local events as well as
remote ones
@@ -114,19 +121,19 @@ class MessageHandler(BaseHandler):
self._purges_by_id[purge_id] = PurgeStatus()
run_in_background(
self._purge_history,
- purge_id, room_id, topological_ordering, delete_local_events,
+ purge_id, room_id, token, delete_local_events,
)
return purge_id
@defer.inlineCallbacks
- def _purge_history(self, purge_id, room_id, topological_ordering,
+ def _purge_history(self, purge_id, room_id, token,
delete_local_events):
"""Carry out a history purge on a room.
Args:
purge_id (str): The id for this purge
room_id (str): The room to purge from
- topological_ordering (int): minimum topo ordering to preserve
+ token (str): topological token to delete events before
delete_local_events (bool): True to delete local events as well as
remote ones
@@ -137,7 +144,7 @@ class MessageHandler(BaseHandler):
try:
with (yield self.pagination_lock.write(room_id)):
yield self.store.purge_history(
- room_id, topological_ordering, delete_local_events,
+ room_id, token, delete_local_events,
)
logger.info("[purge] complete")
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_COMPLETE
@@ -396,7 +403,7 @@ class MessageHandler(BaseHandler):
"avatar_url": profile.avatar_url,
"display_name": profile.display_name,
}
- for user_id, profile in users_with_profile.iteritems()
+ for user_id, profile in iteritems(users_with_profile)
})
@@ -430,6 +437,9 @@ class EventCreationHandler(object):
self.spam_checker = hs.get_spam_checker()
+ if self.config.block_events_without_consent_error is not None:
+ self._consent_uri_builder = ConsentURIBuilder(self.config)
+
@defer.inlineCallbacks
def create_event(self, requester, event_dict, token_id=None, txn_id=None,
prev_events_and_hashes=None):
@@ -481,6 +491,10 @@ class EventCreationHandler(object):
target, e
)
+ is_exempt = yield self._is_exempt_from_privacy_policy(builder)
+ if not is_exempt:
+ yield self.assert_accepted_privacy_policy(requester)
+
if token_id is not None:
builder.internal_metadata.token_id = token_id
@@ -495,6 +509,86 @@ class EventCreationHandler(object):
defer.returnValue((event, context))
+ def _is_exempt_from_privacy_policy(self, builder):
+ """"Determine if an event to be sent is exempt from having to consent
+ to the privacy policy
+
+ Args:
+ builder (synapse.events.builder.EventBuilder): event being created
+
+ Returns:
+ Deferred[bool]: true if the event can be sent without the user
+ consenting
+ """
+ # the only thing the user can do is join the server notices room.
+ if builder.type == EventTypes.Member:
+ membership = builder.content.get("membership", None)
+ if membership == Membership.JOIN:
+ return self._is_server_notices_room(builder.room_id)
+ return succeed(False)
+
+ @defer.inlineCallbacks
+ def _is_server_notices_room(self, room_id):
+ if self.config.server_notices_mxid is None:
+ defer.returnValue(False)
+ user_ids = yield self.store.get_users_in_room(room_id)
+ defer.returnValue(self.config.server_notices_mxid in user_ids)
+
+ @defer.inlineCallbacks
+ def assert_accepted_privacy_policy(self, requester):
+ """Check if a user has accepted the privacy policy
+
+ Called when the given user is about to do something that requires
+ privacy consent. We see if the user is exempt and otherwise check that
+ they have given consent. If they have not, a ConsentNotGiven error is
+ raised.
+
+ Args:
+ requester (synapse.types.Requester):
+ The user making the request
+
+ Returns:
+ Deferred[None]: returns normally if the user has consented or is
+ exempt
+
+ Raises:
+ ConsentNotGivenError: if the user has not given consent yet
+ """
+ if self.config.block_events_without_consent_error is None:
+ return
+
+ # exempt AS users from needing consent
+ if requester.app_service is not None:
+ return
+
+ user_id = requester.user.to_string()
+
+ # exempt the system notices user
+ if (
+ self.config.server_notices_mxid is not None and
+ user_id == self.config.server_notices_mxid
+ ):
+ return
+
+ u = yield self.store.get_user_by_id(user_id)
+ assert u is not None
+ if u["appservice_id"] is not None:
+ # users registered by an appservice are exempt
+ return
+ if u["consent_version"] == self.config.user_consent_version:
+ return
+
+ consent_uri = self._consent_uri_builder.build_user_consent_uri(
+ requester.user.localpart,
+ )
+ msg = self.config.block_events_without_consent_error % {
+ 'consent_uri': consent_uri,
+ }
+ raise ConsentNotGivenError(
+ msg=msg,
+ consent_uri=consent_uri,
+ )
+
@defer.inlineCallbacks
def send_nonmember_event(self, requester, event, context, ratelimit=True):
"""
@@ -577,7 +671,7 @@ class EventCreationHandler(object):
spam_error = self.spam_checker.check_event_for_spam(event)
if spam_error:
- if not isinstance(spam_error, basestring):
+ if not isinstance(spam_error, string_types):
spam_error = "Spam is not permitted here"
raise SynapseError(
403, spam_error, Codes.FORBIDDEN
@@ -733,8 +827,14 @@ class EventCreationHandler(object):
except: # noqa: E722, as we reraise the exception this is fine.
# Ensure that we actually remove the entries in the push actions
# staging area, if we calculated them.
- preserve_fn(self.store.remove_push_actions_from_staging)(event.event_id)
- raise
+ tp, value, tb = sys.exc_info()
+
+ run_in_background(
+ self.store.remove_push_actions_from_staging,
+ event.event_id,
+ )
+
+ six.reraise(tp, value, tb)
@defer.inlineCallbacks
def persist_and_notify_client_event(
@@ -785,7 +885,7 @@ class EventCreationHandler(object):
state_to_include_ids = [
e_id
- for k, e_id in context.current_state_ids.iteritems()
+ for k, e_id in iteritems(context.current_state_ids)
if k[0] in self.hs.config.room_invite_state_types
or k == (EventTypes.Member, event.sender)
]
@@ -799,7 +899,7 @@ class EventCreationHandler(object):
"content": e.content,
"sender": e.sender,
}
- for e in state_to_include.itervalues()
+ for e in itervalues(state_to_include)
]
invitee = UserID.from_string(event.state_key)
@@ -854,22 +954,33 @@ class EventCreationHandler(object):
# this intentionally does not yield: we don't care about the result
# and don't need to wait for it.
- preserve_fn(self.pusher_pool.on_new_notifications)(
+ run_in_background(
+ self.pusher_pool.on_new_notifications,
event_stream_id, max_stream_id
)
@defer.inlineCallbacks
def _notify():
yield run_on_reactor()
- self.notifier.on_new_room_event(
- event, event_stream_id, max_stream_id,
- extra_users=extra_users
- )
+ try:
+ self.notifier.on_new_room_event(
+ event, event_stream_id, max_stream_id,
+ extra_users=extra_users
+ )
+ except Exception:
+ logger.exception("Error notifying about new room event")
- preserve_fn(_notify)()
+ run_in_background(_notify)
if event.type == EventTypes.Message:
- presence = self.hs.get_presence_handler()
# We don't want to block sending messages on any presence code. This
# matters as sometimes presence code can take a while.
- preserve_fn(presence.bump_presence_active_time)(requester.user)
+ run_in_background(self._bump_active_time, requester.user)
+
+ @defer.inlineCallbacks
+ def _bump_active_time(self, user):
+ try:
+ presence = self.hs.get_presence_handler()
+ yield presence.bump_presence_active_time(user)
+ except Exception:
+ logger.exception("Error bumping presence active time")
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index a5e501897c..7fe568132f 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -25,38 +25,42 @@ The methods that define policy are:
from twisted.internet import defer, reactor
from contextlib import contextmanager
+from six import itervalues, iteritems
+
from synapse.api.errors import SynapseError
from synapse.api.constants import PresenceState
from synapse.storage.presence import UserPresenceState
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.async import Linearizer
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import run_in_background
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
from synapse.types import UserID, get_domain_from_id
-import synapse.metrics
+from synapse.metrics import LaterGauge
import logging
+from prometheus_client import Counter
logger = logging.getLogger(__name__)
-metrics = synapse.metrics.get_metrics_for(__name__)
-notified_presence_counter = metrics.register_counter("notified_presence")
-federation_presence_out_counter = metrics.register_counter("federation_presence_out")
-presence_updates_counter = metrics.register_counter("presence_updates")
-timers_fired_counter = metrics.register_counter("timers_fired")
-federation_presence_counter = metrics.register_counter("federation_presence")
-bump_active_time_counter = metrics.register_counter("bump_active_time")
+notified_presence_counter = Counter("synapse_handler_presence_notified_presence", "")
+federation_presence_out_counter = Counter(
+ "synapse_handler_presence_federation_presence_out", "")
+presence_updates_counter = Counter("synapse_handler_presence_presence_updates", "")
+timers_fired_counter = Counter("synapse_handler_presence_timers_fired", "")
+federation_presence_counter = Counter("synapse_handler_presence_federation_presence", "")
+bump_active_time_counter = Counter("synapse_handler_presence_bump_active_time", "")
-get_updates_counter = metrics.register_counter("get_updates", labels=["type"])
+get_updates_counter = Counter("synapse_handler_presence_get_updates", "", ["type"])
-notify_reason_counter = metrics.register_counter("notify_reason", labels=["reason"])
-state_transition_counter = metrics.register_counter(
- "state_transition", labels=["from", "to"]
+notify_reason_counter = Counter(
+ "synapse_handler_presence_notify_reason", "", ["reason"])
+state_transition_counter = Counter(
+ "synapse_handler_presence_state_transition", "", ["from", "to"]
)
@@ -87,6 +91,11 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
class PresenceHandler(object):
def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer):
+ """
self.is_mine = hs.is_mine
self.is_mine_id = hs.is_mine_id
self.clock = hs.get_clock()
@@ -94,7 +103,6 @@ class PresenceHandler(object):
self.wheel_timer = WheelTimer()
self.notifier = hs.get_notifier()
self.federation = hs.get_federation_sender()
-
self.state = hs.get_state_handler()
federation_registry = hs.get_federation_registry()
@@ -137,8 +145,9 @@ class PresenceHandler(object):
for state in active_presence
}
- metrics.register_callback(
- "user_to_current_state_size", lambda: len(self.user_to_current_state)
+ LaterGauge(
+ "synapse_handlers_presence_user_to_current_state_size", "", [],
+ lambda: len(self.user_to_current_state)
)
now = self.clock.time_msec()
@@ -208,7 +217,8 @@ class PresenceHandler(object):
60 * 1000,
)
- metrics.register_callback("wheel_timer_size", lambda: len(self.wheel_timer))
+ LaterGauge("synapse_handlers_presence_wheel_timer_size", "", [],
+ lambda: len(self.wheel_timer))
@defer.inlineCallbacks
def _on_shutdown(self):
@@ -255,6 +265,14 @@ class PresenceHandler(object):
logger.info("Finished _persist_unpersisted_changes")
@defer.inlineCallbacks
+ def _update_states_and_catch_exception(self, new_states):
+ try:
+ res = yield self._update_states(new_states)
+ defer.returnValue(res)
+ except Exception:
+ logger.exception("Error updating presence")
+
+ @defer.inlineCallbacks
def _update_states(self, new_states):
"""Updates presence of users. Sets the appropriate timeouts. Pokes
the notifier and federation if and only if the changed presence state
@@ -303,11 +321,11 @@ class PresenceHandler(object):
# TODO: We should probably ensure there are no races hereafter
- presence_updates_counter.inc_by(len(new_states))
+ presence_updates_counter.inc(len(new_states))
if to_notify:
- notified_presence_counter.inc_by(len(to_notify))
- yield self._persist_and_notify(to_notify.values())
+ notified_presence_counter.inc(len(to_notify))
+ yield self._persist_and_notify(list(to_notify.values()))
self.unpersisted_users_changes |= set(s.user_id for s in new_states)
self.unpersisted_users_changes -= set(to_notify.keys())
@@ -317,7 +335,7 @@ class PresenceHandler(object):
if user_id not in to_notify
}
if to_federation_ping:
- federation_presence_out_counter.inc_by(len(to_federation_ping))
+ federation_presence_out_counter.inc(len(to_federation_ping))
self._push_to_remotes(to_federation_ping.values())
@@ -355,7 +373,7 @@ class PresenceHandler(object):
for user_id in users_to_check
]
- timers_fired_counter.inc_by(len(states))
+ timers_fired_counter.inc(len(states))
changes = handle_timeouts(
states,
@@ -364,7 +382,7 @@ class PresenceHandler(object):
now=now,
)
- preserve_fn(self._update_states)(changes)
+ run_in_background(self._update_states_and_catch_exception, changes)
except Exception:
logger.exception("Exception in _handle_timeouts loop")
@@ -422,20 +440,23 @@ class PresenceHandler(object):
@defer.inlineCallbacks
def _end():
- if affect_presence:
+ try:
self.user_to_num_current_syncs[user_id] -= 1
prev_state = yield self.current_state_for_user(user_id)
yield self._update_states([prev_state.copy_and_replace(
last_user_sync_ts=self.clock.time_msec(),
)])
+ except Exception:
+ logger.exception("Error updating presence after sync")
@contextmanager
def _user_syncing():
try:
yield
finally:
- preserve_fn(_end)()
+ if affect_presence:
+ run_in_background(_end)
defer.returnValue(_user_syncing())
@@ -453,61 +474,6 @@ class PresenceHandler(object):
return syncing_user_ids
@defer.inlineCallbacks
- def update_external_syncs(self, process_id, syncing_user_ids):
- """Update the syncing users for an external process
-
- Args:
- process_id(str): An identifier for the process the users are
- syncing against. This allows synapse to process updates
- as user start and stop syncing against a given process.
- syncing_user_ids(set(str)): The set of user_ids that are
- currently syncing on that server.
- """
-
- # Grab the previous list of user_ids that were syncing on that process
- prev_syncing_user_ids = (
- self.external_process_to_current_syncs.get(process_id, set())
- )
- # Grab the current presence state for both the users that are syncing
- # now and the users that were syncing before this update.
- prev_states = yield self.current_state_for_users(
- syncing_user_ids | prev_syncing_user_ids
- )
- updates = []
- time_now_ms = self.clock.time_msec()
-
- # For each new user that is syncing check if we need to mark them as
- # being online.
- for new_user_id in syncing_user_ids - prev_syncing_user_ids:
- prev_state = prev_states[new_user_id]
- if prev_state.state == PresenceState.OFFLINE:
- updates.append(prev_state.copy_and_replace(
- state=PresenceState.ONLINE,
- last_active_ts=time_now_ms,
- last_user_sync_ts=time_now_ms,
- ))
- else:
- updates.append(prev_state.copy_and_replace(
- last_user_sync_ts=time_now_ms,
- ))
-
- # For each user that is still syncing or stopped syncing update the
- # last sync time so that we will correctly apply the grace period when
- # they stop syncing.
- for old_user_id in prev_syncing_user_ids:
- prev_state = prev_states[old_user_id]
- updates.append(prev_state.copy_and_replace(
- last_user_sync_ts=time_now_ms,
- ))
-
- yield self._update_states(updates)
-
- # Update the last updated time for the process. We expire the entries
- # if we don't receive an update in the given timeframe.
- self.external_process_last_updated_ms[process_id] = self.clock.time_msec()
- self.external_process_to_current_syncs[process_id] = syncing_user_ids
-
- @defer.inlineCallbacks
def update_external_syncs_row(self, process_id, user_id, is_syncing, sync_time_msec):
"""Update the syncing users for an external process as a delta.
@@ -570,7 +536,7 @@ class PresenceHandler(object):
prev_state.copy_and_replace(
last_user_sync_ts=time_now_ms,
)
- for prev_state in prev_states.itervalues()
+ for prev_state in itervalues(prev_states)
])
self.external_process_last_updated_ms.pop(process_id, None)
@@ -593,14 +559,14 @@ class PresenceHandler(object):
for user_id in user_ids
}
- missing = [user_id for user_id, state in states.iteritems() if not state]
+ missing = [user_id for user_id, state in iteritems(states) if not state]
if missing:
# There are things not in our in memory cache. Lets pull them out of
# the database.
res = yield self.store.get_presence_for_users(missing)
states.update(res)
- missing = [user_id for user_id, state in states.iteritems() if not state]
+ missing = [user_id for user_id, state in iteritems(states) if not state]
if missing:
new = {
user_id: UserPresenceState.default(user_id)
@@ -696,7 +662,7 @@ class PresenceHandler(object):
updates.append(prev_state.copy_and_replace(**new_fields))
if updates:
- federation_presence_counter.inc_by(len(updates))
+ federation_presence_counter.inc(len(updates))
yield self._update_states(updates)
@defer.inlineCallbacks
@@ -721,7 +687,7 @@ class PresenceHandler(object):
"""
updates = yield self.current_state_for_users(target_user_ids)
- updates = updates.values()
+ updates = list(updates.values())
for user_id in set(target_user_ids) - set(u.user_id for u in updates):
updates.append(UserPresenceState.default(user_id))
@@ -787,11 +753,11 @@ class PresenceHandler(object):
self._push_to_remotes([state])
else:
user_ids = yield self.store.get_users_in_room(room_id)
- user_ids = filter(self.is_mine_id, user_ids)
+ user_ids = list(filter(self.is_mine_id, user_ids))
states = yield self.current_state_for_users(user_ids)
- self._push_to_remotes(states.values())
+ self._push_to_remotes(list(states.values()))
@defer.inlineCallbacks
def get_presence_list(self, observer_user, accepted=None):
@@ -971,28 +937,28 @@ def should_notify(old_state, new_state):
return False
if old_state.status_msg != new_state.status_msg:
- notify_reason_counter.inc("status_msg_change")
+ notify_reason_counter.labels("status_msg_change").inc()
return True
if old_state.state != new_state.state:
- notify_reason_counter.inc("state_change")
- state_transition_counter.inc(old_state.state, new_state.state)
+ notify_reason_counter.labels("state_change").inc()
+ state_transition_counter.labels(old_state.state, new_state.state).inc()
return True
if old_state.state == PresenceState.ONLINE:
if new_state.currently_active != old_state.currently_active:
- notify_reason_counter.inc("current_active_change")
+ notify_reason_counter.labels("current_active_change").inc()
return True
if new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
# Only notify about last active bumps if we're not currently acive
if not new_state.currently_active:
- notify_reason_counter.inc("last_active_change_online")
+ notify_reason_counter.labels("last_active_change_online").inc()
return True
elif new_state.last_active_ts - old_state.last_active_ts > LAST_ACTIVE_GRANULARITY:
# Always notify for a transition where last active gets bumped.
- notify_reason_counter.inc("last_active_change_not_online")
+ notify_reason_counter.labels("last_active_change_not_online").inc()
return True
return False
@@ -1066,14 +1032,14 @@ class PresenceEventSource(object):
if changed is not None and len(changed) < 500:
# For small deltas, its quicker to get all changes and then
# work out if we share a room or they're in our presence list
- get_updates_counter.inc("stream")
+ get_updates_counter.labels("stream").inc()
for other_user_id in changed:
if other_user_id in users_interested_in:
user_ids_changed.add(other_user_id)
else:
# Too many possible updates. Find all users we can see and check
# if any of them have changed.
- get_updates_counter.inc("full")
+ get_updates_counter.labels("full").inc()
if from_key:
user_ids_changed = stream_change_cache.get_entities_changed(
@@ -1085,10 +1051,10 @@ class PresenceEventSource(object):
updates = yield presence.current_state_for_users(user_ids_changed)
if include_offline:
- defer.returnValue((updates.values(), max_token))
+ defer.returnValue((list(updates.values()), max_token))
else:
defer.returnValue(([
- s for s in updates.itervalues()
+ s for s in itervalues(updates)
if s.state != PresenceState.OFFLINE
], max_token))
@@ -1146,7 +1112,7 @@ def handle_timeouts(user_states, is_mine_fn, syncing_user_ids, now):
if new_state:
changes[state.user_id] = new_state
- return changes.values()
+ return list(changes.values())
def handle_timeout(state, is_mine, syncing_user_ids, now):
@@ -1345,11 +1311,11 @@ def get_interested_remotes(store, states, state_handler):
# hosts in those rooms.
room_ids_to_states, users_to_states = yield get_interested_parties(store, states)
- for room_id, states in room_ids_to_states.iteritems():
+ for room_id, states in iteritems(room_ids_to_states):
hosts = yield state_handler.get_current_hosts_in_room(room_id)
hosts_and_states.append((hosts, states))
- for user_id, states in users_to_states.iteritems():
+ for user_id, states in iteritems(users_to_states):
host = get_domain_from_id(user_id)
hosts_and_states.append(([host], states))
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 3f215c2b4e..2e0672161c 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -135,37 +135,40 @@ class ReceiptsHandler(BaseHandler):
"""Given a list of receipts, works out which remote servers should be
poked and pokes them.
"""
- # TODO: Some of this stuff should be coallesced.
- for receipt in receipts:
- room_id = receipt["room_id"]
- receipt_type = receipt["receipt_type"]
- user_id = receipt["user_id"]
- event_ids = receipt["event_ids"]
- data = receipt["data"]
-
- users = yield self.state.get_current_user_in_room(room_id)
- remotedomains = set(get_domain_from_id(u) for u in users)
- remotedomains = remotedomains.copy()
- remotedomains.discard(self.server_name)
-
- logger.debug("Sending receipt to: %r", remotedomains)
-
- for domain in remotedomains:
- self.federation.send_edu(
- destination=domain,
- edu_type="m.receipt",
- content={
- room_id: {
- receipt_type: {
- user_id: {
- "event_ids": event_ids,
- "data": data,
+ try:
+ # TODO: Some of this stuff should be coallesced.
+ for receipt in receipts:
+ room_id = receipt["room_id"]
+ receipt_type = receipt["receipt_type"]
+ user_id = receipt["user_id"]
+ event_ids = receipt["event_ids"]
+ data = receipt["data"]
+
+ users = yield self.state.get_current_user_in_room(room_id)
+ remotedomains = set(get_domain_from_id(u) for u in users)
+ remotedomains = remotedomains.copy()
+ remotedomains.discard(self.server_name)
+
+ logger.debug("Sending receipt to: %r", remotedomains)
+
+ for domain in remotedomains:
+ self.federation.send_edu(
+ destination=domain,
+ edu_type="m.receipt",
+ content={
+ room_id: {
+ receipt_type: {
+ user_id: {
+ "event_ids": event_ids,
+ "data": data,
+ }
}
- }
+ },
},
- },
- key=(room_id, receipt_type, user_id),
- )
+ key=(room_id, receipt_type, user_id),
+ )
+ except Exception:
+ logger.exception("Error pushing receipts to remote servers")
@defer.inlineCallbacks
def get_receipts_for_room(self, room_id, to_key):
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 523cb67536..ad9bb0a3cd 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -34,6 +34,11 @@ logger = logging.getLogger(__name__)
class RegistrationHandler(BaseHandler):
def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer):
+ """
super(RegistrationHandler, self).__init__(hs)
self.auth = hs.get_auth()
@@ -49,6 +54,7 @@ class RegistrationHandler(BaseHandler):
self._generate_user_id_linearizer = Linearizer(
name="_generate_user_id_linearizer",
)
+ self._server_notices_mxid = hs.config.server_notices_mxid
@defer.inlineCallbacks
def check_username(self, localpart, guest_access_token=None,
@@ -360,6 +366,14 @@ class RegistrationHandler(BaseHandler):
yield identity_handler.bind_threepid(c, user_id)
def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
+ # don't allow people to register the server notices mxid
+ if self._server_notices_mxid is not None:
+ if user_id == self._server_notices_mxid:
+ raise SynapseError(
+ 400, "This user ID is reserved.",
+ errcode=Codes.EXCLUSIVE
+ )
+
# valid user IDs must not clash with any user ID namespaces claimed by
# application services.
services = self.store.get_app_services()
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index bcb54fe41a..f80311c11f 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -70,14 +70,27 @@ class RoomCreationHandler(BaseHandler):
self.event_creation_handler = hs.get_event_creation_handler()
@defer.inlineCallbacks
- def create_room(self, requester, config, ratelimit=True):
+ def create_room(self, requester, config, ratelimit=True,
+ creator_join_profile=None):
""" Creates a new room.
Args:
- requester (Requester): The user who requested the room creation.
+ requester (synapse.types.Requester):
+ The user who requested the room creation.
config (dict) : A dict of configuration options.
+ ratelimit (bool): set to False to disable the rate limiter
+
+ creator_join_profile (dict|None):
+ Set to override the displayname and avatar for the creating
+ user in this room. If unset, displayname and avatar will be
+ derived from the user's profile. If set, should contain the
+ values to go in the body of the 'join' event (typically
+ `avatar_url` and/or `displayname`.
+
Returns:
- The new room ID.
+ Deferred[dict]:
+ a dict containing the keys `room_id` and, if an alias was
+ requested, `room_alias`.
Raises:
SynapseError if the room ID couldn't be stored, or something went
horribly wrong.
@@ -115,6 +128,10 @@ class RoomCreationHandler(BaseHandler):
except Exception:
raise SynapseError(400, "Invalid user_id: %s" % (i,))
+ yield self.event_creation_handler.assert_accepted_privacy_policy(
+ requester,
+ )
+
invite_3pid_list = config.get("invite_3pid", [])
visibility = config.get("visibility", None)
@@ -178,7 +195,8 @@ class RoomCreationHandler(BaseHandler):
initial_state=initial_state,
creation_content=creation_content,
room_alias=room_alias,
- power_level_content_override=config.get("power_level_content_override", {})
+ power_level_content_override=config.get("power_level_content_override", {}),
+ creator_join_profile=creator_join_profile,
)
if "name" in config:
@@ -258,6 +276,7 @@ class RoomCreationHandler(BaseHandler):
creation_content,
room_alias,
power_level_content_override,
+ creator_join_profile,
):
def create(etype, content, **kwargs):
e = {
@@ -301,6 +320,7 @@ class RoomCreationHandler(BaseHandler):
room_id,
"join",
ratelimit=False,
+ content=creator_join_profile,
)
# We treat the power levels override specially as this needs to be one
@@ -446,7 +466,7 @@ class RoomContextHandler(BaseHandler):
state = yield self.store.get_state_for_events(
[last_event_id], None
)
- results["state"] = state[last_event_id].values()
+ results["state"] = list(state[last_event_id].values())
results["start"] = now_token.copy_and_replace(
"room_key", results["start"]
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index add3f9b009..fc507cef36 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -15,6 +15,9 @@
from twisted.internet import defer
+from six import iteritems
+from six.moves import range
+
from ._base import BaseHandler
from synapse.api.constants import (
@@ -200,7 +203,7 @@ class RoomListHandler(BaseHandler):
step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1
chunk = []
- for i in xrange(0, len(rooms_to_scan), step):
+ for i in range(0, len(rooms_to_scan), step):
batch = rooms_to_scan[i:i + step]
logger.info("Processing %i rooms for result", len(batch))
yield concurrently_execute(
@@ -305,7 +308,7 @@ class RoomListHandler(BaseHandler):
)
event_map = yield self.store.get_events([
- event_id for key, event_id in current_state_ids.iteritems()
+ event_id for key, event_id in iteritems(current_state_ids)
if key[0] in (
EventTypes.JoinRules,
EventTypes.Name,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 714583f1d5..f930e939e8 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -17,11 +17,14 @@
import abc
import logging
+from six.moves import http_client
+
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
from twisted.internet import defer
from unpaddedbase64 import decode_base64
+import synapse.server
import synapse.types
from synapse.api.constants import (
EventTypes, Membership,
@@ -46,6 +49,11 @@ class RoomMemberHandler(object):
__metaclass__ = abc.ABCMeta
def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer):
+ """
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
@@ -63,6 +71,7 @@ class RoomMemberHandler(object):
self.clock = hs.get_clock()
self.spam_checker = hs.get_spam_checker()
+ self._server_notices_mxid = self.config.server_notices_mxid
@abc.abstractmethod
def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
@@ -290,11 +299,26 @@ class RoomMemberHandler(object):
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
- if effective_membership_state == "invite":
+ if effective_membership_state == Membership.INVITE:
+ # block any attempts to invite the server notices mxid
+ if target.to_string() == self._server_notices_mxid:
+ raise SynapseError(
+ http_client.FORBIDDEN,
+ "Cannot invite this user",
+ )
+
block_invite = False
- is_requester_admin = yield self.auth.is_server_admin(
- requester.user,
- )
+
+ if (self._server_notices_mxid is not None and
+ requester.user.to_string() == self._server_notices_mxid):
+ # allow the server notices mxid to send invites
+ is_requester_admin = True
+
+ else:
+ is_requester_admin = yield self.auth.is_server_admin(
+ requester.user,
+ )
+
if not is_requester_admin:
if self.config.block_non_admin_invites:
logger.info(
@@ -349,6 +373,20 @@ class RoomMemberHandler(object):
if same_sender and same_membership and same_content:
defer.returnValue(old_state)
+ # we don't allow people to reject invites to the server notice
+ # room, but they can leave it once they are joined.
+ if (
+ old_membership == Membership.INVITE and
+ effective_membership_state == Membership.LEAVE
+ ):
+ is_blocked = yield self._is_server_notice_room(room_id)
+ if is_blocked:
+ raise SynapseError(
+ http_client.FORBIDDEN,
+ "You cannot reject this invite",
+ errcode=Codes.CANNOT_LEAVE_SERVER_NOTICE_ROOM,
+ )
+
is_host_in_room = yield self._is_host_in_room(current_state_ids)
if effective_membership_state == Membership.JOIN:
@@ -844,6 +882,13 @@ class RoomMemberHandler(object):
defer.returnValue(False)
+ @defer.inlineCallbacks
+ def _is_server_notice_room(self, room_id):
+ if self._server_notices_mxid is None:
+ defer.returnValue(False)
+ user_ids = yield self.store.get_users_in_room(room_id)
+ defer.returnValue(self._server_notices_mxid in user_ids)
+
class RoomMemberMasterHandler(RoomMemberHandler):
def __init__(self, hs):
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 9772ed1a0e..1eca26aa1e 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -348,7 +348,7 @@ class SearchHandler(BaseHandler):
rooms = set(e.room_id for e in allowed_events)
for room_id in rooms:
state = yield self.state_handler.get_current_state(room_id)
- state_results[room_id] = state.values()
+ state_results[room_id] = list(state.values())
state_results.values()
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index b52e4c2aff..51ec727df0 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -28,6 +28,8 @@ import collections
import logging
import itertools
+from six import itervalues, iteritems
+
logger = logging.getLogger(__name__)
@@ -275,7 +277,7 @@ class SyncHandler(object):
# result returned by the event source is poor form (it might cache
# the object)
room_id = event["room_id"]
- event_copy = {k: v for (k, v) in event.iteritems()
+ event_copy = {k: v for (k, v) in iteritems(event)
if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
@@ -294,7 +296,7 @@ class SyncHandler(object):
for event in receipts:
room_id = event["room_id"]
# exclude room id, as above
- event_copy = {k: v for (k, v) in event.iteritems()
+ event_copy = {k: v for (k, v) in iteritems(event)
if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
@@ -325,7 +327,7 @@ class SyncHandler(object):
current_state_ids = frozenset()
if any(e.is_state() for e in recents):
current_state_ids = yield self.state.get_current_state_ids(room_id)
- current_state_ids = frozenset(current_state_ids.itervalues())
+ current_state_ids = frozenset(itervalues(current_state_ids))
recents = yield filter_events_for_client(
self.store,
@@ -354,12 +356,24 @@ class SyncHandler(object):
since_key = since_token.room_key
while limited and len(recents) < timeline_limit and max_repeat:
- events, end_key = yield self.store.get_room_events_stream_for_room(
- room_id,
- limit=load_limit + 1,
- from_key=since_key,
- to_key=end_key,
- )
+ # If we have a since_key then we are trying to get any events
+ # that have happened since `since_key` up to `end_key`, so we
+ # can just use `get_room_events_stream_for_room`.
+ # Otherwise, we want to return the last N events in the room
+ # in toplogical ordering.
+ if since_key:
+ events, end_key = yield self.store.get_room_events_stream_for_room(
+ room_id,
+ limit=load_limit + 1,
+ from_key=since_key,
+ to_key=end_key,
+ )
+ else:
+ events, end_key = yield self.store.get_recent_events_for_room(
+ room_id,
+ limit=load_limit + 1,
+ end_token=end_key,
+ )
loaded_recents = sync_config.filter_collection.filter_room_timeline(
events
)
@@ -370,7 +384,7 @@ class SyncHandler(object):
current_state_ids = frozenset()
if any(e.is_state() for e in loaded_recents):
current_state_ids = yield self.state.get_current_state_ids(room_id)
- current_state_ids = frozenset(current_state_ids.itervalues())
+ current_state_ids = frozenset(itervalues(current_state_ids))
loaded_recents = yield filter_events_for_client(
self.store,
@@ -429,7 +443,11 @@ class SyncHandler(object):
Returns:
A Deferred map from ((type, state_key)->Event)
"""
- last_events, token = yield self.store.get_recent_events_for_room(
+ # FIXME this claims to get the state at a stream position, but
+ # get_recent_events_for_room operates by topo ordering. This therefore
+ # does not reliably give you the state at the given stream position.
+ # (https://github.com/matrix-org/synapse/issues/3305)
+ last_events, _ = yield self.store.get_recent_events_for_room(
room_id, end_token=stream_position.room_key, limit=1,
)
@@ -523,11 +541,11 @@ class SyncHandler(object):
state = {}
if state_ids:
- state = yield self.store.get_events(state_ids.values())
+ state = yield self.store.get_events(list(state_ids.values()))
defer.returnValue({
(e.type, e.state_key): e
- for e in sync_config.filter_collection.filter_room_state(state.values())
+ for e in sync_config.filter_collection.filter_room_state(list(state.values()))
})
@defer.inlineCallbacks
@@ -876,7 +894,7 @@ class SyncHandler(object):
presence.extend(states)
# Deduplicate the presence entries so that there's at most one per user
- presence = {p.user_id: p for p in presence}.values()
+ presence = list({p.user_id: p for p in presence}.values())
presence = sync_config.filter_collection.filter_presence(
presence
@@ -972,7 +990,7 @@ class SyncHandler(object):
if since_token:
for joined_sync in sync_result_builder.joined:
it = itertools.chain(
- joined_sync.timeline.events, joined_sync.state.itervalues()
+ joined_sync.timeline.events, itervalues(joined_sync.state)
)
for event in it:
if event.type == EventTypes.Member:
@@ -1028,7 +1046,13 @@ class SyncHandler(object):
Returns:
Deferred(tuple): Returns a tuple of the form:
- `([RoomSyncResultBuilder], [InvitedSyncResult], newly_joined_rooms)`
+ `(room_entries, invited_rooms, newly_joined_rooms, newly_left_rooms)`
+
+ where:
+ room_entries is a list [RoomSyncResultBuilder]
+ invited_rooms is a list [InvitedSyncResult]
+ newly_joined rooms is a list[str] of room ids
+ newly_left_rooms is a list[str] of room ids
"""
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
@@ -1050,7 +1074,7 @@ class SyncHandler(object):
newly_left_rooms = []
room_entries = []
invited = []
- for room_id, events in mem_change_events_by_room_id.iteritems():
+ for room_id, events in iteritems(mem_change_events_by_room_id):
non_joins = [e for e in events if e.membership != Membership.JOIN]
has_join = len(non_joins) != len(events)
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 77c0cf146f..5d9736e88f 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -16,7 +16,7 @@
from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import run_in_background
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
from synapse.types import UserID, get_domain_from_id
@@ -97,7 +97,8 @@ class TypingHandler(object):
if self.hs.is_mine_id(member.user_id):
last_fed_poke = self._member_last_federation_poke.get(member, None)
if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now:
- preserve_fn(self._push_remote)(
+ run_in_background(
+ self._push_remote,
member=member,
typing=True
)
@@ -196,7 +197,7 @@ class TypingHandler(object):
def _push_update(self, member, typing):
if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users.
- preserve_fn(self._push_remote)(member, typing)
+ run_in_background(self._push_remote, member, typing)
self._push_update_local(
member=member,
@@ -205,28 +206,31 @@ class TypingHandler(object):
@defer.inlineCallbacks
def _push_remote(self, member, typing):
- users = yield self.state.get_current_user_in_room(member.room_id)
- self._member_last_federation_poke[member] = self.clock.time_msec()
+ try:
+ users = yield self.state.get_current_user_in_room(member.room_id)
+ self._member_last_federation_poke[member] = self.clock.time_msec()
- now = self.clock.time_msec()
- self.wheel_timer.insert(
- now=now,
- obj=member,
- then=now + FEDERATION_PING_INTERVAL,
- )
+ now = self.clock.time_msec()
+ self.wheel_timer.insert(
+ now=now,
+ obj=member,
+ then=now + FEDERATION_PING_INTERVAL,
+ )
- for domain in set(get_domain_from_id(u) for u in users):
- if domain != self.server_name:
- self.federation.send_edu(
- destination=domain,
- edu_type="m.typing",
- content={
- "room_id": member.room_id,
- "user_id": member.user_id,
- "typing": typing,
- },
- key=member,
- )
+ for domain in set(get_domain_from_id(u) for u in users):
+ if domain != self.server_name:
+ self.federation.send_edu(
+ destination=domain,
+ edu_type="m.typing",
+ content={
+ "room_id": member.room_id,
+ "user_id": member.user_id,
+ "typing": typing,
+ },
+ key=member,
+ )
+ except Exception:
+ logger.exception("Error pushing typing notif to remotes")
@defer.inlineCallbacks
def _recv_edu(self, origin, content):
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 714f0195c8..a39f0f7343 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -22,6 +22,7 @@ from synapse.util.metrics import Measure
from synapse.util.async import sleep
from synapse.types import get_localpart_from_id
+from six import iteritems
logger = logging.getLogger(__name__)
@@ -123,6 +124,13 @@ class UserDirectoryHandler(object):
)
@defer.inlineCallbacks
+ def handle_user_deactivated(self, user_id):
+ """Called when a user ID is deactivated
+ """
+ yield self.store.remove_from_user_dir(user_id)
+ yield self.store.remove_from_user_in_public_room(user_id)
+
+ @defer.inlineCallbacks
def _unsafe_process(self):
# If self.pos is None then means we haven't fetched it from DB
if self.pos is None:
@@ -403,7 +411,7 @@ class UserDirectoryHandler(object):
if change:
users_with_profile = yield self.state.get_current_user_in_room(room_id)
- for user_id, profile in users_with_profile.iteritems():
+ for user_id, profile in iteritems(users_with_profile):
yield self._handle_new_user(room_id, user_id, profile)
else:
users = yield self.store.get_users_in_public_due_to_room(room_id)
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index bfebb0f644..054372e179 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,3 +13,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.internet.defer import CancelledError
+from twisted.python import failure
+
+from synapse.api.errors import SynapseError
+
+
+class RequestTimedOutError(SynapseError):
+ """Exception representing timeout of an outbound request"""
+ def __init__(self):
+ super(RequestTimedOutError, self).__init__(504, "Timed out")
+
+
+def cancelled_to_request_timed_out_error(value, timeout):
+ """Turns CancelledErrors into RequestTimedOutErrors.
+
+ For use with async.add_timeout_to_deferred
+ """
+ if isinstance(value, failure.Failure):
+ value.trap(CancelledError)
+ raise RequestTimedOutError()
+ return value
diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py
index 343e932cb1..a797396ade 100644
--- a/synapse/http/additional_resource.py
+++ b/synapse/http/additional_resource.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.http.server import wrap_request_handler
+from synapse.http.server import wrap_json_request_handler
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
@@ -42,14 +42,13 @@ class AdditionalResource(Resource):
Resource.__init__(self)
self._handler = handler
- # these are required by the request_handler wrapper
- self.version_string = hs.version_string
+ # required by the request_handler wrapper
self.clock = hs.get_clock()
def render(self, request):
self._async_render(request)
return NOT_DONE_YET
- @wrap_request_handler
+ @wrap_json_request_handler
def _async_render(self, request):
return self._handler(request)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index f3e4973c2e..4d4eee3d64 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,10 +19,10 @@ from OpenSSL.SSL import VERIFY_NONE
from synapse.api.errors import (
CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
)
+from synapse.http import cancelled_to_request_timed_out_error
+from synapse.util.async import add_timeout_to_deferred
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.logcontext import make_deferred_yieldable
-from synapse.util import logcontext
-import synapse.metrics
from synapse.http.endpoint import SpiderEndpoint
from canonicaljson import encode_canonical_json
@@ -38,8 +39,9 @@ from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
from twisted.web._newclient import ResponseDone
-from StringIO import StringIO
+from six import StringIO
+from prometheus_client import Counter
import simplejson as json
import logging
import urllib
@@ -47,16 +49,9 @@ import urllib
logger = logging.getLogger(__name__)
-metrics = synapse.metrics.get_metrics_for(__name__)
-
-outgoing_requests_counter = metrics.register_counter(
- "requests",
- labels=["method"],
-)
-incoming_responses_counter = metrics.register_counter(
- "responses",
- labels=["method", "code"],
-)
+outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"])
+incoming_responses_counter = Counter("synapse_http_client_responses", "",
+ ["method", "code"])
class SimpleHttpClient(object):
@@ -93,32 +88,28 @@ class SimpleHttpClient(object):
def request(self, method, uri, *args, **kwargs):
# A small wrapper around self.agent.request() so we can easily attach
# counters to it
- outgoing_requests_counter.inc(method)
+ outgoing_requests_counter.labels(method).inc()
- def send_request():
+ logger.info("Sending request %s %s", method, uri)
+
+ try:
request_deferred = self.agent.request(
method, uri, *args, **kwargs
)
-
- return self.clock.time_bound_deferred(
+ add_timeout_to_deferred(
request_deferred,
- time_out=60,
+ 60, cancelled_to_request_timed_out_error,
)
+ response = yield make_deferred_yieldable(request_deferred)
- logger.info("Sending request %s %s", method, uri)
-
- try:
- with logcontext.PreserveLoggingContext():
- response = yield send_request()
-
- incoming_responses_counter.inc(method, response.code)
+ incoming_responses_counter.labels(method, response.code).inc()
logger.info(
"Received response to %s %s: %s",
method, uri, response.code
)
defer.returnValue(response)
except Exception as e:
- incoming_responses_counter.inc(method, "ERR")
+ incoming_responses_counter.labels(method, "ERR").inc()
logger.info(
"Error sending request to %s %s: %s %s",
method, uri, type(e).__name__, e.message
@@ -509,7 +500,7 @@ class SpiderHttpClient(SimpleHttpClient):
reactor,
SpiderEndpointFactory(hs)
)
- ), [('gzip', GzipDecoder)]
+ ), [(b'gzip', GzipDecoder)]
)
# We could look like Chrome:
# self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko)
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index 00572c2897..87a482650d 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -115,10 +115,15 @@ class _WrappedConnection(object):
if time.time() - self.last_request >= 2.5 * 60:
self.abort()
# Abort the underlying TLS connection. The abort() method calls
- # loseConnection() on the underlying TLS connection which tries to
+ # loseConnection() on the TLS connection which tries to
# shutdown the connection cleanly. We call abortConnection()
- # since that will promptly close the underlying TCP connection.
- self.transport.abortConnection()
+ # since that will promptly close the TLS connection.
+ #
+ # In Twisted >18.4; the TLS connection will be None if it has closed
+ # which will make abortConnection() throw. Check that the TLS connection
+ # is not None before trying to close it.
+ if self.transport.getHandle() is not None:
+ self.transport.abortConnection()
def request(self, request):
self.last_request = time.time()
@@ -286,7 +291,7 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
if (len(answers) == 1
and answers[0].type == dns.SRV
and answers[0].payload
- and answers[0].payload.target == dns.Name('.')):
+ and answers[0].payload.target == dns.Name(b'.')):
raise ConnectError("Service %s unavailable" % service_name)
for answer in answers:
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 60a29081e8..821aed362b 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,17 +13,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import synapse.util.retryutils
from twisted.internet import defer, reactor, protocol
from twisted.internet.error import DNSLookupError
from twisted.web.client import readBody, HTTPConnectionPool, Agent
from twisted.web.http_headers import Headers
from twisted.web._newclient import ResponseDone
+from synapse.http import cancelled_to_request_timed_out_error
from synapse.http.endpoint import matrix_federation_endpoint
-from synapse.util.async import sleep
-from synapse.util import logcontext
import synapse.metrics
+from synapse.util.async import sleep, add_timeout_to_deferred
+from synapse.util import logcontext
+from synapse.util.logcontext import make_deferred_yieldable
+import synapse.util.retryutils
from canonicaljson import encode_canonical_json
@@ -38,22 +41,19 @@ import logging
import random
import sys
import urllib
-import urlparse
+from six.moves.urllib import parse as urlparse
+from six import string_types
+
+from prometheus_client import Counter
logger = logging.getLogger(__name__)
outbound_logger = logging.getLogger("synapse.http.outbound")
-metrics = synapse.metrics.get_metrics_for(__name__)
-
-outgoing_requests_counter = metrics.register_counter(
- "requests",
- labels=["method"],
-)
-incoming_responses_counter = metrics.register_counter(
- "responses",
- labels=["method", "code"],
-)
+outgoing_requests_counter = Counter("synapse_http_matrixfederationclient_requests",
+ "", ["method"])
+incoming_responses_counter = Counter("synapse_http_matrixfederationclient_responses",
+ "", ["method", "code"])
MAX_LONG_RETRIES = 10
@@ -184,21 +184,20 @@ class MatrixFederationHttpClient(object):
producer = body_callback(method, http_url_bytes, headers_dict)
try:
- def send_request():
- request_deferred = self.agent.request(
- method,
- url_bytes,
- Headers(headers_dict),
- producer
- )
-
- return self.clock.time_bound_deferred(
- request_deferred,
- time_out=timeout / 1000. if timeout else 60,
- )
-
- with logcontext.PreserveLoggingContext():
- response = yield send_request()
+ request_deferred = self.agent.request(
+ method,
+ url_bytes,
+ Headers(headers_dict),
+ producer
+ )
+ add_timeout_to_deferred(
+ request_deferred,
+ timeout / 1000. if timeout else 60,
+ cancelled_to_request_timed_out_error,
+ )
+ response = yield make_deferred_yieldable(
+ request_deferred,
+ )
log_result = "%d %s" % (response.code, response.phrase,)
break
@@ -552,7 +551,7 @@ class MatrixFederationHttpClient(object):
encoded_args = {}
for k, vs in args.items():
- if isinstance(vs, basestring):
+ if isinstance(vs, string_types):
vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs]
@@ -667,7 +666,7 @@ def check_content_type_is_json(headers):
RuntimeError if the
"""
- c_type = headers.getRawHeaders("Content-Type")
+ c_type = headers.getRawHeaders(b"Content-Type")
if c_type is None:
raise RuntimeError(
"No Content-Type header"
@@ -684,7 +683,7 @@ def check_content_type_is_json(headers):
def encode_query_args(args):
encoded_args = {}
for k, vs in args.items():
- if isinstance(vs, basestring):
+ if isinstance(vs, string_types):
vs = [vs]
encoded_args[k] = [v.encode("UTF-8") for v in vs]
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
new file mode 100644
index 0000000000..dc06f6c443
--- /dev/null
+++ b/synapse/http/request_metrics.py
@@ -0,0 +1,273 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from prometheus_client.core import Counter, Histogram
+from synapse.metrics import LaterGauge
+
+from synapse.util.logcontext import LoggingContext
+
+logger = logging.getLogger(__name__)
+
+
+# total number of responses served, split by method/servlet/tag
+response_count = Counter(
+ "synapse_http_server_response_count", "", ["method", "servlet", "tag"]
+)
+
+requests_counter = Counter(
+ "synapse_http_server_requests_received", "", ["method", "servlet"]
+)
+
+outgoing_responses_counter = Counter(
+ "synapse_http_server_responses", "", ["method", "code"]
+)
+
+response_timer = Histogram(
+ "synapse_http_server_response_time_seconds", "sec", ["method", "servlet", "tag"]
+)
+
+response_ru_utime = Counter(
+ "synapse_http_server_response_ru_utime_seconds", "sec", ["method", "servlet", "tag"]
+)
+
+response_ru_stime = Counter(
+ "synapse_http_server_response_ru_stime_seconds", "sec", ["method", "servlet", "tag"]
+)
+
+response_db_txn_count = Counter(
+ "synapse_http_server_response_db_txn_count", "", ["method", "servlet", "tag"]
+)
+
+# seconds spent waiting for db txns, excluding scheduling time, when processing
+# this request
+response_db_txn_duration = Counter(
+ "synapse_http_server_response_db_txn_duration_seconds",
+ "",
+ ["method", "servlet", "tag"],
+)
+
+# seconds spent waiting for a db connection, when processing this request
+response_db_sched_duration = Counter(
+ "synapse_http_server_response_db_sched_duration_seconds",
+ "",
+ ["method", "servlet", "tag"],
+)
+
+# size in bytes of the response written
+response_size = Counter(
+ "synapse_http_server_response_size", "", ["method", "servlet", "tag"]
+)
+
+# In flight metrics are incremented while the requests are in flight, rather
+# than when the response was written.
+
+in_flight_requests_ru_utime = Counter(
+ "synapse_http_server_in_flight_requests_ru_utime_seconds",
+ "",
+ ["method", "servlet"],
+)
+
+in_flight_requests_ru_stime = Counter(
+ "synapse_http_server_in_flight_requests_ru_stime_seconds",
+ "",
+ ["method", "servlet"],
+)
+
+in_flight_requests_db_txn_count = Counter(
+ "synapse_http_server_in_flight_requests_db_txn_count", "", ["method", "servlet"]
+)
+
+# seconds spent waiting for db txns, excluding scheduling time, when processing
+# this request
+in_flight_requests_db_txn_duration = Counter(
+ "synapse_http_server_in_flight_requests_db_txn_duration_seconds",
+ "",
+ ["method", "servlet"],
+)
+
+# seconds spent waiting for a db connection, when processing this request
+in_flight_requests_db_sched_duration = Counter(
+ "synapse_http_server_in_flight_requests_db_sched_duration_seconds",
+ "",
+ ["method", "servlet"],
+)
+
+# The set of all in flight requests, set[RequestMetrics]
+_in_flight_requests = set()
+
+
+def _get_in_flight_counts():
+ """Returns a count of all in flight requests by (method, server_name)
+
+ Returns:
+ dict[tuple[str, str], int]
+ """
+ for rm in _in_flight_requests:
+ rm.update_metrics()
+
+ # Map from (method, name) -> int, the number of in flight requests of that
+ # type
+ counts = {}
+ for rm in _in_flight_requests:
+ key = (rm.method, rm.name,)
+ counts[key] = counts.get(key, 0) + 1
+
+ return counts
+
+
+LaterGauge(
+ "synapse_http_request_metrics_in_flight_requests_count",
+ "",
+ ["method", "servlet"],
+ _get_in_flight_counts,
+)
+
+
+class RequestMetrics(object):
+ def start(self, time_sec, name, method):
+ self.start = time_sec
+ self.start_context = LoggingContext.current_context()
+ self.name = name
+ self.method = method
+
+ self._request_stats = _RequestStats.from_context(self.start_context)
+
+ _in_flight_requests.add(self)
+
+ def stop(self, time_sec, request):
+ _in_flight_requests.discard(self)
+
+ context = LoggingContext.current_context()
+
+ tag = ""
+ if context:
+ tag = context.tag
+
+ if context != self.start_context:
+ logger.warn(
+ "Context have unexpectedly changed %r, %r",
+ context, self.start_context
+ )
+ return
+
+ outgoing_responses_counter.labels(request.method, str(request.code)).inc()
+
+ response_count.labels(request.method, self.name, tag).inc()
+
+ response_timer.labels(request.method, self.name, tag).observe(
+ time_sec - self.start
+ )
+
+ ru_utime, ru_stime = context.get_resource_usage()
+
+ response_ru_utime.labels(request.method, self.name, tag).inc(ru_utime)
+ response_ru_stime.labels(request.method, self.name, tag).inc(ru_stime)
+ response_db_txn_count.labels(request.method, self.name, tag).inc(
+ context.db_txn_count
+ )
+ response_db_txn_duration.labels(request.method, self.name, tag).inc(
+ context.db_txn_duration_sec
+ )
+ response_db_sched_duration.labels(request.method, self.name, tag).inc(
+ context.db_sched_duration_sec
+ )
+
+ response_size.labels(request.method, self.name, tag).inc(request.sentLength)
+
+ # We always call this at the end to ensure that we update the metrics
+ # regardless of whether a call to /metrics while the request was in
+ # flight.
+ self.update_metrics()
+
+ def update_metrics(self):
+ """Updates the in flight metrics with values from this request.
+ """
+ diff = self._request_stats.update(self.start_context)
+
+ in_flight_requests_ru_utime.labels(self.method, self.name).inc(diff.ru_utime)
+ in_flight_requests_ru_stime.labels(self.method, self.name).inc(diff.ru_stime)
+
+ in_flight_requests_db_txn_count.labels(self.method, self.name).inc(
+ diff.db_txn_count
+ )
+
+ in_flight_requests_db_txn_duration.labels(self.method, self.name).inc(
+ diff.db_txn_duration_sec
+ )
+
+ in_flight_requests_db_sched_duration.labels(self.method, self.name).inc(
+ diff.db_sched_duration_sec
+ )
+
+
+class _RequestStats(object):
+ """Keeps tracks of various metrics for an in flight request.
+ """
+
+ __slots__ = [
+ "ru_utime",
+ "ru_stime",
+ "db_txn_count",
+ "db_txn_duration_sec",
+ "db_sched_duration_sec",
+ ]
+
+ def __init__(
+ self, ru_utime, ru_stime, db_txn_count, db_txn_duration_sec, db_sched_duration_sec
+ ):
+ self.ru_utime = ru_utime
+ self.ru_stime = ru_stime
+ self.db_txn_count = db_txn_count
+ self.db_txn_duration_sec = db_txn_duration_sec
+ self.db_sched_duration_sec = db_sched_duration_sec
+
+ @staticmethod
+ def from_context(context):
+ ru_utime, ru_stime = context.get_resource_usage()
+
+ return _RequestStats(
+ ru_utime, ru_stime,
+ context.db_txn_count,
+ context.db_txn_duration_sec,
+ context.db_sched_duration_sec,
+ )
+
+ def update(self, context):
+ """Updates the current values and returns the difference between the
+ old and new values.
+
+ Returns:
+ _RequestStats: The difference between the old and new values
+ """
+ new = _RequestStats.from_context(context)
+
+ diff = _RequestStats(
+ new.ru_utime - self.ru_utime,
+ new.ru_stime - self.ru_stime,
+ new.db_txn_count - self.db_txn_count,
+ new.db_txn_duration_sec - self.db_txn_duration_sec,
+ new.db_sched_duration_sec - self.db_sched_duration_sec,
+ )
+
+ self.ru_utime = new.ru_utime
+ self.ru_stime = new.ru_stime
+ self.db_txn_count = new.db_txn_count
+ self.db_txn_duration_sec = new.db_txn_duration_sec
+ self.db_sched_duration_sec = new.db_sched_duration_sec
+
+ return diff
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 8d632290de..bc09b8b2be 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -13,11 +13,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import cgi
+from six.moves import http_client
from synapse.api.errors import (
cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError, Codes
)
+from synapse.http.request_metrics import (
+ requests_counter,
+)
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches import intern_dict
from synapse.util.metrics import Measure
@@ -41,178 +45,174 @@ import simplejson
logger = logging.getLogger(__name__)
-metrics = synapse.metrics.get_metrics_for(__name__)
-
-# total number of responses served, split by method/servlet/tag
-response_count = metrics.register_counter(
- "response_count",
- labels=["method", "servlet", "tag"],
- alternative_names=(
- # the following are all deprecated aliases for the same metric
- metrics.name_prefix + x for x in (
- "_requests",
- "_response_time:count",
- "_response_ru_utime:count",
- "_response_ru_stime:count",
- "_response_db_txn_count:count",
- "_response_db_txn_duration:count",
- )
- )
-)
+HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
+<html lang=en>
+ <head>
+ <meta charset="utf-8">
+ <title>Error {code}</title>
+ </head>
+ <body>
+ <p>{msg}</p>
+ </body>
+</html>
+"""
-requests_counter = metrics.register_counter(
- "requests_received",
- labels=["method", "servlet", ],
-)
-outgoing_responses_counter = metrics.register_counter(
- "responses",
- labels=["method", "code"],
-)
+def wrap_json_request_handler(h):
+ """Wraps a request handler method with exception handling.
-response_timer = metrics.register_counter(
- "response_time_seconds",
- labels=["method", "servlet", "tag"],
- alternative_names=(
- metrics.name_prefix + "_response_time:total",
- ),
-)
+ Also adds logging as per wrap_request_handler_with_logging.
-response_ru_utime = metrics.register_counter(
- "response_ru_utime_seconds", labels=["method", "servlet", "tag"],
- alternative_names=(
- metrics.name_prefix + "_response_ru_utime:total",
- ),
-)
+ The handler method must have a signature of "handle_foo(self, request)",
+ where "self" must have a "clock" attribute (and "request" must be a
+ SynapseRequest).
-response_ru_stime = metrics.register_counter(
- "response_ru_stime_seconds", labels=["method", "servlet", "tag"],
- alternative_names=(
- metrics.name_prefix + "_response_ru_stime:total",
- ),
-)
-
-response_db_txn_count = metrics.register_counter(
- "response_db_txn_count", labels=["method", "servlet", "tag"],
- alternative_names=(
- metrics.name_prefix + "_response_db_txn_count:total",
- ),
-)
+ The handler must return a deferred. If the deferred succeeds we assume that
+ a response has been sent. If the deferred fails with a SynapseError we use
+ it to send a JSON response with the appropriate HTTP reponse code. If the
+ deferred fails with any other type of error we send a 500 reponse.
+ """
-# seconds spent waiting for db txns, excluding scheduling time, when processing
-# this request
-response_db_txn_duration = metrics.register_counter(
- "response_db_txn_duration_seconds", labels=["method", "servlet", "tag"],
- alternative_names=(
- metrics.name_prefix + "_response_db_txn_duration:total",
- ),
-)
+ @defer.inlineCallbacks
+ def wrapped_request_handler(self, request):
+ try:
+ yield h(self, request)
+ except CodeMessageException as e:
+ code = e.code
+ if isinstance(e, SynapseError):
+ logger.info(
+ "%s SynapseError: %s - %s", request, code, e.msg
+ )
+ else:
+ logger.exception(e)
+ respond_with_json(
+ request, code, cs_exception(e), send_cors=True,
+ pretty_print=_request_user_agent_is_curl(request),
+ )
-# seconds spent waiting for a db connection, when processing this request
-response_db_sched_duration = metrics.register_counter(
- "response_db_sched_duration_seconds", labels=["method", "servlet", "tag"]
-)
+ except Exception:
+ # failure.Failure() fishes the original Failure out
+ # of our stack, and thus gives us a sensible stack
+ # trace.
+ f = failure.Failure()
+ logger.error(
+ "Failed handle request via %r: %r: %s",
+ h,
+ request,
+ f.getTraceback().rstrip(),
+ )
+ respond_with_json(
+ request,
+ 500,
+ {
+ "error": "Internal server error",
+ "errcode": Codes.UNKNOWN,
+ },
+ send_cors=True,
+ pretty_print=_request_user_agent_is_curl(request),
+ )
-# size in bytes of the response written
-response_size = metrics.register_counter(
- "response_size", labels=["method", "servlet", "tag"]
-)
+ return wrap_request_handler_with_logging(wrapped_request_handler)
-_next_request_id = 0
+def wrap_html_request_handler(h):
+ """Wraps a request handler method with exception handling.
-def request_handler(include_metrics=False):
- """Decorator for ``wrap_request_handler``"""
- return lambda request_handler: wrap_request_handler(request_handler, include_metrics)
+ Also adds logging as per wrap_request_handler_with_logging.
+ The handler method must have a signature of "handle_foo(self, request)",
+ where "self" must have a "clock" attribute (and "request" must be a
+ SynapseRequest).
+ """
+ def wrapped_request_handler(self, request):
+ d = defer.maybeDeferred(h, self, request)
+ d.addErrback(_return_html_error, request)
+ return d
-def wrap_request_handler(request_handler, include_metrics=False):
- """Wraps a method that acts as a request handler with the necessary logging
- and exception handling.
+ return wrap_request_handler_with_logging(wrapped_request_handler)
- The method must have a signature of "handle_foo(self, request)". The
- argument "self" must have "version_string" and "clock" attributes. The
- argument "request" must be a twisted HTTP request.
- The method must return a deferred. If the deferred succeeds we assume that
- a response has been sent. If the deferred fails with a SynapseError we use
- it to send a JSON response with the appropriate HTTP reponse code. If the
- deferred fails with any other type of error we send a 500 reponse.
+def _return_html_error(f, request):
+ """Sends an HTML error page corresponding to the given failure
- We insert a unique request-id into the logging context for this request and
- log the response and duration for this request.
+ Args:
+ f (twisted.python.failure.Failure):
+ request (twisted.web.iweb.IRequest):
"""
+ if f.check(CodeMessageException):
+ cme = f.value
+ code = cme.code
+ msg = cme.msg
+
+ if isinstance(cme, SynapseError):
+ logger.info(
+ "%s SynapseError: %s - %s", request, code, msg
+ )
+ else:
+ logger.error(
+ "Failed handle request %r: %s",
+ request,
+ f.getTraceback().rstrip(),
+ )
+ else:
+ code = http_client.INTERNAL_SERVER_ERROR
+ msg = "Internal server error"
+
+ logger.error(
+ "Failed handle request %r: %s",
+ request,
+ f.getTraceback().rstrip(),
+ )
+
+ body = HTML_ERROR_TEMPLATE.format(
+ code=code, msg=cgi.escape(msg),
+ ).encode("utf-8")
+ request.setResponseCode(code)
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(b"Content-Length", b"%i" % (len(body),))
+ request.write(body)
+ finish_request(request)
+
+
+def wrap_request_handler_with_logging(h):
+ """Wraps a request handler to provide logging and metrics
+ The handler method must have a signature of "handle_foo(self, request)",
+ where "self" must have a "clock" attribute (and "request" must be a
+ SynapseRequest).
+
+ As well as calling `request.processing` (which will log the response and
+ duration for this request), the wrapped request handler will insert the
+ request id into the logging context.
+ """
@defer.inlineCallbacks
def wrapped_request_handler(self, request):
- global _next_request_id
- request_id = "%s-%s" % (request.method, _next_request_id)
- _next_request_id += 1
+ """
+ Args:
+ self:
+ request (synapse.http.site.SynapseRequest):
+ """
+ request_id = request.get_request_id()
with LoggingContext(request_id) as request_context:
+ request_context.request = request_id
with Measure(self.clock, "wrapped_request_handler"):
- request_metrics = RequestMetrics()
# we start the request metrics timer here with an initial stab
# at the servlet name. For most requests that name will be
# JsonResource (or a subclass), and JsonResource._async_render
# will update it once it picks a servlet.
servlet_name = self.__class__.__name__
- request_metrics.start(self.clock, name=servlet_name)
-
- request_context.request = request_id
- with request.processing():
- try:
- with PreserveLoggingContext(request_context):
- if include_metrics:
- yield request_handler(self, request, request_metrics)
- else:
- requests_counter.inc(request.method, servlet_name)
- yield request_handler(self, request)
- except CodeMessageException as e:
- code = e.code
- if isinstance(e, SynapseError):
- logger.info(
- "%s SynapseError: %s - %s", request, code, e.msg
- )
- else:
- logger.exception(e)
- outgoing_responses_counter.inc(request.method, str(code))
- respond_with_json(
- request, code, cs_exception(e), send_cors=True,
- pretty_print=_request_user_agent_is_curl(request),
- version_string=self.version_string,
- )
- except Exception:
- # failure.Failure() fishes the original Failure out
- # of our stack, and thus gives us a sensible stack
- # trace.
- f = failure.Failure()
- logger.error(
- "Failed handle request %s.%s on %r: %r: %s",
- request_handler.__module__,
- request_handler.__name__,
- self,
- request,
- f.getTraceback().rstrip(),
- )
- respond_with_json(
- request,
- 500,
- {
- "error": "Internal server error",
- "errcode": Codes.UNKNOWN,
- },
- send_cors=True,
- pretty_print=_request_user_agent_is_curl(request),
- version_string=self.version_string,
- )
- finally:
- try:
- request_metrics.stop(
- self.clock, request
- )
- except Exception as e:
- logger.warn("Failed to stop metrics: %r", e)
+ with request.processing(servlet_name):
+ with PreserveLoggingContext(request_context):
+ d = defer.maybeDeferred(h, self, request)
+
+ # record the arrival of the request *after*
+ # dispatching to the handler, so that the handler
+ # can update the servlet name in the request
+ # metrics
+ requests_counter.labels(request.method,
+ request.request_metrics.name).inc()
+ yield d
return wrapped_request_handler
@@ -262,7 +262,6 @@ class JsonResource(HttpServer, resource.Resource):
self.canonical_json = canonical_json
self.clock = hs.get_clock()
self.path_regexs = {}
- self.version_string = hs.version_string
self.hs = hs
def register_paths(self, method, path_patterns, callback):
@@ -278,13 +277,9 @@ class JsonResource(HttpServer, resource.Resource):
self._async_render(request)
return server.NOT_DONE_YET
- # Disable metric reporting because _async_render does its own metrics.
- # It does its own metric reporting because _async_render dispatches to
- # a callback and it's the class name of that callback we want to report
- # against rather than the JsonResource itself.
- @request_handler(include_metrics=True)
+ @wrap_json_request_handler
@defer.inlineCallbacks
- def _async_render(self, request, request_metrics):
+ def _async_render(self, request):
""" This gets called from render() every time someone sends us a request.
This checks if anyone has registered a callback for that method and
path.
@@ -296,9 +291,7 @@ class JsonResource(HttpServer, resource.Resource):
servlet_classname = servlet_instance.__class__.__name__
else:
servlet_classname = "%r" % callback
-
- request_metrics.name = servlet_classname
- requests_counter.inc(request.method, servlet_classname)
+ request.request_metrics.name = servlet_classname
# Now trigger the callback. If it returns a response, we send it
# here. If it throws an exception, that is handled by the wrapper
@@ -345,15 +338,12 @@ class JsonResource(HttpServer, resource.Resource):
def _send_response(self, request, code, response_json_object,
response_code_message=None):
- outgoing_responses_counter.inc(request.method, str(code))
-
# TODO: Only enable CORS for the requests that need it.
respond_with_json(
request, code, response_json_object,
send_cors=True,
response_code_message=response_code_message,
pretty_print=_request_user_agent_is_curl(request),
- version_string=self.version_string,
canonical_json=self.canonical_json,
)
@@ -386,54 +376,6 @@ def _unrecognised_request_handler(request):
raise UnrecognizedRequestError()
-class RequestMetrics(object):
- def start(self, clock, name):
- self.start = clock.time_msec()
- self.start_context = LoggingContext.current_context()
- self.name = name
-
- def stop(self, clock, request):
- context = LoggingContext.current_context()
-
- tag = ""
- if context:
- tag = context.tag
-
- if context != self.start_context:
- logger.warn(
- "Context have unexpectedly changed %r, %r",
- context, self.start_context
- )
- return
-
- response_count.inc(request.method, self.name, tag)
-
- response_timer.inc_by(
- clock.time_msec() - self.start, request.method,
- self.name, tag
- )
-
- ru_utime, ru_stime = context.get_resource_usage()
-
- response_ru_utime.inc_by(
- ru_utime, request.method, self.name, tag
- )
- response_ru_stime.inc_by(
- ru_stime, request.method, self.name, tag
- )
- response_db_txn_count.inc_by(
- context.db_txn_count, request.method, self.name, tag
- )
- response_db_txn_duration.inc_by(
- context.db_txn_duration_ms / 1000., request.method, self.name, tag
- )
- response_db_sched_duration.inc_by(
- context.db_sched_duration_ms / 1000., request.method, self.name, tag
- )
-
- response_size.inc_by(request.sentLength, request.method, self.name, tag)
-
-
class RootRedirect(resource.Resource):
"""Redirects the root '/' path to another path."""
@@ -452,7 +394,7 @@ class RootRedirect(resource.Resource):
def respond_with_json(request, code, json_object, send_cors=False,
response_code_message=None, pretty_print=False,
- version_string="", canonical_json=True):
+ canonical_json=True):
# could alternatively use request.notifyFinish() and flip a flag when
# the Deferred fires, but since the flag is RIGHT THERE it seems like
# a waste.
@@ -474,12 +416,11 @@ def respond_with_json(request, code, json_object, send_cors=False,
request, code, json_bytes,
send_cors=send_cors,
response_code_message=response_code_message,
- version_string=version_string
)
def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
- version_string="", response_code_message=None):
+ response_code_message=None):
"""Sends encoded JSON in response to the given request.
Args:
@@ -493,7 +434,6 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
request.setResponseCode(code, message=response_code_message)
request.setHeader(b"Content-Type", b"application/json")
- request.setHeader(b"Server", version_string)
request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),))
request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate")
@@ -546,6 +486,6 @@ def _request_user_agent_is_curl(request):
b"User-Agent", default=[]
)
for user_agent in user_agents:
- if "curl" in user_agent:
+ if b"curl" in user_agent:
return True
return False
diff --git a/synapse/http/site.py b/synapse/http/site.py
index c8b46e1af2..60299657b9 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -12,27 +12,51 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util.logcontext import LoggingContext
-from twisted.web.server import Site, Request
-
import contextlib
import logging
import re
import time
+from twisted.web.server import Site, Request
+
+from synapse.http.request_metrics import RequestMetrics
+from synapse.util.logcontext import LoggingContext
+
+logger = logging.getLogger(__name__)
+
ACCESS_TOKEN_RE = re.compile(br'(\?.*access(_|%5[Ff])token=)[^&]*(.*)$')
+_next_request_seq = 0
+
class SynapseRequest(Request):
+ """Class which encapsulates an HTTP request to synapse.
+
+ All of the requests processed in synapse are of this type.
+
+ It extends twisted's twisted.web.server.Request, and adds:
+ * Unique request ID
+ * Redaction of access_token query-params in __repr__
+ * Logging at start and end
+ * Metrics to record CPU, wallclock and DB time by endpoint.
+
+ It provides a method `processing` which should be called by the Resource
+ which is handling the request, and returns a context manager.
+
+ """
def __init__(self, site, *args, **kw):
Request.__init__(self, *args, **kw)
self.site = site
self.authenticated_entity = None
self.start_time = 0
+ global _next_request_seq
+ self.request_seq = _next_request_seq
+ _next_request_seq += 1
+
def __repr__(self):
# We overwrite this so that we don't log ``access_token``
- return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % (
+ return '<%s at 0x%x method=%r uri=%r clientproto=%r site=%r>' % (
self.__class__.__name__,
id(self),
self.method,
@@ -41,6 +65,9 @@ class SynapseRequest(Request):
self.site.site_tag,
)
+ def get_request_id(self):
+ return "%s-%i" % (self.method, self.request_seq)
+
def get_redacted_uri(self):
return ACCESS_TOKEN_RE.sub(
br'\1<redacted>\3',
@@ -50,7 +77,18 @@ class SynapseRequest(Request):
def get_user_agent(self):
return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1]
- def started_processing(self):
+ def render(self, resrc):
+ # override the Server header which is set by twisted
+ self.setHeader("Server", self.site.server_version_string)
+ return Request.render(self, resrc)
+
+ def _started_processing(self, servlet_name):
+ self.start_time = time.time()
+ self.request_metrics = RequestMetrics()
+ self.request_metrics.start(
+ self.start_time, name=servlet_name, method=self.method,
+ )
+
self.site.access_logger.info(
"%s - %s - Received request: %s %s",
self.getClientIP(),
@@ -58,32 +96,32 @@ class SynapseRequest(Request):
self.method,
self.get_redacted_uri()
)
- self.start_time = int(time.time() * 1000)
-
- def finished_processing(self):
+ def _finished_processing(self):
try:
context = LoggingContext.current_context()
ru_utime, ru_stime = context.get_resource_usage()
db_txn_count = context.db_txn_count
- db_txn_duration_ms = context.db_txn_duration_ms
- db_sched_duration_ms = context.db_sched_duration_ms
+ db_txn_duration_sec = context.db_txn_duration_sec
+ db_sched_duration_sec = context.db_sched_duration_sec
except Exception:
ru_utime, ru_stime = (0, 0)
- db_txn_count, db_txn_duration_ms = (0, 0)
+ db_txn_count, db_txn_duration_sec = (0, 0)
+
+ end_time = time.time()
self.site.access_logger.info(
"%s - %s - {%s}"
- " Processed request: %dms (%dms, %dms) (%dms/%dms/%d)"
+ " Processed request: %.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)"
" %sB %s \"%s %s %s\" \"%s\"",
self.getClientIP(),
self.site.site_tag,
self.authenticated_entity,
- int(time.time() * 1000) - self.start_time,
- int(ru_utime * 1000),
- int(ru_stime * 1000),
- db_sched_duration_ms,
- db_txn_duration_ms,
+ end_time - self.start_time,
+ ru_utime,
+ ru_stime,
+ db_sched_duration_sec,
+ db_txn_duration_sec,
int(db_txn_count),
self.sentLength,
self.code,
@@ -93,11 +131,38 @@ class SynapseRequest(Request):
self.get_user_agent(),
)
+ try:
+ self.request_metrics.stop(end_time, self)
+ except Exception as e:
+ logger.warn("Failed to stop metrics: %r", e)
+
@contextlib.contextmanager
- def processing(self):
- self.started_processing()
+ def processing(self, servlet_name):
+ """Record the fact that we are processing this request.
+
+ Returns a context manager; the correct way to use this is:
+
+ @defer.inlineCallbacks
+ def handle_request(request):
+ with request.processing("FooServlet"):
+ yield really_handle_the_request()
+
+ This will log the request's arrival. Once the context manager is
+ closed, the completion of the request will be logged, and the various
+ metrics will be updated.
+
+ Args:
+ servlet_name (str): the name of the servlet which will be
+ processing this request. This is used in the metrics.
+
+ It is possible to update this afterwards by updating
+ self.request_metrics.servlet_name.
+ """
+ # TODO: we should probably just move this into render() and finish(),
+ # to save having to call a separate method.
+ self._started_processing(servlet_name)
yield
- self.finished_processing()
+ self._finished_processing()
class XForwardedForRequest(SynapseRequest):
@@ -135,7 +200,8 @@ class SynapseSite(Site):
Subclass of a twisted http Site that does access logging with python's
standard logging
"""
- def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs):
+ def __init__(self, logger_name, site_tag, config, resource,
+ server_version_string, *args, **kwargs):
Site.__init__(self, resource, *args, **kwargs)
self.site_tag = site_tag
@@ -143,6 +209,7 @@ class SynapseSite(Site):
proxied = config.get("x_forwarded", False)
self.requestFactory = SynapseRequestFactory(self, proxied)
self.access_logger = logging.getLogger(logger_name)
+ self.server_version_string = server_version_string
def log(self, request):
pass
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index e3b831db67..429e79c472 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -17,165 +17,178 @@ import logging
import functools
import time
import gc
+import os
import platform
+import attr
-from twisted.internet import reactor
+from prometheus_client import Gauge, Histogram, Counter
+from prometheus_client.core import GaugeMetricFamily, REGISTRY
-from .metric import (
- CounterMetric, CallbackMetric, DistributionMetric, CacheMetric,
- MemoryUsageMetric, GaugeMetric,
-)
-from .process_collector import register_process_collector
+from twisted.internet import reactor
logger = logging.getLogger(__name__)
-
-running_on_pypy = platform.python_implementation() == 'PyPy'
+running_on_pypy = platform.python_implementation() == "PyPy"
all_metrics = []
all_collectors = []
+all_gauges = {}
+
+HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
+
+
+class RegistryProxy(object):
+
+ @staticmethod
+ def collect():
+ for metric in REGISTRY.collect():
+ if not metric.name.startswith("__"):
+ yield metric
-class Metrics(object):
- """ A single Metrics object gives a (mutable) slice view of the all_metrics
- dict, allowing callers to easily register new metrics that are namespaced
- nicely."""
+@attr.s(hash=True)
+class LaterGauge(object):
- def __init__(self, name):
- self.name_prefix = name
+ name = attr.ib()
+ desc = attr.ib()
+ labels = attr.ib(hash=False)
+ caller = attr.ib()
- def make_subspace(self, name):
- return Metrics("%s_%s" % (self.name_prefix, name))
+ def collect(self):
- def register_collector(self, func):
- all_collectors.append(func)
+ g = GaugeMetricFamily(self.name, self.desc, labels=self.labels)
- def _register(self, metric_class, name, *args, **kwargs):
- full_name = "%s_%s" % (self.name_prefix, name)
+ try:
+ calls = self.caller()
+ except Exception:
+ logger.exception(
+ "Exception running callback for LaterGuage(%s)",
+ self.name,
+ )
+ yield g
+ return
- metric = metric_class(full_name, *args, **kwargs)
+ if isinstance(calls, dict):
+ for k, v in calls.items():
+ g.add_metric(k, v)
+ else:
+ g.add_metric([], calls)
- all_metrics.append(metric)
- return metric
+ yield g
- def register_counter(self, *args, **kwargs):
- """
- Returns:
- CounterMetric
- """
- return self._register(CounterMetric, *args, **kwargs)
+ def __attrs_post_init__(self):
+ self._register()
- def register_gauge(self, *args, **kwargs):
- """
- Returns:
- GaugeMetric
- """
- return self._register(GaugeMetric, *args, **kwargs)
+ def _register(self):
+ if self.name in all_gauges.keys():
+ logger.warning("%s already registered, reregistering" % (self.name,))
+ REGISTRY.unregister(all_gauges.pop(self.name))
- def register_callback(self, *args, **kwargs):
- """
- Returns:
- CallbackMetric
- """
- return self._register(CallbackMetric, *args, **kwargs)
+ REGISTRY.register(self)
+ all_gauges[self.name] = self
- def register_distribution(self, *args, **kwargs):
- """
- Returns:
- DistributionMetric
- """
- return self._register(DistributionMetric, *args, **kwargs)
- def register_cache(self, *args, **kwargs):
- """
- Returns:
- CacheMetric
- """
- return self._register(CacheMetric, *args, **kwargs)
+#
+# Detailed CPU metrics
+#
+
+class CPUMetrics(object):
+ def __init__(self):
+ ticks_per_sec = 100
+ try:
+ # Try and get the system config
+ ticks_per_sec = os.sysconf('SC_CLK_TCK')
+ except (ValueError, TypeError, AttributeError):
+ pass
-def register_memory_metrics(hs):
- try:
- import psutil
- process = psutil.Process()
- process.memory_info().rss
- except (ImportError, AttributeError):
- logger.warn(
- "psutil is not installed or incorrect version."
- " Disabling memory metrics."
- )
- return
- metric = MemoryUsageMetric(hs, psutil)
- all_metrics.append(metric)
+ self.ticks_per_sec = ticks_per_sec
+ def collect(self):
+ if not HAVE_PROC_SELF_STAT:
+ return
-def get_metrics_for(pkg_name):
- """ Returns a Metrics instance for conveniently creating metrics
- namespaced with the given name prefix. """
+ with open("/proc/self/stat") as s:
+ line = s.read()
+ raw_stats = line.split(") ", 1)[1].split(" ")
- # Convert a "package.name" to "package_name" because Prometheus doesn't
- # let us use . in metric names
- return Metrics(pkg_name.replace(".", "_"))
+ user = GaugeMetricFamily("process_cpu_user_seconds_total", "")
+ user.add_metric([], float(raw_stats[11]) / self.ticks_per_sec)
+ yield user
+ sys = GaugeMetricFamily("process_cpu_system_seconds_total", "")
+ sys.add_metric([], float(raw_stats[12]) / self.ticks_per_sec)
+ yield sys
-def render_all():
- strs = []
- for collector in all_collectors:
- collector()
+REGISTRY.register(CPUMetrics())
- for metric in all_metrics:
- try:
- strs += metric.render()
- except Exception:
- strs += ["# FAILED to render"]
- logger.exception("Failed to render metric")
+#
+# Python GC metrics
+#
+
+gc_unreachable = Gauge("python_gc_unreachable_total", "Unreachable GC objects", ["gen"])
+gc_time = Histogram(
+ "python_gc_time",
+ "Time taken to GC (sec)",
+ ["gen"],
+ buckets=[0.0025, 0.005, 0.01, 0.025, 0.05, 0.10, 0.25, 0.50, 1.00, 2.50,
+ 5.00, 7.50, 15.00, 30.00, 45.00, 60.00],
+)
- strs.append("") # to generate a final CRLF
- return "\n".join(strs)
+class GCCounts(object):
+ def collect(self):
+ cm = GaugeMetricFamily("python_gc_counts", "GC cycle counts", labels=["gen"])
+ for n, m in enumerate(gc.get_count()):
+ cm.add_metric([str(n)], m)
-register_process_collector(get_metrics_for("process"))
+ yield cm
-python_metrics = get_metrics_for("python")
+REGISTRY.register(GCCounts())
-gc_time = python_metrics.register_distribution("gc_time", labels=["gen"])
-gc_unreachable = python_metrics.register_counter("gc_unreachable_total", labels=["gen"])
-python_metrics.register_callback(
- "gc_counts", lambda: {(i,): v for i, v in enumerate(gc.get_count())}, labels=["gen"]
+#
+# Twisted reactor metrics
+#
+
+tick_time = Histogram(
+ "python_twisted_reactor_tick_time",
+ "Tick time of the Twisted reactor (sec)",
+ buckets=[0.001, 0.002, 0.005, 0.01, 0.025, 0.05, 0.1, 0.2, 0.5, 1, 2, 5],
+)
+pending_calls_metric = Histogram(
+ "python_twisted_reactor_pending_calls",
+ "Pending calls",
+ buckets=[1, 2, 5, 10, 25, 50, 100, 250, 500, 1000],
)
-reactor_metrics = get_metrics_for("python.twisted.reactor")
-tick_time = reactor_metrics.register_distribution("tick_time")
-pending_calls_metric = reactor_metrics.register_distribution("pending_calls")
+#
+# Federation Metrics
+#
+
+sent_edus_counter = Counter("synapse_federation_client_sent_edus", "")
+
+sent_transactions_counter = Counter("synapse_federation_client_sent_transactions", "")
-synapse_metrics = get_metrics_for("synapse")
+events_processed_counter = Counter("synapse_federation_client_events_processed", "")
# Used to track where various components have processed in the event stream,
# e.g. federation sending, appservice sending, etc.
-event_processing_positions = synapse_metrics.register_gauge(
- "event_processing_positions", labels=["name"],
-)
+event_processing_positions = Gauge("synapse_event_processing_positions", "", ["name"])
# Used to track the current max events stream position
-event_persisted_position = synapse_metrics.register_gauge(
- "event_persisted_position",
-)
+event_persisted_position = Gauge("synapse_event_persisted_position", "")
# Used to track the received_ts of the last event processed by various
# components
-event_processing_last_ts = synapse_metrics.register_gauge(
- "event_processing_last_ts", labels=["name"],
-)
+event_processing_last_ts = Gauge("synapse_event_processing_last_ts", "", ["name"])
# Used to track the lag processing events. This is the time difference
# between the last processed event's received_ts and the time it was
# finished being processed.
-event_processing_lag = synapse_metrics.register_gauge(
- "event_processing_lag", labels=["name"],
-)
+event_processing_lag = Gauge("synapse_event_processing_lag", "", ["name"])
def runUntilCurrentTimer(func):
@@ -197,17 +210,17 @@ def runUntilCurrentTimer(func):
num_pending += 1
num_pending += len(reactor.threadCallQueue)
- start = time.time() * 1000
+ start = time.time()
ret = func(*args, **kwargs)
- end = time.time() * 1000
+ end = time.time()
# record the amount of wallclock time spent running pending calls.
# This is a proxy for the actual amount of time between reactor polls,
# since about 25% of time is actually spent running things triggered by
# I/O events, but that is harder to capture without rewriting half the
# reactor.
- tick_time.inc_by(end - start)
- pending_calls_metric.inc_by(num_pending)
+ tick_time.observe(end - start)
+ pending_calls_metric.observe(num_pending)
if running_on_pypy:
return ret
@@ -220,12 +233,12 @@ def runUntilCurrentTimer(func):
if threshold[i] < counts[i]:
logger.info("Collecting gc %d", i)
- start = time.time() * 1000
+ start = time.time()
unreachable = gc.collect(i)
- end = time.time() * 1000
+ end = time.time()
- gc_time.inc_by(end - start, i)
- gc_unreachable.inc_by(unreachable, i)
+ gc_time.labels(i).observe(end - start)
+ gc_unreachable.labels(i).set(unreachable)
return ret
diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py
deleted file mode 100644
index 89bd47c3f7..0000000000
--- a/synapse/metrics/metric.py
+++ /dev/null
@@ -1,301 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from itertools import chain
-import logging
-
-logger = logging.getLogger(__name__)
-
-
-def flatten(items):
- """Flatten a list of lists
-
- Args:
- items: iterable[iterable[X]]
-
- Returns:
- list[X]: flattened list
- """
- return list(chain.from_iterable(items))
-
-
-class BaseMetric(object):
- """Base class for metrics which report a single value per label set
- """
-
- def __init__(self, name, labels=[], alternative_names=[]):
- """
- Args:
- name (str): principal name for this metric
- labels (list(str)): names of the labels which will be reported
- for this metric
- alternative_names (iterable(str)): list of alternative names for
- this metric. This can be useful to provide a migration path
- when renaming metrics.
- """
- self._names = [name] + list(alternative_names)
- self.labels = labels # OK not to clone as we never write it
-
- def dimension(self):
- return len(self.labels)
-
- def is_scalar(self):
- return not len(self.labels)
-
- def _render_labelvalue(self, value):
- # TODO: escape backslashes, quotes and newlines
- return '"%s"' % (value)
-
- def _render_key(self, values):
- if self.is_scalar():
- return ""
- return "{%s}" % (
- ",".join(["%s=%s" % (k, self._render_labelvalue(v))
- for k, v in zip(self.labels, values)])
- )
-
- def _render_for_labels(self, label_values, value):
- """Render this metric for a single set of labels
-
- Args:
- label_values (list[str]): values for each of the labels
- value: value of the metric at with these labels
-
- Returns:
- iterable[str]: rendered metric
- """
- rendered_labels = self._render_key(label_values)
- return (
- "%s%s %.12g" % (name, rendered_labels, value)
- for name in self._names
- )
-
- def render(self):
- """Render this metric
-
- Each metric is rendered as:
-
- name{label1="val1",label2="val2"} value
-
- https://prometheus.io/docs/instrumenting/exposition_formats/#text-format-details
-
- Returns:
- iterable[str]: rendered metrics
- """
- raise NotImplementedError()
-
-
-class CounterMetric(BaseMetric):
- """The simplest kind of metric; one that stores a monotonically-increasing
- value that counts events or running totals.
-
- Example use cases for Counters:
- - Number of requests processed
- - Number of items that were inserted into a queue
- - Total amount of data that a system has processed
- Counters can only go up (and be reset when the process restarts).
- """
-
- def __init__(self, *args, **kwargs):
- super(CounterMetric, self).__init__(*args, **kwargs)
-
- # dict[list[str]]: value for each set of label values. the keys are the
- # label values, in the same order as the labels in self.labels.
- #
- # (if the metric is a scalar, the (single) key is the empty tuple).
- self.counts = {}
-
- # Scalar metrics are never empty
- if self.is_scalar():
- self.counts[()] = 0.
-
- def inc_by(self, incr, *values):
- if len(values) != self.dimension():
- raise ValueError(
- "Expected as many values to inc() as labels (%d)" % (self.dimension())
- )
-
- # TODO: should assert that the tag values are all strings
-
- if values not in self.counts:
- self.counts[values] = incr
- else:
- self.counts[values] += incr
-
- def inc(self, *values):
- self.inc_by(1, *values)
-
- def render(self):
- return flatten(
- self._render_for_labels(k, self.counts[k])
- for k in sorted(self.counts.keys())
- )
-
-
-class GaugeMetric(BaseMetric):
- """A metric that can go up or down
- """
-
- def __init__(self, *args, **kwargs):
- super(GaugeMetric, self).__init__(*args, **kwargs)
-
- # dict[list[str]]: value for each set of label values. the keys are the
- # label values, in the same order as the labels in self.labels.
- #
- # (if the metric is a scalar, the (single) key is the empty tuple).
- self.guages = {}
-
- def set(self, v, *values):
- if len(values) != self.dimension():
- raise ValueError(
- "Expected as many values to inc() as labels (%d)" % (self.dimension())
- )
-
- # TODO: should assert that the tag values are all strings
-
- self.guages[values] = v
-
- def render(self):
- return flatten(
- self._render_for_labels(k, self.guages[k])
- for k in sorted(self.guages.keys())
- )
-
-
-class CallbackMetric(BaseMetric):
- """A metric that returns the numeric value returned by a callback whenever
- it is rendered. Typically this is used to implement gauges that yield the
- size or other state of some in-memory object by actively querying it."""
-
- def __init__(self, name, callback, labels=[]):
- super(CallbackMetric, self).__init__(name, labels=labels)
-
- self.callback = callback
-
- def render(self):
- try:
- value = self.callback()
- except Exception:
- logger.exception("Failed to render %s", self.name)
- return ["# FAILED to render " + self.name]
-
- if self.is_scalar():
- return list(self._render_for_labels([], value))
-
- return flatten(
- self._render_for_labels(k, value[k])
- for k in sorted(value.keys())
- )
-
-
-class DistributionMetric(object):
- """A combination of an event counter and an accumulator, which counts
- both the number of events and accumulates the total value. Typically this
- could be used to keep track of method-running times, or other distributions
- of values that occur in discrete occurances.
-
- TODO(paul): Try to export some heatmap-style stats?
- """
-
- def __init__(self, name, *args, **kwargs):
- self.counts = CounterMetric(name + ":count", **kwargs)
- self.totals = CounterMetric(name + ":total", **kwargs)
-
- def inc_by(self, inc, *values):
- self.counts.inc(*values)
- self.totals.inc_by(inc, *values)
-
- def render(self):
- return self.counts.render() + self.totals.render()
-
-
-class CacheMetric(object):
- __slots__ = (
- "name", "cache_name", "hits", "misses", "evicted_size", "size_callback",
- )
-
- def __init__(self, name, size_callback, cache_name):
- self.name = name
- self.cache_name = cache_name
-
- self.hits = 0
- self.misses = 0
- self.evicted_size = 0
-
- self.size_callback = size_callback
-
- def inc_hits(self):
- self.hits += 1
-
- def inc_misses(self):
- self.misses += 1
-
- def inc_evictions(self, size=1):
- self.evicted_size += size
-
- def render(self):
- size = self.size_callback()
- hits = self.hits
- total = self.misses + self.hits
-
- return [
- """%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits),
- """%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
- """%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
- """%s:evicted_size{name="%s"} %d""" % (
- self.name, self.cache_name, self.evicted_size
- ),
- ]
-
-
-class MemoryUsageMetric(object):
- """Keeps track of the current memory usage, using psutil.
-
- The class will keep the current min/max/sum/counts of rss over the last
- WINDOW_SIZE_SEC, by polling UPDATE_HZ times per second
- """
-
- UPDATE_HZ = 2 # number of times to get memory per second
- WINDOW_SIZE_SEC = 30 # the size of the window in seconds
-
- def __init__(self, hs, psutil):
- clock = hs.get_clock()
- self.memory_snapshots = []
-
- self.process = psutil.Process()
-
- clock.looping_call(self._update_curr_values, 1000 / self.UPDATE_HZ)
-
- def _update_curr_values(self):
- max_size = self.UPDATE_HZ * self.WINDOW_SIZE_SEC
- self.memory_snapshots.append(self.process.memory_info().rss)
- self.memory_snapshots[:] = self.memory_snapshots[-max_size:]
-
- def render(self):
- if not self.memory_snapshots:
- return []
-
- max_rss = max(self.memory_snapshots)
- min_rss = min(self.memory_snapshots)
- sum_rss = sum(self.memory_snapshots)
- len_rss = len(self.memory_snapshots)
-
- return [
- "process_psutil_rss:max %d" % max_rss,
- "process_psutil_rss:min %d" % min_rss,
- "process_psutil_rss:total %d" % sum_rss,
- "process_psutil_rss:count %d" % len_rss,
- ]
diff --git a/synapse/metrics/process_collector.py b/synapse/metrics/process_collector.py
deleted file mode 100644
index 6fec3de399..0000000000
--- a/synapse/metrics/process_collector.py
+++ /dev/null
@@ -1,122 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-
-
-TICKS_PER_SEC = 100
-BYTES_PER_PAGE = 4096
-
-HAVE_PROC_STAT = os.path.exists("/proc/stat")
-HAVE_PROC_SELF_STAT = os.path.exists("/proc/self/stat")
-HAVE_PROC_SELF_LIMITS = os.path.exists("/proc/self/limits")
-HAVE_PROC_SELF_FD = os.path.exists("/proc/self/fd")
-
-# Field indexes from /proc/self/stat, taken from the proc(5) manpage
-STAT_FIELDS = {
- "utime": 14,
- "stime": 15,
- "starttime": 22,
- "vsize": 23,
- "rss": 24,
-}
-
-
-stats = {}
-
-# In order to report process_start_time_seconds we need to know the
-# machine's boot time, because the value in /proc/self/stat is relative to
-# this
-boot_time = None
-if HAVE_PROC_STAT:
- with open("/proc/stat") as _procstat:
- for line in _procstat:
- if line.startswith("btime "):
- boot_time = int(line.split()[1])
-
-
-def update_resource_metrics():
- if HAVE_PROC_SELF_STAT:
- global stats
- with open("/proc/self/stat") as s:
- line = s.read()
- # line is PID (command) more stats go here ...
- raw_stats = line.split(") ", 1)[1].split(" ")
-
- for (name, index) in STAT_FIELDS.iteritems():
- # subtract 3 from the index, because proc(5) is 1-based, and
- # we've lost the first two fields in PID and COMMAND above
- stats[name] = int(raw_stats[index - 3])
-
-
-def _count_fds():
- # Not every OS will have a /proc/self/fd directory
- if not HAVE_PROC_SELF_FD:
- return 0
-
- return len(os.listdir("/proc/self/fd"))
-
-
-def register_process_collector(process_metrics):
- process_metrics.register_collector(update_resource_metrics)
-
- if HAVE_PROC_SELF_STAT:
- process_metrics.register_callback(
- "cpu_user_seconds_total",
- lambda: float(stats["utime"]) / TICKS_PER_SEC
- )
- process_metrics.register_callback(
- "cpu_system_seconds_total",
- lambda: float(stats["stime"]) / TICKS_PER_SEC
- )
- process_metrics.register_callback(
- "cpu_seconds_total",
- lambda: (float(stats["utime"] + stats["stime"])) / TICKS_PER_SEC
- )
-
- process_metrics.register_callback(
- "virtual_memory_bytes",
- lambda: int(stats["vsize"])
- )
- process_metrics.register_callback(
- "resident_memory_bytes",
- lambda: int(stats["rss"]) * BYTES_PER_PAGE
- )
-
- process_metrics.register_callback(
- "start_time_seconds",
- lambda: boot_time + int(stats["starttime"]) / TICKS_PER_SEC
- )
-
- if HAVE_PROC_SELF_FD:
- process_metrics.register_callback(
- "open_fds",
- lambda: _count_fds()
- )
-
- if HAVE_PROC_SELF_LIMITS:
- def _get_max_fds():
- with open("/proc/self/limits") as limits:
- for line in limits:
- if not line.startswith("Max open files "):
- continue
- # Line is Max open files $SOFT $HARD
- return int(line.split()[3])
- return None
-
- process_metrics.register_callback(
- "max_fds",
- lambda: _get_max_fds()
- )
diff --git a/synapse/metrics/resource.py b/synapse/metrics/resource.py
index 870f400600..9789359077 100644
--- a/synapse/metrics/resource.py
+++ b/synapse/metrics/resource.py
@@ -13,27 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.web.resource import Resource
-
-import synapse.metrics
-
+from prometheus_client.twisted import MetricsResource
METRICS_PREFIX = "/_synapse/metrics"
-
-class MetricsResource(Resource):
- isLeaf = True
-
- def __init__(self, hs):
- Resource.__init__(self) # Resource is old-style, so no super()
-
- self.hs = hs
-
- def render_GET(self, request):
- response = synapse.metrics.render_all()
-
- request.setHeader("Content-Type", "text/plain")
- request.setHeader("Content-Length", str(len(response)))
-
- # Encode as UTF-8 (default)
- return response.encode()
+__all__ = ["MetricsResource", "METRICS_PREFIX"]
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 0e40a4aad6..6dce20a284 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -14,33 +14,34 @@
# limitations under the License.
from twisted.internet import defer
+
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError
from synapse.handlers.presence import format_user_presence_state
-from synapse.util import DeferredTimedOutError
from synapse.util.logutils import log_function
-from synapse.util.async import ObservableDeferred
-from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
+from synapse.util.async import (
+ ObservableDeferred, add_timeout_to_deferred,
+ DeferredTimeoutError,
+)
+from synapse.util.logcontext import PreserveLoggingContext, run_in_background
from synapse.util.metrics import Measure
from synapse.types import StreamToken
from synapse.visibility import filter_events_for_client
-import synapse.metrics
+from synapse.metrics import LaterGauge
from collections import namedtuple
+from prometheus_client import Counter
import logging
logger = logging.getLogger(__name__)
-metrics = synapse.metrics.get_metrics_for(__name__)
-
-notified_events_counter = metrics.register_counter("notified_events")
+notified_events_counter = Counter("synapse_notifier_notified_events", "")
-users_woken_by_stream_counter = metrics.register_counter(
- "users_woken_by_stream", labels=["stream"]
-)
+users_woken_by_stream_counter = Counter(
+ "synapse_notifier_users_woken_by_stream", "", ["stream"])
# TODO(paul): Should be shared somewhere
@@ -105,7 +106,7 @@ class _NotifierUserStream(object):
self.last_notified_ms = time_now_ms
noify_deferred = self.notify_deferred
- users_woken_by_stream_counter.inc(stream_key)
+ users_woken_by_stream_counter.labels(stream_key).inc()
with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred())
@@ -194,14 +195,14 @@ class Notifier(object):
all_user_streams.add(x)
return sum(stream.count_listeners() for stream in all_user_streams)
- metrics.register_callback("listeners", count_listeners)
+ LaterGauge("synapse_notifier_listeners", "", [], count_listeners)
- metrics.register_callback(
- "rooms",
+ LaterGauge(
+ "synapse_notifier_rooms", "", [],
lambda: count(bool, self.room_to_user_streams.values()),
)
- metrics.register_callback(
- "users",
+ LaterGauge(
+ "synapse_notifier_users", "", [],
lambda: len(self.user_to_user_stream),
)
@@ -251,9 +252,7 @@ class Notifier(object):
def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
"""Notify any user streams that are interested in this room event"""
# poke any interested application service.
- preserve_fn(self.appservice_handler.notify_interested_services)(
- room_stream_id
- )
+ run_in_background(self._notify_app_services, room_stream_id)
if self.federation_sender:
self.federation_sender.notify_new_events(room_stream_id)
@@ -267,6 +266,13 @@ class Notifier(object):
rooms=[event.room_id],
)
+ @defer.inlineCallbacks
+ def _notify_app_services(self, room_stream_id):
+ try:
+ yield self.appservice_handler.notify_interested_services(room_stream_id)
+ except Exception:
+ logger.exception("Error notifying application services of event")
+
def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
""" Used to inform listeners that something has happend event wise.
@@ -331,11 +337,12 @@ class Notifier(object):
# Now we wait for the _NotifierUserStream to be told there
# is a new token.
listener = user_stream.new_listener(prev_token)
+ add_timeout_to_deferred(
+ listener.deferred,
+ (end_time - now) / 1000.,
+ )
with PreserveLoggingContext():
- yield self.clock.time_bound_deferred(
- listener.deferred,
- time_out=(end_time - now) / 1000.
- )
+ yield listener.deferred
current_token = user_stream.current_token
@@ -346,7 +353,7 @@ class Notifier(object):
# Update the prev_token to the current_token since nothing
# has happened between the old prev_token and the current_token
prev_token = current_token
- except DeferredTimedOutError:
+ except DeferredTimeoutError:
break
except defer.CancelledError:
break
@@ -551,13 +558,14 @@ class Notifier(object):
if end_time <= now:
break
+ add_timeout_to_deferred(
+ listener.deferred.addTimeout,
+ (end_time - now) / 1000.,
+ )
try:
with PreserveLoggingContext():
- yield self.clock.time_bound_deferred(
- listener.deferred,
- time_out=(end_time - now) / 1000.
- )
- except DeferredTimedOutError:
+ yield listener.deferred
+ except DeferredTimeoutError:
break
except defer.CancelledError:
break
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 7a18afe5f9..a8ae7bcd6c 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -39,7 +39,7 @@ def list_with_base_rules(rawrules):
rawrules = [r for r in rawrules if r['priority_class'] >= 0]
# shove the server default rules for each kind onto the end of each
- current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1]
+ current_prio_class = list(PRIORITY_CLASS_INVERSE_MAP)[-1]
ruleslist.extend(make_base_prepend_rules(
PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 7c680659b6..a5cab1f043 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -22,35 +22,32 @@ from .push_rule_evaluator import PushRuleEvaluatorForEvent
from synapse.event_auth import get_user_power_level
from synapse.api.constants import EventTypes, Membership
-from synapse.metrics import get_metrics_for
-from synapse.util.caches import metrics as cache_metrics
+from synapse.util.caches import register_cache
from synapse.util.caches.descriptors import cached
from synapse.util.async import Linearizer
from synapse.state import POWER_KEY
from collections import namedtuple
-
+from prometheus_client import Counter
+from six import itervalues, iteritems
logger = logging.getLogger(__name__)
rules_by_room = {}
-push_metrics = get_metrics_for(__name__)
-push_rules_invalidation_counter = push_metrics.register_counter(
- "push_rules_invalidation_counter"
-)
-push_rules_state_size_counter = push_metrics.register_counter(
- "push_rules_state_size_counter"
-)
+push_rules_invalidation_counter = Counter(
+ "synapse_push_bulk_push_rule_evaluator_push_rules_invalidation_counter", "")
+push_rules_state_size_counter = Counter(
+ "synapse_push_bulk_push_rule_evaluator_push_rules_state_size_counter", "")
# Measures whether we use the fast path of using state deltas, or if we have to
# recalculate from scratch
-push_rules_delta_state_cache_metric = cache_metrics.register_cache(
+push_rules_delta_state_cache_metric = register_cache(
"cache",
- size_callback=lambda: 0, # Meaningless size, as this isn't a cache that stores values
- cache_name="push_rules_delta_state_cache_metric",
+ "push_rules_delta_state_cache_metric",
+ cache=[], # Meaningless size, as this isn't a cache that stores values
)
@@ -64,10 +61,10 @@ class BulkPushRuleEvaluator(object):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- self.room_push_rule_cache_metrics = cache_metrics.register_cache(
+ self.room_push_rule_cache_metrics = register_cache(
"cache",
- size_callback=lambda: 0, # There's not good value for this
- cache_name="room_push_rule_cache",
+ "room_push_rule_cache",
+ cache=[], # Meaningless size, as this isn't a cache that stores values
)
@defer.inlineCallbacks
@@ -126,7 +123,7 @@ class BulkPushRuleEvaluator(object):
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
- (e.type, e.state_key): e for e in auth_events.itervalues()
+ (e.type, e.state_key): e for e in itervalues(auth_events)
}
sender_level = get_user_power_level(event.sender, auth_events)
@@ -160,7 +157,7 @@ class BulkPushRuleEvaluator(object):
condition_cache = {}
- for uid, rules in rules_by_user.iteritems():
+ for uid, rules in iteritems(rules_by_user):
if event.sender == uid:
continue
@@ -309,7 +306,7 @@ class RulesForRoom(object):
current_state_ids = context.current_state_ids
push_rules_delta_state_cache_metric.inc_misses()
- push_rules_state_size_counter.inc_by(len(current_state_ids))
+ push_rules_state_size_counter.inc(len(current_state_ids))
logger.debug(
"Looking for member changes in %r %r", state_group, current_state_ids
@@ -406,7 +403,7 @@ class RulesForRoom(object):
# If the event is a join event then it will be in current state evnts
# map but not in the DB, so we have to explicitly insert it.
if event.type == EventTypes.Member:
- for event_id in member_event_ids.itervalues():
+ for event_id in itervalues(member_event_ids):
if event_id == event.event_id:
members[event_id] = (event.state_key, event.membership)
@@ -414,7 +411,7 @@ class RulesForRoom(object):
logger.debug("Found members %r: %r", self.room_id, members.values())
interested_in_user_ids = set(
- user_id for user_id, membership in members.itervalues()
+ user_id for user_id, membership in itervalues(members)
if membership == Membership.JOIN
)
@@ -426,7 +423,7 @@ class RulesForRoom(object):
)
user_ids = set(
- uid for uid, have_pusher in if_users_with_pushers.iteritems() if have_pusher
+ uid for uid, have_pusher in iteritems(if_users_with_pushers) if have_pusher
)
logger.debug("With pushers: %r", user_ids)
@@ -447,7 +444,7 @@ class RulesForRoom(object):
)
ret_rules_by_user.update(
- item for item in rules_by_user.iteritems() if item[0] is not None
+ item for item in iteritems(rules_by_user) if item[0] is not None
)
self.update_cache(sequence, members, ret_rules_by_user, state_group)
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 58df98a793..ba7286cb72 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -77,10 +77,13 @@ class EmailPusher(object):
@defer.inlineCallbacks
def on_started(self):
if self.mailer is not None:
- self.throttle_params = yield self.store.get_throttle_params_by_room(
- self.pusher_id
- )
- yield self._process()
+ try:
+ self.throttle_params = yield self.store.get_throttle_params_by_room(
+ self.pusher_id
+ )
+ yield self._process()
+ except Exception:
+ logger.exception("Error starting email pusher")
def on_stop(self):
if self.timed_call:
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 2cbac571b8..bf7ff74a1a 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -18,24 +18,19 @@ import logging
from twisted.internet import defer, reactor
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
-import push_rule_evaluator
-import push_tools
-import synapse
+from . import push_rule_evaluator
+from . import push_tools
from synapse.push import PusherConfigException
from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure
-logger = logging.getLogger(__name__)
+from prometheus_client import Counter
-metrics = synapse.metrics.get_metrics_for(__name__)
+logger = logging.getLogger(__name__)
-http_push_processed_counter = metrics.register_counter(
- "http_pushes_processed",
-)
+http_push_processed_counter = Counter("synapse_http_httppusher_http_pushes_processed", "")
-http_push_failed_counter = metrics.register_counter(
- "http_pushes_failed",
-)
+http_push_failed_counter = Counter("synapse_http_httppusher_http_pushes_failed", "")
class HttpPusher(object):
@@ -94,7 +89,10 @@ class HttpPusher(object):
@defer.inlineCallbacks
def on_started(self):
- yield self._process()
+ try:
+ yield self._process()
+ except Exception:
+ logger.exception("Error starting http pusher")
@defer.inlineCallbacks
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index b5cd9b426a..d4be800e5e 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -229,7 +229,8 @@ class Mailer(object):
if room_vars['notifs'] and 'messages' in room_vars['notifs'][-1]:
prev_messages = room_vars['notifs'][-1]['messages']
for message in notifvars['messages']:
- pm = filter(lambda pm: pm['id'] == message['id'], prev_messages)
+ pm = list(filter(lambda pm: pm['id'] == message['id'],
+ prev_messages))
if pm:
if not message["is_historical"]:
pm[0]["is_historical"] = False
diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py
index 277da3cd35..43f0c74ff3 100644
--- a/synapse/push/presentable_names.py
+++ b/synapse/push/presentable_names.py
@@ -113,7 +113,7 @@ def calculate_room_name(store, room_state_ids, user_id, fallback_to_members=True
# so find out who is in the room that isn't the user.
if "m.room.member" in room_state_bytype_ids:
member_events = yield store.get_events(
- room_state_bytype_ids["m.room.member"].values()
+ list(room_state_bytype_ids["m.room.member"].values())
)
all_members = [
ev for ev in member_events.values()
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 3601f2d365..cf735f7468 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -21,6 +21,8 @@ from synapse.types import UserID
from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache
from synapse.util.caches.lrucache import LruCache
+from six import string_types
+
logger = logging.getLogger(__name__)
@@ -150,7 +152,7 @@ class PushRuleEvaluatorForEvent(object):
# Caches (glob, word_boundary) -> regex for push. See _glob_matches
regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR)
-register_cache("regex_push_cache", regex_cache)
+register_cache("cache", "regex_push_cache", regex_cache)
def _glob_matches(glob, value, word_boundary=False):
@@ -238,7 +240,7 @@ def _flatten_dict(d, prefix=[], result=None):
if result is None:
result = {}
for key, value in d.items():
- if isinstance(value, basestring):
+ if isinstance(value, string_types):
result[".".join(prefix + [key])] = value.lower()
elif hasattr(value, "items"):
_flatten_dict(value, prefix=(prefix + [key]), result=result)
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index 71576330a9..5aa6667e91 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from httppusher import HttpPusher
+from .httppusher import HttpPusher
import logging
logger = logging.getLogger(__name__)
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 134e89b371..750d11ca38 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -14,13 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import logging
+
from twisted.internet import defer
-from .pusher import PusherFactory
-from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
+from synapse.push.pusher import PusherFactory
from synapse.util.async import run_on_reactor
-
-import logging
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
@@ -137,12 +137,15 @@ class PusherPool:
if u in self.pushers:
for p in self.pushers[u].values():
deferreds.append(
- preserve_fn(p.on_new_notifications)(
- min_stream_id, max_stream_id
+ run_in_background(
+ p.on_new_notifications,
+ min_stream_id, max_stream_id,
)
)
- yield make_deferred_yieldable(defer.gatherResults(deferreds))
+ yield make_deferred_yieldable(
+ defer.gatherResults(deferreds, consumeErrors=True),
+ )
except Exception:
logger.exception("Exception in pusher on_new_notifications")
@@ -164,10 +167,15 @@ class PusherPool:
if u in self.pushers:
for p in self.pushers[u].values():
deferreds.append(
- preserve_fn(p.on_new_receipts)(min_stream_id, max_stream_id)
+ run_in_background(
+ p.on_new_receipts,
+ min_stream_id, max_stream_id,
+ )
)
- yield make_deferred_yieldable(defer.gatherResults(deferreds))
+ yield make_deferred_yieldable(
+ defer.gatherResults(deferreds, consumeErrors=True),
+ )
except Exception:
logger.exception("Exception in pusher on_new_receipts")
@@ -207,7 +215,7 @@ class PusherPool:
if appid_pushkey in byuser:
byuser[appid_pushkey].on_stop()
byuser[appid_pushkey] = p
- preserve_fn(p.on_started)()
+ run_in_background(p.on_started)
logger.info("Started pushers")
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 711cbb6c50..478c497722 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -39,12 +39,11 @@ REQUIREMENTS = {
"signedjson>=1.0.0": ["signedjson>=1.0.0"],
"pynacl>=1.2.1": ["nacl>=1.2.1", "nacl.bindings"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
+ "Twisted>=16.0.0": ["twisted>=16.0.0"],
- # we break under Twisted 18.4
- # (https://github.com/matrix-org/synapse/issues/3135)
- "Twisted>=16.0.0,<18.4": ["twisted>=16.0.0"],
+ # We use crypto.get_elliptic_curve which is only supported in >=0.15
+ "pyopenssl>=0.15": ["OpenSSL>=0.15"],
- "pyopenssl>=0.14": ["OpenSSL>=0.14"],
"pyyaml": ["yaml"],
"pyasn1": ["pyasn1"],
"daemonize": ["daemonize"],
@@ -57,6 +56,7 @@ REQUIREMENTS = {
"msgpack-python>=0.3.0": ["msgpack"],
"phonenumbers>=8.2.0": ["phonenumbers"],
"six": ["six"],
+ "prometheus_client": ["prometheus_client"],
}
CONDITIONAL_REQUIREMENTS = {
"web_client": {
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 0a9a290af4..c870475cd1 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -53,28 +53,28 @@ from twisted.internet import defer
from twisted.protocols.basic import LineOnlyReceiver
from twisted.python.failure import Failure
-from commands import (
+from .commands import (
COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS,
ErrorCommand, ServerCommand, RdataCommand, PositionCommand, PingCommand,
NameCommand, ReplicateCommand, UserSyncCommand, SyncCommand,
)
-from streams import STREAMS_MAP
+from .streams import STREAMS_MAP
+from synapse.metrics import LaterGauge
from synapse.util.stringutils import random_string
-from synapse.metrics.metric import CounterMetric
-import logging
-import synapse.metrics
-import struct
-import fcntl
+from prometheus_client import Counter
+from collections import defaultdict
-metrics = synapse.metrics.get_metrics_for(__name__)
+from six import iterkeys, iteritems
-connection_close_counter = metrics.register_counter(
- "close_reason", labels=["reason_type"],
-)
+import logging
+import struct
+import fcntl
+connection_close_counter = Counter(
+ "synapse_replication_tcp_protocol_close_reason", "", ["reason_type"])
# A list of all connected protocols. This allows us to send metrics about the
# connections.
@@ -136,12 +136,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# The LoopingCall for sending pings.
self._send_ping_loop = None
- self.inbound_commands_counter = CounterMetric(
- "inbound_commands", labels=["command"],
- )
- self.outbound_commands_counter = CounterMetric(
- "outbound_commands", labels=["command"],
- )
+ self.inbound_commands_counter = defaultdict(int)
+ self.outbound_commands_counter = defaultdict(int)
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
@@ -201,7 +197,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.last_received_command = self.clock.time_msec()
- self.inbound_commands_counter.inc(cmd_name)
+ self.inbound_commands_counter[cmd_name] = (
+ self.inbound_commands_counter[cmd_name] + 1)
cmd_cls = COMMAND_MAP[cmd_name]
try:
@@ -251,8 +248,8 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self._queue_command(cmd)
return
- self.outbound_commands_counter.inc(cmd.NAME)
-
+ self.outbound_commands_counter[cmd.NAME] = (
+ self.outbound_commands_counter[cmd.NAME] + 1)
string = "%s %s" % (cmd.NAME, cmd.to_line(),)
if "\n" in string:
raise Exception("Unexpected newline in command: %r", string)
@@ -317,9 +314,9 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
def connectionLost(self, reason):
logger.info("[%s] Replication connection closed: %r", self.id(), reason)
if isinstance(reason, Failure):
- connection_close_counter.inc(reason.type.__name__)
+ connection_close_counter.labels(reason.type.__name__).inc()
else:
- connection_close_counter.inc(reason.__class__.__name__)
+ connection_close_counter.labels(reason.__class__.__name__).inc()
try:
# Remove us from list of connections to be monitored
@@ -392,7 +389,7 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
if stream_name == "ALL":
# Subscribe to all streams we're publishing to.
- for stream in self.streamer.streams_by_name.iterkeys():
+ for stream in iterkeys(self.streamer.streams_by_name):
self.subscribe_to_stream(stream, token)
else:
self.subscribe_to_stream(stream_name, token)
@@ -498,7 +495,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
BaseReplicationStreamProtocol.connectionMade(self)
# Once we've connected subscribe to the necessary streams
- for stream_name, token in self.handler.get_streams_to_replicate().iteritems():
+ for stream_name, token in iteritems(self.handler.get_streams_to_replicate()):
self.replicate(stream_name, token)
# Tell the server if we have any users currently syncing (should only
@@ -518,7 +515,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
def on_RDATA(self, cmd):
stream_name = cmd.stream_name
- inbound_rdata_count.inc(stream_name)
+ inbound_rdata_count.labels(stream_name).inc()
try:
row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row)
@@ -566,14 +563,12 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# The following simply registers metrics for the replication connections
-metrics.register_callback(
- "pending_commands",
+pending_commands = LaterGauge(
+ "pending_commands", "", ["name", "conn_id"],
lambda: {
(p.name, p.conn_id): len(p.pending_commands)
for p in connected_connections
- },
- labels=["name", "conn_id"],
-)
+ })
def transport_buffer_size(protocol):
@@ -583,14 +578,12 @@ def transport_buffer_size(protocol):
return 0
-metrics.register_callback(
- "transport_send_buffer",
+transport_send_buffer = LaterGauge(
+ "synapse_replication_tcp_transport_send_buffer", "", ["name", "conn_id"],
lambda: {
(p.name, p.conn_id): transport_buffer_size(p)
for p in connected_connections
- },
- labels=["name", "conn_id"],
-)
+ })
def transport_kernel_read_buffer_size(protocol, read=True):
@@ -608,48 +601,38 @@ def transport_kernel_read_buffer_size(protocol, read=True):
return 0
-metrics.register_callback(
- "transport_kernel_send_buffer",
+tcp_transport_kernel_send_buffer = LaterGauge(
+ "synapse_replication_tcp_transport_kernel_send_buffer", "", ["name", "conn_id"],
lambda: {
(p.name, p.conn_id): transport_kernel_read_buffer_size(p, False)
for p in connected_connections
- },
- labels=["name", "conn_id"],
-)
+ })
-metrics.register_callback(
- "transport_kernel_read_buffer",
+tcp_transport_kernel_read_buffer = LaterGauge(
+ "synapse_replication_tcp_transport_kernel_read_buffer", "", ["name", "conn_id"],
lambda: {
(p.name, p.conn_id): transport_kernel_read_buffer_size(p, True)
for p in connected_connections
- },
- labels=["name", "conn_id"],
-)
+ })
-metrics.register_callback(
- "inbound_commands",
+tcp_inbound_commands = LaterGauge(
+ "synapse_replication_tcp_inbound_commands", "", ["command", "name", "conn_id"],
lambda: {
(k[0], p.name, p.conn_id): count
for p in connected_connections
- for k, count in p.inbound_commands_counter.counts.iteritems()
- },
- labels=["command", "name", "conn_id"],
-)
+ for k, count in iteritems(p.inbound_commands_counter)
+ })
-metrics.register_callback(
- "outbound_commands",
+tcp_outbound_commands = LaterGauge(
+ "synapse_replication_tcp_outbound_commands", "", ["command", "name", "conn_id"],
lambda: {
(k[0], p.name, p.conn_id): count
for p in connected_connections
- for k, count in p.outbound_commands_counter.counts.iteritems()
- },
- labels=["command", "name", "conn_id"],
-)
+ for k, count in iteritems(p.outbound_commands_counter)
+ })
# number of updates received for each RDATA stream
-inbound_rdata_count = metrics.register_counter(
- "inbound_rdata_count",
- labels=["stream_name"],
-)
+inbound_rdata_count = Counter("synapse_replication_tcp_inbound_rdata_count", "",
+ ["stream_name"])
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 786c3fe864..63bd6d2652 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -18,24 +18,25 @@
from twisted.internet import defer, reactor
from twisted.internet.protocol import Factory
-from streams import STREAMS_MAP, FederationStream
-from protocol import ServerReplicationStreamProtocol
+from .streams import STREAMS_MAP, FederationStream
+from .protocol import ServerReplicationStreamProtocol
from synapse.util.metrics import Measure, measure_func
+from synapse.metrics import LaterGauge
import logging
-import synapse.metrics
+from prometheus_client import Counter
+from six import itervalues
-metrics = synapse.metrics.get_metrics_for(__name__)
-stream_updates_counter = metrics.register_counter(
- "stream_updates", labels=["stream_name"]
-)
-user_sync_counter = metrics.register_counter("user_sync")
-federation_ack_counter = metrics.register_counter("federation_ack")
-remove_pusher_counter = metrics.register_counter("remove_pusher")
-invalidate_cache_counter = metrics.register_counter("invalidate_cache")
-user_ip_cache_counter = metrics.register_counter("user_ip_cache")
+stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates",
+ "", ["stream_name"])
+user_sync_counter = Counter("synapse_replication_tcp_resource_user_sync", "")
+federation_ack_counter = Counter("synapse_replication_tcp_resource_federation_ack", "")
+remove_pusher_counter = Counter("synapse_replication_tcp_resource_remove_pusher", "")
+invalidate_cache_counter = Counter("synapse_replication_tcp_resource_invalidate_cache",
+ "")
+user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
logger = logging.getLogger(__name__)
@@ -69,33 +70,34 @@ class ReplicationStreamer(object):
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
+ self._server_notices_sender = hs.get_server_notices_sender()
# Current connections.
self.connections = []
- metrics.register_callback("total_connections", lambda: len(self.connections))
+ LaterGauge("synapse_replication_tcp_resource_total_connections", "", [],
+ lambda: len(self.connections))
# List of streams that clients can subscribe to.
# We only support federation stream if federation sending hase been
# disabled on the master.
self.streams = [
- stream(hs) for stream in STREAMS_MAP.itervalues()
+ stream(hs) for stream in itervalues(STREAMS_MAP)
if stream != FederationStream or not hs.config.send_federation
]
self.streams_by_name = {stream.NAME: stream for stream in self.streams}
- metrics.register_callback(
- "connections_per_stream",
+ LaterGauge(
+ "synapse_replication_tcp_resource_connections_per_stream", "",
+ ["stream_name"],
lambda: {
(stream_name,): len([
conn for conn in self.connections
if stream_name in conn.replication_streams
])
for stream_name in self.streams_by_name
- },
- labels=["stream_name"],
- )
+ })
self.federation_sender = None
if not hs.config.send_federation:
@@ -175,7 +177,7 @@ class ReplicationStreamer(object):
logger.info(
"Streaming: %s -> %s", stream.NAME, updates[-1][0]
)
- stream_updates_counter.inc_by(len(updates), stream.NAME)
+ stream_updates_counter.labels(stream.NAME).inc(len(updates))
# Some streams return multiple rows with the same stream IDs,
# we need to make sure they get sent out in batches. We do
@@ -253,6 +255,7 @@ class ReplicationStreamer(object):
yield self.store.insert_client_ip(
user_id, access_token, ip, user_agent, device_id, last_seen,
)
+ yield self._server_notices_sender.on_user_ip(user_id)
def send_sync_to_all_connections(self, data):
"""Sends a SYNC command to all clients.
diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index fceca2edeb..7c01b438cb 100644
--- a/synapse/rest/client/transactions.py
+++ b/synapse/rest/client/transactions.py
@@ -19,6 +19,7 @@ import logging
from synapse.api.auth import get_access_token_from_request
from synapse.util.async import ObservableDeferred
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
@@ -80,31 +81,30 @@ class HttpTransactionCache(object):
Returns:
Deferred which resolves to a tuple of (response_code, response_dict).
"""
- try:
- return self.transactions[txn_key][0].observe()
- except (KeyError, IndexError):
- pass # execute the function instead.
-
- deferred = fn(*args, **kwargs)
-
- # if the request fails with a Twisted failure, remove it
- # from the transaction map. This is done to ensure that we don't
- # cache transient errors like rate-limiting errors, etc.
- def remove_from_map(err):
- self.transactions.pop(txn_key, None)
- return err
- deferred.addErrback(remove_from_map)
-
- # We don't add any other errbacks to the raw deferred, so we ask
- # ObservableDeferred to swallow the error. This is fine as the error will
- # still be reported to the observers.
- observable = ObservableDeferred(deferred, consumeErrors=True)
- self.transactions[txn_key] = (observable, self.clock.time_msec())
- return observable.observe()
+ if txn_key in self.transactions:
+ observable = self.transactions[txn_key][0]
+ else:
+ # execute the function instead.
+ deferred = run_in_background(fn, *args, **kwargs)
+
+ observable = ObservableDeferred(deferred)
+ self.transactions[txn_key] = (observable, self.clock.time_msec())
+
+ # if the request fails with an exception, remove it
+ # from the transaction map. This is done to ensure that we don't
+ # cache transient errors like rate-limiting errors, etc.
+ def remove_from_map(err):
+ self.transactions.pop(txn_key, None)
+ # we deliberately do not propagate the error any further, as we
+ # expect the observers to have reported it.
+
+ deferred.addErrback(remove_from_map)
+
+ return make_deferred_yieldable(observable.observe())
def _cleanup(self):
now = self.clock.time_msec()
- for key in self.transactions.keys():
+ for key in list(self.transactions):
ts = self.transactions[key][1]
if now > (ts + CLEANUP_PERIOD_MS): # after cleanup period
del self.transactions[key]
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index 303419d281..6835a7bba2 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -151,10 +151,11 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
if event.room_id != room_id:
raise SynapseError(400, "Event is for wrong room.")
- depth = event.depth
+ token = yield self.store.get_topological_token_for_event(event_id)
+
logger.info(
- "[purge] purging up to depth %i (event_id %s)",
- depth, event_id,
+ "[purge] purging up to token %s (event_id %s)",
+ token, event_id,
)
elif 'purge_up_to_ts' in body:
ts = body['purge_up_to_ts']
@@ -168,15 +169,30 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
yield self.store.find_first_stream_ordering_after_ts(ts)
)
- (_, depth, _) = (
+ room_event_after_stream_ordering = (
yield self.store.get_room_event_after_stream_ordering(
room_id, stream_ordering,
)
)
+ if room_event_after_stream_ordering:
+ token = yield self.store.get_topological_token_for_event(
+ room_event_after_stream_ordering,
+ )
+ else:
+ logger.warn(
+ "[purge] purging events not possible: No event found "
+ "(received_ts %i => stream_ordering %i)",
+ ts, stream_ordering,
+ )
+ raise SynapseError(
+ 404,
+ "there is no event to be purged",
+ errcode=Codes.NOT_FOUND,
+ )
logger.info(
- "[purge] purging up to depth %i (received_ts %i => "
+ "[purge] purging up to token %d (received_ts %i => "
"stream_ordering %i)",
- depth, ts, stream_ordering,
+ token, ts, stream_ordering,
)
else:
raise SynapseError(
@@ -186,7 +202,7 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
)
purge_id = yield self.handlers.message_handler.start_purge_history(
- room_id, depth,
+ room_id, token,
delete_local_events=delete_local_events,
)
@@ -260,8 +276,8 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(ShutdownRoomRestServlet, self).__init__(hs)
self.store = hs.get_datastore()
- self.handlers = hs.get_handlers()
self.state = hs.get_state_handler()
+ self._room_creation_handler = hs.get_room_creation_handler()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
@@ -283,7 +299,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
message = content.get("message", self.DEFAULT_MESSAGE)
room_name = content.get("room_name", "Content Violation Notification")
- info = yield self.handlers.room_creation_handler.create_room(
+ info = yield self._room_creation_handler.create_room(
room_creator_requester,
config={
"preset": "public_chat",
diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py
index c7aa0bbf59..197335d7aa 100644
--- a/synapse/rest/client/v1/base.py
+++ b/synapse/rest/client/v1/base.py
@@ -52,6 +52,10 @@ class ClientV1RestServlet(RestServlet):
"""A base Synapse REST Servlet for the client version 1 API.
"""
+ # This subclass was presumably created to allow the auth for the v1
+ # protocol version to be different, however this behaviour was removed.
+ # it may no longer be necessary
+
def __init__(self, hs):
"""
Args:
@@ -59,5 +63,5 @@ class ClientV1RestServlet(RestServlet):
"""
self.hs = hs
self.builder_factory = hs.get_event_builder_factory()
- self.auth = hs.get_v1auth()
+ self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs.get_clock())
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 45844aa2d2..34df5be4e9 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -25,7 +25,7 @@ from .base import ClientV1RestServlet, client_path_patterns
import simplejson as json
import urllib
-import urlparse
+from six.moves.urllib import parse as urlparse
import logging
from saml2 import BINDING_HTTP_POST
diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index 4a73813c58..647994bd53 100644
--- a/synapse/rest/client/v1/presence.py
+++ b/synapse/rest/client/v1/presence.py
@@ -23,6 +23,8 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.http.servlet import parse_json_object_from_request
from .base import ClientV1RestServlet, client_path_patterns
+from six import string_types
+
import logging
logger = logging.getLogger(__name__)
@@ -71,7 +73,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
if "status_msg" in content:
state["status_msg"] = content.pop("status_msg")
- if not isinstance(state["status_msg"], basestring):
+ if not isinstance(state["status_msg"], string_types):
raise SynapseError(400, "status_msg must be a string.")
if content:
@@ -129,7 +131,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
if "invite" in content:
for u in content["invite"]:
- if not isinstance(u, basestring):
+ if not isinstance(u, string_types):
raise SynapseError(400, "Bad invite value.")
if len(u) == 0:
continue
@@ -140,7 +142,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
if "drop" in content:
for u in content["drop"]:
- if not isinstance(u, basestring):
+ if not isinstance(u, string_types):
raise SynapseError(400, "Bad drop value.")
if len(u) == 0:
continue
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 1819a560cb..40e523cc5f 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -150,7 +150,7 @@ class PushersRemoveRestServlet(RestServlet):
super(RestServlet, self).__init__()
self.hs = hs
self.notifier = hs.get_notifier()
- self.auth = hs.get_v1auth()
+ self.auth = hs.get_auth()
self.pusher_pool = self.hs.get_pusherpool()
@defer.inlineCallbacks
@@ -176,7 +176,6 @@ class PushersRemoveRestServlet(RestServlet):
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Server", self.hs.version_string)
request.setHeader(b"Content-Length", b"%d" % (
len(PushersRemoveRestServlet.SUCCESS_HTML),
))
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index 8a82097178..9b3022e0b0 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -30,6 +30,8 @@ from hashlib import sha1
import hmac
import logging
+from six import string_types
+
logger = logging.getLogger(__name__)
@@ -333,11 +335,11 @@ class RegisterRestServlet(ClientV1RestServlet):
def _do_shared_secret(self, request, register_json, session):
yield run_on_reactor()
- if not isinstance(register_json.get("mac", None), basestring):
+ if not isinstance(register_json.get("mac", None), string_types):
raise SynapseError(400, "Expected mac.")
- if not isinstance(register_json.get("user", None), basestring):
+ if not isinstance(register_json.get("user", None), string_types):
raise SynapseError(400, "Expected 'user' key.")
- if not isinstance(register_json.get("password", None), basestring):
+ if not isinstance(register_json.get("password", None), string_types):
raise SynapseError(400, "Expected 'password' key.")
if not self.hs.config.registration_shared_secret:
@@ -358,14 +360,14 @@ class RegisterRestServlet(ClientV1RestServlet):
got_mac = str(register_json["mac"])
want_mac = hmac.new(
- key=self.hs.config.registration_shared_secret,
+ key=self.hs.config.registration_shared_secret.encode(),
digestmod=sha1,
)
want_mac.update(user)
- want_mac.update("\x00")
+ want_mac.update(b"\x00")
want_mac.update(password)
- want_mac.update("\x00")
- want_mac.update("admin" if admin else "notadmin")
+ want_mac.update(b"\x00")
+ want_mac.update(b"admin" if admin else b"notadmin")
want_mac = want_mac.hexdigest()
if compare_digest(want_mac, got_mac):
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 2ad0e5943b..0b984987ed 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -28,8 +28,9 @@ from synapse.http.servlet import (
parse_json_object_from_request, parse_string, parse_integer
)
+from six.moves.urllib import parse as urlparse
+
import logging
-import urllib
import simplejson as json
logger = logging.getLogger(__name__)
@@ -40,7 +41,7 @@ class RoomCreateRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomCreateRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self._room_creation_handler = hs.get_room_creation_handler()
def register(self, http_server):
PATTERNS = "/createRoom"
@@ -63,8 +64,7 @@ class RoomCreateRestServlet(ClientV1RestServlet):
def on_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
- handler = self.handlers.room_creation_handler
- info = yield handler.create_room(
+ info = yield self._room_creation_handler.create_room(
requester, self.get_room_config(request)
)
@@ -433,7 +433,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet):
as_client_event = "raw" not in request.args
filter_bytes = request.args.get("filter", None)
if filter_bytes:
- filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8")
+ filter_json = urlparse.unquote(filter_bytes[-1]).decode("UTF-8")
event_filter = Filter(json.loads(filter_json))
else:
event_filter = None
@@ -718,8 +718,8 @@ class RoomTypingRestServlet(ClientV1RestServlet):
def on_PUT(self, request, room_id, user_id):
requester = yield self.auth.get_user_by_req(request)
- room_id = urllib.unquote(room_id)
- target_user = UserID.from_string(urllib.unquote(user_id))
+ room_id = urlparse.unquote(room_id)
+ target_user = UserID.from_string(urlparse.unquote(user_id))
content = parse_json_object_from_request(request)
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 8e5577148f..d6f3a19648 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -129,7 +129,6 @@ class AuthRestServlet(RestServlet):
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Server", self.hs.version_string)
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
@@ -175,7 +174,6 @@ class AuthRestServlet(RestServlet):
html_bytes = html.encode("utf8")
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Server", self.hs.version_string)
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
request.write(html_bytes)
diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py
index ec170109fe..66583d6778 100644
--- a/synapse/rest/client/v2_alpha/notifications.py
+++ b/synapse/rest/client/v2_alpha/notifications.py
@@ -88,7 +88,7 @@ class NotificationsServlet(RestServlet):
pa["topological_ordering"], pa["stream_ordering"]
)
returned_push_actions.append(returned_pa)
- next_token = pa["stream_ordering"]
+ next_token = str(pa["stream_ordering"])
defer.returnValue((200, {
"notifications": returned_push_actions,
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 31554e7879..fe9c02f305 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -37,6 +37,8 @@ from hashlib import sha1
from synapse.util.async import run_on_reactor
from synapse.util.ratelimitutils import FederationRateLimiter
+from six import string_types
+
# We ought to be using hmac.compare_digest() but on older pythons it doesn't
# exist. It's a _really minor_ security flaw to use plain string comparison
@@ -212,14 +214,14 @@ class RegisterRestServlet(RestServlet):
# in sessions. Pull out the username/password provided to us.
desired_password = None
if 'password' in body:
- if (not isinstance(body['password'], basestring) or
+ if (not isinstance(body['password'], string_types) or
len(body['password']) > 512):
raise SynapseError(400, "Invalid password")
desired_password = body["password"]
desired_username = None
if 'username' in body:
- if (not isinstance(body['username'], basestring) or
+ if (not isinstance(body['username'], string_types) or
len(body['username']) > 512):
raise SynapseError(400, "Invalid username")
desired_username = body['username']
@@ -247,7 +249,7 @@ class RegisterRestServlet(RestServlet):
access_token = get_access_token_from_request(request)
- if isinstance(desired_username, basestring):
+ if isinstance(desired_username, string_types):
result = yield self._do_appservice_registration(
desired_username, access_token, body
)
@@ -537,7 +539,7 @@ class RegisterRestServlet(RestServlet):
# includes the password and admin flag in the hashed text. Why are
# these different?
want_mac = hmac.new(
- key=self.hs.config.registration_shared_secret,
+ key=self.hs.config.registration_shared_secret.encode(),
msg=user,
digestmod=sha1,
).hexdigest()
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index eb91c0b293..a291cffbf1 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -85,6 +85,7 @@ class SyncRestServlet(RestServlet):
self.clock = hs.get_clock()
self.filtering = hs.get_filtering()
self.presence_handler = hs.get_presence_handler()
+ self._server_notices_sender = hs.get_server_notices_sender()
@defer.inlineCallbacks
def on_GET(self, request):
@@ -149,6 +150,9 @@ class SyncRestServlet(RestServlet):
else:
since_token = None
+ # send any outstanding server notices to the user.
+ yield self._server_notices_sender.on_user_syncing(user.to_string())
+
affect_presence = set_presence != PresenceState.OFFLINE
if affect_presence:
diff --git a/synapse/rest/consent/__init__.py b/synapse/rest/consent/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/synapse/rest/consent/__init__.py
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
new file mode 100644
index 0000000000..724911d1e6
--- /dev/null
+++ b/synapse/rest/consent/consent_resource.py
@@ -0,0 +1,222 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from hashlib import sha256
+import hmac
+import logging
+from os import path
+from six.moves import http_client
+
+import jinja2
+from jinja2 import TemplateNotFound
+from twisted.internet import defer
+from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
+
+from synapse.api.errors import NotFoundError, SynapseError, StoreError
+from synapse.config import ConfigError
+from synapse.http.server import (
+ finish_request,
+ wrap_html_request_handler,
+)
+from synapse.http.servlet import parse_string
+from synapse.types import UserID
+
+
+# language to use for the templates. TODO: figure this out from Accept-Language
+TEMPLATE_LANGUAGE = "en"
+
+logger = logging.getLogger(__name__)
+
+# use hmac.compare_digest if we have it (python 2.7.7), else just use equality
+if hasattr(hmac, "compare_digest"):
+ compare_digest = hmac.compare_digest
+else:
+ def compare_digest(a, b):
+ return a == b
+
+
+class ConsentResource(Resource):
+ """A twisted Resource to display a privacy policy and gather consent to it
+
+ When accessed via GET, returns the privacy policy via a template.
+
+ When accessed via POST, records the user's consent in the database and
+ displays a success page.
+
+ The config should include a template_dir setting which contains templates
+ for the HTML. The directory should contain one subdirectory per language
+ (eg, 'en', 'fr'), and each language directory should contain the policy
+ document (named as '<version>.html') and a success page (success.html).
+
+ Both forms take a set of parameters from the browser. For the POST form,
+ these are normally sent as form parameters (but may be query-params); for
+ GET requests they must be query params. These are:
+
+ u: the complete mxid, or the localpart of the user giving their
+ consent. Required for both GET (where it is used as an input to the
+ template) and for POST (where it is used to find the row in the db
+ to update).
+
+ h: hmac_sha256(secret, u), where 'secret' is the privacy_secret in the
+ config file. If it doesn't match, the request is 403ed.
+
+ v: the version of the privacy policy being agreed to.
+
+ For GET: optional, and defaults to whatever was set in the config
+ file. Used to choose the version of the policy to pick from the
+ templates directory.
+
+ For POST: required; gives the value to be recorded in the database
+ against the user.
+ """
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer): homeserver
+ """
+ Resource.__init__(self)
+
+ self.hs = hs
+ self.store = hs.get_datastore()
+
+ # this is required by the request_handler wrapper
+ self.clock = hs.get_clock()
+
+ self._default_consent_version = hs.config.user_consent_version
+ if self._default_consent_version is None:
+ raise ConfigError(
+ "Consent resource is enabled but user_consent section is "
+ "missing in config file.",
+ )
+
+ # daemonize changes the cwd to /, so make the path absolute now.
+ consent_template_directory = path.abspath(
+ hs.config.user_consent_template_dir,
+ )
+ if not path.isdir(consent_template_directory):
+ raise ConfigError(
+ "Could not find template directory '%s'" % (
+ consent_template_directory,
+ ),
+ )
+
+ loader = jinja2.FileSystemLoader(consent_template_directory)
+ self._jinja_env = jinja2.Environment(
+ loader=loader,
+ autoescape=jinja2.select_autoescape(['html', 'htm', 'xml']),
+ )
+
+ if hs.config.form_secret is None:
+ raise ConfigError(
+ "Consent resource is enabled but form_secret is not set in "
+ "config file. It should be set to an arbitrary secret string.",
+ )
+
+ self._hmac_secret = hs.config.form_secret.encode("utf-8")
+
+ def render_GET(self, request):
+ self._async_render_GET(request)
+ return NOT_DONE_YET
+
+ @wrap_html_request_handler
+ @defer.inlineCallbacks
+ def _async_render_GET(self, request):
+ """
+ Args:
+ request (twisted.web.http.Request):
+ """
+
+ version = parse_string(request, "v",
+ default=self._default_consent_version)
+ username = parse_string(request, "u", required=True)
+ userhmac = parse_string(request, "h", required=True)
+
+ self._check_hash(username, userhmac)
+
+ if username.startswith('@'):
+ qualified_user_id = username
+ else:
+ qualified_user_id = UserID(username, self.hs.hostname).to_string()
+
+ u = yield self.store.get_user_by_id(qualified_user_id)
+ if u is None:
+ raise NotFoundError("Unknown user")
+
+ try:
+ self._render_template(
+ request, "%s.html" % (version,),
+ user=username, userhmac=userhmac, version=version,
+ has_consented=(u["consent_version"] == version),
+ )
+ except TemplateNotFound:
+ raise NotFoundError("Unknown policy version")
+
+ def render_POST(self, request):
+ self._async_render_POST(request)
+ return NOT_DONE_YET
+
+ @wrap_html_request_handler
+ @defer.inlineCallbacks
+ def _async_render_POST(self, request):
+ """
+ Args:
+ request (twisted.web.http.Request):
+ """
+ version = parse_string(request, "v", required=True)
+ username = parse_string(request, "u", required=True)
+ userhmac = parse_string(request, "h", required=True)
+
+ self._check_hash(username, userhmac)
+
+ if username.startswith('@'):
+ qualified_user_id = username
+ else:
+ qualified_user_id = UserID(username, self.hs.hostname).to_string()
+
+ try:
+ yield self.store.user_set_consent_version(qualified_user_id, version)
+ except StoreError as e:
+ if e.code != 404:
+ raise
+ raise NotFoundError("Unknown user")
+
+ try:
+ self._render_template(request, "success.html")
+ except TemplateNotFound:
+ raise NotFoundError("success.html not found")
+
+ def _render_template(self, request, template_name, **template_args):
+ # get_template checks for ".." so we don't need to worry too much
+ # about path traversal here.
+ template_html = self._jinja_env.get_template(
+ path.join(TEMPLATE_LANGUAGE, template_name)
+ )
+ html_bytes = template_html.render(**template_args).encode("utf8")
+
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(b"Content-Length", b"%i" % len(html_bytes))
+ request.write(html_bytes)
+ finish_request(request)
+
+ def _check_hash(self, userid, userhmac):
+ want_mac = hmac.new(
+ key=self._hmac_secret,
+ msg=userid,
+ digestmod=sha256,
+ ).hexdigest()
+
+ if not compare_digest(want_mac, userhmac):
+ raise SynapseError(http_client.FORBIDDEN, "HMAC incorrect")
diff --git a/synapse/rest/key/v1/server_key_resource.py b/synapse/rest/key/v1/server_key_resource.py
index bd4fea5774..1498d188c1 100644
--- a/synapse/rest/key/v1/server_key_resource.py
+++ b/synapse/rest/key/v1/server_key_resource.py
@@ -49,7 +49,6 @@ class LocalKey(Resource):
"""
def __init__(self, hs):
- self.version_string = hs.version_string
self.response_body = encode_canonical_json(
self.response_json_object(hs.config)
)
@@ -84,7 +83,6 @@ class LocalKey(Resource):
def render_GET(self, request):
return respond_with_json_bytes(
request, 200, self.response_body,
- version_string=self.version_string
)
def getChild(self, name, request):
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index be68d9a096..04775b3c45 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -63,7 +63,6 @@ class LocalKey(Resource):
isLeaf = True
def __init__(self, hs):
- self.version_string = hs.version_string
self.config = hs.config
self.clock = hs.clock
self.update_response_body(self.clock.time_msec())
@@ -115,5 +114,4 @@ class LocalKey(Resource):
self.update_response_body(time_now)
return respond_with_json_bytes(
request, 200, self.response_body,
- version_string=self.version_string
)
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 17e6079cba..21b4c1175e 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.http.server import request_handler, respond_with_json_bytes
+from synapse.http.server import (
+ respond_with_json_bytes, wrap_json_request_handler,
+)
from synapse.http.servlet import parse_integer, parse_json_object_from_request
from synapse.api.errors import SynapseError, Codes
from synapse.crypto.keyring import KeyLookupError
@@ -91,7 +93,6 @@ class RemoteKey(Resource):
def __init__(self, hs):
self.keyring = hs.get_keyring()
self.store = hs.get_datastore()
- self.version_string = hs.version_string
self.clock = hs.get_clock()
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
@@ -99,7 +100,7 @@ class RemoteKey(Resource):
self.async_render_GET(request)
return NOT_DONE_YET
- @request_handler()
+ @wrap_json_request_handler
@defer.inlineCallbacks
def async_render_GET(self, request):
if len(request.postpath) == 1:
@@ -124,7 +125,7 @@ class RemoteKey(Resource):
self.async_render_POST(request)
return NOT_DONE_YET
- @request_handler()
+ @wrap_json_request_handler
@defer.inlineCallbacks
def async_render_POST(self, request):
content = parse_json_object_from_request(request)
@@ -240,5 +241,4 @@ class RemoteKey(Resource):
respond_with_json_bytes(
request, 200, result_io.getvalue(),
- version_string=self.version_string
)
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index e7ac01da01..c0d2f06855 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -28,7 +28,7 @@ import os
import logging
import urllib
-import urlparse
+from six.moves.urllib import parse as urlparse
logger = logging.getLogger(__name__)
@@ -143,6 +143,7 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam
respond_404(request)
return
+ logger.debug("Responding to media request with responder %s")
add_file_headers(request, media_type, file_size, upload_name)
with responder:
yield responder.write_to_consumer(request)
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index fe7e17596f..8cf8820c31 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -12,16 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import synapse.http.servlet
+import logging
-from ._base import parse_media_id, respond_404
+from twisted.internet import defer
from twisted.web.resource import Resource
-from synapse.http.server import request_handler, set_cors_headers
-
from twisted.web.server import NOT_DONE_YET
-from twisted.internet import defer
-import logging
+from synapse.http.server import (
+ set_cors_headers,
+ wrap_json_request_handler,
+)
+import synapse.http.servlet
+from ._base import parse_media_id, respond_404
logger = logging.getLogger(__name__)
@@ -35,15 +37,14 @@ class DownloadResource(Resource):
self.media_repo = media_repo
self.server_name = hs.hostname
- # Both of these are expected by @request_handler()
+ # this is expected by @wrap_json_request_handler
self.clock = hs.get_clock()
- self.version_string = hs.version_string
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
- @request_handler()
+ @wrap_json_request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
set_cors_headers(request)
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index bb79599379..2ac767d2dc 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -47,7 +47,8 @@ import shutil
import cgi
import logging
-import urlparse
+from six.moves.urllib import parse as urlparse
+from six import iteritems
logger = logging.getLogger(__name__)
@@ -603,7 +604,7 @@ class MediaRepository(object):
thumbnails[(t_width, t_height, r_type)] = r_method
# Now we generate the thumbnails for each dimension, store it
- for (t_width, t_height, t_type), t_method in thumbnails.iteritems():
+ for (t_width, t_height, t_type), t_method in iteritems(thumbnails):
# Generate the thumbnail
if t_method == "crop":
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 7f263db239..d23fe10b07 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -255,7 +255,9 @@ class FileResponder(Responder):
self.open_file = open_file
def write_to_consumer(self, consumer):
- return FileSender().beginFileTransfer(self.open_file, consumer)
+ return make_deferred_yieldable(
+ FileSender().beginFileTransfer(self.open_file, consumer)
+ )
def __exit__(self, exc_type, exc_val, exc_tb):
self.open_file.close()
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 0fc21540c6..565cef2b8d 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -24,7 +24,9 @@ import shutil
import sys
import traceback
import simplejson as json
-import urlparse
+
+from six.moves import urllib_parse as urlparse
+from six import string_types
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
@@ -35,13 +37,14 @@ from ._base import FileInfo
from synapse.api.errors import (
SynapseError, Codes,
)
-from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from synapse.util.stringutils import random_string
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.http.client import SpiderHttpClient
from synapse.http.server import (
- request_handler, respond_with_json_bytes,
+ respond_with_json_bytes,
respond_with_json,
+ wrap_json_request_handler,
)
from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii
@@ -57,7 +60,6 @@ class PreviewUrlResource(Resource):
self.auth = hs.get_auth()
self.clock = hs.get_clock()
- self.version_string = hs.version_string
self.filepaths = media_repo.filepaths
self.max_spider_size = hs.config.max_spider_size
self.server_name = hs.hostname
@@ -90,7 +92,7 @@ class PreviewUrlResource(Resource):
self._async_render_GET(request)
return NOT_DONE_YET
- @request_handler()
+ @wrap_json_request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
@@ -144,7 +146,8 @@ class PreviewUrlResource(Resource):
observable = self._cache.get(url)
if not observable:
- download = preserve_fn(self._do_preview)(
+ download = run_in_background(
+ self._do_preview,
url, requester.user, ts,
)
observable = ObservableDeferred(
@@ -589,8 +592,8 @@ def _iterate_over_text(tree, *tags_to_ignore):
# to be returned.
elements = iter([tree])
while True:
- el = elements.next()
- if isinstance(el, basestring):
+ el = next(elements)
+ if isinstance(el, string_types):
yield el
elif el is not None and el.tag not in tags_to_ignore:
# el.text is the text before the first child, so we can immediately
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index c188192f2b..0252afd9d3 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -18,7 +18,7 @@ from twisted.internet import defer, threads
from .media_storage import FileResponder
from synapse.config._base import Config
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import run_in_background
import logging
import os
@@ -87,7 +87,12 @@ class StorageProviderWrapper(StorageProvider):
return self.backend.store_file(path, file_info)
else:
# TODO: Handle errors.
- preserve_fn(self.backend.store_file)(path, file_info)
+ def store():
+ try:
+ return self.backend.store_file(path, file_info)
+ except Exception:
+ logger.exception("Error storing file")
+ run_in_background(store)
return defer.succeed(None)
def fetch(self, path, file_info):
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 58ada49711..aae6e464e8 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -14,18 +14,21 @@
# limitations under the License.
-from ._base import (
- parse_media_id, respond_404, respond_with_file, FileInfo,
- respond_with_responder,
-)
-from twisted.web.resource import Resource
-from synapse.http.servlet import parse_string, parse_integer
-from synapse.http.server import request_handler, set_cors_headers
+import logging
-from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
+from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
-import logging
+from synapse.http.server import (
+ set_cors_headers,
+ wrap_json_request_handler,
+)
+from synapse.http.servlet import parse_integer, parse_string
+from ._base import (
+ FileInfo, parse_media_id, respond_404, respond_with_file,
+ respond_with_responder,
+)
logger = logging.getLogger(__name__)
@@ -41,14 +44,13 @@ class ThumbnailResource(Resource):
self.media_storage = media_storage
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname
- self.version_string = hs.version_string
self.clock = hs.get_clock()
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
- @request_handler()
+ @wrap_json_request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
set_cors_headers(request)
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index f6f498cdc5..7567476fce 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -13,16 +13,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.http.server import respond_with_json, request_handler
-
-from synapse.api.errors import SynapseError
+import logging
-from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
-
from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
-import logging
+from synapse.api.errors import SynapseError
+from synapse.http.server import (
+ respond_with_json,
+ wrap_json_request_handler,
+)
logger = logging.getLogger(__name__)
@@ -40,7 +41,6 @@ class UploadResource(Resource):
self.server_name = hs.hostname
self.auth = hs.get_auth()
self.max_upload_size = hs.config.max_upload_size
- self.version_string = hs.version_string
self.clock = hs.get_clock()
def render_POST(self, request):
@@ -51,7 +51,7 @@ class UploadResource(Resource):
respond_with_json(request, 200, {}, send_cors=True)
return NOT_DONE_YET
- @request_handler()
+ @wrap_json_request_handler
@defer.inlineCallbacks
def _async_render_POST(self, request):
requester = yield self.auth.get_user_by_req(request)
@@ -81,15 +81,15 @@ class UploadResource(Resource):
headers = request.requestHeaders
if headers.hasHeader("Content-Type"):
- media_type = headers.getRawHeaders("Content-Type")[0]
+ media_type = headers.getRawHeaders(b"Content-Type")[0]
else:
raise SynapseError(
msg="Upload request missing 'Content-Type'",
code=400,
)
- # if headers.hasHeader("Content-Disposition"):
- # disposition = headers.getRawHeaders("Content-Disposition")[0]
+ # if headers.hasHeader(b"Content-Disposition"):
+ # disposition = headers.getRawHeaders(b"Content-Disposition")[0]
# TODO(markjh): parse content-dispostion
content_uri = yield self.media_repo.create_content(
diff --git a/synapse/server.py b/synapse/server.py
index cd0c1a51be..58dbf78437 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -46,6 +46,7 @@ from synapse.handlers.devicemessage import DeviceMessageHandler
from synapse.handlers.device import DeviceHandler
from synapse.handlers.e2e_keys import E2eKeysHandler
from synapse.handlers.presence import PresenceHandler
+from synapse.handlers.room import RoomCreationHandler
from synapse.handlers.room_list import RoomListHandler
from synapse.handlers.room_member import RoomMemberMasterHandler
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
@@ -71,6 +72,11 @@ from synapse.rest.media.v1.media_repository import (
MediaRepository,
MediaRepositoryResource,
)
+from synapse.server_notices.server_notices_manager import ServerNoticesManager
+from synapse.server_notices.server_notices_sender import ServerNoticesSender
+from synapse.server_notices.worker_server_notices_sender import (
+ WorkerServerNoticesSender,
+)
from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import DataStore
from synapse.streams.events import EventSources
@@ -97,6 +103,9 @@ class HomeServer(object):
which must be implemented by the subclass. This code may call any of the
required "get" methods on the instance to obtain the sub-dependencies that
one requires.
+
+ Attributes:
+ config (synapse.config.homeserver.HomeserverConfig):
"""
DEPENDENCIES = [
@@ -105,8 +114,8 @@ class HomeServer(object):
'federation_client',
'federation_server',
'handlers',
- 'v1auth',
'auth',
+ 'room_creation_handler',
'state_handler',
'state_resolution_handler',
'presence_handler',
@@ -152,6 +161,8 @@ class HomeServer(object):
'spam_checker',
'room_member_handler',
'federation_registry',
+ 'server_notices_manager',
+ 'server_notices_sender',
]
def __init__(self, hostname, **kwargs):
@@ -225,14 +236,8 @@ class HomeServer(object):
def build_simple_http_client(self):
return SimpleHttpClient(self)
- def build_v1auth(self):
- orf = Auth(self)
- # Matrix spec makes no reference to what HTTP status code is returned,
- # but the V1 API uses 403 where it means 401, and the webclient
- # relies on this behaviour, so V1 gets its own copy of the auth
- # with backwards compat behaviour.
- orf.TOKEN_NOT_FOUND_HTTP_STATUS = 403
- return orf
+ def build_room_creation_handler(self):
+ return RoomCreationHandler(self)
def build_state_handler(self):
return StateHandler(self)
@@ -400,6 +405,16 @@ class HomeServer(object):
def build_federation_registry(self):
return FederationHandlerRegistry()
+ def build_server_notices_manager(self):
+ if self.config.worker_app:
+ raise Exception("Workers cannot send server notices")
+ return ServerNoticesManager(self)
+
+ def build_server_notices_sender(self):
+ if self.config.worker_app:
+ return WorkerServerNoticesSender(self)
+ return ServerNoticesSender(self)
+
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/server.pyi b/synapse/server.pyi
index c3a9a3847b..ce28486233 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -1,4 +1,5 @@
import synapse.api.auth
+import synapse.config.homeserver
import synapse.federation.transaction_queue
import synapse.federation.transport.client
import synapse.handlers
@@ -8,11 +9,17 @@ import synapse.handlers.device
import synapse.handlers.e2e_keys
import synapse.handlers.set_password
import synapse.rest.media.v1.media_repository
+import synapse.server_notices.server_notices_manager
+import synapse.server_notices.server_notices_sender
import synapse.state
import synapse.storage
class HomeServer(object):
+ @property
+ def config(self) -> synapse.config.homeserver.HomeServerConfig:
+ pass
+
def get_auth(self) -> synapse.api.auth.Auth:
pass
@@ -40,6 +47,12 @@ class HomeServer(object):
def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
pass
+ def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler:
+ pass
+
+ def get_event_creation_handler(self) -> synapse.handlers.message.EventCreationHandler:
+ pass
+
def get_set_password_handler(self) -> synapse.handlers.set_password.SetPasswordHandler:
pass
@@ -54,3 +67,9 @@ class HomeServer(object):
def get_media_repository(self) -> synapse.rest.media.v1.media_repository.MediaRepository:
pass
+
+ def get_server_notices_manager(self) -> synapse.server_notices.server_notices_manager.ServerNoticesManager:
+ pass
+
+ def get_server_notices_sender(self) -> synapse.server_notices.server_notices_sender.ServerNoticesSender:
+ pass
diff --git a/synapse/server_notices/__init__.py b/synapse/server_notices/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/synapse/server_notices/__init__.py
diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py
new file mode 100644
index 0000000000..bb74af1af5
--- /dev/null
+++ b/synapse/server_notices/consent_server_notices.py
@@ -0,0 +1,137 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from six import (iteritems, string_types)
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.api.urls import ConsentURIBuilder
+from synapse.config import ConfigError
+from synapse.types import get_localpart_from_id
+
+logger = logging.getLogger(__name__)
+
+
+class ConsentServerNotices(object):
+ """Keeps track of whether we need to send users server_notices about
+ privacy policy consent, and sends one if we do.
+ """
+ def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer):
+ """
+ self._server_notices_manager = hs.get_server_notices_manager()
+ self._store = hs.get_datastore()
+
+ self._users_in_progress = set()
+
+ self._current_consent_version = hs.config.user_consent_version
+ self._server_notice_content = hs.config.user_consent_server_notice_content
+ self._send_to_guests = hs.config.user_consent_server_notice_to_guests
+
+ if self._server_notice_content is not None:
+ if not self._server_notices_manager.is_enabled():
+ raise ConfigError(
+ "user_consent configuration requires server notices, but "
+ "server notices are not enabled.",
+ )
+ if 'body' not in self._server_notice_content:
+ raise ConfigError(
+ "user_consent server_notice_consent must contain a 'body' "
+ "key.",
+ )
+
+ self._consent_uri_builder = ConsentURIBuilder(hs.config)
+
+ @defer.inlineCallbacks
+ def maybe_send_server_notice_to_user(self, user_id):
+ """Check if we need to send a notice to this user, and does so if so
+
+ Args:
+ user_id (str): user to check
+
+ Returns:
+ Deferred
+ """
+ if self._server_notice_content is None:
+ # not enabled
+ return
+
+ # make sure we don't send two messages to the same user at once
+ if user_id in self._users_in_progress:
+ return
+ self._users_in_progress.add(user_id)
+ try:
+ u = yield self._store.get_user_by_id(user_id)
+
+ if u["is_guest"] and not self._send_to_guests:
+ # don't send to guests
+ return
+
+ if u["consent_version"] == self._current_consent_version:
+ # user has already consented
+ return
+
+ if u["consent_server_notice_sent"] == self._current_consent_version:
+ # we've already sent a notice to the user
+ return
+
+ # need to send a message.
+ try:
+ consent_uri = self._consent_uri_builder.build_user_consent_uri(
+ get_localpart_from_id(user_id),
+ )
+ content = copy_with_str_subst(
+ self._server_notice_content, {
+ 'consent_uri': consent_uri,
+ },
+ )
+ yield self._server_notices_manager.send_notice(
+ user_id, content,
+ )
+ yield self._store.user_set_consent_server_notice_sent(
+ user_id, self._current_consent_version,
+ )
+ except SynapseError as e:
+ logger.error("Error sending server notice about user consent: %s", e)
+ finally:
+ self._users_in_progress.remove(user_id)
+
+
+def copy_with_str_subst(x, substitutions):
+ """Deep-copy a structure, carrying out string substitions on any strings
+
+ Args:
+ x (object): structure to be copied
+ substitutions (object): substitutions to be made - passed into the
+ string '%' operator
+
+ Returns:
+ copy of x
+ """
+ if isinstance(x, string_types):
+ return x % substitutions
+ if isinstance(x, dict):
+ return {
+ k: copy_with_str_subst(v, substitutions) for (k, v) in iteritems(x)
+ }
+ if isinstance(x, (list, tuple)):
+ return [copy_with_str_subst(y) for y in x]
+
+ # assume it's uninterested and can be shallow-copied.
+ return x
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
new file mode 100644
index 0000000000..a26deace53
--- /dev/null
+++ b/synapse/server_notices/server_notices_manager.py
@@ -0,0 +1,146 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership, RoomCreationPreset
+from synapse.types import create_requester
+from synapse.util.caches.descriptors import cachedInlineCallbacks
+
+logger = logging.getLogger(__name__)
+
+
+class ServerNoticesManager(object):
+ def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer):
+ """
+
+ self._store = hs.get_datastore()
+ self._config = hs.config
+ self._room_creation_handler = hs.get_room_creation_handler()
+ self._event_creation_handler = hs.get_event_creation_handler()
+ self._is_mine_id = hs.is_mine_id
+
+ def is_enabled(self):
+ """Checks if server notices are enabled on this server.
+
+ Returns:
+ bool
+ """
+ return self._config.server_notices_mxid is not None
+
+ @defer.inlineCallbacks
+ def send_notice(self, user_id, event_content):
+ """Send a notice to the given user
+
+ Creates the server notices room, if none exists.
+
+ Args:
+ user_id (str): mxid of user to send event to.
+ event_content (dict): content of event to send
+
+ Returns:
+ Deferred[None]
+ """
+ room_id = yield self.get_notice_room_for_user(user_id)
+
+ system_mxid = self._config.server_notices_mxid
+ requester = create_requester(system_mxid)
+
+ logger.info("Sending server notice to %s", user_id)
+
+ yield self._event_creation_handler.create_and_send_nonmember_event(
+ requester, {
+ "type": EventTypes.Message,
+ "room_id": room_id,
+ "sender": system_mxid,
+ "content": event_content,
+ },
+ ratelimit=False,
+ )
+
+ @cachedInlineCallbacks()
+ def get_notice_room_for_user(self, user_id):
+ """Get the room for notices for a given user
+
+ If we have not yet created a notice room for this user, create it
+
+ Args:
+ user_id (str): complete user id for the user we want a room for
+
+ Returns:
+ str: room id of notice room.
+ """
+ if not self.is_enabled():
+ raise Exception("Server notices not enabled")
+
+ assert self._is_mine_id(user_id), \
+ "Cannot send server notices to remote users"
+
+ rooms = yield self._store.get_rooms_for_user_where_membership_is(
+ user_id, [Membership.INVITE, Membership.JOIN],
+ )
+ system_mxid = self._config.server_notices_mxid
+ for room in rooms:
+ # it's worth noting that there is an asymmetry here in that we
+ # expect the user to be invited or joined, but the system user must
+ # be joined. This is kinda deliberate, in that if somebody somehow
+ # manages to invite the system user to a room, that doesn't make it
+ # the server notices room.
+ user_ids = yield self._store.get_users_in_room(room.room_id)
+ if system_mxid in user_ids:
+ # we found a room which our user shares with the system notice
+ # user
+ logger.info("Using room %s", room.room_id)
+ defer.returnValue(room.room_id)
+
+ # apparently no existing notice room: create a new one
+ logger.info("Creating server notices room for %s", user_id)
+
+ # see if we want to override the profile info for the server user.
+ # note that if we want to override either the display name or the
+ # avatar, we have to use both.
+ join_profile = None
+ if (
+ self._config.server_notices_mxid_display_name is not None or
+ self._config.server_notices_mxid_avatar_url is not None
+ ):
+ join_profile = {
+ "displayname": self._config.server_notices_mxid_display_name,
+ "avatar_url": self._config.server_notices_mxid_avatar_url,
+ }
+
+ requester = create_requester(system_mxid)
+ info = yield self._room_creation_handler.create_room(
+ requester,
+ config={
+ "preset": RoomCreationPreset.PRIVATE_CHAT,
+ "name": self._config.server_notices_room_name,
+ "power_level_content_override": {
+ "users_default": -10,
+ },
+ "invite": (user_id,)
+ },
+ ratelimit=False,
+ creator_join_profile=join_profile,
+ )
+ room_id = info['room_id']
+
+ logger.info("Created server notices room %s for %s", room_id, user_id)
+ defer.returnValue(room_id)
diff --git a/synapse/server_notices/server_notices_sender.py b/synapse/server_notices/server_notices_sender.py
new file mode 100644
index 0000000000..5d23965f34
--- /dev/null
+++ b/synapse/server_notices/server_notices_sender.py
@@ -0,0 +1,58 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.server_notices.consent_server_notices import ConsentServerNotices
+
+
+class ServerNoticesSender(object):
+ """A centralised place which sends server notices automatically when
+ Certain Events take place
+ """
+ def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer):
+ """
+ # todo: it would be nice to make this more dynamic
+ self._consent_server_notices = ConsentServerNotices(hs)
+
+ def on_user_syncing(self, user_id):
+ """Called when the user performs a sync operation.
+
+ Args:
+ user_id (str): mxid of user who synced
+
+ Returns:
+ Deferred
+ """
+ return self._consent_server_notices.maybe_send_server_notice_to_user(
+ user_id,
+ )
+
+ def on_user_ip(self, user_id):
+ """Called on the master when a worker process saw a client request.
+
+ Args:
+ user_id (str): mxid
+
+ Returns:
+ Deferred
+ """
+ # The synchrotrons use a stubbed version of ServerNoticesSender, so
+ # we check for notices to send to the user in on_user_ip as well as
+ # in on_user_syncing
+ return self._consent_server_notices.maybe_send_server_notice_to_user(
+ user_id,
+ )
diff --git a/synapse/server_notices/worker_server_notices_sender.py b/synapse/server_notices/worker_server_notices_sender.py
new file mode 100644
index 0000000000..4a133026c3
--- /dev/null
+++ b/synapse/server_notices/worker_server_notices_sender.py
@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.internet import defer
+
+
+class WorkerServerNoticesSender(object):
+ """Stub impl of ServerNoticesSender which does nothing"""
+ def __init__(self, hs):
+ """
+ Args:
+ hs (synapse.server.HomeServer):
+ """
+
+ def on_user_syncing(self, user_id):
+ """Called when the user performs a sync operation.
+
+ Args:
+ user_id (str): mxid of user who synced
+
+ Returns:
+ Deferred
+ """
+ return defer.succeed(None)
+
+ def on_user_ip(self, user_id):
+ """Called on the master when a worker process saw a client request.
+
+ Args:
+ user_id (str): mxid
+
+ Returns:
+ Deferred
+ """
+ raise AssertionError("on_user_ip unexpectedly called on worker")
diff --git a/synapse/state.py b/synapse/state.py
index 26093c8434..216418f58d 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -32,6 +32,8 @@ from frozendict import frozendict
import logging
import hashlib
+from six import iteritems, itervalues
+
logger = logging.getLogger(__name__)
@@ -130,9 +132,10 @@ class StateHandler(object):
defer.returnValue(event)
return
- state_map = yield self.store.get_events(state.values(), get_prev_content=False)
+ state_map = yield self.store.get_events(list(state.values()),
+ get_prev_content=False)
state = {
- key: state_map[e_id] for key, e_id in state.iteritems() if e_id in state_map
+ key: state_map[e_id] for key, e_id in iteritems(state) if e_id in state_map
}
defer.returnValue(state)
@@ -338,7 +341,7 @@ class StateHandler(object):
)
if len(state_groups_ids) == 1:
- name, state_list = state_groups_ids.items().pop()
+ name, state_list = list(state_groups_ids.items()).pop()
prev_group, delta_ids = yield self.store.get_state_group_delta(name)
@@ -378,7 +381,7 @@ class StateHandler(object):
new_state = resolve_events_with_state_map(state_set_ids, state_map)
new_state = {
- key: state_map[ev_id] for key, ev_id in new_state.iteritems()
+ key: state_map[ev_id] for key, ev_id in iteritems(new_state)
}
return new_state
@@ -458,15 +461,15 @@ class StateResolutionHandler(object):
# build a map from state key to the event_ids which set that state.
# dict[(str, str), set[str])
state = {}
- for st in state_groups_ids.itervalues():
- for key, e_id in st.iteritems():
+ for st in itervalues(state_groups_ids):
+ for key, e_id in iteritems(st):
state.setdefault(key, set()).add(e_id)
# build a map from state key to the event_ids which set that state,
# including only those where there are state keys in conflict.
conflicted_state = {
k: list(v)
- for k, v in state.iteritems()
+ for k, v in iteritems(state)
if len(v) > 1
}
@@ -474,13 +477,13 @@ class StateResolutionHandler(object):
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_factory(
- state_groups_ids.values(),
+ list(state_groups_ids.values()),
event_map=event_map,
state_map_factory=state_map_factory,
)
else:
new_state = {
- key: e_ids.pop() for key, e_ids in state.iteritems()
+ key: e_ids.pop() for key, e_ids in iteritems(state)
}
with Measure(self.clock, "state.create_group_ids"):
@@ -489,8 +492,8 @@ class StateResolutionHandler(object):
# which will be used as a cache key for future resolutions, but
# not get persisted.
state_group = None
- new_state_event_ids = frozenset(new_state.itervalues())
- for sg, events in state_groups_ids.iteritems():
+ new_state_event_ids = frozenset(itervalues(new_state))
+ for sg, events in iteritems(state_groups_ids):
if new_state_event_ids == frozenset(e_id for e_id in events):
state_group = sg
break
@@ -501,11 +504,11 @@ class StateResolutionHandler(object):
prev_group = None
delta_ids = None
- for old_group, old_ids in state_groups_ids.iteritems():
+ for old_group, old_ids in iteritems(state_groups_ids):
if not set(new_state) - set(old_ids):
n_delta_ids = {
k: v
- for k, v in new_state.iteritems()
+ for k, v in iteritems(new_state)
if old_ids.get(k) != v
}
if not delta_ids or len(n_delta_ids) < len(delta_ids):
@@ -527,7 +530,7 @@ class StateResolutionHandler(object):
def _ordered_events(events):
def key_func(e):
- return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
+ return -int(e.depth), hashlib.sha1(e.event_id.encode()).hexdigest()
return sorted(events, key=key_func)
@@ -584,7 +587,7 @@ def _seperate(state_sets):
conflicted_state = {}
for state_set in state_sets[1:]:
- for key, value in state_set.iteritems():
+ for key, value in iteritems(state_set):
# Check if there is an unconflicted entry for the state key.
unconflicted_value = unconflicted_state.get(key)
if unconflicted_value is None:
@@ -640,7 +643,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
needed_events = set(
event_id
- for event_ids in conflicted_state.itervalues()
+ for event_ids in itervalues(conflicted_state)
for event_id in event_ids
)
if event_map is not None:
@@ -662,7 +665,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
unconflicted_state, conflicted_state, state_map
)
- new_needed_events = set(auth_events.itervalues())
+ new_needed_events = set(itervalues(auth_events))
new_needed_events -= needed_events
if event_map is not None:
new_needed_events -= set(event_map.iterkeys())
@@ -679,7 +682,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
auth_events = {}
- for event_ids in conflicted_state.itervalues():
+ for event_ids in itervalues(conflicted_state):
for event_id in event_ids:
if event_id in state_map:
keys = event_auth.auth_types_for_event(state_map[event_id])
@@ -694,7 +697,7 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma
def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ids,
state_map):
conflicted_state = {}
- for key, event_ids in conflicted_state_ds.iteritems():
+ for key, event_ids in iteritems(conflicted_state_ds):
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
if len(events) > 1:
conflicted_state[key] = events
@@ -703,7 +706,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_
auth_events = {
key: state_map[ev_id]
- for key, ev_id in auth_event_ids.iteritems()
+ for key, ev_id in iteritems(auth_event_ids)
if ev_id in state_map
}
@@ -716,7 +719,7 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_
raise
new_state = unconflicted_state_ids
- for key, event in resolved_state.iteritems():
+ for key, event in iteritems(resolved_state):
new_state[key] = event.event_id
return new_state
@@ -741,7 +744,7 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state)
- for key, events in conflicted_state.iteritems():
+ for key, events in iteritems(conflicted_state):
if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = _resolve_auth_events(
@@ -751,7 +754,7 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state)
- for key, events in conflicted_state.iteritems():
+ for key, events in iteritems(conflicted_state):
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = _resolve_auth_events(
@@ -761,7 +764,7 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state)
- for key, events in conflicted_state.iteritems():
+ for key, events in iteritems(conflicted_state):
if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = _resolve_normal_events(
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 8cdfd50f90..979fa22438 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -14,6 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import datetime
+from dateutil import tz
+import time
+import logging
+
from synapse.storage.devices import DeviceStore
from .appservice import (
ApplicationServiceStore, ApplicationServiceTransactionStore
@@ -55,10 +60,6 @@ from .engines import PostgresEngine
from synapse.api.constants import PresenceState
from synapse.util.caches.stream_change_cache import StreamChangeCache
-
-import logging
-
-
logger = logging.getLogger(__name__)
@@ -213,6 +214,9 @@ class DataStore(RoomMemberStore, RoomStore,
self._stream_order_on_start = self.get_room_max_stream_ordering()
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
+ # Used in _generate_user_daily_visits to keep track of progress
+ self._last_user_visit_update = self._get_start_of_day()
+
super(DataStore, self).__init__(db_conn, hs)
def take_presence_startup_info(self):
@@ -347,6 +351,69 @@ class DataStore(RoomMemberStore, RoomStore,
return self.runInteraction("count_r30_users", _count_r30_users)
+ def _get_start_of_day(self):
+ """
+ Returns millisecond unixtime for start of UTC day.
+ """
+ now = datetime.datetime.utcnow()
+ today_start = datetime.datetime(now.year, now.month,
+ now.day, tzinfo=tz.tzutc())
+ return int(time.mktime(today_start.timetuple())) * 1000
+
+ def generate_user_daily_visits(self):
+ """
+ Generates daily visit data for use in cohort/ retention analysis
+ """
+ def _generate_user_daily_visits(txn):
+ logger.info("Calling _generate_user_daily_visits")
+ today_start = self._get_start_of_day()
+ a_day_in_milliseconds = 24 * 60 * 60 * 1000
+ now = self.clock.time_msec()
+
+ sql = """
+ INSERT INTO user_daily_visits (user_id, device_id, timestamp)
+ SELECT u.user_id, u.device_id, ?
+ FROM user_ips AS u
+ LEFT JOIN (
+ SELECT user_id, device_id, timestamp FROM user_daily_visits
+ WHERE timestamp = ?
+ ) udv
+ ON u.user_id = udv.user_id AND u.device_id=udv.device_id
+ INNER JOIN users ON users.name=u.user_id
+ WHERE last_seen > ? AND last_seen <= ?
+ AND udv.timestamp IS NULL AND users.is_guest=0
+ AND users.appservice_id IS NULL
+ GROUP BY u.user_id, u.device_id
+ """
+
+ # This means that the day has rolled over but there could still
+ # be entries from the previous day. There is an edge case
+ # where if the user logs in at 23:59 and overwrites their
+ # last_seen at 00:01 then they will not be counted in the
+ # previous day's stats - it is important that the query is run
+ # often to minimise this case.
+ if today_start > self._last_user_visit_update:
+ yesterday_start = today_start - a_day_in_milliseconds
+ txn.execute(sql, (
+ yesterday_start, yesterday_start,
+ self._last_user_visit_update, today_start
+ ))
+ self._last_user_visit_update = today_start
+
+ txn.execute(sql, (
+ today_start, today_start,
+ self._last_user_visit_update,
+ now
+ ))
+ # Update _last_user_visit_update to now. The reason to do this
+ # rather just clamping to the beginning of the day is to limit
+ # the size of the join - meaning that the query can be run more
+ # frequently
+ self._last_user_visit_update = now
+
+ return self.runInteraction("generate_user_daily_visits",
+ _generate_user_daily_visits)
+
def get_users(self):
"""Function to reterive a list of users in users table.
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 2262776ab2..22d6257a9f 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -18,8 +18,8 @@ from synapse.api.errors import StoreError
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.caches.descriptors import Cache
from synapse.storage.engines import PostgresEngine
-import synapse.metrics
+from prometheus_client import Histogram
from twisted.internet import defer
@@ -27,20 +27,25 @@ import sys
import time
import threading
+from six import itervalues, iterkeys, iteritems
+from six.moves import intern, range
logger = logging.getLogger(__name__)
+try:
+ MAX_TXN_ID = sys.maxint - 1
+except AttributeError:
+ # python 3 does not have a maximum int value
+ MAX_TXN_ID = 2**63 - 1
+
sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
perf_logger = logging.getLogger("synapse.storage.TIME")
+sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
-metrics = synapse.metrics.get_metrics_for("synapse.storage")
-
-sql_scheduling_timer = metrics.register_distribution("schedule_time")
-
-sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
-sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
+sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
+sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
class LoggingTransaction(object):
@@ -105,7 +110,7 @@ class LoggingTransaction(object):
# Don't let logging failures stop SQL from working
pass
- start = time.time() * 1000
+ start = time.time()
try:
return func(
@@ -115,9 +120,9 @@ class LoggingTransaction(object):
logger.debug("[SQL FAIL] {%s} %s", self.name, e)
raise
finally:
- msecs = (time.time() * 1000) - start
- sql_logger.debug("[SQL time] {%s} %f", self.name, msecs)
- sql_query_timer.inc_by(msecs, sql.split()[0])
+ secs = time.time() - start
+ sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
+ sql_query_timer.labels(sql.split()[0]).observe(secs)
class PerformanceCounters(object):
@@ -127,7 +132,7 @@ class PerformanceCounters(object):
def update(self, key, start_time, end_time=None):
if end_time is None:
- end_time = time.time() * 1000
+ end_time = time.time()
duration = end_time - start_time
count, cum_time = self.current_counters.get(key, (0, 0))
count += 1
@@ -137,7 +142,7 @@ class PerformanceCounters(object):
def interval(self, interval_duration, limit=3):
counters = []
- for name, (count, cum_time) in self.current_counters.iteritems():
+ for name, (count, cum_time) in iteritems(self.current_counters):
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
counters.append((
(cum_time - prev_time) / interval_duration,
@@ -217,12 +222,12 @@ class SQLBaseStore(object):
def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks,
logging_context, func, *args, **kwargs):
- start = time.time() * 1000
+ start = time.time()
txn_id = self._TXN_ID
# We don't really need these to be unique, so lets stop it from
# growing really large.
- self._TXN_ID = (self._TXN_ID + 1) % (sys.maxint - 1)
+ self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
name = "%s-%x" % (desc, txn_id, )
@@ -277,17 +282,17 @@ class SQLBaseStore(object):
logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
finally:
- end = time.time() * 1000
+ end = time.time()
duration = end - start
if logging_context is not None:
logging_context.add_database_transaction(duration)
- transaction_logger.debug("[TXN END] {%s} %f", name, duration)
+ transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
self._current_txn_total_time += duration
self._txn_perf_counters.update(desc, start, end)
- sql_txn_timer.inc_by(duration, desc)
+ sql_txn_timer.labels(desc).observe(duration)
@defer.inlineCallbacks
def runInteraction(self, desc, func, *args, **kwargs):
@@ -344,13 +349,13 @@ class SQLBaseStore(object):
"""
current_context = LoggingContext.current_context()
- start_time = time.time() * 1000
+ start_time = time.time()
def inner_func(conn, *args, **kwargs):
with LoggingContext("runWithConnection") as context:
- sched_duration_ms = time.time() * 1000 - start_time
- sql_scheduling_timer.inc_by(sched_duration_ms)
- current_context.add_database_scheduled(sched_duration_ms)
+ sched_duration_sec = time.time() - start_time
+ sql_scheduling_timer.observe(sched_duration_sec)
+ current_context.add_database_scheduled(sched_duration_sec)
if self.database_engine.is_connection_closed(conn):
logger.debug("Reconnecting closed database connection")
@@ -543,7 +548,7 @@ class SQLBaseStore(object):
", ".join("%s = ?" % (k,) for k in values),
" AND ".join("%s = ?" % (k,) for k in keyvalues)
)
- sqlargs = values.values() + keyvalues.values()
+ sqlargs = list(values.values()) + list(keyvalues.values())
txn.execute(sql, sqlargs)
if txn.rowcount > 0:
@@ -561,7 +566,7 @@ class SQLBaseStore(object):
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues)
)
- txn.execute(sql, allvalues.values())
+ txn.execute(sql, list(allvalues.values()))
# successfully inserted
return True
@@ -629,8 +634,8 @@ class SQLBaseStore(object):
}
if keyvalues:
- sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
- txn.execute(sql, keyvalues.values())
+ sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+ txn.execute(sql, list(keyvalues.values()))
else:
txn.execute(sql)
@@ -694,7 +699,7 @@ class SQLBaseStore(object):
table,
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
- txn.execute(sql, keyvalues.values())
+ txn.execute(sql, list(keyvalues.values()))
else:
sql = "SELECT %s FROM %s" % (
", ".join(retcols),
@@ -725,9 +730,12 @@ class SQLBaseStore(object):
if not iterable:
defer.returnValue(results)
+ # iterables can not be sliced, so convert it to a list first
+ it_list = list(iterable)
+
chunks = [
- iterable[i:i + batch_size]
- for i in xrange(0, len(iterable), batch_size)
+ it_list[i:i + batch_size]
+ for i in range(0, len(it_list), batch_size)
]
for chunk in chunks:
rows = yield self.runInteraction(
@@ -767,7 +775,7 @@ class SQLBaseStore(object):
)
values.extend(iterable)
- for key, value in keyvalues.iteritems():
+ for key, value in iteritems(keyvalues):
clauses.append("%s = ?" % (key,))
values.append(value)
@@ -790,7 +798,7 @@ class SQLBaseStore(object):
@staticmethod
def _simple_update_txn(txn, table, keyvalues, updatevalues):
if keyvalues:
- where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
+ where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
else:
where = ""
@@ -802,7 +810,7 @@ class SQLBaseStore(object):
txn.execute(
update_sql,
- updatevalues.values() + keyvalues.values()
+ list(updatevalues.values()) + list(keyvalues.values())
)
return txn.rowcount
@@ -850,7 +858,7 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k,) for k in keyvalues)
)
- txn.execute(select_sql, keyvalues.values())
+ txn.execute(select_sql, list(keyvalues.values()))
row = txn.fetchone()
if not row:
@@ -888,7 +896,7 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
- txn.execute(sql, keyvalues.values())
+ txn.execute(sql, list(keyvalues.values()))
if txn.rowcount == 0:
raise StoreError(404, "No row found")
if txn.rowcount > 1:
@@ -906,7 +914,7 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
)
- return txn.execute(sql, keyvalues.values())
+ return txn.execute(sql, list(keyvalues.values()))
def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
return self.runInteraction(
@@ -938,7 +946,7 @@ class SQLBaseStore(object):
)
values.extend(iterable)
- for key, value in keyvalues.iteritems():
+ for key, value in iteritems(keyvalues):
clauses.append("%s = ?" % (key,))
values.append(value)
@@ -978,7 +986,7 @@ class SQLBaseStore(object):
txn.close()
if cache:
- min_val = min(cache.itervalues())
+ min_val = min(itervalues(cache))
else:
min_val = max_value
@@ -1093,7 +1101,7 @@ class SQLBaseStore(object):
" AND ".join("%s = ?" % (k,) for k in keyvalues),
" ? ASC LIMIT ? OFFSET ?"
)
- txn.execute(sql, keyvalues.values() + pagevalues)
+ txn.execute(sql, list(keyvalues.values()) + list(pagevalues))
else:
sql = "SELECT %s FROM %s ORDER BY %s" % (
", ".join(retcols),
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index 7b44dae0fc..ce338514e8 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -22,6 +22,8 @@ from . import background_updates
from synapse.util.caches import CACHE_SIZE_FACTOR
+from six import iteritems
+
logger = logging.getLogger(__name__)
@@ -55,6 +57,13 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
columns=["user_id", "last_seen"],
)
+ self.register_background_index_update(
+ "user_ips_last_seen_only_index",
+ index_name="user_ips_last_seen_only",
+ table="user_ips",
+ columns=["last_seen"],
+ )
+
# (user_id, access_token, ip) -> (user_agent, device_id, last_seen)
self._batch_row_update = {}
@@ -92,7 +101,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
def _update_client_ips_batch_txn(self, txn, to_update):
self.database_engine.lock_table(txn, "user_ips")
- for entry in to_update.iteritems():
+ for entry in iteritems(to_update):
(user_id, access_token, ip), (user_agent, device_id, last_seen) = entry
self._simple_upsert_txn(
@@ -224,5 +233,5 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"user_agent": user_agent,
"last_seen": last_seen,
}
- for (access_token, ip), (user_agent, last_seen) in results.iteritems()
+ for (access_token, ip), (user_agent, last_seen) in iteritems(results)
))
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index 712106b83a..d149d8392e 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -21,6 +21,7 @@ from synapse.api.errors import StoreError
from ._base import SQLBaseStore, Cache
from synapse.util.caches.descriptors import cached, cachedList, cachedInlineCallbacks
+from six import itervalues, iteritems
logger = logging.getLogger(__name__)
@@ -360,7 +361,7 @@ class DeviceStore(SQLBaseStore):
return (now_stream_id, [])
if len(query_map) >= 20:
- now_stream_id = max(stream_id for stream_id in query_map.itervalues())
+ now_stream_id = max(stream_id for stream_id in itervalues(query_map))
devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True
@@ -373,13 +374,13 @@ class DeviceStore(SQLBaseStore):
"""
results = []
- for user_id, user_devices in devices.iteritems():
+ for user_id, user_devices in iteritems(devices):
# The prev_id for the first row is always the last row before
# `from_stream_id`
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
rows = txn.fetchall()
prev_id = rows[0][0]
- for device_id, device in user_devices.iteritems():
+ for device_id, device in iteritems(user_devices):
stream_id = query_map[(user_id, device_id)]
result = {
"user_id": user_id,
@@ -483,7 +484,7 @@ class DeviceStore(SQLBaseStore):
if devices:
user_devices = devices[user_id]
results = []
- for device_id, device in user_devices.iteritems():
+ for device_id, device in iteritems(user_devices):
result = {
"device_id": device_id,
}
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index ff8538ddf8..b146487943 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -21,6 +21,8 @@ import simplejson as json
from ._base import SQLBaseStore
+from six import iteritems
+
class EndToEndKeyStore(SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
@@ -81,8 +83,8 @@ class EndToEndKeyStore(SQLBaseStore):
query_list, include_all_devices,
)
- for user_id, device_keys in results.iteritems():
- for device_id, device_info in device_keys.iteritems():
+ for user_id, device_keys in iteritems(results):
+ for device_id, device_info in iteritems(device_keys):
device_info["keys"] = json.loads(device_info.pop("key_json"))
defer.returnValue(results)
diff --git a/synapse/storage/event_push_actions.py b/synapse/storage/event_push_actions.py
index e78f8d0114..d0350ee5fe 100644
--- a/synapse/storage/event_push_actions.py
+++ b/synapse/storage/event_push_actions.py
@@ -18,12 +18,12 @@ from synapse.storage._base import SQLBaseStore, LoggingTransaction
from twisted.internet import defer
from synapse.util.async import sleep
from synapse.util.caches.descriptors import cachedInlineCallbacks
-from synapse.types import RoomStreamToken
-from .stream import lower_bound
import logging
import simplejson as json
+from six import iteritems
+
logger = logging.getLogger(__name__)
@@ -99,7 +99,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def _get_unread_counts_by_receipt_txn(self, txn, room_id, user_id,
last_read_event_id):
sql = (
- "SELECT stream_ordering, topological_ordering"
+ "SELECT stream_ordering"
" FROM events"
" WHERE room_id = ? AND event_id = ?"
)
@@ -111,17 +111,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
return {"notify_count": 0, "highlight_count": 0}
stream_ordering = results[0][0]
- topological_ordering = results[0][1]
return self._get_unread_counts_by_pos_txn(
- txn, room_id, user_id, topological_ordering, stream_ordering
+ txn, room_id, user_id, stream_ordering
)
- def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, topological_ordering,
- stream_ordering):
- token = RoomStreamToken(
- topological_ordering, stream_ordering
- )
+ def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering):
# First get number of notifications.
# We don't need to put a notif=1 clause as all rows always have
@@ -132,10 +127,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" WHERE"
" user_id = ?"
" AND room_id = ?"
- " AND %s"
- ) % (lower_bound(token, self.database_engine, inclusive=False),)
+ " AND stream_ordering > ?"
+ )
- txn.execute(sql, (user_id, room_id))
+ txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone()
notify_count = row[0] if row else 0
@@ -155,10 +150,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" highlight = 1"
" AND user_id = ?"
" AND room_id = ?"
- " AND %s"
- ) % (lower_bound(token, self.database_engine, inclusive=False),)
+ " AND stream_ordering > ?"
+ )
- txn.execute(sql, (user_id, room_id))
+ txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone()
highlight_count = row[0] if row else 0
@@ -209,7 +204,6 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" ep.highlight "
" FROM ("
" SELECT room_id,"
- " MAX(topological_ordering) as topological_ordering,"
" MAX(stream_ordering) as stream_ordering"
" FROM events"
" INNER JOIN receipts_linearized USING (room_id, event_id)"
@@ -219,13 +213,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" event_push_actions AS ep"
" WHERE"
" ep.room_id = rl.room_id"
- " AND ("
- " ep.topological_ordering > rl.topological_ordering"
- " OR ("
- " ep.topological_ordering = rl.topological_ordering"
- " AND ep.stream_ordering > rl.stream_ordering"
- " )"
- " )"
+ " AND ep.stream_ordering > rl.stream_ordering"
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
@@ -318,7 +306,6 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" ep.highlight, e.received_ts"
" FROM ("
" SELECT room_id,"
- " MAX(topological_ordering) as topological_ordering,"
" MAX(stream_ordering) as stream_ordering"
" FROM events"
" INNER JOIN receipts_linearized USING (room_id, event_id)"
@@ -329,13 +316,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" INNER JOIN events AS e USING (room_id, event_id)"
" WHERE"
" ep.room_id = rl.room_id"
- " AND ("
- " ep.topological_ordering > rl.topological_ordering"
- " OR ("
- " ep.topological_ordering = rl.topological_ordering"
- " AND ep.stream_ordering > rl.stream_ordering"
- " )"
- " )"
+ " AND ep.stream_ordering > rl.stream_ordering"
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
@@ -441,13 +422,14 @@ class EventPushActionsWorkerStore(SQLBaseStore):
txn.executemany(sql, (
_gen_entry(user_id, actions)
- for user_id, actions in user_id_actions.iteritems()
+ for user_id, actions in iteritems(user_id_actions)
))
return self.runInteraction(
"add_push_actions_to_staging", _add_push_actions_to_staging_txn
)
+ @defer.inlineCallbacks
def remove_push_actions_from_staging(self, event_id):
"""Called if we failed to persist the event to ensure that stale push
actions don't build up in the DB
@@ -456,13 +438,22 @@ class EventPushActionsWorkerStore(SQLBaseStore):
event_id (str)
"""
- return self._simple_delete(
- table="event_push_actions_staging",
- keyvalues={
- "event_id": event_id,
- },
- desc="remove_push_actions_from_staging",
- )
+ try:
+ res = yield self._simple_delete(
+ table="event_push_actions_staging",
+ keyvalues={
+ "event_id": event_id,
+ },
+ desc="remove_push_actions_from_staging",
+ )
+ defer.returnValue(res)
+ except Exception:
+ # this method is called from an exception handler, so propagating
+ # another exception here really isn't helpful - there's nothing
+ # the caller can do about it. Just log the exception and move on.
+ logger.exception(
+ "Error removing push actions after event persistence failure",
+ )
@defer.inlineCallbacks
def _find_stream_orderings_for_times(self):
@@ -752,10 +743,10 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
)
def _remove_old_push_actions_before_txn(self, txn, room_id, user_id,
- topological_ordering, stream_ordering):
+ stream_ordering):
"""
Purges old push actions for a user and room before a given
- topological_ordering.
+ stream_ordering.
We however keep a months worth of highlighted notifications, so that
users can still get a list of recent highlights.
@@ -764,7 +755,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
txn: The transcation
room_id: Room ID to delete from
user_id: user ID to delete for
- topological_ordering: The lowest topological ordering which will
+ stream_ordering: The lowest stream ordering which will
not be deleted.
"""
txn.call_after(
@@ -783,9 +774,9 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
txn.execute(
"DELETE FROM event_push_actions "
" WHERE user_id = ? AND room_id = ? AND "
- " topological_ordering <= ?"
+ " stream_ordering <= ?"
" AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
- (user_id, room_id, topological_ordering, self.stream_ordering_month_ago)
+ (user_id, room_id, stream_ordering, self.stream_ordering_month_ago)
)
txn.execute("""
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 5fe4a0e56c..cb1082e864 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -22,7 +22,6 @@ import logging
import simplejson as json
from twisted.internet import defer
-
from synapse.storage.events_worker import EventsWorkerStore
from synapse.util.async import ObservableDeferred
from synapse.util.frozenutils import frozendict_json_encoder
@@ -34,37 +33,37 @@ from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-from synapse.types import get_domain_from_id
+from synapse.types import get_domain_from_id, RoomStreamToken
import synapse.metrics
# these are only included to make the type annotations work
from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
-logger = logging.getLogger(__name__)
+from six.moves import range
+from six import itervalues, iteritems
+from prometheus_client import Counter
-metrics = synapse.metrics.get_metrics_for(__name__)
-persist_event_counter = metrics.register_counter("persisted_events")
-event_counter = metrics.register_counter(
- "persisted_events_sep", labels=["type", "origin_type", "origin_entity"]
-)
+logger = logging.getLogger(__name__)
+
+persist_event_counter = Counter("synapse_storage_events_persisted_events", "")
+event_counter = Counter("synapse_storage_events_persisted_events_sep", "",
+ ["type", "origin_type", "origin_entity"])
# The number of times we are recalculating the current state
-state_delta_counter = metrics.register_counter(
- "state_delta",
-)
+state_delta_counter = Counter("synapse_storage_events_state_delta", "")
+
# The number of times we are recalculating state when there is only a
# single forward extremity
-state_delta_single_event_counter = metrics.register_counter(
- "state_delta_single_event",
-)
+state_delta_single_event_counter = Counter(
+ "synapse_storage_events_state_delta_single_event", "")
+
# The number of times we are reculating state when we could have resonably
# calculated the delta when we calculated the state for an event we were
# persisting.
-state_delta_reuse_delta_counter = metrics.register_counter(
- "state_delta_reuse_delta",
-)
+state_delta_reuse_delta_counter = Counter(
+ "synapse_storage_events_state_delta_reuse_delta", "")
def encode_json(json_object):
@@ -249,7 +248,7 @@ class EventsStore(EventsWorkerStore):
partitioned.setdefault(event.room_id, []).append((event, ctx))
deferreds = []
- for room_id, evs_ctxs in partitioned.iteritems():
+ for room_id, evs_ctxs in iteritems(partitioned):
d = self._event_persist_queue.add_to_queue(
room_id, evs_ctxs,
backfilled=backfilled,
@@ -334,7 +333,7 @@ class EventsStore(EventsWorkerStore):
chunks = [
events_and_contexts[x:x + 100]
- for x in xrange(0, len(events_and_contexts), 100)
+ for x in range(0, len(events_and_contexts), 100)
]
for chunk in chunks:
@@ -368,7 +367,7 @@ class EventsStore(EventsWorkerStore):
(event, context)
)
- for room_id, ev_ctx_rm in events_by_room.iteritems():
+ for room_id, ev_ctx_rm in iteritems(events_by_room):
# Work out new extremities by recursively adding and removing
# the new events.
latest_event_ids = yield self.get_latest_event_ids_in_room(
@@ -425,7 +424,9 @@ class EventsStore(EventsWorkerStore):
)
current_state = yield self._get_new_state_after_events(
room_id,
- ev_ctx_rm, new_latest_event_ids,
+ ev_ctx_rm,
+ latest_event_ids,
+ new_latest_event_ids,
)
if current_state is not None:
current_state_for_room[room_id] = current_state
@@ -444,7 +445,7 @@ class EventsStore(EventsWorkerStore):
state_delta_for_room=state_delta_for_room,
new_forward_extremeties=new_forward_extremeties,
)
- persist_event_counter.inc_by(len(chunk))
+ persist_event_counter.inc(len(chunk))
synapse.metrics.event_persisted_position.set(
chunk[-1][0].internal_metadata.stream_ordering,
)
@@ -459,14 +460,14 @@ class EventsStore(EventsWorkerStore):
origin_type = "remote"
origin_entity = get_domain_from_id(event.sender)
- event_counter.inc(event.type, origin_type, origin_entity)
+ event_counter.labels(event.type, origin_type, origin_entity).inc()
- for room_id, new_state in current_state_for_room.iteritems():
+ for room_id, new_state in iteritems(current_state_for_room):
self.get_current_state_ids.prefill(
(room_id, ), new_state
)
- for room_id, latest_event_ids in new_forward_extremeties.iteritems():
+ for room_id, latest_event_ids in iteritems(new_forward_extremeties):
self.get_latest_event_ids_in_room.prefill(
(room_id,), list(latest_event_ids)
)
@@ -513,7 +514,8 @@ class EventsStore(EventsWorkerStore):
defer.returnValue(new_latest_event_ids)
@defer.inlineCallbacks
- def _get_new_state_after_events(self, room_id, events_context, new_latest_event_ids):
+ def _get_new_state_after_events(self, room_id, events_context, old_latest_event_ids,
+ new_latest_event_ids):
"""Calculate the current state dict after adding some new events to
a room
@@ -524,6 +526,9 @@ class EventsStore(EventsWorkerStore):
events_context (list[(EventBase, EventContext)]):
events and contexts which are being added to the room
+ old_latest_event_ids (iterable[str]):
+ the old forward extremities for the room.
+
new_latest_event_ids (iterable[str]):
the new forward extremities for the room.
@@ -534,64 +539,89 @@ class EventsStore(EventsWorkerStore):
"""
if not new_latest_event_ids:
- defer.returnValue({})
+ return
# map from state_group to ((type, key) -> event_id) state map
- state_groups = {}
- missing_event_ids = []
- was_updated = False
+ state_groups_map = {}
+ for ev, ctx in events_context:
+ if ctx.state_group is None:
+ # I don't think this can happen, but let's double-check
+ raise Exception(
+ "Context for new extremity event %s has no state "
+ "group" % (ev.event_id, ),
+ )
+
+ if ctx.state_group in state_groups_map:
+ continue
+
+ state_groups_map[ctx.state_group] = ctx.current_state_ids
+
+ # We need to map the event_ids to their state groups. First, let's
+ # check if the event is one we're persisting, in which case we can
+ # pull the state group from its context.
+ # Otherwise we need to pull the state group from the database.
+
+ # Set of events we need to fetch groups for. (We know none of the old
+ # extremities are going to be in events_context).
+ missing_event_ids = set(old_latest_event_ids)
+
+ event_id_to_state_group = {}
for event_id in new_latest_event_ids:
- # First search in the list of new events we're adding,
- # and then use the current state from that
+ # First search in the list of new events we're adding.
for ev, ctx in events_context:
if event_id == ev.event_id:
- if ctx.current_state_ids is None:
- raise Exception("Unknown current state")
-
- if ctx.state_group is None:
- # I don't think this can happen, but let's double-check
- raise Exception(
- "Context for new extremity event %s has no state "
- "group" % (event_id, ),
- )
-
- # If we've already seen the state group don't bother adding
- # it to the state sets again
- if ctx.state_group not in state_groups:
- state_groups[ctx.state_group] = ctx.current_state_ids
- if ctx.delta_ids or hasattr(ev, "state_key"):
- was_updated = True
+ event_id_to_state_group[event_id] = ctx.state_group
break
else:
# If we couldn't find it, then we'll need to pull
# the state from the database
- was_updated = True
- missing_event_ids.append(event_id)
-
- if not was_updated:
- return
+ missing_event_ids.add(event_id)
if missing_event_ids:
- # Now pull out the state for any missing events from DB
+ # Now pull out the state groups for any missing events from DB
event_to_groups = yield self._get_state_group_for_events(
missing_event_ids,
)
+ event_id_to_state_group.update(event_to_groups)
+
+ # State groups of old_latest_event_ids
+ old_state_groups = set(
+ event_id_to_state_group[evid] for evid in old_latest_event_ids
+ )
+
+ # State groups of new_latest_event_ids
+ new_state_groups = set(
+ event_id_to_state_group[evid] for evid in new_latest_event_ids
+ )
- groups = set(event_to_groups.itervalues()) - set(state_groups.iterkeys())
+ # If they old and new groups are the same then we don't need to do
+ # anything.
+ if old_state_groups == new_state_groups:
+ return
- if groups:
- group_to_state = yield self._get_state_for_groups(groups)
- state_groups.update(group_to_state)
+ # Now that we have calculated new_state_groups we need to get
+ # their state IDs so we can resolve to a single state set.
+ missing_state = new_state_groups - set(state_groups_map)
+ if missing_state:
+ group_to_state = yield self._get_state_for_groups(missing_state)
+ state_groups_map.update(group_to_state)
- if len(state_groups) == 1:
+ if len(new_state_groups) == 1:
# If there is only one state group, then we know what the current
# state is.
- defer.returnValue(state_groups.values()[0])
+ defer.returnValue(state_groups_map[new_state_groups.pop()])
+
+ # Ok, we need to defer to the state handler to resolve our state sets.
def get_events(ev_ids):
return self.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
)
+
+ state_groups = {
+ sg: state_groups_map[sg] for sg in new_state_groups
+ }
+
events_map = {ev.event_id: ev for ev, _ in events_context}
logger.debug("calling resolve_state_groups from preserve_events")
res = yield self._state_resolution_handler.resolve_state_groups(
@@ -614,20 +644,20 @@ class EventsStore(EventsWorkerStore):
"""
existing_state = yield self.get_current_state_ids(room_id)
- existing_events = set(existing_state.itervalues())
- new_events = set(ev_id for ev_id in current_state.itervalues())
+ existing_events = set(itervalues(existing_state))
+ new_events = set(ev_id for ev_id in itervalues(current_state))
changed_events = existing_events ^ new_events
if not changed_events:
return
to_delete = {
- key: ev_id for key, ev_id in existing_state.iteritems()
+ key: ev_id for key, ev_id in iteritems(existing_state)
if ev_id in changed_events
}
events_to_insert = (new_events - existing_events)
to_insert = {
- key: ev_id for key, ev_id in current_state.iteritems()
+ key: ev_id for key, ev_id in iteritems(current_state)
if ev_id in events_to_insert
}
@@ -730,11 +760,11 @@ class EventsStore(EventsWorkerStore):
)
def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
- for room_id, current_state_tuple in state_delta_by_room.iteritems():
+ for room_id, current_state_tuple in iteritems(state_delta_by_room):
to_delete, to_insert = current_state_tuple
txn.executemany(
"DELETE FROM current_state_events WHERE event_id = ?",
- [(ev_id,) for ev_id in to_delete.itervalues()],
+ [(ev_id,) for ev_id in itervalues(to_delete)],
)
self._simple_insert_many_txn(
@@ -747,7 +777,7 @@ class EventsStore(EventsWorkerStore):
"type": key[0],
"state_key": key[1],
}
- for key, ev_id in to_insert.iteritems()
+ for key, ev_id in iteritems(to_insert)
],
)
@@ -766,7 +796,7 @@ class EventsStore(EventsWorkerStore):
"event_id": ev_id,
"prev_event_id": to_delete.get(key, None),
}
- for key, ev_id in state_deltas.iteritems()
+ for key, ev_id in iteritems(state_deltas)
]
)
@@ -809,7 +839,7 @@ class EventsStore(EventsWorkerStore):
def _update_forward_extremities_txn(self, txn, new_forward_extremities,
max_stream_order):
- for room_id, new_extrem in new_forward_extremities.iteritems():
+ for room_id, new_extrem in iteritems(new_forward_extremities):
self._simple_delete_txn(
txn,
table="event_forward_extremities",
@@ -827,7 +857,7 @@ class EventsStore(EventsWorkerStore):
"event_id": ev_id,
"room_id": room_id,
}
- for room_id, new_extrem in new_forward_extremities.iteritems()
+ for room_id, new_extrem in iteritems(new_forward_extremities)
for ev_id in new_extrem
],
)
@@ -844,7 +874,7 @@ class EventsStore(EventsWorkerStore):
"event_id": event_id,
"stream_ordering": max_stream_order,
}
- for room_id, new_extrem in new_forward_extremities.iteritems()
+ for room_id, new_extrem in iteritems(new_forward_extremities)
for event_id in new_extrem
]
)
@@ -872,7 +902,7 @@ class EventsStore(EventsWorkerStore):
new_events_and_contexts[event.event_id] = (event, context)
else:
new_events_and_contexts[event.event_id] = (event, context)
- return new_events_and_contexts.values()
+ return list(new_events_and_contexts.values())
def _update_room_depths_txn(self, txn, events_and_contexts, backfilled):
"""Update min_depth for each room
@@ -898,7 +928,7 @@ class EventsStore(EventsWorkerStore):
event.depth, depth_updates.get(event.room_id, event.depth)
)
- for room_id, depth in depth_updates.iteritems():
+ for room_id, depth in iteritems(depth_updates):
self._update_min_depth_for_room_txn(txn, room_id, depth)
def _update_outliers_txn(self, txn, events_and_contexts):
@@ -1282,7 +1312,7 @@ class EventsStore(EventsWorkerStore):
" WHERE e.event_id IN (%s)"
) % (",".join(["?"] * len(ev_map)),)
- txn.execute(sql, ev_map.keys())
+ txn.execute(sql, list(ev_map))
rows = self.cursor_to_dict(txn)
for row in rows:
event = ev_map[row["event_id"]]
@@ -1545,7 +1575,7 @@ class EventsStore(EventsWorkerStore):
chunks = [
event_ids[i:i + 100]
- for i in xrange(0, len(event_ids), 100)
+ for i in range(0, len(event_ids), 100)
]
for chunk in chunks:
ev_rows = self._simple_select_many_txn(
@@ -1773,15 +1803,14 @@ class EventsStore(EventsWorkerStore):
return self.runInteraction("get_all_new_events", get_all_new_events_txn)
def purge_history(
- self, room_id, topological_ordering, delete_local_events,
+ self, room_id, token, delete_local_events,
):
"""Deletes room history before a certain point
Args:
room_id (str):
- topological_ordering (int):
- minimum topo ordering to preserve
+ token (str): A topological token to delete events before
delete_local_events (bool):
if True, we will delete local events as well as remote ones
@@ -1791,13 +1820,15 @@ class EventsStore(EventsWorkerStore):
return self.runInteraction(
"purge_history",
- self._purge_history_txn, room_id, topological_ordering,
+ self._purge_history_txn, room_id, token,
delete_local_events,
)
def _purge_history_txn(
- self, txn, room_id, topological_ordering, delete_local_events,
+ self, txn, room_id, token_str, delete_local_events,
):
+ token = RoomStreamToken.parse(token_str)
+
# Tables that should be pruned:
# event_auth
# event_backward_extremities
@@ -1842,6 +1873,13 @@ class EventsStore(EventsWorkerStore):
" ON events_to_purge(should_delete)",
)
+ # We do joins against events_to_purge for e.g. calculating state
+ # groups to purge, etc., so lets make an index.
+ txn.execute(
+ "CREATE INDEX events_to_purge_id"
+ " ON events_to_purge(event_id)",
+ )
+
# First ensure that we're not about to delete all the forward extremeties
txn.execute(
"SELECT e.event_id, e.depth FROM events as e "
@@ -1854,7 +1892,7 @@ class EventsStore(EventsWorkerStore):
rows = txn.fetchall()
max_depth = max(row[0] for row in rows)
- if max_depth <= topological_ordering:
+ if max_depth <= token.topological:
# We need to ensure we don't delete all the events from the datanase
# otherwise we wouldn't be able to send any events (due to not
# having any backwards extremeties)
@@ -1870,7 +1908,7 @@ class EventsStore(EventsWorkerStore):
should_delete_expr += " AND event_id NOT LIKE ?"
should_delete_params += ("%:" + self.hs.hostname, )
- should_delete_params += (room_id, topological_ordering)
+ should_delete_params += (room_id, token.topological)
txn.execute(
"INSERT INTO events_to_purge"
@@ -1893,13 +1931,13 @@ class EventsStore(EventsWorkerStore):
logger.info("[purge] Finding new backward extremities")
# We calculate the new entries for the backward extremeties by finding
- # all events that point to events that are to be purged
+ # events to be purged that are pointed to by events we're not going to
+ # purge.
txn.execute(
"SELECT DISTINCT e.event_id FROM events_to_purge AS e"
" INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id"
- " INNER JOIN events AS e2 ON e2.event_id = ed.event_id"
- " WHERE e2.topological_ordering >= ?",
- (topological_ordering, )
+ " LEFT JOIN events_to_purge AS ep2 ON ed.event_id = ep2.event_id"
+ " WHERE ep2.event_id IS NULL",
)
new_backwards_extrems = txn.fetchall()
@@ -1923,16 +1961,22 @@ class EventsStore(EventsWorkerStore):
# Get all state groups that are only referenced by events that are
# to be deleted.
- txn.execute(
- "SELECT state_group FROM event_to_state_groups"
- " INNER JOIN events USING (event_id)"
- " WHERE state_group IN ("
- " SELECT DISTINCT state_group FROM events_to_purge"
- " INNER JOIN event_to_state_groups USING (event_id)"
- " )"
- " GROUP BY state_group HAVING MAX(topological_ordering) < ?",
- (topological_ordering, )
- )
+ # This works by first getting state groups that we may want to delete,
+ # joining against event_to_state_groups to get events that use that
+ # state group, then left joining against events_to_purge again. Any
+ # state group where the left join produce *no nulls* are referenced
+ # only by events that are going to be purged.
+ txn.execute("""
+ SELECT state_group FROM
+ (
+ SELECT DISTINCT state_group FROM events_to_purge
+ INNER JOIN event_to_state_groups USING (event_id)
+ ) AS sp
+ INNER JOIN event_to_state_groups USING (state_group)
+ LEFT JOIN events_to_purge AS ep USING (event_id)
+ GROUP BY state_group
+ HAVING SUM(CASE WHEN ep.event_id IS NULL THEN 1 ELSE 0 END) = 0
+ """)
state_rows = txn.fetchall()
logger.info("[purge] found %i redundant state groups", len(state_rows))
@@ -1945,7 +1989,7 @@ class EventsStore(EventsWorkerStore):
logger.info("[purge] finding state groups which depend on redundant"
" state groups")
remaining_state_groups = []
- for i in xrange(0, len(state_rows), 100):
+ for i in range(0, len(state_rows), 100):
chunk = [sg for sg, in state_rows[i:i + 100]]
# look for state groups whose prev_state_group is one we are about
# to delete
@@ -2001,7 +2045,7 @@ class EventsStore(EventsWorkerStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in curr_state.iteritems()
+ for key, state_id in iteritems(curr_state)
],
)
@@ -2079,10 +2123,25 @@ class EventsStore(EventsWorkerStore):
#
# So, let's stick it at the end so that we don't block event
# persistence.
- logger.info("[purge] updating room_depth")
+ #
+ # We do this by calculating the minimum depth of the backwards
+ # extremities. However, the events in event_backward_extremities
+ # are ones we don't have yet so we need to look at the events that
+ # point to it via event_edges table.
+ txn.execute("""
+ SELECT COALESCE(MIN(depth), 0)
+ FROM event_backward_extremities AS eb
+ INNER JOIN event_edges AS eg ON eg.prev_event_id = eb.event_id
+ INNER JOIN events AS e ON e.event_id = eg.event_id
+ WHERE eb.room_id = ?
+ """, (room_id,))
+ min_depth, = txn.fetchone()
+
+ logger.info("[purge] updating room_depth to %d", min_depth)
+
txn.execute(
"UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
- (topological_ordering, room_id,)
+ (min_depth, room_id,)
)
# finally, drop the temp table. this will commit the txn in sqlite,
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index a937b9bceb..32d9d00ffb 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -20,7 +20,7 @@ from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from synapse.util.logcontext import (
- preserve_fn, PreserveLoggingContext, make_deferred_yieldable
+ PreserveLoggingContext, make_deferred_yieldable, run_in_background,
)
from synapse.util.metrics import Measure
from synapse.api.errors import SynapseError
@@ -319,7 +319,8 @@ class EventsWorkerStore(SQLBaseStore):
res = yield make_deferred_yieldable(defer.gatherResults(
[
- preserve_fn(self._get_event_from_row)(
+ run_in_background(
+ self._get_event_from_row,
row["internal_metadata"], row["json"], row["redacts"],
rejected_reason=row["rejects"],
)
@@ -336,7 +337,7 @@ class EventsWorkerStore(SQLBaseStore):
def _fetch_event_rows(self, txn, events):
rows = []
N = 200
- for i in range(1 + len(events) / N):
+ for i in range(1 + len(events) // N):
evs = events[i * N:(i + 1) * N]
if not evs:
break
diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py
index 78b1e30945..2e2763126d 100644
--- a/synapse/storage/filtering.py
+++ b/synapse/storage/filtering.py
@@ -44,7 +44,7 @@ class FilteringStore(SQLBaseStore):
desc="get_user_filter",
)
- defer.returnValue(json.loads(str(def_json).decode("utf-8")))
+ defer.returnValue(json.loads(bytes(def_json).decode("utf-8")))
def add_user_filter(self, user_localpart, user_filter):
def_json = encode_canonical_json(user_filter)
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 87aeaf71d6..0f13b61da8 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -17,6 +17,7 @@ from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks
from twisted.internet import defer
+import six
import OpenSSL
from signedjson.key import decode_verify_key_bytes
@@ -26,6 +27,13 @@ import logging
logger = logging.getLogger(__name__)
+# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
+# despite being deprecated and removed in favor of memoryview
+if six.PY2:
+ db_binary_type = buffer
+else:
+ db_binary_type = memoryview
+
class KeyStore(SQLBaseStore):
"""Persistence for signature verification keys and tls X.509 certificates
@@ -72,7 +80,7 @@ class KeyStore(SQLBaseStore):
values={
"from_server": from_server,
"ts_added_ms": time_now_ms,
- "tls_certificate": buffer(tls_certificate_bytes),
+ "tls_certificate": db_binary_type(tls_certificate_bytes),
},
desc="store_server_certificate",
)
@@ -92,7 +100,7 @@ class KeyStore(SQLBaseStore):
if verify_key_bytes:
defer.returnValue(decode_verify_key_bytes(
- key_id, str(verify_key_bytes)
+ key_id, bytes(verify_key_bytes)
))
@defer.inlineCallbacks
@@ -135,7 +143,7 @@ class KeyStore(SQLBaseStore):
values={
"from_server": from_server,
"ts_added_ms": time_now_ms,
- "verify_key": buffer(verify_key.encode()),
+ "verify_key": db_binary_type(verify_key.encode()),
},
)
txn.call_after(
@@ -172,7 +180,7 @@ class KeyStore(SQLBaseStore):
"from_server": from_server,
"ts_added_ms": ts_now_ms,
"ts_valid_until_ms": ts_expires_ms,
- "key_json": buffer(key_json_bytes),
+ "key_json": db_binary_type(key_json_bytes),
},
desc="store_server_keys_json",
)
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 04411a665f..cf2aae0468 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -26,7 +26,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 48
+SCHEMA_VERSION = 50
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/presence.py b/synapse/storage/presence.py
index 9e9d3c2591..f05d91cc58 100644
--- a/synapse/storage/presence.py
+++ b/synapse/storage/presence.py
@@ -16,6 +16,7 @@
from ._base import SQLBaseStore
from synapse.api.constants import PresenceState
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
+from synapse.util import batch_iter
from collections import namedtuple
from twisted.internet import defer
@@ -115,11 +116,7 @@ class PresenceStore(SQLBaseStore):
" AND user_id IN (%s)"
)
- batches = (
- presence_states[i:i + 50]
- for i in xrange(0, len(presence_states), 50)
- )
- for states in batches:
+ for states in batch_iter(presence_states, 50):
args = [stream_id]
args.extend(s.user_id for s in states)
txn.execute(
diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py
index 63997ed449..c93c228f6e 100644
--- a/synapse/storage/receipts.py
+++ b/synapse/storage/receipts.py
@@ -297,18 +297,22 @@ class ReceiptsWorkerStore(SQLBaseStore):
if receipt_type != "m.read":
return
- # Returns an ObservableDeferred
+ # Returns either an ObservableDeferred or the raw result
res = self.get_users_with_read_receipts_in_room.cache.get(
room_id, None, update_metrics=False,
)
- if res:
- if isinstance(res, defer.Deferred) and res.called:
+ # first handle the Deferred case
+ if isinstance(res, defer.Deferred):
+ if res.called:
res = res.result
- if user_id in res:
- # We'd only be adding to the set, so no point invalidating if the
- # user is already there
- return
+ else:
+ res = None
+
+ if res and user_id in res:
+ # We'd only be adding to the set, so no point invalidating if the
+ # user is already there
+ return
self.get_users_with_read_receipts_in_room.invalidate((room_id,))
@@ -328,6 +332,35 @@ class ReceiptsStore(ReceiptsWorkerStore):
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_id, data, stream_id):
+ res = self._simple_select_one_txn(
+ txn,
+ table="events",
+ retcols=["topological_ordering", "stream_ordering"],
+ keyvalues={"event_id": event_id},
+ allow_none=True
+ )
+
+ stream_ordering = int(res["stream_ordering"]) if res else None
+
+ # We don't want to clobber receipts for more recent events, so we
+ # have to compare orderings of existing receipts
+ if stream_ordering is not None:
+ sql = (
+ "SELECT stream_ordering, event_id FROM events"
+ " INNER JOIN receipts_linearized as r USING (event_id, room_id)"
+ " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
+ )
+ txn.execute(sql, (room_id, receipt_type, user_id))
+
+ for so, eid in txn:
+ if int(so) >= stream_ordering:
+ logger.debug(
+ "Ignoring new receipt for %s in favour of existing "
+ "one for later event %s",
+ event_id, eid,
+ )
+ return False
+
txn.call_after(
self.get_receipts_for_room.invalidate, (room_id, receipt_type)
)
@@ -351,34 +384,6 @@ class ReceiptsStore(ReceiptsWorkerStore):
(user_id, room_id, receipt_type)
)
- res = self._simple_select_one_txn(
- txn,
- table="events",
- retcols=["topological_ordering", "stream_ordering"],
- keyvalues={"event_id": event_id},
- allow_none=True
- )
-
- topological_ordering = int(res["topological_ordering"]) if res else None
- stream_ordering = int(res["stream_ordering"]) if res else None
-
- # We don't want to clobber receipts for more recent events, so we
- # have to compare orderings of existing receipts
- sql = (
- "SELECT topological_ordering, stream_ordering, event_id FROM events"
- " INNER JOIN receipts_linearized as r USING (event_id, room_id)"
- " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
- )
-
- txn.execute(sql, (room_id, receipt_type, user_id))
-
- if topological_ordering:
- for to, so, _ in txn:
- if int(to) > topological_ordering:
- return False
- elif int(to) == topological_ordering and int(so) >= stream_ordering:
- return False
-
self._simple_delete_txn(
txn,
table="receipts_linearized",
@@ -402,12 +407,11 @@ class ReceiptsStore(ReceiptsWorkerStore):
}
)
- if receipt_type == "m.read" and topological_ordering:
+ if receipt_type == "m.read" and stream_ordering is not None:
self._remove_old_push_actions_before_txn(
txn,
room_id=room_id,
user_id=user_id,
- topological_ordering=topological_ordering,
stream_ordering=stream_ordering,
)
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 384c9977c1..d10cb3d165 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -22,6 +22,8 @@ from synapse.storage import background_updates
from synapse.storage._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from six.moves import range
+
class RegistrationWorkerStore(SQLBaseStore):
@cached()
@@ -31,7 +33,11 @@ class RegistrationWorkerStore(SQLBaseStore):
keyvalues={
"name": user_id,
},
- retcols=["name", "password_hash", "is_guest"],
+ retcols=[
+ "name", "password_hash", "is_guest",
+ "consent_version", "consent_server_notice_sent",
+ "appservice_id",
+ ],
allow_none=True,
desc="get_user_by_id",
)
@@ -96,6 +102,13 @@ class RegistrationStore(RegistrationWorkerStore,
columns=["user_id", "device_id"],
)
+ self.register_background_index_update(
+ "users_creation_ts",
+ index_name="users_creation_ts",
+ table="users",
+ columns=["creation_ts"],
+ )
+
# we no longer use refresh tokens, but it's possible that some people
# might have a background update queued to build this index. Just
# clear the background update.
@@ -272,6 +285,53 @@ class RegistrationStore(RegistrationWorkerStore,
"user_set_password_hash", user_set_password_hash_txn
)
+ def user_set_consent_version(self, user_id, consent_version):
+ """Updates the user table to record privacy policy consent
+
+ Args:
+ user_id (str): full mxid of the user to update
+ consent_version (str): version of the policy the user has consented
+ to
+
+ Raises:
+ StoreError(404) if user not found
+ """
+ def f(txn):
+ self._simple_update_one_txn(
+ txn,
+ table='users',
+ keyvalues={'name': user_id, },
+ updatevalues={'consent_version': consent_version, },
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_by_id, (user_id,)
+ )
+ return self.runInteraction("user_set_consent_version", f)
+
+ def user_set_consent_server_notice_sent(self, user_id, consent_version):
+ """Updates the user table to record that we have sent the user a server
+ notice about privacy policy consent
+
+ Args:
+ user_id (str): full mxid of the user to update
+ consent_version (str): version of the policy we have notified the
+ user about
+
+ Raises:
+ StoreError(404) if user not found
+ """
+ def f(txn):
+ self._simple_update_one_txn(
+ txn,
+ table='users',
+ keyvalues={'name': user_id, },
+ updatevalues={'consent_server_notice_sent': consent_version, },
+ )
+ self._invalidate_cache_and_stream(
+ txn, self.get_user_by_id, (user_id,)
+ )
+ return self.runInteraction("user_set_consent_server_notice_sent", f)
+
def user_delete_access_tokens(self, user_id, except_token_id=None,
device_id=None):
"""
@@ -421,6 +481,35 @@ class RegistrationStore(RegistrationWorkerStore,
ret = yield self.runInteraction("count_users", _count_users)
defer.returnValue(ret)
+ def count_daily_user_type(self):
+ """
+ Counts 1) native non guest users
+ 2) native guests users
+ 3) bridged users
+ who registered on the homeserver in the past 24 hours
+ """
+ def _count_daily_user_type(txn):
+ yesterday = int(self._clock.time()) - (60 * 60 * 24)
+
+ sql = """
+ SELECT user_type, COALESCE(count(*), 0) AS count FROM (
+ SELECT
+ CASE
+ WHEN is_guest=0 AND appservice_id IS NULL THEN 'native'
+ WHEN is_guest=1 AND appservice_id IS NULL THEN 'guest'
+ WHEN is_guest=0 AND appservice_id IS NOT NULL THEN 'bridged'
+ END AS user_type
+ FROM users
+ WHERE creation_ts > ?
+ ) AS t GROUP BY user_type
+ """
+ results = {'native': 0, 'guest': 0, 'bridged': 0}
+ txn.execute(sql, (yesterday,))
+ for row in txn:
+ results[row[0]] = row[1]
+ return results
+ return self.runInteraction("count_daily_user_type", _count_daily_user_type)
+
@defer.inlineCallbacks
def count_nonbridged_users(self):
def _count_users(txn):
@@ -457,7 +546,7 @@ class RegistrationStore(RegistrationWorkerStore,
match = regex.search(user_id)
if match:
found.add(int(match.group(1)))
- for i in xrange(len(found) + 1):
+ for i in range(len(found) + 1):
if i not in found:
return i
@@ -512,3 +601,42 @@ class RegistrationStore(RegistrationWorkerStore,
except self.database_engine.module.IntegrityError:
ret = yield self.get_3pid_guest_access_token(medium, address)
defer.returnValue(ret)
+
+ def add_user_pending_deactivation(self, user_id):
+ """
+ Adds a user to the table of users who need to be parted from all the rooms they're
+ in
+ """
+ return self._simple_insert(
+ "users_pending_deactivation",
+ values={
+ "user_id": user_id,
+ },
+ desc="add_user_pending_deactivation",
+ )
+
+ def del_user_pending_deactivation(self, user_id):
+ """
+ Removes the given user to the table of users who need to be parted from all the
+ rooms they're in, effectively marking that user as fully deactivated.
+ """
+ return self._simple_delete_one(
+ "users_pending_deactivation",
+ keyvalues={
+ "user_id": user_id,
+ },
+ desc="del_user_pending_deactivation",
+ )
+
+ def get_user_pending_deactivation(self):
+ """
+ Gets one user from the table of users waiting to be parted from all the rooms
+ they're in.
+ """
+ return self._simple_select_one_onecol(
+ "users_pending_deactivation",
+ keyvalues={},
+ retcol="user_id",
+ allow_none=True,
+ desc="get_users_pending_deactivation",
+ )
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 6a861943a2..7bfc3d91b5 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -30,6 +30,8 @@ from synapse.types import get_domain_from_id
import logging
import simplejson as json
+from six import itervalues, iteritems
+
logger = logging.getLogger(__name__)
@@ -272,7 +274,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
users_in_room = {}
member_event_ids = [
e_id
- for key, e_id in current_state_ids.iteritems()
+ for key, e_id in iteritems(current_state_ids)
if key[0] == EventTypes.Member
]
@@ -289,7 +291,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
users_in_room = dict(prev_res)
member_event_ids = [
e_id
- for key, e_id in context.delta_ids.iteritems()
+ for key, e_id in iteritems(context.delta_ids)
if key[0] == EventTypes.Member
]
for etype, state_key in context.delta_ids:
@@ -741,7 +743,7 @@ class _JoinedHostsCache(object):
if state_entry.state_group == self.state_group:
pass
elif state_entry.prev_group == self.state_group:
- for (typ, state_key), event_id in state_entry.delta_ids.iteritems():
+ for (typ, state_key), event_id in iteritems(state_entry.delta_ids):
if typ != EventTypes.Member:
continue
@@ -771,7 +773,7 @@ class _JoinedHostsCache(object):
self.state_group = state_entry.state_group
else:
self.state_group = object()
- self._len = sum(len(v) for v in self.hosts_to_joined_users.itervalues())
+ self._len = sum(len(v) for v in itervalues(self.hosts_to_joined_users))
defer.returnValue(frozenset(self.hosts_to_joined_users))
def __len__(self):
diff --git a/synapse/storage/schema/delta/30/as_users.py b/synapse/storage/schema/delta/30/as_users.py
index c53e53c94f..85bd1a2006 100644
--- a/synapse/storage/schema/delta/30/as_users.py
+++ b/synapse/storage/schema/delta/30/as_users.py
@@ -14,6 +14,8 @@
import logging
from synapse.config.appservice import load_appservices
+from six.moves import range
+
logger = logging.getLogger(__name__)
@@ -58,7 +60,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
for as_id, user_ids in owned.items():
n = 100
- user_chunks = (user_ids[i:i + 100] for i in xrange(0, len(user_ids), n))
+ user_chunks = (user_ids[i:i + 100] for i in range(0, len(user_ids), n))
for chunk in user_chunks:
cur.execute(
database_engine.convert_param_style(
diff --git a/synapse/storage/schema/delta/48/add_user_consent.sql b/synapse/storage/schema/delta/48/add_user_consent.sql
new file mode 100644
index 0000000000..5237491506
--- /dev/null
+++ b/synapse/storage/schema/delta/48/add_user_consent.sql
@@ -0,0 +1,18 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* record the version of the privacy policy the user has consented to
+ */
+ALTER TABLE users ADD COLUMN consent_version TEXT;
diff --git a/synapse/storage/schema/delta/48/deactivated_users.sql b/synapse/storage/schema/delta/48/deactivated_users.sql
new file mode 100644
index 0000000000..e9013a6969
--- /dev/null
+++ b/synapse/storage/schema/delta/48/deactivated_users.sql
@@ -0,0 +1,25 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Store any accounts that have been requested to be deactivated.
+ * We part the account from all the rooms its in when its
+ * deactivated. This can take some time and synapse may be restarted
+ * before it completes, so store the user IDs here until the process
+ * is complete.
+ */
+CREATE TABLE users_pending_deactivation (
+ user_id TEXT NOT NULL
+);
diff --git a/synapse/storage/schema/delta/48/group_unique_indexes.py b/synapse/storage/schema/delta/48/group_unique_indexes.py
new file mode 100644
index 0000000000..2233af87d7
--- /dev/null
+++ b/synapse/storage/schema/delta/48/group_unique_indexes.py
@@ -0,0 +1,57 @@
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.prepare_database import get_statements
+
+FIX_INDEXES = """
+-- rebuild indexes as uniques
+DROP INDEX groups_invites_g_idx;
+CREATE UNIQUE INDEX group_invites_g_idx ON group_invites(group_id, user_id);
+DROP INDEX groups_users_g_idx;
+CREATE UNIQUE INDEX group_users_g_idx ON group_users(group_id, user_id);
+
+-- rename other indexes to actually match their table names..
+DROP INDEX groups_users_u_idx;
+CREATE INDEX group_users_u_idx ON group_users(user_id);
+DROP INDEX groups_invites_u_idx;
+CREATE INDEX group_invites_u_idx ON group_invites(user_id);
+DROP INDEX groups_rooms_g_idx;
+CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id);
+DROP INDEX groups_rooms_r_idx;
+CREATE INDEX group_rooms_r_idx ON group_rooms(room_id);
+"""
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid"
+
+ # remove duplicates from group_users & group_invites tables
+ cur.execute("""
+ DELETE FROM group_users WHERE %s NOT IN (
+ SELECT min(%s) FROM group_users GROUP BY group_id, user_id
+ );
+ """ % (rowid, rowid))
+ cur.execute("""
+ DELETE FROM group_invites WHERE %s NOT IN (
+ SELECT min(%s) FROM group_invites GROUP BY group_id, user_id
+ );
+ """ % (rowid, rowid))
+
+ for statement in get_statements(FIX_INDEXES.splitlines()):
+ cur.execute(statement)
+
+
+def run_upgrade(*args, **kwargs):
+ pass
diff --git a/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql b/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql
new file mode 100644
index 0000000000..14dcf18d73
--- /dev/null
+++ b/synapse/storage/schema/delta/49/add_user_consent_server_notice_sent.sql
@@ -0,0 +1,20 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* record whether we have sent a server notice about consenting to the
+ * privacy policy. Specifically records the version of the policy we sent
+ * a message about.
+ */
+ALTER TABLE users ADD COLUMN consent_server_notice_sent TEXT;
diff --git a/synapse/storage/schema/delta/49/add_user_daily_visits.sql b/synapse/storage/schema/delta/49/add_user_daily_visits.sql
new file mode 100644
index 0000000000..3dd478196f
--- /dev/null
+++ b/synapse/storage/schema/delta/49/add_user_daily_visits.sql
@@ -0,0 +1,21 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL,
+ device_id TEXT,
+ timestamp BIGINT NOT NULL );
+CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp);
+CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp);
diff --git a/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql b/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql
new file mode 100644
index 0000000000..3a4ed59b5b
--- /dev/null
+++ b/synapse/storage/schema/delta/49/add_user_ips_last_seen_only_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT into background_updates (update_name, progress_json)
+ VALUES ('user_ips_last_seen_only_index', '{}');
diff --git a/synapse/storage/schema/delta/50/add_creation_ts_users_index.sql b/synapse/storage/schema/delta/50/add_creation_ts_users_index.sql
new file mode 100644
index 0000000000..c93ae47532
--- /dev/null
+++ b/synapse/storage/schema/delta/50/add_creation_ts_users_index.sql
@@ -0,0 +1,19 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+
+INSERT into background_updates (update_name, progress_json)
+ VALUES ('users_creation_ts', '{}');
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index 6ba3e59889..f0fa5d7631 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -18,13 +18,14 @@ import logging
import re
import simplejson as json
+from six import string_types
+
from twisted.internet import defer
from .background_updates import BackgroundUpdateStore
from synapse.api.errors import SynapseError
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-
logger = logging.getLogger(__name__)
SearchEntry = namedtuple('SearchEntry', [
@@ -126,7 +127,7 @@ class SearchStore(BackgroundUpdateStore):
# skip over it.
continue
- if not isinstance(value, basestring):
+ if not isinstance(value, string_types):
# If the event body, name or topic isn't a string
# then skip over it
continue
@@ -447,7 +448,7 @@ class SearchStore(BackgroundUpdateStore):
"search_msgs", self.cursor_to_dict, sql, *args
)
- results = filter(lambda row: row["room_id"] in room_ids, results)
+ results = list(filter(lambda row: row["room_id"] in room_ids, results))
events = yield self._get_events([r["event_id"] for r in results])
@@ -602,7 +603,7 @@ class SearchStore(BackgroundUpdateStore):
"search_rooms", self.cursor_to_dict, sql, *args
)
- results = filter(lambda row: row["room_id"] in room_ids, results)
+ results = list(filter(lambda row: row["room_id"] in room_ids, results))
events = yield self._get_events([r["event_id"] for r in results])
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index 9e6eaaa532..25922e5a9c 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -14,6 +14,7 @@
# limitations under the License.
from twisted.internet import defer
+import six
from ._base import SQLBaseStore
@@ -21,6 +22,13 @@ from unpaddedbase64 import encode_base64
from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.util.caches.descriptors import cached, cachedList
+# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
+# despite being deprecated and removed in favor of memoryview
+if six.PY2:
+ db_binary_type = buffer
+else:
+ db_binary_type = memoryview
+
class SignatureWorkerStore(SQLBaseStore):
@cached()
@@ -56,7 +64,7 @@ class SignatureWorkerStore(SQLBaseStore):
for e_id, h in hashes.items()
}
- defer.returnValue(hashes.items())
+ defer.returnValue(list(hashes.items()))
def _get_event_reference_hashes_txn(self, txn, event_id):
"""Get all the hashes for a given PDU.
@@ -91,7 +99,7 @@ class SignatureStore(SignatureWorkerStore):
vals.append({
"event_id": event.event_id,
"algorithm": ref_alg,
- "hash": buffer(ref_hash_bytes),
+ "hash": db_binary_type(ref_hash_bytes),
})
self._simple_insert_many_txn(
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index ffa4246031..bdee14a8eb 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -16,6 +16,9 @@
from collections import namedtuple
import logging
+from six import iteritems, itervalues
+from six.moves import range
+
from twisted.internet import defer
from synapse.storage.background_updates import BackgroundUpdateStore
@@ -134,7 +137,7 @@ class StateGroupWorkerStore(SQLBaseStore):
event_ids,
)
- groups = set(event_to_groups.itervalues())
+ groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups)
defer.returnValue(group_to_state)
@@ -166,18 +169,18 @@ class StateGroupWorkerStore(SQLBaseStore):
state_event_map = yield self.get_events(
[
- ev_id for group_ids in group_to_ids.itervalues()
- for ev_id in group_ids.itervalues()
+ ev_id for group_ids in itervalues(group_to_ids)
+ for ev_id in itervalues(group_ids)
],
get_prev_content=False
)
defer.returnValue({
group: [
- state_event_map[v] for v in event_id_map.itervalues()
+ state_event_map[v] for v in itervalues(event_id_map)
if v in state_event_map
]
- for group, event_id_map in group_to_ids.iteritems()
+ for group, event_id_map in iteritems(group_to_ids)
})
@defer.inlineCallbacks
@@ -186,7 +189,7 @@ class StateGroupWorkerStore(SQLBaseStore):
"""
results = {}
- chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)]
+ chunks = [groups[i:i + 100] for i in range(0, len(groups), 100)]
for chunk in chunks:
res = yield self.runInteraction(
"_get_state_groups_from_groups",
@@ -347,21 +350,21 @@ class StateGroupWorkerStore(SQLBaseStore):
event_ids,
)
- groups = set(event_to_groups.itervalues())
+ groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups, types)
state_event_map = yield self.get_events(
- [ev_id for sd in group_to_state.itervalues() for ev_id in sd.itervalues()],
+ [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
get_prev_content=False
)
event_to_state = {
event_id: {
k: state_event_map[v]
- for k, v in group_to_state[group].iteritems()
+ for k, v in iteritems(group_to_state[group])
if v in state_event_map
}
- for event_id, group in event_to_groups.iteritems()
+ for event_id, group in iteritems(event_to_groups)
}
defer.returnValue({event: event_to_state[event] for event in event_ids})
@@ -384,12 +387,12 @@ class StateGroupWorkerStore(SQLBaseStore):
event_ids,
)
- groups = set(event_to_groups.itervalues())
+ groups = set(itervalues(event_to_groups))
group_to_state = yield self._get_state_for_groups(groups, types)
event_to_state = {
event_id: group_to_state[group]
- for event_id, group in event_to_groups.iteritems()
+ for event_id, group in iteritems(event_to_groups)
}
defer.returnValue({event: event_to_state[event] for event in event_ids})
@@ -503,7 +506,7 @@ class StateGroupWorkerStore(SQLBaseStore):
got_all = is_all or not missing_types
return {
- k: v for k, v in state_dict_ids.iteritems()
+ k: v for k, v in iteritems(state_dict_ids)
if include(k[0], k[1])
}, missing_types, got_all
@@ -562,12 +565,12 @@ class StateGroupWorkerStore(SQLBaseStore):
# Now we want to update the cache with all the things we fetched
# from the database.
- for group, group_state_dict in group_to_state_dict.iteritems():
+ for group, group_state_dict in iteritems(group_to_state_dict):
state_dict = results[group]
state_dict.update(
((intern_string(k[0]), intern_string(k[1])), to_ascii(v))
- for k, v in group_state_dict.iteritems()
+ for k, v in iteritems(group_state_dict)
)
self._state_group_cache.update(
@@ -654,7 +657,7 @@ class StateGroupWorkerStore(SQLBaseStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in delta_ids.iteritems()
+ for key, state_id in iteritems(delta_ids)
],
)
else:
@@ -669,7 +672,7 @@ class StateGroupWorkerStore(SQLBaseStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in current_state_ids.iteritems()
+ for key, state_id in iteritems(current_state_ids)
],
)
@@ -794,11 +797,11 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
"state_group": state_group_id,
"event_id": event_id,
}
- for event_id, state_group_id in state_groups.iteritems()
+ for event_id, state_group_id in iteritems(state_groups)
],
)
- for event_id, state_group_id in state_groups.iteritems():
+ for event_id, state_group_id in iteritems(state_groups):
txn.call_after(
self._get_state_group_for_event.prefill,
(event_id,), state_group_id
@@ -826,7 +829,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
def reindex_txn(txn):
new_last_state_group = last_state_group
- for count in xrange(batch_size):
+ for count in range(batch_size):
txn.execute(
"SELECT id, room_id FROM state_groups"
" WHERE ? < id AND id <= ?"
@@ -884,7 +887,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
# of keys
delta_state = {
- key: value for key, value in curr_state.iteritems()
+ key: value for key, value in iteritems(curr_state)
if prev_state.get(key, None) != value
}
@@ -924,7 +927,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in delta_state.iteritems()
+ for key, state_id in iteritems(delta_state)
],
)
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index 2956c3b3e0..fb463c525a 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -38,15 +38,17 @@ from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
from synapse.storage.events import EventsWorkerStore
-from synapse.util.caches.descriptors import cached
from synapse.types import RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
-from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
+from synapse.storage.engines import PostgresEngine
import abc
import logging
+from six.moves import range
+from collections import namedtuple
+
logger = logging.getLogger(__name__)
@@ -58,6 +60,12 @@ _STREAM_TOKEN = "stream"
_TOPOLOGICAL_TOKEN = "topological"
+# Used as return values for pagination APIs
+_EventDictReturn = namedtuple("_EventDictReturn", (
+ "event_id", "topological_ordering", "stream_ordering",
+))
+
+
def lower_bound(token, engine, inclusive=False):
inclusive = "=" if inclusive else ""
if token.topological is None:
@@ -196,13 +204,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
results = {}
room_ids = list(room_ids)
- for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
+ for rm_ids in (room_ids[i:i + 20] for i in range(0, len(room_ids), 20)):
res = yield make_deferred_yieldable(defer.gatherResults([
- preserve_fn(self.get_room_events_stream_for_room)(
+ run_in_background(
+ self.get_room_events_stream_for_room,
room_id, from_key, to_key, limit, order=order,
)
for room_id in rm_ids
- ]))
+ ], consumeErrors=True))
results.update(dict(zip(rm_ids, res)))
defer.returnValue(results)
@@ -224,54 +233,55 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
@defer.inlineCallbacks
def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0,
order='DESC'):
- # Note: If from_key is None then we return in topological order. This
- # is because in that case we're using this as a "get the last few messages
- # in a room" function, rather than "get new messages since last sync"
- if from_key is not None:
- from_id = RoomStreamToken.parse_stream_token(from_key).stream
- else:
- from_id = None
- to_id = RoomStreamToken.parse_stream_token(to_key).stream
+ """Get new room events in stream ordering since `from_key`.
+
+ Args:
+ room_id (str)
+ from_key (str): Token from which no events are returned before
+ to_key (str): Token from which no events are returned after. (This
+ is typically the current stream token)
+ limit (int): Maximum number of events to return
+ order (str): Either "DESC" or "ASC". Determines which events are
+ returned when the result is limited. If "DESC" then the most
+ recent `limit` events are returned, otherwise returns the
+ oldest `limit` events.
+
+ Returns:
+ Deferred[tuple[list[FrozenEvent], str]]: Returns the list of
+ events (in ascending order) and the token from the start of
+ the chunk of events returned.
+ """
if from_key == to_key:
defer.returnValue(([], from_key))
- if from_id:
- has_changed = yield self._events_stream_cache.has_entity_changed(
- room_id, from_id
- )
-
- if not has_changed:
- defer.returnValue(([], from_key))
+ from_id = RoomStreamToken.parse_stream_token(from_key).stream
+ to_id = RoomStreamToken.parse_stream_token(to_key).stream
- def f(txn):
- if from_id is not None:
- sql = (
- "SELECT event_id, stream_ordering FROM events WHERE"
- " room_id = ?"
- " AND not outlier"
- " AND stream_ordering > ? AND stream_ordering <= ?"
- " ORDER BY stream_ordering %s LIMIT ?"
- ) % (order,)
- txn.execute(sql, (room_id, from_id, to_id, limit))
- else:
- sql = (
- "SELECT event_id, stream_ordering FROM events WHERE"
- " room_id = ?"
- " AND not outlier"
- " AND stream_ordering <= ?"
- " ORDER BY topological_ordering %s, stream_ordering %s LIMIT ?"
- ) % (order, order,)
- txn.execute(sql, (room_id, to_id, limit))
+ has_changed = yield self._events_stream_cache.has_entity_changed(
+ room_id, from_id
+ )
- rows = self.cursor_to_dict(txn)
+ if not has_changed:
+ defer.returnValue(([], from_key))
+ def f(txn):
+ sql = (
+ "SELECT event_id, stream_ordering FROM events WHERE"
+ " room_id = ?"
+ " AND not outlier"
+ " AND stream_ordering > ? AND stream_ordering <= ?"
+ " ORDER BY stream_ordering %s LIMIT ?"
+ ) % (order,)
+ txn.execute(sql, (room_id, from_id, to_id, limit))
+
+ rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
return rows
rows = yield self.runInteraction("get_room_events_stream_for_room", f)
ret = yield self._get_events(
- [r["event_id"] for r in rows],
+ [r.event_id for r in rows],
get_prev_content=True
)
@@ -281,7 +291,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
ret.reverse()
if rows:
- key = "s%d" % min(r["stream_ordering"] for r in rows)
+ key = "s%d" % min(r.stream_ordering for r in rows)
else:
# Assume we didn't get anything because there was nothing to
# get.
@@ -291,10 +301,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
@defer.inlineCallbacks
def get_membership_changes_for_user(self, user_id, from_key, to_key):
- if from_key is not None:
- from_id = RoomStreamToken.parse_stream_token(from_key).stream
- else:
- from_id = None
+ from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
if from_key == to_key:
@@ -308,34 +315,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
defer.returnValue([])
def f(txn):
- if from_id is not None:
- sql = (
- "SELECT m.event_id, stream_ordering FROM events AS e,"
- " room_memberships AS m"
- " WHERE e.event_id = m.event_id"
- " AND m.user_id = ?"
- " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
- " ORDER BY e.stream_ordering ASC"
- )
- txn.execute(sql, (user_id, from_id, to_id,))
- else:
- sql = (
- "SELECT m.event_id, stream_ordering FROM events AS e,"
- " room_memberships AS m"
- " WHERE e.event_id = m.event_id"
- " AND m.user_id = ?"
- " AND stream_ordering <= ?"
- " ORDER BY stream_ordering ASC"
- )
- txn.execute(sql, (user_id, to_id,))
- rows = self.cursor_to_dict(txn)
+ sql = (
+ "SELECT m.event_id, stream_ordering FROM events AS e,"
+ " room_memberships AS m"
+ " WHERE e.event_id = m.event_id"
+ " AND m.user_id = ?"
+ " AND e.stream_ordering > ? AND e.stream_ordering <= ?"
+ " ORDER BY e.stream_ordering ASC"
+ )
+ txn.execute(sql, (user_id, from_id, to_id,))
+
+ rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
return rows
rows = yield self.runInteraction("get_membership_changes_for_user", f)
ret = yield self._get_events(
- [r["event_id"] for r in rows],
+ [r.event_id for r in rows],
get_prev_content=True
)
@@ -344,14 +341,28 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
defer.returnValue(ret)
@defer.inlineCallbacks
- def get_recent_events_for_room(self, room_id, limit, end_token, from_token=None):
+ def get_recent_events_for_room(self, room_id, limit, end_token):
+ """Get the most recent events in the room in topological ordering.
+
+ Args:
+ room_id (str)
+ limit (int)
+ end_token (str): The stream token representing now.
+
+ Returns:
+ Deferred[tuple[list[FrozenEvent], str]]: Returns a list of
+ events and a token pointing to the start of the returned
+ events.
+ The events returned are in ascending order.
+ """
+
rows, token = yield self.get_recent_event_ids_for_room(
- room_id, limit, end_token, from_token
+ room_id, limit, end_token,
)
logger.debug("stream before")
events = yield self._get_events(
- [r["event_id"] for r in rows],
+ [r.event_id for r in rows],
get_prev_content=True
)
logger.debug("stream after")
@@ -360,60 +371,36 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
defer.returnValue((events, token))
- @cached(num_args=4)
- def get_recent_event_ids_for_room(self, room_id, limit, end_token, from_token=None):
- end_token = RoomStreamToken.parse_stream_token(end_token)
-
- if from_token is None:
- sql = (
- "SELECT stream_ordering, topological_ordering, event_id"
- " FROM events"
- " WHERE room_id = ? AND stream_ordering <= ? AND outlier = ?"
- " ORDER BY topological_ordering DESC, stream_ordering DESC"
- " LIMIT ?"
- )
- else:
- from_token = RoomStreamToken.parse_stream_token(from_token)
- sql = (
- "SELECT stream_ordering, topological_ordering, event_id"
- " FROM events"
- " WHERE room_id = ? AND stream_ordering > ?"
- " AND stream_ordering <= ? AND outlier = ?"
- " ORDER BY topological_ordering DESC, stream_ordering DESC"
- " LIMIT ?"
- )
-
- def get_recent_events_for_room_txn(txn):
- if from_token is None:
- txn.execute(sql, (room_id, end_token.stream, False, limit,))
- else:
- txn.execute(sql, (
- room_id, from_token.stream, end_token.stream, False, limit
- ))
+ @defer.inlineCallbacks
+ def get_recent_event_ids_for_room(self, room_id, limit, end_token):
+ """Get the most recent events in the room in topological ordering.
- rows = self.cursor_to_dict(txn)
+ Args:
+ room_id (str)
+ limit (int)
+ end_token (str): The stream token representing now.
- rows.reverse() # As we selected with reverse ordering
+ Returns:
+ Deferred[tuple[list[_EventDictReturn], str]]: Returns a list of
+ _EventDictReturn and a token pointing to the start of the returned
+ events.
+ The events returned are in ascending order.
+ """
+ # Allow a zero limit here, and no-op.
+ if limit == 0:
+ defer.returnValue(([], end_token))
- if rows:
- # Tokens are positions between events.
- # This token points *after* the last event in the chunk.
- # We need it to point to the event before it in the chunk
- # since we are going backwards so we subtract one from the
- # stream part.
- topo = rows[0]["topological_ordering"]
- toke = rows[0]["stream_ordering"] - 1
- start_token = str(RoomStreamToken(topo, toke))
+ end_token = RoomStreamToken.parse(end_token)
- token = (start_token, str(end_token))
- else:
- token = (str(end_token), str(end_token))
+ rows, token = yield self.runInteraction(
+ "get_recent_event_ids_for_room", self._paginate_room_events_txn,
+ room_id, from_token=end_token, limit=limit,
+ )
- return rows, token
+ # We want to return the results in ascending order.
+ rows.reverse()
- return self.runInteraction(
- "get_recent_events_for_room", get_recent_events_for_room_txn
- )
+ defer.returnValue((rows, token))
def get_room_event_after_stream_ordering(self, room_id, stream_ordering):
"""Gets details of the first event in a room at or after a stream ordering
@@ -517,10 +504,20 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
@staticmethod
def _set_before_and_after(events, rows, topo_order=True):
+ """Inserts ordering information to events' internal metadata from
+ the DB rows.
+
+ Args:
+ events (list[FrozenEvent])
+ rows (list[_EventDictReturn])
+ topo_order (bool): Whether the events were ordered topologically
+ or by stream ordering. If true then all rows should have a non
+ null topological_ordering.
+ """
for event, row in zip(events, rows):
- stream = row["stream_ordering"]
- if topo_order:
- topo = event.depth
+ stream = row.stream_ordering
+ if topo_order and row.topological_ordering:
+ topo = row.topological_ordering
else:
topo = None
internal = event.internal_metadata
@@ -592,87 +589,27 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
retcols=["stream_ordering", "topological_ordering"],
)
- token = RoomStreamToken(
- results["topological_ordering"],
+ # Paginating backwards includes the event at the token, but paginating
+ # forward doesn't.
+ before_token = RoomStreamToken(
+ results["topological_ordering"] - 1,
results["stream_ordering"],
)
- if isinstance(self.database_engine, Sqlite3Engine):
- # SQLite3 doesn't optimise ``(x < a) OR (x = a AND y < b)``
- # So we give pass it to SQLite3 as the UNION ALL of the two queries.
-
- query_before = (
- "SELECT topological_ordering, stream_ordering, event_id FROM events"
- " WHERE room_id = ? AND topological_ordering < ?"
- " UNION ALL"
- " SELECT topological_ordering, stream_ordering, event_id FROM events"
- " WHERE room_id = ? AND topological_ordering = ? AND stream_ordering < ?"
- " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
- )
- before_args = (
- room_id, token.topological,
- room_id, token.topological, token.stream,
- before_limit,
- )
-
- query_after = (
- "SELECT topological_ordering, stream_ordering, event_id FROM events"
- " WHERE room_id = ? AND topological_ordering > ?"
- " UNION ALL"
- " SELECT topological_ordering, stream_ordering, event_id FROM events"
- " WHERE room_id = ? AND topological_ordering = ? AND stream_ordering > ?"
- " ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?"
- )
- after_args = (
- room_id, token.topological,
- room_id, token.topological, token.stream,
- after_limit,
- )
- else:
- query_before = (
- "SELECT topological_ordering, stream_ordering, event_id FROM events"
- " WHERE room_id = ? AND %s"
- " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?"
- ) % (upper_bound(token, self.database_engine, inclusive=False),)
-
- before_args = (room_id, before_limit)
-
- query_after = (
- "SELECT topological_ordering, stream_ordering, event_id FROM events"
- " WHERE room_id = ? AND %s"
- " ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?"
- ) % (lower_bound(token, self.database_engine, inclusive=False),)
-
- after_args = (room_id, after_limit)
-
- txn.execute(query_before, before_args)
-
- rows = self.cursor_to_dict(txn)
- events_before = [r["event_id"] for r in rows]
-
- if rows:
- start_token = str(RoomStreamToken(
- rows[0]["topological_ordering"],
- rows[0]["stream_ordering"] - 1,
- ))
- else:
- start_token = str(RoomStreamToken(
- token.topological,
- token.stream - 1,
- ))
-
- txn.execute(query_after, after_args)
+ after_token = RoomStreamToken(
+ results["topological_ordering"],
+ results["stream_ordering"],
+ )
- rows = self.cursor_to_dict(txn)
- events_after = [r["event_id"] for r in rows]
+ rows, start_token = self._paginate_room_events_txn(
+ txn, room_id, before_token, direction='b', limit=before_limit,
+ )
+ events_before = [r.event_id for r in rows]
- if rows:
- end_token = str(RoomStreamToken(
- rows[-1]["topological_ordering"],
- rows[-1]["stream_ordering"],
- ))
- else:
- end_token = str(token)
+ rows, end_token = self._paginate_room_events_txn(
+ txn, room_id, after_token, direction='f', limit=after_limit,
+ )
+ events_after = [r.event_id for r in rows]
return {
"before": {
@@ -735,17 +672,30 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def has_room_changed_since(self, room_id, stream_id):
return self._events_stream_cache.has_entity_changed(room_id, stream_id)
+ def _paginate_room_events_txn(self, txn, room_id, from_token, to_token=None,
+ direction='b', limit=-1, event_filter=None):
+ """Returns list of events before or after a given token.
-class StreamStore(StreamWorkerStore):
- def get_room_max_stream_ordering(self):
- return self._stream_id_gen.get_current_token()
+ Args:
+ txn
+ room_id (str)
+ from_token (RoomStreamToken): The token used to stream from
+ to_token (RoomStreamToken|None): A token which if given limits the
+ results to only those before
+ direction(char): Either 'b' or 'f' to indicate whether we are
+ paginating forwards or backwards from `from_key`.
+ limit (int): The maximum number of events to return.
+ event_filter (Filter|None): If provided filters the events to
+ those that match the filter.
- def get_room_min_stream_ordering(self):
- return self._backfill_id_gen.get_current_token()
+ Returns:
+ Deferred[tuple[list[_EventDictReturn], str]]: Returns the results
+ as a list of _EventDictReturn and a token that points to the end
+ of the result set.
+ """
+
+ assert int(limit) >= 0
- @defer.inlineCallbacks
- def paginate_room_events(self, room_id, from_key, to_key=None,
- direction='b', limit=-1, event_filter=None):
# Tokens really represent positions between elements, but we use
# the convention of pointing to the event before the gap. Hence
# we have a bit of asymmetry when it comes to equalities.
@@ -753,20 +703,20 @@ class StreamStore(StreamWorkerStore):
if direction == 'b':
order = "DESC"
bounds = upper_bound(
- RoomStreamToken.parse(from_key), self.database_engine
+ from_token, self.database_engine
)
- if to_key:
+ if to_token:
bounds = "%s AND %s" % (bounds, lower_bound(
- RoomStreamToken.parse(to_key), self.database_engine
+ to_token, self.database_engine
))
else:
order = "ASC"
bounds = lower_bound(
- RoomStreamToken.parse(from_key), self.database_engine
+ from_token, self.database_engine
)
- if to_key:
+ if to_token:
bounds = "%s AND %s" % (bounds, upper_bound(
- RoomStreamToken.parse(to_key), self.database_engine
+ to_token, self.database_engine
))
filter_clause, filter_args = filter_to_clause(event_filter)
@@ -775,52 +725,85 @@ class StreamStore(StreamWorkerStore):
bounds += " AND " + filter_clause
args.extend(filter_args)
- if int(limit) > 0:
- args.append(int(limit))
- limit_str = " LIMIT ?"
- else:
- limit_str = ""
+ args.append(int(limit))
sql = (
- "SELECT * FROM events"
+ "SELECT event_id, topological_ordering, stream_ordering"
+ " FROM events"
" WHERE outlier = ? AND room_id = ? AND %(bounds)s"
" ORDER BY topological_ordering %(order)s,"
- " stream_ordering %(order)s %(limit)s"
+ " stream_ordering %(order)s LIMIT ?"
) % {
"bounds": bounds,
"order": order,
- "limit": limit_str
}
- def f(txn):
- txn.execute(sql, args)
-
- rows = self.cursor_to_dict(txn)
-
- if rows:
- topo = rows[-1]["topological_ordering"]
- toke = rows[-1]["stream_ordering"]
- if direction == 'b':
- # Tokens are positions between events.
- # This token points *after* the last event in the chunk.
- # We need it to point to the event before it in the chunk
- # when we are going backwards so we subtract one from the
- # stream part.
- toke -= 1
- next_token = str(RoomStreamToken(topo, toke))
- else:
- # TODO (erikj): We should work out what to do here instead.
- next_token = to_key if to_key else from_key
+ txn.execute(sql, args)
+
+ rows = [_EventDictReturn(row[0], row[1], row[2]) for row in txn]
+
+ if rows:
+ topo = rows[-1].topological_ordering
+ toke = rows[-1].stream_ordering
+ if direction == 'b':
+ # Tokens are positions between events.
+ # This token points *after* the last event in the chunk.
+ # We need it to point to the event before it in the chunk
+ # when we are going backwards so we subtract one from the
+ # stream part.
+ toke -= 1
+ next_token = RoomStreamToken(topo, toke)
+ else:
+ # TODO (erikj): We should work out what to do here instead.
+ next_token = to_token if to_token else from_token
+
+ return rows, str(next_token),
+
+ @defer.inlineCallbacks
+ def paginate_room_events(self, room_id, from_key, to_key=None,
+ direction='b', limit=-1, event_filter=None):
+ """Returns list of events before or after a given token.
- return rows, next_token,
+ Args:
+ room_id (str)
+ from_key (str): The token used to stream from
+ to_key (str|None): A token which if given limits the results to
+ only those before
+ direction(char): Either 'b' or 'f' to indicate whether we are
+ paginating forwards or backwards from `from_key`.
+ limit (int): The maximum number of events to return. Zero or less
+ means no limit.
+ event_filter (Filter|None): If provided filters the events to
+ those that match the filter.
- rows, token = yield self.runInteraction("paginate_room_events", f)
+ Returns:
+ tuple[list[dict], str]: Returns the results as a list of dicts and
+ a token that points to the end of the result set. The dicts have
+ the keys "event_id", "topological_ordering" and "stream_orderign".
+ """
+
+ from_key = RoomStreamToken.parse(from_key)
+ if to_key:
+ to_key = RoomStreamToken.parse(to_key)
+
+ rows, token = yield self.runInteraction(
+ "paginate_room_events", self._paginate_room_events_txn,
+ room_id, from_key, to_key, direction, limit, event_filter,
+ )
events = yield self._get_events(
- [r["event_id"] for r in rows],
+ [r.event_id for r in rows],
get_prev_content=True
)
self._set_before_and_after(events, rows)
defer.returnValue((events, token))
+
+
+class StreamStore(StreamWorkerStore):
+ def get_room_max_stream_ordering(self):
+ return self._stream_id_gen.get_current_token()
+
+ def get_room_min_stream_ordering(self):
+ return self._backfill_id_gen.get_current_token()
diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py
index 13bff9f055..6671d3cfca 100644
--- a/synapse/storage/tags.py
+++ b/synapse/storage/tags.py
@@ -22,6 +22,8 @@ from twisted.internet import defer
import simplejson as json
import logging
+from six.moves import range
+
logger = logging.getLogger(__name__)
@@ -98,7 +100,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
batch_size = 50
results = []
- for i in xrange(0, len(tag_ids), batch_size):
+ for i in range(0, len(tag_ids), batch_size):
tags = yield self.runInteraction(
"get_all_updated_tag_content",
get_tag_content,
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index f825264ea9..e485d19b84 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -17,6 +17,7 @@ from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from twisted.internet import defer
+import six
from canonicaljson import encode_canonical_json
@@ -25,6 +26,13 @@ from collections import namedtuple
import logging
import simplejson as json
+# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
+# despite being deprecated and removed in favor of memoryview
+if six.PY2:
+ db_binary_type = buffer
+else:
+ db_binary_type = memoryview
+
logger = logging.getLogger(__name__)
@@ -110,7 +118,7 @@ class TransactionStore(SQLBaseStore):
"transaction_id": transaction_id,
"origin": origin,
"response_code": code,
- "response_json": buffer(encode_canonical_json(response_dict)),
+ "response_json": db_binary_type(encode_canonical_json(response_dict)),
"ts": self._clock.time_msec(),
},
or_ignore=True,
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index d6e289ffbe..275c299998 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -22,6 +22,8 @@ from synapse.api.constants import EventTypes, JoinRules
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import get_domain_from_id, get_localpart_from_id
+from six import iteritems
+
import re
import logging
@@ -100,7 +102,7 @@ class UserDirectoryStore(SQLBaseStore):
user_id, get_localpart_from_id(user_id), get_domain_from_id(user_id),
profile.display_name,
)
- for user_id, profile in users_with_profile.iteritems()
+ for user_id, profile in iteritems(users_with_profile)
)
elif isinstance(self.database_engine, Sqlite3Engine):
sql = """
@@ -112,7 +114,7 @@ class UserDirectoryStore(SQLBaseStore):
user_id,
"%s %s" % (user_id, p.display_name,) if p.display_name else user_id
)
- for user_id, p in users_with_profile.iteritems()
+ for user_id, p in iteritems(users_with_profile)
)
else:
# This should be unreachable.
@@ -130,7 +132,7 @@ class UserDirectoryStore(SQLBaseStore):
"display_name": profile.display_name,
"avatar_url": profile.avatar_url,
}
- for user_id, profile in users_with_profile.iteritems()
+ for user_id, profile in iteritems(users_with_profile)
]
)
for user_id in users_with_profile:
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 756d8ffa32..fc11e26623 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.api.errors import SynapseError
from synapse.util.logcontext import PreserveLoggingContext
from twisted.internet import defer, reactor, task
@@ -21,12 +20,9 @@ from twisted.internet import defer, reactor, task
import time
import logging
-logger = logging.getLogger(__name__)
-
+from itertools import islice
-class DeferredTimedOutError(SynapseError):
- def __init__(self):
- super(DeferredTimedOutError, self).__init__(504, "Timed out")
+logger = logging.getLogger(__name__)
def unwrapFirstError(failure):
@@ -86,52 +82,18 @@ class Clock(object):
if not ignore_errs:
raise
- def time_bound_deferred(self, given_deferred, time_out):
- if given_deferred.called:
- return given_deferred
-
- ret_deferred = defer.Deferred()
-
- def timed_out_fn():
- e = DeferredTimedOutError()
-
- try:
- ret_deferred.errback(e)
- except Exception:
- pass
-
- try:
- given_deferred.cancel()
- except Exception:
- pass
- timer = None
+def batch_iter(iterable, size):
+ """batch an iterable up into tuples with a maximum size
- def cancel(res):
- try:
- self.cancel_call_later(timer)
- except Exception:
- pass
- return res
+ Args:
+ iterable (iterable): the iterable to slice
+ size (int): the maximum batch size
- ret_deferred.addBoth(cancel)
-
- def success(res):
- try:
- ret_deferred.callback(res)
- except Exception:
- pass
-
- return res
-
- def err(res):
- try:
- ret_deferred.errback(res)
- except Exception:
- pass
-
- given_deferred.addCallbacks(callback=success, errback=err)
-
- timer = self.call_later(time_out, timed_out_fn)
-
- return ret_deferred
+ Returns:
+ an iterator over the chunks
+ """
+ # make sure we can deal with iterables like lists too
+ sourceiter = iter(iterable)
+ # call islice until it returns an empty tuple
+ return iter(lambda: tuple(islice(sourceiter, size)), ())
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 0729bb2863..9dd4e6b5bc 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -15,9 +15,11 @@
from twisted.internet import defer, reactor
+from twisted.internet.defer import CancelledError
+from twisted.python import failure
from .logcontext import (
- PreserveLoggingContext, make_deferred_yieldable, preserve_fn
+ PreserveLoggingContext, make_deferred_yieldable, run_in_background
)
from synapse.util import logcontext, unwrapFirstError
@@ -25,6 +27,8 @@ from contextlib import contextmanager
import logging
+from six.moves import range
+
logger = logging.getLogger(__name__)
@@ -156,13 +160,13 @@ def concurrently_execute(func, args, limit):
def _concurrently_execute_inner():
try:
while True:
- yield func(it.next())
+ yield func(next(it))
except StopIteration:
pass
return logcontext.make_deferred_yieldable(defer.gatherResults([
- preserve_fn(_concurrently_execute_inner)()
- for _ in xrange(limit)
+ run_in_background(_concurrently_execute_inner)
+ for _ in range(limit)
], consumeErrors=True)).addErrback(unwrapFirstError)
@@ -392,3 +396,68 @@ class ReadWriteLock(object):
self.key_to_current_writer.pop(key)
defer.returnValue(_ctx_manager())
+
+
+class DeferredTimeoutError(Exception):
+ """
+ This error is raised by default when a L{Deferred} times out.
+ """
+
+
+def add_timeout_to_deferred(deferred, timeout, on_timeout_cancel=None):
+ """
+ Add a timeout to a deferred by scheduling it to be cancelled after
+ timeout seconds.
+
+ This is essentially a backport of deferred.addTimeout, which was introduced
+ in twisted 16.5.
+
+ If the deferred gets timed out, it errbacks with a DeferredTimeoutError,
+ unless a cancelable function was passed to its initialization or unless
+ a different on_timeout_cancel callable is provided.
+
+ Args:
+ deferred (defer.Deferred): deferred to be timed out
+ timeout (Number): seconds to time out after
+
+ on_timeout_cancel (callable): A callable which is called immediately
+ after the deferred times out, and not if this deferred is
+ otherwise cancelled before the timeout.
+
+ It takes an arbitrary value, which is the value of the deferred at
+ that exact point in time (probably a CancelledError Failure), and
+ the timeout.
+
+ The default callable (if none is provided) will translate a
+ CancelledError Failure into a DeferredTimeoutError.
+ """
+ timed_out = [False]
+
+ def time_it_out():
+ timed_out[0] = True
+ deferred.cancel()
+
+ delayed_call = reactor.callLater(timeout, time_it_out)
+
+ def convert_cancelled(value):
+ if timed_out[0]:
+ to_call = on_timeout_cancel or _cancelled_to_timed_out_error
+ return to_call(value, timeout)
+ return value
+
+ deferred.addBoth(convert_cancelled)
+
+ def cancel_timeout(result):
+ # stop the pending call to cancel the deferred if it's been fired
+ if delayed_call.active():
+ delayed_call.cancel()
+ return result
+
+ deferred.addBoth(cancel_timeout)
+
+
+def _cancelled_to_timed_out_error(value, timeout):
+ if isinstance(value, failure.Failure):
+ value.trap(CancelledError)
+ raise DeferredTimeoutError(timeout, "Deferred")
+ return value
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 4adae96681..183faf75a1 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -13,28 +13,77 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import synapse.metrics
+from prometheus_client.core import Gauge, REGISTRY, GaugeMetricFamily
+
import os
-CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5))
+from six.moves import intern
+import six
-metrics = synapse.metrics.get_metrics_for("synapse.util.caches")
+CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5))
caches_by_name = {}
-# cache_counter = metrics.register_cache(
-# "cache",
-# lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
-# labels=["name"],
-# )
-
-
-def register_cache(name, cache):
- caches_by_name[name] = cache
- return metrics.register_cache(
- "cache",
- lambda: len(cache),
- name,
- )
+collectors_by_name = {}
+
+cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
+cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])
+cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name"])
+cache_total = Gauge("synapse_util_caches_cache:total", "", ["name"])
+
+response_cache_size = Gauge("synapse_util_caches_response_cache:size", "", ["name"])
+response_cache_hits = Gauge("synapse_util_caches_response_cache:hits", "", ["name"])
+response_cache_evicted = Gauge(
+ "synapse_util_caches_response_cache:evicted_size", "", ["name"]
+)
+response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"])
+
+
+def register_cache(cache_type, cache_name, cache):
+
+ # Check if the metric is already registered. Unregister it, if so.
+ # This usually happens during tests, as at runtime these caches are
+ # effectively singletons.
+ metric_name = "cache_%s_%s" % (cache_type, cache_name)
+ if metric_name in collectors_by_name.keys():
+ REGISTRY.unregister(collectors_by_name[metric_name])
+
+ class CacheMetric(object):
+
+ hits = 0
+ misses = 0
+ evicted_size = 0
+
+ def inc_hits(self):
+ self.hits += 1
+
+ def inc_misses(self):
+ self.misses += 1
+
+ def inc_evictions(self, size=1):
+ self.evicted_size += size
+
+ def describe(self):
+ return []
+
+ def collect(self):
+ if cache_type == "response_cache":
+ response_cache_size.labels(cache_name).set(len(cache))
+ response_cache_hits.labels(cache_name).set(self.hits)
+ response_cache_evicted.labels(cache_name).set(self.evicted_size)
+ response_cache_total.labels(cache_name).set(self.hits + self.misses)
+ else:
+ cache_size.labels(cache_name).set(len(cache))
+ cache_hits.labels(cache_name).set(self.hits)
+ cache_evicted.labels(cache_name).set(self.evicted_size)
+ cache_total.labels(cache_name).set(self.hits + self.misses)
+
+ yield GaugeMetricFamily("__unused", "")
+
+ metric = CacheMetric()
+ REGISTRY.register(metric)
+ caches_by_name[cache_name] = cache
+ collectors_by_name[metric_name] = metric
+ return metric
KNOWN_KEYS = {
@@ -66,7 +115,9 @@ def intern_string(string):
return None
try:
- string = string.encode("ascii")
+ if six.PY2:
+ string = string.encode("ascii")
+
return intern(string)
except UnicodeEncodeError:
return string
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 68285a7594..fc1874b65b 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -31,6 +31,9 @@ import functools
import inspect
import threading
+from six import string_types, itervalues
+import six
+
logger = logging.getLogger(__name__)
@@ -80,7 +83,7 @@ class Cache(object):
self.name = name
self.keylen = keylen
self.thread = None
- self.metrics = register_cache(name, self.cache)
+ self.metrics = register_cache("cache", name, self.cache)
def _on_evicted(self, evicted_count):
self.metrics.inc_evictions(evicted_count)
@@ -205,7 +208,7 @@ class Cache(object):
def invalidate_all(self):
self.check_thread()
self.cache.clear()
- for entry in self._pending_deferred_cache.itervalues():
+ for entry in itervalues(self._pending_deferred_cache):
entry.invalidate()
self._pending_deferred_cache.clear()
@@ -392,9 +395,10 @@ class CacheDescriptor(_CacheDescriptorBase):
ret.addErrback(onErr)
- # If our cache_key is a string, try to convert to ascii to save
- # a bit of space in large caches
- if isinstance(cache_key, basestring):
+ # If our cache_key is a string on py2, try to convert to ascii
+ # to save a bit of space in large caches. Py3 does this
+ # internally automatically.
+ if six.PY2 and isinstance(cache_key, string_types):
cache_key = to_ascii(cache_key)
result_d = ObservableDeferred(ret, consumeErrors=True)
@@ -565,7 +569,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
return results
return logcontext.make_deferred_yieldable(defer.gatherResults(
- cached_defers.values(),
+ list(cached_defers.values()),
consumeErrors=True,
).addCallback(update_results_dict).addErrback(
unwrapFirstError
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index 1709e8b429..bdc21e348f 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -55,7 +55,7 @@ class DictionaryCache(object):
__slots__ = []
self.sentinel = Sentinel()
- self.metrics = register_cache(name, self.cache)
+ self.metrics = register_cache("dictionary", name, self.cache)
def check_thread(self):
expected_thread = self.thread
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 0aa103eecb..ff04c91955 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -52,12 +52,12 @@ class ExpiringCache(object):
self._cache = OrderedDict()
- self.metrics = register_cache(cache_name, self)
-
self.iterable = iterable
self._size_estimate = 0
+ self.metrics = register_cache("expiring", cache_name, self)
+
def start(self):
if not self._expiry_ms:
# Don't bother starting the loop if things never expire
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 7f79333e96..a8491b42d5 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -17,7 +17,7 @@ import logging
from twisted.internet import defer
from synapse.util.async import ObservableDeferred
-from synapse.util.caches import metrics as cache_metrics
+from synapse.util.caches import register_cache
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
@@ -38,15 +38,16 @@ class ResponseCache(object):
self.timeout_sec = timeout_ms / 1000.
self._name = name
- self._metrics = cache_metrics.register_cache(
- "response_cache",
- size_callback=lambda: self.size(),
- cache_name=name,
+ self._metrics = register_cache(
+ "response_cache", name, self
)
def size(self):
return len(self.pending_result_cache)
+ def __len__(self):
+ return self.size()
+
def get(self, key):
"""Look up the given key.
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 941d873ab8..a7fe0397fa 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -38,7 +38,7 @@ class StreamChangeCache(object):
self._cache = sorteddict()
self._earliest_known_stream_pos = current_stream_pos
self.name = name
- self.metrics = register_cache(self.name, self._cache)
+ self.metrics = register_cache("cache", self.name, self._cache)
for entity, stream_pos in prefilled_cache.items():
self.entity_has_changed(entity, stream_pos)
diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py
index fcc341a6b7..dd4c9e6067 100644
--- a/synapse/util/caches/treecache.py
+++ b/synapse/util/caches/treecache.py
@@ -1,3 +1,5 @@
+from six import itervalues
+
SENTINEL = object()
@@ -49,7 +51,7 @@ class TreeCache(object):
if popped is SENTINEL:
return default
- node_and_keys = zip(nodes, key)
+ node_and_keys = list(zip(nodes, key))
node_and_keys.reverse()
node_and_keys.append((self.root, None))
@@ -76,7 +78,7 @@ def iterate_tree_cache_entry(d):
can contain dicts.
"""
if isinstance(d, dict):
- for value_d in d.itervalues():
+ for value_d in itervalues(d):
for value in iterate_tree_cache_entry(value_d):
yield value
else:
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
index 3c8a165331..3380970e4e 100644
--- a/synapse/util/file_consumer.py
+++ b/synapse/util/file_consumer.py
@@ -15,7 +15,7 @@
from twisted.internet import threads, reactor
-from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
+from synapse.util.logcontext import make_deferred_yieldable, run_in_background
from six.moves import queue
@@ -70,7 +70,9 @@ class BackgroundFileConsumer(object):
self._producer = producer
self.streaming = streaming
- self._finished_deferred = preserve_fn(threads.deferToThread)(self._writer)
+ self._finished_deferred = run_in_background(
+ threads.deferToThread, self._writer
+ )
if not streaming:
self._producer.resumeProducing()
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index f497b51f4a..15f0a7ba9e 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -16,16 +16,17 @@
from frozendict import frozendict
import simplejson as json
+from six import string_types
+
def freeze(o):
- t = type(o)
- if t is dict:
+ if isinstance(o, dict):
return frozendict({k: freeze(v) for k, v in o.items()})
- if t is frozendict:
+ if isinstance(o, frozendict):
return o
- if t is str or t is unicode:
+ if isinstance(o, string_types):
return o
try:
@@ -37,11 +38,10 @@ def freeze(o):
def unfreeze(o):
- t = type(o)
- if t is dict or t is frozendict:
+ if isinstance(o, (dict, frozendict)):
return dict({k: unfreeze(v) for k, v in o.items()})
- if t is str or t is unicode:
+ if isinstance(o, string_types):
return o
try:
diff --git a/synapse/util/httpresourcetree.py b/synapse/util/httpresourcetree.py
index d747849553..e9f0f292ee 100644
--- a/synapse/util/httpresourcetree.py
+++ b/synapse/util/httpresourcetree.py
@@ -40,9 +40,12 @@ def create_resource_tree(desired_tree, root_resource):
# extra resources to existing nodes. See self._resource_id for the key.
resource_mappings = {}
for full_path, res in desired_tree.items():
+ # twisted requires all resources to be bytes
+ full_path = full_path.encode("utf-8")
+
logger.info("Attaching %s to path %s", res, full_path)
last_resource = root_resource
- for path_seg in full_path.split('/')[1:-1]:
+ for path_seg in full_path.split(b'/')[1:-1]:
if path_seg not in last_resource.listNames():
# resource doesn't exist, so make a "dummy resource"
child_resource = NoResource()
@@ -57,7 +60,7 @@ def create_resource_tree(desired_tree, root_resource):
# ===========================
# now attach the actual desired resource
- last_path_seg = full_path.split('/')[-1]
+ last_path_seg = full_path.split(b'/')[-1]
# if there is already a resource here, thieve its children and
# replace it
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index d59adc236e..a58c723403 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -59,8 +59,8 @@ class LoggingContext(object):
__slots__ = [
"previous_context", "name", "ru_stime", "ru_utime",
- "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
- "usage_start", "usage_end",
+ "db_txn_count", "db_txn_duration_sec", "db_sched_duration_sec",
+ "usage_start",
"main_thread", "alive",
"request", "tag",
]
@@ -84,10 +84,10 @@ class LoggingContext(object):
def stop(self):
pass
- def add_database_transaction(self, duration_ms):
+ def add_database_transaction(self, duration_sec):
pass
- def add_database_scheduled(self, sched_ms):
+ def add_database_scheduled(self, sched_sec):
pass
def __nonzero__(self):
@@ -103,14 +103,16 @@ class LoggingContext(object):
self.ru_utime = 0.
self.db_txn_count = 0
- # ms spent waiting for db txns, excluding scheduling time
- self.db_txn_duration_ms = 0
+ # sec spent waiting for db txns, excluding scheduling time
+ self.db_txn_duration_sec = 0
- # ms spent waiting for db txns to be scheduled
- self.db_sched_duration_ms = 0
+ # sec spent waiting for db txns to be scheduled
+ self.db_sched_duration_sec = 0
+ # If alive has the thread resource usage when the logcontext last
+ # became active.
self.usage_start = None
- self.usage_end = None
+
self.main_thread = threading.current_thread()
self.request = None
self.tag = ""
@@ -159,12 +161,12 @@ class LoggingContext(object):
"""Restore the logging context in thread local storage to the state it
was before this context was entered.
Returns:
- None to avoid suppressing any exeptions that were thrown.
+ None to avoid suppressing any exceptions that were thrown.
"""
current = self.set_current_context(self.previous_context)
if current is not self:
if current is self.sentinel:
- logger.debug("Expected logging context %s has been lost", self)
+ logger.warn("Expected logging context %s has been lost", self)
else:
logger.warn(
"Current logging context %s is not expected context %s",
@@ -185,47 +187,61 @@ class LoggingContext(object):
def start(self):
if threading.current_thread() is not self.main_thread:
+ logger.warning("Started logcontext %s on different thread", self)
return
- if self.usage_start and self.usage_end:
- self.ru_utime += self.usage_end.ru_utime - self.usage_start.ru_utime
- self.ru_stime += self.usage_end.ru_stime - self.usage_start.ru_stime
- self.usage_start = None
- self.usage_end = None
-
+ # If we haven't already started record the thread resource usage so
+ # far
if not self.usage_start:
self.usage_start = get_thread_resource_usage()
def stop(self):
if threading.current_thread() is not self.main_thread:
+ logger.warning("Stopped logcontext %s on different thread", self)
return
+ # When we stop, let's record the resource used since we started
if self.usage_start:
- self.usage_end = get_thread_resource_usage()
+ usage_end = get_thread_resource_usage()
+
+ self.ru_utime += usage_end.ru_utime - self.usage_start.ru_utime
+ self.ru_stime += usage_end.ru_stime - self.usage_start.ru_stime
+
+ self.usage_start = None
+ else:
+ logger.warning("Called stop on logcontext %s without calling start", self)
def get_resource_usage(self):
+ """Get CPU time used by this logcontext so far.
+
+ Returns:
+ tuple[float, float]: The user and system CPU usage in seconds
+ """
ru_utime = self.ru_utime
ru_stime = self.ru_stime
- if self.usage_start and threading.current_thread() is self.main_thread:
+ # If we are on the correct thread and we're currently running then we
+ # can include resource usage so far.
+ is_main_thread = threading.current_thread() is self.main_thread
+ if self.alive and self.usage_start and is_main_thread:
current = get_thread_resource_usage()
ru_utime += current.ru_utime - self.usage_start.ru_utime
ru_stime += current.ru_stime - self.usage_start.ru_stime
return ru_utime, ru_stime
- def add_database_transaction(self, duration_ms):
+ def add_database_transaction(self, duration_sec):
self.db_txn_count += 1
- self.db_txn_duration_ms += duration_ms
+ self.db_txn_duration_sec += duration_sec
- def add_database_scheduled(self, sched_ms):
+ def add_database_scheduled(self, sched_sec):
"""Record a use of the database pool
Args:
- sched_ms (int): number of milliseconds it took us to get a
+ sched_sec (float): number of seconds it took us to get a
connection
"""
- self.db_sched_duration_ms += sched_ms
+ self.db_sched_duration_sec += sched_sec
class LoggingContextFilter(logging.Filter):
@@ -279,7 +295,7 @@ class PreserveLoggingContext(object):
context = LoggingContext.set_current_context(self.current_context)
if context != self.new_context:
- logger.debug(
+ logger.warn(
"Unexpected logging context: %s is not %s",
context, self.new_context,
)
@@ -302,31 +318,49 @@ def preserve_fn(f):
def run_in_background(f, *args, **kwargs):
"""Calls a function, ensuring that the current context is restored after
return from the function, and that the sentinel context is set once the
- deferred returned by the funtion completes.
+ deferred returned by the function completes.
Useful for wrapping functions that return a deferred which you don't yield
- on.
+ on (for instance because you want to pass it to deferred.gatherResults()).
+
+ Note that if you completely discard the result, you should make sure that
+ `f` doesn't raise any deferred exceptions, otherwise a scary-looking
+ CRITICAL error about an unhandled error will be logged without much
+ indication about where it came from.
"""
current = LoggingContext.current_context()
- res = f(*args, **kwargs)
- if isinstance(res, defer.Deferred) and not res.called:
- # The function will have reset the context before returning, so
- # we need to restore it now.
- LoggingContext.set_current_context(current)
-
- # The original context will be restored when the deferred
- # completes, but there is nothing waiting for it, so it will
- # get leaked into the reactor or some other function which
- # wasn't expecting it. We therefore need to reset the context
- # here.
- #
- # (If this feels asymmetric, consider it this way: we are
- # effectively forking a new thread of execution. We are
- # probably currently within a ``with LoggingContext()`` block,
- # which is supposed to have a single entry and exit point. But
- # by spawning off another deferred, we are effectively
- # adding a new exit point.)
- res.addBoth(_set_context_cb, LoggingContext.sentinel)
+ try:
+ res = f(*args, **kwargs)
+ except: # noqa: E722
+ # the assumption here is that the caller doesn't want to be disturbed
+ # by synchronous exceptions, so let's turn them into Failures.
+ return defer.fail()
+
+ if not isinstance(res, defer.Deferred):
+ return res
+
+ if res.called and not res.paused:
+ # The function should have maintained the logcontext, so we can
+ # optimise out the messing about
+ return res
+
+ # The function may have reset the context before returning, so
+ # we need to restore it now.
+ ctx = LoggingContext.set_current_context(current)
+
+ # The original context will be restored when the deferred
+ # completes, but there is nothing waiting for it, so it will
+ # get leaked into the reactor or some other function which
+ # wasn't expecting it. We therefore need to reset the context
+ # here.
+ #
+ # (If this feels asymmetric, consider it this way: we are
+ # effectively forking a new thread of execution. We are
+ # probably currently within a ``with LoggingContext()`` block,
+ # which is supposed to have a single entry and exit point. But
+ # by spawning off another deferred, we are effectively
+ # adding a new exit point.)
+ res.addBoth(_set_context_cb, ctx)
return res
@@ -341,11 +375,20 @@ def make_deferred_yieldable(deferred):
returning a deferred. Then, when the deferred completes, restores the
current logcontext before running callbacks/errbacks.
- (This is more-or-less the opposite operation to preserve_fn.)
+ (This is more-or-less the opposite operation to run_in_background.)
"""
- if isinstance(deferred, defer.Deferred) and not deferred.called:
- prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
- deferred.addBoth(_set_context_cb, prev_context)
+ if not isinstance(deferred, defer.Deferred):
+ return deferred
+
+ if deferred.called and not deferred.paused:
+ # it looks like this deferred is ready to run any callbacks we give it
+ # immediately. We may as well optimise out the logcontext faffery.
+ return deferred
+
+ # ok, we can't be sure that a yield won't block, so let's reset the
+ # logcontext, and add a callback to the deferred to restore it.
+ prev_context = LoggingContext.set_current_context(LoggingContext.sentinel)
+ deferred.addBoth(_set_context_cb, prev_context)
return deferred
diff --git a/synapse/util/logformatter.py b/synapse/util/logformatter.py
index cdbc4bffd7..3e42868ea9 100644
--- a/synapse/util/logformatter.py
+++ b/synapse/util/logformatter.py
@@ -14,7 +14,7 @@
# limitations under the License.
-import StringIO
+from six import StringIO
import logging
import traceback
@@ -32,7 +32,7 @@ class LogFormatter(logging.Formatter):
super(LogFormatter, self).__init__(*args, **kwargs)
def formatException(self, ei):
- sio = StringIO.StringIO()
+ sio = StringIO()
(typ, val, tb) = ei
# log the stack above the exception capture point if possible, but
diff --git a/synapse/util/logutils.py b/synapse/util/logutils.py
index 3a83828d25..03249c5dc8 100644
--- a/synapse/util/logutils.py
+++ b/synapse/util/logutils.py
@@ -96,7 +96,7 @@ def time_function(f):
id = _TIME_FUNC_ID
_TIME_FUNC_ID += 1
- start = time.clock() * 1000
+ start = time.clock()
try:
_log_debug_as_f(
@@ -107,10 +107,10 @@ def time_function(f):
r = f(*args, **kwargs)
finally:
- end = time.clock() * 1000
+ end = time.clock()
_log_debug_as_f(
f,
- "[FUNC END] {%s-%d} %f",
+ "[FUNC END] {%s-%d} %.3f sec",
(func_name, id, end - start,),
)
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index e4b5687a4b..1ba7d65c7c 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -15,8 +15,8 @@
from twisted.internet import defer
+from prometheus_client import Counter
from synapse.util.logcontext import LoggingContext
-import synapse.metrics
from functools import wraps
import logging
@@ -24,66 +24,26 @@ import logging
logger = logging.getLogger(__name__)
+block_counter = Counter("synapse_util_metrics_block_count", "", ["block_name"])
-metrics = synapse.metrics.get_metrics_for(__name__)
-
-# total number of times we have hit this block
-block_counter = metrics.register_counter(
- "block_count",
- labels=["block_name"],
- alternative_names=(
- # the following are all deprecated aliases for the same metric
- metrics.name_prefix + x for x in (
- "_block_timer:count",
- "_block_ru_utime:count",
- "_block_ru_stime:count",
- "_block_db_txn_count:count",
- "_block_db_txn_duration:count",
- )
- )
-)
-
-block_timer = metrics.register_counter(
- "block_time_seconds",
- labels=["block_name"],
- alternative_names=(
- metrics.name_prefix + "_block_timer:total",
- ),
-)
-
-block_ru_utime = metrics.register_counter(
- "block_ru_utime_seconds", labels=["block_name"],
- alternative_names=(
- metrics.name_prefix + "_block_ru_utime:total",
- ),
-)
-
-block_ru_stime = metrics.register_counter(
- "block_ru_stime_seconds", labels=["block_name"],
- alternative_names=(
- metrics.name_prefix + "_block_ru_stime:total",
- ),
-)
-
-block_db_txn_count = metrics.register_counter(
- "block_db_txn_count", labels=["block_name"],
- alternative_names=(
- metrics.name_prefix + "_block_db_txn_count:total",
- ),
-)
+block_timer = Counter("synapse_util_metrics_block_time_seconds", "", ["block_name"])
+
+block_ru_utime = Counter(
+ "synapse_util_metrics_block_ru_utime_seconds", "", ["block_name"])
+
+block_ru_stime = Counter(
+ "synapse_util_metrics_block_ru_stime_seconds", "", ["block_name"])
+
+block_db_txn_count = Counter(
+ "synapse_util_metrics_block_db_txn_count", "", ["block_name"])
# seconds spent waiting for db txns, excluding scheduling time, in this block
-block_db_txn_duration = metrics.register_counter(
- "block_db_txn_duration_seconds", labels=["block_name"],
- alternative_names=(
- metrics.name_prefix + "_block_db_txn_duration:total",
- ),
-)
+block_db_txn_duration = Counter(
+ "synapse_util_metrics_block_db_txn_duration_seconds", "", ["block_name"])
# seconds spent waiting for a db connection, in this block
-block_db_sched_duration = metrics.register_counter(
- "block_db_sched_duration_seconds", labels=["block_name"],
-)
+block_db_sched_duration = Counter(
+ "synapse_util_metrics_block_db_sched_duration_seconds", "", ["block_name"])
def measure_func(name):
@@ -102,7 +62,7 @@ class Measure(object):
__slots__ = [
"clock", "name", "start_context", "start", "new_context", "ru_utime",
"ru_stime",
- "db_txn_count", "db_txn_duration_ms", "db_sched_duration_ms",
+ "db_txn_count", "db_txn_duration_sec", "db_sched_duration_sec",
"created_context",
]
@@ -114,7 +74,7 @@ class Measure(object):
self.created_context = False
def __enter__(self):
- self.start = self.clock.time_msec()
+ self.start = self.clock.time()
self.start_context = LoggingContext.current_context()
if not self.start_context:
self.start_context = LoggingContext("Measure")
@@ -123,17 +83,17 @@ class Measure(object):
self.ru_utime, self.ru_stime = self.start_context.get_resource_usage()
self.db_txn_count = self.start_context.db_txn_count
- self.db_txn_duration_ms = self.start_context.db_txn_duration_ms
- self.db_sched_duration_ms = self.start_context.db_sched_duration_ms
+ self.db_txn_duration_sec = self.start_context.db_txn_duration_sec
+ self.db_sched_duration_sec = self.start_context.db_sched_duration_sec
def __exit__(self, exc_type, exc_val, exc_tb):
if isinstance(exc_type, Exception) or not self.start_context:
return
- duration = self.clock.time_msec() - self.start
+ duration = self.clock.time() - self.start
- block_counter.inc(self.name)
- block_timer.inc_by(duration, self.name)
+ block_counter.labels(self.name).inc()
+ block_timer.labels(self.name).inc(duration)
context = LoggingContext.current_context()
@@ -150,19 +110,13 @@ class Measure(object):
ru_utime, ru_stime = context.get_resource_usage()
- block_ru_utime.inc_by(ru_utime - self.ru_utime, self.name)
- block_ru_stime.inc_by(ru_stime - self.ru_stime, self.name)
- block_db_txn_count.inc_by(
- context.db_txn_count - self.db_txn_count, self.name
- )
- block_db_txn_duration.inc_by(
- (context.db_txn_duration_ms - self.db_txn_duration_ms) / 1000.,
- self.name
- )
- block_db_sched_duration.inc_by(
- (context.db_sched_duration_ms - self.db_sched_duration_ms) / 1000.,
- self.name
- )
+ block_ru_utime.labels(self.name).inc(ru_utime - self.ru_utime)
+ block_ru_stime.labels(self.name).inc(ru_stime - self.ru_stime)
+ block_db_txn_count.labels(self.name).inc(context.db_txn_count - self.db_txn_count)
+ block_db_txn_duration.labels(self.name).inc(
+ context.db_txn_duration_sec - self.db_txn_duration_sec)
+ block_db_sched_duration.labels(self.name).inc(
+ context.db_sched_duration_sec - self.db_sched_duration_sec)
if self.created_context:
self.start_context.__exit__(exc_type, exc_val, exc_tb)
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 1101881a2d..0ab63c3d7d 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -18,7 +18,10 @@ from twisted.internet import defer
from synapse.api.errors import LimitExceededError
from synapse.util.async import sleep
-from synapse.util.logcontext import preserve_fn
+from synapse.util.logcontext import (
+ run_in_background, make_deferred_yieldable,
+ PreserveLoggingContext,
+)
import collections
import contextlib
@@ -150,7 +153,7 @@ class _PerHostRatelimiter(object):
"Ratelimit [%s]: sleeping req",
id(request_id),
)
- ret_defer = preserve_fn(sleep)(self.sleep_msec / 1000.0)
+ ret_defer = run_in_background(sleep, self.sleep_msec / 1000.0)
self.sleeping_requests.add(request_id)
@@ -176,6 +179,9 @@ class _PerHostRatelimiter(object):
return r
def on_err(r):
+ # XXX: why is this necessary? this is called before we start
+ # processing the request so why would the request be in
+ # current_processing?
self.current_processing.discard(request_id)
return r
@@ -187,7 +193,7 @@ class _PerHostRatelimiter(object):
ret_defer.addCallbacks(on_start, on_err)
ret_defer.addBoth(on_both)
- return ret_defer
+ return make_deferred_yieldable(ret_defer)
def _on_exit(self, request_id):
logger.debug(
@@ -197,7 +203,12 @@ class _PerHostRatelimiter(object):
self.current_processing.discard(request_id)
try:
request_id, deferred = self.ready_request_queue.popitem()
+
+ # XXX: why do we do the following? the on_start callback above will
+ # do it for us.
self.current_processing.add(request_id)
- deferred.callback(None)
+
+ with PreserveLoggingContext():
+ deferred.callback(None)
except KeyError:
pass
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 47b0bb5eb3..4e93f69d3a 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -203,8 +203,8 @@ class RetryDestinationLimiter(object):
)
except Exception:
logger.exception(
- "Failed to store set_destination_retry_timings",
+ "Failed to store destination_retry_timings",
)
# we deliberately do this in the background.
- synapse.util.logcontext.preserve_fn(store_retry_timings)()
+ synapse.util.logcontext.run_in_background(store_retry_timings)
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index 95a6168e16..b98b9dc6e4 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -15,6 +15,7 @@
import random
import string
+from six.moves import range
_string_with_symbols = (
string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
@@ -22,12 +23,12 @@ _string_with_symbols = (
def random_string(length):
- return ''.join(random.choice(string.ascii_letters) for _ in xrange(length))
+ return ''.join(random.choice(string.ascii_letters) for _ in range(length))
def random_string_with_symbols(length):
return ''.join(
- random.choice(_string_with_symbols) for _ in xrange(length)
+ random.choice(_string_with_symbols) for _ in range(length)
)
diff --git a/synapse/util/wheel_timer.py b/synapse/util/wheel_timer.py
index b70f9a6b0a..7a9e45aca9 100644
--- a/synapse/util/wheel_timer.py
+++ b/synapse/util/wheel_timer.py
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from six.moves import range
+
class _Entry(object):
__slots__ = ["end_key", "queue"]
@@ -68,7 +70,7 @@ class WheelTimer(object):
# Add empty entries between the end of the current list and when we want
# to insert. This ensures there are no gaps.
self.entries.extend(
- _Entry(key) for key in xrange(last_key, then_key + 1)
+ _Entry(key) for key in range(last_key, then_key + 1)
)
self.entries[-1].queue.append(obj)
|