summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9676.misc1
-rw-r--r--mypy.ini5
-rw-r--r--synapse/events/third_party_rules.py15
-rw-r--r--synapse/secrets.py8
-rw-r--r--synapse/storage/state.py4
-rw-r--r--synapse/visibility.py78
6 files changed, 57 insertions, 54 deletions
diff --git a/changelog.d/9676.misc b/changelog.d/9676.misc
new file mode 100644
index 0000000000..829e38b938
--- /dev/null
+++ b/changelog.d/9676.misc
@@ -0,0 +1 @@
+Add type hints to third party event rules and visibility modules.
diff --git a/mypy.ini b/mypy.ini
index e0685e097c..709a8d07a5 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -20,8 +20,9 @@ files =
   synapse/crypto,
   synapse/event_auth.py,
   synapse/events/builder.py,
-  synapse/events/validator.py,
   synapse/events/spamcheck.py,
+  synapse/events/third_party_rules.py,
+  synapse/events/validator.py,
   synapse/federation,
   synapse/groups,
   synapse/handlers,
@@ -38,6 +39,7 @@ files =
   synapse/push,
   synapse/replication,
   synapse/rest,
+  synapse/secrets.py,
   synapse/server.py,
   synapse/server_notices,
   synapse/spam_checker_api,
@@ -71,6 +73,7 @@ files =
   synapse/util/metrics.py,
   synapse/util/macaroons.py,
   synapse/util/stringutils.py,
+  synapse/visibility.py,
   tests/replication,
   tests/test_utils,
   tests/handlers/test_password_providers.py,
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 02bce8b5c9..9767d23940 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -13,12 +13,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Callable, Union
+from typing import TYPE_CHECKING, Union
 
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.types import Requester, StateMap
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 
 class ThirdPartyEventRules:
     """Allows server admins to provide a Python module implementing an extra
@@ -28,7 +31,7 @@ class ThirdPartyEventRules:
     behaviours.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.third_party_rules = None
 
         self.store = hs.get_datastore()
@@ -95,10 +98,9 @@ class ThirdPartyEventRules:
         if self.third_party_rules is None:
             return True
 
-        ret = await self.third_party_rules.on_create_room(
+        return await self.third_party_rules.on_create_room(
             requester, config, is_requester_admin
         )
-        return ret
 
     async def check_threepid_can_be_invited(
         self, medium: str, address: str, room_id: str
@@ -119,10 +121,9 @@ class ThirdPartyEventRules:
 
         state_events = await self._get_state_map_for_room(room_id)
 
-        ret = await self.third_party_rules.check_threepid_can_be_invited(
+        return await self.third_party_rules.check_threepid_can_be_invited(
             medium, address, state_events
         )
-        return ret
 
     async def check_visibility_can_be_modified(
         self, room_id: str, new_visibility: str
@@ -143,7 +144,7 @@ class ThirdPartyEventRules:
         check_func = getattr(
             self.third_party_rules, "check_visibility_can_be_modified", None
         )
-        if not check_func or not isinstance(check_func, Callable):
+        if not check_func or not callable(check_func):
             return True
 
         state_events = await self._get_state_map_for_room(room_id)
diff --git a/synapse/secrets.py b/synapse/secrets.py
index fb6d90a3b7..7939db75e7 100644
--- a/synapse/secrets.py
+++ b/synapse/secrets.py
@@ -26,10 +26,10 @@ if sys.version_info[0:2] >= (3, 6):
     import secrets
 
     class Secrets:
-        def token_bytes(self, nbytes=32):
+        def token_bytes(self, nbytes: int = 32) -> bytes:
             return secrets.token_bytes(nbytes)
 
-        def token_hex(self, nbytes=32):
+        def token_hex(self, nbytes: int = 32) -> str:
             return secrets.token_hex(nbytes)
 
 
@@ -38,8 +38,8 @@ else:
     import os
 
     class Secrets:
-        def token_bytes(self, nbytes=32):
+        def token_bytes(self, nbytes: int = 32) -> bytes:
             return os.urandom(nbytes)
 
-        def token_hex(self, nbytes=32):
+        def token_hex(self, nbytes: int = 32) -> str:
             return binascii.hexlify(self.token_bytes(nbytes)).decode("ascii")
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index aa25bd8350..2e277a21c4 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -449,7 +449,7 @@ class StateGroupStorage:
         return self.stores.state._get_state_groups_from_groups(groups, state_filter)
 
     async def get_state_for_events(
-        self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
+        self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all()
     ) -> Dict[str, StateMap[EventBase]]:
         """Given a list of event_ids and type tuples, return a list of state
         dicts for each event.
