summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/appservice.py4
-rw-r--r--synapse/handlers/push_rules.py6
-rw-r--r--synapse/handlers/receipts.py25
-rw-r--r--synapse/notifier.py17
-rw-r--r--synapse/push/__init__.py2
-rw-r--r--synapse/push/emailpusher.py2
-rw-r--r--synapse/push/httppusher.py2
-rw-r--r--synapse/push/pusherpool.py12
-rw-r--r--synapse/replication/tcp/client.py4
-rw-r--r--synapse/storage/databases/main/e2e_room_keys.py2
-rw-r--r--synapse/storage/databases/main/receipts.py6
-rw-r--r--synapse/streams/events.py15
-rw-r--r--synapse/types/__init__.py59
13 files changed, 87 insertions, 69 deletions
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 7de7bd3289..c200a45f3a 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -216,7 +216,7 @@ class ApplicationServicesHandler:
 
     def notify_interested_services_ephemeral(
         self,
-        stream_key: str,
+        stream_key: StreamKeyType,
         new_token: Union[int, RoomStreamToken],
         users: Collection[Union[str, UserID]],
     ) -> None:
@@ -326,7 +326,7 @@ class ApplicationServicesHandler:
     async def _notify_interested_services_ephemeral(
         self,
         services: List[ApplicationService],
-        stream_key: str,
+        stream_key: StreamKeyType,
         new_token: int,
         users: Collection[Union[str, UserID]],
     ) -> None:
