summary refs log tree commit diff
path: root/synapse/storage/databases/main/stream.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/stream.py')
-rw-r--r--synapse/storage/databases/main/stream.py59
1 files changed, 40 insertions, 19 deletions
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 323c7bf7a5..4362c93186 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -357,6 +357,24 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
         )
         args.extend(event_filter.related_by_rel_types)
 
+    if event_filter.rel_types:
+        clauses.append(
+            "(%s)"
+            % " OR ".join(
+                "event_relation.relation_type = ?" for _ in event_filter.rel_types
+            )
+        )
+        args.extend(event_filter.rel_types)
+
+    if event_filter.not_rel_types:
+        clauses.append(
+            "((%s) OR event_relation.relation_type IS NULL)"
+            % " AND ".join(
+                "event_relation.relation_type != ?" for _ in event_filter.not_rel_types
+            )
+        )
+        args.extend(event_filter.not_rel_types)
+
     return " AND ".join(clauses), args
 
 
@@ -1024,28 +1042,31 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             "after": {"event_ids": events_after, "token": end_token},
         }
 
-    async def get_all_new_events_stream(
-        self, from_id: int, current_id: int, limit: int, get_prev_content: bool = False
-    ) -> Tuple[int, List[EventBase], Dict[str, Optional[int]]]:
+    async def get_all_new_event_ids_stream(
+        self,
+        from_id: int,
+        current_id: int,
+        limit: int,
+    ) -> Tuple[int, Dict[str, Optional[int]]]:
         """Get all new events
 
-        Returns all events with from_id < stream_ordering <= current_id.
+        Returns all event ids with from_id < stream_ordering <= current_id.
 
         Args:
             from_id:  the stream_ordering of the last event we processed
             current_id:  the stream_ordering of the most recently processed event
             limit: the maximum number of events to return
-            get_prev_content: whether to fetch previous event content
 
         Returns:
-            A tuple of (next_id, events, event_to_received_ts), where `next_id`
+            A tuple of (next_id, event_to_received_ts), where `next_id`
             is the next value to pass as `from_id` (it will either be the
             stream_ordering of the last returned event, or, if fewer than `limit`
             events were found, the `current_id`). The `event_to_received_ts` is
-            a dictionary mapping event ID to the event `received_ts`.
+            a dictionary mapping event ID to the event `received_ts`, sorted by ascending
+            stream_ordering.
         """
 
-        def get_all_new_events_stream_txn(
+        def get_all_new_event_ids_stream_txn(
             txn: LoggingTransaction,
         ) -> Tuple[int, Dict[str, Optional[int]]]:
             sql = (
@@ -1070,15 +1091,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             return upper_bound, event_to_received_ts
 
         upper_bound, event_to_received_ts = await self.db_pool.runInteraction(
-            "get_all_new_events_stream", get_all_new_events_stream_txn
-        )
-
-        events = await self.get_events_as_list(
-            event_to_received_ts.keys(),
-            get_prev_content=get_prev_content,
+            "get_all_new_event_ids_stream", get_all_new_event_ids_stream_txn
         )
 
-        return upper_bound, events, event_to_received_ts
+        return upper_bound, event_to_received_ts
 
     async def get_federation_out_pos(self, typ: str) -> int:
         if self._need_to_reset_federation_stream_positions:
@@ -1202,8 +1218,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             `to_token`), or `limit` is zero.
         """
 
-        assert int(limit) >= 0
-
         # Tokens really represent positions between elements, but we use
         # the convention of pointing to the event before the gap. Hence
         # we have a bit of asymmetry when it comes to equalities.
@@ -1282,8 +1296,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                 # Multiple labels could cause the same event to appear multiple times.
                 needs_distinct = True
 
-        # If there is a filter on relation_senders and relation_types join to the
-        # relations table.
+        # If there is a relation_senders and relation_types filter join to the
+        # relations table to get events related to the current event.
         if event_filter and (
             event_filter.related_by_senders or event_filter.related_by_rel_types
         ):
@@ -1298,6 +1312,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                     LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id)
                 """
 
+        # If there is a not_rel_types filter join to the relations table to get
+        # the event's relation information.
+        if event_filter and (event_filter.rel_types or event_filter.not_rel_types):
+            join_clause += """
+                LEFT JOIN event_relations AS event_relation USING (event_id)
+            """
+
         if needs_distinct:
             select_keywords += " DISTINCT"