summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-05-05 16:35:16 +0100
committerErik Johnston <erik@matrix.org>2021-05-05 16:35:16 +0100
commitfaa7d48930d1c5c92d78d4863a385e9c0974fe42 (patch)
tree57497375a92643f4a57ca765581a9fd63a1fd807
parentCompress (diff)
downloadsynapse-faa7d48930d1c5c92d78d4863a385e9c0974fe42.tar.xz
More ensmalling
-rw-r--r--synapse/events/__init__.py119
-rw-r--r--synapse/events/validator.py4
2 files changed, 75 insertions, 48 deletions
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index ca66dc457a..c04905dfed 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -15,12 +15,12 @@
 # limitations under the License.
 
 import abc
-import attr
 import os
 import zlib
-from typing import Dict, Optional, Tuple, Type, Union
+from typing import Dict, List, Optional, Tuple, Type, Union
 
-from unpaddedbase64 import encode_base64
+import attr
+from unpaddedbase64 import decode_base64, encode_base64
 
 from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
 from synapse.types import JsonDict, RoomStreamToken
@@ -239,15 +239,46 @@ class _Signatures:
     def get_dict(self) -> JsonDict:
         return _decode_dict(self._signatures_bytes)
 
+    def get(self, server_name):
+        return self.get_dict().get(server_name)
+
     def update(self, other: Union[JsonDict, "_Signatures"]):
         if isinstance(other, _Signatures):
-            other_dict = _decode_dict(other)
+            other_dict = _decode_dict(other._signatures_bytes)
         else:
             other_dict = other
 
         signatures = self.get_dict()
         signatures.update(other_dict)
-        self._signatures_bytes = _encode_dict(self._signatures_bytes)
+        self._signatures_bytes = _encode_dict(signatures)
+
+
+class _SmallListV1(str):
+    __slots__ = []
+
+    def get(self):
+        return self.split(",")
+
+    @staticmethod
+    def create(event_ids):
+        return _SmallListV1(",".join(event_ids))
+
+
+class _SmallListV2_V3(bytes):
+    __slots__ = []
+
+    def get(self, url_safe):
+        i = 0
+        while i * 32 < len(self):
+            bit = self[i * 32 : (i + 1) * 32]
+            i += 1
+            yield "$" + encode_base64(bit, urlsafe=url_safe)
+
+    @staticmethod
+    def create(event_ids):
+        return _SmallListV2_V3(
+            b"".join(decode_base64(event_id[1:]) for event_id in event_ids)
+        )
 
 
 class EventBase(metaclass=abc.ABCMeta):
@@ -257,18 +288,17 @@ class EventBase(metaclass=abc.ABCMeta):
         "unsigned",
         "rejected_reason",
         "_encoded_dict",
-        "auth_events",
+        "_auth_event_ids",
         "depth",
         "_content",
         "_hashes",
         "origin",
         "origin_server_ts",
-        "prev_events",
+        "_prev_event_ids",
         "redacts",
         "room_id",
         "sender",
         "type",
-        "user_id",
         "state_key",
         "internal_metadata",
     ]
@@ -297,16 +327,13 @@ class EventBase(metaclass=abc.ABCMeta):
 
         self._encoded_dict = _encode_dict(event_dict)
 
-        self.auth_events = event_dict["auth_events"]
         self.depth = event_dict["depth"]
         self.origin = event_dict["origin"]
         self.origin_server_ts = event_dict["origin_server_ts"]
-        self.prev_events = event_dict["prev_events"]
         self.redacts = event_dict.get("redacts")
         self.room_id = event_dict["room_id"]
         self.sender = event_dict["sender"]
         self.type = event_dict["type"]
-        self.user_id = event_dict["sender"]
         if "state_key" in event_dict:
             self.state_key = event_dict["state_key"]
 
@@ -321,10 +348,18 @@ class EventBase(metaclass=abc.ABCMeta):
         return self.get_dict()["hashes"]
 
     @property
+    def prev_events(self) -> List[str]:
+        return list(self._prev_events)
+
+    @property
     def event_id(self) -> str:
         raise NotImplementedError()
 
     @property
+    def user_id(self) -> str:
+        return self.sender
+
+    @property
     def membership(self):
         return self.content["membership"]
 
