summary refs log tree commit diff
path: root/synapse/replication
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication')
-rw-r--r--synapse/replication/slave/storage/_base.py31
-rw-r--r--synapse/replication/slave/storage/account_data.py49
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py23
-rw-r--r--synapse/replication/slave/storage/devices.py24
-rw-r--r--synapse/replication/slave/storage/events.py57
-rw-r--r--synapse/replication/slave/storage/presence.py19
-rw-r--r--synapse/replication/slave/storage/push_rule.py23
-rw-r--r--synapse/replication/slave/storage/pushers.py16
-rw-r--r--synapse/replication/slave/storage/receipts.py24
-rw-r--r--synapse/replication/slave/storage/room.py11
-rw-r--r--synapse/replication/tcp/client.py196
-rw-r--r--synapse/replication/tcp/protocol.py54
-rw-r--r--synapse/replication/tcp/streams.py104
13 files changed, 422 insertions, 209 deletions
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index ab133db872..b962641166 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -15,7 +15,6 @@
 
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.engines import PostgresEngine
-from twisted.internet import defer
 
 from ._slaved_id_tracker import SlavedIdTracker
 
@@ -34,8 +33,7 @@ class BaseSlavedStore(SQLBaseStore):
         else:
             self._cache_id_gen = None
 
-        self.expire_cache_url = hs.config.worker_replication_url + "/expire_cache"
-        self.http_client = hs.get_simple_http_client()
+        self.hs = hs
 
     def stream_positions(self):
         pos = {}
@@ -43,35 +41,20 @@ class BaseSlavedStore(SQLBaseStore):
             pos["caches"] = self._cache_id_gen.get_current_token()
         return pos
 
-    def process_replication(self, result):
-        stream = result.get("caches")
-        if stream:
-            for row in stream["rows"]:
-                (
-                    position, cache_func, keys, invalidation_ts,
-                ) = row
-
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "caches":
+            self._cache_id_gen.advance(token)
+            for row in rows:
                 try:
-                    getattr(self, cache_func).invalidate(tuple(keys))
+                    getattr(self, row.cache_func).invalidate(tuple(row.keys))
                 except AttributeError:
                     # We probably haven't pulled in the cache in this worker,
                     # which is fine.
                     pass
-            self._cache_id_gen.advance(int(stream["position"]))
-        return defer.succeed(None)
 
     def _invalidate_cache_and_stream(self, txn, cache_func, keys):
         txn.call_after(cache_func.invalidate, keys)
         txn.call_after(self._send_invalidation_poke, cache_func, keys)
 
-    @defer.inlineCallbacks
     def _send_invalidation_poke(self, cache_func, keys):
-        try:
-            yield self.http_client.post_json_get_json(self.expire_cache_url, {
-                "invalidate": [{
-                    "name": cache_func.__name__,
-                    "keys": list(keys),
-                }]
-            })
-        except:
-            logger.exception("Failed to poke on expire_cache")
+        self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys)
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 77c64722c7..efbd87918e 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -69,38 +69,25 @@ class SlavedAccountDataStore(BaseSlavedStore):
         result["tag_account_data"] = position
         return result
 
-    def process_replication(self, result):
-        stream = result.get("user_account_data")
-        if stream:
-            self._account_data_id_gen.advance(int(stream["position"]))
-            for row in stream["rows"]:
-                position, user_id, data_type = row[:3]
-                self.get_global_account_data_by_type_for_user.invalidate(
-                    (data_type, user_id,)
-                )
-                self.get_account_data_for_user.invalidate((user_id,))
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "tag_account_data":
+            self._account_data_id_gen.advance(token)
+            for row in rows:
+                self.get_tags_for_user.invalidate((row.user_id,))
                 self._account_data_stream_cache.entity_has_changed(
-                    user_id, position
+                    row.user_id, token
                 )
-
-        stream = result.get("room_account_data")
-        if stream:
-            self._account_data_id_gen.advance(int(stream["position"]))
-            for row in stream["rows"]:
-                position, user_id = row[:2]
-                self.get_account_data_for_user.invalidate((user_id,))
+        elif stream_name == "account_data":
+            self._account_data_id_gen.advance(token)
+            for row in rows:
+                if not row.room_id:
+                    self.get_global_account_data_by_type_for_user.invalidate(
+                        (row.data_type, row.user_id,)
+                    )
+                self.get_account_data_for_user.invalidate((row.user_id,))
                 self._account_data_stream_cache.entity_has_changed(
-                    user_id, position
+                    row.user_id, token
                 )
