summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/data_stores/main/client_ips.py8
-rw-r--r--synapse/storage/data_stores/main/end_to_end_keys.py217
-rw-r--r--synapse/storage/data_stores/main/events_worker.py97
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/README.md13
-rw-r--r--synapse/storage/data_stores/main/schema/full_schemas/README.txt19
-rw-r--r--synapse/storage/data_stores/main/search.py19
-rw-r--r--synapse/storage/data_stores/main/state.py18
7 files changed, 319 insertions, 72 deletions
diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py
index 320c5b0f07..add3037b69 100644
--- a/synapse/storage/data_stores/main/client_ips.py
+++ b/synapse/storage/data_stores/main/client_ips.py
@@ -451,16 +451,18 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
                 # Technically an access token might not be associated with
                 # a device so we need to check.
                 if device_id:
-                    self.db.simple_upsert_txn(
+                    # this is always an update rather than an upsert: the row should
+                    # already exist, and if it doesn't, that may be because it has been
+                    # deleted, and we don't want to re-create it.
+                    self.db.simple_update_txn(
                         txn,
                         table="devices",
                         keyvalues={"user_id": user_id, "device_id": device_id},
-                        values={
+                        updatevalues={
                             "user_agent": user_agent,
                             "last_seen": last_seen,
                             "ip": ip,
                         },
-                        lock=False,
                     )
             except Exception as e:
                 # Failed to upsert, log and continue
diff --git a/synapse/storage/data_stores/main/end_to_end_keys.py b/synapse/storage/data_stores/main/end_to_end_keys.py
index 38cd0ca9b8..e551606f9d 100644
--- a/synapse/storage/data_stores/main/end_to_end_keys.py
+++ b/synapse/storage/data_stores/main/end_to_end_keys.py
@@ -14,15 +14,18 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Dict, List
+
 from six import iteritems
 
 from canonicaljson import encode_canonical_json, json
 
+from twisted.enterprise.adbapi import Connection
 from twisted.internet import defer
 
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedList
 
 
 class EndToEndKeyWorkerStore(SQLBaseStore):
@@ -271,7 +274,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         Args:
             txn (twisted.enterprise.adbapi.Connection): db connection
             user_id (str): the user whose key is being requested
-            key_type (str): the type of key that is being set: either 'master'
+            key_type (str): the type of key that is being requested: either 'master'
                 for a master key, 'self_signing' for a self-signing key, or
                 'user_signing' for a user-signing key
             from_user_id (str): if specified, signatures made by this user on
@@ -316,8 +319,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         """Returns a user's cross-signing key.
 
         Args:
-            user_id (str): the user whose self-signing key is being requested
-            key_type (str): the type of cross-signing key to get
+            user_id (str): the user whose key is being requested
+            key_type (str): the type of key that is being requested: either 'master'
+                for a master key, 'self_signing' for a self-signing key, or
+                'user_signing' for a user-signing key
             from_user_id (str): if specified, signatures made by this user on
                 the self-signing key will be included in the result
 
@@ -332,6 +337,206 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             from_user_id,
         )
 
+    @cached(num_args=1)
+    def _get_bare_e2e_cross_signing_keys(self, user_id):
+        """Dummy function.  Only used to make a cache for
+        _get_bare_e2e_cross_signing_keys_bulk.
+        """
+        raise NotImplementedError()
+
+    @cachedList(
+        cached_method_name="_get_bare_e2e_cross_signing_keys",
+        list_name="user_ids",
+        num_args=1,
+    )
+    def _get_bare_e2e_cross_signing_keys_bulk(
+        self, user_ids: List[str]
+    ) -> Dict[str, Dict[str, dict]]:
+        """Returns the cross-signing keys for a set of users.  The output of this
+        function should be passed to _get_e2e_cross_signing_signatures_txn if
+        the signatures for the calling user need to be fetched.
+
+        Args:
+            user_ids (list[str]): the users whose keys are being requested
+
+        Returns:
+            dict[str, dict[str, dict]]: mapping from user ID to key type to key
+                data.  If a user's cross-signing keys were not found, either
+                their user ID will not be in the dict, or their user ID will map
+                to None.
+
+        """
+        return self.db.runInteraction(
+            "get_bare_e2e_cross_signing_keys_bulk",
+            self._get_bare_e2e_cross_signing_keys_bulk_txn,
+            user_ids,
+        )
+
+    def _get_bare_e2e_cross_signing_keys_bulk_txn(
+        self, txn: Connection, user_ids: List[str],
+    ) -> Dict[str, Dict[str, dict]]:
+        """Returns the cross-signing keys for a set of users.  The output of this
+        function should be passed to _get_e2e_cross_signing_signatures_txn if
+        the signatures for the calling user need to be fetched.
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            user_ids (list[str]): the users whose keys are being requested
+
+        Returns:
+            dict[str, dict[str, dict]]: mapping from user ID to key type to key
+                data.  If a user's cross-signing keys were not found, their user
+                ID will not be in the dict.
+
+        """
+        result = {}
+
+        batch_size = 100
+        chunks = [
+            user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size)
+        ]
+        for user_chunk in chunks:
+            sql = """
+                SELECT k.user_id, k.keytype, k.keydata, k.stream_id
+                  FROM e2e_cross_signing_keys k
+                  INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
+                                FROM e2e_cross_signing_keys
+                               GROUP BY user_id, keytype) s
+                 USING (user_id, stream_id, keytype)
+                 WHERE k.user_id IN (%s)
+            """ % (
+                ",".join("?" for u in user_chunk),
+            )
+            query_params = []
+            query_params.extend(user_chunk)
+
+            txn.execute(sql, query_params)
+            rows = self.db.cursor_to_dict(txn)
+
+            for row in rows:
+                user_id = row["user_id"]
+                key_type = row["keytype"]
+                key = json.loads(row["keydata"])
+                user_info = result.setdefault(user_id, {})
+                user_info[key_type] = key
+
+        return result
+
+    def _get_e2e_cross_signing_signatures_txn(
+        self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str,
+    ) -> Dict[str, Dict[str, dict]]:
+        """Returns the cross-signing signatures made by a user on a set of keys.
+
+        Args:
+            txn (twisted.enterprise.adbapi.Connection): db connection
+            keys (dict[str, dict[str, dict]]): a map of user ID to key type to
+                key data.  This dict will be modified to add signatures.
+            from_user_id (str): fetch the signatures made by this user
+
+        Returns:
+            dict[str, dict[str, dict]]: mapping from user ID to key type to key
+                data.  The return value will be the same as the keys argument,
+                with the modifications included.
+        """
+
+        # find out what cross-signing keys (a.k.a. devices) we need to get
+        # signatures for.  This is a map of (user_id, device_id) to key type
+        # (device_id is the key's public part).
+        devices = {}
+
+        for user_id, user_info in keys.items():
+            if user_info is None:
+                continue
+            for key_type, key in user_info.items():
+                device_id = None
+                for k in key["keys"].values():
+                    device_id = k
+                devices[(user_id, device_id)] = key_type
+
+        device_list = list(devices)
+
+        # split into batches
+        batch_size = 100
+        chunks = [
+            device_list[i : i + batch_size]
+            for i in range(0, len(device_list), batch_size)
+        ]
+        for user_chunk in chunks:
+            sql = """
+                SELECT target_user_id, target_device_id, key_id, signature
+                  FROM e2e_cross_signing_signatures
+                 WHERE user_id = ?
+                   AND (%s)
+            """ % (
+                " OR ".join(
+                    "(target_user_id = ? AND target_device_id = ?)" for d in devices
+                )
+            )
+            query_params = [from_user_id]
+            for item in devices:
+                # item is a (user_id, device_id) tuple
+                query_params.extend(item)
+
+            txn.execute(sql, query_params)
+            rows = self.db.cursor_to_dict(txn)
+
+            # and add the signatures to the appropriate keys
+            for row in rows:
+                key_id = row["key_id"]
+                target_user_id = row["target_user_id"]
+                target_device_id = row["target_device_id"]
+                key_type = devices[(target_user_id, target_device_id)]
+                # We need to copy everything, because the result may have come
+                # from the cache.  dict.copy only does a shallow copy, so we
+                # need to recursively copy the dicts that will be modified.
+                user_info = keys[target_user_id] = keys[target_user_id].copy()
+                target_user_key = user_info[key_type] = user_info[key_type].copy()
+                if "signatures" in target_user_key:
+                    signatures = target_user_key["signatures"] = target_user_key[
+                        "signatures"
+                    ].copy()
+                    if from_user_id in signatures:
+                        user_sigs = signatures[from_user_id] = signatures[from_user_id]
+                        user_sigs[key_id] = row["signature"]
+                    else:
+                        signatures[from_user_id] = {key_id: row["signature"]}
+                else:
+                    target_user_key["signatures"] = {
+                        from_user_id: {key_id: row["signature"]}
+                    }
+
+        return keys
+
+    @defer.inlineCallbacks
+    def get_e2e_cross_signing_keys_bulk(
+        self, user_ids: List[str], from_user_id: str = None
+    ) -> defer.Deferred:
+        """Returns the cross-signing keys for a set of users.
+
+        Args:
+            user_ids (list[str]): the users whose keys are being requested
+            from_user_id (str): if specified, signatures made by this user on
+                the self-signing keys will be included in the result
+
+        Returns:
+            Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to
+                key data.  If a user's cross-signing keys were not found, either
+                their user ID will not be in the dict, or their user ID will map
+                to None.
+        """
+
+        result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
+
+        if from_user_id:
+            result = yield self.db.runInteraction(
+                "get_e2e_cross_signing_signatures",
+                self._get_e2e_cross_signing_signatures_txn,
+                result,
+                from_user_id,
+            )
+
+        return result
+
     def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
         """Return a list of changes from the user signature stream to notify remotes.
         Note that the user signature stream represents when a user signs their
@@ -520,6 +725,10 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 },
             )
 
