summary refs log tree commit diff
path: root/synapse/storage/databases/main/deviceinbox.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/deviceinbox.py')
-rw-r--r--synapse/storage/databases/main/deviceinbox.py52
1 files changed, 36 insertions, 16 deletions
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index ae3afdd5d2..7c0f953365 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -1,4 +1,5 @@
 # Copyright 2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -19,9 +20,17 @@ from synapse.logging import issue9533_logger
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.replication.tcp.streams import ToDeviceStream
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import (
+    AbstractStreamIdGenerator,
+    MultiWriterIdGenerator,
+    StreamIdGenerator,
+)
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -34,14 +43,21 @@ logger = logging.getLogger(__name__)
 
 
 class DeviceInboxWorkerStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self._instance_name = hs.get_instance_name()
 
         # Map of (user_id, device_id) to the last stream_id that has been
         # deleted up to. This is so that we can no op deletions.
-        self._last_device_delete_cache = ExpiringCache(
+        self._last_device_delete_cache: ExpiringCache[
+            Tuple[str, Optional[str]], int
+        ] = ExpiringCache(
             cache_name="last_device_delete_cache",
             clock=self._clock,
             max_len=10000,
@@ -53,14 +69,16 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 self._instance_name in hs.config.worker.writers.to_device
             )
 
-            self._device_inbox_id_gen = MultiWriterIdGenerator(
-                db_conn=db_conn,
-                db=database,
-                stream_name="to_device",
-                instance_name=self._instance_name,
-                tables=[("device_inbox", "instance_name", "stream_id")],
-                sequence_name="device_inbox_sequence",
-                writers=hs.config.worker.writers.to_device,
+            self._device_inbox_id_gen: AbstractStreamIdGenerator = (
+                MultiWriterIdGenerator(
+                    db_conn=db_conn,
+                    db=database,
+                    stream_name="to_device",
+                    instance_name=self._instance_name,
+                    tables=[("device_inbox", "instance_name", "stream_id")],
+                    sequence_name="device_inbox_sequence",
+                    writers=hs.config.worker.writers.to_device,
+                )
             )
         else:
             self._can_write_to_device = True
@@ -101,6 +119,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
 
     def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == ToDeviceStream.NAME:
+            # If replication is happening than postgres must be being used.
+            assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator)
             self._device_inbox_id_gen.advance(instance_name, token)
             for row in rows:
                 if row.entity.startswith("@"):
@@ -220,11 +240,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         log_kv({"message": f"deleted {count} messages for device", "count": count})
 
         # Update the cache, ensuring that we only ever increase the value
-        last_deleted_stream_id = self._last_device_delete_cache.get(
+        updated_last_deleted_stream_id = self._last_device_delete_cache.get(
             (user_id, device_id), 0
         )
         self._last_device_delete_cache[(user_id, device_id)] = max(
-            last_deleted_stream_id, up_to_stream_id
+            updated_last_deleted_stream_id, up_to_stream_id
         )
 
         return count
@@ -432,7 +452,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 )
 
         async with self._device_inbox_id_gen.get_next() as stream_id:
-            now_ms = self.clock.time_msec()
+            now_ms = self._clock.time_msec()
             await self.db_pool.runInteraction(
                 "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
             )
@@ -483,7 +503,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             )
 
         async with self._device_inbox_id_gen.get_next() as stream_id:
-            now_ms = self.clock.time_msec()
+            now_ms = self._clock.time_msec()
             await self.db_pool.runInteraction(
                 "add_messages_from_remote_to_device_inbox",
                 add_messages_txn,