summary refs log tree commit diff
path: root/synapse/event_auth.py
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2022-02-11 13:07:55 +0000
committerGitHub <noreply@github.com>2022-02-11 13:07:55 +0000
commit4ef39f3353ae2b2c2ccae78b93bb3a05f1911fe6 (patch)
treefcc475cd759628d61fd53291c4424789d3ededa5 /synapse/event_auth.py
parentAdds misc missing type hints (#11953) (diff)
downloadsynapse-4ef39f3353ae2b2c2ccae78b93bb3a05f1911fe6.tar.xz
fix import cycle (#11965)
Diffstat (limited to '')
-rw-r--r--synapse/event_auth.py54
1 files changed, 31 insertions, 23 deletions
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index 19b55a9559..eca00bc975 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+import typing
 from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 from canonicaljson import encode_canonical_json
@@ -34,15 +35,18 @@ from synapse.api.room_versions import (
     EventFormatVersions,
     RoomVersion,
 )
-from synapse.events import EventBase
-from synapse.events.builder import EventBuilder
 from synapse.types import StateMap, UserID, get_domain_from_id
 
+if typing.TYPE_CHECKING:
+    # conditional imports to avoid import cycle
+    from synapse.events import EventBase
+    from synapse.events.builder import EventBuilder
+
 logger = logging.getLogger(__name__)
 
 
 def validate_event_for_room_version(
-    room_version_obj: RoomVersion, event: EventBase
+    room_version_obj: RoomVersion, event: "EventBase"
 ) -> None:
     """Ensure that the event complies with the limits, and has the right signatures
 
@@ -113,7 +117,9 @@ def validate_event_for_room_version(
 
 
 def check_auth_rules_for_event(
-    room_version_obj: RoomVersion, event: EventBase, auth_events: Iterable[EventBase]
+    room_version_obj: RoomVersion,
+    event: "EventBase",
+    auth_events: Iterable["EventBase"],
 ) -> None:
     """Check that an event complies with the auth rules
 
@@ -256,7 +262,7 @@ def check_auth_rules_for_event(
     logger.debug("Allowing! %s", event)
 
 
-def _check_size_limits(event: EventBase) -> None:
+def _check_size_limits(event: "EventBase") -> None:
     if len(event.user_id) > 255:
         raise EventSizeError("'user_id' too large")
     if len(event.room_id) > 255:
@@ -271,7 +277,7 @@ def _check_size_limits(event: EventBase) -> None:
         raise EventSizeError("event too large")
 
 
-def _can_federate(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
+def _can_federate(event: "EventBase", auth_events: StateMap["EventBase"]) -> bool:
     creation_event = auth_events.get((EventTypes.Create, ""))
     # There should always be a creation event, but if not don't federate.
     if not creation_event:
@@ -281,7 +287,7 @@ def _can_federate(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
 
 
 def _is_membership_change_allowed(
-    room_version: RoomVersion, event: EventBase, auth_events: StateMap[EventBase]
+    room_version: RoomVersion, event: "EventBase", auth_events: StateMap["EventBase"]
 ) -> None:
     """
     Confirms that the event which changes membership is an allowed change.
@@ -471,7 +477,7 @@ def _is_membership_change_allowed(
 
 
 def _check_event_sender_in_room(
-    event: EventBase, auth_events: StateMap[EventBase]
+    event: "EventBase", auth_events: StateMap["EventBase"]
 ) -> None:
     key = (EventTypes.Member, event.user_id)
     member_event = auth_events.get(key)
@@ -479,7 +485,9 @@ def _check_event_sender_in_room(
     _check_joined_room(member_event, event.user_id, event.room_id)
 
 
-def _check_joined_room(member: Optional[EventBase], user_id: str, room_id: str) -> None:
+def _check_joined_room(
+    member: Optional["EventBase"], user_id: str, room_id: str
+) -> None:
     if not member or member.membership != Membership.JOIN:
         raise AuthError(
             403, "User %s not in room %s (%s)" % (user_id, room_id, repr(member))
@@ -487,7 +495,7 @@ def _check_joined_room(member: Optional[EventBase], user_id: str, room_id: str)
 
 
 def get_send_level(
-    etype: str, state_key: Optional[str], power_levels_event: Optional[EventBase]
+    etype: str, state_key: Optional[str], power_levels_event: Optional["EventBase"]
 ) -> int:
     """Get the power level required to send an event of a given type
 
@@ -523,7 +531,7 @@ def get_send_level(
     return int(send_level)
 
 
-def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
+def _can_send_event(event: "EventBase", auth_events: StateMap["EventBase"]) -> bool:
     power_levels_event = get_power_level_event(auth_events)
 
     send_level = get_send_level(event.type, event.get("state_key"), power_levels_event)
@@ -547,8 +555,8 @@ def _can_send_event(event: EventBase, auth_events: StateMap[EventBase]) -> bool:
 
 def check_redaction(
     room_version_obj: RoomVersion,
-    event: EventBase,
-    auth_events: StateMap[EventBase],
+    event: "EventBase",
+    auth_events: StateMap["EventBase"],
 ) -> bool:
     """Check whether the event sender is allowed to redact the target event.
 
@@ -585,8 +593,8 @@ def check_redaction(
 
 def check_historical(
     room_version_obj: RoomVersion,
-    event: EventBase,
-    auth_events: StateMap[EventBase],
+    event: "EventBase",
+    auth_events: StateMap["EventBase"],
 ) -> None:
     """Check whether the event sender is allowed to send historical related
     events like "insertion", "batch", and "marker".
@@ -616,8 +624,8 @@ def check_historical(
 
 def _check_power_levels(
     room_version_obj: RoomVersion,
-    event: EventBase,
-    auth_events: StateMap[EventBase],
+    event: "EventBase",
+    auth_events: StateMap["EventBase"],
 ) -> None:
     user_list = event.content.get("users", {})
     # Validate users
@@ -710,11 +718,11 @@ def _check_power_levels(
             )
 
 
-def get_power_level_event(auth_events: StateMap[EventBase]) -> Optional[EventBase]:
+def get_power_level_event(auth_events: StateMap["EventBase"]) -> Optional["EventBase"]:
     return auth_events.get((EventTypes.PowerLevels, ""))
 
 
-def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int:
+def get_user_power_level(user_id: str, auth_events: StateMap["EventBase"]) -> int:
     """Get a user's power level
 
     Args:
@@ -750,7 +758,7 @@ def get_user_power_level(user_id: str, auth_events: StateMap[EventBase]) -> int:
             return 0
 
 
-def get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -> int:
+def get_named_level(auth_events: StateMap["EventBase"], name: str, default: int) -> int:
     power_level_event = get_power_level_event(auth_events)
 
     if not power_level_event:
@@ -764,7 +772,7 @@ def get_named_level(auth_events: StateMap[EventBase], name: str, default: int) -
 
 
 def _verify_third_party_invite(
-    event: EventBase, auth_events: StateMap[EventBase]
+    event: "EventBase", auth_events: StateMap["EventBase"]
 ) -> bool:
     """
     Validates that the invite event is authorized by a previous third-party invite.
@@ -829,7 +837,7 @@ def _verify_third_party_invite(
     return False
 
 
-def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]:
+def get_public_keys(invite_event: "EventBase") -> List[Dict[str, Any]]:
     public_keys = []
     if "public_key" in invite_event.content:
         o = {"public_key": invite_event.content["public_key"]}
@@ -841,7 +849,7 @@ def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]:
 
 
 def auth_types_for_event(
-    room_version: RoomVersion, event: Union[EventBase, EventBuilder]
+    room_version: RoomVersion, event: Union["EventBase", "EventBuilder"]
 ) -> Set[Tuple[str, str]]:
     """Given an event, return a list of (EventType, StateKey) that may be
     needed to auth the event. The returned list may be a superset of what