summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2022-12-13 00:54:46 +0000
committerGitHub <noreply@github.com>2022-12-13 00:54:46 +0000
commite2a1adbf5d11288f2134ced1f84c6ffdd91a9357 (patch)
treea7e8ec0eee2585f55b6f275425a4007a29b6372e /synapse
parentEnable `--warn-redundant-casts` option in mypy (#14671) (diff)
downloadsynapse-e2a1adbf5d11288f2134ced1f84c6ffdd91a9357.tar.xz
Allow selecting "prejoin" events by state keys (#14642)
* Declare new config

* Parse new config

* Read new config

* Don't use trial/our TestCase where it's not needed

Before:

```
$ time trial tests/events/test_utils.py > /dev/null

real	0m2.277s
user	0m2.186s
sys	0m0.083s
```

After:
```
$ time trial tests/events/test_utils.py > /dev/null

real	0m0.566s
user	0m0.508s
sys	0m0.056s
```

* Helper to upsert to event fields

without exceeding size limits.

* Use helper when adding invite/knock state

Now that we allow admins to include events in prejoin room state with
arbitrary state keys, be a good Matrix citizen and ensure they don't
accidentally create an oversized event.

* Changelog

* Move StateFilter tests

should have done this in #14668

* Add extra methods to StateFilter

* Use StateFilter

* Ensure test file enforces typed defs; alphabetise

* Workaround surprising get_current_state_ids

* Whoops, fix mypy
Diffstat (limited to 'synapse')
-rw-r--r--synapse/config/_util.py3
-rw-r--r--synapse/config/api.py63
-rw-r--r--synapse/events/utils.py32
-rw-r--r--synapse/handlers/message.py29
-rw-r--r--synapse/storage/databases/main/events_worker.py33
-rw-r--r--synapse/types/state.py18
6 files changed, 131 insertions, 47 deletions
diff --git a/synapse/config/_util.py b/synapse/config/_util.py
index 3edb4b7106..d3a4b484ab 100644
--- a/synapse/config/_util.py
+++ b/synapse/config/_util.py
@@ -33,6 +33,9 @@ def validate_config(
         config: the configuration value to be validated
         config_path: the path within the config file. This will be used as a basis
            for the error message.
+
+    Raises:
+        ConfigError, if validation fails.
     """
     try:
         jsonschema.validate(config, json_schema)
diff --git a/synapse/config/api.py b/synapse/config/api.py
index e46728e73f..27d50d118f 100644
--- a/synapse/config/api.py
+++ b/synapse/config/api.py
@@ -13,12 +13,13 @@
 # limitations under the License.
 
 import logging
-from typing import Any, Iterable
+from typing import Any, Iterable, Optional, Tuple
 
 from synapse.api.constants import EventTypes
 from synapse.config._base import Config, ConfigError
 from synapse.config._util import validate_config
 from synapse.types import JsonDict
+from synapse.types.state import StateFilter
 
 logger = logging.getLogger(__name__)
 
@@ -26,16 +27,20 @@ logger = logging.getLogger(__name__)
 class ApiConfig(Config):
     section = "api"
 
+    room_prejoin_state: StateFilter
+    track_puppetted_users_ips: bool
+
     def read_config(self, config: JsonDict, **kwargs: Any) -> None:
         validate_config(_MAIN_SCHEMA, config, ())
-        self.room_prejoin_state = list(self._get_prejoin_state_types(config))
+        self.room_prejoin_state = StateFilter.from_types(
+            self._get_prejoin_state_entries(config)
+        )
         self.track_puppeted_user_ips = config.get("track_puppeted_user_ips", False)
 
-    def _get_prejoin_state_types(self, config: JsonDict) -> Iterable[str]:
-        """Get the event types to include in the prejoin state
-
-        Parses the config and returns an iterable of the event types to be included.
-        """
+    def _get_prejoin_state_entries(
+        self, config: JsonDict
+    ) -> Iterable[Tuple[str, Optional[str]]]:
+        """Get the event types and state keys to include in the prejoin state."""
         room_prejoin_state_config = config.get("room_prejoin_state") or {}
 
         # backwards-compatibility support for room_invite_state_types
@@ -50,33 +55,39 @@ class ApiConfig(Config):
 
             logger.warning(_ROOM_INVITE_STATE_TYPES_WARNING)
 
-            yield from config["room_invite_state_types"]
+            for event_type in config["room_invite_state_types"]:
+                yield event_type, None
             return
 
         if not room_prejoin_state_config.get("disable_default_event_types"):
-            yield from _DEFAULT_PREJOIN_STATE_TYPES
+            yield from _DEFAULT_PREJOIN_STATE_TYPES_AND_STATE_KEYS
 
-        yield from room_prejoin_state_config.get("additional_event_types", [])
+        for entry in room_prejoin_state_config.get("additional_event_types", []):
+            if isinstance(entry, str):
+                yield entry, None
+            else:
+                yield entry
 
 
 _ROOM_INVITE_STATE_TYPES_WARNING = """\
 WARNING: The 'room_invite_state_types' configuration setting is now deprecated,
 and replaced with 'room_prejoin_state'. New features may not work correctly
-unless 'room_invite_state_types' is removed. See the sample configuration file for
-details of 'room_prejoin_state'.
+unless 'room_invite_state_types' is removed. See the config documentation at
+    https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#room_prejoin_state
+for details of 'room_prejoin_state'.
 --------------------------------------------------------------------------------
 """
 
-_DEFAULT_PREJOIN_STATE_TYPES = [
-    EventTypes.JoinRules,
-    EventTypes.CanonicalAlias,
-    EventTypes.RoomAvatar,
-    EventTypes.RoomEncryption,
-    EventTypes.Name,
+_DEFAULT_PREJOIN_STATE_TYPES_AND_STATE_KEYS = [
+    (EventTypes.JoinRules, ""),
+    (EventTypes.CanonicalAlias, ""),
+    (EventTypes.RoomAvatar, ""),
+    (EventTypes.RoomEncryption, ""),
+    (EventTypes.Name, ""),
     # Per MSC1772.
-    EventTypes.Create,
+    (EventTypes.Create, ""),
     # Per MSC3173.
-    EventTypes.Topic,
+    (EventTypes.Topic, ""),
 ]
 
 
@@ -90,7 +101,17 @@ _ROOM_PREJOIN_STATE_CONFIG_SCHEMA = {
                 "disable_default_event_types": {"type": "boolean"},
                 "additional_event_types": {
                     "type": "array",
-                    "items": {"type": "string"},
+                    "items": {
+                        "oneOf": [
+                            {"type": "string"},
+                            {
+                                "type": "array",
+                                "items": {"type": "string"},
+                                "minItems": 2,
+                                "maxItems": 2,
+                            },
+                        ],
+                    },
                 },
             },
         },
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 71853caad8..13fa93afb8 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -28,8 +28,14 @@ from typing import (
 )
 
 import attr
+from canonicaljson import encode_canonical_json
 
-from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
+from synapse.api.constants import (
+    MAX_PDU_SIZE,
+    EventContentFields,
+    EventTypes,
+    RelationTypes,
+)
 from synapse.api.errors import Codes, SynapseError
 from synapse.api.room_versions import RoomVersion
 from synapse.types import JsonDict
@@ -674,3 +680,27 @@ def validate_canonicaljson(value: Any) -> None:
     elif not isinstance(value, (bool, str)) and value is not None:
         # Other potential JSON values (bool, None, str) are safe.
         raise SynapseError(400, "Unknown JSON value", Codes.BAD_JSON)
+
+
+def maybe_upsert_event_field(
+    event: EventBase, container: JsonDict, key: str, value: object
+) -> bool:
+    """Upsert an event field, but only if this doesn't make the event too large.
+
+    Returns true iff the upsert took place.
+    """
+    if key in container:
+        old_value: object = container[key]
+        container[key] = value
+        # NB: here and below, we assume that passing a non-None `time_now` argument to
+        # get_pdu_json doesn't increase the size of the encoded result.
+        upsert_okay = len(encode_canonical_json(event.get_pdu_json())) <= MAX_PDU_SIZE
+        if not upsert_okay:
+            container[key] = old_value
+    else:
+        container[key] = value
+        upsert_okay = len(encode_canonical_json(event.get_pdu_json())) <= MAX_PDU_SIZE
+        if not upsert_okay:
+            del container[key]
+
+    return upsert_okay
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index d6e90ef259..845f683358 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -50,6 +50,7 @@ from synapse.event_auth import validate_event_for_room_version
 from synapse.events import EventBase, relation_from_event
 from synapse.events.builder import EventBuilder
 from synapse.events.snapshot import EventContext
+from synapse.events.utils import maybe_upsert_event_field
 from synapse.events.validator import EventValidator
 from synapse.handlers.directory import DirectoryHandler
 from synapse.logging import opentracing
@@ -1739,12 +1740,15 @@ class EventCreationHandler:
 
             if event.type == EventTypes.Member:
                 if event.content["membership"] == Membership.INVITE:
-                    event.unsigned[
-                        "invite_room_state"
-                    ] = await self.store.get_stripped_room_state_from_event_context(
-                        context,
-                        self.room_prejoin_state_types,
-                        membership_user_id=event.sender,
+                    maybe_upsert_event_field(
+                        event,
+                        event.unsigned,
+                        "invite_room_state",
+                        await self.store.get_stripped_room_state_from_event_context(
+                            context,
+                            self.room_prejoin_state_types,
+                            membership_user_id=event.sender,
+                        ),
                     )
 
                     invitee = UserID.from_string(event.state_key)
@@ -1762,11 +1766,14 @@ class EventCreationHandler:
                         event.signatures.update(returned_invite.signatures)
 
                 if event.content["membership"] == Membership.KNOCK:
-                    event.unsigned[
-                        "knock_room_state"
-                    ] = await self.store.get_stripped_room_state_from_event_context(
-                        context,
-                        self.room_prejoin_state_types,
+                    maybe_upsert_event_field(
+                        event,
+                        event.unsigned,
+                        "knock_room_state",
+                        await self.store.get_stripped_room_state_from_event_context(
+                            context,
+                            self.room_prejoin_state_types,
+                        ),
                     )
 
             if event.type == EventTypes.Redaction:
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 01e935edef..318fd7dc71 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -16,11 +16,11 @@ import logging
 import threading
 import weakref
 from enum import Enum, auto
+from itertools import chain
 from typing import (
     TYPE_CHECKING,
     Any,
     Collection,
-    Container,
     Dict,
     Iterable,
     List,
@@ -76,6 +76,7 @@ from synapse.storage.util.id_generators import (
 )
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import JsonDict, get_domain_from_id
+from synapse.types.state import StateFilter
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
 from synapse.util.caches.descriptors import cached, cachedList
@@ -879,7 +880,7 @@ class EventsWorkerStore(SQLBaseStore):
     async def get_stripped_room_state_from_event_context(
         self,
         context: EventContext,
-        state_types_to_include: Container[str],
+        state_keys_to_include: StateFilter,
         membership_user_id: Optional[str] = None,
     ) -> List[JsonDict]:
         """
@@ -892,7 +893,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         Args:
             context: The event context to retrieve state of the room from.
-            state_types_to_include: The type of state events to include.
+            state_keys_to_include: The state events to include, for each event type.
             membership_user_id: An optional user ID to include the stripped membership state
                 events of. This is useful when generating the stripped state of a room for
                 invites. We want to send membership events of the inviter, so that the
@@ -901,21 +902,25 @@ class EventsWorkerStore(SQLBaseStore):
         Returns:
             A list of dictionaries, each representing a stripped state event from the room.
         """
-        current_state_ids = await context.get_current_state_ids()
+        if membership_user_id:
+            types = chain(
+                state_keys_to_include.to_types(),
+                [(EventTypes.Member, membership_user_id)],
+            )
+            filter = StateFilter.from_types(types)
+        else:
+            filter = state_keys_to_include
+        selected_state_ids = await context.get_current_state_ids(filter)
 
         # We know this event is not an outlier, so this must be
         # non-None.
-        assert current_state_ids is not None
-
-        # The state to include
-        state_to_include_ids = [
-            e_id
-            for k, e_id in current_state_ids.items()
-            if k[0] in state_types_to_include
-            or (membership_user_id and k == (EventTypes.Member, membership_user_id))
-        ]
+        assert selected_state_ids is not None
+
+        # Confusingly, get_current_state_events may return events that are discarded by
+        # the filter, if they're in context._state_delta_due_to_event. Strip these away.
+        selected_state_ids = filter.filter_state(selected_state_ids)
 
-        state_to_include = await self.get_events(state_to_include_ids)
+        state_to_include = await self.get_events(selected_state_ids.values())
 
         return [
             {
diff --git a/synapse/types/state.py b/synapse/types/state.py
index 0004d955b4..743a4f9217 100644
--- a/synapse/types/state.py
+++ b/synapse/types/state.py
@@ -118,6 +118,15 @@ class StateFilter:
             )
         )
 
+    def to_types(self) -> Iterable[Tuple[str, Optional[str]]]:
+        """The inverse to `from_types`."""
+        for (event_type, state_keys) in self.types.items():
+            if state_keys is None:
+                yield event_type, None
+            else:
+                for state_key in state_keys:
+                    yield event_type, state_key
+
     @staticmethod
     def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
         """Creates a filter that returns all non-member events, plus the member
@@ -343,6 +352,15 @@ class StateFilter:
             for s in state_keys
         ]
 
+    def wildcard_types(self) -> List[str]:
+        """Returns a list of event types which require us to fetch all state keys.
+        This will be empty unless `has_wildcards` returns True.
+
+        Returns:
+            A list of event types.
+        """
+        return [t for t, state_keys in self.types.items() if state_keys is None]
+
     def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
         """Return the filter split into two: one which assumes it's exclusively
         matching against member state, and one which assumes it's matching