summary refs log tree commit diff
path: root/synapse/storage/controllers/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/controllers/state.py')
-rw-r--r--synapse/storage/controllers/state.py59
1 files changed, 59 insertions, 0 deletions
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 10d219c045..46957723a1 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -37,6 +37,7 @@ from synapse.storage.util.partial_state_events_tracker import (
     PartialCurrentStateTracker,
     PartialStateEventsTracker,
 )
+from synapse.synapse_rust.acl import ServerAclEvaluator
 from synapse.types import MutableStateMap, StateMap, get_domain_from_id
 from synapse.types.state import StateFilter
 from synapse.util.async_helpers import Linearizer
@@ -501,6 +502,31 @@ class StateStorageController:
 
         return event.content.get("alias")
 
+    @cached()
+    async def get_server_acl_for_room(
+        self, room_id: str
+    ) -> Optional[ServerAclEvaluator]:
+        """Get the server ACL evaluator for room, if any
+
+        This does up-front parsing of the content to ignore bad data and pre-compile
+        regular expressions.
+
+        Args:
+            room_id: The room ID
+
+        Returns:
+            The server ACL evaluator, if any
+        """
+
+        acl_event = await self.get_current_state_event(
+            room_id, EventTypes.ServerACL, ""
+        )
+
+        if not acl_event:
+            return None
+
+        return server_acl_evaluator_from_event(acl_event)
+
     @trace
     @tag_args
     async def get_current_state_deltas(
@@ -760,3 +786,36 @@ class StateStorageController:
                 cache.state_group = object()
 
         return frozenset(cache.hosts_to_joined_users)
+
+
+def server_acl_evaluator_from_event(acl_event: EventBase) -> "ServerAclEvaluator":
+    """
+    Create a ServerAclEvaluator from a m.room.server_acl event's content.
+
+    This does up-front parsing of the content to ignore bad data. It then creates
+    the ServerAclEvaluator which will pre-compile regular expressions from the globs.
+    """
+
+    # first of all, parse if literal IPs are blocked.
+    allow_ip_literals = acl_event.content.get("allow_ip_literals", True)
+    if not isinstance(allow_ip_literals, bool):
+        logger.warning("Ignoring non-bool allow_ip_literals flag")
+        allow_ip_literals = True
+
+    # next, parse the deny list by ignoring any non-strings.
+    deny = acl_event.content.get("deny", [])
+    if not isinstance(deny, (list, tuple)):
+        logger.warning("Ignoring non-list deny ACL %s", deny)
+        deny = []
+    else:
+        deny = [s for s in deny if isinstance(s, str)]
+
+    # then the allow list.
+    allow = acl_event.content.get("allow", [])
+    if not isinstance(allow, (list, tuple)):
+        logger.warning("Ignoring non-list allow ACL %s", allow)
+        allow = []
+    else:
+        allow = [s for s in allow if isinstance(s, str)]
+
+    return ServerAclEvaluator(allow_ip_literals, allow, deny)