summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/__init__.py2
-rw-r--r--synapse/storage/background_updates.py4
-rw-r--r--synapse/storage/database.py21
-rw-r--r--synapse/storage/databases/__init__.py2
-rw-r--r--synapse/storage/databases/main/devices.py12
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py39
-rw-r--r--synapse/storage/databases/main/roommember.py2
-rw-r--r--synapse/storage/databases/main/schema/delta/58/15unread_count.sql6
-rw-r--r--synapse/storage/keys.py2
-rw-r--r--synapse/storage/persist_events.py4
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/purge_events.py2
-rw-r--r--synapse/storage/relations.py6
-rw-r--r--synapse/storage/state.py4
-rw-r--r--synapse/storage/util/id_generators.py4
15 files changed, 46 insertions, 66 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 5ef3853559..8e5d78f6f7 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -37,7 +37,7 @@ from synapse.storage.state import StateGroupStorage
 __all__ = ["DataStores", "DataStore"]
 
 
-class Storage(object):
+class Storage:
     """The high level interfaces for talking to various storage layers.
     """
 
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 67a89cd51a..810721ebe9 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -24,7 +24,7 @@ from . import engines
 logger = logging.getLogger(__name__)
 
 
-class BackgroundUpdatePerformance(object):
+class BackgroundUpdatePerformance:
     """Tracks the how long a background update is taking to update its items"""
 
     def __init__(self, name):
@@ -71,7 +71,7 @@ class BackgroundUpdatePerformance(object):
             return float(self.total_item_count) / float(self.total_duration_ms)
 
 
-class BackgroundUpdater(object):
+class BackgroundUpdater:
     """ Background updates are updates to the database that run in the
     background. Each update processes a batch of data at once. We attempt to
     limit the impact of each update by monitoring how long each batch takes to
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 78ca6d8346..ed8a9bffb1 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -248,7 +248,7 @@ class LoggingTransaction:
         self.txn.close()
 
 
-class PerformanceCounters(object):
+class PerformanceCounters:
     def __init__(self):
         self.current_counters = {}
         self.previous_counters = {}
@@ -286,7 +286,7 @@ class PerformanceCounters(object):
 R = TypeVar("R")
 
 
-class DatabasePool(object):
+class DatabasePool:
     """Wraps a single physical database and connection pool.
 
     A single database may be used by multiple data stores.
@@ -1104,7 +1104,7 @@ class DatabasePool(object):
         self,
         table: str,
         keyvalues: Dict[str, Any],
-        retcol: Iterable[str],
+        retcol: str,
         allow_none: Literal[False] = False,
         desc: str = "simple_select_one_onecol",
     ) -> Any:
@@ -1115,7 +1115,7 @@ class DatabasePool(object):
         self,
         table: str,
         keyvalues: Dict[str, Any],
-        retcol: Iterable[str],
+        retcol: str,
         allow_none: Literal[True] = True,
         desc: str = "simple_select_one_onecol",
     ) -> Optional[Any]:
@@ -1125,7 +1125,7 @@ class DatabasePool(object):
         self,
         table: str,
         keyvalues: Dict[str, Any],
-        retcol: Iterable[str],
+        retcol: str,
         allow_none: bool = False,
         desc: str = "simple_select_one_onecol",
     ) -> Optional[Any]:
@@ -1156,7 +1156,7 @@ class DatabasePool(object):
         txn: LoggingTransaction,
         table: str,
         keyvalues: Dict[str, Any],
-        retcol: Iterable[str],
+        retcol: str,
         allow_none: Literal[False] = False,
     ) -> Any:
         ...
@@ -1168,7 +1168,7 @@ class DatabasePool(object):
         txn: LoggingTransaction,
         table: str,
         keyvalues: Dict[str, Any],
-        retcol: Iterable[str],
+        retcol: str,
         allow_none: Literal[True] = True,
     ) -> Optional[Any]:
         ...
@@ -1179,7 +1179,7 @@ class DatabasePool(object):
         txn: LoggingTransaction,
         table: str,
         keyvalues: Dict[str, Any],
-        retcol: Iterable[str],
+        retcol: str,
         allow_none: bool = False,
     ) -> Optional[Any]:
         ret = cls.simple_select_onecol_txn(
@@ -1196,10 +1196,7 @@ class DatabasePool(object):
 
     @staticmethod
     def simple_select_onecol_txn(
-        txn: LoggingTransaction,
-        table: str,
-        keyvalues: Dict[str, Any],
-        retcol: Iterable[str],
+        txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any], retcol: str,
     ) -> List[Any]:
         sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
 
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 0ac854aee2..7f08bd8285 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -24,7 +24,7 @@ from synapse.storage.prepare_database import prepare_database
 logger = logging.getLogger(__name__)
 
 
