summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/__init__.py7
-rw-r--r--synapse/federation/federation_client.py69
-rw-r--r--synapse/federation/replication.py5
-rw-r--r--synapse/federation/send_queue.py298
-rw-r--r--synapse/federation/transaction_queue.py102
-rw-r--r--synapse/federation/transport/client.py9
-rw-r--r--synapse/federation/transport/server.py19
7 files changed, 425 insertions, 84 deletions
diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py
index 979fdf2431..2e32d245ba 100644
--- a/synapse/federation/__init__.py
+++ b/synapse/federation/__init__.py
@@ -17,10 +17,9 @@
 """
 
 from .replication import ReplicationLayer
-from .transport.client import TransportLayerClient
 
 
-def initialize_http_replication(homeserver):
-    transport = TransportLayerClient(homeserver)
+def initialize_http_replication(hs):
+    transport = hs.get_federation_transport_client()
 
-    return ReplicationLayer(homeserver, transport)
+    return ReplicationLayer(hs, transport)
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 94e76b1978..6e23c207ee 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -18,7 +18,6 @@ from twisted.internet import defer
 
 from .federation_base import FederationBase
 from synapse.api.constants import Membership
-from .units import Edu
 
 from synapse.api.errors import (
     CodeMessageException, HttpResponseException, SynapseError,
@@ -45,10 +44,6 @@ logger = logging.getLogger(__name__)
 # synapse.federation.federation_client is a silly name
 metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
 
-sent_pdus_destination_dist = metrics.register_distribution("sent_pdu_destinations")
-
-sent_edus_counter = metrics.register_counter("sent_edus")
-
 sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
 
 
@@ -93,63 +88,6 @@ class FederationClient(FederationBase):
         self._get_pdu_cache.start()
 
     @log_function
-    def send_pdu(self, pdu, destinations):
-        """Informs the replication layer about a new PDU generated within the
-        home server that should be transmitted to others.
-
-        TODO: Figure out when we should actually resolve the deferred.
-
-        Args:
-            pdu (Pdu): The new Pdu.
-
-        Returns:
-            Deferred: Completes when we have successfully processed the PDU
-            and replicated it to any interested remote home servers.
-        """
-        order = self._order
-        self._order += 1
-
-        sent_pdus_destination_dist.inc_by(len(destinations))
-
-        logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
-
-        # TODO, add errback, etc.
-        self._transaction_queue.enqueue_pdu(pdu, destinations, order)
-
-        logger.debug(
-            "[%s] transaction_layer.enqueue_pdu... done",
-            pdu.event_id
-        )
-
-    def send_presence(self, destination, states):
-        if destination != self.server_name:
-            self._transaction_queue.enqueue_presence(destination, states)
-
-    @log_function
-    def send_edu(self, destination, edu_type, content, key=None):
-        edu = Edu(
-            origin=self.server_name,
-            destination=destination,
-            edu_type=edu_type,
-            content=content,
-        )
-
-        sent_edus_counter.inc()
-
-        self._transaction_queue.enqueue_edu(edu, key=key)
-
-    @log_function
-    def send_device_messages(self, destination):
-        """Sends the device messages in the local database to the remote
-        destination"""
-        self._transaction_queue.enqueue_device_messages(destination)
-
-    @log_function
-    def send_failure(self, failure, destination):
-        self._transaction_queue.enqueue_failure(failure, destination)
-        return defer.succeed(None)
-
-    @log_function
     def make_query(self, destination, query_type, args,
                    retry_on_dns_fail=False):
         """Sends a federation Query to a remote homeserver of the given type
@@ -717,12 +655,15 @@ class FederationClient(FederationBase):
         raise RuntimeError("Failed to send to any server.")
 
     def get_public_rooms(self, destination, limit=None, since_token=None,
-                         search_filter=None):
+                         search_filter=None, include_all_networks=False,
+                         third_party_instance_id=None):
         if destination == self.server_name:
             return
 
         return self.transport_layer.get_public_rooms(
-            destination, limit, since_token, search_filter
+            destination, limit, since_token, search_filter,
+            include_all_networks=include_all_networks,
+            third_party_instance_id=third_party_instance_id,
         )
 
     @defer.inlineCallbacks
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index ea66a5dcbc..62d865ec4b 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -20,8 +20,6 @@ a given transport.
 from .federation_client import FederationClient
 from .federation_server import FederationServer
 
