diff --git a/CHANGES.rst b/CHANGES.rst
index f89542a2bb..da31af9606 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -1,3 +1,12 @@
+Changes in synapse v0.8.1 (2015-03-18)
+======================================
+
+* Disable registration by default. New users can be added using the command
+ ``register_new_matrix_user`` or by enabling registration in the config.
+* Add metrics to synapse. To enable metrics use config options
+ ``enable_metrics`` and ``metrics_port``.
+* Fix bug where banning only kicked the user.
+
Changes in synapse v0.8.0 (2015-03-06)
======================================
diff --git a/README.rst b/README.rst
index c2af7c9332..874753762d 100644
--- a/README.rst
+++ b/README.rst
@@ -1,3 +1,5 @@
+.. contents::
+
Introduction
============
@@ -126,6 +128,17 @@ To set up your homeserver, run (in your virtualenv, as before)::
Substituting your host and domain name as appropriate.
+By default, registration of new users is disabled. You can either enable
+registration in the config (it is then recommended to also set up CAPTCHA), or
+you can use the command line to register new users::
+
+ $ source ~/.synapse/bin/activate
+ $ register_new_matrix_user -c homeserver.yaml https://localhost:8448
+ New user localpart: erikj
+ Password:
+ Confirm password:
+ Success!
+
For reliable VoIP calls to be routed via this homeserver, you MUST configure
a TURN server. See docs/turn-howto.rst for details.
@@ -250,7 +263,8 @@ fix try re-installing from PyPI or directly from
ArchLinux
---------
-If running `$ synctl start` fails wit 'returned non-zero exit status 1', you will need to explicitly call Python2.7 - either running as::
+If running `$ synctl start` fails with 'returned non-zero exit status 1',
+you will need to explicitly call Python2.7 - either running as::
$ python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml --pid-file homeserver.pid
diff --git a/contrib/vertobot/bot.pl b/contrib/vertobot/bot.pl
index 828fc48786..0430a38aa8 100755
--- a/contrib/vertobot/bot.pl
+++ b/contrib/vertobot/bot.pl
@@ -175,13 +175,12 @@ sub on_room_message
my $verto_connecting = $loop->new_future;
$bot_verto->connect(
%{ $CONFIG{"verto-bot"} },
- on_connected => sub {
- warn("[Verto] connected to websocket");
- $verto_connecting->done($bot_verto) if not $verto_connecting->is_done;
- },
on_connect_error => sub { die "Cannot connect to verto - $_[-1]" },
on_resolve_error => sub { die "Cannot resolve to verto - $_[-1]" },
-);
+)->then( sub {
+ warn("[Verto] connected to websocket");
+ $verto_connecting->done($bot_verto) if not $verto_connecting->is_done;
+});
Future->needs_all(
$bot_matrix->login( %{ $CONFIG{"matrix-bot"} } )->then( sub {
diff --git a/contrib/vertobot/bridge.pl b/contrib/vertobot/bridge.pl
index e1a07f6659..a551850f40 100755
--- a/contrib/vertobot/bridge.pl
+++ b/contrib/vertobot/bridge.pl
@@ -86,7 +86,7 @@ sub create_virtual_user
"user": "$localpart"
}
EOT
- )->get;
+ )->get;
warn $response->as_string if ($response->code != 200);
}
@@ -266,17 +266,21 @@ my $as_url = $CONFIG{"matrix-bot"}->{as_url};
Future->needs_all(
$http->do_request(
- method => "POST",
- uri => URI->new( $CONFIG{"matrix"}->{server}."/_matrix/appservice/v1/register" ),
- content_type => "application/json",
- content => <<EOT
+ method => "POST",
+ uri => URI->new( $CONFIG{"matrix"}->{server}."/_matrix/appservice/v1/register" ),
+ content_type => "application/json",
+ content => <<EOT
{
"as_token": "$as_token",
"url": "$as_url",
- "namespaces": { "users": ["\@\\\\+.*"] }
+ "namespaces": { "users": [ { "regex": "\@\\\\+.*", "exclusive": false } ] }
}
EOT
- ),
+ )->then( sub{
+ my ($response) = (@_);
+ warn $response->as_string if ($response->code != 200);
+ return Future->done;
+ }),
$verto_connecting,
)->get;
diff --git a/register_new_matrix_user b/register_new_matrix_user
new file mode 100755
index 0000000000..daddadc302
--- /dev/null
+++ b/register_new_matrix_user
@@ -0,0 +1,149 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import getpass
+import hashlib
+import hmac
+import json
+import sys
+import urllib2
+import yaml
+
+
+def request_registration(user, password, server_location, shared_secret):
+ mac = hmac.new(
+ key=shared_secret,
+ msg=user,
+ digestmod=hashlib.sha1,
+ ).hexdigest()
+
+ data = {
+ "user": user,
+ "password": password,
+ "mac": mac,
+ "type": "org.matrix.login.shared_secret",
+ }
+
+ server_location = server_location.rstrip("/")
+
+ print "Sending registration request..."
+
+ req = urllib2.Request(
+ "%s/_matrix/client/api/v1/register" % (server_location,),
+ data=json.dumps(data),
+ headers={'Content-Type': 'application/json'}
+ )
+ try:
+ f = urllib2.urlopen(req)
+ f.read()
+ f.close()
+ print "Success."
+ except urllib2.HTTPError as e:
+ print "ERROR! Received %d %s" % (e.code, e.reason,)
+ if 400 <= e.code < 500:
+ if e.info().type == "application/json":
+ resp = json.load(e)
+ if "error" in resp:
+ print resp["error"]
+ sys.exit(1)
+
+
+def register_new_user(user, password, server_location, shared_secret):
+ if not user:
+ try:
+ default_user = getpass.getuser()
+ except:
+ default_user = None
+
+ if default_user:
+ user = raw_input("New user localpart [%s]: " % (default_user,))
+ if not user:
+ user = default_user
+ else:
+ user = raw_input("New user localpart: ")
+
+ if not user:
+ print "Invalid user name"
+ sys.exit(1)
+
+ if not password:
+ password = getpass.getpass("Password: ")
+
+ if not password:
+ print "Password cannot be blank."
+ sys.exit(1)
+
+ confirm_password = getpass.getpass("Confirm password: ")
+
+ if password != confirm_password:
+ print "Passwords do not match"
+ sys.exit(1)
+
+ request_registration(user, password, server_location, shared_secret)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Used to register new users with a given home server when"
+ " registration has been disabled. The home server must be"
+ " configured with the 'registration_shared_secret' option"
+ " set.",
+ )
+ parser.add_argument(
+ "-u", "--user",
+ default=None,
+ help="Local part of the new user. Will prompt if omitted.",
+ )
+ parser.add_argument(
+ "-p", "--password",
+ default=None,
+ help="New password for user. Will prompt if omitted.",
+ )
+
+ group = parser.add_mutually_exclusive_group(required=True)
+ group.add_argument(
+ "-c", "--config",
+ type=argparse.FileType('r'),
+ help="Path to server config file. Used to read in shared secret.",
+ )
+
+ group.add_argument(
+ "-k", "--shared-secret",
+ help="Shared secret as defined in server config file.",
+ )
+
+ parser.add_argument(
+ "server_url",
+ default="https://localhost:8448",
+ nargs='?',
+ help="URL to use to talk to the home server. Defaults to "
+ " 'https://localhost:8448'.",
+ )
+
+ args = parser.parse_args()
+
+ if "config" in args and args.config:
+ config = yaml.safe_load(args.config)
+ secret = config.get("registration_shared_secret", None)
+ if not secret:
+ print "No 'registration_shared_secret' defined in config."
+ sys.exit(1)
+ else:
+ secret = args.shared_secret
+
+ register_new_user(args.user, args.password, args.server_url, secret)
diff --git a/setup.py b/setup.py
index 2d812fa389..45943adb2c 100755
--- a/setup.py
+++ b/setup.py
@@ -55,5 +55,5 @@ setup(
include_package_data=True,
zip_safe=False,
long_description=long_description,
- scripts=["synctl"],
+ scripts=["synctl", "register_new_matrix_user"],
)
diff --git a/synapse/__init__.py b/synapse/__init__.py
index f46a6df1fb..e134fb2415 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
-__version__ = "0.8.0"
+__version__ = "0.8.1"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index b176db8ce1..64f605b962 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -28,6 +28,12 @@ import logging
logger = logging.getLogger(__name__)
+AuthEventTypes = (
+ EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
+ EventTypes.JoinRules,
+)
+
+
class Auth(object):
def __init__(self, hs):
@@ -166,6 +172,7 @@ class Auth(object):
target = auth_events.get(key)
target_in_room = target and target.membership == Membership.JOIN
+ target_banned = target and target.membership == Membership.BAN
key = (EventTypes.JoinRules, "", )
join_rule_event = auth_events.get(key)
@@ -194,6 +201,7 @@ class Auth(object):
{
"caller_in_room": caller_in_room,
"caller_invited": caller_invited,
+ "target_banned": target_banned,
"target_in_room": target_in_room,
"membership": membership,
"join_rule": join_rule,
@@ -202,6 +210,11 @@ class Auth(object):
}
)
+ if ban_level:
+ ban_level = int(ban_level)
+ else:
+ ban_level = 50 # FIXME (erikj): What should we do here?
+
if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently
# PRIVATE join rules.
@@ -212,6 +225,10 @@ class Auth(object):
403,
"%s not in room %s." % (event.user_id, event.room_id,)
)
+ elif target_banned:
+ raise AuthError(
+ 403, "%s is banned from the room" % (target_user_id,)
+ )
elif target_in_room: # the target is already in the room.
raise AuthError(403, "%s is already in the room." %
target_user_id)
@@ -221,6 +238,8 @@ class Auth(object):
# joined: It's a NOOP
if event.user_id != target_user_id:
raise AuthError(403, "Cannot force another user to join.")
+ elif target_banned:
+ raise AuthError(403, "You are banned from this room")
elif join_rule == JoinRules.PUBLIC:
pass
elif join_rule == JoinRules.INVITE:
@@ -238,6 +257,10 @@ class Auth(object):
403,
"%s not in room %s." % (target_user_id, event.room_id,)
)
+ elif target_banned and user_level < ban_level:
+ raise AuthError(
+ 403, "You cannot unban user &s." % (target_user_id,)
+ )
elif target_user_id != event.user_id:
if kick_level:
kick_level = int(kick_level)
@@ -249,11 +272,6 @@ class Auth(object):
403, "You cannot kick user %s." % target_user_id
)
elif Membership.BAN == membership:
- if ban_level:
- ban_level = int(ban_level)
- else:
- ban_level = 50 # FIXME (erikj): What should we do here?
-
if user_level < ban_level:
raise AuthError(403, "You don't have permission to ban")
else:
@@ -370,7 +388,7 @@ class Auth(object):
AuthError if no user by that token exists or the token is invalid.
"""
try:
- ret = yield self.store.get_user_by_token(token=token)
+ ret = yield self.store.get_user_by_token(token)
if not ret:
raise StoreError(400, "Unknown token")
user_info = {
@@ -412,12 +430,6 @@ class Auth(object):
builder.auth_events = auth_events_entries
- context.auth_events = {
- k: v
- for k, v in context.current_state.items()
- if v.event_id in auth_ids
- }
-
def compute_auth_events(self, event, current_state):
if event.type == EventTypes.Create:
return []
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 420f963d91..b16bf4247d 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -60,6 +60,7 @@ class LoginType(object):
EMAIL_IDENTITY = u"m.login.email.identity"
RECAPTCHA = u"m.login.recaptcha"
APPLICATION_SERVICE = u"m.login.application_service"
+ SHARED_SECRET = u"org.matrix.login.shared_secret"
class EventTypes(object):
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index f96535a978..500cae05fb 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -26,6 +26,7 @@ from synapse.server import HomeServer
from synapse.python_dependencies import check_requirements
from twisted.internet import reactor
+from twisted.application import service
from twisted.enterprise import adbapi
from twisted.web.resource import Resource
from twisted.web.static import File
@@ -46,6 +47,7 @@ from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext
from synapse.rest.client.v1 import ClientV1RestResource
from synapse.rest.client.v2_alpha import ClientV2AlphaRestResource
+from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from daemonize import Daemonize
import twisted.manhole.telnet
@@ -58,7 +60,6 @@ import re
import resource
import subprocess
import sqlite3
-import syweb
logger = logging.getLogger(__name__)
@@ -81,6 +82,7 @@ class SynapseHomeServer(HomeServer):
return AppServiceRestResource(self)
def build_resource_for_web_client(self):
+ import syweb
syweb_path = os.path.dirname(syweb.__file__)
webclient_path = os.path.join(syweb_path, "webclient")
return File(webclient_path) # TODO configurable?
@@ -99,6 +101,12 @@ class SynapseHomeServer(HomeServer):
def build_resource_for_server_key(self):
return LocalKey(self)
+ def build_resource_for_metrics(self):
+ if self.get_config().enable_metrics:
+ return MetricsResource(self)
+ else:
+ return None
+
def build_db_pool(self):
return adbapi.ConnectionPool(
"sqlite3", self.get_db_name(),
@@ -109,7 +117,7 @@ class SynapseHomeServer(HomeServer):
# so that :memory: sqlite works
)
- def create_resource_tree(self, web_client, redirect_root_to_web_client):
+ def create_resource_tree(self, redirect_root_to_web_client):
"""Create the resource tree for this Home Server.
This in unduly complicated because Twisted does not support putting
@@ -121,6 +129,9 @@ class SynapseHomeServer(HomeServer):
location of the web client. This does nothing if web_client is not
True.
"""
+ config = self.get_config()
+ web_client = config.web_client
+
# list containing (path_str, Resource) e.g:
# [ ("/aaa/bbb/cc", Resource1), ("/aaa/dummy", Resource2) ]
desired_tree = [
@@ -144,6 +155,10 @@ class SynapseHomeServer(HomeServer):
else:
self.root_resource = Resource()
+ metrics_resource = self.get_resource_for_metrics()
+ if config.metrics_port is None and metrics_resource is not None:
+ desired_tree.append((METRICS_PREFIX, metrics_resource))
+
# ideally we'd just use getChild and putChild but getChild doesn't work
# unless you give it a Request object IN ADDITION to the name :/ So
# instead, we'll store a copy of this mapping so we can actually add
@@ -205,17 +220,32 @@ class SynapseHomeServer(HomeServer):
"""
return "%s-%s" % (resource, path_seg)
- def start_listening(self, secure_port, unsecure_port):
- if secure_port is not None:
+ def start_listening(self):
+ config = self.get_config()
+
+ if not config.no_tls and config.bind_port is not None:
reactor.listenSSL(
- secure_port, Site(self.root_resource), self.tls_context_factory
+ config.bind_port,
+ Site(self.root_resource),
+ self.tls_context_factory,
+ interface=config.bind_host
+ )
+ logger.info("Synapse now listening on port %d", config.bind_port)
+
+ if config.unsecure_port is not None:
+ reactor.listenTCP(
+ config.unsecure_port,
+ Site(self.root_resource),
+ interface=config.bind_host
)
- logger.info("Synapse now listening on port %d", secure_port)
- if unsecure_port is not None:
+ logger.info("Synapse now listening on port %d", config.unsecure_port)
+
+ metrics_resource = self.get_resource_for_metrics()
+ if metrics_resource and config.metrics_port is not None:
reactor.listenTCP(
- unsecure_port, Site(self.root_resource)
+ config.metrics_port, Site(metrics_resource), interface="127.0.0.1",
)
- logger.info("Synapse now listening on port %d", unsecure_port)
+ logger.info("Metrics now running on 127.0.0.1 port %d", config.metrics_port)
def get_version_string():
@@ -295,16 +325,26 @@ def change_resource_limit(soft_file_no):
logger.warn("Failed to set file limit: %s", e)
-def setup():
+def setup(config_options):
+ """
+ Args:
+ config_options_options: The options passed to Synapse. Usually
+ `sys.argv[1:]`.
+ should_run (bool): Whether to start the reactor.
+
+ Returns:
+ HomeServer
+ """
config = HomeServerConfig.load_config(
"Synapse Homeserver",
- sys.argv[1:],
+ config_options,
generate_section="Homeserver"
)
config.setup_logging()
- check_requirements()
+ # check any extra requirements we have now we have a config
+ check_requirements(config)
version_string = get_version_string()
@@ -330,7 +370,6 @@ def setup():
)
hs.create_resource_tree(
- web_client=config.webclient,
redirect_root_to_web_client=True,
)
@@ -359,24 +398,47 @@ def setup():
f.namespace['hs'] = hs
reactor.listenTCP(config.manhole, f, interface='127.0.0.1')
- bind_port = config.bind_port
- if config.no_tls:
- bind_port = None
-
- hs.start_listening(bind_port, config.unsecure_port)
+ hs.start_listening()
hs.get_pusherpool().start()
hs.get_state_handler().start_caching()
hs.get_datastore().start_profiling()
hs.get_replication_layer().start_get_pdu_cache()
- if config.daemonize:
- print config.pid_file
+ return hs
+
+
+class SynapseService(service.Service):
+ """A twisted Service class that will start synapse. Used to run synapse
+ via twistd and a .tac.
+ """
+ def __init__(self, config):
+ self.config = config
+
+ def startService(self):
+ hs = setup(self.config)
+ change_resource_limit(hs.config.soft_file_limit)
+
+ def stopService(self):
+ return self._port.stopListening()
+
+
+def run(hs):
+
+ def in_thread():
+ with LoggingContext("run"):
+ change_resource_limit(hs.config.soft_file_limit)
+
+ reactor.run()
+
+ if hs.config.daemonize:
+
+ print hs.config.pid_file
daemon = Daemonize(
app="synapse-homeserver",
- pid=config.pid_file,
- action=lambda: run(config),
+ pid=hs.config.pid_file,
+ action=lambda: in_thread(),
auto_close_fds=False,
verbose=True,
logger=logger,
@@ -384,20 +446,15 @@ def setup():
daemon.start()
else:
- run(config)
-
-
-def run(config):
- with LoggingContext("run"):
- change_resource_limit(config.soft_file_limit)
-
- reactor.run()
+ in_thread()
def main():
with LoggingContext("main"):
+ # check base requirements
check_requirements()
- setup()
+ hs = setup(sys.argv[1:])
+ run(hs)
if __name__ == '__main__':
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index c024535f52..241afdf872 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -23,11 +23,13 @@ from .captcha import CaptchaConfig
from .email import EmailConfig
from .voip import VoipConfig
from .registration import RegistrationConfig
+from .metrics import MetricsConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
- EmailConfig, VoipConfig, RegistrationConfig,):
+ EmailConfig, VoipConfig, RegistrationConfig,
+ MetricsConfig,):
pass
diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py
new file mode 100644
index 0000000000..901a429c76
--- /dev/null
+++ b/synapse/config/metrics.py
@@ -0,0 +1,36 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import Config
+
+
+class MetricsConfig(Config):
+ def __init__(self, args):
+ super(MetricsConfig, self).__init__(args)
+ self.enable_metrics = args.enable_metrics
+ self.metrics_port = args.metrics_port
+
+ @classmethod
+ def add_arguments(cls, parser):
+ super(MetricsConfig, cls).add_arguments(parser)
+ metrics_group = parser.add_argument_group("metrics")
+ metrics_group.add_argument(
+ '--enable-metrics', dest="enable_metrics", action="store_true",
+ help="Enable collection and rendering of performance metrics"
+ )
+ metrics_group.add_argument(
+ '--metrics-port', metavar="PORT", type=int,
+ help="Separate port to accept metrics requests on (on localhost)"
+ )
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index cca8ab5676..4401e774d1 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -15,19 +15,46 @@
from ._base import Config
+from synapse.util.stringutils import random_string_with_symbols
+
+import distutils.util
+
class RegistrationConfig(Config):
def __init__(self, args):
super(RegistrationConfig, self).__init__(args)
- self.disable_registration = args.disable_registration
+
+ # `args.disable_registration` may either be a bool or a string depending
+ # on if the option was given a value (e.g. --disable-registration=false
+ # would set `args.disable_registration` to "false" not False.)
+ self.disable_registration = bool(
+ distutils.util.strtobool(str(args.disable_registration))
+ )
+ self.registration_shared_secret = args.registration_shared_secret
@classmethod
def add_arguments(cls, parser):
super(RegistrationConfig, cls).add_arguments(parser)
reg_group = parser.add_argument_group("registration")
+
reg_group.add_argument(
"--disable-registration",
- action='store_true',
- help="Disable registration of new users."
+ const=True,
+ default=True,
+ nargs='?',
+ help="Disable registration of new users.",
)
+ reg_group.add_argument(
+ "--registration-shared-secret", type=str,
+ help="If set, allows registration by anyone who also has the shared"
+ " secret, even if registration is otherwise disabled.",
+ )
+
+ @classmethod
+ def generate_config(cls, args, config_dir_path):
+ if args.disable_registration is None:
+ args.disable_registration = True
+
+ if args.registration_shared_secret is None:
+ args.registration_shared_secret = random_string_with_symbols(50)
diff --git a/synapse/config/server.py b/synapse/config/server.py
index b042d4eed9..58a828cc4c 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -28,7 +28,7 @@ class ServerConfig(Config):
self.unsecure_port = args.unsecure_port
self.daemonize = args.daemonize
self.pid_file = self.abspath(args.pid_file)
- self.webclient = True
+ self.web_client = args.web_client
self.manhole = args.manhole
self.soft_file_limit = args.soft_file_limit
@@ -68,6 +68,8 @@ class ServerConfig(Config):
server_group.add_argument('--pid-file', default="homeserver.pid",
help="When running as a daemon, the file to"
" store the pid in")
+ server_group.add_argument('--web_client', default=True, type=bool,
+ help="Whether or not to serve a web client")
server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
type=int,
help="Turn on the twisted telnet manhole"
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 7e98bdef28..4ecadf0879 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -16,8 +16,7 @@
class EventContext(object):
- def __init__(self, current_state=None, auth_events=None):
+ def __init__(self, current_state=None):
self.current_state = current_state
- self.auth_events = auth_events
self.state_group = None
self.rejected = False
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index f131941f45..6811a0e3d1 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -25,6 +25,7 @@ from synapse.api.errors import (
from synapse.util.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
from synapse.events import FrozenEvent
+import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
@@ -36,9 +37,17 @@ import random
logger = logging.getLogger(__name__)
+# synapse.federation.federation_client is a silly name
+metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
+
+sent_pdus_destination_dist = metrics.register_distribution("sent_pdu_destinations")
+
+sent_edus_counter = metrics.register_counter("sent_edus")
+
+sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
+
+
class FederationClient(FederationBase):
- def __init__(self):
- self._get_pdu_cache = None
def start_get_pdu_cache(self):
self._get_pdu_cache = ExpiringCache(
@@ -68,6 +77,8 @@ class FederationClient(FederationBase):
order = self._order
self._order += 1
+ sent_pdus_destination_dist.inc_by(len(destinations))
+
logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
# TODO, add errback, etc.
@@ -87,6 +98,8 @@ class FederationClient(FederationBase):
content=content,
)
+ sent_edus_counter.inc()
+
# TODO, add errback, etc.
self._transaction_queue.enqueue_edu(edu)
return defer.succeed(None)
@@ -113,6 +126,8 @@ class FederationClient(FederationBase):
a Deferred which will eventually yield a JSON object from the
response
"""
+ sent_queries_counter.inc(query_type)
+
return self.transport_layer.make_query(
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 9c7dcdba96..25c0014f97 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -22,6 +22,7 @@ from .units import Transaction, Edu
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.events import FrozenEvent
+import synapse.metrics
from synapse.api.errors import FederationError, SynapseError
@@ -32,6 +33,15 @@ import logging
logger = logging.getLogger(__name__)
+# synapse.federation.federation_server is a silly name
+metrics = synapse.metrics.get_metrics_for("synapse.federation.server")
+
+received_pdus_counter = metrics.register_counter("received_pdus")
+
+received_edus_counter = metrics.register_counter("received_edus")
+
+received_queries_counter = metrics.register_counter("received_queries", labels=["type"])
+
class FederationServer(FederationBase):
def set_handler(self, handler):
@@ -84,6 +94,8 @@ class FederationServer(FederationBase):
def on_incoming_transaction(self, transaction_data):
transaction = Transaction(**transaction_data)
+ received_pdus_counter.inc_by(len(transaction.pdus))
+
for p in transaction.pdus:
if "unsigned" in p:
unsigned = p["unsigned"]
@@ -153,6 +165,8 @@ class FederationServer(FederationBase):
defer.returnValue((200, response))
def received_edu(self, origin, edu_type, content):
+ received_edus_counter.inc()
+
if edu_type in self.edu_handlers:
self.edu_handlers[edu_type](origin, content)
else:
@@ -204,6 +218,8 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_query_request(self, query_type, args):
+ received_queries_counter.inc(query_type)
+
if query_type in self.query_handlers:
response = yield self.query_handlers[query_type](args)
defer.returnValue((200, response))
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 741a4e7a1a..4dccd93d0e 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -25,12 +25,15 @@ from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.retryutils import (
get_retry_limiter, NotRetryingDestination,
)
+import synapse.metrics
import logging
logger = logging.getLogger(__name__)
+metrics = synapse.metrics.get_metrics_for(__name__)
+
class TransactionQueue(object):
"""This class makes sure we only have one transaction in flight at
@@ -54,11 +57,25 @@ class TransactionQueue(object):
# done
self.pending_transactions = {}
+ metrics.register_callback(
+ "pending_destinations",
+ lambda: len(self.pending_transactions),
+ )
+
# Is a mapping from destination -> list of
# tuple(pending pdus, deferred, order)
- self.pending_pdus_by_dest = {}
+ self.pending_pdus_by_dest = pdus = {}
# destination -> list of tuple(edu, deferred)
- self.pending_edus_by_dest = {}
+ self.pending_edus_by_dest = edus = {}
+
+ metrics.register_callback(
+ "pending_pdus",
+ lambda: sum(map(len, pdus.values())),
+ )
+ metrics.register_callback(
+ "pending_edus",
+ lambda: sum(map(len, edus.values())),
+ )
# destination -> list of tuple(failure, deferred)
self.pending_failures_by_dest = {}
@@ -115,8 +132,8 @@ class TransactionQueue(object):
if not deferred.called:
deferred.errback(failure)
- def log_failure(failure):
- logger.warn("Failed to send pdu", failure.value)
+ def log_failure(f):
+ logger.warn("Failed to send pdu to %s: %s", destination, f.value)
deferred.addErrback(log_failure)
@@ -143,8 +160,8 @@ class TransactionQueue(object):
if not deferred.called:
deferred.errback(failure)
- def log_failure(failure):
- logger.warn("Failed to send pdu", failure.value)
+ def log_failure(f):
+ logger.warn("Failed to send edu to %s: %s", destination, f.value)
deferred.addErrback(log_failure)
@@ -174,7 +191,7 @@ class TransactionQueue(object):
deferred.errback(f)
def log_failure(f):
- logger.warn("Failed to send pdu", f.value)
+ logger.warn("Failed to send failure to %s: %s", destination, f.value)
deferred.addErrback(log_failure)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index ece6dbcf62..7838a81362 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -19,6 +19,7 @@ from synapse.api.urls import FEDERATION_PREFIX as PREFIX
from synapse.api.errors import Codes, SynapseError
from synapse.util.logutils import log_function
+import functools
import logging
import simplejson as json
import re
@@ -30,8 +31,9 @@ logger = logging.getLogger(__name__)
class TransportLayerServer(object):
"""Handles incoming federation HTTP requests"""
+ # A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks
- def _authenticate_request(self, request):
+ def authenticate_request(self, request):
json_request = {
"method": request.method,
"uri": request.uri,
@@ -93,28 +95,6 @@ class TransportLayerServer(object):
defer.returnValue((origin, content))
- def _with_authentication(self, handler):
- @defer.inlineCallbacks
- def new_handler(request, *args, **kwargs):
- try:
- (origin, content) = yield self._authenticate_request(request)
- with self.ratelimiter.ratelimit(origin) as d:
- yield d
- response = yield handler(
- origin, content, request.args, *args, **kwargs
- )
- except:
- logger.exception("_authenticate_request failed")
- raise
- defer.returnValue(response)
- return new_handler
-
- def rate_limit_origin(self, handler):
- def new_handler(origin, *args, **kwargs):
- response = yield handler(origin, *args, **kwargs)
- defer.returnValue(response)
- return new_handler()
-
@log_function
def register_received_handler(self, handler):
""" Register a handler that will be fired when we receive data.
@@ -122,14 +102,12 @@ class TransportLayerServer(object):
Args:
handler (TransportReceivedHandler)
"""
- self.received_handler = handler
-
- # This is when someone is trying to send us a bunch of data.
- self.server.register_path(
- "PUT",
- re.compile("^" + PREFIX + "/send/([^/]*)/$"),
- self._with_authentication(self._on_send_request)
- )
+ FederationSendServlet(
+ handler,
+ authenticator=self,
+ ratelimiter=self.ratelimiter,
+ server_name=self.server_name,
+ ).register(self.server)
@log_function
def register_request_handler(self, handler):
@@ -138,136 +116,65 @@ class TransportLayerServer(object):
Args:
handler (TransportRequestHandler)
"""
- self.request_handler = handler
-
- # This is for when someone asks us for everything since version X
- self.server.register_path(
- "GET",
- re.compile("^" + PREFIX + "/pull/$"),
- self._with_authentication(
- lambda origin, content, query:
- handler.on_pull_request(query["origin"][0], query["v"])
- )
- )
+ for servletclass in SERVLET_CLASSES:
+ servletclass(
+ handler,
+ authenticator=self,
+ ratelimiter=self.ratelimiter,
+ ).register(self.server)
- # This is when someone asks for a data item for a given server
- # data_id pair.
- self.server.register_path(
- "GET",
- re.compile("^" + PREFIX + "/event/([^/]*)/$"),
- self._with_authentication(
- lambda origin, content, query, event_id:
- handler.on_pdu_request(origin, event_id)
- )
- )
- # This is when someone asks for all data for a given context.
- self.server.register_path(
- "GET",
- re.compile("^" + PREFIX + "/state/([^/]*)/$"),
- self._with_authentication(
- lambda origin, content, query, context:
- handler.on_context_state_request(
- origin,
- context,
- query.get("event_id", [None])[0],
- )
- )
- )
+class BaseFederationServlet(object):
+ def __init__(self, handler, authenticator, ratelimiter):
+ self.handler = handler
+ self.authenticator = authenticator
+ self.ratelimiter = ratelimiter
- self.server.register_path(
- "GET",
- re.compile("^" + PREFIX + "/backfill/([^/]*)/$"),
- self._with_authentication(
- lambda origin, content, query, context:
- self._on_backfill_request(
- origin, context, query["v"], query["limit"]
- )
- )
- )
+ def _wrap(self, code):
+ authenticator = self.authenticator
+ ratelimiter = self.ratelimiter
- # This is when we receive a server-server Query
- self.server.register_path(
- "GET",
- re.compile("^" + PREFIX + "/query/([^/]*)$"),
- self._with_authentication(
- lambda origin, content, query, query_type:
- handler.on_query_request(
- query_type,
- {k: v[0].decode("utf-8") for k, v in query.items()}
- )
- )
- )
+ @defer.inlineCallbacks
+ @functools.wraps(code)
+ def new_code(request, *args, **kwargs):
+ try:
+ (origin, content) = yield authenticator.authenticate_request(request)
+ with ratelimiter.ratelimit(origin) as d:
+ yield d
+ response = yield code(
+ origin, content, request.args, *args, **kwargs
+ )
+ except:
+ logger.exception("authenticate_request failed")
+ raise
+ defer.returnValue(response)
- self.server.register_path(
- "GET",
- re.compile("^" + PREFIX + "/make_join/([^/]*)/([^/]*)$"),
- self._with_authentication(
- lambda origin, content, query, context, user_id:
- self._on_make_join_request(
- origin, content, query, context, user_id
- )
- )
- )
+ # Extra logic that functools.wraps() doesn't finish
+ new_code.__self__ = code.__self__
- self.server.register_path(
- "GET",
- re.compile("^" + PREFIX + "/event_auth/([^/]*)/([^/]*)$"),
- self._with_authentication(
- lambda origin, content, query, context, event_id:
- handler.on_event_auth(
- origin, context, event_id,
- )
- )
- )
+ return new_code
- self.server.register_path(
- "PUT",
- re.compile("^" + PREFIX + "/send_join/([^/]*)/([^/]*)$"),
- self._with_authentication(
- lambda origin, content, query, context, event_id:
- self._on_send_join_request(
- origin, content, query,
- )
- )
- )
+ def register(self, server):
+ pattern = re.compile("^" + PREFIX + self.PATH + "$")
- self.server.register_path(
- "PUT",
- re.compile("^" + PREFIX + "/invite/([^/]*)/([^/]*)$"),
- self._with_authentication(
- lambda origin, content, query, context, event_id:
- self._on_invite_request(
- origin, content, query,
- )
- )
- )
+ for method in ("GET", "PUT", "POST"):
+ code = getattr(self, "on_%s" % (method), None)
+ if code is None:
+ continue
- self.server.register_path(
- "POST",
- re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"),
- self._with_authentication(
- lambda origin, content, query, context, event_id:
- self._on_query_auth_request(
- origin, content, event_id,
- )
- )
- )
+ server.register_path(method, pattern, self._wrap(code))
- self.server.register_path(
- "POST",
- re.compile("^" + PREFIX + "/get_missing_events/([^/]*)/?$"),
- self._with_authentication(
- lambda origin, content, query, room_id:
- self._get_missing_events(
- origin, content, room_id,
- )
- )
- )
+class FederationSendServlet(BaseFederationServlet):
+ PATH = "/send/([^/]*)/"
+
+ def __init__(self, handler, server_name, **kwargs):
+ super(FederationSendServlet, self).__init__(handler, **kwargs)
+ self.server_name = server_name
+
+ # This is when someone is trying to send us a bunch of data.
@defer.inlineCallbacks
- @log_function
- def _on_send_request(self, origin, content, query, transaction_id):
+ def on_PUT(self, origin, content, query, transaction_id):
""" Called on PUT /send/<transaction_id>/
Args:
@@ -305,8 +212,7 @@ class TransportLayerServer(object):
return
try:
- handler = self.received_handler
- code, response = yield handler.on_incoming_transaction(
+ code, response = yield self.handler.on_incoming_transaction(
transaction_data
)
except:
@@ -315,65 +221,123 @@ class TransportLayerServer(object):
defer.returnValue((code, response))
- @log_function
- def _on_backfill_request(self, origin, context, v_list, limits):
+
+class FederationPullServlet(BaseFederationServlet):
+ PATH = "/pull/"
+
+ # This is for when someone asks us for everything since version X
+ def on_GET(self, origin, content, query):
+ return self.handler.on_pull_request(query["origin"][0], query["v"])
+
+
+class FederationEventServlet(BaseFederationServlet):
+ PATH = "/event/([^/]*)/"
+
+ # This is when someone asks for a data item for a given server data_id pair.
+ def on_GET(self, origin, content, query, event_id):
+ return self.handler.on_pdu_request(origin, event_id)
+
+
+class FederationStateServlet(BaseFederationServlet):
+ PATH = "/state/([^/]*)/"
+
+ # This is when someone asks for all data for a given context.
+ def on_GET(self, origin, content, query, context):
+ return self.handler.on_context_state_request(
+ origin,
+ context,
+ query.get("event_id", [None])[0],
+ )
+
+
+class FederationBackfillServlet(BaseFederationServlet):
+ PATH = "/backfill/([^/]*)/"
+
+ def on_GET(self, origin, content, query, context):
+ versions = query["v"]
+ limits = query["limit"]
+
if not limits:
- return defer.succeed(
- (400, {"error": "Did not include limit param"})
- )
+ return defer.succeed((400, {"error": "Did not include limit param"}))
limit = int(limits[-1])
- versions = v_list
+ return self.handler.on_backfill_request(origin, context, versions, limit)
- return self.request_handler.on_backfill_request(
- origin, context, versions, limit
+
+class FederationQueryServlet(BaseFederationServlet):
+ PATH = "/query/([^/]*)"
+
+ # This is when we receive a server-server Query
+ def on_GET(self, origin, content, query, query_type):
+ return self.handler.on_query_request(
+ query_type,
+ {k: v[0].decode("utf-8") for k, v in query.items()}
)
+
+class FederationMakeJoinServlet(BaseFederationServlet):
+ PATH = "/make_join/([^/]*)/([^/]*)"
+
@defer.inlineCallbacks
- @log_function
- def _on_make_join_request(self, origin, content, query, context, user_id):
- content = yield self.request_handler.on_make_join_request(
- context, user_id,
- )
+ def on_GET(self, origin, content, query, context, user_id):
+ content = yield self.handler.on_make_join_request(context, user_id)
defer.returnValue((200, content))
- @defer.inlineCallbacks
- @log_function
- def _on_send_join_request(self, origin, content, query):
- content = yield self.request_handler.on_send_join_request(
- origin, content,
- )
- defer.returnValue((200, content))
+class FederationEventAuthServlet(BaseFederationServlet):
+ PATH = "/event_auth/([^/]*)/([^/]*)"
+
+ def on_GET(self, origin, content, query, context, event_id):
+ return self.handler.on_event_auth(origin, context, event_id)
+
+
+class FederationSendJoinServlet(BaseFederationServlet):
+ PATH = "/send_join/([^/]*)/([^/]*)"
@defer.inlineCallbacks
- @log_function
- def _on_invite_request(self, origin, content, query):
- content = yield self.request_handler.on_invite_request(
- origin, content,
- )
+ def on_PUT(self, origin, content, query, context, event_id):
+ # TODO(paul): assert that context/event_id parsed from path actually
+ # match those given in content
+ content = yield self.handler.on_send_join_request(origin, content)
+ defer.returnValue((200, content))
+
+
+class FederationInviteServlet(BaseFederationServlet):
+ PATH = "/invite/([^/]*)/([^/]*)"
+ @defer.inlineCallbacks
+ def on_PUT(self, origin, content, query, context, event_id):
+ # TODO(paul): assert that context/event_id parsed from path actually
+ # match those given in content
+ content = yield self.handler.on_invite_request(origin, content)
defer.returnValue((200, content))
+
+class FederationQueryAuthServlet(BaseFederationServlet):
+ PATH = "/query_auth/([^/]*)/([^/]*)"
+
@defer.inlineCallbacks
- @log_function
- def _on_query_auth_request(self, origin, content, event_id):
- new_content = yield self.request_handler.on_query_auth_request(
+ def on_POST(self, origin, content, query, context, event_id):
+ new_content = yield self.handler.on_query_auth_request(
origin, content, event_id
)
defer.returnValue((200, new_content))
+
+class FederationGetMissingEventsServlet(BaseFederationServlet):
+ # TODO(paul): Why does this path alone end with "/?" optional?
+ PATH = "/get_missing_events/([^/]*)/?"
+
@defer.inlineCallbacks
- @log_function
- def _get_missing_events(self, origin, content, room_id):
+ def on_POST(self, origin, content, query, room_id):
limit = int(content.get("limit", 10))
min_depth = int(content.get("min_depth", 0))
earliest_events = content.get("earliest_events", [])
latest_events = content.get("latest_events", [])
- content = yield self.request_handler.on_get_missing_events(
+ content = yield self.handler.on_get_missing_events(
origin,
room_id=room_id,
earliest_events=earliest_events,
@@ -383,3 +347,18 @@ class TransportLayerServer(object):
)
defer.returnValue((200, content))
+
+
+SERVLET_CLASSES = (
+ FederationPullServlet,
+ FederationEventServlet,
+ FederationStateServlet,
+ FederationBackfillServlet,
+ FederationQueryServlet,
+ FederationMakeJoinServlet,
+ FederationEventServlet,
+ FederationSendJoinServlet,
+ FederationInviteServlet,
+ FederationQueryAuthServlet,
+ FederationGetMissingEventsServlet,
+)
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 1773fa20aa..48816a242d 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -90,8 +90,8 @@ class BaseHandler(object):
event = builder.build()
logger.debug(
- "Created event %s with auth_events: %s, current state: %s",
- event.event_id, context.auth_events, context.current_state,
+ "Created event %s with current state: %s",
+ event.event_id, context.current_state,
)
defer.returnValue(
@@ -106,7 +106,7 @@ class BaseHandler(object):
# We now need to go and hit out to wherever we need to hit out to.
if not suppress_auth:
- self.auth.check(event, auth_events=context.auth_events)
+ self.auth.check(event, auth_events=context.current_state)
yield self.store.persist_event(event, context=context)
@@ -142,7 +142,16 @@ class BaseHandler(object):
"Failed to get destination from event %s", s.event_id
)
- yield self.notifier.on_new_room_event(event, extra_users=extra_users)
+ # Don't block waiting on waking up all the listeners.
+ d = self.notifier.on_new_room_event(event, extra_users=extra_users)
+
+ def log_failure(f):
+ logger.warn(
+ "Failed to notify about %s: %s",
+ event.event_id, f.value
+ )
+
+ d.addErrback(log_failure)
yield federation_handler.handle_new_event(
event, destinations=destinations,
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index d3297b7292..f9f855213b 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -71,7 +71,7 @@ class EventStreamHandler(BaseHandler):
self._streams_per_user[auth_user] += 1
rm_handler = self.hs.get_handlers().room_member_handler
- room_ids = yield rm_handler.get_rooms_for_user(auth_user)
+ room_ids = yield rm_handler.get_joined_rooms_for_user(auth_user)
if timeout:
# If they've set a timeout set a minimum limit.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index ae4e9b316d..15ba417e06 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -290,6 +290,8 @@ class FederationHandler(BaseHandler):
"""
logger.debug("Joining %s to %s", joinee, room_id)
+ yield self.store.clean_room_for_join(room_id)
+
origin, pdu = yield self.replication_layer.make_join(
target_hosts,
room_id,
@@ -464,11 +466,9 @@ class FederationHandler(BaseHandler):
builder=builder,
)
- self.auth.check(event, auth_events=context.auth_events)
-
- pdu = event
+ self.auth.check(event, auth_events=context.current_state)
- defer.returnValue(pdu)
+ defer.returnValue(event)
@defer.inlineCallbacks
@log_function
@@ -705,7 +705,7 @@ class FederationHandler(BaseHandler):
)
if not auth_events:
- auth_events = context.auth_events
+ auth_events = context.current_state
logger.debug(
"_handle_new_event: %s, auth_events: %s",
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 8ef248ecf2..731df00648 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -21,6 +21,7 @@ from synapse.api.constants import PresenceState
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID
+import synapse.metrics
from ._base import BaseHandler
@@ -29,6 +30,8 @@ import logging
logger = logging.getLogger(__name__)
+metrics = synapse.metrics.get_metrics_for(__name__)
+
# TODO(paul): Maybe there's one of these I can steal from somewhere
def partition(l, func):
@@ -133,6 +136,11 @@ class PresenceHandler(BaseHandler):
self._user_cachemap = {}
self._user_cachemap_latest_serial = 0
+ metrics.register_callback(
+ "userCachemap:size",
+ lambda: len(self._user_cachemap),
+ )
+
def _get_or_make_usercache(self, user):
"""If the cache entry doesn't exist, initialise a new one."""
if user not in self._user_cachemap:
@@ -452,7 +460,7 @@ class PresenceHandler(BaseHandler):
# Also include people in all my rooms
rm_handler = self.homeserver.get_handlers().room_member_handler
- room_ids = yield rm_handler.get_rooms_for_user(user)
+ room_ids = yield rm_handler.get_joined_rooms_for_user(user)
if state is None:
state = yield self.store.get_presence_state(user.localpart)
@@ -596,7 +604,7 @@ class PresenceHandler(BaseHandler):
localusers.add(user)
rm_handler = self.homeserver.get_handlers().room_member_handler
- room_ids = yield rm_handler.get_rooms_for_user(user)
+ room_ids = yield rm_handler.get_joined_rooms_for_user(user)
if not localusers and not room_ids:
defer.returnValue(None)
@@ -663,7 +671,7 @@ class PresenceHandler(BaseHandler):
)
rm_handler = self.homeserver.get_handlers().room_member_handler
- room_ids = yield rm_handler.get_rooms_for_user(user)
+ room_ids = yield rm_handler.get_joined_rooms_for_user(user)
if room_ids:
logger.debug(" | %d interested room IDs %r", len(room_ids), room_ids)
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 2ddf9d5378..ee2732b848 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -197,9 +197,8 @@ class ProfileHandler(BaseHandler):
self.ratelimit(user.to_string())
- joins = yield self.store.get_rooms_for_user_where_membership_is(
+ joins = yield self.store.get_rooms_for_user(
user.to_string(),
- [Membership.JOIN],
)
for j in joins:
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index cda4a8502a..c25e321099 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -31,6 +31,7 @@ import base64
import bcrypt
import json
import logging
+import urllib
logger = logging.getLogger(__name__)
@@ -63,6 +64,13 @@ class RegistrationHandler(BaseHandler):
password_hash = bcrypt.hashpw(password, bcrypt.gensalt())
if localpart:
+ if localpart and urllib.quote(localpart) != localpart:
+ raise SynapseError(
+ 400,
+ "User ID must only contain characters which do not"
+ " require URL encoding."
+ )
+
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 80f7ee3f12..823affc380 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -507,7 +507,7 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue((is_remote_invite_join, room_host))
@defer.inlineCallbacks
- def get_rooms_for_user(self, user, membership_list=[Membership.JOIN]):
+ def get_joined_rooms_for_user(self, user):
"""Returns a list of roomids that the user has any of the given
membership states in."""
@@ -517,8 +517,8 @@ class RoomMemberHandler(BaseHandler):
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
else:
- rooms = yield self.store.get_rooms_for_user_where_membership_is(
- user_id=user.to_string(), membership_list=membership_list
+ rooms = yield self.store.get_rooms_for_user(
+ user.to_string(),
)
# For some reason the list of events contains duplicates
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 7883bbd834..35a62fda47 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -96,7 +96,9 @@ class SyncHandler(BaseHandler):
return self.current_sync_for_user(sync_config, since_token)
rm_handler = self.hs.get_handlers().room_member_handler
- room_ids = yield rm_handler.get_rooms_for_user(sync_config.user)
+ room_ids = yield rm_handler.get_joined_rooms_for_user(
+ sync_config.user
+ )
result = yield self.notifier.wait_for_events(
sync_config.user, room_ids,
sync_config.filter, timeout, current_sync_callback
@@ -227,7 +229,7 @@ class SyncHandler(BaseHandler):
logger.debug("Typing %r", typing_by_room)
rm_handler = self.hs.get_handlers().room_member_handler
- room_ids = yield rm_handler.get_rooms_for_user(sync_config.user)
+ room_ids = yield rm_handler.get_joined_rooms_for_user(sync_config.user)
# TODO (mjark): Does public mean "published"?
published_rooms = yield self.store.get_rooms(is_public=True)
diff --git a/synapse/http/client.py b/synapse/http/client.py
index b53a07aa2d..2ae1c4d3a4 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -15,6 +15,7 @@
from synapse.api.errors import CodeMessageException
from syutil.jsonutil import encode_canonical_json
+import synapse.metrics
from twisted.internet import defer, reactor
from twisted.web.client import (
@@ -31,6 +32,17 @@ import urllib
logger = logging.getLogger(__name__)
+metrics = synapse.metrics.get_metrics_for(__name__)
+
+outgoing_requests_counter = metrics.register_counter(
+ "requests",
+ labels=["method"],
+)
+incoming_responses_counter = metrics.register_counter(
+ "responses",
+ labels=["method", "code"],
+)
+
class SimpleHttpClient(object):
"""
@@ -45,12 +57,30 @@ class SimpleHttpClient(object):
self.agent = Agent(reactor)
self.version_string = hs.version_string
+ def request(self, method, *args, **kwargs):
+ # A small wrapper around self.agent.request() so we can easily attach
+ # counters to it
+ outgoing_requests_counter.inc(method)
+ d = self.agent.request(method, *args, **kwargs)
+
+ def _cb(response):
+ incoming_responses_counter.inc(method, response.code)
+ return response
+
+ def _eb(failure):
+ incoming_responses_counter.inc(method, "ERR")
+ return failure
+
+ d.addCallbacks(_cb, _eb)
+
+ return d
+
@defer.inlineCallbacks
def post_urlencoded_get_json(self, uri, args={}):
logger.debug("post_urlencoded_get_json args: %s", args)
query_bytes = urllib.urlencode(args, True)
- response = yield self.agent.request(
+ response = yield self.request(
"POST",
uri.encode("ascii"),
headers=Headers({
@@ -70,7 +100,7 @@ class SimpleHttpClient(object):
logger.info("HTTP POST %s -> %s", json_str, uri)
- response = yield self.agent.request(
+ response = yield self.request(
"POST",
uri.encode("ascii"),
headers=Headers({
@@ -104,7 +134,7 @@ class SimpleHttpClient(object):
query_bytes = urllib.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
- response = yield self.agent.request(
+ response = yield self.request(
"GET",
uri.encode("ascii"),
headers=Headers({
@@ -145,7 +175,7 @@ class SimpleHttpClient(object):
json_str = encode_canonical_json(json_body)
- response = yield self.agent.request(
+ response = yield self.request(
"PUT",
uri.encode("ascii"),
headers=Headers({
@@ -176,7 +206,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
def post_urlencoded_get_raw(self, url, args={}):
query_bytes = urllib.urlencode(args, True)
- response = yield self.agent.request(
+ response = yield self.request(
"POST",
url.encode("ascii"),
bodyProducer=FileBodyProducer(StringIO(query_bytes)),
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 7db001cc63..7fa295cad5 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -23,6 +23,7 @@ from twisted.web._newclient import ResponseDone
from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.async import sleep
from synapse.util.logcontext import PreserveLoggingContext
+import synapse.metrics
from syutil.jsonutil import encode_canonical_json
@@ -40,6 +41,17 @@ import urlparse
logger = logging.getLogger(__name__)
+metrics = synapse.metrics.get_metrics_for(__name__)
+
+outgoing_requests_counter = metrics.register_counter(
+ "requests",
+ labels=["method"],
+)
+incoming_responses_counter = metrics.register_counter(
+ "responses",
+ labels=["method", "code"],
+)
+
class MatrixFederationHttpAgent(_AgentBase):
@@ -49,6 +61,8 @@ class MatrixFederationHttpAgent(_AgentBase):
def request(self, destination, endpoint, method, path, params, query,
headers, body_producer):
+ outgoing_requests_counter.inc(method)
+
host = b""
port = 0
fragment = b""
@@ -59,9 +73,21 @@ class MatrixFederationHttpAgent(_AgentBase):
# Set the connection pool key to be the destination.
key = destination
- return self._requestWithEndpoint(key, endpoint, method, parsed_URI,
- headers, body_producer,
- parsed_URI.originForm)
+ d = self._requestWithEndpoint(key, endpoint, method, parsed_URI,
+ headers, body_producer,
+ parsed_URI.originForm)
+
+ def _cb(response):
+ incoming_responses_counter.inc(method, response.code)
+ return response
+
+ def _eb(failure):
+ incoming_responses_counter.inc(method, "ERR")
+ return failure
+
+ d.addCallbacks(_cb, _eb)
+
+ return d
class MatrixFederationHttpClient(object):
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 767c3ef79b..dee49b9e18 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -18,6 +18,7 @@ from synapse.api.errors import (
cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError
)
from synapse.util.logcontext import LoggingContext
+import synapse.metrics
from syutil.jsonutil import (
encode_canonical_json, encode_pretty_printed_json
@@ -34,6 +35,22 @@ import urllib
logger = logging.getLogger(__name__)
+metrics = synapse.metrics.get_metrics_for(__name__)
+
+incoming_requests_counter = metrics.register_counter(
+ "requests",
+ labels=["method", "servlet"],
+)
+outgoing_responses_counter = metrics.register_counter(
+ "responses",
+ labels=["method", "code"],
+)
+
+response_timer = metrics.register_distribution(
+ "response_time",
+ labels=["method", "servlet"]
+)
+
class HttpServer(object):
""" Interface for registering callbacks on a HTTP server
@@ -74,6 +91,7 @@ class JsonResource(HttpServer, resource.Resource):
self.clock = hs.get_clock()
self.path_regexs = {}
self.version_string = hs.version_string
+ self.hs = hs
def register_path(self, method, path_pattern, callback):
self.path_regexs.setdefault(method, []).append(
@@ -87,7 +105,11 @@ class JsonResource(HttpServer, resource.Resource):
port (int): The port to listen on.
"""
- reactor.listenTCP(port, server.Site(self))
+ reactor.listenTCP(
+ port,
+ server.Site(self),
+ interface=self.hs.config.bind_host
+ )
# Gets called by twisted
def render(self, request):
@@ -131,6 +153,15 @@ class JsonResource(HttpServer, resource.Resource):
# returned response. We pass both the request and any
# matched groups from the regex to the callback.
+ callback = path_entry.callback
+
+ servlet_instance = getattr(callback, "__self__", None)
+ if servlet_instance is not None:
+ servlet_classname = servlet_instance.__class__.__name__
+ else:
+ servlet_classname = "%r" % callback
+ incoming_requests_counter.inc(request.method, servlet_classname)
+
args = [
urllib.unquote(u).decode("UTF-8") for u in m.groups()
]
@@ -140,12 +171,13 @@ class JsonResource(HttpServer, resource.Resource):
request.method, request.path
)
- code, response = yield path_entry.callback(
- request,
- *args
- )
+ code, response = yield callback(request, *args)
self._send_response(request, code, response)
+ response_timer.inc_by(
+ self.clock.time_msec() - start, request.method, servlet_classname
+ )
+
return
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
@@ -190,6 +222,8 @@ class JsonResource(HttpServer, resource.Resource):
request)
return
+ outgoing_responses_counter.inc(request.method, str(code))
+
# TODO: Only enable CORS for the requests that need it.
respond_with_json(
request, code, response_json_object,
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index a4eb6c817c..265559a3ea 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -51,8 +51,8 @@ class RestServlet(object):
pattern = self.PATTERN
for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"):
- if hasattr(self, "on_%s" % (method)):
- method_handler = getattr(self, "on_%s" % (method))
+ if hasattr(self, "on_%s" % (method,)):
+ method_handler = getattr(self, "on_%s" % (method,))
http_server.register_path(method, pattern, method_handler)
else:
raise NotImplementedError("RestServlet must register something.")
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
new file mode 100644
index 0000000000..dffb8a4861
--- /dev/null
+++ b/synapse/metrics/__init__.py
@@ -0,0 +1,111 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Because otherwise 'resource' collides with synapse.metrics.resource
+from __future__ import absolute_import
+
+import logging
+from resource import getrusage, getpagesize, RUSAGE_SELF
+
+from .metric import (
+ CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+# We'll keep all the available metrics in a single toplevel dict, one shared
+# for the entire process. We don't currently support per-HomeServer instances
+# of metrics, because in practice any one python VM will host only one
+# HomeServer anyway. This makes a lot of implementation neater
+all_metrics = {}
+
+
+class Metrics(object):
+ """ A single Metrics object gives a (mutable) slice view of the all_metrics
+ dict, allowing callers to easily register new metrics that are namespaced
+ nicely."""
+
+ def __init__(self, name):
+ self.name_prefix = name
+
+ def _register(self, metric_class, name, *args, **kwargs):
+ full_name = "%s_%s" % (self.name_prefix, name)
+
+ metric = metric_class(full_name, *args, **kwargs)
+
+ all_metrics[full_name] = metric
+ return metric
+
+ def register_counter(self, *args, **kwargs):
+ return self._register(CounterMetric, *args, **kwargs)
+
+ def register_callback(self, *args, **kwargs):
+ return self._register(CallbackMetric, *args, **kwargs)
+
+ def register_distribution(self, *args, **kwargs):
+ return self._register(DistributionMetric, *args, **kwargs)
+
+ def register_cache(self, *args, **kwargs):
+ return self._register(CacheMetric, *args, **kwargs)
+
+
+def get_metrics_for(pkg_name):
+ """ Returns a Metrics instance for conveniently creating metrics
+ namespaced with the given name prefix. """
+
+ # Convert a "package.name" to "package_name" because Prometheus doesn't
+ # let us use . in metric names
+ return Metrics(pkg_name.replace(".", "_"))
+
+
+def render_all():
+ strs = []
+
+ # TODO(paul): Internal hack
+ update_resource_metrics()
+
+ for name in sorted(all_metrics.keys()):
+ try:
+ strs += all_metrics[name].render()
+ except Exception:
+ strs += ["# FAILED to render %s" % name]
+ logger.exception("Failed to render %s metric", name)
+
+ strs.append("") # to generate a final CRLF
+
+ return "\n".join(strs)
+
+
+# Now register some standard process-wide state metrics, to give indications of
+# process resource usage
+
+rusage = None
+PAGE_SIZE = getpagesize()
+
+
+def update_resource_metrics():
+ global rusage
+ rusage = getrusage(RUSAGE_SELF)
+
+resource_metrics = get_metrics_for("process.resource")
+
+# msecs
+resource_metrics.register_callback("utime", lambda: rusage.ru_utime * 1000)
+resource_metrics.register_callback("stime", lambda: rusage.ru_stime * 1000)
+
+# pages
+resource_metrics.register_callback("maxrss", lambda: rusage.ru_maxrss * PAGE_SIZE)
diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py
new file mode 100644
index 0000000000..21b37748f6
--- /dev/null
+++ b/synapse/metrics/metric.py
@@ -0,0 +1,155 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from itertools import chain
+
+
+# TODO(paul): I can't believe Python doesn't have one of these
+def map_concat(func, items):
+ # flatten a list-of-lists
+ return list(chain.from_iterable(map(func, items)))
+
+
+class BaseMetric(object):
+
+ def __init__(self, name, labels=[]):
+ self.name = name
+ self.labels = labels # OK not to clone as we never write it
+
+ def dimension(self):
+ return len(self.labels)
+
+ def is_scalar(self):
+ return not len(self.labels)
+
+ def _render_labelvalue(self, value):
+ # TODO: some kind of value escape
+ return '"%s"' % (value)
+
+ def _render_key(self, values):
+ if self.is_scalar():
+ return ""
+ return "{%s}" % (
+ ",".join(["%s=%s" % (k, self._render_labelvalue(v))
+ for k, v in zip(self.labels, values)])
+ )
+
+ def render(self):
+ return map_concat(self.render_item, sorted(self.counts.keys()))
+
+
+class CounterMetric(BaseMetric):
+ """The simplest kind of metric; one that stores a monotonically-increasing
+ integer that counts events."""
+
+ def __init__(self, *args, **kwargs):
+ super(CounterMetric, self).__init__(*args, **kwargs)
+
+ self.counts = {}
+
+ # Scalar metrics are never empty
+ if self.is_scalar():
+ self.counts[()] = 0
+
+ def inc_by(self, incr, *values):
+ if len(values) != self.dimension():
+ raise ValueError(
+ "Expected as many values to inc() as labels (%d)" % (self.dimension())
+ )
+
+ # TODO: should assert that the tag values are all strings
+
+ if values not in self.counts:
+ self.counts[values] = incr
+ else:
+ self.counts[values] += incr
+
+ def inc(self, *values):
+ self.inc_by(1, *values)
+
+ def render_item(self, k):
+ return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])]
+
+
+class CallbackMetric(BaseMetric):
+ """A metric that returns the numeric value returned by a callback whenever
+ it is rendered. Typically this is used to implement gauges that yield the
+ size or other state of some in-memory object by actively querying it."""
+
+ def __init__(self, name, callback, labels=[]):
+ super(CallbackMetric, self).__init__(name, labels=labels)
+
+ self.callback = callback
+
+ def render(self):
+ value = self.callback()
+
+ if self.is_scalar():
+ return ["%s %d" % (self.name, value)]
+
+ return ["%s%s %d" % (self.name, self._render_key(k), value[k])
+ for k in sorted(value.keys())]
+
+
+class DistributionMetric(object):
+ """A combination of an event counter and an accumulator, which counts
+ both the number of events and accumulates the total value. Typically this
+ could be used to keep track of method-running times, or other distributions
+ of values that occur in discrete occurances.
+
+ TODO(paul): Try to export some heatmap-style stats?
+ """
+
+ def __init__(self, name, *args, **kwargs):
+ self.counts = CounterMetric(name + ":count", **kwargs)
+ self.totals = CounterMetric(name + ":total", **kwargs)
+
+ def inc_by(self, inc, *values):
+ self.counts.inc(*values)
+ self.totals.inc_by(inc, *values)
+
+ def render(self):
+ return self.counts.render() + self.totals.render()
+
+
+class CacheMetric(object):
+ """A combination of two CounterMetrics, one to count cache hits and one to
+ count a total, and a callback metric to yield the current size.
+
+ This metric generates standard metric name pairs, so that monitoring rules
+ can easily be applied to measure hit ratio."""
+
+ def __init__(self, name, size_callback, labels=[]):
+ self.name = name
+
+ self.hits = CounterMetric(name + ":hits", labels=labels)
+ self.total = CounterMetric(name + ":total", labels=labels)
+
+ self.size = CallbackMetric(
+ name + ":size",
+ callback=size_callback,
+ labels=labels,
+ )
+
+ def inc_hits(self, *values):
+ self.hits.inc(*values)
+ self.total.inc(*values)
+
+ def inc_misses(self, *values):
+ self.total.inc(*values)
+
+ def render(self):
+ return self.hits.render() + self.total.render() + self.size.render()
diff --git a/synapse/metrics/resource.py b/synapse/metrics/resource.py
new file mode 100644
index 0000000000..0af4b3eb52
--- /dev/null
+++ b/synapse/metrics/resource.py
@@ -0,0 +1,39 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.web.resource import Resource
+
+import synapse.metrics
+
+
+METRICS_PREFIX = "/_synapse/metrics"
+
+
+class MetricsResource(Resource):
+ isLeaf = True
+
+ def __init__(self, hs):
+ Resource.__init__(self) # Resource is old-style, so no super()
+
+ self.hs = hs
+
+ def render_GET(self, request):
+ response = synapse.metrics.render_all()
+
+ request.setHeader("Content-Type", "text/plain")
+ request.setHeader("Content-Length", str(len(response)))
+
+ # Encode as UTF-8 (default)
+ return response.encode()
diff --git a/synapse/notifier.py b/synapse/notifier.py
index df13e8ddb6..7121d659d0 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -19,12 +19,27 @@ from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.async import run_on_reactor
from synapse.types import StreamToken
+import synapse.metrics
import logging
logger = logging.getLogger(__name__)
+metrics = synapse.metrics.get_metrics_for(__name__)
+
+notified_events_counter = metrics.register_counter("notified_events")
+
+
+# TODO(paul): Should be shared somewhere
+def count(func, l):
+ """Return the number of items in l for which func returns true."""
+ n = 0
+ for x in l:
+ if func(x):
+ n += 1
+ return n
+
class _NotificationListener(object):
""" This represents a single client connection to the events stream.
@@ -59,6 +74,7 @@ class _NotificationListener(object):
try:
self.deferred.callback(result)
+ notified_events_counter.inc_by(len(events))
except defer.AlreadyCalledError:
pass
@@ -95,6 +111,35 @@ class Notifier(object):
"user_joined_room", self._user_joined_room
)
+ # This is not a very cheap test to perform, but it's only executed
+ # when rendering the metrics page, which is likely once per minute at
+ # most when scraping it.
+ def count_listeners():
+ all_listeners = set()
+
+ for x in self.room_to_listeners.values():
+ all_listeners |= x
+ for x in self.user_to_listeners.values():
+ all_listeners |= x
+ for x in self.appservice_to_listeners.values():
+ all_listeners |= x
+
+ return len(all_listeners)
+ metrics.register_callback("listeners", count_listeners)
+
+ metrics.register_callback(
+ "rooms",
+ lambda: count(bool, self.room_to_listeners.values()),
+ )
+ metrics.register_callback(
+ "users",
+ lambda: count(bool, self.user_to_listeners.values()),
+ )
+ metrics.register_callback(
+ "appservices",
+ lambda: count(bool, self.appservice_to_listeners.values()),
+ )
+
@log_function
@defer.inlineCallbacks
def on_new_room_event(self, event, extra_users=[]):
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 3da0ce8703..0727f772a5 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -32,7 +32,7 @@ class Pusher(object):
INITIAL_BACKOFF = 1000
MAX_BACKOFF = 60 * 60 * 1000
GIVE_UP_AFTER = 24 * 60 * 60 * 1000
- DEFAULT_ACTIONS = ['dont-notify']
+ DEFAULT_ACTIONS = ['dont_notify']
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
@@ -105,7 +105,11 @@ class Pusher(object):
room_member_count += 1
for r in rules:
- if r['rule_id'] in enabled_map and not enabled_map[r['rule_id']]:
+ if r['rule_id'] in enabled_map:
+ r['enabled'] = enabled_map[r['rule_id']]
+ elif 'enabled' not in r:
+ r['enabled'] = True
+ if not r['enabled']:
continue
matches = True
@@ -124,13 +128,21 @@ class Pusher(object):
# ignore rules with no actions (we have an explict 'dont_notify')
if len(actions) == 0:
logger.warn(
- "Ignoring rule id %s with no actions for user %s" %
- (r['rule_id'], r['user_name'])
+ "Ignoring rule id %s with no actions for user %s",
+ r['rule_id'], self.user_name
)
continue
if matches:
+ logger.info(
+ "%s matches for user %s, event %s",
+ r['rule_id'], self.user_name, ev['event_id']
+ )
defer.returnValue(actions)
+ logger.info(
+ "No rules match for user %s, event %s",
+ self.user_name, ev['event_id']
+ )
defer.returnValue(Pusher.DEFAULT_ACTIONS)
@staticmethod
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 6e333a3d21..60fd35fbfb 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -6,36 +6,51 @@ def list_with_base_rules(rawrules, user_name):
# shove the server default rules for each kind onto the end of each
current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1]
+
+ ruleslist.extend(make_base_prepend_rules(
+ user_name, PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
+ ))
+
for r in rawrules:
if r['priority_class'] < current_prio_class:
while r['priority_class'] < current_prio_class:
- ruleslist.extend(make_base_rules(
+ ruleslist.extend(make_base_append_rules(
user_name,
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
))
current_prio_class -= 1
+ if current_prio_class > 0:
+ ruleslist.extend(make_base_prepend_rules(
+ user_name,
+ PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
+ ))
ruleslist.append(r)
while current_prio_class > 0:
- ruleslist.extend(make_base_rules(
+ ruleslist.extend(make_base_append_rules(
user_name,
PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
))
current_prio_class -= 1
+ if current_prio_class > 0:
+ ruleslist.extend(make_base_prepend_rules(
+ user_name,
+ PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
+ ))
return ruleslist
-def make_base_rules(user, kind):
+def make_base_append_rules(user, kind):
rules = []
if kind == 'override':
- rules = make_base_override_rules()
+ rules = make_base_append_override_rules()
elif kind == 'underride':
- rules = make_base_underride_rules(user)
+ rules = make_base_append_underride_rules(user)
elif kind == 'content':
- rules = make_base_content_rules(user)
+ rules = make_base_append_content_rules(user)
for r in rules:
r['priority_class'] = PRIORITY_CLASS_MAP[kind]
@@ -44,7 +59,20 @@ def make_base_rules(user, kind):
return rules
-def make_base_content_rules(user):
+def make_base_prepend_rules(user, kind):
+ rules = []
+
+ if kind == 'override':
+ rules = make_base_prepend_override_rules()
+
+ for r in rules:
+ r['priority_class'] = PRIORITY_CLASS_MAP[kind]
+ r['default'] = True # Deprecated, left for backwards compat
+
+ return rules
+
+
+def make_base_append_content_rules(user):
return [
{
'rule_id': 'global/content/.m.rule.contains_user_name',
@@ -68,7 +96,20 @@ def make_base_content_rules(user):
]
-def make_base_override_rules():
+def make_base_prepend_override_rules():
+ return [
+ {
+ 'rule_id': 'global/override/.m.rule.master',
+ 'enabled': False,
+ 'conditions': [],
+ 'actions': [
+ "dont_notify"
+ ]
+ }
+ ]
+
+
+def make_base_append_override_rules():
return [
{
'rule_id': 'global/override/.m.rule.call',
@@ -86,7 +127,7 @@ def make_base_override_rules():
'value': 'ring'
}, {
'set_tweak': 'highlight',
- 'value': 'false'
+ 'value': False
}
]
},
@@ -135,14 +176,14 @@ def make_base_override_rules():
'value': 'default'
}, {
'set_tweak': 'highlight',
- 'value': 'false'
+ 'value': False
}
]
}
]
-def make_base_underride_rules(user):
+def make_base_append_underride_rules(user):
return [
{
'rule_id': 'global/underride/.m.rule.invite_for_me',
@@ -170,7 +211,7 @@ def make_base_underride_rules(user):
'value': 'default'
}, {
'set_tweak': 'highlight',
- 'value': 'false'
+ 'value': False
}
]
},
@@ -186,7 +227,7 @@ def make_base_underride_rules(user):
'actions': [
'notify', {
'set_tweak': 'highlight',
- 'value': 'false'
+ 'value': False
}
]
},
@@ -202,7 +243,7 @@ def make_base_underride_rules(user):
'actions': [
'notify', {
'set_tweak': 'highlight',
- 'value': 'false'
+ 'value': False
}
]
}
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 5fe8a825e3..6b6d5508b8 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -5,7 +5,6 @@ logger = logging.getLogger(__name__)
REQUIREMENTS = {
"syutil>=0.0.3": ["syutil"],
- "matrix_angular_sdk>=0.6.4": ["syweb>=0.6.4"],
"Twisted==14.0.2": ["twisted==14.0.2"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
@@ -18,6 +17,19 @@ REQUIREMENTS = {
"pillow": ["PIL"],
"pydenticon": ["pydenticon"],
}
+CONDITIONAL_REQUIREMENTS = {
+ "web_client": {
+ "matrix_angular_sdk>=0.6.5": ["syweb>=0.6.5"],
+ }
+}
+
+
+def requirements(config=None, include_conditional=False):
+ reqs = REQUIREMENTS.copy()
+ for key, req in CONDITIONAL_REQUIREMENTS.items():
+ if (config and getattr(config, key)) or include_conditional:
+ reqs.update(req)
+ return reqs
def github_link(project, version, egg):
@@ -36,8 +48,8 @@ DEPENDENCY_LINKS = [
),
github_link(
project="matrix-org/matrix-angular-sdk",
- version="v0.6.4",
- egg="matrix_angular_sdk-0.6.4",
+ version="v0.6.5",
+ egg="matrix_angular_sdk-0.6.5",
),
]
@@ -46,10 +58,11 @@ class MissingRequirementError(Exception):
pass
-def check_requirements():
+def check_requirements(config=None):
"""Checks that all the modules needed by synapse have been correctly
installed and are at the correct version"""
- for dependency, module_requirements in REQUIREMENTS.items():
+ for dependency, module_requirements in (
+ requirements(config, include_conditional=False).items()):
for module_requirement in module_requirements:
if ">=" in module_requirement:
module_name, required_version = module_requirement.split(">=")
@@ -110,7 +123,7 @@ def list_requirements():
egg = link.split("#egg=")[1]
linked.append(egg.split('-')[0])
result.append(link)
- for requirement in REQUIREMENTS:
+ for requirement in requirements(include_conditional=True):
is_linked = False
for link in linked:
if requirement.replace('-', '_').startswith(link):
diff --git a/synapse/rest/appservice/v1/register.py b/synapse/rest/appservice/v1/register.py
index a4f6159773..ea24d88f79 100644
--- a/synapse/rest/appservice/v1/register.py
+++ b/synapse/rest/appservice/v1/register.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
-# Licensed under the Apache License, Version 2.0 (the "License");
+# Licensensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
@@ -89,7 +89,8 @@ def _parse_json(request):
if type(content) != dict:
raise SynapseError(400, "Content must be a JSON object.")
return content
- except ValueError:
+ except ValueError as e:
+ logger.warn(e)
raise SynapseError(400, "Content not JSON.")
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index fef0eb6572..d4e7ab2202 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -156,9 +156,12 @@ class PushRuleRestServlet(ClientV1RestServlet):
template_rule = _rule_to_template(r)
if template_rule:
- template_rule['enabled'] = True
if r['rule_id'] in enabled_map:
template_rule['enabled'] = enabled_map[r['rule_id']]
+ elif 'enabled' in r:
+ template_rule['enabled'] = r['enabled']
+ else:
+ template_rule['enabled'] = True
rulearray.append(template_rule)
path = request.postpath[1:]
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index f5acfb945f..a56834e365 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -27,7 +27,6 @@ from hashlib import sha1
import hmac
import simplejson as json
import logging
-import urllib
logger = logging.getLogger(__name__)
@@ -110,14 +109,22 @@ class RegisterRestServlet(ClientV1RestServlet):
login_type = register_json["type"]
is_application_server = login_type == LoginType.APPLICATION_SERVICE
- if self.disable_registration and not is_application_server:
+ is_using_shared_secret = login_type == LoginType.SHARED_SECRET
+
+ can_register = (
+ not self.disable_registration
+ or is_application_server
+ or is_using_shared_secret
+ )
+ if not can_register:
raise SynapseError(403, "Registration has been disabled")
stages = {
LoginType.RECAPTCHA: self._do_recaptcha,
LoginType.PASSWORD: self._do_password,
LoginType.EMAIL_IDENTITY: self._do_email_identity,
- LoginType.APPLICATION_SERVICE: self._do_app_service
+ LoginType.APPLICATION_SERVICE: self._do_app_service,
+ LoginType.SHARED_SECRET: self._do_shared_secret,
}
session_info = self._get_session_info(request, session)
@@ -255,14 +262,11 @@ class RegisterRestServlet(ClientV1RestServlet):
)
password = register_json["password"].encode("utf-8")
- desired_user_id = (register_json["user"].encode("utf-8")
- if "user" in register_json else None)
- if (desired_user_id
- and urllib.quote(desired_user_id) != desired_user_id):
- raise SynapseError(
- 400,
- "User ID must only contain characters which do not " +
- "require URL encoding.")
+ desired_user_id = (
+ register_json["user"].encode("utf-8")
+ if "user" in register_json else None
+ )
+
handler = self.handlers.registration_handler
(user_id, token) = yield handler.register(
localpart=desired_user_id,
@@ -304,6 +308,51 @@ class RegisterRestServlet(ClientV1RestServlet):
"home_server": self.hs.hostname,
})
+ @defer.inlineCallbacks
+ def _do_shared_secret(self, request, register_json, session):
+ yield run_on_reactor()
+
+ if not isinstance(register_json.get("mac", None), basestring):
+ raise SynapseError(400, "Expected mac.")
+ if not isinstance(register_json.get("user", None), basestring):
+ raise SynapseError(400, "Expected 'user' key.")
+ if not isinstance(register_json.get("password", None), basestring):
+ raise SynapseError(400, "Expected 'password' key.")
+
+ if not self.hs.config.registration_shared_secret:
+ raise SynapseError(400, "Shared secret registration is not enabled")
+
+ user = register_json["user"].encode("utf-8")
+
+ # str() because otherwise hmac complains that 'unicode' does not
+ # have the buffer interface
+ got_mac = str(register_json["mac"])
+
+ want_mac = hmac.new(
+ key=self.hs.config.registration_shared_secret,
+ msg=user,
+ digestmod=sha1,
+ ).hexdigest()
+
+ password = register_json["password"].encode("utf-8")
+
+ if compare_digest(want_mac, got_mac):
+ handler = self.handlers.registration_handler
+ user_id, token = yield handler.register(
+ localpart=user,
+ password=password,
+ )
+ self._remove_session(session)
+ defer.returnValue({
+ "user_id": user_id,
+ "access_token": token,
+ "home_server": self.hs.hostname,
+ })
+ else:
+ raise SynapseError(
+ 403, "HMAC incorrect",
+ )
+
def _parse_json(request):
try:
diff --git a/synapse/server.py b/synapse/server.py
index cb8610a1b4..c7772244ba 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -56,6 +56,7 @@ class BaseHomeServer(object):
"""
DEPENDENCIES = [
+ 'config',
'clock',
'http_client',
'db_name',
@@ -79,6 +80,7 @@ class BaseHomeServer(object):
'resource_for_server_key',
'resource_for_media_repository',
'resource_for_app_services',
+ 'resource_for_metrics',
'event_sources',
'ratelimiter',
'keyring',
diff --git a/synapse/state.py b/synapse/state.py
index 80cced351d..ba2500d61c 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -21,6 +21,7 @@ from synapse.util.async import run_on_reactor
from synapse.util.expiringcache import ExpiringCache
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
+from synapse.api.auth import AuthEventTypes
from synapse.events.snapshot import EventContext
from collections import namedtuple
@@ -38,12 +39,6 @@ def _get_state_key_from_event(event):
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
-AuthEventTypes = (
- EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
- EventTypes.JoinRules,
-)
-
-
SIZE_OF_CACHE = 1000
EVICTION_TIMEOUT_SECONDS = 20
@@ -139,18 +134,6 @@ class StateHandler(object):
}
context.state_group = None
- if hasattr(event, "auth_events") and event.auth_events:
- auth_ids = self.hs.get_auth().compute_auth_events(
- event, context.current_state
- )
- context.auth_events = {
- k: v
- for k, v in context.current_state.items()
- if v.event_id in auth_ids
- }
- else:
- context.auth_events = {}
-
if event.is_state():
key = (event.type, event.state_key)
if key in context.current_state:
@@ -187,18 +170,6 @@ class StateHandler(object):
replaces = context.current_state[key]
event.unsigned["replaces_state"] = replaces.event_id
- if hasattr(event, "auth_events") and event.auth_events:
- auth_ids = self.hs.get_auth().compute_auth_events(
- event, context.current_state
- )
- context.auth_events = {
- k: v
- for k, v in context.current_state.items()
- if v.event_id in auth_ids
- }
- else:
- context.auth_events = {}
-
context.prev_state_events = prev_state
defer.returnValue(context)
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index a3ff995695..4b16f445d6 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -450,7 +450,7 @@ class DataStore(RoomMemberStore, RoomStore,
else:
args = (room_id, )
- results = yield self._execute_and_decode(sql, *args)
+ results = yield self._execute_and_decode("get_current_state", sql, *args)
events = yield self._parse_events(results)
defer.returnValue(events)
@@ -475,7 +475,7 @@ class DataStore(RoomMemberStore, RoomStore,
sql += " OR s.type = 'm.room.aliases')"
args = (room_id,)
- results = yield self._execute_and_decode(sql, *args)
+ results = yield self._execute_and_decode("get_current_state", sql, *args)
events = yield self._parse_events(results)
@@ -484,17 +484,18 @@ class DataStore(RoomMemberStore, RoomStore,
for e in events:
if e.type == 'm.room.name':
- name = e.content['name']
+ if 'name' in e.content:
+ name = e.content['name']
elif e.type == 'm.room.aliases':
- aliases.extend(e.content['aliases'])
+ if 'aliases' in e.content:
+ aliases.extend(e.content['aliases'])
defer.returnValue((name, aliases))
@defer.inlineCallbacks
def _get_min_token(self):
row = yield self._execute(
- None,
- "SELECT MIN(stream_ordering) FROM events"
+ "_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
)
self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 3725c9795d..9125bb1198 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -20,10 +20,12 @@ from synapse.events.utils import prune_event
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
from synapse.util.lrucache import LruCache
+import synapse.metrics
from twisted.internet import defer
from collections import namedtuple, OrderedDict
+import functools
import simplejson as json
import sys
import time
@@ -35,9 +37,24 @@ sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
+metrics = synapse.metrics.get_metrics_for("synapse.storage")
+
+sql_scheduling_timer = metrics.register_distribution("schedule_time")
+
+sql_query_timer = metrics.register_distribution("query_time", labels=["verb"])
+sql_txn_timer = metrics.register_distribution("transaction_time", labels=["desc"])
+sql_getevents_timer = metrics.register_distribution("getEvents_time", labels=["desc"])
+
+caches_by_name = {}
+cache_counter = metrics.register_cache(
+ "cache",
+ lambda: {(name,): len(caches_by_name[name]) for name in caches_by_name.keys()},
+ labels=["name"],
+)
+
+
# TODO(paul):
# * more generic key management
-# * export monitoring stats
# * consider other eviction strategies - LRU?
def cached(max_entries=1000):
""" A method decorator that applies a memoizing cache around the function.
@@ -55,6 +72,9 @@ def cached(max_entries=1000):
"""
def wrap(orig):
cache = OrderedDict()
+ name = orig.__name__
+
+ caches_by_name[name] = cache
def prefill(key, value):
while len(cache) > max_entries:
@@ -62,11 +82,14 @@ def cached(max_entries=1000):
cache[key] = value
+ @functools.wraps(orig)
@defer.inlineCallbacks
def wrapped(self, key):
if key in cache:
+ cache_counter.inc_hits(name)
defer.returnValue(cache[key])
+ cache_counter.inc_misses(name)
ret = yield orig(self, key)
prefill(key, ret)
defer.returnValue(ret)
@@ -83,7 +106,8 @@ def cached(max_entries=1000):
class LoggingTransaction(object):
"""An object that almost-transparently proxies for the 'txn' object
- passed to the constructor. Adds logging to the .execute() method."""
+ passed to the constructor. Adds logging and metrics to the .execute()
+ method."""
__slots__ = ["txn", "name"]
def __init__(self, txn, name):
@@ -99,6 +123,7 @@ class LoggingTransaction(object):
def execute(self, sql, *args, **kwargs):
# TODO(paul): Maybe use 'info' and 'debug' for values?
sql_logger.debug("[SQL] {%s} %s", self.name, sql)
+
try:
if args and args[0]:
values = args[0]
@@ -120,8 +145,9 @@ class LoggingTransaction(object):
logger.exception("[SQL FAIL] {%s}", self.name)
raise
finally:
- end = time.time() * 1000
- sql_logger.debug("[SQL time] {%s} %f", self.name, end - start)
+ msecs = (time.time() * 1000) - start
+ sql_logger.debug("[SQL time] {%s} %f", self.name, msecs)
+ sql_query_timer.inc_by(msecs, sql.split()[0])
class PerformanceCounters(object):
@@ -172,11 +198,18 @@ class SQLBaseStore(object):
self._previous_txn_total_time = 0
self._current_txn_total_time = 0
self._previous_loop_ts = 0
+
+ # TODO(paul): These can eventually be removed once the metrics code
+ # is running in mainline, and we have some nice monitoring frontends
+ # to watch it
self._txn_perf_counters = PerformanceCounters()
self._get_event_counters = PerformanceCounters()
self._get_event_cache = LruCache(hs.config.event_cache_size)
+ # Pretend the getEventCache is just another named cache
+ caches_by_name["*getEvent*"] = self._get_event_cache
+
def start_profiling(self):
self._previous_loop_ts = self._clock.time_msec()
@@ -211,6 +244,8 @@ class SQLBaseStore(object):
"""Wraps the .runInteraction() method on the underlying db_pool."""
current_context = LoggingContext.current_context()
+ start_time = time.time() * 1000
+
def inner_func(txn, *args, **kwargs):
with LoggingContext("runInteraction") as context:
current_context.copy_to(context)
@@ -223,6 +258,7 @@ class SQLBaseStore(object):
name = "%s-%x" % (desc, txn_id, )
+ sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
transaction_logger.debug("[TXN START] {%s}", name)
try:
return func(LoggingTransaction(txn, name), *args, **kwargs)
@@ -231,13 +267,13 @@ class SQLBaseStore(object):
raise
finally:
end = time.time() * 1000
- transaction_logger.debug(
- "[TXN END] {%s} %f",
- name, end - start
- )
+ duration = end - start
+
+ transaction_logger.debug("[TXN END] {%s} %f", name, duration)
- self._current_txn_total_time += end - start
+ self._current_txn_total_time += duration
self._txn_perf_counters.update(desc, start, end)
+ sql_txn_timer.inc_by(duration, desc)
with PreserveLoggingContext():
result = yield self._db_pool.runInteraction(
@@ -259,7 +295,7 @@ class SQLBaseStore(object):
)
return results
- def _execute(self, decoder, query, *args):
+ def _execute(self, desc, decoder, query, *args):
"""Runs a single query for a result set.
Args:
@@ -277,10 +313,10 @@ class SQLBaseStore(object):
else:
return cursor.fetchall()
- return self.runInteraction("_execute", interaction)
+ return self.runInteraction(desc, interaction)
- def _execute_and_decode(self, query, *args):
- return self._execute(self.cursor_to_dict, query, *args)
+ def _execute_and_decode(self, desc, query, *args):
+ return self._execute(desc, self.cursor_to_dict, query, *args)
# "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns.
@@ -638,14 +674,22 @@ class SQLBaseStore(object):
get_prev_content=False, allow_rejected=False):
start_time = time.time() * 1000
- update_counter = self._get_event_counters.update
+
+ def update_counter(desc, last_time):
+ curr_time = self._get_event_counters.update(desc, last_time)
+ sql_getevents_timer.inc_by(curr_time - last_time, desc)
+ return curr_time
cache = self._get_event_cache.setdefault(event_id, {})
try:
# Separate cache entries for each way to invoke _get_event_txn
- return cache[(check_redacted, get_prev_content, allow_rejected)]
+ ret = cache[(check_redacted, get_prev_content, allow_rejected)]
+
+ cache_counter.inc_hits("*getEvent*")
+ return ret
except KeyError:
+ cache_counter.inc_misses("*getEvent*")
pass
finally:
start_time = update_counter("event_cache", start_time)
@@ -685,7 +729,11 @@ class SQLBaseStore(object):
check_redacted=True, get_prev_content=False):
start_time = time.time() * 1000
- update_counter = self._get_event_counters.update
+
+ def update_counter(desc, last_time):
+ curr_time = self._get_event_counters.update(desc, last_time)
+ sql_getevents_timer.inc_by(curr_time - last_time, desc)
+ return curr_time
d = json.loads(js)
start_time = update_counter("decode_json", start_time)
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index e30265750a..850676ce6c 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -296,7 +296,7 @@ class ApplicationServiceStore(SQLBaseStore):
# }
# ]
services = {}
- results = yield self._execute_and_decode(sql)
+ results = yield self._execute_and_decode("_populate_cache", sql)
for res in results:
as_token = res["token"]
if as_token not in services:
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 2deda8ac50..032334bfd6 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -429,3 +429,15 @@ class EventFederationStore(SQLBaseStore):
)
return events[:limit]
+
+ def clean_room_for_join(self, room_id):
+ return self.runInteraction(
+ "clean_room_for_join",
+ self._clean_room_for_join_txn,
+ room_id,
+ )
+
+ def _clean_room_for_join_txn(self, txn, room_id):
+ query = "DELETE FROM event_forward_extremities WHERE room_id = ?"
+
+ txn.execute(query, (room_id,))
diff --git a/synapse/storage/feedback.py b/synapse/storage/feedback.py
index fcf011b234..8eab769b71 100644
--- a/synapse/storage/feedback.py
+++ b/synapse/storage/feedback.py
@@ -37,7 +37,7 @@ class FeedbackStore(SQLBaseStore):
"WHERE feedback.target_event_id = ? "
)
- rows = yield self._execute_and_decode(sql, event_id)
+ rows = yield self._execute_and_decode("get_feedback_for_event", sql, event_id)
defer.returnValue(
[
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 1f244019fc..09d1e63657 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -85,7 +85,9 @@ class KeyStore(SQLBaseStore):
" AND key_id in (" + ",".join("?" for key_id in key_ids) + ")"
)
- rows = yield self._execute_and_decode(sql, server_name, *key_ids)
+ rows = yield self._execute_and_decode(
+ "get_server_verify_keys", sql, server_name, *key_ids
+ )
keys = []
for row in rows:
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index bbf322cc84..d769db2c78 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -34,7 +34,7 @@ class PushRuleStore(SQLBaseStore):
"WHERE user_name = ? "
"ORDER BY priority_class DESC, priority DESC"
)
- rows = yield self._execute(None, sql, user_name)
+ rows = yield self._execute("get_push_rules_for_user", None, sql, user_name)
dicts = []
for r in rows:
@@ -57,17 +57,6 @@ class PushRuleStore(SQLBaseStore):
)
@defer.inlineCallbacks
- def get_push_rule_enabled_by_user_rule_id(self, user_name, rule_id):
- results = yield self._simple_select_list(
- PushRuleEnableTable.table_name,
- {'user_name': user_name, 'rule_id': rule_id},
- ['enabled']
- )
- if not results:
- defer.returnValue(True)
- defer.returnValue(results[0])
-
- @defer.inlineCallbacks
def add_push_rule(self, before, after, **kwargs):
vals = copy.copy(kwargs)
if 'conditions' in vals:
@@ -217,17 +206,11 @@ class PushRuleStore(SQLBaseStore):
@defer.inlineCallbacks
def set_push_rule_enabled(self, user_name, rule_id, enabled):
- if enabled:
- yield self._simple_delete_one(
- PushRuleEnableTable.table_name,
- {'user_name': user_name, 'rule_id': rule_id}
- )
- else:
- yield self._simple_upsert(
- PushRuleEnableTable.table_name,
- {'user_name': user_name, 'rule_id': rule_id},
- {'enabled': False}
- )
+ yield self._simple_upsert(
+ PushRuleEnableTable.table_name,
+ {'user_name': user_name, 'rule_id': rule_id},
+ {'enabled': enabled}
+ )
class RuleNotFoundException(Exception):
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index 6622b4d18a..587dada68f 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -37,7 +37,8 @@ class PusherStore(SQLBaseStore):
)
rows = yield self._execute(
- None, sql, app_id_and_pushkey[0], app_id_and_pushkey[1]
+ "get_pushers_by_app_id_and_pushkey", None, sql,
+ app_id_and_pushkey[0], app_id_and_pushkey[1]
)
ret = [
@@ -70,7 +71,7 @@ class PusherStore(SQLBaseStore):
"FROM pushers"
)
- rows = yield self._execute(None, sql)
+ rows = yield self._execute("get_all_pushers", None, sql)
ret = [
{
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 029b07cc66..3c2f1d6a15 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -19,7 +19,7 @@ from sqlite3 import IntegrityError
from synapse.api.errors import StoreError, Codes
-from ._base import SQLBaseStore
+from ._base import SQLBaseStore, cached
class RegistrationStore(SQLBaseStore):
@@ -88,10 +88,14 @@ class RegistrationStore(SQLBaseStore):
query = ("SELECT users.name, users.password_hash FROM users"
" WHERE users.name = ?")
return self._execute(
- self.cursor_to_dict,
- query, user_id
+ "get_user_by_id", self.cursor_to_dict, query, user_id
)
+ @cached()
+ # TODO(paul): Currently there's no code to invalidate this cache. That
+ # means if/when we ever add internal ways to invalidate access tokens or
+ # change whether a user is a server admin, those will need to invoke
+ # store.get_user_by_token.invalidate(token)
def get_user_by_token(self, token):
"""Get a user from the given access token.
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 750b17a45f..549c9af393 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -68,7 +68,7 @@ class RoomStore(SQLBaseStore):
"""
query = RoomsTable.select_statement("room_id=?")
return self._execute(
- RoomsTable.decode_single_result, query, room_id,
+ "get_room", RoomsTable.decode_single_result, query, room_id,
)
@defer.inlineCallbacks
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 71db16d0e5..456e4bd45d 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -82,7 +82,7 @@ class StateStore(SQLBaseStore):
if context.current_state is None:
return
- state_events = context.current_state
+ state_events = dict(context.current_state)
if event.is_state():
state_events[(event.type, event.state_key)] = event
diff --git a/synapse/util/lrucache.py b/synapse/util/lrucache.py
index f115f50e50..65d5792907 100644
--- a/synapse/util/lrucache.py
+++ b/synapse/util/lrucache.py
@@ -16,7 +16,6 @@
class LruCache(object):
"""Least-recently-used cache."""
- # TODO(mjark) Add hit/miss counters
# TODO(mjark) Add mutex for linked list for thread safety.
def __init__(self, max_size):
cache = {}
diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py
index ea53a8085c..52e66beaee 100644
--- a/synapse/util/stringutils.py
+++ b/synapse/util/stringutils.py
@@ -16,6 +16,10 @@
import random
import string
+_string_with_symbols = (
+ string.digits + string.ascii_letters + ".,;:^&*-_+=#~@"
+)
+
def origin_from_ucid(ucid):
return ucid.split("@", 1)[1]
@@ -23,3 +27,9 @@ def origin_from_ucid(ucid):
def random_string(length):
return ''.join(random.choice(string.ascii_letters) for _ in xrange(length))
+
+
+def random_string_with_symbols(length):
+ return ''.join(
+ random.choice(_string_with_symbols) for _ in xrange(length)
+ )
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 6ffc3c99cc..04eba4289e 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -100,7 +100,7 @@ class PresenceTestCase(unittest.TestCase):
self.room_members = []
room_member_handler = handlers.room_member_handler = Mock(spec=[
- "get_rooms_for_user",
+ "get_joined_rooms_for_user",
"get_room_members",
"fetch_room_distributions_into",
])
@@ -111,7 +111,7 @@ class PresenceTestCase(unittest.TestCase):
return defer.succeed([self.room_id])
else:
return defer.succeed([])
- room_member_handler.get_rooms_for_user = get_rooms_for_user
+ room_member_handler.get_joined_rooms_for_user = get_rooms_for_user
def get_room_members(room_id):
if room_id == self.room_id:
diff --git a/tests/handlers/test_presencelike.py b/tests/handlers/test_presencelike.py
index 18cac9a846..977e832da7 100644
--- a/tests/handlers/test_presencelike.py
+++ b/tests/handlers/test_presencelike.py
@@ -64,7 +64,7 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
"set_presence_state",
"is_presence_visible",
"set_profile_displayname",
- "get_rooms_for_user_where_membership_is",
+ "get_rooms_for_user",
]),
handlers=None,
resource_for_federation=Mock(),
@@ -124,9 +124,9 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
self.mock_update_client)
hs.handlers.room_member_handler = Mock(spec=[
- "get_rooms_for_user",
+ "get_joined_rooms_for_user",
])
- hs.handlers.room_member_handler.get_rooms_for_user = (
+ hs.handlers.room_member_handler.get_joined_rooms_for_user = (
lambda u: defer.succeed([]))
# Some local users to test with
@@ -138,7 +138,7 @@ class PresenceProfilelikeDataTestCase(unittest.TestCase):
self.u_potato = UserID.from_string("@potato:remote")
self.mock_get_joined = (
- self.datastore.get_rooms_for_user_where_membership_is
+ self.datastore.get_rooms_for_user
)
@defer.inlineCallbacks
diff --git a/tests/metrics/__init__.py b/tests/metrics/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/metrics/__init__.py
diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py
new file mode 100644
index 0000000000..6009014297
--- /dev/null
+++ b/tests/metrics/test_metric.py
@@ -0,0 +1,161 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from tests import unittest
+
+from synapse.metrics.metric import (
+ CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
+)
+
+
+class CounterMetricTestCase(unittest.TestCase):
+
+ def test_scalar(self):
+ counter = CounterMetric("scalar")
+
+ self.assertEquals(counter.render(), [
+ 'scalar 0',
+ ])
+
+ counter.inc()
+
+ self.assertEquals(counter.render(), [
+ 'scalar 1',
+ ])
+
+ counter.inc_by(2)
+
+ self.assertEquals(counter.render(), [
+ 'scalar 3'
+ ])
+
+ def test_vector(self):
+ counter = CounterMetric("vector", labels=["method"])
+
+ # Empty counter doesn't yet know what values it has
+ self.assertEquals(counter.render(), [])
+
+ counter.inc("GET")
+
+ self.assertEquals(counter.render(), [
+ 'vector{method="GET"} 1',
+ ])
+
+ counter.inc("GET")
+ counter.inc("PUT")
+
+ self.assertEquals(counter.render(), [
+ 'vector{method="GET"} 2',
+ 'vector{method="PUT"} 1',
+ ])
+
+
+class CallbackMetricTestCase(unittest.TestCase):
+
+ def test_scalar(self):
+ d = dict()
+
+ metric = CallbackMetric("size", lambda: len(d))
+
+ self.assertEquals(metric.render(), [
+ 'size 0',
+ ])
+
+ d["key"] = "value"
+
+ self.assertEquals(metric.render(), [
+ 'size 1',
+ ])
+
+ def test_vector(self):
+ vals = dict()
+
+ metric = CallbackMetric("values", lambda: vals, labels=["type"])
+
+ self.assertEquals(metric.render(), [])
+
+ # Keys have to be tuples, even if they're 1-element
+ vals[("foo",)] = 1
+ vals[("bar",)] = 2
+
+ self.assertEquals(metric.render(), [
+ 'values{type="bar"} 2',
+ 'values{type="foo"} 1',
+ ])
+
+
+class DistributionMetricTestCase(unittest.TestCase):
+
+ def test_scalar(self):
+ metric = DistributionMetric("thing")
+
+ self.assertEquals(metric.render(), [
+ 'thing:count 0',
+ 'thing:total 0',
+ ])
+
+ metric.inc_by(500)
+
+ self.assertEquals(metric.render(), [
+ 'thing:count 1',
+ 'thing:total 500',
+ ])
+
+ def test_vector(self):
+ metric = DistributionMetric("queries", labels=["verb"])
+
+ self.assertEquals(metric.render(), [])
+
+ metric.inc_by(300, "SELECT")
+ metric.inc_by(200, "SELECT")
+ metric.inc_by(800, "INSERT")
+
+ self.assertEquals(metric.render(), [
+ 'queries:count{verb="INSERT"} 1',
+ 'queries:count{verb="SELECT"} 2',
+ 'queries:total{verb="INSERT"} 800',
+ 'queries:total{verb="SELECT"} 500',
+ ])
+
+
+class CacheMetricTestCase(unittest.TestCase):
+
+ def test_cache(self):
+ d = dict()
+
+ metric = CacheMetric("cache", lambda: len(d))
+
+ self.assertEquals(metric.render(), [
+ 'cache:hits 0',
+ 'cache:total 0',
+ 'cache:size 0',
+ ])
+
+ metric.inc_misses()
+ d["key"] = "value"
+
+ self.assertEquals(metric.render(), [
+ 'cache:hits 0',
+ 'cache:total 1',
+ 'cache:size 1',
+ ])
+
+ metric.inc_hits()
+
+ self.assertEquals(metric.render(), [
+ 'cache:hits 1',
+ 'cache:total 2',
+ 'cache:size 1',
+ ])
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index 5f2ef64efc..b9c03383a2 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -79,13 +79,13 @@ class PresenceStateTestCase(unittest.TestCase):
room_member_handler = hs.handlers.room_member_handler = Mock(
spec=[
- "get_rooms_for_user",
+ "get_joined_rooms_for_user",
]
)
def get_rooms_for_user(user):
return defer.succeed([])
- room_member_handler.get_rooms_for_user = get_rooms_for_user
+ room_member_handler.get_joined_rooms_for_user = get_rooms_for_user
presence.register_servlets(hs, self.mock_resource)
@@ -166,7 +166,7 @@ class PresenceListTestCase(unittest.TestCase):
hs.handlers.room_member_handler = Mock(
spec=[
- "get_rooms_for_user",
+ "get_joined_rooms_for_user",
]
)
@@ -291,7 +291,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
return ["a-room"]
else:
return []
- hs.handlers.room_member_handler.get_rooms_for_user = get_rooms_for_user
+ hs.handlers.room_member_handler.get_joined_rooms_for_user = get_rooms_for_user
self.mock_datastore = hs.get_datastore()
self.mock_datastore.get_app_service_by_token = Mock(return_value=None)
|