summary refs log tree commit diff
path: root/synapse/replication/tcp/streams
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-01-22 16:53:28 +0000
committerErik Johnston <erik@matrix.org>2020-01-22 16:53:28 +0000
commit57a60365da0c47f286ea4608d766abbca5762233 (patch)
treeaaef0948f26f3352092b787d32e1dda0743d697e /synapse/replication/tcp/streams
parentPull out more info about room key requests (diff)
parentRemove unnecessary abstractions in admin handler (#6751) (diff)
downloadsynapse-erikj/debug_direct_message_checks.tar.xz
Merge branch 'develop' of github.com:matrix-org/synapse into erikj/debug_direct_message_checks github/erikj/debug_direct_message_checks erikj/debug_direct_message_checks
Diffstat (limited to 'synapse/replication/tcp/streams')
-rw-r--r--synapse/replication/tcp/streams/_base.py104
-rw-r--r--synapse/replication/tcp/streams/events.py25
-rw-r--r--synapse/replication/tcp/streams/federation.py4
3 files changed, 74 insertions, 59 deletions
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 8512923eae..a8d568b14a 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -14,12 +14,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import itertools
 import logging
 from collections import namedtuple
+from typing import Any, List, Optional
 
-from twisted.internet import defer
+import attr
 
 logger = logging.getLogger(__name__)
 
@@ -67,10 +67,24 @@ PushersStreamRow = namedtuple(
     "PushersStreamRow",
     ("user_id", "app_id", "pushkey", "deleted"),  # str  # str  # str  # bool
 )
-CachesStreamRow = namedtuple(
-    "CachesStreamRow",
-    ("cache_func", "keys", "invalidation_ts"),  # str  # list(str)  # int
-)
+
+
+@attr.s
+class CachesStreamRow:
+    """Stream to inform workers they should invalidate their cache.
+
+    Attributes:
+        cache_func: Name of the cached function.
+        keys: The entry in the cache to invalidate. If None then will
+            invalidate all.
+        invalidation_ts: Timestamp of when the invalidation took place.
+    """
+
+    cache_func = attr.ib(type=str)
+    keys = attr.ib(type=Optional[List[Any]])
+    invalidation_ts = attr.ib(type=int)
+
+
 PublicRoomsStreamRow = namedtuple(
     "PublicRoomsStreamRow",
     (
@@ -104,8 +118,9 @@ class Stream(object):
     time it was called up until the point `advance_current_token` was called.
     """
 
-    NAME = None  # The name of the stream
-    ROW_TYPE = None  # The type of the row. Used by the default impl of parse_row.
+    NAME = None  # type: str  # The name of the stream
+    # The type of the row. Used by the default impl of parse_row.
+    ROW_TYPE = None  # type: Any
     _LIMITED = True  # Whether the update function takes a limit
 
     @classmethod
@@ -143,8 +158,7 @@ class Stream(object):
         self.upto_token = self.current_token()
         self.last_token = self.upto_token
 
-    @defer.inlineCallbacks
-    def get_updates(self):
+    async def get_updates(self):
         """Gets all updates since the last time this function was called (or
         since the stream was constructed if it hadn't been called before),
         until the `upto_token`
@@ -155,13 +169,12 @@ class Stream(object):
                 list of ``(token, row)`` entries. ``row`` will be json-serialised and
                 sent over the replication steam.
         """
-        updates, current_token = yield self.get_updates_since(self.last_token)
+        updates, current_token = await self.get_updates_since(self.last_token)
         self.last_token = current_token
 
         return updates, current_token
 
-    @defer.inlineCallbacks
-    def get_updates_since(self, from_token):
+    async def get_updates_since(self, from_token):
         """Like get_updates except allows specifying from when we should
         stream updates
 
@@ -181,15 +194,16 @@ class Stream(object):
         if from_token == current_token:
             return [], current_token
 
+        logger.info("get_updates_since: %s", self.__class__)
         if self._LIMITED:
-            rows = yield self.update_function(
+            rows = await self.update_function(
                 from_token, current_token, limit=MAX_EVENTS_BEHIND + 1
             )
 
             # never turn more than MAX_EVENTS_BEHIND + 1 into updates.
             rows = itertools.islice(rows, MAX_EVENTS_BEHIND + 1)
         else:
-            rows = yield self.update_function(from_token, current_token)
+            rows = await self.update_function(from_token, current_token)
 
         updates = [(row[0], row[1:]) for row in rows]
 
@@ -231,8 +245,8 @@ class BackfillStream(Stream):
 
     def __init__(self, hs):
         store = hs.get_datastore()
-        self.current_token = store.get_current_backfill_token
-        self.update_function = store.get_all_new_backfill_event_rows
+        self.current_token = store.get_current_backfill_token  # type: ignore
+        self.update_function = store.get_all_new_backfill_event_rows  # type: ignore
 
         super(BackfillStream, self).__init__(hs)
 
@@ -246,8 +260,8 @@ class PresenceStream(Stream):
         store = hs.get_datastore()
         presence_handler = hs.get_presence_handler()
 
-        self.current_token = store.get_current_presence_token
-        self.update_function = presence_handler.get_all_presence_updates
+        self.current_token = store.get_current_presence_token  # type: ignore
+        self.update_function = presence_handler.get_all_presence_updates  # type: ignore
 
         super(PresenceStream, self).__init__(hs)
 
@@ -260,8 +274,8 @@ class TypingStream(Stream):
     def __init__(self, hs):
         typing_handler = hs.get_typing_handler()
 
-        self.current_token = typing_handler.get_current_token
-        self.update_function = typing_handler.get_all_typing_updates
+        self.current_token = typing_handler.get_current_token  # type: ignore
+        self.update_function = typing_handler.get_all_typing_updates  # type: ignore
 
         super(TypingStream, self).__init__(hs)
 
@@ -273,8 +287,8 @@ class ReceiptsStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
 
-        self.current_token = store.get_max_receipt_stream_id
-        self.update_function = store.get_all_updated_receipts
+        self.current_token = store.get_max_receipt_stream_id  # type: ignore
+        self.update_function = store.get_all_updated_receipts  # type: ignore
 
         super(ReceiptsStream, self).__init__(hs)
 
@@ -294,9 +308,8 @@ class PushRulesStream(Stream):
         push_rules_token, _ = self.store.get_push_rules_stream_token()
         return push_rules_token
 
-    @defer.inlineCallbacks
-    def update_function(self, from_token, to_token, limit):
-        rows = yield self.store.get_all_push_rule_updates(from_token, to_token, limit)
+    async def update_function(self, from_token, to_token, limit):
+        rows = await self.store.get_all_push_rule_updates(from_token, to_token, limit)
         return [(row[0], row[2]) for row in rows]
 
 
@@ -310,8 +323,8 @@ class PushersStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
 
-        self.current_token = store.get_pushers_stream_token
-        self.update_function = store.get_all_updated_pushers_rows
+        self.current_token = store.get_pushers_stream_token  # type: ignore
+        self.update_function = store.get_all_updated_pushers_rows  # type: ignore
 
         super(PushersStream, self).__init__(hs)
 
@@ -327,8 +340,8 @@ class CachesStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
 
-        self.current_token = store.get_cache_stream_token
-        self.update_function = store.get_all_updated_caches
+        self.current_token = store.get_cache_stream_token  # type: ignore
+        self.update_function = store.get_all_updated_caches  # type: ignore
 
         super(CachesStream, self).__init__(hs)
 
@@ -343,8 +356,8 @@ class PublicRoomsStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
 
-        self.current_token = store.get_current_public_room_stream_id
-        self.update_function = store.get_all_new_public_rooms
+        self.current_token = store.get_current_public_room_stream_id  # type: ignore
+        self.update_function = store.get_all_new_public_rooms  # type: ignore
 
         super(PublicRoomsStream, self).__init__(hs)
 
@@ -360,8 +373,8 @@ class DeviceListsStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
 
-        self.current_token = store.get_device_stream_token
-        self.update_function = store.get_all_device_list_changes_for_remotes
+        self.current_token = store.get_device_stream_token  # type: ignore
+        self.update_function = store.get_all_device_list_changes_for_remotes  # type: ignore
 
         super(DeviceListsStream, self).__init__(hs)
 
@@ -376,8 +389,8 @@ class ToDeviceStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
 
-        self.current_token = store.get_to_device_stream_token
-        self.update_function = store.get_all_new_device_messages
+        self.current_token = store.get_to_device_stream_token  # type: ignore
+        self.update_function = store.get_all_new_device_messages  # type: ignore
 
         super(ToDeviceStream, self).__init__(hs)
 
@@ -392,8 +405,8 @@ class TagAccountDataStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
 
-        self.current_token = store.get_max_account_data_stream_id
-        self.update_function = store.get_all_updated_tags
+        self.current_token = store.get_max_account_data_stream_id  # type: ignore
+        self.update_function = store.get_all_updated_tags  # type: ignore
 
         super(TagAccountDataStream, self).__init__(hs)
 
@@ -408,13 +421,12 @@ class AccountDataStream(Stream):
     def __init__(self, hs):
         self.store = hs.get_datastore()
 
-        self.current_token = self.store.get_max_account_data_stream_id
+        self.current_token = self.store.get_max_account_data_stream_id  # type: ignore
 
         super(AccountDataStream, self).__init__(hs)
 
-    @defer.inlineCallbacks
-    def update_function(self, from_token, to_token, limit):
-        global_results, room_results = yield self.store.get_all_updated_account_data(
+    async def update_function(self, from_token, to_token, limit):
+        global_results, room_results = await self.store.get_all_updated_account_data(
             from_token, from_token, to_token, limit
         )
 
@@ -434,8 +446,8 @@ class GroupServerStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
 
-        self.current_token = store.get_group_stream_token
-        self.update_function = store.get_all_groups_changes
+        self.current_token = store.get_group_stream_token  # type: ignore
+        self.update_function = store.get_all_groups_changes  # type: ignore
 
         super(GroupServerStream, self).__init__(hs)
 
@@ -451,7 +463,7 @@ class UserSignatureStream(Stream):
     def __init__(self, hs):
         store = hs.get_datastore()
 
-        self.current_token = store.get_device_stream_token
-        self.update_function = store.get_all_user_signature_changes_for_remotes
+        self.current_token = store.get_device_stream_token  # type: ignore
+        self.update_function = store.get_all_user_signature_changes_for_remotes  # type: ignore
 
         super(UserSignatureStream, self).__init__(hs)
diff --git a/synapse/replication/tcp/streams/events.py b/synapse/replication/tcp/streams/events.py
index d97669c886..b3afabb8cd 100644
--- a/synapse/replication/tcp/streams/events.py
+++ b/synapse/replication/tcp/streams/events.py
@@ -13,12 +13,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 import heapq
+from typing import Tuple, Type
 
 import attr
 
-from twisted.internet import defer
-
 from ._base import Stream
 
 
@@ -63,7 +63,8 @@ class BaseEventsStreamRow(object):
     Specifies how to identify, serialize and deserialize the different types.
     """
 
-    TypeId = None  # Unique string that ids the type. Must be overriden in sub classes.
+    # Unique string that ids the type. Must be overriden in sub classes.
+    TypeId = None  # type: str
 
     @classmethod
     def from_data(cls, data):
@@ -99,9 +100,12 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow):
     event_id = attr.ib()  # str, optional
 
 
-TypeToRow = {
-    Row.TypeId: Row for Row in (EventsStreamEventRow, EventsStreamCurrentStateRow)
-}
+_EventRows = (
+    EventsStreamEventRow,
+    EventsStreamCurrentStateRow,
+)  # type: Tuple[Type[BaseEventsStreamRow], ...]
+
+TypeToRow = {Row.TypeId: Row for Row in _EventRows}
 
 
 class EventsStream(Stream):
@@ -112,20 +116,19 @@ class EventsStream(Stream):
 
     def __init__(self, hs):
         self._store = hs.get_datastore()
-        self.current_token = self._store.get_current_events_token
+        self.current_token = self._store.get_current_events_token  # type: ignore
 
         super(EventsStream, self).__init__(hs)
 
-    @defer.inlineCallbacks
-    def update_function(self, from_token, current_token, limit=None):
-        event_rows = yield self._store.get_all_new_forward_event_rows(
+    async def update_function(self, from_token, current_token, limit=None):
+        event_rows = await self._store.get_all_new_forward_event_rows(
             from_token, current_token, limit
         )
         event_updates = (
             (row[0], EventsStreamEventRow.TypeId, row[1:]) for row in event_rows
         )
 
-        state_rows = yield self._store.get_all_updated_current_state_deltas(
+        state_rows = await self._store.get_all_updated_current_state_deltas(
             from_token, current_token, limit
         )
         state_updates = (
diff --git a/synapse/replication/tcp/streams/federation.py b/synapse/replication/tcp/streams/federation.py
index dc2484109d..615f3dc9ac 100644
--- a/synapse/replication/tcp/streams/federation.py
+++ b/synapse/replication/tcp/streams/federation.py
@@ -37,7 +37,7 @@ class FederationStream(Stream):
     def __init__(self, hs):
         federation_sender = hs.get_federation_sender()
 
-        self.current_token = federation_sender.get_current_token
-        self.update_function = federation_sender.get_replication_rows
+        self.current_token = federation_sender.get_current_token  # type: ignore
+        self.update_function = federation_sender.get_replication_rows  # type: ignore
 
         super(FederationStream, self).__init__(hs)