diff options
author | Erik Johnston <erik@matrix.org> | 2021-05-05 16:35:16 +0100 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2021-05-05 16:35:16 +0100 |
commit | faa7d48930d1c5c92d78d4863a385e9c0974fe42 (patch) | |
tree | 57497375a92643f4a57ca765581a9fd63a1fd807 | |
parent | Compress (diff) | |
download | synapse-faa7d48930d1c5c92d78d4863a385e9c0974fe42.tar.xz |
More ensmalling
-rw-r--r-- | synapse/events/__init__.py | 119 | ||||
-rw-r--r-- | synapse/events/validator.py | 4 |
2 files changed, 75 insertions, 48 deletions
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index ca66dc457a..c04905dfed 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -15,12 +15,12 @@ # limitations under the License. import abc -import attr import os import zlib -from typing import Dict, Optional, Tuple, Type, Union +from typing import Dict, List, Optional, Tuple, Type, Union -from unpaddedbase64 import encode_base64 +import attr +from unpaddedbase64 import decode_base64, encode_base64 from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions from synapse.types import JsonDict, RoomStreamToken @@ -239,15 +239,46 @@ class _Signatures: def get_dict(self) -> JsonDict: return _decode_dict(self._signatures_bytes) + def get(self, server_name): + return self.get_dict().get(server_name) + def update(self, other: Union[JsonDict, "_Signatures"]): if isinstance(other, _Signatures): - other_dict = _decode_dict(other) + other_dict = _decode_dict(other._signatures_bytes) else: other_dict = other signatures = self.get_dict() signatures.update(other_dict) - self._signatures_bytes = _encode_dict(self._signatures_bytes) + self._signatures_bytes = _encode_dict(signatures) + + +class _SmallListV1(str): + __slots__ = [] + + def get(self): + return self.split(",") + + @staticmethod + def create(event_ids): + return _SmallListV1(",".join(event_ids)) + + +class _SmallListV2_V3(bytes): + __slots__ = [] + + def get(self, url_safe): + i = 0 + while i * 32 < len(self): + bit = self[i * 32 : (i + 1) * 32] + i += 1 + yield "$" + encode_base64(bit, urlsafe=url_safe) + + @staticmethod + def create(event_ids): + return _SmallListV2_V3( + b"".join(decode_base64(event_id[1:]) for event_id in event_ids) + ) class EventBase(metaclass=abc.ABCMeta): @@ -257,18 +288,17 @@ class EventBase(metaclass=abc.ABCMeta): "unsigned", "rejected_reason", "_encoded_dict", - "auth_events", + "_auth_event_ids", "depth", "_content", "_hashes", "origin", "origin_server_ts", - "prev_events", + "_prev_event_ids", "redacts", "room_id", "sender", "type", - "user_id", "state_key", "internal_metadata", ] @@ -297,16 +327,13 @@ class EventBase(metaclass=abc.ABCMeta): self._encoded_dict = _encode_dict(event_dict) - self.auth_events = event_dict["auth_events"] self.depth = event_dict["depth"] self.origin = event_dict["origin"] self.origin_server_ts = event_dict["origin_server_ts"] - self.prev_events = event_dict["prev_events"] self.redacts = event_dict.get("redacts") self.room_id = event_dict["room_id"] self.sender = event_dict["sender"] self.type = event_dict["type"] - self.user_id = event_dict["sender"] if "state_key" in event_dict: self.state_key = event_dict["state_key"] @@ -321,10 +348,18 @@ class EventBase(metaclass=abc.ABCMeta): return self.get_dict()["hashes"] @property + def prev_events(self) -> List[str]: + return list(self._prev_events) + + @property def event_id(self) -> str: raise NotImplementedError() @property + def user_id(self) -> str: + return self.sender + + @property def membership(self): return self.content["membership"] @@ -355,24 +390,6 @@ class EventBase(metaclass=abc.ABCMeta): def __set__(self, instance, value): raise AttributeError("Unrecognized attribute %s" % (instance,)) - def prev_event_ids(self): - """Returns the list of prev event IDs. The order matches the order - specified in the event, though there is no meaning to it. - - Returns: - list[str]: The list of event IDs of this event's prev_events - """ - return [e for e, _ in self.prev_events] - - def auth_event_ids(self): - """Returns the list of auth event IDs. The order matches the order - specified in the event, though there is no meaning to it. - - Returns: - list[str]: The list of event IDs of this event's auth_events - """ - return [e for e, _ in self.auth_events] - def freeze(self): """'Freeze' the event dict, so it cannot be modified by accident""" @@ -413,6 +430,12 @@ class FrozenEvent(EventBase): frozen_dict = event_dict self._event_id = event_dict["event_id"] + self._auth_event_ids = _SmallListV1.create( + e for e, _ in event_dict["auth_events"] + ) + self._prev_event_ids = _SmallListV1.create( + e for e, _ in event_dict["prev_events"] + ) super().__init__( frozen_dict, @@ -427,6 +450,12 @@ class FrozenEvent(EventBase): def event_id(self) -> str: return self._event_id + def auth_event_ids(self): + return list(self._auth_event_ids.get()) + + def prev_event_ids(self): + return list(self._prev_event_ids.get()) + def __str__(self): return self.__repr__() @@ -475,6 +504,8 @@ class FrozenEventV2(EventBase): frozen_dict = event_dict self._event_id = None + self._auth_event_ids = _SmallListV2_V3.create(event_dict["auth_events"]) + self._prev_event_ids = _SmallListV2_V3.create(event_dict["prev_events"]) super().__init__( frozen_dict, @@ -496,24 +527,6 @@ class FrozenEventV2(EventBase): self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1]) return self._event_id - def prev_event_ids(self): - """Returns the list of prev event IDs. The order matches the order - specified in the event, though there is no meaning to it. - - Returns: - list[str]: The list of event IDs of this event's prev_events - """ - return self.prev_events - - def auth_event_ids(self): - """Returns the list of auth event IDs. The order matches the order - specified in the event, though there is no meaning to it. - - Returns: - list[str]: The list of event IDs of this event's auth_events - """ - return self.auth_events - def __str__(self): return self.__repr__() @@ -525,6 +538,12 @@ class FrozenEventV2(EventBase): self.state_key if self.is_state() else None, ) + def auth_event_ids(self): + return list(self._auth_event_ids.get(False)) + + def prev_event_ids(self): + return list(self._prev_event_ids.get(False)) + class FrozenEventV3(FrozenEventV2): """FrozenEventV3, which differs from FrozenEventV2 only in the event_id format""" @@ -546,6 +565,12 @@ class FrozenEventV3(FrozenEventV2): ) return self._event_id + def auth_event_ids(self): + return list(self._auth_event_ids.get(True)) + + def prev_event_ids(self): + return list(self._prev_event_ids.get(True)) + def _event_type_from_format_version(format_version: int) -> Type[EventBase]: """Returns the python type to use to construct an Event object for the diff --git a/synapse/events/validator.py b/synapse/events/validator.py index fa6987d7cb..47a74fd5a3 100644 --- a/synapse/events/validator.py +++ b/synapse/events/validator.py @@ -38,6 +38,8 @@ class EventValidator: if event.format_version == EventFormatVersions.V1: EventID.from_string(event.event_id) + event_dict = event.get_dict() + required = [ "auth_events", "content", @@ -49,7 +51,7 @@ class EventValidator: ] for k in required: - if not hasattr(event, k): + if k not in event_dict: raise SynapseError(400, "Event does not have key %s" % (k,)) # Check that the following keys have string values |