diff options
Diffstat (limited to 'synapse/events/third_party_rules.py')
-rw-r--r-- | synapse/events/third_party_rules.py | 64 |
1 files changed, 53 insertions, 11 deletions
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 9d5310851c..1535cc5339 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -12,10 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Callable from synapse.events import EventBase from synapse.events.snapshot import EventContext -from synapse.types import Requester +from synapse.types import Requester, StateMap class ThirdPartyEventRules: @@ -38,7 +39,7 @@ class ThirdPartyEventRules: if module is not None: self.third_party_rules = module( - config=config, http_client=hs.get_simple_http_client() + config=config, module_api=hs.get_module_api(), ) async def check_event_allowed( @@ -59,12 +60,14 @@ class ThirdPartyEventRules: prev_state_ids = await context.get_prev_state_ids() # Retrieve the state events from the database. - state_events = {} - for key, event_id in prev_state_ids.items(): - state_events[key] = await self.store.get_event(event_id, allow_none=True) + events = await self.store.get_events(prev_state_ids.values()) + state_events = {(ev.type, ev.state_key): ev for ev in events.values()} - ret = await self.third_party_rules.check_event_allowed(event, state_events) - return ret + # The module can modify the event slightly if it wants, but caution should be + # exercised, and it's likely to go very wrong if applied to events received over + # federation. + + return await self.third_party_rules.check_event_allowed(event, state_events) async def on_create_room( self, requester: Requester, config: dict, is_requester_admin: bool @@ -106,6 +109,48 @@ class ThirdPartyEventRules: if self.third_party_rules is None: return True + state_events = await self._get_state_map_for_room(room_id) + + ret = await self.third_party_rules.check_threepid_can_be_invited( + medium, address, state_events + ) + return ret + + async def check_visibility_can_be_modified( + self, room_id: str, new_visibility: str + ) -> bool: + """Check if a room is allowed to be published to, or removed from, the public room + list. + + Args: + room_id: The ID of the room. + new_visibility: The new visibility state. Either "public" or "private". + + Returns: + True if the room's visibility can be modified, False if not. + """ + if self.third_party_rules is None: + return True + + check_func = getattr( + self.third_party_rules, "check_visibility_can_be_modified", None + ) + if not check_func or not isinstance(check_func, Callable): + return True + + state_events = await self._get_state_map_for_room(room_id) + + return await check_func(room_id, state_events, new_visibility) + + async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]: + """Given a room ID, return the state events of that room. + + Args: + room_id: The ID of the room. + + Returns: + A dict mapping (event type, state key) to state event. + """ state_ids = await self.store.get_filtered_current_state_ids(room_id) room_state_events = await self.store.get_events(state_ids.values()) @@ -113,7 +158,4 @@ class ThirdPartyEventRules: for key, event_id in state_ids.items(): state_events[key] = room_state_events[event_id] - ret = await self.third_party_rules.check_threepid_can_be_invited( - medium, address, state_events - ) - return ret + return state_events |