summary refs log tree commit diff
path: root/synapse/events/third_party_rules.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/events/third_party_rules.py')
-rw-r--r--synapse/events/third_party_rules.py71
1 files changed, 67 insertions, 4 deletions
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 72ab696898..3e4d52c8d8 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -18,7 +18,7 @@ from twisted.internet.defer import CancelledError
 
 from synapse.api.errors import ModuleFailedException, SynapseError
 from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import UnpersistedEventContextBase
 from synapse.storage.roommember import ProfileInfo
 from synapse.types import Requester, StateMap
 from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
@@ -45,6 +45,8 @@ CHECK_CAN_DEACTIVATE_USER_CALLBACK = Callable[[str, bool], Awaitable[bool]]
 ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable]
 ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable]
 ON_THREEPID_BIND_CALLBACK = Callable[[str, str, str], Awaitable]
+ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK = Callable[[str, str, str], Awaitable]
+ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK = Callable[[str, str, str], Awaitable]
 
 
 def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
@@ -78,7 +80,6 @@ def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
         # correctly, we need to await its result. Therefore it doesn't make a lot of
         # sense to make it go through the run() wrapper.
         if f.__name__ == "check_event_allowed":
-
             # We need to wrap check_event_allowed because its old form would return either
             # a boolean or a dict, but now we want to return the dict separately from the
             # boolean.
@@ -100,7 +101,6 @@ def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
             return wrap_check_event_allowed
 
         if f.__name__ == "on_create_room":
-
             # We need to wrap on_create_room because its old form would return a boolean
             # if the room creation is denied, but now we just want it to raise an
             # exception.
@@ -174,6 +174,12 @@ class ThirdPartyEventRules:
             ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK
         ] = []
         self._on_threepid_bind_callbacks: List[ON_THREEPID_BIND_CALLBACK] = []
+        self._on_add_user_third_party_identifier_callbacks: List[
+            ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK
+        ] = []
+        self._on_remove_user_third_party_identifier_callbacks: List[
+            ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK
+        ] = []
 
     def register_third_party_rules_callbacks(
         self,
@@ -193,6 +199,12 @@ class ThirdPartyEventRules:
             ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK
         ] = None,
         on_threepid_bind: Optional[ON_THREEPID_BIND_CALLBACK] = None,
+        on_add_user_third_party_identifier: Optional[
+            ON_ADD_USER_THIRD_PARTY_IDENTIFIER_CALLBACK
+        ] = None,
+        on_remove_user_third_party_identifier: Optional[
+            ON_REMOVE_USER_THIRD_PARTY_IDENTIFIER_CALLBACK
+        ] = None,
     ) -> None:
         """Register callbacks from modules for each hook."""
         if check_event_allowed is not None:
@@ -230,8 +242,15 @@ class ThirdPartyEventRules:
         if on_threepid_bind is not None:
             self._on_threepid_bind_callbacks.append(on_threepid_bind)
 
+        if on_add_user_third_party_identifier is not None:
+            self._on_add_user_third_party_identifier_callbacks.append(
+                on_add_user_third_party_identifier
+            )
+
     async def check_event_allowed(
-        self, event: EventBase, context: EventContext
+        self,
+        event: EventBase,
+        context: UnpersistedEventContextBase,
     ) -> Tuple[bool, Optional[dict]]:
         """Check if a provided event should be allowed in the given context.
 
@@ -511,6 +530,9 @@ class ThirdPartyEventRules:
         local homeserver, not when it's created on an identity server (and then kept track
         of so that it can be unbound on the same IS later on).
 
+        THIS MODULE CALLBACK METHOD HAS BEEN DEPRECATED. Please use the
+        `on_add_user_third_party_identifier` callback method instead.
+
         Args:
             user_id: the user being associated with the threepid.
             medium: the threepid's medium.
@@ -523,3 +545,44 @@ class ThirdPartyEventRules:
                 logger.exception(
                     "Failed to run module API callback %s: %s", callback, e
                 )
+
+    async def on_add_user_third_party_identifier(
+        self, user_id: str, medium: str, address: str
+    ) -> None:
+        """Called when an association between a user's Matrix ID and a third-party ID
+        (email, phone number) has successfully been registered on the homeserver.
+
+        Args:
+            user_id: The User ID included in the association.
+            medium: The medium of the third-party ID (email, msisdn).
+            address: The address of the third-party ID (i.e. an email address).
+        """
+        for callback in self._on_add_user_third_party_identifier_callbacks:
+            try:
+                await callback(user_id, medium, address)
+            except Exception as e:
+                logger.exception(
+                    "Failed to run module API callback %s: %s", callback, e
+                )
+
+    async def on_remove_user_third_party_identifier(
+        self, user_id: str, medium: str, address: str
+    ) -> None:
+        """Called when an association between a user's Matrix ID and a third-party ID
+        (email, phone number) has been successfully removed on the homeserver.
+
+        This is called *after* any known bindings on identity servers for this
+        association have been removed.
+
+        Args:
+            user_id: The User ID included in the removed association.
+            medium: The medium of the third-party ID (email, msisdn).
+            address: The address of the third-party ID (i.e. an email address).
+        """
+        for callback in self._on_remove_user_third_party_identifier_callbacks:
+            try:
+                await callback(user_id, medium, address)
+            except Exception as e:
+                logger.exception(
+                    "Failed to run module API callback %s: %s", callback, e
+                )