summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py2
-rw-r--r--synapse/storage/database.py19
-rw-r--r--synapse/storage/databases/main/transactions.py28
3 files changed, 26 insertions, 23 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index d472676acf..6b68d8720c 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -114,7 +114,7 @@ def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
         db_content = db_content.tobytes()
 
     # Decode it to a Unicode string before feeding it to the JSON decoder, since
-    # Python 3.5 does not support deserializing bytes.
+    # it only supports handling strings
     if isinstance(db_content, (bytes, bytearray)):
         db_content = db_content.decode("utf8")
 
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 9452368bf0..a761ad603b 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -171,10 +171,7 @@ class LoggingDatabaseConnection:
 
 
 # The type of entry which goes on our after_callbacks and exception_callbacks lists.
-#
-# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
-# that mypy sees the type but the runtime python doesn't.
-_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
+_CallbackListEntry = Tuple[Callable[..., None], Iterable[Any], Dict[str, Any]]
 
 
 R = TypeVar("R")
@@ -221,7 +218,7 @@ class LoggingTransaction:
         self.after_callbacks = after_callbacks
         self.exception_callbacks = exception_callbacks
 
-    def call_after(self, callback: "Callable[..., None]", *args: Any, **kwargs: Any):
+    def call_after(self, callback: Callable[..., None], *args: Any, **kwargs: Any):
         """Call the given callback on the main twisted thread after the
         transaction has finished. Used to invalidate the caches on the
         correct thread.
@@ -233,7 +230,7 @@ class LoggingTransaction:
         self.after_callbacks.append((callback, args, kwargs))
 
     def call_on_exception(
-        self, callback: "Callable[..., None]", *args: Any, **kwargs: Any
+        self, callback: Callable[..., None], *args: Any, **kwargs: Any
     ):
         # if self.exception_callbacks is None, that means that whatever constructed the
         # LoggingTransaction isn't expecting there to be any callbacks; assert that
@@ -485,7 +482,7 @@ class DatabasePool:
         desc: str,
         after_callbacks: List[_CallbackListEntry],
         exception_callbacks: List[_CallbackListEntry],
-        func: "Callable[..., R]",
+        func: Callable[..., R],
         *args: Any,
         **kwargs: Any,
     ) -> R:
@@ -618,7 +615,7 @@ class DatabasePool:
     async def runInteraction(
         self,
         desc: str,
-        func: "Callable[..., R]",
+        func: Callable[..., R],
         *args: Any,
         db_autocommit: bool = False,
         **kwargs: Any,
@@ -678,7 +675,7 @@ class DatabasePool:
 
     async def runWithConnection(
         self,
-        func: "Callable[..., R]",
+        func: Callable[..., R],
         *args: Any,
         db_autocommit: bool = False,
         **kwargs: Any,
@@ -718,7 +715,9 @@ class DatabasePool:
             # pool).
             assert not self.engine.in_transaction(conn)
 
-            with LoggingContext("runWithConnection", parent_context) as context:
+            with LoggingContext(
+                str(curr_context), parent_context=parent_context
+            ) as context:
                 sched_duration_sec = monotonic_time() - start_time
                 sql_scheduling_timer.observe(sched_duration_sec)
                 context.add_database_scheduled(sched_duration_sec)
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index b28ca61f80..82335e7a9d 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -14,7 +14,7 @@
 
 import logging
 from collections import namedtuple
-from typing import Dict, List, Optional, Tuple
+from typing import Iterable, List, Optional, Tuple
 
 from canonicaljson import encode_canonical_json
 
@@ -295,33 +295,37 @@ class TransactionStore(TransactionWorkerStore):
                 },
             )
 
-    async def bulk_store_destination_rooms_entries(
-        self, room_and_destination_to_ordering: Dict[Tuple[str, str], int]
-    ):
+    async def store_destination_rooms_entries(
+        self,
+        destinations: Iterable[str],
+        room_id: str,
+        stream_ordering: int,
+    ) -> None:
         """
-        Updates or creates `destination_rooms` entries for a number of events.
+        Updates or creates `destination_rooms` entries in batch for a single event.
 
         Args:
-            room_and_destination_to_ordering: A mapping of (room, destination) -> stream_id
+            destinations: list of destinations
+            room_id: the room_id of the event
+            stream_ordering: the stream_ordering of the event
         """
 
         await self.db_pool.simple_upsert_many(
             table="destinations",
             key_names=("destination",),
-            key_values={(d,) for _, d in room_and_destination_to_ordering.keys()},
+            key_values=[(d,) for d in destinations],
             value_names=[],
             value_values=[],
             desc="store_destination_rooms_entries_dests",
         )
 
+        rows = [(destination, room_id) for destination in destinations]
         await self.db_pool.simple_upsert_many(
             table="destination_rooms",
-            key_names=("room_id", "destination"),
-            key_values=list(room_and_destination_to_ordering.keys()),
+            key_names=("destination", "room_id"),
+            key_values=rows,
             value_names=["stream_ordering"],
-            value_values=[
-                (stream_id,) for stream_id in room_and_destination_to_ordering.values()
-            ],
+            value_values=[(stream_ordering,)] * len(rows),
             desc="store_destination_rooms_entries_rooms",
         )