summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/client_ips.py124
-rw-r--r--synapse/storage/databases/main/deviceinbox.py92
-rw-r--r--synapse/storage/databases/main/devices.py35
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py85
-rw-r--r--synapse/storage/databases/main/profile.py2
-rw-r--r--synapse/storage/databases/main/registration.py18
6 files changed, 299 insertions, 57 deletions
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index b81d9218ce..1dc7f0ebe3 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -478,6 +478,58 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
 
         return {(d["user_id"], d["device_id"]): d for d in res}
 
+    async def get_user_ip_and_agents(
+        self, user: UserID, since_ts: int = 0
+    ) -> List[LastConnectionInfo]:
+        """Fetch the IPs and user agents for a user since the given timestamp.
+
+        The result might be slightly out of date as client IPs are inserted in batches.
+
+        Args:
+            user: The user for which to fetch IP addresses and user agents.
+            since_ts: The timestamp after which to fetch IP addresses and user agents,
+                in milliseconds.
+
+        Returns:
+            A list of dictionaries, each containing:
+             * `access_token`: The access token used.
+             * `ip`: The IP address used.
+             * `user_agent`: The last user agent seen for this access token and IP
+               address combination.
+             * `last_seen`: The timestamp at which this access token and IP address
+               combination was last seen, in milliseconds.
+
+            Only the latest user agent for each access token and IP address combination
+            is available.
+        """
+        user_id = user.to_string()
+
+        def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]:
+            txn.execute(
+                """
+                SELECT access_token, ip, user_agent, last_seen FROM user_ips
+                WHERE last_seen >= ? AND user_id = ?
+                ORDER BY last_seen
+                DESC
+                """,
+                (since_ts, user_id),
+            )
+            return cast(List[Tuple[str, str, str, int]], txn.fetchall())
+
+        rows = await self.db_pool.runInteraction(
+            desc="get_user_ip_and_agents", func=get_recent
+        )
+
+        return [
+            {
+                "access_token": access_token,
+                "ip": ip,
+                "user_agent": user_agent,
+                "last_seen": last_seen,
+            }
+            for access_token, ip, user_agent, last_seen in rows
+        ]
+
 
 class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
     def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
@@ -622,49 +674,43 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
     async def get_user_ip_and_agents(
         self, user: UserID, since_ts: int = 0
     ) -> List[LastConnectionInfo]:
+        """Fetch the IPs and user agents for a user since the given timestamp.
+
+        Args:
+            user: The user for which to fetch IP addresses and user agents.
+            since_ts: The timestamp after which to fetch IP addresses and user agents,
+                in milliseconds.
+
+        Returns:
+            A list of dictionaries, each containing:
+             * `access_token`: The access token used.
+             * `ip`: The IP address used.
+             * `user_agent`: The last user agent seen for this access token and IP
+               address combination.
+             * `last_seen`: The timestamp at which this access token and IP address
+               combination was last seen, in milliseconds.
+
+            Only the latest user agent for each access token and IP address combination
+            is available.
         """