-
-        stream = result.get("tag_account_data")
-        if stream:
-            self._account_data_id_gen.advance(int(stream["position"]))
-            for row in stream["rows"]:
-                position, user_id = row[:2]
-                self.get_tags_for_user.invalidate((user_id,))
-                self._account_data_stream_cache.entity_has_changed(
-                    user_id, position
-                )
-
-        return super(SlavedAccountDataStore, self).process_replication(result)
+        return super(SlavedAccountDataStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index f9102e0d89..6f3fb64770 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -53,21 +53,18 @@ class SlavedDeviceInboxStore(BaseSlavedStore):
         result["to_device"] = self._device_inbox_id_gen.get_current_token()
         return result
 
-    def process_replication(self, result):
-        stream = result.get("to_device")
-        if stream:
-            self._device_inbox_id_gen.advance(int(stream["position"]))
-            for row in stream["rows"]:
-                stream_id = row[0]
-                entity = row[1]
-
-                if entity.startswith("@"):
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "to_device":
+            self._device_inbox_id_gen.advance(token)
+            for row in rows:
+                if row.entity.startswith("@"):
                     self._device_inbox_stream_cache.entity_has_changed(
-                        entity, stream_id
+                        row.entity, token
                     )
                 else:
                     self._device_federation_outbox_stream_cache.entity_has_changed(
-                        entity, stream_id
+                        row.entity, token
                     )
-
-        return super(SlavedDeviceInboxStore, self).process_replication(result)
+        return super(SlavedDeviceInboxStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index ca46aa17b6..4d4a435471 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -51,22 +51,18 @@ class SlavedDeviceStore(BaseSlavedStore):
         result["device_lists"] = self._device_list_id_gen.get_current_token()
         return result
 
-    def process_replication(self, result):
-        stream = result.get("device_lists")
-        if stream:
-            self._device_list_id_gen.advance(int(stream["position"]))
-            for row in stream["rows"]:
-                stream_id = row[0]
-                user_id = row[1]
-                destination = row[2]
-
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "device_lists":
+            self._device_list_id_gen.advance(token)
+            for row in rows:
                 self._device_list_stream_cache.entity_has_changed(
-                    user_id, stream_id
+                    row.user_id, token
                 )
 
-                if destination:
+                if row.destination:
                     self._device_list_federation_stream_cache.entity_has_changed(
-                        destination, stream_id
+                        row.destination, token
                     )
-
-        return super(SlavedDeviceStore, self).process_replication(result)
+        return super(SlavedDeviceStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index d4db1e452e..5fd47706ef 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -201,48 +201,25 @@ class SlavedEventStore(BaseSlavedStore):
         result["backfill"] = -self._backfill_id_gen.get_current_token()
         return result
 
-    def process_replication(self, result):
-        stream = result.get("events")
-        if stream:
-            self._stream_id_gen.advance(int(stream["position"]))
-
-            if stream["rows"]:
-                logger.info("Got %d event rows", len(stream["rows"]))
-
-            for row in stream["rows"]:
-                self._process_replication_row(
-                    row, backfilled=False,
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "events":
+            self._stream_id_gen.advance(token)
+            for row in rows:
+                self.invalidate_caches_for_event(
+                    token, row.event_id, row.room_id, row.type, row.state_key,
+                    row.redacts,
+                    backfilled=False,
                 )
-
-        stream = result.get("backfill")
-        if stream:
-            self._backfill_id_gen.advance(-int(stream["position"]))
-            for row in stream["rows"]:
-                self._process_replication_row(
-                    row, backfilled=True,
+        elif stream_name == "backfill":
+            self._backfill_id_gen.advance(-token)
+            for row in rows:
+                self.invalidate_caches_for_event(
+                    -token, row.event_id, row.room_id, row.type, row.state_key,
+                    row.redacts,
+                    backfilled=True,
                 )
-
-        stream = result.get("forward_ex_outliers")
-        if stream:
-            self._stream_id_gen.advance(int(stream["position"]))
-            for row in stream["rows"]:
-                event_id = row[1]
-                self._invalidate_get_event_cache(event_id)
-
-        stream = result.get("backward_ex_outliers")
-        if stream:
-            self._backfill_id_gen.advance(-int(stream["position"]))
-            for row in stream["rows"]:
-                event_id = row[1]
-                self._invalidate_get_event_cache(event_id)
-
-        return super(SlavedEventStore, self).process_replication(result)
-
-    def _process_replication_row(self, row, backfilled):
-        stream_ordering = row[0] if not backfilled else -row[0]
-        self.invalidate_caches_for_event(
-            stream_ordering, row[1], row[2], row[3], row[4], row[5],
-            backfilled=backfilled,
+        return super(SlavedEventStore, self).process_replication_rows(
+            stream_name, token, rows
         )
 
     def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index e4a2414d78..dffc80adc3 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -48,15 +48,14 @@ class SlavedPresenceStore(BaseSlavedStore):
         result["presence"] = position
         return result
 
-    def process_replication(self, result):
-        stream = result.get("presence")
-        if stream:
-            self._presence_id_gen.advance(int(stream["position"]))
-            for row in stream["rows"]:
-                position, user_id = row[:2]
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "presence":
+            self._presence_id_gen.advance(token)
+            for row in rows:
                 self.presence_stream_cache.entity_has_changed(
-                    user_id, position
+                    row.user_id, token
                 )
-                self._get_presence_for_user.invalidate((user_id,))
-
-        return super(SlavedPresenceStore, self).process_replication(result)
+                self._get_presence_for_user.invalidate((row.user_id,))
+        return super(SlavedPresenceStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 21ceb0213a..83e880fdd2 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -50,18 +50,15 @@ class SlavedPushRuleStore(SlavedEventStore):
         result["push_rules"] = self._push_rules_stream_id_gen.get_current_token()
         return result
 
-    def process_replication(self, result):
-        stream = result.get("push_rules")
-        if stream:
-            for row in stream["rows"]:
-                position = row[0]
-                user_id = row[2]
-                self.get_push_rules_for_user.invalidate((user_id,))
-                self.get_push_rules_enabled_for_user.invalidate((user_id,))
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "push_rules":
+            self._push_rules_stream_id_gen.advance(token)
+            for row in rows:
+                self.get_push_rules_for_user.invalidate((row.user_id,))
+                self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
                 self.push_rules_stream_cache.entity_has_changed(
-                    user_id, position
+                    row.user_id, token
                 )
-
-            self._push_rules_stream_id_gen.advance(int(stream["position"]))
-
-        return super(SlavedPushRuleStore, self).process_replication(result)
+        return super(SlavedPushRuleStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index d88206b3bb..4e8d68ece9 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -40,13 +40,9 @@ class SlavedPusherStore(BaseSlavedStore):
         result["pushers"] = self._pushers_id_gen.get_current_token()
         return result
 
-    def process_replication(self, result):
-        stream = result.get("pushers")
-        if stream:
-            self._pushers_id_gen.advance(int(stream["position"]))
-
-        stream = result.get("deleted_pushers")
-        if stream:
-            self._pushers_id_gen.advance(int(stream["position"]))
-
-        return super(SlavedPusherStore, self).process_replication(result)
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "pushers":
+            self._pushers_id_gen.advance(token)
+        return super(SlavedPusherStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index ac9662d399..b371574ece 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -65,20 +65,22 @@ class SlavedReceiptsStore(BaseSlavedStore):
         result["receipts"] = self._receipts_id_gen.get_current_token()
         return result
 
-    def process_replication(self, result):
-        stream = result.get("receipts")
-        if stream:
-            self._receipts_id_gen.advance(int(stream["position"]))
-            for row in stream["rows"]:
-                position, room_id, receipt_type, user_id = row[:4]
-                self.invalidate_caches_for_receipt(room_id, receipt_type, user_id)
-                self._receipts_stream_cache.entity_has_changed(room_id, position)
-
-        return super(SlavedReceiptsStore, self).process_replication(result)
-
     def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
         self.get_receipts_for_user.invalidate((user_id, receipt_type))
         self.get_linearized_receipts_for_room.invalidate_many((room_id,))
         self.get_last_receipt_event_id_for_user.invalidate(
             (user_id, room_id, receipt_type)
         )
+
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "receipts":
+            self._receipts_id_gen.advance(token)
+            for row in rows:
+                self.invalidate_caches_for_receipt(
+                    row.room_id, row.receipt_type, row.user_id
+                )
+                self._receipts_stream_cache.entity_has_changed(row.room_id, token)
+
+        return super(SlavedReceiptsStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 6df9a25ef3..f510384033 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -46,9 +46,10 @@ class RoomStore(BaseSlavedStore):
         result["public_rooms"] = self._public_room_id_gen.get_current_token()
         return result
 
-    def process_replication(self, result):
-        stream = result.get("public_rooms")
-        if stream:
-            self._public_room_id_gen.advance(int(stream["position"]))
+    def process_replication_rows(self, stream_name, token, rows):
+        if stream_name == "public_rooms":
+            self._public_room_id_gen.advance(token)
 
-        return super(RoomStore, self).process_replication(result)
+        return super(RoomStore, self).process_replication_rows(
+            stream_name, token, rows
+        )
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
new file mode 100644
index 0000000000..90fb6c1336
--- /dev/null
+++ b/synapse/replication/tcp/client.py
@@ -0,0 +1,196 @@
+# -*- coding: utf-8 -*-
+# Copyright 2017 Vector Creations Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""A replication client for use by synapse workers.
+"""
+
+from twisted.internet import reactor, defer
+from twisted.internet.protocol import ReconnectingClientFactory
+
+from .commands import (
+    FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand,
+)
+from .protocol import ClientReplicationStreamProtocol
+
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class ReplicationClientFactory(ReconnectingClientFactory):
+    """Factory for building connections to the master. Will reconnect if the
+    connection is lost.
+
+    Accepts a handler that will be called when new data is available or data
+    is required.
+    """
+    maxDelay = 5  # Try at least once every N seconds
+
+    def __init__(self, hs, client_name, handler):
+        self.client_name = client_name
+        self.handler = handler
+        self.server_name = hs.config.server_name
+        self._clock = hs.get_clock()  # As self.clock is defined in super class
+
+        reactor.addSystemEventTrigger("before", "shutdown", self.stopTrying)
+
+    def startedConnecting(self, connector):
+        logger.info("Connecting to replication: %r", connector.getDestination())
+
+    def buildProtocol(self, addr):
+        logger.info("Connected to replication: %r", addr)
+        self.resetDelay()
+        return ClientReplicationStreamProtocol(
+            self.client_name, self.server_name, self._clock, self.handler
+        )
+
+    def clientConnectionLost(self, connector, reason):
+        logger.error("Lost replication conn: %r", reason)
+        ReconnectingClientFactory.clientConnectionLost(self, connector, reason)
+
+    def clientConnectionFailed(self, connector, reason):
+        logger.error("Failed to connect to replication: %r", reason)
+        ReconnectingClientFactory.clientConnectionFailed(
+            self, connector, reason
+        )
+
+
+class ReplicationClientHandler(object):
+    """A base handler that can be passed to the ReplicationClientFactory.
+
+    By default proxies incoming replication data to the SlaveStore.
+    """
+    def __init__(self, store):
+        self.store = store
+
+        # The current connection. None if we are currently (re)connecting
+        self.connection = None
+
+        # Any pending commands to be sent once a new connection has been
+        # established
+        self.pending_commands = []
+
+        # Map from string -> deferred, to wake up when receiveing a SYNC with
+        # the given string.
+        # Used for tests.
+        self.awaiting_syncs = {}
+
+    def start_replication(self, hs):
+        """Helper method to start a replication connection to the remote server
+        using TCP.
+        """
+        client_name = hs.config.worker_name
+        factory = ReplicationClientFactory(hs, client_name, self)
+        host = hs.config.worker_replication_host
+        port = hs.config.worker_replication_port
+        reactor.connectTCP(host, port, factory)
+
+    def on_rdata(self, stream_name, token, rows):
+        """Called when we get new replication data. By default this just pokes
+        the slave store.
+
+        Can be overriden in subclasses to handle more.
+        """
+        logger.info("Received rdata %s -> %s", stream_name, token)
+        self.store.process_replication_rows(stream_name, token, rows)
+
+    def on_position(self, stream_name, token):
+        """Called when we get new position data. By default this just pokes
+        the slave store.
+
+        Can be overriden in subclasses to handle more.
+        """
+        self.store.process_replication_rows(stream_name, token, [])
+
+    def on_sync(self, data):
+        """When we received a SYNC we wake up any deferreds that were waiting
+        for the sync with the given data.
+
+        Used by tests.
+        """
+        d = self.awaiting_syncs.pop(data, None)
+        if d:
+            d.callback(data)
+
+    def get_streams_to_replicate(self):
+        """Called when a new connection has been established and we need to
+        subscribe to streams.
+
+        Returns a dictionary of stream name to token.
+        """
+        args = self.store.stream_positions()
+        user_account_data = args.pop("user_account_data", None)
+        room_account_data = args.pop("room_account_data", None)
+        if user_account_data:
+            args["account_data"] = user_account_data
+        elif room_account_data:
+            args["account_data"] = room_account_data
+        return args
+
+    def get_currently_syncing_users(self):
+        """Get the list of currently syncing users (if any). This is called
+        when a connection has been established and we need to send the
+        currently syncing users. (Overriden by the synchrotron's only)
+        """
+        return []
+
+    def send_command(self, cmd):
+        """Send a command to master (when we get establish a connection if we
+        don't have one already.)
+        """
+        if self.connection:
+            self.connection.send_command(cmd)
+        else:
+            logger.warn("Queuing command as not connected: %r", cmd.NAME)
+            self.pending_commands.append(cmd)
+
+    def send_federation_ack(self, token):
+        """Ack data for the federation stream. This allows the master to drop
+        data stored purely in memory.
+        """
+        self.send_command(FederationAckCommand(token))
+
+    def send_user_sync(self, user_id, is_syncing, last_sync_ms):
+        """Poke the master that a user has started/stopped syncing.
+        """
+        self.send_command(UserSyncCommand(user_id, is_syncing, last_sync_ms))
+
+    def send_remove_pusher(self, app_id, push_key, user_id):
+        """Poke the master to remove a pusher for a user
+        """
+        cmd = RemovePusherCommand(app_id, push_key, user_id)
+        self.send_command(cmd)
+
+    def send_invalidate_cache(self, cache_func, keys):
+        """Poke the master to invalidate a cache.
+        """
+        cmd = InvalidateCacheCommand(cache_func.__name__, keys)
+        self.send_command(cmd)
+
+    def await_sync(self, data):
+        """Returns a deferred that is resolved when we receive a SYNC command
+        with given data.
+
+        Used by tests.
+        """
+        return self.awaiting_syncs.setdefault(data, defer.Deferred())
+
+    def update_connection(self, connection):
+        """Called when a connection has been established (or lost with None).
+        """
+        self.connection = connection
+        if connection:
+            for cmd in self.pending_commands:
+                connection.send_command(cmd)
+            self.pending_commands = []
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 6864204616..5770b7125a 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -51,6 +51,7 @@ indicate which side is sending, these are *not* included on the wire::
 
 from twisted.internet import defer
 from twisted.protocols.basic import LineOnlyReceiver
+from twisted.python.failure import Failure
 
 from commands import (
     COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS,
@@ -60,6 +61,7 @@ from commands import (
 from streams import STREAMS_MAP
 
 from synapse.util.stringutils import random_string
+from synapse.metrics.metric import CounterMetric
 
 import logging
 import synapse.metrics
@@ -69,11 +71,8 @@ import fcntl
 
 metrics = synapse.metrics.get_metrics_for(__name__)
 
-inbound_commands_counter = metrics.register_counter(
-    "inbound_commands", labels=["command", "name", "conn_id"],
-)
-outbound_commands_counter = metrics.register_counter(
-    "outbound_commands", labels=["command", "name", "conn_id"],
+connection_close_counter = metrics.register_counter(
+    "close_reason", labels=["reason_type"],
 )
 
 
@@ -86,6 +85,8 @@ logger = logging.getLogger(__name__)
 
 
 PING_TIME = 5000
+PING_TIMEOUT_MULTIPLIER = 5
+PING_TIMEOUT_MS = PING_TIME * PING_TIMEOUT_MULTIPLIER
 
 
 class ConnectionStates(object):
@@ -135,6 +136,13 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         # The LoopingCall for sending pings.
         self._send_ping_loop = None
 
+        self.inbound_commands_counter = CounterMetric(
+            "inbound_commands", labels=["command"],
+        )
+        self.outbound_commands_counter = CounterMetric(
+            "outbound_commands", labels=["command"],
+        )
+
     def connectionMade(self):
         logger.info("[%s] Connection established", self.id())
 
@@ -160,7 +168,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         now = self.clock.time_msec()
 
         if self.time_we_closed:
-            if now - self.time_we_closed > PING_TIME * 3:
+            if now - self.time_we_closed > PING_TIMEOUT_MS:
                 logger.info(
                     "[%s] Failed to close connection gracefully, aborting", self.id()
                 )
@@ -169,7 +177,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
             if now - self.last_sent_command >= PING_TIME:
                 self.send_command(PingCommand(now))
 
-            if self.received_ping and now - self.last_received_command > PING_TIME * 3:
+            if self.received_ping and now - self.last_received_command > PING_TIMEOUT_MS:
                 logger.info(
                     "[%s] Connection hasn't received command in %r ms. Closing.",
                     self.id(), now - self.last_received_command
@@ -193,7 +201,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
         self.last_received_command = self.clock.time_msec()
 
-        inbound_commands_counter.inc(cmd_name, self.name, self.conn_id)
+        self.inbound_commands_counter.inc(cmd_name)
 
         cmd_cls = COMMAND_MAP[cmd_name]
         try:
@@ -214,6 +222,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
             logger.exception("[%s] Failed to handle line: %r", self.id(), line)
 
     def close(self):
+        logger.warn("[%s] Closing connection", self.id())
         self.time_we_closed = self.clock.time_msec()
         self.transport.loseConnection()
         self.on_connection_closed()
@@ -242,7 +251,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
             self._queue_command(cmd)
             return
 
-        outbound_commands_counter.inc(cmd.NAME, self.name, self.conn_id)
+        self.outbound_commands_counter.inc(cmd.NAME)
 
         string = "%s %s" % (cmd.NAME, cmd.to_line(),)
         if "\n" in string:
@@ -307,6 +316,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
     def connectionLost(self, reason):
         logger.info("[%s] Replication connection closed: %r", self.id(), reason)
+        if isinstance(reason, Failure):
+            connection_close_counter.inc(reason.type.__name__)
+        else:
+            connection_close_counter.inc(reason.__class__.__name__)
 
         try:
             # Remove us from list of connections to be monitored
@@ -495,7 +508,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
     def on_SERVER(self, cmd):
         if cmd.data != self.server_name:
             logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
-            self.transport.abortConnection()
+            self.send_error("Wrong remote")
 
     def on_RDATA(self, cmd):
         try:
@@ -604,3 +617,24 @@ metrics.register_callback(
     },
     labels=["name", "conn_id"],
 )
+
+
+metrics.register_callback(
+    "inbound_commands",
+    lambda: {
+        (k[0], p.name, p.conn_id): count
+        for p in connected_connections
+        for k, count in p.inbound_commands_counter.counts.iteritems()
+    },
+    labels=["command", "name", "conn_id"],
+)
+
+metrics.register_callback(
+    "outbound_commands",
+    lambda: {
+        (k[0], p.name, p.conn_id): count
+        for p in connected_connections
+        for k, count in p.outbound_commands_counter.counts.iteritems()
+    },
+    labels=["command", "name", "conn_id"],
+)
diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py
index 4de4ebe84d..369d5f2428 100644
--- a/synapse/replication/tcp/streams.py
+++ b/synapse/replication/tcp/streams.py
@@ -36,34 +36,82 @@ logger = logging.getLogger(__name__)
 MAX_EVENTS_BEHIND = 10000
 
 
-EventStreamRow = namedtuple("EventStreamRow",
-                            ("event_id", "room_id", "type", "state_key", "redacts"))
-BackfillStreamRow = namedtuple("BackfillStreamRow",
-                               ("event_id", "room_id", "type", "state_key", "redacts"))
-PresenceStreamRow = namedtuple("PresenceStreamRow",
-                               ("user_id", "state", "last_active_ts",
-                                "last_federation_update_ts", "last_user_sync_ts",
-                                "status_msg", "currently_active"))
-TypingStreamRow = namedtuple("TypingStreamRow",
-                             ("room_id", "user_ids"))
-ReceiptsStreamRow = namedtuple("ReceiptsStreamRow",
-                               ("room_id", "receipt_type", "user_id", "event_id",
-                                "data"))
-PushRulesStreamRow = namedtuple("PushRulesStreamRow", ("user_id",))
-PushersStreamRow = namedtuple("PushersStreamRow",
-                              ("user_id", "app_id", "pushkey", "deleted",))
-CachesStreamRow = namedtuple("CachesStreamRow",
-                             ("cache_func", "keys", "invalidation_ts",))
-PublicRoomsStreamRow = namedtuple("PublicRoomsStreamRow",
-                                  ("room_id", "visibility", "appservice_id",
-                                   "network_id",))
-DeviceListsStreamRow = namedtuple("DeviceListsStreamRow", ("user_id", "destination",))
-ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",))
-FederationStreamRow = namedtuple("FederationStreamRow", ("type", "data",))
-TagAccountDataStreamRow = namedtuple("TagAccountDataStreamRow",
-                                     ("user_id", "room_id", "data"))
-AccountDataStreamRow = namedtuple("AccountDataStream",
-                                  ("user_id", "room_id", "data_type", "data"))
+EventStreamRow = namedtuple("EventStreamRow", (
+    "event_id",  # str
+    "room_id",  # str
+    "type",  # str
+    "state_key",  # str, optional
+    "redacts",  # str, optional
+))
+BackfillStreamRow = namedtuple("BackfillStreamRow", (
+    "event_id",  # str
+    "room_id",  # str
+    "type",  # str
+    "state_key",  # str, optional
+    "redacts",  # str, optional
+))
+PresenceStreamRow = namedtuple("PresenceStreamRow", (
+    "user_id",  # str
+    "state",  # str
+    "last_active_ts",  # int
+    "last_federation_update_ts",  # int
+    "last_user_sync_ts",  # int
+    "status_msg",   # str
+    "currently_active",  # bool
+))
+TypingStreamRow = namedtuple("TypingStreamRow", (
+    "room_id",  # str
+    "user_ids",  # list(str)
+))
+ReceiptsStreamRow = namedtuple("ReceiptsStreamRow", (
+    "room_id",  # str
+    "receipt_type",  # str
+    "user_id",  # str
+    "event_id",  # str
+    "data",  # dict
+))
+PushRulesStreamRow = namedtuple("PushRulesStreamRow", (
+    "user_id",  # str
+))
+PushersStreamRow = namedtuple("PushersStreamRow", (
+    "user_id",  # str
+    "app_id",  # str
+    "pushkey",  # str
+    "deleted",  # bool
+))
+CachesStreamRow = namedtuple("CachesStreamRow", (
+    "cache_func",  # str
+    "keys",  # list(str)
+    "invalidation_ts",  # int
+))
+PublicRoomsStreamRow = namedtuple("PublicRoomsStreamRow", (
+    "room_id",  # str
+    "visibility",  # str
+    "appservice_id",  # str, optional
+    "network_id",  # str, optional
+))
+DeviceListsStreamRow = namedtuple("DeviceListsStreamRow", (
+    "user_id",  # str
+    "destination",  # str
+))
+ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", (
+    "entity",  # str
+))
+FederationStreamRow = namedtuple("FederationStreamRow", (
+    "type",  # str, the type of data as defined in the BaseFederationRows
+    "data",  # dict, serialization of a federation.send_queue.BaseFederationRow
+))
+TagAccountDataStreamRow = namedtuple("TagAccountDataStreamRow", (
+    "user_id",  # str
+    "room_id",  # str
+    "data",  # dict
+))
+AccountDataStreamRow = namedtuple("AccountDataStream", (
+    "user_id",  # str
+    "room_id",  # str
+    "data_type",  # str
+    "data",  # dict
+))
 
 
 class Stream(object):