diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index a842661a90..512254f65d 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -37,6 +37,65 @@ from synapse.util.frozenutils import freeze
USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))
+class DictProperty:
+ """An object property which delegates to the `_dict` within its parent object."""
+
+ __slots__ = ["key"]
+
+ def __init__(self, key: str):
+ self.key = key
+
+ def __get__(self, instance, owner=None):
+ # if the property is accessed as a class property rather than an instance
+ # property, return the property itself rather than the value
+ if instance is None:
+ return self
+ try:
+ return instance._dict[self.key]
+ except KeyError as e1:
+ # We want this to look like a regular attribute error (mostly so that
+ # hasattr() works correctly), so we convert the KeyError into an
+ # AttributeError.
+ #
+ # To exclude the KeyError from the traceback, we explicitly
+ # 'raise from e1.__context__' (which is better than 'raise from None',
+ # becuase that would omit any *earlier* exceptions).
+ #
+ raise AttributeError(
+ "'%s' has no '%s' property" % (type(instance), self.key)
+ ) from e1.__context__
+
+ def __set__(self, instance, v):
+ instance._dict[self.key] = v
+
+ def __delete__(self, instance):
+ try:
+ del instance._dict[self.key]
+ except KeyError as e1:
+ raise AttributeError(
+ "'%s' has no '%s' property" % (type(instance), self.key)
+ ) from e1.__context__
+
+
+class DefaultDictProperty(DictProperty):
+ """An extension of DictProperty which provides a default if the property is
+ not present in the parent's _dict.
+
+ Note that this means that hasattr() on the property always returns True.
+ """
+
+ __slots__ = ["default"]
+
+ def __init__(self, key, default):
+ super().__init__(key)
+ self.default = default
+
+ def __get__(self, instance, owner=None):
+ if instance is None:
+ return self
+ return instance._dict.get(self.key, self.default)
+
+
class _EventInternalMetadata(object):
def __init__(self, internal_metadata_dict):
self.__dict__ = dict(internal_metadata_dict)
@@ -117,51 +176,6 @@ class _EventInternalMetadata(object):
return getattr(self, "redacted", False)
-_SENTINEL = object()
-
-
-def _event_dict_property(key, default=_SENTINEL):
- """Creates a new property for the given key that delegates access to
- `self._event_dict`.
-
- The default is used if the key is missing from the `_event_dict`, if given,
- otherwise an AttributeError will be raised.
-
- Note: If a default is given then `hasattr` will always return true.
- """
-
- # We want to be able to use hasattr with the event dict properties.
- # However, (on python3) hasattr expects AttributeError to be raised. Hence,
- # we need to transform the KeyError into an AttributeError
-
- def getter_raises(self):
- try:
- return self._event_dict[key]
- except KeyError:
- raise AttributeError(key)
-
- def getter_default(self):
- return self._event_dict.get(key, default)
-
- def setter(self, v):
- try:
- self._event_dict[key] = v
- except KeyError:
- raise AttributeError(key)
-
- def delete(self):
- try:
- del self._event_dict[key]
- except KeyError:
- raise AttributeError(key)
-
- if default is _SENTINEL:
- # No default given, so use the getter that raises
- return property(getter_raises, setter, delete)
- else:
- return property(getter_default, setter, delete)
-
-
class EventBase(object):
def __init__(
self,
@@ -175,23 +189,23 @@ class EventBase(object):
self.unsigned = unsigned
self.rejected_reason = rejected_reason
- self._event_dict = event_dict
+ self._dict = event_dict
self.internal_metadata = _EventInternalMetadata(internal_metadata_dict)
- auth_events = _event_dict_property("auth_events")
- depth = _event_dict_property("depth")
- content = _event_dict_property("content")
- hashes = _event_dict_property("hashes")
- origin = _event_dict_property("origin")
- origin_server_ts = _event_dict_property("origin_server_ts")
- prev_events = _event_dict_property("prev_events")
- redacts = _event_dict_property("redacts", None)
- room_id = _event_dict_property("room_id")
- sender = _event_dict_property("sender")
- state_key = _event_dict_property("state_key")
- type = _event_dict_property("type")
- user_id = _event_dict_property("sender")
+ auth_events = DictProperty("auth_events")
+ depth = DictProperty("depth")
+ content = DictProperty("content")
+ hashes = DictProperty("hashes")
+ origin = DictProperty("origin")
+ origin_server_ts = DictProperty("origin_server_ts")
+ prev_events = DictProperty("prev_events")
+ redacts = DefaultDictProperty("redacts", None)
+ room_id = DictProperty("room_id")
+ sender = DictProperty("sender")
+ state_key = DictProperty("state_key")
+ type = DictProperty("type")
+ user_id = DictProperty("sender")
@property
def event_id(self) -> str:
@@ -205,13 +219,13 @@ class EventBase(object):
return hasattr(self, "state_key") and self.state_key is not None
def get_dict(self) -> JsonDict:
- d = dict(self._event_dict)
+ d = dict(self._dict)
d.update({"signatures": self.signatures, "unsigned": dict(self.unsigned)})
return d
def get(self, key, default=None):
- return self._event_dict.get(key, default)
+ return self._dict.get(key, default)
def get_internal_metadata_dict(self):
return self.internal_metadata.get_dict()
@@ -233,16 +247,16 @@ class EventBase(object):
raise AttributeError("Unrecognized attribute %s" % (instance,))
def __getitem__(self, field):
- return self._event_dict[field]
+ return self._dict[field]
def __contains__(self, field):
- return field in self._event_dict
+ return field in self._dict
def items(self):
- return list(self._event_dict.items())
+ return list(self._dict.items())
def keys(self):
- return six.iterkeys(self._event_dict)
+ return six.iterkeys(self._dict)
def prev_event_ids(self):
"""Returns the list of prev event IDs. The order matches the order
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index b9ee6ec1ec..db3667dc43 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -240,7 +240,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
built_event = yield self._base_builder.build(prev_event_ids)
built_event._event_id = self._event_id
- built_event._event_dict["event_id"] = self._event_id
+ built_event._dict["event_id"] = self._event_id
assert built_event.event_id == self._event_id
return built_event
|