summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/15240.misc1
-rw-r--r--synapse/federation/sender/per_destination_queue.py2
-rw-r--r--synapse/handlers/federation.py29
-rw-r--r--synapse/visibility.py67
-rw-r--r--tests/test_visibility.py40
5 files changed, 109 insertions, 30 deletions
diff --git a/changelog.d/15240.misc b/changelog.d/15240.misc
new file mode 100644
index 0000000000..2b7edf916e
--- /dev/null
+++ b/changelog.d/15240.misc
@@ -0,0 +1 @@
+Refactor `filter_events_for_server`.
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index ffc9d95ee7..478187ce44 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -547,6 +547,8 @@ class PerDestinationQueue:
                         self._server_name,
                         new_pdus,
                         redact=False,
+                        filter_out_erased_senders=True,
+                        filter_out_remote_partial_state_events=True,
                     )
 
                     # If we've filtered out all the extremities, fall back to
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 5f2057269d..80156ef343 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -392,7 +392,7 @@ class FederationHandler:
                 get_prev_content=False,
             )
 
-            # We set `check_history_visibility_only` as we might otherwise get false
+            # We unset `filter_out_erased_senders` as we might otherwise get false
             # positives from users having been erased.
             filtered_extremities = await filter_events_for_server(
                 self._storage_controllers,
@@ -400,7 +400,8 @@ class FederationHandler:
                 self.server_name,
                 events_to_check,
                 redact=False,
-                check_history_visibility_only=True,
+                filter_out_erased_senders=False,
+                filter_out_remote_partial_state_events=False,
             )
             if filtered_extremities:
                 extremities_to_request.append(bp.event_id)
@@ -1331,7 +1332,13 @@ class FederationHandler:
         )
 
         events = await filter_events_for_server(
-            self._storage_controllers, origin, self.server_name, events
+            self._storage_controllers,
+            origin,
+            self.server_name,
+            events,
+            redact=True,
+            filter_out_erased_senders=True,
+            filter_out_remote_partial_state_events=True,
         )
 
         return events
@@ -1362,7 +1369,13 @@ class FederationHandler:
         await self._event_auth_handler.assert_host_in_room(event.room_id, origin)
 
         events = await filter_events_for_server(
-            self._storage_controllers, origin, self.server_name, [event]
+            self._storage_controllers,
+            origin,
+            self.server_name,
+            [event],
+            redact=True,
+            filter_out_erased_senders=True,
+            filter_out_remote_partial_state_events=True,
         )
         event = events[0]
         return event
@@ -1390,7 +1403,13 @@ class FederationHandler:
         )
 
         missing_events = await filter_events_for_server(
-            self._storage_controllers, origin, self.server_name, missing_events
+            self._storage_controllers,
+            origin,
+            self.server_name,
+            missing_events,
+            redact=True,
+            filter_out_erased_senders=True,
+            filter_out_remote_partial_state_events=True,
         )
 
         return missing_events
diff --git a/synapse/visibility.py b/synapse/visibility.py
index e442de3173..468e22f8f6 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -14,7 +14,17 @@
 # limitations under the License.
 import logging
 from enum import Enum, auto
-from typing import Collection, Dict, FrozenSet, List, Optional, Tuple
+from typing import (
+    Collection,
+    Dict,
+    FrozenSet,
+    List,
+    Mapping,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+)
 
 import attr
 from typing_extensions import Final
@@ -565,29 +575,43 @@ async def filter_events_for_server(
     storage: StorageControllers,
     target_server_name: str,
     local_server_name: str,
-    events: List[EventBase],
-    redact: bool = True,
-    check_history_visibility_only: bool = False,
+    events: Sequence[EventBase],
+    *,
+    redact: bool,
+    filter_out_erased_senders: bool,
+    filter_out_remote_partial_state_events: bool,
 ) -> List[EventBase]:
-    """Filter a list of events based on whether given server is allowed to
+    """Filter a list of events based on whether the target server is allowed to
     see them.
 
+    For a fully stated room, the target server is allowed to see an event E if:
+      - the state at E has world readable or shared history vis, OR
+      - the state at E says that the target server is in the room.
+
+    For a partially stated room, the target server is allowed to see E if:
+      - E was created by this homeserver, AND:
+          - the partial state at E has world readable or shared history vis, OR
+          - the partial state at E says that the target server is in the room.
+
+    TODO: state before or state after?
+
     Args:
         storage
-        server_name
+        target_server_name
+        local_server_name
         events
-        redact: Whether to return a redacted version of the event, or
-            to filter them out entirely.
-        check_history_visibility_only: Whether to only check the
-            history visibility, rather than things like if the sender has been
+        redact: Controls what to do with events which have been filtered out.
+            If True, include their redacted forms; if False, omit them entirely.
+        filter_out_erased_senders: If true, also filter out events whose sender has been
             erased. This is used e.g. during pagination to decide whether to
             backfill or not.
