summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/_base.py2
-rw-r--r--synapse/storage/devices.py50
-rw-r--r--synapse/storage/events.py13
-rw-r--r--synapse/storage/events_worker.py6
-rw-r--r--synapse/storage/registration.py60
-rw-r--r--synapse/storage/stream.py2
-rw-r--r--synapse/storage/transactions.py28
7 files changed, 63 insertions, 98 deletions
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 29589853c6..2f940dbae6 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -30,12 +30,12 @@ from prometheus_client import Histogram
 from twisted.internet import defer
 
 from synapse.api.errors import StoreError
+from synapse.logging.context import LoggingContext, PreserveLoggingContext
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.types import get_domain_from_id
 from synapse.util import batch_iter
 from synapse.util.caches.descriptors import Cache
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
 from synapse.util.stringutils import exception_to_unicode
 
 # import a function which will return a monotonic time, in seconds
diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py
index 3413a46675..d2b113a4e7 100644
--- a/synapse/storage/devices.py
+++ b/synapse/storage/devices.py
@@ -24,6 +24,7 @@ from synapse.api.errors import StoreError
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import Cache, SQLBaseStore, db_to_json
 from synapse.storage.background_updates import BackgroundUpdateStore
+from synapse.util import batch_iter
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
 
 logger = logging.getLogger(__name__)
@@ -391,22 +392,47 @@ class DeviceWorkerStore(SQLBaseStore):
 
         return now_stream_id, []
 
-    @defer.inlineCallbacks
-    def get_user_whose_devices_changed(self, from_key):
-        """Get set of users whose devices have changed since `from_key`.
+    def get_users_whose_devices_changed(self, from_key, user_ids):
+        """Get set of users whose devices have changed since `from_key` that
+        are in the given list of user_ids.
+
+        Args:
+            from_key (str): The device lists stream token
+            user_ids (Iterable[str])
+
+        Returns:
+            Deferred[set[str]]: The set of user_ids whose devices have changed
+            since `from_key`
         """
         from_key = int(from_key)
-        changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
-        if changed is not None:
-            defer.returnValue(set(changed))
 
-        sql = """
-            SELECT DISTINCT user_id FROM device_lists_stream WHERE stream_id > ?
-        """
-        rows = yield self._execute(
-            "get_user_whose_devices_changed", None, sql, from_key
+        # Get set of users who *may* have changed. Users not in the returned
+        # list have definitely not changed.
+        to_check = list(
+            self._device_list_stream_cache.get_entities_changed(user_ids, from_key)
+        )
+
+        if not to_check:
+            return defer.succeed(set())
+
+        def _get_users_whose_devices_changed_txn(txn):
+            changes = set()
+
+            sql = """
+                SELECT DISTINCT user_id FROM device_lists_stream
+                WHERE stream_id > ?
+                AND user_id IN (%s)
+            """
+
+            for chunk in batch_iter(to_check, 100):
+                txn.execute(sql % (",".join("?" for _ in chunk),), (from_key,) + chunk)
+                changes.update(user_id for user_id, in txn)
+
+            return changes
+
+        return self.runInteraction(
+            "get_users_whose_devices_changed", _get_users_whose_devices_changed_txn
         )
-        defer.returnValue(set(row[0] for row in rows))
 
     def get_all_device_list_changes_for_remotes(self, from_key, to_key):
         """Return a list of `(stream_id, user_id, destination)` which is the
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index fefba39ea1..b486ca50eb 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -33,6 +33,8 @@ from synapse.api.constants import EventTypes
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase  # noqa: F401
 from synapse.events.snapshot import EventContext  # noqa: F401
+from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
+from synapse.logging.utils import log_function
 from synapse.metrics import BucketCollector
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.state import StateResolutionStore
@@ -45,8 +47,6 @@ from synapse.util import batch_iter
 from synapse.util.async_helpers import ObservableDeferred
 from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
 from synapse.util.frozenutils import frozendict_json_encoder
-from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
-from synapse.util.logutils import log_function
 from synapse.util.metrics import Measure
 
 logger = logging.getLogger(__name__)
@@ -253,7 +253,14 @@ class EventsStore(
         )
 
         # Read the extrems every 60 minutes
-        hs.get_clock().looping_call(self._read_forward_extremities, 60 * 60 * 1000)
+        def read_forward_extremities():
+            # run as a background process to make sure that the database transactions
+            # have a logcontext to report to
+            return run_as_background_process(
+                "read_forward_extremities", self._read_forward_extremities
+            )
+
+        hs.get_clock().looping_call(read_forward_extremities, 60 * 60 * 1000)
 
     @defer.inlineCallbacks
     def _read_forward_extremities(self):
