summary refs log tree commit diff
path: root/synapse/replication/tcp
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-05-07 13:51:08 +0100
committerGitHub <noreply@github.com>2020-05-07 13:51:08 +0100
commitd7983b63a6746d92225295f1e9d521f847cf8ba7 (patch)
tree13d581210d94c26bd75036592996f6f53f7d4bb2 /synapse/replication/tcp
parentMerge pull request #7398 from Starbix/alpine-3.11 (diff)
downloadsynapse-d7983b63a6746d92225295f1e9d521f847cf8ba7.tar.xz
Support any process writing to cache invalidation stream. (#7436)
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r--synapse/replication/tcp/client.py6
-rw-r--r--synapse/replication/tcp/commands.py33
-rw-r--r--synapse/replication/tcp/handler.py42
-rw-r--r--synapse/replication/tcp/resource.py22
-rw-r--r--synapse/replication/tcp/streams/_base.py87
-rw-r--r--synapse/replication/tcp/streams/events.py4
-rw-r--r--synapse/replication/tcp/streams/federation.py12
7 files changed, 100 insertions, 106 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 3bbf3c3569..20cb8a654f 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -100,10 +100,10 @@ class ReplicationDataHandler:
             token: stream token for this batch of rows
             rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
         """
-        self.store.process_replication_rows(stream_name, token, rows)
+        self.store.process_replication_rows(stream_name, instance_name, token, rows)
 
-    async def on_position(self, stream_name: str, token: int):
-        self.store.process_replication_rows(stream_name, token, [])
+    async def on_position(self, stream_name: str, instance_name: str, token: int):
+        self.store.process_replication_rows(stream_name, instance_name, token, [])
 
     def on_remote_server_up(self, server: str):
         """Called when get a new REMOTE_SERVER_UP command."""
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index f58e384d17..c04f622816 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -341,37 +341,6 @@ class RemovePusherCommand(Command):
         return " ".join((self.app_id, self.push_key, self.user_id))
 
 
-class InvalidateCacheCommand(Command):
-    """Sent by the client to invalidate an upstream cache.
-
-    THIS IS NOT RELIABLE, AND SHOULD *NOT* BE USED ACCEPT FOR THINGS THAT ARE
-    NOT DISASTROUS IF WE DROP ON THE FLOOR.
-
-    Mainly used to invalidate destination retry timing caches.
-
-    Format::
-
-        INVALIDATE_CACHE <cache_func> <keys_json>
-
-    Where <keys_json> is a json list.
-    """
-
-    NAME = "INVALIDATE_CACHE"
-
-    def __init__(self, cache_func, keys):
-        self.cache_func = cache_func
-        self.keys = keys
-
-    @classmethod
-    def from_line(cls, line):
-        cache_func, keys_json = line.split(" ", 1)
-
-        return cls(cache_func, json.loads(keys_json))
-
-    def to_line(self):
-        return " ".join((self.cache_func, _json_encoder.encode(self.keys)))
-
-
 class UserIpCommand(Command):
     """Sent periodically when a worker sees activity from a client.
 
@@ -439,7 +408,6 @@ _COMMANDS = (
     UserSyncCommand,
     FederationAckCommand,
     RemovePusherCommand,
-    InvalidateCacheCommand,
     UserIpCommand,
     RemoteServerUpCommand,
     ClearUserSyncsCommand,
@@ -467,7 +435,6 @@ VALID_CLIENT_COMMANDS = (
     ClearUserSyncsCommand.NAME,
     FederationAckCommand.NAME,
     RemovePusherCommand.NAME,
-    InvalidateCacheCommand.NAME,
     UserIpCommand.NAME,
     ErrorCommand.NAME,
     RemoteServerUpCommand.NAME,
diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py
index b14a3d9fca..7c5d6c76e7 100644
--- a/synapse/replication/tcp/handler.py
+++ b/synapse/replication/tcp/handler.py
@@ -15,18 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import (
-    Any,
-    Callable,
-    Dict,
-    Iterable,
-    Iterator,
-    List,
-    Optional,
-    Set,
-    Tuple,
-    TypeVar,
-)
+from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
 
 from prometheus_client import Counter
 
@@ -38,7 +27,6 @@ from synapse.replication.tcp.commands import (
     ClearUserSyncsCommand,
     Command,
     FederationAckCommand,
-    InvalidateCacheCommand,
     PositionCommand,
     RdataCommand,
     RemoteServerUpCommand,
@@ -171,7 +159,7 @@ class ReplicationCommandHandler:
             return
 
         for stream_name, stream in self._streams.items():
-            current_token = stream.current_token()
+            current_token = stream.current_token(self._instance_name)
             self.send_command(
                 PositionCommand(stream_name, self._instance_name, current_token)
             )
@@ -210,18 +198,6 @@ class ReplicationCommandHandler:
 
             self._notifier.on_new_replication_data()
 
-    async def on_INVALIDATE_CACHE(
-        self, conn: AbstractConnection, cmd: InvalidateCacheCommand
-    ):
-        invalidate_cache_counter.inc()
-
-        if self._is_master:
-            # We invalidate the cache locally, but then also stream that to other
-            # workers.
-            await self._store.invalidate_cache_and_stream(
-                cmd.cache_func, tuple(cmd.keys)
-            )
-
     async def on_USER_IP(self, conn: AbstractConnection, cmd: UserIpCommand):
         user_ip_cache_counter.inc()
 
@@ -295,7 +271,7 @@ class ReplicationCommandHandler:
             rows: a list of Stream.ROW_TYPE objects as returned by
                 Stream.parse_row.
         """
-        logger.debug("Received rdata %s -> %s", stream_name, token)
+        logger.debug("Received rdata %s (%s) -> %s", stream_name, instance_name, token)
         await self._replication_data_handler.on_rdata(
             stream_name, instance_name, token, rows
         )
@@ -326,7 +302,7 @@ class ReplicationCommandHandler:
             self._pending_batches.pop(stream_name, [])
 
             # Find where we previously streamed up to.
-            current_token = stream.current_token()
+            current_token = stream.current_token(cmd.instance_name)
 
             # If the position token matches our current token then we're up to
             # date and there's nothing to do. Otherwise, fetch all updates
@@ -363,7 +339,9 @@ class ReplicationCommandHandler:
             logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
 
             # We've now caught up to position sent to us, notify handler.
-            await self._replication_data_handler.on_position(stream_name, cmd.token)
+            await self._replication_data_handler.on_position(
+                cmd.stream_name, cmd.instance_name, cmd.token
+            )
 
             self._streams_by_connection.setdefault(conn, set()).add(stream_name)
 
@@ -491,12 +469,6 @@ class ReplicationCommandHandler:
         cmd = RemovePusherCommand(app_id, push_key, user_id)
         self.send_command(cmd)
 
-    def send_invalidate_cache(self, cache_func: Callable, keys: tuple):
-        """Poke the master to invalidate a cache.
-        """
-        cmd = InvalidateCacheCommand(cache_func.__name__, keys)
-        self.send_command(cmd)
-
     def send_user_ip(
         self,
         user_id: str,
diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py
index b690abedad..002171ce7c 100644
--- a/synapse/replication/tcp/resource.py
+++ b/synapse/replication/tcp/resource.py
@@ -25,7 +25,12 @@ from twisted.internet.protocol import Factory
 
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
-from synapse.replication.tcp.streams import STREAMS_MAP, FederationStream, Stream
+from synapse.replication.tcp.streams import (
+    STREAMS_MAP,
+    CachesStream,
+    FederationStream,
+    Stream,
+)
 from synapse.util.metrics import Measure
 
 stream_updates_counter = Counter(
@@ -71,11 +76,16 @@ class ReplicationStreamer(object):
         self.store = hs.get_datastore()
         self.clock = hs.get_clock()
         self.notifier = hs.get_notifier()
+        self._instance_name = hs.get_instance_name()
 
         self._replication_torture_level = hs.config.replication_torture_level
 
         # Work out list of streams that this instance is the source of.
         self.streams = []  # type: List[Stream]
+
+        # All workers can write to the cache invalidation stream.
+        self.streams.append(CachesStream(hs))
+
         if hs.config.worker_app is None:
             for stream in STREAMS_MAP.values():
                 if stream == FederationStream and hs.config.send_federation:
@@ -83,6 +93,10 @@ class ReplicationStreamer(object):
                     # has been disabled on the master.
                     continue
 
+                if stream == CachesStream:
+                    # We've already added it above.
+                    continue
+
                 self.streams.append(stream(hs))
 
         self.streams_by_name = {stream.NAME: stream for stream in self.streams}
@@ -145,7 +159,9 @@ class ReplicationStreamer(object):
                         random.shuffle(all_streams)
 
                     for stream in all_streams:
-                        if stream.last_token == stream.current_token():
+                        if stream.last_token == stream.current_token(
+                            self._instance_name
+                        ):
                             continue
 
                         if self._replication_torture_level:
@@ -157,7 +173,7 @@ class ReplicationStreamer(object):
                             "Getting stream: %s: %s -> %s",
                             stream.NAME,
                             stream.last_token,
-                            stream.current_token(),
+                            stream.current_token(self._instance_name),
                         )
                         try:
                             updates, current_token, limited = await stream.get_updates()
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 084604e8b0..b48a6a3e91 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -95,20 +95,25 @@ class Stream(object):
     def __init__(
         self,
         local_instance_name: str,
-        current_token_function: Callable[[], Token],
+        current_token_function: Callable[[str], Token],
         update_function: UpdateFunction,
     ):
         """Instantiate a Stream
 
-        current_token_function and update_function are callbacks which should be
-        implemented by subclasses.
+        `current_token_function` and `update_function` are callbacks which
+        should be implemented by subclasses.
 
-        current_token_function is called to get the current token of the underlying
-        stream. It is only meaningful on the process that is the source of the
-        replication stream (ie, usually the master).
+        `current_token_function` takes an instance name, which is a writer to
+        the stream, and returns the position in the stream of the writer (as
+        viewed from the current process). On the writer process this is where
+        the writer has successfully written up to, whereas on other processes
+        this is the position which we have received updates up to over
+        replication. (Note that most streams have a single writer and so their
+        implementations ignore the instance name passed in).
 
-        update_function is called to get updates for this stream between a pair of
-        stream tokens. See the UpdateFunction type definition for more info.
+        `update_function` is called to get updates for this stream between a
+        pair of stream tokens. See the `UpdateFunction` type definition for more
+        info.
 
         Args:
             local_instance_name: The instance name of the current process
@@ -120,13 +125,13 @@ class Stream(object):
         self.update_function = update_function
 
         # The token from which we last asked for updates
-        self.last_token = self.current_token()
+        self.last_token = self.current_token(self.local_instance_name)
 
     def discard_updates_and_advance(self):
         """Called when the stream should advance but the updates would be discarded,
         e.g. when there are no currently connected workers.
         """
-        self.last_token = self.current_token()
+        self.last_token = self.current_token(self.local_instance_name)
 
     async def get_updates(self) -> StreamUpdateResult:
         """Gets all updates since the last time this function was called (or
@@ -138,7 +143,7 @@ class Stream(object):
             position in stream, and `limited` is whether there are more updates
             to fetch.
         """
-        current_token = self.current_token()
+        current_token = self.current_token(self.local_instance_name)
         updates, current_token, limited = await self.get_updates_since(
             self.local_instance_name, self.last_token, current_token
         )
@@ -170,6 +175,16 @@ class Stream(object):
         return updates, upto_token, limited
 
 
+def current_token_without_instance(
+    current_token: Callable[[], int]
+) -> Callable[[str], int]:
+    """Takes a current token callback function for a single writer stream
+    that doesn't take an instance name parameter and wraps it in a function that
+    does accept an instance name parameter but ignores it.
+    """
+    return lambda instance_name: current_token()
+
+
 def db_query_to_update_function(
     query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]]
 ) -> UpdateFunction:
@@ -235,7 +250,7 @@ class BackfillStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_current_backfill_token,
+            current_token_without_instance(store.get_current_backfill_token),
             db_query_to_update_function(store.get_all_new_backfill_event_rows),
         )
 
@@ -271,7 +286,9 @@ class PresenceStream(Stream):
             update_function = make_http_update_function(hs, self.NAME)
 
         super().__init__(
-            hs.get_instance_name(), store.get_current_presence_token, update_function
+            hs.get_instance_name(),
+            current_token_without_instance(store.get_current_presence_token),
+            update_function,
         )
 
 
@@ -296,7 +313,9 @@ class TypingStream(Stream):
             update_function = make_http_update_function(hs, self.NAME)
 
         super().__init__(
-            hs.get_instance_name(), typing_handler.get_current_token, update_function
+            hs.get_instance_name(),
+            current_token_without_instance(typing_handler.get_current_token),
+            update_function,
         )
 
 
@@ -319,7 +338,7 @@ class ReceiptsStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_max_receipt_stream_id,
+            current_token_without_instance(store.get_max_receipt_stream_id),
             db_query_to_update_function(store.get_all_updated_receipts),
         )
 