+        self._invalidate_cache_and_stream(
+            txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
+        )
+
     def set_e2e_cross_signing_key(self, user_id, key_type, key):
         """Set a user's cross-signing key.
 
diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py
index 9ee117ce0f..2c9142814c 100644
--- a/synapse/storage/data_stores/main/events_worker.py
+++ b/synapse/storage/data_stores/main/events_worker.py
@@ -19,8 +19,10 @@ import itertools
 import logging
 import threading
 from collections import namedtuple
+from typing import List, Optional
 
 from canonicaljson import json
+from constantly import NamedConstant, Names
 
 from twisted.internet import defer
 
@@ -55,6 +57,16 @@ EVENT_QUEUE_TIMEOUT_S = 0.1  # Timeout when waiting for requests for events
 _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
 
 
+class EventRedactBehaviour(Names):
+    """
+    What to do when retrieving a redacted event from the database.
+    """
+
+    AS_IS = NamedConstant()
+    REDACT = NamedConstant()
+    BLOCK = NamedConstant()
+
+
 class EventsWorkerStore(SQLBaseStore):
     def __init__(self, database: Database, db_conn, hs):
         super(EventsWorkerStore, self).__init__(database, db_conn, hs)
@@ -125,25 +137,27 @@ class EventsWorkerStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_event(
         self,
-        event_id,
-        check_redacted=True,
-        get_prev_content=False,
-        allow_rejected=False,
-        allow_none=False,
-        check_room_id=None,
+        event_id: List[str],
+        redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+        get_prev_content: bool = False,
+        allow_rejected: bool = False,
+        allow_none: bool = False,
+        check_room_id: Optional[str] = None,
     ):
         """Get an event from the database by event_id.
 
         Args:
