summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/account_data.py83
-rw-r--r--synapse/storage/databases/main/appservice.py2
-rw-r--r--synapse/storage/databases/main/event_federation.py2
-rw-r--r--synapse/storage/databases/main/events.py7
-rw-r--r--synapse/storage/databases/main/purge_events.py1
-rw-r--r--synapse/storage/databases/main/relations.py65
-rw-r--r--synapse/storage/databases/main/signatures.py54
-rw-r--r--synapse/storage/databases/main/stream.py22
-rw-r--r--synapse/storage/databases/main/transactions.py48
9 files changed, 209 insertions, 75 deletions
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index ef475e18c7..5bfa408f74 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -26,6 +26,7 @@ from synapse.storage.database import (
     LoggingTransaction,
 )
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import (
     AbstractStreamIdGenerator,
@@ -44,7 +45,7 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class AccountDataWorkerStore(CacheInvalidationWorkerStore):
+class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore):
     def __init__(
         self,
         database: DatabasePool,
@@ -158,9 +159,9 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
             "get_account_data_for_user", get_account_data_for_user_txn
         )
 
-    @cached(num_args=2, max_entries=5000)
+    @cached(num_args=2, max_entries=5000, tree=True)
     async def get_global_account_data_by_type_for_user(
-        self, data_type: str, user_id: str
+        self, user_id: str, data_type: str
     ) -> Optional[JsonDict]:
         """
         Returns:
@@ -179,7 +180,7 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
         else:
             return None
 
-    @cached(num_args=2)
+    @cached(num_args=2, tree=True)
     async def get_account_data_for_room(
         self, user_id: str, room_id: str
     ) -> Dict[str, JsonDict]:
@@ -210,7 +211,7 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
             "get_account_data_for_room", get_account_data_for_room_txn
         )
 
-    @cached(num_args=3, max_entries=5000)
+    @cached(num_args=3, max_entries=5000, tree=True)
     async def get_account_data_for_room_and_type(
         self, user_id: str, room_id: str, account_data_type: str
     ) -> Optional[JsonDict]:
@@ -392,7 +393,7 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
             for row in rows:
                 if not row.room_id:
                     self.get_global_account_data_by_type_for_user.invalidate(
-                        (row.data_type, row.user_id)
+                        (row.user_id, row.data_type)
                     )
                 self.get_account_data_for_user.invalidate((row.user_id,))
                 self.get_account_data_for_room.invalidate((row.user_id, row.room_id))
@@ -476,7 +477,7 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
             self._account_data_stream_cache.entity_has_changed(user_id, next_id)
             self.get_account_data_for_user.invalidate((user_id,))
             self.get_global_account_data_by_type_for_user.invalidate(
-                (account_data_type, user_id)
+                (user_id, account_data_type)
             )
 
         return self._account_data_id_gen.get_current_token()
@@ -546,6 +547,74 @@ class AccountDataWorkerStore(CacheInvalidationWorkerStore):
         for ignored_user_id in previously_ignored_users ^ currently_ignored_users:
             self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,))
 
+    async def purge_account_data_for_user(self, user_id: str) -> None:
+        """
+        Removes the account data for a user.
+
+        This is intended to be used upon user deactivation and also removes any
+        derived information from account data (e.g. push rules and ignored users).
+
+        Args:
+            user_id: The user ID to remove data for.
+        """
+
+        def purge_account_data_for_user_txn(txn: LoggingTransaction) -> None:
+            # Purge from the primary account_data tables.
+            self.db_pool.simple_delete_txn(
+                txn, table="account_data", keyvalues={"user_id": user_id}
+            )
+
+            self.db_pool.simple_delete_txn(
+                txn, table="room_account_data", keyvalues={"user_id": user_id}
+            )
+
+            # Purge from ignored_users where this user is the ignorer.
+            # N.B. We don't purge where this user is the ignoree, because that
+            #      interferes with other users' account data.
+            #      It's also not this user's data to delete!
+            self.db_pool.simple_delete_txn(
+                txn, table="ignored_users", keyvalues={"ignorer_user_id": user_id}
+            )
+
+            # Remove the push rules
+            self.db_pool.simple_delete_txn(
+                txn, table="push_rules", keyvalues={"user_name": user_id}
+            )
+            self.db_pool.simple_delete_txn(
+                txn, table="push_rules_enable", keyvalues={"user_name": user_id}
+            )
+            self.db_pool.simple_delete_txn(
+                txn, table="push_rules_stream", keyvalues={"user_id": user_id}
+            )
+
+            # Invalidate caches as appropriate
+            self._invalidate_cache_and_stream(
+                txn, self.get_account_data_for_room_and_type, (user_id,)
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.get_account_data_for_user, (user_id,)
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.get_global_account_data_by_type_for_user, (user_id,)
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.get_account_data_for_room, (user_id,)
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.get_push_rules_for_user, (user_id,)
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.get_push_rules_enabled_for_user, (user_id,)
+            )
+            # This user might be contained in the ignored_by cache for other users,
+            # so we have to invalidate it all.
+            self._invalidate_all_cache_and_stream(txn, self.ignored_by)
+
+        await self.db_pool.runInteraction(
+            "purge_account_data_for_user_txn",
+            purge_account_data_for_user_txn,
+        )
+
 
 class AccountDataStore(AccountDataWorkerStore):
     pass
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 92c95a41d7..2bb5288431 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -384,7 +384,7 @@ class ApplicationServiceTransactionWorkerStore(
             "get_new_events_for_appservice", get_new_events_for_appservice_txn
         )
 
-        events = await self.get_events_as_list(event_ids)
+        events = await self.get_events_as_list(event_ids, get_prev_content=True)
 
         return upper_bound, events
 
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index a556f17dac..ca71f073fc 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -65,7 +65,7 @@ class _NoChainCoverIndex(Exception):
         super().__init__("Unexpectedly no chain cover for events in %s" % (room_id,))
 
 
-class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
+class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBaseStore):
     def __init__(
         self,
         database: DatabasePool,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1ae1ebe108..b7554154ac 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1389,6 +1389,8 @@ class PersistEventsStore:
                 "received_ts",
                 "sender",
                 "contains_url",
+                "state_key",
+                "rejection_reason",
             ),
             values=(
                 (
@@ -1405,8 +1407,10 @@ class PersistEventsStore:
                     self._clock.time_msec(),
                     event.sender,
                     "url" in event.content and isinstance(event.content["url"], str),
+                    event.get_state_key(),
+                    context.rejected or None,
                 )
-                for event, _ in events_and_contexts
+                for event, context in events_and_contexts
             ),
         )
 
@@ -1456,6 +1460,7 @@ class PersistEventsStore:
         for event, context in events_and_contexts:
             if context.rejected:
                 # Insert the event_id into the rejections table
+                # (events.rejection_reason has already been done)
                 self._store_rejections_txn(txn, event.event_id, context.rejected)
                 to_remove.add(event)
 
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 91b0576b85..e87a8fb85d 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -390,7 +390,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             "event_search",
             "events",
             "group_rooms",
-            "public_room_list_stream",
             "receipts_graph",
             "receipts_linearized",
             "room_aliases",
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 2cb5d06c13..37468a5183 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,17 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import (
-    TYPE_CHECKING,
-    Any,
-    Dict,
-    Iterable,
-    List,
-    Optional,
-    Tuple,
-    Union,
-    cast,
-)
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union, cast
 
 import attr
 from frozendict import frozendict
@@ -43,6 +33,7 @@ from synapse.storage.relations import (
     PaginationChunk,
     RelationPaginationToken,
 )
+from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached
 
 if TYPE_CHECKING:
@@ -51,6 +42,30 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _ThreadAggregation:
+    latest_event: EventBase
+    count: int
+    current_user_participated: bool
+
+
+@attr.s(slots=True, auto_attribs=True)
+class BundledAggregations:
+    """
+    The bundled aggregations for an event.
+
+    Some values require additional processing during serialization.
+    """
+
+    annotations: Optional[JsonDict] = None
+    references: Optional[JsonDict] = None
+    replace: Optional[EventBase] = None
+    thread: Optional[_ThreadAggregation] = None
+
+    def __bool__(self) -> bool:
+        return bool(self.annotations or self.references or self.replace or self.thread)
+
+
 class RelationsWorkerStore(SQLBaseStore):
     def __init__(
         self,
@@ -60,7 +75,6 @@ class RelationsWorkerStore(SQLBaseStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        self._msc1849_enabled = hs.config.experimental.msc1849_enabled
         self._msc3440_enabled = hs.config.experimental.msc3440_enabled
 
     @cached(tree=True)
@@ -585,7 +599,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
     async def _get_bundled_aggregation_for_event(
         self, event: EventBase, user_id: str
-    ) -> Optional[Dict[str, Any]]:
+    ) -> Optional[BundledAggregations]:
         """Generate bundled aggregations for an event.
 
         Note that this does not use a cache, but depends on cached methods.
