summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/database.py4
-rw-r--r--synapse/storage/databases/main/__init__.py1
-rw-r--r--synapse/storage/databases/main/appservice.py98
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py7
-rw-r--r--synapse/storage/databases/main/events_worker.py62
-rw-r--r--synapse/storage/databases/main/media_repository.py131
-rw-r--r--synapse/storage/databases/main/profile.py4
-rw-r--r--synapse/storage/databases/main/registration.py202
-rw-r--r--synapse/storage/databases/main/room.py104
-rw-r--r--synapse/storage/databases/main/schema/delta/58/22puppet_token.sql17
-rw-r--r--synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql2
11 files changed, 488 insertions, 144 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 0217e63108..a0572b2952 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -94,7 +94,7 @@ def make_pool(
         cp_openfun=lambda conn: engine.on_new_connection(
             LoggingDatabaseConnection(conn, engine, "on_new_connection")
         ),
-        **db_config.config.get("args", {})
+        **db_config.config.get("args", {}),
     )
 
 
@@ -632,7 +632,7 @@ class DatabasePool:
                 func,
                 *args,
                 db_autocommit=db_autocommit,
-                **kwargs
+                **kwargs,
             )
 
             for after_callback, after_args, after_kwargs in after_callbacks:
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 9b16f45f3e..43660ec4fb 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -146,7 +146,6 @@ class DataStore(
             db_conn, "e2e_cross_signing_keys", "stream_id"
         )
 
-        self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
         self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
         self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
         self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 637a938bac..26eef6eb61 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -15,21 +15,31 @@
 # limitations under the License.
 import logging
 import re
-from typing import List
+from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple
 
-from synapse.appservice import ApplicationService, AppServiceTransaction
+from synapse.appservice import (
+    ApplicationService,
+    ApplicationServiceState,
+    AppServiceTransaction,
+)
 from synapse.config.appservice import load_appservices
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
+from synapse.storage.types import Connection
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
-def _make_exclusive_regex(services_cache):
+def _make_exclusive_regex(
+    services_cache: List[ApplicationService],
+) -> Optional[Pattern]:
     # We precompile a regex constructed from all the regexes that the AS's
     # have registered for exclusive users.
     exclusive_user_regexes = [
@@ -39,17 +49,19 @@ def _make_exclusive_regex(services_cache):
     ]
     if exclusive_user_regexes:
         exclusive_user_regex = "|".join("(" + r + ")" for r in exclusive_user_regexes)
-        exclusive_user_regex = re.compile(exclusive_user_regex)
+        exclusive_user_pattern = re.compile(
+            exclusive_user_regex
+        )  # type: Optional[Pattern]
     else:
         # We handle this case specially otherwise the constructed regex
         # will always match
-        exclusive_user_regex = None
+        exclusive_user_pattern = None
 
-    return exclusive_user_regex
+    return exclusive_user_pattern
 
 
 class ApplicationServiceWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         self.services_cache = load_appservices(
             hs.hostname, hs.config.app_service_config_files
         )
@@ -60,7 +72,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
     def get_app_services(self):
         return self.services_cache
 
-    def get_if_app_services_interested_in_user(self, user_id):
+    def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
         """Check if the user is one associated with an app service (exclusively)
         """
         if self.exclusive_user_regex:
@@ -68,7 +80,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
         else:
             return False
 
-    def get_app_service_by_user_id(self, user_id):
+    def get_app_service_by_user_id(self, user_id: str) -> Optional[ApplicationService]:
         """Retrieve an application service from their user ID.
 
         All application services have associated with them a particular user ID.
@@ -77,35 +89,35 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
         a user ID to an application service.
 
         Args:
-            user_id(str): The user ID to see if it is an application service.
+            user_id: The user ID to see if it is an application service.
         Returns:
-            synapse.appservice.ApplicationService or None.
+            The application service or None.
         """
         for service in self.services_cache:
             if service.sender == user_id:
                 return service
         return None
 
-    def get_app_service_by_token(self, token):
+    def get_app_service_by_token(self, token: str) -> Optional[ApplicationService]:
         """Get the application service with the given appservice token.
 
         Args:
-            token (str): The application service token.
+            token: The application service token.
         Returns:
-            synapse.appservice.ApplicationService or None.
+            The application service or None.
         """
         for service in self.services_cache:
             if service.token == token:
                 return service
         return None
 
