diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index bbb0195228..93e5acebc1 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -31,21 +31,19 @@ Events are replicated via a separate events stream.
from .units import Edu
+from synapse.storage.presence import UserPresenceState
from synapse.util.metrics import Measure
import synapse.metrics
from blist import sorteddict
-import ujson
+from collections import namedtuple
+import logging
-metrics = synapse.metrics.get_metrics_for(__name__)
+logger = logging.getLogger(__name__)
-PRESENCE_TYPE = "p"
-KEYED_EDU_TYPE = "k"
-EDU_TYPE = "e"
-FAILURE_TYPE = "f"
-DEVICE_MESSAGE_TYPE = "d"
+metrics = synapse.metrics.get_metrics_for(__name__)
class FederationRemoteSendQueue(object):
@@ -55,18 +53,19 @@ class FederationRemoteSendQueue(object):
self.server_name = hs.hostname
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
+ self.is_mine_id = hs.is_mine_id
- self.presence_map = {}
- self.presence_changed = sorteddict()
+ self.presence_map = {} # Pending presence map user_id -> UserPresenceState
+ self.presence_changed = sorteddict() # Stream position -> user_id
- self.keyed_edu = {}
- self.keyed_edu_changed = sorteddict()
+ self.keyed_edu = {} # (destination, key) -> EDU
+ self.keyed_edu_changed = sorteddict() # stream position -> (destination, key)
- self.edus = sorteddict()
+ self.edus = sorteddict() # stream position -> Edu
- self.failures = sorteddict()
+ self.failures = sorteddict() # stream position -> (destination, Failure)
- self.device_messages = sorteddict()
+ self.device_messages = sorteddict() # stream position -> destination
self.pos = 1
self.pos_time = sorteddict()
@@ -122,7 +121,9 @@ class FederationRemoteSendQueue(object):
del self.presence_changed[key]
user_ids = set(
- user_id for uids in self.presence_changed.values() for _, user_id in uids
+ user_id
+ for uids in self.presence_changed.itervalues()
+ for user_id in uids
)
to_del = [
@@ -189,18 +190,20 @@ class FederationRemoteSendQueue(object):
self.notifier.on_new_replication_data()
- def send_presence(self, destination, states):
- """As per TransactionQueue"""
+ def send_presence(self, states):
+ """As per TransactionQueue
+
+ Args:
+ states (list(UserPresenceState))
+ """
pos = self._next_pos()
- self.presence_map.update({
- state.user_id: state
- for state in states
- })
+ # We only want to send presence for our own users, so lets always just
+ # filter here just in case.
+ local_states = filter(lambda s: self.is_mine_id(s.user_id), states)
- self.presence_changed[pos] = [
- (destination, state.user_id) for state in states
- ]
+ self.presence_map.update({state.user_id: state for state in local_states})
+ self.presence_changed[pos] = [state.user_id for state in local_states]
self.notifier.on_new_replication_data()
@@ -220,10 +223,15 @@ class FederationRemoteSendQueue(object):
def get_current_token(self):
return self.pos - 1
- def get_replication_rows(self, token, limit, federation_ack=None):
- """
+ def federation_ack(self, token):
+ self._clear_queue_before_pos(token)
+
+ def get_replication_rows(self, from_token, to_token, limit, federation_ack=None):
+ """Get rows to be sent over federation between the two tokens
+
Args:
- token (int)
+ from_token (int)
+ to_token(int)
limit (int)
federation_ack (int): Optional. The position where the worker is
explicitly acknowledged it has handled. Allows us to drop
@@ -232,9 +240,11 @@ class FederationRemoteSendQueue(object):
# TODO: Handle limit.
# To handle restarts where we wrap around
- if token > self.pos:
- token = -1
+ if from_token > self.pos:
+ from_token = -1
+ # list of tuple(int, BaseFederationRow), where the first is the position
+ # of the federation stream.
rows = []
# There should be only one reader, so lets delete everything its
@@ -244,62 +254,295 @@ class FederationRemoteSendQueue(object):
# Fetch changed presence
keys = self.presence_changed.keys()
- i = keys.bisect_right(token)
- dest_user_ids = set(
- (pos, dest_user_id)
- for pos in keys[i:]
- for dest_user_id in self.presence_changed[pos]
- )
+ i = keys.bisect_right(from_token)
+ j = keys.bisect_right(to_token) + 1
+ dest_user_ids = [
+ (pos, user_id)
+ for pos in keys[i:j]
+ for user_id in self.presence_changed[pos]
+ ]
- for (key, (dest, user_id)) in dest_user_ids:
- rows.append((key, PRESENCE_TYPE, ujson.dumps({
- "destination": dest,
- "state": self.presence_map[user_id].as_dict(),
- })))
+ for (key, user_id) in dest_user_ids:
+ rows.append((key, PresenceRow(
+ state=self.presence_map[user_id],
+ )))
# Fetch changes keyed edus
keys = self.keyed_edu_changed.keys()
- i = keys.bisect_right(token)
- keyed_edus = set((k, self.keyed_edu_changed[k]) for k in keys[i:])
-
- for (pos, (destination, edu_key)) in keyed_edus:
- rows.append(
- (pos, KEYED_EDU_TYPE, ujson.dumps({
- "key": edu_key,
- "edu": self.keyed_edu[(destination, edu_key)].get_internal_dict(),
- }))
- )
+ i = keys.bisect_right(from_token)
+ j = keys.bisect_right(to_token) + 1
+ # We purposefully clobber based on the key here, python dict comprehensions
+ # always use the last value, so this will correctly point to the last
+ # stream position.
+ keyed_edus = {self.keyed_edu_changed[k]: k for k in keys[i:j]}
+
+ for ((destination, edu_key), pos) in keyed_edus.iteritems():
+ rows.append((pos, KeyedEduRow(
+ key=edu_key,
+ edu=self.keyed_edu[(destination, edu_key)],
+ )))
# Fetch changed edus
keys = self.edus.keys()
- i = keys.bisect_right(token)
- edus = set((k, self.edus[k]) for k in keys[i:])
+ i = keys.bisect_right(from_token)
+ j = keys.bisect_right(to_token) + 1
+ edus = ((k, self.edus[k]) for k in keys[i:j])
for (pos, edu) in edus:
- rows.append((pos, EDU_TYPE, ujson.dumps(edu.get_internal_dict())))
+ rows.append((pos, EduRow(edu)))
# Fetch changed failures
keys = self.failures.keys()
- i = keys.bisect_right(token)
- failures = set((k, self.failures[k]) for k in keys[i:])
+ i = keys.bisect_right(from_token)
+ j = keys.bisect_right(to_token) + 1
+ failures = ((k, self.failures[k]) for k in keys[i:j])
for (pos, (destination, failure)) in failures:
- rows.append((pos, FAILURE_TYPE, ujson.dumps({
- "destination": destination,
- "failure": failure,
- })))
+ rows.append((pos, FailureRow(
+ destination=destination,
+ failure=failure,
+ )))
# Fetch changed device messages
keys = self.device_messages.keys()
- i = keys.bisect_right(token)
- device_messages = set((k, self.device_messages[k]) for k in keys[i:])
+ i = keys.bisect_right(from_token)
+ j = keys.bisect_right(to_token) + 1
+ device_messages = {self.device_messages[k]: k for k in keys[i:j]}
- for (pos, destination) in device_messages:
- rows.append((pos, DEVICE_MESSAGE_TYPE, ujson.dumps({
- "destination": destination,
- })))
+ for (destination, pos) in device_messages.iteritems():
+ rows.append((pos, DeviceRow(
+ destination=destination,
+ )))
# Sort rows based on pos
rows.sort()
- return rows
+ return [(pos, row.TypeId, row.to_data()) for pos, row in rows]
+
+
+class BaseFederationRow(object):
+ """Base class for rows to be sent in the federation stream.
+
+ Specifies how to identify, serialize and deserialize the different types.
+ """
+
+ TypeId = None # Unique string that ids the type. Must be overriden in sub classes.
+
+ @staticmethod
+ def from_data(data):
+ """Parse the data from the federation stream into a row.
+
+ Args:
+ data: The value of ``data`` from FederationStreamRow.data, type
+ depends on the type of stream
+ """
+ raise NotImplementedError()
+
+ def to_data(self):
+ """Serialize this row to be sent over the federation stream.
+
+ Returns:
+ The value to be sent in FederationStreamRow.data. The type depends
+ on the type of stream.
+ """
+ raise NotImplementedError()
+
+ def add_to_buffer(self, buff):
+ """Add this row to the appropriate field in the buffer ready for this
+ to be sent over federation.
+
+ We use a buffer so that we can batch up events that have come in at
+ the same time and send them all at once.
+
+ Args:
+ buff (BufferedToSend)
+ """
+ raise NotImplementedError()
+
+
+class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", (
+ "state", # UserPresenceState
+))):
+ TypeId = "p"
+
+ @staticmethod
+ def from_data(data):
+ return PresenceRow(
+ state=UserPresenceState.from_dict(data)
+ )
+
+ def to_data(self):
+ return self.state.as_dict()
+
+ def add_to_buffer(self, buff):
+ buff.presence.append(self.state)
+
+
+class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", (
+ "key", # tuple(str) - the edu key passed to send_edu
+ "edu", # Edu
+))):
+ """Streams EDUs that have an associated key that is ued to clobber. For example,
+ typing EDUs clobber based on room_id.
+ """
+
+ TypeId = "k"
+
+ @staticmethod
+ def from_data(data):
+ return KeyedEduRow(
+ key=tuple(data["key"]),
+ edu=Edu(**data["edu"]),
+ )
+
+ def to_data(self):
+ return {
+ "key": self.key,
+ "edu": self.edu.get_internal_dict(),
+ }
+
+ def add_to_buffer(self, buff):
+ buff.keyed_edus.setdefault(
+ self.edu.destination, {}
+ )[self.key] = self.edu
+
+
+class EduRow(BaseFederationRow, namedtuple("EduRow", (
+ "edu", # Edu
+))):
+ """Streams EDUs that don't have keys. See KeyedEduRow
+ """
+ TypeId = "e"
+
+ @staticmethod
+ def from_data(data):
+ return EduRow(Edu(**data))
+
+ def to_data(self):
+ return self.edu.get_internal_dict()
+
+ def add_to_buffer(self, buff):
+ buff.edus.setdefault(self.edu.destination, []).append(self.edu)
+
+
+class FailureRow(BaseFederationRow, namedtuple("FailureRow", (
+ "destination", # str
+ "failure",
+))):
+ """Streams failures to a remote server. Failures are issued when there was
+ something wrong with a transaction the remote sent us, e.g. it included
+ an event that was invalid.
+ """
+
+ TypeId = "f"
+
+ @staticmethod
+ def from_data(data):
+ return FailureRow(
+ destination=data["destination"],
+ failure=data["failure"],
+ )
+
+ def to_data(self):
+ return {
+ "destination": self.destination,
+ "failure": self.failure,
+ }
+
+ def add_to_buffer(self, buff):
+ buff.failures.setdefault(self.destination, []).append(self.failure)
+
+
+class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", (
+ "destination", # str
+))):
+ """Streams the fact that either a) there is pending to device messages for
+ users on the remote, or b) a local users device has changed and needs to
+ be sent to the remote.
+ """
+ TypeId = "d"
+
+ @staticmethod
+ def from_data(data):
+ return DeviceRow(destination=data["destination"])
+
+ def to_data(self):
+ return {"destination": self.destination}
+
+ def add_to_buffer(self, buff):
+ buff.device_destinations.add(self.destination)
+
+
+TypeToRow = {
+ Row.TypeId: Row
+ for Row in (
+ PresenceRow,
+ KeyedEduRow,
+ EduRow,
+ FailureRow,
+ DeviceRow,
+ )
+}
+
+
+ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", (
+ "presence", # list(UserPresenceState)
+ "keyed_edus", # dict of destination -> { key -> Edu }
+ "edus", # dict of destination -> [Edu]
+ "failures", # dict of destination -> [failures]
+ "device_destinations", # set of destinations
+))
+
+
+def process_rows_for_federation(transaction_queue, rows):
+ """Parse a list of rows from the federation stream and put them in the
+ transaction queue ready for sending to the relevant homeservers.
+
+ Args:
+ transaction_queue (TransactionQueue)
+ rows (list(synapse.replication.tcp.streams.FederationStreamRow))
+ """
+
+ # The federation stream contains a bunch of different types of
+ # rows that need to be handled differently. We parse the rows, put
+ # them into the appropriate collection and then send them off.
+
+ buff = ParsedFederationStreamData(
+ presence=[],
+ keyed_edus={},
+ edus={},
+ failures={},
+ device_destinations=set(),
+ )
+
+ # Parse the rows in the stream and add to the buffer
+ for row in rows:
+ if row.type not in TypeToRow:
+ logger.error("Unrecognized federation row type %r", row.type)
+ continue
+
+ RowType = TypeToRow[row.type]
+ parsed_row = RowType.from_data(row.data)
+ parsed_row.add_to_buffer(buff)
+
+ if buff.presence:
+ transaction_queue.send_presence(buff.presence)
+
+ for destination, edu_map in buff.keyed_edus.iteritems():
+ for key, edu in edu_map.items():
+ transaction_queue.send_edu(
+ edu.destination, edu.edu_type, edu.content, key=key,
+ )
+
+ for destination, edu_list in buff.edus.iteritems():
+ for edu in edu_list:
+ transaction_queue.send_edu(
+ edu.destination, edu.edu_type, edu.content, key=None,
+ )
+
+ for destination, failure_list in buff.failures.iteritems():
+ for failure in failure_list:
+ transaction_queue.send_failure(destination, failure)
+
+ for destination in buff.device_destinations:
+ transaction_queue.send_device_messages(destination)
|