summary refs log tree commit diff
path: root/synapse/federation/send_queue.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/send_queue.py')
-rw-r--r--synapse/federation/send_queue.py375
1 files changed, 309 insertions, 66 deletions
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index bbb0195228..93e5acebc1 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -31,21 +31,19 @@ Events are replicated via a separate events stream.
 
 from .units import Edu
 
+from synapse.storage.presence import UserPresenceState
 from synapse.util.metrics import Measure
 import synapse.metrics
 
 from blist import sorteddict
-import ujson
+from collections import namedtuple
 
+import logging
 
-metrics = synapse.metrics.get_metrics_for(__name__)
+logger = logging.getLogger(__name__)
 
 
-PRESENCE_TYPE = "p"
-KEYED_EDU_TYPE = "k"
-EDU_TYPE = "e"
-FAILURE_TYPE = "f"
-DEVICE_MESSAGE_TYPE = "d"
+metrics = synapse.metrics.get_metrics_for(__name__)
 
 
 class FederationRemoteSendQueue(object):
@@ -55,18 +53,19 @@ class FederationRemoteSendQueue(object):
         self.server_name = hs.hostname
         self.clock = hs.get_clock()
         self.notifier = hs.get_notifier()
+        self.is_mine_id = hs.is_mine_id
 
-        self.presence_map = {}
-        self.presence_changed = sorteddict()
+        self.presence_map = {}  # Pending presence map user_id -> UserPresenceState
+        self.presence_changed = sorteddict()  # Stream position -> user_id
 
-        self.keyed_edu = {}
-        self.keyed_edu_changed = sorteddict()
+        self.keyed_edu = {}  # (destination, key) -> EDU
+        self.keyed_edu_changed = sorteddict()  # stream position -> (destination, key)
 
-        self.edus = sorteddict()
+        self.edus = sorteddict()  # stream position -> Edu
 
-        self.failures = sorteddict()
+        self.failures = sorteddict()  # stream position -> (destination, Failure)
 
-        self.device_messages = sorteddict()
+        self.device_messages = sorteddict()  # stream position -> destination
 
         self.pos = 1
         self.pos_time = sorteddict()