-    def get_app_service_by_id(self, as_id):
+    def get_app_service_by_id(self, as_id: str) -> Optional[ApplicationService]:
         """Get the application service with the given appservice ID.
 
         Args:
-            as_id (str): The application service ID.
+            as_id: The application service ID.
         Returns:
-            synapse.appservice.ApplicationService or None.
+            The application service or None.
         """
         for service in self.services_cache:
             if service.id == as_id:
@@ -124,11 +136,13 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
 class ApplicationServiceTransactionWorkerStore(
     ApplicationServiceWorkerStore, EventsWorkerStore
 ):
-    async def get_appservices_by_state(self, state):
+    async def get_appservices_by_state(
+        self, state: ApplicationServiceState
+    ) -> List[ApplicationService]:
         """Get a list of application services based on their state.
 
         Args:
-            state(ApplicationServiceState): The state to filter on.
+            state: The state to filter on.
         Returns:
             A list of ApplicationServices, which may be empty.
         """
@@ -145,13 +159,15 @@ class ApplicationServiceTransactionWorkerStore(
                     services.append(service)
         return services
 
-    async def get_appservice_state(self, service):
+    async def get_appservice_state(
+        self, service: ApplicationService
+    ) -> Optional[ApplicationServiceState]:
         """Get the application service state.
 
         Args:
-            service(ApplicationService): The service whose state to set.
+            service: The service whose state to set.
         Returns:
-            An ApplicationServiceState.
+            An ApplicationServiceState or none.
         """
         result = await self.db_pool.simple_select_one(
             "application_services_state",
@@ -164,12 +180,14 @@ class ApplicationServiceTransactionWorkerStore(
             return result.get("state")
         return None
 
-    async def set_appservice_state(self, service, state) -> None:
+    async def set_appservice_state(
+        self, service: ApplicationService, state: ApplicationServiceState
+    ) -> None:
         """Set the application service state.
 
         Args:
-            service(ApplicationService): The service whose state to set.
-            state(ApplicationServiceState): The connectivity state to apply.
+            service: The service whose state to set.
+            state: The connectivity state to apply.
         """
         await self.db_pool.simple_upsert(
             "application_services_state", {"as_id": service.id}, {"state": state}
@@ -226,13 +244,14 @@ class ApplicationServiceTransactionWorkerStore(
             "create_appservice_txn", _create_appservice_txn
         )
 
-    async def complete_appservice_txn(self, txn_id, service) -> None:
+    async def complete_appservice_txn(
+        self, txn_id: int, service: ApplicationService
+    ) -> None:
         """Completes an application service transaction.
 
         Args:
-            txn_id(str): The transaction ID being completed.
-            service(ApplicationService): The application service which was sent
-            this transaction.
+            txn_id: The transaction ID being completed.
+            service: The application service which was sent this transaction.
         """
         txn_id = int(txn_id)
 
@@ -242,7 +261,7 @@ class ApplicationServiceTransactionWorkerStore(
             # has probably missed some events), so whine loudly but still continue,
             # since it shouldn't fail completion of the transaction.
             last_txn_id = self._get_last_txn(txn, service.id)
-            if (last_txn_id + 1) != txn_id:
+            if (txn_id + 1) != txn_id:
                 logger.error(
                     "appservice: Completing a transaction which has an ID > 1 from "
                     "the last ID sent to this AS. We've either dropped events or "
@@ -272,12 +291,13 @@ class ApplicationServiceTransactionWorkerStore(
             "complete_appservice_txn", _complete_appservice_txn
         )
 
-    async def get_oldest_unsent_txn(self, service):
-        """Get the oldest transaction which has not been sent for this
-        service.
+    async def get_oldest_unsent_txn(
+        self, service: ApplicationService
+    ) -> Optional[AppServiceTransaction]:
+        """Get the oldest transaction which has not been sent for this service.
 
         Args:
-            service(ApplicationService): The app service to get the oldest txn.
+            service: The app service to get the oldest txn.
         Returns:
             An AppServiceTransaction or None.
         """
@@ -313,7 +333,7 @@ class ApplicationServiceTransactionWorkerStore(
             service=service, id=entry["txn_id"], events=events, ephemeral=[]
         )
 
-    def _get_last_txn(self, txn, service_id):
+    def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
         txn.execute(
             "SELECT last_txn FROM application_services_state WHERE as_id=?",
             (service_id,),
@@ -324,7 +344,7 @@ class ApplicationServiceTransactionWorkerStore(
         else:
             return int(last_txn_id[0])  # select 'last_txn' col
 
-    async def set_appservice_last_pos(self, pos) -> None:
+    async def set_appservice_last_pos(self, pos: int) -> None:
         def set_appservice_last_pos_txn(txn):
             txn.execute(
                 "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
@@ -334,7 +354,9 @@ class ApplicationServiceTransactionWorkerStore(
             "set_appservice_last_pos", set_appservice_last_pos_txn
         )
 
-    async def get_new_events_for_appservice(self, current_id, limit):
+    async def get_new_events_for_appservice(
+        self, current_id: int, limit: int
+    ) -> Tuple[int, List[EventBase]]:
         """Get all new events for an appservice"""
 
         def get_new_events_for_appservice_txn(txn):
@@ -394,7 +416,7 @@ class ApplicationServiceTransactionWorkerStore(
         )
 
     async def set_type_stream_id_for_appservice(
-        self, service: ApplicationService, type: str, pos: int
+        self, service: ApplicationService, type: str, pos: Optional[int]
     ) -> None:
         if type not in ("read_receipt", "presence"):
             raise ValueError(
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 5e4af2eb51..97b6754846 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -92,6 +92,13 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             where_clause="NOT have_censored",
         )
 
+        self.db_pool.updates.register_background_index_update(
+            "users_have_local_media",
+            index_name="users_have_local_media",
+            table="local_media_repository",
+            columns=["user_id", "created_ts"],
+        )
+
     async def _background_reindex_fields_sender(self, progress, batch_size):
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 6e7f16f39c..4732685f6e 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -31,6 +31,7 @@ from synapse.api.room_versions import (
     RoomVersions,
 )
 from synapse.events import EventBase, make_event_from_dict
+from synapse.events.snapshot import EventContext
 from synapse.events.utils import prune_event
 from synapse.logging.context import PreserveLoggingContext, current_context
 from synapse.metrics.background_process_metrics import (
@@ -44,7 +45,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
 from synapse.storage.database import DatabasePool
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
-from synapse.types import Collection, get_domain_from_id
+from synapse.types import Collection, JsonDict, get_domain_from_id
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.iterutils import batch_iter
@@ -525,6 +526,57 @@ class EventsWorkerStore(SQLBaseStore):
 
         return event_map
 
+    async def get_stripped_room_state_from_event_context(
+        self,
+        context: EventContext,
+        state_types_to_include: List[EventTypes],
+        membership_user_id: Optional[str] = None,
+    ) -> List[JsonDict]:
+        """
+        Retrieve the stripped state from a room, given an event context to retrieve state
+        from as well as the state types to include. Optionally, include the membership
+        events from a specific user.
+
+        "Stripped" state means that only the `type`, `state_key`, `content` and `sender` keys
+        are included from each state event.
+
+        Args:
+            context: The event context to retrieve state of the room from.
+            state_types_to_include: The type of state events to include.
+            membership_user_id: An optional user ID to include the stripped membership state
+                events of. This is useful when generating the stripped state of a room for
+                invites. We want to send membership events of the inviter, so that the
+                invitee can display the inviter's profile information if the room lacks any.
+
+        Returns:
+            A list of dictionaries, each representing a stripped state event from the room.
+        """
+        current_state_ids = await context.get_current_state_ids()
+
+        # We know this event is not an outlier, so this must be
+        # non-None.
+        assert current_state_ids is not None
+
+        # The state to include
+        state_to_include_ids = [
+            e_id
+            for k, e_id in current_state_ids.items()
+            if k[0] in state_types_to_include
+            or (membership_user_id and k == (EventTypes.Member, membership_user_id))
+        ]
+
+        state_to_include = await self.get_events(state_to_include_ids)
+
+        return [
+            {
+                "type": e.type,
+                "state_key": e.state_key,
+                "content": e.content,
+                "sender": e.sender,
+            }
+            for e in state_to_include.values()
+        ]
+
     def _do_fetch(self, conn):
         """Takes a database connection and waits for requests for events from
         the _event_fetch_list queue.
@@ -1065,11 +1117,13 @@ class EventsWorkerStore(SQLBaseStore):
         def get_all_new_forward_event_rows(txn):
             sql = (
                 "SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
+                " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
                 " FROM events AS e"
                 " LEFT JOIN redactions USING (event_id)"
                 " LEFT JOIN state_events USING (event_id)"
                 " LEFT JOIN event_relations USING (event_id)"
+                " LEFT JOIN room_memberships USING (event_id)"
+                " LEFT JOIN rejections USING (event_id)"
                 " WHERE ? < stream_ordering AND stream_ordering <= ?"
                 " AND instance_name = ?"
                 " ORDER BY stream_ordering ASC"
@@ -1100,12 +1154,14 @@ class EventsWorkerStore(SQLBaseStore):
         def get_ex_outlier_stream_rows_txn(txn):
             sql = (
                 "SELECT event_stream_ordering, e.event_id, e.room_id, e.type,"
-                " state_key, redacts, relates_to_id"
+                " state_key, redacts, relates_to_id, membership, rejections.reason IS NOT NULL"
                 " FROM events AS e"
                 " INNER JOIN ex_outlier_stream AS out USING (event_id)"
                 " LEFT JOIN redactions USING (event_id)"
                 " LEFT JOIN state_events USING (event_id)"
                 " LEFT JOIN event_relations USING (event_id)"
+                " LEFT JOIN room_memberships USING (event_id)"
+                " LEFT JOIN rejections USING (event_id)"
                 " WHERE ? < event_stream_ordering"
                 " AND event_stream_ordering <= ?"
                 " AND out.instance_name = ?"
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index cc538c5c10..4b2f224718 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -93,6 +93,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
 
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
+        self.server_name = hs.hostname
 
     async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
         """Get the metadata for a local piece of media
@@ -115,6 +116,109 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="get_local_media",
         )
 
+    async def get_local_media_by_user_paginate(
+        self, start: int, limit: int, user_id: str
+    ) -> Tuple[List[Dict[str, Any]], int]:
+        """Get a paginated list of metadata for a local piece of media
+        which an user_id has uploaded
+
+        Args:
+            start: offset in the list
+            limit: maximum amount of media_ids to retrieve
+            user_id: fully-qualified user id
+        Returns:
+            A paginated list of all metadata of user's media,
+            plus the total count of all the user's media
+        """
+
+        def get_local_media_by_user_paginate_txn(txn):
+
+            args = [user_id]
+            sql = """
+                SELECT COUNT(*) as total_media
+                FROM local_media_repository
+                WHERE user_id = ?
+            """
+            txn.execute(sql, args)
+            count = txn.fetchone()[0]
+
+            sql = """
+                SELECT
+                    "media_id",
+                    "media_type",
+                    "media_length",
+                    "upload_name",
+                    "created_ts",
+                    "last_access_ts",
+                    "quarantined_by",
+                    "safe_from_quarantine"
+                FROM local_media_repository
+                WHERE user_id = ?
+                ORDER BY created_ts DESC, media_id DESC
+                LIMIT ? OFFSET ?
+            """
+
+            args += [limit, start]
+            txn.execute(sql, args)
+            media = self.db_pool.cursor_to_dict(txn)
+            return media, count
+
+        return await self.db_pool.runInteraction(
+            "get_local_media_by_user_paginate_txn", get_local_media_by_user_paginate_txn
+        )
+
+    async def get_local_media_before(
+        self, before_ts: int, size_gt: int, keep_profiles: bool,
+    ) -> Optional[List[str]]:
+
+        # to find files that have never been accessed (last_access_ts IS NULL)
+        # compare with `created_ts`
+        sql = """
+            SELECT media_id
+            FROM local_media_repository AS lmr
+            WHERE
+                ( last_access_ts < ?
+                OR ( created_ts < ? AND last_access_ts IS NULL ) )
+                AND media_length > ?
+        """
+
+        if keep_profiles:
+            sql_keep = """
+                AND (
+                    NOT EXISTS
+                        (SELECT 1
+                         FROM profiles
+                         WHERE profiles.avatar_url = '{media_prefix}' || lmr.media_id)
+                    AND NOT EXISTS
+                        (SELECT 1
+                         FROM groups
+                         WHERE groups.avatar_url = '{media_prefix}' || lmr.media_id)
+                    AND NOT EXISTS
+                        (SELECT 1
+                         FROM room_memberships
+                         WHERE room_memberships.avatar_url = '{media_prefix}' || lmr.media_id)
+                    AND NOT EXISTS
+                        (SELECT 1
+                         FROM user_directory
+                         WHERE user_directory.avatar_url = '{media_prefix}' || lmr.media_id)
+                    AND NOT EXISTS
+                        (SELECT 1
+                         FROM room_stats_state
+                         WHERE room_stats_state.avatar = '{media_prefix}' || lmr.media_id)
+                )
+            """.format(
+                media_prefix="mxc://%s/" % (self.server_name,),
+            )
+            sql += sql_keep
+
+        def _get_local_media_before_txn(txn):
+            txn.execute(sql, (before_ts, before_ts, size_gt))
+            return [row[0] for row in txn]
+
+        return await self.db_pool.runInteraction(
+            "get_local_media_before", _get_local_media_before_txn
+        )
+
     async def store_local_media(
         self,
         media_id,
@@ -348,6 +452,33 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
             desc="get_remote_media_thumbnails",
         )
 
+    async def get_remote_media_thumbnail(
+        self, origin: str, media_id: str, t_width: int, t_height: int, t_type: str,
+    ) -> Optional[Dict[str, Any]]:
+        """Fetch the thumbnail info of given width, height and type.
+        """
+
+        return await self.db_pool.simple_select_one(
+            table="remote_media_cache_thumbnails",
+            keyvalues={
+                "media_origin": origin,
+                "media_id": media_id,
+                "thumbnail_width": t_width,
+                "thumbnail_height": t_height,
+                "thumbnail_type": t_type,
+            },
+            retcols=(
+                "thumbnail_width",
+                "thumbnail_height",
+                "thumbnail_method",
+                "thumbnail_type",
+                "thumbnail_length",
+                "filesystem_id",
+            ),
+            allow_none=True,
+            desc="get_remote_media_thumbnail",
+        )
+
     async def store_remote_media_thumbnail(
         self,
         origin,
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index a6d1eb908a..0e25ca3d7a 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -39,7 +39,7 @@ class ProfileWorkerStore(SQLBaseStore):
             avatar_url=profile["avatar_url"], display_name=profile["displayname"]
         )
 
-    async def get_profile_displayname(self, user_localpart: str) -> str:
+    async def get_profile_displayname(self, user_localpart: str) -> Optional[str]:
         return await self.db_pool.simple_select_one_onecol(
             table="profiles",
             keyvalues={"user_id": user_localpart},
@@ -47,7 +47,7 @@ class ProfileWorkerStore(SQLBaseStore):
             desc="get_profile_displayname",
         )
 
-    async def get_profile_avatar_url(self, user_localpart: str) -> str:
+    async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]:
         return await self.db_pool.simple_select_one_onecol(
             table="profiles",
             keyvalues={"user_id": user_localpart},
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 4c843b7679..e5d07ce72a 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -16,29 +16,64 @@
 # limitations under the License.
 import logging
 import re
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+
+import attr
 
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
 from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import DatabasePool
-from synapse.storage.types import Cursor
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.databases.main.stats import StatsStore
+from synapse.storage.types import Connection, Cursor
+from synapse.storage.util.id_generators import IdGenerator
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import UserID
 from synapse.util.caches.descriptors import cached
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
 
 logger = logging.getLogger(__name__)
 
 
-class RegistrationWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+@attr.s(frozen=True, slots=True)
+class TokenLookupResult:
+    """Result of looking up an access token.
+
+    Attributes:
+        user_id: The user that this token authenticates as
+        is_guest
+        shadow_banned
+        token_id: The ID of the access token looked up
+        device_id: The device associated with the token, if any.
+        valid_until_ms: The timestamp the token expires, if any.
+        token_owner: The "owner" of the token. This is either the same as the
+            user, or a server admin who is logged in as the user.
+    """
+
+    user_id = attr.ib(type=str)
+    is_guest = attr.ib(type=bool, default=False)
+    shadow_banned = attr.ib(type=bool, default=False)
+    token_id = attr.ib(type=Optional[int], default=None)
+    device_id = attr.ib(type=Optional[str], default=None)
+    valid_until_ms = attr.ib(type=Optional[int], default=None)
+    token_owner = attr.ib(type=str)
+
+    # Make the token owner default to the user ID, which is the common case.
+    @token_owner.default
+    def _default_token_owner(self):
+        return self.user_id
+
+
+class RegistrationWorkerStore(CacheInvalidationWorkerStore):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.config = hs.config
-        self.clock = hs.get_clock()
 
         # Note: we don't check this sequence for consistency as we'd have to
         # call `find_max_generated_user_id_localpart` each time, which is
@@ -55,7 +90,7 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         # Create a background job for culling expired 3PID validity tokens
         if hs.config.run_background_tasks:
-            self.clock.looping_call(
+            self._clock.looping_call(
                 self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
             )
 
@@ -92,21 +127,19 @@ class RegistrationWorkerStore(SQLBaseStore):
         if not info:
             return False
 
-        now = self.clock.time_msec()
+        now = self._clock.time_msec()
         trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
         is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
         return is_trial
 
     @cached()
-    async def get_user_by_access_token(self, token: str) -> Optional[dict]:
+    async def get_user_by_access_token(self, token: str) -> Optional[TokenLookupResult]:
         """Get a user from the given access token.
 
         Args:
             token: The access token of a user.
         Returns:
-            None, if the token did not match, otherwise dict
-            including the keys `name`, `is_guest`, `device_id`, `token_id`,
-            `valid_until_ms`.
+            None, if the token did not match, otherwise a `TokenLookupResult`
         """
         return await self.db_pool.runInteraction(
             "get_user_by_access_token", self._query_for_auth, token
@@ -236,13 +269,13 @@ class RegistrationWorkerStore(SQLBaseStore):
             desc="get_renewal_token_for_user",
         )
 
-    async def get_users_expiring_soon(self) -> List[Dict[str, int]]:
+    async def get_users_expiring_soon(self) -> List[Dict[str, Any]]:
         """Selects users whose account will expire in the [now, now + renew_at] time
         window (see configuration for account_validity for information on what renew_at
         refers to).
 
         Returns:
-            A list of dictionaries mapping user ID to expiration time (in milliseconds).
+            A list of dictionaries, each with a user ID and expiration time (in milliseconds).
         """
 
         def select_users_txn(txn, now_ms, renew_at):
@@ -257,7 +290,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         return await self.db_pool.runInteraction(
             "get_users_expiring_soon",
             select_users_txn,
-            self.clock.time_msec(),
+            self._clock.time_msec(),
             self.config.account_validity.renew_at,
         )
 
@@ -327,19 +360,24 @@ class RegistrationWorkerStore(SQLBaseStore):
 
         await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
 
-    def _query_for_auth(self, txn, token):
-        sql = (
-            "SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id,"
-            " access_tokens.device_id, access_tokens.valid_until_ms"
-            " FROM users"
-            " INNER JOIN access_tokens on users.name = access_tokens.user_id"
-            " WHERE token = ?"
-        )
+    def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
+        sql = """
+            SELECT users.name as user_id,
+                users.is_guest,
+                users.shadow_banned,
+                access_tokens.id as token_id,
+                access_tokens.device_id,
+                access_tokens.valid_until_ms,
+                access_tokens.user_id as token_owner
+            FROM users
+            INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id)
+            WHERE token = ?
+        """
 
         txn.execute(sql, (token,))
         rows = self.db_pool.cursor_to_dict(txn)
         if rows:
-            return rows[0]
+            return TokenLookupResult(**rows[0])
 
         return None
 
@@ -803,7 +841,7 @@ class RegistrationWorkerStore(SQLBaseStore):
         await self.db_pool.runInteraction(
             "cull_expired_threepid_validation_tokens",
             cull_expired_threepid_validation_tokens_txn,
-            self.clock.time_msec(),
+            self._clock.time_msec(),
         )
 
     @wrap_as_background_process("account_validity_set_expiration_dates")
@@ -890,10 +928,10 @@ class RegistrationWorkerStore(SQLBaseStore):
 
 
 class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
-        self.clock = hs.get_clock()
+        self._clock = hs.get_clock()
         self.config = hs.config
 
         self.db_pool.updates.register_background_index_update(
@@ -1016,13 +1054,56 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
 
         return 1
 
+    async def set_user_deactivated_status(
+        self, user_id: str, deactivated: bool
+    ) -> None:
+        """Set the `deactivated` property for the provided user to the provided value.
+
+        Args:
+            user_id: The ID of the user to set the status for.
+            deactivated: The value to set for `deactivated`.
+        """
 
-class RegistrationStore(RegistrationBackgroundUpdateStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+        await self.db_pool.runInteraction(
+            "set_user_deactivated_status",
+            self.set_user_deactivated_status_txn,
+            user_id,
+            deactivated,
+        )
+
+    def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool):
+        self.db_pool.simple_update_one_txn(
+            txn=txn,
+            table="users",
+            keyvalues={"name": user_id},
+            updatevalues={"deactivated": 1 if deactivated else 0},
+        )
+        self._invalidate_cache_and_stream(
+            txn, self.get_user_deactivated_status, (user_id,)
+        )
+        txn.call_after(self.is_guest.invalidate, (user_id,))
+
+    @cached()
+    async def is_guest(self, user_id: str) -> bool:
+        res = await self.db_pool.simple_select_one_onecol(
+            table="users",
+            keyvalues={"name": user_id},
+            retcol="is_guest",
+            allow_none=True,
+            desc="is_guest",
+        )
+
+        return res if res else False
+
+
+class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
 
+        self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
+
     async def add_access_token_to_user(
         self,
         user_id: str,
@@ -1138,19 +1219,19 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
     def _register_user(
         self,
         txn,
-        user_id,
-        password_hash,
-        was_guest,
-        make_guest,
-        appservice_id,
-        create_profile_with_displayname,
-        admin,
-        user_type,
-        shadow_banned,
+        user_id: str,
+        password_hash: Optional[str],
+        was_guest: bool,
+        make_guest: bool,
+        appservice_id: Optional[str],
+        create_profile_with_displayname: Optional[str],
+        admin: bool,
+        user_type: Optional[str],
+        shadow_banned: bool,
     ):
         user_id_obj = UserID.from_string(user_id)
 
-        now = int(self.clock.time())
+        now = int(self._clock.time())
 
         try:
             if was_guest:
@@ -1374,18 +1455,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
 
         await self.db_pool.runInteraction("delete_access_token", f)
 
-    @cached()
-    async def is_guest(self, user_id: str) -> bool:
-        res = await self.db_pool.simple_select_one_onecol(
-            table="users",
-            keyvalues={"name": user_id},
-            retcol="is_guest",
-            allow_none=True,
-            desc="is_guest",
-        )
-
-        return res if res else False
-
     async def add_user_pending_deactivation(self, user_id: str) -> None:
         """
         Adds a user to the table of users who need to be parted from all the rooms they're
@@ -1479,7 +1548,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
                 txn,
                 table="threepid_validation_session",
                 keyvalues={"session_id": session_id},
-                updatevalues={"validated_at": self.clock.time_msec()},
+                updatevalues={"validated_at": self._clock.time_msec()},
             )
 
             return next_link
@@ -1547,35 +1616,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
             start_or_continue_validation_session_txn,
         )
 
-    async def set_user_deactivated_status(
-        self, user_id: str, deactivated: bool
-    ) -> None:
-        """Set the `deactivated` property for the provided user to the provided value.
-
-        Args:
-            user_id: The ID of the user to set the status for.
-            deactivated: The value to set for `deactivated`.
-        """
-
-        await self.db_pool.runInteraction(
-            "set_user_deactivated_status",
-            self.set_user_deactivated_status_txn,
-            user_id,
-            deactivated,
-        )
-
-    def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
-        self.db_pool.simple_update_one_txn(
-            txn=txn,
-            table="users",
-            keyvalues={"name": user_id},
-            updatevalues={"deactivated": 1 if deactivated else 0},
-        )
-        self._invalidate_cache_and_stream(
-            txn, self.get_user_deactivated_status, (user_id,)
-        )
-        txn.call_after(self.is_guest.invalidate, (user_id,))
-
 
 def find_max_generated_user_id_localpart(cur: Cursor) -> int:
     """
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index e83d961c20..dc0c4b5499 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1411,6 +1411,65 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
             desc="add_event_report",
         )
 
+    async def get_event_report(self, report_id: int) -> Optional[Dict[str, Any]]:
+        """Retrieve an event report
+
+        Args:
+            report_id: ID of reported event in database
+        Returns:
+            event_report: json list of information from event report
+        """
+
+        def _get_event_report_txn(txn, report_id):
+
+            sql = """
+                SELECT
+                    er.id,
+                    er.received_ts,
+                    er.room_id,
+                    er.event_id,
+                    er.user_id,
+                    er.content,
+                    events.sender,
+                    room_stats_state.canonical_alias,
+                    room_stats_state.name,
+                    event_json.json AS event_json
+                FROM event_reports AS er
+                LEFT JOIN events
+                    ON events.event_id = er.event_id
+                JOIN event_json
+                    ON event_json.event_id = er.event_id
+                JOIN room_stats_state
+                    ON room_stats_state.room_id = er.room_id
+                WHERE er.id = ?
+            """
+
+            txn.execute(sql, [report_id])
+            row = txn.fetchone()
+
+            if not row:
+                return None
+
+            event_report = {
+                "id": row[0],
+                "received_ts": row[1],
+                "room_id": row[2],
+                "event_id": row[3],
+                "user_id": row[4],
+                "score": db_to_json(row[5]).get("score"),
+                "reason": db_to_json(row[5]).get("reason"),
+                "sender": row[6],
+                "canonical_alias": row[7],
+                "name": row[8],
+                "event_json": db_to_json(row[9]),
+            }
+
+            return event_report
+
+        return await self.db_pool.runInteraction(
+            "get_event_report", _get_event_report_txn, report_id
+        )
+
     async def get_event_reports_paginate(
         self,
         start: int,
@@ -1468,18 +1527,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
                     er.room_id,
                     er.event_id,
                     er.user_id,
-                    er.reason,
                     er.content,
                     events.sender,
-                    room_aliases.room_alias,
-                    event_json.json AS event_json
+                    room_stats_state.canonical_alias,
+                    room_stats_state.name
                 FROM event_reports AS er
-                LEFT JOIN room_aliases
-                    ON room_aliases.room_id = er.room_id
-                JOIN events
+                LEFT JOIN events
                     ON events.event_id = er.event_id
-                JOIN event_json
-                    ON event_json.event_id = er.event_id
+                JOIN room_stats_state
+                    ON room_stats_state.room_id = er.room_id
                 {where_clause}
                 ORDER BY er.received_ts {order}
                 LIMIT ?
@@ -1490,15 +1546,29 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
 
             args += [limit, start]
             txn.execute(sql, args)
-            event_reports = self.db_pool.cursor_to_dict(txn)
-
-            if count > 0:
-                for row in event_reports:
-                    try:
-                        row["content"] = db_to_json(row["content"])
-                        row["event_json"] = db_to_json(row["event_json"])
-                    except Exception:
-                        continue
+
+            event_reports = []
+            for row in txn:
+                try:
+                    s = db_to_json(row[5]).get("score")
+                    r = db_to_json(row[5]).get("reason")
+                except Exception:
+                    logger.error("Unable to parse json from event_reports: %s", row[0])
+                    continue
+                event_reports.append(
+                    {
+                        "id": row[0],
+                        "received_ts": row[1],
+                        "room_id": row[2],
+                        "event_id": row[3],
+                        "user_id": row[4],
+                        "score": s,
+                        "reason": r,
+                        "sender": row[6],
+                        "canonical_alias": row[7],
+                        "name": row[8],
+                    }
+                )
 
             return event_reports, count
 
diff --git a/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql b/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql
new file mode 100644
index 0000000000..00a9431a97
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/22puppet_token.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Whether the access token is an admin token for controlling another user.
+ALTER TABLE access_tokens ADD COLUMN puppets_user_id TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql b/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql
new file mode 100644
index 0000000000..a2842687f1
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql
@@ -0,0 +1,2 @@
+INSERT INTO background_updates (update_name, progress_json) VALUES
+  ('users_have_local_media', '{}');
\ No newline at end of file