diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 52f4f54215..8e46957d15 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -33,14 +33,12 @@ import logging
from collections import namedtuple
from typing import Dict, List, Tuple, Type
-from six import iteritems
-
from sortedcontainers import SortedDict
from twisted.internet import defer
+from synapse.api.presence import UserPresenceState
from synapse.metrics import LaterGauge
-from synapse.storage.presence import UserPresenceState
from synapse.util.metrics import Measure
from .units import Edu
@@ -48,7 +46,7 @@ from .units import Edu
logger = logging.getLogger(__name__)
-class FederationRemoteSendQueue(object):
+class FederationRemoteSendQueue:
"""A drop in replacement for FederationSender"""
def __init__(self, hs):
@@ -57,6 +55,11 @@ class FederationRemoteSendQueue(object):
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
+ # We may have multiple federation sender instances, so we need to track
+ # their positions separately.
+ self._sender_instances = hs.config.worker.federation_shard_config.instances
+ self._sender_positions = {}
+
# Pending presence map user_id -> UserPresenceState
self.presence_map = {} # type: Dict[str, UserPresenceState]
@@ -263,7 +266,14 @@ class FederationRemoteSendQueue(object):
def get_current_token(self):
return self.pos - 1
- def federation_ack(self, token):
+ def federation_ack(self, instance_name, token):
+ if self._sender_instances:
+ # If we have configured multiple federation sender instances we need
+ # to track their positions separately, and only clear the queue up
+ # to the token all instances have acked.
+ self._sender_positions[instance_name] = token
+ token = min(self._sender_positions.values())
+
self._clear_queue_before_pos(token)
async def get_replication_rows(
@@ -327,7 +337,7 @@ class FederationRemoteSendQueue(object):
# stream position.
keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]}
- for ((destination, edu_key), pos) in iteritems(keyed_edus):
+ for ((destination, edu_key), pos) in keyed_edus.items():
rows.append(
(
pos,
@@ -355,13 +365,13 @@ class FederationRemoteSendQueue(object):
)
-class BaseFederationRow(object):
+class BaseFederationRow:
"""Base class for rows to be sent in the federation stream.
Specifies how to identify, serialize and deserialize the different types.
"""
- TypeId = "" # Unique string that ids the type. Must be overriden in sub classes.
+ TypeId = "" # Unique string that ids the type. Must be overridden in sub classes.
@staticmethod
def from_data(data):
@@ -530,10 +540,10 @@ def process_rows_for_federation(transaction_queue, rows):
states=[state], destinations=destinations
)
- for destination, edu_map in iteritems(buff.keyed_edus):
+ for destination, edu_map in buff.keyed_edus.items():
for key, edu in edu_map.items():
transaction_queue.send_edu(edu, key)
- for destination, edu_list in iteritems(buff.edus):
+ for destination, edu_list in buff.edus.items():
for edu in edu_list:
transaction_queue.send_edu(edu, None)
|