diff --git a/changelog.d/10958.misc b/changelog.d/10958.misc
new file mode 100644
index 0000000000..409ecc35cb
--- /dev/null
+++ b/changelog.d/10958.misc
@@ -0,0 +1 @@
+Add type hints to filtering classes.
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index ad1ff6a9df..20e91a115d 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -15,7 +15,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
-from typing import List
+from typing import (
+ TYPE_CHECKING,
+ Awaitable,
+ Container,
+ Iterable,
+ List,
+ Optional,
+ Set,
+ TypeVar,
+ Union,
+)
import jsonschema
from jsonschema import FormatChecker
@@ -23,7 +33,11 @@ from jsonschema import FormatChecker
from synapse.api.constants import EventContentFields
from synapse.api.errors import SynapseError
from synapse.api.presence import UserPresenceState
-from synapse.types import RoomID, UserID
+from synapse.events import EventBase
+from synapse.types import JsonDict, RoomID, UserID
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
FILTER_SCHEMA = {
"additionalProperties": False,
@@ -120,25 +134,29 @@ USER_FILTER_SCHEMA = {
@FormatChecker.cls_checks("matrix_room_id")
-def matrix_room_id_validator(room_id_str):
+def matrix_room_id_validator(room_id_str: str) -> RoomID:
return RoomID.from_string(room_id_str)
@FormatChecker.cls_checks("matrix_user_id")
-def matrix_user_id_validator(user_id_str):
+def matrix_user_id_validator(user_id_str: str) -> UserID:
return UserID.from_string(user_id_str)
class Filtering:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.store = hs.get_datastore()
- async def get_user_filter(self, user_localpart, filter_id):
+ async def get_user_filter(
+ self, user_localpart: str, filter_id: Union[int, str]
+ ) -> "FilterCollection":
result = await self.store.get_user_filter(user_localpart, filter_id)
return FilterCollection(result)
- def add_user_filter(self, user_localpart, user_filter):
+ def add_user_filter(
+ self, user_localpart: str, user_filter: JsonDict
+ ) -> Awaitable[int]:
self.check_valid_filter(user_filter)
return self.store.add_user_filter(user_localpart, user_filter)
@@ -146,13 +164,13 @@ class Filtering:
# replace_user_filter at some point? There's no REST API specified for
# them however
- def check_valid_filter(self, user_filter_json):
+ def check_valid_filter(self, user_filter_json: JsonDict) -> None:
"""Check if the provided filter is valid.
This inspects all definitions contained within the filter.
Args:
- user_filter_json(dict): The filter
+ user_filter_json: The filter
Raises:
SynapseError: If the filter is not valid.
"""
@@ -167,8 +185,12 @@ class Filtering:
raise SynapseError(400, str(e))
+# Filters work across events, presence EDUs, and account data.
+FilterEvent = TypeVar("FilterEvent", EventBase, UserPresenceState, JsonDict)
+
+
class FilterCollection:
- def __init__(self, filter_json):
+ def __init__(self, filter_json: JsonDict):
self._filter_json = filter_json
room_filter_json = self._filter_json.get("room", {})
@@ -188,25 +210,25 @@ class FilterCollection:
self.event_fields = filter_json.get("event_fields", [])
self.event_format = filter_json.get("event_format", "client")
- def __repr__(self):
+ def __repr__(self) -> str:
return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
- def get_filter_json(self):
+ def get_filter_json(self) -> JsonDict:
return self._filter_json
- def timeline_limit(self):
+ def timeline_limit(self) -> int:
return self._room_timeline_filter.limit()
- def presence_limit(self):
+ def presence_limit(self) -> int:
return self._presence_filter.limit()
- def ephemeral_limit(self):
+ def ephemeral_limit(self) -> int:
return self._room_ephemeral_filter.limit()
- def lazy_load_members(self):
+ def lazy_load_members(self) -> bool:
return self._room_state_filter.lazy_load_members()
- def include_redundant_members(self):
+ def include_redundant_members(self) -> bool:
return self._room_state_filter.include_redundant_members()
def filter_presence(self, events):
@@ -218,29 +240,31 @@ class FilterCollection:
def filter_room_state(self, events):
return self._room_state_filter.filter(self._room_filter.filter(events))
- def filter_room_timeline(self, events):
+ def filter_room_timeline(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
return self._room_timeline_filter.filter(self._room_filter.filter(events))
- def filter_room_ephemeral(self, events):
+ def filter_room_ephemeral(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
return self._room_ephemeral_filter.filter(self._room_filter.filter(events))
- def filter_room_account_data(self, events):
+ def filter_room_account_data(
+ self, events: Iterable[FilterEvent]
+ ) -> List[FilterEvent]:
return self._room_account_data.filter(self._room_filter.filter(events))
- def blocks_all_presence(self):
+ def blocks_all_presence(self) -> bool:
return (
self._presence_filter.filters_all_types()
or self._presence_filter.filters_all_senders()
)
- def blocks_all_room_ephemeral(self):
+ def blocks_all_room_ephemeral(self) -> bool:
return (
self._room_ephemeral_filter.filters_all_types()
or self._room_ephemeral_filter.filters_all_senders()
or self._room_ephemeral_filter.filters_all_rooms()
)
- def blocks_all_room_timeline(self):
+ def blocks_all_room_timeline(self) -> bool:
return (
self._room_timeline_filter.filters_all_types()
or self._room_timeline_filter.filters_all_senders()
@@ -249,7 +273,7 @@ class FilterCollection:
class Filter:
- def __init__(self, filter_json):
+ def __init__(self, filter_json: JsonDict):
self.filter_json = filter_json
self.types = self.filter_json.get("types", None)
@@ -266,20 +290,20 @@ class Filter:
self.labels = self.filter_json.get("org.matrix.labels", None)
self.not_labels = self.filter_json.get("org.matrix.not_labels", [])
- def filters_all_types(self):
+ def filters_all_types(self) -> bool:
return "*" in self.not_types
- def filters_all_senders(self):
+ def filters_all_senders(self) -> bool:
return "*" in self.not_senders
- def filters_all_rooms(self):
+ def filters_all_rooms(self) -> bool:
return "*" in self.not_rooms
- def check(self, event):
+ def check(self, event: FilterEvent) -> bool:
"""Checks whether the filter matches the given event.
Returns:
- bool: True if the event matches
+ True if the event matches
"""
# We usually get the full "events" as dictionaries coming through,
# except for presence which actually gets passed around as its own
@@ -305,18 +329,25 @@ class Filter:
room_id = event.get("room_id", None)
ev_type = event.get("type", None)
- content = event.get("content", {})
+ content = event.get("content") or {}
# check if there is a string url field in the content for filtering purposes
contains_url = isinstance(content.get("url"), str)
labels = content.get(EventContentFields.LABELS, [])
return self.check_fields(room_id, sender, ev_type, labels, contains_url)
- def check_fields(self, room_id, sender, event_type, labels, contains_url):
+ def check_fields(
+ self,
+ room_id: Optional[str],
+ sender: Optional[str],
+ event_type: Optional[str],
+ labels: Container[str],
+ contains_url: bool,
+ ) -> bool:
"""Checks whether the filter matches the given event fields.
Returns:
- bool: True if the event fields match
+ True if the event fields match
"""
literal_keys = {
"rooms": lambda v: room_id == v,
@@ -343,14 +374,14 @@ class Filter:
return True
- def filter_rooms(self, room_ids):
+ def filter_rooms(self, room_ids: Iterable[str]) -> Set[str]:
"""Apply the 'rooms' filter to a given list of rooms.
Args:
- room_ids (list): A list of room_ids.
+ room_ids: A list of room_ids.
Returns:
- list: A list of room_ids that match the filter
+ A list of room_ids that match the filter
"""
room_ids = set(room_ids)
@@ -363,23 +394,23 @@ class Filter:
return room_ids
- def filter(self, events):
+ def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
return list(filter(self.check, events))
- def limit(self):
+ def limit(self) -> int:
return self.filter_json.get("limit", 10)
- def lazy_load_members(self):
+ def lazy_load_members(self) -> bool:
return self.filter_json.get("lazy_load_members", False)
- def include_redundant_members(self):
+ def include_redundant_members(self) -> bool:
return self.filter_json.get("include_redundant_members", False)
- def with_room_ids(self, room_ids):
+ def with_room_ids(self, room_ids: Iterable[str]) -> "Filter":
"""Returns a new filter with the given room IDs appended.
Args:
- room_ids (iterable[unicode]): The room_ids to add
+ room_ids: The room_ids to add
Returns:
filter: A new filter including the given rooms and the old
@@ -390,8 +421,8 @@ class Filter:
return newFilter
-def _matches_wildcard(actual_value, filter_value):
- if filter_value.endswith("*"):
+def _matches_wildcard(actual_value: Optional[str], filter_value: str) -> bool:
+ if filter_value.endswith("*") and isinstance(actual_value, str):
type_prefix = filter_value[:-1]
return actual_value.startswith(type_prefix)
else:
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index bb244a03c0..434986fa64 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Union
+
from canonicaljson import encode_canonical_json
from synapse.api.errors import Codes, SynapseError
@@ -22,7 +24,9 @@ from synapse.util.caches.descriptors import cached
class FilteringStore(SQLBaseStore):
@cached(num_args=2)
- async def get_user_filter(self, user_localpart, filter_id):
+ async def get_user_filter(
+ self, user_localpart: str, filter_id: Union[int, str]
+ ) -> JsonDict:
# filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
# with a coherent error message rather than 500 M_UNKNOWN.
try:
@@ -40,7 +44,7 @@ class FilteringStore(SQLBaseStore):
return db_to_json(def_json)
- async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str:
+ async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int:
def_json = encode_canonical_json(user_filter)
# Need an atomic transaction to SELECT the maximal ID so far then
|