-            event_id (str): The event_id of the event to fetch
-            check_redacted (bool): If True, check if event has been redacted
-                and redact it.
-            get_prev_content (bool): If True and event is a state event,
+            event_id: The event_id of the event to fetch
+            redact_behaviour: Determine what to do with a redacted event. Possible values:
+                * AS_IS - Return the full event body with no redacted content
+                * REDACT - Return the event but with a redacted body
+                * DISALLOW - Do not return redacted events
+            get_prev_content: If True and event is a state event,
                 include the previous states content in the unsigned field.
-            allow_rejected (bool): If True return rejected events.
-            allow_none (bool): If True, return None if no event found, if
+            allow_rejected: If True return rejected events.
+            allow_none: If True, return None if no event found, if
                 False throw a NotFoundError
-            check_room_id (str|None): if not None, check the room of the found event.
+            check_room_id: if not None, check the room of the found event.
                 If there is a mismatch, behave as per allow_none.
 
         Returns:
@@ -154,7 +168,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         events = yield self.get_events_as_list(
             [event_id],
-            check_redacted=check_redacted,
+            redact_behaviour=redact_behaviour,
             get_prev_content=get_prev_content,
             allow_rejected=allow_rejected,
         )
@@ -173,27 +187,30 @@ class EventsWorkerStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_events(
         self,
-        event_ids,
-        check_redacted=True,
-        get_prev_content=False,
-        allow_rejected=False,
+        event_ids: List[str],
+        redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+        get_prev_content: bool = False,
+        allow_rejected: bool = False,
     ):
         """Get events from the database
 
         Args:
-            event_ids (list): The event_ids of the events to fetch
-            check_redacted (bool): If True, check if event has been redacted
-                and redact it.
-            get_prev_content (bool): If True and event is a state event,
+            event_ids: The event_ids of the events to fetch
+            redact_behaviour: Determine what to do with a redacted event. Possible
+                values:
+                * AS_IS - Return the full event body with no redacted content
+                * REDACT - Return the event but with a redacted body
+                * DISALLOW - Do not return redacted events
+            get_prev_content: If True and event is a state event,
                 include the previous states content in the unsigned field.
-            allow_rejected (bool): If True return rejected events.
+            allow_rejected: If True return rejected events.
 
         Returns:
             Deferred : Dict from event_id to event.
         """
         events = yield self.get_events_as_list(
             event_ids,
-            check_redacted=check_redacted,
+            redact_behaviour=redact_behaviour,
             get_prev_content=get_prev_content,
             allow_rejected=allow_rejected,
         )
