summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/app/federation_sender.py66
-rw-r--r--synapse/federation/send_queue.py246
2 files changed, 221 insertions, 91 deletions
diff --git a/synapse/app/federation_sender.py b/synapse/app/federation_sender.py
index 145c01f3a3..477e16e0fa 100644
--- a/synapse/app/federation_sender.py
+++ b/synapse/app/federation_sender.py
@@ -23,7 +23,6 @@ from synapse.config.homeserver import HomeServerConfig
 from synapse.crypto import context_factory
 from synapse.http.site import SynapseSite
 from synapse.federation import send_queue
-from synapse.federation.units import Edu
 from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
 from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
 from synapse.replication.slave.storage.events import SlavedEventStore
@@ -33,7 +32,6 @@ from synapse.replication.slave.storage.transactions import TransactionStore
 from synapse.replication.slave.storage.devices import SlavedDeviceStore
 from synapse.replication.tcp.client import ReplicationClientHandler
 from synapse.storage.engines import create_engine
-from synapse.storage.presence import UserPresenceState
 from synapse.util.async import Linearizer
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.logcontext import LoggingContext, PreserveLoggingContext, preserve_fn
@@ -277,69 +275,7 @@ class FederationSenderHandler(object):
         # The federation stream contains things that we want to send out, e.g.
         # presence, typing, etc.
         if stream_name == "federation":
-            # The federation stream containis a bunch of different types of
-            # rows that need to be handled differently. We parse the rows, put
-            # them into the appropriate collection and then send them off.
-            presence_to_send = {}
-            keyed_edus = {}
-            edus = {}
-            failures = {}
-            device_destinations = set()
-
-            # Parse the rows in the stream
-            for row in rows:
-                typ = row.type
-                content = row.data
-
-                if typ == send_queue.PRESENCE_TYPE:
-                    destination = content["destination"]
-                    state = UserPresenceState.from_dict(content["state"])
-
-                    presence_to_send.setdefault(destination, []).append(state)
-                elif typ == send_queue.KEYED_EDU_TYPE:
-                    key = content["key"]
-                    edu = Edu(**content["edu"])
-
-                    keyed_edus.setdefault(
-                        edu.destination, {}
-                    )[(edu.destination, tuple(key))] = edu
-                elif typ == send_queue.EDU_TYPE:
-                    edu = Edu(**content)
-
-                    edus.setdefault(edu.destination, []).append(edu)
-                elif typ == send_queue.FAILURE_TYPE:
-                    destination = content["destination"]
-                    failure = content["failure"]
-
-                    failures.setdefault(destination, []).append(failure)
-                elif typ == send_queue.DEVICE_MESSAGE_TYPE:
-                    device_destinations.add(content["destination"])
-                else:
-                    raise Exception("Unrecognised federation type: %r", typ)
-
-            # We've finished collecting, send everything off
-            for destination, states in presence_to_send.items():
-                self.federation_sender.send_presence(destination, states)
-
-            for destination, edu_map in keyed_edus.items():
-                for key, edu in edu_map.items():
-                    self.federation_sender.send_edu(
-                        edu.destination, edu.edu_type, edu.content, key=key,
-                    )
-
-            for destination, edu_list in edus.items():
-                for edu in edu_list:
-                    self.federation_sender.send_edu(
-                        edu.destination, edu.edu_type, edu.content, key=None,
-                    )
-
-            for destination, failure_list in failures.items():
-                for failure in failure_list:
-                    self.federation_sender.send_failure(destination, failure)
-
-            for destination in device_destinations:
-                self.federation_sender.send_device_messages(destination)
-
+            send_queue.process_rows_for_federation(self.federation_sender, rows)
             preserve_fn(self.update_token)(token)
 
         # We also need to poke the federation sender when new events happen
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 78c852ed69..8a6392c697 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -31,22 +31,17 @@ Events are replicated via a separate events stream.
 
 from .units import Edu
 
+from synapse.storage.presence import UserPresenceState
 from synapse.util.metrics import Measure
 import synapse.metrics
 
 from blist import sorteddict
+from collections import namedtuple
 
 
 metrics = synapse.metrics.get_metrics_for(__name__)
 
 
