diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index 4b0a9b2974..13dd6ce248 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -1,7 +1,7 @@
# Copyright 2015, 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# Copyright 2018-2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -86,6 +86,9 @@ ROOM_EVENT_FILTER_SCHEMA = {
# cf https://github.com/matrix-org/matrix-doc/pull/2326
"org.matrix.labels": {"type": "array", "items": {"type": "string"}},
"org.matrix.not_labels": {"type": "array", "items": {"type": "string"}},
+ # MSC3440, filtering by event relations.
+ "io.element.relation_senders": {"type": "array", "items": {"type": "string"}},
+ "io.element.relation_types": {"type": "array", "items": {"type": "string"}},
},
}
@@ -146,14 +149,16 @@ def matrix_user_id_validator(user_id_str: str) -> UserID:
class Filtering:
def __init__(self, hs: "HomeServer"):
- super().__init__()
+ self._hs = hs
self.store = hs.get_datastore()
+ self.DEFAULT_FILTER_COLLECTION = FilterCollection(hs, {})
+
async def get_user_filter(
self, user_localpart: str, filter_id: Union[int, str]
) -> "FilterCollection":
result = await self.store.get_user_filter(user_localpart, filter_id)
- return FilterCollection(result)
+ return FilterCollection(self._hs, result)
def add_user_filter(
self, user_localpart: str, user_filter: JsonDict
@@ -191,21 +196,22 @@ FilterEvent = TypeVar("FilterEvent", EventBase, UserPresenceState, JsonDict)
class FilterCollection:
- def __init__(self, filter_json: JsonDict):
+ def __init__(self, hs: "HomeServer", filter_json: JsonDict):
self._filter_json = filter_json
room_filter_json = self._filter_json.get("room", {})
self._room_filter = Filter(
- {k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms")}
+ hs,
+ {k: v for k, v in room_filter_json.items() if k in ("rooms", "not_rooms")},
)
- self._room_timeline_filter = Filter(room_filter_json.get("timeline", {}))
- self._room_state_filter = Filter(room_filter_json.get("state", {}))
- self._room_ephemeral_filter = Filter(room_filter_json.get("ephemeral", {}))
- self._room_account_data = Filter(room_filter_json.get("account_data", {}))
- self._presence_filter = Filter(filter_json.get("presence", {}))
- self._account_data = Filter(filter_json.get("account_data", {}))
+ self._room_timeline_filter = Filter(hs, room_filter_json.get("timeline", {}))
+ self._room_state_filter = Filter(hs, room_filter_json.get("state", {}))
+ self._room_ephemeral_filter = Filter(hs, room_filter_json.get("ephemeral", {}))
+ self._room_account_data = Filter(hs, room_filter_json.get("account_data", {}))
+ self._presence_filter = Filter(hs, filter_json.get("presence", {}))
+ self._account_data = Filter(hs, filter_json.get("account_data", {}))
self.include_leave = filter_json.get("room", {}).get("include_leave", False)
self.event_fields = filter_json.get("event_fields", [])
@@ -232,25 +238,37 @@ class FilterCollection:
def include_redundant_members(self) -> bool:
return self._room_state_filter.include_redundant_members
- def filter_presence(
+ async def filter_presence(
self, events: Iterable[UserPresenceState]
) -> List[UserPresenceState]:
- return self._presence_filter.filter(events)
+ return await self._presence_filter.filter(events)
- def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
- return self._account_data.filter(events)
+ async def filter_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
+ return await self._account_data.filter(events)
- def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]:
- return self._room_state_filter.filter(self._room_filter.filter(events))
+ async def filter_room_state(self, events: Iterable[EventBase]) -> List[EventBase]:
+ return await self._room_state_filter.filter(
+ await self._room_filter.filter(events)
+ )
- def filter_room_timeline(self, events: Iterable[EventBase]) -> List[EventBase]:
- return self._room_timeline_filter.filter(self._room_filter.filter(events))
+ async def filter_room_timeline(
+ self, events: Iterable[EventBase]
+ ) -> List[EventBase]:
+ return await self._room_timeline_filter.filter(
+ await self._room_filter.filter(events)
+ )
- def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]:
- return self._room_ephemeral_filter.filter(self._room_filter.filter(events))
+ async def filter_room_ephemeral(self, events: Iterable[JsonDict]) -> List[JsonDict]:
+ return await self._room_ephemeral_filter.filter(
+ await self._room_filter.filter(events)
+ )
- def filter_room_account_data(self, events: Iterable[JsonDict]) -> List[JsonDict]:
- return self._room_account_data.filter(self._room_filter.filter(events))
+ async def filter_room_account_data(
+ self, events: Iterable[JsonDict]
+ ) -> List[JsonDict]:
+ return await self._room_account_data.filter(
+ await self._room_filter.filter(events)
+ )
def blocks_all_presence(self) -> bool:
return (
@@ -274,7 +292,9 @@ class FilterCollection:
class Filter:
- def __init__(self, filter_json: JsonDict):
+ def __init__(self, hs: "HomeServer", filter_json: JsonDict):
+ self._hs = hs
+ self._store = hs.get_datastore()
self.filter_json = filter_json
self.limit = filter_json.get("limit", 10)
@@ -297,6 +317,20 @@ class Filter:
self.labels = filter_json.get("org.matrix.labels", None)
self.not_labels = filter_json.get("org.matrix.not_labels", [])
+ # Ideally these would be rejected at the endpoint if they were provided
+ # and not supported, but that would involve modifying the JSON schema
+ # based on the homeserver configuration.
+ if hs.config.experimental.msc3440_enabled:
+ self.relation_senders = self.filter_json.get(
+ "io.element.relation_senders", None
+ )
+ self.relation_types = self.filter_json.get(
+ "io.element.relation_types", None
+ )
+ else:
+ self.relation_senders = None
+ self.relation_types = None
+
def filters_all_types(self) -> bool:
return "*" in self.not_types
@@ -306,7 +340,7 @@ class Filter:
def filters_all_rooms(self) -> bool:
return "*" in self.not_rooms
- def check(self, event: FilterEvent) -> bool:
+ def _check(self, event: FilterEvent) -> bool:
"""Checks whether the filter matches the given event.
Args:
@@ -420,8 +454,30 @@ class Filter:
return room_ids
- def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
- return list(filter(self.check, events))
+ async def _check_event_relations(
+ self, events: Iterable[FilterEvent]
+ ) -> List[FilterEvent]:
+ # The event IDs to check, mypy doesn't understand the ifinstance check.
+ event_ids = [event.event_id for event in events if isinstance(event, EventBase)] # type: ignore[attr-defined]
+ event_ids_to_keep = set(
+ await self._store.events_have_relations(
+ event_ids, self.relation_senders, self.relation_types
+ )
+ )
+
+ return [
+ event
+ for event in events
+ if not isinstance(event, EventBase) or event.event_id in event_ids_to_keep
+ ]
+
+ async def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
+ result = [event for event in events if self._check(event)]
+
+ if self.relation_senders or self.relation_types:
+ return await self._check_event_relations(result)
+
+ return result
def with_room_ids(self, room_ids: Iterable[str]) -> "Filter":
"""Returns a new filter with the given room IDs appended.
@@ -433,7 +489,7 @@ class Filter:
filter: A new filter including the given rooms and the old
filter's rooms.
"""
- newFilter = Filter(self.filter_json)
+ newFilter = Filter(self._hs, self.filter_json)
newFilter.rooms += room_ids
return newFilter
@@ -444,6 +500,3 @@ def _matches_wildcard(actual_value: Optional[str], filter_value: str) -> bool:
return actual_value.startswith(type_prefix)
else:
return actual_value == filter_value
-
-
-DEFAULT_FILTER_COLLECTION = FilterCollection({})
|