summary refs log tree commit diff
path: root/synapse/events
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/events')
-rw-r--r--synapse/events/__init__.py40
-rw-r--r--synapse/events/builder.py16
-rw-r--r--synapse/events/spamcheck.py4
-rw-r--r--synapse/events/third_party_rules.py245
4 files changed, 240 insertions, 65 deletions
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 6286ad999a..0298af4c02 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -105,28 +105,28 @@ class _EventInternalMetadata:
         self._dict = dict(internal_metadata_dict)
 
         # the stream ordering of this event. None, until it has been persisted.
-        self.stream_ordering = None  # type: Optional[int]
+        self.stream_ordering: Optional[int] = None
 
         # whether this event is an outlier (ie, whether we have the state at that point
         # in the DAG)
         self.outlier = False
 
-    out_of_band_membership = DictProperty("out_of_band_membership")  # type: bool
-    send_on_behalf_of = DictProperty("send_on_behalf_of")  # type: str
-    recheck_redaction = DictProperty("recheck_redaction")  # type: bool
-    soft_failed = DictProperty("soft_failed")  # type: bool
-    proactively_send = DictProperty("proactively_send")  # type: bool
-    redacted = DictProperty("redacted")  # type: bool
-    txn_id = DictProperty("txn_id")  # type: str
-    token_id = DictProperty("token_id")  # type: int
-    historical = DictProperty("historical")  # type: bool
+    out_of_band_membership: bool = DictProperty("out_of_band_membership")
+    send_on_behalf_of: str = DictProperty("send_on_behalf_of")
+    recheck_redaction: bool = DictProperty("recheck_redaction")
+    soft_failed: bool = DictProperty("soft_failed")
+    proactively_send: bool = DictProperty("proactively_send")
+    redacted: bool = DictProperty("redacted")
+    txn_id: str = DictProperty("txn_id")
+    token_id: int = DictProperty("token_id")
+    historical: bool = DictProperty("historical")
 
     # XXX: These are set by StreamWorkerStore._set_before_and_after.
     # I'm pretty sure that these are never persisted to the database, so shouldn't
     # be here
-    before = DictProperty("before")  # type: RoomStreamToken
-    after = DictProperty("after")  # type: RoomStreamToken
-    order = DictProperty("order")  # type: Tuple[int, int]
+    before: RoomStreamToken = DictProperty("before")
+    after: RoomStreamToken = DictProperty("after")
+    order: Tuple[int, int] = DictProperty("order")
 
     def get_dict(self) -> JsonDict:
         return dict(self._dict)
@@ -291,6 +291,20 @@ class EventBase(metaclass=abc.ABCMeta):
 
         return pdu_json
 
+    def get_templated_pdu_json(self) -> JsonDict:
+        """
+        Return a JSON object suitable for a templated event, as used in the
+        make_{join,leave,knock} workflow.
+        """
+        # By using _dict directly we don't pull in signatures/unsigned.
+        template_json = dict(self._dict)
+        # The hashes (similar to the signature) need to be recalculated by the
+        # joining/leaving/knocking server after (potentially) modifying the
+        # event.
+        template_json.pop("hashes")
+
+        return template_json
+
     def __set__(self, instance, value):
         raise AttributeError("Unrecognized attribute %s" % (instance,))
 
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index 26e3950859..87e2bb123b 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -132,12 +132,12 @@ class EventBuilder:
         format_version = self.room_version.event_format
         if format_version == EventFormatVersions.V1:
             # The types of auth/prev events changes between event versions.
-            auth_events = await self._store.add_event_hashes(
-                auth_event_ids
-            )  # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
-            prev_events = await self._store.add_event_hashes(
-                prev_event_ids
-            )  # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
+            auth_events: Union[
+                List[str], List[Tuple[str, Dict[str, str]]]
+            ] = await self._store.add_event_hashes(auth_event_ids)
+            prev_events: Union[
+                List[str], List[Tuple[str, Dict[str, str]]]
+            ] = await self._store.add_event_hashes(prev_event_ids)
         else:
             auth_events = auth_event_ids
             prev_events = prev_event_ids
