summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-08-13 11:21:07 +0100
committerErik Johnston <erik@matrix.org>2020-09-02 16:15:23 +0100
commit34bc0bec988a115f6959b588a996869cba6a88a6 (patch)
tree0a74eb03ab407dbfd3f377351e07e3b79d7b5480
parentConvert stuff (diff)
downloadsynapse-erikj/event_token_type.tar.xz
Change return type of persist_events github/erikj/event_token_type erikj/event_token_type
-rw-r--r--synapse/handlers/_base.py6
-rw-r--r--synapse/handlers/federation.py26
-rw-r--r--synapse/handlers/message.py23
-rw-r--r--synapse/handlers/room.py16
-rw-r--r--synapse/handlers/room_member.py24
-rw-r--r--synapse/handlers/room_member_worker.py8
-rw-r--r--synapse/push/pusherpool.py2
-rw-r--r--synapse/replication/http/federation.py4
-rw-r--r--synapse/replication/http/membership.py4
-rw-r--r--synapse/replication/http/send_event.py4
-rw-r--r--synapse/replication/tcp/client.py8
-rw-r--r--synapse/types.py6
12 files changed, 73 insertions, 58 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index ba2bf99800..a4fee37dd9 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 import synapse.state
 import synapse.storage
@@ -22,6 +23,9 @@ from synapse.api.constants import EventTypes, Membership
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.types import UserID
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -30,7 +34,7 @@ class BaseHandler(object):
     Common base class for the event handlers.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         """
         Args:
             hs (synapse.server.HomeServer):
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index a0aa884ae5..0e75029f0b 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1276,7 +1276,7 @@ class FederationHandler(BaseHandler):
 
     async def do_invite_join(
         self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict
-    ) -> Tuple[str, int]:
+    ) -> Tuple[str, EventStreamToken]:
         """ Attempts to join the `joinee` to the room `room_id` via the
         servers contained in `target_hosts`.
 
@@ -1373,7 +1373,7 @@ class FederationHandler(BaseHandler):
             await self._replication.wait_for_stream_position(
                 self.config.worker.events_shard_config.get_instance(room_id),
                 "events",
-                max_stream_id,
+                max_stream_id.stream,
             )
 
             # Check whether this room is the result of an upgrade of a room we already know
@@ -1965,7 +1965,7 @@ class FederationHandler(BaseHandler):
         state: List[EventBase],
         event: EventBase,
         room_version: RoomVersion,
-    ) -> int:
+    ) -> EventStreamToken:
         """Checks the auth chain is valid (and passes auth checks) for the
         state and event. Then persists the auth chain and state atomically.
         Persists the event separately. Notifies about the persisted events
@@ -2917,7 +2917,7 @@ class FederationHandler(BaseHandler):
         room_id: str,
         event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
         backfilled: bool = False,
-    ) -> int:
+    ) -> EventStreamToken:
         """Persists events and tells the notifier/pushers about them, if
         necessary.
 
@@ -2938,9 +2938,9 @@ class FederationHandler(BaseHandler):
                 event_and_contexts=event_and_contexts,
                 backfilled=backfilled,
             )
-            return result["max_stream_id"]
+            return EventStreamToken.parse(result["max_stream_id"])
         else:
-            max_stream_id = await self.storage.persistence.persist_events(
+            max_stream_token = await self.storage.persistence.persist_events(
                 event_and_contexts, backfilled=backfilled
             )
 
@@ -2951,12 +2951,12 @@ class FederationHandler(BaseHandler):
 
             if not backfilled:  # Never notify for backfilled events
                 for event, _ in event_and_contexts:
-                    await self._notify_persisted_event(event, max_stream_id)
+                    await self._notify_persisted_event(event, max_stream_token)
 
-            return max_stream_id
+            return max_stream_token
 
     async def _notify_persisted_event(
-        self, event: EventBase, max_stream_id: int
+        self, event: EventBase, max_stream_token: EventStreamToken,
     ) -> None:
         """Checks to see if notifier/pushers should be notified about the
         event or not.
@@ -2982,12 +2982,14 @@ class FederationHandler(BaseHandler):
         elif event.internal_metadata.is_outlier():
             return
 
-        event_stream_id = event.internal_metadata.stream_ordering
+        event_stream_token = EventStreamToken(event.internal_metadata.stream_ordering)
         self.notifier.on_new_room_event(
-            event, event_stream_id, max_stream_id, extra_users=extra_users
+            event, event_stream_token, max_stream_token, extra_users=extra_users
         )
 
-        await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
+        await self.pusher_pool.on_new_notifications(
+            event_stream_token.stream, max_stream_token.stream
+        )
 
     async def _clean_room_for_join(self, room_id: str) -> None:
         """Called to clean up any data in DB for a given room, ready for the
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 00c98f1d83..8652b50518 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -922,9 +922,9 @@ class EventCreationHandler(object):
                     ratelimit=ratelimit,
                     extra_users=extra_users,
                 )
