summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-02-27 13:07:39 +0000
committerErik Johnston <erik@matrix.org>2020-02-27 13:23:14 +0000
commitbf7e5d710d446204446161598f0c9e12e4a1f1e1 (patch)
tree077bf24281b75d72a28e4117a5f1fc6048b9c89f
parentStore room version on invite (#6983) (diff)
downloadsynapse-erikj/worker_can_read_streams.tar.xz
ove stream fetch DB queries to worker stores. github/erikj/worker_can_read_streams erikj/worker_can_read_streams
-rw-r--r--synapse/replication/slave/storage/_base.py14
-rw-r--r--synapse/replication/slave/storage/events.py8
-rw-r--r--synapse/replication/slave/storage/pushers.py3
-rw-r--r--synapse/storage/data_stores/main/cache.py44
-rw-r--r--synapse/storage/data_stores/main/deviceinbox.py88
-rw-r--r--synapse/storage/data_stores/main/events.py90
-rw-r--r--synapse/storage/data_stores/main/events_worker.py90
-rw-r--r--synapse/storage/data_stores/main/room.py40
8 files changed, 199 insertions, 178 deletions
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index f45cbd37a0..751c799d94 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -18,8 +18,10 @@ from typing import Dict, Optional
 
 import six
 
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.data_stores.main.cache import CURRENT_STATE_CACHE_NAME
+from synapse.storage.data_stores.main.cache import (
+    CURRENT_STATE_CACHE_NAME,
+    CacheInvalidationWorkerStore,
+)
 from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
 
@@ -35,7 +37,7 @@ def __func__(inp):
         return inp.__func__
 
 
-class BaseSlavedStore(SQLBaseStore):
+class BaseSlavedStore(CacheInvalidationWorkerStore):
     def __init__(self, database: Database, db_conn, hs):
         super(BaseSlavedStore, self).__init__(database, db_conn, hs)
         if isinstance(self.database_engine, PostgresEngine):
@@ -60,6 +62,12 @@ class BaseSlavedStore(SQLBaseStore):
             pos["caches"] = self._cache_id_gen.get_current_token()
         return pos
 
+    def get_cache_stream_token(self):
+        if self._cache_id_gen:
+            return self._cache_id_gen.get_current_token()
+        else:
+            return 0
+
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == "caches":
             if self._cache_id_gen:
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index e73342c657..75e5fffdb8 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -93,6 +93,14 @@ class SlavedEventStore(
     def get_room_min_stream_ordering(self):
         return self._backfill_id_gen.get_current_token()
 
+    def get_current_backfill_token(self):
+        """The current minimum token that backfilled events have reached"""
+        return -self._backfill_id_gen.get_current_token()
+
+    def get_current_events_token(self):
+        """The current maximum token that events have reached"""
+        return self._stream_id_gen.get_current_token()
+
     def stream_positions(self):
         result = super(SlavedEventStore, self).stream_positions()
         result["events"] = self._stream_id_gen.get_current_token()
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index f22c2d44a3..bce8a3d115 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -33,6 +33,9 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
         result["pushers"] = self._pushers_id_gen.get_current_token()
         return result
 
+    def get_pushers_stream_token(self):
+        return self._pushers_id_gen.get_current_token()
+
     def process_replication_rows(self, stream_name, token, rows):
         if stream_name == "pushers":
             self._pushers_id_gen.advance(token)
diff --git a/synapse/storage/data_stores/main/cache.py b/synapse/storage/data_stores/main/cache.py
index d4c44dcc75..4dc5da3fe8 100644
--- a/synapse/storage/data_stores/main/cache.py
+++ b/synapse/storage/data_stores/main/cache.py
@@ -32,7 +32,29 @@ logger = logging.getLogger(__name__)
 CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
 
 
-class CacheInvalidationStore(SQLBaseStore):
+class CacheInvalidationWorkerStore(SQLBaseStore):
+    def get_all_updated_caches(self, last_id, current_id, limit):
+        if last_id == current_id:
+            return defer.succeed([])
+
+        def get_all_updated_caches_txn(txn):
+            # We purposefully don't bound by the current token, as we want to
+            # send across cache invalidations as quickly as possible. Cache
+            # invalidations are idempotent, so duplicates are fine.
+            sql = (
+                "SELECT stream_id, cache_func, keys, invalidation_ts"
+                " FROM cache_invalidation_stream"
+                " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
+            )
+            txn.execute(sql, (last_id, limit))
+            return txn.fetchall()
+
+        return self.db.runInteraction(
+            "get_all_updated_caches", get_all_updated_caches_txn
+        )
+
+
+class CacheInvalidationStore(CacheInvalidationWorkerStore):
     async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
         """Invalidates the cache and adds it to the cache stream so slaves
         will know to invalidate their caches.
@@ -145,26 +167,6 @@ class CacheInvalidationStore(SQLBaseStore):
                 },
             )
 
-    def get_all_updated_caches(self, last_id, current_id, limit):
-        if last_id == current_id:
-            return defer.succeed([])
-
-        def get_all_updated_caches_txn(txn):
-            # We purposefully don't bound by the current token, as we want to
-            # send across cache invalidations as quickly as possible. Cache
-            # invalidations are idempotent, so duplicates are fine.
-            sql = (
-                "SELECT stream_id, cache_func, keys, invalidation_ts"
-                " FROM cache_invalidation_stream"
-                " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
-            )
-            txn.execute(sql, (last_id, limit))
-            return txn.fetchall()
-
-        return self.db.runInteraction(
-            "get_all_updated_caches", get_all_updated_caches_txn
-        )
-
     def get_cache_stream_token(self):
         if self._cache_id_gen:
             return self._cache_id_gen.get_current_token()
diff --git a/synapse/storage/data_stores/main/deviceinbox.py b/synapse/storage/data_stores/main/deviceinbox.py
index 0613b49f4a..9a1178fb39 100644
--- a/synapse/storage/data_stores/main/deviceinbox.py
+++ b/synapse/storage/data_stores/main/deviceinbox.py
@@ -207,6 +207,50 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             "delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
         )
 
+    def get_all_new_device_messages(self, last_pos, current_pos, limit):
+        """
+        Args:
+            last_pos(int):
+            current_pos(int):
+            limit(int):
+        Returns:
+            A deferred list of rows from the device inbox
+        """
+        if last_pos == current_pos:
+            return defer.succeed([])
+
+        def get_all_new_device_messages_txn(txn):
+            # We limit like this as we might have multiple rows per stream_id, and
+            # we want to make sure we always get all entries for any stream_id
+            # we return.
+            upper_pos = min(current_pos, last_pos + limit)
+            sql = (
+                "SELECT max(stream_id), user_id"
+                " FROM device_inbox"
+                " WHERE ? < stream_id AND stream_id <= ?"
+                " GROUP BY user_id"
+            )
+            txn.execute(sql, (last_pos, upper_pos))
+            rows = txn.fetchall()
+
+            sql = (
+                "SELECT max(stream_id), destination"
+                " FROM device_federation_outbox"
+                " WHERE ? < stream_id AND stream_id <= ?"
+                " GROUP BY destination"
+            )
+            txn.execute(sql, (last_pos, upper_pos))
+            rows.extend(txn)
+
+            # Order by ascending stream ordering
+            rows.sort()
+
+            return rows
+
+        return self.db.runInteraction(
+            "get_all_new_device_messages", get_all_new_device_messages_txn
+        )
+
 
 class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
@@ -411,47 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
                 rows.append((user_id, device_id, stream_id, message_json))
 
         txn.executemany(sql, rows)
-
-    def get_all_new_device_messages(self, last_pos, current_pos, limit):
-        """
-        Args:
-            last_pos(int):
-            current_pos(int):
-            limit(int):
-        Returns:
-            A deferred list of rows from the device inbox
-        """
-        if last_pos == current_pos:
-            return defer.succeed([])
-
-        def get_all_new_device_messages_txn(txn):
-            # We limit like this as we might have multiple rows per stream_id, and
-            # we want to make sure we always get all entries for any stream_id
-            # we return.
-            upper_pos = min(current_pos, last_pos + limit)
-            sql = (
-                "SELECT max(stream_id), user_id"
-                " FROM device_inbox"
-                " WHERE ? < stream_id AND stream_id <= ?"
-                " GROUP BY user_id"
-            )
-            txn.execute(sql, (last_pos, upper_pos))
-            rows = txn.fetchall()
-
-            sql = (
-                "SELECT max(stream_id), destination"
-                " FROM device_federation_outbox"
-                " WHERE ? < stream_id AND stream_id <= ?"
-                " GROUP BY destination"
-            )
-            txn.execute(sql, (last_pos, upper_pos))
-            rows.extend(txn)
-
-            # Order by ascending stream ordering
-            rows.sort()
-
-            return rows
-
-        return self.db.runInteraction(
-            "get_all_new_device_messages", get_all_new_device_messages_txn
-        )
diff --git a/synapse/storage/data_stores/main/events.py b/synapse/storage/data_stores/main/events.py
index 8ae23df00a..f0ebe9ebe9 100644
--- a/synapse/storage/data_stores/main/events.py
+++ b/synapse/storage/data_stores/main/events.py
@@ -1271,96 +1271,6 @@ class EventsStore(
         """The current maximum token that events have reached"""
         return self._stream_id_gen.get_current_token()
 
-    def get_all_new_forward_event_rows(self, last_id, current_id, limit):
-        if last_id == current_id:
-            return defer.succeed([])
-
-        def get_all_new_forward_event_rows(txn):
-            sql = (
-                "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
-                " FROM events AS e"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
-                " LEFT JOIN event_relations USING (event_id)"
-                " WHERE ? < stream_ordering AND stream_ordering <= ?"
-                " ORDER BY stream_ordering ASC"
-                " LIMIT ?"
-            )
-            txn.execute(sql, (last_id, current_id, limit))
-            new_event_updates = txn.fetchall()
-
-            if len(new_event_updates) == limit:
-                upper_bound = new_event_updates[-1][0]
-            else:
-                upper_bound = current_id
-
-            sql = (
-                "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
-                " FROM events AS e"
-                " INNER JOIN ex_outlier_stream USING (event_id)"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
-                " LEFT JOIN event_relations USING (event_id)"
-                " WHERE ? < event_stream_ordering"
-                " AND event_stream_ordering <= ?"
-                " ORDER BY event_stream_ordering DESC"
-            )
-            txn.execute(sql, (last_id, upper_bound))
-            new_event_updates.extend(txn)
-
-            return new_event_updates
-
-        return self.db.runInteraction(
-            "get_all_new_forward_event_rows", get_all_new_forward_event_rows
-        )
-
-    def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
-        if last_id == current_id:
-            return defer.succeed([])
-
-        def get_all_new_backfill_event_rows(txn):
-            sql = (
-                "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
-                " FROM events AS e"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
-                " LEFT JOIN event_relations USING (event_id)"
-                " WHERE ? > stream_ordering AND stream_ordering >= ?"
-                " ORDER BY stream_ordering ASC"
-                " LIMIT ?"
-            )
-            txn.execute(sql, (-last_id, -current_id, limit))
-            new_event_updates = txn.fetchall()
-
-            if len(new_event_updates) == limit:
-                upper_bound = new_event_updates[-1][0]
-            else:
-                upper_bound = current_id
-
-            sql = (
-                "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
-                " FROM events AS e"
-                " INNER JOIN ex_outlier_stream USING (event_id)"
-                " LEFT JOIN redactions USING (event_id)"
-                " LEFT JOIN state_events USING (event_id)"
-                " LEFT JOIN event_relations USING (event_id)"
-                " WHERE ? > event_stream_ordering"
-                " AND event_stream_ordering >= ?"
-                " ORDER BY event_stream_ordering DESC"
-            )
-            txn.execute(sql, (-last_id, -upper_bound))
-            new_event_updates.extend(txn.fetchall())
-
-            return new_event_updates
-
-        return self.db.runInteraction(
-            "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
-        )
-
     @cached(num_args=5, max_entries=10)
     def get_all_new_events(
         self,
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 47a3a26072..86a419ad00 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -911,3 +911,93 @@ class EventsWorkerStore(SQLBaseStore):
         complexity_v1 = round(state_events / 500, 2)
 
         return {"v1": complexity_v1}
+
+    def get_all_new_forward_event_rows(self, last_id, current_id, limit):
+        if last_id == current_id:
+            return defer.succeed([])
+
+        def get_all_new_forward_event_rows(txn):
+            sql = (
+                "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts, relates_to_id"
+                " FROM events AS e"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN event_relations USING (event_id)"
+                " WHERE ? < stream_ordering AND stream_ordering <= ?"
+                " ORDER BY stream_ordering ASC"
+                " LIMIT ?"
+            )
+            txn.execute(sql, (last_id, current_id, limit))
+            new_event_updates = txn.fetchall()
+
+            if len(new_event_updates) == limit:
+                upper_bound = new_event_updates[-1][0]
+            else:
+                upper_bound = current_id
+
+            sql = (
+                "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts, relates_to_id"
+                " FROM events AS e"
+                " INNER JOIN ex_outlier_stream USING (event_id)"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN event_relations USING (event_id)"
+                " WHERE ? < event_stream_ordering"
+                " AND event_stream_ordering <= ?"
+                " ORDER BY event_stream_ordering DESC"
+            )
+            txn.execute(sql, (last_id, upper_bound))
+            new_event_updates.extend(txn)
+
+            return new_event_updates
+
+        return self.db.runInteraction(
+            "get_all_new_forward_event_rows", get_all_new_forward_event_rows
+        )
+
+    def get_all_new_backfill_event_rows(self, last_id, current_id, limit):
+        if last_id == current_id:
+            return defer.succeed([])
+
+        def get_all_new_backfill_event_rows(txn):
+            sql = (
+                "SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts, relates_to_id"
+                " FROM events AS e"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN event_relations USING (event_id)"
+                " WHERE ? > stream_ordering AND stream_ordering >= ?"
+                " ORDER BY stream_ordering ASC"
+                " LIMIT ?"
+            )
+            txn.execute(sql, (-last_id, -current_id, limit))
+            new_event_updates = txn.fetchall()
+
+            if len(new_event_updates) == limit:
+                upper_bound = new_event_updates[-1][0]
+            else:
+                upper_bound = current_id
+
+            sql = (
+                "SELECT -event_stream_ordering, e.event_id, e.room_id, e.type,"
+                " state_key, redacts, relates_to_id"
+                " FROM events AS e"
+                " INNER JOIN ex_outlier_stream USING (event_id)"
+                " LEFT JOIN redactions USING (event_id)"
+                " LEFT JOIN state_events USING (event_id)"
+                " LEFT JOIN event_relations USING (event_id)"
+                " WHERE ? > event_stream_ordering"
+                " AND event_stream_ordering >= ?"
+                " ORDER BY event_stream_ordering DESC"
+            )
+            txn.execute(sql, (-last_id, -upper_bound))
+            new_event_updates.extend(txn.fetchall())
+
+            return new_event_updates
+
+        return self.db.runInteraction(
+            "get_all_new_backfill_event_rows", get_all_new_backfill_event_rows
+        )
diff --git a/synapse/storage/data_stores/main/room.py b/synapse/storage/data_stores/main/room.py
index e6c10c6316..aaebe427d3 100644
--- a/synapse/storage/data_stores/main/room.py
+++ b/synapse/storage/data_stores/main/room.py
@@ -732,6 +732,26 @@ class RoomWorkerStore(SQLBaseStore):
 
         return total_media_quarantined
 
+    def get_all_new_public_rooms(self, prev_id, current_id, limit):
+        def get_all_new_public_rooms(txn):
+            sql = """
+                SELECT stream_id, room_id, visibility, appservice_id, network_id
+                FROM public_room_list_stream
+                WHERE stream_id > ? AND stream_id <= ?
+                ORDER BY stream_id ASC
+                LIMIT ?
+            """
+
+            txn.execute(sql, (prev_id, current_id, limit))
+            return txn.fetchall()
+
+        if prev_id == current_id:
+            return defer.succeed([])
+
+        return self.db.runInteraction(
+            "get_all_new_public_rooms", get_all_new_public_rooms
+        )
+
 
 class RoomBackgroundUpdateStore(SQLBaseStore):
     REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -1249,26 +1269,6 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
     def get_current_public_room_stream_id(self):
         return self._public_room_id_gen.get_current_token()
 
-    def get_all_new_public_rooms(self, prev_id, current_id, limit):
-        def get_all_new_public_rooms(txn):
-            sql = """
-                SELECT stream_id, room_id, visibility, appservice_id, network_id
-                FROM public_room_list_stream
-                WHERE stream_id > ? AND stream_id <= ?
-                ORDER BY stream_id ASC
-                LIMIT ?
-            """
-
-            txn.execute(sql, (prev_id, current_id, limit))
-            return txn.fetchall()
-
-        if prev_id == current_id:
-            return defer.succeed([])
-
-        return self.db.runInteraction(
-            "get_all_new_public_rooms", get_all_new_public_rooms
-        )
-
     @defer.inlineCallbacks
     def block_room(self, room_id, user_id):
         """Marks the room as blocked. Can be called multiple times.