@@ -339,7 +358,7 @@ class PushRulesStream(Stream):
             hs.get_instance_name(), self._current_token, self._update_function
         )
 
-    def _current_token(self) -> int:
+    def _current_token(self, instance_name: str) -> int:
         push_rules_token, _ = self.store.get_push_rules_stream_token()
         return push_rules_token
 
@@ -373,7 +392,7 @@ class PushersStream(Stream):
 
         super().__init__(
             hs.get_instance_name(),
-            store.get_pushers_stream_token,
+            current_token_without_instance(store.get_pushers_stream_token),
             db_query_to_update_function(store.get_all_updated_pushers_rows),
         )
 
@@ -402,12 +421,26 @@ class CachesStream(Stream):
     ROW_TYPE = CachesStreamRow
 
     def __init__(self, hs):
-        store = hs.get_datastore()
+        self.store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_cache_stream_token,
-            db_query_to_update_function(store.get_all_updated_caches),
+            self.store.get_cache_stream_token,
+            self._update_function,
+        )
+
+    async def _update_function(
+        self, instance_name: str, from_token: int, upto_token: int, limit: int
+    ):
+        rows = await self.store.get_all_updated_caches(
+            instance_name, from_token, upto_token, limit
         )
+        updates = [(row[0], row[1:]) for row in rows]
+        limited = False
+        if len(updates) >= limit:
+            upto_token = updates[-1][0]
+            limited = True
+
+        return updates, upto_token, limited
 
 
 class PublicRoomsStream(Stream):
