summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2022-12-13 00:54:46 +0000
committerGitHub <noreply@github.com>2022-12-13 00:54:46 +0000
commite2a1adbf5d11288f2134ced1f84c6ffdd91a9357 (patch)
treea7e8ec0eee2585f55b6f275425a4007a29b6372e /synapse/storage/databases
parentEnable `--warn-redundant-casts` option in mypy (#14671) (diff)
downloadsynapse-e2a1adbf5d11288f2134ced1f84c6ffdd91a9357.tar.xz
Allow selecting "prejoin" events by state keys (#14642)
* Declare new config

* Parse new config

* Read new config

* Don't use trial/our TestCase where it's not needed

Before:

```
$ time trial tests/events/test_utils.py > /dev/null

real	0m2.277s
user	0m2.186s
sys	0m0.083s
```

After:
```
$ time trial tests/events/test_utils.py > /dev/null

real	0m0.566s
user	0m0.508s
sys	0m0.056s
```

* Helper to upsert to event fields

without exceeding size limits.

* Use helper when adding invite/knock state

Now that we allow admins to include events in prejoin room state with
arbitrary state keys, be a good Matrix citizen and ensure they don't
accidentally create an oversized event.

* Changelog

* Move StateFilter tests

should have done this in #14668

* Add extra methods to StateFilter

* Use StateFilter

* Ensure test file enforces typed defs; alphabetise

* Workaround surprising get_current_state_ids

* Whoops, fix mypy
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/events_worker.py33
1 files changed, 19 insertions, 14 deletions
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 01e935edef..318fd7dc71 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -16,11 +16,11 @@ import logging
 import threading
 import weakref
 from enum import Enum, auto
+from itertools import chain
 from typing import (
     TYPE_CHECKING,
     Any,
     Collection,
-    Container,
     Dict,
     Iterable,
     List,
@@ -76,6 +76,7 @@ from synapse.storage.util.id_generators import (
 )
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import JsonDict, get_domain_from_id
+from synapse.types.state import StateFilter
 from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
 from synapse.util.caches.descriptors import cached, cachedList
@@ -879,7 +880,7 @@ class EventsWorkerStore(SQLBaseStore):
     async def get_stripped_room_state_from_event_context(
         self,
         context: EventContext,
-        state_types_to_include: Container[str],
+        state_keys_to_include: StateFilter,
         membership_user_id: Optional[str] = None,
     ) -> List[JsonDict]:
         """
@@ -892,7 +893,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         Args:
             context: The event context to retrieve state of the room from.
-            state_types_to_include: The type of state events to include.
+            state_keys_to_include: The state events to include, for each event type.
             membership_user_id: An optional user ID to include the stripped membership state
                 events of. This is useful when generating the stripped state of a room for
                 invites. We want to send membership events of the inviter, so that the
@@ -901,21 +902,25 @@ class EventsWorkerStore(SQLBaseStore):
         Returns:
             A list of dictionaries, each representing a stripped state event from the room.
         """
-        current_state_ids = await context.get_current_state_ids()
+        if membership_user_id:
+            types = chain(
+                state_keys_to_include.to_types(),
+                [(EventTypes.Member, membership_user_id)],
+            )
+            filter = StateFilter.from_types(types)
+        else:
+            filter = state_keys_to_include
+        selected_state_ids = await context.get_current_state_ids(filter)
 
         # We know this event is not an outlier, so this must be
         # non-None.
-        assert current_state_ids is not None
-
-        # The state to include
-        state_to_include_ids = [
-            e_id
-            for k, e_id in current_state_ids.items()
-            if k[0] in state_types_to_include
-            or (membership_user_id and k == (EventTypes.Member, membership_user_id))
-        ]
+        assert selected_state_ids is not None
+
+        # Confusingly, get_current_state_events may return events that are discarded by
+        # the filter, if they're in context._state_delta_due_to_event. Strip these away.
+        selected_state_ids = filter.filter_state(selected_state_ids)
 
-        state_to_include = await self.get_events(state_to_include_ids)
+        state_to_include = await self.get_events(selected_state_ids.values())
 
         return [
             {