-                stream_id = result["stream_id"]
-                event.internal_metadata.stream_ordering = stream_id
-                return EventStreamToken(stream_id)
+                stream_token = EventStreamToken.parse(result["stream_token"])
+                event.internal_metadata.stream_ordering = stream_token.stream
+                return stream_token
 
             stream_token = await self.persist_and_notify_client_event(
                 requester, event, context, ratelimit=ratelimit, extra_users=extra_users
@@ -975,7 +975,7 @@ class EventCreationHandler(object):
         context: EventContext,
         ratelimit: bool = True,
         extra_users: List[UserID] = [],
-    ) -> int:
+    ) -> EventStreamToken:
         """Called when we have fully built the event, have already
         calculated the push actions for the event, and checked auth.
 
@@ -1146,20 +1146,23 @@ class EventCreationHandler(object):
             if prev_state_ids:
                 raise AuthError(403, "Changing the room create event is forbidden")
 
-        event_stream_id, max_stream_id = await self.storage.persistence.persist_event(
-            event, context=context
-        )
+        (
+            event_stream_token,
+            max_stream_token,
+        ) = await self.storage.persistence.persist_event(event, context=context)
 
         if self._ephemeral_events_enabled:
             # If there's an expiry timestamp on the event, schedule its expiry.
             self._message_handler.maybe_schedule_expiry(event)
 
-        await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
+        await self.pusher_pool.on_new_notifications(
+            event_stream_token.stream, max_stream_token.stream
+        )
 
         def _notify():
             try:
                 self.notifier.on_new_room_event(
-                    event, event_stream_id, max_stream_id, extra_users=extra_users
+                    event, event_stream_token, max_stream_token, extra_users=extra_users
                 )
             except Exception:
                 logger.exception("Error notifying about new room event")
@@ -1171,7 +1174,7 @@ class EventCreationHandler(object):
             # matters as sometimes presence code can take a while.
             run_in_background(self._bump_active_time, requester.user)
 
-        return event_stream_id
+        return event_stream_token
 
     async def _bump_active_time(self, user: UserID) -> None:
         try:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 23ebd6a2ea..df5e75124a 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -559,7 +559,7 @@ class RoomCreationHandler(BaseHandler):
         config: JsonDict,
         ratelimit: bool = True,
         creator_join_profile: Optional[JsonDict] = None,
-    ) -> Tuple[dict, int]:
+    ) -> Tuple[dict, EventStreamToken]:
         """ Creates a new room.
 
         Args:
@@ -806,7 +806,7 @@ class RoomCreationHandler(BaseHandler):
         await self._replication.wait_for_stream_position(
             self.hs.config.worker.events_shard_config.get_instance(room_id),
             "events",
-            last_stream_id,
+            last_stream_id.stream,
         )
 
         return result, last_stream_id
@@ -822,7 +822,7 @@ class RoomCreationHandler(BaseHandler):
         room_alias: Optional[RoomAlias] = None,
         power_level_content_override: Optional[JsonDict] = None,
         creator_join_profile: Optional[JsonDict] = None,
-    ) -> int:
+    ) -> EventStreamToken:
         """Sends the initial events into a new room.
 
         `power_level_content_override` doesn't apply when initial state has
@@ -844,7 +844,7 @@ class RoomCreationHandler(BaseHandler):
 
             return e
 
-        async def send(etype: str, content: JsonDict, **kwargs) -> int:
+        async def send(etype: str, content: JsonDict, **kwargs) -> EventStreamToken:
             event = create(etype, content, **kwargs)
             logger.debug("Sending %s in new room", etype)
             # Allow these events to be sent even if the user is shadow-banned to
@@ -1240,7 +1240,7 @@ class RoomShutdownHandler(object):
 
             room_creator_requester = create_requester(new_room_user_id)
 
-            info, stream_id = await self._room_creation_handler.create_room(
+            info, stream_token = await self._room_creation_handler.create_room(
                 room_creator_requester,
                 config={
                     "preset": RoomCreationPreset.PUBLIC_CHAT,
@@ -1261,7 +1261,7 @@ class RoomShutdownHandler(object):
             await self._replication.wait_for_stream_position(
                 self.hs.config.worker.events_shard_config.get_instance(new_room_id),
                 "events",
-                stream_id,
+                stream_token.stream,
             )
         else:
             new_room_id = None
@@ -1279,7 +1279,7 @@ class RoomShutdownHandler(object):
             try:
                 # Kick users from room
                 target_requester = create_requester(user_id)
-                _, stream_id = await self.room_member_handler.update_membership(
+                _, stream_token = await self.room_member_handler.update_membership(
                     requester=target_requester,
                     target=target_requester.user,
                     room_id=room_id,
@@ -1293,7 +1293,7 @@ class RoomShutdownHandler(object):
                 await self._replication.wait_for_stream_position(
                     self.hs.config.worker.events_shard_config.get_instance(room_id),
                     "events",
-                    stream_id,
+                    stream_token.stream,
                 )
 
                 await self.room_member_handler.forget(target_requester.user, room_id)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 0c325b358e..67ed742989 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -114,7 +114,7 @@ class RoomMemberHandler(object):
         room_id: str,
         user: UserID,
         content: dict,
-    ) -> Tuple[str, int]:
+    ) -> Tuple[str, EventStreamToken]:
         """Try and join a room that this server is not in
 
         Args:
