summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-09-21 13:34:26 -0400
committerGitHub <noreply@github.com>2021-09-21 13:34:26 -0400
commit4054dfa409fa17b45ab8f265813994956ed97bae (patch)
tree8a2c2a5700dcd62f1b8ad8593d0b25c9c7b32e55
parentAdd types to http.site (#10867) (diff)
downloadsynapse-4054dfa409fa17b45ab8f265813994956ed97bae.tar.xz
Add type hints for event streams. (#10856)
-rw-r--r--changelog.d/10856.misc1
-rw-r--r--synapse/handlers/account_data.py13
-rw-r--r--synapse/handlers/appservice.py6
-rw-r--r--synapse/handlers/initial_sync.py2
-rw-r--r--synapse/handlers/presence.py8
-rw-r--r--synapse/handlers/receipts.py13
-rw-r--r--synapse/handlers/room.py18
-rw-r--r--synapse/handlers/sync.py6
-rw-r--r--synapse/handlers/typing.py13
-rw-r--r--synapse/module_api/__init__.py2
-rw-r--r--synapse/notifier.py2
-rw-r--r--synapse/storage/databases/main/receipts.py6
-rw-r--r--synapse/streams/__init__.py22
-rw-r--r--synapse/streams/events.py49
-rw-r--r--tests/handlers/test_receipts.py2
-rw-r--r--tests/handlers/test_typing.py46
-rw-r--r--tests/rest/client/test_shadow_banned.py10
-rw-r--r--tests/rest/client/test_typing.py10
18 files changed, 169 insertions, 60 deletions
diff --git a/changelog.d/10856.misc b/changelog.d/10856.misc
new file mode 100644
index 0000000000..f09af2e00a
--- /dev/null
+++ b/changelog.d/10856.misc
@@ -0,0 +1 @@
+Add missing type hints to handlers.
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index e9e7a78546..96273e2f81 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import random
-from typing import TYPE_CHECKING, Any, List, Tuple
+from typing import TYPE_CHECKING, Collection, List, Optional, Tuple
 
 from synapse.replication.http.account_data import (
     ReplicationAddTagRestServlet,
@@ -21,6 +21,7 @@ from synapse.replication.http.account_data import (
     ReplicationRoomAccountDataRestServlet,
     ReplicationUserAccountDataRestServlet,
 )
+from synapse.streams import EventSource
 from synapse.types import JsonDict, UserID
 
 if TYPE_CHECKING:
@@ -163,7 +164,7 @@ class AccountDataHandler:
             return response["max_stream_id"]
 
 
-class AccountDataEventSource:
+class AccountDataEventSource(EventSource[int, JsonDict]):
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
 
@@ -171,7 +172,13 @@ class AccountDataEventSource:
         return self.store.get_max_account_data_stream_id()
 
     async def get_new_events(
-        self, user: UserID, from_key: int, **kwargs: Any
+        self,
+        user: UserID,
+        from_key: int,
+        limit: Optional[int],
+        room_ids: Collection[str],
+        is_guest: bool,
+        explicit_room_id: Optional[str] = None,
     ) -> Tuple[List[JsonDict], int]:
         user_id = user.to_string()
         last_stream_id = from_key
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 8bde9ed66f..b7213b67a5 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -254,7 +254,7 @@ class ApplicationServicesHandler:
     async def _handle_typing(
         self, service: ApplicationService, new_token: int
     ) -> List[JsonDict]:
-        typing_source = self.event_sources.sources["typing"]
+        typing_source = self.event_sources.sources.typing
         # Get the typing events from just before current
         typing, _ = await typing_source.get_new_events_as(
             service=service,
@@ -269,7 +269,7 @@ class ApplicationServicesHandler:
         from_key = await self.store.get_type_stream_id_for_appservice(
             service, "read_receipt"
         )
-        receipts_source = self.event_sources.sources["receipt"]
+        receipts_source = self.event_sources.sources.receipt
         receipts, _ = await receipts_source.get_new_events_as(
             service=service, from_key=from_key
         )
@@ -279,7 +279,7 @@ class ApplicationServicesHandler:
         self, service: ApplicationService, users: Collection[Union[str, UserID]]
     ) -> List[JsonDict]:
         events: List[JsonDict] = []
-        presence_source = self.event_sources.sources["presence"]
+        presence_source = self.event_sources.sources.presence
         from_key = await self.store.get_type_stream_id_for_appservice(
             service, "presence"
         )
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index c942086e74..9ad39a65d8 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -125,7 +125,7 @@ class InitialSyncHandler(BaseHandler):
 
         now_token = self.hs.get_event_sources().get_current_token()
 
-        presence_stream = self.hs.get_event_sources().sources["presence"]
+        presence_stream = self.hs.get_event_sources().sources.presence
         presence, _ = await presence_stream.get_new_events(
             user, from_key=None, include_offline=False
         )
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 841c8815b0..983c837c66 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -65,6 +65,7 @@ from synapse.replication.http.streams import ReplicationGetStreamUpdates
 from synapse.replication.tcp.commands import ClearUserSyncsCommand
 from synapse.replication.tcp.streams import PresenceFederationStream, PresenceStream
 from synapse.storage.databases.main import DataStore
+from synapse.streams import EventSource
 from synapse.types import JsonDict, UserID, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.descriptors import _CacheContext, cached
@@ -1500,7 +1501,7 @@ def format_user_presence_state(
     return content
 
 
-class PresenceEventSource:
+class PresenceEventSource(EventSource[int, UserPresenceState]):
     def __init__(self, hs: "HomeServer"):
         # We can't call get_presence_handler here because there's a cycle:
         #
@@ -1519,10 +1520,11 @@ class PresenceEventSource:
         self,
         user: UserID,
         from_key: Optional[int],
+        limit: Optional[int] = None,
         room_ids: Optional[List[str]] = None,
-        include_offline: bool = True,
+        is_guest: bool = False,
         explicit_room_id: Optional[str] = None,
-        **kwargs: Any,
+        include_offline: bool = True,
     ) -> Tuple[List[UserPresenceState], int]:
         # The process for getting presence events are:
         #  1. Get the rooms the user is in.
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index c7567ac05f..5881f09ebd 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -12,11 +12,12 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Any, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
 
 from synapse.api.constants import ReadReceiptEventFields
 from synapse.appservice import ApplicationService
 from synapse.handlers._base import BaseHandler
+from synapse.streams import EventSource
 from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
 
 if TYPE_CHECKING:
@@ -162,7 +163,7 @@ class ReceiptsHandler(BaseHandler):
             await self.federation_sender.send_read_receipt(receipt)
 
 
-class ReceiptEventSource:
+class ReceiptEventSource(EventSource[int, JsonDict]):
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.config = hs.config
@@ -216,7 +217,13 @@ class ReceiptEventSource:
         return visible_events
 
     async def get_new_events(
-        self, from_key: int, room_ids: List[str], user: UserID, **kwargs: Any
+        self,
+        user: UserID,
+        from_key: int,
+        limit: Optional[int],
+        room_ids: Iterable[str],
+        is_guest: bool,
+        explicit_room_id: Optional[str] = None,
     ) -> Tuple[List[JsonDict], int]:
         from_key = int(from_key)
         to_key = self.get_current_key()
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index abdd506164..287ea2fd06 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -20,7 +20,16 @@ import math
 import random
 import string
 from collections import OrderedDict
-from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Awaitable,
+    Collection,
+    Dict,
+    List,
+    Optional,
+    Tuple,
+)
 
 from synapse.api.constants import (
     EventContentFields,
@@ -47,6 +56,7 @@ from synapse.events import EventBase
 from synapse.events.utils import copy_power_levels_contents
 from synapse.rest.admin._base import assert_user_is_admin
 from synapse.storage.state import StateFilter
+from synapse.streams import EventSource
 from synapse.types import (
     JsonDict,
     MutableStateMap,
@@ -1173,7 +1183,7 @@ class RoomContextHandler:
         return results
 
 
-class RoomEventSource:
+class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
 
@@ -1181,8 +1191,8 @@ class RoomEventSource:
         self,
         user: UserID,
         from_key: RoomStreamToken,
-        limit: int,
-        room_ids: List[str],
+        limit: Optional[int],
+        room_ids: Collection[str],
         is_guest: bool,
         explicit_room_id: Optional[str] = None,
     ) -> Tuple[List[EventBase], RoomStreamToken]:
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index e93db4bdcc..2c7c6d63a9 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -443,7 +443,7 @@ class SyncHandler:
 
             room_ids = sync_result_builder.joined_room_ids
 
-            typing_source = self.event_sources.sources["typing"]
+            typing_source = self.event_sources.sources.typing
             typing, typing_key = await typing_source.get_new_events(
                 user=sync_config.user,
                 from_key=typing_key,
@@ -465,7 +465,7 @@ class SyncHandler:
 
             receipt_key = since_token.receipt_key if since_token else 0
 
-            receipt_source = self.event_sources.sources["receipt"]
+            receipt_source = self.event_sources.sources.receipt
             receipts, receipt_key = await receipt_source.get_new_events(
                 user=sync_config.user,
                 from_key=receipt_key,
@@ -1415,7 +1415,7 @@ class SyncHandler:
         sync_config = sync_result_builder.sync_config
         user = sync_result_builder.sync_config.user
 
-        presence_source = self.event_sources.sources["presence"]
+        presence_source = self.event_sources.sources.presence
 
         since_token = sync_result_builder.since_token
         presence_key = None
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 4492c8567b..9326330c90 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -14,7 +14,7 @@
 import logging
 import random
 from collections import namedtuple
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
 
 from synapse.api.errors import AuthError, ShadowBanError, SynapseError
 from synapse.appservice import ApplicationService
@@ -23,6 +23,7 @@ from synapse.metrics.background_process_metrics import (
     wrap_as_background_process,
 )
 from synapse.replication.tcp.streams import TypingStream
+from synapse.streams import EventSource
 from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.metrics import Measure
@@ -439,7 +440,7 @@ class TypingWriterHandler(FollowerTypingHandler):
         raise Exception("Typing writer instance got typing info over replication")
 
 
-class TypingNotificationEventSource:
+class TypingNotificationEventSource(EventSource[int, JsonDict]):
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.clock = hs.get_clock()
@@ -485,7 +486,13 @@ class TypingNotificationEventSource:
             return (events, handler._latest_room_serial)
 
     async def get_new_events(
-        self, from_key: int, room_ids: Iterable[str], **kwargs: Any
+        self,
+        user: UserID,
+        from_key: int,
+        limit: Optional[int],
+        room_ids: Iterable[str],
+        is_guest: bool,
+        explicit_room_id: Optional[str] = None,
     ) -> Tuple[List[JsonDict], int]:
         with Measure(self.clock, "typing.get_new_events"):
             from_key = int(from_key)
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 2d403532fa..3196c2bec6 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -91,7 +91,7 @@ class ModuleApi:
         self._auth = hs.get_auth()
         self._auth_handler = auth_handler
         self._server_name = hs.hostname
-        self._presence_stream = hs.get_event_sources().sources["presence"]
+        self._presence_stream = hs.get_event_sources().sources.presence
         self._state = hs.get_state_handler()
         self._clock: Clock = hs.get_clock()
         self._send_email_handler = hs.get_send_email_handler()
diff --git a/synapse/notifier.py b/synapse/notifier.py
index bbe337949a..1a9f84ba45 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -584,7 +584,7 @@ class Notifier:
             events: List[EventBase] = []
             end_token = from_token
 
-            for name, source in self.event_sources.sources.items():
+            for name, source in self.event_sources.sources.get_sources():
                 keyname = "%s_key" % name
                 before_id = getattr(before_token, keyname)
                 after_id = getattr(after_token, keyname)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index edeaacd7a6..01a4281301 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Tuple
 
 from twisted.internet import defer
 
@@ -153,12 +153,12 @@ class ReceiptsWorkerStore(SQLBaseStore):
         }
 
     async def get_linearized_receipts_for_rooms(
-        self, room_ids: List[str], to_key: int, from_key: Optional[int] = None
+        self, room_ids: Iterable[str], to_key: int, from_key: Optional[int] = None
     ) -> List[dict]:
         """Get receipts for multiple rooms for sending to clients.
 
         Args:
-            room_id: List of room_ids.
+            room_id: The room IDs to fetch receipts of.
             to_key: Max stream id to fetch receipts up to.
             from_key: Min stream id to fetch receipts from. None fetches
                 from the start.
diff --git a/synapse/streams/__init__.py b/synapse/streams/__init__.py
index 5e83dba2ed..806b671305 100644
--- a/synapse/streams/__init__.py
+++ b/synapse/streams/__init__.py
@@ -11,3 +11,25 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
+from typing import Collection, Generic, List, Optional, Tuple, TypeVar
+
+from synapse.types import UserID
+
+# The key, this is either a stream token or int.
+K = TypeVar("K")
+# The return type.
+R = TypeVar("R")
+
+
+class EventSource(Generic[K, R]):
+    async def get_new_events(
+        self,
+        user: UserID,
+        from_key: K,
+        limit: Optional[int],
+        room_ids: Collection[str],
+        is_guest: bool,
+        explicit_room_id: Optional[str] = None,
+    ) -> Tuple[List[R], K]:
+        ...
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 99b0aac2fb..21591d0bfd 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -12,29 +12,40 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, Dict
+from typing import TYPE_CHECKING, Iterator, Tuple
+
+import attr
 
 from synapse.handlers.account_data import AccountDataEventSource
 from synapse.handlers.presence import PresenceEventSource
 from synapse.handlers.receipts import ReceiptEventSource
 from synapse.handlers.room import RoomEventSource
 from synapse.handlers.typing import TypingNotificationEventSource
+from synapse.streams import EventSource
 from synapse.types import StreamToken
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
 
-class EventSources:
-    SOURCE_TYPES = {
-        "room": RoomEventSource,
-        "presence": PresenceEventSource,
-        "typing": TypingNotificationEventSource,
-        "receipt": ReceiptEventSource,
-        "account_data": AccountDataEventSource,
-    }
 
-    def __init__(self, hs):
-        self.sources: Dict[str, Any] = {
-            name: cls(hs) for name, cls in EventSources.SOURCE_TYPES.items()
-        }
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class _EventSourcesInner:
+    room: RoomEventSource
+    presence: PresenceEventSource
+    typing: TypingNotificationEventSource
+    receipt: ReceiptEventSource
+    account_data: AccountDataEventSource
+
+    def get_sources(self) -> Iterator[Tuple[str, EventSource]]:
+        for attribute in _EventSourcesInner.__attrs_attrs__:  # type: ignore[attr-defined]
+            yield attribute.name, getattr(self, attribute.name)
+
+
+class EventSources:
+    def __init__(self, hs: "HomeServer"):
+        self.sources = _EventSourcesInner(
+            *(attribute.type(hs) for attribute in _EventSourcesInner.__attrs_attrs__)  # type: ignore[attr-defined]
+        )
         self.store = hs.get_datastore()
 
     def get_current_token(self) -> StreamToken:
@@ -44,11 +55,11 @@ class EventSources:
         groups_key = self.store.get_group_stream_token()
 
         token = StreamToken(
-            room_key=self.sources["room"].get_current_key(),
-            presence_key=self.sources["presence"].get_current_key(),
-            typing_key=self.sources["typing"].get_current_key(),
-            receipt_key=self.sources["receipt"].get_current_key(),
-            account_data_key=self.sources["account_data"].get_current_key(),
+            room_key=self.sources.room.get_current_key(),
+            presence_key=self.sources.presence.get_current_key(),
+            typing_key=self.sources.typing.get_current_key(),
+            receipt_key=self.sources.receipt.get_current_key(),
+            account_data_key=self.sources.account_data.get_current_key(),
             push_rules_key=push_rules_key,
             to_device_key=to_device_key,
             device_list_key=device_list_key,
@@ -67,7 +78,7 @@ class EventSources:
             The current token for pagination.
         """
         token = StreamToken(
-            room_key=self.sources["room"].get_current_key(),
+            room_key=self.sources.room.get_current_key(),
             presence_key=0,
             typing_key=0,
             receipt_key=0,
diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py
index 732a12c9bd..5de89c873b 100644
--- a/tests/handlers/test_receipts.py
+++ b/tests/handlers/test_receipts.py
@@ -23,7 +23,7 @@ from tests import unittest
 
 class ReceiptsTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
-        self.event_source = hs.get_event_sources().sources["receipt"]
+        self.event_source = hs.get_event_sources().sources.receipt
 
     # In the first param of _test_filters_hidden we use "hidden" instead of
     # ReadReceiptEventFields.MSC2285_HIDDEN. We do this because we're mocking
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index fa3cff598e..000f9b9fde 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -89,7 +89,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.handler = hs.get_typing_handler()
 
-        self.event_source = hs.get_event_sources().sources["typing"]
+        self.event_source = hs.get_event_sources().sources.typing
 
         self.datastore = hs.get_datastore()
         self.datastore.get_destination_retry_timings = Mock(
@@ -171,7 +171,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(self.event_source.get_current_key(), 1)
         events = self.get_success(
-            self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+            self.event_source.get_new_events(
+                user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
+            )
         )
         self.assertEquals(
             events[0],
@@ -239,7 +241,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(self.event_source.get_current_key(), 1)
         events = self.get_success(
-            self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+            self.event_source.get_new_events(
+                user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
+            )
         )
         self.assertEquals(
             events[0],
@@ -276,7 +280,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(self.event_source.get_current_key(), 0)
         events = self.get_success(
-            self.event_source.get_new_events(room_ids=[OTHER_ROOM_ID], from_key=0)
+            self.event_source.get_new_events(
+                user=U_APPLE,
+                from_key=0,
+                limit=None,
+                room_ids=[OTHER_ROOM_ID],
+                is_guest=False,
+            )
         )
         self.assertEquals(events[0], [])
         self.assertEquals(events[1], 0)
@@ -324,7 +334,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(self.event_source.get_current_key(), 1)
         events = self.get_success(
-            self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+            self.event_source.get_new_events(
+                user=U_APPLE, from_key=0, limit=None, room_ids=[ROOM_ID], is_guest=False
+            )
         )
         self.assertEquals(
             events[0],
@@ -350,7 +362,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(self.event_source.get_current_key(), 1)
         events = self.get_success(
-            self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+            self.event_source.get_new_events(
+                user=U_APPLE,
+                from_key=0,
+                limit=None,
+                room_ids=[ROOM_ID],
+                is_guest=False,
+            )
         )
         self.assertEquals(
             events[0],
@@ -369,7 +387,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(self.event_source.get_current_key(), 2)
         events = self.get_success(
-            self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1)
+            self.event_source.get_new_events(
+                user=U_APPLE,
+                from_key=1,
+                limit=None,
+                room_ids=[ROOM_ID],
+                is_guest=False,
+            )
         )
         self.assertEquals(
             events[0],
@@ -392,7 +416,13 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(self.event_source.get_current_key(), 3)
         events = self.get_success(
-            self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+            self.event_source.get_new_events(
+                user=U_APPLE,
+                from_key=0,
+                limit=None,
+                room_ids=[ROOM_ID],
+                is_guest=False,
+            )
         )
         self.assertEquals(
             events[0],
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index 6a0d9a82be..b0c44af033 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -193,7 +193,7 @@ class RoomTestCase(_ShadowBannedBase):
         self.assertEquals(200, channel.code)
 
         # There should be no typing events.
-        event_source = self.hs.get_event_sources().sources["typing"]
+        event_source = self.hs.get_event_sources().sources.typing
         self.assertEquals(event_source.get_current_key(), 0)
 
         # The other user can join and send typing events.
@@ -210,7 +210,13 @@ class RoomTestCase(_ShadowBannedBase):
         # These appear in the room.
         self.assertEquals(event_source.get_current_key(), 1)
         events = self.get_success(
-            event_source.get_new_events(from_key=0, room_ids=[room_id])
+            event_source.get_new_events(
+                user=UserID.from_string(self.other_user_id),
+                from_key=0,
+                limit=None,
+                room_ids=[room_id],
+                is_guest=False,
+            )
         )
         self.assertEquals(
             events[0],
diff --git a/tests/rest/client/test_typing.py b/tests/rest/client/test_typing.py
index b54b004733..ee0abd5295 100644
--- a/tests/rest/client/test_typing.py
+++ b/tests/rest/client/test_typing.py
@@ -41,7 +41,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
             federation_client=Mock(),
         )
 
-        self.event_source = hs.get_event_sources().sources["typing"]
+        self.event_source = hs.get_event_sources().sources.typing
 
         hs.get_federation_handler = Mock()
 
@@ -76,7 +76,13 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(self.event_source.get_current_key(), 1)
         events = self.get_success(
-            self.event_source.get_new_events(from_key=0, room_ids=[self.room_id])
+            self.event_source.get_new_events(
+                user=UserID.from_string(self.user_id),
+                from_key=0,
+                limit=None,
+                room_ids=[self.room_id],
+                is_guest=False,
+            )
         )
         self.assertEquals(
             events[0],