@@ -156,7 +156,7 @@ class EventBuilder:
         # the db)
         depth = min(depth, MAX_DEPTH)
 
-        event_dict = {
+        event_dict: Dict[str, Any] = {
             "auth_events": auth_events,
             "prev_events": prev_events,
             "type": self.type,
@@ -166,7 +166,7 @@ class EventBuilder:
             "unsigned": self.unsigned,
             "depth": depth,
             "prev_state": [],
-        }  # type: Dict[str, Any]
+        }
 
         if self.is_state():
             event_dict["state_key"] = self._state_key
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index efec16c226..57f1d53fa8 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -76,7 +76,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"):
     """Wrapper that loads spam checkers configured using the old configuration, and
     registers the spam checker hooks they implement.
     """
-    spam_checkers = []  # type: List[Any]
+    spam_checkers: List[Any] = []
     api = hs.get_module_api()
     for module, config in hs.config.spam_checkers:
         # Older spam checkers don't accept the `api` argument, so we
@@ -239,7 +239,7 @@ class SpamChecker:
             will be used as the error message returned to the user.
         """
         for callback in self._check_event_for_spam_callbacks:
-            res = await callback(event)  # type: Union[bool, str]
+            res: Union[bool, str] = await callback(event)
             if res:
                 return res
 
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index f7944fd834..7a6eb3e516 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -11,16 +11,124 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
 
-from typing import TYPE_CHECKING, Union
-
+from synapse.api.errors import SynapseError
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.types import Requester, StateMap
+from synapse.util.async_helpers import maybe_awaitable
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
+logger = logging.getLogger(__name__)
+
+
+CHECK_EVENT_ALLOWED_CALLBACK = Callable[
+    [EventBase, StateMap[EventBase]], Awaitable[Tuple[bool, Optional[dict]]]
+]
+ON_CREATE_ROOM_CALLBACK = Callable[[Requester, dict, bool], Awaitable]
+CHECK_THREEPID_CAN_BE_INVITED_CALLBACK = Callable[
+    [str, str, StateMap[EventBase]], Awaitable[bool]
+]
+CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[
+    [str, StateMap[EventBase], str], Awaitable[bool]
+]
+
+
+def load_legacy_third_party_event_rules(hs: "HomeServer"):
+    """Wrapper that loads a third party event rules module configured using the old
+    configuration, and registers the hooks they implement.
+    """
+    if hs.config.third_party_event_rules is None:
+        return
+
+    module, config = hs.config.third_party_event_rules
+
+    api = hs.get_module_api()
+    third_party_rules = module(config=config, module_api=api)
+
+    # The known hooks. If a module implements a method which name appears in this set,
+    # we'll want to register it.
+    third_party_event_rules_methods = {
+        "check_event_allowed",
+        "on_create_room",
+        "check_threepid_can_be_invited",
+        "check_visibility_can_be_modified",
+    }
+
+    def async_wrapper(f: Optional[Callable]) -> Optional[Callable[..., Awaitable]]:
+        # f might be None if the callback isn't implemented by the module. In this
+        # case we don't want to register a callback at all so we return None.
+        if f is None:
+            return None
+
+        # We return a separate wrapper for these methods because, in order to wrap them
+        # correctly, we need to await its result. Therefore it doesn't make a lot of
+        # sense to make it go through the run() wrapper.
+        if f.__name__ == "check_event_allowed":
+
+            # We need to wrap check_event_allowed because its old form would return either
+            # a boolean or a dict, but now we want to return the dict separately from the
+            # boolean.
+            async def wrap_check_event_allowed(
+                event: EventBase,
+                state_events: StateMap[EventBase],
+            ) -> Tuple[bool, Optional[dict]]:
+                # We've already made sure f is not None above, but mypy doesn't do well
+                # across function boundaries so we need to tell it f is definitely not
+                # None.
+                assert f is not None
+
+                res = await f(event, state_events)
+                if isinstance(res, dict):
+                    return True, res
+                else:
+                    return res, None
+
+            return wrap_check_event_allowed
+
+        if f.__name__ == "on_create_room":
+
+            # We need to wrap on_create_room because its old form would return a boolean
+            # if the room creation is denied, but now we just want it to raise an
+            # exception.
+            async def wrap_on_create_room(
+                requester: Requester, config: dict, is_requester_admin: bool
+            ) -> None:
+                # We've already made sure f is not None above, but mypy doesn't do well
+                # across function boundaries so we need to tell it f is definitely not
+                # None.
+                assert f is not None
+
+                res = await f(requester, config, is_requester_admin)
+                if res is False:
+                    raise SynapseError(
+                        403,
+                        "Room creation forbidden with these parameters",
+                    )
+
+            return wrap_on_create_room
+
+        def run(*args, **kwargs):
+            # mypy doesn't do well across function boundaries so we need to tell it
+            # f is definitely not None.
+            assert f is not None
+
+            return maybe_awaitable(f(*args, **kwargs))
+
+        return run
+
+    # Register the hooks through the module API.
+    hooks = {
+        hook: async_wrapper(getattr(third_party_rules, hook, None))
+        for hook in third_party_event_rules_methods
+    }
+
+    api.register_third_party_rules_callbacks(**hooks)
+
 
 class ThirdPartyEventRules:
     """Allows server admins to provide a Python module implementing an extra