@@ -532,11 +532,11 @@ class RoomMemberHandler(object):
                 if requester.is_guest:
                     content["kind"] = "guest"
 
-                remote_join_response = await self._remote_join(
+                event_id, stream_token = await self._remote_join(
                     requester, remote_room_hosts, room_id, target, content
                 )
 
-                return remote_join_response
+                return event_id, stream_token
 
         elif effective_membership_state == Membership.LEAVE:
             if not is_host_in_room:
@@ -809,7 +809,7 @@ class RoomMemberHandler(object):
         requester: Requester,
         txn_id: Optional[str],
         id_access_token: Optional[str] = None,
-    ) -> int:
+    ) -> EventStreamToken:
         """Invite a 3PID to a room.
 
         Args:
@@ -867,11 +867,11 @@ class RoomMemberHandler(object):
         if invitee:
             # Note that update_membership with an action of "invite" can raise
             # a ShadowBanError, but this was done above already.
-            _, stream_id = await self.update_membership(
+            _, stream_token = await self.update_membership(
                 requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
             )
         else:
-            stream_id = await self._make_and_store_3pid_invite(
+            stream_token = await self._make_and_store_3pid_invite(
                 requester,
                 id_server,
                 medium,
@@ -882,7 +882,7 @@ class RoomMemberHandler(object):
                 id_access_token=id_access_token,
             )
 
-        return stream_id
+        return stream_token
 
     async def _make_and_store_3pid_invite(
         self,
@@ -894,7 +894,7 @@ class RoomMemberHandler(object):
         user: UserID,
         txn_id: Optional[str],
         id_access_token: Optional[str] = None,
-    ) -> int:
+    ) -> EventStreamToken:
         room_state = await self.state_handler.get_current_state(room_id)
 
         inviter_display_name = ""
@@ -1050,7 +1050,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         room_id: str,
         user: UserID,
         content: dict,
-    ) -> Tuple[str, int]:
+    ) -> Tuple[str, EventStreamToken]:
         """Implements RoomMemberHandler._remote_join
         """
         # filter ourselves out of remote_room_hosts: do_invite_join ignores it
@@ -1158,7 +1158,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         txn_id: Optional[str],
         requester: Requester,
         content: JsonDict,
-    ) -> Tuple[str, int]:
+    ) -> Tuple[str, EventStreamToken]:
         """Generate a local invite rejection
 
         This is called after we fail to reject an invite via a remote server. It
@@ -1224,10 +1224,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
 
         context = await self.state_handler.compute_event_context(event)
         context.app_service = requester.app_service
-        stream_id = await self.event_creation_handler.handle_new_client_event(
+        stream_token = await self.event_creation_handler.handle_new_client_event(
             requester, event, context, extra_users=[UserID.from_string(target_user)],
         )
-        return event.event_id, stream_id
+        return event.event_id, stream_token
 
     async def _user_joined_room(self, target: UserID, room_id: str) -> None:
         """Implements RoomMemberHandler._user_joined_room
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index d1db8eb2d9..cc7ab5d01a 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -43,7 +43,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
         room_id: str,
         user: UserID,
         content: dict,
-    ) -> Tuple[str, int]:
+    ) -> Tuple[str, EventStreamToken]:
         """Implements RoomMemberHandler._remote_join
         """
         if len(remote_room_hosts) == 0:
@@ -59,7 +59,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
 
         await self._user_joined_room(user, room_id)
 
-        return ret["event_id"], ret["stream_id"]
+        return ret["event_id"], EventStreamToken.parse(ret["stream_id"])
 
     async def remote_reject_invite(
         self,
@@ -67,7 +67,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
         txn_id: Optional[str],
         requester: Requester,
         content: dict,
-    ) -> Tuple[str, int]:
+    ) -> Tuple[str, EventStreamToken]:
         """
         Rejects an out-of-band invite received from a remote user
 
@@ -79,7 +79,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
             requester=requester,
             content=content,
         )
