diff --git a/CHANGES.rst b/CHANGES.rst
index a415944756..6291fedb9a 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -1,3 +1,43 @@
+Changes in synapse v0.23.0-rc2 (2017-09-26)
+===========================================
+
+Bug fixes:
+
+* Fix regression in performance of syncs (PR #2470)
+
+
+Changes in synapse v0.23.0-rc1 (2017-09-25)
+===========================================
+
+Features:
+
+* Add a frontend proxy worker (PR #2344)
+* Add support for event_id_only push format (PR #2450)
+* Add a PoC for filtering spammy events (PR #2456)
+* Add a config option to block all room invites (PR #2457)
+
+
+Changes:
+
+* Use bcrypt module instead of py-bcrypt (PR #2288) Thanks to @kyrias!
+* Improve performance of generating push notifications (PR #2343, #2357, #2365,
+ #2366, #2371)
+* Improve DB performance for device list handling in sync (PR #2362)
+* Include a sample prometheus config (PR #2416)
+* Document known to work postgres version (PR #2433) Thanks to @ptman!
+
+
+Bug fixes:
+
+* Fix caching error in the push evaluator (PR #2332)
+* Fix bug where pusherpool didn't start and broke some rooms (PR #2342)
+* Fix port script for user directory tables (PR #2375)
+* Fix device lists notifications when user rejoins a room (PR #2443, #2449)
+* Fix sync to always send down current state events in timeline (PR #2451)
+* Fix bug where guest users were incorrectly kicked (PR #2453)
+* Fix bug talking to IPv6 only servers using SRV records (PR #2462)
+
+
Changes in synapse v0.22.1 (2017-07-06)
=======================================
diff --git a/README.rst b/README.rst
index 4491b45181..9da8c7f7a8 100644
--- a/README.rst
+++ b/README.rst
@@ -200,11 +200,11 @@ different. See `the spec`__ for more information on key management.)
.. __: `key_management`_
The default configuration exposes two HTTP ports: 8008 and 8448. Port 8008 is
-configured without TLS; it is not recommended this be exposed outside your
-local network. Port 8448 is configured to use TLS with a self-signed
-certificate. This is fine for testing with but, to avoid your clients
-complaining about the certificate, you will almost certainly want to use
-another certificate for production purposes. (Note that a self-signed
+configured without TLS; it should be behind a reverse proxy for TLS/SSL
+termination on port 443 which in turn should be used for clients. Port 8448
+is configured to use TLS with a self-signed certificate. If you would like
+to do initial test with a client without having to setup a reverse proxy,
+you can temporarly use another certificate. (Note that a self-signed
certificate is fine for `Federation`_). You can do so by changing
``tls_certificate_path``, ``tls_private_key_path`` and ``tls_dh_params_path``
in ``homeserver.yaml``; alternatively, you can use a reverse-proxy, but be sure
@@ -283,10 +283,16 @@ Connecting to Synapse from a client
The easiest way to try out your new Synapse installation is by connecting to it
from a web client. The easiest option is probably the one at
http://riot.im/app. You will need to specify a "Custom server" when you log on
-or register: set this to ``https://localhost:8448`` - remember to specify the
-port (``:8448``) unless you changed the configuration. (Leave the identity
+or register: set this to ``https://domain.tld`` if you setup a reverse proxy
+following the recommended setup, or ``https://localhost:8448`` - remember to specify the
+port (``:8448``) if not ``:443`` unless you changed the configuration. (Leave the identity
server as the default - see `Identity servers`_.)
+If using port 8448 you will run into errors until you accept the self-signed
+certificate. You can easily do this by going to ``https://localhost:8448``
+directly with your browser and accept the presented certificate. You can then
+go back in your web client and proceed further.
+
If all goes well you should at least be able to log in, create a room, and
start sending messages.
@@ -593,8 +599,9 @@ you to run your server on a machine that might not have the same name as your
domain name. For example, you might want to run your server at
``synapse.example.com``, but have your Matrix user-ids look like
``@user:example.com``. (A SRV record also allows you to change the port from
-the default 8448. However, if you are thinking of using a reverse-proxy, be
-sure to read `Reverse-proxying the federation port`_ first.)
+the default 8448. However, if you are thinking of using a reverse-proxy on the
+federation port, which is not recommended, be sure to read
+`Reverse-proxying the federation port`_ first.)
To use a SRV record, first create your SRV record and publish it in DNS. This
should have the format ``_matrix._tcp.<yourdomain.com> <ttl> IN SRV 10 0 <port>
@@ -674,7 +681,7 @@ For information on how to install and use PostgreSQL, please see
Using a reverse proxy with Synapse
==================================
-It is possible to put a reverse proxy such as
+It is recommended to put a reverse proxy such as
`nginx <https://nginx.org/en/docs/http/ngx_http_proxy_module.html>`_,
`Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_ or
`HAProxy <http://www.haproxy.org/>`_ in front of Synapse. One advantage of
@@ -692,9 +699,9 @@ federation port has a number of pitfalls. It is possible, but be sure to read
`Reverse-proxying the federation port`_.
The recommended setup is therefore to configure your reverse-proxy on port 443
-for client connections, but to also expose port 8448 for server-server
-connections. All the Matrix endpoints begin ``/_matrix``, so an example nginx
-configuration might look like::
+to port 8008 of synapse for client connections, but to also directly expose port
+8448 for server-server connections. All the Matrix endpoints begin ``/_matrix``,
+so an example nginx configuration might look like::
server {
listen 443 ssl;
diff --git a/synapse/__init__.py b/synapse/__init__.py
index dbf22eca00..ec83e6adb7 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
-__version__ = "0.22.1"
+__version__ = "0.23.0-rc2"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index e3da45b416..72858cca1f 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -519,6 +519,14 @@ class Auth(object):
)
def is_server_admin(self, user):
+ """ Check if the given user is a local server admin.
+
+ Args:
+ user (str): mxid of user to check
+
+ Returns:
+ bool: True if the user is an admin
+ """
return self.store.is_server_admin(user)
@defer.inlineCallbacks
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index b22cacf8dc..3f9d9d5f8b 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -34,6 +34,7 @@ from .password_auth_providers import PasswordAuthProviderConfig
from .emailconfig import EmailConfig
from .workers import WorkerConfig
from .push import PushConfig
+from .spam_checker import SpamCheckerConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
@@ -41,7 +42,8 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
JWTConfig, PasswordConfig, EmailConfig,
- WorkerConfig, PasswordAuthProviderConfig, PushConfig,):
+ WorkerConfig, PasswordAuthProviderConfig, PushConfig,
+ SpamCheckerConfig,):
pass
diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py
index 83762d089a..90824cab7f 100644
--- a/synapse/config/password_auth_providers.py
+++ b/synapse/config/password_auth_providers.py
@@ -15,13 +15,15 @@
from ._base import Config, ConfigError
-import importlib
+from synapse.util.module_loader import load_module
class PasswordAuthProviderConfig(Config):
def read_config(self, config):
self.password_providers = []
+ provider_config = None
+
# We want to be backwards compatible with the old `ldap_config`
# param.
ldap_config = config.get("ldap_config", {})
@@ -38,19 +40,15 @@ class PasswordAuthProviderConfig(Config):
if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider":
from ldap_auth_provider import LdapAuthProvider
provider_class = LdapAuthProvider
+ try:
+ provider_config = provider_class.parse_config(provider["config"])
+ except Exception as e:
+ raise ConfigError(
+ "Failed to parse config for %r: %r" % (provider['module'], e)
+ )
else:
- # We need to import the module, and then pick the class out of
- # that, so we split based on the last dot.
- module, clz = provider['module'].rsplit(".", 1)
- module = importlib.import_module(module)
- provider_class = getattr(module, clz)
+ (provider_class, provider_config) = load_module(provider)
- try:
- provider_config = provider_class.parse_config(provider["config"])
- except Exception as e:
- raise ConfigError(
- "Failed to parse config for %r: %r" % (provider['module'], e)
- )
self.password_providers.append((provider_class, provider_config))
def default_config(self, **kwargs):
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 89d61a0503..c9a1715f1f 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -43,6 +43,12 @@ class ServerConfig(Config):
self.filter_timeline_limit = config.get("filter_timeline_limit", -1)
+ # Whether we should block invites sent to users on this server
+ # (other than those sent by local server admins)
+ self.block_non_admin_invites = config.get(
+ "block_non_admin_invites", False,
+ )
+
if self.public_baseurl is not None:
if self.public_baseurl[-1] != '/':
self.public_baseurl += '/'
@@ -194,6 +200,10 @@ class ServerConfig(Config):
# and sync operations. The default value is -1, means no upper limit.
# filter_timeline_limit: 5000
+ # Whether room invites to users on this server should be blocked
+ # (except those sent by local server admins). The default is False.
+ # block_non_admin_invites: True
+
# List of ports that Synapse should listen on, their purpose and their
# configuration.
listeners:
diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
new file mode 100644
index 0000000000..3fec42bdb0
--- /dev/null
+++ b/synapse/config/spam_checker.py
@@ -0,0 +1,35 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util.module_loader import load_module
+
+from ._base import Config
+
+
+class SpamCheckerConfig(Config):
+ def read_config(self, config):
+ self.spam_checker = None
+
+ provider = config.get("spam_checker", None)
+ if provider is not None:
+ self.spam_checker = load_module(provider)
+
+ def default_config(self, **kwargs):
+ return """\
+ # spam_checker:
+ # module: "my_custom_project.SuperSpamChecker"
+ # config:
+ # example_option: 'things'
+ """
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
index c2bd64d6c2..f1fd488b90 100644
--- a/synapse/crypto/keyclient.py
+++ b/synapse/crypto/keyclient.py
@@ -13,14 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
+from synapse.util import logcontext
from twisted.web.http import HTTPClient
from twisted.internet.protocol import Factory
from twisted.internet import defer, reactor
from synapse.http.endpoint import matrix_federation_endpoint
-from synapse.util.logcontext import (
- preserve_context_over_fn, preserve_context_over_deferred
-)
import simplejson as json
import logging
@@ -43,14 +40,10 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
for i in range(5):
try:
- protocol = yield preserve_context_over_fn(
- endpoint.connect, factory
- )
- server_response, server_certificate = yield preserve_context_over_deferred(
- protocol.remote_key
- )
- defer.returnValue((server_response, server_certificate))
- return
+ with logcontext.PreserveLoggingContext():
+ protocol = yield endpoint.connect(factory)
+ server_response, server_certificate = yield protocol.remote_key
+ defer.returnValue((server_response, server_certificate))
except SynapseKeyClientError as e:
logger.exception("Error getting key for %r" % (server_name,))
if e.status.startswith("4"):
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 51851d04e5..054bac456d 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -18,7 +18,7 @@ from synapse.crypto.keyclient import fetch_server_key
from synapse.api.errors import SynapseError, Codes
from synapse.util import unwrapFirstError, logcontext
from synapse.util.logcontext import (
- preserve_context_over_fn, PreserveLoggingContext,
+ PreserveLoggingContext,
preserve_fn
)
from synapse.util.metrics import Measure
@@ -57,7 +57,8 @@ Attributes:
json_object(dict): The JSON object to verify.
deferred(twisted.internet.defer.Deferred):
A deferred (server_name, key_id, verify_key) tuple that resolves when
- a verify key has been fetched
+ a verify key has been fetched. The deferreds' callbacks are run with no
+ logcontext.
"""
@@ -82,9 +83,11 @@ class Keyring(object):
self.key_downloads = {}
def verify_json_for_server(self, server_name, json_object):
- return self.verify_json_objects_for_server(
- [(server_name, json_object)]
- )[0]
+ return logcontext.make_deferred_yieldable(
+ self.verify_json_objects_for_server(
+ [(server_name, json_object)]
+ )[0]
+ )
def verify_json_objects_for_server(self, server_and_json):
"""Bulk verifies signatures of json objects, bulk fetching keys as
@@ -94,8 +97,10 @@ class Keyring(object):
server_and_json (list): List of pairs of (server_name, json_object)
Returns:
- list of deferreds indicating success or failure to verify each
- json object's signature for the given server_name.
+ List<Deferred>: for each input pair, a deferred indicating success
+ or failure to verify each json object's signature for the given
+ server_name. The deferreds run their callbacks in the sentinel
+ logcontext.
"""
verify_requests = []
@@ -122,93 +127,71 @@ class Keyring(object):
verify_requests.append(verify_request)
- @defer.inlineCallbacks
- def handle_key_deferred(verify_request):
- server_name = verify_request.server_name
- try:
- _, key_id, verify_key = yield verify_request.deferred
- except IOError as e:
- logger.warn(
- "Got IOError when downloading keys for %s: %s %s",
- server_name, type(e).__name__, str(e.message),
- )
- raise SynapseError(
- 502,
- "Error downloading keys for %s" % (server_name,),
- Codes.UNAUTHORIZED,
- )
- except Exception as e:
- logger.exception(
- "Got Exception when downloading keys for %s: %s %s",
- server_name, type(e).__name__, str(e.message),
- )
- raise SynapseError(
- 401,
- "No key for %s with id %s" % (server_name, key_ids),
- Codes.UNAUTHORIZED,
- )
+ preserve_fn(self._start_key_lookups)(verify_requests)
- json_object = verify_request.json_object
+ # Pass those keys to handle_key_deferred so that the json object
+ # signatures can be verified
+ handle = preserve_fn(_handle_key_deferred)
+ return [
+ handle(rq) for rq in verify_requests
+ ]
- logger.debug("Got key %s %s:%s for server %s, verifying" % (
- key_id, verify_key.alg, verify_key.version, server_name,
- ))
- try:
- verify_signed_json(json_object, server_name, verify_key)
- except:
- raise SynapseError(
- 401,
- "Invalid signature for server %s with key %s:%s" % (
- server_name, verify_key.alg, verify_key.version
- ),
- Codes.UNAUTHORIZED,
- )
+ @defer.inlineCallbacks
+ def _start_key_lookups(self, verify_requests):
+ """Sets off the key fetches for each verify request
+
+ Once each fetch completes, verify_request.deferred will be resolved.
+
+ Args:
+ verify_requests (List[VerifyKeyRequest]):
+ """
+ # create a deferred for each server we're going to look up the keys
+ # for; we'll resolve them once we have completed our lookups.
+ # These will be passed into wait_for_previous_lookups to block
+ # any other lookups until we have finished.
+ # The deferreds are called with no logcontext.
server_to_deferred = {
- server_name: defer.Deferred()
- for server_name, _ in server_and_json
+ rq.server_name: defer.Deferred()
+ for rq in verify_requests
}
- with PreserveLoggingContext():
+ # We want to wait for any previous lookups to complete before
+ # proceeding.
+ yield self.wait_for_previous_lookups(
+ [rq.server_name for rq in verify_requests],
+ server_to_deferred,
+ )
- # We want to wait for any previous lookups to complete before
- # proceeding.
- wait_on_deferred = self.wait_for_previous_lookups(
- [server_name for server_name, _ in server_and_json],
- server_to_deferred,
- )
+ # Actually start fetching keys.
+ self._get_server_verify_keys(verify_requests)
- # Actually start fetching keys.
- wait_on_deferred.addBoth(
- lambda _: self.get_server_verify_keys(verify_requests)
- )
+ # When we've finished fetching all the keys for a given server_name,
+ # resolve the deferred passed to `wait_for_previous_lookups` so that
+ # any lookups waiting will proceed.
+ #
+ # map from server name to a set of request ids
+ server_to_request_ids = {}
- # When we've finished fetching all the keys for a given server_name,
- # resolve the deferred passed to `wait_for_previous_lookups` so that
- # any lookups waiting will proceed.
- server_to_request_ids = {}
-
- def remove_deferreds(res, server_name, verify_request):
- request_id = id(verify_request)
- server_to_request_ids[server_name].discard(request_id)
- if not server_to_request_ids[server_name]:
- d = server_to_deferred.pop(server_name, None)
- if d:
- d.callback(None)
- return res
-
- for verify_request in verify_requests:
- server_name = verify_request.server_name
- request_id = id(verify_request)
- server_to_request_ids.setdefault(server_name, set()).add(request_id)
- deferred.addBoth(remove_deferreds, server_name, verify_request)
+ for verify_request in verify_requests:
+ server_name = verify_request.server_name
+ request_id = id(verify_request)
+ server_to_request_ids.setdefault(server_name, set()).add(request_id)
- # Pass those keys to handle_key_deferred so that the json object
- # signatures can be verified
- return [
- preserve_context_over_fn(handle_key_deferred, verify_request)
- for verify_request in verify_requests
- ]
+ def remove_deferreds(res, verify_request):
+ server_name = verify_request.server_name
+ request_id = id(verify_request)
+ server_to_request_ids[server_name].discard(request_id)
+ if not server_to_request_ids[server_name]:
+ d = server_to_deferred.pop(server_name, None)
+ if d:
+ d.callback(None)
+ return res
+
+ for verify_request in verify_requests:
+ verify_request.deferred.addBoth(
+ remove_deferreds, verify_request,
+ )
@defer.inlineCallbacks
def wait_for_previous_lookups(self, server_names, server_to_deferred):
@@ -245,7 +228,7 @@ class Keyring(object):
self.key_downloads[server_name] = deferred
deferred.addBoth(rm, server_name)
- def get_server_verify_keys(self, verify_requests):
+ def _get_server_verify_keys(self, verify_requests):
"""Tries to find at least one key for each verify request
For each verify_request, verify_request.deferred is called back with
@@ -314,21 +297,23 @@ class Keyring(object):
if not missing_keys:
break
- for verify_request in requests_missing_keys.values():
- verify_request.deferred.errback(SynapseError(
- 401,
- "No key for %s with id %s" % (
- verify_request.server_name, verify_request.key_ids,
- ),
- Codes.UNAUTHORIZED,
- ))
+ with PreserveLoggingContext():
+ for verify_request in requests_missing_keys:
+ verify_request.deferred.errback(SynapseError(
+ 401,
+ "No key for %s with id %s" % (
+ verify_request.server_name, verify_request.key_ids,
+ ),
+ Codes.UNAUTHORIZED,
+ ))
def on_err(err):
- for verify_request in verify_requests:
- if not verify_request.deferred.called:
- verify_request.deferred.errback(err)
+ with PreserveLoggingContext():
+ for verify_request in verify_requests:
+ if not verify_request.deferred.called:
+ verify_request.deferred.errback(err)
- do_iterations().addErrback(on_err)
+ preserve_fn(do_iterations)().addErrback(on_err)
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):
@@ -738,3 +723,47 @@ class Keyring(object):
],
consumeErrors=True,
).addErrback(unwrapFirstError))
+
+
+@defer.inlineCallbacks
+def _handle_key_deferred(verify_request):
+ server_name = verify_request.server_name
+ try:
+ with PreserveLoggingContext():
+ _, key_id, verify_key = yield verify_request.deferred
+ except IOError as e:
+ logger.warn(
+ "Got IOError when downloading keys for %s: %s %s",
+ server_name, type(e).__name__, str(e.message),
+ )
+ raise SynapseError(
+ 502,
+ "Error downloading keys for %s" % (server_name,),
+ Codes.UNAUTHORIZED,
+ )
+ except Exception as e:
+ logger.exception(
+ "Got Exception when downloading keys for %s: %s %s",
+ server_name, type(e).__name__, str(e.message),
+ )
+ raise SynapseError(
+ 401,
+ "No key for %s with id %s" % (server_name, verify_request.key_ids),
+ Codes.UNAUTHORIZED,
+ )
+
+ json_object = verify_request.json_object
+
+ logger.debug("Got key %s %s:%s for server %s, verifying" % (
+ key_id, verify_key.alg, verify_key.version, server_name,
+ ))
+ try:
+ verify_signed_json(json_object, server_name, verify_key)
+ except:
+ raise SynapseError(
+ 401,
+ "Invalid signature for server %s with key %s:%s" % (
+ server_name, verify_key.alg, verify_key.version
+ ),
+ Codes.UNAUTHORIZED,
+ )
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
new file mode 100644
index 0000000000..e739f105b2
--- /dev/null
+++ b/synapse/events/spamcheck.py
@@ -0,0 +1,47 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class SpamChecker(object):
+ def __init__(self, hs):
+ self.spam_checker = None
+
+ module = None
+ config = None
+ try:
+ module, config = hs.config.spam_checker
+ except:
+ pass
+
+ if module is not None:
+ self.spam_checker = module(config=config)
+
+ def check_event_for_spam(self, event):
+ """Checks if a given event is considered "spammy" by this server.
+
+ If the server considers an event spammy, then it will be rejected if
+ sent by a local user. If it is sent by a user on another server, then
+ users receive a blank event.
+
+ Args:
+ event (synapse.events.EventBase): the event to be checked
+
+ Returns:
+ bool: True if the event is spammy.
+ """
+ if self.spam_checker is None:
+ return False
+
+ return self.spam_checker.check_event_for_spam(event)
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 2339cc9034..a0f5d40eb3 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -12,28 +12,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-
-from twisted.internet import defer
-
-from synapse.events.utils import prune_event
-
-from synapse.crypto.event_signing import check_event_content_hash
-
-from synapse.api.errors import SynapseError
-
-from synapse.util import unwrapFirstError
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
-
import logging
+from synapse.api.errors import SynapseError
+from synapse.crypto.event_signing import check_event_content_hash
+from synapse.events.utils import prune_event
+from synapse.util import unwrapFirstError, logcontext
+from twisted.internet import defer
logger = logging.getLogger(__name__)
class FederationBase(object):
def __init__(self, hs):
- pass
+ self.spam_checker = hs.get_spam_checker()
@defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
@@ -57,56 +49,52 @@ class FederationBase(object):
"""
deferreds = self._check_sigs_and_hashes(pdus)
- def callback(pdu):
- return pdu
+ @defer.inlineCallbacks
+ def handle_check_result(pdu, deferred):
+ try:
+ res = yield logcontext.make_deferred_yieldable(deferred)
+ except SynapseError:
+ res = None
- def errback(failure, pdu):
- failure.trap(SynapseError)
- return None
-
- def try_local_db(res, pdu):
if not res:
# Check local db.
- return self.store.get_event(
+ res = yield self.store.get_event(
pdu.event_id,
allow_rejected=True,
allow_none=True,
)
- return res
- def try_remote(res, pdu):
if not res and pdu.origin != origin:
- return self.get_pdu(
- destinations=[pdu.origin],
- event_id=pdu.event_id,
- outlier=outlier,
- timeout=10000,
- ).addErrback(lambda e: None)
- return res
-
- def warn(res, pdu):
+ try:
+ res = yield self.get_pdu(
+ destinations=[pdu.origin],
+ event_id=pdu.event_id,
+ outlier=outlier,
+ timeout=10000,
+ )
+ except SynapseError:
+ pass
+
if not res:
logger.warn(
"Failed to find copy of %s with valid signature",
pdu.event_id,
)
- return res
- for pdu, deferred in zip(pdus, deferreds):
- deferred.addCallbacks(
- callback, errback, errbackArgs=[pdu]
- ).addCallback(
- try_local_db, pdu
- ).addCallback(
- try_remote, pdu
- ).addCallback(
- warn, pdu
- )
+ defer.returnValue(res)
- valid_pdus = yield preserve_context_over_deferred(defer.gatherResults(
- deferreds,
- consumeErrors=True
- )).addErrback(unwrapFirstError)
+ handle = logcontext.preserve_fn(handle_check_result)
+ deferreds2 = [
+ handle(pdu, deferred)
+ for pdu, deferred in zip(pdus, deferreds)
+ ]
+
+ valid_pdus = yield logcontext.make_deferred_yieldable(
+ defer.gatherResults(
+ deferreds2,
+ consumeErrors=True,
+ )
+ ).addErrback(unwrapFirstError)
if include_none:
defer.returnValue(valid_pdus)
@@ -114,15 +102,24 @@ class FederationBase(object):
defer.returnValue([p for p in valid_pdus if p])
def _check_sigs_and_hash(self, pdu):
- return self._check_sigs_and_hashes([pdu])[0]
+ return logcontext.make_deferred_yieldable(
+ self._check_sigs_and_hashes([pdu])[0],
+ )
def _check_sigs_and_hashes(self, pdus):
- """Throws a SynapseError if a PDU does not have the correct
- signatures.
+ """Checks that each of the received events is correctly signed by the
+ sending server.
+
+ Args:
+ pdus (list[FrozenEvent]): the events to be checked
Returns:
- FrozenEvent: Either the given event or it redacted if it failed the
- content hash check.
+ list[Deferred]: for each input event, a deferred which:
+ * returns the original event if the checks pass
+ * returns a redacted version of the event (if the signature
+ matched but the hash did not)
+ * throws a SynapseError if the signature check failed.
+ The deferreds run their callbacks in the sentinel logcontext.
"""
redacted_pdus = [
@@ -130,26 +127,38 @@ class FederationBase(object):
for pdu in pdus
]
- deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([
+ deferreds = self.keyring.verify_json_objects_for_server([
(p.origin, p.get_pdu_json())
for p in redacted_pdus
])
+ ctx = logcontext.LoggingContext.current_context()
+
def callback(_, pdu, redacted):
- if not check_event_content_hash(pdu):
- logger.warn(
- "Event content has been tampered, redacting %s: %s",
- pdu.event_id, pdu.get_pdu_json()
- )
- return redacted
- return pdu
+ with logcontext.PreserveLoggingContext(ctx):
+ if not check_event_content_hash(pdu):
+ logger.warn(
+ "Event content has been tampered, redacting %s: %s",
+ pdu.event_id, pdu.get_pdu_json()
+ )
+ return redacted
+
+ if self.spam_checker.check_event_for_spam(pdu):
+ logger.warn(
+ "Event contains spam, redacting %s: %s",
+ pdu.event_id, pdu.get_pdu_json()
+ )
+ return redacted
+
+ return pdu
def errback(failure, pdu):
failure.trap(SynapseError)
- logger.warn(
- "Signature check failed for %s",
- pdu.event_id,
- )
+ with logcontext.PreserveLoggingContext(ctx):
+ logger.warn(
+ "Signature check failed for %s",
+ pdu.event_id,
+ )
return failure
for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 861441708b..7c5e5d957f 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -22,7 +22,7 @@ from synapse.api.constants import Membership
from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError,
)
-from synapse.util import unwrapFirstError
+from synapse.util import unwrapFirstError, logcontext
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
@@ -189,10 +189,10 @@ class FederationClient(FederationBase):
]
# FIXME: We should handle signature failures more gracefully.
- pdus[:] = yield preserve_context_over_deferred(defer.gatherResults(
+ pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults(
self._check_sigs_and_hashes(pdus),
consumeErrors=True,
- )).addErrback(unwrapFirstError)
+ ).addErrback(unwrapFirstError))
defer.returnValue(pdus)
@@ -252,7 +252,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0]
# Check signatures are correct.
- signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
+ signed_pdu = yield self._check_sigs_and_hash(pdu)
break
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 4669199b2d..18f87cad67 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1074,6 +1074,9 @@ class FederationHandler(BaseHandler):
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
+ if self.hs.config.block_non_admin_invites:
+ raise SynapseError(403, "This server does not accept room invites")
+
membership = event.content.get("membership")
if event.type != EventTypes.Member or membership != Membership.INVITE:
raise SynapseError(400, "The event was not an m.room.member invite event")
@@ -2090,6 +2093,14 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
+ """Handle an exchange_third_party_invite request from a remote server
+
+ The remote server will call this when it wants to turn a 3pid invite
+ into a normal m.room.member invite.
+
+ Returns:
+ Deferred: resolves (to None)
+ """
builder = self.event_builder_factory.new(event_dict)
message_handler = self.hs.get_handlers().message_handler
@@ -2108,9 +2119,12 @@ class FederationHandler(BaseHandler):
raise e
yield self._check_signature(event, context)
+ # XXX we send the invite here, but send_membership_event also sends it,
+ # so we end up making two requests. I think this is redundant.
returned_invite = yield self.send_invite(origin, event)
# TODO: Make sure the signatures actually are correct.
event.signatures.update(returned_invite.signatures)
+
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(None, event, context)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 5b8f20b73c..e22d4803b9 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
@@ -59,6 +58,8 @@ class MessageHandler(BaseHandler):
self.action_generator = hs.get_action_generator()
+ self.spam_checker = hs.get_spam_checker()
+
@defer.inlineCallbacks
def purge_history(self, room_id, event_id):
event = yield self.store.get_event(event_id)
@@ -322,6 +323,12 @@ class MessageHandler(BaseHandler):
token_id=requester.access_token_id,
txn_id=txn_id
)
+
+ if self.spam_checker.check_event_for_spam(event):
+ raise SynapseError(
+ 403, "Spam is not permitted here", Codes.FORBIDDEN
+ )
+
yield self.send_nonmember_event(
requester,
event,
@@ -413,6 +420,51 @@ class MessageHandler(BaseHandler):
[serialize_event(c, now) for c in room_state.values()]
)
+ @defer.inlineCallbacks
+ def get_joined_members(self, requester, room_id):
+ """Get all the joined members in the room and their profile information.
+
+ If the user has left the room return the state events from when they left.
+
+ Args:
+ requester(Requester): The user requesting state events.
+ room_id(str): The room ID to get all state events from.
+ Returns:
+ A dict of user_id to profile info
+ """
+ user_id = requester.user.to_string()
+ if not requester.app_service:
+ # We check AS auth after fetching the room membership, as it
+ # requires us to pull out all joined members anyway.
+ membership, _ = yield self._check_in_room_or_world_readable(
+ room_id, user_id
+ )
+ if membership != Membership.JOIN:
+ raise NotImplementedError(
+ "Getting joined members after leaving is not implemented"
+ )
+
+ users_with_profile = yield self.state.get_current_user_in_room(room_id)
+
+ # If this is an AS, double check that they are allowed to see the members.
+ # This can either be because the AS user is in the room or becuase there
+ # is a user in the room that the AS is "interested in"
+ if requester.app_service and user_id not in users_with_profile:
+ for uid in users_with_profile:
+ if requester.app_service.is_interested_in_user(uid):
+ break
+ else:
+ # Loop fell through, AS has no interested users in room
+ raise AuthError(403, "Appservice not in room")
+
+ defer.returnValue({
+ user_id: {
+ "avatar_url": profile.avatar_url,
+ "display_name": profile.display_name,
+ }
+ for user_id, profile in users_with_profile.iteritems()
+ })
+
@measure_func("_create_new_client_event")
@defer.inlineCallbacks
def _create_new_client_event(self, builder, requester=None, prev_event_ids=None):
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index dadc19d45b..d6ad57171c 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -193,6 +193,8 @@ class RoomMemberHandler(BaseHandler):
if action in ["kick", "unban"]:
effective_membership_state = "leave"
+ # if this is a join with a 3pid signature, we may need to turn a 3pid
+ # invite into a normal invite before we can handle the join.
if third_party_signed is not None:
replication = self.hs.get_replication_layer()
yield replication.exchange_third_party_invite(
@@ -210,6 +212,16 @@ class RoomMemberHandler(BaseHandler):
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
+ if (effective_membership_state == "invite" and
+ self.hs.config.block_non_admin_invites):
+ is_requester_admin = yield self.auth.is_server_admin(
+ requester.user,
+ )
+ if not is_requester_admin:
+ raise SynapseError(
+ 403, "Invites have been disabled on this server",
+ )
+
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
current_state_ids = yield self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids,
@@ -473,6 +485,16 @@ class RoomMemberHandler(BaseHandler):
requester,
txn_id
):
+ if self.hs.config.block_non_admin_invites:
+ is_requester_admin = yield self.auth.is_server_admin(
+ requester.user,
+ )
+ if not is_requester_admin:
+ raise SynapseError(
+ 403, "Invites have been disabled on this server",
+ Codes.FORBIDDEN,
+ )
+
invitee = yield self._lookup_3pid(
id_server, medium, address
)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 69c1bc189e..219529936f 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -306,11 +306,6 @@ class SyncHandler(object):
timeline_limit = sync_config.filter_collection.timeline_limit()
block_all_timeline = sync_config.filter_collection.blocks_all_room_timeline()
- # Pull out the current state, as we always want to include those events
- # in the timeline if they're there.
- current_state_ids = yield self.state.get_current_state_ids(room_id)
- current_state_ids = frozenset(current_state_ids.itervalues())
-
if recents is None or newly_joined_room or timeline_limit < len(recents):
limited = True
else:
@@ -318,6 +313,15 @@ class SyncHandler(object):
if recents:
recents = sync_config.filter_collection.filter_room_timeline(recents)
+
+ # We check if there are any state events, if there are then we pass
+ # all current state events to the filter_events function. This is to
+ # ensure that we always include current state in the timeline
+ current_state_ids = frozenset()
+ if any(e.is_state() for e in recents):
+ current_state_ids = yield self.state.get_current_state_ids(room_id)
+ current_state_ids = frozenset(current_state_ids.itervalues())
+
recents = yield filter_events_for_client(
self.store,
sync_config.user.to_string(),
@@ -354,6 +358,15 @@ class SyncHandler(object):
loaded_recents = sync_config.filter_collection.filter_room_timeline(
events
)
+
+ # We check if there are any state events, if there are then we pass
+ # all current state events to the filter_events function. This is to
+ # ensure that we always include current state in the timeline
+ current_state_ids = frozenset()
+ if any(e.is_state() for e in loaded_recents):
+ current_state_ids = yield self.state.get_current_state_ids(room_id)
+ current_state_ids = frozenset(current_state_ids.itervalues())
+
loaded_recents = yield filter_events_for_client(
self.store,
sync_config.user.to_string(),
@@ -1042,7 +1055,18 @@ class SyncHandler(object):
# We want to figure out if we joined the room at some point since
# the last sync (even if we have since left). This is to make sure
# we do send down the room, and with full state, where necessary
+
old_state_ids = None
+ if room_id in joined_room_ids and non_joins:
+ # Always include if the user (re)joined the room, especially
+ # important so that device list changes are calculated correctly.
+ # If there are non join member events, but we are still in the room,
+ # then the user must have left and joined
+ newly_joined_rooms.append(room_id)
+
+ # User is in the room so we don't need to do the invite/leave checks
+ continue
+
if room_id in joined_room_ids or has_join:
old_state_ids = yield self.get_state_at(room_id, since_token)
old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
@@ -1054,8 +1078,9 @@ class SyncHandler(object):
if not old_mem_ev or old_mem_ev.membership != Membership.JOIN:
newly_joined_rooms.append(room_id)
- if room_id in joined_room_ids:
- continue
+ # If user is in the room then we don't need to do the invite/leave checks
+ if room_id in joined_room_ids:
+ continue
if not non_joins:
continue
diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py
index d8923c9abb..a97532162f 100644
--- a/synapse/http/endpoint.py
+++ b/synapse/http/endpoint.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import socket
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet import defer, reactor
@@ -30,7 +31,10 @@ logger = logging.getLogger(__name__)
SERVER_CACHE = {}
-
+# our record of an individual server which can be tried to reach a destination.
+#
+# "host" is actually a dotted-quad or ipv6 address string. Except when there's
+# no SRV record, in which case it is the original hostname.
_Server = collections.namedtuple(
"_Server", "priority weight host port expires"
)
@@ -219,9 +223,10 @@ class SRVClientEndpoint(object):
return self.default_server
else:
raise ConnectError(
- "Not server available for %s" % self.service_name
+ "No server available for %s" % self.service_name
)
+ # look for all servers with the same priority
min_priority = self.servers[0].priority
weight_indexes = list(
(index, server.weight + 1)
@@ -231,11 +236,22 @@ class SRVClientEndpoint(object):
total_weight = sum(weight for index, weight in weight_indexes)
target_weight = random.randint(0, total_weight)
-
for index, weight in weight_indexes:
target_weight -= weight
if target_weight <= 0:
server = self.servers[index]
+ # XXX: this looks totally dubious:
+ #
+ # (a) we never reuse a server until we have been through
+ # all of the servers at the same priority, so if the
+ # weights are A: 100, B:1, we always do ABABAB instead of
+ # AAAA...AAAB (approximately).
+ #
+ # (b) After using all the servers at the lowest priority,
+ # we move onto the next priority. We should only use the
+ # second priority if servers at the top priority are
+ # unreachable.
+ #
del self.servers[index]
self.used_servers.append(server)
return server
@@ -280,26 +296,21 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
continue
payload = answer.payload
- host = str(payload.target)
- srv_ttl = answer.ttl
- try:
- answers, _, _ = yield dns_client.lookupAddress(host)
- except DNSNameError:
- continue
+ hosts = yield _get_hosts_for_srv_record(
+ dns_client, str(payload.target)
+ )
- for answer in answers:
- if answer.type == dns.A and answer.payload:
- ip = answer.payload.dottedQuad()
- host_ttl = min(srv_ttl, answer.ttl)
+ for (ip, ttl) in hosts:
+ host_ttl = min(answer.ttl, ttl)
- servers.append(_Server(
- host=ip,
- port=int(payload.port),
- priority=int(payload.priority),
- weight=int(payload.weight),
- expires=int(clock.time()) + host_ttl,
- ))
+ servers.append(_Server(
+ host=ip,
+ port=int(payload.port),
+ priority=int(payload.priority),
+ weight=int(payload.weight),
+ expires=int(clock.time()) + host_ttl,
+ ))
servers.sort()
cache[service_name] = list(servers)
@@ -317,3 +328,80 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t
raise e
defer.returnValue(servers)
+
+
+@defer.inlineCallbacks
+def _get_hosts_for_srv_record(dns_client, host):
+ """Look up each of the hosts in a SRV record
+
+ Args:
+ dns_client (twisted.names.dns.IResolver):
+ host (basestring): host to look up
+
+ Returns:
+ Deferred[list[(str, int)]]: a list of (host, ttl) pairs
+
+ """
+ ip4_servers = []
+ ip6_servers = []
+
+ def cb(res):
+ # lookupAddress and lookupIP6Address return a three-tuple
+ # giving the answer, authority, and additional sections of the
+ # response.
+ #
+ # we only care about the answers.
+
+ return res[0]
+
+ def eb(res, record_type):
+ if res.check(DNSNameError):
+ return []
+ logger.warn("Error looking up %s for %s: %s",
+ record_type, host, res, res.value)
+ return res
+
+ # no logcontexts here, so we can safely fire these off and gatherResults
+ d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb)
+ d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb)
+ results = yield defer.DeferredList(
+ [d1, d2], consumeErrors=True)
+
+ # if all of the lookups failed, raise an exception rather than blowing out
+ # the cache with an empty result.
+ if results and all(s == defer.FAILURE for (s, _) in results):
+ defer.returnValue(results[0][1])
+
+ for (success, result) in results:
+ if success == defer.FAILURE:
+ continue
+
+ for answer in result:
+ if not answer.payload:
+ continue
+
+ try:
+ if answer.type == dns.A:
+ ip = answer.payload.dottedQuad()
+ ip4_servers.append((ip, answer.ttl))
+ elif answer.type == dns.AAAA:
+ ip = socket.inet_ntop(
+ socket.AF_INET6, answer.payload.address,
+ )
+ ip6_servers.append((ip, answer.ttl))
+ else:
+ # the most likely candidate here is a CNAME record.
+ # rfc2782 says srvs may not point to aliases.
+ logger.warn(
+ "Ignoring unexpected DNS record type %s for %s",
+ answer.type, host,
+ )
+ continue
+ except Exception as e:
+ logger.warn("Ignoring invalid DNS response for %s: %s",
+ host, e)
+ continue
+
+ # keep the ipv4 results before the ipv6 results, mostly to match historical
+ # behaviour.
+ defer.returnValue(ip4_servers + ip6_servers)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 8b94e6f29f..8c8b7fa656 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -204,18 +204,15 @@ class MatrixFederationHttpClient(object):
raise
logger.warn(
- "{%s} Sending request failed to %s: %s %s: %s - %s",
+ "{%s} Sending request failed to %s: %s %s: %s",
txn_id,
destination,
method,
url_bytes,
- type(e).__name__,
_flatten_response_never_received(e),
)
- log_result = "%s - %s" % (
- type(e).__name__, _flatten_response_never_received(e),
- )
+ log_result = _flatten_response_never_received(e)
if retries_left and not timeout:
if long_retries:
@@ -618,12 +615,14 @@ class _JsonProducer(object):
def _flatten_response_never_received(e):
if hasattr(e, "reasons"):
- return ", ".join(
+ reasons = ", ".join(
_flatten_response_never_received(f.value)
for f in e.reasons
)
+
+ return "%s:[%s]" % (type(e).__name__, reasons)
else:
- return "%s: %s" % (type(e).__name__, e.message,)
+ return repr(e)
def check_content_type_is_json(headers):
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index cd388770c8..6c379d53ac 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -398,22 +398,18 @@ class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(JoinedRoomMemberListRestServlet, self).__init__(hs)
- self.state = hs.get_state_handler()
+ self.message_handler = hs.get_handlers().message_handler
@defer.inlineCallbacks
def on_GET(self, request, room_id):
- yield self.auth.get_user_by_req(request)
+ requester = yield self.auth.get_user_by_req(request)
- users_with_profile = yield self.state.get_current_user_in_room(room_id)
+ users_with_profile = yield self.message_handler.get_joined_members(
+ requester, room_id,
+ )
defer.returnValue((200, {
- "joined": {
- user_id: {
- "avatar_url": profile.avatar_url,
- "display_name": profile.display_name,
- }
- for user_id, profile in users_with_profile.iteritems()
- }
+ "joined": users_with_profile,
}))
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index d92b7ff337..d5cec10127 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -14,6 +14,9 @@
# limitations under the License.
import os
+import re
+
+NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
class MediaFilePaths(object):
@@ -73,19 +76,105 @@ class MediaFilePaths(object):
)
def url_cache_filepath(self, media_id):
- return os.path.join(
- self.base_path, "url_cache",
- media_id[0:2], media_id[2:4], media_id[4:]
- )
+ if NEW_FORMAT_ID_RE.match(media_id):
+ # Media id is of the form <DATE><RANDOM_STRING>
+ # E.g.: 2017-09-28-fsdRDt24DS234dsf
+ return os.path.join(
+ self.base_path, "url_cache",
+ media_id[:10], media_id[11:]
+ )
+ else:
+ return os.path.join(
+ self.base_path, "url_cache",
+ media_id[0:2], media_id[2:4], media_id[4:],
+ )
+
+ def url_cache_filepath_dirs_to_delete(self, media_id):
+ "The dirs to try and remove if we delete the media_id file"
+ if NEW_FORMAT_ID_RE.match(media_id):
+ return [
+ os.path.join(
+ self.base_path, "url_cache",
+ media_id[:10],
+ ),
+ ]
+ else:
+ return [
+ os.path.join(
+ self.base_path, "url_cache",
+ media_id[0:2], media_id[2:4],
+ ),
+ os.path.join(
+ self.base_path, "url_cache",
+ media_id[0:2],
+ ),
+ ]
def url_cache_thumbnail(self, media_id, width, height, content_type,
method):
+ # Media id is of the form <DATE><RANDOM_STRING>
+ # E.g.: 2017-09-28-fsdRDt24DS234dsf
+
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (
width, height, top_level_type, sub_type, method
)
- return os.path.join(
- self.base_path, "url_cache_thumbnails",
- media_id[0:2], media_id[2:4], media_id[4:],
- file_name
- )
+
+ if NEW_FORMAT_ID_RE.match(media_id):
+ return os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[:10], media_id[11:],
+ file_name
+ )
+ else:
+ return os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[0:2], media_id[2:4], media_id[4:],
+ file_name
+ )
+
+ def url_cache_thumbnail_directory(self, media_id):
+ # Media id is of the form <DATE><RANDOM_STRING>
+ # E.g.: 2017-09-28-fsdRDt24DS234dsf
+
+ if NEW_FORMAT_ID_RE.match(media_id):
+ return os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[:10], media_id[11:],
+ )
+ else:
+ return os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[0:2], media_id[2:4], media_id[4:],
+ )
+
+ def url_cache_thumbnail_dirs_to_delete(self, media_id):
+ "The dirs to try and remove if we delete the media_id thumbnails"
+ # Media id is of the form <DATE><RANDOM_STRING>
+ # E.g.: 2017-09-28-fsdRDt24DS234dsf
+ if NEW_FORMAT_ID_RE.match(media_id):
+ return [
+ os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[:10], media_id[11:],
+ ),
+ os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[:10],
+ ),
+ ]
+ else:
+ return [
+ os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[0:2], media_id[2:4], media_id[4:],
+ ),
+ os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[0:2], media_id[2:4],
+ ),
+ os.path.join(
+ self.base_path, "url_cache_thumbnails",
+ media_id[0:2],
+ ),
+ ]
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index b81a336c5d..895b480d5c 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -36,6 +36,9 @@ import cgi
import ujson as json
import urlparse
import itertools
+import datetime
+import errno
+import shutil
import logging
logger = logging.getLogger(__name__)
@@ -70,6 +73,10 @@ class PreviewUrlResource(Resource):
self.downloads = {}
+ self._cleaner_loop = self.clock.looping_call(
+ self._expire_url_cache_data, 10 * 1000
+ )
+
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
@@ -130,7 +137,7 @@ class PreviewUrlResource(Resource):
cache_result = yield self.store.get_url_cache(url, ts)
if (
cache_result and
- cache_result["download_ts"] + cache_result["expires"] > ts and
+ cache_result["expires_ts"] > ts and
cache_result["response_code"] / 100 == 2
):
respond_with_json_bytes(
@@ -239,7 +246,7 @@ class PreviewUrlResource(Resource):
url,
media_info["response_code"],
media_info["etag"],
- media_info["expires"],
+ media_info["expires"] + media_info["created_ts"],
json.dumps(og),
media_info["filesystem_id"],
media_info["created_ts"],
@@ -253,8 +260,7 @@ class PreviewUrlResource(Resource):
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
- # XXX: horrible duplication with base_resource's _download_remote_file()
- file_id = random_string(24)
+ file_id = datetime.date.today().isoformat() + '_' + random_string(16)
fname = self.filepaths.url_cache_filepath(file_id)
self.media_repo._makedirs(fname)
@@ -328,6 +334,88 @@ class PreviewUrlResource(Resource):
"etag": headers["ETag"][0] if "ETag" in headers else None,
})
+ @defer.inlineCallbacks
+ def _expire_url_cache_data(self):
+ """Clean up expired url cache content, media and thumbnails.
+ """
+ now = self.clock.time_msec()
+
+ # First we delete expired url cache entries
+ media_ids = yield self.store.get_expired_url_cache(now)
+
+ removed_media = []
+ for media_id in media_ids:
+ fname = self.filepaths.url_cache_filepath(media_id)
+ try:
+ os.remove(fname)
+ except OSError as e:
+ # If the path doesn't exist, meh
+ if e.errno != errno.ENOENT:
+ logger.warn("Failed to remove media: %r: %s", media_id, e)
+ continue
+
+ removed_media.append(media_id)
+
+ try:
+ dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
+ for dir in dirs:
+ os.rmdir(dir)
+ except:
+ pass
+
+ yield self.store.delete_url_cache(removed_media)
+
+ if removed_media:
+ logger.info("Deleted %d entries from url cache", len(removed_media))
+
+ # Now we delete old images associated with the url cache.
+ # These may be cached for a bit on the client (i.e., they
+ # may have a room open with a preview url thing open).
+ # So we wait a couple of days before deleting, just in case.
+ expire_before = now - 2 * 24 * 60 * 60 * 1000
+ media_ids = yield self.store.get_url_cache_media_before(expire_before)
+
+ removed_media = []
+ for media_id in media_ids:
+ fname = self.filepaths.url_cache_filepath(media_id)
+ try:
+ os.remove(fname)
+ except OSError as e:
+ # If the path doesn't exist, meh
+ if e.errno != errno.ENOENT:
+ logger.warn("Failed to remove media: %r: %s", media_id, e)
+ continue
+
+ try:
+ dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
+ for dir in dirs:
+ os.rmdir(dir)
+ except:
+ pass
+
+ thumbnail_dir = self.filepaths.url_cache_thumbnail_directory(media_id)
+ try:
+ shutil.rmtree(thumbnail_dir)
+ except OSError as e:
+ # If the path doesn't exist, meh
+ if e.errno != errno.ENOENT:
+ logger.warn("Failed to remove media: %r: %s", media_id, e)
+ continue
+
+ removed_media.append(media_id)
+
+ try:
+ dirs = self.filepaths.url_cache_thumbnail_dirs_to_delete(media_id)
+ for dir in dirs:
+ os.rmdir(dir)
+ except:
+ pass
+
+ yield self.store.delete_url_cache_media(removed_media)
+
+ if removed_media:
+ logger.info("Deleted %d media from url cache", len(removed_media))
+
def decode_and_calc_og(body, media_uri, request_encoding=None):
from lxml import etree
diff --git a/synapse/server.py b/synapse/server.py
index 5b892cc390..10e3e9a4f1 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -31,6 +31,7 @@ from synapse.appservice.api import ApplicationServiceApi
from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory
+from synapse.events.spamcheck import SpamChecker
from synapse.federation import initialize_http_replication
from synapse.federation.send_queue import FederationRemoteSendQueue
from synapse.federation.transport.client import TransportLayerClient
@@ -148,6 +149,7 @@ class HomeServer(object):
'groups_server_handler',
'groups_attestation_signing',
'groups_attestation_renewer',
+ 'spam_checker',
]
def __init__(self, hostname, **kwargs):
@@ -333,6 +335,9 @@ class HomeServer(object):
def build_groups_attestation_renewer(self):
return GroupAttestionRenewer(self)
+ def build_spam_checker(self):
+ return SpamChecker(self)
+
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 3b5e0a4fb9..87aeaf71d6 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -113,30 +113,37 @@ class KeyStore(SQLBaseStore):
keys[key_id] = key
defer.returnValue(keys)
- @defer.inlineCallbacks
def store_server_verify_key(self, server_name, from_server, time_now_ms,
verify_key):
"""Stores a NACL verification key for the given server.
Args:
server_name (str): The name of the server.
- key_id (str): The version of the key for the server.
from_server (str): Where the verification key was looked up
- ts_now_ms (int): The time now in milliseconds
- verification_key (VerifyKey): The NACL verify key.
+ time_now_ms (int): The time now in milliseconds
+ verify_key (nacl.signing.VerifyKey): The NACL verify key.
"""
- yield self._simple_upsert(
- table="server_signature_keys",
- keyvalues={
- "server_name": server_name,
- "key_id": "%s:%s" % (verify_key.alg, verify_key.version),
- },
- values={
- "from_server": from_server,
- "ts_added_ms": time_now_ms,
- "verify_key": buffer(verify_key.encode()),
- },
- desc="store_server_verify_key",
- )
+ key_id = "%s:%s" % (verify_key.alg, verify_key.version)
+
+ def _txn(txn):
+ self._simple_upsert_txn(
+ txn,
+ table="server_signature_keys",
+ keyvalues={
+ "server_name": server_name,
+ "key_id": key_id,
+ },
+ values={
+ "from_server": from_server,
+ "ts_added_ms": time_now_ms,
+ "verify_key": buffer(verify_key.encode()),
+ },
+ )
+ txn.call_after(
+ self._get_server_verify_key.invalidate,
+ (server_name, key_id)
+ )
+
+ return self.runInteraction("store_server_verify_key", _txn)
def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes):
diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py
index 82bb61b811..7110a71279 100644
--- a/synapse/storage/media_repository.py
+++ b/synapse/storage/media_repository.py
@@ -62,7 +62,7 @@ class MediaRepositoryStore(SQLBaseStore):
def get_url_cache_txn(txn):
# get the most recently cached result (relative to the given ts)
sql = (
- "SELECT response_code, etag, expires, og, media_id, download_ts"
+ "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
" FROM local_media_repository_url_cache"
" WHERE url = ? AND download_ts <= ?"
" ORDER BY download_ts DESC LIMIT 1"
@@ -74,7 +74,7 @@ class MediaRepositoryStore(SQLBaseStore):
# ...or if we've requested a timestamp older than the oldest
# copy in the cache, return the oldest copy (if any)
sql = (
- "SELECT response_code, etag, expires, og, media_id, download_ts"
+ "SELECT response_code, etag, expires_ts, og, media_id, download_ts"
" FROM local_media_repository_url_cache"
" WHERE url = ? AND download_ts > ?"
" ORDER BY download_ts ASC LIMIT 1"
@@ -86,14 +86,14 @@ class MediaRepositoryStore(SQLBaseStore):
return None
return dict(zip((
- 'response_code', 'etag', 'expires', 'og', 'media_id', 'download_ts'
+ 'response_code', 'etag', 'expires_ts', 'og', 'media_id', 'download_ts'
), row))
return self.runInteraction(
"get_url_cache", get_url_cache_txn
)
- def store_url_cache(self, url, response_code, etag, expires, og, media_id,
+ def store_url_cache(self, url, response_code, etag, expires_ts, og, media_id,
download_ts):
return self._simple_insert(
"local_media_repository_url_cache",
@@ -101,7 +101,7 @@ class MediaRepositoryStore(SQLBaseStore):
"url": url,
"response_code": response_code,
"etag": etag,
- "expires": expires,
+ "expires_ts": expires_ts,
"og": og,
"media_id": media_id,
"download_ts": download_ts,
@@ -238,3 +238,64 @@ class MediaRepositoryStore(SQLBaseStore):
},
)
return self.runInteraction("delete_remote_media", delete_remote_media_txn)
+
+ def get_expired_url_cache(self, now_ts):
+ sql = (
+ "SELECT media_id FROM local_media_repository_url_cache"
+ " WHERE expires_ts < ?"
+ " ORDER BY expires_ts ASC"
+ " LIMIT 500"
+ )
+
+ def _get_expired_url_cache_txn(txn):
+ txn.execute(sql, (now_ts,))
+ return [row[0] for row in txn]
+
+ return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn)
+
+ def delete_url_cache(self, media_ids):
+ sql = (
+ "DELETE FROM local_media_repository_url_cache"
+ " WHERE media_id = ?"
+ )
+
+ def _delete_url_cache_txn(txn):
+ txn.executemany(sql, [(media_id,) for media_id in media_ids])
+
+ return self.runInteraction("delete_url_cache", _delete_url_cache_txn)
+
+ def get_url_cache_media_before(self, before_ts):
+ sql = (
+ "SELECT media_id FROM local_media_repository"
+ " WHERE created_ts < ? AND url_cache IS NOT NULL"
+ " ORDER BY created_ts ASC"
+ " LIMIT 500"
+ )
+
+ def _get_url_cache_media_before_txn(txn):
+ txn.execute(sql, (before_ts,))
+ return [row[0] for row in txn]
+
+ return self.runInteraction(
+ "get_url_cache_media_before", _get_url_cache_media_before_txn,
+ )
+
+ def delete_url_cache_media(self, media_ids):
+ def _delete_url_cache_media_txn(txn):
+ sql = (
+ "DELETE FROM local_media_repository"
+ " WHERE media_id = ?"
+ )
+
+ txn.executemany(sql, [(media_id,) for media_id in media_ids])
+
+ sql = (
+ "DELETE FROM local_media_repository_thumbnails"
+ " WHERE media_id = ?"
+ )
+
+ txn.executemany(sql, [(media_id,) for media_id in media_ids])
+
+ return self.runInteraction(
+ "delete_url_cache_media", _delete_url_cache_media_txn,
+ )
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 72b670b83b..a0af8456f5 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 43
+SCHEMA_VERSION = 44
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/schema/delta/44/expire_url_cache.sql b/synapse/storage/schema/delta/44/expire_url_cache.sql
new file mode 100644
index 0000000000..e2b775f038
--- /dev/null
+++ b/synapse/storage/schema/delta/44/expire_url_cache.sql
@@ -0,0 +1,38 @@
+/* Copyright 2017 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL;
+
+-- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support
+-- indices on expressions until 3.9.
+CREATE TABLE local_media_repository_url_cache_new(
+ url TEXT,
+ response_code INTEGER,
+ etag TEXT,
+ expires_ts BIGINT,
+ og TEXT,
+ media_id TEXT,
+ download_ts BIGINT
+);
+
+INSERT INTO local_media_repository_url_cache_new
+ SELECT url, response_code, etag, expires + download_ts, og, media_id, download_ts FROM local_media_repository_url_cache;
+
+DROP TABLE local_media_repository_url_cache;
+ALTER TABLE local_media_repository_url_cache_new RENAME TO local_media_repository_url_cache;
+
+CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache(expires_ts);
+CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache(url, download_ts);
+CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache(media_id);
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
new file mode 100644
index 0000000000..4288312b8a
--- /dev/null
+++ b/synapse/util/module_loader.py
@@ -0,0 +1,42 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import importlib
+
+from synapse.config._base import ConfigError
+
+
+def load_module(provider):
+ """ Loads a module with its config
+ Take a dict with keys 'module' (the module name) and 'config'
+ (the config dict).
+
+ Returns
+ Tuple of (provider class, parsed config object)
+ """
+ # We need to import the module, and then pick the class out of
+ # that, so we split based on the last dot.
+ module, clz = provider['module'].rsplit(".", 1)
+ module = importlib.import_module(module)
+ provider_class = getattr(module, clz)
+
+ try:
+ provider_config = provider_class.parse_config(provider["config"])
+ except Exception as e:
+ raise ConfigError(
+ "Failed to parse config for %r: %r" % (provider['module'], e)
+ )
+
+ return provider_class, provider_config
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index da2c9e44e7..570312da84 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -12,17 +12,65 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import time
+import signedjson.key
+import signedjson.sign
+from mock import Mock
+from synapse.api.errors import SynapseError
from synapse.crypto import keyring
+from synapse.util import async, logcontext
from synapse.util.logcontext import LoggingContext
-from tests import utils, unittest
+from tests import unittest, utils
from twisted.internet import defer
+class MockPerspectiveServer(object):
+ def __init__(self):
+ self.server_name = "mock_server"
+ self.key = signedjson.key.generate_signing_key(0)
+
+ def get_verify_keys(self):
+ vk = signedjson.key.get_verify_key(self.key)
+ return {
+ "%s:%s" % (vk.alg, vk.version): vk,
+ }
+
+ def get_signed_key(self, server_name, verify_key):
+ key_id = "%s:%s" % (verify_key.alg, verify_key.version)
+ res = {
+ "server_name": server_name,
+ "old_verify_keys": {},
+ "valid_until_ts": time.time() * 1000 + 3600,
+ "verify_keys": {
+ key_id: {
+ "key": signedjson.key.encode_verify_key_base64(verify_key)
+ }
+ }
+ }
+ signedjson.sign.sign_json(res, self.server_name, self.key)
+ return res
+
+
class KeyringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
- self.hs = yield utils.setup_test_homeserver(handlers=None)
+ self.mock_perspective_server = MockPerspectiveServer()
+ self.http_client = Mock()
+ self.hs = yield utils.setup_test_homeserver(
+ handlers=None,
+ http_client=self.http_client,
+ )
+ self.hs.config.perspectives = {
+ self.mock_perspective_server.server_name:
+ self.mock_perspective_server.get_verify_keys()
+ }
+
+ def check_context(self, _, expected):
+ self.assertEquals(
+ getattr(LoggingContext.current_context(), "test_key", None),
+ expected
+ )
@defer.inlineCallbacks
def test_wait_for_previous_lookups(self):
@@ -30,11 +78,6 @@ class KeyringTestCase(unittest.TestCase):
kr = keyring.Keyring(self.hs)
- def check_context(_, expected):
- self.assertEquals(
- LoggingContext.current_context().test_key, expected
- )
-
lookup_1_deferred = defer.Deferred()
lookup_2_deferred = defer.Deferred()
@@ -50,7 +93,7 @@ class KeyringTestCase(unittest.TestCase):
self.assertTrue(wait_1_deferred.called)
# ... so we should have preserved the LoggingContext.
self.assertIs(LoggingContext.current_context(), context_one)
- wait_1_deferred.addBoth(check_context, "one")
+ wait_1_deferred.addBoth(self.check_context, "one")
with LoggingContext("two") as context_two:
context_two.test_key = "two"
@@ -64,7 +107,7 @@ class KeyringTestCase(unittest.TestCase):
self.assertFalse(wait_2_deferred.called)
# ... so we should have reset the LoggingContext.
self.assertIs(LoggingContext.current_context(), sentinel_context)
- wait_2_deferred.addBoth(check_context, "two")
+ wait_2_deferred.addBoth(self.check_context, "two")
# let the first lookup complete (in the sentinel context)
lookup_1_deferred.callback(None)
@@ -72,3 +115,115 @@ class KeyringTestCase(unittest.TestCase):
# now the second wait should complete and restore our
# loggingcontext.
yield wait_2_deferred
+
+ @defer.inlineCallbacks
+ def test_verify_json_objects_for_server_awaits_previous_requests(self):
+ key1 = signedjson.key.generate_signing_key(1)
+
+ kr = keyring.Keyring(self.hs)
+ json1 = {}
+ signedjson.sign.sign_json(json1, "server10", key1)
+
+ persp_resp = {
+ "server_keys": [
+ self.mock_perspective_server.get_signed_key(
+ "server10",
+ signedjson.key.get_verify_key(key1)
+ ),
+ ]
+ }
+ persp_deferred = defer.Deferred()
+
+ @defer.inlineCallbacks
+ def get_perspectives(**kwargs):
+ self.assertEquals(
+ LoggingContext.current_context().test_key, "11",
+ )
+ with logcontext.PreserveLoggingContext():
+ yield persp_deferred
+ defer.returnValue(persp_resp)
+ self.http_client.post_json.side_effect = get_perspectives
+
+ with LoggingContext("11") as context_11:
+ context_11.test_key = "11"
+
+ # start off a first set of lookups
+ res_deferreds = kr.verify_json_objects_for_server(
+ [("server10", json1),
+ ("server11", {})
+ ]
+ )
+
+ # the unsigned json should be rejected pretty quickly
+ self.assertTrue(res_deferreds[1].called)
+ try:
+ yield res_deferreds[1]
+ self.assertFalse("unsigned json didn't cause a failure")
+ except SynapseError:
+ pass
+
+ self.assertFalse(res_deferreds[0].called)
+ res_deferreds[0].addBoth(self.check_context, None)
+
+ # wait a tick for it to send the request to the perspectives server
+ # (it first tries the datastore)
+ yield async.sleep(0.005)
+ self.http_client.post_json.assert_called_once()
+
+ self.assertIs(LoggingContext.current_context(), context_11)
+
+ context_12 = LoggingContext("12")
+ context_12.test_key = "12"
+ with logcontext.PreserveLoggingContext(context_12):
+ # a second request for a server with outstanding requests
+ # should block rather than start a second call
+ self.http_client.post_json.reset_mock()
+ self.http_client.post_json.return_value = defer.Deferred()
+
+ res_deferreds_2 = kr.verify_json_objects_for_server(
+ [("server10", json1)],
+ )
+ yield async.sleep(0.005)
+ self.http_client.post_json.assert_not_called()
+ res_deferreds_2[0].addBoth(self.check_context, None)
+
+ # complete the first request
+ with logcontext.PreserveLoggingContext():
+ persp_deferred.callback(persp_resp)
+ self.assertIs(LoggingContext.current_context(), context_11)
+
+ with logcontext.PreserveLoggingContext():
+ yield res_deferreds[0]
+ yield res_deferreds_2[0]
+
+ @defer.inlineCallbacks
+ def test_verify_json_for_server(self):
+ kr = keyring.Keyring(self.hs)
+
+ key1 = signedjson.key.generate_signing_key(1)
+ yield self.hs.datastore.store_server_verify_key(
+ "server9", "", time.time() * 1000,
+ signedjson.key.get_verify_key(key1),
+ )
+ json1 = {}
+ signedjson.sign.sign_json(json1, "server9", key1)
+
+ sentinel_context = LoggingContext.current_context()
+
+ with LoggingContext("one") as context_one:
+ context_one.test_key = "one"
+
+ defer = kr.verify_json_for_server("server9", {})
+ try:
+ yield defer
+ self.fail("should fail on unsigned json")
+ except SynapseError:
+ pass
+ self.assertIs(LoggingContext.current_context(), context_one)
+
+ defer = kr.verify_json_for_server("server9", json1)
+ self.assertFalse(defer.called)
+ self.assertIs(LoggingContext.current_context(), sentinel_context)
+ yield defer
+
+ self.assertIs(LoggingContext.current_context(), context_one)
diff --git a/tests/test_dns.py b/tests/test_dns.py
index c394c57ee7..d08b0f4333 100644
--- a/tests/test_dns.py
+++ b/tests/test_dns.py
@@ -24,15 +24,17 @@ from synapse.http.endpoint import resolve_service
from tests.utils import MockClock
+@unittest.DEBUG
class DnsTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_resolve(self):
dns_client_mock = Mock()
- service_name = "test_service.examle.com"
+ service_name = "test_service.example.com"
host_name = "example.com"
ip_address = "127.0.0.1"
+ ip6_address = "::1"
answer_srv = dns.RRHeader(
type=dns.SRV,
@@ -48,8 +50,22 @@ class DnsTestCase(unittest.TestCase):
)
)
- dns_client_mock.lookupService.return_value = ([answer_srv], None, None)
- dns_client_mock.lookupAddress.return_value = ([answer_a], None, None)
+ answer_aaaa = dns.RRHeader(
+ type=dns.AAAA,
+ payload=dns.Record_AAAA(
+ address=ip6_address,
+ )
+ )
+
+ dns_client_mock.lookupService.return_value = defer.succeed(
+ ([answer_srv], None, None),
+ )
+ dns_client_mock.lookupAddress.return_value = defer.succeed(
+ ([answer_a], None, None),
+ )
+ dns_client_mock.lookupIPV6Address.return_value = defer.succeed(
+ ([answer_aaaa], None, None),
+ )
cache = {}
@@ -59,10 +75,12 @@ class DnsTestCase(unittest.TestCase):
dns_client_mock.lookupService.assert_called_once_with(service_name)
dns_client_mock.lookupAddress.assert_called_once_with(host_name)
+ dns_client_mock.lookupIPV6Address.assert_called_once_with(host_name)
- self.assertEquals(len(servers), 1)
+ self.assertEquals(len(servers), 2)
self.assertEquals(servers, cache[service_name])
self.assertEquals(servers[0].host, ip_address)
+ self.assertEquals(servers[1].host, ip6_address)
@defer.inlineCallbacks
def test_from_cache_expired_and_dns_fail(self):
diff --git a/tests/utils.py b/tests/utils.py
index 4f7e32b3ab..3c81a3e16d 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -56,6 +56,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
config.worker_replication_url = ""
config.worker_app = None
config.email_enable_notifs = False
+ config.block_non_admin_invites = False
config.use_frozen_dicts = True
config.database_config = {"name": "sqlite3"}
|