-PRESENCE_TYPE = "p"
-KEYED_EDU_TYPE = "k"
-EDU_TYPE = "e"
-FAILURE_TYPE = "f"
-DEVICE_MESSAGE_TYPE = "d"
-
-
 class FederationRemoteSendQueue(object):
     """A drop in replacement for TransactionQueue"""
 
@@ -257,10 +252,10 @@ class FederationRemoteSendQueue(object):
         )
 
         for (key, (dest, user_id)) in dest_user_ids:
-            rows.append((key, PRESENCE_TYPE, {
-                "destination": dest,
-                "state": self.presence_map[user_id].as_dict(),
-            }))
+            rows.append((key, PresenceRow(
+                destination=dest,
+                state=self.presence_map[user_id],
+            )))
 
         # Fetch changes keyed edus
         keys = self.keyed_edu_changed.keys()
@@ -269,12 +264,10 @@ class FederationRemoteSendQueue(object):
         keyed_edus = set((k, self.keyed_edu_changed[k]) for k in keys[i:j])
 
         for (pos, (destination, edu_key)) in keyed_edus:
-            rows.append(
-                (pos, KEYED_EDU_TYPE, {
-                    "key": edu_key,
-                    "edu": self.keyed_edu[(destination, edu_key)].get_internal_dict(),
-                })
-            )
+            rows.append((pos, KeyedEduRow(
+                key=edu_key,
+                edu=self.keyed_edu[(destination, edu_key)],
+            )))
 
         # Fetch changed edus
         keys = self.edus.keys()
@@ -283,7 +276,7 @@ class FederationRemoteSendQueue(object):
         edus = set((k, self.edus[k]) for k in keys[i:j])
 
         for (pos, edu) in edus:
-            rows.append((pos, EDU_TYPE, edu.get_internal_dict()))
+            rows.append((pos, EduRow(edu)))
 
         # Fetch changed failures
         keys = self.failures.keys()
@@ -292,10 +285,10 @@ class FederationRemoteSendQueue(object):
         failures = set((k, self.failures[k]) for k in keys[i:j])
 
         for (pos, (destination, failure)) in failures:
-            rows.append((pos, FAILURE_TYPE, {
-                "destination": destination,
-                "failure": failure,
-            }))
+            rows.append((pos, FailureRow(
+                destination=destination,
+                failure=failure,
+            )))
 
         # Fetch changed device messages
         keys = self.device_messages.keys()
@@ -304,11 +297,212 @@ class FederationRemoteSendQueue(object):
         device_messages = set((k, self.device_messages[k]) for k in keys[i:j])
 
         for (pos, destination) in device_messages:
-            rows.append((pos, DEVICE_MESSAGE_TYPE, {
-                "destination": destination,
-            }))
+            rows.append((pos, DeviceRow(
+                destination=destination,
+            )))
 
         # Sort rows based on pos
         rows.sort()
 
