summary refs log tree commit diff
path: root/synapse/replication/tcp
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r--synapse/replication/tcp/client.py7
-rw-r--r--synapse/replication/tcp/commands.py32
-rw-r--r--synapse/replication/tcp/protocol.py6
-rw-r--r--synapse/replication/tcp/resource.py13
-rw-r--r--synapse/replication/tcp/streams.py22
5 files changed, 79 insertions, 1 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 90fb6c1336..6d2513c4e2 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -20,6 +20,7 @@ from twisted.internet.protocol import ReconnectingClientFactory
 
 from .commands import (
     FederationAckCommand, UserSyncCommand, RemovePusherCommand, InvalidateCacheCommand,
+    UserIpCommand,
 )
 from .protocol import ClientReplicationStreamProtocol
 
@@ -178,6 +179,12 @@ class ReplicationClientHandler(object):
         cmd = InvalidateCacheCommand(cache_func.__name__, keys)
         self.send_command(cmd)
 
+    def send_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
+        """Tell the master that the user made a request.
+        """
+        cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
+        self.send_command(cmd)
+
     def await_sync(self, data):
         """Returns a deferred that is resolved when we receive a SYNC command
         with given data.
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index 84d2a2272a..a009214e43 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -304,6 +304,36 @@ class InvalidateCacheCommand(Command):
         return " ".join((self.cache_func, json.dumps(self.keys)))
 
 
+class UserIpCommand(Command):
+    """Sent periodically when a worker sees activity from a client.
+
+    Format::
+
+        USER_IP <user_id>, <access_token>, <ip>, <device_id>, <last_seen>, <user_agent>
+    """
+    NAME = "USER_IP"
+
+    def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen):
+        self.user_id = user_id
+        self.access_token = access_token
+        self.ip = ip
+        self.user_agent = user_agent
+        self.device_id = device_id
+        self.last_seen = last_seen
+
+    @classmethod
+    def from_line(cls, line):
+        user_id, access_token, ip, device_id, last_seen, user_agent = line.split(" ", 5)
+
+        return cls(user_id, access_token, ip, user_agent, device_id, int(last_seen))
+
+    def to_line(self):
+        return " ".join((
+            self.user_id, self.access_token, self.ip, self.device_id,
+            str(self.last_seen), self.user_agent,
+        ))
+
+
 # Map of command name to command type.
 COMMAND_MAP = {
     cmd.NAME: cmd
@@ -320,6 +350,7 @@ COMMAND_MAP = {
         SyncCommand,
         RemovePusherCommand,
         InvalidateCacheCommand,
+        UserIpCommand,
     )
 }
 
@@ -342,5 +373,6 @@ VALID_CLIENT_COMMANDS = (
     FederationAckCommand.NAME,
     RemovePusherCommand.NAME,
     InvalidateCacheCommand.NAME,
+    UserIpCommand.NAME,
     ErrorCommand.NAME,
 )
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index 9fee2a484b..062272f8dd 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -406,6 +406,12 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
     def on_INVALIDATE_CACHE(self, cmd):
         self.streamer.on_invalidate_cache(cmd.cache_func, cmd.keys)
 
+    def on_USER_IP(self, cmd):
+        self.streamer.on_user_ip(
+            cmd.user_id, cmd.access_token, cmd.ip, cmd.user_agent, cmd.device_id,
+            cmd.last_seen,
+        )
+
     @defer.inlineCallbacks
     def subscribe_to_stream(self, stream_name, token):
         """Subscribe the remote to a streams.
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index 8b2c4c3043..3ea3ca5a6f 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -35,6 +35,7 @@ user_sync_counter = metrics.register_counter("user_sync")
 federation_ack_counter = metrics.register_counter("federation_ack")
 remove_pusher_counter = metrics.register_counter("remove_pusher")
 invalidate_cache_counter = metrics.register_counter("invalidate_cache")
+user_ip_cache_counter = metrics.register_counter("user_ip_cache")
 
 logger = logging.getLogger(__name__)
 
@@ -67,6 +68,7 @@ class ReplicationStreamer(object):
         self.store = hs.get_datastore()
         self.presence_handler = hs.get_presence_handler()
         self.clock = hs.get_clock()
+        self.notifier = hs.get_notifier()
 
         # Current connections.
         self.connections = []
@@ -99,7 +101,7 @@ class ReplicationStreamer(object):
         if not hs.config.send_federation:
             self.federation_sender = hs.get_federation_sender()
 
-        hs.get_notifier().add_replication_callback(self.on_notifier_poke)
+        self.notifier.add_replication_callback(self.on_notifier_poke)
 
         # Keeps track of whether we are currently checking for updates
         self.is_looping = False
@@ -237,6 +239,15 @@ class ReplicationStreamer(object):
         invalidate_cache_counter.inc()
         getattr(self.store, cache_func).invalidate(tuple(keys))
 
+    @measure_func("repl.on_user_ip")
+    def on_user_ip(self, user_id, access_token, ip, user_agent, device_id, last_seen):
+        """The client saw a user request
+        """
+        user_ip_cache_counter.inc()
+        self.store.insert_client_ip(
+            user_id, access_token, ip, user_agent, device_id, last_seen,
+        )
+
     def send_sync_to_all_connections(self, data):
         """Sends a SYNC command to all clients.
 
diff --git a/synapse/replication/tcp/streams.py b/synapse/replication/tcp/streams.py
index 369d5f2428..fbafe12cc2 100644
--- a/synapse/replication/tcp/streams.py
+++ b/synapse/replication/tcp/streams.py
@@ -112,6 +112,12 @@ AccountDataStreamRow = namedtuple("AccountDataStream", (
     "data_type",  # str
     "data",  # dict
 ))
+CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", (
+    "room_id",  # str
+    "type",  # str
+    "state_key",  # str
+    "event_id",  # str, optional
+))
 
 
 class Stream(object):
@@ -443,6 +449,21 @@ class AccountDataStream(Stream):
         defer.returnValue(results)
 
 
+class CurrentStateDeltaStream(Stream):
+    """Current state for a room was changed
+    """
+    NAME = "current_state_deltas"
+    ROW_TYPE = CurrentStateDeltaStreamRow
+
+    def __init__(self, hs):
+        store = hs.get_datastore()
+
+        self.current_token = store.get_max_current_state_delta_stream_id
+        self.update_function = store.get_all_updated_current_state_deltas
+
+        super(CurrentStateDeltaStream, self).__init__(hs)
+
+
 STREAMS_MAP = {
     stream.NAME: stream
     for stream in (
@@ -460,5 +481,6 @@ STREAMS_MAP = {
         FederationStream,
         TagAccountDataStream,
         AccountDataStream,
+        CurrentStateDeltaStream,
     )
 }