From 34bc0bec988a115f6959b588a996869cba6a88a6 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 13 Aug 2020 11:21:07 +0100 Subject: Change return type of persist_events --- synapse/handlers/_base.py | 6 +++++- synapse/handlers/federation.py | 26 ++++++++++++++------------ synapse/handlers/message.py | 23 +++++++++++++---------- synapse/handlers/room.py | 16 ++++++++-------- synapse/handlers/room_member.py | 24 ++++++++++++------------ synapse/handlers/room_member_worker.py | 8 ++++---- synapse/push/pusherpool.py | 2 +- synapse/replication/http/federation.py | 4 ++-- synapse/replication/http/membership.py | 4 ++-- synapse/replication/http/send_event.py | 4 ++-- synapse/replication/tcp/client.py | 8 +++++++- synapse/types.py | 6 +++--- 12 files changed, 73 insertions(+), 58 deletions(-) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py index ba2bf99800..a4fee37dd9 100644 --- a/synapse/handlers/_base.py +++ b/synapse/handlers/_base.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING import synapse.state import synapse.storage @@ -22,6 +23,9 @@ from synapse.api.constants import EventTypes, Membership from synapse.api.ratelimiting import Ratelimiter from synapse.types import UserID +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -30,7 +34,7 @@ class BaseHandler(object): Common base class for the event handlers. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): """ Args: hs (synapse.server.HomeServer): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index a0aa884ae5..0e75029f0b 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1276,7 +1276,7 @@ class FederationHandler(BaseHandler): async def do_invite_join( self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict - ) -> Tuple[str, int]: + ) -> Tuple[str, EventStreamToken]: """ Attempts to join the `joinee` to the room `room_id` via the servers contained in `target_hosts`. @@ -1373,7 +1373,7 @@ class FederationHandler(BaseHandler): await self._replication.wait_for_stream_position( self.config.worker.events_shard_config.get_instance(room_id), "events", - max_stream_id, + max_stream_id.stream, ) # Check whether this room is the result of an upgrade of a room we already know @@ -1965,7 +1965,7 @@ class FederationHandler(BaseHandler): state: List[EventBase], event: EventBase, room_version: RoomVersion, - ) -> int: + ) -> EventStreamToken: """Checks the auth chain is valid (and passes auth checks) for the state and event. Then persists the auth chain and state atomically. Persists the event separately. Notifies about the persisted events @@ -2917,7 +2917,7 @@ class FederationHandler(BaseHandler): room_id: str, event_and_contexts: Sequence[Tuple[EventBase, EventContext]], backfilled: bool = False, - ) -> int: + ) -> EventStreamToken: """Persists events and tells the notifier/pushers about them, if necessary. @@ -2938,9 +2938,9 @@ class FederationHandler(BaseHandler): event_and_contexts=event_and_contexts, backfilled=backfilled, ) - return result["max_stream_id"] + return EventStreamToken.parse(result["max_stream_id"]) else: - max_stream_id = await self.storage.persistence.persist_events( + max_stream_token = await self.storage.persistence.persist_events( event_and_contexts, backfilled=backfilled ) @@ -2951,12 +2951,12 @@ class FederationHandler(BaseHandler): if not backfilled: # Never notify for backfilled events for event, _ in event_and_contexts: - await self._notify_persisted_event(event, max_stream_id) + await self._notify_persisted_event(event, max_stream_token) - return max_stream_id + return max_stream_token async def _notify_persisted_event( - self, event: EventBase, max_stream_id: int + self, event: EventBase, max_stream_token: EventStreamToken, ) -> None: """Checks to see if notifier/pushers should be notified about the event or not. @@ -2982,12 +2982,14 @@ class FederationHandler(BaseHandler): elif event.internal_metadata.is_outlier(): return - event_stream_id = event.internal_metadata.stream_ordering + event_stream_token = EventStreamToken(event.internal_metadata.stream_ordering) self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, extra_users=extra_users + event, event_stream_token, max_stream_token, extra_users=extra_users ) - await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) + await self.pusher_pool.on_new_notifications( + event_stream_token.stream, max_stream_token.stream + ) async def _clean_room_for_join(self, room_id: str) -> None: """Called to clean up any data in DB for a given room, ready for the diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 00c98f1d83..8652b50518 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -922,9 +922,9 @@ class EventCreationHandler(object): ratelimit=ratelimit, extra_users=extra_users, ) - stream_id = result["stream_id"] - event.internal_metadata.stream_ordering = stream_id - return EventStreamToken(stream_id) + stream_token = EventStreamToken.parse(result["stream_token"]) + event.internal_metadata.stream_ordering = stream_token.stream + return stream_token stream_token = await self.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users @@ -975,7 +975,7 @@ class EventCreationHandler(object): context: EventContext, ratelimit: bool = True, extra_users: List[UserID] = [], - ) -> int: + ) -> EventStreamToken: """Called when we have fully built the event, have already calculated the push actions for the event, and checked auth. @@ -1146,20 +1146,23 @@ class EventCreationHandler(object): if prev_state_ids: raise AuthError(403, "Changing the room create event is forbidden") - event_stream_id, max_stream_id = await self.storage.persistence.persist_event( - event, context=context - ) + ( + event_stream_token, + max_stream_token, + ) = await self.storage.persistence.persist_event(event, context=context) if self._ephemeral_events_enabled: # If there's an expiry timestamp on the event, schedule its expiry. self._message_handler.maybe_schedule_expiry(event) - await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id) + await self.pusher_pool.on_new_notifications( + event_stream_token.stream, max_stream_token.stream + ) def _notify(): try: self.notifier.on_new_room_event( - event, event_stream_id, max_stream_id, extra_users=extra_users + event, event_stream_token, max_stream_token, extra_users=extra_users ) except Exception: logger.exception("Error notifying about new room event") @@ -1171,7 +1174,7 @@ class EventCreationHandler(object): # matters as sometimes presence code can take a while. run_in_background(self._bump_active_time, requester.user) - return event_stream_id + return event_stream_token async def _bump_active_time(self, user: UserID) -> None: try: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 23ebd6a2ea..df5e75124a 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -559,7 +559,7 @@ class RoomCreationHandler(BaseHandler): config: JsonDict, ratelimit: bool = True, creator_join_profile: Optional[JsonDict] = None, - ) -> Tuple[dict, int]: + ) -> Tuple[dict, EventStreamToken]: """ Creates a new room. Args: @@ -806,7 +806,7 @@ class RoomCreationHandler(BaseHandler): await self._replication.wait_for_stream_position( self.hs.config.worker.events_shard_config.get_instance(room_id), "events", - last_stream_id, + last_stream_id.stream, ) return result, last_stream_id @@ -822,7 +822,7 @@ class RoomCreationHandler(BaseHandler): room_alias: Optional[RoomAlias] = None, power_level_content_override: Optional[JsonDict] = None, creator_join_profile: Optional[JsonDict] = None, - ) -> int: + ) -> EventStreamToken: """Sends the initial events into a new room. `power_level_content_override` doesn't apply when initial state has @@ -844,7 +844,7 @@ class RoomCreationHandler(BaseHandler): return e - async def send(etype: str, content: JsonDict, **kwargs) -> int: + async def send(etype: str, content: JsonDict, **kwargs) -> EventStreamToken: event = create(etype, content, **kwargs) logger.debug("Sending %s in new room", etype) # Allow these events to be sent even if the user is shadow-banned to @@ -1240,7 +1240,7 @@ class RoomShutdownHandler(object): room_creator_requester = create_requester(new_room_user_id) - info, stream_id = await self._room_creation_handler.create_room( + info, stream_token = await self._room_creation_handler.create_room( room_creator_requester, config={ "preset": RoomCreationPreset.PUBLIC_CHAT, @@ -1261,7 +1261,7 @@ class RoomShutdownHandler(object): await self._replication.wait_for_stream_position( self.hs.config.worker.events_shard_config.get_instance(new_room_id), "events", - stream_id, + stream_token.stream, ) else: new_room_id = None @@ -1279,7 +1279,7 @@ class RoomShutdownHandler(object): try: # Kick users from room target_requester = create_requester(user_id) - _, stream_id = await self.room_member_handler.update_membership( + _, stream_token = await self.room_member_handler.update_membership( requester=target_requester, target=target_requester.user, room_id=room_id, @@ -1293,7 +1293,7 @@ class RoomShutdownHandler(object): await self._replication.wait_for_stream_position( self.hs.config.worker.events_shard_config.get_instance(room_id), "events", - stream_id, + stream_token.stream, ) await self.room_member_handler.forget(target_requester.user, room_id) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 0c325b358e..67ed742989 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -114,7 +114,7 @@ class RoomMemberHandler(object): room_id: str, user: UserID, content: dict, - ) -> Tuple[str, int]: + ) -> Tuple[str, EventStreamToken]: """Try and join a room that this server is not in Args: @@ -532,11 +532,11 @@ class RoomMemberHandler(object): if requester.is_guest: content["kind"] = "guest" - remote_join_response = await self._remote_join( + event_id, stream_token = await self._remote_join( requester, remote_room_hosts, room_id, target, content ) - return remote_join_response + return event_id, stream_token elif effective_membership_state == Membership.LEAVE: if not is_host_in_room: @@ -809,7 +809,7 @@ class RoomMemberHandler(object): requester: Requester, txn_id: Optional[str], id_access_token: Optional[str] = None, - ) -> int: + ) -> EventStreamToken: """Invite a 3PID to a room. Args: @@ -867,11 +867,11 @@ class RoomMemberHandler(object): if invitee: # Note that update_membership with an action of "invite" can raise # a ShadowBanError, but this was done above already. - _, stream_id = await self.update_membership( + _, stream_token = await self.update_membership( requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id ) else: - stream_id = await self._make_and_store_3pid_invite( + stream_token = await self._make_and_store_3pid_invite( requester, id_server, medium, @@ -882,7 +882,7 @@ class RoomMemberHandler(object): id_access_token=id_access_token, ) - return stream_id + return stream_token async def _make_and_store_3pid_invite( self, @@ -894,7 +894,7 @@ class RoomMemberHandler(object): user: UserID, txn_id: Optional[str], id_access_token: Optional[str] = None, - ) -> int: + ) -> EventStreamToken: room_state = await self.state_handler.get_current_state(room_id) inviter_display_name = "" @@ -1050,7 +1050,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): room_id: str, user: UserID, content: dict, - ) -> Tuple[str, int]: + ) -> Tuple[str, EventStreamToken]: """Implements RoomMemberHandler._remote_join """ # filter ourselves out of remote_room_hosts: do_invite_join ignores it @@ -1158,7 +1158,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): txn_id: Optional[str], requester: Requester, content: JsonDict, - ) -> Tuple[str, int]: + ) -> Tuple[str, EventStreamToken]: """Generate a local invite rejection This is called after we fail to reject an invite via a remote server. It @@ -1224,10 +1224,10 @@ class RoomMemberMasterHandler(RoomMemberHandler): context = await self.state_handler.compute_event_context(event) context.app_service = requester.app_service - stream_id = await self.event_creation_handler.handle_new_client_event( + stream_token = await self.event_creation_handler.handle_new_client_event( requester, event, context, extra_users=[UserID.from_string(target_user)], ) - return event.event_id, stream_id + return event.event_id, stream_token async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_joined_room diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py index d1db8eb2d9..cc7ab5d01a 100644 --- a/synapse/handlers/room_member_worker.py +++ b/synapse/handlers/room_member_worker.py @@ -43,7 +43,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): room_id: str, user: UserID, content: dict, - ) -> Tuple[str, int]: + ) -> Tuple[str, EventStreamToken]: """Implements RoomMemberHandler._remote_join """ if len(remote_room_hosts) == 0: @@ -59,7 +59,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): await self._user_joined_room(user, room_id) - return ret["event_id"], ret["stream_id"] + return ret["event_id"], EventStreamToken.parse(ret["stream_id"]) async def remote_reject_invite( self, @@ -67,7 +67,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): txn_id: Optional[str], requester: Requester, content: dict, - ) -> Tuple[str, int]: + ) -> Tuple[str, EventStreamToken]: """ Rejects an out-of-band invite received from a remote user @@ -79,7 +79,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler): requester=requester, content=content, ) - return ret["event_id"], EventStreamToken(ret["stream_id"]) + return ret["event_id"], EventStreamToken.parse(ret["stream_id"]) async def _user_joined_room(self, target: UserID, room_id: str) -> None: """Implements RoomMemberHandler._user_joined_room diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 3c3262a88c..15546790a5 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -178,7 +178,7 @@ class PusherPool: ) await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"]) - async def on_new_notifications(self, min_stream_id, max_stream_id): + async def on_new_notifications(self, min_stream_id: int, max_stream_id: int): if not self.pushers: # nothing to do here. return diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py index 5c8be747e1..d48e4fef82 100644 --- a/synapse/replication/http/federation.py +++ b/synapse/replication/http/federation.py @@ -125,11 +125,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): logger.info("Got %d events from federation", len(event_and_contexts)) - max_stream_id = await self.federation_handler.persist_events_and_notify( + max_stream_token = await self.federation_handler.persist_events_and_notify( room_id, event_and_contexts, backfilled ) - return 200, {"max_stream_id": max_stream_id} + return 200, {"max_stream_id": str(max_stream_token)} class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py index 6a39ffa95d..3c91863fae 100644 --- a/synapse/replication/http/membership.py +++ b/synapse/replication/http/membership.py @@ -86,7 +86,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): remote_room_hosts, room_id, user_id, event_content ) - return 200, {"event_id": event_id, "stream_id": stream_id} + return 200, {"event_id": event_id, "stream_id": str(stream_id)} class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): @@ -150,7 +150,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint): invite_event_id, txn_id, requester, event_content, ) - return 200, {"event_id": event_id, "stream_id": stream_token.stream} + return 200, {"event_id": event_id, "stream_id": str(stream_token)} class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py index f13d452426..84c1d13dd4 100644 --- a/synapse/replication/http/send_event.py +++ b/synapse/replication/http/send_event.py @@ -116,11 +116,11 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): "Got event to send with ID: %s into room: %s", event.event_id, event.room_id ) - stream_id = await self.event_creation_handler.persist_and_notify_client_event( + stream_token = await self.event_creation_handler.persist_and_notify_client_event( requester, event, context, ratelimit=ratelimit, extra_users=extra_users ) - return 200, {"stream_id": stream_id} + return 200, {"stream_token": str(stream_token)} def register_servlets(hs, http_server): diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index d6ecf5b327..07d31d457b 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -29,6 +29,7 @@ from synapse.replication.tcp.streams.events import ( EventsStreamEventRow, EventsStreamRow, ) +from synapse.types import EventStreamToken from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure @@ -152,7 +153,12 @@ class ReplicationDataHandler: if event.type == EventTypes.Member: extra_users = (event.state_key,) max_token = self.store.get_room_max_stream_ordering() - self.notifier.on_new_room_event(event, token, max_token, extra_users) + self.notifier.on_new_room_event( + event, + EventStreamToken(token), + EventStreamToken(max_token), + extra_users, + ) await self.pusher_pool.on_new_notifications(token, token) diff --git a/synapse/types.py b/synapse/types.py index 1525e22cac..3be9faa16b 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -446,7 +446,7 @@ class EventStreamToken: stream = attr.ib(type=int, validator=attr.validators.instance_of(int)) @classmethod - def parse(cls, string): + def parse(cls, string: str) -> "EventStreamToken": try: if string[0] == "s": return cls(topological=None, stream=int(string[1:])) @@ -458,7 +458,7 @@ class EventStreamToken: raise SynapseError(400, "Invalid token %r" % (string,)) @classmethod - def parse_stream_token(cls, string): + def parse_stream_token(cls, string: str) -> "EventStreamToken": try: if string[0] == "s": return cls(topological=None, stream=int(string[1:])) @@ -466,7 +466,7 @@ class EventStreamToken: pass raise SynapseError(400, "Invalid token %r" % (string,)) - def __str__(self): + def __str__(self) -> str: if self.topological is not None: return "t%d-%d" % (self.topological, self.stream) else: -- cgit 1.4.1