-from .transaction_queue import TransactionQueue
-
 from .persistence import TransactionActions
 
 import logging
@@ -66,9 +64,6 @@ class ReplicationLayer(FederationClient, FederationServer):
         self._clock = hs.get_clock()
 
         self.transaction_actions = TransactionActions(self.store)
-        self._transaction_queue = TransactionQueue(hs, transport_layer)
-
-        self._order = 0
 
         self.hs = hs
 
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
new file mode 100644
index 0000000000..5c9f7a86f0
--- /dev/null
+++ b/synapse/federation/send_queue.py
@@ -0,0 +1,298 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A federation sender that forwards things to be sent across replication to
+a worker process.
+
+It assumes there is a single worker process feeding off of it.
+
+Each row in the replication stream consists of a type and some json, where the
+types indicate whether they are presence, or edus, etc.
+
+Ephemeral or non-event data are queued up in-memory. When the worker requests
+updates since a particular point, all in-memory data since before that point is
+dropped. We also expire things in the queue after 5 minutes, to ensure that a
+dead worker doesn't cause the queues to grow limitlessly.
+
+Events are replicated via a separate events stream.
+"""
+
+from .units import Edu
+
+from synapse.util.metrics import Measure
+import synapse.metrics
+
+from blist import sorteddict
+import ujson
+
+
+metrics = synapse.metrics.get_metrics_for(__name__)
+
+
+PRESENCE_TYPE = "p"
+KEYED_EDU_TYPE = "k"
+EDU_TYPE = "e"
+FAILURE_TYPE = "f"
+DEVICE_MESSAGE_TYPE = "d"
+
+
+class FederationRemoteSendQueue(object):
+    """A drop in replacement for TransactionQueue"""
+
+    def __init__(self, hs):
+        self.server_name = hs.hostname
+        self.clock = hs.get_clock()
+
+        self.presence_map = {}
+        self.presence_changed = sorteddict()
+
+        self.keyed_edu = {}
+        self.keyed_edu_changed = sorteddict()
+
+        self.edus = sorteddict()
+
+        self.failures = sorteddict()
+
+        self.device_messages = sorteddict()
+
+        self.pos = 1
+        self.pos_time = sorteddict()
+
+        # EVERYTHING IS SAD. In particular, python only makes new scopes when
+        # we make a new function, so we need to make a new function so the inner
+        # lambda binds to the queue rather than to the name of the queue which
+        # changes. ARGH.
+        def register(name, queue):
+            metrics.register_callback(
+                queue_name + "_size",
+                lambda: len(queue),
+            )
+
+        for queue_name in [
+            "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
+            "edus", "failures", "device_messages", "pos_time",
+        ]:
+            register(queue_name, getattr(self, queue_name))
+
+        self.clock.looping_call(self._clear_queue, 30 * 1000)
+
+    def _next_pos(self):
+        pos = self.pos
+        self.pos += 1
+        self.pos_time[self.clock.time_msec()] = pos
+        return pos
+
+    def _clear_queue(self):
+        """Clear the queues for anything older than N minutes"""
+
+        FIVE_MINUTES_AGO = 5 * 60 * 1000
+        now = self.clock.time_msec()
+
+        keys = self.pos_time.keys()
+        time = keys.bisect_left(now - FIVE_MINUTES_AGO)
+        if not keys[:time]:
+            return
+
+        position_to_delete = max(keys[:time])
+        for key in keys[:time]:
+            del self.pos_time[key]
+
+        self._clear_queue_before_pos(position_to_delete)
+
+    def _clear_queue_before_pos(self, position_to_delete):
+        """Clear all the queues from before a given position"""
+        with Measure(self.clock, "send_queue._clear"):
+            # Delete things out of presence maps
+            keys = self.presence_changed.keys()
+            i = keys.bisect_left(position_to_delete)
+            for key in keys[:i]:
+                del self.presence_changed[key]
+
+            user_ids = set(
+                user_id for uids in self.presence_changed.values() for _, user_id in uids
+            )
+
+            to_del = [
+                user_id for user_id in self.presence_map if user_id not in user_ids
+            ]
+            for user_id in to_del:
+                del self.presence_map[user_id]
+
+            # Delete things out of keyed edus
+            keys = self.keyed_edu_changed.keys()
+            i = keys.bisect_left(position_to_delete)
+            for key in keys[:i]:
+                del self.keyed_edu_changed[key]
+
+            live_keys = set()
+            for edu_key in self.keyed_edu_changed.values():
+                live_keys.add(edu_key)
+
+            to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys]
+            for edu_key in to_del:
+                del self.keyed_edu[edu_key]
+
+            # Delete things out of edu map
+            keys = self.edus.keys()
+            i = keys.bisect_left(position_to_delete)
+            for key in keys[:i]:
+                del self.edus[key]
+
+            # Delete things out of failure map
+            keys = self.failures.keys()
+            i = keys.bisect_left(position_to_delete)
+            for key in keys[:i]:
+                del self.failures[key]
+
+            # Delete things out of device map
+            keys = self.device_messages.keys()
+            i = keys.bisect_left(position_to_delete)
+            for key in keys[:i]:
+                del self.device_messages[key]
+
+    def notify_new_events(self, current_id):
+        """As per TransactionQueue"""
+        # We don't need to replicate this as it gets sent down a different
+        # stream.
+        pass
+
+    def send_edu(self, destination, edu_type, content, key=None):
+        """As per TransactionQueue"""
+        pos = self._next_pos()
+
+        edu = Edu(
+            origin=self.server_name,
+            destination=destination,
+            edu_type=edu_type,
+            content=content,
+        )
+
+        if key:
+            assert isinstance(key, tuple)
+            self.keyed_edu[(destination, key)] = edu
+            self.keyed_edu_changed[pos] = (destination, key)
+        else:
+            self.edus[pos] = edu
+
+    def send_presence(self, destination, states):
+        """As per TransactionQueue"""
+        pos = self._next_pos()
+
+        self.presence_map.update({
+            state.user_id: state
+            for state in states
+        })
+
+        self.presence_changed[pos] = [
+            (destination, state.user_id) for state in states
+        ]
+
+    def send_failure(self, failure, destination):
+        """As per TransactionQueue"""
+        pos = self._next_pos()
+
+        self.failures[pos] = (destination, str(failure))
+
+    def send_device_messages(self, destination):
+        """As per TransactionQueue"""
+        pos = self._next_pos()
+        self.device_messages[pos] = destination
+
+    def get_current_token(self):
+        return self.pos - 1
+
+    def get_replication_rows(self, token, limit, federation_ack=None):
+        """
+        Args:
+            token (int)
+            limit (int)
+            federation_ack (int): Optional. The position where the worker is
+                explicitly acknowledged it has handled. Allows us to drop
+                data from before that point
+        """
+        # TODO: Handle limit.
+
+        # To handle restarts where we wrap around
+        if token > self.pos:
+            token = -1
+
+        rows = []
+
+        # There should be only one reader, so lets delete everything its
+        # acknowledged its seen.
+        if federation_ack:
+            self._clear_queue_before_pos(federation_ack)
+
+        # Fetch changed presence
+        keys = self.presence_changed.keys()
+        i = keys.bisect_right(token)
+        dest_user_ids = set(
+            (pos, dest_user_id)
+            for pos in keys[i:]
+            for dest_user_id in self.presence_changed[pos]
+        )
+
+        for (key, (dest, user_id)) in dest_user_ids:
+            rows.append((key, PRESENCE_TYPE, ujson.dumps({
+                "destination": dest,
+                "state": self.presence_map[user_id].as_dict(),
+            })))
+
+        # Fetch changes keyed edus
+        keys = self.keyed_edu_changed.keys()
+        i = keys.bisect_right(token)
+        keyed_edus = set((k, self.keyed_edu_changed[k]) for k in keys[i:])
+
+        for (pos, (destination, edu_key)) in keyed_edus:
+            rows.append(
+                (pos, KEYED_EDU_TYPE, ujson.dumps({
+                    "key": edu_key,
+                    "edu": self.keyed_edu[(destination, edu_key)].get_internal_dict(),
+                }))
+            )
+
+        # Fetch changed edus
+        keys = self.edus.keys()
+        i = keys.bisect_right(token)
+        edus = set((k, self.edus[k]) for k in keys[i:])
+
+        for (pos, edu) in edus:
+            rows.append((pos, EDU_TYPE, ujson.dumps(edu.get_internal_dict())))
+
+        # Fetch changed failures
+        keys = self.failures.keys()
+        i = keys.bisect_right(token)
+        failures = set((k, self.failures[k]) for k in keys[i:])
+
+        for (pos, (destination, failure)) in failures:
+            rows.append((pos, FAILURE_TYPE, ujson.dumps({
+                "destination": destination,
+                "failure": failure,
+            })))
+
+        # Fetch changed device messages
+        keys = self.device_messages.keys()
+        i = keys.bisect_right(token)
+        device_messages = set((k, self.device_messages[k]) for k in keys[i:])
+
+        for (pos, destination) in device_messages:
+            rows.append((pos, DEVICE_MESSAGE_TYPE, ujson.dumps({
+                "destination": destination,
+            })))
+
+        # Sort rows based on pos
+        rows.sort()
+
+        return rows
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index f8ca93e4c3..51b656d74a 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
 from .persistence import TransactionActions
 from .units import Transaction, Edu
 
+from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import HttpResponseException
 from synapse.util.async import run_on_reactor
 from synapse.util.logcontext import preserve_context_over_fn
@@ -26,6 +27,7 @@ from synapse.util.retryutils import (
     get_retry_limiter, NotRetryingDestination,
 )
 from synapse.util.metrics import measure_func
+from synapse.types import get_domain_from_id
 from synapse.handlers.presence import format_user_presence_state
 import synapse.metrics
 
@@ -36,6 +38,12 @@ logger = logging.getLogger(__name__)
 
 metrics = synapse.metrics.get_metrics_for(__name__)
 
+client_metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
+sent_pdus_destination_dist = client_metrics.register_distribution(
+    "sent_pdu_destinations"
+)
+sent_edus_counter = client_metrics.register_counter("sent_edus")
+
 
 class TransactionQueue(object):
     """This class makes sure we only have one transaction in flight at
