summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/8537.misc1
-rw-r--r--synapse/events/builder.py21
-rw-r--r--synapse/handlers/message.py30
-rw-r--r--synapse/handlers/room.py1
-rw-r--r--synapse/handlers/room_member.py59
-rw-r--r--tests/handlers/test_message.py1
-rw-r--r--tests/handlers/test_presence.py2
-rw-r--r--tests/replication/test_federation_sender_shard.py2
-rw-r--r--tests/storage/test_redaction.py4
9 files changed, 56 insertions, 65 deletions
diff --git a/changelog.d/8537.misc b/changelog.d/8537.misc
new file mode 100644
index 0000000000..26309b5b93
--- /dev/null
+++ b/changelog.d/8537.misc
@@ -0,0 +1 @@
+Factor out common code between `RoomMemberHandler._locally_reject_invite` and `EventCreationHandler.create_event`.
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index b6c47be646..df4f950fec 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -97,32 +97,37 @@ class EventBuilder:
     def is_state(self):
         return self._state_key is not None
 
-    async def build(self, prev_event_ids: List[str]) -> EventBase:
+    async def build(
+        self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]]
+    ) -> EventBase:
         """Transform into a fully signed and hashed event
 
         Args:
             prev_event_ids: The event IDs to use as the prev events
+            auth_event_ids: The event IDs to use as the auth events.
+                Should normally be set to None, which will cause them to be calculated
+                based on the room state at the prev_events.
 
         Returns:
             The signed and hashed event.
         """
-
-        state_ids = await self._state.get_current_state_ids(
-            self.room_id, prev_event_ids
-        )
-        auth_ids = self._auth.compute_auth_events(self, state_ids)
+        if auth_event_ids is None:
+            state_ids = await self._state.get_current_state_ids(
+                self.room_id, prev_event_ids
+            )
+            auth_event_ids = self._auth.compute_auth_events(self, state_ids)
 
         format_version = self.room_version.event_format
         if format_version == EventFormatVersions.V1:
             # The types of auth/prev events changes between event versions.
             auth_events = await self._store.add_event_hashes(
-                auth_ids
+                auth_event_ids
             )  # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
             prev_events = await self._store.add_event_hashes(
                 prev_event_ids
             )  # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
         else:
-            auth_events = auth_ids
+            auth_events = auth_event_ids
             prev_events = prev_event_ids
 
         old_depth = await self._store.get_max_depth_of(prev_event_ids)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index c52e6824d3..f18f882596 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -437,9 +437,9 @@ class EventCreationHandler:
         self,
         requester: Requester,
         event_dict: dict,
-        token_id: Optional[str] = None,
         txn_id: Optional[str] = None,
         prev_event_ids: Optional[List[str]] = None,
+        auth_event_ids: Optional[List[str]] = None,
         require_consent: bool = True,
     ) -> Tuple[EventBase, EventContext]:
         """
@@ -453,13 +453,18 @@ class EventCreationHandler:
         Args:
             requester
             event_dict: An entire event
-            token_id
             txn_id
             prev_event_ids:
                 the forward extremities to use as the prev_events for the
                 new event.
 
                 If None, they will be requested from the database.
+
+            auth_event_ids:
+                The event ids to use as the auth_events for the new event.
+                Should normally be left as None, which will cause them to be calculated
+                based on the room state at the prev_events.
+
             require_consent: Whether to check if the requester has
                 consented to the privacy policy.
         Raises:
@@ -511,14 +516,17 @@ class EventCreationHandler:
         if require_consent and not is_exempt:
             await self.assert_accepted_privacy_policy(requester)
 
-        if token_id is not None:
-            builder.internal_metadata.token_id = token_id
+        if requester.access_token_id is not None:
+            builder.internal_metadata.token_id = requester.access_token_id
 
         if txn_id is not None:
             builder.internal_metadata.txn_id = txn_id
 
         event, context = await self.create_new_client_event(
-            builder=builder, requester=requester, prev_event_ids=prev_event_ids,
+            builder=builder,
+            requester=requester,
+            prev_event_ids=prev_event_ids,
+            auth_event_ids=auth_event_ids,
         )
 
         # In an ideal world we wouldn't need the second part of this condition. However,
@@ -726,7 +734,7 @@ class EventCreationHandler:
                     return event, event.internal_metadata.stream_ordering
 
             event, context = await self.create_event(
-                requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id
+                requester, event_dict, txn_id=txn_id
             )
 
             assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
@@ -757,6 +765,7 @@ class EventCreationHandler:
         builder: EventBuilder,
         requester: Optional[Requester] = None,
         prev_event_ids: Optional[List[str]] = None,
+        auth_event_ids: Optional[List[str]] = None,
     ) -> Tuple[EventBase, EventContext]:
         """Create a new event for a local client
 
@@ -769,6 +778,11 @@ class EventCreationHandler:
 
                 If None, they will be requested from the database.
 
