From f63bedef07360216a8de71dc38f00f1aea503903 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 7 Mar 2022 09:00:05 -0500 Subject: Invalidate caches when an event with a relation is redacted. (#12121) The caches for the target of the relation must be cleared so that the bundled aggregations are re-calculated after the redaction is processed. --- synapse/storage/databases/main/cache.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'synapse/storage/databases/main/cache.py') diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index c428dd5596..abd54c7dc7 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -200,6 +200,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_relations_for_event.invalidate((relates_to,)) self.get_aggregation_groups_for_event.invalidate((relates_to,)) self.get_applicable_edit.invalidate((relates_to,)) + self.get_thread_summary.invalidate((relates_to,)) + self.get_thread_participated.invalidate((relates_to,)) async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): """Invalidates the cache and adds it to the cache stream so slaves -- cgit 1.5.1 From 88cd6f937807e64c05458cec86ef0ba0c1c656b3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 10 Mar 2022 09:03:59 -0500 Subject: Allow retrieving the relations of a redacted event. (#12130) This is allowed per MSC2675, although the original implementation did not allow for it and would return an empty chunk / not bundle aggregations. The main thing to improve is that the various caches get cleared properly when an event is redacted, and that edits must not leak if the original event is redacted (as that would presumably leak something similar to the original event content). --- changelog.d/12130.bugfix | 1 + changelog.d/12189.bugfix | 1 + changelog.d/12189.misc | 1 - synapse/rest/client/relations.py | 82 +++++++++++++---------------- synapse/storage/databases/main/cache.py | 4 ++ synapse/storage/databases/main/events.py | 11 ++-- synapse/storage/databases/main/relations.py | 60 +++++++++++---------- tests/rest/client/test_relations.py | 45 ++++++++++++++-- 8 files changed, 122 insertions(+), 83 deletions(-) create mode 100644 changelog.d/12130.bugfix create mode 100644 changelog.d/12189.bugfix delete mode 100644 changelog.d/12189.misc (limited to 'synapse/storage/databases/main/cache.py') diff --git a/changelog.d/12130.bugfix b/changelog.d/12130.bugfix new file mode 100644 index 0000000000..df9b0dc413 --- /dev/null +++ b/changelog.d/12130.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when redacting events with relations. diff --git a/changelog.d/12189.bugfix b/changelog.d/12189.bugfix new file mode 100644 index 0000000000..df9b0dc413 --- /dev/null +++ b/changelog.d/12189.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when redacting events with relations. diff --git a/changelog.d/12189.misc b/changelog.d/12189.misc deleted file mode 100644 index 015e808e63..0000000000 --- a/changelog.d/12189.misc +++ /dev/null @@ -1 +0,0 @@ -Support skipping some arguments when generating cache keys. diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 07fa1cdd4c..d9a6be43f7 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -27,7 +27,7 @@ from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns -from synapse.storage.relations import AggregationPaginationToken, PaginationChunk +from synapse.storage.relations import AggregationPaginationToken from synapse.types import JsonDict, StreamToken if TYPE_CHECKING: @@ -82,28 +82,25 @@ class RelationPaginationServlet(RestServlet): from_token_str = parse_string(request, "from") to_token_str = parse_string(request, "to") - if event.internal_metadata.is_redacted(): - # If the event is redacted, return an empty list of relations - pagination_chunk = PaginationChunk(chunk=[]) - else: - # Return the relations - from_token = None - if from_token_str: - from_token = await StreamToken.from_string(self.store, from_token_str) - to_token = None - if to_token_str: - to_token = await StreamToken.from_string(self.store, to_token_str) - - pagination_chunk = await self.store.get_relations_for_event( - event_id=parent_id, - room_id=room_id, - relation_type=relation_type, - event_type=event_type, - limit=limit, - direction=direction, - from_token=from_token, - to_token=to_token, - ) + # Return the relations + from_token = None + if from_token_str: + from_token = await StreamToken.from_string(self.store, from_token_str) + to_token = None + if to_token_str: + to_token = await StreamToken.from_string(self.store, to_token_str) + + pagination_chunk = await self.store.get_relations_for_event( + event_id=parent_id, + event=event, + room_id=room_id, + relation_type=relation_type, + event_type=event_type, + limit=limit, + direction=direction, + from_token=from_token, + to_token=to_token, + ) events = await self.store.get_events_as_list( [c["event_id"] for c in pagination_chunk.chunk] @@ -193,27 +190,23 @@ class RelationAggregationPaginationServlet(RestServlet): from_token_str = parse_string(request, "from") to_token_str = parse_string(request, "to") - if event.internal_metadata.is_redacted(): - # If the event is redacted, return an empty list of relations - pagination_chunk = PaginationChunk(chunk=[]) - else: - # Return the relations - from_token = None - if from_token_str: - from_token = AggregationPaginationToken.from_string(from_token_str) - - to_token = None - if to_token_str: - to_token = AggregationPaginationToken.from_string(to_token_str) - - pagination_chunk = await self.store.get_aggregation_groups_for_event( - event_id=parent_id, - room_id=room_id, - event_type=event_type, - limit=limit, - from_token=from_token, - to_token=to_token, - ) + # Return the relations + from_token = None + if from_token_str: + from_token = AggregationPaginationToken.from_string(from_token_str) + + to_token = None + if to_token_str: + to_token = AggregationPaginationToken.from_string(to_token_str) + + pagination_chunk = await self.store.get_aggregation_groups_for_event( + event_id=parent_id, + room_id=room_id, + event_type=event_type, + limit=limit, + from_token=from_token, + to_token=to_token, + ) return 200, await pagination_chunk.to_dict(self.store) @@ -295,6 +288,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): result = await self.store.get_relations_for_event( event_id=parent_id, + event=event, room_id=room_id, relation_type=relation_type, event_type=event_type, diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index abd54c7dc7..d6a2df1afe 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -191,6 +191,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore): if redacts: self._invalidate_get_event_cache(redacts) + # Caches which might leak edits must be invalidated for the event being + # redacted. + self.get_relations_for_event.invalidate((redacts,)) + self.get_applicable_edit.invalidate((redacts,)) if etype == EventTypes.Member: self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 1dc83aa5e3..1a322882bf 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1619,9 +1619,12 @@ class PersistEventsStore: txn.call_after(prefill) - def _store_redaction(self, txn, event): - # invalidate the cache for the redacted event + def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None: + # Invalidate the caches for the redacted event, note that these caches + # are also cleared as part of event replication in _invalidate_caches_for_event. txn.call_after(self.store._invalidate_get_event_cache, event.redacts) + txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,)) + txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,)) self.db_pool.simple_upsert_txn( txn, @@ -1812,9 +1815,7 @@ class PersistEventsStore: txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) if rel_type == RelationTypes.THREAD: - txn.call_after( - self.store.get_thread_summary.invalidate, (parent_id, event.room_id) - ) + txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,)) # It should be safe to only invalidate the cache if the user has not # previously participated in the thread, but that's difficult (and # potentially error-prone) so it is always invalidated. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 36aa1092f6..be1500092b 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -91,10 +91,11 @@ class RelationsWorkerStore(SQLBaseStore): self._msc3440_enabled = hs.config.experimental.msc3440_enabled - @cached(tree=True) + @cached(uncached_args=("event",), tree=True) async def get_relations_for_event( self, event_id: str, + event: EventBase, room_id: str, relation_type: Optional[str] = None, event_type: Optional[str] = None, @@ -108,6 +109,7 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: Fetch events that relate to this event ID. + event: The matching EventBase to event_id. room_id: The room the event belongs to. relation_type: Only fetch events with this relation type, if given. event_type: Only fetch events with this event type, if given. @@ -122,9 +124,13 @@ class RelationsWorkerStore(SQLBaseStore): List of event IDs that match relations requested. The rows are of the form `{"event_id": "..."}`. """ + # We don't use `event_id`, it's there so that we can cache based on + # it. The `event_id` must match the `event.event_id`. + assert event.event_id == event_id where_clause = ["relates_to_id = ?", "room_id = ?"] - where_args: List[Union[str, int]] = [event_id, room_id] + where_args: List[Union[str, int]] = [event.event_id, room_id] + is_redacted = event.internal_metadata.is_redacted() if relation_type is not None: where_clause.append("relation_type = ?") @@ -157,7 +163,7 @@ class RelationsWorkerStore(SQLBaseStore): order = "ASC" sql = """ - SELECT event_id, topological_ordering, stream_ordering + SELECT event_id, relation_type, topological_ordering, stream_ordering FROM event_relations INNER JOIN events USING (event_id) WHERE %s @@ -178,9 +184,12 @@ class RelationsWorkerStore(SQLBaseStore): last_stream_id = None events = [] for row in txn: - events.append({"event_id": row[0]}) - last_topo_id = row[1] - last_stream_id = row[2] + # Do not include edits for redacted events as they leak event + # content. + if not is_redacted or row[1] != RelationTypes.REPLACE: + events.append({"event_id": row[0]}) + last_topo_id = row[2] + last_stream_id = row[3] # If there are more events, generate the next pagination key. next_token = None @@ -776,7 +785,7 @@ class RelationsWorkerStore(SQLBaseStore): ) references = await self.get_relations_for_event( - event_id, room_id, RelationTypes.REFERENCE, direction="f" + event_id, event, room_id, RelationTypes.REFERENCE, direction="f" ) if references.chunk: aggregations.references = await references.to_dict(cast("DataStore", self)) @@ -797,41 +806,36 @@ class RelationsWorkerStore(SQLBaseStore): A map of event ID to the bundled aggregation for the event. Not all events may have bundled aggregations in the results. """ - # The already processed event IDs. Tracked separately from the result - # since the result omits events which do not have bundled aggregations. - seen_event_ids = set() - - # State events and redacted events do not get bundled aggregations. - events = [ - event - for event in events - if not event.is_state() and not event.internal_metadata.is_redacted() - ] + # De-duplicate events by ID to handle the same event requested multiple times. + # + # State events do not get bundled aggregations. + events_by_id = { + event.event_id: event for event in events if not event.is_state() + } # event ID -> bundled aggregation in non-serialized form. results: Dict[str, BundledAggregations] = {} # Fetch other relations per event. - for event in events: - # De-duplicate events by ID to handle the same event requested multiple - # times. The caches that _get_bundled_aggregation_for_event use should - # capture this, but best to reduce work. - if event.event_id in seen_event_ids: - continue - seen_event_ids.add(event.event_id) - + for event in events_by_id.values(): event_result = await self._get_bundled_aggregation_for_event(event, user_id) if event_result: results[event.event_id] = event_result - # Fetch any edits. - edits = await self._get_applicable_edits(seen_event_ids) + # Fetch any edits (but not for redacted events). + edits = await self._get_applicable_edits( + [ + event_id + for event_id, event in events_by_id.items() + if not event.internal_metadata.is_redacted() + ] + ) for event_id, edit in edits.items(): results.setdefault(event_id, BundledAggregations()).replace = edit # Fetch thread summaries. if self._msc3440_enabled: - summaries = await self._get_thread_summaries(seen_event_ids) + summaries = await self._get_thread_summaries(events_by_id.keys()) # Only fetch participated for a limited selection based on what had # summaries. participated = await self._get_threads_participated( diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index a40a5de399..f9ae6e663f 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1475,12 +1475,13 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self.assertEqual(relations, {}) def test_redact_parent_annotation(self) -> None: - """Test that annotations of an event are redacted when the original event + """Test that annotations of an event are viewable when the original event is redacted. """ # Add a relation channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") self.assertEqual(200, channel.code, channel.json_body) + related_event_id = channel.json_body["event_id"] # The relations should exist. event_ids, relations = self._make_relation_requests() @@ -1494,11 +1495,45 @@ class RelationRedactionTestCase(BaseRelationsTestCase): # Redact the original event. self._redact(self.parent_id) - # The relations are not returned. + # The relations are returned. event_ids, relations = self._make_relation_requests() - self.assertEqual(event_ids, []) - self.assertEqual(relations, {}) + self.assertEquals(event_ids, [related_event_id]) + self.assertEquals( + relations["m.annotation"], + {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]}, + ) # There's nothing to aggregate. chunk = self._get_aggregations() - self.assertEqual(chunk, []) + self.assertEqual(chunk, [{"count": 1, "key": "👍", "type": "m.reaction"}]) + + @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) + def test_redact_parent_thread(self) -> None: + """ + Test that thread replies are still available when the root event is redacted. + """ + channel = self._send_relation( + RelationTypes.THREAD, + EventTypes.Message, + content={"body": "reply 1", "msgtype": "m.text"}, + ) + self.assertEqual(200, channel.code, channel.json_body) + related_event_id = channel.json_body["event_id"] + + # Redact one of the reactions. + self._redact(self.parent_id) + + # The unredacted relation should still exist. + event_ids, relations = self._make_relation_requests() + self.assertEquals(len(event_ids), 1) + self.assertDictContainsSubset( + { + "count": 1, + "current_user_participated": True, + }, + relations[RelationTypes.THREAD], + ) + self.assertEqual( + relations[RelationTypes.THREAD]["latest_event"]["event_id"], + related_event_id, + ) -- cgit 1.5.1 From c486fa5fd9082643e40a55ffa59d902aa6db4c2b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 16 Mar 2022 10:37:04 -0400 Subject: Add some missing type hints to cache datastore. (#12216) --- changelog.d/12216.misc | 1 + synapse/storage/databases/main/cache.py | 57 +++++++++++++++++++++------------ 2 files changed, 37 insertions(+), 21 deletions(-) create mode 100644 changelog.d/12216.misc (limited to 'synapse/storage/databases/main/cache.py') diff --git a/changelog.d/12216.misc b/changelog.d/12216.misc new file mode 100644 index 0000000000..dc398ac1e0 --- /dev/null +++ b/changelog.d/12216.misc @@ -0,0 +1 @@ +Add missing type hints for cache storage. diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index d6a2df1afe..2d7511d613 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -23,6 +23,7 @@ from synapse.replication.tcp.streams.events import ( EventsStream, EventsStreamCurrentStateRow, EventsStreamEventRow, + EventsStreamRow, ) from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -31,6 +32,7 @@ from synapse.storage.database import ( LoggingTransaction, ) from synapse.storage.engines import PostgresEngine +from synapse.util.caches.descriptors import _CachedFunction from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -82,7 +84,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): if last_id == current_id: return [], current_id, False - def get_all_updated_caches_txn(txn): + def get_all_updated_caches_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: # We purposefully don't bound by the current token, as we want to # send across cache invalidations as quickly as possible. Cache # invalidations are idempotent, so duplicates are fine. @@ -107,7 +111,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): "get_all_updated_caches", get_all_updated_caches_txn ) - def process_replication_rows(self, stream_name, instance_name, token, rows): + def process_replication_rows( + self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any] + ) -> None: if stream_name == EventsStream.NAME: for row in rows: self._process_event_stream_row(token, row) @@ -142,10 +148,11 @@ class CacheInvalidationWorkerStore(SQLBaseStore): super().process_replication_rows(stream_name, instance_name, token, rows) - def _process_event_stream_row(self, token, row): + def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None: data = row.data if row.type == EventsStreamEventRow.TypeId: + assert isinstance(data, EventsStreamEventRow) self._invalidate_caches_for_event( token, data.event_id, @@ -157,9 +164,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): backfilled=False, ) elif row.type == EventsStreamCurrentStateRow.TypeId: - self._curr_state_delta_stream_cache.entity_has_changed( - row.data.room_id, token - ) + assert isinstance(data, EventsStreamCurrentStateRow) + self._curr_state_delta_stream_cache.entity_has_changed(data.room_id, token) if data.type == EventTypes.Member: self.get_rooms_for_user_with_stream_ordering.invalidate( @@ -170,15 +176,15 @@ class CacheInvalidationWorkerStore(SQLBaseStore): def _invalidate_caches_for_event( self, - stream_ordering, - event_id, - room_id, - etype, - state_key, - redacts, - relates_to, - backfilled, - ): + stream_ordering: int, + event_id: str, + room_id: str, + etype: str, + state_key: Optional[str], + redacts: Optional[str], + relates_to: Optional[str], + backfilled: bool, + ) -> None: self._invalidate_get_event_cache(event_id) self.have_seen_event.invalidate((room_id, event_id)) @@ -207,7 +213,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_thread_summary.invalidate((relates_to,)) self.get_thread_participated.invalidate((relates_to,)) - async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): + async def invalidate_cache_and_stream( + self, cache_name: str, keys: Tuple[Any, ...] + ) -> None: """Invalidates the cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -227,7 +235,12 @@ class CacheInvalidationWorkerStore(SQLBaseStore): keys, ) - def _invalidate_cache_and_stream(self, txn, cache_func, keys): + def _invalidate_cache_and_stream( + self, + txn: LoggingTransaction, + cache_func: _CachedFunction, + keys: Tuple[Any, ...], + ) -> None: """Invalidates the cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -238,7 +251,9 @@ class CacheInvalidationWorkerStore(SQLBaseStore): txn.call_after(cache_func.invalidate, keys) self._send_invalidation_to_replication(txn, cache_func.__name__, keys) - def _invalidate_all_cache_and_stream(self, txn, cache_func): + def _invalidate_all_cache_and_stream( + self, txn: LoggingTransaction, cache_func: _CachedFunction + ) -> None: """Invalidates the entire cache and adds it to the cache stream so slaves will know to invalidate their caches. """ @@ -279,8 +294,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): ) def _send_invalidation_to_replication( - self, txn, cache_name: str, keys: Optional[Iterable[Any]] - ): + self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]] + ) -> None: """Notifies replication that given cache has been invalidated. Note that this does *not* invalidate the cache locally. @@ -315,7 +330,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): "instance_name": self._instance_name, "cache_func": cache_name, "keys": keys, - "invalidation_ts": self.clock.time_msec(), + "invalidation_ts": self._clock.time_msec(), }, ) -- cgit 1.5.1 From 7ca8ee67a5165e33f03454218c81be96397e7591 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 25 Mar 2022 14:58:56 +0000 Subject: Add cache for `get_membership_from_event_ids` (#12272) This should speed up push rule calculations for rooms with large numbers of local users when the main push rule cache fails. Co-authored-by: reivilibre --- changelog.d/12272.misc | 1 + synapse/push/bulk_push_rule_evaluator.py | 30 +++++++++++----------- synapse/storage/databases/main/cache.py | 4 +++ synapse/storage/databases/main/events.py | 7 ++++++ synapse/storage/databases/main/roommember.py | 37 +++++++++++++++++++++++++--- synapse/storage/persist_events.py | 15 ++++++++--- 6 files changed, 72 insertions(+), 22 deletions(-) create mode 100644 changelog.d/12272.misc (limited to 'synapse/storage/databases/main/cache.py') diff --git a/changelog.d/12272.misc b/changelog.d/12272.misc new file mode 100644 index 0000000000..95589f3361 --- /dev/null +++ b/changelog.d/12272.misc @@ -0,0 +1 @@ +Add a new cache `_get_membership_from_event_id` to speed up push rule calculations in large rooms. diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 030898e4d0..a402a3e403 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -24,6 +24,7 @@ from synapse.event_auth import get_user_power_level from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY +from synapse.storage.databases.main.roommember import EventIdMembership from synapse.util.async_helpers import Linearizer from synapse.util.caches import CacheMetric, register_cache from synapse.util.caches.descriptors import lru_cache @@ -292,7 +293,7 @@ def _condition_checker( return True -MemberMap = Dict[str, Tuple[str, str]] +MemberMap = Dict[str, Optional[EventIdMembership]] Rule = Dict[str, dict] RulesByUser = Dict[str, List[Rule]] StateGroup = Union[object, int] @@ -306,7 +307,7 @@ class RulesForRoomData: *only* include data, and not references to e.g. the data stores. """ - # event_id -> (user_id, state) + # event_id -> EventIdMembership member_map: MemberMap = attr.Factory(dict) # user_id -> rules rules_by_user: RulesByUser = attr.Factory(dict) @@ -447,11 +448,10 @@ class RulesForRoom: res = self.data.member_map.get(event_id, None) if res: - user_id, state = res - if state == Membership.JOIN: - rules = self.data.rules_by_user.get(user_id, None) + if res.membership == Membership.JOIN: + rules = self.data.rules_by_user.get(res.user_id, None) if rules: - ret_rules_by_user[user_id] = rules + ret_rules_by_user[res.user_id] = rules continue # If a user has left a room we remove their push rule. If they @@ -502,24 +502,26 @@ class RulesForRoom: """ sequence = self.data.sequence - rows = await self.store.get_membership_from_event_ids(member_event_ids.values()) - - members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows} + members = await self.store.get_membership_from_event_ids( + member_event_ids.values() + ) - # If the event is a join event then it will be in current state evnts + # If the event is a join event then it will be in current state events # map but not in the DB, so we have to explicitly insert it. if event.type == EventTypes.Member: for event_id in member_event_ids.values(): if event_id == event.event_id: - members[event_id] = (event.state_key, event.membership) + members[event_id] = EventIdMembership( + user_id=event.state_key, membership=event.membership + ) if logger.isEnabledFor(logging.DEBUG): logger.debug("Found members %r: %r", self.room_id, members.values()) joined_user_ids = { - user_id - for user_id, membership in members.values() - if membership == Membership.JOIN + entry.user_id + for entry in members.values() + if entry and entry.membership == Membership.JOIN } logger.debug("Joined: %r", joined_user_ids) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 2d7511d613..dd4e83a2ad 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -192,6 +192,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,)) + # The `_get_membership_from_event_id` is immutable, except for the + # case where we look up an event *before* persisting it. + self._get_membership_from_event_id.invalidate((event_id,)) + if not backfilled: self._events_stream_cache.entity_has_changed(room_id, stream_ordering) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 1f60aef180..d253243125 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1745,6 +1745,13 @@ class PersistEventsStore: (event.state_key,), ) + # The `_get_membership_from_event_id` is immutable, except for the + # case where we look up an event *before* persisting it. + txn.call_after( + self.store._get_membership_from_event_id.invalidate, + (event.event_id,), + ) + # We update the local_current_membership table only if the event is # "current", i.e., its something that has just happened. # diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index bef675b845..3248da5356 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -63,6 +63,14 @@ _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update" _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" +@attr.s(frozen=True, slots=True, auto_attribs=True) +class EventIdMembership: + """Returned by `get_membership_from_event_ids`""" + + user_id: str + membership: str + + class RoomMemberWorkerStore(EventsWorkerStore): def __init__( self, @@ -772,7 +780,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): retcols=("user_id", "display_name", "avatar_url", "event_id"), keyvalues={"membership": Membership.JOIN}, batch_size=500, - desc="_get_membership_from_event_ids", + desc="_get_joined_profiles_from_event_ids", ) return { @@ -1000,12 +1008,26 @@ class RoomMemberWorkerStore(EventsWorkerStore): return set(room_ids) + @cached(max_entries=5000) + async def _get_membership_from_event_id( + self, member_event_id: str + ) -> Optional[EventIdMembership]: + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_membership_from_event_id", list_name="member_event_ids" + ) async def get_membership_from_event_ids( self, member_event_ids: Iterable[str] - ) -> List[dict]: - """Get user_id and membership of a set of event IDs.""" + ) -> Dict[str, Optional[EventIdMembership]]: + """Get user_id and membership of a set of event IDs. + + Returns: + Mapping from event ID to `EventIdMembership` if the event is a + membership event, otherwise the value is None. + """ - return await self.db_pool.simple_select_many_batch( + rows = await self.db_pool.simple_select_many_batch( table="room_memberships", column="event_id", iterable=member_event_ids, @@ -1015,6 +1037,13 @@ class RoomMemberWorkerStore(EventsWorkerStore): desc="get_membership_from_event_ids", ) + return { + row["event_id"]: EventIdMembership( + membership=row["membership"], user_id=row["user_id"] + ) + for row in rows + } + async def is_local_host_in_room_ignoring_users( self, room_id: str, ignore_users: Collection[str] ) -> bool: diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index 7d543fdbe0..b402922817 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -1023,8 +1023,13 @@ class EventsPersistenceStorage: # Check if any of the changes that we don't have events for are joins. if events_to_check: - rows = await self.main_store.get_membership_from_event_ids(events_to_check) - is_still_joined = any(row["membership"] == Membership.JOIN for row in rows) + members = await self.main_store.get_membership_from_event_ids( + events_to_check + ) + is_still_joined = any( + member and member.membership == Membership.JOIN + for member in members.values() + ) if is_still_joined: return True @@ -1060,9 +1065,11 @@ class EventsPersistenceStorage: ), event_id in current_state.items() if typ == EventTypes.Member and not self.is_mine_id(state_key) ] - rows = await self.main_store.get_membership_from_event_ids(remote_event_ids) + members = await self.main_store.get_membership_from_event_ids(remote_event_ids) potentially_left_users.update( - row["user_id"] for row in rows if row["membership"] == Membership.JOIN + member.user_id + for member in members.values() + if member and member.membership == Membership.JOIN ) return False -- cgit 1.5.1