@@ -203,21 +220,23 @@ class EventsWorkerStore(SQLBaseStore):
     @defer.inlineCallbacks
     def get_events_as_list(
         self,
-        event_ids,
-        check_redacted=True,
-        get_prev_content=False,
-        allow_rejected=False,
+        event_ids: List[str],
+        redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
+        get_prev_content: bool = False,
+        allow_rejected: bool = False,
     ):
         """Get events from the database and return in a list in the same order
         as given by `event_ids` arg.
 
         Args:
-            event_ids (list): The event_ids of the events to fetch
-            check_redacted (bool): If True, check if event has been redacted
-                and redact it.
-            get_prev_content (bool): If True and event is a state event,
+            event_ids: The event_ids of the events to fetch
+            redact_behaviour: Determine what to do with a redacted event. Possible values:
+                * AS_IS - Return the full event body with no redacted content
+                * REDACT - Return the event but with a redacted body
+                * DISALLOW - Do not return redacted events
+            get_prev_content: If True and event is a state event,
                 include the previous states content in the unsigned field.
-            allow_rejected (bool): If True return rejected events.
+            allow_rejected: If True, return rejected events.
 
         Returns:
             Deferred[list[EventBase]]: List of events fetched from the database. The
@@ -319,10 +338,14 @@ class EventsWorkerStore(SQLBaseStore):
                     # Update the cache to save doing the checks again.
                     entry.event.internal_metadata.recheck_redaction = False
 
-            if check_redacted and entry.redacted_event:
-                event = entry.redacted_event
-            else:
-                event = entry.event
+            event = entry.event
+
+            if entry.redacted_event:
+                if redact_behaviour == EventRedactBehaviour.BLOCK:
+                    # Skip this event
+                    continue
+                elif redact_behaviour == EventRedactBehaviour.REDACT:
+                    event = entry.redacted_event
 
             events.append(event)
 
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.md b/synapse/storage/data_stores/main/schema/full_schemas/README.md
new file mode 100644
index 0000000000..bbd3f18604
--- /dev/null
+++ b/synapse/storage/data_stores/main/schema/full_schemas/README.md
@@ -0,0 +1,13 @@
+# Building full schema dumps
+
+These schemas need to be made from a database that has had all background updates run.
+
+To do so, use `scripts-dev/make_full_schema.sh`. This will produce
+`full.sql.postgres ` and `full.sql.sqlite` files.
+
+Ensure postgres is installed and your user has the ability to run bash commands
+such as `createdb`.
+
+```
+./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/
+```
diff --git a/synapse/storage/data_stores/main/schema/full_schemas/README.txt b/synapse/storage/data_stores/main/schema/full_schemas/README.txt
deleted file mode 100644
index d3f6401344..0000000000
--- a/synapse/storage/data_stores/main/schema/full_schemas/README.txt
+++ /dev/null
@@ -1,19 +0,0 @@
-Building full schema dumps
-==========================
-
-These schemas need to be made from a database that has had all background updates run.
-
-Postgres
---------
-
-$ pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner $DATABASE_NAME| sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > full.sql.postgres
-
-SQLite
-------
-
-$ sqlite3 $DATABASE_FILE ".schema" > full.sql.sqlite
-
-After
------
-
-Delete the CREATE statements for "sqlite_stat1", "schema_version", "applied_schema_deltas", and "applied_module_schemas".
\ No newline at end of file
diff --git a/synapse/storage/data_stores/main/search.py b/synapse/storage/data_stores/main/search.py
index 4eec2fae5e..47ebb8a214 100644
--- a/synapse/storage/data_stores/main/search.py
+++ b/synapse/storage/data_stores/main/search.py
@@ -25,6 +25,7 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
 from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 
@@ -384,7 +385,7 @@ class SearchStore(SearchBackgroundUpdateStore):
         """
         clauses = []
 
