summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-05-15 11:19:43 -0400
committerGitHub <noreply@github.com>2020-05-15 11:19:43 -0400
commit5355421295315a75df6a39b34f25a9eea293545f (patch)
tree7b629e81d55ab7643ebce05af16b4d0638a6f45a /synapse
parentFix a small typo in the arguments of simple_update in update_remote_profile_c... (diff)
downloadsynapse-5355421295315a75df6a39b34f25a9eea293545f.tar.xz
Add type hints to event_auth code. (#7505)
Diffstat (limited to 'synapse')
-rw-r--r--synapse/event_auth.py78
1 files changed, 46 insertions, 32 deletions
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 5a5b568a95..c582355146 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import Set, Tuple
+from typing import List, Optional, Set, Tuple
 
 from canonicaljson import encode_canonical_json
 from signedjson.key import decode_verify_key_bytes
@@ -29,18 +29,19 @@ from synapse.api.room_versions import (
     EventFormatVersions,
     RoomVersion,
 )
-from synapse.types import UserID, get_domain_from_id
+from synapse.events import EventBase
+from synapse.types import StateMap, UserID, get_domain_from_id
 
 logger = logging.getLogger(__name__)
 
 
 def check(
     room_version_obj: RoomVersion,
-    event,
-    auth_events,
-    do_sig_check=True,
-    do_size_check=True,
-):
+    event: EventBase,
+    auth_events: StateMap[EventBase],
+    do_sig_check: bool = True,
+    do_size_check: bool = True,
+) -> None:
     """ Checks if this event is correctly authed.
 
     Args:
@@ -189,7 +190,7 @@ def check(
     logger.debug("Allowing! %s", event)
 
 
-def _check_size_limits(event):
+def _check_size_limits(event: EventBase) -> None:
     def too_big(field):
         raise EventSizeError("%s too large" % (field,))
 
@@ -207,13 +208,18 @@ def _check_size_limits(event):
         too_big("event")
 
 
-def _can_federate(event, auth_events):
+def _can_federate(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
     creation_event = auth_events.get((EventTypes.Create, ""))
+    # There should always be a creation event, but if not don't federate.
+    if not creation_event:
+        return False
 
     return creation_event.content.get("m.federate", True) is True
 
 
-def _is_membership_change_allowed(event, auth_events):
+def _is_membership_change_allowed(
+    event: EventBase, auth_events: StateMap[EventBase]
+) -> None:
     membership = event.content["membership"]
 
     # Check if this is the room creator joining:
@@ -339,21 +345,25 @@ def _is_membership_change_allowed(event, auth_events):
         raise AuthError(500, "Unknown membership %s" % membership)
 
 
-def _check_event_sender_in_room(event, auth_events):
+def _check_event_sender_in_room(
+    event: EventBase, auth_events: StateMap[EventBase]
+) -> None:
     key = (EventTypes.Member, event.user_id)
     member_event = auth_events.get(key)
 
-    return _check_joined_room(member_event, event.user_id, event.room_id)
+    _check_joined_room(member_event, event.user_id, event.room_id)
 
 
-def _check_joined_room(member, user_id, room_id):
+def _check_joined_room(member: Optional[EventBase], user_id: str, room_id: str) -> None:
     if not member or member.membership != Membership.JOIN:
         raise AuthError(
             403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member))
         )
 
 
-def get_send_level(etype, state_key, power_levels_event):
+def get_send_level(
+    etype: str, state_key: Optional[str], power_levels_event: Optional[EventBase]
+) -> int:
     """Get the power level required to send an event of a given type
 
     The federation spec [1] refers to this as "Required Power Level".
@@ -361,13 +371,13 @@ def get_send_level(etype, state_key, power_levels_event):
     https://matrix.org/docs/spec/server_server/unstable.html#definitions
 
     Args:
-        etype (str): type of event
-        state_key (str|None): state_key of state event, or None if it is not
+        etype: type of event
+        state_key: state_key of state event, or None if it is not
             a state event.
