diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 8e5d78f6f7..bbff3c8d5b 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -47,6 +47,9 @@ class Storage:
# interfaces.
self.main = stores.main
- self.persistence = EventsPersistenceStorage(hs, stores)
self.purge_events = PurgeEventsStorage(hs, stores)
self.state = StateGroupStorage(hs, stores)
+
+ self.persistence = None
+ if stores.persist_events:
+ self.persistence = EventsPersistenceStorage(hs, stores)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index ed8a9bffb1..79ec8f119d 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -952,7 +952,7 @@ class DatabasePool:
key_names: Collection[str],
key_values: Collection[Iterable[Any]],
value_names: Collection[str],
- value_values: Iterable[Iterable[str]],
+ value_values: Iterable[Iterable[Any]],
) -> None:
"""
Upsert, many times.
@@ -981,7 +981,7 @@ class DatabasePool:
key_names: Iterable[str],
key_values: Collection[Iterable[Any]],
value_names: Collection[str],
- value_values: Iterable[Iterable[str]],
+ value_values: Iterable[Iterable[Any]],
) -> None:
"""
Upsert, many times, but without native UPSERT support or batching.
diff --git a/synapse/storage/databases/__init__.py b/synapse/storage/databases/__init__.py
index 985b12df91..aa5d490624 100644
--- a/synapse/storage/databases/__init__.py
+++ b/synapse/storage/databases/__init__.py
@@ -75,7 +75,7 @@ class Databases:
# If we're on a process that can persist events also
# instantiate a `PersistEventsStore`
- if hs.config.worker.writers.events == hs.get_instance_name():
+ if hs.get_instance_name() in hs.config.worker.writers.events:
persist_events = PersistEventsStore(hs, database, main)
if "state" in database_config.databases:
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 2ae2fbd5d7..ccb3384db9 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -172,7 +172,7 @@ class DataStore(
else:
self._cache_id_gen = None
- super(DataStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self._presence_on_startup = self._get_active_presence(db_conn)
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 4436b1a83d..ef81d73573 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -29,22 +29,20 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
-class AccountDataWorkerStore(SQLBaseStore):
+# The ABCMeta metaclass ensures that it cannot be instantiated without
+# the abstract methods being implemented.
+class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
"""This is an abstract base class where subclasses must implement
`get_max_account_data_stream_id` which can be called in the initializer.
"""
- # This ABCMeta metaclass ensures that we cannot be instantiated without
- # the abstract methods being implemented.
- __metaclass__ = abc.ABCMeta
-
def __init__(self, database: DatabasePool, db_conn, hs):
account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max
)
- super(AccountDataWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
@abc.abstractmethod
def get_max_account_data_stream_id(self):
@@ -315,7 +313,7 @@ class AccountDataStore(AccountDataWorkerStore):
],
)
- super(AccountDataStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
def get_max_account_data_stream_id(self) -> int:
"""Get the current max stream id for the private user data stream
@@ -341,7 +339,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
- with await self._account_data_id_gen.get_next() as next_id:
+ async with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as room_account_data has a unique constraint
# on (user_id, room_id, account_data_type) so simple_upsert will
# retry if there is a conflict.
@@ -389,7 +387,7 @@ class AccountDataStore(AccountDataWorkerStore):
"""
content_json = json_encoder.encode(content)
- with await self._account_data_id_gen.get_next() as next_id:
+ async with self._account_data_id_gen.get_next() as next_id:
# no need to lock here as account_data has a unique constraint on
# (user_id, account_data_type) so simple_upsert will retry if
# there is a conflict.
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 454c0bc50c..85f6b1e3fd 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -52,7 +52,7 @@ class ApplicationServiceWorkerStore(SQLBaseStore):
)
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
- super(ApplicationServiceWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
def get_app_services(self):
return self.services_cache
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index c2fc847fbc..239c7a949c 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -31,7 +31,7 @@ LAST_SEEN_GRANULARITY = 120 * 1000
class ClientIpBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(ClientIpBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
"user_ips_device_index",
@@ -358,7 +358,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
name="client_ip_last_seen", keylen=4, max_entries=50000
)
- super(ClientIpStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.user_ips_max_age = hs.config.user_ips_max_age
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 0044433110..d42faa3f1f 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -283,7 +283,7 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(DeviceInboxBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
"device_inbox_stream_index",
@@ -313,7 +313,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(DeviceInboxStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
# Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions.
@@ -362,7 +362,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
rows.append((destination, stream_id, now_ms, edu_json))
txn.executemany(sql, rows)
- with await self._device_inbox_id_gen.get_next() as stream_id:
+ async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id
@@ -411,7 +411,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
txn, stream_id, local_messages_by_user_then_device
)
- with await self._device_inbox_id_gen.get_next() as stream_id:
+ async with self._device_inbox_id_gen.get_next() as stream_id:
now_ms = self.clock.time_msec()
await self.db_pool.runInteraction(
"add_messages_from_remote_to_device_inbox",
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index add4e3ea0e..fdf394c612 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -377,7 +377,7 @@ class DeviceWorkerStore(SQLBaseStore):
THe new stream ID.
"""
- with await self._device_list_id_gen.get_next() as stream_id:
+ async with self._device_list_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"add_user_sig_change_to_streams",
self._add_user_signature_change_txn,
@@ -481,7 +481,7 @@ class DeviceWorkerStore(SQLBaseStore):
}
async def get_users_whose_devices_changed(
- self, from_key: str, user_ids: Iterable[str]
+ self, from_key: int, user_ids: Iterable[str]
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
are in the given list of user_ids.
@@ -493,7 +493,6 @@ class DeviceWorkerStore(SQLBaseStore):
Returns:
The set of user_ids whose devices have changed since `from_key`
"""
- from_key = int(from_key)
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
@@ -527,7 +526,7 @@ class DeviceWorkerStore(SQLBaseStore):
)
async def get_users_whose_signatures_changed(
- self, user_id: str, from_key: str
+ self, user_id: str, from_key: int
) -> Set[str]:
"""Get the users who have new cross-signing signatures made by `user_id` since
`from_key`.
@@ -539,7 +538,7 @@ class DeviceWorkerStore(SQLBaseStore):
Returns:
A set of user IDs with updated signatures.
"""
- from_key = int(from_key)
+
if self._user_signature_stream_cache.has_entity_changed(user_id, from_key):
sql = """
SELECT DISTINCT user_ids FROM user_signature_stream
@@ -702,7 +701,7 @@ class DeviceWorkerStore(SQLBaseStore):
class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(DeviceBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
"device_lists_stream_idx",
@@ -827,7 +826,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(DeviceStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
@@ -1094,7 +1093,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not device_ids:
return
- with await self._device_list_id_gen.get_next_mult(
+ async with self._device_list_id_gen.get_next_mult(
len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
@@ -1109,7 +1108,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return stream_ids[-1]
context = get_active_span_text_map()
- with await self._device_list_id_gen.get_next_mult(
+ async with self._device_list_id_gen.get_next_mult(
len(hosts) * len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index fba3098ea2..22e1ed15d0 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
from synapse.handlers.e2e_keys import SignatureListItem
-@attr.s
+@attr.s(slots=True)
class DeviceKeyLookupResult:
"""The type returned by get_e2e_device_keys_and_signatures"""
@@ -831,7 +831,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
key (dict): the key data
"""
- with await self._cross_signing_id_gen.get_next() as stream_id:
+ async with self._cross_signing_id_gen.get_next() as stream_id:
return await self.db_pool.runInteraction(
"add_e2e_cross_signing_key",
self._set_e2e_cross_signing_key_txn,
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 0b69aa6a94..6d3689c09e 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -438,7 +438,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"""
if stream_ordering <= self.stream_ordering_month_ago:
- raise StoreError(400, "stream_ordering too old")
+ raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
sql = """
SELECT event_id FROM stream_ordering_to_exterm
@@ -600,7 +600,7 @@ class EventFederationStore(EventFederationWorkerStore):
EVENT_AUTH_STATE_ONLY = "event_auth_state_only"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(EventFederationStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
self.EVENT_AUTH_STATE_ONLY, self._background_delete_non_state_event_auth
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 5233ed83e2..62f1738732 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -68,7 +68,7 @@ def _deserialize_action(actions, is_highlight):
class EventPushActionsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(EventPushActionsWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
# These get correctly set by _find_stream_orderings_for_times_txn
self.stream_ordering_month_ago = None
@@ -661,7 +661,7 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(EventPushActionsStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
self.EPA_HIGHLIGHT_INDEX,
@@ -969,7 +969,7 @@ def _action_has_highlight(actions):
return False
-@attr.s
+@attr.s(slots=True)
class _EventPushSummary:
"""Summary of pending event push actions for a given user in a given room.
Used in _rotate_notifs_before_txn to manipulate results from event_push_actions.
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index b3d27a2ee7..18def01f50 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -17,7 +17,7 @@
import itertools
import logging
from collections import OrderedDict, namedtuple
-from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
import attr
from prometheus_client import Counter
@@ -32,7 +32,7 @@ from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchEntry
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.iterutils import batch_iter
@@ -97,18 +97,21 @@ class PersistEventsStore:
self.store = main_data_store
self.database_engine = db.engine
self._clock = hs.get_clock()
+ self._instance_name = hs.get_instance_name()
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
# Ideally we'd move these ID gens here, unfortunately some other ID
# generators are chained off them so doing so is a bit of a PITA.
- self._backfill_id_gen = self.store._backfill_id_gen # type: StreamIdGenerator
- self._stream_id_gen = self.store._stream_id_gen # type: StreamIdGenerator
+ self._backfill_id_gen = (
+ self.store._backfill_id_gen
+ ) # type: MultiWriterIdGenerator
+ self._stream_id_gen = self.store._stream_id_gen # type: MultiWriterIdGenerator
# This should only exist on instances that are configured to write
assert (
- hs.config.worker.writers.events == hs.get_instance_name()
+ hs.get_instance_name() in hs.config.worker.writers.events
), "Can only instantiate EventsStore on master"
async def _persist_events_and_state_updates(
@@ -153,15 +156,15 @@ class PersistEventsStore:
# Note: Multiple instances of this function cannot be in flight at
# the same time for the same room.
if backfilled:
- stream_ordering_manager = await self._backfill_id_gen.get_next_mult(
+ stream_ordering_manager = self._backfill_id_gen.get_next_mult(
len(events_and_contexts)
)
else:
- stream_ordering_manager = await self._stream_id_gen.get_next_mult(
+ stream_ordering_manager = self._stream_id_gen.get_next_mult(
len(events_and_contexts)
)
- with stream_ordering_manager as stream_orderings:
+ async with stream_ordering_manager as stream_orderings:
for (event, context), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream
@@ -213,7 +216,7 @@ class PersistEventsStore:
Returns:
Filtered event ids
"""
- results = []
+ results = [] # type: List[str]
def _get_events_which_are_prevs_txn(txn, batch):
sql = """
@@ -631,7 +634,9 @@ class PersistEventsStore:
)
@classmethod
- def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts):
+ def _filter_events_and_contexts_for_duplicates(
+ cls, events_and_contexts: List[Tuple[EventBase, EventContext]]
+ ) -> List[Tuple[EventBase, EventContext]]:
"""Ensure that we don't have the same event twice.
Pick the earliest non-outlier if there is one, else the earliest one.
@@ -641,7 +646,9 @@ class PersistEventsStore:
Returns:
list[(EventBase, EventContext)]: filtered list
"""
- new_events_and_contexts = OrderedDict()
+ new_events_and_contexts = (
+ OrderedDict()
+ ) # type: OrderedDict[str, Tuple[EventBase, EventContext]]
for event, context in events_and_contexts:
prev_event_context = new_events_and_contexts.get(event.event_id)
if prev_event_context:
@@ -655,7 +662,12 @@ class PersistEventsStore:
new_events_and_contexts[event.event_id] = (event, context)
return list(new_events_and_contexts.values())
- def _update_room_depths_txn(self, txn, events_and_contexts, backfilled):
+ def _update_room_depths_txn(
+ self,
+ txn,
+ events_and_contexts: List[Tuple[EventBase, EventContext]],
+ backfilled: bool,
+ ):
"""Update min_depth for each room
Args:
@@ -664,7 +676,7 @@ class PersistEventsStore:
we are persisting
backfilled (bool): True if the events were backfilled
"""
- depth_updates = {}
+ depth_updates = {} # type: Dict[str, int]
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
@@ -800,6 +812,7 @@ class PersistEventsStore:
table="events",
values=[
{
+ "instance_name": self._instance_name,
"stream_ordering": event.internal_metadata.stream_ordering,
"topological_ordering": event.depth,
"depth": event.depth,
@@ -1095,6 +1108,10 @@ class PersistEventsStore:
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
"""
+
+ def str_or_none(val: Any) -> Optional[str]:
+ return val if isinstance(val, str) else None
+
self.db_pool.simple_insert_many_txn(
txn,
table="room_memberships",
@@ -1105,8 +1122,8 @@ class PersistEventsStore:
"sender": event.user_id,
"room_id": event.room_id,
"membership": event.membership,
- "display_name": event.content.get("displayname", None),
- "avatar_url": event.content.get("avatar_url", None),
+ "display_name": str_or_none(event.content.get("displayname")),
+ "avatar_url": str_or_none(event.content.get("avatar_url")),
}
for event in events
],
@@ -1436,7 +1453,7 @@ class PersistEventsStore:
Forward extremities are handled when we first start persisting the events.
"""
- events_by_room = {}
+ events_by_room = {} # type: Dict[str, List[EventBase]]
for ev in events:
events_by_room.setdefault(ev.room_id, []).append(ev)
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index e53c6373a8..5e4af2eb51 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -29,7 +29,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
DELETE_SOFT_FAILED_EXTREMITIES = "delete_soft_failed_extremities"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(EventsBackgroundUpdatesStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index a7a73cc3d8..de9e8d1dc6 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from __future__ import division
-
import itertools
import logging
import threading
@@ -42,7 +40,8 @@ from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.engines import PostgresEngine
+from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import Collection, get_domain_from_id
from synapse.util.caches.descriptors import Cache, cached
from synapse.util.iterutils import batch_iter
@@ -76,29 +75,56 @@ class EventRedactBehaviour(Names):
class EventsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(EventsWorkerStore, self).__init__(database, db_conn, hs)
-
- if hs.config.worker.writers.events == hs.get_instance_name():
- # We are the process in charge of generating stream ids for events,
- # so instantiate ID generators based on the database
- self._stream_id_gen = StreamIdGenerator(
- db_conn, "events", "stream_ordering",
+ super().__init__(database, db_conn, hs)
+
+ if isinstance(database.engine, PostgresEngine):
+ # If we're using Postgres than we can use `MultiWriterIdGenerator`
+ # regardless of whether this process writes to the streams or not.
+ self._stream_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ instance_name=hs.get_instance_name(),
+ table="events",
+ instance_column="instance_name",
+ id_column="stream_ordering",
+ sequence_name="events_stream_seq",
)
- self._backfill_id_gen = StreamIdGenerator(
- db_conn,
- "events",
- "stream_ordering",
- step=-1,
- extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+ self._backfill_id_gen = MultiWriterIdGenerator(
+ db_conn=db_conn,
+ db=database,
+ instance_name=hs.get_instance_name(),
+ table="events",
+ instance_column="instance_name",
+ id_column="stream_ordering",
+ sequence_name="events_backfill_stream_seq",
+ positive=False,
)
else:
- # Another process is in charge of persisting events and generating
- # stream IDs: rely on the replication streams to let us know which
- # IDs we can process.
- self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
- self._backfill_id_gen = SlavedIdTracker(
- db_conn, "events", "stream_ordering", step=-1
- )
+ # We shouldn't be running in worker mode with SQLite, but its useful
+ # to support it for unit tests.
+ #
+ # If this process is the writer than we need to use
+ # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
+ # updated over replication. (Multiple writers are not supported for
+ # SQLite).
+ if hs.get_instance_name() in hs.config.worker.writers.events:
+ self._stream_id_gen = StreamIdGenerator(
+ db_conn, "events", "stream_ordering",
+ )
+ self._backfill_id_gen = StreamIdGenerator(
+ db_conn,
+ "events",
+ "stream_ordering",
+ step=-1,
+ extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
+ )
+ else:
+ self._stream_id_gen = SlavedIdTracker(
+ db_conn, "events", "stream_ordering"
+ )
+ self._backfill_id_gen = SlavedIdTracker(
+ db_conn, "events", "stream_ordering", step=-1
+ )
self._get_event_cache = Cache(
"*getEvent*",
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index ccfbb2135e..7218191965 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -1265,7 +1265,7 @@ class GroupServerStore(GroupServerWorkerStore):
return next_id
- with await self._group_updates_id_gen.get_next() as next_id:
+ async with self._group_updates_id_gen.get_next() as next_id:
res = await self.db_pool.runInteraction(
"register_user_group_membership",
_register_user_group_membership_txn,
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 86557d5512..cc538c5c10 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -17,12 +17,14 @@ from typing import Any, Dict, Iterable, List, Optional, Tuple
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool
+BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
+ "media_repository_drop_index_wo_method"
+)
+
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(MediaRepositoryBackgroundUpdateStore, self).__init__(
- database, db_conn, hs
- )
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
update_name="local_media_repository_url_idx",
@@ -32,12 +34,65 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
where_clause="url_cache IS NOT NULL",
)
+ # The following the updates add the method to the unique constraint of
+ # the thumbnail databases. That fixes an issue, where thumbnails of the
+ # same resolution, but different methods could overwrite one another.
+ # This can happen with custom thumbnail configs or with dynamic thumbnailing.
+ self.db_pool.updates.register_background_index_update(
+ update_name="local_media_repository_thumbnails_method_idx",
+ index_name="local_media_repository_thumbn_media_id_width_height_method_key",
+ table="local_media_repository_thumbnails",
+ columns=[
+ "media_id",
+ "thumbnail_width",
+ "thumbnail_height",
+ "thumbnail_type",
+ "thumbnail_method",
+ ],
+ unique=True,
+ )
+
+ self.db_pool.updates.register_background_index_update(
+ update_name="remote_media_repository_thumbnails_method_idx",
+ index_name="remote_media_repository_thumbn_media_origin_id_width_height_method_key",
+ table="remote_media_cache_thumbnails",
+ columns=[
+ "media_origin",
+ "media_id",
+ "thumbnail_width",
+ "thumbnail_height",
+ "thumbnail_type",
+ "thumbnail_method",
+ ],
+ unique=True,
+ )
+
+ self.db_pool.updates.register_background_update_handler(
+ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD,
+ self._drop_media_index_without_method,
+ )
+
+ async def _drop_media_index_without_method(self, progress, batch_size):
+ def f(txn):
+ txn.execute(
+ "ALTER TABLE local_media_repository_thumbnails DROP CONSTRAINT IF EXISTS local_media_repository_thumbn_media_id_thumbnail_width_thum_key"
+ )
+ txn.execute(
+ "ALTER TABLE remote_media_cache_thumbnails DROP CONSTRAINT IF EXISTS remote_media_repository_thumbn_media_id_thumbnail_width_thum_key"
+ )
+
+ await self.db_pool.runInteraction("drop_media_indices_without_method", f)
+ await self.db_pool.updates._end_background_update(
+ BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD
+ )
+ return 1
+
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"""Persistence for attachments and avatars"""
def __init__(self, database: DatabasePool, db_conn, hs):
- super(MediaRepositoryStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
"""Get the metadata for a local piece of media
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 1d793d3deb..e0cedd1aac 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -28,7 +28,7 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000
class MonthlyActiveUsersWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(MonthlyActiveUsersWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self._clock = hs.get_clock()
self.hs = hs
@@ -120,7 +120,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(MonthlyActiveUsersStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
self._mau_stats_only = hs.config.mau_stats_only
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index c9f655dfb7..dbbb99cb95 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -23,11 +23,11 @@ from synapse.util.iterutils import batch_iter
class PresenceStore(SQLBaseStore):
async def update_presence(self, presence_states):
- stream_ordering_manager = await self._presence_id_gen.get_next_mult(
+ stream_ordering_manager = self._presence_id_gen.get_next_mult(
len(presence_states)
)
- with stream_ordering_manager as stream_orderings:
+ async with stream_ordering_manager as stream_orderings:
await self.db_pool.runInteraction(
"update_presence",
self._update_presence_txn,
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index d2e0685e9e..de37866d25 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,11 +13,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, Optional
+from typing import Any, Dict, List, Optional, Tuple
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.databases.main.roommember import ProfileInfo
+from synapse.types import UserID
+from synapse.util.caches.descriptors import cached
+
+BATCH_SIZE = 100
class ProfileWorkerStore(SQLBaseStore):
@@ -39,6 +44,7 @@ class ProfileWorkerStore(SQLBaseStore):
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
)
+ @cached(max_entries=5000)
async def get_profile_displayname(self, user_localpart: str) -> str:
return await self.db_pool.simple_select_one_onecol(
table="profiles",
@@ -47,6 +53,7 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_displayname",
)
+ @cached(max_entries=5000)
async def get_profile_avatar_url(self, user_localpart: str) -> str:
return await self.db_pool.simple_select_one_onecol(
table="profiles",
@@ -55,6 +62,58 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_avatar_url",
)
+ async def get_latest_profile_replication_batch_number(self):
+ def f(txn):
+ txn.execute("SELECT MAX(batch) as maxbatch FROM profiles")
+ rows = self.db_pool.cursor_to_dict(txn)
+ return rows[0]["maxbatch"]
+
+ return await self.db_pool.runInteraction(
+ "get_latest_profile_replication_batch_number", f
+ )
+
+ async def get_profile_batch(self, batchnum):
+ return await self.db_pool.simple_select_list(
+ table="profiles",
+ keyvalues={"batch": batchnum},
+ retcols=("user_id", "displayname", "avatar_url", "active"),
+ desc="get_profile_batch",
+ )
+
+ async def assign_profile_batch(self):
+ def f(txn):
+ sql = (
+ "UPDATE profiles SET batch = "
+ "(SELECT COALESCE(MAX(batch), -1) + 1 FROM profiles) "
+ "WHERE user_id in ("
+ " SELECT user_id FROM profiles WHERE batch is NULL limit ?"
+ ")"
+ )
+ txn.execute(sql, (BATCH_SIZE,))
+ return txn.rowcount
+
+ return await self.db_pool.runInteraction("assign_profile_batch", f)
+
+ async def get_replication_hosts(self):
+ def f(txn):
+ txn.execute(
+ "SELECT host, last_synced_batch FROM profile_replication_status"
+ )
+ rows = self.db_pool.cursor_to_dict(txn)
+ return {r["host"]: r["last_synced_batch"] for r in rows}
+
+ return await self.db_pool.runInteraction("get_replication_hosts", f)
+
+ async def update_replication_batch_for_host(
+ self, host: str, last_synced_batch: int
+ ):
+ return await self.db_pool.simple_upsert(
+ table="profile_replication_status",
+ keyvalues={"host": host},
+ values={"last_synced_batch": last_synced_batch},
+ desc="update_replication_batch_for_host",
+ )
+
async def get_from_remote_profile_cache(
self, user_id: str
) -> Optional[Dict[str, Any]]:
@@ -72,27 +131,82 @@ class ProfileWorkerStore(SQLBaseStore):
)
async def set_profile_displayname(
- self, user_localpart: str, new_displayname: str
+ self, user_localpart: str, new_displayname: str, batchnum: int
) -> None:
- await self.db_pool.simple_update_one(
+ # Invalidate the read cache for this user
+ self.get_profile_displayname.invalidate((user_localpart,))
+
+ await self.db_pool.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
- updatevalues={"displayname": new_displayname},
+ values={"displayname": new_displayname, "batch": batchnum},
desc="set_profile_displayname",
+ lock=False, # we can do this because user_id has a unique index
)
async def set_profile_avatar_url(
- self, user_localpart: str, new_avatar_url: str
+ self, user_localpart: str, new_avatar_url: str, batchnum: int
) -> None:
- await self.db_pool.simple_update_one(
+ # Invalidate the read cache for this user
+ self.get_profile_avatar_url.invalidate((user_localpart,))
+
+ await self.db_pool.simple_upsert(
table="profiles",
keyvalues={"user_id": user_localpart},
- updatevalues={"avatar_url": new_avatar_url},
+ values={"avatar_url": new_avatar_url, "batch": batchnum},
desc="set_profile_avatar_url",
+ lock=False, # we can do this because user_id has a unique index
+ )
+
+ async def set_profiles_active(
+ self, users: List[UserID], active: bool, hide: bool, batchnum: int,
+ ) -> None:
+ """Given a set of users, set active and hidden flags on them.
+
+ Args:
+ users: A list of UserIDs
+ active: Whether to set the users to active or inactive
+ hide: Whether to hide the users (withold from replication). If
+ False and active is False, users will have their profiles
+ erased
+ batchnum: The batch number, used for profile replication
+ """
+ # Convert list of localparts to list of tuples containing localparts
+ user_localparts = [(user.localpart,) for user in users]
+
+ # Generate list of value tuples for each user
+ value_names = ("active", "batch")
+ values = [(int(active), batchnum) for _ in user_localparts] # type: List[Tuple]
+
+ if not active and not hide:
+ # we are deactivating for real (not in hide mode)
+ # so clear the profile information
+ value_names += ("avatar_url", "displayname")
+ values = [v + (None, None) for v in values]
+
+ return await self.db_pool.runInteraction(
+ "set_profiles_active",
+ self.db_pool.simple_upsert_many_txn,
+ table="profiles",
+ key_names=("user_id",),
+ key_values=user_localparts,
+ value_names=value_names,
+ value_values=values,
)
class ProfileStore(ProfileWorkerStore):
+ def __init__(self, database, db_conn, hs):
+ super().__init__(database, db_conn, hs)
+
+ self.db_pool.updates.register_background_index_update(
+ "profile_replication_status_host_index",
+ index_name="profile_replication_status_idx",
+ table="profile_replication_status",
+ columns=["host"],
+ unique=True,
+ )
+
async def add_remote_profile_cache(
self, user_id: str, displayname: str, avatar_url: str
) -> None:
@@ -115,10 +229,10 @@ class ProfileStore(ProfileWorkerStore):
async def update_remote_profile_cache(
self, user_id: str, displayname: str, avatar_url: str
) -> int:
- return await self.db_pool.simple_update(
+ return await self.db_pool.simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
- updatevalues={
+ values={
"displayname": displayname,
"avatar_url": avatar_url,
"last_check": self._clock.time_msec(),
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index ea833829ae..d7a03cbf7d 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -69,6 +69,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
# room_depth
# state_groups
# state_groups_state
+ # destination_rooms
# we will build a temporary table listing the events so that we don't
# have to keep shovelling the list back and forth across the
@@ -336,6 +337,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
# and finally, the tables with an index on room_id (or no useful index)
for table in (
"current_state_events",
+ "destination_rooms",
"event_backward_extremities",
"event_forward_extremities",
"event_json",
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 0de802a86b..711d5aa23d 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -13,11 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import abc
import logging
from typing import List, Tuple, Union
+from synapse.api.errors import NotFoundError, StoreError
from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -27,6 +27,7 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.pusher import PusherWorkerStore
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util import json_encoder
@@ -60,6 +61,8 @@ def _load_rules(rawrules, enabled_map, use_new_defaults=False):
return rules
+# The ABCMeta metaclass ensures that it cannot be instantiated without
+# the abstract methods being implemented.
class PushRulesWorkerStore(
ApplicationServiceWorkerStore,
ReceiptsWorkerStore,
@@ -67,17 +70,14 @@ class PushRulesWorkerStore(
RoomMemberWorkerStore,
EventsWorkerStore,
SQLBaseStore,
+ metaclass=abc.ABCMeta,
):
"""This is an abstract base class where subclasses must implement
`get_max_push_rules_stream_id` which can be called in the initializer.
"""
- # This ABCMeta metaclass ensures that we cannot be instantiated without
- # the abstract methods being implemented.
- __metaclass__ = abc.ABCMeta
-
def __init__(self, database: DatabasePool, db_conn, hs):
- super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
self._push_rules_stream_id_gen = StreamIdGenerator(
@@ -338,7 +338,7 @@ class PushRuleStore(PushRulesWorkerStore):
) -> None:
conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions)
- with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
if before or after:
@@ -540,6 +540,25 @@ class PushRuleStore(PushRulesWorkerStore):
},
)
+ # ensure we have a push_rules_enable row
+ # enabledness defaults to true
+ if isinstance(self.database_engine, PostgresEngine):
+ sql = """
+ INSERT INTO push_rules_enable (id, user_name, rule_id, enabled)
+ VALUES (?, ?, ?, ?)
+ ON CONFLICT DO NOTHING
+ """
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ sql = """
+ INSERT OR IGNORE INTO push_rules_enable (id, user_name, rule_id, enabled)
+ VALUES (?, ?, ?, ?)
+ """
+ else:
+ raise RuntimeError("Unknown database engine")
+
+ new_enable_id = self._push_rules_enable_id_gen.get_next()
+ txn.execute(sql, (new_enable_id, user_id, rule_id, 1))
+
async def delete_push_rule(self, user_id: str, rule_id: str) -> None:
"""
Delete a push rule. Args specify the row to be deleted and can be
@@ -552,6 +571,12 @@ class PushRuleStore(PushRulesWorkerStore):
"""
def delete_push_rule_txn(txn, stream_id, event_stream_ordering):
+ # we don't use simple_delete_one_txn because that would fail if the
+ # user did not have a push_rule_enable row.
+ self.db_pool.simple_delete_txn(
+ txn, "push_rules_enable", {"user_name": user_id, "rule_id": rule_id}
+ )
+
self.db_pool.simple_delete_one_txn(
txn, "push_rules", {"user_name": user_id, "rule_id": rule_id}
)
@@ -560,7 +585,7 @@ class PushRuleStore(PushRulesWorkerStore):
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
)
- with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
@@ -570,10 +595,29 @@ class PushRuleStore(PushRulesWorkerStore):
event_stream_ordering,
)
- async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
- with await self._push_rules_stream_id_gen.get_next() as stream_id:
- event_stream_ordering = self._stream_id_gen.get_current_token()
+ async def set_push_rule_enabled(
+ self, user_id: str, rule_id: str, enabled: bool, is_default_rule: bool
+ ) -> None:
+ """
+ Sets the `enabled` state of a push rule.
+ Args:
+ user_id: the user ID of the user who wishes to enable/disable the rule
+ e.g. '@tina:example.org'
+ rule_id: the full rule ID of the rule to be enabled/disabled
+ e.g. 'global/override/.m.rule.roomnotif'
+ or 'global/override/myCustomRule'
+ enabled: True if the rule is to be enabled, False if it is to be
+ disabled
+ is_default_rule: True if and only if this is a server-default rule.
+ This skips the check for existence (as only user-created rules
+ are always stored in the database `push_rules` table).
+
+ Raises:
+ NotFoundError if the rule does not exist.
+ """
+ async with self._push_rules_stream_id_gen.get_next() as stream_id:
+ event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
"_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn,
@@ -582,12 +626,47 @@ class PushRuleStore(PushRulesWorkerStore):
user_id,
rule_id,
enabled,
+ is_default_rule,
)
def _set_push_rule_enabled_txn(
- self, txn, stream_id, event_stream_ordering, user_id, rule_id, enabled
+ self,
+ txn,
+ stream_id,
+ event_stream_ordering,
+ user_id,
+ rule_id,
+ enabled,
+ is_default_rule,
):
new_id = self._push_rules_enable_id_gen.get_next()
+
+ if not is_default_rule:
+ # first check it exists; we need to lock for key share so that a
+ # transaction that deletes the push rule will conflict with this one.
+ # We also need a push_rule_enable row to exist for every push_rules
+ # row, otherwise it is possible to simultaneously delete a push rule
+ # (that has no _enable row) and enable it, resulting in a dangling
+ # _enable row. To solve this: we either need to use SERIALISABLE or
+ # ensure we always have a push_rule_enable row for every push_rule
+ # row. We chose the latter.
+ for_key_share = "FOR KEY SHARE"
+ if not isinstance(self.database_engine, PostgresEngine):
+ # For key share is not applicable/available on SQLite
+ for_key_share = ""
+ sql = (
+ """
+ SELECT 1 FROM push_rules
+ WHERE user_name = ? AND rule_id = ?
+ %s
+ """
+ % for_key_share
+ )
+ txn.execute(sql, (user_id, rule_id))
+ if txn.fetchone() is None:
+ # needed to set NOT_FOUND code.
+ raise NotFoundError("Push rule does not exist.")
+
self.db_pool.simple_upsert_txn(
txn,
"push_rules_enable",
@@ -606,8 +685,30 @@ class PushRuleStore(PushRulesWorkerStore):
)
async def set_push_rule_actions(
- self, user_id, rule_id, actions, is_default_rule
+ self,
+ user_id: str,
+ rule_id: str,
+ actions: List[Union[dict, str]],
+ is_default_rule: bool,
) -> None:
+ """
+ Sets the `actions` state of a push rule.
+
+ Will throw NotFoundError if the rule does not exist; the Code for this
+ is NOT_FOUND.
+
+ Args:
+ user_id: the user ID of the user who wishes to enable/disable the rule
+ e.g. '@tina:example.org'
+ rule_id: the full rule ID of the rule to be enabled/disabled
+ e.g. 'global/override/.m.rule.roomnotif'
+ or 'global/override/myCustomRule'
+ actions: A list of actions (each action being a dict or string),
+ e.g. ["notify", {"set_tweak": "highlight", "value": false}]
+ is_default_rule: True if and only if this is a server-default rule.
+ This skips the check for existence (as only user-created rules
+ are always stored in the database `push_rules` table).
+ """
actions_json = json_encoder.encode(actions)
def set_push_rule_actions_txn(txn, stream_id, event_stream_ordering):
@@ -629,12 +730,19 @@ class PushRuleStore(PushRulesWorkerStore):
update_stream=False,
)
else:
- self.db_pool.simple_update_one_txn(
- txn,
- "push_rules",
- {"user_name": user_id, "rule_id": rule_id},
- {"actions": actions_json},
- )
+ try:
+ self.db_pool.simple_update_one_txn(
+ txn,
+ "push_rules",
+ {"user_name": user_id, "rule_id": rule_id},
+ {"actions": actions_json},
+ )
+ except StoreError as serr:
+ if serr.code == 404:
+ # this sets the NOT_FOUND error Code
+ raise NotFoundError("Push rule does not exist")
+ else:
+ raise
self._insert_push_rules_update_txn(
txn,
@@ -646,7 +754,7 @@ class PushRuleStore(PushRulesWorkerStore):
data={"actions": actions_json},
)
- with await self._push_rules_stream_id_gen.get_next() as stream_id:
+ async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction(
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index c388468273..df8609b97b 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -281,7 +281,7 @@ class PusherStore(PusherWorkerStore):
last_stream_ordering,
profile_tag="",
) -> None:
- with await self._pushers_id_gen.get_next() as stream_id:
+ async with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
# (app_id, pushkey, user_name) so simple_upsert will retry
await self.db_pool.simple_upsert(
@@ -344,7 +344,7 @@ class PusherStore(PusherWorkerStore):
},
)
- with await self._pushers_id_gen.get_next() as stream_id:
+ async with self._pushers_id_gen.get_next() as stream_id:
await self.db_pool.runInteraction(
"delete_pusher", delete_pusher_txn, stream_id
)
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 4a0d5a320e..c79ddff680 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -31,17 +31,15 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
-class ReceiptsWorkerStore(SQLBaseStore):
+# The ABCMeta metaclass ensures that it cannot be instantiated without
+# the abstract methods being implemented.
+class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
"""This is an abstract base class where subclasses must implement
`get_max_receipt_stream_id` which can be called in the initializer.
"""
- # This ABCMeta metaclass ensures that we cannot be instantiated without
- # the abstract methods being implemented.
- __metaclass__ = abc.ABCMeta
-
def __init__(self, database: DatabasePool, db_conn, hs):
- super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
@@ -388,7 +386,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
db_conn, "receipts_linearized", "stream_id"
)
- super(ReceiptsStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_current_token()
@@ -526,7 +524,7 @@ class ReceiptsStore(ReceiptsWorkerStore):
"insert_receipt_conv", graph_to_linear
)
- with await self._receipts_id_gen.get_next() as stream_id:
+ async with self._receipts_id_gen.get_next() as stream_id:
event_ts = await self.db_pool.runInteraction(
"insert_linearized_receipt",
self.insert_linearized_receipt_txn,
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 01f20c03c2..a06451b7f0 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -36,7 +36,7 @@ logger = logging.getLogger(__name__)
class RegistrationWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RegistrationWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.config = hs.config
self.clock = hs.get_clock()
@@ -116,6 +116,20 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="get_expiration_ts_for_user",
)
+ async def is_account_expired(self, user_id: str, current_ts: int) -> bool:
+ """
+ Returns whether an user account is expired.
+
+ Args:
+ user_id: The user's ID
+ current_ts: The current timestamp
+
+ Returns:
+ Whether the user account has expired
+ """
+ expiration_ts = await self.get_expiration_ts_for_user(user_id)
+ return expiration_ts is not None and current_ts >= expiration_ts
+
async def set_account_validity_for_user(
self,
user_id: str,
@@ -156,6 +170,37 @@ class RegistrationWorkerStore(SQLBaseStore):
"set_account_validity_for_user", set_account_validity_for_user_txn
)
+ async def get_expired_users(self):
+ """Get UserIDs of all expired users.
+
+ Users who are not active, or do not have profile information, are
+ excluded from the results.
+
+ Returns:
+ Deferred[List[UserID]]: List of expired user IDs
+ """
+
+ def get_expired_users_txn(txn, now_ms):
+ # We need to use pattern matching as profiles.user_id is confusingly just the
+ # user's localpart, whereas account_validity.user_id is a full user ID
+ sql = """
+ SELECT av.user_id from account_validity AS av
+ LEFT JOIN profiles as p
+ ON av.user_id LIKE '%%' || p.user_id || ':%%'
+ WHERE expiration_ts_ms <= ?
+ AND p.active = 1
+ """
+ txn.execute(sql, (now_ms,))
+ rows = txn.fetchall()
+
+ return [UserID.from_string(row[0]) for row in rows]
+
+ res = await self.db_pool.runInteraction(
+ "get_expired_users", get_expired_users_txn, self.clock.time_msec()
+ )
+
+ return res
+
async def set_renewal_token_for_user(
self, user_id: str, renewal_token: str
) -> None:
@@ -262,6 +307,54 @@ class RegistrationWorkerStore(SQLBaseStore):
desc="delete_account_validity_for_user",
)
+ async def get_info_for_users(
+ self, user_ids: List[str],
+ ):
+ """Return the user info for a given set of users
+
+ Args:
+ user_ids: A list of users to return information about
+
+ Returns:
+ Deferred[Dict[str, bool]]: A dictionary mapping each user ID to
+ a dict with the following keys:
+ * expired - whether this is an expired user
+ * deactivated - whether this is a deactivated user
+ """
+ # Get information of all our local users
+ def _get_info_for_users_txn(txn):
+ rows = []
+
+ for user_id in user_ids:
+ sql = """
+ SELECT u.name, u.deactivated, av.expiration_ts_ms
+ FROM users as u
+ LEFT JOIN account_validity as av
+ ON av.user_id = u.name
+ WHERE u.name = ?
+ """
+
+ txn.execute(sql, (user_id,))
+ row = txn.fetchone()
+ if row:
+ rows.append(row)
+
+ return rows
+
+ info_rows = await self.db_pool.runInteraction(
+ "get_info_for_users", _get_info_for_users_txn
+ )
+
+ return {
+ user_id: {
+ "expired": (
+ expiration is not None and self.clock.time_msec() >= expiration
+ ),
+ "deactivated": deactivated == 1,
+ }
+ for user_id, deactivated, expiration in info_rows
+ }
+
async def is_server_admin(self, user: UserID) -> bool:
"""Determines if a user is an admin of this homeserver.
@@ -764,7 +857,7 @@ class RegistrationWorkerStore(SQLBaseStore):
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RegistrationBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.clock = hs.get_clock()
self.config = hs.config
@@ -892,7 +985,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
class RegistrationStore(RegistrationBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RegistrationStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self._account_validity = hs.config.account_validity
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 717df97301..8fab8de973 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -69,7 +69,7 @@ class RoomSortOrder(Enum):
class RoomWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RoomWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.config = hs.config
@@ -104,7 +104,8 @@ class RoomWorkerStore(SQLBaseStore):
curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
rooms.creator, state.encryption, state.is_federatable AS federatable,
rooms.is_public AS public, state.join_rules, state.guest_access,
- state.history_visibility, curr.current_state_events AS state_events
+ state.history_visibility, curr.current_state_events AS state_events,
+ state.avatar, state.topic
FROM rooms
LEFT JOIN room_stats_state state USING (room_id)
LEFT JOIN room_stats_current curr USING (room_id)
@@ -343,6 +344,23 @@ class RoomWorkerStore(SQLBaseStore):
desc="is_room_blocked",
)
+ async def is_room_published(self, room_id: str) -> bool:
+ """Check whether a room has been published in the local public room
+ directory.
+
+ Args:
+ room_id
+ Returns:
+ Whether the room is currently published in the room directory
+ """
+ # Get room information
+ room_info = await self.get_room(room_id)
+ if not room_info:
+ return False
+
+ # Check the is_public value
+ return room_info.get("is_public", False)
+
async def get_rooms_paginate(
self,
start: int,
@@ -551,6 +569,11 @@ class RoomWorkerStore(SQLBaseStore):
Returns:
dict[int, int]: "min_lifetime" and "max_lifetime" for this room.
"""
+ # If the room retention feature is disabled, return a policy with no minimum nor
+ # maximum, in order not to filter out events we should filter out when sending to
+ # the client.
+ if not self.config.retention_enabled:
+ return {"min_lifetime": None, "max_lifetime": None}
def get_retention_policy_for_room_txn(txn):
txn.execute(
@@ -862,7 +885,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
ADD_ROOMS_ROOM_VERSION_COLUMN = "add_rooms_room_version_column"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RoomBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.config = hs.config
@@ -1073,7 +1096,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RoomStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.config = hs.config
@@ -1136,7 +1159,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with await self._public_room_id_gen.get_next() as next_id:
+ async with self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"store_room_txn", store_room_txn, next_id
)
@@ -1203,7 +1226,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with await self._public_room_id_gen.get_next() as next_id:
+ async with self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"set_room_is_public", set_room_is_public_txn, next_id
)
@@ -1283,7 +1306,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
},
)
- with await self._public_room_id_gen.get_next() as next_id:
+ async with self._public_room_id_gen.get_next() as next_id:
await self.db_pool.runInteraction(
"set_room_is_public_appservice",
set_room_is_public_appservice_txn,
@@ -1327,6 +1350,101 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
desc="add_event_report",
)
+ async def get_event_reports_paginate(
+ self,
+ start: int,
+ limit: int,
+ direction: str = "b",
+ user_id: Optional[str] = None,
+ room_id: Optional[str] = None,
+ ) -> Tuple[List[Dict[str, Any]], int]:
+ """Retrieve a paginated list of event reports
+
+ Args:
+ start: event offset to begin the query from
+ limit: number of rows to retrieve
+ direction: Whether to fetch the most recent first (`"b"`) or the
+ oldest first (`"f"`)
+ user_id: search for user_id. Ignored if user_id is None
+ room_id: search for room_id. Ignored if room_id is None
+ Returns:
+ event_reports: json list of event reports
+ count: total number of event reports matching the filter criteria
+ """
+
+ def _get_event_reports_paginate_txn(txn):
+ filters = []
+ args = []
+
+ if user_id:
+ filters.append("er.user_id LIKE ?")
+ args.extend(["%" + user_id + "%"])
+ if room_id:
+ filters.append("er.room_id LIKE ?")
+ args.extend(["%" + room_id + "%"])
+
+ if direction == "b":
+ order = "DESC"
+ else:
+ order = "ASC"
+
+ where_clause = "WHERE " + " AND ".join(filters) if len(filters) > 0 else ""
+
+ sql = """
+ SELECT COUNT(*) as total_event_reports
+ FROM event_reports AS er
+ {}
+ """.format(
+ where_clause
+ )
+ txn.execute(sql, args)
+ count = txn.fetchone()[0]
+
+ sql = """
+ SELECT
+ er.id,
+ er.received_ts,
+ er.room_id,
+ er.event_id,
+ er.user_id,
+ er.reason,
+ er.content,
+ events.sender,
+ room_aliases.room_alias,
+ event_json.json AS event_json
+ FROM event_reports AS er
+ LEFT JOIN room_aliases
+ ON room_aliases.room_id = er.room_id
+ JOIN events
+ ON events.event_id = er.event_id
+ JOIN event_json
+ ON event_json.event_id = er.event_id
+ {where_clause}
+ ORDER BY er.received_ts {order}
+ LIMIT ?
+ OFFSET ?
+ """.format(
+ where_clause=where_clause, order=order,
+ )
+
+ args += [limit, start]
+ txn.execute(sql, args)
+ event_reports = self.db_pool.cursor_to_dict(txn)
+
+ if count > 0:
+ for row in event_reports:
+ try:
+ row["content"] = db_to_json(row["content"])
+ row["event_json"] = db_to_json(row["event_json"])
+ except Exception:
+ continue
+
+ return event_reports, count
+
+ return await self.db_pool.runInteraction(
+ "get_event_reports_paginate", _get_event_reports_paginate_txn
+ )
+
def get_current_public_room_stream_id(self):
return self._public_room_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 91a8b43da3..4fa8767b01 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -55,7 +55,7 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
class RoomMemberWorkerStore(EventsWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RoomMemberWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
# Is the current_state_events.membership up to date? Or is the
# background update still running?
@@ -819,7 +819,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
class RoomMemberBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RoomMemberBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
)
@@ -973,7 +973,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(RoomMemberStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
async def forget(self, user_id: str, room_id: str) -> None:
"""Indicate that user_id wishes to discard history for room_id."""
diff --git a/synapse/storage/databases/main/schema/delta/48/profiles_batch.sql b/synapse/storage/databases/main/schema/delta/48/profiles_batch.sql
new file mode 100644
index 0000000000..e744c02fe8
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/48/profiles_batch.sql
@@ -0,0 +1,36 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Add a batch number to track changes to profiles and the
+ * order they're made in so we can replicate user profiles
+ * to other hosts as they change
+ */
+ALTER TABLE profiles ADD COLUMN batch BIGINT DEFAULT NULL;
+
+/*
+ * Index on the batch number so we can get profiles
+ * by their batch
+ */
+CREATE INDEX profiles_batch_idx ON profiles(batch);
+
+/*
+ * A table to track what batch of user profiles has been
+ * synced to what profile replication target.
+ */
+CREATE TABLE profile_replication_status (
+ host TEXT NOT NULL,
+ last_synced_batch BIGINT NOT NULL
+);
diff --git a/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql b/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql
new file mode 100644
index 0000000000..96051ac179
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/50/profiles_deactivated_users.sql
@@ -0,0 +1,23 @@
+/* Copyright 2018 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * A flag saying whether the user owning the profile has been deactivated
+ * This really belongs on the users table, not here, but the users table
+ * stores users by their full user_id and profiles stores them by localpart,
+ * so we can't easily join between the two tables. Plus, the batch number
+ * realy ought to represent data in this table that has changed.
+ */
+ALTER TABLE profiles ADD COLUMN active SMALLINT DEFAULT 1 NOT NULL;
\ No newline at end of file
diff --git a/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql b/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql
new file mode 100644
index 0000000000..7542ab8cbd
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/55/profile_replication_status_index.sql
@@ -0,0 +1,16 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE UNIQUE INDEX profile_replication_status_idx ON profile_replication_status(host);
\ No newline at end of file
diff --git a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
new file mode 100644
index 0000000000..b64926e9c9
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
@@ -0,0 +1,33 @@
+/* Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * This adds the method to the unique key constraint of the thumbnail databases.
+ * Otherwise you can't have a scaled and a cropped thumbnail with the same
+ * resolution, which happens quite often with dynamic thumbnailing.
+ * This is the postgres specific migration modifying the table with a background
+ * migration.
+ */
+
+-- add new index that includes method to local media
+INSERT INTO background_updates (update_name, progress_json) VALUES
+ ('local_media_repository_thumbnails_method_idx', '{}');
+
+-- add new index that includes method to remote media
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+ ('remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx');
+
+-- drop old index
+INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
+ ('media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx');
+
diff --git a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite
new file mode 100644
index 0000000000..1d0c04b53a
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.sqlite
@@ -0,0 +1,44 @@
+/* Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * This adds the method to the unique key constraint of the thumbnail databases.
+ * Otherwise you can't have a scaled and a cropped thumbnail with the same
+ * resolution, which happens quite often with dynamic thumbnailing.
+ * This is a sqlite specific migration, since sqlite can't modify the unique
+ * constraint of a table without recreating it.
+ */
+
+CREATE TABLE local_media_repository_thumbnails_new ( media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_type TEXT, thumbnail_method TEXT, thumbnail_length INTEGER, UNIQUE ( media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method ) );
+
+INSERT INTO local_media_repository_thumbnails_new
+ SELECT media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method, thumbnail_length
+ FROM local_media_repository_thumbnails;
+
+DROP TABLE local_media_repository_thumbnails;
+
+ALTER TABLE local_media_repository_thumbnails_new RENAME TO local_media_repository_thumbnails;
+
+CREATE INDEX local_media_repository_thumbnails_media_id ON local_media_repository_thumbnails (media_id);
+
+
+
+CREATE TABLE IF NOT EXISTS remote_media_cache_thumbnails_new ( media_origin TEXT, media_id TEXT, thumbnail_width INTEGER, thumbnail_height INTEGER, thumbnail_method TEXT, thumbnail_type TEXT, thumbnail_length INTEGER, filesystem_id TEXT, UNIQUE ( media_origin, media_id, thumbnail_width, thumbnail_height, thumbnail_type, thumbnail_method ) );
+
+INSERT INTO remote_media_cache_thumbnails_new
+ SELECT media_origin, media_id, thumbnail_width, thumbnail_height, thumbnail_method, thumbnail_type, thumbnail_length, filesystem_id
+ FROM remote_media_cache_thumbnails;
+
+DROP TABLE remote_media_cache_thumbnails;
+
+ALTER TABLE remote_media_cache_thumbnails_new RENAME TO remote_media_cache_thumbnails;
diff --git a/synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql b/synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql
new file mode 100644
index 0000000000..847aebd85e
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/10_pushrules_enabled_delete_obsolete.sql
@@ -0,0 +1,28 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ Delete stuck 'enabled' bits that correspond to deleted or non-existent push rules.
+ We ignore rules that are server-default rules because they are not defined
+ in the `push_rules` table.
+**/
+
+DELETE FROM push_rules_enable WHERE
+ rule_id NOT LIKE 'global/%/.m.rule.%'
+ AND NOT EXISTS (
+ SELECT 1 FROM push_rules
+ WHERE push_rules.user_name = push_rules_enable.user_name
+ AND push_rules.rule_id = push_rules_enable.rule_id
+ );
diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql
new file mode 100644
index 0000000000..98ff76d709
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql
@@ -0,0 +1,16 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE events ADD COLUMN instance_name TEXT;
diff --git a/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
new file mode 100644
index 0000000000..97c1e6a0c5
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/14events_instance_name.sql.postgres
@@ -0,0 +1,26 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE SEQUENCE IF NOT EXISTS events_stream_seq;
+
+SELECT setval('events_stream_seq', (
+ SELECT COALESCE(MAX(stream_ordering), 1) FROM events
+));
+
+CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq;
+
+SELECT setval('events_backfill_stream_seq', (
+ SELECT COALESCE(-MIN(stream_ordering), 1) FROM events
+));
diff --git a/synapse/storage/databases/main/schema/delta/58/15_catchup_destination_rooms.sql b/synapse/storage/databases/main/schema/delta/58/15_catchup_destination_rooms.sql
new file mode 100644
index 0000000000..ebfbed7925
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/15_catchup_destination_rooms.sql
@@ -0,0 +1,42 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+-- This schema delta alters the schema to enable 'catching up' remote homeservers
+-- after there has been a connectivity problem for any reason.
+
+-- This stores, for each (destination, room) pair, the stream_ordering of the
+-- latest event for that destination.
+CREATE TABLE IF NOT EXISTS destination_rooms (
+ -- the destination in question.
+ destination TEXT NOT NULL REFERENCES destinations (destination),
+ -- the ID of the room in question
+ room_id TEXT NOT NULL REFERENCES rooms (room_id),
+ -- the stream_ordering of the event
+ stream_ordering BIGINT NOT NULL,
+ PRIMARY KEY (destination, room_id)
+ -- We don't declare a foreign key on stream_ordering here because that'd mean
+ -- we'd need to either maintain an index (expensive) or do a table scan of
+ -- destination_rooms whenever we delete an event (also potentially expensive).
+ -- In addition to that, a foreign key on stream_ordering would be redundant
+ -- as this row doesn't need to refer to a specific event; if the event gets
+ -- deleted then it doesn't affect the validity of the stream_ordering here.
+);
+
+-- This index is needed to make it so that a deletion of a room (in the rooms
+-- table) can be efficient, as otherwise a table scan would need to be performed
+-- to check that no destination_rooms rows point to the room to be deleted.
+-- Also: it makes it efficient to delete all the entries for a given room ID,
+-- such as when purging a room.
+CREATE INDEX IF NOT EXISTS destination_rooms_room_id
+ ON destination_rooms (room_id);
diff --git a/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql b/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql
new file mode 100644
index 0000000000..55f5d0f732
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/16populate_stats_process_rooms_fix.sql
@@ -0,0 +1,22 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+-- This delta file fixes a regression introduced by 58/12room_stats.sql, removing the hacky
+-- populate_stats_process_rooms_2 background job and restores the functionality under the
+-- original name.
+-- See https://github.com/matrix-org/synapse/issues/8238 for details
+
+DELETE FROM background_updates WHERE update_name = 'populate_stats_process_rooms';
+UPDATE background_updates SET update_name = 'populate_stats_process_rooms'
+ WHERE update_name = 'populate_stats_process_rooms_2';
diff --git a/synapse/storage/databases/main/schema/delta/58/17_catchup_last_successful.sql b/synapse/storage/databases/main/schema/delta/58/17_catchup_last_successful.sql
new file mode 100644
index 0000000000..a67aa5e500
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/17_catchup_last_successful.sql
@@ -0,0 +1,21 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- This column tracks the stream_ordering of the event that was most recently
+-- successfully transmitted to the destination.
+-- A value of NULL means that we have not sent an event successfully yet
+-- (at least, not since the introduction of this column).
+ALTER TABLE destinations
+ ADD COLUMN last_successful_stream_ordering BIGINT;
diff --git a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres
index 889a9a0ce4..20c5af2eb7 100644
--- a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres
+++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.postgres
@@ -658,10 +658,19 @@ CREATE TABLE presence_stream (
+CREATE TABLE profile_replication_status (
+ host text NOT NULL,
+ last_synced_batch bigint NOT NULL
+);
+
+
+
CREATE TABLE profiles (
user_id text NOT NULL,
displayname text,
- avatar_url text
+ avatar_url text,
+ batch bigint,
+ active smallint DEFAULT 1 NOT NULL
);
@@ -1788,6 +1797,10 @@ CREATE INDEX presence_stream_user_id ON presence_stream USING btree (user_id);
+CREATE INDEX profiles_batch_idx ON profiles USING btree (batch);
+
+
+
CREATE INDEX public_room_index ON rooms USING btree (is_public);
diff --git a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
index a0411ede7e..e28ec3fa45 100644
--- a/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
+++ b/synapse/storage/databases/main/schema/full_schemas/54/full.sql.sqlite
@@ -6,7 +6,7 @@ CREATE TABLE presence_allow_inbound( observed_user_id TEXT NOT NULL, observer_us
CREATE TABLE users( name TEXT, password_hash TEXT, creation_ts BIGINT, admin SMALLINT DEFAULT 0 NOT NULL, upgrade_ts BIGINT, is_guest SMALLINT DEFAULT 0 NOT NULL, appservice_id TEXT, consent_version TEXT, consent_server_notice_sent TEXT, user_type TEXT DEFAULT NULL, UNIQUE(name) );
CREATE TABLE access_tokens( id BIGINT PRIMARY KEY, user_id TEXT NOT NULL, device_id TEXT, token TEXT NOT NULL, last_used BIGINT, UNIQUE(token) );
CREATE TABLE user_ips ( user_id TEXT NOT NULL, access_token TEXT NOT NULL, device_id TEXT, ip TEXT NOT NULL, user_agent TEXT NOT NULL, last_seen BIGINT NOT NULL );
-CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, UNIQUE(user_id) );
+CREATE TABLE profiles( user_id TEXT NOT NULL, displayname TEXT, avatar_url TEXT, batch BIGINT DEFAULT NULL, active SMALLINT DEFAULT 1 NOT NULL, UNIQUE(user_id) );
CREATE TABLE received_transactions( transaction_id TEXT, origin TEXT, ts BIGINT, response_code INTEGER, response_json bytea, has_been_referenced smallint default 0, UNIQUE (transaction_id, origin) );
CREATE TABLE destinations( destination TEXT PRIMARY KEY, retry_last_ts BIGINT, retry_interval INTEGER );
CREATE TABLE events( stream_ordering INTEGER PRIMARY KEY, topological_ordering BIGINT NOT NULL, event_id TEXT NOT NULL, type TEXT NOT NULL, room_id TEXT NOT NULL, content TEXT, unrecognized_keys TEXT, processed BOOL NOT NULL, outlier BOOL NOT NULL, depth BIGINT DEFAULT 0 NOT NULL, origin_server_ts BIGINT, received_ts BIGINT, sender TEXT, contains_url BOOLEAN, UNIQUE (event_id) );
@@ -202,6 +202,8 @@ CREATE INDEX group_users_u_idx ON group_users(user_id);
CREATE INDEX group_invites_u_idx ON group_invites(user_id);
CREATE UNIQUE INDEX group_rooms_g_idx ON group_rooms(group_id, room_id);
CREATE INDEX group_rooms_r_idx ON group_rooms(room_id);
+CREATE INDEX profiles_batch_idx ON profiles(batch);
+CREATE TABLE profile_replication_status ( host TEXT NOT NULL, last_synced_batch BIGINT NOT NULL );
CREATE TABLE user_daily_visits ( user_id TEXT NOT NULL, device_id TEXT, timestamp BIGINT NOT NULL );
CREATE INDEX user_daily_visits_uts_idx ON user_daily_visits(user_id, timestamp);
CREATE INDEX user_daily_visits_ts_idx ON user_daily_visits(timestamp);
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index f01cf2fd02..e34fce6281 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -89,7 +89,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(SearchBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
if not hs.config.enable_search:
return
@@ -342,7 +342,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
class SearchStore(SearchBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(SearchStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
async def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys.
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 5c6168e301..3c1e33819b 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -56,7 +56,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""
def __init__(self, database: DatabasePool, db_conn, hs):
- super(StateGroupWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
async def get_room_version(self, room_id: str) -> RoomVersion:
"""Get the room_version of a given room
@@ -320,7 +320,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(MainStateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
@@ -506,4 +506,4 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
"""
def __init__(self, database: DatabasePool, db_conn, hs):
- super(StateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 55a250ef06..5beb302be3 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -61,7 +61,7 @@ TYPE_TO_ORIGIN_TABLE = {"room": ("rooms", "room_id"), "user": ("users", "name")}
class StatsStore(StateDeltasStore):
def __init__(self, database: DatabasePool, db_conn, hs):
- super(StatsStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
self.clock = self.hs.get_clock()
@@ -74,9 +74,6 @@ class StatsStore(StateDeltasStore):
"populate_stats_process_rooms", self._populate_stats_process_rooms
)
self.db_pool.updates.register_background_update_handler(
- "populate_stats_process_rooms_2", self._populate_stats_process_rooms_2
- )
- self.db_pool.updates.register_background_update_handler(
"populate_stats_process_users", self._populate_stats_process_users
)
# we no longer need to perform clean-up, but we will give ourselves
@@ -148,31 +145,10 @@ class StatsStore(StateDeltasStore):
return len(users_to_work_on)
async def _populate_stats_process_rooms(self, progress, batch_size):
- """
- This was a background update which regenerated statistics for rooms.
-
- It has been replaced by StatsStore._populate_stats_process_rooms_2. This background
- job has been scheduled to run as part of Synapse v1.0.0, and again now. To ensure
- someone upgrading from <v1.0.0, this background task has been turned into a no-op
- so that the potentially expensive task is not run twice.
-
- Further context: https://github.com/matrix-org/synapse/pull/7977
- """
- await self.db_pool.updates._end_background_update(
- "populate_stats_process_rooms"
- )
- return 1
-
- async def _populate_stats_process_rooms_2(self, progress, batch_size):
- """
- This is a background update which regenerates statistics for rooms.
-
- It replaces StatsStore._populate_stats_process_rooms. See its docstring for the
- reasoning.
- """
+ """This is a background update which regenerates statistics for rooms."""
if not self.stats_enabled:
await self.db_pool.updates._end_background_update(
- "populate_stats_process_rooms_2"
+ "populate_stats_process_rooms"
)
return 1
@@ -189,13 +165,13 @@ class StatsStore(StateDeltasStore):
return [r for r, in txn]
rooms_to_work_on = await self.db_pool.runInteraction(
- "populate_stats_rooms_2_get_batch", _get_next_batch
+ "populate_stats_rooms_get_batch", _get_next_batch
)
# No more rooms -- complete the transaction.
if not rooms_to_work_on:
await self.db_pool.updates._end_background_update(
- "populate_stats_process_rooms_2"
+ "populate_stats_process_rooms"
)
return 1
@@ -204,9 +180,9 @@ class StatsStore(StateDeltasStore):
progress["last_room_id"] = room_id
await self.db_pool.runInteraction(
- "_populate_stats_process_rooms_2",
+ "_populate_stats_process_rooms",
self.db_pool.updates._background_update_progress_txn,
- "populate_stats_process_rooms_2",
+ "populate_stats_process_rooms",
progress,
)
@@ -234,6 +210,7 @@ class StatsStore(StateDeltasStore):
* topic
* avatar
* canonical_alias
+ * guest_access
A is_federatable key can also be included with a boolean value.
@@ -258,6 +235,7 @@ class StatsStore(StateDeltasStore):
"topic",
"avatar",
"canonical_alias",
+ "guest_access",
):
field = fields.get(col, sentinel)
if field is not sentinel and (not isinstance(field, str) or "\0" in field):
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index db20a3db30..92e96468b4 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -79,8 +79,8 @@ _EventDictReturn = namedtuple(
def generate_pagination_where_clause(
direction: str,
column_names: Tuple[str, str],
- from_token: Optional[Tuple[int, int]],
- to_token: Optional[Tuple[int, int]],
+ from_token: Optional[Tuple[Optional[int], int]],
+ to_token: Optional[Tuple[Optional[int], int]],
engine: BaseDatabaseEngine,
) -> str:
"""Creates an SQL expression to bound the columns by the pagination
@@ -259,16 +259,14 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
return " AND ".join(clauses), args
-class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
+class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
"""This is an abstract base class where subclasses must implement
`get_room_max_stream_ordering` and `get_room_min_stream_ordering`
which can be called in the initializer.
"""
- __metaclass__ = abc.ABCMeta
-
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
- super(StreamWorkerStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
self._send_federation = hs.should_send_federation()
@@ -310,11 +308,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_room_events_stream_for_rooms(
self,
room_ids: Collection[str],
- from_key: str,
- to_key: str,
+ from_key: RoomStreamToken,
+ to_key: RoomStreamToken,
limit: int = 0,
order: str = "DESC",
- ) -> Dict[str, Tuple[List[EventBase], str]]:
+ ) -> Dict[str, Tuple[List[EventBase], RoomStreamToken]]:
"""Get new room events in stream ordering since `from_key`.
Args:
@@ -333,9 +331,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
- list of recent events in the room
- stream ordering key for the start of the chunk of events returned.
"""
- from_id = RoomStreamToken.parse_stream_token(from_key).stream
-
- room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id)
+ room_ids = self._events_stream_cache.get_entities_changed(
+ room_ids, from_key.stream
+ )
if not room_ids:
return {}
@@ -364,16 +362,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
def get_rooms_that_changed(
- self, room_ids: Collection[str], from_key: str
+ self, room_ids: Collection[str], from_key: RoomStreamToken
) -> Set[str]:
"""Given a list of rooms and a token, return rooms where there may have
been changes.
-
- Args:
- room_ids
- from_key: The room_key portion of a StreamToken
"""
- from_id = RoomStreamToken.parse_stream_token(from_key).stream
+ from_id = from_key.stream
return {
room_id
for room_id in room_ids
@@ -383,11 +377,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_room_events_stream_for_room(
self,
room_id: str,
- from_key: str,
- to_key: str,
+ from_key: RoomStreamToken,
+ to_key: RoomStreamToken,
limit: int = 0,
order: str = "DESC",
- ) -> Tuple[List[EventBase], str]:
+ ) -> Tuple[List[EventBase], RoomStreamToken]:
"""Get new room events in stream ordering since `from_key`.
Args:
@@ -408,8 +402,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if from_key == to_key:
return [], from_key
- from_id = RoomStreamToken.parse_stream_token(from_key).stream
- to_id = RoomStreamToken.parse_stream_token(to_key).stream
+ from_id = from_key.stream
+ to_id = to_key.stream
has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
@@ -441,7 +435,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
ret.reverse()
if rows:
- key = "s%d" % min(r.stream_ordering for r in rows)
+ key = RoomStreamToken(None, min(r.stream_ordering for r in rows))
else:
# Assume we didn't get anything because there was nothing to
# get.
@@ -450,10 +444,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key
async def get_membership_changes_for_user(
- self, user_id: str, from_key: str, to_key: str
+ self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
) -> List[EventBase]:
- from_id = RoomStreamToken.parse_stream_token(from_key).stream
- to_id = RoomStreamToken.parse_stream_token(to_key).stream
+ from_id = from_key.stream
+ to_id = to_key.stream
if from_key == to_key:
return []
@@ -491,8 +485,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret
async def get_recent_events_for_room(
- self, room_id: str, limit: int, end_token: str
- ) -> Tuple[List[EventBase], str]:
+ self, room_id: str, limit: int, end_token: RoomStreamToken
+ ) -> Tuple[List[EventBase], RoomStreamToken]:
"""Get the most recent events in the room in topological ordering.
Args:
@@ -518,8 +512,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return (events, token)
async def get_recent_event_ids_for_room(
- self, room_id: str, limit: int, end_token: str
- ) -> Tuple[List[_EventDictReturn], str]:
+ self, room_id: str, limit: int, end_token: RoomStreamToken
+ ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
"""Get the most recent events in the room in topological ordering.
Args:
@@ -535,8 +529,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if limit == 0:
return [], end_token
- end_token = RoomStreamToken.parse(end_token)
-
rows, token = await self.db_pool.runInteraction(
"get_recent_event_ids_for_room",
self._paginate_room_events_txn,
@@ -619,17 +611,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
allow_none=allow_none,
)
- async def get_stream_token_for_event(self, event_id: str) -> str:
+ async def get_stream_token_for_event(self, event_id: str) -> RoomStreamToken:
"""The stream token for an event
Args:
event_id: The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
- A "s%d" stream token.
+ A stream token.
"""
stream_id = await self.get_stream_id_for_event(event_id)
- return "s%d" % (stream_id,)
+ return RoomStreamToken(None, stream_id)
async def get_topological_token_for_event(self, event_id: str) -> str:
"""The stream token for an event
@@ -951,7 +943,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
direction: str = "b",
limit: int = -1,
event_filter: Optional[Filter] = None,
- ) -> Tuple[List[_EventDictReturn], str]:
+ ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
"""Returns list of events before or after a given token.
Args:
@@ -986,8 +978,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
bounds = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
- from_token=from_token,
- to_token=to_token,
+ from_token=from_token.as_tuple(),
+ to_token=to_token.as_tuple() if to_token else None,
engine=self.database_engine,
)
@@ -1051,17 +1043,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# TODO (erikj): We should work out what to do here instead.
next_token = to_token if to_token else from_token
- return rows, str(next_token)
+ return rows, next_token
async def paginate_room_events(
self,
room_id: str,
- from_key: str,
- to_key: Optional[str] = None,
+ from_key: RoomStreamToken,
+ to_key: Optional[RoomStreamToken] = None,
direction: str = "b",
limit: int = -1,
event_filter: Optional[Filter] = None,
- ) -> Tuple[List[EventBase], str]:
+ ) -> Tuple[List[EventBase], RoomStreamToken]:
"""Returns list of events before or after a given token.
Args:
@@ -1080,10 +1072,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
and `to_key`).
"""
- from_key = RoomStreamToken.parse(from_key)
- if to_key:
- to_key = RoomStreamToken.parse(to_key)
-
rows, token = await self.db_pool.runInteraction(
"paginate_room_events",
self._paginate_room_events_txn,
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index 96ffe26cc9..9f120d3cb6 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -210,7 +210,7 @@ class TagsStore(TagsWorkerStore):
)
self._update_revision_txn(txn, user_id, room_id, next_id)
- with await self._account_data_id_gen.get_next() as next_id:
+ async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("add_tag", add_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
@@ -232,7 +232,7 @@ class TagsStore(TagsWorkerStore):
txn.execute(sql, (user_id, room_id, tag))
self._update_revision_txn(txn, user_id, room_id, next_id)
- with await self._account_data_id_gen.get_next() as next_id:
+ async with self._account_data_id_gen.get_next() as next_id:
await self.db_pool.runInteraction("remove_tag", remove_tag_txn, next_id)
self.get_tags_for_user.invalidate((user_id,))
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index 5b31aab700..97aed1500e 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -15,13 +15,14 @@
import logging
from collections import namedtuple
-from typing import Optional, Tuple
+from typing import Iterable, List, Optional, Tuple
from canonicaljson import encode_canonical_json
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict
from synapse.util.caches.expiringcache import ExpiringCache
@@ -47,7 +48,7 @@ class TransactionStore(SQLBaseStore):
"""
def __init__(self, database: DatabasePool, db_conn, hs):
- super(TransactionStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self._clock.looping_call(self._start_cleanup_transactions, 30 * 60 * 1000)
@@ -164,7 +165,9 @@ class TransactionStore(SQLBaseStore):
allow_none=True,
)
- if result and result["retry_last_ts"] > 0:
+ # check we have a row and retry_last_ts is not null or zero
+ # (retry_last_ts can't be negative)
+ if result and result["retry_last_ts"]:
return result
else:
return None
@@ -215,6 +218,7 @@ class TransactionStore(SQLBaseStore):
retry_interval = EXCLUDED.retry_interval
WHERE
EXCLUDED.retry_interval = 0
+ OR destinations.retry_interval IS NULL
OR destinations.retry_interval < EXCLUDED.retry_interval
"""
@@ -246,7 +250,11 @@ class TransactionStore(SQLBaseStore):
"retry_interval": retry_interval,
},
)
- elif retry_interval == 0 or prev_row["retry_interval"] < retry_interval:
+ elif (
+ retry_interval == 0
+ or prev_row["retry_interval"] is None
+ or prev_row["retry_interval"] < retry_interval
+ ):
self.db_pool.simple_update_one_txn(
txn,
"destinations",
@@ -273,3 +281,196 @@ class TransactionStore(SQLBaseStore):
await self.db_pool.runInteraction(
"_cleanup_transactions", _cleanup_transactions_txn
)
+
+ async def store_destination_rooms_entries(
+ self, destinations: Iterable[str], room_id: str, stream_ordering: int,
+ ) -> None:
+ """
+ Updates or creates `destination_rooms` entries in batch for a single event.
+
+ Args:
+ destinations: list of destinations
+ room_id: the room_id of the event
+ stream_ordering: the stream_ordering of the event
+ """
+
+ return await self.db_pool.runInteraction(
+ "store_destination_rooms_entries",
+ self._store_destination_rooms_entries_txn,
+ destinations,
+ room_id,
+ stream_ordering,
+ )
+
+ def _store_destination_rooms_entries_txn(
+ self,
+ txn: LoggingTransaction,
+ destinations: Iterable[str],
+ room_id: str,
+ stream_ordering: int,
+ ) -> None:
+
+ # ensure we have a `destinations` row for this destination, as there is
+ # a foreign key constraint.
+ if isinstance(self.database_engine, PostgresEngine):
+ q = """
+ INSERT INTO destinations (destination)
+ VALUES (?)
+ ON CONFLICT DO NOTHING;
+ """
+ elif isinstance(self.database_engine, Sqlite3Engine):
+ q = """
+ INSERT OR IGNORE INTO destinations (destination)
+ VALUES (?);
+ """
+ else:
+ raise RuntimeError("Unknown database engine")
+
+ txn.execute_batch(q, ((destination,) for destination in destinations))
+
+ rows = [(destination, room_id) for destination in destinations]
+
+ self.db_pool.simple_upsert_many_txn(
+ txn,
+ "destination_rooms",
+ ["destination", "room_id"],
+ rows,
+ ["stream_ordering"],
+ [(stream_ordering,)] * len(rows),
+ )
+
+ async def get_destination_last_successful_stream_ordering(
+ self, destination: str
+ ) -> Optional[int]:
+ """
+ Gets the stream ordering of the PDU most-recently successfully sent
+ to the specified destination, or None if this information has not been
+ tracked yet.
+
+ Args:
+ destination: the destination to query
+ """
+ return await self.db_pool.simple_select_one_onecol(
+ "destinations",
+ {"destination": destination},
+ "last_successful_stream_ordering",
+ allow_none=True,
+ desc="get_last_successful_stream_ordering",
+ )
+
+ async def set_destination_last_successful_stream_ordering(
+ self, destination: str, last_successful_stream_ordering: int
+ ) -> None:
+ """
+ Marks that we have successfully sent the PDUs up to and including the
+ one specified.
+
+ Args:
+ destination: the destination we have successfully sent to
+ last_successful_stream_ordering: the stream_ordering of the most
+ recent successfully-sent PDU
+ """
+ return await self.db_pool.simple_upsert(
+ "destinations",
+ keyvalues={"destination": destination},
+ values={"last_successful_stream_ordering": last_successful_stream_ordering},
+ desc="set_last_successful_stream_ordering",
+ )
+
+ async def get_catch_up_room_event_ids(
+ self, destination: str, last_successful_stream_ordering: int,
+ ) -> List[str]:
+ """
+ Returns at most 50 event IDs and their corresponding stream_orderings
+ that correspond to the oldest events that have not yet been sent to
+ the destination.
+
+ Args:
+ destination: the destination in question
+ last_successful_stream_ordering: the stream_ordering of the
+ most-recently successfully-transmitted event to the destination
+
+ Returns:
+ list of event_ids
+ """
+ return await self.db_pool.runInteraction(
+ "get_catch_up_room_event_ids",
+ self._get_catch_up_room_event_ids_txn,
+ destination,
+ last_successful_stream_ordering,
+ )
+
+ @staticmethod
+ def _get_catch_up_room_event_ids_txn(
+ txn: LoggingTransaction, destination: str, last_successful_stream_ordering: int,
+ ) -> List[str]:
+ q = """
+ SELECT event_id FROM destination_rooms
+ JOIN events USING (stream_ordering)
+ WHERE destination = ?
+ AND stream_ordering > ?
+ ORDER BY stream_ordering
+ LIMIT 50
+ """
+ txn.execute(
+ q, (destination, last_successful_stream_ordering),
+ )
+ event_ids = [row[0] for row in txn]
+ return event_ids
+
+ async def get_catch_up_outstanding_destinations(
+ self, after_destination: Optional[str]
+ ) -> List[str]:
+ """
+ Gets at most 25 destinations which have outstanding PDUs to be caught up,
+ and are not being backed off from
+ Args:
+ after_destination:
+ If provided, all destinations must be lexicographically greater
+ than this one.
+
+ Returns:
+ list of up to 25 destinations with outstanding catch-up.
+ These are the lexicographically first destinations which are
+ lexicographically greater than after_destination (if provided).
+ """
+ time = self.hs.get_clock().time_msec()
+
+ return await self.db_pool.runInteraction(
+ "get_catch_up_outstanding_destinations",
+ self._get_catch_up_outstanding_destinations_txn,
+ time,
+ after_destination,
+ )
+
+ @staticmethod
+ def _get_catch_up_outstanding_destinations_txn(
+ txn: LoggingTransaction, now_time_ms: int, after_destination: Optional[str]
+ ) -> List[str]:
+ q = """
+ SELECT destination FROM destinations
+ WHERE destination IN (
+ SELECT destination FROM destination_rooms
+ WHERE destination_rooms.stream_ordering >
+ destinations.last_successful_stream_ordering
+ )
+ AND destination > ?
+ AND (
+ retry_last_ts IS NULL OR
+ retry_last_ts + retry_interval < ?
+ )
+ ORDER BY destination
+ LIMIT 25
+ """
+ txn.execute(
+ q,
+ (
+ # everything is lexicographically greater than "" so this gives
+ # us the first batch of up to 25.
+ after_destination or "",
+ now_time_ms,
+ ),
+ )
+
+ destinations = [row[0] for row in txn]
+ return destinations
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index b89668d561..3b9211a6d2 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -23,7 +23,7 @@ from synapse.types import JsonDict
from synapse.util import json_encoder, stringutils
-@attr.s
+@attr.s(slots=True)
class UIAuthSessionData:
session_id = attr.ib(type=str)
# The dictionary from the client root level, not the 'auth' key.
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index f2f9a5799a..5a390ff2f6 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -38,7 +38,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
SHARE_PRIVATE_WORKING_SET = 500
def __init__(self, database: DatabasePool, db_conn, hs):
- super(UserDirectoryBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.server_name = hs.hostname
@@ -564,7 +564,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
SHARE_PRIVATE_WORKING_SET = 500
def __init__(self, database: DatabasePool, db_conn, hs):
- super(UserDirectoryStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
async def remove_from_user_dir(self, user_id: str) -> None:
def _remove_from_user_dir_txn(txn):
diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py
index 2f7c95fc74..f9575b1f1f 100644
--- a/synapse/storage/databases/main/user_erasure_store.py
+++ b/synapse/storage/databases/main/user_erasure_store.py
@@ -100,7 +100,7 @@ class UserErasureStore(UserErasureWorkerStore):
return
# They are there, delete them.
- self.simple_delete_one_txn(
+ self.db_pool.simple_delete_one_txn(
txn, "erased_users", keyvalues={"user_id": user_id}
)
diff --git a/synapse/storage/databases/state/bg_updates.py b/synapse/storage/databases/state/bg_updates.py
index 139085b672..acb24e33af 100644
--- a/synapse/storage/databases/state/bg_updates.py
+++ b/synapse/storage/databases/state/bg_updates.py
@@ -181,7 +181,7 @@ class StateBackgroundUpdateStore(StateGroupBackgroundUpdateStore):
STATE_GROUPS_ROOM_INDEX_UPDATE_NAME = "state_groups_room_id_idx"
def __init__(self, database: DatabasePool, db_conn, hs):
- super(StateBackgroundUpdateStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
self._background_deduplicate_state,
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index e924f1ca3b..bec3780a32 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -52,7 +52,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
"""
def __init__(self, database: DatabasePool, db_conn, hs):
- super(StateGroupDataStore, self).__init__(database, db_conn, hs)
+ super().__init__(database, db_conn, hs)
# Originally the state store used a single DictionaryCache to cache the
# event IDs for the state types in a given state group to avoid hammering
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index dbaeef91dd..d89f6ed128 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -18,7 +18,7 @@
import itertools
import logging
from collections import deque, namedtuple
-from typing import Iterable, List, Optional, Set, Tuple
+from typing import Dict, Iterable, List, Optional, Set, Tuple
from prometheus_client import Counter, Histogram
@@ -31,7 +31,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.databases import Databases
from synapse.storage.databases.main.events import DeltaState
-from synapse.types import StateMap
+from synapse.types import Collection, StateMap
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.metrics import Measure
@@ -185,6 +185,8 @@ class EventsPersistenceStorage:
# store for now.
self.main_store = stores.main
self.state_store = stores.state
+
+ assert stores.persist_events
self.persist_events_store = stores.persist_events
self._clock = hs.get_clock()
@@ -208,7 +210,7 @@ class EventsPersistenceStorage:
Returns:
the stream ordering of the latest persisted event
"""
- partitioned = {}
+ partitioned = {} # type: Dict[str, List[Tuple[EventBase, EventContext]]]
for event, ctx in events_and_contexts:
partitioned.setdefault(event.room_id, []).append((event, ctx))
@@ -305,7 +307,9 @@ class EventsPersistenceStorage:
# Work out the new "current state" for each room.
# We do this by working out what the new extremities are and then
# calculating the state from that.
- events_by_room = {}
+ events_by_room = (
+ {}
+ ) # type: Dict[str, List[Tuple[EventBase, EventContext]]]
for event, context in chunk:
events_by_room.setdefault(event.room_id, []).append(
(event, context)
@@ -436,7 +440,7 @@ class EventsPersistenceStorage:
self,
room_id: str,
event_contexts: List[Tuple[EventBase, EventContext]],
- latest_event_ids: List[str],
+ latest_event_ids: Collection[str],
):
"""Calculates the new forward extremities for a room given events to
persist.
@@ -470,7 +474,7 @@ class EventsPersistenceStorage:
# Remove any events which are prev_events of any existing events.
existing_prevs = await self.persist_events_store._get_events_which_are_prevs(
result
- )
+ ) # type: Collection[str]
result.difference_update(existing_prevs)
# Finally handle the case where the new events have soft-failed prev
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index ee60e2a718..4957e77f4c 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -19,12 +19,15 @@ import logging
import os
import re
from collections import Counter
-from typing import TextIO
+from typing import Optional, TextIO
import attr
+from synapse.config.homeserver import HomeServerConfig
+from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.engines.postgres import PostgresEngine
-from synapse.storage.types import Cursor
+from synapse.storage.types import Connection, Cursor
+from synapse.types import Collection
logger = logging.getLogger(__name__)
@@ -63,7 +66,12 @@ UNAPPLIED_DELTA_ON_WORKER_ERROR = (
)
-def prepare_database(db_conn, database_engine, config, databases=["main", "state"]):
+def prepare_database(
+ db_conn: Connection,
+ database_engine: BaseDatabaseEngine,
+ config: Optional[HomeServerConfig],
+ databases: Collection[str] = ["main", "state"],
+):
"""Prepares a physical database for usage. Will either create all necessary tables
or upgrade from an older schema version.
@@ -73,16 +81,24 @@ def prepare_database(db_conn, database_engine, config, databases=["main", "state
Args:
db_conn:
database_engine:
- config (synapse.config.homeserver.HomeServerConfig|None):
+ config :
application config, or None if we are connecting to an existing
database which we expect to be configured already
- databases (list[str]): The name of the databases that will be used
+ databases: The name of the databases that will be used
with this physical database. Defaults to all databases.
"""
try:
cur = db_conn.cursor()
+ # sqlite does not automatically start transactions for DDL / SELECT statements,
+ # so we start one before running anything. This ensures that any upgrades
+ # are either applied completely, or not at all.
+ #
+ # (psycopg2 automatically starts a transaction as soon as we run any statements
+ # at all, so this is redundant but harmless there.)
+ cur.execute("BEGIN TRANSACTION")
+
logger.info("%r: Checking existing schema version", databases)
version_info = _get_or_create_schema_state(cur, database_engine)
@@ -622,7 +638,7 @@ def _get_or_create_schema_state(txn, database_engine):
return None
-@attr.s()
+@attr.s(slots=True)
class _DirectoryListing:
"""Helper class to store schema file name and the
absolute path to it.
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index d30e3f11e7..cec96ad6a7 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -22,7 +22,7 @@ from synapse.api.errors import SynapseError
logger = logging.getLogger(__name__)
-@attr.s
+@attr.s(slots=True)
class PaginationChunk:
"""Returned by relation pagination APIs.
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index b7eb4f8ac9..b0353ac2dc 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -12,14 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-import contextlib
import heapq
import logging
import threading
from collections import deque
-from typing import Dict, List, Set
+from contextlib import contextmanager
+from typing import Dict, List, Optional, Set, Union
+import attr
from typing_extensions import Deque
from synapse.storage.database import DatabasePool, LoggingTransaction
@@ -86,7 +86,7 @@ class StreamIdGenerator:
upwards, -1 to grow downwards.
Usage:
- with await stream_id_gen.get_next() as stream_id:
+ async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
@@ -101,10 +101,10 @@ class StreamIdGenerator:
)
self._unfinished_ids = deque() # type: Deque[int]
- async def get_next(self):
+ def get_next(self):
"""
Usage:
- with await stream_id_gen.get_next() as stream_id:
+ async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
with self._lock:
@@ -113,7 +113,7 @@ class StreamIdGenerator:
self._unfinished_ids.append(next_id)
- @contextlib.contextmanager
+ @contextmanager
def manager():
try:
yield next_id
@@ -121,12 +121,12 @@ class StreamIdGenerator:
with self._lock:
self._unfinished_ids.remove(next_id)
- return manager()
+ return _AsyncCtxManagerWrapper(manager())
- async def get_next_mult(self, n):
+ def get_next_mult(self, n):
"""
Usage:
- with await stream_id_gen.get_next(n) as stream_ids:
+ async with stream_id_gen.get_next(n) as stream_ids:
# ... persist events ...
"""
with self._lock:
@@ -140,7 +140,7 @@ class StreamIdGenerator:
for next_id in next_ids:
self._unfinished_ids.append(next_id)
- @contextlib.contextmanager
+ @contextmanager
def manager():
try:
yield next_ids
@@ -149,7 +149,7 @@ class StreamIdGenerator:
for next_id in next_ids:
self._unfinished_ids.remove(next_id)
- return manager()
+ return _AsyncCtxManagerWrapper(manager())
def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or
@@ -224,6 +224,10 @@ class MultiWriterIdGenerator:
# should be less than the minimum of this set (if not empty).
self._unfinished_ids = set() # type: Set[int]
+ # Set of local IDs that we've processed that are larger than the current
+ # position, due to there being smaller unpersisted IDs.
+ self._finished_ids = set() # type: Set[int]
+
# We track the max position where we know everything before has been
# persisted. This is done by a) looking at the min across all instances
# and b) noting that if we have seen a run of persisted positions
@@ -236,8 +240,12 @@ class MultiWriterIdGenerator:
# gaps should be relatively rare it's still worth doing the book keeping
# that allows us to skip forwards when there are gapless runs of
# positions.
+ #
+ # We start at 1 here as a) the first generated stream ID will be 2, and
+ # b) other parts of the code assume that stream IDs are strictly greater
+ # than 0.
self._persisted_upto_position = (
- min(self._current_positions.values()) if self._current_positions else 0
+ min(self._current_positions.values()) if self._current_positions else 1
)
self._known_persisted_positions = [] # type: List[int]
@@ -274,59 +282,23 @@ class MultiWriterIdGenerator:
def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
return self._sequence_gen.get_next_mult_txn(txn, n)
- async def get_next(self):
+ def get_next(self):
"""
Usage:
- with await stream_id_gen.get_next() as stream_id:
+ async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
- next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
-
- # Assert the fetched ID is actually greater than what we currently
- # believe the ID to be. If not, then the sequence and table have got
- # out of sync somehow.
- with self._lock:
- assert self._current_positions.get(self._instance_name, 0) < next_id
-
- self._unfinished_ids.add(next_id)
-
- @contextlib.contextmanager
- def manager():
- try:
- # Multiply by the return factor so that the ID has correct sign.
- yield self._return_factor * next_id
- finally:
- self._mark_id_as_finished(next_id)
- return manager()
+ return _MultiWriterCtxManager(self)
- async def get_next_mult(self, n: int):
+ def get_next_mult(self, n: int):
"""
Usage:
- with await stream_id_gen.get_next_mult(5) as stream_ids:
+ async with stream_id_gen.get_next_mult(5) as stream_ids:
# ... persist events ...
"""
- next_ids = await self._db.runInteraction(
- "_load_next_mult_id", self._load_next_mult_id_txn, n
- )
- # Assert the fetched ID is actually greater than any ID we've already
- # seen. If not, then the sequence and table have got out of sync
- # somehow.
- with self._lock:
- assert max(self._current_positions.values(), default=0) < min(next_ids)
-
- self._unfinished_ids.update(next_ids)
-
- @contextlib.contextmanager
- def manager():
- try:
- yield [self._return_factor * i for i in next_ids]
- finally:
- for i in next_ids:
- self._mark_id_as_finished(i)
-
- return manager()
+ return _MultiWriterCtxManager(self, n)
def get_next_txn(self, txn: LoggingTransaction):
"""
@@ -348,17 +320,44 @@ class MultiWriterIdGenerator:
def _mark_id_as_finished(self, next_id: int):
"""The ID has finished being processed so we should advance the
- current poistion if possible.
+ current position if possible.
"""
with self._lock:
self._unfinished_ids.discard(next_id)
+ self._finished_ids.add(next_id)
+
+ new_cur = None
+
+ if self._unfinished_ids:
+ # If there are unfinished IDs then the new position will be the
+ # largest finished ID less than the minimum unfinished ID.
+
+ finished = set()
+
+ min_unfinshed = min(self._unfinished_ids)
+ for s in self._finished_ids:
+ if s < min_unfinshed:
+ if new_cur is None or new_cur < s:
+ new_cur = s
+ else:
+ finished.add(s)
+
+ # We clear these out since they're now all less than the new
+ # position.
+ self._finished_ids = finished
+ else:
+ # There are no unfinished IDs so the new position is simply the
+ # largest finished one.
+ new_cur = max(self._finished_ids)
- # Figure out if its safe to advance the position by checking there
- # aren't any lower allocated IDs that are yet to finish.
- if all(c > next_id for c in self._unfinished_ids):
+ # We clear these out since they're now all less than the new
+ # position.
+ self._finished_ids.clear()
+
+ if new_cur:
curr = self._current_positions.get(self._instance_name, 0)
- self._current_positions[self._instance_name] = max(curr, next_id)
+ self._current_positions[self._instance_name] = max(curr, new_cur)
self._add_persisted_position(next_id)
@@ -367,9 +366,7 @@ class MultiWriterIdGenerator:
equal to it have been successfully persisted.
"""
- # Currently we don't support this operation, as it's not obvious how to
- # condense the stream positions of multiple writers into a single int.
- raise NotImplementedError()
+ return self.get_persisted_upto_position()
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.
@@ -428,7 +425,7 @@ class MultiWriterIdGenerator:
# We move the current min position up if the minimum current positions
# of all instances is higher (since by definition all positions less
# that that have been persisted).
- min_curr = min(self._current_positions.values())
+ min_curr = min(self._current_positions.values(), default=0)
self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
# We now iterate through the seen positions, discarding those that are
@@ -449,3 +446,61 @@ class MultiWriterIdGenerator:
# There was a gap in seen positions, so there is nothing more to
# do.
break
+
+
+@attr.s(slots=True)
+class _AsyncCtxManagerWrapper:
+ """Helper class to convert a plain context manager to an async one.
+
+ This is mainly useful if you have a plain context manager but the interface
+ requires an async one.
+ """
+
+ inner = attr.ib()
+
+ async def __aenter__(self):
+ return self.inner.__enter__()
+
+ async def __aexit__(self, exc_type, exc, tb):
+ return self.inner.__exit__(exc_type, exc, tb)
+
+
+@attr.s(slots=True)
+class _MultiWriterCtxManager:
+ """Async context manager returned by MultiWriterIdGenerator
+ """
+
+ id_gen = attr.ib(type=MultiWriterIdGenerator)
+ multiple_ids = attr.ib(type=Optional[int], default=None)
+ stream_ids = attr.ib(type=List[int], factory=list)
+
+ async def __aenter__(self) -> Union[int, List[int]]:
+ self.stream_ids = await self.id_gen._db.runInteraction(
+ "_load_next_mult_id",
+ self.id_gen._load_next_mult_id_txn,
+ self.multiple_ids or 1,
+ )
+
+ # Assert the fetched ID is actually greater than any ID we've already
+ # seen. If not, then the sequence and table have got out of sync
+ # somehow.
+ with self.id_gen._lock:
+ assert max(self.id_gen._current_positions.values(), default=0) < min(
+ self.stream_ids
+ )
+
+ self.id_gen._unfinished_ids.update(self.stream_ids)
+
+ if self.multiple_ids is None:
+ return self.stream_ids[0] * self.id_gen._return_factor
+ else:
+ return [i * self.id_gen._return_factor for i in self.stream_ids]
+
+ async def __aexit__(self, exc_type, exc, tb):
+ for i in self.stream_ids:
+ self.id_gen._mark_id_as_finished(i)
+
+ if exc_type is not None:
+ return False
+
+ return False
|