summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py19
-rw-r--r--synapse/crypto/event_signing.py2
-rw-r--r--synapse/events/snapshot.py4
-rw-r--r--synapse/events/utils.py2
-rw-r--r--synapse/events/validator.py2
-rw-r--r--synapse/state/__init__.py10
-rw-r--r--synapse/storage/databases/main/stream.py4
-rw-r--r--synapse/types/__init__.py12
-rw-r--r--synapse/types/state.py26
-rw-r--r--synapse/util/__init__.py20
-rw-r--r--synapse/util/frozenutils.py6
11 files changed, 59 insertions, 48 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index a203ed533a..b97ee59f15 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -17,9 +17,9 @@
 """ This is an implementation of a Matrix homeserver.
 """
 
-import json
 import os
 import sys
+from typing import Any, Dict
 
 from synapse.util.rust import check_rust_lib_up_to_date
 from synapse.util.stringutils import strtobool
@@ -61,11 +61,20 @@ try:
 except ImportError:
     pass
 
-# Use the standard library json implementation instead of simplejson.
+# Teach canonicaljson how to serialise immutabledicts.
 try:
-    from canonicaljson import set_json_library
-
-    set_json_library(json)
+    from canonicaljson import register_preserialisation_callback
+    from immutabledict import immutabledict
+
+    def _immutabledict_cb(d: immutabledict) -> Dict[str, Any]:
+        try:
+            return d._dict
+        except Exception:
+            # Paranoia: fall back to a `dict()` call, in case a future version of
+            # immutabledict removes `_dict` from the implementation.
+            return dict(d)
+
+    register_preserialisation_callback(immutabledict, _immutabledict_cb)
 except ImportError:
     pass
 
