diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index eec55b6478..41b015dba1 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -33,6 +33,7 @@ from synapse.api.constants import MAX_DEPTH, EventTypes
from synapse.api.errors import StoreError
from synapse.api.room_versions import EventFormatVersions, RoomVersion
from synapse.events import EventBase, make_event_from_dict
+from synapse.logging.tracing import tag_args, trace
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
@@ -126,6 +127,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
return await self.get_events_as_list(event_ids)
+ @trace
+ @tag_args
async def get_auth_chain_ids(
self,
room_id: str,
@@ -709,6 +712,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Return all events where not all sets can reach them.
return {eid for eid, n in event_to_missing_sets.items() if n}
+ @trace
+ @tag_args
async def get_oldest_event_ids_with_depth_in_room(
self, room_id: str
) -> List[Tuple[str, int]]:
@@ -767,6 +772,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
room_id,
)
+ @trace
async def get_insertion_event_backward_extremities_in_room(
self, room_id: str
) -> List[Tuple[str, int]]:
@@ -1339,6 +1345,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
event_results.reverse()
return event_results
+ @trace
+ @tag_args
async def get_successor_events(self, event_id: str) -> List[str]:
"""Fetch all events that have the given event as a prev event
@@ -1375,6 +1383,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
_delete_old_forward_extrem_cache_txn,
)
+ @trace
async def insert_insertion_extremity(self, event_id: str, room_id: str) -> None:
await self.db_pool.simple_upsert(
table="insertion_event_extremities",
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 161aad0f89..eabf9c9739 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -74,7 +74,17 @@ receipt.
"""
import logging
-from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ List,
+ Mapping,
+ Optional,
+ Tuple,
+ Union,
+ cast,
+)
import attr
@@ -154,7 +164,9 @@ class NotifCounts:
highlight_count: int = 0
-def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str:
+def _serialize_action(
+ actions: Collection[Union[Mapping, str]], is_highlight: bool
+) -> str:
"""Custom serializer for actions. This allows us to "compress" common actions.
We use the fact that most users have the same actions for notifs (and for
@@ -227,7 +239,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id: str,
) -> NotifCounts:
"""Get the notification count, the highlight count and the unread message count
- for a given user in a given room after the given read receipt.
+ for a given user in a given room after their latest read receipt.
Note that this function assumes the user to be a current member of the room,
since it's either called by the sync handler to handle joined room entries, or by
@@ -238,9 +250,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id: The user to retrieve the counts for.
Returns
- A dict containing the counts mentioned earlier in this docstring,
- respectively under the keys "notify_count", "highlight_count" and
- "unread_count".
+ A NotifCounts object containing the notification count, the highlight count
+ and the unread message count.
"""
return await self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room",
@@ -255,6 +266,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
room_id: str,
user_id: str,
) -> NotifCounts:
+ # Get the stream ordering of the user's latest receipt in the room.
result = self.get_last_receipt_for_user_txn(
txn,
user_id,
@@ -266,13 +278,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
),
)
- stream_ordering = None
if result:
_, stream_ordering = result
- if stream_ordering is None:
- # Either last_read_event_id is None, or it's an event we don't have (e.g.
- # because it's been purged), in which case retrieve the stream ordering for
+ else:
+ # If the user has no receipts in the room, retrieve the stream ordering for
# the latest membership event from this user in this room (which we assume is
# a join).
event_id = self.db_pool.simple_select_one_onecol_txn(
@@ -289,10 +299,26 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
)
def _get_unread_counts_by_pos_txn(
- self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ user_id: str,
+ receipt_stream_ordering: int,
) -> NotifCounts:
"""Get the number of unread messages for a user/room that have happened
since the given stream ordering.
+
+ Args:
+ txn: The database transaction.
+ room_id: The room ID to get unread counts for.
+ user_id: The user ID to get unread counts for.
+ receipt_stream_ordering: The stream ordering of the user's latest
+ receipt in the room. If there are no receipts, the stream ordering
+ of the user's join event.
+
+ Returns
+ A NotifCounts object containing the notification count, the highlight count
+ and the unread message count.
"""
counts = NotifCounts()
@@ -320,7 +346,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
OR last_receipt_stream_ordering = ?
)
""",
- (room_id, user_id, stream_ordering, stream_ordering),
+ (room_id, user_id, receipt_stream_ordering, receipt_stream_ordering),
)
row = txn.fetchone()
@@ -338,17 +364,20 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
AND stream_ordering > ?
AND highlight = 1
"""
- txn.execute(sql, (user_id, room_id, stream_ordering))
+ txn.execute(sql, (user_id, room_id, receipt_stream_ordering))
row = txn.fetchone()
if row:
counts.highlight_count += row[0]
# Finally we need to count push actions that aren't included in the
- # summary returned above, e.g. recent events that haven't been
- # summarised yet, or the summary is empty due to a recent read receipt.
- stream_ordering = max(stream_ordering, summary_stream_ordering)
+ # summary returned above. This might be due to recent events that haven't
+ # been summarised yet or the summary is out of date due to a recent read
+ # receipt.
+ start_unread_stream_ordering = max(
+ receipt_stream_ordering, summary_stream_ordering
+ )
notify_count, unread_count = self._get_notif_unread_count_for_user_room(
- txn, room_id, user_id, stream_ordering
+ txn, room_id, user_id, start_unread_stream_ordering
)
counts.notify_count += notify_count
@@ -733,7 +762,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
async def add_push_actions_to_staging(
self,
event_id: str,
- user_id_actions: Dict[str, List[Union[dict, str]]],
+ user_id_actions: Dict[str, Collection[Union[Mapping, str]]],
count_as_unread: bool,
) -> None:
"""Add the push actions for the event to the push action staging area.
@@ -750,7 +779,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# This is a helper function for generating the necessary tuple that
# can be used to insert into the `event_push_actions_staging` table.
def _gen_entry(
- user_id: str, actions: List[Union[dict, str]]
+ user_id: str, actions: Collection[Union[Mapping, str]]
) -> Tuple[str, str, str, int, int, int]:
is_highlight = 1 if _action_has_highlight(actions) else 0
notif = 1 if "notify" in actions else 0
@@ -1151,8 +1180,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
txn: The database transaction.
old_rotate_stream_ordering: The previous maximum event stream ordering.
rotate_to_stream_ordering: The new maximum event stream ordering to summarise.
-
- Returns whether the archiving process has caught up or not.
"""
# Calculate the new counts that should be upserted into event_push_summary
@@ -1238,9 +1265,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
(rotate_to_stream_ordering,),
)
- async def _remove_old_push_actions_that_have_rotated(
- self,
- ) -> None:
+ async def _remove_old_push_actions_that_have_rotated(self) -> None:
"""Clear out old push actions that have been summarised."""
# We want to clear out anything that is older than a day that *has* already
@@ -1397,7 +1422,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
]
-def _action_has_highlight(actions: List[Union[dict, str]]) -> bool:
+def _action_has_highlight(actions: Collection[Union[Mapping, str]]) -> bool:
for action in actions:
if not isinstance(action, dict):
continue
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 5560b38a48..1c3b804da0 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -40,6 +40,7 @@ from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, relation_from_event
from synapse.events.snapshot import EventContext
+from synapse.logging.tracing import trace
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@@ -145,6 +146,7 @@ class PersistEventsStore:
self._backfill_id_gen: AbstractStreamIdGenerator = self.store._backfill_id_gen
self._stream_id_gen: AbstractStreamIdGenerator = self.store._stream_id_gen
+ @trace
async def _persist_events_and_state_updates(
self,
events_and_contexts: List[Tuple[EventBase, EventContext]],
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index e9ff6cfb34..90e6d82058 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -54,6 +54,7 @@ from synapse.logging.context import (
current_context,
make_deferred_yieldable,
)
+from synapse.logging.tracing import start_active_span, tag_args, trace
from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
@@ -430,6 +431,8 @@ class EventsWorkerStore(SQLBaseStore):
return {e.event_id: e for e in events}
+ @trace
+ @tag_args
async def get_events_as_list(
self,
event_ids: Collection[str],
@@ -1090,23 +1093,42 @@ class EventsWorkerStore(SQLBaseStore):
"""
fetched_event_ids: Set[str] = set()
fetched_events: Dict[str, _EventRow] = {}
- events_to_fetch = event_ids
- while events_to_fetch:
- row_map = await self._enqueue_events(events_to_fetch)
+ async def _fetch_event_ids_and_get_outstanding_redactions(
+ event_ids_to_fetch: Collection[str],
+ ) -> Collection[str]:
+ """
+ Fetch all of the given event_ids and return any associated redaction event_ids
+ that we still need to fetch in the next iteration.
+ """
+ row_map = await self._enqueue_events(event_ids_to_fetch)
# we need to recursively fetch any redactions of those events
redaction_ids: Set[str] = set()
- for event_id in events_to_fetch:
+ for event_id in event_ids_to_fetch:
row = row_map.get(event_id)
fetched_event_ids.add(event_id)
if row:
fetched_events[event_id] = row
redaction_ids.update(row.redactions)
- events_to_fetch = redaction_ids.difference(fetched_event_ids)
- if events_to_fetch:
- logger.debug("Also fetching redaction events %s", events_to_fetch)
+ event_ids_to_fetch = redaction_ids.difference(fetched_event_ids)
+ return event_ids_to_fetch
+
+ # Grab the initial list of events requested
+ event_ids_to_fetch = await _fetch_event_ids_and_get_outstanding_redactions(
+ event_ids
+ )
+ # Then go and recursively find all of the associated redactions
+ with start_active_span("recursively fetching redactions"):
+ while event_ids_to_fetch:
+ logger.debug("Also fetching redaction events %s", event_ids_to_fetch)
+
+ event_ids_to_fetch = (
+ await _fetch_event_ids_and_get_outstanding_redactions(
+ event_ids_to_fetch
+ )
+ )
# build a map from event_id to EventBase
event_map: Dict[str, EventBase] = {}
@@ -1424,6 +1446,8 @@ class EventsWorkerStore(SQLBaseStore):
return {r["event_id"] for r in rows}
+ @trace
+ @tag_args
async def have_seen_events(
self, room_id: str, event_ids: Iterable[str]
) -> Set[str]:
@@ -2200,3 +2224,63 @@ class EventsWorkerStore(SQLBaseStore):
(room_id,),
)
return [row[0] for row in txn]
+
+ def mark_event_rejected_txn(
+ self,
+ txn: LoggingTransaction,
+ event_id: str,
+ rejection_reason: Optional[str],
+ ) -> None:
+ """Mark an event that was previously accepted as rejected, or vice versa
+
+ This can happen, for example, when resyncing state during a faster join.
+
+ Args:
+ txn:
+ event_id: ID of event to update
+ rejection_reason: reason it has been rejected, or None if it is now accepted
+ """
+ if rejection_reason is None:
+ logger.info(
+ "Marking previously-processed event %s as accepted",
+ event_id,
+ )
+ self.db_pool.simple_delete_txn(
+ txn,
+ "rejections",
+ keyvalues={"event_id": event_id},
+ )
+ else:
+ logger.info(
+ "Marking previously-processed event %s as rejected(%s)",
+ event_id,
+ rejection_reason,
+ )
+ self.db_pool.simple_upsert_txn(
+ txn,
+ table="rejections",
+ keyvalues={"event_id": event_id},
+ values={
+ "reason": rejection_reason,
+ "last_check": self._clock.time_msec(),
+ },
+ )
+ self.db_pool.simple_update_txn(
+ txn,
+ table="events",
+ keyvalues={"event_id": event_id},
+ updatevalues={"rejection_reason": rejection_reason},
+ )
+
+ self.invalidate_get_event_cache_after_txn(txn, event_id)
+
+ # TODO(faster_joins): invalidate the cache on workers. Ideally we'd just
+ # call '_send_invalidation_to_replication', but we actually need the other
+ # end to call _invalidate_local_get_event_cache() rather than (just)
+ # _get_event_cache.invalidate().
+ #
+ # One solution might be to (somehow) get the workers to call
+ # _invalidate_caches_for_event() (though that will invalidate more than
+ # strictly necessary).
+ #
+ # https://github.com/matrix-org/synapse/issues/12994
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 768f95d16c..255620f996 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -14,11 +14,23 @@
# limitations under the License.
import abc
import logging
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Tuple, Union, cast
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Collection,
+ Dict,
+ List,
+ Mapping,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+ cast,
+)
from synapse.api.errors import StoreError
from synapse.config.homeserver import ExperimentalConfig
-from synapse.push.baserules import list_with_base_rules
+from synapse.push.baserules import FilteredPushRules, PushRule, compile_push_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
@@ -50,60 +62,30 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-def _is_experimental_rule_enabled(
- rule_id: str, experimental_config: ExperimentalConfig
-) -> bool:
- """Used by `_load_rules` to filter out experimental rules when they
- have not been enabled.
- """
- if (
- rule_id == "global/override/.org.matrix.msc3786.rule.room.server_acl"
- and not experimental_config.msc3786_enabled
- ):
- return False
- if (
- rule_id == "global/underride/.org.matrix.msc3772.thread_reply"
- and not experimental_config.msc3772_enabled
- ):
- return False
- return True
-
-
def _load_rules(
rawrules: List[JsonDict],
enabled_map: Dict[str, bool],
experimental_config: ExperimentalConfig,
-) -> List[JsonDict]:
- ruleslist = []
- for rawrule in rawrules:
- rule = dict(rawrule)
- rule["conditions"] = db_to_json(rawrule["conditions"])
- rule["actions"] = db_to_json(rawrule["actions"])
- rule["default"] = False
- ruleslist.append(rule)
-
- # We're going to be mutating this a lot, so copy it. We also filter out
- # any experimental default push rules that aren't enabled.
- rules = [
- rule
- for rule in list_with_base_rules(ruleslist)
- if _is_experimental_rule_enabled(rule["rule_id"], experimental_config)
- ]
+) -> FilteredPushRules:
+ """Take the DB rows returned from the DB and convert them into a full
+ `FilteredPushRules` object.
+ """
- for i, rule in enumerate(rules):
- rule_id = rule["rule_id"]
+ ruleslist = [
+ PushRule(
+ rule_id=rawrule["rule_id"],
+ priority_class=rawrule["priority_class"],
+ conditions=db_to_json(rawrule["conditions"]),
+ actions=db_to_json(rawrule["actions"]),
+ )
+ for rawrule in rawrules
+ ]
- if rule_id not in enabled_map:
- continue
- if rule.get("enabled", True) == bool(enabled_map[rule_id]):
- continue
+ push_rules = compile_push_rules(ruleslist)
- # Rules are cached across users.
- rule = dict(rule)
- rule["enabled"] = bool(enabled_map[rule_id])
- rules[i] = rule
+ filtered_rules = FilteredPushRules(push_rules, enabled_map, experimental_config)
- return rules
+ return filtered_rules
# The ABCMeta metaclass ensures that it cannot be instantiated without
@@ -162,7 +144,7 @@ class PushRulesWorkerStore(
raise NotImplementedError()
@cached(max_entries=5000)
- async def get_push_rules_for_user(self, user_id: str) -> List[JsonDict]:
+ async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
rows = await self.db_pool.simple_select_list(
table="push_rules",
keyvalues={"user_name": user_id},
@@ -216,11 +198,11 @@ class PushRulesWorkerStore(
@cachedList(cached_method_name="get_push_rules_for_user", list_name="user_ids")
async def bulk_get_push_rules(
self, user_ids: Collection[str]
- ) -> Dict[str, List[JsonDict]]:
+ ) -> Dict[str, FilteredPushRules]:
if not user_ids:
return {}
- results: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
+ raw_rules: Dict[str, List[JsonDict]] = {user_id: [] for user_id in user_ids}
rows = await self.db_pool.simple_select_many_batch(
table="push_rules",
@@ -234,11 +216,13 @@ class PushRulesWorkerStore(
rows.sort(key=lambda row: (-int(row["priority_class"]), -int(row["priority"])))
for row in rows:
- results.setdefault(row["user_name"], []).append(row)
+ raw_rules.setdefault(row["user_name"], []).append(row)
enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
- for user_id, rules in results.items():
+ results: Dict[str, FilteredPushRules] = {}
+
+ for user_id, rules in raw_rules.items():
results[user_id] = _load_rules(
rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental
)
@@ -345,8 +329,8 @@ class PushRuleStore(PushRulesWorkerStore):
user_id: str,
rule_id: str,
priority_class: int,
- conditions: List[Dict[str, str]],
- actions: List[Union[JsonDict, str]],
+ conditions: Sequence[Mapping[str, str]],
+ actions: Sequence[Union[Mapping[str, Any], str]],
before: Optional[str] = None,
after: Optional[str] = None,
) -> None:
@@ -817,7 +801,7 @@ class PushRuleStore(PushRulesWorkerStore):
return self._push_rules_stream_id_gen.get_current_token()
async def copy_push_rule_from_room_to_room(
- self, new_room_id: str, user_id: str, rule: dict
+ self, new_room_id: str, user_id: str, rule: PushRule
) -> None:
"""Copy a single push rule from one room to another for a specific user.
@@ -827,21 +811,27 @@ class PushRuleStore(PushRulesWorkerStore):
rule: A push rule.
"""
# Create new rule id
- rule_id_scope = "/".join(rule["rule_id"].split("/")[:-1])
+ rule_id_scope = "/".join(rule.rule_id.split("/")[:-1])
new_rule_id = rule_id_scope + "/" + new_room_id
+ new_conditions = []
+
# Change room id in each condition
- for condition in rule.get("conditions", []):
+ for condition in rule.conditions:
+ new_condition = condition
if condition.get("key") == "room_id":
- condition["pattern"] = new_room_id
+ new_condition = dict(condition)
+ new_condition["pattern"] = new_room_id
+
+ new_conditions.append(new_condition)
# Add the rule for the new room
await self.add_push_rule(
user_id=user_id,
rule_id=new_rule_id,
- priority_class=rule["priority_class"],
- conditions=rule["conditions"],
- actions=rule["actions"],
+ priority_class=rule.priority_class,
+ conditions=new_conditions,
+ actions=rule.actions,
)
async def copy_push_rules_from_room_to_room_for_user(
@@ -859,8 +849,11 @@ class PushRuleStore(PushRulesWorkerStore):
user_push_rules = await self.get_push_rules_for_user(user_id)
# Get rules relating to the old room and copy them to the new room
- for rule in user_push_rules:
- conditions = rule.get("conditions", [])
+ for rule, enabled in user_push_rules:
+ if not enabled:
+ continue
+
+ conditions = rule.conditions
if any(
(c.get("key") == "room_id" and c.get("pattern") == old_room_id)
for c in conditions
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 0090c9f225..124c70ad37 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -161,7 +161,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_type: The receipt types to fetch.
Returns:
- The latest receipt, if one exists.
+ The event ID and stream ordering of the latest receipt, if one exists.
"""
clause, args = make_in_list_sql_clause(
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 0f1f0d11ea..b7d4baa6bb 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -2001,9 +2001,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
+ # We join on room_stats_state despite not using any columns from it
+ # because the join can influence the number of rows returned;
+ # e.g. a room that doesn't have state, maybe because it was deleted.
+ # The query returning the total count should be consistent with
+ # the query returning the results.
sql = """
SELECT COUNT(*) as total_event_reports
FROM event_reports AS er
+ JOIN room_stats_state ON room_stats_state.room_id = er.room_id
{}
""".format(
where_clause
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 93ff4816c8..827c1f1efd 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -283,6 +283,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
Returns:
A mapping from user ID to ProfileInfo.
+
+ Preconditions:
+ - There is full state available for the room (it is not partial-stated).
"""
def _get_users_in_room_with_profiles(
@@ -1212,6 +1215,30 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"get_forgotten_rooms_for_user", _get_forgotten_rooms_for_user_txn
)
+ async def is_locally_forgotten_room(self, room_id: str) -> bool:
+ """Returns whether all local users have forgotten this room_id.
+
+ Args:
+ room_id: The room ID to query.
+
+ Returns:
+ Whether the room is forgotten.
+ """
+
+ sql = """
+ SELECT count(*) > 0 FROM local_current_membership
+ INNER JOIN room_memberships USING (room_id, event_id)
+ WHERE
+ room_id = ?
+ AND forgotten = 0;
+ """
+
+ rows = await self.db_pool.execute("is_forgotten_room", None, sql, room_id)
+
+ # `count(*)` returns always an integer
+ # If any rows still exist it means someone has not forgotten this room yet
+ return not rows[0][0]
+
async def get_rooms_user_has_been_in(self, user_id: str) -> Set[str]:
"""Get all rooms that the user has ever been in.
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index f70705a0af..0b10af0e58 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -430,6 +430,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
updatevalues={"state_group": state_group},
)
+ # the event may now be rejected where it was not before, or vice versa,
+ # in which case we need to update the rejected flags.
+ if bool(context.rejected) != (event.rejected_reason is not None):
+ self.mark_event_rejected_txn(txn, event.event_id, context.rejected)
+
self.db_pool.simple_delete_one_txn(
txn,
table="partial_state_events",
|