diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 9d5310851c..1ca77519d5 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -12,10 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Callable
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.types import Requester
+from synapse.module_api import ModuleApi
+from synapse.types import Requester, StateMap
class ThirdPartyEventRules:
@@ -38,7 +40,7 @@ class ThirdPartyEventRules:
if module is not None:
self.third_party_rules = module(
- config=config, http_client=hs.get_simple_http_client()
+ config=config, module_api=ModuleApi(hs, hs.get_auth_handler()),
)
async def check_event_allowed(
@@ -106,6 +108,48 @@ class ThirdPartyEventRules:
if self.third_party_rules is None:
return True
+ state_events = await self._get_state_map_for_room(room_id)
+
+ ret = await self.third_party_rules.check_threepid_can_be_invited(
+ medium, address, state_events
+ )
+ return ret
+
+ async def check_visibility_can_be_modified(
+ self, room_id: str, new_visibility: str
+ ) -> bool:
+ """Check if a room is allowed to be published to, or removed from, the public room
+ list.
+
+ Args:
+ room_id: The ID of the room.
+ new_visibility: The new visibility state. Either "public" or "private".
+
+ Returns:
+ True if the room's visibility can be modified, False if not.
+ """
+ if self.third_party_rules is None:
+ return True
+
+ check_func = getattr(
+ self.third_party_rules, "check_visibility_can_be_modified", None
+ )
+ if not check_func or not isinstance(check_func, Callable):
+ return True
+
+ state_events = await self._get_state_map_for_room(room_id)
+
+ return await check_func(room_id, state_events, new_visibility)
+
+ async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]:
+ """Given a room ID, return the state events of that room.
+
+ Args:
+ room_id: The ID of the room.
+
+ Returns:
+ A dict mapping (event type, state key) to state event.
+ """
state_ids = await self.store.get_filtered_current_state_ids(room_id)
room_state_events = await self.store.get_events(state_ids.values())
@@ -113,7 +157,4 @@ class ThirdPartyEventRules:
for key, event_id in state_ids.items():
state_events[key] = room_state_events[event_id]
- ret = await self.third_party_rules.check_threepid_can_be_invited(
- medium, address, state_events
- )
- return ret
+ return state_events
|