diff --git a/changelog.d/14642.feature b/changelog.d/14642.feature
new file mode 100644
index 0000000000..cbc9db10c3
--- /dev/null
+++ b/changelog.d/14642.feature
@@ -0,0 +1 @@
+Allow selecting "prejoin" events by state keys in addition to event types.
diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md
index dc5e5ac597..4d32902fea 100644
--- a/docs/usage/configuration/config_documentation.md
+++ b/docs/usage/configuration/config_documentation.md
@@ -2501,32 +2501,53 @@ Config settings related to the client/server API
---
### `room_prejoin_state`
-Controls for the state that is shared with users who receive an invite
-to a room. By default, the following state event types are shared with users who
-receive invites to the room:
-- m.room.join_rules
-- m.room.canonical_alias
-- m.room.avatar
-- m.room.encryption
-- m.room.name
-- m.room.create
-- m.room.topic
+This setting controls the state that is shared with users upon receiving an
+invite to a room, or in reply to a knock on a room. By default, the following
+state events are shared with users:
+
+- `m.room.join_rules`
+- `m.room.canonical_alias`
+- `m.room.avatar`
+- `m.room.encryption`
+- `m.room.name`
+- `m.room.create`
+- `m.room.topic`
To change the default behavior, use the following sub-options:
-* `disable_default_event_types`: set to true to disable the above defaults. If this
- is enabled, only the event types listed in `additional_event_types` are shared.
- Defaults to false.
-* `additional_event_types`: Additional state event types to share with users when they are invited
- to a room. By default, this list is empty (so only the default event types are shared).
+* `disable_default_event_types`: boolean. Set to `true` to disable the above
+ defaults. If this is enabled, only the event types listed in
+ `additional_event_types` are shared. Defaults to `false`.
+* `additional_event_types`: A list of additional state events to include in the
+ events to be shared. By default, this list is empty (so only the default event
+ types are shared).
+
+ Each entry in this list should be either a single string or a list of two
+ strings.
+ * A standalone string `t` represents all events with type `t` (i.e.
+ with no restrictions on state keys).
+ * A pair of strings `[t, s]` represents a single event with type `t` and
+ state key `s`. The same type can appear in two entries with different state
+ keys: in this situation, both state keys are included in prejoin state.
Example configuration:
```yaml
room_prejoin_state:
- disable_default_event_types: true
+ disable_default_event_types: false
additional_event_types:
- - org.example.custom.event.type
- - m.room.join_rules
+ # Share all events of type `org.example.custom.event.typeA`
+ - org.example.custom.event.typeA
+ # Share only events of type `org.example.custom.event.typeB` whose
+ # state_key is "foo"
+ - ["org.example.custom.event.typeB", "foo"]
+ # Share only events of type `org.example.custom.event.typeC` whose
+ # state_key is "bar" or "baz"
+ - ["org.example.custom.event.typeC", "bar"]
+ - ["org.example.custom.event.typeC", "baz"]
```
+
+*Changed in Synapse 1.74:* admins can filter the events in prejoin state based
+on their state key.
+
---
### `track_puppeted_user_ips`
diff --git a/mypy.ini b/mypy.ini
index 727536df50..37acf589c9 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -89,6 +89,12 @@ disallow_untyped_defs = False
[mypy-tests.*]
disallow_untyped_defs = False
+[mypy-tests.config.test_api]
+disallow_untyped_defs = True
+
+[mypy-tests.federation.transport.test_client]
+disallow_untyped_defs = True
+
[mypy-tests.handlers.test_sso]
disallow_untyped_defs = True
@@ -101,7 +107,7 @@ disallow_untyped_defs = True
[mypy-tests.push.test_bulk_push_rule_evaluator]
disallow_untyped_defs = True
-[mypy-tests.test_server]
+[mypy-tests.rest.*]
disallow_untyped_defs = True
[mypy-tests.state.test_profile]
@@ -110,10 +116,10 @@ disallow_untyped_defs = True
[mypy-tests.storage.*]
disallow_untyped_defs = True
-[mypy-tests.rest.*]
+[mypy-tests.test_server]
disallow_untyped_defs = True
-[mypy-tests.federation.transport.test_client]
+[mypy-tests.types.*]
disallow_untyped_defs = True
[mypy-tests.util.caches.*]
diff --git a/synapse/config/_util.py b/synapse/config/_util.py
index 3edb4b7106..d3a4b484ab 100644
--- a/synapse/config/_util.py
+++ b/synapse/config/_util.py
@@ -33,6 +33,9 @@ def validate_config(
config: the configuration value to be validated
config_path: the path within the config file. This will be used as a basis
for the error message.
+
+ Raises:
+ ConfigError, if validation fails.
"""
try:
jsonschema.validate(config, json_schema)
diff --git a/synapse/config/api.py b/synapse/config/api.py
index e46728e73f..27d50d118f 100644
--- a/synapse/config/api.py
+++ b/synapse/config/api.py
@@ -13,12 +13,13 @@
# limitations under the License.
import logging
-from typing import Any, Iterable
+from typing import Any, Iterable, Optional, Tuple
from synapse.api.constants import EventTypes
from synapse.config._base import Config, ConfigError
from synapse.config._util import validate_config
from synapse.types import JsonDict
+from synapse.types.state import StateFilter
logger = logging.getLogger(__name__)
@@ -26,16 +27,20 @@ logger = logging.getLogger(__name__)
class ApiConfig(Config):
section = "api"
+ room_prejoin_state: StateFilter
+ track_puppetted_users_ips: bool
+
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
validate_config(_MAIN_SCHEMA, config, ())
- self.room_prejoin_state = list(self._get_prejoin_state_types(config))
+ self.room_prejoin_state = StateFilter.from_types(
+ self._get_prejoin_state_entries(config)
+ )
self.track_puppeted_user_ips = config.get("track_puppeted_user_ips", False)
- def _get_prejoin_state_types(self, config: JsonDict) -> Iterable[str]:
- """Get the event types to include in the prejoin state
-
- Parses the config and returns an iterable of the event types to be included.
- """
+ def _get_prejoin_state_entries(
+ self, config: JsonDict
+ ) -> Iterable[Tuple[str, Optional[str]]]:
+ """Get the event types and state keys to include in the prejoin state."""
room_prejoin_state_config = config.get("room_prejoin_state") or {}
# backwards-compatibility support for room_invite_state_types
@@ -50,33 +55,39 @@ class ApiConfig(Config):
logger.warning(_ROOM_INVITE_STATE_TYPES_WARNING)
- yield from config["room_invite_state_types"]
+ for event_type in config["room_invite_state_types"]:
+ yield event_type, None
return
if not room_prejoin_state_config.get("disable_default_event_types"):
- yield from _DEFAULT_PREJOIN_STATE_TYPES
+ yield from _DEFAULT_PREJOIN_STATE_TYPES_AND_STATE_KEYS
- yield from room_prejoin_state_config.get("additional_event_types", [])
+ for entry in room_prejoin_state_config.get("additional_event_types", []):
+ if isinstance(entry, str):
+ yield entry, None
+ else:
+ yield entry
_ROOM_INVITE_STATE_TYPES_WARNING = """\
WARNING: The 'room_invite_state_types' configuration setting is now deprecated,
and replaced with 'room_prejoin_state'. New features may not work correctly
-unless 'room_invite_state_types' is removed. See the sample configuration file for
-details of 'room_prejoin_state'.
+unless 'room_invite_state_types' is removed. See the config documentation at
+ https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#room_prejoin_state
+for details of 'room_prejoin_state'.
--------------------------------------------------------------------------------
"""
-_DEFAULT_PREJOIN_STATE_TYPES = [
- EventTypes.JoinRules,
- EventTypes.CanonicalAlias,
- EventTypes.RoomAvatar,
- EventTypes.RoomEncryption,
- EventTypes.Name,
+_DEFAULT_PREJOIN_STATE_TYPES_AND_STATE_KEYS = [
+ (EventTypes.JoinRules, ""),
+ (EventTypes.CanonicalAlias, ""),
+ (EventTypes.RoomAvatar, ""),
+ (EventTypes.RoomEncryption, ""),
+ (EventTypes.Name, ""),
# Per MSC1772.
- EventTypes.Create,
+ (EventTypes.Create, ""),
# Per MSC3173.
- EventTypes.Topic,
+ (EventTypes.Topic, ""),
]
@@ -90,7 +101,17 @@ _ROOM_PREJOIN_STATE_CONFIG_SCHEMA = {
"disable_default_event_types": {"type": "boolean"},
"additional_event_types": {
"type": "array",
- "items": {"type": "string"},
+ "items": {
+ "oneOf": [
+ {"type": "string"},
+ {
+ "type": "array",
+ "items": {"type": "string"},
+ "minItems": 2,
+ "maxItems": 2,
+ },
+ ],
+ },
},
},
},
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 71853caad8..13fa93afb8 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -28,8 +28,14 @@ from typing import (
)
import attr
+from canonicaljson import encode_canonical_json
-from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
+from synapse.api.constants import (
+ MAX_PDU_SIZE,
+ EventContentFields,
+ EventTypes,
+ RelationTypes,
+)
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.types import JsonDict
@@ -674,3 +680,27 @@ def validate_canonicaljson(value: Any) -> None:
elif not isinstance(value, (bool, str)) and value is not None:
# Other potential JSON values (bool, None, str) are safe.
raise SynapseError(400, "Unknown JSON value", Codes.BAD_JSON)
+
+
+def maybe_upsert_event_field(
+ event: EventBase, container: JsonDict, key: str, value: object
+) -> bool:
+ """Upsert an event field, but only if this doesn't make the event too large.
+
+ Returns true iff the upsert took place.
+ """
+ if key in container:
+ old_value: object = container[key]
+ container[key] = value
+ # NB: here and below, we assume that passing a non-None `time_now` argument to
+ # get_pdu_json doesn't increase the size of the encoded result.
+ upsert_okay = len(encode_canonical_json(event.get_pdu_json())) <= MAX_PDU_SIZE
+ if not upsert_okay:
+ container[key] = old_value
+ else:
+ container[key] = value
+ upsert_okay = len(encode_canonical_json(event.get_pdu_json())) <= MAX_PDU_SIZE
+ if not upsert_okay:
+ del container[key]
+
+ return upsert_okay
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index d6e90ef259..845f683358 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -50,6 +50,7 @@ from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase, relation_from_event
from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext
+from synapse.events.utils import maybe_upsert_event_field
from synapse.events.validator import EventValidator
from synapse.handlers.directory import DirectoryHandler
from synapse.logging import opentracing
@@ -1739,12 +1740,15 @@ class EventCreationHandler:
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE:
- event.unsigned[
- "invite_room_state"
- ] = await self.store.get_stripped_room_state_from_event_context(
- context,
- self.room_prejoin_state_types,
- membership_user_id=event.sender,
+ maybe_upsert_event_field(
+ event,
+ event.unsigned,
+ "invite_room_state",
+ await self.store.get_stripped_room_state_from_event_context(
+ context,
+ self.room_prejoin_state_types,
+ membership_user_id=event.sender,
+ ),
)
invitee = UserID.from_string(event.state_key)
@@ -1762,11 +1766,14 @@ class EventCreationHandler:
event.signatures.update(returned_invite.signatures)
if event.content["membership"] == Membership.KNOCK:
- event.unsigned[
- "knock_room_state"
- ] = await self.store.get_stripped_room_state_from_event_context(
- context,
- self.room_prejoin_state_types,
+ maybe_upsert_event_field(
+ event,
+ event.unsigned,
+ "knock_room_state",
+ await self.store.get_stripped_room_state_from_event_context(
+ context,
+ self.room_prejoin_state_types,
+ ),
)
if event.type == EventTypes.Redaction:
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 01e935edef..318fd7dc71 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -16,11 +16,11 @@ import logging
import threading
import weakref
from enum import Enum, auto
+from itertools import chain
from typing import (
TYPE_CHECKING,
Any,
Collection,
- Container,
Dict,
Iterable,
List,
@@ -76,6 +76,7 @@ from synapse.storage.util.id_generators import (
)
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
+from synapse.types.state import StateFilter
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
from synapse.util.caches.descriptors import cached, cachedList
@@ -879,7 +880,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_stripped_room_state_from_event_context(
self,
context: EventContext,
- state_types_to_include: Container[str],
+ state_keys_to_include: StateFilter,
membership_user_id: Optional[str] = None,
) -> List[JsonDict]:
"""
@@ -892,7 +893,7 @@ class EventsWorkerStore(SQLBaseStore):
Args:
context: The event context to retrieve state of the room from.
- state_types_to_include: The type of state events to include.
+ state_keys_to_include: The state events to include, for each event type.
membership_user_id: An optional user ID to include the stripped membership state
events of. This is useful when generating the stripped state of a room for
invites. We want to send membership events of the inviter, so that the
@@ -901,21 +902,25 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
A list of dictionaries, each representing a stripped state event from the room.
"""
- current_state_ids = await context.get_current_state_ids()
+ if membership_user_id:
+ types = chain(
+ state_keys_to_include.to_types(),
+ [(EventTypes.Member, membership_user_id)],
+ )
+ filter = StateFilter.from_types(types)
+ else:
+ filter = state_keys_to_include
+ selected_state_ids = await context.get_current_state_ids(filter)
# We know this event is not an outlier, so this must be
# non-None.
- assert current_state_ids is not None
-
- # The state to include
- state_to_include_ids = [
- e_id
- for k, e_id in current_state_ids.items()
- if k[0] in state_types_to_include
- or (membership_user_id and k == (EventTypes.Member, membership_user_id))
- ]
+ assert selected_state_ids is not None
+
+ # Confusingly, get_current_state_events may return events that are discarded by
+ # the filter, if they're in context._state_delta_due_to_event. Strip these away.
+ selected_state_ids = filter.filter_state(selected_state_ids)
- state_to_include = await self.get_events(state_to_include_ids)
+ state_to_include = await self.get_events(selected_state_ids.values())
return [
{
diff --git a/synapse/types/state.py b/synapse/types/state.py
index 0004d955b4..743a4f9217 100644
--- a/synapse/types/state.py
+++ b/synapse/types/state.py
@@ -118,6 +118,15 @@ class StateFilter:
)
)
+ def to_types(self) -> Iterable[Tuple[str, Optional[str]]]:
+ """The inverse to `from_types`."""
+ for (event_type, state_keys) in self.types.items():
+ if state_keys is None:
+ yield event_type, None
+ else:
+ for state_key in state_keys:
+ yield event_type, state_key
+
@staticmethod
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
"""Creates a filter that returns all non-member events, plus the member
@@ -343,6 +352,15 @@ class StateFilter:
for s in state_keys
]
+ def wildcard_types(self) -> List[str]:
+ """Returns a list of event types which require us to fetch all state keys.
+ This will be empty unless `has_wildcards` returns True.
+
+ Returns:
+ A list of event types.
+ """
+ return [t for t, state_keys in self.types.items() if state_keys is None]
+
def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
"""Return the filter split into two: one which assumes it's exclusively
matching against member state, and one which assumes it's matching
diff --git a/tests/config/test_api.py b/tests/config/test_api.py
new file mode 100644
index 0000000000..6773c9a277
--- /dev/null
+++ b/tests/config/test_api.py
@@ -0,0 +1,145 @@
+from unittest import TestCase as StdlibTestCase
+
+import yaml
+
+from synapse.config import ConfigError
+from synapse.config.api import ApiConfig
+from synapse.types.state import StateFilter
+
+DEFAULT_PREJOIN_STATE_PAIRS = {
+ ("m.room.join_rules", ""),
+ ("m.room.canonical_alias", ""),
+ ("m.room.avatar", ""),
+ ("m.room.encryption", ""),
+ ("m.room.name", ""),
+ ("m.room.create", ""),
+ ("m.room.topic", ""),
+}
+
+
+class TestRoomPrejoinState(StdlibTestCase):
+ def read_config(self, source: str) -> ApiConfig:
+ config = ApiConfig()
+ config.read_config(yaml.safe_load(source))
+ return config
+
+ def test_no_prejoin_state(self) -> None:
+ config = self.read_config("foo: bar")
+ self.assertFalse(config.room_prejoin_state.has_wildcards())
+ self.assertEqual(
+ set(config.room_prejoin_state.concrete_types()), DEFAULT_PREJOIN_STATE_PAIRS
+ )
+
+ def test_disable_default_event_types(self) -> None:
+ config = self.read_config(
+ """
+room_prejoin_state:
+ disable_default_event_types: true
+ """
+ )
+ self.assertEqual(config.room_prejoin_state, StateFilter.none())
+
+ def test_event_without_state_key(self) -> None:
+ config = self.read_config(
+ """
+room_prejoin_state:
+ disable_default_event_types: true
+ additional_event_types:
+ - foo
+ """
+ )
+ self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
+ self.assertEqual(config.room_prejoin_state.concrete_types(), [])
+
+ def test_event_with_specific_state_key(self) -> None:
+ config = self.read_config(
+ """
+room_prejoin_state:
+ disable_default_event_types: true
+ additional_event_types:
+ - [foo, bar]
+ """
+ )
+ self.assertFalse(config.room_prejoin_state.has_wildcards())
+ self.assertEqual(
+ set(config.room_prejoin_state.concrete_types()),
+ {("foo", "bar")},
+ )
+
+ def test_repeated_event_with_specific_state_key(self) -> None:
+ config = self.read_config(
+ """
+room_prejoin_state:
+ disable_default_event_types: true
+ additional_event_types:
+ - [foo, bar]
+ - [foo, baz]
+ """
+ )
+ self.assertFalse(config.room_prejoin_state.has_wildcards())
+ self.assertEqual(
+ set(config.room_prejoin_state.concrete_types()),
+ {("foo", "bar"), ("foo", "baz")},
+ )
+
+ def test_no_specific_state_key_overrides_specific_state_key(self) -> None:
+ config = self.read_config(
+ """
+room_prejoin_state:
+ disable_default_event_types: true
+ additional_event_types:
+ - [foo, bar]
+ - foo
+ """
+ )
+ self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
+ self.assertEqual(config.room_prejoin_state.concrete_types(), [])
+
+ config = self.read_config(
+ """
+room_prejoin_state:
+ disable_default_event_types: true
+ additional_event_types:
+ - foo
+ - [foo, bar]
+ """
+ )
+ self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
+ self.assertEqual(config.room_prejoin_state.concrete_types(), [])
+
+ def test_bad_event_type_entry_raises(self) -> None:
+ with self.assertRaises(ConfigError):
+ self.read_config(
+ """
+room_prejoin_state:
+ additional_event_types:
+ - []
+ """
+ )
+
+ with self.assertRaises(ConfigError):
+ self.read_config(
+ """
+room_prejoin_state:
+ additional_event_types:
+ - [a]
+ """
+ )
+
+ with self.assertRaises(ConfigError):
+ self.read_config(
+ """
+room_prejoin_state:
+ additional_event_types:
+ - [a, b, c]
+ """
+ )
+
+ with self.assertRaises(ConfigError):
+ self.read_config(
+ """
+room_prejoin_state:
+ additional_event_types:
+ - [true, 1.23]
+ """
+ )
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index b1c47efac7..a79256846f 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -12,19 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import unittest as stdlib_unittest
+
from synapse.api.constants import EventContentFields
from synapse.api.room_versions import RoomVersions
from synapse.events import make_event_from_dict
from synapse.events.utils import (
SerializeEventConfig,
copy_and_fixup_power_levels_contents,
+ maybe_upsert_event_field,
prune_event,
serialize_event,
)
from synapse.util.frozenutils import freeze
-from tests import unittest
-
def MockEvent(**kwargs):
if "event_id" not in kwargs:
@@ -34,7 +35,31 @@ def MockEvent(**kwargs):
return make_event_from_dict(kwargs)
-class PruneEventTestCase(unittest.TestCase):
+class TestMaybeUpsertEventField(stdlib_unittest.TestCase):
+ def test_update_okay(self) -> None:
+ event = make_event_from_dict({"event_id": "$1234"})
+ success = maybe_upsert_event_field(event, event.unsigned, "key", "value")
+ self.assertTrue(success)
+ self.assertEqual(event.unsigned["key"], "value")
+
+ def test_update_not_okay(self) -> None:
+ event = make_event_from_dict({"event_id": "$1234"})
+ LARGE_STRING = "a" * 100_000
+ success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING)
+ self.assertFalse(success)
+ self.assertNotIn("key", event.unsigned)
+
+ def test_update_not_okay_leaves_original_value(self) -> None:
+ event = make_event_from_dict(
+ {"event_id": "$1234", "unsigned": {"key": "value"}}
+ )
+ LARGE_STRING = "a" * 100_000
+ success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING)
+ self.assertFalse(success)
+ self.assertEqual(event.unsigned["key"], "value")
+
+
+class PruneEventTestCase(stdlib_unittest.TestCase):
def run_test(self, evdict, matchdict, **kwargs):
"""
Asserts that a new event constructed with `evdict` will look like
@@ -391,7 +416,7 @@ class PruneEventTestCase(unittest.TestCase):
)
-class SerializeEventTestCase(unittest.TestCase):
+class SerializeEventTestCase(stdlib_unittest.TestCase):
def serialize(self, ev, fields):
return serialize_event(
ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields)
@@ -513,7 +538,7 @@ class SerializeEventTestCase(unittest.TestCase):
)
-class CopyPowerLevelsContentTestCase(unittest.TestCase):
+class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase):
def setUp(self) -> None:
self.test_content = {
"ban": 50,
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index a433e70870..bad7f0bc60 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -26,7 +26,7 @@ from synapse.types import JsonDict, RoomID, StateMap, UserID
from synapse.types.state import StateFilter
from synapse.util import Clock
-from tests.unittest import HomeserverTestCase, TestCase
+from tests.unittest import HomeserverTestCase
logger = logging.getLogger(__name__)
@@ -494,624 +494,3 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertEqual(is_all, True)
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
-
-
-class StateFilterDifferenceTestCase(TestCase):
- def assert_difference(
- self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
- ) -> None:
- self.assertEqual(
- minuend.approx_difference(subtrahend),
- expected,
- f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
- )
-
- def test_state_filter_difference_no_include_other_minus_no_include_other(
- self,
- ) -> None:
- """
- Tests the StateFilter.approx_difference method
- where, in a.approx_difference(b), both a and b do not have the
- include_others flag set.
- """
- # (wildcard on state keys) - (wildcard on state keys):
- self.assert_difference(
- StateFilter.freeze(
- {EventTypes.Member: None, EventTypes.Create: None},
- include_others=False,
- ),
- StateFilter.freeze(
- {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
- include_others=False,
- ),
- StateFilter.freeze({EventTypes.Create: None}, include_others=False),
- )
-
- # (wildcard on state keys) - (specific state keys)
- # This one is an over-approximation because we can't represent
- # 'all state keys except a few named examples'
- self.assert_difference(
- StateFilter.freeze({EventTypes.Member: None}, include_others=False),
- StateFilter.freeze(
- {EventTypes.Member: {"@wombat:spqr"}},
- include_others=False,
- ),
- StateFilter.freeze({EventTypes.Member: None}, include_others=False),
- )
-
- # (wildcard on state keys) - (no state keys)
- self.assert_difference(
- StateFilter.freeze(
- {EventTypes.Member: None},
- include_others=False,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: set(),
- },
- include_others=False,
- ),
- StateFilter.freeze(
- {EventTypes.Member: None},
- include_others=False,
- ),
- )
-
- # (specific state keys) - (wildcard on state keys):
- self.assert_difference(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=False,
- ),
- StateFilter.freeze(
- {EventTypes.Member: None},
- include_others=False,
- ),
- StateFilter.freeze(
- {EventTypes.CanonicalAlias: {""}},
- include_others=False,
- ),
- )
-
- # (specific state keys) - (specific state keys)
- self.assert_difference(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=False,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr"},
- },
- include_others=False,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=False,
- ),
- )
-
- # (specific state keys) - (no state keys)
- self.assert_difference(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=False,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: set(),
- },
- include_others=False,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=False,
- ),
- )
-
- def test_state_filter_difference_include_other_minus_no_include_other(self) -> None:
- """
- Tests the StateFilter.approx_difference method
- where, in a.approx_difference(b), only a has the include_others flag set.
- """
- # (wildcard on state keys) - (wildcard on state keys):
- self.assert_difference(
- StateFilter.freeze(
- {EventTypes.Member: None, EventTypes.Create: None},
- include_others=True,
- ),
- StateFilter.freeze(
- {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
- include_others=False,
- ),
- StateFilter.freeze(
- {
- EventTypes.Create: None,
- EventTypes.Member: set(),
- EventTypes.CanonicalAlias: set(),
- },
- include_others=True,
- ),
- )
-
- # (wildcard on state keys) - (specific state keys)
- # This one is an over-approximation because we can't represent
- # 'all state keys except a few named examples'
- # This also shows that the resultant state filter is normalised.
- self.assert_difference(
- StateFilter.freeze({EventTypes.Member: None}, include_others=True),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr"},
- EventTypes.Create: {""},
- },
- include_others=False,
- ),
- StateFilter(types=frozendict(), include_others=True),
- )
-
- # (wildcard on state keys) - (no state keys)
- self.assert_difference(
- StateFilter.freeze(
- {EventTypes.Member: None},
- include_others=True,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: set(),
- },
- include_others=False,
- ),
- StateFilter(
- types=frozendict(),
- include_others=True,
- ),
- )
-
- # (specific state keys) - (wildcard on state keys):
- self.assert_difference(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=True,
- ),
- StateFilter.freeze(
- {EventTypes.Member: None},
- include_others=False,
- ),
- StateFilter.freeze(
- {
- EventTypes.CanonicalAlias: {""},
- EventTypes.Member: set(),
- },
- include_others=True,
- ),
- )
-
- # (specific state keys) - (specific state keys)
- self.assert_difference(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=True,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr"},
- },
- include_others=False,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=True,
- ),
- )
-
- # (specific state keys) - (no state keys)
- self.assert_difference(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=True,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: set(),
- },
- include_others=False,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=True,
- ),
- )
-
- def test_state_filter_difference_include_other_minus_include_other(self) -> None:
- """
- Tests the StateFilter.approx_difference method
- where, in a.approx_difference(b), both a and b have the include_others
- flag set.
- """
- # (wildcard on state keys) - (wildcard on state keys):
- self.assert_difference(
- StateFilter.freeze(
- {EventTypes.Member: None, EventTypes.Create: None},
- include_others=True,
- ),
- StateFilter.freeze(
- {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
- include_others=True,
- ),
- StateFilter(types=frozendict(), include_others=False),
- )
-
- # (wildcard on state keys) - (specific state keys)
- # This one is an over-approximation because we can't represent
- # 'all state keys except a few named examples'
- self.assert_difference(
- StateFilter.freeze({EventTypes.Member: None}, include_others=True),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=True,
- ),
- StateFilter.freeze(
- {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
- include_others=False,
- ),
- )
-
- # (wildcard on state keys) - (no state keys)
- self.assert_difference(
- StateFilter.freeze(
- {EventTypes.Member: None},
- include_others=True,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: set(),
- },
- include_others=True,
- ),
- StateFilter.freeze(
- {EventTypes.Member: None},
- include_others=False,
- ),
- )
-
- # (specific state keys) - (wildcard on state keys):
- self.assert_difference(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=True,
- ),
- StateFilter.freeze(
- {EventTypes.Member: None},
- include_others=True,
- ),
- StateFilter(
- types=frozendict(),
- include_others=False,
- ),
- )
-
- # (specific state keys) - (specific state keys)
- # This one is an over-approximation because we can't represent
- # 'all state keys except a few named examples'
- self.assert_difference(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- EventTypes.Create: {""},
- },
- include_others=True,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr"},
- EventTypes.Create: set(),
- },
- include_others=True,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@spqr:spqr"},
- EventTypes.Create: {""},
- },
- include_others=False,
- ),
- )
-
- # (specific state keys) - (no state keys)
- self.assert_difference(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=True,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: set(),
- },
- include_others=True,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- },
- include_others=False,
- ),
- )
-
- def test_state_filter_difference_no_include_other_minus_include_other(self) -> None:
- """
- Tests the StateFilter.approx_difference method
- where, in a.approx_difference(b), only b has the include_others flag set.
- """
- # (wildcard on state keys) - (wildcard on state keys):
- self.assert_difference(
- StateFilter.freeze(
- {EventTypes.Member: None, EventTypes.Create: None},
- include_others=False,
- ),
- StateFilter.freeze(
- {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
- include_others=True,
- ),
- StateFilter(types=frozendict(), include_others=False),
- )
-
- # (wildcard on state keys) - (specific state keys)
- # This one is an over-approximation because we can't represent
- # 'all state keys except a few named examples'
- self.assert_difference(
- StateFilter.freeze({EventTypes.Member: None}, include_others=False),
- StateFilter.freeze(
- {EventTypes.Member: {"@wombat:spqr"}},
- include_others=True,
- ),
- StateFilter.freeze({EventTypes.Member: None}, include_others=False),
- )
-
- # (wildcard on state keys) - (no state keys)
- self.assert_difference(
- StateFilter.freeze(
- {EventTypes.Member: None},
- include_others=False,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: set(),
- },
- include_others=True,
- ),
- StateFilter.freeze(
- {EventTypes.Member: None},
- include_others=False,
- ),
- )
-
- # (specific state keys) - (wildcard on state keys):
- self.assert_difference(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=False,
- ),
- StateFilter.freeze(
- {EventTypes.Member: None},
- include_others=True,
- ),
- StateFilter(
- types=frozendict(),
- include_others=False,
- ),
- )
-
- # (specific state keys) - (specific state keys)
- # This one is an over-approximation because we can't represent
- # 'all state keys except a few named examples'
- self.assert_difference(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=False,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr"},
- },
- include_others=True,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@spqr:spqr"},
- },
- include_others=False,
- ),
- )
-
- # (specific state keys) - (no state keys)
- self.assert_difference(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- EventTypes.CanonicalAlias: {""},
- },
- include_others=False,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: set(),
- },
- include_others=True,
- ),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
- },
- include_others=False,
- ),
- )
-
- def test_state_filter_difference_simple_cases(self) -> None:
- """
- Tests some very simple cases of the StateFilter approx_difference,
- that are not explicitly tested by the more in-depth tests.
- """
-
- self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none())
-
- self.assert_difference(
- StateFilter.all(),
- StateFilter.none(),
- StateFilter.all(),
- )
-
-
-class StateFilterTestCase(TestCase):
- def test_return_expanded(self) -> None:
- """
- Tests the behaviour of the return_expanded() function that expands
- StateFilters to include more state types (for the sake of cache hit rate).
- """
-
- self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all())
-
- self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none())
-
- # Concrete-only state filters stay the same
- # (Case: mixed filter)
- self.assertEqual(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:test", "@alicia:test"},
- "some.other.state.type": {""},
- },
- include_others=False,
- ).return_expanded(),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:test", "@alicia:test"},
- "some.other.state.type": {""},
- },
- include_others=False,
- ),
- )
-
- # Concrete-only state filters stay the same
- # (Case: non-member-only filter)
- self.assertEqual(
- StateFilter.freeze(
- {"some.other.state.type": {""}}, include_others=False
- ).return_expanded(),
- StateFilter.freeze({"some.other.state.type": {""}}, include_others=False),
- )
-
- # Concrete-only state filters stay the same
- # (Case: member-only filter)
- self.assertEqual(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:test", "@alicia:test"},
- },
- include_others=False,
- ).return_expanded(),
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:test", "@alicia:test"},
- },
- include_others=False,
- ),
- )
-
- # Wildcard member-only state filters stay the same
- self.assertEqual(
- StateFilter.freeze(
- {EventTypes.Member: None},
- include_others=False,
- ).return_expanded(),
- StateFilter.freeze(
- {EventTypes.Member: None},
- include_others=False,
- ),
- )
-
- # If there is a wildcard in the non-member portion of the filter,
- # it's expanded to include ALL non-member events.
- # (Case: mixed filter)
- self.assertEqual(
- StateFilter.freeze(
- {
- EventTypes.Member: {"@wombat:test", "@alicia:test"},
- "some.other.state.type": None,
- },
- include_others=False,
- ).return_expanded(),
- StateFilter.freeze(
- {EventTypes.Member: {"@wombat:test", "@alicia:test"}},
- include_others=True,
- ),
- )
-
- # If there is a wildcard in the non-member portion of the filter,
- # it's expanded to include ALL non-member events.
- # (Case: non-member-only filter)
- self.assertEqual(
- StateFilter.freeze(
- {
- "some.other.state.type": None,
- },
- include_others=False,
- ).return_expanded(),
- StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
- )
- self.assertEqual(
- StateFilter.freeze(
- {
- "some.other.state.type": None,
- "yet.another.state.type": {"wombat"},
- },
- include_others=False,
- ).return_expanded(),
- StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
- )
diff --git a/tests/types/__init__.py b/tests/types/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tests/types/__init__.py
diff --git a/tests/types/test_state.py b/tests/types/test_state.py
new file mode 100644
index 0000000000..eb809f9fb7
--- /dev/null
+++ b/tests/types/test_state.py
@@ -0,0 +1,627 @@
+from frozendict import frozendict
+
+from synapse.api.constants import EventTypes
+from synapse.types.state import StateFilter
+
+from tests.unittest import TestCase
+
+
+class StateFilterDifferenceTestCase(TestCase):
+ def assert_difference(
+ self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
+ ) -> None:
+ self.assertEqual(
+ minuend.approx_difference(subtrahend),
+ expected,
+ f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
+ )
+
+ def test_state_filter_difference_no_include_other_minus_no_include_other(
+ self,
+ ) -> None:
+ """
+ Tests the StateFilter.approx_difference method
+ where, in a.approx_difference(b), both a and b do not have the
+ include_others flag set.
+ """
+ # (wildcard on state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.Create: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+ include_others=False,
+ ),
+ StateFilter.freeze({EventTypes.Create: None}, include_others=False),
+ )
+
+ # (wildcard on state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ self.assert_difference(
+ StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+ StateFilter.freeze(
+ {EventTypes.Member: {"@wombat:spqr"}},
+ include_others=False,
+ ),
+ StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+ )
+
+ # (wildcard on state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.CanonicalAlias: {""}},
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (specific state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ )
+
+ def test_state_filter_difference_include_other_minus_no_include_other(self) -> None:
+ """
+ Tests the StateFilter.approx_difference method
+ where, in a.approx_difference(b), only a has the include_others flag set.
+ """
+ # (wildcard on state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.Create: None},
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Create: None,
+ EventTypes.Member: set(),
+ EventTypes.CanonicalAlias: set(),
+ },
+ include_others=True,
+ ),
+ )
+
+ # (wildcard on state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ # This also shows that the resultant state filter is normalised.
+ self.assert_difference(
+ StateFilter.freeze({EventTypes.Member: None}, include_others=True),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ EventTypes.Create: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter(types=frozendict(), include_others=True),
+ )
+
+ # (wildcard on state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=False,
+ ),
+ StateFilter(
+ types=frozendict(),
+ include_others=True,
+ ),
+ )
+
+ # (specific state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.CanonicalAlias: {""},
+ EventTypes.Member: set(),
+ },
+ include_others=True,
+ ),
+ )
+
+ # (specific state keys) - (specific state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ )
+
+ # (specific state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ )
+
+ def test_state_filter_difference_include_other_minus_include_other(self) -> None:
+ """
+ Tests the StateFilter.approx_difference method
+ where, in a.approx_difference(b), both a and b have the include_others
+ flag set.
+ """
+ # (wildcard on state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.Create: None},
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+ include_others=True,
+ ),
+ StateFilter(types=frozendict(), include_others=False),
+ )
+
+ # (wildcard on state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ self.assert_difference(
+ StateFilter.freeze({EventTypes.Member: None}, include_others=True),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+ include_others=False,
+ ),
+ )
+
+ # (wildcard on state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=True,
+ ),
+ StateFilter(
+ types=frozendict(),
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ EventTypes.Create: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ EventTypes.Create: set(),
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@spqr:spqr"},
+ EventTypes.Create: {""},
+ },
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ },
+ include_others=False,
+ ),
+ )
+
+ def test_state_filter_difference_no_include_other_minus_include_other(self) -> None:
+ """
+ Tests the StateFilter.approx_difference method
+ where, in a.approx_difference(b), only b has the include_others flag set.
+ """
+ # (wildcard on state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.Create: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
+ include_others=True,
+ ),
+ StateFilter(types=frozendict(), include_others=False),
+ )
+
+ # (wildcard on state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ self.assert_difference(
+ StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+ StateFilter.freeze(
+ {EventTypes.Member: {"@wombat:spqr"}},
+ include_others=True,
+ ),
+ StateFilter.freeze({EventTypes.Member: None}, include_others=False),
+ )
+
+ # (wildcard on state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (wildcard on state keys):
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=True,
+ ),
+ StateFilter(
+ types=frozendict(),
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (specific state keys)
+ # This one is an over-approximation because we can't represent
+ # 'all state keys except a few named examples'
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr"},
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@spqr:spqr"},
+ },
+ include_others=False,
+ ),
+ )
+
+ # (specific state keys) - (no state keys)
+ self.assert_difference(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ EventTypes.CanonicalAlias: {""},
+ },
+ include_others=False,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: set(),
+ },
+ include_others=True,
+ ),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
+ },
+ include_others=False,
+ ),
+ )
+
+ def test_state_filter_difference_simple_cases(self) -> None:
+ """
+ Tests some very simple cases of the StateFilter approx_difference,
+ that are not explicitly tested by the more in-depth tests.
+ """
+
+ self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none())
+
+ self.assert_difference(
+ StateFilter.all(),
+ StateFilter.none(),
+ StateFilter.all(),
+ )
+
+
+class StateFilterTestCase(TestCase):
+ def test_return_expanded(self) -> None:
+ """
+ Tests the behaviour of the return_expanded() function that expands
+ StateFilters to include more state types (for the sake of cache hit rate).
+ """
+
+ self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all())
+
+ self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none())
+
+ # Concrete-only state filters stay the same
+ # (Case: mixed filter)
+ self.assertEqual(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:test", "@alicia:test"},
+ "some.other.state.type": {""},
+ },
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:test", "@alicia:test"},
+ "some.other.state.type": {""},
+ },
+ include_others=False,
+ ),
+ )
+
+ # Concrete-only state filters stay the same
+ # (Case: non-member-only filter)
+ self.assertEqual(
+ StateFilter.freeze(
+ {"some.other.state.type": {""}}, include_others=False
+ ).return_expanded(),
+ StateFilter.freeze({"some.other.state.type": {""}}, include_others=False),
+ )
+
+ # Concrete-only state filters stay the same
+ # (Case: member-only filter)
+ self.assertEqual(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:test", "@alicia:test"},
+ },
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:test", "@alicia:test"},
+ },
+ include_others=False,
+ ),
+ )
+
+ # Wildcard member-only state filters stay the same
+ self.assertEqual(
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze(
+ {EventTypes.Member: None},
+ include_others=False,
+ ),
+ )
+
+ # If there is a wildcard in the non-member portion of the filter,
+ # it's expanded to include ALL non-member events.
+ # (Case: mixed filter)
+ self.assertEqual(
+ StateFilter.freeze(
+ {
+ EventTypes.Member: {"@wombat:test", "@alicia:test"},
+ "some.other.state.type": None,
+ },
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze(
+ {EventTypes.Member: {"@wombat:test", "@alicia:test"}},
+ include_others=True,
+ ),
+ )
+
+ # If there is a wildcard in the non-member portion of the filter,
+ # it's expanded to include ALL non-member events.
+ # (Case: non-member-only filter)
+ self.assertEqual(
+ StateFilter.freeze(
+ {
+ "some.other.state.type": None,
+ },
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
+ )
+ self.assertEqual(
+ StateFilter.freeze(
+ {
+ "some.other.state.type": None,
+ "yet.another.state.type": {"wombat"},
+ },
+ include_others=False,
+ ).return_expanded(),
+ StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
+ )
|