diff --git a/synapse/storage/events_worker.py b/synapse/storage/events_worker.py
index 6d680d405a..09db872511 100644
--- a/synapse/storage/events_worker.py
+++ b/synapse/storage/events_worker.py
@@ -29,14 +29,14 @@ from synapse.api.room_versions import EventFormatVersions
 from synapse.events import FrozenEvent, event_type_from_format_version  # noqa: F401
 from synapse.events.snapshot import EventContext  # noqa: F401
 from synapse.events.utils import prune_event
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import get_domain_from_id
-from synapse.util.logcontext import (
+from synapse.logging.context import (
     LoggingContext,
     PreserveLoggingContext,
     make_deferred_yieldable,
     run_in_background,
 )
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import get_domain_from_id
 from synapse.util.metrics import Measure
 
 from ._base import SQLBaseStore
diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py
index 983ce13291..aea5b3276b 100644
--- a/synapse/storage/registration.py
+++ b/synapse/storage/registration.py
@@ -25,6 +25,7 @@ from twisted.internet import defer
 
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, ThreepidValidationError
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage import background_updates
 from synapse.storage._base import SQLBaseStore
 from synapse.types import UserID
@@ -432,19 +433,6 @@ class RegistrationWorkerStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def get_3pid_guest_access_token(self, medium, address):
-        ret = yield self._simple_select_one(
-            "threepid_guest_access_tokens",
-            {"medium": medium, "address": address},
-            ["guest_access_token"],
-            True,
-            "get_3pid_guest_access_token",
-        )
-        if ret:
-            defer.returnValue(ret["guest_access_token"])
-        defer.returnValue(None)
-
-    @defer.inlineCallbacks
     def get_user_id_by_threepid(self, medium, address, require_verified=False):
         """Returns user id from threepid
 
@@ -619,9 +607,15 @@ class RegistrationStore(
         )
 
         # Create a background job for culling expired 3PID validity tokens
-        hs.get_clock().looping_call(
-            self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
-        )
+        def start_cull():
+            # run as a background process to make sure that the database transactions
+            # have a logcontext to report to
+            return run_as_background_process(
+                "cull_expired_threepid_validation_tokens",
+                self.cull_expired_threepid_validation_tokens,
+            )
+
+        hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
 
     @defer.inlineCallbacks
     def _backgroud_update_set_deactivated_flag(self, progress, batch_size):
@@ -972,40 +966,6 @@ class RegistrationStore(
 
         defer.returnValue(res if res else False)
 
-    @defer.inlineCallbacks
-    def save_or_get_3pid_guest_access_token(
-        self, medium, address, access_token, inviter_user_id
-    ):
-        """
-        Gets the 3pid's guest access token if exists, else saves access_token.
-
-        Args:
-            medium (str): Medium of the 3pid. Must be "email".
-            address (str): 3pid address.
-            access_token (str): The access token to persist if none is
-                already persisted.
-            inviter_user_id (str): User ID of the inviter.
-
-        Returns:
-            deferred str: Whichever access token is persisted at the end
-            of this function call.
-        """
-
-        def insert(txn):
-            txn.execute(
-                "INSERT INTO threepid_guest_access_tokens "
-                "(medium, address, guest_access_token, first_inviter) "
-                "VALUES (?, ?, ?, ?)",
-                (medium, address, access_token, inviter_user_id),
-            )
-
-        try:
-            yield self.runInteraction("save_3pid_guest_access_token", insert)
-            defer.returnValue(access_token)
-        except self.database_engine.module.IntegrityError:
-            ret = yield self.get_3pid_guest_access_token(medium, address)
-            defer.returnValue(ret)
-
     def add_user_pending_deactivation(self, user_id):
         """
         Adds a user to the table of users who need to be parted from all the rooms they're
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index d9482a3843..386a9dbe14 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -41,12 +41,12 @@ from six.moves import range
 
 from twisted.internet import defer
 
+from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.events_worker import EventsWorkerStore
 from synapse.types import RoomStreamToken
 from synapse.util.caches.stream_change_cache import StreamChangeCache
-from synapse.util.logcontext import make_deferred_yieldable, run_in_background
 
 logger = logging.getLogger(__name__)
 
diff --git a/synapse/storage/transactions.py b/synapse/storage/transactions.py
index b1188f6bcb..fd18619178 100644
--- a/synapse/storage/transactions.py
+++ b/synapse/storage/transactions.py
@@ -133,34 +133,6 @@ class TransactionStore(SQLBaseStore):
             desc="set_received_txn_response",
         )
 
-    def prep_send_transaction(self, transaction_id, destination, origin_server_ts):
-        """Persists an outgoing transaction and calculates the values for the
-        previous transaction id list.
-
-        This should be called before sending the transaction so that it has the
-        correct value for the `prev_ids` key.
-
-        Args:
-            transaction_id (str)
-            destination (str)
-            origin_server_ts (int)
-
-        Returns:
-            list: A list of previous transaction ids.
-        """
-        return defer.succeed([])
-
-    def delivered_txn(self, transaction_id, destination, code, response_dict):
-        """Persists the response for an outgoing transaction.
-
-        Args:
-            transaction_id (str)
-            destination (str)
-            code (int)
-            response_json (str)
-        """
-        pass
-
     @defer.inlineCallbacks
     def get_destination_retry_timings(self, destination):
         """Gets the current retry timings (if any) for a given destination.