-class Databases(object):
+class Databases:
     """The various databases.
 
     These are low level interfaces to physical databases.
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index f8fe948122..add4e3ea0e 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -291,15 +291,9 @@ class DeviceWorkerStore(SQLBaseStore):
                 prev_id = stream_id
 
                 if device is not None:
-                    key_json = device.key_json
-                    if key_json:
-                        result["keys"] = db_to_json(key_json)
-
-                        if device.signatures:
-                            for sig_user_id, sigs in device.signatures.items():
-                                result["keys"].setdefault("signatures", {}).setdefault(
-                                    sig_user_id, {}
-                                ).update(sigs)
+                    keys = device.keys
+                    if keys:
+                        result["keys"] = keys
 
                     device_display_name = device.display_name
                     if device_display_name:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 09af033233..fba3098ea2 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -43,12 +43,8 @@ class DeviceKeyLookupResult:
 
     # the key data from e2e_device_keys_json. Typically includes fields like
     # "algorithm", "keys" (including the curve25519 identity key and the ed25519 signing
-    # key) and "signatures" (a signature of the structure by the ed25519 key)
-    key_json = attr.ib(type=Optional[str])
-
-    # cross-signing sigs on this device.
-    # dict from (signing user_id)->(signing device_id)->sig
-    signatures = attr.ib(type=Optional[Dict[str, Dict[str, str]]], factory=dict)
+    # key) and "signatures" (a map from (user id) to (key id/device_id) to signature.)
+    keys = attr.ib(type=Optional[JsonDict])
 
 
 class EndToEndKeyWorkerStore(SQLBaseStore):
@@ -70,15 +66,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             for device_id, device in user_devices.items():
                 result = {"device_id": device_id}
 
-                key_json = device.key_json
-                if key_json:
-                    result["keys"] = db_to_json(key_json)
-
-                    if device.signatures:
-                        for sig_user_id, sigs in device.signatures.items():
-                            result["keys"].setdefault("signatures", {}).setdefault(
-                                sig_user_id, {}
-                            ).update(sigs)
+                keys = device.keys
+                if keys:
+                    result["keys"] = keys
 
                 device_display_name = device.display_name
                 if device_display_name:
@@ -114,16 +104,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         for user_id, device_keys in results.items():
             rv[user_id] = {}
             for device_id, device_info in device_keys.items():
-                r = db_to_json(device_info.key_json)
+                r = device_info.keys
                 r["unsigned"] = {}
                 display_name = device_info.display_name
                 if display_name is not None:
                     r["unsigned"]["device_display_name"] = display_name
-                if device_info.signatures:
-                    for sig_user_id, sigs in device_info.signatures.items():
-                        r.setdefault("signatures", {}).setdefault(
-                            sig_user_id, {}
-                        ).update(sigs)
                 rv[user_id][device_id] = r
 
         return rv
@@ -140,6 +125,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
         Any cross-signatures made on the keys by the owner of the device are also
         included.
 
+        The cross-signatures are added to the `signatures` field within the `keys`
+        object in the response.
+
         Args:
             query_list: List of pairs of user_ids and device_ids. Device id can be None
                 to indicate "all devices for this user"
@@ -170,7 +158,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             (user_id, device_id)
             for user_id, dev in result.items()
             for device_id, d in dev.items()
-            if d is not None
+            if d is not None and d.keys is not None
         )
 
         for batch in batch_iter(signature_query, 50):
@@ -183,8 +171,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             # add each cross-signing signature to the correct device in the result dict.
             for (user_id, key_id, device_id, signature) in cross_sigs_result:
                 target_device_result = result[user_id][device_id]
-                target_device_signatures = target_device_result.signatures
-
+                target_device_signatures = target_device_result.keys.setdefault(
+                    "signatures", {}
+                )
                 signing_user_signatures = target_device_signatures.setdefault(
                     user_id, {}
                 )
@@ -240,7 +229,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
             if include_deleted_devices:
                 deleted_devices.remove((user_id, device_id))
             result.setdefault(user_id, {})[device_id] = DeviceKeyLookupResult(
-                display_name, key_json
+                display_name, db_to_json(key_json) if key_json else None
             )
 
         if include_deleted_devices:
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index c46f5cd524..91a8b43da3 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -999,7 +999,7 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
         await self.db_pool.runInteraction("forget_membership", f)
 
 
-class _JoinedHostsCache(object):
+class _JoinedHostsCache:
     """Cache for joined hosts in a room that is optimised to handle updates
     via state deltas.
     """
diff --git a/synapse/storage/databases/main/schema/delta/58/15unread_count.sql b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql
index b451e8663a..317fba8a5d 100644
--- a/synapse/storage/databases/main/schema/delta/58/15unread_count.sql
+++ b/synapse/storage/databases/main/schema/delta/58/15unread_count.sql
@@ -19,8 +19,8 @@
 
 -- Add columns to event_push_actions and event_push_actions_staging to track unread
 -- messages and calculate unread counts.
-ALTER TABLE event_push_actions_staging ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0;
-ALTER TABLE event_push_actions ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0;
+ALTER TABLE event_push_actions_staging ADD COLUMN unread SMALLINT;
+ALTER TABLE event_push_actions ADD COLUMN unread SMALLINT;
 
 -- Add column to event_push_summary
-ALTER TABLE event_push_summary ADD COLUMN unread_count BIGINT NOT NULL DEFAULT 0;
\ No newline at end of file
+ALTER TABLE event_push_summary ADD COLUMN unread_count BIGINT;
\ No newline at end of file
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 4769b21529..afd10f7bae 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -22,6 +22,6 @@ logger = logging.getLogger(__name__)
 
 
 @attr.s(slots=True, frozen=True)
-class FetchKeyResult(object):
+class FetchKeyResult:
     verify_key = attr.ib()  # VerifyKey: the key itself
     valid_until_ts = attr.ib()  # int: how long we can use this key for
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index f15b95e633..dbaeef91dd 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -69,7 +69,7 @@ stale_forward_extremities_counter = Histogram(
 )
 
 
-class _EventPeristenceQueue(object):
+class _EventPeristenceQueue:
     """Queues up events so that they can be persisted in bulk with only one
     concurrent transaction per room.
     """
@@ -172,7 +172,7 @@ class _EventPeristenceQueue(object):
             pass
 
 
-class EventsPersistenceStorage(object):
+class EventsPersistenceStorage:
     """High level interface for handling persisting newly received events.
 
     Takes care of batching up events by room, and calculating the necessary
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 1c5f305132..964d8d9eb8 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -569,7 +569,7 @@ def _get_or_create_schema_state(txn, database_engine):
 
 
 @attr.s()