diff --git a/synapse/handlers/push_rules.py b/synapse/handlers/push_rules.py
index 7ed88a3611..87b428ab1c 100644
--- a/synapse/handlers/push_rules.py
+++ b/synapse/handlers/push_rules.py
@@ -19,7 +19,7 @@ from synapse.api.errors import SynapseError, UnrecognizedRequestError
 from synapse.push.clientformat import format_push_rules_for_user
 from synapse.storage.push_rule import RuleNotFoundException
 from synapse.synapse_rust.push import get_base_rule_ids
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, StreamKeyType, UserID
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -114,7 +114,9 @@ class PushRulesHandler:
             user_id: the user ID the change is for.
         """
         stream_id = self._main_store.get_max_push_rules_stream_id()
-        self._notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
+        self._notifier.on_new_event(
+            StreamKeyType.PUSH_RULES, stream_id, users=[user_id]
+        )
 
     async def push_rules_for_user(
         self, user: UserID
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index a7a29b758b..69ac468f75 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -130,11 +130,10 @@ class ReceiptsHandler:
 
     async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
         """Takes a list of receipts, stores them and informs the notifier."""
-        min_batch_id: Optional[int] = None
-        max_batch_id: Optional[int] = None
 
+        receipts_persisted: List[ReadReceipt] = []
         for receipt in receipts:
-            res = await self.store.insert_receipt(
+            stream_id = await self.store.insert_receipt(
                 receipt.room_id,
                 receipt.receipt_type,
                 receipt.user_id,
@@ -143,30 +142,26 @@ class ReceiptsHandler:
                 receipt.data,
             )
 
-            if not res:
-                # res will be None if this receipt is 'old'
+            if stream_id is None:
+                # stream_id will be None if this receipt is 'old'
                 continue
 
-            stream_id, max_persisted_id = res
+            receipts_persisted.append(receipt)
 
-            if min_batch_id is None or stream_id < min_batch_id:
-                min_batch_id = stream_id
-            if max_batch_id is None or max_persisted_id > max_batch_id:
-                max_batch_id = max_persisted_id
-
-        # Either both of these should be None or neither.
-        if min_batch_id is None or max_batch_id is None:
+        if not receipts_persisted:
             # no new receipts
             return False
 
-        affected_room_ids = list({r.room_id for r in receipts})
+        max_batch_id = self.store.get_max_receipt_stream_id()
+
+        affected_room_ids = list({r.room_id for r in receipts_persisted})
 
         self.notifier.on_new_event(
             StreamKeyType.RECEIPT, max_batch_id, rooms=affected_room_ids
         )
         # Note that the min here shouldn't be relied upon to be accurate.
         await self.hs.get_pusherpool().on_new_receipts(
-            min_batch_id, max_batch_id, affected_room_ids
+            {r.user_id for r in receipts_persisted}
         )
 
         return True
diff --git a/synapse/notifier.py b/synapse/notifier.py
index fc39e5c963..99e7715896 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -126,7 +126,7 @@ class _NotifierUserStream:
 
     def notify(
         self,
-        stream_key: str,
+        stream_key: StreamKeyType,
         stream_id: Union[int, RoomStreamToken],
         time_now_ms: int,
     ) -> None:
@@ -454,7 +454,7 @@ class Notifier:
 
     def on_new_event(
         self,
-        stream_key: str,
+        stream_key: StreamKeyType,
         new_token: Union[int, RoomStreamToken],
         users: Optional[Collection[Union[str, UserID]]] = None,
         rooms: Optional[StrCollection] = None,
@@ -655,30 +655,29 @@ class Notifier:
             events: List[Union[JsonDict, EventBase]] = []
             end_token = from_token
 
-            for name, source in self.event_sources.sources.get_sources():
-                keyname = "%s_key" % name
-                before_id = getattr(before_token, keyname)
-                after_id = getattr(after_token, keyname)
+            for keyname, source in self.event_sources.sources.get_sources():
+                before_id = before_token.get_field(keyname)
+                after_id = after_token.get_field(keyname)
                 if before_id == after_id:
                     continue
 
                 new_events, new_key = await source.get_new_events(
                     user=user,
-                    from_key=getattr(from_token, keyname),
+                    from_key=from_token.get_field(keyname),
                     limit=limit,
                     is_guest=is_peeking,
                     room_ids=room_ids,
                     explicit_room_id=explicit_room_id,
                 )
 
-                if name == "room":
+                if keyname == StreamKeyType.ROOM:
                     new_events = await filter_events_for_client(
                         self._storage_controllers,
                         user.to_string(),
                         new_events,
                         is_peeking=is_peeking,
                     )
-                elif name == "presence":
+                elif keyname == StreamKeyType.PRESENCE:
                     now = self.clock.time_msec()
                     new_events[:] = [
                         {
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 9e3a98741a..9e5eb2a445 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -182,7 +182,7 @@ class Pusher(metaclass=abc.ABCMeta):
         raise NotImplementedError()
 
     @abc.abstractmethod
-    def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
+    def on_new_receipts(self) -> None:
         raise NotImplementedError()
 
     @abc.abstractmethod
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 1710dd51b9..cf45fd09a8 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -99,7 +99,7 @@ class EmailPusher(Pusher):
                 pass
             self.timed_call = None
 
-    def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
+    def on_new_receipts(self) -> None:
         # We could wake up and cancel the timer but there tend to be quite a
         # lot of read receipts so it's probably less work to just let the
         # timer fire
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 50027680cb..725910a659 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -160,7 +160,7 @@ class HttpPusher(Pusher):
         if should_check_for_notifs:
             self._start_processing()
 
-    def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
+    def on_new_receipts(self) -> None:
         # Note that the min here shouldn't be relied upon to be accurate.
 
         # We could check the receipts are actually m.read receipts here,
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 6517e3566f..15a2cc932f 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -292,20 +292,12 @@ class PusherPool:
         except Exception:
             logger.exception("Exception in pusher on_new_notifications")
 
-    async def on_new_receipts(
-        self, min_stream_id: int, max_stream_id: int, affected_room_ids: Iterable[str]
-    ) -> None:
+    async def on_new_receipts(self, users_affected: StrCollection) -> None:
         if not self.pushers:
             # nothing to do here.
             return
 
         try:
-            # Need to subtract 1 from the minimum because the lower bound here
-            # is not inclusive
-            users_affected = await self.store.get_users_sent_receipts_between(
-                min_stream_id - 1, max_stream_id
-            )
-
             for u in users_affected:
                 # Don't push if the user account has expired
                 expired = await self._account_validity_handler.is_user_expired(u)
@@ -314,7 +306,7 @@ class PusherPool:
 
                 if u in self.pushers:
                     for p in self.pushers[u].values():
-                        p.on_new_receipts(min_stream_id, max_stream_id)
+                        p.on_new_receipts()
 
         except Exception:
             logger.exception("Exception in pusher on_new_receipts")
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index f4f2b29e96..d5337fe588 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -129,9 +129,7 @@ class ReplicationDataHandler:
             self.notifier.on_new_event(
                 StreamKeyType.RECEIPT, token, rooms=[row.room_id for row in rows]
             )
-            await self._pusher_pool.on_new_receipts(
-                token, token, {row.room_id for row in rows}
-            )
+            await self._pusher_pool.on_new_receipts({row.user_id for row in rows})
         elif stream_name == ToDeviceStream.NAME:
             entities = [row.entity for row in rows if row.entity.startswith("@")]
             if entities:
diff --git a/synapse/storage/databases/main/e2e_room_keys.py b/synapse/storage/databases/main/e2e_room_keys.py
index d01f28cc80..bc7c6a6346 100644
--- a/synapse/storage/databases/main/e2e_room_keys.py
+++ b/synapse/storage/databases/main/e2e_room_keys.py
@@ -208,7 +208,7 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore):
                     "message": "Set room key",
                     "room_id": room_id,
                     "session_id": session_id,
-                    StreamKeyType.ROOM: room_key,
+                    StreamKeyType.ROOM.value: room_key,
                 }
             )
 
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 0231f9407b..3bab1024ea 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -742,7 +742,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
         event_ids: List[str],
         thread_id: Optional[str],
         data: dict,
-    ) -> Optional[Tuple[int, int]]:
+    ) -> Optional[int]:
         """Insert a receipt, either from local client or remote server.
 
         Automatically does conversion between linearized and graph
