summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/__init__.py8
-rw-r--r--synapse/federation/federation_base.py33
-rw-r--r--synapse/federation/federation_client.py61
-rw-r--r--synapse/federation/federation_server.py171
-rw-r--r--synapse/federation/replication.py73
-rw-r--r--synapse/federation/transaction_queue.py8
-rw-r--r--synapse/federation/transport/client.py3
-rw-r--r--synapse/federation/transport/server.py11
8 files changed, 170 insertions, 198 deletions
diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py
index 2e32d245ba..f5f0bdfca3 100644
--- a/synapse/federation/__init__.py
+++ b/synapse/federation/__init__.py
@@ -15,11 +15,3 @@
 
 """ This package includes all the federation specific logic.
 """
-
-from .replication import ReplicationLayer
-
-
-def initialize_http_replication(hs):
-    transport = hs.get_federation_transport_client()
-
-    return ReplicationLayer(hs, transport)
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index a0f5d40eb3..79eaa31031 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -16,7 +16,9 @@ import logging
 
 from synapse.api.errors import SynapseError
 from synapse.crypto.event_signing import check_event_content_hash
+from synapse.events import FrozenEvent
 from synapse.events.utils import prune_event
+from synapse.http.servlet import assert_params_in_request
 from synapse.util import unwrapFirstError, logcontext
 from twisted.internet import defer
 
@@ -25,7 +27,13 @@ logger = logging.getLogger(__name__)
 
 class FederationBase(object):
     def __init__(self, hs):
+        self.hs = hs
+
+        self.server_name = hs.hostname
+        self.keyring = hs.get_keyring()
         self.spam_checker = hs.get_spam_checker()
+        self.store = hs.get_datastore()
+        self._clock = hs.get_clock()
 
     @defer.inlineCallbacks
     def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
@@ -169,3 +177,28 @@ class FederationBase(object):
             )
 
         return deferreds
+
+
+def event_from_pdu_json(pdu_json, outlier=False):
+    """Construct a FrozenEvent from an event json received over federation
+
+    Args:
+        pdu_json (object): pdu as received over federation
+        outlier (bool): True to mark this event as an outlier
+
+    Returns:
+        FrozenEvent
+
+    Raises:
+        SynapseError: if the pdu is missing required fields
+    """
+    # we could probably enforce a bunch of other fields here (room_id, sender,
+    # origin, etc etc)
+    assert_params_in_request(pdu_json, ('event_id', 'type'))
+    event = FrozenEvent(
+        pdu_json
+    )
+
+    event.internal_metadata.outlier = outlier
+
+    return event
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index b8f02f5391..38440da5b5 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -14,29 +14,29 @@
 # limitations under the License.
 
 
+import copy
+import itertools
+import logging
+import random
+
 from twisted.internet import defer
 
-from .federation_base import FederationBase
 from synapse.api.constants import Membership
-
 from synapse.api.errors import (
-    CodeMessageException, HttpResponseException, SynapseError,
+    CodeMessageException, HttpResponseException, SynapseError, FederationDeniedError
 )
-from synapse.util import unwrapFirstError, logcontext
+from synapse.events import builder
+from synapse.federation.federation_base import (
+    FederationBase,
+    event_from_pdu_json,
+)
+import synapse.metrics
+from synapse.util import logcontext, unwrapFirstError
 from synapse.util.caches.expiringcache import ExpiringCache
-from synapse.util.logutils import log_function
 from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
-from synapse.events import FrozenEvent, builder
-import synapse.metrics
-
+from synapse.util.logutils import log_function
 from synapse.util.retryutils import NotRetryingDestination
 
-import copy
-import itertools
-import logging
-import random
-
-
 logger = logging.getLogger(__name__)
 
 
@@ -58,6 +58,7 @@ class FederationClient(FederationBase):
             self._clear_tried_cache, 60 * 1000,
         )
         self.state = hs.get_state_handler()
+        self.transport_layer = hs.get_federation_transport_client()
 
     def _clear_tried_cache(self):
         """Clear pdu_destination_tried cache"""
@@ -184,7 +185,7 @@ class FederationClient(FederationBase):
         logger.debug("backfill transaction_data=%s", repr(transaction_data))
 
         pdus = [
-            self.event_from_pdu_json(p, outlier=False)
+            event_from_pdu_json(p, outlier=False)
             for p in transaction_data["pdus"]
         ]
 