-class _DirectoryListing(object):
+class _DirectoryListing:
     """Helper class to store schema file name and the
     absolute path to it.
 
diff --git a/synapse/storage/purge_events.py b/synapse/storage/purge_events.py
index 79d9f06e2e..bfa0a9fd06 100644
--- a/synapse/storage/purge_events.py
+++ b/synapse/storage/purge_events.py
@@ -20,7 +20,7 @@ from typing import Set
 logger = logging.getLogger(__name__)
 
 
-class PurgeEventsStorage(object):
+class PurgeEventsStorage:
     """High level interface for purging rooms and event history.
     """
 
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index d471ec9860..d30e3f11e7 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -23,7 +23,7 @@ logger = logging.getLogger(__name__)
 
 
 @attr.s
-class PaginationChunk(object):
+class PaginationChunk:
     """Returned by relation pagination APIs.
 
     Attributes:
@@ -51,7 +51,7 @@ class PaginationChunk(object):
 
 
 @attr.s(frozen=True, slots=True)
-class RelationPaginationToken(object):
+class RelationPaginationToken:
     """Pagination token for relation pagination API.
 
     As the results are in topological order, we can use the
@@ -82,7 +82,7 @@ class RelationPaginationToken(object):
 
 
 @attr.s(frozen=True, slots=True)
-class AggregationPaginationToken(object):
+class AggregationPaginationToken:
     """Pagination token for relation aggregation pagination API.
 
     As the results are order by count and then MAX(stream_ordering) of the
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 96a1b59d64..8f68d968f0 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -29,7 +29,7 @@ T = TypeVar("T")
 
 
 @attr.s(slots=True)
-class StateFilter(object):
+class StateFilter:
     """A filter used when querying for state.
 
     Attributes:
@@ -326,7 +326,7 @@ class StateFilter(object):
         return member_filter, non_member_filter
 
 
-class StateGroupStorage(object):
+class StateGroupStorage:
     """High level interface to fetching state for event.
     """
 
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 9f3d23f0a5..76bc3afdfa 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -25,7 +25,7 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.util.sequence import PostgresSequenceGenerator
 
 
-class IdGenerator(object):
+class IdGenerator:
     def __init__(self, db_conn, table, column):
         self._lock = threading.Lock()
         self._next_id = _load_current_id(db_conn, table, column)
@@ -59,7 +59,7 @@ def _load_current_id(db_conn, table, column, step=1):
     return (max if step > 0 else min)(current_id, step)
 
 
-class StreamIdGenerator(object):
+class StreamIdGenerator:
     """Used to generate new stream ids when persisting events while keeping
     track of which transactions have been completed.