diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 9cce62ae6c..a3fddea042 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -46,6 +46,7 @@ from .events_forward_extremities import EventForwardExtremitiesStore
from .filtering import FilteringStore
from .group_server import GroupServerStore
from .keys import KeyStore
+from .lock import LockStore
from .media_repository import MediaRepositoryStore
from .metrics import ServerMetricsStore
from .monthly_active_users import MonthlyActiveUsersStore
@@ -119,6 +120,7 @@ class DataStore(
CacheInvalidationWorkerStore,
ServerMetricsStore,
EventForwardExtremitiesStore,
+ LockStore,
):
def __init__(self, database: DatabasePool, db_conn, hs):
self.hs = hs
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index c0ea445550..f23f8c6ecf 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -14,18 +14,20 @@
import itertools
import logging
from queue import Empty, PriorityQueue
-from typing import Collection, Dict, Iterable, List, Set, Tuple
+from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.constants import MAX_DEPTH
from synapse.api.errors import StoreError
-from synapse.events import EventBase
+from synapse.api.room_versions import RoomVersion
+from synapse.events import EventBase, make_event_from_dict
from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor
+from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
@@ -1044,6 +1046,107 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
_delete_old_forward_extrem_cache_txn,
)
+ async def insert_received_event_to_staging(
+ self, origin: str, event: EventBase
+ ) -> None:
+ """Insert a newly received event from federation into the staging area."""
+
+ # We use an upsert here to handle the case where we see the same event
+ # from the same server multiple times.
+ await self.db_pool.simple_upsert(
+ table="federation_inbound_events_staging",
+ keyvalues={
+ "origin": origin,
+ "event_id": event.event_id,
+ },
+ values={},
+ insertion_values={
+ "room_id": event.room_id,
+ "received_ts": self._clock.time_msec(),
+ "event_json": json_encoder.encode(event.get_dict()),
+ "internal_metadata": json_encoder.encode(
+ event.internal_metadata.get_dict()
+ ),
+ },
+ desc="insert_received_event_to_staging",
+ )
+
+ async def remove_received_event_from_staging(
+ self,
+ origin: str,
+ event_id: str,
+ ) -> None:
+ """Remove the given event from the staging area"""
+ await self.db_pool.simple_delete(
+ table="federation_inbound_events_staging",
+ keyvalues={
+ "origin": origin,
+ "event_id": event_id,
+ },
+ desc="remove_received_event_from_staging",
+ )
+
+ async def get_next_staged_event_id_for_room(
+ self,
+ room_id: str,
+ ) -> Optional[Tuple[str, str]]:
+ """Get the next event ID in the staging area for the given room."""
+
+ def _get_next_staged_event_id_for_room_txn(txn):
+ sql = """
+ SELECT origin, event_id
+ FROM federation_inbound_events_staging
+ WHERE room_id = ?
+ ORDER BY received_ts ASC
+ LIMIT 1
+ """
+
+ txn.execute(sql, (room_id,))
+
+ return txn.fetchone()
+
+ return await self.db_pool.runInteraction(
+ "get_next_staged_event_id_for_room", _get_next_staged_event_id_for_room_txn
+ )
+
+ async def get_next_staged_event_for_room(
+ self,
+ room_id: str,
+ room_version: RoomVersion,
+ ) -> Optional[Tuple[str, EventBase]]:
+ """Get the next event in the staging area for the given room."""
+
+ def _get_next_staged_event_for_room_txn(txn):
+ sql = """
+ SELECT event_json, internal_metadata, origin
+ FROM federation_inbound_events_staging
+ WHERE room_id = ?
+ ORDER BY received_ts ASC
+ LIMIT 1
+ """
+ txn.execute(sql, (room_id,))
+
+ return txn.fetchone()
+
+ row = await self.db_pool.runInteraction(
+ "get_next_staged_event_for_room", _get_next_staged_event_for_room_txn
+ )
+
+ if not row:
+ return None
+
+ event_d = db_to_json(row[0])
+ internal_metadata_d = db_to_json(row[1])
+ origin = row[2]
+
+ event = make_event_from_dict(
+ event_dict=event_d,
+ room_version=room_version,
+ internal_metadata_dict=internal_metadata_d,
+ )
+
+ return origin, event
+
class EventFederationStore(EventFederationWorkerStore):
"""Responsible for storing and serving up the various graphs associated
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index cbe4be1437..da3a7df27b 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -29,6 +29,25 @@ from synapse.types import JsonDict
logger = logging.getLogger(__name__)
+_REPLACE_STREAM_ORDRING_SQL_COMMANDS = (
+ # there should be no leftover rows without a stream_ordering2, but just in case...
+ "UPDATE events SET stream_ordering2 = stream_ordering WHERE stream_ordering2 IS NULL",
+ # finally, we can drop the rule and switch the columns
+ "DROP RULE populate_stream_ordering2 ON events",
+ "ALTER TABLE events DROP COLUMN stream_ordering",
+ "ALTER TABLE events RENAME COLUMN stream_ordering2 TO stream_ordering",
+)
+
+
+class _BackgroundUpdates:
+ EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
+ EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
+ DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
+ POPULATE_STREAM_ORDERING2 = "populate_stream_ordering2"
+ INDEX_STREAM_ORDERING2 = "index_stream_ordering2"
+ REPLACE_STREAM_ORDERING_COLUMN = "replace_stream_ordering_column"
+
+
@attr.s(slots=True, frozen=True)
class _CalculateChainCover:
"""Return value for _calculate_chain_cover_txn."""
@@ -48,19 +67,15 @@ class _CalculateChainCover:
class EventsBackgroundUpdatesStore(SQLBaseStore):
-
- EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
- EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url"
- DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
-
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
- self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
+ _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME,
+ self._background_reindex_origin_server_ts,
)
self.db_pool.updates.register_background_update_handler(
- self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
+ _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME,
self._background_reindex_fields_sender,
)
@@ -85,7 +100,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
)
self.db_pool.updates.register_background_update_handler(
- self.DELETE_SOFT_FAILED_EXTREMITIES, self._cleanup_extremities_bg_update
+ _BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES,
+ self._cleanup_extremities_bg_update,
)
self.db_pool.updates.register_background_update_handler(
@@ -139,6 +155,24 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self._purged_chain_cover_index,
)
+ # bg updates for replacing stream_ordering with a BIGINT
+ # (these only run on postgres.)
+ self.db_pool.updates.register_background_update_handler(
+ _BackgroundUpdates.POPULATE_STREAM_ORDERING2,
+ self._background_populate_stream_ordering2,
+ )
+ self.db_pool.updates.register_background_index_update(
+ _BackgroundUpdates.INDEX_STREAM_ORDERING2,
+ index_name="events_stream_ordering",
+ table="events",
+ columns=["stream_ordering2"],
+ unique=True,
+ )
+ self.db_pool.updates.register_background_update_handler(
+ _BackgroundUpdates.REPLACE_STREAM_ORDERING_COLUMN,
+ self._background_replace_stream_ordering_column,
+ )
+
async def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
@@ -190,18 +224,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
}
self.db_pool.updates._background_update_progress_txn(
- txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
+ txn, _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress
)
return len(rows)
result = await self.db_pool.runInteraction(
- self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
+ _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn
)
if not result:
await self.db_pool.updates._end_background_update(
- self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
+ _BackgroundUpdates.EVENT_FIELDS_SENDER_URL_UPDATE_NAME
)
return result
@@ -264,18 +298,18 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
}
self.db_pool.updates._background_update_progress_txn(
- txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress
+ txn, _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME, progress
)
return len(rows_to_update)
result = await self.db_pool.runInteraction(
- self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
+ _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
)
if not result:
await self.db_pool.updates._end_background_update(
- self.EVENT_ORIGIN_SERVER_TS_NAME
+ _BackgroundUpdates.EVENT_ORIGIN_SERVER_TS_NAME
)
return result
@@ -454,7 +488,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
if not num_handled:
await self.db_pool.updates._end_background_update(
- self.DELETE_SOFT_FAILED_EXTREMITIES
+ _BackgroundUpdates.DELETE_SOFT_FAILED_EXTREMITIES
)
def _drop_table_txn(txn):
@@ -1009,3 +1043,71 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
await self.db_pool.updates._end_background_update("purged_chain_cover")
return result
+
+ async def _background_populate_stream_ordering2(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """Populate events.stream_ordering2, then replace stream_ordering
+
+ This is to deal with the fact that stream_ordering was initially created as a
+ 32-bit integer field.
+ """
+ batch_size = max(batch_size, 1)
+
+ def process(txn: Cursor) -> int:
+ last_stream = progress.get("last_stream", -(1 << 31))
+ txn.execute(
+ """
+ UPDATE events SET stream_ordering2=stream_ordering
+ WHERE stream_ordering IN (
+ SELECT stream_ordering FROM events WHERE stream_ordering > ?
+ ORDER BY stream_ordering LIMIT ?
+ )
+ RETURNING stream_ordering;
+ """,
+ (last_stream, batch_size),
+ )
+ row_count = txn.rowcount
+ if row_count == 0:
+ return 0
+ last_stream = max(row[0] for row in txn)
+ logger.info("populated stream_ordering2 up to %i", last_stream)
+
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ _BackgroundUpdates.POPULATE_STREAM_ORDERING2,
+ {"last_stream": last_stream},
+ )
+ return row_count
+
+ result = await self.db_pool.runInteraction(
+ "_background_populate_stream_ordering2", process
+ )
+
+ if result != 0:
+ return result
+
+ await self.db_pool.updates._end_background_update(
+ _BackgroundUpdates.POPULATE_STREAM_ORDERING2
+ )
+ return 0
+
+ async def _background_replace_stream_ordering_column(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """Drop the old 'stream_ordering' column and rename 'stream_ordering2' into its place."""
+
+ def process(txn: Cursor) -> None:
+ for sql in _REPLACE_STREAM_ORDRING_SQL_COMMANDS:
+ logger.info("completing stream_ordering migration: %s", sql)
+ txn.execute(sql)
+
+ await self.db_pool.runInteraction(
+ "_background_replace_stream_ordering_column", process
+ )
+
+ await self.db_pool.updates._end_background_update(
+ _BackgroundUpdates.REPLACE_STREAM_ORDERING_COLUMN
+ )
+
+ return 0
diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py
new file mode 100644
index 0000000000..e76188328c
--- /dev/null
+++ b/synapse/storage/databases/main/lock.py
@@ -0,0 +1,334 @@
+# Copyright 2021 Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from types import TracebackType
+from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type
+
+from twisted.internet.interfaces import IReactorCore
+
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.types import Connection
+from synapse.util import Clock
+from synapse.util.stringutils import random_string
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+logger = logging.getLogger(__name__)
+
+
+# How often to renew an acquired lock by updating the `last_renewed_ts` time in
+# the lock table.
+_RENEWAL_INTERVAL_MS = 30 * 1000
+
+# How long before an acquired lock times out.
+_LOCK_TIMEOUT_MS = 2 * 60 * 1000
+
+
+class LockStore(SQLBaseStore):
+ """Provides a best effort distributed lock between worker instances.
+
+ Locks are identified by a name and key. A lock is acquired by inserting into
+ the `worker_locks` table if a) there is no existing row for the name/key or
+ b) the existing row has a `last_renewed_ts` older than `_LOCK_TIMEOUT_MS`.
+
+ When a lock is taken out the instance inserts a random `token`, the instance
+ that holds that token holds the lock until it drops (or times out).
+
+ The instance that holds the lock should regularly update the
+ `last_renewed_ts` column with the current time.
+ """
+
+ def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ super().__init__(database, db_conn, hs)
+
+ self._reactor = hs.get_reactor()
+ self._instance_name = hs.get_instance_id()
+
+ # A map from `(lock_name, lock_key)` to the token of any locks that we
+ # think we currently hold.
+ self._live_tokens: Dict[Tuple[str, str], str] = {}
+
+ # When we shut down we want to remove the locks. Technically this can
+ # lead to a race, as we may drop the lock while we are still processing.
+ # However, a) it should be a small window, b) the lock is best effort
+ # anyway and c) we want to really avoid leaking locks when we restart.
+ hs.get_reactor().addSystemEventTrigger(
+ "before",
+ "shutdown",
+ self._on_shutdown,
+ )
+
+ @wrap_as_background_process("LockStore._on_shutdown")
+ async def _on_shutdown(self) -> None:
+ """Called when the server is shutting down"""
+ logger.info("Dropping held locks due to shutdown")
+
+ for (lock_name, lock_key), token in self._live_tokens.items():
+ await self._drop_lock(lock_name, lock_key, token)
+
+ logger.info("Dropped locks due to shutdown")
+
+ async def try_acquire_lock(self, lock_name: str, lock_key: str) -> Optional["Lock"]:
+ """Try to acquire a lock for the given name/key. Will return an async
+ context manager if the lock is successfully acquired, which *must* be
+ used (otherwise the lock will leak).
+ """
+
+ now = self._clock.time_msec()
+ token = random_string(6)
+
+ if self.db_pool.engine.can_native_upsert:
+
+ def _try_acquire_lock_txn(txn: LoggingTransaction) -> bool:
+ # We take out the lock if either a) there is no row for the lock
+ # already or b) the existing row has timed out.
+ sql = """
+ INSERT INTO worker_locks (lock_name, lock_key, instance_name, token, last_renewed_ts)
+ VALUES (?, ?, ?, ?, ?)
+ ON CONFLICT (lock_name, lock_key)
+ DO UPDATE
+ SET
+ token = EXCLUDED.token,
+ instance_name = EXCLUDED.instance_name,
+ last_renewed_ts = EXCLUDED.last_renewed_ts
+ WHERE
+ worker_locks.last_renewed_ts < ?
+ """
+ txn.execute(
+ sql,
+ (
+ lock_name,
+ lock_key,
+ self._instance_name,
+ token,
+ now,
+ now - _LOCK_TIMEOUT_MS,
+ ),
+ )
+
+ # We only acquired the lock if we inserted or updated the table.
+ return bool(txn.rowcount)
+
+ did_lock = await self.db_pool.runInteraction(
+ "try_acquire_lock",
+ _try_acquire_lock_txn,
+ # We can autocommit here as we're executing a single query, this
+ # will avoid serialization errors.
+ db_autocommit=True,
+ )
+ if not did_lock:
+ return None
+
+ else:
+ # If we're on an old SQLite we emulate the above logic by first
+ # clearing out any existing stale locks and then upserting.
+
+ def _try_acquire_lock_emulated_txn(txn: LoggingTransaction) -> bool:
+ sql = """
+ DELETE FROM worker_locks
+ WHERE
+ lock_name = ?
+ AND lock_key = ?
+ AND last_renewed_ts < ?
+ """
+ txn.execute(
+ sql,
+ (lock_name, lock_key, now - _LOCK_TIMEOUT_MS),
+ )
+
+ inserted = self.db_pool.simple_upsert_txn_emulated(
+ txn,
+ table="worker_locks",
+ keyvalues={
+ "lock_name": lock_name,
+ "lock_key": lock_key,
+ },
+ values={},
+ insertion_values={
+ "token": token,
+ "last_renewed_ts": self._clock.time_msec(),
+ "instance_name": self._instance_name,
+ },
+ )
+
+ return inserted
+
+ did_lock = await self.db_pool.runInteraction(
+ "try_acquire_lock_emulated", _try_acquire_lock_emulated_txn
+ )
+
+ if not did_lock:
+ return None
+
+ self._live_tokens[(lock_name, lock_key)] = token
+
+ return Lock(
+ self._reactor,
+ self._clock,
+ self,
+ lock_name=lock_name,
+ lock_key=lock_key,
+ token=token,
+ )
+
+ async def _is_lock_still_valid(
+ self, lock_name: str, lock_key: str, token: str
+ ) -> bool:
+ """Checks whether this instance still holds the lock."""
+ last_renewed_ts = await self.db_pool.simple_select_one_onecol(
+ table="worker_locks",
+ keyvalues={
+ "lock_name": lock_name,
+ "lock_key": lock_key,
+ "token": token,
+ },
+ retcol="last_renewed_ts",
+ allow_none=True,
+ desc="is_lock_still_valid",
+ )
+ return (
+ last_renewed_ts is not None
+ and self._clock.time_msec() - _LOCK_TIMEOUT_MS < last_renewed_ts
+ )
+
+ async def _renew_lock(self, lock_name: str, lock_key: str, token: str) -> None:
+ """Attempt to renew the lock if we still hold it."""
+ await self.db_pool.simple_update(
+ table="worker_locks",
+ keyvalues={
+ "lock_name": lock_name,
+ "lock_key": lock_key,
+ "token": token,
+ },
+ updatevalues={"last_renewed_ts": self._clock.time_msec()},
+ desc="renew_lock",
+ )
+
+ async def _drop_lock(self, lock_name: str, lock_key: str, token: str) -> None:
+ """Attempt to drop the lock, if we still hold it"""
+ await self.db_pool.simple_delete(
+ table="worker_locks",
+ keyvalues={
+ "lock_name": lock_name,
+ "lock_key": lock_key,
+ "token": token,
+ },
+ desc="drop_lock",
+ )
+
+ self._live_tokens.pop((lock_name, lock_key), None)
+
+
+class Lock:
+ """An async context manager that manages an acquired lock, ensuring it is
+ regularly renewed and dropping it when the context manager exits.
+
+ The lock object has an `is_still_valid` method which can be used to
+ double-check the lock is still valid, if e.g. processing work in a loop.
+
+ For example:
+
+ lock = await self.store.try_acquire_lock(...)
+ if not lock:
+ return
+
+ async with lock:
+ for item in work:
+ await process(item)
+
+ if not await lock.is_still_valid():
+ break
+ """
+
+ def __init__(
+ self,
+ reactor: IReactorCore,
+ clock: Clock,
+ store: LockStore,
+ lock_name: str,
+ lock_key: str,
+ token: str,
+ ) -> None:
+ self._reactor = reactor
+ self._clock = clock
+ self._store = store
+ self._lock_name = lock_name
+ self._lock_key = lock_key
+
+ self._token = token
+
+ self._looping_call = clock.looping_call(
+ self._renew, _RENEWAL_INTERVAL_MS, store, lock_name, lock_key, token
+ )
+
+ self._dropped = False
+
+ @staticmethod
+ @wrap_as_background_process("Lock._renew")
+ async def _renew(
+ store: LockStore,
+ lock_name: str,
+ lock_key: str,
+ token: str,
+ ) -> None:
+ """Renew the lock.
+
+ Note: this is a static method, rather than using self.*, so that we
+ don't end up with a reference to `self` in the reactor, which would stop
+ this from being cleaned up if we dropped the context manager.
+ """
+ await store._renew_lock(lock_name, lock_key, token)
+
+ async def is_still_valid(self) -> bool:
+ """Check if the lock is still held by us"""
+ return await self._store._is_lock_still_valid(
+ self._lock_name, self._lock_key, self._token
+ )
+
+ async def __aenter__(self) -> None:
+ if self._dropped:
+ raise Exception("Cannot reuse a Lock object")
+
+ async def __aexit__(
+ self,
+ _exctype: Optional[Type[BaseException]],
+ _excinst: Optional[BaseException],
+ _exctb: Optional[TracebackType],
+ ) -> bool:
+ if self._looping_call.running:
+ self._looping_call.stop()
+
+ await self._store._drop_lock(self._lock_name, self._lock_key, self._token)
+ self._dropped = True
+
+ return False
+
+ def __del__(self) -> None:
+ if not self._dropped:
+ # We should not be dropped without the lock being released (unless
+ # we're shutting down), but if we are then let's at least stop
+ # renewing the lock.
+ if self._looping_call.running:
+ self._looping_call.stop()
+
+ if self._reactor.running:
+ logger.error(
+ "Lock for (%s, %s) dropped without being released",
+ self._lock_name,
+ self._lock_key,
+ )
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e5c5cf8ff0..e31c5864ac 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -53,6 +53,9 @@ class TokenLookupResult:
valid_until_ms: The timestamp the token expires, if any.
token_owner: The "owner" of the token. This is either the same as the
user, or a server admin who is logged in as the user.
+ token_used: True if this token was used at least once in a request.
+ This field can be out of date since `get_user_by_access_token` is
+ cached.
"""
user_id = attr.ib(type=str)
@@ -62,6 +65,7 @@ class TokenLookupResult:
device_id = attr.ib(type=Optional[str], default=None)
valid_until_ms = attr.ib(type=Optional[int], default=None)
token_owner = attr.ib(type=str)
+ token_used = attr.ib(type=bool, default=False)
# Make the token owner default to the user ID, which is the common case.
@token_owner.default
@@ -69,6 +73,29 @@ class TokenLookupResult:
return self.user_id
+@attr.s(frozen=True, slots=True)
+class RefreshTokenLookupResult:
+ """Result of looking up a refresh token."""
+
+ user_id = attr.ib(type=str)
+ """The user this token belongs to."""
+
+ device_id = attr.ib(type=str)
+ """The device associated with this refresh token."""
+
+ token_id = attr.ib(type=int)
+ """The ID of this refresh token."""
+
+ next_token_id = attr.ib(type=Optional[int])
+ """The ID of the refresh token which replaced this one."""
+
+ has_next_refresh_token_been_refreshed = attr.ib(type=bool)
+ """True if the next refresh token was used for another refresh."""
+
+ has_next_access_token_been_used = attr.ib(type=bool)
+ """True if the next access token was already used at least once."""
+
+
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
@@ -441,7 +468,8 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
access_tokens.id as token_id,
access_tokens.device_id,
access_tokens.valid_until_ms,
- access_tokens.user_id as token_owner
+ access_tokens.user_id as token_owner,
+ access_tokens.used as token_used
FROM users
INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id)
WHERE token = ?
@@ -449,8 +477,15 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
txn.execute(sql, (token,))
rows = self.db_pool.cursor_to_dict(txn)
+
if rows:
- return TokenLookupResult(**rows[0])
+ row = rows[0]
+
+ # This field is nullable, ensure it comes out as a boolean
+ if row["token_used"] is None:
+ row["token_used"] = False
+
+ return TokenLookupResult(**row)
return None
@@ -1072,6 +1107,111 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="update_access_token_last_validated",
)
+ @cached()
+ async def mark_access_token_as_used(self, token_id: int) -> None:
+ """
+ Mark the access token as used, which invalidates the refresh token used
+ to obtain it.
+
+ Because get_user_by_access_token is cached, this function might be
+ called multiple times for the same token, effectively doing unnecessary
+ SQL updates. Because updating the `used` field only goes one way (from
+ False to True) it is safe to cache this function as well to avoid this
+ issue.
+
+ Args:
+ token_id: The ID of the access token to update.
+ Raises:
+ StoreError if there was a problem updating this.
+ """
+ await self.db_pool.simple_update_one(
+ "access_tokens",
+ {"id": token_id},
+ {"used": True},
+ desc="mark_access_token_as_used",
+ )
+
+ async def lookup_refresh_token(
+ self, token: str
+ ) -> Optional[RefreshTokenLookupResult]:
+ """Lookup a refresh token with hints about its validity."""
+
+ def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]:
+ txn.execute(
+ """
+ SELECT
+ rt.id token_id,
+ rt.user_id,
+ rt.device_id,
+ rt.next_token_id,
+ (nrt.next_token_id IS NOT NULL) has_next_refresh_token_been_refreshed,
+ at.used has_next_access_token_been_used
+ FROM refresh_tokens rt
+ LEFT JOIN refresh_tokens nrt ON rt.next_token_id = nrt.id
+ LEFT JOIN access_tokens at ON at.refresh_token_id = nrt.id
+ WHERE rt.token = ?
+ """,
+ (token,),
+ )
+ row = txn.fetchone()
+
+ if row is None:
+ return None
+
+ return RefreshTokenLookupResult(
+ token_id=row[0],
+ user_id=row[1],
+ device_id=row[2],
+ next_token_id=row[3],
+ has_next_refresh_token_been_refreshed=row[4],
+ # This column is nullable, ensure it's a boolean
+ has_next_access_token_been_used=(row[5] or False),
+ )
+
+ return await self.db_pool.runInteraction(
+ "lookup_refresh_token", _lookup_refresh_token_txn
+ )
+
+ async def replace_refresh_token(self, token_id: int, next_token_id: int) -> None:
+ """
+ Set the successor of a refresh token, removing the existing successor
+ if any.
+
+ Args:
+ token_id: ID of the refresh token to update.
+ next_token_id: ID of its successor.
+ """
+
+ def _replace_refresh_token_txn(txn) -> None:
+ # First check if there was an existing refresh token
+ old_next_token_id = self.db_pool.simple_select_one_onecol_txn(
+ txn,
+ "refresh_tokens",
+ {"id": token_id},
+ "next_token_id",
+ allow_none=True,
+ )
+
+ self.db_pool.simple_update_one_txn(
+ txn,
+ "refresh_tokens",
+ {"id": token_id},
+ {"next_token_id": next_token_id},
+ )
+
+ # Delete the old "next" token if it exists. This should cascade and
+ # delete the associated access_token
+ if old_next_token_id is not None:
+ self.db_pool.simple_delete_one_txn(
+ txn,
+ "refresh_tokens",
+ {"id": old_next_token_id},
+ )
+
+ await self.db_pool.runInteraction(
+ "replace_refresh_token", _replace_refresh_token_txn
+ )
+
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(
@@ -1263,6 +1403,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
+ self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
async def add_access_token_to_user(
self,
@@ -1271,14 +1412,18 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
device_id: Optional[str],
valid_until_ms: Optional[int],
puppets_user_id: Optional[str] = None,
+ refresh_token_id: Optional[int] = None,
) -> int:
"""Adds an access token for the given user.
Args:
user_id: The user ID.
token: The new access token to add.
- device_id: ID of the device to associate with the access token
+ device_id: ID of the device to associate with the access token.
valid_until_ms: when the token is valid until. None for no expiry.
+ puppets_user_id
+ refresh_token_id: ID of the refresh token generated alongside this
+ access token.
Raises:
StoreError if there was a problem adding this.
Returns:
@@ -1297,12 +1442,47 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"valid_until_ms": valid_until_ms,
"puppets_user_id": puppets_user_id,
"last_validated": now,
+ "refresh_token_id": refresh_token_id,
+ "used": False,
},
desc="add_access_token_to_user",
)
return next_id
+ async def add_refresh_token_to_user(
+ self,
+ user_id: str,
+ token: str,
+ device_id: Optional[str],
+ ) -> int:
+ """Adds a refresh token for the given user.
+
+ Args:
+ user_id: The user ID.
+ token: The new access token to add.
+ device_id: ID of the device to associate with the refresh token.
+ Raises:
+ StoreError if there was a problem adding this.
+ Returns:
+ The token ID
+ """
+ next_id = self._refresh_tokens_id_gen.get_next()
+
+ await self.db_pool.simple_insert(
+ "refresh_tokens",
+ {
+ "id": next_id,
+ "user_id": user_id,
+ "device_id": device_id,
+ "token": token,
+ "next_token_id": None,
+ },
+ desc="add_refresh_token_to_user",
+ )
+
+ return next_id
+
def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
old_device_id = self.db_pool.simple_select_one_onecol_txn(
txn, "access_tokens", {"token": token}, "device_id"
@@ -1545,7 +1725,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
device_id: Optional[str] = None,
) -> List[Tuple[str, int, Optional[str]]]:
"""
- Invalidate access tokens belonging to a user
+ Invalidate access and refresh tokens belonging to a user
Args:
user_id: ID of user the tokens belong to
@@ -1565,7 +1745,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
items = keyvalues.items()
where_clause = " AND ".join(k + " = ?" for k, _ in items)
values = [v for _, v in items] # type: List[Union[str, int]]
+ # Conveniently, refresh_tokens and access_tokens both use the user_id and device_id fields. Only caveat
+ # is the `except_token_id` param that is tricky to get right, so for now we're just using the same where
+ # clause and values before we handle that. This seems to be only used in the "set password" handler.
+ refresh_where_clause = where_clause
+ refresh_values = values.copy()
if except_token_id:
+ # TODO: support that for refresh tokens
where_clause += " AND id != ?"
values.append(except_token_id)
@@ -1583,6 +1769,11 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
txn.execute("DELETE FROM access_tokens WHERE %s" % where_clause, values)
+ txn.execute(
+ "DELETE FROM refresh_tokens WHERE %s" % refresh_where_clause,
+ refresh_values,
+ )
+
return tokens_and_devices
return await self.db_pool.runInteraction("user_delete_access_tokens", f)
@@ -1599,6 +1790,14 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
await self.db_pool.runInteraction("delete_access_token", f)
+ async def delete_refresh_token(self, refresh_token: str) -> None:
+ def f(txn):
+ self.db_pool.simple_delete_one_txn(
+ txn, table="refresh_tokens", keyvalues={"token": refresh_token}
+ )
+
+ await self.db_pool.runInteraction("delete_refresh_token", f)
+
async def add_user_pending_deactivation(self, user_id: str) -> None:
"""
Adds a user to the table of users who need to be parted from all the rooms they're
|