@@ -244,7 +245,7 @@ class FederationClient(FederationBase):
                 logger.debug("transaction_data %r", transaction_data)
 
                 pdu_list = [
-                    self.event_from_pdu_json(p, outlier=outlier)
+                    event_from_pdu_json(p, outlier=outlier)
                     for p in transaction_data["pdus"]
                 ]
 
@@ -266,6 +267,9 @@ class FederationClient(FederationBase):
             except NotRetryingDestination as e:
                 logger.info(e.message)
                 continue
+            except FederationDeniedError as e:
+                logger.info(e.message)
+                continue
             except Exception as e:
                 pdu_attempts[destination] = now
 
@@ -336,11 +340,11 @@ class FederationClient(FederationBase):
         )
 
         pdus = [
-            self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
+            event_from_pdu_json(p, outlier=True) for p in result["pdus"]
         ]
 
         auth_chain = [
-            self.event_from_pdu_json(p, outlier=True)
+            event_from_pdu_json(p, outlier=True)
             for p in result.get("auth_chain", [])
         ]
 
@@ -441,7 +445,7 @@ class FederationClient(FederationBase):
         )
 
         auth_chain = [
-            self.event_from_pdu_json(p, outlier=True)
+            event_from_pdu_json(p, outlier=True)
             for p in res["auth_chain"]
         ]
 
@@ -570,12 +574,12 @@ class FederationClient(FederationBase):
                 logger.debug("Got content: %s", content)
 
                 state = [
-                    self.event_from_pdu_json(p, outlier=True)
+                    event_from_pdu_json(p, outlier=True)
                     for p in content.get("state", [])
                 ]
 
                 auth_chain = [
-                    self.event_from_pdu_json(p, outlier=True)
+                    event_from_pdu_json(p, outlier=True)
                     for p in content.get("auth_chain", [])
                 ]
 
@@ -650,7 +654,7 @@ class FederationClient(FederationBase):
 
         logger.debug("Got response to send_invite: %s", pdu_dict)
 
-        pdu = self.event_from_pdu_json(pdu_dict)
+        pdu = event_from_pdu_json(pdu_dict)
 
         # Check signatures are correct.
         pdu = yield self._check_sigs_and_hash(pdu)
@@ -740,7 +744,7 @@ class FederationClient(FederationBase):
         )
 
         auth_chain = [
-            self.event_from_pdu_json(e)
+            event_from_pdu_json(e)
             for e in content["auth_chain"]
         ]
 
@@ -788,7 +792,7 @@ class FederationClient(FederationBase):
             )
 
             events = [
-                self.event_from_pdu_json(e)
+                event_from_pdu_json(e)
                 for e in content.get("events", [])
             ]
 
@@ -805,15 +809,6 @@ class FederationClient(FederationBase):
 
         defer.returnValue(signed_events)
 
-    def event_from_pdu_json(self, pdu_json, outlier=False):
-        event = FrozenEvent(
-            pdu_json
-        )
-
-        event.internal_metadata.outlier = outlier
-
-        return event
-
     @defer.inlineCallbacks
     def forward_third_party_invite(self, destinations, room_id, event_dict):
         for destination in destinations:
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index a2327f24b6..bea7fd0b71 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -12,25 +12,26 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
+
+import simplejson as json
 from twisted.internet import defer
 
-from .federation_base import FederationBase
-from .units import Transaction, Edu
+from synapse.api.errors import AuthError, FederationError, SynapseError, NotFoundError
+from synapse.crypto.event_signing import compute_event_signature
+from synapse.federation.federation_base import (
+    FederationBase,
+    event_from_pdu_json,
+)
 
+from synapse.federation.persistence import TransactionActions
+from synapse.federation.units import Edu, Transaction
+import synapse.metrics
+from synapse.types import get_domain_from_id
 from synapse.util import async
+from synapse.util.caches.response_cache import ResponseCache
 from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
 from synapse.util.logutils import log_function
-from synapse.util.caches.response_cache import ResponseCache
-from synapse.events import FrozenEvent
-from synapse.types import get_domain_from_id
-import synapse.metrics
-
-from synapse.api.errors import AuthError, FederationError, SynapseError
-
-from synapse.crypto.event_signing import compute_event_signature
-
-import simplejson as json
-import logging
 
 # when processing incoming transactions, we try to handle multiple rooms in
 # parallel, up to this limit.