-        search_query = search_query = _parse_query(self.database_engine, search_term)
+        search_query = _parse_query(self.database_engine, search_term)
 
         args = []
 
@@ -453,7 +454,12 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         results = list(filter(lambda row: row["room_id"] in room_ids, results))
 
-        events = yield self.get_events_as_list([r["event_id"] for r in results])
+        # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
+        # search results (which is a data leak)
+        events = yield self.get_events_as_list(
+            [r["event_id"] for r in results],
+            redact_behaviour=EventRedactBehaviour.BLOCK,
+        )
 
         event_map = {ev.event_id: ev for ev in events}
 
@@ -495,7 +501,7 @@ class SearchStore(SearchBackgroundUpdateStore):
         """
         clauses = []
 
-        search_query = search_query = _parse_query(self.database_engine, search_term)
+        search_query = _parse_query(self.database_engine, search_term)
 
         args = []
 
@@ -600,7 +606,12 @@ class SearchStore(SearchBackgroundUpdateStore):
 
         results = list(filter(lambda row: row["room_id"] in room_ids, results))
 
-        events = yield self.get_events_as_list([r["event_id"] for r in results])
+        # We set redact_behaviour to BLOCK here to prevent redacted events being returned in
+        # search results (which is a data leak)
+        events = yield self.get_events_as_list(
+            [r["event_id"] for r in results],
+            redact_behaviour=EventRedactBehaviour.BLOCK,
+        )
 
         event_map = {ev.event_id: ev for ev in events}
 
diff --git a/synapse/storage/data_stores/main/state.py b/synapse/storage/data_stores/main/state.py
index 9ef7b48c74..dcc6b43cdf 100644
--- a/synapse/storage/data_stores/main/state.py
+++ b/synapse/storage/data_stores/main/state.py
@@ -278,7 +278,7 @@ class StateGroupWorkerStore(
 
     @defer.inlineCallbacks
     def get_room_predecessor(self, room_id):
-        """Get the predecessor room of an upgraded room if one exists.
+        """Get the predecessor of an upgraded room if it exists.
         Otherwise return None.
 
         Args:
@@ -291,14 +291,22 @@ class StateGroupWorkerStore(
                     * room_id (str): The room ID of the predecessor room
                     * event_id (str): The ID of the tombstone event in the predecessor room
 
+                None if a predecessor key is not found, or is not a dictionary.
+
         Raises:
-            NotFoundError if the room is unknown
+            NotFoundError if the given room is unknown
         """
         # Retrieve the room's create event
         create_event = yield self.get_create_event_for_room(room_id)
 
-        # Return predecessor if present
-        return create_event.content.get("predecessor", None)
+        # Retrieve the predecessor key of the create event
+        predecessor = create_event.content.get("predecessor", None)
+
+        # Ensure the key is a dictionary
+        if not isinstance(predecessor, dict):
+            return None
+
+        return predecessor
 
     @defer.inlineCallbacks
     def get_create_event_for_room(self, room_id):
@@ -318,7 +326,7 @@ class StateGroupWorkerStore(
 
         # If we can't find the create event, assume we've hit a dead end
         if not create_id:
-            raise NotFoundError("Unknown room %s" % (room_id))
+            raise NotFoundError("Unknown room %s" % (room_id,))
 
         # Retrieve the room's create event and return
         create_event = yield self.get_event(create_id)