diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 9d5310851c..1535cc5339 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -12,10 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Callable
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.types import Requester
+from synapse.types import Requester, StateMap
class ThirdPartyEventRules:
@@ -38,7 +39,7 @@ class ThirdPartyEventRules:
if module is not None:
self.third_party_rules = module(
- config=config, http_client=hs.get_simple_http_client()
+ config=config, module_api=hs.get_module_api(),
)
async def check_event_allowed(
@@ -59,12 +60,14 @@ class ThirdPartyEventRules:
prev_state_ids = await context.get_prev_state_ids()
# Retrieve the state events from the database.
- state_events = {}
- for key, event_id in prev_state_ids.items():
- state_events[key] = await self.store.get_event(event_id, allow_none=True)
+ events = await self.store.get_events(prev_state_ids.values())
+ state_events = {(ev.type, ev.state_key): ev for ev in events.values()}
- ret = await self.third_party_rules.check_event_allowed(event, state_events)
- return ret
+ # The module can modify the event slightly if it wants, but caution should be
+ # exercised, and it's likely to go very wrong if applied to events received over
+ # federation.
+
+ return await self.third_party_rules.check_event_allowed(event, state_events)
async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool
@@ -106,6 +109,48 @@ class ThirdPartyEventRules:
if self.third_party_rules is None:
return True
+ state_events = await self._get_state_map_for_room(room_id)
+
+ ret = await self.third_party_rules.check_threepid_can_be_invited(
+ medium, address, state_events
+ )
+ return ret
+
+ async def check_visibility_can_be_modified(
+ self, room_id: str, new_visibility: str
+ ) -> bool:
+ """Check if a room is allowed to be published to, or removed from, the public room
+ list.
+
+ Args:
+ room_id: The ID of the room.
+ new_visibility: The new visibility state. Either "public" or "private".
+
+ Returns:
+ True if the room's visibility can be modified, False if not.
+ """
+ if self.third_party_rules is None:
+ return True
+
+ check_func = getattr(
+ self.third_party_rules, "check_visibility_can_be_modified", None
+ )
+ if not check_func or not isinstance(check_func, Callable):
+ return True
+
+ state_events = await self._get_state_map_for_room(room_id)
+
+ return await check_func(room_id, state_events, new_visibility)
+
+ async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]:
+ """Given a room ID, return the state events of that room.
+
+ Args:
+ room_id: The ID of the room.
+
+ Returns:
+ A dict mapping (event type, state key) to state event.
+ """
state_ids = await self.store.get_filtered_current_state_ids(room_id)
room_state_events = await self.store.get_events(state_ids.values())
@@ -113,7 +158,4 @@ class ThirdPartyEventRules:
for key, event_id in state_ids.items():
state_events[key] = room_state_events[event_id]
- ret = await self.third_party_rules.check_threepid_can_be_invited(
- medium, address, state_events
- )
- return ret
+ return state_events
|