@@ -122,7 +121,9 @@ class FederationRemoteSendQueue(object):
                 del self.presence_changed[key]
 
             user_ids = set(
-                user_id for uids in self.presence_changed.values() for _, user_id in uids
+                user_id
+                for uids in self.presence_changed.itervalues()
+                for user_id in uids
             )
 
             to_del = [
@@ -189,18 +190,20 @@ class FederationRemoteSendQueue(object):
 
         self.notifier.on_new_replication_data()
 
-    def send_presence(self, destination, states):
-        """As per TransactionQueue"""
+    def send_presence(self, states):
+        """As per TransactionQueue
+
+        Args:
+            states (list(UserPresenceState))
+        """
         pos = self._next_pos()
 
-        self.presence_map.update({
-            state.user_id: state
-            for state in states
-        })
+        # We only want to send presence for our own users, so lets always just
+        # filter here just in case.
+        local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
 
-        self.presence_changed[pos] = [
-            (destination, state.user_id) for state in states
-        ]
+        self.presence_map.update({state.user_id: state for state in local_states})
+        self.presence_changed[pos] = [state.user_id for state in local_states]
 
         self.notifier.on_new_replication_data()
 
@@ -220,10 +223,15 @@ class FederationRemoteSendQueue(object):
     def get_current_token(self):
         return self.pos - 1
 
-    def get_replication_rows(self, token, limit, federation_ack=None):
-        """
+    def federation_ack(self, token):
+        self._clear_queue_before_pos(token)
+
+    def get_replication_rows(self, from_token, to_token, limit, federation_ack=None):
+        """Get rows to be sent over federation between the two tokens
+
         Args:
-            token (int)
+            from_token (int)
+            to_token(int)
             limit (int)
             federation_ack (int): Optional. The position where the worker is
                 explicitly acknowledged it has handled. Allows us to drop
@@ -232,9 +240,11 @@ class FederationRemoteSendQueue(object):
         # TODO: Handle limit.
 
         # To handle restarts where we wrap around
-        if token > self.pos:
-            token = -1
+        if from_token > self.pos:
+            from_token = -1
 
+        # list of tuple(int, BaseFederationRow), where the first is the position
+        # of the federation stream.
         rows = []
 
         # There should be only one reader, so lets delete everything its
@@ -244,62 +254,295 @@ class FederationRemoteSendQueue(object):
 
         # Fetch changed presence
         keys = self.presence_changed.keys()
-        i = keys.bisect_right(token)
-        dest_user_ids = set(
-            (pos, dest_user_id)
-            for pos in keys[i:]
-            for dest_user_id in self.presence_changed[pos]
-        )
+        i = keys.bisect_right(from_token)
+        j = keys.bisect_right(to_token) + 1
+        dest_user_ids = [
+            (pos, user_id)
+            for pos in keys[i:j]
+            for user_id in self.presence_changed[pos]
+        ]
 
-        for (key, (dest, user_id)) in dest_user_ids:
-            rows.append((key, PRESENCE_TYPE, ujson.dumps({
-                "destination": dest,
-                "state": self.presence_map[user_id].as_dict(),
-            })))
+        for (key, user_id) in dest_user_ids:
+            rows.append((key, PresenceRow(
+                state=self.presence_map[user_id],
+            )))
 
         # Fetch changes keyed edus
         keys = self.keyed_edu_changed.keys()
-        i = keys.bisect_right(token)
-        keyed_edus = set((k, self.keyed_edu_changed[k]) for k in keys[i:])
-
-        for (pos, (destination, edu_key)) in keyed_edus:
-            rows.append(
-                (pos, KEYED_EDU_TYPE, ujson.dumps({
-                    "key": edu_key,
-                    "edu": self.keyed_edu[(destination, edu_key)].get_internal_dict(),
-                }))
-            )
+        i = keys.bisect_right(from_token)
+        j = keys.bisect_right(to_token) + 1
+        # We purposefully clobber based on the key here, python dict comprehensions
+        # always use the last value, so this will correctly point to the last
+        # stream position.
+        keyed_edus = {self.keyed_edu_changed[k]: k for k in keys[i:j]}
+
+        for ((destination, edu_key), pos) in keyed_edus.iteritems():
+            rows.append((pos, KeyedEduRow(
+                key=edu_key,
+                edu=self.keyed_edu[(destination, edu_key)],
+            )))
 
         # Fetch changed edus
         keys = self.edus.keys()
-        i = keys.bisect_right(token)
-        edus = set((k, self.edus[k]) for k in keys[i:])
+        i = keys.bisect_right(from_token)
+        j = keys.bisect_right(to_token) + 1
+        edus = ((k, self.edus[k]) for k in keys[i:j])
 
         for (pos, edu) in edus:
-            rows.append((pos, EDU_TYPE, ujson.dumps(edu.get_internal_dict())))
+            rows.append((pos, EduRow(edu)))
 
         # Fetch changed failures
         keys = self.failures.keys()
-        i = keys.bisect_right(token)
-        failures = set((k, self.failures[k]) for k in keys[i:])
+        i = keys.bisect_right(from_token)
+        j = keys.bisect_right(to_token) + 1
+        failures = ((k, self.failures[k]) for k in keys[i:j])
 
         for (pos, (destination, failure)) in failures:
-            rows.append((pos, FAILURE_TYPE, ujson.dumps({
-                "destination": destination,
-                "failure": failure,
-            })))
+            rows.append((pos, FailureRow(
+                destination=destination,
+                failure=failure,
+            )))
 
         # Fetch changed device messages
         keys = self.device_messages.keys()
-        i = keys.bisect_right(token)
-        device_messages = set((k, self.device_messages[k]) for k in keys[i:])
+        i = keys.bisect_right(from_token)
+        j = keys.bisect_right(to_token) + 1
+        device_messages = {self.device_messages[k]: k for k in keys[i:j]}
 
-        for (pos, destination) in device_messages:
-            rows.append((pos, DEVICE_MESSAGE_TYPE, ujson.dumps({
-                "destination": destination,
-            })))
+        for (destination, pos) in device_messages.iteritems():
+            rows.append((pos, DeviceRow(
+                destination=destination,
+            )))
 
         # Sort rows based on pos
         rows.sort()
 
-        return rows
+        return [(pos, row.TypeId, row.to_data()) for pos, row in rows]
+
+
+class BaseFederationRow(object):
+    """Base class for rows to be sent in the federation stream.
+
+    Specifies how to identify, serialize and deserialize the different types.
+    """
+
+    TypeId = None  # Unique string that ids the type. Must be overriden in sub classes.
+
+    @staticmethod
+    def from_data(data):
+        """Parse the data from the federation stream into a row.
+
+        Args:
+            data: The value of ``data`` from FederationStreamRow.data, type
+                depends on the type of stream
+        """
+        raise NotImplementedError()
+
+    def to_data(self):
+        """Serialize this row to be sent over the federation stream.
+
+        Returns:
+            The value to be sent in FederationStreamRow.data. The type depends
+            on the type of stream.
+        """
+        raise NotImplementedError()
+
+    def add_to_buffer(self, buff):
+        """Add this row to the appropriate field in the buffer ready for this
+        to be sent over federation.
+
+        We use a buffer so that we can batch up events that have come in at
+        the same time and send them all at once.
+
+        Args:
+            buff (BufferedToSend)
+        """
+        raise NotImplementedError()
+
+
+class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", (
+    "state",  # UserPresenceState
+))):
+    TypeId = "p"
+
+    @staticmethod
+    def from_data(data):
+        return PresenceRow(
+            state=UserPresenceState.from_dict(data)
+        )
+
+    def to_data(self):
+        return self.state.as_dict()
+
+    def add_to_buffer(self, buff):
+        buff.presence.append(self.state)
+
+
+class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", (
+    "key",  # tuple(str) - the edu key passed to send_edu
+    "edu",  # Edu
+))):
+    """Streams EDUs that have an associated key that is ued to clobber. For example,
+    typing EDUs clobber based on room_id.
+    """
+
+    TypeId = "k"
+
+    @staticmethod
+    def from_data(data):
+        return KeyedEduRow(
+            key=tuple(data["key"]),
+            edu=Edu(**data["edu"]),
+        )
+
+    def to_data(self):
+        return {
+            "key": self.key,
+            "edu": self.edu.get_internal_dict(),
+        }
+
+    def add_to_buffer(self, buff):
+        buff.keyed_edus.setdefault(
+            self.edu.destination, {}
+        )[self.key] = self.edu
+
+
+class EduRow(BaseFederationRow, namedtuple("EduRow", (
+    "edu",  # Edu
+))):
+    """Streams EDUs that don't have keys. See KeyedEduRow
+    """
+    TypeId = "e"
+
+    @staticmethod
+    def from_data(data):
+        return EduRow(Edu(**data))
+
+    def to_data(self):
+        return self.edu.get_internal_dict()
+
+    def add_to_buffer(self, buff):
+        buff.edus.setdefault(self.edu.destination, []).append(self.edu)
+
+
+class FailureRow(BaseFederationRow, namedtuple("FailureRow", (
+    "destination",  # str
+    "failure",
+))):
+    """Streams failures to a remote server. Failures are issued when there was
+    something wrong with a transaction the remote sent us, e.g. it included
+    an event that was invalid.
+    """
+
+    TypeId = "f"
+
+    @staticmethod
+    def from_data(data):
+        return FailureRow(
+            destination=data["destination"],
+            failure=data["failure"],
+        )
+
+    def to_data(self):
+        return {
+            "destination": self.destination,
+            "failure": self.failure,
+        }
+
+    def add_to_buffer(self, buff):
+        buff.failures.setdefault(self.destination, []).append(self.failure)
+
+
+class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
+    "destination",  # str
+))):
+    """Streams the fact that either a) there is pending to device messages for
+    users on the remote, or b) a local users device has changed and needs to
+    be sent to the remote.
+    """
+    TypeId = "d"
+
+    @staticmethod
+    def from_data(data):
+        return DeviceRow(destination=data["destination"])
+
+    def to_data(self):
+        return {"destination": self.destination}
+
+    def add_to_buffer(self, buff):
+        buff.device_destinations.add(self.destination)
+
+
+TypeToRow = {
+    Row.TypeId: Row
+    for Row in (
+        PresenceRow,
+        KeyedEduRow,
+        EduRow,
+        FailureRow,
+        DeviceRow,
+    )
+}
+
+
+ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", (
+    "presence",  # list(UserPresenceState)
+    "keyed_edus",  # dict of destination -> { key -> Edu }
+    "edus",  # dict of destination -> [Edu]
+    "failures",  # dict of destination -> [failures]
+    "device_destinations",  # set of destinations
+))
+
+
+def process_rows_for_federation(transaction_queue, rows):
+    """Parse a list of rows from the federation stream and put them in the
+    transaction queue ready for sending to the relevant homeservers.
+
+    Args:
+        transaction_queue (TransactionQueue)
+        rows (list(synapse.replication.tcp.streams.FederationStreamRow))
+    """
+
+    # The federation stream contains a bunch of different types of
+    # rows that need to be handled differently. We parse the rows, put
+    # them into the appropriate collection and then send them off.
+
+    buff = ParsedFederationStreamData(
+        presence=[],
+        keyed_edus={},
+        edus={},
+        failures={},
+        device_destinations=set(),
+    )
+
+    # Parse the rows in the stream and add to the buffer
+    for row in rows:
+        if row.type not in TypeToRow:
+            logger.error("Unrecognized federation row type %r", row.type)
+            continue
+
+        RowType = TypeToRow[row.type]
+        parsed_row = RowType.from_data(row.data)
+        parsed_row.add_to_buffer(buff)
+
+    if buff.presence:
+        transaction_queue.send_presence(buff.presence)
+
+    for destination, edu_map in buff.keyed_edus.iteritems():
+        for key, edu in edu_map.items():
+            transaction_queue.send_edu(
+                edu.destination, edu.edu_type, edu.content, key=key,
+            )
+
+    for destination, edu_list in buff.edus.iteritems():
+        for edu in edu_list:
+            transaction_queue.send_edu(
+                edu.destination, edu.edu_type, edu.content, key=None,
+            )
+
+    for destination, failure_list in buff.failures.iteritems():
+        for failure in failure_list:
+            transaction_queue.send_failure(destination, failure)
+
+    for destination in buff.device_destinations:
+        transaction_queue.send_device_messages(destination)