diff options
author | David Baker <dave@matrix.org> | 2017-10-02 16:20:41 +0100 |
---|---|---|
committer | David Baker <dave@matrix.org> | 2017-10-02 16:20:41 +0100 |
commit | 27955056e092b440ed00da13fe6326dd438bb900 (patch) | |
tree | 8480cf26404b89993e6ffa6cadbabc477be19661 /synapse | |
parent | Merge pull request #2472 from matrix-org/erikj/groups_rooms (diff) | |
parent | Merge pull request #2480 from matrix-org/rav/federation_client_logging (diff) | |
download | synapse-27955056e092b440ed00da13fe6326dd438bb900.tar.xz |
Merge branch 'develop' into erikj/groups_merged
Diffstat (limited to 'synapse')
26 files changed, 930 insertions, 273 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index dbf22eca00..ec83e6adb7 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.22.1" +__version__ = "0.23.0-rc2" diff --git a/synapse/api/auth.py b/synapse/api/auth.py index e3da45b416..72858cca1f 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -519,6 +519,14 @@ class Auth(object): ) def is_server_admin(self, user): + """ Check if the given user is a local server admin. + + Args: + user (str): mxid of user to check + + Returns: + bool: True if the user is an admin + """ return self.store.is_server_admin(user) @defer.inlineCallbacks diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index b22cacf8dc..3f9d9d5f8b 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -34,6 +34,7 @@ from .password_auth_providers import PasswordAuthProviderConfig from .emailconfig import EmailConfig from .workers import WorkerConfig from .push import PushConfig +from .spam_checker import SpamCheckerConfig class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, @@ -41,7 +42,8 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig, VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig, AppServiceConfig, KeyConfig, SAML2Config, CasConfig, JWTConfig, PasswordConfig, EmailConfig, - WorkerConfig, PasswordAuthProviderConfig, PushConfig,): + WorkerConfig, PasswordAuthProviderConfig, PushConfig, + SpamCheckerConfig,): pass diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index 83762d089a..90824cab7f 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -15,13 +15,15 @@ from ._base import Config, ConfigError -import importlib +from synapse.util.module_loader import load_module class PasswordAuthProviderConfig(Config): def read_config(self, config): self.password_providers = [] + provider_config = None + # We want to be backwards compatible with the old `ldap_config` # param. ldap_config = config.get("ldap_config", {}) @@ -38,19 +40,15 @@ class PasswordAuthProviderConfig(Config): if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider": from ldap_auth_provider import LdapAuthProvider provider_class = LdapAuthProvider + try: + provider_config = provider_class.parse_config(provider["config"]) + except Exception as e: + raise ConfigError( + "Failed to parse config for %r: %r" % (provider['module'], e) + ) else: - # We need to import the module, and then pick the class out of - # that, so we split based on the last dot. - module, clz = provider['module'].rsplit(".", 1) - module = importlib.import_module(module) - provider_class = getattr(module, clz) + (provider_class, provider_config) = load_module(provider) - try: - provider_config = provider_class.parse_config(provider["config"]) - except Exception as e: - raise ConfigError( - "Failed to parse config for %r: %r" % (provider['module'], e) - ) self.password_providers.append((provider_class, provider_config)) def default_config(self, **kwargs): diff --git a/synapse/config/server.py b/synapse/config/server.py index 89d61a0503..c9a1715f1f 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -43,6 +43,12 @@ class ServerConfig(Config): self.filter_timeline_limit = config.get("filter_timeline_limit", -1) + # Whether we should block invites sent to users on this server + # (other than those sent by local server admins) + self.block_non_admin_invites = config.get( + "block_non_admin_invites", False, + ) + if self.public_baseurl is not None: if self.public_baseurl[-1] != '/': self.public_baseurl += '/' @@ -194,6 +200,10 @@ class ServerConfig(Config): # and sync operations. The default value is -1, means no upper limit. # filter_timeline_limit: 5000 + # Whether room invites to users on this server should be blocked + # (except those sent by local server admins). The default is False. + # block_non_admin_invites: True + # List of ports that Synapse should listen on, their purpose and their # configuration. listeners: diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py new file mode 100644 index 0000000000..3fec42bdb0 --- /dev/null +++ b/synapse/config/spam_checker.py @@ -0,0 +1,35 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.util.module_loader import load_module + +from ._base import Config + + +class SpamCheckerConfig(Config): + def read_config(self, config): + self.spam_checker = None + + provider = config.get("spam_checker", None) + if provider is not None: + self.spam_checker = load_module(provider) + + def default_config(self, **kwargs): + return """\ + # spam_checker: + # module: "my_custom_project.SuperSpamChecker" + # config: + # example_option: 'things' + """ diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py index c2bd64d6c2..f1fd488b90 100644 --- a/synapse/crypto/keyclient.py +++ b/synapse/crypto/keyclient.py @@ -13,14 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. - +from synapse.util import logcontext from twisted.web.http import HTTPClient from twisted.internet.protocol import Factory from twisted.internet import defer, reactor from synapse.http.endpoint import matrix_federation_endpoint -from synapse.util.logcontext import ( - preserve_context_over_fn, preserve_context_over_deferred -) import simplejson as json import logging @@ -43,14 +40,10 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1): for i in range(5): try: - protocol = yield preserve_context_over_fn( - endpoint.connect, factory - ) - server_response, server_certificate = yield preserve_context_over_deferred( - protocol.remote_key - ) - defer.returnValue((server_response, server_certificate)) - return + with logcontext.PreserveLoggingContext(): + protocol = yield endpoint.connect(factory) + server_response, server_certificate = yield protocol.remote_key + defer.returnValue((server_response, server_certificate)) except SynapseKeyClientError as e: logger.exception("Error getting key for %r" % (server_name,)) if e.status.startswith("4"): diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 51851d04e5..054bac456d 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -18,7 +18,7 @@ from synapse.crypto.keyclient import fetch_server_key from synapse.api.errors import SynapseError, Codes from synapse.util import unwrapFirstError, logcontext from synapse.util.logcontext import ( - preserve_context_over_fn, PreserveLoggingContext, + PreserveLoggingContext, preserve_fn ) from synapse.util.metrics import Measure @@ -57,7 +57,8 @@ Attributes: json_object(dict): The JSON object to verify. deferred(twisted.internet.defer.Deferred): A deferred (server_name, key_id, verify_key) tuple that resolves when - a verify key has been fetched + a verify key has been fetched. The deferreds' callbacks are run with no + logcontext. """ @@ -82,9 +83,11 @@ class Keyring(object): self.key_downloads = {} def verify_json_for_server(self, server_name, json_object): - return self.verify_json_objects_for_server( - [(server_name, json_object)] - )[0] + return logcontext.make_deferred_yieldable( + self.verify_json_objects_for_server( + [(server_name, json_object)] + )[0] + ) def verify_json_objects_for_server(self, server_and_json): """Bulk verifies signatures of json objects, bulk fetching keys as @@ -94,8 +97,10 @@ class Keyring(object): server_and_json (list): List of pairs of (server_name, json_object) Returns: - list of deferreds indicating success or failure to verify each - json object's signature for the given server_name. + List<Deferred>: for each input pair, a deferred indicating success + or failure to verify each json object's signature for the given + server_name. The deferreds run their callbacks in the sentinel + logcontext. """ verify_requests = [] @@ -122,93 +127,71 @@ class Keyring(object): verify_requests.append(verify_request) - @defer.inlineCallbacks - def handle_key_deferred(verify_request): - server_name = verify_request.server_name - try: - _, key_id, verify_key = yield verify_request.deferred - except IOError as e: - logger.warn( - "Got IOError when downloading keys for %s: %s %s", - server_name, type(e).__name__, str(e.message), - ) - raise SynapseError( - 502, - "Error downloading keys for %s" % (server_name,), - Codes.UNAUTHORIZED, - ) - except Exception as e: - logger.exception( - "Got Exception when downloading keys for %s: %s %s", - server_name, type(e).__name__, str(e.message), - ) - raise SynapseError( - 401, - "No key for %s with id %s" % (server_name, key_ids), - Codes.UNAUTHORIZED, - ) + preserve_fn(self._start_key_lookups)(verify_requests) - json_object = verify_request.json_object + # Pass those keys to handle_key_deferred so that the json object + # signatures can be verified + handle = preserve_fn(_handle_key_deferred) + return [ + handle(rq) for rq in verify_requests + ] - logger.debug("Got key %s %s:%s for server %s, verifying" % ( - key_id, verify_key.alg, verify_key.version, server_name, - )) - try: - verify_signed_json(json_object, server_name, verify_key) - except: - raise SynapseError( - 401, - "Invalid signature for server %s with key %s:%s" % ( - server_name, verify_key.alg, verify_key.version - ), - Codes.UNAUTHORIZED, - ) + @defer.inlineCallbacks + def _start_key_lookups(self, verify_requests): + """Sets off the key fetches for each verify request + + Once each fetch completes, verify_request.deferred will be resolved. + + Args: + verify_requests (List[VerifyKeyRequest]): + """ + # create a deferred for each server we're going to look up the keys + # for; we'll resolve them once we have completed our lookups. + # These will be passed into wait_for_previous_lookups to block + # any other lookups until we have finished. + # The deferreds are called with no logcontext. server_to_deferred = { - server_name: defer.Deferred() - for server_name, _ in server_and_json + rq.server_name: defer.Deferred() + for rq in verify_requests } - with PreserveLoggingContext(): + # We want to wait for any previous lookups to complete before + # proceeding. + yield self.wait_for_previous_lookups( + [rq.server_name for rq in verify_requests], + server_to_deferred, + ) - # We want to wait for any previous lookups to complete before - # proceeding. - wait_on_deferred = self.wait_for_previous_lookups( - [server_name for server_name, _ in server_and_json], - server_to_deferred, - ) + # Actually start fetching keys. + self._get_server_verify_keys(verify_requests) - # Actually start fetching keys. - wait_on_deferred.addBoth( - lambda _: self.get_server_verify_keys(verify_requests) - ) + # When we've finished fetching all the keys for a given server_name, + # resolve the deferred passed to `wait_for_previous_lookups` so that + # any lookups waiting will proceed. + # + # map from server name to a set of request ids + server_to_request_ids = {} - # When we've finished fetching all the keys for a given server_name, - # resolve the deferred passed to `wait_for_previous_lookups` so that - # any lookups waiting will proceed. - server_to_request_ids = {} - - def remove_deferreds(res, server_name, verify_request): - request_id = id(verify_request) - server_to_request_ids[server_name].discard(request_id) - if not server_to_request_ids[server_name]: - d = server_to_deferred.pop(server_name, None) - if d: - d.callback(None) - return res - - for verify_request in verify_requests: - server_name = verify_request.server_name - request_id = id(verify_request) - server_to_request_ids.setdefault(server_name, set()).add(request_id) - deferred.addBoth(remove_deferreds, server_name, verify_request) + for verify_request in verify_requests: + server_name = verify_request.server_name + request_id = id(verify_request) + server_to_request_ids.setdefault(server_name, set()).add(request_id) - # Pass those keys to handle_key_deferred so that the json object - # signatures can be verified - return [ - preserve_context_over_fn(handle_key_deferred, verify_request) - for verify_request in verify_requests - ] + def remove_deferreds(res, verify_request): + server_name = verify_request.server_name + request_id = id(verify_request) + server_to_request_ids[server_name].discard(request_id) + if not server_to_request_ids[server_name]: + d = server_to_deferred.pop(server_name, None) + if d: + d.callback(None) + return res + + for verify_request in verify_requests: + verify_request.deferred.addBoth( + remove_deferreds, verify_request, + ) @defer.inlineCallbacks def wait_for_previous_lookups(self, server_names, server_to_deferred): @@ -245,7 +228,7 @@ class Keyring(object): self.key_downloads[server_name] = deferred deferred.addBoth(rm, server_name) - def get_server_verify_keys(self, verify_requests): + def _get_server_verify_keys(self, verify_requests): """Tries to find at least one key for each verify request For each verify_request, verify_request.deferred is called back with @@ -314,21 +297,23 @@ class Keyring(object): if not missing_keys: break - for verify_request in requests_missing_keys.values(): - verify_request.deferred.errback(SynapseError( - 401, - "No key for %s with id %s" % ( - verify_request.server_name, verify_request.key_ids, - ), - Codes.UNAUTHORIZED, - )) + with PreserveLoggingContext(): + for verify_request in requests_missing_keys: + verify_request.deferred.errback(SynapseError( + 401, + "No key for %s with id %s" % ( + verify_request.server_name, verify_request.key_ids, + ), + Codes.UNAUTHORIZED, + )) def on_err(err): - for verify_request in verify_requests: - if not verify_request.deferred.called: - verify_request.deferred.errback(err) + with PreserveLoggingContext(): + for verify_request in verify_requests: + if not verify_request.deferred.called: + verify_request.deferred.errback(err) - do_iterations().addErrback(on_err) + preserve_fn(do_iterations)().addErrback(on_err) @defer.inlineCallbacks def get_keys_from_store(self, server_name_and_key_ids): @@ -738,3 +723,47 @@ class Keyring(object): ], consumeErrors=True, ).addErrback(unwrapFirstError)) + + +@defer.inlineCallbacks +def _handle_key_deferred(verify_request): + server_name = verify_request.server_name + try: + with PreserveLoggingContext(): + _, key_id, verify_key = yield verify_request.deferred + except IOError as e: + logger.warn( + "Got IOError when downloading keys for %s: %s %s", + server_name, type(e).__name__, str(e.message), + ) + raise SynapseError( + 502, + "Error downloading keys for %s" % (server_name,), + Codes.UNAUTHORIZED, + ) + except Exception as e: + logger.exception( + "Got Exception when downloading keys for %s: %s %s", + server_name, type(e).__name__, str(e.message), + ) + raise SynapseError( + 401, + "No key for %s with id %s" % (server_name, verify_request.key_ids), + Codes.UNAUTHORIZED, + ) + + json_object = verify_request.json_object + + logger.debug("Got key %s %s:%s for server %s, verifying" % ( + key_id, verify_key.alg, verify_key.version, server_name, + )) + try: + verify_signed_json(json_object, server_name, verify_key) + except: + raise SynapseError( + 401, + "Invalid signature for server %s with key %s:%s" % ( + server_name, verify_key.alg, verify_key.version + ), + Codes.UNAUTHORIZED, + ) diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py new file mode 100644 index 0000000000..e739f105b2 --- /dev/null +++ b/synapse/events/spamcheck.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class SpamChecker(object): + def __init__(self, hs): + self.spam_checker = None + + module = None + config = None + try: + module, config = hs.config.spam_checker + except: + pass + + if module is not None: + self.spam_checker = module(config=config) + + def check_event_for_spam(self, event): + """Checks if a given event is considered "spammy" by this server. + + If the server considers an event spammy, then it will be rejected if + sent by a local user. If it is sent by a user on another server, then + users receive a blank event. + + Args: + event (synapse.events.EventBase): the event to be checked + + Returns: + bool: True if the event is spammy. + """ + if self.spam_checker is None: + return False + + return self.spam_checker.check_event_for_spam(event) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 2339cc9034..a0f5d40eb3 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -12,28 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - -from twisted.internet import defer - -from synapse.events.utils import prune_event - -from synapse.crypto.event_signing import check_event_content_hash - -from synapse.api.errors import SynapseError - -from synapse.util import unwrapFirstError -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred - import logging +from synapse.api.errors import SynapseError +from synapse.crypto.event_signing import check_event_content_hash +from synapse.events.utils import prune_event +from synapse.util import unwrapFirstError, logcontext +from twisted.internet import defer logger = logging.getLogger(__name__) class FederationBase(object): def __init__(self, hs): - pass + self.spam_checker = hs.get_spam_checker() @defer.inlineCallbacks def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False, @@ -57,56 +49,52 @@ class FederationBase(object): """ deferreds = self._check_sigs_and_hashes(pdus) - def callback(pdu): - return pdu + @defer.inlineCallbacks + def handle_check_result(pdu, deferred): + try: + res = yield logcontext.make_deferred_yieldable(deferred) + except SynapseError: + res = None - def errback(failure, pdu): - failure.trap(SynapseError) - return None - - def try_local_db(res, pdu): if not res: # Check local db. - return self.store.get_event( + res = yield self.store.get_event( pdu.event_id, allow_rejected=True, allow_none=True, ) - return res - def try_remote(res, pdu): if not res and pdu.origin != origin: - return self.get_pdu( - destinations=[pdu.origin], - event_id=pdu.event_id, - outlier=outlier, - timeout=10000, - ).addErrback(lambda e: None) - return res - - def warn(res, pdu): + try: + res = yield self.get_pdu( + destinations=[pdu.origin], + event_id=pdu.event_id, + outlier=outlier, + timeout=10000, + ) + except SynapseError: + pass + if not res: logger.warn( "Failed to find copy of %s with valid signature", pdu.event_id, ) - return res - for pdu, deferred in zip(pdus, deferreds): - deferred.addCallbacks( - callback, errback, errbackArgs=[pdu] - ).addCallback( - try_local_db, pdu - ).addCallback( - try_remote, pdu - ).addCallback( - warn, pdu - ) + defer.returnValue(res) - valid_pdus = yield preserve_context_over_deferred(defer.gatherResults( - deferreds, - consumeErrors=True - )).addErrback(unwrapFirstError) + handle = logcontext.preserve_fn(handle_check_result) + deferreds2 = [ + handle(pdu, deferred) + for pdu, deferred in zip(pdus, deferreds) + ] + + valid_pdus = yield logcontext.make_deferred_yieldable( + defer.gatherResults( + deferreds2, + consumeErrors=True, + ) + ).addErrback(unwrapFirstError) if include_none: defer.returnValue(valid_pdus) @@ -114,15 +102,24 @@ class FederationBase(object): defer.returnValue([p for p in valid_pdus if p]) def _check_sigs_and_hash(self, pdu): - return self._check_sigs_and_hashes([pdu])[0] + return logcontext.make_deferred_yieldable( + self._check_sigs_and_hashes([pdu])[0], + ) def _check_sigs_and_hashes(self, pdus): - """Throws a SynapseError if a PDU does not have the correct - signatures. + """Checks that each of the received events is correctly signed by the + sending server. + + Args: + pdus (list[FrozenEvent]): the events to be checked Returns: - FrozenEvent: Either the given event or it redacted if it failed the - content hash check. + list[Deferred]: for each input event, a deferred which: + * returns the original event if the checks pass + * returns a redacted version of the event (if the signature + matched but the hash did not) + * throws a SynapseError if the signature check failed. + The deferreds run their callbacks in the sentinel logcontext. """ redacted_pdus = [ @@ -130,26 +127,38 @@ class FederationBase(object): for pdu in pdus ] - deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([ + deferreds = self.keyring.verify_json_objects_for_server([ (p.origin, p.get_pdu_json()) for p in redacted_pdus ]) + ctx = logcontext.LoggingContext.current_context() + def callback(_, pdu, redacted): - if not check_event_content_hash(pdu): - logger.warn( - "Event content has been tampered, redacting %s: %s", - pdu.event_id, pdu.get_pdu_json() - ) - return redacted - return pdu + with logcontext.PreserveLoggingContext(ctx): + if not check_event_content_hash(pdu): + logger.warn( + "Event content has been tampered, redacting %s: %s", + pdu.event_id, pdu.get_pdu_json() + ) + return redacted + + if self.spam_checker.check_event_for_spam(pdu): + logger.warn( + "Event contains spam, redacting %s: %s", + pdu.event_id, pdu.get_pdu_json() + ) + return redacted + + return pdu def errback(failure, pdu): failure.trap(SynapseError) - logger.warn( - "Signature check failed for %s", - pdu.event_id, - ) + with logcontext.PreserveLoggingContext(ctx): + logger.warn( + "Signature check failed for %s", + pdu.event_id, + ) return failure for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus): diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 861441708b..7c5e5d957f 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -22,7 +22,7 @@ from synapse.api.constants import Membership from synapse.api.errors import ( CodeMessageException, HttpResponseException, SynapseError, ) -from synapse.util import unwrapFirstError +from synapse.util import unwrapFirstError, logcontext from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.logutils import log_function from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred @@ -189,10 +189,10 @@ class FederationClient(FederationBase): ] # FIXME: We should handle signature failures more gracefully. - pdus[:] = yield preserve_context_over_deferred(defer.gatherResults( + pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults( self._check_sigs_and_hashes(pdus), consumeErrors=True, - )).addErrback(unwrapFirstError) + ).addErrback(unwrapFirstError)) defer.returnValue(pdus) @@ -252,7 +252,7 @@ class FederationClient(FederationBase): pdu = pdu_list[0] # Check signatures are correct. - signed_pdu = yield self._check_sigs_and_hashes([pdu])[0] + signed_pdu = yield self._check_sigs_and_hash(pdu) break diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4669199b2d..18f87cad67 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1074,6 +1074,9 @@ class FederationHandler(BaseHandler): if is_blocked: raise SynapseError(403, "This room has been blocked on this server") + if self.hs.config.block_non_admin_invites: + raise SynapseError(403, "This server does not accept room invites") + membership = event.content.get("membership") if event.type != EventTypes.Member or membership != Membership.INVITE: raise SynapseError(400, "The event was not an m.room.member invite event") @@ -2090,6 +2093,14 @@ class FederationHandler(BaseHandler): @defer.inlineCallbacks @log_function def on_exchange_third_party_invite_request(self, origin, room_id, event_dict): + """Handle an exchange_third_party_invite request from a remote server + + The remote server will call this when it wants to turn a 3pid invite + into a normal m.room.member invite. + + Returns: + Deferred: resolves (to None) + """ builder = self.event_builder_factory.new(event_dict) message_handler = self.hs.get_handlers().message_handler @@ -2108,9 +2119,12 @@ class FederationHandler(BaseHandler): raise e yield self._check_signature(event, context) + # XXX we send the invite here, but send_membership_event also sends it, + # so we end up making two requests. I think this is redundant. returned_invite = yield self.send_invite(origin, event) # TODO: Make sure the signatures actually are correct. event.signatures.update(returned_invite.signatures) + member_handler = self.hs.get_handlers().room_member_handler yield member_handler.send_membership_event(None, event, context) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 5b8f20b73c..e22d4803b9 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from twisted.internet import defer from synapse.api.constants import EventTypes, Membership @@ -59,6 +58,8 @@ class MessageHandler(BaseHandler): self.action_generator = hs.get_action_generator() + self.spam_checker = hs.get_spam_checker() + @defer.inlineCallbacks def purge_history(self, room_id, event_id): event = yield self.store.get_event(event_id) @@ -322,6 +323,12 @@ class MessageHandler(BaseHandler): token_id=requester.access_token_id, txn_id=txn_id ) + + if self.spam_checker.check_event_for_spam(event): + raise SynapseError( + 403, "Spam is not permitted here", Codes.FORBIDDEN + ) + yield self.send_nonmember_event( requester, event, @@ -413,6 +420,51 @@ class MessageHandler(BaseHandler): [serialize_event(c, now) for c in room_state.values()] ) + @defer.inlineCallbacks + def get_joined_members(self, requester, room_id): + """Get all the joined members in the room and their profile information. + + If the user has left the room return the state events from when they left. + + Args: + requester(Requester): The user requesting state events. + room_id(str): The room ID to get all state events from. + Returns: + A dict of user_id to profile info + """ + user_id = requester.user.to_string() + if not requester.app_service: + # We check AS auth after fetching the room membership, as it + # requires us to pull out all joined members anyway. + membership, _ = yield self._check_in_room_or_world_readable( + room_id, user_id + ) + if membership != Membership.JOIN: + raise NotImplementedError( + "Getting joined members after leaving is not implemented" + ) + + users_with_profile = yield self.state.get_current_user_in_room(room_id) + + # If this is an AS, double check that they are allowed to see the members. + # This can either be because the AS user is in the room or becuase there + # is a user in the room that the AS is "interested in" + if requester.app_service and user_id not in users_with_profile: + for uid in users_with_profile: + if requester.app_service.is_interested_in_user(uid): + break + else: + # Loop fell through, AS has no interested users in room + raise AuthError(403, "Appservice not in room") + + defer.returnValue({ + user_id: { + "avatar_url": profile.avatar_url, + "display_name": profile.display_name, + } + for user_id, profile in users_with_profile.iteritems() + }) + @measure_func("_create_new_client_event") @defer.inlineCallbacks def _create_new_client_event(self, builder, requester=None, prev_event_ids=None): diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index dadc19d45b..d6ad57171c 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -193,6 +193,8 @@ class RoomMemberHandler(BaseHandler): if action in ["kick", "unban"]: effective_membership_state = "leave" + # if this is a join with a 3pid signature, we may need to turn a 3pid + # invite into a normal invite before we can handle the join. if third_party_signed is not None: replication = self.hs.get_replication_layer() yield replication.exchange_third_party_invite( @@ -210,6 +212,16 @@ class RoomMemberHandler(BaseHandler): if is_blocked: raise SynapseError(403, "This room has been blocked on this server") + if (effective_membership_state == "invite" and + self.hs.config.block_non_admin_invites): + is_requester_admin = yield self.auth.is_server_admin( + requester.user, + ) + if not is_requester_admin: + raise SynapseError( + 403, "Invites have been disabled on this server", + ) + latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id) current_state_ids = yield self.state_handler.get_current_state_ids( room_id, latest_event_ids=latest_event_ids, @@ -473,6 +485,16 @@ class RoomMemberHandler(BaseHandler): requester, txn_id ): + if self.hs.config.block_non_admin_invites: + is_requester_admin = yield self.auth.is_server_admin( + requester.user, + ) + if not is_requester_admin: + raise SynapseError( + 403, "Invites have been disabled on this server", + Codes.FORBIDDEN, + ) + invitee = yield self._lookup_3pid( id_server, medium, address ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 69c1bc189e..219529936f 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -306,11 +306,6 @@ class SyncHandler(object): timeline_limit = sync_config.filter_collection.timeline_limit() block_all_timeline = sync_config.filter_collection.blocks_all_room_timeline() - # Pull out the current state, as we always want to include those events - # in the timeline if they're there. - current_state_ids = yield self.state.get_current_state_ids(room_id) - current_state_ids = frozenset(current_state_ids.itervalues()) - if recents is None or newly_joined_room or timeline_limit < len(recents): limited = True else: @@ -318,6 +313,15 @@ class SyncHandler(object): if recents: recents = sync_config.filter_collection.filter_room_timeline(recents) + + # We check if there are any state events, if there are then we pass + # all current state events to the filter_events function. This is to + # ensure that we always include current state in the timeline + current_state_ids = frozenset() + if any(e.is_state() for e in recents): + current_state_ids = yield self.state.get_current_state_ids(room_id) + current_state_ids = frozenset(current_state_ids.itervalues()) + recents = yield filter_events_for_client( self.store, sync_config.user.to_string(), @@ -354,6 +358,15 @@ class SyncHandler(object): loaded_recents = sync_config.filter_collection.filter_room_timeline( events ) + + # We check if there are any state events, if there are then we pass + # all current state events to the filter_events function. This is to + # ensure that we always include current state in the timeline + current_state_ids = frozenset() + if any(e.is_state() for e in loaded_recents): + current_state_ids = yield self.state.get_current_state_ids(room_id) + current_state_ids = frozenset(current_state_ids.itervalues()) + loaded_recents = yield filter_events_for_client( self.store, sync_config.user.to_string(), @@ -1042,7 +1055,18 @@ class SyncHandler(object): # We want to figure out if we joined the room at some point since # the last sync (even if we have since left). This is to make sure # we do send down the room, and with full state, where necessary + old_state_ids = None + if room_id in joined_room_ids and non_joins: + # Always include if the user (re)joined the room, especially + # important so that device list changes are calculated correctly. + # If there are non join member events, but we are still in the room, + # then the user must have left and joined + newly_joined_rooms.append(room_id) + + # User is in the room so we don't need to do the invite/leave checks + continue + if room_id in joined_room_ids or has_join: old_state_ids = yield self.get_state_at(room_id, since_token) old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None) @@ -1054,8 +1078,9 @@ class SyncHandler(object): if not old_mem_ev or old_mem_ev.membership != Membership.JOIN: newly_joined_rooms.append(room_id) - if room_id in joined_room_ids: - continue + # If user is in the room then we don't need to do the invite/leave checks + if room_id in joined_room_ids: + continue if not non_joins: continue diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py index d8923c9abb..a97532162f 100644 --- a/synapse/http/endpoint.py +++ b/synapse/http/endpoint.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import socket from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS from twisted.internet import defer, reactor @@ -30,7 +31,10 @@ logger = logging.getLogger(__name__) SERVER_CACHE = {} - +# our record of an individual server which can be tried to reach a destination. +# +# "host" is actually a dotted-quad or ipv6 address string. Except when there's +# no SRV record, in which case it is the original hostname. _Server = collections.namedtuple( "_Server", "priority weight host port expires" ) @@ -219,9 +223,10 @@ class SRVClientEndpoint(object): return self.default_server else: raise ConnectError( - "Not server available for %s" % self.service_name + "No server available for %s" % self.service_name ) + # look for all servers with the same priority min_priority = self.servers[0].priority weight_indexes = list( (index, server.weight + 1) @@ -231,11 +236,22 @@ class SRVClientEndpoint(object): total_weight = sum(weight for index, weight in weight_indexes) target_weight = random.randint(0, total_weight) - for index, weight in weight_indexes: target_weight -= weight if target_weight <= 0: server = self.servers[index] + # XXX: this looks totally dubious: + # + # (a) we never reuse a server until we have been through + # all of the servers at the same priority, so if the + # weights are A: 100, B:1, we always do ABABAB instead of + # AAAA...AAAB (approximately). + # + # (b) After using all the servers at the lowest priority, + # we move onto the next priority. We should only use the + # second priority if servers at the top priority are + # unreachable. + # del self.servers[index] self.used_servers.append(server) return server @@ -280,26 +296,21 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t continue payload = answer.payload - host = str(payload.target) - srv_ttl = answer.ttl - try: - answers, _, _ = yield dns_client.lookupAddress(host) - except DNSNameError: - continue + hosts = yield _get_hosts_for_srv_record( + dns_client, str(payload.target) + ) - for answer in answers: - if answer.type == dns.A and answer.payload: - ip = answer.payload.dottedQuad() - host_ttl = min(srv_ttl, answer.ttl) + for (ip, ttl) in hosts: + host_ttl = min(answer.ttl, ttl) - servers.append(_Server( - host=ip, - port=int(payload.port), - priority=int(payload.priority), - weight=int(payload.weight), - expires=int(clock.time()) + host_ttl, - )) + servers.append(_Server( + host=ip, + port=int(payload.port), + priority=int(payload.priority), + weight=int(payload.weight), + expires=int(clock.time()) + host_ttl, + )) servers.sort() cache[service_name] = list(servers) @@ -317,3 +328,80 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=t raise e defer.returnValue(servers) + + +@defer.inlineCallbacks +def _get_hosts_for_srv_record(dns_client, host): + """Look up each of the hosts in a SRV record + + Args: + dns_client (twisted.names.dns.IResolver): + host (basestring): host to look up + + Returns: + Deferred[list[(str, int)]]: a list of (host, ttl) pairs + + """ + ip4_servers = [] + ip6_servers = [] + + def cb(res): + # lookupAddress and lookupIP6Address return a three-tuple + # giving the answer, authority, and additional sections of the + # response. + # + # we only care about the answers. + + return res[0] + + def eb(res, record_type): + if res.check(DNSNameError): + return [] + logger.warn("Error looking up %s for %s: %s", + record_type, host, res, res.value) + return res + + # no logcontexts here, so we can safely fire these off and gatherResults + d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb) + d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb) + results = yield defer.DeferredList( + [d1, d2], consumeErrors=True) + + # if all of the lookups failed, raise an exception rather than blowing out + # the cache with an empty result. + if results and all(s == defer.FAILURE for (s, _) in results): + defer.returnValue(results[0][1]) + + for (success, result) in results: + if success == defer.FAILURE: + continue + + for answer in result: + if not answer.payload: + continue + + try: + if answer.type == dns.A: + ip = answer.payload.dottedQuad() + ip4_servers.append((ip, answer.ttl)) + elif answer.type == dns.AAAA: + ip = socket.inet_ntop( + socket.AF_INET6, answer.payload.address, + ) + ip6_servers.append((ip, answer.ttl)) + else: + # the most likely candidate here is a CNAME record. + # rfc2782 says srvs may not point to aliases. + logger.warn( + "Ignoring unexpected DNS record type %s for %s", + answer.type, host, + ) + continue + except Exception as e: + logger.warn("Ignoring invalid DNS response for %s: %s", + host, e) + continue + + # keep the ipv4 results before the ipv6 results, mostly to match historical + # behaviour. + defer.returnValue(ip4_servers + ip6_servers) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 8b94e6f29f..8c8b7fa656 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -204,18 +204,15 @@ class MatrixFederationHttpClient(object): raise logger.warn( - "{%s} Sending request failed to %s: %s %s: %s - %s", + "{%s} Sending request failed to %s: %s %s: %s", txn_id, destination, method, url_bytes, - type(e).__name__, _flatten_response_never_received(e), ) - log_result = "%s - %s" % ( - type(e).__name__, _flatten_response_never_received(e), - ) + log_result = _flatten_response_never_received(e) if retries_left and not timeout: if long_retries: @@ -618,12 +615,14 @@ class _JsonProducer(object): def _flatten_response_never_received(e): if hasattr(e, "reasons"): - return ", ".join( + reasons = ", ".join( _flatten_response_never_received(f.value) for f in e.reasons ) + + return "%s:[%s]" % (type(e).__name__, reasons) else: - return "%s: %s" % (type(e).__name__, e.message,) + return repr(e) def check_content_type_is_json(headers): diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index cd388770c8..6c379d53ac 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -398,22 +398,18 @@ class JoinedRoomMemberListRestServlet(ClientV1RestServlet): def __init__(self, hs): super(JoinedRoomMemberListRestServlet, self).__init__(hs) - self.state = hs.get_state_handler() + self.message_handler = hs.get_handlers().message_handler @defer.inlineCallbacks def on_GET(self, request, room_id): - yield self.auth.get_user_by_req(request) + requester = yield self.auth.get_user_by_req(request) - users_with_profile = yield self.state.get_current_user_in_room(room_id) + users_with_profile = yield self.message_handler.get_joined_members( + requester, room_id, + ) defer.returnValue((200, { - "joined": { - user_id: { - "avatar_url": profile.avatar_url, - "display_name": profile.display_name, - } - for user_id, profile in users_with_profile.iteritems() - } + "joined": users_with_profile, })) diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index d92b7ff337..d5cec10127 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -14,6 +14,9 @@ # limitations under the License. import os +import re + +NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d") class MediaFilePaths(object): @@ -73,19 +76,105 @@ class MediaFilePaths(object): ) def url_cache_filepath(self, media_id): - return os.path.join( - self.base_path, "url_cache", - media_id[0:2], media_id[2:4], media_id[4:] - ) + if NEW_FORMAT_ID_RE.match(media_id): + # Media id is of the form <DATE><RANDOM_STRING> + # E.g.: 2017-09-28-fsdRDt24DS234dsf + return os.path.join( + self.base_path, "url_cache", + media_id[:10], media_id[11:] + ) + else: + return os.path.join( + self.base_path, "url_cache", + media_id[0:2], media_id[2:4], media_id[4:], + ) + + def url_cache_filepath_dirs_to_delete(self, media_id): + "The dirs to try and remove if we delete the media_id file" + if NEW_FORMAT_ID_RE.match(media_id): + return [ + os.path.join( + self.base_path, "url_cache", + media_id[:10], + ), + ] + else: + return [ + os.path.join( + self.base_path, "url_cache", + media_id[0:2], media_id[2:4], + ), + os.path.join( + self.base_path, "url_cache", + media_id[0:2], + ), + ] def url_cache_thumbnail(self, media_id, width, height, content_type, method): + # Media id is of the form <DATE><RANDOM_STRING> + # E.g.: 2017-09-28-fsdRDt24DS234dsf + top_level_type, sub_type = content_type.split("/") file_name = "%i-%i-%s-%s-%s" % ( width, height, top_level_type, sub_type, method ) - return os.path.join( - self.base_path, "url_cache_thumbnails", - media_id[0:2], media_id[2:4], media_id[4:], - file_name - ) + + if NEW_FORMAT_ID_RE.match(media_id): + return os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[:10], media_id[11:], + file_name + ) + else: + return os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[0:2], media_id[2:4], media_id[4:], + file_name + ) + + def url_cache_thumbnail_directory(self, media_id): + # Media id is of the form <DATE><RANDOM_STRING> + # E.g.: 2017-09-28-fsdRDt24DS234dsf + + if NEW_FORMAT_ID_RE.match(media_id): + return os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[:10], media_id[11:], + ) + else: + return os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[0:2], media_id[2:4], media_id[4:], + ) + + def url_cache_thumbnail_dirs_to_delete(self, media_id): + "The dirs to try and remove if we delete the media_id thumbnails" + # Media id is of the form <DATE><RANDOM_STRING> + # E.g.: 2017-09-28-fsdRDt24DS234dsf + if NEW_FORMAT_ID_RE.match(media_id): + return [ + os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[:10], media_id[11:], + ), + os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[:10], + ), + ] + else: + return [ + os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[0:2], media_id[2:4], media_id[4:], + ), + os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[0:2], media_id[2:4], + ), + os.path.join( + self.base_path, "url_cache_thumbnails", + media_id[0:2], + ), + ] diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index b81a336c5d..895b480d5c 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -36,6 +36,9 @@ import cgi import ujson as json import urlparse import itertools +import datetime +import errno +import shutil import logging logger = logging.getLogger(__name__) @@ -70,6 +73,10 @@ class PreviewUrlResource(Resource): self.downloads = {} + self._cleaner_loop = self.clock.looping_call( + self._expire_url_cache_data, 10 * 1000 + ) + def render_GET(self, request): self._async_render_GET(request) return NOT_DONE_YET @@ -130,7 +137,7 @@ class PreviewUrlResource(Resource): cache_result = yield self.store.get_url_cache(url, ts) if ( cache_result and - cache_result["download_ts"] + cache_result["expires"] > ts and + cache_result["expires_ts"] > ts and cache_result["response_code"] / 100 == 2 ): respond_with_json_bytes( @@ -239,7 +246,7 @@ class PreviewUrlResource(Resource): url, media_info["response_code"], media_info["etag"], - media_info["expires"], + media_info["expires"] + media_info["created_ts"], json.dumps(og), media_info["filesystem_id"], media_info["created_ts"], @@ -253,8 +260,7 @@ class PreviewUrlResource(Resource): # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? - # XXX: horrible duplication with base_resource's _download_remote_file() - file_id = random_string(24) + file_id = datetime.date.today().isoformat() + '_' + random_string(16) fname = self.filepaths.url_cache_filepath(file_id) self.media_repo._makedirs(fname) @@ -328,6 +334,88 @@ class PreviewUrlResource(Resource): "etag": headers["ETag"][0] if "ETag" in headers else None, }) + @defer.inlineCallbacks + def _expire_url_cache_data(self): + """Clean up expired url cache content, media and thumbnails. + """ + now = self.clock.time_msec() + + # First we delete expired url cache entries + media_ids = yield self.store.get_expired_url_cache(now) + + removed_media = [] + for media_id in media_ids: + fname = self.filepaths.url_cache_filepath(media_id) + try: + os.remove(fname) + except OSError as e: + # If the path doesn't exist, meh + if e.errno != errno.ENOENT: + logger.warn("Failed to remove media: %r: %s", media_id, e) + continue + + removed_media.append(media_id) + + try: + dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id) + for dir in dirs: + os.rmdir(dir) + except: + pass + + yield self.store.delete_url_cache(removed_media) + + if removed_media: + logger.info("Deleted %d entries from url cache", len(removed_media)) + + # Now we delete old images associated with the url cache. + # These may be cached for a bit on the client (i.e., they + # may have a room open with a preview url thing open). + # So we wait a couple of days before deleting, just in case. + expire_before = now - 2 * 24 * 60 * 60 * 1000 + media_ids = yield self.store.get_url_cache_media_before(expire_before) + + removed_media = [] + for media_id in media_ids: + fname = self.filepaths.url_cache_filepath(media_id) + try: + os.remove(fname) + except OSError as e: + # If the path doesn't exist, meh + if e.errno != errno.ENOENT: + logger.warn("Failed to remove media: %r: %s", media_id, e) + continue + + try: + dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id) + for dir in dirs: + os.rmdir(dir) + except: + pass + + thumbnail_dir = self.filepaths.url_cache_thumbnail_directory(media_id) + try: + shutil.rmtree(thumbnail_dir) + except OSError as e: + # If the path doesn't exist, meh + if e.errno != errno.ENOENT: + logger.warn("Failed to remove media: %r: %s", media_id, e) + continue + + removed_media.append(media_id) + + try: + dirs = self.filepaths.url_cache_thumbnail_dirs_to_delete(media_id) + for dir in dirs: + os.rmdir(dir) + except: + pass + + yield self.store.delete_url_cache_media(removed_media) + + if removed_media: + logger.info("Deleted %d media from url cache", len(removed_media)) + def decode_and_calc_og(body, media_uri, request_encoding=None): from lxml import etree diff --git a/synapse/server.py b/synapse/server.py index 5b892cc390..10e3e9a4f1 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -31,6 +31,7 @@ from synapse.appservice.api import ApplicationServiceApi from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.crypto.keyring import Keyring from synapse.events.builder import EventBuilderFactory +from synapse.events.spamcheck import SpamChecker from synapse.federation import initialize_http_replication from synapse.federation.send_queue import FederationRemoteSendQueue from synapse.federation.transport.client import TransportLayerClient @@ -148,6 +149,7 @@ class HomeServer(object): 'groups_server_handler', 'groups_attestation_signing', 'groups_attestation_renewer', + 'spam_checker', ] def __init__(self, hostname, **kwargs): @@ -333,6 +335,9 @@ class HomeServer(object): def build_groups_attestation_renewer(self): return GroupAttestionRenewer(self) + def build_spam_checker(self): + return SpamChecker(self) + def remove_pusher(self, app_id, push_key, user_id): return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py index 3b5e0a4fb9..87aeaf71d6 100644 --- a/synapse/storage/keys.py +++ b/synapse/storage/keys.py @@ -113,30 +113,37 @@ class KeyStore(SQLBaseStore): keys[key_id] = key defer.returnValue(keys) - @defer.inlineCallbacks def store_server_verify_key(self, server_name, from_server, time_now_ms, verify_key): """Stores a NACL verification key for the given server. Args: server_name (str): The name of the server. - key_id (str): The version of the key for the server. from_server (str): Where the verification key was looked up - ts_now_ms (int): The time now in milliseconds - verification_key (VerifyKey): The NACL verify key. + time_now_ms (int): The time now in milliseconds + verify_key (nacl.signing.VerifyKey): The NACL verify key. """ - yield self._simple_upsert( - table="server_signature_keys", - keyvalues={ - "server_name": server_name, - "key_id": "%s:%s" % (verify_key.alg, verify_key.version), - }, - values={ - "from_server": from_server, - "ts_added_ms": time_now_ms, - "verify_key": buffer(verify_key.encode()), - }, - desc="store_server_verify_key", - ) + key_id = "%s:%s" % (verify_key.alg, verify_key.version) + + def _txn(txn): + self._simple_upsert_txn( + txn, + table="server_signature_keys", + keyvalues={ + "server_name": server_name, + "key_id": key_id, + }, + values={ + "from_server": from_server, + "ts_added_ms": time_now_ms, + "verify_key": buffer(verify_key.encode()), + }, + ) + txn.call_after( + self._get_server_verify_key.invalidate, + (server_name, key_id) + ) + + return self.runInteraction("store_server_verify_key", _txn) def store_server_keys_json(self, server_name, key_id, from_server, ts_now_ms, ts_expires_ms, key_json_bytes): diff --git a/synapse/storage/media_repository.py b/synapse/storage/media_repository.py index 82bb61b811..7110a71279 100644 --- a/synapse/storage/media_repository.py +++ b/synapse/storage/media_repository.py @@ -62,7 +62,7 @@ class MediaRepositoryStore(SQLBaseStore): def get_url_cache_txn(txn): # get the most recently cached result (relative to the given ts) sql = ( - "SELECT response_code, etag, expires, og, media_id, download_ts" + "SELECT response_code, etag, expires_ts, og, media_id, download_ts" " FROM local_media_repository_url_cache" " WHERE url = ? AND download_ts <= ?" " ORDER BY download_ts DESC LIMIT 1" @@ -74,7 +74,7 @@ class MediaRepositoryStore(SQLBaseStore): # ...or if we've requested a timestamp older than the oldest # copy in the cache, return the oldest copy (if any) sql = ( - "SELECT response_code, etag, expires, og, media_id, download_ts" + "SELECT response_code, etag, expires_ts, og, media_id, download_ts" " FROM local_media_repository_url_cache" " WHERE url = ? AND download_ts > ?" " ORDER BY download_ts ASC LIMIT 1" @@ -86,14 +86,14 @@ class MediaRepositoryStore(SQLBaseStore): return None return dict(zip(( - 'response_code', 'etag', 'expires', 'og', 'media_id', 'download_ts' + 'response_code', 'etag', 'expires_ts', 'og', 'media_id', 'download_ts' ), row)) return self.runInteraction( "get_url_cache", get_url_cache_txn ) - def store_url_cache(self, url, response_code, etag, expires, og, media_id, + def store_url_cache(self, url, response_code, etag, expires_ts, og, media_id, download_ts): return self._simple_insert( "local_media_repository_url_cache", @@ -101,7 +101,7 @@ class MediaRepositoryStore(SQLBaseStore): "url": url, "response_code": response_code, "etag": etag, - "expires": expires, + "expires_ts": expires_ts, "og": og, "media_id": media_id, "download_ts": download_ts, @@ -238,3 +238,64 @@ class MediaRepositoryStore(SQLBaseStore): }, ) return self.runInteraction("delete_remote_media", delete_remote_media_txn) + + def get_expired_url_cache(self, now_ts): + sql = ( + "SELECT media_id FROM local_media_repository_url_cache" + " WHERE expires_ts < ?" + " ORDER BY expires_ts ASC" + " LIMIT 500" + ) + + def _get_expired_url_cache_txn(txn): + txn.execute(sql, (now_ts,)) + return [row[0] for row in txn] + + return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn) + + def delete_url_cache(self, media_ids): + sql = ( + "DELETE FROM local_media_repository_url_cache" + " WHERE media_id = ?" + ) + + def _delete_url_cache_txn(txn): + txn.executemany(sql, [(media_id,) for media_id in media_ids]) + + return self.runInteraction("delete_url_cache", _delete_url_cache_txn) + + def get_url_cache_media_before(self, before_ts): + sql = ( + "SELECT media_id FROM local_media_repository" + " WHERE created_ts < ? AND url_cache IS NOT NULL" + " ORDER BY created_ts ASC" + " LIMIT 500" + ) + + def _get_url_cache_media_before_txn(txn): + txn.execute(sql, (before_ts,)) + return [row[0] for row in txn] + + return self.runInteraction( + "get_url_cache_media_before", _get_url_cache_media_before_txn, + ) + + def delete_url_cache_media(self, media_ids): + def _delete_url_cache_media_txn(txn): + sql = ( + "DELETE FROM local_media_repository" + " WHERE media_id = ?" + ) + + txn.executemany(sql, [(media_id,) for media_id in media_ids]) + + sql = ( + "DELETE FROM local_media_repository_thumbnails" + " WHERE media_id = ?" + ) + + txn.executemany(sql, [(media_id,) for media_id in media_ids]) + + return self.runInteraction( + "delete_url_cache_media", _delete_url_cache_media_txn, + ) diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index 72b670b83b..a0af8456f5 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) # Remember to update this number every time a change is made to database # schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 43 +SCHEMA_VERSION = 44 dir_path = os.path.abspath(os.path.dirname(__file__)) diff --git a/synapse/storage/schema/delta/44/expire_url_cache.sql b/synapse/storage/schema/delta/44/expire_url_cache.sql new file mode 100644 index 0000000000..e2b775f038 --- /dev/null +++ b/synapse/storage/schema/delta/44/expire_url_cache.sql @@ -0,0 +1,38 @@ +/* Copyright 2017 New Vector Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL; + +-- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support +-- indices on expressions until 3.9. +CREATE TABLE local_media_repository_url_cache_new( + url TEXT, + response_code INTEGER, + etag TEXT, + expires_ts BIGINT, + og TEXT, + media_id TEXT, + download_ts BIGINT +); + +INSERT INTO local_media_repository_url_cache_new + SELECT url, response_code, etag, expires + download_ts, og, media_id, download_ts FROM local_media_repository_url_cache; + +DROP TABLE local_media_repository_url_cache; +ALTER TABLE local_media_repository_url_cache_new RENAME TO local_media_repository_url_cache; + +CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache(expires_ts); +CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache(url, download_ts); +CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache(media_id); diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py new file mode 100644 index 0000000000..4288312b8a --- /dev/null +++ b/synapse/util/module_loader.py @@ -0,0 +1,42 @@ +# -*- coding: utf-8 -*- +# Copyright 2017 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib + +from synapse.config._base import ConfigError + + +def load_module(provider): + """ Loads a module with its config + Take a dict with keys 'module' (the module name) and 'config' + (the config dict). + + Returns + Tuple of (provider class, parsed config object) + """ + # We need to import the module, and then pick the class out of + # that, so we split based on the last dot. + module, clz = provider['module'].rsplit(".", 1) + module = importlib.import_module(module) + provider_class = getattr(module, clz) + + try: + provider_config = provider_class.parse_config(provider["config"]) + except Exception as e: + raise ConfigError( + "Failed to parse config for %r: %r" % (provider['module'], e) + ) + + return provider_class, provider_config |