diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
index 5c9f7a86f0..78c852ed69 100644
--- a/synapse/federation/send_queue.py
+++ b/synapse/federation/send_queue.py
@@ -35,7 +35,6 @@ from synapse.util.metrics import Measure
import synapse.metrics
from blist import sorteddict
-import ujson
metrics = synapse.metrics.get_metrics_for(__name__)
@@ -54,6 +53,7 @@ class FederationRemoteSendQueue(object):
def __init__(self, hs):
self.server_name = hs.hostname
self.clock = hs.get_clock()
+ self.notifier = hs.get_notifier()
self.presence_map = {}
self.presence_changed = sorteddict()
@@ -186,6 +186,8 @@ class FederationRemoteSendQueue(object):
else:
self.edus[pos] = edu
+ self.notifier.on_new_replication_data()
+
def send_presence(self, destination, states):
"""As per TransactionQueue"""
pos = self._next_pos()
@@ -199,24 +201,33 @@ class FederationRemoteSendQueue(object):
(destination, state.user_id) for state in states
]
+ self.notifier.on_new_replication_data()
+
def send_failure(self, failure, destination):
"""As per TransactionQueue"""
pos = self._next_pos()
self.failures[pos] = (destination, str(failure))
+ self.notifier.on_new_replication_data()
def send_device_messages(self, destination):
"""As per TransactionQueue"""
pos = self._next_pos()
self.device_messages[pos] = destination
+ self.notifier.on_new_replication_data()
def get_current_token(self):
return self.pos - 1
- def get_replication_rows(self, token, limit, federation_ack=None):
- """
+ def federation_ack(self, token):
+ self._clear_queue_before_pos(token)
+
+ def get_replication_rows(self, from_token, to_token, limit, federation_ack=None):
+ """Get rows to be sent over federation between the two tokens
+
Args:
- token (int)
+ from_token (int)
+ to_token(int)
limit (int)
federation_ack (int): Optional. The position where the worker is
explicitly acknowledged it has handled. Allows us to drop
@@ -225,8 +236,8 @@ class FederationRemoteSendQueue(object):
# TODO: Handle limit.
# To handle restarts where we wrap around
- if token > self.pos:
- token = -1
+ if from_token > self.pos:
+ from_token = -1
rows = []
@@ -237,60 +248,65 @@ class FederationRemoteSendQueue(object):
# Fetch changed presence
keys = self.presence_changed.keys()
- i = keys.bisect_right(token)
+ i = keys.bisect_right(from_token)
+ j = keys.bisect_right(to_token) + 1
dest_user_ids = set(
(pos, dest_user_id)
- for pos in keys[i:]
+ for pos in keys[i:j]
for dest_user_id in self.presence_changed[pos]
)
for (key, (dest, user_id)) in dest_user_ids:
- rows.append((key, PRESENCE_TYPE, ujson.dumps({
+ rows.append((key, PRESENCE_TYPE, {
"destination": dest,
"state": self.presence_map[user_id].as_dict(),
- })))
+ }))
# Fetch changes keyed edus
keys = self.keyed_edu_changed.keys()
- i = keys.bisect_right(token)
- keyed_edus = set((k, self.keyed_edu_changed[k]) for k in keys[i:])
+ i = keys.bisect_right(from_token)
+ j = keys.bisect_right(to_token) + 1
+ keyed_edus = set((k, self.keyed_edu_changed[k]) for k in keys[i:j])
for (pos, (destination, edu_key)) in keyed_edus:
rows.append(
- (pos, KEYED_EDU_TYPE, ujson.dumps({
+ (pos, KEYED_EDU_TYPE, {
"key": edu_key,
"edu": self.keyed_edu[(destination, edu_key)].get_internal_dict(),
- }))
+ })
)
# Fetch changed edus
keys = self.edus.keys()
- i = keys.bisect_right(token)
- edus = set((k, self.edus[k]) for k in keys[i:])
+ i = keys.bisect_right(from_token)
+ j = keys.bisect_right(to_token) + 1
+ edus = set((k, self.edus[k]) for k in keys[i:j])
for (pos, edu) in edus:
- rows.append((pos, EDU_TYPE, ujson.dumps(edu.get_internal_dict())))
+ rows.append((pos, EDU_TYPE, edu.get_internal_dict()))
# Fetch changed failures
keys = self.failures.keys()
- i = keys.bisect_right(token)
- failures = set((k, self.failures[k]) for k in keys[i:])
+ i = keys.bisect_right(from_token)
+ j = keys.bisect_right(to_token) + 1
+ failures = set((k, self.failures[k]) for k in keys[i:j])
for (pos, (destination, failure)) in failures:
- rows.append((pos, FAILURE_TYPE, ujson.dumps({
+ rows.append((pos, FAILURE_TYPE, {
"destination": destination,
"failure": failure,
- })))
+ }))
# Fetch changed device messages
keys = self.device_messages.keys()
- i = keys.bisect_right(token)
- device_messages = set((k, self.device_messages[k]) for k in keys[i:])
+ i = keys.bisect_right(from_token)
+ j = keys.bisect_right(to_token) + 1
+ device_messages = set((k, self.device_messages[k]) for k in keys[i:j])
for (pos, destination) in device_messages:
- rows.append((pos, DEVICE_MESSAGE_TYPE, ujson.dumps({
+ rows.append((pos, DEVICE_MESSAGE_TYPE, {
"destination": destination,
- })))
+ }))
# Sort rows based on pos
rows.sort()
|