summary refs log tree commit diff
path: root/synapse/events
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/events')
-rw-r--r--synapse/events/third_party_rules.py51
1 files changed, 45 insertions, 6 deletions
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 9d5310851c..fed459198a 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -12,10 +12,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Callable
 
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
-from synapse.types import Requester
+from synapse.module_api import ModuleApi
+from synapse.types import Requester, StateMap
 
 
 class ThirdPartyEventRules:
@@ -38,7 +40,7 @@ class ThirdPartyEventRules:
 
         if module is not None:
             self.third_party_rules = module(
-                config=config, http_client=hs.get_simple_http_client()
+                config=config, module_api=ModuleApi(hs, hs.get_auth_handler()),
             )
 
     async def check_event_allowed(
@@ -106,6 +108,46 @@ class ThirdPartyEventRules:
         if self.third_party_rules is None:
             return True
 
+        state_events = await self._get_state_map_for_room(room_id)
+
+        ret = await self.third_party_rules.check_threepid_can_be_invited(
+            medium, address, state_events
+        )
+        return ret
+
+    async def check_visibility_can_be_modified(
+        self, room_id: str, new_visibility: str
+    ) -> bool:
+        """Check if a room is allowed to be published to, or removed from, the public room
+        list.
+
+        Args:
+            room_id: The ID of the room.
+            new_visibility: The new visibility state. Either "public" or "private".
+
+        Returns:
+            True if the room's visibility can be modified, False if not.
+        """
+        if self.third_party_rules is None:
+            return True
+
+        check_func = getattr(self.third_party_rules, "check_visibility_can_be_modified")
+        if not check_func or not isinstance(check_func, Callable):
+            return True
+
+        state_events = await self._get_state_map_for_room(room_id)
+
+        return await check_func(room_id, state_events, new_visibility)
+
+    async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]:
+        """Given a room ID, return the state events of that room.
+
+        Args:
+            room_id: The ID of the room.
+
+        Returns:
+            A dict mapping (event type, state key) to state event.
+        """
         state_ids = await self.store.get_filtered_current_state_ids(room_id)
         room_state_events = await self.store.get_events(state_ids.values())
 
@@ -113,7 +155,4 @@ class ThirdPartyEventRules:
         for key, event_id in state_ids.items():
             state_events[key] = room_state_events[event_id]
 
-        ret = await self.third_party_rules.check_threepid_can_be_invited(
-            medium, address, state_events
-        )
-        return ret
+        return state_events