-        Fetch IP/User Agent connection since a given timestamp.
-        """
-        user_id = user.to_string()
-        results: Dict[Tuple[str, str], Tuple[str, int]] = {}
+        results: Dict[Tuple[str, str], LastConnectionInfo] = {
+            (connection["access_token"], connection["ip"]): connection
+            for connection in await super().get_user_ip_and_agents(user, since_ts)
+        }
 
+        # Overlay data that is pending insertion on top of the results from the
+        # database.
+        user_id = user.to_string()
         for key in self._batch_row_update:
-            (
-                uid,
-                access_token,
-                ip,
-            ) = key
+            uid, access_token, ip = key
             if uid == user_id:
                 user_agent, _, last_seen = self._batch_row_update[key]
                 if last_seen >= since_ts:
-                    results[(access_token, ip)] = (user_agent, last_seen)
-
-        def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]:
-            txn.execute(
-                """
-                SELECT access_token, ip, user_agent, last_seen FROM user_ips
-                WHERE last_seen >= ? AND user_id = ?
-                ORDER BY last_seen
-                DESC
-                """,
-                (since_ts, user_id),
-            )
-            return cast(List[Tuple[str, str, str, int]], txn.fetchall())
-
-        rows = await self.db_pool.runInteraction(
-            desc="get_user_ip_and_agents", func=get_recent
-        )
+                    results[(access_token, ip)] = {
+                        "access_token": access_token,
+                        "ip": ip,
+                        "user_agent": user_agent,
+                        "last_seen": last_seen,
+                    }
 
-        results.update(
-            ((access_token, ip), (user_agent, last_seen))
-            for access_token, ip, user_agent, last_seen in rows
-        )
-        return [
-            {
-                "access_token": access_token,
-                "ip": ip,
-                "user_agent": user_agent,
-                "last_seen": last_seen,
-            }
-            for (access_token, ip), (user_agent, last_seen) in results.items()
-        ]
+        return list(results.values())
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 8143168107..b0ccab0c9b 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -19,9 +19,10 @@ from synapse.logging import issue9533_logger
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.replication.tcp.streams import ToDeviceStream
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -555,6 +556,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
 
 class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
+    REMOVE_DELETED_DEVICES = "remove_deleted_devices_from_device_inbox"
 
     def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
@@ -570,6 +572,11 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
             self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
         )
 
+        self.db_pool.updates.register_background_update_handler(
+            self.REMOVE_DELETED_DEVICES,
+            self._remove_deleted_devices_from_device_inbox,
+        )
+
     async def _background_drop_index_device_inbox(self, progress, batch_size):
         def reindex_txn(conn):
             txn = conn.cursor()
@@ -582,6 +589,89 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
 
         return 1
 
+    async def _remove_deleted_devices_from_device_inbox(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        """A background update that deletes all device_inboxes for deleted devices.
+
+        This should only need to be run once (when users upgrade to v1.46.0)
+
+        Args:
+            progress: JsonDict used to store progress of this background update
+            batch_size: the maximum number of rows to retrieve in a single select query
+
+        Returns:
+            The number of deleted rows
+        """
+
+        def _remove_deleted_devices_from_device_inbox_txn(
+            txn: LoggingTransaction,
+        ) -> int:
+            """stream_id is not unique
+            we need to use an inclusive `stream_id >= ?` clause,
+            since we might not have deleted all dead device messages for the stream_id
+            returned from the previous query
+
+            Then delete only rows matching the `(user_id, device_id, stream_id)` tuple,
+            to avoid problems of deleting a large number of rows all at once
+            due to a single device having lots of device messages.
+            """
+
+            last_stream_id = progress.get("stream_id", 0)
+
+            sql = """
+                SELECT device_id, user_id, stream_id
+                FROM device_inbox
+                WHERE
+                    stream_id >= ?
+                    AND (device_id, user_id) NOT IN (
+                        SELECT device_id, user_id FROM devices
+                    )
+                ORDER BY stream_id
+                LIMIT ?
+            """
+
+            txn.execute(sql, (last_stream_id, batch_size))
+            rows = txn.fetchall()
+
+            num_deleted = 0
+            for row in rows:
+                num_deleted += self.db_pool.simple_delete_txn(
+                    txn,
+                    "device_inbox",
+                    {"device_id": row[0], "user_id": row[1], "stream_id": row[2]},
+                )
+
+            if rows:
+                # send more than stream_id to progress
+                # otherwise it can happen in large deployments that
+                # no change of status is visible in the log file
+                # it may be that the stream_id does not change in several runs
+                self.db_pool.updates._background_update_progress_txn(
+                    txn,
+                    self.REMOVE_DELETED_DEVICES,
+                    {
+                        "device_id": rows[-1][0],
+                        "user_id": rows[-1][1],
+                        "stream_id": rows[-1][2],
+                    },
+                )
+
+            return num_deleted
+
+        number_deleted = await self.db_pool.runInteraction(
+            "_remove_deleted_devices_from_device_inbox",
+            _remove_deleted_devices_from_device_inbox_txn,
+        )
+
+        # The task is finished when no more lines are deleted.
+        if not number_deleted:
+            await self.db_pool.updates._end_background_update(
+                self.REMOVE_DELETED_DEVICES
+            )
+
+        return number_deleted
+
 
 class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
     pass
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index a01bf2c5b7..b15cd030e0 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1134,19 +1134,14 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             raise StoreError(500, "Problem storing device.")
 
     async def delete_device(self, user_id: str, device_id: str) -> None:
-        """Delete a device.
+        """Delete a device and its device_inbox.
 
         Args:
             user_id: The ID of the user which owns the device
             device_id: The ID of the device to delete
         """
-        await self.db_pool.simple_delete_one(
-            table="devices",
-            keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
-            desc="delete_device",
-        )
 
-        self.device_id_exists_cache.invalidate((user_id, device_id))
+        await self.delete_devices(user_id, [device_id])
 
     async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
         """Deletes several devices.
@@ -1155,13 +1150,25 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             user_id: The ID of the user which owns the devices
             device_ids: The IDs of the devices to delete
         """
-        await self.db_pool.simple_delete_many(
-            table="devices",
-            column="device_id",
-            iterable=device_ids,
-            keyvalues={"user_id": user_id, "hidden": False},
-            desc="delete_devices",
-        )
+
+        def _delete_devices_txn(txn: LoggingTransaction) -> None:
+            self.db_pool.simple_delete_many_txn(
+                txn,
+                table="devices",
+                column="device_id",
+                values=device_ids,
+                keyvalues={"user_id": user_id, "hidden": False},
+            )
+
+            self.db_pool.simple_delete_many_txn(
+                txn,
+                table="device_inbox",
+                column="device_id",
+                values=device_ids,
+                keyvalues={"user_id": user_id},
+            )
+
+        await self.db_pool.runInteraction("delete_devices", _delete_devices_txn)
         for device_id in device_ids:
             self.device_id_exists_cache.invalidate((user_id, device_id))
 
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index fc49112063..ae3a8a63e4 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -17,11 +17,15 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
 
 import attr
 