@@ -431,7 +464,7 @@ class PublicRoomsStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_current_public_room_stream_id,
+            current_token_without_instance(store.get_current_public_room_stream_id),
             db_query_to_update_function(store.get_all_new_public_rooms),
         )
 
@@ -452,7 +485,7 @@ class DeviceListsStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_device_stream_token,
+            current_token_without_instance(store.get_device_stream_token),
             db_query_to_update_function(store.get_all_device_list_changes_for_remotes),
         )
 
@@ -470,7 +503,7 @@ class ToDeviceStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_to_device_stream_token,
+            current_token_without_instance(store.get_to_device_stream_token),
             db_query_to_update_function(store.get_all_new_device_messages),
         )
 
@@ -490,7 +523,7 @@ class TagAccountDataStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_max_account_data_stream_id,
+            current_token_without_instance(store.get_max_account_data_stream_id),
             db_query_to_update_function(store.get_all_updated_tags),
         )
 
@@ -510,7 +543,7 @@ class AccountDataStream(Stream):
         self.store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            self.store.get_max_account_data_stream_id,
+            current_token_without_instance(self.store.get_max_account_data_stream_id),
             db_query_to_update_function(self._update_function),
         )
 
@@ -541,7 +574,7 @@ class GroupServerStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_group_stream_token,
+            current_token_without_instance(store.get_group_stream_token),
             db_query_to_update_function(store.get_all_groups_changes),
         )
 
