diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 8f797e3ae9..e96f9999a8 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -11,16 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import enum
import logging
-from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, Optional
import attr
-from synapse.api.constants import RelationTypes
+from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase, relation_from_event
-from synapse.storage.databases.main.relations import _RelatedEvent
-from synapse.types import JsonDict, Requester, StreamToken, UserID
+from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.logging.opentracing import trace
+from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent
+from synapse.streams.config import PaginationConfig
+from synapse.types import JsonDict, Requester, UserID
+from synapse.util.async_helpers import gather_results
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
@@ -30,6 +35,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+class ThreadsListInclude(str, enum.Enum):
+ """Valid values for the 'include' flag of /threads."""
+
+ all = "all"
+ participated = "participated"
+
+
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation:
# The latest event in the thread.
@@ -65,18 +77,17 @@ class RelationsHandler:
self._clock = hs.get_clock()
self._event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
+ self._event_creation_handler = hs.get_event_creation_handler()
async def get_relations(
self,
requester: Requester,
event_id: str,
room_id: str,
+ pagin_config: PaginationConfig,
+ include_original_event: bool,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
- limit: int = 5,
- direction: str = "b",
- from_token: Optional[StreamToken] = None,
- to_token: Optional[StreamToken] = None,
) -> JsonDict:
"""Get related events of a event, ordered by topological ordering.
@@ -86,13 +97,10 @@ class RelationsHandler:
requester: The user requesting the relations.
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
+ pagin_config: The pagination config rules to apply, if any.
+ include_original_event: Whether to include the parent event.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
- limit: Only fetch the most recent `limit` events.
- direction: Whether to fetch the most recent first (`"b"`) or the
- oldest first (`"f"`).
- from_token: Fetch rows from the given token, or from the start if None.
- to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
The pagination chunk.
@@ -102,7 +110,7 @@ class RelationsHandler:
# TODO Properly handle a user leaving a room.
(_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
- room_id, user_id, allow_departed_users=True
+ room_id, requester, allow_departed_users=True
)
# This gets the original event and checks that a) the event exists and
@@ -120,10 +128,10 @@ class RelationsHandler:
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
- limit=limit,
- direction=direction,
- from_token=from_token,
- to_token=to_token,
+ limit=pagin_config.limit,
+ direction=pagin_config.direction,
+ from_token=pagin_config.from_token,
+ to_token=pagin_config.to_token,
)
events = await self._main_store.get_events_as_list(
@@ -137,113 +145,189 @@ class RelationsHandler:
is_peeking=(member_event_id is None),
)
- now = self._clock.time_msec()
- # Do not bundle aggregations when retrieving the original event because
- # we want the content before relations are applied to it.
- original_event = self._event_serializer.serialize_event(
- event, now, bundle_aggregations=None
- )
# The relations returned for the requested event do include their
# bundled aggregations.
aggregations = await self.get_bundled_aggregations(
events, requester.user.to_string()
)
- serialized_events = self._event_serializer.serialize_events(
- events, now, bundle_aggregations=aggregations
- )
- return_value = {
- "chunk": serialized_events,
- "original_event": original_event,
+ now = self._clock.time_msec()
+ return_value: JsonDict = {
+ "chunk": self._event_serializer.serialize_events(
+ events, now, bundle_aggregations=aggregations
+ ),
}
+ if include_original_event:
+ # Do not bundle aggregations when retrieving the original event because
+ # we want the content before relations are applied to it.
+ return_value["original_event"] = self._event_serializer.serialize_event(
+ event, now, bundle_aggregations=None
+ )
if next_token:
return_value["next_batch"] = await next_token.to_string(self._main_store)
- if from_token:
- return_value["prev_batch"] = await from_token.to_string(self._main_store)
+ if pagin_config.from_token:
+ return_value["prev_batch"] = await pagin_config.from_token.to_string(
+ self._main_store
+ )
return return_value
- async def get_relations_for_event(
+ async def redact_events_related_to(
self,
+ requester: Requester,
event_id: str,
- event: EventBase,
- room_id: str,
- relation_type: str,
- ignored_users: FrozenSet[str] = frozenset(),
- ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
- """Get a list of events which relate to an event, ordered by topological ordering.
+ initial_redaction_event: EventBase,
+ relation_types: List[str],
+ ) -> None:
+ """Redacts all events related to the given event ID with one of the given
+ relation types.
- Args:
- event_id: Fetch events that relate to this event ID.
- event: The matching EventBase to event_id.
- room_id: The room the event belongs to.
- relation_type: The type of relation.
- ignored_users: The users ignored by the requesting user.
+ This method is expected to be called when redacting the event referred to by
+ the given event ID.
- Returns:
- List of event IDs that match relations requested. The rows are of
- the form `{"event_id": "..."}`.
- """
+ If an event cannot be redacted (e.g. because of insufficient permissions), log
+ the error and try to redact the next one.
- # Call the underlying storage method, which is cached.
- related_events, next_token = await self._main_store.get_relations_for_event(
- event_id, event, room_id, relation_type, direction="f"
+ Args:
+ requester: The requester to redact events on behalf of.
+ event_id: The event IDs to look and redact relations of.
+ initial_redaction_event: The redaction for the event referred to by
+ event_id.
+ relation_types: The types of relations to look for.
+
+ Raises:
+ ShadowBanError if the requester is shadow-banned
+ """
+ related_event_ids = (
+ await self._main_store.get_all_relations_for_event_with_types(
+ event_id, relation_types
+ )
)
- # Filter out ignored users and convert to the expected format.
- related_events = [
- event for event in related_events if event.sender not in ignored_users
- ]
-
- return related_events, next_token
+ for related_event_id in related_event_ids:
+ try:
+ await self._event_creation_handler.create_and_send_nonmember_event(
+ requester,
+ {
+ "type": EventTypes.Redaction,
+ "content": initial_redaction_event.content,
+ "room_id": initial_redaction_event.room_id,
+ "sender": requester.user.to_string(),
+ "redacts": related_event_id,
+ },
+ ratelimit=False,
+ )
+ except SynapseError as e:
+ logger.warning(
+ "Failed to redact event %s (related to event %s): %s",
+ related_event_id,
+ event_id,
+ e.msg,
+ )
- async def get_annotations_for_event(
- self,
- event_id: str,
- room_id: str,
- limit: int = 5,
- ignored_users: FrozenSet[str] = frozenset(),
- ) -> List[JsonDict]:
- """Get a list of annotations on the event, grouped by event type and
+ async def get_annotations_for_events(
+ self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
+ ) -> Dict[str, List[JsonDict]]:
+ """Get a list of annotations to the given events, grouped by event type and
aggregation key, sorted by count.
- This is used e.g. to get the what and how many reactions have happend
+ This is used e.g. to get the what and how many reactions have happened
on an event.
Args:
- event_id: Fetch events that relate to this event ID.
- room_id: The room the event belongs to.
- limit: Only fetch the `limit` groups.
+ event_ids: Fetch events that relate to these event IDs.
ignored_users: The users ignored by the requesting user.
Returns:
- List of groups of annotations that match. Each row is a dict with
- `type`, `key` and `count` fields.
+ A map of event IDs to a list of groups of annotations that match.
+ Each entry is a dict with `type`, `key` and `count` fields.
"""
# Get the base results for all users.
- full_results = await self._main_store.get_aggregation_groups_for_event(
- event_id, room_id, limit
+ full_results = await self._main_store.get_aggregation_groups_for_events(
+ event_ids
)
+ # Avoid additional logic if there are no ignored users.
+ if not ignored_users:
+ return {
+ event_id: results
+ for event_id, results in full_results.items()
+ if results
+ }
+
# Then subtract off the results for any ignored users.
ignored_results = await self._main_store.get_aggregation_groups_for_users(
- event_id, room_id, limit, ignored_users
+ [event_id for event_id, results in full_results.items() if results],
+ ignored_users,
)
- filtered_results = []
- for result in full_results:
- key = (result["type"], result["key"])
- if key in ignored_results:
- result = result.copy()
- result["count"] -= ignored_results[key]
- if result["count"] <= 0:
- continue
- filtered_results.append(result)
+ filtered_results = {}
+ for event_id, results in full_results.items():
+ # If no annotations, skip.
+ if not results:
+ continue
+
+ # If there are not ignored results for this event, copy verbatim.
+ if event_id not in ignored_results:
+ filtered_results[event_id] = results
+ continue
+
+ # Otherwise, subtract out the ignored results.
+ event_ignored_results = ignored_results[event_id]
+ for result in results:
+ key = (result["type"], result["key"])
+ if key in event_ignored_results:
+ # Ensure to not modify the cache.
+ result = result.copy()
+ result["count"] -= event_ignored_results[key]
+ if result["count"] <= 0:
+ continue
+ filtered_results.setdefault(event_id, []).append(result)
return filtered_results
+ async def get_references_for_events(
+ self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
+ ) -> Dict[str, List[_RelatedEvent]]:
+ """Get a list of references to the given events.
+
+ Args:
+ event_ids: Fetch events that relate to this event ID.
+ ignored_users: The users ignored by the requesting user.
+
+ Returns:
+ A map of event IDs to a list related events.
+ """
+
+ related_events = await self._main_store.get_references_for_events(event_ids)
+
+ # Avoid additional logic if there are no ignored users.
+ if not ignored_users:
+ return {
+ event_id: results
+ for event_id, results in related_events.items()
+ if results
+ }
+
+ # Filter out ignored users.
+ results = {}
+ for event_id, events in related_events.items():
+ # If no references, skip.
+ if not events:
+ continue
+
+ # Filter ignored users out.
+ events = [event for event in events if event.sender not in ignored_users]
+ # If there are no events left, skip this event.
+ if not events:
+ continue
+
+ results[event_id] = events
+
+ return results
+
async def _get_threads_for_events(
self,
events_by_id: Dict[str, EventBase],
@@ -306,61 +390,69 @@ class RelationsHandler:
results = {}
for event_id, summary in summaries.items():
- if summary:
- thread_count, latest_thread_event = summary
-
- # Subtract off the count of any ignored users.
- for ignored_user in ignored_users:
- thread_count -= ignored_results.get((event_id, ignored_user), 0)
-
- # This is gnarly, but if the latest event is from an ignored user,
- # attempt to find one that isn't from an ignored user.
- if latest_thread_event.sender in ignored_users:
- room_id = latest_thread_event.room_id
-
- # If the root event is not found, something went wrong, do
- # not include a summary of the thread.
- event = await self._event_handler.get_event(user, room_id, event_id)
- if event is None:
- continue
+ # If no thread, skip.
+ if not summary:
+ continue
- potential_events, _ = await self.get_relations_for_event(
- event_id,
- event,
- room_id,
- RelationTypes.THREAD,
- ignored_users,
- )
+ thread_count, latest_thread_event = summary
- # If all found events are from ignored users, do not include
- # a summary of the thread.
- if not potential_events:
- continue
+ # Subtract off the count of any ignored users.
+ for ignored_user in ignored_users:
+ thread_count -= ignored_results.get((event_id, ignored_user), 0)
- # The *last* event returned is the one that is cared about.
- event = await self._event_handler.get_event(
- user, room_id, potential_events[-1].event_id
- )
- # It is unexpected that the event will not exist.
- if event is None:
- logger.warning(
- "Unable to fetch latest event in a thread with event ID: %s",
- potential_events[-1].event_id,
- )
- continue
- latest_thread_event = event
-
- results[event_id] = _ThreadAggregation(
- latest_event=latest_thread_event,
- count=thread_count,
- # If there's a thread summary it must also exist in the
- # participated dictionary.
- current_user_participated=events_by_id[event_id].sender == user_id
- or participated[event_id],
+ # This is gnarly, but if the latest event is from an ignored user,
+ # attempt to find one that isn't from an ignored user.
+ if latest_thread_event.sender in ignored_users:
+ room_id = latest_thread_event.room_id
+
+ # If the root event is not found, something went wrong, do
+ # not include a summary of the thread.
+ event = await self._event_handler.get_event(user, room_id, event_id)
+ if event is None:
+ continue
+
+ # Attempt to find another event to use as the latest event.
+ potential_events, _ = await self._main_store.get_relations_for_event(
+ event_id, event, room_id, RelationTypes.THREAD, direction="f"
)
+ # Filter out ignored users.
+ potential_events = [
+ event
+ for event in potential_events
+ if event.sender not in ignored_users
+ ]
+
+ # If all found events are from ignored users, do not include
+ # a summary of the thread.
+ if not potential_events:
+ continue
+
+ # The *last* event returned is the one that is cared about.
+ event = await self._event_handler.get_event(
+ user, room_id, potential_events[-1].event_id
+ )
+ # It is unexpected that the event will not exist.
+ if event is None:
+ logger.warning(
+ "Unable to fetch latest event in a thread with event ID: %s",
+ potential_events[-1].event_id,
+ )
+ continue
+ latest_thread_event = event
+
+ results[event_id] = _ThreadAggregation(
+ latest_event=latest_thread_event,
+ count=thread_count,
+ # If there's a thread summary it must also exist in the
+ # participated dictionary.
+ current_user_participated=events_by_id[event_id].sender == user_id
+ or participated[event_id],
+ )
+
return results
+ @trace
async def get_bundled_aggregations(
self, events: Iterable[EventBase], user_id: str
) -> Dict[str, BundledAggregations]:
@@ -435,48 +527,131 @@ class RelationsHandler:
# (as that is what makes it part of the thread).
relations_by_id[latest_thread_event.event_id] = RelationTypes.THREAD
- # Fetch other relations per event.
- for event in events_by_id.values():
- # Fetch any annotations (ie, reactions) to bundle with this event.
- annotations = await self.get_annotations_for_event(
- event.event_id, event.room_id, ignored_users=ignored_users
+ async def _fetch_annotations() -> None:
+ """Fetch any annotations (ie, reactions) to bundle with this event."""
+ annotations_by_event_id = await self.get_annotations_for_events(
+ events_by_id.keys(), ignored_users=ignored_users
)
- if annotations:
- results.setdefault(
- event.event_id, BundledAggregations()
- ).annotations = {"chunk": annotations}
-
- # Fetch any references to bundle with this event.
- references, next_token = await self.get_relations_for_event(
- event.event_id,
- event,
- event.room_id,
- RelationTypes.REFERENCE,
- ignored_users=ignored_users,
+ for event_id, annotations in annotations_by_event_id.items():
+ if annotations:
+ results.setdefault(event_id, BundledAggregations()).annotations = {
+ "chunk": annotations
+ }
+
+ async def _fetch_references() -> None:
+ """Fetch any references to bundle with this event."""
+ references_by_event_id = await self.get_references_for_events(
+ events_by_id.keys(), ignored_users=ignored_users
+ )
+ for event_id, references in references_by_event_id.items():
+ if references:
+ results.setdefault(event_id, BundledAggregations()).references = {
+ "chunk": [{"event_id": ev.event_id} for ev in references]
+ }
+
+ async def _fetch_edits() -> None:
+ """
+ Fetch any edits (but not for redacted events).
+
+ Note that there is no use in limiting edits by ignored users since the
+ parent event should be ignored in the first place if the user is ignored.
+ """
+ edits = await self._main_store.get_applicable_edits(
+ [
+ event_id
+ for event_id, event in events_by_id.items()
+ if not event.internal_metadata.is_redacted()
+ ]
+ )
+ for event_id, edit in edits.items():
+ results.setdefault(event_id, BundledAggregations()).replace = edit
+
+ # Parallelize the calls for annotations, references, and edits since they
+ # are unrelated.
+ await make_deferred_yieldable(
+ gather_results(
+ (
+ run_in_background(_fetch_annotations),
+ run_in_background(_fetch_references),
+ run_in_background(_fetch_edits),
+ )
)
- if references:
- aggregations = results.setdefault(event.event_id, BundledAggregations())
- aggregations.references = {
- "chunk": [{"event_id": ev.event_id} for ev in references]
- }
-
- if next_token:
- aggregations.references["next_batch"] = await next_token.to_string(
- self._main_store
- )
-
- # Fetch any edits (but not for redacted events).
- #
- # Note that there is no use in limiting edits by ignored users since the
- # parent event should be ignored in the first place if the user is ignored.
- edits = await self._main_store.get_applicable_edits(
- [
- event_id
- for event_id, event in events_by_id.items()
- if not event.internal_metadata.is_redacted()
- ]
)
- for event_id, edit in edits.items():
- results.setdefault(event_id, BundledAggregations()).replace = edit
return results
+
+ async def get_threads(
+ self,
+ requester: Requester,
+ room_id: str,
+ include: ThreadsListInclude,
+ limit: int = 5,
+ from_token: Optional[ThreadsNextBatch] = None,
+ ) -> JsonDict:
+ """Get related events of a event, ordered by topological ordering.
+
+ Args:
+ requester: The user requesting the relations.
+ room_id: The room the event belongs to.
+ include: One of "all" or "participated" to indicate which threads should
+ be returned.
+ limit: Only fetch the most recent `limit` events.
+ from_token: Fetch rows from the given token, or from the start if None.
+
+ Returns:
+ The pagination chunk.
+ """
+
+ user_id = requester.user.to_string()
+
+ # TODO Properly handle a user leaving a room.
+ (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
+ room_id, requester, allow_departed_users=True
+ )
+
+ # Note that ignored users are not passed into get_threads
+ # below. Ignored users are handled in filter_events_for_client (and by
+ # not passing them in here we should get a better cache hit rate).
+ thread_roots, next_batch = await self._main_store.get_threads(
+ room_id=room_id, limit=limit, from_token=from_token
+ )
+
+ events = await self._main_store.get_events_as_list(thread_roots)
+
+ if include == ThreadsListInclude.participated:
+ # Pre-seed thread participation with whether the requester sent the event.
+ participated = {event.event_id: event.sender == user_id for event in events}
+ # For events the requester did not send, check the database for whether
+ # the requester sent a threaded reply.
+ participated.update(
+ await self._main_store.get_threads_participated(
+ [eid for eid, p in participated.items() if not p],
+ user_id,
+ )
+ )
+
+ # Limit the returned threads to those the user has participated in.
+ events = [event for event in events if participated[event.event_id]]
+
+ events = await filter_events_for_client(
+ self._storage_controllers,
+ user_id,
+ events,
+ is_peeking=(member_event_id is None),
+ )
+
+ aggregations = await self.get_bundled_aggregations(
+ events, requester.user.to_string()
+ )
+
+ now = self._clock.time_msec()
+ serialized_events = self._event_serializer.serialize_events(
+ events, now, bundle_aggregations=aggregations
+ )
+
+ return_value: JsonDict = {"chunk": serialized_events}
+
+ if next_batch:
+ return_value["next_batch"] = str(next_batch)
+
+ return return_value
|