@@ -616,24 +630,24 @@ class RelationsWorkerStore(SQLBaseStore):
         # The bundled aggregations to include, a mapping of relation type to a
         # type-specific value. Some types include the direct return type here
         # while others need more processing during serialization.
-        aggregations: Dict[str, Any] = {}
+        aggregations = BundledAggregations()
 
         annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
         if annotations.chunk:
-            aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
+            aggregations.annotations = annotations.to_dict()
 
         references = await self.get_relations_for_event(
             event_id, room_id, RelationTypes.REFERENCE, direction="f"
         )
         if references.chunk:
-            aggregations[RelationTypes.REFERENCE] = references.to_dict()
+            aggregations.references = references.to_dict()
 
         edit = None
         if event.type == EventTypes.Message:
             edit = await self.get_applicable_edit(event_id, room_id)
 
         if edit:
-            aggregations[RelationTypes.REPLACE] = edit
+            aggregations.replace = edit
 
         # If this event is the start of a thread, include a summary of the replies.
         if self._msc3440_enabled:
@@ -644,11 +658,11 @@ class RelationsWorkerStore(SQLBaseStore):
                 event_id, room_id, user_id
             )
             if latest_thread_event:
-                aggregations[RelationTypes.THREAD] = {
-                    "latest_event": latest_thread_event,
-                    "count": thread_count,
-                    "current_user_participated": participated,
-                }
+                aggregations.thread = _ThreadAggregation(
+                    latest_event=latest_thread_event,
+                    count=thread_count,
+                    current_user_participated=participated,
+                )
 
         # Store the bundled aggregations in the event metadata for later use.
         return aggregations
