diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 48cc9c1ac5..64d373e9d7 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -11,16 +11,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import enum
import logging
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
import attr
-from synapse.api.constants import RelationTypes
+from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import SynapseError
from synapse.events import EventBase, relation_from_event
from synapse.logging.tracing import SynapseTags, set_attribute, trace
-from synapse.storage.databases.main.relations import _RelatedEvent
+from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent
+from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, Requester, StreamToken, UserID
from synapse.visibility import filter_events_for_client
@@ -31,6 +33,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+class ThreadsListInclude(str, enum.Enum):
+ """Valid values for the 'include' flag of /threads."""
+
+ all = "all"
+ participated = "participated"
+
+
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation:
# The latest event in the thread.
@@ -66,18 +75,17 @@ class RelationsHandler:
self._clock = hs.get_clock()
self._event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
+ self._event_creation_handler = hs.get_event_creation_handler()
async def get_relations(
self,
requester: Requester,
event_id: str,
room_id: str,
+ pagin_config: PaginationConfig,
+ include_original_event: bool,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
- limit: int = 5,
- direction: str = "b",
- from_token: Optional[StreamToken] = None,
- to_token: Optional[StreamToken] = None,
) -> JsonDict:
"""Get related events of a event, ordered by topological ordering.
@@ -87,13 +95,10 @@ class RelationsHandler:
requester: The user requesting the relations.
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
+ pagin_config: The pagination config rules to apply, if any.
+ include_original_event: Whether to include the parent event.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
- limit: Only fetch the most recent `limit` events.
- direction: Whether to fetch the most recent first (`"b"`) or the
- oldest first (`"f"`).
- from_token: Fetch rows from the given token, or from the start if None.
- to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
The pagination chunk.
@@ -121,10 +126,10 @@ class RelationsHandler:
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
- limit=limit,
- direction=direction,
- from_token=from_token,
- to_token=to_token,
+ limit=pagin_config.limit,
+ direction=pagin_config.direction,
+ from_token=pagin_config.from_token,
+ to_token=pagin_config.to_token,
)
events = await self._main_store.get_events_as_list(
@@ -138,31 +143,32 @@ class RelationsHandler:
is_peeking=(member_event_id is None),
)
- now = self._clock.time_msec()
- # Do not bundle aggregations when retrieving the original event because
- # we want the content before relations are applied to it.
- original_event = self._event_serializer.serialize_event(
- event, now, bundle_aggregations=None
- )
# The relations returned for the requested event do include their
# bundled aggregations.
aggregations = await self.get_bundled_aggregations(
events, requester.user.to_string()
)
- serialized_events = self._event_serializer.serialize_events(
- events, now, bundle_aggregations=aggregations
- )
- return_value = {
- "chunk": serialized_events,
- "original_event": original_event,
+ now = self._clock.time_msec()
+ return_value: JsonDict = {
+ "chunk": self._event_serializer.serialize_events(
+ events, now, bundle_aggregations=aggregations
+ ),
}
+ if include_original_event:
+ # Do not bundle aggregations when retrieving the original event because
+ # we want the content before relations are applied to it.
+ return_value["original_event"] = self._event_serializer.serialize_event(
+ event, now, bundle_aggregations=None
+ )
if next_token:
return_value["next_batch"] = await next_token.to_string(self._main_store)
- if from_token:
- return_value["prev_batch"] = await from_token.to_string(self._main_store)
+ if pagin_config.from_token:
+ return_value["prev_batch"] = await pagin_config.from_token.to_string(
+ self._main_store
+ )
return return_value
@@ -201,6 +207,59 @@ class RelationsHandler:
return related_events, next_token
+ async def redact_events_related_to(
+ self,
+ requester: Requester,
+ event_id: str,
+ initial_redaction_event: EventBase,
+ relation_types: List[str],
+ ) -> None:
+ """Redacts all events related to the given event ID with one of the given
+ relation types.
+
+ This method is expected to be called when redacting the event referred to by
+ the given event ID.
+
+ If an event cannot be redacted (e.g. because of insufficient permissions), log
+ the error and try to redact the next one.
+
+ Args:
+ requester: The requester to redact events on behalf of.
+ event_id: The event IDs to look and redact relations of.
+ initial_redaction_event: The redaction for the event referred to by
+ event_id.
+ relation_types: The types of relations to look for.
+
+ Raises:
+ ShadowBanError if the requester is shadow-banned
+ """
+ related_event_ids = (
+ await self._main_store.get_all_relations_for_event_with_types(
+ event_id, relation_types
+ )
+ )
+
+ for related_event_id in related_event_ids:
+ try:
+ await self._event_creation_handler.create_and_send_nonmember_event(
+ requester,
+ {
+ "type": EventTypes.Redaction,
+ "content": initial_redaction_event.content,
+ "room_id": initial_redaction_event.room_id,
+ "sender": requester.user.to_string(),
+ "redacts": related_event_id,
+ },
+ ratelimit=False,
+ )
+ except SynapseError as e:
+ logger.warning(
+ "Failed to redact event %s (related to event %s): %s",
+ related_event_id,
+ event_id,
+ e.msg,
+ )
+
@trace
async def get_annotations_for_event(
self,
@@ -490,3 +549,79 @@ class RelationsHandler:
results.setdefault(event_id, BundledAggregations()).replace = edit
return results
+
+ async def get_threads(
+ self,
+ requester: Requester,
+ room_id: str,
+ include: ThreadsListInclude,
+ limit: int = 5,
+ from_token: Optional[ThreadsNextBatch] = None,
+ ) -> JsonDict:
+ """Get related events of a event, ordered by topological ordering.
+
+ Args:
+ requester: The user requesting the relations.
+ room_id: The room the event belongs to.
+ include: One of "all" or "participated" to indicate which threads should
+ be returned.
+ limit: Only fetch the most recent `limit` events.
+ from_token: Fetch rows from the given token, or from the start if None.
+
+ Returns:
+ The pagination chunk.
+ """
+
+ user_id = requester.user.to_string()
+
+ # TODO Properly handle a user leaving a room.
+ (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
+ room_id, requester, allow_departed_users=True
+ )
+
+ # Note that ignored users are not passed into get_relations_for_event
+ # below. Ignored users are handled in filter_events_for_client (and by
+ # not passing them in here we should get a better cache hit rate).
+ thread_roots, next_batch = await self._main_store.get_threads(
+ room_id=room_id, limit=limit, from_token=from_token
+ )
+
+ events = await self._main_store.get_events_as_list(thread_roots)
+
+ if include == ThreadsListInclude.participated:
+ # Pre-seed thread participation with whether the requester sent the event.
+ participated = {event.event_id: event.sender == user_id for event in events}
+ # For events the requester did not send, check the database for whether
+ # the requester sent a threaded reply.
+ participated.update(
+ await self._main_store.get_threads_participated(
+ [eid for eid, p in participated.items() if not p],
+ user_id,
+ )
+ )
+
+ # Limit the returned threads to those the user has participated in.
+ events = [event for event in events if participated[event.event_id]]
+
+ events = await filter_events_for_client(
+ self._storage_controllers,
+ user_id,
+ events,
+ is_peeking=(member_event_id is None),
+ )
+
+ aggregations = await self.get_bundled_aggregations(
+ events, requester.user.to_string()
+ )
+
+ now = self._clock.time_msec()
+ serialized_events = self._event_serializer.serialize_events(
+ events, now, bundle_aggregations=aggregations
+ )
+
+ return_value: JsonDict = {"chunk": serialized_events}
+
+ if next_batch:
+ return_value["next_batch"] = str(next_batch)
+
+ return return_value
|