@@ -355,24 +390,6 @@ class EventBase(metaclass=abc.ABCMeta):
     def __set__(self, instance, value):
         raise AttributeError("Unrecognized attribute %s" % (instance,))
 
-    def prev_event_ids(self):
-        """Returns the list of prev event IDs. The order matches the order
-        specified in the event, though there is no meaning to it.
-
-        Returns:
-            list[str]: The list of event IDs of this event's prev_events
-        """
-        return [e for e, _ in self.prev_events]
-
-    def auth_event_ids(self):
-        """Returns the list of auth event IDs. The order matches the order
-        specified in the event, though there is no meaning to it.
-
-        Returns:
-            list[str]: The list of event IDs of this event's auth_events
-        """
-        return [e for e, _ in self.auth_events]
-
     def freeze(self):
         """'Freeze' the event dict, so it cannot be modified by accident"""
 
@@ -413,6 +430,12 @@ class FrozenEvent(EventBase):
             frozen_dict = event_dict
 
         self._event_id = event_dict["event_id"]
+        self._auth_event_ids = _SmallListV1.create(
+            e for e, _ in event_dict["auth_events"]
+        )
+        self._prev_event_ids = _SmallListV1.create(
+            e for e, _ in event_dict["prev_events"]
+        )
 
         super().__init__(
             frozen_dict,
@@ -427,6 +450,12 @@ class FrozenEvent(EventBase):
     def event_id(self) -> str:
         return self._event_id
 
+    def auth_event_ids(self):
+        return list(self._auth_event_ids.get())
+
+    def prev_event_ids(self):
+        return list(self._prev_event_ids.get())
+
     def __str__(self):
         return self.__repr__()
 
@@ -475,6 +504,8 @@ class FrozenEventV2(EventBase):
             frozen_dict = event_dict
 
         self._event_id = None
+        self._auth_event_ids = _SmallListV2_V3.create(event_dict["auth_events"])
+        self._prev_event_ids = _SmallListV2_V3.create(event_dict["prev_events"])
 
         super().__init__(
             frozen_dict,
@@ -496,24 +527,6 @@ class FrozenEventV2(EventBase):
         self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
         return self._event_id
 
-    def prev_event_ids(self):
-        """Returns the list of prev event IDs. The order matches the order
-        specified in the event, though there is no meaning to it.
-
-        Returns:
-            list[str]: The list of event IDs of this event's prev_events
-        """
-        return self.prev_events
-
-    def auth_event_ids(self):
-        """Returns the list of auth event IDs. The order matches the order
-        specified in the event, though there is no meaning to it.
-
-        Returns:
-            list[str]: The list of event IDs of this event's auth_events
-        """
-        return self.auth_events
-
     def __str__(self):
         return self.__repr__()
 
@@ -525,6 +538,12 @@ class FrozenEventV2(EventBase):
             self.state_key if self.is_state() else None,
         )
 
+    def auth_event_ids(self):
+        return list(self._auth_event_ids.get(False))
+
+    def prev_event_ids(self):
+        return list(self._prev_event_ids.get(False))
+
 
 class FrozenEventV3(FrozenEventV2):
     """FrozenEventV3, which differs from FrozenEventV2 only in the event_id format"""
@@ -546,6 +565,12 @@ class FrozenEventV3(FrozenEventV2):
         )
         return self._event_id
 
+    def auth_event_ids(self):
+        return list(self._auth_event_ids.get(True))
+
+    def prev_event_ids(self):
+        return list(self._prev_event_ids.get(True))
+
 
 def _event_type_from_format_version(format_version: int) -> Type[EventBase]:
     """Returns the python type to use to construct an Event object for the
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index fa6987d7cb..47a74fd5a3 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -38,6 +38,8 @@ class EventValidator:
         if event.format_version == EventFormatVersions.V1:
             EventID.from_string(event.event_id)
 
+        event_dict = event.get_dict()
+
         required = [
             "auth_events",
             "content",
@@ -49,7 +51,7 @@ class EventValidator:
         ]
 
         for k in required:
-            if not hasattr(event, k):
+            if k not in event_dict:
                 raise SynapseError(400, "Event does not have key %s" % (k,))
 
         # Check that the following keys have string values