diff options
author | Erik Johnston <erik@matrix.org> | 2015-02-19 10:38:48 +0000 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2015-02-19 10:38:48 +0000 |
commit | 8321e8a2e0381f52f3f434223db58f6ea280d89e (patch) | |
tree | 13c5600cbeb56c7c2837dd2df329f10a239f91ac /synapse | |
parent | Merge pull request #73 from matrix-org/hotfixes-v0.7.0f (diff) | |
parent | Update release date (diff) | |
download | synapse-8321e8a2e0381f52f3f434223db58f6ea280d89e.tar.xz |
Merge branch 'release-v0.7.1' of github.com:matrix-org/synapse
Diffstat (limited to '')
45 files changed, 1995 insertions, 304 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py index 792b8089f6..a804a25748 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py @@ -16,4 +16,4 @@ """ This is a reference implementation of a Matrix home server. """ -__version__ = "0.7.0f" +__version__ = "0.7.1" diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 0054745363..b176db8ce1 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -306,6 +306,34 @@ class Auth(object): # Can optionally look elsewhere in the request (e.g. headers) try: access_token = request.args["access_token"][0] + + # Check for application service tokens with a user_id override + try: + app_service = yield self.store.get_app_service_by_token( + access_token + ) + if not app_service: + raise KeyError + + user_id = app_service.sender + if "user_id" in request.args: + user_id = request.args["user_id"][0] + if not app_service.is_interested_in_user(user_id): + raise AuthError( + 403, + "Application service cannot masquerade as this user." + ) + + if not user_id: + raise KeyError + + defer.returnValue( + (UserID.from_string(user_id), ClientInfo("", "")) + ) + return + except KeyError: + pass # normal users won't have this query parameter set + user_info = yield self.get_user_by_token(access_token) user = user_info["user"] device_id = user_info["device_id"] @@ -344,8 +372,7 @@ class Auth(object): try: ret = yield self.store.get_user_by_token(token=token) if not ret: - raise StoreError() - + raise StoreError(400, "Unknown token") user_info = { "admin": bool(ret.get("admin", False)), "device_id": ret.get("device_id"), @@ -358,6 +385,18 @@ class Auth(object): raise AuthError(403, "Unrecognised access token.", errcode=Codes.UNKNOWN_TOKEN) + @defer.inlineCallbacks + def get_appservice_by_req(self, request): + try: + token = request.args["access_token"][0] + service = yield self.store.get_app_service_by_token(token) + if not service: + raise AuthError(403, "Unrecognised access token.", + errcode=Codes.UNKNOWN_TOKEN) + defer.returnValue(service) + except KeyError: + raise AuthError(403, "Missing access token.") + def is_server_admin(self, user): return self.store.is_server_admin(user) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 0d3fc629af..420f963d91 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -59,6 +59,7 @@ class LoginType(object): EMAIL_URL = u"m.login.email.url" EMAIL_IDENTITY = u"m.login.email.identity" RECAPTCHA = u"m.login.recaptcha" + APPLICATION_SERVICE = u"m.login.application_service" class EventTypes(object): diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 5041828f18..eddd889778 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -36,7 +36,8 @@ class Codes(object): CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" CAPTCHA_INVALID = "M_CAPTCHA_INVALID" MISSING_PARAM = "M_MISSING_PARAM", - TOO_LARGE = "M_TOO_LARGE" + TOO_LARGE = "M_TOO_LARGE", + EXCLUSIVE = "M_EXCLUSIVE" class CodeMessageException(RuntimeError): diff --git a/synapse/api/urls.py b/synapse/api/urls.py index 693c0efda6..9485719332 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -22,3 +22,4 @@ WEB_CLIENT_PREFIX = "/_matrix/client" CONTENT_REPO_PREFIX = "/_matrix/content" SERVER_KEY_PREFIX = "/_matrix/key/v1" MEDIA_PREFIX = "/_matrix/media/v1" +APP_SERVICE_PREFIX = "/_matrix/appservice/v1" diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index f5681fac20..ea20de1434 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -14,6 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys +sys.dont_write_bytecode = True + from synapse.storage import prepare_database, UpgradeDatabaseException from synapse.server import HomeServer @@ -26,13 +29,14 @@ from twisted.web.resource import Resource from twisted.web.static import File from twisted.web.server import Site from synapse.http.server import JsonResource, RootRedirect +from synapse.rest.appservice.v1 import AppServiceRestResource from synapse.rest.media.v0.content_repository import ContentRepoResource from synapse.rest.media.v1.media_repository import MediaRepositoryResource from synapse.http.server_key_resource import LocalKey from synapse.http.matrixfederationclient import MatrixFederationHttpClient from synapse.api.urls import ( CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX, - SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, + SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, APP_SERVICE_PREFIX ) from synapse.config.homeserver import HomeServerConfig from synapse.crypto import context_factory @@ -48,7 +52,7 @@ import synapse import logging import os import re -import sys +import subprocess import sqlite3 import syweb @@ -69,6 +73,9 @@ class SynapseHomeServer(HomeServer): def build_resource_for_federation(self): return JsonResource(self) + def build_resource_for_app_services(self): + return AppServiceRestResource(self) + def build_resource_for_web_client(self): syweb_path = os.path.dirname(syweb.__file__) webclient_path = os.path.join(syweb_path, "webclient") @@ -90,7 +97,9 @@ class SynapseHomeServer(HomeServer): "sqlite3", self.get_db_name(), check_same_thread=False, cp_min=1, - cp_max=1 + cp_max=1, + cp_openfun=prepare_database, # Prepare the database for each conn + # so that :memory: sqlite works ) def create_resource_tree(self, web_client, redirect_root_to_web_client): @@ -114,6 +123,7 @@ class SynapseHomeServer(HomeServer): (CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()), (SERVER_KEY_PREFIX, self.get_resource_for_server_key()), (MEDIA_PREFIX, self.get_resource_for_media_repository()), + (APP_SERVICE_PREFIX, self.get_resource_for_app_services()), ] if web_client: logger.info("Adding the web client.") @@ -199,6 +209,66 @@ class SynapseHomeServer(HomeServer): logger.info("Synapse now listening on port %d", unsecure_port) +def get_version_string(): + null = open(os.devnull, 'w') + cwd = os.path.dirname(os.path.abspath(__file__)) + try: + git_branch = subprocess.check_output( + ['git', 'rev-parse', '--abbrev-ref', 'HEAD'], + stderr=null, + cwd=cwd, + ).strip() + git_branch = "b=" + git_branch + except subprocess.CalledProcessError: + git_branch = "" + + try: + git_tag = subprocess.check_output( + ['git', 'describe', '--exact-match'], + stderr=null, + cwd=cwd, + ).strip() + git_tag = "t=" + git_tag + except subprocess.CalledProcessError: + git_tag = "" + + try: + git_commit = subprocess.check_output( + ['git', 'rev-parse', '--short', 'HEAD'], + stderr=null, + cwd=cwd, + ).strip() + except subprocess.CalledProcessError: + git_commit = "" + + try: + dirty_string = "-this_is_a_dirty_checkout" + is_dirty = subprocess.check_output( + ['git', 'describe', '--dirty=' + dirty_string], + stderr=null, + cwd=cwd, + ).strip().endswith(dirty_string) + + git_dirty = "dirty" if is_dirty else "" + except subprocess.CalledProcessError: + git_dirty = "" + + if git_branch or git_tag or git_commit or git_dirty: + git_version = ",".join( + s for s in + (git_branch, git_tag, git_commit, git_dirty,) + if s + ) + + return ( + "Synapse/%s (%s)" % ( + synapse.__version__, git_version, + ) + ).encode("ascii") + + return ("Synapse/%s" % (synapse.__version__,)).encode("ascii") + + def setup(): config = HomeServerConfig.load_config( "Synapse Homeserver", @@ -210,8 +280,10 @@ def setup(): check_requirements() + version_string = get_version_string() + logger.info("Server hostname: %s", config.server_name) - logger.info("Server version: %s", synapse.__version__) + logger.info("Server version: %s", version_string) if re.search(":[0-9]+$", config.server_name): domain_with_port = config.server_name @@ -228,6 +300,7 @@ def setup(): tls_context_factory=tls_context_factory, config=config, content_addr=config.content_addr, + version_string=version_string, ) hs.create_resource_tree( @@ -252,14 +325,6 @@ def setup(): logger.info("Database prepared in %s.", db_name) - db_pool = hs.get_db_pool() - - if db_name == ":memory:": - # Memory databases will need to be setup each time they are opened. - reactor.callWhenRunning( - db_pool.runWithConnection, prepare_database - ) - if config.manhole: f = twisted.manhole.telnet.ShellFactory() f.username = "matrix" @@ -270,12 +335,13 @@ def setup(): bind_port = config.bind_port if config.no_tls: bind_port = None + hs.start_listening(bind_port, config.unsecure_port) hs.get_pusherpool().start() - hs.get_state_handler().start_caching() hs.get_datastore().start_profiling() + hs.get_replication_layer().start_get_pdu_cache() if config.daemonize: print config.pid_file diff --git a/synapse/app/synctl.py b/synapse/app/synctl.py index 363c20f994..3a70a248dc 100755 --- a/synapse/app/synctl.py +++ b/synapse/app/synctl.py @@ -19,7 +19,7 @@ import os import subprocess import signal -SYNAPSE = ["python", "-m", "synapse.app.homeserver"] +SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"] CONFIGFILE = "homeserver.yaml" PIDFILE = "homeserver.pid" diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py new file mode 100644 index 0000000000..381b4cfc4a --- /dev/null +++ b/synapse/appservice/__init__.py @@ -0,0 +1,147 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.api.constants import EventTypes + +import logging +import re + +logger = logging.getLogger(__name__) + + +class ApplicationService(object): + """Defines an application service. This definition is mostly what is + provided to the /register AS API. + + Provides methods to check if this service is "interested" in events. + """ + NS_USERS = "users" + NS_ALIASES = "aliases" + NS_ROOMS = "rooms" + # The ordering here is important as it is used to map database values (which + # are stored as ints representing the position in this list) to namespace + # values. + NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS] + + def __init__(self, token, url=None, namespaces=None, hs_token=None, + sender=None, txn_id=None): + self.token = token + self.url = url + self.hs_token = hs_token + self.sender = sender + self.namespaces = self._check_namespaces(namespaces) + self.txn_id = txn_id + + def _check_namespaces(self, namespaces): + # Sanity check that it is of the form: + # { + # users: ["regex",...], + # aliases: ["regex",...], + # rooms: ["regex",...], + # } + if not namespaces: + return None + + for ns in ApplicationService.NS_LIST: + if type(namespaces[ns]) != list: + raise ValueError("Bad namespace value for '%s'", ns) + for regex in namespaces[ns]: + if not isinstance(regex, basestring): + raise ValueError("Expected string regex for ns '%s'", ns) + return namespaces + + def _matches_regex(self, test_string, namespace_key): + if not isinstance(test_string, basestring): + logger.error( + "Expected a string to test regex against, but got %s", + test_string + ) + return False + + for regex in self.namespaces[namespace_key]: + if re.match(regex, test_string): + return True + return False + + def _matches_user(self, event, member_list): + if (hasattr(event, "sender") and + self.is_interested_in_user(event.sender)): + return True + # also check m.room.member state key + if (hasattr(event, "type") and event.type == EventTypes.Member + and hasattr(event, "state_key") + and self.is_interested_in_user(event.state_key)): + return True + # check joined member events + for member in member_list: + if self.is_interested_in_user(member.state_key): + return True + return False + + def _matches_room_id(self, event): + if hasattr(event, "room_id"): + return self.is_interested_in_room(event.room_id) + return False + + def _matches_aliases(self, event, alias_list): + for alias in alias_list: + if self.is_interested_in_alias(alias): + return True + return False + + def is_interested(self, event, restrict_to=None, aliases_for_event=None, + member_list=None): + """Check if this service is interested in this event. + + Args: + event(Event): The event to check. + restrict_to(str): The namespace to restrict regex tests to. + aliases_for_event(list): A list of all the known room aliases for + this event. + member_list(list): A list of all joined room members in this room. + Returns: + bool: True if this service would like to know about this event. + """ + if aliases_for_event is None: + aliases_for_event = [] + if member_list is None: + member_list = [] + + if restrict_to and restrict_to not in ApplicationService.NS_LIST: + # this is a programming error, so fail early and raise a general + # exception + raise Exception("Unexpected restrict_to value: %s". restrict_to) + + if not restrict_to: + return (self._matches_user(event, member_list) + or self._matches_aliases(event, aliases_for_event) + or self._matches_room_id(event)) + elif restrict_to == ApplicationService.NS_ALIASES: + return self._matches_aliases(event, aliases_for_event) + elif restrict_to == ApplicationService.NS_ROOMS: + return self._matches_room_id(event) + elif restrict_to == ApplicationService.NS_USERS: + return self._matches_user(event, member_list) + + def is_interested_in_user(self, user_id): + return self._matches_regex(user_id, ApplicationService.NS_USERS) + + def is_interested_in_alias(self, alias): + return self._matches_regex(alias, ApplicationService.NS_ALIASES) + + def is_interested_in_room(self, room_id): + return self._matches_regex(room_id, ApplicationService.NS_ROOMS) + + def __str__(self): + return "ApplicationService: %s" % (self.__dict__,) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py new file mode 100644 index 0000000000..c2179f8d55 --- /dev/null +++ b/synapse/appservice/api.py @@ -0,0 +1,108 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from twisted.internet import defer + +from synapse.api.errors import CodeMessageException +from synapse.http.client import SimpleHttpClient +from synapse.events.utils import serialize_event + +import logging +import urllib + +logger = logging.getLogger(__name__) + + +class ApplicationServiceApi(SimpleHttpClient): + """This class manages HS -> AS communications, including querying and + pushing. + """ + + def __init__(self, hs): + super(ApplicationServiceApi, self).__init__(hs) + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def query_user(self, service, user_id): + uri = service.url + ("/users/%s" % urllib.quote(user_id)) + response = None + try: + response = yield self.get_json(uri, { + "access_token": service.hs_token + }) + if response is not None: # just an empty json object + defer.returnValue(True) + except CodeMessageException as e: + if e.code == 404: + defer.returnValue(False) + return + logger.warning("query_user to %s received %s", uri, e.code) + except Exception as ex: + logger.warning("query_user to %s threw exception %s", uri, ex) + defer.returnValue(False) + + @defer.inlineCallbacks + def query_alias(self, service, alias): + uri = service.url + ("/rooms/%s" % urllib.quote(alias)) + response = None + try: + response = yield self.get_json(uri, { + "access_token": service.hs_token + }) + if response is not None: # just an empty json object + defer.returnValue(True) + except CodeMessageException as e: + logger.warning("query_alias to %s received %s", uri, e.code) + if e.code == 404: + defer.returnValue(False) + return + except Exception as ex: + logger.warning("query_alias to %s threw exception %s", uri, ex) + defer.returnValue(False) + + @defer.inlineCallbacks + def push_bulk(self, service, events): + events = self._serialize(events) + + uri = service.url + ("/transactions/%s" % + urllib.quote(str(0))) # TODO txn_ids + response = None + try: + response = yield self.put_json( + uri=uri, + json_body={ + "events": events + }, + args={ + "access_token": service.hs_token + }) + if response: # just an empty json object + # TODO: Mark txn as sent successfully + defer.returnValue(True) + except CodeMessageException as e: + logger.warning("push_bulk to %s received %s", uri, e.code) + except Exception as ex: + logger.warning("push_bulk to %s threw exception %s", uri, ex) + defer.returnValue(False) + + @defer.inlineCallbacks + def push(self, service, event): + response = yield self.push_bulk(service, [event]) + defer.returnValue(response) + + def _serialize(self, events): + time_now = self.clock.time_msec() + return [ + serialize_event(e, time_now, as_client_event=True) for e in events + ] diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 3fb99f7125..828aced44a 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -22,6 +22,8 @@ from syutil.crypto.signing_key import ( from syutil.base64util import decode_base64, encode_base64 from synapse.api.errors import SynapseError, Codes +from synapse.util.retryutils import get_retry_limiter + from OpenSSL import crypto import logging @@ -87,12 +89,18 @@ class Keyring(object): return # Try to fetch the key from the remote server. - # TODO(markjh): Ratelimit requests to a given server. - (response, tls_certificate) = yield fetch_server_key( - server_name, self.hs.tls_context_factory + limiter = yield get_retry_limiter( + server_name, + self.clock, + self.store, ) + with limiter: + (response, tls_certificate) = yield fetch_server_key( + server_name, self.hs.tls_context_factory + ) + # Check the response. x509_certificate_bytes = crypto.dump_certificate( diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 70c9a6f46b..cd3c962d50 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -19,10 +19,13 @@ from twisted.internet import defer from .federation_base import FederationBase from .units import Edu -from synapse.api.errors import CodeMessageException +from synapse.api.errors import CodeMessageException, SynapseError +from synapse.util.expiringcache import ExpiringCache from synapse.util.logutils import log_function from synapse.events import FrozenEvent +from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination + import logging @@ -30,6 +33,20 @@ logger = logging.getLogger(__name__) class FederationClient(FederationBase): + def __init__(self): + self._get_pdu_cache = None + + def start_get_pdu_cache(self): + self._get_pdu_cache = ExpiringCache( + cache_name="get_pdu_cache", + clock=self._clock, + max_len=1000, + expiry_ms=120*1000, + reset_expiry_on_get=False, + ) + + self._get_pdu_cache.start() + @log_function def send_pdu(self, pdu, destinations): """Informs the replication layer about a new PDU generated within the @@ -160,29 +177,58 @@ class FederationClient(FederationBase): # TODO: Rate limit the number of times we try and get the same event. + if self._get_pdu_cache: + e = self._get_pdu_cache.get(event_id) + if e: + defer.returnValue(e) + pdu = None for destination in destinations: try: - transaction_data = yield self.transport_layer.get_event( - destination, event_id + limiter = yield get_retry_limiter( + destination, + self._clock, + self.store, ) - logger.debug("transaction_data %r", transaction_data) + with limiter: + transaction_data = yield self.transport_layer.get_event( + destination, event_id + ) - pdu_list = [ - self.event_from_pdu_json(p, outlier=outlier) - for p in transaction_data["pdus"] - ] + logger.debug("transaction_data %r", transaction_data) - if pdu_list: - pdu = pdu_list[0] + pdu_list = [ + self.event_from_pdu_json(p, outlier=outlier) + for p in transaction_data["pdus"] + ] - # Check signatures are correct. - pdu = yield self._check_sigs_and_hash(pdu) + if pdu_list: + pdu = pdu_list[0] - break - except CodeMessageException: - raise + # Check signatures are correct. + pdu = yield self._check_sigs_and_hash(pdu) + + break + + except SynapseError: + logger.info( + "Failed to get PDU %s from %s because %s", + event_id, destination, e, + ) + continue + except CodeMessageException as e: + if 400 <= e.code < 500: + raise + + logger.info( + "Failed to get PDU %s from %s because %s", + event_id, destination, e, + ) + continue + except NotRetryingDestination as e: + logger.info(e.message) + continue except Exception as e: logger.info( "Failed to get PDU %s from %s because %s", @@ -190,6 +236,9 @@ class FederationClient(FederationBase): ) continue + if self._get_pdu_cache is not None: + self._get_pdu_cache[event_id] = pdu + defer.returnValue(pdu) @defer.inlineCallbacks diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 9f5c98694c..22b9663831 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -114,7 +114,15 @@ class FederationServer(FederationBase): with PreserveLoggingContext(): dl = [] for pdu in pdu_list: - dl.append(self._handle_new_pdu(transaction.origin, pdu)) + d = self._handle_new_pdu(transaction.origin, pdu) + + def handle_failure(failure): + failure.trap(FederationError) + self.send_failure(failure.value, transaction.origin) + + d.addErrback(handle_failure) + + dl.append(d) if hasattr(transaction, "edus"): for edu in [Edu(**x) for x in transaction.edus]: @@ -124,7 +132,10 @@ class FederationServer(FederationBase): edu.content ) - results = yield defer.DeferredList(dl) + for failure in getattr(transaction, "pdu_failures", []): + logger.info("Got failure %r", failure) + + results = yield defer.DeferredList(dl, consumeErrors=True) ret = [] for r in results: @@ -132,10 +143,16 @@ class FederationServer(FederationBase): ret.append({}) else: logger.exception(r[1]) - ret.append({"error": str(r[1])}) + ret.append({"error": str(r[1].value)}) logger.debug("Returning: %s", str(ret)) + response = { + "pdus": dict(zip( + (p.event_id for p in pdu_list), ret + )), + } + yield self.transaction_actions.set_response( transaction, 200, response @@ -331,7 +348,6 @@ class FederationServer(FederationBase): ) if already_seen: logger.debug("Already seen pdu %s", pdu.event_id) - defer.returnValue({}) return # Check signature. @@ -367,7 +383,13 @@ class FederationServer(FederationBase): pdu.room_id, min_depth ) - if min_depth and pdu.depth > min_depth and max_recursion > 0: + if min_depth and pdu.depth < min_depth: + # This is so that we don't notify the user about this + # message, to work around the fact that some events will + # reference really really old events we really don't want to + # send to the clients. + pdu.internal_metadata.outlier = True + elif min_depth and pdu.depth > min_depth and max_recursion > 0: for event_id, hashes in pdu.prev_events: if event_id not in have_seen: logger.debug( @@ -418,7 +440,7 @@ class FederationServer(FederationBase): except: logger.warn("Failed to get state for event: %s", pdu.event_id) - ret = yield self.handler.on_receive_pdu( + yield self.handler.on_receive_pdu( origin, pdu, backfilled=False, @@ -426,8 +448,6 @@ class FederationServer(FederationBase): auth_chain=auth_chain, ) - defer.returnValue(ret) - def __str__(self): return "<ReplicationLayer(%s)>" % self.server_name diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 731019ad9f..7d30c924d1 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -22,6 +22,9 @@ from .units import Transaction from synapse.api.errors import HttpResponseException from synapse.util.logutils import log_function from synapse.util.logcontext import PreserveLoggingContext +from synapse.util.retryutils import ( + get_retry_limiter, NotRetryingDestination, +) import logging @@ -63,6 +66,26 @@ class TransactionQueue(object): # HACK to get unique tx id self._next_txn_id = int(self._clock.time_msec()) + def can_send_to(self, destination): + """Can we send messages to the given server? + + We can't send messages to ourselves. If we are running on localhost + then we can only federation with other servers running on localhost. + Otherwise we only federate with servers on a public domain. + + Args: + destination(str): The server we are possibly trying to send to. + Returns: + bool: True if we can send to the server. + """ + + if destination == self.server_name: + return False + if self.server_name.startswith("localhost"): + return destination.startswith("localhost") + else: + return not destination.startswith("localhost") + @defer.inlineCallbacks @log_function def enqueue_pdu(self, pdu, destinations, order): @@ -71,8 +94,9 @@ class TransactionQueue(object): # table and we'll get back to it later. destinations = set(destinations) - destinations.discard(self.server_name) - destinations.discard("localhost") + destinations = set( + dest for dest in destinations if self.can_send_to(dest) + ) logger.debug("Sending to: %s", str(destinations)) @@ -87,24 +111,27 @@ class TransactionQueue(object): (pdu, deferred, order) ) - def eb(failure): + def chain(failure): if not deferred.called: deferred.errback(failure) - else: - logger.warn("Failed to send pdu", failure) + + def log_failure(failure): + logger.warn("Failed to send pdu", failure.value) + + deferred.addErrback(log_failure) with PreserveLoggingContext(): - self._attempt_new_transaction(destination).addErrback(eb) + self._attempt_new_transaction(destination).addErrback(chain) deferreds.append(deferred) - yield defer.DeferredList(deferreds) + yield defer.DeferredList(deferreds, consumeErrors=True) # NO inlineCallbacks def enqueue_edu(self, edu): destination = edu.destination - if destination == self.server_name: + if not self.can_send_to(destination): return deferred = defer.Deferred() @@ -112,51 +139,53 @@ class TransactionQueue(object): (edu, deferred) ) - def eb(failure): + def chain(failure): if not deferred.called: deferred.errback(failure) - else: - logger.warn("Failed to send edu", failure) + + def log_failure(failure): + logger.warn("Failed to send pdu", failure.value) + + deferred.addErrback(log_failure) with PreserveLoggingContext(): - self._attempt_new_transaction(destination).addErrback(eb) + self._attempt_new_transaction(destination).addErrback(chain) return deferred @defer.inlineCallbacks def enqueue_failure(self, failure, destination): + if destination == self.server_name or destination == "localhost": + return + deferred = defer.Deferred() + if not self.can_send_to(destination): + return + self.pending_failures_by_dest.setdefault( destination, [] ).append( (failure, deferred) ) + def chain(f): + if not deferred.called: + deferred.errback(f) + + def log_failure(f): + logger.warn("Failed to send pdu", f.value) + + deferred.addErrback(log_failure) + + with PreserveLoggingContext(): + self._attempt_new_transaction(destination).addErrback(chain) + yield deferred @defer.inlineCallbacks @log_function def _attempt_new_transaction(self, destination): - - (retry_last_ts, retry_interval) = (0, 0) - retry_timings = yield self.store.get_destination_retry_timings( - destination - ) - if retry_timings: - (retry_last_ts, retry_interval) = ( - retry_timings.retry_last_ts, retry_timings.retry_interval - ) - if retry_last_ts + retry_interval > int(self._clock.time_msec()): - logger.info( - "TX [%s] not ready for retry yet - " - "dropping transaction for now", - destination, - ) - return - else: - logger.info("TX [%s] is ready for retry", destination) - if destination in self.pending_transactions: # XXX: pending_transactions can get stuck on by a never-ending # request at which point pending_pdus_by_dest just keeps growing. @@ -183,15 +212,6 @@ class TransactionQueue(object): logger.info("TX [%s] Nothing to send", destination) return - logger.debug( - "TX [%s] Attempting new transaction" - " (pdus: %d, edus: %d, failures: %d)", - destination, - len(pending_pdus), - len(pending_edus), - len(pending_failures) - ) - # Sort based on the order field pending_pdus.sort(key=lambda t: t[2]) @@ -204,6 +224,21 @@ class TransactionQueue(object): ] try: + limiter = yield get_retry_limiter( + destination, + self._clock, + self.store, + ) + + logger.debug( + "TX [%s] Attempting new transaction" + " (pdus: %d, edus: %d, failures: %d)", + destination, + len(pending_pdus), + len(pending_edus), + len(pending_failures) + ) + self.pending_transactions[destination] = 1 logger.debug("TX [%s] Persisting transaction...", destination) @@ -229,52 +264,56 @@ class TransactionQueue(object): transaction.transaction_id, ) - # Actually send the transaction - - # FIXME (erikj): This is a bit of a hack to make the Pdu age - # keys work - def json_data_cb(): - data = transaction.get_dict() - now = int(self._clock.time_msec()) - if "pdus" in data: - for p in data["pdus"]: - if "age_ts" in p: - unsigned = p.setdefault("unsigned", {}) - unsigned["age"] = now - int(p["age_ts"]) - del p["age_ts"] - return data - - try: - response = yield self.transport_layer.send_transaction( - transaction, json_data_cb - ) - code = 200 - except HttpResponseException as e: - code = e.code - response = e.response - - logger.info("TX [%s] got %d response", destination, code) - - logger.debug("TX [%s] Sent transaction", destination) - logger.debug("TX [%s] Marking as delivered...", destination) + with limiter: + # Actually send the transaction + + # FIXME (erikj): This is a bit of a hack to make the Pdu age + # keys work + def json_data_cb(): + data = transaction.get_dict() + now = int(self._clock.time_msec()) + if "pdus" in data: + for p in data["pdus"]: + if "age_ts" in p: + unsigned = p.setdefault("unsigned", {}) + unsigned["age"] = now - int(p["age_ts"]) + del p["age_ts"] + return data + + try: + response = yield self.transport_layer.send_transaction( + transaction, json_data_cb + ) + code = 200 + + if response: + for e_id, r in getattr(response, "pdus", {}).items(): + if "error" in r: + logger.warn( + "Transaction returned error for %s: %s", + e_id, r, + ) + except HttpResponseException as e: + code = e.code + response = e.response + + logger.info("TX [%s] got %d response", destination, code) + + logger.debug("TX [%s] Sent transaction", destination) + logger.debug("TX [%s] Marking as delivered...", destination) yield self.transaction_actions.delivered( transaction, code, response ) logger.debug("TX [%s] Marked as delivered", destination) + logger.debug("TX [%s] Yielding to callbacks...", destination) for deferred in deferreds: if code == 200: - if retry_last_ts: - # this host is alive! reset retry schedule - yield self.store.set_destination_retry_timings( - destination, 0, 0 - ) deferred.callback(None) else: - self.set_retrying(destination, retry_interval) deferred.errback(RuntimeError("Got status %d" % code)) # Ensures we don't continue until all callbacks on that @@ -285,6 +324,12 @@ class TransactionQueue(object): pass logger.debug("TX [%s] Yielded to callbacks", destination) + except NotRetryingDestination: + logger.info( + "TX [%s] not ready for retry yet - " + "dropping transaction for now", + destination, + ) except RuntimeError as e: # We capture this here as there as nothing actually listens # for this finishing functions deferred. @@ -296,14 +341,12 @@ class TransactionQueue(object): except Exception as e: # We capture this here as there as nothing actually listens # for this finishing functions deferred. - logger.exception( + logger.warn( "TX [%s] Problem in _attempt_transaction: %s", destination, e, ) - self.set_retrying(destination, retry_interval) - for deferred in deferreds: if not deferred.called: deferred.errback(e) @@ -314,22 +357,3 @@ class TransactionQueue(object): # Check to see if there is anything else to send. self._attempt_new_transaction(destination) - - @defer.inlineCallbacks - def set_retrying(self, destination, retry_interval): - # track that this destination is having problems and we should - # give it a chance to recover before trying it again - - if retry_interval: - retry_interval *= 2 - # plateau at hourly retries for now - if retry_interval >= 60 * 60 * 1000: - retry_interval = 60 * 60 * 1000 - else: - retry_interval = 2000 # try again at first after 2 seconds - - yield self.store.set_destination_retry_timings( - destination, - int(self._clock.time_msec()), - retry_interval - ) diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index a32eab9316..8d345bf936 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.appservice.api import ApplicationServiceApi from .register import RegistrationHandler from .room import ( RoomCreationHandler, RoomMemberHandler, RoomListHandler @@ -26,6 +27,7 @@ from .presence import PresenceHandler from .directory import DirectoryHandler from .typing import TypingNotificationHandler from .admin import AdminHandler +from .appservice import ApplicationServicesHandler from .sync import SyncHandler @@ -52,4 +54,7 @@ class Handlers(object): self.directory_handler = DirectoryHandler(hs) self.typing_notification_handler = TypingNotificationHandler(hs) self.admin_handler = AdminHandler(hs) + self.appservice_handler = ApplicationServicesHandler( + hs, ApplicationServiceApi(hs) + ) self.sync_handler = SyncHandler(hs) diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py new file mode 100644 index 0000000000..2c488a46f6 --- /dev/null +++ b/synapse/handlers/appservice.py @@ -0,0 +1,211 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, Membership +from synapse.api.errors import Codes, StoreError, SynapseError +from synapse.appservice import ApplicationService +from synapse.types import UserID +import synapse.util.stringutils as stringutils + +import logging + +logger = logging.getLogger(__name__) + + +# NB: Purposefully not inheriting BaseHandler since that contains way too much +# setup code which this handler does not need or use. This makes testing a lot +# easier. +class ApplicationServicesHandler(object): + + def __init__(self, hs, appservice_api): + self.store = hs.get_datastore() + self.hs = hs + self.appservice_api = appservice_api + + @defer.inlineCallbacks + def register(self, app_service): + logger.info("Register -> %s", app_service) + # check the token is recognised + try: + stored_service = yield self.store.get_app_service_by_token( + app_service.token + ) + if not stored_service: + raise StoreError(404, "Application service not found") + except StoreError: + raise SynapseError( + 403, "Unrecognised application services token. " + "Consult the home server admin.", + errcode=Codes.FORBIDDEN + ) + + app_service.hs_token = self._generate_hs_token() + + # create a sender for this application service which is used when + # creating rooms, etc.. + account = yield self.hs.get_handlers().registration_handler.register() + app_service.sender = account[0] + + yield self.store.update_app_service(app_service) + defer.returnValue(app_service) + + @defer.inlineCallbacks + def unregister(self, token): + logger.info("Unregister as_token=%s", token) + yield self.store.unregister_app_service(token) + + @defer.inlineCallbacks + def notify_interested_services(self, event): + """Notifies (pushes) all application services interested in this event. + + Pushing is done asynchronously, so this method won't block for any + prolonged length of time. + + Args: + event(Event): The event to push out to interested services. + """ + # Gather interested services + services = yield self._get_services_for_event(event) + if len(services) == 0: + return # no services need notifying + + # Do we know this user exists? If not, poke the user query API for + # all services which match that user regex. This needs to block as these + # user queries need to be made BEFORE pushing the event. + yield self._check_user_exists(event.sender) + if event.type == EventTypes.Member: + yield self._check_user_exists(event.state_key) + + # Fork off pushes to these services - XXX First cut, best effort + for service in services: + self.appservice_api.push(service, event) + + @defer.inlineCallbacks + def query_user_exists(self, user_id): + """Check if any application service knows this user_id exists. + + Args: + user_id(str): The user to query if they exist on any AS. + Returns: + True if this user exists on at least one application service. + """ + user_query_services = yield self._get_services_for_user( + user_id=user_id + ) + for user_service in user_query_services: + is_known_user = yield self.appservice_api.query_user( + user_service, user_id + ) + if is_known_user: + defer.returnValue(True) + defer.returnValue(False) + + @defer.inlineCallbacks + def query_room_alias_exists(self, room_alias): + """Check if an application service knows this room alias exists. + + Args: + room_alias(RoomAlias): The room alias to query. + Returns: + namedtuple: with keys "room_id" and "servers" or None if no + association can be found. + """ + room_alias_str = room_alias.to_string() + alias_query_services = yield self._get_services_for_event( + event=None, + restrict_to=ApplicationService.NS_ALIASES, + alias_list=[room_alias_str] + ) + for alias_service in alias_query_services: + is_known_alias = yield self.appservice_api.query_alias( + alias_service, room_alias_str + ) + if is_known_alias: + # the alias exists now so don't query more ASes. + result = yield self.store.get_association_from_room_alias( + room_alias + ) + defer.returnValue(result) + + @defer.inlineCallbacks + def _get_services_for_event(self, event, restrict_to="", alias_list=None): + """Retrieve a list of application services interested in this event. + + Args: + event(Event): The event to check. Can be None if alias_list is not. + restrict_to(str): The namespace to restrict regex tests to. + alias_list: A list of aliases to get services for. If None, this + list is obtained from the database. + Returns: + list<ApplicationService>: A list of services interested in this + event based on the service regex. + """ + member_list = None + if hasattr(event, "room_id"): + # We need to know the aliases associated with this event.room_id, + # if any. + if not alias_list: + alias_list = yield self.store.get_aliases_for_room( + event.room_id + ) + # We need to know the members associated with this event.room_id, + # if any. + member_list = yield self.store.get_room_members( + room_id=event.room_id, + membership=Membership.JOIN + ) + + services = yield self.store.get_app_services() + interested_list = [ + s for s in services if ( + s.is_interested(event, restrict_to, alias_list, member_list) + ) + ] + defer.returnValue(interested_list) + + @defer.inlineCallbacks + def _get_services_for_user(self, user_id): + services = yield self.store.get_app_services() + interested_list = [ + s for s in services if ( + s.is_interested_in_user(user_id) + ) + ] + defer.returnValue(interested_list) + + @defer.inlineCallbacks + def _is_unknown_user(self, user_id): + user = UserID.from_string(user_id) + if not self.hs.is_mine(user): + # we don't know if they are unknown or not since it isn't one of our + # users. We can't poke ASes. + defer.returnValue(False) + return + + user_info = yield self.store.get_user_by_id(user_id) + defer.returnValue(len(user_info) == 0) + + @defer.inlineCallbacks + def _check_user_exists(self, user_id): + unknown_user = yield self._is_unknown_user(user_id) + if unknown_user: + exists = yield self.query_user_exists(user_id) + defer.returnValue(exists) + defer.returnValue(True) + + def _generate_hs_token(self): + return stringutils.random_string(24) diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 7b60921040..20ab9e269c 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -37,18 +37,15 @@ class DirectoryHandler(BaseHandler): ) @defer.inlineCallbacks - def create_association(self, user_id, room_alias, room_id, servers=None): - - # TODO(erikj): Do auth. + def _create_association(self, room_alias, room_id, servers=None): + # general association creation for both human users and app services if not self.hs.is_mine(room_alias): raise SynapseError(400, "Room alias must be local") # TODO(erikj): Change this. # TODO(erikj): Add transactions. - # TODO(erikj): Check if there is a current association. - if not servers: servers = yield self.store.get_joined_hosts_for_room(room_id) @@ -62,22 +59,77 @@ class DirectoryHandler(BaseHandler): ) @defer.inlineCallbacks + def create_association(self, user_id, room_alias, room_id, servers=None): + # association creation for human users + # TODO(erikj): Do user auth. + + can_create = yield self.can_modify_alias( + room_alias, + user_id=user_id + ) + if not can_create: + raise SynapseError( + 400, "This alias is reserved by an application service.", + errcode=Codes.EXCLUSIVE + ) + yield self._create_association(room_alias, room_id, servers) + + @defer.inlineCallbacks + def create_appservice_association(self, service, room_alias, room_id, + servers=None): + if not service.is_interested_in_alias(room_alias.to_string()): + raise SynapseError( + 400, "This application service has not reserved" + " this kind of alias.", errcode=Codes.EXCLUSIVE + ) + + # association creation for app services + yield self._create_association(room_alias, room_id, servers) + + @defer.inlineCallbacks def delete_association(self, user_id, room_alias): + # association deletion for human users + # TODO Check if server admin + can_delete = yield self.can_modify_alias( + room_alias, + user_id=user_id + ) + if not can_delete: + raise SynapseError( + 400, "This alias is reserved by an application service.", + errcode=Codes.EXCLUSIVE + ) + + yield self._delete_association(room_alias) + + @defer.inlineCallbacks + def delete_appservice_association(self, service, room_alias): + if not service.is_interested_in_alias(room_alias.to_string()): + raise SynapseError( + 400, + "This application service has not reserved this kind of alias", + errcode=Codes.EXCLUSIVE + ) + yield self._delete_association(room_alias) + + @defer.inlineCallbacks + def _delete_association(self, room_alias): if not self.hs.is_mine(room_alias): raise SynapseError(400, "Room alias must be local") - room_id = yield self.store.delete_room_alias(room_alias) + yield self.store.delete_room_alias(room_alias) - if room_id: - yield self._update_room_alias_events(user_id, room_id) + # TODO - Looks like _update_room_alias_event has never been implemented + # if room_id: + # yield self._update_room_alias_events(user_id, room_id) @defer.inlineCallbacks def get_association(self, room_alias): room_id = None if self.hs.is_mine(room_alias): - result = yield self.store.get_association_from_room_alias( + result = yield self.get_association_from_room_alias( room_alias ) @@ -138,7 +190,7 @@ class DirectoryHandler(BaseHandler): 400, "Room Alias is not hosted on this Home Server" ) - result = yield self.store.get_association_from_room_alias( + result = yield self.get_association_from_room_alias( room_alias ) @@ -166,3 +218,27 @@ class DirectoryHandler(BaseHandler): "sender": user_id, "content": {"aliases": aliases}, }, ratelimit=False) + + @defer.inlineCallbacks + def get_association_from_room_alias(self, room_alias): + result = yield self.store.get_association_from_room_alias( + room_alias + ) + if not result: + # Query AS to see if it exists + as_handler = self.hs.get_handlers().appservice_handler + result = yield as_handler.query_room_alias_exists(room_alias) + defer.returnValue(result) + + @defer.inlineCallbacks + def can_modify_alias(self, alias, user_id=None): + services = yield self.store.get_app_services() + interested_services = [ + s for s in services if s.is_interested_in_alias(alias.to_string()) + ] + for service in interested_services: + if user_id == service.sender: + # this user IS the app service + defer.returnValue(True) + return + defer.returnValue(len(interested_services) == 0) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3aecc1ca51..0eb2ff95ca 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -802,7 +802,7 @@ class FederationHandler(BaseHandler): missing_auth = event_auth_events - seen_events if missing_auth: - logger.debug("Missing auth: %s", missing_auth) + logger.info("Missing auth: %s", missing_auth) # If we don't have all the auth events, we need to get them. try: remote_auth_chain = yield self.replication_layer.get_event_auth( @@ -856,7 +856,7 @@ class FederationHandler(BaseHandler): if different_auth and not event.internal_metadata.is_outlier(): # Do auth conflict res. - logger.debug("Different auth: %s", different_auth) + logger.info("Different auth: %s", different_auth) different_events = yield defer.gatherResults( [ @@ -892,6 +892,8 @@ class FederationHandler(BaseHandler): context.state_group = None if different_auth and not event.internal_metadata.is_outlier(): + logger.info("Different auth after resolution: %s", different_auth) + # Only do auth resolution if we have something new to say. # We can't rove an auth failure. do_resolution = False diff --git a/synapse/handlers/login.py b/synapse/handlers/login.py index d297d71c03..7447800460 100644 --- a/synapse/handlers/login.py +++ b/synapse/handlers/login.py @@ -16,12 +16,13 @@ from twisted.internet import defer from ._base import BaseHandler -from synapse.api.errors import LoginError, Codes +from synapse.api.errors import LoginError, Codes, CodeMessageException from synapse.http.client import SimpleHttpClient from synapse.util.emailutils import EmailException import synapse.util.emailutils as emailutils import bcrypt +import json import logging logger = logging.getLogger(__name__) @@ -96,16 +97,20 @@ class LoginHandler(BaseHandler): @defer.inlineCallbacks def _query_email(self, email): - httpCli = SimpleHttpClient(self.hs) - data = yield httpCli.get_json( - # TODO FIXME This should be configurable. - # XXX: ID servers need to use HTTPS - "http://%s%s" % ( - "matrix.org:8090", "/_matrix/identity/api/v1/lookup" - ), - { - 'medium': 'email', - 'address': email - } - ) - defer.returnValue(data) + http_client = SimpleHttpClient(self.hs) + try: + data = yield http_client.get_json( + # TODO FIXME This should be configurable. + # XXX: ID servers need to use HTTPS + "http://%s%s" % ( + "matrix.org:8090", "/_matrix/identity/api/v1/lookup" + ), + { + 'medium': 'email', + 'address': email + } + ) + defer.returnValue(data) + except CodeMessageException as e: + data = json.loads(e.msg) + defer.returnValue(data) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 59287010ed..8ef248ecf2 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -492,7 +492,7 @@ class PresenceHandler(BaseHandler): user, domain, remoteusers )) - yield defer.DeferredList(deferreds) + yield defer.DeferredList(deferreds, consumeErrors=True) def _start_polling_local(self, user, target_user): target_localpart = target_user.localpart @@ -548,7 +548,7 @@ class PresenceHandler(BaseHandler): self._stop_polling_remote(user, domain, remoteusers) ) - return defer.DeferredList(deferreds) + return defer.DeferredList(deferreds, consumeErrors=True) def _stop_polling_local(self, user, target_user): for localpart in self._local_pushmap.keys(): @@ -729,7 +729,7 @@ class PresenceHandler(BaseHandler): del self._remote_sendmap[user] with PreserveLoggingContext(): - yield defer.DeferredList(deferreds) + yield defer.DeferredList(deferreds, consumeErrors=True) @defer.inlineCallbacks def push_update_to_local_and_remote(self, observed_user, statuscache, @@ -768,7 +768,7 @@ class PresenceHandler(BaseHandler): ) ) - yield defer.DeferredList(deferreds) + yield defer.DeferredList(deferreds, consumeErrors=True) defer.returnValue((localusers, remote_domains)) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 0247327eb9..516a936cee 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -18,7 +18,8 @@ from twisted.internet import defer from synapse.types import UserID from synapse.api.errors import ( - SynapseError, RegistrationError, InvalidCaptchaError + AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError, + CodeMessageException ) from ._base import BaseHandler import synapse.util.stringutils as stringutils @@ -28,6 +29,7 @@ from synapse.http.client import CaptchaServerHttpClient import base64 import bcrypt +import json import logging logger = logging.getLogger(__name__) @@ -64,6 +66,8 @@ class RegistrationHandler(BaseHandler): user = UserID(localpart, self.hs.hostname) user_id = user.to_string() + yield self.check_user_id_is_valid(user_id) + token = self._generate_token(user_id) yield self.store.register( user_id=user_id, @@ -82,6 +86,7 @@ class RegistrationHandler(BaseHandler): localpart = self._generate_user_id() user = UserID(localpart, self.hs.hostname) user_id = user.to_string() + yield self.check_user_id_is_valid(user_id) token = self._generate_token(user_id) yield self.store.register( @@ -122,6 +127,27 @@ class RegistrationHandler(BaseHandler): defer.returnValue((user_id, token)) @defer.inlineCallbacks + def appservice_register(self, user_localpart, as_token): + user = UserID(user_localpart, self.hs.hostname) + user_id = user.to_string() + service = yield self.store.get_app_service_by_token(as_token) + if not service: + raise AuthError(403, "Invalid application service token.") + if not service.is_interested_in_user(user_id): + raise SynapseError( + 400, "Invalid user localpart for this application service.", + errcode=Codes.EXCLUSIVE + ) + token = self._generate_token(user_id) + yield self.store.register( + user_id=user_id, + token=token, + password_hash="" + ) + self.distributor.fire("registered_user", user) + defer.returnValue((user_id, token)) + + @defer.inlineCallbacks def check_recaptcha(self, ip, private_key, challenge, response): """Checks a recaptcha is correct.""" @@ -167,6 +193,20 @@ class RegistrationHandler(BaseHandler): # XXX: This should be a deferred list, shouldn't it? yield self._bind_threepid(c, user_id) + @defer.inlineCallbacks + def check_user_id_is_valid(self, user_id): + # valid user IDs must not clash with any user ID namespaces claimed by + # application services. + services = yield self.store.get_app_services() + interested_services = [ + s for s in services if s.is_interested_in_user(user_id) + ] + if len(interested_services) > 0: + raise SynapseError( + 400, "This user ID is reserved by an application service.", + errcode=Codes.EXCLUSIVE + ) + def _generate_token(self, user_id): # urlsafe variant uses _ and - so use . as the separator and replace # all =s with .s so http clients don't quote =s when it is used as @@ -181,21 +221,26 @@ class RegistrationHandler(BaseHandler): def _threepid_from_creds(self, creds): # TODO: get this from the homeserver rather than creating a new one for # each request - httpCli = SimpleHttpClient(self.hs) + http_client = SimpleHttpClient(self.hs) # XXX: make this configurable! trustedIdServers = ['matrix.org:8090', 'matrix.org'] if not creds['idServer'] in trustedIdServers: logger.warn('%s is not a trusted ID server: rejecting 3pid ' + 'credentials', creds['idServer']) defer.returnValue(None) - data = yield httpCli.get_json( - # XXX: This should be HTTPS - "http://%s%s" % ( - creds['idServer'], - "/_matrix/identity/api/v1/3pid/getValidated3pid" - ), - {'sid': creds['sid'], 'clientSecret': creds['clientSecret']} - ) + + data = {} + try: + data = yield http_client.get_json( + # XXX: This should be HTTPS + "http://%s%s" % ( + creds['idServer'], + "/_matrix/identity/api/v1/3pid/getValidated3pid" + ), + {'sid': creds['sid'], 'clientSecret': creds['clientSecret']} + ) + except CodeMessageException as e: + data = json.loads(e.msg) if 'medium' in data: defer.returnValue(data) @@ -205,19 +250,23 @@ class RegistrationHandler(BaseHandler): def _bind_threepid(self, creds, mxid): yield logger.debug("binding threepid") - httpCli = SimpleHttpClient(self.hs) - data = yield httpCli.post_urlencoded_get_json( - # XXX: Change when ID servers are all HTTPS - "http://%s%s" % ( - creds['idServer'], "/_matrix/identity/api/v1/3pid/bind" - ), - { - 'sid': creds['sid'], - 'clientSecret': creds['clientSecret'], - 'mxid': mxid, - } - ) - logger.debug("bound threepid") + http_client = SimpleHttpClient(self.hs) + data = None + try: + data = yield http_client.post_urlencoded_get_json( + # XXX: Change when ID servers are all HTTPS + "http://%s%s" % ( + creds['idServer'], "/_matrix/identity/api/v1/3pid/bind" + ), + { + 'sid': creds['sid'], + 'clientSecret': creds['clientSecret'], + 'mxid': mxid, + } + ) + logger.debug("bound threepid") + except CodeMessageException as e: + data = json.loads(e.msg) defer.returnValue(data) @defer.inlineCallbacks diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index c69787005f..c2762f92c7 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -181,7 +181,7 @@ class TypingNotificationHandler(BaseHandler): }, )) - yield defer.DeferredList(deferreds, consumeErrors=False) + yield defer.DeferredList(deferreds, consumeErrors=True) @defer.inlineCallbacks def _recv_edu(self, origin, content): diff --git a/synapse/http/client.py b/synapse/http/client.py index 43b2329780..22b7145ac1 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from synapse.http.agent_name import AGENT_NAME +from synapse.api.errors import CodeMessageException from syutil.jsonutil import encode_canonical_json from twisted.internet import defer, reactor @@ -44,6 +43,7 @@ class SimpleHttpClient(object): # BrowserLikePolicyForHTTPS which will do regular cert validation # 'like a browser' self.agent = Agent(reactor) + self.version_string = hs.version_string @defer.inlineCallbacks def post_urlencoded_get_json(self, uri, args={}): @@ -55,7 +55,7 @@ class SimpleHttpClient(object): uri.encode("ascii"), headers=Headers({ b"Content-Type": [b"application/x-www-form-urlencoded"], - b"User-Agent": [AGENT_NAME], + b"User-Agent": [self.version_string], }), bodyProducer=FileBodyProducer(StringIO(query_bytes)) ) @@ -85,7 +85,7 @@ class SimpleHttpClient(object): @defer.inlineCallbacks def get_json(self, uri, args={}): - """ Get's some json from the given host and path + """ Gets some json from the given URI. Args: uri (str): The URI to request, not including query parameters @@ -93,15 +93,13 @@ class SimpleHttpClient(object): None. **Note**: The value of each key is assumed to be an iterable and *not* a string. - Returns: - Deferred: Succeeds when we get *any* HTTP response. - - The result of the deferred is a tuple of `(code, response)`, - where `response` is a dict representing the decoded JSON body. + Deferred: Succeeds when we get *any* 2xx HTTP response, with the + HTTP body as JSON. + Raises: + On a non-2xx HTTP response. The response body will be used as the + error message. """ - - yield if len(args): query_bytes = urllib.urlencode(args, True) uri = "%s?%s" % (uri, query_bytes) @@ -110,13 +108,62 @@ class SimpleHttpClient(object): "GET", uri.encode("ascii"), headers=Headers({ - b"User-Agent": [AGENT_NAME], + b"User-Agent": [self.version_string], }) ) body = yield readBody(response) - defer.returnValue(json.loads(body)) + if 200 <= response.code < 300: + defer.returnValue(json.loads(body)) + else: + # NB: This is explicitly not json.loads(body)'d because the contract + # of CodeMessageException is a *string* message. Callers can always + # load it into JSON if they want. + raise CodeMessageException(response.code, body) + + @defer.inlineCallbacks + def put_json(self, uri, json_body, args={}): + """ Puts some json to the given URI. + + Args: + uri (str): The URI to request, not including query parameters + json_body (dict): The JSON to put in the HTTP body, + args (dict): A dictionary used to create query strings, defaults to + None. + **Note**: The value of each key is assumed to be an iterable + and *not* a string. + Returns: + Deferred: Succeeds when we get *any* 2xx HTTP response, with the + HTTP body as JSON. + Raises: + On a non-2xx HTTP response. + """ + if len(args): + query_bytes = urllib.urlencode(args, True) + uri = "%s?%s" % (uri, query_bytes) + + json_str = json.dumps(json_body) + + response = yield self.agent.request( + "PUT", + uri.encode("ascii"), + headers=Headers({ + b"User-Agent": [self.version_string], + "Content-Type": ["application/json"] + }), + bodyProducer=FileBodyProducer(StringIO(json_str)) + ) + + body = yield readBody(response) + + if 200 <= response.code < 300: + defer.returnValue(json.loads(body)) + else: + # NB: This is explicitly not json.loads(body)'d because the contract + # of CodeMessageException is a *string* message. Callers can always + # load it into JSON if they want. + raise CodeMessageException(response.code, body) class CaptchaServerHttpClient(SimpleHttpClient): @@ -135,7 +182,7 @@ class CaptchaServerHttpClient(SimpleHttpClient): bodyProducer=FileBodyProducer(StringIO(query_bytes)), headers=Headers({ b"Content-Type": [b"application/x-www-form-urlencoded"], - b"User-Agent": [AGENT_NAME], + b"User-Agent": [self.version_string], }) ) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 1927948001..7db001cc63 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -20,7 +20,6 @@ from twisted.web.client import readBody, _AgentBase, _URI from twisted.web.http_headers import Headers from twisted.web._newclient import ResponseDone -from synapse.http.agent_name import AGENT_NAME from synapse.http.endpoint import matrix_federation_endpoint from synapse.util.async import sleep from synapse.util.logcontext import PreserveLoggingContext @@ -80,6 +79,7 @@ class MatrixFederationHttpClient(object): self.server_name = hs.hostname self.agent = MatrixFederationHttpAgent(reactor) self.clock = hs.get_clock() + self.version_string = hs.version_string @defer.inlineCallbacks def _create_request(self, destination, method, path_bytes, @@ -87,7 +87,7 @@ class MatrixFederationHttpClient(object): query_bytes=b"", retry_on_dns_fail=True): """ Creates and sends a request to the given url """ - headers_dict[b"User-Agent"] = [AGENT_NAME] + headers_dict[b"User-Agent"] = [self.version_string] headers_dict[b"Host"] = [destination] url_bytes = urlparse.urlunparse( @@ -144,16 +144,16 @@ class MatrixFederationHttpClient(object): destination, e ) - raise SynapseError(400, "Domain specified not found.") + raise logger.warn( - "Sending request failed to %s: %s %s : %s", + "Sending request failed to %s: %s %s: %s - %s", destination, method, url_bytes, - e + type(e).__name__, + _flatten_response_never_received(e), ) - _print_ex(e) if retries_left: yield sleep(2 ** (5 - retries_left)) @@ -447,14 +447,6 @@ def _readBodyToFile(response, stream, max_size): return d -def _print_ex(e): - if hasattr(e, "reasons") and e.reasons: - for ex in e.reasons: - _print_ex(ex) - else: - logger.warn(e) - - class _JsonProducer(object): """ Used by the twisted http client to create the HTTP body from json """ @@ -474,3 +466,13 @@ class _JsonProducer(object): def stopProducing(self): pass + + +def _flatten_response_never_received(e): + if hasattr(e, "reasons"): + return ", ".join( + _flatten_response_never_received(f.value) + for f in e.reasons + ) + else: + return "%s: %s" % (type(e).__name__, e.message,) diff --git a/synapse/http/server.py b/synapse/http/server.py index 589c30c199..74a101a5d7 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -14,7 +14,6 @@ # limitations under the License. -from synapse.http.agent_name import AGENT_NAME from synapse.api.errors import ( cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError ) @@ -74,6 +73,7 @@ class JsonResource(HttpServer, resource.Resource): self.clock = hs.get_clock() self.path_regexs = {} + self.version_string = hs.version_string def register_path(self, method, path_pattern, callback): self.path_regexs.setdefault(method, []).append( @@ -189,9 +189,13 @@ class JsonResource(HttpServer, resource.Resource): return # TODO: Only enable CORS for the requests that need it. - respond_with_json(request, code, response_json_object, send_cors=True, - response_code_message=response_code_message, - pretty_print=self._request_user_agent_is_curl) + respond_with_json( + request, code, response_json_object, + send_cors=True, + response_code_message=response_code_message, + pretty_print=self._request_user_agent_is_curl, + version_string=self.version_string, + ) @staticmethod def _request_user_agent_is_curl(request): @@ -221,18 +225,23 @@ class RootRedirect(resource.Resource): def respond_with_json(request, code, json_object, send_cors=False, - response_code_message=None, pretty_print=False): + response_code_message=None, pretty_print=False, + version_string=""): if not pretty_print: json_bytes = encode_pretty_printed_json(json_object) else: json_bytes = encode_canonical_json(json_object) - return respond_with_json_bytes(request, code, json_bytes, send_cors, - response_code_message=response_code_message) + return respond_with_json_bytes( + request, code, json_bytes, + send_cors=send_cors, + response_code_message=response_code_message, + version_string=version_string + ) def respond_with_json_bytes(request, code, json_bytes, send_cors=False, - response_code_message=None): + version_string="", response_code_message=None): """Sends encoded JSON in response to the given request. Args: @@ -246,7 +255,7 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False, request.setResponseCode(code, message=response_code_message) request.setHeader(b"Content-Type", b"application/json") - request.setHeader(b"Server", AGENT_NAME) + request.setHeader(b"Server", version_string) request.setHeader(b"Content-Length", b"%d" % (len(json_bytes),)) if send_cors: diff --git a/synapse/http/server_key_resource.py b/synapse/http/server_key_resource.py index 4fc491dc82..71e9a51f5c 100644 --- a/synapse/http/server_key_resource.py +++ b/synapse/http/server_key_resource.py @@ -50,6 +50,7 @@ class LocalKey(Resource): def __init__(self, hs): self.hs = hs + self.version_string = hs.version_string self.response_body = encode_canonical_json( self.response_json_object(hs.config) ) @@ -82,7 +83,10 @@ class LocalKey(Resource): return json_object def render_GET(self, request): - return respond_with_json_bytes(request, 200, self.response_body) + return respond_with_json_bytes( + request, 200, self.response_body, + version_string=self.version_string + ) def getChild(self, name, request): if name == '': diff --git a/synapse/notifier.py b/synapse/notifier.py index e3b6ead620..2475f3ffbe 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -99,6 +99,12 @@ class Notifier(object): `extra_users` param. """ yield run_on_reactor() + + # poke any interested application service. + self.hs.get_handlers().appservice_handler.notify_interested_services( + event + ) + room_id = event.room_id room_source = self.event_sources.sources["room"] @@ -135,7 +141,8 @@ class Notifier(object): with PreserveLoggingContext(): yield defer.DeferredList( - [notify(l).addErrback(eb) for l in listeners] + [notify(l).addErrback(eb) for l in listeners], + consumeErrors=True, ) @defer.inlineCallbacks @@ -203,7 +210,8 @@ class Notifier(object): with PreserveLoggingContext(): yield defer.DeferredList( - [notify(l).addErrback(eb) for l in listeners] + [notify(l).addErrback(eb) for l in listeners], + consumeErrors=True, ) @defer.inlineCallbacks diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 5788db4eba..202fcb42f4 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -66,6 +66,7 @@ class HttpPusher(Pusher): d = { 'notification': { 'id': event['event_id'], + 'room_id': event['room_id'], 'type': event['type'], 'sender': event['user_id'], 'counts': { # -- we don't mark messages as read yet so diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index ec78fc3627..c71df75404 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -5,7 +5,7 @@ logger = logging.getLogger(__name__) REQUIREMENTS = { "syutil>=0.0.3": ["syutil"], - "matrix_angular_sdk>=0.6.2": ["syweb>=0.6.2"], + "matrix_angular_sdk>=0.6.3": ["syweb>=0.6.3"], "Twisted==14.0.2": ["twisted==14.0.2"], "service_identity>=1.0.0": ["service_identity>=1.0.0"], "pyopenssl>=0.14": ["OpenSSL>=0.14"], @@ -25,20 +25,20 @@ def github_link(project, version, egg): DEPENDENCY_LINKS = [ github_link( + project="pyca/pynacl", + version="d4d3175589b892f6ea7c22f466e0e223853516fa", + egg="pynacl-0.3.0", + ), + github_link( project="matrix-org/syutil", version="v0.0.3", egg="syutil-0.0.3", ), github_link( project="matrix-org/matrix-angular-sdk", - version="v0.6.2", - egg="matrix_angular_sdk-0.6.2", + version="v0.6.3", + egg="matrix_angular_sdk-0.6.3", ), - github_link( - project="pyca/pynacl", - version="d4d3175589b892f6ea7c22f466e0e223853516fa", - egg="pynacl-0.3.0", - ) ] diff --git a/synapse/http/agent_name.py b/synapse/rest/appservice/__init__.py index d761890863..1a84d94cd9 100644 --- a/synapse/http/agent_name.py +++ b/synapse/rest/appservice/__init__.py @@ -1,5 +1,5 @@ # -*- coding: utf-8 -*- -# Copyright 2014, 2015 OpenMarket Ltd +# Copyright 2015 OpenMarket Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from synapse import __version__ - -AGENT_NAME = ("Synapse/%s" % (__version__,)).encode("ascii") diff --git a/synapse/rest/appservice/v1/__init__.py b/synapse/rest/appservice/v1/__init__.py new file mode 100644 index 0000000000..a7877609ad --- /dev/null +++ b/synapse/rest/appservice/v1/__init__.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from . import register + +from synapse.http.server import JsonResource + + +class AppServiceRestResource(JsonResource): + """A resource for version 1 of the matrix application service API.""" + + def __init__(self, hs): + JsonResource.__init__(self, hs) + self.register_servlets(self, hs) + + @staticmethod + def register_servlets(appservice_resource, hs): + register.register_servlets(hs, appservice_resource) diff --git a/synapse/rest/appservice/v1/base.py b/synapse/rest/appservice/v1/base.py new file mode 100644 index 0000000000..65d5bcf9be --- /dev/null +++ b/synapse/rest/appservice/v1/base.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains base REST classes for constructing client v1 servlets. +""" + +from synapse.http.servlet import RestServlet +from synapse.api.urls import APP_SERVICE_PREFIX +import re + +import logging + + +logger = logging.getLogger(__name__) + + +def as_path_pattern(path_regex): + """Creates a regex compiled appservice path with the correct path + prefix. + + Args: + path_regex (str): The regex string to match. This should NOT have a ^ + as this will be prefixed. + Returns: + SRE_Pattern + """ + return re.compile("^" + APP_SERVICE_PREFIX + path_regex) + + +class AppServiceRestServlet(RestServlet): + """A base Synapse REST Servlet for the application services version 1 API. + """ + + def __init__(self, hs): + self.hs = hs + self.handler = hs.get_handlers().appservice_handler diff --git a/synapse/rest/appservice/v1/register.py b/synapse/rest/appservice/v1/register.py new file mode 100644 index 0000000000..3bd0c1220c --- /dev/null +++ b/synapse/rest/appservice/v1/register.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains REST servlets to do with registration: /register""" +from twisted.internet import defer + +from base import AppServiceRestServlet, as_path_pattern +from synapse.api.errors import CodeMessageException, SynapseError +from synapse.storage.appservice import ApplicationService + +import json +import logging + +logger = logging.getLogger(__name__) + + +class RegisterRestServlet(AppServiceRestServlet): + """Handles AS registration with the home server. + """ + + PATTERN = as_path_pattern("/register$") + + @defer.inlineCallbacks + def on_POST(self, request): + params = _parse_json(request) + + # sanity check required params + try: + as_token = params["as_token"] + as_url = params["url"] + if (not isinstance(as_token, basestring) or + not isinstance(as_url, basestring)): + raise ValueError + except (KeyError, ValueError): + raise SynapseError( + 400, "Missed required keys: as_token(str) / url(str)." + ) + + namespaces = { + "users": [], + "rooms": [], + "aliases": [] + } + + if "namespaces" in params: + self._parse_namespace(namespaces, params["namespaces"], "users") + self._parse_namespace(namespaces, params["namespaces"], "rooms") + self._parse_namespace(namespaces, params["namespaces"], "aliases") + + app_service = ApplicationService(as_token, as_url, namespaces) + + app_service = yield self.handler.register(app_service) + hs_token = app_service.hs_token + + defer.returnValue((200, { + "hs_token": hs_token + })) + + def _parse_namespace(self, target_ns, origin_ns, ns): + if ns not in target_ns or ns not in origin_ns: + return # nothing to parse / map through to. + + possible_regex_list = origin_ns[ns] + if not type(possible_regex_list) == list: + raise SynapseError(400, "Namespace %s isn't an array." % ns) + + for regex in possible_regex_list: + if not isinstance(regex, basestring): + raise SynapseError( + 400, "Regex '%s' isn't a string in namespace %s" % + (regex, ns) + ) + + target_ns[ns] = origin_ns[ns] + + +class UnregisterRestServlet(AppServiceRestServlet): + """Handles AS registration with the home server. + """ + + PATTERN = as_path_pattern("/unregister$") + + def on_POST(self, request): + params = _parse_json(request) + try: + as_token = params["as_token"] + if not isinstance(as_token, basestring): + raise ValueError + except (KeyError, ValueError): + raise SynapseError(400, "Missing required key: as_token(str)") + + yield self.handler.unregister(as_token) + + raise CodeMessageException(500, "Not implemented") + + +def _parse_json(request): + try: + content = json.loads(request.content.read()) + if type(content) != dict: + raise SynapseError(400, "Content must be a JSON object.") + return content + except ValueError: + raise SynapseError(400, "Content not JSON.") + + +def register_servlets(hs, http_server): + RegisterRestServlet(hs).register(http_server) + UnregisterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 420aa89f38..6758a888b3 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -45,8 +45,6 @@ class ClientDirectoryServer(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_alias): - user, client = yield self.auth.get_user_by_req(request) - content = _parse_json(request) if "room_id" not in content: raise SynapseError(400, "Missing room_id key", @@ -70,34 +68,70 @@ class ClientDirectoryServer(ClientV1RestServlet): dir_handler = self.handlers.directory_handler try: - user_id = user.to_string() - yield dir_handler.create_association( - user_id, room_alias, room_id, servers + # try to auth as a user + user, client = yield self.auth.get_user_by_req(request) + try: + user_id = user.to_string() + yield dir_handler.create_association( + user_id, room_alias, room_id, servers + ) + yield dir_handler.send_room_alias_update_event(user_id, room_id) + except SynapseError as e: + raise e + except: + logger.exception("Failed to create association") + raise + except AuthError: + # try to auth as an application service + service = yield self.auth.get_appservice_by_req(request) + yield dir_handler.create_appservice_association( + service, room_alias, room_id, servers + ) + logger.info( + "Application service at %s created alias %s pointing to %s", + service.url, + room_alias.to_string(), + room_id ) - yield dir_handler.send_room_alias_update_event(user_id, room_id) - except SynapseError as e: - raise e - except: - logger.exception("Failed to create association") - raise defer.returnValue((200, {})) @defer.inlineCallbacks def on_DELETE(self, request, room_alias): + dir_handler = self.handlers.directory_handler + + try: + service = yield self.auth.get_appservice_by_req(request) + room_alias = RoomAlias.from_string(room_alias) + yield dir_handler.delete_appservice_association( + service, room_alias + ) + logger.info( + "Application service at %s deleted alias %s", + service.url, + room_alias.to_string() + ) + defer.returnValue((200, {})) + except AuthError: + # fallback to default user behaviour if they aren't an AS + pass + user, client = yield self.auth.get_user_by_req(request) is_admin = yield self.auth.is_server_admin(user) if not is_admin: raise AuthError(403, "You need to be a server admin") - dir_handler = self.handlers.directory_handler - room_alias = RoomAlias.from_string(room_alias) yield dir_handler.delete_association( user.to_string(), room_alias ) + logger.info( + "User %s deleted alias %s", + user.to_string(), + room_alias.to_string() + ) defer.returnValue((200, {})) diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index d3399c446b..8d2115082b 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -110,7 +110,8 @@ class RegisterRestServlet(ClientV1RestServlet): stages = { LoginType.RECAPTCHA: self._do_recaptcha, LoginType.PASSWORD: self._do_password, - LoginType.EMAIL_IDENTITY: self._do_email_identity + LoginType.EMAIL_IDENTITY: self._do_email_identity, + LoginType.APPLICATION_SERVICE: self._do_app_service } session_info = self._get_session_info(request, session) @@ -276,6 +277,27 @@ class RegisterRestServlet(ClientV1RestServlet): self._remove_session(session) defer.returnValue(result) + @defer.inlineCallbacks + def _do_app_service(self, request, register_json, session): + if "access_token" not in request.args: + raise SynapseError(400, "Expected application service token.") + if "user" not in register_json: + raise SynapseError(400, "Expected 'user' key.") + + as_token = request.args["access_token"][0] + user_localpart = register_json["user"].encode("utf-8") + + handler = self.handlers.registration_handler + (user_id, token) = yield handler.appservice_register( + user_localpart, as_token + ) + self._remove_session(session) + defer.returnValue({ + "user_id": user_id, + "access_token": token, + "home_server": self.hs.hostname, + }) + def _parse_json(request): try: diff --git a/synapse/rest/media/v1/base_resource.py b/synapse/rest/media/v1/base_resource.py index d44d5f1298..b10cbddb81 100644 --- a/synapse/rest/media/v1/base_resource.py +++ b/synapse/rest/media/v1/base_resource.py @@ -54,7 +54,7 @@ class BaseMediaResource(Resource): try: yield request_handler(self, request) except CodeMessageException as e: - logger.exception(e) + logger.info("Responding with error: %r", e) respond_with_json( request, e.code, cs_exception(e), send_cors=True ) diff --git a/synapse/server.py b/synapse/server.py index 23bdad0c7c..ba2b2593f1 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -77,6 +77,7 @@ class BaseHomeServer(object): 'resource_for_content_repo', 'resource_for_server_key', 'resource_for_media_repository', + 'resource_for_app_services', 'event_sources', 'ratelimiter', 'keyring', diff --git a/synapse/state.py b/synapse/state.py index fe5f3dc84b..80cced351d 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor +from synapse.util.expiringcache import ExpiringCache from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.events.snapshot import EventContext @@ -51,7 +52,6 @@ class _StateCacheEntry(object): def __init__(self, state, state_group, ts): self.state = state self.state_group = state_group - self.ts = ts class StateHandler(object): @@ -69,12 +69,15 @@ class StateHandler(object): def start_caching(self): logger.debug("start_caching") - self._state_cache = {} - - def f(): - self._prune_cache() + self._state_cache = ExpiringCache( + cache_name="state_cache", + clock=self.clock, + max_len=SIZE_OF_CACHE, + expiry_ms=EVICTION_TIMEOUT_SECONDS*1000, + reset_expiry_on_get=True, + ) - self.clock.looping_call(f, 5*1000) + self._state_cache.start() @defer.inlineCallbacks def get_current_state(self, room_id, event_type=None, state_key=""): @@ -409,34 +412,3 @@ class StateHandler(object): return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() return sorted(events, key=key_func) - - def _prune_cache(self): - logger.debug( - "_prune_cache. before len: %d", - len(self._state_cache.keys()) - ) - - now = self.clock.time_msec() - - if len(self._state_cache.keys()) > SIZE_OF_CACHE: - sorted_entries = sorted( - self._state_cache.items(), - key=lambda k, v: v.ts, - ) - - for k, _ in sorted_entries[SIZE_OF_CACHE:]: - self._state_cache.pop(k) - - keys_to_delete = set() - - for key, cache_entry in self._state_cache.items(): - if now - cache_entry.ts > EVICTION_TIMEOUT_SECONDS*1000: - keys_to_delete.add(key) - - for k in keys_to_delete: - self._state_cache.pop(k) - - logger.debug( - "_prune_cache. after len: %d", - len(self._state_cache.keys()) - ) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index 02b1f06854..d16e7b8fac 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.util.logutils import log_function from synapse.api.constants import EventTypes +from .appservice import ApplicationServiceStore from .directory import DirectoryStore from .feedback import FeedbackStore from .presence import PresenceStore @@ -65,6 +66,7 @@ SCHEMAS = [ "event_signatures", "pusher", "media_repository", + "application_services", "filtering", "rejections", ] @@ -72,7 +74,9 @@ SCHEMAS = [ # Remember to update this number every time an incompatible change is made to # database schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 12 +SCHEMA_VERSION = 13 + +dir_path = os.path.abspath(os.path.dirname(__file__)) class _RollbackButIsFineException(Exception): @@ -86,6 +90,7 @@ class DataStore(RoomMemberStore, RoomStore, RegistrationStore, StreamStore, ProfileStore, FeedbackStore, PresenceStore, TransactionStore, DirectoryStore, KeyStore, StateStore, SignatureStore, + ApplicationServiceStore, EventFederationStore, MediaRepositoryStore, RejectionsStore, @@ -580,7 +585,6 @@ def schema_path(schema): A filesystem path pointing at a ".sql" file. """ - dir_path = os.path.dirname(__file__) schemaPath = os.path.join(dir_path, "schema", schema + ".sql") return schemaPath @@ -637,10 +641,13 @@ def prepare_database(db_conn): c.executescript(sql_script) db_conn.commit() + else: + logger.info("Database is at version %r", user_version) else: sql_script = "BEGIN TRANSACTION;\n" for sql_loc in SCHEMAS: + logger.debug("Applying schema %r", sql_loc) sql_script += read_schema(sql_loc) sql_script += "\n" sql_script += "COMMIT TRANSACTION;" diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py new file mode 100644 index 0000000000..d941b1f387 --- /dev/null +++ b/synapse/storage/appservice.py @@ -0,0 +1,244 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from twisted.internet import defer + +from synapse.api.errors import StoreError +from synapse.appservice import ApplicationService +from ._base import SQLBaseStore + + +logger = logging.getLogger(__name__) + + +class ApplicationServiceCache(object): + """Caches ApplicationServices and provides utility functions on top. + + This class is designed to be invoked on incoming events in order to avoid + hammering the database every time to extract a list of application service + regexes. + """ + + def __init__(self): + self.services = [] + + +class ApplicationServiceStore(SQLBaseStore): + + def __init__(self, hs): + super(ApplicationServiceStore, self).__init__(hs) + self.cache = ApplicationServiceCache() + self.cache_defer = self._populate_cache() + + @defer.inlineCallbacks + def unregister_app_service(self, token): + """Unregisters this service. + + This removes all AS specific regex and the base URL. The token is the + only thing preserved for future registration attempts. + """ + yield self.cache_defer # make sure the cache is ready + yield self.runInteraction( + "unregister_app_service", + self._unregister_app_service_txn, + token, + ) + # update cache TODO: Should this be in the txn? + for service in self.cache.services: + if service.token == token: + service.url = None + service.namespaces = None + service.hs_token = None + + def _unregister_app_service_txn(self, txn, token): + # kill the url to prevent pushes + txn.execute( + "UPDATE application_services SET url=NULL WHERE token=?", + (token,) + ) + + # cleanup regex + as_id = self._get_as_id_txn(txn, token) + if not as_id: + logger.warning( + "unregister_app_service_txn: Failed to find as_id for token=", + token + ) + return False + + txn.execute( + "DELETE FROM application_services_regex WHERE as_id=?", + (as_id,) + ) + return True + + @defer.inlineCallbacks + def update_app_service(self, service): + """Update an application service, clobbering what was previously there. + + Args: + service(ApplicationService): The updated service. + """ + yield self.cache_defer # make sure the cache is ready + + # NB: There is no "insert" since we provide no public-facing API to + # allocate new ASes. It relies on the server admin inserting the AS + # token into the database manually. + + if not service.token or not service.url: + raise StoreError(400, "Token and url must be specified.") + + if not service.hs_token: + raise StoreError(500, "No HS token") + + yield self.runInteraction( + "update_app_service", + self._update_app_service_txn, + service + ) + + # update cache TODO: Should this be in the txn? + for (index, cache_service) in enumerate(self.cache.services): + if service.token == cache_service.token: + self.cache.services[index] = service + logger.info("Updated: %s", service) + return + # new entry + self.cache.services.append(service) + logger.info("Updated(new): %s", service) + + def _update_app_service_txn(self, txn, service): + as_id = self._get_as_id_txn(txn, service.token) + if not as_id: + logger.warning( + "update_app_service_txn: Failed to find as_id for token=", + service.token + ) + return False + + txn.execute( + "UPDATE application_services SET url=?, hs_token=?, sender=? " + "WHERE id=?", + (service.url, service.hs_token, service.sender, as_id,) + ) + # cleanup regex + txn.execute( + "DELETE FROM application_services_regex WHERE as_id=?", + (as_id,) + ) + for (ns_int, ns_str) in enumerate(ApplicationService.NS_LIST): + if ns_str in service.namespaces: + for regex in service.namespaces[ns_str]: + txn.execute( + "INSERT INTO application_services_regex(" + "as_id, namespace, regex) values(?,?,?)", + (as_id, ns_int, regex) + ) + return True + + def _get_as_id_txn(self, txn, token): + cursor = txn.execute( + "SELECT id FROM application_services WHERE token=?", + (token,) + ) + res = cursor.fetchone() + if res: + return res[0] + + @defer.inlineCallbacks + def get_app_services(self): + yield self.cache_defer # make sure the cache is ready + defer.returnValue(self.cache.services) + + @defer.inlineCallbacks + def get_app_service_by_token(self, token, from_cache=True): + """Get the application service with the given token. + + Args: + token (str): The application service token. + from_cache (bool): True to get this service from the cache, False to + check the database. + Raises: + StoreError if there was a problem retrieving this service. + """ + yield self.cache_defer # make sure the cache is ready + + if from_cache: + for service in self.cache.services: + if service.token == token: + defer.returnValue(service) + return + defer.returnValue(None) + + # TODO: The from_cache=False impl + # TODO: This should be JOINed with the application_services_regex table. + + @defer.inlineCallbacks + def _populate_cache(self): + """Populates the ApplicationServiceCache from the database.""" + sql = ("SELECT * FROM application_services LEFT JOIN " + "application_services_regex ON application_services.id = " + "application_services_regex.as_id") + # SQL results in the form: + # [ + # { + # 'regex': "something", + # 'url': "something", + # 'namespace': enum, + # 'as_id': 0, + # 'token': "something", + # 'hs_token': "otherthing", + # 'id': 0 + # } + # ] + services = {} + results = yield self._execute_and_decode(sql) + for res in results: + as_token = res["token"] + if as_token not in services: + # add the service + services[as_token] = { + "url": res["url"], + "token": as_token, + "hs_token": res["hs_token"], + "sender": res["sender"], + "namespaces": { + ApplicationService.NS_USERS: [], + ApplicationService.NS_ALIASES: [], + ApplicationService.NS_ROOMS: [] + } + } + # add the namespace regex if one exists + ns_int = res["namespace"] + if ns_int is None: + continue + try: + services[as_token]["namespaces"][ + ApplicationService.NS_LIST[ns_int]].append( + res["regex"] + ) + except IndexError: + logger.error("Bad namespace enum '%s'. %s", ns_int, res) + + # TODO get last successful txn id f.e. service + for service in services.values(): + logger.info("Found application service: %s", service) + self.cache.services.append(ApplicationService( + token=service["token"], + url=service["url"], + namespaces=service["namespaces"], + hs_token=service["hs_token"], + sender=service["sender"] + )) diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 779f9ce544..9bf608bc90 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -288,7 +288,7 @@ class RoomMemberStore(SQLBaseStore): deferreds = [self.get_rooms_for_user(u) for u in user_id_list] - results = yield defer.DeferredList(deferreds) + results = yield defer.DeferredList(deferreds, consumeErrors=True) # A list of sets of strings giving room IDs for each user room_id_lists = [set([r.room_id for r in result[1]]) for result in results] diff --git a/synapse/storage/schema/application_services.sql b/synapse/storage/schema/application_services.sql new file mode 100644 index 0000000000..e491ad5aec --- /dev/null +++ b/synapse/storage/schema/application_services.sql @@ -0,0 +1,34 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS application_services( + id INTEGER PRIMARY KEY AUTOINCREMENT, + url TEXT, + token TEXT, + hs_token TEXT, + sender TEXT, + UNIQUE(token) ON CONFLICT ROLLBACK +); + +CREATE TABLE IF NOT EXISTS application_services_regex( + id INTEGER PRIMARY KEY AUTOINCREMENT, + as_id INTEGER NOT NULL, + namespace INTEGER, /* enum[room_id|room_alias|user_id] */ + regex TEXT, + FOREIGN KEY(as_id) REFERENCES application_services(id) +); + + + diff --git a/synapse/storage/schema/delta/v13.sql b/synapse/storage/schema/delta/v13.sql new file mode 100644 index 0000000000..e491ad5aec --- /dev/null +++ b/synapse/storage/schema/delta/v13.sql @@ -0,0 +1,34 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS application_services( + id INTEGER PRIMARY KEY AUTOINCREMENT, + url TEXT, + token TEXT, + hs_token TEXT, + sender TEXT, + UNIQUE(token) ON CONFLICT ROLLBACK +); + +CREATE TABLE IF NOT EXISTS application_services_regex( + id INTEGER PRIMARY KEY AUTOINCREMENT, + as_id INTEGER NOT NULL, + namespace INTEGER, /* enum[room_id|room_alias|user_id] */ + regex TEXT, + FOREIGN KEY(as_id) REFERENCES application_services(id) +); + + + diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index e77eba90ad..79109d0b19 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -99,8 +99,6 @@ class Clock(object): except: pass - return res - given_deferred.addCallbacks(callback=sucess, errback=err) timer = self.call_later(time_out, timed_out_fn) diff --git a/synapse/util/expiringcache.py b/synapse/util/expiringcache.py new file mode 100644 index 0000000000..1c7859297a --- /dev/null +++ b/synapse/util/expiringcache.py @@ -0,0 +1,115 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + + +logger = logging.getLogger(__name__) + + +class ExpiringCache(object): + def __init__(self, cache_name, clock, max_len=0, expiry_ms=0, + reset_expiry_on_get=False): + """ + Args: + cache_name (str): Name of this cache, used for logging. + clock (Clock) + max_len (int): Max size of dict. If the dict grows larger than this + then the oldest items get automatically evicted. Default is 0, + which indicates there is no max limit. + expiry_ms (int): How long before an item is evicted from the cache + in milliseconds. Default is 0, indicating items never get + evicted based on time. + reset_expiry_on_get (bool): If true, will reset the expiry time for + an item on access. Defaults to False. + + """ + self._cache_name = cache_name + + self._clock = clock + + self._max_len = max_len + self._expiry_ms = expiry_ms + + self._reset_expiry_on_get = reset_expiry_on_get + + self._cache = {} + + def start(self): + if not self._expiry_ms: + # Don't bother starting the loop if things never expire + return + + def f(): + self._prune_cache() + + self._clock.looping_call(f, self._expiry_ms/2) + + def __setitem__(self, key, value): + now = self._clock.time_msec() + self._cache[key] = _CacheEntry(now, value) + + # Evict if there are now too many items + if self._max_len and len(self._cache.keys()) > self._max_len: + sorted_entries = sorted( + self._cache.items(), + key=lambda k, v: v.time, + ) + + for k, _ in sorted_entries[self._max_len:]: + self._cache.pop(k) + + def __getitem__(self, key): + entry = self._cache[key] + + if self._reset_expiry_on_get: + entry.time = self._clock.time_msec() + + return entry.value + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def _prune_cache(self): + if not self._expiry_ms: + # zero expiry time means don't expire. This should never get called + # since we have this check in start too. + return + begin_length = len(self._cache) + + now = self._clock.time_msec() + + keys_to_delete = set() + + for key, cache_entry in self._cache.items(): + if now - cache_entry.time > self._expiry_ms: + keys_to_delete.add(key) + + for k in keys_to_delete: + self._cache.pop(k) + + logger.debug( + "[%s] _prune_cache before: %d, after len: %d", + self._cache_name, begin_length, len(self._cache.keys()) + ) + + +class _CacheEntry(object): + def __init__(self, time, value): + self.time = time + self.value = value diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py new file mode 100644 index 0000000000..4e82232796 --- /dev/null +++ b/synapse/util/retryutils.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import CodeMessageException + +import logging + + +logger = logging.getLogger(__name__) + + +class NotRetryingDestination(Exception): + def __init__(self, retry_last_ts, retry_interval, destination): + msg = "Not retrying server %s." % (destination,) + super(NotRetryingDestination, self).__init__(msg) + + self.retry_last_ts = retry_last_ts + self.retry_interval = retry_interval + self.destination = destination + + +@defer.inlineCallbacks +def get_retry_limiter(destination, clock, store, **kwargs): + """For a given destination check if we have previously failed to + send a request there and are waiting before retrying the destination. + If we are not ready to retry the destination, this will raise a + NotRetryingDestination exception. Otherwise, will return a Context Manager + that will mark the destination as down if an exception is thrown (excluding + CodeMessageException with code < 500) + + Example usage: + + try: + limiter = yield get_retry_limiter(destination, clock, store) + with limiter: + response = yield do_request() + except NotRetryingDestination: + # We aren't ready to retry that destination. + raise + """ + retry_last_ts, retry_interval = (0, 0) + + retry_timings = yield store.get_destination_retry_timings( + destination + ) + + if retry_timings: + retry_last_ts, retry_interval = ( + retry_timings.retry_last_ts, retry_timings.retry_interval + ) + + now = int(clock.time_msec()) + + if retry_last_ts + retry_interval > now: + raise NotRetryingDestination( + retry_last_ts=retry_last_ts, + retry_interval=retry_interval, + destination=destination, + ) + + defer.returnValue( + RetryDestinationLimiter( + destination, + clock, + store, + retry_interval, + **kwargs + ) + ) + + +class RetryDestinationLimiter(object): + def __init__(self, destination, clock, store, retry_interval, + min_retry_interval=5000, max_retry_interval=60 * 60 * 1000, + multiplier_retry_interval=2,): + """Marks the destination as "down" if an exception is thrown in the + context, except for CodeMessageException with code < 500. + + If no exception is raised, marks the destination as "up". + + Args: + destination (str) + clock (Clock) + store (DataStore) + retry_interval (int): The next retry interval taken from the + database in milliseconds, or zero if the last request was + successful. + min_retry_interval (int): The minimum retry interval to use after + a failed request, in milliseconds. + max_retry_interval (int): The maximum retry interval to use after + a failed request, in milliseconds. + multiplier_retry_interval (int): The multiplier to use to increase + the retry interval after a failed request. + """ + self.clock = clock + self.store = store + self.destination = destination + + self.retry_interval = retry_interval + self.min_retry_interval = min_retry_interval + self.max_retry_interval = max_retry_interval + self.multiplier_retry_interval = multiplier_retry_interval + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + def err(failure): + logger.exception( + "Failed to store set_destination_retry_timings", + failure.value + ) + + valid_err_code = False + if exc_type is CodeMessageException: + valid_err_code = 0 <= exc_val.code < 500 + + if exc_type is None or valid_err_code: + # We connected successfully. + if not self.retry_interval: + return + + retry_last_ts = 0 + self.retry_interval = 0 + else: + # We couldn't connect. + if self.retry_interval: + self.retry_interval *= self.multiplier_retry_interval + + if self.retry_interval >= self.max_retry_interval: + self.retry_interval = self.max_retry_interval + else: + self.retry_interval = self.min_retry_interval + + retry_last_ts = int(self.clock.time_msec()) + + self.store.set_destination_retry_timings( + self.destination, retry_last_ts, self.retry_interval + ).addErrback(err) |