diff options
author | Erik Johnston <erik@matrix.org> | 2020-01-31 13:02:27 +0000 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2020-01-31 13:02:27 +0000 |
commit | 5013b3a49391548401e1ecbcf3e25bfaa31bbbb3 (patch) | |
tree | 6ac896dfe7258d1fea9bd9911c5d0e2796d13137 | |
parent | Add types to function signatures in SyncHandler (diff) | |
download | synapse-5013b3a49391548401e1ecbcf3e25bfaa31bbbb3.tar.xz |
Use a FrozenEvent base class for types
-rw-r--r-- | synapse/events/__init__.py | 23 | ||||
-rw-r--r-- | synapse/handlers/sync.py | 30 |
2 files changed, 37 insertions, 16 deletions
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index f813fa2fe7..2691ac5545 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -257,7 +257,26 @@ class EventBase(object): return [e for e, _ in self.auth_events] -class FrozenEvent(EventBase): +class FrozenEventBase(EventBase): + """Base class for fully initialised events. + """ + + @property + def event_id(self) -> str: + raise NotImplementedError() + + @property + def type(self) -> str: + raise NotImplementedError() + + @property + def state_key(self) -> str: + """Raises if there is no state key. + """ + raise NotImplementedError() + + +class FrozenEvent(FrozenEventBase): format_version = EventFormatVersions.V1 # All events of this type are V1 def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None): @@ -305,7 +324,7 @@ class FrozenEvent(EventBase): ) -class FrozenEventV2(EventBase): +class FrozenEventV2(FrozenEventBase): format_version = EventFormatVersions.V2 # All events of this type are V2 def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None): diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 6bd21ff761..3933f06c54 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -25,18 +25,18 @@ from prometheus_client import Counter from synapse.api.constants import EventTypes, Membership from synapse.api.filtering import FilterCollection -from synapse.events import EventBase +from synapse.events import FrozenEventBase from synapse.logging.context import LoggingContext from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( + Collection, JsonDict, RoomStreamToken, StateMap, StreamToken, UserID, - Collection, ) from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.expiringcache import ExpiringCache @@ -84,7 +84,7 @@ class SyncConfig: @attr.s(slots=True, frozen=True) class TimelineBatch: prev_batch = attr.ib(type=StreamToken) - events = attr.ib(type=List[EventBase]) + events = attr.ib(type=List[FrozenEventBase]) limited = attr.ib(bool) def __nonzero__(self) -> bool: @@ -100,7 +100,7 @@ class TimelineBatch: class JoinedSyncResult: room_id = attr.ib(type=str) timeline = attr.ib(type=TimelineBatch) - state = attr.ib(type=StateMap[EventBase]) + state = attr.ib(type=StateMap[FrozenEventBase]) ephemeral = attr.ib(type=List[JsonDict]) account_data = attr.ib(type=List[JsonDict]) unread_notifications = attr.ib(type=JsonDict) @@ -126,7 +126,7 @@ class JoinedSyncResult: class ArchivedSyncResult: room_id = attr.ib(type=str) timeline = attr.ib(type=TimelineBatch) - state = attr.ib(type=StateMap[EventBase]) + state = attr.ib(type=StateMap[FrozenEventBase]) account_data = attr.ib(type=List[JsonDict]) def __nonzero__(self) -> bool: @@ -141,7 +141,7 @@ class ArchivedSyncResult: @attr.s(slots=True, frozen=True) class InvitedSyncResult: room_id = attr.ib(type=str) - invite = attr.ib(type=EventBase) + invite = attr.ib(type=FrozenEventBase) def __nonzero__(self) -> bool: """Invited rooms should always be reported to the client""" @@ -419,7 +419,7 @@ class SyncHandler(object): sync_config: SyncConfig, now_token: StreamToken, since_token: Optional[StreamToken] = None, - potential_recents: Optional[List[EventBase]] = None, + potential_recents: Optional[List[FrozenEventBase]] = None, newly_joined_room: bool = False, ) -> TimelineBatch: """ @@ -539,7 +539,7 @@ class SyncHandler(object): ) async def get_state_after_event( - self, event: EventBase, state_filter: StateFilter = StateFilter.all() + self, event: FrozenEventBase, state_filter: StateFilter = StateFilter.all() ) -> StateMap[str]: """ Get the room state after the given event @@ -593,7 +593,7 @@ class SyncHandler(object): room_id: str, sync_config: SyncConfig, batch: TimelineBatch, - state: StateMap[EventBase], + state: StateMap[FrozenEventBase], now_token: StreamToken, ) -> Optional[JsonDict]: """ Works out a room summary block for this room, summarising the number @@ -743,7 +743,7 @@ class SyncHandler(object): since_token: Optional[StreamToken], now_token: StreamToken, full_state: bool, - ) -> StateMap[EventBase]: + ) -> StateMap[FrozenEventBase]: """ Works out the difference in state between the start of the timeline and the previous sync. @@ -922,7 +922,7 @@ class SyncHandler(object): if t[0] == EventTypes.Member: cache.set(t[1], event_id) - state = {} # type: Dict[str, EventBase] + state = {} # type: Dict[str, FrozenEventBase] if state_ids: state = await self.store.get_events(list(state_ids.values())) @@ -1488,7 +1488,7 @@ class SyncHandler(object): user_id, since_token.room_key, now_token.room_key ) - mem_change_events_by_room_id = {} # type: Dict[str, List[EventBase]] + mem_change_events_by_room_id = {} # type: Dict[str, List[FrozenEventBase]] for event in rooms_changed: mem_change_events_by_room_id.setdefault(event.room_id, []).append(event) @@ -1607,7 +1607,9 @@ class SyncHandler(object): # This is all screaming out for a refactor, as the logic here is # subtle and the moving parts numerous. if leave_event.internal_metadata.is_out_of_band_membership(): - batch_events = [leave_event] # type: Optional[List[EventBase]] + batch_events = [ + leave_event + ] # type: Optional[List[FrozenEventBase]] else: batch_events = None @@ -2070,7 +2072,7 @@ class RoomSyncResultBuilder(object): room_id = attr.ib(type=str) rtype = attr.ib(type=str) - events = attr.ib(type=Optional[List[EventBase]]) + events = attr.ib(type=Optional[List[FrozenEventBase]]) newly_joined = attr.ib(type=bool) full_state = attr.ib(type=bool) since_token = attr.ib(type=Optional[StreamToken]) |