@@ -53,50 +54,19 @@ class FederationServer(FederationBase):
         super(FederationServer, self).__init__(hs)
 
         self.auth = hs.get_auth()
+        self.handler = hs.get_handlers().federation_handler
 
         self._server_linearizer = async.Linearizer("fed_server")
         self._transaction_linearizer = async.Linearizer("fed_txn_handler")
 
+        self.transaction_actions = TransactionActions(self.store)
+
+        self.registry = hs.get_federation_registry()
+
         # We cache responses to state queries, as they take a while and often
         # come in waves.
         self._state_resp_cache = ResponseCache(hs, timeout_ms=30000)
 
-    def set_handler(self, handler):
-        """Sets the handler that the replication layer will use to communicate
-        receipt of new PDUs from other home servers. The required methods are
-        documented on :py:class:`.ReplicationHandler`.
-        """
-        self.handler = handler
-
-    def register_edu_handler(self, edu_type, handler):
-        if edu_type in self.edu_handlers:
-            raise KeyError("Already have an EDU handler for %s" % (edu_type,))
-
-        self.edu_handlers[edu_type] = handler
-
-    def register_query_handler(self, query_type, handler):
-        """Sets the handler callable that will be used to handle an incoming
-        federation Query of the given type.
-
-        Args:
-            query_type (str): Category name of the query, which should match
-                the string used by make_query.
-            handler (callable): Invoked to handle incoming queries of this type
-
-        handler is invoked as:
-            result = handler(args)
-
-        where 'args' is a dict mapping strings to strings of the query
-          arguments. It should return a Deferred that will eventually yield an
-          object to encode as JSON.
-        """
-        if query_type in self.query_handlers:
-            raise KeyError(
-                "Already have a Query handler for %s" % (query_type,)
-            )
-
-        self.query_handlers[query_type] = handler
-
     @defer.inlineCallbacks
     @log_function
     def on_backfill_request(self, origin, room_id, versions, limit):
@@ -172,7 +142,7 @@ class FederationServer(FederationBase):
                 p["age_ts"] = request_time - int(p["age"])
                 del p["age"]
 
-            event = self.event_from_pdu_json(p)
+            event = event_from_pdu_json(p)
             room_id = event.room_id
             pdus_by_room.setdefault(room_id, []).append(event)
 
@@ -230,16 +200,7 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     def received_edu(self, origin, edu_type, content):
         received_edus_counter.inc()
-
-        if edu_type in self.edu_handlers:
-            try:
-                yield self.edu_handlers[edu_type](origin, content)
-            except SynapseError as e:
-                logger.info("Failed to handle edu %r: %r", edu_type, e)
-            except Exception as e:
-                logger.exception("Failed to handle edu %r", edu_type)
-        else:
-            logger.warn("Received EDU of type %s with no handler", edu_type)
+        yield self.registry.on_edu(edu_type, origin, content)
 
     @defer.inlineCallbacks
     @log_function
@@ -329,14 +290,8 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     def on_query_request(self, query_type, args):
         received_queries_counter.inc(query_type)
-
-        if query_type in self.query_handlers:
-            response = yield self.query_handlers[query_type](args)
-            defer.returnValue((200, response))
-        else:
-            defer.returnValue(
-                (404, "No handler for Query type '%s'" % (query_type,))
-            )
+        resp = yield self.registry.on_query(query_type, args)
+        defer.returnValue((200, resp))
 
     @defer.inlineCallbacks
     def on_make_join_request(self, room_id, user_id):
@@ -346,7 +301,7 @@ class FederationServer(FederationBase):
 
     @defer.inlineCallbacks
     def on_invite_request(self, origin, content):
-        pdu = self.event_from_pdu_json(content)
+        pdu = event_from_pdu_json(content)
         ret_pdu = yield self.handler.on_invite_request(origin, pdu)
         time_now = self._clock.time_msec()
         defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
@@ -354,7 +309,7 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     def on_send_join_request(self, origin, content):
         logger.debug("on_send_join_request: content: %s", content)
-        pdu = self.event_from_pdu_json(content)
+        pdu = event_from_pdu_json(content)
         logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
         res_pdus = yield self.handler.on_send_join_request(origin, pdu)
         time_now = self._clock.time_msec()