-
+        filter_out_remote_partial_state_events: If True, also filter out events in
+            partial state rooms created by other homeservers.
     Returns
         The filtered events.
     """
 
-    def is_sender_erased(event: EventBase, erased_senders: Dict[str, bool]) -> bool:
+    def is_sender_erased(event: EventBase, erased_senders: Mapping[str, bool]) -> bool:
         if erased_senders and erased_senders[event.sender]:
             logger.info("Sender of %s has been erased, redacting", event.event_id)
             return True
@@ -616,7 +640,7 @@ async def filter_events_for_server(
         # server has no users in the room: redact
         return False
 
-    if not check_history_visibility_only:
+    if filter_out_erased_senders:
         erased_senders = await storage.main.are_users_erased(e.sender for e in events)
     else:
         # We don't want to check whether users are erased, which is equivalent
@@ -631,15 +655,15 @@ async def filter_events_for_server(
     # otherwise a room could be fully joined after we retrieve those, which would then bypass
     # this check but would base the filtering on an outdated view of the membership events.
 
-    partial_state_invisible_events = set()
-    if not check_history_visibility_only:
+    partial_state_invisible_event_ids: Set[str] = set()
+    if filter_out_remote_partial_state_events:
         for e in events:
             sender_domain = get_domain_from_id(e.sender)
             if (
                 sender_domain != local_server_name
                 and await storage.main.is_partial_state_room(e.room_id)
             ):
-                partial_state_invisible_events.add(e)
+                partial_state_invisible_event_ids.add(e.event_id)
 
     # Let's check to see if all the events have a history visibility
     # of "shared" or "world_readable". If that's the case then we don't
@@ -658,17 +682,20 @@ async def filter_events_for_server(
         target_server_name,
     )
 
-    to_return = []
-    for e in events:
+    def include_event_in_output(e: EventBase) -> bool:
         erased = is_sender_erased(e, erased_senders)
         visible = check_event_is_visible(
             event_to_history_vis[e.event_id], event_to_memberships.get(e.event_id, {})
         )
 
-        if e in partial_state_invisible_events:
+        if e.event_id in partial_state_invisible_event_ids:
             visible = False
 
-        if visible and not erased:
+        return visible and not erased
+
+    to_return = []
+    for e in events:
+        if include_event_in_output(e):
             to_return.append(e)
         elif redact:
             to_return.append(prune_event(e))
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 2801a950a8..9ed330f554 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -63,7 +63,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
 
         filtered = self.get_success(
             filter_events_for_server(
-                self._storage_controllers, "test_server", "hs", events_to_filter
+                self._storage_controllers,
+                "test_server",
+                "hs",
+                events_to_filter,
+                redact=True,
+                filter_out_erased_senders=True,
+                filter_out_remote_partial_state_events=True,
             )
         )
 
@@ -85,7 +91,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         self.assertEqual(
             self.get_success(
                 filter_events_for_server(
-                    self._storage_controllers, "remote_hs", "hs", [outlier]
+                    self._storage_controllers,
+                    "remote_hs",
+                    "hs",
+                    [outlier],
+                    redact=True,
+                    filter_out_erased_senders=True,
+                    filter_out_remote_partial_state_events=True,
                 )
             ),
             [outlier],
@@ -96,7 +108,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
 
         filtered = self.get_success(
             filter_events_for_server(
-                self._storage_controllers, "remote_hs", "local_hs", [outlier, evt]
+                self._storage_controllers,
+                "remote_hs",
+                "local_hs",
+                [outlier, evt],
+                redact=True,
+                filter_out_erased_senders=True,
+                filter_out_remote_partial_state_events=True,
             )
         )
         self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
@@ -108,7 +126,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         # be redacted)
         filtered = self.get_success(
             filter_events_for_server(
-                self._storage_controllers, "other_server", "local_hs", [outlier, evt]
+                self._storage_controllers,
+                "other_server",
+                "local_hs",
+                [outlier, evt],
+                redact=True,
+                filter_out_erased_senders=True,
+                filter_out_remote_partial_state_events=True,
             )
         )
         self.assertEqual(filtered[0], outlier)
@@ -143,7 +167,13 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         # ... and the filtering happens.
         filtered = self.get_success(
             filter_events_for_server(
-                self._storage_controllers, "test_server", "local_hs", events_to_filter
+                self._storage_controllers,
+                "test_server",
+                "local_hs",
+                events_to_filter,
+                redact=True,
+                filter_out_erased_senders=True,
+                filter_out_remote_partial_state_events=True,
             )
         )