summary refs log tree commit diff
path: root/synapse/api
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-10-27 11:26:30 -0400
committerGitHub <noreply@github.com>2021-10-27 11:26:30 -0400
commit19d5dc69316a28035caf6a6519ad8a116023de81 (patch)
tree1ddb638911f8c547b9aac0f18aba1eab9173315c /synapse/api
parentDelete messages from `device_inbox` table when deleting device (#10969) (diff)
downloadsynapse-19d5dc69316a28035caf6a6519ad8a116023de81.tar.xz
Refactor `Filter` to handle fields according to data being filtered. (#11194)
This avoids filtering against fields which cannot exist on an
event source. E.g. presence updates don't have a room.
Diffstat (limited to 'synapse/api')
-rw-r--r--synapse/api/filtering.py139
1 files changed, 78 insertions, 61 deletions
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index bc550ae646..4b0a9b2974 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -18,7 +18,8 @@ import json
 from typing import (
     TYPE_CHECKING,
     Awaitable,
-    Container,
+    Callable,
+    Dict,
     Iterable,
     List,
     Optional,
@@ -217,19 +218,19 @@ class FilterCollection:
         return self._filter_json
 
     def timeline_limit(self) -> int:
-        return self._room_timeline_filter.limit()
+        return self._room_timeline_filter.limit
 
     def presence_limit(self) -> int:
-        return self._presence_filter.limit()
+        return self._presence_filter.limit
 
     def ephemeral_limit(self) -> int:
-        return self._room_ephemeral_filter.limit()
+        return self._room_ephemeral_filter.limit
 
     def lazy_load_members(self) -> bool:
-        return self._room_state_filter.lazy_load_members()
+        return self._room_state_filter.lazy_load_members
 
     def include_redundant_members(self) -> bool:
-        return self._room_state_filter.include_redundant_members()
+        return self._room_state_filter.include_redundant_members
 
     def filter_presence(
         self, events: Iterable[UserPresenceState]
@@ -276,19 +277,25 @@ class Filter:
     def __init__(self, filter_json: JsonDict):
         self.filter_json = filter_json
 
-        self.types = self.filter_json.get("types", None)
-        self.not_types = self.filter_json.get("not_types", [])
+        self.limit = filter_json.get("limit", 10)
+        self.lazy_load_members = filter_json.get("lazy_load_members", False)
+        self.include_redundant_members = filter_json.get(
+            "include_redundant_members", False
+        )
+
+        self.types = filter_json.get("types", None)
+        self.not_types = filter_json.get("not_types", [])
 
-        self.rooms = self.filter_json.get("rooms", None)
-        self.not_rooms = self.filter_json.get("not_rooms", [])
+        self.rooms = filter_json.get("rooms", None)
+        self.not_rooms = filter_json.get("not_rooms", [])
 
-        self.senders = self.filter_json.get("senders", None)
-        self.not_senders = self.filter_json.get("not_senders", [])
+        self.senders = filter_json.get("senders", None)
+        self.not_senders = filter_json.get("not_senders", [])
 
-        self.contains_url = self.filter_json.get("contains_url", None)
+        self.contains_url = filter_json.get("contains_url", None)
 
-        self.labels = self.filter_json.get("org.matrix.labels", None)
-        self.not_labels = self.filter_json.get("org.matrix.not_labels", [])
+        self.labels = filter_json.get("org.matrix.labels", None)
+        self.not_labels = filter_json.get("org.matrix.not_labels", [])
 
     def filters_all_types(self) -> bool:
         return "*" in self.not_types
@@ -302,76 +309,95 @@ class Filter:
     def check(self, event: FilterEvent) -> bool:
         """Checks whether the filter matches the given event.
 
+        Args:
+            event: The event, account data, or presence to check against this
+                filter.
+
         Returns:
-            True if the event matches
+            True if the event matches the filter.
         """
         # We usually get the full "events" as dictionaries coming through,
         # except for presence which actually gets passed around as its own
         # namedtuple type.
         if isinstance(event, UserPresenceState):
-            sender: Optional[str] = event.user_id
-            room_id = None
-            ev_type = "m.presence"
-            contains_url = False
-            labels: List[str] = []
+            user_id = event.user_id
+            field_matchers = {
+                "senders": lambda v: user_id == v,
+                "types": lambda v: "m.presence" == v,
+            }
+            return self._check_fields(field_matchers)
         else:
+            content = event.get("content")
+            # Content is assumed to be a dict below, so ensure it is. This should
+            # always be true for events, but account_data has been allowed to
+            # have non-dict content.
+            if not isinstance(content, dict):
+                content = {}
+
             sender = event.get("sender", None)
             if not sender:
                 # Presence events had their 'sender' in content.user_id, but are
                 # now handled above. We don't know if anything else uses this
                 # form. TODO: Check this and probably remove it.
-                content = event.get("content")
-                # account_data has been allowed to have non-dict content, so
-                # check type first
-                if isinstance(content, dict):
-                    sender = content.get("user_id")
+                sender = content.get("user_id")
 
             room_id = event.get("room_id", None)
             ev_type = event.get("type", None)
 
-            content = event.get("content") or {}
             # check if there is a string url field in the content for filtering purposes
-            contains_url = isinstance(content.get("url"), str)
             labels = content.get(EventContentFields.LABELS, [])
 
-        return self.check_fields(room_id, sender, ev_type, labels, contains_url)
+            field_matchers = {
+                "rooms": lambda v: room_id == v,
+                "senders": lambda v: sender == v,
+                "types": lambda v: _matches_wildcard(ev_type, v),
+                "labels": lambda v: v in labels,
+            }
+
+            result = self._check_fields(field_matchers)
+            if not result:
+                return result
+
+            contains_url_filter = self.contains_url
+            if contains_url_filter is not None:
+                contains_url = isinstance(content.get("url"), str)
+                if contains_url_filter != contains_url:
+                    return False
+
+            return True
 
-    def check_fields(
-        self,
-        room_id: Optional[str],
-        sender: Optional[str],
-        event_type: Optional[str],
-        labels: Container[str],
-        contains_url: bool,
-    ) -> bool:
+    def _check_fields(self, field_matchers: Dict[str, Callable[[str], bool]]) -> bool:
         """Checks whether the filter matches the given event fields.
 
+        Args:
+            field_matchers: A map of attribute name to callable to use for checking
+                particular fields.
+
+                The attribute name and an inverse (not_<attribute name>) must
+                exist on the Filter.
+
+                The callable should return true if the event's value matches the
+                filter's value.
+
         Returns:
             True if the event fields match
         """
-        literal_keys = {
-            "rooms": lambda v: room_id == v,
-            "senders": lambda v: sender == v,
-            "types": lambda v: _matches_wildcard(event_type, v),
-            "labels": lambda v: v in labels,
-        }
-
-        for name, match_func in literal_keys.items():
+
+        for name, match_func in field_matchers.items():
+            # If the event matches one of the disallowed values, reject it.
             not_name = "not_%s" % (name,)
             disallowed_values = getattr(self, not_name)
             if any(map(match_func, disallowed_values)):
                 return False
 
+            # Other the event does not match at least one of the allowed values,
+            # reject it.
             allowed_values = getattr(self, name)
             if allowed_values is not None:
                 if not any(map(match_func, allowed_values)):
                     return False
 
-        contains_url_filter = self.filter_json.get("contains_url")
-        if contains_url_filter is not None:
-            if contains_url_filter != contains_url:
-                return False
-
+        # Otherwise, accept it.
         return True
 
     def filter_rooms(self, room_ids: Iterable[str]) -> Set[str]:
@@ -385,10 +411,10 @@ class Filter:
         """
         room_ids = set(room_ids)
 
-        disallowed_rooms = set(self.filter_json.get("not_rooms", []))
+        disallowed_rooms = set(self.not_rooms)
         room_ids -= disallowed_rooms
 
-        allowed_rooms = self.filter_json.get("rooms", None)
+        allowed_rooms = self.rooms
         if allowed_rooms is not None:
             room_ids &= set(allowed_rooms)
 
@@ -397,15 +423,6 @@ class Filter:
     def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
         return list(filter(self.check, events))
 
-    def limit(self) -> int:
-        return self.filter_json.get("limit", 10)
-
-    def lazy_load_members(self) -> bool:
-        return self.filter_json.get("lazy_load_members", False)
-
-    def include_redundant_members(self) -> bool:
-        return self.filter_json.get("include_redundant_members", False)
-
     def with_room_ids(self, room_ids: Iterable[str]) -> "Filter":
         """Returns a new filter with the given room IDs appended.