diff --git a/synapse/_scripts/__init__.py b/synapse/_scripts/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/synapse/_scripts/__init__.py
diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py
new file mode 100644
index 0000000000..70cecde486
--- /dev/null
+++ b/synapse/_scripts/register_new_matrix_user.py
@@ -0,0 +1,215 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function
+
+import argparse
+import getpass
+import hashlib
+import hmac
+import logging
+import sys
+
+from six.moves import input
+
+import requests as _requests
+import yaml
+
+
+def request_registration(
+ user,
+ password,
+ server_location,
+ shared_secret,
+ admin=False,
+ requests=_requests,
+ _print=print,
+ exit=sys.exit,
+):
+
+ url = "%s/_matrix/client/r0/admin/register" % (server_location,)
+
+ # Get the nonce
+ r = requests.get(url, verify=False)
+
+ if r.status_code is not 200:
+ _print("ERROR! Received %d %s" % (r.status_code, r.reason))
+ if 400 <= r.status_code < 500:
+ try:
+ _print(r.json()["error"])
+ except Exception:
+ pass
+ return exit(1)
+
+ nonce = r.json()["nonce"]
+
+ mac = hmac.new(key=shared_secret.encode('utf8'), digestmod=hashlib.sha1)
+
+ mac.update(nonce.encode('utf8'))
+ mac.update(b"\x00")
+ mac.update(user.encode('utf8'))
+ mac.update(b"\x00")
+ mac.update(password.encode('utf8'))
+ mac.update(b"\x00")
+ mac.update(b"admin" if admin else b"notadmin")
+
+ mac = mac.hexdigest()
+
+ data = {
+ "nonce": nonce,
+ "username": user,
+ "password": password,
+ "mac": mac,
+ "admin": admin,
+ }
+
+ _print("Sending registration request...")
+ r = requests.post(url, json=data, verify=False)
+
+ if r.status_code is not 200:
+ _print("ERROR! Received %d %s" % (r.status_code, r.reason))
+ if 400 <= r.status_code < 500:
+ try:
+ _print(r.json()["error"])
+ except Exception:
+ pass
+ return exit(1)
+
+ _print("Success!")
+
+
+def register_new_user(user, password, server_location, shared_secret, admin):
+ if not user:
+ try:
+ default_user = getpass.getuser()
+ except Exception:
+ default_user = None
+
+ if default_user:
+ user = input("New user localpart [%s]: " % (default_user,))
+ if not user:
+ user = default_user
+ else:
+ user = input("New user localpart: ")
+
+ if not user:
+ print("Invalid user name")
+ sys.exit(1)
+
+ if not password:
+ password = getpass.getpass("Password: ")
+
+ if not password:
+ print("Password cannot be blank.")
+ sys.exit(1)
+
+ confirm_password = getpass.getpass("Confirm password: ")
+
+ if password != confirm_password:
+ print("Passwords do not match")
+ sys.exit(1)
+
+ if admin is None:
+ admin = input("Make admin [no]: ")
+ if admin in ("y", "yes", "true"):
+ admin = True
+ else:
+ admin = False
+
+ request_registration(user, password, server_location, shared_secret, bool(admin))
+
+
+def main():
+
+ logging.captureWarnings(True)
+
+ parser = argparse.ArgumentParser(
+ description="Used to register new users with a given home server when"
+ " registration has been disabled. The home server must be"
+ " configured with the 'registration_shared_secret' option"
+ " set."
+ )
+ parser.add_argument(
+ "-u",
+ "--user",
+ default=None,
+ help="Local part of the new user. Will prompt if omitted.",
+ )
+ parser.add_argument(
+ "-p",
+ "--password",
+ default=None,
+ help="New password for user. Will prompt if omitted.",
+ )
+ admin_group = parser.add_mutually_exclusive_group()
+ admin_group.add_argument(
+ "-a",
+ "--admin",
+ action="store_true",
+ help=(
+ "Register new user as an admin. "
+ "Will prompt if --no-admin is not set either."
+ ),
+ )
+ admin_group.add_argument(
+ "--no-admin",
+ action="store_true",
+ help=(
+ "Register new user as a regular user. "
+ "Will prompt if --admin is not set either."
+ ),
+ )
+
+ group = parser.add_mutually_exclusive_group(required=True)
+ group.add_argument(
+ "-c",
+ "--config",
+ type=argparse.FileType('r'),
+ help="Path to server config file. Used to read in shared secret.",
+ )
+
+ group.add_argument(
+ "-k", "--shared-secret", help="Shared secret as defined in server config file."
+ )
+
+ parser.add_argument(
+ "server_url",
+ default="https://localhost:8448",
+ nargs='?',
+ help="URL to use to talk to the home server. Defaults to "
+ " 'https://localhost:8448'.",
+ )
+
+ args = parser.parse_args()
+
+ if "config" in args and args.config:
+ config = yaml.safe_load(args.config)
+ secret = config.get("registration_shared_secret", None)
+ if not secret:
+ print("No 'registration_shared_secret' defined in config.")
+ sys.exit(1)
+ else:
+ secret = args.shared_secret
+
+ admin = None
+ if args.admin or args.no_admin:
+ admin = args.admin
+
+ register_new_user(args.user, args.password, args.server_url, secret, admin)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index b2815da0ab..f5928633d2 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -62,6 +62,7 @@ class LoginType(object):
class EventTypes(object):
Member = "m.room.member"
Create = "m.room.create"
+ Tombstone = "m.room.tombstone"
JoinRules = "m.room.join_rules"
PowerLevels = "m.room.power_levels"
Aliases = "m.room.aliases"
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index 6d9f1ca0ef..f78695b657 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -28,7 +28,6 @@ FEDERATION_PREFIX = "/_matrix/federation/v1"
STATIC_PREFIX = "/_matrix/static"
WEB_CLIENT_PREFIX = "/_matrix/client"
CONTENT_REPO_PREFIX = "/_matrix/content"
-SERVER_KEY_PREFIX = "/_matrix/key/v1"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
MEDIA_PREFIX = "/_matrix/media/r0"
LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 0b85b377e3..415374a2ce 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -37,7 +37,6 @@ from synapse.api.urls import (
FEDERATION_PREFIX,
LEGACY_MEDIA_PREFIX,
MEDIA_PREFIX,
- SERVER_KEY_PREFIX,
SERVER_KEY_V2_PREFIX,
STATIC_PREFIX,
WEB_CLIENT_PREFIX,
@@ -59,7 +58,6 @@ from synapse.python_dependencies import CONDITIONAL_REQUIREMENTS, check_requirem
from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.rest import ClientRestResource
-from synapse.rest.key.v1.server_key_resource import LocalKey
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.server import HomeServer
@@ -236,10 +234,7 @@ class SynapseHomeServer(HomeServer):
)
if name in ["keys", "federation"]:
- resources.update({
- SERVER_KEY_PREFIX: LocalKey(self),
- SERVER_KEY_V2_PREFIX: KeyApiV2Resource(self),
- })
+ resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
if name == "webclient":
resources[WEB_CLIENT_PREFIX] = build_resource_for_web_client(self)
@@ -553,14 +548,6 @@ def run(hs):
generate_monthly_active_users,
)
- # XXX is this really supposed to be a background process? it looks
- # like it needs to complete before some of the other stuff runs.
- run_as_background_process(
- "initialise_reserved_users",
- hs.get_datastore().initialise_reserved_users,
- hs.config.mau_limits_reserved_threepids,
- )
-
start_generate_monthly_active_users()
if hs.config.limit_usage_by_mau:
clock.looping_call(start_generate_monthly_active_users, 5 * 60 * 1000)
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index b8d5690f2b..10dd40159f 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -31,6 +31,7 @@ from .push import PushConfig
from .ratelimiting import RatelimitConfig
from .registration import RegistrationConfig
from .repository import ContentRepositoryConfig
+from .room_directory import RoomDirectoryConfig
from .saml2 import SAML2Config
from .server import ServerConfig
from .server_notices_config import ServerNoticesConfig
@@ -49,7 +50,7 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
WorkerConfig, PasswordAuthProviderConfig, PushConfig,
SpamCheckerConfig, GroupsConfig, UserDirectoryConfig,
ConsentConfig,
- ServerNoticesConfig,
+ ServerNoticesConfig, RoomDirectoryConfig,
):
pass
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 0fb964eb67..7480ed5145 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -15,10 +15,10 @@
from distutils.util import strtobool
+from synapse.config._base import Config, ConfigError
+from synapse.types import RoomAlias
from synapse.util.stringutils import random_string_with_symbols
-from ._base import Config
-
class RegistrationConfig(Config):
@@ -44,6 +44,10 @@ class RegistrationConfig(Config):
)
self.auto_join_rooms = config.get("auto_join_rooms", [])
+ for room_alias in self.auto_join_rooms:
+ if not RoomAlias.is_valid(room_alias):
+ raise ConfigError('Invalid auto_join_rooms entry %s' % (room_alias,))
+ self.autocreate_auto_join_rooms = config.get("autocreate_auto_join_rooms", True)
def default_config(self, **kwargs):
registration_shared_secret = random_string_with_symbols(50)
@@ -98,6 +102,13 @@ class RegistrationConfig(Config):
# to these rooms
#auto_join_rooms:
# - "#example:example.com"
+
+ # Where auto_join_rooms are specified, setting this flag ensures that the
+ # the rooms exist by creating them when the first user on the
+ # homeserver registers.
+ # Setting to false means that if the rooms are not manually created,
+ # users cannot be auto-joined since they do not exist.
+ autocreate_auto_join_rooms: true
""" % locals()
def add_arguments(self, parser):
diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
new file mode 100644
index 0000000000..9da13ab11b
--- /dev/null
+++ b/synapse/config/room_directory.py
@@ -0,0 +1,102 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util import glob_to_regex
+
+from ._base import Config, ConfigError
+
+
+class RoomDirectoryConfig(Config):
+ def read_config(self, config):
+ alias_creation_rules = config["alias_creation_rules"]
+
+ self._alias_creation_rules = [
+ _AliasRule(rule)
+ for rule in alias_creation_rules
+ ]
+
+ def default_config(self, config_dir_path, server_name, **kwargs):
+ return """
+ # The `alias_creation` option controls who's allowed to create aliases
+ # on this server.
+ #
+ # The format of this option is a list of rules that contain globs that
+ # match against user_id and the new alias (fully qualified with server
+ # name). The action in the first rule that matches is taken, which can
+ # currently either be "allow" or "deny".
+ #
+ # If no rules match the request is denied.
+ alias_creation_rules:
+ - user_id: "*"
+ alias: "*"
+ action: allow
+ """
+
+ def is_alias_creation_allowed(self, user_id, alias):
+ """Checks if the given user is allowed to create the given alias
+
+ Args:
+ user_id (str)
+ alias (str)
+
+ Returns:
+ boolean: True if user is allowed to crate the alias
+ """
+ for rule in self._alias_creation_rules:
+ if rule.matches(user_id, alias):
+ return rule.action == "allow"
+
+ return False
+
+
+class _AliasRule(object):
+ def __init__(self, rule):
+ action = rule["action"]
+ user_id = rule["user_id"]
+ alias = rule["alias"]
+
+ if action in ("allow", "deny"):
+ self.action = action
+ else:
+ raise ConfigError(
+ "alias_creation_rules rules can only have action of 'allow'"
+ " or 'deny'"
+ )
+
+ try:
+ self._user_id_regex = glob_to_regex(user_id)
+ self._alias_regex = glob_to_regex(alias)
+ except Exception as e:
+ raise ConfigError("Failed to parse glob into regex: %s", e)
+
+ def matches(self, user_id, alias):
+ """Tests if this rule matches the given user_id and alias.
+
+ Args:
+ user_id (str)
+ alias (str)
+
+ Returns:
+ boolean
+ """
+
+ # Note: The regexes are anchored at both ends
+ if not self._user_id_regex.match(user_id):
+ return False
+
+ if not self._alias_regex.match(alias):
+ return False
+
+ return True
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
index 080c81f14b..d40e4b8591 100644
--- a/synapse/crypto/keyclient.py
+++ b/synapse/crypto/keyclient.py
@@ -15,6 +15,8 @@
import logging
+from six.moves import urllib
+
from canonicaljson import json
from twisted.internet import defer, reactor
@@ -28,15 +30,15 @@ from synapse.util import logcontext
logger = logging.getLogger(__name__)
-KEY_API_V1 = b"/_matrix/key/v1/"
+KEY_API_V2 = "/_matrix/key/v2/server/%s"
@defer.inlineCallbacks
-def fetch_server_key(server_name, tls_client_options_factory, path=KEY_API_V1):
+def fetch_server_key(server_name, tls_client_options_factory, key_id):
"""Fetch the keys for a remote server."""
factory = SynapseKeyClientFactory()
- factory.path = path
+ factory.path = KEY_API_V2 % (urllib.parse.quote(key_id), )
factory.host = server_name
endpoint = matrix_federation_endpoint(
reactor, server_name, tls_client_options_factory, timeout=30
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index d89f94c219..515ebbc148 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2017 New Vector Ltd.
+# Copyright 2017, 2018 New Vector Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,8 +18,6 @@ import hashlib
import logging
from collections import namedtuple
-from six.moves import urllib
-
from signedjson.key import (
decode_verify_key_bytes,
encode_verify_key_base64,
@@ -395,32 +393,13 @@ class Keyring(object):
@defer.inlineCallbacks
def get_keys_from_server(self, server_name_and_key_ids):
- @defer.inlineCallbacks
- def get_key(server_name, key_ids):
- keys = None
- try:
- keys = yield self.get_server_verify_key_v2_direct(
- server_name, key_ids
- )
- except Exception as e:
- logger.info(
- "Unable to get key %r for %r directly: %s %s",
- key_ids, server_name,
- type(e).__name__, str(e),
- )
-
- if not keys:
- keys = yield self.get_server_verify_key_v1_direct(
- server_name, key_ids
- )
-
- keys = {server_name: keys}
-
- defer.returnValue(keys)
-
results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
[
- run_in_background(get_key, server_name, key_ids)
+ run_in_background(
+ self.get_server_verify_key_v2_direct,
+ server_name,
+ key_ids,
+ )
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
@@ -525,10 +504,7 @@ class Keyring(object):
continue
(response, tls_certificate) = yield fetch_server_key(
- server_name, self.hs.tls_client_options_factory,
- path=("/_matrix/key/v2/server/%s" % (
- urllib.parse.quote(requested_key_id),
- )).encode("ascii"),
+ server_name, self.hs.tls_client_options_factory, requested_key_id
)
if (u"signatures" not in response
@@ -657,78 +633,6 @@ class Keyring(object):
defer.returnValue(results)
- @defer.inlineCallbacks
- def get_server_verify_key_v1_direct(self, server_name, key_ids):
- """Finds a verification key for the server with one of the key ids.
- Args:
- server_name (str): The name of the server to fetch a key for.
- keys_ids (list of str): The key_ids to check for.
- """
-
- # Try to fetch the key from the remote server.
-
- (response, tls_certificate) = yield fetch_server_key(
- server_name, self.hs.tls_client_options_factory
- )
-
- # Check the response.
-
- x509_certificate_bytes = crypto.dump_certificate(
- crypto.FILETYPE_ASN1, tls_certificate
- )
-
- if ("signatures" not in response
- or server_name not in response["signatures"]):
- raise KeyLookupError("Key response not signed by remote server")
-
- if "tls_certificate" not in response:
- raise KeyLookupError("Key response missing TLS certificate")
-
- tls_certificate_b64 = response["tls_certificate"]
-
- if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
- raise KeyLookupError("TLS certificate doesn't match")
-
- # Cache the result in the datastore.
-
- time_now_ms = self.clock.time_msec()
-
- verify_keys = {}
- for key_id, key_base64 in response["verify_keys"].items():
- if is_signing_algorithm_supported(key_id):
- key_bytes = decode_base64(key_base64)
- verify_key = decode_verify_key_bytes(key_id, key_bytes)
- verify_key.time_added = time_now_ms
- verify_keys[key_id] = verify_key
-
- for key_id in response["signatures"][server_name]:
- if key_id not in response["verify_keys"]:
- raise KeyLookupError(
- "Key response must include verification keys for all"
- " signatures"
- )
- if key_id in verify_keys:
- verify_signed_json(
- response,
- server_name,
- verify_keys[key_id]
- )
-
- yield self.store.store_server_certificate(
- server_name,
- server_name,
- time_now_ms,
- tls_certificate,
- )
-
- yield self.store_keys(
- server_name=server_name,
- from_server=server_name,
- verify_keys=verify_keys,
- )
-
- defer.returnValue(verify_keys)
-
def store_keys(self, server_name, from_server, verify_keys):
"""Store a collection of verify keys for a given server
Args:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index af0107a46e..fa2cc550e2 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import re
import six
from six import iteritems
@@ -44,6 +43,7 @@ from synapse.replication.http.federation import (
ReplicationGetQueryRestServlet,
)
from synapse.types import get_domain_from_id
+from synapse.util import glob_to_regex
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.logcontext import nested_logging_context
@@ -324,11 +324,6 @@ class FederationServer(FederationBase):
defer.returnValue((404, ""))
@defer.inlineCallbacks
- @log_function
- def on_pull_request(self, origin, versions):
- raise NotImplementedError("Pull transactions not implemented")
-
- @defer.inlineCallbacks
def on_query_request(self, query_type, args):
received_queries_counter.labels(query_type).inc()
resp = yield self.registry.on_query(query_type, args)
@@ -729,22 +724,10 @@ def _acl_entry_matches(server_name, acl_entry):
if not isinstance(acl_entry, six.string_types):
logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry))
return False
- regex = _glob_to_regex(acl_entry)
+ regex = glob_to_regex(acl_entry)
return regex.match(server_name)
-def _glob_to_regex(glob):
- res = ''
- for c in glob:
- if c == '*':
- res = res + '.*'
- elif c == '?':
- res = res + '.'
- else:
- res = res + re.escape(c)
- return re.compile(res + "\\Z", re.IGNORECASE)
-
-
class FederationHandlerRegistry(object):
"""Allows classes to register themselves as handlers for a given EDU or
query type for incoming federation traffic.
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 7288d49074..3553f418f1 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -362,14 +362,6 @@ class FederationSendServlet(BaseFederationServlet):
defer.returnValue((code, response))
-class FederationPullServlet(BaseFederationServlet):
- PATH = "/pull/"
-
- # This is for when someone asks us for everything since version X
- def on_GET(self, origin, content, query):
- return self.handler.on_pull_request(query["origin"][0], query["v"])
-
-
class FederationEventServlet(BaseFederationServlet):
PATH = "/event/(?P<event_id>[^/]*)/"
@@ -1261,7 +1253,6 @@ class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet):
FEDERATION_SERVLET_CLASSES = (
FederationSendServlet,
- FederationPullServlet,
FederationEventServlet,
FederationStateServlet,
FederationStateIdsServlet,
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 02f12f6645..0699731c13 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -43,6 +43,7 @@ class DirectoryHandler(BaseHandler):
self.state = hs.get_state_handler()
self.appservice_handler = hs.get_application_service_handler()
self.event_creation_handler = hs.get_event_creation_handler()
+ self.config = hs.config
self.federation = hs.get_federation_client()
hs.get_federation_registry().register_query_handler(
@@ -111,6 +112,14 @@ class DirectoryHandler(BaseHandler):
403, "This user is not permitted to create this alias",
)
+ if not self.config.is_alias_creation_allowed(user_id, room_alias.to_string()):
+ # Lets just return a generic message, as there may be all sorts of
+ # reasons why we said no. TODO: Allow configurable error messages
+ # per alias creation rule?
+ raise SynapseError(
+ 403, "Not allowed to create alias",
+ )
+
can_create = yield self.can_modify_alias(
room_alias,
user_id=user_id
@@ -129,9 +138,30 @@ class DirectoryHandler(BaseHandler):
)
@defer.inlineCallbacks
- def delete_association(self, requester, room_alias):
- # association deletion for human users
+ def delete_association(self, requester, room_alias, send_event=True):
+ """Remove an alias from the directory
+
+ (this is only meant for human users; AS users should call
+ delete_appservice_association)
+
+ Args:
+ requester (Requester):
+ room_alias (RoomAlias):
+ send_event (bool): Whether to send an updated m.room.aliases event.
+ Note that, if we delete the canonical alias, we will always attempt
+ to send an m.room.canonical_alias event
+
+ Returns:
+ Deferred[unicode]: room id that the alias used to point to
+ Raises:
+ NotFoundError: if the alias doesn't exist
+
+ AuthError: if the user doesn't have perms to delete the alias (ie, the user
+ is neither the creator of the alias, nor a server admin.
+
+ SynapseError: if the alias belongs to an AS
+ """
user_id = requester.user.to_string()
try:
@@ -159,10 +189,11 @@ class DirectoryHandler(BaseHandler):
room_id = yield self._delete_association(room_alias)
try:
- yield self.send_room_alias_update_event(
- requester,
- room_id
- )
+ if send_event:
+ yield self.send_room_alias_update_event(
+ requester,
+ room_id
+ )
yield self._update_canonical_alias(
requester,
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index e009395207..563bb3cea3 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -156,7 +156,7 @@ class InitialSyncHandler(BaseHandler):
room_end_token = "s%d" % (event.stream_ordering,)
deferred_room_state = run_in_background(
self.store.get_state_for_events,
- [event.event_id], None,
+ [event.event_id],
)
deferred_room_state.addCallback(
lambda states: states[event.event_id]
@@ -301,7 +301,7 @@ class InitialSyncHandler(BaseHandler):
def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
membership, member_event_id, is_peeking):
room_state = yield self.store.get_state_for_events(
- [member_event_id], None
+ [member_event_id],
)
room_state = room_state[member_event_id]
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 6c4fcfb10a..a7cd779b02 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -35,6 +35,7 @@ from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
+from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder
@@ -80,7 +81,7 @@ class MessageHandler(object):
elif membership == Membership.LEAVE:
key = (event_type, state_key)
room_state = yield self.store.get_state_for_events(
- [membership_event_id], [key]
+ [membership_event_id], StateFilter.from_types([key])
)
data = room_state[membership_event_id].get(key)
@@ -88,7 +89,7 @@ class MessageHandler(object):
@defer.inlineCallbacks
def get_state_events(
- self, user_id, room_id, types=None, filtered_types=None,
+ self, user_id, room_id, state_filter=StateFilter.all(),
at_token=None, is_guest=False,
):
"""Retrieve all state events for a given room. If the user is
@@ -100,13 +101,8 @@ class MessageHandler(object):
Args:
user_id(str): The user requesting state events.
room_id(str): The room ID to get all state events from.
- types(list[(str, str|None)]|None): List of (type, state_key) tuples
- which are used to filter the state fetched. If `state_key` is None,
- all events are returned of the given type.
- May be None, which matches any key.
- filtered_types(list[str]|None): Only apply filtering via `types` to this
- list of event types. Other types of events are returned unfiltered.
- If None, `types` filtering is applied to all events.
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
at_token(StreamToken|None): the stream token of the at which we are requesting
the stats. If the user is not allowed to view the state as of that
stream token, we raise a 403 SynapseError. If None, returns the current
@@ -139,7 +135,7 @@ class MessageHandler(object):
event = last_events[0]
if visible_events:
room_state = yield self.store.get_state_for_events(
- [event.event_id], types, filtered_types=filtered_types,
+ [event.event_id], state_filter=state_filter,
)
room_state = room_state[event.event_id]
else:
@@ -158,12 +154,12 @@ class MessageHandler(object):
if membership == Membership.JOIN:
state_ids = yield self.store.get_filtered_current_state_ids(
- room_id, types, filtered_types=filtered_types,
+ room_id, state_filter=state_filter,
)
room_state = yield self.store.get_events(state_ids.values())
elif membership == Membership.LEAVE:
room_state = yield self.store.get_state_for_events(
- [membership_event_id], types, filtered_types=filtered_types,
+ [membership_event_id], state_filter=state_filter,
)
room_state = room_state[membership_event_id]
@@ -431,6 +427,9 @@ class EventCreationHandler(object):
if event.is_state():
prev_state = yield self.deduplicate_state_event(event, context)
+ logger.info(
+ "Not bothering to persist duplicate state event %s", event.event_id,
+ )
if prev_state is not None:
defer.returnValue(prev_state)
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index a155b6e938..43f81bd607 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -21,6 +21,7 @@ from twisted.python.failure import Failure
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.events.utils import serialize_event
+from synapse.storage.state import StateFilter
from synapse.types import RoomStreamToken
from synapse.util.async_helpers import ReadWriteLock
from synapse.util.logcontext import run_in_background
@@ -255,16 +256,14 @@ class PaginationHandler(object):
if event_filter and event_filter.lazy_load_members():
# TODO: remove redundant members
- types = [
- (EventTypes.Member, state_key)
- for state_key in set(
- event.sender # FIXME: we also care about invite targets etc.
- for event in events
- )
- ]
+ # FIXME: we also care about invite targets etc.
+ state_filter = StateFilter.from_types(
+ (EventTypes.Member, event.sender)
+ for event in events
+ )
state_ids = yield self.store.get_state_ids_for_event(
- events[0].event_id, types=types,
+ events[0].event_id, state_filter=state_filter,
)
if state_ids:
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index da914c46ff..d2beb275cf 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -220,15 +220,42 @@ class RegistrationHandler(BaseHandler):
# auto-join the user to any rooms we're supposed to dump them into
fake_requester = create_requester(user_id)
+
+ # try to create the room if we're the first user on the server
+ should_auto_create_rooms = False
+ if self.hs.config.autocreate_auto_join_rooms:
+ count = yield self.store.count_all_users()
+ should_auto_create_rooms = count == 1
+
for r in self.hs.config.auto_join_rooms:
try:
- yield self._join_user_to_room(fake_requester, r)
+ if should_auto_create_rooms:
+ room_alias = RoomAlias.from_string(r)
+ if self.hs.hostname != room_alias.domain:
+ logger.warning(
+ 'Cannot create room alias %s, '
+ 'it does not match server domain',
+ r,
+ )
+ else:
+ # create room expects the localpart of the room alias
+ room_alias_localpart = room_alias.localpart
+
+ # getting the RoomCreationHandler during init gives a dependency
+ # loop
+ yield self.hs.get_room_creation_handler().create_room(
+ fake_requester,
+ config={
+ "preset": "public_chat",
+ "room_alias_name": room_alias_localpart
+ },
+ ratelimit=False,
+ )
+ else:
+ yield self._join_user_to_room(fake_requester, r)
except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e)
- # We used to generate default identicons here, but nowadays
- # we want clients to generate their own as part of their branding
- # rather than there being consistent matrix-wide ones, so we don't.
defer.returnValue((user_id, token))
@defer.inlineCallbacks
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index ab1571b27b..3928faa6e7 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -21,7 +21,7 @@ import math
import string
from collections import OrderedDict
-from six import string_types
+from six import iteritems, string_types
from twisted.internet import defer
@@ -32,9 +32,11 @@ from synapse.api.constants import (
JoinRules,
RoomCreationPreset,
)
-from synapse.api.errors import AuthError, Codes, StoreError, SynapseError
+from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
+from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
from synapse.util import stringutils
+from synapse.util.async_helpers import Linearizer
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -72,6 +74,334 @@ class RoomCreationHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
+ self.room_member_handler = hs.get_room_member_handler()
+
+ # linearizer to stop two upgrades happening at once
+ self._upgrade_linearizer = Linearizer("room_upgrade_linearizer")
+
+ @defer.inlineCallbacks
+ def upgrade_room(self, requester, old_room_id, new_version):
+ """Replace a room with a new room with a different version
+
+ Args:
+ requester (synapse.types.Requester): the user requesting the upgrade
+ old_room_id (unicode): the id of the room to be replaced
+ new_version (unicode): the new room version to use
+
+ Returns:
+ Deferred[unicode]: the new room id
+ """
+ yield self.ratelimit(requester)
+
+ user_id = requester.user.to_string()
+
+ with (yield self._upgrade_linearizer.queue(old_room_id)):
+ # start by allocating a new room id
+ r = yield self.store.get_room(old_room_id)
+ if r is None:
+ raise NotFoundError("Unknown room id %s" % (old_room_id,))
+ new_room_id = yield self._generate_room_id(
+ creator_id=user_id, is_public=r["is_public"],
+ )
+
+ logger.info("Creating new room %s to replace %s", new_room_id, old_room_id)
+
+ # we create and auth the tombstone event before properly creating the new
+ # room, to check our user has perms in the old room.
+ tombstone_event, tombstone_context = (
+ yield self.event_creation_handler.create_event(
+ requester, {
+ "type": EventTypes.Tombstone,
+ "state_key": "",
+ "room_id": old_room_id,
+ "sender": user_id,
+ "content": {
+ "body": "This room has been replaced",
+ "replacement_room": new_room_id,
+ }
+ },
+ token_id=requester.access_token_id,
+ )
+ )
+ yield self.auth.check_from_context(tombstone_event, tombstone_context)
+
+ yield self.clone_exiting_room(
+ requester,
+ old_room_id=old_room_id,
+ new_room_id=new_room_id,
+ new_room_version=new_version,
+ tombstone_event_id=tombstone_event.event_id,
+ )
+
+ # now send the tombstone
+ yield self.event_creation_handler.send_nonmember_event(
+ requester, tombstone_event, tombstone_context,
+ )
+
+ old_room_state = yield tombstone_context.get_current_state_ids(self.store)
+
+ # update any aliases
+ yield self._move_aliases_to_new_room(
+ requester, old_room_id, new_room_id, old_room_state,
+ )
+
+ # and finally, shut down the PLs in the old room, and update them in the new
+ # room.
+ yield self._update_upgraded_room_pls(
+ requester, old_room_id, new_room_id, old_room_state,
+ )
+
+ defer.returnValue(new_room_id)
+
+ @defer.inlineCallbacks
+ def _update_upgraded_room_pls(
+ self, requester, old_room_id, new_room_id, old_room_state,
+ ):
+ """Send updated power levels in both rooms after an upgrade
+
+ Args:
+ requester (synapse.types.Requester): the user requesting the upgrade
+ old_room_id (unicode): the id of the room to be replaced
+ new_room_id (unicode): the id of the replacement room
+ old_room_state (dict[tuple[str, str], str]): the state map for the old room
+
+ Returns:
+ Deferred
+ """
+ old_room_pl_event_id = old_room_state.get((EventTypes.PowerLevels, ""))
+
+ if old_room_pl_event_id is None:
+ logger.warning(
+ "Not supported: upgrading a room with no PL event. Not setting PLs "
+ "in old room.",
+ )
+ return
+
+ old_room_pl_state = yield self.store.get_event(old_room_pl_event_id)
+
+ # we try to stop regular users from speaking by setting the PL required
+ # to send regular events and invites to 'Moderator' level. That's normally
+ # 50, but if the default PL in a room is 50 or more, then we set the
+ # required PL above that.
+
+ pl_content = dict(old_room_pl_state.content)
+ users_default = int(pl_content.get("users_default", 0))
+ restricted_level = max(users_default + 1, 50)
+
+ updated = False
+ for v in ("invite", "events_default"):
+ current = int(pl_content.get(v, 0))
+ if current < restricted_level:
+ logger.info(
+ "Setting level for %s in %s to %i (was %i)",
+ v, old_room_id, restricted_level, current,
+ )
+ pl_content[v] = restricted_level
+ updated = True
+ else:
+ logger.info(
+ "Not setting level for %s (already %i)",
+ v, current,
+ )
+
+ if updated:
+ try:
+ yield self.event_creation_handler.create_and_send_nonmember_event(
+ requester, {
+ "type": EventTypes.PowerLevels,
+ "state_key": '',
+ "room_id": old_room_id,
+ "sender": requester.user.to_string(),
+ "content": pl_content,
+ }, ratelimit=False,
+ )
+ except AuthError as e:
+ logger.warning("Unable to update PLs in old room: %s", e)
+
+ logger.info("Setting correct PLs in new room")
+ yield self.event_creation_handler.create_and_send_nonmember_event(
+ requester, {
+ "type": EventTypes.PowerLevels,
+ "state_key": '',
+ "room_id": new_room_id,
+ "sender": requester.user.to_string(),
+ "content": old_room_pl_state.content,
+ }, ratelimit=False,
+ )
+
+ @defer.inlineCallbacks
+ def clone_exiting_room(
+ self, requester, old_room_id, new_room_id, new_room_version,
+ tombstone_event_id,
+ ):
+ """Populate a new room based on an old room
+
+ Args:
+ requester (synapse.types.Requester): the user requesting the upgrade
+ old_room_id (unicode): the id of the room to be replaced
+ new_room_id (unicode): the id to give the new room (should already have been
+ created with _gemerate_room_id())
+ new_room_version (unicode): the new room version to use
+ tombstone_event_id (unicode|str): the ID of the tombstone event in the old
+ room.
+ Returns:
+ Deferred[None]
+ """
+ user_id = requester.user.to_string()
+
+ if not self.spam_checker.user_may_create_room(user_id):
+ raise SynapseError(403, "You are not permitted to create rooms")
+
+ creation_content = {
+ "room_version": new_room_version,
+ "predecessor": {
+ "room_id": old_room_id,
+ "event_id": tombstone_event_id,
+ }
+ }
+
+ initial_state = dict()
+
+ types_to_copy = (
+ (EventTypes.JoinRules, ""),
+ (EventTypes.Name, ""),
+ (EventTypes.Topic, ""),
+ (EventTypes.RoomHistoryVisibility, ""),
+ (EventTypes.GuestAccess, ""),
+ (EventTypes.RoomAvatar, ""),
+ )
+
+ old_room_state_ids = yield self.store.get_filtered_current_state_ids(
+ old_room_id, StateFilter.from_types(types_to_copy),
+ )
+ # map from event_id to BaseEvent
+ old_room_state_events = yield self.store.get_events(old_room_state_ids.values())
+
+ for k, old_event_id in iteritems(old_room_state_ids):
+ old_event = old_room_state_events.get(old_event_id)
+ if old_event:
+ initial_state[k] = old_event.content
+
+ yield self._send_events_for_new_room(
+ requester,
+ new_room_id,
+
+ # we expect to override all the presets with initial_state, so this is
+ # somewhat arbitrary.
+ preset_config=RoomCreationPreset.PRIVATE_CHAT,
+
+ invite_list=[],
+ initial_state=initial_state,
+ creation_content=creation_content,
+ )
+
+ # XXX invites/joins
+ # XXX 3pid invites
+
+ @defer.inlineCallbacks
+ def _move_aliases_to_new_room(
+ self, requester, old_room_id, new_room_id, old_room_state,
+ ):
+ directory_handler = self.hs.get_handlers().directory_handler
+
+ aliases = yield self.store.get_aliases_for_room(old_room_id)
+
+ # check to see if we have a canonical alias.
+ canonical_alias = None
+ canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, ""))
+ if canonical_alias_event_id:
+ canonical_alias_event = yield self.store.get_event(canonical_alias_event_id)
+ if canonical_alias_event:
+ canonical_alias = canonical_alias_event.content.get("alias", "")
+
+ # first we try to remove the aliases from the old room (we suppress sending
+ # the room_aliases event until the end).
+ #
+ # Note that we'll only be able to remove aliases that (a) aren't owned by an AS,
+ # and (b) unless the user is a server admin, which the user created.
+ #
+ # This is probably correct - given we don't allow such aliases to be deleted
+ # normally, it would be odd to allow it in the case of doing a room upgrade -
+ # but it makes the upgrade less effective, and you have to wonder why a room
+ # admin can't remove aliases that point to that room anyway.
+ # (cf https://github.com/matrix-org/synapse/issues/2360)
+ #
+ removed_aliases = []
+ for alias_str in aliases:
+ alias = RoomAlias.from_string(alias_str)
+ try:
+ yield directory_handler.delete_association(
+ requester, alias, send_event=False,
+ )
+ removed_aliases.append(alias_str)
+ except SynapseError as e:
+ logger.warning(
+ "Unable to remove alias %s from old room: %s",
+ alias, e,
+ )
+
+ # if we didn't find any aliases, or couldn't remove anyway, we can skip the rest
+ # of this.
+ if not removed_aliases:
+ return
+
+ try:
+ # this can fail if, for some reason, our user doesn't have perms to send
+ # m.room.aliases events in the old room (note that we've already checked that
+ # they have perms to send a tombstone event, so that's not terribly likely).
+ #
+ # If that happens, it's regrettable, but we should carry on: it's the same
+ # as when you remove an alias from the directory normally - it just means that
+ # the aliases event gets out of sync with the directory
+ # (cf https://github.com/vector-im/riot-web/issues/2369)
+ yield directory_handler.send_room_alias_update_event(
+ requester, old_room_id,
+ )
+ except AuthError as e:
+ logger.warning(
+ "Failed to send updated alias event on old room: %s", e,
+ )
+
+ # we can now add any aliases we successfully removed to the new room.
+ for alias in removed_aliases:
+ try:
+ yield directory_handler.create_association(
+ requester, RoomAlias.from_string(alias),
+ new_room_id, servers=(self.hs.hostname, ),
+ send_event=False,
+ )
+ logger.info("Moved alias %s to new room", alias)
+ except SynapseError as e:
+ # I'm not really expecting this to happen, but it could if the spam
+ # checking module decides it shouldn't, or similar.
+ logger.error(
+ "Error adding alias %s to new room: %s",
+ alias, e,
+ )
+
+ try:
+ if canonical_alias and (canonical_alias in removed_aliases):
+ yield self.event_creation_handler.create_and_send_nonmember_event(
+ requester,
+ {
+ "type": EventTypes.CanonicalAlias,
+ "state_key": "",
+ "room_id": new_room_id,
+ "sender": requester.user.to_string(),
+ "content": {"alias": canonical_alias, },
+ },
+ ratelimit=False
+ )
+
+ yield directory_handler.send_room_alias_update_event(
+ requester, new_room_id,
+ )
+ except SynapseError as e:
+ # again I'm not really expecting this to fail, but if it does, I'd rather
+ # we returned the new room to the client at this point.
+ logger.error(
+ "Unable to send updated alias events in new room: %s", e,
+ )
@defer.inlineCallbacks
def create_room(self, requester, config, ratelimit=True,
@@ -164,28 +494,7 @@ class RoomCreationHandler(BaseHandler):
visibility = config.get("visibility", None)
is_public = visibility == "public"
- # autogen room IDs and try to create it. We may clash, so just
- # try a few times till one goes through, giving up eventually.
- attempts = 0
- room_id = None
- while attempts < 5:
- try:
- random_string = stringutils.random_string(18)
- gen_room_id = RoomID(
- random_string,
- self.hs.hostname,
- )
- yield self.store.store_room(
- room_id=gen_room_id.to_string(),
- room_creator_user_id=user_id,
- is_public=is_public
- )
- room_id = gen_room_id.to_string()
- break
- except StoreError:
- attempts += 1
- if not room_id:
- raise StoreError(500, "Couldn't generate a room ID.")
+ room_id = yield self._generate_room_id(creator_id=user_id, is_public=is_public)
if room_alias:
directory_handler = self.hs.get_handlers().directory_handler
@@ -215,18 +524,15 @@ class RoomCreationHandler(BaseHandler):
# override any attempt to set room versions via the creation_content
creation_content["room_version"] = room_version
- room_member_handler = self.hs.get_room_member_handler()
-
yield self._send_events_for_new_room(
requester,
room_id,
- room_member_handler,
preset_config=preset_config,
invite_list=invite_list,
initial_state=initial_state,
creation_content=creation_content,
room_alias=room_alias,
- power_level_content_override=config.get("power_level_content_override", {}),
+ power_level_content_override=config.get("power_level_content_override"),
creator_join_profile=creator_join_profile,
)
@@ -262,7 +568,7 @@ class RoomCreationHandler(BaseHandler):
if is_direct:
content["is_direct"] = is_direct
- yield room_member_handler.update_membership(
+ yield self.room_member_handler.update_membership(
requester,
UserID.from_string(invitee),
room_id,
@@ -300,14 +606,13 @@ class RoomCreationHandler(BaseHandler):
self,
creator, # A Requester object.
room_id,
- room_member_handler,
preset_config,
invite_list,
initial_state,
creation_content,
- room_alias,
- power_level_content_override,
- creator_join_profile,
+ room_alias=None,
+ power_level_content_override=None,
+ creator_join_profile=None,
):
def create(etype, content, **kwargs):
e = {
@@ -323,6 +628,7 @@ class RoomCreationHandler(BaseHandler):
@defer.inlineCallbacks
def send(etype, content, **kwargs):
event = create(etype, content, **kwargs)
+ logger.info("Sending %s in new room", etype)
yield self.event_creation_handler.create_and_send_nonmember_event(
creator,
event,
@@ -345,7 +651,8 @@ class RoomCreationHandler(BaseHandler):
content=creation_content,
)
- yield room_member_handler.update_membership(
+ logger.info("Sending %s in new room", EventTypes.Member)
+ yield self.room_member_handler.update_membership(
creator,
creator.user,
room_id,
@@ -387,7 +694,8 @@ class RoomCreationHandler(BaseHandler):
for invitee in invite_list:
power_level_content["users"][invitee] = 100
- power_level_content.update(power_level_content_override)
+ if power_level_content_override:
+ power_level_content.update(power_level_content_override)
yield send(
etype=EventTypes.PowerLevels,
@@ -426,6 +734,30 @@ class RoomCreationHandler(BaseHandler):
content=content,
)
+ @defer.inlineCallbacks
+ def _generate_room_id(self, creator_id, is_public):
+ # autogen room IDs and try to create it. We may clash, so just
+ # try a few times till one goes through, giving up eventually.
+ attempts = 0
+ while attempts < 5:
+ try:
+ random_string = stringutils.random_string(18)
+ gen_room_id = RoomID(
+ random_string,
+ self.hs.hostname,
+ ).to_string()
+ if isinstance(gen_room_id, bytes):
+ gen_room_id = gen_room_id.decode('utf-8')
+ yield self.store.store_room(
+ room_id=gen_room_id,
+ room_creator_user_id=creator_id,
+ is_public=is_public,
+ )
+ defer.returnValue(gen_room_id)
+ except StoreError:
+ attempts += 1
+ raise StoreError(500, "Couldn't generate a room ID.")
+
class RoomContextHandler(object):
def __init__(self, hs):
@@ -489,23 +821,24 @@ class RoomContextHandler(object):
else:
last_event_id = event_id
- types = None
- filtered_types = None
if event_filter and event_filter.lazy_load_members():
- members = set(ev.sender for ev in itertools.chain(
- results["events_before"],
- (results["event"],),
- results["events_after"],
- ))
- filtered_types = [EventTypes.Member]
- types = [(EventTypes.Member, member) for member in members]
+ state_filter = StateFilter.from_lazy_load_member_list(
+ ev.sender
+ for ev in itertools.chain(
+ results["events_before"],
+ (results["event"],),
+ results["events_after"],
+ )
+ )
+ else:
+ state_filter = StateFilter.all()
# XXX: why do we return the state as of the last event rather than the
# first? Shouldn't we be consistent with /sync?
# https://github.com/matrix-org/matrix-doc/issues/687
state = yield self.store.get_state_for_events(
- [last_event_id], types, filtered_types=filtered_types,
+ [last_event_id], state_filter=state_filter,
)
results["state"] = list(state[last_event_id].values())
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 0c1d52fd11..80e7b15de8 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -24,6 +24,7 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.api.filtering import Filter
from synapse.events.utils import serialize_event
+from synapse.storage.state import StateFilter
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -324,9 +325,12 @@ class SearchHandler(BaseHandler):
else:
last_event_id = event.event_id
+ state_filter = StateFilter.from_types(
+ [(EventTypes.Member, sender) for sender in senders]
+ )
+
state = yield self.store.get_state_for_event(
- last_event_id,
- types=[(EventTypes.Member, sender) for sender in senders]
+ last_event_id, state_filter
)
res["profile_info"] = {
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 351892a94f..09739f2862 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -27,6 +27,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.roommember import MemberSummary
+from synapse.storage.state import StateFilter
from synapse.types import RoomStreamToken
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
@@ -469,25 +470,20 @@ class SyncHandler(object):
))
@defer.inlineCallbacks
- def get_state_after_event(self, event, types=None, filtered_types=None):
+ def get_state_after_event(self, event, state_filter=StateFilter.all()):
"""
Get the room state after the given event
Args:
event(synapse.events.EventBase): event of interest
- types(list[(str, str|None)]|None): List of (type, state_key) tuples
- which are used to filter the state fetched. If `state_key` is None,
- all events are returned of the given type.
- May be None, which matches any key.
- filtered_types(list[str]|None): Only apply filtering via `types` to this
- list of event types. Other types of events are returned unfiltered.
- If None, `types` filtering is applied to all events.
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
Returns:
A Deferred map from ((type, state_key)->Event)
"""
state_ids = yield self.store.get_state_ids_for_event(
- event.event_id, types, filtered_types=filtered_types,
+ event.event_id, state_filter=state_filter,
)
if event.is_state():
state_ids = state_ids.copy()
@@ -495,18 +491,14 @@ class SyncHandler(object):
defer.returnValue(state_ids)
@defer.inlineCallbacks
- def get_state_at(self, room_id, stream_position, types=None, filtered_types=None):
+ def get_state_at(self, room_id, stream_position, state_filter=StateFilter.all()):
""" Get the room state at a particular stream position
Args:
room_id(str): room for which to get state
stream_position(StreamToken): point at which to get state
- types(list[(str, str|None)]|None): List of (type, state_key) tuples
- which are used to filter the state fetched. If `state_key` is None,
- all events are returned of the given type.
- filtered_types(list[str]|None): Only apply filtering via `types` to this
- list of event types. Other types of events are returned unfiltered.
- If None, `types` filtering is applied to all events.
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
Returns:
A Deferred map from ((type, state_key)->Event)
@@ -522,7 +514,7 @@ class SyncHandler(object):
if last_events:
last_event = last_events[-1]
state = yield self.get_state_after_event(
- last_event, types, filtered_types=filtered_types,
+ last_event, state_filter=state_filter,
)
else:
@@ -563,10 +555,11 @@ class SyncHandler(object):
last_event = last_events[-1]
state_ids = yield self.store.get_state_ids_for_event(
- last_event.event_id, [
+ last_event.event_id,
+ state_filter=StateFilter.from_types([
(EventTypes.Name, ''),
(EventTypes.CanonicalAlias, ''),
- ]
+ ]),
)
# this is heavily cached, thus: fast.
@@ -717,8 +710,7 @@ class SyncHandler(object):
with Measure(self.clock, "compute_state_delta"):
- types = None
- filtered_types = None
+ members_to_fetch = None
lazy_load_members = sync_config.filter_collection.lazy_load_members()
include_redundant_members = (
@@ -729,16 +721,21 @@ class SyncHandler(object):
# We only request state for the members needed to display the
# timeline:
- types = [
- (EventTypes.Member, state_key)
- for state_key in set(
- event.sender # FIXME: we also care about invite targets etc.
- for event in batch.events
- )
- ]
+ members_to_fetch = set(
+ event.sender # FIXME: we also care about invite targets etc.
+ for event in batch.events
+ )
+
+ if full_state:
+ # always make sure we LL ourselves so we know we're in the room
+ # (if we are) to fix https://github.com/vector-im/riot-web/issues/7209
+ # We only need apply this on full state syncs given we disabled
+ # LL for incr syncs in #3840.
+ members_to_fetch.add(sync_config.user.to_string())
- # only apply the filtering to room members
- filtered_types = [EventTypes.Member]
+ state_filter = StateFilter.from_lazy_load_member_list(members_to_fetch)
+ else:
+ state_filter = StateFilter.all()
timeline_state = {
(event.type, event.state_key): event.event_id
@@ -746,28 +743,19 @@ class SyncHandler(object):
}
if full_state:
- if lazy_load_members:
- # always make sure we LL ourselves so we know we're in the room
- # (if we are) to fix https://github.com/vector-im/riot-web/issues/7209
- # We only need apply this on full state syncs given we disabled
- # LL for incr syncs in #3840.
- types.append((EventTypes.Member, sync_config.user.to_string()))
-
if batch:
current_state_ids = yield self.store.get_state_ids_for_event(
- batch.events[-1].event_id, types=types,
- filtered_types=filtered_types,
+ batch.events[-1].event_id, state_filter=state_filter,
)
state_ids = yield self.store.get_state_ids_for_event(
- batch.events[0].event_id, types=types,
- filtered_types=filtered_types,
+ batch.events[0].event_id, state_filter=state_filter,
)
else:
current_state_ids = yield self.get_state_at(
- room_id, stream_position=now_token, types=types,
- filtered_types=filtered_types,
+ room_id, stream_position=now_token,
+ state_filter=state_filter,
)
state_ids = current_state_ids
@@ -781,8 +769,7 @@ class SyncHandler(object):
)
elif batch.limited:
state_at_timeline_start = yield self.store.get_state_ids_for_event(
- batch.events[0].event_id, types=types,
- filtered_types=filtered_types,
+ batch.events[0].event_id, state_filter=state_filter,
)
# for now, we disable LL for gappy syncs - see
@@ -797,17 +784,15 @@ class SyncHandler(object):
# members to just be ones which were timeline senders, which then ensures
# all of the rest get included in the state block (if we need to know
# about them).
- types = None
- filtered_types = None
+ state_filter = StateFilter.all()
state_at_previous_sync = yield self.get_state_at(
- room_id, stream_position=since_token, types=types,
- filtered_types=filtered_types,
+ room_id, stream_position=since_token,
+ state_filter=state_filter,
)
current_state_ids = yield self.store.get_state_ids_for_event(
- batch.events[-1].event_id, types=types,
- filtered_types=filtered_types,
+ batch.events[-1].event_id, state_filter=state_filter,
)
state_ids = _calculate_state(
@@ -821,7 +806,7 @@ class SyncHandler(object):
else:
state_ids = {}
if lazy_load_members:
- if types and batch.events:
+ if members_to_fetch and batch.events:
# We're returning an incremental sync, with no
# "gap" since the previous sync, so normally there would be
# no state to return.
@@ -831,8 +816,12 @@ class SyncHandler(object):
# timeline here, and then dedupe any redundant ones below.
state_ids = yield self.store.get_state_ids_for_event(
- batch.events[0].event_id, types=types,
- filtered_types=None, # we only want members!
+ batch.events[0].event_id,
+ # we only want members!
+ state_filter=StateFilter.from_types(
+ (EventTypes.Member, member)
+ for member in members_to_fetch
+ ),
)
if lazy_load_members and not include_redundant_members:
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index f369124258..50e1007d84 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -85,7 +85,10 @@ class EmailPusher(object):
self.timed_call = None
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
- self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering)
+ if self.max_stream_ordering:
+ self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering)
+ else:
+ self.max_stream_ordering = max_stream_ordering
self._start_processing()
def on_new_receipts(self, min_stream_id, max_stream_id):
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 16fb5e8471..ebcb93bfc7 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -26,7 +26,6 @@ import bleach
import jinja2
from twisted.internet import defer
-from twisted.mail.smtp import sendmail
from synapse.api.constants import EventTypes
from synapse.api.errors import StoreError
@@ -85,6 +84,7 @@ class Mailer(object):
self.notif_template_html = notif_template_html
self.notif_template_text = notif_template_text
+ self.sendmail = self.hs.get_sendmail()
self.store = self.hs.get_datastore()
self.macaroon_gen = self.hs.get_macaroon_generator()
self.state_handler = self.hs.get_state_handler()
@@ -191,11 +191,11 @@ class Mailer(object):
multipart_msg.attach(html_part)
logger.info("Sending email push notification to %s" % email_address)
- # logger.debug(html_text)
- yield sendmail(
+ yield self.sendmail(
self.hs.config.email_smtp_host,
- raw_from, raw_to, multipart_msg.as_string(),
+ raw_from, raw_to, multipart_msg.as_string().encode('utf8'),
+ reactor=self.hs.get_reactor(),
port=self.hs.config.email_smtp_port,
requireAuthentication=self.hs.config.email_smtp_user is not None,
username=self.hs.config.email_smtp_user,
@@ -333,7 +333,7 @@ class Mailer(object):
notif_events, user_id, reason):
if len(notifs_by_room) == 1:
# Only one room has new stuff
- room_id = notifs_by_room.keys()[0]
+ room_id = list(notifs_by_room.keys())[0]
# If the room has some kind of name, use it, but we don't
# want the generated-from-names one here otherwise we'll
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 943876456b..ca62ee7637 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -51,7 +51,6 @@ REQUIREMENTS = {
"daemonize>=2.3.1": ["daemonize"],
"bcrypt>=3.1.0": ["bcrypt>=3.1.0"],
"pillow>=3.1.2": ["PIL"],
- "pydenticon>=0.2": ["pydenticon"],
"sortedcontainers>=1.4.4": ["sortedcontainers"],
"psutil>=2.0.0": ["psutil>=2.0.0"],
"pysaml2>=3.0.0": ["saml2"],
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index cbe9645817..586dddb40b 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -106,7 +106,7 @@ class ReplicationClientHandler(object):
Can be overriden in subclasses to handle more.
"""
- logger.info("Received rdata %s -> %s", stream_name, token)
+ logger.debug("Received rdata %s -> %s", stream_name, token)
return self.store.process_replication_rows(stream_name, token, rows)
def on_position(self, stream_name, token):
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 5dc7b3fffc..0b3fe6cbf5 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -656,7 +656,7 @@ tcp_inbound_commands = LaterGauge(
"",
["command", "name"],
lambda: {
- (k[0], p.name,): count
+ (k, p.name,): count
for p in connected_connections
for k, count in iteritems(p.inbound_commands_counter)
},
@@ -667,7 +667,7 @@ tcp_outbound_commands = LaterGauge(
"",
["command", "name"],
lambda: {
- (k[0], p.name,): count
+ (k, p.name,): count
for p in connected_connections
for k, count in iteritems(p.outbound_commands_counter)
},
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index 4856822a5d..5f35c2d1be 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -47,6 +47,7 @@ from synapse.rest.client.v2_alpha import (
register,
report_event,
room_keys,
+ room_upgrade_rest_servlet,
sendtodevice,
sync,
tags,
@@ -116,3 +117,4 @@ class ClientRestResource(JsonResource):
sendtodevice.register_servlets(hs, client_resource)
user_directory.register_servlets(hs, client_resource)
groups.register_servlets(hs, client_resource)
+ room_upgrade_rest_servlet.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 663934efd0..fcfe7857f6 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -33,6 +33,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
@@ -409,7 +410,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
room_id=room_id,
user_id=requester.user.to_string(),
at_token=at_token,
- types=[(EventTypes.Member, None)],
+ state_filter=StateFilter.from_types([(EventTypes.Member, None)]),
)
chunk = []
diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
new file mode 100644
index 0000000000..e6356101fd
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
@@ -0,0 +1,89 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+
+from twisted.internet import defer
+
+from synapse.api.constants import KNOWN_ROOM_VERSIONS
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
+
+from ._base import client_v2_patterns
+
+logger = logging.getLogger(__name__)
+
+
+class RoomUpgradeRestServlet(RestServlet):
+ """Handler for room uprade requests.
+
+ Handles requests of the form:
+
+ POST /_matrix/client/r0/rooms/$roomid/upgrade HTTP/1.1
+ Content-Type: application/json
+
+ {
+ "new_version": "2",
+ }
+
+ Creates a new room and shuts down the old one. Returns the ID of the new room.
+
+ Args:
+ hs (synapse.server.HomeServer):
+ """
+ PATTERNS = client_v2_patterns(
+ # /rooms/$roomid/upgrade
+ "/rooms/(?P<room_id>[^/]*)/upgrade$",
+ v2_alpha=False,
+ )
+
+ def __init__(self, hs):
+ super(RoomUpgradeRestServlet, self).__init__()
+ self._hs = hs
+ self._room_creation_handler = hs.get_room_creation_handler()
+ self._auth = hs.get_auth()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, room_id):
+ requester = yield self._auth.get_user_by_req(request)
+
+ content = parse_json_object_from_request(request)
+ assert_params_in_dict(content, ("new_version", ))
+ new_version = content["new_version"]
+
+ if new_version not in KNOWN_ROOM_VERSIONS:
+ raise SynapseError(
+ 400,
+ "Your homeserver does not support this room version",
+ Codes.UNSUPPORTED_ROOM_VERSION,
+ )
+
+ new_room_id = yield self._room_creation_handler.upgrade_room(
+ requester, room_id, new_version
+ )
+
+ ret = {
+ "replacement_room": new_room_id,
+ }
+
+ defer.returnValue((200, ret))
+
+
+def register_servlets(hs, http_server):
+ RoomUpgradeRestServlet(hs).register(http_server)
diff --git a/synapse/rest/key/v1/__init__.py b/synapse/rest/key/v1/__init__.py
deleted file mode 100644
index fe0ac3f8e9..0000000000
--- a/synapse/rest/key/v1/__init__.py
+++ /dev/null
@@ -1,14 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
diff --git a/synapse/rest/key/v1/server_key_resource.py b/synapse/rest/key/v1/server_key_resource.py
deleted file mode 100644
index 38eb2ee23f..0000000000
--- a/synapse/rest/key/v1/server_key_resource.py
+++ /dev/null
@@ -1,92 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-import logging
-
-from canonicaljson import encode_canonical_json
-from signedjson.sign import sign_json
-from unpaddedbase64 import encode_base64
-
-from OpenSSL import crypto
-from twisted.web.resource import Resource
-
-from synapse.http.server import respond_with_json_bytes
-
-logger = logging.getLogger(__name__)
-
-
-class LocalKey(Resource):
- """HTTP resource containing encoding the TLS X.509 certificate and NACL
- signature verification keys for this server::
-
- GET /key HTTP/1.1
-
- HTTP/1.1 200 OK
- Content-Type: application/json
- {
- "server_name": "this.server.example.com"
- "verify_keys": {
- "algorithm:version": # base64 encoded NACL verification key.
- },
- "tls_certificate": # base64 ASN.1 DER encoded X.509 tls cert.
- "signatures": {
- "this.server.example.com": {
- "algorithm:version": # NACL signature for this server.
- }
- }
- }
- """
-
- def __init__(self, hs):
- self.response_body = encode_canonical_json(
- self.response_json_object(hs.config)
- )
- Resource.__init__(self)
-
- @staticmethod
- def response_json_object(server_config):
- verify_keys = {}
- for key in server_config.signing_key:
- verify_key_bytes = key.verify_key.encode()
- key_id = "%s:%s" % (key.alg, key.version)
- verify_keys[key_id] = encode_base64(verify_key_bytes)
-
- x509_certificate_bytes = crypto.dump_certificate(
- crypto.FILETYPE_ASN1,
- server_config.tls_certificate
- )
- json_object = {
- u"server_name": server_config.server_name,
- u"verify_keys": verify_keys,
- u"tls_certificate": encode_base64(x509_certificate_bytes)
- }
- for key in server_config.signing_key:
- json_object = sign_json(
- json_object,
- server_config.server_name,
- key,
- )
-
- return json_object
-
- def render_GET(self, request):
- return respond_with_json_bytes(
- request, 200, self.response_body,
- )
-
- def getChild(self, name, request):
- if name == b'':
- return self
diff --git a/synapse/rest/media/v1/identicon_resource.py b/synapse/rest/media/v1/identicon_resource.py
deleted file mode 100644
index bdbd8d50dd..0000000000
--- a/synapse/rest/media/v1/identicon_resource.py
+++ /dev/null
@@ -1,68 +0,0 @@
-# Copyright 2015, 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from pydenticon import Generator
-
-from twisted.web.resource import Resource
-
-from synapse.http.servlet import parse_integer
-
-FOREGROUND = [
- "rgb(45,79,255)",
- "rgb(254,180,44)",
- "rgb(226,121,234)",
- "rgb(30,179,253)",
- "rgb(232,77,65)",
- "rgb(49,203,115)",
- "rgb(141,69,170)"
-]
-
-BACKGROUND = "rgb(224,224,224)"
-SIZE = 5
-
-
-class IdenticonResource(Resource):
- isLeaf = True
-
- def __init__(self):
- Resource.__init__(self)
- self.generator = Generator(
- SIZE, SIZE, foreground=FOREGROUND, background=BACKGROUND,
- )
-
- def generate_identicon(self, name, width, height):
- v_padding = width % SIZE
- h_padding = height % SIZE
- top_padding = v_padding // 2
- left_padding = h_padding // 2
- bottom_padding = v_padding - top_padding
- right_padding = h_padding - left_padding
- width -= v_padding
- height -= h_padding
- padding = (top_padding, bottom_padding, left_padding, right_padding)
- identicon = self.generator.generate(
- name, width, height, padding=padding
- )
- return identicon
-
- def render_GET(self, request):
- name = "/".join(request.postpath)
- width = parse_integer(request, "width", default=96)
- height = parse_integer(request, "height", default=96)
- identicon_bytes = self.generate_identicon(name, width, height)
- request.setHeader(b"Content-Type", b"image/png")
- request.setHeader(
- b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
- )
- return identicon_bytes
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 08b1867fab..d6c5f07af0 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -45,7 +45,6 @@ from ._base import FileInfo, respond_404, respond_with_responder
from .config_resource import MediaConfigResource
from .download_resource import DownloadResource
from .filepath import MediaFilePaths
-from .identicon_resource import IdenticonResource
from .media_storage import MediaStorage
from .preview_url_resource import PreviewUrlResource
from .storage_provider import StorageProviderWrapper
@@ -769,7 +768,6 @@ class MediaRepositoryResource(Resource):
self.putChild(b"thumbnail", ThumbnailResource(
hs, media_repo, media_repo.media_storage,
))
- self.putChild(b"identicon", IdenticonResource())
if hs.config.url_preview_enabled:
self.putChild(b"preview_url", PreviewUrlResource(
hs, media_repo, media_repo.media_storage,
diff --git a/synapse/server.py b/synapse/server.py
index 3e9d3d8256..9985687b95 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -23,6 +23,7 @@ import abc
import logging
from twisted.enterprise import adbapi
+from twisted.mail.smtp import sendmail
from twisted.web.client import BrowserLikePolicyForHTTPS
from synapse.api.auth import Auth
@@ -174,6 +175,7 @@ class HomeServer(object):
'message_handler',
'pagination_handler',
'room_context_handler',
+ 'sendmail',
]
# This is overridden in derived application classes
@@ -207,6 +209,7 @@ class HomeServer(object):
logger.info("Setting up.")
with self.get_db_conn() as conn:
self.datastore = self.DATASTORE_CLASS(conn, self)
+ conn.commit()
logger.info("Finished setting up.")
def get_reactor(self):
@@ -268,6 +271,9 @@ class HomeServer(object):
def build_room_creation_handler(self):
return RoomCreationHandler(self)
+ def build_sendmail(self):
+ return sendmail
+
def build_state_handler(self):
return StateHandler(self)
diff --git a/synapse/server.pyi b/synapse/server.pyi
index ce28486233..06cd083a74 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -7,6 +7,9 @@ import synapse.handlers.auth
import synapse.handlers.deactivate_account
import synapse.handlers.device
import synapse.handlers.e2e_keys
+import synapse.handlers.room
+import synapse.handlers.room_member
+import synapse.handlers.message
import synapse.handlers.set_password
import synapse.rest.media.v1.media_repository
import synapse.server_notices.server_notices_manager
@@ -50,6 +53,9 @@ class HomeServer(object):
def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler:
pass
+ def get_room_member_handler(self) -> synapse.handlers.room_member.RoomMemberHandler:
+ pass
+
def get_event_creation_handler(self) -> synapse.handlers.message.EventCreationHandler:
pass
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index d10ff9e4b9..62497ab63f 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -589,10 +589,14 @@ class DeviceStore(SQLBaseStore):
combined list of changes to devices, and which destinations need to be
poked. `destination` may be None if no destinations need to be poked.
"""
+ # We do a group by here as there can be a large number of duplicate
+ # entries, since we throw away device IDs.
sql = """
- SELECT stream_id, user_id, destination FROM device_lists_stream
+ SELECT MAX(stream_id) AS stream_id, user_id, destination
+ FROM device_lists_stream
LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
WHERE ? < stream_id AND stream_id <= ?
+ GROUP BY user_id, destination
"""
return self._execute(
"get_all_device_list_changes_for_remotes", None,
diff --git a/synapse/storage/directory.py b/synapse/storage/directory.py
index cfb687cb53..61a029a53c 100644
--- a/synapse/storage/directory.py
+++ b/synapse/storage/directory.py
@@ -90,7 +90,7 @@ class DirectoryWorkerStore(SQLBaseStore):
class DirectoryStore(DirectoryWorkerStore):
@defer.inlineCallbacks
def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
- """ Creates an associatin between a room alias and room_id/servers
+ """ Creates an association between a room alias and room_id/servers
Args:
room_alias (RoomAlias)
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index c780f55277..919e855f3b 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -38,6 +38,7 @@ from synapse.state import StateResolutionStore
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.events_worker import EventsWorkerStore
+from synapse.storage.state import StateGroupWorkerStore
from synapse.types import RoomStreamToken, get_domain_from_id
from synapse.util import batch_iter
from synapse.util.async_helpers import ObservableDeferred
@@ -205,7 +206,8 @@ def _retry_on_integrity_error(func):
# inherits from EventFederationStore so that we can call _update_backward_extremities
# and _handle_mult_prev_events (though arguably those could both be moved in here)
-class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore):
+class EventsStore(StateGroupWorkerStore, EventFederationStore, EventsWorkerStore,
+ BackgroundUpdateStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
@@ -2034,62 +2036,44 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
logger.info("[purge] finding redundant state groups")
- # Get all state groups that are only referenced by events that are
- # to be deleted.
- # This works by first getting state groups that we may want to delete,
- # joining against event_to_state_groups to get events that use that
- # state group, then left joining against events_to_purge again. Any
- # state group where the left join produce *no nulls* are referenced
- # only by events that are going to be purged.
+ # Get all state groups that are referenced by events that are to be
+ # deleted. We then go and check if they are referenced by other events
+ # or state groups, and if not we delete them.
txn.execute("""
- SELECT state_group FROM
- (
- SELECT DISTINCT state_group FROM events_to_purge
- INNER JOIN event_to_state_groups USING (event_id)
- ) AS sp
- INNER JOIN event_to_state_groups USING (state_group)
- LEFT JOIN events_to_purge AS ep USING (event_id)
- GROUP BY state_group
- HAVING SUM(CASE WHEN ep.event_id IS NULL THEN 1 ELSE 0 END) = 0
+ SELECT DISTINCT state_group FROM events_to_purge
+ INNER JOIN event_to_state_groups USING (event_id)
""")
- state_rows = txn.fetchall()
- logger.info("[purge] found %i redundant state groups", len(state_rows))
-
- # make a set of the redundant state groups, so that we can look them up
- # efficiently
- state_groups_to_delete = set([sg for sg, in state_rows])
-
- # Now we get all the state groups that rely on these state groups
- logger.info("[purge] finding state groups which depend on redundant"
- " state groups")
- remaining_state_groups = []
- for i in range(0, len(state_rows), 100):
- chunk = [sg for sg, in state_rows[i:i + 100]]
- # look for state groups whose prev_state_group is one we are about
- # to delete
- rows = self._simple_select_many_txn(
- txn,
- table="state_group_edges",
- column="prev_state_group",
- iterable=chunk,
- retcols=["state_group"],
- keyvalues={},
- )
- remaining_state_groups.extend(
- row["state_group"] for row in rows
+ referenced_state_groups = set(sg for sg, in txn)
+ logger.info(
+ "[purge] found %i referenced state groups",
+ len(referenced_state_groups),
+ )
- # exclude state groups we are about to delete: no point in
- # updating them
- if row["state_group"] not in state_groups_to_delete
+ logger.info("[purge] finding state groups that can be deleted")
+
+ state_groups_to_delete, remaining_state_groups = (
+ self._find_unreferenced_groups_during_purge(
+ txn, referenced_state_groups,
)
+ )
+
+ logger.info(
+ "[purge] found %i state groups to delete",
+ len(state_groups_to_delete),
+ )
+
+ logger.info(
+ "[purge] de-delta-ing %i remaining state groups",
+ len(remaining_state_groups),
+ )
# Now we turn the state groups that reference to-be-deleted state
# groups to non delta versions.
for sg in remaining_state_groups:
logger.info("[purge] de-delta-ing remaining state group %s", sg)
curr_state = self._get_state_groups_from_groups_txn(
- txn, [sg], types=None
+ txn, [sg],
)
curr_state = curr_state[sg]
@@ -2127,11 +2111,11 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
logger.info("[purge] removing redundant state groups")
txn.executemany(
"DELETE FROM state_groups_state WHERE state_group = ?",
- state_rows
+ ((sg,) for sg in state_groups_to_delete),
)
txn.executemany(
"DELETE FROM state_groups WHERE id = ?",
- state_rows
+ ((sg,) for sg in state_groups_to_delete),
)
logger.info("[purge] removing events from event_to_state_groups")
@@ -2227,6 +2211,85 @@ class EventsStore(EventFederationStore, EventsWorkerStore, BackgroundUpdateStore
logger.info("[purge] done")
+ def _find_unreferenced_groups_during_purge(self, txn, state_groups):
+ """Used when purging history to figure out which state groups can be
+ deleted and which need to be de-delta'ed (due to one of its prev groups
+ being scheduled for deletion).
+
+ Args:
+ txn
+ state_groups (set[int]): Set of state groups referenced by events
+ that are going to be deleted.
+
+ Returns:
+ tuple[set[int], set[int]]: The set of state groups that can be
+ deleted and the set of state groups that need to be de-delta'ed
+ """
+ # Graph of state group -> previous group
+ graph = {}
+
+ # Set of events that we have found to be referenced by events
+ referenced_groups = set()
+
+ # Set of state groups we've already seen
+ state_groups_seen = set(state_groups)
+
+ # Set of state groups to handle next.
+ next_to_search = set(state_groups)
+ while next_to_search:
+ # We bound size of groups we're looking up at once, to stop the
+ # SQL query getting too big
+ if len(next_to_search) < 100:
+ current_search = next_to_search
+ next_to_search = set()
+ else:
+ current_search = set(itertools.islice(next_to_search, 100))
+ next_to_search -= current_search
+
+ # Check if state groups are referenced
+ sql = """
+ SELECT DISTINCT state_group FROM event_to_state_groups
+ LEFT JOIN events_to_purge AS ep USING (event_id)
+ WHERE state_group IN (%s) AND ep.event_id IS NULL
+ """ % (",".join("?" for _ in current_search),)
+ txn.execute(sql, list(current_search))
+
+ referenced = set(sg for sg, in txn)
+ referenced_groups |= referenced
+
+ # We don't continue iterating up the state group graphs for state
+ # groups that are referenced.
+ current_search -= referenced
+
+ rows = self._simple_select_many_txn(
+ txn,
+ table="state_group_edges",
+ column="prev_state_group",
+ iterable=current_search,
+ keyvalues={},
+ retcols=("prev_state_group", "state_group",),
+ )
+
+ prevs = set(row["state_group"] for row in rows)
+ # We don't bother re-handling groups we've already seen
+ prevs -= state_groups_seen
+ next_to_search |= prevs
+ state_groups_seen |= prevs
+
+ for row in rows:
+ # Note: Each state group can have at most one prev group
+ graph[row["state_group"]] = row["prev_state_group"]
+
+ to_delete = state_groups_seen - referenced_groups
+
+ to_dedelta = set()
+ for sg in referenced_groups:
+ prev_sg = graph.get(sg)
+ if prev_sg and prev_sg in to_delete:
+ to_dedelta.add(sg)
+
+ return to_delete, to_dedelta
+
@defer.inlineCallbacks
def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream
diff --git a/synapse/storage/monthly_active_users.py b/synapse/storage/monthly_active_users.py
index 0fe8c8e24c..cf4104dc2e 100644
--- a/synapse/storage/monthly_active_users.py
+++ b/synapse/storage/monthly_active_users.py
@@ -33,19 +33,29 @@ class MonthlyActiveUsersStore(SQLBaseStore):
self._clock = hs.get_clock()
self.hs = hs
self.reserved_users = ()
+ # Do not add more reserved users than the total allowable number
+ self._initialise_reserved_users(
+ dbconn.cursor(),
+ hs.config.mau_limits_reserved_threepids[:self.hs.config.max_mau_value],
+ )
- @defer.inlineCallbacks
- def initialise_reserved_users(self, threepids):
- store = self.hs.get_datastore()
+ def _initialise_reserved_users(self, txn, threepids):
+ """Ensures that reserved threepids are accounted for in the MAU table, should
+ be called on start up.
+
+ Args:
+ txn (cursor):
+ threepids (list[dict]): List of threepid dicts to reserve
+ """
reserved_user_list = []
- # Do not add more reserved users than the total allowable number
- for tp in threepids[:self.hs.config.max_mau_value]:
- user_id = yield store.get_user_id_by_threepid(
+ for tp in threepids:
+ user_id = self.get_user_id_by_threepid_txn(
+ txn,
tp["medium"], tp["address"]
)
if user_id:
- yield self.upsert_monthly_active_user(user_id)
+ self.upsert_monthly_active_user_txn(txn, user_id)
reserved_user_list.append(user_id)
else:
logger.warning(
@@ -55,8 +65,7 @@ class MonthlyActiveUsersStore(SQLBaseStore):
@defer.inlineCallbacks
def reap_monthly_active_users(self):
- """
- Cleans out monthly active user table to ensure that no stale
+ """Cleans out monthly active user table to ensure that no stale
entries exist.
Returns:
@@ -165,19 +174,44 @@ class MonthlyActiveUsersStore(SQLBaseStore):
@defer.inlineCallbacks
def upsert_monthly_active_user(self, user_id):
+ """Updates or inserts the user into the monthly active user table, which
+ is used to track the current MAU usage of the server
+
+ Args:
+ user_id (str): user to add/update
"""
- Updates or inserts monthly active user member
- Arguments:
- user_id (str): user to add/update
- Deferred[bool]: True if a new entry was created, False if an
- existing one was updated.
+ is_insert = yield self.runInteraction(
+ "upsert_monthly_active_user", self.upsert_monthly_active_user_txn,
+ user_id
+ )
+
+ if is_insert:
+ self.user_last_seen_monthly_active.invalidate((user_id,))
+ self.get_monthly_active_count.invalidate(())
+
+ def upsert_monthly_active_user_txn(self, txn, user_id):
+ """Updates or inserts monthly active user member
+
+ Note that, after calling this method, it will generally be necessary
+ to invalidate the caches on user_last_seen_monthly_active and
+ get_monthly_active_count. We can't do that here, because we are running
+ in a database thread rather than the main thread, and we can't call
+ txn.call_after because txn may not be a LoggingTransaction.
+
+ Args:
+ txn (cursor):
+ user_id (str): user to add/update
+
+ Returns:
+ bool: True if a new entry was created, False if an
+ existing one was updated.
"""
# Am consciously deciding to lock the table on the basis that is ought
# never be a big table and alternative approaches (batching multiple
# upserts into a single txn) introduced a lot of extra complexity.
# See https://github.com/matrix-org/synapse/issues/3854 for more
- is_insert = yield self._simple_upsert(
- desc="upsert_monthly_active_user",
+ is_insert = self._simple_upsert_txn(
+ txn,
table="monthly_active_users",
keyvalues={
"user_id": user_id,
@@ -186,9 +220,8 @@ class MonthlyActiveUsersStore(SQLBaseStore):
"timestamp": int(self._clock.time_msec()),
},
)
- if is_insert:
- self.user_last_seen_monthly_active.invalidate((user_id,))
- self.get_monthly_active_count.invalidate(())
+
+ return is_insert
@cached(num_args=1)
def user_last_seen_monthly_active(self, user_id):
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index b364719312..bd740e1e45 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 51
+SCHEMA_VERSION = 52
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 2dd14aba1c..80d76bf9d7 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -474,17 +474,44 @@ class RegistrationStore(RegistrationWorkerStore,
@defer.inlineCallbacks
def get_user_id_by_threepid(self, medium, address):
- ret = yield self._simple_select_one(
+ """Returns user id from threepid
+
+ Args:
+ medium (str): threepid medium e.g. email
+ address (str): threepid address e.g. me@example.com
+
+ Returns:
+ Deferred[str|None]: user id or None if no user id/threepid mapping exists
+ """
+ user_id = yield self.runInteraction(
+ "get_user_id_by_threepid", self.get_user_id_by_threepid_txn,
+ medium, address
+ )
+ defer.returnValue(user_id)
+
+ def get_user_id_by_threepid_txn(self, txn, medium, address):
+ """Returns user id from threepid
+
+ Args:
+ txn (cursor):
+ medium (str): threepid medium e.g. email
+ address (str): threepid address e.g. me@example.com
+
+ Returns:
+ str|None: user id or None if no user id/threepid mapping exists
+ """
+ ret = self._simple_select_one_txn(
+ txn,
"user_threepids",
{
"medium": medium,
"address": address
},
- ['user_id'], True, 'get_user_id_by_threepid'
+ ['user_id'], True
)
if ret:
- defer.returnValue(ret['user_id'])
- defer.returnValue(None)
+ return ret['user_id']
+ return None
def user_delete_threepid(self, user_id, medium, address):
return self._simple_delete(
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 61013b8919..41c65e112a 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -47,7 +47,7 @@ class RoomWorkerStore(SQLBaseStore):
Args:
room_id (str): The ID of the room to retrieve.
Returns:
- A namedtuple containing the room information, or an empty list.
+ A dict containing the room information, or None if the room is unknown.
"""
return self._simple_select_one(
table="rooms",
diff --git a/synapse/storage/schema/delta/52/add_event_to_state_group_index.sql b/synapse/storage/schema/delta/52/add_event_to_state_group_index.sql
new file mode 100644
index 0000000000..91e03d13e1
--- /dev/null
+++ b/synapse/storage/schema/delta/52/add_event_to_state_group_index.sql
@@ -0,0 +1,19 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This is needed to efficiently check for unreferenced state groups during
+-- purge. Added events_to_state_group(state_group) index
+INSERT into background_updates (update_name, progress_json)
+ VALUES ('event_to_state_groups_sg_index', '{}');
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 3f4cbd61c4..d737bd6778 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -19,6 +19,8 @@ from collections import namedtuple
from six import iteritems, itervalues
from six.moves import range
+import attr
+
from twisted.internet import defer
from synapse.api.constants import EventTypes
@@ -48,6 +50,318 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
return len(self.delta_ids) if self.delta_ids else 0
+@attr.s(slots=True)
+class StateFilter(object):
+ """A filter used when querying for state.
+
+ Attributes:
+ types (dict[str, set[str]|None]): Map from type to set of state keys (or
+ None). This specifies which state_keys for the given type to fetch
+ from the DB. If None then all events with that type are fetched. If
+ the set is empty then no events with that type are fetched.
+ include_others (bool): Whether to fetch events with types that do not
+ appear in `types`.
+ """
+
+ types = attr.ib()
+ include_others = attr.ib(default=False)
+
+ def __attrs_post_init__(self):
+ # If `include_others` is set we canonicalise the filter by removing
+ # wildcards from the types dictionary
+ if self.include_others:
+ self.types = {
+ k: v for k, v in iteritems(self.types)
+ if v is not None
+ }
+
+ @staticmethod
+ def all():
+ """Creates a filter that fetches everything.
+
+ Returns:
+ StateFilter
+ """
+ return StateFilter(types={}, include_others=True)
+
+ @staticmethod
+ def none():
+ """Creates a filter that fetches nothing.
+
+ Returns:
+ StateFilter
+ """
+ return StateFilter(types={}, include_others=False)
+
+ @staticmethod
+ def from_types(types):
+ """Creates a filter that only fetches the given types
+
+ Args:
+ types (Iterable[tuple[str, str|None]]): A list of type and state
+ keys to fetch. A state_key of None fetches everything for
+ that type
+
+ Returns:
+ StateFilter
+ """
+ type_dict = {}
+ for typ, s in types:
+ if typ in type_dict:
+ if type_dict[typ] is None:
+ continue
+
+ if s is None:
+ type_dict[typ] = None
+ continue
+
+ type_dict.setdefault(typ, set()).add(s)
+
+ return StateFilter(types=type_dict)
+
+ @staticmethod
+ def from_lazy_load_member_list(members):
+ """Creates a filter that returns all non-member events, plus the member
+ events for the given users
+
+ Args:
+ members (iterable[str]): Set of user IDs
+
+ Returns:
+ StateFilter
+ """
+ return StateFilter(
+ types={EventTypes.Member: set(members)},
+ include_others=True,
+ )
+
+ def return_expanded(self):
+ """Creates a new StateFilter where type wild cards have been removed
+ (except for memberships). The returned filter is a superset of the
+ current one, i.e. anything that passes the current filter will pass
+ the returned filter.
+
+ This helps the caching as the DictionaryCache knows if it has *all* the
+ state, but does not know if it has all of the keys of a particular type,
+ which makes wildcard lookups expensive unless we have a complete cache.
+ Hence, if we are doing a wildcard lookup, populate the cache fully so
+ that we can do an efficient lookup next time.
+
+ Note that since we have two caches, one for membership events and one for
+ other events, we can be a bit more clever than simply returning
+ `StateFilter.all()` if `has_wildcards()` is True.
+
+ We return a StateFilter where:
+ 1. the list of membership events to return is the same
+ 2. if there is a wildcard that matches non-member events we
+ return all non-member events
+
+ Returns:
+ StateFilter
+ """
+
+ if self.is_full():
+ # If we're going to return everything then there's nothing to do
+ return self
+
+ if not self.has_wildcards():
+ # If there are no wild cards, there's nothing to do
+ return self
+
+ if EventTypes.Member in self.types:
+ get_all_members = self.types[EventTypes.Member] is None
+ else:
+ get_all_members = self.include_others
+
+ has_non_member_wildcard = self.include_others or any(
+ state_keys is None
+ for t, state_keys in iteritems(self.types)
+ if t != EventTypes.Member
+ )
+
+ if not has_non_member_wildcard:
+ # If there are no non-member wild cards we can just return ourselves
+ return self
+
+ if get_all_members:
+ # We want to return everything.
+ return StateFilter.all()
+ else:
+ # We want to return all non-members, but only particular
+ # memberships
+ return StateFilter(
+ types={EventTypes.Member: self.types[EventTypes.Member]},
+ include_others=True,
+ )
+
+ def make_sql_filter_clause(self):
+ """Converts the filter to an SQL clause.
+
+ For example:
+
+ f = StateFilter.from_types([("m.room.create", "")])
+ clause, args = f.make_sql_filter_clause()
+ clause == "(type = ? AND state_key = ?)"
+ args == ['m.room.create', '']
+
+
+ Returns:
+ tuple[str, list]: The SQL string (may be empty) and arguments. An
+ empty SQL string is returned when the filter matches everything
+ (i.e. is "full").
+ """
+
+ where_clause = ""
+ where_args = []
+
+ if self.is_full():
+ return where_clause, where_args
+
+ if not self.include_others and not self.types:
+ # i.e. this is an empty filter, so we need to return a clause that
+ # will match nothing
+ return "1 = 2", []
+
+ # First we build up a lost of clauses for each type/state_key combo
+ clauses = []
+ for etype, state_keys in iteritems(self.types):
+ if state_keys is None:
+ clauses.append("(type = ?)")
+ where_args.append(etype)
+ continue
+
+ for state_key in state_keys:
+ clauses.append("(type = ? AND state_key = ?)")
+ where_args.extend((etype, state_key))
+
+ # This will match anything that appears in `self.types`
+ where_clause = " OR ".join(clauses)
+
+ # If we want to include stuff that's not in the types dict then we add
+ # a `OR type NOT IN (...)` clause to the end.
+ if self.include_others:
+ if where_clause:
+ where_clause += " OR "
+
+ where_clause += "type NOT IN (%s)" % (
+ ",".join(["?"] * len(self.types)),
+ )
+ where_args.extend(self.types)
+
+ return where_clause, where_args
+
+ def max_entries_returned(self):
+ """Returns the maximum number of entries this filter will return if
+ known, otherwise returns None.
+
+ For example a simple state filter asking for `("m.room.create", "")`
+ will return 1, whereas the default state filter will return None.
+
+ This is used to bail out early if the right number of entries have been
+ fetched.
+ """
+ if self.has_wildcards():
+ return None
+
+ return len(self.concrete_types())
+
+ def filter_state(self, state_dict):
+ """Returns the state filtered with by this StateFilter
+
+ Args:
+ state (dict[tuple[str, str], Any]): The state map to filter
+
+ Returns:
+ dict[tuple[str, str], Any]: The filtered state map
+ """
+ if self.is_full():
+ return dict(state_dict)
+
+ filtered_state = {}
+ for k, v in iteritems(state_dict):
+ typ, state_key = k
+ if typ in self.types:
+ state_keys = self.types[typ]
+ if state_keys is None or state_key in state_keys:
+ filtered_state[k] = v
+ elif self.include_others:
+ filtered_state[k] = v
+
+ return filtered_state
+
+ def is_full(self):
+ """Whether this filter fetches everything or not
+
+ Returns:
+ bool
+ """
+ return self.include_others and not self.types
+
+ def has_wildcards(self):
+ """Whether the filter includes wildcards or is attempting to fetch
+ specific state.
+
+ Returns:
+ bool
+ """
+
+ return (
+ self.include_others
+ or any(
+ state_keys is None
+ for state_keys in itervalues(self.types)
+ )
+ )
+
+ def concrete_types(self):
+ """Returns a list of concrete type/state_keys (i.e. not None) that
+ will be fetched. This will be a complete list if `has_wildcards`
+ returns False, but otherwise will be a subset (or even empty).
+
+ Returns:
+ list[tuple[str,str]]
+ """
+ return [
+ (t, s)
+ for t, state_keys in iteritems(self.types)
+ if state_keys is not None
+ for s in state_keys
+ ]
+
+ def get_member_split(self):
+ """Return the filter split into two: one which assumes it's exclusively
+ matching against member state, and one which assumes it's matching
+ against non member state.
+
+ This is useful due to the returned filters giving correct results for
+ `is_full()`, `has_wildcards()`, etc, when operating against maps that
+ either exclusively contain member events or only contain non-member
+ events. (Which is the case when dealing with the member vs non-member
+ state caches).
+
+ Returns:
+ tuple[StateFilter, StateFilter]: The member and non member filters
+ """
+
+ if EventTypes.Member in self.types:
+ state_keys = self.types[EventTypes.Member]
+ if state_keys is None:
+ member_filter = StateFilter.all()
+ else:
+ member_filter = StateFilter({EventTypes.Member: state_keys})
+ elif self.include_others:
+ member_filter = StateFilter.all()
+ else:
+ member_filter = StateFilter.none()
+
+ non_member_filter = StateFilter(
+ types={k: v for k, v in iteritems(self.types) if k != EventTypes.Member},
+ include_others=self.include_others,
+ )
+
+ return member_filter, non_member_filter
+
+
# this inherits from EventsWorkerStore because it calls self.get_events
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""The parts of StateGroupStore that can be called from workers.
@@ -152,61 +466,41 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
)
# FIXME: how should this be cached?
- def get_filtered_current_state_ids(self, room_id, types, filtered_types=None):
+ def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
"""Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result
of doing a fresh state resolution as per state_handler.get_current_state
+
Args:
room_id (str)
- types (list[(Str, (Str|None))]): List of (type, state_key) tuples
- which are used to filter the state fetched. `state_key` may be
- None, which matches any `state_key`
- filtered_types (list[Str]|None): List of types to apply the above filter to.
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+
Returns:
- deferred: dict of (type, state_key) -> event
+ Deferred[dict[tuple[str, str], str]]: Map from type/state_key to
+ event ID.
"""
- include_other_types = False if filtered_types is None else True
-
def _get_filtered_current_state_ids_txn(txn):
results = {}
- sql = """SELECT type, state_key, event_id FROM current_state_events
- WHERE room_id = ? %s"""
- # Turns out that postgres doesn't like doing a list of OR's and
- # is about 1000x slower, so we just issue a query for each specific
- # type seperately.
- if types:
- clause_to_args = [
- (
- "AND type = ? AND state_key = ?",
- (etype, state_key)
- ) if state_key is not None else (
- "AND type = ?",
- (etype,)
- )
- for etype, state_key in types
- ]
-
- if include_other_types:
- unique_types = set(filtered_types)
- clause_to_args.append(
- (
- "AND type <> ? " * len(unique_types),
- list(unique_types)
- )
- )
- else:
- # If types is None we fetch all the state, and so just use an
- # empty where clause with no extra args.
- clause_to_args = [("", [])]
- for where_clause, where_args in clause_to_args:
- args = [room_id]
- args.extend(where_args)
- txn.execute(sql % (where_clause,), args)
- for row in txn:
- typ, state_key, event_id = row
- key = (intern_string(typ), intern_string(state_key))
- results[key] = event_id
+ sql = """
+ SELECT type, state_key, event_id FROM current_state_events
+ WHERE room_id = ?
+ """
+
+ where_clause, where_args = state_filter.make_sql_filter_clause()
+
+ if where_clause:
+ sql += " AND (%s)" % (where_clause,)
+
+ args = [room_id]
+ args.extend(where_args)
+ txn.execute(sql, args)
+ for row in txn:
+ typ, state_key, event_id = row
+ key = (intern_string(typ), intern_string(state_key))
+ results[key] = event_id
+
return results
return self.runInteraction(
@@ -322,20 +616,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
})
@defer.inlineCallbacks
- def _get_state_groups_from_groups(self, groups, types, members=None):
+ def _get_state_groups_from_groups(self, groups, state_filter):
"""Returns the state groups for a given set of groups, filtering on
types of state events.
Args:
groups(list[int]): list of state group IDs to query
- types (Iterable[str, str|None]|None): list of 2-tuples of the form
- (`type`, `state_key`), where a `state_key` of `None` matches all
- state_keys for the `type`. If None, all types are returned.
- members (bool|None): If not None, then, in addition to any filtering
- implied by types, the results are also filtered to only include
- member events (if True), or to exclude member events (if False)
-
- Returns:
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
Returns:
Deferred[dict[int, dict[tuple[str, str], str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
@@ -346,19 +634,23 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
for chunk in chunks:
res = yield self.runInteraction(
"_get_state_groups_from_groups",
- self._get_state_groups_from_groups_txn, chunk, types, members,
+ self._get_state_groups_from_groups_txn, chunk, state_filter,
)
results.update(res)
defer.returnValue(results)
def _get_state_groups_from_groups_txn(
- self, txn, groups, types=None, members=None,
+ self, txn, groups, state_filter=StateFilter.all(),
):
results = {group: {} for group in groups}
- if types is not None:
- types = list(set(types)) # deduplicate types list
+ where_clause, where_args = state_filter.make_sql_filter_clause()
+
+ # Unless the filter clause is empty, we're going to append it after an
+ # existing where clause
+ if where_clause:
+ where_clause = " AND (%s)" % (where_clause,)
if isinstance(self.database_engine, PostgresEngine):
# Temporarily disable sequential scans in this transaction. This is
@@ -374,79 +666,33 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# group for the given type, state_key.
# This may return multiple rows per (type, state_key), but last_value
# should be the same.
- sql = ("""
+ sql = """
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
- SELECT type, state_key, last_value(event_id) OVER (
+ SELECT DISTINCT type, state_key, last_value(event_id) OVER (
PARTITION BY type, state_key ORDER BY state_group ASC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS event_id FROM state_groups_state
WHERE state_group IN (
SELECT state_group FROM state
)
- %s
- """)
-
- if members is True:
- sql += " AND type = '%s'" % (EventTypes.Member,)
- elif members is False:
- sql += " AND type <> '%s'" % (EventTypes.Member,)
-
- # Turns out that postgres doesn't like doing a list of OR's and
- # is about 1000x slower, so we just issue a query for each specific
- # type seperately.
- if types is not None:
- clause_to_args = [
- (
- "AND type = ? AND state_key = ?",
- (etype, state_key)
- ) if state_key is not None else (
- "AND type = ?",
- (etype,)
- )
- for etype, state_key in types
- ]
- else:
- # If types is None we fetch all the state, and so just use an
- # empty where clause with no extra args.
- clause_to_args = [("", [])]
+ """
- for where_clause, where_args in clause_to_args:
- for group in groups:
- args = [group]
- args.extend(where_args)
+ for group in groups:
+ args = [group]
+ args.extend(where_args)
- txn.execute(sql % (where_clause,), args)
- for row in txn:
- typ, state_key, event_id = row
- key = (typ, state_key)
- results[group][key] = event_id
+ txn.execute(sql + where_clause, args)
+ for row in txn:
+ typ, state_key, event_id = row
+ key = (typ, state_key)
+ results[group][key] = event_id
else:
- where_args = []
- where_clauses = []
- wildcard_types = False
- if types is not None:
- for typ in types:
- if typ[1] is None:
- where_clauses.append("(type = ?)")
- where_args.append(typ[0])
- wildcard_types = True
- else:
- where_clauses.append("(type = ? AND state_key = ?)")
- where_args.extend([typ[0], typ[1]])
-
- where_clause = "AND (%s)" % (" OR ".join(where_clauses))
- else:
- where_clause = ""
-
- if members is True:
- where_clause += " AND type = '%s'" % EventTypes.Member
- elif members is False:
- where_clause += " AND type <> '%s'" % EventTypes.Member
+ max_entries_returned = state_filter.max_entries_returned()
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
@@ -460,12 +706,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# without the right indices (which we can't add until
# after we finish deduping state, which requires this func)
args = [next_group]
- if types:
- args.extend(where_args)
+ args.extend(where_args)
txn.execute(
"SELECT type, state_key, event_id FROM state_groups_state"
- " WHERE state_group = ? %s" % (where_clause,),
+ " WHERE state_group = ? " + where_clause,
args
)
results[group].update(
@@ -481,9 +726,8 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
# wildcards (i.e. Nones) in which case we have to do an exhaustive
# search
if (
- types is not None and
- not wildcard_types and
- len(results[group]) == len(types)
+ max_entries_returned is not None and
+ len(results[group]) == max_entries_returned
):
break
@@ -498,20 +742,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
@defer.inlineCallbacks
- def get_state_for_events(self, event_ids, types, filtered_types=None):
+ def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
"""Given a list of event_ids and type tuples, return a list of state
- dicts for each event. The state dicts will only have the type/state_keys
- that are in the `types` list.
+ dicts for each event.
Args:
event_ids (list[string])
- types (list[(str, str|None)]|None): List of (type, state_key) tuples
- which are used to filter the state fetched. If `state_key` is None,
- all events are returned of the given type.
- May be None, which matches any key.
- filtered_types(list[str]|None): Only apply filtering via `types` to this
- list of event types. Other types of events are returned unfiltered.
- If None, `types` filtering is applied to all events.
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
Returns:
deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
@@ -521,7 +759,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
)
groups = set(itervalues(event_to_groups))
- group_to_state = yield self._get_state_for_groups(groups, types, filtered_types)
+ group_to_state = yield self._get_state_for_groups(groups, state_filter)
state_event_map = yield self.get_events(
[ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
@@ -540,20 +778,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
defer.returnValue({event: event_to_state[event] for event in event_ids})
@defer.inlineCallbacks
- def get_state_ids_for_events(self, event_ids, types=None, filtered_types=None):
+ def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
"""
Get the state dicts corresponding to a list of events, containing the event_ids
of the state events (as opposed to the events themselves)
Args:
event_ids(list(str)): events whose state should be returned
- types(list[(str, str|None)]|None): List of (type, state_key) tuples
- which are used to filter the state fetched. If `state_key` is None,
- all events are returned of the given type.
- May be None, which matches any key.
- filtered_types(list[str]|None): Only apply filtering via `types` to this
- list of event types. Other types of events are returned unfiltered.
- If None, `types` filtering is applied to all events.
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
Returns:
A deferred dict from event_id -> (type, state_key) -> event_id
@@ -563,7 +796,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
)
groups = set(itervalues(event_to_groups))
- group_to_state = yield self._get_state_for_groups(groups, types, filtered_types)
+ group_to_state = yield self._get_state_for_groups(groups, state_filter)
event_to_state = {
event_id: group_to_state[group]
@@ -573,45 +806,35 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
defer.returnValue({event: event_to_state[event] for event in event_ids})
@defer.inlineCallbacks
- def get_state_for_event(self, event_id, types=None, filtered_types=None):
+ def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
"""
Get the state dict corresponding to a particular event
Args:
event_id(str): event whose state should be returned
- types(list[(str, str|None)]|None): List of (type, state_key) tuples
- which are used to filter the state fetched. If `state_key` is None,
- all events are returned of the given type.
- May be None, which matches any key.
- filtered_types(list[str]|None): Only apply filtering via `types` to this
- list of event types. Other types of events are returned unfiltered.
- If None, `types` filtering is applied to all events.
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
Returns:
A deferred dict from (type, state_key) -> state_event
"""
- state_map = yield self.get_state_for_events([event_id], types, filtered_types)
+ state_map = yield self.get_state_for_events([event_id], state_filter)
defer.returnValue(state_map[event_id])
@defer.inlineCallbacks
- def get_state_ids_for_event(self, event_id, types=None, filtered_types=None):
+ def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
"""
Get the state dict corresponding to a particular event
Args:
event_id(str): event whose state should be returned
- types(list[(str, str|None)]|None): List of (type, state_key) tuples
- which are used to filter the state fetched. If `state_key` is None,
- all events are returned of the given type.
- May be None, which matches any key.
- filtered_types(list[str]|None): Only apply filtering via `types` to this
- list of event types. Other types of events are returned unfiltered.
- If None, `types` filtering is applied to all events.
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
Returns:
A deferred dict from (type, state_key) -> state_event
"""
- state_map = yield self.get_state_ids_for_events([event_id], types, filtered_types)
+ state_map = yield self.get_state_ids_for_events([event_id], state_filter)
defer.returnValue(state_map[event_id])
@cached(max_entries=50000)
@@ -642,18 +865,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
defer.returnValue({row["event_id"]: row["state_group"] for row in rows})
- def _get_some_state_from_cache(self, cache, group, types, filtered_types=None):
+ def _get_state_for_group_using_cache(self, cache, group, state_filter):
"""Checks if group is in cache. See `_get_state_for_groups`
Args:
cache(DictionaryCache): the state group cache to use
group(int): The state group to lookup
- types(list[str, str|None]): List of 2-tuples of the form
- (`type`, `state_key`), where a `state_key` of `None` matches all
- state_keys for the `type`.
- filtered_types(list[str]|None): Only apply filtering via `types` to this
- list of event types. Other types of events are returned unfiltered.
- If None, `types` filtering is applied to all events.
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
Returns 2-tuple (`state_dict`, `got_all`).
`got_all` is a bool indicating if we successfully retrieved all
@@ -662,124 +881,102 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
is_all, known_absent, state_dict_ids = cache.get(group)
- type_to_key = {}
+ if is_all or state_filter.is_full():
+ # Either we have everything or want everything, either way
+ # `is_all` tells us whether we've gotten everything.
+ return state_filter.filter_state(state_dict_ids), is_all
# tracks whether any of our requested types are missing from the cache
missing_types = False
- for typ, state_key in types:
- key = (typ, state_key)
-
- if (
- state_key is None or
- (filtered_types is not None and typ not in filtered_types)
- ):
- type_to_key[typ] = None
- # we mark the type as missing from the cache because
- # when the cache was populated it might have been done with a
- # restricted set of state_keys, so the wildcard will not work
- # and the cache may be incomplete.
- missing_types = True
- else:
- if type_to_key.get(typ, object()) is not None:
- type_to_key.setdefault(typ, set()).add(state_key)
-
+ if state_filter.has_wildcards():
+ # We don't know if we fetched all the state keys for the types in
+ # the filter that are wildcards, so we have to assume that we may
+ # have missed some.
+ missing_types = True
+ else:
+ # There aren't any wild cards, so `concrete_types()` returns the
+ # complete list of event types we're wanting.
+ for key in state_filter.concrete_types():
if key not in state_dict_ids and key not in known_absent:
missing_types = True
+ break
- sentinel = object()
-
- def include(typ, state_key):
- valid_state_keys = type_to_key.get(typ, sentinel)
- if valid_state_keys is sentinel:
- return filtered_types is not None and typ not in filtered_types
- if valid_state_keys is None:
- return True
- if state_key in valid_state_keys:
- return True
- return False
-
- got_all = is_all
- if not got_all:
- # the cache is incomplete. We may still have got all the results we need, if
- # we don't have any wildcards in the match list.
- if not missing_types and filtered_types is None:
- got_all = True
-
- return {
- k: v for k, v in iteritems(state_dict_ids)
- if include(k[0], k[1])
- }, got_all
-
- def _get_all_state_from_cache(self, cache, group):
- """Checks if group is in cache. See `_get_state_for_groups`
-
- Returns 2-tuple (`state_dict`, `got_all`). `got_all` is a bool
- indicating if we successfully retrieved all requests state from the
- cache, if False we need to query the DB for the missing state.
-
- Args:
- cache(DictionaryCache): the state group cache to use
- group: The state group to lookup
- """
- is_all, _, state_dict_ids = cache.get(group)
-
- return state_dict_ids, is_all
+ return state_filter.filter_state(state_dict_ids), not missing_types
@defer.inlineCallbacks
- def _get_state_for_groups(self, groups, types=None, filtered_types=None):
+ def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
Args:
groups (iterable[int]): list of state groups for which we want
to get the state.
- types (None|iterable[(str, None|str)]):
- indicates the state type/keys required. If None, the whole
- state is fetched and returned.
-
- Otherwise, each entry should be a `(type, state_key)` tuple to
- include in the response. A `state_key` of None is a wildcard
- meaning that we require all state with that type.
- filtered_types(list[str]|None): Only apply filtering via `types` to this
- list of event types. Other types of events are returned unfiltered.
- If None, `types` filtering is applied to all events.
-
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
Returns:
Deferred[dict[int, dict[tuple[str, str], str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
- if types is not None:
- non_member_types = [t for t in types if t[0] != EventTypes.Member]
-
- if filtered_types is not None and EventTypes.Member not in filtered_types:
- # we want all of the membership events
- member_types = None
- else:
- member_types = [t for t in types if t[0] == EventTypes.Member]
- else:
- non_member_types = None
- member_types = None
+ member_filter, non_member_filter = state_filter.get_member_split()
- non_member_state = yield self._get_state_for_groups_using_cache(
- groups, self._state_group_cache, non_member_types, filtered_types,
+ # Now we look them up in the member and non-member caches
+ non_member_state, incomplete_groups_nm, = (
+ yield self._get_state_for_groups_using_cache(
+ groups, self._state_group_cache,
+ state_filter=non_member_filter,
+ )
)
- # XXX: we could skip this entirely if member_types is []
- member_state = yield self._get_state_for_groups_using_cache(
- # we set filtered_types=None as member_state only ever contain members.
- groups, self._state_group_members_cache, member_types, None,
+
+ member_state, incomplete_groups_m, = (
+ yield self._get_state_for_groups_using_cache(
+ groups, self._state_group_members_cache,
+ state_filter=member_filter,
+ )
)
- state = non_member_state
+ state = dict(non_member_state)
for group in groups:
state[group].update(member_state[group])
+ # Now fetch any missing groups from the database
+
+ incomplete_groups = incomplete_groups_m | incomplete_groups_nm
+
+ if not incomplete_groups:
+ defer.returnValue(state)
+
+ cache_sequence_nm = self._state_group_cache.sequence
+ cache_sequence_m = self._state_group_members_cache.sequence
+
+ # Help the cache hit ratio by expanding the filter a bit
+ db_state_filter = state_filter.return_expanded()
+
+ group_to_state_dict = yield self._get_state_groups_from_groups(
+ list(incomplete_groups),
+ state_filter=db_state_filter,
+ )
+
+ # Now lets update the caches
+ self._insert_into_cache(
+ group_to_state_dict,
+ db_state_filter,
+ cache_seq_num_members=cache_sequence_m,
+ cache_seq_num_non_members=cache_sequence_nm,
+ )
+
+ # And finally update the result dict, by filtering out any extra
+ # stuff we pulled out of the database.
+ for group, group_state_dict in iteritems(group_to_state_dict):
+ # We just replace any existing entries, as we will have loaded
+ # everything we need from the database anyway.
+ state[group] = state_filter.filter_state(group_state_dict)
+
defer.returnValue(state)
- @defer.inlineCallbacks
def _get_state_for_groups_using_cache(
- self, groups, cache, types=None, filtered_types=None
+ self, groups, cache, state_filter,
):
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key, querying from a specific cache.
@@ -790,89 +987,85 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
cache (DictionaryCache): the cache of group ids to state dicts which
we will pass through - either the normal state cache or the specific
members state cache.
- types (None|iterable[(str, None|str)]):
- indicates the state type/keys required. If None, the whole
- state is fetched and returned.
-
- Otherwise, each entry should be a `(type, state_key)` tuple to
- include in the response. A `state_key` of None is a wildcard
- meaning that we require all state with that type.
- filtered_types(list[str]|None): Only apply filtering via `types` to this
- list of event types. Other types of events are returned unfiltered.
- If None, `types` filtering is applied to all events.
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
Returns:
- Deferred[dict[int, dict[tuple[str, str], str]]]:
- dict of state_group_id -> (dict of (type, state_key) -> event id)
+ tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of
+ dict of state_group_id -> (dict of (type, state_key) -> event id)
+ of entries in the cache, and the state group ids either missing
+ from the cache or incomplete.
"""
- if types:
- types = frozenset(types)
results = {}
- missing_groups = []
- if types is not None:
- for group in set(groups):
- state_dict_ids, got_all = self._get_some_state_from_cache(
- cache, group, types, filtered_types
- )
- results[group] = state_dict_ids
+ incomplete_groups = set()
+ for group in set(groups):
+ state_dict_ids, got_all = self._get_state_for_group_using_cache(
+ cache, group, state_filter
+ )
+ results[group] = state_dict_ids
- if not got_all:
- missing_groups.append(group)
- else:
- for group in set(groups):
- state_dict_ids, got_all = self._get_all_state_from_cache(
- cache, group
- )
+ if not got_all:
+ incomplete_groups.add(group)
- results[group] = state_dict_ids
+ return results, incomplete_groups
- if not got_all:
- missing_groups.append(group)
+ def _insert_into_cache(self, group_to_state_dict, state_filter,
+ cache_seq_num_members, cache_seq_num_non_members):
+ """Inserts results from querying the database into the relevant cache.
- if missing_groups:
- # Okay, so we have some missing_types, let's fetch them.
- cache_seq_num = cache.sequence
+ Args:
+ group_to_state_dict (dict): The new entries pulled from database.
+ Map from state group to state dict
+ state_filter (StateFilter): The state filter used to fetch state
+ from the database.
+ cache_seq_num_members (int): Sequence number of member cache since
+ last lookup in cache
+ cache_seq_num_non_members (int): Sequence number of member cache since
+ last lookup in cache
+ """
- # the DictionaryCache knows if it has *all* the state, but
- # does not know if it has all of the keys of a particular type,
- # which makes wildcard lookups expensive unless we have a complete
- # cache. Hence, if we are doing a wildcard lookup, populate the
- # cache fully so that we can do an efficient lookup next time.
+ # We need to work out which types we've fetched from the DB for the
+ # member vs non-member caches. This should be as accurate as possible,
+ # but can be an underestimate (e.g. when we have wild cards)
- if filtered_types or (types and any(k is None for (t, k) in types)):
- types_to_fetch = None
- else:
- types_to_fetch = types
+ member_filter, non_member_filter = state_filter.get_member_split()
+ if member_filter.is_full():
+ # We fetched all member events
+ member_types = None
+ else:
+ # `concrete_types()` will only return a subset when there are wild
+ # cards in the filter, but that's fine.
+ member_types = member_filter.concrete_types()
- group_to_state_dict = yield self._get_state_groups_from_groups(
- missing_groups, types_to_fetch, cache == self._state_group_members_cache,
- )
+ if non_member_filter.is_full():
+ # We fetched all non member events
+ non_member_types = None
+ else:
+ non_member_types = non_member_filter.concrete_types()
+
+ for group, group_state_dict in iteritems(group_to_state_dict):
+ state_dict_members = {}
+ state_dict_non_members = {}
- for group, group_state_dict in iteritems(group_to_state_dict):
- state_dict = results[group]
-
- # update the result, filtering by `types`.
- if types:
- for k, v in iteritems(group_state_dict):
- (typ, _) = k
- if (
- (k in types or (typ, None) in types) or
- (filtered_types and typ not in filtered_types)
- ):
- state_dict[k] = v
+ for k, v in iteritems(group_state_dict):
+ if k[0] == EventTypes.Member:
+ state_dict_members[k] = v
else:
- state_dict.update(group_state_dict)
-
- # update the cache with all the things we fetched from the
- # database.
- cache.update(
- cache_seq_num,
- key=group,
- value=group_state_dict,
- fetched_keys=types_to_fetch,
- )
+ state_dict_non_members[k] = v
- defer.returnValue(results)
+ self._state_group_members_cache.update(
+ cache_seq_num_members,
+ key=group,
+ value=state_dict_members,
+ fetched_keys=member_types,
+ )
+
+ self._state_group_cache.update(
+ cache_seq_num_non_members,
+ key=group,
+ value=state_dict_non_members,
+ fetched_keys=non_member_types,
+ )
def store_state_group(self, event_id, room_id, prev_group, delta_ids,
current_state_ids):
@@ -1064,6 +1257,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
+ EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
def __init__(self, db_conn, hs):
super(StateStore, self).__init__(db_conn, hs)
@@ -1082,6 +1276,12 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
columns=["state_key"],
where_clause="type='m.room.member'",
)
+ self.register_background_index_update(
+ self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME,
+ index_name="event_to_state_groups_sg_index",
+ table="event_to_state_groups",
+ columns=["state_group"],
+ )
def _store_event_state_mappings_txn(self, txn, events_and_contexts):
state_groups = {}
@@ -1181,12 +1381,12 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
continue
prev_state = self._get_state_groups_from_groups_txn(
- txn, [prev_group], types=None
+ txn, [prev_group],
)
prev_state = prev_state[prev_group]
curr_state = self._get_state_groups_from_groups_txn(
- txn, [state_group], types=None
+ txn, [state_group],
)
curr_state = curr_state[state_group]
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 9a8fae0497..0ae7e2ef3b 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+import re
from itertools import islice
import attr
@@ -138,3 +139,27 @@ def log_failure(failure, msg, consumeErrors=True):
if not consumeErrors:
return failure
+
+
+def glob_to_regex(glob):
+ """Converts a glob to a compiled regex object.
+
+ The regex is anchored at the beginning and end of the string.
+
+ Args:
+ glob (str)
+
+ Returns:
+ re.RegexObject
+ """
+ res = ''
+ for c in glob:
+ if c == '*':
+ res = res + '.*'
+ elif c == '?':
+ res = res + '.'
+ else:
+ res = res + re.escape(c)
+
+ # \A anchors at start of string, \Z at end of string
+ return re.compile(r"\A" + res + r"\Z", re.IGNORECASE)
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 43f48196be..0281a7c919 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -23,6 +23,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events.utils import prune_event
+from synapse.storage.state import StateFilter
from synapse.types import get_domain_from_id
logger = logging.getLogger(__name__)
@@ -72,7 +73,7 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
)
event_id_to_state = yield store.get_state_for_events(
frozenset(e.event_id for e in events),
- types=types,
+ state_filter=StateFilter.from_types(types),
)
ignore_dict_content = yield store.get_global_account_data_by_type_for_user(
@@ -273,8 +274,8 @@ def filter_events_for_server(store, server_name, events):
# need to check membership (as we know the server is in the room).
event_to_state_ids = yield store.get_state_ids_for_events(
frozenset(e.event_id for e in events),
- types=(
- (EventTypes.RoomHistoryVisibility, ""),
+ state_filter=StateFilter.from_types(
+ types=((EventTypes.RoomHistoryVisibility, ""),),
)
)
@@ -314,9 +315,11 @@ def filter_events_for_server(store, server_name, events):
# of the history vis and membership state at those events.
event_to_state_ids = yield store.get_state_ids_for_events(
frozenset(e.event_id for e in events),
- types=(
- (EventTypes.RoomHistoryVisibility, ""),
- (EventTypes.Member, None),
+ state_filter=StateFilter.from_types(
+ types=(
+ (EventTypes.RoomHistoryVisibility, ""),
+ (EventTypes.Member, None),
+ ),
)
)
|