@@ -559,7 +592,7 @@ class UserSignatureStream(Stream):
         store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            store.get_device_stream_token,
+            current_token_without_instance(store.get_device_stream_token),
             db_query_to_update_function(
                 store.get_all_user_signature_changes_for_remotes
             ),
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index 890e75d827..f370390331 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -20,7 +20,7 @@ from typing import List, Tuple, Type
 
 import attr
 
-from ._base import Stream, StreamUpdateResult, Token
+from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
 
 
 """Handling of the 'events' replication stream
@@ -119,7 +119,7 @@ class EventsStream(Stream):
         self._store = hs.get_datastore()
         super().__init__(
             hs.get_instance_name(),
-            self._store.get_current_events_token,
+            current_token_without_instance(self._store.get_current_events_token),
             self._update_function,
         )
 
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index b0505b8a2c..9bcd13b009 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -15,7 +15,11 @@
 # limitations under the License.
 from collections import namedtuple
 
-from synapse.replication.tcp.streams._base import Stream, make_http_update_function
+from synapse.replication.tcp.streams._base import (
+    Stream,
+    current_token_without_instance,
+    make_http_update_function,
+)
 
 
 class FederationStream(Stream):
@@ -41,7 +45,9 @@ class FederationStream(Stream):
             # will be a real FederationSender, which has stubs for current_token and
             # get_replication_rows.)
             federation_sender = hs.get_federation_sender()
-            current_token = federation_sender.get_current_token
+            current_token = current_token_without_instance(
+                federation_sender.get_current_token
+            )
             update_function = federation_sender.get_replication_rows
 
         elif hs.should_send_federation():
@@ -58,7 +64,7 @@ class FederationStream(Stream):
         super().__init__(hs.get_instance_name(), current_token, update_function)
 
     @staticmethod
-    def _stub_current_token():
+    def _stub_current_token(instance_name: str) -> int:
         # dummy current-token method for use on workers
         return 0