summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2022-03-31 18:39:34 +0100
committerGitHub <noreply@github.com>2022-03-31 17:39:34 +0000
commit6927d8725430356880e212edf0c61bf32cb071c3 (patch)
treedc4d94f3262e26e2b552d7d2b9f7f91922643249
parentAdd more type hints to the main state store. (#12267) (diff)
downloadsynapse-6927d8725430356880e212edf0c61bf32cb071c3.tar.xz
Handle outliers in `/federation/v1/event` (#12332)
The intention here is to avoid doing state lookups for outliers in
`/_matrix/federation/v1/event`. Unfortunately that's expanded into something of
a rewrite of `filter_events_for_server`, which ended up trying to do that
operation in a couple of places.
Diffstat (limited to '')
-rw-r--r--changelog.d/12332.misc1
-rw-r--r--synapse/visibility.py234
-rw-r--r--tests/test_visibility.py53
3 files changed, 182 insertions, 106 deletions
diff --git a/changelog.d/12332.misc b/changelog.d/12332.misc
new file mode 100644
index 0000000000..9f333e718a
--- /dev/null
+++ b/changelog.d/12332.misc
@@ -0,0 +1 @@
+Avoid trying to calculate the state at outlier events.
diff --git a/synapse/visibility.py b/synapse/visibility.py
index 49519eb8f5..250f073597 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -1,4 +1,5 @@
 # Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright (C) The Matrix.org Foundation C.I.C. 2022
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,7 +13,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Dict, FrozenSet, List, Optional
+from typing import Collection, Dict, FrozenSet, List, Optional, Tuple
+
+from typing_extensions import Final
 
 from synapse.api.constants import EventTypes, HistoryVisibility, Membership
 from synapse.events import EventBase
@@ -40,6 +43,8 @@ MEMBERSHIP_PRIORITY = (
     Membership.BAN,
 )
 
+_HISTORY_VIS_KEY: Final[Tuple[str, str]] = (EventTypes.RoomHistoryVisibility, "")
+
 
 async def filter_events_for_client(
     storage: Storage,
@@ -74,7 +79,7 @@ async def filter_events_for_client(
     # to clients.
     events = [e for e in events if not e.internal_metadata.is_soft_failed()]
 
-    types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
+    types = (_HISTORY_VIS_KEY, (EventTypes.Member, user_id))
 
     # we exclude outliers at this point, and then handle them separately later
     event_id_to_state = await storage.state.get_state_for_events(
@@ -157,7 +162,7 @@ async def filter_events_for_client(
         state = event_id_to_state[event.event_id]
 
         # get the room_visibility at the time of the event.
-        visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
+        visibility_event = state.get(_HISTORY_VIS_KEY, None)
         if visibility_event:
             visibility = visibility_event.content.get(
                 "history_visibility", HistoryVisibility.SHARED
@@ -293,67 +298,28 @@ async def filter_events_for_server(
             return True
         return False
 
-    def check_event_is_visible(event: EventBase, state: StateMap[EventBase]) -> bool:
-        history = state.get((EventTypes.RoomHistoryVisibility, ""), None)
-        if history:
-            visibility = history.content.get(
-                "history_visibility", HistoryVisibility.SHARED
-            )
-            if visibility in [HistoryVisibility.INVITED, HistoryVisibility.JOINED]:
-                # We now loop through all state events looking for
-                # membership states for the requesting server to determine
-                # if the server is either in the room or has been invited
-                # into the room.
-                for ev in state.values():
-                    if ev.type != EventTypes.Member:
-                        continue
-                    try:
-                        domain = get_domain_from_id(ev.state_key)
-                    except Exception:
-                        continue
-
-                    if domain != server_name:
-                        continue
-
-                    memtype = ev.membership
-                    if memtype == Membership.JOIN:
-                        return True
-                    elif memtype == Membership.INVITE:
-                        if visibility == HistoryVisibility.INVITED:
-                            return True
-                else:
-                    # server has no users in the room: redact
-                    return False
-
-        return True
-
-    # Lets check to see if all the events have a history visibility
-    # of "shared" or "world_readable". If that's the case then we don't
-    # need to check membership (as we know the server is in the room).
-    event_to_state_ids = await storage.state.get_state_ids_for_events(
-        frozenset(e.event_id for e in events),
-        state_filter=StateFilter.from_types(
-            types=((EventTypes.RoomHistoryVisibility, ""),)
-        ),
-    )
-
-    visibility_ids = set()
-    for sids in event_to_state_ids.values():
-        hist = sids.get((EventTypes.RoomHistoryVisibility, ""))
-        if hist:
-            visibility_ids.add(hist)
+    def check_event_is_visible(
+        visibility: str, memberships: StateMap[EventBase]
+    ) -> bool:
+        if visibility not in (HistoryVisibility.INVITED, HistoryVisibility.JOINED):
+            return True
 
-    # If we failed to find any history visibility events then the default
-    # is "shared" visibility.
-    if not visibility_ids:
-        all_open = True
-    else:
-        event_map = await storage.main.get_events(visibility_ids)
-        all_open = all(
-            e.content.get("history_visibility")
-            in (None, HistoryVisibility.SHARED, HistoryVisibility.WORLD_READABLE)
-            for e in event_map.values()
-        )
+        # We now loop through all membership events looking for
+        # membership states for the requesting server to determine
+        # if the server is either in the room or has been invited
+        # into the room.
+        for ev in memberships.values():
+            assert get_domain_from_id(ev.state_key) == server_name
+
+            memtype = ev.membership
+            if memtype == Membership.JOIN:
+                return True
+            elif memtype == Membership.INVITE:
+                if visibility == HistoryVisibility.INVITED:
+                    return True
+
+        # server has no users in the room: redact
+        return False
 
     if not check_history_visibility_only:
         erased_senders = await storage.main.are_users_erased(e.sender for e in events)
@@ -362,34 +328,100 @@ async def filter_events_for_server(
         # to no users having been erased.
         erased_senders = {}
 
-    if all_open:
-        # all the history_visibility state affecting these events is open, so
-        # we don't need to filter by membership state. We *do* need to check
-        # for user erasure, though.
-        if erased_senders:
-            to_return = []
-            for e in events:
-                if not is_sender_erased(e, erased_senders):
-                    to_return.append(e)
-                elif redact:
-                    to_return.append(prune_event(e))
-
-            return to_return
-
-        # If there are no erased users then we can just return the given list
-        # of events without having to copy it.
-        return events
-
-    # Ok, so we're dealing with events that have non-trivial visibility
-    # rules, so we need to also get the memberships of the room.
-
-    # first, for each event we're wanting to return, get the event_ids
-    # of the history vis and membership state at those events.
+    # Let's check to see if all the events have a history visibility
+    # of "shared" or "world_readable". If that's the case then we don't
+    # need to check membership (as we know the server is in the room).
+    event_to_history_vis = await _event_to_history_vis(storage, events)
+
+    # for any with restricted vis, we also need the memberships
+    event_to_memberships = await _event_to_memberships(
+        storage,
+        [
+            e
+            for e in events
+            if event_to_history_vis[e.event_id]
+            not in (HistoryVisibility.SHARED, HistoryVisibility.WORLD_READABLE)
+        ],
+        server_name,
+    )
+
+    to_return = []
+    for e in events:
+        erased = is_sender_erased(e, erased_senders)
+        visible = check_event_is_visible(
+            event_to_history_vis[e.event_id], event_to_memberships.get(e.event_id, {})
+        )
+        if visible and not erased:
+            to_return.append(e)
+        elif redact:
+            to_return.append(prune_event(e))
+
+    return to_return
+
+
+async def _event_to_history_vis(
+    storage: Storage, events: Collection[EventBase]
+) -> Dict[str, str]:
+    """Get the history visibility at each of the given events
+
+    Returns a map from event id to history_visibility setting
+    """
+
+    # outliers get special treatment here. We don't have the state at that point in the
+    # room (and attempting to look it up will raise an exception), so all we can really
+    # do is assume that the requesting server is allowed to see the event. That's
+    # equivalent to there not being a history_visibility event, so we just exclude
+    # any outliers from the query.
+    event_to_state_ids = await storage.state.get_state_ids_for_events(
+        frozenset(e.event_id for e in events if not e.internal_metadata.is_outlier()),
+        state_filter=StateFilter.from_types(types=(_HISTORY_VIS_KEY,)),
+    )
+
+    visibility_ids = {
+        vis_event_id
+        for vis_event_id in (
+            state_ids.get(_HISTORY_VIS_KEY) for state_ids in event_to_state_ids.values()
+        )
+        if vis_event_id
+    }
+    vis_events = await storage.main.get_events(visibility_ids)
+
+    result: Dict[str, str] = {}
+    for event in events:
+        vis = HistoryVisibility.SHARED
+        state_ids = event_to_state_ids.get(event.event_id)
+
+        # if we didn't find any state for this event, it's an outlier, and we assume
+        # it's open
+        visibility_id = None
+        if state_ids:
+            visibility_id = state_ids.get(_HISTORY_VIS_KEY)
+
+        if visibility_id:
+            vis_event = vis_events[visibility_id]
+            vis = vis_event.content.get("history_visibility", HistoryVisibility.SHARED)
+            assert isinstance(vis, str)
+
+        result[event.event_id] = vis
+    return result
+
+
+async def _event_to_memberships(
+    storage: Storage, events: Collection[EventBase], server_name: str
+) -> Dict[str, StateMap[EventBase]]:
+    """Get the remote membership list at each of the given events
+
+    Returns a map from event id to state map, which will contain only membership events
+    for the given server.
+    """
+
+    if not events:
+        return {}
+
+    # for each event, get the event_ids of the membership state at those events.
     event_to_state_ids = await storage.state.get_state_ids_for_events(
         frozenset(e.event_id for e in events),
-        state_filter=StateFilter.from_types(
-            types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None))
-        ),
+        state_filter=StateFilter.from_types(types=((EventTypes.Member, None),)),
     )
 
     # We only want to pull out member events that correspond to the
@@ -405,10 +437,7 @@ async def filter_events_for_server(
         for key, event_id in key_to_eid.items()
     }
 
-    def include(typ, state_key):
-        if typ != EventTypes.Member:
-            return True
-
+    def include(state_key: str) -> bool:
         # we avoid using get_domain_from_id here for efficiency.
         idx = state_key.find(":")
         if idx == -1:
@@ -416,10 +445,14 @@ async def filter_events_for_server(
         return state_key[idx + 1 :] == server_name
 
     event_map = await storage.main.get_events(
-        [e_id for e_id, key in event_id_to_state_key.items() if include(key[0], key[1])]
+        [
+            e_id
+            for e_id, (_, state_key) in event_id_to_state_key.items()
+            if include(state_key)
+        ]
     )
 
-    event_to_state = {
+    return {
         e_id: {
             key: event_map[inner_e_id]
             for key, inner_e_id in key_to_eid.items()
@@ -427,14 +460,3 @@ async def filter_events_for_server(
         }
         for e_id, key_to_eid in event_to_state_ids.items()
     }
-
-    to_return = []
-    for e in events:
-        erased = is_sender_erased(e, erased_senders)
-        visible = check_event_is_visible(e, event_to_state[e.event_id])
-        if visible and not erased:
-            to_return.append(e)
-        elif redact:
-            to_return.append(prune_event(e))
-
-    return to_return
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index 532e3fe9cd..a02fd4f79a 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -17,6 +17,7 @@ from unittest.mock import patch
 
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase, make_event_from_dict
+from synapse.events.snapshot import EventContext
 from synapse.types import JsonDict, create_requester
 from synapse.visibility import filter_events_for_client, filter_events_for_server
 
@@ -73,6 +74,39 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
             self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
             self.assertEqual(filtered[i].content["a"], "b")
 
+    def test_filter_outlier(self) -> None:
+        # outlier events must be returned, for the good of the collective federation
+        self.get_success(self._inject_room_member("@resident:remote_hs"))
+        self.get_success(self._inject_visibility("@resident:remote_hs", "joined"))
+
+        outlier = self.get_success(self._inject_outlier())
+        self.assertEqual(
+            self.get_success(
+                filter_events_for_server(self.storage, "remote_hs", [outlier])
+            ),
+            [outlier],
+        )
+
+        # it should also work when there are other events in the list
+        evt = self.get_success(self._inject_message("@unerased:local_hs"))
+
+        filtered = self.get_success(
+            filter_events_for_server(self.storage, "remote_hs", [outlier, evt])
+        )
+        self.assertEqual(len(filtered), 2, f"expected 2 results, got: {filtered}")
+        self.assertEqual(filtered[0], outlier)
+        self.assertEqual(filtered[1].event_id, evt.event_id)
+        self.assertEqual(filtered[1].content, evt.content)
+
+        # ... but other servers should only be able to see the outlier (the other should
+        # be redacted)
+        filtered = self.get_success(
+            filter_events_for_server(self.storage, "other_server", [outlier, evt])
+        )
+        self.assertEqual(filtered[0], outlier)
+        self.assertEqual(filtered[1].event_id, evt.event_id)
+        self.assertNotIn("body", filtered[1].content)
+
     def test_erased_user(self) -> None:
         # 4 message events, from erased and unerased users, with a membership
         # change in the middle of them.
@@ -187,6 +221,25 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         self.get_success(self.storage.persistence.persist_event(event, context))
         return event
 
+    def _inject_outlier(self) -> EventBase:
+        builder = self.event_builder_factory.for_room_version(
+            RoomVersions.V1,
+            {
+                "type": "m.room.member",
+                "sender": "@test:user",
+                "state_key": "@test:user",
+                "room_id": TEST_ROOM_ID,
+                "content": {"membership": "join"},
+            },
+        )
+
+        event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
+        event.internal_metadata.outlier = True
+        self.get_success(
+            self.storage.persistence.persist_event(event, EventContext.for_outlier())
+        )
+        return event
+
 
 class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
     def test_out_of_band_invite_rejection(self):