diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 8e5d78f6f7..bbff3c8d5b 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -47,6 +47,9 @@ class Storage:
# interfaces.
self.main = stores.main
- self.persistence = EventsPersistenceStorage(hs, stores)
self.purge_events = PurgeEventsStorage(hs, stores)
self.state = StateGroupStorage(hs, stores)
+
+ self.persistence = None
+ if stores.persist_events:
+ self.persistence = EventsPersistenceStorage(hs, stores)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index ed8a9bffb1..79ec8f119d 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -952,7 +952,7 @@ class DatabasePool:
key_names: Collection[str],
key_values: Collection[Iterable[Any]],
value_names: Collection[str],
- value_values: Iterable[Iterable[str]],
+ value_values: Iterable[Iterable[Any]],
) -> None:
"""
Upsert, many times.
@@ -981,7 +981,7 @@ class DatabasePool:
key_names: Iterable[str],
key_values: Collection[Iterable[Any]],
value_names: Collection[str],
- value_values: Iterable[Iterable[str]],
+ value_values: Iterable[Iterable[Any]],
) -> None:
"""
Upsert, many times, but without native UPSERT support or batching.
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 985b12df91..aa5d490624 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -75,7 +75,7 @@ class Databases:
# If we're on a process that can persist events also
# instantiate a `PersistEventsStore`
- if hs.config.worker.writers.events == hs.get_instance_name():
+ if hs.get_instance_name() in hs.config.worker.writers.events:
persist_events = PersistEventsStore(hs, database, main)
if "state" in database_config.databases:
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 2ae2fbd5d7..0cb12f4c61 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -160,19 +160,25 @@ class DataStore(
)
if isinstance(self.database_engine, PostgresEngine):
+ # We set the `writers` to an empty list here as we don't care about
+ # missing updates over restarts, as we'll not have anything in our
+ # caches to invalidate. (This reduces the amount of writes to the DB
+ # that happen).
self._cache_id_gen = MultiWriterIdGenerator(
db_conn,
database,
- instance_name="master",
+ stream_name="caches",
+ instance_name=hs.get_instance_name(),
table="cache_invalidation_stream_by_instance",
instance_column="instance_name",
id_column="stream_id",
sequence_name="cache_invalidation_stream_seq",
+ writers=[],
)
else:
self._cache_id_gen = None
- super(DataStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self._presence_on_startup = self._get_active_presence(db_conn)
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 4436b1a83d..ef81d73573 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -29,22 +29,20 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
-class AccountDataWorkerStore(SQLBaseStore):
+# The ABCMeta metaclass ensures that it cannot be instantiated without
+# the abstract methods being implemented.
+class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
"""This is an abstract base class where subclasses must implement
`get_max_account_data_stream_id` which can be called in the initializer.
"""
- # This ABCMeta metaclass ensures that we cannot be instantiated without
- # the abstract methods being implemented.
- __metaclass__ = abc.ABCMeta
-
def __init__(self, database: DatabasePool, db_conn, hs):
account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max
)
- super(AccountDataWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
@abc.abstractmethod
def get_max_account_data_stream_id(self):
@@ -315,7 +313,7 @@ class AccountDataStore(AccountDataWorkerStore):
],
)
- super(AccountDataStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
def get_max_account_data_stream_id(self) -> int:
"""Get the current max stream id for the private user data stream
@@ -341,7 +339,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
- with await self._account_data_id_gen.get_next() as next_id:
+ async with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict.
@@ -389,7 +387,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
- with await self._account_data_id_gen.get_next() as next_id:
+ async with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict.
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 454c0bc50c..85f6b1e3fd 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -52,7 +52,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
)
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
- super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
def get_app_services(self):
return self.services_cache
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index d568789124..2611b92281 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -31,7 +31,7 @@ LAST_SEEN_GRANULARITY = 10 * 60 * 1000
class ClientIpBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
"user_ips_device_index",
@@ -358,7 +358,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
name="client_ip_last_seen", keylen=4, max_entries=50000
)
- super(ClientIpStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.user_ips_max_age = hs.config.user_ips_max_age
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 0044433110..d42faa3f1f 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -283,7 +283,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
"device_inbox_stream_index",
@@ -313,7 +313,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(DeviceInboxStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
# Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions.
@@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
- with await self._device_inbox_id_gen.get_next() as stream_id:
+ async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
@@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
txn, stream_id, local_messages_by_user_then_device
)
- with await self._device_inbox_id_gen.get_next() as stream_id:
+ async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox",
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index add4e3ea0e..fdf394c612 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -377,7 +377,7 @@ class DeviceWorkerStore(SQLBaseStore):
THe new stream ID.
"""
- with await self._device_list_id_gen.get_next() as stream_id:
+ async with self._device_list_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
@@ -481,7 +481,7 @@ class DeviceWorkerStore(SQLBaseStore):
}
async def get_users_whose_devices_changed(
- self, from_key: str, user_ids: Iterable[str]
+ self, from_key: int, user_ids: Iterable[str]
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
are in the given list of user_ids.
@@ -493,7 +493,6 @@ class DeviceWorkerStore(SQLBaseStore):
Returns:
The set of user_ids whose devices have changed since `from_key`
"""
- from_key = int(from_key)
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
@@ -527,7 +526,7 @@ class DeviceWorkerStore(SQLBaseStore):
)
async def get_users_whose_signatures_changed(
- self, user_id: str, from_key: str
+ self, user_id: str, from_key: int
) -> Set[str]:
"""Get the users who have new cross-signing signatures made by `user_id` since
`from_key`.
@@ -539,7 +538,7 @@ class DeviceWorkerStore(SQLBaseStore):
Returns:
A set of user IDs with updated signatures.
"""
- from_key = int(from_key)
+
if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
sql = """
SELECT DISTINCT user_ids FROM user_signature_stream
@@ -702,7 +701,7 @@ class DeviceWorkerStore(SQLBaseStore):
class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
"device_lists_stream_idx",
@@ -827,7 +826,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(DeviceStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
@@ -1094,7 +1093,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not device_ids:
return
- with await self._device_list_id_gen.get_next_mult(
+ async with self._device_list_id_gen.get_next_mult(
len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
@@ -1109,7 +1108,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return stream_ids[-1]
context = get_active_span_text_map()
- with await self._device_list_id_gen.get_next_mult(
+ async with self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index fba3098ea2..22e1ed15d0 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
from synapse.handlers.e2e_keys import SignatureListItem
-@attr.s
+@attr.s(slots=True)
class DeviceKeyLookupResult:
"""The type returned by get_e2e_device_keys_and_signatures"""
@@ -831,7 +831,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
key (dict): the key data
"""
- with await self._cross_signing_id_gen.get_next() as stream_id:
+ async with self._cross_signing_id_gen.get_next() as stream_id:
return await self.db_pool.runInteraction(
"add_e2e_cross_signing_key",
self._set_e2e_cross_signing_key_txn,
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 0b69aa6a94..6d3689c09e 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -438,7 +438,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"""
if stream_ordering <= self.stream_ordering_month_ago:
- raise StoreError(400, "stream_ordering too old")
+ raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
sql = """
SELECT event_id FROM stream_ordering_to_exterm
@@ -600,7 +600,7 @@ class EventFederationStore(EventFederationWorkerStore):
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(EventFederationStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 5233ed83e2..62f1738732 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -68,7 +68,7 @@ def _deserialize_action(actions, is_highlight):
class EventPushActionsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
# These get correctly set by _find_stream_orderings_for_times_txn
self.stream_ordering_month_ago = None
@@ -661,7 +661,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(EventPushActionsStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
self.EPA_HIGHLIGHT_INDEX,
@@ -969,7 +969,7 @@ def _action_has_highlight(actions):
return False
-@attr.s
+@attr.s(slots=True)
class _EventPushSummary:
"""Summary of pending event push actions for a given user in a given room.
Used in _rotate_notifs_before_txn to manipulate results from event_push_actions.
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index b3d27a2ee7..18def01f50 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -17,7 +17,7 @@
import itertools
import logging
from collections import OrderedDict, namedtuple
-from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
import attr
from prometheus_client import Counter
@@ -32,7 +32,7 @@ from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchEntry
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.iterutils import batch_iter
@@ -97,18 +97,21 @@ class PersistEventsStore:
self.store = main_data_store
self.database_engine = db.engine
self._clock = hs.get_clock()
+ self._instance_name = hs.get_instance_name()
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
# Ideally we'd move these ID gens here, unfortunately some other ID
# generators are chained off them so doing so is a bit of a PITA.
- self._backfill_id_gen = self.store._backfill_id_gen # type: StreamIdGenerator
- self._stream_id_gen = self.store._stream_id_gen # type: StreamIdGenerator
+ self._backfill_id_gen = (
+ self.store._backfill_id_gen
+ ) # type: MultiWriterIdGenerator
+ self._stream_id_gen = self.store._stream_id_gen # type: MultiWriterIdGenerator
# This should only exist on instances that are configured to write
assert (
- hs.config.worker.writers.events == hs.get_instance_name()
+ hs.get_instance_name() in hs.config.worker.writers.events
), "Can only instantiate EventsStore on master"
async def _persist_events_and_state_updates(
@@ -153,15 +156,15 @@ class PersistEventsStore:
# Note: Multiple instances of this function cannot be in flight at
# the same time for the same room.
if backfilled:
- stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
+ stream_ordering_manager = self._backfill_id_gen.get_next_mult(
len(events_and_contexts)
)
else:
- stream_ordering_manager = await self._stream_id_gen.get_next_mult(
+ stream_ordering_manager = self._stream_id_gen.get_next_mult(
len(events_and_contexts)
)
- with stream_ordering_manager as stream_orderings:
+ async with stream_ordering_manager as stream_orderings:
for (event, context), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream
@@ -213,7 +216,7 @@ class PersistEventsStore:
Returns:
Filtered event ids
"""
- results = []
+ results = [] # type: List[str]
def _get_events_which_are_prevs_txn(txn, batch):
sql = """
@@ -631,7 +634,9 @@ class PersistEventsStore:
)
@classmethod
- def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts):
+ def _filter_events_and_contexts_for_duplicates(
+ cls, events_and_contexts: List[Tuple[EventBase, EventContext]]
+ ) -> List[Tuple[EventBase, EventContext]]:
"""Ensure that we don't have the same event twice.
Pick the earliest non-outlier if there is one, else the earliest one.
@@ -641,7 +646,9 @@ class PersistEventsStore:
Returns:
list[(EventBase, EventContext)]: filtered list
"""
- new_events_and_contexts = OrderedDict()
+ new_events_and_contexts = (
+ OrderedDict()
+ ) # type: OrderedDict[str, Tuple[EventBase, EventContext]]
for event, context in events_and_contexts:
prev_event_context = new_events_and_contexts.get(event.event_id)
if prev_event_context:
@@ -655,7 +662,12 @@ class PersistEventsStore:
new_events_and_contexts[event.event_id] = (event, context)
return list(new_events_and_contexts.values())
- def _update_room_depths_txn(self, txn, events_and_contexts, backfilled):
+ def _update_room_depths_txn(
+ self,
+ txn,
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
+ backfilled: bool,
+ ):
"""Update min_depth for each room
Args:
@@ -664,7 +676,7 @@ class PersistEventsStore:
we are persisting
backfilled (bool): True if the events were backfilled
"""
- depth_updates = {}
+ depth_updates = {} # type: Dict[str, int]
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
@@ -800,6 +812,7 @@ class PersistEventsStore:
table="events",
values=[
{
+ "instance_name": self._instance_name,
"stream_ordering": event.internal_metadata.stream_ordering,
"topological_ordering": event.depth,
"depth": event.depth,
@@ -1095,6 +1108,10 @@ class PersistEventsStore:
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
"""
+
+ def str_or_none(val: Any) -> Optional[str]:
+ return val if isinstance(val, str) else None
+
self.db_pool.simple_insert_many_txn(
txn,
table="room_memberships",
@@ -1105,8 +1122,8 @@ class PersistEventsStore:
"sender": event.user_id,
"room_id": event.room_id,
"membership": event.membership,
- "display_name": event.content.get("displayname", None),
- "avatar_url": event.content.get("avatar_url", None),
+ "display_name": str_or_none(event.content.get("displayname")),
+ "avatar_url": str_or_none(event.content.get("avatar_url")),
}
for event in events
],
@@ -1436,7 +1453,7 @@ class PersistEventsStore:
Forward extremities are handled when we first start persisting the events.
"""
- events_by_room = {}
+ events_by_room = {} # type: Dict[str, List[EventBase]]
for ev in events:
events_by_room.setdefault(ev.room_id, []).append(ev)
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index e53c6373a8..5e4af2eb51 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -29,7 +29,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index a7a73cc3d8..f95679ebc4 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from __future__ import division
-
import itertools
import logging
import threading
@@ -42,7 +40,8 @@ from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import Collection, get_domain_from_id
from synapse.util.caches.descriptors import Cache, cached
from synapse.util.iterutils import batch_iter
@@ -76,29 +75,60 @@ class EventRedactBehaviour(Names):
class EventsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(EventsWorkerStore, self).__init__(database, db_conn, hs)
-
- if hs.config.worker.writers.events == hs.get_instance_name():
- # We are the process in charge of generating stream ids for events,
- # so instantiate ID generators based on the database
- self._stream_id_gen = StreamIdGenerator(
- db_conn, "events", "stream_ordering",
+ super().__init__(database, db_conn, hs)
+
+ if isinstance(database.engine, PostgresEngine):
+ # If we're using Postgres than we can use `MultiWriterIdGenerator`
+ # regardless of whether this process writes to the streams or not.
+ self._stream_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ stream_name="events",
+ instance_name=hs.get_instance_name(),
+ table="events",
+ instance_column="instance_name",
+ id_column="stream_ordering",
+ sequence_name="events_stream_seq",
+ writers=hs.config.worker.writers.events,
)
- self._backfill_id_gen = StreamIdGenerator(
- db_conn,
- "events",
- "stream_ordering",
- step=-1,
- extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+ self._backfill_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ stream_name="backfill",
+ instance_name=hs.get_instance_name(),
+ table="events",
+ instance_column="instance_name",
+ id_column="stream_ordering",
+ sequence_name="events_backfill_stream_seq",
+ positive=False,
+ writers=hs.config.worker.writers.events,
)
else:
- # Another process is in charge of persisting events and generating
- # stream IDs: rely on the replication streams to let us know which
- # IDs we can process.
- self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
- self._backfill_id_gen = SlavedIdTracker(
- db_conn, "events", "stream_ordering", step=-1
- )
+ # We shouldn't be running in worker mode with SQLite, but its useful
+ # to support it for unit tests.
+ #
+ # If this process is the writer than we need to use
+ # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
+ # updated over replication. (Multiple writers are not supported for
+ # SQLite).
+ if hs.get_instance_name() in hs.config.worker.writers.events:
+ self._stream_id_gen = StreamIdGenerator(
+ db_conn, "events", "stream_ordering",
+ )
+ self._backfill_id_gen = StreamIdGenerator(
+ db_conn,
+ "events",
+ "stream_ordering",
+ step=-1,
+ extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+ )
+ else:
+ self._stream_id_gen = SlavedIdTracker(
+ db_conn, "events", "stream_ordering"
+ )
+ self._backfill_id_gen = SlavedIdTracker(
+ db_conn, "events", "stream_ordering", step=-1
+ )
self._get_event_cache = Cache(
"*getEvent*",
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index ccfbb2135e..7218191965 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -1265,7 +1265,7 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id
- with await self._group_updates_id_gen.get_next() as next_id:
+ async with self._group_updates_id_gen.get_next() as next_id:
res = await self.db_pool.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 86557d5512..cc538c5c10 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -17,12 +17,14 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
+BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
+ "media_repository_drop_index_wo_method"
+)
+
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(MediaRepositoryBackgroundUpdateStore, self).__init__(
- database, db_conn, hs
- )
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
update_name="local_media_repository_url_idx",
@@ -32,12 +34,65 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
where_clause="url_cache IS NOT NULL",
)
+ # The following the updates add the method to the unique constraint of
+ # the thumbnail databases. That fixes an issue, where thumbnails of the
+ # same resolution, but different methods could overwrite one another.
+ # This can happen with custom thumbnail configs or with dynamic thumbnailing.
+ self.db_pool.updates.register_background_index_update(
+ update_name="local_media_repository_thumbnails_method_idx",
+ index_name="local_media_repository_thumbn_media_id_width_height_method_key",
+ table="local_media_repository_thumbnails",
+ columns=[
+ "media_id",
+ "thumbnail_width",
+ "thumbnail_height",
+ "thumbnail_type",
+ "thumbnail_method",
+ ],
+ unique=True,
+ )
+
+ self.db_pool.updates.register_background_index_update(
+ update_name="remote_media_repository_thumbnails_method_idx",
+ index_name="remote_media_repository_thumbn_media_origin_id_width_height_method_key",
+ table="remote_media_cache_thumbnails",
+ columns=[
+ "media_origin",
+ "media_id",
+ "thumbnail_width",
+ "thumbnail_height",
+ "thumbnail_type",
+ "thumbnail_method",
+ ],
+ unique=True,
+ )
+
+ self.db_pool.updates.register_background_update_handler(
+ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD,
+ self._drop_media_index_without_method,
+ )
+
+ async def _drop_media_index_without_method(self, progress, batch_size):
+ def f(txn):
+ txn.execute(
+ "ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key"
+ )
+ txn.execute(
+ "ALTER TABLE remote_media_cache_thumbnails DROP CONSTRAINT IF EXISTS remote_media_repository_thumbn_media_id_thumbnail_width_thum_key"
+ )
+
+ await self.db_pool.runInteraction("drop_media_indices_without_method", f)
+ await self.db_pool.updates._end_background_update(
+ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD
+ )
+ return 1
+
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"""Persistence for attachments and avatars"""
def __init__(self, database: DatabasePool, db_conn, hs):
- super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
"""Get the metadata for a local piece of media
diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py
index 686052bd83..92099f95ce 100644
--- a/synapse/storage/databases/main/metrics.py
+++ b/synapse/storage/databases/main/metrics.py
@@ -12,10 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import typing
-from collections import Counter
-from synapse.metrics import BucketCollector
+from synapse.metrics import GaugeBucketCollector
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
@@ -23,6 +21,26 @@ from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)
+# Collect metrics on the number of forward extremities that exist.
+_extremities_collecter = GaugeBucketCollector(
+ "synapse_forward_extremities",
+ "Number of rooms on the server with the given number of forward extremities"
+ " or fewer",
+ buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500],
+)
+
+# we also expose metrics on the "number of excess extremity events", which is
+# (E-1)*N, where E is the number of extremities and N is the number of state
+# events in the room. This is an approximation to the number of state events
+# we could remove from state resolution by reducing the graph to a single
+# forward extremity.
+_excess_state_events_collecter = GaugeBucketCollector(
+ "synapse_excess_extremity_events",
+ "Number of rooms on the server with the given number of excess extremity "
+ "events, or fewer",
+ buckets=[0] + [1 << n for n in range(12)],
+)
+
class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
"""Functions to pull various metrics from the DB, for e.g. phone home
@@ -32,18 +50,6 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
- # Collect metrics on the number of forward extremities that exist.
- # Counter of number of extremities to count
- self._current_forward_extremities_amount = (
- Counter()
- ) # type: typing.Counter[int]
-
- BucketCollector(
- "synapse_forward_extremities",
- lambda: self._current_forward_extremities_amount,
- buckets=[1, 2, 3, 5, 7, 10, 15, 20, 50, 100, 200, 500, "+Inf"],
- )
-
# Read the extrems every 60 minutes
def read_forward_extremities():
# run as a background process to make sure that the database transactions
@@ -58,14 +64,25 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
def fetch(txn):
txn.execute(
"""
- select count(*) c from event_forward_extremities
- group by room_id
+ SELECT t1.c, t2.c
+ FROM (
+ SELECT room_id, COUNT(*) c FROM event_forward_extremities
+ GROUP BY room_id
+ ) t1 LEFT JOIN (
+ SELECT room_id, COUNT(*) c FROM current_state_events
+ GROUP BY room_id
+ ) t2 ON t1.room_id = t2.room_id
"""
)
return txn.fetchall()
res = await self.db_pool.runInteraction("read_forward_extremities", fetch)
- self._current_forward_extremities_amount = Counter([x[0] for x in res])
+
+ _extremities_collecter.update_data(x[0] for x in res)
+
+ _excess_state_events_collecter.update_data(
+ (x[0] - 1) * x[1] for x in res if x[1]
+ )
async def count_daily_messages(self):
"""
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index e3ac512f11..097a16cb2e 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -28,7 +28,7 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000
class MonthlyActiveUsersWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(MonthlyActiveUsersWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self._clock = hs.get_clock()
self.hs = hs
@@ -41,7 +41,14 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"""
def _count_users(txn):
- sql = "SELECT COALESCE(count(*), 0) FROM monthly_active_users"
+ # Exclude app service users
+ sql = """
+ SELECT COALESCE(count(*), 0)
+ FROM monthly_active_users
+ LEFT JOIN users
+ ON monthly_active_users.user_id=users.name
+ WHERE (users.appservice_id IS NULL OR users.appservice_id = '');
+ """
txn.execute(sql)
(count,) = txn.fetchone()
return count
@@ -120,7 +127,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
self._mau_stats_only = hs.config.mau_stats_only
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index c9f655dfb7..dbbb99cb95 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -23,11 +23,11 @@ from synapse.util.iterutils import batch_iter
class PresenceStore(SQLBaseStore):
async def update_presence(self, presence_states):
- stream_ordering_manager = await self._presence_id_gen.get_next_mult(
+ stream_ordering_manager = self._presence_id_gen.get_next_mult(
len(presence_states)
)
- with stream_ordering_manager as stream_orderings:
+ async with stream_ordering_manager as stream_orderings:
await self.db_pool.runInteraction(
"update_presence",
self._update_presence_txn,
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index ea833829ae..ecfc6717b3 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -42,17 +42,17 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
The set of state groups that are referenced by deleted events.
"""
+ parsed_token = await RoomStreamToken.parse(self, token)
+
return await self.db_pool.runInteraction(
"purge_history",
self._purge_history_txn,
room_id,
- token,
+ parsed_token,
delete_local_events,
)
- def _purge_history_txn(self, txn, room_id, token_str, delete_local_events):
- token = RoomStreamToken.parse(token_str)
-
+ def _purge_history_txn(self, txn, room_id, token, delete_local_events):
# Tables that should be pruned:
# event_auth
# event_backward_extremities
@@ -69,6 +69,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
# room_depth
# state_groups
# state_groups_state
+ # destination_rooms
# we will build a temporary table listing the events so that we don't
# have to keep shovelling the list back and forth across the
@@ -336,6 +337,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
# and finally, the tables with an index on room_id (or no useful index)
for table in (
"current_state_events",
+ "destination_rooms",
"event_backward_extremities",
"event_forward_extremities",
"event_json",
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 0de802a86b..711d5aa23d 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -13,11 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import abc
import logging
from typing import List, Tuple, Union
+from synapse.api.errors import NotFoundError, StoreError
from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -27,6 +27,7 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.pusher import PusherWorkerStore
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util import json_encoder
@@ -60,6 +61,8 @@ def _load_rules(rawrules, enabled_map, use_new_defaults=False):
return rules
+# The ABCMeta metaclass ensures that it cannot be instantiated without
+# the abstract methods being implemented.
class PushRulesWorkerStore(
ApplicationServiceWorkerStore,
ReceiptsWorkerStore,
@@ -67,17 +70,14 @@ class PushRulesWorkerStore(
RoomMemberWorkerStore,
EventsWorkerStore,
SQLBaseStore,
+ metaclass=abc.ABCMeta,
):
"""This is an abstract base class where subclasses must implement
`get_max_push_rules_stream_id` which can be called in the initializer.
"""
- # This ABCMeta metaclass ensures that we cannot be instantiated without
- # the abstract methods being implemented.
- __metaclass__ = abc.ABCMeta
-
def __init__(self, database: DatabasePool, db_conn, hs):
- super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
self._push_rules_stream_id_gen = StreamIdGenerator(
@@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore):
) -> None:
conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions)
- with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
if before or after:
@@ -540,6 +540,25 @@ class PushRuleStore(PushRulesWorkerStore):
},
)
+ # ensure we have a push_rules_enable row
+ # enabledness defaults to true
+ if isinstance(self.database_engine, PostgresEngine):
+ sql = """
+ INSERT INTO push_rules_enable (id, user_name, rule_id, enabled)
+ VALUES (?, ?, ?, ?)
+ ON CONFLICT DO NOTHING
+ """
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ sql = """
+ INSERT OR IGNORE INTO push_rules_enable (id, user_name, rule_id, enabled)
+ VALUES (?, ?, ?, ?)
+ """
+ else:
+ raise RuntimeError("Unknown database engine")
+
+ new_enable_id = self._push_rules_enable_id_gen.get_next()
+ txn.execute(sql, (new_enable_id, user_id, rule_id, 1))
+
async def delete_push_rule(self, user_id: str, rule_id: str) -> None:
"""
Delete a push rule. Args specify the row to be deleted and can be
@@ -552,6 +571,12 @@ class PushRuleStore(PushRulesWorkerStore):
"""
def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
+ # we don't use simple_delete_one_txn because that would fail if the
+ # user did not have a push_rule_enable row.
+ self.db_pool.simple_delete_txn(
+ txn, "push_rules_enable", {"user_name": user_id, "rule_id": rule_id}
+ )
+
self.db_pool.simple_delete_one_txn(
txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}
)
@@ -560,7 +585,7 @@ class PushRuleStore(PushRulesWorkerStore):
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
)
- with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
@@ -570,10 +595,29 @@ class PushRuleStore(PushRulesWorkerStore):
event_stream_ordering,
)
- async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
- with await self._push_rules_stream_id_gen.get_next() as stream_id:
- event_stream_ordering = self._stream_id_gen.get_current_token()
+ async def set_push_rule_enabled(
+ self, user_id: str, rule_id: str, enabled: bool, is_default_rule: bool
+ ) -> None:
+ """
+ Sets the `enabled` state of a push rule.
+ Args:
+ user_id: the user ID of the user who wishes to enable/disable the rule
+ e.g. '@tina:example.org'
+ rule_id: the full rule ID of the rule to be enabled/disabled
+ e.g. 'global/override/.m.rule.roomnotif'
+ or 'global/override/myCustomRule'
+ enabled: True if the rule is to be enabled, False if it is to be
+ disabled
+ is_default_rule: True if and only if this is a server-default rule.
+ This skips the check for existence (as only user-created rules
+ are always stored in the database `push_rules` table).
+
+ Raises:
+ NotFoundError if the rule does not exist.
+ """
+ async with self._push_rules_stream_id_gen.get_next() as stream_id:
+ event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
"_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn,
@@ -582,12 +626,47 @@ class PushRuleStore(PushRulesWorkerStore):
user_id,
rule_id,
enabled,
+ is_default_rule,
)
def _set_push_rule_enabled_txn(
- self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
+ self,
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ enabled,
+ is_default_rule,
):
new_id = self._push_rules_enable_id_gen.get_next()
+
+ if not is_default_rule:
+ # first check it exists; we need to lock for key share so that a
+ # transaction that deletes the push rule will conflict with this one.
+ # We also need a push_rule_enable row to exist for every push_rules
+ # row, otherwise it is possible to simultaneously delete a push rule
+ # (that has no _enable row) and enable it, resulting in a dangling
+ # _enable row. To solve this: we either need to use SERIALISABLE or
+ # ensure we always have a push_rule_enable row for every push_rule
+ # row. We chose the latter.
+ for_key_share = "FOR KEY SHARE"
+ if not isinstance(self.database_engine, PostgresEngine):
+ # For key share is not applicable/available on SQLite
+ for_key_share = ""
+ sql = (
+ """
+ SELECT 1 FROM push_rules
+ WHERE user_name = ? AND rule_id = ?
+ %s
+ """
+ % for_key_share
+ )
+ txn.execute(sql, (user_id, rule_id))
+ if txn.fetchone() is None:
+ # needed to set NOT_FOUND code.
+ raise NotFoundError("Push rule does not exist.")
+
self.db_pool.simple_upsert_txn(
txn,
"push_rules_enable",
@@ -606,8 +685,30 @@ class PushRuleStore(PushRulesWorkerStore):
)
async def set_push_rule_actions(
- self, user_id, rule_id, actions, is_default_rule
+ self,
+ user_id: str,
+ rule_id: str,
+ actions: List[Union[dict, str]],
+ is_default_rule: bool,
) -> None:
+ """
+ Sets the `actions` state of a push rule.
+
+ Will throw NotFoundError if the rule does not exist; the Code for this
+ is NOT_FOUND.
+
+ Args:
+ user_id: the user ID of the user who wishes to enable/disable the rule
+ e.g. '@tina:example.org'
+ rule_id: the full rule ID of the rule to be enabled/disabled
+ e.g. 'global/override/.m.rule.roomnotif'
+ or 'global/override/myCustomRule'
+ actions: A list of actions (each action being a dict or string),
+ e.g. ["notify", {"set_tweak": "highlight", "value": false}]
+ is_default_rule: True if and only if this is a server-default rule.
+ This skips the check for existence (as only user-created rules
+ are always stored in the database `push_rules` table).
+ """
actions_json = json_encoder.encode(actions)
def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
@@ -629,12 +730,19 @@ class PushRuleStore(PushRulesWorkerStore):
update_stream=False,
)
else:
- self.db_pool.simple_update_one_txn(
- txn,
- "push_rules",
- {"user_name": user_id, "rule_id": rule_id},
- {"actions": actions_json},
- )
+ try:
+ self.db_pool.simple_update_one_txn(
+ txn,
+ "push_rules",
+ {"user_name": user_id, "rule_id": rule_id},
+ {"actions": actions_json},
+ )
+ except StoreError as serr:
+ if serr.code == 404:
+ # this sets the NOT_FOUND error Code
+ raise NotFoundError("Push rule does not exist")
+ else:
+ raise
self._insert_push_rules_update_txn(
txn,
@@ -646,7 +754,7 @@ class PushRuleStore(PushRulesWorkerStore):
data={"actions": actions_json},
)
- with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index c388468273..df8609b97b 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore):
last_stream_ordering,
profile_tag="",
) -> None:
- with await self._pushers_id_gen.get_next() as stream_id:
+ async with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so simple_upsert will retry
await self.db_pool.simple_upsert(
@@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore):
},
)
- with await self._pushers_id_gen.get_next() as stream_id:
+ async with self._pushers_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"delete_pusher", delete_pusher_txn, stream_id
)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 4a0d5a320e..c79ddff680 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -31,17 +31,15 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
-class ReceiptsWorkerStore(SQLBaseStore):
+# The ABCMeta metaclass ensures that it cannot be instantiated without
+# the abstract methods being implemented.
+class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
"""This is an abstract base class where subclasses must implement
`get_max_receipt_stream_id` which can be called in the initializer.
"""
- # This ABCMeta metaclass ensures that we cannot be instantiated without
- # the abstract methods being implemented.
- __metaclass__ = abc.ABCMeta
-
def __init__(self, database: DatabasePool, db_conn, hs):
- super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
@@ -388,7 +386,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
db_conn, "receipts_linearized", "stream_id"
)
- super(ReceiptsStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
@@ -526,7 +524,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"insert_receipt_conv", graph_to_linear
)
- with await self._receipts_id_gen.get_next() as stream_id:
+ async with self._receipts_id_gen.get_next() as stream_id:
event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 01f20c03c2..a83df7759d 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -36,11 +36,14 @@ logger = logging.getLogger(__name__)
class RegistrationWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RegistrationWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.config = hs.config
self.clock = hs.get_clock()
+ # Note: we don't check this sequence for consistency as we'd have to
+ # call `find_max_generated_user_id_localpart` each time, which is
+ # expensive if there are many entries.
self._user_id_seq = build_sequence_generator(
database.engine, find_max_generated_user_id_localpart, "user_id_seq",
)
@@ -116,6 +119,20 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_expiration_ts_for_user",
)
+ async def is_account_expired(self, user_id: str, current_ts: int) -> bool:
+ """
+ Returns whether an user account is expired.
+
+ Args:
+ user_id: The user's ID
+ current_ts: The current timestamp
+
+ Returns:
+ Whether the user account has expired
+ """
+ expiration_ts = await self.get_expiration_ts_for_user(user_id)
+ return expiration_ts is not None and current_ts >= expiration_ts
+
async def set_account_validity_for_user(
self,
user_id: str,
@@ -379,7 +396,7 @@ class RegistrationWorkerStore(SQLBaseStore):
async def get_user_by_external_id(
self, auth_provider: str, external_id: str
- ) -> str:
+ ) -> Optional[str]:
"""Look up a user by their external auth id
Args:
@@ -387,7 +404,7 @@ class RegistrationWorkerStore(SQLBaseStore):
external_id: id on that system
Returns:
- str|None: the mxid of the user, or None if they are not known
+ the mxid of the user, or None if they are not known
"""
return await self.db_pool.simple_select_one_onecol(
table="user_external_ids",
@@ -764,7 +781,7 @@ class RegistrationWorkerStore(SQLBaseStore):
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.clock = hs.get_clock()
self.config = hs.config
@@ -892,7 +909,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
class RegistrationStore(RegistrationBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RegistrationStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self._account_validity = hs.config.account_validity
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 717df97301..3c7630857f 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -69,7 +69,7 @@ class RoomSortOrder(Enum):
class RoomWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RoomWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.config = hs.config
@@ -104,7 +104,8 @@ class RoomWorkerStore(SQLBaseStore):
curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
rooms.creator, state.encryption, state.is_federatable AS federatable,
rooms.is_public AS public, state.join_rules, state.guest_access,
- state.history_visibility, curr.current_state_events AS state_events
+ state.history_visibility, curr.current_state_events AS state_events,
+ state.avatar, state.topic
FROM rooms
LEFT JOIN room_stats_state state USING (room_id)
LEFT JOIN room_stats_current curr USING (room_id)
@@ -862,7 +863,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.config = hs.config
@@ -1073,7 +1074,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RoomStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.config = hs.config
@@ -1136,7 +1137,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with await self._public_room_id_gen.get_next() as next_id:
+ async with self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"store_room_txn", store_room_txn, next_id
)
@@ -1203,7 +1204,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with await self._public_room_id_gen.get_next() as next_id:
+ async with self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"set_room_is_public", set_room_is_public_txn, next_id
)
@@ -1283,7 +1284,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with await self._public_room_id_gen.get_next() as next_id:
+ async with self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"set_room_is_public_appservice",
set_room_is_public_appservice_txn,
@@ -1327,6 +1328,101 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
desc="add_event_report",
)
+ async def get_event_reports_paginate(
+ self,
+ start: int,
+ limit: int,
+ direction: str = "b",
+ user_id: Optional[str] = None,
+ room_id: Optional[str] = None,
+ ) -> Tuple[List[Dict[str, Any]], int]:
+ """Retrieve a paginated list of event reports
+
+ Args:
+ start: event offset to begin the query from
+ limit: number of rows to retrieve
+ direction: Whether to fetch the most recent first (`"b"`) or the
+ oldest first (`"f"`)
+ user_id: search for user_id. Ignored if user_id is None
+ room_id: search for room_id. Ignored if room_id is None
+ Returns:
+ event_reports: json list of event reports
+ count: total number of event reports matching the filter criteria
+ """
+
+ def _get_event_reports_paginate_txn(txn):
+ filters = []
+ args = []
+
+ if user_id:
+ filters.append("er.user_id LIKE ?")
+ args.extend(["%" + user_id + "%"])
+ if room_id:
+ filters.append("er.room_id LIKE ?")
+ args.extend(["%" + room_id + "%"])
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
+
+ sql = """
+ SELECT COUNT(*) as total_event_reports
+ FROM event_reports AS er
+ {}
+ """.format(
+ where_clause
+ )
+ txn.execute(sql, args)
+ count = txn.fetchone()[0]
+
+ sql = """
+ SELECT
+ er.id,
+ er.received_ts,
+ er.room_id,
+ er.event_id,
+ er.user_id,
+ er.reason,
+ er.content,
+ events.sender,
+ room_aliases.room_alias,
+ event_json.json AS event_json
+ FROM event_reports AS er
+ LEFT JOIN room_aliases
+ ON room_aliases.room_id = er.room_id
+ JOIN events
+ ON events.event_id = er.event_id
+ JOIN event_json
+ ON event_json.event_id = er.event_id
+ {where_clause}
+ ORDER BY er.received_ts {order}
+ LIMIT ?
+ OFFSET ?
+ """.format(
+ where_clause=where_clause, order=order,
+ )
+
+ args += [limit, start]
+ txn.execute(sql, args)
+ event_reports = self.db_pool.cursor_to_dict(txn)
+
+ if count > 0:
+ for row in event_reports:
+ try:
+ row["content"] = db_to_json(row["content"])
+ row["event_json"] = db_to_json(row["event_json"])
+ except Exception:
+ continue
+
+ return event_reports, count
+
+ return await self.db_pool.runInteraction(
+ "get_event_reports_paginate", _get_event_reports_paginate_txn
+ )
+
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 91a8b43da3..86ffe2479e 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
@@ -37,7 +36,7 @@ from synapse.storage.roommember import (
ProfileInfo,
RoomsForUser,
)
-from synapse.types import Collection, get_domain_from_id
+from synapse.types import Collection, PersistedEventPosition, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
@@ -55,7 +54,7 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
class RoomMemberWorkerStore(EventsWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
# Is the current_state_events.membership up to date? Or is the
# background update still running?
@@ -387,7 +386,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# for rooms the server is participating in.
if self._current_state_events_membership_up_to_date:
sql = """
- SELECT room_id, e.stream_ordering
+ SELECT room_id, e.instance_name, e.stream_ordering
FROM current_state_events AS c
INNER JOIN events AS e USING (room_id, event_id)
WHERE
@@ -397,7 +396,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""
else:
sql = """
- SELECT room_id, e.stream_ordering
+ SELECT room_id, e.instance_name, e.stream_ordering
FROM current_state_events AS c
INNER JOIN room_memberships AS m USING (room_id, event_id)
INNER JOIN events AS e USING (room_id, event_id)
@@ -408,7 +407,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"""
txn.execute(sql, (user_id, Membership.JOIN))
- return frozenset(GetRoomsForUserWithStreamOrdering(*row) for row in txn)
+ return frozenset(
+ GetRoomsForUserWithStreamOrdering(
+ room_id, PersistedEventPosition(instance, stream_id)
+ )
+ for room_id, instance, stream_id in txn
+ )
async def get_users_server_still_shares_room_with(
self, user_ids: Collection[str]
@@ -819,7 +823,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
class RoomMemberBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
)
@@ -973,7 +977,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RoomMemberStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
async def forget(self, user_id: str, room_id: str) -> None:
"""Indicate that user_id wishes to discard history for room_id."""
diff --git a/synapse/storage/databases/main/schema/delta/56/event_labels.sql b/synapse/storage/databases/main/schema/delta/56/event_labels.sql
index 5e29c1da19..ccf287971c 100644
--- a/synapse/storage/databases/main/schema/delta/56/event_labels.sql
+++ b/synapse/storage/databases/main/schema/delta/56/event_labels.sql
@@ -13,7 +13,7 @@
* limitations under the License.
*/
--- room_id and topoligical_ordering are denormalised from the events table in order to
+-- room_id and topological_ordering are denormalised from the events table in order to
-- make the index work.
CREATE TABLE IF NOT EXISTS event_labels (
event_id TEXT,
diff --git a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
new file mode 100644
index 0000000000..b64926e9c9
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
@@ -0,0 +1,33 @@
+/* Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * This adds the method to the unique key constraint of the thumbnail databases.
+ * Otherwise you can't have a scaled and a cropped thumbnail with the same
+ * resolution, which happens quite often with dynamic thumbnailing.
+ * This is the postgres specific migration modifying the table with a background
+ * migration.
+ */
+
+-- add new index that includes method to local media
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('local_media_repository_thumbnails_method_idx', '{}');
+
+-- add new index that includes method to remote media
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+ ('remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx');
+
+-- drop old index
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+ ('media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx');
+
diff --git a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite
new file mode 100644
index 0000000000..1d0c04b53a
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite
@@ -0,0 +1,44 @@
+/* Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * This adds the method to the unique key constraint of the thumbnail databases.
+ * Otherwise you can't have a scaled and a cropped thumbnail with the same
+ * resolution, which happens quite often with dynamic thumbnailing.
+ * This is a sqlite specific migration, since sqlite can't modify the unique
+ * constraint of a table without recreating it.
+ */
+
+CREATE TABLE local_media_repository_thumbnails_new ( media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_type TEXT, thumbnail_method TEXT, thumbnail_length INTEGER, UNIQUE ( media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method ) );
+
+INSERT INTO local_media_repository_thumbnails_new
+ SELECT media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method, thumbnail_length
+ FROM local_media_repository_thumbnails;
+
+DROP TABLE local_media_repository_thumbnails;
+
+ALTER TABLE local_media_repository_thumbnails_new RENAME TO local_media_repository_thumbnails;
+
+CREATE INDEX local_media_repository_thumbnails_media_id ON local_media_repository_thumbnails (media_id);
+
+
+
+CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails_new ( media_origin TEXT, media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_method TEXT, thumbnail_type TEXT, thumbnail_length INTEGER, filesystem_id TEXT, UNIQUE ( media_origin, media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method ) );
+
+INSERT INTO remote_media_cache_thumbnails_new
+ SELECT media_origin, media_id, thumbnail_width, thumbnail_height, thumbnail_method, thumbnail_type, thumbnail_length, filesystem_id
+ FROM remote_media_cache_thumbnails;
+
+DROP TABLE remote_media_cache_thumbnails;
+
+ALTER TABLE remote_media_cache_thumbnails_new RENAME TO remote_media_cache_thumbnails;
diff --git a/synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql b/synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql
new file mode 100644
index 0000000000..847aebd85e
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql
@@ -0,0 +1,28 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ Delete stuck 'enabled' bits that correspond to deleted or non-existent push rules.
+ We ignore rules that are server-default rules because they are not defined
+ in the `push_rules` table.
+**/
+
+DELETE FROM push_rules_enable WHERE
+ rule_id NOT LIKE 'global/%/.m.rule.%'
+ AND NOT EXISTS (
+ SELECT 1 FROM push_rules
+ WHERE push_rules.user_name = push_rules_enable.user_name
+ AND push_rules.rule_id = push_rules_enable.rule_id
+ );
diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql
new file mode 100644
index 0000000000..98ff76d709
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql
@@ -0,0 +1,16 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE events ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
new file mode 100644
index 0000000000..c31f9af82a
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
@@ -0,0 +1,28 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS events_stream_seq;
+
+SELECT setval('events_stream_seq', (
+ SELECT COALESCE(MAX(stream_ordering), 1) FROM events
+));
+
+CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq;
+
+-- If the server has never backfilled a room then doing `-MIN(...)` will give
+-- a negative result, hence why we do `GREATEST(...)`
+SELECT setval('events_backfill_stream_seq', (
+ SELECT GREATEST(COALESCE(-MIN(stream_ordering), 1), 1) FROM events
+));
diff --git a/synapse/storage/databases/main/schema/delta/58/15_catchup_destination_rooms.sql b/synapse/storage/databases/main/schema/delta/58/15_catchup_destination_rooms.sql
new file mode 100644
index 0000000000..ebfbed7925
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/15_catchup_destination_rooms.sql
@@ -0,0 +1,42 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+-- This schema delta alters the schema to enable 'catching up' remote homeservers
+-- after there has been a connectivity problem for any reason.
+
+-- This stores, for each (destination, room) pair, the stream_ordering of the
+-- latest event for that destination.
+CREATE TABLE IF NOT EXISTS destination_rooms (
+ -- the destination in question.
+ destination TEXT NOT NULL REFERENCES destinations (destination),
+ -- the ID of the room in question
+ room_id TEXT NOT NULL REFERENCES rooms (room_id),
+ -- the stream_ordering of the event
+ stream_ordering BIGINT NOT NULL,
+ PRIMARY KEY (destination, room_id)
+ -- We don't declare a foreign key on stream_ordering here because that'd mean
+ -- we'd need to either maintain an index (expensive) or do a table scan of
+ -- destination_rooms whenever we delete an event (also potentially expensive).
+ -- In addition to that, a foreign key on stream_ordering would be redundant
+ -- as this row doesn't need to refer to a specific event; if the event gets
+ -- deleted then it doesn't affect the validity of the stream_ordering here.
+);
+
+-- This index is needed to make it so that a deletion of a room (in the rooms
+-- table) can be efficient, as otherwise a table scan would need to be performed
+-- to check that no destination_rooms rows point to the room to be deleted.
+-- Also: it makes it efficient to delete all the entries for a given room ID,
+-- such as when purging a room.
+CREATE INDEX IF NOT EXISTS destination_rooms_room_id
+ ON destination_rooms (room_id);
diff --git a/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql b/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql
new file mode 100644
index 0000000000..55f5d0f732
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+-- This delta file fixes a regression introduced by 58/12room_stats.sql, removing the hacky
+-- populate_stats_process_rooms_2 background job and restores the functionality under the
+-- original name.
+-- See https://github.com/matrix-org/synapse/issues/8238 for details
+
+DELETE FROM background_updates WHERE update_name = 'populate_stats_process_rooms';
+UPDATE background_updates SET update_name = 'populate_stats_process_rooms'
+ WHERE update_name = 'populate_stats_process_rooms_2';
diff --git a/synapse/storage/databases/main/schema/delta/58/17_catchup_last_successful.sql b/synapse/storage/databases/main/schema/delta/58/17_catchup_last_successful.sql
new file mode 100644
index 0000000000..a67aa5e500
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/17_catchup_last_successful.sql
@@ -0,0 +1,21 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This column tracks the stream_ordering of the event that was most recently
+-- successfully transmitted to the destination.
+-- A value of NULL means that we have not sent an event successfully yet
+-- (at least, not since the introduction of this column).
+ALTER TABLE destinations
+ ADD COLUMN last_successful_stream_ordering BIGINT;
diff --git a/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql
new file mode 100644
index 0000000000..985fd949a2
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/18stream_positions.sql
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE stream_positions (
+ stream_name TEXT NOT NULL,
+ instance_name TEXT NOT NULL,
+ stream_id BIGINT NOT NULL
+);
+
+CREATE UNIQUE INDEX stream_positions_idx ON stream_positions(stream_name, instance_name);
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 9c5f0229c1..141207fb16 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -89,7 +89,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
if not hs.config.enable_search:
return
@@ -342,7 +342,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
class SearchStore(SearchBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(SearchStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
async def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys.
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 5c6168e301..3c1e33819b 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -56,7 +56,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
def __init__(self, database: DatabasePool, db_conn, hs):
- super(StateGroupWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
async def get_room_version(self, room_id: str) -> RoomVersion:
"""Get the room_version of a given room
@@ -320,7 +320,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
@@ -506,4 +506,4 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
"""
def __init__(self, database: DatabasePool, db_conn, hs):
- super(StateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 55a250ef06..5beb302be3 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -61,7 +61,7 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")}
class StatsStore(StateDeltasStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(StatsStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
self.clock = self.hs.get_clock()
@@ -74,9 +74,6 @@ class StatsStore(StateDeltasStore):
"populate_stats_process_rooms", self._populate_stats_process_rooms
)
self.db_pool.updates.register_background_update_handler(
- "populate_stats_process_rooms_2", self._populate_stats_process_rooms_2
- )
- self.db_pool.updates.register_background_update_handler(
"populate_stats_process_users", self._populate_stats_process_users
)
# we no longer need to perform clean-up, but we will give ourselves
@@ -148,31 +145,10 @@ class StatsStore(StateDeltasStore):
return len(users_to_work_on)
async def _populate_stats_process_rooms(self, progress, batch_size):
- """
- This was a background update which regenerated statistics for rooms.
-
- It has been replaced by StatsStore._populate_stats_process_rooms_2. This background
- job has been scheduled to run as part of Synapse v1.0.0, and again now. To ensure
- someone upgrading from <v1.0.0, this background task has been turned into a no-op
- so that the potentially expensive task is not run twice.
-
- Further context: https://github.com/matrix-org/synapse/pull/7977
- """
- await self.db_pool.updates._end_background_update(
- "populate_stats_process_rooms"
- )
- return 1
-
- async def _populate_stats_process_rooms_2(self, progress, batch_size):
- """
- This is a background update which regenerates statistics for rooms.
-
- It replaces StatsStore._populate_stats_process_rooms. See its docstring for the
- reasoning.
- """
+ """This is a background update which regenerates statistics for rooms."""
if not self.stats_enabled:
await self.db_pool.updates._end_background_update(
- "populate_stats_process_rooms_2"
+ "populate_stats_process_rooms"
)
return 1
@@ -189,13 +165,13 @@ class StatsStore(StateDeltasStore):
return [r for r, in txn]
rooms_to_work_on = await self.db_pool.runInteraction(
- "populate_stats_rooms_2_get_batch", _get_next_batch
+ "populate_stats_rooms_get_batch", _get_next_batch
)
# No more rooms -- complete the transaction.
if not rooms_to_work_on:
await self.db_pool.updates._end_background_update(
- "populate_stats_process_rooms_2"
+ "populate_stats_process_rooms"
)
return 1
@@ -204,9 +180,9 @@ class StatsStore(StateDeltasStore):
progress["last_room_id"] = room_id
await self.db_pool.runInteraction(
- "_populate_stats_process_rooms_2",
+ "_populate_stats_process_rooms",
self.db_pool.updates._background_update_progress_txn,
- "populate_stats_process_rooms_2",
+ "populate_stats_process_rooms",
progress,
)
@@ -234,6 +210,7 @@ class StatsStore(StateDeltasStore):
* topic
* avatar
* canonical_alias
+ * guest_access
A is_federatable key can also be included with a boolean value.
@@ -258,6 +235,7 @@ class StatsStore(StateDeltasStore):
"topic",
"avatar",
"canonical_alias",
+ "guest_access",
):
field = fields.get(col, sentinel)
if field is not sentinel and (not isinstance(field, str) or "\0" in field):
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index db20a3db30..37249f1e3f 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -35,7 +35,6 @@ what sort order was used:
- topological tokems: "t%d-%d", where the integers map to the topological
and stream ordering columns respectively.
"""
-
import abc
import logging
from collections import namedtuple
@@ -54,7 +53,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
-from synapse.types import Collection, RoomStreamToken
+from synapse.types import Collection, PersistedEventPosition, RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
if TYPE_CHECKING:
@@ -79,8 +78,8 @@ _EventDictReturn = namedtuple(
def generate_pagination_where_clause(
direction: str,
column_names: Tuple[str, str],
- from_token: Optional[Tuple[int, int]],
- to_token: Optional[Tuple[int, int]],
+ from_token: Optional[Tuple[Optional[int], int]],
+ to_token: Optional[Tuple[Optional[int], int]],
engine: BaseDatabaseEngine,
) -> str:
"""Creates an SQL expression to bound the columns by the pagination
@@ -259,16 +258,14 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
return " AND ".join(clauses), args
-class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
+class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
"""This is an abstract base class where subclasses must implement
`get_room_max_stream_ordering` and `get_room_min_stream_ordering`
which can be called in the initializer.
"""
- __metaclass__ = abc.ABCMeta
-
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
- super(StreamWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
self._send_federation = hs.should_send_federation()
@@ -307,14 +304,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def get_room_min_stream_ordering(self) -> int:
raise NotImplementedError()
+ def get_room_max_token(self) -> RoomStreamToken:
+ return RoomStreamToken(None, self.get_room_max_stream_ordering())
+
async def get_room_events_stream_for_rooms(
self,
room_ids: Collection[str],
- from_key: str,
- to_key: str,
+ from_key: RoomStreamToken,
+ to_key: RoomStreamToken,
limit: int = 0,
order: str = "DESC",
- ) -> Dict[str, Tuple[List[EventBase], str]]:
+ ) -> Dict[str, Tuple[List[EventBase], RoomStreamToken]]:
"""Get new room events in stream ordering since `from_key`.
Args:
@@ -333,9 +333,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
- list of recent events in the room
- stream ordering key for the start of the chunk of events returned.
"""
- from_id = RoomStreamToken.parse_stream_token(from_key).stream
-
- room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id)
+ room_ids = self._events_stream_cache.get_entities_changed(
+ room_ids, from_key.stream
+ )
if not room_ids:
return {}
@@ -364,16 +364,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
def get_rooms_that_changed(
- self, room_ids: Collection[str], from_key: str
+ self, room_ids: Collection[str], from_key: RoomStreamToken
) -> Set[str]:
"""Given a list of rooms and a token, return rooms where there may have
been changes.
-
- Args:
- room_ids
- from_key: The room_key portion of a StreamToken
"""
- from_id = RoomStreamToken.parse_stream_token(from_key).stream
+ from_id = from_key.stream
return {
room_id
for room_id in room_ids
@@ -383,11 +379,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_room_events_stream_for_room(
self,
room_id: str,
- from_key: str,
- to_key: str,
+ from_key: RoomStreamToken,
+ to_key: RoomStreamToken,
limit: int = 0,
order: str = "DESC",
- ) -> Tuple[List[EventBase], str]:
+ ) -> Tuple[List[EventBase], RoomStreamToken]:
"""Get new room events in stream ordering since `from_key`.
Args:
@@ -408,8 +404,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if from_key == to_key:
return [], from_key
- from_id = RoomStreamToken.parse_stream_token(from_key).stream
- to_id = RoomStreamToken.parse_stream_token(to_key).stream
+ from_id = from_key.stream
+ to_id = to_key.stream
has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
@@ -441,7 +437,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
ret.reverse()
if rows:
- key = "s%d" % min(r.stream_ordering for r in rows)
+ key = RoomStreamToken(None, min(r.stream_ordering for r in rows))
else:
# Assume we didn't get anything because there was nothing to
# get.
@@ -450,10 +446,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key
async def get_membership_changes_for_user(
- self, user_id: str, from_key: str, to_key: str
+ self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
) -> List[EventBase]:
- from_id = RoomStreamToken.parse_stream_token(from_key).stream
- to_id = RoomStreamToken.parse_stream_token(to_key).stream
+ from_id = from_key.stream
+ to_id = to_key.stream
if from_key == to_key:
return []
@@ -491,8 +487,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret
async def get_recent_events_for_room(
- self, room_id: str, limit: int, end_token: str
- ) -> Tuple[List[EventBase], str]:
+ self, room_id: str, limit: int, end_token: RoomStreamToken
+ ) -> Tuple[List[EventBase], RoomStreamToken]:
"""Get the most recent events in the room in topological ordering.
Args:
@@ -518,8 +514,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return (events, token)
async def get_recent_event_ids_for_room(
- self, room_id: str, limit: int, end_token: str
- ) -> Tuple[List[_EventDictReturn], str]:
+ self, room_id: str, limit: int, end_token: RoomStreamToken
+ ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
"""Get the most recent events in the room in topological ordering.
Args:
@@ -535,8 +531,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if limit == 0:
return [], end_token
- end_token = RoomStreamToken.parse(end_token)
-
rows, token = await self.db_pool.runInteraction(
"get_recent_event_ids_for_room",
self._paginate_room_events_txn,
@@ -619,26 +613,28 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
allow_none=allow_none,
)
- async def get_stream_token_for_event(self, event_id: str) -> str:
- """The stream token for an event
- Args:
- event_id: The id of the event to look up a stream token for.
- Raises:
- StoreError if the event wasn't in the database.
- Returns:
- A "s%d" stream token.
+ async def get_position_for_event(self, event_id: str) -> PersistedEventPosition:
+ """Get the persisted position for an event
"""
- stream_id = await self.get_stream_id_for_event(event_id)
- return "s%d" % (stream_id,)
+ row = await self.db_pool.simple_select_one(
+ table="events",
+ keyvalues={"event_id": event_id},
+ retcols=("stream_ordering", "instance_name"),
+ desc="get_position_for_event",
+ )
- async def get_topological_token_for_event(self, event_id: str) -> str:
+ return PersistedEventPosition(
+ row["instance_name"] or "master", row["stream_ordering"]
+ )
+
+ async def get_topological_token_for_event(self, event_id: str) -> RoomStreamToken:
"""The stream token for an event
Args:
event_id: The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
- A "t%d-%d" topological token.
+ A `RoomStreamToken` topological token.
"""
row = await self.db_pool.simple_select_one(
table="events",
@@ -646,7 +642,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
retcols=("stream_ordering", "topological_ordering"),
desc="get_topological_token_for_event",
)
- return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"])
+ return RoomStreamToken(row["topological_ordering"], row["stream_ordering"])
async def get_current_topological_token(self, room_id: str, stream_key: int) -> int:
"""Gets the topological token in a room after or at the given stream
@@ -695,8 +691,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
else:
topo = None
internal = event.internal_metadata
- internal.before = str(RoomStreamToken(topo, stream - 1))
- internal.after = str(RoomStreamToken(topo, stream))
+ internal.before = RoomStreamToken(topo, stream - 1)
+ internal.after = RoomStreamToken(topo, stream)
internal.order = (int(topo) if topo else 0, int(stream))
async def get_events_around(
@@ -951,7 +947,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
direction: str = "b",
limit: int = -1,
event_filter: Optional[Filter] = None,
- ) -> Tuple[List[_EventDictReturn], str]:
+ ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
"""Returns list of events before or after a given token.
Args:
@@ -986,8 +982,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
bounds = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
- from_token=from_token,
- to_token=to_token,
+ from_token=from_token.as_tuple(),
+ to_token=to_token.as_tuple() if to_token else None,
engine=self.database_engine,
)
@@ -1051,17 +1047,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# TODO (erikj): We should work out what to do here instead.
next_token = to_token if to_token else from_token
- return rows, str(next_token)
+ return rows, next_token
async def paginate_room_events(
self,
room_id: str,
- from_key: str,
- to_key: Optional[str] = None,
+ from_key: RoomStreamToken,
+ to_key: Optional[RoomStreamToken] = None,
direction: str = "b",
limit: int = -1,
event_filter: Optional[Filter] = None,
- ) -> Tuple[List[EventBase], str]:
+ ) -> Tuple[List[EventBase], RoomStreamToken]:
"""Returns list of events before or after a given token.
Args:
@@ -1080,10 +1076,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
and `to_key`).
"""
- from_key = RoomStreamToken.parse(from_key)
- if to_key:
- to_key = RoomStreamToken.parse(to_key)
-
rows, token = await self.db_pool.runInteraction(
"paginate_room_events",
self._paginate_room_events_txn,
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 96ffe26cc9..9f120d3cb6 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore):
)
self._update_revision_txn(txn, user_id, room_id, next_id)
- with await self._account_data_id_gen.get_next() as next_id:
+ async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
@@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore):
txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id)
- with await self._account_data_id_gen.get_next() as next_id:
+ async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 5b31aab700..97aed1500e 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -15,13 +15,14 @@
import logging
from collections import namedtuple
-from typing import Optional, Tuple
+from typing import Iterable, List, Optional, Tuple
from canonicaljson import encode_canonical_json
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict
from synapse.util.caches.expiringcache import ExpiringCache
@@ -47,7 +48,7 @@ class TransactionStore(SQLBaseStore):
"""
def __init__(self, database: DatabasePool, db_conn, hs):
- super(TransactionStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
@@ -164,7 +165,9 @@ class TransactionStore(SQLBaseStore):
allow_none=True,
)
- if result and result["retry_last_ts"] > 0:
+ # check we have a row and retry_last_ts is not null or zero
+ # (retry_last_ts can't be negative)
+ if result and result["retry_last_ts"]:
return result
else:
return None
@@ -215,6 +218,7 @@ class TransactionStore(SQLBaseStore):
retry_interval = EXCLUDED.retry_interval
WHERE
EXCLUDED.retry_interval = 0
+ OR destinations.retry_interval IS NULL
OR destinations.retry_interval < EXCLUDED.retry_interval
"""
@@ -246,7 +250,11 @@ class TransactionStore(SQLBaseStore):
"retry_interval": retry_interval,
},
)
- elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval:
+ elif (
+ retry_interval == 0
+ or prev_row["retry_interval"] is None
+ or prev_row["retry_interval"] < retry_interval
+ ):
self.db_pool.simple_update_one_txn(
txn,
"destinations",
@@ -273,3 +281,196 @@ class TransactionStore(SQLBaseStore):
await self.db_pool.runInteraction(
"_cleanup_transactions", _cleanup_transactions_txn
)
+
+ async def store_destination_rooms_entries(
+ self, destinations: Iterable[str], room_id: str, stream_ordering: int,
+ ) -> None:
+ """
+ Updates or creates `destination_rooms` entries in batch for a single event.
+
+ Args:
+ destinations: list of destinations
+ room_id: the room_id of the event
+ stream_ordering: the stream_ordering of the event
+ """
+
+ return await self.db_pool.runInteraction(
+ "store_destination_rooms_entries",
+ self._store_destination_rooms_entries_txn,
+ destinations,
+ room_id,
+ stream_ordering,
+ )
+
+ def _store_destination_rooms_entries_txn(
+ self,
+ txn: LoggingTransaction,
+ destinations: Iterable[str],
+ room_id: str,
+ stream_ordering: int,
+ ) -> None:
+
+ # ensure we have a `destinations` row for this destination, as there is
+ # a foreign key constraint.
+ if isinstance(self.database_engine, PostgresEngine):
+ q = """
+ INSERT INTO destinations (destination)
+ VALUES (?)
+ ON CONFLICT DO NOTHING;
+ """
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ q = """
+ INSERT OR IGNORE INTO destinations (destination)
+ VALUES (?);
+ """
+ else:
+ raise RuntimeError("Unknown database engine")
+
+ txn.execute_batch(q, ((destination,) for destination in destinations))
+
+ rows = [(destination, room_id) for destination in destinations]
+
+ self.db_pool.simple_upsert_many_txn(
+ txn,
+ "destination_rooms",
+ ["destination", "room_id"],
+ rows,
+ ["stream_ordering"],
+ [(stream_ordering,)] * len(rows),
+ )
+
+ async def get_destination_last_successful_stream_ordering(
+ self, destination: str
+ ) -> Optional[int]:
+ """
+ Gets the stream ordering of the PDU most-recently successfully sent
+ to the specified destination, or None if this information has not been
+ tracked yet.
+
+ Args:
+ destination: the destination to query
+ """
+ return await self.db_pool.simple_select_one_onecol(
+ "destinations",
+ {"destination": destination},
+ "last_successful_stream_ordering",
+ allow_none=True,
+ desc="get_last_successful_stream_ordering",
+ )
+
+ async def set_destination_last_successful_stream_ordering(
+ self, destination: str, last_successful_stream_ordering: int
+ ) -> None:
+ """
+ Marks that we have successfully sent the PDUs up to and including the
+ one specified.
+
+ Args:
+ destination: the destination we have successfully sent to
+ last_successful_stream_ordering: the stream_ordering of the most
+ recent successfully-sent PDU
+ """
+ return await self.db_pool.simple_upsert(
+ "destinations",
+ keyvalues={"destination": destination},
+ values={"last_successful_stream_ordering": last_successful_stream_ordering},
+ desc="set_last_successful_stream_ordering",
+ )
+
+ async def get_catch_up_room_event_ids(
+ self, destination: str, last_successful_stream_ordering: int,
+ ) -> List[str]:
+ """
+ Returns at most 50 event IDs and their corresponding stream_orderings
+ that correspond to the oldest events that have not yet been sent to
+ the destination.
+
+ Args:
+ destination: the destination in question
+ last_successful_stream_ordering: the stream_ordering of the
+ most-recently successfully-transmitted event to the destination
+
+ Returns:
+ list of event_ids
+ """
+ return await self.db_pool.runInteraction(
+ "get_catch_up_room_event_ids",
+ self._get_catch_up_room_event_ids_txn,
+ destination,
+ last_successful_stream_ordering,
+ )
+
+ @staticmethod
+ def _get_catch_up_room_event_ids_txn(
+ txn: LoggingTransaction, destination: str, last_successful_stream_ordering: int,
+ ) -> List[str]:
+ q = """
+ SELECT event_id FROM destination_rooms
+ JOIN events USING (stream_ordering)
+ WHERE destination = ?
+ AND stream_ordering > ?
+ ORDER BY stream_ordering
+ LIMIT 50
+ """
+ txn.execute(
+ q, (destination, last_successful_stream_ordering),
+ )
+ event_ids = [row[0] for row in txn]
+ return event_ids
+
+ async def get_catch_up_outstanding_destinations(
+ self, after_destination: Optional[str]
+ ) -> List[str]:
+ """
+ Gets at most 25 destinations which have outstanding PDUs to be caught up,
+ and are not being backed off from
+ Args:
+ after_destination:
+ If provided, all destinations must be lexicographically greater
+ than this one.
+
+ Returns:
+ list of up to 25 destinations with outstanding catch-up.
+ These are the lexicographically first destinations which are
+ lexicographically greater than after_destination (if provided).
+ """
+ time = self.hs.get_clock().time_msec()
+
+ return await self.db_pool.runInteraction(
+ "get_catch_up_outstanding_destinations",
+ self._get_catch_up_outstanding_destinations_txn,
+ time,
+ after_destination,
+ )
+
+ @staticmethod
+ def _get_catch_up_outstanding_destinations_txn(
+ txn: LoggingTransaction, now_time_ms: int, after_destination: Optional[str]
+ ) -> List[str]:
+ q = """
+ SELECT destination FROM destinations
+ WHERE destination IN (
+ SELECT destination FROM destination_rooms
+ WHERE destination_rooms.stream_ordering >
+ destinations.last_successful_stream_ordering
+ )
+ AND destination > ?
+ AND (
+ retry_last_ts IS NULL OR
+ retry_last_ts + retry_interval < ?
+ )
+ ORDER BY destination
+ LIMIT 25
+ """
+ txn.execute(
+ q,
+ (
+ # everything is lexicographically greater than "" so this gives
+ # us the first batch of up to 25.
+ after_destination or "",
+ now_time_ms,
+ ),
+ )
+
+ destinations = [row[0] for row in txn]
+ return destinations
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index b89668d561..3b9211a6d2 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -23,7 +23,7 @@ from synapse.types import JsonDict
from synapse.util import json_encoder, stringutils
-@attr.s
+@attr.s(slots=True)
class UIAuthSessionData:
session_id = attr.ib(type=str)
# The dictionary from the client root level, not the 'auth' key.
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index f2f9a5799a..5a390ff2f6 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -38,7 +38,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
SHARE_PRIVATE_WORKING_SET = 500
def __init__(self, database: DatabasePool, db_conn, hs):
- super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
@@ -564,7 +564,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
SHARE_PRIVATE_WORKING_SET = 500
def __init__(self, database: DatabasePool, db_conn, hs):
- super(UserDirectoryStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
async def remove_from_user_dir(self, user_id: str) -> None:
def _remove_from_user_dir_txn(txn):
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index 2f7c95fc74..f9575b1f1f 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -100,7 +100,7 @@ class UserErasureStore(UserErasureWorkerStore):
return
# They are there, delete them.
- self.simple_delete_one_txn(
+ self.db_pool.simple_delete_one_txn(
txn, "erased_users", keyvalues={"user_id": user_id}
)
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index 139085b672..acb24e33af 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -181,7 +181,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
self._background_deduplicate_state,
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index e924f1ca3b..0e31cc811a 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -24,7 +24,7 @@ from synapse.storage.databases.state.bg_updates import StateBackgroundUpdateStor
from synapse.storage.state import StateFilter
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import build_sequence_generator
-from synapse.types import StateMap
+from synapse.types import MutableStateMap, StateMap
from synapse.util.caches.descriptors import cached
from synapse.util.caches.dictionary_cache import DictionaryCache
@@ -52,7 +52,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"""
def __init__(self, database: DatabasePool, db_conn, hs):
- super(StateGroupDataStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
# Originally the state store used a single DictionaryCache to cache the
# event IDs for the state types in a given state group to avoid hammering
@@ -99,6 +99,9 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
self._state_group_seq_gen = build_sequence_generator(
self.database_engine, get_max_state_group_txn, "state_group_id_seq"
)
+ self._state_group_seq_gen.check_consistency(
+ db_conn, table="state_groups", id_column="id"
+ )
@cached(max_entries=10000, iterable=True)
async def get_state_group_delta(self, state_group):
@@ -205,7 +208,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
async def _get_state_for_groups(
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
- ) -> Dict[int, StateMap[str]]:
+ ) -> Dict[int, MutableStateMap[str]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index dbaeef91dd..72939f3984 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -18,7 +18,7 @@
import itertools
import logging
from collections import deque, namedtuple
-from typing import Iterable, List, Optional, Set, Tuple
+from typing import Dict, Iterable, List, Optional, Set, Tuple
from prometheus_client import Counter, Histogram
@@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases import Databases
from synapse.storage.databases.main.events import DeltaState
-from synapse.types import StateMap
+from synapse.types import Collection, PersistedEventPosition, RoomStreamToken, StateMap
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.metrics import Measure
@@ -185,18 +185,21 @@ class EventsPersistenceStorage:
# store for now.
self.main_store = stores.main
self.state_store = stores.state
+
+ assert stores.persist_events
self.persist_events_store = stores.persist_events
self._clock = hs.get_clock()
+ self._instance_name = hs.get_instance_name()
self.is_mine_id = hs.is_mine_id
self._event_persist_queue = _EventPeristenceQueue()
self._state_resolution_handler = hs.get_state_resolution_handler()
async def persist_events(
self,
- events_and_contexts: List[Tuple[EventBase, EventContext]],
+ events_and_contexts: Iterable[Tuple[EventBase, EventContext]],
backfilled: bool = False,
- ) -> int:
+ ) -> RoomStreamToken:
"""
Write events to the database
Args:
@@ -208,7 +211,7 @@ class EventsPersistenceStorage:
Returns:
the stream ordering of the latest persisted event
"""
- partitioned = {}
+ partitioned = {} # type: Dict[str, List[Tuple[EventBase, EventContext]]]
for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx))
@@ -226,11 +229,11 @@ class EventsPersistenceStorage:
defer.gatherResults(deferreds, consumeErrors=True)
)
- return self.main_store.get_current_events_token()
+ return self.main_store.get_room_max_token()
async def persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
- ) -> Tuple[int, int]:
+ ) -> Tuple[PersistedEventPosition, RoomStreamToken]:
"""
Returns:
The stream ordering of `event`, and the stream ordering of the
@@ -244,8 +247,10 @@ class EventsPersistenceStorage:
await make_deferred_yieldable(deferred)
- max_persisted_id = self.main_store.get_current_events_token()
- return (event.internal_metadata.stream_ordering, max_persisted_id)
+ event_stream_id = event.internal_metadata.stream_ordering
+
+ pos = PersistedEventPosition(self._instance_name, event_stream_id)
+ return pos, self.main_store.get_room_max_token()
def _maybe_start_persisting(self, room_id: str):
async def persisting_queue(item):
@@ -305,7 +310,9 @@ class EventsPersistenceStorage:
# Work out the new "current state" for each room.
# We do this by working out what the new extremities are and then
# calculating the state from that.
- events_by_room = {}
+ events_by_room = (
+ {}
+ ) # type: Dict[str, List[Tuple[EventBase, EventContext]]]
for event, context in chunk:
events_by_room.setdefault(event.room_id, []).append(
(event, context)
@@ -436,7 +443,7 @@ class EventsPersistenceStorage:
self,
room_id: str,
event_contexts: List[Tuple[EventBase, EventContext]],
- latest_event_ids: List[str],
+ latest_event_ids: Collection[str],
):
"""Calculates the new forward extremities for a room given events to
persist.
@@ -470,7 +477,7 @@ class EventsPersistenceStorage:
# Remove any events which are prev_events of any existing events.
existing_prevs = await self.persist_events_store._get_events_which_are_prevs(
result
- )
+ ) # type: Collection[str]
result.difference_update(existing_prevs)
# Finally handle the case where the new events have soft-failed prev
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index ee60e2a718..4957e77f4c 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -19,12 +19,15 @@ import logging
import os
import re
from collections import Counter
-from typing import TextIO
+from typing import Optional, TextIO
import attr
+from synapse.config.homeserver import HomeServerConfig
+from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.engines.postgres import PostgresEngine
-from synapse.storage.types import Cursor
+from synapse.storage.types import Connection, Cursor
+from synapse.types import Collection
logger = logging.getLogger(__name__)
@@ -63,7 +66,12 @@ UNAPPLIED_DELTA_ON_WORKER_ERROR = (
)
-def prepare_database(db_conn, database_engine, config, databases=["main", "state"]):
+def prepare_database(
+ db_conn: Connection,
+ database_engine: BaseDatabaseEngine,
+ config: Optional[HomeServerConfig],
+ databases: Collection[str] = ["main", "state"],
+):
"""Prepares a physical database for usage. Will either create all necessary tables
or upgrade from an older schema version.
@@ -73,16 +81,24 @@ def prepare_database(db_conn, database_engine, config, databases=["main", "state
Args:
db_conn:
database_engine:
- config (synapse.config.homeserver.HomeServerConfig|None):
+ config :
application config, or None if we are connecting to an existing
database which we expect to be configured already
- databases (list[str]): The name of the databases that will be used
+ databases: The name of the databases that will be used
with this physical database. Defaults to all databases.
"""
try:
cur = db_conn.cursor()
+ # sqlite does not automatically start transactions for DDL / SELECT statements,
+ # so we start one before running anything. This ensures that any upgrades
+ # are either applied completely, or not at all.
+ #
+ # (psycopg2 automatically starts a transaction as soon as we run any statements
+ # at all, so this is redundant but harmless there.)
+ cur.execute("BEGIN TRANSACTION")
+
logger.info("%r: Checking existing schema version", databases)
version_info = _get_or_create_schema_state(cur, database_engine)
@@ -622,7 +638,7 @@ def _get_or_create_schema_state(txn, database_engine):
return None
-@attr.s()
+@attr.s(slots=True)
class _DirectoryListing:
"""Helper class to store schema file name and the
absolute path to it.
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index d30e3f11e7..cec96ad6a7 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -22,7 +22,7 @@ from synapse.api.errors import SynapseError
logger = logging.getLogger(__name__)
-@attr.s
+@attr.s(slots=True)
class PaginationChunk:
"""Returned by relation pagination APIs.
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 8c4a83a840..f152f63321 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -25,7 +25,7 @@ RoomsForUser = namedtuple(
)
GetRoomsForUserWithStreamOrdering = namedtuple(
- "_GetRoomsForUserWithStreamOrdering", ("room_id", "stream_ordering")
+ "_GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
)
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 8f68d968f0..08a69f2f96 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -20,7 +20,7 @@ import attr
from synapse.api.constants import EventTypes
from synapse.events import EventBase
-from synapse.types import StateMap
+from synapse.types import MutableStateMap, StateMap
logger = logging.getLogger(__name__)
@@ -349,7 +349,7 @@ class StateGroupStorage:
async def get_state_groups_ids(
self, _room_id: str, event_ids: Iterable[str]
- ) -> Dict[int, StateMap[str]]:
+ ) -> Dict[int, MutableStateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events
Args:
@@ -532,7 +532,7 @@ class StateGroupStorage:
def _get_state_for_groups(
self, groups: Iterable[int], state_filter: StateFilter = StateFilter.all()
- ) -> Awaitable[Dict[int, StateMap[str]]]:
+ ) -> Awaitable[Dict[int, MutableStateMap[str]]]:
"""Gets the state at each of a list of state groups, optionally
filtering by type/state_key
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index b7eb4f8ac9..02fbb656e8 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -12,16 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-import contextlib
import heapq
import logging
import threading
from collections import deque
-from typing import Dict, List, Set
+from contextlib import contextmanager
+from typing import Dict, List, Optional, Set, Union
+import attr
from typing_extensions import Deque
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.util.sequence import PostgresSequenceGenerator
@@ -86,7 +87,7 @@ class StreamIdGenerator:
upwards, -1 to grow downwards.
Usage:
- with await stream_id_gen.get_next() as stream_id:
+ async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
@@ -101,10 +102,10 @@ class StreamIdGenerator:
)
self._unfinished_ids = deque() # type: Deque[int]
- async def get_next(self):
+ def get_next(self):
"""
Usage:
- with await stream_id_gen.get_next() as stream_id:
+ async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
with self._lock:
@@ -113,7 +114,7 @@ class StreamIdGenerator:
self._unfinished_ids.append(next_id)
- @contextlib.contextmanager
+ @contextmanager
def manager():
try:
yield next_id
@@ -121,12 +122,12 @@ class StreamIdGenerator:
with self._lock:
self._unfinished_ids.remove(next_id)
- return manager()
+ return _AsyncCtxManagerWrapper(manager())
- async def get_next_mult(self, n):
+ def get_next_mult(self, n):
"""
Usage:
- with await stream_id_gen.get_next(n) as stream_ids:
+ async with stream_id_gen.get_next(n) as stream_ids:
# ... persist events ...
"""
with self._lock:
@@ -140,7 +141,7 @@ class StreamIdGenerator:
for next_id in next_ids:
self._unfinished_ids.append(next_id)
- @contextlib.contextmanager
+ @contextmanager
def manager():
try:
yield next_ids
@@ -149,7 +150,7 @@ class StreamIdGenerator:
for next_id in next_ids:
self._unfinished_ids.remove(next_id)
- return manager()
+ return _AsyncCtxManagerWrapper(manager())
def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or
@@ -184,12 +185,16 @@ class MultiWriterIdGenerator:
Args:
db_conn
db
+ stream_name: A name for the stream.
instance_name: The name of this instance.
table: Database table associated with stream.
instance_column: Column that stores the row's writer's instance name
id_column: Column that stores the stream ID.
sequence_name: The name of the postgres sequence used to generate new
IDs.
+ writers: A list of known writers to use to populate current positions
+ on startup. Can be empty if nothing uses `get_current_token` or
+ `get_positions` (e.g. caches stream).
positive: Whether the IDs are positive (true) or negative (false).
When using negative IDs we go backwards from -1 to -2, -3, etc.
"""
@@ -198,16 +203,20 @@ class MultiWriterIdGenerator:
self,
db_conn,
db: DatabasePool,
+ stream_name: str,
instance_name: str,
table: str,
instance_column: str,
id_column: str,
sequence_name: str,
+ writers: List[str],
positive: bool = True,
):
self._db = db
+ self._stream_name = stream_name
self._instance_name = instance_name
self._positive = positive
+ self._writers = writers
self._return_factor = 1 if positive else -1
# We lock as some functions may be called from DB threads.
@@ -216,14 +225,16 @@ class MultiWriterIdGenerator:
# Note: If we are a negative stream then we still store all the IDs as
# positive to make life easier for us, and simply negate the IDs when we
# return them.
- self._current_positions = self._load_current_ids(
- db_conn, table, instance_column, id_column
- )
+ self._current_positions = {} # type: Dict[str, int]
# Set of local IDs that we're still processing. The current position
# should be less than the minimum of this set (if not empty).
self._unfinished_ids = set() # type: Set[int]
+ # Set of local IDs that we've processed that are larger than the current
+ # position, due to there being smaller unpersisted IDs.
+ self._finished_ids = set() # type: Set[int]
+
# We track the max position where we know everything before has been
# persisted. This is done by a) looking at the min across all instances
# and b) noting that if we have seen a run of persisted positions
@@ -236,37 +247,113 @@ class MultiWriterIdGenerator:
# gaps should be relatively rare it's still worth doing the book keeping
# that allows us to skip forwards when there are gapless runs of
# positions.
+ #
+ # We start at 1 here as a) the first generated stream ID will be 2, and
+ # b) other parts of the code assume that stream IDs are strictly greater
+ # than 0.
self._persisted_upto_position = (
- min(self._current_positions.values()) if self._current_positions else 0
+ min(self._current_positions.values()) if self._current_positions else 1
)
self._known_persisted_positions = [] # type: List[int]
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
+ # We check that the table and sequence haven't diverged.
+ self._sequence_gen.check_consistency(
+ db_conn, table=table, id_column=id_column, positive=positive
+ )
+
+ # This goes and fills out the above state from the database.
+ self._load_current_ids(db_conn, table, instance_column, id_column)
+
def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str
- ) -> Dict[str, int]:
- # If positive stream aggregate via MAX. For negative stream use MIN
- # *and* negate the result to get a positive number.
- sql = """
- SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
- GROUP BY %(instance)s
- """ % {
- "instance": instance_column,
- "id": id_column,
- "table": table,
- "agg": "MAX" if self._positive else "-MIN",
- }
-
+ ):
cur = db_conn.cursor()
- cur.execute(sql)
- # `cur` is an iterable over returned rows, which are 2-tuples.
- current_positions = dict(cur)
+ # Load the current positions of all writers for the stream.
+ if self._writers:
+ # We delete any stale entries in the positions table. This is
+ # important if we add back a writer after a long time; we want to
+ # consider that a "new" writer, rather than using the old stale
+ # entry here.
+ sql = """
+ DELETE FROM stream_positions
+ WHERE
+ stream_name = ?
+ AND instance_name != ALL(?)
+ """
+ sql = self._db.engine.convert_param_style(sql)
+ cur.execute(sql, (self._stream_name, self._writers))
+
+ sql = """
+ SELECT instance_name, stream_id FROM stream_positions
+ WHERE stream_name = ?
+ """
+ sql = self._db.engine.convert_param_style(sql)
+
+ cur.execute(sql, (self._stream_name,))
+
+ self._current_positions = {
+ instance: stream_id * self._return_factor
+ for instance, stream_id in cur
+ if instance in self._writers
+ }
- cur.close()
+ # We set the `_persisted_upto_position` to be the minimum of all current
+ # positions. If empty we use the max stream ID from the DB table.
+ min_stream_id = min(self._current_positions.values(), default=None)
+
+ if min_stream_id is None:
+ # We add a GREATEST here to ensure that the result is always
+ # positive. (This can be a problem for e.g. backfill streams where
+ # the server has never backfilled).
+ sql = """
+ SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
+ FROM %(table)s
+ """ % {
+ "id": id_column,
+ "table": table,
+ "agg": "MAX" if self._positive else "-MIN",
+ }
+ cur.execute(sql)
+ (stream_id,) = cur.fetchone()
+ self._persisted_upto_position = stream_id
+ else:
+ # If we have a min_stream_id then we pull out everything greater
+ # than it from the DB so that we can prefill
+ # `_known_persisted_positions` and get a more accurate
+ # `_persisted_upto_position`.
+ #
+ # We also check if any of the later rows are from this instance, in
+ # which case we use that for this instance's current position. This
+ # is to handle the case where we didn't finish persisting to the
+ # stream positions table before restart (or the stream position
+ # table otherwise got out of date).
+
+ sql = """
+ SELECT %(instance)s, %(id)s FROM %(table)s
+ WHERE ? %(cmp)s %(id)s
+ """ % {
+ "id": id_column,
+ "table": table,
+ "instance": instance_column,
+ "cmp": "<=" if self._positive else ">=",
+ }
+ sql = self._db.engine.convert_param_style(sql)
+ cur.execute(sql, (min_stream_id,))
- return current_positions
+ self._persisted_upto_position = min_stream_id
+
+ with self._lock:
+ for (instance, stream_id,) in cur:
+ stream_id = self._return_factor * stream_id
+ self._add_persisted_position(stream_id)
+
+ if instance == self._instance_name:
+ self._current_positions[instance] = stream_id
+
+ cur.close()
def _load_next_id_txn(self, txn) -> int:
return self._sequence_gen.get_next_id_txn(txn)
@@ -274,59 +361,23 @@ class MultiWriterIdGenerator:
def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
return self._sequence_gen.get_next_mult_txn(txn, n)
- async def get_next(self):
+ def get_next(self):
"""
Usage:
- with await stream_id_gen.get_next() as stream_id:
+ async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
- next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
- # Assert the fetched ID is actually greater than what we currently
- # believe the ID to be. If not, then the sequence and table have got
- # out of sync somehow.
- with self._lock:
- assert self._current_positions.get(self._instance_name, 0) < next_id
-
- self._unfinished_ids.add(next_id)
+ return _MultiWriterCtxManager(self)
- @contextlib.contextmanager
- def manager():
- try:
- # Multiply by the return factor so that the ID has correct sign.
- yield self._return_factor * next_id
- finally:
- self._mark_id_as_finished(next_id)
-
- return manager()
-
- async def get_next_mult(self, n: int):
+ def get_next_mult(self, n: int):
"""
Usage:
- with await stream_id_gen.get_next_mult(5) as stream_ids:
+ async with stream_id_gen.get_next_mult(5) as stream_ids:
# ... persist events ...
"""
- next_ids = await self._db.runInteraction(
- "_load_next_mult_id", self._load_next_mult_id_txn, n
- )
- # Assert the fetched ID is actually greater than any ID we've already
- # seen. If not, then the sequence and table have got out of sync
- # somehow.
- with self._lock:
- assert max(self._current_positions.values(), default=0) < min(next_ids)
-
- self._unfinished_ids.update(next_ids)
-
- @contextlib.contextmanager
- def manager():
- try:
- yield [self._return_factor * i for i in next_ids]
- finally:
- for i in next_ids:
- self._mark_id_as_finished(i)
-
- return manager()
+ return _MultiWriterCtxManager(self, n)
def get_next_txn(self, txn: LoggingTransaction):
"""
@@ -344,21 +395,63 @@ class MultiWriterIdGenerator:
txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)
+ # Update the `stream_positions` table with newly updated stream
+ # ID (unless self._writers is not set in which case we don't
+ # bother, as nothing will read it).
+ #
+ # We only do this on the success path so that the persisted current
+ # position points to a persited row with the correct instance name.
+ if self._writers:
+ txn.call_after(
+ run_as_background_process,
+ "MultiWriterIdGenerator._update_table",
+ self._db.runInteraction,
+ "MultiWriterIdGenerator._update_table",
+ self._update_stream_positions_table_txn,
+ )
+
return self._return_factor * next_id
def _mark_id_as_finished(self, next_id: int):
"""The ID has finished being processed so we should advance the
- current poistion if possible.
+ current position if possible.
"""
with self._lock:
self._unfinished_ids.discard(next_id)
+ self._finished_ids.add(next_id)
+
+ new_cur = None
+
+ if self._unfinished_ids:
+ # If there are unfinished IDs then the new position will be the
+ # largest finished ID less than the minimum unfinished ID.
+
+ finished = set()
+
+ min_unfinshed = min(self._unfinished_ids)
+ for s in self._finished_ids:
+ if s < min_unfinshed:
+ if new_cur is None or new_cur < s:
+ new_cur = s
+ else:
+ finished.add(s)
+
+ # We clear these out since they're now all less than the new
+ # position.
+ self._finished_ids = finished
+ else:
+ # There are no unfinished IDs so the new position is simply the
+ # largest finished one.
+ new_cur = max(self._finished_ids)
+
+ # We clear these out since they're now all less than the new
+ # position.
+ self._finished_ids.clear()
- # Figure out if its safe to advance the position by checking there
- # aren't any lower allocated IDs that are yet to finish.
- if all(c > next_id for c in self._unfinished_ids):
+ if new_cur:
curr = self._current_positions.get(self._instance_name, 0)
- self._current_positions[self._instance_name] = max(curr, next_id)
+ self._current_positions[self._instance_name] = max(curr, new_cur)
self._add_persisted_position(next_id)
@@ -367,19 +460,28 @@ class MultiWriterIdGenerator:
equal to it have been successfully persisted.
"""
- # Currently we don't support this operation, as it's not obvious how to
- # condense the stream positions of multiple writers into a single int.
- raise NotImplementedError()
+ return self.get_persisted_upto_position()
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.
"""
+ # If we don't have an entry for the given instance name, we assume it's a
+ # new writer.
+ #
+ # For new writers we assume their initial position to be the current
+ # persisted up to position. This stops Synapse from doing a full table
+ # scan when a new writer announces itself over replication.
with self._lock:
- return self._return_factor * self._current_positions.get(instance_name, 0)
+ return self._return_factor * self._current_positions.get(
+ instance_name, self._persisted_upto_position
+ )
def get_positions(self) -> Dict[str, int]:
"""Get a copy of the current positon map.
+
+ Note that this won't necessarily include all configured writers if some
+ writers haven't written anything yet.
"""
with self._lock:
@@ -428,7 +530,7 @@ class MultiWriterIdGenerator:
# We move the current min position up if the minimum current positions
# of all instances is higher (since by definition all positions less
# that that have been persisted).
- min_curr = min(self._current_positions.values())
+ min_curr = min(self._current_positions.values(), default=0)
self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
# We now iterate through the seen positions, discarding those that are
@@ -449,3 +551,95 @@ class MultiWriterIdGenerator:
# There was a gap in seen positions, so there is nothing more to
# do.
break
+
+ def _update_stream_positions_table_txn(self, txn):
+ """Update the `stream_positions` table with newly persisted position.
+ """
+
+ if not self._writers:
+ return
+
+ # We upsert the value, ensuring on conflict that we always increase the
+ # value (or decrease if stream goes backwards).
+ sql = """
+ INSERT INTO stream_positions (stream_name, instance_name, stream_id)
+ VALUES (?, ?, ?)
+ ON CONFLICT (stream_name, instance_name)
+ DO UPDATE SET
+ stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id)
+ """ % {
+ "agg": "GREATEST" if self._positive else "LEAST",
+ }
+
+ pos = (self.get_current_token_for_writer(self._instance_name),)
+ txn.execute(sql, (self._stream_name, self._instance_name, pos))
+
+
+@attr.s(slots=True)
+class _AsyncCtxManagerWrapper:
+ """Helper class to convert a plain context manager to an async one.
+
+ This is mainly useful if you have a plain context manager but the interface
+ requires an async one.
+ """
+
+ inner = attr.ib()
+
+ async def __aenter__(self):
+ return self.inner.__enter__()
+
+ async def __aexit__(self, exc_type, exc, tb):
+ return self.inner.__exit__(exc_type, exc, tb)
+
+
+@attr.s(slots=True)
+class _MultiWriterCtxManager:
+ """Async context manager returned by MultiWriterIdGenerator
+ """
+
+ id_gen = attr.ib(type=MultiWriterIdGenerator)
+ multiple_ids = attr.ib(type=Optional[int], default=None)
+ stream_ids = attr.ib(type=List[int], factory=list)
+
+ async def __aenter__(self) -> Union[int, List[int]]:
+ self.stream_ids = await self.id_gen._db.runInteraction(
+ "_load_next_mult_id",
+ self.id_gen._load_next_mult_id_txn,
+ self.multiple_ids or 1,
+ )
+
+ # Assert the fetched ID is actually greater than any ID we've already
+ # seen. If not, then the sequence and table have got out of sync
+ # somehow.
+ with self.id_gen._lock:
+ assert max(self.id_gen._current_positions.values(), default=0) < min(
+ self.stream_ids
+ )
+
+ self.id_gen._unfinished_ids.update(self.stream_ids)
+
+ if self.multiple_ids is None:
+ return self.stream_ids[0] * self.id_gen._return_factor
+ else:
+ return [i * self.id_gen._return_factor for i in self.stream_ids]
+
+ async def __aexit__(self, exc_type, exc, tb):
+ for i in self.stream_ids:
+ self.id_gen._mark_id_as_finished(i)
+
+ if exc_type is not None:
+ return False
+
+ # Update the `stream_positions` table with newly updated stream
+ # ID (unless self._writers is not set in which case we don't
+ # bother, as nothing will read it).
+ #
+ # We only do this on the success path so that the persisted current
+ # position points to a persisted row with the correct instance name.
+ if self.id_gen._writers:
+ await self.id_gen._db.runInteraction(
+ "MultiWriterIdGenerator._update_table",
+ self.id_gen._update_stream_positions_table_txn,
+ )
+
+ return False
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index ffc1894748..2dd95e2709 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -13,11 +13,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
+import logging
import threading
from typing import Callable, List, Optional
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
-from synapse.storage.types import Cursor
+from synapse.storage.engines import (
+ BaseDatabaseEngine,
+ IncorrectDatabaseSetup,
+ PostgresEngine,
+)
+from synapse.storage.types import Connection, Cursor
+
+logger = logging.getLogger(__name__)
+
+
+_INCONSISTENT_SEQUENCE_ERROR = """
+Postgres sequence '%(seq)s' is inconsistent with associated
+table '%(table)s'. This can happen if Synapse has been downgraded and
+then upgraded again, or due to a bad migration.
+
+To fix this error, shut down Synapse (including any and all workers)
+and run the following SQL:
+
+ SELECT setval('%(seq)s', (
+ %(max_id_sql)s
+ ));
+
+See docs/postgres.md for more information.
+"""
class SequenceGenerator(metaclass=abc.ABCMeta):
@@ -28,6 +51,19 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
"""Gets the next ID in the sequence"""
...
+ @abc.abstractmethod
+ def check_consistency(
+ self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+ ):
+ """Should be called during start up to test that the current value of
+ the sequence is greater than or equal to the maximum ID in the table.
+
+ This is to handle various cases where the sequence value can get out
+ of sync with the table, e.g. if Synapse gets rolled back to a previous
+ version and the rolled forwards again.
+ """
+ ...
+
class PostgresSequenceGenerator(SequenceGenerator):
"""An implementation of SequenceGenerator which uses a postgres sequence"""
@@ -45,6 +81,50 @@ class PostgresSequenceGenerator(SequenceGenerator):
)
return [i for (i,) in txn]
+ def check_consistency(
+ self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+ ):
+ txn = db_conn.cursor()
+
+ # First we get the current max ID from the table.
+ table_sql = "SELECT GREATEST(%(agg)s(%(id)s), 0) FROM %(table)s" % {
+ "id": id_column,
+ "table": table,
+ "agg": "MAX" if positive else "-MIN",
+ }
+
+ txn.execute(table_sql)
+ row = txn.fetchone()
+ if not row:
+ # Table is empty, so nothing to do.
+ txn.close()
+ return
+
+ # Now we fetch the current value from the sequence and compare with the
+ # above.
+ max_stream_id = row[0]
+ txn.execute(
+ "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
+ )
+ last_value, is_called = txn.fetchone()
+ txn.close()
+
+ # If `is_called` is False then `last_value` is actually the value that
+ # will be generated next, so we decrement to get the true "last value".
+ if not is_called:
+ last_value -= 1
+
+ if max_stream_id > last_value:
+ logger.warning(
+ "Postgres sequence %s is behind table %s: %d < %d",
+ last_value,
+ max_stream_id,
+ )
+ raise IncorrectDatabaseSetup(
+ _INCONSISTENT_SEQUENCE_ERROR
+ % {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql}
+ )
+
GetFirstCallbackType = Callable[[Cursor], int]
@@ -81,6 +161,12 @@ class LocalSequenceGenerator(SequenceGenerator):
self._current_max_id += 1
return self._current_max_id
+ def check_consistency(
+ self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+ ):
+ # There is nothing to do for in memory sequences
+ pass
+
def build_sequence_generator(
database_engine: BaseDatabaseEngine,
|