@@ -44,13 +52,14 @@ class TransactionQueue(object):
     It batches pending PDUs into single transactions.
     """
 
-    def __init__(self, hs, transport_layer):
+    def __init__(self, hs):
         self.server_name = hs.hostname
 
         self.store = hs.get_datastore()
+        self.state = hs.get_state_handler()
         self.transaction_actions = TransactionActions(self.store)
 
-        self.transport_layer = transport_layer
+        self.transport_layer = hs.get_federation_transport_client()
 
         self.clock = hs.get_clock()
 
@@ -95,6 +104,11 @@ class TransactionQueue(object):
         # HACK to get unique tx id
         self._next_txn_id = int(self.clock.time_msec())
 
+        self._order = 1
+
+        self._is_processing = False
+        self._last_poked_id = -1
+
     def can_send_to(self, destination):
         """Can we send messages to the given server?
 
@@ -115,11 +129,61 @@ class TransactionQueue(object):
         else:
             return not destination.startswith("localhost")
 
-    def enqueue_pdu(self, pdu, destinations, order):
+    @defer.inlineCallbacks
+    def notify_new_events(self, current_id):
+        """This gets called when we have some new events we might want to
+        send out to other servers.
+        """
+        self._last_poked_id = max(current_id, self._last_poked_id)
+
+        if self._is_processing:
+            return
+
+        try:
+            self._is_processing = True
+            while True:
+                last_token = yield self.store.get_federation_out_pos("events")
+                next_token, events = yield self.store.get_all_new_events_stream(
+                    last_token, self._last_poked_id, limit=20,
+                )
+
+                logger.debug("Handling %s -> %s", last_token, next_token)
+
+                if not events and next_token >= self._last_poked_id:
+                    break
+
+                for event in events:
+                    users_in_room = yield self.state.get_current_user_in_room(
+                        event.room_id, latest_event_ids=[event.event_id],
+                    )
+
+                    destinations = set(
+                        get_domain_from_id(user_id) for user_id in users_in_room
+                    )
+
+                    if event.type == EventTypes.Member:
+                        if event.content["membership"] == Membership.JOIN:
+                            destinations.add(get_domain_from_id(event.state_key))
+
+                    logger.debug("Sending %s to %r", event, destinations)
+
+                    self._send_pdu(event, destinations)
+
+                yield self.store.update_federation_out_pos(
+                    "events", next_token
+                )
+
+        finally:
+            self._is_processing = False
+
+    def _send_pdu(self, pdu, destinations):
         # We loop through all destinations to see whether we already have
         # a transaction in progress. If we do, stick it in the pending_pdus
         # table and we'll get back to it later.
 
