summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorBrendan Abolivier <babolivier@matrix.org>2020-09-04 11:02:10 +0100
committerBrendan Abolivier <babolivier@matrix.org>2020-09-04 11:02:10 +0100
commitcc23d81a74caead82fa97bddd535b29bb9e1df56 (patch)
tree7013f8ae7aaac8867313ac9e86ac90f93abbbff8 /synapse/storage
parentMerge branch 'develop' into matrix-org-hotfixes (diff)
parentRevert "Add experimental support for sharding event persister. (#8170)" (#8242) (diff)
downloadsynapse-cc23d81a74caead82fa97bddd535b29bb9e1df56.tar.xz
Merge branch 'develop' into matrix-org-hotfixes
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/__init__.py2
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py109
-rw-r--r--synapse/storage/databases/main/event_federation.py2
-rw-r--r--synapse/storage/databases/main/events.py4
-rw-r--r--synapse/storage/databases/main/events_worker.py66
-rw-r--r--synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql16
-rw-r--r--synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres26
-rw-r--r--synapse/storage/util/id_generators.py10
8 files changed, 91 insertions, 144 deletions
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index c73d54fb67..0ac854aee2 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -68,7 +68,7 @@ class Databases(object):
 
                     # If we're on a process that can persist events also
                     # instantiate a `PersistEventsStore`
-                    if hs.get_instance_name() in hs.config.worker.writers.events:
+                    if hs.config.worker.writers.events == hs.get_instance_name():
                         persist_events = PersistEventsStore(hs, database, main)
 
                 if "state" in database_config.databases:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index cc0b15ae07..09af033233 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -25,6 +25,7 @@ from twisted.enterprise.adbapi import Connection
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import make_in_list_sql_clause
+from synapse.storage.types import Cursor
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
@@ -45,8 +46,9 @@ class DeviceKeyLookupResult:
     # key) and "signatures" (a signature of the structure by the ed25519 key)
     key_json = attr.ib(type=Optional[str])
 
-    # cross-signing sigs
-    signatures = attr.ib(type=Optional[Dict], default=None)
+    # cross-signing sigs on this device.
+    # dict from (signing user_id)->(signing device_id)->sig
+    signatures = attr.ib(type=Optional[Dict[str, Dict[str, str]]], factory=dict)
 
 
 class EndToEndKeyWorkerStore(SQLBaseStore):
@@ -133,7 +135,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         include_all_devices: bool = False,
         include_deleted_devices: bool = False,
     ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
-        """Fetch a list of device keys, together with their cross-signatures.
+        """Fetch a list of device keys
+
+        Any cross-signatures made on the keys by the owner of the device are also
+        included.
 
         Args:
             query_list: List of pairs of user_ids and device_ids. Device id can be None
@@ -154,22 +159,51 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
 
         result = await self.db_pool.runInteraction(
             "get_e2e_device_keys",
-            self._get_e2e_device_keys_and_signatures_txn,
+            self._get_e2e_device_keys_txn,
             query_list,
             include_all_devices,
             include_deleted_devices,
         )
 
+        # get the (user_id, device_id) tuples to look up cross-signatures for
+        signature_query = (
+            (user_id, device_id)
+            for user_id, dev in result.items()
+            for device_id, d in dev.items()
+            if d is not None
+        )
+
+        for batch in batch_iter(signature_query, 50):
+            cross_sigs_result = await self.db_pool.runInteraction(
+                "get_e2e_cross_signing_signatures",
+                self._get_e2e_cross_signing_signatures_for_devices_txn,
+                batch,
+            )
+
+            # add each cross-signing signature to the correct device in the result dict.
+            for (user_id, key_id, device_id, signature) in cross_sigs_result:
+                target_device_result = result[user_id][device_id]
+                target_device_signatures = target_device_result.signatures
+
+                signing_user_signatures = target_device_signatures.setdefault(
+                    user_id, {}
+                )
+                signing_user_signatures[key_id] = signature
+
         log_kv(result)
         return result
 
-    def _get_e2e_device_keys_and_signatures_txn(
+    def _get_e2e_device_keys_txn(
         self, txn, query_list, include_all_devices=False, include_deleted_devices=False
     ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
+        """Get information on devices from the database
+
+        The results include the device's keys and self-signatures, but *not* any
+        cross-signing signatures which have been added subsequently (for which, see
+        get_e2e_device_keys_and_signatures)
+        """
         query_clauses = []
         query_params = []
-        signature_query_clauses = []
-        signature_query_params = []
 
         if include_all_devices is False:
             include_deleted_devices = False
@@ -180,20 +214,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         for (user_id, device_id) in query_list:
             query_clause = "user_id = ?"
             query_params.append(user_id)
-            signature_query_clause = "target_user_id = ?"
-            signature_query_params.append(user_id)
 
             if device_id is not None:
                 query_clause += " AND device_id = ?"
                 query_params.append(device_id)
-                signature_query_clause += " AND target_device_id = ?"
-                signature_query_params.append(device_id)
-
-            signature_query_clause += " AND user_id = ?"
-            signature_query_params.append(user_id)
 
             query_clauses.append(query_clause)
-            signature_query_clauses.append(signature_query_clause)
 
         sql = (
             "SELECT user_id, device_id, "
@@ -221,41 +247,36 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             for user_id, device_id in deleted_devices:
                 result.setdefault(user_id, {})[device_id] = None
 
-        # get signatures on the device
-        signature_sql = ("SELECT *  FROM e2e_cross_signing_signatures WHERE %s") % (
-            " OR ".join("(" + q + ")" for q in signature_query_clauses)
-        )
+        return result
 
-        txn.execute(signature_sql, signature_query_params)
-        rows = self.db_pool.cursor_to_dict(txn)
-
-        # add each cross-signing signature to the correct device in the result dict.
-        for row in rows:
-            signing_user_id = row["user_id"]
-            signing_key_id = row["key_id"]
-            target_user_id = row["target_user_id"]
-            target_device_id = row["target_device_id"]
-            signature = row["signature"]
-
-            target_user_result = result.get(target_user_id)
-            if not target_user_result:
-                continue
+    def _get_e2e_cross_signing_signatures_for_devices_txn(
+        self, txn: Cursor, device_query: Iterable[Tuple[str, str]]
+    ) -> List[Tuple[str, str, str, str]]:
+        """Get cross-signing signatures for a given list of devices
 
-            target_device_result = target_user_result.get(target_device_id)
-            if not target_device_result:
-                # note that target_device_result will be None for deleted devices.
-                continue
+        Returns signatures made by the owners of the devices.
 
-            target_device_signatures = target_device_result.signatures
-            if target_device_signatures is None:
-                target_device_signatures = target_device_result.signatures = {}
+        Returns: a list of results; each entry in the list is a tuple of
+            (user_id, key_id, target_device_id, signature).
+        """
+        signature_query_clauses = []
+        signature_query_params = []
 
-            signing_user_signatures = target_device_signatures.setdefault(
-                signing_user_id, {}
+        for (user_id, device_id) in device_query:
+            signature_query_clauses.append(
+                "target_user_id = ? AND target_device_id = ? AND user_id = ?"
             )
-            signing_user_signatures[signing_key_id] = signature
+            signature_query_params.extend([user_id, device_id, user_id])
 
-        return result
+        signature_sql = """
+            SELECT user_id, key_id, target_device_id, signature
+            FROM e2e_cross_signing_signatures WHERE %s
+            """ % (
+            " OR ".join("(" + q + ")" for q in signature_query_clauses)
+        )
+
+        txn.execute(signature_sql, signature_query_params)
+        return txn.fetchall()
 
     async def get_e2e_one_time_keys(
         self, user_id: str, device_id: str, key_ids: List[str]
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 4c3c162acf..0b69aa6a94 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -438,7 +438,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         """
 
         if stream_ordering <= self.stream_ordering_month_ago:
-            raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
+            raise StoreError(400, "stream_ordering too old")
 
         sql = """
                 SELECT event_id FROM stream_ordering_to_exterm
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index b94fe7ac17..b3d27a2ee7 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -97,7 +97,6 @@ class PersistEventsStore:
         self.store = main_data_store
         self.database_engine = db.engine
         self._clock = hs.get_clock()
-        self._instance_name = hs.get_instance_name()
 
         self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
         self.is_mine_id = hs.is_mine_id
@@ -109,7 +108,7 @@ class PersistEventsStore:
 
         # This should only exist on instances that are configured to write
         assert (
-            hs.get_instance_name() in hs.config.worker.writers.events
+            hs.config.worker.writers.events == hs.get_instance_name()
         ), "Can only instantiate EventsStore on master"
 
     async def _persist_events_and_state_updates(
@@ -801,7 +800,6 @@ class PersistEventsStore:
             table="events",
             values=[
                 {
-                    "instance_name": self._instance_name,
                     "stream_ordering": event.internal_metadata.stream_ordering,
                     "topological_ordering": event.depth,
                     "depth": event.depth,
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 17f5997b89..a7a73cc3d8 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -42,8 +42,7 @@ from synapse.replication.tcp.streams import BackfillStream
 from synapse.replication.tcp.streams.events import EventsStream
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool
-from synapse.storage.engines import PostgresEngine
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import Collection, get_domain_from_id
 from synapse.util.caches.descriptors import Cache, cached
 from synapse.util.iterutils import batch_iter
@@ -79,54 +78,27 @@ class EventsWorkerStore(SQLBaseStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super(EventsWorkerStore, self).__init__(database, db_conn, hs)
 
-        if isinstance(database.engine, PostgresEngine):
-            # If we're using Postgres than we can use `MultiWriterIdGenerator`
-            # regardless of whether this process writes to the streams or not.
-            self._stream_id_gen = MultiWriterIdGenerator(
-                db_conn=db_conn,
-                db=database,
-                instance_name=hs.get_instance_name(),
-                table="events",
-                instance_column="instance_name",
-                id_column="stream_ordering",
-                sequence_name="events_stream_seq",
+        if hs.config.worker.writers.events == hs.get_instance_name():
+            # We are the process in charge of generating stream ids for events,
+            # so instantiate ID generators based on the database
+            self._stream_id_gen = StreamIdGenerator(
+                db_conn, "events", "stream_ordering",
             )
-            self._backfill_id_gen = MultiWriterIdGenerator(
-                db_conn=db_conn,
-                db=database,
-                instance_name=hs.get_instance_name(),
-                table="events",
-                instance_column="instance_name",
-                id_column="stream_ordering",
-                sequence_name="events_backfill_stream_seq",
-                positive=False,
+            self._backfill_id_gen = StreamIdGenerator(
+                db_conn,
+                "events",
+                "stream_ordering",
+                step=-1,
+                extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
             )
         else:
-            # We shouldn't be running in worker mode with SQLite, but its useful
-            # to support it for unit tests.
-            #
-            # If this process is the writer than we need to use
-            # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
-            # updated over replication. (Multiple writers are not supported for
-            # SQLite).
-            if hs.get_instance_name() in hs.config.worker.writers.events:
-                self._stream_id_gen = StreamIdGenerator(
-                    db_conn, "events", "stream_ordering",
-                )
-                self._backfill_id_gen = StreamIdGenerator(
-                    db_conn,
-                    "events",
-                    "stream_ordering",
-                    step=-1,
-                    extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
-                )
-            else:
-                self._stream_id_gen = SlavedIdTracker(
-                    db_conn, "events", "stream_ordering"
-                )
-                self._backfill_id_gen = SlavedIdTracker(
-                    db_conn, "events", "stream_ordering", step=-1
-                )
+            # Another process is in charge of persisting events and generating
+            # stream IDs: rely on the replication streams to let us know which
+            # IDs we can process.
+            self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
+            self._backfill_id_gen = SlavedIdTracker(
+                db_conn, "events", "stream_ordering", step=-1
+            )
 
         self._get_event_cache = Cache(
             "*getEvent*",
diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql
deleted file mode 100644
index 98ff76d709..0000000000
--- a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql
+++ /dev/null
@@ -1,16 +0,0 @@
-/* Copyright 2020 The Matrix.org Foundation C.I.C.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-ALTER TABLE events ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
deleted file mode 100644
index 97c1e6a0c5..0000000000
--- a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
+++ /dev/null
@@ -1,26 +0,0 @@
-/* Copyright 2020 The Matrix.org Foundation C.I.C.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-CREATE SEQUENCE IF NOT EXISTS events_stream_seq;
-
-SELECT setval('events_stream_seq', (
-    SELECT COALESCE(MAX(stream_ordering), 1) FROM events
-));
-
-CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq;
-
-SELECT setval('events_backfill_stream_seq', (
-    SELECT COALESCE(-MIN(stream_ordering), 1) FROM events
-));
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 8fd21c2bf8..9f3d23f0a5 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -231,12 +231,8 @@ class MultiWriterIdGenerator:
         # gaps should be relatively rare it's still worth doing the book keeping
         # that allows us to skip forwards when there are gapless runs of
         # positions.
-        #
-        # We start at 1 here as a) the first generated stream ID will be 2, and
-        # b) other parts of the code assume that stream IDs are strictly greater
-        # than 0.
         self._persisted_upto_position = (
-            min(self._current_positions.values()) if self._current_positions else 1
+            min(self._current_positions.values()) if self._current_positions else 0
         )
         self._known_persisted_positions = []  # type: List[int]
 
@@ -366,7 +362,9 @@ class MultiWriterIdGenerator:
         equal to it have been successfully persisted.
         """
 
-        return self.get_persisted_upto_position()
+        # Currently we don't support this operation, as it's not obvious how to
+        # condense the stream positions of multiple writers into a single int.
+        raise NotImplementedError()
 
     def get_current_token_for_writer(self, instance_name: str) -> int:
         """Returns the position of the given writer.