+            auth_event_ids:
+                The event ids to use as the auth_events for the new event.
+                Should normally be left as None, which will cause them to be calculated
+                based on the room state at the prev_events.
+
         Returns:
             Tuple of created event, context
         """
@@ -790,7 +804,9 @@ class EventCreationHandler:
             builder.type == EventTypes.Create or len(prev_event_ids) > 0
         ), "Attempting to create an event with no prev_events"
 
-        event = await builder.build(prev_event_ids=prev_event_ids)
+        event = await builder.build(
+            prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids
+        )
         context = await self.state.compute_event_context(event)
         if requester:
             context.app_service = requester.app_service
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 93ed51063a..ec300d8877 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -214,7 +214,6 @@ class RoomCreationHandler(BaseHandler):
                     "replacement_room": new_room_id,
                 },
             },
-            token_id=requester.access_token_id,
         )
         old_room_version = await self.store.get_room_version_id(old_room_id)
         await self.auth.check_from_context(
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 0080eeaf8d..ec784030e9 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -17,12 +17,10 @@ import abc
 import logging
 import random
 from http import HTTPStatus
-from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
-
-from unpaddedbase64 import encode_base64
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
 
 from synapse import types
-from synapse.api.constants import MAX_DEPTH, AccountDataTypes, EventTypes, Membership
+from synapse.api.constants import AccountDataTypes, EventTypes, Membership
 from synapse.api.errors import (
     AuthError,
     Codes,
@@ -31,12 +29,8 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.api.ratelimiting import Ratelimiter
-from synapse.api.room_versions import EventFormatVersions
-from synapse.crypto.event_signing import compute_event_reference_hash
 from synapse.events import EventBase
-from synapse.events.builder import create_local_event_from_event_dict
 from synapse.events.snapshot import EventContext
-from synapse.events.validator import EventValidator
 from synapse.storage.roommember import RoomsForUser
 from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
 from synapse.util.async_helpers import Linearizer
@@ -193,7 +187,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 # For backwards compatibility:
                 "membership": membership,
             },
-            token_id=requester.access_token_id,
             txn_id=txn_id,
             prev_event_ids=prev_event_ids,
             require_consent=require_consent,
@@ -1133,31 +1126,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
 
         room_id = invite_event.room_id
         target_user = invite_event.state_key
-        room_version = await self.store.get_room_version(room_id)
 
         content["membership"] = Membership.LEAVE
 
-        # the auth events for the new event are the same as that of the invite, plus
-        # the invite itself.
-        #
-        # the prev_events are just the invite.
-        invite_hash = invite_event.event_id  # type: Union[str, Tuple]
-        if room_version.event_format == EventFormatVersions.V1:
-            alg, h = compute_event_reference_hash(invite_event)
-            invite_hash = (invite_event.event_id, {alg: encode_base64(h)})
-
-        auth_events = tuple(invite_event.auth_events) + (invite_hash,)
-        prev_events = (invite_hash,)
-
-        # we cap depth of generated events, to ensure that they are not
-        # rejected by other servers (and so that they can be persisted in
-        # the db)
-        depth = min(invite_event.depth + 1, MAX_DEPTH)
-
         event_dict = {
-            "depth": depth,
-            "auth_events": auth_events,
-            "prev_events": prev_events,
             "type": EventTypes.Member,
             "room_id": room_id,
             "sender": target_user,
@@ -1165,24 +1137,23 @@ class RoomMemberMasterHandler(RoomMemberHandler):
             "state_key": target_user,
         }
 
-        event = create_local_event_from_event_dict(
-            clock=self.clock,
-            hostname=self.hs.hostname,
-            signing_key=self.hs.signing_key,
-            room_version=room_version,
-            event_dict=event_dict,
+        # the auth events for the new event are the same as that of the invite, plus
+        # the invite itself.
+        #
+        # the prev_events are just the invite.
+        prev_event_ids = [invite_event.event_id]
+        auth_event_ids = invite_event.auth_event_ids() + prev_event_ids
+
+        event, context = await self.event_creation_handler.create_event(
+            requester,
+            event_dict,
+            txn_id=txn_id,
+            prev_event_ids=prev_event_ids,
+            auth_event_ids=auth_event_ids,
         )
         event.internal_metadata.outlier = True
         event.internal_metadata.out_of_band_membership = True
-        if txn_id is not None:
-            event.internal_metadata.txn_id = txn_id
-        if requester.access_token_id is not None:
-            event.internal_metadata.token_id = requester.access_token_id
-
-        EventValidator().validate_new(event, self.config)
 
-        context = await self.state_handler.compute_event_context(event)
-        context.app_service = requester.app_service
         result_event = await self.event_creation_handler.handle_new_client_event(
             requester, event, context, extra_users=[UserID.from_string(target_user)],
         )
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 64e28bc639..9f6f21a6e2 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -66,7 +66,6 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
                     "sender": self.requester.user.to_string(),
                     "content": {"msgtype": "m.text", "body": random_string(5)},
                 },
-                token_id=self.token_id,
                 txn_id=txn_id,
             )
         )
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 914c82e7a8..8ed67640f8 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -615,7 +615,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
             self.store.get_latest_event_ids_in_room(room_id)
         )
 
-        event = self.get_success(builder.build(prev_event_ids))
+        event = self.get_success(builder.build(prev_event_ids, None))
 
         self.get_success(self.federation_handler.on_receive_pdu(hostname, event))
 
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 9c4a9c3563..779745ae9d 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -226,7 +226,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         }
 
         builder = factory.for_room_version(room_version, event_dict)
-        join_event = self.get_success(builder.build(prev_event_ids))
+        join_event = self.get_success(builder.build(prev_event_ids, None))
 
         self.get_success(federation.on_send_join_request(remote_server, join_event))
         self.replicate()
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index 1ea35d60c1..d4f9e809db 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -236,9 +236,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
                 self._event_id = event_id
 
             @defer.inlineCallbacks
-            def build(self, prev_event_ids):
+            def build(self, prev_event_ids, auth_event_ids):
                 built_event = yield defer.ensureDeferred(
-                    self._base_builder.build(prev_event_ids)
+                    self._base_builder.build(prev_event_ids, auth_event_ids)
                 )
 
                 built_event._event_id = self._event_id