+        order = self._order
+        self._order += 1
+
         destinations = set(destinations)
         destinations = set(
             dest for dest in destinations if self.can_send_to(dest)
@@ -130,6 +194,8 @@ class TransactionQueue(object):
         if not destinations:
             return
 
+        sent_pdus_destination_dist.inc_by(len(destinations))
+
         for destination in destinations:
             self.pending_pdus_by_dest.setdefault(destination, []).append(
                 (pdu, order)
@@ -139,7 +205,10 @@ class TransactionQueue(object):
                 self._attempt_new_transaction, destination
             )
 
-    def enqueue_presence(self, destination, states):
+    def send_presence(self, destination, states):
+        if not self.can_send_to(destination):
+            return
+
         self.pending_presence_by_dest.setdefault(destination, {}).update({
             state.user_id: state for state in states
         })
@@ -148,12 +217,19 @@ class TransactionQueue(object):
             self._attempt_new_transaction, destination
         )
 
-    def enqueue_edu(self, edu, key=None):
-        destination = edu.destination
+    def send_edu(self, destination, edu_type, content, key=None):
+        edu = Edu(
+            origin=self.server_name,
+            destination=destination,
+            edu_type=edu_type,
+            content=content,
+        )
 
         if not self.can_send_to(destination):
             return
 