@@ -485,7 +485,7 @@ class StateGroupStorage:
         return {event: event_to_state[event] for event in event_ids}
 
     async def get_state_ids_for_events(
-        self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
+        self, event_ids: Iterable[str], state_filter: StateFilter = StateFilter.all()
     ) -> Dict[str, StateMap[str]]:
         """
         Get the state dicts corresponding to a list of events, containing the event_ids
diff --git a/synapse/visibility.py b/synapse/visibility.py
index e39d02602a..ff53a49b3a 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-import operator
+from typing import Dict, FrozenSet, List, Optional
 
 from synapse.api.constants import (
     AccountDataTypes,
@@ -21,10 +21,11 @@ from synapse.api.constants import (
     HistoryVisibility,
     Membership,
 )
+from synapse.events import EventBase
 from synapse.events.utils import prune_event
 from synapse.storage import Storage
 from synapse.storage.state import StateFilter
-from synapse.types import get_domain_from_id
+from synapse.types import StateMap, get_domain_from_id
 
 logger = logging.getLogger(__name__)
 
@@ -48,32 +49,32 @@ MEMBERSHIP_PRIORITY = (
 
 async def filter_events_for_client(
     storage: Storage,
-    user_id,
-    events,
-    is_peeking=False,
-    always_include_ids=frozenset(),
-    filter_send_to_client=True,
-):
+    user_id: str,
+    events: List[EventBase],
+    is_peeking: bool = False,
+    always_include_ids: FrozenSet[str] = frozenset(),
+    filter_send_to_client: bool = True,
+) -> List[EventBase]:
     """
     Check which events a user is allowed to see. If the user can see the event but its
     sender asked for their data to be erased, prune the content of the event.
 
     Args:
         storage
-        user_id(str): user id to be checked
-        events(list[synapse.events.EventBase]): sequence of events to be checked
-        is_peeking(bool): should be True if:
+        user_id: user id to be checked
+        events: sequence of events to be checked
+        is_peeking: should be True if:
           * the user is not currently a member of the room, and:
           * the user has not been a member of the room since the given
             events
-        always_include_ids (set(event_id)): set of event ids to specifically
+        always_include_ids: set of event ids to specifically
             include (unless sender is ignored)
-        filter_send_to_client (bool): Whether we're checking an event that's going to be
+        filter_send_to_client: Whether we're checking an event that's going to be
             sent to a client. This might not always be the case since this function can
             also be called to check whether a user can see the state at a given point.
 
     Returns:
-        list[synapse.events.EventBase]
+        The filtered events.
     """
     # Filter out events that have been soft failed so that we don't relay them
     # to clients.
@@ -90,7 +91,7 @@ async def filter_events_for_client(
         AccountDataTypes.IGNORED_USER_LIST, user_id
     )
 
-    ignore_list = frozenset()
+    ignore_list = frozenset()  # type: FrozenSet[str]
     if ignore_dict_content:
         ignored_users_dict = ignore_dict_content.get("ignored_users", {})
         if isinstance(ignored_users_dict, dict):
@@ -107,19 +108,18 @@ async def filter_events_for_client(
                 room_id
             ] = await storage.main.get_retention_policy_for_room(room_id)
 
-    def allowed(event):
+    def allowed(event: EventBase) -> Optional[EventBase]:
         """
         Args:
-            event (synapse.events.EventBase): event to check
+            event: event to check
 
         Returns:
-            None|EventBase:
-               None if the user cannot see this event at all
+           None if the user cannot see this event at all
 
-               a redacted copy of the event if they can only see a redacted
-               version
+           a redacted copy of the event if they can only see a redacted
+           version
 
-               the original event if they can see it as normal.
+           the original event if they can see it as normal.
         """
         # Only run some checks if these events aren't about to be sent to clients. This is
         # because, if this is not the case, we're probably only checking if the users can
@@ -252,48 +252,46 @@ async def filter_events_for_client(
 
         return event
 
-    # check each event: gives an iterable[None|EventBase]
+    # Check each event: gives an iterable of None or (a potentially modified)
+    # EventBase.
     filtered_events = map(allowed, events)
 
-    # remove the None entries
-    filtered_events = filter(operator.truth, filtered_events)
-
-    # we turn it into a list before returning it.
-    return list(filtered_events)
+    # Turn it into a list and remove None entries before returning.
+    return [ev for ev in filtered_events if ev]
 
 
 async def filter_events_for_server(
     storage: Storage,
-    server_name,
-    events,
-    redact=True,
-    check_history_visibility_only=False,
-):
+    server_name: str,
+    events: List[EventBase],
+    redact: bool = True,
+    check_history_visibility_only: bool = False,
+) -> List[EventBase]:
     """Filter a list of events based on whether given server is allowed to
     see them.
 
     Args:
         storage
-        server_name (str)
-        events (iterable[FrozenEvent])
-        redact (bool): Whether to return a redacted version of the event, or
+        server_name
+        events
+        redact: Whether to return a redacted version of the event, or
             to filter them out entirely.
-        check_history_visibility_only (bool): Whether to only check the
+        check_history_visibility_only: Whether to only check the
             history visibility, rather than things like if the sender has been
             erased. This is used e.g. during pagination to decide whether to
             backfill or not.
 
     Returns
-        list[FrozenEvent]
+        The filtered events.
     """
 
-    def is_sender_erased(event, erased_senders):
+    def is_sender_erased(event: EventBase, erased_senders: Dict[str, bool]) -> bool:
         if erased_senders and erased_senders[event.sender]:
             logger.info("Sender of %s has been erased, redacting", event.event_id)
             return True
         return False
 
-    def check_event_is_visible(event, state):
+    def check_event_is_visible(event: EventBase, state: StateMap[EventBase]) -> bool:
         history = state.get((EventTypes.RoomHistoryVisibility, ""), None)
         if history:
             visibility = history.content.get(