diff --git a/.gitignore b/.gitignore
index 491047c352..c8901eb206 100644
--- a/.gitignore
+++ b/.gitignore
@@ -46,3 +46,5 @@ static/client/register/register_config.js
env/
*.config
+
+.vscode/
diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md
new file mode 100644
index 0000000000..abdbc1ea86
--- /dev/null
+++ b/docs/admin_api/media_admin_api.md
@@ -0,0 +1,23 @@
+# List all media in a room
+
+This API gets a list of known media in a room.
+
+The API is:
+```
+GET /_matrix/client/r0/admin/room/<room_id>/media
+```
+including an `access_token` of a server admin.
+
+It returns a JSON body like the following:
+```
+{
+ "local": [
+ "mxc://localhost/xwvutsrqponmlkjihgfedcba",
+ "mxc://localhost/abcdefghijklmnopqrstuvwx"
+ ],
+ "remote": [
+ "mxc://matrix.org/xwvutsrqponmlkjihgfedcba",
+ "mxc://matrix.org/abcdefghijklmnopqrstuvwx"
+ ]
+}
+```
diff --git a/scripts/move_remote_media_to_new_store.py b/scripts/move_remote_media_to_new_store.py
new file mode 100755
index 0000000000..7914ead889
--- /dev/null
+++ b/scripts/move_remote_media_to_new_store.py
@@ -0,0 +1,133 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Moves a list of remote media from one media store to another.
+
+The input should be a list of media files to be moved, one per line. Each line
+should be formatted::
+
+ <origin server>|<file id>
+
+This can be extracted from postgres with::
+
+ psql --tuples-only -A -c "select media_origin, filesystem_id from
+ matrix.remote_media_cache where ..."
+
+To use, pipe the above into::
+
+ PYTHON_PATH=. ./scripts/move_remote_media_to_new_store.py <source repo> <dest repo>
+"""
+
+from __future__ import print_function
+
+import argparse
+import logging
+
+import sys
+
+import os
+
+import shutil
+
+from synapse.rest.media.v1.filepath import MediaFilePaths
+
+logger = logging.getLogger()
+
+
+def main(src_repo, dest_repo):
+ src_paths = MediaFilePaths(src_repo)
+ dest_paths = MediaFilePaths(dest_repo)
+ for line in sys.stdin:
+ line = line.strip()
+ parts = line.split('|')
+ if len(parts) != 2:
+ print("Unable to parse input line %s" % line, file=sys.stderr)
+ exit(1)
+
+ move_media(parts[0], parts[1], src_paths, dest_paths)
+
+
+def move_media(origin_server, file_id, src_paths, dest_paths):
+ """Move the given file, and any thumbnails, to the dest repo
+
+ Args:
+ origin_server (str):
+ file_id (str):
+ src_paths (MediaFilePaths):
+ dest_paths (MediaFilePaths):
+ """
+ logger.info("%s/%s", origin_server, file_id)
+
+ # check that the original exists
+ original_file = src_paths.remote_media_filepath(origin_server, file_id)
+ if not os.path.exists(original_file):
+ logger.warn(
+ "Original for %s/%s (%s) does not exist",
+ origin_server, file_id, original_file,
+ )
+ else:
+ mkdir_and_move(
+ original_file,
+ dest_paths.remote_media_filepath(origin_server, file_id),
+ )
+
+ # now look for thumbnails
+ original_thumb_dir = src_paths.remote_media_thumbnail_dir(
+ origin_server, file_id,
+ )
+ if not os.path.exists(original_thumb_dir):
+ return
+
+ mkdir_and_move(
+ original_thumb_dir,
+ dest_paths.remote_media_thumbnail_dir(origin_server, file_id)
+ )
+
+
+def mkdir_and_move(original_file, dest_file):
+ dirname = os.path.dirname(dest_file)
+ if not os.path.exists(dirname):
+ logger.debug("mkdir %s", dirname)
+ os.makedirs(dirname)
+ logger.debug("mv %s %s", original_file, dest_file)
+ shutil.move(original_file, dest_file)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description=__doc__,
+ formatter_class = argparse.RawDescriptionHelpFormatter,
+ )
+ parser.add_argument(
+ "-v", action='store_true', help='enable debug logging')
+ parser.add_argument(
+ "src_repo",
+ help="Path to source content repo",
+ )
+ parser.add_argument(
+ "dest_repo",
+ help="Path to source content repo",
+ )
+ args = parser.parse_args()
+
+ logging_config = {
+ "level": logging.DEBUG if args.v else logging.INFO,
+ "format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s"
+ }
+ logging.basicConfig(**logging_config)
+
+ main(args.src_repo, args.dest_repo)
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 79b35b3e7c..aa15f73f36 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -46,6 +46,7 @@ class Codes(object):
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
THREEPID_IN_USE = "M_THREEPID_IN_USE"
THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
+ THREEPID_DENIED = "M_THREEPID_DENIED"
INVALID_USERNAME = "M_INVALID_USERNAME"
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
@@ -140,6 +141,32 @@ class RegistrationError(SynapseError):
pass
+class FederationDeniedError(SynapseError):
+ """An error raised when the server tries to federate with a server which
+ is not on its federation whitelist.
+
+ Attributes:
+ destination (str): The destination which has been denied
+ """
+
+ def __init__(self, destination):
+ """Raised by federation client or server to indicate that we are
+ are deliberately not attempting to contact a given server because it is
+ not on our federation whitelist.
+
+ Args:
+ destination (str): the domain in question
+ """
+
+ self.destination = destination
+
+ super(FederationDeniedError, self).__init__(
+ code=403,
+ msg="Federation denied with %s." % (self.destination,),
+ errcode=Codes.FORBIDDEN,
+ )
+
+
class InteractiveAuthIncompleteError(Exception):
"""An error raised when UI auth is not yet complete
diff --git a/synapse/app/appservice.py b/synapse/app/appservice.py
index 7d0c2879ae..c6fe4516d1 100644
--- a/synapse/app/appservice.py
+++ b/synapse/app/appservice.py
@@ -49,19 +49,6 @@ class AppserviceSlaveStore(
class AppserviceServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
diff --git a/synapse/app/client_reader.py b/synapse/app/client_reader.py
index dc3f6efd43..3b3352798d 100644
--- a/synapse/app/client_reader.py
+++ b/synapse/app/client_reader.py
@@ -64,19 +64,6 @@ class ClientReaderSlavedStore(
class ClientReaderServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self)
diff --git a/synapse/app/federation_reader.py b/synapse/app/federation_reader.py
index a072291e1f..4de43c41f0 100644
--- a/synapse/app/federation_reader.py
+++ b/synapse/app/federation_reader.py
@@ -58,19 +58,6 @@ class FederationReaderSlavedStore(
class FederationReaderServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 09e9488f06..f760826d27 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -76,19 +76,6 @@ class FederationSenderSlaveStore(
class FederationSenderServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self)
diff --git a/synapse/app/frontend_proxy.py b/synapse/app/frontend_proxy.py
index ae531c0aa4..e32ee8fe93 100644
--- a/synapse/app/frontend_proxy.py
+++ b/synapse/app/frontend_proxy.py
@@ -118,19 +118,6 @@ class FrontendProxySlavedStore(
class FrontendProxyServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 92ab3b311b..cb82a415a6 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -266,19 +266,6 @@ class SynapseHomeServer(HomeServer):
except IncorrectDatabaseSetup as e:
quit_with_error(e.message)
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(config_options):
"""
diff --git a/synapse/app/media_repository.py b/synapse/app/media_repository.py
index eab1597aaa..1ed1ca8772 100644
--- a/synapse/app/media_repository.py
+++ b/synapse/app/media_repository.py
@@ -60,19 +60,6 @@ class MediaRepositorySlavedStore(
class MediaRepositoryServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
diff --git a/synapse/app/pusher.py b/synapse/app/pusher.py
index 7fbbb0b0e1..32ccea3f13 100644
--- a/synapse/app/pusher.py
+++ b/synapse/app/pusher.py
@@ -81,19 +81,6 @@ class PusherSlaveStore(
class PusherServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = PusherSlaveStore(self.get_db_conn(), self)
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 0abba3016e..f87531f1b6 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -246,19 +246,6 @@ class SynchrotronApplicationService(object):
class SynchrotronServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)
diff --git a/synapse/app/user_dir.py b/synapse/app/user_dir.py
index a48c4a2ae6..494ccb702c 100644
--- a/synapse/app/user_dir.py
+++ b/synapse/app/user_dir.py
@@ -92,19 +92,6 @@ class UserDirectorySlaveStore(
class UserDirectoryServer(HomeServer):
- def get_db_conn(self, run_new_connection=True):
- # Any param beginning with cp_ is a parameter for adbapi, and should
- # not be passed to the database engine.
- db_params = {
- k: v for k, v in self.db_config.get("args", {}).items()
- if not k.startswith("cp_")
- }
- db_conn = self.database_engine.module.connect(**db_params)
-
- if run_new_connection:
- self.database_engine.on_new_connection(db_conn)
- return db_conn
-
def setup(self):
logger.info("Setting up.")
self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self)
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index ef917fc9f2..336959094b 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -31,6 +31,8 @@ class RegistrationConfig(Config):
strtobool(str(config["disable_registration"]))
)
+ self.registrations_require_3pid = config.get("registrations_require_3pid", [])
+ self.allowed_local_3pids = config.get("allowed_local_3pids", [])
self.registration_shared_secret = config.get("registration_shared_secret")
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
@@ -52,6 +54,23 @@ class RegistrationConfig(Config):
# Enable registration for new users.
enable_registration: False
+ # The user must provide all of the below types of 3PID when registering.
+ #
+ # registrations_require_3pid:
+ # - email
+ # - msisdn
+
+ # Mandate that users are only allowed to associate certain formats of
+ # 3PIDs with accounts on this server.
+ #
+ # allowed_local_3pids:
+ # - medium: email
+ # pattern: ".*@matrix\\.org"
+ # - medium: email
+ # pattern: ".*@vector\\.im"
+ # - medium: msisdn
+ # pattern: "\\+44"
+
# If set, allows registration by anyone who also has the shared
# secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s"
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 6baa474931..25ea77738a 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -16,6 +16,8 @@
from ._base import Config, ConfigError
from collections import namedtuple
+from synapse.util.module_loader import load_module
+
MISSING_NETADDR = (
"Missing netaddr library. This is required for URL preview API."
@@ -36,6 +38,14 @@ ThumbnailRequirement = namedtuple(
"ThumbnailRequirement", ["width", "height", "method", "media_type"]
)
+MediaStorageProviderConfig = namedtuple(
+ "MediaStorageProviderConfig", (
+ "store_local", # Whether to store newly uploaded local files
+ "store_remote", # Whether to store newly downloaded remote files
+ "store_synchronous", # Whether to wait for successful storage for local uploads
+ ),
+)
+
def parse_thumbnail_requirements(thumbnail_sizes):
""" Takes a list of dictionaries with "width", "height", and "method" keys
@@ -73,16 +83,61 @@ class ContentRepositoryConfig(Config):
self.media_store_path = self.ensure_directory(config["media_store_path"])
- self.backup_media_store_path = config.get("backup_media_store_path")
- if self.backup_media_store_path:
- self.backup_media_store_path = self.ensure_directory(
- self.backup_media_store_path
- )
+ backup_media_store_path = config.get("backup_media_store_path")
- self.synchronous_backup_media_store = config.get(
+ synchronous_backup_media_store = config.get(
"synchronous_backup_media_store", False
)
+ storage_providers = config.get("media_storage_providers", [])
+
+ if backup_media_store_path:
+ if storage_providers:
+ raise ConfigError(
+ "Cannot use both 'backup_media_store_path' and 'storage_providers'"
+ )
+
+ storage_providers = [{
+ "module": "file_system",
+ "store_local": True,
+ "store_synchronous": synchronous_backup_media_store,
+ "store_remote": True,
+ "config": {
+ "directory": backup_media_store_path,
+ }
+ }]
+
+ # This is a list of config that can be used to create the storage
+ # providers. The entries are tuples of (Class, class_config,
+ # MediaStorageProviderConfig), where Class is the class of the provider,
+ # the class_config the config to pass to it, and
+ # MediaStorageProviderConfig are options for StorageProviderWrapper.
+ #
+ # We don't create the storage providers here as not all workers need
+ # them to be started.
+ self.media_storage_providers = []
+
+ for provider_config in storage_providers:
+ # We special case the module "file_system" so as not to need to
+ # expose FileStorageProviderBackend
+ if provider_config["module"] == "file_system":
+ provider_config["module"] = (
+ "synapse.rest.media.v1.storage_provider"
+ ".FileStorageProviderBackend"
+ )
+
+ provider_class, parsed_config = load_module(provider_config)
+
+ wrapper_config = MediaStorageProviderConfig(
+ provider_config.get("store_local", False),
+ provider_config.get("store_remote", False),
+ provider_config.get("store_synchronous", False),
+ )
+
+ self.media_storage_providers.append(
+ (provider_class, parsed_config, wrapper_config,)
+ )
+
self.uploads_path = self.ensure_directory(config["uploads_path"])
self.dynamic_thumbnails = config["dynamic_thumbnails"]
self.thumbnail_requirements = parse_thumbnail_requirements(
@@ -127,13 +182,19 @@ class ContentRepositoryConfig(Config):
# Directory where uploaded images and attachments are stored.
media_store_path: "%(media_store)s"
- # A secondary directory where uploaded images and attachments are
- # stored as a backup.
- # backup_media_store_path: "%(media_store)s"
-
- # Whether to wait for successful write to backup media store before
- # returning successfully.
- # synchronous_backup_media_store: false
+ # Media storage providers allow media to be stored in different
+ # locations.
+ # media_storage_providers:
+ # - module: file_system
+ # # Whether to write new local files.
+ # store_local: false
+ # # Whether to write new remote media
+ # store_remote: false
+ # # Whether to block upload requests waiting for write to this
+ # # provider to complete
+ # store_synchronous: false
+ # config:
+ # directory: /mnt/some/other/directory
# Directory where in-progress uploads are stored.
uploads_path: "%(uploads_path)s"
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 436dd8a6fe..8f0b6d1f28 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -55,6 +55,17 @@ class ServerConfig(Config):
"block_non_admin_invites", False,
)
+ # FIXME: federation_domain_whitelist needs sytests
+ self.federation_domain_whitelist = None
+ federation_domain_whitelist = config.get(
+ "federation_domain_whitelist", None
+ )
+ # turn the whitelist into a hash for speed of lookup
+ if federation_domain_whitelist is not None:
+ self.federation_domain_whitelist = {}
+ for domain in federation_domain_whitelist:
+ self.federation_domain_whitelist[domain] = True
+
if self.public_baseurl is not None:
if self.public_baseurl[-1] != '/':
self.public_baseurl += '/'
@@ -210,6 +221,17 @@ class ServerConfig(Config):
# (except those sent by local server admins). The default is False.
# block_non_admin_invites: True
+ # Restrict federation to the following whitelist of domains.
+ # N.B. we recommend also firewalling your federation listener to limit
+ # inbound federation traffic as early as possible, rather than relying
+ # purely on this application-layer restriction. If not specified, the
+ # default is to whitelist everything.
+ #
+ # federation_domain_whitelist:
+ # - lon.example.com
+ # - nyc.example.com
+ # - syd.example.com
+
# List of ports that Synapse should listen on, their purpose and their
# configuration.
listeners:
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index e9a732ff03..87e3fe7b97 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -25,7 +25,9 @@ class EventContext(object):
The current state map excluding the current event.
(type, state_key) -> event_id
- state_group (int): state group id
+ state_group (int|None): state group id, if the state has been stored
+ as a state group. This is usually only None if e.g. the event is
+ an outlier.
rejected (bool|str): A rejection reason if the event was rejected, else
False
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index b1fe03f702..813907f7f2 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -23,7 +23,7 @@ from twisted.internet import defer
from synapse.api.constants import Membership
from synapse.api.errors import (
- CodeMessageException, HttpResponseException, SynapseError,
+ CodeMessageException, HttpResponseException, SynapseError, FederationDeniedError
)
from synapse.events import builder
from synapse.federation.federation_base import (
@@ -266,6 +266,9 @@ class FederationClient(FederationBase):
except NotRetryingDestination as e:
logger.info(e.message)
continue
+ except FederationDeniedError as e:
+ logger.info(e.message)
+ continue
except Exception as e:
pdu_attempts[destination] = now
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 9d39f46583..a141ec9953 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
from .persistence import TransactionActions
from .units import Transaction, Edu
-from synapse.api.errors import HttpResponseException
+from synapse.api.errors import HttpResponseException, FederationDeniedError
from synapse.util import logcontext, PreserveLoggingContext
from synapse.util.async import run_on_reactor
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
@@ -490,6 +490,8 @@ class TransactionQueue(object):
(e.retry_last_ts + e.retry_interval) / 1000.0
),
)
+ except FederationDeniedError as e:
+ logger.info(e)
except Exception as e:
logger.warn(
"TX [%s] Failed to send transaction: %s",
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 1f3ce238f6..5488e82985 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -212,6 +212,9 @@ class TransportLayerClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
+
+ Fails with ``FederationDeniedError`` if the remote destination
+ is not in our federation whitelist
"""
valid_memberships = {Membership.JOIN, Membership.LEAVE}
if membership not in valid_memberships:
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 2b02b021ec..06c16ba4fa 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -16,7 +16,7 @@
from twisted.internet import defer
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import Codes, SynapseError, FederationDeniedError
from synapse.http.server import JsonResource
from synapse.http.servlet import (
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
@@ -81,6 +81,7 @@ class Authenticator(object):
self.keyring = hs.get_keyring()
self.server_name = hs.hostname
self.store = hs.get_datastore()
+ self.federation_domain_whitelist = hs.config.federation_domain_whitelist
# A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks
@@ -92,6 +93,12 @@ class Authenticator(object):
"signatures": {},
}
+ if (
+ self.federation_domain_whitelist is not None and
+ self.server_name not in self.federation_domain_whitelist
+ ):
+ raise FederationDeniedError(self.server_name)
+
if content is not None:
json_request["content"] = content
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 2152efc692..0e83453851 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -14,6 +14,7 @@
# limitations under the License.
from synapse.api import errors
from synapse.api.constants import EventTypes
+from synapse.api.errors import FederationDeniedError
from synapse.util import stringutils
from synapse.util.async import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
@@ -513,6 +514,9 @@ class DeviceListEduUpdater(object):
# This makes it more likely that the device lists will
# eventually become consistent.
return
+ except FederationDeniedError as e:
+ logger.info(e)
+ return
except Exception:
# TODO: Remember that we are now out of sync and try again
# later
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index a0464ae5c0..8580ada60a 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -34,6 +34,7 @@ class DirectoryHandler(BaseHandler):
self.state = hs.get_state_handler()
self.appservice_handler = hs.get_application_service_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
self.federation = hs.get_replication_layer()
self.federation.register_query_handler(
@@ -249,8 +250,7 @@ class DirectoryHandler(BaseHandler):
def send_room_alias_update_event(self, requester, user_id, room_id):
aliases = yield self.store.get_aliases_for_room(room_id)
- msg_handler = self.hs.get_handlers().message_handler
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Aliases,
@@ -272,8 +272,7 @@ class DirectoryHandler(BaseHandler):
if not alias_event or alias_event.content.get("alias", "") != alias_str:
return
- msg_handler = self.hs.get_handlers().message_handler
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.CanonicalAlias,
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 5af8abf66b..9aa95f89e6 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -19,7 +19,9 @@ import logging
from canonicaljson import encode_canonical_json
from twisted.internet import defer
-from synapse.api.errors import SynapseError, CodeMessageException
+from synapse.api.errors import (
+ SynapseError, CodeMessageException, FederationDeniedError,
+)
from synapse.types import get_domain_from_id, UserID
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
from synapse.util.retryutils import NotRetryingDestination
@@ -140,6 +142,10 @@ class E2eKeysHandler(object):
failures[destination] = {
"status": 503, "message": "Not ready for retry",
}
+ except FederationDeniedError as e:
+ failures[destination] = {
+ "status": 403, "message": "Federation Denied",
+ }
except Exception as e:
# include ConnectionRefused and other errors
failures[destination] = {
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index ac70730885..46bcf8b081 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -22,6 +23,7 @@ from ._base import BaseHandler
from synapse.api.errors import (
AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
+ FederationDeniedError,
)
from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.events.validator import EventValidator
@@ -74,6 +76,7 @@ class FederationHandler(BaseHandler):
self.is_mine_id = hs.is_mine_id
self.pusher_pool = hs.get_pusherpool()
self.spam_checker = hs.get_spam_checker()
+ self.event_creation_handler = hs.get_event_creation_handler()
self.replication_layer.set_handler(self)
@@ -782,6 +785,9 @@ class FederationHandler(BaseHandler):
except NotRetryingDestination as e:
logger.info(e.message)
continue
+ except FederationDeniedError as e:
+ logger.info(e)
+ continue
except Exception as e:
logger.exception(
"Failed to backfill from %s because %s",
@@ -804,13 +810,12 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys())
logger.debug("calling resolve_state_groups in _maybe_backfill")
+ resolve = logcontext.preserve_fn(
+ self.state_handler.resolve_state_groups_for_events
+ )
states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
- [
- logcontext.preserve_fn(self.state_handler.resolve_state_groups)(
- room_id, [e]
- )
- for e in event_ids
- ], consumeErrors=True,
+ [resolve(room_id, [e]) for e in event_ids],
+ consumeErrors=True,
))
states = dict(zip(event_ids, [s.state for s in states]))
@@ -1004,8 +1009,7 @@ class FederationHandler(BaseHandler):
})
try:
- message_handler = self.hs.get_handlers().message_handler
- event, context = yield message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder,
)
except AuthError as e:
@@ -1245,8 +1249,7 @@ class FederationHandler(BaseHandler):
"state_key": user_id,
})
- message_handler = self.hs.get_handlers().message_handler
- event, context = yield message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder,
)
@@ -1828,8 +1831,8 @@ class FederationHandler(BaseHandler):
current_state = set(e.event_id for e in auth_events.values())
different_auth = event_auth_events - current_state
- self._update_context_for_auth_events(
- context, auth_events, event_key,
+ yield self._update_context_for_auth_events(
+ event, context, auth_events, event_key,
)
if different_auth and not event.internal_metadata.is_outlier():
@@ -1910,8 +1913,8 @@ class FederationHandler(BaseHandler):
# 4. Look at rejects and their proofs.
# TODO.
- self._update_context_for_auth_events(
- context, auth_events, event_key,
+ yield self._update_context_for_auth_events(
+ event, context, auth_events, event_key,
)
try:
@@ -1920,11 +1923,15 @@ class FederationHandler(BaseHandler):
logger.warn("Failed auth resolution for %r because %s", event, e)
raise e
- def _update_context_for_auth_events(self, context, auth_events,
+ @defer.inlineCallbacks
+ def _update_context_for_auth_events(self, event, context, auth_events,
event_key):
- """Update the state_ids in an event context after auth event resolution
+ """Update the state_ids in an event context after auth event resolution,
+ storing the changes as a new state group.
Args:
+ event (Event): The event we're handling the context for
+
context (synapse.events.snapshot.EventContext): event context
to be updated
@@ -1947,7 +1954,13 @@ class FederationHandler(BaseHandler):
context.prev_state_ids.update({
k: a.event_id for k, a in auth_events.iteritems()
})
- context.state_group = self.store.get_next_state_group()
+ context.state_group = yield self.store.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=context.prev_group,
+ delta_ids=context.delta_ids,
+ current_state_ids=context.current_state_ids,
+ )
@defer.inlineCallbacks
def construct_auth_difference(self, local_auth, remote_auth):
@@ -2117,8 +2130,7 @@ class FederationHandler(BaseHandler):
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder)
- message_handler = self.hs.get_handlers().message_handler
- event, context = yield message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder
)
@@ -2156,8 +2168,7 @@ class FederationHandler(BaseHandler):
"""
builder = self.event_builder_factory.new(event_dict)
- message_handler = self.hs.get_handlers().message_handler
- event, context = yield message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder=builder,
)
@@ -2207,8 +2218,9 @@ class FederationHandler(BaseHandler):
builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder)
- message_handler = self.hs.get_handlers().message_handler
- event, context = yield message_handler._create_new_client_event(builder=builder)
+ event, context = yield self.event_creation_handler.create_new_client_event(
+ builder=builder,
+ )
defer.returnValue((event, context))
@defer.inlineCallbacks
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 276d1a7722..4e9752ccbd 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
-# Copyright 2017 New Vector Ltd
+# Copyright 2017 - 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -47,21 +47,9 @@ class MessageHandler(BaseHandler):
self.hs = hs
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
- self.validator = EventValidator()
- self.profile_handler = hs.get_profile_handler()
self.pagination_lock = ReadWriteLock()
- self.pusher_pool = hs.get_pusherpool()
-
- # We arbitrarily limit concurrent event creation for a room to 5.
- # This is to stop us from diverging history *too* much.
- self.limiter = Limiter(max_count=5)
-
- self.action_generator = hs.get_action_generator()
-
- self.spam_checker = hs.get_spam_checker()
-
@defer.inlineCallbacks
def purge_history(self, room_id, event_id, delete_local_events=False):
event = yield self.store.get_event(event_id)
@@ -183,6 +171,162 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk)
@defer.inlineCallbacks
+ def get_room_data(self, user_id=None, room_id=None,
+ event_type=None, state_key="", is_guest=False):
+ """ Get data from a room.
+
+ Args:
+ event : The room path event
+ Returns:
+ The path data content.
+ Raises:
+ SynapseError if something went wrong.
+ """
+ membership, membership_event_id = yield self._check_in_room_or_world_readable(
+ room_id, user_id
+ )
+
+ if membership == Membership.JOIN:
+ data = yield self.state_handler.get_current_state(
+ room_id, event_type, state_key
+ )
+ elif membership == Membership.LEAVE:
+ key = (event_type, state_key)
+ room_state = yield self.store.get_state_for_events(
+ [membership_event_id], [key]
+ )
+ data = room_state[membership_event_id].get(key)
+
+ defer.returnValue(data)
+
+ @defer.inlineCallbacks
+ def _check_in_room_or_world_readable(self, room_id, user_id):
+ try:
+ # check_user_was_in_room will return the most recent membership
+ # event for the user if:
+ # * The user is a non-guest user, and was ever in the room
+ # * The user is a guest user, and has joined the room
+ # else it will throw.
+ member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
+ defer.returnValue((member_event.membership, member_event.event_id))
+ return
+ except AuthError:
+ visibility = yield self.state_handler.get_current_state(
+ room_id, EventTypes.RoomHistoryVisibility, ""
+ )
+ if (
+ visibility and
+ visibility.content["history_visibility"] == "world_readable"
+ ):
+ defer.returnValue((Membership.JOIN, None))
+ return
+ raise AuthError(
+ 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
+ )
+
+ @defer.inlineCallbacks
+ def get_state_events(self, user_id, room_id, is_guest=False):
+ """Retrieve all state events for a given room. If the user is
+ joined to the room then return the current state. If the user has
+ left the room return the state events from when they left.
+
+ Args:
+ user_id(str): The user requesting state events.
+ room_id(str): The room ID to get all state events from.
+ Returns:
+ A list of dicts representing state events. [{}, {}, {}]
+ """
+ membership, membership_event_id = yield self._check_in_room_or_world_readable(
+ room_id, user_id
+ )
+
+ if membership == Membership.JOIN:
+ room_state = yield self.state_handler.get_current_state(room_id)
+ elif membership == Membership.LEAVE:
+ room_state = yield self.store.get_state_for_events(
+ [membership_event_id], None
+ )
+ room_state = room_state[membership_event_id]
+
+ now = self.clock.time_msec()
+ defer.returnValue(
+ [serialize_event(c, now) for c in room_state.values()]
+ )
+
+ @defer.inlineCallbacks
+ def get_joined_members(self, requester, room_id):
+ """Get all the joined members in the room and their profile information.
+
+ If the user has left the room return the state events from when they left.
+
+ Args:
+ requester(Requester): The user requesting state events.
+ room_id(str): The room ID to get all state events from.
+ Returns:
+ A dict of user_id to profile info
+ """
+ user_id = requester.user.to_string()
+ if not requester.app_service:
+ # We check AS auth after fetching the room membership, as it
+ # requires us to pull out all joined members anyway.
+ membership, _ = yield self._check_in_room_or_world_readable(
+ room_id, user_id
+ )
+ if membership != Membership.JOIN:
+ raise NotImplementedError(
+ "Getting joined members after leaving is not implemented"
+ )
+
+ users_with_profile = yield self.state.get_current_user_in_room(room_id)
+
+ # If this is an AS, double check that they are allowed to see the members.
+ # This can either be because the AS user is in the room or becuase there
+ # is a user in the room that the AS is "interested in"
+ if requester.app_service and user_id not in users_with_profile:
+ for uid in users_with_profile:
+ if requester.app_service.is_interested_in_user(uid):
+ break
+ else:
+ # Loop fell through, AS has no interested users in room
+ raise AuthError(403, "Appservice not in room")
+
+ defer.returnValue({
+ user_id: {
+ "avatar_url": profile.avatar_url,
+ "display_name": profile.display_name,
+ }
+ for user_id, profile in users_with_profile.iteritems()
+ })
+
+
+class EventCreationHandler(object):
+ def __init__(self, hs):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.state = hs.get_state_handler()
+ self.clock = hs.get_clock()
+ self.validator = EventValidator()
+ self.profile_handler = hs.get_profile_handler()
+ self.event_builder_factory = hs.get_event_builder_factory()
+ self.server_name = hs.hostname
+ self.ratelimiter = hs.get_ratelimiter()
+ self.notifier = hs.get_notifier()
+
+ # This is only used to get at ratelimit function, and maybe_kick_guest_users
+ self.base_handler = BaseHandler(hs)
+
+ self.pusher_pool = hs.get_pusherpool()
+
+ # We arbitrarily limit concurrent event creation for a room to 5.
+ # This is to stop us from diverging history *too* much.
+ self.limiter = Limiter(max_count=5)
+
+ self.action_generator = hs.get_action_generator()
+
+ self.spam_checker = hs.get_spam_checker()
+
+ @defer.inlineCallbacks
def create_event(self, requester, event_dict, token_id=None, txn_id=None,
prev_event_ids=None):
"""
@@ -234,7 +378,7 @@ class MessageHandler(BaseHandler):
if txn_id is not None:
builder.internal_metadata.txn_id = txn_id
- event, context = yield self._create_new_client_event(
+ event, context = yield self.create_new_client_event(
builder=builder,
requester=requester,
prev_event_ids=prev_event_ids,
@@ -259,11 +403,6 @@ class MessageHandler(BaseHandler):
"Tried to send member event through non-member codepath"
)
- # We check here if we are currently being rate limited, so that we
- # don't do unnecessary work. We check again just before we actually
- # send the event.
- yield self.ratelimit(requester, update=False)
-
user = UserID.from_string(event.sender)
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
@@ -342,137 +481,9 @@ class MessageHandler(BaseHandler):
)
defer.returnValue(event)
+ @measure_func("create_new_client_event")
@defer.inlineCallbacks
- def get_room_data(self, user_id=None, room_id=None,
- event_type=None, state_key="", is_guest=False):
- """ Get data from a room.
-
- Args:
- event : The room path event
- Returns:
- The path data content.
- Raises:
- SynapseError if something went wrong.
- """
- membership, membership_event_id = yield self._check_in_room_or_world_readable(
- room_id, user_id
- )
-
- if membership == Membership.JOIN:
- data = yield self.state_handler.get_current_state(
- room_id, event_type, state_key
- )
- elif membership == Membership.LEAVE:
- key = (event_type, state_key)
- room_state = yield self.store.get_state_for_events(
- [membership_event_id], [key]
- )
- data = room_state[membership_event_id].get(key)
-
- defer.returnValue(data)
-
- @defer.inlineCallbacks
- def _check_in_room_or_world_readable(self, room_id, user_id):
- try:
- # check_user_was_in_room will return the most recent membership
- # event for the user if:
- # * The user is a non-guest user, and was ever in the room
- # * The user is a guest user, and has joined the room
- # else it will throw.
- member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
- defer.returnValue((member_event.membership, member_event.event_id))
- return
- except AuthError:
- visibility = yield self.state_handler.get_current_state(
- room_id, EventTypes.RoomHistoryVisibility, ""
- )
- if (
- visibility and
- visibility.content["history_visibility"] == "world_readable"
- ):
- defer.returnValue((Membership.JOIN, None))
- return
- raise AuthError(
- 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
- )
-
- @defer.inlineCallbacks
- def get_state_events(self, user_id, room_id, is_guest=False):
- """Retrieve all state events for a given room. If the user is
- joined to the room then return the current state. If the user has
- left the room return the state events from when they left.
-
- Args:
- user_id(str): The user requesting state events.
- room_id(str): The room ID to get all state events from.
- Returns:
- A list of dicts representing state events. [{}, {}, {}]
- """
- membership, membership_event_id = yield self._check_in_room_or_world_readable(
- room_id, user_id
- )
-
- if membership == Membership.JOIN:
- room_state = yield self.state_handler.get_current_state(room_id)
- elif membership == Membership.LEAVE:
- room_state = yield self.store.get_state_for_events(
- [membership_event_id], None
- )
- room_state = room_state[membership_event_id]
-
- now = self.clock.time_msec()
- defer.returnValue(
- [serialize_event(c, now) for c in room_state.values()]
- )
-
- @defer.inlineCallbacks
- def get_joined_members(self, requester, room_id):
- """Get all the joined members in the room and their profile information.
-
- If the user has left the room return the state events from when they left.
-
- Args:
- requester(Requester): The user requesting state events.
- room_id(str): The room ID to get all state events from.
- Returns:
- A dict of user_id to profile info
- """
- user_id = requester.user.to_string()
- if not requester.app_service:
- # We check AS auth after fetching the room membership, as it
- # requires us to pull out all joined members anyway.
- membership, _ = yield self._check_in_room_or_world_readable(
- room_id, user_id
- )
- if membership != Membership.JOIN:
- raise NotImplementedError(
- "Getting joined members after leaving is not implemented"
- )
-
- users_with_profile = yield self.state.get_current_user_in_room(room_id)
-
- # If this is an AS, double check that they are allowed to see the members.
- # This can either be because the AS user is in the room or becuase there
- # is a user in the room that the AS is "interested in"
- if requester.app_service and user_id not in users_with_profile:
- for uid in users_with_profile:
- if requester.app_service.is_interested_in_user(uid):
- break
- else:
- # Loop fell through, AS has no interested users in room
- raise AuthError(403, "Appservice not in room")
-
- defer.returnValue({
- user_id: {
- "avatar_url": profile.avatar_url,
- "display_name": profile.display_name,
- }
- for user_id, profile in users_with_profile.iteritems()
- })
-
- @measure_func("_create_new_client_event")
- @defer.inlineCallbacks
- def _create_new_client_event(self, builder, requester=None, prev_event_ids=None):
+ def create_new_client_event(self, builder, requester=None, prev_event_ids=None):
if prev_event_ids:
prev_events = yield self.store.add_event_hashes(prev_event_ids)
prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
@@ -509,9 +520,7 @@ class MessageHandler(BaseHandler):
builder.prev_events = prev_events
builder.depth = depth
- state_handler = self.state_handler
-
- context = yield state_handler.compute_event_context(builder)
+ context = yield self.state.compute_event_context(builder)
if requester:
context.app_service = requester.app_service
@@ -551,7 +560,7 @@ class MessageHandler(BaseHandler):
# We now need to go and hit out to wherever we need to hit out to.
if ratelimit:
- yield self.ratelimit(requester)
+ yield self.base_handler.ratelimit(requester)
try:
yield self.auth.check_from_context(event, context)
@@ -567,7 +576,7 @@ class MessageHandler(BaseHandler):
logger.exception("Failed to encode content: %r", event.content)
raise
- yield self.maybe_kick_guest_users(event, context)
+ yield self.base_handler.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 5b808beac1..9021d4d57f 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -25,6 +25,7 @@ from synapse.http.client import CaptchaServerHttpClient
from synapse import types
from synapse.types import UserID
from synapse.util.async import run_on_reactor
+from synapse.util.threepids import check_3pid_allowed
from ._base import BaseHandler
logger = logging.getLogger(__name__)
@@ -293,7 +294,7 @@ class RegistrationHandler(BaseHandler):
"""
for c in threepidCreds:
- logger.info("validating theeepidcred sid %s on id server %s",
+ logger.info("validating threepidcred sid %s on id server %s",
c['sid'], c['idServer'])
try:
identity_handler = self.hs.get_handlers().identity_handler
@@ -307,6 +308,11 @@ class RegistrationHandler(BaseHandler):
logger.info("got threepid with medium '%s' and address '%s'",
threepid['medium'], threepid['address'])
+ if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']):
+ raise RegistrationError(
+ 403, "Third party identifier is not allowed"
+ )
+
@defer.inlineCallbacks
def bind_emails(self, user_id, threepidCreds):
"""Links emails with a user ID and informs an identity server.
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index d1cc87a016..6ab020bf41 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -64,6 +65,7 @@ class RoomCreationHandler(BaseHandler):
super(RoomCreationHandler, self).__init__(hs)
self.spam_checker = hs.get_spam_checker()
+ self.event_creation_handler = hs.get_event_creation_handler()
@defer.inlineCallbacks
def create_room(self, requester, config, ratelimit=True):
@@ -163,13 +165,11 @@ class RoomCreationHandler(BaseHandler):
creation_content = config.get("creation_content", {})
- msg_handler = self.hs.get_handlers().message_handler
room_member_handler = self.hs.get_handlers().room_member_handler
yield self._send_events_for_new_room(
requester,
room_id,
- msg_handler,
room_member_handler,
preset_config=preset_config,
invite_list=invite_list,
@@ -181,7 +181,7 @@ class RoomCreationHandler(BaseHandler):
if "name" in config:
name = config["name"]
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Name,
@@ -194,7 +194,7 @@ class RoomCreationHandler(BaseHandler):
if "topic" in config:
topic = config["topic"]
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Topic,
@@ -249,7 +249,6 @@ class RoomCreationHandler(BaseHandler):
self,
creator, # A Requester object.
room_id,
- msg_handler,
room_member_handler,
preset_config,
invite_list,
@@ -272,7 +271,7 @@ class RoomCreationHandler(BaseHandler):
@defer.inlineCallbacks
def send(etype, content, **kwargs):
event = create(etype, content, **kwargs)
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_handler.create_and_send_nonmember_event(
creator,
event,
ratelimit=False
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index bb40075387..dfa09141ed 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -203,7 +203,8 @@ class RoomListHandler(BaseHandler):
if limit:
step = limit + 1
else:
- step = len(rooms_to_scan)
+ # step cannot be zero
+ step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1
chunk = []
for i in xrange(0, len(rooms_to_scan), step):
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 7e6467cd1d..37dc5e99ab 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -46,6 +47,7 @@ class RoomMemberHandler(BaseHandler):
super(RoomMemberHandler, self).__init__(hs)
self.profile_handler = hs.get_profile_handler()
+ self.event_creation_hander = hs.get_event_creation_handler()
self.member_linearizer = Linearizer(name="member")
@@ -66,13 +68,12 @@ class RoomMemberHandler(BaseHandler):
):
if content is None:
content = {}
- msg_handler = self.hs.get_handlers().message_handler
content["membership"] = membership
if requester.is_guest:
content["kind"] = "guest"
- event, context = yield msg_handler.create_event(
+ event, context = yield self.event_creation_hander.create_event(
requester,
{
"type": EventTypes.Member,
@@ -90,12 +91,14 @@ class RoomMemberHandler(BaseHandler):
)
# Check if this event matches the previous membership event for the user.
- duplicate = yield msg_handler.deduplicate_state_event(event, context)
+ duplicate = yield self.event_creation_hander.deduplicate_state_event(
+ event, context,
+ )
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
defer.returnValue(duplicate)
- yield msg_handler.handle_new_client_event(
+ yield self.event_creation_hander.handle_new_client_event(
requester,
event,
context,
@@ -394,8 +397,9 @@ class RoomMemberHandler(BaseHandler):
else:
requester = synapse.types.create_requester(target_user)
- message_handler = self.hs.get_handlers().message_handler
- prev_event = yield message_handler.deduplicate_state_event(event, context)
+ prev_event = yield self.event_creation_hander.deduplicate_state_event(
+ event, context,
+ )
if prev_event is not None:
return
@@ -412,7 +416,7 @@ class RoomMemberHandler(BaseHandler):
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
- yield message_handler.handle_new_client_event(
+ yield self.event_creation_hander.handle_new_client_event(
requester,
event,
context,
@@ -644,8 +648,7 @@ class RoomMemberHandler(BaseHandler):
)
)
- msg_handler = self.hs.get_handlers().message_handler
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_hander.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.ThirdPartyInvite,
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 930d713013..f3e4973c2e 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -18,6 +18,7 @@ from OpenSSL.SSL import VERIFY_NONE
from synapse.api.errors import (
CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
)
+from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.logcontext import make_deferred_yieldable
from synapse.util import logcontext
import synapse.metrics
@@ -67,7 +68,11 @@ class SimpleHttpClient(object):
self.hs = hs
pool = HTTPConnectionPool(reactor)
- pool.maxPersistentPerHost = 5
+
+ # the pusher makes lots of concurrent SSL connections to sygnal, and
+ # tends to do so in batches, so we need to allow the pool to keep lots
+ # of idle connections around.
+ pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
pool.cachedConnectionTimeout = 2 * 60
# The default context factory in Twisted 14.0.0 (which we require) is
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index e2b99ef3bd..87639b9151 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -357,8 +357,7 @@ def _get_hosts_for_srv_record(dns_client, host):
def eb(res, record_type):
if res.check(DNSNameError):
return []
- logger.warn("Error looking up %s for %s: %s",
- record_type, host, res, res.value)
+ logger.warn("Error looking up %s for %s: %s", record_type, host, res)
return res
# no logcontexts here, so we can safely fire these off and gatherResults
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 833496b72d..9145405cb0 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -27,7 +27,7 @@ import synapse.metrics
from canonicaljson import encode_canonical_json
from synapse.api.errors import (
- SynapseError, Codes, HttpResponseException,
+ SynapseError, Codes, HttpResponseException, FederationDeniedError,
)
from signedjson.sign import sign_json
@@ -123,11 +123,22 @@ class MatrixFederationHttpClient(object):
Fails with ``HTTPRequestException``: if we get an HTTP response
code >= 300.
+
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
+
(May also fail with plenty of other Exceptions for things like DNS
failures, connection failures, SSL failures.)
"""
+ if (
+ self.hs.config.federation_domain_whitelist and
+ destination not in self.hs.config.federation_domain_whitelist
+ ):
+ raise FederationDeniedError(destination)
+
limiter = yield synapse.util.retryutils.get_retry_limiter(
destination,
self.clock,
@@ -308,6 +319,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
"""
if not json_data_callback:
@@ -368,6 +382,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
"""
def body_callback(method, url_bytes, headers_dict):
@@ -422,6 +439,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
"""
logger.debug("get_json args: %s", args)
@@ -475,6 +495,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
"""
response = yield self._request(
@@ -518,6 +541,9 @@ class MatrixFederationHttpClient(object):
Fails with ``NotRetryingDestination`` if we are not yet ready
to retry this server.
+
+ Fails with ``FederationDeniedError`` if this destination
+ is not on our federation whitelist
"""
encoded_args = {}
diff --git a/synapse/metrics/__init__.py b/synapse/metrics/__init__.py
index 2265e6e8d6..e0cfb7d08f 100644
--- a/synapse/metrics/__init__.py
+++ b/synapse/metrics/__init__.py
@@ -146,10 +146,15 @@ def runUntilCurrentTimer(func):
num_pending += 1
num_pending += len(reactor.threadCallQueue)
-
start = time.time() * 1000
ret = func(*args, **kwargs)
end = time.time() * 1000
+
+ # record the amount of wallclock time spent running pending calls.
+ # This is a proxy for the actual amount of time between reactor polls,
+ # since about 25% of time is actually spent running things triggered by
+ # I/O events, but that is harder to capture without rewriting half the
+ # reactor.
tick_time.inc_by(end - start)
pending_calls_metric.inc_by(num_pending)
diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py
index 1e783e5ff4..ff5aa8c0e1 100644
--- a/synapse/metrics/metric.py
+++ b/synapse/metrics/metric.py
@@ -193,7 +193,9 @@ class DistributionMetric(object):
class CacheMetric(object):
- __slots__ = ("name", "cache_name", "hits", "misses", "size_callback")
+ __slots__ = (
+ "name", "cache_name", "hits", "misses", "evicted_size", "size_callback",
+ )
def __init__(self, name, size_callback, cache_name):
self.name = name
@@ -201,6 +203,7 @@ class CacheMetric(object):
self.hits = 0
self.misses = 0
+ self.evicted_size = 0
self.size_callback = size_callback
@@ -210,6 +213,9 @@ class CacheMetric(object):
def inc_misses(self):
self.misses += 1
+ def inc_evictions(self, size=1):
+ self.evicted_size += size
+
def render(self):
size = self.size_callback()
hits = self.hits
@@ -219,6 +225,9 @@ class CacheMetric(object):
"""%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits),
"""%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
"""%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
+ """%s:evicted_size{name="%s"} %d""" % (
+ self.name, self.cache_name, self.evicted_size
+ ),
]
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 29d7296b43..8acb5df0f3 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -19,7 +19,7 @@ from synapse.storage import DataStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.event_push_actions import EventPushActionsStore
from synapse.storage.roommember import RoomMemberStore
-from synapse.storage.state import StateGroupReadStore
+from synapse.storage.state import StateGroupWorkerStore
from synapse.storage.stream import StreamStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import BaseSlavedStore
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
# the method descriptor on the DataStore and chuck them into our class.
-class SlavedEventStore(StateGroupReadStore, BaseSlavedStore):
+class SlavedEventStore(StateGroupWorkerStore, BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedEventStore, self).__init__(db_conn, hs)
diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index f954d2ea65..2ad486c67d 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -180,6 +181,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
self.state = hs.get_state_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
@defer.inlineCallbacks
def on_POST(self, request, room_id):
@@ -212,8 +214,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
)
new_room_id = info["room_id"]
- msg_handler = self.handlers.message_handler
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_handler.create_and_send_nonmember_event(
room_creator_requester,
{
"type": "m.room.message",
@@ -298,6 +299,27 @@ class QuarantineMediaInRoom(ClientV1RestServlet):
defer.returnValue((200, {"num_quarantined": num_quarantined}))
+class ListMediaInRoom(ClientV1RestServlet):
+ """Lists all of the media in a given room.
+ """
+ PATTERNS = client_path_patterns("/admin/room/(?P<room_id>[^/]+)/media")
+
+ def __init__(self, hs):
+ super(ListMediaInRoom, self).__init__(hs)
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+ is_admin = yield self.auth.is_server_admin(requester.user)
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id)
+
+ defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs}))
+
+
class ResetPasswordRestServlet(ClientV1RestServlet):
"""Post request to allow an administrator reset password for a user.
This needs user to have administrator access in Synapse.
@@ -496,3 +518,4 @@ def register_servlets(hs, http_server):
SearchUsersRestServlet(hs).register(http_server)
ShutdownRoomRestServlet(hs).register(http_server)
QuarantineMediaInRoom(hs).register(http_server)
+ ListMediaInRoom(hs).register(http_server)
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index 32ed1d3ab2..5c5fa8f7ab 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -70,10 +70,15 @@ class RegisterRestServlet(ClientV1RestServlet):
self.handlers = hs.get_handlers()
def on_GET(self, request):
+
+ require_email = 'email' in self.hs.config.registrations_require_3pid
+ require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
+
+ flows = []
if self.hs.config.enable_registration_captcha:
- return (
- 200,
- {"flows": [
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if not require_msisdn:
+ flows.extend([
{
"type": LoginType.RECAPTCHA,
"stages": [
@@ -82,27 +87,34 @@ class RegisterRestServlet(ClientV1RestServlet):
LoginType.PASSWORD
]
},
+ ])
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ flows.extend([
{
"type": LoginType.RECAPTCHA,
"stages": [LoginType.RECAPTCHA, LoginType.PASSWORD]
}
- ]}
- )
+ ])
else:
- return (
- 200,
- {"flows": [
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if require_email or not require_msisdn:
+ flows.extend([
{
"type": LoginType.EMAIL_IDENTITY,
"stages": [
LoginType.EMAIL_IDENTITY, LoginType.PASSWORD
]
- },
+ }
+ ])
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ flows.extend([
{
"type": LoginType.PASSWORD
}
- ]}
- )
+ ])
+ return (200, {"flows": flows})
@defer.inlineCallbacks
def on_POST(self, request):
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 682a0af9fc..fbb2fc36e4 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -82,6 +83,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
+ self.event_creation_hander = hs.get_event_creation_handler()
def register(self, http_server):
# /room/$roomid/state/$eventtype
@@ -162,15 +164,16 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
content=content,
)
else:
- msg_handler = self.handlers.message_handler
- event, context = yield msg_handler.create_event(
+ event, context = yield self.event_creation_hander.create_event(
requester,
event_dict,
token_id=requester.access_token_id,
txn_id=txn_id,
)
- yield msg_handler.send_nonmember_event(requester, event, context)
+ yield self.event_creation_hander.send_nonmember_event(
+ requester, event, context,
+ )
ret = {}
if event:
@@ -184,6 +187,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
+ self.event_creation_hander = hs.get_event_creation_handler()
def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id]
@@ -195,15 +199,19 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
- msg_handler = self.handlers.message_handler
- event = yield msg_handler.create_and_send_nonmember_event(
+ event_dict = {
+ "type": event_type,
+ "content": content,
+ "room_id": room_id,
+ "sender": requester.user.to_string(),
+ }
+
+ if 'ts' in request.args and requester.app_service:
+ event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
+
+ event = yield self.event_creation_hander.create_and_send_nonmember_event(
requester,
- {
- "type": event_type,
- "content": content,
- "room_id": room_id,
- "sender": requester.user.to_string(),
- },
+ event_dict,
txn_id=txn_id,
)
@@ -665,6 +673,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomRedactEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
+ self.event_creation_handler = hs.get_event_creation_handler()
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
@@ -675,8 +684,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
- msg_handler = self.handlers.message_handler
- event = yield msg_handler.create_and_send_nonmember_event(
+ event = yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Redaction,
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 385a3ad2ec..30523995af 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -26,6 +26,7 @@ from synapse.http.servlet import (
)
from synapse.util.async import run_on_reactor
from synapse.util.msisdn import phone_number_to_msisdn
+from synapse.util.threepids import check_3pid_allowed
from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
@@ -47,6 +48,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
'id_server', 'client_secret', 'email', 'send_attempt'
])
+ if not check_3pid_allowed(self.hs, "email", body['email']):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
@@ -78,6 +84,11 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+ if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.datastore.get_user_id_by_threepid(
'msisdn', msisdn
)
@@ -217,6 +228,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+ if not check_3pid_allowed(self.hs, "email", body['email']):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.datastore.get_user_id_by_threepid(
'email', body['email']
)
@@ -255,6 +271,11 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+ if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.datastore.get_user_id_by_threepid(
'msisdn', msisdn
)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index e9d88a8895..c6f4680a76 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -26,6 +26,7 @@ from synapse.http.servlet import (
RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
)
from synapse.util.msisdn import phone_number_to_msisdn
+from synapse.util.threepids import check_3pid_allowed
from ._base import client_v2_patterns, interactive_auth_handler
@@ -70,6 +71,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
'id_server', 'client_secret', 'email', 'send_attempt'
])
+ if not check_3pid_allowed(self.hs, "email", body['email']):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
@@ -105,6 +111,11 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+ if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'msisdn', msisdn
)
@@ -305,31 +316,67 @@ class RegisterRestServlet(RestServlet):
if 'x_show_msisdn' in body and body['x_show_msisdn']:
show_msisdn = True
+ # FIXME: need a better error than "no auth flow found" for scenarios
+ # where we required 3PID for registration but the user didn't give one
+ require_email = 'email' in self.hs.config.registrations_require_3pid
+ require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
+
+ flows = []
if self.hs.config.enable_registration_captcha:
- flows = [
- [LoginType.RECAPTCHA],
- [LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
- ]
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ flows.extend([[LoginType.RECAPTCHA]])
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if not require_msisdn:
+ flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]])
+
if show_msisdn:
+ # only support the MSISDN-only flow if we don't require email 3PIDs
+ if not require_email:
+ flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]])
+ # always let users provide both MSISDN & email
flows.extend([
- [LoginType.MSISDN, LoginType.RECAPTCHA],
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
])
else:
- flows = [
- [LoginType.DUMMY],
- [LoginType.EMAIL_IDENTITY],
- ]
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ flows.extend([[LoginType.DUMMY]])
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if not require_msisdn:
+ flows.extend([[LoginType.EMAIL_IDENTITY]])
+
if show_msisdn:
+ # only support the MSISDN-only flow if we don't require email 3PIDs
+ if not require_email or require_msisdn:
+ flows.extend([[LoginType.MSISDN]])
+ # always let users provide both MSISDN & email
flows.extend([
- [LoginType.MSISDN],
- [LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
+ [LoginType.MSISDN, LoginType.EMAIL_IDENTITY]
])
auth_result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request)
)
+ # Check that we're not trying to register a denied 3pid.
+ #
+ # the user-facing checks will probably already have happened in
+ # /register/email/requestToken when we requested a 3pid, but that's not
+ # guaranteed.
+
+ if auth_result:
+ for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
+ if login_type in auth_result:
+ medium = auth_result[login_type]['medium']
+ address = auth_result[login_type]['address']
+
+ if not check_3pid_allowed(self.hs, medium, address):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed",
+ Codes.THREEPID_DENIED,
+ )
+
if registered_user_id is not None:
logger.info(
"Already registered user ID %r for this session",
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index cc2842aa72..17e6079cba 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -93,6 +93,7 @@ class RemoteKey(Resource):
self.store = hs.get_datastore()
self.version_string = hs.version_string
self.clock = hs.get_clock()
+ self.federation_domain_whitelist = hs.config.federation_domain_whitelist
def render_GET(self, request):
self.async_render_GET(request)
@@ -137,6 +138,13 @@ class RemoteKey(Resource):
logger.info("Handling query for keys %r", query)
store_queries = []
for server_name, key_ids in query.items():
+ if (
+ self.federation_domain_whitelist is not None and
+ server_name not in self.federation_domain_whitelist
+ ):
+ logger.debug("Federation denied with %s", server_name)
+ continue
+
if not key_ids:
key_ids = (None,)
for key_id in key_ids:
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index b2c76440b7..bb79599379 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -27,15 +27,14 @@ from .identicon_resource import IdenticonResource
from .preview_url_resource import PreviewUrlResource
from .filepath import MediaFilePaths
from .thumbnailer import Thumbnailer
-from .storage_provider import (
- StorageProviderWrapper, FileStorageProviderBackend,
-)
+from .storage_provider import StorageProviderWrapper
from .media_storage import MediaStorage
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.util.stringutils import random_string
-from synapse.api.errors import SynapseError, HttpResponseException, \
- NotFoundError
+from synapse.api.errors import (
+ SynapseError, HttpResponseException, NotFoundError, FederationDeniedError,
+)
from synapse.util.async import Linearizer
from synapse.util.stringutils import is_ascii
@@ -77,21 +76,19 @@ class MediaRepository(object):
self.recently_accessed_remotes = set()
self.recently_accessed_locals = set()
+ self.federation_domain_whitelist = hs.config.federation_domain_whitelist
+
# List of StorageProviders where we should search for media and
# potentially upload to.
storage_providers = []
- # TODO: Move this into config and allow other storage providers to be
- # defined.
- if hs.config.backup_media_store_path:
- backend = FileStorageProviderBackend(
- self.primary_base_path, hs.config.backup_media_store_path,
- )
+ for clz, provider_config, wrapper_config in hs.config.media_storage_providers:
+ backend = clz(hs, provider_config)
provider = StorageProviderWrapper(
backend,
- store=True,
- store_synchronous=hs.config.synchronous_backup_media_store,
- store_remote=True,
+ store_local=wrapper_config.store_local,
+ store_remote=wrapper_config.store_remote,
+ store_synchronous=wrapper_config.store_synchronous,
)
storage_providers.append(provider)
@@ -222,6 +219,12 @@ class MediaRepository(object):
Deferred: Resolves once a response has successfully been written
to request
"""
+ if (
+ self.federation_domain_whitelist is not None and
+ server_name not in self.federation_domain_whitelist
+ ):
+ raise FederationDeniedError(server_name)
+
self.mark_recently_accessed(server_name, media_id)
# We linearize here to ensure that we don't try and download remote
@@ -256,6 +259,12 @@ class MediaRepository(object):
Returns:
Deferred[dict]: The media_info of the file
"""
+ if (
+ self.federation_domain_whitelist is not None and
+ server_name not in self.federation_domain_whitelist
+ ):
+ raise FederationDeniedError(server_name)
+
# We linearize here to ensure that we don't try and download remote
# media multiple times concurrently
key = (server_name, media_id)
@@ -463,8 +472,10 @@ class MediaRepository(object):
@defer.inlineCallbacks
def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
- t_method, t_type):
- input_path = self.filepaths.local_media_filepath(media_id)
+ t_method, t_type, url_cache):
+ input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
+ None, media_id, url_cache=url_cache,
+ ))
thumbnailer = Thumbnailer(input_path)
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
@@ -477,6 +488,7 @@ class MediaRepository(object):
file_info = FileInfo(
server_name=None,
file_id=media_id,
+ url_cache=url_cache,
thumbnail=True,
thumbnail_width=t_width,
thumbnail_height=t_height,
@@ -503,7 +515,9 @@ class MediaRepository(object):
@defer.inlineCallbacks
def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
t_width, t_height, t_method, t_type):
- input_path = self.filepaths.remote_media_filepath(server_name, file_id)
+ input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
+ server_name, file_id, url_cache=False,
+ ))
thumbnailer = Thumbnailer(input_path)
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
@@ -561,12 +575,9 @@ class MediaRepository(object):
if not requirements:
return
- if server_name:
- input_path = self.filepaths.remote_media_filepath(server_name, file_id)
- elif url_cache:
- input_path = self.filepaths.url_cache_filepath(media_id)
- else:
- input_path = self.filepaths.local_media_filepath(media_id)
+ input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
+ server_name, file_id, url_cache=url_cache,
+ ))
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 001e84578e..e8e8b3986d 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -18,6 +18,7 @@ from twisted.protocols.basic import FileSender
from ._base import Responder
+from synapse.util.file_consumer import BackgroundFileConsumer
from synapse.util.logcontext import make_deferred_yieldable
import contextlib
@@ -26,6 +27,7 @@ import logging
import shutil
import sys
+
logger = logging.getLogger(__name__)
@@ -151,6 +153,37 @@ class MediaStorage(object):
defer.returnValue(None)
+ @defer.inlineCallbacks
+ def ensure_media_is_in_local_cache(self, file_info):
+ """Ensures that the given file is in the local cache. Attempts to
+ download it from storage providers if it isn't.
+
+ Args:
+ file_info (FileInfo)
+
+ Returns:
+ Deferred[str]: Full path to local file
+ """
+ path = self._file_info_to_path(file_info)
+ local_path = os.path.join(self.local_media_directory, path)
+ if os.path.exists(local_path):
+ defer.returnValue(local_path)
+
+ dirname = os.path.dirname(local_path)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+ for provider in self.storage_providers:
+ res = yield provider.fetch(path, file_info)
+ if res:
+ with res:
+ consumer = BackgroundFileConsumer(open(local_path, "w"))
+ yield res.write_to_consumer(consumer)
+ yield consumer.wait()
+ defer.returnValue(local_path)
+
+ raise Exception("file could not be found")
+
def _file_info_to_path(self, file_info):
"""Converts file_info into a relative path.
@@ -164,6 +197,14 @@ class MediaStorage(object):
str
"""
if file_info.url_cache:
+ if file_info.thumbnail:
+ return self.filepaths.url_cache_thumbnail_rel(
+ media_id=file_info.file_id,
+ width=file_info.thumbnail_width,
+ height=file_info.thumbnail_height,
+ content_type=file_info.thumbnail_type,
+ method=file_info.thumbnail_method,
+ )
return self.filepaths.url_cache_filepath_rel(file_info.file_id)
if file_info.server_name:
@@ -220,9 +261,8 @@ class FileResponder(Responder):
def __init__(self, open_file):
self.open_file = open_file
- @defer.inlineCallbacks
def write_to_consumer(self, consumer):
- yield FileSender().beginFileTransfer(self.open_file, consumer)
+ return FileSender().beginFileTransfer(self.open_file, consumer)
def __exit__(self, exc_type, exc_val, exc_tb):
self.open_file.close()
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 981f01e417..31fe7aa75c 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -12,6 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import cgi
+import datetime
+import errno
+import fnmatch
+import itertools
+import logging
+import os
+import re
+import shutil
+import sys
+import traceback
+import ujson as json
+import urlparse
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
@@ -33,18 +46,6 @@ from synapse.http.server import (
from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii
-import os
-import re
-import fnmatch
-import cgi
-import ujson as json
-import urlparse
-import itertools
-import datetime
-import errno
-import shutil
-
-import logging
logger = logging.getLogger(__name__)
@@ -286,17 +287,28 @@ class PreviewUrlResource(Resource):
url_cache=True,
)
- try:
- with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ try:
logger.debug("Trying to get url '%s'" % url)
length, headers, uri, code = yield self.client.get_file(
url, output_stream=f, max_size=self.max_spider_size,
)
+ except Exception as e:
# FIXME: pass through 404s and other error messages nicely
+ logger.warn("Error downloading %s: %r", url, e)
+ raise SynapseError(
+ 500, "Failed to download content: %s" % (
+ traceback.format_exception_only(sys.exc_type, e),
+ ),
+ Codes.UNKNOWN,
+ )
+ yield finish()
- yield finish()
-
- media_type = headers["Content-Type"][0]
+ try:
+ if "Content-Type" in headers:
+ media_type = headers["Content-Type"][0]
+ else:
+ media_type = "application/octet-stream"
time_now_ms = self.clock.time_msec()
content_disposition = headers.get("Content-Disposition", None)
@@ -336,10 +348,11 @@ class PreviewUrlResource(Resource):
)
except Exception as e:
- raise SynapseError(
- 500, ("Failed to download content: %s" % e),
- Codes.UNKNOWN
- )
+ logger.error("Error handling downloaded %s: %r", url, e)
+ # TODO: we really ought to delete the downloaded file in this
+ # case, since we won't have recorded it in the db, and will
+ # therefore not expire it.
+ raise
defer.returnValue({
"media_type": media_type,
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 2ad602e101..c188192f2b 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -17,6 +17,7 @@ from twisted.internet import defer, threads
from .media_storage import FileResponder
+from synapse.config._base import Config
from synapse.util.logcontext import preserve_fn
import logging
@@ -64,19 +65,19 @@ class StorageProviderWrapper(StorageProvider):
Args:
backend (StorageProvider)
- store (bool): Whether to store new files or not.
+ store_local (bool): Whether to store new local files or not.
store_synchronous (bool): Whether to wait for file to be successfully
uploaded, or todo the upload in the backgroud.
store_remote (bool): Whether remote media should be uploaded
"""
- def __init__(self, backend, store, store_synchronous, store_remote):
+ def __init__(self, backend, store_local, store_synchronous, store_remote):
self.backend = backend
- self.store = store
+ self.store_local = store_local
self.store_synchronous = store_synchronous
self.store_remote = store_remote
def store_file(self, path, file_info):
- if not self.store:
+ if not file_info.server_name and not self.store_local:
return defer.succeed(None)
if file_info.server_name and not self.store_remote:
@@ -97,13 +98,13 @@ class FileStorageProviderBackend(StorageProvider):
"""A storage provider that stores files in a directory on a filesystem.
Args:
- cache_directory (str): Base path of the local media repository
- base_directory (str): Base path to store new files
+ hs (HomeServer)
+ config: The config returned by `parse_config`.
"""
- def __init__(self, cache_directory, base_directory):
- self.cache_directory = cache_directory
- self.base_directory = base_directory
+ def __init__(self, hs, config):
+ self.cache_directory = hs.config.media_store_path
+ self.base_directory = config
def store_file(self, path, file_info):
"""See StorageProvider.store_file"""
@@ -125,3 +126,15 @@ class FileStorageProviderBackend(StorageProvider):
backup_fname = os.path.join(self.base_directory, path)
if os.path.isfile(backup_fname):
return FileResponder(open(backup_fname, "rb"))
+
+ @staticmethod
+ def parse_config(config):
+ """Called on startup to parse config supplied. This should parse
+ the config and raise if there is a problem.
+
+ The returned value is passed into the constructor.
+
+ In this case we only care about a single param, the directory, so let's
+ just pull that out.
+ """
+ return Config.ensure_directory(config["directory"])
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index c09f2dec41..58ada49711 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -67,7 +67,7 @@ class ThumbnailResource(Resource):
yield self._respond_local_thumbnail(
request, media_id, width, height, method, m_type
)
- self.media_repo.mark_recently_accessed(server_name, media_id)
+ self.media_repo.mark_recently_accessed(None, media_id)
else:
if self.dynamic_thumbnails:
yield self._select_or_generate_remote_thumbnail(
@@ -79,7 +79,7 @@ class ThumbnailResource(Resource):
request, server_name, media_id,
width, height, method, m_type
)
- self.media_repo.mark_recently_accessed(None, media_id)
+ self.media_repo.mark_recently_accessed(server_name, media_id)
@defer.inlineCallbacks
def _respond_local_thumbnail(self, request, media_id, width, height,
@@ -164,7 +164,8 @@ class ThumbnailResource(Resource):
# Okay, so we generate one.
file_path = yield self.media_repo.generate_local_exact_thumbnail(
- media_id, desired_width, desired_height, desired_method, desired_type
+ media_id, desired_width, desired_height, desired_method, desired_type,
+ url_cache=media_info["url_cache"],
)
if file_path:
diff --git a/synapse/server.py b/synapse/server.py
index 99693071b6..fbd602d40e 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -55,6 +55,7 @@ from synapse.handlers.read_marker import ReadMarkerHandler
from synapse.handlers.user_directory import UserDirectoryHandler
from synapse.handlers.groups_local import GroupsLocalHandler
from synapse.handlers.profile import ProfileHandler
+from synapse.handlers.message import EventCreationHandler
from synapse.groups.groups_server import GroupsServerHandler
from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
@@ -66,7 +67,7 @@ from synapse.rest.media.v1.media_repository import (
MediaRepository,
MediaRepositoryResource,
)
-from synapse.state import StateHandler
+from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import DataStore
from synapse.streams.events import EventSources
from synapse.util import Clock
@@ -102,6 +103,7 @@ class HomeServer(object):
'v1auth',
'auth',
'state_handler',
+ 'state_resolution_handler',
'presence_handler',
'sync_handler',
'typing_handler',
@@ -117,6 +119,7 @@ class HomeServer(object):
'application_service_handler',
'device_message_handler',
'profile_handler',
+ 'event_creation_handler',
'deactivate_account_handler',
'set_password_handler',
'notifier',
@@ -224,6 +227,9 @@ class HomeServer(object):
def build_state_handler(self):
return StateHandler(self)
+ def build_state_resolution_handler(self):
+ return StateResolutionHandler(self)
+
def build_presence_handler(self):
return PresenceHandler(self)
@@ -272,6 +278,9 @@ class HomeServer(object):
def build_profile_handler(self):
return ProfileHandler(self)
+ def build_event_creation_handler(self):
+ return EventCreationHandler(self)
+
def build_deactivate_account_handler(self):
return DeactivateAccountHandler(self)
@@ -307,6 +316,23 @@ class HomeServer(object):
**self.db_config.get("args", {})
)
+ def get_db_conn(self, run_new_connection=True):
+ """Makes a new connection to the database, skipping the db pool
+
+ Returns:
+ Connection: a connection object implementing the PEP-249 spec
+ """
+ # Any param beginning with cp_ is a parameter for adbapi, and should
+ # not be passed to the database engine.
+ db_params = {
+ k: v for k, v in self.db_config.get("args", {}).items()
+ if not k.startswith("cp_")
+ }
+ db_conn = self.database_engine.module.connect(**db_params)
+ if run_new_connection:
+ self.database_engine.on_new_connection(db_conn)
+ return db_conn
+
def build_media_repository_resource(self):
# build the media repo resource. This indirects through the HomeServer
# to ensure that we only have a single instance of
diff --git a/synapse/server.pyi b/synapse/server.pyi
index 41416ef252..c3a9a3847b 100644
--- a/synapse/server.pyi
+++ b/synapse/server.pyi
@@ -34,6 +34,9 @@ class HomeServer(object):
def get_state_handler(self) -> synapse.state.StateHandler:
pass
+ def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
+ pass
+
def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
pass
diff --git a/synapse/state.py b/synapse/state.py
index 1f9abf9d3d..cc93bbcb6b 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -58,7 +58,11 @@ class _StateCacheEntry(object):
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
def __init__(self, state, state_group, prev_group=None, delta_ids=None):
+ # dict[(str, str), str] map from (type, state_key) to event_id
self.state = frozendict(state)
+
+ # the ID of a state group if one and only one is involved.
+ # otherwise, None otherwise?
self.state_group = state_group
self.prev_group = prev_group
@@ -81,31 +85,19 @@ class _StateCacheEntry(object):
class StateHandler(object):
- """ Responsible for doing state conflict resolution.
+ """Fetches bits of state from the stores, and does state resolution
+ where necessary
"""
def __init__(self, hs):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.hs = hs
-
- # dict of set of event_ids -> _StateCacheEntry.
- self._state_cache = None
- self.resolve_linearizer = Linearizer(name="state_resolve_lock")
+ self._state_resolution_handler = hs.get_state_resolution_handler()
def start_caching(self):
- logger.debug("start_caching")
-
- self._state_cache = ExpiringCache(
- cache_name="state_cache",
- clock=self.clock,
- max_len=SIZE_OF_CACHE,
- expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
- iterable=True,
- reset_expiry_on_get=True,
- )
-
- self._state_cache.start()
+ # TODO: remove this shim
+ self._state_resolution_handler.start_caching()
@defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key="",
@@ -127,7 +119,7 @@ class StateHandler(object):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state")
- ret = yield self.resolve_state_groups(room_id, latest_event_ids)
+ ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
if event_type:
@@ -146,19 +138,27 @@ class StateHandler(object):
defer.returnValue(state)
@defer.inlineCallbacks
- def get_current_state_ids(self, room_id, event_type=None, state_key="",
- latest_event_ids=None):
+ def get_current_state_ids(self, room_id, latest_event_ids=None):
+ """Get the current state, or the state at a set of events, for a room
+
+ Args:
+ room_id (str):
+
+ latest_event_ids (iterable[str]|None): if given, the forward
+ extremities to resolve. If None, we look them up from the
+ database (via a cache)
+
+ Returns:
+ Deferred[dict[(str, str), str)]]: the state dict, mapping from
+ (event_type, state_key) -> event_id
+ """
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state_ids")
- ret = yield self.resolve_state_groups(room_id, latest_event_ids)
+ ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
- if event_type:
- defer.returnValue(state.get((event_type, state_key)))
- return
-
defer.returnValue(state)
@defer.inlineCallbacks
@@ -166,7 +166,7 @@ class StateHandler(object):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_user_in_room")
- entry = yield self.resolve_state_groups(room_id, latest_event_ids)
+ entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
defer.returnValue(joined_users)
@@ -175,7 +175,7 @@ class StateHandler(object):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
- entry = yield self.resolve_state_groups(room_id, latest_event_ids)
+ entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
defer.returnValue(joined_hosts)
@@ -183,8 +183,15 @@ class StateHandler(object):
def compute_event_context(self, event, old_state=None):
"""Build an EventContext structure for the event.
+ This works out what the current state should be for the event, and
+ generates a new state group if necessary.
+
Args:
event (synapse.events.EventBase):
+ old_state (dict|None): The state at the event if it can't be
+ calculated from existing events. This is normally only specified
+ when receiving an event from federation where we don't have the
+ prev events for, e.g. when backfilling.
Returns:
synapse.events.snapshot.EventContext:
"""
@@ -208,15 +215,22 @@ class StateHandler(object):
context.current_state_ids = {}
context.prev_state_ids = {}
context.prev_state_events = []
- context.state_group = self.store.get_next_state_group()
+
+ # We don't store state for outliers, so we don't generate a state
+ # froup for it.
+ context.state_group = None
+
defer.returnValue(context)
if old_state:
+ # We already have the state, so we don't need to calculate it.
+ # Let's just correctly fill out the context and create a
+ # new state group for it.
+
context = EventContext()
context.prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state
}
- context.state_group = self.store.get_next_state_group()
if event.is_state():
key = (event.type, event.state_key)
@@ -229,11 +243,19 @@ class StateHandler(object):
else:
context.current_state_ids = context.prev_state_ids
+ context.state_group = yield self.store.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=None,
+ delta_ids=None,
+ current_state_ids=context.current_state_ids,
+ )
+
context.prev_state_events = []
defer.returnValue(context)
logger.debug("calling resolve_state_groups from compute_event_context")
- entry = yield self.resolve_state_groups(
+ entry = yield self.resolve_state_groups_for_events(
event.room_id, [e for e, _ in event.prev_events],
)
@@ -242,7 +264,8 @@ class StateHandler(object):
context = EventContext()
context.prev_state_ids = curr_state
if event.is_state():
- context.state_group = self.store.get_next_state_group()
+ # If this is a state event then we need to create a new state
+ # group for the state after this event.
key = (event.type, event.state_key)
if key in context.prev_state_ids:
@@ -253,38 +276,57 @@ class StateHandler(object):
context.current_state_ids[key] = event.event_id
if entry.state_group:
+ # If the state at the event has a state group assigned then
+ # we can use that as the prev group
context.prev_group = entry.state_group
context.delta_ids = {
key: event.event_id
}
elif entry.prev_group:
+ # If the state at the event only has a prev group, then we can
+ # use that as a prev group too.
context.prev_group = entry.prev_group
context.delta_ids = dict(entry.delta_ids)
context.delta_ids[key] = event.event_id
+
+ context.state_group = yield self.store.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=context.prev_group,
+ delta_ids=context.delta_ids,
+ current_state_ids=context.current_state_ids,
+ )
else:
+ context.current_state_ids = context.prev_state_ids
+ context.prev_group = entry.prev_group
+ context.delta_ids = entry.delta_ids
+
if entry.state_group is None:
- entry.state_group = self.store.get_next_state_group()
+ entry.state_group = yield self.store.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=entry.prev_group,
+ delta_ids=entry.delta_ids,
+ current_state_ids=context.current_state_ids,
+ )
entry.state_id = entry.state_group
context.state_group = entry.state_group
- context.current_state_ids = context.prev_state_ids
- context.prev_group = entry.prev_group
- context.delta_ids = entry.delta_ids
context.prev_state_events = []
defer.returnValue(context)
@defer.inlineCallbacks
- @log_function
- def resolve_state_groups(self, room_id, event_ids):
+ def resolve_state_groups_for_events(self, room_id, event_ids):
""" Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
+ Args:
+ room_id (str):
+ event_ids (list[str]):
+
Returns:
- a Deferred tuple of (`state_group`, `state`, `prev_state`).
- `state_group` is the name of a state group if one and only one is
- involved. `state` is a map from (type, state_key) to event, and
- `prev_state` is a list of event ids.
+ Deferred[_StateCacheEntry]: resolved state
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
@@ -295,13 +337,7 @@ class StateHandler(object):
room_id, event_ids
)
- logger.debug(
- "resolve_state_groups state_groups %s",
- state_groups_ids.keys()
- )
-
- group_names = frozenset(state_groups_ids.keys())
- if len(group_names) == 1:
+ if len(state_groups_ids) == 1:
name, state_list = state_groups_ids.items().pop()
prev_group, delta_ids = yield self.store.get_state_group_delta(name)
@@ -313,6 +349,92 @@ class StateHandler(object):
delta_ids=delta_ids,
))
+ result = yield self._state_resolution_handler.resolve_state_groups(
+ room_id, state_groups_ids, self._state_map_factory,
+ )
+ defer.returnValue(result)
+
+ def _state_map_factory(self, ev_ids):
+ return self.store.get_events(
+ ev_ids, get_prev_content=False, check_redacted=False,
+ )
+
+ def resolve_events(self, state_sets, event):
+ logger.info(
+ "Resolving state for %s with %d groups", event.room_id, len(state_sets)
+ )
+ state_set_ids = [{
+ (ev.type, ev.state_key): ev.event_id
+ for ev in st
+ } for st in state_sets]
+
+ state_map = {
+ ev.event_id: ev
+ for st in state_sets
+ for ev in st
+ }
+
+ with Measure(self.clock, "state._resolve_events"):
+ new_state = resolve_events_with_state_map(state_set_ids, state_map)
+
+ new_state = {
+ key: state_map[ev_id] for key, ev_id in new_state.items()
+ }
+
+ return new_state
+
+
+class StateResolutionHandler(object):
+ """Responsible for doing state conflict resolution.
+
+ Note that the storage layer depends on this handler, so all functions must
+ be storage-independent.
+ """
+ def __init__(self, hs):
+ self.clock = hs.get_clock()
+
+ # dict of set of event_ids -> _StateCacheEntry.
+ self._state_cache = None
+ self.resolve_linearizer = Linearizer(name="state_resolve_lock")
+
+ def start_caching(self):
+ logger.debug("start_caching")
+
+ self._state_cache = ExpiringCache(
+ cache_name="state_cache",
+ clock=self.clock,
+ max_len=SIZE_OF_CACHE,
+ expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
+ iterable=True,
+ reset_expiry_on_get=True,
+ )
+
+ self._state_cache.start()
+
+ @defer.inlineCallbacks
+ @log_function
+ def resolve_state_groups(self, room_id, state_groups_ids, state_map_factory):
+ """Resolves conflicts between a set of state groups
+
+ Always generates a new state group (unless we hit the cache), so should
+ not be called for a single state group
+
+ Args:
+ room_id (str): room we are resolving for (used for logging)
+ state_groups_ids (dict[int, dict[(str, str), str]]):
+ map from state group id to the state in that state group
+ (where 'state' is a map from state key to event id)
+
+ Returns:
+ Deferred[_StateCacheEntry]: resolved state
+ """
+ logger.debug(
+ "resolve_state_groups state_groups %s",
+ state_groups_ids.keys()
+ )
+
+ group_names = frozenset(state_groups_ids.keys())
+
with (yield self.resolve_linearizer.queue(group_names)):
if self._state_cache is not None:
cache = self._state_cache.get(group_names, None)
@@ -343,15 +465,17 @@ class StateHandler(object):
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_factory(
state_groups_ids.values(),
- state_map_factory=lambda ev_ids: self.store.get_events(
- ev_ids, get_prev_content=False, check_redacted=False,
- ),
+ state_map_factory=state_map_factory,
)
else:
new_state = {
key: e_ids.pop() for key, e_ids in state.items()
}
+ # if the new state matches any of the input state groups, we can
+ # use that state group again. Otherwise we will generate a state_id
+ # which will be used as a cache key for future resolutions, but
+ # not get persisted.
state_group = None
new_state_event_ids = frozenset(new_state.values())
for sg, events in state_groups_ids.items():
@@ -388,30 +512,6 @@ class StateHandler(object):
defer.returnValue(cache)
- def resolve_events(self, state_sets, event):
- logger.info(
- "Resolving state for %s with %d groups", event.room_id, len(state_sets)
- )
- state_set_ids = [{
- (ev.type, ev.state_key): ev.event_id
- for ev in st
- } for st in state_sets]
-
- state_map = {
- ev.event_id: ev
- for st in state_sets
- for ev in st
- }
-
- with Measure(self.clock, "state._resolve_events"):
- new_state = resolve_events_with_state_map(state_set_ids, state_map)
-
- new_state = {
- key: state_map[ev_id] for key, ev_id in new_state.items()
- }
-
- return new_state
-
def _ordered_events(events):
def key_func(e):
@@ -429,8 +529,8 @@ def resolve_events_with_state_map(state_sets, state_map):
state_sets.
Returns
- dict[(str, str), synapse.events.FrozenEvent]:
- a map from (type, state_key) to event.
+ dict[(str, str), str]:
+ a map from (type, state_key) to event_id.
"""
if len(state_sets) == 1:
return state_sets[0]
@@ -452,6 +552,21 @@ def _seperate(state_sets):
"""Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated
with them in different state sets.
+
+ Args:
+ state_sets(list[dict[(str, str), str]]):
+ List of dicts of (type, state_key) -> event_id, which are the
+ different state groups to resolve.
+
+ Returns:
+ (dict[(str, str), str], dict[(str, str), set[str]]):
+ A tuple of (unconflicted_state, conflicted_state), where:
+
+ unconflicted_state is a dict mapping (type, state_key)->event_id
+ for unconflicted state keys.
+
+ conflicted_state is a dict mapping (type, state_key) to a set of
+ event ids for conflicted state keys.
"""
unconflicted_state = dict(state_sets[0])
conflicted_state = {}
@@ -492,8 +607,8 @@ def resolve_events_with_factory(state_sets, state_map_factory):
a Deferred of dict of event_id to event.
Returns
- Deferred[dict[(str, str), synapse.events.FrozenEvent]]:
- a map from (type, state_key) to event.
+ Deferred[dict[(str, str), str]]:
+ a map from (type, state_key) to event_id.
"""
if len(state_sets) == 1:
defer.returnValue(state_sets[0])
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index d01d46338a..f8fbd02ceb 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -124,7 +124,6 @@ class DataStore(RoomMemberStore, RoomStore,
)
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
- self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index a6ae79dfad..8a0386c1a4 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -62,3 +62,9 @@ class PostgresEngine(object):
def lock_table(self, txn, table):
txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
+
+ def get_next_state_group_id(self, txn):
+ """Returns an int that can be used as a new state_group ID
+ """
+ txn.execute("SELECT nextval('state_group_id_seq')")
+ return txn.fetchone()[0]
diff --git a/synapse/storage/engines/sqlite3.py b/synapse/storage/engines/sqlite3.py
index 755c9a1f07..60f0fa7fb3 100644
--- a/synapse/storage/engines/sqlite3.py
+++ b/synapse/storage/engines/sqlite3.py
@@ -16,6 +16,7 @@
from synapse.storage.prepare_database import prepare_database
import struct
+import threading
class Sqlite3Engine(object):
@@ -24,6 +25,11 @@ class Sqlite3Engine(object):
def __init__(self, database_module, database_config):
self.module = database_module
+ # The current max state_group, or None if we haven't looked
+ # in the DB yet.
+ self._current_state_group_id = None
+ self._current_state_group_id_lock = threading.Lock()
+
def check_database(self, txn):
pass
@@ -43,6 +49,19 @@ class Sqlite3Engine(object):
def lock_table(self, txn, table):
return
+ def get_next_state_group_id(self, txn):
+ """Returns an int that can be used as a new state_group ID
+ """
+ # We do application locking here since if we're using sqlite then
+ # we are a single process synapse.
+ with self._current_state_group_id_lock:
+ if self._current_state_group_id is None:
+ txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
+ self._current_state_group_id = txn.fetchone()[0]
+
+ self._current_state_group_id += 1
+ return self._current_state_group_id
+
# Following functions taken from: https://github.com/coleifer/peewee
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 238a2006b8..86a7c5920d 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -110,7 +110,7 @@ class _EventPeristenceQueue(object):
end_item.events_and_contexts.extend(events_and_contexts)
return end_item.deferred.observe()
- deferred = ObservableDeferred(defer.Deferred())
+ deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
queue.append(self._EventPersistQueueItem(
events_and_contexts=events_and_contexts,
@@ -152,8 +152,8 @@ class _EventPeristenceQueue(object):
try:
ret = yield per_item_callback(item)
item.deferred.callback(ret)
- except Exception as e:
- item.deferred.errback(e)
+ except Exception:
+ item.deferred.errback()
finally:
queue = self._event_persist_queues.pop(room_id, None)
if queue:
@@ -342,8 +342,20 @@ class EventsStore(SQLBaseStore):
# NB: Assumes that we are only persisting events for one room
# at a time.
+
+ # map room_id->list[event_ids] giving the new forward
+ # extremities in each room
new_forward_extremeties = {}
+
+ # map room_id->(type,state_key)->event_id tracking the full
+ # state in each room after adding these events
current_state_for_room = {}
+
+ # map room_id->(to_delete, to_insert) where each entry is
+ # a map (type,key)->event_id giving the state delta in each
+ # room
+ state_delta_for_room = {}
+
if not backfilled:
with Measure(self._clock, "_calculate_state_and_extrem"):
# Work out the new "current state" for each room.
@@ -386,11 +398,19 @@ class EventsStore(SQLBaseStore):
if all_single_prev_not_state:
continue
- state = yield self._calculate_state_delta(
- room_id, ev_ctx_rm, new_latest_event_ids
+ logger.info(
+ "Calculating state delta for room %s", room_id,
+ )
+ current_state = yield self._get_new_state_after_events(
+ ev_ctx_rm, new_latest_event_ids,
)
- if state:
- current_state_for_room[room_id] = state
+ if current_state is not None:
+ current_state_for_room[room_id] = current_state
+ delta = yield self._calculate_state_delta(
+ room_id, current_state,
+ )
+ if delta is not None:
+ state_delta_for_room[room_id] = delta
yield self.runInteraction(
"persist_events",
@@ -398,7 +418,7 @@ class EventsStore(SQLBaseStore):
events_and_contexts=chunk,
backfilled=backfilled,
delete_existing=delete_existing,
- current_state_for_room=current_state_for_room,
+ state_delta_for_room=state_delta_for_room,
new_forward_extremeties=new_forward_extremeties,
)
persist_event_counter.inc_by(len(chunk))
@@ -415,7 +435,7 @@ class EventsStore(SQLBaseStore):
event_counter.inc(event.type, origin_type, origin_entity)
- for room_id, (_, _, new_state) in current_state_for_room.iteritems():
+ for room_id, new_state in current_state_for_room.iteritems():
self.get_current_state_ids.prefill(
(room_id, ), new_state
)
@@ -467,20 +487,22 @@ class EventsStore(SQLBaseStore):
defer.returnValue(new_latest_event_ids)
@defer.inlineCallbacks
- def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids):
- """Calculate the new state deltas for a room.
+ def _get_new_state_after_events(self, events_context, new_latest_event_ids):
+ """Calculate the current state dict after adding some new events to
+ a room
- Assumes that we are only persisting events for one room at a time.
+ Args:
+ events_context (list[(EventBase, EventContext)]):
+ events and contexts which are being added to the room
+
+ new_latest_event_ids (iterable[str]):
+ the new forward extremities for the room.
Returns:
- 3-tuple (to_delete, to_insert, new_state) where both are state dicts,
- i.e. (type, state_key) -> event_id. `to_delete` are the entries to
- first be deleted from current_state_events, `to_insert` are entries
- to insert. `new_state` is the full set of state.
- May return None if there are no changes to be applied.
+ Deferred[dict[(str,str), str]|None]:
+ None if there are no changes to the room state, or
+ a dict of (type, state_key) -> event_id].
"""
- # Now we need to work out the different state sets for
- # each state extremities
state_sets = []
state_groups = set()
missing_event_ids = []
@@ -523,12 +545,12 @@ class EventsStore(SQLBaseStore):
state_sets.extend(group_to_state.itervalues())
if not new_latest_event_ids:
- current_state = {}
+ defer.returnValue({})
elif was_updated:
if len(state_sets) == 1:
# If there is only one state set, then we know what the current
# state is.
- current_state = state_sets[0]
+ defer.returnValue(state_sets[0])
else:
# We work out the current state by passing the state sets to the
# state resolution algorithm. It may ask for some events, including
@@ -537,8 +559,7 @@ class EventsStore(SQLBaseStore):
# up in the db.
logger.info(
- "Resolving state for %s with %i state sets",
- room_id, len(state_sets),
+ "Resolving state with %i state sets", len(state_sets),
)
events_map = {ev.event_id: ev for ev, _ in events_context}
@@ -567,9 +588,22 @@ class EventsStore(SQLBaseStore):
state_sets,
state_map_factory=get_events,
)
+ defer.returnValue(current_state)
else:
return
+ @defer.inlineCallbacks
+ def _calculate_state_delta(self, room_id, current_state):
+ """Calculate the new state deltas for a room.
+
+ Assumes that we are only persisting events for one room at a time.
+
+ Returns:
+ 2-tuple (to_delete, to_insert) where both are state dicts,
+ i.e. (type, state_key) -> event_id. `to_delete` are the entries to
+ first be deleted from current_state_events, `to_insert` are entries
+ to insert.
+ """
existing_state = yield self.get_current_state_ids(room_id)
existing_events = set(existing_state.itervalues())
@@ -589,7 +623,7 @@ class EventsStore(SQLBaseStore):
if ev_id in events_to_insert
}
- defer.returnValue((to_delete, to_insert, current_state))
+ defer.returnValue((to_delete, to_insert))
@defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True,
@@ -649,7 +683,7 @@ class EventsStore(SQLBaseStore):
@log_function
def _persist_events_txn(self, txn, events_and_contexts, backfilled,
- delete_existing=False, current_state_for_room={},
+ delete_existing=False, state_delta_for_room={},
new_forward_extremeties={}):
"""Insert some number of room events into the necessary database tables.
@@ -665,7 +699,7 @@ class EventsStore(SQLBaseStore):
delete_existing (bool): True to purge existing table rows for the
events from the database. This is useful when retrying due to
IntegrityError.
- current_state_for_room (dict[str, (list[str], list[str])]):
+ state_delta_for_room (dict[str, (list[str], list[str])]):
The current-state delta for each room. For each room, a tuple
(to_delete, to_insert), being a list of event ids to be removed
from the current state, and a list of event ids to be added to
@@ -677,7 +711,7 @@ class EventsStore(SQLBaseStore):
"""
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
- self._update_current_state_txn(txn, current_state_for_room, max_stream_order)
+ self._update_current_state_txn(txn, state_delta_for_room, max_stream_order)
self._update_forward_extremities_txn(
txn,
@@ -721,9 +755,8 @@ class EventsStore(SQLBaseStore):
events_and_contexts=events_and_contexts,
)
- # Insert into the state_groups, state_groups_state, and
- # event_to_state_groups tables.
- self._store_mult_state_groups_txn(txn, events_and_contexts)
+ # Insert into event_to_state_groups.
+ self._store_event_state_mappings_txn(txn, events_and_contexts)
# _store_rejected_events_txn filters out any events which were
# rejected, and returns the filtered list.
@@ -743,7 +776,7 @@ class EventsStore(SQLBaseStore):
def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
for room_id, current_state_tuple in state_delta_by_room.iteritems():
- to_delete, to_insert, _ = current_state_tuple
+ to_delete, to_insert = current_state_tuple
txn.executemany(
"DELETE FROM current_state_events WHERE event_id = ?",
[(ev_id,) for ev_id in to_delete.itervalues()],
@@ -958,10 +991,9 @@ class EventsStore(SQLBaseStore):
# an outlier in the database. We now have some state at that
# so we need to update the state_groups table with that state.
- # insert into the state_group, state_groups_state and
- # event_to_state_groups tables.
+ # insert into event_to_state_groups.
try:
- self._store_mult_state_groups_txn(txn, ((event, context),))
+ self._store_event_state_mappings_txn(txn, ((event, context),))
except Exception:
logger.exception("")
raise
diff --git a/synapse/storage/room.py b/synapse/storage/room.py
index 23688430b7..fff6652e05 100644
--- a/synapse/storage/room.py
+++ b/synapse/storage/room.py
@@ -16,11 +16,9 @@
from twisted.internet import defer
from synapse.api.errors import StoreError
+from synapse.storage.search import SearchStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-from ._base import SQLBaseStore
-from .engines import PostgresEngine, Sqlite3Engine
-
import collections
import logging
import ujson as json
@@ -40,7 +38,7 @@ RatelimitOverride = collections.namedtuple(
)
-class RoomStore(SQLBaseStore):
+class RoomStore(SearchStore):
@defer.inlineCallbacks
def store_room(self, room_id, room_creator_user_id, is_public):
@@ -263,8 +261,8 @@ class RoomStore(SQLBaseStore):
},
)
- self._store_event_search_txn(
- txn, event, "content.topic", event.content["topic"]
+ self.store_event_search_txn(
+ txn, event, "content.topic", event.content["topic"],
)
def _store_room_name_txn(self, txn, event):
@@ -279,14 +277,14 @@ class RoomStore(SQLBaseStore):
}
)
- self._store_event_search_txn(
- txn, event, "content.name", event.content["name"]
+ self.store_event_search_txn(
+ txn, event, "content.name", event.content["name"],
)
def _store_room_message_txn(self, txn, event):
if hasattr(event, "content") and "body" in event.content:
- self._store_event_search_txn(
- txn, event, "content.body", event.content["body"]
+ self.store_event_search_txn(
+ txn, event, "content.body", event.content["body"],
)
def _store_history_visibility_txn(self, txn, event):
@@ -308,31 +306,6 @@ class RoomStore(SQLBaseStore):
event.content[key]
))
- def _store_event_search_txn(self, txn, event, key, value):
- if isinstance(self.database_engine, PostgresEngine):
- sql = (
- "INSERT INTO event_search"
- " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
- " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
- )
- txn.execute(
- sql,
- (
- event.event_id, event.room_id, key, value,
- event.internal_metadata.stream_ordering,
- event.origin_server_ts,
- )
- )
- elif isinstance(self.database_engine, Sqlite3Engine):
- sql = (
- "INSERT INTO event_search (event_id, room_id, key, value)"
- " VALUES (?,?,?,?)"
- )
- txn.execute(sql, (event.event_id, event.room_id, key, value,))
- else:
- # This should be unreachable.
- raise Exception("Unrecognized database engine")
-
def add_event_report(self, room_id, event_id, user_id, reason, content,
received_ts):
next_id = self._event_reports_id_gen.get_next()
@@ -533,73 +506,114 @@ class RoomStore(SQLBaseStore):
)
self.is_room_blocked.invalidate((room_id,))
+ def get_media_mxcs_in_room(self, room_id):
+ """Retrieves all the local and remote media MXC URIs in a given room
+
+ Args:
+ room_id (str)
+
+ Returns:
+ The local and remote media as a lists of tuples where the key is
+ the hostname and the value is the media ID.
+ """
+ def _get_media_mxcs_in_room_txn(txn):
+ local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+ local_media_mxcs = []
+ remote_media_mxcs = []
+
+ # Convert the IDs to MXC URIs
+ for media_id in local_mxcs:
+ local_media_mxcs.append("mxc://%s/%s" % (self.hostname, media_id))
+ for hostname, media_id in remote_mxcs:
+ remote_media_mxcs.append("mxc://%s/%s" % (hostname, media_id))
+
+ return local_media_mxcs, remote_media_mxcs
+ return self.runInteraction("get_media_ids_in_room", _get_media_mxcs_in_room_txn)
+
def quarantine_media_ids_in_room(self, room_id, quarantined_by):
"""For a room loops through all events with media and quarantines
the associated media
"""
- def _get_media_ids_in_room(txn):
- mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+ def _quarantine_media_in_room_txn(txn):
+ local_mxcs, remote_mxcs = self._get_media_mxcs_in_room_txn(txn, room_id)
+ total_media_quarantined = 0
- next_token = self.get_current_events_token() + 1
+ # Now update all the tables to set the quarantined_by flag
- total_media_quarantined = 0
+ txn.executemany("""
+ UPDATE local_media_repository
+ SET quarantined_by = ?
+ WHERE media_id = ?
+ """, ((quarantined_by, media_id) for media_id in local_mxcs))
- while next_token:
- sql = """
- SELECT stream_ordering, content FROM events
- WHERE room_id = ?
- AND stream_ordering < ?
- AND contains_url = ? AND outlier = ?
- ORDER BY stream_ordering DESC
- LIMIT ?
+ txn.executemany(
"""
- txn.execute(sql, (room_id, next_token, True, False, 100))
-
- next_token = None
- local_media_mxcs = []
- remote_media_mxcs = []
- for stream_ordering, content_json in txn:
- next_token = stream_ordering
- content = json.loads(content_json)
-
- content_url = content.get("url")
- thumbnail_url = content.get("info", {}).get("thumbnail_url")
-
- for url in (content_url, thumbnail_url):
- if not url:
- continue
- matches = mxc_re.match(url)
- if matches:
- hostname = matches.group(1)
- media_id = matches.group(2)
- if hostname == self.hostname:
- local_media_mxcs.append(media_id)
- else:
- remote_media_mxcs.append((hostname, media_id))
-
- # Now update all the tables to set the quarantined_by flag
-
- txn.executemany("""
- UPDATE local_media_repository
+ UPDATE remote_media_cache
SET quarantined_by = ?
- WHERE media_id = ?
- """, ((quarantined_by, media_id) for media_id in local_media_mxcs))
-
- txn.executemany(
- """
- UPDATE remote_media_cache
- SET quarantined_by = ?
- WHERE media_origin AND media_id = ?
- """,
- (
- (quarantined_by, origin, media_id)
- for origin, media_id in remote_media_mxcs
- )
+ WHERE media_origin = ? AND media_id = ?
+ """,
+ (
+ (quarantined_by, origin, media_id)
+ for origin, media_id in remote_mxcs
)
+ )
- total_media_quarantined += len(local_media_mxcs)
- total_media_quarantined += len(remote_media_mxcs)
+ total_media_quarantined += len(local_mxcs)
+ total_media_quarantined += len(remote_mxcs)
return total_media_quarantined
- return self.runInteraction("get_media_ids_in_room", _get_media_ids_in_room)
+ return self.runInteraction(
+ "quarantine_media_in_room",
+ _quarantine_media_in_room_txn,
+ )
+
+ def _get_media_mxcs_in_room_txn(self, txn, room_id):
+ """Retrieves all the local and remote media MXC URIs in a given room
+
+ Args:
+ txn (cursor)
+ room_id (str)
+
+ Returns:
+ The local and remote media as a lists of tuples where the key is
+ the hostname and the value is the media ID.
+ """
+ mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)")
+
+ next_token = self.get_current_events_token() + 1
+ local_media_mxcs = []
+ remote_media_mxcs = []
+
+ while next_token:
+ sql = """
+ SELECT stream_ordering, content FROM events
+ WHERE room_id = ?
+ AND stream_ordering < ?
+ AND contains_url = ? AND outlier = ?
+ ORDER BY stream_ordering DESC
+ LIMIT ?
+ """
+ txn.execute(sql, (room_id, next_token, True, False, 100))
+
+ next_token = None
+ for stream_ordering, content_json in txn:
+ next_token = stream_ordering
+ content = json.loads(content_json)
+
+ content_url = content.get("url")
+ thumbnail_url = content.get("info", {}).get("thumbnail_url")
+
+ for url in (content_url, thumbnail_url):
+ if not url:
+ continue
+ matches = mxc_re.match(url)
+ if matches:
+ hostname = matches.group(1)
+ media_id = matches.group(2)
+ if hostname == self.hostname:
+ local_media_mxcs.append(media_id)
+ else:
+ remote_media_mxcs.append((hostname, media_id))
+
+ return local_media_mxcs, remote_media_mxcs
diff --git a/synapse/storage/schema/delta/47/state_group_seq.py b/synapse/storage/schema/delta/47/state_group_seq.py
new file mode 100644
index 0000000000..f6766501d2
--- /dev/null
+++ b/synapse/storage/schema/delta/47/state_group_seq.py
@@ -0,0 +1,37 @@
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.storage.engines import PostgresEngine
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ if isinstance(database_engine, PostgresEngine):
+ # if we already have some state groups, we want to start making new
+ # ones with a higher id.
+ cur.execute("SELECT max(id) FROM state_groups")
+ row = cur.fetchone()
+
+ if row[0] is None:
+ start_val = 1
+ else:
+ start_val = row[0] + 1
+
+ cur.execute(
+ "CREATE SEQUENCE state_group_id_seq START WITH %s",
+ (start_val, ),
+ )
+
+
+def run_upgrade(*args, **kwargs):
+ pass
diff --git a/synapse/storage/search.py b/synapse/storage/search.py
index 479b04c636..f1ac9ba0fd 100644
--- a/synapse/storage/search.py
+++ b/synapse/storage/search.py
@@ -13,19 +13,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from collections import namedtuple
+import logging
+import re
+import ujson as json
+
from twisted.internet import defer
from .background_updates import BackgroundUpdateStore
from synapse.api.errors import SynapseError
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
-import logging
-import re
-import ujson as json
-
logger = logging.getLogger(__name__)
+SearchEntry = namedtuple('SearchEntry', [
+ 'key', 'value', 'event_id', 'room_id', 'stream_ordering',
+ 'origin_server_ts',
+])
+
class SearchStore(BackgroundUpdateStore):
@@ -49,16 +55,17 @@ class SearchStore(BackgroundUpdateStore):
@defer.inlineCallbacks
def _background_reindex_search(self, progress, batch_size):
+ # we work through the events table from highest stream id to lowest
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0)
- INSERT_CLUMP_SIZE = 1000
TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
def reindex_search_txn(txn):
sql = (
- "SELECT stream_ordering, event_id, room_id, type, content FROM events"
+ "SELECT stream_ordering, event_id, room_id, type, content, "
+ " origin_server_ts FROM events"
" WHERE ? <= stream_ordering AND stream_ordering < ?"
" AND (%s)"
" ORDER BY stream_ordering DESC"
@@ -67,6 +74,10 @@ class SearchStore(BackgroundUpdateStore):
txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
+ # we could stream straight from the results into
+ # store_search_entries_txn with a generator function, but that
+ # would mean having two cursors open on the database at once.
+ # Instead we just build a list of results.
rows = self.cursor_to_dict(txn)
if not rows:
return 0
@@ -79,6 +90,8 @@ class SearchStore(BackgroundUpdateStore):
event_id = row["event_id"]
room_id = row["room_id"]
etype = row["type"]
+ stream_ordering = row["stream_ordering"]
+ origin_server_ts = row["origin_server_ts"]
try:
content = json.loads(row["content"])
except Exception:
@@ -93,6 +106,8 @@ class SearchStore(BackgroundUpdateStore):
elif etype == "m.room.name":
key = "content.name"
value = content["name"]
+ else:
+ raise Exception("unexpected event type %s" % etype)
except (KeyError, AttributeError):
# If the event is missing a necessary field then
# skip over it.
@@ -103,25 +118,16 @@ class SearchStore(BackgroundUpdateStore):
# then skip over it
continue
- event_search_rows.append((event_id, room_id, key, value))
+ event_search_rows.append(SearchEntry(
+ key=key,
+ value=value,
+ event_id=event_id,
+ room_id=room_id,
+ stream_ordering=stream_ordering,
+ origin_server_ts=origin_server_ts,
+ ))
- if isinstance(self.database_engine, PostgresEngine):
- sql = (
- "INSERT INTO event_search (event_id, room_id, key, vector)"
- " VALUES (?,?,?,to_tsvector('english', ?))"
- )
- elif isinstance(self.database_engine, Sqlite3Engine):
- sql = (
- "INSERT INTO event_search (event_id, room_id, key, value)"
- " VALUES (?,?,?,?)"
- )
- else:
- # This should be unreachable.
- raise Exception("Unrecognized database engine")
-
- for index in range(0, len(event_search_rows), INSERT_CLUMP_SIZE):
- clump = event_search_rows[index:index + INSERT_CLUMP_SIZE]
- txn.executemany(sql, clump)
+ self.store_search_entries_txn(txn, event_search_rows)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
@@ -242,6 +248,62 @@ class SearchStore(BackgroundUpdateStore):
defer.returnValue(num_rows)
+ def store_event_search_txn(self, txn, event, key, value):
+ """Add event to the search table
+
+ Args:
+ txn (cursor):
+ event (EventBase):
+ key (str):
+ value (str):
+ """
+ self.store_search_entries_txn(
+ txn,
+ (SearchEntry(
+ key=key,
+ value=value,
+ event_id=event.event_id,
+ room_id=event.room_id,
+ stream_ordering=event.internal_metadata.stream_ordering,
+ origin_server_ts=event.origin_server_ts,
+ ),),
+ )
+
+ def store_search_entries_txn(self, txn, entries):
+ """Add entries to the search table
+
+ Args:
+ txn (cursor):
+ entries (iterable[SearchEntry]):
+ entries to be added to the table
+ """
+ if isinstance(self.database_engine, PostgresEngine):
+ sql = (
+ "INSERT INTO event_search"
+ " (event_id, room_id, key, vector, stream_ordering, origin_server_ts)"
+ " VALUES (?,?,?,to_tsvector('english', ?),?,?)"
+ )
+
+ args = ((
+ entry.event_id, entry.room_id, entry.key, entry.value,
+ entry.stream_ordering, entry.origin_server_ts,
+ ) for entry in entries)
+
+ txn.executemany(sql, args)
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ sql = (
+ "INSERT INTO event_search (event_id, room_id, key, value)"
+ " VALUES (?,?,?,?)"
+ )
+ args = ((
+ entry.event_id, entry.room_id, entry.key, entry.value,
+ ) for entry in entries)
+
+ txn.executemany(sql, args)
+ else:
+ # This should be unreachable.
+ raise Exception("Unrecognized database engine")
+
@defer.inlineCallbacks
def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys.
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 360e3e4355..adb48df73e 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -42,11 +42,8 @@ class _GetStateGroupDelta(namedtuple("_GetStateGroupDelta", ("prev_group", "delt
return len(self.delta_ids) if self.delta_ids else 0
-class StateGroupReadStore(SQLBaseStore):
- """The read-only parts of StateGroupStore
-
- None of these functions write to the state tables, so are suitable for
- including in the SlavedStores.
+class StateGroupWorkerStore(SQLBaseStore):
+ """The parts of StateGroupStore that can be called from workers.
"""
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
@@ -54,7 +51,7 @@ class StateGroupReadStore(SQLBaseStore):
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
def __init__(self, db_conn, hs):
- super(StateGroupReadStore, self).__init__(db_conn, hs)
+ super(StateGroupWorkerStore, self).__init__(db_conn, hs)
self._state_group_cache = DictionaryCache(
"*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
@@ -549,116 +546,66 @@ class StateGroupReadStore(SQLBaseStore):
defer.returnValue(results)
+ def store_state_group(self, event_id, room_id, prev_group, delta_ids,
+ current_state_ids):
+ """Store a new set of state, returning a newly assigned state group.
-class StateStore(StateGroupReadStore, BackgroundUpdateStore):
- """ Keeps track of the state at a given event.
-
- This is done by the concept of `state groups`. Every event is a assigned
- a state group (identified by an arbitrary string), which references a
- collection of state events. The current state of an event is then the
- collection of state events referenced by the event's state group.
-
- Hence, every change in the current state causes a new state group to be
- generated. However, if no change happens (e.g., if we get a message event
- with only one parent it inherits the state group from its parent.)
-
- There are three tables:
- * `state_groups`: Stores group name, first event with in the group and
- room id.
- * `event_to_state_groups`: Maps events to state groups.
- * `state_groups_state`: Maps state group to state events.
- """
-
- STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
- STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
- CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
-
- def __init__(self, db_conn, hs):
- super(StateStore, self).__init__(db_conn, hs)
- self.register_background_update_handler(
- self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
- self._background_deduplicate_state,
- )
- self.register_background_update_handler(
- self.STATE_GROUP_INDEX_UPDATE_NAME,
- self._background_index_state,
- )
- self.register_background_index_update(
- self.CURRENT_STATE_INDEX_UPDATE_NAME,
- index_name="current_state_events_member_index",
- table="current_state_events",
- columns=["state_key"],
- where_clause="type='m.room.member'",
- )
-
- def _have_persisted_state_group_txn(self, txn, state_group):
- txn.execute(
- "SELECT count(*) FROM state_groups WHERE id = ?",
- (state_group,)
- )
- row = txn.fetchone()
- return row and row[0]
-
- def _store_mult_state_groups_txn(self, txn, events_and_contexts):
- state_groups = {}
- for event, context in events_and_contexts:
- if event.internal_metadata.is_outlier():
- continue
+ Args:
+ event_id (str): The event ID for which the state was calculated
+ room_id (str)
+ prev_group (int|None): A previous state group for the room, optional.
+ delta_ids (dict|None): The delta between state at `prev_group` and
+ `current_state_ids`, if `prev_group` was given. Same format as
+ `current_state_ids`.
+ current_state_ids (dict): The state to store. Map of (type, state_key)
+ to event_id.
- if context.current_state_ids is None:
+ Returns:
+ Deferred[int]: The state group ID
+ """
+ def _store_state_group_txn(txn):
+ if current_state_ids is None:
# AFAIK, this can never happen
- logger.error(
- "Non-outlier event %s had current_state_ids==None",
- event.event_id)
- continue
+ raise Exception("current_state_ids cannot be None")
- # if the event was rejected, just give it the same state as its
- # predecessor.
- if context.rejected:
- state_groups[event.event_id] = context.prev_group
- continue
-
- state_groups[event.event_id] = context.state_group
-
- if self._have_persisted_state_group_txn(txn, context.state_group):
- continue
+ state_group = self.database_engine.get_next_state_group_id(txn)
self._simple_insert_txn(
txn,
table="state_groups",
values={
- "id": context.state_group,
- "room_id": event.room_id,
- "event_id": event.event_id,
+ "id": state_group,
+ "room_id": room_id,
+ "event_id": event_id,
},
)
# We persist as a delta if we can, while also ensuring the chain
# of deltas isn't tooo long, as otherwise read performance degrades.
- if context.prev_group:
+ if prev_group:
is_in_db = self._simple_select_one_onecol_txn(
txn,
table="state_groups",
- keyvalues={"id": context.prev_group},
+ keyvalues={"id": prev_group},
retcol="id",
allow_none=True,
)
if not is_in_db:
raise Exception(
"Trying to persist state with unpersisted prev_group: %r"
- % (context.prev_group,)
+ % (prev_group,)
)
potential_hops = self._count_state_group_hops_txn(
- txn, context.prev_group
+ txn, prev_group
)
- if context.prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
+ if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
self._simple_insert_txn(
txn,
table="state_group_edges",
values={
- "state_group": context.state_group,
- "prev_state_group": context.prev_group,
+ "state_group": state_group,
+ "prev_state_group": prev_group,
},
)
@@ -667,13 +614,13 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
table="state_groups_state",
values=[
{
- "state_group": context.state_group,
- "room_id": event.room_id,
+ "state_group": state_group,
+ "room_id": room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in context.delta_ids.iteritems()
+ for key, state_id in delta_ids.iteritems()
],
)
else:
@@ -682,13 +629,13 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
table="state_groups_state",
values=[
{
- "state_group": context.state_group,
- "room_id": event.room_id,
+ "state_group": state_group,
+ "room_id": room_id,
"type": key[0],
"state_key": key[1],
"event_id": state_id,
}
- for key, state_id in context.current_state_ids.iteritems()
+ for key, state_id in current_state_ids.iteritems()
],
)
@@ -699,11 +646,71 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
txn.call_after(
self._state_group_cache.update,
self._state_group_cache.sequence,
- key=context.state_group,
- value=dict(context.current_state_ids),
+ key=state_group,
+ value=dict(current_state_ids),
full=True,
)
+ return state_group
+
+ return self.runInteraction("store_state_group", _store_state_group_txn)
+
+
+class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
+ """ Keeps track of the state at a given event.
+
+ This is done by the concept of `state groups`. Every event is a assigned
+ a state group (identified by an arbitrary string), which references a
+ collection of state events. The current state of an event is then the
+ collection of state events referenced by the event's state group.
+
+ Hence, every change in the current state causes a new state group to be
+ generated. However, if no change happens (e.g., if we get a message event
+ with only one parent it inherits the state group from its parent.)
+
+ There are three tables:
+ * `state_groups`: Stores group name, first event with in the group and
+ room id.
+ * `event_to_state_groups`: Maps events to state groups.
+ * `state_groups_state`: Maps state group to state events.
+ """
+
+ STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
+ STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
+ CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
+
+ def __init__(self, db_conn, hs):
+ super(StateStore, self).__init__(db_conn, hs)
+ self.register_background_update_handler(
+ self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
+ self._background_deduplicate_state,
+ )
+ self.register_background_update_handler(
+ self.STATE_GROUP_INDEX_UPDATE_NAME,
+ self._background_index_state,
+ )
+ self.register_background_index_update(
+ self.CURRENT_STATE_INDEX_UPDATE_NAME,
+ index_name="current_state_events_member_index",
+ table="current_state_events",
+ columns=["state_key"],
+ where_clause="type='m.room.member'",
+ )
+
+ def _store_event_state_mappings_txn(self, txn, events_and_contexts):
+ state_groups = {}
+ for event, context in events_and_contexts:
+ if event.internal_metadata.is_outlier():
+ continue
+
+ # if the event was rejected, just give it the same state as its
+ # predecessor.
+ if context.rejected:
+ state_groups[event.event_id] = context.prev_group
+ continue
+
+ state_groups[event.event_id] = context.state_group
+
self._simple_insert_many_txn(
txn,
table="event_to_state_groups",
@@ -763,9 +770,6 @@ class StateStore(StateGroupReadStore, BackgroundUpdateStore):
return count
- def get_next_state_group(self):
- return self._state_groups_id_gen.get_next()
-
@defer.inlineCallbacks
def _background_deduplicate_state(self, progress, batch_size):
"""This background update will slowly deduplicate state by reencoding
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index f150ef0103..dfdcbb3181 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -641,13 +641,12 @@ class UserDirectoryStore(SQLBaseStore):
"""
if self.hs.config.user_directory_search_all_users:
- # dummy to keep the number of binds & aliases the same
+ # make s.user_id null to keep the ordering algorithm happy
join_clause = """
- LEFT JOIN (
- SELECT NULL as user_id WHERE NULL = ?
- ) AS s USING (user_id)"
+ CROSS JOIN (SELECT NULL as user_id) AS s
"""
- where_clause = ""
+ join_args = ()
+ where_clause = "1=1"
else:
join_clause = """
LEFT JOIN users_in_public_rooms AS p USING (user_id)
@@ -656,6 +655,7 @@ class UserDirectoryStore(SQLBaseStore):
WHERE user_id = ? AND share_private
) AS s USING (user_id)
"""
+ join_args = (user_id,)
where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)"
if isinstance(self.database_engine, PostgresEngine):
@@ -697,7 +697,7 @@ class UserDirectoryStore(SQLBaseStore):
avatar_url IS NULL
LIMIT ?
""" % (join_clause, where_clause)
- args = (user_id, full_query, exact_query, prefix_query, limit + 1,)
+ args = join_args + (full_query, exact_query, prefix_query, limit + 1,)
elif isinstance(self.database_engine, Sqlite3Engine):
search_query = _parse_query_sqlite(search_term)
@@ -715,7 +715,7 @@ class UserDirectoryStore(SQLBaseStore):
avatar_url IS NULL
LIMIT ?
""" % (join_clause, where_clause)
- args = (user_id, search_query, limit + 1)
+ args = join_args + (search_query, limit + 1)
else:
# This should be unreachable.
raise Exception("Unrecognized database engine")
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index af65bfe7b8..bf3a66eae4 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -75,6 +75,7 @@ class Cache(object):
self.cache = LruCache(
max_size=max_entries, keylen=keylen, cache_type=cache_type,
size_callback=(lambda d: len(d)) if iterable else None,
+ evicted_callback=self._on_evicted,
)
self.name = name
@@ -83,6 +84,9 @@ class Cache(object):
self.thread = None
self.metrics = register_cache(name, self.cache)
+ def _on_evicted(self, evicted_count):
+ self.metrics.inc_evictions(evicted_count)
+
def check_thread(self):
expected_thread = self.thread
if expected_thread is None:
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 6ad53a6390..0aa103eecb 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -79,7 +79,11 @@ class ExpiringCache(object):
while self._max_len and len(self) > self._max_len:
_key, value = self._cache.popitem(last=False)
if self.iterable:
- self._size_estimate -= len(value.value)
+ removed_len = len(value.value)
+ self.metrics.inc_evictions(removed_len)
+ self._size_estimate -= removed_len
+ else:
+ self.metrics.inc_evictions()
def __getitem__(self, key):
try:
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index cf5fbb679c..f088dd430e 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -49,7 +49,24 @@ class LruCache(object):
Can also set callbacks on objects when getting/setting which are fired
when that key gets invalidated/evicted.
"""
- def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None):
+ def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None,
+ evicted_callback=None):
+ """
+ Args:
+ max_size (int):
+
+ keylen (int):
+
+ cache_type (type):
+ type of underlying cache to be used. Typically one of dict
+ or TreeCache.
+
+ size_callback (func(V) -> int | None):
+
+ evicted_callback (func(int)|None):
+ if not None, called on eviction with the size of the evicted
+ entry
+ """
cache = cache_type()
self.cache = cache # Used for introspection.
list_root = _Node(None, None, None, None)
@@ -61,8 +78,10 @@ class LruCache(object):
def evict():
while cache_len() > max_size:
todelete = list_root.prev_node
- delete_node(todelete)
+ evicted_len = delete_node(todelete)
cache.pop(todelete.key, None)
+ if evicted_callback:
+ evicted_callback(evicted_len)
def synchronized(f):
@wraps(f)
@@ -111,12 +130,15 @@ class LruCache(object):
prev_node.next_node = next_node
next_node.prev_node = prev_node
+ deleted_len = 1
if size_callback:
- cached_cache_len[0] -= size_callback(node.value)
+ deleted_len = size_callback(node.value)
+ cached_cache_len[0] -= deleted_len
for cb in node.callbacks:
cb()
node.callbacks.clear()
+ return deleted_len
@synchronized
def cache_get(key, default=None, callbacks=[]):
diff --git a/synapse/util/file_consumer.py b/synapse/util/file_consumer.py
new file mode 100644
index 0000000000..90a2608d6f
--- /dev/null
+++ b/synapse/util/file_consumer.py
@@ -0,0 +1,139 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import threads, reactor
+
+from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
+
+import Queue
+
+
+class BackgroundFileConsumer(object):
+ """A consumer that writes to a file like object. Supports both push
+ and pull producers
+
+ Args:
+ file_obj (file): The file like object to write to. Closed when
+ finished.
+ """
+
+ # For PushProducers pause if we have this many unwritten slices
+ _PAUSE_ON_QUEUE_SIZE = 5
+ # And resume once the size of the queue is less than this
+ _RESUME_ON_QUEUE_SIZE = 2
+
+ def __init__(self, file_obj):
+ self._file_obj = file_obj
+
+ # Producer we're registered with
+ self._producer = None
+
+ # True if PushProducer, false if PullProducer
+ self.streaming = False
+
+ # For PushProducers, indicates whether we've paused the producer and
+ # need to call resumeProducing before we get more data.
+ self._paused_producer = False
+
+ # Queue of slices of bytes to be written. When producer calls
+ # unregister a final None is sent.
+ self._bytes_queue = Queue.Queue()
+
+ # Deferred that is resolved when finished writing
+ self._finished_deferred = None
+
+ # If the _writer thread throws an exception it gets stored here.
+ self._write_exception = None
+
+ def registerProducer(self, producer, streaming):
+ """Part of IConsumer interface
+
+ Args:
+ producer (IProducer)
+ streaming (bool): True if push based producer, False if pull
+ based.
+ """
+ if self._producer:
+ raise Exception("registerProducer called twice")
+
+ self._producer = producer
+ self.streaming = streaming
+ self._finished_deferred = preserve_fn(threads.deferToThread)(self._writer)
+ if not streaming:
+ self._producer.resumeProducing()
+
+ def unregisterProducer(self):
+ """Part of IProducer interface
+ """
+ self._producer = None
+ if not self._finished_deferred.called:
+ self._bytes_queue.put_nowait(None)
+
+ def write(self, bytes):
+ """Part of IProducer interface
+ """
+ if self._write_exception:
+ raise self._write_exception
+
+ if self._finished_deferred.called:
+ raise Exception("consumer has closed")
+
+ self._bytes_queue.put_nowait(bytes)
+
+ # If this is a PushProducer and the queue is getting behind
+ # then we pause the producer.
+ if self.streaming and self._bytes_queue.qsize() >= self._PAUSE_ON_QUEUE_SIZE:
+ self._paused_producer = True
+ self._producer.pauseProducing()
+
+ def _writer(self):
+ """This is run in a background thread to write to the file.
+ """
+ try:
+ while self._producer or not self._bytes_queue.empty():
+ # If we've paused the producer check if we should resume the
+ # producer.
+ if self._producer and self._paused_producer:
+ if self._bytes_queue.qsize() <= self._RESUME_ON_QUEUE_SIZE:
+ reactor.callFromThread(self._resume_paused_producer)
+
+ bytes = self._bytes_queue.get()
+
+ # If we get a None (or empty list) then that's a signal used
+ # to indicate we should check if we should stop.
+ if bytes:
+ self._file_obj.write(bytes)
+
+ # If its a pull producer then we need to explicitly ask for
+ # more stuff.
+ if not self.streaming and self._producer:
+ reactor.callFromThread(self._producer.resumeProducing)
+ except Exception as e:
+ self._write_exception = e
+ raise
+ finally:
+ self._file_obj.close()
+
+ def wait(self):
+ """Returns a deferred that resolves when finished writing to file
+ """
+ return make_deferred_yieldable(self._finished_deferred)
+
+ def _resume_paused_producer(self):
+ """Gets called if we should resume producing after being paused
+ """
+ if self._paused_producer and self._producer:
+ self._paused_producer = False
+ self._producer.resumeProducing()
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 1adedbb361..47b0bb5eb3 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -26,6 +26,18 @@ logger = logging.getLogger(__name__)
class NotRetryingDestination(Exception):
def __init__(self, retry_last_ts, retry_interval, destination):
+ """Raised by the limiter (and federation client) to indicate that we are
+ are deliberately not attempting to contact a given server.
+
+ Args:
+ retry_last_ts (int): the unix ts in milliseconds of our last attempt
+ to contact the server. 0 indicates that the last attempt was
+ successful or that we've never actually attempted to connect.
+ retry_interval (int): the time in milliseconds to wait until the next
+ attempt.
+ destination (str): the domain in question
+ """
+
msg = "Not retrying server %s." % (destination,)
super(NotRetryingDestination, self).__init__(msg)
diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py
new file mode 100644
index 0000000000..75efa0117b
--- /dev/null
+++ b/synapse/util/threepids.py
@@ -0,0 +1,48 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+
+logger = logging.getLogger(__name__)
+
+
+def check_3pid_allowed(hs, medium, address):
+ """Checks whether a given format of 3PID is allowed to be used on this HS
+
+ Args:
+ hs (synapse.server.HomeServer): server
+ medium (str): 3pid medium - e.g. email, msisdn
+ address (str): address within that medium (e.g. "wotan@matrix.org")
+ msisdns need to first have been canonicalised
+ Returns:
+ bool: whether the 3PID medium/address is allowed to be added to this HS
+ """
+
+ if hs.config.allowed_local_3pids:
+ for constraint in hs.config.allowed_local_3pids:
+ logger.debug(
+ "Checking 3PID %s (%s) against %s (%s)",
+ address, medium, constraint['pattern'], constraint['medium'],
+ )
+ if (
+ medium == constraint['medium'] and
+ re.match(constraint['pattern'], address)
+ ):
+ return True
+ else:
+ return True
+
+ return False
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index c899fecf5d..d4ec02ffc2 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -167,7 +167,7 @@ class KeyringTestCase(unittest.TestCase):
# wait a tick for it to send the request to the perspectives server
# (it first tries the datastore)
- yield async.sleep(0.005)
+ yield async.sleep(1) # XXX find out why this takes so long!
self.http_client.post_json.assert_called_once()
self.assertIs(LoggingContext.current_context(), context_11)
@@ -183,7 +183,7 @@ class KeyringTestCase(unittest.TestCase):
res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1)],
)
- yield async.sleep(0.005)
+ yield async.sleep(01)
self.http_client.post_json.assert_not_called()
res_deferreds_2[0].addBoth(self.check_context, None)
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 19f5ed6bce..d92bf240b1 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -143,7 +143,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
except errors.SynapseError:
pass
- @unittest.DEBUG
@defer.inlineCallbacks
def test_claim_one_time_key(self):
local_user = "@boris:" + self.hs.hostname
diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py
index f85455a5af..39bde6e3f8 100644
--- a/tests/metrics/test_metric.py
+++ b/tests/metrics/test_metric.py
@@ -141,6 +141,7 @@ class CacheMetricTestCase(unittest.TestCase):
'cache:hits{name="cache_name"} 0',
'cache:total{name="cache_name"} 0',
'cache:size{name="cache_name"} 0',
+ 'cache:evicted_size{name="cache_name"} 0',
])
metric.inc_misses()
@@ -150,6 +151,7 @@ class CacheMetricTestCase(unittest.TestCase):
'cache:hits{name="cache_name"} 0',
'cache:total{name="cache_name"} 1',
'cache:size{name="cache_name"} 1',
+ 'cache:evicted_size{name="cache_name"} 0',
])
metric.inc_hits()
@@ -158,4 +160,14 @@ class CacheMetricTestCase(unittest.TestCase):
'cache:hits{name="cache_name"} 1',
'cache:total{name="cache_name"} 2',
'cache:size{name="cache_name"} 1',
+ 'cache:evicted_size{name="cache_name"} 0',
+ ])
+
+ metric.inc_evictions(2)
+
+ self.assertEquals(metric.render(), [
+ 'cache:hits{name="cache_name"} 1',
+ 'cache:total{name="cache_name"} 2',
+ 'cache:size{name="cache_name"} 1',
+ 'cache:evicted_size{name="cache_name"} 2',
])
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 81063f19a1..74f104e3b8 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -15,6 +15,8 @@
from twisted.internet import defer, reactor
from tests import unittest
+import tempfile
+
from mock import Mock, NonCallableMock
from tests.utils import setup_test_homeserver
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -41,7 +43,9 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
self.event_id = 0
server_factory = ReplicationStreamProtocolFactory(self.hs)
- listener = reactor.listenUNIX("\0xxx", server_factory)
+ # XXX: mktemp is unsafe and should never be used. but we're just a test.
+ path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket")
+ listener = reactor.listenUNIX(path, server_factory)
self.addCleanup(listener.stopListening)
self.streamer = server_factory.streamer
@@ -49,7 +53,7 @@ class BaseSlavedStoreTestCase(unittest.TestCase):
client_factory = ReplicationClientFactory(
self.hs, "client_name", self.replication_handler
)
- client_connector = reactor.connectUNIX("\0xxx", client_factory)
+ client_connector = reactor.connectUNIX(path, client_factory)
self.addCleanup(client_factory.stopTrying)
self.addCleanup(client_connector.disconnect)
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 105e1228bb..f430cce931 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -226,11 +226,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
context = EventContext()
context.current_state_ids = state_ids
context.prev_state_ids = state_ids
- elif not backfill:
+ else:
state_handler = self.hs.get_state_handler()
context = yield state_handler.compute_event_context(event)
- else:
- context = EventContext()
context.push_actions = push_actions
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 096f771bea..8aba456510 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -49,6 +49,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
self.hs.get_device_handler = Mock(return_value=self.device_handler)
self.hs.config.enable_registration = True
+ self.hs.config.registrations_require_3pid = []
self.hs.config.auto_join_rooms = []
# init the thing we're testing
diff --git a/tests/rest/media/__init__.py b/tests/rest/media/__init__.py
new file mode 100644
index 0000000000..a354d38ca8
--- /dev/null
+++ b/tests/rest/media/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/rest/media/v1/__init__.py b/tests/rest/media/v1/__init__.py
new file mode 100644
index 0000000000..a354d38ca8
--- /dev/null
+++ b/tests/rest/media/v1/__init__.py
@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
new file mode 100644
index 0000000000..eef38b6781
--- /dev/null
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -0,0 +1,86 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer
+
+from synapse.rest.media.v1._base import FileInfo
+from synapse.rest.media.v1.media_storage import MediaStorage
+from synapse.rest.media.v1.filepath import MediaFilePaths
+from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
+
+from mock import Mock
+
+from tests import unittest
+
+import os
+import shutil
+import tempfile
+
+
+class MediaStorageTests(unittest.TestCase):
+ def setUp(self):
+ self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
+
+ self.primary_base_path = os.path.join(self.test_dir, "primary")
+ self.secondary_base_path = os.path.join(self.test_dir, "secondary")
+
+ hs = Mock()
+ hs.config.media_store_path = self.primary_base_path
+
+ storage_providers = [FileStorageProviderBackend(
+ hs, self.secondary_base_path
+ )]
+
+ self.filepaths = MediaFilePaths(self.primary_base_path)
+ self.media_storage = MediaStorage(
+ self.primary_base_path, self.filepaths, storage_providers,
+ )
+
+ def tearDown(self):
+ shutil.rmtree(self.test_dir)
+
+ @defer.inlineCallbacks
+ def test_ensure_media_is_in_local_cache(self):
+ media_id = "some_media_id"
+ test_body = "Test\n"
+
+ # First we create a file that is in a storage provider but not in the
+ # local primary media store
+ rel_path = self.filepaths.local_media_filepath_rel(media_id)
+ secondary_path = os.path.join(self.secondary_base_path, rel_path)
+
+ os.makedirs(os.path.dirname(secondary_path))
+
+ with open(secondary_path, "w") as f:
+ f.write(test_body)
+
+ # Now we run ensure_media_is_in_local_cache, which should copy the file
+ # to the local cache.
+ file_info = FileInfo(None, media_id)
+ local_path = yield self.media_storage.ensure_media_is_in_local_cache(file_info)
+
+ self.assertTrue(os.path.exists(local_path))
+
+ # Asserts the file is under the expected local cache directory
+ self.assertEquals(
+ os.path.commonprefix([self.primary_base_path, local_path]),
+ self.primary_base_path,
+ )
+
+ with open(local_path) as f:
+ body = f.read()
+
+ self.assertEqual(test_body, body)
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 6afaca3a61..888ddfaddd 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -36,8 +36,7 @@ class RedactionTestCase(unittest.TestCase):
self.store = hs.get_datastore()
self.event_builder_factory = hs.get_event_builder_factory()
- self.handlers = hs.get_handlers()
- self.message_handler = self.handlers.message_handler
+ self.event_creation_handler = hs.get_event_creation_handler()
self.u_alice = UserID.from_string("@alice:test")
self.u_bob = UserID.from_string("@bob:test")
@@ -59,7 +58,7 @@ class RedactionTestCase(unittest.TestCase):
"content": content,
})
- event, context = yield self.message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder
)
@@ -79,7 +78,7 @@ class RedactionTestCase(unittest.TestCase):
"content": {"body": body, "msgtype": u"message"},
})
- event, context = yield self.message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder
)
@@ -98,7 +97,7 @@ class RedactionTestCase(unittest.TestCase):
"redacts": event_id,
})
- event, context = yield self.message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder
)
diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py
index 1be7d932f6..657b279e5d 100644
--- a/tests/storage/test_roommember.py
+++ b/tests/storage/test_roommember.py
@@ -37,8 +37,7 @@ class RoomMemberStoreTestCase(unittest.TestCase):
# storage logic
self.store = hs.get_datastore()
self.event_builder_factory = hs.get_event_builder_factory()
- self.handlers = hs.get_handlers()
- self.message_handler = self.handlers.message_handler
+ self.event_creation_handler = hs.get_event_creation_handler()
self.u_alice = UserID.from_string("@alice:test")
self.u_bob = UserID.from_string("@bob:test")
@@ -58,7 +57,7 @@ class RoomMemberStoreTestCase(unittest.TestCase):
"content": {"membership": membership},
})
- event, context = yield self.message_handler._create_new_client_event(
+ event, context = yield self.event_creation_handler.create_new_client_event(
builder
)
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
new file mode 100644
index 0000000000..0891308f25
--- /dev/null
+++ b/tests/storage/test_user_directory.py
@@ -0,0 +1,88 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.storage import UserDirectoryStore
+from synapse.storage.roommember import ProfileInfo
+from tests import unittest
+from tests.utils import setup_test_homeserver
+
+ALICE = "@alice:a"
+BOB = "@bob:b"
+BOBBY = "@bobby:a"
+
+
+class UserDirectoryStoreTestCase(unittest.TestCase):
+ @defer.inlineCallbacks
+ def setUp(self):
+ self.hs = yield setup_test_homeserver()
+ self.store = UserDirectoryStore(None, self.hs)
+
+ # alice and bob are both in !room_id. bobby is not but shares
+ # a homeserver with alice.
+ yield self.store.add_profiles_to_user_dir(
+ "!room:id",
+ {
+ ALICE: ProfileInfo(None, "alice"),
+ BOB: ProfileInfo(None, "bob"),
+ BOBBY: ProfileInfo(None, "bobby")
+ },
+ )
+ yield self.store.add_users_to_public_room(
+ "!room:id",
+ [ALICE, BOB],
+ )
+ yield self.store.add_users_who_share_room(
+ "!room:id",
+ False,
+ (
+ (ALICE, BOB),
+ (BOB, ALICE),
+ ),
+ )
+
+ @defer.inlineCallbacks
+ def test_search_user_dir(self):
+ # normally when alice searches the directory she should just find
+ # bob because bobby doesn't share a room with her.
+ r = yield self.store.search_user_dir(ALICE, "bob", 10)
+ self.assertFalse(r["limited"])
+ self.assertEqual(1, len(r["results"]))
+ self.assertDictEqual(r["results"][0], {
+ "user_id": BOB,
+ "display_name": "bob",
+ "avatar_url": None,
+ })
+
+ @defer.inlineCallbacks
+ def test_search_user_dir_all_users(self):
+ self.hs.config.user_directory_search_all_users = True
+ try:
+ r = yield self.store.search_user_dir(ALICE, "bob", 10)
+ self.assertFalse(r["limited"])
+ self.assertEqual(2, len(r["results"]))
+ self.assertDictEqual(r["results"][0], {
+ "user_id": BOB,
+ "display_name": "bob",
+ "avatar_url": None,
+ })
+ self.assertDictEqual(r["results"][1], {
+ "user_id": BOBBY,
+ "display_name": "bobby",
+ "avatar_url": None,
+ })
+ finally:
+ self.hs.config.user_directory_search_all_users = False
diff --git a/tests/test_state.py b/tests/test_state.py
index feb84f3d48..a5c5e55951 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
from synapse.events import FrozenEvent
from synapse.api.auth import Auth
from synapse.api.constants import EventTypes, Membership
-from synapse.state import StateHandler
+from synapse.state import StateHandler, StateResolutionHandler
from .utils import MockClock
@@ -80,14 +80,14 @@ class StateGroupStore(object):
return defer.succeed(groups)
- def store_state_groups(self, event, context):
- if context.current_state_ids is None:
- return
+ def store_state_group(self, event_id, room_id, prev_group, delta_ids,
+ current_state_ids):
+ state_group = self._next_group
+ self._next_group += 1
- state_events = dict(context.current_state_ids)
+ self._group_to_state[state_group] = dict(current_state_ids)
- self._group_to_state[context.state_group] = state_events
- self._event_to_state_group[event.event_id] = context.state_group
+ return state_group
def get_events(self, event_ids, **kwargs):
return {
@@ -95,10 +95,19 @@ class StateGroupStore(object):
if e_id in self._event_id_to_event
}
+ def get_state_group_delta(self, name):
+ return (None, None)
+
def register_events(self, events):
for e in events:
self._event_id_to_event[e.event_id] = e
+ def register_event_context(self, event, context):
+ self._event_to_state_group[event.event_id] = context.state_group
+
+ def register_event_id_state_group(self, event_id, state_group):
+ self._event_to_state_group[event_id] = state_group
+
class DictObj(dict):
def __init__(self, **kwargs):
@@ -137,25 +146,16 @@ class Graph(object):
class StateTestCase(unittest.TestCase):
def setUp(self):
- self.store = Mock(
- spec_set=[
- "get_state_groups_ids",
- "add_event_hashes",
- "get_events",
- "get_next_state_group",
- "get_state_group_delta",
- ]
- )
+ self.store = StateGroupStore()
hs = Mock(spec_set=[
"get_datastore", "get_auth", "get_state_handler", "get_clock",
+ "get_state_resolution_handler",
])
hs.get_datastore.return_value = self.store
hs.get_state_handler.return_value = None
hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs)
-
- self.store.get_next_state_group.side_effect = Mock
- self.store.get_state_group_delta.return_value = (None, None)
+ hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
self.state = StateHandler(hs)
self.event_id = 0
@@ -195,14 +195,13 @@ class StateTestCase(unittest.TestCase):
}
)
- store = StateGroupStore()
- self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
+ self.store.register_events(graph.walk())
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
- store.store_state_groups(event, context)
+ self.store.register_event_context(event, context)
context_store[event.event_id] = context
self.assertEqual(2, len(context_store["D"].prev_state_ids))
@@ -247,16 +246,13 @@ class StateTestCase(unittest.TestCase):
}
)
- store = StateGroupStore()
- self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
- self.store.get_events = store.get_events
- store.register_events(graph.walk())
+ self.store.register_events(graph.walk())
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
- store.store_state_groups(event, context)
+ self.store.register_event_context(event, context)
context_store[event.event_id] = context
self.assertSetEqual(
@@ -313,16 +309,13 @@ class StateTestCase(unittest.TestCase):
}
)
- store = StateGroupStore()
- self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
- self.store.get_events = store.get_events
- store.register_events(graph.walk())
+ self.store.register_events(graph.walk())
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
- store.store_state_groups(event, context)
+ self.store.register_event_context(event, context)
context_store[event.event_id] = context
self.assertSetEqual(
@@ -396,16 +389,13 @@ class StateTestCase(unittest.TestCase):
self._add_depths(nodes, edges)
graph = Graph(nodes, edges)
- store = StateGroupStore()
- self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
- self.store.get_events = store.get_events
- store.register_events(graph.walk())
+ self.store.register_events(graph.walk())
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
- store.store_state_groups(event, context)
+ self.store.register_event_context(event, context)
context_store[event.event_id] = context
self.assertSetEqual(
@@ -465,7 +455,11 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_trivial_annotate_message(self):
- event = create_event(type="test_message", name="event")
+ prev_event_id = "prev_event_id"
+ event = create_event(
+ type="test_message", name="event2",
+ prev_events=[(prev_event_id, {})],
+ )
old_state = [
create_event(type="test1", state_key="1"),
@@ -473,11 +467,11 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- group_name = "group_name_1"
-
- self.store.get_state_groups_ids.return_value = {
- group_name: {(e.type, e.state_key): e.event_id for e in old_state},
- }
+ group_name = self.store.store_state_group(
+ prev_event_id, event.room_id, None, None,
+ {(e.type, e.state_key): e.event_id for e in old_state},
+ )
+ self.store.register_event_id_state_group(prev_event_id, group_name)
context = yield self.state.compute_event_context(event)
@@ -490,7 +484,11 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_trivial_annotate_state(self):
- event = create_event(type="state", state_key="", name="event")
+ prev_event_id = "prev_event_id"
+ event = create_event(
+ type="state", state_key="", name="event2",
+ prev_events=[(prev_event_id, {})],
+ )
old_state = [
create_event(type="test1", state_key="1"),
@@ -498,11 +496,11 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- group_name = "group_name_1"
-
- self.store.get_state_groups_ids.return_value = {
- group_name: {(e.type, e.state_key): e.event_id for e in old_state},
- }
+ group_name = self.store.store_state_group(
+ prev_event_id, event.room_id, None, None,
+ {(e.type, e.state_key): e.event_id for e in old_state},
+ )
+ self.store.register_event_id_state_group(prev_event_id, group_name)
context = yield self.state.compute_event_context(event)
@@ -515,7 +513,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_resolve_message_conflict(self):
- event = create_event(type="test_message", name="event")
+ prev_event_id1 = "event_id1"
+ prev_event_id2 = "event_id2"
+ event = create_event(
+ type="test_message", name="event3",
+ prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
+ )
creation = create_event(
type=EventTypes.Create, state_key=""
@@ -535,12 +538,12 @@ class StateTestCase(unittest.TestCase):
create_event(type="test4", state_key=""),
]
- store = StateGroupStore()
- store.register_events(old_state_1)
- store.register_events(old_state_2)
- self.store.get_events = store.get_events
+ self.store.register_events(old_state_1)
+ self.store.register_events(old_state_2)
- context = yield self._get_context(event, old_state_1, old_state_2)
+ context = yield self._get_context(
+ event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+ )
self.assertEqual(len(context.current_state_ids), 6)
@@ -548,7 +551,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_resolve_state_conflict(self):
- event = create_event(type="test4", state_key="", name="event")
+ prev_event_id1 = "event_id1"
+ prev_event_id2 = "event_id2"
+ event = create_event(
+ type="test4", state_key="", name="event",
+ prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
+ )
creation = create_event(
type=EventTypes.Create, state_key=""
@@ -573,7 +581,9 @@ class StateTestCase(unittest.TestCase):
store.register_events(old_state_2)
self.store.get_events = store.get_events
- context = yield self._get_context(event, old_state_1, old_state_2)
+ context = yield self._get_context(
+ event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+ )
self.assertEqual(len(context.current_state_ids), 6)
@@ -581,7 +591,12 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_standard_depth_conflict(self):
- event = create_event(type="test4", name="event")
+ prev_event_id1 = "event_id1"
+ prev_event_id2 = "event_id2"
+ event = create_event(
+ type="test4", name="event",
+ prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
+ )
member_event = create_event(
type=EventTypes.Member,
@@ -613,7 +628,9 @@ class StateTestCase(unittest.TestCase):
store.register_events(old_state_2)
self.store.get_events = store.get_events
- context = yield self._get_context(event, old_state_1, old_state_2)
+ context = yield self._get_context(
+ event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+ )
self.assertEqual(
old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
@@ -637,19 +654,26 @@ class StateTestCase(unittest.TestCase):
store.register_events(old_state_1)
store.register_events(old_state_2)
- context = yield self._get_context(event, old_state_1, old_state_2)
+ context = yield self._get_context(
+ event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+ )
self.assertEqual(
old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
)
- def _get_context(self, event, old_state_1, old_state_2):
- group_name_1 = "group_name_1"
- group_name_2 = "group_name_2"
+ def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2,
+ old_state_2):
+ sg1 = self.store.store_state_group(
+ prev_event_id_1, event.room_id, None, None,
+ {(e.type, e.state_key): e.event_id for e in old_state_1},
+ )
+ self.store.register_event_id_state_group(prev_event_id_1, sg1)
- self.store.get_state_groups_ids.return_value = {
- group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1},
- group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2},
- }
+ sg2 = self.store.store_state_group(
+ prev_event_id_2, event.room_id, None, None,
+ {(e.type, e.state_key): e.event_id for e in old_state_2},
+ )
+ self.store.register_event_id_state_group(prev_event_id_2, sg2)
return self.state.compute_event_context(event)
diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py
new file mode 100644
index 0000000000..76e2234255
--- /dev/null
+++ b/tests/util/test_file_consumer.py
@@ -0,0 +1,176 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer, reactor
+from mock import NonCallableMock
+
+from synapse.util.file_consumer import BackgroundFileConsumer
+
+from tests import unittest
+from StringIO import StringIO
+
+import threading
+
+
+class FileConsumerTests(unittest.TestCase):
+
+ @defer.inlineCallbacks
+ def test_pull_consumer(self):
+ string_file = StringIO()
+ consumer = BackgroundFileConsumer(string_file)
+
+ try:
+ producer = DummyPullProducer()
+
+ yield producer.register_with_consumer(consumer)
+
+ yield producer.write_and_wait("Foo")
+
+ self.assertEqual(string_file.getvalue(), "Foo")
+
+ yield producer.write_and_wait("Bar")
+
+ self.assertEqual(string_file.getvalue(), "FooBar")
+ finally:
+ consumer.unregisterProducer()
+
+ yield consumer.wait()
+
+ self.assertTrue(string_file.closed)
+
+ @defer.inlineCallbacks
+ def test_push_consumer(self):
+ string_file = BlockingStringWrite()
+ consumer = BackgroundFileConsumer(string_file)
+
+ try:
+ producer = NonCallableMock(spec_set=[])
+
+ consumer.registerProducer(producer, True)
+
+ consumer.write("Foo")
+ yield string_file.wait_for_n_writes(1)
+
+ self.assertEqual(string_file.buffer, "Foo")
+
+ consumer.write("Bar")
+ yield string_file.wait_for_n_writes(2)
+
+ self.assertEqual(string_file.buffer, "FooBar")
+ finally:
+ consumer.unregisterProducer()
+
+ yield consumer.wait()
+
+ self.assertTrue(string_file.closed)
+
+ @defer.inlineCallbacks
+ def test_push_producer_feedback(self):
+ string_file = BlockingStringWrite()
+ consumer = BackgroundFileConsumer(string_file)
+
+ try:
+ producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
+
+ resume_deferred = defer.Deferred()
+ producer.resumeProducing.side_effect = lambda: resume_deferred.callback(None)
+
+ consumer.registerProducer(producer, True)
+
+ number_writes = 0
+ with string_file.write_lock:
+ for _ in range(consumer._PAUSE_ON_QUEUE_SIZE):
+ consumer.write("Foo")
+ number_writes += 1
+
+ producer.pauseProducing.assert_called_once()
+
+ yield string_file.wait_for_n_writes(number_writes)
+
+ yield resume_deferred
+ producer.resumeProducing.assert_called_once()
+ finally:
+ consumer.unregisterProducer()
+
+ yield consumer.wait()
+
+ self.assertTrue(string_file.closed)
+
+
+class DummyPullProducer(object):
+ def __init__(self):
+ self.consumer = None
+ self.deferred = defer.Deferred()
+
+ def resumeProducing(self):
+ d = self.deferred
+ self.deferred = defer.Deferred()
+ d.callback(None)
+
+ def write_and_wait(self, bytes):
+ d = self.deferred
+ self.consumer.write(bytes)
+ return d
+
+ def register_with_consumer(self, consumer):
+ d = self.deferred
+ self.consumer = consumer
+ self.consumer.registerProducer(self, False)
+ return d
+
+
+class BlockingStringWrite(object):
+ def __init__(self):
+ self.buffer = ""
+ self.closed = False
+ self.write_lock = threading.Lock()
+
+ self._notify_write_deferred = None
+ self._number_of_writes = 0
+
+ def write(self, bytes):
+ with self.write_lock:
+ self.buffer += bytes
+ self._number_of_writes += 1
+
+ reactor.callFromThread(self._notify_write)
+
+ def close(self):
+ self.closed = True
+
+ def _notify_write(self):
+ "Called by write to indicate a write happened"
+ with self.write_lock:
+ if not self._notify_write_deferred:
+ return
+ d = self._notify_write_deferred
+ self._notify_write_deferred = None
+ d.callback(None)
+
+ @defer.inlineCallbacks
+ def wait_for_n_writes(self, n):
+ "Wait for n writes to have happened"
+ while True:
+ with self.write_lock:
+ if n <= self._number_of_writes:
+ return
+
+ if not self._notify_write_deferred:
+ self._notify_write_deferred = defer.Deferred()
+
+ d = self._notify_write_deferred
+
+ yield d
diff --git a/tests/utils.py b/tests/utils.py
index 44e5f75093..8efd3a3475 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -13,27 +13,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.http.server import HttpServer
-from synapse.api.errors import cs_error, CodeMessageException, StoreError
-from synapse.api.constants import EventTypes
-from synapse.storage.prepare_database import prepare_database
-from synapse.storage.engines import create_engine
-from synapse.server import HomeServer
-from synapse.federation.transport import server
-from synapse.util.ratelimitutils import FederationRateLimiter
-
-from synapse.util.logcontext import LoggingContext
-
-from twisted.internet import defer, reactor
-from twisted.enterprise.adbapi import ConnectionPool
-
-from collections import namedtuple
-from mock import patch, Mock
import hashlib
+from inspect import getcallargs
import urllib
import urlparse
-from inspect import getcallargs
+from mock import Mock, patch
+from twisted.internet import defer, reactor
+
+from synapse.api.errors import CodeMessageException, cs_error
+from synapse.federation.transport import server
+from synapse.http.server import HttpServer
+from synapse.server import HomeServer
+from synapse.storage import PostgresEngine
+from synapse.storage.engines import create_engine
+from synapse.storage.prepare_database import prepare_database
+from synapse.util.logcontext import LoggingContext
+from synapse.util.ratelimitutils import FederationRateLimiter
+
+# set this to True to run the tests against postgres instead of sqlite.
+# It requires you to have a local postgres database called synapse_test, within
+# which ALL TABLES WILL BE DROPPED
+USE_POSTGRES_FOR_TESTS = False
@defer.inlineCallbacks
@@ -57,36 +58,70 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.worker_app = None
config.email_enable_notifs = False
config.block_non_admin_invites = False
+ config.federation_domain_whitelist = None
+ config.user_directory_search_all_users = False
# disable user directory updates, because they get done in the
# background, which upsets the test runner.
config.update_user_directory = False
config.use_frozen_dicts = True
- config.database_config = {"name": "sqlite3"}
config.ldap_enabled = False
if "clock" not in kargs:
kargs["clock"] = MockClock()
+ if USE_POSTGRES_FOR_TESTS:
+ config.database_config = {
+ "name": "psycopg2",
+ "args": {
+ "database": "synapse_test",
+ "cp_min": 1,
+ "cp_max": 5,
+ },
+ }
+ else:
+ config.database_config = {
+ "name": "sqlite3",
+ "args": {
+ "database": ":memory:",
+ "cp_min": 1,
+ "cp_max": 1,
+ },
+ }
+
+ db_engine = create_engine(config.database_config)
+
+ # we need to configure the connection pool to run the on_new_connection
+ # function, so that we can test code that uses custom sqlite functions
+ # (like rank).
+ config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
+
if datastore is None:
- db_pool = SQLiteMemoryDbPool()
- yield db_pool.prepare()
hs = HomeServer(
- name, db_pool=db_pool, config=config,
+ name, config=config,
+ db_config=config.database_config,
version_string="Synapse/tests",
- database_engine=create_engine(config.database_config),
- get_db_conn=db_pool.get_db_conn,
+ database_engine=db_engine,
room_list_handler=object(),
tls_server_context_factory=Mock(),
**kargs
)
+ db_conn = hs.get_db_conn()
+ # make sure that the database is empty
+ if isinstance(db_engine, PostgresEngine):
+ cur = db_conn.cursor()
+ cur.execute("SELECT tablename FROM pg_tables where schemaname='public'")
+ rows = cur.fetchall()
+ for r in rows:
+ cur.execute("DROP TABLE %s CASCADE" % r[0])
+ yield prepare_database(db_conn, db_engine, config)
hs.setup()
else:
hs = HomeServer(
name, db_pool=None, datastore=datastore, config=config,
version_string="Synapse/tests",
- database_engine=create_engine(config.database_config),
+ database_engine=db_engine,
room_list_handler=object(),
tls_server_context_factory=Mock(),
**kargs
@@ -305,168 +340,6 @@ class MockClock(object):
return d
-class SQLiteMemoryDbPool(ConnectionPool, object):
- def __init__(self):
- super(SQLiteMemoryDbPool, self).__init__(
- "sqlite3", ":memory:",
- cp_min=1,
- cp_max=1,
- )
-
- self.config = Mock()
- self.config.password_providers = []
- self.config.database_config = {"name": "sqlite3"}
-
- def prepare(self):
- engine = self.create_engine()
- return self.runWithConnection(
- lambda conn: prepare_database(conn, engine, self.config)
- )
-
- def get_db_conn(self):
- conn = self.connect()
- engine = self.create_engine()
- prepare_database(conn, engine, self.config)
- return conn
-
- def create_engine(self):
- return create_engine(self.config.database_config)
-
-
-class MemoryDataStore(object):
-
- Room = namedtuple(
- "Room",
- ["room_id", "is_public", "creator"]
- )
-
- def __init__(self):
- self.tokens_to_users = {}
- self.paths_to_content = {}
-
- self.members = {}
- self.rooms = {}
-
- self.current_state = {}
- self.events = []
-
- class Snapshot(namedtuple("Snapshot", "room_id user_id membership_state")):
- def fill_out_prev_events(self, event):
- pass
-
- def snapshot_room(self, room_id, user_id, state_type=None, state_key=None):
- return self.Snapshot(
- room_id, user_id, self.get_room_member(user_id, room_id)
- )
-
- def register(self, user_id, token, password_hash):
- if user_id in self.tokens_to_users.values():
- raise StoreError(400, "User in use.")
- self.tokens_to_users[token] = user_id
-
- def get_user_by_access_token(self, token):
- try:
- return {
- "name": self.tokens_to_users[token],
- }
- except Exception:
- raise StoreError(400, "User does not exist.")
-
- def get_room(self, room_id):
- try:
- return self.rooms[room_id]
- except Exception:
- return None
-
- def store_room(self, room_id, room_creator_user_id, is_public):
- if room_id in self.rooms:
- raise StoreError(409, "Conflicting room!")
-
- room = MemoryDataStore.Room(
- room_id=room_id,
- is_public=is_public,
- creator=room_creator_user_id
- )
- self.rooms[room_id] = room
-
- def get_room_member(self, user_id, room_id):
- return self.members.get(room_id, {}).get(user_id)
-
- def get_room_members(self, room_id, membership=None):
- if membership:
- return [
- v for k, v in self.members.get(room_id, {}).items()
- if v.membership == membership
- ]
- else:
- return self.members.get(room_id, {}).values()
-
- def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
- return [
- m[user_id] for m in self.members.values()
- if user_id in m and m[user_id].membership in membership_list
- ]
-
- def get_room_events_stream(self, user_id=None, from_key=None, to_key=None,
- limit=0, with_feedback=False):
- return ([], from_key) # TODO
-
- def get_joined_hosts_for_room(self, room_id):
- return defer.succeed([])
-
- def persist_event(self, event):
- if event.type == EventTypes.Member:
- room_id = event.room_id
- user = event.state_key
- self.members.setdefault(room_id, {})[user] = event
-
- if hasattr(event, "state_key"):
- key = (event.room_id, event.type, event.state_key)
- self.current_state[key] = event
-
- self.events.append(event)
-
- def get_current_state(self, room_id, event_type=None, state_key=""):
- if event_type:
- key = (room_id, event_type, state_key)
- if self.current_state.get(key):
- return [self.current_state.get(key)]
- return None
- else:
- return [
- e for e in self.current_state
- if e[0] == room_id
- ]
-
- def set_presence_state(self, user_localpart, state):
- return defer.succeed({"state": 0})
-
- def get_presence_list(self, user_localpart, accepted):
- return []
-
- def get_room_events_max_id(self):
- return "s0" # TODO (erikj)
-
- def get_send_event_level(self, room_id):
- return defer.succeed(0)
-
- def get_power_level(self, room_id, user_id):
- return defer.succeed(0)
-
- def get_add_state_level(self, room_id):
- return defer.succeed(0)
-
- def get_room_join_rule(self, room_id):
- # TODO (erikj): This should be configurable
- return defer.succeed("invite")
-
- def get_ops_levels(self, room_id):
- return defer.succeed((5, 5, 5))
-
- def insert_client_ip(self, user, access_token, ip, user_agent):
- return defer.succeed(None)
-
-
def _format_call(args, kwargs):
return ", ".join(
["%r" % (a) for a in args] +
|