+        sent_edus_counter.inc()
+
         if key:
             self.pending_edus_keyed_by_dest.setdefault(
                 destination, {}
@@ -165,7 +241,7 @@ class TransactionQueue(object):
             self._attempt_new_transaction, destination
         )
 
-    def enqueue_failure(self, failure, destination):
+    def send_failure(self, failure, destination):
         if destination == self.server_name or destination == "localhost":
             return
 
@@ -180,7 +256,7 @@ class TransactionQueue(object):
             self._attempt_new_transaction, destination
         )
 
-    def enqueue_device_messages(self, destination):
+    def send_device_messages(self, destination):
         if destination == self.server_name or destination == "localhost":
             return
 
@@ -191,6 +267,9 @@ class TransactionQueue(object):
             self._attempt_new_transaction, destination
         )
 
+    def get_current_token(self):
+        return 0
+
     @defer.inlineCallbacks
     def _attempt_new_transaction(self, destination):
         # list of (pending_pdu, deferred, order)
@@ -383,6 +462,13 @@ class TransactionQueue(object):
                     code = e.code
                     response = e.response
 
+                    if e.code == 429 or 500 <= e.code:
+                        logger.info(
+                            "TX [%s] {%s} got %d response",
+                            destination, txn_id, code
+                        )
+                        raise e
+
                 logger.info(
                     "TX [%s] {%s} got %d response",
                     destination, txn_id, code
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index db45c7826c..491cdc29e1 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -249,10 +249,15 @@ class TransportLayerClient(object):
     @defer.inlineCallbacks
     @log_function
     def get_public_rooms(self, remote_server, limit, since_token,
-                         search_filter=None):
+                         search_filter=None, include_all_networks=False,
+                         third_party_instance_id=None):
         path = PREFIX + "/publicRooms"
 
-        args = {}
+        args = {
+            "include_all_networks": "true" if include_all_networks else "false",
+        }
+        if third_party_instance_id:
+            args["third_party_instance_id"] = third_party_instance_id,
         if limit:
             args["limit"] = [str(limit)]
         if since_token:
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index fec337be64..159dbd1747 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -20,9 +20,11 @@ from synapse.api.errors import Codes, SynapseError
 from synapse.http.server import JsonResource
 from synapse.http.servlet import (
     parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
+    parse_boolean_from_args,
 )
 from synapse.util.ratelimitutils import FederationRateLimiter
 from synapse.util.versionstring import get_version_string
+from synapse.types import ThirdPartyInstanceID
 
 import functools
 import logging
@@ -558,8 +560,23 @@ class PublicRoomList(BaseFederationServlet):
     def on_GET(self, origin, content, query):
         limit = parse_integer_from_args(query, "limit", 0)
         since_token = parse_string_from_args(query, "since", None)
+        include_all_networks = parse_boolean_from_args(
+            query, "include_all_networks", False
+        )
+        third_party_instance_id = parse_string_from_args(
+            query, "third_party_instance_id", None
+        )
+
+        if include_all_networks:
+            network_tuple = None
+        elif third_party_instance_id:
+            network_tuple = ThirdPartyInstanceID.from_string(third_party_instance_id)
+        else:
+            network_tuple = ThirdPartyInstanceID(None, None)
+
         data = yield self.room_list_handler.get_local_public_room_list(
-            limit, since_token
+            limit, since_token,
+            network_tuple=network_tuple
         )
         defer.returnValue((200, data))