@@ -374,7 +329,7 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     def on_send_leave_request(self, origin, content):
         logger.debug("on_send_leave_request: content: %s", content)
-        pdu = self.event_from_pdu_json(content)
+        pdu = event_from_pdu_json(content)
         logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
         yield self.handler.on_send_leave_request(origin, pdu)
         defer.returnValue((200, {}))
@@ -411,7 +366,7 @@ class FederationServer(FederationBase):
         """
         with (yield self._server_linearizer.queue((origin, room_id))):
             auth_chain = [
-                self.event_from_pdu_json(e)
+                event_from_pdu_json(e)
                 for e in content["auth_chain"]
             ]
 
@@ -586,15 +541,6 @@ class FederationServer(FederationBase):
     def __str__(self):
         return "<ReplicationLayer(%s)>" % self.server_name
 
-    def event_from_pdu_json(self, pdu_json, outlier=False):
-        event = FrozenEvent(
-            pdu_json
-        )
-
-        event.internal_metadata.outlier = outlier
-
-        return event
-
     @defer.inlineCallbacks
     def exchange_third_party_invite(
             self,
@@ -617,3 +563,66 @@ class FederationServer(FederationBase):
             origin, room_id, event_dict
         )
         defer.returnValue(ret)
+
+
+class FederationHandlerRegistry(object):
+    """Allows classes to register themselves as handlers for a given EDU or
+    query type for incoming federation traffic.
+    """
+    def __init__(self):
+        self.edu_handlers = {}
+        self.query_handlers = {}
+
+    def register_edu_handler(self, edu_type, handler):
+        """Sets the handler callable that will be used to handle an incoming
+        federation EDU of the given type.
+
+        Args:
+            edu_type (str): The type of the incoming EDU to register handler for
+            handler (Callable[[str, dict]]): A callable invoked on incoming EDU
+                of the given type. The arguments are the origin server name and
+                the EDU contents.
+        """
+        if edu_type in self.edu_handlers:
+            raise KeyError("Already have an EDU handler for %s" % (edu_type,))
+
+        self.edu_handlers[edu_type] = handler
+
+    def register_query_handler(self, query_type, handler):
+        """Sets the handler callable that will be used to handle an incoming
+        federation query of the given type.
+
+        Args:
+            query_type (str): Category name of the query, which should match
+                the string used by make_query.
+            handler (Callable[[dict], Deferred[dict]]): Invoked to handle
+                incoming queries of this type. The return will be yielded
+                on and the result used as the response to the query request.
+        """
+        if query_type in self.query_handlers:
+            raise KeyError(
+                "Already have a Query handler for %s" % (query_type,)
+            )
+
+        self.query_handlers[query_type] = handler
+
+    @defer.inlineCallbacks
+    def on_edu(self, edu_type, origin, content):
+        handler = self.edu_handlers.get(edu_type)
+        if not handler:
+            logger.warn("No handler registered for EDU type %s", edu_type)
+
+        try:
+            yield handler(origin, content)
+        except SynapseError as e:
+            logger.info("Failed to handle edu %r: %r", edu_type, e)
+        except Exception as e:
+            logger.exception("Failed to handle edu %r", edu_type)
+
+    def on_query(self, query_type, args):
+        handler = self.query_handlers.get(query_type)
+        if not handler:
+            logger.warn("No handler registered for query type %s", query_type)
+            raise NotFoundError("No handler for Query type '%s'" % (query_type,))
+
+        return handler(args)
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
deleted file mode 100644
index 62d865ec4b..0000000000
--- a/synapse/federation/replication.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""This layer is responsible for replicating with remote home servers using
-a given transport.
-"""
-
-from .federation_client import FederationClient
-from .federation_server import FederationServer
-
-from .persistence import TransactionActions
-
-import logging
-
-
-logger = logging.getLogger(__name__)
-
-
-class ReplicationLayer(FederationClient, FederationServer):
-    """This layer is responsible for replicating with remote home servers over
-    the given transport. I.e., does the sending and receiving of PDUs to
-    remote home servers.
-
-    The layer communicates with the rest of the server via a registered
-    ReplicationHandler.
-
-    In more detail, the layer:
-        * Receives incoming data and processes it into transactions and pdus.
-        * Fetches any PDUs it thinks it might have missed.
-        * Keeps the current state for contexts up to date by applying the
-          suitable conflict resolution.
-        * Sends outgoing pdus wrapped in transactions.
-        * Fills out the references to previous pdus/transactions appropriately
-          for outgoing data.
-    """
-
-    def __init__(self, hs, transport_layer):
-        self.server_name = hs.hostname
-
-        self.keyring = hs.get_keyring()
-
-        self.transport_layer = transport_layer
-
-        self.federation_client = self
-
-        self.store = hs.get_datastore()
-
-        self.handler = None
-        self.edu_handlers = {}
-        self.query_handlers = {}
-
-        self._clock = hs.get_clock()
-
-        self.transaction_actions = TransactionActions(self.store)
-
-        self.hs = hs
-
-        super(ReplicationLayer, self).__init__(hs)
-
-    def __str__(self):
-        return "<ReplicationLayer(%s)>" % self.server_name
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 3e7809b04f..a141ec9953 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -19,7 +19,7 @@ from twisted.internet import defer
 from .persistence import TransactionActions
 from .units import Transaction, Edu
 
