diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 45ca6620a8..691080ce74 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import List, Tuple
+from typing import List, Optional, Tuple
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.replication.tcp.streams import ToDeviceStream
@@ -115,7 +115,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
async def get_new_messages_for_device(
self,
user_id: str,
- device_id: str,
+ device_id: Optional[str],
last_stream_id: int,
current_stream_id: int,
limit: int = 100,
@@ -163,7 +163,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
@trace
async def delete_messages_for_device(
- self, user_id: str, device_id: str, up_to_stream_id: int
+ self, user_id: str, device_id: Optional[str], up_to_stream_id: int
) -> int:
"""
Args:
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index d788dc0fc6..757da3d55d 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Dict, List
+from typing import Dict, List, Optional
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
@@ -109,7 +109,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
return users
@cached(num_args=1)
- async def user_last_seen_monthly_active(self, user_id: str) -> int:
+ async def user_last_seen_monthly_active(self, user_id: str) -> Optional[int]:
"""
Checks if a given user is part of the monthly active user group
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 0309661841..b7072f1f5e 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -22,7 +22,6 @@ from canonicaljson import encode_canonical_json
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, LoggingTransaction
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict
from synapse.util.caches.expiringcache import ExpiringCache
@@ -312,49 +311,23 @@ class TransactionStore(TransactionWorkerStore):
stream_ordering: the stream_ordering of the event
"""
- return await self.db_pool.runInteraction(
- "store_destination_rooms_entries",
- self._store_destination_rooms_entries_txn,
- destinations,
- room_id,
- stream_ordering,
+ await self.db_pool.simple_upsert_many(
+ table="destinations",
+ key_names=("destination",),
+ key_values=[(d,) for d in destinations],
+ value_names=[],
+ value_values=[],
+ desc="store_destination_rooms_entries_dests",
)
- def _store_destination_rooms_entries_txn(
- self,
- txn: LoggingTransaction,
- destinations: Iterable[str],
- room_id: str,
- stream_ordering: int,
- ) -> None:
-
- # ensure we have a `destinations` row for this destination, as there is
- # a foreign key constraint.
- if isinstance(self.database_engine, PostgresEngine):
- q = """
- INSERT INTO destinations (destination)
- VALUES (?)
- ON CONFLICT DO NOTHING;
- """
- elif isinstance(self.database_engine, Sqlite3Engine):
- q = """
- INSERT OR IGNORE INTO destinations (destination)
- VALUES (?);
- """
- else:
- raise RuntimeError("Unknown database engine")
-
- txn.execute_batch(q, ((destination,) for destination in destinations))
-
rows = [(destination, room_id) for destination in destinations]
-
- self.db_pool.simple_upsert_many_txn(
- txn,
+ await self.db_pool.simple_upsert_many(
table="destination_rooms",
key_names=("destination", "room_id"),
key_values=rows,
value_names=["stream_ordering"],
value_values=[(stream_ordering,)] * len(rows),
+ desc="store_destination_rooms_entries_rooms",
)
async def get_destination_last_successful_stream_ordering(
|