diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index bc550ae646..4b0a9b2974 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -18,7 +18,8 @@ import json
from typing import (
TYPE_CHECKING,
Awaitable,
- Container,
+ Callable,
+ Dict,
Iterable,
List,
Optional,
@@ -217,19 +218,19 @@ class FilterCollection:
return self._filter_json
def timeline_limit(self) -> int:
- return self._room_timeline_filter.limit()
+ return self._room_timeline_filter.limit
def presence_limit(self) -> int:
- return self._presence_filter.limit()
+ return self._presence_filter.limit
def ephemeral_limit(self) -> int:
- return self._room_ephemeral_filter.limit()
+ return self._room_ephemeral_filter.limit
def lazy_load_members(self) -> bool:
- return self._room_state_filter.lazy_load_members()
+ return self._room_state_filter.lazy_load_members
def include_redundant_members(self) -> bool:
- return self._room_state_filter.include_redundant_members()
+ return self._room_state_filter.include_redundant_members
def filter_presence(
self, events: Iterable[UserPresenceState]
@@ -276,19 +277,25 @@ class Filter:
def __init__(self, filter_json: JsonDict):
self.filter_json = filter_json
- self.types = self.filter_json.get("types", None)
- self.not_types = self.filter_json.get("not_types", [])
+ self.limit = filter_json.get("limit", 10)
+ self.lazy_load_members = filter_json.get("lazy_load_members", False)
+ self.include_redundant_members = filter_json.get(
+ "include_redundant_members", False
+ )
+
+ self.types = filter_json.get("types", None)
+ self.not_types = filter_json.get("not_types", [])
- self.rooms = self.filter_json.get("rooms", None)
- self.not_rooms = self.filter_json.get("not_rooms", [])
+ self.rooms = filter_json.get("rooms", None)
+ self.not_rooms = filter_json.get("not_rooms", [])
- self.senders = self.filter_json.get("senders", None)
- self.not_senders = self.filter_json.get("not_senders", [])
+ self.senders = filter_json.get("senders", None)
+ self.not_senders = filter_json.get("not_senders", [])
- self.contains_url = self.filter_json.get("contains_url", None)
+ self.contains_url = filter_json.get("contains_url", None)
- self.labels = self.filter_json.get("org.matrix.labels", None)
- self.not_labels = self.filter_json.get("org.matrix.not_labels", [])
+ self.labels = filter_json.get("org.matrix.labels", None)
+ self.not_labels = filter_json.get("org.matrix.not_labels", [])
def filters_all_types(self) -> bool:
return "*" in self.not_types
@@ -302,76 +309,95 @@ class Filter:
def check(self, event: FilterEvent) -> bool:
"""Checks whether the filter matches the given event.
+ Args:
+ event: The event, account data, or presence to check against this
+ filter.
+
Returns:
- True if the event matches
+ True if the event matches the filter.
"""
# We usually get the full "events" as dictionaries coming through,
# except for presence which actually gets passed around as its own
# namedtuple type.
if isinstance(event, UserPresenceState):
- sender: Optional[str] = event.user_id
- room_id = None
- ev_type = "m.presence"
- contains_url = False
- labels: List[str] = []
+ user_id = event.user_id
+ field_matchers = {
+ "senders": lambda v: user_id == v,
+ "types": lambda v: "m.presence" == v,
+ }
+ return self._check_fields(field_matchers)
else:
+ content = event.get("content")
+ # Content is assumed to be a dict below, so ensure it is. This should
+ # always be true for events, but account_data has been allowed to
+ # have non-dict content.
+ if not isinstance(content, dict):
+ content = {}
+
sender = event.get("sender", None)
if not sender:
# Presence events had their 'sender' in content.user_id, but are
# now handled above. We don't know if anything else uses this
# form. TODO: Check this and probably remove it.
- content = event.get("content")
- # account_data has been allowed to have non-dict content, so
- # check type first
- if isinstance(content, dict):
- sender = content.get("user_id")
+ sender = content.get("user_id")
room_id = event.get("room_id", None)
ev_type = event.get("type", None)
- content = event.get("content") or {}
# check if there is a string url field in the content for filtering purposes
- contains_url = isinstance(content.get("url"), str)
labels = content.get(EventContentFields.LABELS, [])
- return self.check_fields(room_id, sender, ev_type, labels, contains_url)
+ field_matchers = {
+ "rooms": lambda v: room_id == v,
+ "senders": lambda v: sender == v,
+ "types": lambda v: _matches_wildcard(ev_type, v),
+ "labels": lambda v: v in labels,
+ }
+
+ result = self._check_fields(field_matchers)
+ if not result:
+ return result
+
+ contains_url_filter = self.contains_url
+ if contains_url_filter is not None:
+ contains_url = isinstance(content.get("url"), str)
+ if contains_url_filter != contains_url:
+ return False
+
+ return True
- def check_fields(
- self,
- room_id: Optional[str],
- sender: Optional[str],
- event_type: Optional[str],
- labels: Container[str],
- contains_url: bool,
- ) -> bool:
+ def _check_fields(self, field_matchers: Dict[str, Callable[[str], bool]]) -> bool:
"""Checks whether the filter matches the given event fields.
+ Args:
+ field_matchers: A map of attribute name to callable to use for checking
+ particular fields.
+
+ The attribute name and an inverse (not_<attribute name>) must
+ exist on the Filter.
+
+ The callable should return true if the event's value matches the
+ filter's value.
+
Returns:
True if the event fields match
"""
- literal_keys = {
- "rooms": lambda v: room_id == v,
- "senders": lambda v: sender == v,
- "types": lambda v: _matches_wildcard(event_type, v),
- "labels": lambda v: v in labels,
- }
-
- for name, match_func in literal_keys.items():
+
+ for name, match_func in field_matchers.items():
+ # If the event matches one of the disallowed values, reject it.
not_name = "not_%s" % (name,)
disallowed_values = getattr(self, not_name)
if any(map(match_func, disallowed_values)):
return False
+ # Other the event does not match at least one of the allowed values,
+ # reject it.
allowed_values = getattr(self, name)
if allowed_values is not None:
if not any(map(match_func, allowed_values)):
return False
- contains_url_filter = self.filter_json.get("contains_url")
- if contains_url_filter is not None:
- if contains_url_filter != contains_url:
- return False
-
+ # Otherwise, accept it.
return True
def filter_rooms(self, room_ids: Iterable[str]) -> Set[str]:
@@ -385,10 +411,10 @@ class Filter:
"""
room_ids = set(room_ids)
- disallowed_rooms = set(self.filter_json.get("not_rooms", []))
+ disallowed_rooms = set(self.not_rooms)
room_ids -= disallowed_rooms
- allowed_rooms = self.filter_json.get("rooms", None)
+ allowed_rooms = self.rooms
if allowed_rooms is not None:
room_ids &= set(allowed_rooms)
@@ -397,15 +423,6 @@ class Filter:
def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
return list(filter(self.check, events))
- def limit(self) -> int:
- return self.filter_json.get("limit", 10)
-
- def lazy_load_members(self) -> bool:
- return self.filter_json.get("lazy_load_members", False)
-
- def include_redundant_members(self) -> bool:
- return self.filter_json.get("include_redundant_members", False)
-
def with_room_ids(self, room_ids: Iterable[str]) -> "Filter":
"""Returns a new filter with the given room IDs appended.
|