diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py
index 979fdf2431..2e32d245ba 100644
--- a/synapse/federation/__init__.py
+++ b/synapse/federation/__init__.py
@@ -17,10 +17,9 @@
"""
from .replication import ReplicationLayer
-from .transport.client import TransportLayerClient
-def initialize_http_replication(homeserver):
- transport = TransportLayerClient(homeserver)
+def initialize_http_replication(hs):
+ transport = hs.get_federation_transport_client()
- return ReplicationLayer(homeserver, transport)
+ return ReplicationLayer(hs, transport)
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 94e76b1978..b255709165 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -18,7 +18,6 @@ from twisted.internet import defer
from .federation_base import FederationBase
from synapse.api.constants import Membership
-from .units import Edu
from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError,
@@ -45,10 +44,6 @@ logger = logging.getLogger(__name__)
# synapse.federation.federation_client is a silly name
metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
-sent_pdus_destination_dist = metrics.register_distribution("sent_pdu_destinations")
-
-sent_edus_counter = metrics.register_counter("sent_edus")
-
sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
@@ -93,63 +88,6 @@ class FederationClient(FederationBase):
self._get_pdu_cache.start()
@log_function
- def send_pdu(self, pdu, destinations):
- """Informs the replication layer about a new PDU generated within the
- home server that should be transmitted to others.
-
- TODO: Figure out when we should actually resolve the deferred.
-
- Args:
- pdu (Pdu): The new Pdu.
-
- Returns:
- Deferred: Completes when we have successfully processed the PDU
- and replicated it to any interested remote home servers.
- """
- order = self._order
- self._order += 1
-
- sent_pdus_destination_dist.inc_by(len(destinations))
-
- logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
-
- # TODO, add errback, etc.
- self._transaction_queue.enqueue_pdu(pdu, destinations, order)
-
- logger.debug(
- "[%s] transaction_layer.enqueue_pdu... done",
- pdu.event_id
- )
-
- def send_presence(self, destination, states):
- if destination != self.server_name:
- self._transaction_queue.enqueue_presence(destination, states)
-
- @log_function
- def send_edu(self, destination, edu_type, content, key=None):
- edu = Edu(
- origin=self.server_name,
- destination=destination,
- edu_type=edu_type,
- content=content,
- )
-
- sent_edus_counter.inc()
-
- self._transaction_queue.enqueue_edu(edu, key=key)
-
- @log_function
- def send_device_messages(self, destination):
- """Sends the device messages in the local database to the remote
- destination"""
- self._transaction_queue.enqueue_device_messages(destination)
-
- @log_function
- def send_failure(self, failure, destination):
- self._transaction_queue.enqueue_failure(failure, destination)
- return defer.succeed(None)
-
- @log_function
def make_query(self, destination, query_type, args,
retry_on_dns_fail=False):
"""Sends a federation Query to a remote homeserver of the given type
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index ea66a5dcbc..62d865ec4b 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -20,8 +20,6 @@ a given transport.
from .federation_client import FederationClient
from .federation_server import FederationServer
-from .transaction_queue import TransactionQueue
-
from .persistence import TransactionActions
import logging
@@ -66,9 +64,6 @@ class ReplicationLayer(FederationClient, FederationServer):
self._clock = hs.get_clock()
self.transaction_actions = TransactionActions(self.store)
- self._transaction_queue = TransactionQueue(hs, transport_layer)
-
- self._order = 0
self.hs = hs
diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py
new file mode 100644
index 0000000000..5c9f7a86f0
--- /dev/null
+++ b/synapse/federation/send_queue.py
@@ -0,0 +1,298 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014-2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A federation sender that forwards things to be sent across replication to
+a worker process.
+
+It assumes there is a single worker process feeding off of it.
+
+Each row in the replication stream consists of a type and some json, where the
+types indicate whether they are presence, or edus, etc.
+
+Ephemeral or non-event data are queued up in-memory. When the worker requests
+updates since a particular point, all in-memory data since before that point is
+dropped. We also expire things in the queue after 5 minutes, to ensure that a
+dead worker doesn't cause the queues to grow limitlessly.
+
+Events are replicated via a separate events stream.
+"""
+
+from .units import Edu
+
+from synapse.util.metrics import Measure
+import synapse.metrics
+
+from blist import sorteddict
+import ujson
+
+
+metrics = synapse.metrics.get_metrics_for(__name__)
+
+
+PRESENCE_TYPE = "p"
+KEYED_EDU_TYPE = "k"
+EDU_TYPE = "e"
+FAILURE_TYPE = "f"
+DEVICE_MESSAGE_TYPE = "d"
+
+
+class FederationRemoteSendQueue(object):
+ """A drop in replacement for TransactionQueue"""
+
+ def __init__(self, hs):
+ self.server_name = hs.hostname
+ self.clock = hs.get_clock()
+
+ self.presence_map = {}
+ self.presence_changed = sorteddict()
+
+ self.keyed_edu = {}
+ self.keyed_edu_changed = sorteddict()
+
+ self.edus = sorteddict()
+
+ self.failures = sorteddict()
+
+ self.device_messages = sorteddict()
+
+ self.pos = 1
+ self.pos_time = sorteddict()
+
+ # EVERYTHING IS SAD. In particular, python only makes new scopes when
+ # we make a new function, so we need to make a new function so the inner
+ # lambda binds to the queue rather than to the name of the queue which
+ # changes. ARGH.
+ def register(name, queue):
+ metrics.register_callback(
+ queue_name + "_size",
+ lambda: len(queue),
+ )
+
+ for queue_name in [
+ "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
+ "edus", "failures", "device_messages", "pos_time",
+ ]:
+ register(queue_name, getattr(self, queue_name))
+
+ self.clock.looping_call(self._clear_queue, 30 * 1000)
+
+ def _next_pos(self):
+ pos = self.pos
+ self.pos += 1
+ self.pos_time[self.clock.time_msec()] = pos
+ return pos
+
+ def _clear_queue(self):
+ """Clear the queues for anything older than N minutes"""
+
+ FIVE_MINUTES_AGO = 5 * 60 * 1000
+ now = self.clock.time_msec()
+
+ keys = self.pos_time.keys()
+ time = keys.bisect_left(now - FIVE_MINUTES_AGO)
+ if not keys[:time]:
+ return
+
+ position_to_delete = max(keys[:time])
+ for key in keys[:time]:
+ del self.pos_time[key]
+
+ self._clear_queue_before_pos(position_to_delete)
+
+ def _clear_queue_before_pos(self, position_to_delete):
+ """Clear all the queues from before a given position"""
+ with Measure(self.clock, "send_queue._clear"):
+ # Delete things out of presence maps
+ keys = self.presence_changed.keys()
+ i = keys.bisect_left(position_to_delete)
+ for key in keys[:i]:
+ del self.presence_changed[key]
+
+ user_ids = set(
+ user_id for uids in self.presence_changed.values() for _, user_id in uids
+ )
+
+ to_del = [
+ user_id for user_id in self.presence_map if user_id not in user_ids
+ ]
+ for user_id in to_del:
+ del self.presence_map[user_id]
+
+ # Delete things out of keyed edus
+ keys = self.keyed_edu_changed.keys()
+ i = keys.bisect_left(position_to_delete)
+ for key in keys[:i]:
+ del self.keyed_edu_changed[key]
+
+ live_keys = set()
+ for edu_key in self.keyed_edu_changed.values():
+ live_keys.add(edu_key)
+
+ to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys]
+ for edu_key in to_del:
+ del self.keyed_edu[edu_key]
+
+ # Delete things out of edu map
+ keys = self.edus.keys()
+ i = keys.bisect_left(position_to_delete)
+ for key in keys[:i]:
+ del self.edus[key]
+
+ # Delete things out of failure map
+ keys = self.failures.keys()
+ i = keys.bisect_left(position_to_delete)
+ for key in keys[:i]:
+ del self.failures[key]
+
+ # Delete things out of device map
+ keys = self.device_messages.keys()
+ i = keys.bisect_left(position_to_delete)
+ for key in keys[:i]:
+ del self.device_messages[key]
+
+ def notify_new_events(self, current_id):
+ """As per TransactionQueue"""
+ # We don't need to replicate this as it gets sent down a different
+ # stream.
+ pass
+
+ def send_edu(self, destination, edu_type, content, key=None):
+ """As per TransactionQueue"""
+ pos = self._next_pos()
+
+ edu = Edu(
+ origin=self.server_name,
+ destination=destination,
+ edu_type=edu_type,
+ content=content,
+ )
+
+ if key:
+ assert isinstance(key, tuple)
+ self.keyed_edu[(destination, key)] = edu
+ self.keyed_edu_changed[pos] = (destination, key)
+ else:
+ self.edus[pos] = edu
+
+ def send_presence(self, destination, states):
+ """As per TransactionQueue"""
+ pos = self._next_pos()
+
+ self.presence_map.update({
+ state.user_id: state
+ for state in states
+ })
+
+ self.presence_changed[pos] = [
+ (destination, state.user_id) for state in states
+ ]
+
+ def send_failure(self, failure, destination):
+ """As per TransactionQueue"""
+ pos = self._next_pos()
+
+ self.failures[pos] = (destination, str(failure))
+
+ def send_device_messages(self, destination):
+ """As per TransactionQueue"""
+ pos = self._next_pos()
+ self.device_messages[pos] = destination
+
+ def get_current_token(self):
+ return self.pos - 1
+
+ def get_replication_rows(self, token, limit, federation_ack=None):
+ """
+ Args:
+ token (int)
+ limit (int)
+ federation_ack (int): Optional. The position where the worker is
+ explicitly acknowledged it has handled. Allows us to drop
+ data from before that point
+ """
+ # TODO: Handle limit.
+
+ # To handle restarts where we wrap around
+ if token > self.pos:
+ token = -1
+
+ rows = []
+
+ # There should be only one reader, so lets delete everything its
+ # acknowledged its seen.
+ if federation_ack:
+ self._clear_queue_before_pos(federation_ack)
+
+ # Fetch changed presence
+ keys = self.presence_changed.keys()
+ i = keys.bisect_right(token)
+ dest_user_ids = set(
+ (pos, dest_user_id)
+ for pos in keys[i:]
+ for dest_user_id in self.presence_changed[pos]
+ )
+
+ for (key, (dest, user_id)) in dest_user_ids:
+ rows.append((key, PRESENCE_TYPE, ujson.dumps({
+ "destination": dest,
+ "state": self.presence_map[user_id].as_dict(),
+ })))
+
+ # Fetch changes keyed edus
+ keys = self.keyed_edu_changed.keys()
+ i = keys.bisect_right(token)
+ keyed_edus = set((k, self.keyed_edu_changed[k]) for k in keys[i:])
+
+ for (pos, (destination, edu_key)) in keyed_edus:
+ rows.append(
+ (pos, KEYED_EDU_TYPE, ujson.dumps({
+ "key": edu_key,
+ "edu": self.keyed_edu[(destination, edu_key)].get_internal_dict(),
+ }))
+ )
+
+ # Fetch changed edus
+ keys = self.edus.keys()
+ i = keys.bisect_right(token)
+ edus = set((k, self.edus[k]) for k in keys[i:])
+
+ for (pos, edu) in edus:
+ rows.append((pos, EDU_TYPE, ujson.dumps(edu.get_internal_dict())))
+
+ # Fetch changed failures
+ keys = self.failures.keys()
+ i = keys.bisect_right(token)
+ failures = set((k, self.failures[k]) for k in keys[i:])
+
+ for (pos, (destination, failure)) in failures:
+ rows.append((pos, FAILURE_TYPE, ujson.dumps({
+ "destination": destination,
+ "failure": failure,
+ })))
+
+ # Fetch changed device messages
+ keys = self.device_messages.keys()
+ i = keys.bisect_right(token)
+ device_messages = set((k, self.device_messages[k]) for k in keys[i:])
+
+ for (pos, destination) in device_messages:
+ rows.append((pos, DEVICE_MESSAGE_TYPE, ujson.dumps({
+ "destination": destination,
+ })))
+
+ # Sort rows based on pos
+ rows.sort()
+
+ return rows
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index f8ca93e4c3..c94c74a67e 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
from .persistence import TransactionActions
from .units import Transaction, Edu
+from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor
from synapse.util.logcontext import preserve_context_over_fn
@@ -26,6 +27,7 @@ from synapse.util.retryutils import (
get_retry_limiter, NotRetryingDestination,
)
from synapse.util.metrics import measure_func
+from synapse.types import get_domain_from_id
from synapse.handlers.presence import format_user_presence_state
import synapse.metrics
@@ -36,6 +38,12 @@ logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
+client_metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
+sent_pdus_destination_dist = client_metrics.register_distribution(
+ "sent_pdu_destinations"
+)
+sent_edus_counter = client_metrics.register_counter("sent_edus")
+
class TransactionQueue(object):
"""This class makes sure we only have one transaction in flight at
@@ -44,13 +52,14 @@ class TransactionQueue(object):
It batches pending PDUs into single transactions.
"""
- def __init__(self, hs, transport_layer):
+ def __init__(self, hs):
self.server_name = hs.hostname
self.store = hs.get_datastore()
+ self.state = hs.get_state_handler()
self.transaction_actions = TransactionActions(self.store)
- self.transport_layer = transport_layer
+ self.transport_layer = hs.get_federation_transport_client()
self.clock = hs.get_clock()
@@ -95,6 +104,11 @@ class TransactionQueue(object):
# HACK to get unique tx id
self._next_txn_id = int(self.clock.time_msec())
+ self._order = 1
+
+ self._is_processing = False
+ self._last_poked_id = -1
+
def can_send_to(self, destination):
"""Can we send messages to the given server?
@@ -115,11 +129,61 @@ class TransactionQueue(object):
else:
return not destination.startswith("localhost")
- def enqueue_pdu(self, pdu, destinations, order):
+ @defer.inlineCallbacks
+ def notify_new_events(self, current_id):
+ """This gets called when we have some new events we might want to
+ send out to other servers.
+ """
+ self._last_poked_id = max(current_id, self._last_poked_id)
+
+ if self._is_processing:
+ return
+
+ try:
+ self._is_processing = True
+ while True:
+ last_token = yield self.store.get_federation_out_pos("events")
+ next_token, events = yield self.store.get_all_new_events_stream(
+ last_token, self._last_poked_id, limit=20,
+ )
+
+ logger.debug("Handling %s -> %s", last_token, next_token)
+
+ if not events and next_token >= self._last_poked_id:
+ break
+
+ for event in events:
+ users_in_room = yield self.state.get_current_user_in_room(
+ event.room_id, latest_event_ids=[event.event_id],
+ )
+
+ destinations = set(
+ get_domain_from_id(user_id) for user_id in users_in_room
+ )
+
+ if event.type == EventTypes.Member:
+ if event.content["membership"] == Membership.JOIN:
+ destinations.add(get_domain_from_id(event.state_key))
+
+ logger.debug("Sending %s to %r", event, destinations)
+
+ self._send_pdu(event, destinations)
+
+ yield self.store.update_federation_out_pos(
+ "events", next_token
+ )
+
+ finally:
+ self._is_processing = False
+
+ def _send_pdu(self, pdu, destinations):
# We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus
# table and we'll get back to it later.
+ order = self._order
+ self._order += 1
+
destinations = set(destinations)
destinations = set(
dest for dest in destinations if self.can_send_to(dest)
@@ -130,6 +194,8 @@ class TransactionQueue(object):
if not destinations:
return
+ sent_pdus_destination_dist.inc_by(len(destinations))
+
for destination in destinations:
self.pending_pdus_by_dest.setdefault(destination, []).append(
(pdu, order)
@@ -139,7 +205,10 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination
)
- def enqueue_presence(self, destination, states):
+ def send_presence(self, destination, states):
+ if not self.can_send_to(destination):
+ return
+
self.pending_presence_by_dest.setdefault(destination, {}).update({
state.user_id: state for state in states
})
@@ -148,12 +217,19 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination
)
- def enqueue_edu(self, edu, key=None):
- destination = edu.destination
+ def send_edu(self, destination, edu_type, content, key=None):
+ edu = Edu(
+ origin=self.server_name,
+ destination=destination,
+ edu_type=edu_type,
+ content=content,
+ )
if not self.can_send_to(destination):
return
+ sent_edus_counter.inc()
+
if key:
self.pending_edus_keyed_by_dest.setdefault(
destination, {}
@@ -165,7 +241,7 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination
)
- def enqueue_failure(self, failure, destination):
+ def send_failure(self, failure, destination):
if destination == self.server_name or destination == "localhost":
return
@@ -180,7 +256,7 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination
)
- def enqueue_device_messages(self, destination):
+ def send_device_messages(self, destination):
if destination == self.server_name or destination == "localhost":
return
@@ -191,6 +267,9 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination
)
+ def get_current_token(self):
+ return 0
+
@defer.inlineCallbacks
def _attempt_new_transaction(self, destination):
# list of (pending_pdu, deferred, order)
|