diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 72fef1533f..0264dea61d 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -63,7 +63,7 @@ if TYPE_CHECKING:
from synapse.server import HomeServer
# python 3 does not have a maximum int value
-MAX_TXN_ID = 2 ** 63 - 1
+MAX_TXN_ID = 2**63 - 1
logger = logging.getLogger(__name__)
@@ -241,9 +241,17 @@ class LoggingTransaction:
self.exception_callbacks = exception_callbacks
def call_after(self, callback: Callable[..., object], *args: Any, **kwargs: Any):
- """Call the given callback on the main twisted thread after the
- transaction has finished. Used to invalidate the caches on the
- correct thread.
+ """Call the given callback on the main twisted thread after the transaction has
+ finished.
+
+ Mostly used to invalidate the caches on the correct thread.
+
+ Note that transactions may be retried a few times if they encounter database
+ errors such as serialization failures. Callbacks given to `call_after`
+ will accumulate across transaction attempts and will _all_ be called once a
+ transaction attempt succeeds, regardless of whether previous transaction
+ attempts failed. Otherwise, if all transaction attempts fail, all
+ `call_on_exception` callbacks will be run instead.
"""
# if self.after_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
@@ -254,6 +262,15 @@ class LoggingTransaction:
def call_on_exception(
self, callback: Callable[..., object], *args: Any, **kwargs: Any
):
+ """Call the given callback on the main twisted thread after the transaction has
+ failed.
+
+ Note that transactions may be retried a few times if they encounter database
+ errors such as serialization failures. Callbacks given to `call_on_exception`
+ will accumulate across transaction attempts and will _all_ be called once the
+ final transaction attempt fails. No `call_on_exception` callbacks will be run
+ if any transaction attempt succeeds.
+ """
# if self.exception_callbacks is None, that means that whatever constructed the
# LoggingTransaction isn't expecting there to be any callbacks; assert that
# is not the case.
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index f024761ba7..d4a38daa9a 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -33,7 +33,7 @@ from .account_data import AccountDataStore
from .appservice import ApplicationServiceStore, ApplicationServiceTransactionStore
from .cache import CacheInvalidationWorkerStore
from .censor_events import CensorEventsStore
-from .client_ips import ClientIpStore
+from .client_ips import ClientIpWorkerStore
from .deviceinbox import DeviceInboxStore
from .devices import DeviceStore
from .directory import DirectoryStore
@@ -49,7 +49,7 @@ from .keys import KeyStore
from .lock import LockStore
from .media_repository import MediaRepositoryStore
from .metrics import ServerMetricsStore
-from .monthly_active_users import MonthlyActiveUsersStore
+from .monthly_active_users import MonthlyActiveUsersWorkerStore
from .openid import OpenIdStore
from .presence import PresenceStore
from .profile import ProfileStore
@@ -112,13 +112,13 @@ class DataStore(
AccountDataStore,
EventPushActionsStore,
OpenIdStore,
- ClientIpStore,
+ ClientIpWorkerStore,
DeviceStore,
DeviceInboxStore,
UserDirectoryStore,
GroupServerStore,
UserErasureStore,
- MonthlyActiveUsersStore,
+ MonthlyActiveUsersWorkerStore,
StatsStore,
RelationsStore,
CensorEventsStore,
@@ -146,6 +146,7 @@ class DataStore(
extra_tables=[
("user_signature_stream", "stream_id"),
("device_lists_outbound_pokes", "stream_id"),
+ ("device_lists_changes_in_room", "stream_id"),
],
)
@@ -183,8 +184,18 @@ class DataStore(
super().__init__(database, db_conn, hs)
device_list_max = self._device_list_id_gen.get_current_token()
+ device_list_prefill, min_device_list_id = self.db_pool.get_cache_dict(
+ db_conn,
+ "device_lists_stream",
+ entity_column="user_id",
+ stream_column="stream_id",
+ max_value=device_list_max,
+ limit=1000,
+ )
self._device_list_stream_cache = StreamChangeCache(
- "DeviceListStreamChangeCache", device_list_max
+ "DeviceListStreamChangeCache",
+ min_device_list_id,
+ prefilled_cache=device_list_prefill,
)
self._user_signature_stream_cache = StreamChangeCache(
"UserSignatureStreamChangeCache", device_list_max
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index 0694446558..eb32c34a85 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -29,7 +29,9 @@ from synapse.storage._base import db_to_json
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
-from synapse.types import JsonDict
+from synapse.storage.types import Cursor
+from synapse.storage.util.sequence import build_sequence_generator
+from synapse.types import DeviceListUpdates, JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import _CacheContext, cached
@@ -72,6 +74,22 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
)
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
+ def get_max_as_txn_id(txn: Cursor) -> int:
+ logger.warning("Falling back to slow query, you should port to postgres")
+ txn.execute(
+ "SELECT COALESCE(max(txn_id), 0) FROM application_services_txns"
+ )
+ return txn.fetchone()[0] # type: ignore
+
+ self._as_txn_seq_gen = build_sequence_generator(
+ db_conn,
+ database.engine,
+ get_max_as_txn_id,
+ "application_services_txn_id_seq",
+ table="application_services_txns",
+ id_column="txn_id",
+ )
+
super().__init__(database, db_conn, hs)
def get_app_services(self):
@@ -217,6 +235,7 @@ class ApplicationServiceTransactionWorkerStore(
to_device_messages: List[JsonDict],
one_time_key_counts: TransactionOneTimeKeyCounts,
unused_fallback_keys: TransactionUnusedFallbackKeys,
+ device_list_summary: DeviceListUpdates,
) -> AppServiceTransaction:
"""Atomically creates a new transaction for this application service
with the given list of events. Ephemeral events are NOT persisted to the
@@ -231,27 +250,14 @@ class ApplicationServiceTransactionWorkerStore(
appservice devices in the transaction.
unused_fallback_keys: Lists of unused fallback keys for relevant
appservice devices in the transaction.
+ device_list_summary: The device list summary to include in the transaction.
Returns:
A new transaction.
"""
def _create_appservice_txn(txn):
- # work out new txn id (highest txn id for this service += 1)
- # The highest id may be the last one sent (in which case it is last_txn)
- # or it may be the highest in the txns list (which are waiting to be/are
- # being sent)
- last_txn_id = self._get_last_txn(txn, service.id)
-
- txn.execute(
- "SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
- (service.id,),
- )
- highest_txn_id = txn.fetchone()[0]
- if highest_txn_id is None:
- highest_txn_id = 0
-
- new_txn_id = max(highest_txn_id, last_txn_id) + 1
+ new_txn_id = self._as_txn_seq_gen.get_next_id_txn(txn)
# Insert new txn into txn table
event_ids = json_encoder.encode([e.event_id for e in events])
@@ -268,6 +274,7 @@ class ApplicationServiceTransactionWorkerStore(
to_device_messages=to_device_messages,
one_time_key_counts=one_time_key_counts,
unused_fallback_keys=unused_fallback_keys,
+ device_list_summary=device_list_summary,
)
return await self.db_pool.runInteraction(
@@ -283,25 +290,8 @@ class ApplicationServiceTransactionWorkerStore(
txn_id: The transaction ID being completed.
service: The application service which was sent this transaction.
"""
- txn_id = int(txn_id)
def _complete_appservice_txn(txn):
- # Debugging query: Make sure the txn being completed is EXACTLY +1 from
- # what was there before. If it isn't, we've got problems (e.g. the AS
- # has probably missed some events), so whine loudly but still continue,
- # since it shouldn't fail completion of the transaction.
- last_txn_id = self._get_last_txn(txn, service.id)
- if (last_txn_id + 1) != txn_id:
- logger.error(
- "appservice: Completing a transaction which has an ID > 1 from "
- "the last ID sent to this AS. We've either dropped events or "
- "sent it to the AS out of order. FIX ME. last_txn=%s "
- "completing_txn=%s service_id=%s",
- last_txn_id,
- txn_id,
- service.id,
- )
-
# Set current txn_id for AS to 'txn_id'
self.db_pool.simple_upsert_txn(
txn,
@@ -359,8 +349,8 @@ class ApplicationServiceTransactionWorkerStore(
events = await self.get_events_as_list(event_ids)
- # TODO: to-device messages, one-time key counts and unused fallback keys
- # are not yet populated for catch-up transactions.
+ # TODO: to-device messages, one-time key counts, device list summaries and unused
+ # fallback keys are not yet populated for catch-up transactions.
# We likely want to populate those for reliability.
return AppServiceTransaction(
service=service,
@@ -370,19 +360,9 @@ class ApplicationServiceTransactionWorkerStore(
to_device_messages=[],
one_time_key_counts={},
unused_fallback_keys={},
+ device_list_summary=DeviceListUpdates(),
)
- def _get_last_txn(self, txn, service_id: Optional[str]) -> int:
- txn.execute(
- "SELECT last_txn FROM application_services_state WHERE as_id=?",
- (service_id,),
- )
- last_txn_id = txn.fetchone()
- if last_txn_id is None or last_txn_id[0] is None: # no row exists
- return 0
- else:
- return int(last_txn_id[0]) # select 'last_txn' col
-
async def set_appservice_last_pos(self, pos: int) -> None:
def set_appservice_last_pos_txn(txn):
txn.execute(
@@ -430,7 +410,7 @@ class ApplicationServiceTransactionWorkerStore(
async def get_type_stream_id_for_appservice(
self, service: ApplicationService, type: str
) -> int:
- if type not in ("read_receipt", "presence", "to_device"):
+ if type not in ("read_receipt", "presence", "to_device", "device_list"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (type,)
@@ -446,7 +426,8 @@ class ApplicationServiceTransactionWorkerStore(
)
last_stream_id = txn.fetchone()
if last_stream_id is None or last_stream_id[0] is None: # no row exists
- return 0
+ # Stream tokens always start from 1, to avoid foot guns around `0` being falsey.
+ return 1
else:
return int(last_stream_id[0])
@@ -457,7 +438,7 @@ class ApplicationServiceTransactionWorkerStore(
async def set_appservice_stream_type_pos(
self, service: ApplicationService, stream_type: str, pos: Optional[int]
) -> None:
- if stream_type not in ("read_receipt", "presence", "to_device"):
+ if stream_type not in ("read_receipt", "presence", "to_device", "device_list"):
raise ValueError(
"Expected type to be a valid application stream id type, got %s"
% (stream_type,)
diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py
index 8b0c614ece..8480ea4e1c 100644
--- a/synapse/storage/databases/main/client_ips.py
+++ b/synapse/storage/databases/main/client_ips.py
@@ -25,7 +25,9 @@ from synapse.storage.database import (
LoggingTransaction,
make_tuple_comparison_clause,
)
-from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore
+from synapse.storage.databases.main.monthly_active_users import (
+ MonthlyActiveUsersWorkerStore,
+)
from synapse.types import JsonDict, UserID
from synapse.util.caches.lrucache import LruCache
@@ -397,7 +399,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
return updated
-class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
+class ClientIpWorkerStore(ClientIpBackgroundUpdateStore, MonthlyActiveUsersWorkerStore):
def __init__(
self,
database: DatabasePool,
@@ -406,11 +408,40 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
):
super().__init__(database, db_conn, hs)
+ if hs.config.redis.redis_enabled:
+ # If we're using Redis, we can shift this update process off to
+ # the background worker
+ self._update_on_this_worker = hs.config.worker.run_background_tasks
+ else:
+ # If we're NOT using Redis, this must be handled by the master
+ self._update_on_this_worker = hs.get_instance_name() == "master"
+
self.user_ips_max_age = hs.config.server.user_ips_max_age
+ # (user_id, access_token, ip,) -> last_seen
+ self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
+ cache_name="client_ip_last_seen", max_size=50000
+ )
+
if hs.config.worker.run_background_tasks and self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
+ if self._update_on_this_worker:
+ # This is the designated worker that can write to the client IP
+ # tables.
+
+ # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
+ self._batch_row_update: Dict[
+ Tuple[str, str, str], Tuple[str, Optional[str], int]
+ ] = {}
+
+ self._client_ip_looper = self._clock.looping_call(
+ self._update_client_ips_batch, 5 * 1000
+ )
+ self.hs.get_reactor().addSystemEventTrigger(
+ "before", "shutdown", self._update_client_ips_batch
+ )
+
@wrap_as_background_process("prune_old_user_ips")
async def _prune_old_user_ips(self) -> None:
"""Removes entries in user IPs older than the configured period."""
@@ -456,7 +487,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
"_prune_old_user_ips", _prune_old_user_ips_txn
)
- async def get_last_client_ip_by_device(
+ async def _get_last_client_ip_by_device_from_database(
self, user_id: str, device_id: Optional[str]
) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]:
"""For each device_id listed, give the user_ip it was last seen on.
@@ -487,7 +518,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
return {(d["user_id"], d["device_id"]): d for d in res}
- async def get_user_ip_and_agents(
+ async def _get_user_ip_and_agents_from_database(
self, user: UserID, since_ts: int = 0
) -> List[LastConnectionInfo]:
"""Fetch the IPs and user agents for a user since the given timestamp.
@@ -539,34 +570,6 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
for access_token, ip, user_agent, last_seen in rows
]
-
-class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
- def __init__(
- self,
- database: DatabasePool,
- db_conn: LoggingDatabaseConnection,
- hs: "HomeServer",
- ):
-
- # (user_id, access_token, ip,) -> last_seen
- self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
- cache_name="client_ip_last_seen", max_size=50000
- )
-
- super().__init__(database, db_conn, hs)
-
- # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
- self._batch_row_update: Dict[
- Tuple[str, str, str], Tuple[str, Optional[str], int]
- ] = {}
-
- self._client_ip_looper = self._clock.looping_call(
- self._update_client_ips_batch, 5 * 1000
- )
- self.hs.get_reactor().addSystemEventTrigger(
- "before", "shutdown", self._update_client_ips_batch
- )
-
async def insert_client_ip(
self,
user_id: str,
@@ -584,17 +587,27 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
last_seen = self.client_ip_last_seen.get(key)
except KeyError:
last_seen = None
- await self.populate_monthly_active_users(user_id)
+
# Rate-limited inserts
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return
self.client_ip_last_seen.set(key, now)
- self._batch_row_update[key] = (user_agent, device_id, now)
+ if self._update_on_this_worker:
+ await self.populate_monthly_active_users(user_id)
+ self._batch_row_update[key] = (user_agent, device_id, now)
+ else:
+ # We are not the designated writer-worker, so stream over replication
+ self.hs.get_replication_command_handler().send_user_ip(
+ user_id, access_token, ip, user_agent, device_id, now
+ )
@wrap_as_background_process("update_client_ips")
async def _update_client_ips_batch(self) -> None:
+ assert (
+ self._update_on_this_worker
+ ), "This worker is not designated to update client IPs"
# If the DB pool has already terminated, don't try updating
if not self.db_pool.is_running():
@@ -612,6 +625,10 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
txn: LoggingTransaction,
to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]],
) -> None:
+ assert (
+ self._update_on_this_worker
+ ), "This worker is not designated to update client IPs"
+
if "user_ips" in self.db_pool._unsafe_to_upsert_tables or (
not self.database_engine.can_native_upsert
):
@@ -662,7 +679,12 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
A dictionary mapping a tuple of (user_id, device_id) to dicts, with
keys giving the column names from the devices table.
"""
- ret = await super().get_last_client_ip_by_device(user_id, device_id)
+ ret = await self._get_last_client_ip_by_device_from_database(user_id, device_id)
+
+ if not self._update_on_this_worker:
+ # Only the writing-worker has additional in-memory data to enhance
+ # the result
+ return ret
# Update what is retrieved from the database with data which is pending
# insertion, as if it has already been stored in the database.
@@ -707,9 +729,16 @@ class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
Only the latest user agent for each access token and IP address combination
is available.
"""
+ rows_from_db = await self._get_user_ip_and_agents_from_database(user, since_ts)
+
+ if not self._update_on_this_worker:
+ # Only the writing-worker has additional in-memory data to enhance
+ # the result
+ return rows_from_db
+
results: Dict[Tuple[str, str], LastConnectionInfo] = {
(connection["access_token"], connection["ip"]): connection
- for connection in await super().get_user_ip_and_agents(user, since_ts)
+ for connection in rows_from_db
}
# Overlay data that is pending insertion on top of the results from the
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 3b3a089b76..07eea4b3d2 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -681,42 +681,64 @@ class DeviceWorkerStore(SQLBaseStore):
return self._device_list_stream_cache.get_all_entities_changed(from_key)
async def get_users_whose_devices_changed(
- self, from_key: int, user_ids: Iterable[str]
+ self,
+ from_key: int,
+ user_ids: Optional[Iterable[str]] = None,
+ to_key: Optional[int] = None,
) -> Set[str]:
"""Get set of users whose devices have changed since `from_key` that
are in the given list of user_ids.
Args:
- from_key: The device lists stream token
- user_ids: The user IDs to query for devices.
+ from_key: The minimum device lists stream token to query device list changes for,
+ exclusive.
+ user_ids: If provided, only check if these users have changed their device lists.
+ Otherwise changes from all users are returned.
+ to_key: The maximum device lists stream token to query device list changes for,
+ inclusive.
Returns:
- The set of user_ids whose devices have changed since `from_key`
+ The set of user_ids whose devices have changed since `from_key` (exclusive)
+ until `to_key` (inclusive).
"""
-
# Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed.
- to_check = self._device_list_stream_cache.get_entities_changed(
- user_ids, from_key
- )
+ if user_ids is None:
+ # Get set of all users that have had device list changes since 'from_key'
+ user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed(
+ from_key
+ )
+ else:
+ # The same as above, but filter results to only those users in 'user_ids'
+ user_ids_to_check = self._device_list_stream_cache.get_entities_changed(
+ user_ids, from_key
+ )
- if not to_check:
+ if not user_ids_to_check:
return set()
def _get_users_whose_devices_changed_txn(txn):
changes = set()
- sql = """
+ stream_id_where_clause = "stream_id > ?"
+ sql_args = [from_key]
+
+ if to_key:
+ stream_id_where_clause += " AND stream_id <= ?"
+ sql_args.append(to_key)
+
+ sql = f"""
SELECT DISTINCT user_id FROM device_lists_stream
- WHERE stream_id > ?
+ WHERE {stream_id_where_clause}
AND
"""
- for chunk in batch_iter(to_check, 100):
+ # Query device changes with a batch of users at a time
+ for chunk in batch_iter(user_ids_to_check, 100):
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
)
- txn.execute(sql + clause, (from_key,) + tuple(args))
+ txn.execute(sql + clause, sql_args + args)
changes.update(user_id for user_id, in txn)
return changes
@@ -788,6 +810,7 @@ class DeviceWorkerStore(SQLBaseStore):
SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
) AS e
WHERE ? < stream_id AND stream_id <= ?
+ ORDER BY stream_id ASC
LIMIT ?
"""
@@ -1506,7 +1529,11 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
async def add_device_change_to_streams(
- self, user_id: str, device_ids: Collection[str], hosts: Collection[str]
+ self,
+ user_id: str,
+ device_ids: Collection[str],
+ hosts: Optional[Collection[str]],
+ room_ids: Collection[str],
) -> Optional[int]:
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
@@ -1515,7 +1542,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: The ID of the user whose device changed.
device_ids: The IDs of any changed devices. If empty, this function will
return None.
- hosts: The remote destinations that should be notified of the change.
+ hosts: The remote destinations that should be notified of the change. If
+ None then the set of hosts have *not* been calculated, and will be
+ calculated later by a background task.
+ room_ids: The rooms that the user is in
Returns:
The maximum stream ID of device list updates that were added to the database, or
@@ -1524,34 +1554,62 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
if not device_ids:
return None
- async with self._device_list_id_gen.get_next_mult(
- len(device_ids)
- ) as stream_ids:
- await self.db_pool.runInteraction(
- "add_device_change_to_stream",
- self._add_device_change_to_stream_txn,
+ context = get_active_span_text_map()
+
+ def add_device_changes_txn(
+ txn, stream_ids_for_device_change, stream_ids_for_outbound_pokes
+ ):
+ self._add_device_change_to_stream_txn(
+ txn,
user_id,
device_ids,
- stream_ids,
+ stream_ids_for_device_change,
)
- if not hosts:
- return stream_ids[-1]
+ self._add_device_outbound_room_poke_txn(
+ txn,
+ user_id,
+ device_ids,
+ room_ids,
+ stream_ids_for_device_change,
+ context,
+ hosts_have_been_calculated=hosts is not None,
+ )
- context = get_active_span_text_map()
- async with self._device_list_id_gen.get_next_mult(
- len(hosts) * len(device_ids)
- ) as stream_ids:
- await self.db_pool.runInteraction(
- "add_device_outbound_poke_to_stream",
- self._add_device_outbound_poke_to_stream_txn,
+ # If the set of hosts to send to has not been calculated yet (and so
+ # `hosts` is None) or there are no `hosts` to send to, then skip
+ # trying to persist them to the DB.
+ if not hosts:
+ return
+
+ self._add_device_outbound_poke_to_stream_txn(
+ txn,
user_id,
device_ids,
hosts,
- stream_ids,
+ stream_ids_for_outbound_pokes,
context,
)
+ # `device_lists_stream` wants a stream ID per device update.
+ num_stream_ids = len(device_ids)
+
+ if hosts:
+ # `device_lists_outbound_pokes` wants a different stream ID for
+ # each row, which is a row per host per device update.
+ num_stream_ids += len(hosts) * len(device_ids)
+
+ async with self._device_list_id_gen.get_next_mult(num_stream_ids) as stream_ids:
+ stream_ids_for_device_change = stream_ids[: len(device_ids)]
+ stream_ids_for_outbound_pokes = stream_ids[len(device_ids) :]
+
+ await self.db_pool.runInteraction(
+ "add_device_change_to_stream",
+ add_device_changes_txn,
+ stream_ids_for_device_change,
+ stream_ids_for_outbound_pokes,
+ )
+
return stream_ids[-1]
def _add_device_change_to_stream_txn(
@@ -1595,7 +1653,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
user_id: str,
device_ids: Iterable[str],
hosts: Collection[str],
- stream_ids: List[str],
+ stream_ids: List[int],
context: Dict[str, str],
) -> None:
for host in hosts:
@@ -1606,8 +1664,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
now = self._clock.time_msec()
- next_stream_id = iter(stream_ids)
+ stream_id_iterator = iter(stream_ids)
+ encoded_context = json_encoder.encode(context)
self.db_pool.simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
@@ -1623,16 +1682,146 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
values=[
(
destination,
- next(next_stream_id),
+ next(stream_id_iterator),
user_id,
device_id,
False,
now,
- json_encoder.encode(context)
- if whitelisted_homeserver(destination)
- else "{}",
+ encoded_context if whitelisted_homeserver(destination) else "{}",
)
for destination in hosts
for device_id in device_ids
],
)
+
+ def _add_device_outbound_room_poke_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_ids: Iterable[str],
+ room_ids: Collection[str],
+ stream_ids: List[str],
+ context: Dict[str, str],
+ hosts_have_been_calculated: bool,
+ ) -> None:
+ """Record the user in the room has updated their device.
+
+ Args:
+ hosts_have_been_calculated: True if `device_lists_outbound_pokes`
+ has been updated already with the updates.
+ """
+
+ # We only need to convert to outbound pokes if they are our user.
+ converted_to_destinations = (
+ hosts_have_been_calculated or not self.hs.is_mine_id(user_id)
+ )
+
+ encoded_context = json_encoder.encode(context)
+
+ # The `device_lists_changes_in_room.stream_id` column matches the
+ # corresponding `stream_id` of the update in the `device_lists_stream`
+ # table, i.e. all rows persisted for the same device update will have
+ # the same `stream_id` (but different room IDs).
+ self.db_pool.simple_insert_many_txn(
+ txn,
+ table="device_lists_changes_in_room",
+ keys=(
+ "user_id",
+ "device_id",
+ "room_id",
+ "stream_id",
+ "converted_to_destinations",
+ "opentracing_context",
+ ),
+ values=[
+ (
+ user_id,
+ device_id,
+ room_id,
+ stream_id,
+ converted_to_destinations,
+ encoded_context,
+ )
+ for room_id in room_ids
+ for device_id, stream_id in zip(device_ids, stream_ids)
+ ],
+ )
+
+ async def get_uncoverted_outbound_room_pokes(
+ self, limit: int = 10
+ ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
+ """Get device list changes by room that have not yet been handled and
+ written to `device_lists_outbound_pokes`.
+
+ Returns:
+ A list of user ID, device ID, room ID, stream ID and optional opentracing context.
+ """
+
+ sql = """
+ SELECT user_id, device_id, room_id, stream_id, opentracing_context
+ FROM device_lists_changes_in_room
+ WHERE NOT converted_to_destinations
+ ORDER BY stream_id
+ LIMIT ?
+ """
+
+ def get_uncoverted_outbound_room_pokes_txn(txn):
+ txn.execute(sql, (limit,))
+ return txn.fetchall()
+
+ return await self.db_pool.runInteraction(
+ "get_uncoverted_outbound_room_pokes", get_uncoverted_outbound_room_pokes_txn
+ )
+
+ async def add_device_list_outbound_pokes(
+ self,
+ user_id: str,
+ device_id: str,
+ room_id: str,
+ stream_id: int,
+ hosts: Collection[str],
+ context: Optional[Dict[str, str]],
+ ) -> None:
+ """Queue the device update to be sent to the given set of hosts,
+ calculated from the room ID.
+
+ Marks the associated row in `device_lists_changes_in_room` as handled.
+ """
+
+ def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]):
+ if hosts:
+ self._add_device_outbound_poke_to_stream_txn(
+ txn,
+ user_id=user_id,
+ device_ids=[device_id],
+ hosts=hosts,
+ stream_ids=stream_ids,
+ context=context,
+ )
+
+ self.db_pool.simple_update_txn(
+ txn,
+ table="device_lists_changes_in_room",
+ keyvalues={
+ "user_id": user_id,
+ "device_id": device_id,
+ "stream_id": stream_id,
+ "room_id": room_id,
+ },
+ updatevalues={"converted_to_destinations": True},
+ )
+
+ if not hosts:
+ # If there are no hosts then we don't try and generate stream IDs.
+ return await self.db_pool.runInteraction(
+ "add_device_list_outbound_pokes",
+ add_device_list_outbound_pokes_txn,
+ [],
+ )
+
+ async with self._device_list_id_gen.get_next_mult(len(hosts)) as stream_ids:
+ return await self.db_pool.runInteraction(
+ "add_device_list_outbound_pokes",
+ add_device_list_outbound_pokes_txn,
+ stream_ids,
+ )
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 59454a47df..a60e3f4fdd 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -22,7 +22,6 @@ from typing import (
Dict,
Iterable,
List,
- NoReturn,
Optional,
Set,
Tuple,
@@ -1330,10 +1329,9 @@ class EventsWorkerStore(SQLBaseStore):
return results
@cached(max_entries=100000, tree=True)
- async def have_seen_event(self, room_id: str, event_id: str) -> NoReturn:
- # this only exists for the benefit of the @cachedList descriptor on
- # _have_seen_events_dict
- raise NotImplementedError()
+ async def have_seen_event(self, room_id: str, event_id: str) -> bool:
+ res = await self._have_seen_events_dict(((room_id, event_id),))
+ return res[(room_id, event_id)]
def _get_current_state_event_counts_txn(
self, txn: LoggingTransaction, room_id: str
diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py
index 216622964a..4f1c22c71b 100644
--- a/synapse/storage/databases/main/monthly_active_users.py
+++ b/synapse/storage/databases/main/monthly_active_users.py
@@ -15,7 +15,6 @@ import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast
from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage._base import SQLBaseStore
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
@@ -36,7 +35,7 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 60 * 60 * 1000
-class MonthlyActiveUsersWorkerStore(SQLBaseStore):
+class MonthlyActiveUsersWorkerStore(RegistrationWorkerStore):
def __init__(
self,
database: DatabasePool,
@@ -47,9 +46,30 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
self._clock = hs.get_clock()
self.hs = hs
+ if hs.config.redis.redis_enabled:
+ # If we're using Redis, we can shift this update process off to
+ # the background worker
+ self._update_on_this_worker = hs.config.worker.run_background_tasks
+ else:
+ # If we're NOT using Redis, this must be handled by the master
+ self._update_on_this_worker = hs.get_instance_name() == "master"
+
self._limit_usage_by_mau = hs.config.server.limit_usage_by_mau
self._max_mau_value = hs.config.server.max_mau_value
+ self._mau_stats_only = hs.config.server.mau_stats_only
+
+ if self._update_on_this_worker:
+ # Do not add more reserved users than the total allowable number
+ self.db_pool.new_transaction(
+ db_conn,
+ "initialise_mau_threepids",
+ [],
+ [],
+ self._initialise_reserved_users,
+ hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
+ )
+
@cached(num_args=0)
async def get_monthly_active_count(self) -> int:
"""Generates current count of monthly active users
@@ -222,28 +242,6 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):
"reap_monthly_active_users", _reap_users, reserved_users
)
-
-class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerStore):
- def __init__(
- self,
- database: DatabasePool,
- db_conn: LoggingDatabaseConnection,
- hs: "HomeServer",
- ):
- super().__init__(database, db_conn, hs)
-
- self._mau_stats_only = hs.config.server.mau_stats_only
-
- # Do not add more reserved users than the total allowable number
- self.db_pool.new_transaction(
- db_conn,
- "initialise_mau_threepids",
- [],
- [],
- self._initialise_reserved_users,
- hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value],
- )
-
def _initialise_reserved_users(
self, txn: LoggingTransaction, threepids: List[dict]
) -> None:
@@ -254,6 +252,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
txn:
threepids: List of threepid dicts to reserve
"""
+ assert (
+ self._update_on_this_worker
+ ), "This worker is not designated to update MAUs"
# XXX what is this function trying to achieve? It upserts into
# monthly_active_users for each *registered* reserved mau user, but why?
@@ -287,6 +288,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
Args:
user_id: user to add/update
"""
+ assert (
+ self._update_on_this_worker
+ ), "This worker is not designated to update MAUs"
+
# Support user never to be included in MAU stats. Note I can't easily call this
# from upsert_monthly_active_user_txn because then I need a _txn form of
# is_support_user which is complicated because I want to cache the result.
@@ -322,6 +327,9 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
txn (cursor):
user_id (str): user to add/update
"""
+ assert (
+ self._update_on_this_worker
+ ), "This worker is not designated to update MAUs"
# Am consciously deciding to lock the table on the basis that is ought
# never be a big table and alternative approaches (batching multiple
@@ -349,6 +357,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore, RegistrationWorkerS
Args:
user_id(str): the user_id to query
"""
+ assert (
+ self._update_on_this_worker
+ ), "This worker is not designated to update MAUs"
+
if self._limit_usage_by_mau or self._mau_stats_only:
# Trial users and guests should not be included as part of MAU group
is_guest = await self.is_guest(user_id) # type: ignore[attr-defined]
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 7f3d190e94..c7634c92fd 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1745,6 +1745,18 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
"replace_refresh_token", _replace_refresh_token_txn
)
+ @cached()
+ async def is_guest(self, user_id: str) -> bool:
+ res = await self.db_pool.simple_select_one_onecol(
+ table="users",
+ keyvalues={"name": user_id},
+ retcol="is_guest",
+ allow_none=True,
+ desc="is_guest",
+ )
+
+ return res if res else False
+
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(
@@ -1887,18 +1899,6 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
txn.call_after(self.is_guest.invalidate, (user_id,))
- @cached()
- async def is_guest(self, user_id: str) -> bool:
- res = await self.db_pool.simple_select_one_onecol(
- table="users",
- keyvalues={"name": user_id},
- retcol="is_guest",
- allow_none=True,
- desc="is_guest",
- )
-
- return res if res else False
-
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
def __init__(
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index b2295fd51f..64a7808140 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -26,8 +26,6 @@ from typing import (
cast,
)
-import attr
-
from synapse.api.constants import RelationTypes
from synapse.events import EventBase
from synapse.storage._base import SQLBaseStore
@@ -39,8 +37,7 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine
-from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
-from synapse.types import RoomStreamToken, StreamToken
+from synapse.types import JsonDict, RoomStreamToken, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
@@ -73,7 +70,7 @@ class RelationsWorkerStore(SQLBaseStore):
direction: str = "b",
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
- ) -> PaginationChunk:
+ ) -> Tuple[List[str], Optional[StreamToken]]:
"""Get a list of relations for an event, ordered by topological ordering.
Args:
@@ -90,8 +87,10 @@ class RelationsWorkerStore(SQLBaseStore):
to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
- List of event IDs that match relations requested. The rows are of
- the form `{"event_id": "..."}`.
+ A tuple of:
+ A list of related event IDs
+
+ The next stream token, if one exists.
"""
# We don't use `event_id`, it's there so that we can cache based on
# it. The `event_id` must match the `event.event_id`.
@@ -146,7 +145,7 @@ class RelationsWorkerStore(SQLBaseStore):
def _get_recent_references_for_event_txn(
txn: LoggingTransaction,
- ) -> PaginationChunk:
+ ) -> Tuple[List[str], Optional[StreamToken]]:
txn.execute(sql, where_args + [limit + 1])
last_topo_id = None
@@ -156,7 +155,7 @@ class RelationsWorkerStore(SQLBaseStore):
# Do not include edits for redacted events as they leak event
# content.
if not is_redacted or row[1] != RelationTypes.REPLACE:
- events.append({"event_id": row[0]})
+ events.append(row[0])
last_topo_id = row[2]
last_stream_id = row[3]
@@ -179,9 +178,7 @@ class RelationsWorkerStore(SQLBaseStore):
groups_key=0,
)
- return PaginationChunk(
- chunk=list(events[:limit]), next_batch=next_token, prev_batch=from_token
- )
+ return events[:limit], next_token
return await self.db_pool.runInteraction(
"get_recent_references_for_event", _get_recent_references_for_event_txn
@@ -252,15 +249,8 @@ class RelationsWorkerStore(SQLBaseStore):
@cached(tree=True)
async def get_aggregation_groups_for_event(
- self,
- event_id: str,
- room_id: str,
- event_type: Optional[str] = None,
- limit: int = 5,
- direction: str = "b",
- from_token: Optional[AggregationPaginationToken] = None,
- to_token: Optional[AggregationPaginationToken] = None,
- ) -> PaginationChunk:
+ self, event_id: str, room_id: str, limit: int = 5
+ ) -> List[JsonDict]:
"""Get a list of annotations on the event, grouped by event type and
aggregation key, sorted by count.
@@ -270,79 +260,36 @@ class RelationsWorkerStore(SQLBaseStore):
Args:
event_id: Fetch events that relate to this event ID.
room_id: The room the event belongs to.
- event_type: Only fetch events with this event type, if given.
limit: Only fetch the `limit` groups.
- direction: Whether to fetch the highest count first (`"b"`) or
- the lowest count first (`"f"`).
- from_token: Fetch rows from the given token, or from the start if None.
- to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
List of groups of annotations that match. Each row is a dict with
`type`, `key` and `count` fields.
"""
- where_clause = ["relates_to_id = ?", "room_id = ?", "relation_type = ?"]
- where_args: List[Union[str, int]] = [
+ where_args = [
event_id,
room_id,
RelationTypes.ANNOTATION,
+ limit,
]
- if event_type:
- where_clause.append("type = ?")
- where_args.append(event_type)
-
- having_clause = generate_pagination_where_clause(
- direction=direction,
- column_names=("COUNT(*)", "MAX(stream_ordering)"),
- from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
- to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
- engine=self.database_engine,
- )
-
- if direction == "b":
- order = "DESC"
- else:
- order = "ASC"
-
- if having_clause:
- having_clause = "HAVING " + having_clause
- else:
- having_clause = ""
-
sql = """
- SELECT type, aggregation_key, COUNT(DISTINCT sender), MAX(stream_ordering)
+ SELECT type, aggregation_key, COUNT(DISTINCT sender)
FROM event_relations
INNER JOIN events USING (event_id)
- WHERE {where_clause}
+ WHERE relates_to_id = ? AND room_id = ? AND relation_type = ?
GROUP BY relation_type, type, aggregation_key
- {having_clause}
- ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order}
+ ORDER BY COUNT(*) DESC
LIMIT ?
- """.format(
- where_clause=" AND ".join(where_clause),
- order=order,
- having_clause=having_clause,
- )
+ """
def _get_aggregation_groups_for_event_txn(
txn: LoggingTransaction,
- ) -> PaginationChunk:
- txn.execute(sql, where_args + [limit + 1])
-
- next_batch = None
- events = []
- for row in txn:
- events.append({"type": row[0], "key": row[1], "count": row[2]})
- next_batch = AggregationPaginationToken(row[2], row[3])
+ ) -> List[JsonDict]:
+ txn.execute(sql, where_args)
- if len(events) <= limit:
- next_batch = None
-
- return PaginationChunk(
- chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
- )
+ return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn]
return await self.db_pool.runInteraction(
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 3248da5356..98d09b3736 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -361,7 +361,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return None
async def get_rooms_for_local_user_where_membership_is(
- self, user_id: str, membership_list: Collection[str]
+ self,
+ user_id: str,
+ membership_list: Collection[str],
+ excluded_rooms: Optional[List[str]] = None,
) -> List[RoomsForUser]:
"""Get all the rooms for this *local* user where the membership for this user
matches one in the membership list.
@@ -372,6 +375,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
user_id: The user ID.
membership_list: A list of synapse.api.constants.Membership
values which the user must be in.
+ excluded_rooms: A list of rooms to ignore.
Returns:
The RoomsForUser that the user matches the membership types.
@@ -386,12 +390,19 @@ class RoomMemberWorkerStore(EventsWorkerStore):
membership_list,
)
- # Now we filter out forgotten rooms
- forgotten_rooms = await self.get_forgotten_rooms_for_user(user_id)
- return [room for room in rooms if room.room_id not in forgotten_rooms]
+ # Now we filter out forgotten and excluded rooms
+ rooms_to_exclude: Set[str] = await self.get_forgotten_rooms_for_user(user_id)
+
+ if excluded_rooms is not None:
+ rooms_to_exclude.update(set(excluded_rooms))
+
+ return [room for room in rooms if room.room_id not in rooms_to_exclude]
def _get_rooms_for_local_user_where_membership_is_txn(
- self, txn, user_id: str, membership_list: List[str]
+ self,
+ txn,
+ user_id: str,
+ membership_list: List[str],
) -> List[RoomsForUser]:
# Paranoia check.
if not self.hs.is_mine_id(user_id):
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 28460fd364..4a461a0abb 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -12,9 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import collections.abc
import logging
-from typing import TYPE_CHECKING, Collection, Iterable, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
+
+from frozendict import frozendict
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
@@ -29,7 +30,7 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.state import StateFilter
-from synapse.types import JsonDict, StateMap
+from synapse.types import JsonDict, JsonMapping, StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList
@@ -132,7 +133,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return room_version
- async def get_room_predecessor(self, room_id: str) -> Optional[dict]:
+ async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]:
"""Get the predecessor of an upgraded room if it exists.
Otherwise return None.
@@ -158,9 +159,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
predecessor = create_event.content.get("predecessor", None)
# Ensure the key is a dictionary
- if not isinstance(predecessor, collections.abc.Mapping):
+ if not isinstance(predecessor, (dict, frozendict)):
return None
+ # The keys must be strings since the data is JSON.
return predecessor
async def get_create_event_for_room(self, room_id: str) -> EventBase:
@@ -306,8 +308,14 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
list_name="event_ids",
num_args=1,
)
- async def _get_state_group_for_events(self, event_ids: Collection[str]) -> JsonDict:
- """Returns mapping event_id -> state_group"""
+ async def _get_state_group_for_events(
+ self, event_ids: Collection[str]
+ ) -> Dict[str, int]:
+ """Returns mapping event_id -> state_group.
+
+ Raises:
+ RuntimeError if the state is unknown at any of the given events
+ """
rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
@@ -317,7 +325,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="_get_state_group_for_events",
)
- return {row["event_id"]: row["state_group"] for row in rows}
+ res = {row["event_id"]: row["state_group"] for row in rows}
+ for e in event_ids:
+ if e not in res:
+ raise RuntimeError("No state group for unknown or outlier event %s" % e)
+ return res
async def get_referenced_state_groups(
self, state_groups: Iterable[int]
@@ -521,7 +533,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
)
for user_id in potentially_left_users - joined_users:
- await self.mark_remote_user_device_list_as_unsubscribed(user_id)
+ await self.mark_remote_user_device_list_as_unsubscribed(user_id) # type: ignore[attr-defined]
return batch_size
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 39e1efe373..8e764790db 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -36,7 +36,7 @@ what sort order was used:
"""
import logging
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set, Tuple
import attr
from frozendict import frozendict
@@ -585,7 +585,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key
async def get_membership_changes_for_user(
- self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
+ self,
+ user_id: str,
+ from_key: RoomStreamToken,
+ to_key: RoomStreamToken,
+ excluded_rooms: Optional[List[str]] = None,
) -> List[EventBase]:
"""Fetch membership events for a given user.
@@ -610,23 +614,29 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
min_from_id = from_key.stream
max_to_id = to_key.get_max_stream_pos()
+ args: List[Any] = [user_id, min_from_id, max_to_id]
+
+ ignore_room_clause = ""
+ if excluded_rooms is not None and len(excluded_rooms) > 0:
+ ignore_room_clause = "AND e.room_id NOT IN (%s)" % ",".join(
+ "?" for _ in excluded_rooms
+ )
+ args = args + excluded_rooms
+
sql = """
SELECT m.event_id, instance_name, topological_ordering, stream_ordering
FROM events AS e, room_memberships AS m
WHERE e.event_id = m.event_id
AND m.user_id = ?
AND e.stream_ordering > ? AND e.stream_ordering <= ?
+ %s
ORDER BY e.stream_ordering ASC
- """
- txn.execute(
- sql,
- (
- user_id,
- min_from_id,
- max_to_id,
- ),
+ """ % (
+ ignore_room_clause,
)
+ txn.execute(sql, args)
+
rows = [
_EventDictReturn(event_id, None, stream_ordering)
for event_id, instance_name, topological_ordering, stream_ordering in txn
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
deleted file mode 100644
index fba270150b..0000000000
--- a/synapse/storage/relations.py
+++ /dev/null
@@ -1,84 +0,0 @@
-# Copyright 2019 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
-
-import attr
-
-from synapse.api.errors import SynapseError
-from synapse.types import JsonDict
-
-if TYPE_CHECKING:
- from synapse.storage.databases.main import DataStore
-
-logger = logging.getLogger(__name__)
-
-
-@attr.s(slots=True, auto_attribs=True)
-class PaginationChunk:
- """Returned by relation pagination APIs.
-
- Attributes:
- chunk: The rows returned by pagination
- next_batch: Token to fetch next set of results with, if
- None then there are no more results.
- prev_batch: Token to fetch previous set of results with, if
- None then there are no previous results.
- """
-
- chunk: List[JsonDict]
- next_batch: Optional[Any] = None
- prev_batch: Optional[Any] = None
-
- async def to_dict(self, store: "DataStore") -> Dict[str, Any]:
- d = {"chunk": self.chunk}
-
- if self.next_batch:
- d["next_batch"] = await self.next_batch.to_string(store)
-
- if self.prev_batch:
- d["prev_batch"] = await self.prev_batch.to_string(store)
-
- return d
-
-
-@attr.s(frozen=True, slots=True, auto_attribs=True)
-class AggregationPaginationToken:
- """Pagination token for relation aggregation pagination API.
-
- As the results are order by count and then MAX(stream_ordering) of the
- aggregation groups, we can just use them as our pagination token.
-
- Attributes:
- count: The count of relations in the boundary group.
- stream: The MAX stream ordering in the boundary group.
- """
-
- count: int
- stream: int
-
- @staticmethod
- def from_string(string: str) -> "AggregationPaginationToken":
- try:
- c, s = string.split("-")
- return AggregationPaginationToken(int(c), int(s))
- except ValueError:
- raise SynapseError(400, "Invalid aggregation pagination token")
-
- async def to_string(self, store: "DataStore") -> str:
- return "%d-%d" % (self.count, self.stream)
-
- def as_tuple(self) -> Tuple[Any, ...]:
- return attr.astuple(self)
diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py
index 7b21c1b96d..151f2aa9bb 100644
--- a/synapse/storage/schema/__init__.py
+++ b/synapse/storage/schema/__init__.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-SCHEMA_VERSION = 68 # remember to update the list below when updating
+SCHEMA_VERSION = 69 # remember to update the list below when updating
"""Represents the expectations made by the codebase about the database schema
This should be incremented whenever the codebase changes its requirements on the
@@ -58,6 +58,10 @@ Changes in SCHEMA_VERSION = 68:
- event_reference_hashes is no longer read.
- `events` has `state_key` and `rejection_reason` columns, which are populated for
new events.
+
+Changes in SCHEMA_VERSION = 69:
+ - We now write to `device_lists_changes_in_room` table.
+ - Use sequence to generate future `application_services_txns.txn_id`s
"""
diff --git a/synapse/storage/schema/main/delta/68/06_msc3202_add_device_list_appservice_stream_type.sql b/synapse/storage/schema/main/delta/68/06_msc3202_add_device_list_appservice_stream_type.sql
new file mode 100644
index 0000000000..7590e34b94
--- /dev/null
+++ b/synapse/storage/schema/main/delta/68/06_msc3202_add_device_list_appservice_stream_type.sql
@@ -0,0 +1,23 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Add a column to track what device list changes stream id that this application
+-- service has been caught up to.
+
+-- We explicitly don't set this field as "NOT NULL", as having NULL as a possible
+-- state is useful for determining if we've ever sent traffic for a stream type
+-- to an appservice. See https://github.com/matrix-org/synapse/issues/10836 for
+-- one way this can be used.
+ALTER TABLE application_services_state ADD COLUMN device_list_stream_id BIGINT;
\ No newline at end of file
diff --git a/synapse/storage/schema/main/delta/69/01as_txn_seq.py b/synapse/storage/schema/main/delta/69/01as_txn_seq.py
new file mode 100644
index 0000000000..24bd4b391e
--- /dev/null
+++ b/synapse/storage/schema/main/delta/69/01as_txn_seq.py
@@ -0,0 +1,44 @@
+# Copyright 2022 Beeper
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+Adds a postgres SEQUENCE for generating application service transaction IDs.
+"""
+
+from synapse.storage.engines import PostgresEngine
+
+
+def run_create(cur, database_engine, *args, **kwargs):
+ if isinstance(database_engine, PostgresEngine):
+ # If we already have some AS TXNs we want to start from the current
+ # maximum value. There are two potential places this is stored - the
+ # actual TXNs themselves *and* the AS state table. At time of migration
+ # it is possible the TXNs table is empty so we must include the AS state
+ # last_txn as a potential option, and pick the maximum.
+
+ cur.execute("SELECT COALESCE(max(txn_id), 0) FROM application_services_txns")
+ row = cur.fetchone()
+ txn_max = row[0]
+
+ cur.execute("SELECT COALESCE(max(last_txn), 0) FROM application_services_state")
+ row = cur.fetchone()
+ last_txn_max = row[0]
+
+ start_val = max(last_txn_max, txn_max) + 1
+
+ cur.execute(
+ "CREATE SEQUENCE application_services_txn_id_seq START WITH %s",
+ (start_val,),
+ )
diff --git a/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql b/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql
new file mode 100644
index 0000000000..b5b1782b2a
--- /dev/null
+++ b/synapse/storage/schema/main/delta/69/01device_list_oubound_by_room.sql
@@ -0,0 +1,38 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+CREATE TABLE device_lists_changes_in_room (
+ user_id TEXT NOT NULL,
+ device_id TEXT NOT NULL,
+ room_id TEXT NOT NULL,
+
+ -- This initially matches `device_lists_stream.stream_id`. Note that we
+ -- delete older values from `device_lists_stream`, so we can't use a foreign
+ -- constraint here.
+ --
+ -- The table will contain rows with the same `stream_id` but different
+ -- `room_id`, as for each device update we store a row per room the user is
+ -- joined to. Therefore `(stream_id, room_id)` gives a unique index.
+ stream_id BIGINT NOT NULL,
+
+ -- We have a background process which goes through this table and converts
+ -- entries into rows in `device_lists_outbound_pokes`. Once we have processed
+ -- a row, we mark it as such by setting `converted_to_destinations=TRUE`.
+ converted_to_destinations BOOLEAN NOT NULL,
+ opentracing_context TEXT
+);
+
+CREATE UNIQUE INDEX device_lists_changes_in_stream_id ON device_lists_changes_in_room(stream_id, room_id);
+CREATE INDEX device_lists_changes_in_stream_id_unconverted ON device_lists_changes_in_room(stream_id) WHERE NOT converted_to_destinations;
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 86f1a5373b..cda194e8c8 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -571,6 +571,10 @@ class StateGroupStorage:
Returns:
dict of state_group_id -> (dict of (type, state_key) -> event id)
+
+ Raises:
+ RuntimeError if we don't have a state group for one or more of the events
+ (ie they are outliers or unknown)
"""
if not event_ids:
return {}
@@ -659,6 +663,10 @@ class StateGroupStorage:
Returns:
A dict of (event_id) -> (type, state_key) -> [state_events]
+
+ Raises:
+ RuntimeError if we don't have a state group for one or more of the events
+ (ie they are outliers or unknown)
"""
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
@@ -696,6 +704,10 @@ class StateGroupStorage:
Returns:
A dict from event_id -> (type, state_key) -> event_id
+
+ Raises:
+ RuntimeError if we don't have a state group for one or more of the events
+ (ie they are outliers or unknown)
"""
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
@@ -723,6 +735,10 @@ class StateGroupStorage:
Returns:
A dict from (type, state_key) -> state_event
+
+ Raises:
+ RuntimeError if we don't have a state group for the event (ie it is an
+ outlier or is unknown)
"""
state_map = await self.get_state_for_events(
[event_id], state_filter or StateFilter.all()
@@ -741,6 +757,10 @@ class StateGroupStorage:
Returns:
A dict from (type, state_key) -> state_event_id
+
+ Raises:
+ RuntimeError if we don't have a state group for the event (ie it is an
+ outlier or is unknown)
"""
state_map = await self.get_state_ids_for_events(
[event_id], state_filter or StateFilter.all()
|