-from synapse.api.constants import EventContentFields
+from synapse.api.constants import EventContentFields, RelationTypes
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import make_event_from_dict
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingTransaction,
+    make_tuple_comparison_clause,
+)
 from synapse.storage.databases.main.events import PersistEventsStore
 from synapse.storage.types import Cursor
 from synapse.types import JsonDict
@@ -167,6 +171,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             self._purged_chain_cover_index,
         )
 
+        self.db_pool.updates.register_background_update_handler(
+            "event_thread_relation", self._event_thread_relation
+        )
+
         ################################################################################
 
         # bg updates for replacing stream_ordering with a BIGINT
@@ -1091,6 +1099,79 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
         return result
 
+    async def _event_thread_relation(self, progress: JsonDict, batch_size: int) -> int:
+        """Background update handler which will store thread relations for existing events."""
+        last_event_id = progress.get("last_event_id", "")
+
+        def _event_thread_relation_txn(txn: LoggingTransaction) -> int:
+            txn.execute(
+                """
+                SELECT event_id, json FROM event_json
+                LEFT JOIN event_relations USING (event_id)
+                WHERE event_id > ? AND event_relations.event_id IS NULL
+                ORDER BY event_id LIMIT ?
+                """,
+                (last_event_id, batch_size),
+            )
+
+            results = list(txn)
+            missing_thread_relations = []
+            for (event_id, event_json_raw) in results:
+                try:
+                    event_json = db_to_json(event_json_raw)
+                except Exception as e:
+                    logger.warning(
+                        "Unable to load event %s (no relations will be updated): %s",
+                        event_id,
+                        e,
+                    )
+                    continue
+
+                # If there's no relation (or it is not a thread), skip!
+                relates_to = event_json["content"].get("m.relates_to")
+                if not relates_to or not isinstance(relates_to, dict):
+                    continue
+                if relates_to.get("rel_type") != RelationTypes.THREAD:
+                    continue
+
+                # Get the parent ID.
+                parent_id = relates_to.get("event_id")
+                if not isinstance(parent_id, str):
+                    continue
+
+                missing_thread_relations.append((event_id, parent_id))
+
+            # Insert the missing data.
+            self.db_pool.simple_insert_many_txn(
+                txn=txn,
+                table="event_relations",
+                values=[
+                    {
+                        "event_id": event_id,
+                        "relates_to_Id": parent_id,
+                        "relation_type": RelationTypes.THREAD,
+                    }
+                    for event_id, parent_id in missing_thread_relations
+                ],
+            )
+
+            if results:
+                latest_event_id = results[-1][0]
+                self.db_pool.updates._background_update_progress_txn(
+                    txn, "event_thread_relation", {"last_event_id": latest_event_id}
+                )
+
+            return len(results)
+
+        num_rows = await self.db_pool.runInteraction(
+            desc="event_thread_relation", func=_event_thread_relation_txn
+        )
+
+        if not num_rows:
+            await self.db_pool.updates._end_background_update("event_thread_relation")
+
+        return num_rows
+
     async def _background_populate_stream_ordering2(
         self, progress: JsonDict, batch_size: int
     ) -> int:
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index ba7075caa5..dd8e27e226 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -91,7 +91,7 @@ class ProfileWorkerStore(SQLBaseStore):
         )
 
     async def update_remote_profile_cache(
-        self, user_id: str, displayname: str, avatar_url: str
+        self, user_id: str, displayname: Optional[str], avatar_url: Optional[str]
     ) -> int:
         return await self.db_pool.simple_update(
             table="remote_profile_cache",
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 37d47aa823..6c7d6ba508 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -499,6 +499,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
 
+    async def set_user_type(self, user: UserID, user_type: Optional[UserTypes]) -> None:
+        """Sets the user type.
+
+        Args:
+            user: user ID of the user.
+            user_type: type of the user or None for a user without a type.
+        """
+
+        def set_user_type_txn(txn):
+            self.db_pool.simple_update_one_txn(
+                txn, "users", {"name": user.to_string()}, {"user_type": user_type}
+            )
+            self._invalidate_cache_and_stream(
+                txn, self.get_user_by_id, (user.to_string(),)
+            )
+
+        await self.db_pool.runInteraction("set_user_type", set_user_type_txn)
+
     def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
         sql = """
             SELECT users.name as user_id,