| diff --git a/synapse/metrics/metric.py b/synapse/metrics/metric.py
index e87b2b80a7..920cde1dd0 100644
--- a/synapse/metrics/metric.py
+++ b/synapse/metrics/metric.py
@@ -83,6 +83,9 @@ class CounterMetric(BaseMetric):
     def render(self):
         return map_concat(self.render_item, sorted(self.counts.keys()))
 
+    def unregister_counter(self, *values):
+        self.counts.pop(values, None)
+
 
 class CallbackMetric(BaseMetric):
     """A metric that returns the numeric value returned by a callback whenever
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
 index 6864204616..4f44836c2f 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -51,6 +51,7 @@ indicate which side is sending, these are *not* included on the wire::
 
 from twisted.internet import defer
 from twisted.protocols.basic import LineOnlyReceiver
+from twisted.python.failure import Failure
 
 from commands import (
     COMMAND_MAP, VALID_CLIENT_COMMANDS, VALID_SERVER_COMMANDS,
@@ -75,6 +76,9 @@ inbound_commands_counter = metrics.register_counter(
 outbound_commands_counter = metrics.register_counter(
     "outbound_commands", labels=["command", "name", "conn_id"],
 )
+connection_close_counter = metrics.register_counter(
+    "close_reason", labels=["reason_type"],
+)
 
 
 # A list of all connected protocols. This allows us to send metrics about the
@@ -307,6 +311,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
 
     def connectionLost(self, reason):
         logger.info("[%s] Replication connection closed: %r", self.id(), reason)
+        if isinstance(reason, Failure):
+            connection_close_counter.inc(reason.type.__name__)
+        else:
+            connection_close_counter.inc(reason.__class__.__name__)
 
         try:
             # Remove us from list of connections to be monitored
@@ -326,6 +334,14 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
         self.state = ConnectionStates.CLOSED
         self.pending_commands = []
 
+        for cmd in COMMAND_MAP:
+            outbound_commands_counter.unregister_counter(
+                cmd.NAME, self.name, self.conn_id
+            )
+            inbound_commands_counter.unregister_counter(
+                cmd.NAME, self.name, self.conn_id
+            )
+
         if self.transport:
             self.transport.unregisterProducer()
 |