diff --git a/synapse/crypto/event_signing.py b/synapse/crypto/event_signing.py
index 23b799ac32..1a293f1df0 100644
--- a/synapse/crypto/event_signing.py
+++ b/synapse/crypto/event_signing.py
@@ -51,7 +51,7 @@ def check_event_content_hash(
     # some malformed events lack a 'hashes'. Protect against it being missing
     # or a weird type by basically treating it the same as an unhashed event.
     hashes = event.get("hashes")
-    # nb it might be a frozendict or a dict
+    # nb it might be a immutabledict or a dict
     if not isinstance(hashes, collections.abc.Mapping):
         raise SynapseError(
             400, "Malformed 'hashes': %s" % (type(hashes),), Codes.UNAUTHORIZED
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index c04ad08cbb..9b4d692cf4 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -15,7 +15,7 @@ from abc import ABC, abstractmethod
 from typing import TYPE_CHECKING, List, Optional, Tuple
 
 import attr
-from frozendict import frozendict
+from immutabledict import immutabledict
 
 from synapse.appservice import ApplicationService
 from synapse.events import EventBase
@@ -489,4 +489,4 @@ def _decode_state_dict(
     if input is None:
         return None
 
-    return frozendict({(etype, state_key): v for etype, state_key, v in input})
+    return immutabledict({(etype, state_key): v for etype, state_key, v in input})
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index b9c15ffcdb..e41c7a4b83 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -567,7 +567,7 @@ PowerLevelsContent = Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]]
 def copy_and_fixup_power_levels_contents(
     old_power_levels: PowerLevelsContent,
 ) -> Dict[str, Union[int, Dict[str, int]]]:
-    """Copy the content of a power_levels event, unfreezing frozendicts along the way.
+    """Copy the content of a power_levels event, unfreezing immutabledicts along the way.
 
     We accept as input power level values which are strings, provided they represent an
     integer, e.g. `"`100"` instead of 100. Such strings are converted to integers
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index fb1737b910..6f0e4386d3 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -258,7 +258,7 @@ POWER_LEVELS_SCHEMA = {
 def _create_power_level_validator() -> Type[jsonschema.Draft7Validator]:
     validator = jsonschema.validators.validator_for(POWER_LEVELS_SCHEMA)
 
-    # by default jsonschema does not consider a frozendict to be an object so
+    # by default jsonschema does not consider a immutabledict to be an object so
     # we need to use a custom type checker
     # https://python-jsonschema.readthedocs.io/en/stable/validate/?highlight=object#validating-with-additional-types
     type_checker = validator.TYPE_CHECKER.redefine(
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 4dc25df67e..6031095249 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -33,7 +33,7 @@ from typing import (
 )
 
 import attr
-from frozendict import frozendict
+from immutabledict import immutabledict
 from prometheus_client import Counter, Histogram
 
 from synapse.api.constants import EventTypes
@@ -105,14 +105,18 @@ class _StateCacheEntry:
         #
         # This can be None if we have a `state_group` (as then we can fetch the
         # state from the DB.)
-        self._state = frozendict(state) if state is not None else None
+        self._state: Optional[StateMap[str]] = (
+            immutabledict(state) if state is not None else None
+        )
 
         # the ID of a state group if one and only one is involved.
         # otherwise, None otherwise?
         self.state_group = state_group
 
         self.prev_group = prev_group
-        self.delta_ids = frozendict(delta_ids) if delta_ids is not None else None
+        self.delta_ids: Optional[StateMap[str]] = (
+            immutabledict(delta_ids) if delta_ids is not None else None
+        )
 
     async def get_state(
         self,
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index ac5fbf6b86..2b8779bbb8 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -50,7 +50,7 @@ from typing import (
 )
 
 import attr
-from frozendict import frozendict
+from immutabledict import immutabledict
 from typing_extensions import Literal
 
 from twisted.internet import defer
@@ -557,7 +557,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                 if p > min_pos
             }
 
-        return RoomStreamToken(None, min_pos, frozendict(positions))
+        return RoomStreamToken(None, min_pos, immutabledict(positions))
 
     async def get_room_events_stream_for_rooms(
         self,
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 33363867c4..c09b9cf87d 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -35,7 +35,7 @@ from typing import (
 )
 
 import attr
-from frozendict import frozendict
+from immutabledict import immutabledict
 from signedjson.key import decode_verify_key_bytes
 from signedjson.types import VerifyKey
 from typing_extensions import Final, TypedDict
@@ -490,12 +490,12 @@ class RoomStreamToken:
     )
     stream: int = attr.ib(validator=attr.validators.instance_of(int))
 
-    instance_map: "frozendict[str, int]" = attr.ib(
-        factory=frozendict,
+    instance_map: "immutabledict[str, int]" = attr.ib(
+        factory=immutabledict,
         validator=attr.validators.deep_mapping(
             key_validator=attr.validators.instance_of(str),
             value_validator=attr.validators.instance_of(int),
-            mapping_validator=attr.validators.instance_of(frozendict),
+            mapping_validator=attr.validators.instance_of(immutabledict),
         ),
     )
 
@@ -531,7 +531,7 @@ class RoomStreamToken:
                 return cls(
                     topological=None,
                     stream=stream,
-                    instance_map=frozendict(instance_map),
+                    instance_map=immutabledict(instance_map),
                 )
         except CancelledError:
             raise
@@ -566,7 +566,7 @@ class RoomStreamToken:
             for instance in set(self.instance_map).union(other.instance_map)
         }
 
-        return RoomStreamToken(None, max_stream, frozendict(instance_map))
+        return RoomStreamToken(None, max_stream, immutabledict(instance_map))
 
     def as_historical_tuple(self) -> Tuple[int, int]:
         """Returns a tuple of `(topological, stream)` for historical tokens.
diff --git a/synapse/types/state.py b/synapse/types/state.py
index 4b3071acce..1e78a74047 100644
--- a/synapse/types/state.py
+++ b/synapse/types/state.py
@@ -28,7 +28,7 @@ from typing import (
 )
 
 import attr
-from frozendict import frozendict
+from immutabledict import immutabledict
 
 from synapse.api.constants import EventTypes
 from synapse.types import MutableStateMap, StateKey, StateMap
@@ -56,7 +56,7 @@ class StateFilter:
             appear in `types`.
     """
 
-    types: "frozendict[str, Optional[FrozenSet[str]]]"
+    types: "immutabledict[str, Optional[FrozenSet[str]]]"
     include_others: bool = False
 
     def __attrs_post_init__(self) -> None:
@@ -67,7 +67,7 @@ class StateFilter:
             object.__setattr__(
                 self,
                 "types",
-                frozendict({k: v for k, v in self.types.items() if v is not None}),
+                immutabledict({k: v for k, v in self.types.items() if v is not None}),
             )
 
     @staticmethod
@@ -112,7 +112,7 @@ class StateFilter:
             type_dict.setdefault(typ, set()).add(s)  # type: ignore
 
         return StateFilter(
-            types=frozendict(
+            types=immutabledict(
                 (k, frozenset(v) if v is not None else None)
                 for k, v in type_dict.items()
             )
@@ -139,7 +139,7 @@ class StateFilter:
             The new state filter
         """
         return StateFilter(
-            types=frozendict({EventTypes.Member: frozenset(members)}),
+            types=immutabledict({EventTypes.Member: frozenset(members)}),
             include_others=True,
         )
 
@@ -159,7 +159,7 @@ class StateFilter:
                 types_with_frozen_values[state_types] = None
 
         return StateFilter(
-            frozendict(types_with_frozen_values), include_others=include_others
+            immutabledict(types_with_frozen_values), include_others=include_others
         )
 
     def return_expanded(self) -> "StateFilter":
@@ -217,7 +217,7 @@ class StateFilter:
             # We want to return all non-members, but only particular
             # memberships
             return StateFilter(
-                types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}),
+                types=immutabledict({EventTypes.Member: self.types[EventTypes.Member]}),
                 include_others=True,
             )
         else:
@@ -381,14 +381,16 @@ class StateFilter:
             if state_keys is None:
                 member_filter = StateFilter.all()
             else:
-                member_filter = StateFilter(frozendict({EventTypes.Member: state_keys}))
+                member_filter = StateFilter(
+                    immutabledict({EventTypes.Member: state_keys})
+                )
         elif self.include_others:
             member_filter = StateFilter.all()
         else:
             member_filter = StateFilter.none()
 
         non_member_filter = StateFilter(
-            types=frozendict(
+            types=immutabledict(
                 {k: v for k, v in self.types.items() if k != EventTypes.Member}
             ),
             include_others=self.include_others,
@@ -578,8 +580,8 @@ class StateFilter:
         return False
 
 
-_ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True)
+_ALL_STATE_FILTER = StateFilter(types=immutabledict(), include_others=True)
 _ALL_NON_MEMBER_STATE_FILTER = StateFilter(
-    types=frozendict({EventTypes.Member: frozenset()}), include_others=True
+    types=immutabledict({EventTypes.Member: frozenset()}), include_others=True
 )
-_NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False)
+_NONE_STATE_FILTER = StateFilter(types=immutabledict(), include_others=False)
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 7be9d5f113..9ddd26ccaa 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -18,7 +18,7 @@ import typing
 from typing import Any, Callable, Dict, Generator, Optional, Sequence
 
 import attr
-from frozendict import frozendict
+from immutabledict import immutabledict
 from matrix_common.versionstring import get_distribution_version_string
 from typing_extensions import ParamSpec
 
@@ -41,22 +41,18 @@ def _reject_invalid_json(val: Any) -> None:
     raise ValueError("Invalid JSON value: '%s'" % val)
 
 
-def _handle_frozendict(obj: Any) -> Dict[Any, Any]:
-    """Helper for json_encoder. Makes frozendicts serializable by returning
+def _handle_immutabledict(obj: Any) -> Dict[Any, Any]:
+    """Helper for json_encoder. Makes immutabledicts serializable by returning
     the underlying dict
     """
-    if type(obj) is frozendict:
+    if type(obj) is immutabledict:
         # fishing the protected dict out of the object is a bit nasty,
         # but we don't really want the overhead of copying the dict.
         try:
             # Safety: we catch the AttributeError immediately below.
-            # See https://github.com/matrix-org/python-canonicaljson/issues/36#issuecomment-927816293
-            # for discussion on how frozendict's internals have changed over time.
-            return obj._dict  # type: ignore[attr-defined]
+            return obj._dict
         except AttributeError:
-            # When the C implementation of frozendict is used,
-            # there isn't a `_dict` attribute with a dict
-            # so we resort to making a copy of the frozendict
+            # If all else fails, resort to making a copy of the immutabledict
             return dict(obj)
     raise TypeError(
         "Object of type %s is not JSON serializable" % obj.__class__.__name__
@@ -64,11 +60,11 @@ def _handle_frozendict(obj: Any) -> Dict[Any, Any]:
 
 
 # A custom JSON encoder which:
-#   * handles frozendicts
+#   * handles immutabledicts
 #   * produces valid JSON (no NaNs etc)
 #   * reduces redundant whitespace
 json_encoder = json.JSONEncoder(
-    allow_nan=False, separators=(",", ":"), default=_handle_frozendict
+    allow_nan=False, separators=(",", ":"), default=_handle_immutabledict
 )
 
 # Create a custom decoder to reject Python extensions to JSON.
diff --git a/synapse/util/frozenutils.py b/synapse/util/frozenutils.py
index 7223af1a36..889caa2601 100644
--- a/synapse/util/frozenutils.py
+++ b/synapse/util/frozenutils.py
@@ -14,14 +14,14 @@
 import collections.abc
 from typing import Any
 
-from frozendict import frozendict
+from immutabledict import immutabledict
 
 
 def freeze(o: Any) -> Any:
     if isinstance(o, dict):
-        return frozendict({k: freeze(v) for k, v in o.items()})
+        return immutabledict({k: freeze(v) for k, v in o.items()})
 
-    if isinstance(o, frozendict):
+    if isinstance(o, immutabledict):
         return o
 
     if isinstance(o, (bytes, str)):