@@ -804,9 +804,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             data,
         )
 
-        max_persisted_id = self._receipts_id_gen.get_current_token()
-
-        return stream_id, max_persisted_id
+        return stream_id
 
     async def _insert_graph_receipt(
         self,
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index d7084d2358..609a0978a9 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import TYPE_CHECKING, Iterator, Tuple
+from typing import TYPE_CHECKING, Sequence, Tuple
 
 import attr
 
@@ -23,7 +23,7 @@ from synapse.handlers.room import RoomEventSource
 from synapse.handlers.typing import TypingNotificationEventSource
 from synapse.logging.opentracing import trace
 from synapse.streams import EventSource
-from synapse.types import StreamToken
+from synapse.types import StreamKeyType, StreamToken
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -37,9 +37,14 @@ class _EventSourcesInner:
     receipt: ReceiptEventSource
     account_data: AccountDataEventSource
 
-    def get_sources(self) -> Iterator[Tuple[str, EventSource]]:
-        for attribute in attr.fields(_EventSourcesInner):
-            yield attribute.name, getattr(self, attribute.name)
+    def get_sources(self) -> Sequence[Tuple[StreamKeyType, EventSource]]:
+        return [
+            (StreamKeyType.ROOM, self.room),
+            (StreamKeyType.PRESENCE, self.presence),
+            (StreamKeyType.TYPING, self.typing),
+            (StreamKeyType.RECEIPT, self.receipt),
+            (StreamKeyType.ACCOUNT_DATA, self.account_data),
+        ]
 
 
 class EventSources:
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 76b0e3e694..406d5b1611 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -22,8 +22,8 @@ from typing import (
     Any,
     ClassVar,
     Dict,
-    Final,
     List,
+    Literal,
     Mapping,
     Match,
     MutableMapping,
@@ -34,6 +34,7 @@ from typing import (
     Type,
     TypeVar,
     Union,
+    overload,
 )
 
 import attr
@@ -649,20 +650,20 @@ class RoomStreamToken:
             return "s%d" % (self.stream,)
 
 
-class StreamKeyType:
+class StreamKeyType(Enum):
     """Known stream types.
 
     A stream is a list of entities ordered by an incrementing "stream token".
     """
 
-    ROOM: Final = "room_key"
-    PRESENCE: Final = "presence_key"
-    TYPING: Final = "typing_key"
-    RECEIPT: Final = "receipt_key"
-    ACCOUNT_DATA: Final = "account_data_key"
-    PUSH_RULES: Final = "push_rules_key"
-    TO_DEVICE: Final = "to_device_key"
-    DEVICE_LIST: Final = "device_list_key"
+    ROOM = "room_key"
+    PRESENCE = "presence_key"
+    TYPING = "typing_key"
+    RECEIPT = "receipt_key"
+    ACCOUNT_DATA = "account_data_key"
+    PUSH_RULES = "push_rules_key"
+    TO_DEVICE = "to_device_key"
+    DEVICE_LIST = "device_list_key"
     UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key"
 
 
@@ -784,7 +785,7 @@ class StreamToken:
     def room_stream_id(self) -> int:
         return self.room_key.stream
 
-    def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken":
+    def copy_and_advance(self, key: StreamKeyType, new_value: Any) -> "StreamToken":
         """Advance the given key in the token to a new value if and only if the
         new value is after the old value.
 
@@ -797,16 +798,44 @@ class StreamToken:
             return new_token
 
         new_token = self.copy_and_replace(key, new_value)
-        new_id = int(getattr(new_token, key))
-        old_id = int(getattr(self, key))
+        new_id = new_token.get_field(key)
+        old_id = self.get_field(key)
 
         if old_id < new_id:
             return new_token
         else:
             return self
 
-    def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken":
-        return attr.evolve(self, **{key: new_value})
+    def copy_and_replace(self, key: StreamKeyType, new_value: Any) -> "StreamToken":
+        return attr.evolve(self, **{key.value: new_value})
+
+    @overload
+    def get_field(self, key: Literal[StreamKeyType.ROOM]) -> RoomStreamToken:
+        ...
+
+    @overload
+    def get_field(
+        self,
+        key: Literal[
+            StreamKeyType.ACCOUNT_DATA,
+            StreamKeyType.DEVICE_LIST,
+            StreamKeyType.PRESENCE,
+            StreamKeyType.PUSH_RULES,
+            StreamKeyType.RECEIPT,
+            StreamKeyType.TO_DEVICE,
+            StreamKeyType.TYPING,
+            StreamKeyType.UN_PARTIAL_STATED_ROOMS,
+        ],
+    ) -> int:
+        ...
+
+    @overload
+    def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]:
+        ...
+
+    def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]:
+        """Returns the stream ID for the given key."""
+        return getattr(self, key.value)
 
 
 StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0, 0)