summary refs log tree commit diff
path: root/synapse/streams/events.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/streams/events.py')
-rw-r--r--synapse/streams/events.py49
1 files changed, 30 insertions, 19 deletions
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 99b0aac2fb..21591d0bfd 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -12,29 +12,40 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, Dict
+from typing import TYPE_CHECKING, Iterator, Tuple
+
+import attr
 
 from synapse.handlers.account_data import AccountDataEventSource
 from synapse.handlers.presence import PresenceEventSource
 from synapse.handlers.receipts import ReceiptEventSource
 from synapse.handlers.room import RoomEventSource
 from synapse.handlers.typing import TypingNotificationEventSource
+from synapse.streams import EventSource
 from synapse.types import StreamToken
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
 
-class EventSources:
-    SOURCE_TYPES = {
-        "room": RoomEventSource,
-        "presence": PresenceEventSource,
-        "typing": TypingNotificationEventSource,
-        "receipt": ReceiptEventSource,
-        "account_data": AccountDataEventSource,
-    }
 
-    def __init__(self, hs):
-        self.sources: Dict[str, Any] = {
-            name: cls(hs) for name, cls in EventSources.SOURCE_TYPES.items()
-        }
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class _EventSourcesInner:
+    room: RoomEventSource
+    presence: PresenceEventSource
+    typing: TypingNotificationEventSource
+    receipt: ReceiptEventSource
+    account_data: AccountDataEventSource
+
+    def get_sources(self) -> Iterator[Tuple[str, EventSource]]:
+        for attribute in _EventSourcesInner.__attrs_attrs__:  # type: ignore[attr-defined]
+            yield attribute.name, getattr(self, attribute.name)
+
+
+class EventSources:
+    def __init__(self, hs: "HomeServer"):
+        self.sources = _EventSourcesInner(
+            *(attribute.type(hs) for attribute in _EventSourcesInner.__attrs_attrs__)  # type: ignore[attr-defined]
+        )
         self.store = hs.get_datastore()
 
     def get_current_token(self) -> StreamToken:
@@ -44,11 +55,11 @@ class EventSources:
         groups_key = self.store.get_group_stream_token()
 
         token = StreamToken(
-            room_key=self.sources["room"].get_current_key(),
-            presence_key=self.sources["presence"].get_current_key(),
-            typing_key=self.sources["typing"].get_current_key(),
-            receipt_key=self.sources["receipt"].get_current_key(),
-            account_data_key=self.sources["account_data"].get_current_key(),
+            room_key=self.sources.room.get_current_key(),
+            presence_key=self.sources.presence.get_current_key(),
+            typing_key=self.sources.typing.get_current_key(),
+            receipt_key=self.sources.receipt.get_current_key(),
+            account_data_key=self.sources.account_data.get_current_key(),
             push_rules_key=push_rules_key,
             to_device_key=to_device_key,
             device_list_key=device_list_key,
@@ -67,7 +78,7 @@ class EventSources:
             The current token for pagination.
         """
         token = StreamToken(
-            room_key=self.sources["room"].get_current_key(),
+            room_key=self.sources.room.get_current_key(),
             presence_key=0,
             typing_key=0,
             receipt_key=0,