diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index ae3afdd5d2..7c0f953365 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -1,4 +1,5 @@
# Copyright 2016 OpenMarket Ltd
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -19,9 +20,17 @@ from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.replication.tcp.streams import ToDeviceStream
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.database import (
+ DatabasePool,
+ LoggingDatabaseConnection,
+ LoggingTransaction,
+)
from synapse.storage.engines import PostgresEngine
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import (
+ AbstractStreamIdGenerator,
+ MultiWriterIdGenerator,
+ StreamIdGenerator,
+)
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache
@@ -34,14 +43,21 @@ logger = logging.getLogger(__name__)
class DeviceInboxWorkerStore(SQLBaseStore):
- def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
# Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions.
- self._last_device_delete_cache = ExpiringCache(
+ self._last_device_delete_cache: ExpiringCache[
+ Tuple[str, Optional[str]], int
+ ] = ExpiringCache(
cache_name="last_device_delete_cache",
clock=self._clock,
max_len=10000,
@@ -53,14 +69,16 @@ class DeviceInboxWorkerStore(SQLBaseStore):
self._instance_name in hs.config.worker.writers.to_device
)
- self._device_inbox_id_gen = MultiWriterIdGenerator(
- db_conn=db_conn,
- db=database,
- stream_name="to_device",
- instance_name=self._instance_name,
- tables=[("device_inbox", "instance_name", "stream_id")],
- sequence_name="device_inbox_sequence",
- writers=hs.config.worker.writers.to_device,
+ self._device_inbox_id_gen: AbstractStreamIdGenerator = (
+ MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ stream_name="to_device",
+ instance_name=self._instance_name,
+ tables=[("device_inbox", "instance_name", "stream_id")],
+ sequence_name="device_inbox_sequence",
+ writers=hs.config.worker.writers.to_device,
+ )
)
else:
self._can_write_to_device = True
@@ -101,6 +119,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == ToDeviceStream.NAME:
+ # If replication is happening than postgres must be being used.
+ assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator)
self._device_inbox_id_gen.advance(instance_name, token)
for row in rows:
if row.entity.startswith("@"):
@@ -220,11 +240,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
log_kv({"message": f"deleted {count} messages for device", "count": count})
# Update the cache, ensuring that we only ever increase the value
- last_deleted_stream_id = self._last_device_delete_cache.get(
+ updated_last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), 0
)
self._last_device_delete_cache[(user_id, device_id)] = max(
- last_deleted_stream_id, up_to_stream_id
+ updated_last_deleted_stream_id, up_to_stream_id
)
return count
@@ -432,7 +452,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
async with self._device_inbox_id_gen.get_next() as stream_id:
- now_ms = self.clock.time_msec()
+ now_ms = self._clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
)
@@ -483,7 +503,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
async with self._device_inbox_id_gen.get_next() as stream_id:
- now_ms = self.clock.time_msec()
+ now_ms = self._clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox",
add_messages_txn,
|