@@ -35,36 +143,65 @@ class ThirdPartyEventRules:
 
         self.store = hs.get_datastore()
 
-        module = None
-        config = None
-        if hs.config.third_party_event_rules:
-            module, config = hs.config.third_party_event_rules
+        self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = []
+        self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = []
+        self._check_threepid_can_be_invited_callbacks: List[
+            CHECK_THREEPID_CAN_BE_INVITED_CALLBACK
+        ] = []
+        self._check_visibility_can_be_modified_callbacks: List[
+            CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
+        ] = []
+
+    def register_third_party_rules_callbacks(
+        self,
+        check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None,
+        on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None,
+        check_threepid_can_be_invited: Optional[
+            CHECK_THREEPID_CAN_BE_INVITED_CALLBACK
+        ] = None,
+        check_visibility_can_be_modified: Optional[
+            CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
+        ] = None,
+    ):
+        """Register callbacks from modules for each hook."""
+        if check_event_allowed is not None:
+            self._check_event_allowed_callbacks.append(check_event_allowed)
+
+        if on_create_room is not None:
+            self._on_create_room_callbacks.append(on_create_room)
+
+        if check_threepid_can_be_invited is not None:
+            self._check_threepid_can_be_invited_callbacks.append(
+                check_threepid_can_be_invited,
+            )
 
-        if module is not None:
-            self.third_party_rules = module(
-                config=config,
-                module_api=hs.get_module_api(),
+        if check_visibility_can_be_modified is not None:
+            self._check_visibility_can_be_modified_callbacks.append(
+                check_visibility_can_be_modified,
             )
 
     async def check_event_allowed(
         self, event: EventBase, context: EventContext
-    ) -> Union[bool, dict]:
+    ) -> Tuple[bool, Optional[dict]]:
         """Check if a provided event should be allowed in the given context.
 
         The module can return:
             * True: the event is allowed.
             * False: the event is not allowed, and should be rejected with M_FORBIDDEN.
-            * a dict: replacement event data.
+
+        If the event is allowed, the module can also return a dictionary to use as a
+        replacement for the event.
 
         Args:
             event: The event to be checked.
             context: The context of the event.
 
         Returns:
-            The result from the ThirdPartyRules module, as above
+            The result from the ThirdPartyRules module, as above.
         """