@@ -657,7 +671,7 @@ class RelationsWorkerStore(SQLBaseStore):
         self,
         events: Iterable[EventBase],
         user_id: str,
-    ) -> Dict[str, Dict[str, Any]]:
+    ) -> Dict[str, BundledAggregations]:
         """Generate bundled aggregations for events.
 
         Args:
@@ -668,15 +682,12 @@ class RelationsWorkerStore(SQLBaseStore):
             A map of event ID to the bundled aggregation for the event. Not all
             events may have bundled aggregations in the results.
         """
-        # If bundled aggregations are disabled, nothing to do.
-        if not self._msc1849_enabled:
-            return {}
 
         # TODO Parallelize.
         results = {}
         for event in events:
             event_result = await self._get_bundled_aggregation_for_event(event, user_id)
-            if event_result is not None:
+            if event_result:
                 results[event.event_id] = event_result
 
         return results
diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py
index 3201623fe4..0518b8b910 100644
--- a/synapse/storage/databases/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -12,16 +12,19 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Dict, Iterable, List, Tuple
+from typing import Collection, Dict, List, Tuple
 
 from unpaddedbase64 import encode_base64
 
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.types import Cursor
+from synapse.crypto.event_signing import compute_event_reference_hash
+from synapse.storage.databases.main.events_worker import (
+    EventRedactBehaviour,
+    EventsWorkerStore,
+)
 from synapse.util.caches.descriptors import cached, cachedList
 
 
-class SignatureWorkerStore(SQLBaseStore):
+class SignatureWorkerStore(EventsWorkerStore):
     @cached()
     def get_event_reference_hash(self, event_id):
         # This is a dummy function to allow get_event_reference_hashes
@@ -32,7 +35,7 @@ class SignatureWorkerStore(SQLBaseStore):
         cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1
     )
     async def get_event_reference_hashes(
-        self, event_ids: Iterable[str]
+        self, event_ids: Collection[str]
     ) -> Dict[str, Dict[str, bytes]]:
         """Get all hashes for given events.
 
@@ -41,18 +44,27 @@ class SignatureWorkerStore(SQLBaseStore):
 
         Returns:
              A mapping of event ID to a mapping of algorithm to hash.
+             Returns an empty dict for a given event id if that event is unknown.
         """
+        events = await self.get_events(
+            event_ids,
+            redact_behaviour=EventRedactBehaviour.AS_IS,
+            allow_rejected=True,
+        )
 