-        power_levels_event (synapse.events.EventBase|None): power levels event
+        power_levels_event: power levels event
             in force at this point in the room
     Returns:
-        int: power level required to send this event.
+        power level required to send this event.
     """
 
     if power_levels_event:
@@ -388,7 +398,7 @@ def get_send_level(etype, state_key, power_levels_event):
     return int(send_level)
 
 
-def _can_send_event(event, auth_events):
+def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
     power_levels_event = _get_power_level_event(auth_events)
 
     send_level = get_send_level(event.type, event.get("state_key"), power_levels_event)
@@ -410,7 +420,9 @@ def _can_send_event(event, auth_events):
     return True
 
 
-def check_redaction(room_version_obj: RoomVersion, event, auth_events):
+def check_redaction(
+    room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase],
+) -> bool:
     """Check whether the event sender is allowed to redact the target event.
 
     Returns:
@@ -442,7 +454,9 @@ def check_redaction(room_version_obj: RoomVersion, event, auth_events):
     raise AuthError(403, "You don't have permission to redact events")
 
 
-def _check_power_levels(room_version_obj, event, auth_events):
+def _check_power_levels(
+    room_version_obj: RoomVersion, event: EventBase, auth_events: StateMap[EventBase],
+) -> None:
     user_list = event.content.get("users", {})
     # Validate users
     for k, v in user_list.items():
@@ -473,7 +487,7 @@ def _check_power_levels(room_version_obj, event, auth_events):
         ("redact", None),
         ("kick", None),
         ("invite", None),
-    ]
+    ]  # type: List[Tuple[str, Optional[str]]]
 
     old_list = current_state.content.get("users", {})
     for user in set(list(old_list) + list(user_list)):
@@ -503,12 +517,12 @@ def _check_power_levels(room_version_obj, event, auth_events):
             new_loc = new_loc.get(dir, {})
 
         if level_to_check in old_loc:
-            old_level = int(old_loc[level_to_check])
+            old_level = int(old_loc[level_to_check])  # type: Optional[int]
         else:
             old_level = None
 
         if level_to_check in new_loc:
-            new_level = int(new_loc[level_to_check])
+            new_level = int(new_loc[level_to_check])  # type: Optional[int]
         else:
             new_level = None
 
@@ -534,21 +548,21 @@ def _check_power_levels(room_version_obj, event, auth_events):
             )
 
 
-def _get_power_level_event(auth_events):
+def _get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]:
     return auth_events.get((EventTypes.PowerLevels, ""))
 
 
-def get_user_power_level(user_id, auth_events):
+def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int:
     """Get a user's power level
 
     Args:
-        user_id (str): user's id to look up in power_levels
-        auth_events (dict[(str, str), synapse.events.EventBase]):
+        user_id: user's id to look up in power_levels
+        auth_events:
             state in force at this point in the room (or rather, a subset of
             it including at least the create event and power levels event.
 
     Returns:
-        int: the user's power level in this room.
+        the user's power level in this room.
     """
     power_level_event = _get_power_level_event(auth_events)
     if power_level_event:
@@ -574,7 +588,7 @@ def get_user_power_level(user_id, auth_events):
             return 0
 
 
-def _get_named_level(auth_events, name, default):
+def _get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int:
     power_level_event = _get_power_level_event(auth_events)
 
     if not power_level_event:
@@ -587,7 +601,7 @@ def _get_named_level(auth_events, name, default):
         return default
 
 
-def _verify_third_party_invite(event, auth_events):
+def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase]):
     """
     Validates that the invite event is authorized by a previous third-party invite.
 
@@ -662,7 +676,7 @@ def get_public_keys(invite_event):
     return public_keys
 
 
-def auth_types_for_event(event) -> Set[Tuple[str, str]]:
+def auth_types_for_event(event: EventBase) -> Set[Tuple[str, str]]:
     """Given an event, return a list of (EventType, StateKey) that may be
     needed to auth the event. The returned list may be a superset of what
     would actually be required depending on the full state of the room.