-        return rows
+        return [(pos, row.TypeId, row.to_data()) for pos, row in rows]
+
+
+class BaseFederationRow(object):
+    TypeId = None
+
+    @staticmethod
+    def from_data(data):
+        """Parse the data from the federation stream into a row.
+        """
+        raise NotImplementedError()
+
+    def to_data(self):
+        """Serialize this row to be sent over the federation stream
+        """
+        raise NotImplementedError()
+
+    def add_to_buffer(self, buff):
+        """Add this row to the appropriate field in the buffer ready for this
+        to be sent over federation.
+
+        We use a buffer so that we can batch up events that have come in at
+        the same time and send them all at once.
+
+        Args:
+            buff (BufferedToSend)
+        """
+        raise NotImplementedError()
+
+
+class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", (
+    "destination",  # str
+    "state",  # UserPresenceState
+))):
+    TypeId = "p"
+
+    @staticmethod
+    def from_data(data):
+        return PresenceRow(
+            destination=data["destination"],
+            state=UserPresenceState.from_dict(data["state"])
+        )
+
+    def to_data(self):
+        return {
+            "destination": self.destination,
+            "state": self.state.as_dict()
+        }
+
+    def add_to_buffer(self, buff):
+        buff.presence.setdefault(self.destination, []).append(self.state)
+
+
+class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", (
+    "key",  # tuple(str) - the edu key passed to send_edu
+    "edu",  # Edu
+))):
+    TypeId = "k"
+
+    @staticmethod
+    def from_data(data):
+        return KeyedEduRow(
+            key=tuple(data["key"]),
+            edu=Edu(**data["edu"]),
+        )
+
+    def to_data(self):
+        return {
+            "key": self.key,
+            "edu": self.edu.get_internal_dict(),
+        }
+
+    def add_to_buffer(self, buff):
+        buff.keyed_edus.setdefault(
+            self.edu.destination, {}
+        )[self.key] = self.edu
+
+
+class EduRow(BaseFederationRow, namedtuple("EduRow", (
+    "edu",  # Edu
+))):
+    TypeId = "e"
+
+    @staticmethod
+    def from_data(data):
+        return EduRow(Edu(**data))
+
+    def to_data(self):
+        return self.edu.get_internal_dict()
+
+    def add_to_buffer(self, buff):
+        buff.edus.setdefault(self.edu.destination, []).append(self.edu)
+
+
+class FailureRow(BaseFederationRow, namedtuple("FailureRow", (
+    "destination",  # str
+    "failure",
+))):
+    TypeId = "f"
+
+    @staticmethod
+    def from_data(data):
+        return FailureRow(
+            destination=data["destination"],
+            failure=data["failure"],
+        )
+
+    def to_data(self):
+        return {
+            "destination": self.destination,
+            "failure": self.failure,
+        }
+
+    def add_to_buffer(self, buff):
+        buff.failures.setdefault(self.destination, []).append(self.failure)
+
+
+class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
+    "destination",  # str
+))):
+    TypeId = "d"
+
+    @staticmethod
+    def from_data(data):
+        return DeviceRow(destination=data)
+
+    def to_data(self):
+        return self.destination
+
+    def add_to_buffer(self, buff):
+        buff.device_destinations.add(self.destination)
+
+
+TypeToRow = {
+    Row.TypeId: Row
+    for Row in (
+        PresenceRow,
+        KeyedEduRow,
+        EduRow,
+        FailureRow,
+        DeviceRow,
+    )
+}
+
+
+BufferedToSend = namedtuple("BufferedToSend", (
+    "presence",  # dict of destination -> [UserPresenceState]
+    "keyed_edus",  # dict of destination -> { key -> Edu }
+    "edus",  # dict of destination -> [Edu]
+    "failures",  # dict of destination -> [failures]
+    "device_destinations",  # set of destinations
+))
+
+
+def process_rows_for_federation(federation_sender, rows):
+    """Parse a list of rows from the federation stream and them send them out.
+
+    Args:
+        federation_sender (TransactionQueue)
+        rows (list(FederationStreamRow))
+    """
+
+    # The federation stream containis a bunch of different types of
+    # rows that need to be handled differently. We parse the rows, put
+    # them into the appropriate collection and then send them off.
+
+    buff = BufferedToSend(
+        presence={},
+        keyed_edus={},
+        edus={},
+        failures={},
+        device_destinations=set(),
+    )
+
+    # Parse the rows in the stream and add to the buffer
+    for row in rows:
+        RowType = TypeToRow[row.type]
+        parsed_row = RowType.from_data(row.data)
+        parsed_row.add_to_buffer(buff)
+
+    # We've finished collecting, send everything off
+    for destination, states in buff.presence.iteritems():
+        federation_sender.send_presence(destination, states)
+
+    for destination, edu_map in buff.keyed_edus.iteritems():
+        for key, edu in edu_map.items():
+            federation_sender.send_edu(
+                edu.destination, edu.edu_type, edu.content, key=key,
+            )
+
+    for destination, edu_list in buff.edus.iteritems():
+        for edu in edu_list:
+            federation_sender.send_edu(
+                edu.destination, edu.edu_type, edu.content, key=None,
+            )
+
+    for destination, failure_list in buff.failures.iteritems():
+        for failure in failure_list:
+            federation_sender.send_failure(destination, failure)
+
+    for destination in buff.device_destinations:
+        federation_sender.send_device_messages(destination)