summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-10-01 07:02:32 -0400
committerGitHub <noreply@github.com>2021-10-01 07:02:32 -0400
commit7e440520c9b370ce008c6a65c5dd87a360a6457c (patch)
treede7da69fea19d210f2f9b8650a140aaf97c95285
parentClean-up registration tests (#10945) (diff)
downloadsynapse-7e440520c9b370ce008c6a65c5dd87a360a6457c.tar.xz
Add type hints to filtering classes. (#10958)
-rw-r--r--changelog.d/10958.misc1
-rw-r--r--synapse/api/filtering.py117
-rw-r--r--synapse/storage/databases/main/filtering.py8
3 files changed, 81 insertions, 45 deletions
diff --git a/changelog.d/10958.misc b/changelog.d/10958.misc
new file mode 100644
index 0000000000..409ecc35cb
--- /dev/null
+++ b/changelog.d/10958.misc
@@ -0,0 +1 @@
+Add type hints to filtering classes.
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index ad1ff6a9df..20e91a115d 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -15,7 +15,17 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import json
-from typing import List
+from typing import (
+    TYPE_CHECKING,
+    Awaitable,
+    Container,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    TypeVar,
+    Union,
+)
 
 import jsonschema
 from jsonschema import FormatChecker
@@ -23,7 +33,11 @@ from jsonschema import FormatChecker
 from synapse.api.constants import EventContentFields
 from synapse.api.errors import SynapseError
 from synapse.api.presence import UserPresenceState
-from synapse.types import RoomID, UserID
+from synapse.events import EventBase
+from synapse.types import JsonDict, RoomID, UserID
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
 
 FILTER_SCHEMA = {
     "additionalProperties": False,
@@ -120,25 +134,29 @@ USER_FILTER_SCHEMA = {
 
 
 @FormatChecker.cls_checks("matrix_room_id")
-def matrix_room_id_validator(room_id_str):
+def matrix_room_id_validator(room_id_str: str) -> RoomID:
     return RoomID.from_string(room_id_str)
 
 
 @FormatChecker.cls_checks("matrix_user_id")
-def matrix_user_id_validator(user_id_str):
+def matrix_user_id_validator(user_id_str: str) -> UserID:
     return UserID.from_string(user_id_str)
 
 
 class Filtering:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.store = hs.get_datastore()
 
-    async def get_user_filter(self, user_localpart, filter_id):
+    async def get_user_filter(
+        self, user_localpart: str, filter_id: Union[int, str]
+    ) -> "FilterCollection":
         result = await self.store.get_user_filter(user_localpart, filter_id)
         return FilterCollection(result)
 
-    def add_user_filter(self, user_localpart, user_filter):
+    def add_user_filter(
+        self, user_localpart: str, user_filter: JsonDict
+    ) -> Awaitable[int]:
         self.check_valid_filter(user_filter)
         return self.store.add_user_filter(user_localpart, user_filter)
 
@@ -146,13 +164,13 @@ class Filtering:
     #   replace_user_filter at some point? There's no REST API specified for
     #   them however
 
-    def check_valid_filter(self, user_filter_json):
+    def check_valid_filter(self, user_filter_json: JsonDict) -> None:
         """Check if the provided filter is valid.
 
         This inspects all definitions contained within the filter.
 
         Args:
-            user_filter_json(dict): The filter
+            user_filter_json: The filter
         Raises:
             SynapseError: If the filter is not valid.
         """
@@ -167,8 +185,12 @@ class Filtering:
             raise SynapseError(400, str(e))
 
 
+# Filters work across events, presence EDUs, and account data.
+FilterEvent = TypeVar("FilterEvent", EventBase, UserPresenceState, JsonDict)
+
+
 class FilterCollection:
-    def __init__(self, filter_json):
+    def __init__(self, filter_json: JsonDict):
         self._filter_json = filter_json
 
         room_filter_json = self._filter_json.get("room", {})
@@ -188,25 +210,25 @@ class FilterCollection:
         self.event_fields = filter_json.get("event_fields", [])
         self.event_format = filter_json.get("event_format", "client")
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "<FilterCollection %s>" % (json.dumps(self._filter_json),)
 
-    def get_filter_json(self):
+    def get_filter_json(self) -> JsonDict:
         return self._filter_json
 
-    def timeline_limit(self):
+    def timeline_limit(self) -> int:
         return self._room_timeline_filter.limit()
 
-    def presence_limit(self):
+    def presence_limit(self) -> int:
         return self._presence_filter.limit()
 
-    def ephemeral_limit(self):
+    def ephemeral_limit(self) -> int:
         return self._room_ephemeral_filter.limit()
 
-    def lazy_load_members(self):
+    def lazy_load_members(self) -> bool:
         return self._room_state_filter.lazy_load_members()
 
-    def include_redundant_members(self):
+    def include_redundant_members(self) -> bool:
         return self._room_state_filter.include_redundant_members()
 
     def filter_presence(self, events):
@@ -218,29 +240,31 @@ class FilterCollection:
     def filter_room_state(self, events):
         return self._room_state_filter.filter(self._room_filter.filter(events))
 
-    def filter_room_timeline(self, events):
+    def filter_room_timeline(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
         return self._room_timeline_filter.filter(self._room_filter.filter(events))
 
-    def filter_room_ephemeral(self, events):
+    def filter_room_ephemeral(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
         return self._room_ephemeral_filter.filter(self._room_filter.filter(events))
 
-    def filter_room_account_data(self, events):
+    def filter_room_account_data(
+        self, events: Iterable[FilterEvent]
+    ) -> List[FilterEvent]:
         return self._room_account_data.filter(self._room_filter.filter(events))
 
-    def blocks_all_presence(self):
+    def blocks_all_presence(self) -> bool:
         return (
             self._presence_filter.filters_all_types()
             or self._presence_filter.filters_all_senders()
         )
 
-    def blocks_all_room_ephemeral(self):
+    def blocks_all_room_ephemeral(self) -> bool:
         return (
             self._room_ephemeral_filter.filters_all_types()
             or self._room_ephemeral_filter.filters_all_senders()
             or self._room_ephemeral_filter.filters_all_rooms()
         )
 
-    def blocks_all_room_timeline(self):
+    def blocks_all_room_timeline(self) -> bool:
         return (
             self._room_timeline_filter.filters_all_types()
             or self._room_timeline_filter.filters_all_senders()
@@ -249,7 +273,7 @@ class FilterCollection:
 
 
 class Filter:
-    def __init__(self, filter_json):
+    def __init__(self, filter_json: JsonDict):
         self.filter_json = filter_json
 
         self.types = self.filter_json.get("types", None)
@@ -266,20 +290,20 @@ class Filter:
         self.labels = self.filter_json.get("org.matrix.labels", None)
         self.not_labels = self.filter_json.get("org.matrix.not_labels", [])
 
-    def filters_all_types(self):
+    def filters_all_types(self) -> bool:
         return "*" in self.not_types
 
-    def filters_all_senders(self):
+    def filters_all_senders(self) -> bool:
         return "*" in self.not_senders
 
-    def filters_all_rooms(self):
+    def filters_all_rooms(self) -> bool:
         return "*" in self.not_rooms
 
-    def check(self, event):
+    def check(self, event: FilterEvent) -> bool:
         """Checks whether the filter matches the given event.
 
         Returns:
-            bool: True if the event matches
+            True if the event matches
         """
         # We usually get the full "events" as dictionaries coming through,
         # except for presence which actually gets passed around as its own
@@ -305,18 +329,25 @@ class Filter:
             room_id = event.get("room_id", None)
             ev_type = event.get("type", None)
 
-            content = event.get("content", {})
+            content = event.get("content") or {}
             # check if there is a string url field in the content for filtering purposes
             contains_url = isinstance(content.get("url"), str)
             labels = content.get(EventContentFields.LABELS, [])
 
         return self.check_fields(room_id, sender, ev_type, labels, contains_url)
 
-    def check_fields(self, room_id, sender, event_type, labels, contains_url):
+    def check_fields(
+        self,
+        room_id: Optional[str],
+        sender: Optional[str],
+        event_type: Optional[str],
+        labels: Container[str],
+        contains_url: bool,
+    ) -> bool:
         """Checks whether the filter matches the given event fields.
 
         Returns:
-            bool: True if the event fields match
+            True if the event fields match
         """
         literal_keys = {
             "rooms": lambda v: room_id == v,
@@ -343,14 +374,14 @@ class Filter:
 
         return True
 
-    def filter_rooms(self, room_ids):
+    def filter_rooms(self, room_ids: Iterable[str]) -> Set[str]:
         """Apply the 'rooms' filter to a given list of rooms.
 
         Args:
-            room_ids (list): A list of room_ids.
+            room_ids: A list of room_ids.
 
         Returns:
-            list: A list of room_ids that match the filter
+            A list of room_ids that match the filter
         """
         room_ids = set(room_ids)
 
@@ -363,23 +394,23 @@ class Filter:
 
         return room_ids
 
-    def filter(self, events):
+    def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
         return list(filter(self.check, events))
 
-    def limit(self):
+    def limit(self) -> int:
         return self.filter_json.get("limit", 10)
 
-    def lazy_load_members(self):
+    def lazy_load_members(self) -> bool:
         return self.filter_json.get("lazy_load_members", False)
 
-    def include_redundant_members(self):
+    def include_redundant_members(self) -> bool:
         return self.filter_json.get("include_redundant_members", False)
 
-    def with_room_ids(self, room_ids):
+    def with_room_ids(self, room_ids: Iterable[str]) -> "Filter":
         """Returns a new filter with the given room IDs appended.
 
         Args:
-            room_ids (iterable[unicode]): The room_ids to add
+            room_ids: The room_ids to add
 
         Returns:
             filter: A new filter including the given rooms and the old
@@ -390,8 +421,8 @@ class Filter:
         return newFilter
 
 
-def _matches_wildcard(actual_value, filter_value):
-    if filter_value.endswith("*"):
+def _matches_wildcard(actual_value: Optional[str], filter_value: str) -> bool:
+    if filter_value.endswith("*") and isinstance(actual_value, str):
         type_prefix = filter_value[:-1]
         return actual_value.startswith(type_prefix)
     else:
diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py
index bb244a03c0..434986fa64 100644
--- a/synapse/storage/databases/main/filtering.py
+++ b/synapse/storage/databases/main/filtering.py
@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Union
+
 from canonicaljson import encode_canonical_json
 
 from synapse.api.errors import Codes, SynapseError
@@ -22,7 +24,9 @@ from synapse.util.caches.descriptors import cached
 
 class FilteringStore(SQLBaseStore):
     @cached(num_args=2)
-    async def get_user_filter(self, user_localpart, filter_id):
+    async def get_user_filter(
+        self, user_localpart: str, filter_id: Union[int, str]
+    ) -> JsonDict:
         # filter_id is BIGINT UNSIGNED, so if it isn't a number, fail
         # with a coherent error message rather than 500 M_UNKNOWN.
         try:
@@ -40,7 +44,7 @@ class FilteringStore(SQLBaseStore):
 
         return db_to_json(def_json)
 
-    async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str:
+    async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> int:
         def_json = encode_canonical_json(user_filter)
 
         # Need an atomic transaction to SELECT the maximal ID so far then