diff --git a/synapse/_scripts/__init__.py b/synapse/_scripts/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/synapse/_scripts/__init__.py
diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py
new file mode 100644
index 0000000000..70cecde486
--- /dev/null
+++ b/synapse/_scripts/register_new_matrix_user.py
@@ -0,0 +1,215 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import argparse
+import getpass
+import hashlib
+import hmac
+import logging
+import sys
+
+from six.moves import input
+
+import requests as _requests
+import yaml
+
+
+def request_registration(
+ user,
+ password,
+ server_location,
+ shared_secret,
+ admin=False,
+ requests=_requests,
+ _print=print,
+ exit=sys.exit,
+):
+
+ url = "%s/_matrix/client/r0/admin/register" % (server_location,)
+
+ # Get the nonce
+ r = requests.get(url, verify=False)
+
+ if r.status_code is not 200:
+ _print("ERROR! Received %d %s" % (r.status_code, r.reason))
+ if 400 <= r.status_code < 500:
+ try:
+ _print(r.json()["error"])
+ except Exception:
+ pass
+ return exit(1)
+
+ nonce = r.json()["nonce"]
+
+ mac = hmac.new(key=shared_secret.encode('utf8'), digestmod=hashlib.sha1)
+
+ mac.update(nonce.encode('utf8'))
+ mac.update(b"\x00")
+ mac.update(user.encode('utf8'))
+ mac.update(b"\x00")
+ mac.update(password.encode('utf8'))
+ mac.update(b"\x00")
+ mac.update(b"admin" if admin else b"notadmin")
+
+ mac = mac.hexdigest()
+
+ data = {
+ "nonce": nonce,
+ "username": user,
+ "password": password,
+ "mac": mac,
+ "admin": admin,
+ }
+
+ _print("Sending registration request...")
+ r = requests.post(url, json=data, verify=False)
+
+ if r.status_code is not 200:
+ _print("ERROR! Received %d %s" % (r.status_code, r.reason))
+ if 400 <= r.status_code < 500:
+ try:
+ _print(r.json()["error"])
+ except Exception:
+ pass
+ return exit(1)
+
+ _print("Success!")
+
+
+def register_new_user(user, password, server_location, shared_secret, admin):
+ if not user:
+ try:
+ default_user = getpass.getuser()
+ except Exception:
+ default_user = None
+
+ if default_user:
+ user = input("New user localpart [%s]: " % (default_user,))
+ if not user:
+ user = default_user
+ else:
+ user = input("New user localpart: ")
+
+ if not user:
+ print("Invalid user name")
+ sys.exit(1)
+
+ if not password:
+ password = getpass.getpass("Password: ")
+
+ if not password:
+ print("Password cannot be blank.")
+ sys.exit(1)
+
+ confirm_password = getpass.getpass("Confirm password: ")
+
+ if password != confirm_password:
+ print("Passwords do not match")
+ sys.exit(1)
+
+ if admin is None:
+ admin = input("Make admin [no]: ")
+ if admin in ("y", "yes", "true"):
+ admin = True
+ else:
+ admin = False
+
+ request_registration(user, password, server_location, shared_secret, bool(admin))
+
+
+def main():
+
+ logging.captureWarnings(True)
+
+ parser = argparse.ArgumentParser(
+ description="Used to register new users with a given home server when"
+ " registration has been disabled. The home server must be"
+ " configured with the 'registration_shared_secret' option"
+ " set."
+ )
+ parser.add_argument(
+ "-u",
+ "--user",
+ default=None,
+ help="Local part of the new user. Will prompt if omitted.",
+ )
+ parser.add_argument(
+ "-p",
+ "--password",
+ default=None,
+ help="New password for user. Will prompt if omitted.",
+ )
+ admin_group = parser.add_mutually_exclusive_group()
+ admin_group.add_argument(
+ "-a",
+ "--admin",
+ action="store_true",
+ help=(
+ "Register new user as an admin. "
+ "Will prompt if --no-admin is not set either."
+ ),
+ )
+ admin_group.add_argument(
+ "--no-admin",
+ action="store_true",
+ help=(
+ "Register new user as a regular user. "
+ "Will prompt if --admin is not set either."
+ ),
+ )
+
+ group = parser.add_mutually_exclusive_group(required=True)
+ group.add_argument(
+ "-c",
+ "--config",
+ type=argparse.FileType('r'),
+ help="Path to server config file. Used to read in shared secret.",
+ )
+
+ group.add_argument(
+ "-k", "--shared-secret", help="Shared secret as defined in server config file."
+ )
+
+ parser.add_argument(
+ "server_url",
+ default="https://localhost:8448",
+ nargs='?',
+ help="URL to use to talk to the home server. Defaults to "
+ " 'https://localhost:8448'.",
+ )
+
+ args = parser.parse_args()
+
+ if "config" in args and args.config:
+ config = yaml.safe_load(args.config)
+ secret = config.get("registration_shared_secret", None)
+ if not secret:
+ print("No 'registration_shared_secret' defined in config.")
+ sys.exit(1)
+ else:
+ secret = args.shared_secret
+
+ admin = None
+ if args.admin or args.no_admin:
+ admin = args.admin
+
+ register_new_user(args.user, args.password, args.server_url, secret, admin)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index eed8c67e6a..677c0bdd4c 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -172,7 +172,10 @@ USER_FILTER_SCHEMA = {
# events a lot easier as we can then use a negative lookbehind
# assertion to split '\.' If we allowed \\ then it would
# incorrectly split '\\.' See synapse.events.utils.serialize_event
- "pattern": "^((?!\\\).)*$"
+ #
+ # Note that because this is a regular expression, we have to escape
+ # each backslash in the pattern.
+ "pattern": r"^((?!\\\\).)*$"
}
}
},
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index e3f0d99a3f..593e1e75db 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -20,6 +20,7 @@ import sys
from six import iteritems
+import psutil
from prometheus_client import Gauge
from twisted.application import service
@@ -502,7 +503,6 @@ def run(hs):
def performance_stats_init():
try:
- import psutil
process = psutil.Process()
# Ensure we can fetch both, and make the initial request for cpu_percent
# so the next request will use this as the initial point.
@@ -510,12 +510,9 @@ def run(hs):
process.cpu_percent(interval=None)
logger.info("report_stats can use psutil")
stats_process.append(process)
- except (ImportError, AttributeError):
- logger.warn(
- "report_stats enabled but psutil is not installed or incorrect version."
- " Disabling reporting of memory/cpu stats."
- " Ensuring psutil is available will help matrix.org track performance"
- " changes across releases."
+ except (AttributeError):
+ logger.warning(
+ "Unable to read memory/cpu stats. Disabling reporting."
)
def generate_user_daily_visit_stats():
@@ -530,10 +527,13 @@ def run(hs):
clock.looping_call(generate_user_daily_visit_stats, 5 * 60 * 1000)
# monthly active user limiting functionality
- clock.looping_call(
- hs.get_datastore().reap_monthly_active_users, 1000 * 60 * 60
- )
- hs.get_datastore().reap_monthly_active_users()
+ def reap_monthly_active_users():
+ return run_as_background_process(
+ "reap_monthly_active_users",
+ hs.get_datastore().reap_monthly_active_users,
+ )
+ clock.looping_call(reap_monthly_active_users, 1000 * 60 * 60)
+ reap_monthly_active_users()
@defer.inlineCallbacks
def generate_monthly_active_users():
@@ -547,12 +547,15 @@ def run(hs):
registered_reserved_users_mau_gauge.set(float(reserved_count))
max_mau_gauge.set(float(hs.config.max_mau_value))
- hs.get_datastore().initialise_reserved_users(
- hs.config.mau_limits_reserved_threepids
- )
- generate_monthly_active_users()
+ def start_generate_monthly_active_users():
+ return run_as_background_process(
+ "generate_monthly_active_users",
+ generate_monthly_active_users,
+ )
+
+ start_generate_monthly_active_users()
if hs.config.limit_usage_by_mau:
- clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000)
+ clock.looping_call(start_generate_monthly_active_users, 5 * 60 * 1000)
# End of monthly active user settings
if hs.config.report_stats:
@@ -568,7 +571,7 @@ def run(hs):
clock.call_later(5 * 60, start_phone_stats_home)
if hs.config.daemonize and hs.config.print_pidfile:
- print (hs.config.pid_file)
+ print(hs.config.pid_file)
_base.start_reactor(
"synapse-homeserver",
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index 0f9f8e19f6..83b0863f00 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -161,11 +161,11 @@ class PusherReplicationHandler(ReplicationClientHandler):
else:
yield self.start_pusher(row.user_id, row.app_id, row.pushkey)
elif stream_name == "events":
- self.pusher_pool.on_new_notifications(
+ yield self.pusher_pool.on_new_notifications(
token, token,
)
elif stream_name == "receipts":
- self.pusher_pool.on_new_receipts(
+ yield self.pusher_pool.on_new_receipts(
token, token, set(row.room_id for row in rows)
)
except Exception:
@@ -183,7 +183,7 @@ class PusherReplicationHandler(ReplicationClientHandler):
def start_pusher(self, user_id, app_id, pushkey):
key = "%s:%s" % (app_id, pushkey)
logger.info("Starting pusher %r / %r", user_id, key)
- return self.pusher_pool._refresh_pusher(app_id, pushkey, user_id)
+ return self.pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
def start(config_options):
diff --git a/synapse/config/__main__.py b/synapse/config/__main__.py
index 8fccf573ee..79fe9c3dac 100644
--- a/synapse/config/__main__.py
+++ b/synapse/config/__main__.py
@@ -28,7 +28,7 @@ if __name__ == "__main__":
sys.stderr.write("\n" + str(e) + "\n")
sys.exit(1)
- print (getattr(config, key))
+ print(getattr(config, key))
sys.exit(0)
else:
sys.stderr.write("Unknown command %r\n" % (action,))
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 3d2e90dd5b..14dae65ea0 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -106,10 +106,7 @@ class Config(object):
@classmethod
def check_file(cls, file_path, config_name):
if file_path is None:
- raise ConfigError(
- "Missing config for %s."
- % (config_name,)
- )
+ raise ConfigError("Missing config for %s." % (config_name,))
try:
os.stat(file_path)
except OSError as e:
@@ -128,9 +125,7 @@ class Config(object):
if e.errno != errno.EEXIST:
raise
if not os.path.isdir(dir_path):
- raise ConfigError(
- "%s is not a directory" % (dir_path,)
- )
+ raise ConfigError("%s is not a directory" % (dir_path,))
return dir_path
@classmethod
@@ -156,21 +151,20 @@ class Config(object):
return results
def generate_config(
- self,
- config_dir_path,
- server_name,
- is_generating_file,
- report_stats=None,
+ self, config_dir_path, server_name, is_generating_file, report_stats=None
):
default_config = "# vim:ft=yaml\n"
- default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all(
- "default_config",
- config_dir_path=config_dir_path,
- server_name=server_name,
- is_generating_file=is_generating_file,
- report_stats=report_stats,
- ))
+ default_config += "\n\n".join(
+ dedent(conf)
+ for conf in self.invoke_all(
+ "default_config",
+ config_dir_path=config_dir_path,
+ server_name=server_name,
+ is_generating_file=is_generating_file,
+ report_stats=report_stats,
+ )
+ )
config = yaml.load(default_config)
@@ -178,23 +172,22 @@ class Config(object):
@classmethod
def load_config(cls, description, argv):
- config_parser = argparse.ArgumentParser(
- description=description,
- )
+ config_parser = argparse.ArgumentParser(description=description)
config_parser.add_argument(
- "-c", "--config-path",
+ "-c",
+ "--config-path",
action="append",
metavar="CONFIG_FILE",
help="Specify config file. Can be given multiple times and"
- " may specify directories containing *.yaml files."
+ " may specify directories containing *.yaml files.",
)
config_parser.add_argument(
"--keys-directory",
metavar="DIRECTORY",
help="Where files such as certs and signing keys are stored when"
- " their location is given explicitly in the config."
- " Defaults to the directory containing the last config file",
+ " their location is given explicitly in the config."
+ " Defaults to the directory containing the last config file",
)
config_args = config_parser.parse_args(argv)
@@ -203,9 +196,7 @@ class Config(object):
obj = cls()
obj.read_config_files(
- config_files,
- keys_directory=config_args.keys_directory,
- generate_keys=False,
+ config_files, keys_directory=config_args.keys_directory, generate_keys=False
)
return obj
@@ -213,38 +204,38 @@ class Config(object):
def load_or_generate_config(cls, description, argv):
config_parser = argparse.ArgumentParser(add_help=False)
config_parser.add_argument(
- "-c", "--config-path",
+ "-c",
+ "--config-path",
action="append",
metavar="CONFIG_FILE",
help="Specify config file. Can be given multiple times and"
- " may specify directories containing *.yaml files."
+ " may specify directories containing *.yaml files.",
)
config_parser.add_argument(
"--generate-config",
action="store_true",
- help="Generate a config file for the server name"
+ help="Generate a config file for the server name",
)
config_parser.add_argument(
"--report-stats",
action="store",
help="Whether the generated config reports anonymized usage statistics",
- choices=["yes", "no"]
+ choices=["yes", "no"],
)
config_parser.add_argument(
"--generate-keys",
action="store_true",
- help="Generate any missing key files then exit"
+ help="Generate any missing key files then exit",
)
config_parser.add_argument(
"--keys-directory",
metavar="DIRECTORY",
help="Used with 'generate-*' options to specify where files such as"
- " certs and signing keys should be stored in, unless explicitly"
- " specified in the config."
+ " certs and signing keys should be stored in, unless explicitly"
+ " specified in the config.",
)
config_parser.add_argument(
- "-H", "--server-name",
- help="The server name to generate a config file for"
+ "-H", "--server-name", help="The server name to generate a config file for"
)
config_args, remaining_args = config_parser.parse_known_args(argv)
@@ -257,8 +248,8 @@ class Config(object):
if config_args.generate_config:
if config_args.report_stats is None:
config_parser.error(
- "Please specify either --report-stats=yes or --report-stats=no\n\n" +
- MISSING_REPORT_STATS_SPIEL
+ "Please specify either --report-stats=yes or --report-stats=no\n\n"
+ + MISSING_REPORT_STATS_SPIEL
)
if not config_files:
config_parser.error(
@@ -287,26 +278,32 @@ class Config(object):
config_dir_path=config_dir_path,
server_name=server_name,
report_stats=(config_args.report_stats == "yes"),
- is_generating_file=True
+ is_generating_file=True,
)
obj.invoke_all("generate_files", config)
config_file.write(config_str)
- print((
- "A config file has been generated in %r for server name"
- " %r with corresponding SSL keys and self-signed"
- " certificates. Please review this file and customise it"
- " to your needs."
- ) % (config_path, server_name))
+ print(
+ (
+ "A config file has been generated in %r for server name"
+ " %r with corresponding SSL keys and self-signed"
+ " certificates. Please review this file and customise it"
+ " to your needs."
+ )
+ % (config_path, server_name)
+ )
print(
"If this server name is incorrect, you will need to"
" regenerate the SSL certificates"
)
return
else:
- print((
- "Config file %r already exists. Generating any missing key"
- " files."
- ) % (config_path,))
+ print(
+ (
+ "Config file %r already exists. Generating any missing key"
+ " files."
+ )
+ % (config_path,)
+ )
generate_keys = True
parser = argparse.ArgumentParser(
@@ -338,8 +335,7 @@ class Config(object):
return obj
- def read_config_files(self, config_files, keys_directory=None,
- generate_keys=False):
+ def read_config_files(self, config_files, keys_directory=None, generate_keys=False):
if not keys_directory:
keys_directory = os.path.dirname(config_files[-1])
@@ -364,8 +360,9 @@ class Config(object):
if "report_stats" not in config:
raise ConfigError(
- MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS + "\n" +
- MISSING_REPORT_STATS_SPIEL
+ MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS
+ + "\n"
+ + MISSING_REPORT_STATS_SPIEL
)
if generate_keys:
@@ -399,16 +396,16 @@ def find_config_files(search_paths):
for entry in os.listdir(config_path):
entry_path = os.path.join(config_path, entry)
if not os.path.isfile(entry_path):
- print (
- "Found subdirectory in config directory: %r. IGNORING."
- ) % (entry_path, )
+ err = "Found subdirectory in config directory: %r. IGNORING."
+ print(err % (entry_path,))
continue
if not entry.endswith(".yaml"):
- print (
- "Found file in config directory that does not"
- " end in '.yaml': %r. IGNORING."
- ) % (entry_path, )
+ err = (
+ "Found file in config directory that does not end in "
+ "'.yaml': %r. IGNORING."
+ )
+ print(err % (entry_path,))
continue
files.append(entry_path)
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index e2582cfecc..93d70cff14 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -19,18 +19,12 @@ from __future__ import print_function
import email.utils
import logging
import os
-import sys
-import textwrap
-from ._base import Config
+import pkg_resources
-logger = logging.getLogger(__name__)
+from ._base import Config, ConfigError
-TEMPLATE_DIR_WARNING = """\
-WARNING: The email notifier is configured to look for templates in '%(template_dir)s',
-but no templates could be found there. We will fall back to using the example templates;
-to get rid of this warning, leave 'email.template_dir' unset.
-"""
+logger = logging.getLogger(__name__)
class EmailConfig(Config):
@@ -78,20 +72,22 @@ class EmailConfig(Config):
self.email_notif_template_html = email_config["notif_template_html"]
self.email_notif_template_text = email_config["notif_template_text"]
- self.email_template_dir = email_config.get("template_dir")
-
- # backwards-compatibility hack
- if (
- self.email_template_dir == "res/templates"
- and not os.path.isfile(
- os.path.join(self.email_template_dir, self.email_notif_template_text)
+ template_dir = email_config.get("template_dir")
+ # we need an absolute path, because we change directory after starting (and
+ # we don't yet know what auxilliary templates like mail.css we will need).
+ # (Note that loading as package_resources with jinja.PackageLoader doesn't
+ # work for the same reason.)
+ if not template_dir:
+ template_dir = pkg_resources.resource_filename(
+ 'synapse', 'res/templates'
)
- ):
- t = TEMPLATE_DIR_WARNING % {
- "template_dir": self.email_template_dir,
- }
- print(textwrap.fill(t, width=80) + "\n", file=sys.stderr)
- self.email_template_dir = None
+ template_dir = os.path.abspath(template_dir)
+
+ for f in self.email_notif_template_text, self.email_notif_template_html:
+ p = os.path.join(template_dir, f)
+ if not os.path.isfile(p):
+ raise ConfigError("Unable to find email template file %s" % (p, ))
+ self.email_template_dir = template_dir
self.email_notif_for_new_users = email_config.get(
"notif_for_new_users", True
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index b8d5690f2b..10dd40159f 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -31,6 +31,7 @@ from .push import PushConfig
from .ratelimiting import RatelimitConfig
from .registration import RegistrationConfig
from .repository import ContentRepositoryConfig
+from .room_directory import RoomDirectoryConfig
from .saml2 import SAML2Config
from .server import ServerConfig
from .server_notices_config import ServerNoticesConfig
@@ -49,7 +50,7 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
WorkerConfig, PasswordAuthProviderConfig, PushConfig,
SpamCheckerConfig, GroupsConfig, UserDirectoryConfig,
ConsentConfig,
- ServerNoticesConfig,
+ ServerNoticesConfig, RoomDirectoryConfig,
):
pass
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 0fb964eb67..7480ed5145 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -15,10 +15,10 @@
from distutils.util import strtobool
+from synapse.config._base import Config, ConfigError
+from synapse.types import RoomAlias
from synapse.util.stringutils import random_string_with_symbols
-from ._base import Config
-
class RegistrationConfig(Config):
@@ -44,6 +44,10 @@ class RegistrationConfig(Config):
)
self.auto_join_rooms = config.get("auto_join_rooms", [])
+ for room_alias in self.auto_join_rooms:
+ if not RoomAlias.is_valid(room_alias):
+ raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,))
+ self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
def default_config(self, **kwargs):
registration_shared_secret = random_string_with_symbols(50)
@@ -98,6 +102,13 @@ class RegistrationConfig(Config):
# to these rooms
#auto_join_rooms:
# - "#example:example.com"
+
+ # Where auto_join_rooms are specified, setting this flag ensures that the
+ # the rooms exist by creating them when the first user on the
+ # homeserver registers.
+ # Setting to false means that if the rooms are not manually created,
+ # users cannot be auto-joined since they do not exist.
+ autocreate_auto_join_rooms: true
""" % locals()
def add_arguments(self, parser):
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index fc909c1fac..06c62ab62c 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -178,7 +178,7 @@ class ContentRepositoryConfig(Config):
def default_config(self, **kwargs):
media_store = self.default_path("media_store")
uploads_path = self.default_path("uploads")
- return """
+ return r"""
# Directory where uploaded images and attachments are stored.
media_store_path: "%(media_store)s"
diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
new file mode 100644
index 0000000000..9da13ab11b
--- /dev/null
+++ b/synapse/config/room_directory.py
@@ -0,0 +1,102 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util import glob_to_regex
+
+from ._base import Config, ConfigError
+
+
+class RoomDirectoryConfig(Config):
+ def read_config(self, config):
+ alias_creation_rules = config["alias_creation_rules"]
+
+ self._alias_creation_rules = [
+ _AliasRule(rule)
+ for rule in alias_creation_rules
+ ]
+
+ def default_config(self, config_dir_path, server_name, **kwargs):
+ return """
+ # The `alias_creation` option controls who's allowed to create aliases
+ # on this server.
+ #
+ # The format of this option is a list of rules that contain globs that
+ # match against user_id and the new alias (fully qualified with server
+ # name). The action in the first rule that matches is taken, which can
+ # currently either be "allow" or "deny".
+ #
+ # If no rules match the request is denied.
+ alias_creation_rules:
+ - user_id: "*"
+ alias: "*"
+ action: allow
+ """
+
+ def is_alias_creation_allowed(self, user_id, alias):
+ """Checks if the given user is allowed to create the given alias
+
+ Args:
+ user_id (str)
+ alias (str)
+
+ Returns:
+ boolean: True if user is allowed to crate the alias
+ """
+ for rule in self._alias_creation_rules:
+ if rule.matches(user_id, alias):
+ return rule.action == "allow"
+
+ return False
+
+
+class _AliasRule(object):
+ def __init__(self, rule):
+ action = rule["action"]
+ user_id = rule["user_id"]
+ alias = rule["alias"]
+
+ if action in ("allow", "deny"):
+ self.action = action
+ else:
+ raise ConfigError(
+ "alias_creation_rules rules can only have action of 'allow'"
+ " or 'deny'"
+ )
+
+ try:
+ self._user_id_regex = glob_to_regex(user_id)
+ self._alias_regex = glob_to_regex(alias)
+ except Exception as e:
+ raise ConfigError("Failed to parse glob into regex: %s", e)
+
+ def matches(self, user_id, alias):
+ """Tests if this rule matches the given user_id and alias.
+
+ Args:
+ user_id (str)
+ alias (str)
+
+ Returns:
+ boolean
+ """
+
+ # Note: The regexes are anchored at both ends
+ if not self._user_id_regex.match(user_id):
+ return False
+
+ if not self._alias_regex.match(alias):
+ return False
+
+ return True
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
index 57d4665e84..080c81f14b 100644
--- a/synapse/crypto/keyclient.py
+++ b/synapse/crypto/keyclient.py
@@ -55,7 +55,7 @@ def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1):
raise IOError("Cannot get key for %r" % server_name)
except (ConnectError, DomainError) as e:
logger.warn("Error getting key for %r: %s", server_name, e)
- except Exception as e:
+ except Exception:
logger.exception("Error getting key for %r", server_name)
raise IOError("Cannot get key for %r" % server_name)
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index af3eee95b9..d4d4474847 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -690,7 +690,7 @@ def auth_types_for_event(event):
auth_types = []
auth_types.append((EventTypes.PowerLevels, "", ))
- auth_types.append((EventTypes.Member, event.user_id, ))
+ auth_types.append((EventTypes.Member, event.sender, ))
auth_types.append((EventTypes.Create, "", ))
if event.type == EventTypes.Member:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 4efe95faa4..0f9302a6a8 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import re
import six
from six import iteritems
@@ -44,6 +43,7 @@ from synapse.replication.http.federation import (
ReplicationGetQueryRestServlet,
)
from synapse.types import get_domain_from_id
+from synapse.util import glob_to_regex
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.logcontext import nested_logging_context
@@ -729,22 +729,10 @@ def _acl_entry_matches(server_name, acl_entry):
if not isinstance(acl_entry, six.string_types):
logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry))
return False
- regex = _glob_to_regex(acl_entry)
+ regex = glob_to_regex(acl_entry)
return regex.match(server_name)
-def _glob_to_regex(glob):
- res = ''
- for c in glob:
- if c == '*':
- res = res + '.*'
- elif c == '?':
- res = res + '.'
- else:
- res = res + re.escape(c)
- return re.compile(res + "\\Z", re.IGNORECASE)
-
-
class FederationHandlerRegistry(object):
"""Allows classes to register themselves as handlers for a given EDU or
query type for incoming federation traffic.
@@ -800,7 +788,7 @@ class FederationHandlerRegistry(object):
yield handler(origin, content)
except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e)
- except Exception as e:
+ except Exception:
logger.exception("Failed to handle edu %r", edu_type)
def on_query(self, query_type, args):
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 2a5eab124f..329e3c7d71 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -22,7 +22,7 @@ import bcrypt
import pymacaroons
from canonicaljson import json
-from twisted.internet import defer, threads
+from twisted.internet import defer
from twisted.web.client import PartialDownloadError
import synapse.util.stringutils as stringutils
@@ -37,8 +37,8 @@ from synapse.api.errors import (
)
from synapse.module_api import ModuleApi
from synapse.types import UserID
+from synapse.util import logcontext
from synapse.util.caches.expiringcache import ExpiringCache
-from synapse.util.logcontext import make_deferred_yieldable
from ._base import BaseHandler
@@ -884,11 +884,7 @@ class AuthHandler(BaseHandler):
bcrypt.gensalt(self.bcrypt_rounds),
).decode('ascii')
- return make_deferred_yieldable(
- threads.deferToThreadPool(
- self.hs.get_reactor(), self.hs.get_reactor().getThreadPool(), _do_hash
- ),
- )
+ return logcontext.defer_to_thread(self.hs.get_reactor(), _do_hash)
def validate_hash(self, password, stored_hash):
"""Validates that self.hash(password) == stored_hash.
@@ -913,13 +909,7 @@ class AuthHandler(BaseHandler):
if not isinstance(stored_hash, bytes):
stored_hash = stored_hash.encode('ascii')
- return make_deferred_yieldable(
- threads.deferToThreadPool(
- self.hs.get_reactor(),
- self.hs.get_reactor().getThreadPool(),
- _do_validate_hash,
- ),
- )
+ return logcontext.defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
else:
return defer.succeed(False)
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index b078df4a76..75fe50c42c 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -17,8 +17,8 @@ import logging
from twisted.internet import defer
from synapse.api.errors import SynapseError
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import UserID, create_requester
-from synapse.util.logcontext import run_in_background
from ._base import BaseHandler
@@ -121,7 +121,7 @@ class DeactivateAccountHandler(BaseHandler):
None
"""
if not self._user_parter_running:
- run_in_background(self._user_parter_loop)
+ run_as_background_process("user_parter_loop", self._user_parter_loop)
@defer.inlineCallbacks
def _user_parter_loop(self):
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 18741c5fac..7d67bf803a 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -43,6 +43,7 @@ class DirectoryHandler(BaseHandler):
self.state = hs.get_state_handler()
self.appservice_handler = hs.get_application_service_handler()
self.event_creation_handler = hs.get_event_creation_handler()
+ self.config = hs.config
self.federation = hs.get_federation_client()
hs.get_federation_registry().register_query_handler(
@@ -80,42 +81,68 @@ class DirectoryHandler(BaseHandler):
)
@defer.inlineCallbacks
- def create_association(self, user_id, room_alias, room_id, servers=None):
- # association creation for human users
- # TODO(erikj): Do user auth.
+ def create_association(self, requester, room_alias, room_id, servers=None,
+ send_event=True):
+ """Attempt to create a new alias
- if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
- raise SynapseError(
- 403, "This user is not permitted to create this alias",
- )
+ Args:
+ requester (Requester)
+ room_alias (RoomAlias)
+ room_id (str)
+ servers (list[str]|None): List of servers that others servers
+ should try and join via
+ send_event (bool): Whether to send an updated m.room.aliases event
- can_create = yield self.can_modify_alias(
- room_alias,
- user_id=user_id
- )
- if not can_create:
- raise SynapseError(
- 400, "This alias is reserved by an application service.",
- errcode=Codes.EXCLUSIVE
- )
- yield self._create_association(room_alias, room_id, servers, creator=user_id)
+ Returns:
+ Deferred
+ """
- @defer.inlineCallbacks
- def create_appservice_association(self, service, room_alias, room_id,
- servers=None):
- if not service.is_interested_in_alias(room_alias.to_string()):
- raise SynapseError(
- 400, "This application service has not reserved"
- " this kind of alias.", errcode=Codes.EXCLUSIVE
+ user_id = requester.user.to_string()
+
+ service = requester.app_service
+ if service:
+ if not service.is_interested_in_alias(room_alias.to_string()):
+ raise SynapseError(
+ 400, "This application service has not reserved"
+ " this kind of alias.", errcode=Codes.EXCLUSIVE
+ )
+ else:
+ if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
+ raise AuthError(
+ 403, "This user is not permitted to create this alias",
+ )
+
+ if not self.config.is_alias_creation_allowed(user_id, room_alias.to_string()):
+ # Lets just return a generic message, as there may be all sorts of
+ # reasons why we said no. TODO: Allow configurable error messages
+ # per alias creation rule?
+ raise SynapseError(
+ 403, "Not allowed to create alias",
+ )
+
+ can_create = yield self.can_modify_alias(
+ room_alias,
+ user_id=user_id
)
+ if not can_create:
+ raise AuthError(
+ 400, "This alias is reserved by an application service.",
+ errcode=Codes.EXCLUSIVE
+ )
- # association creation for app services
- yield self._create_association(room_alias, room_id, servers)
+ yield self._create_association(room_alias, room_id, servers, creator=user_id)
+ if send_event:
+ yield self.send_room_alias_update_event(
+ requester,
+ room_id
+ )
@defer.inlineCallbacks
- def delete_association(self, requester, user_id, room_alias):
+ def delete_association(self, requester, room_alias):
# association deletion for human users
+ user_id = requester.user.to_string()
+
try:
can_delete = yield self._user_can_delete_alias(room_alias, user_id)
except StoreError as e:
@@ -143,7 +170,6 @@ class DirectoryHandler(BaseHandler):
try:
yield self.send_room_alias_update_event(
requester,
- requester.user.to_string(),
room_id
)
@@ -261,7 +287,7 @@ class DirectoryHandler(BaseHandler):
)
@defer.inlineCallbacks
- def send_room_alias_update_event(self, requester, user_id, room_id):
+ def send_room_alias_update_event(self, requester, room_id):
aliases = yield self.store.get_aliases_for_room(room_id)
yield self.event_creation_handler.create_and_send_nonmember_event(
@@ -270,7 +296,7 @@ class DirectoryHandler(BaseHandler):
"type": EventTypes.Aliases,
"state_key": self.hs.hostname,
"room_id": room_id,
- "sender": user_id,
+ "sender": requester.user.to_string(),
"content": {"aliases": aliases},
},
ratelimit=False
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index cab57a8849..cd5b9bbb19 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -53,7 +53,7 @@ from synapse.replication.http.federation import (
ReplicationFederationSendEventsRestServlet,
)
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
-from synapse.state import resolve_events_with_factory
+from synapse.state import StateResolutionStore, resolve_events_with_store
from synapse.types import UserID, get_domain_from_id
from synapse.util import logcontext, unwrapFirstError
from synapse.util.async_helpers import Linearizer
@@ -384,24 +384,24 @@ class FederationHandler(BaseHandler):
for x in remote_state:
event_map[x.event_id] = x
- # Resolve any conflicting state
- @defer.inlineCallbacks
- def fetch(ev_ids):
- fetched = yield self.store.get_events(
- ev_ids, get_prev_content=False, check_redacted=False,
- )
- # add any events we fetch here to the `event_map` so that we
- # can use them to build the state event list below.
- event_map.update(fetched)
- defer.returnValue(fetched)
-
room_version = yield self.store.get_room_version(room_id)
- state_map = yield resolve_events_with_factory(
- room_version, state_maps, event_map, fetch,
+ state_map = yield resolve_events_with_store(
+ room_version, state_maps, event_map,
+ state_res_store=StateResolutionStore(self.store),
)
- # we need to give _process_received_pdu the actual state events
+ # We need to give _process_received_pdu the actual state events
# rather than event ids, so generate that now.
+
+ # First though we need to fetch all the events that are in
+ # state_map, so we can build up the state below.
+ evs = yield self.store.get_events(
+ list(state_map.values()),
+ get_prev_content=False,
+ check_redacted=False,
+ )
+ event_map.update(evs)
+
state = [
event_map[e] for e in six.itervalues(state_map)
]
@@ -2520,7 +2520,7 @@ class FederationHandler(BaseHandler):
if not backfilled: # Never notify for backfilled events
for event, _ in event_and_contexts:
- self._notify_persisted_event(event, max_stream_id)
+ yield self._notify_persisted_event(event, max_stream_id)
def _notify_persisted_event(self, event, max_stream_id):
"""Checks to see if notifier/pushers should be notified about the
@@ -2553,7 +2553,7 @@ class FederationHandler(BaseHandler):
extra_users=extra_users
)
- self.pusher_pool.on_new_notifications(
+ return self.pusher_pool.on_new_notifications(
event_stream_id, max_stream_id,
)
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 53e5e2648b..173315af6c 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -20,7 +20,7 @@ from six import iteritems
from twisted.internet import defer
-from synapse.api.errors import SynapseError
+from synapse.api.errors import HttpResponseException, SynapseError
from synapse.types import get_domain_from_id
logger = logging.getLogger(__name__)
@@ -37,9 +37,23 @@ def _create_rerouter(func_name):
)
else:
destination = get_domain_from_id(group_id)
- return getattr(self.transport_client, func_name)(
+ d = getattr(self.transport_client, func_name)(
destination, group_id, *args, **kwargs
)
+
+ # Capture errors returned by the remote homeserver and
+ # re-throw specific errors as SynapseErrors. This is so
+ # when the remote end responds with things like 403 Not
+ # In Group, we can communicate that to the client instead
+ # of a 500.
+ def h(failure):
+ failure.trap(HttpResponseException)
+ e = failure.value
+ if e.code == 403:
+ raise e.to_synapse_error()
+ return failure
+ d.addErrback(h)
+ return d
return f
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index e009395207..563bb3cea3 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -156,7 +156,7 @@ class InitialSyncHandler(BaseHandler):
room_end_token = "s%d" % (event.stream_ordering,)
deferred_room_state = run_in_background(
self.store.get_state_for_events,
- [event.event_id], None,
+ [event.event_id],
)
deferred_room_state.addCallback(
lambda states: states[event.event_id]
@@ -301,7 +301,7 @@ class InitialSyncHandler(BaseHandler):
def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
membership, member_event_id, is_peeking):
room_state = yield self.store.get_state_for_events(
- [member_event_id], None
+ [member_event_id],
)
room_state = room_state[member_event_id]
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 4954b23a0d..969e588e73 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -35,6 +35,7 @@ from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
+from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder
@@ -80,7 +81,7 @@ class MessageHandler(object):
elif membership == Membership.LEAVE:
key = (event_type, state_key)
room_state = yield self.store.get_state_for_events(
- [membership_event_id], [key]
+ [membership_event_id], StateFilter.from_types([key])
)
data = room_state[membership_event_id].get(key)
@@ -88,7 +89,7 @@ class MessageHandler(object):
@defer.inlineCallbacks
def get_state_events(
- self, user_id, room_id, types=None, filtered_types=None,
+ self, user_id, room_id, state_filter=StateFilter.all(),
at_token=None, is_guest=False,
):
"""Retrieve all state events for a given room. If the user is
@@ -100,13 +101,8 @@ class MessageHandler(object):
Args:
user_id(str): The user requesting state events.
room_id(str): The room ID to get all state events from.
- types(list[(str, str|None)]|None): List of (type, state_key) tuples
- which are used to filter the state fetched. If `state_key` is None,
- all events are returned of the given type.
- May be None, which matches any key.
- filtered_types(list[str]|None): Only apply filtering via `types` to this
- list of event types. Other types of events are returned unfiltered.
- If None, `types` filtering is applied to all events.
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
at_token(StreamToken|None): the stream token of the at which we are requesting
the stats. If the user is not allowed to view the state as of that
stream token, we raise a 403 SynapseError. If None, returns the current
@@ -139,7 +135,7 @@ class MessageHandler(object):
event = last_events[0]
if visible_events:
room_state = yield self.store.get_state_for_events(
- [event.event_id], types, filtered_types=filtered_types,
+ [event.event_id], state_filter=state_filter,
)
room_state = room_state[event.event_id]
else:
@@ -158,12 +154,12 @@ class MessageHandler(object):
if membership == Membership.JOIN:
state_ids = yield self.store.get_filtered_current_state_ids(
- room_id, types, filtered_types=filtered_types,
+ room_id, state_filter=state_filter,
)
room_state = yield self.store.get_events(state_ids.values())
elif membership == Membership.LEAVE:
room_state = yield self.store.get_state_for_events(
- [membership_event_id], types, filtered_types=filtered_types,
+ [membership_event_id], state_filter=state_filter,
)
room_state = room_state[membership_event_id]
@@ -779,7 +775,7 @@ class EventCreationHandler(object):
event, context=context
)
- self.pusher_pool.on_new_notifications(
+ yield self.pusher_pool.on_new_notifications(
event_stream_id, max_stream_id,
)
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index a155b6e938..43f81bd607 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -21,6 +21,7 @@ from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.events.utils import serialize_event
+from synapse.storage.state import StateFilter
from synapse.types import RoomStreamToken
from synapse.util.async_helpers import ReadWriteLock
from synapse.util.logcontext import run_in_background
@@ -255,16 +256,14 @@ class PaginationHandler(object):
if event_filter and event_filter.lazy_load_members():
# TODO: remove redundant members
- types = [
- (EventTypes.Member, state_key)
- for state_key in set(
- event.sender # FIXME: we also care about invite targets etc.
- for event in events
- )
- ]
+ # FIXME: we also care about invite targets etc.
+ state_filter = StateFilter.from_types(
+ (EventTypes.Member, event.sender)
+ for event in events
+ )
state_ids = yield self.store.get_state_ids_for_event(
- events[0].event_id, types=types,
+ events[0].event_id, state_filter=state_filter,
)
if state_ids:
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index a6f3181f09..4c2690ba26 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -119,7 +119,7 @@ class ReceiptsHandler(BaseHandler):
"receipt_key", max_batch_id, rooms=affected_room_ids
)
# Note that the min here shouldn't be relied upon to be accurate.
- self.hs.get_pusherpool().on_new_receipts(
+ yield self.hs.get_pusherpool().on_new_receipts(
min_batch_id, max_batch_id, affected_room_ids,
)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index da914c46ff..e9d7b25a36 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -50,6 +50,7 @@ class RegistrationHandler(BaseHandler):
self._auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
+ self.room_creation_handler = self.hs.get_room_creation_handler()
self.captcha_client = CaptchaServerHttpClient(hs)
self._next_generated_user_id = None
@@ -220,9 +221,36 @@ class RegistrationHandler(BaseHandler):
# auto-join the user to any rooms we're supposed to dump them into
fake_requester = create_requester(user_id)
+
+ # try to create the room if we're the first user on the server
+ should_auto_create_rooms = False
+ if self.hs.config.autocreate_auto_join_rooms:
+ count = yield self.store.count_all_users()
+ should_auto_create_rooms = count == 1
+
for r in self.hs.config.auto_join_rooms:
try:
- yield self._join_user_to_room(fake_requester, r)
+ if should_auto_create_rooms:
+ room_alias = RoomAlias.from_string(r)
+ if self.hs.hostname != room_alias.domain:
+ logger.warning(
+ 'Cannot create room alias %s, '
+ 'it does not match server domain',
+ r,
+ )
+ else:
+ # create room expects the localpart of the room alias
+ room_alias_localpart = room_alias.localpart
+ yield self.room_creation_handler.create_room(
+ fake_requester,
+ config={
+ "preset": "public_chat",
+ "room_alias_name": room_alias_localpart
+ },
+ ratelimit=False,
+ )
+ else:
+ yield self._join_user_to_room(fake_requester, r)
except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index c3f820b975..3ba92bdb4c 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -33,6 +33,7 @@ from synapse.api.constants import (
RoomCreationPreset,
)
from synapse.api.errors import AuthError, Codes, StoreError, SynapseError
+from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
from synapse.util import stringutils
from synapse.visibility import filter_events_for_client
@@ -190,10 +191,11 @@ class RoomCreationHandler(BaseHandler):
if room_alias:
directory_handler = self.hs.get_handlers().directory_handler
yield directory_handler.create_association(
- user_id=user_id,
+ requester=requester,
room_id=room_id,
room_alias=room_alias,
servers=[self.hs.hostname],
+ send_event=False,
)
preset_config = config.get(
@@ -289,7 +291,7 @@ class RoomCreationHandler(BaseHandler):
if room_alias:
result["room_alias"] = room_alias.to_string()
yield directory_handler.send_room_alias_update_event(
- requester, user_id, room_id
+ requester, room_id
)
defer.returnValue(result)
@@ -488,23 +490,24 @@ class RoomContextHandler(object):
else:
last_event_id = event_id
- types = None
- filtered_types = None
if event_filter and event_filter.lazy_load_members():
- members = set(ev.sender for ev in itertools.chain(
- results["events_before"],
- (results["event"],),
- results["events_after"],
- ))
- filtered_types = [EventTypes.Member]
- types = [(EventTypes.Member, member) for member in members]
+ state_filter = StateFilter.from_lazy_load_member_list(
+ ev.sender
+ for ev in itertools.chain(
+ results["events_before"],
+ (results["event"],),
+ results["events_after"],
+ )
+ )
+ else:
+ state_filter = StateFilter.all()
# XXX: why do we return the state as of the last event rather than the
# first? Shouldn't we be consistent with /sync?
# https://github.com/matrix-org/matrix-doc/issues/687
state = yield self.store.get_state_for_events(
- [last_event_id], types, filtered_types=filtered_types,
+ [last_event_id], state_filter=state_filter,
)
results["state"] = list(state[last_event_id].values())
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 351892a94f..09739f2862 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -27,6 +27,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.roommember import MemberSummary
+from synapse.storage.state import StateFilter
from synapse.types import RoomStreamToken
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
@@ -469,25 +470,20 @@ class SyncHandler(object):
))
@defer.inlineCallbacks
- def get_state_after_event(self, event, types=None, filtered_types=None):
+ def get_state_after_event(self, event, state_filter=StateFilter.all()):
"""
Get the room state after the given event
Args:
event(synapse.events.EventBase): event of interest
- types(list[(str, str|None)]|None): List of (type, state_key) tuples
- which are used to filter the state fetched. If `state_key` is None,
- all events are returned of the given type.
- May be None, which matches any key.
- filtered_types(list[str]|None): Only apply filtering via `types` to this
- list of event types. Other types of events are returned unfiltered.
- If None, `types` filtering is applied to all events.
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
Returns:
A Deferred map from ((type, state_key)->Event)
"""
state_ids = yield self.store.get_state_ids_for_event(
- event.event_id, types, filtered_types=filtered_types,
+ event.event_id, state_filter=state_filter,
)
if event.is_state():
state_ids = state_ids.copy()
@@ -495,18 +491,14 @@ class SyncHandler(object):
defer.returnValue(state_ids)
@defer.inlineCallbacks
- def get_state_at(self, room_id, stream_position, types=None, filtered_types=None):
+ def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()):
""" Get the room state at a particular stream position
Args:
room_id(str): room for which to get state
stream_position(StreamToken): point at which to get state
- types(list[(str, str|None)]|None): List of (type, state_key) tuples
- which are used to filter the state fetched. If `state_key` is None,
- all events are returned of the given type.
- filtered_types(list[str]|None): Only apply filtering via `types` to this
- list of event types. Other types of events are returned unfiltered.
- If None, `types` filtering is applied to all events.
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
Returns:
A Deferred map from ((type, state_key)->Event)
@@ -522,7 +514,7 @@ class SyncHandler(object):
if last_events:
last_event = last_events[-1]
state = yield self.get_state_after_event(
- last_event, types, filtered_types=filtered_types,
+ last_event, state_filter=state_filter,
)
else:
@@ -563,10 +555,11 @@ class SyncHandler(object):
last_event = last_events[-1]
state_ids = yield self.store.get_state_ids_for_event(
- last_event.event_id, [
+ last_event.event_id,
+ state_filter=StateFilter.from_types([
(EventTypes.Name, ''),
(EventTypes.CanonicalAlias, ''),
- ]
+ ]),
)
# this is heavily cached, thus: fast.
@@ -717,8 +710,7 @@ class SyncHandler(object):
with Measure(self.clock, "compute_state_delta"):
- types = None
- filtered_types = None
+ members_to_fetch = None
lazy_load_members = sync_config.filter_collection.lazy_load_members()
include_redundant_members = (
@@ -729,16 +721,21 @@ class SyncHandler(object):
# We only request state for the members needed to display the
# timeline:
- types = [
- (EventTypes.Member, state_key)
- for state_key in set(
- event.sender # FIXME: we also care about invite targets etc.
- for event in batch.events
- )
- ]
+ members_to_fetch = set(
+ event.sender # FIXME: we also care about invite targets etc.
+ for event in batch.events
+ )
+
+ if full_state:
+ # always make sure we LL ourselves so we know we're in the room
+ # (if we are) to fix https://github.com/vector-im/riot-web/issues/7209
+ # We only need apply this on full state syncs given we disabled
+ # LL for incr syncs in #3840.
+ members_to_fetch.add(sync_config.user.to_string())
- # only apply the filtering to room members
- filtered_types = [EventTypes.Member]
+ state_filter = StateFilter.from_lazy_load_member_list(members_to_fetch)
+ else:
+ state_filter = StateFilter.all()
timeline_state = {
(event.type, event.state_key): event.event_id
@@ -746,28 +743,19 @@ class SyncHandler(object):
}
if full_state:
- if lazy_load_members:
- # always make sure we LL ourselves so we know we're in the room
- # (if we are) to fix https://github.com/vector-im/riot-web/issues/7209
- # We only need apply this on full state syncs given we disabled
- # LL for incr syncs in #3840.
- types.append((EventTypes.Member, sync_config.user.to_string()))
-
if batch:
current_state_ids = yield self.store.get_state_ids_for_event(
- batch.events[-1].event_id, types=types,
- filtered_types=filtered_types,
+ batch.events[-1].event_id, state_filter=state_filter,
)
state_ids = yield self.store.get_state_ids_for_event(
- batch.events[0].event_id, types=types,
- filtered_types=filtered_types,
+ batch.events[0].event_id, state_filter=state_filter,
)
else:
current_state_ids = yield self.get_state_at(
- room_id, stream_position=now_token, types=types,
- filtered_types=filtered_types,
+ room_id, stream_position=now_token,
+ state_filter=state_filter,
)
state_ids = current_state_ids
@@ -781,8 +769,7 @@ class SyncHandler(object):
)
elif batch.limited:
state_at_timeline_start = yield self.store.get_state_ids_for_event(
- batch.events[0].event_id, types=types,
- filtered_types=filtered_types,
+ batch.events[0].event_id, state_filter=state_filter,
)
# for now, we disable LL for gappy syncs - see
@@ -797,17 +784,15 @@ class SyncHandler(object):
# members to just be ones which were timeline senders, which then ensures
# all of the rest get included in the state block (if we need to know
# about them).
- types = None
- filtered_types = None
+ state_filter = StateFilter.all()
state_at_previous_sync = yield self.get_state_at(
- room_id, stream_position=since_token, types=types,
- filtered_types=filtered_types,
+ room_id, stream_position=since_token,
+ state_filter=state_filter,
)
current_state_ids = yield self.store.get_state_ids_for_event(
- batch.events[-1].event_id, types=types,
- filtered_types=filtered_types,
+ batch.events[-1].event_id, state_filter=state_filter,
)
state_ids = _calculate_state(
@@ -821,7 +806,7 @@ class SyncHandler(object):
else:
state_ids = {}
if lazy_load_members:
- if types and batch.events:
+ if members_to_fetch and batch.events:
# We're returning an incremental sync, with no
# "gap" since the previous sync, so normally there would be
# no state to return.
@@ -831,8 +816,12 @@ class SyncHandler(object):
# timeline here, and then dedupe any redundant ones below.
state_ids = yield self.store.get_state_ids_for_event(
- batch.events[0].event_id, types=types,
- filtered_types=None, # we only want members!
+ batch.events[0].event_id,
+ # we only want members!
+ state_filter=StateFilter.from_types(
+ (EventTypes.Member, member)
+ for member in members_to_fetch
+ ),
)
if lazy_load_members and not include_redundant_members:
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index d8413d6aa7..f11b430126 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -20,6 +20,7 @@ from six import iteritems
from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.roommember import ProfileInfo
from synapse.types import get_localpart_from_id
from synapse.util.metrics import Measure
@@ -98,7 +99,6 @@ class UserDirectoryHandler(object):
"""
return self.store.search_user_dir(user_id, search_term, limit)
- @defer.inlineCallbacks
def notify_new_event(self):
"""Called when there may be more deltas to process
"""
@@ -108,11 +108,15 @@ class UserDirectoryHandler(object):
if self._is_processing:
return
+ @defer.inlineCallbacks
+ def process():
+ try:
+ yield self._unsafe_process()
+ finally:
+ self._is_processing = False
+
self._is_processing = True
- try:
- yield self._unsafe_process()
- finally:
- self._is_processing = False
+ run_as_background_process("user_directory.notify_new_event", process)
@defer.inlineCallbacks
def handle_local_profile_change(self, user_id, profile):
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index fcc02fc77d..24b6110c20 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -230,7 +230,7 @@ class MatrixFederationHttpClient(object):
Returns:
Deferred: resolves with the http response object on success.
- Fails with ``HTTPRequestException``: if we get an HTTP response
+ Fails with ``HttpResponseException``: if we get an HTTP response
code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready
@@ -480,7 +480,7 @@ class MatrixFederationHttpClient(object):
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body.
- Fails with ``HTTPRequestException`` if we get an HTTP response
+ Fails with ``HttpResponseException`` if we get an HTTP response
code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready
@@ -534,7 +534,7 @@ class MatrixFederationHttpClient(object):
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body.
- Fails with ``HTTPRequestException`` if we get an HTTP response
+ Fails with ``HttpResponseException`` if we get an HTTP response
code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready
@@ -589,7 +589,7 @@ class MatrixFederationHttpClient(object):
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body.
- Fails with ``HTTPRequestException`` if we get an HTTP response
+ Fails with ``HttpResponseException`` if we get an HTTP response
code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready
@@ -640,7 +640,7 @@ class MatrixFederationHttpClient(object):
Deferred: Succeeds when we get a 2xx HTTP response. The result
will be the decoded JSON body.
- Fails with ``HTTPRequestException`` if we get an HTTP response
+ Fails with ``HttpResponseException`` if we get an HTTP response
code >= 300.
Fails with ``NotRetryingDestination`` if we are not yet ready
@@ -684,7 +684,7 @@ class MatrixFederationHttpClient(object):
Deferred: resolves with an (int,dict) tuple of the file length and
a dict of the response headers.
- Fails with ``HTTPRequestException`` if we get an HTTP response code
+ Fails with ``HttpResponseException`` if we get an HTTP response code
>= 300
Fails with ``NotRetryingDestination`` if we are not yet ready
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index d746371420..f369124258 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -18,8 +18,7 @@ import logging
from twisted.internet import defer
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
-from synapse.util.logcontext import LoggingContext
-from synapse.util.metrics import Measure
+from synapse.metrics.background_process_metrics import run_as_background_process
logger = logging.getLogger(__name__)
@@ -71,18 +70,11 @@ class EmailPusher(object):
# See httppusher
self.max_stream_ordering = None
- self.processing = False
+ self._is_processing = False
- @defer.inlineCallbacks
def on_started(self):
if self.mailer is not None:
- try:
- self.throttle_params = yield self.store.get_throttle_params_by_room(
- self.pusher_id
- )
- yield self._process()
- except Exception:
- logger.exception("Error starting email pusher")
+ self._start_processing()
def on_stop(self):
if self.timed_call:
@@ -92,43 +84,52 @@ class EmailPusher(object):
pass
self.timed_call = None
- @defer.inlineCallbacks
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering)
- yield self._process()
+ self._start_processing()
def on_new_receipts(self, min_stream_id, max_stream_id):
# We could wake up and cancel the timer but there tend to be quite a
# lot of read receipts so it's probably less work to just let the
# timer fire
- return defer.succeed(None)
+ pass
- @defer.inlineCallbacks
def on_timer(self):
self.timed_call = None
- yield self._process()
+ self._start_processing()
+
+ def _start_processing(self):
+ if self._is_processing:
+ return
+
+ run_as_background_process("emailpush.process", self._process)
@defer.inlineCallbacks
def _process(self):
- if self.processing:
- return
+ # we should never get here if we are already processing
+ assert not self._is_processing
+
+ try:
+ self._is_processing = True
+
+ if self.throttle_params is None:
+ # this is our first loop: load up the throttle params
+ self.throttle_params = yield self.store.get_throttle_params_by_room(
+ self.pusher_id
+ )
- with LoggingContext("emailpush._process"):
- with Measure(self.clock, "emailpush._process"):
+ # if the max ordering changes while we're running _unsafe_process,
+ # call it again, and so on until we've caught up.
+ while True:
+ starting_max_ordering = self.max_stream_ordering
try:
- self.processing = True
- # if the max ordering changes while we're running _unsafe_process,
- # call it again, and so on until we've caught up.
- while True:
- starting_max_ordering = self.max_stream_ordering
- try:
- yield self._unsafe_process()
- except Exception:
- logger.exception("Exception processing notifs")
- if self.max_stream_ordering == starting_max_ordering:
- break
- finally:
- self.processing = False
+ yield self._unsafe_process()
+ except Exception:
+ logger.exception("Exception processing notifs")
+ if self.max_stream_ordering == starting_max_ordering:
+ break
+ finally:
+ self._is_processing = False
@defer.inlineCallbacks
def _unsafe_process(self):
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 48abd5e4d6..6bd703632d 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -22,9 +22,8 @@ from prometheus_client import Counter
from twisted.internet import defer
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import PusherConfigException
-from synapse.util.logcontext import LoggingContext
-from synapse.util.metrics import Measure
from . import push_rule_evaluator, push_tools
@@ -61,7 +60,7 @@ class HttpPusher(object):
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.failing_since = pusherdict['failing_since']
self.timed_call = None
- self.processing = False
+ self._is_processing = False
# This is the highest stream ordering we know it's safe to process.
# When new events arrive, we'll be given a window of new events: we
@@ -92,34 +91,27 @@ class HttpPusher(object):
self.data_minus_url.update(self.data)
del self.data_minus_url['url']
- @defer.inlineCallbacks
def on_started(self):
- try:
- yield self._process()
- except Exception:
- logger.exception("Error starting http pusher")
+ self._start_processing()
- @defer.inlineCallbacks
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering or 0)
- yield self._process()
+ self._start_processing()
- @defer.inlineCallbacks
def on_new_receipts(self, min_stream_id, max_stream_id):
# Note that the min here shouldn't be relied upon to be accurate.
# We could check the receipts are actually m.read receipts here,
# but currently that's the only type of receipt anyway...
- with LoggingContext("push.on_new_receipts"):
- with Measure(self.clock, "push.on_new_receipts"):
- badge = yield push_tools.get_badge_count(
- self.hs.get_datastore(), self.user_id
- )
- yield self._send_badge(badge)
+ run_as_background_process("http_pusher.on_new_receipts", self._update_badge)
@defer.inlineCallbacks
+ def _update_badge(self):
+ badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
+ yield self._send_badge(badge)
+
def on_timer(self):
- yield self._process()
+ self._start_processing()
def on_stop(self):
if self.timed_call:
@@ -129,27 +121,31 @@ class HttpPusher(object):
pass
self.timed_call = None
+ def _start_processing(self):
+ if self._is_processing:
+ return
+
+ run_as_background_process("httppush.process", self._process)
+
@defer.inlineCallbacks
def _process(self):
- if self.processing:
- return
+ # we should never get here if we are already processing
+ assert not self._is_processing
- with LoggingContext("push._process"):
- with Measure(self.clock, "push._process"):
+ try:
+ self._is_processing = True
+ # if the max ordering changes while we're running _unsafe_process,
+ # call it again, and so on until we've caught up.
+ while True:
+ starting_max_ordering = self.max_stream_ordering
try:
- self.processing = True
- # if the max ordering changes while we're running _unsafe_process,
- # call it again, and so on until we've caught up.
- while True:
- starting_max_ordering = self.max_stream_ordering
- try:
- yield self._unsafe_process()
- except Exception:
- logger.exception("Exception processing notifs")
- if self.max_stream_ordering == starting_max_ordering:
- break
- finally:
- self.processing = False
+ yield self._unsafe_process()
+ except Exception:
+ logger.exception("Exception processing notifs")
+ if self.max_stream_ordering == starting_max_ordering:
+ break
+ finally:
+ self._is_processing = False
@defer.inlineCallbacks
def _unsafe_process(self):
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index b9dcfee740..16fb5e8471 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -526,12 +526,8 @@ def load_jinja2_templates(config):
Returns:
(notif_template_html, notif_template_text)
"""
- logger.info("loading jinja2")
-
- if config.email_template_dir:
- loader = jinja2.FileSystemLoader(config.email_template_dir)
- else:
- loader = jinja2.PackageLoader('synapse', 'res/templates')
+ logger.info("loading email templates from '%s'", config.email_template_dir)
+ loader = jinja2.FileSystemLoader(config.email_template_dir)
env = jinja2.Environment(loader=loader)
env.filters["format_ts"] = format_ts_filter
env.filters["mxc_to_http"] = _create_mxc_to_http_filter(config)
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 9f7d5ef217..5a4e73ccd6 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -20,24 +20,39 @@ from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push.pusher import PusherFactory
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
class PusherPool:
+ """
+ The pusher pool. This is responsible for dispatching notifications of new events to
+ the http and email pushers.
+
+ It provides three methods which are designed to be called by the rest of the
+ application: `start`, `on_new_notifications`, and `on_new_receipts`: each of these
+ delegates to each of the relevant pushers.
+
+ Note that it is expected that each pusher will have its own 'processing' loop which
+ will send out the notifications in the background, rather than blocking until the
+ notifications are sent; accordingly Pusher.on_started, Pusher.on_new_notifications and
+ Pusher.on_new_receipts are not expected to return deferreds.
+ """
def __init__(self, _hs):
self.hs = _hs
self.pusher_factory = PusherFactory(_hs)
- self.start_pushers = _hs.config.start_pushers
+ self._should_start_pushers = _hs.config.start_pushers
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.pushers = {}
- @defer.inlineCallbacks
def start(self):
- pushers = yield self.store.get_all_pushers()
- self._start_pushers(pushers)
+ """Starts the pushers off in a background process.
+ """
+ if not self._should_start_pushers:
+ logger.info("Not starting pushers because they are disabled in the config")
+ return
+ run_as_background_process("start_pushers", self._start_pushers)
@defer.inlineCallbacks
def add_pusher(self, user_id, access_token, kind, app_id,
@@ -86,7 +101,7 @@ class PusherPool:
last_stream_ordering=last_stream_ordering,
profile_tag=profile_tag,
)
- yield self._refresh_pusher(app_id, pushkey, user_id)
+ yield self.start_pusher_by_id(app_id, pushkey, user_id)
@defer.inlineCallbacks
def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey,
@@ -123,45 +138,23 @@ class PusherPool:
p['app_id'], p['pushkey'], p['user_name'],
)
- def on_new_notifications(self, min_stream_id, max_stream_id):
- run_as_background_process(
- "on_new_notifications",
- self._on_new_notifications, min_stream_id, max_stream_id,
- )
-
@defer.inlineCallbacks
- def _on_new_notifications(self, min_stream_id, max_stream_id):
+ def on_new_notifications(self, min_stream_id, max_stream_id):
try:
users_affected = yield self.store.get_push_action_users_in_range(
min_stream_id, max_stream_id
)
- deferreds = []
-
for u in users_affected:
if u in self.pushers:
for p in self.pushers[u].values():
- deferreds.append(
- run_in_background(
- p.on_new_notifications,
- min_stream_id, max_stream_id,
- )
- )
-
- yield make_deferred_yieldable(
- defer.gatherResults(deferreds, consumeErrors=True),
- )
+ p.on_new_notifications(min_stream_id, max_stream_id)
+
except Exception:
logger.exception("Exception in pusher on_new_notifications")
- def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
- run_as_background_process(
- "on_new_receipts",
- self._on_new_receipts, min_stream_id, max_stream_id, affected_room_ids,
- )
-
@defer.inlineCallbacks
- def _on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
+ def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
try:
# Need to subtract 1 from the minimum because the lower bound here
# is not inclusive
@@ -171,26 +164,20 @@ class PusherPool:
# This returns a tuple, user_id is at index 3
users_affected = set([r[3] for r in updated_receipts])
- deferreds = []
-
for u in users_affected:
if u in self.pushers:
for p in self.pushers[u].values():
- deferreds.append(
- run_in_background(
- p.on_new_receipts,
- min_stream_id, max_stream_id,
- )
- )
-
- yield make_deferred_yieldable(
- defer.gatherResults(deferreds, consumeErrors=True),
- )
+ p.on_new_receipts(min_stream_id, max_stream_id)
+
except Exception:
logger.exception("Exception in pusher on_new_receipts")
@defer.inlineCallbacks
- def _refresh_pusher(self, app_id, pushkey, user_id):
+ def start_pusher_by_id(self, app_id, pushkey, user_id):
+ """Look up the details for the given pusher, and start it"""
+ if not self._should_start_pushers:
+ return
+
resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(
app_id, pushkey
)
@@ -201,33 +188,49 @@ class PusherPool:
p = r
if p:
+ self._start_pusher(p)
- self._start_pushers([p])
+ @defer.inlineCallbacks
+ def _start_pushers(self):
+ """Start all the pushers
- def _start_pushers(self, pushers):
- if not self.start_pushers:
- logger.info("Not starting pushers because they are disabled in the config")
- return
+ Returns:
+ Deferred
+ """
+ pushers = yield self.store.get_all_pushers()
logger.info("Starting %d pushers", len(pushers))
for pusherdict in pushers:
- try:
- p = self.pusher_factory.create_pusher(pusherdict)
- except Exception:
- logger.exception("Couldn't start a pusher: caught Exception")
- continue
- if p:
- appid_pushkey = "%s:%s" % (
- pusherdict['app_id'],
- pusherdict['pushkey'],
- )
- byuser = self.pushers.setdefault(pusherdict['user_name'], {})
+ self._start_pusher(pusherdict)
+ logger.info("Started pushers")
- if appid_pushkey in byuser:
- byuser[appid_pushkey].on_stop()
- byuser[appid_pushkey] = p
- run_in_background(p.on_started)
+ def _start_pusher(self, pusherdict):
+ """Start the given pusher
- logger.info("Started pushers")
+ Args:
+ pusherdict (dict):
+
+ Returns:
+ None
+ """
+ try:
+ p = self.pusher_factory.create_pusher(pusherdict)
+ except Exception:
+ logger.exception("Couldn't start a pusher: caught Exception")
+ return
+
+ if not p:
+ return
+
+ appid_pushkey = "%s:%s" % (
+ pusherdict['app_id'],
+ pusherdict['pushkey'],
+ )
+ byuser = self.pushers.setdefault(pusherdict['user_name'], {})
+
+ if appid_pushkey in byuser:
+ byuser[appid_pushkey].on_stop()
+ byuser[appid_pushkey] = p
+ p.on_started()
@defer.inlineCallbacks
def remove_pusher(self, app_id, pushkey, user_id):
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index f51184b50d..943876456b 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -53,6 +53,7 @@ REQUIREMENTS = {
"pillow>=3.1.2": ["PIL"],
"pydenticon>=0.2": ["pydenticon"],
"sortedcontainers>=1.4.4": ["sortedcontainers"],
+ "psutil>=2.0.0": ["psutil>=2.0.0"],
"pysaml2>=3.0.0": ["saml2"],
"pymacaroons-pynacl>=0.9.3": ["pymacaroons"],
"msgpack-python>=0.4.2": ["msgpack"],
@@ -79,9 +80,6 @@ CONDITIONAL_REQUIREMENTS = {
"matrix-synapse-ldap3": {
"matrix-synapse-ldap3>=0.1": ["ldap_auth_provider"],
},
- "psutil": {
- "psutil>=2.0.0": ["psutil>=2.0.0"],
- },
"postgres": {
"psycopg2>=2.6": ["psycopg2"]
}
diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 97733f3026..0220acf644 100644
--- a/synapse/rest/client/v1/directory.py
+++ b/synapse/rest/client/v1/directory.py
@@ -74,38 +74,11 @@ class ClientDirectoryServer(ClientV1RestServlet):
if room is None:
raise SynapseError(400, "Room does not exist")
- dir_handler = self.handlers.directory_handler
+ requester = yield self.auth.get_user_by_req(request)
- try:
- # try to auth as a user
- requester = yield self.auth.get_user_by_req(request)
- try:
- user_id = requester.user.to_string()
- yield dir_handler.create_association(
- user_id, room_alias, room_id, servers
- )
- yield dir_handler.send_room_alias_update_event(
- requester,
- user_id,
- room_id
- )
- except SynapseError as e:
- raise e
- except Exception:
- logger.exception("Failed to create association")
- raise
- except AuthError:
- # try to auth as an application service
- service = yield self.auth.get_appservice_by_req(request)
- yield dir_handler.create_appservice_association(
- service, room_alias, room_id, servers
- )
- logger.info(
- "Application service at %s created alias %s pointing to %s",
- service.url,
- room_alias.to_string(),
- room_id
- )
+ yield self.handlers.directory_handler.create_association(
+ requester, room_alias, room_id, servers
+ )
defer.returnValue((200, {}))
@@ -135,7 +108,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
room_alias = RoomAlias.from_string(room_alias)
yield dir_handler.delete_association(
- requester, user.to_string(), room_alias
+ requester, room_alias
)
logger.info(
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 663934efd0..fcfe7857f6 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -33,6 +33,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
@@ -409,7 +410,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
room_id=room_id,
user_id=requester.user.to_string(),
at_token=at_token,
- types=[(EventTypes.Member, None)],
+ state_filter=StateFilter.from_types([(EventTypes.Member, None)]),
)
chunk = []
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index bd8b5f4afa..693b303881 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -99,7 +99,7 @@ class AuthRestServlet(RestServlet):
cannot be handled in the normal flow (with requests to the same endpoint).
Current use is for web fallback auth.
"""
- PATTERNS = client_v2_patterns("/auth/(?P<stagetype>[\w\.]*)/fallback/web")
+ PATTERNS = client_v2_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
def __init__(self, hs):
super(AuthRestServlet, self).__init__()
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index a828ff4438..08b1867fab 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -25,7 +25,7 @@ from six.moves.urllib import parse as urlparse
import twisted.internet.error
import twisted.web.http
-from twisted.internet import defer, threads
+from twisted.internet import defer
from twisted.web.resource import Resource
from synapse.api.errors import (
@@ -36,8 +36,8 @@ from synapse.api.errors import (
)
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util import logcontext
from synapse.util.async_helpers import Linearizer
-from synapse.util.logcontext import make_deferred_yieldable
from synapse.util.retryutils import NotRetryingDestination
from synapse.util.stringutils import is_ascii, random_string
@@ -492,10 +492,11 @@ class MediaRepository(object):
))
thumbnailer = Thumbnailer(input_path)
- t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
+ t_byte_source = yield logcontext.defer_to_thread(
+ self.hs.get_reactor(),
self._generate_thumbnail,
thumbnailer, t_width, t_height, t_method, t_type
- ))
+ )
if t_byte_source:
try:
@@ -534,10 +535,11 @@ class MediaRepository(object):
))
thumbnailer = Thumbnailer(input_path)
- t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
+ t_byte_source = yield logcontext.defer_to_thread(
+ self.hs.get_reactor(),
self._generate_thumbnail,
thumbnailer, t_width, t_height, t_method, t_type
- ))
+ )
if t_byte_source:
try:
@@ -620,15 +622,17 @@ class MediaRepository(object):
for (t_width, t_height, t_type), t_method in iteritems(thumbnails):
# Generate the thumbnail
if t_method == "crop":
- t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
+ t_byte_source = yield logcontext.defer_to_thread(
+ self.hs.get_reactor(),
thumbnailer.crop,
t_width, t_height, t_type,
- ))
+ )
elif t_method == "scale":
- t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
+ t_byte_source = yield logcontext.defer_to_thread(
+ self.hs.get_reactor(),
thumbnailer.scale,
t_width, t_height, t_type,
- ))
+ )
else:
logger.error("Unrecognized method: %r", t_method)
continue
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index a6189224ee..896078fe76 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -21,9 +21,10 @@ import sys
import six
-from twisted.internet import defer, threads
+from twisted.internet import defer
from twisted.protocols.basic import FileSender
+from synapse.util import logcontext
from synapse.util.file_consumer import BackgroundFileConsumer
from synapse.util.logcontext import make_deferred_yieldable
@@ -64,9 +65,10 @@ class MediaStorage(object):
with self.store_into_file(file_info) as (f, fname, finish_cb):
# Write to the main repository
- yield make_deferred_yieldable(threads.deferToThread(
+ yield logcontext.defer_to_thread(
+ self.hs.get_reactor(),
_write_file_synchronously, source, f,
- ))
+ )
yield finish_cb()
defer.returnValue(fname)
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 8c892ff187..1a7bfd6b56 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -674,7 +674,7 @@ def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
# This splits the paragraph into words, but keeping the
# (preceeding) whitespace intact so we can easily concat
# words back together.
- for match in re.finditer("\s*\S+", description):
+ for match in re.finditer(r"\s*\S+", description):
word = match.group()
# Keep adding words while the total length is less than
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 7b9f8b4d79..5aa03031f6 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -17,9 +17,10 @@ import logging
import os
import shutil
-from twisted.internet import defer, threads
+from twisted.internet import defer
from synapse.config._base import Config
+from synapse.util import logcontext
from synapse.util.logcontext import run_in_background
from .media_storage import FileResponder
@@ -120,7 +121,8 @@ class FileStorageProviderBackend(StorageProvider):
if not os.path.exists(dirname):
os.makedirs(dirname)
- return threads.deferToThread(
+ return logcontext.defer_to_thread(
+ self.hs.get_reactor(),
shutil.copyfile, primary_fname, backup_fname,
)
diff --git a/synapse/server.py b/synapse/server.py
index 3e9d3d8256..cf6b872cbd 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -207,6 +207,7 @@ class HomeServer(object):
logger.info("Setting up.")
with self.get_db_conn() as conn:
self.datastore = self.DATASTORE_CLASS(conn, self)
+ conn.commit()
logger.info("Finished setting up.")
def get_reactor(self):
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index b22495c1f9..9b40b18d5b 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -19,13 +19,14 @@ from collections import namedtuple
from six import iteritems, itervalues
+import attr
from frozendict import frozendict
from twisted.internet import defer
from synapse.api.constants import EventTypes, RoomVersions
from synapse.events.snapshot import EventContext
-from synapse.state import v1
+from synapse.state import v1, v2
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.expiringcache import ExpiringCache
@@ -372,15 +373,10 @@ class StateHandler(object):
result = yield self._state_resolution_handler.resolve_state_groups(
room_id, room_version, state_groups_ids, None,
- self._state_map_factory,
+ state_res_store=StateResolutionStore(self.store),
)
defer.returnValue(result)
- def _state_map_factory(self, ev_ids):
- return self.store.get_events(
- ev_ids, get_prev_content=False, check_redacted=False,
- )
-
@defer.inlineCallbacks
def resolve_events(self, room_version, state_sets, event):
logger.info(
@@ -398,10 +394,10 @@ class StateHandler(object):
}
with Measure(self.clock, "state._resolve_events"):
- new_state = yield resolve_events_with_factory(
+ new_state = yield resolve_events_with_store(
room_version, state_set_ids,
event_map=state_map,
- state_map_factory=self._state_map_factory
+ state_res_store=StateResolutionStore(self.store),
)
new_state = {
@@ -436,7 +432,7 @@ class StateResolutionHandler(object):
@defer.inlineCallbacks
@log_function
def resolve_state_groups(
- self, room_id, room_version, state_groups_ids, event_map, state_map_factory,
+ self, room_id, room_version, state_groups_ids, event_map, state_res_store,
):
"""Resolves conflicts between a set of state groups
@@ -454,9 +450,11 @@ class StateResolutionHandler(object):
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing
- events will be requested via state_map_factory.
+ events will be requested via state_res_store.
+
+ If None, all events will be fetched via state_res_store.
- If None, all events will be fetched via state_map_factory.
+ state_res_store (StateResolutionStore)
Returns:
Deferred[_StateCacheEntry]: resolved state
@@ -480,10 +478,10 @@ class StateResolutionHandler(object):
# start by assuming we won't have any conflicted state, and build up the new
# state map by iterating through the state groups. If we discover a conflict,
- # we give up and instead use `resolve_events_with_factory`.
+ # we give up and instead use `resolve_events_with_store`.
#
# XXX: is this actually worthwhile, or should we just let
- # resolve_events_with_factory do it?
+ # resolve_events_with_store do it?
new_state = {}
conflicted_state = False
for st in itervalues(state_groups_ids):
@@ -498,11 +496,11 @@ class StateResolutionHandler(object):
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
- new_state = yield resolve_events_with_factory(
+ new_state = yield resolve_events_with_store(
room_version,
list(itervalues(state_groups_ids)),
event_map=event_map,
- state_map_factory=state_map_factory,
+ state_res_store=state_res_store,
)
# if the new state matches any of the input state groups, we can
@@ -583,7 +581,7 @@ def _make_state_cache_entry(
)
-def resolve_events_with_factory(room_version, state_sets, event_map, state_map_factory):
+def resolve_events_with_store(room_version, state_sets, event_map, state_res_store):
"""
Args:
room_version(str): Version of the room
@@ -599,17 +597,19 @@ def resolve_events_with_factory(room_version, state_sets, event_map, state_map_f
If None, all events will be fetched via state_map_factory.
- state_map_factory(func): will be called
- with a list of event_ids that are needed, and should return with
- a Deferred of dict of event_id to event.
+ state_res_store (StateResolutionStore)
Returns
Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id.
"""
- if room_version in (RoomVersions.V1, RoomVersions.VDH_TEST,):
- return v1.resolve_events_with_factory(
- state_sets, event_map, state_map_factory,
+ if room_version == RoomVersions.V1:
+ return v1.resolve_events_with_store(
+ state_sets, event_map, state_res_store.get_events,
+ )
+ elif room_version == RoomVersions.VDH_TEST:
+ return v2.resolve_events_with_store(
+ state_sets, event_map, state_res_store,
)
else:
# This should only happen if we added a version but forgot to add it to
@@ -617,3 +617,54 @@ def resolve_events_with_factory(room_version, state_sets, event_map, state_map_f
raise Exception(
"No state resolution algorithm defined for version %r" % (room_version,)
)
+
+
+@attr.s
+class StateResolutionStore(object):
+ """Interface that allows state resolution algorithms to access the database
+ in well defined way.
+
+ Args:
+ store (DataStore)
+ """
+
+ store = attr.ib()
+
+ def get_events(self, event_ids, allow_rejected=False):
+ """Get events from the database
+
+ Args:
+ event_ids (list): The event_ids of the events to fetch
+ allow_rejected (bool): If True return rejected events.
+
+ Returns:
+ Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
+ """
+
+ return self.store.get_events(
+ event_ids,
+ check_redacted=False,
+ get_prev_content=False,
+ allow_rejected=allow_rejected,
+ )
+
+ def get_auth_chain(self, event_ids):
+ """Gets the full auth chain for a set of events (including rejected
+ events).
+
+ Includes the given event IDs in the result.
+
+ Note that:
+ 1. All events must be state events.
+ 2. For v1 rooms this may not have the full auth chain in the
+ presence of rejected events
+
+ Args:
+ event_ids (list): The event IDs of the events to fetch the auth
+ chain for. Must be state events.
+
+ Returns:
+ Deferred[list[str]]: List of event IDs of the auth chain.
+ """
+
+ return self.store.get_auth_chain_ids(event_ids, include_given=True)
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 7a7157b352..70a981f4a2 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -31,7 +31,7 @@ POWER_KEY = (EventTypes.PowerLevels, "")
@defer.inlineCallbacks
-def resolve_events_with_factory(state_sets, event_map, state_map_factory):
+def resolve_events_with_store(state_sets, event_map, state_map_factory):
"""
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
new file mode 100644
index 0000000000..5d06f7e928
--- /dev/null
+++ b/synapse/state/v2.py
@@ -0,0 +1,544 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import heapq
+import itertools
+import logging
+
+from six import iteritems, itervalues
+
+from twisted.internet import defer
+
+from synapse import event_auth
+from synapse.api.constants import EventTypes
+from synapse.api.errors import AuthError
+
+logger = logging.getLogger(__name__)
+
+
+@defer.inlineCallbacks
+def resolve_events_with_store(state_sets, event_map, state_res_store):
+ """Resolves the state using the v2 state resolution algorithm
+
+ Args:
+ state_sets(list): List of dicts of (type, state_key) -> event_id,
+ which are the different state groups to resolve.
+
+ event_map(dict[str,FrozenEvent]|None):
+ a dict from event_id to event, for any events that we happen to
+ have in flight (eg, those currently being persisted). This will be
+ used as a starting point fof finding the state we need; any missing
+ events will be requested via state_res_store.
+
+ If None, all events will be fetched via state_res_store.
+
+ state_res_store (StateResolutionStore)
+
+ Returns
+ Deferred[dict[(str, str), str]]:
+ a map from (type, state_key) to event_id.
+ """
+
+ logger.debug("Computing conflicted state")
+
+ # First split up the un/conflicted state
+ unconflicted_state, conflicted_state = _seperate(state_sets)
+
+ if not conflicted_state:
+ defer.returnValue(unconflicted_state)
+
+ logger.debug("%d conflicted state entries", len(conflicted_state))
+ logger.debug("Calculating auth chain difference")
+
+ # Also fetch all auth events that appear in only some of the state sets'
+ # auth chains.
+ auth_diff = yield _get_auth_chain_difference(
+ state_sets, event_map, state_res_store,
+ )
+
+ full_conflicted_set = set(itertools.chain(
+ itertools.chain.from_iterable(itervalues(conflicted_state)),
+ auth_diff,
+ ))
+
+ events = yield state_res_store.get_events([
+ eid for eid in full_conflicted_set
+ if eid not in event_map
+ ], allow_rejected=True)
+ event_map.update(events)
+
+ full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map)
+
+ logger.debug("%d full_conflicted_set entries", len(full_conflicted_set))
+
+ # Get and sort all the power events (kicks/bans/etc)
+ power_events = (
+ eid for eid in full_conflicted_set
+ if _is_power_event(event_map[eid])
+ )
+
+ sorted_power_events = yield _reverse_topological_power_sort(
+ power_events,
+ event_map,
+ state_res_store,
+ full_conflicted_set,
+ )
+
+ logger.debug("sorted %d power events", len(sorted_power_events))
+
+ # Now sequentially auth each one
+ resolved_state = yield _iterative_auth_checks(
+ sorted_power_events, unconflicted_state, event_map,
+ state_res_store,
+ )
+
+ logger.debug("resolved power events")
+
+ # OK, so we've now resolved the power events. Now sort the remaining
+ # events using the mainline of the resolved power level.
+
+ leftover_events = [
+ ev_id
+ for ev_id in full_conflicted_set
+ if ev_id not in sorted_power_events
+ ]
+
+ logger.debug("sorting %d remaining events", len(leftover_events))
+
+ pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
+ leftover_events = yield _mainline_sort(
+ leftover_events, pl, event_map, state_res_store,
+ )
+
+ logger.debug("resolving remaining events")
+
+ resolved_state = yield _iterative_auth_checks(
+ leftover_events, resolved_state, event_map,
+ state_res_store,
+ )
+
+ logger.debug("resolved")
+
+ # We make sure that unconflicted state always still applies.
+ resolved_state.update(unconflicted_state)
+
+ logger.debug("done")
+
+ defer.returnValue(resolved_state)
+
+
+@defer.inlineCallbacks
+def _get_power_level_for_sender(event_id, event_map, state_res_store):
+ """Return the power level of the sender of the given event according to
+ their auth events.
+
+ Args:
+ event_id (str)
+ event_map (dict[str,FrozenEvent])
+ state_res_store (StateResolutionStore)
+
+ Returns:
+ Deferred[int]
+ """
+ event = yield _get_event(event_id, event_map, state_res_store)
+
+ pl = None
+ for aid, _ in event.auth_events:
+ aev = yield _get_event(aid, event_map, state_res_store)
+ if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
+ pl = aev
+ break
+
+ if pl is None:
+ # Couldn't find power level. Check if they're the creator of the room
+ for aid, _ in event.auth_events:
+ aev = yield _get_event(aid, event_map, state_res_store)
+ if (aev.type, aev.state_key) == (EventTypes.Create, ""):
+ if aev.content.get("creator") == event.sender:
+ defer.returnValue(100)
+ break
+ defer.returnValue(0)
+
+ level = pl.content.get("users", {}).get(event.sender)
+ if level is None:
+ level = pl.content.get("users_default", 0)
+
+ if level is None:
+ defer.returnValue(0)
+ else:
+ defer.returnValue(int(level))
+
+
+@defer.inlineCallbacks
+def _get_auth_chain_difference(state_sets, event_map, state_res_store):
+ """Compare the auth chains of each state set and return the set of events
+ that only appear in some but not all of the auth chains.
+
+ Args:
+ state_sets (list)
+ event_map (dict[str,FrozenEvent])
+ state_res_store (StateResolutionStore)
+
+ Returns:
+ Deferred[set[str]]: Set of event IDs
+ """
+ common = set(itervalues(state_sets[0])).intersection(
+ *(itervalues(s) for s in state_sets[1:])
+ )
+
+ auth_sets = []
+ for state_set in state_sets:
+ auth_ids = set(
+ eid
+ for key, eid in iteritems(state_set)
+ if (key[0] in (
+ EventTypes.Member,
+ EventTypes.ThirdPartyInvite,
+ ) or key in (
+ (EventTypes.PowerLevels, ''),
+ (EventTypes.Create, ''),
+ (EventTypes.JoinRules, ''),
+ )) and eid not in common
+ )
+
+ auth_chain = yield state_res_store.get_auth_chain(auth_ids)
+ auth_ids.update(auth_chain)
+
+ auth_sets.append(auth_ids)
+
+ intersection = set(auth_sets[0]).intersection(*auth_sets[1:])
+ union = set().union(*auth_sets)
+
+ defer.returnValue(union - intersection)
+
+
+def _seperate(state_sets):
+ """Return the unconflicted and conflicted state. This is different than in
+ the original algorithm, as this defines a key to be conflicted if one of
+ the state sets doesn't have that key.
+
+ Args:
+ state_sets (list)
+
+ Returns:
+ tuple[dict, dict]: A tuple of unconflicted and conflicted state. The
+ conflicted state dict is a map from type/state_key to set of event IDs
+ """
+ unconflicted_state = {}
+ conflicted_state = {}
+
+ for key in set(itertools.chain.from_iterable(state_sets)):
+ event_ids = set(state_set.get(key) for state_set in state_sets)
+ if len(event_ids) == 1:
+ unconflicted_state[key] = event_ids.pop()
+ else:
+ event_ids.discard(None)
+ conflicted_state[key] = event_ids
+
+ return unconflicted_state, conflicted_state
+
+
+def _is_power_event(event):
+ """Return whether or not the event is a "power event", as defined by the
+ v2 state resolution algorithm
+
+ Args:
+ event (FrozenEvent)
+
+ Returns:
+ boolean
+ """
+ if (event.type, event.state_key) in (
+ (EventTypes.PowerLevels, ""),
+ (EventTypes.JoinRules, ""),
+ (EventTypes.Create, ""),
+ ):
+ return True
+
+ if event.type == EventTypes.Member:
+ if event.membership in ('leave', 'ban'):
+ return event.sender != event.state_key
+
+ return False
+
+
+@defer.inlineCallbacks
+def _add_event_and_auth_chain_to_graph(graph, event_id, event_map,
+ state_res_store, auth_diff):
+ """Helper function for _reverse_topological_power_sort that add the event
+ and its auth chain (that is in the auth diff) to the graph
+
+ Args:
+ graph (dict[str, set[str]]): A map from event ID to the events auth
+ event IDs
+ event_id (str): Event to add to the graph
+ event_map (dict[str,FrozenEvent])
+ state_res_store (StateResolutionStore)
+ auth_diff (set[str]): Set of event IDs that are in the auth difference.
+ """
+
+ state = [event_id]
+ while state:
+ eid = state.pop()
+ graph.setdefault(eid, set())
+
+ event = yield _get_event(eid, event_map, state_res_store)
+ for aid, _ in event.auth_events:
+ if aid in auth_diff:
+ if aid not in graph:
+ state.append(aid)
+
+ graph.setdefault(eid, set()).add(aid)
+
+
+@defer.inlineCallbacks
+def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_diff):
+ """Returns a list of the event_ids sorted by reverse topological ordering,
+ and then by power level and origin_server_ts
+
+ Args:
+ event_ids (list[str]): The events to sort
+ event_map (dict[str,FrozenEvent])
+ state_res_store (StateResolutionStore)
+ auth_diff (set[str]): Set of event IDs that are in the auth difference.
+
+ Returns:
+ Deferred[list[str]]: The sorted list
+ """
+
+ graph = {}
+ for event_id in event_ids:
+ yield _add_event_and_auth_chain_to_graph(
+ graph, event_id, event_map, state_res_store, auth_diff,
+ )
+
+ event_to_pl = {}
+ for event_id in graph:
+ pl = yield _get_power_level_for_sender(event_id, event_map, state_res_store)
+ event_to_pl[event_id] = pl
+
+ def _get_power_order(event_id):
+ ev = event_map[event_id]
+ pl = event_to_pl[event_id]
+
+ return -pl, ev.origin_server_ts, event_id
+
+ # Note: graph is modified during the sort
+ it = lexicographical_topological_sort(
+ graph,
+ key=_get_power_order,
+ )
+ sorted_events = list(it)
+
+ defer.returnValue(sorted_events)
+
+
+@defer.inlineCallbacks
+def _iterative_auth_checks(event_ids, base_state, event_map, state_res_store):
+ """Sequentially apply auth checks to each event in given list, updating the
+ state as it goes along.
+
+ Args:
+ event_ids (list[str]): Ordered list of events to apply auth checks to
+ base_state (dict[tuple[str, str], str]): The set of state to start with
+ event_map (dict[str,FrozenEvent])
+ state_res_store (StateResolutionStore)
+
+ Returns:
+ Deferred[dict[tuple[str, str], str]]: Returns the final updated state
+ """
+ resolved_state = base_state.copy()
+
+ for event_id in event_ids:
+ event = event_map[event_id]
+
+ auth_events = {}
+ for aid, _ in event.auth_events:
+ ev = yield _get_event(aid, event_map, state_res_store)
+
+ if ev.rejected_reason is None:
+ auth_events[(ev.type, ev.state_key)] = ev
+
+ for key in event_auth.auth_types_for_event(event):
+ if key in resolved_state:
+ ev_id = resolved_state[key]
+ ev = yield _get_event(ev_id, event_map, state_res_store)
+
+ if ev.rejected_reason is None:
+ auth_events[key] = event_map[ev_id]
+
+ try:
+ event_auth.check(
+ event, auth_events,
+ do_sig_check=False,
+ do_size_check=False
+ )
+
+ resolved_state[(event.type, event.state_key)] = event_id
+ except AuthError:
+ pass
+
+ defer.returnValue(resolved_state)
+
+
+@defer.inlineCallbacks
+def _mainline_sort(event_ids, resolved_power_event_id, event_map,
+ state_res_store):
+ """Returns a sorted list of event_ids sorted by mainline ordering based on
+ the given event resolved_power_event_id
+
+ Args:
+ event_ids (list[str]): Events to sort
+ resolved_power_event_id (str): The final resolved power level event ID
+ event_map (dict[str,FrozenEvent])
+ state_res_store (StateResolutionStore)
+
+ Returns:
+ Deferred[list[str]]: The sorted list
+ """
+ mainline = []
+ pl = resolved_power_event_id
+ while pl:
+ mainline.append(pl)
+ pl_ev = yield _get_event(pl, event_map, state_res_store)
+ auth_events = pl_ev.auth_events
+ pl = None
+ for aid, _ in auth_events:
+ ev = yield _get_event(aid, event_map, state_res_store)
+ if (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""):
+ pl = aid
+ break
+
+ mainline_map = {ev_id: i + 1 for i, ev_id in enumerate(reversed(mainline))}
+
+ event_ids = list(event_ids)
+
+ order_map = {}
+ for ev_id in event_ids:
+ depth = yield _get_mainline_depth_for_event(
+ event_map[ev_id], mainline_map,
+ event_map, state_res_store,
+ )
+ order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id)
+
+ event_ids.sort(key=lambda ev_id: order_map[ev_id])
+
+ defer.returnValue(event_ids)
+
+
+@defer.inlineCallbacks
+def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_store):
+ """Get the mainline depths for the given event based on the mainline map
+
+ Args:
+ event (FrozenEvent)
+ mainline_map (dict[str, int]): Map from event_id to mainline depth for
+ events in the mainline.
+ event_map (dict[str,FrozenEvent])
+ state_res_store (StateResolutionStore)
+
+ Returns:
+ Deferred[int]
+ """
+
+ # We do an iterative search, replacing `event with the power level in its
+ # auth events (if any)
+ while event:
+ depth = mainline_map.get(event.event_id)
+ if depth is not None:
+ defer.returnValue(depth)
+
+ auth_events = event.auth_events
+ event = None
+
+ for aid, _ in auth_events:
+ aev = yield _get_event(aid, event_map, state_res_store)
+ if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
+ event = aev
+ break
+
+ # Didn't find a power level auth event, so we just return 0
+ defer.returnValue(0)
+
+
+@defer.inlineCallbacks
+def _get_event(event_id, event_map, state_res_store):
+ """Helper function to look up event in event_map, falling back to looking
+ it up in the store
+
+ Args:
+ event_id (str)
+ event_map (dict[str,FrozenEvent])
+ state_res_store (StateResolutionStore)
+
+ Returns:
+ Deferred[FrozenEvent]
+ """
+ if event_id not in event_map:
+ events = yield state_res_store.get_events([event_id], allow_rejected=True)
+ event_map.update(events)
+ defer.returnValue(event_map[event_id])
+
+
+def lexicographical_topological_sort(graph, key):
+ """Performs a lexicographic reverse topological sort on the graph.
+
+ This returns a reverse topological sort (i.e. if node A references B then B
+ appears before A in the sort), with ties broken lexicographically based on
+ return value of the `key` function.
+
+ NOTE: `graph` is modified during the sort.
+
+ Args:
+ graph (dict[str, set[str]]): A representation of the graph where each
+ node is a key in the dict and its value are the nodes edges.
+ key (func): A function that takes a node and returns a value that is
+ comparable and used to order nodes
+
+ Yields:
+ str: The next node in the topological sort
+ """
+
+ # Note, this is basically Kahn's algorithm except we look at nodes with no
+ # outgoing edges, c.f.
+ # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
+ outdegree_map = graph
+ reverse_graph = {}
+
+ # Lists of nodes with zero out degree. Is actually a tuple of
+ # `(key(node), node)` so that sorting does the right thing
+ zero_outdegree = []
+
+ for node, edges in iteritems(graph):
+ if len(edges) == 0:
+ zero_outdegree.append((key(node), node))
+
+ reverse_graph.setdefault(node, set())
+ for edge in edges:
+ reverse_graph.setdefault(edge, set()).add(node)
+
+ # heapq is a built in implementation of a sorted queue.
+ heapq.heapify(zero_outdegree)
+
+ while zero_outdegree:
+ _, node = heapq.heappop(zero_outdegree)
+
+ for parent in reverse_graph[node]:
+ out = outdegree_map[parent]
+ out.discard(node)
+ if len(out) == 0:
+ heapq.heappush(zero_outdegree, (key(parent), parent))
+
+ yield node
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index be61147b9b..d9d0255d0b 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -18,7 +18,7 @@ import threading
import time
from six import PY2, iteritems, iterkeys, itervalues
-from six.moves import intern, range
+from six.moves import builtins, intern, range
from canonicaljson import json
from prometheus_client import Histogram
@@ -1233,7 +1233,7 @@ def db_to_json(db_content):
# psycopg2 on Python 2 returns buffer objects, which we need to cast to
# bytes to decode
- if PY2 and isinstance(db_content, buffer):
+ if PY2 and isinstance(db_content, builtins.buffer):
db_content = bytes(db_content)
# Decode it to a Unicode string before feeding it to json.loads, so we
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index cfb687cb53..61a029a53c 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -90,7 +90,7 @@ class DirectoryWorkerStore(SQLBaseStore):
class DirectoryStore(DirectoryWorkerStore):
@defer.inlineCallbacks
def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
- """ Creates an associatin between a room alias and room_id/servers
+ """ Creates an association between a room alias and room_id/servers
Args:
room_alias (RoomAlias)
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 379b5c514f..e791922733 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -34,6 +34,7 @@ from synapse.api.errors import SynapseError
from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.state import StateResolutionStore
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.events_worker import EventsWorkerStore
@@ -733,11 +734,6 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
# Ok, we need to defer to the state handler to resolve our state sets.
- def get_events(ev_ids):
- return self.get_events(
- ev_ids, get_prev_content=False, check_redacted=False,
- )
-
state_groups = {
sg: state_groups_map[sg] for sg in new_state_groups
}
@@ -747,7 +743,8 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
logger.debug("calling resolve_state_groups from preserve_events")
res = yield self._state_resolution_handler.resolve_state_groups(
- room_id, room_version, state_groups, events_map, get_events
+ room_id, room_version, state_groups, events_map,
+ state_res_store=StateResolutionStore(self)
)
defer.returnValue((res.state, None))
@@ -856,6 +853,27 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
# Insert into event_to_state_groups.
self._store_event_state_mappings_txn(txn, events_and_contexts)
+ # We want to store event_auth mappings for rejected events, as they're
+ # used in state res v2.
+ # This is only necessary if the rejected event appears in an accepted
+ # event's auth chain, but its easier for now just to store them (and
+ # it doesn't take much storage compared to storing the entire event
+ # anyway).
+ self._simple_insert_many_txn(
+ txn,
+ table="event_auth",
+ values=[
+ {
+ "event_id": event.event_id,
+ "room_id": event.room_id,
+ "auth_id": auth_id,
+ }
+ for event, _ in events_and_contexts
+ for auth_id, _ in event.auth_events
+ if event.is_state()
+ ],
+ )
+
# _store_rejected_events_txn filters out any events which were
# rejected, and returns the filtered list.
events_and_contexts = self._store_rejected_events_txn(
@@ -1331,21 +1349,6 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
txn, event.room_id, event.redacts
)
- self._simple_insert_many_txn(
- txn,
- table="event_auth",
- values=[
- {
- "event_id": event.event_id,
- "room_id": event.room_id,
- "auth_id": auth_id,
- }
- for event, _ in events_and_contexts
- for auth_id, _ in event.auth_events
- if event.is_state()
- ],
- )
-
# Update the event_forward_extremities, event_backward_extremities and
# event_edges tables.
self._handle_mult_prev_events(
@@ -2068,7 +2071,7 @@ class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore
for sg in remaining_state_groups:
logger.info("[purge] de-delta-ing remaining state group %s", sg)
curr_state = self._get_state_groups_from_groups_txn(
- txn, [sg], types=None
+ txn, [sg],
)
curr_state = curr_state[sg]
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index a1331c1a61..8af17921e3 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview
if six.PY2:
- db_binary_type = buffer
+ db_binary_type = six.moves.builtins.buffer
else:
db_binary_type = memoryview
diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py
index 0fe8c8e24c..cf4104dc2e 100644
--- a/synapse/storage/monthly_active_users.py
+++ b/synapse/storage/monthly_active_users.py
@@ -33,19 +33,29 @@ class MonthlyActiveUsersStore(SQLBaseStore):
self._clock = hs.get_clock()
self.hs = hs
self.reserved_users = ()
+ # Do not add more reserved users than the total allowable number
+ self._initialise_reserved_users(
+ dbconn.cursor(),
+ hs.config.mau_limits_reserved_threepids[:self.hs.config.max_mau_value],
+ )
- @defer.inlineCallbacks
- def initialise_reserved_users(self, threepids):
- store = self.hs.get_datastore()
+ def _initialise_reserved_users(self, txn, threepids):
+ """Ensures that reserved threepids are accounted for in the MAU table, should
+ be called on start up.
+
+ Args:
+ txn (cursor):
+ threepids (list[dict]): List of threepid dicts to reserve
+ """
reserved_user_list = []
- # Do not add more reserved users than the total allowable number
- for tp in threepids[:self.hs.config.max_mau_value]:
- user_id = yield store.get_user_id_by_threepid(
+ for tp in threepids:
+ user_id = self.get_user_id_by_threepid_txn(
+ txn,
tp["medium"], tp["address"]
)
if user_id:
- yield self.upsert_monthly_active_user(user_id)
+ self.upsert_monthly_active_user_txn(txn, user_id)
reserved_user_list.append(user_id)
else:
logger.warning(
@@ -55,8 +65,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
@defer.inlineCallbacks
def reap_monthly_active_users(self):
- """
- Cleans out monthly active user table to ensure that no stale
+ """Cleans out monthly active user table to ensure that no stale
entries exist.
Returns:
@@ -165,19 +174,44 @@ class MonthlyActiveUsersStore(SQLBaseStore):
@defer.inlineCallbacks
def upsert_monthly_active_user(self, user_id):
+ """Updates or inserts the user into the monthly active user table, which
+ is used to track the current MAU usage of the server
+
+ Args:
+ user_id (str): user to add/update
"""
- Updates or inserts monthly active user member
- Arguments:
- user_id (str): user to add/update
- Deferred[bool]: True if a new entry was created, False if an
- existing one was updated.
+ is_insert = yield self.runInteraction(
+ "upsert_monthly_active_user", self.upsert_monthly_active_user_txn,
+ user_id
+ )
+
+ if is_insert:
+ self.user_last_seen_monthly_active.invalidate((user_id,))
+ self.get_monthly_active_count.invalidate(())
+
+ def upsert_monthly_active_user_txn(self, txn, user_id):
+ """Updates or inserts monthly active user member
+
+ Note that, after calling this method, it will generally be necessary
+ to invalidate the caches on user_last_seen_monthly_active and
+ get_monthly_active_count. We can't do that here, because we are running
+ in a database thread rather than the main thread, and we can't call
+ txn.call_after because txn may not be a LoggingTransaction.
+
+ Args:
+ txn (cursor):
+ user_id (str): user to add/update
+
+ Returns:
+ bool: True if a new entry was created, False if an
+ existing one was updated.
"""
# Am consciously deciding to lock the table on the basis that is ought
# never be a big table and alternative approaches (batching multiple
# upserts into a single txn) introduced a lot of extra complexity.
# See https://github.com/matrix-org/synapse/issues/3854 for more
- is_insert = yield self._simple_upsert(
- desc="upsert_monthly_active_user",
+ is_insert = self._simple_upsert_txn(
+ txn,
table="monthly_active_users",
keyvalues={
"user_id": user_id,
@@ -186,9 +220,8 @@ class MonthlyActiveUsersStore(SQLBaseStore):
"timestamp": int(self._clock.time_msec()),
},
)
- if is_insert:
- self.user_last_seen_monthly_active.invalidate((user_id,))
- self.get_monthly_active_count.invalidate(())
+
+ return is_insert
@cached(num_args=1)
def user_last_seen_monthly_active(self, user_id):
diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py
index c7987bfcdd..2743b52bad 100644
--- a/synapse/storage/pusher.py
+++ b/synapse/storage/pusher.py
@@ -29,7 +29,7 @@ from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
if six.PY2:
- db_binary_type = buffer
+ db_binary_type = six.moves.builtins.buffer
else:
db_binary_type = memoryview
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 26b429e307..80d76bf9d7 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -474,17 +474,44 @@ class RegistrationStore(RegistrationWorkerStore,
@defer.inlineCallbacks
def get_user_id_by_threepid(self, medium, address):
- ret = yield self._simple_select_one(
+ """Returns user id from threepid
+
+ Args:
+ medium (str): threepid medium e.g. email
+ address (str): threepid address e.g. me@example.com
+
+ Returns:
+ Deferred[str|None]: user id or None if no user id/threepid mapping exists
+ """
+ user_id = yield self.runInteraction(
+ "get_user_id_by_threepid", self.get_user_id_by_threepid_txn,
+ medium, address
+ )
+ defer.returnValue(user_id)
+
+ def get_user_id_by_threepid_txn(self, txn, medium, address):
+ """Returns user id from threepid
+
+ Args:
+ txn (cursor):
+ medium (str): threepid medium e.g. email
+ address (str): threepid address e.g. me@example.com
+
+ Returns:
+ str|None: user id or None if no user id/threepid mapping exists
+ """
+ ret = self._simple_select_one_txn(
+ txn,
"user_threepids",
{
"medium": medium,
"address": address
},
- ['user_id'], True, 'get_user_id_by_threepid'
+ ['user_id'], True
)
if ret:
- defer.returnValue(ret['user_id'])
- defer.returnValue(None)
+ return ret['user_id']
+ return None
def user_delete_threepid(self, user_id, medium, address):
return self._simple_delete(
@@ -567,7 +594,7 @@ class RegistrationStore(RegistrationWorkerStore,
def _find_next_generated_user_id(txn):
txn.execute("SELECT name FROM users")
- regex = re.compile("^@(\d+):")
+ regex = re.compile(r"^@(\d+):")
found = set()
diff --git a/synapse/storage/signatures.py b/synapse/storage/signatures.py
index 5623391f6e..158e9dbe7b 100644
--- a/synapse/storage/signatures.py
+++ b/synapse/storage/signatures.py
@@ -27,7 +27,7 @@ from ._base import SQLBaseStore
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview
if six.PY2:
- db_binary_type = buffer
+ db_binary_type = six.moves.builtins.buffer
else:
db_binary_type = memoryview
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 0f86311ed4..59a50a5df9 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -19,6 +19,8 @@ from collections import namedtuple
from six import iteritems, itervalues
from six.moves import range
+import attr
+
from twisted.internet import defer
from synapse.api.constants import EventTypes
@@ -48,6 +50,318 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
return len(self.delta_ids) if self.delta_ids else 0
+@attr.s(slots=True)
+class StateFilter(object):
+ """A filter used when querying for state.
+
+ Attributes:
+ types (dict[str, set[str]|None]): Map from type to set of state keys (or
+ None). This specifies which state_keys for the given type to fetch
+ from the DB. If None then all events with that type are fetched. If
+ the set is empty then no events with that type are fetched.
+ include_others (bool): Whether to fetch events with types that do not
+ appear in `types`.
+ """
+
+ types = attr.ib()
+ include_others = attr.ib(default=False)
+
+ def __attrs_post_init__(self):
+ # If `include_others` is set we canonicalise the filter by removing
+ # wildcards from the types dictionary
+ if self.include_others:
+ self.types = {
+ k: v for k, v in iteritems(self.types)
+ if v is not None
+ }
+
+ @staticmethod
+ def all():
+ """Creates a filter that fetches everything.
+
+ Returns:
+ StateFilter
+ """
+ return StateFilter(types={}, include_others=True)
+
+ @staticmethod
+ def none():
+ """Creates a filter that fetches nothing.
+
+ Returns:
+ StateFilter
+ """
+ return StateFilter(types={}, include_others=False)
+
+ @staticmethod
+ def from_types(types):
+ """Creates a filter that only fetches the given types
+
+ Args:
+ types (Iterable[tuple[str, str|None]]): A list of type and state
+ keys to fetch. A state_key of None fetches everything for
+ that type
+
+ Returns:
+ StateFilter
+ """
+ type_dict = {}
+ for typ, s in types:
+ if typ in type_dict:
+ if type_dict[typ] is None:
+ continue
+
+ if s is None:
+ type_dict[typ] = None
+ continue
+
+ type_dict.setdefault(typ, set()).add(s)
+
+ return StateFilter(types=type_dict)
+
+ @staticmethod
+ def from_lazy_load_member_list(members):
+ """Creates a filter that returns all non-member events, plus the member
+ events for the given users
+
+ Args:
+ members (iterable[str]): Set of user IDs
+
+ Returns:
+ StateFilter
+ """
+ return StateFilter(
+ types={EventTypes.Member: set(members)},
+ include_others=True,
+ )
+
+ def return_expanded(self):
+ """Creates a new StateFilter where type wild cards have been removed
+ (except for memberships). The returned filter is a superset of the
+ current one, i.e. anything that passes the current filter will pass
+ the returned filter.
+
+ This helps the caching as the DictionaryCache knows if it has *all* the
+ state, but does not know if it has all of the keys of a particular type,
+ which makes wildcard lookups expensive unless we have a complete cache.
+ Hence, if we are doing a wildcard lookup, populate the cache fully so
+ that we can do an efficient lookup next time.
+
+ Note that since we have two caches, one for membership events and one for
+ other events, we can be a bit more clever than simply returning
+ `StateFilter.all()` if `has_wildcards()` is True.
+
+ We return a StateFilter where:
+ 1. the list of membership events to return is the same
+ 2. if there is a wildcard that matches non-member events we
+ return all non-member events
+
+ Returns:
+ StateFilter
+ """
+
+ if self.is_full():
+ # If we're going to return everything then there's nothing to do
+ return self
+
+ if not self.has_wildcards():
+ # If there are no wild cards, there's nothing to do
+ return self
+
+ if EventTypes.Member in self.types:
+ get_all_members = self.types[EventTypes.Member] is None
+ else:
+ get_all_members = self.include_others
+
+ has_non_member_wildcard = self.include_others or any(
+ state_keys is None
+ for t, state_keys in iteritems(self.types)
+ if t != EventTypes.Member
+ )
+
+ if not has_non_member_wildcard:
+ # If there are no non-member wild cards we can just return ourselves
+ return self
+
+ if get_all_members:
+ # We want to return everything.
+ return StateFilter.all()
+ else:
+ # We want to return all non-members, but only particular
+ # memberships
+ return StateFilter(
+ types={EventTypes.Member: self.types[EventTypes.Member]},
+ include_others=True,
+ )
+
+ def make_sql_filter_clause(self):
+ """Converts the filter to an SQL clause.
+
+ For example:
+
+ f = StateFilter.from_types([("m.room.create", "")])
+ clause, args = f.make_sql_filter_clause()
+ clause == "(type = ? AND state_key = ?)"
+ args == ['m.room.create', '']
+
+
+ Returns:
+ tuple[str, list]: The SQL string (may be empty) and arguments. An
+ empty SQL string is returned when the filter matches everything
+ (i.e. is "full").
+ """
+
+ where_clause = ""
+ where_args = []
+
+ if self.is_full():
+ return where_clause, where_args
+
+ if not self.include_others and not self.types:
+ # i.e. this is an empty filter, so we need to return a clause that
+ # will match nothing
+ return "1 = 2", []
+
+ # First we build up a lost of clauses for each type/state_key combo
+ clauses = []
+ for etype, state_keys in iteritems(self.types):
+ if state_keys is None:
+ clauses.append("(type = ?)")
+ where_args.append(etype)
+ continue
+
+ for state_key in state_keys:
+ clauses.append("(type = ? AND state_key = ?)")
+ where_args.extend((etype, state_key))
+
+ # This will match anything that appears in `self.types`
+ where_clause = " OR ".join(clauses)
+
+ # If we want to include stuff that's not in the types dict then we add
+ # a `OR type NOT IN (...)` clause to the end.
+ if self.include_others:
+ if where_clause:
+ where_clause += " OR "
+
+ where_clause += "type NOT IN (%s)" % (
+ ",".join(["?"] * len(self.types)),
+ )
+ where_args.extend(self.types)
+
+ return where_clause, where_args
+
+ def max_entries_returned(self):
+ """Returns the maximum number of entries this filter will return if
+ known, otherwise returns None.
+
+ For example a simple state filter asking for `("m.room.create", "")`
+ will return 1, whereas the default state filter will return None.
+
+ This is used to bail out early if the right number of entries have been
+ fetched.
+ """
+ if self.has_wildcards():
+ return None
+
+ return len(self.concrete_types())
+
+ def filter_state(self, state_dict):
+ """Returns the state filtered with by this StateFilter
+
+ Args:
+ state (dict[tuple[str, str], Any]): The state map to filter
+
+ Returns:
+ dict[tuple[str, str], Any]: The filtered state map
+ """
+ if self.is_full():
+ return dict(state_dict)
+
+ filtered_state = {}
+ for k, v in iteritems(state_dict):
+ typ, state_key = k
+ if typ in self.types:
+ state_keys = self.types[typ]
+ if state_keys is None or state_key in state_keys:
+ filtered_state[k] = v
+ elif self.include_others:
+ filtered_state[k] = v
+
+ return filtered_state
+
+ def is_full(self):
+ """Whether this filter fetches everything or not
+
+ Returns:
+ bool
+ """
+ return self.include_others and not self.types
+
+ def has_wildcards(self):
+ """Whether the filter includes wildcards or is attempting to fetch
+ specific state.
+
+ Returns:
+ bool
+ """
+
+ return (
+ self.include_others
+ or any(
+ state_keys is None
+ for state_keys in itervalues(self.types)
+ )
+ )
+
+ def concrete_types(self):
+ """Returns a list of concrete type/state_keys (i.e. not None) that
+ will be fetched. This will be a complete list if `has_wildcards`
+ returns False, but otherwise will be a subset (or even empty).
+
+ Returns:
+ list[tuple[str,str]]
+ """
+ return [
+ (t, s)
+ for t, state_keys in iteritems(self.types)
+ if state_keys is not None
+ for s in state_keys
+ ]
+
+ def get_member_split(self):
+ """Return the filter split into two: one which assumes it's exclusively
+ matching against member state, and one which assumes it's matching
+ against non member state.
+
+ This is useful due to the returned filters giving correct results for
+ `is_full()`, `has_wildcards()`, etc, when operating against maps that
+ either exclusively contain member events or only contain non-member
+ events. (Which is the case when dealing with the member vs non-member
+ state caches).
+
+ Returns:
+ tuple[StateFilter, StateFilter]: The member and non member filters
+ """
+
+ if EventTypes.Member in self.types:
+ state_keys = self.types[EventTypes.Member]
+ if state_keys is None:
+ member_filter = StateFilter.all()
+ else:
+ member_filter = StateFilter({EventTypes.Member: state_keys})
+ elif self.include_others:
+ member_filter = StateFilter.all()
+ else:
+ member_filter = StateFilter.none()
+
+ non_member_filter = StateFilter(
+ types={k: v for k, v in iteritems(self.types) if k != EventTypes.Member},
+ include_others=self.include_others,
+ )
+
+ return member_filter, non_member_filter
+
+
# this inherits from EventsWorkerStore because it calls self.get_events
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""The parts of StateGroupStore that can be called from workers.
@@ -152,61 +466,41 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
)
# FIXME: how should this be cached?
- def get_filtered_current_state_ids(self, room_id, types, filtered_types=None):
+ def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
"""Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result
of doing a fresh state resolution as per state_handler.get_current_state
+
Args:
room_id (str)
- types (list[(Str, (Str|None))]): List of (type, state_key) tuples
- which are used to filter the state fetched. `state_key` may be
- None, which matches any `state_key`
- filtered_types (list[Str]|None): List of types to apply the above filter to.
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+
Returns:
- deferred: dict of (type, state_key) -> event
+ Deferred[dict[tuple[str, str], str]]: Map from type/state_key to
+ event ID.
"""
- include_other_types = False if filtered_types is None else True
-
def _get_filtered_current_state_ids_txn(txn):
results = {}
- sql = """SELECT type, state_key, event_id FROM current_state_events
- WHERE room_id = ? %s"""
- # Turns out that postgres doesn't like doing a list of OR's and
- # is about 1000x slower, so we just issue a query for each specific
- # type seperately.
- if types:
- clause_to_args = [
- (
- "AND type = ? AND state_key = ?",
- (etype, state_key)
- ) if state_key is not None else (
- "AND type = ?",
- (etype,)
- )
- for etype, state_key in types
- ]
-
- if include_other_types:
- unique_types = set(filtered_types)
- clause_to_args.append(
- (
- "AND type <> ? " * len(unique_types),
- list(unique_types)
- )
- )
- else:
- # If types is None we fetch all the state, and so just use an
- # empty where clause with no extra args.
- clause_to_args = [("", [])]
- for where_clause, where_args in clause_to_args:
- args = [room_id]
- args.extend(where_args)
- txn.execute(sql % (where_clause,), args)
- for row in txn:
- typ, state_key, event_id = row
- key = (intern_string(typ), intern_string(state_key))
- results[key] = event_id
+ sql = """
+ SELECT type, state_key, event_id FROM current_state_events
+ WHERE room_id = ?
+ """
+
+ where_clause, where_args = state_filter.make_sql_filter_clause()
+
+ if where_clause:
+ sql += " AND (%s)" % (where_clause,)
+
+ args = [room_id]
+ args.extend(where_args)
+ txn.execute(sql, args)
+ for row in txn:
+ typ, state_key, event_id = row
+ key = (intern_string(typ), intern_string(state_key))
+ results[key] = event_id
+
return results
return self.runInteraction(
@@ -322,20 +616,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
})
@defer.inlineCallbacks
- def _get_state_groups_from_groups(self, groups, types, members=None):
+ def _get_state_groups_from_groups(self, groups, state_filter):
"""Returns the state groups for a given set of groups, filtering on
types of state events.
Args:
groups(list[int]): list of state group IDs to query
- types (Iterable[str, str|None]|None): list of 2-tuples of the form
- (`type`, `state_key`), where a `state_key` of `None` matches all
- state_keys for the `type`. If None, all types are returned.
- members (bool|None): If not None, then, in addition to any filtering
- implied by types, the results are also filtered to only include
- member events (if True), or to exclude member events (if False)
-
- Returns:
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
Returns:
Deferred[dict[int, dict[tuple[str, str], str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
@@ -346,19 +634,23 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
for chunk in chunks:
res = yield self.runInteraction(
"_get_state_groups_from_groups",
- self._get_state_groups_from_groups_txn, chunk, types, members,
+ self._get_state_groups_from_groups_txn, chunk, state_filter,
)
results.update(res)
defer.returnValue(results)
def _get_state_groups_from_groups_txn(
- self, txn, groups, types=None, members=None,
+ self, txn, groups, state_filter=StateFilter.all(),
):
results = {group: {} for group in groups}
- if types is not None:
- types = list(set(types)) # deduplicate types list
+ where_clause, where_args = state_filter.make_sql_filter_clause()
+
+ # Unless the filter clause is empty, we're going to append it after an
+ # existing where clause
+ if where_clause:
+ where_clause = " AND (%s)" % (where_clause,)
if isinstance(self.database_engine, PostgresEngine):
# Temporarily disable sequential scans in this transaction. This is
@@ -374,79 +666,33 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# group for the given type, state_key.
# This may return multiple rows per (type, state_key), but last_value
# should be the same.
- sql = ("""
+ sql = """
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
- SELECT type, state_key, last_value(event_id) OVER (
+ SELECT DISTINCT type, state_key, last_value(event_id) OVER (
PARTITION BY type, state_key ORDER BY state_group ASC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS event_id FROM state_groups_state
WHERE state_group IN (
SELECT state_group FROM state
)
- %s
- """)
-
- if members is True:
- sql += " AND type = '%s'" % (EventTypes.Member,)
- elif members is False:
- sql += " AND type <> '%s'" % (EventTypes.Member,)
-
- # Turns out that postgres doesn't like doing a list of OR's and
- # is about 1000x slower, so we just issue a query for each specific
- # type seperately.
- if types is not None:
- clause_to_args = [
- (
- "AND type = ? AND state_key = ?",
- (etype, state_key)
- ) if state_key is not None else (
- "AND type = ?",
- (etype,)
- )
- for etype, state_key in types
- ]
- else:
- # If types is None we fetch all the state, and so just use an
- # empty where clause with no extra args.
- clause_to_args = [("", [])]
+ """
- for where_clause, where_args in clause_to_args:
- for group in groups:
- args = [group]
- args.extend(where_args)
+ for group in groups:
+ args = [group]
+ args.extend(where_args)
- txn.execute(sql % (where_clause,), args)
- for row in txn:
- typ, state_key, event_id = row
- key = (typ, state_key)
- results[group][key] = event_id
+ txn.execute(sql + where_clause, args)
+ for row in txn:
+ typ, state_key, event_id = row
+ key = (typ, state_key)
+ results[group][key] = event_id
else:
- where_args = []
- where_clauses = []
- wildcard_types = False
- if types is not None:
- for typ in types:
- if typ[1] is None:
- where_clauses.append("(type = ?)")
- where_args.append(typ[0])
- wildcard_types = True
- else:
- where_clauses.append("(type = ? AND state_key = ?)")
- where_args.extend([typ[0], typ[1]])
-
- where_clause = "AND (%s)" % (" OR ".join(where_clauses))
- else:
- where_clause = ""
-
- if members is True:
- where_clause += " AND type = '%s'" % EventTypes.Member
- elif members is False:
- where_clause += " AND type <> '%s'" % EventTypes.Member
+ max_entries_returned = state_filter.max_entries_returned()
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
@@ -460,12 +706,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# without the right indices (which we can't add until
# after we finish deduping state, which requires this func)
args = [next_group]
- if types:
- args.extend(where_args)
+ args.extend(where_args)
txn.execute(
"SELECT type, state_key, event_id FROM state_groups_state"
- " WHERE state_group = ? %s" % (where_clause,),
+ " WHERE state_group = ? " + where_clause,
args
)
results[group].update(
@@ -481,9 +726,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# wildcards (i.e. Nones) in which case we have to do an exhaustive
# search
if (
- types is not None and
- not wildcard_types and
- len(results[group]) == len(types)
+ max_entries_returned is not None and
+ len(results[group]) == max_entries_returned
):
break
@@ -498,20 +742,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
@defer.inlineCallbacks
- def get_state_for_events(self, event_ids, types, filtered_types=None):
+ def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
"""Given a list of event_ids and type tuples, return a list of state
- dicts for each event. The state dicts will only have the type/state_keys
- that are in the `types` list.
+ dicts for each event.
Args:
event_ids (list[string])
- types (list[(str, str|None)]|None): List of (type, state_key) tuples
- which are used to filter the state fetched. If `state_key` is None,
- all events are returned of the given type.
- May be None, which matches any key.
- filtered_types(list[str]|None): Only apply filtering via `types` to this
- list of event types. Other types of events are returned unfiltered.
- If None, `types` filtering is applied to all events.
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
Returns:
deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
@@ -521,7 +759,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
)
groups = set(itervalues(event_to_groups))
- group_to_state = yield self._get_state_for_groups(groups, types, filtered_types)
+ group_to_state = yield self._get_state_for_groups(groups, state_filter)
state_event_map = yield self.get_events(
[ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
@@ -540,20 +778,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
defer.returnValue({event: event_to_state[event] for event in event_ids})
@defer.inlineCallbacks
- def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None):
+ def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
"""
Get the state dicts corresponding to a list of events, containing the event_ids
of the state events (as opposed to the events themselves)
Args:
event_ids(list(str)): events whose state should be returned
- types(list[(str, str|None)]|None): List of (type, state_key) tuples
- which are used to filter the state fetched. If `state_key` is None,
- all events are returned of the given type.
- May be None, which matches any key.
- filtered_types(list[str]|None): Only apply filtering via `types` to this
- list of event types. Other types of events are returned unfiltered.
- If None, `types` filtering is applied to all events.
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
Returns:
A deferred dict from event_id -> (type, state_key) -> event_id
@@ -563,7 +796,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
)
groups = set(itervalues(event_to_groups))
- group_to_state = yield self._get_state_for_groups(groups, types, filtered_types)
+ group_to_state = yield self._get_state_for_groups(groups, state_filter)
event_to_state = {
event_id: group_to_state[group]
@@ -573,45 +806,35 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
defer.returnValue({event: event_to_state[event] for event in event_ids})
@defer.inlineCallbacks
- def get_state_for_event(self, event_id, types=None, filtered_types=None):
+ def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
"""
Get the state dict corresponding to a particular event
Args:
event_id(str): event whose state should be returned
- types(list[(str, str|None)]|None): List of (type, state_key) tuples
- which are used to filter the state fetched. If `state_key` is None,
- all events are returned of the given type.
- May be None, which matches any key.
- filtered_types(list[str]|None): Only apply filtering via `types` to this
- list of event types. Other types of events are returned unfiltered.
- If None, `types` filtering is applied to all events.
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
Returns:
A deferred dict from (type, state_key) -> state_event
"""
- state_map = yield self.get_state_for_events([event_id], types, filtered_types)
+ state_map = yield self.get_state_for_events([event_id], state_filter)
defer.returnValue(state_map[event_id])
@defer.inlineCallbacks
- def get_state_ids_for_event(self, event_id, types=None, filtered_types=None):
+ def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
"""
Get the state dict corresponding to a particular event
Args:
event_id(str): event whose state should be returned
- types(list[(str, str|None)]|None): List of (type, state_key) tuples
- which are used to filter the state fetched. If `state_key` is None,
- all events are returned of the given type.
- May be None, which matches any key.
- filtered_types(list[str]|None): Only apply filtering via `types` to this
- list of event types. Other types of events are returned unfiltered.
- If None, `types` filtering is applied to all events.
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
Returns:
A deferred dict from (type, state_key) -> state_event
"""
- state_map = yield self.get_state_ids_for_events([event_id], types, filtered_types)
+ state_map = yield self.get_state_ids_for_events([event_id], state_filter)
defer.returnValue(state_map[event_id])
@cached(max_entries=50000)
@@ -642,18 +865,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
- def _get_some_state_from_cache(self, cache, group, types, filtered_types=None):
+ def _get_state_for_group_using_cache(self, cache, group, state_filter):
"""Checks if group is in cache. See `_get_state_for_groups`
Args:
cache(DictionaryCache): the state group cache to use
group(int): The state group to lookup
- types(list[str, str|None]): List of 2-tuples of the form
- (`type`, `state_key`), where a `state_key` of `None` matches all
- state_keys for the `type`.
- filtered_types(list[str]|None): Only apply filtering via `types` to this
- list of event types. Other types of events are returned unfiltered.
- If None, `types` filtering is applied to all events.
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
Returns 2-tuple (`state_dict`, `got_all`).
`got_all` is a bool indicating if we successfully retrieved all
@@ -662,124 +881,102 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
is_all, known_absent, state_dict_ids = cache.get(group)
- type_to_key = {}
+ if is_all or state_filter.is_full():
+ # Either we have everything or want everything, either way
+ # `is_all` tells us whether we've gotten everything.
+ return state_filter.filter_state(state_dict_ids), is_all
# tracks whether any of our requested types are missing from the cache
missing_types = False
- for typ, state_key in types:
- key = (typ, state_key)
-
- if (
- state_key is None or
- (filtered_types is not None and typ not in filtered_types)
- ):
- type_to_key[typ] = None
- # we mark the type as missing from the cache because
- # when the cache was populated it might have been done with a
- # restricted set of state_keys, so the wildcard will not work
- # and the cache may be incomplete.
- missing_types = True
- else:
- if type_to_key.get(typ, object()) is not None:
- type_to_key.setdefault(typ, set()).add(state_key)
-
+ if state_filter.has_wildcards():
+ # We don't know if we fetched all the state keys for the types in
+ # the filter that are wildcards, so we have to assume that we may
+ # have missed some.
+ missing_types = True
+ else:
+ # There aren't any wild cards, so `concrete_types()` returns the
+ # complete list of event types we're wanting.
+ for key in state_filter.concrete_types():
if key not in state_dict_ids and key not in known_absent:
missing_types = True
+ break
- sentinel = object()
-
- def include(typ, state_key):
- valid_state_keys = type_to_key.get(typ, sentinel)
- if valid_state_keys is sentinel:
- return filtered_types is not None and typ not in filtered_types
- if valid_state_keys is None:
- return True
- if state_key in valid_state_keys:
- return True
- return False
-
- got_all = is_all
- if not got_all:
- # the cache is incomplete. We may still have got all the results we need, if
- # we don't have any wildcards in the match list.
- if not missing_types and filtered_types is None:
- got_all = True
-
- return {
- k: v for k, v in iteritems(state_dict_ids)
- if include(k[0], k[1])
- }, got_all
-
- def _get_all_state_from_cache(self, cache, group):
- """Checks if group is in cache. See `_get_state_for_groups`
-
- Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool
- indicating if we successfully retrieved all requests state from the
- cache, if False we need to query the DB for the missing state.
-
- Args:
- cache(DictionaryCache): the state group cache to use
- group: The state group to lookup
- """
- is_all, _, state_dict_ids = cache.get(group)
-
- return state_dict_ids, is_all
+ return state_filter.filter_state(state_dict_ids), not missing_types
@defer.inlineCallbacks
- def _get_state_for_groups(self, groups, types=None, filtered_types=None):
+ def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
Args:
groups (iterable[int]): list of state groups for which we want
to get the state.
- types (None|iterable[(str, None|str)]):
- indicates the state type/keys required. If None, the whole
- state is fetched and returned.
-
- Otherwise, each entry should be a `(type, state_key)` tuple to
- include in the response. A `state_key` of None is a wildcard
- meaning that we require all state with that type.
- filtered_types(list[str]|None): Only apply filtering via `types` to this
- list of event types. Other types of events are returned unfiltered.
- If None, `types` filtering is applied to all events.
-
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
Returns:
Deferred[dict[int, dict[tuple[str, str], str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
- if types is not None:
- non_member_types = [t for t in types if t[0] != EventTypes.Member]
-
- if filtered_types is not None and EventTypes.Member not in filtered_types:
- # we want all of the membership events
- member_types = None
- else:
- member_types = [t for t in types if t[0] == EventTypes.Member]
- else:
- non_member_types = None
- member_types = None
+ member_filter, non_member_filter = state_filter.get_member_split()
- non_member_state = yield self._get_state_for_groups_using_cache(
- groups, self._state_group_cache, non_member_types, filtered_types,
+ # Now we look them up in the member and non-member caches
+ non_member_state, incomplete_groups_nm, = (
+ yield self._get_state_for_groups_using_cache(
+ groups, self._state_group_cache,
+ state_filter=non_member_filter,
+ )
)
- # XXX: we could skip this entirely if member_types is []
- member_state = yield self._get_state_for_groups_using_cache(
- # we set filtered_types=None as member_state only ever contain members.
- groups, self._state_group_members_cache, member_types, None,
+
+ member_state, incomplete_groups_m, = (
+ yield self._get_state_for_groups_using_cache(
+ groups, self._state_group_members_cache,
+ state_filter=member_filter,
+ )
)
- state = non_member_state
+ state = dict(non_member_state)
for group in groups:
state[group].update(member_state[group])
+ # Now fetch any missing groups from the database
+
+ incomplete_groups = incomplete_groups_m | incomplete_groups_nm
+
+ if not incomplete_groups:
+ defer.returnValue(state)
+
+ cache_sequence_nm = self._state_group_cache.sequence
+ cache_sequence_m = self._state_group_members_cache.sequence
+
+ # Help the cache hit ratio by expanding the filter a bit
+ db_state_filter = state_filter.return_expanded()
+
+ group_to_state_dict = yield self._get_state_groups_from_groups(
+ list(incomplete_groups),
+ state_filter=db_state_filter,
+ )
+
+ # Now lets update the caches
+ self._insert_into_cache(
+ group_to_state_dict,
+ db_state_filter,
+ cache_seq_num_members=cache_sequence_m,
+ cache_seq_num_non_members=cache_sequence_nm,
+ )
+
+ # And finally update the result dict, by filtering out any extra
+ # stuff we pulled out of the database.
+ for group, group_state_dict in iteritems(group_to_state_dict):
+ # We just replace any existing entries, as we will have loaded
+ # everything we need from the database anyway.
+ state[group] = state_filter.filter_state(group_state_dict)
+
defer.returnValue(state)
- @defer.inlineCallbacks
def _get_state_for_groups_using_cache(
- self, groups, cache, types=None, filtered_types=None
+ self, groups, cache, state_filter,
):
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key, querying from a specific cache.
@@ -790,89 +987,85 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
cache (DictionaryCache): the cache of group ids to state dicts which
we will pass through - either the normal state cache or the specific
members state cache.
- types (None|iterable[(str, None|str)]):
- indicates the state type/keys required. If None, the whole
- state is fetched and returned.
-
- Otherwise, each entry should be a `(type, state_key)` tuple to
- include in the response. A `state_key` of None is a wildcard
- meaning that we require all state with that type.
- filtered_types(list[str]|None): Only apply filtering via `types` to this
- list of event types. Other types of events are returned unfiltered.
- If None, `types` filtering is applied to all events.
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
+ tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
+ of entries in the cache, and the state group ids either missing
+ from the cache or incomplete.
"""
- if types:
- types = frozenset(types)
results = {}
- missing_groups = []
- if types is not None:
- for group in set(groups):
- state_dict_ids, got_all = self._get_some_state_from_cache(
- cache, group, types, filtered_types
- )
- results[group] = state_dict_ids
+ incomplete_groups = set()
+ for group in set(groups):
+ state_dict_ids, got_all = self._get_state_for_group_using_cache(
+ cache, group, state_filter
+ )
+ results[group] = state_dict_ids
- if not got_all:
- missing_groups.append(group)
- else:
- for group in set(groups):
- state_dict_ids, got_all = self._get_all_state_from_cache(
- cache, group
- )
+ if not got_all:
+ incomplete_groups.add(group)
- results[group] = state_dict_ids
+ return results, incomplete_groups
- if not got_all:
- missing_groups.append(group)
+ def _insert_into_cache(self, group_to_state_dict, state_filter,
+ cache_seq_num_members, cache_seq_num_non_members):
+ """Inserts results from querying the database into the relevant cache.
- if missing_groups:
- # Okay, so we have some missing_types, let's fetch them.
- cache_seq_num = cache.sequence
+ Args:
+ group_to_state_dict (dict): The new entries pulled from database.
+ Map from state group to state dict
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+ cache_seq_num_members (int): Sequence number of member cache since
+ last lookup in cache
+ cache_seq_num_non_members (int): Sequence number of member cache since
+ last lookup in cache
+ """
- # the DictionaryCache knows if it has *all* the state, but
- # does not know if it has all of the keys of a particular type,
- # which makes wildcard lookups expensive unless we have a complete
- # cache. Hence, if we are doing a wildcard lookup, populate the
- # cache fully so that we can do an efficient lookup next time.
+ # We need to work out which types we've fetched from the DB for the
+ # member vs non-member caches. This should be as accurate as possible,
+ # but can be an underestimate (e.g. when we have wild cards)
- if filtered_types or (types and any(k is None for (t, k) in types)):
- types_to_fetch = None
- else:
- types_to_fetch = types
+ member_filter, non_member_filter = state_filter.get_member_split()
+ if member_filter.is_full():
+ # We fetched all member events
+ member_types = None
+ else:
+ # `concrete_types()` will only return a subset when there are wild
+ # cards in the filter, but that's fine.
+ member_types = member_filter.concrete_types()
- group_to_state_dict = yield self._get_state_groups_from_groups(
- missing_groups, types_to_fetch, cache == self._state_group_members_cache,
- )
+ if non_member_filter.is_full():
+ # We fetched all non member events
+ non_member_types = None
+ else:
+ non_member_types = non_member_filter.concrete_types()
+
+ for group, group_state_dict in iteritems(group_to_state_dict):
+ state_dict_members = {}
+ state_dict_non_members = {}
- for group, group_state_dict in iteritems(group_to_state_dict):
- state_dict = results[group]
-
- # update the result, filtering by `types`.
- if types:
- for k, v in iteritems(group_state_dict):
- (typ, _) = k
- if (
- (k in types or (typ, None) in types) or
- (filtered_types and typ not in filtered_types)
- ):
- state_dict[k] = v
+ for k, v in iteritems(group_state_dict):
+ if k[0] == EventTypes.Member:
+ state_dict_members[k] = v
else:
- state_dict.update(group_state_dict)
-
- # update the cache with all the things we fetched from the
- # database.
- cache.update(
- cache_seq_num,
- key=group,
- value=group_state_dict,
- fetched_keys=types_to_fetch,
- )
+ state_dict_non_members[k] = v
- defer.returnValue(results)
+ self._state_group_members_cache.update(
+ cache_seq_num_members,
+ key=group,
+ value=state_dict_members,
+ fetched_keys=member_types,
+ )
+
+ self._state_group_cache.update(
+ cache_seq_num_non_members,
+ key=group,
+ value=state_dict_non_members,
+ fetched_keys=non_member_types,
+ )
def store_state_group(self, event_id, room_id, prev_group, delta_ids,
current_state_ids):
@@ -1268,12 +1461,12 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
continue
prev_state = self._get_state_groups_from_groups_txn(
- txn, [prev_group], types=None
+ txn, [prev_group],
)
prev_state = prev_state[prev_group]
curr_state = self._get_state_groups_from_groups_txn(
- txn, [state_group], types=None
+ txn, [state_group],
)
curr_state = curr_state[state_group]
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index a3032cdce9..d8bf953ec0 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -30,7 +30,7 @@ from ._base import SQLBaseStore, db_to_json
# py2 sqlite has buffer hardcoded as only binary type, so we must use it,
# despite being deprecated and removed in favor of memoryview
if six.PY2:
- db_binary_type = buffer
+ db_binary_type = six.moves.builtins.buffer
else:
db_binary_type = memoryview
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 9a8fae0497..0ae7e2ef3b 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+import re
from itertools import islice
import attr
@@ -138,3 +139,27 @@ def log_failure(failure, msg, consumeErrors=True):
if not consumeErrors:
return failure
+
+
+def glob_to_regex(glob):
+ """Converts a glob to a compiled regex object.
+
+ The regex is anchored at the beginning and end of the string.
+
+ Args:
+ glob (str)
+
+ Returns:
+ re.RegexObject
+ """
+ res = ''
+ for c in glob:
+ if c == '*':
+ res = res + '.*'
+ elif c == '?':
+ res = res + '.'
+ else:
+ res = res + re.escape(c)
+
+ # \A anchors at start of string, \Z at end of string
+ return re.compile(r"\A" + res + r"\Z", re.IGNORECASE)
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index f2bde74dc5..625aedc940 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -15,6 +15,8 @@
import logging
+from six import integer_types
+
from sortedcontainers import SortedDict
from synapse.util import caches
@@ -47,7 +49,7 @@ class StreamChangeCache(object):
def has_entity_changed(self, entity, stream_pos):
"""Returns True if the entity may have been updated since stream_pos
"""
- assert type(stream_pos) is int or type(stream_pos) is long
+ assert type(stream_pos) in integer_types
if stream_pos < self._earliest_known_stream_pos:
self.metrics.inc_misses()
diff --git a/synapse/util/logcontext.py b/synapse/util/logcontext.py
index 89224b26cc..4c6e92beb8 100644
--- a/synapse/util/logcontext.py
+++ b/synapse/util/logcontext.py
@@ -25,7 +25,7 @@ See doc/log_contexts.rst for details on how this works.
import logging
import threading
-from twisted.internet import defer
+from twisted.internet import defer, threads
logger = logging.getLogger(__name__)
@@ -562,58 +562,76 @@ def _set_context_cb(result, context):
return result
-# modules to ignore in `logcontext_tracer`
-_to_ignore = [
- "synapse.util.logcontext",
- "synapse.http.server",
- "synapse.storage._base",
- "synapse.util.async_helpers",
-]
+def defer_to_thread(reactor, f, *args, **kwargs):
+ """
+ Calls the function `f` using a thread from the reactor's default threadpool and
+ returns the result as a Deferred.
+
+ Creates a new logcontext for `f`, which is created as a child of the current
+ logcontext (so its CPU usage metrics will get attributed to the current
+ logcontext). `f` should preserve the logcontext it is given.
+
+ The result deferred follows the Synapse logcontext rules: you should `yield`
+ on it.
+
+ Args:
+ reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread
+ the Deferred will be invoked, and whose threadpool we should use for the
+ function.
+
+ Normally this will be hs.get_reactor().
+
+ f (callable): The function to call.
+ args: positional arguments to pass to f.
-def logcontext_tracer(frame, event, arg):
- """A tracer that logs whenever a logcontext "unexpectedly" changes within
- a function. Probably inaccurate.
+ kwargs: keyword arguments to pass to f.
- Use by calling `sys.settrace(logcontext_tracer)` in the main thread.
+ Returns:
+ Deferred: A Deferred which fires a callback with the result of `f`, or an
+ errback if `f` throws an exception.
"""
- if event == 'call':
- name = frame.f_globals["__name__"]
- if name.startswith("synapse"):
- if name == "synapse.util.logcontext":
- if frame.f_code.co_name in ["__enter__", "__exit__"]:
- tracer = frame.f_back.f_trace
- if tracer:
- tracer.just_changed = True
-
- tracer = frame.f_trace
- if tracer:
- return tracer
-
- if not any(name.startswith(ig) for ig in _to_ignore):
- return LineTracer()
-
-
-class LineTracer(object):
- __slots__ = ["context", "just_changed"]
-
- def __init__(self):
- self.context = LoggingContext.current_context()
- self.just_changed = False
-
- def __call__(self, frame, event, arg):
- if event in 'line':
- if self.just_changed:
- self.context = LoggingContext.current_context()
- self.just_changed = False
- else:
- c = LoggingContext.current_context()
- if c != self.context:
- logger.info(
- "Context changed! %s -> %s, %s, %s",
- self.context, c,
- frame.f_code.co_filename, frame.f_lineno
- )
- self.context = c
+ return defer_to_threadpool(reactor, reactor.getThreadPool(), f, *args, **kwargs)
- return self
+
+def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
+ """
+ A wrapper for twisted.internet.threads.deferToThreadpool, which handles
+ logcontexts correctly.
+
+ Calls the function `f` using a thread from the given threadpool and returns
+ the result as a Deferred.
+
+ Creates a new logcontext for `f`, which is created as a child of the current
+ logcontext (so its CPU usage metrics will get attributed to the current
+ logcontext). `f` should preserve the logcontext it is given.
+
+ The result deferred follows the Synapse logcontext rules: you should `yield`
+ on it.
+
+ Args:
+ reactor (twisted.internet.base.ReactorBase): The reactor in whose main thread
+ the Deferred will be invoked. Normally this will be hs.get_reactor().
+
+ threadpool (twisted.python.threadpool.ThreadPool): The threadpool to use for
+ running `f`. Normally this will be hs.get_reactor().getThreadPool().
+
+ f (callable): The function to call.
+
+ args: positional arguments to pass to f.
+
+ kwargs: keyword arguments to pass to f.
+
+ Returns:
+ Deferred: A Deferred which fires a callback with the result of `f`, or an
+ errback if `f` throws an exception.
+ """
+ logcontext = LoggingContext.current_context()
+
+ def g():
+ with LoggingContext(parent_context=logcontext):
+ return f(*args, **kwargs)
+
+ return make_deferred_yieldable(
+ threads.deferToThreadPool(reactor, threadpool, g)
+ )
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 43f48196be..0281a7c919 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -23,6 +23,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events.utils import prune_event
+from synapse.storage.state import StateFilter
from synapse.types import get_domain_from_id
logger = logging.getLogger(__name__)
@@ -72,7 +73,7 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
)
event_id_to_state = yield store.get_state_for_events(
frozenset(e.event_id for e in events),
- types=types,
+ state_filter=StateFilter.from_types(types),
)
ignore_dict_content = yield store.get_global_account_data_by_type_for_user(
@@ -273,8 +274,8 @@ def filter_events_for_server(store, server_name, events):
# need to check membership (as we know the server is in the room).
event_to_state_ids = yield store.get_state_ids_for_events(
frozenset(e.event_id for e in events),
- types=(
- (EventTypes.RoomHistoryVisibility, ""),
+ state_filter=StateFilter.from_types(
+ types=((EventTypes.RoomHistoryVisibility, ""),),
)
)
@@ -314,9 +315,11 @@ def filter_events_for_server(store, server_name, events):
# of the history vis and membership state at those events.
event_to_state_ids = yield store.get_state_ids_for_events(
frozenset(e.event_id for e in events),
- types=(
- (EventTypes.RoomHistoryVisibility, ""),
- (EventTypes.Member, None),
+ state_filter=StateFilter.from_types(
+ types=(
+ (EventTypes.RoomHistoryVisibility, ""),
+ (EventTypes.Member, None),
+ ),
)
)
|