-from synapse.api.errors import HttpResponseException
+from synapse.api.errors import HttpResponseException, FederationDeniedError
 from synapse.util import logcontext, PreserveLoggingContext
 from synapse.util.async import run_on_reactor
 from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
@@ -42,6 +42,8 @@ sent_edus_counter = client_metrics.register_counter("sent_edus")
 
 sent_transactions_counter = client_metrics.register_counter("sent_transactions")
 
+events_processed_counter = client_metrics.register_counter("events_processed")
+
 
 class TransactionQueue(object):
     """This class makes sure we only have one transaction in flight at
@@ -205,6 +207,8 @@ class TransactionQueue(object):
 
                     self._send_pdu(event, destinations)
 
+                events_processed_counter.inc_by(len(events))
+
                 yield self.store.update_federation_out_pos(
                     "events", next_token
                 )
@@ -486,6 +490,8 @@ class TransactionQueue(object):
                     (e.retry_last_ts + e.retry_interval) / 1000.0
                 ),
             )
+        except FederationDeniedError as e:
+            logger.info(e)
         except Exception as e:
             logger.warn(
                 "TX [%s] Failed to send transaction: %s",
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 1f3ce238f6..5488e82985 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -212,6 +212,9 @@ class TransportLayerClient(object):
 
             Fails with ``NotRetryingDestination`` if we are not yet ready
             to retry this server.
+
+            Fails with ``FederationDeniedError`` if the remote destination
+            is not in our federation whitelist
         """
         valid_memberships = {Membership.JOIN, Membership.LEAVE}
         if membership not in valid_memberships:
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 2b02b021ec..a66a6b0692 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -16,7 +16,7 @@
 from twisted.internet import defer
 
 from synapse.api.urls import FEDERATION_PREFIX as PREFIX
-from synapse.api.errors import Codes, SynapseError
+from synapse.api.errors import Codes, SynapseError, FederationDeniedError
 from synapse.http.server import JsonResource
 from synapse.http.servlet import (
     parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
@@ -81,6 +81,7 @@ class Authenticator(object):
         self.keyring = hs.get_keyring()
         self.server_name = hs.hostname
         self.store = hs.get_datastore()
+        self.federation_domain_whitelist = hs.config.federation_domain_whitelist
 
     # A method just so we can pass 'self' as the authenticator to the Servlets
     @defer.inlineCallbacks
@@ -92,6 +93,12 @@ class Authenticator(object):
             "signatures": {},
         }
 
+        if (
+            self.federation_domain_whitelist is not None and
+            self.server_name not in self.federation_domain_whitelist
+        ):
+            raise FederationDeniedError(self.server_name)
+
         if content is not None:
             json_request["content"] = content
 
@@ -1183,7 +1190,7 @@ GROUP_ATTESTATION_SERVLET_CLASSES = (
 def register_servlets(hs, resource, authenticator, ratelimiter):
     for servletclass in FEDERATION_SERVLET_CLASSES:
         servletclass(
-            handler=hs.get_replication_layer(),
+            handler=hs.get_federation_server(),
             authenticator=authenticator,
             ratelimiter=ratelimiter,
             server_name=hs.hostname,