-        def f(txn):
-            return {
-                event_id: self._get_event_reference_hashes_txn(txn, event_id)
-                for event_id in event_ids
-            }
+        hashes: Dict[str, Dict[str, bytes]] = {}
+        for event_id in event_ids:
+            event = events.get(event_id)
+            if event is None:
+                hashes[event_id] = {}
+            else:
+                ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
+                hashes[event_id] = {ref_alg: ref_hash_bytes}
 
-        return await self.db_pool.runInteraction("get_event_reference_hashes", f)
+        return hashes
 
     async def add_event_hashes(
-        self, event_ids: Iterable[str]
+        self, event_ids: Collection[str]
     ) -> List[Tuple[str, Dict[str, str]]]:
         """
 
@@ -70,24 +82,6 @@ class SignatureWorkerStore(SQLBaseStore):
 
         return list(encoded_hashes.items())
 
-    def _get_event_reference_hashes_txn(
-        self, txn: Cursor, event_id: str
-    ) -> Dict[str, bytes]:
-        """Get all the hashes for a given PDU.
-        Args:
-            txn:
-            event_id: Id for the Event.
-        Returns:
-            A mapping of algorithm -> hash.
-        """
-        query = (
-            "SELECT algorithm, hash"
-            " FROM event_reference_hashes"
-            " WHERE event_id = ?"
-        )
-        txn.execute(query, (event_id,))
-        return {k: v for k, v in txn}
-
 
 class SignatureStore(SignatureWorkerStore):
     """Persistence for event signatures and hashes"""
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 319464b1fa..a898f847e7 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -81,6 +81,14 @@ class _EventDictReturn:
     stream_ordering: int
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _EventsAround:
+    events_before: List[EventBase]
+    events_after: List[EventBase]
+    start: RoomStreamToken
+    end: RoomStreamToken
+
+
 def generate_pagination_where_clause(
     direction: str,
     column_names: Tuple[str, str],
@@ -846,7 +854,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         before_limit: int,
         after_limit: int,
         event_filter: Optional[Filter] = None,
-    ) -> dict:
+    ) -> _EventsAround:
         """Retrieve events and pagination tokens around a given event in a
         room.
         """
@@ -869,12 +877,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             list(results["after"]["event_ids"]), get_prev_content=True
         )
 
-        return {
-            "events_before": events_before,
-            "events_after": events_after,
-            "start": results["before"]["token"],
-            "end": results["after"]["token"],
-        }
+        return _EventsAround(
+            events_before=events_before,
+            events_after=events_after,
+            start=results["before"]["token"],
+            end=results["after"]["token"],
+        )
 
     def _get_events_around_txn(
         self,
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 4b78b4d098..ba79e19f7f 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -561,6 +561,54 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
             "get_destinations_paginate_txn", get_destinations_paginate_txn
         )
 
+    async def get_destination_rooms_paginate(
+        self, destination: str, start: int, limit: int, direction: str = "f"
+    ) -> Tuple[List[JsonDict], int]:
+        """Function to retrieve a paginated list of destination's rooms.
+        This will return a json list of rooms and the
+        total number of rooms.
+
+        Args:
+            destination: the destination to query
+            start: start number to begin the query from
+            limit: number of rows to retrieve
+            direction: sort ascending or descending by room_id
+        Returns:
+            A tuple of a dict of rooms and a count of total rooms.
+        """
+
+        def get_destination_rooms_paginate_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[JsonDict], int]:
+
+            if direction == "b":
+                order = "DESC"
+            else:
+                order = "ASC"
+
+            sql = """
+                SELECT COUNT(*) as total_rooms
+                FROM destination_rooms
+                WHERE destination = ?
+                """
+            txn.execute(sql, [destination])
+            count = cast(Tuple[int], txn.fetchone())[0]
+
+            rooms = self.db_pool.simple_select_list_paginate_txn(
+                txn=txn,
+                table="destination_rooms",
+                orderby="room_id",
+                start=start,
+                limit=limit,
+                retcols=("room_id", "stream_ordering"),
+                order_direction=order,
+            )
+            return rooms, count
+
+        return await self.db_pool.runInteraction(
+            "get_destination_rooms_paginate_txn", get_destination_rooms_paginate_txn
+        )
+
     async def is_destination_known(self, destination: str) -> bool:
         """Check if a destination is known to the server."""
         result = await self.db_pool.simple_select_one_onecol(