-        if self.third_party_rules is None:
-            return True
+        # Bail out early without hitting the store if we don't have any callbacks to run.
+        if len(self._check_event_allowed_callbacks) == 0:
+            return True, None
 
         prev_state_ids = await context.get_prev_state_ids()
 
@@ -77,29 +214,46 @@ class ThirdPartyEventRules:
         # the hashes and signatures.
         event.freeze()
 
-        return await self.third_party_rules.check_event_allowed(event, state_events)
+        for callback in self._check_event_allowed_callbacks:
+            try:
+                res, replacement_data = await callback(event, state_events)
+            except Exception as e:
+                logger.warning("Failed to run module API callback %s: %s", callback, e)
+                continue
+
+            # Return if the event shouldn't be allowed or if the module came up with a
+            # replacement dict for the event.
+            if res is False:
+                return res, None
+            elif isinstance(replacement_data, dict):
+                return True, replacement_data
+
+        return True, None
 
     async def on_create_room(
         self, requester: Requester, config: dict, is_requester_admin: bool
-    ) -> bool:
-        """Intercept requests to create room to allow, deny or update the
-        request config.
+    ) -> None:
+        """Intercept requests to create room to maybe deny it (via an exception) or
+        update the request config.
 
         Args:
             requester
             config: The creation config from the client.
             is_requester_admin: If the requester is an admin
-
-        Returns:
-            Whether room creation is allowed or denied.
         """
-
-        if self.third_party_rules is None:
-            return True
-
-        return await self.third_party_rules.on_create_room(
-            requester, config, is_requester_admin
-        )
+        for callback in self._on_create_room_callbacks:
+            try:
+                await callback(requester, config, is_requester_admin)
+            except Exception as e:
+                # Don't silence the errors raised by this callback since we expect it to
+                # raise an exception to deny the creation of the room; instead make sure
+                # it's a SynapseError we can send to clients.
+                if not isinstance(e, SynapseError):
+                    e = SynapseError(
+                        403, "Room creation forbidden with these parameters"
+                    )
+
+                raise e
 
     async def check_threepid_can_be_invited(
         self, medium: str, address: str, room_id: str
@@ -114,15 +268,20 @@ class ThirdPartyEventRules:
         Returns:
             True if the 3PID can be invited, False if not.
         """
-
-        if self.third_party_rules is None:
+        # Bail out early without hitting the store if we don't have any callbacks to run.
+        if len(self._check_threepid_can_be_invited_callbacks) == 0:
             return True
 
         state_events = await self._get_state_map_for_room(room_id)
 
-        return await self.third_party_rules.check_threepid_can_be_invited(
-            medium, address, state_events
-        )
+        for callback in self._check_threepid_can_be_invited_callbacks:
+            try:
+                if await callback(medium, address, state_events) is False:
+                    return False
+            except Exception as e:
+                logger.warning("Failed to run module API callback %s: %s", callback, e)
+
+        return True
 
     async def check_visibility_can_be_modified(
         self, room_id: str, new_visibility: str
@@ -137,18 +296,20 @@ class ThirdPartyEventRules:
         Returns:
             True if the room's visibility can be modified, False if not.
         """
-        if self.third_party_rules is None:
-            return True
-
-        check_func = getattr(
-            self.third_party_rules, "check_visibility_can_be_modified", None
-        )
-        if not check_func or not callable(check_func):
+        # Bail out early without hitting the store if we don't have any callback
+        if len(self._check_visibility_can_be_modified_callbacks) == 0:
             return True
 
         state_events = await self._get_state_map_for_room(room_id)
 
-        return await check_func(room_id, state_events, new_visibility)
+        for callback in self._check_visibility_can_be_modified_callbacks:
+            try:
+                if await callback(room_id, state_events, new_visibility) is False:
+                    return False
+            except Exception as e:
+                logger.warning("Failed to run module API callback %s: %s", callback, e)
+
+        return True
 
     async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]:
         """Given a room ID, return the state events of that room.