-        return ret["event_id"], EventStreamToken(ret["stream_id"])
+        return ret["event_id"], EventStreamToken.parse(ret["stream_id"])
 
     async def _user_joined_room(self, target: UserID, room_id: str) -> None:
         """Implements RoomMemberHandler._user_joined_room
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 3c3262a88c..15546790a5 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -178,7 +178,7 @@ class PusherPool:
                 )
                 await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
 
-    async def on_new_notifications(self, min_stream_id, max_stream_id):
+    async def on_new_notifications(self, min_stream_id: int, max_stream_id: int):
         if not self.pushers:
             # nothing to do here.
             return
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 5c8be747e1..d48e4fef82 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -125,11 +125,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
 
         logger.info("Got %d events from federation", len(event_and_contexts))
 
-        max_stream_id = await self.federation_handler.persist_events_and_notify(
+        max_stream_token = await self.federation_handler.persist_events_and_notify(
             room_id, event_and_contexts, backfilled
         )
 
-        return 200, {"max_stream_id": max_stream_id}
+        return 200, {"max_stream_id": str(max_stream_token)}
 
 
 class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index 6a39ffa95d..3c91863fae 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -86,7 +86,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
             remote_room_hosts, room_id, user_id, event_content
         )
 
-        return 200, {"event_id": event_id, "stream_id": stream_id}
+        return 200, {"event_id": event_id, "stream_id": str(stream_id)}
 
 
 class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
@@ -150,7 +150,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
             invite_event_id, txn_id, requester, event_content,
         )
 
-        return 200, {"event_id": event_id, "stream_id": stream_token.stream}
+        return 200, {"event_id": event_id, "stream_id": str(stream_token)}
 
 
 class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index f13d452426..84c1d13dd4 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -116,11 +116,11 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
             "Got event to send with ID: %s into room: %s", event.event_id, event.room_id
         )
 
-        stream_id = await self.event_creation_handler.persist_and_notify_client_event(
+        stream_token = await self.event_creation_handler.persist_and_notify_client_event(
             requester, event, context, ratelimit=ratelimit, extra_users=extra_users
         )
 
-        return 200, {"stream_id": stream_id}
+        return 200, {"stream_token": str(stream_token)}
 
 
 def register_servlets(hs, http_server):
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index d6ecf5b327..07d31d457b 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -29,6 +29,7 @@ from synapse.replication.tcp.streams.events import (
     EventsStreamEventRow,
     EventsStreamRow,
 )
+from synapse.types import EventStreamToken
 from synapse.util.async_helpers import timeout_deferred
 from synapse.util.metrics import Measure
 
@@ -152,7 +153,12 @@ class ReplicationDataHandler:
                 if event.type == EventTypes.Member:
                     extra_users = (event.state_key,)
                 max_token = self.store.get_room_max_stream_ordering()
-                self.notifier.on_new_room_event(event, token, max_token, extra_users)
+                self.notifier.on_new_room_event(
+                    event,
+                    EventStreamToken(token),
+                    EventStreamToken(max_token),
+                    extra_users,
+                )
 
             await self.pusher_pool.on_new_notifications(token, token)
 
diff --git a/synapse/types.py b/synapse/types.py
index 1525e22cac..3be9faa16b 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -446,7 +446,7 @@ class EventStreamToken:
     stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
 
     @classmethod
-    def parse(cls, string):
+    def parse(cls, string: str) -> "EventStreamToken":
         try:
             if string[0] == "s":
                 return cls(topological=None, stream=int(string[1:]))
@@ -458,7 +458,7 @@ class EventStreamToken:
         raise SynapseError(400, "Invalid token %r" % (string,))
 
     @classmethod
-    def parse_stream_token(cls, string):
+    def parse_stream_token(cls, string: str) -> "EventStreamToken":
         try:
             if string[0] == "s":
                 return cls(topological=None, stream=int(string[1:]))
@@ -466,7 +466,7 @@ class EventStreamToken:
             pass
         raise SynapseError(400, "Invalid token %r" % (string,))
 
-    def __str__(self):
+    def __str__(self) -> str:
         if self.topological is not None:
             return "t%d-%d" % (self.topological, self.stream)
         else: